warp-lang 1.5.0__py3-none-win_amd64.whl → 1.6.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (132) hide show
  1. warp/__init__.py +5 -0
  2. warp/autograd.py +414 -191
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +40 -12
  6. warp/build_dll.py +13 -6
  7. warp/builtins.py +1124 -497
  8. warp/codegen.py +261 -136
  9. warp/config.py +1 -1
  10. warp/context.py +357 -119
  11. warp/examples/assets/square_cloth.usd +0 -0
  12. warp/examples/benchmarks/benchmark_gemm.py +27 -18
  13. warp/examples/benchmarks/benchmark_interop_paddle.py +3 -3
  14. warp/examples/benchmarks/benchmark_interop_torch.py +3 -3
  15. warp/examples/core/example_torch.py +18 -34
  16. warp/examples/fem/example_apic_fluid.py +1 -0
  17. warp/examples/fem/example_mixed_elasticity.py +1 -1
  18. warp/examples/optim/example_bounce.py +1 -1
  19. warp/examples/optim/example_cloth_throw.py +1 -1
  20. warp/examples/optim/example_diffray.py +4 -15
  21. warp/examples/optim/example_drone.py +1 -1
  22. warp/examples/optim/example_softbody_properties.py +392 -0
  23. warp/examples/optim/example_trajectory.py +1 -3
  24. warp/examples/optim/example_walker.py +5 -0
  25. warp/examples/sim/example_cartpole.py +0 -2
  26. warp/examples/sim/example_cloth.py +3 -1
  27. warp/examples/sim/example_cloth_self_contact.py +260 -0
  28. warp/examples/sim/example_granular_collision_sdf.py +4 -5
  29. warp/examples/sim/example_jacobian_ik.py +0 -2
  30. warp/examples/sim/example_quadruped.py +5 -2
  31. warp/examples/tile/example_tile_cholesky.py +79 -0
  32. warp/examples/tile/example_tile_convolution.py +2 -2
  33. warp/examples/tile/example_tile_fft.py +2 -2
  34. warp/examples/tile/example_tile_filtering.py +3 -3
  35. warp/examples/tile/example_tile_matmul.py +4 -4
  36. warp/examples/tile/example_tile_mlp.py +12 -12
  37. warp/examples/tile/example_tile_nbody.py +180 -0
  38. warp/examples/tile/example_tile_walker.py +319 -0
  39. warp/fem/geometry/geometry.py +0 -2
  40. warp/math.py +147 -0
  41. warp/native/array.h +12 -0
  42. warp/native/builtin.h +0 -1
  43. warp/native/bvh.cpp +149 -70
  44. warp/native/bvh.cu +287 -68
  45. warp/native/bvh.h +195 -85
  46. warp/native/clang/clang.cpp +5 -1
  47. warp/native/coloring.cpp +5 -1
  48. warp/native/cuda_util.cpp +91 -53
  49. warp/native/cuda_util.h +5 -0
  50. warp/native/exports.h +40 -40
  51. warp/native/intersect.h +17 -0
  52. warp/native/mat.h +41 -0
  53. warp/native/mathdx.cpp +19 -0
  54. warp/native/mesh.cpp +25 -8
  55. warp/native/mesh.cu +153 -101
  56. warp/native/mesh.h +482 -403
  57. warp/native/quat.h +40 -0
  58. warp/native/solid_angle.h +7 -0
  59. warp/native/sort.cpp +85 -0
  60. warp/native/sort.cu +34 -0
  61. warp/native/sort.h +3 -1
  62. warp/native/spatial.h +11 -0
  63. warp/native/tile.h +1187 -669
  64. warp/native/tile_reduce.h +8 -6
  65. warp/native/vec.h +41 -0
  66. warp/native/warp.cpp +8 -1
  67. warp/native/warp.cu +263 -40
  68. warp/native/warp.h +19 -5
  69. warp/optim/linear.py +22 -4
  70. warp/render/render_opengl.py +130 -64
  71. warp/sim/__init__.py +6 -1
  72. warp/sim/collide.py +270 -26
  73. warp/sim/import_urdf.py +8 -8
  74. warp/sim/integrator_euler.py +25 -7
  75. warp/sim/integrator_featherstone.py +154 -35
  76. warp/sim/integrator_vbd.py +842 -40
  77. warp/sim/model.py +134 -72
  78. warp/sparse.py +1 -1
  79. warp/stubs.py +265 -132
  80. warp/tape.py +28 -30
  81. warp/tests/aux_test_module_unload.py +15 -0
  82. warp/tests/{test_sim_grad.py → flaky_test_sim_grad.py} +104 -63
  83. warp/tests/test_array.py +74 -0
  84. warp/tests/test_assert.py +242 -0
  85. warp/tests/test_codegen.py +14 -61
  86. warp/tests/test_collision.py +2 -2
  87. warp/tests/test_coloring.py +12 -2
  88. warp/tests/test_examples.py +12 -1
  89. warp/tests/test_func.py +21 -4
  90. warp/tests/test_grad_debug.py +87 -2
  91. warp/tests/test_hash_grid.py +1 -1
  92. warp/tests/test_ipc.py +116 -0
  93. warp/tests/test_lerp.py +13 -87
  94. warp/tests/test_mat.py +138 -167
  95. warp/tests/test_math.py +47 -1
  96. warp/tests/test_matmul.py +17 -16
  97. warp/tests/test_matmul_lite.py +10 -15
  98. warp/tests/test_mesh.py +84 -60
  99. warp/tests/test_mesh_query_aabb.py +165 -0
  100. warp/tests/test_mesh_query_point.py +328 -286
  101. warp/tests/test_mesh_query_ray.py +134 -121
  102. warp/tests/test_mlp.py +2 -2
  103. warp/tests/test_operators.py +43 -0
  104. warp/tests/test_overwrite.py +47 -2
  105. warp/tests/test_quat.py +77 -0
  106. warp/tests/test_reload.py +29 -0
  107. warp/tests/test_sim_grad_bounce_linear.py +204 -0
  108. warp/tests/test_smoothstep.py +17 -83
  109. warp/tests/test_static.py +19 -3
  110. warp/tests/test_tape.py +25 -0
  111. warp/tests/test_tile.py +178 -191
  112. warp/tests/test_tile_load.py +356 -0
  113. warp/tests/test_tile_mathdx.py +61 -8
  114. warp/tests/test_tile_mlp.py +17 -17
  115. warp/tests/test_tile_reduce.py +24 -18
  116. warp/tests/test_tile_shared_memory.py +66 -17
  117. warp/tests/test_tile_view.py +165 -0
  118. warp/tests/test_torch.py +35 -0
  119. warp/tests/test_utils.py +36 -24
  120. warp/tests/test_vec.py +110 -0
  121. warp/tests/unittest_suites.py +29 -4
  122. warp/tests/unittest_utils.py +30 -13
  123. warp/thirdparty/unittest_parallel.py +2 -2
  124. warp/types.py +411 -101
  125. warp/utils.py +10 -7
  126. {warp_lang-1.5.0.dist-info → warp_lang-1.6.0.dist-info}/METADATA +92 -69
  127. {warp_lang-1.5.0.dist-info → warp_lang-1.6.0.dist-info}/RECORD +130 -119
  128. {warp_lang-1.5.0.dist-info → warp_lang-1.6.0.dist-info}/WHEEL +1 -1
  129. warp/examples/benchmarks/benchmark_tile.py +0 -179
  130. warp/native/tile_gemm.h +0 -341
  131. {warp_lang-1.5.0.dist-info → warp_lang-1.6.0.dist-info}/LICENSE.md +0 -0
  132. {warp_lang-1.5.0.dist-info → warp_lang-1.6.0.dist-info}/top_level.txt +0 -0
warp/context.py CHANGED
@@ -5,8 +5,11 @@
5
5
  # distribution of this software and related documentation without an express
6
6
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
7
 
8
+ from __future__ import annotations
9
+
8
10
  import ast
9
11
  import ctypes
12
+ import errno
10
13
  import functools
11
14
  import hashlib
12
15
  import inspect
@@ -17,6 +20,7 @@ import operator
17
20
  import os
18
21
  import platform
19
22
  import sys
23
+ import time
20
24
  import types
21
25
  import typing
22
26
  import weakref
@@ -238,24 +242,23 @@ class Function:
238
242
  # in a way that is compatible with Python's semantics.
239
243
  signature_params = []
240
244
  signature_default_param_kind = inspect.Parameter.POSITIONAL_OR_KEYWORD
241
- for param_name in self.input_types.keys():
242
- if param_name.startswith("**"):
243
- param_name = param_name[2:]
245
+ for raw_param_name in self.input_types.keys():
246
+ if raw_param_name.startswith("**"):
247
+ param_name = raw_param_name[2:]
244
248
  param_kind = inspect.Parameter.VAR_KEYWORD
245
- elif param_name.startswith("*"):
246
- param_name = param_name[1:]
249
+ elif raw_param_name.startswith("*"):
250
+ param_name = raw_param_name[1:]
247
251
  param_kind = inspect.Parameter.VAR_POSITIONAL
248
252
 
249
253
  # Once a variadic argument like `*args` is found, any following
250
254
  # arguments need to be passed using keywords.
251
255
  signature_default_param_kind = inspect.Parameter.KEYWORD_ONLY
252
256
  else:
257
+ param_name = raw_param_name
253
258
  param_kind = signature_default_param_kind
254
259
 
