warp-lang 1.8.1__py3-none-manylinux_2_34_aarch64.whl → 1.9.1__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (141) hide show
  1. warp/__init__.py +282 -103
  2. warp/__init__.pyi +1904 -114
  3. warp/bin/warp-clang.so +0 -0
  4. warp/bin/warp.so +0 -0
  5. warp/build.py +93 -30
  6. warp/build_dll.py +331 -101
  7. warp/builtins.py +1244 -160
  8. warp/codegen.py +317 -206
  9. warp/config.py +1 -1
  10. warp/context.py +1465 -789
  11. warp/examples/core/example_marching_cubes.py +1 -0
  12. warp/examples/core/example_render_opengl.py +100 -3
  13. warp/examples/fem/example_apic_fluid.py +98 -52
  14. warp/examples/fem/example_convection_diffusion_dg.py +25 -4
  15. warp/examples/fem/example_diffusion_mgpu.py +8 -3
  16. warp/examples/fem/utils.py +68 -22
  17. warp/examples/interop/example_jax_kernel.py +2 -1
  18. warp/fabric.py +1 -1
  19. warp/fem/cache.py +27 -19
  20. warp/fem/domain.py +2 -2
  21. warp/fem/field/nodal_field.py +2 -2
  22. warp/fem/field/virtual.py +264 -166
  23. warp/fem/geometry/geometry.py +5 -5
  24. warp/fem/integrate.py +129 -51
  25. warp/fem/space/restriction.py +4 -0
  26. warp/fem/space/shape/tet_shape_function.py +3 -10
  27. warp/jax_experimental/custom_call.py +25 -2
  28. warp/jax_experimental/ffi.py +22 -1
  29. warp/jax_experimental/xla_ffi.py +16 -7
  30. warp/marching_cubes.py +708 -0
  31. warp/native/array.h +99 -4
  32. warp/native/builtin.h +86 -9
  33. warp/native/bvh.cpp +64 -28
  34. warp/native/bvh.cu +58 -58
  35. warp/native/bvh.h +2 -2
  36. warp/native/clang/clang.cpp +7 -7
  37. warp/native/coloring.cpp +8 -2
  38. warp/native/crt.cpp +2 -2
  39. warp/native/crt.h +3 -5
  40. warp/native/cuda_util.cpp +41 -10
  41. warp/native/cuda_util.h +10 -4
  42. warp/native/exports.h +1842 -1908
  43. warp/native/fabric.h +2 -1
  44. warp/native/hashgrid.cpp +37 -37
  45. warp/native/hashgrid.cu +2 -2
  46. warp/native/initializer_array.h +1 -1
  47. warp/native/intersect.h +2 -2
  48. warp/native/mat.h +1910 -116
  49. warp/native/mathdx.cpp +43 -43
  50. warp/native/mesh.cpp +24 -24
  51. warp/native/mesh.cu +26 -26
  52. warp/native/mesh.h +4 -2
  53. warp/native/nanovdb/GridHandle.h +179 -12
  54. warp/native/nanovdb/HostBuffer.h +8 -7
  55. warp/native/nanovdb/NanoVDB.h +517 -895
  56. warp/native/nanovdb/NodeManager.h +323 -0
  57. warp/native/nanovdb/PNanoVDB.h +2 -2
  58. warp/native/quat.h +331 -14
  59. warp/native/range.h +7 -1
  60. warp/native/reduce.cpp +10 -10
  61. warp/native/reduce.cu +13 -14
  62. warp/native/runlength_encode.cpp +2 -2
  63. warp/native/runlength_encode.cu +5 -5
  64. warp/native/scan.cpp +3 -3
  65. warp/native/scan.cu +4 -4
  66. warp/native/sort.cpp +10 -10
  67. warp/native/sort.cu +40 -31
  68. warp/native/sort.h +2 -0
  69. warp/native/sparse.cpp +8 -8
  70. warp/native/sparse.cu +13 -13
  71. warp/native/spatial.h +366 -17
  72. warp/native/temp_buffer.h +2 -2
  73. warp/native/tile.h +471 -82
  74. warp/native/vec.h +328 -14
  75. warp/native/volume.cpp +54 -54
  76. warp/native/volume.cu +1 -1
  77. warp/native/volume.h +2 -1
  78. warp/native/volume_builder.cu +30 -37
  79. warp/native/warp.cpp +150 -149
  80. warp/native/warp.cu +377 -216
  81. warp/native/warp.h +227 -226
  82. warp/optim/linear.py +736 -271
  83. warp/render/imgui_manager.py +289 -0
  84. warp/render/render_opengl.py +99 -18
  85. warp/render/render_usd.py +1 -0
  86. warp/sim/graph_coloring.py +2 -2
  87. warp/sparse.py +558 -175
  88. warp/tests/aux_test_module_aot.py +7 -0
  89. warp/tests/cuda/test_async.py +3 -3
  90. warp/tests/cuda/test_conditional_captures.py +101 -0
  91. warp/tests/geometry/test_hash_grid.py +38 -0
  92. warp/tests/geometry/test_marching_cubes.py +233 -12
  93. warp/tests/interop/test_jax.py +608 -28
  94. warp/tests/sim/test_coloring.py +6 -6
  95. warp/tests/test_array.py +58 -5
  96. warp/tests/test_codegen.py +4 -3
  97. warp/tests/test_context.py +8 -15
  98. warp/tests/test_enum.py +136 -0
  99. warp/tests/test_examples.py +2 -2
  100. warp/tests/test_fem.py +49 -6
  101. warp/tests/test_fixedarray.py +229 -0
  102. warp/tests/test_func.py +18 -15
  103. warp/tests/test_future_annotations.py +7 -5
  104. warp/tests/test_linear_solvers.py +30 -0
  105. warp/tests/test_map.py +15 -1
  106. warp/tests/test_mat.py +1518 -378
  107. warp/tests/test_mat_assign_copy.py +178 -0
  108. warp/tests/test_mat_constructors.py +574 -0
  109. warp/tests/test_module_aot.py +287 -0
  110. warp/tests/test_print.py +69 -0
  111. warp/tests/test_quat.py +140 -34
  112. warp/tests/test_quat_assign_copy.py +145 -0
  113. warp/tests/test_reload.py +2 -1
  114. warp/tests/test_sparse.py +71 -0
  115. warp/tests/test_spatial.py +140 -34
  116. warp/tests/test_spatial_assign_copy.py +160 -0
  117. warp/tests/test_struct.py +43 -3
  118. warp/tests/test_tuple.py +96 -0
  119. warp/tests/test_types.py +61 -20
  120. warp/tests/test_vec.py +179 -34
  121. warp/tests/test_vec_assign_copy.py +143 -0
  122. warp/tests/tile/test_tile.py +245 -18
  123. warp/tests/tile/test_tile_cholesky.py +605 -0
  124. warp/tests/tile/test_tile_load.py +169 -0
  125. warp/tests/tile/test_tile_mathdx.py +2 -558
  126. warp/tests/tile/test_tile_matmul.py +1 -1
  127. warp/tests/tile/test_tile_mlp.py +1 -1
  128. warp/tests/tile/test_tile_shared_memory.py +5 -5
  129. warp/tests/unittest_suites.py +6 -0
  130. warp/tests/walkthrough_debug.py +1 -1
  131. warp/thirdparty/unittest_parallel.py +108 -9
  132. warp/types.py +571 -267
  133. warp/utils.py +68 -86
  134. {warp_lang-1.8.1.dist-info → warp_lang-1.9.1.dist-info}/METADATA +29 -69
  135. {warp_lang-1.8.1.dist-info → warp_lang-1.9.1.dist-info}/RECORD +138 -128
  136. warp/native/marching.cpp +0 -19
  137. warp/native/marching.cu +0 -514
  138. warp/native/marching.h +0 -19
  139. {warp_lang-1.8.1.dist-info → warp_lang-1.9.1.dist-info}/WHEEL +0 -0
  140. {warp_lang-1.8.1.dist-info → warp_lang-1.9.1.dist-info}/licenses/LICENSE.md +0 -0
  141. {warp_lang-1.8.1.dist-info → warp_lang-1.9.1.dist-info}/top_level.txt +0 -0
warp/context.py CHANGED
@@ -26,13 +26,28 @@ import json
26
26
  import operator
27
27
  import os
28
28
  import platform
29
+ import shutil
29
30
  import sys
30
31
  import types
31
32
  import typing
32
33
  import weakref
33
34
  from copy import copy as shallowcopy
34
35
  from pathlib import Path
35
- from typing import Any, Callable, Dict, List, Literal, Mapping, Sequence, Tuple, TypeVar, Union, get_args, get_origin
36
+ from typing import (
37
+ Any,
38
+ Callable,
39
+ Dict,
40
+ Iterable,
41
+ List,
42
+ Literal,
43
+ Mapping,
44
+ Sequence,
45
+ Tuple,
46
+ TypeVar,
47
+ Union,
48
+ get_args,
49
+ get_origin,
50
+ )
36
51
 
37
52
  import numpy as np
38
53
 
@@ -327,39 +342,25 @@ class Function:
327
342
  warp.codegen.apply_defaults(bound_args, self.defaults)
328
343
 
329
344
  arguments = tuple(bound_args.arguments.values())
330
-
331
- # Store the last runtime error we encountered from a function execution
332
- last_execution_error = None
345
+ arg_types = tuple(warp.codegen.get_arg_type(x) for x in arguments)
333
346
 
334
347
  # try and find a matching overload
335
348
  for overload in self.user_overloads.values():
336
349
  if len(overload.input_types) != len(arguments):
337
350
  continue
351
+
352
+ if not warp.codegen.func_match_args(overload, arg_types, {}):
353
+ continue
354
+
338
355
  template_types = list(overload.input_types.values())
339
356
  arg_names = list(overload.input_types.keys())
340
- try:
341
- # attempt to unify argument types with function template types
342
- warp.types.infer_argument_types(arguments, template_types, arg_names)
343
- return overload.func(*arguments)
344
- except Exception as e:
345
- # The function was callable but threw an error during its execution.
346
- # This might be the intended overload, but it failed, or it might be the wrong overload.
347
- # We save this specific error and continue, just in case another overload later in the
348
- # list is a better match and doesn't fail.
349
- last_execution_error = e
350
- continue
351
357
 
352
- if last_execution_error:
353
- # Raise a new, more contextual RuntimeError, but link it to the
354
- # original error that was caught. This preserves the original
355
- # traceback and error type for easier debugging.
356
- raise RuntimeError(
357
- f"Error calling function '{self.key}'. No version succeeded. "
358
- f"See above for the error from the last version that was tried."
359
- ) from last_execution_error
360
- else:
361
- # We got here without ever calling an overload.func
362
- raise RuntimeError(f"Error calling function '{self.key}', no overload found for arguments {args}")
358
+ # attempt to unify argument types with function template types
359
+ warp.types.infer_argument_types(arguments, template_types, arg_names)
360
+ return overload.func(*arguments)
361
+
362
+ # We got here without ever calling an overload.func
363
+ raise RuntimeError(f"Error calling function '{self.key}', no overload found for arguments {args}")
363
364
 
364
365
  # user-defined function with no overloads
365
366
  if self.func is None:
@@ -385,7 +386,7 @@ class Function:
385
386
  def mangle(self) -> str:
386
387
  """Build a mangled name for the C-exported function, e.g.: `builtin_normalize_vec3()`."""
387
388
 
388
- name = "builtin_" + self.key
389
+ name = "wp_builtin_" + self.key
389
390
 
390
391
  # Runtime arguments that are to be passed to the function, not its template signature.
391
392
  if self.export_func is not None:
@@ -475,6 +476,25 @@ class Function:
475
476
  # failed to find overload
476
477
  return None
477
478
 
479
+ def build(self, builder: ModuleBuilder | None):
480
+ self.adj.build(builder)
481
+
482
+ # complete the function return type after we have analyzed it (inferred from return statement in ast)
483
+ if not self.value_func:
484
+
485
+ def wrap(adj):
486
+ def value_type(arg_types, arg_values):
487
+ if adj.return_var is None or len(adj.return_var) == 0:
488
+ return None
489
+ if len(adj.return_var) == 1:
490
+ return adj.return_var[0].type
491
+ else:
492
+ return [v.type for v in adj.return_var]
493
+
494
+ return value_type
495
+
496
+ self.value_func = wrap(self.adj)
497
+
478
498
  def __repr__(self):
479
499
  inputs_str = ", ".join([f"{k}: {warp.types.type_repr(v)}" for k, v in self.input_types.items()])
480
500
  return f"<Function {self.key}({inputs_str})>"
@@ -807,14 +827,17 @@ class Kernel:
807
827
  sig = warp.types.get_signature(arg_types, func_name=self.key)
808
828
  return self.overloads.get(sig)
809
829
 
810
- def get_mangled_name(self):
811
- if self.hash is None:
812
- raise RuntimeError(f"Missing hash for kernel {self.key} in module {self.module.name}")
830
+ def get_mangled_name(self) -> str:
831
+ if self.module.options["strip_hash"]:
832
+ return self.key
833
+ else:
834
+ if self.hash is None:
835
+ raise RuntimeError(f"Missing hash for kernel {self.key} in module {self.module.name}")
813
836
 
814
- # TODO: allow customizing the number of hash characters used
815
- hash_suffix = self.hash.hex()[:8]
837
+ # TODO: allow customizing the number of hash characters used
838
+ hash_suffix = self.hash.hex()[:8]
816
839
 
817
- return f"{self.key}_{hash_suffix}"
840
+ return f"{self.key}_{hash_suffix}"
818
841
 
819
842
  def __call__(self, *args, **kwargs):
820
843
  # we implement this function only to ensure Kernel is a callable object
@@ -1597,6 +1620,9 @@ class ModuleHasher:
1597
1620
  # line directives, e.g. for Nsight Compute
1598
1621
  ch.update(bytes(ctypes.c_int(warp.config.line_directives)))
1599
1622
 
1623
+ # whether to use `assign_copy` instead of `assign_inplace`
1624
+ ch.update(bytes(ctypes.c_int(warp.config.enable_vector_component_overwrites)))
1625
+
1600
1626
  # build config
1601
1627
  ch.update(bytes(warp.config.mode, "utf-8"))
1602
1628
 
@@ -1784,6 +1810,9 @@ class ModuleBuilder:
1784
1810
  self.structs[struct] = None
1785
1811
 
1786
1812
  def build_kernel(self, kernel):
1813
+ if kernel.options.get("enable_backward", True):
1814
+ kernel.adj.used_by_backward_kernel = True
1815
+
1787
1816
  kernel.adj.build(self)
1788
1817
 
1789
1818
  if kernel.adj.return_var is not None:
@@ -1794,23 +1823,7 @@ class ModuleBuilder:
1794
1823
  if func in self.functions:
1795
1824
  return
1796
1825
  else:
1797
- func.adj.build(self)
1798
-
1799
- # complete the function return type after we have analyzed it (inferred from return statement in ast)
1800
- if not func.value_func:
1801
-
1802
- def wrap(adj):
1803
- def value_type(arg_types, arg_values):
1804
- if adj.return_var is None or len(adj.return_var) == 0:
1805
- return None
1806
- if len(adj.return_var) == 1:
1807
- return adj.return_var[0].type
1808
- else:
1809
- return [v.type for v in adj.return_var]
1810
-
1811
- return value_type
1812
-
1813
- func.value_func = wrap(func.adj)
1826
+ func.build(self)
1814
1827
 
1815
1828
  # use dict to preserve import order
1816
1829
  self.functions[func] = None
@@ -1830,10 +1843,11 @@ class ModuleBuilder:
1830
1843
  source = ""
1831
1844
 
1832
1845
  # code-gen LTO forward declarations
1833
- source += 'extern "C" {\n'
1834
- for fwd in self.ltoirs_decl.values():
1835
- source += fwd + "\n"
1836
- source += "}\n"
1846
+ if len(self.ltoirs_decl) > 0:
1847
+ source += 'extern "C" {\n'
1848
+ for fwd in self.ltoirs_decl.values():
1849
+ source += fwd + "\n"
1850
+ source += "}\n"
1837
1851
 
1838
1852
  # code-gen structs
1839
1853
  visited_structs = set()
@@ -1898,9 +1912,9 @@ class ModuleExec:
1898
1912
  if self.device.is_cuda:
1899
1913
  # use CUDA context guard to avoid side effects during garbage collection
1900
1914
  with self.device.context_guard:
1901
- runtime.core.cuda_unload_module(self.device.context, self.handle)
1915
+ runtime.core.wp_cuda_unload_module(self.device.context, self.handle)
1902
1916
  else:
1903
- runtime.llvm.unload_obj(self.handle.encode("utf-8"))
1917
+ runtime.llvm.wp_unload_obj(self.handle.encode("utf-8"))
1904
1918
 
1905
1919
  # lookup and cache kernel entry points
1906
1920
  def get_kernel_hooks(self, kernel) -> KernelHooks:
@@ -1918,13 +1932,13 @@ class ModuleExec:
1918
1932
 
1919
1933
  if self.device.is_cuda:
1920
1934
  forward_name = name + "_cuda_kernel_forward"
