warp-lang 0.11.0__py3-none-manylinux2014_x86_64.whl → 1.0.0__py3-none-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (170) hide show
  1. warp/__init__.py +8 -0
  2. warp/bin/warp-clang.so +0 -0
  3. warp/bin/warp.so +0 -0
  4. warp/build.py +7 -6
  5. warp/build_dll.py +70 -79
  6. warp/builtins.py +10 -6
  7. warp/codegen.py +51 -19
  8. warp/config.py +7 -8
  9. warp/constants.py +3 -0
  10. warp/context.py +948 -245
  11. warp/dlpack.py +198 -113
  12. warp/examples/assets/bunny.usd +0 -0
  13. warp/examples/assets/cartpole.urdf +110 -0
  14. warp/examples/assets/crazyflie.usd +0 -0
  15. warp/examples/assets/cube.usda +42 -0
  16. warp/examples/assets/nv_ant.xml +92 -0
  17. warp/examples/assets/nv_humanoid.xml +183 -0
  18. warp/examples/assets/quadruped.urdf +268 -0
  19. warp/examples/assets/rocks.nvdb +0 -0
  20. warp/examples/assets/rocks.usd +0 -0
  21. warp/examples/assets/sphere.usda +56 -0
  22. warp/examples/assets/torus.usda +105 -0
  23. warp/examples/benchmarks/benchmark_api.py +383 -0
  24. warp/examples/benchmarks/benchmark_cloth.py +279 -0
  25. warp/examples/benchmarks/benchmark_cloth_cupy.py +88 -0
  26. warp/examples/benchmarks/benchmark_cloth_jax.py +100 -0
  27. warp/examples/benchmarks/benchmark_cloth_numba.py +142 -0
  28. warp/examples/benchmarks/benchmark_cloth_numpy.py +77 -0
  29. warp/examples/benchmarks/benchmark_cloth_pytorch.py +86 -0
  30. warp/examples/benchmarks/benchmark_cloth_taichi.py +112 -0
  31. warp/examples/benchmarks/benchmark_cloth_warp.py +146 -0
  32. warp/examples/benchmarks/benchmark_launches.py +295 -0
  33. warp/examples/core/example_dem.py +221 -0
  34. warp/examples/core/example_fluid.py +267 -0
  35. warp/examples/core/example_graph_capture.py +129 -0
  36. warp/examples/core/example_marching_cubes.py +177 -0
  37. warp/examples/core/example_mesh.py +154 -0
  38. warp/examples/core/example_mesh_intersect.py +193 -0
  39. warp/examples/core/example_nvdb.py +169 -0
  40. warp/examples/core/example_raycast.py +89 -0
  41. warp/examples/core/example_raymarch.py +178 -0
  42. warp/examples/core/example_render_opengl.py +141 -0
  43. warp/examples/core/example_sph.py +389 -0
  44. warp/examples/core/example_torch.py +181 -0
  45. warp/examples/core/example_wave.py +249 -0
  46. warp/examples/fem/bsr_utils.py +380 -0
  47. warp/examples/fem/example_apic_fluid.py +391 -0
  48. warp/examples/fem/example_convection_diffusion.py +168 -0
  49. warp/examples/fem/example_convection_diffusion_dg.py +209 -0
  50. warp/examples/fem/example_convection_diffusion_dg0.py +194 -0
  51. warp/examples/fem/example_deformed_geometry.py +159 -0
  52. warp/examples/fem/example_diffusion.py +173 -0
  53. warp/examples/fem/example_diffusion_3d.py +152 -0
  54. warp/examples/fem/example_diffusion_mgpu.py +214 -0
  55. warp/examples/fem/example_mixed_elasticity.py +222 -0
  56. warp/examples/fem/example_navier_stokes.py +243 -0
  57. warp/examples/fem/example_stokes.py +192 -0
  58. warp/examples/fem/example_stokes_transfer.py +249 -0
  59. warp/examples/fem/mesh_utils.py +109 -0
  60. warp/examples/fem/plot_utils.py +287 -0
  61. warp/examples/optim/example_bounce.py +248 -0
  62. warp/examples/optim/example_cloth_throw.py +210 -0
  63. warp/examples/optim/example_diffray.py +535 -0
  64. warp/examples/optim/example_drone.py +850 -0
  65. warp/examples/optim/example_inverse_kinematics.py +169 -0
  66. warp/examples/optim/example_inverse_kinematics_torch.py +170 -0
  67. warp/examples/optim/example_spring_cage.py +234 -0
  68. warp/examples/optim/example_trajectory.py +201 -0
  69. warp/examples/sim/example_cartpole.py +128 -0
  70. warp/examples/sim/example_cloth.py +184 -0
  71. warp/examples/sim/example_granular.py +113 -0
  72. warp/examples/sim/example_granular_collision_sdf.py +185 -0
  73. warp/examples/sim/example_jacobian_ik.py +213 -0
  74. warp/examples/sim/example_particle_chain.py +106 -0
  75. warp/examples/sim/example_quadruped.py +179 -0
  76. warp/examples/sim/example_rigid_chain.py +191 -0
  77. warp/examples/sim/example_rigid_contact.py +176 -0
  78. warp/examples/sim/example_rigid_force.py +126 -0
  79. warp/examples/sim/example_rigid_gyroscopic.py +97 -0
  80. warp/examples/sim/example_rigid_soft_contact.py +124 -0
  81. warp/examples/sim/example_soft_body.py +178 -0
  82. warp/fabric.py +29 -20
  83. warp/fem/cache.py +0 -1
  84. warp/fem/dirichlet.py +0 -2
  85. warp/fem/integrate.py +0 -1
  86. warp/jax.py +45 -0
  87. warp/jax_experimental.py +339 -0
  88. warp/native/builtin.h +12 -0
  89. warp/native/bvh.cu +18 -18
  90. warp/native/clang/clang.cpp +8 -3
  91. warp/native/cuda_util.cpp +94 -5
  92. warp/native/cuda_util.h +35 -6
  93. warp/native/cutlass_gemm.cpp +1 -1
  94. warp/native/cutlass_gemm.cu +4 -1
  95. warp/native/error.cpp +66 -0
  96. warp/native/error.h +27 -0
  97. warp/native/mesh.cu +2 -2
  98. warp/native/reduce.cu +4 -4
  99. warp/native/runlength_encode.cu +2 -2
  100. warp/native/scan.cu +2 -2
  101. warp/native/sparse.cu +0 -1
  102. warp/native/temp_buffer.h +2 -2
  103. warp/native/warp.cpp +95 -60
  104. warp/native/warp.cu +1053 -218
  105. warp/native/warp.h +49 -32
  106. warp/optim/linear.py +33 -16
  107. warp/render/render_opengl.py +202 -101
  108. warp/render/render_usd.py +82 -40
  109. warp/sim/__init__.py +13 -4
  110. warp/sim/articulation.py +4 -5
  111. warp/sim/collide.py +320 -175
  112. warp/sim/import_mjcf.py +25 -30
  113. warp/sim/import_urdf.py +94 -63
  114. warp/sim/import_usd.py +51 -36
  115. warp/sim/inertia.py +3 -2
  116. warp/sim/integrator.py +233 -0
  117. warp/sim/integrator_euler.py +447 -469
  118. warp/sim/integrator_featherstone.py +1991 -0
  119. warp/sim/integrator_xpbd.py +1420 -640
  120. warp/sim/model.py +765 -487
  121. warp/sim/particles.py +2 -1
  122. warp/sim/render.py +35 -13
  123. warp/sim/utils.py +222 -11
  124. warp/stubs.py +8 -0
  125. warp/tape.py +16 -1
  126. warp/tests/aux_test_grad_customs.py +23 -0
  127. warp/tests/test_array.py +190 -1
  128. warp/tests/test_async.py +656 -0
  129. warp/tests/test_bool.py +50 -0
  130. warp/tests/test_dlpack.py +164 -11
  131. warp/tests/test_examples.py +166 -74
  132. warp/tests/test_fem.py +8 -1
  133. warp/tests/test_generics.py +15 -5
  134. warp/tests/test_grad.py +1 -1
  135. warp/tests/test_grad_customs.py +172 -12
  136. warp/tests/test_jax.py +254 -0
  137. warp/tests/test_large.py +29 -6
  138. warp/tests/test_launch.py +25 -0
  139. warp/tests/test_linear_solvers.py +20 -3
  140. warp/tests/test_matmul.py +61 -16
  141. warp/tests/test_matmul_lite.py +13 -13
  142. warp/tests/test_mempool.py +186 -0
  143. warp/tests/test_multigpu.py +3 -0
  144. warp/tests/test_options.py +16 -2
  145. warp/tests/test_peer.py +137 -0
  146. warp/tests/test_print.py +3 -1
  147. warp/tests/test_quat.py +23 -0
  148. warp/tests/test_sim_kinematics.py +97 -0
  149. warp/tests/test_snippet.py +126 -3
  150. warp/tests/test_streams.py +108 -79
  151. warp/tests/test_torch.py +16 -8
  152. warp/tests/test_utils.py +32 -27
  153. warp/tests/test_verify_fp.py +65 -0
  154. warp/tests/test_volume.py +1 -1
  155. warp/tests/unittest_serial.py +2 -0
  156. warp/tests/unittest_suites.py +12 -0
  157. warp/tests/unittest_utils.py +14 -7
  158. warp/thirdparty/unittest_parallel.py +15 -3
  159. warp/torch.py +10 -8
  160. warp/types.py +363 -246
  161. warp/utils.py +143 -19
  162. warp_lang-1.0.0.dist-info/LICENSE.md +126 -0
  163. warp_lang-1.0.0.dist-info/METADATA +394 -0
  164. {warp_lang-0.11.0.dist-info → warp_lang-1.0.0.dist-info}/RECORD +167 -86
  165. warp/sim/optimizer.py +0 -138
  166. warp_lang-0.11.0.dist-info/LICENSE.md +0 -36
  167. warp_lang-0.11.0.dist-info/METADATA +0 -238
  168. /warp/tests/{walkthough_debug.py → walkthrough_debug.py} +0 -0
  169. {warp_lang-0.11.0.dist-info → warp_lang-1.0.0.dist-info}/WHEEL +0 -0
  170. {warp_lang-0.11.0.dist-info → warp_lang-1.0.0.dist-info}/top_level.txt +0 -0
warp/context.py CHANGED
@@ -7,6 +7,7 @@
7
7
 
8
8
  import ast
9
9
  import ctypes
10
+ import functools
10
11
  import gc
11
12
  import hashlib
12
13
  import inspect
@@ -72,6 +73,7 @@ class Function:
72
73
  custom_replay_func=None,
73
74
  native_snippet=None,
74
75
  adj_native_snippet=None,
76
+ replay_snippet=None,
75
77
  skip_forward_codegen=False,
76
78
  skip_reverse_codegen=False,
77
79
  custom_reverse_num_input_args=-1,
@@ -97,6 +99,7 @@ class Function:
97
99
  self.custom_replay_func = custom_replay_func
98
100
  self.native_snippet = native_snippet
99
101
  self.adj_native_snippet = adj_native_snippet
102
+ self.replay_snippet = replay_snippet
100
103
  self.custom_grad_func = None
101
104
  self.require_original_output_arg = require_original_output_arg
102
105
 
@@ -641,11 +644,13 @@ def func(f):
641
644
  func=f, key=name, namespace="", module=m, value_func=None
642
645
  ) # value_type not known yet, will be inferred during Adjoint.build()
643
646
 
644
- # return the top of the list of overloads for this key
645
- return m.functions[name]
647
+ # use the top of the list of overloads for this key
648
+ g = m.functions[name]
649
+ # copy over the function attributes, including docstring
650
+ return functools.update_wrapper(g, f)
646
651
 
647
652
 
648
- def func_native(snippet, adj_snippet=None):
653
+ def func_native(snippet, adj_snippet=None, replay_snippet=None):
649
654
  """
650
655
  Decorator to register native code snippet, @func_native
651
656
  """
@@ -655,10 +660,17 @@ def func_native(snippet, adj_snippet=None):
655
660
 
656
661
  m = get_module(f.__module__)
657
662
  func = Function(
658
- func=f, key=name, namespace="", module=m, native_snippet=snippet, adj_native_snippet=adj_snippet
663
+ func=f,
664
+ key=name,
665
+ namespace="",
666
+ module=m,
667
+ native_snippet=snippet,
668
+ adj_native_snippet=adj_snippet,
669
+ replay_snippet=replay_snippet,
659
670
  ) # cuda snippets do not have a return value_type
660
-
661
- return m.functions[name]
671
+ g = m.functions[name]
672
+ # copy over the function attributes, including docstring
673
+ return functools.update_wrapper(g, f)
662
674
 