255
- param = param = inspect.Parameter(
256
- param_name,
257
- param_kind,
258
- default=self.defaults.get(param_name, inspect.Parameter.empty),
260
+ param = inspect.Parameter(
261
+ param_name, param_kind, default=self.defaults.get(param_name, inspect.Parameter.empty)
259
262
  )
260
263
  signature_params.append(param)
261
264
  self.signature = inspect.Signature(signature_params)
@@ -294,22 +297,22 @@ class Function:
294
297
 
295
298
  if hasattr(self, "user_overloads") and len(self.user_overloads):
296
299
  # user-defined function with overloads
300
+ bound_args = self.signature.bind(*args, **kwargs)
301
+ if self.defaults:
302
+ warp.codegen.apply_defaults(bound_args, self.defaults)
297
303
 
298
- if len(kwargs):
299
- raise RuntimeError(
300
- f"Error calling function '{self.key}', keyword arguments are not supported for user-defined overloads."
301
- )
304
+ arguments = tuple(bound_args.arguments.values())
302
305
 
303
306
  # try and find a matching overload
304
307
  for overload in self.user_overloads.values():
305
- if len(overload.input_types) != len(args):
308
+ if len(overload.input_types) != len(arguments):
306
309
  continue
307
310
  template_types = list(overload.input_types.values())
308
311
  arg_names = list(overload.input_types.keys())
309
312
  try:
310
313
  # attempt to unify argument types with function template types
311
- warp.types.infer_argument_types(args, template_types, arg_names)
312
- return overload.func(*args)
314
+ warp.types.infer_argument_types(arguments, template_types, arg_names)
315
+ return overload.func(*arguments)
313
316
  except Exception:
314
317
  continue
315
318
 
@@ -392,7 +395,8 @@ class Function:
392
395
  if not warp.codegen.func_match_args(f, arg_types, kwarg_types):
393
396
  continue
394
397
 
395
- if len(f.input_types) != len(arg_types):
398
+ acceptable_arg_num = len(f.input_types) - len(f.defaults) <= len(arg_types) <= len(f.input_types)
399
+ if not acceptable_arg_num:
396
400
  continue
397
401
 
398
402
  # try to match the given types to the function template types
@@ -409,6 +413,10 @@ class Function:
409
413
 
410
414
  arg_names = f.input_types.keys()
411
415
  overload_annotations = dict(zip(arg_names, arg_types))
416
+ # add defaults
417
+ for k, d in f.defaults.items():
418
+ if k not in overload_annotations:
419
+ overload_annotations[k] = warp.codegen.strip_reference(warp.codegen.get_arg_type(d))
412
420
 
413
421
  ovl = shallowcopy(f)
414
422
  ovl.adj = warp.codegen.Adjoint(f.func, overload_annotations)
@@ -509,11 +517,10 @@ def call_builtin(func: Function, *params) -> Tuple[bool, Any]:
509
517
  if elem_count != arg_type._length_:
510
518
  return (False, None)
511
519
 
512
- # Retrieve the element type of the sequence while ensuring
513
- # that it's homogeneous.
520
+ # Retrieve the element type of the sequence while ensuring that it's homogeneous.
514
521
  elem_type = type(arr[0])
515
- for i in range(1, elem_count):
516
- if type(arr[i]) is not elem_type:
522
+ for array_index in range(1, elem_count):
523
+ if type(arr[array_index]) is not elem_type:
517
524
  raise ValueError("All array elements must share the same type.")
518
525
 
519
526
  expected_elem_type = arg_type._wp_scalar_type_
@@ -543,10 +550,10 @@ def call_builtin(func: Function, *params) -> Tuple[bool, Any]:
543
550
  c_param = arg_type()
544
551
  if warp.types.type_is_matrix(arg_type):
545
552
  rows, cols = arg_type._shape_
546
- for i in range(rows):
547
- idx_start = i * cols
553
+ for row_index in range(rows):
554
+ idx_start = row_index * cols
548
555
  idx_end = idx_start + cols
549
- c_param[i] = arr[idx_start:idx_end]
556
+ c_param[row_index] = arr[idx_start:idx_end]
550
557
  else:
551
558
  c_param[:] = arr
552
559
 
@@ -753,8 +760,15 @@ def func(f):
753
760
  scope_locals = inspect.currentframe().f_back.f_locals
754
761
 
755
762
  m = get_module(f.__module__)
763
+ doc = getattr(f, "__doc__", "") or ""
756
764
  Function(
757
- func=f, key=name, namespace="", module=m, value_func=None, scope_locals=scope_locals
765
+ func=f,
766
+ key=name,
767
+ namespace="",
768
+ module=m,
769
+ value_func=None,
770
+ scope_locals=scope_locals,
771
+ doc=doc.strip(),
758
772
  ) # value_type not known yet, will be inferred during Adjoint.build()
759
773
 
760
774
  # use the top of the list of overloads for this key
@@ -1059,7 +1073,8 @@ def overload(kernel, arg_types=Union[None, Dict[str, Any], List[Any]]):
1059
1073
  raise RuntimeError("wp.overload() called with invalid argument!")
1060
1074
 
1061
1075
 
1062
- builtin_functions = {}
1076
+ # native functions that are part of the Warp API
1077
+ builtin_functions: Dict[str, Function] = {}
1063
1078
 
1064
1079
 
1065
1080
  def get_generic_vtypes():
@@ -1239,16 +1254,16 @@ def add_builtin(
1239
1254
  typelists.append(l)
1240
1255
 
1241
1256
  for arg_types in itertools.product(*typelists):
1242
- arg_types = dict(zip(input_types.keys(), arg_types))
1257
+ concrete_arg_types = dict(zip(input_types.keys(), arg_types))
1243
1258
 
1244
1259
  # Some of these argument lists won't work, eg if the function is mul(), we won't be
1245
1260
  # able to do a matrix vector multiplication for a mat22 and a vec3. The `constraint`
1246
1261
  # function determines which combinations are valid:
1247
1262
  if constraint:
1248
- if constraint(arg_types) is False:
1263
+ if constraint(concrete_arg_types) is False:
1249
1264
  continue
1250
1265
 
1251
- return_type = value_func(arg_types, None)
1266
+ return_type = value_func(concrete_arg_types, None)
1252
1267
 
1253
1268
  # The return_type might just be vector_t(length=3,dtype=wp.float32), so we've got to match that
1254
1269
  # in the list of hard coded types so it knows it's returning one of them:
@@ -1266,7 +1281,7 @@ def add_builtin(
1266
1281
  # finally we can generate a function call for these concrete types:
1267
1282
  add_builtin(
1268
1283
  key,
1269
- input_types=arg_types,
1284
+ input_types=concrete_arg_types,
1270
1285
  value_type=return_type,
1271
1286
  value_func=value_func if return_type is Any else None,
1272
1287
  export_func=export_func,
@@ -1328,6 +1343,28 @@ def add_builtin(
1328
1343
  setattr(warp, key, func)
1329
1344
 
1330
1345
 
1346
+ def register_api_function(
1347
+ function: Function,
1348
+ group: str = "Other",
1349
+ hidden=False,
1350
+ ):
1351
+ """Main entry point to register a Warp Python function to be part of the Warp API and appear in the documentation.
1352
+
1353
+ Args:
1354
+ function (Function): Warp function to be registered.
1355
+ group (str): Classification used for the documentation.
1356
+ input_types (Mapping[str, Any]): Signature of the user-facing function.
1357
+ Variadic arguments are supported by prefixing the parameter names
1358
+ with asterisks as in `*args` and `**kwargs`. Generic arguments are
1359
+ supported with types such as `Any`, `Float`, `Scalar`, etc.
1360
+ value_type (Any): Type returned by the function.
1361
+ hidden (bool): Whether to add that function into the documentation.
1362
+ """
1363
+ function.group = group
1364
+ function.hidden = hidden
1365
+ builtin_functions[function.key] = function
1366
+
1367
+
1331
1368
  # global dictionary of modules
1332
1369
  user_modules = {}
1333
1370
 
@@ -1561,6 +1598,7 @@ class ModuleBuilder:
1561
1598
  self.options = options
1562
1599
  self.module = module
1563
1600
  self.deferred_functions = []
1601
+ self.fatbins = {} # map from <some identifier> to fatbins, to add at link time
1564
1602
  self.ltoirs = {} # map from lto symbol to lto binary
1565
1603
  self.ltoirs_decl = {} # map from lto symbol to lto forward declaration
1566
1604
 
@@ -1675,7 +1713,7 @@ class ModuleBuilder:
1675
1713
 
1676
1714
  for kernel in self.kernels:
1677
1715
  source += warp.codegen.codegen_kernel(kernel, device=device, options=self.options)
1678
- source += warp.codegen.codegen_module(kernel, device=device)
1716
+ source += warp.codegen.codegen_module(kernel, device=device, options=self.options)
1679
1717
 
1680
1718
  # add headers
1681
1719
  if device == "cpu":
@@ -1728,20 +1766,26 @@ class ModuleExec:
1728
1766
 
1729
1767
  name = kernel.get_mangled_name()
1730
1768
 
1769
+ options = dict(kernel.module.options)
1770
+ options.update(kernel.options)
1771
+
1731
1772
  if self.device.is_cuda:
1732
1773
  forward_name = name + "_cuda_kernel_forward"
1733
1774
  forward_kernel = runtime.core.cuda_get_kernel(
1734
1775
  self.device.context, self.handle, forward_name.encode("utf-8")
1735
1776
  )
1736
1777
 
1737
- backward_name = name + "_cuda_kernel_backward"
1738
- backward_kernel = runtime.core.cuda_get_kernel(
1739
- self.device.context, self.handle, backward_name.encode("utf-8")
1740
- )
1778
+ if options["enable_backward"]:
1779
+ backward_name = name + "_cuda_kernel_backward"
1780
+ backward_kernel = runtime.core.cuda_get_kernel(
1781
+ self.device.context, self.handle, backward_name.encode("utf-8")
1782
+ )
1783
+ else:
1784
+ backward_kernel = None
1741
1785
 
1742
1786
  # look up the required shared memory size for each kernel from module metadata
1743
1787
  forward_smem_bytes = self.meta[forward_name + "_smem_bytes"]
1744
- backward_smem_bytes = self.meta[backward_name + "_smem_bytes"]
1788
+ backward_smem_bytes = self.meta[backward_name + "_smem_bytes"] if options["enable_backward"] else 0
1745
1789
 
1746
1790
  # configure kernels maximum shared memory size
1747
1791
  max_smem_bytes = runtime.core.cuda_get_max_shared_memory(self.device.context)
@@ -1751,9 +1795,6 @@ class ModuleExec:
1751
1795
  f"Warning: Failed to configure kernel dynamic shared memory for this device, tried to configure {forward_name} kernel for {forward_smem_bytes} bytes, but maximum available is {max_smem_bytes}"
1752
1796
  )
1753
1797
 
1754
- options = dict(kernel.module.options)
1755
- options.update(kernel.options)
1756
-
1757
1798
  if options["enable_backward"] and not runtime.core.cuda_configure_kernel_shared_memory(
1758
1799
  backward_kernel, backward_smem_bytes
1759
1800
  ):
@@ -1768,9 +1809,14 @@ class ModuleExec:
1768
1809
  forward = (
1769
1810
  func(runtime.llvm.lookup(self.handle.encode("utf-8"), (name + "_cpu_forward").encode("utf-8"))) or None
1770
1811
  )
1771
- backward = (
1772
- func(runtime.llvm.lookup(self.handle.encode("utf-8"), (name + "_cpu_backward").encode("utf-8"))) or None
1773
- )
1812
+
1813
+ if options["enable_backward"]:
1814
+ backward = (
1815
+ func(runtime.llvm.lookup(self.handle.encode("utf-8"), (name + "_cpu_backward").encode("utf-8")))
1816
+ or None
1817
+ )
1818
+ else:
1819
+ backward = None
1774
1820
 
1775
1821
  hooks = KernelHooks(forward, backward)
1776
1822
 
@@ -1803,13 +1849,13 @@ class Module:
1803
1849
  self._live_kernels = weakref.WeakSet()
1804
1850
 
1805
1851
  # executable modules currently loaded
1806
- self.execs = {} # (device.context: ModuleExec)
1852
+ self.execs = {} # ((device.context, blockdim): ModuleExec)
1807
1853
 
1808
1854
  # set of device contexts where the build has failed
1809
1855
  self.failed_builds = set()
1810
1856
 
1811
- # hash data, including the module hash
1812
- self.hasher = None
1857
+ # hash data, including the module hash. Module may store multiple hashes (one per block_dim used)
1858
+ self.hashers = {}
1813
1859
 
1814
1860
  # LLVM executable modules are identified using strings. Since it's possible for multiple
1815
1861
  # executable versions to be loaded at the same time, we need a way to ensure uniqueness.
@@ -1822,6 +1868,8 @@ class Module:
1822
1868
  "max_unroll": warp.config.max_unroll,
1823
1869
  "enable_backward": warp.config.enable_backward,
1824
1870
  "fast_math": False,
1871
+ "fuse_fp": True,
1872
+ "lineinfo": False,
1825
1873
  "cuda_output": None, # supported values: "ptx", "cubin", or None (automatic)
1826
1874
  "mode": warp.config.mode,
1827
1875
  "block_dim": 256,
@@ -1965,28 +2013,27 @@ class Module:
1965
2013
 
1966
2014
  def hash_module(self):
1967
2015
  # compute latest hash
1968
- self.hasher = ModuleHasher(self)
1969
- return self.hasher.get_module_hash()
2016
+ block_dim = self.options["block_dim"]
2017
+ self.hashers[block_dim] = ModuleHasher(self)
2018
+ return self.hashers[block_dim].get_module_hash()
1970
2019
 
1971
2020
  def load(self, device, block_dim=None) -> ModuleExec:
1972
2021
  device = runtime.get_device(device)
1973
2022
 
1974
- # re-compile module if tile size (blockdim) changes
1975
- # todo: it would be better to have a method such as `module.get_kernel(block_dim=N)`
1976
- # that can return a single kernel instance with a given block size
2023
+ # update module options if launching with a new block dim
1977
2024
  if block_dim is not None:
1978
- if self.options["block_dim"] != block_dim:
1979
- self.unload()
1980
2025
  self.options["block_dim"] = block_dim
1981
2026
 
2027
+ active_block_dim = self.options["block_dim"]
2028
+
1982
2029
  # compute the hash if needed
1983
- if self.hasher is None:
1984
- self.hasher = ModuleHasher(self)
2030
+ if active_block_dim not in self.hashers:
2031
+ self.hashers[active_block_dim] = ModuleHasher(self)
1985
2032
 
1986
2033
  # check if executable module is already loaded and not stale
1987
- exec = self.execs.get(device.context)
2034
+ exec = self.execs.get((device.context, active_block_dim))
1988
2035
  if exec is not None:
1989
- if exec.module_hash == self.hasher.module_hash:
2036
+ if exec.module_hash == self.hashers[active_block_dim].get_module_hash():
1990
2037
  return exec
1991
2038
 
1992
2039
  # quietly avoid repeated build attempts to reduce error spew
@@ -1994,10 +2041,11 @@ class Module:
1994
2041
  return None
1995
2042
 
1996
2043
  module_name = "wp_" + self.name
1997
- module_hash = self.hasher.module_hash
2044
+ module_hash = self.hashers[active_block_dim].get_module_hash()
1998
2045
 
1999
2046
  # use a unique module path using the module short hash
2000
- module_dir = os.path.join(warp.config.kernel_cache_dir, f"{module_name}_{module_hash.hex()[:7]}")
2047
+ module_name_short = f"{module_name}_{module_hash.hex()[:7]}"
2048
+ module_dir = os.path.join(warp.config.kernel_cache_dir, module_name_short)
2001
2049
 
2002
2050
  with warp.ScopedTimer(
2003
2051
  f"Module {self.name} {module_hash.hex()[:7]} load on device '{device}'", active=not warp.config.quiet
@@ -2005,7 +2053,7 @@ class Module:
2005
2053
  # -----------------------------------------------------------
2006
2054
  # determine output paths
2007
2055
  if device.is_cpu:
2008
- output_name = "module_codegen.o"
2056
+ output_name = f"{module_name_short}.o"
2009
2057
  output_arch = None
2010
2058
 
2011
2059
  elif device.is_cuda:
@@ -2025,10 +2073,10 @@ class Module:
2025
2073
 
2026
2074
  if use_ptx:
2027
2075
  output_arch = min(device.arch, warp.config.ptx_target_arch)
2028
- output_name = f"module_codegen.sm{output_arch}.ptx"
2076
+ output_name = f"{module_name_short}.sm{output_arch}.ptx"
2029
2077
  else:
2030
2078
  output_arch = device.arch
2031
- output_name = f"module_codegen.sm{output_arch}.cubin"
2079
+ output_name = f"{module_name_short}.sm{output_arch}.cubin"
2032
2080
 
2033
2081
  # final object binary path
2034
2082
  binary_path = os.path.join(module_dir, output_name)
@@ -2050,7 +2098,7 @@ class Module:
2050
2098
  # Some of the Tile codegen, such as cuFFTDx and cuBLASDx, requires knowledge of the target arch
2051
2099
  "output_arch": output_arch,
2052
2100
  }
2053
- builder = ModuleBuilder(self, builder_options, hasher=self.hasher)
2101
+ builder = ModuleBuilder(self, builder_options, hasher=self.hashers[active_block_dim])
2054
2102
 
2055
2103
  # create a temporary (process unique) dir for build outputs before moving to the binary dir
2056
2104
  build_dir = os.path.join(
@@ -2066,7 +2114,7 @@ class Module:
2066
2114
  if device.is_cpu:
2067
2115
  # build
2068
2116
  try:
2069
- source_code_path = os.path.join(build_dir, "module_codegen.cpp")
2117
+ source_code_path = os.path.join(build_dir, f"{module_name_short}.cpp")
2070
2118
 
2071
2119
  # write cpp sources
2072
2120
  cpp_source = builder.codegen("cpu")
@@ -2084,6 +2132,7 @@ class Module:
2084
2132
  mode=self.options["mode"],
2085
2133
  fast_math=self.options["fast_math"],
2086
2134
  verify_fp=warp.config.verify_fp,
2135
+ fuse_fp=self.options["fuse_fp"],
2087
2136
  )
2088
2137
 
2089
2138
  except Exception as e:
@@ -2094,7 +2143,7 @@ class Module:
2094
2143
  elif device.is_cuda:
2095
2144
  # build
2096
2145
  try:
2097
- source_code_path = os.path.join(build_dir, "module_codegen.cu")
2146
+ source_code_path = os.path.join(build_dir, f"{module_name_short}.cu")
2098
2147
 
2099
2148
  # write cuda sources
2100
2149
  cu_source = builder.codegen("cuda")
@@ -2111,9 +2160,12 @@ class Module:
2111
2160
  output_arch,
2112
2161
  output_path,
2113
2162
  config=self.options["mode"],
2114
- fast_math=self.options["fast_math"],
2115
2163
  verify_fp=warp.config.verify_fp,
2164
+ fast_math=self.options["fast_math"],
2165
+ fuse_fp=self.options["fuse_fp"],
2166
+ lineinfo=self.options["lineinfo"],
2116
2167
  ltoirs=builder.ltoirs.values(),
2168
+ fatbins=builder.fatbins.values(),
2117
2169
  )
2118
2170
 
2119
2171
  except Exception as e:
@@ -2125,7 +2177,7 @@ class Module:
2125
2177
  # build meta data
2126
2178
 
2127
2179
  meta = builder.build_meta()
2128
- meta_path = os.path.join(build_dir, "module_codegen.meta")
2180
+ meta_path = os.path.join(build_dir, f"{module_name_short}.meta")
2129
2181
 
2130
2182
  with open(meta_path, "w") as meta_file:
2131
2183
  json.dump(meta, meta_file)
@@ -2133,12 +2185,34 @@ class Module:
2133
2185
  # -----------------------------------------------------------
2134
2186
  # update cache
2135
2187
 
2136
- try:
2137
- # Copy process-specific build directory to a process-independent location
2138
- os.rename(build_dir, module_dir)
2139
- except (OSError, FileExistsError):
2140
- # another process likely updated the module dir first
2141
- pass
2188
+ def safe_rename(src, dst, attempts=5, delay=0.1):
2189
+ for i in range(attempts):
2190
+ try:
2191
+ os.rename(src, dst)
2192
+ return
2193
+ except FileExistsError:
2194
+ return
2195
+ except OSError as e:
2196
+ if e.errno == errno.ENOTEMPTY:
2197
+ # if directory exists we assume another process
2198
+ # got there first, in which case we will copy
2199
+ # our output to the directory manually in second step
2200
+ return
2201
+ else:
2202
+ # otherwise assume directory creation failed e.g.: access denied
2203
+ # on Windows we see occasional failures to rename directories due to
2204
+ # some process holding a lock on a file to be moved to workaround
2205
+ # this we make multiple attempts to rename with some delay
2206
+ if i < attempts - 1:
2207
+ time.sleep(delay)
2208
+ else:
2209
+ print(
2210
+ f"Could not update Warp cache with module binaries, trying to rename {build_dir} to {module_dir}, error {e}"
2211
+ )
2212
+ raise e
2213
+
2214
+ # try to move process outputs to cache
2215
+ safe_rename(build_dir, module_dir)
2142
2216
 
2143
2217
  if os.path.exists(module_dir):
2144
2218
  if not os.path.exists(binary_path):
@@ -2167,7 +2241,7 @@ class Module:
2167
2241
  # -----------------------------------------------------------
2168
2242
  # Load CPU or CUDA binary
2169
2243
 
2170
- meta_path = os.path.join(module_dir, "module_codegen.meta")
2244
+ meta_path = os.path.join(module_dir, f"{module_name_short}.meta")
2171
2245
  with open(meta_path, "r") as meta_file:
2172
2246
  meta = json.load(meta_file)
2173
2247
 
@@ -2177,13 +2251,13 @@ class Module:
2177
2251
  self.cpu_exec_id += 1
2178
2252
  runtime.llvm.load_obj(binary_path.encode("utf-8"), module_handle.encode("utf-8"))
2179
2253
  module_exec = ModuleExec(module_handle, module_hash, device, meta)
2180
- self.execs[None] = module_exec
2254
+ self.execs[(None, active_block_dim)] = module_exec
2181
2255
 
2182
2256
  elif device.is_cuda:
2183
2257
  cuda_module = warp.build.load_cuda(binary_path, device)
2184
2258
  if cuda_module is not None:
2185
2259
  module_exec = ModuleExec(cuda_module, module_hash, device, meta)
2186
- self.execs[device.context] = module_exec
2260
+ self.execs[(device.context, active_block_dim)] = module_exec
2187
2261
  else:
2188
2262
  module_load_timer.extra_msg = " (error)"
2189
2263
  raise Exception(f"Failed to load CUDA module '{self.name}'")
@@ -2205,14 +2279,14 @@ class Module:
2205
2279
 
2206
2280
  def mark_modified(self):
2207
2281
  # clear hash data
2208
- self.hasher = None
2282
+ self.hashers = {}
2209
2283
 
2210
2284
  # clear build failures
2211
2285
  self.failed_builds = set()
2212
2286
 
2213
2287
  # lookup kernel entry points based on name, called after compilation / module load
2214
2288
  def get_kernel_hooks(self, kernel, device):
2215
- module_exec = self.execs.get(device.context)
2289
+ module_exec = self.execs.get((device.context, self.options["block_dim"]))
2216
2290
  if module_exec is not None:
2217
2291
  return module_exec.get_kernel_hooks(kernel)
2218
2292
  else:
@@ -2331,6 +2405,7 @@ class Event:
2331
2405
  DEFAULT = 0x0
2332
2406
  BLOCKING_SYNC = 0x1
2333
2407
  DISABLE_TIMING = 0x2
2408
+ INTERPROCESS = 0x4
2334
2409
 
2335
2410
  def __new__(cls, *args, **kwargs):
2336
2411
  """Creates a new event instance."""
@@ -2338,7 +2413,9 @@ class Event:
2338
2413
  instance.owner = False
2339
2414
  return instance
2340
2415
 
2341
- def __init__(self, device: "Devicelike" = None, cuda_event=None, enable_timing: bool = False):
2416
+ def __init__(
2417
+ self, device: "Devicelike" = None, cuda_event=None, enable_timing: bool = False, interprocess: bool = False
2418
+ ):
2342
2419
  """Initializes the event on a CUDA device.
2343
2420
 
2344
2421
  Args:
@@ -2350,6 +2427,12 @@ class Event:
2350
2427
  :func:`~warp.get_event_elapsed_time` can be used to measure the
2351
2428
  time between two events created with ``enable_timing=True`` and
2352
2429
  recorded onto streams.
2430
+ interprocess: If ``True`` this event may be used as an interprocess event.
2431
+
2432
+ Raises:
2433
+ RuntimeError: The event could not be created.
2434
+ ValueError: The combination of ``enable_timing=True`` and
2435
+ ``interprocess=True`` is not allowed.
2353
2436
  """
2354
2437
 
2355
2438
  device = get_device(device)
@@ -2364,11 +2447,48 @@ class Event:
2364
2447
  flags = Event.Flags.DEFAULT
2365
2448
  if not enable_timing:
2366
2449
  flags |= Event.Flags.DISABLE_TIMING
2450
+ if interprocess:
2451
+ if enable_timing:
2452
+ raise ValueError("The combination of 'enable_timing=True' and 'interprocess=True' is not allowed.")
2453
+ flags |= Event.Flags.INTERPROCESS
2454
+
2367
2455
  self.cuda_event = runtime.core.cuda_event_create(device.context, flags)
2368
2456
  if not self.cuda_event:
2369
2457
  raise RuntimeError(f"Failed to create event on device {device}")
2370
2458
  self.owner = True
2371
2459
 
2460
+ def ipc_handle(self) -> bytes:
2461
+ """Return a CUDA IPC handle of the event as a 64-byte ``bytes`` object.
2462
+
2463
+ The event must have been created with ``interprocess=True`` in order to
2464
+ obtain a valid interprocess handle.
2465
+
2466
+ IPC is currently only supported on Linux.
2467
+
2468
+ Example:
2469
+ Create an event and get its IPC handle::
2470
+
2471
+ e1 = wp.Event(interprocess=True)
2472
+ event_handle = e1.ipc_handle()
2473
+
2474
+ Raises:
2475
+ RuntimeError: Device does not support IPC.
2476
+ """
2477
+
2478
+ if self.device.is_ipc_supported is not False:
2479
+ # Allocate a buffer for the data (64-element char array)
2480
+ ipc_handle_buffer = (ctypes.c_char * 64)()
2481
+
2482
+ warp.context.runtime.core.cuda_ipc_get_event_handle(self.device.context, self.cuda_event, ipc_handle_buffer)
2483
+
2484
+ if ipc_handle_buffer.raw == bytes(64):
2485
+ warp.utils.warn("IPC event handle appears to be invalid. Was interprocess=True used?")
2486
+
2487
+ return ipc_handle_buffer.raw
2488
+
2489
+ else:
2490
+ raise RuntimeError(f"Device {self.device} does not support IPC.")
2491
+
2372
2492
  def __del__(self):
2373
2493
  if not self.owner:
2374
2494
  return
@@ -2516,23 +2636,27 @@ class Device:
2516
2636
  """A device to allocate Warp arrays and to launch kernels on.
2517
2637
 
2518
2638
  Attributes:
2519
- ordinal: A Warp-specific integer label for the device. ``-1`` for CPU devices.
2520
- name: A string label for the device. By default, CPU devices will be named according to the processor name,
2639
+ ordinal (int): A Warp-specific label for the device. ``-1`` for CPU devices.
2640
+ name (str): A label for the device. By default, CPU devices will be named according to the processor name,
2521
2641
  or ``"CPU"`` if the processor name cannot be determined.
2522
- arch: An integer representing the compute capability version number calculated as
2523
- ``10 * major + minor``. ``0`` for CPU devices.
2524
- is_uva: A boolean indicating whether the device supports unified addressing.
2642
+ arch (int): The compute capability version number calculated as ``10 * major + minor``.
2643
+ ``0`` for CPU devices.
2644
+ is_uva (bool): Indicates whether the device supports unified addressing.
2525
2645
  ``False`` for CPU devices.
2526
- is_cubin_supported: A boolean indicating whether Warp's version of NVRTC can directly
2646
+ is_cubin_supported (bool): Indicates whether Warp's version of NVRTC can directly
2527
2647
  generate CUDA binary files (cubin) for this device's architecture. ``False`` for CPU devices.
2528
- is_mempool_supported: A boolean indicating whether the device supports using the
2529
- ``cuMemAllocAsync`` and ``cuMemPool`` family of APIs for stream-ordered memory allocations. ``False`` for
2530
- CPU devices.
2531
- is_primary: A boolean indicating whether this device's CUDA context is also the
2532
- device's primary context.
2533
- uuid: A string representing the UUID of the CUDA device. The UUID is in the same format used by
2534
- ``nvidia-smi -L``. ``None`` for CPU devices.
2535
- pci_bus_id: A string identifier for the CUDA device in the format ``[domain]:[bus]:[device]``, in which
2648
+ is_mempool_supported (bool): Indicates whether the device supports using the ``cuMemAllocAsync`` and
2649
+ ``cuMemPool`` family of APIs for stream-ordered memory allocations. ``False`` for CPU devices.
2650
+ is_ipc_supported (Optional[bool]): Indicates whether the device supports IPC.
2651
+
2652
+ - ``True`` if supported.
2653
+ - ``False`` if not supported.
2654
+ - ``None`` if IPC support could not be determined (e.g. CUDA 11).
2655
+
2656
+ is_primary (bool): Indicates whether this device's CUDA context is also the device's primary context.
2657
+ uuid (str): The UUID of the CUDA device. The UUID is in the same format used by ``nvidia-smi -L``.
2658
+ ``None`` for CPU devices.
2659
+ pci_bus_id (str): An identifier for the CUDA device in the format ``[domain]:[bus]:[device]``, in which
2536
2660
  ``domain``, ``bus``, and ``device`` are all hexadecimal values. ``None`` for CPU devices.
2537
2661
  """
2538
2662
 
@@ -2565,6 +2689,7 @@ class Device:
2565
2689
  self.is_uva = False
2566
2690
  self.is_mempool_supported = False
2567
2691
  self.is_mempool_enabled = False
2692
+ self.is_ipc_supported = False # TODO: Support IPC for CPU arrays
2568
2693
  self.is_cubin_supported = False
2569
2694
  self.uuid = None
2570
2695
  self.pci_bus_id = None
@@ -2580,8 +2705,14 @@ class Device:
2580
2705
  # CUDA device
2581
2706
  self.name = runtime.core.cuda_device_get_name(ordinal).decode()
2582
2707
  self.arch = runtime.core.cuda_device_get_arch(ordinal)
2583
- self.is_uva = runtime.core.cuda_device_is_uva(ordinal)
2584
- self.is_mempool_supported = runtime.core.cuda_device_is_mempool_supported(ordinal)
2708
+ self.is_uva = runtime.core.cuda_device_is_uva(ordinal) > 0
2709
+ self.is_mempool_supported = runtime.core.cuda_device_is_mempool_supported(ordinal) > 0
2710
+ if platform.system() == "Linux":
2711
+ # Use None when IPC support cannot be determined
2712
+ ipc_support_api_query = runtime.core.cuda_device_is_ipc_supported(ordinal)
2713
+ self.is_ipc_supported = bool(ipc_support_api_query) if ipc_support_api_query >= 0 else None
2714
+ else:
2715
+ self.is_ipc_supported = False
2585
2716
  if warp.config.enable_mempools_at_init:
2586
2717
  # enable if supported
2587
2718
  self.is_mempool_enabled = self.is_mempool_supported
@@ -3062,6 +3193,9 @@ class Runtime:
3062
3193
  self.core.radix_sort_pairs_int_host.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int]
3063
3194
  self.core.radix_sort_pairs_int_device.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int]
3064
3195
 
3196
+ self.core.radix_sort_pairs_float_host.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int]
3197
+ self.core.radix_sort_pairs_float_device.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int]
3198
+
3065
3199
  self.core.runlength_encode_int_host.argtypes = [
3066
3200
  ctypes.c_uint64,
3067
3201
  ctypes.c_uint64,
@@ -3078,10 +3212,16 @@ class Runtime:
3078
3212
  ]
