warp-lang 1.8.0__py3-none-manylinux_2_34_aarch64.whl → 1.9.0__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (153) hide show
  1. warp/__init__.py +282 -103
  2. warp/__init__.pyi +482 -110
  3. warp/bin/warp-clang.so +0 -0
  4. warp/bin/warp.so +0 -0
  5. warp/build.py +93 -30
  6. warp/build_dll.py +48 -63
  7. warp/builtins.py +955 -137
  8. warp/codegen.py +327 -209
  9. warp/config.py +1 -1
  10. warp/context.py +1363 -800
  11. warp/examples/core/example_marching_cubes.py +1 -0
  12. warp/examples/core/example_render_opengl.py +100 -3
  13. warp/examples/fem/example_apic_fluid.py +98 -52
  14. warp/examples/fem/example_convection_diffusion_dg.py +25 -4
  15. warp/examples/fem/example_diffusion_mgpu.py +8 -3
  16. warp/examples/fem/utils.py +68 -22
  17. warp/examples/interop/example_jax_callable.py +34 -4
  18. warp/examples/interop/example_jax_kernel.py +27 -1
  19. warp/fabric.py +1 -1
  20. warp/fem/cache.py +27 -19
  21. warp/fem/domain.py +2 -2
  22. warp/fem/field/nodal_field.py +2 -2
  23. warp/fem/field/virtual.py +266 -166
  24. warp/fem/geometry/geometry.py +5 -5
  25. warp/fem/integrate.py +200 -91
  26. warp/fem/space/restriction.py +4 -0
  27. warp/fem/space/shape/tet_shape_function.py +3 -10
  28. warp/jax_experimental/custom_call.py +1 -1
  29. warp/jax_experimental/ffi.py +203 -54
  30. warp/marching_cubes.py +708 -0
  31. warp/native/array.h +103 -8
  32. warp/native/builtin.h +90 -9
  33. warp/native/bvh.cpp +64 -28
  34. warp/native/bvh.cu +58 -58
  35. warp/native/bvh.h +2 -2
  36. warp/native/clang/clang.cpp +7 -7
  37. warp/native/coloring.cpp +13 -3
  38. warp/native/crt.cpp +2 -2
  39. warp/native/crt.h +3 -5
  40. warp/native/cuda_util.cpp +42 -11
  41. warp/native/cuda_util.h +10 -4
  42. warp/native/exports.h +1842 -1908
  43. warp/native/fabric.h +2 -1
  44. warp/native/hashgrid.cpp +37 -37
  45. warp/native/hashgrid.cu +2 -2
  46. warp/native/initializer_array.h +1 -1
  47. warp/native/intersect.h +4 -4
  48. warp/native/mat.h +1913 -119
  49. warp/native/mathdx.cpp +43 -43
  50. warp/native/mesh.cpp +24 -24
  51. warp/native/mesh.cu +26 -26
  52. warp/native/mesh.h +5 -3
  53. warp/native/nanovdb/GridHandle.h +179 -12
  54. warp/native/nanovdb/HostBuffer.h +8 -7
  55. warp/native/nanovdb/NanoVDB.h +517 -895
  56. warp/native/nanovdb/NodeManager.h +323 -0
  57. warp/native/nanovdb/PNanoVDB.h +2 -2
  58. warp/native/quat.h +337 -16
  59. warp/native/rand.h +7 -7
  60. warp/native/range.h +7 -1
  61. warp/native/reduce.cpp +10 -10
  62. warp/native/reduce.cu +13 -14
  63. warp/native/runlength_encode.cpp +2 -2
  64. warp/native/runlength_encode.cu +5 -5
  65. warp/native/scan.cpp +3 -3
  66. warp/native/scan.cu +4 -4
  67. warp/native/sort.cpp +10 -10
  68. warp/native/sort.cu +22 -22
  69. warp/native/sparse.cpp +8 -8
  70. warp/native/sparse.cu +14 -14
  71. warp/native/spatial.h +366 -17
  72. warp/native/svd.h +23 -8
  73. warp/native/temp_buffer.h +2 -2
  74. warp/native/tile.h +303 -70
  75. warp/native/tile_radix_sort.h +5 -1
  76. warp/native/tile_reduce.h +16 -25
  77. warp/native/tuple.h +2 -2
  78. warp/native/vec.h +385 -18
  79. warp/native/volume.cpp +54 -54
  80. warp/native/volume.cu +1 -1
  81. warp/native/volume.h +2 -1
  82. warp/native/volume_builder.cu +30 -37
  83. warp/native/warp.cpp +150 -149
  84. warp/native/warp.cu +337 -193
  85. warp/native/warp.h +227 -226
  86. warp/optim/linear.py +736 -271
  87. warp/render/imgui_manager.py +289 -0
  88. warp/render/render_opengl.py +137 -57
  89. warp/render/render_usd.py +0 -1
  90. warp/sim/collide.py +1 -2
  91. warp/sim/graph_coloring.py +2 -2
  92. warp/sim/integrator_vbd.py +10 -2
  93. warp/sparse.py +559 -176
  94. warp/tape.py +2 -0
  95. warp/tests/aux_test_module_aot.py +7 -0
  96. warp/tests/cuda/test_async.py +3 -3
  97. warp/tests/cuda/test_conditional_captures.py +101 -0
  98. warp/tests/geometry/test_marching_cubes.py +233 -12
  99. warp/tests/sim/test_cloth.py +89 -6
  100. warp/tests/sim/test_coloring.py +82 -7
  101. warp/tests/test_array.py +56 -5
  102. warp/tests/test_assert.py +53 -0
  103. warp/tests/test_atomic_cas.py +127 -114
  104. warp/tests/test_codegen.py +3 -2
  105. warp/tests/test_context.py +8 -15
  106. warp/tests/test_enum.py +136 -0
  107. warp/tests/test_examples.py +2 -2
  108. warp/tests/test_fem.py +45 -2
  109. warp/tests/test_fixedarray.py +229 -0
  110. warp/tests/test_func.py +18 -15
  111. warp/tests/test_future_annotations.py +7 -5
  112. warp/tests/test_linear_solvers.py +30 -0
  113. warp/tests/test_map.py +1 -1
  114. warp/tests/test_mat.py +1540 -378
  115. warp/tests/test_mat_assign_copy.py +178 -0
  116. warp/tests/test_mat_constructors.py +574 -0
  117. warp/tests/test_module_aot.py +287 -0
  118. warp/tests/test_print.py +69 -0
  119. warp/tests/test_quat.py +162 -34
  120. warp/tests/test_quat_assign_copy.py +145 -0
  121. warp/tests/test_reload.py +2 -1
  122. warp/tests/test_sparse.py +103 -0
  123. warp/tests/test_spatial.py +140 -34
  124. warp/tests/test_spatial_assign_copy.py +160 -0
  125. warp/tests/test_static.py +48 -0
  126. warp/tests/test_struct.py +43 -3
  127. warp/tests/test_tape.py +38 -0
  128. warp/tests/test_types.py +0 -20
  129. warp/tests/test_vec.py +216 -441
  130. warp/tests/test_vec_assign_copy.py +143 -0
  131. warp/tests/test_vec_constructors.py +325 -0
  132. warp/tests/tile/test_tile.py +206 -152
  133. warp/tests/tile/test_tile_cholesky.py +605 -0
  134. warp/tests/tile/test_tile_load.py +169 -0
  135. warp/tests/tile/test_tile_mathdx.py +2 -558
  136. warp/tests/tile/test_tile_matmul.py +179 -0
  137. warp/tests/tile/test_tile_mlp.py +1 -1
  138. warp/tests/tile/test_tile_reduce.py +100 -11
  139. warp/tests/tile/test_tile_shared_memory.py +16 -16
  140. warp/tests/tile/test_tile_sort.py +59 -55
  141. warp/tests/unittest_suites.py +16 -0
  142. warp/tests/walkthrough_debug.py +1 -1
  143. warp/thirdparty/unittest_parallel.py +108 -9
  144. warp/types.py +554 -264
  145. warp/utils.py +68 -86
  146. {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/METADATA +28 -65
  147. {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/RECORD +150 -138
  148. warp/native/marching.cpp +0 -19
  149. warp/native/marching.cu +0 -514
  150. warp/native/marching.h +0 -19
  151. {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/WHEEL +0 -0
  152. {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/licenses/LICENSE.md +0 -0
  153. {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/top_level.txt +0 -0
warp/context.py CHANGED
@@ -26,13 +26,28 @@ import json
26
26
  import operator
27
27
  import os
28
28
  import platform
29
+ import shutil
29
30
  import sys
30
31
  import types
31
32
  import typing
32
33
  import weakref
33
34
  from copy import copy as shallowcopy
34
35
  from pathlib import Path
35
- from typing import Any, Callable, Dict, List, Literal, Mapping, Sequence, Tuple, TypeVar, Union, get_args, get_origin
36
+ from typing import (
37
+ Any,
38
+ Callable,
39
+ Dict,
40
+ Iterable,
41
+ List,
42
+ Literal,
43
+ Mapping,
44
+ Sequence,
45
+ Tuple,
46
+ TypeVar,
47
+ Union,
48
+ get_args,
49
+ get_origin,
50
+ )
36
51
 
37
52
  import numpy as np
38
53
 
@@ -327,39 +342,25 @@ class Function:
327
342
  warp.codegen.apply_defaults(bound_args, self.defaults)
328
343
 
329
344
  arguments = tuple(bound_args.arguments.values())
330
-
331
- # Store the last runtime error we encountered from a function execution
332
- last_execution_error = None
345
+ arg_types = tuple(warp.codegen.get_arg_type(x) for x in arguments)
333
346
 
334
347
  # try and find a matching overload
335
348
  for overload in self.user_overloads.values():
336
349
  if len(overload.input_types) != len(arguments):
337
350
  continue
351
+
352
+ if not warp.codegen.func_match_args(overload, arg_types, {}):
353
+ continue
354
+
338
355
  template_types = list(overload.input_types.values())
339
356
  arg_names = list(overload.input_types.keys())
340
- try:
341
- # attempt to unify argument types with function template types
342
- warp.types.infer_argument_types(arguments, template_types, arg_names)
343
- return overload.func(*arguments)
344
- except Exception as e:
345
- # The function was callable but threw an error during its execution.
346
- # This might be the intended overload, but it failed, or it might be the wrong overload.
347
- # We save this specific error and continue, just in case another overload later in the
348
- # list is a better match and doesn't fail.
349
- last_execution_error = e
350
- continue
351
357
 
352
- if last_execution_error:
353
- # Raise a new, more contextual RuntimeError, but link it to the
354
- # original error that was caught. This preserves the original
355
- # traceback and error type for easier debugging.
356
- raise RuntimeError(
357
- f"Error calling function '{self.key}'. No version succeeded. "
358
- f"See above for the error from the last version that was tried."
359
- ) from last_execution_error
360
- else:
361
- # We got here without ever calling an overload.func
362
- raise RuntimeError(f"Error calling function '{self.key}', no overload found for arguments {args}")
358
+ # attempt to unify argument types with function template types
359
+ warp.types.infer_argument_types(arguments, template_types, arg_names)
360
+ return overload.func(*arguments)
361
+
362
+ # We got here without ever calling an overload.func
363
+ raise RuntimeError(f"Error calling function '{self.key}', no overload found for arguments {args}")
363
364
 
364
365
  # user-defined function with no overloads
365
366
  if self.func is None:
@@ -385,7 +386,7 @@ class Function:
385
386
  def mangle(self) -> str:
386
387
  """Build a mangled name for the C-exported function, e.g.: `builtin_normalize_vec3()`."""
387
388
 
388
- name = "builtin_" + self.key
389
+ name = "wp_builtin_" + self.key
389
390
 
390
391
  # Runtime arguments that are to be passed to the function, not its template signature.
391
392
  if self.export_func is not None:
@@ -475,6 +476,25 @@ class Function:
475
476
  # failed to find overload
476
477
  return None
477
478
 
479
+ def build(self, builder: ModuleBuilder | None):
480
+ self.adj.build(builder)
481
+
482
+ # complete the function return type after we have analyzed it (inferred from return statement in ast)
483
+ if not self.value_func:
484
+
485
+ def wrap(adj):
486
+ def value_type(arg_types, arg_values):
487
+ if adj.return_var is None or len(adj.return_var) == 0:
488
+ return None
489
+ if len(adj.return_var) == 1:
490
+ return adj.return_var[0].type
491
+ else:
492
+ return [v.type for v in adj.return_var]
493
+
494
+ return value_type
495
+
496
+ self.value_func = wrap(self.adj)
497
+
478
498
  def __repr__(self):
479
499
  inputs_str = ", ".join([f"{k}: {warp.types.type_repr(v)}" for k, v in self.input_types.items()])
480
500
  return f"<Function {self.key}({inputs_str})>"
@@ -807,14 +827,17 @@ class Kernel:
807
827
  sig = warp.types.get_signature(arg_types, func_name=self.key)
808
828
  return self.overloads.get(sig)
809
829
 
810
- def get_mangled_name(self):
811
- if self.hash is None:
812
- raise RuntimeError(f"Missing hash for kernel {self.key} in module {self.module.name}")
830
+ def get_mangled_name(self) -> str:
831
+ if self.module.options["strip_hash"]:
832
+ return self.key
833
+ else:
834
+ if self.hash is None:
835
+ raise RuntimeError(f"Missing hash for kernel {self.key} in module {self.module.name}")
813
836
 
814
- # TODO: allow customizing the number of hash characters used
815
- hash_suffix = self.hash.hex()[:8]
837
+ # TODO: allow customizing the number of hash characters used
838
+ hash_suffix = self.hash.hex()[:8]
816
839
 
817
- return f"{self.key}_{hash_suffix}"
840
+ return f"{self.key}_{hash_suffix}"
818
841
 
819
842
  def __call__(self, *args, **kwargs):
820
843
  # we implement this function only to ensure Kernel is a callable object
@@ -1597,6 +1620,9 @@ class ModuleHasher:
1597
1620
  # line directives, e.g. for Nsight Compute
1598
1621
  ch.update(bytes(ctypes.c_int(warp.config.line_directives)))
1599
1622
 
1623
+ # whether to use `assign_copy` instead of `assign_inplace`
1624
+ ch.update(bytes(ctypes.c_int(warp.config.enable_vector_component_overwrites)))
1625
+
1600
1626
  # build config
1601
1627
  ch.update(bytes(warp.config.mode, "utf-8"))
1602
1628
 
@@ -1692,7 +1718,7 @@ class ModuleHasher:
1692
1718
  ch.update(bytes(name, "utf-8"))
1693
1719
  ch.update(self.get_constant_bytes(value))
1694
1720
 
1695
- # hash wp.static() expressions that were evaluated at declaration time
1721
+ # hash wp.static() expressions
1696
1722
  for k, v in adj.static_expressions.items():
1697
1723
  ch.update(bytes(k, "utf-8"))
1698
1724
  if isinstance(v, Function):
@@ -1784,6 +1810,9 @@ class ModuleBuilder:
1784
1810
  self.structs[struct] = None
1785
1811
 
1786
1812
  def build_kernel(self, kernel):
1813
+ if kernel.options.get("enable_backward", True):
1814
+ kernel.adj.used_by_backward_kernel = True
1815
+
1787
1816
  kernel.adj.build(self)
1788
1817
 
1789
1818
  if kernel.adj.return_var is not None:
@@ -1794,23 +1823,7 @@ class ModuleBuilder:
1794
1823
  if func in self.functions:
1795
1824
  return
1796
1825
  else:
1797
- func.adj.build(self)
1798
-
1799
- # complete the function return type after we have analyzed it (inferred from return statement in ast)
1800
- if not func.value_func:
1801
-
1802
- def wrap(adj):
1803
- def value_type(arg_types, arg_values):
1804
- if adj.return_var is None or len(adj.return_var) == 0:
1805
- return None
1806
- if len(adj.return_var) == 1:
1807
- return adj.return_var[0].type
1808
- else:
1809
- return [v.type for v in adj.return_var]
1810
-
1811
- return value_type
1812
-
1813
- func.value_func = wrap(func.adj)
1826
+ func.build(self)
1814
1827
 
1815
1828
  # use dict to preserve import order
1816
1829
  self.functions[func] = None
@@ -1830,10 +1843,11 @@ class ModuleBuilder:
1830
1843
  source = ""
1831
1844
 
1832
1845
  # code-gen LTO forward declarations
1833
- source += 'extern "C" {\n'
1834
- for fwd in self.ltoirs_decl.values():
1835
- source += fwd + "\n"
1836
- source += "}\n"
1846
+ if len(self.ltoirs_decl) > 0:
1847
+ source += 'extern "C" {\n'
1848
+ for fwd in self.ltoirs_decl.values():
1849
+ source += fwd + "\n"
1850
+ source += "}\n"
1837
1851
 
1838
1852
  # code-gen structs
1839
1853
  visited_structs = set()
@@ -1898,9 +1912,9 @@ class ModuleExec:
1898
1912
  if self.device.is_cuda:
1899
1913
  # use CUDA context guard to avoid side effects during garbage collection
1900
1914
  with self.device.context_guard:
1901
- runtime.core.cuda_unload_module(self.device.context, self.handle)
1915
+ runtime.core.wp_cuda_unload_module(self.device.context, self.handle)
1902
1916
  else:
1903
- runtime.llvm.unload_obj(self.handle.encode("utf-8"))
1917
+ runtime.llvm.wp_unload_obj(self.handle.encode("utf-8"))
1904
1918
 
1905
1919
  # lookup and cache kernel entry points
1906
1920
  def get_kernel_hooks(self, kernel) -> KernelHooks:
@@ -1918,13 +1932,13 @@ class ModuleExec:
1918
1932
 
1919
1933
  if self.device.is_cuda:
1920
1934
  forward_name = name + "_cuda_kernel_forward"
1921
- forward_kernel = runtime.core.cuda_get_kernel(
1935
+ forward_kernel = runtime.core.wp_cuda_get_kernel(
1922
1936
  self.device.context, self.handle, forward_name.encode("utf-8")
1923
1937
  )
1924
1938
 
1925
1939
  if options["enable_backward"]:
1926
1940
  backward_name = name + "_cuda_kernel_backward"
1927
- backward_kernel = runtime.core.cuda_get_kernel(
1941
+ backward_kernel = runtime.core.wp_cuda_get_kernel(
1928
1942
  self.device.context, self.handle, backward_name.encode("utf-8")
1929
1943
  )
1930
1944
  else:
@@ -1935,14 +1949,14 @@ class ModuleExec:
1935
1949
  backward_smem_bytes = self.meta[backward_name + "_smem_bytes"] if options["enable_backward"] else 0
1936
1950
 
1937
1951
  # configure kernels maximum shared memory size
1938
- max_smem_bytes = runtime.core.cuda_get_max_shared_memory(self.device.context)
1952
+ max_smem_bytes = runtime.core.wp_cuda_get_max_shared_memory(self.device.context)
1939
1953
 
1940
- if not runtime.core.cuda_configure_kernel_shared_memory(forward_kernel, forward_smem_bytes):
1954
+ if not runtime.core.wp_cuda_configure_kernel_shared_memory(forward_kernel, forward_smem_bytes):
1941
1955
  print(
1942
1956
  f"Warning: Failed to configure kernel dynamic shared memory for this device, tried to configure {forward_name} kernel for {forward_smem_bytes} bytes, but maximum available is {max_smem_bytes}"
1943
1957
  )
1944
1958
 
1945
- if options["enable_backward"] and not runtime.core.cuda_configure_kernel_shared_memory(
1959
+ if options["enable_backward"] and not runtime.core.wp_cuda_configure_kernel_shared_memory(
1946
1960
  backward_kernel, backward_smem_bytes
1947
1961
  ):
1948
1962
  print(
@@ -1954,12 +1968,13 @@ class ModuleExec:
1954
1968
  else:
1955
1969
  func = ctypes.CFUNCTYPE(None)
1956
1970
  forward = (
1957
- func(runtime.llvm.lookup(self.handle.encode("utf-8"), (name + "_cpu_forward").encode("utf-8"))) or None
1971
+ func(runtime.llvm.wp_lookup(self.handle.encode("utf-8"), (name + "_cpu_forward").encode("utf-8")))
1972
+ or None
1958
1973
  )
1959
1974
 
1960
1975
  if options["enable_backward"]:
1961
1976
  backward = (
1962
- func(runtime.llvm.lookup(self.handle.encode("utf-8"), (name + "_cpu_backward").encode("utf-8")))
1977
+ func(runtime.llvm.wp_lookup(self.handle.encode("utf-8"), (name + "_cpu_backward").encode("utf-8")))
1963
1978
  or None
1964
1979
  )
1965
1980
  else:
@@ -1971,6 +1986,25 @@ class ModuleExec:
1971
1986
  return hooks
1972
1987
 
1973
1988
 
1989
+ def _check_and_raise_long_path_error(e: FileNotFoundError):
1990
+ """Check if the error is due to a Windows long path and provide work-around instructions if it is.
1991
+
1992
+ ``FileNotFoundError.filename`` may legitimately be ``None`` when the originating
1993
+ API does not supply a path. Guard against that to avoid masking the original
1994
+ error with a ``TypeError``.
1995
+ """
1996
+ filename = getattr(e, "filename", None)
1997
+
1998
+ # Fast-exit when this is clearly not a legacy-path limitation:
1999
+ if filename is None or len(filename) < 260 or os.name != "nt" or filename.startswith("\\\\?\\"):
2000
+ raise e
2001
+
2002
+ raise RuntimeError(
2003
+ f"File path '{e.filename}' exceeds 259 characters, long-path support is required for this operation. "
2004
+ "See https://learn.microsoft.com/en-us/windows/win32/fileio/maximum-file-path-limitation for more information."
2005
+ ) from e
2006
+
2007
+
1974
2008
  # -----------------------------------------------------
1975
2009
  # stores all functions and kernels for a Python module
1976
2010
  # creates a hash of the function to use for checking
@@ -2011,6 +2045,9 @@ class Module:
2011
2045
  # is retained and later reloaded with the same hash.
2012
2046
  self.cpu_exec_id = 0
2013
2047
 
2048
+ # Indicates whether the module has functions or kernels with unresolved static expressions.
2049
+ self.has_unresolved_static_expressions = False
2050
+
2014
2051
  self.options = {
2015
2052
  "max_unroll": warp.config.max_unroll,
2016
2053
  "enable_backward": warp.config.enable_backward,
@@ -2018,9 +2055,10 @@ class Module:
2018
2055
  "fuse_fp": True,
2019
2056
  "lineinfo": warp.config.lineinfo,
2020
2057
  "cuda_output": None, # supported values: "ptx", "cubin", or None (automatic)
2021
- "mode": warp.config.mode,
2058
+ "mode": None,
2022
2059
  "block_dim": 256,
2023
2060
  "compile_time_trace": warp.config.compile_time_trace,
2061
+ "strip_hash": False,
2024
2062
  }
2025
2063
 
2026
2064
  # Module dependencies are determined by scanning each function
@@ -2047,6 +2085,10 @@ class Module:
2047
2085
  # track all kernel objects, even if they are duplicates
2048
2086
  self._live_kernels.add(kernel)
2049
2087
 
2088
+ # Check for unresolved static expressions in the kernel.
2089
+ if kernel.adj.has_unresolved_static_expressions:
2090
+ self.has_unresolved_static_expressions = True
2091
+
2050
2092
  self.find_references(kernel.adj)
2051
2093
 
2052
2094
  # for a reload of module on next launch
@@ -2106,6 +2148,10 @@ class Module:
2106
2148
  del func_existing.user_overloads[k]
2107
2149
  func_existing.add_overload(func)
2108
2150
 
2151
+ # Check for unresolved static expressions in the function.
2152
+ if func.adj.has_unresolved_static_expressions:
2153
+ self.has_unresolved_static_expressions = True
2154
+
2109
2155
  self.find_references(func.adj)
2110
2156
 
2111
2157
  # for a reload of module on next launch
@@ -2159,224 +2205,419 @@ class Module:
2159
2205
  if isinstance(arg.type, warp.codegen.Struct) and arg.type.module is not None:
2160
2206
  add_ref(arg.type.module)
2161
2207
 
2162
- def hash_module(self):
2208
+ def hash_module(self) -> bytes:
2209
+ """Get the hash of the module for the current block_dim.
2210
+
2211
+ This function always creates a new `ModuleHasher` instance and computes the hash.
2212
+ """
2163
2213
  # compute latest hash
2164
2214
  block_dim = self.options["block_dim"]
2165
2215
  self.hashers[block_dim] = ModuleHasher(self)
2166
2216
  return self.hashers[block_dim].get_module_hash()
2167
2217
 
2168
- def load(self, device, block_dim=None) -> ModuleExec:
2169
- device = runtime.get_device(device)
2170
-
2171
- # update module options if launching with a new block dim
2172
- if block_dim is not None:
2173
- self.options["block_dim"] = block_dim
2218
+ def get_module_hash(self, block_dim: int | None = None) -> bytes:
2219
+ """Get the hash of the module for the current block_dim.
2174
2220
 
2175
- active_block_dim = self.options["block_dim"]
2221
+ If a hash has not been computed for the current block_dim, it will be computed and cached.
2222
+ """
2223
+ if block_dim is None:
2224
+ block_dim = self.options["block_dim"]
2225
+
2226
+ if self.has_unresolved_static_expressions:
2227
+ # The module hash currently does not account for unresolved static expressions
2228
+ # (only static expressions evaluated at declaration time so far).
2229
+ # We need to generate the code for the functions and kernels that have
2230
+ # unresolved static expressions and then compute the module hash again.
2231
+ builder_options = {
2232
+ **self.options,
2233
+ "output_arch": None,
2234
+ }
2235
+ # build functions, kernels to resolve static expressions
2236
+ _ = ModuleBuilder(self, builder_options)
2237
+
2238
+ self.has_unresolved_static_expressions = False
2176
2239
 
2177
2240
  # compute the hash if needed
2178
- if active_block_dim not in self.hashers:
2179
- self.hashers[active_block_dim] = ModuleHasher(self)
2241
+ if block_dim not in self.hashers:
2242
+ self.hashers[block_dim] = ModuleHasher(self)
2180
2243
 
2181
- # check if executable module is already loaded and not stale
2182
- exec = self.execs.get((device.context, active_block_dim))
2183
- if exec is not None:
2184
- if exec.module_hash == self.hashers[active_block_dim].get_module_hash():
2185
- return exec
2244
+ return self.hashers[block_dim].get_module_hash()
2186
2245
 
2187
- # quietly avoid repeated build attempts to reduce error spew
2188
- if device.context in self.failed_builds:
2246
+ def _use_ptx(self, device) -> bool:
2247
+ # determine whether to use PTX or CUBIN
2248
+ if device.is_cubin_supported:
2249
+ # get user preference specified either per module or globally
2250
+ preferred_cuda_output = self.options.get("cuda_output") or warp.config.cuda_output
2251
+ if preferred_cuda_output is not None:
2252
+ use_ptx = preferred_cuda_output == "ptx"
2253
+ else:
2254
+ # determine automatically: older drivers may not be able to handle PTX generated using newer
2255
+ # CUDA Toolkits, in which case we fall back on generating CUBIN modules
2256
+ use_ptx = runtime.driver_version >= runtime.toolkit_version
2257
+ else:
2258
+ # CUBIN not an option, must use PTX (e.g. CUDA Toolkit too old)
2259
+ use_ptx = True
2260
+
2261
+ return use_ptx
2262
+
2263
+ def get_module_identifier(self) -> str:
2264
+ """Get an abbreviated module name to use for directories and files in the cache.
2265
+
2266
+ Depending on the setting of the ``"strip_hash"`` option for this module,
2267
+ the module identifier might include a content-dependent hash as a suffix.
2268
+ """
2269
+ if self.options["strip_hash"]:
2270
+ module_name_short = f"wp_{self.name}"
2271
+ else:
2272
+ module_hash = self.get_module_hash()
2273
+ module_name_short = f"wp_{self.name}_{module_hash.hex()[:7]}"
2274
+
2275
+ return module_name_short
2276
+
2277
+ def get_compile_arch(self, device: Device | None = None) -> int | None:
2278
+ if device is None:
2279
+ device = runtime.get_device()
2280
+
2281
+ if device.is_cpu:
2189
2282
  return None
2190
2283
 
2191
- module_name = "wp_" + self.name
2192
- module_hash = self.hashers[active_block_dim].get_module_hash()
2284
+ if self._use_ptx(device):
2285
+ # use the default PTX arch if the device supports it
2286
+ if warp.config.ptx_target_arch is not None:
2287
+ output_arch = min(device.arch, warp.config.ptx_target_arch)
2288
+ else:
2289
+ output_arch = min(device.arch, runtime.default_ptx_arch)
2290
+ else:
2291
+ output_arch = device.arch
2292
+
2293
+ return output_arch
2193
2294
 
2194
- # use a unique module path using the module short hash
2195
- module_name_short = f"{module_name}_{module_hash.hex()[:7]}"
2196
- module_dir = os.path.join(warp.config.kernel_cache_dir, module_name_short)
2295
+ def get_compile_output_name(
2296
+ self, device: Device | None, output_arch: int | None = None, use_ptx: bool | None = None
2297
+ ) -> str:
2298
+ """Get the filename to use for the compiled module binary.
2197
2299
 
2198
- with warp.ScopedTimer(
2199
- f"Module {self.name} {module_hash.hex()[:7]} load on device '{device}'", active=not warp.config.quiet
2200
- ) as module_load_timer:
2201
- # -----------------------------------------------------------
2202
- # determine output paths
2203
- if device.is_cpu:
2204
- output_name = f"{module_name_short}.o"
2205
- output_arch = None
2300
+ This is only the filename, e.g. ``wp___main___0340cd1.sm86.ptx``.
2301
+ It should be used to form a path.
2302
+ """
2303
+ module_name_short = self.get_module_identifier()
2206
2304
 
2207
- elif device.is_cuda:
2208
- # determine whether to use PTX or CUBIN
2209
- if device.is_cubin_supported:
2210
- # get user preference specified either per module or globally
2211
- preferred_cuda_output = self.options.get("cuda_output") or warp.config.cuda_output
2212
- if preferred_cuda_output is not None:
2213
- use_ptx = preferred_cuda_output == "ptx"
2214
- else:
2215
- # determine automatically: older drivers may not be able to handle PTX generated using newer
2216
- # CUDA Toolkits, in which case we fall back on generating CUBIN modules
2217
- use_ptx = runtime.driver_version >= runtime.toolkit_version
2218
- else:
2219
- # CUBIN not an option, must use PTX (e.g. CUDA Toolkit too old)
2220
- use_ptx = True
2305
+ if device and device.is_cpu:
2306
+ return f"{module_name_short}.o"
2221
2307
 
2222
- if use_ptx:
2223
- # use the default PTX arch if the device supports it
2224
- if warp.config.ptx_target_arch is not None:
2225
- output_arch = min(device.arch, warp.config.ptx_target_arch)
2226
- else:
2227
- output_arch = min(device.arch, runtime.default_ptx_arch)
2228
- output_name = f"{module_name_short}.sm{output_arch}.ptx"
2229
- else:
2230
- output_arch = device.arch
2231
- output_name = f"{module_name_short}.sm{output_arch}.cubin"
2308
+ # For CUDA compilation, we must have an architecture.
2309
+ final_arch = output_arch
2310
+ if final_arch is None:
2311
+ if device:
2312
+ # Infer the architecture from the device
2313
+ final_arch = self.get_compile_arch(device)
2314
+ else:
2315
+ raise ValueError(
2316
+ "Either 'device' or 'output_arch' must be provided to determine compilation architecture"
2317
+ )
2318
+
2319
+ # Determine if we should compile to PTX or CUBIN
2320
+ if use_ptx is None:
2321
+ if device:
2322
+ use_ptx = self._use_ptx(device)
2323
+ else:
2324
+ init()
2325
+ use_ptx = final_arch not in runtime.nvrtc_supported_archs
2326
+
2327
+ if use_ptx:
2328
+ output_name = f"{module_name_short}.sm{final_arch}.ptx"
2329
+ else:
2330
+ output_name = f"{module_name_short}.sm{final_arch}.cubin"
2331
+
2332
+ return output_name
2333
+
2334
+ def get_meta_name(self) -> str:
2335
+ """Get the filename to use for the module metadata file.
2336
+
2337
+ This is only the filename. It should be used to form a path.
2338
+ """
2339
+ return f"{self.get_module_identifier()}.meta"
2340
+
2341
+ def compile(
2342
+ self,
2343
+ device: Device | None = None,
2344
+ output_dir: str | os.PathLike | None = None,
2345
+ output_name: str | None = None,
2346
+ output_arch: int | None = None,
2347
+ use_ptx: bool | None = None,
2348
+ ) -> None:
2349
+ """Compile this module for a specific device.
2350
+
2351
+ Note that this function only generates and compiles code. The resulting
2352
+ binary is not loaded into the runtime.
2353
+
2354
+ Args:
2355
+ device: The device to compile the module for.
2356
+ output_dir: The directory to write the compiled module to.
2357
+ output_name: The name of the compiled module binary file.
2358
+ output_arch: The architecture to compile the module for.
2359
+ """
2360
+ if output_arch is None:
2361
+ output_arch = self.get_compile_arch(device) # Will remain at None if device is CPU
2362
+
2363
+ if output_name is None:
2364
+ output_name = self.get_compile_output_name(device, output_arch, use_ptx)
2365
+
2366
+ builder_options = {
2367
+ **self.options,
2368
+ # Some of the tile codegen, such as cuFFTDx and cuBLASDx, requires knowledge of the target arch
2369
+ "output_arch": output_arch,
2370
+ }
2371
+ builder = ModuleBuilder(
2372
+ self,
2373
+ builder_options,
2374
+ hasher=self.hashers.get(self.options["block_dim"], None),
2375
+ )
2376
+
2377
+ # create a temporary (process unique) dir for build outputs before moving to the binary dir
2378
+ module_name_short = self.get_module_identifier()
2379
+
2380
+ if output_dir is None:
2381
+ output_dir = os.path.join(warp.config.kernel_cache_dir, f"{module_name_short}")
2382
+ else:
2383
+ output_dir = os.fspath(output_dir)
2384
+
2385
+ meta_path = os.path.join(output_dir, self.get_meta_name())
2232
2386
 
2387
+ build_dir = os.path.normpath(output_dir) + f"_p{os.getpid()}"
2388
+
2389
+ # dir may exist from previous attempts / runs / archs
2390
+ Path(build_dir).mkdir(parents=True, exist_ok=True)
2391
+
2392
+ mode = self.options["mode"] if self.options["mode"] is not None else warp.config.mode
2393
+
2394
+ # build CPU
2395
+ if output_arch is None:
2396
+ # build
2397
+ try:
2398
+ source_code_path = os.path.join(build_dir, f"{module_name_short}.cpp")
2399
+
2400
+ # write cpp sources
2401
+ cpp_source = builder.codegen("cpu")
2402
+
2403
+ with open(source_code_path, "w") as cpp_file:
2404
+ cpp_file.write(cpp_source)
2405
+
2406
+ output_path = os.path.join(build_dir, output_name)
2407
+
2408
+ # build object code
2409
+ with warp.ScopedTimer("Compile x86", active=warp.config.verbose):
2410
+ warp.build.build_cpu(
2411
+ output_path,
2412
+ source_code_path,
2413
+ mode=mode,
2414
+ fast_math=self.options["fast_math"],
2415
+ verify_fp=warp.config.verify_fp,
2416
+ fuse_fp=self.options["fuse_fp"],
2417
+ )
2418
+
2419
+ except Exception as e:
2420
+ if isinstance(e, FileNotFoundError):
2421
+ _check_and_raise_long_path_error(e)
2422
+
2423
+ self.failed_builds.add(None)
2424
+
2425
+ raise (e)
2426
+
2427
+ else:
2428
+ # build
2429
+ try:
2430
+ source_code_path = os.path.join(build_dir, f"{module_name_short}.cu")
2431
+
2432
+ # write cuda sources
2433
+ cu_source = builder.codegen("cuda")
2434
+
2435
+ with open(source_code_path, "w") as cu_file:
2436
+ cu_file.write(cu_source)
2437
+
2438
+ output_path = os.path.join(build_dir, output_name)
2439
+
2440
+ # generate PTX or CUBIN
2441
+ with warp.ScopedTimer(
2442
+ f"Compile CUDA (arch={builder_options['output_arch']}, mode={mode}, block_dim={self.options['block_dim']})",
2443
+ active=warp.config.verbose,
2444
+ ):
2445
+ warp.build.build_cuda(
2446
+ source_code_path,
2447
+ builder_options["output_arch"],
2448
+ output_path,
2449
+ config=mode,
2450
+ verify_fp=warp.config.verify_fp,
2451
+ fast_math=self.options["fast_math"],
2452
+ fuse_fp=self.options["fuse_fp"],
2453
+ lineinfo=self.options["lineinfo"],
2454
+ compile_time_trace=self.options["compile_time_trace"],
2455
+ ltoirs=builder.ltoirs.values(),
2456
+ fatbins=builder.fatbins.values(),
2457
+ )
2458
+
2459
+ except Exception as e:
2460
+ if isinstance(e, FileNotFoundError):
2461
+ _check_and_raise_long_path_error(e)
2462
+
2463
+ if device:
2464
+ self.failed_builds.add(device.context)
2465
+
2466
+ raise (e)
2467
+
2468
+ # ------------------------------------------------------------
2469
+ # build meta data
2470
+
2471
+ meta = builder.build_meta()
2472
+ output_meta_path = os.path.join(build_dir, self.get_meta_name())
2473
+
2474
+ with open(output_meta_path, "w") as meta_file:
2475
+ json.dump(meta, meta_file)
2476
+
2477
+ # -----------------------------------------------------------
2478
+ # update cache
2479
+
2480
+ # try to move process outputs to cache
2481
+ warp.build.safe_rename(build_dir, output_dir)
2482
+
2483
+ if os.path.exists(output_dir):
2233
2484
  # final object binary path
2234
- binary_path = os.path.join(module_dir, output_name)
2485
+ binary_path = os.path.join(output_dir, output_name)
2235
2486
 
2236
- # -----------------------------------------------------------
2237
- # check cache and build if necessary
2487
+ if not os.path.exists(binary_path) or self.options["strip_hash"]:
2488
+ # copy our output file to the destination module
2489
+ # this is necessary in case different processes
2490
+ # have different GPU architectures / devices
2491
+ try:
2492
+ os.rename(output_path, binary_path)
2493
+ except (OSError, FileExistsError):
2494
+ # another process likely updated the module dir first
2495
+ pass
2238
2496
 
2239
- build_dir = None
2497
+ if not os.path.exists(meta_path) or self.options["strip_hash"]:
2498
+ # copy our output file to the destination module
2499
+ # this is necessary in case different processes
2500
+ # have different GPU architectures / devices
2501
+ try:
2502
+ os.rename(output_meta_path, meta_path)
2503
+ except (OSError, FileExistsError):
2504
+ # another process likely updated the module dir first
2505
+ pass
2240
2506
 
2241
- # we always want to build if binary doesn't exist yet
2242
- # and we want to rebuild if we are not caching kernels or if we are tracking array access
2243
- if (
2244
- not os.path.exists(binary_path)
2245
- or not warp.config.cache_kernels
2246
- or warp.config.verify_autograd_array_access
2247
- ):
2248
- builder_options = {
2249
- **self.options,
2250
- # Some of the tile codegen, such as cuFFTDx and cuBLASDx, requires knowledge of the target arch
2251
- "output_arch": output_arch,
2252
- }
2253
- builder = ModuleBuilder(self, builder_options, hasher=self.hashers[active_block_dim])
2254
-
2255
- # create a temporary (process unique) dir for build outputs before moving to the binary dir
2256
- build_dir = os.path.join(
2257
- warp.config.kernel_cache_dir, f"{module_name}_{module_hash.hex()[:7]}_p{os.getpid()}"
2258
- )
2507
+ try:
2508
+ final_source_path = os.path.join(output_dir, os.path.basename(source_code_path))
2509
+ if not os.path.exists(final_source_path) or self.options["strip_hash"]:
2510
+ os.rename(source_code_path, final_source_path)
2511
+ except (OSError, FileExistsError):
2512
+ # another process likely updated the module dir first
2513
+ pass
2514
+ except Exception as e:
2515
+ # We don't need source_code_path to be copied successfully to proceed, so warn and keep running
2516
+ warp.utils.warn(f"Exception when renaming {source_code_path}: {e}")
2259
2517
 
2260
- # dir may exist from previous attempts / runs / archs
2261
- Path(build_dir).mkdir(parents=True, exist_ok=True)
2518
+ # clean up build_dir used for this process regardless
2519
+ shutil.rmtree(build_dir, ignore_errors=True)
2262
2520
 
2263
- module_load_timer.extra_msg = " (compiled)" # For wp.ScopedTimer informational purposes
2521
+ def load(
2522
+ self,
2523
+ device,
2524
+ block_dim: int | None = None,
2525
+ binary_path: os.PathLike | None = None,
2526
+ output_arch: int | None = None,
2527
+ meta_path: os.PathLike | None = None,
2528
+ ) -> ModuleExec | None:
2529
+ device = runtime.get_device(device)
2264
2530
 
2265
- # build CPU
2266
- if device.is_cpu:
2267
- # build
2268
- try:
2269
- source_code_path = os.path.join(build_dir, f"{module_name_short}.cpp")
2531
+ # update module options if launching with a new block dim
2532
+ if block_dim is not None:
2533
+ self.options["block_dim"] = block_dim
2270
2534
 
2271
- # write cpp sources
2272
- cpp_source = builder.codegen("cpu")
2535
+ active_block_dim = self.options["block_dim"]
2273
2536
 
2274
- with open(source_code_path, "w") as cpp_file:
2275
- cpp_file.write(cpp_source)
2537
+ # check if executable module is already loaded and not stale
2538
+ exec = self.execs.get((device.context, active_block_dim))
2539
+ if exec is not None:
2540
+ if self.options["strip_hash"] or (exec.module_hash == self.get_module_hash(active_block_dim)):
2541
+ return exec
2276
2542
 
2277
- output_path = os.path.join(build_dir, output_name)
2543
+ # quietly avoid repeated build attempts to reduce error spew
2544
+ if device.context in self.failed_builds:
2545
+ return None
2278
2546
 
2279
- # build object code
2280
- with warp.ScopedTimer("Compile x86", active=warp.config.verbose):
2281
- warp.build.build_cpu(
2282
- output_path,
2283
- source_code_path,
2284
- mode=self.options["mode"],
2285
- fast_math=self.options["fast_math"],
2286
- verify_fp=warp.config.verify_fp,
2287
- fuse_fp=self.options["fuse_fp"],
2288
- )
2547
+ module_hash = self.get_module_hash(active_block_dim)
2289
2548
 
2290
- except Exception as e:
2291
- self.failed_builds.add(None)
2292
- module_load_timer.extra_msg = " (error)"
2293
- raise (e)
2549
+ # use a unique module path using the module short hash
2550
+ module_name_short = self.get_module_identifier()
2294
2551
 
2295
- elif device.is_cuda:
2296
- # build
2297
- try:
2298
- source_code_path = os.path.join(build_dir, f"{module_name_short}.cu")
2299
-
2300
- # write cuda sources
2301
- cu_source = builder.codegen("cuda")
2302
-
2303
- with open(source_code_path, "w") as cu_file:
2304
- cu_file.write(cu_source)
2305
-
2306
- output_path = os.path.join(build_dir, output_name)
2307
-
2308
- # generate PTX or CUBIN
2309
- with warp.ScopedTimer("Compile CUDA", active=warp.config.verbose):
2310
- warp.build.build_cuda(
2311
- source_code_path,
2312
- output_arch,
2313
- output_path,
2314
- config=self.options["mode"],
2315
- verify_fp=warp.config.verify_fp,
2316
- fast_math=self.options["fast_math"],
2317
- fuse_fp=self.options["fuse_fp"],
2318
- lineinfo=self.options["lineinfo"],
2319
- compile_time_trace=self.options["compile_time_trace"],
2320
- ltoirs=builder.ltoirs.values(),
2321
- fatbins=builder.fatbins.values(),
2322
- )
2552
+ module_load_timer_name = (
2553
+ f"Module {self.name} {module_hash.hex()[:7]} load on device '{device}'"
2554
+ if self.options["strip_hash"] is False
2555
+ else f"Module {self.name} load on device '{device}'"
2556
+ )
2323
2557
 
2324
- except Exception as e:
2325
- self.failed_builds.add(device.context)
2326
- module_load_timer.extra_msg = " (error)"
2327
- raise (e)
2558
+ if warp.config.verbose:
2559
+ module_load_timer_name += f" (block_dim={active_block_dim})"
2328
2560
 
2329
- # ------------------------------------------------------------
2330
- # build meta data
2561
+ with warp.ScopedTimer(module_load_timer_name, active=not warp.config.quiet) as module_load_timer:
2562
+ # -----------------------------------------------------------
2563
+ # Determine binary path and build if necessary
2331
2564
 
2332
- meta = builder.build_meta()
2333
- meta_path = os.path.join(build_dir, f"{module_name_short}.meta")
2565
+ if binary_path:
2566
+ # We will never re-codegen or re-compile in this situation
2567
+ # The expected files must already exist
2334
2568
 
2335
- with open(meta_path, "w") as meta_file:
2336
- json.dump(meta, meta_file)
2569
+ if device.is_cuda and output_arch is None:
2570
+ raise ValueError("'output_arch' must be provided if a 'binary_path' is provided")
2337
2571
 
2338
- # -----------------------------------------------------------
2339
- # update cache
2572
+ if meta_path is None:
2573
+ raise ValueError("'meta_path' must be provided if a 'binary_path' is provided")
2340
2574
 
2341
- # try to move process outputs to cache
2342
- warp.build.safe_rename(build_dir, module_dir)
2575
+ if not os.path.exists(binary_path):
2576
+ module_load_timer.extra_msg = " (error)"
2577
+ raise FileNotFoundError(f"Binary file {binary_path} does not exist")
2578
+ else:
2579
+ module_load_timer.extra_msg = " (cached)"
2580
+ else:
2581
+ # we will build if binary doesn't exist yet
2582
+ # we will rebuild if we are not caching kernels or if we are tracking array access
2343
2583
 
2344
- if os.path.exists(module_dir):
2345
- if not os.path.exists(binary_path):
2346
- # copy our output file to the destination module
2347
- # this is necessary in case different processes
2348
- # have different GPU architectures / devices
2349
- try:
2350
- os.rename(output_path, binary_path)
2351
- except (OSError, FileExistsError):
2352
- # another process likely updated the module dir first
2353
- pass
2584
+ output_name = self.get_compile_output_name(device)
2585
+ output_arch = self.get_compile_arch(device)
2586
+
2587
+ module_dir = os.path.join(warp.config.kernel_cache_dir, module_name_short)
2588
+ meta_path = os.path.join(module_dir, self.get_meta_name())
2589
+ # final object binary path
2590
+ binary_path = os.path.join(module_dir, output_name)
2354
2591
 
2592
+ if (
2593
+ not os.path.exists(binary_path)
2594
+ or not warp.config.cache_kernels
2595
+ or warp.config.verify_autograd_array_access
2596
+ ):
2355
2597
  try:
2356
- final_source_path = os.path.join(module_dir, os.path.basename(source_code_path))
2357
- if not os.path.exists(final_source_path):
2358
- os.rename(source_code_path, final_source_path)
2359
- except (OSError, FileExistsError):
2360
- # another process likely updated the module dir first
2361
- pass
2598
+ self.compile(device, module_dir, output_name, output_arch)
2362
2599
  except Exception as e:
2363
- # We don't need source_code_path to be copied successfully to proceed, so warn and keep running
2364
- warp.utils.warn(f"Exception when renaming {source_code_path}: {e}")
2365
- else:
2366
- module_load_timer.extra_msg = " (cached)" # For wp.ScopedTimer informational purposes
2600
+ module_load_timer.extra_msg = " (error)"
2601
+ raise (e)
2602
+
2603
+ module_load_timer.extra_msg = " (compiled)"
2604
+ else:
2605
+ module_load_timer.extra_msg = " (cached)"
2367
2606
 
2368
2607
  # -----------------------------------------------------------
2369
2608
  # Load CPU or CUDA binary
2370
2609
 
2371
- meta_path = os.path.join(module_dir, f"{module_name_short}.meta")
2372
- with open(meta_path) as meta_file:
2373
- meta = json.load(meta_file)
2610
+ if os.path.exists(meta_path):
2611
+ with open(meta_path) as meta_file:
2612
+ meta = json.load(meta_file)
2613
+ else:
2614
+ raise FileNotFoundError(f"Module metadata file {meta_path} was not found in the cache")
2374
2615
 
2375
2616
  if device.is_cpu:
2376
2617
  # LLVM modules are identified using strings, so we need to ensure uniqueness
2377
- module_handle = f"{module_name}_{self.cpu_exec_id}"
2618
+ module_handle = f"wp_{self.name}_{self.cpu_exec_id}"
2378
2619
  self.cpu_exec_id += 1
2379
- runtime.llvm.load_obj(binary_path.encode("utf-8"), module_handle.encode("utf-8"))
2620
+ runtime.llvm.wp_load_obj(binary_path.encode("utf-8"), module_handle.encode("utf-8"))
2380
2621
  module_exec = ModuleExec(module_handle, module_hash, device, meta)
2381
2622
  self.execs[(None, active_block_dim)] = module_exec
2382
2623
 
@@ -2389,12 +2630,6 @@ class Module:
2389
2630
  module_load_timer.extra_msg = " (error)"
2390
2631
  raise Exception(f"Failed to load CUDA module '{self.name}'")
2391
2632
 
2392
- if build_dir:
2393
- import shutil
2394
-
2395
- # clean up build_dir used for this process regardless
2396
- shutil.rmtree(build_dir, ignore_errors=True)
2397
-
2398
2633
  return module_exec
2399
2634
 
2400
2635
  def unload(self):
@@ -2430,13 +2665,13 @@ class CpuDefaultAllocator:
2430
2665
  self.deleter = lambda ptr, size: self.free(ptr, size)
2431
2666
 
2432
2667
  def alloc(self, size_in_bytes):
2433
- ptr = runtime.core.alloc_host(size_in_bytes)
2668
+ ptr = runtime.core.wp_alloc_host(size_in_bytes)
2434
2669
  if not ptr:
2435
2670
  raise RuntimeError(f"Failed to allocate {size_in_bytes} bytes on device 'cpu'")
2436
2671
  return ptr
2437
2672
 
2438
2673
  def free(self, ptr, size_in_bytes):
2439
- runtime.core.free_host(ptr)
2674
+ runtime.core.wp_free_host(ptr)
2440
2675
 
2441
2676
 
2442
2677
  class CpuPinnedAllocator:
@@ -2445,13 +2680,13 @@ class CpuPinnedAllocator:
2445
2680
  self.deleter = lambda ptr, size: self.free(ptr, size)
2446
2681
 
2447
2682
  def alloc(self, size_in_bytes):
2448
- ptr = runtime.core.alloc_pinned(size_in_bytes)
2683
+ ptr = runtime.core.wp_alloc_pinned(size_in_bytes)
2449
2684
  if not ptr:
2450
2685
  raise RuntimeError(f"Failed to allocate {size_in_bytes} bytes on device '{self.device}'")
2451
2686
  return ptr
2452
2687
 
2453
2688
  def free(self, ptr, size_in_bytes):
2454
- runtime.core.free_pinned(ptr)
2689
+ runtime.core.wp_free_pinned(ptr)
2455
2690
 
2456
2691
 
2457
2692
  class CudaDefaultAllocator:
@@ -2461,7 +2696,7 @@ class CudaDefaultAllocator:
2461
2696
  self.deleter = lambda ptr, size: self.free(ptr, size)
2462
2697
 
2463
2698
  def alloc(self, size_in_bytes):
2464
- ptr = runtime.core.alloc_device_default(self.device.context, size_in_bytes)
2699
+ ptr = runtime.core.wp_alloc_device_default(self.device.context, size_in_bytes)
2465
2700
  # If the allocation fails, check if graph capture is active to raise an informative error.
2466
2701
  # We delay the capture check to avoid overhead.
2467
2702
  if not ptr:
@@ -2483,7 +2718,7 @@ class CudaDefaultAllocator:
2483
2718
  return ptr
2484
2719
 
2485
2720
  def free(self, ptr, size_in_bytes):
2486
- runtime.core.free_device_default(self.device.context, ptr)
2721
+ runtime.core.wp_free_device_default(self.device.context, ptr)
2487
2722
 
2488
2723
 
2489
2724
  class CudaMempoolAllocator:
@@ -2494,13 +2729,13 @@ class CudaMempoolAllocator:
2494
2729
  self.deleter = lambda ptr, size: self.free(ptr, size)
2495
2730
 
2496
2731
  def alloc(self, size_in_bytes):
2497
- ptr = runtime.core.alloc_device_async(self.device.context, size_in_bytes)
2732
+ ptr = runtime.core.wp_alloc_device_async(self.device.context, size_in_bytes)
2498
2733
  if not ptr:
2499
2734
  raise RuntimeError(f"Failed to allocate {size_in_bytes} bytes on device '{self.device}'")
2500
2735
  return ptr
2501
2736
 
2502
2737
  def free(self, ptr, size_in_bytes):
2503
- runtime.core.free_device_async(self.device.context, ptr)
2738
+ runtime.core.wp_free_device_async(self.device.context, ptr)
2504
2739
 
2505
2740
 
2506
2741
  class ContextGuard:
@@ -2509,15 +2744,15 @@ class ContextGuard:
2509
2744
 
2510
2745
  def __enter__(self):
2511
2746
  if self.device.is_cuda:
2512
- runtime.core.cuda_context_push_current(self.device.context)
2747
+ runtime.core.wp_cuda_context_push_current(self.device.context)
2513
2748
  elif is_cuda_driver_initialized():
2514
- self.saved_context = runtime.core.cuda_context_get_current()
2749
+ self.saved_context = runtime.core.wp_cuda_context_get_current()
2515
2750
 
2516
2751
  def __exit__(self, exc_type, exc_value, traceback):
2517
2752
  if self.device.is_cuda:
2518
- runtime.core.cuda_context_pop_current()
2753
+ runtime.core.wp_cuda_context_pop_current()
2519
2754
  elif is_cuda_driver_initialized():
2520
- runtime.core.cuda_context_set_current(self.saved_context)
2755
+ runtime.core.wp_cuda_context_set_current(self.saved_context)
2521
2756
 
2522
2757
 
2523
2758
  class Event:
@@ -2580,7 +2815,7 @@ class Event:
2580
2815
  raise ValueError("The combination of 'enable_timing=True' and 'interprocess=True' is not allowed.")
2581
2816
  flags |= Event.Flags.INTERPROCESS
2582
2817
 
2583
- self.cuda_event = runtime.core.cuda_event_create(device.context, flags)
2818
+ self.cuda_event = runtime.core.wp_cuda_event_create(device.context, flags)
2584
2819
  if not self.cuda_event:
2585
2820
  raise RuntimeError(f"Failed to create event on device {device}")
2586
2821
  self.owner = True
@@ -2607,7 +2842,9 @@ class Event:
2607
2842
  # Allocate a buffer for the data (64-element char array)
2608
2843
  ipc_handle_buffer = (ctypes.c_char * 64)()
2609
2844
 
2610
- warp.context.runtime.core.cuda_ipc_get_event_handle(self.device.context, self.cuda_event, ipc_handle_buffer)
2845
+ warp.context.runtime.core.wp_cuda_ipc_get_event_handle(
2846
+ self.device.context, self.cuda_event, ipc_handle_buffer
2847
+ )
2611
2848
 
2612
2849
  if ipc_handle_buffer.raw == bytes(64):
2613
2850
  warp.utils.warn("IPC event handle appears to be invalid. Was interprocess=True used?")
@@ -2624,7 +2861,7 @@ class Event:
2624
2861
  This property may not be accessed during a graph capture on any stream.
2625
2862
  """
2626
2863
 
2627
- result_code = runtime.core.cuda_event_query(self.cuda_event)
2864
+ result_code = runtime.core.wp_cuda_event_query(self.cuda_event)
2628
2865
 
2629
2866
  return result_code == 0
2630
2867
 
@@ -2632,7 +2869,7 @@ class Event:
2632
2869
  if not self.owner:
2633
2870
  return
2634
2871
 
2635
- runtime.core.cuda_event_destroy(self.cuda_event)
2872
+ runtime.core.wp_cuda_event_destroy(self.cuda_event)
2636
2873
 
2637
2874
 
2638
2875
  class Stream:
@@ -2682,12 +2919,12 @@ class Stream:
2682
2919
  # we pass cuda_stream through kwargs because cuda_stream=None is actually a valid value (CUDA default stream)
2683
2920
  if "cuda_stream" in kwargs:
2684
2921
  self.cuda_stream = kwargs["cuda_stream"]
2685
- device.runtime.core.cuda_stream_register(device.context, self.cuda_stream)
2922
+ device.runtime.core.wp_cuda_stream_register(device.context, self.cuda_stream)
2686
2923
  else:
2687
2924
  if not isinstance(priority, int):
2688
2925
  raise TypeError("Stream priority must be an integer.")
2689
2926
  clamped_priority = max(-1, min(priority, 0)) # Only support two priority levels
2690
- self.cuda_stream = device.runtime.core.cuda_stream_create(device.context, clamped_priority)
2927
+ self.cuda_stream = device.runtime.core.wp_cuda_stream_create(device.context, clamped_priority)
2691
2928
 
2692
2929
  if not self.cuda_stream:
2693
2930
  raise RuntimeError(f"Failed to create stream on device {device}")
@@ -2698,9 +2935,9 @@ class Stream:
2698
2935
  return
2699
2936
 
2700
2937
  if self.owner:
2701
- runtime.core.cuda_stream_destroy(self.device.context, self.cuda_stream)
2938
+ runtime.core.wp_cuda_stream_destroy(self.device.context, self.cuda_stream)
2702
2939
  else:
2703
- runtime.core.cuda_stream_unregister(self.device.context, self.cuda_stream)
2940
+ runtime.core.wp_cuda_stream_unregister(self.device.context, self.cuda_stream)
2704
2941
 
2705
2942
  @property
2706
2943
  def cached_event(self) -> Event:
@@ -2726,7 +2963,7 @@ class Stream:
2726
2963
  f"Event from device {event.device} cannot be recorded on stream from device {self.device}"
2727
2964
  )
2728
2965
 
2729
- runtime.core.cuda_event_record(event.cuda_event, self.cuda_stream, event.enable_timing)
2966
+ runtime.core.wp_cuda_event_record(event.cuda_event, self.cuda_stream, event.enable_timing)
2730
2967
 
2731
2968
  return event
2732
2969
 
@@ -2735,7 +2972,7 @@ class Stream:
2735
2972
 
2736
2973
  This function does not block the host thread.
2737
2974
  """
2738
- runtime.core.cuda_stream_wait_event(self.cuda_stream, event.cuda_event)
2975
+ runtime.core.wp_cuda_stream_wait_event(self.cuda_stream, event.cuda_event)
2739
2976
 
2740
2977
  def wait_stream(self, other_stream: Stream, event: Event | None = None):
2741
2978
  """Records an event on `other_stream` and makes this stream wait on it.
@@ -2758,7 +2995,7 @@ class Stream:
2758
2995
  if event is None:
2759
2996
  event = other_stream.cached_event
2760
2997
 
2761
- runtime.core.cuda_stream_wait_stream(self.cuda_stream, other_stream.cuda_stream, event.cuda_event)
2998
+ runtime.core.wp_cuda_stream_wait_stream(self.cuda_stream, other_stream.cuda_stream, event.cuda_event)
2762
2999
 
2763
3000
  @property
2764
3001
  def is_complete(self) -> bool:
@@ -2767,19 +3004,19 @@ class Stream:
2767
3004
  This property may not be accessed during a graph capture on any stream.
2768
3005
  """
2769
3006
 
2770
- result_code = runtime.core.cuda_stream_query(self.cuda_stream)
3007
+ result_code = runtime.core.wp_cuda_stream_query(self.cuda_stream)
2771
3008
 
2772
3009
  return result_code == 0
2773
3010
 
2774
3011
  @property
2775
3012
  def is_capturing(self) -> bool:
2776
3013
  """A boolean indicating whether a graph capture is currently ongoing on this stream."""
2777
- return bool(runtime.core.cuda_stream_is_capturing(self.cuda_stream))
3014
+ return bool(runtime.core.wp_cuda_stream_is_capturing(self.cuda_stream))
2778
3015
 
2779
3016
  @property
2780
3017
  def priority(self) -> int:
2781
3018
  """An integer representing the priority of the stream."""
2782
- return runtime.core.cuda_stream_get_priority(self.cuda_stream)
3019
+ return runtime.core.wp_cuda_stream_get_priority(self.cuda_stream)
2783
3020
 
2784
3021
 
2785
3022
  class Device:
@@ -2848,22 +3085,22 @@ class Device:
2848
3085
  self.pci_bus_id = None
2849
3086
 
2850
3087
  # TODO: add more device-specific dispatch functions
2851
- self.memset = runtime.core.memset_host
2852
- self.memtile = runtime.core.memtile_host
3088
+ self.memset = runtime.core.wp_memset_host
3089
+ self.memtile = runtime.core.wp_memtile_host
2853
3090
 
2854
3091
  self.default_allocator = CpuDefaultAllocator(self)
2855
3092
  self.pinned_allocator = CpuPinnedAllocator(self)
2856
3093
 
2857
- elif ordinal >= 0 and ordinal < runtime.core.cuda_device_get_count():
3094
+ elif ordinal >= 0 and ordinal < runtime.core.wp_cuda_device_get_count():
2858
3095
  # CUDA device
2859
- self.name = runtime.core.cuda_device_get_name(ordinal).decode()
2860
- self.arch = runtime.core.cuda_device_get_arch(ordinal)
2861
- self.sm_count = runtime.core.cuda_device_get_sm_count(ordinal)
2862
- self.is_uva = runtime.core.cuda_device_is_uva(ordinal) > 0
2863
- self.is_mempool_supported = runtime.core.cuda_device_is_mempool_supported(ordinal) > 0
3096
+ self.name = runtime.core.wp_cuda_device_get_name(ordinal).decode()
3097
+ self.arch = runtime.core.wp_cuda_device_get_arch(ordinal)
3098
+ self.sm_count = runtime.core.wp_cuda_device_get_sm_count(ordinal)
3099
+ self.is_uva = runtime.core.wp_cuda_device_is_uva(ordinal) > 0
3100
+ self.is_mempool_supported = runtime.core.wp_cuda_device_is_mempool_supported(ordinal) > 0
2864
3101
  if platform.system() == "Linux":
2865
3102
  # Use None when IPC support cannot be determined
2866
- ipc_support_api_query = runtime.core.cuda_device_is_ipc_supported(ordinal)
3103
+ ipc_support_api_query = runtime.core.wp_cuda_device_is_ipc_supported(ordinal)
2867
3104
  self.is_ipc_supported = bool(ipc_support_api_query) if ipc_support_api_query >= 0 else None
2868
3105
  else:
2869
3106
  self.is_ipc_supported = False
@@ -2875,13 +3112,13 @@ class Device:
2875
3112
  self.is_mempool_enabled = False
2876
3113
 
2877
3114
  uuid_buffer = (ctypes.c_char * 16)()
2878
- runtime.core.cuda_device_get_uuid(ordinal, uuid_buffer)
3115
+ runtime.core.wp_cuda_device_get_uuid(ordinal, uuid_buffer)
2879
3116
  uuid_byte_str = bytes(uuid_buffer).hex()
2880
3117
  self.uuid = f"GPU-{uuid_byte_str[0:8]}-{uuid_byte_str[8:12]}-{uuid_byte_str[12:16]}-{uuid_byte_str[16:20]}-{uuid_byte_str[20:]}"
2881
3118
 
2882
- pci_domain_id = runtime.core.cuda_device_get_pci_domain_id(ordinal)
2883
- pci_bus_id = runtime.core.cuda_device_get_pci_bus_id(ordinal)
2884
- pci_device_id = runtime.core.cuda_device_get_pci_device_id(ordinal)
3119
+ pci_domain_id = runtime.core.wp_cuda_device_get_pci_domain_id(ordinal)
3120
+ pci_bus_id = runtime.core.wp_cuda_device_get_pci_bus_id(ordinal)
3121
+ pci_device_id = runtime.core.wp_cuda_device_get_pci_device_id(ordinal)
2885
3122
  # This is (mis)named to correspond to the naming of cudaDeviceGetPCIBusId
2886
3123
  self.pci_bus_id = f"{pci_domain_id:08X}:{pci_bus_id:02X}:{pci_device_id:02X}"
2887
3124
 
@@ -2905,8 +3142,8 @@ class Device:
2905
3142
  self._init_streams()
2906
3143
 
2907
3144
  # TODO: add more device-specific dispatch functions
2908
- self.memset = lambda ptr, value, size: runtime.core.memset_device(self.context, ptr, value, size)
2909
- self.memtile = lambda ptr, src, srcsize, reps: runtime.core.memtile_device(
3145
+ self.memset = lambda ptr, value, size: runtime.core.wp_memset_device(self.context, ptr, value, size)
3146
+ self.memtile = lambda ptr, src, srcsize, reps: runtime.core.wp_memtile_device(
2910
3147
  self.context, ptr, src, srcsize, reps
2911
3148
  )
2912
3149
 
@@ -2965,15 +3202,15 @@ class Device:
2965
3202
  return self._context
2966
3203
  elif self.is_primary:
2967
3204
  # acquire primary context on demand
2968
- prev_context = runtime.core.cuda_context_get_current()
2969
- self._context = self.runtime.core.cuda_device_get_primary_context(self.ordinal)
3205
+ prev_context = runtime.core.wp_cuda_context_get_current()
3206
+ self._context = self.runtime.core.wp_cuda_device_get_primary_context(self.ordinal)
2970
3207
  if self._context is None:
2971
- runtime.core.cuda_context_set_current(prev_context)
3208
+ runtime.core.wp_cuda_context_set_current(prev_context)
2972
3209
  raise RuntimeError(f"Failed to acquire primary context for device {self}")
2973
3210
  self.runtime.context_map[self._context] = self
2974
3211
  # initialize streams
2975
3212
  self._init_streams()
2976
- runtime.core.cuda_context_set_current(prev_context)
3213
+ runtime.core.wp_cuda_context_set_current(prev_context)
2977
3214
  return self._context
2978
3215
 
2979
3216
  @property
@@ -3017,7 +3254,7 @@ class Device:
3017
3254
  if stream.device != self:
3018
3255
  raise RuntimeError(f"Stream from device {stream.device} cannot be used on device {self}")
3019
3256
 
3020
- self.runtime.core.cuda_context_set_stream(self.context, stream.cuda_stream, int(sync))
3257
+ self.runtime.core.wp_cuda_context_set_stream(self.context, stream.cuda_stream, int(sync))
3021
3258
  self._stream = stream
3022
3259
  else:
3023
3260
  raise RuntimeError(f"Device {self} is not a CUDA device")
@@ -3035,7 +3272,7 @@ class Device:
3035
3272
  """
3036
3273
  if self.is_cuda:
3037
3274
  total_mem = ctypes.c_size_t()
3038
- self.runtime.core.cuda_device_get_memory_info(self.ordinal, None, ctypes.byref(total_mem))
3275
+ self.runtime.core.wp_cuda_device_get_memory_info(self.ordinal, None, ctypes.byref(total_mem))
3039
3276
  return total_mem.value
3040
3277
  else:
3041
3278
  # TODO: cpu
@@ -3049,7 +3286,7 @@ class Device:
3049
3286
  """
3050
3287
  if self.is_cuda:
3051
3288
  free_mem = ctypes.c_size_t()
3052
- self.runtime.core.cuda_device_get_memory_info(self.ordinal, ctypes.byref(free_mem), None)
3289
+ self.runtime.core.wp_cuda_device_get_memory_info(self.ordinal, ctypes.byref(free_mem), None)
3053
3290
  return free_mem.value
3054
3291
  else:
3055
3292
  # TODO: cpu
@@ -3076,7 +3313,7 @@ class Device:
3076
3313
 
3077
3314
  def make_current(self):
3078
3315
  if self.context is not None:
3079
- self.runtime.core.cuda_context_set_current(self.context)
3316
+ self.runtime.core.wp_cuda_context_set_current(self.context)
3080
3317
 
3081
3318
  def can_access(self, other):
3082
3319
  # TODO: this function should be redesigned in terms of (device, resource).
@@ -3102,11 +3339,7 @@ class Graph:
3102
3339
  self.capture_id = capture_id
3103
3340
  self.module_execs: set[ModuleExec] = set()
3104
3341
  self.graph_exec: ctypes.c_void_p | None = None
3105
-
3106
3342
  self.graph: ctypes.c_void_p | None = None
3107
- self.has_conditional = (
3108
- False # Track if there are conditional nodes in the graph since they are not allowed in child graphs
3109
- )
3110
3343
 
3111
3344
  def __del__(self):
3112
3345
  if not hasattr(self, "graph") or not hasattr(self, "device") or not self.graph:
@@ -3114,9 +3347,9 @@ class Graph:
3114
3347
 
3115
3348
  # use CUDA context guard to avoid side effects during garbage collection
3116
3349
  with self.device.context_guard:
3117
- runtime.core.cuda_graph_destroy(self.device.context, self.graph)
3350
+ runtime.core.wp_cuda_graph_destroy(self.device.context, self.graph)
3118
3351
  if hasattr(self, "graph_exec") and self.graph_exec is not None:
3119
- runtime.core.cuda_graph_exec_destroy(self.device.context, self.graph_exec)
3352
+ runtime.core.wp_cuda_graph_exec_destroy(self.device.context, self.graph_exec)
3120
3353
 
3121
3354
  # retain executable CUDA modules used by this graph, which prevents them from being unloaded
3122
3355
  def retain_module_exec(self, module_exec: ModuleExec):
@@ -3128,6 +3361,14 @@ class Runtime:
3128
3361
  if sys.version_info < (3, 9):
3129
3362
  warp.utils.warn(f"Python 3.9 or newer is recommended for running Warp, detected {sys.version_info}")
3130
3363
 
3364
+ if platform.system() == "Darwin" and platform.machine() == "x86_64":
3365
+ warp.utils.warn(
3366
+ "Support for Warp on Intel-based macOS is deprecated and will be removed in the near future. "
3367
+ "Apple Silicon-based Macs will continue to be supported.",
3368
+ DeprecationWarning,
3369
+ stacklevel=3,
3370
+ )
3371
+
3131
3372
  bin_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "bin")
3132
3373
 
3133
3374
  if os.name == "nt":
@@ -3150,7 +3391,7 @@ class Runtime:
3150
3391
  if os.path.exists(llvm_lib):
3151
3392
  self.llvm = self.load_dll(llvm_lib)
3152
3393
  # setup c-types for warp-clang.dll
3153
- self.llvm.lookup.restype = ctypes.c_uint64
3394
+ self.llvm.wp_lookup.restype = ctypes.c_uint64
3154
3395
  else:
3155
3396
  self.llvm = None
3156
3397
 
@@ -3159,83 +3400,83 @@ class Runtime:
3159
3400
 
3160
3401
  # setup c-types for warp.dll
3161
3402
  try:
3162
- self.core.get_error_string.argtypes = []
3163
- self.core.get_error_string.restype = ctypes.c_char_p
3164
- self.core.set_error_output_enabled.argtypes = [ctypes.c_int]
3165
- self.core.set_error_output_enabled.restype = None
3166
- self.core.is_error_output_enabled.argtypes = []
3167
- self.core.is_error_output_enabled.restype = ctypes.c_int
3168
-
3169
- self.core.alloc_host.argtypes = [ctypes.c_size_t]
3170
- self.core.alloc_host.restype = ctypes.c_void_p
3171
- self.core.alloc_pinned.argtypes = [ctypes.c_size_t]
3172
- self.core.alloc_pinned.restype = ctypes.c_void_p
3173
- self.core.alloc_device.argtypes = [ctypes.c_void_p, ctypes.c_size_t]
3174
- self.core.alloc_device.restype = ctypes.c_void_p
3175
- self.core.alloc_device_default.argtypes = [ctypes.c_void_p, ctypes.c_size_t]
3176
- self.core.alloc_device_default.restype = ctypes.c_void_p
3177
- self.core.alloc_device_async.argtypes = [ctypes.c_void_p, ctypes.c_size_t]
3178
- self.core.alloc_device_async.restype = ctypes.c_void_p
3179
-
3180
- self.core.float_to_half_bits.argtypes = [ctypes.c_float]
3181
- self.core.float_to_half_bits.restype = ctypes.c_uint16
3182
- self.core.half_bits_to_float.argtypes = [ctypes.c_uint16]
3183
- self.core.half_bits_to_float.restype = ctypes.c_float
3184
-
3185
- self.core.free_host.argtypes = [ctypes.c_void_p]
3186
- self.core.free_host.restype = None
3187
- self.core.free_pinned.argtypes = [ctypes.c_void_p]
3188
- self.core.free_pinned.restype = None
3189
- self.core.free_device.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
3190
- self.core.free_device.restype = None
3191
- self.core.free_device_default.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
3192
- self.core.free_device_default.restype = None
3193
- self.core.free_device_async.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
3194
- self.core.free_device_async.restype = None
3195
-
3196
- self.core.memset_host.argtypes = [ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t]
3197
- self.core.memset_host.restype = None
3198
- self.core.memset_device.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t]
3199
- self.core.memset_device.restype = None
3200
-
3201
- self.core.memtile_host.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ctypes.c_size_t]
3202
- self.core.memtile_host.restype = None
3203
- self.core.memtile_device.argtypes = [
3403
+ self.core.wp_get_error_string.argtypes = []
3404
+ self.core.wp_get_error_string.restype = ctypes.c_char_p
3405
+ self.core.wp_set_error_output_enabled.argtypes = [ctypes.c_int]
3406
+ self.core.wp_set_error_output_enabled.restype = None
3407
+ self.core.wp_is_error_output_enabled.argtypes = []
3408
+ self.core.wp_is_error_output_enabled.restype = ctypes.c_int
3409
+
3410
+ self.core.wp_alloc_host.argtypes = [ctypes.c_size_t]
3411
+ self.core.wp_alloc_host.restype = ctypes.c_void_p
3412
+ self.core.wp_alloc_pinned.argtypes = [ctypes.c_size_t]
3413
+ self.core.wp_alloc_pinned.restype = ctypes.c_void_p
3414
+ self.core.wp_alloc_device.argtypes = [ctypes.c_void_p, ctypes.c_size_t]
3415
+ self.core.wp_alloc_device.restype = ctypes.c_void_p
3416
+ self.core.wp_alloc_device_default.argtypes = [ctypes.c_void_p, ctypes.c_size_t]
3417
+ self.core.wp_alloc_device_default.restype = ctypes.c_void_p
3418
+ self.core.wp_alloc_device_async.argtypes = [ctypes.c_void_p, ctypes.c_size_t]
3419
+ self.core.wp_alloc_device_async.restype = ctypes.c_void_p
3420
+
3421
+ self.core.wp_float_to_half_bits.argtypes = [ctypes.c_float]
3422
+ self.core.wp_float_to_half_bits.restype = ctypes.c_uint16
3423
+ self.core.wp_half_bits_to_float.argtypes = [ctypes.c_uint16]
3424
+ self.core.wp_half_bits_to_float.restype = ctypes.c_float
3425
+
3426
+ self.core.wp_free_host.argtypes = [ctypes.c_void_p]
3427
+ self.core.wp_free_host.restype = None
3428
+ self.core.wp_free_pinned.argtypes = [ctypes.c_void_p]
3429
+ self.core.wp_free_pinned.restype = None
3430
+ self.core.wp_free_device.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
3431
+ self.core.wp_free_device.restype = None
3432
+ self.core.wp_free_device_default.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
3433
+ self.core.wp_free_device_default.restype = None
3434
+ self.core.wp_free_device_async.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
3435
+ self.core.wp_free_device_async.restype = None
3436
+
3437
+ self.core.wp_memset_host.argtypes = [ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t]
3438
+ self.core.wp_memset_host.restype = None
3439
+ self.core.wp_memset_device.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t]
3440
+ self.core.wp_memset_device.restype = None
3441
+
3442
+ self.core.wp_memtile_host.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ctypes.c_size_t]
3443
+ self.core.wp_memtile_host.restype = None
3444
+ self.core.wp_memtile_device.argtypes = [
3204
3445
  ctypes.c_void_p,
3205
3446
  ctypes.c_void_p,
3206
3447
  ctypes.c_void_p,
3207
3448
  ctypes.c_size_t,
3208
3449
  ctypes.c_size_t,
3209
3450
  ]
3210
- self.core.memtile_device.restype = None
3451
+ self.core.wp_memtile_device.restype = None
3211
3452
 
3212
- self.core.memcpy_h2h.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t]
3213
- self.core.memcpy_h2h.restype = ctypes.c_bool
3214
- self.core.memcpy_h2d.argtypes = [
3453
+ self.core.wp_memcpy_h2h.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t]
3454
+ self.core.wp_memcpy_h2h.restype = ctypes.c_bool
3455
+ self.core.wp_memcpy_h2d.argtypes = [
3215
3456
  ctypes.c_void_p,
3216
3457
  ctypes.c_void_p,
3217
3458
  ctypes.c_void_p,
3218
3459
  ctypes.c_size_t,
3219
3460
  ctypes.c_void_p,
3220
3461
  ]
3221
- self.core.memcpy_h2d.restype = ctypes.c_bool
3222
- self.core.memcpy_d2h.argtypes = [
3462
+ self.core.wp_memcpy_h2d.restype = ctypes.c_bool
3463
+ self.core.wp_memcpy_d2h.argtypes = [
3223
3464
  ctypes.c_void_p,
3224
3465
  ctypes.c_void_p,
3225
3466
  ctypes.c_void_p,
3226
3467
  ctypes.c_size_t,
3227
3468
  ctypes.c_void_p,
3228
3469
  ]
3229
- self.core.memcpy_d2h.restype = ctypes.c_bool
3230
- self.core.memcpy_d2d.argtypes = [
3470
+ self.core.wp_memcpy_d2h.restype = ctypes.c_bool
3471
+ self.core.wp_memcpy_d2d.argtypes = [
3231
3472
  ctypes.c_void_p,
3232
3473
  ctypes.c_void_p,
3233
3474
  ctypes.c_void_p,
3234
3475
  ctypes.c_size_t,
3235
3476
  ctypes.c_void_p,
3236
3477
  ]
3237
- self.core.memcpy_d2d.restype = ctypes.c_bool
3238
- self.core.memcpy_p2p.argtypes = [
3478
+ self.core.wp_memcpy_d2d.restype = ctypes.c_bool
3479
+ self.core.wp_memcpy_p2p.argtypes = [
3239
3480
  ctypes.c_void_p,
3240
3481
  ctypes.c_void_p,
3241
3482
  ctypes.c_void_p,
@@ -3243,17 +3484,17 @@ class Runtime:
3243
3484
  ctypes.c_size_t,
3244
3485
  ctypes.c_void_p,
3245
3486
  ]
3246
- self.core.memcpy_p2p.restype = ctypes.c_bool
3487
+ self.core.wp_memcpy_p2p.restype = ctypes.c_bool
3247
3488
 
3248
- self.core.array_copy_host.argtypes = [
3489
+ self.core.wp_array_copy_host.argtypes = [
3249
3490
  ctypes.c_void_p,
3250
3491
  ctypes.c_void_p,
3251
3492
  ctypes.c_int,
3252
3493
  ctypes.c_int,
3253
3494
  ctypes.c_int,
3254
3495
  ]
3255
- self.core.array_copy_host.restype = ctypes.c_bool
3256
- self.core.array_copy_device.argtypes = [
3496
+ self.core.wp_array_copy_host.restype = ctypes.c_bool
3497
+ self.core.wp_array_copy_device.argtypes = [
3257
3498
  ctypes.c_void_p,
3258
3499
  ctypes.c_void_p,
3259
3500
  ctypes.c_void_p,
@@ -3261,41 +3502,41 @@ class Runtime:
3261
3502
  ctypes.c_int,
3262
3503
  ctypes.c_int,
3263
3504
  ]
3264
- self.core.array_copy_device.restype = ctypes.c_bool
3505
+ self.core.wp_array_copy_device.restype = ctypes.c_bool
3265
3506
 
3266
- self.core.array_fill_host.argtypes = [ctypes.c_void_p, ctypes.c_int, ctypes.c_void_p, ctypes.c_int]
3267
- self.core.array_fill_host.restype = None
3268
- self.core.array_fill_device.argtypes = [
3507
+ self.core.wp_array_fill_host.argtypes = [ctypes.c_void_p, ctypes.c_int, ctypes.c_void_p, ctypes.c_int]
3508
+ self.core.wp_array_fill_host.restype = None
3509
+ self.core.wp_array_fill_device.argtypes = [
3269
3510
  ctypes.c_void_p,
3270
3511
  ctypes.c_void_p,
3271
3512
  ctypes.c_int,
3272
3513
  ctypes.c_void_p,
3273
3514
  ctypes.c_int,
3274
3515
  ]
3275
- self.core.array_fill_device.restype = None
3516
+ self.core.wp_array_fill_device.restype = None
3276
3517
 
3277
- self.core.array_sum_double_host.argtypes = [
3518
+ self.core.wp_array_sum_double_host.argtypes = [
3278
3519
  ctypes.c_uint64,
3279
3520
  ctypes.c_uint64,
3280
3521
  ctypes.c_int,
3281
3522
  ctypes.c_int,
3282
3523
  ctypes.c_int,
3283
3524
  ]
3284
- self.core.array_sum_float_host.argtypes = [
3525
+ self.core.wp_array_sum_float_host.argtypes = [
3285
3526
  ctypes.c_uint64,
3286
3527
  ctypes.c_uint64,
3287
3528
  ctypes.c_int,
3288
3529
  ctypes.c_int,
3289
3530
  ctypes.c_int,
3290
3531
  ]
3291
- self.core.array_sum_double_device.argtypes = [
3532
+ self.core.wp_array_sum_double_device.argtypes = [
3292
3533
  ctypes.c_uint64,
3293
3534
  ctypes.c_uint64,
3294
3535
  ctypes.c_int,
3295
3536
  ctypes.c_int,
3296
3537
  ctypes.c_int,
3297
3538
  ]
3298
- self.core.array_sum_float_device.argtypes = [
3539
+ self.core.wp_array_sum_float_device.argtypes = [
3299
3540
  ctypes.c_uint64,
3300
3541
  ctypes.c_uint64,
3301
3542
  ctypes.c_int,
@@ -3303,7 +3544,7 @@ class Runtime:
3303
3544
  ctypes.c_int,
3304
3545
  ]
3305
3546
 
3306
- self.core.array_inner_double_host.argtypes = [
3547
+ self.core.wp_array_inner_double_host.argtypes = [
3307
3548
  ctypes.c_uint64,
3308
3549
  ctypes.c_uint64,
3309
3550
  ctypes.c_uint64,
@@ -3312,7 +3553,7 @@ class Runtime:
3312
3553
  ctypes.c_int,
3313
3554
  ctypes.c_int,
3314
3555
  ]
3315
- self.core.array_inner_float_host.argtypes = [
3556
+ self.core.wp_array_inner_float_host.argtypes = [
3316
3557
  ctypes.c_uint64,
3317
3558
  ctypes.c_uint64,
3318
3559
  ctypes.c_uint64,
@@ -3321,7 +3562,7 @@ class Runtime:
3321
3562
  ctypes.c_int,
3322
3563
  ctypes.c_int,
3323
3564
  ]
3324
- self.core.array_inner_double_device.argtypes = [
3565
+ self.core.wp_array_inner_double_device.argtypes = [
3325
3566
  ctypes.c_uint64,
3326
3567
  ctypes.c_uint64,
3327
3568
  ctypes.c_uint64,
@@ -3330,7 +3571,7 @@ class Runtime:
3330
3571
  ctypes.c_int,
3331
3572
  ctypes.c_int,
3332
3573
  ]
3333
- self.core.array_inner_float_device.argtypes = [
3574
+ self.core.wp_array_inner_float_device.argtypes = [
3334
3575
  ctypes.c_uint64,
3335
3576
  ctypes.c_uint64,
3336
3577
  ctypes.c_uint64,
@@ -3340,21 +3581,36 @@ class Runtime:
3340
3581
  ctypes.c_int,
3341
3582
  ]
3342
3583
 
3343
- self.core.array_scan_int_host.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int, ctypes.c_bool]
3344
- self.core.array_scan_float_host.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int, ctypes.c_bool]
3345
- self.core.array_scan_int_device.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int, ctypes.c_bool]
3346
- self.core.array_scan_float_device.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int, ctypes.c_bool]
3584
+ self.core.wp_array_scan_int_host.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int, ctypes.c_bool]
3585
+ self.core.wp_array_scan_float_host.argtypes = [
3586
+ ctypes.c_uint64,
3587
+ ctypes.c_uint64,
3588
+ ctypes.c_int,
3589
+ ctypes.c_bool,
3590
+ ]
3591
+ self.core.wp_array_scan_int_device.argtypes = [
3592
+ ctypes.c_uint64,
3593
+ ctypes.c_uint64,
3594
+ ctypes.c_int,
3595
+ ctypes.c_bool,
3596
+ ]
3597
+ self.core.wp_array_scan_float_device.argtypes = [
3598
+ ctypes.c_uint64,
3599
+ ctypes.c_uint64,
3600
+ ctypes.c_int,
3601
+ ctypes.c_bool,
3602
+ ]
3347
3603
 
3348
- self.core.radix_sort_pairs_int_host.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int]
3349
- self.core.radix_sort_pairs_int_device.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int]
3604
+ self.core.wp_radix_sort_pairs_int_host.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int]
3605
+ self.core.wp_radix_sort_pairs_int_device.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int]
3350
3606
 
3351
- self.core.radix_sort_pairs_float_host.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int]
3352
- self.core.radix_sort_pairs_float_device.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int]
3607
+ self.core.wp_radix_sort_pairs_float_host.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int]
3608
+ self.core.wp_radix_sort_pairs_float_device.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int]
3353
3609
 
3354
- self.core.radix_sort_pairs_int64_host.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int]
3355
- self.core.radix_sort_pairs_int64_device.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int]
3610
+ self.core.wp_radix_sort_pairs_int64_host.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int]
3611
+ self.core.wp_radix_sort_pairs_int64_device.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int]
3356
3612
 
3357
- self.core.segmented_sort_pairs_int_host.argtypes = [
3613
+ self.core.wp_segmented_sort_pairs_int_host.argtypes = [
3358
3614
  ctypes.c_uint64,
3359
3615
  ctypes.c_uint64,
3360
3616
  ctypes.c_int,
@@ -3362,7 +3618,7 @@ class Runtime:
3362
3618
  ctypes.c_uint64,
3363
3619
  ctypes.c_int,
3364
3620
  ]
3365
- self.core.segmented_sort_pairs_int_device.argtypes = [
3621
+ self.core.wp_segmented_sort_pairs_int_device.argtypes = [
3366
3622
  ctypes.c_uint64,
3367
3623
  ctypes.c_uint64,
3368
3624
  ctypes.c_int,
@@ -3371,7 +3627,7 @@ class Runtime:
3371
3627
  ctypes.c_int,
3372
3628
  ]
3373
3629
 
3374
- self.core.segmented_sort_pairs_float_host.argtypes = [
3630
+ self.core.wp_segmented_sort_pairs_float_host.argtypes = [
3375
3631
  ctypes.c_uint64,
3376
3632
  ctypes.c_uint64,
3377
3633
  ctypes.c_int,
@@ -3379,7 +3635,7 @@ class Runtime:
3379
3635
  ctypes.c_uint64,
3380
3636
  ctypes.c_int,
3381
3637
  ]
3382
- self.core.segmented_sort_pairs_float_device.argtypes = [
3638
+ self.core.wp_segmented_sort_pairs_float_device.argtypes = [
3383
3639
  ctypes.c_uint64,
3384
3640
  ctypes.c_uint64,
3385
3641
  ctypes.c_int,
@@ -3388,14 +3644,14 @@ class Runtime:
3388
3644
  ctypes.c_int,
3389
3645
  ]
3390
3646
 
3391
- self.core.runlength_encode_int_host.argtypes = [
3647
+ self.core.wp_runlength_encode_int_host.argtypes = [
3392
3648
  ctypes.c_uint64,
3393
3649
  ctypes.c_uint64,
3394
3650
  ctypes.c_uint64,
3395
3651
  ctypes.c_uint64,
3396
3652
  ctypes.c_int,
3397
3653
  ]
3398
- self.core.runlength_encode_int_device.argtypes = [
3654
+ self.core.wp_runlength_encode_int_device.argtypes = [
3399
3655
  ctypes.c_uint64,
3400
3656
  ctypes.c_uint64,
3401
3657
  ctypes.c_uint64,
@@ -3403,11 +3659,11 @@ class Runtime:
3403
3659
  ctypes.c_int,
3404
3660
  ]
3405
3661
 
3406
- self.core.bvh_create_host.restype = ctypes.c_uint64
3407
- self.core.bvh_create_host.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int, ctypes.c_int]
3662
+ self.core.wp_bvh_create_host.restype = ctypes.c_uint64
3663
+ self.core.wp_bvh_create_host.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int, ctypes.c_int]
3408
3664
 
3409
- self.core.bvh_create_device.restype = ctypes.c_uint64
3410
- self.core.bvh_create_device.argtypes = [
3665
+ self.core.wp_bvh_create_device.restype = ctypes.c_uint64
3666
+ self.core.wp_bvh_create_device.argtypes = [
3411
3667
  ctypes.c_void_p,
3412
3668
  ctypes.c_void_p,
3413
3669
  ctypes.c_void_p,
@@ -3415,14 +3671,14 @@ class Runtime:
3415
3671
  ctypes.c_int,
3416
3672
  ]
3417
3673
 
3418
- self.core.bvh_destroy_host.argtypes = [ctypes.c_uint64]
3419
- self.core.bvh_destroy_device.argtypes = [ctypes.c_uint64]
3674
+ self.core.wp_bvh_destroy_host.argtypes = [ctypes.c_uint64]
3675
+ self.core.wp_bvh_destroy_device.argtypes = [ctypes.c_uint64]
3420
3676
 
3421
- self.core.bvh_refit_host.argtypes = [ctypes.c_uint64]
3422
- self.core.bvh_refit_device.argtypes = [ctypes.c_uint64]
3677
+ self.core.wp_bvh_refit_host.argtypes = [ctypes.c_uint64]
3678
+ self.core.wp_bvh_refit_device.argtypes = [ctypes.c_uint64]
3423
3679
 
3424
- self.core.mesh_create_host.restype = ctypes.c_uint64
3425
- self.core.mesh_create_host.argtypes = [
3680
+ self.core.wp_mesh_create_host.restype = ctypes.c_uint64
3681
+ self.core.wp_mesh_create_host.argtypes = [
3426
3682
  warp.types.array_t,
3427
3683
  warp.types.array_t,
3428
3684
  warp.types.array_t,
@@ -3432,8 +3688,8 @@ class Runtime:
3432
3688
  ctypes.c_int,
3433
3689
  ]
3434
3690
 
3435
- self.core.mesh_create_device.restype = ctypes.c_uint64
3436
- self.core.mesh_create_device.argtypes = [
3691
+ self.core.wp_mesh_create_device.restype = ctypes.c_uint64
3692
+ self.core.wp_mesh_create_device.argtypes = [
3437
3693
  ctypes.c_void_p,
3438
3694
  warp.types.array_t,
3439
3695
  warp.types.array_t,
@@ -3444,61 +3700,61 @@ class Runtime:
3444
3700
  ctypes.c_int,
3445
3701
  ]
3446
3702
 
3447
- self.core.mesh_destroy_host.argtypes = [ctypes.c_uint64]
3448
- self.core.mesh_destroy_device.argtypes = [ctypes.c_uint64]
3703
+ self.core.wp_mesh_destroy_host.argtypes = [ctypes.c_uint64]
3704
+ self.core.wp_mesh_destroy_device.argtypes = [ctypes.c_uint64]
3449
3705
 
3450
- self.core.mesh_refit_host.argtypes = [ctypes.c_uint64]
3451
- self.core.mesh_refit_device.argtypes = [ctypes.c_uint64]
3706
+ self.core.wp_mesh_refit_host.argtypes = [ctypes.c_uint64]
3707
+ self.core.wp_mesh_refit_device.argtypes = [ctypes.c_uint64]
3452
3708
 
3453
- self.core.mesh_set_points_host.argtypes = [ctypes.c_uint64, warp.types.array_t]
3454
- self.core.mesh_set_points_device.argtypes = [ctypes.c_uint64, warp.types.array_t]
3709
+ self.core.wp_mesh_set_points_host.argtypes = [ctypes.c_uint64, warp.types.array_t]
3710
+ self.core.wp_mesh_set_points_device.argtypes = [ctypes.c_uint64, warp.types.array_t]
3455
3711
 
3456
- self.core.mesh_set_velocities_host.argtypes = [ctypes.c_uint64, warp.types.array_t]
3457
- self.core.mesh_set_velocities_device.argtypes = [ctypes.c_uint64, warp.types.array_t]
3712
+ self.core.wp_mesh_set_velocities_host.argtypes = [ctypes.c_uint64, warp.types.array_t]
3713
+ self.core.wp_mesh_set_velocities_device.argtypes = [ctypes.c_uint64, warp.types.array_t]
3458
3714
 
3459
- self.core.hash_grid_create_host.argtypes = [ctypes.c_int, ctypes.c_int, ctypes.c_int]
3460
- self.core.hash_grid_create_host.restype = ctypes.c_uint64
3461
- self.core.hash_grid_destroy_host.argtypes = [ctypes.c_uint64]
3462
- self.core.hash_grid_update_host.argtypes = [ctypes.c_uint64, ctypes.c_float, ctypes.c_void_p]
3463
- self.core.hash_grid_reserve_host.argtypes = [ctypes.c_uint64, ctypes.c_int]
3715
+ self.core.wp_hash_grid_create_host.argtypes = [ctypes.c_int, ctypes.c_int, ctypes.c_int]
3716
+ self.core.wp_hash_grid_create_host.restype = ctypes.c_uint64
3717
+ self.core.wp_hash_grid_destroy_host.argtypes = [ctypes.c_uint64]
3718
+ self.core.wp_hash_grid_update_host.argtypes = [ctypes.c_uint64, ctypes.c_float, ctypes.c_void_p]
3719
+ self.core.wp_hash_grid_reserve_host.argtypes = [ctypes.c_uint64, ctypes.c_int]
3464
3720
 
3465
- self.core.hash_grid_create_device.argtypes = [ctypes.c_void_p, ctypes.c_int, ctypes.c_int, ctypes.c_int]
3466
- self.core.hash_grid_create_device.restype = ctypes.c_uint64
3467
- self.core.hash_grid_destroy_device.argtypes = [ctypes.c_uint64]
3468
- self.core.hash_grid_update_device.argtypes = [ctypes.c_uint64, ctypes.c_float, ctypes.c_void_p]
3469
- self.core.hash_grid_reserve_device.argtypes = [ctypes.c_uint64, ctypes.c_int]
3721
+ self.core.wp_hash_grid_create_device.argtypes = [ctypes.c_void_p, ctypes.c_int, ctypes.c_int, ctypes.c_int]
3722
+ self.core.wp_hash_grid_create_device.restype = ctypes.c_uint64
3723
+ self.core.wp_hash_grid_destroy_device.argtypes = [ctypes.c_uint64]
3724
+ self.core.wp_hash_grid_update_device.argtypes = [ctypes.c_uint64, ctypes.c_float, ctypes.c_void_p]
3725
+ self.core.wp_hash_grid_reserve_device.argtypes = [ctypes.c_uint64, ctypes.c_int]
3470
3726
 
3471
- self.core.volume_create_host.argtypes = [ctypes.c_void_p, ctypes.c_uint64, ctypes.c_bool, ctypes.c_bool]
3472
- self.core.volume_create_host.restype = ctypes.c_uint64
3473
- self.core.volume_get_tiles_host.argtypes = [
3727
+ self.core.wp_volume_create_host.argtypes = [ctypes.c_void_p, ctypes.c_uint64, ctypes.c_bool, ctypes.c_bool]
3728
+ self.core.wp_volume_create_host.restype = ctypes.c_uint64
3729
+ self.core.wp_volume_get_tiles_host.argtypes = [
3474
3730
  ctypes.c_uint64,
3475
3731
  ctypes.c_void_p,
3476
3732
  ]
3477
- self.core.volume_get_voxels_host.argtypes = [
3733
+ self.core.wp_volume_get_voxels_host.argtypes = [
3478
3734
  ctypes.c_uint64,
3479
3735
  ctypes.c_void_p,
3480
3736
  ]
3481
- self.core.volume_destroy_host.argtypes = [ctypes.c_uint64]
3737
+ self.core.wp_volume_destroy_host.argtypes = [ctypes.c_uint64]
3482
3738
 
3483
- self.core.volume_create_device.argtypes = [
3739
+ self.core.wp_volume_create_device.argtypes = [
3484
3740
  ctypes.c_void_p,
3485
3741
  ctypes.c_void_p,
3486
3742
  ctypes.c_uint64,
3487
3743
  ctypes.c_bool,
3488
3744
  ctypes.c_bool,
3489
3745
  ]
3490
- self.core.volume_create_device.restype = ctypes.c_uint64
3491
- self.core.volume_get_tiles_device.argtypes = [
3746
+ self.core.wp_volume_create_device.restype = ctypes.c_uint64
3747
+ self.core.wp_volume_get_tiles_device.argtypes = [
3492
3748
  ctypes.c_uint64,
3493
3749
  ctypes.c_void_p,
3494
3750
  ]
3495
- self.core.volume_get_voxels_device.argtypes = [
3751
+ self.core.wp_volume_get_voxels_device.argtypes = [
3496
3752
  ctypes.c_uint64,
3497
3753
  ctypes.c_void_p,
3498
3754
  ]
3499
- self.core.volume_destroy_device.argtypes = [ctypes.c_uint64]
3755
+ self.core.wp_volume_destroy_device.argtypes = [ctypes.c_uint64]
3500
3756
 
3501
- self.core.volume_from_tiles_device.argtypes = [
3757
+ self.core.wp_volume_from_tiles_device.argtypes = [
3502
3758
  ctypes.c_void_p,
3503
3759
  ctypes.c_void_p,
3504
3760
  ctypes.c_int,
@@ -3509,8 +3765,8 @@ class Runtime:
3509
3765
  ctypes.c_uint32,
3510
3766
  ctypes.c_char_p,
3511
3767
  ]
3512
- self.core.volume_from_tiles_device.restype = ctypes.c_uint64
3513
- self.core.volume_index_from_tiles_device.argtypes = [
3768
+ self.core.wp_volume_from_tiles_device.restype = ctypes.c_uint64
3769
+ self.core.wp_volume_index_from_tiles_device.argtypes = [
3514
3770
  ctypes.c_void_p,
3515
3771
  ctypes.c_void_p,
3516
3772
  ctypes.c_int,
@@ -3518,8 +3774,8 @@ class Runtime:
3518
3774
  ctypes.c_float * 3,
3519
3775
  ctypes.c_bool,
3520
3776
  ]
3521
- self.core.volume_index_from_tiles_device.restype = ctypes.c_uint64
3522
- self.core.volume_from_active_voxels_device.argtypes = [
3777
+ self.core.wp_volume_index_from_tiles_device.restype = ctypes.c_uint64
3778
+ self.core.wp_volume_from_active_voxels_device.argtypes = [
3523
3779
  ctypes.c_void_p,
3524
3780
  ctypes.c_void_p,
3525
3781
  ctypes.c_int,
@@ -3527,25 +3783,25 @@ class Runtime:
3527
3783
  ctypes.c_float * 3,
3528
3784
  ctypes.c_bool,
3529
3785
  ]
3530
- self.core.volume_from_active_voxels_device.restype = ctypes.c_uint64
3786
+ self.core.wp_volume_from_active_voxels_device.restype = ctypes.c_uint64
3531
3787
 
3532
- self.core.volume_get_buffer_info.argtypes = [
3788
+ self.core.wp_volume_get_buffer_info.argtypes = [
3533
3789
  ctypes.c_uint64,
3534
3790
  ctypes.POINTER(ctypes.c_void_p),
3535
3791
  ctypes.POINTER(ctypes.c_uint64),
3536
3792
  ]
3537
- self.core.volume_get_voxel_size.argtypes = [
3793
+ self.core.wp_volume_get_voxel_size.argtypes = [
3538
3794
  ctypes.c_uint64,
3539
3795
  ctypes.POINTER(ctypes.c_float),
3540
3796
  ctypes.POINTER(ctypes.c_float),
3541
3797
  ctypes.POINTER(ctypes.c_float),
3542
3798
  ]
3543
- self.core.volume_get_tile_and_voxel_count.argtypes = [
3799
+ self.core.wp_volume_get_tile_and_voxel_count.argtypes = [
3544
3800
  ctypes.c_uint64,
3545
3801
  ctypes.POINTER(ctypes.c_uint32),
3546
3802
  ctypes.POINTER(ctypes.c_uint64),
3547
3803
  ]
3548
- self.core.volume_get_grid_info.argtypes = [
3804
+ self.core.wp_volume_get_grid_info.argtypes = [
3549
3805
  ctypes.c_uint64,
3550
3806
  ctypes.POINTER(ctypes.c_uint64),
3551
3807
  ctypes.POINTER(ctypes.c_uint32),
@@ -3554,12 +3810,12 @@ class Runtime:
3554
3810
  ctypes.c_float * 9,
3555
3811
  ctypes.c_char * 16,
3556
3812
  ]
3557
- self.core.volume_get_grid_info.restype = ctypes.c_char_p
3558
- self.core.volume_get_blind_data_count.argtypes = [
3813
+ self.core.wp_volume_get_grid_info.restype = ctypes.c_char_p
3814
+ self.core.wp_volume_get_blind_data_count.argtypes = [
3559
3815
  ctypes.c_uint64,
3560
3816
  ]
3561
- self.core.volume_get_blind_data_count.restype = ctypes.c_uint64
3562
- self.core.volume_get_blind_data_info.argtypes = [
3817
+ self.core.wp_volume_get_blind_data_count.restype = ctypes.c_uint64
3818
+ self.core.wp_volume_get_blind_data_info.argtypes = [
3563
3819
  ctypes.c_uint64,
3564
3820
  ctypes.c_uint32,
3565
3821
  ctypes.POINTER(ctypes.c_void_p),
@@ -3567,7 +3823,7 @@ class Runtime:
3567
3823
  ctypes.POINTER(ctypes.c_uint32),
3568
3824
  ctypes.c_char * 16,
3569
3825
  ]
3570
- self.core.volume_get_blind_data_info.restype = ctypes.c_char_p
3826
+ self.core.wp_volume_get_blind_data_info.restype = ctypes.c_char_p
3571
3827
 
3572
3828
  bsr_matrix_from_triplets_argtypes = [
3573
3829
  ctypes.c_int, # block_size
@@ -3589,8 +3845,8 @@ class Runtime:
3589
3845
  ctypes.c_void_p, # bsr_nnz_event
3590
3846
  ]
3591
3847
 
3592
- self.core.bsr_matrix_from_triplets_host.argtypes = bsr_matrix_from_triplets_argtypes
3593
- self.core.bsr_matrix_from_triplets_device.argtypes = bsr_matrix_from_triplets_argtypes
3848
+ self.core.wp_bsr_matrix_from_triplets_host.argtypes = bsr_matrix_from_triplets_argtypes
3849
+ self.core.wp_bsr_matrix_from_triplets_device.argtypes = bsr_matrix_from_triplets_argtypes
3594
3850
 
3595
3851
  bsr_transpose_argtypes = [
3596
3852
  ctypes.c_int, # row_count
@@ -3602,228 +3858,232 @@ class Runtime:
3602
3858
  ctypes.POINTER(ctypes.c_int), # transposed_bsr_columns
3603
3859
  ctypes.POINTER(ctypes.c_int), # src to dest block map
3604
3860
  ]
3605
- self.core.bsr_transpose_host.argtypes = bsr_transpose_argtypes
3606
- self.core.bsr_transpose_device.argtypes = bsr_transpose_argtypes
3607
-
3608
- self.core.is_cuda_enabled.argtypes = None
3609
- self.core.is_cuda_enabled.restype = ctypes.c_int
3610
- self.core.is_cuda_compatibility_enabled.argtypes = None
3611
- self.core.is_cuda_compatibility_enabled.restype = ctypes.c_int
3612
- self.core.is_mathdx_enabled.argtypes = None
3613
- self.core.is_mathdx_enabled.restype = ctypes.c_int
3614
-
3615
- self.core.cuda_driver_version.argtypes = None
3616
- self.core.cuda_driver_version.restype = ctypes.c_int
3617
- self.core.cuda_toolkit_version.argtypes = None
3618
- self.core.cuda_toolkit_version.restype = ctypes.c_int
3619
- self.core.cuda_driver_is_initialized.argtypes = None
3620
- self.core.cuda_driver_is_initialized.restype = ctypes.c_bool
3621
-
3622
- self.core.nvrtc_supported_arch_count.argtypes = None
3623
- self.core.nvrtc_supported_arch_count.restype = ctypes.c_int
3624
- self.core.nvrtc_supported_archs.argtypes = [ctypes.POINTER(ctypes.c_int)]
3625
- self.core.nvrtc_supported_archs.restype = None
3626
-
3627
- self.core.cuda_device_get_count.argtypes = None
3628
- self.core.cuda_device_get_count.restype = ctypes.c_int
3629
- self.core.cuda_device_get_primary_context.argtypes = [ctypes.c_int]
3630
- self.core.cuda_device_get_primary_context.restype = ctypes.c_void_p
3631
- self.core.cuda_device_get_name.argtypes = [ctypes.c_int]
3632
- self.core.cuda_device_get_name.restype = ctypes.c_char_p
3633
- self.core.cuda_device_get_arch.argtypes = [ctypes.c_int]
3634
- self.core.cuda_device_get_arch.restype = ctypes.c_int
3635
- self.core.cuda_device_get_sm_count.argtypes = [ctypes.c_int]
3636
- self.core.cuda_device_get_sm_count.restype = ctypes.c_int
3637
- self.core.cuda_device_is_uva.argtypes = [ctypes.c_int]
3638
- self.core.cuda_device_is_uva.restype = ctypes.c_int
3639
- self.core.cuda_device_is_mempool_supported.argtypes = [ctypes.c_int]
3640
- self.core.cuda_device_is_mempool_supported.restype = ctypes.c_int
3641
- self.core.cuda_device_is_ipc_supported.argtypes = [ctypes.c_int]
3642
- self.core.cuda_device_is_ipc_supported.restype = ctypes.c_int
3643
- self.core.cuda_device_set_mempool_release_threshold.argtypes = [ctypes.c_int, ctypes.c_uint64]
3644
- self.core.cuda_device_set_mempool_release_threshold.restype = ctypes.c_int
3645
- self.core.cuda_device_get_mempool_release_threshold.argtypes = [ctypes.c_int]
3646
- self.core.cuda_device_get_mempool_release_threshold.restype = ctypes.c_uint64
3647
- self.core.cuda_device_get_mempool_used_mem_current.argtypes = [ctypes.c_int]
3648
- self.core.cuda_device_get_mempool_used_mem_current.restype = ctypes.c_uint64
3649
- self.core.cuda_device_get_mempool_used_mem_high.argtypes = [ctypes.c_int]
3650
- self.core.cuda_device_get_mempool_used_mem_high.restype = ctypes.c_uint64
3651
- self.core.cuda_device_get_memory_info.argtypes = [ctypes.c_int, ctypes.c_void_p, ctypes.c_void_p]
3652
- self.core.cuda_device_get_memory_info.restype = None
3653
- self.core.cuda_device_get_uuid.argtypes = [ctypes.c_int, ctypes.c_char * 16]
3654
- self.core.cuda_device_get_uuid.restype = None
3655
- self.core.cuda_device_get_pci_domain_id.argtypes = [ctypes.c_int]
3656
- self.core.cuda_device_get_pci_domain_id.restype = ctypes.c_int
3657
- self.core.cuda_device_get_pci_bus_id.argtypes = [ctypes.c_int]
3658
- self.core.cuda_device_get_pci_bus_id.restype = ctypes.c_int
3659
- self.core.cuda_device_get_pci_device_id.argtypes = [ctypes.c_int]
3660
- self.core.cuda_device_get_pci_device_id.restype = ctypes.c_int
3661
-
3662
- self.core.cuda_context_get_current.argtypes = None
3663
- self.core.cuda_context_get_current.restype = ctypes.c_void_p
3664
- self.core.cuda_context_set_current.argtypes = [ctypes.c_void_p]
3665
- self.core.cuda_context_set_current.restype = None
3666
- self.core.cuda_context_push_current.argtypes = [ctypes.c_void_p]
3667
- self.core.cuda_context_push_current.restype = None
3668
- self.core.cuda_context_pop_current.argtypes = None
3669
- self.core.cuda_context_pop_current.restype = None
3670
- self.core.cuda_context_create.argtypes = [ctypes.c_int]
3671
- self.core.cuda_context_create.restype = ctypes.c_void_p
3672
- self.core.cuda_context_destroy.argtypes = [ctypes.c_void_p]
3673
- self.core.cuda_context_destroy.restype = None
3674
- self.core.cuda_context_synchronize.argtypes = [ctypes.c_void_p]
3675
- self.core.cuda_context_synchronize.restype = None
3676
- self.core.cuda_context_check.argtypes = [ctypes.c_void_p]
3677
- self.core.cuda_context_check.restype = ctypes.c_uint64
3678
-
3679
- self.core.cuda_context_get_device_ordinal.argtypes = [ctypes.c_void_p]
3680
- self.core.cuda_context_get_device_ordinal.restype = ctypes.c_int
3681
- self.core.cuda_context_is_primary.argtypes = [ctypes.c_void_p]
3682
- self.core.cuda_context_is_primary.restype = ctypes.c_int
3683
- self.core.cuda_context_get_stream.argtypes = [ctypes.c_void_p]
3684
- self.core.cuda_context_get_stream.restype = ctypes.c_void_p
3685
- self.core.cuda_context_set_stream.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int]
3686
- self.core.cuda_context_set_stream.restype = None
3861
+ self.core.wp_bsr_transpose_host.argtypes = bsr_transpose_argtypes
3862
+ self.core.wp_bsr_transpose_device.argtypes = bsr_transpose_argtypes
3863
+
3864
+ self.core.wp_is_cuda_enabled.argtypes = None
3865
+ self.core.wp_is_cuda_enabled.restype = ctypes.c_int
3866
+ self.core.wp_is_cuda_compatibility_enabled.argtypes = None
3867
+ self.core.wp_is_cuda_compatibility_enabled.restype = ctypes.c_int
3868
+ self.core.wp_is_mathdx_enabled.argtypes = None
3869
+ self.core.wp_is_mathdx_enabled.restype = ctypes.c_int
3870
+
3871
+ self.core.wp_cuda_driver_version.argtypes = None
3872
+ self.core.wp_cuda_driver_version.restype = ctypes.c_int
3873
+ self.core.wp_cuda_toolkit_version.argtypes = None
3874
+ self.core.wp_cuda_toolkit_version.restype = ctypes.c_int
3875
+ self.core.wp_cuda_driver_is_initialized.argtypes = None
3876
+ self.core.wp_cuda_driver_is_initialized.restype = ctypes.c_bool
3877
+
3878
+ self.core.wp_nvrtc_supported_arch_count.argtypes = None
3879
+ self.core.wp_nvrtc_supported_arch_count.restype = ctypes.c_int
3880
+ self.core.wp_nvrtc_supported_archs.argtypes = [ctypes.POINTER(ctypes.c_int)]
3881
+ self.core.wp_nvrtc_supported_archs.restype = None
3882
+
3883
+ self.core.wp_cuda_device_get_count.argtypes = None
3884
+ self.core.wp_cuda_device_get_count.restype = ctypes.c_int
3885
+ self.core.wp_cuda_device_get_primary_context.argtypes = [ctypes.c_int]
3886
+ self.core.wp_cuda_device_get_primary_context.restype = ctypes.c_void_p
3887
+ self.core.wp_cuda_device_get_name.argtypes = [ctypes.c_int]
3888
+ self.core.wp_cuda_device_get_name.restype = ctypes.c_char_p
3889
+ self.core.wp_cuda_device_get_arch.argtypes = [ctypes.c_int]
3890
+ self.core.wp_cuda_device_get_arch.restype = ctypes.c_int
3891
+ self.core.wp_cuda_device_get_sm_count.argtypes = [ctypes.c_int]
3892
+ self.core.wp_cuda_device_get_sm_count.restype = ctypes.c_int
3893
+ self.core.wp_cuda_device_is_uva.argtypes = [ctypes.c_int]
3894
+ self.core.wp_cuda_device_is_uva.restype = ctypes.c_int
3895
+ self.core.wp_cuda_device_is_mempool_supported.argtypes = [ctypes.c_int]
3896
+ self.core.wp_cuda_device_is_mempool_supported.restype = ctypes.c_int
3897
+ self.core.wp_cuda_device_is_ipc_supported.argtypes = [ctypes.c_int]
3898
+ self.core.wp_cuda_device_is_ipc_supported.restype = ctypes.c_int
3899
+ self.core.wp_cuda_device_set_mempool_release_threshold.argtypes = [ctypes.c_int, ctypes.c_uint64]
3900
+ self.core.wp_cuda_device_set_mempool_release_threshold.restype = ctypes.c_int
3901
+ self.core.wp_cuda_device_get_mempool_release_threshold.argtypes = [ctypes.c_int]
3902
+ self.core.wp_cuda_device_get_mempool_release_threshold.restype = ctypes.c_uint64
3903
+ self.core.wp_cuda_device_get_mempool_used_mem_current.argtypes = [ctypes.c_int]
3904
+ self.core.wp_cuda_device_get_mempool_used_mem_current.restype = ctypes.c_uint64
3905
+ self.core.wp_cuda_device_get_mempool_used_mem_high.argtypes = [ctypes.c_int]
3906
+ self.core.wp_cuda_device_get_mempool_used_mem_high.restype = ctypes.c_uint64
3907
+ self.core.wp_cuda_device_get_memory_info.argtypes = [ctypes.c_int, ctypes.c_void_p, ctypes.c_void_p]
3908
+ self.core.wp_cuda_device_get_memory_info.restype = None
3909
+ self.core.wp_cuda_device_get_uuid.argtypes = [ctypes.c_int, ctypes.c_char * 16]
3910
+ self.core.wp_cuda_device_get_uuid.restype = None
3911
+ self.core.wp_cuda_device_get_pci_domain_id.argtypes = [ctypes.c_int]
3912
+ self.core.wp_cuda_device_get_pci_domain_id.restype = ctypes.c_int
3913
+ self.core.wp_cuda_device_get_pci_bus_id.argtypes = [ctypes.c_int]
3914
+ self.core.wp_cuda_device_get_pci_bus_id.restype = ctypes.c_int
3915
+ self.core.wp_cuda_device_get_pci_device_id.argtypes = [ctypes.c_int]
3916
+ self.core.wp_cuda_device_get_pci_device_id.restype = ctypes.c_int
3917
+
3918
+ self.core.wp_cuda_context_get_current.argtypes = None
3919
+ self.core.wp_cuda_context_get_current.restype = ctypes.c_void_p
3920
+ self.core.wp_cuda_context_set_current.argtypes = [ctypes.c_void_p]
3921
+ self.core.wp_cuda_context_set_current.restype = None
3922
+ self.core.wp_cuda_context_push_current.argtypes = [ctypes.c_void_p]
3923
+ self.core.wp_cuda_context_push_current.restype = None
3924
+ self.core.wp_cuda_context_pop_current.argtypes = None
3925
+ self.core.wp_cuda_context_pop_current.restype = None
3926
+ self.core.wp_cuda_context_create.argtypes = [ctypes.c_int]
3927
+ self.core.wp_cuda_context_create.restype = ctypes.c_void_p
3928
+ self.core.wp_cuda_context_destroy.argtypes = [ctypes.c_void_p]
3929
+ self.core.wp_cuda_context_destroy.restype = None
3930
+ self.core.wp_cuda_context_synchronize.argtypes = [ctypes.c_void_p]
3931
+ self.core.wp_cuda_context_synchronize.restype = None
3932
+ self.core.wp_cuda_context_check.argtypes = [ctypes.c_void_p]
3933
+ self.core.wp_cuda_context_check.restype = ctypes.c_uint64
3934
+
3935
+ self.core.wp_cuda_context_get_device_ordinal.argtypes = [ctypes.c_void_p]
3936
+ self.core.wp_cuda_context_get_device_ordinal.restype = ctypes.c_int
3937
+ self.core.wp_cuda_context_is_primary.argtypes = [ctypes.c_void_p]
3938
+ self.core.wp_cuda_context_is_primary.restype = ctypes.c_int
3939
+ self.core.wp_cuda_context_get_stream.argtypes = [ctypes.c_void_p]
3940
+ self.core.wp_cuda_context_get_stream.restype = ctypes.c_void_p
3941
+ self.core.wp_cuda_context_set_stream.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int]
3942
+ self.core.wp_cuda_context_set_stream.restype = None
3687
3943
 
3688
3944
  # peer access
3689
- self.core.cuda_is_peer_access_supported.argtypes = [ctypes.c_int, ctypes.c_int]
3690
- self.core.cuda_is_peer_access_supported.restype = ctypes.c_int
3691
- self.core.cuda_is_peer_access_enabled.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
3692
- self.core.cuda_is_peer_access_enabled.restype = ctypes.c_int
3693
- self.core.cuda_set_peer_access_enabled.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int]
3694
- self.core.cuda_set_peer_access_enabled.restype = ctypes.c_int
3695
- self.core.cuda_is_mempool_access_enabled.argtypes = [ctypes.c_int, ctypes.c_int]
3696
- self.core.cuda_is_mempool_access_enabled.restype = ctypes.c_int
3697
- self.core.cuda_set_mempool_access_enabled.argtypes = [ctypes.c_int, ctypes.c_int, ctypes.c_int]
3698
- self.core.cuda_set_mempool_access_enabled.restype = ctypes.c_int
3945
+ self.core.wp_cuda_is_peer_access_supported.argtypes = [ctypes.c_int, ctypes.c_int]
3946
+ self.core.wp_cuda_is_peer_access_supported.restype = ctypes.c_int
3947
+ self.core.wp_cuda_is_peer_access_enabled.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
3948
+ self.core.wp_cuda_is_peer_access_enabled.restype = ctypes.c_int
3949
+ self.core.wp_cuda_set_peer_access_enabled.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int]
3950
+ self.core.wp_cuda_set_peer_access_enabled.restype = ctypes.c_int
3951
+ self.core.wp_cuda_is_mempool_access_enabled.argtypes = [ctypes.c_int, ctypes.c_int]
3952
+ self.core.wp_cuda_is_mempool_access_enabled.restype = ctypes.c_int
3953
+ self.core.wp_cuda_set_mempool_access_enabled.argtypes = [ctypes.c_int, ctypes.c_int, ctypes.c_int]
3954
+ self.core.wp_cuda_set_mempool_access_enabled.restype = ctypes.c_int
3699
3955
 
3700
3956
  # inter-process communication
3701
- self.core.cuda_ipc_get_mem_handle.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_char)]
3702
- self.core.cuda_ipc_get_mem_handle.restype = None
3703
- self.core.cuda_ipc_open_mem_handle.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_char)]
3704
- self.core.cuda_ipc_open_mem_handle.restype = ctypes.c_void_p
3705
- self.core.cuda_ipc_close_mem_handle.argtypes = [ctypes.c_void_p]
3706
- self.core.cuda_ipc_close_mem_handle.restype = None
3707
- self.core.cuda_ipc_get_event_handle.argtypes = [
3957
+ self.core.wp_cuda_ipc_get_mem_handle.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_char)]
3958
+ self.core.wp_cuda_ipc_get_mem_handle.restype = None
3959
+ self.core.wp_cuda_ipc_open_mem_handle.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_char)]
3960
+ self.core.wp_cuda_ipc_open_mem_handle.restype = ctypes.c_void_p
3961
+ self.core.wp_cuda_ipc_close_mem_handle.argtypes = [ctypes.c_void_p]
3962
+ self.core.wp_cuda_ipc_close_mem_handle.restype = None
3963
+ self.core.wp_cuda_ipc_get_event_handle.argtypes = [
3708
3964
  ctypes.c_void_p,
3709
3965
  ctypes.c_void_p,
3710
3966
  ctypes.POINTER(ctypes.c_char),
3711
3967
  ]
3712
- self.core.cuda_ipc_get_event_handle.restype = None
3713
- self.core.cuda_ipc_open_event_handle.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_char)]
3714
- self.core.cuda_ipc_open_event_handle.restype = ctypes.c_void_p
3715
-
3716
- self.core.cuda_stream_create.argtypes = [ctypes.c_void_p, ctypes.c_int]
3717
- self.core.cuda_stream_create.restype = ctypes.c_void_p
3718
- self.core.cuda_stream_destroy.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
3719
- self.core.cuda_stream_destroy.restype = None
3720
- self.core.cuda_stream_query.argtypes = [ctypes.c_void_p]
3721
- self.core.cuda_stream_query.restype = ctypes.c_int
3722
- self.core.cuda_stream_register.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
3723
- self.core.cuda_stream_register.restype = None
3724
- self.core.cuda_stream_unregister.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
3725
- self.core.cuda_stream_unregister.restype = None
3726
- self.core.cuda_stream_synchronize.argtypes = [ctypes.c_void_p]
3727
- self.core.cuda_stream_synchronize.restype = None
3728
- self.core.cuda_stream_wait_event.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
3729
- self.core.cuda_stream_wait_event.restype = None
3730
- self.core.cuda_stream_wait_stream.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p]
3731
- self.core.cuda_stream_wait_stream.restype = None
3732
- self.core.cuda_stream_is_capturing.argtypes = [ctypes.c_void_p]
3733
- self.core.cuda_stream_is_capturing.restype = ctypes.c_int
3734
- self.core.cuda_stream_get_capture_id.argtypes = [ctypes.c_void_p]
3735
- self.core.cuda_stream_get_capture_id.restype = ctypes.c_uint64
3736
- self.core.cuda_stream_get_priority.argtypes = [ctypes.c_void_p]
3737
- self.core.cuda_stream_get_priority.restype = ctypes.c_int
3738
-
3739
- self.core.cuda_event_create.argtypes = [ctypes.c_void_p, ctypes.c_uint]
3740
- self.core.cuda_event_create.restype = ctypes.c_void_p
3741
- self.core.cuda_event_destroy.argtypes = [ctypes.c_void_p]
3742
- self.core.cuda_event_destroy.restype = None
3743
- self.core.cuda_event_query.argtypes = [ctypes.c_void_p]
3744
- self.core.cuda_event_query.restype = ctypes.c_int
3745
- self.core.cuda_event_record.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_bool]
3746
- self.core.cuda_event_record.restype = None
3747
- self.core.cuda_event_synchronize.argtypes = [ctypes.c_void_p]
3748
- self.core.cuda_event_synchronize.restype = None
3749
- self.core.cuda_event_elapsed_time.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
3750
- self.core.cuda_event_elapsed_time.restype = ctypes.c_float
3751
-
3752
- self.core.cuda_graph_begin_capture.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int]
3753
- self.core.cuda_graph_begin_capture.restype = ctypes.c_bool
3754
- self.core.cuda_graph_end_capture.argtypes = [
3968
+ self.core.wp_cuda_ipc_get_event_handle.restype = None
3969
+ self.core.wp_cuda_ipc_open_event_handle.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_char)]
3970
+ self.core.wp_cuda_ipc_open_event_handle.restype = ctypes.c_void_p
3971
+
3972
+ self.core.wp_cuda_stream_create.argtypes = [ctypes.c_void_p, ctypes.c_int]
3973
+ self.core.wp_cuda_stream_create.restype = ctypes.c_void_p
3974
+ self.core.wp_cuda_stream_destroy.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
3975
+ self.core.wp_cuda_stream_destroy.restype = None
3976
+ self.core.wp_cuda_stream_query.argtypes = [ctypes.c_void_p]
3977
+ self.core.wp_cuda_stream_query.restype = ctypes.c_int
3978
+ self.core.wp_cuda_stream_register.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
3979
+ self.core.wp_cuda_stream_register.restype = None
3980
+ self.core.wp_cuda_stream_unregister.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
3981
+ self.core.wp_cuda_stream_unregister.restype = None
3982
+ self.core.wp_cuda_stream_synchronize.argtypes = [ctypes.c_void_p]
3983
+ self.core.wp_cuda_stream_synchronize.restype = None
3984
+ self.core.wp_cuda_stream_wait_event.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
3985
+ self.core.wp_cuda_stream_wait_event.restype = None
3986
+ self.core.wp_cuda_stream_wait_stream.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p]
3987
+ self.core.wp_cuda_stream_wait_stream.restype = None
3988
+ self.core.wp_cuda_stream_is_capturing.argtypes = [ctypes.c_void_p]
3989
+ self.core.wp_cuda_stream_is_capturing.restype = ctypes.c_int
3990
+ self.core.wp_cuda_stream_get_capture_id.argtypes = [ctypes.c_void_p]
3991
+ self.core.wp_cuda_stream_get_capture_id.restype = ctypes.c_uint64
3992
+ self.core.wp_cuda_stream_get_priority.argtypes = [ctypes.c_void_p]
3993
+ self.core.wp_cuda_stream_get_priority.restype = ctypes.c_int
3994
+
3995
+ self.core.wp_cuda_event_create.argtypes = [ctypes.c_void_p, ctypes.c_uint]
3996
+ self.core.wp_cuda_event_create.restype = ctypes.c_void_p
3997
+ self.core.wp_cuda_event_destroy.argtypes = [ctypes.c_void_p]
3998
+ self.core.wp_cuda_event_destroy.restype = None
3999
+ self.core.wp_cuda_event_query.argtypes = [ctypes.c_void_p]
4000
+ self.core.wp_cuda_event_query.restype = ctypes.c_int
4001
+ self.core.wp_cuda_event_record.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_bool]
4002
+ self.core.wp_cuda_event_record.restype = None
4003
+ self.core.wp_cuda_event_synchronize.argtypes = [ctypes.c_void_p]
4004
+ self.core.wp_cuda_event_synchronize.restype = None
4005
+ self.core.wp_cuda_event_elapsed_time.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
4006
+ self.core.wp_cuda_event_elapsed_time.restype = ctypes.c_float
4007
+
4008
+ self.core.wp_cuda_graph_begin_capture.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int]
4009
+ self.core.wp_cuda_graph_begin_capture.restype = ctypes.c_bool
4010
+ self.core.wp_cuda_graph_end_capture.argtypes = [
3755
4011
  ctypes.c_void_p,
3756
4012
  ctypes.c_void_p,
3757
4013
  ctypes.POINTER(ctypes.c_void_p),
3758
4014
  ]
3759
- self.core.cuda_graph_end_capture.restype = ctypes.c_bool
4015
+ self.core.wp_cuda_graph_end_capture.restype = ctypes.c_bool
3760
4016
 
3761
- self.core.cuda_graph_create_exec.argtypes = [
4017
+ self.core.wp_cuda_graph_create_exec.argtypes = [
4018
+ ctypes.c_void_p,
3762
4019
  ctypes.c_void_p,
3763
4020
  ctypes.c_void_p,
3764
4021
  ctypes.POINTER(ctypes.c_void_p),
3765
4022
  ]
3766
- self.core.cuda_graph_create_exec.restype = ctypes.c_bool
4023
+ self.core.wp_cuda_graph_create_exec.restype = ctypes.c_bool
3767
4024
 
3768
- self.core.capture_debug_dot_print.argtypes = [ctypes.c_void_p, ctypes.c_char_p, ctypes.c_uint32]
3769
- self.core.capture_debug_dot_print.restype = ctypes.c_bool
4025
+ self.core.wp_capture_debug_dot_print.argtypes = [ctypes.c_void_p, ctypes.c_char_p, ctypes.c_uint32]
4026
+ self.core.wp_capture_debug_dot_print.restype = ctypes.c_bool
3770
4027
 
3771
- self.core.cuda_graph_launch.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
3772
- self.core.cuda_graph_launch.restype = ctypes.c_bool
3773
- self.core.cuda_graph_exec_destroy.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
3774
- self.core.cuda_graph_exec_destroy.restype = ctypes.c_bool
4028
+ self.core.wp_cuda_graph_launch.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
4029
+ self.core.wp_cuda_graph_launch.restype = ctypes.c_bool
4030
+ self.core.wp_cuda_graph_exec_destroy.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
4031
+ self.core.wp_cuda_graph_exec_destroy.restype = ctypes.c_bool
3775
4032
 
3776
- self.core.cuda_graph_destroy.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
3777
- self.core.cuda_graph_destroy.restype = ctypes.c_bool
4033
+ self.core.wp_cuda_graph_destroy.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
4034
+ self.core.wp_cuda_graph_destroy.restype = ctypes.c_bool
3778
4035
 
3779
- self.core.cuda_graph_insert_if_else.argtypes = [
4036
+ self.core.wp_cuda_graph_insert_if_else.argtypes = [
3780
4037
  ctypes.c_void_p,
3781
4038
  ctypes.c_void_p,
3782
4039
  ctypes.POINTER(ctypes.c_int),
3783
4040
  ctypes.POINTER(ctypes.c_void_p),
3784
4041
  ctypes.POINTER(ctypes.c_void_p),
3785
4042
  ]
3786
- self.core.cuda_graph_insert_if_else.restype = ctypes.c_bool
4043
+ self.core.wp_cuda_graph_insert_if_else.restype = ctypes.c_bool
3787
4044
 
3788
- self.core.cuda_graph_insert_while.argtypes = [
4045
+ self.core.wp_cuda_graph_insert_while.argtypes = [
3789
4046
  ctypes.c_void_p,
3790
4047
  ctypes.c_void_p,
3791
4048
  ctypes.POINTER(ctypes.c_int),
3792
4049
  ctypes.POINTER(ctypes.c_void_p),
3793
4050
  ctypes.POINTER(ctypes.c_uint64),
3794
4051
  ]
3795
- self.core.cuda_graph_insert_while.restype = ctypes.c_bool
4052
+ self.core.wp_cuda_graph_insert_while.restype = ctypes.c_bool
3796
4053
 
3797
- self.core.cuda_graph_set_condition.argtypes = [
4054
+ self.core.wp_cuda_graph_set_condition.argtypes = [
3798
4055
  ctypes.c_void_p,
3799
4056
  ctypes.c_void_p,
3800
4057
  ctypes.POINTER(ctypes.c_int),
3801
4058
  ctypes.c_uint64,
3802
4059
  ]
3803
- self.core.cuda_graph_set_condition.restype = ctypes.c_bool
4060
+ self.core.wp_cuda_graph_set_condition.restype = ctypes.c_bool
3804
4061
 
3805
- self.core.cuda_graph_pause_capture.argtypes = [
4062
+ self.core.wp_cuda_graph_pause_capture.argtypes = [
3806
4063
  ctypes.c_void_p,
3807
4064
  ctypes.c_void_p,
3808
4065
  ctypes.POINTER(ctypes.c_void_p),
3809
4066
  ]
3810
- self.core.cuda_graph_pause_capture.restype = ctypes.c_bool
4067
+ self.core.wp_cuda_graph_pause_capture.restype = ctypes.c_bool
3811
4068
 
3812
- self.core.cuda_graph_resume_capture.argtypes = [
4069
+ self.core.wp_cuda_graph_resume_capture.argtypes = [
3813
4070
  ctypes.c_void_p,
3814
4071
  ctypes.c_void_p,
3815
4072
  ctypes.c_void_p,
3816
4073
  ]
3817
- self.core.cuda_graph_resume_capture.restype = ctypes.c_bool
4074
+ self.core.wp_cuda_graph_resume_capture.restype = ctypes.c_bool
3818
4075
 
3819
- self.core.cuda_graph_insert_child_graph.argtypes = [
4076
+ self.core.wp_cuda_graph_insert_child_graph.argtypes = [
3820
4077
  ctypes.c_void_p,
3821
4078
  ctypes.c_void_p,
3822
4079
  ctypes.c_void_p,
3823
4080
  ]
3824
- self.core.cuda_graph_insert_child_graph.restype = ctypes.c_bool
4081
+ self.core.wp_cuda_graph_insert_child_graph.restype = ctypes.c_bool
3825
4082
 
3826
- self.core.cuda_compile_program.argtypes = [
4083
+ self.core.wp_cuda_graph_check_conditional_body.argtypes = [ctypes.c_void_p]
4084
+ self.core.wp_cuda_graph_check_conditional_body.restype = ctypes.c_bool
4085
+
4086
+ self.core.wp_cuda_compile_program.argtypes = [
3827
4087
  ctypes.c_char_p, # cuda_src
3828
4088
  ctypes.c_char_p, # program name
3829
4089
  ctypes.c_int, # arch
@@ -3843,9 +4103,9 @@ class Runtime:
3843
4103
  ctypes.POINTER(ctypes.c_size_t), # ltoir_sizes
3844
4104
  ctypes.POINTER(ctypes.c_int), # ltoir_input_types, each of type nvJitLinkInputType
3845
4105
  ]
3846
- self.core.cuda_compile_program.restype = ctypes.c_size_t
4106
+ self.core.wp_cuda_compile_program.restype = ctypes.c_size_t
3847
4107
 
3848
- self.core.cuda_compile_fft.argtypes = [
4108
+ self.core.wp_cuda_compile_fft.argtypes = [
3849
4109
  ctypes.c_char_p, # lto
3850
4110
  ctypes.c_char_p, # function name
3851
4111
  ctypes.c_int, # num include dirs
@@ -3858,9 +4118,9 @@ class Runtime:
3858
4118
  ctypes.c_int, # precision
3859
4119
  ctypes.POINTER(ctypes.c_int), # smem (out)
3860
4120
  ]
3861
- self.core.cuda_compile_fft.restype = ctypes.c_bool
4121
+ self.core.wp_cuda_compile_fft.restype = ctypes.c_bool
3862
4122
 
3863
- self.core.cuda_compile_dot.argtypes = [
4123
+ self.core.wp_cuda_compile_dot.argtypes = [
3864
4124
  ctypes.c_char_p, # lto
3865
4125
  ctypes.c_char_p, # function name
3866
4126
  ctypes.c_int, # num include dirs
@@ -3879,9 +4139,9 @@ class Runtime:
3879
4139
  ctypes.c_int, # c_arrangement
3880
4140
  ctypes.c_int, # num threads
3881
4141
  ]
3882
- self.core.cuda_compile_dot.restype = ctypes.c_bool
4142
+ self.core.wp_cuda_compile_dot.restype = ctypes.c_bool
3883
4143
 
3884
- self.core.cuda_compile_solver.argtypes = [
4144
+ self.core.wp_cuda_compile_solver.argtypes = [
3885
4145
  ctypes.c_char_p, # universal fatbin
3886
4146
  ctypes.c_char_p, # lto
3887
4147
  ctypes.c_char_p, # function name
@@ -3901,24 +4161,24 @@ class Runtime:
3901
4161
  ctypes.c_int, # fill_mode
3902
4162
  ctypes.c_int, # num threads
3903
4163
  ]
3904
- self.core.cuda_compile_solver.restype = ctypes.c_bool
4164
+ self.core.wp_cuda_compile_solver.restype = ctypes.c_bool
3905
4165
 
3906
- self.core.cuda_load_module.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
3907
- self.core.cuda_load_module.restype = ctypes.c_void_p
4166
+ self.core.wp_cuda_load_module.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
4167
+ self.core.wp_cuda_load_module.restype = ctypes.c_void_p
3908
4168
 
3909
- self.core.cuda_unload_module.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
3910
- self.core.cuda_unload_module.restype = None
4169
+ self.core.wp_cuda_unload_module.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
4170
+ self.core.wp_cuda_unload_module.restype = None
3911
4171
 
3912
- self.core.cuda_get_kernel.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_char_p]
3913
- self.core.cuda_get_kernel.restype = ctypes.c_void_p
4172
+ self.core.wp_cuda_get_kernel.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_char_p]
4173
+ self.core.wp_cuda_get_kernel.restype = ctypes.c_void_p
3914
4174
 
3915
- self.core.cuda_get_max_shared_memory.argtypes = [ctypes.c_void_p]
3916
- self.core.cuda_get_max_shared_memory.restype = ctypes.c_int
4175
+ self.core.wp_cuda_get_max_shared_memory.argtypes = [ctypes.c_void_p]
4176
+ self.core.wp_cuda_get_max_shared_memory.restype = ctypes.c_int
3917
4177
 
3918
- self.core.cuda_configure_kernel_shared_memory.argtypes = [ctypes.c_void_p, ctypes.c_int]
3919
- self.core.cuda_configure_kernel_shared_memory.restype = ctypes.c_bool
4178
+ self.core.wp_cuda_configure_kernel_shared_memory.argtypes = [ctypes.c_void_p, ctypes.c_int]
4179
+ self.core.wp_cuda_configure_kernel_shared_memory.restype = ctypes.c_bool
3920
4180
 
3921
- self.core.cuda_launch_kernel.argtypes = [
4181
+ self.core.wp_cuda_launch_kernel.argtypes = [
3922
4182
  ctypes.c_void_p,
3923
4183
  ctypes.c_void_p,
3924
4184
  ctypes.c_size_t,
@@ -3928,54 +4188,54 @@ class Runtime:
3928
4188
  ctypes.POINTER(ctypes.c_void_p),
3929
4189
  ctypes.c_void_p,
3930
4190
  ]
3931
- self.core.cuda_launch_kernel.restype = ctypes.c_size_t
4191
+ self.core.wp_cuda_launch_kernel.restype = ctypes.c_size_t
3932
4192
 
3933
- self.core.cuda_graphics_map.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
3934
- self.core.cuda_graphics_map.restype = None
3935
- self.core.cuda_graphics_unmap.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
3936
- self.core.cuda_graphics_unmap.restype = None
3937
- self.core.cuda_graphics_device_ptr_and_size.argtypes = [
4193
+ self.core.wp_cuda_graphics_map.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
4194
+ self.core.wp_cuda_graphics_map.restype = None
4195
+ self.core.wp_cuda_graphics_unmap.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
4196
+ self.core.wp_cuda_graphics_unmap.restype = None
4197
+ self.core.wp_cuda_graphics_device_ptr_and_size.argtypes = [
3938
4198
  ctypes.c_void_p,
3939
4199
  ctypes.c_void_p,
3940
4200
  ctypes.POINTER(ctypes.c_uint64),
3941
4201
  ctypes.POINTER(ctypes.c_size_t),
3942
4202
  ]
3943
- self.core.cuda_graphics_device_ptr_and_size.restype = None
3944
- self.core.cuda_graphics_register_gl_buffer.argtypes = [ctypes.c_void_p, ctypes.c_uint32, ctypes.c_uint]
3945
- self.core.cuda_graphics_register_gl_buffer.restype = ctypes.c_void_p
3946
- self.core.cuda_graphics_unregister_resource.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
3947
- self.core.cuda_graphics_unregister_resource.restype = None
3948
-
3949
- self.core.cuda_timing_begin.argtypes = [ctypes.c_int]
3950
- self.core.cuda_timing_begin.restype = None
3951
- self.core.cuda_timing_get_result_count.argtypes = []
3952
- self.core.cuda_timing_get_result_count.restype = int
3953
- self.core.cuda_timing_end.argtypes = []
3954
- self.core.cuda_timing_end.restype = None
3955
-
3956
- self.core.graph_coloring.argtypes = [
4203
+ self.core.wp_cuda_graphics_device_ptr_and_size.restype = None
4204
+ self.core.wp_cuda_graphics_register_gl_buffer.argtypes = [ctypes.c_void_p, ctypes.c_uint32, ctypes.c_uint]
4205
+ self.core.wp_cuda_graphics_register_gl_buffer.restype = ctypes.c_void_p
4206
+ self.core.wp_cuda_graphics_unregister_resource.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
4207
+ self.core.wp_cuda_graphics_unregister_resource.restype = None
4208
+
4209
+ self.core.wp_cuda_timing_begin.argtypes = [ctypes.c_int]
4210
+ self.core.wp_cuda_timing_begin.restype = None
4211
+ self.core.wp_cuda_timing_get_result_count.argtypes = []
4212
+ self.core.wp_cuda_timing_get_result_count.restype = int
4213
+ self.core.wp_cuda_timing_end.argtypes = []
4214
+ self.core.wp_cuda_timing_end.restype = None
4215
+
4216
+ self.core.wp_graph_coloring.argtypes = [
3957
4217
  ctypes.c_int,
3958
4218
  warp.types.array_t,
3959
4219
  ctypes.c_int,
3960
4220
  warp.types.array_t,
3961
4221
  ]
3962
- self.core.graph_coloring.restype = ctypes.c_int
4222
+ self.core.wp_graph_coloring.restype = ctypes.c_int
3963
4223
 
3964
- self.core.balance_coloring.argtypes = [
4224
+ self.core.wp_balance_coloring.argtypes = [
3965
4225
  ctypes.c_int,
3966
4226
  warp.types.array_t,
3967
4227
  ctypes.c_int,
3968
4228
  ctypes.c_float,
3969
4229
  warp.types.array_t,
3970
4230
  ]
3971
- self.core.balance_coloring.restype = ctypes.c_float
4231
+ self.core.wp_balance_coloring.restype = ctypes.c_float
3972
4232
 
3973
- self.core.init.restype = ctypes.c_int
4233
+ self.core.wp_init.restype = ctypes.c_int
3974
4234
 
3975
4235
  except AttributeError as e:
3976
4236
  raise RuntimeError(f"Setting C-types for {warp_lib} failed. It may need rebuilding.") from e
3977
4237
 
3978
- error = self.core.init()
4238
+ error = self.core.wp_init()
3979
4239
 
3980
4240
  if error != 0:
3981
4241
  raise Exception("Warp initialization failed")
@@ -3991,8 +4251,8 @@ class Runtime:
3991
4251
  self.device_map["cpu"] = self.cpu_device
3992
4252
  self.context_map[None] = self.cpu_device
3993
4253
 
3994
- self.is_cuda_enabled = bool(self.core.is_cuda_enabled())
3995
- self.is_cuda_compatibility_enabled = bool(self.core.is_cuda_compatibility_enabled())
4254
+ self.is_cuda_enabled = bool(self.core.wp_is_cuda_enabled())
4255
+ self.is_cuda_compatibility_enabled = bool(self.core.wp_is_cuda_compatibility_enabled())
3996
4256
 
3997
4257
  self.toolkit_version = None # CTK version used to build the core lib
3998
4258
  self.driver_version = None # installed driver version
@@ -4005,12 +4265,15 @@ class Runtime:
4005
4265
 
4006
4266
  if self.is_cuda_enabled:
4007
4267
  # get CUDA Toolkit and driver versions
4008
- toolkit_version = self.core.cuda_toolkit_version()
4009
- driver_version = self.core.cuda_driver_version()
4010
-
4011
- # save versions as tuples, e.g., (12, 4)
4268
+ toolkit_version = self.core.wp_cuda_toolkit_version()
4012
4269
  self.toolkit_version = (toolkit_version // 1000, (toolkit_version % 1000) // 10)
4013
- self.driver_version = (driver_version // 1000, (driver_version % 1000) // 10)
4270
+
4271
+ if self.core.wp_cuda_driver_is_initialized():
4272
+ # save versions as tuples, e.g., (12, 4)
4273
+ driver_version = self.core.wp_cuda_driver_version()
4274
+ self.driver_version = (driver_version // 1000, (driver_version % 1000) // 10)
4275
+ else:
4276
+ self.driver_version = None
4014
4277
 
4015
4278
  # determine minimum required driver version
4016
4279
  if self.is_cuda_compatibility_enabled:
@@ -4024,18 +4287,18 @@ class Runtime:
4024
4287
  self.min_driver_version = self.toolkit_version
4025
4288
 
4026
4289
  # determine if the installed driver is sufficient
4027
- if self.driver_version >= self.min_driver_version:
4290
+ if self.driver_version is not None and self.driver_version >= self.min_driver_version:
4028
4291
  # get all architectures supported by NVRTC
4029
- num_archs = self.core.nvrtc_supported_arch_count()
4292
+ num_archs = self.core.wp_nvrtc_supported_arch_count()
4030
4293
  if num_archs > 0:
4031
4294
  archs = (ctypes.c_int * num_archs)()
4032
- self.core.nvrtc_supported_archs(archs)
4295
+ self.core.wp_nvrtc_supported_archs(archs)
4033
4296
  self.nvrtc_supported_archs = set(archs)
4034
4297
  else:
4035
4298
  self.nvrtc_supported_archs = set()
4036
4299
 
4037
4300
  # get CUDA device count
4038
- cuda_device_count = self.core.cuda_device_get_count()
4301
+ cuda_device_count = self.core.wp_cuda_device_get_count()
4039
4302
 
4040
4303
  # register primary CUDA devices
4041
4304
  for i in range(cuda_device_count):
@@ -4052,7 +4315,7 @@ class Runtime:
4052
4315
  # set default device
4053
4316
  if cuda_device_count > 0:
4054
4317
  # stick with the current cuda context, if one is bound
4055
- initial_context = self.core.cuda_context_get_current()
4318
+ initial_context = self.core.wp_cuda_context_get_current()
4056
4319
  if initial_context is not None:
4057
4320
  self.set_default_device("cuda")
4058
4321
  # if this is a non-primary context that was just registered, update the device count
@@ -4066,9 +4329,14 @@ class Runtime:
4066
4329
  # Update the default PTX architecture based on devices present in the system.
4067
4330
  # Use the lowest architecture among devices that meet the minimum architecture requirement.
4068
4331
  # Devices below the required minimum will use the highest architecture they support.
4069
- eligible_archs = [d.arch for d in self.cuda_devices if d.arch >= self.default_ptx_arch]
4070
- if eligible_archs:
4071
- self.default_ptx_arch = min(eligible_archs)
4332
+ try:
4333
+ self.default_ptx_arch = min(
4334
+ d.arch
4335
+ for d in self.cuda_devices
4336
+ if d.arch >= self.default_ptx_arch and d.arch in self.nvrtc_supported_archs
4337
+ )
4338
+ except ValueError:
4339
+ pass # no eligible NVRTC-supported arch ≥ default, retain existing
4072
4340
  else:
4073
4341
  # CUDA not available
4074
4342
  self.set_default_device("cpu")
@@ -4100,6 +4368,8 @@ class Runtime:
4100
4368
  if not self.is_cuda_enabled:
4101
4369
  # Warp was compiled without CUDA support
4102
4370
  greeting.append(" CUDA not enabled in this build")
4371
+ elif self.driver_version is None:
4372
+ greeting.append(" CUDA driver not found or failed to initialize")
4103
4373
  elif self.driver_version < self.min_driver_version:
4104
4374
  # insufficient CUDA driver version
4105
4375
  greeting.append(
@@ -4143,7 +4413,7 @@ class Runtime:
4143
4413
  access_vector.append(1)
4144
4414
  else:
4145
4415
  peer_device = self.cuda_devices[j]
4146
- can_access = self.core.cuda_is_peer_access_supported(
4416
+ can_access = self.core.wp_cuda_is_peer_access_supported(
4147
4417
  target_device.ordinal, peer_device.ordinal
4148
4418
  )
4149
4419
  access_vector.append(can_access)
@@ -4168,7 +4438,7 @@ class Runtime:
4168
4438
 
4169
4439
  if cuda_device_count > 0:
4170
4440
  # ensure initialization did not change the initial context (e.g. querying available memory)
4171
- self.core.cuda_context_set_current(initial_context)
4441
+ self.core.wp_cuda_context_set_current(initial_context)
4172
4442
 
4173
4443
  # detect possible misconfiguration of the system
4174
4444
  devices_without_uva = []
@@ -4196,7 +4466,7 @@ class Runtime:
4196
4466
  elif self.is_cuda_enabled:
4197
4467
  # Report a warning about insufficient driver version. The warning should appear even in quiet mode
4198
4468
  # when the greeting message is suppressed. Also try to provide guidance for resolving the situation.
4199
- if self.driver_version < self.min_driver_version:
4469
+ if self.driver_version is not None and self.driver_version < self.min_driver_version:
4200
4470
  msg = []
4201
4471
  msg.append("\n Insufficient CUDA driver version.")
4202
4472
  msg.append(
@@ -4207,7 +4477,7 @@ class Runtime:
4207
4477
  warp.utils.warn("\n ".join(msg))
4208
4478
 
4209
4479
  def get_error_string(self):
4210
- return self.core.get_error_string().decode("utf-8")
4480
+ return self.core.wp_get_error_string().decode("utf-8")
4211
4481
 
4212
4482
  def load_dll(self, dll_path):
4213
4483
  try:
@@ -4243,21 +4513,21 @@ class Runtime:
4243
4513
  self.default_device = self.get_device(ident)
4244
4514
 
4245
4515
  def get_current_cuda_device(self) -> Device:
4246
- current_context = self.core.cuda_context_get_current()
4516
+ current_context = self.core.wp_cuda_context_get_current()
4247
4517
  if current_context is not None:
4248
4518
  current_device = self.context_map.get(current_context)
4249
4519
  if current_device is not None:
4250
4520
  # this is a known device
4251
4521
  return current_device
4252
- elif self.core.cuda_context_is_primary(current_context):
4522
+ elif self.core.wp_cuda_context_is_primary(current_context):
4253
4523
  # this is a primary context that we haven't used yet
4254
- ordinal = self.core.cuda_context_get_device_ordinal(current_context)
4524
+ ordinal = self.core.wp_cuda_context_get_device_ordinal(current_context)
4255
4525
  device = self.cuda_devices[ordinal]
4256
4526
  self.context_map[current_context] = device
4257
4527
  return device
4258
4528
  else:
4259
4529
  # this is an unseen non-primary context, register it as a new device with a unique alias
4260
- ordinal = self.core.cuda_context_get_device_ordinal(current_context)
4530
+ ordinal = self.core.wp_cuda_context_get_device_ordinal(current_context)
4261
4531
  alias = f"cuda:{ordinal}.{self.cuda_custom_context_count[ordinal]}"
4262
4532
  self.cuda_custom_context_count[ordinal] += 1
4263
4533
  return self.map_cuda_device(alias, current_context)
@@ -4280,7 +4550,7 @@ class Runtime:
4280
4550
 
4281
4551
  def map_cuda_device(self, alias, context=None) -> Device:
4282
4552
  if context is None:
4283
- context = self.core.cuda_context_get_current()
4553
+ context = self.core.wp_cuda_context_get_current()
4284
4554
  if context is None:
4285
4555
  raise RuntimeError(f"Unable to determine CUDA context for device alias '{alias}'")
4286
4556
 
@@ -4302,10 +4572,10 @@ class Runtime:
4302
4572
  # it's an unmapped context
4303
4573
 
4304
4574
  # get the device ordinal
4305
- ordinal = self.core.cuda_context_get_device_ordinal(context)
4575
+ ordinal = self.core.wp_cuda_context_get_device_ordinal(context)
4306
4576
 
4307
4577
  # check if this is a primary context (we could get here if it's a device that hasn't been used yet)
4308
- if self.core.cuda_context_is_primary(context):
4578
+ if self.core.wp_cuda_context_is_primary(context):
4309
4579
  # rename the device
4310
4580
  device = self.cuda_primary_devices[ordinal]
4311
4581
  return self.rename_device(device, alias)
@@ -4336,7 +4606,7 @@ class Runtime:
4336
4606
  if not device.is_cuda:
4337
4607
  return
4338
4608
 
4339
- err = self.core.cuda_context_check(device.context)
4609
+ err = self.core.wp_cuda_context_check(device.context)
4340
4610
  if err != 0:
4341
4611
  raise RuntimeError(f"CUDA error detected: {err}")
4342
4612
 
@@ -4368,7 +4638,7 @@ def is_cuda_driver_initialized() -> bool:
4368
4638
  """
4369
4639
  init()
4370
4640
 
4371
- return runtime.core.cuda_driver_is_initialized()
4641
+ return runtime.core.wp_cuda_driver_is_initialized()
4372
4642
 
4373
4643
 
4374
4644
  def get_devices() -> list[Device]:
@@ -4576,7 +4846,7 @@ def set_mempool_release_threshold(device: Devicelike, threshold: int | float) ->
4576
4846
  elif threshold > 0 and threshold <= 1:
4577
4847
  threshold = int(threshold * device.total_memory)
4578
4848
 
4579
- if not runtime.core.cuda_device_set_mempool_release_threshold(device.ordinal, threshold):
4849
+ if not runtime.core.wp_cuda_device_set_mempool_release_threshold(device.ordinal, threshold):
4580
4850
  raise RuntimeError(f"Failed to set memory pool release threshold for device {device}")
4581
4851
 
4582
4852
 
@@ -4606,7 +4876,7 @@ def get_mempool_release_threshold(device: Devicelike = None) -> int:
4606
4876
  if not device.is_mempool_supported:
4607
4877
  raise RuntimeError(f"Device {device} does not support memory pools")
4608
4878
 
4609
- return runtime.core.cuda_device_get_mempool_release_threshold(device.ordinal)
4879
+ return runtime.core.wp_cuda_device_get_mempool_release_threshold(device.ordinal)
4610
4880
 
4611
4881
 
4612
4882
  def get_mempool_used_mem_current(device: Devicelike = None) -> int:
@@ -4635,7 +4905,7 @@ def get_mempool_used_mem_current(device: Devicelike = None) -> int:
4635
4905
  if not device.is_mempool_supported:
4636
4906
  raise RuntimeError(f"Device {device} does not support memory pools")
4637
4907
 
4638
- return runtime.core.cuda_device_get_mempool_used_mem_current(device.ordinal)
4908
+ return runtime.core.wp_cuda_device_get_mempool_used_mem_current(device.ordinal)
4639
4909
 
4640
4910
 
4641
4911
  def get_mempool_used_mem_high(device: Devicelike = None) -> int:
@@ -4664,7 +4934,7 @@ def get_mempool_used_mem_high(device: Devicelike = None) -> int:
4664
4934
  if not device.is_mempool_supported:
4665
4935
  raise RuntimeError(f"Device {device} does not support memory pools")
4666
4936
 
4667
- return runtime.core.cuda_device_get_mempool_used_mem_high(device.ordinal)
4937
+ return runtime.core.wp_cuda_device_get_mempool_used_mem_high(device.ordinal)
4668
4938
 
4669
4939
 
4670
4940
  def is_peer_access_supported(target_device: Devicelike, peer_device: Devicelike) -> bool:
@@ -4685,7 +4955,7 @@ def is_peer_access_supported(target_device: Devicelike, peer_device: Devicelike)
4685
4955
  if not target_device.is_cuda or not peer_device.is_cuda:
4686
4956
  return False
4687
4957
 
4688
- return bool(runtime.core.cuda_is_peer_access_supported(target_device.ordinal, peer_device.ordinal))
4958
+ return bool(runtime.core.wp_cuda_is_peer_access_supported(target_device.ordinal, peer_device.ordinal))
4689
4959
 
4690
4960
 
4691
4961
  def is_peer_access_enabled(target_device: Devicelike, peer_device: Devicelike) -> bool:
@@ -4706,7 +4976,7 @@ def is_peer_access_enabled(target_device: Devicelike, peer_device: Devicelike) -
4706
4976
  if not target_device.is_cuda or not peer_device.is_cuda:
4707
4977
  return False
4708
4978
 
4709
- return bool(runtime.core.cuda_is_peer_access_enabled(target_device.context, peer_device.context))
4979
+ return bool(runtime.core.wp_cuda_is_peer_access_enabled(target_device.context, peer_device.context))
4710
4980
 
4711
4981
 
4712
4982
  def set_peer_access_enabled(target_device: Devicelike, peer_device: Devicelike, enable: bool) -> None:
@@ -4736,7 +5006,7 @@ def set_peer_access_enabled(target_device: Devicelike, peer_device: Devicelike,
4736
5006
  else:
4737
5007
  return
4738
5008
 
4739
- if not runtime.core.cuda_set_peer_access_enabled(target_device.context, peer_device.context, int(enable)):
5009
+ if not runtime.core.wp_cuda_set_peer_access_enabled(target_device.context, peer_device.context, int(enable)):
4740
5010
  action = "enable" if enable else "disable"
4741
5011
  raise RuntimeError(f"Failed to {action} peer access from device {peer_device} to device {target_device}")
4742
5012
 
@@ -4777,7 +5047,7 @@ def is_mempool_access_enabled(target_device: Devicelike, peer_device: Devicelike
4777
5047
  if not peer_device.is_cuda or not target_device.is_cuda or not target_device.is_mempool_supported:
4778
5048
  return False
4779
5049
 
4780
- return bool(runtime.core.cuda_is_mempool_access_enabled(target_device.ordinal, peer_device.ordinal))
5050
+ return bool(runtime.core.wp_cuda_is_mempool_access_enabled(target_device.ordinal, peer_device.ordinal))
4781
5051
 
4782
5052
 
4783
5053
  def set_mempool_access_enabled(target_device: Devicelike, peer_device: Devicelike, enable: bool) -> None:
@@ -4810,7 +5080,7 @@ def set_mempool_access_enabled(target_device: Devicelike, peer_device: Devicelik
4810
5080
  else:
4811
5081
  return
4812
5082
 
4813
- if not runtime.core.cuda_set_mempool_access_enabled(target_device.ordinal, peer_device.ordinal, int(enable)):
5083
+ if not runtime.core.wp_cuda_set_mempool_access_enabled(target_device.ordinal, peer_device.ordinal, int(enable)):
4814
5084
  action = "enable" if enable else "disable"
4815
5085
  raise RuntimeError(f"Failed to {action} memory pool access from device {peer_device} to device {target_device}")
4816
5086
 
@@ -4891,7 +5161,7 @@ def get_event_elapsed_time(start_event: Event, end_event: Event, synchronize: bo
4891
5161
  if synchronize:
4892
5162
  synchronize_event(end_event)
4893
5163
 
4894
- return runtime.core.cuda_event_elapsed_time(start_event.cuda_event, end_event.cuda_event)
5164
+ return runtime.core.wp_cuda_event_elapsed_time(start_event.cuda_event, end_event.cuda_event)
4895
5165
 
4896
5166
 
4897
5167
  def wait_stream(other_stream: Stream, event: Event | None = None):
@@ -4985,7 +5255,7 @@ class RegisteredGLBuffer:
4985
5255
  self.context = self.device.context
4986
5256
  self.flags = flags
4987
5257
  self.fallback_to_copy = fallback_to_copy
4988
- self.resource = runtime.core.cuda_graphics_register_gl_buffer(self.context, gl_buffer_id, flags)
5258
+ self.resource = runtime.core.wp_cuda_graphics_register_gl_buffer(self.context, gl_buffer_id, flags)
4989
5259
  if self.resource is None:
4990
5260
  if self.fallback_to_copy:
4991
5261
  self.warp_buffer = None
@@ -5004,7 +5274,7 @@ class RegisteredGLBuffer:
5004
5274
 
5005
5275
  # use CUDA context guard to avoid side effects during garbage collection
5006
5276
  with self.device.context_guard:
5007
- runtime.core.cuda_graphics_unregister_resource(self.context, self.resource)
5277
+ runtime.core.wp_cuda_graphics_unregister_resource(self.context, self.resource)
5008
5278
 
5009
5279
  def map(self, dtype, shape) -> warp.array:
5010
5280
  """Map the OpenGL buffer to a Warp array.
@@ -5017,10 +5287,10 @@ class RegisteredGLBuffer:
5017
5287
  A Warp array object representing the mapped OpenGL buffer.
5018
5288
  """
5019
5289
  if self.resource is not None:
5020
- runtime.core.cuda_graphics_map(self.context, self.resource)
5290
+ runtime.core.wp_cuda_graphics_map(self.context, self.resource)
5021
5291
  ptr = ctypes.c_uint64(0)
5022
5292
  size = ctypes.c_size_t(0)
5023
- runtime.core.cuda_graphics_device_ptr_and_size(
5293
+ runtime.core.wp_cuda_graphics_device_ptr_and_size(
5024
5294
  self.context, self.resource, ctypes.byref(ptr), ctypes.byref(size)
5025
5295
  )
5026
5296
  return warp.array(ptr=ptr.value, dtype=dtype, shape=shape, device=self.device)
@@ -5045,7 +5315,7 @@ class RegisteredGLBuffer:
5045
5315
  def unmap(self):
5046
5316
  """Unmap the OpenGL buffer."""
5047
5317
  if self.resource is not None:
5048
- runtime.core.cuda_graphics_unmap(self.context, self.resource)
5318
+ runtime.core.wp_cuda_graphics_unmap(self.context, self.resource)
5049
5319
  elif self.fallback_to_copy:
5050
5320
  if self.warp_buffer is None:
5051
5321
  raise RuntimeError("RegisteredGLBuffer first has to be mapped")
@@ -5401,7 +5671,7 @@ def event_from_ipc_handle(handle, device: Devicelike = None) -> Event:
5401
5671
  raise RuntimeError(f"IPC is not supported on device {device}.")
5402
5672
 
5403
5673
  event = Event(
5404
- device=device, cuda_event=warp.context.runtime.core.cuda_ipc_open_event_handle(device.context, handle)
5674
+ device=device, cuda_event=warp.context.runtime.core.wp_cuda_ipc_open_event_handle(device.context, handle)
5405
5675
  )
5406
5676
  # Events created from IPC handles must be freed with cuEventDestroy
5407
5677
  event.owner = True
@@ -5533,6 +5803,44 @@ def pack_arg(kernel, arg_type, arg_name, value, device, adjoint=False):
5533
5803
  ) from e
5534
5804
 
5535
5805
 
5806
+ # invoke a CPU kernel by passing the parameters as a ctypes structure
5807
+ def invoke(kernel, hooks, params: Sequence[Any], adjoint: bool):
5808
+ fields = []
5809
+
5810
+ for i in range(0, len(kernel.adj.args)):
5811
+ arg_name = kernel.adj.args[i].label
5812
+ field = (arg_name, type(params[1 + i])) # skip the first argument, which is the launch bounds
5813
+ fields.append(field)
5814
+
5815
+ ArgsStruct = type("ArgsStruct", (ctypes.Structure,), {"_fields_": fields})
5816
+
5817
+ args = ArgsStruct()
5818
+ for i, field in enumerate(fields):
5819
+ name = field[0]
5820
+ setattr(args, name, params[1 + i])
5821
+
5822
+ if not adjoint:
5823
+ hooks.forward(params[0], ctypes.byref(args))
5824
+
5825
+ # for adjoint kernels the adjoint arguments are passed through a second struct
5826
+ else:
5827
+ adj_fields = []
5828
+
5829
+ for i in range(0, len(kernel.adj.args)):
5830
+ arg_name = kernel.adj.args[i].label
5831
+ field = (arg_name, type(params[1 + len(fields) + i])) # skip the first argument, which is the launch bounds
5832
+ adj_fields.append(field)
5833
+
5834
+ AdjArgsStruct = type("AdjArgsStruct", (ctypes.Structure,), {"_fields_": adj_fields})
5835
+
5836
+ adj_args = AdjArgsStruct()
5837
+ for i, field in enumerate(adj_fields):
5838
+ name = field[0]
5839
+ setattr(adj_args, name, params[1 + len(fields) + i])
5840
+
5841
+ hooks.backward(params[0], ctypes.byref(args), ctypes.byref(adj_args))
5842
+
5843
+
5536
5844
  class Launch:
5537
5845
  """Represents all data required for a kernel launch so that launches can be replayed quickly.
5538
5846
 
@@ -5725,24 +6033,21 @@ class Launch:
5725
6033
  stream: The stream to launch on.
5726
6034
  """
5727
6035
  if self.device.is_cpu:
5728
- if self.adjoint:
5729
- self.hooks.backward(*self.params)
5730
- else:
5731
- self.hooks.forward(*self.params)
6036
+ invoke(self.kernel, self.hooks, self.params, self.adjoint)
5732
6037
  else:
5733
6038
  if stream is None:
5734
6039
  stream = self.device.stream
5735
6040
 
5736
6041
  # If the stream is capturing, we retain the CUDA module so that it doesn't get unloaded
5737
6042
  # before the captured graph is released.
5738
- if len(runtime.captures) > 0 and runtime.core.cuda_stream_is_capturing(stream.cuda_stream):
5739
- capture_id = runtime.core.cuda_stream_get_capture_id(stream.cuda_stream)
6043
+ if len(runtime.captures) > 0 and runtime.core.wp_cuda_stream_is_capturing(stream.cuda_stream):
6044
+ capture_id = runtime.core.wp_cuda_stream_get_capture_id(stream.cuda_stream)
5740
6045
  graph = runtime.captures.get(capture_id)
5741
6046
  if graph is not None:
5742
6047
  graph.retain_module_exec(self.module_exec)
5743
6048
 
5744
6049
  if self.adjoint:
5745
- runtime.core.cuda_launch_kernel(
6050
+ runtime.core.wp_cuda_launch_kernel(
5746
6051
  self.device.context,
5747
6052
  self.hooks.backward,
5748
6053
  self.bounds.size,
@@ -5753,7 +6058,7 @@ class Launch:
5753
6058
  stream.cuda_stream,
5754
6059
  )
5755
6060
  else:
5756
- runtime.core.cuda_launch_kernel(
6061
+ runtime.core.wp_cuda_launch_kernel(
5757
6062
  self.device.context,
5758
6063
  self.hooks.forward,
5759
6064
  self.bounds.size,
@@ -5872,7 +6177,7 @@ def launch(
5872
6177
  # late bind
5873
6178
  hooks = module_exec.get_kernel_hooks(kernel)
5874
6179
 
5875
- pack_args(fwd_args, params)
6180
+ pack_args(fwd_args, params, adjoint=False)
5876
6181
  pack_args(adj_args, params, adjoint=True)
5877
6182
 
5878
6183
  # run kernel
@@ -5883,38 +6188,25 @@ def launch(
5883
6188
  f"Failed to find backward kernel '{kernel.key}' from module '{kernel.module.name}' for device '{device}'"
5884
6189
  )
5885
6190
 
5886
- if record_cmd:
5887
- launch = Launch(
5888
- kernel=kernel,
5889
- hooks=hooks,
5890
- params=params,
5891
- params_addr=None,
5892
- bounds=bounds,
5893
- device=device,
5894
- adjoint=adjoint,
5895
- )
5896
- return launch
5897
- hooks.backward(*params)
5898
-
5899
6191
  else:
5900
6192
  if hooks.forward is None:
5901
6193
  raise RuntimeError(
5902
6194
  f"Failed to find forward kernel '{kernel.key}' from module '{kernel.module.name}' for device '{device}'"
5903
6195
  )
5904
6196
 
5905
- if record_cmd:
5906
- launch = Launch(
5907
- kernel=kernel,
5908
- hooks=hooks,
5909
- params=params,
5910
- params_addr=None,
5911
- bounds=bounds,
5912
- device=device,
5913
- adjoint=adjoint,
5914
- )
5915
- return launch
5916
- else:
5917
- hooks.forward(*params)
6197
+ if record_cmd:
6198
+ launch = Launch(
6199
+ kernel=kernel,
6200
+ hooks=hooks,
6201
+ params=params,
6202
+ params_addr=None,
6203
+ bounds=bounds,
6204
+ device=device,
6205
+ adjoint=adjoint,
6206
+ )
6207
+ return launch
6208
+
6209
+ invoke(kernel, hooks, params, adjoint)
5918
6210
 
5919
6211
  else:
5920
6212
  kernel_args = [ctypes.c_void_p(ctypes.addressof(x)) for x in params]
@@ -5925,8 +6217,8 @@ def launch(
5925
6217
 
5926
6218
  # If the stream is capturing, we retain the CUDA module so that it doesn't get unloaded
5927
6219
  # before the captured graph is released.
5928
- if len(runtime.captures) > 0 and runtime.core.cuda_stream_is_capturing(stream.cuda_stream):
5929
- capture_id = runtime.core.cuda_stream_get_capture_id(stream.cuda_stream)
6220
+ if len(runtime.captures) > 0 and runtime.core.wp_cuda_stream_is_capturing(stream.cuda_stream):
6221
+ capture_id = runtime.core.wp_cuda_stream_get_capture_id(stream.cuda_stream)
5930
6222
  graph = runtime.captures.get(capture_id)
5931
6223
  if graph is not None:
5932
6224
  graph.retain_module_exec(module_exec)
@@ -5951,7 +6243,7 @@ def launch(
5951
6243
  )
5952
6244
  return launch
5953
6245
  else:
5954
- runtime.core.cuda_launch_kernel(
6246
+ runtime.core.wp_cuda_launch_kernel(
5955
6247
  device.context,
5956
6248
  hooks.backward,
5957
6249
  bounds.size,
@@ -5982,7 +6274,7 @@ def launch(
5982
6274
  return launch
5983
6275
  else:
5984
6276
  # launch
5985
- runtime.core.cuda_launch_kernel(
6277
+ runtime.core.wp_cuda_launch_kernel(
5986
6278
  device.context,
5987
6279
  hooks.forward,
5988
6280
  bounds.size,
@@ -6084,7 +6376,7 @@ def synchronize():
6084
6376
 
6085
6377
  if is_cuda_driver_initialized():
6086
6378
  # save the original context to avoid side effects
6087
- saved_context = runtime.core.cuda_context_get_current()
6379
+ saved_context = runtime.core.wp_cuda_context_get_current()
6088
6380
 
6089
6381
  # TODO: only synchronize devices that have outstanding work
6090
6382
  for device in runtime.cuda_devices:
@@ -6093,10 +6385,10 @@ def synchronize():
6093
6385
  if device.is_capturing:
6094
6386
  raise RuntimeError(f"Cannot synchronize device {device} while graph capture is active")
6095
6387
 
6096
- runtime.core.cuda_context_synchronize(device.context)
6388
+ runtime.core.wp_cuda_context_synchronize(device.context)
6097
6389
 
6098
6390
  # restore the original context to avoid side effects
6099
- runtime.core.cuda_context_set_current(saved_context)
6391
+ runtime.core.wp_cuda_context_set_current(saved_context)
6100
6392
 
6101
6393
 
6102
6394
  def synchronize_device(device: Devicelike = None):
@@ -6114,7 +6406,7 @@ def synchronize_device(device: Devicelike = None):
6114
6406
  if device.is_capturing:
6115
6407
  raise RuntimeError(f"Cannot synchronize device {device} while graph capture is active")
6116
6408
 
6117
- runtime.core.cuda_context_synchronize(device.context)
6409
+ runtime.core.wp_cuda_context_synchronize(device.context)
6118
6410
 
6119
6411
 
6120
6412
  def synchronize_stream(stream_or_device: Stream | Devicelike | None = None):
@@ -6132,7 +6424,7 @@ def synchronize_stream(stream_or_device: Stream | Devicelike | None = None):
6132
6424
  else:
6133
6425
  stream = runtime.get_device(stream_or_device).stream
6134
6426
 
6135
- runtime.core.cuda_stream_synchronize(stream.cuda_stream)
6427
+ runtime.core.wp_cuda_stream_synchronize(stream.cuda_stream)
6136
6428
 
6137
6429
 
6138
6430
  def synchronize_event(event: Event):
@@ -6144,20 +6436,25 @@ def synchronize_event(event: Event):
6144
6436
  event: Event to wait for.
6145
6437
  """
6146
6438
 
6147
- runtime.core.cuda_event_synchronize(event.cuda_event)
6439
+ runtime.core.wp_cuda_event_synchronize(event.cuda_event)
6148
6440
 
6149
6441
 
6150
- def force_load(device: Device | str | list[Device] | list[str] | None = None, modules: list[Module] | None = None):
6442
+ def force_load(
6443
+ device: Device | str | list[Device] | list[str] | None = None,
6444
+ modules: list[Module] | None = None,
6445
+ block_dim: int | None = None,
6446
+ ):
6151
6447
  """Force user-defined kernels to be compiled and loaded
6152
6448
 
6153
6449
  Args:
6154
6450
  device: The device or list of devices to load the modules on. If None, load on all devices.
6155
6451
  modules: List of modules to load. If None, load all imported modules.
6452
+ block_dim: The number of threads per block (always 1 for "cpu" devices).
6156
6453
  """
6157
6454
 
6158
6455
  if is_cuda_driver_initialized():
6159
6456
  # save original context to avoid side effects
6160
- saved_context = runtime.core.cuda_context_get_current()
6457
+ saved_context = runtime.core.wp_cuda_context_get_current()
6161
6458
 
6162
6459
  if device is None:
6163
6460
  devices = get_devices()
@@ -6171,22 +6468,26 @@ def force_load(device: Device | str | list[Device] | list[str] | None = None, mo
6171
6468
 
6172
6469
  for d in devices:
6173
6470
  for m in modules:
6174
- m.load(d)
6471
+ m.load(d, block_dim=block_dim)
6175
6472
 
6176
6473
  if is_cuda_available():
6177
6474
  # restore original context to avoid side effects
6178
- runtime.core.cuda_context_set_current(saved_context)
6475
+ runtime.core.wp_cuda_context_set_current(saved_context)
6179
6476
 
6180
6477
 
6181
6478
  def load_module(
6182
- module: Module | types.ModuleType | str | None = None, device: Device | str | None = None, recursive: bool = False
6479
+ module: Module | types.ModuleType | str | None = None,
6480
+ device: Device | str | None = None,
6481
+ recursive: bool = False,
6482
+ block_dim: int | None = None,
6183
6483
  ):
6184
- """Force user-defined module to be compiled and loaded
6484
+ """Force a user-defined module to be compiled and loaded
6185
6485
 
6186
6486
  Args:
6187
6487
  module: The module to load. If None, load the current module.
6188
6488
  device: The device to load the modules on. If None, load on all devices.
6189
6489
  recursive: Whether to load submodules. E.g., if the given module is `warp.sim`, this will also load `warp.sim.model`, `warp.sim.articulation`, etc.
6490
+ block_dim: The number of threads per block (always 1 for "cpu" devices).
6190
6491
 
6191
6492
  Note: A module must be imported before it can be loaded by this function.
6192
6493
  """
@@ -6207,9 +6508,13 @@ def load_module(
6207
6508
  modules = []
6208
6509
 
6209
6510
  # add the given module, if found
6210
- m = user_modules.get(module_name)
6211
- if m is not None:
6212
- modules.append(m)
6511
+ if isinstance(module, Module):
6512
+ # this ensures that we can load "unique" or procedural modules, which aren't added to `user_modules` by name
6513
+ modules.append(module)
6514
+ else:
6515
+ m = user_modules.get(module_name)
6516
+ if m is not None:
6517
+ modules.append(m)
6213
6518
 
6214
6519
  # add submodules, if recursive
6215
6520
  if recursive:
@@ -6218,7 +6523,203 @@ def load_module(
6218
6523
  if name.startswith(prefix):
6219
6524
  modules.append(mod)
6220
6525
 
6221
- force_load(device=device, modules=modules)
6526
+ force_load(device=device, modules=modules, block_dim=block_dim)
6527
+
6528
+
6529
+ def _resolve_module(module: Module | types.ModuleType | str) -> Module:
6530
+ """Resolve a module from a string, Module, or types.ModuleType.
6531
+
6532
+ Args:
6533
+ module: The module to resolve.
6534
+
6535
+ Returns:
6536
+ The resolved module.
6537
+
6538
+ Raises:
6539
+ TypeError: If the module argument is not a Module, a types.ModuleType, or a string.
6540
+ """
6541
+
6542
+ if isinstance(module, str):
6543
+ module_object = get_module(module)
6544
+ elif isinstance(module, Module):
6545
+ module_object = module
6546
+ elif isinstance(module, types.ModuleType):
6547
+ module_object = get_module(module.__name__)
6548
+ else:
6549
+ raise TypeError(f"Argument 'module' must be a Module or a string, got {type(module)}")
6550
+
6551
+ return module_object
6552
+
6553
+
6554
+ def compile_aot_module(
6555
+ module: Module | types.ModuleType | str,
6556
+ device: Device | str | list[Device] | list[str] | None = None,
6557
+ arch: int | Iterable[int] | None = None,
6558
+ module_dir: str | os.PathLike | None = None,
6559
+ use_ptx: bool | None = None,
6560
+ strip_hash: bool | None = None,
6561
+ ) -> None:
6562
+ """Compile a module (ahead of time) for a given device.
6563
+
6564
+ Args:
6565
+ module: The module to compile.
6566
+ device: The device or devices to compile the module for. If ``None``,
6567
+ and ``arch`` is not specified, compile the module for the current device.
6568
+ arch: The architecture or architectures to compile the module for. If ``None``,
6569
+ the architecture to compile for will be inferred from the current device.
6570
+ module_dir: The directory to save the source, meta, and compiled files to.
6571
+ If not specified, the module will be compiled to the default cache directory.
6572
+ use_ptx: Whether to compile the module to PTX. This setting is only used
6573
+ when compiling modules for the GPU. If ``None``, Warp will decide an
6574
+ appropriate setting based on the runtime environment.
6575
+ strip_hash: Whether to strip the hash from the module and kernel names.
6576
+ Setting this value to ``True`` or ``False`` will update the module's
6577
+ ``"strip_hash"`` option. If left at ``None``, the current value will
6578
+ be used.
6579
+
6580
+ Warning: Do not enable ``strip_hash`` for modules that contain generic
6581
+ kernels. Generic kernels compile to multiple overloads, and the
6582
+ per-overload hash is required to distinguish them. Stripping the hash
6583
+ in this case will cause the module to fail to compile.
6584
+
6585
+ Raises:
6586
+ TypeError: If the module argument is not a Module, a types.ModuleType, or a string.
6587
+ """
6588
+
6589
+ if is_cuda_driver_initialized():
6590
+ # save original context to avoid side effects
6591
+ saved_context = runtime.core.wp_cuda_context_get_current()
6592
+
6593
+ module_object = _resolve_module(module)
6594
+
6595
+ if strip_hash is not None:
6596
+ module_object.options["strip_hash"] = strip_hash
6597
+
6598
+ if device is None and arch:
6599
+ # User provided no device, but an arch, so we will not compile for the default device
6600
+ devices = []
6601
+ elif isinstance(device, list):
6602
+ devices = [get_device(device_item) for device_item in device]
6603
+ else:
6604
+ devices = [get_device(device)]
6605
+
6606
+ for d in devices:
6607
+ module_object.compile(d, module_dir, use_ptx=use_ptx)
6608
+
6609
+ if arch:
6610
+ if isinstance(arch, str) or not hasattr(arch, "__iter__"):
6611
+ arch = [arch]
6612
+
6613
+ for arch_value in arch:
6614
+ module_object.compile(None, module_dir, output_arch=arch_value, use_ptx=use_ptx)
6615
+
6616
+ if is_cuda_available():
6617
+ # restore original context to avoid side effects
6618
+ runtime.core.wp_cuda_context_set_current(saved_context)
6619
+
6620
+
6621
+ def load_aot_module(
6622
+ module: Module | types.ModuleType | str,
6623
+ device: Device | str | list[Device] | list[str] | None = None,
6624
+ arch: int | None = None,
6625
+ module_dir: str | os.PathLike | None = None,
6626
+ use_ptx: bool | None = None,
6627
+ strip_hash: bool = False,
6628
+ ) -> None:
6629
+ """Load a previously compiled module (ahead of time).
6630
+
6631
+ Args:
6632
+ module: The module to load.
6633
+ device: The device or devices to load the module on. If ``None``,
6634
+ load the module for the current device.
6635
+ arch: The architecture to load the module for on all devices.
6636
+ If ``None``, the architecture to load for will be inferred from the
6637
+ current device.
6638
+ module_dir: The directory to load the module from.
6639
+ If not specified, the module will be loaded from the default cache directory.
6640
+ use_ptx: Whether to load the module from PTX. This setting is only used
6641
+ when loading modules for the GPU. If ``None`` on a CUDA device, Warp will
6642
+ try both PTX and CUBIN (PTX first) and load the first that exists.
6643
+ If neither exists, a ``FileNotFoundError`` is raised listing all
6644
+ attempted paths.
6645
+ strip_hash: Whether to strip the hash from the module and kernel names.
6646
+ Setting this value to ``True`` or ``False`` will update the module's
6647
+ ``"strip_hash"`` option. If left at ``None``, the current value will
6648
+ be used.
6649
+
6650
+ Warning: Do not enable ``strip_hash`` for modules that contain generic
6651
+ kernels. Generic kernels compile to multiple overloads, and the
6652
+ per-overload hash is required to distinguish them. Stripping the hash
6653
+ in this case will cause the module to fail to compile.
6654
+
6655
+ Raises:
6656
+ FileNotFoundError: If no matching binary is found. When ``use_ptx`` is
6657
+ ``None`` on a CUDA device, both PTX and CUBIN candidates are tried
6658
+ before raising.
6659
+ TypeError: If the module argument is not a Module, a types.ModuleType, or a string.
6660
+ """
6661
+
6662
+ if is_cuda_driver_initialized():
6663
+ # save original context to avoid side effects
6664
+ saved_context = runtime.core.wp_cuda_context_get_current()
6665
+
6666
+ if device is None:
6667
+ devices = [runtime.get_device()]
6668
+ elif isinstance(device, list):
6669
+ devices = [get_device(device_item) for device_item in device]
6670
+ else:
6671
+ devices = [get_device(device)]
6672
+
6673
+ module_object = _resolve_module(module)
6674
+
6675
+ if strip_hash is not None:
6676
+ module_object.options["strip_hash"] = strip_hash
6677
+
6678
+ if module_dir is None:
6679
+ module_dir = os.path.join(warp.config.kernel_cache_dir, module_object.get_module_identifier())
6680
+ else:
6681
+ module_dir = os.fspath(module_dir)
6682
+
6683
+ for d in devices:
6684
+ # Identify the files in the cache to load
6685
+ if arch is None:
6686
+ output_arch = module_object.get_compile_arch(d)
6687
+ else:
6688
+ output_arch = arch
6689
+
6690
+ meta_path = os.path.join(module_dir, module_object.get_meta_name())
6691
+
6692
+ # Determine candidate binaries to try
6693
+ tried_paths = []
6694
+ binary_path = None
6695
+ if d.is_cuda and use_ptx is None:
6696
+ candidate_flags = (True, False) # try PTX first, then CUBIN
6697
+ else:
6698
+ candidate_flags = (use_ptx,)
6699
+
6700
+ for candidate_use_ptx in candidate_flags:
6701
+ candidate_path = os.path.join(
6702
+ module_dir, module_object.get_compile_output_name(d, output_arch, candidate_use_ptx)
6703
+ )
6704
+ tried_paths.append(candidate_path)
6705
+ if os.path.exists(candidate_path):
6706
+ binary_path = candidate_path
6707
+ break
6708
+
6709
+ if binary_path is None:
6710
+ raise FileNotFoundError(f"Binary file not found. Tried: {', '.join(tried_paths)}")
6711
+
6712
+ module_object.load(
6713
+ d,
6714
+ block_dim=module_object.options["block_dim"],
6715
+ binary_path=binary_path,
6716
+ output_arch=output_arch,
6717
+ meta_path=meta_path,
6718
+ )
6719
+
6720
+ if is_cuda_available():
6721
+ # restore original context to avoid side effects
6722
+ runtime.core.wp_cuda_context_set_current(saved_context)
6222
6723
 
6223
6724
 
6224
6725
  def set_module_options(options: dict[str, Any], module: Any = None):
@@ -6255,6 +6756,40 @@ def get_module_options(module: Any = None) -> dict[str, Any]:
6255
6756
  return get_module(m.__name__).options
6256
6757
 
6257
6758
 
6759
+ def _unregister_capture(device: Device, stream: Stream, graph: Graph):
6760
+ """Unregister a graph capture from the device and runtime.
6761
+
6762
+ This should be called when a graph capture is no longer active, either because it completed or was paused.
6763
+ The graph should only be registered while it is actively capturing.
6764
+
6765
+ Args:
6766
+ device: The CUDA device the graph was being captured on
6767
+ stream: The CUDA stream the graph was being captured on
6768
+ graph: The Graph object that was being captured
6769
+ """
6770
+ del device.captures[stream]
6771
+ del runtime.captures[graph.capture_id]
6772
+
6773
+
6774
+ def _register_capture(device: Device, stream: Stream, graph: Graph, capture_id: int):
6775
+ """Register a graph capture with the device and runtime.
6776
+
6777
+ Makes the graph discoverable through its capture_id so that retain_module_exec() can be called
6778
+ when launching kernels during graph capture. This ensures modules are retained until graph execution completes.
6779
+
6780
+ Args:
6781
+ device: The CUDA device the graph is being captured on
6782
+ stream: The CUDA stream the graph is being captured on
6783
+ graph: The Graph object being captured
6784
+ capture_id: Unique identifier for this graph capture
6785
+ """
6786
+ # add to ongoing captures on the device
6787
+ device.captures[stream] = graph
6788
+
6789
+ # add to lookup table by globally unique capture id
6790
+ runtime.captures[capture_id] = graph
6791
+
6792
+
6258
6793
  def capture_begin(
6259
6794
  device: Devicelike = None,
6260
6795
  stream: Stream | None = None,
@@ -6314,17 +6849,13 @@ def capture_begin(
6314
6849
  if force_module_load:
6315
6850
  force_load(device)
6316
6851
 
6317
- if not runtime.core.cuda_graph_begin_capture(device.context, stream.cuda_stream, int(external)):
6852
+ if not runtime.core.wp_cuda_graph_begin_capture(device.context, stream.cuda_stream, int(external)):
6318
6853
  raise RuntimeError(runtime.get_error_string())
6319
6854
 
6320
- capture_id = runtime.core.cuda_stream_get_capture_id(stream.cuda_stream)
6855
+ capture_id = runtime.core.wp_cuda_stream_get_capture_id(stream.cuda_stream)
6321
6856
  graph = Graph(device, capture_id)
6322
6857
 
6323
- # add to ongoing captures on the device
6324
- device.captures[stream] = graph
6325
-
6326
- # add to lookup table by globally unique capture id
6327
- runtime.captures[capture_id] = graph
6858
+ _register_capture(device, stream, graph, capture_id)
6328
6859
 
6329
6860
 
6330
6861
  def capture_end(device: Devicelike = None, stream: Stream | None = None) -> Graph:
@@ -6352,12 +6883,11 @@ def capture_end(device: Devicelike = None, stream: Stream | None = None) -> Grap
6352
6883
  if graph is None:
6353
6884
  raise RuntimeError("Graph capture is not active on this stream")
6354
6885
 
6355
- del device.captures[stream]
6356
- del runtime.captures[graph.capture_id]
6886
+ _unregister_capture(device, stream, graph)
6357
6887
 
6358
6888
  # get the graph executable
6359
6889
  g = ctypes.c_void_p()
6360
- result = runtime.core.cuda_graph_end_capture(device.context, stream.cuda_stream, ctypes.byref(g))
6890
+ result = runtime.core.wp_cuda_graph_end_capture(device.context, stream.cuda_stream, ctypes.byref(g))
6361
6891
 
6362
6892
  if not result:
6363
6893
  # A concrete error should've already been reported, so we don't need to go into details here
@@ -6378,7 +6908,7 @@ def capture_debug_dot_print(graph: Graph, path: str, verbose: bool = False):
6378
6908
  path: Path to save the DOT file
6379
6909
  verbose: Whether to include additional debug information in the output
6380
6910
  """
6381
- if not runtime.core.capture_debug_dot_print(graph.graph, path.encode(), 0 if verbose else 1):
6911
+ if not runtime.core.wp_capture_debug_dot_print(graph.graph, path.encode(), 0 if verbose else 1):
6382
6912
  raise RuntimeError(f"Graph debug dot print error: {runtime.get_error_string()}")
6383
6913
 
6384
6914
 
@@ -6393,7 +6923,7 @@ def assert_conditional_graph_support():
6393
6923
  raise RuntimeError("Conditional graph nodes require CUDA driver 12.4+")
6394
6924
 
6395
6925
 
6396
- def capture_pause(device: Devicelike = None, stream: Stream | None = None) -> ctypes.c_void_p:
6926
+ def capture_pause(device: Devicelike = None, stream: Stream | None = None) -> Graph:
6397
6927
  if stream is not None:
6398
6928
  device = stream.device
6399
6929
  else:
@@ -6402,14 +6932,24 @@ def capture_pause(device: Devicelike = None, stream: Stream | None = None) -> ct
6402
6932
  raise RuntimeError("Must be a CUDA device")
6403
6933
  stream = device.stream
6404
6934
 
6405
- graph = ctypes.c_void_p()
6406
- if not runtime.core.cuda_graph_pause_capture(device.context, stream.cuda_stream, ctypes.byref(graph)):
6935
+ # get the graph being captured
6936
+ graph = device.captures.get(stream)
6937
+
6938
+ if graph is None:
6939
+ raise RuntimeError("Graph capture is not active on this stream")
6940
+
6941
+ _unregister_capture(device, stream, graph)
6942
+
6943
+ g = ctypes.c_void_p()
6944
+ if not runtime.core.wp_cuda_graph_pause_capture(device.context, stream.cuda_stream, ctypes.byref(g)):
6407
6945
  raise RuntimeError(runtime.get_error_string())
6408
6946
 
6947
+ graph.graph = g
6948
+
6409
6949
  return graph
6410
6950
 
6411
6951
 
6412
- def capture_resume(graph: ctypes.c_void_p, device: Devicelike = None, stream: Stream | None = None):
6952
+ def capture_resume(graph: Graph, device: Devicelike = None, stream: Stream | None = None):
6413
6953
  if stream is not None:
6414
6954
  device = stream.device
6415
6955
  else:
@@ -6418,9 +6958,14 @@ def capture_resume(graph: ctypes.c_void_p, device: Devicelike = None, stream: St
6418
6958
  raise RuntimeError("Must be a CUDA device")
6419
6959
  stream = device.stream
6420
6960
 
6421
- if not runtime.core.cuda_graph_resume_capture(device.context, stream.cuda_stream, graph):
6961
+ if not runtime.core.wp_cuda_graph_resume_capture(device.context, stream.cuda_stream, graph.graph):
6422
6962
  raise RuntimeError(runtime.get_error_string())
6423
6963
 
6964
+ capture_id = runtime.core.wp_cuda_stream_get_capture_id(stream.cuda_stream)
6965
+ graph.capture_id = capture_id
6966
+
6967
+ _register_capture(device, stream, graph, capture_id)
6968
+
6424
6969
 
6425
6970
  # reusable pinned readback buffer for conditions
6426
6971
  condition_host = None
@@ -6499,15 +7044,13 @@ def capture_if(
6499
7044
 
6500
7045
  return
6501
7046
 
6502
- graph.has_conditional = True
6503
-
6504
7047
  # ensure conditional graph nodes are supported
6505
7048
  assert_conditional_graph_support()
6506
7049
 
6507
7050
  # insert conditional node
6508
7051
  graph_on_true = ctypes.c_void_p()
6509
7052
  graph_on_false = ctypes.c_void_p()
6510
- if not runtime.core.cuda_graph_insert_if_else(
7053
+ if not runtime.core.wp_cuda_graph_insert_if_else(
6511
7054
  device.context,
6512
7055
  stream.cuda_stream,
6513
7056
  ctypes.cast(condition.ptr, ctypes.POINTER(ctypes.c_int32)),
@@ -6518,18 +7061,19 @@ def capture_if(
6518
7061
 
6519
7062
  # pause capturing parent graph
6520
7063
  main_graph = capture_pause(stream=stream)
7064
+ # store the pointer to the cuda graph to restore it later
7065
+ main_graph_ptr = main_graph.graph
6521
7066
 
6522
7067
  # capture if-graph
6523
7068
  if on_true is not None:
6524
- capture_resume(graph_on_true, stream=stream)
7069
+ # temporarily repurpose the main_graph python object such that all dependencies
7070
+ # added through retain_module_exec() end up in the correct python graph object
7071
+ main_graph.graph = graph_on_true
7072
+ capture_resume(main_graph, stream=stream)
6525
7073
  if isinstance(on_true, Callable):
6526
7074
  on_true(**kwargs)
6527
7075
  elif isinstance(on_true, Graph):
6528
- if on_true.has_conditional:
6529
- raise RuntimeError(
6530
- "The on_true graph contains conditional nodes, which are not allowed in child graphs"
6531
- )
6532
- if not runtime.core.cuda_graph_insert_child_graph(
7076
+ if not runtime.core.wp_cuda_graph_insert_child_graph(
6533
7077
  device.context,
6534
7078
  stream.cuda_stream,
6535
7079
  on_true.graph,
@@ -6539,17 +7083,20 @@ def capture_if(
6539
7083
  raise TypeError("on_true must be a Callable or a Graph")
6540
7084
  capture_pause(stream=stream)
6541
7085
 
7086
+ # check the if-body graph
7087
+ if not runtime.core.wp_cuda_graph_check_conditional_body(graph_on_true):
7088
+ raise RuntimeError(runtime.get_error_string())
7089
+
6542
7090
  # capture else-graph
6543
7091
  if on_false is not None:
6544
- capture_resume(graph_on_false, stream=stream)
7092
+ # temporarily repurpose the main_graph python object such that all dependencies
7093
+ # added through retain_module_exec() end up in the correct python graph object
7094
+ main_graph.graph = graph_on_false
7095
+ capture_resume(main_graph, stream=stream)
6545
7096
  if isinstance(on_false, Callable):
6546
7097
  on_false(**kwargs)
6547
7098
  elif isinstance(on_false, Graph):
6548
- if on_false.has_conditional:
6549
- raise RuntimeError(
6550
- "The on_false graph contains conditional nodes, which are not allowed in child graphs"
6551
- )
6552
- if not runtime.core.cuda_graph_insert_child_graph(
7099
+ if not runtime.core.wp_cuda_graph_insert_child_graph(
6553
7100
  device.context,
6554
7101
  stream.cuda_stream,
6555
7102
  on_false.graph,
@@ -6559,6 +7106,13 @@ def capture_if(
6559
7106
  raise TypeError("on_false must be a Callable or a Graph")
6560
7107
  capture_pause(stream=stream)
6561
7108
 
7109
+ # check the else-body graph
7110
+ if not runtime.core.wp_cuda_graph_check_conditional_body(graph_on_false):
7111
+ raise RuntimeError(runtime.get_error_string())
7112
+
7113
+ # restore the main graph to its original state
7114
+ main_graph.graph = main_graph_ptr
7115
+
6562
7116
  # resume capturing parent graph
6563
7117
  capture_resume(main_graph, stream=stream)
6564
7118
 
@@ -6622,15 +7176,13 @@ def capture_while(condition: warp.array(dtype=int), while_body: Callable | Graph
6622
7176
 
6623
7177
  return
6624
7178
 
6625
- graph.has_conditional = True
6626
-
6627
7179
  # ensure conditional graph nodes are supported
6628
7180
  assert_conditional_graph_support()
6629
7181
 
6630
7182
  # insert conditional while-node
6631
7183
  body_graph = ctypes.c_void_p()
6632
7184
  cond_handle = ctypes.c_uint64()
6633
- if not runtime.core.cuda_graph_insert_while(
7185
+ if not runtime.core.wp_cuda_graph_insert_while(
6634
7186
  device.context,
6635
7187
  stream.cuda_stream,
6636
7188
  ctypes.cast(condition.ptr, ctypes.POINTER(ctypes.c_int32)),
@@ -6641,26 +7193,29 @@ def capture_while(condition: warp.array(dtype=int), while_body: Callable | Graph
6641
7193
 
6642
7194
  # pause capturing parent graph and start capturing child graph
6643
7195
  main_graph = capture_pause(stream=stream)
6644
- capture_resume(body_graph, stream=stream)
7196
+ # store the pointer to the cuda graph to restore it later
7197
+ main_graph_ptr = main_graph.graph
7198
+
7199
+ # temporarily repurpose the main_graph python object such that all dependencies
7200
+ # added through retain_module_exec() end up in the correct python graph object
7201
+ main_graph.graph = body_graph
7202
+ capture_resume(main_graph, stream=stream)
6645
7203
 
6646
7204
  # capture while-body
6647
7205
  if isinstance(while_body, Callable):
6648
7206
  while_body(**kwargs)
6649
7207
  elif isinstance(while_body, Graph):
6650
- if while_body.has_conditional:
6651
- raise RuntimeError("The body graph contains conditional nodes, which are not allowed in child graphs")
6652
-
6653
- if not runtime.core.cuda_graph_insert_child_graph(
7208
+ if not runtime.core.wp_cuda_graph_insert_child_graph(
6654
7209
  device.context,
6655
7210
  stream.cuda_stream,
6656
7211
  while_body.graph,
6657
7212
  ):
6658
7213
  raise RuntimeError(runtime.get_error_string())
6659
7214
  else:
6660
- raise RuntimeError(runtime.get_error_string())
7215
+ raise TypeError("while_body must be a callable or a graph")
6661
7216
 
6662
7217
  # update condition
6663
- if not runtime.core.cuda_graph_set_condition(
7218
+ if not runtime.core.wp_cuda_graph_set_condition(
6664
7219
  device.context,
6665
7220
  stream.cuda_stream,
6666
7221
  ctypes.cast(condition.ptr, ctypes.POINTER(ctypes.c_int32)),
@@ -6668,8 +7223,15 @@ def capture_while(condition: warp.array(dtype=int), while_body: Callable | Graph
6668
7223
  ):
6669
7224
  raise RuntimeError(runtime.get_error_string())
6670
7225
 
6671
- # stop capturing child graph and resume capturing parent graph
7226
+ # stop capturing while-body
6672
7227
  capture_pause(stream=stream)
7228
+
7229
+ # check the while-body graph
7230
+ if not runtime.core.wp_cuda_graph_check_conditional_body(body_graph):
7231
+ raise RuntimeError(runtime.get_error_string())
7232
+
7233
+ # restore the main graph to its original state
7234
+ main_graph.graph = main_graph_ptr
6673
7235
  capture_resume(main_graph, stream=stream)
6674
7236
 
6675
7237
 
@@ -6691,12 +7253,14 @@ def capture_launch(graph: Graph, stream: Stream | None = None):
6691
7253
 
6692
7254
  if graph.graph_exec is None:
6693
7255
  g = ctypes.c_void_p()
6694
- result = runtime.core.cuda_graph_create_exec(graph.device.context, graph.graph, ctypes.byref(g))
7256
+ result = runtime.core.wp_cuda_graph_create_exec(
7257
+ graph.device.context, stream.cuda_stream, graph.graph, ctypes.byref(g)
7258
+ )
6695
7259
  if not result:
6696
7260
  raise RuntimeError(f"Graph creation error: {runtime.get_error_string()}")
6697
7261
  graph.graph_exec = g
6698
7262
 
6699
- if not runtime.core.cuda_graph_launch(graph.graph_exec, stream.cuda_stream):
7263
+ if not runtime.core.wp_cuda_graph_launch(graph.graph_exec, stream.cuda_stream):
6700
7264
  raise RuntimeError(f"Graph launch error: {runtime.get_error_string()}")
6701
7265
 
6702
7266
 
@@ -6807,24 +7371,24 @@ def copy(
6807
7371
  if dest.device.is_cuda:
6808
7372
  if src.device.is_cuda:
6809
7373
  if src.device == dest.device:
6810
- result = runtime.core.memcpy_d2d(
7374
+ result = runtime.core.wp_memcpy_d2d(
6811
7375
  dest.device.context, dst_ptr, src_ptr, bytes_to_copy, stream.cuda_stream
6812
7376
  )
6813
7377
  else:
6814
- result = runtime.core.memcpy_p2p(
7378
+ result = runtime.core.wp_memcpy_p2p(
6815
7379
  dest.device.context, dst_ptr, src.device.context, src_ptr, bytes_to_copy, stream.cuda_stream
6816
7380
  )
6817
7381
  else:
6818
- result = runtime.core.memcpy_h2d(
7382
+ result = runtime.core.wp_memcpy_h2d(
6819
7383
  dest.device.context, dst_ptr, src_ptr, bytes_to_copy, stream.cuda_stream
6820
7384
  )
6821
7385
  else:
6822
7386
  if src.device.is_cuda:
6823
- result = runtime.core.memcpy_d2h(
7387
+ result = runtime.core.wp_memcpy_d2h(
6824
7388
  src.device.context, dst_ptr, src_ptr, bytes_to_copy, stream.cuda_stream
6825
7389
  )
6826
7390
  else:
6827
- result = runtime.core.memcpy_h2h(dst_ptr, src_ptr, bytes_to_copy)
7391
+ result = runtime.core.wp_memcpy_h2h(dst_ptr, src_ptr, bytes_to_copy)
6828
7392
 
6829
7393
  if not result:
6830
7394
  raise RuntimeError(f"Warp copy error: {runtime.get_error_string()}")
@@ -6859,17 +7423,17 @@ def copy(
6859
7423
  # This work involves a kernel launch, so it must run on the destination device.
6860
7424
  # If the copy stream is different, we need to synchronize it.
6861
7425
  if stream == dest.device.stream:
6862
- result = runtime.core.array_copy_device(
7426
+ result = runtime.core.wp_array_copy_device(
6863
7427
  dest.device.context, dst_ptr, src_ptr, dst_type, src_type, src_elem_size
6864
7428
  )
6865
7429
  else:
6866
7430
  dest.device.stream.wait_stream(stream)
6867
- result = runtime.core.array_copy_device(
7431
+ result = runtime.core.wp_array_copy_device(
6868
7432
  dest.device.context, dst_ptr, src_ptr, dst_type, src_type, src_elem_size
6869
7433
  )
6870
7434
  stream.wait_stream(dest.device.stream)
6871
7435
  else:
6872
- result = runtime.core.array_copy_host(dst_ptr, src_ptr, dst_type, src_type, src_elem_size)
7436
+ result = runtime.core.wp_array_copy_host(dst_ptr, src_ptr, dst_type, src_type, src_elem_size)
6873
7437
 
6874
7438
  if not result:
6875
7439
  raise RuntimeError(f"Warp copy error: {runtime.get_error_string()}")
@@ -7174,7 +7738,6 @@ def export_stubs(file): # pragma: no cover
7174
7738
  """,
7175
7739
  file=file,
7176
7740
  )
7177
-
7178
7741
  print(
7179
7742
  "# Autogenerated file, do not edit, this file provides stubs for builtins autocomplete in VSCode, PyCharm, etc",
7180
7743
  file=file,