663
675
  return snippet_func
664
676
 
@@ -702,7 +714,11 @@ def func_grad(forward_fn):
702
714
  def match_function(f):
703
715
  # check whether the function overload f matches the signature of the provided gradient function
704
716
  if not hasattr(f.adj, "return_var"):
705
- f.adj.build(None)
717
+ # we have to temporarily build this function to figure out its return type(s);
718
+ # note that we do not have a ModuleBuilder instance here at this wrapping stage, hence we
719
+ # have to create a dummy builder
720
+ builder = ModuleBuilder(Module("dummy", None), f.module.options)
721
+ f.adj.build(builder)
706
722
  expected_args = list(f.input_types.items())
707
723
  if f.adj.return_var is not None:
708
724
  expected_args += [(f"adj_ret_{var.label}", var.type) for var in f.adj.return_var]
@@ -737,13 +753,13 @@ def func_grad(forward_fn):
737
753
  continue
738
754
  if match_function(f):
739
755
  add_custom_grad(f)
740
- return
756
+ return grad_fn
741
757
  raise RuntimeError(
742
758
  f"No function overload found for gradient function {grad_fn.__qualname__} for function {forward_fn.key}"
743
759
  )
744
760
  else:
745
761
  # resolve return variables
746
- forward_fn.adj.build(None)
762
+ forward_fn.adj.build(None, forward_fn.module.options)
747
763
 
748
764
  expected_args = list(forward_fn.input_types.items())
749
765
  if forward_fn.adj.return_var is not None:
@@ -759,6 +775,8 @@ def func_grad(forward_fn):
759
775
  f"\n{', '.join(map(lambda nt: f'{nt[0]}: {nt[1].__name__}', expected_args))}"
760
776
  )
761
777
 
778
+ return grad_fn
779
+
762
780
  return wrapper
763
781
 
764
782
 
@@ -802,6 +820,7 @@ def func_replay(forward_fn):
802
820
  skip_adding_overload=True,
803
821
  code_transformers=f.adj.transformers,
804
822
  )
823
+ return replay_fn
805
824
 
806
825
  return wrapper
807
826
 
@@ -822,6 +841,7 @@ def kernel(f=None, *, enable_backward=None):
822
841
  module=m,
823
842
  options=options,
824
843
  )
844
+ k = functools.update_wrapper(k, f)
825
845
  return k
826
846
 
827
847
  if f is None:
@@ -835,7 +855,7 @@ def kernel(f=None, *, enable_backward=None):
835
855
  def struct(c):
836
856
  m = get_module(c.__module__)
837
857
  s = warp.codegen.Struct(cls=c, key=warp.codegen.make_full_qualified_name(c), module=m)
838
-
858
+ s = functools.update_wrapper(s, c)
839
859
  return s
840
860
 
841
861
 
@@ -1151,6 +1171,7 @@ class ModuleBuilder:
1151
1171
  self.structs = {}
1152
1172
  self.options = options
1153
1173
  self.module = module
1174
+ self.deferred_functions = []
1154
1175
 
1155
1176
  # build all functions declared in the module
1156
1177
  for func in module.functions.values():
@@ -1167,6 +1188,10 @@ class ModuleBuilder:
1167
1188
  for k in kernel.overloads.values():
1168
1189
  self.build_kernel(k)
1169
1190
 
1191
+ # build all functions outside this module which are called from functions or kernels in this module
1192
+ for func in self.deferred_functions:
1193
+ self.build_function(func)
1194
+
1170
1195
  def build_struct_recursive(self, struct: warp.codegen.Struct):
1171
1196
  structs = []
1172
1197
 
@@ -1236,7 +1261,11 @@ class ModuleBuilder:
1236
1261
  )
1237
1262
  else:
1238
1263
  source += warp.codegen.codegen_snippet(
1239
- func.adj, name=func.key, snippet=func.native_snippet, adj_snippet=func.adj_native_snippet
1264
+ func.adj,
1265
+ name=func.key,
1266
+ snippet=func.native_snippet,
1267
+ adj_snippet=func.adj_native_snippet,
1268
+ replay_snippet=func.replay_snippet,
1240
1269
  )
1241
1270
 
1242
1271
  for kernel in self.module.kernels.values():
@@ -1281,7 +1310,7 @@ class Module:
1281
1310
  self.cuda_build_failed = False
1282
1311
 
1283
1312
  self.options = {
1284
- "max_unroll": 16,
1313
+ "max_unroll": warp.config.max_unroll,
1285
1314
  "enable_backward": warp.config.enable_backward,
1286
1315
  "fast_math": False,
1287
1316
  "cuda_output": None, # supported values: "ptx", "cubin", or None (automatic)
@@ -1439,6 +1468,14 @@ class Module:
1439
1468
  ch.update(bytes(s, "utf-8"))
1440
1469
  if func.custom_replay_func:
1441
1470
  s = func.custom_replay_func.adj.source
1471
+ if func.replay_snippet:
1472
+ s = func.replay_snippet
1473
+ if func.native_snippet:
1474
+ s = func.native_snippet
1475
+ ch.update(bytes(s, "utf-8"))
1476
+ if func.adj_native_snippet:
1477
+ s = func.adj_native_snippet
1478
+ ch.update(bytes(s, "utf-8"))
1442
1479
 
1443
1480
  # cache func arg types
1444
1481
  for arg, arg_type in func.adj.arg_types.items():
@@ -1447,6 +1484,7 @@ class Module:
1447
1484
 
1448
1485
  # kernel source
1449
1486
  for kernel in module.kernels.values():
1487
+ ch.update(bytes(kernel.key, "utf-8"))
1450
1488
  ch.update(bytes(kernel.adj.source, "utf-8"))
1451
1489
  # cache kernel arg types
1452
1490
  for arg, arg_type in kernel.adj.arg_types.items():
@@ -1646,7 +1684,7 @@ class Module:
1646
1684
  if cuda_module is not None:
1647
1685
  self.cuda_modules[device.context] = cuda_module
1648
1686
  else:
1649
- raise Exception("Failed to load CUDA module")
1687
+ raise Exception(f"Failed to load CUDA module '{self.name}'")
1650
1688
 
1651
1689
  except Exception as e:
1652
1690
  self.cuda_build_failed = True
@@ -1714,33 +1752,83 @@ class Module:
1714
1752
  # execution context
1715
1753
 
1716
1754
 
1717
- # a simple allocator
1718
- # TODO: use a pooled allocator to avoid hitting the system allocator
1719
- class Allocator:
1755
+ class CpuDefaultAllocator:
1720
1756
  def __init__(self, device):
1721
- self.device = device
1757
+ assert device.is_cpu
1758
+ self.deleter = lambda ptr, size: self.free(ptr, size)
1722
1759
 
1723
- def alloc(self, size_in_bytes, pinned=False):
1724
- if self.device.is_cuda:
1725
- if self.device.is_capturing:
1726
- raise RuntimeError(f"Cannot allocate memory on device {self} while graph capture is active")
1727
- return runtime.core.alloc_device(self.device.context, size_in_bytes)
1728
- elif self.device.is_cpu:
1729
- if pinned:
1730
- return runtime.core.alloc_pinned(size_in_bytes)
1731
- else:
1732
- return runtime.core.alloc_host(size_in_bytes)
1760
+ def alloc(self, size_in_bytes):
1761
+ ptr = runtime.core.alloc_host(size_in_bytes)
1762
+ if not ptr:
1763
+ raise RuntimeError(f"Failed to allocate {size_in_bytes} bytes on device '{self.device}'")
1764
+ return ptr
1733
1765
 
1734
- def free(self, ptr, size_in_bytes, pinned=False):
1735
- if self.device.is_cuda:
1766
+ def free(self, ptr, size_in_bytes):
1767
+ runtime.core.free_host(ptr)
1768
+
1769
+
1770
+ class CpuPinnedAllocator:
1771
+ def __init__(self, device):
1772
+ assert device.is_cpu
1773
+ self.deleter = lambda ptr, size: self.free(ptr, size)
1774
+
1775
+ def alloc(self, size_in_bytes):
1776
+ ptr = runtime.core.alloc_pinned(size_in_bytes)
1777
+ if not ptr:
1778
+ raise RuntimeError(f"Failed to allocate {size_in_bytes} bytes on device '{self.device}'")
1779
+ return ptr
1780
+
1781
+ def free(self, ptr, size_in_bytes):
1782
+ runtime.core.free_pinned(ptr)
1783
+
1784
+
1785
+ class CudaDefaultAllocator:
1786
+ def __init__(self, device):
1787
+ assert device.is_cuda
1788
+ self.device = device
1789
+ self.deleter = lambda ptr, size: self.free(ptr, size)
1790
+
1791
+ def alloc(self, size_in_bytes):
1792
+ ptr = runtime.core.alloc_device_default(self.device.context, size_in_bytes)
1793
+ # If the allocation fails, check if graph capture is active to raise an informative error.
1794
+ # We delay the capture check to avoid overhead.
1795
+ if not ptr:
1796
+ reason = ""
1736
1797
  if self.device.is_capturing:
1737
- raise RuntimeError(f"Cannot free memory on device {self} while graph capture is active")
1738
- return runtime.core.free_device(self.device.context, ptr)
1739
- elif self.device.is_cpu:
1740
- if pinned:
1741
- return runtime.core.free_pinned(ptr)
1742
- else:
1743
- return runtime.core.free_host(ptr)
1798
+ if not self.device.is_mempool_supported:
1799
+ reason = (
1800
+ ": "
1801
+ f"Failed to allocate memory during graph capture because memory pools are not supported "
1802
+ f"on device '{self.device}'. Try pre-allocating memory before capture begins."
1803
+ )
1804
+ elif not self.device.is_mempool_enabled:
1805
+ reason = (
1806
+ ": "
1807
+ f"Failed to allocate memory during graph capture because memory pools are not enabled "
1808
+ f"on device '{self.device}'. Try calling wp.set_mempool_enabled('{self.device}', True) before capture begins."
1809
+ )
1810
+ raise RuntimeError(f"Failed to allocate {size_in_bytes} bytes on device '{self.device}'{reason}")
1811
+ return ptr
1812
+
1813
+ def free(self, ptr, size_in_bytes):
1814
+ runtime.core.free_device_default(self.device.context, ptr)
1815
+
1816
+
1817
+ class CudaMempoolAllocator:
1818
+ def __init__(self, device):
1819
+ assert device.is_cuda
1820
+ assert device.is_mempool_supported
1821
+ self.device = device
1822
+ self.deleter = lambda ptr, size: self.free(ptr, size)
1823
+
1824
+ def alloc(self, size_in_bytes):
1825
+ ptr = runtime.core.alloc_device_async(self.device.context, size_in_bytes)
1826
+ if not ptr:
1827
+ raise RuntimeError(f"Failed to allocate {size_in_bytes} bytes on device '{self.device}'")
1828
+ return ptr
1829
+
1830
+ def free(self, ptr, size_in_bytes):
1831
+ runtime.core.free_device_async(self.device.context, ptr)
1744
1832
 
1745
1833
 
1746
1834
  class ContextGuard:
@@ -1762,8 +1850,12 @@ class ContextGuard:
1762
1850
 
1763
1851
  class Stream:
1764
1852
  def __init__(self, device=None, **kwargs):
1853
+ self.cuda_stream = None
1765
1854
  self.owner = False
1766
1855
 
1856
+ # event used internally for synchronization (cached to avoid creating temporary events)
1857
+ self._cached_event = None
1858
+
1767
1859
  # we can't use get_device() if called during init, but we can use an explicit Device arg
1768
1860
  if runtime is not None:
1769
1861
  device = runtime.get_device(device)
@@ -1775,20 +1867,32 @@ class Stream:
1775
1867
  if not device.is_cuda:
1776
1868
  raise RuntimeError(f"Device {device} is not a CUDA device")
1777
1869
 
1870
+ self.device = device
1871
+
1778
1872
  # we pass cuda_stream through kwargs because cuda_stream=None is actually a valid value (CUDA default stream)
1779
1873
  if "cuda_stream" in kwargs:
1780
1874
  self.cuda_stream = kwargs["cuda_stream"]
1875
+ device.runtime.core.cuda_stream_register(device.context, self.cuda_stream)
1781
1876
  else:
1782
1877
  self.cuda_stream = device.runtime.core.cuda_stream_create(device.context)
1783
1878
  if not self.cuda_stream:
1784
1879
  raise RuntimeError(f"Failed to create stream on device {device}")
1785
1880
  self.owner = True
1786
1881
 
1787
- self.device = device
1788
-
1789
1882
  def __del__(self):
1883
+ if not self.cuda_stream:
1884
+ return
1885
+
1790
1886
  if self.owner:
1791
1887
  runtime.core.cuda_stream_destroy(self.device.context, self.cuda_stream)
1888
+ else:
1889
+ runtime.core.cuda_stream_unregister(self.device.context, self.cuda_stream)
1890
+
1891
+ @property
1892
+ def cached_event(self):
1893
+ if self._cached_event is None:
1894
+ self._cached_event = Event(self.device)
1895
+ return self._cached_event
1792
1896
 
1793
1897
  def record_event(self, event=None):
1794
1898
  if event is None:
@@ -1798,20 +1902,23 @@ class Stream:
1798
1902
  f"Event from device {event.device} cannot be recorded on stream from device {self.device}"
1799
1903
  )