1921
- forward_kernel = runtime.core.cuda_get_kernel(
1935
+ forward_kernel = runtime.core.wp_cuda_get_kernel(
1922
1936
  self.device.context, self.handle, forward_name.encode("utf-8")
1923
1937
  )
1924
1938
 
1925
1939
  if options["enable_backward"]:
1926
1940
  backward_name = name + "_cuda_kernel_backward"
1927
- backward_kernel = runtime.core.cuda_get_kernel(
1941
+ backward_kernel = runtime.core.wp_cuda_get_kernel(
1928
1942
  self.device.context, self.handle, backward_name.encode("utf-8")
1929
1943
  )
1930
1944
  else:
@@ -1935,14 +1949,14 @@ class ModuleExec:
1935
1949
  backward_smem_bytes = self.meta[backward_name + "_smem_bytes"] if options["enable_backward"] else 0
1936
1950
 
1937
1951
  # configure kernels maximum shared memory size
1938
- max_smem_bytes = runtime.core.cuda_get_max_shared_memory(self.device.context)
1952
+ max_smem_bytes = runtime.core.wp_cuda_get_max_shared_memory(self.device.context)
1939
1953
 
1940
- if not runtime.core.cuda_configure_kernel_shared_memory(forward_kernel, forward_smem_bytes):
1954
+ if not runtime.core.wp_cuda_configure_kernel_shared_memory(forward_kernel, forward_smem_bytes):
1941
1955
  print(
1942
1956
  f"Warning: Failed to configure kernel dynamic shared memory for this device, tried to configure {forward_name} kernel for {forward_smem_bytes} bytes, but maximum available is {max_smem_bytes}"
1943
1957
  )
1944
1958
 
1945
- if options["enable_backward"] and not runtime.core.cuda_configure_kernel_shared_memory(
1959
+ if options["enable_backward"] and not runtime.core.wp_cuda_configure_kernel_shared_memory(
1946
1960
  backward_kernel, backward_smem_bytes
1947
1961
  ):
1948
1962
  print(
@@ -1954,12 +1968,13 @@ class ModuleExec:
1954
1968
  else:
1955
1969
  func = ctypes.CFUNCTYPE(None)
1956
1970
  forward = (
1957
- func(runtime.llvm.lookup(self.handle.encode("utf-8"), (name + "_cpu_forward").encode("utf-8"))) or None
1971
+ func(runtime.llvm.wp_lookup(self.handle.encode("utf-8"), (name + "_cpu_forward").encode("utf-8")))
1972
+ or None
1958
1973
  )
1959
1974
 
1960
1975
  if options["enable_backward"]:
1961
1976
  backward = (
1962
- func(runtime.llvm.lookup(self.handle.encode("utf-8"), (name + "_cpu_backward").encode("utf-8")))
1977
+ func(runtime.llvm.wp_lookup(self.handle.encode("utf-8"), (name + "_cpu_backward").encode("utf-8")))
1963
1978
  or None
1964
1979
  )
1965
1980
  else:
@@ -1971,6 +1986,25 @@ class ModuleExec:
1971
1986
  return hooks
1972
1987
 
1973
1988
 
1989
+ def _check_and_raise_long_path_error(e: FileNotFoundError):
1990
+ """Check if the error is due to a Windows long path and provide work-around instructions if it is.
1991
+
1992
+ ``FileNotFoundError.filename`` may legitimately be ``None`` when the originating
1993
+ API does not supply a path. Guard against that to avoid masking the original
1994
+ error with a ``TypeError``.
1995
+ """
1996
+ filename = getattr(e, "filename", None)
1997
+
1998
+ # Fast-exit when this is clearly not a legacy-path limitation:
1999
+ if filename is None or len(filename) < 260 or os.name != "nt" or filename.startswith("\\\\?\\"):
2000
+ raise e
2001
+
2002
+ raise RuntimeError(
2003
+ f"File path '{e.filename}' exceeds 259 characters, long-path support is required for this operation. "
2004
+ "See https://learn.microsoft.com/en-us/windows/win32/fileio/maximum-file-path-limitation for more information."
2005
+ ) from e
2006
+
2007
+
1974
2008
  # -----------------------------------------------------
1975
2009
  # stores all functions and kernels for a Python module
1976
2010
  # creates a hash of the function to use for checking
@@ -2024,6 +2058,7 @@ class Module:
2024
2058
  "mode": None,
2025
2059
  "block_dim": 256,
2026
2060
  "compile_time_trace": warp.config.compile_time_trace,
2061
+ "strip_hash": False,
2027
2062
  }
2028
2063
 
2029
2064
  # Module dependencies are determined by scanning each function
@@ -2170,20 +2205,23 @@ class Module:
2170
2205
  if isinstance(arg.type, warp.codegen.Struct) and arg.type.module is not None:
2171
2206
  add_ref(arg.type.module)
2172
2207
 
2173
- def hash_module(self):
2208
+ def hash_module(self) -> bytes:
2209
+ """Get the hash of the module for the current block_dim.
2210
+
2211
+ This function always creates a new `ModuleHasher` instance and computes the hash.
2212
+ """
2174
2213
  # compute latest hash
2175
2214
  block_dim = self.options["block_dim"]
2176
2215
  self.hashers[block_dim] = ModuleHasher(self)
2177
2216
  return self.hashers[block_dim].get_module_hash()
2178
2217
 
2179
- def load(self, device, block_dim=None) -> ModuleExec | None:
2180
- device = runtime.get_device(device)
2218
+ def get_module_hash(self, block_dim: int | None = None) -> bytes:
2219
+ """Get the hash of the module for the current block_dim.
2181
2220
 
2182
- # update module options if launching with a new block dim
2183
- if block_dim is not None:
2184
- self.options["block_dim"] = block_dim
2185
-
2186
- active_block_dim = self.options["block_dim"]
2221
+ If a hash has not been computed for the current block_dim, it will be computed and cached.
2222
+ """
2223
+ if block_dim is None:
2224
+ block_dim = self.options["block_dim"]
2187
2225
 
2188
2226
  if self.has_unresolved_static_expressions:
2189
2227
  # The module hash currently does not account for unresolved static expressions
@@ -2200,210 +2238,360 @@ class Module:
2200
2238
  self.has_unresolved_static_expressions = False
2201
2239
 
2202
2240
  # compute the hash if needed
2203
- if active_block_dim not in self.hashers:
2204
- self.hashers[active_block_dim] = ModuleHasher(self)
2241
+ if block_dim not in self.hashers:
2242
+ self.hashers[block_dim] = ModuleHasher(self)
2205
2243
 
2206
- # check if executable module is already loaded and not stale
2207
- exec = self.execs.get((device.context, active_block_dim))
2208
- if exec is not None:
2209
- if exec.module_hash == self.hashers[active_block_dim].get_module_hash():
2210
- return exec
2244
+ return self.hashers[block_dim].get_module_hash()
2211
2245
 
2212
- # quietly avoid repeated build attempts to reduce error spew
2213
- if device.context in self.failed_builds:
2214
- return None
2246
+ def _use_ptx(self, device) -> bool:
2247
+ return device.get_cuda_output_format(self.options.get("cuda_output")) == "ptx"
2215
2248
 
2216
- module_name = "wp_" + self.name
2217
- module_hash = self.hashers[active_block_dim].get_module_hash()
2249
+ def get_module_identifier(self) -> str:
2250
+ """Get an abbreviated module name to use for directories and files in the cache.
2218
2251
 
2219
- # use a unique module path using the module short hash
2220
- module_name_short = f"{module_name}_{module_hash.hex()[:7]}"
2221
- module_dir = os.path.join(warp.config.kernel_cache_dir, module_name_short)
2252
+ Depending on the setting of the ``"strip_hash"`` option for this module,
2253
+ the module identifier might include a content-dependent hash as a suffix.
2254
+ """
2255
+ if self.options["strip_hash"]:
2256
+ module_name_short = f"wp_{self.name}"
2257
+ else:
2258
+ module_hash = self.get_module_hash()
2259
+ module_name_short = f"wp_{self.name}_{module_hash.hex()[:7]}"
2222
2260
 
2223
- with warp.ScopedTimer(
2224
- f"Module {self.name} {module_hash.hex()[:7]} load on device '{device}'", active=not warp.config.quiet
2225
- ) as module_load_timer:
2226
- # -----------------------------------------------------------
2227
- # determine output paths
2228
- if device.is_cpu:
2229
- output_name = f"{module_name_short}.o"
2230
- output_arch = None
2261
+ return module_name_short
2231
2262
 
2232
- elif device.is_cuda:
2233
- # determine whether to use PTX or CUBIN
2234
- if device.is_cubin_supported:
2235
- # get user preference specified either per module or globally
2236
- preferred_cuda_output = self.options.get("cuda_output") or warp.config.cuda_output
2237
- if preferred_cuda_output is not None:
2238
- use_ptx = preferred_cuda_output == "ptx"
2239
- else:
2240
- # determine automatically: older drivers may not be able to handle PTX generated using newer
2241
- # CUDA Toolkits, in which case we fall back on generating CUBIN modules
2242
- use_ptx = runtime.driver_version >= runtime.toolkit_version
2243
- else:
2244
- # CUBIN not an option, must use PTX (e.g. CUDA Toolkit too old)
2245
- use_ptx = True
2263
+ def get_compile_arch(self, device: Device | None = None) -> int | None:
2264
+ if device is None:
2265
+ device = runtime.get_device()
2246
2266
 
2247
- if use_ptx:
2248
- # use the default PTX arch if the device supports it
2249
- if warp.config.ptx_target_arch is not None:
2250
- output_arch = min(device.arch, warp.config.ptx_target_arch)
2251
- else:
2252
- output_arch = min(device.arch, runtime.default_ptx_arch)
2253
- output_name = f"{module_name_short}.sm{output_arch}.ptx"
2254
- else:
2255
- output_arch = device.arch
2256
- output_name = f"{module_name_short}.sm{output_arch}.cubin"
2267
+ return device.get_cuda_compile_arch()
2257
2268
 
2258
- # final object binary path
2259
- binary_path = os.path.join(module_dir, output_name)
2269
+ def get_compile_output_name(
2270
+ self, device: Device | None, output_arch: int | None = None, use_ptx: bool | None = None
2271
+ ) -> str:
2272
+ """Get the filename to use for the compiled module binary.
2260
2273
 
2261
- # -----------------------------------------------------------
2262
- # check cache and build if necessary
2274
+ This is only the filename, e.g. ``wp___main___0340cd1.sm86.ptx``.
2275
+ It should be used to form a path.
2276
+ """
2277
+ module_name_short = self.get_module_identifier()
2263
2278
 
2264
- build_dir = None
2279
+ if device and device.is_cpu:
2280
+ return f"{module_name_short}.o"
2265
2281
 
2266
- # we always want to build if binary doesn't exist yet
2267
- # and we want to rebuild if we are not caching kernels or if we are tracking array access
2268
- if (
2269
- not os.path.exists(binary_path)
2270
- or not warp.config.cache_kernels
2271
- or warp.config.verify_autograd_array_access
2272
- ):
2273
- builder_options = {
2274
- **self.options,
2275
- # Some of the tile codegen, such as cuFFTDx and cuBLASDx, requires knowledge of the target arch
2276
- "output_arch": output_arch,
2277
- }
2278
- builder = ModuleBuilder(self, builder_options, hasher=self.hashers[active_block_dim])
2279
-
2280
- # create a temporary (process unique) dir for build outputs before moving to the binary dir
2281
- build_dir = os.path.join(
2282
- warp.config.kernel_cache_dir, f"{module_name}_{module_hash.hex()[:7]}_p{os.getpid()}"
2282
+ # For CUDA compilation, we must have an architecture.
2283
+ final_arch = output_arch
2284
+ if final_arch is None:
2285
+ if device:
2286
+ # Infer the architecture from the device
2287
+ final_arch = self.get_compile_arch(device)
2288
+ else:
2289
+ raise ValueError(
2290
+ "Either 'device' or 'output_arch' must be provided to determine compilation architecture"
2283
2291
  )
2284
2292
 
2285
- # dir may exist from previous attempts / runs / archs
2286
- Path(build_dir).mkdir(parents=True, exist_ok=True)
2293
+ # Determine if we should compile to PTX or CUBIN
2294
+ if use_ptx is None:
2295
+ if device:
2296
+ use_ptx = self._use_ptx(device)
2297
+ else:
2298
+ init()
2299
+ use_ptx = final_arch not in runtime.nvrtc_supported_archs
2287
2300
 
2288
- module_load_timer.extra_msg = " (compiled)" # For wp.ScopedTimer informational purposes
2301
+ if use_ptx:
2302
+ output_name = f"{module_name_short}.sm{final_arch}.ptx"
2303
+ else:
2304
+ output_name = f"{module_name_short}.sm{final_arch}.cubin"
2289
2305
 
2290
- mode = self.options["mode"] if self.options["mode"] is not None else warp.config.mode
2306
+ return output_name
2291
2307
 
2292
- # build CPU
2293
- if device.is_cpu:
2294
- # build
2295
- try:
2296
- source_code_path = os.path.join(build_dir, f"{module_name_short}.cpp")
2308
+ def get_meta_name(self) -> str:
2309
+ """Get the filename to use for the module metadata file.
2297
2310
 
2298
- # write cpp sources
2299
- cpp_source = builder.codegen("cpu")
2311
+ This is only the filename. It should be used to form a path.
2312
+ """
2313
+ return f"{self.get_module_identifier()}.meta"
2300
2314
 
2301
- with open(source_code_path, "w") as cpp_file:
2302
- cpp_file.write(cpp_source)
2315
+ def compile(
2316
+ self,
2317
+ device: Device | None = None,
2318
+ output_dir: str | os.PathLike | None = None,
2319
+ output_name: str | None = None,
2320
+ output_arch: int | None = None,
2321
+ use_ptx: bool | None = None,
2322
+ ) -> None:
2323
+ """Compile this module for a specific device.
2303
2324
 
2304
- output_path = os.path.join(build_dir, output_name)
2325
+ Note that this function only generates and compiles code. The resulting
2326
+ binary is not loaded into the runtime.
2305
2327
 
2306
- # build object code
2307
- with warp.ScopedTimer("Compile x86", active=warp.config.verbose):
2308
- warp.build.build_cpu(
2309
- output_path,
2310
- source_code_path,
2311
- mode=mode,
2312
- fast_math=self.options["fast_math"],
2313
- verify_fp=warp.config.verify_fp,
2314
- fuse_fp=self.options["fuse_fp"],
2315
- )
2328
+ Args:
2329
+ device: The device to compile the module for.
2330
+ output_dir: The directory to write the compiled module to.
2331
+ output_name: The name of the compiled module binary file.
2332
+ output_arch: The architecture to compile the module for.
2333
+ """
2334
+ if output_arch is None:
2335
+ output_arch = self.get_compile_arch(device) # Will remain at None if device is CPU
2316
2336
 
2317
- except Exception as e:
2318
- self.failed_builds.add(None)
2319
- module_load_timer.extra_msg = " (error)"
2320
- raise (e)
2337
+ if output_name is None:
2338
+ output_name = self.get_compile_output_name(device, output_arch, use_ptx)
2321
2339
 
2322
- elif device.is_cuda:
2323
- # build
2324
- try:
2325
- source_code_path = os.path.join(build_dir, f"{module_name_short}.cu")
2326
-
2327
- # write cuda sources
2328
- cu_source = builder.codegen("cuda")
2329
-
2330
- with open(source_code_path, "w") as cu_file:
2331
- cu_file.write(cu_source)
2332
-
2333
- output_path = os.path.join(build_dir, output_name)
2334
-
2335
- # generate PTX or CUBIN
2336
- with warp.ScopedTimer("Compile CUDA", active=warp.config.verbose):
2337
- warp.build.build_cuda(
2338
- source_code_path,
2339
- output_arch,
2340
- output_path,
2341
- config=mode,
2342
- verify_fp=warp.config.verify_fp,
2343
- fast_math=self.options["fast_math"],
2344
- fuse_fp=self.options["fuse_fp"],
2345
- lineinfo=self.options["lineinfo"],
2346
- compile_time_trace=self.options["compile_time_trace"],
2347
- ltoirs=builder.ltoirs.values(),
2348
- fatbins=builder.fatbins.values(),
2349
- )
2340
+ builder_options = {
2341
+ **self.options,
2342
+ # Some of the tile codegen, such as cuFFTDx and cuBLASDx, requires knowledge of the target arch
2343
+ "output_arch": output_arch,
2344
+ }
2345
+ builder = ModuleBuilder(
2346
+ self,
2347
+ builder_options,
2348
+ hasher=self.hashers.get(self.options["block_dim"], None),
2349
+ )
2350
2350
 
2351
- except Exception as e:
2352
- self.failed_builds.add(device.context)
2353
- module_load_timer.extra_msg = " (error)"
2354
- raise (e)
2351
+ # create a temporary (process unique) dir for build outputs before moving to the binary dir
2352
+ module_name_short = self.get_module_identifier()
2355
2353
 
2356
- # ------------------------------------------------------------
2357
- # build meta data
2354
+ if output_dir is None:
2355
+ output_dir = os.path.join(warp.config.kernel_cache_dir, f"{module_name_short}")
2356
+ else:
2357
+ output_dir = os.fspath(output_dir)
2358
2358
 
2359
- meta = builder.build_meta()
2360
- meta_path = os.path.join(build_dir, f"{module_name_short}.meta")
2359
+ meta_path = os.path.join(output_dir, self.get_meta_name())
2361
2360
 
2362
- with open(meta_path, "w") as meta_file:
2363
- json.dump(meta, meta_file)
2361
+ build_dir = os.path.normpath(output_dir) + f"_p{os.getpid()}"
2364
2362
 
2365
- # -----------------------------------------------------------
2366
- # update cache
2363
+ # dir may exist from previous attempts / runs / archs
2364
+ Path(build_dir).mkdir(parents=True, exist_ok=True)
2367
2365
 
2368
- # try to move process outputs to cache
2369
- warp.build.safe_rename(build_dir, module_dir)
2366
+ mode = self.options["mode"] if self.options["mode"] is not None else warp.config.mode
2370
2367
 
2371
- if os.path.exists(module_dir):
2372
- if not os.path.exists(binary_path):
2373
- # copy our output file to the destination module
2374
- # this is necessary in case different processes
2375
- # have different GPU architectures / devices
2376
- try:
2377
- os.rename(output_path, binary_path)
2378
- except (OSError, FileExistsError):
2379
- # another process likely updated the module dir first
2380
- pass
2368
+ # build CPU
2369
+ if output_arch is None:
2370
+ # build
2371
+ try:
2372
+ source_code_path = os.path.join(build_dir, f"{module_name_short}.cpp")
2373
+
2374
+ # write cpp sources
2375
+ cpp_source = builder.codegen("cpu")
2376
+
2377
+ with open(source_code_path, "w") as cpp_file:
2378
+ cpp_file.write(cpp_source)
2379
+
2380
+ output_path = os.path.join(build_dir, output_name)
2381
+
2382
+ # build object code
2383
+ with warp.ScopedTimer("Compile x86", active=warp.config.verbose):
2384
+ warp.build.build_cpu(
2385
+ output_path,
2386
+ source_code_path,
2387
+ mode=mode,
2388
+ fast_math=self.options["fast_math"],
2389
+ verify_fp=warp.config.verify_fp,
2390
+ fuse_fp=self.options["fuse_fp"],
2391
+ )
2392
+
2393
+ except Exception as e:
2394
+ if isinstance(e, FileNotFoundError):
2395
+ _check_and_raise_long_path_error(e)
2396
+
2397
+ self.failed_builds.add(None)
2398
+
2399
+ raise (e)
2381
2400
 
2401
+ else:
2402
+ # build
2403
+ try:
2404
+ source_code_path = os.path.join(build_dir, f"{module_name_short}.cu")
2405
+
2406
+ # write cuda sources
2407
+ cu_source = builder.codegen("cuda")
2408
+
2409
+ with open(source_code_path, "w") as cu_file:
2410
+ cu_file.write(cu_source)
2411
+
2412
+ output_path = os.path.join(build_dir, output_name)
2413
+
2414
+ # generate PTX or CUBIN
2415
+ with warp.ScopedTimer(
2416
+ f"Compile CUDA (arch={builder_options['output_arch']}, mode={mode}, block_dim={self.options['block_dim']})",
2417
+ active=warp.config.verbose,
2418
+ ):
2419
+ warp.build.build_cuda(
2420
+ source_code_path,
2421
+ builder_options["output_arch"],
2422
+ output_path,
2423
+ config=mode,
2424
+ verify_fp=warp.config.verify_fp,
2425
+ fast_math=self.options["fast_math"],
2426
+ fuse_fp=self.options["fuse_fp"],
2427
+ lineinfo=self.options["lineinfo"],
2428
+ compile_time_trace=self.options["compile_time_trace"],
2429
+ ltoirs=builder.ltoirs.values(),
2430
+ fatbins=builder.fatbins.values(),
2431
+ )
2432
+
2433
+ except Exception as e:
2434
+ if isinstance(e, FileNotFoundError):
2435
+ _check_and_raise_long_path_error(e)
2436
+
2437
+ if device:
2438
+ self.failed_builds.add(device.context)
2439
+
2440
+ raise (e)
2441
+
2442
+ # ------------------------------------------------------------
2443
+ # build meta data
2444
+
2445
+ meta = builder.build_meta()
2446
+ output_meta_path = os.path.join(build_dir, self.get_meta_name())
2447
+
2448
+ with open(output_meta_path, "w") as meta_file:
2449
+ json.dump(meta, meta_file)
2450
+
2451
+ # -----------------------------------------------------------
2452
+ # update cache
2453
+
2454
+ # try to move process outputs to cache
2455
+ warp.build.safe_rename(build_dir, output_dir)
2456
+
2457
+ if os.path.exists(output_dir):
2458
+ # final object binary path
2459
+ binary_path = os.path.join(output_dir, output_name)
2460
+
2461
+ if not os.path.exists(binary_path) or self.options["strip_hash"]:
2462
+ # copy our output file to the destination module
2463
+ # this is necessary in case different processes
2464
+ # have different GPU architectures / devices
2465
+ try:
2466
+ os.rename(output_path, binary_path)
2467
+ except (OSError, FileExistsError):
2468
+ # another process likely updated the module dir first
2469
+ pass
2470
+
2471
+ if not os.path.exists(meta_path) or self.options["strip_hash"]:
2472
+ # copy our output file to the destination module
2473
+ # this is necessary in case different processes
2474
+ # have different GPU architectures / devices
2475
+ try:
2476
+ os.rename(output_meta_path, meta_path)
2477
+ except (OSError, FileExistsError):
2478
+ # another process likely updated the module dir first
2479
+ pass
2480
+
2481
+ try:
2482
+ final_source_path = os.path.join(output_dir, os.path.basename(source_code_path))
2483
+ if not os.path.exists(final_source_path) or self.options["strip_hash"]:
2484
+ os.rename(source_code_path, final_source_path)
2485
+ except (OSError, FileExistsError):
2486
+ # another process likely updated the module dir first
2487
+ pass
2488
+ except Exception as e:
2489
+ # We don't need source_code_path to be copied successfully to proceed, so warn and keep running
2490
+ warp.utils.warn(f"Exception when renaming {source_code_path}: {e}")
2491
+
2492
+ # clean up build_dir used for this process regardless
2493
+ shutil.rmtree(build_dir, ignore_errors=True)
2494
+
2495
+ def load(
2496
+ self,
2497
+ device,
2498
+ block_dim: int | None = None,
2499
+ binary_path: os.PathLike | None = None,
2500
+ output_arch: int | None = None,
2501
+ meta_path: os.PathLike | None = None,
2502
+ ) -> ModuleExec | None:
2503
+ device = runtime.get_device(device)
2504
+
2505
+ # update module options if launching with a new block dim
2506
+ if block_dim is not None:
2507
+ self.options["block_dim"] = block_dim
2508
+
2509
+ active_block_dim = self.options["block_dim"]
2510
+
2511
+ # check if executable module is already loaded and not stale
2512
+ exec = self.execs.get((device.context, active_block_dim))
2513
+ if exec is not None:
2514
+ if self.options["strip_hash"] or (exec.module_hash == self.get_module_hash(active_block_dim)):
2515
+ return exec
2516
+
2517
+ # quietly avoid repeated build attempts to reduce error spew
2518
+ if device.context in self.failed_builds:
2519
+ return None
2520
+
2521
+ module_hash = self.get_module_hash(active_block_dim)
2522
+
2523
+ # use a unique module path using the module short hash
2524
+ module_name_short = self.get_module_identifier()
2525
+
2526
+ module_load_timer_name = (
2527
+ f"Module {self.name} {module_hash.hex()[:7]} load on device '{device}'"
2528
+ if self.options["strip_hash"] is False
2529
+ else f"Module {self.name} load on device '{device}'"
2530
+ )
2531
+
2532
+ if warp.config.verbose:
2533
+ module_load_timer_name += f" (block_dim={active_block_dim})"
2534
+
2535
+ with warp.ScopedTimer(module_load_timer_name, active=not warp.config.quiet) as module_load_timer:
2536
+ # -----------------------------------------------------------
2537
+ # Determine binary path and build if necessary
2538
+
2539
+ if binary_path:
2540
+ # We will never re-codegen or re-compile in this situation
2541
+ # The expected files must already exist
2542
+
2543
+ if device.is_cuda and output_arch is None:
2544
+ raise ValueError("'output_arch' must be provided if a 'binary_path' is provided")
2545
+
2546
+ if meta_path is None:
2547
+ raise ValueError("'meta_path' must be provided if a 'binary_path' is provided")
2548
+
2549
+ if not os.path.exists(binary_path):
2550
+ module_load_timer.extra_msg = " (error)"
2551
+ raise FileNotFoundError(f"Binary file {binary_path} does not exist")
2552
+ else:
2553
+ module_load_timer.extra_msg = " (cached)"
2554
+ else:
2555
+ # we will build if binary doesn't exist yet
2556
+ # we will rebuild if we are not caching kernels or if we are tracking array access
2557
+
2558
+ output_name = self.get_compile_output_name(device)
2559
+ output_arch = self.get_compile_arch(device)
2560
+
2561
+ module_dir = os.path.join(warp.config.kernel_cache_dir, module_name_short)
2562
+ meta_path = os.path.join(module_dir, self.get_meta_name())
2563
+ # final object binary path
2564
+ binary_path = os.path.join(module_dir, output_name)
2565
+
2566
+ if (
2567
+ not os.path.exists(binary_path)
2568
+ or not warp.config.cache_kernels
2569
+ or warp.config.verify_autograd_array_access
2570
+ ):
2382
2571
  try:
2383
- final_source_path = os.path.join(module_dir, os.path.basename(source_code_path))
2384
- if not os.path.exists(final_source_path):
2385
- os.rename(source_code_path, final_source_path)
2386
- except (OSError, FileExistsError):
2387
- # another process likely updated the module dir first
2388
- pass
2572
+ self.compile(device, module_dir, output_name, output_arch)
2389
2573
  except Exception as e:
2390
- # We don't need source_code_path to be copied successfully to proceed, so warn and keep running
2391
- warp.utils.warn(f"Exception when renaming {source_code_path}: {e}")
2392
- else:
2393
- module_load_timer.extra_msg = " (cached)" # For wp.ScopedTimer informational purposes
2574
+ module_load_timer.extra_msg = " (error)"
2575
+ raise (e)
2576
+
2577
+ module_load_timer.extra_msg = " (compiled)"
2578
+ else:
2579
+ module_load_timer.extra_msg = " (cached)"
2394
2580
 
2395
2581
  # -----------------------------------------------------------
2396
2582
  # Load CPU or CUDA binary
2397
2583
 
2398
- meta_path = os.path.join(module_dir, f"{module_name_short}.meta")
2399
- with open(meta_path) as meta_file:
2400
- meta = json.load(meta_file)
2584
+ if os.path.exists(meta_path):
2585
+ with open(meta_path) as meta_file:
2586
+ meta = json.load(meta_file)
2587
+ else:
2588
+ raise FileNotFoundError(f"Module metadata file {meta_path} was not found in the cache")
2401
2589
 
2402
2590
  if device.is_cpu:
2403
2591
  # LLVM modules are identified using strings, so we need to ensure uniqueness
2404
- module_handle = f"{module_name}_{self.cpu_exec_id}"
2592
+ module_handle = f"wp_{self.name}_{self.cpu_exec_id}"
2405
2593
  self.cpu_exec_id += 1
2406
- runtime.llvm.load_obj(binary_path.encode("utf-8"), module_handle.encode("utf-8"))
2594
+ runtime.llvm.wp_load_obj(binary_path.encode("utf-8"), module_handle.encode("utf-8"))
2407
2595
  module_exec = ModuleExec(module_handle, module_hash, device, meta)
2408
2596
  self.execs[(None, active_block_dim)] = module_exec
2409
2597
 
@@ -2416,12 +2604,6 @@ class Module:
2416
2604
  module_load_timer.extra_msg = " (error)"
2417
2605
  raise Exception(f"Failed to load CUDA module '{self.name}'")
2418
2606
 
2419
- if build_dir:
2420
- import shutil
2421
-
2422
- # clean up build_dir used for this process regardless
2423
- shutil.rmtree(build_dir, ignore_errors=True)
2424
-
2425
2607
  return module_exec
2426
2608
 
2427
2609
  def unload(self):
@@ -2457,13 +2639,13 @@ class CpuDefaultAllocator:
2457
2639
  self.deleter = lambda ptr, size: self.free(ptr, size)
2458
2640
 
2459
2641
  def alloc(self, size_in_bytes):
2460
- ptr = runtime.core.alloc_host(size_in_bytes)
2642
+ ptr = runtime.core.wp_alloc_host(size_in_bytes)
2461
2643
  if not ptr:
2462
2644
  raise RuntimeError(f"Failed to allocate {size_in_bytes} bytes on device 'cpu'")
2463
2645
  return ptr
2464
2646
 
2465
2647
  def free(self, ptr, size_in_bytes):
2466
- runtime.core.free_host(ptr)
2648
+ runtime.core.wp_free_host(ptr)
2467
2649
 
2468
2650
 
2469
2651
  class CpuPinnedAllocator:
@@ -2472,13 +2654,13 @@ class CpuPinnedAllocator:
2472
2654
  self.deleter = lambda ptr, size: self.free(ptr, size)
2473
2655
 
2474
2656
  def alloc(self, size_in_bytes):
2475
- ptr = runtime.core.alloc_pinned(size_in_bytes)
2657
+ ptr = runtime.core.wp_alloc_pinned(size_in_bytes)
2476
2658
  if not ptr:
2477
2659
  raise RuntimeError(f"Failed to allocate {size_in_bytes} bytes on device '{self.device}'")
2478
2660
  return ptr
2479
2661
 
2480
2662
  def free(self, ptr, size_in_bytes):
2481
- runtime.core.free_pinned(ptr)
2663
+ runtime.core.wp_free_pinned(ptr)
2482
2664
 
2483
2665
 
2484
2666
  class CudaDefaultAllocator:
@@ -2488,7 +2670,7 @@ class CudaDefaultAllocator:
2488
2670
  self.deleter = lambda ptr, size: self.free(ptr, size)
2489
2671
 
2490
2672
  def alloc(self, size_in_bytes):
2491
- ptr = runtime.core.alloc_device_default(self.device.context, size_in_bytes)
2673
+ ptr = runtime.core.wp_alloc_device_default(self.device.context, size_in_bytes)
2492
2674
  # If the allocation fails, check if graph capture is active to raise an informative error.
2493
2675
  # We delay the capture check to avoid overhead.
2494
2676
  if not ptr:
@@ -2510,7 +2692,7 @@ class CudaDefaultAllocator:
2510
2692
  return ptr
2511
2693
 
2512
2694
  def free(self, ptr, size_in_bytes):
2513
- runtime.core.free_device_default(self.device.context, ptr)
2695
+ runtime.core.wp_free_device_default(self.device.context, ptr)
2514
2696
 
2515
2697
 
2516
2698
  class CudaMempoolAllocator:
@@ -2521,13 +2703,13 @@ class CudaMempoolAllocator:
2521
2703
  self.deleter = lambda ptr, size: self.free(ptr, size)
2522
2704
 
2523
2705
  def alloc(self, size_in_bytes):
2524
- ptr = runtime.core.alloc_device_async(self.device.context, size_in_bytes)
2706
+ ptr = runtime.core.wp_alloc_device_async(self.device.context, size_in_bytes)
2525
2707
  if not ptr:
2526
2708
  raise RuntimeError(f"Failed to allocate {size_in_bytes} bytes on device '{self.device}'")
2527
2709
  return ptr
2528
2710
 
2529
2711
  def free(self, ptr, size_in_bytes):
2530
- runtime.core.free_device_async(self.device.context, ptr)
2712
+ runtime.core.wp_free_device_async(self.device.context, ptr)
2531
2713
 
2532
2714
 
2533
2715
  class ContextGuard:
@@ -2536,15 +2718,15 @@ class ContextGuard:
2536
2718
 
2537
2719
  def __enter__(self):
2538
2720
  if self.device.is_cuda:
2539
- runtime.core.cuda_context_push_current(self.device.context)
2721
+ runtime.core.wp_cuda_context_push_current(self.device.context)
2540
2722
  elif is_cuda_driver_initialized():
2541
- self.saved_context = runtime.core.cuda_context_get_current()
2723
+ self.saved_context = runtime.core.wp_cuda_context_get_current()
2542
2724
 
2543
2725
  def __exit__(self, exc_type, exc_value, traceback):
2544
2726
  if self.device.is_cuda:
2545
- runtime.core.cuda_context_pop_current()
2727
+ runtime.core.wp_cuda_context_pop_current()
2546
2728
  elif is_cuda_driver_initialized():
2547
- runtime.core.cuda_context_set_current(self.saved_context)
2729
+ runtime.core.wp_cuda_context_set_current(self.saved_context)
2548
2730
 
2549
2731
 
2550
2732
  class Event:
@@ -2607,7 +2789,7 @@ class Event:
2607
2789
  raise ValueError("The combination of 'enable_timing=True' and 'interprocess=True' is not allowed.")
2608
2790
  flags |= Event.Flags.INTERPROCESS
2609
2791
 
2610
- self.cuda_event = runtime.core.cuda_event_create(device.context, flags)
2792
+ self.cuda_event = runtime.core.wp_cuda_event_create(device.context, flags)
2611
2793
  if not self.cuda_event:
2612
2794
  raise RuntimeError(f"Failed to create event on device {device}")
2613
2795
  self.owner = True
@@ -2634,7 +2816,9 @@ class Event:
2634
2816
  # Allocate a buffer for the data (64-element char array)
2635
2817
  ipc_handle_buffer = (ctypes.c_char * 64)()
2636
2818
 
2637
- warp.context.runtime.core.cuda_ipc_get_event_handle(self.device.context, self.cuda_event, ipc_handle_buffer)
2819
+ warp.context.runtime.core.wp_cuda_ipc_get_event_handle(
2820
+ self.device.context, self.cuda_event, ipc_handle_buffer
2821
+ )
2638
2822
 
2639
2823
  if ipc_handle_buffer.raw == bytes(64):
2640
2824
  warp.utils.warn("IPC event handle appears to be invalid. Was interprocess=True used?")
@@ -2651,7 +2835,7 @@ class Event:
2651
2835
  This property may not be accessed during a graph capture on any stream.
2652
2836
  """
2653
2837
 
2654
- result_code = runtime.core.cuda_event_query(self.cuda_event)
2838
+ result_code = runtime.core.wp_cuda_event_query(self.cuda_event)
2655
2839
 
2656
2840
  return result_code == 0
2657
2841
 
@@ -2659,7 +2843,7 @@ class Event:
2659
2843
  if not self.owner:
2660
2844
  return
2661
2845
 
2662
- runtime.core.cuda_event_destroy(self.cuda_event)
2846
+ runtime.core.wp_cuda_event_destroy(self.cuda_event)
2663
2847
 
2664
2848
 
2665
2849
  class Stream:
@@ -2709,12 +2893,12 @@ class Stream:
2709
2893
  # we pass cuda_stream through kwargs because cuda_stream=None is actually a valid value (CUDA default stream)
2710
2894
  if "cuda_stream" in kwargs:
2711
2895
  self.cuda_stream = kwargs["cuda_stream"]
2712
- device.runtime.core.cuda_stream_register(device.context, self.cuda_stream)
2896
+ device.runtime.core.wp_cuda_stream_register(device.context, self.cuda_stream)
2713
2897
  else:
2714
2898
  if not isinstance(priority, int):
2715
2899
  raise TypeError("Stream priority must be an integer.")
2716
2900
  clamped_priority = max(-1, min(priority, 0)) # Only support two priority levels
2717
- self.cuda_stream = device.runtime.core.cuda_stream_create(device.context, clamped_priority)
2901
+ self.cuda_stream = device.runtime.core.wp_cuda_stream_create(device.context, clamped_priority)
2718
2902
 
2719
2903
  if not self.cuda_stream:
2720
2904
  raise RuntimeError(f"Failed to create stream on device {device}")
@@ -2725,9 +2909,9 @@ class Stream:
2725
2909
  return
2726
2910
 
2727
2911
  if self.owner:
2728
- runtime.core.cuda_stream_destroy(self.device.context, self.cuda_stream)
2912
+ runtime.core.wp_cuda_stream_destroy(self.device.context, self.cuda_stream)
2729
2913
  else:
2730
- runtime.core.cuda_stream_unregister(self.device.context, self.cuda_stream)
2914
+ runtime.core.wp_cuda_stream_unregister(self.device.context, self.cuda_stream)
2731
2915
 
2732
2916
  @property
2733
2917
  def cached_event(self) -> Event:
@@ -2753,7 +2937,7 @@ class Stream:
2753
2937
  f"Event from device {event.device} cannot be recorded on stream from device {self.device}"
2754
2938
  )
2755
2939
 
2756
- runtime.core.cuda_event_record(event.cuda_event, self.cuda_stream, event.enable_timing)
2940
+ runtime.core.wp_cuda_event_record(event.cuda_event, self.cuda_stream, event.enable_timing)
2757
2941
 
2758
2942
  return event
2759
2943
 
@@ -2762,7 +2946,7 @@ class Stream:
2762
2946
 
2763
2947
  This function does not block the host thread.
2764
2948
  """
2765
- runtime.core.cuda_stream_wait_event(self.cuda_stream, event.cuda_event)
2949
+ runtime.core.wp_cuda_stream_wait_event(self.cuda_stream, event.cuda_event)
2766
2950
 
2767
2951
  def wait_stream(self, other_stream: Stream, event: Event | None = None):
2768
2952
  """Records an event on `other_stream` and makes this stream wait on it.
@@ -2785,7 +2969,7 @@ class Stream:
2785
2969
  if event is None:
2786
2970
  event = other_stream.cached_event
2787
2971
 
2788
- runtime.core.cuda_stream_wait_stream(self.cuda_stream, other_stream.cuda_stream, event.cuda_event)
2972
+ runtime.core.wp_cuda_stream_wait_stream(self.cuda_stream, other_stream.cuda_stream, event.cuda_event)
2789
2973
 
2790
2974
  @property
2791
2975
  def is_complete(self) -> bool:
@@ -2794,19 +2978,19 @@ class Stream:
2794
2978
  This property may not be accessed during a graph capture on any stream.
2795
2979
  """
2796
2980
 
2797
- result_code = runtime.core.cuda_stream_query(self.cuda_stream)
2981
+ result_code = runtime.core.wp_cuda_stream_query(self.cuda_stream)
2798
2982
 
2799
2983
  return result_code == 0
2800
2984
 
2801
2985
  @property
2802
2986
  def is_capturing(self) -> bool:
2803
2987
  """A boolean indicating whether a graph capture is currently ongoing on this stream."""
2804
- return bool(runtime.core.cuda_stream_is_capturing(self.cuda_stream))
2988
+ return bool(runtime.core.wp_cuda_stream_is_capturing(self.cuda_stream))
2805
2989
 
2806
2990
  @property
2807
2991
  def priority(self) -> int:
2808
2992
  """An integer representing the priority of the stream."""
2809
- return runtime.core.cuda_stream_get_priority(self.cuda_stream)
2993
+ return runtime.core.wp_cuda_stream_get_priority(self.cuda_stream)
2810
2994
 
2811
2995
 
2812
2996
  class Device:
@@ -2875,22 +3059,22 @@ class Device:
2875
3059
  self.pci_bus_id = None
2876
3060
 
2877
3061
  # TODO: add more device-specific dispatch functions
2878
- self.memset = runtime.core.memset_host
2879
- self.memtile = runtime.core.memtile_host
3062
+ self.memset = runtime.core.wp_memset_host
3063
+ self.memtile = runtime.core.wp_memtile_host
2880
3064
 
2881
3065
  self.default_allocator = CpuDefaultAllocator(self)
2882
3066
  self.pinned_allocator = CpuPinnedAllocator(self)
2883
3067
 
2884
- elif ordinal >= 0 and ordinal < runtime.core.cuda_device_get_count():
3068
+ elif ordinal >= 0 and ordinal < runtime.core.wp_cuda_device_get_count():
2885
3069
  # CUDA device
2886
- self.name = runtime.core.cuda_device_get_name(ordinal).decode()
2887
- self.arch = runtime.core.cuda_device_get_arch(ordinal)
2888
- self.sm_count = runtime.core.cuda_device_get_sm_count(ordinal)
2889
- self.is_uva = runtime.core.cuda_device_is_uva(ordinal) > 0
2890
- self.is_mempool_supported = runtime.core.cuda_device_is_mempool_supported(ordinal) > 0
3070
+ self.name = runtime.core.wp_cuda_device_get_name(ordinal).decode()
3071
+ self.arch = runtime.core.wp_cuda_device_get_arch(ordinal)
3072
+ self.sm_count = runtime.core.wp_cuda_device_get_sm_count(ordinal)
3073
+ self.is_uva = runtime.core.wp_cuda_device_is_uva(ordinal) > 0
3074
+ self.is_mempool_supported = runtime.core.wp_cuda_device_is_mempool_supported(ordinal) > 0
2891
3075
  if platform.system() == "Linux":
2892
3076
  # Use None when IPC support cannot be determined
2893
- ipc_support_api_query = runtime.core.cuda_device_is_ipc_supported(ordinal)
3077
+ ipc_support_api_query = runtime.core.wp_cuda_device_is_ipc_supported(ordinal)
2894
3078
  self.is_ipc_supported = bool(ipc_support_api_query) if ipc_support_api_query >= 0 else None
2895
3079
  else:
2896
3080
  self.is_ipc_supported = False
@@ -2902,13 +3086,13 @@ class Device:
2902
3086
  self.is_mempool_enabled = False
2903
3087
 
2904
3088
  uuid_buffer = (ctypes.c_char * 16)()
2905
- runtime.core.cuda_device_get_uuid(ordinal, uuid_buffer)
3089
+ runtime.core.wp_cuda_device_get_uuid(ordinal, uuid_buffer)
2906
3090
  uuid_byte_str = bytes(uuid_buffer).hex()
2907
3091
  self.uuid = f"GPU-{uuid_byte_str[0:8]}-{uuid_byte_str[8:12]}-{uuid_byte_str[12:16]}-{uuid_byte_str[16:20]}-{uuid_byte_str[20:]}"
2908
3092
 
2909
- pci_domain_id = runtime.core.cuda_device_get_pci_domain_id(ordinal)
2910
- pci_bus_id = runtime.core.cuda_device_get_pci_bus_id(ordinal)
2911
- pci_device_id = runtime.core.cuda_device_get_pci_device_id(ordinal)
3093
+ pci_domain_id = runtime.core.wp_cuda_device_get_pci_domain_id(ordinal)
3094
+ pci_bus_id = runtime.core.wp_cuda_device_get_pci_bus_id(ordinal)
3095
+ pci_device_id = runtime.core.wp_cuda_device_get_pci_device_id(ordinal)
2912
3096
  # This is (mis)named to correspond to the naming of cudaDeviceGetPCIBusId
2913
3097
  self.pci_bus_id = f"{pci_domain_id:08X}:{pci_bus_id:02X}:{pci_device_id:02X}"
2914
3098
 
@@ -2932,8 +3116,8 @@ class Device:
2932
3116
  self._init_streams()
2933
3117
 
2934
3118
  # TODO: add more device-specific dispatch functions
2935
- self.memset = lambda ptr, value, size: runtime.core.memset_device(self.context, ptr, value, size)
2936
- self.memtile = lambda ptr, src, srcsize, reps: runtime.core.memtile_device(
3119
+ self.memset = lambda ptr, value, size: runtime.core.wp_memset_device(self.context, ptr, value, size)
3120
+ self.memtile = lambda ptr, src, srcsize, reps: runtime.core.wp_memtile_device(
2937
3121
  self.context, ptr, src, srcsize, reps
2938
3122
  )
2939
3123
 
@@ -2992,15 +3176,15 @@ class Device:
2992
3176
  return self._context
2993
3177
  elif self.is_primary:
2994
3178
  # acquire primary context on demand
2995
- prev_context = runtime.core.cuda_context_get_current()
2996
- self._context = self.runtime.core.cuda_device_get_primary_context(self.ordinal)
3179
+ prev_context = runtime.core.wp_cuda_context_get_current()
3180
+ self._context = self.runtime.core.wp_cuda_device_get_primary_context(self.ordinal)
2997
3181
  if self._context is None:
2998
- runtime.core.cuda_context_set_current(prev_context)
3182
+ runtime.core.wp_cuda_context_set_current(prev_context)
2999
3183
  raise RuntimeError(f"Failed to acquire primary context for device {self}")
3000
3184
  self.runtime.context_map[self._context] = self
3001
3185
  # initialize streams
3002
3186
  self._init_streams()
3003
- runtime.core.cuda_context_set_current(prev_context)
3187
+ runtime.core.wp_cuda_context_set_current(prev_context)
3004
3188
  return self._context
3005
3189
 
3006
3190
  @property
@@ -3044,7 +3228,7 @@ class Device:
3044
3228
  if stream.device != self:
3045
3229
  raise RuntimeError(f"Stream from device {stream.device} cannot be used on device {self}")
3046
3230
 
3047
- self.runtime.core.cuda_context_set_stream(self.context, stream.cuda_stream, int(sync))
3231
+ self.runtime.core.wp_cuda_context_set_stream(self.context, stream.cuda_stream, int(sync))
3048
3232
  self._stream = stream
3049
3233
  else:
3050
3234
  raise RuntimeError(f"Device {self} is not a CUDA device")
@@ -3062,7 +3246,7 @@ class Device:
3062
3246
  """
3063
3247
  if self.is_cuda:
3064
3248
  total_mem = ctypes.c_size_t()
3065
- self.runtime.core.cuda_device_get_memory_info(self.ordinal, None, ctypes.byref(total_mem))
3249
+ self.runtime.core.wp_cuda_device_get_memory_info(self.ordinal, None, ctypes.byref(total_mem))
3066
3250
  return total_mem.value
3067
3251
  else:
3068
3252
  # TODO: cpu
@@ -3076,7 +3260,7 @@ class Device:
3076
3260
  """
3077
3261
  if self.is_cuda:
3078
3262
  free_mem = ctypes.c_size_t()
3079
- self.runtime.core.cuda_device_get_memory_info(self.ordinal, ctypes.byref(free_mem), None)
3263
+ self.runtime.core.wp_cuda_device_get_memory_info(self.ordinal, ctypes.byref(free_mem), None)
3080
3264
  return free_mem.value
3081
3265
  else:
3082
3266
  # TODO: cpu
@@ -3103,7 +3287,7 @@ class Device:
3103
3287
 
3104
3288
  def make_current(self):
3105
3289
  if self.context is not None:
3106
- self.runtime.core.cuda_context_set_current(self.context)
3290
+ self.runtime.core.wp_cuda_context_set_current(self.context)
3107
3291
 
3108
3292
  def can_access(self, other):
3109
3293
  # TODO: this function should be redesigned in terms of (device, resource).
@@ -3117,6 +3301,78 @@ class Device:
3117
3301
  else:
3118
3302
  return False
3119
3303
 
3304
+ def get_cuda_output_format(self, preferred_cuda_output: str | None = None) -> str | None:
3305
+ """Determine the CUDA output format to use for this device.
3306
+
3307
+ This method is intended for internal use by Warp's compilation system.
3308
+ External users should not need to call this method directly.
3309
+
3310
+ It determines whether to use PTX or CUBIN output based on device capabilities,
3311
+ caller preferences, and runtime constraints.
3312
+
3313
+ Args:
3314
+ preferred_cuda_output: Caller's preferred format (``"ptx"``, ``"cubin"``, or ``None``).
3315
+ If ``None``, falls back to global config or automatic determination.
3316
+
3317
+ Returns:
3318
+ The output format to use: ``"ptx"``, ``"cubin"``, or ``None`` for CPU devices.
3319
+ """
3320
+
3321
+ if self.is_cpu:
3322
+ # CPU devices don't use CUDA compilation
3323
+ return None
3324
+
3325
+ if not self.is_cubin_supported:
3326
+ return "ptx"
3327
+
3328
+ # Use provided preference or fall back to global config
3329
+ if preferred_cuda_output is None:
3330
+ preferred_cuda_output = warp.config.cuda_output
3331
+
3332
+ if preferred_cuda_output is not None:
3333
+ # Caller specified a preference, use it if supported
3334
+ if preferred_cuda_output in ("ptx", "cubin"):
3335
+ return preferred_cuda_output
3336
+ else:
3337
+ # Invalid preference, fall back to automatic determination
3338
+ pass
3339
+
3340
+ # Determine automatically: Older drivers may not be able to handle PTX generated using newer CUDA Toolkits,
3341
+ # in which case we fall back on generating CUBIN modules
3342
+ return "ptx" if self.runtime.driver_version >= self.runtime.toolkit_version else "cubin"
3343
+
3344
+ def get_cuda_compile_arch(self) -> int | None:
3345
+ """Get the CUDA architecture to use when compiling code for this device.
3346
+
3347
+ This method is intended for internal use by Warp's compilation system.
3348
+ External users should not need to call this method directly.
3349
+
3350
+ Determines the appropriate compute capability version to use when compiling
3351
+ CUDA kernels for this device. The architecture depends on the device's
3352
+ CUDA output format preference and available target architectures.
3353
+
3354
+ For PTX output format, uses the minimum of the device's architecture and
3355
+ the configured PTX target architecture to ensure compatibility.
3356
+ For CUBIN output format, uses the device's exact architecture.
3357
+
3358
+ Returns:
3359
+ The compute capability version (e.g., 75 for ``sm_75``) to use for compilation,
3360
+ or ``None`` for CPU devices which don't use CUDA compilation.
3361
+ """
3362
+ if self.is_cpu:
3363
+ return None
3364
+
3365
+ if self.get_cuda_output_format() == "ptx":
3366
+ # use the default PTX arch if the device supports it
3367
+ if warp.config.ptx_target_arch is not None:
3368
+ output_arch = min(self.arch, warp.config.ptx_target_arch)
3369
+ else:
3370
+ output_arch = min(self.arch, runtime.default_ptx_arch)
3371
+ else:
3372
+ output_arch = self.arch
3373
+
3374
+ return output_arch
3375
+
3120
3376
 
3121
3377
  """ Meta-type for arguments that can be resolved to a concrete Device.
3122
3378
  """
@@ -3129,11 +3385,7 @@ class Graph:
3129
3385
  self.capture_id = capture_id
3130
3386
  self.module_execs: set[ModuleExec] = set()
3131
3387
  self.graph_exec: ctypes.c_void_p | None = None
3132
-
3133
3388
  self.graph: ctypes.c_void_p | None = None
3134
- self.has_conditional = (
3135
- False # Track if there are conditional nodes in the graph since they are not allowed in child graphs
3136
- )
3137
3389
 
3138
3390
  def __del__(self):
3139
3391
  if not hasattr(self, "graph") or not hasattr(self, "device") or not self.graph:
@@ -3141,9 +3393,9 @@ class Graph:
3141
3393
 
3142
3394
  # use CUDA context guard to avoid side effects during garbage collection
3143
3395
  with self.device.context_guard:
3144
- runtime.core.cuda_graph_destroy(self.device.context, self.graph)
3396
+ runtime.core.wp_cuda_graph_destroy(self.device.context, self.graph)
3145
3397
  if hasattr(self, "graph_exec") and self.graph_exec is not None:
3146
- runtime.core.cuda_graph_exec_destroy(self.device.context, self.graph_exec)
3398
+ runtime.core.wp_cuda_graph_exec_destroy(self.device.context, self.graph_exec)
3147
3399
 
3148
3400
  # retain executable CUDA modules used by this graph, which prevents them from being unloaded
3149
3401
  def retain_module_exec(self, module_exec: ModuleExec):
@@ -3155,6 +3407,14 @@ class Runtime:
3155
3407
  if sys.version_info < (3, 9):
3156
3408
  warp.utils.warn(f"Python 3.9 or newer is recommended for running Warp, detected {sys.version_info}")
3157
3409
 
3410
+ if platform.system() == "Darwin" and platform.machine() == "x86_64":
3411
+ warp.utils.warn(
3412
+ "Support for Warp on Intel-based macOS is deprecated and will be removed in the near future. "
3413
+ "Apple Silicon-based Macs will continue to be supported.",
3414
+ DeprecationWarning,
3415
+ stacklevel=3,
3416
+ )
3417
+
3158
3418
  bin_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "bin")
3159
3419
 
3160
3420
  if os.name == "nt":
@@ -3177,7 +3437,7 @@ class Runtime:
3177
3437
  if os.path.exists(llvm_lib):
3178
3438
  self.llvm = self.load_dll(llvm_lib)
3179
3439
  # setup c-types for warp-clang.dll
3180
- self.llvm.lookup.restype = ctypes.c_uint64
3440
+ self.llvm.wp_lookup.restype = ctypes.c_uint64
3181
3441
  else:
3182
3442
  self.llvm = None
3183
3443
 
@@ -3186,83 +3446,83 @@ class Runtime:
3186
3446
 
3187
3447
  # setup c-types for warp.dll
3188
3448
  try:
3189
- self.core.get_error_string.argtypes = []
3190
- self.core.get_error_string.restype = ctypes.c_char_p
3191
- self.core.set_error_output_enabled.argtypes = [ctypes.c_int]
3192
- self.core.set_error_output_enabled.restype = None
3193
- self.core.is_error_output_enabled.argtypes = []
3194
- self.core.is_error_output_enabled.restype = ctypes.c_int
3195
-
3196
- self.core.alloc_host.argtypes = [ctypes.c_size_t]
3197
- self.core.alloc_host.restype = ctypes.c_void_p
3198
- self.core.alloc_pinned.argtypes = [ctypes.c_size_t]
3199
- self.core.alloc_pinned.restype = ctypes.c_void_p
3200
- self.core.alloc_device.argtypes = [ctypes.c_void_p, ctypes.c_size_t]
3201
- self.core.alloc_device.restype = ctypes.c_void_p
3202
- self.core.alloc_device_default.argtypes = [ctypes.c_void_p, ctypes.c_size_t]
3203
- self.core.alloc_device_default.restype = ctypes.c_void_p
3204
- self.core.alloc_device_async.argtypes = [ctypes.c_void_p, ctypes.c_size_t]
3205
- self.core.alloc_device_async.restype = ctypes.c_void_p
3206
-
3207
- self.core.float_to_half_bits.argtypes = [ctypes.c_float]
3208
- self.core.float_to_half_bits.restype = ctypes.c_uint16
3209
- self.core.half_bits_to_float.argtypes = [ctypes.c_uint16]
3210
- self.core.half_bits_to_float.restype = ctypes.c_float
3211
-
3212
- self.core.free_host.argtypes = [ctypes.c_void_p]
3213
- self.core.free_host.restype = None
3214
- self.core.free_pinned.argtypes = [ctypes.c_void_p]
3215
- self.core.free_pinned.restype = None
3216
- self.core.free_device.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
3217
- self.core.free_device.restype = None
3218
- self.core.free_device_default.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
3219
- self.core.free_device_default.restype = None
3220
- self.core.free_device_async.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
3221
- self.core.free_device_async.restype = None
3222
-
3223
- self.core.memset_host.argtypes = [ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t]
3224
- self.core.memset_host.restype = None
3225
- self.core.memset_device.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t]
3226
- self.core.memset_device.restype = None
3227
-
3228
- self.core.memtile_host.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ctypes.c_size_t]
3229
- self.core.memtile_host.restype = None
3230
- self.core.memtile_device.argtypes = [
3449
+ self.core.wp_get_error_string.argtypes = []
3450
+ self.core.wp_get_error_string.restype = ctypes.c_char_p
3451
+ self.core.wp_set_error_output_enabled.argtypes = [ctypes.c_int]
3452
+ self.core.wp_set_error_output_enabled.restype = None
3453
+ self.core.wp_is_error_output_enabled.argtypes = []
3454
+ self.core.wp_is_error_output_enabled.restype = ctypes.c_int
3455
+
3456
+ self.core.wp_alloc_host.argtypes = [ctypes.c_size_t]
3457
+ self.core.wp_alloc_host.restype = ctypes.c_void_p
3458
+ self.core.wp_alloc_pinned.argtypes = [ctypes.c_size_t]
3459
+ self.core.wp_alloc_pinned.restype = ctypes.c_void_p
3460
+ self.core.wp_alloc_device.argtypes = [ctypes.c_void_p, ctypes.c_size_t]
3461
+ self.core.wp_alloc_device.restype = ctypes.c_void_p
3462
+ self.core.wp_alloc_device_default.argtypes = [ctypes.c_void_p, ctypes.c_size_t]
3463
+ self.core.wp_alloc_device_default.restype = ctypes.c_void_p
3464
+ self.core.wp_alloc_device_async.argtypes = [ctypes.c_void_p, ctypes.c_size_t]
3465
+ self.core.wp_alloc_device_async.restype = ctypes.c_void_p
3466
+
3467
+ self.core.wp_float_to_half_bits.argtypes = [ctypes.c_float]
3468
+ self.core.wp_float_to_half_bits.restype = ctypes.c_uint16
3469
+ self.core.wp_half_bits_to_float.argtypes = [ctypes.c_uint16]
3470
+ self.core.wp_half_bits_to_float.restype = ctypes.c_float
3471
+
3472
+ self.core.wp_free_host.argtypes = [ctypes.c_void_p]
3473
+ self.core.wp_free_host.restype = None
3474
+ self.core.wp_free_pinned.argtypes = [ctypes.c_void_p]
3475
+ self.core.wp_free_pinned.restype = None
3476
+ self.core.wp_free_device.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
3477
+ self.core.wp_free_device.restype = None
3478
+ self.core.wp_free_device_default.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
3479
+ self.core.wp_free_device_default.restype = None
3480
+ self.core.wp_free_device_async.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
3481
+ self.core.wp_free_device_async.restype = None
3482
+
3483
+ self.core.wp_memset_host.argtypes = [ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t]
3484
+ self.core.wp_memset_host.restype = None
3485
+ self.core.wp_memset_device.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t]
3486
+ self.core.wp_memset_device.restype = None
3487
+
3488
+ self.core.wp_memtile_host.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ctypes.c_size_t]
3489
+ self.core.wp_memtile_host.restype = None
3490
+ self.core.wp_memtile_device.argtypes = [
3231
3491
  ctypes.c_void_p,
3232
3492
  ctypes.c_void_p,
3233
3493
  ctypes.c_void_p,
3234
3494
  ctypes.c_size_t,
3235
3495
  ctypes.c_size_t,
3236
3496
  ]
3237
- self.core.memtile_device.restype = None
3497
+ self.core.wp_memtile_device.restype = None
3238
3498
 
3239
- self.core.memcpy_h2h.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t]
3240
- self.core.memcpy_h2h.restype = ctypes.c_bool
3241
- self.core.memcpy_h2d.argtypes = [
3499
+ self.core.wp_memcpy_h2h.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t]
3500
+ self.core.wp_memcpy_h2h.restype = ctypes.c_bool
3501
+ self.core.wp_memcpy_h2d.argtypes = [
3242
3502
  ctypes.c_void_p,
3243
3503
  ctypes.c_void_p,
3244
3504
  ctypes.c_void_p,
3245
3505
  ctypes.c_size_t,
3246
3506
  ctypes.c_void_p,
3247
3507
  ]