3079
3213
 
3080
3214
  self.core.bvh_create_host.restype = ctypes.c_uint64
3081
- self.core.bvh_create_host.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int]
3215
+ self.core.bvh_create_host.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int, ctypes.c_int]
3082
3216
 
3083
3217
  self.core.bvh_create_device.restype = ctypes.c_uint64
3084
- self.core.bvh_create_device.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int]
3218
+ self.core.bvh_create_device.argtypes = [
3219
+ ctypes.c_void_p,
3220
+ ctypes.c_void_p,
3221
+ ctypes.c_void_p,
3222
+ ctypes.c_int,
3223
+ ctypes.c_int,
3224
+ ]
3085
3225
 
3086
3226
  self.core.bvh_destroy_host.argtypes = [ctypes.c_uint64]
3087
3227
  self.core.bvh_destroy_device.argtypes = [ctypes.c_uint64]
@@ -3097,6 +3237,7 @@ class Runtime:
3097
3237
  ctypes.c_int,
3098
3238
  ctypes.c_int,
3099
3239
  ctypes.c_int,
3240
+ ctypes.c_int,
3100
3241
  ]
3101
3242
 
3102
3243
  self.core.mesh_create_device.restype = ctypes.c_uint64
@@ -3108,6 +3249,7 @@ class Runtime:
3108
3249
  ctypes.c_int,