1800
1904
 
1801
- runtime.core.cuda_event_record(self.device.context, event.cuda_event, self.cuda_stream)
1905
+ runtime.core.cuda_event_record(event.cuda_event, self.cuda_stream)
1802
1906
 
1803
1907
  return event
1804
1908
 
1805
1909
  def wait_event(self, event):
1806
- runtime.core.cuda_stream_wait_event(self.device.context, self.cuda_stream, event.cuda_event)
1910
+ runtime.core.cuda_stream_wait_event(self.cuda_stream, event.cuda_event)
1807
1911
 
1808
1912
  def wait_stream(self, other_stream, event=None):
1809
1913
  if event is None:
1810
- event = Event(other_stream.device)
1914
+ event = other_stream.cached_event
1811
1915
 
1812
- runtime.core.cuda_stream_wait_stream(
1813
- self.device.context, self.cuda_stream, other_stream.cuda_stream, event.cuda_event
1814
- )
1916
+ runtime.core.cuda_stream_wait_stream(self.cuda_stream, other_stream.cuda_stream, event.cuda_event)
1917
+
1918
+ # whether a graph capture is currently ongoing on this stream
1919
+ @property
1920
+ def is_capturing(self):
1921
+ return bool(runtime.core.cuda_stream_is_capturing(self.cuda_stream))
1815
1922
 
1816
1923
 
1817
1924
  class Event:
@@ -1842,8 +1949,10 @@ class Event:
1842
1949
  self.owner = True
1843
1950
 
1844
1951
  def __del__(self):
1845
- if self.owner:
1846
- runtime.core.cuda_event_destroy(self.device.context, self.cuda_event)
1952
+ if not self.owner:
1953
+ return
1954
+
1955
+ runtime.core.cuda_event_destroy(self.cuda_event)
1847
1956
 
1848
1957
 
1849
1958
  class Device:
@@ -1887,10 +1996,9 @@ class Device:
1887
1996
  self._stream = None
1888
1997
  self.null_stream = None
1889
1998
 
1890
- # indicates whether CUDA graph capture is active for this device
1891
- self.is_capturing = False
1999
+ # set of streams where capture has started
2000
+ self.captures = set()
1892
2001
 
1893
- self.allocator = Allocator(self)
1894
2002
  self.context_guard = ContextGuard(self)
1895
2003
 
1896
2004
  if self.ordinal == -1:
@@ -1898,8 +2006,9 @@ class Device:
1898
2006
  self.name = platform.processor() or "CPU"
1899
2007
  self.arch = 0
1900
2008
  self.is_uva = False
1901
- self.is_cubin_supported = False
1902
2009
  self.is_mempool_supported = False
2010
+ self.is_mempool_enabled = False
2011
+ self.is_cubin_supported = False
1903
2012
  self.uuid = None
1904
2013
  self.pci_bus_id = None
1905
2014
 
@@ -1907,14 +2016,21 @@ class Device:
1907
2016
  self.memset = runtime.core.memset_host
1908
2017
  self.memtile = runtime.core.memtile_host
1909
2018
 
2019
+ self.default_allocator = CpuDefaultAllocator(self)
2020
+ self.pinned_allocator = CpuPinnedAllocator(self)
2021
+
1910
2022
  elif ordinal >= 0 and ordinal < runtime.core.cuda_device_get_count():
1911
2023
  # CUDA device
1912
2024
  self.name = runtime.core.cuda_device_get_name(ordinal).decode()
1913
2025
  self.arch = runtime.core.cuda_device_get_arch(ordinal)
1914
2026
  self.is_uva = runtime.core.cuda_device_is_uva(ordinal)
1915
- # check whether our NVRTC can generate CUBINs for this architecture
1916
- self.is_cubin_supported = self.arch in runtime.nvrtc_supported_archs
1917
- self.is_mempool_supported = runtime.core.cuda_device_is_memory_pool_supported(ordinal)
2027
+ self.is_mempool_supported = runtime.core.cuda_device_is_mempool_supported(ordinal)
2028
+ if warp.config.enable_mempools_at_init:
2029
+ # enable if supported
2030
+ self.is_mempool_enabled = self.is_mempool_supported
2031
+ else:
2032
+ # disable by default
2033
+ self.is_mempool_enabled = False
1918
2034
 
1919
2035
  uuid_buffer = (ctypes.c_char * 16)()
1920
2036
  runtime.core.cuda_device_get_uuid(ordinal, uuid_buffer)
@@ -1927,13 +2043,20 @@ class Device:
1927
2043
  # This is (mis)named to correspond to the naming of cudaDeviceGetPCIBusId
1928
2044
  self.pci_bus_id = f"{pci_domain_id:08X}:{pci_bus_id:02X}:{pci_device_id:02X}"
1929
2045
 
1930
- # Warn the user of a possible misconfiguration of their system
1931
- if not self.is_mempool_supported:
1932
- warp.utils.warn(
1933
- f"Support for stream ordered memory allocators was not detected on device {ordinal}. "
1934
- "This can prevent the use of graphs and/or result in poor performance. "
1935
- "Is the UVM driver enabled?"
1936
- )
2046
+ self.default_allocator = CudaDefaultAllocator(self)
2047
+ if self.is_mempool_supported:
2048
+ self.mempool_allocator = CudaMempoolAllocator(self)
2049
+ else:
2050
+ self.mempool_allocator = None
2051
+
2052
+ # set current allocator
2053
+ if self.is_mempool_enabled:
2054
+ self.current_allocator = self.mempool_allocator
2055
+ else:
2056
+ self.current_allocator = self.default_allocator
2057
+
2058
+ # check whether our NVRTC can generate CUBINs for this architecture
2059
+ self.is_cubin_supported = self.arch in runtime.nvrtc_supported_archs
1937
2060
 
1938
2061
  # initialize streams unless context acquisition is postponed
1939
2062
  if self._context is not None:
@@ -1948,9 +2071,18 @@ class Device:
1948
2071
  else:
1949
2072
  raise RuntimeError(f"Invalid device ordinal ({ordinal})'")
1950
2073
 
2074
+ def get_allocator(self, pinned=False):
2075
+ if self.is_cuda:
2076
+ return self.current_allocator
2077
+ else:
2078
+ if pinned:
2079
+ return self.pinned_allocator
2080
+ else:
2081
+ return self.default_allocator
2082
+
1951
2083
  def init_streams(self):
1952
2084
  # create a stream for asynchronous work
1953
- self.stream = Stream(self)
2085
+ self.set_stream(Stream(self))
1954
2086
 
1955
2087
  # CUDA default stream for some synchronous operations
1956
2088
  self.null_stream = Stream(self, cuda_stream=None)
@@ -1965,6 +2097,17 @@ class Device:
1965
2097
  """A boolean indicating whether or not the device is a CUDA device."""
1966
2098
  return self.ordinal >= 0
1967
2099
 
2100
+ @property
2101
+ def is_capturing(self):
2102
+ if self.is_cuda and self.stream is not None:
2103
+ # There is no CUDA API to check if graph capture was started on a device, so we
2104
+ # can't tell if a capture was started by external code on a different stream.
2105
+ # The best we can do is check whether a graph capture was started by Warp on this
2106
+ # device and whether the current stream is capturing.
2107
+ return self.captures or self.stream.is_capturing
2108
+ else:
2109
+ return False
2110
+
1968
2111
  @property
1969
2112
  def context(self):
1970
2113
  """The context associated with the device."""
@@ -1972,12 +2115,15 @@ class Device:
1972
2115
  return self._context
1973
2116
  elif self.is_primary:
1974
2117
  # acquire primary context on demand
1975
- self._context = self.runtime.core.cuda_device_primary_context_retain(self.ordinal)
2118
+ prev_context = runtime.core.cuda_context_get_current()
2119
+ self._context = self.runtime.core.cuda_device_get_primary_context(self.ordinal)
1976
2120
  if self._context is None:
2121
+ runtime.core.cuda_context_set_current(prev_context)
1977
2122
  raise RuntimeError(f"Failed to acquire primary context for device {self}")
1978
2123
  self.runtime.context_map[self._context] = self
1979
2124
  # initialize streams
1980
2125
  self.init_streams()
2126
+ runtime.core.cuda_context_set_current(prev_context)
1981
2127
  return self._context
1982
2128
 
1983
2129
  @property
@@ -1998,12 +2144,16 @@ class Device:
1998
2144
  raise RuntimeError(f"Device {self} is not a CUDA device")
1999
2145
 
2000
2146
  @stream.setter
2001
- def stream(self, s):
2147
+ def stream(self, stream):
2148
+ self.set_stream(stream)
2149
+
2150
+ def set_stream(self, stream, sync=True):
2002
2151
  if self.is_cuda:
2003
- if s.device != self:
2004
- raise RuntimeError(f"Stream from device {s.device} cannot be used on device {self}")
2005
- self._stream = s
2006
- self.runtime.core.cuda_context_set_stream(self.context, s.cuda_stream)
2152
+ if stream.device != self:
2153
+ raise RuntimeError(f"Stream from device {stream.device} cannot be used on device {self}")
2154
+
2155
+ self.runtime.core.cuda_context_set_stream(self.context, stream.cuda_stream, int(sync))
2156
+ self._stream = stream
2007
2157
  else:
2008
2158
  raise RuntimeError(f"Device {self} is not a CUDA device")
2009
2159
 
@@ -2012,6 +2162,26 @@ class Device:
2012
2162
  """A boolean indicating whether or not the device has a stream associated with it."""
2013
2163
  return self._stream is not None
2014
2164
 
2165
+ @property
2166
+ def total_memory(self):
2167
+ if self.is_cuda:
2168
+ total_mem = ctypes.c_size_t()
2169
+ self.runtime.core.cuda_device_get_memory_info(self.ordinal, None, ctypes.byref(total_mem))
2170
+ return total_mem.value
2171
+ else:
2172
+ # TODO: cpu
2173
+ return 0
2174
+
2175
+ @property
2176
+ def free_memory(self):
2177
+ if self.is_cuda:
2178
+ free_mem = ctypes.c_size_t()
2179
+ self.runtime.core.cuda_device_get_memory_info(self.ordinal, ctypes.byref(free_mem), None)
2180
+ return free_mem.value
2181
+ else:
2182
+ # TODO: cpu
2183
+ return 0
2184
+
2015
2185
  def __str__(self):
2016
2186
  return self.alias
2017
2187
 
@@ -2036,11 +2206,14 @@ class Device:
2036
2206
  self.runtime.core.cuda_context_set_current(self.context)
2037
2207
 
2038
2208
  def can_access(self, other):
2209
+ # TODO: this function should be redesigned in terms of (device, resource).
2210
+ # - a device can access any resource on the same device
2211
+ # - a CUDA device can access pinned memory on the host
2212
+ # - a CUDA device can access regular allocations on a peer device if peer access is enabled
2213
+ # - a CUDA device can access mempool allocations on a peer device if mempool access is enabled
2039
2214
  other = self.runtime.get_device(other)
2040
2215
  if self.context == other.context:
2041
2216
  return True
2042
- elif self.context is not None and other.context is not None:
2043
- return bool(self.runtime.core.cuda_context_can_access_peer(self.context, other.context))
2044
2217
  else:
2045
2218
  return False
2046
2219
 
@@ -2056,6 +2229,9 @@ class Graph:
2056
2229
  self.exec = exec
2057
2230
 
2058
2231
  def __del__(self):
2232
+ if not self.exec:
2233
+ return
2234
+
2059
2235
  # use CUDA context guard to avoid side effects during garbage collection
2060
2236
  with self.device.context_guard:
2061
2237
  runtime.core.cuda_graph_destroy(self.device.context, self.exec)
@@ -2095,12 +2271,23 @@ class Runtime:
2095
2271
  self.llvm = None
2096
2272
 
2097
2273
  # setup c-types for warp.dll
