warp-lang 1.5.1__py3-none-macosx_10_13_universal2.whl → 1.6.1__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (131) hide show
  1. warp/__init__.py +5 -0
  2. warp/autograd.py +414 -191
  3. warp/bin/libwarp-clang.dylib +0 -0
  4. warp/bin/libwarp.dylib +0 -0
  5. warp/build.py +40 -12
  6. warp/build_dll.py +13 -6
  7. warp/builtins.py +1077 -481
  8. warp/codegen.py +250 -122
  9. warp/config.py +65 -21
  10. warp/context.py +500 -149
  11. warp/examples/assets/square_cloth.usd +0 -0
  12. warp/examples/benchmarks/benchmark_gemm.py +27 -18
  13. warp/examples/benchmarks/benchmark_interop_paddle.py +3 -3
  14. warp/examples/benchmarks/benchmark_interop_torch.py +3 -3
  15. warp/examples/core/example_marching_cubes.py +1 -1
  16. warp/examples/core/example_mesh.py +1 -1
  17. warp/examples/core/example_torch.py +18 -34
  18. warp/examples/core/example_wave.py +1 -1
  19. warp/examples/fem/example_apic_fluid.py +1 -0
  20. warp/examples/fem/example_mixed_elasticity.py +1 -1
  21. warp/examples/optim/example_bounce.py +1 -1
  22. warp/examples/optim/example_cloth_throw.py +1 -1
  23. warp/examples/optim/example_diffray.py +4 -15
  24. warp/examples/optim/example_drone.py +1 -1
  25. warp/examples/optim/example_softbody_properties.py +392 -0
  26. warp/examples/optim/example_trajectory.py +1 -3
  27. warp/examples/optim/example_walker.py +5 -0
  28. warp/examples/sim/example_cartpole.py +0 -2
  29. warp/examples/sim/example_cloth_self_contact.py +314 -0
  30. warp/examples/sim/example_granular_collision_sdf.py +4 -5
  31. warp/examples/sim/example_jacobian_ik.py +0 -2
  32. warp/examples/sim/example_quadruped.py +5 -2
  33. warp/examples/tile/example_tile_cholesky.py +79 -0
  34. warp/examples/tile/example_tile_convolution.py +2 -2
  35. warp/examples/tile/example_tile_fft.py +2 -2
  36. warp/examples/tile/example_tile_filtering.py +3 -3
  37. warp/examples/tile/example_tile_matmul.py +4 -4
  38. warp/examples/tile/example_tile_mlp.py +12 -12
  39. warp/examples/tile/example_tile_nbody.py +191 -0
  40. warp/examples/tile/example_tile_walker.py +319 -0
  41. warp/math.py +147 -0
  42. warp/native/array.h +12 -0
  43. warp/native/builtin.h +0 -1
  44. warp/native/bvh.cpp +149 -70
  45. warp/native/bvh.cu +287 -68
  46. warp/native/bvh.h +195 -85
  47. warp/native/clang/clang.cpp +6 -2
  48. warp/native/crt.h +1 -0
  49. warp/native/cuda_util.cpp +35 -0
  50. warp/native/cuda_util.h +5 -0
  51. warp/native/exports.h +40 -40
  52. warp/native/intersect.h +17 -0
  53. warp/native/mat.h +57 -3
  54. warp/native/mathdx.cpp +19 -0
  55. warp/native/mesh.cpp +25 -8
  56. warp/native/mesh.cu +153 -101
  57. warp/native/mesh.h +482 -403
  58. warp/native/quat.h +40 -0
  59. warp/native/solid_angle.h +7 -0
  60. warp/native/sort.cpp +85 -0
  61. warp/native/sort.cu +34 -0
  62. warp/native/sort.h +3 -1
  63. warp/native/spatial.h +11 -0
  64. warp/native/tile.h +1189 -664
  65. warp/native/tile_reduce.h +8 -6
  66. warp/native/vec.h +41 -0
  67. warp/native/warp.cpp +8 -1
  68. warp/native/warp.cu +263 -40
  69. warp/native/warp.h +19 -5
  70. warp/optim/linear.py +22 -4
  71. warp/render/render_opengl.py +132 -59
  72. warp/render/render_usd.py +10 -2
  73. warp/sim/__init__.py +6 -1
  74. warp/sim/collide.py +289 -32
  75. warp/sim/import_urdf.py +20 -5
  76. warp/sim/integrator_euler.py +25 -7
  77. warp/sim/integrator_featherstone.py +147 -35
  78. warp/sim/integrator_vbd.py +842 -40
  79. warp/sim/model.py +173 -112
  80. warp/sim/render.py +2 -2
  81. warp/stubs.py +249 -116
  82. warp/tape.py +28 -30
  83. warp/tests/aux_test_module_unload.py +15 -0
  84. warp/tests/{test_sim_grad.py → flaky_test_sim_grad.py} +104 -63
  85. warp/tests/test_array.py +100 -0
  86. warp/tests/test_assert.py +242 -0
  87. warp/tests/test_codegen.py +14 -61
  88. warp/tests/test_collision.py +8 -8
  89. warp/tests/test_examples.py +16 -1
  90. warp/tests/test_grad_debug.py +87 -2
  91. warp/tests/test_hash_grid.py +1 -1
  92. warp/tests/test_ipc.py +116 -0
  93. warp/tests/test_launch.py +77 -26
  94. warp/tests/test_mat.py +213 -168
  95. warp/tests/test_math.py +47 -1
  96. warp/tests/test_matmul.py +11 -7
  97. warp/tests/test_matmul_lite.py +4 -4
  98. warp/tests/test_mesh.py +84 -60
  99. warp/tests/test_mesh_query_aabb.py +165 -0
  100. warp/tests/test_mesh_query_point.py +328 -286
  101. warp/tests/test_mesh_query_ray.py +134 -121
  102. warp/tests/test_mlp.py +2 -2
  103. warp/tests/test_operators.py +43 -0
  104. warp/tests/test_overwrite.py +6 -5
  105. warp/tests/test_quat.py +77 -0
  106. warp/tests/test_reload.py +29 -0
  107. warp/tests/test_sim_grad_bounce_linear.py +204 -0
  108. warp/tests/test_static.py +16 -0
  109. warp/tests/test_tape.py +25 -0
  110. warp/tests/test_tile.py +134 -191
  111. warp/tests/test_tile_load.py +399 -0
  112. warp/tests/test_tile_mathdx.py +61 -8
  113. warp/tests/test_tile_mlp.py +17 -17
  114. warp/tests/test_tile_reduce.py +24 -18
  115. warp/tests/test_tile_shared_memory.py +66 -17
  116. warp/tests/test_tile_view.py +165 -0
  117. warp/tests/test_torch.py +35 -0
  118. warp/tests/test_utils.py +36 -24
  119. warp/tests/test_vec.py +110 -0
  120. warp/tests/unittest_suites.py +29 -4
  121. warp/tests/unittest_utils.py +30 -11
  122. warp/thirdparty/unittest_parallel.py +5 -2
  123. warp/types.py +419 -111
  124. warp/utils.py +9 -5
  125. {warp_lang-1.5.1.dist-info → warp_lang-1.6.1.dist-info}/METADATA +86 -45
  126. {warp_lang-1.5.1.dist-info → warp_lang-1.6.1.dist-info}/RECORD +129 -118
  127. {warp_lang-1.5.1.dist-info → warp_lang-1.6.1.dist-info}/WHEEL +1 -1
  128. warp/examples/benchmarks/benchmark_tile.py +0 -179
  129. warp/native/tile_gemm.h +0 -341
  130. {warp_lang-1.5.1.dist-info → warp_lang-1.6.1.dist-info}/LICENSE.md +0 -0
  131. {warp_lang-1.5.1.dist-info → warp_lang-1.6.1.dist-info}/top_level.txt +0 -0
warp/context.py CHANGED
@@ -5,6 +5,8 @@
5
5
  # distribution of this software and related documentation without an express
6
6
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
7
 
8
+ from __future__ import annotations
9
+
8
10
  import ast
9
11
  import ctypes
10
12
  import errno
@@ -32,6 +34,7 @@ import warp
32
34
  import warp.build
33
35
  import warp.codegen
34
36
  import warp.config
37
+ from warp.types import launch_bounds_t
35
38
 
36
39
  # represents either a built-in or user-defined function
37
40
 
@@ -393,7 +396,8 @@ class Function:
393
396
  if not warp.codegen.func_match_args(f, arg_types, kwarg_types):
394
397
  continue
395
398
 
396
- if len(f.input_types) != len(arg_types):
399
+ acceptable_arg_num = len(f.input_types) - len(f.defaults) <= len(arg_types) <= len(f.input_types)
400
+ if not acceptable_arg_num:
397
401
  continue
398
402
 
399
403
  # try to match the given types to the function template types
@@ -410,6 +414,10 @@ class Function:
410
414
 
411
415
  arg_names = f.input_types.keys()
412
416
  overload_annotations = dict(zip(arg_names, arg_types))
417
+ # add defaults
418
+ for k, d in f.defaults.items():
419
+ if k not in overload_annotations:
420
+ overload_annotations[k] = warp.codegen.strip_reference(warp.codegen.get_arg_type(d))
413
421
 
414
422
  ovl = shallowcopy(f)
415
423
  ovl.adj = warp.codegen.Adjoint(f.func, overload_annotations)
@@ -753,8 +761,15 @@ def func(f):
753
761
  scope_locals = inspect.currentframe().f_back.f_locals
754
762
 
755
763
  m = get_module(f.__module__)
764
+ doc = getattr(f, "__doc__", "") or ""
756
765
  Function(
757
- func=f, key=name, namespace="", module=m, value_func=None, scope_locals=scope_locals
766
+ func=f,
767
+ key=name,
768
+ namespace="",
769
+ module=m,
770
+ value_func=None,
771
+ scope_locals=scope_locals,
772
+ doc=doc.strip(),
758
773
  ) # value_type not known yet, will be inferred during Adjoint.build()
759
774
 
760
775
  # use the top of the list of overloads for this key
@@ -1059,7 +1074,8 @@ def overload(kernel, arg_types=Union[None, Dict[str, Any], List[Any]]):
1059
1074
  raise RuntimeError("wp.overload() called with invalid argument!")
1060
1075
 
1061
1076
 
1062
- builtin_functions = {}
1077
+ # native functions that are part of the Warp API
1078
+ builtin_functions: Dict[str, Function] = {}
1063
1079
 
1064
1080
 
1065
1081
  def get_generic_vtypes():
@@ -1328,6 +1344,28 @@ def add_builtin(
1328
1344
  setattr(warp, key, func)
1329
1345
 
1330
1346
 
1347
+ def register_api_function(
1348
+ function: Function,
1349
+ group: str = "Other",
1350
+ hidden=False,
1351
+ ):
1352
+ """Main entry point to register a Warp Python function to be part of the Warp API and appear in the documentation.
1353
+
1354
+ Args:
1355
+ function (Function): Warp function to be registered.
1356
+ group (str): Classification used for the documentation.
1357
+ input_types (Mapping[str, Any]): Signature of the user-facing function.
1358
+ Variadic arguments are supported by prefixing the parameter names
1359
+ with asterisks as in `*args` and `**kwargs`. Generic arguments are
1360
+ supported with types such as `Any`, `Float`, `Scalar`, etc.
1361
+ value_type (Any): Type returned by the function.
1362
+ hidden (bool): Whether to add that function into the documentation.
1363
+ """
1364
+ function.group = group
1365
+ function.hidden = hidden
1366
+ builtin_functions[function.key] = function
1367
+
1368
+
1331
1369
  # global dictionary of modules
1332
1370
  user_modules = {}
1333
1371
 
@@ -1561,6 +1599,7 @@ class ModuleBuilder:
1561
1599
  self.options = options
1562
1600
  self.module = module
1563
1601
  self.deferred_functions = []
1602
+ self.fatbins = {} # map from <some identifier> to fatbins, to add at link time
1564
1603
  self.ltoirs = {} # map from lto symbol to lto binary
1565
1604
  self.ltoirs_decl = {} # map from lto symbol to lto forward declaration
1566
1605
 
@@ -1675,7 +1714,7 @@ class ModuleBuilder:
1675
1714
 
1676
1715
  for kernel in self.kernels:
1677
1716
  source += warp.codegen.codegen_kernel(kernel, device=device, options=self.options)
1678
- source += warp.codegen.codegen_module(kernel, device=device)
1717
+ source += warp.codegen.codegen_module(kernel, device=device, options=self.options)
1679
1718
 
1680
1719
  # add headers
1681
1720
  if device == "cpu":
@@ -1728,20 +1767,26 @@ class ModuleExec:
1728
1767
 
1729
1768
  name = kernel.get_mangled_name()
1730
1769
 
1770
+ options = dict(kernel.module.options)
1771
+ options.update(kernel.options)
1772
+
1731
1773
  if self.device.is_cuda:
1732
1774
  forward_name = name + "_cuda_kernel_forward"
1733
1775
  forward_kernel = runtime.core.cuda_get_kernel(
1734
1776
  self.device.context, self.handle, forward_name.encode("utf-8")
1735
1777
  )
1736
1778
 
1737
- backward_name = name + "_cuda_kernel_backward"
1738
- backward_kernel = runtime.core.cuda_get_kernel(
1739
- self.device.context, self.handle, backward_name.encode("utf-8")
1740
- )
1779
+ if options["enable_backward"]:
1780
+ backward_name = name + "_cuda_kernel_backward"
1781
+ backward_kernel = runtime.core.cuda_get_kernel(
1782
+ self.device.context, self.handle, backward_name.encode("utf-8")
1783
+ )
1784
+ else:
1785
+ backward_kernel = None
1741
1786
 
1742
1787
  # look up the required shared memory size for each kernel from module metadata
1743
1788
  forward_smem_bytes = self.meta[forward_name + "_smem_bytes"]
1744
- backward_smem_bytes = self.meta[backward_name + "_smem_bytes"]
1789
+ backward_smem_bytes = self.meta[backward_name + "_smem_bytes"] if options["enable_backward"] else 0
1745
1790
 
1746
1791
  # configure kernels maximum shared memory size
1747
1792
  max_smem_bytes = runtime.core.cuda_get_max_shared_memory(self.device.context)
@@ -1751,9 +1796,6 @@ class ModuleExec:
1751
1796
  f"Warning: Failed to configure kernel dynamic shared memory for this device, tried to configure {forward_name} kernel for {forward_smem_bytes} bytes, but maximum available is {max_smem_bytes}"
1752
1797
  )
1753
1798
 
1754
- options = dict(kernel.module.options)
1755
- options.update(kernel.options)
1756
-
1757
1799
  if options["enable_backward"] and not runtime.core.cuda_configure_kernel_shared_memory(
1758
1800
  backward_kernel, backward_smem_bytes
1759
1801
  ):
@@ -1768,9 +1810,14 @@ class ModuleExec:
1768
1810
  forward = (
1769
1811
  func(runtime.llvm.lookup(self.handle.encode("utf-8"), (name + "_cpu_forward").encode("utf-8"))) or None
1770
1812
  )
1771
- backward = (
1772
- func(runtime.llvm.lookup(self.handle.encode("utf-8"), (name + "_cpu_backward").encode("utf-8"))) or None
1773
- )
1813
+
1814
+ if options["enable_backward"]:
1815
+ backward = (
1816
+ func(runtime.llvm.lookup(self.handle.encode("utf-8"), (name + "_cpu_backward").encode("utf-8")))
1817
+ or None
1818
+ )
1819
+ else:
1820
+ backward = None
1774
1821
 
1775
1822
  hooks = KernelHooks(forward, backward)
1776
1823
 
@@ -1803,13 +1850,13 @@ class Module:
1803
1850
  self._live_kernels = weakref.WeakSet()
1804
1851
 
1805
1852
  # executable modules currently loaded
1806
- self.execs = {} # (device.context: ModuleExec)
1853
+ self.execs = {} # ((device.context, blockdim): ModuleExec)
1807
1854
 
1808
1855
  # set of device contexts where the build has failed
1809
1856
  self.failed_builds = set()
1810
1857
 
1811
- # hash data, including the module hash
1812
- self.hasher = None
1858
+ # hash data, including the module hash. Module may store multiple hashes (one per block_dim used)
1859
+ self.hashers = {}
1813
1860
 
1814
1861
  # LLVM executable modules are identified using strings. Since it's possible for multiple
1815
1862
  # executable versions to be loaded at the same time, we need a way to ensure uniqueness.
@@ -1822,6 +1869,8 @@ class Module:
1822
1869
  "max_unroll": warp.config.max_unroll,
1823
1870
  "enable_backward": warp.config.enable_backward,
1824
1871
  "fast_math": False,
1872
+ "fuse_fp": True,
1873
+ "lineinfo": False,
1825
1874
  "cuda_output": None, # supported values: "ptx", "cubin", or None (automatic)
1826
1875
  "mode": warp.config.mode,
1827
1876
  "block_dim": 256,
@@ -1965,28 +2014,27 @@ class Module:
1965
2014
 
1966
2015
  def hash_module(self):
1967
2016
  # compute latest hash
1968
- self.hasher = ModuleHasher(self)
1969
- return self.hasher.get_module_hash()
2017
+ block_dim = self.options["block_dim"]
2018
+ self.hashers[block_dim] = ModuleHasher(self)
2019
+ return self.hashers[block_dim].get_module_hash()
1970
2020
 
1971
2021
  def load(self, device, block_dim=None) -> ModuleExec:
1972
2022
  device = runtime.get_device(device)
1973
2023
 
1974
- # re-compile module if tile size (blockdim) changes
1975
- # todo: it would be better to have a method such as `module.get_kernel(block_dim=N)`
1976
- # that can return a single kernel instance with a given block size
2024
+ # update module options if launching with a new block dim
1977
2025
  if block_dim is not None:
1978
- if self.options["block_dim"] != block_dim:
1979
- self.unload()
1980
2026
  self.options["block_dim"] = block_dim
1981
2027
 
2028
+ active_block_dim = self.options["block_dim"]
2029
+
1982
2030
  # compute the hash if needed
1983
- if self.hasher is None:
1984
- self.hasher = ModuleHasher(self)
2031
+ if active_block_dim not in self.hashers:
2032
+ self.hashers[active_block_dim] = ModuleHasher(self)
1985
2033
 
1986
2034
  # check if executable module is already loaded and not stale
1987
- exec = self.execs.get(device.context)
2035
+ exec = self.execs.get((device.context, active_block_dim))
1988
2036
  if exec is not None:
1989
- if exec.module_hash == self.hasher.module_hash:
2037
+ if exec.module_hash == self.hashers[active_block_dim].get_module_hash():
1990
2038
  return exec
1991
2039
 
1992
2040
  # quietly avoid repeated build attempts to reduce error spew
@@ -1994,10 +2042,11 @@ class Module:
1994
2042
  return None
1995
2043
 
1996
2044
  module_name = "wp_" + self.name
1997
- module_hash = self.hasher.module_hash
2045
+ module_hash = self.hashers[active_block_dim].get_module_hash()
1998
2046
 
1999
2047
  # use a unique module path using the module short hash
2000
- module_dir = os.path.join(warp.config.kernel_cache_dir, f"{module_name}_{module_hash.hex()[:7]}")
2048
+ module_name_short = f"{module_name}_{module_hash.hex()[:7]}"
2049
+ module_dir = os.path.join(warp.config.kernel_cache_dir, module_name_short)
2001
2050
 
2002
2051
  with warp.ScopedTimer(
2003
2052
  f"Module {self.name} {module_hash.hex()[:7]} load on device '{device}'", active=not warp.config.quiet
@@ -2005,7 +2054,7 @@ class Module:
2005
2054
  # -----------------------------------------------------------
2006
2055
  # determine output paths
2007
2056
  if device.is_cpu:
2008
- output_name = "module_codegen.o"
2057
+ output_name = f"{module_name_short}.o"
2009
2058
  output_arch = None
2010
2059
 
2011
2060
  elif device.is_cuda:
@@ -2025,10 +2074,10 @@ class Module:
2025
2074
 
2026
2075
  if use_ptx:
2027
2076
  output_arch = min(device.arch, warp.config.ptx_target_arch)
2028
- output_name = f"module_codegen.sm{output_arch}.ptx"
2077
+ output_name = f"{module_name_short}.sm{output_arch}.ptx"
2029
2078
  else:
2030
2079
  output_arch = device.arch
2031
- output_name = f"module_codegen.sm{output_arch}.cubin"
2080
+ output_name = f"{module_name_short}.sm{output_arch}.cubin"
2032
2081
 
2033
2082
  # final object binary path
2034
2083
  binary_path = os.path.join(module_dir, output_name)
@@ -2050,7 +2099,7 @@ class Module:
2050
2099
  # Some of the Tile codegen, such as cuFFTDx and cuBLASDx, requires knowledge of the target arch
2051
2100
  "output_arch": output_arch,
2052
2101
  }
2053
- builder = ModuleBuilder(self, builder_options, hasher=self.hasher)
2102
+ builder = ModuleBuilder(self, builder_options, hasher=self.hashers[active_block_dim])
2054
2103
 
2055
2104
  # create a temporary (process unique) dir for build outputs before moving to the binary dir
2056
2105
  build_dir = os.path.join(
@@ -2066,7 +2115,7 @@ class Module:
2066
2115
  if device.is_cpu:
2067
2116
  # build
2068
2117
  try:
2069
- source_code_path = os.path.join(build_dir, "module_codegen.cpp")
2118
+ source_code_path = os.path.join(build_dir, f"{module_name_short}.cpp")
2070
2119
 
2071
2120
  # write cpp sources
2072
2121
  cpp_source = builder.codegen("cpu")
@@ -2084,6 +2133,7 @@ class Module:
2084
2133
  mode=self.options["mode"],
2085
2134
  fast_math=self.options["fast_math"],
2086
2135
  verify_fp=warp.config.verify_fp,
2136
+ fuse_fp=self.options["fuse_fp"],
2087
2137
  )
2088
2138
 
2089
2139
  except Exception as e:
@@ -2094,7 +2144,7 @@ class Module:
2094
2144
  elif device.is_cuda:
2095
2145
  # build
2096
2146
  try:
2097
- source_code_path = os.path.join(build_dir, "module_codegen.cu")
2147
+ source_code_path = os.path.join(build_dir, f"{module_name_short}.cu")
2098
2148
 
2099
2149
  # write cuda sources
2100
2150
  cu_source = builder.codegen("cuda")
@@ -2111,9 +2161,12 @@ class Module:
2111
2161
  output_arch,
2112
2162
  output_path,
2113
2163
  config=self.options["mode"],
2114
- fast_math=self.options["fast_math"],
2115
2164
  verify_fp=warp.config.verify_fp,
2165
+ fast_math=self.options["fast_math"],
2166
+ fuse_fp=self.options["fuse_fp"],
2167
+ lineinfo=self.options["lineinfo"],
2116
2168
  ltoirs=builder.ltoirs.values(),
2169
+ fatbins=builder.fatbins.values(),
2117
2170
  )
2118
2171
 
2119
2172
  except Exception as e:
@@ -2125,7 +2178,7 @@ class Module:
2125
2178
  # build meta data
2126
2179
 
2127
2180
  meta = builder.build_meta()
2128
- meta_path = os.path.join(build_dir, "module_codegen.meta")
2181
+ meta_path = os.path.join(build_dir, f"{module_name_short}.meta")
2129
2182
 
2130
2183
  with open(meta_path, "w") as meta_file:
2131
2184
  json.dump(meta, meta_file)
@@ -2189,7 +2242,7 @@ class Module:
2189
2242
  # -----------------------------------------------------------
2190
2243
  # Load CPU or CUDA binary
2191
2244
 
2192
- meta_path = os.path.join(module_dir, "module_codegen.meta")
2245
+ meta_path = os.path.join(module_dir, f"{module_name_short}.meta")
2193
2246
  with open(meta_path, "r") as meta_file:
2194
2247
  meta = json.load(meta_file)
2195
2248
 
@@ -2199,13 +2252,13 @@ class Module:
2199
2252
  self.cpu_exec_id += 1
2200
2253
  runtime.llvm.load_obj(binary_path.encode("utf-8"), module_handle.encode("utf-8"))
2201
2254
  module_exec = ModuleExec(module_handle, module_hash, device, meta)
2202
- self.execs[None] = module_exec
2255
+ self.execs[(None, active_block_dim)] = module_exec
2203
2256
 
2204
2257
  elif device.is_cuda:
2205
2258
  cuda_module = warp.build.load_cuda(binary_path, device)