3109
3250
  ctypes.c_int,
3110
3251
  ctypes.c_int,
3252
+ ctypes.c_int,
3111
3253
  ]
3112
3254
 
3113
3255
  self.core.mesh_destroy_host.argtypes = [ctypes.c_uint64]
@@ -3345,6 +3487,8 @@ class Runtime:
3345
3487
  self.core.cuda_device_is_uva.restype = ctypes.c_int
3346
3488
  self.core.cuda_device_is_mempool_supported.argtypes = [ctypes.c_int]
3347
3489
  self.core.cuda_device_is_mempool_supported.restype = ctypes.c_int
3490
+ self.core.cuda_device_is_ipc_supported.argtypes = [ctypes.c_int]
3491
+ self.core.cuda_device_is_ipc_supported.restype = ctypes.c_int
3348
3492
  self.core.cuda_device_set_mempool_release_threshold.argtypes = [ctypes.c_int, ctypes.c_uint64]
3349
3493
  self.core.cuda_device_set_mempool_release_threshold.restype = ctypes.c_int
3350
3494
  self.core.cuda_device_get_mempool_release_threshold.argtypes = [ctypes.c_int]
@@ -3398,6 +3542,22 @@ class Runtime:
3398
3542
  self.core.cuda_set_mempool_access_enabled.argtypes = [ctypes.c_int, ctypes.c_int, ctypes.c_int]