3248
- self.core.memcpy_h2d.restype = ctypes.c_bool
3249
- self.core.memcpy_d2h.argtypes = [
3508
+ self.core.wp_memcpy_h2d.restype = ctypes.c_bool
3509
+ self.core.wp_memcpy_d2h.argtypes = [
3250
3510
  ctypes.c_void_p,
3251
3511
  ctypes.c_void_p,
3252
3512
  ctypes.c_void_p,
3253
3513
  ctypes.c_size_t,
3254
3514
  ctypes.c_void_p,
3255
3515
  ]
3256
- self.core.memcpy_d2h.restype = ctypes.c_bool
3257
- self.core.memcpy_d2d.argtypes = [
3516
+ self.core.wp_memcpy_d2h.restype = ctypes.c_bool
3517
+ self.core.wp_memcpy_d2d.argtypes = [
3258
3518
  ctypes.c_void_p,
3259
3519
  ctypes.c_void_p,
3260
3520
  ctypes.c_void_p,
3261
3521
  ctypes.c_size_t,
3262
3522
  ctypes.c_void_p,
3263
3523
  ]
3264
- self.core.memcpy_d2d.restype = ctypes.c_bool
3265
- self.core.memcpy_p2p.argtypes = [
3524
+ self.core.wp_memcpy_d2d.restype = ctypes.c_bool
3525
+ self.core.wp_memcpy_p2p.argtypes = [
3266
3526
  ctypes.c_void_p,
3267
3527
  ctypes.c_void_p,
3268
3528
  ctypes.c_void_p,
@@ -3270,17 +3530,17 @@ class Runtime:
3270
3530
  ctypes.c_size_t,
3271
3531
  ctypes.c_void_p,
3272
3532
  ]
3273
- self.core.memcpy_p2p.restype = ctypes.c_bool
3533
+ self.core.wp_memcpy_p2p.restype = ctypes.c_bool
3274
3534
 
3275
- self.core.array_copy_host.argtypes = [
3535
+ self.core.wp_array_copy_host.argtypes = [
3276
3536
  ctypes.c_void_p,
3277
3537
  ctypes.c_void_p,
3278
3538
  ctypes.c_int,
3279
3539
  ctypes.c_int,
3280
3540
  ctypes.c_int,
3281
3541
  ]
3282
- self.core.array_copy_host.restype = ctypes.c_bool
3283
- self.core.array_copy_device.argtypes = [
3542
+ self.core.wp_array_copy_host.restype = ctypes.c_bool
3543
+ self.core.wp_array_copy_device.argtypes = [
3284
3544
  ctypes.c_void_p,
3285
3545
  ctypes.c_void_p,
3286
3546
  ctypes.c_void_p,
@@ -3288,41 +3548,41 @@ class Runtime:
3288
3548
  ctypes.c_int,
3289
3549
  ctypes.c_int,
3290
3550
  ]
3291
- self.core.array_copy_device.restype = ctypes.c_bool
3551
+ self.core.wp_array_copy_device.restype = ctypes.c_bool
3292
3552
 
3293
- self.core.array_fill_host.argtypes = [ctypes.c_void_p, ctypes.c_int, ctypes.c_void_p, ctypes.c_int]
3294
- self.core.array_fill_host.restype = None
3295
- self.core.array_fill_device.argtypes = [
3553
+ self.core.wp_array_fill_host.argtypes = [ctypes.c_void_p, ctypes.c_int, ctypes.c_void_p, ctypes.c_int]
3554
+ self.core.wp_array_fill_host.restype = None
3555
+ self.core.wp_array_fill_device.argtypes = [
3296
3556
  ctypes.c_void_p,
3297
3557
  ctypes.c_void_p,
3298
3558
  ctypes.c_int,
3299
3559
  ctypes.c_void_p,
3300
3560
  ctypes.c_int,
3301
3561
  ]
3302
- self.core.array_fill_device.restype = None
3562
+ self.core.wp_array_fill_device.restype = None
3303
3563
 
3304
- self.core.array_sum_double_host.argtypes = [
3564
+ self.core.wp_array_sum_double_host.argtypes = [
3305
3565
  ctypes.c_uint64,
3306
3566
  ctypes.c_uint64,
3307
3567
  ctypes.c_int,
3308
3568
  ctypes.c_int,
3309
3569
  ctypes.c_int,
3310
3570
  ]
3311
- self.core.array_sum_float_host.argtypes = [
3571
+ self.core.wp_array_sum_float_host.argtypes = [
3312
3572
  ctypes.c_uint64,
3313
3573
  ctypes.c_uint64,
3314
3574
  ctypes.c_int,
3315
3575
  ctypes.c_int,
3316
3576
  ctypes.c_int,
3317
3577
  ]
3318
- self.core.array_sum_double_device.argtypes = [
3578
+ self.core.wp_array_sum_double_device.argtypes = [
3319
3579
  ctypes.c_uint64,
3320
3580
  ctypes.c_uint64,
3321
3581
  ctypes.c_int,
3322
3582
  ctypes.c_int,
3323
3583
  ctypes.c_int,
3324
3584
  ]
3325
- self.core.array_sum_float_device.argtypes = [
3585
+ self.core.wp_array_sum_float_device.argtypes = [
3326
3586
  ctypes.c_uint64,
3327
3587
  ctypes.c_uint64,
3328
3588
  ctypes.c_int,
@@ -3330,7 +3590,7 @@ class Runtime:
3330
3590
  ctypes.c_int,
3331
3591
  ]
3332
3592
 
3333
- self.core.array_inner_double_host.argtypes = [
3593
+ self.core.wp_array_inner_double_host.argtypes = [
3334
3594
  ctypes.c_uint64,
3335
3595
  ctypes.c_uint64,
3336
3596
  ctypes.c_uint64,
@@ -3339,7 +3599,7 @@ class Runtime:
3339
3599
  ctypes.c_int,
3340
3600
  ctypes.c_int,
3341
3601
  ]
3342
- self.core.array_inner_float_host.argtypes = [
3602
+ self.core.wp_array_inner_float_host.argtypes = [
3343
3603
  ctypes.c_uint64,
3344
3604
  ctypes.c_uint64,
3345
3605
  ctypes.c_uint64,
@@ -3348,7 +3608,7 @@ class Runtime:
3348
3608
  ctypes.c_int,
3349
3609
  ctypes.c_int,
3350
3610
  ]
3351
- self.core.array_inner_double_device.argtypes = [
3611
+ self.core.wp_array_inner_double_device.argtypes = [
3352
3612
  ctypes.c_uint64,
3353
3613
  ctypes.c_uint64,
3354
3614
  ctypes.c_uint64,
@@ -3357,7 +3617,7 @@ class Runtime:
3357
3617
  ctypes.c_int,
3358
3618
  ctypes.c_int,
3359
3619
  ]
3360
- self.core.array_inner_float_device.argtypes = [
3620
+ self.core.wp_array_inner_float_device.argtypes = [
3361
3621
  ctypes.c_uint64,
3362
3622
  ctypes.c_uint64,
3363
3623
  ctypes.c_uint64,
@@ -3367,21 +3627,36 @@ class Runtime:
3367
3627
  ctypes.c_int,
3368
3628
  ]
3369
3629
 
3370
- self.core.array_scan_int_host.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int, ctypes.c_bool]
3371
- self.core.array_scan_float_host.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int, ctypes.c_bool]
3372
- self.core.array_scan_int_device.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int, ctypes.c_bool]
3373
- self.core.array_scan_float_device.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int, ctypes.c_bool]
3630
+ self.core.wp_array_scan_int_host.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int, ctypes.c_bool]
3631
+ self.core.wp_array_scan_float_host.argtypes = [
3632
+ ctypes.c_uint64,
3633
+ ctypes.c_uint64,
3634
+ ctypes.c_int,
3635
+ ctypes.c_bool,
3636
+ ]
3637
+ self.core.wp_array_scan_int_device.argtypes = [
3638
+ ctypes.c_uint64,
3639
+ ctypes.c_uint64,
3640
+ ctypes.c_int,
3641
+ ctypes.c_bool,
3642
+ ]
3643
+ self.core.wp_array_scan_float_device.argtypes = [
3644
+ ctypes.c_uint64,
3645
+ ctypes.c_uint64,
3646
+ ctypes.c_int,
3647
+ ctypes.c_bool,
3648
+ ]
3374
3649
 
3375
- self.core.radix_sort_pairs_int_host.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int]
3376
- self.core.radix_sort_pairs_int_device.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int]
3650
+ self.core.wp_radix_sort_pairs_int_host.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int]
3651
+ self.core.wp_radix_sort_pairs_int_device.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int]
3377
3652
 
3378
- self.core.radix_sort_pairs_float_host.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int]
3379
- self.core.radix_sort_pairs_float_device.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int]
3653
+ self.core.wp_radix_sort_pairs_float_host.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int]
3654
+ self.core.wp_radix_sort_pairs_float_device.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int]
3380
3655
 
3381
- self.core.radix_sort_pairs_int64_host.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int]
3382
- self.core.radix_sort_pairs_int64_device.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int]
3656
+ self.core.wp_radix_sort_pairs_int64_host.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int]
3657
+ self.core.wp_radix_sort_pairs_int64_device.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int]
3383
3658
 
