warp-lang 0.11.0__py3-none-manylinux2014_x86_64.whl → 1.0.0__py3-none-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (170) hide show
  1. warp/__init__.py +8 -0
  2. warp/bin/warp-clang.so +0 -0
  3. warp/bin/warp.so +0 -0
  4. warp/build.py +7 -6
  5. warp/build_dll.py +70 -79
  6. warp/builtins.py +10 -6
  7. warp/codegen.py +51 -19
  8. warp/config.py +7 -8
  9. warp/constants.py +3 -0
  10. warp/context.py +948 -245
  11. warp/dlpack.py +198 -113
  12. warp/examples/assets/bunny.usd +0 -0
  13. warp/examples/assets/cartpole.urdf +110 -0
  14. warp/examples/assets/crazyflie.usd +0 -0
  15. warp/examples/assets/cube.usda +42 -0
  16. warp/examples/assets/nv_ant.xml +92 -0
  17. warp/examples/assets/nv_humanoid.xml +183 -0
  18. warp/examples/assets/quadruped.urdf +268 -0
  19. warp/examples/assets/rocks.nvdb +0 -0
  20. warp/examples/assets/rocks.usd +0 -0
  21. warp/examples/assets/sphere.usda +56 -0
  22. warp/examples/assets/torus.usda +105 -0
  23. warp/examples/benchmarks/benchmark_api.py +383 -0
  24. warp/examples/benchmarks/benchmark_cloth.py +279 -0
  25. warp/examples/benchmarks/benchmark_cloth_cupy.py +88 -0
  26. warp/examples/benchmarks/benchmark_cloth_jax.py +100 -0
  27. warp/examples/benchmarks/benchmark_cloth_numba.py +142 -0
  28. warp/examples/benchmarks/benchmark_cloth_numpy.py +77 -0
  29. warp/examples/benchmarks/benchmark_cloth_pytorch.py +86 -0
  30. warp/examples/benchmarks/benchmark_cloth_taichi.py +112 -0
  31. warp/examples/benchmarks/benchmark_cloth_warp.py +146 -0
  32. warp/examples/benchmarks/benchmark_launches.py +295 -0
  33. warp/examples/core/example_dem.py +221 -0
  34. warp/examples/core/example_fluid.py +267 -0
  35. warp/examples/core/example_graph_capture.py +129 -0
  36. warp/examples/core/example_marching_cubes.py +177 -0
  37. warp/examples/core/example_mesh.py +154 -0
  38. warp/examples/core/example_mesh_intersect.py +193 -0
  39. warp/examples/core/example_nvdb.py +169 -0
  40. warp/examples/core/example_raycast.py +89 -0
  41. warp/examples/core/example_raymarch.py +178 -0
  42. warp/examples/core/example_render_opengl.py +141 -0
  43. warp/examples/core/example_sph.py +389 -0
  44. warp/examples/core/example_torch.py +181 -0
  45. warp/examples/core/example_wave.py +249 -0
  46. warp/examples/fem/bsr_utils.py +380 -0
  47. warp/examples/fem/example_apic_fluid.py +391 -0
  48. warp/examples/fem/example_convection_diffusion.py +168 -0
  49. warp/examples/fem/example_convection_diffusion_dg.py +209 -0
  50. warp/examples/fem/example_convection_diffusion_dg0.py +194 -0
  51. warp/examples/fem/example_deformed_geometry.py +159 -0
  52. warp/examples/fem/example_diffusion.py +173 -0
  53. warp/examples/fem/example_diffusion_3d.py +152 -0
  54. warp/examples/fem/example_diffusion_mgpu.py +214 -0
  55. warp/examples/fem/example_mixed_elasticity.py +222 -0
  56. warp/examples/fem/example_navier_stokes.py +243 -0
  57. warp/examples/fem/example_stokes.py +192 -0
  58. warp/examples/fem/example_stokes_transfer.py +249 -0
  59. warp/examples/fem/mesh_utils.py +109 -0
  60. warp/examples/fem/plot_utils.py +287 -0
  61. warp/examples/optim/example_bounce.py +248 -0
  62. warp/examples/optim/example_cloth_throw.py +210 -0
  63. warp/examples/optim/example_diffray.py +535 -0
  64. warp/examples/optim/example_drone.py +850 -0
  65. warp/examples/optim/example_inverse_kinematics.py +169 -0
  66. warp/examples/optim/example_inverse_kinematics_torch.py +170 -0
  67. warp/examples/optim/example_spring_cage.py +234 -0
  68. warp/examples/optim/example_trajectory.py +201 -0
  69. warp/examples/sim/example_cartpole.py +128 -0
  70. warp/examples/sim/example_cloth.py +184 -0
  71. warp/examples/sim/example_granular.py +113 -0
  72. warp/examples/sim/example_granular_collision_sdf.py +185 -0
  73. warp/examples/sim/example_jacobian_ik.py +213 -0
  74. warp/examples/sim/example_particle_chain.py +106 -0
  75. warp/examples/sim/example_quadruped.py +179 -0
  76. warp/examples/sim/example_rigid_chain.py +191 -0
  77. warp/examples/sim/example_rigid_contact.py +176 -0
  78. warp/examples/sim/example_rigid_force.py +126 -0
  79. warp/examples/sim/example_rigid_gyroscopic.py +97 -0
  80. warp/examples/sim/example_rigid_soft_contact.py +124 -0
  81. warp/examples/sim/example_soft_body.py +178 -0
  82. warp/fabric.py +29 -20
  83. warp/fem/cache.py +0 -1
  84. warp/fem/dirichlet.py +0 -2
  85. warp/fem/integrate.py +0 -1
  86. warp/jax.py +45 -0
  87. warp/jax_experimental.py +339 -0
  88. warp/native/builtin.h +12 -0
  89. warp/native/bvh.cu +18 -18
  90. warp/native/clang/clang.cpp +8 -3
  91. warp/native/cuda_util.cpp +94 -5
  92. warp/native/cuda_util.h +35 -6
  93. warp/native/cutlass_gemm.cpp +1 -1
  94. warp/native/cutlass_gemm.cu +4 -1
  95. warp/native/error.cpp +66 -0
  96. warp/native/error.h +27 -0
  97. warp/native/mesh.cu +2 -2
  98. warp/native/reduce.cu +4 -4
  99. warp/native/runlength_encode.cu +2 -2
  100. warp/native/scan.cu +2 -2
  101. warp/native/sparse.cu +0 -1
  102. warp/native/temp_buffer.h +2 -2
  103. warp/native/warp.cpp +95 -60
  104. warp/native/warp.cu +1053 -218
  105. warp/native/warp.h +49 -32
  106. warp/optim/linear.py +33 -16
  107. warp/render/render_opengl.py +202 -101
  108. warp/render/render_usd.py +82 -40
  109. warp/sim/__init__.py +13 -4
  110. warp/sim/articulation.py +4 -5
  111. warp/sim/collide.py +320 -175
  112. warp/sim/import_mjcf.py +25 -30
  113. warp/sim/import_urdf.py +94 -63
  114. warp/sim/import_usd.py +51 -36
  115. warp/sim/inertia.py +3 -2
  116. warp/sim/integrator.py +233 -0
  117. warp/sim/integrator_euler.py +447 -469
  118. warp/sim/integrator_featherstone.py +1991 -0
  119. warp/sim/integrator_xpbd.py +1420 -640
  120. warp/sim/model.py +765 -487
  121. warp/sim/particles.py +2 -1
  122. warp/sim/render.py +35 -13
  123. warp/sim/utils.py +222 -11
  124. warp/stubs.py +8 -0
  125. warp/tape.py +16 -1
  126. warp/tests/aux_test_grad_customs.py +23 -0
  127. warp/tests/test_array.py +190 -1
  128. warp/tests/test_async.py +656 -0
  129. warp/tests/test_bool.py +50 -0
  130. warp/tests/test_dlpack.py +164 -11
  131. warp/tests/test_examples.py +166 -74
  132. warp/tests/test_fem.py +8 -1
  133. warp/tests/test_generics.py +15 -5
  134. warp/tests/test_grad.py +1 -1
  135. warp/tests/test_grad_customs.py +172 -12
  136. warp/tests/test_jax.py +254 -0
  137. warp/tests/test_large.py +29 -6
  138. warp/tests/test_launch.py +25 -0
  139. warp/tests/test_linear_solvers.py +20 -3
  140. warp/tests/test_matmul.py +61 -16
  141. warp/tests/test_matmul_lite.py +13 -13
  142. warp/tests/test_mempool.py +186 -0
  143. warp/tests/test_multigpu.py +3 -0
  144. warp/tests/test_options.py +16 -2
  145. warp/tests/test_peer.py +137 -0
  146. warp/tests/test_print.py +3 -1
  147. warp/tests/test_quat.py +23 -0
  148. warp/tests/test_sim_kinematics.py +97 -0
  149. warp/tests/test_snippet.py +126 -3
  150. warp/tests/test_streams.py +108 -79
  151. warp/tests/test_torch.py +16 -8
  152. warp/tests/test_utils.py +32 -27
  153. warp/tests/test_verify_fp.py +65 -0
  154. warp/tests/test_volume.py +1 -1
  155. warp/tests/unittest_serial.py +2 -0
  156. warp/tests/unittest_suites.py +12 -0
  157. warp/tests/unittest_utils.py +14 -7
  158. warp/thirdparty/unittest_parallel.py +15 -3
  159. warp/torch.py +10 -8
  160. warp/types.py +363 -246
  161. warp/utils.py +143 -19
  162. warp_lang-1.0.0.dist-info/LICENSE.md +126 -0
  163. warp_lang-1.0.0.dist-info/METADATA +394 -0
  164. {warp_lang-0.11.0.dist-info → warp_lang-1.0.0.dist-info}/RECORD +167 -86
  165. warp/sim/optimizer.py +0 -138
  166. warp_lang-0.11.0.dist-info/LICENSE.md +0 -36
  167. warp_lang-0.11.0.dist-info/METADATA +0 -238
  168. /warp/tests/{walkthough_debug.py → walkthrough_debug.py} +0 -0
  169. {warp_lang-0.11.0.dist-info → warp_lang-1.0.0.dist-info}/WHEEL +0 -0
  170. {warp_lang-0.11.0.dist-info → warp_lang-1.0.0.dist-info}/top_level.txt +0 -0
warp/native/warp.cu CHANGED
@@ -9,11 +9,17 @@
9
9
  #include "warp.h"
10
10
  #include "scan.h"
11
11
  #include "cuda_util.h"
12
+ #include "error.h"
12
13
 
13
14
  #include <nvrtc.h>
14
15
  #include <nvPTXCompiler.h>
15
16
 
17
+ #include <algorithm>
18
+ #include <iterator>
19
+ #include <list>
16
20
  #include <map>
21
+ #include <unordered_map>
22
+ #include <unordered_set>
17
23
  #include <vector>
18
24
 
19
25
  #define check_nvrtc(code) (check_nvrtc_result(code, __FILE__, __LINE__))
@@ -81,14 +87,55 @@ struct DeviceInfo
81
87
  char name[kNameLen] = "";
82
88
  int arch = 0;
83
89
  int is_uva = 0;
84
- int is_memory_pool_supported = 0;
90
+ int is_mempool_supported = 0;
91
+ CUcontext primary_context = NULL;
85
92
  };
86
93
 
87
94
  struct ContextInfo
88
95
  {
89
96
  DeviceInfo* device_info = NULL;
90
97
 
91
- CUstream stream = NULL; // created when needed
98
+ // the current stream, managed from Python (see cuda_context_set_stream() and cuda_context_get_stream())
99
+ CUstream stream = NULL;
100
+ };
101
+
102
+ struct CaptureInfo
103
+ {
104
+ CUstream stream = NULL; // the main stream where capture begins and ends
105
+ uint64_t id = 0; // unique capture id from CUDA
106
+ bool external = false; // whether this is an external capture
107
+ };
108
+
109
+ struct StreamInfo
110
+ {
111
+ CUevent cached_event = NULL; // event used for stream synchronization (cached to avoid creating temporary events)
112
+ CaptureInfo* capture = NULL; // capture info (only if started on this stream)
113
+ };
114
+
115
+ struct GraphInfo
116
+ {
117
+ std::vector<void*> unfreed_allocs;
118
+ };
119
+
120
+ // Information for graph allocations that are not freed by the graph.
121
+ // These allocations have a shared ownership:
122
+ // - The graph instance allocates/maps the memory on each launch, even if the user reference is released.
123
+ // - The user reference must remain valid even if the graph is destroyed.
124
+ // The memory will be freed once the user reference is released and the graph is destroyed.
125
+ struct GraphAllocInfo
126
+ {
127
+ uint64_t capture_id = 0;
128
+ void* context = NULL;
129
+ bool ref_exists = false; // whether user reference still exists
130
+ bool graph_destroyed = false; // whether graph instance was destroyed
131
+ };
132
+
133
+ // Information used when deferring deallocations.
134
+ struct FreeInfo
135
+ {
136
+ void* context = NULL;
137
+ void* ptr = NULL;
138
+ bool is_async = false;
92
139
  };
93
140
 
94
141
  // cached info for all devices, indexed by ordinal
@@ -100,6 +147,22 @@ static std::map<CUdevice, DeviceInfo*> g_device_map;
100
147
  // cached info for all known contexts
101
148
  static std::map<CUcontext, ContextInfo> g_contexts;
102
149
 
150
+ // cached info for all known streams (including registered external streams)
151
+ static std::unordered_map<CUstream, StreamInfo> g_streams;
152
+
153
+ // Ongoing graph captures registered using wp.capture_begin().
154
+ // This maps the capture id to the stream where capture was started.
155
+ // See cuda_graph_begin_capture(), cuda_graph_end_capture(), and free_device_async().
156
+ static std::unordered_map<uint64_t, CaptureInfo*> g_captures;
157
+
158
+ // Memory allocated during graph capture requires special handling.
159
+ // See alloc_device_async() and free_device_async().
160
+ static std::unordered_map<void*, GraphAllocInfo> g_graph_allocs;
161
+
162
+ // Memory that cannot be freed immediately gets queued here.
163
+ // Call free_deferred_allocs() to release.
164
+ static std::vector<FreeInfo> g_deferred_free_list;
165
+
103
166
 
104
167
  void cuda_set_context_restore_policy(bool always_restore)