3399
3543
  self.core.cuda_set_mempool_access_enabled.restype = ctypes.c_int
3400
3544
 
3545
+ # inter-process communication
3546
+ self.core.cuda_ipc_get_mem_handle.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_char)]
3547
+ self.core.cuda_ipc_get_mem_handle.restype = None
3548
+ self.core.cuda_ipc_open_mem_handle.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_char)]
3549
+ self.core.cuda_ipc_open_mem_handle.restype = ctypes.c_void_p
3550
+ self.core.cuda_ipc_close_mem_handle.argtypes = [ctypes.c_void_p]
3551
+ self.core.cuda_ipc_close_mem_handle.restype = None
3552
+ self.core.cuda_ipc_get_event_handle.argtypes = [
3553
+ ctypes.c_void_p,
3554
+ ctypes.c_void_p,
3555
+ ctypes.POINTER(ctypes.c_char),
3556
+ ]
3557
+ self.core.cuda_ipc_get_event_handle.restype = None
3558
+ self.core.cuda_ipc_open_event_handle.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_char)]
3559
+ self.core.cuda_ipc_open_event_handle.restype = ctypes.c_void_p
3560
+
3401
3561
  self.core.cuda_stream_create.argtypes = [ctypes.c_void_p, ctypes.c_int]
3402
3562
  self.core.cuda_stream_create.restype = ctypes.c_void_p
3403
3563
  self.core.cuda_stream_destroy.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
@@ -3445,6 +3605,7 @@ class Runtime:
3445
3605
 
3446
3606
  self.core.cuda_compile_program.argtypes = [
3447
3607
  ctypes.c_char_p, # cuda_src
3608
+ ctypes.c_char_p, # program name
3448
3609
  ctypes.c_int, # arch
3449
3610
  ctypes.c_char_p, # include_dir
3450
3611
  ctypes.c_int, # num_cuda_include_dirs
@@ -3453,10 +3614,13 @@ class Runtime:
3453
3614
  ctypes.c_bool, # verbose
3454
3615
  ctypes.c_bool, # verify_fp
3455
3616
  ctypes.c_bool, # fast_math
3617
+ ctypes.c_bool, # fuse_fp
3618
+ ctypes.c_bool, # lineinfo
3456
3619
  ctypes.c_char_p, # output_path
3457
3620
  ctypes.c_size_t, # num_ltoirs
3458
3621
  ctypes.POINTER(ctypes.c_char_p), # ltoirs
3459
3622
  ctypes.POINTER(ctypes.c_size_t), # ltoir_sizes
3623
+ ctypes.POINTER(ctypes.c_int), # ltoir_input_types, each of type nvJitLinkInputType
3460
3624
  ]
3461
3625
  self.core.cuda_compile_program.restype = ctypes.c_size_t
3462
3626
 
@@ -3496,6 +3660,22 @@ class Runtime:
3496
3660
  ]
3497
3661
  self.core.cuda_compile_dot.restype = ctypes.c_bool
3498
3662
 
3663
+ self.core.cuda_compile_solver.argtypes = [
3664
+ ctypes.c_char_p, # universal fatbin
3665
+ ctypes.c_char_p, # lto
3666
+ ctypes.c_char_p, # function name
3667
+ ctypes.c_int, # num include dirs
3668
+ ctypes.POINTER(ctypes.c_char_p), # include dirs
3669
+ ctypes.c_char_p, # mathdx include dir
3670
+ ctypes.c_int, # arch
3671
+ ctypes.c_int, # M
3672
+ ctypes.c_int, # N
3673
+ ctypes.c_int, # precision
3674
+ ctypes.c_int, # fill_mode
3675
+ ctypes.c_int, # num threads
3676
+ ]
3677
+ self.core.cuda_compile_fft.restype = ctypes.c_bool
3678
+
3499
3679
  self.core.cuda_load_module.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
3500
3680
  self.core.cuda_load_module.restype = ctypes.c_void_p
3501
3681
 
@@ -4074,7 +4254,7 @@ def set_mempool_enabled(device: Devicelike, enable: bool) -> None:
4074
4254
  They should generally be enabled, but there is a rare caveat. Copying data between different GPUs
4075
4255
  may fail during graph capture if the memory was allocated using pooled allocators and memory pool
4076
4256
  access is not enabled between the two GPUs. This is an internal CUDA limitation that is not related
4077
- to Warp. The preferred solution is to enable memory pool access using `warp.set_mempool_access_enabled()`.
4257
+ to Warp. The preferred solution is to enable memory pool access using :func:`set_mempool_access_enabled`.
4078
4258
  If peer access is not supported, then the default CUDA allocators must be used to pre-allocate the memory
4079
4259
  prior to graph capture.