2274
+ self.core.get_error_string.argtypes = []
2275
+ self.core.get_error_string.restype = ctypes.c_char_p
2276
+ self.core.set_error_output_enabled.argtypes = [ctypes.c_int]
2277
+ self.core.set_error_output_enabled.restype = None
2278
+ self.core.is_error_output_enabled.argtypes = []
2279
+ self.core.is_error_output_enabled.restype = ctypes.c_int
2280
+
2098
2281
  self.core.alloc_host.argtypes = [ctypes.c_size_t]
2099
2282
  self.core.alloc_host.restype = ctypes.c_void_p
2100
2283
  self.core.alloc_pinned.argtypes = [ctypes.c_size_t]
2101
2284
  self.core.alloc_pinned.restype = ctypes.c_void_p
2102
2285
  self.core.alloc_device.argtypes = [ctypes.c_void_p, ctypes.c_size_t]
2103
2286
  self.core.alloc_device.restype = ctypes.c_void_p
2287
+ self.core.alloc_device_default.argtypes = [ctypes.c_void_p, ctypes.c_size_t]
2288
+ self.core.alloc_device_default.restype = ctypes.c_void_p
2289
+ self.core.alloc_device_async.argtypes = [ctypes.c_void_p, ctypes.c_size_t]
2290
+ self.core.alloc_device_async.restype = ctypes.c_void_p
2104
2291
 
2105
2292
  self.core.float_to_half_bits.argtypes = [ctypes.c_float]
2106
2293
  self.core.float_to_half_bits.restype = ctypes.c_uint16
@@ -2113,6 +2300,10 @@ class Runtime:
2113
2300
  self.core.free_pinned.restype = None
2114
2301
  self.core.free_device.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
2115
2302
  self.core.free_device.restype = None
2303
+ self.core.free_device_default.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
2304
+ self.core.free_device_default.restype = None
2305
+ self.core.free_device_async.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
2306
+ self.core.free_device_async.restype = None
2116
2307
 
2117
2308
  self.core.memset_host.argtypes = [ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t]
2118
2309
  self.core.memset_host.restype = None
@@ -2131,15 +2322,40 @@ class Runtime:
2131
2322
  self.core.memtile_device.restype = None
2132
2323
 
2133
2324
  self.core.memcpy_h2h.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t]
2134
- self.core.memcpy_h2h.restype = None
2135
- self.core.memcpy_h2d.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t]
2136
- self.core.memcpy_h2d.restype = None
2137
- self.core.memcpy_d2h.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t]
2138
- self.core.memcpy_d2h.restype = None
2139
- self.core.memcpy_d2d.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t]
2140
- self.core.memcpy_d2d.restype = None
2141
- self.core.memcpy_peer.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t]
2142
- self.core.memcpy_peer.restype = None
2325
+ self.core.memcpy_h2h.restype = ctypes.c_bool
2326
+ self.core.memcpy_h2d.argtypes = [
2327
+ ctypes.c_void_p,
2328
+ ctypes.c_void_p,
2329
+ ctypes.c_void_p,
2330
+ ctypes.c_size_t,
2331
+ ctypes.c_void_p,
2332
+ ]
2333
+ self.core.memcpy_h2d.restype = ctypes.c_bool
2334
+ self.core.memcpy_d2h.argtypes = [
2335
+ ctypes.c_void_p,
2336
+ ctypes.c_void_p,
2337
+ ctypes.c_void_p,
2338
+ ctypes.c_size_t,
2339
+ ctypes.c_void_p,
2340
+ ]
2341
+ self.core.memcpy_d2h.restype = ctypes.c_bool
2342
+ self.core.memcpy_d2d.argtypes = [
2343
+ ctypes.c_void_p,
2344
+ ctypes.c_void_p,
2345
+ ctypes.c_void_p,
2346
+ ctypes.c_size_t,
2347
+ ctypes.c_void_p,
2348
+ ]
2349
+ self.core.memcpy_d2d.restype = ctypes.c_bool
2350
+ self.core.memcpy_p2p.argtypes = [
2351
+ ctypes.c_void_p,
2352
+ ctypes.c_void_p,
2353
+ ctypes.c_void_p,
2354
+ ctypes.c_void_p,
2355
+ ctypes.c_size_t,
2356
+ ctypes.c_void_p,
2357
+ ]
2358
+ self.core.memcpy_p2p.restype = ctypes.c_bool
2143
2359
 
2144
2360
  self.core.array_copy_host.argtypes = [
2145
2361
  ctypes.c_void_p,
@@ -2148,7 +2364,7 @@ class Runtime:
2148
2364
  ctypes.c_int,
2149
2365
  ctypes.c_int,
2150
2366
  ]
2151
- self.core.array_copy_host.restype = ctypes.c_size_t
2367
+ self.core.array_copy_host.restype = ctypes.c_bool
2152
2368
  self.core.array_copy_device.argtypes = [
2153
2369
  ctypes.c_void_p,
2154
2370
  ctypes.c_void_p,
@@ -2157,7 +2373,7 @@ class Runtime:
2157
2373
  ctypes.c_int,
2158
2374
  ctypes.c_int,
2159
2375
  ]
2160
- self.core.array_copy_device.restype = ctypes.c_size_t
2376
+ self.core.array_copy_device.restype = ctypes.c_bool
2161
2377
 
2162
2378
  self.core.array_fill_host.argtypes = [ctypes.c_void_p, ctypes.c_int, ctypes.c_void_p, ctypes.c_int]
2163
2379
  self.core.array_fill_host.restype = None
@@ -2311,6 +2527,7 @@ class Runtime:
2311
2527
  self.core.hash_grid_reserve_device.argtypes = [ctypes.c_uint64, ctypes.c_int]
2312
2528
 
2313
2529
  self.core.cutlass_gemm.argtypes = [
2530
+ ctypes.c_void_p,
2314
2531
  ctypes.c_int,
2315
2532
  ctypes.c_int,
2316
2533
  ctypes.c_int,
@@ -2327,7 +2544,7 @@ class Runtime:
2327
2544
  ctypes.c_bool,
2328
2545
  ctypes.c_int,
2329
2546
  ]
2330
- self.core.cutlass_gemm.restypes = ctypes.c_bool
2547
+ self.core.cutlass_gemm.restype = ctypes.c_bool
2331
2548
 
2332
2549
  self.core.volume_create_host.argtypes = [ctypes.c_void_p, ctypes.c_uint64]
2333
2550
  self.core.volume_create_host.restype = ctypes.c_uint64
@@ -2463,14 +2680,22 @@ class Runtime:
2463
2680
 
2464
2681
  self.core.cuda_device_get_count.argtypes = None
2465
2682
  self.core.cuda_device_get_count.restype = ctypes.c_int
2466
- self.core.cuda_device_primary_context_retain.argtypes = [ctypes.c_int]
2467
- self.core.cuda_device_primary_context_retain.restype = ctypes.c_void_p
2683
+ self.core.cuda_device_get_primary_context.argtypes = [ctypes.c_int]
2684
+ self.core.cuda_device_get_primary_context.restype = ctypes.c_void_p
2468
2685
  self.core.cuda_device_get_name.argtypes = [ctypes.c_int]
2469
2686
  self.core.cuda_device_get_name.restype = ctypes.c_char_p
2470
2687
  self.core.cuda_device_get_arch.argtypes = [ctypes.c_int]
2471
2688
  self.core.cuda_device_get_arch.restype = ctypes.c_int
2472
2689
  self.core.cuda_device_is_uva.argtypes = [ctypes.c_int]
2473
2690
  self.core.cuda_device_is_uva.restype = ctypes.c_int
2691
+ self.core.cuda_device_is_mempool_supported.argtypes = [ctypes.c_int]
2692
+ self.core.cuda_device_is_mempool_supported.restype = ctypes.c_int
2693
+ self.core.cuda_device_set_mempool_release_threshold.argtypes = [ctypes.c_int, ctypes.c_uint64]
2694
+ self.core.cuda_device_set_mempool_release_threshold.restype = ctypes.c_int
2695
+ self.core.cuda_device_get_mempool_release_threshold.argtypes = [ctypes.c_int]
2696
+ self.core.cuda_device_get_mempool_release_threshold.restype = ctypes.c_uint64
2697
+ self.core.cuda_device_get_memory_info.argtypes = [ctypes.c_int, ctypes.c_void_p, ctypes.c_void_p]
2698
+ self.core.cuda_device_get_memory_info.restype = None
2474
2699
  self.core.cuda_device_get_uuid.argtypes = [ctypes.c_int, ctypes.c_char * 16]
2475
2700
  self.core.cuda_device_get_uuid.restype = None
2476
2701
  self.core.cuda_device_get_pci_domain_id.argtypes = [ctypes.c_int]
@@ -2503,42 +2728,53 @@ class Runtime:
2503
2728
  self.core.cuda_context_is_primary.restype = ctypes.c_int
2504
2729
  self.core.cuda_context_get_stream.argtypes = [ctypes.c_void_p]
2505
2730
  self.core.cuda_context_get_stream.restype = ctypes.c_void_p
2506
- self.core.cuda_context_set_stream.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
2731
+ self.core.cuda_context_set_stream.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int]
2507
2732
  self.core.cuda_context_set_stream.restype = None
2508
- self.core.cuda_context_can_access_peer.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
2509
- self.core.cuda_context_can_access_peer.restype = ctypes.c_int
2733
+
2734
+ # peer access
2735
+ self.core.cuda_is_peer_access_supported.argtypes = [ctypes.c_int, ctypes.c_int]
2736
+ self.core.cuda_is_peer_access_supported.restype = ctypes.c_int
2737
+ self.core.cuda_is_peer_access_enabled.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
2738
+ self.core.cuda_is_peer_access_enabled.restype = ctypes.c_int
2739
+ self.core.cuda_set_peer_access_enabled.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int]
2740
+ self.core.cuda_set_peer_access_enabled.restype = ctypes.c_int
2741
+ self.core.cuda_is_mempool_access_enabled.argtypes = [ctypes.c_int, ctypes.c_int]
2742
+ self.core.cuda_is_mempool_access_enabled.restype = ctypes.c_int
2743
+ self.core.cuda_set_mempool_access_enabled.argtypes = [ctypes.c_int, ctypes.c_int, ctypes.c_int]
2744
+ self.core.cuda_set_mempool_access_enabled.restype = ctypes.c_int
2510
2745
 
2511
2746
  self.core.cuda_stream_create.argtypes = [ctypes.c_void_p]
2512
2747
  self.core.cuda_stream_create.restype = ctypes.c_void_p
2513
2748
  self.core.cuda_stream_destroy.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
2514
2749
  self.core.cuda_stream_destroy.restype = None
2515
- self.core.cuda_stream_synchronize.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
2750
+ self.core.cuda_stream_register.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
2751
+ self.core.cuda_stream_register.restype = None
2752
+ self.core.cuda_stream_unregister.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
2753
+ self.core.cuda_stream_unregister.restype = None
2754
+ self.core.cuda_stream_synchronize.argtypes = [ctypes.c_void_p]
2516
2755
  self.core.cuda_stream_synchronize.restype = None
2517
- self.core.cuda_stream_wait_event.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p]
2756
+ self.core.cuda_stream_wait_event.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
2518
2757
  self.core.cuda_stream_wait_event.restype = None
2519
- self.core.cuda_stream_wait_stream.argtypes = [
2520
- ctypes.c_void_p,
2521
- ctypes.c_void_p,
2522
- ctypes.c_void_p,
2523
- ctypes.c_void_p,
2524
- ]
2758
+ self.core.cuda_stream_wait_stream.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p]
2525
2759
  self.core.cuda_stream_wait_stream.restype = None
2760
+ self.core.cuda_stream_is_capturing.argtypes = [ctypes.c_void_p]
2761
+ self.core.cuda_stream_is_capturing.restype = ctypes.c_int
2526
2762
 
2527
2763
  self.core.cuda_event_create.argtypes = [ctypes.c_void_p, ctypes.c_uint]
2528
2764
  self.core.cuda_event_create.restype = ctypes.c_void_p
2529
- self.core.cuda_event_destroy.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
2765
+ self.core.cuda_event_destroy.argtypes = [ctypes.c_void_p]
2530
2766
  self.core.cuda_event_destroy.restype = None
2531
- self.core.cuda_event_record.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p]
2767
+ self.core.cuda_event_record.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
2532
2768
  self.core.cuda_event_record.restype = None
2533
2769
 
2534
- self.core.cuda_graph_begin_capture.argtypes = [ctypes.c_void_p]
2535
- self.core.cuda_graph_begin_capture.restype = None
2536
- self.core.cuda_graph_end_capture.argtypes = [ctypes.c_void_p]
2537
- self.core.cuda_graph_end_capture.restype = ctypes.c_void_p
2770
+ self.core.cuda_graph_begin_capture.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int]
2771
+ self.core.cuda_graph_begin_capture.restype = ctypes.c_bool
2772
+ self.core.cuda_graph_end_capture.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.POINTER(ctypes.c_void_p)]
2773
+ self.core.cuda_graph_end_capture.restype = ctypes.c_bool
2538
2774
  self.core.cuda_graph_launch.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