105
168
  {
@@ -116,12 +179,12 @@ int cuda_init()
116
179
  if (!init_cuda_driver())
117
180
  return -1;
118
181
 
119
- int deviceCount = 0;
120
- if (check_cu(cuDeviceGetCount_f(&deviceCount)))
182
+ int device_count = 0;
183
+ if (check_cu(cuDeviceGetCount_f(&device_count)))
121
184
  {
122
- g_devices.resize(deviceCount);
185
+ g_devices.resize(device_count);
123
186
 
124
- for (int i = 0; i < deviceCount; i++)
187
+ for (int i = 0; i < device_count; i++)
125
188
  {
126
189
  CUdevice device;
127
190
  if (check_cu(cuDeviceGet_f(&device, i)))
@@ -135,7 +198,7 @@ int cuda_init()
135
198
  check_cu(cuDeviceGetAttribute_f(&g_devices[i].pci_bus_id, CU_DEVICE_ATTRIBUTE_PCI_BUS_ID, device));
136
199
  check_cu(cuDeviceGetAttribute_f(&g_devices[i].pci_device_id, CU_DEVICE_ATTRIBUTE_PCI_DEVICE_ID, device));
137
200
  check_cu(cuDeviceGetAttribute_f(&g_devices[i].is_uva, CU_DEVICE_ATTRIBUTE_UNIFIED_ADDRESSING, device));
138
- check_cu(cuDeviceGetAttribute_f(&g_devices[i].is_memory_pool_supported, CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED, device));
201
+ check_cu(cuDeviceGetAttribute_f(&g_devices[i].is_mempool_supported, CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED, device));
139
202
  int major = 0;
140
203
  int minor = 0;
141
204
  check_cu(cuDeviceGetAttribute_f(&major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device));
@@ -168,9 +231,9 @@ static inline CUcontext get_current_context()
168
231
  return NULL;
169
232
  }
170
233
 
171
- static inline CUstream get_current_stream()
234
+ static inline CUstream get_current_stream(void* context=NULL)
172
235
  {
173
- return static_cast<CUstream>(cuda_context_get_stream(NULL));
236
+ return static_cast<CUstream>(cuda_context_get_stream(context));
174
237
  }
175
238
 
176
239
  static ContextInfo* get_context_info(CUcontext ctx)
@@ -191,11 +254,22 @@ static ContextInfo* get_context_info(CUcontext ctx)
191
254
  {
192
255
  // previously unseen context, add the info
193
256
  ContextGuard guard(ctx, true);
194
- ContextInfo context_info;
257
+
195
258
  CUdevice device;
196
259
  if (check_cu(cuCtxGetDevice_f(&device)))
197
260
  {
198
- context_info.device_info = g_device_map[device];
261
+ DeviceInfo* device_info = g_device_map[device];
262
+
263
+ // workaround for https://nvbugspro.nvidia.com/bug/4456003
264
+ if (device_info->is_mempool_supported)
265
+ {
266
+ void* dummy = NULL;
267
+ check_cuda(cudaMallocAsync(&dummy, 1, NULL));
268
+ check_cuda(cudaFreeAsync(dummy, NULL));
269
+ }
270
+
271
+ ContextInfo context_info;
272
+ context_info.device_info = device_info;
199
273
  auto result = g_contexts.insert(std::make_pair(ctx, context_info));
200
274
  return &result.first->second;
201
275
  }
@@ -204,10 +278,116 @@ static ContextInfo* get_context_info(CUcontext ctx)
204
278
  return NULL;
205
279
  }
206
280
 
281
+ static inline ContextInfo* get_context_info(void* context)
282
+ {
283
+ return get_context_info(static_cast<CUcontext>(context));
284
+ }
285
+
286
+ static inline StreamInfo* get_stream_info(CUstream stream)
287
+ {
288
+ auto it = g_streams.find(stream);
289
+ if (it != g_streams.end())
290
+ return &it->second;
291
+ else
292
+ return NULL;
293
+ }
294
+
295
+ static void deferred_free(void* ptr, void* context, bool is_async)
296
+ {
297
+ FreeInfo free_info;
298
+ free_info.ptr = ptr;
299
+ free_info.context = context ? context : get_current_context();
300
+ free_info.is_async = is_async;
301
+ g_deferred_free_list.push_back(free_info);
302
+ }
303
+
304
+ static int free_deferred_allocs(void* context = NULL)
305
+ {
306
+ if (g_deferred_free_list.empty() || !g_captures.empty())
307
+ return 0;
308
+
309
+ int num_freed_allocs = 0;
310
+ for (auto it = g_deferred_free_list.begin(); it != g_deferred_free_list.end(); /*noop*/)
311
+ {
312
+ const FreeInfo& free_info = *it;
313
+
314
+ // free the pointer if it matches the given context or if the context is unspecified
315
+ if (free_info.context == context || !context)
316
+ {
317
+ ContextGuard guard(free_info.context);
318
+
319
+ if (free_info.is_async)
320
+ {
321
+ // this could be a regular stream-ordered allocation or a graph allocation
322
+ cudaError_t res = cudaFreeAsync(free_info.ptr, NULL);
323
+ if (res != cudaSuccess)
324
+ {
325
+ if (res == cudaErrorInvalidValue)
326
+ {
327
+ // This can happen if we try to release the pointer but the graph was
328
+ // never launched, so the memory isn't mapped.
329
+ // This is fine, so clear the error.
330
+ cudaGetLastError();
331
+ }
332
+ else
333
+ {
334
+ // something else went wrong, report error
335
+ check_cuda(res);
336
+ }
337
+ }
338
+ }
339
+ else
340
+ {
341
+ check_cuda(cudaFree(free_info.ptr));
342
+ }
343
+
344
+ ++num_freed_allocs;
345
+
346
+ it = g_deferred_free_list.erase(it);
347
+ }
348
+ else
349
+ {
350
+ ++it;
351
+ }
352
+ }
353
+
354
+ return num_freed_allocs;
355
+ }
356
+
357
+ static void CUDART_CB on_graph_destroy(void* user_data)
358
+ {
359
+ if (!user_data)
360
+ return;
361
+
362
+ GraphInfo* graph_info = static_cast<GraphInfo*>(user_data);
363
+
364
+ for (void* ptr : graph_info->unfreed_allocs)
365
+ {
366
+ auto alloc_iter = g_graph_allocs.find(ptr);
367
+ if (alloc_iter != g_graph_allocs.end())
368
+ {
369
+ GraphAllocInfo& alloc_info = alloc_iter->second;
370
+ if (alloc_info.ref_exists)
371
+ {
372
+ // unreference from graph so the pointer will be deallocated when the user reference goes away
373
+ alloc_info.graph_destroyed = true;
374
+ }
375
+ else
376
+ {
377
+ // the pointer can be freed, but we can't call CUDA functions in this callback, so defer it
378
+ deferred_free(ptr, alloc_info.context, true);
379
+ g_graph_allocs.erase(alloc_iter);
380
+ }
381
+ }
382
+ }
383
+
384
+ delete graph_info;
385
+ }
386
+
207
387
 
208
388
  void* alloc_pinned(size_t s)
209
389
  {
210
- void* ptr;
390
+ void* ptr = NULL;
211
391
  check_cuda(cudaMallocHost(&ptr, s));
212
392
  return ptr;
213
393
  }
@@ -218,84 +398,320 @@ void free_pinned(void* ptr)
218
398
  }
219
399
 
220
400
  void* alloc_device(void* context, size_t s)
401
+ {
402
+ int ordinal = cuda_context_get_device_ordinal(context);
403
+
404
+ // use stream-ordered allocator if available
405
+ if (cuda_device_is_mempool_supported(ordinal))
406
+ return alloc_device_async(context, s);
407
+ else
408
+ return alloc_device_default(context, s);
409
+ }
410
+
411
+ void free_device(void* context, void* ptr)
412
+ {
413
+ int ordinal = cuda_context_get_device_ordinal(context);
414
+
415
+ // use stream-ordered allocator if available
416
+ if (cuda_device_is_mempool_supported(ordinal))
417
+ free_device_async(context, ptr);
418
+ else
419
+ free_device_default(context, ptr);
420
+ }
421
+
422
+ void* alloc_device_default(void* context, size_t s)
221
423
  {
222
424
  ContextGuard guard(context);
223
425
 
224
- void* ptr;
426
+ void* ptr = NULL;
225
427
  check_cuda(cudaMalloc(&ptr, s));
428
+
226
429
  return ptr;
227
430
  }
228
431
 
229
- void* alloc_temp_device(void* context, size_t s)
432
+ void free_device_default(void* context, void* ptr)
230
433
  {
231
- // "cudaMallocAsync ignores the current device/context when determining where the allocation will reside. Instead,
232
- // cudaMallocAsync determines the resident device based on the specified memory pool or the supplied stream."
233
434
  ContextGuard guard(context);
234
435
 
235
- void* ptr;
236
-
237
- if (cuda_context_is_memory_pool_supported(context))
436
+ // check if a capture is in progress
437
+ if (g_captures.empty())
238
438
  {
239
- check_cuda(cudaMallocAsync(&ptr, s, get_current_stream()));
439
+ check_cuda(cudaFree(ptr));
240
440
  }
241
441
  else
242
442
  {
243
- check_cuda(cudaMalloc(&ptr, s));
443
+ // we must defer the operation until graph captures complete
444
+ deferred_free(ptr, context, false);
244
445
  }
245
-
246
- return ptr;
247
446
  }
248
447
 
249
- void free_device(void* context, void* ptr)
448
+ void* alloc_device_async(void* context, size_t s)
250
449
  {
450
+ // stream-ordered allocations don't rely on the current context,
451
+ // but we set the context here for consistent behaviour
251
452
  ContextGuard guard(context);
252
453
 
253
- check_cuda(cudaFree(ptr));
454
+ ContextInfo* context_info = get_context_info(context);
455
+ if (!context_info)
456
+ return NULL;
457
+
458
+ CUstream stream = context_info->stream;
459
+
460
+ void* ptr = NULL;
461
+ check_cuda(cudaMallocAsync(&ptr, s, stream));
462
+
463
+ if (ptr)
464
+ {
465
+ // if the stream is capturing, the allocation requires special handling
466
+ if (cuda_stream_is_capturing(stream))
467
+ {
468
+ // check if this is a known capture
469
+ uint64_t capture_id = get_capture_id(stream);
470
+ auto capture_iter = g_captures.find(capture_id);
471
+ if (capture_iter != g_captures.end())
472
+ {
473
+ // remember graph allocation details
474
+ GraphAllocInfo alloc_info;
475
+ alloc_info.capture_id = capture_id;
476
+ alloc_info.context = context ? context : get_current_context();
477
+ alloc_info.ref_exists = true; // user reference created and returned here
478
+ alloc_info.graph_destroyed = false; // graph not destroyed yet
479
+ g_graph_allocs[ptr] = alloc_info;
480
+ }
481
+ }
482
+ }
483
+
484
+ return ptr;
254
485
  }
255
486
 
256
- void free_temp_device(void* context, void* ptr)
487
+ void free_device_async(void* context, void* ptr)
257
488
  {
489
+ // stream-ordered allocators generally don't rely on the current context,
490
+ // but we set the context here for consistent behaviour
258
491
  ContextGuard guard(context);
259
492
 
260
- if (cuda_context_is_memory_pool_supported(context))
493
+ // NB: Stream-ordered deallocations are tricky, because the memory could still be used on another stream
494
+ // or even multiple streams. To avoid use-after-free errors, we need to ensure that all preceding work
495
+ // completes before releasing the memory. The strategy is different for regular stream-ordered allocations
496
+ // and allocations made during graph capture. See below for details.
497
+
498
+ // check if this allocation was made during graph capture
499
+ auto alloc_iter = g_graph_allocs.find(ptr);
500
+ if (alloc_iter == g_graph_allocs.end())
261
501
  {
262
- check_cuda(cudaFreeAsync(ptr, get_current_stream()));
502
+ // Not a graph allocation.
503
+ // Check if graph capture is ongoing.
504
+ if (g_captures.empty())
505
+ {
506
+ // cudaFreeAsync on the null stream does not block or trigger synchronization, but it postpones
507
+ // the deallocation until a synchronization point is reached, so preceding work on this pointer
508
+ // should safely complete.
509
+ check_cuda(cudaFreeAsync(ptr, NULL));
510
+ }
511
+ else
512
+ {
513
+ // We must defer the free operation until graph capture completes.
514
+ deferred_free(ptr, context, true);
515
+ }
263
516
  }
264
517
  else
265
518
  {
266
- check_cuda(cudaFree(ptr));
519
+ // get the graph allocation details
520
+ GraphAllocInfo& alloc_info = alloc_iter->second;
521
+
522
+ uint64_t capture_id = alloc_info.capture_id;
523
+
524
+ // check if the capture is still active
525
+ auto capture_iter = g_captures.find(capture_id);
526
+ if (capture_iter != g_captures.end())
527
+ {
528
+ // Add a mem free node. Use all current leaf nodes as dependencies to ensure that all prior
529
+ // work completes before deallocating. This works with both Warp-initiated and external captures
530
+ // and avoids the need to explicitly track all streams used during the capture.
531
+ CaptureInfo* capture = capture_iter->second;
532
+ cudaGraph_t graph = get_capture_graph(capture->stream);
533
+ std::vector<cudaGraphNode_t> leaf_nodes;
534
+ if (graph && get_graph_leaf_nodes(graph, leaf_nodes))
535
+ {
536
+ cudaGraphNode_t free_node;
537
+ check_cuda(cudaGraphAddMemFreeNode(&free_node, graph, leaf_nodes.data(), leaf_nodes.size(), ptr));
538
+ }
539
+
540
+ // we're done with this allocation, it's owned by the graph
541
+ g_graph_allocs.erase(alloc_iter);
542
+ }
543
+ else
544
+ {
545
+ // the capture has ended
546
+ // if the owning graph was already destroyed, we can free the pointer now
547
+ if (alloc_info.graph_destroyed)
548
+ {
549
+ if (g_captures.empty())
550
+ {
551
+ // try to free the pointer now
552
+ cudaError_t res = cudaFreeAsync(ptr, NULL);
553
+ if (res == cudaErrorInvalidValue)
554
+ {
555
+ // This can happen if we try to release the pointer but the graph was
556
+ // never launched, so the memory isn't mapped.
557
+ // This is fine, so clear the error.
558
+ cudaGetLastError();
559
+ }
560
+ else
561
+ {
562
+ // check for other errors
563
+ check_cuda(res);
564
+ }
565
+ }
566
+ else
567
+ {
568
+ // We must defer the operation until graph capture completes.
569
+ deferred_free(ptr, context, true);
570
+ }
571
+
572
+ // we're done with this allocation
573
+ g_graph_allocs.erase(alloc_iter);
574
+ }
575
+ else
576
+ {
577
+ // graph still exists
578
+ // unreference the pointer so it will be deallocated once the graph instance is destroyed
579
+ alloc_info.ref_exists = false;
580
+ }
581
+ }
267
582
  }
268
583
  }
269
584
 
270
- void memcpy_h2d(void* context, void* dest, void* src, size_t n)
585
+ bool memcpy_h2d(void* context, void* dest, void* src, size_t n, void* stream)
271
586
  {
272
587
  ContextGuard guard(context);
273
-
274
- check_cuda(cudaMemcpyAsync(dest, src, n, cudaMemcpyHostToDevice, get_current_stream()));
588
+
589
+ CUstream cuda_stream;
590
+ if (stream != WP_CURRENT_STREAM)
591
+ cuda_stream = static_cast<CUstream>(stream);
592
+ else
593
+ cuda_stream = get_current_stream(context);
594
+
595
+ return check_cuda(cudaMemcpyAsync(dest, src, n, cudaMemcpyHostToDevice, cuda_stream));
275
596
  }
276
597
 
277
- void memcpy_d2h(void* context, void* dest, void* src, size_t n)
598
+ bool memcpy_d2h(void* context, void* dest, void* src, size_t n, void* stream)
278
599
  {
279
600
  ContextGuard guard(context);
280
601
 
281
- check_cuda(cudaMemcpyAsync(dest, src, n, cudaMemcpyDeviceToHost, get_current_stream()));
602
+ CUstream cuda_stream;
603
+ if (stream != WP_CURRENT_STREAM)
604
+ cuda_stream = static_cast<CUstream>(stream);
605
+ else
606
+ cuda_stream = get_current_stream(context);
607
+
608
+ return check_cuda(cudaMemcpyAsync(dest, src, n, cudaMemcpyDeviceToHost, cuda_stream));
282
609
  }
283
610
 
284
- void memcpy_d2d(void* context, void* dest, void* src, size_t n)
611
+ bool memcpy_d2d(void* context, void* dest, void* src, size_t n, void* stream)
285
612
  {
286
613
  ContextGuard guard(context);
287
614
 
288
- check_cuda(cudaMemcpyAsync(dest, src, n, cudaMemcpyDeviceToDevice, get_current_stream()));
615
+ CUstream cuda_stream;
616
+ if (stream != WP_CURRENT_STREAM)
617
+ cuda_stream = static_cast<CUstream>(stream);
618
+ else
619
+ cuda_stream = get_current_stream(context);
620
+
621
+ return check_cuda(cudaMemcpyAsync(dest, src, n, cudaMemcpyDeviceToDevice, cuda_stream));
289
622
  }
290
623
 
291
- void memcpy_peer(void* context, void* dest, void* src, size_t n)
624
+ bool memcpy_p2p(void* dst_context, void* dst, void* src_context, void* src, size_t n, void* stream)
292
625
  {
293
- ContextGuard guard(context);
626
+ // ContextGuard guard(context);
627
+
628
+ CUstream cuda_stream;
629
+ if (stream != WP_CURRENT_STREAM)
630
+ cuda_stream = static_cast<CUstream>(stream);
631
+ else
632
+ cuda_stream = get_current_stream(dst_context);
633
+
634
+ // Notes:
635
+ // - cuMemcpyPeerAsync() works fine with both regular and pooled allocations (cudaMalloc() and cudaMallocAsync(), respectively)
636
+ // when not capturing a graph.
637
+ // - cuMemcpyPeerAsync() is not supported during graph capture, so we must use cudaMemcpyAsync() with kind=cudaMemcpyDefault.
638
+ // - cudaMemcpyAsync() works fine with regular allocations, but doesn't work with pooled allocations
639
+ // unless mempool access has been enabled.
640
+ // - There is no reliable way to check if mempool access is enabled during graph capture,
641
+ // because cudaMemPoolGetAccess() cannot be called during graph capture.
642
+ // - CUDA will report error 1 (invalid argument) if cudaMemcpyAsync() is called but mempool access is not enabled.
643
+
644
+ if (!cuda_stream_is_capturing(stream))
645
+ {
646
+ return check_cu(cuMemcpyPeerAsync_f(
647
+ (CUdeviceptr)dst, (CUcontext)dst_context,
648
+ (CUdeviceptr)src, (CUcontext)src_context,
649
+ n, cuda_stream));
650
+ }
651
+ else
652
+ {
653
+ cudaError_t result = cudaSuccess;
654
+
655
+ // cudaMemcpyAsync() is sensitive to the bound context to resolve pointer locations.
656
+ // If fails with cudaErrorInvalidValue if it cannot resolve an argument.
657
+ // We first try the copy in the destination context, then if it fails we retry in the source context.
658
+ // The cudaErrorInvalidValue error doesn't cause graph capture to fail, so it's ok to retry.
659
+ // Since this trial-and-error shenanigans only happens during capture, there
660
+ // is no perf impact when the graph is launched.
661
+ // For bonus points, this approach simplifies memory pool access requirements.
662
+ // Access only needs to be enabled one way, either from the source device to the destination device
663
+ // or vice versa. Sometimes, when it's really quiet, you can actually hear my genius.
664
+ {
665
+ // try doing the copy in the destination context
666
+ ContextGuard guard(dst_context);
667
+ result = cudaMemcpyAsync(dst, src, n, cudaMemcpyDefault, cuda_stream);
294
668
 
295
- // NB: assumes devices involved support UVA
296
- check_cuda(cudaMemcpyAsync(dest, src, n, cudaMemcpyDefault, get_current_stream()));
669
+ if (result != cudaSuccess)
670
+ {
671
+ // clear error in destination context
672
+ cudaGetLastError();
673
+
674
+ // try doing the copy in the source context
675
+ ContextGuard guard(src_context);
676
+ result = cudaMemcpyAsync(dst, src, n, cudaMemcpyDefault, cuda_stream);
677
+
678
+ // clear error in source context
679
+ cudaGetLastError();
680
+ }
681
+ }
682
+
683
+ // If the copy failed, try to detect if mempool allocations are involved to generate a helpful error message.
684
+ if (!check_cuda(result))
685
+ {
686
+ if (result == cudaErrorInvalidValue && src != NULL && dst != NULL)
687
+ {
688
+ // check if either of the pointers was allocated from a mempool
689
+ void* src_mempool = NULL;
690
+ void* dst_mempool = NULL;
691
+ cuPointerGetAttribute_f(&src_mempool, CU_POINTER_ATTRIBUTE_MEMPOOL_HANDLE, (CUdeviceptr)src);
692
+ cuPointerGetAttribute_f(&dst_mempool, CU_POINTER_ATTRIBUTE_MEMPOOL_HANDLE, (CUdeviceptr)dst);
693
+ cudaGetLastError(); // clear any errors
694
+ // check if either of the pointers was allocated during graph capture
695
+ auto src_alloc = g_graph_allocs.find(src);
696
+ auto dst_alloc = g_graph_allocs.find(dst);
697
+ if (src_mempool != NULL || src_alloc != g_graph_allocs.end() ||
698
+ dst_mempool != NULL || dst_alloc != g_graph_allocs.end())
699
+ {
700
+ wp::append_error_string("*** CUDA mempool allocations were used in a peer-to-peer copy during graph capture.");
701
+ wp::append_error_string("*** This operation fails if mempool access is not enabled between the peer devices.");
702
+ wp::append_error_string("*** Either enable mempool access between the devices or use the default CUDA allocator");
703
+ wp::append_error_string("*** to pre-allocate the arrays before graph capture begins.");
704
+ }
705
+ }
706
+
707
+ return false;
708
+ }
709
+
710
+ return true;
711
+ }
297
712
  }
298
713
 
714
+
299
715
  __global__ void memset_kernel(int* dest, int value, size_t n)
300
716
  {
301
717
  const size_t tid = wp::grid_index();
@@ -378,14 +794,15 @@ void memtile_device(void* context, void* dst, const void* src, size_t srcsize, s
378
794
  {
379
795
  // generic version
380
796
 
797
+ // copy value to device memory
381
798
  // TODO: use a persistent stream-local staging buffer to avoid allocs?
382
- void* src_device;
383
- check_cuda(cudaMalloc(&src_device, srcsize));
384
- check_cuda(cudaMemcpyAsync(src_device, src, srcsize, cudaMemcpyHostToDevice, get_current_stream()));
799
+ void* src_devptr = alloc_device(WP_CURRENT_CONTEXT, srcsize);
800
+ check_cuda(cudaMemcpyAsync(src_devptr, src, srcsize, cudaMemcpyHostToDevice, get_current_stream()));
801
+
802
+ wp_launch_device(WP_CURRENT_CONTEXT, memtile_kernel, n, (dst, src_devptr, srcsize, n));
385
803
 
386
- wp_launch_device(WP_CURRENT_CONTEXT, memtile_kernel, n, (dst, src_device, srcsize, n));
804
+ free_device(WP_CURRENT_CONTEXT, src_devptr);
387
805
 
388
- check_cuda(cudaFree(src_device));
389
806
  }
390
807
  }
391
808
 
@@ -611,15 +1028,13 @@ static __global__ void array_copy_fabric_indexed_to_fabric_indexed_kernel(wp::in
611
1028
  }
612
1029
 
613
1030
 
614
- WP_API size_t array_copy_device(void* context, void* dst, void* src, int dst_type, int src_type, int elem_size)
1031
+ WP_API bool array_copy_device(void* context, void* dst, void* src, int dst_type, int src_type, int elem_size)
615
1032
  {
616
1033
  if (!src || !dst)
617
- return 0;
1034
+ return false;
618
1035
 
619
1036
  const void* src_data = NULL;
620
- const void* src_grad = NULL;
621
1037
  void* dst_data = NULL;
622
- void* dst_grad = NULL;
623
1038
  int src_ndim = 0;
624
1039
  int dst_ndim = 0;
625
1040
  const int* src_shape = NULL;
@@ -641,7 +1056,6 @@ WP_API size_t array_copy_device(void* context, void* dst, void* src, int dst_typ
641
1056
  {
642
1057
  const wp::array_t<void>& src_arr = *static_cast<const wp::array_t<void>*>(src);
643
1058
  src_data = src_arr.data;
644
- src_grad = src_arr.grad;
645
1059
  src_ndim = src_arr.ndim;
646
1060
  src_shape = src_arr.shape.dims;
647
1061
  src_strides = src_arr.strides;
@@ -669,14 +1083,13 @@ WP_API size_t array_copy_device(void* context, void* dst, void* src, int dst_typ
669
1083
  else
670
1084
  {
671
1085
  fprintf(stderr, "Warp copy error: Invalid array type (%d)\n", src_type);
672
- return 0;
1086
+ return false;
673
1087
  }
674
1088
 
675
1089
  if (dst_type == wp::ARRAY_TYPE_REGULAR)
676
1090
  {
677
1091
  const wp::array_t<void>& dst_arr = *static_cast<const wp::array_t<void>*>(dst);
678
1092
  dst_data = dst_arr.data;
679
- dst_grad = dst_arr.grad;
680
1093
  dst_ndim = dst_arr.ndim;
681
1094
  dst_shape = dst_arr.shape.dims;
682
1095
  dst_strides = dst_arr.strides;
@@ -704,13 +1117,13 @@ WP_API size_t array_copy_device(void* context, void* dst, void* src, int dst_typ
704
1117
  else
705
1118
  {
706
1119
  fprintf(stderr, "Warp copy error: Invalid array type (%d)\n", dst_type);
707
- return 0;
1120
+ return false;
708
1121
  }
709
1122
 
710
1123
  if (src_ndim != dst_ndim)
711
1124
  {
712
1125
  fprintf(stderr, "Warp copy error: Incompatible array dimensionalities (%d and %d)\n", src_ndim, dst_ndim);
713
- return 0;
1126
+ return false;
714
1127
  }
715
1128
 
716
1129
  ContextGuard guard(context);
@@ -725,11 +1138,11 @@ WP_API size_t array_copy_device(void* context, void* dst, void* src, int dst_typ
725
1138
  if (src_fabricarray->size != n)
726
1139
  {
727
1140
  fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
728
- return 0;
1141
+ return false;
729
1142
  }
730
1143
  wp_launch_device(WP_CURRENT_CONTEXT, array_copy_fabric_to_fabric_kernel, n,
731
1144
  (*dst_fabricarray, *src_fabricarray, elem_size));
732
- return n;
1145
+ return true;
733
1146
  }
734
1147
  else if (src_indexedfabricarray)
735
1148
  {
@@ -737,11 +1150,11 @@ WP_API size_t array_copy_device(void* context, void* dst, void* src, int dst_typ
737
1150
  if (src_indexedfabricarray->size != n)
738
1151
  {
739
1152
  fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
740
- return 0;
1153
+ return false;
741
1154
  }
742
1155
  wp_launch_device(WP_CURRENT_CONTEXT, array_copy_fabric_indexed_to_fabric_kernel, n,
743
1156
  (*dst_fabricarray, *src_indexedfabricarray, elem_size));
744
- return n;
1157
+ return true;
745
1158
  }
746
1159
  else
747
1160
  {
@@ -749,11 +1162,11 @@ WP_API size_t array_copy_device(void* context, void* dst, void* src, int dst_typ
749
1162
  if (size_t(src_shape[0]) != n)
750
1163
  {
751
1164
  fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
752
- return 0;
1165
+ return false;
753
1166
  }
754
1167
  wp_launch_device(WP_CURRENT_CONTEXT, array_copy_to_fabric_kernel, n,
755
1168
  (*dst_fabricarray, src_data, src_strides[0], src_indices[0], elem_size));
756
- return n;
1169
+ return true;
757
1170
  }
758
1171
  }
759
1172
  if (dst_indexedfabricarray)
@@ -765,11 +1178,11 @@ WP_API size_t array_copy_device(void* context, void* dst, void* src, int dst_typ
765
1178
  if (src_fabricarray->size != n)
766
1179
  {
767
1180
  fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
768
- return 0;
1181
+ return false;
769
1182
  }
770
1183
  wp_launch_device(WP_CURRENT_CONTEXT, array_copy_fabric_to_fabric_indexed_kernel, n,
771
1184
  (*dst_indexedfabricarray, *src_fabricarray, elem_size));
772
- return n;
1185
+ return true;
773
1186
  }
774
1187
  else if (src_indexedfabricarray)
775
1188
  {
@@ -777,11 +1190,11 @@ WP_API size_t array_copy_device(void* context, void* dst, void* src, int dst_typ
777
1190
  if (src_indexedfabricarray->size != n)
778
1191
  {
779
1192
  fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
780
- return 0;
1193
+ return false;
781
1194
  }
782
1195
  wp_launch_device(WP_CURRENT_CONTEXT, array_copy_fabric_indexed_to_fabric_indexed_kernel, n,
783
1196
  (*dst_indexedfabricarray, *src_indexedfabricarray, elem_size));
784
- return n;
1197
+ return true;
785
1198
  }
786
1199
  else
787
1200
  {
@@ -789,11 +1202,11 @@ WP_API size_t array_copy_device(void* context, void* dst, void* src, int dst_typ
789
1202
  if (size_t(src_shape[0]) != n)
790
1203
  {
791
1204
  fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
792
- return 0;
1205
+ return false;
793
1206
  }
794
1207
  wp_launch_device(WP_CURRENT_CONTEXT, array_copy_to_fabric_indexed_kernel, n,
795
1208
  (*dst_indexedfabricarray, src_data, src_strides[0], src_indices[0], elem_size));
796
- return n;
1209
+ return true;
797
1210
  }
798
1211
  }
799
1212
  else if (src_fabricarray)
@@ -803,11 +1216,11 @@ WP_API size_t array_copy_device(void* context, void* dst, void* src, int dst_typ
803
1216
  if (size_t(dst_shape[0]) != n)
804
1217
  {
805
1218
  fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
806
- return 0;
1219
+ return false;
807
1220
  }
808
1221
  wp_launch_device(WP_CURRENT_CONTEXT, array_copy_from_fabric_kernel, n,
809
1222
  (*src_fabricarray, dst_data, dst_strides[0], dst_indices[0], elem_size));
810
- return n;
1223
+ return true;
811
1224
  }
812
1225
  else if (src_indexedfabricarray)
813
1226
  {
@@ -816,11 +1229,11 @@ WP_API size_t array_copy_device(void* context, void* dst, void* src, int dst_typ
816
1229
  if (size_t(dst_shape[0]) != n)
817
1230
  {
818
1231
  fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
819
- return 0;
1232
+ return false;
820
1233
  }
821
1234
  wp_launch_device(WP_CURRENT_CONTEXT, array_copy_from_fabric_indexed_kernel, n,
822
1235
  (*src_indexedfabricarray, dst_data, dst_strides[0], dst_indices[0], elem_size));
823
- return n;
1236
+ return true;
824
1237
  }
825
1238
 
826
1239
  size_t n = 1;
@@ -829,7 +1242,7 @@ WP_API size_t array_copy_device(void* context, void* dst, void* src, int dst_typ
829
1242
  if (src_shape[i] != dst_shape[i])
830
1243
  {
831
1244
  fprintf(stderr, "Warp copy error: Incompatible array shapes\n");
832
- return 0;
1245
+ return false;
833
1246
  }
834
1247
  n *= src_shape[i];
835
1248
  }
@@ -888,13 +1301,10 @@ WP_API size_t array_copy_device(void* context, void* dst, void* src, int dst_typ
888
1301
  }
889
1302
  default:
890
1303
  fprintf(stderr, "Warp copy error: invalid array dimensionality (%d)\n", src_ndim);
891
- return 0;
1304
+ return false;
892
1305
  }
893
1306
 
894
- if (check_cuda(cudaGetLastError()))
895
- return n;
896
- else
897
- return 0;
1307
+ return check_cuda(cudaGetLastError());
898
1308
  }
899
1309
 
900
1310
 
@@ -1065,8 +1475,8 @@ WP_API void array_fill_device(void* context, void* arr_ptr, int arr_type, const
1065
1475
  ContextGuard guard(context);
1066
1476
 
1067
1477
  // copy value to device memory
1068
- void* value_devptr;
1069
- check_cuda(cudaMalloc(&value_devptr, value_size));
1478
+ // TODO: use a persistent stream-local staging buffer to avoid allocs?
1479
+ void* value_devptr = alloc_device(WP_CURRENT_CONTEXT, value_size);
1070
1480
  check_cuda(cudaMemcpyAsync(value_devptr, value_ptr, value_size, cudaMemcpyHostToDevice, get_current_stream()));
1071
1481
 
1072
1482
  // handle fabric arrays
@@ -1123,6 +1533,8 @@ WP_API void array_fill_device(void* context, void* arr_ptr, int arr_type, const
1123
1533
  fprintf(stderr, "Warp fill error: invalid array dimensionality (%d)\n", ndim);
1124
1534
  return;
1125
1535
  }
1536
+
1537
+ free_device(WP_CURRENT_CONTEXT, value_devptr);
1126
1538
  }
1127
1539
 
1128
1540
  void array_scan_int_device(uint64_t in, uint64_t out, int len, bool inclusive)
@@ -1178,20 +1590,20 @@ int cuda_device_get_count()
1178
1590
  return count;
1179
1591
  }
1180
1592
 
1181
- void* cuda_device_primary_context_retain(int ordinal)
1593
+ void* cuda_device_get_primary_context(int ordinal)
1182
1594
  {
1183
- CUcontext context = NULL;
1184
- CUdevice device;
1185
- if (check_cu(cuDeviceGet_f(&device, ordinal)))
1186
- check_cu(cuDevicePrimaryCtxRetain_f(&context, device));
1187
- return context;
1188
- }
1595
+ if (ordinal >= 0 && ordinal < int(g_devices.size()))
1596
+ {
1597
+ DeviceInfo& device_info = g_devices[ordinal];
1189
1598
 
1190
- void cuda_device_primary_context_release(int ordinal)
1191
- {
1192
- CUdevice device;
1193
- if (check_cu(cuDeviceGet_f(&device, ordinal)))
1194
- check_cu(cuDevicePrimaryCtxRelease_f(device));
1599
+ // acquire the primary context if we haven't already
1600
+ if (!device_info.primary_context)
1601
+ check_cu(cuDevicePrimaryCtxRetain_f(&device_info.primary_context, device_info.device));
1602
+
1603
+ return device_info.primary_context;
1604
+ }
1605
+
1606
+ return NULL;
1195
1607
  }
1196
1608
 
1197
1609
  const char* cuda_device_get_name(int ordinal)
@@ -1241,13 +1653,105 @@ int cuda_device_is_uva(int ordinal)
1241
1653
  return 0;
1242
1654
  }
1243
1655
 
1244
- int cuda_device_is_memory_pool_supported(int ordinal)
1656
+ int cuda_device_is_mempool_supported(int ordinal)
1245
1657
  {
1246
1658
  if (ordinal >= 0 && ordinal < int(g_devices.size()))
1247
- return g_devices[ordinal].is_memory_pool_supported;
1248
- return false;
1659
+ return g_devices[ordinal].is_mempool_supported;
1660
+ return 0;
1661
+ }
1662
+
1663
+ int cuda_device_set_mempool_release_threshold(int ordinal, uint64_t threshold)
1664
+ {
1665
+ if (ordinal < 0 || ordinal > int(g_devices.size()))
1666
+ {
1667
+ fprintf(stderr, "Invalid device ordinal %d\n", ordinal);
1668
+ return 0;
1669
+ }
1670
+
1671
+ if (!g_devices[ordinal].is_mempool_supported)
1672
+ return 0;
1673
+
1674
+ cudaMemPool_t pool;
1675
+ if (!check_cuda(cudaDeviceGetDefaultMemPool(&pool, ordinal)))
1676
+ {
1677
+ fprintf(stderr, "Warp error: Failed to get memory pool on device %d\n", ordinal);
1678
+ return 0;
1679
+ }
1680
+
1681
+ if (!check_cuda(cudaMemPoolSetAttribute(pool, cudaMemPoolAttrReleaseThreshold, &threshold)))
1682
+ {
1683
+ fprintf(stderr, "Warp error: Failed to set memory pool attribute on device %d\n", ordinal);
1684
+ return 0;
1685
+ }
1686
+
1687
+ return 1; // success
1688
+ }
1689
+
1690
+ uint64_t cuda_device_get_mempool_release_threshold(int ordinal)
1691
+ {
1692
+ if (ordinal < 0 || ordinal > int(g_devices.size()))
1693
+ {
1694
+ fprintf(stderr, "Invalid device ordinal %d\n", ordinal);
1695
+ return 0;
1696
+ }
1697
+
1698
+ if (!g_devices[ordinal].is_mempool_supported)
1699
+ return 0;
1700
+
1701
+ cudaMemPool_t pool;
1702
+ if (!check_cuda(cudaDeviceGetDefaultMemPool(&pool, ordinal)))
1703
+ {
1704
+ fprintf(stderr, "Warp error: Failed to get memory pool on device %d\n", ordinal);
1705
+ return 0;
1706
+ }
1707
+
1708
+ uint64_t threshold = 0;
1709
+ if (!check_cuda(cudaMemPoolGetAttribute(pool, cudaMemPoolAttrReleaseThreshold, &threshold)))
1710
+ {
1711
+ fprintf(stderr, "Warp error: Failed to get memory pool release threshold on device %d\n", ordinal);
1712
+ return 0;
1713
+ }
1714
+
1715
+ return threshold;
1716
+ }
1717
+
1718
+ void cuda_device_get_memory_info(int ordinal, size_t* free_mem, size_t* total_mem)
1719
+ {
1720
+ // use temporary storage if user didn't specify pointers
1721
+ size_t tmp_free_mem, tmp_total_mem;
1722
+
1723
+ if (free_mem)
1724
+ *free_mem = 0;
1725
+ else
1726
+ free_mem = &tmp_free_mem;
1727
+
1728
+ if (total_mem)
1729
+ *total_mem = 0;
1730
+ else
1731
+ total_mem = &tmp_total_mem;
1732
+
1733
+ if (ordinal >= 0 && ordinal < int(g_devices.size()))
1734
+ {
1735
+ if (g_devices[ordinal].primary_context)
1736
+ {
1737
+ ContextGuard guard(g_devices[ordinal].primary_context, true);
1738
+ check_cu(cuMemGetInfo_f(free_mem, total_mem));
1739
+ }
1740
+ else
1741
+ {
1742
+ // if we haven't acquired the primary context yet, acquire it temporarily
1743
+ CUcontext primary_context = NULL;
1744
+ check_cu(cuDevicePrimaryCtxRetain_f(&primary_context, g_devices[ordinal].device));
1745
+ {
1746
+ ContextGuard guard(primary_context, true);
1747
+ check_cu(cuMemGetInfo_f(free_mem, total_mem));
1748
+ }
1749
+ check_cu(cuDevicePrimaryCtxRelease_f(g_devices[ordinal].device));
1750
+ }
1751
+ }
1249
1752
  }
1250
1753
 
1754
+
1251
1755
  void* cuda_context_get_current()
1252
1756
  {
1253
1757
  return get_current_context();
@@ -1313,26 +1817,35 @@ void cuda_context_synchronize(void* context)
1313
1817
  ContextGuard guard(context);
1314
1818
 
1315
1819
  check_cu(cuCtxSynchronize_f());
1820
+
1821
+ if (free_deferred_allocs(context ? context : get_current_context()) > 0)
1822
+ {
1823
+ // ensure deferred asynchronous deallocations complete
1824
+ check_cu(cuCtxSynchronize_f());
1825
+ }
1826
+
1827
+ // check_cuda(cudaDeviceGraphMemTrim(cuda_context_get_device_ordinal(context)));
1316
1828
  }
1317
1829
 
1318
1830
  uint64_t cuda_context_check(void* context)
1319
1831
  {
1320
1832
  ContextGuard guard(context);
1321
1833
 
1322
- cudaStreamCaptureStatus status;
1323
- cudaStreamIsCapturing(get_current_stream(), &status);
1834
+ // check errors before syncing
1835
+ cudaError_t e = cudaGetLastError();
1836
+ check_cuda(e);
1837
+
1838
+ cudaStreamCaptureStatus status = cudaStreamCaptureStatusNone;
1839
+ check_cuda(cudaStreamIsCapturing(get_current_stream(), &status));
1324
1840
 
1325
- // do not check during cuda stream capture
1326
- // since we cannot synchronize the device
1841
+ // synchronize if the stream is not capturing
1327
1842
  if (status == cudaStreamCaptureStatusNone)
1328
1843
  {
1329
- cudaDeviceSynchronize();
1330
- return cudaPeekAtLastError();
1331
- }
1332
- else
1333
- {
1334
- return 0;
1844
+ check_cuda(cudaDeviceSynchronize());
1845
+ e = cudaGetLastError();
1335
1846
  }
1847
+
1848
+ return static_cast<uint64_t>(e);
1336
1849
  }
1337
1850
 
1338
1851
 
@@ -1344,25 +1857,28 @@ int cuda_context_get_device_ordinal(void* context)
1344
1857
 
1345
1858
  int cuda_context_is_primary(void* context)
1346
1859
  {
1347
- int ordinal = cuda_context_get_device_ordinal(context);
1348
- if (ordinal != -1)
1860
+ CUcontext ctx = static_cast<CUcontext>(context);
1861
+ ContextInfo* context_info = get_context_info(ctx);
1862
+ if (!context_info)
1349
1863
  {
1350
- // there is no CUDA API to check if a context is primary, but we can temporarily
1351
- // acquire the device's primary context to check the pointer
1352
- void* device_primary_context = cuda_device_primary_context_retain(ordinal);
1353
- cuda_device_primary_context_release(ordinal);
1354
- return int(context == device_primary_context);
1864
+ fprintf(stderr, "Warp error: Failed to get context info\n");
1865
+ return 0;
1355
1866
  }
1356
- return 0;
1357
- }
1358
1867
 
1359
- int cuda_context_is_memory_pool_supported(void* context)
1360
- {
1361
- int ordinal = cuda_context_get_device_ordinal(context);
1362
- if (ordinal != -1)
1868
+ // if the device primary context is known, check if it matches the given context
1869
+ DeviceInfo* device_info = context_info->device_info;
1870
+ if (device_info->primary_context)
1871
+ return int(ctx == device_info->primary_context);
1872
+
1873
+ // there is no CUDA API to check if a context is primary, but we can temporarily
1874
+ // acquire the device's primary context to check the pointer
1875
+ CUcontext primary_ctx;
1876
+ if (check_cu(cuDevicePrimaryCtxRetain_f(&primary_ctx, device_info->device)))
1363
1877
  {
1364
- return cuda_device_is_memory_pool_supported(ordinal);
1878
+ check_cu(cuDevicePrimaryCtxRelease_f(device_info->device));
1879
+ return int(ctx == primary_ctx);
1365
1880
  }
1881
+
1366
1882
  return 0;
1367
1883
  }
1368
1884
 
@@ -1376,115 +1892,251 @@ void* cuda_context_get_stream(void* context)
1376
1892
  return NULL;
1377
1893
  }
1378
1894
 
1379
- void cuda_context_set_stream(void* context, void* stream)
1895
+ void cuda_context_set_stream(void* context, void* stream, int sync)
1380
1896
  {
1381
- ContextInfo* info = get_context_info(static_cast<CUcontext>(context));
1382
- if (info)
1897
+ ContextInfo* context_info = get_context_info(static_cast<CUcontext>(context));
1898
+ if (context_info)
1383
1899
  {
1384
- info->stream = static_cast<CUstream>(stream);
1900
+ CUstream new_stream = static_cast<CUstream>(stream);
1901
+
1902
+ // check whether we should sync with the previous stream on this device
1903
+ if (sync)
1904
+ {
1905
+ CUstream old_stream = context_info->stream;
1906
+ StreamInfo* old_stream_info = get_stream_info(old_stream);
1907
+ if (old_stream_info)
1908
+ {
1909
+ CUevent cached_event = old_stream_info->cached_event;
1910
+ check_cu(cuEventRecord_f(cached_event, old_stream));
1911
+ check_cu(cuStreamWaitEvent_f(new_stream, cached_event, CU_EVENT_WAIT_DEFAULT));
1912
+ }
1913
+ }
1914
+
1915
+ context_info->stream = new_stream;
1385
1916
  }
1386
1917
  }
1387
1918
 
1388
- int cuda_context_enable_peer_access(void* context, void* peer_context)
1919
+
1920
+ int cuda_is_peer_access_supported(int target_ordinal, int peer_ordinal)
1389
1921
  {
1390
- if (!context || !peer_context)
1922
+ int num_devices = int(g_devices.size());
1923
+
1924
+ if (target_ordinal < 0 || target_ordinal > num_devices)
1391
1925
  {
1392
- fprintf(stderr, "Warp error: Failed to enable peer access: invalid argument\n");
1926
+ fprintf(stderr, "Warp error: Invalid target device ordinal %d\n", target_ordinal);
1393
1927
  return 0;
1394
1928
  }
1395
1929
 
1396
- if (context == peer_context)
1397
- return 1; // ok
1930
+ if (peer_ordinal < 0 || peer_ordinal > num_devices)
1931
+ {
1932
+ fprintf(stderr, "Warp error: Invalid peer device ordinal %d\n", peer_ordinal);
1933
+ return 0;
1934
+ }
1398
1935
 
1399
- CUcontext ctx = static_cast<CUcontext>(context);
1400
- CUcontext peer_ctx = static_cast<CUcontext>(peer_context);
1936
+ if (target_ordinal == peer_ordinal)
1937
+ return 1;
1938
+
1939
+ int can_access = 0;
1940
+ check_cuda(cudaDeviceCanAccessPeer(&can_access, peer_ordinal, target_ordinal));
1941
+
1942
+ return can_access;
1943
+ }
1401
1944
 
1402
- ContextInfo* info = get_context_info(ctx);
1403
- ContextInfo* peer_info = get_context_info(peer_ctx);
1404
- if (!info || !peer_info)
1945
+ int cuda_is_peer_access_enabled(void* target_context, void* peer_context)
1946
+ {
1947
+ if (!target_context || !peer_context)
1405
1948
  {
1406
- fprintf(stderr, "Warp error: Failed to enable peer access: failed to get context info\n");
1949
+ fprintf(stderr, "Warp error: invalid CUDA context\n");
1407
1950
  return 0;
1408
1951
  }
1409
1952
 
1410
- // check if same device
1411
- if (info->device_info == peer_info->device_info)
1953
+ if (target_context == peer_context)
1954
+ return 1;
1955
+
1956
+ int target_ordinal = cuda_context_get_device_ordinal(target_context);
1957
+ int peer_ordinal = cuda_context_get_device_ordinal(peer_context);
1958
+
1959
+ // check if peer access is supported
1960
+ int can_access = 0;
1961
+ check_cuda(cudaDeviceCanAccessPeer(&can_access, peer_ordinal, target_ordinal));
1962
+ if (!can_access)
1963
+ return 0;
1964
+
1965
+ // There is no CUDA API to query if peer access is enabled, but we can try to enable it and check the result.
1966
+
1967
+ ContextGuard guard(peer_context, true);
1968
+
1969
+ CUcontext target_ctx = static_cast<CUcontext>(target_context);
1970
+
1971
+ CUresult result = cuCtxEnablePeerAccess_f(target_ctx, 0);
1972
+ if (result == CUDA_ERROR_PEER_ACCESS_ALREADY_ENABLED)
1412
1973
  {
1413
- if (info->device_info->is_uva)
1974
+ return 1;
1975
+ }
1976
+ else if (result == CUDA_SUCCESS)
1977
+ {
1978
+ // undo enablement
1979
+ check_cu(cuCtxDisablePeerAccess_f(target_ctx));
1980
+ return 0;
1981
+ }
1982
+ else
1983
+ {
1984
+ // report error
1985
+ check_cu(result);
1986
+ return 0;
1987
+ }
1988
+ }
1989
+
1990
+ int cuda_set_peer_access_enabled(void* target_context, void* peer_context, int enable)
1991
+ {
1992
+ if (!target_context || !peer_context)
1993
+ {
1994
+ fprintf(stderr, "Warp error: invalid CUDA context\n");
1995
+ return 0;
1996
+ }
1997
+
1998
+ if (target_context == peer_context)
1999
+ return 1; // no-op
2000
+
2001
+ int target_ordinal = cuda_context_get_device_ordinal(target_context);
2002
+ int peer_ordinal = cuda_context_get_device_ordinal(peer_context);
2003
+
2004
+ // check if peer access is supported
2005
+ int can_access = 0;
2006
+ check_cuda(cudaDeviceCanAccessPeer(&can_access, peer_ordinal, target_ordinal));
2007
+ if (!can_access)
2008
+ {
2009
+ // failure if enabling, success if disabling
2010
+ if (enable)
1414
2011
  {
1415
- return 1; // ok
2012
+ fprintf(stderr, "Warp error: device %d cannot access device %d\n", peer_ordinal, target_ordinal);
2013
+ return 0;
1416
2014
  }
1417
2015
  else
2016
+ return 1;
2017
+ }
2018
+
2019
+ ContextGuard guard(peer_context, true);
2020
+
2021
+ CUcontext target_ctx = static_cast<CUcontext>(target_context);
2022
+
2023
+ if (enable)
2024
+ {
2025
+ CUresult status = cuCtxEnablePeerAccess_f(target_ctx, 0);
2026
+ if (status != CUDA_SUCCESS && status != CUDA_ERROR_PEER_ACCESS_ALREADY_ENABLED)
1418
2027
  {
1419
- fprintf(stderr, "Warp error: Failed to enable peer access: device doesn't support UVA\n");
2028
+ check_cu(status);
2029
+ fprintf(stderr, "Warp error: failed to enable peer access from device %d to device %d\n", peer_ordinal, target_ordinal);
1420
2030
  return 0;
1421
2031
  }
1422
2032
  }
1423
2033
  else
1424
2034
  {
1425
- // different devices, try to enable
1426
- ContextGuard guard(ctx, true);
1427
- CUresult result = cuCtxEnablePeerAccess_f(peer_ctx, 0);
1428
- if (result == CUDA_SUCCESS || result == CUDA_ERROR_PEER_ACCESS_ALREADY_ENABLED)
1429
- {
1430
- return 1; // ok
1431
- }
1432
- else
2035
+ CUresult status = cuCtxDisablePeerAccess_f(target_ctx);
2036
+ if (status != CUDA_SUCCESS && status != CUDA_ERROR_PEER_ACCESS_NOT_ENABLED)
1433
2037
  {
1434
- check_cu(result);
2038
+ check_cu(status);
2039
+ fprintf(stderr, "Warp error: failed to disable peer access from device %d to device %d\n", peer_ordinal, target_ordinal);
1435
2040
  return 0;
1436
2041
  }
1437
2042
  }
2043
+
2044
+ return 1; // success
1438
2045
  }
1439
2046
 
1440
- int cuda_context_can_access_peer(void* context, void* peer_context)
2047
+ int cuda_is_mempool_access_enabled(int target_ordinal, int peer_ordinal)
1441
2048
  {
1442
- if (!context || !peer_context)
2049
+ int num_devices = int(g_devices.size());
2050
+
2051
+ if (target_ordinal < 0 || target_ordinal > num_devices)
2052
+ {
2053
+ fprintf(stderr, "Warp error: Invalid device ordinal %d\n", target_ordinal);
2054
+ return 0;
2055
+ }
2056
+
2057
+ if (peer_ordinal < 0 || peer_ordinal > num_devices)
2058
+ {
2059
+ fprintf(stderr, "Warp error: Invalid peer device ordinal %d\n", peer_ordinal);
1443
2060
  return 0;
2061
+ }
1444
2062
 
1445
- if (context == peer_context)
2063
+ if (target_ordinal == peer_ordinal)
1446
2064
  return 1;
1447
2065
 
1448
- CUcontext ctx = static_cast<CUcontext>(context);
1449
- CUcontext peer_ctx = static_cast<CUcontext>(peer_context);
1450
-
1451
- ContextInfo* info = get_context_info(ctx);
1452
- ContextInfo* peer_info = get_context_info(peer_ctx);
1453
- if (!info || !peer_info)
2066
+ cudaMemPool_t pool;
2067
+ if (!check_cuda(cudaDeviceGetDefaultMemPool(&pool, target_ordinal)))
2068
+ {
2069
+ fprintf(stderr, "Warp error: Failed to get memory pool of device %d\n", target_ordinal);
1454
2070
  return 0;
2071
+ }
2072
+
2073
+ cudaMemAccessFlags flags = cudaMemAccessFlagsProtNone;
2074
+ cudaMemLocation location;
2075
+ location.id = peer_ordinal;
2076
+ location.type = cudaMemLocationTypeDevice;
2077
+ if (check_cuda(cudaMemPoolGetAccess(&flags, pool, &location)))
2078
+ return int(flags != cudaMemAccessFlagsProtNone);
2079
+
2080
+ return 0;
2081
+ }
2082
+
2083
+ int cuda_set_mempool_access_enabled(int target_ordinal, int peer_ordinal, int enable)
2084
+ {
2085
+ int num_devices = int(g_devices.size());
1455
2086
 
1456
- // check if same device
1457
- if (info->device_info == peer_info->device_info)
2087
+ if (target_ordinal < 0 || target_ordinal > num_devices)
1458
2088
  {
1459
- if (info->device_info->is_uva)
1460
- return 1;
1461
- else
1462
- return 0;
2089
+ fprintf(stderr, "Warp error: Invalid device ordinal %d\n", target_ordinal);
2090
+ return 0;
2091
+ }
2092
+
2093
+ if (peer_ordinal < 0 || peer_ordinal > num_devices)
2094
+ {
2095
+ fprintf(stderr, "Warp error: Invalid peer device ordinal %d\n", peer_ordinal);
2096
+ return 0;
1463
2097
  }
2098
+
2099
+ if (target_ordinal == peer_ordinal)
2100
+ return 1; // no-op
2101
+
2102
+ // get the memory pool
2103
+ cudaMemPool_t pool;
2104
+ if (!check_cuda(cudaDeviceGetDefaultMemPool(&pool, target_ordinal)))
2105
+ {
2106
+ fprintf(stderr, "Warp error: Failed to get memory pool of device %d\n", target_ordinal);
2107
+ return 0;
2108
+ }
2109
+
2110
+ cudaMemAccessDesc desc;
2111
+ desc.location.type = cudaMemLocationTypeDevice;
2112
+ desc.location.id = peer_ordinal;
2113
+
2114
+ // only cudaMemAccessFlagsProtReadWrite and cudaMemAccessFlagsProtNone are supported
2115
+ if (enable)
2116
+ desc.flags = cudaMemAccessFlagsProtReadWrite;
1464
2117
  else
2118
+ desc.flags = cudaMemAccessFlagsProtNone;
2119
+
2120
+ if (!check_cuda(cudaMemPoolSetAccess(pool, &desc, 1)))
1465
2121
  {
1466
- // different devices, try to enable
1467
- // TODO: is there a better way to check?
1468
- ContextGuard guard(ctx, true);
1469
- CUresult result = cuCtxEnablePeerAccess_f(peer_ctx, 0);
1470
- if (result == CUDA_SUCCESS || result == CUDA_ERROR_PEER_ACCESS_ALREADY_ENABLED)
1471
- return 1;
1472
- else
1473
- return 0;
2122
+ fprintf(stderr, "Warp error: Failed to set mempool access from device %d to device %d\n", peer_ordinal, target_ordinal);
2123
+ return 0;
1474
2124
  }
2125
+
2126
+ return 1; // success
1475
2127
  }
1476
2128
 
2129
+
1477
2130
  void* cuda_stream_create(void* context)
1478
2131
  {
1479
- CUcontext ctx = context ? static_cast<CUcontext>(context) : get_current_context();
1480
- if (!ctx)
1481
- return NULL;
1482
-
1483
2132
  ContextGuard guard(context, true);
1484
2133
 
1485
2134
  CUstream stream;
1486
2135
  if (check_cu(cuStreamCreate_f(&stream, CU_STREAM_DEFAULT)))
2136
+ {
2137
+ cuda_stream_register(WP_CURRENT_CONTEXT, stream);
1487
2138
  return stream;
2139
+ }
1488
2140
  else
1489
2141
  return NULL;
1490
2142
  }
@@ -1494,20 +2146,45 @@ void cuda_stream_destroy(void* context, void* stream)
1494
2146
  if (!stream)
1495
2147
  return;
1496
2148
 
1497
- CUcontext ctx = context ? static_cast<CUcontext>(context) : get_current_context();
1498
- if (!ctx)
1499
- return;
1500
-
1501
- ContextGuard guard(context, true);
2149
+ cuda_stream_unregister(context, stream);
1502
2150
 
1503
2151
  check_cu(cuStreamDestroy_f(static_cast<CUstream>(stream)));
1504
2152
  }
1505
2153
 
1506
- void cuda_stream_synchronize(void* context, void* stream)
2154
+ void cuda_stream_register(void* context, void* stream)
1507
2155
  {
2156
+ if (!stream)
2157
+ return;
2158
+
1508
2159
  ContextGuard guard(context);
1509
2160
 
1510
- check_cu(cuStreamSynchronize_f(static_cast<CUstream>(stream)));
2161
+ // populate stream info
2162
+ StreamInfo& stream_info = g_streams[static_cast<CUstream>(stream)];
2163
+ check_cu(cuEventCreate_f(&stream_info.cached_event, CU_EVENT_DISABLE_TIMING));
2164
+ }
2165
+
2166
+ void cuda_stream_unregister(void* context, void* stream)
2167
+ {
2168
+ if (!stream)
2169
+ return;
2170
+
2171
+ CUstream cuda_stream = static_cast<CUstream>(stream);
2172
+
2173
+ StreamInfo* stream_info = get_stream_info(cuda_stream);
2174
+ if (stream_info)
2175
+ {
2176
+ // release stream info
2177
+ check_cu(cuEventDestroy_f(stream_info->cached_event));
2178
+ g_streams.erase(cuda_stream);
2179
+ }
2180
+
2181
+ // make sure we don't leave dangling references to this stream
2182
+ ContextInfo* context_info = get_context_info(context);
2183
+ if (context_info)
2184
+ {
2185
+ if (cuda_stream == context_info->stream)
2186
+ context_info->stream = NULL;
2187
+ }
1511
2188
  }
1512
2189
 
1513
2190
  void* cuda_stream_get_current()
@@ -1515,24 +2192,33 @@ void* cuda_stream_get_current()
1515
2192
  return get_current_stream();
1516
2193
  }
1517
2194
 
1518
- void cuda_stream_wait_event(void* context, void* stream, void* event)
2195
+ void cuda_stream_synchronize(void* stream)
1519
2196
  {
1520
- ContextGuard guard(context);
2197
+ check_cu(cuStreamSynchronize_f(static_cast<CUstream>(stream)));
2198
+ }
1521
2199
 
2200
+ void cuda_stream_wait_event(void* stream, void* event)
2201
+ {
1522
2202
  check_cu(cuStreamWaitEvent_f(static_cast<CUstream>(stream), static_cast<CUevent>(event), 0));
1523
2203
  }
1524
2204
 
1525
- void cuda_stream_wait_stream(void* context, void* stream, void* other_stream, void* event)
2205
+ void cuda_stream_wait_stream(void* stream, void* other_stream, void* event)
1526
2206
  {
1527
- ContextGuard guard(context);
1528
-
1529
2207
  check_cu(cuEventRecord_f(static_cast<CUevent>(event), static_cast<CUstream>(other_stream)));
1530
2208
  check_cu(cuStreamWaitEvent_f(static_cast<CUstream>(stream), static_cast<CUevent>(event), 0));
1531
2209
  }
1532
2210
 
2211
+ int cuda_stream_is_capturing(void* stream)
2212
+ {
2213
+ cudaStreamCaptureStatus status = cudaStreamCaptureStatusNone;
2214
+ check_cuda(cudaStreamIsCapturing(static_cast<cudaStream_t>(stream), &status));
2215
+
2216
+ return int(status != cudaStreamCaptureStatusNone);
2217
+ }
2218
+
1533
2219
  void* cuda_event_create(void* context, unsigned flags)
1534
2220
  {
1535
- ContextGuard guard(context);
2221
+ ContextGuard guard(context, true);
1536
2222
 
1537
2223
  CUevent event;
1538
2224
  if (check_cu(cuEventCreate_f(&event, flags)))
@@ -1541,68 +2227,217 @@ void* cuda_event_create(void* context, unsigned flags)
1541
2227
  return NULL;
1542
2228
  }
1543
2229
 
1544
- void cuda_event_destroy(void* context, void* event)
2230
+ void cuda_event_destroy(void* event)
1545
2231
  {
1546
- ContextGuard guard(context, true);
1547
-
1548
2232
  check_cu(cuEventDestroy_f(static_cast<CUevent>(event)));
1549
2233
  }
1550
2234
 
1551
- void cuda_event_record(void* context, void* event, void* stream)
2235
+ void cuda_event_record(void* event, void* stream)
1552
2236
  {
1553
- ContextGuard guard(context);
1554
-
1555
2237
  check_cu(cuEventRecord_f(static_cast<CUevent>(event), static_cast<CUstream>(stream)));
1556
2238
  }
1557
2239
 
1558
- void cuda_graph_begin_capture(void* context)
2240
+ bool cuda_graph_begin_capture(void* context, void* stream, int external)
1559
2241
  {
1560
2242
  ContextGuard guard(context);
1561
2243
 
1562
- check_cuda(cudaStreamBeginCapture(get_current_stream(), cudaStreamCaptureModeGlobal));
2244
+ CUstream cuda_stream = static_cast<CUstream>(stream);
2245
+ StreamInfo* stream_info = get_stream_info(cuda_stream);
2246
+ if (!stream_info)
2247
+ {
2248
+ wp::set_error_string("Warp error: unknown stream");
2249
+ return false;
2250
+ }
2251
+
2252
+ if (external)
2253
+ {
2254
+ // if it's an external capture, make sure it's already active so we can get the capture id
2255
+ cudaStreamCaptureStatus status = cudaStreamCaptureStatusNone;
2256
+ if (!check_cuda(cudaStreamIsCapturing(cuda_stream, &status)))
2257
+ return false;
2258
+ if (status != cudaStreamCaptureStatusActive)
2259
+ {
2260
+ wp::set_error_string("Warp error: stream is not capturing");
2261
+ return false;
2262
+ }
2263
+ }
2264
+ else
2265
+ {
2266
+ // start the capture
2267
+ if (!check_cuda(cudaStreamBeginCapture(cuda_stream, cudaStreamCaptureModeGlobal)))
2268
+ return false;
2269
+ }
2270
+
2271
+ uint64_t capture_id = get_capture_id(cuda_stream);
2272
+
2273
+ CaptureInfo* capture = new CaptureInfo();
2274
+ capture->stream = cuda_stream;
2275
+ capture->id = capture_id;
2276
+ capture->external = bool(external);
2277
+
2278
+ // update stream info
2279
+ stream_info->capture = capture;
2280
+
2281
+ // add to known captures
2282
+ g_captures[capture_id] = capture;
2283
+
2284
+ return true;
1563
2285
  }
1564
2286
 
1565
- void* cuda_graph_end_capture(void* context)
2287
+ bool cuda_graph_end_capture(void* context, void* stream, void** graph_ret)
1566
2288
  {
1567
2289
  ContextGuard guard(context);
1568
2290
 
1569
- cudaGraph_t graph = NULL;
1570
- check_cuda(cudaStreamEndCapture(get_current_stream(), &graph));
2291
+ // check if this is a known stream
2292
+ CUstream cuda_stream = static_cast<CUstream>(stream);
2293
+ StreamInfo* stream_info = get_stream_info(cuda_stream);
2294
+ if (!stream_info)
2295
+ {
2296
+ wp::set_error_string("Warp error: unknown capture stream");
2297
+ return false;
2298
+ }
1571
2299
 
1572
- if (graph)
2300
+ // check if this stream was used to start a capture
2301
+ CaptureInfo* capture = stream_info->capture;
2302
+ if (!capture)
1573
2303
  {
1574
- // enable to create debug GraphVis visualization of graph
1575
- //cudaGraphDebugDotPrint(graph, "graph.dot", cudaGraphDebugDotFlagsVerbose);
2304
+ wp::set_error_string("Warp error: stream has no capture started");
2305
+ return false;
2306
+ }
1576
2307
 
1577
- cudaGraphExec_t graph_exec = NULL;
1578
- //check_cuda(cudaGraphInstantiate(&graph_exec, graph, NULL, NULL, 0));
1579
-
1580
- // can use after CUDA 11.4 to permit graphs to capture cudaMallocAsync() operations
1581
- check_cuda(cudaGraphInstantiateWithFlags(&graph_exec, graph, cudaGraphInstantiateFlagAutoFreeOnLaunch));
2308
+ // get capture info
2309
+ bool external = capture->external;
2310
+ uint64_t capture_id = capture->id;
1582
2311
 
1583
- // free source graph
1584
- check_cuda(cudaGraphDestroy(graph));
2312
+ // clear capture info
2313
+ stream_info->capture = NULL;
2314
+ g_captures.erase(capture_id);
2315
+ delete capture;
2316
+
2317
+ // a lambda to clean up on exit in case of error
2318
+ auto clean_up = [cuda_stream, capture_id, external]()
2319
+ {
2320
+ // unreference outstanding graph allocs so that they will be released with the user reference
2321
+ for (auto it = g_graph_allocs.begin(); it != g_graph_allocs.end(); ++it)
2322
+ {
2323
+ GraphAllocInfo& alloc_info = it->second;
2324
+ if (alloc_info.capture_id == capture_id)
2325
+ alloc_info.graph_destroyed = true;
2326
+ }
2327
+
2328
+ // make sure we terminate the capture
2329
+ if (!external)
2330
+ {
2331
+ cudaGraph_t graph = NULL;
2332
+ cudaStreamEndCapture(cuda_stream, &graph);
2333
+ cudaGetLastError();
2334
+ }
2335
+ };
1585
2336
 
1586
- return graph_exec;
2337
+ // get captured graph without ending the capture in case it is external
2338
+ cudaGraph_t graph = get_capture_graph(cuda_stream);
2339
+ if (!graph)
2340
+ {
2341
+ clean_up();
2342
+ return false;
1587
2343
  }
1588
- else
2344
+
2345
+ // ensure that all forked streams are joined to the main capture stream by manually
2346
+ // adding outstanding capture dependencies gathered from the graph leaf nodes
2347
+ std::vector<cudaGraphNode_t> stream_dependencies;
2348
+ std::vector<cudaGraphNode_t> leaf_nodes;
2349
+ if (get_capture_dependencies(cuda_stream, stream_dependencies) && get_graph_leaf_nodes(graph, leaf_nodes))
2350
+ {
2351
+ // compute set difference to get unjoined dependencies
2352
+ std::vector<cudaGraphNode_t> unjoined_dependencies;
2353
+ std::sort(stream_dependencies.begin(), stream_dependencies.end());
2354
+ std::sort(leaf_nodes.begin(), leaf_nodes.end());
2355
+ std::set_difference(leaf_nodes.begin(), leaf_nodes.end(),
2356
+ stream_dependencies.begin(), stream_dependencies.end(),
2357
+ std::back_inserter(unjoined_dependencies));
2358
+ if (!unjoined_dependencies.empty())
2359
+ {
2360
+ check_cu(cuStreamUpdateCaptureDependencies_f(cuda_stream, unjoined_dependencies.data(), unjoined_dependencies.size(),
2361
+ CU_STREAM_ADD_CAPTURE_DEPENDENCIES));
2362
+ // ensure graph is still valid
2363
+ if (get_capture_graph(cuda_stream) != graph)
2364
+ {
2365
+ clean_up();
2366
+ return false;
2367
+ }
2368
+ }
2369
+ }
2370
+
2371
+ // check if this graph has unfreed allocations, which require special handling
2372
+ std::vector<void*> unfreed_allocs;
2373
+ for (auto it = g_graph_allocs.begin(); it != g_graph_allocs.end(); ++it)
1589
2374
  {
1590
- return NULL;
2375
+ GraphAllocInfo& alloc_info = it->second;
2376
+ if (alloc_info.capture_id == capture_id)
2377
+ unfreed_allocs.push_back(it->first);
2378
+ }
2379
+
2380
+ if (!unfreed_allocs.empty())
2381
+ {
2382
+ // Create a user object that will notify us when the instantiated graph is destroyed.
2383
+ // This works for external captures also, since we wouldn't otherwise know when
2384
+ // the externally-created graph instance gets deleted.
2385
+ // This callback is guaranteed to arrive after the graph has finished executing on the device,
2386
+ // not necessarily when cudaGraphExecDestroy() is called.
2387
+ GraphInfo* graph_info = new GraphInfo;
2388
+ graph_info->unfreed_allocs = unfreed_allocs;
2389
+ cudaUserObject_t user_object;
2390
+ check_cuda(cudaUserObjectCreate(&user_object, graph_info, on_graph_destroy, 1, cudaUserObjectNoDestructorSync));
2391
+ check_cuda(cudaGraphRetainUserObject(graph, user_object, 1, cudaGraphUserObjectMove));
2392
+
2393
+ // ensure graph is still valid
2394
+ if (get_capture_graph(cuda_stream) != graph)
2395
+ {
2396
+ clean_up();
2397
+ return false;
2398
+ }
1591
2399
  }
2400
+
2401
+ // for external captures, we don't instantiate the graph ourselves, so we're done
2402
+ if (external)
2403
+ return true;
2404
+
2405
+ cudaGraphExec_t graph_exec = NULL;
2406
+
2407
+ // end the capture
2408
+ if (!check_cuda(cudaStreamEndCapture(cuda_stream, &graph)))
2409
+ return false;
2410
+
2411
+ // enable to create debug GraphVis visualization of graph
2412
+ // cudaGraphDebugDotPrint(graph, "graph.dot", cudaGraphDebugDotFlagsVerbose);
2413
+
2414
+ // can use after CUDA 11.4 to permit graphs to capture cudaMallocAsync() operations
2415
+ if (!check_cuda(cudaGraphInstantiateWithFlags(&graph_exec, graph, cudaGraphInstantiateFlagAutoFreeOnLaunch)))
2416
+ return false;
2417
+
2418
+ // free source graph
2419
+ check_cuda(cudaGraphDestroy(graph));
2420
+
2421
+ // process deferred free list if no more captures are ongoing
2422
+ if (g_captures.empty())
2423
+ free_deferred_allocs();
2424
+
2425
+ if (graph_ret)
2426
+ *graph_ret = graph_exec;
2427
+
2428
+ return true;
1592
2429
  }
1593
2430
 
1594
- void cuda_graph_launch(void* context, void* graph_exec)
2431
+ bool cuda_graph_launch(void* graph_exec, void* stream)
1595
2432
  {
1596
- ContextGuard guard(context);
1597
-
1598
- check_cuda(cudaGraphLaunch((cudaGraphExec_t)graph_exec, get_current_stream()));
2433
+ return check_cuda(cudaGraphLaunch((cudaGraphExec_t)graph_exec, (cudaStream_t)stream));
1599
2434
  }
1600
2435
 
1601
- void cuda_graph_destroy(void* context, void* graph_exec)
2436
+ bool cuda_graph_destroy(void* context, void* graph_exec)
1602
2437
  {
1603
2438
  ContextGuard guard(context);
1604
2439
 
1605
- check_cuda(cudaGraphExecDestroy((cudaGraphExec_t)graph_exec));
2440
+ return check_cuda(cudaGraphExecDestroy((cudaGraphExec_t)graph_exec));
1606
2441
  }
1607
2442
 
1608
2443
  size_t cuda_compile_program(const char* cuda_src, int arch, const char* include_dir, bool debug, bool verbose, bool verify_fp, bool fast_math, const char* output_path)
@@ -1880,7 +2715,7 @@ void* cuda_get_kernel(void* context, void* module, const char* name)
1880
2715
  return kernel;
1881
2716
  }
1882
2717
 
1883
- size_t cuda_launch_kernel(void* context, void* kernel, size_t dim, int max_blocks, void** args)
2718
+ size_t cuda_launch_kernel(void* context, void* kernel, size_t dim, int max_blocks, void** args, void* stream)
1884
2719
  {
1885
2720
  ContextGuard guard(context);
1886
2721
 
@@ -1913,7 +2748,7 @@ size_t cuda_launch_kernel(void* context, void* kernel, size_t dim, int max_block
1913
2748
  (CUfunction)kernel,
1914
2749
  grid_dim, 1, 1,
1915
2750
  block_dim, 1, 1,
1916
- 0, get_current_stream(),
2751
+ 0, static_cast<CUstream>(stream),
1917
2752
  args,
1918
2753
  0);
1919
2754