4080
4260
  """
@@ -4846,6 +5026,40 @@ def from_numpy(
4846
5026
  )
4847
5027
 
4848
5028
 
5029
+ def event_from_ipc_handle(handle, device: "Devicelike" = None) -> Event:
5030
+ """Create an event from an IPC handle.
5031
+
5032
+ Args:
5033
+ handle: The interprocess event handle for an existing CUDA event.
5034
+ device (Devicelike): Device to associate with the array.
5035
+
5036
+ Returns:
5037
+ An event created from the interprocess event handle ``handle``.
5038
+
5039
+ Raises:
5040
+ RuntimeError: IPC is not supported on ``device``.
5041
+ """
5042
+
5043
+ try:
5044
+ # Performance note: try first, ask questions later
5045
+ device = warp.context.runtime.get_device(device)
5046
+ except Exception:
5047
+ # Fallback to using the public API for retrieving the device,
5048
+ # which takes take of initializing Warp if needed.
5049
+ device = warp.context.get_device(device)
5050
+
5051
+ if device.is_ipc_supported is False:
5052
+ raise RuntimeError(f"IPC is not supported on device {device}.")
5053
+
5054
+ event = Event(
5055
+ device=device, cuda_event=warp.context.runtime.core.cuda_ipc_open_event_handle(device.context, handle)
5056
+ )
5057
+ # Events created from IPC handles must be freed with cuEventDestroy
5058
+ event.owner = True
5059
+
5060
+ return event
5061
+
5062
+
4849
5063
  # given a kernel destination argument type and a value convert
4850
5064
  # to a c-type that can be passed to a kernel
4851
5065
  def pack_arg(kernel, arg_type, arg_name, value, device, adjoint=False):
@@ -4927,6 +5141,9 @@ def pack_arg(kernel, arg_type, arg_name, value, device, adjoint=False):
4927
5141
 
4928
5142
  # try to convert to a value type (vec3, mat33, etc)
4929
5143
  elif issubclass(arg_type, ctypes.Array):
5144
+ # simple value types don't have gradient arrays, but native built-in signatures still expect a non-null adjoint value of the correct type
5145
+ if value is None and adjoint:
5146
+ return arg_type(0)
4930
5147
  if warp.types.types_equal(type(value), arg_type):
4931
5148
  return value
4932
5149
  else:
@@ -4936,9 +5153,6 @@ def pack_arg(kernel, arg_type, arg_name, value, device, adjoint=False):
4936
5153
  except Exception as e:
4937
5154
  raise ValueError(f"Failed to convert argument for param {arg_name} to {type_str(arg_type)}") from e
4938
5155
 
4939
- elif isinstance(value, bool):
4940
- return ctypes.c_bool(value)
4941
-
4942
5156
  elif isinstance(value, arg_type):
4943
5157
  try:
4944
5158
  # try to pack as a scalar type
@@ -4953,6 +5167,9 @@ def pack_arg(kernel, arg_type, arg_name, value, device, adjoint=False):
4953
5167
  ) from e
4954
5168
 
4955
5169
  else:
5170
+ # scalar args don't have gradient arrays, but native built-in signatures still expect a non-null scalar adjoint
5171
+ if value is None and adjoint:
5172
+ return arg_type._type_(0)
4956
5173
  try:
4957
5174
  # try to pack as a scalar type
4958
5175
  if arg_type is warp.types.float16:
@@ -5272,6 +5489,8 @@ def launch(
5272
5489
  params_addr=kernel_params,
5273
5490
  bounds=bounds,
5274
5491
  device=device,
5492
+ max_blocks=max_blocks,
5493
+ block_dim=block_dim,
5275
5494
  )
5276
5495
  return launch
5277
5496
 
@@ -5355,7 +5574,7 @@ def launch_tiled(*args, **kwargs):
5355
5574
  kwargs["dim"] = dim + [kwargs["block_dim"]]
5356
5575
 
5357
5576
  # forward to original launch method
5358
- launch(*args, **kwargs)
5577
+ return launch(*args, **kwargs)
5359
5578
 
5360
5579
 
5361
5580
  def synchronize():
@@ -6010,14 +6229,19 @@ def export_functions_rst(file): # pragma: no cover
6010
6229
  # build dictionary of all functions by group
6011
6230
  groups = {}
6012
6231
 
6013
- for _k, f in builtin_functions.items():
6232
+ functions = list(builtin_functions.values())
6233
+
6234
+ for f in functions:
6014
6235
  # build dict of groups
6015
6236
  if f.group not in groups:
6016
6237
  groups[f.group] = []
6017
6238
 
6018
- # append all overloads to the group
6019
- for o in f.overloads:
6020
- groups[f.group].append(o)
6239
+ if hasattr(f, "overloads"):
6240
+ # append all overloads to the group
6241
+ for o in f.overloads:
6242
+ groups[f.group].append(o)
6243
+ else:
6244
+ groups[f.group].append(f)
6021
6245
 
6022
6246
  # Keep track of what function and query types have been written
6023
6247
  written_functions = set()
@@ -6037,6 +6261,10 @@ def export_functions_rst(file): # pragma: no cover
6037
6261
  print("---------------", file=file)
6038
6262
 
6039
6263
  for f in g:
6264
+ if f.func:
6265
+ # f is a Warp function written in Python, we can use autofunction
6266
+ print(f".. autofunction:: {f.func.__module__}.{f.key}", file=file)
6267
+ continue
6040
6268
  for f_prefix, query_type in query_types:
6041
6269
  if f.key.startswith(f_prefix) and query_type not in written_query_types:
6042
6270
  print(f".. autoclass:: {query_type}", file=file)
@@ -6094,24 +6322,32 @@ def export_stubs(file): # pragma: no cover
6094
6322
  print(header, file=file)
6095
6323
  print(file=file)
6096
6324
 
6097
- for k, g in builtin_functions.items():
6098
- for f in g.overloads:
6099
- args = ", ".join(f"{k}: {type_str(v)}" for k, v in f.input_types.items())
6325
+ def add_stub(f):
6326
+ args = ", ".join(f"{k}: {type_str(v)}" for k, v in f.input_types.items())
6100
6327
 
6101
- return_str = ""
6328
+ return_str = ""
6102
6329
 
6103
- if f.hidden: # or f.generic:
6104
- continue
6330
+ if f.hidden: # or f.generic:
6331
+ return
6105
6332
 
6333
+ return_type = f.value_type
6334
+ if f.value_func:
6106
6335
  return_type = f.value_func(None, None)
6107
- if return_type:
6108
- return_str = " -> " + type_str(return_type)
6109
-
6110
- print("@over", file=file)
6111
- print(f"def {f.key}({args}){return_str}:", file=file)
6112
- print(f' """{f.doc}', file=file)
6113
- print(' """', file=file)
6114
- print(" ...\n\n", file=file)
6336
+ if return_type:
6337
+ return_str = " -> " + type_str(return_type)
6338
+
6339
+ print("@over", file=file)
6340
+ print(f"def {f.key}({args}){return_str}:", file=file)
6341
+ print(f' """{f.doc}', file=file)
6342
+ print(' """', file=file)
6343
+ print(" ...\n\n", file=file)
6344
+
6345
+ for g in builtin_functions.values():
6346
+ if hasattr(g, "overloads"):
6347
+ for f in g.overloads:
6348
+ add_stub(f)
6349
+ else:
6350
+ add_stub(g)
6115
6351
 
6116
6352
 
6117
6353
  def export_builtins(file: io.TextIOBase): # pragma: no cover
@@ -6137,6 +6373,8 @@ def export_builtins(file: io.TextIOBase): # pragma: no cover
6137
6373
  file.write('extern "C" {\n\n')
6138
6374
 
6139
6375
  for k, g in builtin_functions.items():
6376
+ if not hasattr(g, "overloads"):
6377
+ continue
6140
6378
  for f in g.overloads:
6141
6379
  if not f.export or f.generic:
6142
6380
  continue