2539
- self.core.cuda_graph_launch.restype = None
2775
+ self.core.cuda_graph_launch.restype = ctypes.c_bool
2540
2776
  self.core.cuda_graph_destroy.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
2541
- self.core.cuda_graph_destroy.restype = None
2777
+ self.core.cuda_graph_destroy.restype = ctypes.c_bool
2542
2778
 
2543
2779
  self.core.cuda_compile_program.argtypes = [
2544
2780
  ctypes.c_char_p,
@@ -2567,6 +2803,7 @@ class Runtime:
2567
2803
  ctypes.c_size_t,
2568
2804
  ctypes.c_int,
2569
2805
  ctypes.POINTER(ctypes.c_void_p),
2806
+ ctypes.c_void_p,
2570
2807
  ]
2571
2808
  self.core.cuda_launch_kernel.restype = ctypes.c_size_t
2572
2809
 
@@ -2620,7 +2857,11 @@ class Runtime:
2620
2857
  else:
2621
2858
  self.nvrtc_supported_archs = []
2622
2859
 
2623
- # register CUDA devices
2860
+ # this is so we can give non-primary contexts a reasonable alias
2861
+ # associated with the physical device (e.g., "cuda:0.0", "cuda:0.1")
2862
+ self.cuda_custom_context_count = [0] * cuda_device_count
2863
+
2864
+ # register primary CUDA devices
2624
2865
  self.cuda_devices = []
2625
2866
  self.cuda_primary_devices = []
2626
2867
  for i in range(cuda_device_count):
@@ -2632,8 +2873,12 @@ class Runtime:
2632
2873
 
2633
2874
  # set default device
2634
2875
  if cuda_device_count > 0:
2635
- if self.core.cuda_context_get_current() is not None:
2876
+ # stick with the current cuda context, if one is bound
2877
+ initial_context = self.core.cuda_context_get_current()
2878
+ if initial_context is not None:
2636
2879
  self.set_default_device("cuda")
2880
+ # if this is a non-primary context that was just registered, update the device count
2881
+ cuda_device_count = len(self.cuda_devices)
2637
2882
  else:
2638
2883
  self.set_default_device("cuda:0")
2639
2884
  else:
@@ -2643,43 +2888,130 @@ class Runtime:
2643
2888
  # initialize kernel cache
2644
2889
  warp.build.init_kernel_cache(warp.config.kernel_cache_dir)
2645
2890
 
2891
+ devices_without_uva = []
2892
+ devices_without_mempool = []
2893
+ for cuda_device in self.cuda_devices:
2894
+ if cuda_device.is_primary:
2895
+ if not cuda_device.is_uva:
2896
+ devices_without_uva.append(cuda_device)
2897
+ if not cuda_device.is_mempool_supported:
2898
+ devices_without_mempool.append(cuda_device)
2899
+
2646
2900
  # print device and version information
2647
2901
  if not warp.config.quiet:
2648
- print(f"Warp {warp.config.version} initialized:")
2902
+ greeting = []
2903
+
2904
+ greeting.append(f"Warp {warp.config.version} initialized:")
2649
2905
  if cuda_device_count > 0:
2650
2906
  toolkit_version = (self.toolkit_version // 1000, (self.toolkit_version % 1000) // 10)
2651
2907
  driver_version = (self.driver_version // 1000, (self.driver_version % 1000) // 10)
2652
- print(
2653
- f" CUDA Toolkit: {toolkit_version[0]}.{toolkit_version[1]}, Driver: {driver_version[0]}.{driver_version[1]}"
2908
+ greeting.append(
2909
+ f" CUDA Toolkit {toolkit_version[0]}.{toolkit_version[1]}, Driver {driver_version[0]}.{driver_version[1]}"
2654
2910
  )
2655
2911
  else:
2656
2912
  if self.core.is_cuda_enabled():
2657
2913
  # Warp was compiled with CUDA support, but no devices are available
2658
- print(" CUDA devices not available")
2914
+ greeting.append(" CUDA devices not available")
2659
2915
  else:
2660
2916
  # Warp was compiled without CUDA support
2661
- print(" CUDA support not enabled in this build")
2662
- print(" Devices:")
2663
- print(f' "{self.cpu_device.alias}" | {self.cpu_device.name}')
2917
+ greeting.append(" CUDA support not enabled in this build")
2918
+ greeting.append(" Devices:")
2919
+ alias_str = f'"{self.cpu_device.alias}"'
2920
+ name_str = f'"{self.cpu_device.name}"'
2921
+ greeting.append(f" {alias_str:10s} : {name_str}")
2664
2922
  for cuda_device in self.cuda_devices:
2665
- print(f' "{cuda_device.alias}" | {cuda_device.name} (sm_{cuda_device.arch})')
2666
- print(f" Kernel cache: {warp.config.kernel_cache_dir}")
2667
-
2668
- # CUDA compatibility check
2669
- if cuda_device_count > 0 and not self.core.is_cuda_compatibility_enabled():
2670
- if self.driver_version < self.toolkit_version:
2671
- print("******************************************************************")
2672
- print("* WARNING: *")
2673
- print("* Warp was compiled without CUDA compatibility support *")
2674
- print("* (quick build). The CUDA Toolkit version used to build *")
2675
- print("* Warp is not fully supported by the current driver. *")
2676
- print("* Some CUDA functionality may not work correctly! *")
2677
- print("* Update the driver or rebuild Warp without the --quick flag. *")
2678
- print("******************************************************************")
2923
+ alias_str = f'"{cuda_device.alias}"'
2924
+ if cuda_device.is_primary:
2925
+ name_str = f'"{cuda_device.name}"'
2926
+ arch_str = f"sm_{cuda_device.arch}"
2927
+ mem_str = f"{cuda_device.total_memory / 1024 / 1024 / 1024:.0f} GiB"
2928
+ if cuda_device.is_mempool_supported:
2929
+ if cuda_device.is_mempool_enabled:
2930
+ mempool_str = "mempool enabled"
2931
+ else:
2932
+ mempool_str = "mempool supported"
2933
+ else:
2934
+ mempool_str = "mempool not supported"
2935
+ greeting.append(f" {alias_str:10s} : {name_str} ({mem_str}, {arch_str}, {mempool_str})")
2936
+ else:
2937
+ primary_alias_str = f'"{self.cuda_primary_devices[cuda_device.ordinal].alias}"'
2938
+ greeting.append(f" {alias_str:10s} : Non-primary context on device {primary_alias_str}")
2939
+ if cuda_device_count > 1:
2940
+ # check peer access support
2941
+ access_matrix = []
2942
+ all_accessible = True
2943
+ none_accessible = True
2944
+ for i in range(cuda_device_count):
2945
+ target_device = self.cuda_devices[i]
2946
+ access_vector = []
2947
+ for j in range(cuda_device_count):
2948
+ if i == j:
2949
+ access_vector.append(1)
2950
+ else:
2951
+ peer_device = self.cuda_devices[j]
2952
+ can_access = self.core.cuda_is_peer_access_supported(
2953
+ target_device.ordinal, peer_device.ordinal
2954
+ )
2955
+ access_vector.append(can_access)
2956
+ all_accessible = all_accessible and can_access
2957
+ none_accessible = none_accessible and not can_access
2958
+ access_matrix.append(access_vector)
2959
+ greeting.append(" CUDA peer access:")
2960
+ if all_accessible:
2961
+ greeting.append(" Supported fully (all-directional)")
2962
+ elif none_accessible:
2963
+ greeting.append(" Not supported")
2964
+ else:
2965
+ greeting.append(" Supported partially (see access matrix)")
2966
+ # print access matrix
2967
+ for i in range(cuda_device_count):
2968
+ alias_str = f'"{self.cuda_devices[i].alias}"'
2969
+ greeting.append(f" {alias_str:10s} : {access_matrix[i]}")
2970
+ greeting.append(" Kernel cache:")
2971
+ greeting.append(f" {warp.config.kernel_cache_dir}")
2972
+
2973
+ print("\n".join(greeting))
2974
+
2975
+ if cuda_device_count > 0:
2976
+ # warn about possible misconfiguration of the system
2977
+ if devices_without_uva:
2978
+ # This should not happen on any system officially supported by Warp. UVA is not available
2979
+ # on 32-bit Windows, which we don't support. Nonetheless, we should check and report a
2980
+ # warning out of abundance of caution. It may help with debugging a broken VM setup etc.
2981
+ warp.utils.warn(
2982
+ f"Support for Unified Virtual Addressing (UVA) was not detected on devices {devices_without_uva}."
2983
+ )
2984
+ if devices_without_mempool:
2985
+ warp.utils.warn(
2986
+ f"Support for CUDA memory pools was not detected on devices {devices_without_mempool}. "
2987
+ "This prevents memory allocations in CUDA graphs and may result in poor performance. "
2988
+ "Is the UVM driver enabled?"
2989
+ )
2990
+
2991
+ # CUDA compatibility check. This should only affect developer builds done with the
2992
+ # --quick flag. The consequences of running with an older driver can be obscure and severe,
2993
+ # so make sure we print a very visible warning.
2994
+ if self.driver_version < self.toolkit_version and not self.core.is_cuda_compatibility_enabled():
2995
+ print(
2996
+ "******************************************************************\n"
2997
+ "* WARNING: *\n"
2998
+ "* Warp was compiled without CUDA compatibility support *\n"
2999
+ "* (quick build). The CUDA Toolkit version used to build *\n"
3000
+ "* Warp is not fully supported by the current driver. *\n"
3001
+ "* Some CUDA functionality may not work correctly! *\n"
3002
+ "* Update the driver or rebuild Warp without the --quick flag. *\n"
3003
+ "******************************************************************\n"
3004
+ )
3005
+
3006
+ # ensure initialization did not change the initial context (e.g. querying available memory)
3007
+ self.core.cuda_context_set_current(initial_context)
2679
3008
 
2680
3009
  # global tape
2681
3010
  self.tape = None
2682
3011
 
3012
+ def get_error_string(self):
3013
+ return self.core.get_error_string().decode("utf-8")
3014
+
2683
3015
  def load_dll(self, dll_path):
2684
3016
  try:
2685
3017
  if sys.version_info[0] > 3 or sys.version_info[0] == 3 and sys.version_info[1] >= 8:
@@ -2691,7 +3023,7 @@ class Runtime:
2691
3023
  raise RuntimeError(
2692
3024
  f"Failed to load the shared library '{dll_path}'.\n"
2693
3025
  "The execution environment's libstdc++ runtime is older than the version the Warp library was built for.\n"
2694
- "See https://nvidia.github.io/warp/_build/html/installation.html#conda-environments for details."
3026
+ "See https://nvidia.github.io/warp/installation.html#conda-environments for details."
2695
3027
  ) from e
2696
3028
  else:
2697
3029
  raise RuntimeError(f"Failed to load the shared library '{dll_path}'") from e
@@ -2728,14 +3060,20 @@ class Runtime:
2728
3060
  return device
2729
3061
  else:
2730
3062
  # this is an unseen non-primary context, register it as a new device with a unique alias
2731
- alias = f"cuda!{current_context:x}"
3063
+ ordinal = self.core.cuda_context_get_device_ordinal(current_context)
3064
+ alias = f"cuda:{ordinal}.{self.cuda_custom_context_count[ordinal]}"
3065
+ self.cuda_custom_context_count[ordinal] += 1
2732
3066
  return self.map_cuda_device(alias, current_context)
2733
3067
  elif self.default_device.is_cuda:
2734
3068
  return self.default_device
2735
3069
  elif self.cuda_devices:
2736
3070
  return self.cuda_devices[0]
2737
3071
  else:
2738
- raise RuntimeError("CUDA is not available")
3072
+ # CUDA is not available
3073
+ if not self.core.is_cuda_enabled():
3074
+ raise RuntimeError('"cuda" device requested but this build of Warp does not support CUDA')
3075
+ else:
3076
+ raise RuntimeError('"cuda" device requested but CUDA is not supported by the hardware or driver')
2739
3077
 
2740
3078
  def rename_device(self, device, alias):
2741
3079
  del self.device_map[device.alias]
@@ -2936,20 +3274,253 @@ def unmap_cuda_device(alias: str):
2936
3274
  runtime.unmap_cuda_device(alias)
2937
3275
 
2938
3276
 
3277
+ def is_mempool_supported(device: Devicelike):
3278
+ """Check if CUDA memory pool allocators are available on the device."""
3279
+
3280
+ assert_initialized()
3281
+
3282
+ device = runtime.get_device(device)
3283
+
3284
+ return device.is_mempool_supported
3285
+
3286
+
3287
+ def is_mempool_enabled(device: Devicelike):
3288
+ """Check if CUDA memory pool allocators are enabled on the device."""
3289
+
3290
+ assert_initialized()
3291
+
3292
+ device = runtime.get_device(device)
3293
+
3294
+ return device.is_mempool_enabled
3295
+
3296
+
3297
+ def set_mempool_enabled(device: Devicelike, enable: bool):
3298
+ """Enable or disable CUDA memory pool allocators on the device.
3299
+
3300
+ Pooled allocators are typically faster and allow allocating memory during graph capture.
3301
+
3302
+ They should generally be enabled, but there is a rare caveat. Copying data between different GPUs
3303
+ may fail during graph capture if the memory was allocated using pooled allocators and memory pool
3304
+ access is not enabled between the two GPUs. This is an internal CUDA limitation that is not related
3305
+ to Warp. The preferred solution is to enable memory pool access using `warp.set_mempool_access_enabled()`.
3306
+ If peer access is not supported, then the default CUDA allocators must be used to pre-allocate the memory
3307
+ prior to graph capture.
3308
+ """
3309
+
3310
+ assert_initialized()
3311
+
3312
+ device = runtime.get_device(device)
3313
+
3314
+ if device.is_cuda:
3315
+ if enable:
3316
+ if not device.is_mempool_supported:
3317
+ raise RuntimeError(f"Device {device} does not support memory pools")
3318
+ device.current_allocator = device.mempool_allocator
3319
+ device.is_mempool_enabled = True
3320
+ else:
3321
+ device.current_allocator = device.default_allocator
3322
+ device.is_mempool_enabled = False
3323
+ else:
3324
+ if enable:
3325
+ raise ValueError("Memory pools are only supported on CUDA devices")
3326
+
3327
+
3328
+ def set_mempool_release_threshold(device: Devicelike, threshold: int):
3329
+ """Set the CUDA memory pool release threshold on the device.
3330
+
3331
+ This is the amount of reserved memory to hold onto before trying to release memory back to the OS.
3332
+ When more than this amount of bytes is held by the memory pool, the allocator will try to release
3333
+ memory back to the OS on the next call to stream, event, or device synchronize.
3334
+ """
3335
+
3336
+ assert_initialized()
3337
+
3338
+ device = runtime.get_device(device)
3339
+
3340
+ if not device.is_cuda:
3341
+ raise ValueError("Memory pools are only supported on CUDA devices")
3342
+
3343
+ if not device.is_mempool_supported:
3344
+ raise RuntimeError(f"Device {device} does not support memory pools")
3345
+
3346
+ if not runtime.core.cuda_device_set_mempool_release_threshold(device.ordinal, threshold):
3347
+ raise RuntimeError(f"Failed to set memory pool release threshold for device {device}")
3348
+
3349
+
3350
+ def get_mempool_release_threshold(device: Devicelike):
3351
+ """Get the CUDA memory pool release threshold on the device."""
3352
+
3353
+ assert_initialized()
3354
+
3355
+ device = runtime.get_device(device)
3356
+
3357
+ if not device.is_cuda:
3358
+ raise ValueError("Memory pools are only supported on CUDA devices")
3359
+
3360
+ if not device.is_mempool_supported:
3361
+ raise RuntimeError(f"Device {device} does not support memory pools")
3362
+
3363
+ return runtime.core.cuda_device_get_mempool_release_threshold(device.ordinal)
3364
+
3365
+
3366
+ def is_peer_access_supported(target_device: Devicelike, peer_device: Devicelike):
3367
+ """Check if `peer_device` can directly access the memory of `target_device` on this system.
3368
+
3369
+ This applies to memory allocated using default CUDA allocators. For memory allocated using
3370
+ CUDA pooled allocators, use `is_mempool_access_supported()`.
3371
+
3372
+ Returns:
3373
+ A Boolean value indicating if this peer access is supported by the system.
3374
+ """
3375
+
3376
+ assert_initialized()
3377
+
3378
+ target_device = runtime.get_device(target_device)
3379
+ peer_device = runtime.get_device(peer_device)
3380
+
3381
+ if not target_device.is_cuda or not peer_device.is_cuda:
3382
+ return False
3383
+
3384
+ return bool(runtime.core.cuda_is_peer_access_supported(target_device.ordinal, peer_device.ordinal))
3385
+
3386
+
3387
+ def is_peer_access_enabled(target_device: Devicelike, peer_device: Devicelike):
3388
+ """Check if `peer_device` can currently access the memory of `target_device`.
3389
+
3390
+ This applies to memory allocated using default CUDA allocators. For memory allocated using
3391
+ CUDA pooled allocators, use `is_mempool_access_enabled()`.
3392
+
3393
+ Returns:
3394
+ A Boolean value indicating if this peer access is currently enabled.
3395
+ """
3396
+
3397
+ assert_initialized()
3398
+
3399
+ target_device = runtime.get_device(target_device)
3400
+ peer_device = runtime.get_device(peer_device)
3401
+
3402
+ if not target_device.is_cuda or not peer_device.is_cuda:
3403
+ return False
3404
+
3405
+ return bool(runtime.core.cuda_is_peer_access_enabled(target_device.context, peer_device.context))
3406
+
3407
+
3408
+ def set_peer_access_enabled(target_device: Devicelike, peer_device: Devicelike, enable: bool):
3409
+ """Enable or disable direct access from `peer_device` to the memory of `target_device`.
3410
+
3411
+ Enabling peer access can improve the speed of peer-to-peer memory transfers, but can have
3412
+ a negative impact on memory consumption and allocation performance.
3413
+
3414
+ This applies to memory allocated using default CUDA allocators. For memory allocated using
3415
+ CUDA pooled allocators, use `set_mempool_access_enabled()`.
3416
+ """
3417
+
3418
+ assert_initialized()
3419
+
3420
+ target_device = runtime.get_device(target_device)
3421
+ peer_device = runtime.get_device(peer_device)
3422
+
3423
+ if not target_device.is_cuda or not peer_device.is_cuda:
3424
+ if enable:
3425
+ raise ValueError("Peer access is only supported between CUDA devices")
3426
+ else:
3427
+ return
3428
+
3429
+ if not is_peer_access_supported(target_device, peer_device):
3430
+ if enable:
3431
+ raise RuntimeError(f"Device {peer_device} cannot access device {target_device}")
3432
+ else:
3433
+ return
3434
+
3435
+ if not runtime.core.cuda_set_peer_access_enabled(target_device.context, peer_device.context, int(enable)):
3436
+ action = "enable" if enable else "disable"
3437
+ raise RuntimeError(f"Failed to {action} peer access from device {peer_device} to device {target_device}")
3438
+
3439
+
3440
+ def is_mempool_access_supported(target_device: Devicelike, peer_device: Devicelike):
3441
+ """Check if `peer_device` can directly access the memory pool of `target_device`.
3442
+
3443
+ If mempool access is possible, it can be managed using `set_mempool_access_enabled()` and `is_mempool_access_enabled()`.
3444
+
3445
+ Returns:
3446
+ A Boolean value indicating if this memory pool access is supported by the system.
3447
+ """
3448
+
3449
+ assert_initialized()
3450
+
3451
+ return target_device.is_mempool_supported and is_peer_access_supported(target_device, peer_device)
3452
+
3453
+
3454
+ def is_mempool_access_enabled(target_device: Devicelike, peer_device: Devicelike):
3455
+ """Check if `peer_device` can currently access the memory pool of `target_device`.
3456
+
3457
+ This applies to memory allocated using CUDA pooled allocators. For memory allocated using
3458
+ default CUDA allocators, use `is_peer_access_enabled()`.
3459
+
3460
+ Returns:
3461
+ A Boolean value indicating if this peer access is currently enabled.
3462
+ """
3463
+
3464
+ assert_initialized()
3465
+
3466
+ target_device = runtime.get_device(target_device)
3467
+ peer_device = runtime.get_device(peer_device)
3468
+
3469
+ if not peer_device.is_cuda or not target_device.is_cuda or not target_device.is_mempool_supported:
3470
+ return False
3471
+
3472
+ return bool(runtime.core.cuda_is_mempool_access_enabled(target_device.ordinal, peer_device.ordinal))
3473
+
3474
+
3475
+ def set_mempool_access_enabled(target_device: Devicelike, peer_device: Devicelike, enable: bool):
3476
+ """Enable or disable access from `peer_device` to the memory pool of `target_device`.
3477
+
3478
+ This applies to memory allocated using CUDA pooled allocators. For memory allocated using
3479
+ default CUDA allocators, use `set_peer_access_enabled()`.
3480
+ """
3481
+
3482
+ assert_initialized()
3483
+
3484
+ target_device = runtime.get_device(target_device)
3485
+ peer_device = runtime.get_device(peer_device)
3486
+
3487
+ if not target_device.is_cuda or not peer_device.is_cuda:
3488
+ if enable:
3489
+ raise ValueError("Memory pool access is only supported between CUDA devices")
3490
+ else:
3491
+ return
3492
+
3493
+ if not target_device.is_mempool_supported:
3494
+ if enable:
3495
+ raise RuntimeError(f"Device {target_device} does not support memory pools")
3496
+ else:
3497
+ return
3498
+
3499
+ if not is_peer_access_supported(target_device, peer_device):
3500
+ if enable:
3501
+ raise RuntimeError(f"Device {peer_device} cannot access device {target_device}")
3502
+ else:
3503
+ return
3504
+
3505
+ if not runtime.core.cuda_set_mempool_access_enabled(target_device.ordinal, peer_device.ordinal, int(enable)):
3506
+ action = "enable" if enable else "disable"
3507
+ raise RuntimeError(f"Failed to {action} memory pool access from device {peer_device} to device {target_device}")
3508
+
3509
+
2939
3510
  def get_stream(device: Devicelike = None) -> Stream:
2940
3511
  """Return the stream currently used by the given device"""
2941
3512
 
2942
3513
  return get_device(device).stream
2943
3514
 
2944
3515
 
2945
- def set_stream(stream, device: Devicelike = None):
3516
+ def set_stream(stream, device: Devicelike = None, sync: bool = False):
2946
3517
  """Set the stream to be used by the given device.
2947
3518
 
2948
3519
  If this is an external stream, caller is responsible for guaranteeing the lifetime of the stream.
2949
3520
  Consider using wp.ScopedStream instead.
2950
3521
  """
2951
3522
 
2952
- get_device(device).stream = stream
3523
+ get_device(device).set_stream(stream, sync=sync)
2953
3524
 
2954
3525
 
2955
3526
  def record_event(event: Event = None):
@@ -3017,7 +3588,12 @@ class RegisteredGLBuffer:
3017
3588
  self.resource = runtime.core.cuda_graphics_register_gl_buffer(self.context, gl_buffer_id, flags)
3018
3589
 
3019
3590
  def __del__(self):
3020
- runtime.core.cuda_graphics_unregister_resource(self.context, self.resource)
3591
+ if not self.resource:
3592
+ return
3593
+
3594
+ # use CUDA context guard to avoid side effects during garbage collection
3595
+ with self.device.context_guard:
3596
+ runtime.core.cuda_graphics_unregister_resource(self.context, self.resource)
3021
3597
 
3022
3598
  def map(self, dtype, shape) -> warp.array:
3023
3599
  """Map the OpenGL buffer to a Warp array.
@@ -3036,7 +3612,7 @@ class RegisteredGLBuffer:
3036
3612
  runtime.core.cuda_graphics_device_ptr_and_size(
3037
3613
  self.context, self.resource, ctypes.byref(ptr), ctypes.byref(size)
3038
3614
  )
3039
- return warp.array(ptr=ptr.value, dtype=dtype, shape=shape, device=self.device, owner=False)
3615
+ return warp.array(ptr=ptr.value, dtype=dtype, shape=shape, device=self.device)
3040
3616
 
3041
3617
  def unmap(self):
3042
3618
  """Unmap the OpenGL buffer."""
@@ -3066,9 +3642,7 @@ def zeros(
3066
3642
 
3067
3643
  arr = empty(shape=shape, dtype=dtype, device=device, requires_grad=requires_grad, pinned=pinned, **kwargs)
3068
3644
 
3069
- # use the CUDA default stream for synchronous behaviour with other streams
3070
- with warp.ScopedStream(arr.device.null_stream):
3071
- arr.zero_()
3645
+ arr.zero_()
3072
3646
 
3073
3647
  return arr
3074
3648
 
@@ -3095,6 +3669,48 @@ def zeros_like(
3095
3669
  return arr
3096
3670
 
3097
3671
 
3672
+ def ones(
3673
+ shape: Tuple = None,
3674
+ dtype=float,
3675
+ device: Devicelike = None,
3676
+ requires_grad: bool = False,
3677
+ pinned: bool = False,
3678
+ **kwargs,
3679
+ ) -> warp.array:
3680
+ """Return a one-initialized array
3681
+
3682
+ Args:
3683
+ shape: Array dimensions
3684
+ dtype: Type of each element, e.g.: warp.vec3, warp.mat33, etc
3685
+ device: Device that array will live on
3686
+ requires_grad: Whether the array will be tracked for back propagation
3687
+ pinned: Whether the array uses pinned host memory (only applicable to CPU arrays)
3688
+
3689
+ Returns:
3690
+ A warp.array object representing the allocation
3691
+ """
3692
+
3693
+ return full(shape=shape, value=1, dtype=dtype, device=device, requires_grad=requires_grad, pinned=pinned, **kwargs)
3694
+
3695
+
3696
+ def ones_like(
3697
+ src: warp.array, device: Devicelike = None, requires_grad: bool = None, pinned: bool = None
3698
+ ) -> warp.array:
3699
+ """Return a one-initialized array with the same type and dimension of another array
3700
+
3701
+ Args:
3702
+ src: The template array to use for shape, data type, and device
3703
+ device: The device where the new array will be created (defaults to src.device)
3704
+ requires_grad: Whether the array will be tracked for back propagation
3705
+ pinned: Whether the array uses pinned host memory (only applicable to CPU arrays)
3706
+
3707
+ Returns:
3708
+ A warp.array object representing the allocation
3709
+ """
3710
+
3711
+ return full_like(src, 1, device=device, requires_grad=requires_grad, pinned=pinned)
3712
+
3713
+
3098
3714
  def full(
3099
3715
  shape: Tuple = None,
3100
3716
  value=0,
@@ -3154,9 +3770,7 @@ def full(
3154
3770
 
3155
3771
  arr = empty(shape=shape, dtype=dtype, device=device, requires_grad=requires_grad, pinned=pinned, **kwargs)
3156
3772
 
3157
- # use the CUDA default stream for synchronous behaviour with other streams
3158
- with warp.ScopedStream(arr.device.null_stream):
3159
- arr.fill_(value)
3773
+ arr.fill_(value)
3160
3774
 
3161
3775
  return arr
3162
3776
 
@@ -3295,7 +3909,6 @@ def from_numpy(
3295
3909
  data=arr,
3296
3910
  dtype=dtype,
3297
3911
  shape=shape,
3298
- owner=False,
3299
3912
  device=device,
3300
3913
  requires_grad=requires_grad,
3301
3914
  )
@@ -3339,7 +3952,6 @@ def pack_arg(kernel, arg_type, arg_name, value, device, adjoint=False):
3339
3952
  )
3340
3953
 
3341
3954
  # check device
3342
- # if a.device != device and not device.can_access(a.device):
3343
3955
  if value.device != device:
3344
3956
  raise RuntimeError(
3345
3957
  f"Error launching kernel '{kernel.key}', trying to launch on device='{device}', but input array for argument '{arg_name}' is on device={value.device}."
@@ -3495,22 +4107,29 @@ class Launch:
3495
4107
  for i, v in enumerate(values):
3496
4108
  self.set_param_at_index_from_ctype(i, v)
3497
4109
 
3498
- def launch(self) -> Any:
4110
+ def launch(self, stream=None) -> Any:
3499
4111
  if self.device.is_cpu:
3500
4112
  self.hooks.forward(*self.params)
3501
4113
  else:
4114
+ if stream is None:
4115
+ stream = self.device.stream
3502
4116
  runtime.core.cuda_launch_kernel(
3503
- self.device.context, self.hooks.forward, self.bounds.size, self.max_blocks, self.params_addr
4117
+ self.device.context,
4118
+ self.hooks.forward,
4119
+ self.bounds.size,
4120
+ self.max_blocks,
4121
+ self.params_addr,
4122
+ stream.cuda_stream,
3504
4123
  )
3505
4124
 
3506
4125
 
3507
4126
  def launch(
3508
4127
  kernel,
3509
4128
  dim: Tuple[int],
3510
- inputs: List,
3511
- outputs: List = [],
3512
- adj_inputs: List = [],
3513
- adj_outputs: List = [],
4129
+ inputs: Sequence = [],
4130
+ outputs: Sequence = [],
4131
+ adj_inputs: Sequence = [],
4132
+ adj_outputs: Sequence = [],
3514
4133
  device: Devicelike = None,
3515
4134
  stream: Stream = None,
3516
4135
  adjoint=False,
@@ -3525,7 +4144,7 @@ def launch(
3525
4144
  Args:
3526
4145
  kernel: The name of a Warp kernel function, decorated with the ``@wp.kernel`` decorator
3527
4146
  dim: The number of threads to launch the kernel, can be an integer, or a Tuple of ints with max of 4 dimensions
3528
- inputs: The input parameters to the kernel
4147
+ inputs: The input parameters to the kernel (optional)
3529
4148
  outputs: The output parameters (optional)
3530
4149
  adj_inputs: The adjoint inputs (optional)
3531
4150
  adj_outputs: The adjoint outputs (optional)
@@ -3570,8 +4189,13 @@ def launch(
3570
4189
 
3571
4190
  params.append(pack_arg(kernel, arg_type, arg_name, a, device, adjoint))
3572
4191
 
3573
- fwd_args = inputs + outputs
3574
- adj_args = adj_inputs + adj_outputs
4192
+ fwd_args = []
4193
+ fwd_args.extend(inputs)
4194
+ fwd_args.extend(outputs)
4195
+
4196
+ adj_args = []
4197
+ adj_args.extend(adj_inputs)
4198
+ adj_args.extend(adj_outputs)
3575
4199
 
3576
4200
  if (len(fwd_args)) != (len(kernel.adj.args)):
3577
4201
  raise RuntimeError(
@@ -3622,45 +4246,47 @@ def launch(
3622
4246
  kernel_args = [ctypes.c_void_p(ctypes.addressof(x)) for x in params]
3623
4247
  kernel_params = (ctypes.c_void_p * len(kernel_args))(*kernel_args)
3624
4248
 
3625
- with warp.ScopedStream(stream):
3626
- if adjoint:
3627
- if hooks.backward is None:
3628
- raise RuntimeError(
3629
- f"Failed to find backward kernel '{kernel.key}' from module '{kernel.module.name}' for device '{device}'"
3630
- )
4249
+ if stream is None:
4250
+ stream = device.stream
3631
4251
 
3632
- runtime.core.cuda_launch_kernel(
3633
- device.context, hooks.backward, bounds.size, max_blocks, kernel_params
4252
+ if adjoint:
4253
+ if hooks.backward is None:
4254
+ raise RuntimeError(
4255
+ f"Failed to find backward kernel '{kernel.key}' from module '{kernel.module.name}' for device '{device}'"
3634
4256
  )
3635
4257
 
3636
- else:
3637
- if hooks.forward is None:
3638
- raise RuntimeError(
3639
- f"Failed to find forward kernel '{kernel.key}' from module '{kernel.module.name}' for device '{device}'"
3640
- )
4258
+ runtime.core.cuda_launch_kernel(
4259
+ device.context, hooks.backward, bounds.size, max_blocks, kernel_params, stream.cuda_stream
4260
+ )
3641
4261
 
3642
- if record_cmd:
3643
- launch = Launch(
3644
- kernel=kernel,
3645
- hooks=hooks,
3646
- params=params,
3647
- params_addr=kernel_params,
3648
- bounds=bounds,
3649
- device=device,
3650
- )
3651
- return launch
4262
+ else:
4263
+ if hooks.forward is None:
4264
+ raise RuntimeError(
4265
+ f"Failed to find forward kernel '{kernel.key}' from module '{kernel.module.name}' for device '{device}'"
4266
+ )
3652
4267
 
3653
- else:
3654
- # launch
3655
- runtime.core.cuda_launch_kernel(
3656
- device.context, hooks.forward, bounds.size, max_blocks, kernel_params
3657
- )
4268
+ if record_cmd:
4269
+ launch = Launch(
4270
+ kernel=kernel,
4271
+ hooks=hooks,
4272
+ params=params,
4273
+ params_addr=kernel_params,
4274
+ bounds=bounds,
4275
+ device=device,
4276
+ )
4277
+ return launch
3658
4278
 
3659
- try:
3660
- runtime.verify_cuda_device(device)
3661
- except Exception as e:
3662
- print(f"Error launching kernel: {kernel.key} on device {device}")
3663
- raise e
4279
+ else:
4280
+ # launch
4281
+ runtime.core.cuda_launch_kernel(
4282
+ device.context, hooks.forward, bounds.size, max_blocks, kernel_params, stream.cuda_stream
4283
+ )
4284
+
4285
+ try:
4286
+ runtime.verify_cuda_device(device)
4287
+ except Exception as e:
4288
+ print(f"Error launching kernel: {kernel.key} on device {device}")
4289
+ raise e
3664
4290
 
3665
4291
  # record on tape if one is active
3666
4292
  if runtime.tape and record_tape:
@@ -3698,7 +4324,7 @@ def synchronize_device(device: Devicelike = None):
3698
4324
  or memory copies have completed.
3699
4325
 
3700
4326
  Args:
3701
- device: Device to synchronize. If None, synchronize the current CUDA device.
4327
+ device: Device to synchronize.
3702
4328
  """
3703
4329
 
3704
4330
  device = runtime.get_device(device)
@@ -3721,7 +4347,7 @@ def synchronize_stream(stream_or_device=None):
3721
4347
  else:
3722
4348
  stream = runtime.get_device(stream_or_device).stream
3723
4349
 
3724
- runtime.core.cuda_stream_synchronize(stream.device.context, stream.cuda_stream)
4350
+ runtime.core.cuda_stream_synchronize(stream.cuda_stream)
3725
4351
 
3726
4352
 
3727
4353
  def force_load(device: Union[Device, str, List[Device], List[str]] = None, modules: List[Module] = None):
@@ -3805,7 +4431,7 @@ def set_module_options(options: Dict[str, Any], module: Optional[Any] = None):
3805
4431
  for the current module individually. Available options are listed below.
3806
4432
 
3807
4433
  * **mode**: The compilation mode to use, can be "debug", or "release", defaults to the value of ``warp.config.mode``.
3808
- * **max_unroll**: The maximum fixed-size loop to unroll (default 16)
4434
+ * **max_unroll**: The maximum fixed-size loop to unroll, defaults to the value of ``warp.config.max_unroll``.
3809
4435
 
3810
4436
  Args:
3811
4437
 
@@ -3831,22 +4457,28 @@ def get_module_options(module: Optional[Any] = None) -> Dict[str, Any]:
3831
4457
  return get_module(m.__name__).options
3832
4458
 
3833
4459
 
3834
- def capture_begin(device: Devicelike = None, stream=None, force_module_load=None):
4460
+ def capture_begin(device: Devicelike = None, stream=None, force_module_load=None, external=False):
3835
4461
  """Begin capture of a CUDA graph
3836
4462
 
3837
4463
  Captures all subsequent kernel launches and memory operations on CUDA devices.
3838
- This can be used to record large numbers of kernels and replay them with low-overhead.
4464
+ This can be used to record large numbers of kernels and replay them with low overhead.
4465
+
4466
+ If `device` is specified, the capture will begin on the CUDA stream currently
4467
+ associated with the device. If `stream` is specified, the capture will begin
4468
+ on the given stream. If both are omitted, the capture will begin on the current
4469
+ stream of the current device.
3839
4470
 
3840
4471
  Args:
3841
4472
 
3842
- device: The device to capture on, if None the current CUDA device will be used
4473
+ device: The CUDA device to capture on
3843
4474
  stream: The CUDA stream to capture on
3844
4475
  force_module_load: Whether or not to force loading of all kernels before capture, in general it is better to use :func:`~warp.load_module()` to selectively load kernels.
4476
+ external: Whether the capture was already started externally
3845
4477
 
3846
4478
  """
3847
4479
 
3848
4480
  if force_module_load is None:
3849
- force_module_load = warp.config.graph_capture_module_load_default
4481
+ force_module_load = warp.config.enable_graph_capture_module_load_by_default
3850
4482
 
3851
4483
  if warp.config.verify_cuda:
3852
4484
  raise RuntimeError("Cannot use CUDA error verification during graph capture")
@@ -3857,24 +4489,36 @@ def capture_begin(device: Devicelike = None, stream=None, force_module_load=None
3857
4489
  device = runtime.get_device(device)
3858
4490
  if not device.is_cuda:
3859
4491
  raise RuntimeError("Must be a CUDA device")
4492
+ stream = device.stream
3860
4493
 
3861
- if force_module_load:
3862
- force_load(device)
4494
+ if external:
4495
+ # make sure the stream is already capturing
4496
+ if not stream.is_capturing:
4497
+ raise RuntimeError("External capture reported, but the stream is not capturing")
4498
+ else:
4499
+ # make sure the stream is not capturing yet
4500
+ if stream.is_capturing:
4501
+ raise RuntimeError("Graph capture already in progress on this stream")
3863
4502
 
3864
- device.is_capturing = True
4503
+ if force_module_load:
4504
+ force_load(device)
3865
4505
 
3866
- # disable garbage collection to avoid older allocations getting collected during graph capture
3867
- gc.disable()
4506
+ device.captures.add(stream)
3868
4507
 
3869
- with warp.ScopedStream(stream):
3870
- runtime.core.cuda_graph_begin_capture(device.context)
4508
+ if not runtime.core.cuda_graph_begin_capture(device.context, stream.cuda_stream, int(external)):
4509
+ raise RuntimeError(runtime.get_error_string())
3871
4510
 
3872
4511
 
3873
- def capture_end(device: Devicelike = None, stream=None) -> Graph:
4512
+ def capture_end(device: Devicelike = None, stream: Stream = None) -> Graph:
3874
4513
  """Ends the capture of a CUDA graph
3875
4514
 
4515
+ Args:
4516
+
4517
+ device: The CUDA device where capture began
4518
+ stream: The CUDA stream where capture began
4519
+
3876
4520
  Returns:
3877
- A handle to a CUDA graph object that can be launched with :func:`~warp.capture_launch()`
4521
+ A Graph object that can be launched with :func:`~warp.capture_launch()`
3878
4522
  """
3879
4523
 
3880
4524
  if stream is not None:
@@ -3883,20 +4527,22 @@ def capture_end(device: Devicelike = None, stream=None) -> Graph:
3883
4527
  device = runtime.get_device(device)
3884
4528
  if not device.is_cuda:
3885
4529
  raise RuntimeError("Must be a CUDA device")
4530
+ stream = device.stream
3886
4531
 
3887
- with warp.ScopedStream(stream):
3888
- graph = runtime.core.cuda_graph_end_capture(device.context)
4532
+ if stream not in device.captures:
4533
+ raise RuntimeError("Graph capture is not active on this stream")
3889
4534
 
3890
- device.is_capturing = False
4535
+ device.captures.remove(stream)
3891
4536
 
3892
- # re-enable GC
3893
- gc.enable()
4537
+ graph = ctypes.c_void_p()
4538
+ result = runtime.core.cuda_graph_end_capture(device.context, stream.cuda_stream, ctypes.byref(graph))
3894
4539
 
3895
- if graph is None:
3896
- raise RuntimeError(
3897
- "Error occurred during CUDA graph capture. This could be due to an unintended allocation or CPU/GPU synchronization event."
3898
- )
3899
- else:
4540
+ if not result:
4541
+ # A concrete error should've already been reported, so we don't need to go into details here
4542
+ raise RuntimeError(f"CUDA graph capture failed. {runtime.get_error_string()}")
4543
+
4544
+ # note that for external captures, we do not return a graph, because we don't instantiate it ourselves
4545
+ if graph:
3900
4546
  return Graph(device, graph)
3901
4547
 
3902
4548
 
@@ -3914,15 +4560,16 @@ def capture_launch(graph: Graph, stream: Stream = None):
3914
4560
  device = stream.device
3915
4561
  else:
3916
4562
  device = graph.device
4563
+ stream = device.stream
3917
4564
 
3918
- with warp.ScopedStream(stream):
3919
- runtime.core.cuda_graph_launch(device.context, graph.exec)
4565
+ if not runtime.core.cuda_graph_launch(graph.exec, stream.cuda_stream):
4566
+ raise RuntimeError(f"Graph launch error: {runtime.get_error_string()}")
3920
4567
 
3921
4568
 
3922
4569
  def copy(
3923
4570
  dest: warp.array, src: warp.array, dest_offset: int = 0, src_offset: int = 0, count: int = 0, stream: Stream = None
3924
4571
  ):
3925
- """Copy array contents from src to dest
4572
+ """Copy array contents from `src` to `dest`.
3926
4573
 
3927
4574
  Args:
3928
4575
  dest: Destination array, must be at least as big as source buffer
@@ -3932,6 +4579,12 @@ def copy(
3932
4579
  count: Number of array elements to copy (will copy all elements if set to 0)
3933
4580
  stream: The stream on which to perform the copy (optional)
3934
4581
 
4582
+ The stream, if specified, can be from any device. If the stream is omitted, then Warp selects a stream based on the following rules:
4583
+ (1) If the destination array is on a CUDA device, use the current stream on the destination device.
4584
+ (2) Otherwise, if the source array is on a CUDA device, use the current stream on the source device.
4585
+
4586
+ If neither source nor destination are on a CUDA device, no stream is used for the copy.
4587
+
3935
4588
  """
3936
4589
 
3937
4590
  if not warp.types.is_array(src) or not warp.types.is_array(dest):
@@ -3944,14 +4597,50 @@ def copy(
3944
4597
  if count == 0:
3945
4598
  return
3946
4599
 
3947
- # copying non-contiguous arrays requires that they are on the same device
3948
- if not (src.is_contiguous and dest.is_contiguous) and src.device != dest.device:
3949
- if dest.is_contiguous:
3950
- # make a contiguous copy of the source array
3951
- src = src.contiguous()
3952
- else:
3953
- # make a copy of the source array on the destination device
3954
- src = src.to(dest.device)
4600
+ # figure out the stream for the copy
4601
+ if stream is None:
4602
+ if dest.device.is_cuda:
4603
+ stream = dest.device.stream
4604
+ elif src.device.is_cuda:
4605
+ stream = src.device.stream
4606
+
4607
+ # Copying between different devices requires contiguous arrays. If the arrays
4608
+ # are not contiguous, we must use temporary staging buffers for the transfer.
4609
+ # TODO: We can skip the staging if device access is enabled.
4610
+ if src.device != dest.device:
4611
+ # If the source is not contiguous, make a contiguous copy on the source device.
4612
+ if not src.is_contiguous:
4613
+ # FIXME: We can't use a temporary CPU allocation during graph capture,
4614
+ # because launching the graph will crash after the allocation is
4615
+ # garbage-collected.
4616
+ if src.device.is_cpu and stream.is_capturing:
4617
+ raise RuntimeError("Failed to allocate a CPU staging buffer during graph capture")
4618
+ # This involves an allocation and a kernel launch, which must run on the source device.
4619
+ if src.device.is_cuda and stream != src.device.stream:
4620
+ src.device.stream.wait_stream(stream)
4621
+ src = src.contiguous()
4622
+ stream.wait_stream(src.device.stream)
4623
+ else:
4624
+ src = src.contiguous()
4625
+
4626
+ # The source is now contiguous. If the destination is not contiguous,
4627
+ # clone a contiguous copy on the destination device.
4628
+ if not dest.is_contiguous:
4629
+ # FIXME: We can't use a temporary CPU allocation during graph capture,
4630
+ # because launching the graph will crash after the allocation is
4631
+ # garbage-collected.
4632
+ if dest.device.is_cpu and stream.is_capturing:
4633
+ raise RuntimeError("Failed to allocate a CPU staging buffer during graph capture")
4634
+ # The allocation must run on the destination device
4635
+ if dest.device.is_cuda and stream != dest.device.stream:
4636
+ dest.device.stream.wait_stream(stream)
4637
+ tmp = empty_like(src, device=dest.device)
4638
+ stream.wait_stream(dest.device.stream)
4639
+ else:
4640
+ tmp = empty_like(src, device=dest.device)
4641
+ # Run the copy on the stream given by the caller
4642
+ copy(tmp, src, stream=stream)
4643
+ src = tmp
3955
4644
 
3956
4645
  if src.is_contiguous and dest.is_contiguous:
3957
4646
  bytes_to_copy = count * warp.types.type_size_in_bytes(src.dtype)
@@ -3975,32 +4664,33 @@ def copy(
3975
4664
  f"Trying to copy source buffer with size ({bytes_to_copy}) to offset ({dst_offset_in_bytes}) is larger than destination size ({dst_size_in_bytes})"
3976
4665
  )
3977
4666
 
3978
- if src.device.is_cpu and dest.device.is_cpu:
3979
- runtime.core.memcpy_h2h(dst_ptr, src_ptr, bytes_to_copy)
4667
+ if dest.device.is_cuda:
4668
+ if src.device.is_cuda:
4669
+ if src.device == dest.device:
4670
+ result = runtime.core.memcpy_d2d(
4671
+ dest.device.context, dst_ptr, src_ptr, bytes_to_copy, stream.cuda_stream
4672
+ )
4673
+ else:
4674
+ result = runtime.core.memcpy_p2p(
4675
+ dest.device.context, dst_ptr, src.device.context, src_ptr, bytes_to_copy, stream.cuda_stream
4676
+ )
4677
+ else:
4678
+ result = runtime.core.memcpy_h2d(
4679
+ dest.device.context, dst_ptr, src_ptr, bytes_to_copy, stream.cuda_stream
4680
+ )
3980
4681
  else:
3981
- # figure out the CUDA context/stream for the copy
3982
- if stream is not None:
3983
- copy_device = stream.device
3984
- elif dest.device.is_cuda:
3985
- copy_device = dest.device
4682
+ if src.device.is_cuda:
4683
+ result = runtime.core.memcpy_d2h(
4684
+ src.device.context, dst_ptr, src_ptr, bytes_to_copy, stream.cuda_stream
4685
+ )
3986
4686
  else:
3987
- copy_device = src.device
3988
-
3989
- with warp.ScopedStream(stream):
3990
- if src.device.is_cpu and dest.device.is_cuda:
3991
- runtime.core.memcpy_h2d(copy_device.context, dst_ptr, src_ptr, bytes_to_copy)
3992
- elif src.device.is_cuda and dest.device.is_cpu:
3993
- runtime.core.memcpy_d2h(copy_device.context, dst_ptr, src_ptr, bytes_to_copy)
3994
- elif src.device.is_cuda and dest.device.is_cuda:
3995
- if src.device == dest.device:
3996
- runtime.core.memcpy_d2d(copy_device.context, dst_ptr, src_ptr, bytes_to_copy)
3997
- else:
3998
- runtime.core.memcpy_peer(copy_device.context, dst_ptr, src_ptr, bytes_to_copy)
3999
- else:
4000
- raise RuntimeError("Unexpected source and destination combination")
4687
+ result = runtime.core.memcpy_h2h(dst_ptr, src_ptr, bytes_to_copy)
4688
+
4689
+ if not result:
4690
+ raise RuntimeError(f"Warp copy error: {runtime.get_error_string()}")
4001
4691
 
4002
4692
  else:
4003
- # handle non-contiguous and indexed arrays
4693
+ # handle non-contiguous arrays
4004
4694
 
4005
4695
  if src.shape != dest.shape:
4006
4696
  raise RuntimeError("Incompatible array shapes")
@@ -4028,11 +4718,24 @@ def copy(
4028
4718
  src_type = warp.types.array_type_id(src)
4029
4719
  dst_type = warp.types.array_type_id(dest)
4030
4720
 
4031
- if src.device.is_cuda:
4032
- with warp.ScopedStream(stream):
4033
- runtime.core.array_copy_device(src.device.context, dst_ptr, src_ptr, dst_type, src_type, src_elem_size)
4721
+ if dest.device.is_cuda:
4722
+ # This work involves a kernel launch, so it must run on the destination device.
4723
+ # If the copy stream is different, we need to synchronize it.
4724
+ if stream == dest.device.stream:
4725
+ result = runtime.core.array_copy_device(
4726
+ dest.device.context, dst_ptr, src_ptr, dst_type, src_type, src_elem_size
4727
+ )
4728
+ else:
4729
+ dest.device.stream.wait_stream(stream)
4730
+ result = runtime.core.array_copy_device(
4731
+ dest.device.context, dst_ptr, src_ptr, dst_type, src_type, src_elem_size
4732
+ )
4733
+ stream.wait_stream(dest.device.stream)
4034
4734
  else:
4035
- runtime.core.array_copy_host(dst_ptr, src_ptr, dst_type, src_type, src_elem_size)
4735
+ result = runtime.core.array_copy_host(dst_ptr, src_ptr, dst_type, src_type, src_elem_size)
4736
+
4737
+ if not result:
4738
+ raise RuntimeError(f"Warp copy error: {runtime.get_error_string()}")
4036
4739
 
4037
4740
  # copy gradient, if needed
4038
4741
  if hasattr(src, "grad") and src.grad is not None and hasattr(dest, "grad") and dest.grad is not None: