warp-lang 1.5.1__py3-none-macosx_10_13_universal2.whl → 1.6.0__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (123) hide show
  1. warp/__init__.py +5 -0
  2. warp/autograd.py +414 -191
  3. warp/bin/libwarp-clang.dylib +0 -0
  4. warp/bin/libwarp.dylib +0 -0
  5. warp/build.py +40 -12
  6. warp/build_dll.py +13 -6
  7. warp/builtins.py +1076 -480
  8. warp/codegen.py +240 -119
  9. warp/config.py +1 -1
  10. warp/context.py +298 -84
  11. warp/examples/assets/square_cloth.usd +0 -0
  12. warp/examples/benchmarks/benchmark_gemm.py +27 -18
  13. warp/examples/benchmarks/benchmark_interop_paddle.py +3 -3
  14. warp/examples/benchmarks/benchmark_interop_torch.py +3 -3
  15. warp/examples/core/example_torch.py +18 -34
  16. warp/examples/fem/example_apic_fluid.py +1 -0
  17. warp/examples/fem/example_mixed_elasticity.py +1 -1
  18. warp/examples/optim/example_bounce.py +1 -1
  19. warp/examples/optim/example_cloth_throw.py +1 -1
  20. warp/examples/optim/example_diffray.py +4 -15
  21. warp/examples/optim/example_drone.py +1 -1
  22. warp/examples/optim/example_softbody_properties.py +392 -0
  23. warp/examples/optim/example_trajectory.py +1 -3
  24. warp/examples/optim/example_walker.py +5 -0
  25. warp/examples/sim/example_cartpole.py +0 -2
  26. warp/examples/sim/example_cloth_self_contact.py +260 -0
  27. warp/examples/sim/example_granular_collision_sdf.py +4 -5
  28. warp/examples/sim/example_jacobian_ik.py +0 -2
  29. warp/examples/sim/example_quadruped.py +5 -2
  30. warp/examples/tile/example_tile_cholesky.py +79 -0
  31. warp/examples/tile/example_tile_convolution.py +2 -2
  32. warp/examples/tile/example_tile_fft.py +2 -2
  33. warp/examples/tile/example_tile_filtering.py +3 -3
  34. warp/examples/tile/example_tile_matmul.py +4 -4
  35. warp/examples/tile/example_tile_mlp.py +12 -12
  36. warp/examples/tile/example_tile_nbody.py +180 -0
  37. warp/examples/tile/example_tile_walker.py +319 -0
  38. warp/math.py +147 -0
  39. warp/native/array.h +12 -0
  40. warp/native/builtin.h +0 -1
  41. warp/native/bvh.cpp +149 -70
  42. warp/native/bvh.cu +287 -68
  43. warp/native/bvh.h +195 -85
  44. warp/native/clang/clang.cpp +5 -1
  45. warp/native/cuda_util.cpp +35 -0
  46. warp/native/cuda_util.h +5 -0
  47. warp/native/exports.h +40 -40
  48. warp/native/intersect.h +17 -0
  49. warp/native/mat.h +41 -0
  50. warp/native/mathdx.cpp +19 -0
  51. warp/native/mesh.cpp +25 -8
  52. warp/native/mesh.cu +153 -101
  53. warp/native/mesh.h +482 -403
  54. warp/native/quat.h +40 -0
  55. warp/native/solid_angle.h +7 -0
  56. warp/native/sort.cpp +85 -0
  57. warp/native/sort.cu +34 -0
  58. warp/native/sort.h +3 -1
  59. warp/native/spatial.h +11 -0
  60. warp/native/tile.h +1185 -664
  61. warp/native/tile_reduce.h +8 -6
  62. warp/native/vec.h +41 -0
  63. warp/native/warp.cpp +8 -1
  64. warp/native/warp.cu +263 -40
  65. warp/native/warp.h +19 -5
  66. warp/optim/linear.py +22 -4
  67. warp/render/render_opengl.py +124 -59
  68. warp/sim/__init__.py +6 -1
  69. warp/sim/collide.py +270 -26
  70. warp/sim/integrator_euler.py +25 -7
  71. warp/sim/integrator_featherstone.py +154 -35
  72. warp/sim/integrator_vbd.py +842 -40
  73. warp/sim/model.py +111 -53
  74. warp/stubs.py +248 -115
  75. warp/tape.py +28 -30
  76. warp/tests/aux_test_module_unload.py +15 -0
  77. warp/tests/{test_sim_grad.py → flaky_test_sim_grad.py} +104 -63
  78. warp/tests/test_array.py +74 -0
  79. warp/tests/test_assert.py +242 -0
  80. warp/tests/test_codegen.py +14 -61
  81. warp/tests/test_collision.py +2 -2
  82. warp/tests/test_examples.py +9 -0
  83. warp/tests/test_grad_debug.py +87 -2
  84. warp/tests/test_hash_grid.py +1 -1
  85. warp/tests/test_ipc.py +116 -0
  86. warp/tests/test_mat.py +138 -167
  87. warp/tests/test_math.py +47 -1
  88. warp/tests/test_matmul.py +11 -7
  89. warp/tests/test_matmul_lite.py +4 -4
  90. warp/tests/test_mesh.py +84 -60
  91. warp/tests/test_mesh_query_aabb.py +165 -0
  92. warp/tests/test_mesh_query_point.py +328 -286
  93. warp/tests/test_mesh_query_ray.py +134 -121
  94. warp/tests/test_mlp.py +2 -2
  95. warp/tests/test_operators.py +43 -0
  96. warp/tests/test_overwrite.py +2 -2
  97. warp/tests/test_quat.py +77 -0
  98. warp/tests/test_reload.py +29 -0
  99. warp/tests/test_sim_grad_bounce_linear.py +204 -0
  100. warp/tests/test_static.py +16 -0
  101. warp/tests/test_tape.py +25 -0
  102. warp/tests/test_tile.py +134 -191
  103. warp/tests/test_tile_load.py +356 -0
  104. warp/tests/test_tile_mathdx.py +61 -8
  105. warp/tests/test_tile_mlp.py +17 -17
  106. warp/tests/test_tile_reduce.py +24 -18
  107. warp/tests/test_tile_shared_memory.py +66 -17
  108. warp/tests/test_tile_view.py +165 -0
  109. warp/tests/test_torch.py +35 -0
  110. warp/tests/test_utils.py +36 -24
  111. warp/tests/test_vec.py +110 -0
  112. warp/tests/unittest_suites.py +29 -4
  113. warp/tests/unittest_utils.py +30 -11
  114. warp/thirdparty/unittest_parallel.py +2 -2
  115. warp/types.py +409 -99
  116. warp/utils.py +9 -5
  117. {warp_lang-1.5.1.dist-info → warp_lang-1.6.0.dist-info}/METADATA +68 -44
  118. {warp_lang-1.5.1.dist-info → warp_lang-1.6.0.dist-info}/RECORD +121 -110
  119. {warp_lang-1.5.1.dist-info → warp_lang-1.6.0.dist-info}/WHEEL +1 -1
  120. warp/examples/benchmarks/benchmark_tile.py +0 -179
  121. warp/native/tile_gemm.h +0 -341
  122. {warp_lang-1.5.1.dist-info → warp_lang-1.6.0.dist-info}/LICENSE.md +0 -0
  123. {warp_lang-1.5.1.dist-info → warp_lang-1.6.0.dist-info}/top_level.txt +0 -0