2206
2259
  if cuda_module is not None:
2207
2260
  module_exec = ModuleExec(cuda_module, module_hash, device, meta)
2208
- self.execs[device.context] = module_exec
2261
+ self.execs[(device.context, active_block_dim)] = module_exec
2209
2262
  else:
2210
2263
  module_load_timer.extra_msg = " (error)"
2211
2264
  raise Exception(f"Failed to load CUDA module '{self.name}'")
@@ -2227,14 +2280,14 @@ class Module:
2227
2280
 
2228
2281
  def mark_modified(self):
2229
2282
  # clear hash data
2230
- self.hasher = None
2283
+ self.hashers = {}
2231
2284
 
2232
2285
  # clear build failures
2233
2286
  self.failed_builds = set()
2234
2287
 
2235
2288
  # lookup kernel entry points based on name, called after compilation / module load
2236
2289
  def get_kernel_hooks(self, kernel, device):
2237
- module_exec = self.execs.get(device.context)
2290
+ module_exec = self.execs.get((device.context, self.options["block_dim"]))
2238
2291
  if module_exec is not None:
2239
2292
  return module_exec.get_kernel_hooks(kernel)
2240
2293
  else:
@@ -2353,6 +2406,7 @@ class Event:
2353
2406
  DEFAULT = 0x0
2354
2407
  BLOCKING_SYNC = 0x1
2355
2408
  DISABLE_TIMING = 0x2
2409
+ INTERPROCESS = 0x4
2356
2410
 
2357
2411
  def __new__(cls, *args, **kwargs):
2358
2412
  """Creates a new event instance."""
@@ -2360,7 +2414,9 @@ class Event:
2360
2414
  instance.owner = False
2361
2415
  return instance
2362
2416
 
2363
- def __init__(self, device: "Devicelike" = None, cuda_event=None, enable_timing: bool = False):
2417
+ def __init__(
2418
+ self, device: "Devicelike" = None, cuda_event=None, enable_timing: bool = False, interprocess: bool = False
2419
+ ):
2364
2420
  """Initializes the event on a CUDA device.
2365
2421
 
2366
2422
  Args:
@@ -2372,6 +2428,12 @@ class Event:
2372
2428
  :func:`~warp.get_event_elapsed_time` can be used to measure the
2373
2429
  time between two events created with ``enable_timing=True`` and
2374
2430
  recorded onto streams.
2431
+ interprocess: If ``True`` this event may be used as an interprocess event.
2432
+
2433
+ Raises:
2434
+ RuntimeError: The event could not be created.
2435
+ ValueError: The combination of ``enable_timing=True`` and
2436
+ ``interprocess=True`` is not allowed.
2375
2437
  """
2376
2438
 
2377
2439
  device = get_device(device)
@@ -2386,11 +2448,48 @@ class Event:
2386
2448
  flags = Event.Flags.DEFAULT
2387
2449
  if not enable_timing:
2388
2450
  flags |= Event.Flags.DISABLE_TIMING
2451
+ if interprocess:
2452
+ if enable_timing:
2453
+ raise ValueError("The combination of 'enable_timing=True' and 'interprocess=True' is not allowed.")
2454
+ flags |= Event.Flags.INTERPROCESS
2455
+
2389
2456
  self.cuda_event = runtime.core.cuda_event_create(device.context, flags)
2390
2457
  if not self.cuda_event:
2391
2458
  raise RuntimeError(f"Failed to create event on device {device}")
2392
2459
  self.owner = True
2393
2460
 
2461
+ def ipc_handle(self) -> bytes:
2462
+ """Return a CUDA IPC handle of the event as a 64-byte ``bytes`` object.
2463
+
2464
+ The event must have been created with ``interprocess=True`` in order to
2465
+ obtain a valid interprocess handle.
2466
+
2467
+ IPC is currently only supported on Linux.
2468
+
2469
+ Example:
2470
+ Create an event and get its IPC handle::
2471
+
2472
+ e1 = wp.Event(interprocess=True)
2473
+ event_handle = e1.ipc_handle()
2474
+
2475
+ Raises:
2476
+ RuntimeError: Device does not support IPC.
2477
+ """
2478
+
2479
+ if self.device.is_ipc_supported is not False:
2480
+ # Allocate a buffer for the data (64-element char array)
2481
+ ipc_handle_buffer = (ctypes.c_char * 64)()
2482
+
2483
+ warp.context.runtime.core.cuda_ipc_get_event_handle(self.device.context, self.cuda_event, ipc_handle_buffer)
2484
+
2485
+ if ipc_handle_buffer.raw == bytes(64):
2486
+ warp.utils.warn("IPC event handle appears to be invalid. Was interprocess=True used?")
2487
+
2488
+ return ipc_handle_buffer.raw
2489
+
2490
+ else:
2491
+ raise RuntimeError(f"Device {self.device} does not support IPC.")
2492
+
2394
2493
  def __del__(self):
2395
2494
  if not self.owner:
2396
2495
  return
@@ -2538,23 +2637,27 @@ class Device:
2538
2637
  """A device to allocate Warp arrays and to launch kernels on.
2539
2638
 
2540
2639
  Attributes:
2541
- ordinal: A Warp-specific integer label for the device. ``-1`` for CPU devices.
2542
- name: A string label for the device. By default, CPU devices will be named according to the processor name,
2640
+ ordinal (int): A Warp-specific label for the device. ``-1`` for CPU devices.
2641
+ name (str): A label for the device. By default, CPU devices will be named according to the processor name,
2543
2642
  or ``"CPU"`` if the processor name cannot be determined.
2544
- arch: An integer representing the compute capability version number calculated as
2545
- ``10 * major + minor``. ``0`` for CPU devices.
2546
- is_uva: A boolean indicating whether the device supports unified addressing.
2643
+ arch (int): The compute capability version number calculated as ``10 * major + minor``.
2644
+ ``0`` for CPU devices.
2645
+ is_uva (bool): Indicates whether the device supports unified addressing.
2547
2646
  ``False`` for CPU devices.
2548
- is_cubin_supported: A boolean indicating whether Warp's version of NVRTC can directly
2647
+ is_cubin_supported (bool): Indicates whether Warp's version of NVRTC can directly
2549
2648
  generate CUDA binary files (cubin) for this device's architecture. ``False`` for CPU devices.
2550
- is_mempool_supported: A boolean indicating whether the device supports using the
2551
- ``cuMemAllocAsync`` and ``cuMemPool`` family of APIs for stream-ordered memory allocations. ``False`` for
2552
- CPU devices.
2553
- is_primary: A boolean indicating whether this device's CUDA context is also the
2554
- device's primary context.
2555
- uuid: A string representing the UUID of the CUDA device. The UUID is in the same format used by
2556
- ``nvidia-smi -L``. ``None`` for CPU devices.
2557
- pci_bus_id: A string identifier for the CUDA device in the format ``[domain]:[bus]:[device]``, in which
2649
+ is_mempool_supported (bool): Indicates whether the device supports using the ``cuMemAllocAsync`` and
2650
+ ``cuMemPool`` family of APIs for stream-ordered memory allocations. ``False`` for CPU devices.
2651
+ is_ipc_supported (Optional[bool]): Indicates whether the device supports IPC.
2652
+
2653
+ - ``True`` if supported.
2654
+ - ``False`` if not supported.
2655
+ - ``None`` if IPC support could not be determined (e.g. CUDA 11).
2656
+
2657
+ is_primary (bool): Indicates whether this device's CUDA context is also the device's primary context.
2658
+ uuid (str): The UUID of the CUDA device. The UUID is in the same format used by ``nvidia-smi -L``.
2659
+ ``None`` for CPU devices.
2660
+ pci_bus_id (str): An identifier for the CUDA device in the format ``[domain]:[bus]:[device]``, in which
2558
2661
  ``domain``, ``bus``, and ``device`` are all hexadecimal values. ``None`` for CPU devices.
2559
2662
  """
2560
2663
 
@@ -2587,6 +2690,7 @@ class Device:
2587
2690
  self.is_uva = False
2588
2691
  self.is_mempool_supported = False
2589
2692
  self.is_mempool_enabled = False
2693
+ self.is_ipc_supported = False # TODO: Support IPC for CPU arrays
2590
2694
  self.is_cubin_supported = False
2591
2695
  self.uuid = None