3384
- self.core.segmented_sort_pairs_int_host.argtypes = [
3659
+ self.core.wp_segmented_sort_pairs_int_host.argtypes = [
3385
3660
  ctypes.c_uint64,
3386
3661
  ctypes.c_uint64,
3387
3662
  ctypes.c_int,
@@ -3389,7 +3664,7 @@ class Runtime:
3389
3664
  ctypes.c_uint64,
3390
3665
  ctypes.c_int,
3391
3666
  ]
3392
- self.core.segmented_sort_pairs_int_device.argtypes = [
3667
+ self.core.wp_segmented_sort_pairs_int_device.argtypes = [
3393
3668
  ctypes.c_uint64,
3394
3669
  ctypes.c_uint64,
3395
3670
  ctypes.c_int,
@@ -3398,7 +3673,7 @@ class Runtime:
3398
3673
  ctypes.c_int,
3399
3674
  ]
3400
3675
 
3401
- self.core.segmented_sort_pairs_float_host.argtypes = [
3676
+ self.core.wp_segmented_sort_pairs_float_host.argtypes = [
3402
3677
  ctypes.c_uint64,
3403
3678
  ctypes.c_uint64,
3404
3679
  ctypes.c_int,
@@ -3406,7 +3681,7 @@ class Runtime:
3406
3681
  ctypes.c_uint64,
3407
3682
  ctypes.c_int,
3408
3683
  ]
3409
- self.core.segmented_sort_pairs_float_device.argtypes = [
3684
+ self.core.wp_segmented_sort_pairs_float_device.argtypes = [
3410
3685
  ctypes.c_uint64,
3411
3686
  ctypes.c_uint64,
3412
3687
  ctypes.c_int,
@@ -3415,14 +3690,14 @@ class Runtime:
3415
3690
  ctypes.c_int,
3416
3691
  ]
3417
3692
 
3418
- self.core.runlength_encode_int_host.argtypes = [
3693
+ self.core.wp_runlength_encode_int_host.argtypes = [
3419
3694
  ctypes.c_uint64,
3420
3695
  ctypes.c_uint64,
3421
3696
  ctypes.c_uint64,
3422
3697
  ctypes.c_uint64,
3423
3698
  ctypes.c_int,
3424
3699
  ]
3425
- self.core.runlength_encode_int_device.argtypes = [
3700
+ self.core.wp_runlength_encode_int_device.argtypes = [
3426
3701
  ctypes.c_uint64,
3427
3702
  ctypes.c_uint64,
3428
3703
  ctypes.c_uint64,
@@ -3430,11 +3705,11 @@ class Runtime:
3430
3705
  ctypes.c_int,
3431
3706
  ]
3432
3707
 
3433
- self.core.bvh_create_host.restype = ctypes.c_uint64
3434
- self.core.bvh_create_host.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int, ctypes.c_int]
3708
+ self.core.wp_bvh_create_host.restype = ctypes.c_uint64
3709
+ self.core.wp_bvh_create_host.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int, ctypes.c_int]
3435
3710
 
3436
- self.core.bvh_create_device.restype = ctypes.c_uint64
3437
- self.core.bvh_create_device.argtypes = [
3711
+ self.core.wp_bvh_create_device.restype = ctypes.c_uint64
3712
+ self.core.wp_bvh_create_device.argtypes = [
3438
3713
  ctypes.c_void_p,
3439
3714
  ctypes.c_void_p,
3440
3715
  ctypes.c_void_p,
@@ -3442,14 +3717,14 @@ class Runtime:
3442
3717
  ctypes.c_int,
3443
3718
  ]
3444
3719
 
3445
- self.core.bvh_destroy_host.argtypes = [ctypes.c_uint64]
3446
- self.core.bvh_destroy_device.argtypes = [ctypes.c_uint64]
3720
+ self.core.wp_bvh_destroy_host.argtypes = [ctypes.c_uint64]
3721
+ self.core.wp_bvh_destroy_device.argtypes = [ctypes.c_uint64]
3447
3722
 
3448
- self.core.bvh_refit_host.argtypes = [ctypes.c_uint64]
3449
- self.core.bvh_refit_device.argtypes = [ctypes.c_uint64]
3723
+ self.core.wp_bvh_refit_host.argtypes = [ctypes.c_uint64]
3724
+ self.core.wp_bvh_refit_device.argtypes = [ctypes.c_uint64]
3450
3725
 
3451
- self.core.mesh_create_host.restype = ctypes.c_uint64
3452
- self.core.mesh_create_host.argtypes = [
3726
+ self.core.wp_mesh_create_host.restype = ctypes.c_uint64
3727
+ self.core.wp_mesh_create_host.argtypes = [
3453
3728
  warp.types.array_t,
3454
3729
  warp.types.array_t,
3455
3730
  warp.types.array_t,
@@ -3459,8 +3734,8 @@ class Runtime:
3459
3734
  ctypes.c_int,
3460
3735
  ]
3461
3736
 
3462
- self.core.mesh_create_device.restype = ctypes.c_uint64
3463
- self.core.mesh_create_device.argtypes = [
3737
+ self.core.wp_mesh_create_device.restype = ctypes.c_uint64
3738
+ self.core.wp_mesh_create_device.argtypes = [
3464
3739
  ctypes.c_void_p,
3465
3740
  warp.types.array_t,
3466
3741
  warp.types.array_t,
@@ -3471,61 +3746,61 @@ class Runtime:
3471
3746
  ctypes.c_int,
3472
3747
  ]
3473
3748
 
3474
- self.core.mesh_destroy_host.argtypes = [ctypes.c_uint64]
3475
- self.core.mesh_destroy_device.argtypes = [ctypes.c_uint64]
3749
+ self.core.wp_mesh_destroy_host.argtypes = [ctypes.c_uint64]
3750
+ self.core.wp_mesh_destroy_device.argtypes = [ctypes.c_uint64]
3476
3751
 
3477
- self.core.mesh_refit_host.argtypes = [ctypes.c_uint64]
3478
- self.core.mesh_refit_device.argtypes = [ctypes.c_uint64]
3752
+ self.core.wp_mesh_refit_host.argtypes = [ctypes.c_uint64]
3753
+ self.core.wp_mesh_refit_device.argtypes = [ctypes.c_uint64]
3479
3754
 
3480
- self.core.mesh_set_points_host.argtypes = [ctypes.c_uint64, warp.types.array_t]
3481
- self.core.mesh_set_points_device.argtypes = [ctypes.c_uint64, warp.types.array_t]
3755
+ self.core.wp_mesh_set_points_host.argtypes = [ctypes.c_uint64, warp.types.array_t]
3756
+ self.core.wp_mesh_set_points_device.argtypes = [ctypes.c_uint64, warp.types.array_t]
3482
3757
 
3483
- self.core.mesh_set_velocities_host.argtypes = [ctypes.c_uint64, warp.types.array_t]
3484
- self.core.mesh_set_velocities_device.argtypes = [ctypes.c_uint64, warp.types.array_t]
3758
+ self.core.wp_mesh_set_velocities_host.argtypes = [ctypes.c_uint64, warp.types.array_t]
3759
+ self.core.wp_mesh_set_velocities_device.argtypes = [ctypes.c_uint64, warp.types.array_t]
3485
3760
 
3486
- self.core.hash_grid_create_host.argtypes = [ctypes.c_int, ctypes.c_int, ctypes.c_int]
3487
- self.core.hash_grid_create_host.restype = ctypes.c_uint64
3488
- self.core.hash_grid_destroy_host.argtypes = [ctypes.c_uint64]
3489
- self.core.hash_grid_update_host.argtypes = [ctypes.c_uint64, ctypes.c_float, ctypes.c_void_p]
3490
- self.core.hash_grid_reserve_host.argtypes = [ctypes.c_uint64, ctypes.c_int]
3761
+ self.core.wp_hash_grid_create_host.argtypes = [ctypes.c_int, ctypes.c_int, ctypes.c_int]
3762
+ self.core.wp_hash_grid_create_host.restype = ctypes.c_uint64
3763
+ self.core.wp_hash_grid_destroy_host.argtypes = [ctypes.c_uint64]
3764
+ self.core.wp_hash_grid_update_host.argtypes = [ctypes.c_uint64, ctypes.c_float, ctypes.c_void_p]
3765
+ self.core.wp_hash_grid_reserve_host.argtypes = [ctypes.c_uint64, ctypes.c_int]
3491
3766
 
3492
- self.core.hash_grid_create_device.argtypes = [ctypes.c_void_p, ctypes.c_int, ctypes.c_int, ctypes.c_int]
3493
- self.core.hash_grid_create_device.restype = ctypes.c_uint64
3494
- self.core.hash_grid_destroy_device.argtypes = [ctypes.c_uint64]
3495
- self.core.hash_grid_update_device.argtypes = [ctypes.c_uint64, ctypes.c_float, ctypes.c_void_p]
3496
- self.core.hash_grid_reserve_device.argtypes = [ctypes.c_uint64, ctypes.c_int]
3767
+ self.core.wp_hash_grid_create_device.argtypes = [ctypes.c_void_p, ctypes.c_int, ctypes.c_int, ctypes.c_int]
3768
+ self.core.wp_hash_grid_create_device.restype = ctypes.c_uint64
3769
+ self.core.wp_hash_grid_destroy_device.argtypes = [ctypes.c_uint64]
3770
+ self.core.wp_hash_grid_update_device.argtypes = [ctypes.c_uint64, ctypes.c_float, ctypes.c_void_p]
3771
+ self.core.wp_hash_grid_reserve_device.argtypes = [ctypes.c_uint64, ctypes.c_int]
3497
3772
 
3498
- self.core.volume_create_host.argtypes = [ctypes.c_void_p, ctypes.c_uint64, ctypes.c_bool, ctypes.c_bool]
3499
- self.core.volume_create_host.restype = ctypes.c_uint64
3500
- self.core.volume_get_tiles_host.argtypes = [
3773
+ self.core.wp_volume_create_host.argtypes = [ctypes.c_void_p, ctypes.c_uint64, ctypes.c_bool, ctypes.c_bool]
3774
+ self.core.wp_volume_create_host.restype = ctypes.c_uint64
3775
+ self.core.wp_volume_get_tiles_host.argtypes = [
3501
3776
  ctypes.c_uint64,
3502
3777
  ctypes.c_void_p,
3503
3778
  ]
3504
- self.core.volume_get_voxels_host.argtypes = [
3779
+ self.core.wp_volume_get_voxels_host.argtypes = [
3505
3780
  ctypes.c_uint64,
3506
3781
  ctypes.c_void_p,
3507
3782
  ]
3508
- self.core.volume_destroy_host.argtypes = [ctypes.c_uint64]
3783
+ self.core.wp_volume_destroy_host.argtypes = [ctypes.c_uint64]
3509
3784
 
3510
- self.core.volume_create_device.argtypes = [
3785
+ self.core.wp_volume_create_device.argtypes = [
3511
3786
  ctypes.c_void_p,
3512
3787
  ctypes.c_void_p,
3513
3788
  ctypes.c_uint64,
3514
3789
  ctypes.c_bool,
3515
3790
  ctypes.c_bool,
3516
3791
  ]
3517
- self.core.volume_create_device.restype = ctypes.c_uint64
3518
- self.core.volume_get_tiles_device.argtypes = [
3792
+ self.core.wp_volume_create_device.restype = ctypes.c_uint64
3793
+ self.core.wp_volume_get_tiles_device.argtypes = [
3519
3794
  ctypes.c_uint64,
3520
3795
  ctypes.c_void_p,
3521
3796
  ]
3522
- self.core.volume_get_voxels_device.argtypes = [
3797
+ self.core.wp_volume_get_voxels_device.argtypes = [
3523
3798
  ctypes.c_uint64,
3524
3799
  ctypes.c_void_p,
3525
3800
  ]
3526
- self.core.volume_destroy_device.argtypes = [ctypes.c_uint64]
3801
+ self.core.wp_volume_destroy_device.argtypes = [ctypes.c_uint64]
3527
3802
 
3528
- self.core.volume_from_tiles_device.argtypes = [
3803
+ self.core.wp_volume_from_tiles_device.argtypes = [
3529
3804
  ctypes.c_void_p,
3530
3805
  ctypes.c_void_p,
3531
3806
  ctypes.c_int,
@@ -3536,8 +3811,8 @@ class Runtime:
3536
3811
  ctypes.c_uint32,
3537
3812
  ctypes.c_char_p,
3538
3813
  ]
3539
- self.core.volume_from_tiles_device.restype = ctypes.c_uint64
3540
- self.core.volume_index_from_tiles_device.argtypes = [
3814
+ self.core.wp_volume_from_tiles_device.restype = ctypes.c_uint64
3815
+ self.core.wp_volume_index_from_tiles_device.argtypes = [
3541
3816
  ctypes.c_void_p,
3542
3817
  ctypes.c_void_p,
3543
3818
  ctypes.c_int,
@@ -3545,8 +3820,8 @@ class Runtime:
3545
3820
  ctypes.c_float * 3,
3546
3821
  ctypes.c_bool,
3547
3822
  ]
3548
- self.core.volume_index_from_tiles_device.restype = ctypes.c_uint64
3549
- self.core.volume_from_active_voxels_device.argtypes = [
3823
+ self.core.wp_volume_index_from_tiles_device.restype = ctypes.c_uint64
3824
+ self.core.wp_volume_from_active_voxels_device.argtypes = [
3550
3825
  ctypes.c_void_p,
3551
3826
  ctypes.c_void_p,
3552
3827
  ctypes.c_int,
@@ -3554,25 +3829,25 @@ class Runtime:
3554
3829
  ctypes.c_float * 3,
3555
3830
  ctypes.c_bool,
3556
3831
  ]
3557
- self.core.volume_from_active_voxels_device.restype = ctypes.c_uint64
3832
+ self.core.wp_volume_from_active_voxels_device.restype = ctypes.c_uint64
3558
3833
 
3559
- self.core.volume_get_buffer_info.argtypes = [
3834
+ self.core.wp_volume_get_buffer_info.argtypes = [
3560
3835
  ctypes.c_uint64,
3561
3836
  ctypes.POINTER(ctypes.c_void_p),
3562
3837
  ctypes.POINTER(ctypes.c_uint64),
3563
3838
  ]
3564
- self.core.volume_get_voxel_size.argtypes = [
3839
+ self.core.wp_volume_get_voxel_size.argtypes = [
3565
3840
  ctypes.c_uint64,
3566
3841
  ctypes.POINTER(ctypes.c_float),
3567
3842
  ctypes.POINTER(ctypes.c_float),
3568
3843
  ctypes.POINTER(ctypes.c_float),
3569
3844
  ]
3570
- self.core.volume_get_tile_and_voxel_count.argtypes = [
3845
+ self.core.wp_volume_get_tile_and_voxel_count.argtypes = [
3571
3846
  ctypes.c_uint64,
3572
3847
  ctypes.POINTER(ctypes.c_uint32),
3573
3848
  ctypes.POINTER(ctypes.c_uint64),
3574
3849
  ]
3575
- self.core.volume_get_grid_info.argtypes = [
3850
+ self.core.wp_volume_get_grid_info.argtypes = [
3576
3851
  ctypes.c_uint64,
3577
3852
  ctypes.POINTER(ctypes.c_uint64),
3578
3853
  ctypes.POINTER(ctypes.c_uint32),
@@ -3581,12 +3856,12 @@ class Runtime:
3581
3856
  ctypes.c_float * 9,
3582
3857
  ctypes.c_char * 16,
3583
3858
  ]
3584
- self.core.volume_get_grid_info.restype = ctypes.c_char_p
3585
- self.core.volume_get_blind_data_count.argtypes = [
3859
+ self.core.wp_volume_get_grid_info.restype = ctypes.c_char_p
3860
+ self.core.wp_volume_get_blind_data_count.argtypes = [
3586
3861
  ctypes.c_uint64,
3587
3862
  ]
3588
- self.core.volume_get_blind_data_count.restype = ctypes.c_uint64
3589
- self.core.volume_get_blind_data_info.argtypes = [
3863
+ self.core.wp_volume_get_blind_data_count.restype = ctypes.c_uint64
3864
+ self.core.wp_volume_get_blind_data_info.argtypes = [
3590
3865
  ctypes.c_uint64,
3591
3866
  ctypes.c_uint32,
3592
3867
  ctypes.POINTER(ctypes.c_void_p),
@@ -3594,7 +3869,7 @@ class Runtime:
3594
3869
  ctypes.POINTER(ctypes.c_uint32),
3595
3870
  ctypes.c_char * 16,
3596
3871
  ]
3597
- self.core.volume_get_blind_data_info.restype = ctypes.c_char_p
3872
+ self.core.wp_volume_get_blind_data_info.restype = ctypes.c_char_p
3598
3873
 
3599
3874
  bsr_matrix_from_triplets_argtypes = [
3600
3875
  ctypes.c_int, # block_size
@@ -3616,8 +3891,8 @@ class Runtime:
3616
3891
  ctypes.c_void_p, # bsr_nnz_event
3617
3892
  ]
3618
3893
 
3619
- self.core.bsr_matrix_from_triplets_host.argtypes = bsr_matrix_from_triplets_argtypes
3620
- self.core.bsr_matrix_from_triplets_device.argtypes = bsr_matrix_from_triplets_argtypes
3894
+ self.core.wp_bsr_matrix_from_triplets_host.argtypes = bsr_matrix_from_triplets_argtypes
3895
+ self.core.wp_bsr_matrix_from_triplets_device.argtypes = bsr_matrix_from_triplets_argtypes
3621
3896
 
3622
3897
  bsr_transpose_argtypes = [
3623
3898
  ctypes.c_int, # row_count
@@ -3629,229 +3904,238 @@ class Runtime:
3629
3904
  ctypes.POINTER(ctypes.c_int), # transposed_bsr_columns
3630
3905
  ctypes.POINTER(ctypes.c_int), # src to dest block map
3631
3906
  ]
3632
- self.core.bsr_transpose_host.argtypes = bsr_transpose_argtypes
3633
- self.core.bsr_transpose_device.argtypes = bsr_transpose_argtypes
3634
-
3635
- self.core.is_cuda_enabled.argtypes = None
3636
- self.core.is_cuda_enabled.restype = ctypes.c_int
3637
- self.core.is_cuda_compatibility_enabled.argtypes = None
3638
- self.core.is_cuda_compatibility_enabled.restype = ctypes.c_int
3639
- self.core.is_mathdx_enabled.argtypes = None
3640
- self.core.is_mathdx_enabled.restype = ctypes.c_int
3641
-
3642
- self.core.cuda_driver_version.argtypes = None
3643
- self.core.cuda_driver_version.restype = ctypes.c_int
3644
- self.core.cuda_toolkit_version.argtypes = None
3645
- self.core.cuda_toolkit_version.restype = ctypes.c_int
3646
- self.core.cuda_driver_is_initialized.argtypes = None
3647
- self.core.cuda_driver_is_initialized.restype = ctypes.c_bool
3648
-
3649
- self.core.nvrtc_supported_arch_count.argtypes = None
3650
- self.core.nvrtc_supported_arch_count.restype = ctypes.c_int
3651
- self.core.nvrtc_supported_archs.argtypes = [ctypes.POINTER(ctypes.c_int)]
3652
- self.core.nvrtc_supported_archs.restype = None
3653
-
3654
- self.core.cuda_device_get_count.argtypes = None
3655
- self.core.cuda_device_get_count.restype = ctypes.c_int
3656
- self.core.cuda_device_get_primary_context.argtypes = [ctypes.c_int]
3657
- self.core.cuda_device_get_primary_context.restype = ctypes.c_void_p
3658
- self.core.cuda_device_get_name.argtypes = [ctypes.c_int]
3659
- self.core.cuda_device_get_name.restype = ctypes.c_char_p
3660
- self.core.cuda_device_get_arch.argtypes = [ctypes.c_int]
3661
- self.core.cuda_device_get_arch.restype = ctypes.c_int
3662
- self.core.cuda_device_get_sm_count.argtypes = [ctypes.c_int]
3663
- self.core.cuda_device_get_sm_count.restype = ctypes.c_int
3664
- self.core.cuda_device_is_uva.argtypes = [ctypes.c_int]
3665
- self.core.cuda_device_is_uva.restype = ctypes.c_int
3666
- self.core.cuda_device_is_mempool_supported.argtypes = [ctypes.c_int]
3667
- self.core.cuda_device_is_mempool_supported.restype = ctypes.c_int
3668
- self.core.cuda_device_is_ipc_supported.argtypes = [ctypes.c_int]
3669
- self.core.cuda_device_is_ipc_supported.restype = ctypes.c_int
3670
- self.core.cuda_device_set_mempool_release_threshold.argtypes = [ctypes.c_int, ctypes.c_uint64]
3671
- self.core.cuda_device_set_mempool_release_threshold.restype = ctypes.c_int
3672
- self.core.cuda_device_get_mempool_release_threshold.argtypes = [ctypes.c_int]
3673
- self.core.cuda_device_get_mempool_release_threshold.restype = ctypes.c_uint64
3674
- self.core.cuda_device_get_mempool_used_mem_current.argtypes = [ctypes.c_int]
3675
- self.core.cuda_device_get_mempool_used_mem_current.restype = ctypes.c_uint64
3676
- self.core.cuda_device_get_mempool_used_mem_high.argtypes = [ctypes.c_int]
3677
- self.core.cuda_device_get_mempool_used_mem_high.restype = ctypes.c_uint64
3678
- self.core.cuda_device_get_memory_info.argtypes = [ctypes.c_int, ctypes.c_void_p, ctypes.c_void_p]
3679
- self.core.cuda_device_get_memory_info.restype = None
3680
- self.core.cuda_device_get_uuid.argtypes = [ctypes.c_int, ctypes.c_char * 16]
3681
- self.core.cuda_device_get_uuid.restype = None
3682
- self.core.cuda_device_get_pci_domain_id.argtypes = [ctypes.c_int]
3683
- self.core.cuda_device_get_pci_domain_id.restype = ctypes.c_int
3684
- self.core.cuda_device_get_pci_bus_id.argtypes = [ctypes.c_int]
3685
- self.core.cuda_device_get_pci_bus_id.restype = ctypes.c_int
3686
- self.core.cuda_device_get_pci_device_id.argtypes = [ctypes.c_int]
3687
- self.core.cuda_device_get_pci_device_id.restype = ctypes.c_int
3688
-
3689
- self.core.cuda_context_get_current.argtypes = None
3690
- self.core.cuda_context_get_current.restype = ctypes.c_void_p
3691
- self.core.cuda_context_set_current.argtypes = [ctypes.c_void_p]
3692
- self.core.cuda_context_set_current.restype = None
3693
- self.core.cuda_context_push_current.argtypes = [ctypes.c_void_p]
3694
- self.core.cuda_context_push_current.restype = None
3695
- self.core.cuda_context_pop_current.argtypes = None
3696
- self.core.cuda_context_pop_current.restype = None
3697
- self.core.cuda_context_create.argtypes = [ctypes.c_int]
3698
- self.core.cuda_context_create.restype = ctypes.c_void_p
3699
- self.core.cuda_context_destroy.argtypes = [ctypes.c_void_p]
3700
- self.core.cuda_context_destroy.restype = None
3701
- self.core.cuda_context_synchronize.argtypes = [ctypes.c_void_p]
3702
- self.core.cuda_context_synchronize.restype = None
3703
- self.core.cuda_context_check.argtypes = [ctypes.c_void_p]
3704
- self.core.cuda_context_check.restype = ctypes.c_uint64
3705
-
3706
- self.core.cuda_context_get_device_ordinal.argtypes = [ctypes.c_void_p]
3707
- self.core.cuda_context_get_device_ordinal.restype = ctypes.c_int
3708
- self.core.cuda_context_is_primary.argtypes = [ctypes.c_void_p]
3709
- self.core.cuda_context_is_primary.restype = ctypes.c_int
3710
- self.core.cuda_context_get_stream.argtypes = [ctypes.c_void_p]
3711
- self.core.cuda_context_get_stream.restype = ctypes.c_void_p
3712
- self.core.cuda_context_set_stream.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int]
3713
- self.core.cuda_context_set_stream.restype = None
3907
+ self.core.wp_bsr_transpose_host.argtypes = bsr_transpose_argtypes
3908
+ self.core.wp_bsr_transpose_device.argtypes = bsr_transpose_argtypes
3909
+
3910
+ self.core.wp_is_cuda_enabled.argtypes = None
3911
+ self.core.wp_is_cuda_enabled.restype = ctypes.c_int
3912
+ self.core.wp_is_cuda_compatibility_enabled.argtypes = None
3913
+ self.core.wp_is_cuda_compatibility_enabled.restype = ctypes.c_int
3914
+ self.core.wp_is_mathdx_enabled.argtypes = None
3915
+ self.core.wp_is_mathdx_enabled.restype = ctypes.c_int
3916
+
3917
+ self.core.wp_cuda_driver_version.argtypes = None
3918
+ self.core.wp_cuda_driver_version.restype = ctypes.c_int
3919
+ self.core.wp_cuda_toolkit_version.argtypes = None
3920
+ self.core.wp_cuda_toolkit_version.restype = ctypes.c_int
3921
+ self.core.wp_cuda_driver_is_initialized.argtypes = None
3922
+ self.core.wp_cuda_driver_is_initialized.restype = ctypes.c_bool
3923
+
3924
+ self.core.wp_nvrtc_supported_arch_count.argtypes = None
3925
+ self.core.wp_nvrtc_supported_arch_count.restype = ctypes.c_int
3926
+ self.core.wp_nvrtc_supported_archs.argtypes = [ctypes.POINTER(ctypes.c_int)]
3927
+ self.core.wp_nvrtc_supported_archs.restype = None
3928
+
3929
+ self.core.wp_cuda_device_get_count.argtypes = None
3930
+ self.core.wp_cuda_device_get_count.restype = ctypes.c_int
3931
+ self.core.wp_cuda_device_get_primary_context.argtypes = [ctypes.c_int]
3932
+ self.core.wp_cuda_device_get_primary_context.restype = ctypes.c_void_p
3933
+ self.core.wp_cuda_device_get_name.argtypes = [ctypes.c_int]
3934
+ self.core.wp_cuda_device_get_name.restype = ctypes.c_char_p
3935
+ self.core.wp_cuda_device_get_arch.argtypes = [ctypes.c_int]
3936
+ self.core.wp_cuda_device_get_arch.restype = ctypes.c_int
3937
+ self.core.wp_cuda_device_get_sm_count.argtypes = [ctypes.c_int]
3938
+ self.core.wp_cuda_device_get_sm_count.restype = ctypes.c_int
3939
+ self.core.wp_cuda_device_is_uva.argtypes = [ctypes.c_int]
3940
+ self.core.wp_cuda_device_is_uva.restype = ctypes.c_int
3941
+ self.core.wp_cuda_device_is_mempool_supported.argtypes = [ctypes.c_int]
3942
+ self.core.wp_cuda_device_is_mempool_supported.restype = ctypes.c_int
3943
+ self.core.wp_cuda_device_is_ipc_supported.argtypes = [ctypes.c_int]
3944
+ self.core.wp_cuda_device_is_ipc_supported.restype = ctypes.c_int
3945
+ self.core.wp_cuda_device_set_mempool_release_threshold.argtypes = [ctypes.c_int, ctypes.c_uint64]
3946
+ self.core.wp_cuda_device_set_mempool_release_threshold.restype = ctypes.c_int
3947
+ self.core.wp_cuda_device_get_mempool_release_threshold.argtypes = [ctypes.c_int]
3948
+ self.core.wp_cuda_device_get_mempool_release_threshold.restype = ctypes.c_uint64
3949
+ self.core.wp_cuda_device_get_mempool_used_mem_current.argtypes = [ctypes.c_int]
3950
+ self.core.wp_cuda_device_get_mempool_used_mem_current.restype = ctypes.c_uint64
3951
+ self.core.wp_cuda_device_get_mempool_used_mem_high.argtypes = [ctypes.c_int]
3952
+ self.core.wp_cuda_device_get_mempool_used_mem_high.restype = ctypes.c_uint64
3953
+ self.core.wp_cuda_device_get_memory_info.argtypes = [ctypes.c_int, ctypes.c_void_p, ctypes.c_void_p]
3954
+ self.core.wp_cuda_device_get_memory_info.restype = None
3955
+ self.core.wp_cuda_device_get_uuid.argtypes = [ctypes.c_int, ctypes.c_char * 16]
3956
+ self.core.wp_cuda_device_get_uuid.restype = None
3957
+ self.core.wp_cuda_device_get_pci_domain_id.argtypes = [ctypes.c_int]
3958
+ self.core.wp_cuda_device_get_pci_domain_id.restype = ctypes.c_int
3959
+ self.core.wp_cuda_device_get_pci_bus_id.argtypes = [ctypes.c_int]
3960
+ self.core.wp_cuda_device_get_pci_bus_id.restype = ctypes.c_int
3961
+ self.core.wp_cuda_device_get_pci_device_id.argtypes = [ctypes.c_int]
3962
+ self.core.wp_cuda_device_get_pci_device_id.restype = ctypes.c_int
3963
+
3964
+ self.core.wp_cuda_context_get_current.argtypes = None
3965
+ self.core.wp_cuda_context_get_current.restype = ctypes.c_void_p
3966
+ self.core.wp_cuda_context_set_current.argtypes = [ctypes.c_void_p]
3967
+ self.core.wp_cuda_context_set_current.restype = None
3968
+ self.core.wp_cuda_context_push_current.argtypes = [ctypes.c_void_p]
3969
+ self.core.wp_cuda_context_push_current.restype = None
3970
+ self.core.wp_cuda_context_pop_current.argtypes = None
3971
+ self.core.wp_cuda_context_pop_current.restype = None
3972
+ self.core.wp_cuda_context_create.argtypes = [ctypes.c_int]
3973
+ self.core.wp_cuda_context_create.restype = ctypes.c_void_p
3974
+ self.core.wp_cuda_context_destroy.argtypes = [ctypes.c_void_p]
3975
+ self.core.wp_cuda_context_destroy.restype = None
3976
+ self.core.wp_cuda_context_synchronize.argtypes = [ctypes.c_void_p]
3977
+ self.core.wp_cuda_context_synchronize.restype = None
3978
+ self.core.wp_cuda_context_check.argtypes = [ctypes.c_void_p]
3979
+ self.core.wp_cuda_context_check.restype = ctypes.c_uint64
3980
+
3981
+ self.core.wp_cuda_context_get_device_ordinal.argtypes = [ctypes.c_void_p]
3982
+ self.core.wp_cuda_context_get_device_ordinal.restype = ctypes.c_int
3983
+ self.core.wp_cuda_context_is_primary.argtypes = [ctypes.c_void_p]
3984
+ self.core.wp_cuda_context_is_primary.restype = ctypes.c_int
3985
+ self.core.wp_cuda_context_get_stream.argtypes = [ctypes.c_void_p]
3986
+ self.core.wp_cuda_context_get_stream.restype = ctypes.c_void_p
3987
+ self.core.wp_cuda_context_set_stream.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int]
3988
+ self.core.wp_cuda_context_set_stream.restype = None
3714
3989
 
3715
3990
  # peer access
3716
- self.core.cuda_is_peer_access_supported.argtypes = [ctypes.c_int, ctypes.c_int]
3717
- self.core.cuda_is_peer_access_supported.restype = ctypes.c_int
3718
- self.core.cuda_is_peer_access_enabled.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
3719
- self.core.cuda_is_peer_access_enabled.restype = ctypes.c_int
3720
- self.core.cuda_set_peer_access_enabled.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int]
3721
- self.core.cuda_set_peer_access_enabled.restype = ctypes.c_int
3722
- self.core.cuda_is_mempool_access_enabled.argtypes = [ctypes.c_int, ctypes.c_int]
3723
- self.core.cuda_is_mempool_access_enabled.restype = ctypes.c_int
3724
- self.core.cuda_set_mempool_access_enabled.argtypes = [ctypes.c_int, ctypes.c_int, ctypes.c_int]
3725
- self.core.cuda_set_mempool_access_enabled.restype = ctypes.c_int
3991
+ self.core.wp_cuda_is_peer_access_supported.argtypes = [ctypes.c_int, ctypes.c_int]
3992
+ self.core.wp_cuda_is_peer_access_supported.restype = ctypes.c_int
3993
+ self.core.wp_cuda_is_peer_access_enabled.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
3994
+ self.core.wp_cuda_is_peer_access_enabled.restype = ctypes.c_int
3995
+ self.core.wp_cuda_set_peer_access_enabled.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int]
3996
+ self.core.wp_cuda_set_peer_access_enabled.restype = ctypes.c_int
3997
+ self.core.wp_cuda_is_mempool_access_enabled.argtypes = [ctypes.c_int, ctypes.c_int]
3998
+ self.core.wp_cuda_is_mempool_access_enabled.restype = ctypes.c_int
3999
+ self.core.wp_cuda_set_mempool_access_enabled.argtypes = [ctypes.c_int, ctypes.c_int, ctypes.c_int]
4000
+ self.core.wp_cuda_set_mempool_access_enabled.restype = ctypes.c_int
3726
4001
 
3727
4002
  # inter-process communication
3728
- self.core.cuda_ipc_get_mem_handle.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_char)]
3729
- self.core.cuda_ipc_get_mem_handle.restype = None
3730
- self.core.cuda_ipc_open_mem_handle.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_char)]
3731
- self.core.cuda_ipc_open_mem_handle.restype = ctypes.c_void_p
3732
- self.core.cuda_ipc_close_mem_handle.argtypes = [ctypes.c_void_p]
3733
- self.core.cuda_ipc_close_mem_handle.restype = None
3734
- self.core.cuda_ipc_get_event_handle.argtypes = [
4003
+ self.core.wp_cuda_ipc_get_mem_handle.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_char)]
4004
+ self.core.wp_cuda_ipc_get_mem_handle.restype = None
4005
+ self.core.wp_cuda_ipc_open_mem_handle.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_char)]
4006
+ self.core.wp_cuda_ipc_open_mem_handle.restype = ctypes.c_void_p
4007
+ self.core.wp_cuda_ipc_close_mem_handle.argtypes = [ctypes.c_void_p]
4008
+ self.core.wp_cuda_ipc_close_mem_handle.restype = None
4009
+ self.core.wp_cuda_ipc_get_event_handle.argtypes = [
3735
4010
  ctypes.c_void_p,
3736
4011
  ctypes.c_void_p,
3737
4012
  ctypes.POINTER(ctypes.c_char),
3738
4013
  ]
3739
- self.core.cuda_ipc_get_event_handle.restype = None
3740
- self.core.cuda_ipc_open_event_handle.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_char)]
3741
- self.core.cuda_ipc_open_event_handle.restype = ctypes.c_void_p
3742
-
3743
- self.core.cuda_stream_create.argtypes = [ctypes.c_void_p, ctypes.c_int]
3744
- self.core.cuda_stream_create.restype = ctypes.c_void_p
3745
- self.core.cuda_stream_destroy.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
3746
- self.core.cuda_stream_destroy.restype = None
3747
- self.core.cuda_stream_query.argtypes = [ctypes.c_void_p]
3748
- self.core.cuda_stream_query.restype = ctypes.c_int
3749
- self.core.cuda_stream_register.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
3750
- self.core.cuda_stream_register.restype = None
3751
- self.core.cuda_stream_unregister.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
3752
- self.core.cuda_stream_unregister.restype = None
3753
- self.core.cuda_stream_synchronize.argtypes = [ctypes.c_void_p]
3754
- self.core.cuda_stream_synchronize.restype = None
3755
- self.core.cuda_stream_wait_event.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
3756
- self.core.cuda_stream_wait_event.restype = None
3757
- self.core.cuda_stream_wait_stream.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p]
3758
- self.core.cuda_stream_wait_stream.restype = None
3759
- self.core.cuda_stream_is_capturing.argtypes = [ctypes.c_void_p]
3760
- self.core.cuda_stream_is_capturing.restype = ctypes.c_int
3761
- self.core.cuda_stream_get_capture_id.argtypes = [ctypes.c_void_p]
3762
- self.core.cuda_stream_get_capture_id.restype = ctypes.c_uint64
3763
- self.core.cuda_stream_get_priority.argtypes = [ctypes.c_void_p]
3764
- self.core.cuda_stream_get_priority.restype = ctypes.c_int
3765
-
3766
- self.core.cuda_event_create.argtypes = [ctypes.c_void_p, ctypes.c_uint]
3767
- self.core.cuda_event_create.restype = ctypes.c_void_p
3768
- self.core.cuda_event_destroy.argtypes = [ctypes.c_void_p]
3769
- self.core.cuda_event_destroy.restype = None
3770
- self.core.cuda_event_query.argtypes = [ctypes.c_void_p]
3771
- self.core.cuda_event_query.restype = ctypes.c_int
3772
- self.core.cuda_event_record.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_bool]
3773
- self.core.cuda_event_record.restype = None
3774
- self.core.cuda_event_synchronize.argtypes = [ctypes.c_void_p]
3775
- self.core.cuda_event_synchronize.restype = None
3776
- self.core.cuda_event_elapsed_time.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
3777
- self.core.cuda_event_elapsed_time.restype = ctypes.c_float
3778
-
3779
- self.core.cuda_graph_begin_capture.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int]
3780
- self.core.cuda_graph_begin_capture.restype = ctypes.c_bool
3781
- self.core.cuda_graph_end_capture.argtypes = [
4014
+ self.core.wp_cuda_ipc_get_event_handle.restype = None
4015
+ self.core.wp_cuda_ipc_open_event_handle.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_char)]
4016
+ self.core.wp_cuda_ipc_open_event_handle.restype = ctypes.c_void_p
4017
+
4018
+ self.core.wp_cuda_stream_create.argtypes = [ctypes.c_void_p, ctypes.c_int]
4019
+ self.core.wp_cuda_stream_create.restype = ctypes.c_void_p
4020
+ self.core.wp_cuda_stream_destroy.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
4021
+ self.core.wp_cuda_stream_destroy.restype = None
4022
+ self.core.wp_cuda_stream_query.argtypes = [ctypes.c_void_p]
4023
+ self.core.wp_cuda_stream_query.restype = ctypes.c_int
4024
+ self.core.wp_cuda_stream_register.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
4025
+ self.core.wp_cuda_stream_register.restype = None
4026
+ self.core.wp_cuda_stream_unregister.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
4027
+ self.core.wp_cuda_stream_unregister.restype = None
4028
+ self.core.wp_cuda_stream_synchronize.argtypes = [ctypes.c_void_p]
4029
+ self.core.wp_cuda_stream_synchronize.restype = None
4030
+ self.core.wp_cuda_stream_wait_event.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
4031
+ self.core.wp_cuda_stream_wait_event.restype = None
4032
+ self.core.wp_cuda_stream_wait_stream.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p]
4033
+ self.core.wp_cuda_stream_wait_stream.restype = None
4034
+ self.core.wp_cuda_stream_is_capturing.argtypes = [ctypes.c_void_p]
4035
+ self.core.wp_cuda_stream_is_capturing.restype = ctypes.c_int
4036
+ self.core.wp_cuda_stream_get_capture_id.argtypes = [ctypes.c_void_p]
4037
+ self.core.wp_cuda_stream_get_capture_id.restype = ctypes.c_uint64
4038
+ self.core.wp_cuda_stream_get_priority.argtypes = [ctypes.c_void_p]
4039
+ self.core.wp_cuda_stream_get_priority.restype = ctypes.c_int
4040
+
4041
+ self.core.wp_cuda_event_create.argtypes = [ctypes.c_void_p, ctypes.c_uint]
4042
+ self.core.wp_cuda_event_create.restype = ctypes.c_void_p
4043
+ self.core.wp_cuda_event_destroy.argtypes = [ctypes.c_void_p]
4044
+ self.core.wp_cuda_event_destroy.restype = None
4045
+ self.core.wp_cuda_event_query.argtypes = [ctypes.c_void_p]
4046
+ self.core.wp_cuda_event_query.restype = ctypes.c_int
4047
+ self.core.wp_cuda_event_record.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_bool]
4048
+ self.core.wp_cuda_event_record.restype = None
4049
+ self.core.wp_cuda_event_synchronize.argtypes = [ctypes.c_void_p]
4050
+ self.core.wp_cuda_event_synchronize.restype = None
4051
+ self.core.wp_cuda_event_elapsed_time.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
4052
+ self.core.wp_cuda_event_elapsed_time.restype = ctypes.c_float
4053
+
4054
+ self.core.wp_cuda_graph_begin_capture.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int]
4055
+ self.core.wp_cuda_graph_begin_capture.restype = ctypes.c_bool
4056
+ self.core.wp_cuda_graph_end_capture.argtypes = [
3782
4057
  ctypes.c_void_p,
3783
4058
  ctypes.c_void_p,
3784
4059
  ctypes.POINTER(ctypes.c_void_p),
3785
4060
  ]
3786
- self.core.cuda_graph_end_capture.restype = ctypes.c_bool
4061
+ self.core.wp_cuda_graph_end_capture.restype = ctypes.c_bool
3787
4062
 
3788
- self.core.cuda_graph_create_exec.argtypes = [
4063
+ self.core.wp_cuda_graph_create_exec.argtypes = [
3789
4064
  ctypes.c_void_p,
3790
4065
  ctypes.c_void_p,
3791
4066
  ctypes.c_void_p,
3792
4067
  ctypes.POINTER(ctypes.c_void_p),
3793
4068
  ]
3794
- self.core.cuda_graph_create_exec.restype = ctypes.c_bool
4069
+ self.core.wp_cuda_graph_create_exec.restype = ctypes.c_bool
3795
4070
 
3796
- self.core.capture_debug_dot_print.argtypes = [ctypes.c_void_p, ctypes.c_char_p, ctypes.c_uint32]
3797
- self.core.capture_debug_dot_print.restype = ctypes.c_bool
4071
+ self.core.wp_capture_debug_dot_print.argtypes = [ctypes.c_void_p, ctypes.c_char_p, ctypes.c_uint32]
4072
+ self.core.wp_capture_debug_dot_print.restype = ctypes.c_bool
3798
4073
 
3799
- self.core.cuda_graph_launch.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
3800
- self.core.cuda_graph_launch.restype = ctypes.c_bool
3801
- self.core.cuda_graph_exec_destroy.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
3802
- self.core.cuda_graph_exec_destroy.restype = ctypes.c_bool
4074
+ self.core.wp_cuda_graph_launch.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
4075
+ self.core.wp_cuda_graph_launch.restype = ctypes.c_bool
4076
+ self.core.wp_cuda_graph_exec_destroy.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
4077
+ self.core.wp_cuda_graph_exec_destroy.restype = ctypes.c_bool
3803
4078
 
3804
- self.core.cuda_graph_destroy.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
3805
- self.core.cuda_graph_destroy.restype = ctypes.c_bool
4079
+ self.core.wp_cuda_graph_destroy.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
4080
+ self.core.wp_cuda_graph_destroy.restype = ctypes.c_bool
3806
4081
 
3807
- self.core.cuda_graph_insert_if_else.argtypes = [
4082
+ self.core.wp_cuda_graph_insert_if_else.argtypes = [
3808
4083
  ctypes.c_void_p,
3809
4084
  ctypes.c_void_p,
4085
+ ctypes.c_int,
4086
+ ctypes.c_bool,
3810
4087
  ctypes.POINTER(ctypes.c_int),
3811
4088
  ctypes.POINTER(ctypes.c_void_p),
3812
4089
  ctypes.POINTER(ctypes.c_void_p),
3813
4090
  ]
3814
- self.core.cuda_graph_insert_if_else.restype = ctypes.c_bool
4091
+ self.core.wp_cuda_graph_insert_if_else.restype = ctypes.c_bool
3815
4092
 
3816
- self.core.cuda_graph_insert_while.argtypes = [
4093
+ self.core.wp_cuda_graph_insert_while.argtypes = [
3817
4094
  ctypes.c_void_p,
3818
4095
  ctypes.c_void_p,
4096
+ ctypes.c_int,
4097
+ ctypes.c_bool,
3819
4098
  ctypes.POINTER(ctypes.c_int),
3820
4099
  ctypes.POINTER(ctypes.c_void_p),
3821
4100
  ctypes.POINTER(ctypes.c_uint64),
3822
4101
  ]
3823
- self.core.cuda_graph_insert_while.restype = ctypes.c_bool
4102
+ self.core.wp_cuda_graph_insert_while.restype = ctypes.c_bool
3824
4103
 
3825
- self.core.cuda_graph_set_condition.argtypes = [
4104
+ self.core.wp_cuda_graph_set_condition.argtypes = [
3826
4105
  ctypes.c_void_p,
3827
4106
  ctypes.c_void_p,
4107
+ ctypes.c_int,
4108
+ ctypes.c_bool,
3828
4109
  ctypes.POINTER(ctypes.c_int),
3829
4110
  ctypes.c_uint64,
3830
4111
  ]
3831
- self.core.cuda_graph_set_condition.restype = ctypes.c_bool
4112
+ self.core.wp_cuda_graph_set_condition.restype = ctypes.c_bool
3832
4113
 
3833
- self.core.cuda_graph_pause_capture.argtypes = [
4114
+ self.core.wp_cuda_graph_pause_capture.argtypes = [
3834
4115
  ctypes.c_void_p,
3835
4116
  ctypes.c_void_p,
3836
4117
  ctypes.POINTER(ctypes.c_void_p),
3837
4118
  ]
3838
- self.core.cuda_graph_pause_capture.restype = ctypes.c_bool
4119
+ self.core.wp_cuda_graph_pause_capture.restype = ctypes.c_bool
3839
4120
 
3840
- self.core.cuda_graph_resume_capture.argtypes = [
4121
+ self.core.wp_cuda_graph_resume_capture.argtypes = [
3841
4122
  ctypes.c_void_p,
3842
4123
  ctypes.c_void_p,
3843
4124
  ctypes.c_void_p,
3844
4125
  ]
3845
- self.core.cuda_graph_resume_capture.restype = ctypes.c_bool
4126
+ self.core.wp_cuda_graph_resume_capture.restype = ctypes.c_bool
3846
4127
 
3847
- self.core.cuda_graph_insert_child_graph.argtypes = [
4128
+ self.core.wp_cuda_graph_insert_child_graph.argtypes = [
3848
4129
  ctypes.c_void_p,
3849
4130
  ctypes.c_void_p,
3850
4131
  ctypes.c_void_p,
3851
4132
  ]
3852
- self.core.cuda_graph_insert_child_graph.restype = ctypes.c_bool
4133
+ self.core.wp_cuda_graph_insert_child_graph.restype = ctypes.c_bool
4134
+
4135
+ self.core.wp_cuda_graph_check_conditional_body.argtypes = [ctypes.c_void_p]
4136
+ self.core.wp_cuda_graph_check_conditional_body.restype = ctypes.c_bool
3853
4137
 
3854
- self.core.cuda_compile_program.argtypes = [
4138
+ self.core.wp_cuda_compile_program.argtypes = [
3855
4139
  ctypes.c_char_p, # cuda_src
3856
4140
  ctypes.c_char_p, # program name
3857
4141
  ctypes.c_int, # arch
@@ -3871,9 +4155,9 @@ class Runtime:
3871
4155
  ctypes.POINTER(ctypes.c_size_t), # ltoir_sizes
3872
4156
  ctypes.POINTER(ctypes.c_int), # ltoir_input_types, each of type nvJitLinkInputType
3873
4157
  ]
3874
- self.core.cuda_compile_program.restype = ctypes.c_size_t
4158
+ self.core.wp_cuda_compile_program.restype = ctypes.c_size_t
3875
4159
 
3876
- self.core.cuda_compile_fft.argtypes = [
4160
+ self.core.wp_cuda_compile_fft.argtypes = [
3877
4161
  ctypes.c_char_p, # lto
3878
4162
  ctypes.c_char_p, # function name
3879
4163
  ctypes.c_int, # num include dirs
@@ -3886,9 +4170,9 @@ class Runtime:
3886
4170
  ctypes.c_int, # precision
3887
4171
  ctypes.POINTER(ctypes.c_int), # smem (out)
3888
4172
  ]
3889
- self.core.cuda_compile_fft.restype = ctypes.c_bool
4173
+ self.core.wp_cuda_compile_fft.restype = ctypes.c_bool
3890
4174
 
3891
- self.core.cuda_compile_dot.argtypes = [
4175
+ self.core.wp_cuda_compile_dot.argtypes = [
3892
4176
  ctypes.c_char_p, # lto
3893
4177
  ctypes.c_char_p, # function name
3894
4178
  ctypes.c_int, # num include dirs
@@ -3907,9 +4191,9 @@ class Runtime:
3907
4191
  ctypes.c_int, # c_arrangement
3908
4192
  ctypes.c_int, # num threads
3909
4193
  ]
3910
- self.core.cuda_compile_dot.restype = ctypes.c_bool
4194
+ self.core.wp_cuda_compile_dot.restype = ctypes.c_bool
3911
4195
 
3912
- self.core.cuda_compile_solver.argtypes = [
4196
+ self.core.wp_cuda_compile_solver.argtypes = [
3913
4197
  ctypes.c_char_p, # universal fatbin
3914
4198
  ctypes.c_char_p, # lto
3915
4199
  ctypes.c_char_p, # function name
@@ -3929,24 +4213,24 @@ class Runtime:
3929
4213
  ctypes.c_int, # fill_mode
3930
4214
  ctypes.c_int, # num threads
3931
4215
  ]
3932
- self.core.cuda_compile_solver.restype = ctypes.c_bool
4216
+ self.core.wp_cuda_compile_solver.restype = ctypes.c_bool
3933
4217
 
3934
- self.core.cuda_load_module.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
3935
- self.core.cuda_load_module.restype = ctypes.c_void_p
4218
+ self.core.wp_cuda_load_module.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
4219
+ self.core.wp_cuda_load_module.restype = ctypes.c_void_p
3936
4220
 
3937
- self.core.cuda_unload_module.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
3938
- self.core.cuda_unload_module.restype = None
4221
+ self.core.wp_cuda_unload_module.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
4222
+ self.core.wp_cuda_unload_module.restype = None
3939
4223
 
3940
- self.core.cuda_get_kernel.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_char_p]
3941
- self.core.cuda_get_kernel.restype = ctypes.c_void_p
4224
+ self.core.wp_cuda_get_kernel.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_char_p]
4225
+ self.core.wp_cuda_get_kernel.restype = ctypes.c_void_p
3942
4226
 
3943
- self.core.cuda_get_max_shared_memory.argtypes = [ctypes.c_void_p]
3944
- self.core.cuda_get_max_shared_memory.restype = ctypes.c_int
4227
+ self.core.wp_cuda_get_max_shared_memory.argtypes = [ctypes.c_void_p]
4228
+ self.core.wp_cuda_get_max_shared_memory.restype = ctypes.c_int
3945
4229
 
3946
- self.core.cuda_configure_kernel_shared_memory.argtypes = [ctypes.c_void_p, ctypes.c_int]
3947
- self.core.cuda_configure_kernel_shared_memory.restype = ctypes.c_bool
4230
+ self.core.wp_cuda_configure_kernel_shared_memory.argtypes = [ctypes.c_void_p, ctypes.c_int]
4231
+ self.core.wp_cuda_configure_kernel_shared_memory.restype = ctypes.c_bool
3948
4232
 
3949
- self.core.cuda_launch_kernel.argtypes = [
4233
+ self.core.wp_cuda_launch_kernel.argtypes = [
3950
4234
  ctypes.c_void_p,
3951
4235
  ctypes.c_void_p,
3952
4236
  ctypes.c_size_t,
@@ -3956,54 +4240,54 @@ class Runtime:
3956
4240
  ctypes.POINTER(ctypes.c_void_p),
3957
4241
  ctypes.c_void_p,
3958
4242
  ]
3959
- self.core.cuda_launch_kernel.restype = ctypes.c_size_t
4243
+ self.core.wp_cuda_launch_kernel.restype = ctypes.c_size_t
3960
4244
 
3961
- self.core.cuda_graphics_map.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
3962
- self.core.cuda_graphics_map.restype = None
3963
- self.core.cuda_graphics_unmap.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
3964
- self.core.cuda_graphics_unmap.restype = None
3965
- self.core.cuda_graphics_device_ptr_and_size.argtypes = [
4245
+ self.core.wp_cuda_graphics_map.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
4246
+ self.core.wp_cuda_graphics_map.restype = None
4247
+ self.core.wp_cuda_graphics_unmap.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
4248
+ self.core.wp_cuda_graphics_unmap.restype = None
4249
+ self.core.wp_cuda_graphics_device_ptr_and_size.argtypes = [
3966
4250
  ctypes.c_void_p,
3967
4251
  ctypes.c_void_p,
3968
4252
  ctypes.POINTER(ctypes.c_uint64),
3969
4253
  ctypes.POINTER(ctypes.c_size_t),
3970
4254
  ]
3971
- self.core.cuda_graphics_device_ptr_and_size.restype = None
3972
- self.core.cuda_graphics_register_gl_buffer.argtypes = [ctypes.c_void_p, ctypes.c_uint32, ctypes.c_uint]
3973
- self.core.cuda_graphics_register_gl_buffer.restype = ctypes.c_void_p
3974
- self.core.cuda_graphics_unregister_resource.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
3975
- self.core.cuda_graphics_unregister_resource.restype = None
3976
-
3977
- self.core.cuda_timing_begin.argtypes = [ctypes.c_int]
3978
- self.core.cuda_timing_begin.restype = None
3979
- self.core.cuda_timing_get_result_count.argtypes = []
3980
- self.core.cuda_timing_get_result_count.restype = int
3981
- self.core.cuda_timing_end.argtypes = []
3982
- self.core.cuda_timing_end.restype = None
3983
-
3984
- self.core.graph_coloring.argtypes = [
4255
+ self.core.wp_cuda_graphics_device_ptr_and_size.restype = None
4256
+ self.core.wp_cuda_graphics_register_gl_buffer.argtypes = [ctypes.c_void_p, ctypes.c_uint32, ctypes.c_uint]
4257
+ self.core.wp_cuda_graphics_register_gl_buffer.restype = ctypes.c_void_p
4258
+ self.core.wp_cuda_graphics_unregister_resource.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
4259
+ self.core.wp_cuda_graphics_unregister_resource.restype = None
4260
+
4261
+ self.core.wp_cuda_timing_begin.argtypes = [ctypes.c_int]
4262
+ self.core.wp_cuda_timing_begin.restype = None
4263
+ self.core.wp_cuda_timing_get_result_count.argtypes = []
4264
+ self.core.wp_cuda_timing_get_result_count.restype = int
4265
+ self.core.wp_cuda_timing_end.argtypes = []
4266
+ self.core.wp_cuda_timing_end.restype = None
4267
+
4268
+ self.core.wp_graph_coloring.argtypes = [
3985
4269
  ctypes.c_int,
3986
4270
  warp.types.array_t,
3987
4271
  ctypes.c_int,
3988
4272
  warp.types.array_t,
3989
4273
  ]
3990
- self.core.graph_coloring.restype = ctypes.c_int
4274
+ self.core.wp_graph_coloring.restype = ctypes.c_int
3991
4275
 
3992
- self.core.balance_coloring.argtypes = [
4276
+ self.core.wp_balance_coloring.argtypes = [
3993
4277
  ctypes.c_int,
3994
4278
  warp.types.array_t,
3995
4279
  ctypes.c_int,
3996
4280
  ctypes.c_float,
3997
4281
  warp.types.array_t,
3998
4282
  ]
3999
- self.core.balance_coloring.restype = ctypes.c_float
4283
+ self.core.wp_balance_coloring.restype = ctypes.c_float
4000
4284
 
4001
- self.core.init.restype = ctypes.c_int
4285
+ self.core.wp_init.restype = ctypes.c_int
4002
4286
 
4003
4287
  except AttributeError as e:
4004
4288
  raise RuntimeError(f"Setting C-types for {warp_lib} failed. It may need rebuilding.") from e
4005
4289
 
4006
- error = self.core.init()
4290
+ error = self.core.wp_init()
4007
4291
 
4008
4292
  if error != 0:
4009
4293
  raise Exception("Warp initialization failed")
@@ -4019,8 +4303,8 @@ class Runtime:
4019
4303
  self.device_map["cpu"] = self.cpu_device
4020
4304
  self.context_map[None] = self.cpu_device
4021
4305
 
4022
- self.is_cuda_enabled = bool(self.core.is_cuda_enabled())
4023
- self.is_cuda_compatibility_enabled = bool(self.core.is_cuda_compatibility_enabled())
4306
+ self.is_cuda_enabled = bool(self.core.wp_is_cuda_enabled())
4307
+ self.is_cuda_compatibility_enabled = bool(self.core.wp_is_cuda_compatibility_enabled())
4024
4308
 
4025
4309
  self.toolkit_version = None # CTK version used to build the core lib
4026
4310
  self.driver_version = None # installed driver version
@@ -4033,12 +4317,15 @@ class Runtime:
4033
4317
 
4034
4318
  if self.is_cuda_enabled:
4035
4319
  # get CUDA Toolkit and driver versions
4036
- toolkit_version = self.core.cuda_toolkit_version()
4037
- driver_version = self.core.cuda_driver_version()
4038
-
4039
- # save versions as tuples, e.g., (12, 4)
4320
+ toolkit_version = self.core.wp_cuda_toolkit_version()
4040
4321
  self.toolkit_version = (toolkit_version // 1000, (toolkit_version % 1000) // 10)
4041
- self.driver_version = (driver_version // 1000, (driver_version % 1000) // 10)
4322
+
4323
+ if self.core.wp_cuda_driver_is_initialized():
4324
+ # save versions as tuples, e.g., (12, 4)
4325
+ driver_version = self.core.wp_cuda_driver_version()
4326
+ self.driver_version = (driver_version // 1000, (driver_version % 1000) // 10)
4327
+ else:
4328
+ self.driver_version = None
4042
4329
 
4043
4330
  # determine minimum required driver version
4044
4331
  if self.is_cuda_compatibility_enabled:
@@ -4052,18 +4339,18 @@ class Runtime:
4052
4339
  self.min_driver_version = self.toolkit_version
4053
4340
 
4054
4341
  # determine if the installed driver is sufficient
4055
- if self.driver_version >= self.min_driver_version:
4342
+ if self.driver_version is not None and self.driver_version >= self.min_driver_version:
4056
4343
  # get all architectures supported by NVRTC
4057
- num_archs = self.core.nvrtc_supported_arch_count()
4344
+ num_archs = self.core.wp_nvrtc_supported_arch_count()
4058
4345
  if num_archs > 0:
4059
4346
  archs = (ctypes.c_int * num_archs)()
4060
- self.core.nvrtc_supported_archs(archs)
4347
+ self.core.wp_nvrtc_supported_archs(archs)
4061
4348
  self.nvrtc_supported_archs = set(archs)
4062
4349
  else:
4063
4350
  self.nvrtc_supported_archs = set()
4064
4351
 
4065
4352
  # get CUDA device count
4066
- cuda_device_count = self.core.cuda_device_get_count()
4353
+ cuda_device_count = self.core.wp_cuda_device_get_count()
4067
4354
 
4068
4355
  # register primary CUDA devices
4069
4356
  for i in range(cuda_device_count):
@@ -4080,7 +4367,7 @@ class Runtime:
4080
4367
  # set default device
4081
4368
  if cuda_device_count > 0:
4082
4369
  # stick with the current cuda context, if one is bound
4083
- initial_context = self.core.cuda_context_get_current()
4370
+ initial_context = self.core.wp_cuda_context_get_current()
4084
4371
  if initial_context is not None:
4085
4372
  self.set_default_device("cuda")
4086
4373
  # if this is a non-primary context that was just registered, update the device count
@@ -4133,6 +4420,8 @@ class Runtime:
4133
4420
  if not self.is_cuda_enabled:
4134
4421
  # Warp was compiled without CUDA support
4135
4422
  greeting.append(" CUDA not enabled in this build")
4423
+ elif self.driver_version is None:
4424
+ greeting.append(" CUDA driver not found or failed to initialize")
4136
4425
  elif self.driver_version < self.min_driver_version:
4137
4426
  # insufficient CUDA driver version
4138
4427
  greeting.append(
@@ -4176,7 +4465,7 @@ class Runtime:
4176
4465
  access_vector.append(1)
4177
4466
  else:
4178
4467
  peer_device = self.cuda_devices[j]
4179
- can_access = self.core.cuda_is_peer_access_supported(
4468
+ can_access = self.core.wp_cuda_is_peer_access_supported(
4180
4469
  target_device.ordinal, peer_device.ordinal
4181
4470
  )
4182
4471
  access_vector.append(can_access)
@@ -4201,7 +4490,7 @@ class Runtime:
4201
4490
 
4202
4491
  if cuda_device_count > 0:
4203
4492
  # ensure initialization did not change the initial context (e.g. querying available memory)
4204
- self.core.cuda_context_set_current(initial_context)
4493
+ self.core.wp_cuda_context_set_current(initial_context)
4205
4494
 
4206
4495
  # detect possible misconfiguration of the system
4207
4496
  devices_without_uva = []
@@ -4229,7 +4518,7 @@ class Runtime:
4229
4518
  elif self.is_cuda_enabled:
4230
4519
  # Report a warning about insufficient driver version. The warning should appear even in quiet mode
4231
4520
  # when the greeting message is suppressed. Also try to provide guidance for resolving the situation.
4232
- if self.driver_version < self.min_driver_version:
4521
+ if self.driver_version is not None and self.driver_version < self.min_driver_version:
4233
4522
  msg = []
4234
4523
  msg.append("\n Insufficient CUDA driver version.")
4235
4524
  msg.append(
@@ -4240,7 +4529,7 @@ class Runtime:
4240
4529
  warp.utils.warn("\n ".join(msg))
4241
4530
 
4242
4531
  def get_error_string(self):
4243
- return self.core.get_error_string().decode("utf-8")
4532
+ return self.core.wp_get_error_string().decode("utf-8")
4244
4533
 
4245
4534
  def load_dll(self, dll_path):
4246
4535
  try:
@@ -4276,21 +4565,21 @@ class Runtime:
4276
4565
  self.default_device = self.get_device(ident)
4277
4566
 
4278
4567
  def get_current_cuda_device(self) -> Device:
4279
- current_context = self.core.cuda_context_get_current()
4568
+ current_context = self.core.wp_cuda_context_get_current()
4280
4569
  if current_context is not None:
4281
4570
  current_device = self.context_map.get(current_context)
4282
4571
  if current_device is not None:
4283
4572
  # this is a known device
4284
4573
  return current_device
4285
- elif self.core.cuda_context_is_primary(current_context):
4574
+ elif self.core.wp_cuda_context_is_primary(current_context):
4286
4575
  # this is a primary context that we haven't used yet
4287
- ordinal = self.core.cuda_context_get_device_ordinal(current_context)
4576
+ ordinal = self.core.wp_cuda_context_get_device_ordinal(current_context)
4288
4577
  device = self.cuda_devices[ordinal]
4289
4578
  self.context_map[current_context] = device
4290
4579
  return device
4291
4580
  else:
4292
4581
  # this is an unseen non-primary context, register it as a new device with a unique alias
4293
- ordinal = self.core.cuda_context_get_device_ordinal(current_context)
4582
+ ordinal = self.core.wp_cuda_context_get_device_ordinal(current_context)
4294
4583
  alias = f"cuda:{ordinal}.{self.cuda_custom_context_count[ordinal]}"
4295
4584
  self.cuda_custom_context_count[ordinal] += 1
4296
4585
  return self.map_cuda_device(alias, current_context)
@@ -4313,7 +4602,7 @@ class Runtime:
4313
4602
 
4314
4603
  def map_cuda_device(self, alias, context=None) -> Device:
4315
4604
  if context is None:
4316
- context = self.core.cuda_context_get_current()
4605
+ context = self.core.wp_cuda_context_get_current()
4317
4606
  if context is None:
4318
4607
  raise RuntimeError(f"Unable to determine CUDA context for device alias '{alias}'")
4319
4608
 
@@ -4335,10 +4624,10 @@ class Runtime:
4335
4624
  # it's an unmapped context
4336
4625
 
4337
4626
  # get the device ordinal
4338
- ordinal = self.core.cuda_context_get_device_ordinal(context)
4627
+ ordinal = self.core.wp_cuda_context_get_device_ordinal(context)
4339
4628
 
4340
4629
  # check if this is a primary context (we could get here if it's a device that hasn't been used yet)
4341
- if self.core.cuda_context_is_primary(context):
4630
+ if self.core.wp_cuda_context_is_primary(context):
4342
4631
  # rename the device
4343
4632
  device = self.cuda_primary_devices[ordinal]
4344
4633
  return self.rename_device(device, alias)
@@ -4369,7 +4658,7 @@ class Runtime:
4369
4658
  if not device.is_cuda:
4370
4659
  return
4371
4660
 
4372
- err = self.core.cuda_context_check(device.context)
4661
+ err = self.core.wp_cuda_context_check(device.context)
4373
4662
  if err != 0:
4374
4663
  raise RuntimeError(f"CUDA error detected: {err}")
4375
4664
 
@@ -4401,7 +4690,7 @@ def is_cuda_driver_initialized() -> bool:
4401
4690
  """
4402
4691
  init()
4403
4692
 
4404
- return runtime.core.cuda_driver_is_initialized()
4693
+ return runtime.core.wp_cuda_driver_is_initialized()
4405
4694
 
4406
4695
 
4407
4696
  def get_devices() -> list[Device]:
@@ -4609,7 +4898,7 @@ def set_mempool_release_threshold(device: Devicelike, threshold: int | float) ->
4609
4898
  elif threshold > 0 and threshold <= 1:
4610
4899
  threshold = int(threshold * device.total_memory)
4611
4900
 
4612
- if not runtime.core.cuda_device_set_mempool_release_threshold(device.ordinal, threshold):
4901
+ if not runtime.core.wp_cuda_device_set_mempool_release_threshold(device.ordinal, threshold):
4613
4902
  raise RuntimeError(f"Failed to set memory pool release threshold for device {device}")
4614
4903
 
4615
4904
 
@@ -4639,7 +4928,7 @@ def get_mempool_release_threshold(device: Devicelike = None) -> int:
4639
4928
  if not device.is_mempool_supported:
4640
4929
  raise RuntimeError(f"Device {device} does not support memory pools")
4641
4930
 
4642
- return runtime.core.cuda_device_get_mempool_release_threshold(device.ordinal)
4931
+ return runtime.core.wp_cuda_device_get_mempool_release_threshold(device.ordinal)
4643
4932
 
4644
4933
 
4645
4934
  def get_mempool_used_mem_current(device: Devicelike = None) -> int:
@@ -4668,7 +4957,7 @@ def get_mempool_used_mem_current(device: Devicelike = None) -> int:
4668
4957
  if not device.is_mempool_supported:
4669
4958
  raise RuntimeError(f"Device {device} does not support memory pools")
4670
4959
 
4671
- return runtime.core.cuda_device_get_mempool_used_mem_current(device.ordinal)
4960
+ return runtime.core.wp_cuda_device_get_mempool_used_mem_current(device.ordinal)
4672
4961
 
4673
4962
 
4674
4963
  def get_mempool_used_mem_high(device: Devicelike = None) -> int:
@@ -4697,7 +4986,7 @@ def get_mempool_used_mem_high(device: Devicelike = None) -> int:
4697
4986
  if not device.is_mempool_supported:
4698
4987
  raise RuntimeError(f"Device {device} does not support memory pools")
4699
4988
 
4700
- return runtime.core.cuda_device_get_mempool_used_mem_high(device.ordinal)
4989
+ return runtime.core.wp_cuda_device_get_mempool_used_mem_high(device.ordinal)
4701
4990
 
4702
4991
 
4703
4992
  def is_peer_access_supported(target_device: Devicelike, peer_device: Devicelike) -> bool:
@@ -4718,7 +5007,7 @@ def is_peer_access_supported(target_device: Devicelike, peer_device: Devicelike)
4718
5007
  if not target_device.is_cuda or not peer_device.is_cuda:
4719
5008
  return False
4720
5009
 
4721
- return bool(runtime.core.cuda_is_peer_access_supported(target_device.ordinal, peer_device.ordinal))
5010
+ return bool(runtime.core.wp_cuda_is_peer_access_supported(target_device.ordinal, peer_device.ordinal))
4722
5011
 
4723
5012
 
4724
5013
  def is_peer_access_enabled(target_device: Devicelike, peer_device: Devicelike) -> bool:
@@ -4739,7 +5028,7 @@ def is_peer_access_enabled(target_device: Devicelike, peer_device: Devicelike) -
4739
5028
  if not target_device.is_cuda or not peer_device.is_cuda:
4740
5029
  return False
4741
5030
 
4742
- return bool(runtime.core.cuda_is_peer_access_enabled(target_device.context, peer_device.context))
5031
+ return bool(runtime.core.wp_cuda_is_peer_access_enabled(target_device.context, peer_device.context))
4743
5032
 
4744
5033
 
4745
5034
  def set_peer_access_enabled(target_device: Devicelike, peer_device: Devicelike, enable: bool) -> None:
@@ -4769,7 +5058,7 @@ def set_peer_access_enabled(target_device: Devicelike, peer_device: Devicelike,
4769
5058
  else:
4770
5059
  return
4771
5060
 
4772
- if not runtime.core.cuda_set_peer_access_enabled(target_device.context, peer_device.context, int(enable)):
5061
+ if not runtime.core.wp_cuda_set_peer_access_enabled(target_device.context, peer_device.context, int(enable)):
4773
5062
  action = "enable" if enable else "disable"
4774
5063
  raise RuntimeError(f"Failed to {action} peer access from device {peer_device} to device {target_device}")
4775
5064
 
@@ -4810,7 +5099,7 @@ def is_mempool_access_enabled(target_device: Devicelike, peer_device: Devicelike
4810
5099
  if not peer_device.is_cuda or not target_device.is_cuda or not target_device.is_mempool_supported:
4811
5100
  return False
4812
5101
 
4813
- return bool(runtime.core.cuda_is_mempool_access_enabled(target_device.ordinal, peer_device.ordinal))
5102
+ return bool(runtime.core.wp_cuda_is_mempool_access_enabled(target_device.ordinal, peer_device.ordinal))
4814
5103
 
4815
5104
 
4816
5105
  def set_mempool_access_enabled(target_device: Devicelike, peer_device: Devicelike, enable: bool) -> None:
@@ -4843,7 +5132,7 @@ def set_mempool_access_enabled(target_device: Devicelike, peer_device: Devicelik
4843
5132
  else:
4844
5133
  return
4845
5134
 
4846
- if not runtime.core.cuda_set_mempool_access_enabled(target_device.ordinal, peer_device.ordinal, int(enable)):
5135
+ if not runtime.core.wp_cuda_set_mempool_access_enabled(target_device.ordinal, peer_device.ordinal, int(enable)):
4847
5136
  action = "enable" if enable else "disable"
4848
5137
  raise RuntimeError(f"Failed to {action} memory pool access from device {peer_device} to device {target_device}")
4849
5138
 
@@ -4924,7 +5213,7 @@ def get_event_elapsed_time(start_event: Event, end_event: Event, synchronize: bo
4924
5213
  if synchronize:
4925
5214
  synchronize_event(end_event)
4926
5215
 
4927
- return runtime.core.cuda_event_elapsed_time(start_event.cuda_event, end_event.cuda_event)
5216
+ return runtime.core.wp_cuda_event_elapsed_time(start_event.cuda_event, end_event.cuda_event)
4928
5217
 
4929
5218
 
4930
5219
  def wait_stream(other_stream: Stream, event: Event | None = None):
@@ -5018,7 +5307,7 @@ class RegisteredGLBuffer:
5018
5307
  self.context = self.device.context
5019
5308
  self.flags = flags
5020
5309
  self.fallback_to_copy = fallback_to_copy
5021
- self.resource = runtime.core.cuda_graphics_register_gl_buffer(self.context, gl_buffer_id, flags)
5310
+ self.resource = runtime.core.wp_cuda_graphics_register_gl_buffer(self.context, gl_buffer_id, flags)
5022
5311
  if self.resource is None:
5023
5312
  if self.fallback_to_copy:
5024
5313
  self.warp_buffer = None
@@ -5037,7 +5326,7 @@ class RegisteredGLBuffer:
5037
5326
 
5038
5327
  # use CUDA context guard to avoid side effects during garbage collection
5039
5328
  with self.device.context_guard:
5040
- runtime.core.cuda_graphics_unregister_resource(self.context, self.resource)
5329
+ runtime.core.wp_cuda_graphics_unregister_resource(self.context, self.resource)
5041
5330
 
5042
5331
  def map(self, dtype, shape) -> warp.array:
5043
5332
  """Map the OpenGL buffer to a Warp array.
@@ -5050,10 +5339,10 @@ class RegisteredGLBuffer:
5050
5339
  A Warp array object representing the mapped OpenGL buffer.
5051
5340
  """
5052
5341
  if self.resource is not None:
5053
- runtime.core.cuda_graphics_map(self.context, self.resource)
5342
+ runtime.core.wp_cuda_graphics_map(self.context, self.resource)
5054
5343
  ptr = ctypes.c_uint64(0)
5055
5344
  size = ctypes.c_size_t(0)
5056
- runtime.core.cuda_graphics_device_ptr_and_size(
5345
+ runtime.core.wp_cuda_graphics_device_ptr_and_size(
5057
5346
  self.context, self.resource, ctypes.byref(ptr), ctypes.byref(size)
5058
5347
  )
5059
5348
  return warp.array(ptr=ptr.value, dtype=dtype, shape=shape, device=self.device)
@@ -5078,7 +5367,7 @@ class RegisteredGLBuffer:
5078
5367
  def unmap(self):
5079
5368
  """Unmap the OpenGL buffer."""
5080
5369
  if self.resource is not None:
5081
- runtime.core.cuda_graphics_unmap(self.context, self.resource)
5370
+ runtime.core.wp_cuda_graphics_unmap(self.context, self.resource)
5082
5371
  elif self.fallback_to_copy:
5083
5372
  if self.warp_buffer is None:
5084
5373
  raise RuntimeError("RegisteredGLBuffer first has to be mapped")
@@ -5434,7 +5723,7 @@ def event_from_ipc_handle(handle, device: Devicelike = None) -> Event:
5434
5723
  raise RuntimeError(f"IPC is not supported on device {device}.")
5435
5724
 
5436
5725
  event = Event(
5437
- device=device, cuda_event=warp.context.runtime.core.cuda_ipc_open_event_handle(device.context, handle)
5726
+ device=device, cuda_event=warp.context.runtime.core.wp_cuda_ipc_open_event_handle(device.context, handle)
5438
5727
  )
5439
5728
  # Events created from IPC handles must be freed with cuEventDestroy
5440
5729
  event.owner = True
@@ -5566,6 +5855,44 @@ def pack_arg(kernel, arg_type, arg_name, value, device, adjoint=False):
5566
5855
  ) from e
5567
5856
 
5568
5857
 
5858
+ # invoke a CPU kernel by passing the parameters as a ctypes structure
5859
+ def invoke(kernel, hooks, params: Sequence[Any], adjoint: bool):
5860
+ fields = []
5861
+
5862
+ for i in range(0, len(kernel.adj.args)):
5863
+ arg_name = kernel.adj.args[i].label
5864
+ field = (arg_name, type(params[1 + i])) # skip the first argument, which is the launch bounds
5865
+ fields.append(field)
5866
+
5867
+ ArgsStruct = type("ArgsStruct", (ctypes.Structure,), {"_fields_": fields})
5868
+
5869
+ args = ArgsStruct()
5870
+ for i, field in enumerate(fields):
5871
+ name = field[0]
5872
+ setattr(args, name, params[1 + i])
5873
+
5874
+ if not adjoint:
5875
+ hooks.forward(params[0], ctypes.byref(args))
5876
+
5877
+ # for adjoint kernels the adjoint arguments are passed through a second struct
5878
+ else:
5879
+ adj_fields = []
5880
+
5881
+ for i in range(0, len(kernel.adj.args)):
5882
+ arg_name = kernel.adj.args[i].label
5883
+ field = (arg_name, type(params[1 + len(fields) + i])) # skip the first argument, which is the launch bounds
5884
+ adj_fields.append(field)
5885
+
5886
+ AdjArgsStruct = type("AdjArgsStruct", (ctypes.Structure,), {"_fields_": adj_fields})
5887
+
5888
+ adj_args = AdjArgsStruct()
5889
+ for i, field in enumerate(adj_fields):
5890
+ name = field[0]
5891
+ setattr(adj_args, name, params[1 + len(fields) + i])
5892
+
5893
+ hooks.backward(params[0], ctypes.byref(args), ctypes.byref(adj_args))
5894
+
5895
+
5569
5896
  class Launch:
5570
5897
  """Represents all data required for a kernel launch so that launches can be replayed quickly.
5571
5898
 
@@ -5758,24 +6085,21 @@ class Launch:
5758
6085
  stream: The stream to launch on.
5759
6086
  """
5760
6087
  if self.device.is_cpu:
5761
- if self.adjoint:
5762
- self.hooks.backward(*self.params)
5763
- else:
5764
- self.hooks.forward(*self.params)
6088
+ invoke(self.kernel, self.hooks, self.params, self.adjoint)
5765
6089
  else:
5766
6090
  if stream is None:
5767
6091
  stream = self.device.stream
5768
6092
 
5769
6093
  # If the stream is capturing, we retain the CUDA module so that it doesn't get unloaded
5770
6094
  # before the captured graph is released.
5771
- if len(runtime.captures) > 0 and runtime.core.cuda_stream_is_capturing(stream.cuda_stream):
5772
- capture_id = runtime.core.cuda_stream_get_capture_id(stream.cuda_stream)
6095
+ if len(runtime.captures) > 0 and runtime.core.wp_cuda_stream_is_capturing(stream.cuda_stream):
6096
+ capture_id = runtime.core.wp_cuda_stream_get_capture_id(stream.cuda_stream)
5773
6097
  graph = runtime.captures.get(capture_id)
5774
6098
  if graph is not None:
5775
6099
  graph.retain_module_exec(self.module_exec)
5776
6100
 
5777
6101
  if self.adjoint:
5778
- runtime.core.cuda_launch_kernel(
6102
+ runtime.core.wp_cuda_launch_kernel(
5779
6103
  self.device.context,
5780
6104
  self.hooks.backward,
5781
6105
  self.bounds.size,
@@ -5786,7 +6110,7 @@ class Launch:
5786
6110
  stream.cuda_stream,
5787
6111
  )
5788
6112
  else:
5789
- runtime.core.cuda_launch_kernel(
6113
+ runtime.core.wp_cuda_launch_kernel(
5790
6114
  self.device.context,
5791
6115
  self.hooks.forward,
5792
6116
  self.bounds.size,
@@ -5905,7 +6229,7 @@ def launch(
5905
6229
  # late bind
5906
6230
  hooks = module_exec.get_kernel_hooks(kernel)
5907
6231
 
5908
- pack_args(fwd_args, params)
6232
+ pack_args(fwd_args, params, adjoint=False)
5909
6233
  pack_args(adj_args, params, adjoint=True)
5910
6234
 
5911
6235
  # run kernel
@@ -5916,38 +6240,25 @@ def launch(
5916
6240
  f"Failed to find backward kernel '{kernel.key}' from module '{kernel.module.name}' for device '{device}'"
5917
6241
  )
5918
6242
 
5919
- if record_cmd:
5920
- launch = Launch(
5921
- kernel=kernel,
5922
- hooks=hooks,
5923
- params=params,
5924
- params_addr=None,
5925
- bounds=bounds,
5926
- device=device,
5927
- adjoint=adjoint,
5928
- )
5929
- return launch
5930
- hooks.backward(*params)
5931
-
5932
6243
  else:
5933
6244
  if hooks.forward is None:
5934
6245
  raise RuntimeError(
5935
6246
  f"Failed to find forward kernel '{kernel.key}' from module '{kernel.module.name}' for device '{device}'"
5936
6247
  )
5937
6248
 
5938
- if record_cmd:
5939
- launch = Launch(
5940
- kernel=kernel,
5941
- hooks=hooks,
5942
- params=params,
5943
- params_addr=None,
5944
- bounds=bounds,
5945
- device=device,
5946
- adjoint=adjoint,
5947
- )
5948
- return launch
5949
- else:
5950
- hooks.forward(*params)
6249
+ if record_cmd:
6250
+ launch = Launch(
6251
+ kernel=kernel,
6252
+ hooks=hooks,
6253
+ params=params,
6254
+ params_addr=None,
6255
+ bounds=bounds,
6256
+ device=device,
6257
+ adjoint=adjoint,
6258
+ )
6259
+ return launch
6260
+
6261
+ invoke(kernel, hooks, params, adjoint)
5951
6262
 
5952
6263
  else:
5953
6264
  kernel_args = [ctypes.c_void_p(ctypes.addressof(x)) for x in params]
@@ -5958,8 +6269,8 @@ def launch(
5958
6269
 
5959
6270
  # If the stream is capturing, we retain the CUDA module so that it doesn't get unloaded
5960
6271
  # before the captured graph is released.
5961
- if len(runtime.captures) > 0 and runtime.core.cuda_stream_is_capturing(stream.cuda_stream):
5962
- capture_id = runtime.core.cuda_stream_get_capture_id(stream.cuda_stream)
6272
+ if len(runtime.captures) > 0 and runtime.core.wp_cuda_stream_is_capturing(stream.cuda_stream):
6273
+ capture_id = runtime.core.wp_cuda_stream_get_capture_id(stream.cuda_stream)
5963
6274
  graph = runtime.captures.get(capture_id)
5964
6275
  if graph is not None:
5965
6276
  graph.retain_module_exec(module_exec)
@@ -5984,7 +6295,7 @@ def launch(
5984
6295
  )
5985
6296
  return launch
5986
6297
  else:
5987
- runtime.core.cuda_launch_kernel(
6298
+ runtime.core.wp_cuda_launch_kernel(
5988
6299
  device.context,
5989
6300
  hooks.backward,
5990
6301
  bounds.size,
@@ -6015,7 +6326,7 @@ def launch(
6015
6326
  return launch
6016
6327
  else:
6017
6328
  # launch
6018
- runtime.core.cuda_launch_kernel(
6329
+ runtime.core.wp_cuda_launch_kernel(
6019
6330
  device.context,
6020
6331
  hooks.forward,
6021
6332
  bounds.size,
@@ -6117,7 +6428,7 @@ def synchronize():
6117
6428
 
6118
6429
  if is_cuda_driver_initialized():
6119
6430
  # save the original context to avoid side effects
6120
- saved_context = runtime.core.cuda_context_get_current()
6431
+ saved_context = runtime.core.wp_cuda_context_get_current()
6121
6432
 
6122
6433
  # TODO: only synchronize devices that have outstanding work
6123
6434
  for device in runtime.cuda_devices:
@@ -6126,10 +6437,10 @@ def synchronize():
6126
6437
  if device.is_capturing:
6127
6438
  raise RuntimeError(f"Cannot synchronize device {device} while graph capture is active")
6128
6439
 
6129
- runtime.core.cuda_context_synchronize(device.context)
6440
+ runtime.core.wp_cuda_context_synchronize(device.context)
6130
6441
 
6131
6442
  # restore the original context to avoid side effects
6132
- runtime.core.cuda_context_set_current(saved_context)
6443
+ runtime.core.wp_cuda_context_set_current(saved_context)
6133
6444
 
6134
6445
 
6135
6446
  def synchronize_device(device: Devicelike = None):
@@ -6147,7 +6458,7 @@ def synchronize_device(device: Devicelike = None):
6147
6458
  if device.is_capturing:
6148
6459
  raise RuntimeError(f"Cannot synchronize device {device} while graph capture is active")
6149
6460
 
6150
- runtime.core.cuda_context_synchronize(device.context)
6461
+ runtime.core.wp_cuda_context_synchronize(device.context)
6151
6462
 
6152
6463
 
6153
6464
  def synchronize_stream(stream_or_device: Stream | Devicelike | None = None):
@@ -6165,7 +6476,7 @@ def synchronize_stream(stream_or_device: Stream | Devicelike | None = None):
6165
6476
  else:
6166
6477
  stream = runtime.get_device(stream_or_device).stream
6167
6478
 
6168
- runtime.core.cuda_stream_synchronize(stream.cuda_stream)
6479
+ runtime.core.wp_cuda_stream_synchronize(stream.cuda_stream)
6169
6480
 
6170
6481
 
6171
6482
  def synchronize_event(event: Event):
@@ -6177,20 +6488,25 @@ def synchronize_event(event: Event):
6177
6488
  event: Event to wait for.
6178
6489
  """
6179
6490
 
6180
- runtime.core.cuda_event_synchronize(event.cuda_event)
6491
+ runtime.core.wp_cuda_event_synchronize(event.cuda_event)
6181
6492
 
6182
6493
 
6183
- def force_load(device: Device | str | list[Device] | list[str] | None = None, modules: list[Module] | None = None):
6494
+ def force_load(
6495
+ device: Device | str | list[Device] | list[str] | None = None,
6496
+ modules: list[Module] | None = None,
6497
+ block_dim: int | None = None,
6498
+ ):
6184
6499
  """Force user-defined kernels to be compiled and loaded
6185
6500
 
6186
6501
  Args:
6187
6502
  device: The device or list of devices to load the modules on. If None, load on all devices.
6188
6503
  modules: List of modules to load. If None, load all imported modules.
6504
+ block_dim: The number of threads per block (always 1 for "cpu" devices).
6189
6505
  """
6190
6506
 
6191
6507
  if is_cuda_driver_initialized():
6192
6508
  # save original context to avoid side effects
6193
- saved_context = runtime.core.cuda_context_get_current()
6509
+ saved_context = runtime.core.wp_cuda_context_get_current()
6194
6510
 
6195
6511
  if device is None:
6196
6512
  devices = get_devices()
@@ -6204,22 +6520,26 @@ def force_load(device: Device | str | list[Device] | list[str] | None = None, mo
6204
6520
 
6205
6521
  for d in devices:
6206
6522
  for m in modules:
6207
- m.load(d)
6523
+ m.load(d, block_dim=block_dim)
6208
6524
 
6209
6525
  if is_cuda_available():
6210
6526
  # restore original context to avoid side effects
6211
- runtime.core.cuda_context_set_current(saved_context)
6527
+ runtime.core.wp_cuda_context_set_current(saved_context)
6212
6528
 
6213
6529
 
6214
6530
  def load_module(
6215
- module: Module | types.ModuleType | str | None = None, device: Device | str | None = None, recursive: bool = False
6531
+ module: Module | types.ModuleType | str | None = None,
6532
+ device: Device | str | None = None,
6533
+ recursive: bool = False,
6534
+ block_dim: int | None = None,
6216
6535
  ):
6217
- """Force user-defined module to be compiled and loaded
6536
+ """Force a user-defined module to be compiled and loaded
6218
6537
 
6219
6538
  Args:
6220
6539
  module: The module to load. If None, load the current module.
6221
6540
  device: The device to load the modules on. If None, load on all devices.
6222
6541
  recursive: Whether to load submodules. E.g., if the given module is `warp.sim`, this will also load `warp.sim.model`, `warp.sim.articulation`, etc.
6542
+ block_dim: The number of threads per block (always 1 for "cpu" devices).
6223
6543
 
6224
6544
  Note: A module must be imported before it can be loaded by this function.
6225
6545
  """
@@ -6240,9 +6560,13 @@ def load_module(
6240
6560
  modules = []
6241
6561
 
6242
6562
  # add the given module, if found
6243
- m = user_modules.get(module_name)
6244
- if m is not None:
6245
- modules.append(m)
6563
+ if isinstance(module, Module):
6564
+ # this ensures that we can load "unique" or procedural modules, which aren't added to `user_modules` by name
6565
+ modules.append(module)
6566
+ else:
6567
+ m = user_modules.get(module_name)
6568
+ if m is not None:
6569
+ modules.append(m)
6246
6570
 
6247
6571
  # add submodules, if recursive
6248
6572
  if recursive:
@@ -6251,7 +6575,203 @@ def load_module(
6251
6575
  if name.startswith(prefix):
6252
6576
  modules.append(mod)
6253
6577
 
6254
- force_load(device=device, modules=modules)
6578
+ force_load(device=device, modules=modules, block_dim=block_dim)
6579
+
6580
+
6581
+ def _resolve_module(module: Module | types.ModuleType | str) -> Module:
6582
+ """Resolve a module from a string, Module, or types.ModuleType.
6583
+
6584
+ Args:
6585
+ module: The module to resolve.
6586
+
6587
+ Returns:
6588
+ The resolved module.
6589
+
6590
+ Raises:
6591
+ TypeError: If the module argument is not a Module, a types.ModuleType, or a string.
6592
+ """
6593
+
6594
+ if isinstance(module, str):
6595
+ module_object = get_module(module)
6596
+ elif isinstance(module, Module):
6597
+ module_object = module
6598
+ elif isinstance(module, types.ModuleType):
6599
+ module_object = get_module(module.__name__)
6600
+ else:
6601
+ raise TypeError(f"Argument 'module' must be a Module or a string, got {type(module)}")
6602
+
6603
+ return module_object
6604
+
6605
+
6606
+ def compile_aot_module(
6607
+ module: Module | types.ModuleType | str,
6608
+ device: Device | str | list[Device] | list[str] | None = None,
6609
+ arch: int | Iterable[int] | None = None,
6610
+ module_dir: str | os.PathLike | None = None,
6611
+ use_ptx: bool | None = None,
6612
+ strip_hash: bool | None = None,
6613
+ ) -> None:
6614
+ """Compile a module (ahead of time) for a given device.
6615
+
6616
+ Args:
6617
+ module: The module to compile.
6618
+ device: The device or devices to compile the module for. If ``None``,
6619
+ and ``arch`` is not specified, compile the module for the current device.
6620
+ arch: The architecture or architectures to compile the module for. If ``None``,
6621
+ the architecture to compile for will be inferred from the current device.
6622
+ module_dir: The directory to save the source, meta, and compiled files to.
6623
+ If not specified, the module will be compiled to the default cache directory.
6624
+ use_ptx: Whether to compile the module to PTX. This setting is only used
6625
+ when compiling modules for the GPU. If ``None``, Warp will decide an
6626
+ appropriate setting based on the runtime environment.
6627
+ strip_hash: Whether to strip the hash from the module and kernel names.
6628
+ Setting this value to ``True`` or ``False`` will update the module's
6629
+ ``"strip_hash"`` option. If left at ``None``, the current value will
6630
+ be used.
6631
+
6632
+ Warning: Do not enable ``strip_hash`` for modules that contain generic
6633
+ kernels. Generic kernels compile to multiple overloads, and the
6634
+ per-overload hash is required to distinguish them. Stripping the hash
6635
+ in this case will cause the module to fail to compile.
6636
+
6637
+ Raises:
6638
+ TypeError: If the module argument is not a Module, a types.ModuleType, or a string.
6639
+ """
6640
+
6641
+ if is_cuda_driver_initialized():
6642
+ # save original context to avoid side effects
6643
+ saved_context = runtime.core.wp_cuda_context_get_current()
6644
+
6645
+ module_object = _resolve_module(module)
6646
+
6647
+ if strip_hash is not None:
6648
+ module_object.options["strip_hash"] = strip_hash
6649
+
6650
+ if device is None and arch:
6651
+ # User provided no device, but an arch, so we will not compile for the default device
6652
+ devices = []
6653
+ elif isinstance(device, list):
6654
+ devices = [get_device(device_item) for device_item in device]
6655
+ else:
6656
+ devices = [get_device(device)]
6657
+
6658
+ for d in devices:
6659
+ module_object.compile(d, module_dir, use_ptx=use_ptx)
6660
+
6661
+ if arch:
6662
+ if isinstance(arch, str) or not hasattr(arch, "__iter__"):
6663
+ arch = [arch]
6664
+
6665
+ for arch_value in arch:
6666
+ module_object.compile(None, module_dir, output_arch=arch_value, use_ptx=use_ptx)
6667
+
6668
+ if is_cuda_available():
6669
+ # restore original context to avoid side effects
6670
+ runtime.core.wp_cuda_context_set_current(saved_context)
6671
+
6672
+
6673
+ def load_aot_module(
6674
+ module: Module | types.ModuleType | str,
6675
+ device: Device | str | list[Device] | list[str] | None = None,
6676
+ arch: int | None = None,
6677
+ module_dir: str | os.PathLike | None = None,
6678
+ use_ptx: bool | None = None,
6679
+ strip_hash: bool = False,
6680
+ ) -> None:
6681
+ """Load a previously compiled module (ahead of time).
6682
+
6683
+ Args:
6684
+ module: The module to load.
6685
+ device: The device or devices to load the module on. If ``None``,
6686
+ load the module for the current device.
6687
+ arch: The architecture to load the module for on all devices.
6688
+ If ``None``, the architecture to load for will be inferred from the
6689
+ current device.
6690
+ module_dir: The directory to load the module from.
6691
+ If not specified, the module will be loaded from the default cache directory.
6692
+ use_ptx: Whether to load the module from PTX. This setting is only used
6693
+ when loading modules for the GPU. If ``None`` on a CUDA device, Warp will
6694
+ try both PTX and CUBIN (PTX first) and load the first that exists.
6695
+ If neither exists, a ``FileNotFoundError`` is raised listing all
6696
+ attempted paths.
6697
+ strip_hash: Whether to strip the hash from the module and kernel names.
6698
+ Setting this value to ``True`` or ``False`` will update the module's
6699
+ ``"strip_hash"`` option. If left at ``None``, the current value will
6700
+ be used.
6701
+
6702
+ Warning: Do not enable ``strip_hash`` for modules that contain generic
6703
+ kernels. Generic kernels compile to multiple overloads, and the
6704
+ per-overload hash is required to distinguish them. Stripping the hash
6705
+ in this case will cause the module to fail to compile.
6706
+
6707
+ Raises:
6708
+ FileNotFoundError: If no matching binary is found. When ``use_ptx`` is
6709
+ ``None`` on a CUDA device, both PTX and CUBIN candidates are tried
6710
+ before raising.
6711
+ TypeError: If the module argument is not a Module, a types.ModuleType, or a string.
6712
+ """
6713
+
6714
+ if is_cuda_driver_initialized():
6715
+ # save original context to avoid side effects
6716
+ saved_context = runtime.core.wp_cuda_context_get_current()
6717
+
6718
+ if device is None:
6719
+ devices = [runtime.get_device()]
6720
+ elif isinstance(device, list):
6721
+ devices = [get_device(device_item) for device_item in device]
6722
+ else:
6723
+ devices = [get_device(device)]
6724
+
6725
+ module_object = _resolve_module(module)
6726
+
6727
+ if strip_hash is not None:
6728
+ module_object.options["strip_hash"] = strip_hash
6729
+
6730
+ if module_dir is None:
6731
+ module_dir = os.path.join(warp.config.kernel_cache_dir, module_object.get_module_identifier())
6732
+ else:
6733
+ module_dir = os.fspath(module_dir)
6734
+
6735
+ for d in devices:
6736
+ # Identify the files in the cache to load
6737
+ if arch is None:
6738
+ output_arch = module_object.get_compile_arch(d)
6739
+ else:
6740
+ output_arch = arch
6741
+
6742
+ meta_path = os.path.join(module_dir, module_object.get_meta_name())
6743
+
6744
+ # Determine candidate binaries to try
6745
+ tried_paths = []
6746
+ binary_path = None
6747
+ if d.is_cuda and use_ptx is None:
6748
+ candidate_flags = (True, False) # try PTX first, then CUBIN
6749
+ else:
6750
+ candidate_flags = (use_ptx,)
6751
+
6752
+ for candidate_use_ptx in candidate_flags:
6753
+ candidate_path = os.path.join(
6754
+ module_dir, module_object.get_compile_output_name(d, output_arch, candidate_use_ptx)
6755
+ )
6756
+ tried_paths.append(candidate_path)
6757
+ if os.path.exists(candidate_path):
6758
+ binary_path = candidate_path
6759
+ break
6760
+
6761
+ if binary_path is None:
6762
+ raise FileNotFoundError(f"Binary file not found. Tried: {', '.join(tried_paths)}")
6763
+
6764
+ module_object.load(
6765
+ d,
6766
+ block_dim=module_object.options["block_dim"],
6767
+ binary_path=binary_path,
6768
+ output_arch=output_arch,
6769
+ meta_path=meta_path,
6770
+ )
6771
+
6772
+ if is_cuda_available():
6773
+ # restore original context to avoid side effects
6774
+ runtime.core.wp_cuda_context_set_current(saved_context)
6255
6775
 
6256
6776
 
6257
6777
  def set_module_options(options: dict[str, Any], module: Any = None):
@@ -6381,10 +6901,10 @@ def capture_begin(
6381
6901
  if force_module_load:
6382
6902
  force_load(device)
6383
6903
 
6384
- if not runtime.core.cuda_graph_begin_capture(device.context, stream.cuda_stream, int(external)):
6904
+ if not runtime.core.wp_cuda_graph_begin_capture(device.context, stream.cuda_stream, int(external)):
6385
6905
  raise RuntimeError(runtime.get_error_string())
6386
6906
 
6387
- capture_id = runtime.core.cuda_stream_get_capture_id(stream.cuda_stream)
6907
+ capture_id = runtime.core.wp_cuda_stream_get_capture_id(stream.cuda_stream)
6388
6908
  graph = Graph(device, capture_id)
6389
6909
 
6390
6910
  _register_capture(device, stream, graph, capture_id)
@@ -6419,7 +6939,7 @@ def capture_end(device: Devicelike = None, stream: Stream | None = None) -> Grap
6419
6939
 
6420
6940
  # get the graph executable
6421
6941
  g = ctypes.c_void_p()
6422
- result = runtime.core.cuda_graph_end_capture(device.context, stream.cuda_stream, ctypes.byref(g))
6942
+ result = runtime.core.wp_cuda_graph_end_capture(device.context, stream.cuda_stream, ctypes.byref(g))
6423
6943
 
6424
6944
  if not result:
6425
6945
  # A concrete error should've already been reported, so we don't need to go into details here
@@ -6440,7 +6960,7 @@ def capture_debug_dot_print(graph: Graph, path: str, verbose: bool = False):
6440
6960
  path: Path to save the DOT file
6441
6961
  verbose: Whether to include additional debug information in the output
6442
6962
  """
6443
- if not runtime.core.capture_debug_dot_print(graph.graph, path.encode(), 0 if verbose else 1):
6963
+ if not runtime.core.wp_capture_debug_dot_print(graph.graph, path.encode(), 0 if verbose else 1):
6444
6964
  raise RuntimeError(f"Graph debug dot print error: {runtime.get_error_string()}")
6445
6965
 
6446
6966
 
@@ -6473,7 +6993,7 @@ def capture_pause(device: Devicelike = None, stream: Stream | None = None) -> Gr
6473
6993
  _unregister_capture(device, stream, graph)
6474
6994
 
6475
6995
  g = ctypes.c_void_p()
6476
- if not runtime.core.cuda_graph_pause_capture(device.context, stream.cuda_stream, ctypes.byref(g)):
6996
+ if not runtime.core.wp_cuda_graph_pause_capture(device.context, stream.cuda_stream, ctypes.byref(g)):
6477
6997
  raise RuntimeError(runtime.get_error_string())
6478
6998
 
6479
6999
  graph.graph = g
@@ -6490,10 +7010,10 @@ def capture_resume(graph: Graph, device: Devicelike = None, stream: Stream | Non
6490
7010
  raise RuntimeError("Must be a CUDA device")
6491
7011
  stream = device.stream
6492
7012
 
6493
- if not runtime.core.cuda_graph_resume_capture(device.context, stream.cuda_stream, graph.graph):
7013
+ if not runtime.core.wp_cuda_graph_resume_capture(device.context, stream.cuda_stream, graph.graph):
6494
7014
  raise RuntimeError(runtime.get_error_string())
6495
7015
 
6496
- capture_id = runtime.core.cuda_stream_get_capture_id(stream.cuda_stream)
7016
+ capture_id = runtime.core.wp_cuda_stream_get_capture_id(stream.cuda_stream)
6497
7017
  graph.capture_id = capture_id
6498
7018
 
6499
7019
  _register_capture(device, stream, graph, capture_id)
@@ -6576,17 +7096,17 @@ def capture_if(
6576
7096
 
6577
7097
  return
6578
7098
 
6579
- graph.has_conditional = True
6580
-
6581
7099
  # ensure conditional graph nodes are supported
6582
7100
  assert_conditional_graph_support()
6583
7101
 
6584
7102
  # insert conditional node
6585
7103
  graph_on_true = ctypes.c_void_p()
6586
7104
  graph_on_false = ctypes.c_void_p()
6587
- if not runtime.core.cuda_graph_insert_if_else(
7105
+ if not runtime.core.wp_cuda_graph_insert_if_else(
6588
7106
  device.context,
6589
7107
  stream.cuda_stream,
7108
+ device.get_cuda_compile_arch(),
7109
+ device.get_cuda_output_format() == "ptx",
6590
7110
  ctypes.cast(condition.ptr, ctypes.POINTER(ctypes.c_int32)),
6591
7111
  None if on_true is None else ctypes.byref(graph_on_true),
6592
7112
  None if on_false is None else ctypes.byref(graph_on_false),
@@ -6607,11 +7127,7 @@ def capture_if(
6607
7127
  if isinstance(on_true, Callable):
6608
7128
  on_true(**kwargs)
6609
7129
  elif isinstance(on_true, Graph):
6610
- if on_true.has_conditional:
6611
- raise RuntimeError(
6612
- "The on_true graph contains conditional nodes, which are not allowed in child graphs"
6613
- )
6614
- if not runtime.core.cuda_graph_insert_child_graph(
7130
+ if not runtime.core.wp_cuda_graph_insert_child_graph(
6615
7131
  device.context,
6616
7132
  stream.cuda_stream,
6617
7133
  on_true.graph,
@@ -6621,6 +7137,10 @@ def capture_if(
6621
7137
  raise TypeError("on_true must be a Callable or a Graph")
6622
7138
  capture_pause(stream=stream)
6623
7139
 
7140
+ # check the if-body graph
7141
+ if not runtime.core.wp_cuda_graph_check_conditional_body(graph_on_true):
7142
+ raise RuntimeError(runtime.get_error_string())
7143
+
6624
7144
  # capture else-graph
6625
7145
  if on_false is not None:
6626
7146
  # temporarily repurpose the main_graph python object such that all dependencies
@@ -6630,11 +7150,7 @@ def capture_if(
6630
7150
  if isinstance(on_false, Callable):
6631
7151
  on_false(**kwargs)
6632
7152
  elif isinstance(on_false, Graph):
6633
- if on_false.has_conditional:
6634
- raise RuntimeError(
6635
- "The on_false graph contains conditional nodes, which are not allowed in child graphs"
6636
- )
6637
- if not runtime.core.cuda_graph_insert_child_graph(
7153
+ if not runtime.core.wp_cuda_graph_insert_child_graph(
6638
7154
  device.context,
6639
7155
  stream.cuda_stream,
6640
7156
  on_false.graph,
@@ -6644,6 +7160,10 @@ def capture_if(
6644
7160
  raise TypeError("on_false must be a Callable or a Graph")
6645
7161
  capture_pause(stream=stream)
6646
7162
 
7163
+ # check the else-body graph
7164
+ if not runtime.core.wp_cuda_graph_check_conditional_body(graph_on_false):
7165
+ raise RuntimeError(runtime.get_error_string())
7166
+
6647
7167
  # restore the main graph to its original state
6648
7168
  main_graph.graph = main_graph_ptr
6649
7169
 
@@ -6651,7 +7171,9 @@ def capture_if(
6651
7171
  capture_resume(main_graph, stream=stream)
6652
7172
 
6653
7173
 
6654
- def capture_while(condition: warp.array(dtype=int), while_body: Callable | Graph, stream: Stream = None, **kwargs):
7174
+ def capture_while(
7175
+ condition: warp.array(dtype=int), while_body: Callable | Graph, stream: Stream | None = None, **kwargs
7176
+ ):
6655
7177
  """Create a dynamic loop based on a condition.
6656
7178
 
6657
7179
  The condition value is retrieved from the first element of the ``condition`` array.
@@ -6710,17 +7232,17 @@ def capture_while(condition: warp.array(dtype=int), while_body: Callable | Graph
6710
7232
 
6711
7233
  return
6712
7234
 
6713
- graph.has_conditional = True
6714
-
6715
7235
  # ensure conditional graph nodes are supported
6716
7236
  assert_conditional_graph_support()
6717
7237
 
6718
7238
  # insert conditional while-node
6719
7239
  body_graph = ctypes.c_void_p()
6720
7240
  cond_handle = ctypes.c_uint64()
6721
- if not runtime.core.cuda_graph_insert_while(
7241
+ if not runtime.core.wp_cuda_graph_insert_while(
6722
7242
  device.context,
6723
7243
  stream.cuda_stream,
7244
+ device.get_cuda_compile_arch(),
7245
+ device.get_cuda_output_format() == "ptx",
6724
7246
  ctypes.cast(condition.ptr, ctypes.POINTER(ctypes.c_int32)),
6725
7247
  ctypes.byref(body_graph),
6726
7248
  ctypes.byref(cond_handle),
@@ -6741,29 +7263,33 @@ def capture_while(condition: warp.array(dtype=int), while_body: Callable | Graph
6741
7263
  if isinstance(while_body, Callable):
6742
7264
  while_body(**kwargs)
6743
7265
  elif isinstance(while_body, Graph):
6744
- if while_body.has_conditional:
6745
- raise RuntimeError("The body graph contains conditional nodes, which are not allowed in child graphs")
6746
-
6747
- if not runtime.core.cuda_graph_insert_child_graph(
7266
+ if not runtime.core.wp_cuda_graph_insert_child_graph(
6748
7267
  device.context,
6749
7268
  stream.cuda_stream,
6750
7269
  while_body.graph,
6751
7270
  ):
6752
7271
  raise RuntimeError(runtime.get_error_string())
6753
7272
  else:
6754
- raise RuntimeError(runtime.get_error_string())
7273
+ raise TypeError("while_body must be a callable or a graph")
6755
7274
 
6756
7275
  # update condition
6757
- if not runtime.core.cuda_graph_set_condition(
7276
+ if not runtime.core.wp_cuda_graph_set_condition(
6758
7277
  device.context,
6759
7278
  stream.cuda_stream,
7279
+ device.get_cuda_compile_arch(),
7280
+ device.get_cuda_output_format() == "ptx",
6760
7281
  ctypes.cast(condition.ptr, ctypes.POINTER(ctypes.c_int32)),
6761
7282
  cond_handle,
6762
7283
  ):
6763
7284
  raise RuntimeError(runtime.get_error_string())
6764
7285
 
6765
- # stop capturing child graph and resume capturing parent graph
7286
+ # stop capturing while-body
6766
7287
  capture_pause(stream=stream)
7288
+
7289
+ # check the while-body graph
7290
+ if not runtime.core.wp_cuda_graph_check_conditional_body(body_graph):
7291
+ raise RuntimeError(runtime.get_error_string())
7292
+
6767
7293
  # restore the main graph to its original state
6768
7294
  main_graph.graph = main_graph_ptr
6769
7295
  capture_resume(main_graph, stream=stream)
@@ -6787,14 +7313,14 @@ def capture_launch(graph: Graph, stream: Stream | None = None):
6787
7313
 
6788
7314
  if graph.graph_exec is None:
6789
7315
  g = ctypes.c_void_p()
6790
- result = runtime.core.cuda_graph_create_exec(
7316
+ result = runtime.core.wp_cuda_graph_create_exec(
6791
7317
  graph.device.context, stream.cuda_stream, graph.graph, ctypes.byref(g)
6792
7318
  )
6793
7319
  if not result:
6794
7320
  raise RuntimeError(f"Graph creation error: {runtime.get_error_string()}")
6795
7321
  graph.graph_exec = g
6796
7322
 
6797
- if not runtime.core.cuda_graph_launch(graph.graph_exec, stream.cuda_stream):
7323
+ if not runtime.core.wp_cuda_graph_launch(graph.graph_exec, stream.cuda_stream):
6798
7324
  raise RuntimeError(f"Graph launch error: {runtime.get_error_string()}")
6799
7325
 
6800
7326
 
@@ -6905,24 +7431,24 @@ def copy(
6905
7431
  if dest.device.is_cuda:
6906
7432
  if src.device.is_cuda:
6907
7433
  if src.device == dest.device:
6908
- result = runtime.core.memcpy_d2d(
7434
+ result = runtime.core.wp_memcpy_d2d(
6909
7435
  dest.device.context, dst_ptr, src_ptr, bytes_to_copy, stream.cuda_stream
6910
7436
  )
6911
7437
  else:
6912
- result = runtime.core.memcpy_p2p(
7438
+ result = runtime.core.wp_memcpy_p2p(
6913
7439
  dest.device.context, dst_ptr, src.device.context, src_ptr, bytes_to_copy, stream.cuda_stream
6914
7440
  )
6915
7441
  else:
6916
- result = runtime.core.memcpy_h2d(
7442
+ result = runtime.core.wp_memcpy_h2d(
6917
7443
  dest.device.context, dst_ptr, src_ptr, bytes_to_copy, stream.cuda_stream
6918
7444
  )
6919
7445
  else:
6920
7446
  if src.device.is_cuda:
6921
- result = runtime.core.memcpy_d2h(
7447
+ result = runtime.core.wp_memcpy_d2h(
6922
7448
  src.device.context, dst_ptr, src_ptr, bytes_to_copy, stream.cuda_stream
6923
7449
  )
6924
7450
  else:
6925
- result = runtime.core.memcpy_h2h(dst_ptr, src_ptr, bytes_to_copy)
7451
+ result = runtime.core.wp_memcpy_h2h(dst_ptr, src_ptr, bytes_to_copy)
6926
7452
 
6927
7453
  if not result:
6928
7454
  raise RuntimeError(f"Warp copy error: {runtime.get_error_string()}")
@@ -6957,17 +7483,17 @@ def copy(
6957
7483
  # This work involves a kernel launch, so it must run on the destination device.
6958
7484
  # If the copy stream is different, we need to synchronize it.
6959
7485
  if stream == dest.device.stream:
6960
- result = runtime.core.array_copy_device(
7486
+ result = runtime.core.wp_array_copy_device(
6961
7487
  dest.device.context, dst_ptr, src_ptr, dst_type, src_type, src_elem_size
6962
7488
  )
6963
7489
  else:
6964
7490
  dest.device.stream.wait_stream(stream)
6965
- result = runtime.core.array_copy_device(
7491
+ result = runtime.core.wp_array_copy_device(
6966
7492
  dest.device.context, dst_ptr, src_ptr, dst_type, src_type, src_elem_size
6967
7493
  )
6968
7494
  stream.wait_stream(dest.device.stream)
6969
7495
  else:
6970
- result = runtime.core.array_copy_host(dst_ptr, src_ptr, dst_type, src_type, src_elem_size)
7496
+ result = runtime.core.wp_array_copy_host(dst_ptr, src_ptr, dst_type, src_type, src_elem_size)
6971
7497
 
6972
7498
  if not result:
6973
7499
  raise RuntimeError(f"Warp copy error: {runtime.get_error_string()}")
@@ -7272,7 +7798,6 @@ def export_stubs(file): # pragma: no cover
7272
7798
  """,
7273
7799
  file=file,
7274
7800
  )
7275
-
7276
7801
  print(
7277
7802
  "# Autogenerated file, do not edit, this file provides stubs for builtins autocomplete in VSCode, PyCharm, etc",
7278
7803
  file=file,
@@ -7283,6 +7808,7 @@ def export_stubs(file): # pragma: no cover
7283
7808
  print("from typing import Callable", file=file)
7284
7809
  print("from typing import TypeVar", file=file)
7285
7810
  print("from typing import Generic", file=file)
7811
+ print("from typing import Sequence", file=file)
7286
7812
  print("from typing import overload as over", file=file)
7287
7813
  print(file=file)
7288
7814
 
@@ -7311,7 +7837,7 @@ def export_stubs(file): # pragma: no cover
7311
7837
  print(header, file=file)
7312
7838
  print(file=file)
7313
7839
 
7314
- def add_stub(f):
7840
+ def add_builtin_function_stub(f):
7315
7841
  args = ", ".join(f"{k}: {type_str(v)}" for k, v in f.input_types.items())
7316
7842
 
7317
7843
  return_str = ""
@@ -7331,12 +7857,162 @@ def export_stubs(file): # pragma: no cover
7331
7857
  print(' """', file=file)
7332
7858
  print(" ...\n\n", file=file)
7333
7859
 
7860
+ def add_vector_type_stub(cls, label):
7861
+ cls_name = cls.__name__
7862
+ scalar_type_name = cls._wp_scalar_type_.__name__
7863
+
7864
+ print(f"class {cls_name}:", file=file)
7865
+
7866
+ print(" @over", file=file)
7867
+ print(" def __init__(self) -> None:", file=file)
7868
+ print(f' """Construct a zero-initialized {label}."""', file=file)
7869
+ print(" ...\n\n", file=file)
7870
+
7871
+ print(" @over", file=file)
7872
+ print(f" def __init__(self, other: {cls_name}) -> None:", file=file)
7873
+ print(f' """Construct a {label} by copy."""', file=file)
7874
+ print(" ...\n\n", file=file)
7875
+
7876
+ args = ", ".join(f"{x}: {scalar_type_name}" for x in "xyzw"[: cls._length_])
7877
+ print(" @over", file=file)
7878
+ print(f" def __init__(self, {args}) -> None:", file=file)
7879
+ print(f' """Construct a {label} from its component values."""', file=file)
7880
+ print(" ...\n\n", file=file)
7881
+
7882
+ print(" @over", file=file)
7883
+ print(f" def __init__(self, args: Sequence[{scalar_type_name}]) -> None:", file=file)
7884
+ print(f' """Construct a {label} from a sequence of values."""', file=file)
7885
+ print(" ...\n\n", file=file)
7886
+
7887
+ print(" @over", file=file)
7888
+ print(f" def __init__(self, value: {scalar_type_name}) -> None:", file=file)
7889
+ print(f' """Construct a {label} filled with a value."""', file=file)
7890
+ print(" ...\n\n", file=file)
7891
+
7892
+ def add_matrix_type_stub(cls, label):
7893
+ cls_name = cls.__name__
7894
+ scalar_type_name = cls._wp_scalar_type_.__name__
7895
+ scalar_short_name = warp.types.scalar_short_name(cls._wp_scalar_type_)
7896
+
7897
+ print(f"class {cls_name}:", file=file)
7898
+
7899
+ print(" @over", file=file)
7900
+ print(" def __init__(self) -> None:", file=file)
7901
+ print(f' """Construct a zero-initialized {label}."""', file=file)
7902
+ print(" ...\n\n", file=file)
7903
+
7904
+ print(" @over", file=file)
7905
+ print(f" def __init__(self, other: {cls_name}) -> None:", file=file)
7906
+ print(f' """Construct a {label} by copy."""', file=file)
7907
+ print(" ...\n\n", file=file)
7908
+
7909
+ args = ", ".join(f"m{i}{j}: {scalar_type_name}" for i in range(cls._shape_[0]) for j in range(cls._shape_[1]))
7910
+ print(" @over", file=file)
7911
+ print(f" def __init__(self, {args}) -> None:", file=file)
7912
+ print(f' """Construct a {label} from its component values."""', file=file)
7913
+ print(" ...\n\n", file=file)
7914
+
7915
+ args = ", ".join(f"v{i}: vec{cls._shape_[0]}{scalar_short_name}" for i in range(cls._shape_[0]))
7916
+ print(" @over", file=file)
7917
+ print(f" def __init__(self, {args}) -> None:", file=file)
7918
+ print(f' """Construct a {label} from its row vectors."""', file=file)
7919
+ print(" ...\n\n", file=file)
7920
+
7921
+ print(" @over", file=file)
7922
+ print(f" def __init__(self, args: Sequence[{scalar_type_name}]) -> None:", file=file)
7923
+ print(f' """Construct a {label} from a sequence of values."""', file=file)
7924
+ print(" ...\n\n", file=file)
7925
+
7926
+ print(" @over", file=file)
7927
+ print(f" def __init__(self, value: {scalar_type_name}) -> None:", file=file)
7928
+ print(f' """Construct a {label} filled with a value."""', file=file)
7929
+ print(" ...\n\n", file=file)
7930
+
7931
+ def add_transform_type_stub(cls, label):
7932
+ cls_name = cls.__name__
7933
+ scalar_type_name = cls._wp_scalar_type_.__name__
7934
+ scalar_short_name = warp.types.scalar_short_name(cls._wp_scalar_type_)
7935
+
7936
+ print(f"class {cls_name}:", file=file)
7937
+
7938
+ print(" @over", file=file)
7939
+ print(" def __init__(self) -> None:", file=file)
7940
+ print(f' """Construct a zero-initialized {label}."""', file=file)
7941
+ print(" ...\n\n", file=file)
7942
+
7943
+ print(" @over", file=file)
7944
+ print(f" def __init__(self, other: {cls_name}) -> None:", file=file)
7945
+ print(f' """Construct a {label} by copy."""', file=file)
7946
+ print(" ...\n\n", file=file)
7947
+
7948
+ print(" @over", file=file)
7949
+ print(f" def __init__(self, p: vec3{scalar_short_name}, q: quat{scalar_short_name}) -> None:", file=file)
7950
+ print(f' """Construct a {label} from its p and q components."""', file=file)
7951
+ print(" ...\n\n", file=file)
7952
+
7953
+ args = ()
7954
+ args += tuple(f"p{x}: {scalar_type_name}" for x in "xyz")
7955
+ args += tuple(f"q{x}: {scalar_type_name}" for x in "xyzw")
7956
+ args = ", ".join(args)
7957
+ print(" @over", file=file)
7958
+ print(f" def __init__(self, {args}) -> None:", file=file)
7959
+ print(f' """Construct a {label} from its component values."""', file=file)
7960
+ print(" ...\n\n", file=file)
7961
+
7962
+ print(" @over", file=file)
7963
+ print(
7964
+ f" def __init__(self, p: Sequence[{scalar_type_name}], q: Sequence[{scalar_type_name}]) -> None:",
7965
+ file=file,
7966
+ )
7967
+ print(f' """Construct a {label} from two sequences of values."""', file=file)
7968
+ print(" ...\n\n", file=file)
7969
+
7970
+ print(" @over", file=file)
7971
+ print(f" def __init__(self, value: {scalar_type_name}) -> None:", file=file)
7972
+ print(f' """Construct a {label} filled with a value."""', file=file)
7973
+ print(" ...\n\n", file=file)
7974
+
7975
+ # Vector types.
7976
+ suffixes = ("h", "f", "d", "b", "ub", "s", "us", "i", "ui", "l", "ul")
7977
+ for length in (2, 3, 4):
7978
+ for suffix in suffixes:
7979
+ cls = getattr(warp.types, f"vec{length}{suffix}")
7980
+ add_vector_type_stub(cls, "vector")
7981
+
7982
+ print(f"vec{length} = vec{length}f", file=file)
7983
+
7984
+ # Matrix types.
7985
+ suffixes = ("h", "f", "d")
7986
+ for length in (2, 3, 4):
7987
+ shape = f"{length}{length}"
7988
+ for suffix in suffixes:
7989
+ cls = getattr(warp.types, f"mat{shape}{suffix}")
7990
+ add_matrix_type_stub(cls, "matrix")
7991
+
7992
+ print(f"mat{shape} = mat{shape}f", file=file)
7993
+
7994
+ # Quaternion types.
7995
+ suffixes = ("h", "f", "d")
7996
+ for suffix in suffixes:
7997
+ cls = getattr(warp.types, f"quat{suffix}")
7998
+ add_vector_type_stub(cls, "quaternion")
7999
+
8000
+ print("quat = quatf", file=file)
8001
+
8002
+ # Transformation types.
8003
+ suffixes = ("h", "f", "d")
8004
+ for suffix in suffixes:
8005
+ cls = getattr(warp.types, f"transform{suffix}")
8006
+ add_transform_type_stub(cls, "transformation")
8007
+
8008
+ print("transform = transformf", file=file)
8009
+
7334
8010
  for g in builtin_functions.values():
7335
8011
  if hasattr(g, "overloads"):
7336
8012
  for f in g.overloads:
7337
- add_stub(f)
8013
+ add_builtin_function_stub(f)
7338
8014
  elif isinstance(g, Function):
7339
- add_stub(g)
8015
+ add_builtin_function_stub(g)
7340
8016
 
7341
8017
 
7342
8018
  def export_builtins(file: io.TextIOBase): # pragma: no cover