warp/context.py CHANGED
@@ -5,6 +5,8 @@
5
5
  # distribution of this software and related documentation without an express
6
6
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
7
 
8
+ from __future__ import annotations
9
+
8
10
  import ast
9
11
  import ctypes
10
12
  import errno
@@ -393,7 +395,8 @@ class Function:
393
395
  if not warp.codegen.func_match_args(f, arg_types, kwarg_types):
394
396
  continue
395
397
 
396
- if len(f.input_types) != len(arg_types):
398
+ acceptable_arg_num = len(f.input_types) - len(f.defaults) <= len(arg_types) <= len(f.input_types)
399
+ if not acceptable_arg_num:
397
400
  continue
398
401
 
399
402
  # try to match the given types to the function template types
@@ -410,6 +413,10 @@ class Function:
410
413
 
411
414
  arg_names = f.input_types.keys()
412
415
  overload_annotations = dict(zip(arg_names, arg_types))
416
+ # add defaults
417
+ for k, d in f.defaults.items():
418
+ if k not in overload_annotations:
419
+ overload_annotations[k] = warp.codegen.strip_reference(warp.codegen.get_arg_type(d))
413
420
 
414
421
  ovl = shallowcopy(f)
415
422
  ovl.adj = warp.codegen.Adjoint(f.func, overload_annotations)
@@ -753,8 +760,15 @@ def func(f):
753
760
  scope_locals = inspect.currentframe().f_back.f_locals
754
761
 
755
762
  m = get_module(f.__module__)
763
+ doc = getattr(f, "__doc__", "") or ""
756
764
  Function(
757
- func=f, key=name, namespace="", module=m, value_func=None, scope_locals=scope_locals
765
+ func=f,
766
+ key=name,
767
+ namespace="",
768
+ module=m,
769
+ value_func=None,
770
+ scope_locals=scope_locals,
771
+ doc=doc.strip(),
758
772
  ) # value_type not known yet, will be inferred during Adjoint.build()
759
773
 
760
774
  # use the top of the list of overloads for this key
@@ -1059,7 +1073,8 @@ def overload(kernel, arg_types=Union[None, Dict[str, Any], List[Any]]):
1059
1073
  raise RuntimeError("wp.overload() called with invalid argument!")
1060
1074
 
1061
1075
 
1062
- builtin_functions = {}
1076
+ # native functions that are part of the Warp API
1077
+ builtin_functions: Dict[str, Function] = {}
1063
1078
 
1064
1079
 
1065
1080
  def get_generic_vtypes():
@@ -1328,6 +1343,28 @@ def add_builtin(
1328
1343
  setattr(warp, key, func)
1329
1344
 
1330
1345
 
1346
+ def register_api_function(
1347
+ function: Function,
1348
+ group: str = "Other",
1349
+ hidden=False,
1350
+ ):
1351
+ """Main entry point to register a Warp Python function to be part of the Warp API and appear in the documentation.
1352
+
1353
+ Args:
1354
+ function (Function): Warp function to be registered.
1355
+ group (str): Classification used for the documentation.
1356
+ input_types (Mapping[str, Any]): Signature of the user-facing function.
1357
+ Variadic arguments are supported by prefixing the parameter names
1358
+ with asterisks as in `*args` and `**kwargs`. Generic arguments are
1359
+ supported with types such as `Any`, `Float`, `Scalar`, etc.
1360
+ value_type (Any): Type returned by the function.
1361
+ hidden (bool): Whether to add that function into the documentation.
1362
+ """
1363
+ function.group = group
1364
+ function.hidden = hidden
1365
+ builtin_functions[function.key] = function
1366
+
1367
+
1331
1368
  # global dictionary of modules
1332
1369
  user_modules = {}
1333
1370
 
@@ -1561,6 +1598,7 @@ class ModuleBuilder:
1561
1598
  self.options = options
1562
1599
  self.module = module
1563
1600
  self.deferred_functions = []
1601
+ self.fatbins = {} # map from <some identifier> to fatbins, to add at link time
1564
1602
  self.ltoirs = {} # map from lto symbol to lto binary
1565
1603
  self.ltoirs_decl = {} # map from lto symbol to lto forward declaration
1566
1604
 
@@ -1675,7 +1713,7 @@ class ModuleBuilder:
1675
1713
 
1676
1714
  for kernel in self.kernels:
1677
1715
  source += warp.codegen.codegen_kernel(kernel, device=device, options=self.options)
1678
- source += warp.codegen.codegen_module(kernel, device=device)
1716
+ source += warp.codegen.codegen_module(kernel, device=device, options=self.options)
1679
1717
 
1680
1718
  # add headers
1681
1719
  if device == "cpu":
@@ -1728,20 +1766,26 @@ class ModuleExec:
1728
1766
 
1729
1767
  name = kernel.get_mangled_name()
1730
1768
 
1769
+ options = dict(kernel.module.options)
1770
+ options.update(kernel.options)
1771
+
1731
1772
  if self.device.is_cuda:
1732
1773
  forward_name = name + "_cuda_kernel_forward"
1733
1774
  forward_kernel = runtime.core.cuda_get_kernel(
1734
1775
  self.device.context, self.handle, forward_name.encode("utf-8")
1735
1776
  )
1736
1777
 
1737
- backward_name = name + "_cuda_kernel_backward"
1738
- backward_kernel = runtime.core.cuda_get_kernel(
1739
- self.device.context, self.handle, backward_name.encode("utf-8")
1740
- )
1778
+ if options["enable_backward"]:
1779
+ backward_name = name + "_cuda_kernel_backward"
1780
+ backward_kernel = runtime.core.cuda_get_kernel(
1781
+ self.device.context, self.handle, backward_name.encode("utf-8")
1782
+ )
1783
+ else:
1784
+ backward_kernel = None
1741
1785
 
1742
1786
  # look up the required shared memory size for each kernel from module metadata
1743
1787
  forward_smem_bytes = self.meta[forward_name + "_smem_bytes"]
1744
- backward_smem_bytes = self.meta[backward_name + "_smem_bytes"]
1788
+ backward_smem_bytes = self.meta[backward_name + "_smem_bytes"] if options["enable_backward"] else 0
1745
1789
 
1746
1790
  # configure kernels maximum shared memory size
1747
1791
  max_smem_bytes = runtime.core.cuda_get_max_shared_memory(self.device.context)
@@ -1751,9 +1795,6 @@ class ModuleExec:
1751
1795
  f"Warning: Failed to configure kernel dynamic shared memory for this device, tried to configure {forward_name} kernel for {forward_smem_bytes} bytes, but maximum available is {max_smem_bytes}"
1752
1796
  )
1753
1797
 
1754
- options = dict(kernel.module.options)
1755
- options.update(kernel.options)
1756
-
1757
1798
  if options["enable_backward"] and not runtime.core.cuda_configure_kernel_shared_memory(
1758
1799
  backward_kernel, backward_smem_bytes
1759
1800
  ):
@@ -1768,9 +1809,14 @@ class ModuleExec:
1768
1809
  forward = (
1769
1810
  func(runtime.llvm.lookup(self.handle.encode("utf-8"), (name + "_cpu_forward").encode("utf-8"))) or None
1770
1811
  )
1771
- backward = (
1772
- func(runtime.llvm.lookup(self.handle.encode("utf-8"), (name + "_cpu_backward").encode("utf-8"))) or None
1773
- )
1812
+
1813
+ if options["enable_backward"]:
1814
+ backward = (
1815
+ func(runtime.llvm.lookup(self.handle.encode("utf-8"), (name + "_cpu_backward").encode("utf-8")))
1816
+ or None
1817
+ )
1818
+ else:
1819
+ backward = None
1774
1820
 
1775
1821
  hooks = KernelHooks(forward, backward)
1776
1822
 
@@ -1803,13 +1849,13 @@ class Module:
1803
1849
  self._live_kernels = weakref.WeakSet()
1804
1850
 
1805
1851
  # executable modules currently loaded
1806
- self.execs = {} # (device.context: ModuleExec)
1852
+ self.execs = {} # ((device.context, blockdim): ModuleExec)
1807
1853
 
1808
1854
  # set of device contexts where the build has failed
1809
1855
  self.failed_builds = set()
1810
1856
 
1811
- # hash data, including the module hash
1812
- self.hasher = None
1857
+ # hash data, including the module hash. Module may store multiple hashes (one per block_dim used)
1858
+ self.hashers = {}
1813
1859
 
1814
1860
  # LLVM executable modules are identified using strings. Since it's possible for multiple
1815
1861
  # executable versions to be loaded at the same time, we need a way to ensure uniqueness.
@@ -1822,6 +1868,8 @@ class Module:
1822
1868
  "max_unroll": warp.config.max_unroll,
1823
1869
  "enable_backward": warp.config.enable_backward,
1824
1870
  "fast_math": False,
1871
+ "fuse_fp": True,
1872
+ "lineinfo": False,
1825
1873
  "cuda_output": None, # supported values: "ptx", "cubin", or None (automatic)
1826
1874
  "mode": warp.config.mode,
1827
1875
  "block_dim": 256,
@@ -1965,28 +2013,27 @@ class Module:
1965
2013
 
1966
2014
  def hash_module(self):
1967
2015
  # compute latest hash
1968
- self.hasher = ModuleHasher(self)
1969
- return self.hasher.get_module_hash()
2016
+ block_dim = self.options["block_dim"]
2017
+ self.hashers[block_dim] = ModuleHasher(self)
2018
+ return self.hashers[block_dim].get_module_hash()
1970
2019
 
1971
2020
  def load(self, device, block_dim=None) -> ModuleExec:
1972
2021
  device = runtime.get_device(device)
1973
2022
 
1974
- # re-compile module if tile size (blockdim) changes
1975
- # todo: it would be better to have a method such as `module.get_kernel(block_dim=N)`
1976
- # that can return a single kernel instance with a given block size
2023
+ # update module options if launching with a new block dim
1977
2024
  if block_dim is not None:
1978
- if self.options["block_dim"] != block_dim:
1979
- self.unload()
1980
2025
  self.options["block_dim"] = block_dim
1981
2026
 
2027
+ active_block_dim = self.options["block_dim"]
2028
+
1982
2029
  # compute the hash if needed
1983
- if self.hasher is None:
1984
- self.hasher = ModuleHasher(self)
2030
+ if active_block_dim not in self.hashers:
2031
+ self.hashers[active_block_dim] = ModuleHasher(self)
1985
2032
 
1986
2033
  # check if executable module is already loaded and not stale
1987
- exec = self.execs.get(device.context)
2034
+ exec = self.execs.get((device.context, active_block_dim))
1988
2035
  if exec is not None:
1989
- if exec.module_hash == self.hasher.module_hash:
2036
+ if exec.module_hash == self.hashers[active_block_dim].get_module_hash():
1990
2037
  return exec
1991
2038
 
1992
2039
  # quietly avoid repeated build attempts to reduce error spew
@@ -1994,10 +2041,11 @@ class Module:
1994
2041
  return None
1995
2042
 
1996
2043
  module_name = "wp_" + self.name
1997
- module_hash = self.hasher.module_hash
2044
+ module_hash = self.hashers[active_block_dim].get_module_hash()
1998
2045
 
1999
2046
  # use a unique module path using the module short hash
2000
- module_dir = os.path.join(warp.config.kernel_cache_dir, f"{module_name}_{module_hash.hex()[:7]}")
2047
+ module_name_short = f"{module_name}_{module_hash.hex()[:7]}"
2048
+ module_dir = os.path.join(warp.config.kernel_cache_dir, module_name_short)
2001
2049
 
2002
2050
  with warp.ScopedTimer(
2003
2051
  f"Module {self.name} {module_hash.hex()[:7]} load on device '{device}'", active=not warp.config.quiet
@@ -2005,7 +2053,7 @@ class Module:
2005
2053
  # -----------------------------------------------------------
2006
2054
  # determine output paths
2007
2055
  if device.is_cpu:
2008
- output_name = "module_codegen.o"
2056
+ output_name = f"{module_name_short}.o"
2009
2057
  output_arch = None
2010
2058
 
2011
2059
  elif device.is_cuda:
@@ -2025,10 +2073,10 @@ class Module:
2025
2073
 
2026
2074
  if use_ptx:
2027
2075
  output_arch = min(device.arch, warp.config.ptx_target_arch)
2028
- output_name = f"module_codegen.sm{output_arch}.ptx"
2076
+ output_name = f"{module_name_short}.sm{output_arch}.ptx"
2029
2077
  else:
2030
2078
  output_arch = device.arch
2031
- output_name = f"module_codegen.sm{output_arch}.cubin"
2079
+ output_name = f"{module_name_short}.sm{output_arch}.cubin"
2032
2080
 
2033
2081
  # final object binary path
2034
2082
  binary_path = os.path.join(module_dir, output_name)
@@ -2050,7 +2098,7 @@ class Module:
2050
2098
  # Some of the Tile codegen, such as cuFFTDx and cuBLASDx, requires knowledge of the target arch
2051
2099
  "output_arch": output_arch,
2052
2100
  }
2053
- builder = ModuleBuilder(self, builder_options, hasher=self.hasher)
2101
+ builder = ModuleBuilder(self, builder_options, hasher=self.hashers[active_block_dim])
2054
2102
 
2055
2103
  # create a temporary (process unique) dir for build outputs before moving to the binary dir
2056
2104
  build_dir = os.path.join(
@@ -2066,7 +2114,7 @@ class Module:
2066
2114
  if device.is_cpu:
2067
2115
  # build
2068
2116
  try:
2069
- source_code_path = os.path.join(build_dir, "module_codegen.cpp")
2117
+ source_code_path = os.path.join(build_dir, f"{module_name_short}.cpp")
2070
2118
 
2071
2119
  # write cpp sources
2072
2120
  cpp_source = builder.codegen("cpu")
@@ -2084,6 +2132,7 @@ class Module:
2084
2132
  mode=self.options["mode"],
2085
2133
  fast_math=self.options["fast_math"],
2086
2134
  verify_fp=warp.config.verify_fp,
2135
+ fuse_fp=self.options["fuse_fp"],
2087
2136
  )
2088
2137
 
2089
2138
  except Exception as e:
@@ -2094,7 +2143,7 @@ class Module:
2094
2143
  elif device.is_cuda:
2095
2144
  # build
2096
2145
  try:
2097
- source_code_path = os.path.join(build_dir, "module_codegen.cu")
2146
+ source_code_path = os.path.join(build_dir, f"{module_name_short}.cu")
2098
2147
 
2099
2148
  # write cuda sources
2100
2149
  cu_source = builder.codegen("cuda")
@@ -2111,9 +2160,12 @@ class Module:
2111
2160
  output_arch,
2112
2161
  output_path,
2113
2162
  config=self.options["mode"],
2114
- fast_math=self.options["fast_math"],
2115
2163
  verify_fp=warp.config.verify_fp,
2164
+ fast_math=self.options["fast_math"],
2165
+ fuse_fp=self.options["fuse_fp"],
2166
+ lineinfo=self.options["lineinfo"],
2116
2167
  ltoirs=builder.ltoirs.values(),
2168
+ fatbins=builder.fatbins.values(),
2117
2169
  )
2118
2170
 
2119
2171
  except Exception as e:
@@ -2125,7 +2177,7 @@ class Module:
2125
2177
  # build meta data
2126
2178
 
2127
2179
  meta = builder.build_meta()
2128
- meta_path = os.path.join(build_dir, "module_codegen.meta")
2180
+ meta_path = os.path.join(build_dir, f"{module_name_short}.meta")
2129
2181
 
2130
2182
  with open(meta_path, "w") as meta_file:
2131
2183
  json.dump(meta, meta_file)
@@ -2189,7 +2241,7 @@ class Module:
2189
2241
  # -----------------------------------------------------------
2190
2242
  # Load CPU or CUDA binary
2191
2243
 
2192
- meta_path = os.path.join(module_dir, "module_codegen.meta")
2244
+ meta_path = os.path.join(module_dir, f"{module_name_short}.meta")
2193
2245
  with open(meta_path, "r") as meta_file:
2194
2246
  meta = json.load(meta_file)
2195
2247
 
@@ -2199,13 +2251,13 @@ class Module:
2199
2251
  self.cpu_exec_id += 1
2200
2252
  runtime.llvm.load_obj(binary_path.encode("utf-8"), module_handle.encode("utf-8"))
2201
2253
  module_exec = ModuleExec(module_handle, module_hash, device, meta)
2202
- self.execs[None] = module_exec
2254
+ self.execs[(None, active_block_dim)] = module_exec
2203
2255
 
2204
2256
  elif device.is_cuda:
2205
2257
  cuda_module = warp.build.load_cuda(binary_path, device)
2206
2258
  if cuda_module is not None:
2207
2259
  module_exec = ModuleExec(cuda_module, module_hash, device, meta)
2208
- self.execs[device.context] = module_exec
2260
+ self.execs[(device.context, active_block_dim)] = module_exec
2209
2261
  else:
2210
2262
  module_load_timer.extra_msg = " (error)"
2211
2263
  raise Exception(f"Failed to load CUDA module '{self.name}'")
@@ -2227,14 +2279,14 @@ class Module:
2227
2279
 
2228
2280
  def mark_modified(self):
2229
2281
  # clear hash data
2230
- self.hasher = None
2282
+ self.hashers = {}
2231
2283
 
2232
2284
  # clear build failures
2233
2285
  self.failed_builds = set()
2234
2286
 
2235
2287
  # lookup kernel entry points based on name, called after compilation / module load
2236
2288
  def get_kernel_hooks(self, kernel, device):
2237
- module_exec = self.execs.get(device.context)
2289
+ module_exec = self.execs.get((device.context, self.options["block_dim"]))
2238
2290
  if module_exec is not None:
2239
2291
  return module_exec.get_kernel_hooks(kernel)
2240
2292
  else:
@@ -2353,6 +2405,7 @@ class Event:
2353
2405
  DEFAULT = 0x0
2354
2406
  BLOCKING_SYNC = 0x1
2355
2407
  DISABLE_TIMING = 0x2
2408
+ INTERPROCESS = 0x4
2356
2409
 
2357
2410
  def __new__(cls, *args, **kwargs):
2358
2411
  """Creates a new event instance."""
@@ -2360,7 +2413,9 @@ class Event:
2360
2413
  instance.owner = False
2361
2414
  return instance
2362
2415
 
2363
- def __init__(self, device: "Devicelike" = None, cuda_event=None, enable_timing: bool = False):
2416
+ def __init__(
2417
+ self, device: "Devicelike" = None, cuda_event=None, enable_timing: bool = False, interprocess: bool = False
2418
+ ):
2364
2419
  """Initializes the event on a CUDA device.
2365
2420
 
2366
2421
  Args:
@@ -2372,6 +2427,12 @@ class Event:
2372
2427
  :func:`~warp.get_event_elapsed_time` can be used to measure the
2373
2428
  time between two events created with ``enable_timing=True`` and
2374
2429
  recorded onto streams.
2430
+ interprocess: If ``True`` this event may be used as an interprocess event.
2431
+
2432
+ Raises:
2433
+ RuntimeError: The event could not be created.
2434
+ ValueError: The combination of ``enable_timing=True`` and
2435
+ ``interprocess=True`` is not allowed.
2375
2436
  """
2376
2437
 
2377
2438
  device = get_device(device)
@@ -2386,11 +2447,48 @@ class Event:
2386
2447
  flags = Event.Flags.DEFAULT
2387
2448
  if not enable_timing:
2388
2449
  flags |= Event.Flags.DISABLE_TIMING
2450
+ if interprocess:
2451
+ if enable_timing:
2452
+ raise ValueError("The combination of 'enable_timing=True' and 'interprocess=True' is not allowed.")
2453
+ flags |= Event.Flags.INTERPROCESS
2454
+
2389
2455
  self.cuda_event = runtime.core.cuda_event_create(device.context, flags)
2390
2456
  if not self.cuda_event:
2391
2457
  raise RuntimeError(f"Failed to create event on device {device}")
2392
2458
  self.owner = True
2393
2459
 
2460
+ def ipc_handle(self) -> bytes:
2461
+ """Return a CUDA IPC handle of the event as a 64-byte ``bytes`` object.
2462
+
2463
+ The event must have been created with ``interprocess=True`` in order to
2464
+ obtain a valid interprocess handle.
2465
+
2466
+ IPC is currently only supported on Linux.
2467
+
2468
+ Example:
2469
+ Create an event and get its IPC handle::
2470
+
2471
+ e1 = wp.Event(interprocess=True)
2472
+ event_handle = e1.ipc_handle()
2473
+
2474
+ Raises:
2475
+ RuntimeError: Device does not support IPC.
2476
+ """
2477
+
2478
+ if self.device.is_ipc_supported is not False:
2479
+ # Allocate a buffer for the data (64-element char array)
2480
+ ipc_handle_buffer = (ctypes.c_char * 64)()
2481
+
2482
+ warp.context.runtime.core.cuda_ipc_get_event_handle(self.device.context, self.cuda_event, ipc_handle_buffer)
2483
+
2484
+ if ipc_handle_buffer.raw == bytes(64):
2485
+ warp.utils.warn("IPC event handle appears to be invalid. Was interprocess=True used?")
2486
+
2487
+ return ipc_handle_buffer.raw
2488
+
2489
+ else:
2490
+ raise RuntimeError(f"Device {self.device} does not support IPC.")
2491
+
2394
2492
  def __del__(self):
2395
2493
  if not self.owner:
2396
2494
  return
@@ -2538,23 +2636,27 @@ class Device:
2538
2636
  """A device to allocate Warp arrays and to launch kernels on.
2539
2637
 
2540
2638
  Attributes:
2541
- ordinal: A Warp-specific integer label for the device. ``-1`` for CPU devices.
2542
- name: A string label for the device. By default, CPU devices will be named according to the processor name,
2639
+ ordinal (int): A Warp-specific label for the device. ``-1`` for CPU devices.
2640
+ name (str): A label for the device. By default, CPU devices will be named according to the processor name,
2543
2641
  or ``"CPU"`` if the processor name cannot be determined.
2544
- arch: An integer representing the compute capability version number calculated as
2545
- ``10 * major + minor``. ``0`` for CPU devices.
2546
- is_uva: A boolean indicating whether the device supports unified addressing.
2642
+ arch (int): The compute capability version number calculated as ``10 * major + minor``.
2643
+ ``0`` for CPU devices.
2644
+ is_uva (bool): Indicates whether the device supports unified addressing.
2547
2645
  ``False`` for CPU devices.
2548
- is_cubin_supported: A boolean indicating whether Warp's version of NVRTC can directly
2646
+ is_cubin_supported (bool): Indicates whether Warp's version of NVRTC can directly
2549
2647
  generate CUDA binary files (cubin) for this device's architecture. ``False`` for CPU devices.
2550
- is_mempool_supported: A boolean indicating whether the device supports using the
2551
- ``cuMemAllocAsync`` and ``cuMemPool`` family of APIs for stream-ordered memory allocations. ``False`` for
2552
- CPU devices.
2553
- is_primary: A boolean indicating whether this device's CUDA context is also the
2554
- device's primary context.
2555
- uuid: A string representing the UUID of the CUDA device. The UUID is in the same format used by
2556
- ``nvidia-smi -L``. ``None`` for CPU devices.
2557
- pci_bus_id: A string identifier for the CUDA device in the format ``[domain]:[bus]:[device]``, in which
2648
+ is_mempool_supported (bool): Indicates whether the device supports using the ``cuMemAllocAsync`` and
2649
+ ``cuMemPool`` family of APIs for stream-ordered memory allocations. ``False`` for CPU devices.
2650
+ is_ipc_supported (Optional[bool]): Indicates whether the device supports IPC.
2651
+
2652
+ - ``True`` if supported.
2653
+ - ``False`` if not supported.
2654
+ - ``None`` if IPC support could not be determined (e.g. CUDA 11).
2655
+
2656
+ is_primary (bool): Indicates whether this device's CUDA context is also the device's primary context.
2657
+ uuid (str): The UUID of the CUDA device. The UUID is in the same format used by ``nvidia-smi -L``.
2658
+ ``None`` for CPU devices.
2659
+ pci_bus_id (str): An identifier for the CUDA device in the format ``[domain]:[bus]:[device]``, in which
2558
2660
  ``domain``, ``bus``, and ``device`` are all hexadecimal values. ``None`` for CPU devices.
2559
2661
  """
2560
2662
 
@@ -2587,6 +2689,7 @@ class Device:
2587
2689
  self.is_uva = False
2588
2690
  self.is_mempool_supported = False
2589
2691
  self.is_mempool_enabled = False
2692
+ self.is_ipc_supported = False # TODO: Support IPC for CPU arrays
2590
2693
  self.is_cubin_supported = False
2591
2694
  self.uuid = None
2592
2695
  self.pci_bus_id = None
@@ -2602,8 +2705,14 @@ class Device:
2602
2705
  # CUDA device
2603
2706
  self.name = runtime.core.cuda_device_get_name(ordinal).decode()
2604
2707
  self.arch = runtime.core.cuda_device_get_arch(ordinal)
2605
- self.is_uva = runtime.core.cuda_device_is_uva(ordinal)
2606
- self.is_mempool_supported = runtime.core.cuda_device_is_mempool_supported(ordinal)
2708
+ self.is_uva = runtime.core.cuda_device_is_uva(ordinal) > 0
2709
+ self.is_mempool_supported = runtime.core.cuda_device_is_mempool_supported(ordinal) > 0
2710
+ if platform.system() == "Linux":
2711
+ # Use None when IPC support cannot be determined
2712
+ ipc_support_api_query = runtime.core.cuda_device_is_ipc_supported(ordinal)
2713
+ self.is_ipc_supported = bool(ipc_support_api_query) if ipc_support_api_query >= 0 else None
2714
+ else:
2715
+ self.is_ipc_supported = False
2607
2716
  if warp.config.enable_mempools_at_init:
2608
2717
  # enable if supported
2609
2718
  self.is_mempool_enabled = self.is_mempool_supported
@@ -3084,6 +3193,9 @@ class Runtime:
3084
3193
  self.core.radix_sort_pairs_int_host.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int]
3085
3194
  self.core.radix_sort_pairs_int_device.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int]
3086
3195
 
3196
+ self.core.radix_sort_pairs_float_host.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int]
3197
+ self.core.radix_sort_pairs_float_device.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int]
3198
+
3087
3199
  self.core.runlength_encode_int_host.argtypes = [
3088
3200
  ctypes.c_uint64,
3089
3201
  ctypes.c_uint64,
@@ -3100,10 +3212,16 @@ class Runtime:
3100
3212
  ]
3101
3213
 
3102
3214
  self.core.bvh_create_host.restype = ctypes.c_uint64
3103
- self.core.bvh_create_host.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int]
3215
+ self.core.bvh_create_host.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int, ctypes.c_int]
3104
3216
 
3105
3217
  self.core.bvh_create_device.restype = ctypes.c_uint64
3106
- self.core.bvh_create_device.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int]
3218
+ self.core.bvh_create_device.argtypes = [
3219
+ ctypes.c_void_p,
3220
+ ctypes.c_void_p,
3221
+ ctypes.c_void_p,
3222
+ ctypes.c_int,
3223
+ ctypes.c_int,
3224
+ ]
3107
3225
 
3108
3226
  self.core.bvh_destroy_host.argtypes = [ctypes.c_uint64]
3109
3227
  self.core.bvh_destroy_device.argtypes = [ctypes.c_uint64]
@@ -3119,6 +3237,7 @@ class Runtime:
3119
3237
  ctypes.c_int,
3120
3238
  ctypes.c_int,
3121
3239
  ctypes.c_int,
3240
+ ctypes.c_int,
3122
3241
  ]
3123
3242
 
3124
3243
  self.core.mesh_create_device.restype = ctypes.c_uint64
@@ -3130,6 +3249,7 @@ class Runtime:
3130
3249
  ctypes.c_int,
3131
3250
  ctypes.c_int,
3132
3251
  ctypes.c_int,
3252
+ ctypes.c_int,
3133
3253
  ]
3134
3254
 
3135
3255
  self.core.mesh_destroy_host.argtypes = [ctypes.c_uint64]
@@ -3367,6 +3487,8 @@ class Runtime:
3367
3487
  self.core.cuda_device_is_uva.restype = ctypes.c_int
3368
3488
  self.core.cuda_device_is_mempool_supported.argtypes = [ctypes.c_int]
3369
3489
  self.core.cuda_device_is_mempool_supported.restype = ctypes.c_int
3490
+ self.core.cuda_device_is_ipc_supported.argtypes = [ctypes.c_int]
3491
+ self.core.cuda_device_is_ipc_supported.restype = ctypes.c_int
3370
3492
  self.core.cuda_device_set_mempool_release_threshold.argtypes = [ctypes.c_int, ctypes.c_uint64]
3371
3493
  self.core.cuda_device_set_mempool_release_threshold.restype = ctypes.c_int
3372
3494
  self.core.cuda_device_get_mempool_release_threshold.argtypes = [ctypes.c_int]
@@ -3420,6 +3542,22 @@ class Runtime:
3420
3542
  self.core.cuda_set_mempool_access_enabled.argtypes = [ctypes.c_int, ctypes.c_int, ctypes.c_int]
3421
3543
  self.core.cuda_set_mempool_access_enabled.restype = ctypes.c_int
3422
3544
 
3545
+ # inter-process communication
3546
+ self.core.cuda_ipc_get_mem_handle.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_char)]
3547
+ self.core.cuda_ipc_get_mem_handle.restype = None
3548
+ self.core.cuda_ipc_open_mem_handle.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_char)]
3549
+ self.core.cuda_ipc_open_mem_handle.restype = ctypes.c_void_p
3550
+ self.core.cuda_ipc_close_mem_handle.argtypes = [ctypes.c_void_p]
3551
+ self.core.cuda_ipc_close_mem_handle.restype = None
3552
+ self.core.cuda_ipc_get_event_handle.argtypes = [
3553
+ ctypes.c_void_p,
3554
+ ctypes.c_void_p,
3555
+ ctypes.POINTER(ctypes.c_char),
3556
+ ]
3557
+ self.core.cuda_ipc_get_event_handle.restype = None
3558
+ self.core.cuda_ipc_open_event_handle.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_char)]
3559
+ self.core.cuda_ipc_open_event_handle.restype = ctypes.c_void_p
3560
+
3423
3561
  self.core.cuda_stream_create.argtypes = [ctypes.c_void_p, ctypes.c_int]
3424
3562
  self.core.cuda_stream_create.restype = ctypes.c_void_p
3425
3563
  self.core.cuda_stream_destroy.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
@@ -3467,6 +3605,7 @@ class Runtime:
3467
3605
 
3468
3606
  self.core.cuda_compile_program.argtypes = [
3469
3607
  ctypes.c_char_p, # cuda_src
3608
+ ctypes.c_char_p, # program name
3470
3609
  ctypes.c_int, # arch
3471
3610
  ctypes.c_char_p, # include_dir
3472
3611
  ctypes.c_int, # num_cuda_include_dirs
@@ -3475,10 +3614,13 @@ class Runtime:
3475
3614
  ctypes.c_bool, # verbose
3476
3615
  ctypes.c_bool, # verify_fp
3477
3616
  ctypes.c_bool, # fast_math
3617
+ ctypes.c_bool, # fuse_fp
3618
+ ctypes.c_bool, # lineinfo
3478
3619
  ctypes.c_char_p, # output_path
3479
3620
  ctypes.c_size_t, # num_ltoirs
3480
3621
  ctypes.POINTER(ctypes.c_char_p), # ltoirs
3481
3622
  ctypes.POINTER(ctypes.c_size_t), # ltoir_sizes
3623
+ ctypes.POINTER(ctypes.c_int), # ltoir_input_types, each of type nvJitLinkInputType
3482
3624
  ]
3483
3625
  self.core.cuda_compile_program.restype = ctypes.c_size_t
3484
3626
 
@@ -3518,6 +3660,22 @@ class Runtime:
3518
3660
  ]
3519
3661
  self.core.cuda_compile_dot.restype = ctypes.c_bool
3520
3662
 
3663
+ self.core.cuda_compile_solver.argtypes = [
3664
+ ctypes.c_char_p, # universal fatbin
3665
+ ctypes.c_char_p, # lto
3666
+ ctypes.c_char_p, # function name
3667
+ ctypes.c_int, # num include dirs
3668
+ ctypes.POINTER(ctypes.c_char_p), # include dirs
3669
+ ctypes.c_char_p, # mathdx include dir
3670
+ ctypes.c_int, # arch
3671
+ ctypes.c_int, # M
3672
+ ctypes.c_int, # N
3673
+ ctypes.c_int, # precision
3674
+ ctypes.c_int, # fill_mode
3675
+ ctypes.c_int, # num threads
3676
+ ]
3677
+ self.core.cuda_compile_fft.restype = ctypes.c_bool
3678
+
3521
3679
  self.core.cuda_load_module.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
3522
3680
  self.core.cuda_load_module.restype = ctypes.c_void_p
3523
3681
 
@@ -4868,6 +5026,40 @@ def from_numpy(
4868
5026
  )
4869
5027
 
4870
5028
 
5029
+ def event_from_ipc_handle(handle, device: "Devicelike" = None) -> Event:
5030
+ """Create an event from an IPC handle.
5031
+
5032
+ Args:
5033
+ handle: The interprocess event handle for an existing CUDA event.
5034
+ device (Devicelike): Device to associate with the array.
5035
+
5036
+ Returns:
5037
+ An event created from the interprocess event handle ``handle``.
5038
+
5039
+ Raises:
5040
+ RuntimeError: IPC is not supported on ``device``.
5041
+ """
5042
+
5043
+ try:
5044
+ # Performance note: try first, ask questions later
5045
+ device = warp.context.runtime.get_device(device)
5046
+ except Exception:
5047
+ # Fallback to using the public API for retrieving the device,
5048
+ # which takes take of initializing Warp if needed.
5049
+ device = warp.context.get_device(device)
5050
+
5051
+ if device.is_ipc_supported is False:
5052
+ raise RuntimeError(f"IPC is not supported on device {device}.")
5053
+
5054
+ event = Event(
5055
+ device=device, cuda_event=warp.context.runtime.core.cuda_ipc_open_event_handle(device.context, handle)
5056
+ )
5057
+ # Events created from IPC handles must be freed with cuEventDestroy
5058
+ event.owner = True
5059
+
5060
+ return event
5061
+
5062
+
4871
5063
  # given a kernel destination argument type and a value convert
