warp-lang 1.7.2rc1__py3-none-win_amd64.whl → 1.8.1__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (193) hide show
  1. warp/__init__.py +3 -1
  2. warp/__init__.pyi +3489 -1
  3. warp/autograd.py +45 -122
  4. warp/bin/warp-clang.dll +0 -0
  5. warp/bin/warp.dll +0 -0
  6. warp/build.py +241 -252
  7. warp/build_dll.py +130 -26
  8. warp/builtins.py +1907 -384
  9. warp/codegen.py +272 -104
  10. warp/config.py +12 -1
  11. warp/constants.py +1 -1
  12. warp/context.py +770 -238
  13. warp/dlpack.py +1 -1
  14. warp/examples/benchmarks/benchmark_cloth.py +2 -2
  15. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  16. warp/examples/core/example_sample_mesh.py +1 -1
  17. warp/examples/core/example_spin_lock.py +93 -0
  18. warp/examples/core/example_work_queue.py +118 -0
  19. warp/examples/fem/example_adaptive_grid.py +5 -5
  20. warp/examples/fem/example_apic_fluid.py +1 -1
  21. warp/examples/fem/example_burgers.py +1 -1
  22. warp/examples/fem/example_convection_diffusion.py +9 -6
  23. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  24. warp/examples/fem/example_deformed_geometry.py +1 -1
  25. warp/examples/fem/example_diffusion.py +2 -2
  26. warp/examples/fem/example_diffusion_3d.py +1 -1
  27. warp/examples/fem/example_distortion_energy.py +1 -1
  28. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  29. warp/examples/fem/example_magnetostatics.py +5 -3
  30. warp/examples/fem/example_mixed_elasticity.py +5 -3
  31. warp/examples/fem/example_navier_stokes.py +11 -9
  32. warp/examples/fem/example_nonconforming_contact.py +5 -3
  33. warp/examples/fem/example_streamlines.py +8 -3
  34. warp/examples/fem/utils.py +9 -8
  35. warp/examples/interop/example_jax_callable.py +34 -4
  36. warp/examples/interop/example_jax_ffi_callback.py +2 -2
  37. warp/examples/interop/example_jax_kernel.py +27 -1
  38. warp/examples/optim/example_drone.py +1 -1
  39. warp/examples/sim/example_cloth.py +1 -1
  40. warp/examples/sim/example_cloth_self_contact.py +48 -54
  41. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  42. warp/examples/tile/example_tile_cholesky.py +2 -1
  43. warp/examples/tile/example_tile_convolution.py +1 -1
  44. warp/examples/tile/example_tile_filtering.py +1 -1
  45. warp/examples/tile/example_tile_matmul.py +1 -1
  46. warp/examples/tile/example_tile_mlp.py +2 -0
  47. warp/fabric.py +7 -7
  48. warp/fem/__init__.py +5 -0
  49. warp/fem/adaptivity.py +1 -1
  50. warp/fem/cache.py +152 -63
  51. warp/fem/dirichlet.py +2 -2
  52. warp/fem/domain.py +136 -6
  53. warp/fem/field/field.py +141 -99
  54. warp/fem/field/nodal_field.py +85 -39
  55. warp/fem/field/virtual.py +99 -52
  56. warp/fem/geometry/adaptive_nanogrid.py +91 -86
  57. warp/fem/geometry/closest_point.py +13 -0
  58. warp/fem/geometry/deformed_geometry.py +102 -40
  59. warp/fem/geometry/element.py +56 -2
  60. warp/fem/geometry/geometry.py +323 -22
  61. warp/fem/geometry/grid_2d.py +157 -62
  62. warp/fem/geometry/grid_3d.py +116 -20
  63. warp/fem/geometry/hexmesh.py +86 -20
  64. warp/fem/geometry/nanogrid.py +166 -86
  65. warp/fem/geometry/partition.py +59 -25
  66. warp/fem/geometry/quadmesh.py +86 -135
  67. warp/fem/geometry/tetmesh.py +47 -119
  68. warp/fem/geometry/trimesh.py +77 -270
  69. warp/fem/integrate.py +181 -95
  70. warp/fem/linalg.py +25 -58
  71. warp/fem/operator.py +124 -27
  72. warp/fem/quadrature/pic_quadrature.py +36 -14
  73. warp/fem/quadrature/quadrature.py +40 -16
  74. warp/fem/space/__init__.py +1 -1
  75. warp/fem/space/basis_function_space.py +66 -46
  76. warp/fem/space/basis_space.py +17 -4
  77. warp/fem/space/dof_mapper.py +1 -1
  78. warp/fem/space/function_space.py +2 -2
  79. warp/fem/space/grid_2d_function_space.py +4 -1
  80. warp/fem/space/hexmesh_function_space.py +4 -2
  81. warp/fem/space/nanogrid_function_space.py +3 -1
  82. warp/fem/space/partition.py +11 -2
  83. warp/fem/space/quadmesh_function_space.py +4 -1
  84. warp/fem/space/restriction.py +5 -2
  85. warp/fem/space/shape/__init__.py +10 -8
  86. warp/fem/space/tetmesh_function_space.py +4 -1
  87. warp/fem/space/topology.py +52 -21
  88. warp/fem/space/trimesh_function_space.py +4 -1
  89. warp/fem/utils.py +53 -8
  90. warp/jax.py +1 -2
  91. warp/jax_experimental/ffi.py +210 -67
  92. warp/jax_experimental/xla_ffi.py +37 -24
  93. warp/math.py +171 -1
  94. warp/native/array.h +103 -4
  95. warp/native/builtin.h +182 -35
  96. warp/native/coloring.cpp +6 -2
  97. warp/native/cuda_util.cpp +1 -1
  98. warp/native/exports.h +118 -63
  99. warp/native/intersect.h +5 -5
  100. warp/native/mat.h +8 -13
  101. warp/native/mathdx.cpp +11 -5
  102. warp/native/matnn.h +1 -123
  103. warp/native/mesh.h +1 -1
  104. warp/native/quat.h +34 -6
  105. warp/native/rand.h +7 -7
  106. warp/native/sparse.cpp +121 -258
  107. warp/native/sparse.cu +181 -274
  108. warp/native/spatial.h +305 -17
  109. warp/native/svd.h +23 -8
  110. warp/native/tile.h +603 -73
  111. warp/native/tile_radix_sort.h +1112 -0
  112. warp/native/tile_reduce.h +239 -13
  113. warp/native/tile_scan.h +240 -0
  114. warp/native/tuple.h +189 -0
  115. warp/native/vec.h +10 -20
  116. warp/native/warp.cpp +36 -4
  117. warp/native/warp.cu +588 -52
  118. warp/native/warp.h +47 -74
  119. warp/optim/linear.py +5 -1
  120. warp/paddle.py +7 -8
  121. warp/py.typed +0 -0
  122. warp/render/render_opengl.py +110 -80
  123. warp/render/render_usd.py +124 -62
  124. warp/sim/__init__.py +9 -0
  125. warp/sim/collide.py +253 -80
  126. warp/sim/graph_coloring.py +8 -1
  127. warp/sim/import_mjcf.py +4 -3
  128. warp/sim/import_usd.py +11 -7
  129. warp/sim/integrator.py +5 -2
  130. warp/sim/integrator_euler.py +1 -1
  131. warp/sim/integrator_featherstone.py +1 -1
  132. warp/sim/integrator_vbd.py +761 -322
  133. warp/sim/integrator_xpbd.py +1 -1
  134. warp/sim/model.py +265 -260
  135. warp/sim/utils.py +10 -7
  136. warp/sparse.py +303 -166
  137. warp/tape.py +54 -51
  138. warp/tests/cuda/test_conditional_captures.py +1046 -0
  139. warp/tests/cuda/test_streams.py +1 -1
  140. warp/tests/geometry/test_volume.py +2 -2
  141. warp/tests/interop/test_dlpack.py +9 -9
  142. warp/tests/interop/test_jax.py +0 -1
  143. warp/tests/run_coverage_serial.py +1 -1
  144. warp/tests/sim/disabled_kinematics.py +2 -2
  145. warp/tests/sim/{test_vbd.py → test_cloth.py} +378 -112
  146. warp/tests/sim/test_collision.py +159 -51
  147. warp/tests/sim/test_coloring.py +91 -2
  148. warp/tests/test_array.py +254 -2
  149. warp/tests/test_array_reduce.py +2 -2
  150. warp/tests/test_assert.py +53 -0
  151. warp/tests/test_atomic_cas.py +312 -0
  152. warp/tests/test_codegen.py +142 -19
  153. warp/tests/test_conditional.py +47 -1
  154. warp/tests/test_ctypes.py +0 -20
  155. warp/tests/test_devices.py +8 -0
  156. warp/tests/test_fabricarray.py +4 -2
  157. warp/tests/test_fem.py +58 -25
  158. warp/tests/test_func.py +42 -1
  159. warp/tests/test_grad.py +1 -1
  160. warp/tests/test_lerp.py +1 -3
  161. warp/tests/test_map.py +481 -0
  162. warp/tests/test_mat.py +23 -24
  163. warp/tests/test_quat.py +28 -15
  164. warp/tests/test_rounding.py +10 -38
  165. warp/tests/test_runlength_encode.py +7 -7
  166. warp/tests/test_smoothstep.py +1 -1
  167. warp/tests/test_sparse.py +83 -2
  168. warp/tests/test_spatial.py +507 -1
  169. warp/tests/test_static.py +48 -0
  170. warp/tests/test_struct.py +2 -2
  171. warp/tests/test_tape.py +38 -0
  172. warp/tests/test_tuple.py +265 -0
  173. warp/tests/test_types.py +2 -2
  174. warp/tests/test_utils.py +24 -18
  175. warp/tests/test_vec.py +38 -408
  176. warp/tests/test_vec_constructors.py +325 -0
  177. warp/tests/tile/test_tile.py +438 -131
  178. warp/tests/tile/test_tile_mathdx.py +518 -14
  179. warp/tests/tile/test_tile_matmul.py +179 -0
  180. warp/tests/tile/test_tile_reduce.py +307 -5
  181. warp/tests/tile/test_tile_shared_memory.py +136 -7
  182. warp/tests/tile/test_tile_sort.py +121 -0
  183. warp/tests/unittest_suites.py +14 -6
  184. warp/types.py +462 -308
  185. warp/utils.py +647 -86
  186. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/METADATA +20 -6
  187. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/RECORD +190 -176
  188. warp/stubs.py +0 -3381
  189. warp/tests/sim/test_xpbd.py +0 -399
  190. warp/tests/test_mlp.py +0 -282
  191. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/WHEEL +0 -0
  192. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/licenses/LICENSE.md +0 -0
  193. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/top_level.txt +0 -0