2592
2696
  self.pci_bus_id = None
@@ -2602,8 +2706,14 @@ class Device:
2602
2706
  # CUDA device
2603
2707
  self.name = runtime.core.cuda_device_get_name(ordinal).decode()
2604
2708
  self.arch = runtime.core.cuda_device_get_arch(ordinal)
2605
- self.is_uva = runtime.core.cuda_device_is_uva(ordinal)
2606
- self.is_mempool_supported = runtime.core.cuda_device_is_mempool_supported(ordinal)
2709
+ self.is_uva = runtime.core.cuda_device_is_uva(ordinal) > 0
2710
+ self.is_mempool_supported = runtime.core.cuda_device_is_mempool_supported(ordinal) > 0
2711
+ if platform.system() == "Linux":
2712
+ # Use None when IPC support cannot be determined
2713
+ ipc_support_api_query = runtime.core.cuda_device_is_ipc_supported(ordinal)
2714
+ self.is_ipc_supported = bool(ipc_support_api_query) if ipc_support_api_query >= 0 else None
2715
+ else:
2716
+ self.is_ipc_supported = False
2607
2717
  if warp.config.enable_mempools_at_init:
2608
2718
  # enable if supported
2609
2719
  self.is_mempool_enabled = self.is_mempool_supported
@@ -3084,6 +3194,9 @@ class Runtime:
3084
3194
  self.core.radix_sort_pairs_int_host.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int]
3085
3195
  self.core.radix_sort_pairs_int_device.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int]
3086
3196
 
3197
+ self.core.radix_sort_pairs_float_host.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int]
3198
+ self.core.radix_sort_pairs_float_device.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int]
3199
+
3087
3200
  self.core.runlength_encode_int_host.argtypes = [
3088
3201
  ctypes.c_uint64,
3089
3202
  ctypes.c_uint64,
@@ -3100,10 +3213,16 @@ class Runtime:
3100
3213
  ]
3101
3214
 
3102
3215
  self.core.bvh_create_host.restype = ctypes.c_uint64
3103
- self.core.bvh_create_host.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int]
3216
+ self.core.bvh_create_host.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int, ctypes.c_int]
3104
3217
 
3105
3218
  self.core.bvh_create_device.restype = ctypes.c_uint64
3106
- self.core.bvh_create_device.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int]
3219
+ self.core.bvh_create_device.argtypes = [
3220
+ ctypes.c_void_p,
3221
+ ctypes.c_void_p,
3222
+ ctypes.c_void_p,
3223
+ ctypes.c_int,
3224
+ ctypes.c_int,
3225
+ ]
3107
3226
 
3108
3227
  self.core.bvh_destroy_host.argtypes = [ctypes.c_uint64]
3109
3228
  self.core.bvh_destroy_device.argtypes = [ctypes.c_uint64]
@@ -3119,6 +3238,7 @@ class Runtime:
3119
3238
  ctypes.c_int,
3120
3239
  ctypes.c_int,
3121
3240
  ctypes.c_int,
3241
+ ctypes.c_int,
3122
3242
  ]
3123
3243
 
3124
3244
  self.core.mesh_create_device.restype = ctypes.c_uint64
@@ -3130,6 +3250,7 @@ class Runtime:
3130
3250
  ctypes.c_int,
3131
3251
  ctypes.c_int,
3132
3252
  ctypes.c_int,
3253
+ ctypes.c_int,
3133
3254
  ]
3134
3255
 
3135
3256
  self.core.mesh_destroy_host.argtypes = [ctypes.c_uint64]
@@ -3367,6 +3488,8 @@ class Runtime:
3367
3488
  self.core.cuda_device_is_uva.restype = ctypes.c_int
3368
3489
  self.core.cuda_device_is_mempool_supported.argtypes = [ctypes.c_int]
3369
3490
  self.core.cuda_device_is_mempool_supported.restype = ctypes.c_int
3491
+ self.core.cuda_device_is_ipc_supported.argtypes = [ctypes.c_int]
3492
+ self.core.cuda_device_is_ipc_supported.restype = ctypes.c_int
3370
3493
  self.core.cuda_device_set_mempool_release_threshold.argtypes = [ctypes.c_int, ctypes.c_uint64]
3371
3494
  self.core.cuda_device_set_mempool_release_threshold.restype = ctypes.c_int
3372
3495
  self.core.cuda_device_get_mempool_release_threshold.argtypes = [ctypes.c_int]
@@ -3420,6 +3543,22 @@ class Runtime:
3420
3543
  self.core.cuda_set_mempool_access_enabled.argtypes = [ctypes.c_int, ctypes.c_int, ctypes.c_int]
3421
3544
  self.core.cuda_set_mempool_access_enabled.restype = ctypes.c_int
3422
3545
 
3546
+ # inter-process communication
3547
+ self.core.cuda_ipc_get_mem_handle.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_char)]
3548
+ self.core.cuda_ipc_get_mem_handle.restype = None
3549
+ self.core.cuda_ipc_open_mem_handle.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_char)]
3550
+ self.core.cuda_ipc_open_mem_handle.restype = ctypes.c_void_p
3551
+ self.core.cuda_ipc_close_mem_handle.argtypes = [ctypes.c_void_p]
3552
+ self.core.cuda_ipc_close_mem_handle.restype = None
3553
+ self.core.cuda_ipc_get_event_handle.argtypes = [
3554
+ ctypes.c_void_p,
3555
+ ctypes.c_void_p,
3556
+ ctypes.POINTER(ctypes.c_char),
3557
+ ]
3558
+ self.core.cuda_ipc_get_event_handle.restype = None
3559
+ self.core.cuda_ipc_open_event_handle.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_char)]
3560
+ self.core.cuda_ipc_open_event_handle.restype = ctypes.c_void_p
3561
+
3423
3562
  self.core.cuda_stream_create.argtypes = [ctypes.c_void_p, ctypes.c_int]
3424
3563
  self.core.cuda_stream_create.restype = ctypes.c_void_p
3425
3564
  self.core.cuda_stream_destroy.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
@@ -3467,6 +3606,7 @@ class Runtime:
3467
3606
 
3468
3607
  self.core.cuda_compile_program.argtypes = [
3469
3608
  ctypes.c_char_p, # cuda_src
3609
+ ctypes.c_char_p, # program name
3470
3610
  ctypes.c_int, # arch
3471
3611
  ctypes.c_char_p, # include_dir
3472
3612
  ctypes.c_int, # num_cuda_include_dirs
@@ -3475,10 +3615,13 @@ class Runtime:
3475
3615
  ctypes.c_bool, # verbose
3476
3616
  ctypes.c_bool, # verify_fp
3477
3617
  ctypes.c_bool, # fast_math
3618
+ ctypes.c_bool, # fuse_fp
3619
+ ctypes.c_bool, # lineinfo
3478
3620
  ctypes.c_char_p, # output_path
3479
3621
  ctypes.c_size_t, # num_ltoirs
3480
3622
  ctypes.POINTER(ctypes.c_char_p), # ltoirs
3481
3623
  ctypes.POINTER(ctypes.c_size_t), # ltoir_sizes
3624
+ ctypes.POINTER(ctypes.c_int), # ltoir_input_types, each of type nvJitLinkInputType
3482
3625
  ]
3483
3626
  self.core.cuda_compile_program.restype = ctypes.c_size_t
3484
3627
 
@@ -3518,6 +3661,22 @@ class Runtime:
3518
3661
  ]
3519
3662
  self.core.cuda_compile_dot.restype = ctypes.c_bool
3520
3663
 
3664
+ self.core.cuda_compile_solver.argtypes = [
3665
+ ctypes.c_char_p, # universal fatbin
3666
+ ctypes.c_char_p, # lto
3667
+ ctypes.c_char_p, # function name
3668
+ ctypes.c_int, # num include dirs
3669
+ ctypes.POINTER(ctypes.c_char_p), # include dirs
3670
+ ctypes.c_char_p, # mathdx include dir
3671
+ ctypes.c_int, # arch
3672
+ ctypes.c_int, # M
3673
+ ctypes.c_int, # N
3674
+ ctypes.c_int, # precision
3675
+ ctypes.c_int, # fill_mode
3676
+ ctypes.c_int, # num threads
3677
+ ]
3678
+ self.core.cuda_compile_fft.restype = ctypes.c_bool
3679
+
3521
3680
  self.core.cuda_load_module.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
3522
3681
  self.core.cuda_load_module.restype = ctypes.c_void_p
3523
3682
 
@@ -4868,6 +5027,40 @@ def from_numpy(
4868
5027
  )
4869
5028
 
4870
5029
 
5030
+ def event_from_ipc_handle(handle, device: "Devicelike" = None) -> Event:
5031
+ """Create an event from an IPC handle.
5032
+
5033
+ Args:
5034
+ handle: The interprocess event handle for an existing CUDA event.
5035
+ device (Devicelike): Device to associate with the array.
5036
+
5037
+ Returns:
5038
+ An event created from the interprocess event handle ``handle``.
5039
+
5040
+ Raises:
5041
+ RuntimeError: IPC is not supported on ``device``.
5042
+ """
5043
+
5044
+ try:
5045
+ # Performance note: try first, ask questions later
5046
+ device = warp.context.runtime.get_device(device)
5047
+ except Exception:
5048
+ # Fallback to using the public API for retrieving the device,
5049
+ # which takes take of initializing Warp if needed.
5050
+ device = warp.context.get_device(device)
5051
+
5052
+ if device.is_ipc_supported is False:
5053
+ raise RuntimeError(f"IPC is not supported on device {device}.")
5054
+
5055
+ event = Event(
5056
+ device=device, cuda_event=warp.context.runtime.core.cuda_ipc_open_event_handle(device.context, handle)
5057
+ )
5058
+ # Events created from IPC handles must be freed with cuEventDestroy
5059
+ event.owner = True
5060
+
5061
+ return event
5062
+
5063
+
4871
5064
  # given a kernel destination argument type and a value convert
4872
5065
  # to a c-type that can be passed to a kernel
4873
5066
  def pack_arg(kernel, arg_type, arg_name, value, device, adjoint=False):
@@ -4949,6 +5142,9 @@ def pack_arg(kernel, arg_type, arg_name, value, device, adjoint=False):
4949
5142
 
4950
5143
  # try to convert to a value type (vec3, mat33, etc)
4951
5144
  elif issubclass(arg_type, ctypes.Array):
5145
+ # simple value types don't have gradient arrays, but native built-in signatures still expect a non-null adjoint value of the correct type
5146
+ if value is None and adjoint:
5147
+ return arg_type(0)
4952
5148
  if warp.types.types_equal(type(value), arg_type):
4953
5149
  return value
4954
5150
  else:
@@ -4958,9 +5154,6 @@ def pack_arg(kernel, arg_type, arg_name, value, device, adjoint=False):
4958
5154
  except Exception as e:
4959
5155
  raise ValueError(f"Failed to convert argument for param {arg_name} to {type_str(arg_type)}") from e
4960
5156
 
4961
- elif isinstance(value, bool):
4962
- return ctypes.c_bool(value)
4963
-
4964
5157
  elif isinstance(value, arg_type):
4965
5158
  try:
4966
5159
  # try to pack as a scalar type
@@ -4975,6 +5168,9 @@ def pack_arg(kernel, arg_type, arg_name, value, device, adjoint=False):
4975
5168
  ) from e
4976
5169
 
4977
5170
  else:
5171
+ # scalar args don't have gradient arrays, but native built-in signatures still expect a non-null scalar adjoint
5172
+ if value is None and adjoint:
5173
+ return arg_type._type_(0)
4978
5174
  try:
4979
5175
  # try to pack as a scalar type
4980
5176
  if arg_type is warp.types.float16:
@@ -4992,8 +5188,23 @@ def pack_arg(kernel, arg_type, arg_name, value, device, adjoint=False):
4992
5188
  # represents all data required for a kernel launch
4993
5189
  # so that launches can be replayed quickly, use `wp.launch(..., record_cmd=True)`
4994
5190
  class Launch:
5191
+ """Represents all data required for a kernel launch so that launches can be replayed quickly.
5192
+
5193
+ Users should not directly instantiate this class, instead use
5194
+ ``wp.launch(..., record_cmd=True)`` to record a launch.
5195
+ """
5196
+
4995
5197
  def __init__(
4996
- self, kernel, device, hooks=None, params=None, params_addr=None, bounds=None, max_blocks=0, block_dim=256
5198
+ self,
5199
+ kernel,
5200
+ device: Device,
5201
+ hooks: Optional[KernelHooks] = None,
5202
+ params: Optional[Sequence[Any]] = None,
5203
+ params_addr: Optional[Sequence[ctypes.c_void_p]] = None,
5204
+ bounds: Optional[launch_bounds_t] = None,
5205
+ max_blocks: int = 0,
5206
+ block_dim: int = 256,
5207
+ adjoint: bool = False,
4997
5208
  ):
4998
5209
  # retain the module executable so it doesn't get unloaded
4999
5210
  self.module_exec = kernel.module.load(device)
@@ -5006,13 +5217,14 @@ class Launch:
5006
5217
 
5007
5218
  # if not specified set a zero bound
5008
5219
  if not bounds:
5009
- bounds = warp.types.launch_bounds_t(0)
5220
+ bounds = launch_bounds_t(0)
5010
5221
 
5011
5222
  # if not specified then build a list of default value params for args
5012
5223
  if not params:
5013
5224
  params = []
5014
5225
  params.append(bounds)
5015
5226
 
5227
+ # Pack forward parameters
5016
5228
  for a in kernel.adj.args:
5017
5229
  if isinstance(a.type, warp.types.array):
5018
5230
  params.append(a.type.__ctype__())
@@ -5021,6 +5233,18 @@ class Launch:
5021
5233
  else:
5022
5234
  params.append(pack_arg(kernel, a.type, a.label, 0, device, False))
5023
5235
 
5236
+ # Pack adjoint parameters if adjoint=True
5237
+ if adjoint:
5238
+ for a in kernel.adj.args:
5239
+ if isinstance(a.type, warp.types.array):
5240
+ params.append(a.type.__ctype__())
5241
+ elif isinstance(a.type, warp.codegen.Struct):
5242
+ params.append(a.type().__ctype__())
5243
+ else:
5244
+ # For primitive types in adjoint mode, initialize with 0
5245
+ params.append(pack_arg(kernel, a.type, a.label, 0, device, True))
5246
+
5247
+ # Create array of parameter addresses
5024
5248
  kernel_args = [ctypes.c_void_p(ctypes.addressof(x)) for x in params]
5025
5249
  kernel_params = (ctypes.c_void_p * len(kernel_args))(*kernel_args)
5026
5250
 
@@ -5030,13 +5254,30 @@ class Launch:
5030
5254
  self.hooks = hooks
5031
5255
  self.params = params
5032
5256
  self.params_addr = params_addr
5033
- self.device = device
5034
- self.bounds = bounds
5035
- self.max_blocks = max_blocks
5036
- self.block_dim = block_dim
5257
+ self.device: Device = device
5258
+ """The device to launch on.
5259
+ This should not be changed after the launch object is created.
5260
+ """
5261
+
5262
+ self.bounds: launch_bounds_t = bounds
5263
+ """The launch bounds. Update with :meth:`set_dim`."""
5037
5264
 
5038
- def set_dim(self, dim):
5039
- self.bounds = warp.types.launch_bounds_t(dim)
5265
+ self.max_blocks: int = max_blocks
5266
+ """The maximum number of CUDA thread blocks to use."""
5267
+
5268
+ self.block_dim: int = block_dim
5269
+ """The number of threads per block."""
5270
+
5271
+ self.adjoint: bool = adjoint
5272
+ """Whether to run the adjoint kernel instead of the forward kernel."""
5273
+
5274
+ def set_dim(self, dim: Union[int, List[int], Tuple[int, ...]]):
5275
+ """Set the launch dimensions.
5276
+
5277
+ Args:
5278
+ dim: The dimensions of the launch.
5279
+ """
5280
+ self.bounds = launch_bounds_t(dim)
5040
5281
 
5041
5282
  # launch bounds always at index 0
5042
5283
  self.params[0] = self.bounds
@@ -5045,22 +5286,36 @@ class Launch:
5045
5286
  if self.params_addr:
5046
5287
  self.params_addr[0] = ctypes.c_void_p(ctypes.addressof(self.bounds))
5047
5288
 
5048
- # set kernel param at an index, will convert to ctype as necessary
5049
- def set_param_at_index(self, index, value):
5289
+ def set_param_at_index(self, index: int, value: Any, adjoint: bool = False):
5290
+ """Set a kernel parameter at an index.
5291
+
5292
+ Args:
5293
+ index: The index of the param to set.
5294
+ value: The value to set the param to.
5295
+ """
5050
5296
  arg_type = self.kernel.adj.args[index].type
5051
5297
  arg_name = self.kernel.adj.args[index].label
5052
5298
 
5053
- carg = pack_arg(self.kernel, arg_type, arg_name, value, self.device, False)
5299
+ carg = pack_arg(self.kernel, arg_type, arg_name, value, self.device, adjoint)
5054
5300
 
5055
- self.params[index + 1] = carg
5301
+ if adjoint:
5302
+ params_index = index + len(self.kernel.adj.args) + 1
5303
+ else:
5304
+ params_index = index + 1
5305
+
5306
+ self.params[params_index] = carg
5056
5307
 
5057
5308
  # for CUDA kernels we need to update the address to each arg
5058
5309
  if self.params_addr:
5059
- self.params_addr[index + 1] = ctypes.c_void_p(ctypes.addressof(carg))
5310
+ self.params_addr[params_index] = ctypes.c_void_p(ctypes.addressof(carg))
5311
+
5312
+ def set_param_at_index_from_ctype(self, index: int, value: Union[ctypes.Structure, int, float]):
5313
+ """Set a kernel parameter at an index without any type conversion.
5060
5314
 
5061
- # set kernel param at an index without any type conversion
5062
- # args must be passed as ctypes or basic int / float types
5063
- def set_param_at_index_from_ctype(self, index, value):
5315
+ Args:
5316
+ index: The index of the param to set.
5317
+ value: The value to set the param to.
5318
+ """
5064
5319
  if isinstance(value, ctypes.Structure):
5065
5320
  # not sure how to directly assign struct->struct without reallocating using ctypes
5066
5321
  self.params[index + 1] = value
@@ -5072,32 +5327,62 @@ class Launch:
5072
5327
  else:
5073
5328
  self.params[index + 1].__init__(value)
5074
5329
 
5075
- # set kernel param by argument name
5076
- def set_param_by_name(self, name, value):
5330
+ def set_param_by_name(self, name: str, value: Any, adjoint: bool = False):
5331
+ """Set a kernel parameter by argument name.
5332
+
5333
+ Args:
5334
+ name: The name of the argument to set.
5335
+ value: The value to set the argument to.
5336
+ adjoint: If ``True``, set the adjoint of this parameter instead of the forward parameter.
5337
+ """
5077
5338
  for i, arg in enumerate(self.kernel.adj.args):
5078
5339
  if arg.label == name:
5079
- self.set_param_at_index(i, value)
5340
+ self.set_param_at_index(i, value, adjoint)
5341
+ return
5080
5342
 
5081
- # set kernel param by argument name with no type conversions
5082
- def set_param_by_name_from_ctype(self, name, value):
5343
+ raise ValueError(f"Argument '{name}' not found in kernel '{self.kernel.key}'")
5344
+
5345
+ def set_param_by_name_from_ctype(self, name: str, value: ctypes.Structure):
5346
+ """Set a kernel parameter by argument name with no type conversions.
5347
+
5348
+ Args:
5349
+ name: The name of the argument to set.
5350
+ value: The value to set the argument to.
5351
+ """
5083
5352
  # lookup argument index
5084
5353
  for i, arg in enumerate(self.kernel.adj.args):
5085
5354
  if arg.label == name:
5086
5355
  self.set_param_at_index_from_ctype(i, value)
5087
5356
 
5088
- # set all params
5089
- def set_params(self, values):
5357
+ def set_params(self, values: Sequence[Any]):
5358
+ """Set all parameters.
5359
+
5360
+ Args:
5361
+ values: A list of values to set the params to.
5362
+ """
5090
5363
  for i, v in enumerate(values):
5091
5364
  self.set_param_at_index(i, v)
5092
5365
 
5093
- # set all params without performing type-conversions
5094
- def set_params_from_ctypes(self, values):
5366
+ def set_params_from_ctypes(self, values: Sequence[ctypes.Structure]):
5367
+ """Set all parameters without performing type-conversions.
5368
+
5369
+ Args:
5370
+ values: A list of ctypes or basic int / float types.
5371
+ """
5095
5372
  for i, v in enumerate(values):
5096
5373
  self.set_param_at_index_from_ctype(i, v)
5097
5374
 
5098
- def launch(self, stream=None) -> Any:
5375
+ def launch(self, stream: Optional[Stream] = None) -> None:
5376
+ """Launch the kernel.
5377
+
5378
+ Args:
5379
+ stream: The stream to launch on.
5380
+ """
5099
5381
  if self.device.is_cpu:
5100
- self.hooks.forward(*self.params)
5382
+ if self.adjoint:
5383
+ self.hooks.backward(*self.params)
5384
+ else:
5385
+ self.hooks.forward(*self.params)
5101
5386
  else:
5102
5387
  if stream is None:
5103
5388
  stream = self.device.stream
@@ -5110,32 +5395,44 @@ class Launch:
5110
5395
  if graph is not None:
5111
5396
  graph.retain_module_exec(self.module_exec)
5112
5397
 
5113
- runtime.core.cuda_launch_kernel(
5114
- self.device.context,
5115
- self.hooks.forward,
5116
- self.bounds.size,
5117
- self.max_blocks,
5118
- self.block_dim,
5119
- self.hooks.forward_smem_bytes,
5120
- self.params_addr,
5121
- stream.cuda_stream,
5122
- )
5398
+ if self.adjoint:
5399
+ runtime.core.cuda_launch_kernel(
5400
+ self.device.context,
5401
+ self.hooks.backward,
5402
+ self.bounds.size,
5403
+ self.max_blocks,
5404
+ self.block_dim,
5405
+ self.hooks.backward_smem_bytes,
5406
+ self.params_addr,
5407
+ stream.cuda_stream,
5408
+ )
5409
+ else:
5410
+ runtime.core.cuda_launch_kernel(
5411
+ self.device.context,
5412
+ self.hooks.forward,
5413
+ self.bounds.size,
5414
+ self.max_blocks,
5415
+ self.block_dim,
5416
+ self.hooks.forward_smem_bytes,
5417
+ self.params_addr,
5418
+ stream.cuda_stream,
5419
+ )
5123
5420
 
5124
5421
 
5125
5422
  def launch(
5126
5423
  kernel,
5127
- dim: Tuple[int],
5424
+ dim: Union[int, Sequence[int]],
5128
5425
  inputs: Sequence = [],
5129
5426
  outputs: Sequence = [],
5130
5427
  adj_inputs: Sequence = [],
5131
5428
  adj_outputs: Sequence = [],
5132
5429
  device: Devicelike = None,
5133
- stream: Stream = None,
5134
- adjoint=False,
5135
- record_tape=True,
5136
- record_cmd=False,
5137
- max_blocks=0,
5138
- block_dim=256,
5430
+ stream: Optional[Stream] = None,
5431
+ adjoint: bool = False,
5432
+ record_tape: bool = True,
5433
+ record_cmd: bool = False,
5434
+ max_blocks: int = 0,
5435
+ block_dim: int = 256,
5139
5436
  ):
5140
5437
  """Launch a Warp kernel on the target device
5141
5438
 
@@ -5143,18 +5440,23 @@ def launch(
5143
5440
 
5144
5441
  Args:
5145
5442
  kernel: The name of a Warp kernel function, decorated with the ``@wp.kernel`` decorator
5146
- dim: The number of threads to launch the kernel, can be an integer, or a Tuple of ints with max of 4 dimensions
5443
+ dim: The number of threads to launch the kernel, can be an integer or a
5444
+ sequence of integers with a maximum of 4 dimensions.
5147
5445
  inputs: The input parameters to the kernel (optional)
5148
5446
  outputs: The output parameters (optional)
5149
5447
  adj_inputs: The adjoint inputs (optional)
5150
5448
  adj_outputs: The adjoint outputs (optional)
5151
- device: The device to launch on (optional)
5152
- stream: The stream to launch on (optional)
5153
- adjoint: Whether to run forward or backward pass (typically use False)
5154
- record_tape: When true the launch will be recorded the global wp.Tape() object when present
5155
- record_cmd: When True the launch will be returned as a ``Launch`` command object, the launch will not occur until the user calls ``cmd.launch()``
5156
- max_blocks: The maximum number of CUDA thread blocks to use. Only has an effect for CUDA kernel launches.
5157
- If negative or zero, the maximum hardware value will be used.
5449
+ device: The device to launch on.
5450
+ stream: The stream to launch on.
5451
+ adjoint: Whether to run forward or backward pass (typically use ``False``).
5452
+ record_tape: When ``True``, the launch will be recorded the global
5453
+ :class:`wp.Tape() <warp.Tape>` object when present.
5454
+ record_cmd: When ``True``, the launch will return a :class:`Launch`
5455
+ object. The launch will not occur until the user calls
5456
+ :meth:`Launch.launch()`.
5457
+ max_blocks: The maximum number of CUDA thread blocks to use.
5458
+ Only has an effect for CUDA kernel launches.
5459
+ If negative or zero, the maximum hardware value will be used.
5158
5460
  block_dim: The number of threads per block.
5159
5461
  """
5160
5462
 
@@ -5175,7 +5477,7 @@ def launch(
5175
5477
  print(f"kernel: {kernel.key} dim: {dim} inputs: {inputs} outputs: {outputs} device: {device}")
5176
5478
 
5177
5479
  # construct launch bounds
5178
- bounds = warp.types.launch_bounds_t(dim)
5480
+ bounds = launch_bounds_t(dim)
5179
5481
 
5180
5482
  if bounds.size > 0:
5181
5483
  # first param is the number of threads
@@ -5232,6 +5534,17 @@ def launch(
5232
5534
  f"Failed to find backward kernel '{kernel.key}' from module '{kernel.module.name}' for device '{device}'"
5233
5535
  )
5234
5536
 
5537
+ if record_cmd:
5538
+ launch = Launch(
5539
+ kernel=kernel,
5540
+ hooks=hooks,
5541
+ params=params,
5542
+ params_addr=None,
5543
+ bounds=bounds,
5544
+ device=device,
5545
+ adjoint=adjoint,
5546
+ )
5547
+ return launch
5235
5548
  hooks.backward(*params)
5236
5549
 
5237
5550
  else:
@@ -5242,7 +5555,13 @@ def launch(
5242
5555
 
5243
5556
  if record_cmd:
5244
5557
  launch = Launch(
5245
- kernel=kernel, hooks=hooks, params=params, params_addr=None, bounds=bounds, device=device
5558
+ kernel=kernel,
5559
+ hooks=hooks,
5560
+ params=params,
5561
+ params_addr=None,
5562
+ bounds=bounds,
5563
+ device=device,
5564
+ adjoint=adjoint,
5246
5565
  )
5247
5566
  return launch
5248
5567
  else:
@@ -5269,16 +5588,30 @@ def launch(
5269
5588
  f"Failed to find backward kernel '{kernel.key}' from module '{kernel.module.name}' for device '{device}'"
5270
5589
  )
5271
5590
 
5272
- runtime.core.cuda_launch_kernel(
5273
- device.context,
5274
- hooks.backward,
5275
- bounds.size,
5276
- max_blocks,
5277
- block_dim,
5278
- hooks.backward_smem_bytes,
5279
- kernel_params,
5280
- stream.cuda_stream,
5281
- )
5591
+ if record_cmd:
5592
+ launch = Launch(
5593
+ kernel=kernel,
5594
+ hooks=hooks,
5595
+ params=params,
5596
+ params_addr=kernel_params,
5597
+ bounds=bounds,
5598
+ device=device,
5599
+ max_blocks=max_blocks,
5600
+ block_dim=block_dim,
5601
+ adjoint=adjoint,
5602
+ )
5603
+ return launch
5604
+ else:
5605
+ runtime.core.cuda_launch_kernel(
5606
+ device.context,
5607
+ hooks.backward,
5608
+ bounds.size,
5609
+ max_blocks,
5610
+ block_dim,
5611
+ hooks.backward_smem_bytes,
5612
+ kernel_params,
5613
+ stream.cuda_stream,
5614
+ )
5282
5615
 
5283
5616
  else:
5284
5617
  if hooks.forward is None:
@@ -5298,7 +5631,6 @@ def launch(
5298
5631
  block_dim=block_dim,
5299
5632
  )
5300
5633
  return launch
5301
-
5302
5634
  else:
5303
5635
  # launch
5304
5636
  runtime.core.cuda_launch_kernel(
@@ -6034,14 +6366,19 @@ def export_functions_rst(file): # pragma: no cover
6034
6366
  # build dictionary of all functions by group
6035
6367
  groups = {}
6036
6368
 
6037
- for _k, f in builtin_functions.items():
6369
+ functions = list(builtin_functions.values())
6370
+
6371
+ for f in functions:
6038
6372
  # build dict of groups
6039
6373
  if f.group not in groups:
6040
6374
  groups[f.group] = []
6041
6375
 
6042
- # append all overloads to the group
6043
- for o in f.overloads:
6044
- groups[f.group].append(o)
6376
+ if hasattr(f, "overloads"):
6377
+ # append all overloads to the group
6378
+ for o in f.overloads:
6379
+ groups[f.group].append(o)
6380
+ else:
6381
+ groups[f.group].append(f)
6045
6382
 
6046
6383
  # Keep track of what function and query types have been written
6047
6384
  written_functions = set()
@@ -6061,6 +6398,10 @@ def export_functions_rst(file): # pragma: no cover
6061
6398
  print("---------------", file=file)
6062
6399
 
6063
6400
  for f in g:
6401
+ if f.func:
6402
+ # f is a Warp function written in Python, we can use autofunction
6403
+ print(f".. autofunction:: {f.func.__module__}.{f.key}", file=file)
6404
+ continue
6064
6405
  for f_prefix, query_type in query_types:
6065
6406
  if f.key.startswith(f_prefix) and query_type not in written_query_types:
6066
6407
  print(f".. autoclass:: {query_type}", file=file)
@@ -6118,24 +6459,32 @@ def export_stubs(file): # pragma: no cover
6118
6459
  print(header, file=file)
6119
6460
  print(file=file)
6120
6461
 
6121
- for k, g in builtin_functions.items():
6122
- for f in g.overloads:
6123
- args = ", ".join(f"{k}: {type_str(v)}" for k, v in f.input_types.items())
6462
+ def add_stub(f):
6463
+ args = ", ".join(f"{k}: {type_str(v)}" for k, v in f.input_types.items())
6124
6464
 
6125
- return_str = ""
6465
+ return_str = ""
6126
6466
 
6127
- if f.hidden: # or f.generic:
6128
- continue
6467
+ if f.hidden: # or f.generic:
6468
+ return
6129
6469
 
6470
+ return_type = f.value_type
6471
+ if f.value_func:
6130
6472
  return_type = f.value_func(None, None)
6131
- if return_type:
6132
- return_str = " -> " + type_str(return_type)
6133
-
6134
- print("@over", file=file)
6135
- print(f"def {f.key}({args}){return_str}:", file=file)
6136
- print(f' """{f.doc}', file=file)
6137
- print(' """', file=file)
6138
- print(" ...\n\n", file=file)
6473
+ if return_type:
6474
+ return_str = " -> " + type_str(return_type)
6475
+
6476
+ print("@over", file=file)
6477
+ print(f"def {f.key}({args}){return_str}:", file=file)
6478
+ print(f' """{f.doc}', file=file)
6479
+ print(' """', file=file)
6480
+ print(" ...\n\n", file=file)
6481
+
6482
+ for g in builtin_functions.values():
6483
+ if hasattr(g, "overloads"):
6484
+ for f in g.overloads:
6485
+ add_stub(f)
6486
+ else:
6487
+ add_stub(g)
6139
6488
 
6140
6489
 
6141
6490
  def export_builtins(file: io.TextIOBase): # pragma: no cover
@@ -6161,6 +6510,8 @@ def export_builtins(file: io.TextIOBase): # pragma: no cover
6161
6510
  file.write('extern "C" {\n\n')
6162
6511
 
6163
6512
  for k, g in builtin_functions.items():
6513
+ if not hasattr(g, "overloads"):
6514
+ continue
6164
6515
  for f in g.overloads:
6165
6516
  if not f.export or f.generic:
6166
6517
  continue