4872
5064
  # to a c-type that can be passed to a kernel
4873
5065
  def pack_arg(kernel, arg_type, arg_name, value, device, adjoint=False):
@@ -4949,6 +5141,9 @@ def pack_arg(kernel, arg_type, arg_name, value, device, adjoint=False):
4949
5141
 
4950
5142
  # try to convert to a value type (vec3, mat33, etc)
4951
5143
  elif issubclass(arg_type, ctypes.Array):
5144
+ # simple value types don't have gradient arrays, but native built-in signatures still expect a non-null adjoint value of the correct type
5145
+ if value is None and adjoint:
5146
+ return arg_type(0)
4952
5147
  if warp.types.types_equal(type(value), arg_type):
4953
5148
  return value
4954
5149
  else:
@@ -4958,9 +5153,6 @@ def pack_arg(kernel, arg_type, arg_name, value, device, adjoint=False):
4958
5153
  except Exception as e:
4959
5154
  raise ValueError(f"Failed to convert argument for param {arg_name} to {type_str(arg_type)}") from e
4960
5155
 
4961
- elif isinstance(value, bool):
4962
- return ctypes.c_bool(value)
4963
-
4964
5156
  elif isinstance(value, arg_type):
4965
5157
  try:
4966
5158
  # try to pack as a scalar type
@@ -4975,6 +5167,9 @@ def pack_arg(kernel, arg_type, arg_name, value, device, adjoint=False):
4975
5167
  ) from e
4976
5168
 
4977
5169
  else:
5170
+ # scalar args don't have gradient arrays, but native built-in signatures still expect a non-null scalar adjoint
5171
+ if value is None and adjoint:
5172
+ return arg_type._type_(0)
4978
5173
  try:
4979
5174
  # try to pack as a scalar type
4980
5175
  if arg_type is warp.types.float16:
@@ -6034,14 +6229,19 @@ def export_functions_rst(file): # pragma: no cover
6034
6229
  # build dictionary of all functions by group
6035
6230
  groups = {}
6036
6231
 
6037
- for _k, f in builtin_functions.items():
6232
+ functions = list(builtin_functions.values())
6233
+
6234
+ for f in functions:
6038
6235
  # build dict of groups
6039
6236
  if f.group not in groups:
6040
6237
  groups[f.group] = []
6041
6238
 
6042
- # append all overloads to the group
6043
- for o in f.overloads:
6044
- groups[f.group].append(o)
6239
+ if hasattr(f, "overloads"):
6240
+ # append all overloads to the group
6241
+ for o in f.overloads:
6242
+ groups[f.group].append(o)
6243
+ else:
6244
+ groups[f.group].append(f)
6045
6245
 
6046
6246
  # Keep track of what function and query types have been written
6047
6247
  written_functions = set()
@@ -6061,6 +6261,10 @@ def export_functions_rst(file): # pragma: no cover
6061
6261
  print("---------------", file=file)
6062
6262
 
6063
6263
  for f in g:
6264
+ if f.func:
6265
+ # f is a Warp function written in Python, we can use autofunction
6266
+ print(f".. autofunction:: {f.func.__module__}.{f.key}", file=file)
6267
+ continue
6064
6268
  for f_prefix, query_type in query_types:
6065
6269
  if f.key.startswith(f_prefix) and query_type not in written_query_types:
6066
6270
  print(f".. autoclass:: {query_type}", file=file)
@@ -6118,24 +6322,32 @@ def export_stubs(file): # pragma: no cover
6118
6322
  print(header, file=file)
6119
6323
  print(file=file)
6120
6324
 
6121
- for k, g in builtin_functions.items():
6122
- for f in g.overloads:
6123
- args = ", ".join(f"{k}: {type_str(v)}" for k, v in f.input_types.items())
6325
+ def add_stub(f):
6326
+ args = ", ".join(f"{k}: {type_str(v)}" for k, v in f.input_types.items())
6124
6327
 
6125
- return_str = ""
6328
+ return_str = ""
6126
6329
 
6127
- if f.hidden: # or f.generic:
6128
- continue
6330
+ if f.hidden: # or f.generic:
6331
+ return
6129
6332
 
6333
+ return_type = f.value_type
6334
+ if f.value_func:
6130
6335
  return_type = f.value_func(None, None)
6131
- if return_type:
6132
- return_str = " -> " + type_str(return_type)
6133
-
6134
- print("@over", file=file)
6135
- print(f"def {f.key}({args}){return_str}:", file=file)
6136
- print(f' """{f.doc}', file=file)
6137
- print(' """', file=file)
6138
- print(" ...\n\n", file=file)
6336
+ if return_type:
6337
+ return_str = " -> " + type_str(return_type)
6338
+
6339
+ print("@over", file=file)
6340
+ print(f"def {f.key}({args}){return_str}:", file=file)
6341
+ print(f' """{f.doc}', file=file)
6342
+ print(' """', file=file)
6343
+ print(" ...\n\n", file=file)
6344
+
6345
+ for g in builtin_functions.values():
6346
+ if hasattr(g, "overloads"):
6347
+ for f in g.overloads:
6348
+ add_stub(f)
6349
+ else:
6350
+ add_stub(g)
6139
6351
 
6140
6352
 
6141
6353
  def export_builtins(file: io.TextIOBase): # pragma: no cover
@@ -6161,6 +6373,8 @@ def export_builtins(file: io.TextIOBase): # pragma: no cover
6161
6373
  file.write('extern "C" {\n\n')
6162
6374
 
6163
6375
  for k, g in builtin_functions.items():
6376
+ if not hasattr(g, "overloads"):
6377
+ continue
6164
6378
  for f in g.overloads:
6165
6379
  if not f.export or f.generic:
6166
6380
  continue