warp/native/warp.cu CHANGED
@@ -27,6 +27,9 @@
27
27
  #if WP_ENABLE_MATHDX
28
28
  #include <nvJitLink.h>
29
29
  #include <libmathdx.h>
30
+ #include <libcublasdx.h>
31
+ #include <libcufftdx.h>
32
+ #include <libcusolverdx.h>
30
33
  #endif
31
34
 
32
35
  #include <array>
@@ -155,6 +158,7 @@ struct DeviceInfo
155
158
  int arch = 0;
156
159
  int is_uva = 0;
157
160
  int is_mempool_supported = 0;
161
+ int sm_count = 0;
158
162
  int is_ipc_supported = -1;
159
163
  int max_smem_bytes = 0;
160
164
  CUcontext primary_context = NULL;
@@ -166,6 +170,9 @@ struct ContextInfo
166
170
 
167
171
  // the current stream, managed from Python (see cuda_context_set_stream() and cuda_context_get_stream())
168
172
  CUstream stream = NULL;
173
+
174
+ // conditional graph node support, loaded on demand if the driver supports it (CUDA 12.4+)
175
+ CUmodule conditional_module = NULL;
169
176
  };
170
177
 
171
178
  struct CaptureInfo
@@ -280,6 +287,7 @@ int cuda_init()
280
287
  check_cu(cuDeviceGetAttribute_f(&g_devices[i].pci_device_id, CU_DEVICE_ATTRIBUTE_PCI_DEVICE_ID, device));
281
288
  check_cu(cuDeviceGetAttribute_f(&g_devices[i].is_uva, CU_DEVICE_ATTRIBUTE_UNIFIED_ADDRESSING, device));
282
289
  check_cu(cuDeviceGetAttribute_f(&g_devices[i].is_mempool_supported, CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED, device));
290
+ check_cu(cuDeviceGetAttribute_f(&g_devices[i].sm_count, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, device));
283
291
  #ifdef CUDA_VERSION
284
292
  #if CUDA_VERSION >= 12000
285
293
  int device_attribute_integrated = 0;
@@ -301,7 +309,13 @@ int cuda_init()
301
309
  check_cu(cuDeviceGetAttribute_f(&major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device));
302
310
  check_cu(cuDeviceGetAttribute_f(&minor, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, device));
303
311
  g_devices[i].arch = 10 * major + minor;
304
-
312
+ #ifdef CUDA_VERSION
313
+ #if CUDA_VERSION < 13000
314
+ if (g_devices[i].arch == 110) {
315
+ g_devices[i].arch = 101; // Thor SM change
316
+ }
317
+ #endif
318
+ #endif
305
319
  g_device_map[device] = &g_devices[i];
306
320
  }
307
321
  else
@@ -1786,6 +1800,13 @@ int cuda_device_get_arch(int ordinal)
1786
1800
  return 0;
1787
1801
  }
1788
1802
 
1803
+ int cuda_device_get_sm_count(int ordinal)
1804
+ {
1805
+ if (ordinal >= 0 && ordinal < int(g_devices.size()))
1806
+ return g_devices[ordinal].sm_count;
1807
+ return 0;
1808
+ }
1809
+
1789
1810
  void cuda_device_get_uuid(int ordinal, char uuid[16])
1790
1811
  {
1791
1812
  memcpy(uuid, g_devices[ordinal].uuid.bytes, sizeof(char)*16);
@@ -2034,6 +2055,9 @@ void cuda_context_destroy(void* context)
2034
2055
  if (info->stream)
2035
2056
  check_cu(cuStreamDestroy_f(info->stream));
2036
2057
 
2058
+ if (info->conditional_module)
2059
+ check_cu(cuModuleUnload_f(info->conditional_module));
2060
+
2037
2061
  g_contexts.erase(ctx);
2038
2062
  }
2039
2063
 
@@ -2739,22 +2763,10 @@ bool cuda_graph_end_capture(void* context, void* stream, void** graph_ret)
2739
2763
  if (external)
2740
2764
  return true;
2741
2765
 
2742
- cudaGraphExec_t graph_exec = NULL;
2743
-
2744
2766
  // end the capture
2745
2767
  if (!check_cuda(cudaStreamEndCapture(cuda_stream, &graph)))
2746
2768
  return false;
2747
2769
 
2748
- // enable to create debug GraphVis visualization of graph
2749
- // cudaGraphDebugDotPrint(graph, "graph.dot", cudaGraphDebugDotFlagsVerbose);
2750
-
2751
- // can use after CUDA 11.4 to permit graphs to capture cudaMallocAsync() operations
2752
- if (!check_cuda(cudaGraphInstantiateWithFlags(&graph_exec, graph, cudaGraphInstantiateFlagAutoFreeOnLaunch)))
2753
- return false;
2754
-
2755
- // free source graph
2756
- check_cuda(cudaGraphDestroy(graph));
2757
-
2758
2770
  // process deferred free list if no more captures are ongoing
2759
2771
  if (g_captures.empty())
2760
2772
  {
@@ -2763,11 +2775,510 @@ bool cuda_graph_end_capture(void* context, void* stream, void** graph_ret)
2763
2775
  }
2764
2776
 
2765
2777
  if (graph_ret)
2766
- *graph_ret = graph_exec;
2778
+ *graph_ret = graph;
2767
2779
 
2768
2780
  return true;
2769
2781
  }
2770
2782
 
2783
+ bool capture_debug_dot_print(void* graph, const char *path, uint32_t flags)
2784
+ {
2785
+ if (!check_cuda(cudaGraphDebugDotPrint((cudaGraph_t)graph, path, flags)))
2786
+ return false;
2787
+ return true;
2788
+ }
2789
+
2790
+ bool cuda_graph_create_exec(void* context, void* stream, void* graph, void** graph_exec_ret)
2791
+ {
2792
+ ContextGuard guard(context);
2793
+
2794
+ cudaGraphExec_t graph_exec = NULL;
2795
+ if (!check_cuda(cudaGraphInstantiateWithFlags(&graph_exec, (cudaGraph_t)graph, cudaGraphInstantiateFlagAutoFreeOnLaunch)))
2796
+ return false;
2797
+
2798
+ // Usually uploading the graph explicitly is optional, but when updating graph nodes (e.g., indirect dispatch)
2799
+ // then the upload is required because otherwise the graph nodes that get updated might not yet be uploaded, which
2800
+ // results in undefined behavior.
2801
+ CUstream cuda_stream = static_cast<CUstream>(stream);
2802
+ if (!check_cuda(cudaGraphUpload(graph_exec, cuda_stream)))
2803
+ return false;
2804
+
2805
+ if (graph_exec_ret)
2806
+ *graph_exec_ret = graph_exec;
2807
+
2808
+ return true;
2809
+ }
2810
+
2811
+ // Support for conditional graph nodes available with CUDA 12.4+.
2812
+ #if CUDA_VERSION >= 12040
2813
+
2814
+ // CUBIN data for compiled conditional modules, loaded on demand, keyed on device architecture
2815
+ static std::map<int, void*> g_conditional_cubins;
2816
+
2817
+ // Compile module with conditional helper kernels
2818
+ static void* compile_conditional_module(int arch)
2819
+ {
2820
+ static const char* kernel_source = R"(
2821
+ typedef __device_builtin__ unsigned long long cudaGraphConditionalHandle;
2822
+ extern "C" __device__ __cudart_builtin__ void cudaGraphSetConditional(cudaGraphConditionalHandle handle, unsigned int value);
2823
+
2824
+ extern "C" __global__ void set_conditional_if_handle_kernel(cudaGraphConditionalHandle handle, int* value)
2825
+ {
2826
+ if (threadIdx.x + blockIdx.x * blockDim.x == 0)
2827
+ cudaGraphSetConditional(handle, *value);
2828
+ }
2829
+
2830
+ extern "C" __global__ void set_conditional_else_handle_kernel(cudaGraphConditionalHandle handle, int* value)
2831
+ {
2832
+ if (threadIdx.x + blockIdx.x * blockDim.x == 0)
2833
+ cudaGraphSetConditional(handle, !*value);
2834
+ }
2835
+
2836
+ extern "C" __global__ void set_conditional_if_else_handles_kernel(cudaGraphConditionalHandle if_handle, cudaGraphConditionalHandle else_handle, int* value)
2837
+ {
2838
+ if (threadIdx.x + blockIdx.x * blockDim.x == 0)
2839
+ {
2840
+ cudaGraphSetConditional(if_handle, *value);
2841
+ cudaGraphSetConditional(else_handle, !*value);
2842
+ }
2843
+ }
2844
+ )";
2845
+
2846
+ // avoid recompilation
2847
+ auto it = g_conditional_cubins.find(arch);
2848
+ if (it != g_conditional_cubins.end())
2849
+ return it->second;
2850
+
2851
+ nvrtcProgram prog;
2852
+ if (!check_nvrtc(nvrtcCreateProgram(&prog, kernel_source, "conditional_kernels", 0, NULL, NULL)))
2853
+ return NULL;
2854
+
2855
+ char arch_opt[128];
2856
+ snprintf(arch_opt, sizeof(arch_opt), "--gpu-architecture=sm_%d", arch);
2857
+
2858
+ std::vector<const char*> opts;
2859
+ opts.push_back(arch_opt);
2860
+
2861
+ if (!check_nvrtc(nvrtcCompileProgram(prog, int(opts.size()), opts.data())))
2862
+ {
2863
+ size_t log_size;
2864
+ if (check_nvrtc(nvrtcGetProgramLogSize(prog, &log_size)))
2865
+ {
2866
+ std::vector<char> log(log_size);
2867
+ if (check_nvrtc(nvrtcGetProgramLog(prog, log.data())))
2868
+ fprintf(stderr, "%s", log.data());
2869
+ }
2870
+ nvrtcDestroyProgram(&prog);
2871
+ return NULL;
2872
+ }
2873
+
2874
+ // get output
2875
+ char* output = NULL;
2876
+ size_t output_size = 0;
2877
+ check_nvrtc(nvrtcGetCUBINSize(prog, &output_size));
2878
+ if (output_size > 0)
2879
+ {
2880
+ output = new char[output_size];
2881
+ if (check_nvrtc(nvrtcGetCUBIN(prog, output)))
2882
+ g_conditional_cubins[arch] = output;
2883
+ }
2884
+
2885
+ nvrtcDestroyProgram(&prog);
2886
+
2887
+ // return CUBIN data
2888
+ return output;
2889
+ }
2890
+
2891
+
2892
+ // Load module with conditional helper kernels
2893
+ static CUmodule load_conditional_module(void* context)
2894
+ {
2895
+ ContextInfo* context_info = get_context_info(context);
2896
+ if (!context_info)
2897
+ return NULL;
2898
+
2899
+ // check if already loaded
2900
+ if (context_info->conditional_module)
2901
+ return context_info->conditional_module;
2902
+
2903
+ int arch = context_info->device_info->arch;
2904
+
2905
+ // compile if needed
2906
+ void* compiled_module = compile_conditional_module(arch);
2907
+ if (!compiled_module)
2908
+ {
2909
+ fprintf(stderr, "Warp error: Failed to compile conditional kernels\n");
2910
+ return NULL;
2911
+ }
2912
+
2913
+ // load module
2914
+ CUmodule module = NULL;
2915
+ if (!check_cu(cuModuleLoadDataEx_f(&module, compiled_module, 0, NULL, NULL)))
2916
+ {
2917
+ fprintf(stderr, "Warp error: Failed to load conditional kernels module\n");
2918
+ return NULL;
2919
+ }
2920
+
2921
+ context_info->conditional_module = module;
2922
+
2923
+ return module;
2924
+ }
2925
+
2926
+ static CUfunction get_conditional_kernel(void* context, const char* name)
2927
+ {
2928
+ // load module if needed
2929
+ CUmodule module = load_conditional_module(context);
2930
+ if (!module)
2931
+ return NULL;
2932
+
2933
+ CUfunction kernel;
2934
+ if (!check_cu(cuModuleGetFunction_f(&kernel, module, name)))
2935
+ {
2936
+ fprintf(stderr, "Warp error: Failed to get kernel %s\n", name);
2937
+ return NULL;
2938
+ }
2939
+
2940
+ return kernel;
2941
+ }
2942
+
2943
+ bool cuda_graph_pause_capture(void* context, void* stream, void** graph_ret)
2944
+ {
2945
+ ContextGuard guard(context);
2946
+
2947
+ CUstream cuda_stream = static_cast<CUstream>(stream);
2948
+ if (!check_cuda(cudaStreamEndCapture(cuda_stream, (cudaGraph_t*)graph_ret)))
2949
+ return false;
2950
+ return true;
2951
+ }
2952
+
2953
+ bool cuda_graph_resume_capture(void* context, void* stream, void* graph)
2954
+ {
2955
+ ContextGuard guard(context);
2956
+
2957
+ CUstream cuda_stream = static_cast<CUstream>(stream);
2958
+ cudaGraph_t cuda_graph = static_cast<cudaGraph_t>(graph);
2959
+
2960
+ std::vector<cudaGraphNode_t> leaf_nodes;
2961
+ if (!get_graph_leaf_nodes(cuda_graph, leaf_nodes))
2962
+ return false;
2963
+
2964
+ if (!check_cuda(cudaStreamBeginCaptureToGraph(cuda_stream,
2965
+ cuda_graph,
2966
+ leaf_nodes.data(),
2967
+ nullptr,
2968
+ leaf_nodes.size(),
2969
+ cudaStreamCaptureModeGlobal)))
2970
+ return false;
2971
+
2972
+ return true;
2973
+ }
2974
+
2975
+ // https://developer.nvidia.com/blog/constructing-cuda-graphs-with-dynamic-parameters/#combined_approach
2976
+ // https://developer.nvidia.com/blog/dynamic-control-flow-in-cuda-graphs-with-conditional-nodes/
2977
+ // condition is a gpu pointer
2978
+ // if_graph_ret and else_graph_ret should be NULL if not needed
2979
+ bool cuda_graph_insert_if_else(void* context, void* stream, int* condition, void** if_graph_ret, void** else_graph_ret)
2980
+ {
2981
+ bool has_if = if_graph_ret != NULL;
2982
+ bool has_else = else_graph_ret != NULL;
2983
+ int num_branches = int(has_if) + int(has_else);
2984
+
2985
+ // if neither the IF nor ELSE branches are required, it's a no-op
2986
+ if (num_branches == 0)
2987
+ return true;
2988
+
2989
+ ContextGuard guard(context);
2990
+
2991
+ CUstream cuda_stream = static_cast<CUstream>(stream);
2992
+
2993
+ // Get the current stream capturing graph
2994
+ cudaStreamCaptureStatus capture_status = cudaStreamCaptureStatusNone;
2995
+ cudaGraph_t cuda_graph = NULL;
2996
+ const cudaGraphNode_t* capture_deps = NULL;
2997
+ size_t dep_count = 0;
2998
+ if (!check_cuda(cudaStreamGetCaptureInfo(cuda_stream, &capture_status, nullptr, &cuda_graph, &capture_deps, &dep_count)))
2999
+ return false;
3000
+
3001
+ // abort if not capturing
3002
+ if (!cuda_graph || capture_status != cudaStreamCaptureStatusActive)
3003
+ {
3004
+ wp::set_error_string("Stream is not capturing");
3005
+ return false;
3006
+ }
3007
+
3008
+ //int driver_version = cuda_driver_version();
3009
+
3010
+ // IF-ELSE nodes are only supported with CUDA 12.8+
3011
+ // Somehow child graphs produce wrong results when an else branch is used
3012
+ // Seems to be a bug in the CUDA driver: https://nvbugs/5241330
3013
+ if (num_branches == 1 /*|| driver_version >= 12080*/)
3014
+ {
3015
+ cudaGraphConditionalHandle handle;
3016
+ cudaGraphConditionalHandleCreate(&handle, cuda_graph);
3017
+
3018
+ // run a kernel to set the condition handle from the condition pointer
3019
+ // (need to negate the condition if only the else branch is used)
3020
+ CUfunction kernel;
3021
+ if (has_if)
3022
+ kernel = get_conditional_kernel(context, "set_conditional_if_handle_kernel");
3023
+ else
3024
+ kernel = get_conditional_kernel(context, "set_conditional_else_handle_kernel");
3025
+
3026
+ if (!kernel)
3027
+ {
3028
+ wp::set_error_string("Failed to get built-in conditional kernel");
3029
+ return false;
3030
+ }
3031
+
3032
+ void* kernel_args[2];
3033
+ kernel_args[0] = &handle;
3034
+ kernel_args[1] = &condition;
3035
+
3036
+ if (!check_cuda(cuLaunchKernel_f(kernel, 1, 1, 1, 1, 1, 1, 0, cuda_stream, kernel_args, NULL)))
3037
+ return false;
3038
+
3039
+ if (!check_cuda(cudaStreamGetCaptureInfo(cuda_stream, &capture_status, nullptr, &cuda_graph, &capture_deps, &dep_count)))
3040
+ return false;
3041
+
3042
+ // create conditional node
3043
+ cudaGraphNode_t condition_node;
3044
+ cudaGraphNodeParams condition_params = { cudaGraphNodeTypeConditional };
3045
+ condition_params.conditional.handle = handle;
3046
+ condition_params.conditional.type = cudaGraphCondTypeIf;
3047
+ condition_params.conditional.size = num_branches;
3048
+ if (!check_cuda(cudaGraphAddNode(&condition_node, cuda_graph, capture_deps, dep_count, &condition_params)))
3049
+ return false;
3050
+
3051
+ if (!check_cuda(cudaStreamUpdateCaptureDependencies(cuda_stream, &condition_node, 1, cudaStreamSetCaptureDependencies)))
3052
+ return false;
3053
+
3054
+ if (num_branches == 1)
3055
+ {
3056
+ if (has_if)
3057
+ *if_graph_ret = condition_params.conditional.phGraph_out[0];
3058
+ else
3059
+ *else_graph_ret = condition_params.conditional.phGraph_out[0];
3060
+ }
3061
+ else
3062
+ {
3063
+ *if_graph_ret = condition_params.conditional.phGraph_out[0];
3064
+ *else_graph_ret = condition_params.conditional.phGraph_out[1];
3065
+ }
3066
+ }
3067
+ else
3068
+ {
3069
+ // Create IF node followed by an additional IF node with negated condition
3070
+ cudaGraphConditionalHandle if_handle, else_handle;
3071
+ cudaGraphConditionalHandleCreate(&if_handle, cuda_graph);
3072
+ cudaGraphConditionalHandleCreate(&else_handle, cuda_graph);
3073
+
3074
+ CUfunction kernel = get_conditional_kernel(context, "set_conditional_if_else_handles_kernel");
3075
+ if (!kernel)
3076
+ {
3077
+ wp::set_error_string("Failed to get built-in conditional kernel");
3078
+ return false;
3079
+ }
3080
+
3081
+ void* kernel_args[3];
3082
+ kernel_args[0] = &if_handle;
3083
+ kernel_args[1] = &else_handle;
3084
+ kernel_args[2] = &condition;
3085
+
3086
+ if (!check_cu(cuLaunchKernel_f(kernel, 1, 1, 1, 1, 1, 1, 0, cuda_stream, kernel_args, NULL)))
3087
+ return false;
3088
+
3089
+ if (!check_cuda(cudaStreamGetCaptureInfo(cuda_stream, &capture_status, nullptr, &cuda_graph, &capture_deps, &dep_count)))
3090
+ return false;
3091
+
3092
+ cudaGraphNode_t if_node;
3093
+ cudaGraphNodeParams if_params = { cudaGraphNodeTypeConditional };
3094
+ if_params.conditional.handle = if_handle;
3095
+ if_params.conditional.type = cudaGraphCondTypeIf;
3096
+ if_params.conditional.size = 1;
3097
+ if (!check_cuda(cudaGraphAddNode(&if_node, cuda_graph, capture_deps, dep_count, &if_params)))
3098
+ return false;
3099
+
3100
+ cudaGraphNode_t else_node;
3101
+ cudaGraphNodeParams else_params = { cudaGraphNodeTypeConditional };
3102
+ else_params.conditional.handle = else_handle;
3103
+ else_params.conditional.type = cudaGraphCondTypeIf;
3104
+ else_params.conditional.size = 1;
3105
+ if (!check_cuda(cudaGraphAddNode(&else_node, cuda_graph, &if_node, 1, &else_params)))
3106
+ return false;
3107
+
3108
+ if (!check_cuda(cudaStreamUpdateCaptureDependencies(cuda_stream, &else_node, 1, cudaStreamSetCaptureDependencies)))
3109
+ return false;
3110
+
3111
+ *if_graph_ret = if_params.conditional.phGraph_out[0];
3112
+ *else_graph_ret = else_params.conditional.phGraph_out[0];
3113
+ }
3114
+
3115
+ return true;
3116
+ }
3117
+
3118
+ bool cuda_graph_insert_child_graph(void* context, void* stream, void* child_graph)
3119
+ {
3120
+ ContextGuard guard(context);
3121
+
3122
+ CUstream cuda_stream = static_cast<CUstream>(stream);
3123
+
3124
+ // Get the current stream capturing graph
3125
+ cudaStreamCaptureStatus capture_status = cudaStreamCaptureStatusNone;
3126
+ void* cuda_graph = NULL;
3127
+ const cudaGraphNode_t* capture_deps = NULL;
3128
+ size_t dep_count = 0;
3129
+ if (!check_cuda(cudaStreamGetCaptureInfo(cuda_stream, &capture_status, nullptr, (cudaGraph_t*)&cuda_graph, &capture_deps, &dep_count)))
3130
+ return false;
3131
+
3132
+ if (!cuda_graph_pause_capture(context, cuda_stream, &cuda_graph))
3133
+ return false;
3134
+
3135
+ cudaGraphNode_t body_node;
3136
+ if (!check_cuda(cudaGraphAddChildGraphNode(&body_node,
3137
+ static_cast<cudaGraph_t>(cuda_graph),
3138
+ capture_deps, dep_count,
3139
+ static_cast<cudaGraph_t>(child_graph))))
3140
+ return false;
3141
+
3142
+ if (!cuda_graph_resume_capture(context, cuda_stream, cuda_graph))
3143
+ return false;
3144
+
3145
+ if (!check_cuda(cudaStreamUpdateCaptureDependencies(cuda_stream, &body_node, 1, cudaStreamSetCaptureDependencies)))
3146
+ return false;
3147
+
3148
+ return true;
3149
+ }
3150
+
3151
+ bool cuda_graph_insert_while(void* context, void* stream, int* condition, void** body_graph_ret, uint64_t* handle_ret)
3152
+ {
3153
+ // if there's no body, it's a no-op
3154
+ if (!body_graph_ret)
3155
+ return true;
3156
+
3157
+ ContextGuard guard(context);
3158
+
3159
+ CUstream cuda_stream = static_cast<CUstream>(stream);
3160
+
3161
+ // Get the current stream capturing graph
3162
+ cudaStreamCaptureStatus capture_status = cudaStreamCaptureStatusNone;
3163
+ cudaGraph_t cuda_graph = NULL;
3164
+ const cudaGraphNode_t* capture_deps = NULL;
3165
+ size_t dep_count = 0;
3166
+ if (!check_cuda(cudaStreamGetCaptureInfo(cuda_stream, &capture_status, nullptr, &cuda_graph, &capture_deps, &dep_count)))
3167
+ return false;
3168
+
3169
+ // abort if not capturing
3170
+ if (!cuda_graph || capture_status != cudaStreamCaptureStatusActive)
3171
+ {
3172
+ wp::set_error_string("Stream is not capturing");
3173
+ return false;
3174
+ }
3175
+
3176
+ cudaGraphConditionalHandle handle;
3177
+ if (!check_cuda(cudaGraphConditionalHandleCreate(&handle, cuda_graph)))
3178
+ return false;
3179
+
3180
+ // launch a kernel to set the condition handle from condition pointer
3181
+ CUfunction kernel = get_conditional_kernel(context, "set_conditional_if_handle_kernel");
3182
+ if (!kernel)
3183
+ {
3184
+ wp::set_error_string("Failed to get built-in conditional kernel");
3185
+ return false;
3186
+ }
3187
+
3188
+ void* kernel_args[2];
3189
+ kernel_args[0] = &handle;
3190
+ kernel_args[1] = &condition;
3191
+
3192
+ if (!check_cu(cuLaunchKernel_f(kernel, 1, 1, 1, 1, 1, 1, 0, cuda_stream, kernel_args, NULL)))
3193
+ return false;
3194
+
3195
+ if (!check_cuda(cudaStreamGetCaptureInfo(cuda_stream, &capture_status, nullptr, &cuda_graph, &capture_deps, &dep_count)))
3196
+ return false;
3197
+
3198
+ // insert conditional graph node
3199
+ cudaGraphNode_t while_node;
3200
+ cudaGraphNodeParams while_params = { cudaGraphNodeTypeConditional };
3201
+ while_params.conditional.handle = handle;
3202
+ while_params.conditional.type = cudaGraphCondTypeWhile;
3203
+ while_params.conditional.size = 1;
3204
+ if (!check_cuda(cudaGraphAddNode(&while_node, cuda_graph, capture_deps, dep_count, &while_params)))
3205
+ return false;
3206
+
3207
+ if (!check_cuda(cudaStreamUpdateCaptureDependencies(cuda_stream, &while_node, 1, cudaStreamSetCaptureDependencies)))
3208
+ return false;
3209
+
3210
+ *body_graph_ret = while_params.conditional.phGraph_out[0];
3211
+ *handle_ret = handle;
3212
+
3213
+ return true;
3214
+ }
3215
+
3216
+ bool cuda_graph_set_condition(void* context, void* stream, int* condition, uint64_t handle)
3217
+ {
3218
+ ContextGuard guard(context);
3219
+
3220
+ CUstream cuda_stream = static_cast<CUstream>(stream);
3221
+
3222
+ // launch a kernel to set the condition handle from condition pointer
3223
+ CUfunction kernel = get_conditional_kernel(context, "set_conditional_if_handle_kernel");
3224
+ if (!kernel)
3225
+ {
3226
+ wp::set_error_string("Failed to get built-in conditional kernel");
3227
+ return false;
3228
+ }
3229
+
3230
+ void* kernel_args[2];
3231
+ kernel_args[0] = &handle;
3232
+ kernel_args[1] = &condition;
3233
+
3234
+ if (!check_cu(cuLaunchKernel_f(kernel, 1, 1, 1, 1, 1, 1, 0, cuda_stream, kernel_args, NULL)))
3235
+ return false;
3236
+
3237
+ return true;
3238
+ }
3239
+
3240
+ #else
3241
+ // stubs for conditional graph node API if CUDA toolkit is too old.
3242
+
3243
+ bool cuda_graph_pause_capture(void* context, void* stream, void** graph_ret)
3244
+ {
3245
+ wp::set_error_string("Warp error: Warp must be built with CUDA Toolkit 12.4+ to enable conditional graph nodes");
3246
+ return false;
3247
+ }
3248
+
3249
+ bool cuda_graph_resume_capture(void* context, void* stream, void* graph)
3250
+ {
3251
+ wp::set_error_string("Warp error: Warp must be built with CUDA Toolkit 12.4+ to enable conditional graph nodes");
3252
+ return false;
3253
+ }
3254
+
3255
+ bool cuda_graph_insert_if_else(void* context, void* stream, int* condition, void** if_graph_ret, void** else_graph_ret)
3256
+ {
3257
+ wp::set_error_string("Warp error: Warp must be built with CUDA Toolkit 12.4+ to enable conditional graph nodes");
3258
+ return false;
3259
+ }
3260
+
3261
+ bool cuda_graph_insert_while(void* context, void* stream, int* condition, void** body_graph_ret, uint64_t* handle_ret)
3262
+ {
3263
+ wp::set_error_string("Warp error: Warp must be built with CUDA Toolkit 12.4+ to enable conditional graph nodes");
3264
+ return false;
3265
+ }
3266
+
3267
+ bool cuda_graph_set_condition(void* context, void* stream, int* condition, uint64_t handle)
3268
+ {
3269
+ wp::set_error_string("Warp error: Warp must be built with CUDA Toolkit 12.4+ to enable conditional graph nodes");
3270
+ return false;
3271
+ }
3272
+
3273
+ bool cuda_graph_insert_child_graph(void* context, void* stream, void* child_graph)
3274
+ {
3275
+ wp::set_error_string("Warp error: Warp must be built with CUDA Toolkit 12.4+ to enable conditional graph nodes");
3276
+ return false;
3277
+ }
3278
+
3279
+ #endif // support for conditional graph nodes
3280
+
3281
+
2771
3282
  bool cuda_graph_launch(void* graph_exec, void* stream)
2772
3283
  {
2773
3284
  // TODO: allow naming graphs?
@@ -2780,7 +3291,14 @@ bool cuda_graph_launch(void* graph_exec, void* stream)
2780
3291
  return result;
2781
3292
  }
2782
3293
 
2783
- bool cuda_graph_destroy(void* context, void* graph_exec)
3294
+ bool cuda_graph_destroy(void* context, void* graph)
3295
+ {
3296
+ ContextGuard guard(context);
3297
+
3298
+ return check_cuda(cudaGraphDestroy((cudaGraph_t)graph));
3299
+ }
3300
+
3301
+ bool cuda_graph_exec_destroy(void* context, void* graph_exec)
2784
3302
  {
2785
3303
  ContextGuard guard(context);
2786
3304
 
@@ -2832,7 +3350,7 @@ bool write_file(const char* data, size_t size, std::string filename, const char*
2832
3350
  }
2833
3351
  #endif
2834
3352
 
2835
- size_t cuda_compile_program(const char* cuda_src, const char* program_name, int arch, const char* include_dir, int num_cuda_include_dirs, const char** cuda_include_dirs, bool debug, bool verbose, bool verify_fp, bool fast_math, bool fuse_fp, bool lineinfo, const char* output_path, size_t num_ltoirs, char** ltoirs, size_t* ltoir_sizes, int* ltoir_input_types)
3353
+ size_t cuda_compile_program(const char* cuda_src, const char* program_name, int arch, const char* include_dir, int num_cuda_include_dirs, const char** cuda_include_dirs, bool debug, bool verbose, bool verify_fp, bool fast_math, bool fuse_fp, bool lineinfo, bool compile_time_trace, const char* output_path, size_t num_ltoirs, char** ltoirs, size_t* ltoir_sizes, int* ltoir_input_types)
2836
3354
  {
2837
3355
  // use file extension to determine whether to output PTX or CUBIN
2838
3356
  const char* output_ext = strrchr(output_path, '.');
@@ -2919,11 +3437,11 @@ size_t cuda_compile_program(const char* cuda_src, const char* program_name, int
2919
3437
  else
2920
3438
  opts.push_back("--fmad=false");
2921
3439
 
2922
- std::vector<std::string> cuda_include_opt;
3440
+ std::vector<std::string> stored_options;
2923
3441
  for(int i = 0; i < num_cuda_include_dirs; i++)
2924
3442
  {
2925
- cuda_include_opt.push_back(std::string("--include-path=") + cuda_include_dirs[i]);
2926
- opts.push_back(cuda_include_opt.back().c_str());
3443
+ stored_options.push_back(std::string("--include-path=") + cuda_include_dirs[i]);
3444
+ opts.push_back(stored_options.back().c_str());
2927
3445
  }
2928
3446
 
2929
3447
  opts.push_back("--device-as-default-execution-space");
@@ -2936,6 +3454,16 @@ size_t cuda_compile_program(const char* cuda_src, const char* program_name, int
2936
3454
  opts.push_back("--relocatable-device-code=true");
2937
3455
  }
2938
3456
 
3457
+ if (compile_time_trace)
3458
+ {
3459
+ #if CUDA_VERSION >= 12080
3460
+ stored_options.push_back(std::string("--fdevice-time-trace=") + std::string(output_path).append("_compile-time-trace.json"));
3461
+ opts.push_back(stored_options.back().c_str());
3462
+ #else
3463
+ fprintf(stderr, "Warp warning: CUDA version is less than 12.8, compile_time_trace is not supported\n");
3464
+ #endif
3465
+ }
3466
+
2939
3467
  nvrtcProgram prog;
2940
3468
  nvrtcResult res;
2941
3469
 
@@ -3162,11 +3690,11 @@ size_t cuda_compile_program(const char* cuda_src, const char* program_name, int
3162
3690
  CHECK_ANY(num_include_dirs == 0);
3163
3691
 
3164
3692
  bool res = true;
3165
- cufftdxHandle h;
3166
- CHECK_CUFFTDX(cufftdxCreate(&h));
3693
+ cufftdxDescriptor h;
3694
+ CHECK_CUFFTDX(cufftdxCreateDescriptor(&h));
3167
3695
 
3168
- // CUFFTDX_API_BLOCK_LMEM means each thread starts with a subset of the data
3169
- CHECK_CUFFTDX(cufftdxSetOperatorInt64(h, cufftdxOperatorType::CUFFTDX_OPERATOR_API, cufftdxApi::CUFFTDX_API_BLOCK_LMEM));
3696
+ // CUFFTDX_API_LMEM means each thread starts with a subset of the data
3697
+ CHECK_CUFFTDX(cufftdxSetOperatorInt64(h, cufftdxOperatorType::CUFFTDX_OPERATOR_API, cufftdxApi::CUFFTDX_API_LMEM));
3170
3698
  CHECK_CUFFTDX(cufftdxSetOperatorInt64(h, cufftdxOperatorType::CUFFTDX_OPERATOR_EXECUTION, commondxExecution::COMMONDX_EXECUTION_BLOCK));
3171
3699
  CHECK_CUFFTDX(cufftdxSetOperatorInt64(h, cufftdxOperatorType::CUFFTDX_OPERATOR_SIZE, (long long)size));
3172
3700
  CHECK_CUFFTDX(cufftdxSetOperatorInt64(h, cufftdxOperatorType::CUFFTDX_OPERATOR_DIRECTION, (cufftdxDirection)direction));
@@ -3191,7 +3719,7 @@ size_t cuda_compile_program(const char* cuda_src, const char* program_name, int
3191
3719
  res = false;
3192
3720
  }
3193
3721
 
3194
- CHECK_CUFFTDX(cufftdxDestroy(h));
3722
+ CHECK_CUFFTDX(cufftdxDestroyDescriptor(h));
3195
3723
 
3196
3724
  return res;
3197
3725
  }
@@ -3207,22 +3735,22 @@ size_t cuda_compile_program(const char* cuda_src, const char* program_name, int
3207
3735
  CHECK_ANY(num_include_dirs == 0);
3208
3736
 
3209
3737
  bool res = true;
3210
- cublasdxHandle h;
3211
- CHECK_CUBLASDX(cublasdxCreate(&h));
3738
+ cublasdxDescriptor h;
3739
+ CHECK_CUBLASDX(cublasdxCreateDescriptor(&h));
3212
3740
 
3213
3741
  CHECK_CUBLASDX(cublasdxSetOperatorInt64(h, cublasdxOperatorType::CUBLASDX_OPERATOR_FUNCTION, cublasdxFunction::CUBLASDX_FUNCTION_MM));
3214
3742
  CHECK_CUBLASDX(cublasdxSetOperatorInt64(h, cublasdxOperatorType::CUBLASDX_OPERATOR_EXECUTION, commondxExecution::COMMONDX_EXECUTION_BLOCK));
3215
- CHECK_CUBLASDX(cublasdxSetOperatorInt64(h, cublasdxOperatorType::CUBLASDX_OPERATOR_API, cublasdxApi::CUBLASDX_API_BLOCK_SMEM));
3743
+ CHECK_CUBLASDX(cublasdxSetOperatorInt64(h, cublasdxOperatorType::CUBLASDX_OPERATOR_API, cublasdxApi::CUBLASDX_API_SMEM));
3216
3744
  std::array<long long int, 3> precisions = {precision_A, precision_B, precision_C};
3217
- CHECK_CUBLASDX(cublasdxSetOperatorInt64Array(h, cublasdxOperatorType::CUBLASDX_OPERATOR_PRECISION, 3, precisions.data()));
3745
+ CHECK_CUBLASDX(cublasdxSetOperatorInt64s(h, cublasdxOperatorType::CUBLASDX_OPERATOR_PRECISION, 3, precisions.data()));
3218
3746
  CHECK_CUBLASDX(cublasdxSetOperatorInt64(h, cublasdxOperatorType::CUBLASDX_OPERATOR_SM, (long long)(arch * 10)));
3219
3747
  CHECK_CUBLASDX(cublasdxSetOperatorInt64(h, cublasdxOperatorType::CUBLASDX_OPERATOR_TYPE, (cublasdxType)type));
3220
3748
  std::array<long long int, 3> block_dim = {num_threads, 1, 1};
3221
- CHECK_CUBLASDX(cublasdxSetOperatorInt64Array(h, cublasdxOperatorType::CUBLASDX_OPERATOR_BLOCK_DIM, block_dim.size(), block_dim.data()));
3749
+ CHECK_CUBLASDX(cublasdxSetOperatorInt64s(h, cublasdxOperatorType::CUBLASDX_OPERATOR_BLOCK_DIM, block_dim.size(), block_dim.data()));
3222
3750
  std::array<long long int, 3> size = {M, N, K};
3223
- CHECK_CUBLASDX(cublasdxSetOperatorInt64Array(h, cublasdxOperatorType::CUBLASDX_OPERATOR_SIZE, size.size(), size.data()));
3751
+ CHECK_CUBLASDX(cublasdxSetOperatorInt64s(h, cublasdxOperatorType::CUBLASDX_OPERATOR_SIZE, size.size(), size.data()));
3224
3752
  std::array<long long int, 3> arrangement = {arrangement_A, arrangement_B, arrangement_C};
3225
- CHECK_CUBLASDX(cublasdxSetOperatorInt64Array(h, cublasdxOperatorType::CUBLASDX_OPERATOR_ARRANGEMENT, arrangement.size(), arrangement.data()));
3753
+ CHECK_CUBLASDX(cublasdxSetOperatorInt64s(h, cublasdxOperatorType::CUBLASDX_OPERATOR_ARRANGEMENT, arrangement.size(), arrangement.data()));
3226
3754
 
3227
3755
  CHECK_CUBLASDX(cublasdxSetOptionStr(h, commondxOption::COMMONDX_OPTION_SYMBOL_NAME, symbol_name));
3228
3756
 
@@ -3236,12 +3764,12 @@ size_t cuda_compile_program(const char* cuda_src, const char* program_name, int
3236
3764
  res = false;
3237
3765
  }
3238
3766
 
3239
- CHECK_CUBLASDX(cublasdxDestroy(h));
3767
+ CHECK_CUBLASDX(cublasdxDestroyDescriptor(h));
3240
3768
 
3241
3769
  return res;
3242
3770
  }
3243
3771
 
3244
- bool cuda_compile_solver(const char* fatbin_output_path, const char* ltoir_output_path, const char* symbol_name, int num_include_dirs, const char** include_dirs, const char* mathdx_include_dir, int arch, int M, int N, int function, int precision, int fill_mode, int num_threads)
3772
+ bool cuda_compile_solver(const char* fatbin_output_path, const char* ltoir_output_path, const char* symbol_name, int num_include_dirs, const char** include_dirs, const char* mathdx_include_dir, int arch, int M, int N, int NRHS, int function, int side, int diag, int precision, int arrangement_A, int arrangement_B, int fill_mode, int num_threads)
3245
3773
  {
3246
3774
 
3247
3775
  CHECK_ANY(ltoir_output_path != nullptr);
@@ -3252,34 +3780,42 @@ size_t cuda_compile_program(const char* cuda_src, const char* program_name, int
3252
3780
 
3253
3781
  bool res = true;
3254
3782
 
3255
- cusolverHandle h { 0 };
3256
- CHECK_CUSOLVER(cusolverCreate(&h));
3257
- long long int size[2] = {M, N};
3258
- long long int block_dim[3] = {num_threads, 1, 1};
3259
- CHECK_CUSOLVER(cusolverSetOperatorInt64Array(h, cusolverOperatorType::CUSOLVER_OPERATOR_SIZE, 2, size));
3260
- CHECK_CUSOLVER(cusolverSetOperatorInt64Array(h, cusolverOperatorType::CUSOLVER_OPERATOR_BLOCK_DIM, 3, block_dim));
3261
- CHECK_CUSOLVER(cusolverSetOperatorInt64(h, cusolverOperatorType::CUSOLVER_OPERATOR_TYPE, cusolverType::CUSOLVER_TYPE_REAL));
3262
- CHECK_CUSOLVER(cusolverSetOperatorInt64(h, cusolverOperatorType::CUSOLVER_OPERATOR_API, cusolverApi::CUSOLVER_API_BLOCK_SMEM));
3263
- CHECK_CUSOLVER(cusolverSetOperatorInt64(h, cusolverOperatorType::CUSOLVER_OPERATOR_FUNCTION, (cusolverFunction)function));
3264
- CHECK_CUSOLVER(cusolverSetOperatorInt64(h, cusolverOperatorType::CUSOLVER_OPERATOR_EXECUTION, commondxExecution::COMMONDX_EXECUTION_BLOCK));
3265
- CHECK_CUSOLVER(cusolverSetOperatorInt64(h, cusolverOperatorType::CUSOLVER_OPERATOR_PRECISION, (commondxPrecision)precision));
3266
- CHECK_CUSOLVER(cusolverSetOperatorInt64(h, cusolverOperatorType::CUSOLVER_OPERATOR_FILL_MODE, (cusolverFillMode)fill_mode));
3267
- CHECK_CUSOLVER(cusolverSetOperatorInt64(h, cusolverOperatorType::CUSOLVER_OPERATOR_SM, (long long)(arch * 10)));
3783
+ cusolverdxDescriptor h { 0 };
3784
+ CHECK_CUSOLVER(cusolverdxCreateDescriptor(&h));
3785
+ std::array<long long int, 3> size = {M, N, NRHS};
3786
+ CHECK_CUSOLVER(cusolverdxSetOperatorInt64s(h, cusolverdxOperatorType::CUSOLVERDX_OPERATOR_SIZE, size.size(), size.data()));
3787
+ std::array<long long int, 3> block_dim = {num_threads, 1, 1};
3788
+ CHECK_CUSOLVER(cusolverdxSetOperatorInt64s(h, cusolverdxOperatorType::CUSOLVERDX_OPERATOR_BLOCK_DIM, block_dim.size(), block_dim.data()));
3789
+ CHECK_CUSOLVER(cusolverdxSetOperatorInt64(h, cusolverdxOperatorType::CUSOLVERDX_OPERATOR_TYPE, cusolverdxType::CUSOLVERDX_TYPE_REAL));
3790
+ CHECK_CUSOLVER(cusolverdxSetOperatorInt64(h, cusolverdxOperatorType::CUSOLVERDX_OPERATOR_API, cusolverdxApi::CUSOLVERDX_API_SMEM));
3791
+ CHECK_CUSOLVER(cusolverdxSetOperatorInt64(h, cusolverdxOperatorType::CUSOLVERDX_OPERATOR_FUNCTION, (cusolverdxFunction)function));
3792
+ if (side >= 0) {
3793
+ CHECK_CUSOLVER(cusolverdxSetOperatorInt64(h, cusolverdxOperatorType::CUSOLVERDX_OPERATOR_SIDE, (cusolverdxSide)side));
3794
+ }
3795
+ if (diag >= 0) {
3796
+ CHECK_CUSOLVER(cusolverdxSetOperatorInt64(h, cusolverdxOperatorType::CUSOLVERDX_OPERATOR_DIAG, (cusolverdxDiag)diag));
3797
+ }
3798
+ CHECK_CUSOLVER(cusolverdxSetOperatorInt64(h, cusolverdxOperatorType::CUSOLVERDX_OPERATOR_EXECUTION, commondxExecution::COMMONDX_EXECUTION_BLOCK));
3799
+ CHECK_CUSOLVER(cusolverdxSetOperatorInt64(h, cusolverdxOperatorType::CUSOLVERDX_OPERATOR_PRECISION, (commondxPrecision)precision));
3800
+ std::array<long long int, 2> arrangement = {arrangement_A, arrangement_B};
3801
+ CHECK_CUSOLVER(cusolverdxSetOperatorInt64s(h, cusolverdxOperatorType::CUSOLVERDX_OPERATOR_ARRANGEMENT, arrangement.size(), arrangement.data()));
3802
+ CHECK_CUSOLVER(cusolverdxSetOperatorInt64(h, cusolverdxOperatorType::CUSOLVERDX_OPERATOR_FILL_MODE, (cusolverdxFillMode)fill_mode));
3803
+ CHECK_CUSOLVER(cusolverdxSetOperatorInt64(h, cusolverdxOperatorType::CUSOLVERDX_OPERATOR_SM, (long long)(arch * 10)));
3268
3804
 
3269
- CHECK_CUSOLVER(cusolverSetOptionStr(h, commondxOption::COMMONDX_OPTION_SYMBOL_NAME, symbol_name));
3805
+ CHECK_CUSOLVER(cusolverdxSetOptionStr(h, commondxOption::COMMONDX_OPTION_SYMBOL_NAME, symbol_name));
3270
3806
 
3271
3807
  size_t lto_size = 0;
3272
- CHECK_CUSOLVER(cusolverGetLTOIRSize(h, &lto_size));
3808
+ CHECK_CUSOLVER(cusolverdxGetLTOIRSize(h, &lto_size));
3273
3809
 
3274
3810
  std::vector<char> lto(lto_size);
3275
- CHECK_CUSOLVER(cusolverGetLTOIR(h, lto.size(), lto.data()));
3811
+ CHECK_CUSOLVER(cusolverdxGetLTOIR(h, lto.size(), lto.data()));
3276
3812
 
3277
3813
  // This fatbin is universal, ie it is the same for any instantiations of a cusolver device function
3278
3814
  size_t fatbin_size = 0;
3279
- CHECK_CUSOLVER(cusolverGetUniversalFATBINSize(h, &fatbin_size));
3815
+ CHECK_CUSOLVER(cusolverdxGetUniversalFATBINSize(h, &fatbin_size));
3280
3816
 
3281
3817
  std::vector<char> fatbin(fatbin_size);
3282
- CHECK_CUSOLVER(cusolverGetUniversalFATBIN(h, fatbin.size(), fatbin.data()));
3818
+ CHECK_CUSOLVER(cusolverdxGetUniversalFATBIN(h, fatbin.size(), fatbin.data()));
3283
3819
 
3284
3820
  if(!write_file(lto.data(), lto.size(), ltoir_output_path, "wb")) {
3285
3821
  res = false;
@@ -3289,7 +3825,7 @@ size_t cuda_compile_program(const char* cuda_src, const char* program_name, int
3289
3825
  res = false;
3290
3826
  }
3291
3827
 
3292
- CHECK_CUSOLVER(cusolverDestroy(h));
3828
+ CHECK_CUSOLVER(cusolverdxDestroyDescriptor(h));
3293
3829
 
3294
3830
  return res;
3295
3831
  }