warp-lang 1.8.1__py3-none-win_amd64.whl → 1.9.1__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (141) hide show
  1. warp/__init__.py +282 -103
  2. warp/__init__.pyi +1904 -114
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +93 -30
  6. warp/build_dll.py +331 -101
  7. warp/builtins.py +1244 -160
  8. warp/codegen.py +317 -206
  9. warp/config.py +1 -1
  10. warp/context.py +1465 -789
  11. warp/examples/core/example_marching_cubes.py +1 -0
  12. warp/examples/core/example_render_opengl.py +100 -3
  13. warp/examples/fem/example_apic_fluid.py +98 -52
  14. warp/examples/fem/example_convection_diffusion_dg.py +25 -4
  15. warp/examples/fem/example_diffusion_mgpu.py +8 -3
  16. warp/examples/fem/utils.py +68 -22
  17. warp/examples/interop/example_jax_kernel.py +2 -1
  18. warp/fabric.py +1 -1
  19. warp/fem/cache.py +27 -19
  20. warp/fem/domain.py +2 -2
  21. warp/fem/field/nodal_field.py +2 -2
  22. warp/fem/field/virtual.py +264 -166
  23. warp/fem/geometry/geometry.py +5 -5
  24. warp/fem/integrate.py +129 -51
  25. warp/fem/space/restriction.py +4 -0
  26. warp/fem/space/shape/tet_shape_function.py +3 -10
  27. warp/jax_experimental/custom_call.py +25 -2
  28. warp/jax_experimental/ffi.py +22 -1
  29. warp/jax_experimental/xla_ffi.py +16 -7
  30. warp/marching_cubes.py +708 -0
  31. warp/native/array.h +99 -4
  32. warp/native/builtin.h +86 -9
  33. warp/native/bvh.cpp +64 -28
  34. warp/native/bvh.cu +58 -58
  35. warp/native/bvh.h +2 -2
  36. warp/native/clang/clang.cpp +7 -7
  37. warp/native/coloring.cpp +8 -2
  38. warp/native/crt.cpp +2 -2
  39. warp/native/crt.h +3 -5
  40. warp/native/cuda_util.cpp +41 -10
  41. warp/native/cuda_util.h +10 -4
  42. warp/native/exports.h +1842 -1908
  43. warp/native/fabric.h +2 -1
  44. warp/native/hashgrid.cpp +37 -37
  45. warp/native/hashgrid.cu +2 -2
  46. warp/native/initializer_array.h +1 -1
  47. warp/native/intersect.h +2 -2
  48. warp/native/mat.h +1910 -116
  49. warp/native/mathdx.cpp +43 -43
  50. warp/native/mesh.cpp +24 -24
  51. warp/native/mesh.cu +26 -26
  52. warp/native/mesh.h +4 -2
  53. warp/native/nanovdb/GridHandle.h +179 -12
  54. warp/native/nanovdb/HostBuffer.h +8 -7
  55. warp/native/nanovdb/NanoVDB.h +517 -895
  56. warp/native/nanovdb/NodeManager.h +323 -0
  57. warp/native/nanovdb/PNanoVDB.h +2 -2
  58. warp/native/quat.h +331 -14
  59. warp/native/range.h +7 -1
  60. warp/native/reduce.cpp +10 -10
  61. warp/native/reduce.cu +13 -14
  62. warp/native/runlength_encode.cpp +2 -2
  63. warp/native/runlength_encode.cu +5 -5
  64. warp/native/scan.cpp +3 -3
  65. warp/native/scan.cu +4 -4
  66. warp/native/sort.cpp +10 -10
  67. warp/native/sort.cu +40 -31
  68. warp/native/sort.h +2 -0
  69. warp/native/sparse.cpp +8 -8
  70. warp/native/sparse.cu +13 -13
  71. warp/native/spatial.h +366 -17
  72. warp/native/temp_buffer.h +2 -2
  73. warp/native/tile.h +471 -82
  74. warp/native/vec.h +328 -14
  75. warp/native/volume.cpp +54 -54
  76. warp/native/volume.cu +1 -1
  77. warp/native/volume.h +2 -1
  78. warp/native/volume_builder.cu +30 -37
  79. warp/native/warp.cpp +150 -149
  80. warp/native/warp.cu +377 -216
  81. warp/native/warp.h +227 -226
  82. warp/optim/linear.py +736 -271
  83. warp/render/imgui_manager.py +289 -0
  84. warp/render/render_opengl.py +99 -18
  85. warp/render/render_usd.py +1 -0
  86. warp/sim/graph_coloring.py +2 -2
  87. warp/sparse.py +558 -175
  88. warp/tests/aux_test_module_aot.py +7 -0
  89. warp/tests/cuda/test_async.py +3 -3
  90. warp/tests/cuda/test_conditional_captures.py +101 -0
  91. warp/tests/geometry/test_hash_grid.py +38 -0
  92. warp/tests/geometry/test_marching_cubes.py +233 -12
  93. warp/tests/interop/test_jax.py +608 -28
  94. warp/tests/sim/test_coloring.py +6 -6
  95. warp/tests/test_array.py +58 -5
  96. warp/tests/test_codegen.py +4 -3
  97. warp/tests/test_context.py +8 -15
  98. warp/tests/test_enum.py +136 -0
  99. warp/tests/test_examples.py +2 -2
  100. warp/tests/test_fem.py +49 -6
  101. warp/tests/test_fixedarray.py +229 -0
  102. warp/tests/test_func.py +18 -15
  103. warp/tests/test_future_annotations.py +7 -5
  104. warp/tests/test_linear_solvers.py +30 -0
  105. warp/tests/test_map.py +15 -1
  106. warp/tests/test_mat.py +1518 -378
  107. warp/tests/test_mat_assign_copy.py +178 -0
  108. warp/tests/test_mat_constructors.py +574 -0
  109. warp/tests/test_module_aot.py +287 -0
  110. warp/tests/test_print.py +69 -0
  111. warp/tests/test_quat.py +140 -34
  112. warp/tests/test_quat_assign_copy.py +145 -0
  113. warp/tests/test_reload.py +2 -1
  114. warp/tests/test_sparse.py +71 -0
  115. warp/tests/test_spatial.py +140 -34
  116. warp/tests/test_spatial_assign_copy.py +160 -0
  117. warp/tests/test_struct.py +43 -3
  118. warp/tests/test_tuple.py +96 -0
  119. warp/tests/test_types.py +61 -20
  120. warp/tests/test_vec.py +179 -34
  121. warp/tests/test_vec_assign_copy.py +143 -0
  122. warp/tests/tile/test_tile.py +245 -18
  123. warp/tests/tile/test_tile_cholesky.py +605 -0
  124. warp/tests/tile/test_tile_load.py +169 -0
  125. warp/tests/tile/test_tile_mathdx.py +2 -558
  126. warp/tests/tile/test_tile_matmul.py +1 -1
  127. warp/tests/tile/test_tile_mlp.py +1 -1
  128. warp/tests/tile/test_tile_shared_memory.py +5 -5
  129. warp/tests/unittest_suites.py +6 -0
  130. warp/tests/walkthrough_debug.py +1 -1
  131. warp/thirdparty/unittest_parallel.py +108 -9
  132. warp/types.py +571 -267
  133. warp/utils.py +68 -86
  134. {warp_lang-1.8.1.dist-info → warp_lang-1.9.1.dist-info}/METADATA +29 -69
  135. {warp_lang-1.8.1.dist-info → warp_lang-1.9.1.dist-info}/RECORD +138 -128
  136. warp/native/marching.cpp +0 -19
  137. warp/native/marching.cu +0 -514
  138. warp/native/marching.h +0 -19
  139. {warp_lang-1.8.1.dist-info → warp_lang-1.9.1.dist-info}/WHEEL +0 -0
  140. {warp_lang-1.8.1.dist-info → warp_lang-1.9.1.dist-info}/licenses/LICENSE.md +0 -0
  141. {warp_lang-1.8.1.dist-info → warp_lang-1.9.1.dist-info}/top_level.txt +0 -0
warp/native/warp.cu CHANGED
@@ -19,6 +19,7 @@
19
19
  #include "scan.h"
20
20
  #include "cuda_util.h"
21
21
  #include "error.h"
22
+ #include "sort.h"
22
23
 
23
24
  #include <cstdlib>
24
25
  #include <fstream>
@@ -168,7 +169,7 @@ struct ContextInfo
168
169
  {
169
170
  DeviceInfo* device_info = NULL;
170
171
 
171
- // the current stream, managed from Python (see cuda_context_set_stream() and cuda_context_get_stream())
172
+ // the current stream, managed from Python (see wp_cuda_context_set_stream() and wp_cuda_context_get_stream())
172
173
  CUstream stream = NULL;
173
174
 
174
175
  // conditional graph node support, loaded on demand if the driver supports it (CUDA 12.4+)
@@ -237,11 +238,11 @@ static std::unordered_map<CUstream, StreamInfo> g_streams;
237
238
 
238
239
  // Ongoing graph captures registered using wp.capture_begin().
239
240
  // This maps the capture id to the stream where capture was started.
240
- // See cuda_graph_begin_capture(), cuda_graph_end_capture(), and free_device_async().
241
+ // See wp_cuda_graph_begin_capture(), wp_cuda_graph_end_capture(), and wp_free_device_async().
241
242
  static std::unordered_map<uint64_t, CaptureInfo*> g_captures;
242
243
 
243
244
  // Memory allocated during graph capture requires special handling.
244
- // See alloc_device_async() and free_device_async().
245
+ // See wp_alloc_device_async() and wp_free_device_async().
245
246
  static std::unordered_map<void*, GraphAllocInfo> g_graph_allocs;
246
247
 
247
248
  // Memory that cannot be freed immediately gets queued here.
@@ -252,12 +253,12 @@ static std::vector<FreeInfo> g_deferred_free_list;
252
253
  // Call unload_deferred_modules() to release.
253
254
  static std::vector<ModuleInfo> g_deferred_module_list;
254
255
 
255
- void cuda_set_context_restore_policy(bool always_restore)
256
+ void wp_cuda_set_context_restore_policy(bool always_restore)
256
257
  {
257
258
  ContextGuard::always_restore = always_restore;
258
259
  }
259
260
 
260
- int cuda_get_context_restore_policy()
261
+ int wp_cuda_get_context_restore_policy()
261
262
  {
262
263
  return int(ContextGuard::always_restore);
263
264
  }
@@ -348,7 +349,7 @@ static inline CUcontext get_current_context()
348
349
 
349
350
  static inline CUstream get_current_stream(void* context=NULL)
350
351
  {
351
- return static_cast<CUstream>(cuda_context_get_stream(context));
352
+ return static_cast<CUstream>(wp_cuda_context_get_stream(context));
352
353
  }
353
354
 
354
355
  static ContextInfo* get_context_info(CUcontext ctx)
@@ -481,7 +482,7 @@ static int unload_deferred_modules(void* context = NULL)
481
482
  const ModuleInfo& module_info = *it;
482
483
  if (module_info.context == context || !context)
483
484
  {
484
- cuda_unload_module(module_info.context, module_info.module);
485
+ wp_cuda_unload_module(module_info.context, module_info.module);
485
486
  ++num_unloaded_modules;
486
487
  it = g_deferred_module_list.erase(it);
487
488
  }
@@ -535,41 +536,41 @@ static inline const char* get_cuda_kernel_name(void* kernel)
535
536
  }
536
537
 
537
538
 
538
- void* alloc_pinned(size_t s)
539
+ void* wp_alloc_pinned(size_t s)
539
540
  {
540
541
  void* ptr = NULL;
541
542
  check_cuda(cudaMallocHost(&ptr, s));
542
543
  return ptr;
543
544
  }
544
545
 
545
- void free_pinned(void* ptr)
546
+ void wp_free_pinned(void* ptr)
546
547
  {
547
548
  cudaFreeHost(ptr);
548
549
  }
549
550
 
550
- void* alloc_device(void* context, size_t s)
551
+ void* wp_alloc_device(void* context, size_t s)
551
552
  {
552
- int ordinal = cuda_context_get_device_ordinal(context);
553
+ int ordinal = wp_cuda_context_get_device_ordinal(context);
553
554
 
554
555
  // use stream-ordered allocator if available
555
- if (cuda_device_is_mempool_supported(ordinal))
556
- return alloc_device_async(context, s);
556
+ if (wp_cuda_device_is_mempool_supported(ordinal))
557
+ return wp_alloc_device_async(context, s);
557
558
  else
558
- return alloc_device_default(context, s);
559
+ return wp_alloc_device_default(context, s);
559
560
  }
560
561
 
561
- void free_device(void* context, void* ptr)
562
+ void wp_free_device(void* context, void* ptr)
562
563
  {
563
- int ordinal = cuda_context_get_device_ordinal(context);
564
+ int ordinal = wp_cuda_context_get_device_ordinal(context);
564
565
 
565
566
  // use stream-ordered allocator if available
566
- if (cuda_device_is_mempool_supported(ordinal))
567
- free_device_async(context, ptr);
567
+ if (wp_cuda_device_is_mempool_supported(ordinal))
568
+ wp_free_device_async(context, ptr);
568
569
  else
569
- free_device_default(context, ptr);
570
+ wp_free_device_default(context, ptr);
570
571
  }
571
572
 
572
- void* alloc_device_default(void* context, size_t s)
573
+ void* wp_alloc_device_default(void* context, size_t s)
573
574
  {
574
575
  ContextGuard guard(context);
575
576
 
@@ -579,7 +580,7 @@ void* alloc_device_default(void* context, size_t s)
579
580
  return ptr;
580
581
  }
581
582
 
582
- void free_device_default(void* context, void* ptr)
583
+ void wp_free_device_default(void* context, void* ptr)
583
584
  {
584
585
  ContextGuard guard(context);
585
586
 
@@ -595,7 +596,7 @@ void free_device_default(void* context, void* ptr)
595
596
  }
596
597
  }
597
598
 
598
- void* alloc_device_async(void* context, size_t s)
599
+ void* wp_alloc_device_async(void* context, size_t s)
599
600
  {
600
601
  // stream-ordered allocations don't rely on the current context,
601
602
  // but we set the context here for consistent behaviour
@@ -613,7 +614,7 @@ void* alloc_device_async(void* context, size_t s)
613
614
  if (ptr)
614
615
  {
615
616
  // if the stream is capturing, the allocation requires special handling
616
- if (cuda_stream_is_capturing(stream))
617
+ if (wp_cuda_stream_is_capturing(stream))
617
618
  {
618
619
  // check if this is a known capture
619
620
  uint64_t capture_id = get_capture_id(stream);
@@ -634,7 +635,7 @@ void* alloc_device_async(void* context, size_t s)
634
635
  return ptr;
635
636
  }
636
637
 
637
- void free_device_async(void* context, void* ptr)
638
+ void wp_free_device_async(void* context, void* ptr)
638
639
  {
639
640
  // stream-ordered allocators generally don't rely on the current context,
640
641
  // but we set the context here for consistent behaviour
@@ -732,7 +733,7 @@ void free_device_async(void* context, void* ptr)
732
733
  }
733
734
  }
734
735
 
735
- bool memcpy_h2d(void* context, void* dest, void* src, size_t n, void* stream)
736
+ bool wp_memcpy_h2d(void* context, void* dest, void* src, size_t n, void* stream)
736
737
  {
737
738
  ContextGuard guard(context);
738
739
 
@@ -751,7 +752,7 @@ bool memcpy_h2d(void* context, void* dest, void* src, size_t n, void* stream)
751
752
  return result;
752
753
  }
753
754
 
754
- bool memcpy_d2h(void* context, void* dest, void* src, size_t n, void* stream)
755
+ bool wp_memcpy_d2h(void* context, void* dest, void* src, size_t n, void* stream)
755
756
  {
756
757
  ContextGuard guard(context);
757
758
 
@@ -770,7 +771,7 @@ bool memcpy_d2h(void* context, void* dest, void* src, size_t n, void* stream)
770
771
  return result;
771
772
  }
772
773
 
773
- bool memcpy_d2d(void* context, void* dest, void* src, size_t n, void* stream)
774
+ bool wp_memcpy_d2d(void* context, void* dest, void* src, size_t n, void* stream)
774
775
  {
775
776
  ContextGuard guard(context);
776
777
 
@@ -789,7 +790,7 @@ bool memcpy_d2d(void* context, void* dest, void* src, size_t n, void* stream)
789
790
  return result;
790
791
  }
791
792
 
792
- bool memcpy_p2p(void* dst_context, void* dst, void* src_context, void* src, size_t n, void* stream)
793
+ bool wp_memcpy_p2p(void* dst_context, void* dst, void* src_context, void* src, size_t n, void* stream)
793
794
  {
794
795
  // ContextGuard guard(context);
795
796
 
@@ -809,7 +810,7 @@ bool memcpy_p2p(void* dst_context, void* dst, void* src_context, void* src, size
809
810
  // because cudaMemPoolGetAccess() cannot be called during graph capture.
810
811
  // - CUDA will report error 1 (invalid argument) if cudaMemcpyAsync() is called but mempool access is not enabled.
811
812
 
812
- if (!cuda_stream_is_capturing(stream))
813
+ if (!wp_cuda_stream_is_capturing(stream))
813
814
  {
814
815
  begin_cuda_range(WP_TIMING_MEMCPY, cuda_stream, get_stream_context(stream), "memcpy PtoP");
815
816
 
@@ -896,7 +897,7 @@ __global__ void memset_kernel(int* dest, int value, size_t n)
896
897
  }
897
898
  }
898
899
 
899
- void memset_device(void* context, void* dest, int value, size_t n)
900
+ void wp_memset_device(void* context, void* dest, int value, size_t n)
900
901
  {
901
902
  ContextGuard guard(context);
902
903
 
@@ -940,7 +941,7 @@ __global__ void memtile_value_kernel(T* dst, T value, size_t n)
940
941
  }
941
942
  }
942
943
 
943
- void memtile_device(void* context, void* dst, const void* src, size_t srcsize, size_t n)
944
+ void wp_memtile_device(void* context, void* dst, const void* src, size_t srcsize, size_t n)
944
945
  {
945
946
  ContextGuard guard(context);
946
947
 
@@ -976,12 +977,12 @@ void memtile_device(void* context, void* dst, const void* src, size_t srcsize, s
976
977
 
977
978
  // copy value to device memory
978
979
  // TODO: use a persistent stream-local staging buffer to avoid allocs?
979
- void* src_devptr = alloc_device(WP_CURRENT_CONTEXT, srcsize);
980
+ void* src_devptr = wp_alloc_device(WP_CURRENT_CONTEXT, srcsize);
980
981
  check_cuda(cudaMemcpyAsync(src_devptr, src, srcsize, cudaMemcpyHostToDevice, get_current_stream()));
981
982
 
982
983
  wp_launch_device(WP_CURRENT_CONTEXT, memtile_kernel, n, (dst, src_devptr, srcsize, n));
983
984
 
984
- free_device(WP_CURRENT_CONTEXT, src_devptr);
985
+ wp_free_device(WP_CURRENT_CONTEXT, src_devptr);
985
986
 
986
987
  }
987
988
  }
@@ -1208,7 +1209,7 @@ static __global__ void array_copy_fabric_indexed_to_fabric_indexed_kernel(wp::in
1208
1209
  }
1209
1210
 
1210
1211
 
1211
- WP_API bool array_copy_device(void* context, void* dst, void* src, int dst_type, int src_type, int elem_size)
1212
+ WP_API bool wp_array_copy_device(void* context, void* dst, void* src, int dst_type, int src_type, int elem_size)
1212
1213
  {
1213
1214
  if (!src || !dst)
1214
1215
  return false;
@@ -1600,7 +1601,7 @@ static __global__ void array_fill_fabric_indexed_kernel(wp::indexedfabricarray_t
1600
1601
  }
1601
1602
 
1602
1603
 
1603
- WP_API void array_fill_device(void* context, void* arr_ptr, int arr_type, const void* value_ptr, int value_size)
1604
+ WP_API void wp_array_fill_device(void* context, void* arr_ptr, int arr_type, const void* value_ptr, int value_size)
1604
1605
  {
1605
1606
  if (!arr_ptr || !value_ptr)
1606
1607
  return;
@@ -1656,7 +1657,7 @@ WP_API void array_fill_device(void* context, void* arr_ptr, int arr_type, const
1656
1657
 
1657
1658
  // copy value to device memory
1658
1659
  // TODO: use a persistent stream-local staging buffer to avoid allocs?
1659
- void* value_devptr = alloc_device(WP_CURRENT_CONTEXT, value_size);
1660
+ void* value_devptr = wp_alloc_device(WP_CURRENT_CONTEXT, value_size);
1660
1661
  check_cuda(cudaMemcpyAsync(value_devptr, value_ptr, value_size, cudaMemcpyHostToDevice, get_current_stream()));
1661
1662
 
1662
1663
  // handle fabric arrays
@@ -1714,20 +1715,20 @@ WP_API void array_fill_device(void* context, void* arr_ptr, int arr_type, const
1714
1715
  return;
1715
1716
  }
1716
1717
 
1717
- free_device(WP_CURRENT_CONTEXT, value_devptr);
1718
+ wp_free_device(WP_CURRENT_CONTEXT, value_devptr);
1718
1719
  }
1719
1720
 
1720
- void array_scan_int_device(uint64_t in, uint64_t out, int len, bool inclusive)
1721
+ void wp_array_scan_int_device(uint64_t in, uint64_t out, int len, bool inclusive)
1721
1722
  {
1722
1723
  scan_device((const int*)in, (int*)out, len, inclusive);
1723
1724
  }
1724
1725
 
1725
- void array_scan_float_device(uint64_t in, uint64_t out, int len, bool inclusive)
1726
+ void wp_array_scan_float_device(uint64_t in, uint64_t out, int len, bool inclusive)
1726
1727
  {
1727
1728
  scan_device((const float*)in, (float*)out, len, inclusive);
1728
1729
  }
1729
1730
 
1730
- int cuda_driver_version()
1731
+ int wp_cuda_driver_version()
1731
1732
  {
1732
1733
  int version;
1733
1734
  if (check_cu(cuDriverGetVersion_f(&version)))
@@ -1736,17 +1737,17 @@ int cuda_driver_version()
1736
1737
  return 0;
1737
1738
  }
1738
1739
 
1739
- int cuda_toolkit_version()
1740
+ int wp_cuda_toolkit_version()
1740
1741
  {
1741
1742
  return CUDA_VERSION;
1742
1743
  }
1743
1744
 
1744
- bool cuda_driver_is_initialized()
1745
+ bool wp_cuda_driver_is_initialized()
1745
1746
  {
1746
1747
  return is_cuda_driver_initialized();
1747
1748
  }
1748
1749
 
1749
- int nvrtc_supported_arch_count()
1750
+ int wp_nvrtc_supported_arch_count()
1750
1751
  {
1751
1752
  int count;
1752
1753
  if (check_nvrtc(nvrtcGetNumSupportedArchs(&count)))
@@ -1755,7 +1756,7 @@ int nvrtc_supported_arch_count()
1755
1756
  return 0;
1756
1757
  }
1757
1758
 
1758
- void nvrtc_supported_archs(int* archs)
1759
+ void wp_nvrtc_supported_archs(int* archs)
1759
1760
  {
1760
1761
  if (archs)
1761
1762
  {
@@ -1763,14 +1764,14 @@ void nvrtc_supported_archs(int* archs)
1763
1764
  }
1764
1765
  }
1765
1766
 
1766
- int cuda_device_get_count()
1767
+ int wp_cuda_device_get_count()
1767
1768
  {
1768
1769
  int count = 0;
1769
1770
  check_cu(cuDeviceGetCount_f(&count));
1770
1771
  return count;
1771
1772
  }
1772
1773
 
1773
- void* cuda_device_get_primary_context(int ordinal)
1774
+ void* wp_cuda_device_get_primary_context(int ordinal)
1774
1775
  {
1775
1776
  if (ordinal >= 0 && ordinal < int(g_devices.size()))
1776
1777
  {
@@ -1786,75 +1787,75 @@ void* cuda_device_get_primary_context(int ordinal)
1786
1787
  return NULL;
1787
1788
  }
1788
1789
 
1789
- const char* cuda_device_get_name(int ordinal)
1790
+ const char* wp_cuda_device_get_name(int ordinal)
1790
1791
  {
1791
1792
  if (ordinal >= 0 && ordinal < int(g_devices.size()))
1792
1793
  return g_devices[ordinal].name;
1793
1794
  return NULL;
1794
1795
  }
1795
1796
 
1796
- int cuda_device_get_arch(int ordinal)
1797
+ int wp_cuda_device_get_arch(int ordinal)
1797
1798
  {
1798
1799
  if (ordinal >= 0 && ordinal < int(g_devices.size()))
1799
1800
  return g_devices[ordinal].arch;
1800
1801
  return 0;
1801
1802
  }
1802
1803
 
1803
- int cuda_device_get_sm_count(int ordinal)
1804
+ int wp_cuda_device_get_sm_count(int ordinal)
1804
1805
  {
1805
1806
  if (ordinal >= 0 && ordinal < int(g_devices.size()))
1806
1807
  return g_devices[ordinal].sm_count;
1807
1808
  return 0;
1808
1809
  }
1809
1810
 
1810
- void cuda_device_get_uuid(int ordinal, char uuid[16])
1811
+ void wp_cuda_device_get_uuid(int ordinal, char uuid[16])
1811
1812
  {
1812
1813
  memcpy(uuid, g_devices[ordinal].uuid.bytes, sizeof(char)*16);
1813
1814
  }
1814
1815
 
1815
- int cuda_device_get_pci_domain_id(int ordinal)
1816
+ int wp_cuda_device_get_pci_domain_id(int ordinal)
1816
1817
  {
1817
1818
  if (ordinal >= 0 && ordinal < int(g_devices.size()))
1818
1819
  return g_devices[ordinal].pci_domain_id;
1819
1820
  return -1;
1820
1821
  }
1821
1822
 
1822
- int cuda_device_get_pci_bus_id(int ordinal)
1823
+ int wp_cuda_device_get_pci_bus_id(int ordinal)
1823
1824
  {
1824
1825
  if (ordinal >= 0 && ordinal < int(g_devices.size()))
1825
1826
  return g_devices[ordinal].pci_bus_id;
1826
1827
  return -1;
1827
1828
  }
1828
1829
 
1829
- int cuda_device_get_pci_device_id(int ordinal)
1830
+ int wp_cuda_device_get_pci_device_id(int ordinal)
1830
1831
  {
1831
1832
  if (ordinal >= 0 && ordinal < int(g_devices.size()))
1832
1833
  return g_devices[ordinal].pci_device_id;
1833
1834
  return -1;
1834
1835
  }
1835
1836
 
1836
- int cuda_device_is_uva(int ordinal)
1837
+ int wp_cuda_device_is_uva(int ordinal)
1837
1838
  {
1838
1839
  if (ordinal >= 0 && ordinal < int(g_devices.size()))
1839
1840
  return g_devices[ordinal].is_uva;
1840
1841
  return 0;
1841
1842
  }
1842
1843
 
1843
- int cuda_device_is_mempool_supported(int ordinal)
1844
+ int wp_cuda_device_is_mempool_supported(int ordinal)
1844
1845
  {
1845
1846
  if (ordinal >= 0 && ordinal < int(g_devices.size()))
1846
1847
  return g_devices[ordinal].is_mempool_supported;
1847
1848
  return 0;
1848
1849
  }
1849
1850
 
1850
- int cuda_device_is_ipc_supported(int ordinal)
1851
+ int wp_cuda_device_is_ipc_supported(int ordinal)
1851
1852
  {
1852
1853
  if (ordinal >= 0 && ordinal < int(g_devices.size()))
1853
1854
  return g_devices[ordinal].is_ipc_supported;
1854
1855
  return 0;
1855
1856
  }
1856
1857
 
1857
- int cuda_device_set_mempool_release_threshold(int ordinal, uint64_t threshold)
1858
+ int wp_cuda_device_set_mempool_release_threshold(int ordinal, uint64_t threshold)
1858
1859
  {
1859
1860
  if (ordinal < 0 || ordinal > int(g_devices.size()))
1860
1861
  {
@@ -1881,7 +1882,7 @@ int cuda_device_set_mempool_release_threshold(int ordinal, uint64_t threshold)
1881
1882
  return 1; // success
1882
1883
  }
1883
1884
 
1884
- uint64_t cuda_device_get_mempool_release_threshold(int ordinal)
1885
+ uint64_t wp_cuda_device_get_mempool_release_threshold(int ordinal)
1885
1886
  {
1886
1887
  if (ordinal < 0 || ordinal > int(g_devices.size()))
1887
1888
  {
@@ -1909,7 +1910,7 @@ uint64_t cuda_device_get_mempool_release_threshold(int ordinal)
1909
1910
  return threshold;
1910
1911
  }
1911
1912
 
1912
- uint64_t cuda_device_get_mempool_used_mem_current(int ordinal)
1913
+ uint64_t wp_cuda_device_get_mempool_used_mem_current(int ordinal)
1913
1914
  {
1914
1915
  if (ordinal < 0 || ordinal > int(g_devices.size()))
1915
1916
  {
@@ -1937,7 +1938,7 @@ uint64_t cuda_device_get_mempool_used_mem_current(int ordinal)
1937
1938
  return mem_used;
1938
1939
  }
1939
1940
 
1940
- uint64_t cuda_device_get_mempool_used_mem_high(int ordinal)
1941
+ uint64_t wp_cuda_device_get_mempool_used_mem_high(int ordinal)
1941
1942
  {
1942
1943
  if (ordinal < 0 || ordinal > int(g_devices.size()))
1943
1944
  {
@@ -1965,7 +1966,7 @@ uint64_t cuda_device_get_mempool_used_mem_high(int ordinal)
1965
1966
  return mem_high_water_mark;
1966
1967
  }
1967
1968
 
1968
- void cuda_device_get_memory_info(int ordinal, size_t* free_mem, size_t* total_mem)
1969
+ void wp_cuda_device_get_memory_info(int ordinal, size_t* free_mem, size_t* total_mem)
1969
1970
  {
1970
1971
  // use temporary storage if user didn't specify pointers
1971
1972
  size_t tmp_free_mem, tmp_total_mem;
@@ -2002,12 +2003,12 @@ void cuda_device_get_memory_info(int ordinal, size_t* free_mem, size_t* total_me
2002
2003
  }
2003
2004
 
2004
2005
 
2005
- void* cuda_context_get_current()
2006
+ void* wp_cuda_context_get_current()
2006
2007
  {
2007
2008
  return get_current_context();
2008
2009
  }
2009
2010
 
2010
- void cuda_context_set_current(void* context)
2011
+ void wp_cuda_context_set_current(void* context)
2011
2012
  {
2012
2013
  CUcontext ctx = static_cast<CUcontext>(context);
2013
2014
  CUcontext prev_ctx = NULL;
@@ -2018,18 +2019,18 @@ void cuda_context_set_current(void* context)
2018
2019
  }
2019
2020
  }
2020
2021
 
2021
- void cuda_context_push_current(void* context)
2022
+ void wp_cuda_context_push_current(void* context)
2022
2023
  {
2023
2024
  check_cu(cuCtxPushCurrent_f(static_cast<CUcontext>(context)));
2024
2025
  }
2025
2026
 
2026
- void cuda_context_pop_current()
2027
+ void wp_cuda_context_pop_current()
2027
2028
  {
2028
2029
  CUcontext context;
2029
2030
  check_cu(cuCtxPopCurrent_f(&context));
2030
2031
  }
2031
2032
 
2032
- void* cuda_context_create(int device_ordinal)
2033
+ void* wp_cuda_context_create(int device_ordinal)
2033
2034
  {
2034
2035
  CUcontext ctx = NULL;
2035
2036
  CUdevice device;
@@ -2038,15 +2039,15 @@ void* cuda_context_create(int device_ordinal)
2038
2039
  return ctx;
2039
2040
  }
2040
2041
 
2041
- void cuda_context_destroy(void* context)
2042
+ void wp_cuda_context_destroy(void* context)
2042
2043
  {
2043
2044
  if (context)
2044
2045
  {
2045
2046
  CUcontext ctx = static_cast<CUcontext>(context);
2046
2047
 
2047
2048
  // ensure this is not the current context
2048
- if (ctx == cuda_context_get_current())
2049
- cuda_context_set_current(NULL);
2049
+ if (ctx == wp_cuda_context_get_current())
2050
+ wp_cuda_context_set_current(NULL);
2050
2051
 
2051
2052
  // release the cached info about this context
2052
2053
  ContextInfo* info = get_context_info(ctx);
@@ -2065,7 +2066,7 @@ void cuda_context_destroy(void* context)
2065
2066
  }
2066
2067
  }
2067
2068
 
2068
- void cuda_context_synchronize(void* context)
2069
+ void wp_cuda_context_synchronize(void* context)
2069
2070
  {
2070
2071
  ContextGuard guard(context);
2071
2072
 
@@ -2079,10 +2080,10 @@ void cuda_context_synchronize(void* context)
2079
2080
 
2080
2081
  unload_deferred_modules(context);
2081
2082
 
2082
- // check_cuda(cudaDeviceGraphMemTrim(cuda_context_get_device_ordinal(context)));
2083
+ // check_cuda(cudaDeviceGraphMemTrim(wp_cuda_context_get_device_ordinal(context)));
2083
2084
  }
2084
2085
 
2085
- uint64_t cuda_context_check(void* context)
2086
+ uint64_t wp_cuda_context_check(void* context)
2086
2087
  {
2087
2088
  ContextGuard guard(context);
2088
2089
 
@@ -2104,13 +2105,13 @@ uint64_t cuda_context_check(void* context)
2104
2105
  }
2105
2106
 
2106
2107
 
2107
- int cuda_context_get_device_ordinal(void* context)
2108
+ int wp_cuda_context_get_device_ordinal(void* context)
2108
2109
  {
2109
2110
  ContextInfo* info = get_context_info(static_cast<CUcontext>(context));
2110
2111
  return info && info->device_info ? info->device_info->ordinal : -1;
2111
2112
  }
2112
2113
 
2113
- int cuda_context_is_primary(void* context)
2114
+ int wp_cuda_context_is_primary(void* context)
2114
2115
  {
2115
2116
  CUcontext ctx = static_cast<CUcontext>(context);
2116
2117
  ContextInfo* context_info = get_context_info(ctx);
@@ -2137,7 +2138,7 @@ int cuda_context_is_primary(void* context)
2137
2138
  return 0;
2138
2139
  }
2139
2140
 
2140
- void* cuda_context_get_stream(void* context)
2141
+ void* wp_cuda_context_get_stream(void* context)
2141
2142
  {
2142
2143
  ContextInfo* info = get_context_info(static_cast<CUcontext>(context));
2143
2144
  if (info)
@@ -2147,7 +2148,7 @@ void* cuda_context_get_stream(void* context)
2147
2148
  return NULL;
2148
2149
  }
2149
2150
 
2150
- void cuda_context_set_stream(void* context, void* stream, int sync)
2151
+ void wp_cuda_context_set_stream(void* context, void* stream, int sync)
2151
2152
  {
2152
2153
  ContextInfo* context_info = get_context_info(static_cast<CUcontext>(context));
2153
2154
  if (context_info)
@@ -2171,7 +2172,7 @@ void cuda_context_set_stream(void* context, void* stream, int sync)
2171
2172
  }
2172
2173
  }
2173
2174
 
2174
- int cuda_is_peer_access_supported(int target_ordinal, int peer_ordinal)
2175
+ int wp_cuda_is_peer_access_supported(int target_ordinal, int peer_ordinal)
2175
2176
  {
2176
2177
  int num_devices = int(g_devices.size());
2177
2178
 
@@ -2196,7 +2197,7 @@ int cuda_is_peer_access_supported(int target_ordinal, int peer_ordinal)
2196
2197
  return can_access;
2197
2198
  }
2198
2199
 
2199
- int cuda_is_peer_access_enabled(void* target_context, void* peer_context)
2200
+ int wp_cuda_is_peer_access_enabled(void* target_context, void* peer_context)
2200
2201
  {
2201
2202
  if (!target_context || !peer_context)
2202
2203
  {
@@ -2207,8 +2208,8 @@ int cuda_is_peer_access_enabled(void* target_context, void* peer_context)
2207
2208
  if (target_context == peer_context)
2208
2209
  return 1;
2209
2210
 
2210
- int target_ordinal = cuda_context_get_device_ordinal(target_context);
2211
- int peer_ordinal = cuda_context_get_device_ordinal(peer_context);
2211
+ int target_ordinal = wp_cuda_context_get_device_ordinal(target_context);
2212
+ int peer_ordinal = wp_cuda_context_get_device_ordinal(peer_context);
2212
2213
 
2213
2214
  // check if peer access is supported
2214
2215
  int can_access = 0;
@@ -2241,7 +2242,7 @@ int cuda_is_peer_access_enabled(void* target_context, void* peer_context)
2241
2242
  }
2242
2243
  }
2243
2244
 
2244
- int cuda_set_peer_access_enabled(void* target_context, void* peer_context, int enable)
2245
+ int wp_cuda_set_peer_access_enabled(void* target_context, void* peer_context, int enable)
2245
2246
  {
2246
2247
  if (!target_context || !peer_context)
2247
2248
  {
@@ -2252,8 +2253,8 @@ int cuda_set_peer_access_enabled(void* target_context, void* peer_context, int e
2252
2253
  if (target_context == peer_context)
2253
2254
  return 1; // no-op
2254
2255
 
2255
- int target_ordinal = cuda_context_get_device_ordinal(target_context);
2256
- int peer_ordinal = cuda_context_get_device_ordinal(peer_context);
2256
+ int target_ordinal = wp_cuda_context_get_device_ordinal(target_context);
2257
+ int peer_ordinal = wp_cuda_context_get_device_ordinal(peer_context);
2257
2258
 
2258
2259
  // check if peer access is supported
2259
2260
  int can_access = 0;
@@ -2298,7 +2299,7 @@ int cuda_set_peer_access_enabled(void* target_context, void* peer_context, int e
2298
2299
  return 1; // success
2299
2300
  }
2300
2301
 
2301
- int cuda_is_mempool_access_enabled(int target_ordinal, int peer_ordinal)
2302
+ int wp_cuda_is_mempool_access_enabled(int target_ordinal, int peer_ordinal)
2302
2303
  {
2303
2304
  int num_devices = int(g_devices.size());
2304
2305
 
@@ -2334,7 +2335,7 @@ int cuda_is_mempool_access_enabled(int target_ordinal, int peer_ordinal)
2334
2335
  return 0;
2335
2336
  }
2336
2337
 
2337
- int cuda_set_mempool_access_enabled(int target_ordinal, int peer_ordinal, int enable)
2338
+ int wp_cuda_set_mempool_access_enabled(int target_ordinal, int peer_ordinal, int enable)
2338
2339
  {
2339
2340
  int num_devices = int(g_devices.size());
2340
2341
 
@@ -2380,13 +2381,13 @@ int cuda_set_mempool_access_enabled(int target_ordinal, int peer_ordinal, int en
2380
2381
  return 1; // success
2381
2382
  }
2382
2383
 
2383
- void cuda_ipc_get_mem_handle(void* ptr, char* out_buffer) {
2384
+ void wp_cuda_ipc_get_mem_handle(void* ptr, char* out_buffer) {
2384
2385
  CUipcMemHandle memHandle;
2385
2386
  check_cu(cuIpcGetMemHandle_f(&memHandle, (CUdeviceptr)ptr));
2386
2387
  memcpy(out_buffer, memHandle.reserved, CU_IPC_HANDLE_SIZE);
2387
2388
  }
2388
2389
 
2389
- void* cuda_ipc_open_mem_handle(void* context, char* handle) {
2390
+ void* wp_cuda_ipc_open_mem_handle(void* context, char* handle) {
2390
2391
  ContextGuard guard(context);
2391
2392
 
2392
2393
  CUipcMemHandle memHandle;
@@ -2401,11 +2402,11 @@ void* cuda_ipc_open_mem_handle(void* context, char* handle) {
2401
2402
  return NULL;
2402
2403
  }
2403
2404
 
2404
- void cuda_ipc_close_mem_handle(void* ptr) {
2405
+ void wp_cuda_ipc_close_mem_handle(void* ptr) {
2405
2406
  check_cu(cuIpcCloseMemHandle_f((CUdeviceptr) ptr));
2406
2407
  }
2407
2408
 
2408
- void cuda_ipc_get_event_handle(void* context, void* event, char* out_buffer) {
2409
+ void wp_cuda_ipc_get_event_handle(void* context, void* event, char* out_buffer) {
2409
2410
  ContextGuard guard(context);
2410
2411
 
2411
2412
  CUipcEventHandle eventHandle;
@@ -2413,7 +2414,7 @@ void cuda_ipc_get_event_handle(void* context, void* event, char* out_buffer) {
2413
2414
  memcpy(out_buffer, eventHandle.reserved, CU_IPC_HANDLE_SIZE);
2414
2415
  }
2415
2416
 
2416
- void* cuda_ipc_open_event_handle(void* context, char* handle) {
2417
+ void* wp_cuda_ipc_open_event_handle(void* context, char* handle) {
2417
2418
  ContextGuard guard(context);
2418
2419
 
2419
2420
  CUipcEventHandle eventHandle;
@@ -2427,31 +2428,34 @@ void* cuda_ipc_open_event_handle(void* context, char* handle) {
2427
2428
  return NULL;
2428
2429
  }
2429
2430
 
2430
- void* cuda_stream_create(void* context, int priority)
2431
+ void* wp_cuda_stream_create(void* context, int priority)
2431
2432
  {
2432
2433
  ContextGuard guard(context, true);
2433
2434
 
2434
2435
  CUstream stream;
2435
2436
  if (check_cu(cuStreamCreateWithPriority_f(&stream, CU_STREAM_DEFAULT, priority)))
2436
2437
  {
2437
- cuda_stream_register(WP_CURRENT_CONTEXT, stream);
2438
+ wp_cuda_stream_register(WP_CURRENT_CONTEXT, stream);
2438
2439
  return stream;
2439
2440
  }
2440
2441
  else
2441
2442
  return NULL;
2442
2443
  }
2443
2444
 
2444
- void cuda_stream_destroy(void* context, void* stream)
2445
+ void wp_cuda_stream_destroy(void* context, void* stream)
2445
2446
  {
2446
2447
  if (!stream)
2447
2448
  return;
2448
2449
 
2449
- cuda_stream_unregister(context, stream);
2450
+ wp_cuda_stream_unregister(context, stream);
2451
+
2452
+ // release temporary radix sort buffer associated with this stream
2453
+ radix_sort_release(context, stream);
2450
2454
 
2451
2455
  check_cu(cuStreamDestroy_f(static_cast<CUstream>(stream)));
2452
2456
  }
2453
2457
 
2454
- int cuda_stream_query(void* stream)
2458
+ int wp_cuda_stream_query(void* stream)
2455
2459
  {
2456
2460
  CUresult res = cuStreamQuery_f(static_cast<CUstream>(stream));
2457
2461
 
@@ -2464,7 +2468,7 @@ int cuda_stream_query(void* stream)
2464
2468
  return res;
2465
2469
  }
2466
2470
 
2467
- void cuda_stream_register(void* context, void* stream)
2471
+ void wp_cuda_stream_register(void* context, void* stream)
2468
2472
  {
2469
2473
  if (!stream)
2470
2474
  return;
@@ -2476,7 +2480,7 @@ void cuda_stream_register(void* context, void* stream)
2476
2480
  check_cu(cuEventCreate_f(&stream_info.cached_event, CU_EVENT_DISABLE_TIMING));
2477
2481
  }
2478
2482
 
2479
- void cuda_stream_unregister(void* context, void* stream)
2483
+ void wp_cuda_stream_unregister(void* context, void* stream)
2480
2484
  {
2481
2485
  if (!stream)
2482
2486
  return;
@@ -2500,28 +2504,28 @@ void cuda_stream_unregister(void* context, void* stream)
2500
2504
  }
2501
2505
  }
2502
2506
 
2503
- void* cuda_stream_get_current()
2507
+ void* wp_cuda_stream_get_current()
2504
2508
  {
2505
2509
  return get_current_stream();
2506
2510
  }
2507
2511
 
2508
- void cuda_stream_synchronize(void* stream)
2512
+ void wp_cuda_stream_synchronize(void* stream)
2509
2513
  {
2510
2514
  check_cu(cuStreamSynchronize_f(static_cast<CUstream>(stream)));
2511
2515
  }
2512
2516
 
2513
- void cuda_stream_wait_event(void* stream, void* event)
2517
+ void wp_cuda_stream_wait_event(void* stream, void* event)
2514
2518
  {
2515
2519
  check_cu(cuStreamWaitEvent_f(static_cast<CUstream>(stream), static_cast<CUevent>(event), 0));
2516
2520
  }
2517
2521
 
2518
- void cuda_stream_wait_stream(void* stream, void* other_stream, void* event)
2522
+ void wp_cuda_stream_wait_stream(void* stream, void* other_stream, void* event)
2519
2523
  {
2520
2524
  check_cu(cuEventRecord_f(static_cast<CUevent>(event), static_cast<CUstream>(other_stream)));
2521
2525
  check_cu(cuStreamWaitEvent_f(static_cast<CUstream>(stream), static_cast<CUevent>(event), 0));
2522
2526
  }
2523
2527
 
2524
- int cuda_stream_is_capturing(void* stream)
2528
+ int wp_cuda_stream_is_capturing(void* stream)
2525
2529
  {
2526
2530
  cudaStreamCaptureStatus status = cudaStreamCaptureStatusNone;
2527
2531
  check_cuda(cudaStreamIsCapturing(static_cast<cudaStream_t>(stream), &status));
@@ -2529,12 +2533,12 @@ int cuda_stream_is_capturing(void* stream)
2529
2533
  return int(status != cudaStreamCaptureStatusNone);
2530
2534
  }
2531
2535
 
2532
- uint64_t cuda_stream_get_capture_id(void* stream)
2536
+ uint64_t wp_cuda_stream_get_capture_id(void* stream)
2533
2537
  {
2534
2538
  return get_capture_id(static_cast<CUstream>(stream));
2535
2539
  }
2536
2540
 
2537
- int cuda_stream_get_priority(void* stream)
2541
+ int wp_cuda_stream_get_priority(void* stream)
2538
2542
  {
2539
2543
  int priority = 0;
2540
2544
  check_cuda(cuStreamGetPriority_f(static_cast<CUstream>(stream), &priority));
@@ -2542,7 +2546,7 @@ int cuda_stream_get_priority(void* stream)
2542
2546
  return priority;
2543
2547
  }
2544
2548
 
2545
- void* cuda_event_create(void* context, unsigned flags)
2549
+ void* wp_cuda_event_create(void* context, unsigned flags)
2546
2550
  {
2547
2551
  ContextGuard guard(context, true);
2548
2552
 
@@ -2553,12 +2557,12 @@ void* cuda_event_create(void* context, unsigned flags)
2553
2557
  return NULL;
2554
2558
  }
2555
2559
 
2556
- void cuda_event_destroy(void* event)
2560
+ void wp_cuda_event_destroy(void* event)
2557
2561
  {
2558
2562
  check_cu(cuEventDestroy_f(static_cast<CUevent>(event)));
2559
2563
  }
2560
2564
 
2561
- int cuda_event_query(void* event)
2565
+ int wp_cuda_event_query(void* event)
2562
2566
  {
2563
2567
  CUresult res = cuEventQuery_f(static_cast<CUevent>(event));
2564
2568
 
@@ -2571,9 +2575,9 @@ int cuda_event_query(void* event)
2571
2575
  return res;
2572
2576
  }
2573
2577
 
2574
- void cuda_event_record(void* event, void* stream, bool timing)
2578
+ void wp_cuda_event_record(void* event, void* stream, bool timing)
2575
2579
  {
2576
- if (timing && !g_captures.empty() && cuda_stream_is_capturing(stream))
2580
+ if (timing && !g_captures.empty() && wp_cuda_stream_is_capturing(stream))
2577
2581
  {
2578
2582
  // record timing event during graph capture
2579
2583
  check_cu(cuEventRecordWithFlags_f(static_cast<CUevent>(event), static_cast<CUstream>(stream), CU_EVENT_RECORD_EXTERNAL));
@@ -2584,12 +2588,12 @@ void cuda_event_record(void* event, void* stream, bool timing)
2584
2588
  }
2585
2589
  }
2586
2590
 
2587
- void cuda_event_synchronize(void* event)
2591
+ void wp_cuda_event_synchronize(void* event)
2588
2592
  {
2589
2593
  check_cu(cuEventSynchronize_f(static_cast<CUevent>(event)));
2590
2594
  }
2591
2595
 
2592
- float cuda_event_elapsed_time(void* start_event, void* end_event)
2596
+ float wp_cuda_event_elapsed_time(void* start_event, void* end_event)
2593
2597
  {
2594
2598
  float elapsed = 0.0f;
2595
2599
  cudaEvent_t start = static_cast<cudaEvent_t>(start_event);
@@ -2598,7 +2602,7 @@ float cuda_event_elapsed_time(void* start_event, void* end_event)
2598
2602
  return elapsed;
2599
2603
  }
2600
2604
 
2601
- bool cuda_graph_begin_capture(void* context, void* stream, int external)
2605
+ bool wp_cuda_graph_begin_capture(void* context, void* stream, int external)
2602
2606
  {
2603
2607
  ContextGuard guard(context);
2604
2608
 
@@ -2645,7 +2649,7 @@ bool cuda_graph_begin_capture(void* context, void* stream, int external)
2645
2649
  return true;
2646
2650
  }
2647
2651
 
2648
- bool cuda_graph_end_capture(void* context, void* stream, void** graph_ret)
2652
+ bool wp_cuda_graph_end_capture(void* context, void* stream, void** graph_ret)
2649
2653
  {
2650
2654
  ContextGuard guard(context);
2651
2655
 
@@ -2780,14 +2784,14 @@ bool cuda_graph_end_capture(void* context, void* stream, void** graph_ret)
2780
2784
  return true;
2781
2785
  }
2782
2786
 
2783
- bool capture_debug_dot_print(void* graph, const char *path, uint32_t flags)
2787
+ bool wp_capture_debug_dot_print(void* graph, const char *path, uint32_t flags)
2784
2788
  {
2785
2789
  if (!check_cuda(cudaGraphDebugDotPrint((cudaGraph_t)graph, path, flags)))
2786
2790
  return false;
2787
2791
  return true;
2788
2792
  }
2789
2793
 
2790
- bool cuda_graph_create_exec(void* context, void* stream, void* graph, void** graph_exec_ret)
2794
+ bool wp_cuda_graph_create_exec(void* context, void* stream, void* graph, void** graph_exec_ret)
2791
2795
  {
2792
2796
  ContextGuard guard(context);
2793
2797
 
@@ -2811,11 +2815,12 @@ bool cuda_graph_create_exec(void* context, void* stream, void* graph, void** gra
2811
2815
  // Support for conditional graph nodes available with CUDA 12.4+.
2812
2816
  #if CUDA_VERSION >= 12040
2813
2817
 
2814
- // CUBIN data for compiled conditional modules, loaded on demand, keyed on device architecture
2815
- static std::map<int, void*> g_conditional_cubins;
2818
+ // CUBIN or PTX data for compiled conditional modules, loaded on demand, keyed on device architecture
2819
+ using ModuleKey = std::pair<int, bool>; // <arch, use_ptx>
2820
+ static std::map<ModuleKey, void*> g_conditional_modules;
2816
2821
 
2817
2822
  // Compile module with conditional helper kernels
2818
- static void* compile_conditional_module(int arch)
2823
+ static void* compile_conditional_module(int arch, bool use_ptx)
2819
2824
  {
2820
2825
  static const char* kernel_source = R"(
2821
2826
  typedef __device_builtin__ unsigned long long cudaGraphConditionalHandle;
@@ -2844,8 +2849,9 @@ static void* compile_conditional_module(int arch)
2844
2849
  )";
2845
2850
 
2846
2851
  // avoid recompilation
2847
- auto it = g_conditional_cubins.find(arch);
2848
- if (it != g_conditional_cubins.end())
2852
+ ModuleKey key = {arch, use_ptx};
2853
+ auto it = g_conditional_modules.find(key);
2854
+ if (it != g_conditional_modules.end())
2849
2855
  return it->second;
2850
2856
 
2851
2857
  nvrtcProgram prog;
@@ -2853,11 +2859,23 @@ static void* compile_conditional_module(int arch)
2853
2859
  return NULL;
2854
2860
 
2855
2861
  char arch_opt[128];
2856
- snprintf(arch_opt, sizeof(arch_opt), "--gpu-architecture=sm_%d", arch);
2862
+ if (use_ptx)
2863
+ snprintf(arch_opt, sizeof(arch_opt), "--gpu-architecture=compute_%d", arch);
2864
+ else
2865
+ snprintf(arch_opt, sizeof(arch_opt), "--gpu-architecture=sm_%d", arch);
2857
2866
 
2858
2867
  std::vector<const char*> opts;
2859
2868
  opts.push_back(arch_opt);
2860
2869
 
2870
+ const bool print_debug = (std::getenv("WARP_DEBUG") != nullptr);
2871
+ if (print_debug)
2872
+ {
2873
+ printf("NVRTC options (conditional module, arch=%d, use_ptx=%s):\n", arch, use_ptx ? "true" : "false");
2874
+ for(auto o: opts) {
2875
+ printf("%s\n", o);
2876
+ }
2877
+ }
2878
+
2861
2879
  if (!check_nvrtc(nvrtcCompileProgram(prog, int(opts.size()), opts.data())))
2862
2880
  {
2863
2881
  size_t log_size;
@@ -2874,23 +2892,37 @@ static void* compile_conditional_module(int arch)
2874
2892
  // get output
2875
2893
  char* output = NULL;
2876
2894
  size_t output_size = 0;
2877
- check_nvrtc(nvrtcGetCUBINSize(prog, &output_size));
2878
- if (output_size > 0)
2895
+
2896
+ if (use_ptx)
2897
+ {
2898
+ check_nvrtc(nvrtcGetPTXSize(prog, &output_size));
2899
+ if (output_size > 0)
2900
+ {
2901
+ output = new char[output_size];
2902
+ if (check_nvrtc(nvrtcGetPTX(prog, output)))
2903
+ g_conditional_modules[key] = output;
2904
+ }
2905
+ }
2906
+ else
2879
2907
  {
2880
- output = new char[output_size];
2881
- if (check_nvrtc(nvrtcGetCUBIN(prog, output)))
2882
- g_conditional_cubins[arch] = output;
2908
+ check_nvrtc(nvrtcGetCUBINSize(prog, &output_size));
2909
+ if (output_size > 0)
2910
+ {
2911
+ output = new char[output_size];
2912
+ if (check_nvrtc(nvrtcGetCUBIN(prog, output)))
2913
+ g_conditional_modules[key] = output;
2914
+ }
2883
2915
  }
2884
2916
 
2885
2917
  nvrtcDestroyProgram(&prog);
2886
2918
 
2887
- // return CUBIN data
2919
+ // return CUBIN or PTX data
2888
2920
  return output;
2889
2921
  }
2890
2922
 
2891
2923
 
2892
2924
  // Load module with conditional helper kernels
2893
- static CUmodule load_conditional_module(void* context)
2925
+ static CUmodule load_conditional_module(void* context, int arch, bool use_ptx)
2894
2926
  {
2895
2927
  ContextInfo* context_info = get_context_info(context);
2896
2928
  if (!context_info)
@@ -2900,17 +2932,15 @@ static CUmodule load_conditional_module(void* context)
2900
2932
  if (context_info->conditional_module)
2901
2933
  return context_info->conditional_module;
2902
2934
 
2903
- int arch = context_info->device_info->arch;
2904
-
2905
2935
  // compile if needed
2906
- void* compiled_module = compile_conditional_module(arch);
2936
+ void* compiled_module = compile_conditional_module(arch, use_ptx);
2907
2937
  if (!compiled_module)
2908
2938
  {
2909
2939
  fprintf(stderr, "Warp error: Failed to compile conditional kernels\n");
2910
2940
  return NULL;
2911
2941
  }
2912
2942
 
2913
- // load module
2943
+ // load module (handles both PTX and CUBIN data automatically)
2914
2944
  CUmodule module = NULL;
2915
2945
  if (!check_cu(cuModuleLoadDataEx_f(&module, compiled_module, 0, NULL, NULL)))
2916
2946
  {
@@ -2923,10 +2953,10 @@ static CUmodule load_conditional_module(void* context)
2923
2953
  return module;
2924
2954
  }
2925
2955
 
2926
- static CUfunction get_conditional_kernel(void* context, const char* name)
2956
+ static CUfunction get_conditional_kernel(void* context, int arch, bool use_ptx, const char* name)
2927
2957
  {
2928
2958
  // load module if needed
2929
- CUmodule module = load_conditional_module(context);
2959
+ CUmodule module = load_conditional_module(context, arch, use_ptx);
2930
2960
  if (!module)
2931
2961
  return NULL;
2932
2962
 
@@ -2940,7 +2970,7 @@ static CUfunction get_conditional_kernel(void* context, const char* name)
2940
2970
  return kernel;
2941
2971
  }
2942
2972
 
2943
- bool cuda_graph_pause_capture(void* context, void* stream, void** graph_ret)
2973
+ bool wp_cuda_graph_pause_capture(void* context, void* stream, void** graph_ret)
2944
2974
  {
2945
2975
  ContextGuard guard(context);
2946
2976
 
@@ -2950,7 +2980,7 @@ bool cuda_graph_pause_capture(void* context, void* stream, void** graph_ret)
2950
2980
  return true;
2951
2981
  }
2952
2982
 
2953
- bool cuda_graph_resume_capture(void* context, void* stream, void* graph)
2983
+ bool wp_cuda_graph_resume_capture(void* context, void* stream, void* graph)
2954
2984
  {
2955
2985
  ContextGuard guard(context);
2956
2986
 
@@ -2976,7 +3006,7 @@ bool cuda_graph_resume_capture(void* context, void* stream, void* graph)
2976
3006
  // https://developer.nvidia.com/blog/dynamic-control-flow-in-cuda-graphs-with-conditional-nodes/
2977
3007
  // condition is a gpu pointer
2978
3008
  // if_graph_ret and else_graph_ret should be NULL if not needed
2979
- bool cuda_graph_insert_if_else(void* context, void* stream, int* condition, void** if_graph_ret, void** else_graph_ret)
3009
+ bool wp_cuda_graph_insert_if_else(void* context, void* stream, int arch, bool use_ptx, int* condition, void** if_graph_ret, void** else_graph_ret)
2980
3010
  {
2981
3011
  bool has_if = if_graph_ret != NULL;
2982
3012
  bool has_else = else_graph_ret != NULL;
@@ -2991,21 +3021,21 @@ bool cuda_graph_insert_if_else(void* context, void* stream, int* condition, void
2991
3021
  CUstream cuda_stream = static_cast<CUstream>(stream);
2992
3022
 
2993
3023
  // Get the current stream capturing graph
2994
- cudaStreamCaptureStatus capture_status = cudaStreamCaptureStatusNone;
3024
+ CUstreamCaptureStatus capture_status = CU_STREAM_CAPTURE_STATUS_NONE;
2995
3025
  cudaGraph_t cuda_graph = NULL;
2996
3026
  const cudaGraphNode_t* capture_deps = NULL;
2997
3027
  size_t dep_count = 0;
2998
- if (!check_cuda(cudaStreamGetCaptureInfo(cuda_stream, &capture_status, nullptr, &cuda_graph, &capture_deps, &dep_count)))
3028
+ if (!check_cu(cuStreamGetCaptureInfo_f(cuda_stream, &capture_status, nullptr, &cuda_graph, &capture_deps, &dep_count)))
2999
3029
  return false;
3000
3030
 
3001
3031
  // abort if not capturing
3002
- if (!cuda_graph || capture_status != cudaStreamCaptureStatusActive)
3032
+ if (!cuda_graph || capture_status != CU_STREAM_CAPTURE_STATUS_ACTIVE)
3003
3033
  {
3004
3034
  wp::set_error_string("Stream is not capturing");
3005
3035
  return false;
3006
3036
  }
3007
3037
 
3008
- //int driver_version = cuda_driver_version();
3038
+ //int driver_version = wp_cuda_driver_version();
3009
3039
 
3010
3040
  // IF-ELSE nodes are only supported with CUDA 12.8+
3011
3041
  // Somehow child graphs produce wrong results when an else branch is used
@@ -3013,15 +3043,15 @@ bool cuda_graph_insert_if_else(void* context, void* stream, int* condition, void
3013
3043
  if (num_branches == 1 /*|| driver_version >= 12080*/)
3014
3044
  {
3015
3045
  cudaGraphConditionalHandle handle;
3016
- cudaGraphConditionalHandleCreate(&handle, cuda_graph);
3046
+ check_cuda(cudaGraphConditionalHandleCreate(&handle, cuda_graph));
3017
3047
 
3018
3048
  // run a kernel to set the condition handle from the condition pointer
3019
3049
  // (need to negate the condition if only the else branch is used)
3020
3050
  CUfunction kernel;
3021
3051
  if (has_if)
3022
- kernel = get_conditional_kernel(context, "set_conditional_if_handle_kernel");
3052
+ kernel = get_conditional_kernel(context, arch, use_ptx, "set_conditional_if_handle_kernel");
3023
3053
  else
3024
- kernel = get_conditional_kernel(context, "set_conditional_else_handle_kernel");
3054
+ kernel = get_conditional_kernel(context, arch, use_ptx, "set_conditional_else_handle_kernel");
3025
3055
 
3026
3056
  if (!kernel)
3027
3057
  {
@@ -3033,22 +3063,23 @@ bool cuda_graph_insert_if_else(void* context, void* stream, int* condition, void
3033
3063
  kernel_args[0] = &handle;
3034
3064
  kernel_args[1] = &condition;
3035
3065
 
3036
- if (!check_cuda(cuLaunchKernel_f(kernel, 1, 1, 1, 1, 1, 1, 0, cuda_stream, kernel_args, NULL)))
3066
+ if (!check_cu(cuLaunchKernel_f(kernel, 1, 1, 1, 1, 1, 1, 0, cuda_stream, kernel_args, NULL)))
3037
3067
  return false;
3038
3068
 
3039
- if (!check_cuda(cudaStreamGetCaptureInfo(cuda_stream, &capture_status, nullptr, &cuda_graph, &capture_deps, &dep_count)))
3069
+ if (!check_cu(cuStreamGetCaptureInfo_f(cuda_stream, &capture_status, nullptr, &cuda_graph, &capture_deps, &dep_count)))
3040
3070
  return false;
3041
3071
 
3042
3072
  // create conditional node
3043
- cudaGraphNode_t condition_node;
3044
- cudaGraphNodeParams condition_params = { cudaGraphNodeTypeConditional };
3073
+ CUgraphNode condition_node;
3074
+ CUgraphNodeParams condition_params = { CU_GRAPH_NODE_TYPE_CONDITIONAL };
3045
3075
  condition_params.conditional.handle = handle;
3046
- condition_params.conditional.type = cudaGraphCondTypeIf;
3076
+ condition_params.conditional.type = CU_GRAPH_COND_TYPE_IF;
3047
3077
  condition_params.conditional.size = num_branches;
3048
- if (!check_cuda(cudaGraphAddNode(&condition_node, cuda_graph, capture_deps, dep_count, &condition_params)))
3078
+ condition_params.conditional.ctx = get_current_context();
3079
+ if (!check_cu(cuGraphAddNode_f(&condition_node, cuda_graph, capture_deps, NULL, dep_count, &condition_params)))
3049
3080
  return false;
3050
3081
 
3051
- if (!check_cuda(cudaStreamUpdateCaptureDependencies(cuda_stream, &condition_node, 1, cudaStreamSetCaptureDependencies)))
3082
+ if (!check_cu(cuStreamUpdateCaptureDependencies_f(cuda_stream, &condition_node, 1, cudaStreamSetCaptureDependencies)))
3052
3083
  return false;
3053
3084
 
3054
3085
  if (num_branches == 1)
@@ -3068,10 +3099,10 @@ bool cuda_graph_insert_if_else(void* context, void* stream, int* condition, void
3068
3099
  {
3069
3100
  // Create IF node followed by an additional IF node with negated condition
3070
3101
  cudaGraphConditionalHandle if_handle, else_handle;
3071
- cudaGraphConditionalHandleCreate(&if_handle, cuda_graph);
3072
- cudaGraphConditionalHandleCreate(&else_handle, cuda_graph);
3102
+ check_cuda(cudaGraphConditionalHandleCreate(&if_handle, cuda_graph));
3103
+ check_cuda(cudaGraphConditionalHandleCreate(&else_handle, cuda_graph));
3073
3104
 
3074
- CUfunction kernel = get_conditional_kernel(context, "set_conditional_if_else_handles_kernel");
3105
+ CUfunction kernel = get_conditional_kernel(context, arch, use_ptx, "set_conditional_if_else_handles_kernel");
3075
3106
  if (!kernel)
3076
3107
  {
3077
3108
  wp::set_error_string("Failed to get built-in conditional kernel");
@@ -3086,26 +3117,28 @@ bool cuda_graph_insert_if_else(void* context, void* stream, int* condition, void
3086
3117
  if (!check_cu(cuLaunchKernel_f(kernel, 1, 1, 1, 1, 1, 1, 0, cuda_stream, kernel_args, NULL)))
3087
3118
  return false;
3088
3119
 
3089
- if (!check_cuda(cudaStreamGetCaptureInfo(cuda_stream, &capture_status, nullptr, &cuda_graph, &capture_deps, &dep_count)))
3120
+ if (!check_cu(cuStreamGetCaptureInfo_f(cuda_stream, &capture_status, nullptr, &cuda_graph, &capture_deps, &dep_count)))
3090
3121
  return false;
3091
3122
 
3092
- cudaGraphNode_t if_node;
3093
- cudaGraphNodeParams if_params = { cudaGraphNodeTypeConditional };
3123
+ CUgraphNode if_node;
3124
+ CUgraphNodeParams if_params = { CU_GRAPH_NODE_TYPE_CONDITIONAL };
3094
3125
  if_params.conditional.handle = if_handle;
3095
- if_params.conditional.type = cudaGraphCondTypeIf;
3126
+ if_params.conditional.type = CU_GRAPH_COND_TYPE_IF;
3096
3127
  if_params.conditional.size = 1;
3097
- if (!check_cuda(cudaGraphAddNode(&if_node, cuda_graph, capture_deps, dep_count, &if_params)))
3128
+ if_params.conditional.ctx = get_current_context();
3129
+ if (!check_cu(cuGraphAddNode_f(&if_node, cuda_graph, capture_deps, NULL, dep_count, &if_params)))
3098
3130
  return false;
3099
3131
 
3100
- cudaGraphNode_t else_node;
3101
- cudaGraphNodeParams else_params = { cudaGraphNodeTypeConditional };
3132
+ CUgraphNode else_node;
3133
+ CUgraphNodeParams else_params = { CU_GRAPH_NODE_TYPE_CONDITIONAL };
3102
3134
  else_params.conditional.handle = else_handle;
3103
- else_params.conditional.type = cudaGraphCondTypeIf;
3135
+ else_params.conditional.type = CU_GRAPH_COND_TYPE_IF;
3104
3136
  else_params.conditional.size = 1;
3105
- if (!check_cuda(cudaGraphAddNode(&else_node, cuda_graph, &if_node, 1, &else_params)))
3137
+ else_params.conditional.ctx = get_current_context();
3138
+ if (!check_cu(cuGraphAddNode_f(&else_node, cuda_graph, &if_node, NULL, 1, &else_params)))
3106
3139
  return false;
3107
3140
 
3108
- if (!check_cuda(cudaStreamUpdateCaptureDependencies(cuda_stream, &else_node, 1, cudaStreamSetCaptureDependencies)))
3141
+ if (!check_cu(cuStreamUpdateCaptureDependencies_f(cuda_stream, &else_node, 1, cudaStreamSetCaptureDependencies)))
3109
3142
  return false;
3110
3143
 
3111
3144
  *if_graph_ret = if_params.conditional.phGraph_out[0];
@@ -3115,21 +3148,143 @@ bool cuda_graph_insert_if_else(void* context, void* stream, int* condition, void
3115
3148
  return true;
3116
3149
  }
3117
3150
 
3118
- bool cuda_graph_insert_child_graph(void* context, void* stream, void* child_graph)
3151
+ // graph node type names for intelligible error reporting
3152
+ static const char* get_graph_node_type_name(CUgraphNodeType type)
3153
+ {
3154
+ static const std::unordered_map<CUgraphNodeType, const char*> names
3155
+ {
3156
+ {CU_GRAPH_NODE_TYPE_KERNEL, "kernel launch"},
3157
+ {CU_GRAPH_NODE_TYPE_MEMCPY, "memcpy"},
3158
+ {CU_GRAPH_NODE_TYPE_MEMSET, "memset"},
3159
+ {CU_GRAPH_NODE_TYPE_HOST, "host execution"},
3160
+ {CU_GRAPH_NODE_TYPE_GRAPH, "graph launch"},
3161
+ {CU_GRAPH_NODE_TYPE_EMPTY, "empty node"},
3162
+ {CU_GRAPH_NODE_TYPE_WAIT_EVENT, "event wait"},
3163
+ {CU_GRAPH_NODE_TYPE_EVENT_RECORD, "event record"},
3164
+ {CU_GRAPH_NODE_TYPE_EXT_SEMAS_SIGNAL, "semaphore signal"},
3165
+ {CU_GRAPH_NODE_TYPE_EXT_SEMAS_WAIT, "semaphore wait"},
3166
+ {CU_GRAPH_NODE_TYPE_MEM_ALLOC, "memory allocation"},
3167
+ {CU_GRAPH_NODE_TYPE_MEM_FREE, "memory deallocation"},
3168
+ {CU_GRAPH_NODE_TYPE_BATCH_MEM_OP, "batched mem op"},
3169
+ {CU_GRAPH_NODE_TYPE_CONDITIONAL, "conditional node"},
3170
+ };
3171
+
3172
+ auto it = names.find(type);
3173
+ if (it != names.end())
3174
+ return it->second;
3175
+ else
3176
+ return "unknown node";
3177
+ }
3178
+
3179
+ // check if a graph can be launched as a child graph
3180
+ static bool is_valid_child_graph(void* child_graph)
3181
+ {
3182
+ // disallowed child graph nodes according to the documentation of cuGraphAddChildGraphNode()
3183
+ static const std::unordered_set<CUgraphNodeType> disallowed_nodes
3184
+ {
3185
+ CU_GRAPH_NODE_TYPE_MEM_ALLOC,
3186
+ CU_GRAPH_NODE_TYPE_MEM_FREE,
3187
+ CU_GRAPH_NODE_TYPE_CONDITIONAL,
3188
+ };
3189
+
3190
+ if (!child_graph)
3191
+ {
3192
+ wp::set_error_string("Child graph is null");
3193
+ return false;
3194
+ }
3195
+
3196
+ size_t num_nodes = 0;
3197
+ if (!check_cuda(cudaGraphGetNodes((cudaGraph_t)child_graph, NULL, &num_nodes)))
3198
+ return false;
3199
+ std::vector<cudaGraphNode_t> nodes(num_nodes);
3200
+ if (!check_cuda(cudaGraphGetNodes((cudaGraph_t)child_graph, nodes.data(), &num_nodes)))
3201
+ return false;
3202
+
3203
+ for (size_t i = 0; i < num_nodes; i++)
3204
+ {
3205
+ // note: we use the driver API to get the node type, otherwise some nodes are not recognized correctly
3206
+ CUgraphNodeType node_type;
3207
+ check_cu(cuGraphNodeGetType_f(nodes[i], &node_type));
3208
+ auto it = disallowed_nodes.find(node_type);
3209
+ if (it != disallowed_nodes.end())
3210
+ {
3211
+ wp::set_error_string("Child graph contains an unsupported operation (%s)", get_graph_node_type_name(node_type));
3212
+ return false;
3213
+ }
3214
+ }
3215
+
3216
+ return true;
3217
+ }
3218
+
3219
+ // check if a graph can be used as a conditional body graph
3220
+ // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#condtional-node-body-graph-requirements
3221
+ bool wp_cuda_graph_check_conditional_body(void* body_graph)
3119
3222
  {
3223
+ static const std::unordered_set<CUgraphNodeType> allowed_nodes
3224
+ {
3225
+ CU_GRAPH_NODE_TYPE_MEMCPY,
3226
+ CU_GRAPH_NODE_TYPE_MEMSET,
3227
+ CU_GRAPH_NODE_TYPE_KERNEL,
3228
+ CU_GRAPH_NODE_TYPE_GRAPH,
3229
+ CU_GRAPH_NODE_TYPE_EMPTY,
3230
+ CU_GRAPH_NODE_TYPE_CONDITIONAL,
3231
+ };
3232
+
3233
+ if (!body_graph)
3234
+ {
3235
+ wp::set_error_string("Conditional body graph is null");
3236
+ return false;
3237
+ }
3238
+
3239
+ size_t num_nodes = 0;
3240
+ if (!check_cuda(cudaGraphGetNodes((cudaGraph_t)body_graph, NULL, &num_nodes)))
3241
+ return false;
3242
+ std::vector<cudaGraphNode_t> nodes(num_nodes);
3243
+ if (!check_cuda(cudaGraphGetNodes((cudaGraph_t)body_graph, nodes.data(), &num_nodes)))
3244
+ return false;
3245
+
3246
+ for (size_t i = 0; i < num_nodes; i++)
3247
+ {
3248
+ // note: we use the driver API to get the node type, otherwise some nodes are not recognized correctly
3249
+ CUgraphNodeType node_type;
3250
+ check_cu(cuGraphNodeGetType_f(nodes[i], &node_type));
3251
+ if (allowed_nodes.find(node_type) == allowed_nodes.end())
3252
+ {
3253
+ wp::set_error_string("Conditional body graph contains an unsupported operation (%s)", get_graph_node_type_name(node_type));
3254
+ return false;
3255
+ }
3256
+ else if (node_type == CU_GRAPH_NODE_TYPE_GRAPH)
3257
+ {
3258
+ // check nested child graphs recursively
3259
+ cudaGraph_t child_graph = NULL;
3260
+ if (!check_cuda(cudaGraphChildGraphNodeGetGraph(nodes[i], &child_graph)))
3261
+ return false;
3262
+ if (!wp_cuda_graph_check_conditional_body(child_graph))
3263
+ return false;
3264
+ }
3265
+ }
3266
+
3267
+ return true;
3268
+ }
3269
+
3270
+ bool wp_cuda_graph_insert_child_graph(void* context, void* stream, void* child_graph)
3271
+ {
3272
+ if (!is_valid_child_graph(child_graph))
3273
+ return false;
3274
+
3120
3275
  ContextGuard guard(context);
3121
3276
 
3122
3277
  CUstream cuda_stream = static_cast<CUstream>(stream);
3123
3278
 
3124
3279
  // Get the current stream capturing graph
3125
- cudaStreamCaptureStatus capture_status = cudaStreamCaptureStatusNone;
3280
+ CUstreamCaptureStatus capture_status = CU_STREAM_CAPTURE_STATUS_NONE;
3126
3281
  void* cuda_graph = NULL;
3127
- const cudaGraphNode_t* capture_deps = NULL;
3282
+ const CUgraphNode* capture_deps = NULL;
3128
3283
  size_t dep_count = 0;
3129
- if (!check_cuda(cudaStreamGetCaptureInfo(cuda_stream, &capture_status, nullptr, (cudaGraph_t*)&cuda_graph, &capture_deps, &dep_count)))
3284
+ if (!check_cu(cuStreamGetCaptureInfo_f(cuda_stream, &capture_status, nullptr, (cudaGraph_t*)&cuda_graph, &capture_deps, &dep_count)))
3130
3285
  return false;
3131
3286
 
3132
- if (!cuda_graph_pause_capture(context, cuda_stream, &cuda_graph))
3287
+ if (!wp_cuda_graph_pause_capture(context, cuda_stream, &cuda_graph))
3133
3288
  return false;
3134
3289
 
3135
3290
  cudaGraphNode_t body_node;
@@ -3139,16 +3294,16 @@ bool cuda_graph_insert_child_graph(void* context, void* stream, void* child_grap
3139
3294
  static_cast<cudaGraph_t>(child_graph))))
3140
3295
  return false;
3141
3296
 
3142
- if (!cuda_graph_resume_capture(context, cuda_stream, cuda_graph))
3297
+ if (!wp_cuda_graph_resume_capture(context, cuda_stream, cuda_graph))
3143
3298
  return false;
3144
3299
 
3145
- if (!check_cuda(cudaStreamUpdateCaptureDependencies(cuda_stream, &body_node, 1, cudaStreamSetCaptureDependencies)))
3300
+ if (!check_cu(cuStreamUpdateCaptureDependencies_f(cuda_stream, &body_node, 1, cudaStreamSetCaptureDependencies)))
3146
3301
  return false;
3147
3302
 
3148
3303
  return true;
3149
3304
  }
3150
3305
 
3151
- bool cuda_graph_insert_while(void* context, void* stream, int* condition, void** body_graph_ret, uint64_t* handle_ret)
3306
+ bool wp_cuda_graph_insert_while(void* context, void* stream, int arch, bool use_ptx, int* condition, void** body_graph_ret, uint64_t* handle_ret)
3152
3307
  {
3153
3308
  // if there's no body, it's a no-op
3154
3309
  if (!body_graph_ret)
@@ -3159,15 +3314,15 @@ bool cuda_graph_insert_while(void* context, void* stream, int* condition, void**
3159
3314
  CUstream cuda_stream = static_cast<CUstream>(stream);
3160
3315
 
3161
3316
  // Get the current stream capturing graph
3162
- cudaStreamCaptureStatus capture_status = cudaStreamCaptureStatusNone;
3317
+ CUstreamCaptureStatus capture_status = CU_STREAM_CAPTURE_STATUS_NONE;
3163
3318
  cudaGraph_t cuda_graph = NULL;
3164
3319
  const cudaGraphNode_t* capture_deps = NULL;
3165
3320
  size_t dep_count = 0;
3166
- if (!check_cuda(cudaStreamGetCaptureInfo(cuda_stream, &capture_status, nullptr, &cuda_graph, &capture_deps, &dep_count)))
3321
+ if (!check_cu(cuStreamGetCaptureInfo_f(cuda_stream, &capture_status, nullptr, &cuda_graph, &capture_deps, &dep_count)))
3167
3322
  return false;
3168
3323
 
3169
3324
  // abort if not capturing
3170
- if (!cuda_graph || capture_status != cudaStreamCaptureStatusActive)
3325
+ if (!cuda_graph || capture_status != CU_STREAM_CAPTURE_STATUS_ACTIVE)
3171
3326
  {
3172
3327
  wp::set_error_string("Stream is not capturing");
3173
3328
  return false;
@@ -3178,7 +3333,7 @@ bool cuda_graph_insert_while(void* context, void* stream, int* condition, void**
3178
3333
  return false;
3179
3334
 
3180
3335
  // launch a kernel to set the condition handle from condition pointer
3181
- CUfunction kernel = get_conditional_kernel(context, "set_conditional_if_handle_kernel");
3336
+ CUfunction kernel = get_conditional_kernel(context, arch, use_ptx, "set_conditional_if_handle_kernel");
3182
3337
  if (!kernel)
3183
3338
  {
3184
3339
  wp::set_error_string("Failed to get built-in conditional kernel");
@@ -3192,19 +3347,20 @@ bool cuda_graph_insert_while(void* context, void* stream, int* condition, void**
3192
3347
  if (!check_cu(cuLaunchKernel_f(kernel, 1, 1, 1, 1, 1, 1, 0, cuda_stream, kernel_args, NULL)))
3193
3348
  return false;
3194
3349
 
3195
- if (!check_cuda(cudaStreamGetCaptureInfo(cuda_stream, &capture_status, nullptr, &cuda_graph, &capture_deps, &dep_count)))
3350
+ if (!check_cu(cuStreamGetCaptureInfo_f(cuda_stream, &capture_status, nullptr, &cuda_graph, &capture_deps, &dep_count)))
3196
3351
  return false;
3197
3352
 
3198
3353
  // insert conditional graph node
3199
- cudaGraphNode_t while_node;
3200
- cudaGraphNodeParams while_params = { cudaGraphNodeTypeConditional };
3354
+ CUgraphNode while_node;
3355
+ CUgraphNodeParams while_params = { CU_GRAPH_NODE_TYPE_CONDITIONAL };
3201
3356
  while_params.conditional.handle = handle;
3202
- while_params.conditional.type = cudaGraphCondTypeWhile;
3357
+ while_params.conditional.type = CU_GRAPH_COND_TYPE_WHILE;
3203
3358
  while_params.conditional.size = 1;
3204
- if (!check_cuda(cudaGraphAddNode(&while_node, cuda_graph, capture_deps, dep_count, &while_params)))
3359
+ while_params.conditional.ctx = get_current_context();
3360
+ if (!check_cu(cuGraphAddNode_f(&while_node, cuda_graph, capture_deps, NULL, dep_count, &while_params)))
3205
3361
  return false;
3206
3362
 
3207
- if (!check_cuda(cudaStreamUpdateCaptureDependencies(cuda_stream, &while_node, 1, cudaStreamSetCaptureDependencies)))
3363
+ if (!check_cu(cuStreamUpdateCaptureDependencies_f(cuda_stream, &while_node, 1, cudaStreamSetCaptureDependencies)))
3208
3364
  return false;
3209
3365
 
3210
3366
  *body_graph_ret = while_params.conditional.phGraph_out[0];
@@ -3213,14 +3369,14 @@ bool cuda_graph_insert_while(void* context, void* stream, int* condition, void**
3213
3369
  return true;
3214
3370
  }
3215
3371
 
3216
- bool cuda_graph_set_condition(void* context, void* stream, int* condition, uint64_t handle)
3372
+ bool wp_cuda_graph_set_condition(void* context, void* stream, int arch, bool use_ptx, int* condition, uint64_t handle)
3217
3373
  {
3218
3374
  ContextGuard guard(context);
3219
3375
 
3220
3376
  CUstream cuda_stream = static_cast<CUstream>(stream);
3221
3377
 
3222
3378
  // launch a kernel to set the condition handle from condition pointer
3223
- CUfunction kernel = get_conditional_kernel(context, "set_conditional_if_handle_kernel");
3379
+ CUfunction kernel = get_conditional_kernel(context, arch, use_ptx, "set_conditional_if_handle_kernel");
3224
3380
  if (!kernel)
3225
3381
  {
3226
3382
  wp::set_error_string("Failed to get built-in conditional kernel");
@@ -3240,37 +3396,43 @@ bool cuda_graph_set_condition(void* context, void* stream, int* condition, uint6
3240
3396
  #else
3241
3397
  // stubs for conditional graph node API if CUDA toolkit is too old.
3242
3398
 
3243
- bool cuda_graph_pause_capture(void* context, void* stream, void** graph_ret)
3399
+ bool wp_cuda_graph_pause_capture(void* context, void* stream, void** graph_ret)
3400
+ {
3401
+ wp::set_error_string("Warp error: Warp must be built with CUDA Toolkit 12.4+ to enable conditional graph nodes");
3402
+ return false;
3403
+ }
3404
+
3405
+ bool wp_cuda_graph_resume_capture(void* context, void* stream, void* graph)
3244
3406
  {
3245
3407
  wp::set_error_string("Warp error: Warp must be built with CUDA Toolkit 12.4+ to enable conditional graph nodes");
3246
3408
  return false;
3247
3409
  }
3248
3410
 
3249
- bool cuda_graph_resume_capture(void* context, void* stream, void* graph)
3411
+ bool wp_cuda_graph_insert_if_else(void* context, void* stream, int arch, bool use_ptx, int* condition, void** if_graph_ret, void** else_graph_ret)
3250
3412
  {
3251
3413
  wp::set_error_string("Warp error: Warp must be built with CUDA Toolkit 12.4+ to enable conditional graph nodes");
3252
3414
  return false;
3253
3415
  }
3254
3416
 
3255
- bool cuda_graph_insert_if_else(void* context, void* stream, int* condition, void** if_graph_ret, void** else_graph_ret)
3417
+ bool wp_cuda_graph_insert_while(void* context, void* stream, int arch, bool use_ptx, int* condition, void** body_graph_ret, uint64_t* handle_ret)
3256
3418
  {
3257
3419
  wp::set_error_string("Warp error: Warp must be built with CUDA Toolkit 12.4+ to enable conditional graph nodes");
3258
3420
  return false;
3259
3421
  }
3260
3422
 
3261
- bool cuda_graph_insert_while(void* context, void* stream, int* condition, void** body_graph_ret, uint64_t* handle_ret)
3423
+ bool wp_cuda_graph_set_condition(void* context, void* stream, int arch, bool use_ptx, int* condition, uint64_t handle)
3262
3424
  {
3263
3425
  wp::set_error_string("Warp error: Warp must be built with CUDA Toolkit 12.4+ to enable conditional graph nodes");
3264
3426
  return false;
3265
3427
  }
3266
3428
 
3267
- bool cuda_graph_set_condition(void* context, void* stream, int* condition, uint64_t handle)
3429
+ bool wp_cuda_graph_insert_child_graph(void* context, void* stream, void* child_graph)
3268
3430
  {
3269
3431
  wp::set_error_string("Warp error: Warp must be built with CUDA Toolkit 12.4+ to enable conditional graph nodes");
3270
3432
  return false;
3271
3433
  }
3272
3434
 
3273
- bool cuda_graph_insert_child_graph(void* context, void* stream, void* child_graph)
3435
+ bool wp_cuda_graph_check_conditional_body(void* body_graph)
3274
3436
  {
3275
3437
  wp::set_error_string("Warp error: Warp must be built with CUDA Toolkit 12.4+ to enable conditional graph nodes");
3276
3438
  return false;
@@ -3279,7 +3441,7 @@ bool cuda_graph_insert_child_graph(void* context, void* stream, void* child_grap
3279
3441
  #endif // support for conditional graph nodes
3280
3442
 
3281
3443
 
3282
- bool cuda_graph_launch(void* graph_exec, void* stream)
3444
+ bool wp_cuda_graph_launch(void* graph_exec, void* stream)
3283
3445
  {
3284
3446
  // TODO: allow naming graphs?
3285
3447
  begin_cuda_range(WP_TIMING_GRAPH, stream, get_stream_context(stream), "graph");
@@ -3291,14 +3453,14 @@ bool cuda_graph_launch(void* graph_exec, void* stream)
3291
3453
  return result;
3292
3454
  }
3293
3455
 
3294
- bool cuda_graph_destroy(void* context, void* graph)
3456
+ bool wp_cuda_graph_destroy(void* context, void* graph)
3295
3457
  {
3296
3458
  ContextGuard guard(context);
3297
3459
 
3298
3460
  return check_cuda(cudaGraphDestroy((cudaGraph_t)graph));
3299
3461
  }
3300
3462
 
3301
- bool cuda_graph_exec_destroy(void* context, void* graph_exec)
3463
+ bool wp_cuda_graph_exec_destroy(void* context, void* graph_exec)
3302
3464
  {
3303
3465
  ContextGuard guard(context);
3304
3466
 
@@ -3350,7 +3512,7 @@ bool write_file(const char* data, size_t size, std::string filename, const char*
3350
3512
  }
3351
3513
  #endif
3352
3514
 
3353
- size_t cuda_compile_program(const char* cuda_src, const char* program_name, int arch, const char* include_dir, int num_cuda_include_dirs, const char** cuda_include_dirs, bool debug, bool verbose, bool verify_fp, bool fast_math, bool fuse_fp, bool lineinfo, bool compile_time_trace, const char* output_path, size_t num_ltoirs, char** ltoirs, size_t* ltoir_sizes, int* ltoir_input_types)
3515
+ size_t wp_cuda_compile_program(const char* cuda_src, const char* program_name, int arch, const char* include_dir, int num_cuda_include_dirs, const char** cuda_include_dirs, bool debug, bool verbose, bool verify_fp, bool fast_math, bool fuse_fp, bool lineinfo, bool compile_time_trace, const char* output_path, size_t num_ltoirs, char** ltoirs, size_t* ltoir_sizes, int* ltoir_input_types)
3354
3516
  {
3355
3517
  // use file extension to determine whether to output PTX or CUBIN
3356
3518
  const char* output_ext = strrchr(output_path, '.');
@@ -3406,9 +3568,9 @@ size_t cuda_compile_program(const char* cuda_src, const char* program_name, int
3406
3568
  {
3407
3569
  opts.push_back("--define-macro=_DEBUG");
3408
3570
  opts.push_back("--generate-line-info");
3409
-
3410
- // disabling since it causes issues with `Unresolved extern function 'cudaGetParameterBufferV2'
3411
- //opts.push_back("--device-debug");
3571
+ #ifndef _WIN32
3572
+ opts.push_back("--device-debug"); // -G
3573
+ #endif
3412
3574
  }
3413
3575
  else
3414
3576
  {
@@ -3678,7 +3840,7 @@ size_t cuda_compile_program(const char* cuda_src, const char* program_name, int
3678
3840
  }
3679
3841
  }
3680
3842
 
3681
- bool cuda_compile_fft(const char* ltoir_output_path, const char* symbol_name, int num_include_dirs, const char** include_dirs, const char* mathdx_include_dir, int arch, int size, int elements_per_thread, int direction, int precision, int* shared_memory_size)
3843
+ bool wp_cuda_compile_fft(const char* ltoir_output_path, const char* symbol_name, int num_include_dirs, const char** include_dirs, const char* mathdx_include_dir, int arch, int size, int elements_per_thread, int direction, int precision, int* shared_memory_size)
3682
3844
  {
3683
3845
 
3684
3846
  CHECK_ANY(ltoir_output_path != nullptr);
@@ -3724,7 +3886,7 @@ size_t cuda_compile_program(const char* cuda_src, const char* program_name, int
3724
3886
  return res;
3725
3887
  }
3726
3888
 
3727
- bool cuda_compile_dot(const char* ltoir_output_path, const char* symbol_name, int num_include_dirs, const char** include_dirs, const char* mathdx_include_dir, int arch, int M, int N, int K, int precision_A, int precision_B, int precision_C, int type, int arrangement_A, int arrangement_B, int arrangement_C, int num_threads)
3889
+ bool wp_cuda_compile_dot(const char* ltoir_output_path, const char* symbol_name, int num_include_dirs, const char** include_dirs, const char* mathdx_include_dir, int arch, int M, int N, int K, int precision_A, int precision_B, int precision_C, int type, int arrangement_A, int arrangement_B, int arrangement_C, int num_threads)
3728
3890
  {
3729
3891
 
3730
3892
  CHECK_ANY(ltoir_output_path != nullptr);
@@ -3769,7 +3931,7 @@ size_t cuda_compile_program(const char* cuda_src, const char* program_name, int
3769
3931
  return res;
3770
3932
  }
3771
3933
 
3772
- bool cuda_compile_solver(const char* fatbin_output_path, const char* ltoir_output_path, const char* symbol_name, int num_include_dirs, const char** include_dirs, const char* mathdx_include_dir, int arch, int M, int N, int NRHS, int function, int side, int diag, int precision, int arrangement_A, int arrangement_B, int fill_mode, int num_threads)
3934
+ bool wp_cuda_compile_solver(const char* fatbin_output_path, const char* ltoir_output_path, const char* symbol_name, int num_include_dirs, const char** include_dirs, const char* mathdx_include_dir, int arch, int M, int N, int NRHS, int function, int side, int diag, int precision, int arrangement_A, int arrangement_B, int fill_mode, int num_threads)
3773
3935
  {
3774
3936
 
3775
3937
  CHECK_ANY(ltoir_output_path != nullptr);
@@ -3832,7 +3994,7 @@ size_t cuda_compile_program(const char* cuda_src, const char* program_name, int
3832
3994
 
3833
3995
  #endif
3834
3996
 
3835
- void* cuda_load_module(void* context, const char* path)
3997
+ void* wp_cuda_load_module(void* context, const char* path)
3836
3998
  {
3837
3999
  ContextGuard guard(context);
3838
4000
 
@@ -3951,7 +4113,7 @@ void* cuda_load_module(void* context, const char* path)
3951
4113
  return module;
3952
4114
  }
3953
4115
 
3954
- void cuda_unload_module(void* context, void* module)
4116
+ void wp_cuda_unload_module(void* context, void* module)
3955
4117
  {
3956
4118
  // ensure there are no graph captures in progress
3957
4119
  if (g_captures.empty())
@@ -3970,7 +4132,7 @@ void cuda_unload_module(void* context, void* module)
3970
4132
  }
3971
4133
 
3972
4134
 
3973
- int cuda_get_max_shared_memory(void* context)
4135
+ int wp_cuda_get_max_shared_memory(void* context)
3974
4136
  {
3975
4137
  ContextInfo* info = get_context_info(context);
3976
4138
  if (!info)
@@ -3980,7 +4142,7 @@ int cuda_get_max_shared_memory(void* context)
3980
4142
  return max_smem_bytes;
3981
4143
  }
3982
4144
 
3983
- bool cuda_configure_kernel_shared_memory(void* kernel, int size)
4145
+ bool wp_cuda_configure_kernel_shared_memory(void* kernel, int size)
3984
4146
  {
3985
4147
  int requested_smem_bytes = size;
3986
4148
 
@@ -3992,7 +4154,7 @@ bool cuda_configure_kernel_shared_memory(void* kernel, int size)
3992
4154
  return true;
3993
4155
  }
3994
4156
 
3995
- void* cuda_get_kernel(void* context, void* module, const char* name)
4157
+ void* wp_cuda_get_kernel(void* context, void* module, const char* name)
3996
4158
  {
3997
4159
  ContextGuard guard(context);
3998
4160
 
@@ -4007,7 +4169,7 @@ void* cuda_get_kernel(void* context, void* module, const char* name)
4007
4169
  return kernel;
4008
4170
  }
4009
4171
 
4010
- size_t cuda_launch_kernel(void* context, void* kernel, size_t dim, int max_blocks, int block_dim, int shared_memory_bytes, void** args, void* stream)
4172
+ size_t wp_cuda_launch_kernel(void* context, void* kernel, size_t dim, int max_blocks, int block_dim, int shared_memory_bytes, void** args, void* stream)
4011
4173
  {
4012
4174
  ContextGuard guard(context);
4013
4175
 
@@ -4061,21 +4223,21 @@ size_t cuda_launch_kernel(void* context, void* kernel, size_t dim, int max_block
4061
4223
  return res;
4062
4224
  }
4063
4225
 
4064
- void cuda_graphics_map(void* context, void* resource)
4226
+ void wp_cuda_graphics_map(void* context, void* resource)
4065
4227
  {
4066
4228
  ContextGuard guard(context);
4067
4229
 
4068
4230
  check_cu(cuGraphicsMapResources_f(1, (CUgraphicsResource*)resource, get_current_stream()));
4069
4231
  }
4070
4232
 
4071
- void cuda_graphics_unmap(void* context, void* resource)
4233
+ void wp_cuda_graphics_unmap(void* context, void* resource)
4072
4234
  {
4073
4235
  ContextGuard guard(context);
4074
4236
 
4075
4237
  check_cu(cuGraphicsUnmapResources_f(1, (CUgraphicsResource*)resource, get_current_stream()));
4076
4238
  }
4077
4239
 
4078
- void cuda_graphics_device_ptr_and_size(void* context, void* resource, uint64_t* ptr, size_t* size)
4240
+ void wp_cuda_graphics_device_ptr_and_size(void* context, void* resource, uint64_t* ptr, size_t* size)
4079
4241
  {
4080
4242
  ContextGuard guard(context);
4081
4243
 
@@ -4087,7 +4249,7 @@ void cuda_graphics_device_ptr_and_size(void* context, void* resource, uint64_t*
4087
4249
  *size = bytes;
4088
4250
  }
4089
4251
 
4090
- void* cuda_graphics_register_gl_buffer(void* context, uint32_t gl_buffer, unsigned int flags)
4252
+ void* wp_cuda_graphics_register_gl_buffer(void* context, uint32_t gl_buffer, unsigned int flags)
4091
4253
  {
4092
4254
  ContextGuard guard(context);
4093
4255
 
@@ -4102,7 +4264,7 @@ void* cuda_graphics_register_gl_buffer(void* context, uint32_t gl_buffer, unsign
4102
4264
  return resource;
4103
4265
  }
4104
4266
 
4105
- void cuda_graphics_unregister_resource(void* context, void* resource)
4267
+ void wp_cuda_graphics_unregister_resource(void* context, void* resource)
4106
4268
  {
4107
4269
  ContextGuard guard(context);
4108
4270
 
@@ -4111,25 +4273,25 @@ void cuda_graphics_unregister_resource(void* context, void* resource)
4111
4273
  delete res;
4112
4274
  }
4113
4275
 
4114
- void cuda_timing_begin(int flags)
4276
+ void wp_cuda_timing_begin(int flags)
4115
4277
  {
4116
4278
  g_cuda_timing_state = new CudaTimingState(flags, g_cuda_timing_state);
4117
4279
  }
4118
4280
 
4119
- int cuda_timing_get_result_count()
4281
+ int wp_cuda_timing_get_result_count()
4120
4282
  {
4121
4283
  if (g_cuda_timing_state)
4122
4284
  return int(g_cuda_timing_state->ranges.size());
4123
4285
  return 0;
4124
4286
  }
4125
4287
 
4126
- void cuda_timing_end(timing_result_t* results, int size)
4288
+ void wp_cuda_timing_end(timing_result_t* results, int size)
4127
4289
  {
4128
4290
  if (!g_cuda_timing_state)
4129
4291
  return;
4130
4292
 
4131
4293
  // number of results to write to the user buffer
4132
- int count = std::min(cuda_timing_get_result_count(), size);
4294
+ int count = std::min(wp_cuda_timing_get_result_count(), size);
4133
4295
 
4134
4296
  // compute timings and write results
4135
4297
  for (int i = 0; i < count; i++)
@@ -4163,7 +4325,6 @@ void cuda_timing_end(timing_result_t* results, int size)
4163
4325
  #include "reduce.cu"
4164
4326
  #include "runlength_encode.cu"
4165
4327
  #include "scan.cu"
4166
- #include "marching.cu"
4167
4328
  #include "sparse.cu"
4168
4329
  #include "volume.cu"
4169
4330
  #include "volume_builder.cu"