warp-lang 1.7.2__py3-none-macosx_10_13_universal2.whl → 1.8.0__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (181) hide show
  1. warp/__init__.py +3 -1
  2. warp/__init__.pyi +3489 -1
  3. warp/autograd.py +45 -122
  4. warp/bin/libwarp-clang.dylib +0 -0
  5. warp/bin/libwarp.dylib +0 -0
  6. warp/build.py +241 -252
  7. warp/build_dll.py +125 -26
  8. warp/builtins.py +1907 -384
  9. warp/codegen.py +257 -101
  10. warp/config.py +12 -1
  11. warp/constants.py +1 -1
  12. warp/context.py +657 -223
  13. warp/dlpack.py +1 -1
  14. warp/examples/benchmarks/benchmark_cloth.py +2 -2
  15. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  16. warp/examples/core/example_sample_mesh.py +1 -1
  17. warp/examples/core/example_spin_lock.py +93 -0
  18. warp/examples/core/example_work_queue.py +118 -0
  19. warp/examples/fem/example_adaptive_grid.py +5 -5
  20. warp/examples/fem/example_apic_fluid.py +1 -1
  21. warp/examples/fem/example_burgers.py +1 -1
  22. warp/examples/fem/example_convection_diffusion.py +9 -6
  23. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  24. warp/examples/fem/example_deformed_geometry.py +1 -1
  25. warp/examples/fem/example_diffusion.py +2 -2
  26. warp/examples/fem/example_diffusion_3d.py +1 -1
  27. warp/examples/fem/example_distortion_energy.py +1 -1
  28. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  29. warp/examples/fem/example_magnetostatics.py +5 -3
  30. warp/examples/fem/example_mixed_elasticity.py +5 -3
  31. warp/examples/fem/example_navier_stokes.py +11 -9
  32. warp/examples/fem/example_nonconforming_contact.py +5 -3
  33. warp/examples/fem/example_streamlines.py +8 -3
  34. warp/examples/fem/utils.py +9 -8
  35. warp/examples/interop/example_jax_ffi_callback.py +2 -2
  36. warp/examples/optim/example_drone.py +1 -1
  37. warp/examples/sim/example_cloth.py +1 -1
  38. warp/examples/sim/example_cloth_self_contact.py +48 -54
  39. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  40. warp/examples/tile/example_tile_cholesky.py +2 -1
  41. warp/examples/tile/example_tile_convolution.py +1 -1
  42. warp/examples/tile/example_tile_filtering.py +1 -1
  43. warp/examples/tile/example_tile_matmul.py +1 -1
  44. warp/examples/tile/example_tile_mlp.py +2 -0
  45. warp/fabric.py +7 -7
  46. warp/fem/__init__.py +5 -0
  47. warp/fem/adaptivity.py +1 -1
  48. warp/fem/cache.py +152 -63
  49. warp/fem/dirichlet.py +2 -2
  50. warp/fem/domain.py +136 -6
  51. warp/fem/field/field.py +141 -99
  52. warp/fem/field/nodal_field.py +85 -39
  53. warp/fem/field/virtual.py +97 -52
  54. warp/fem/geometry/adaptive_nanogrid.py +91 -86
  55. warp/fem/geometry/closest_point.py +13 -0
  56. warp/fem/geometry/deformed_geometry.py +102 -40
  57. warp/fem/geometry/element.py +56 -2
  58. warp/fem/geometry/geometry.py +323 -22
  59. warp/fem/geometry/grid_2d.py +157 -62
  60. warp/fem/geometry/grid_3d.py +116 -20
  61. warp/fem/geometry/hexmesh.py +86 -20
  62. warp/fem/geometry/nanogrid.py +166 -86
  63. warp/fem/geometry/partition.py +59 -25
  64. warp/fem/geometry/quadmesh.py +86 -135
  65. warp/fem/geometry/tetmesh.py +47 -119
  66. warp/fem/geometry/trimesh.py +77 -270
  67. warp/fem/integrate.py +107 -52
  68. warp/fem/linalg.py +25 -58
  69. warp/fem/operator.py +124 -27
  70. warp/fem/quadrature/pic_quadrature.py +36 -14
  71. warp/fem/quadrature/quadrature.py +40 -16
  72. warp/fem/space/__init__.py +1 -1
  73. warp/fem/space/basis_function_space.py +66 -46
  74. warp/fem/space/basis_space.py +17 -4
  75. warp/fem/space/dof_mapper.py +1 -1
  76. warp/fem/space/function_space.py +2 -2
  77. warp/fem/space/grid_2d_function_space.py +4 -1
  78. warp/fem/space/hexmesh_function_space.py +4 -2
  79. warp/fem/space/nanogrid_function_space.py +3 -1
  80. warp/fem/space/partition.py +11 -2
  81. warp/fem/space/quadmesh_function_space.py +4 -1
  82. warp/fem/space/restriction.py +5 -2
  83. warp/fem/space/shape/__init__.py +10 -8
  84. warp/fem/space/tetmesh_function_space.py +4 -1
  85. warp/fem/space/topology.py +52 -21
  86. warp/fem/space/trimesh_function_space.py +4 -1
  87. warp/fem/utils.py +53 -8
  88. warp/jax.py +1 -2
  89. warp/jax_experimental/ffi.py +12 -17
  90. warp/jax_experimental/xla_ffi.py +37 -24
  91. warp/math.py +171 -1
  92. warp/native/array.h +99 -0
  93. warp/native/builtin.h +174 -31
  94. warp/native/coloring.cpp +1 -1
  95. warp/native/exports.h +118 -63
  96. warp/native/intersect.h +3 -3
  97. warp/native/mat.h +5 -10
  98. warp/native/mathdx.cpp +11 -5
  99. warp/native/matnn.h +1 -123
  100. warp/native/quat.h +28 -4
  101. warp/native/sparse.cpp +121 -258
  102. warp/native/sparse.cu +181 -274
  103. warp/native/spatial.h +305 -17
  104. warp/native/tile.h +583 -72
  105. warp/native/tile_radix_sort.h +1108 -0
  106. warp/native/tile_reduce.h +237 -2
  107. warp/native/tile_scan.h +240 -0
  108. warp/native/tuple.h +189 -0
  109. warp/native/vec.h +6 -16
  110. warp/native/warp.cpp +36 -4
  111. warp/native/warp.cu +574 -51
  112. warp/native/warp.h +47 -74
  113. warp/optim/linear.py +5 -1
  114. warp/paddle.py +7 -8
  115. warp/py.typed +0 -0
  116. warp/render/render_opengl.py +58 -29
  117. warp/render/render_usd.py +124 -61
  118. warp/sim/__init__.py +9 -0
  119. warp/sim/collide.py +252 -78
  120. warp/sim/graph_coloring.py +8 -1
  121. warp/sim/import_mjcf.py +4 -3
  122. warp/sim/import_usd.py +11 -7
  123. warp/sim/integrator.py +5 -2
  124. warp/sim/integrator_euler.py +1 -1
  125. warp/sim/integrator_featherstone.py +1 -1
  126. warp/sim/integrator_vbd.py +751 -320
  127. warp/sim/integrator_xpbd.py +1 -1
  128. warp/sim/model.py +265 -260
  129. warp/sim/utils.py +10 -7
  130. warp/sparse.py +303 -166
  131. warp/tape.py +52 -51
  132. warp/tests/cuda/test_conditional_captures.py +1046 -0
  133. warp/tests/cuda/test_streams.py +1 -1
  134. warp/tests/geometry/test_volume.py +2 -2
  135. warp/tests/interop/test_dlpack.py +9 -9
  136. warp/tests/interop/test_jax.py +0 -1
  137. warp/tests/run_coverage_serial.py +1 -1
  138. warp/tests/sim/disabled_kinematics.py +2 -2
  139. warp/tests/sim/{test_vbd.py → test_cloth.py} +296 -113
  140. warp/tests/sim/test_collision.py +159 -51
  141. warp/tests/sim/test_coloring.py +15 -1
  142. warp/tests/test_array.py +254 -2
  143. warp/tests/test_array_reduce.py +2 -2
  144. warp/tests/test_atomic_cas.py +299 -0
  145. warp/tests/test_codegen.py +142 -19
  146. warp/tests/test_conditional.py +47 -1
  147. warp/tests/test_ctypes.py +0 -20
  148. warp/tests/test_devices.py +8 -0
  149. warp/tests/test_fabricarray.py +4 -2
  150. warp/tests/test_fem.py +58 -25
  151. warp/tests/test_func.py +42 -1
  152. warp/tests/test_grad.py +1 -1
  153. warp/tests/test_lerp.py +1 -3
  154. warp/tests/test_map.py +481 -0
  155. warp/tests/test_mat.py +1 -24
  156. warp/tests/test_quat.py +6 -15
  157. warp/tests/test_rounding.py +10 -38
  158. warp/tests/test_runlength_encode.py +7 -7
  159. warp/tests/test_smoothstep.py +1 -1
  160. warp/tests/test_sparse.py +51 -2
  161. warp/tests/test_spatial.py +507 -1
  162. warp/tests/test_struct.py +2 -2
  163. warp/tests/test_tuple.py +265 -0
  164. warp/tests/test_types.py +2 -2
  165. warp/tests/test_utils.py +24 -18
  166. warp/tests/tile/test_tile.py +420 -1
  167. warp/tests/tile/test_tile_mathdx.py +518 -14
  168. warp/tests/tile/test_tile_reduce.py +213 -0
  169. warp/tests/tile/test_tile_shared_memory.py +130 -1
  170. warp/tests/tile/test_tile_sort.py +117 -0
  171. warp/tests/unittest_suites.py +4 -6
  172. warp/types.py +462 -308
  173. warp/utils.py +647 -86
  174. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/METADATA +20 -6
  175. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/RECORD +178 -166
  176. warp/stubs.py +0 -3381
  177. warp/tests/sim/test_xpbd.py +0 -399
  178. warp/tests/test_mlp.py +0 -282
  179. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/WHEEL +0 -0
  180. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/licenses/LICENSE.md +0 -0
  181. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/top_level.txt +0 -0
warp/native/warp.cu CHANGED
@@ -27,6 +27,9 @@
27
27
  #if WP_ENABLE_MATHDX
28
28
  #include <nvJitLink.h>
29
29
  #include <libmathdx.h>
30
+ #include <libcublasdx.h>
31
+ #include <libcufftdx.h>
32
+ #include <libcusolverdx.h>
30
33
  #endif
31
34
 
32
35
  #include <array>
@@ -155,6 +158,7 @@ struct DeviceInfo
155
158
  int arch = 0;
156
159
  int is_uva = 0;
157
160
  int is_mempool_supported = 0;
161
+ int sm_count = 0;
158
162
  int is_ipc_supported = -1;
159
163
  int max_smem_bytes = 0;
160
164
  CUcontext primary_context = NULL;
@@ -166,6 +170,9 @@ struct ContextInfo
166
170
 
167
171
  // the current stream, managed from Python (see cuda_context_set_stream() and cuda_context_get_stream())
168
172
  CUstream stream = NULL;
173
+
174
+ // conditional graph node support, loaded on demand if the driver supports it (CUDA 12.4+)
175
+ CUmodule conditional_module = NULL;
169
176
  };
170
177
 
171
178
  struct CaptureInfo
@@ -280,6 +287,7 @@ int cuda_init()
280
287
  check_cu(cuDeviceGetAttribute_f(&g_devices[i].pci_device_id, CU_DEVICE_ATTRIBUTE_PCI_DEVICE_ID, device));
281
288
  check_cu(cuDeviceGetAttribute_f(&g_devices[i].is_uva, CU_DEVICE_ATTRIBUTE_UNIFIED_ADDRESSING, device));
282
289
  check_cu(cuDeviceGetAttribute_f(&g_devices[i].is_mempool_supported, CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED, device));
290
+ check_cu(cuDeviceGetAttribute_f(&g_devices[i].sm_count, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, device));
283
291
  #ifdef CUDA_VERSION
284
292
  #if CUDA_VERSION >= 12000
285
293
  int device_attribute_integrated = 0;
@@ -1786,6 +1794,13 @@ int cuda_device_get_arch(int ordinal)
1786
1794
  return 0;
1787
1795
  }
1788
1796
 
1797
+ int cuda_device_get_sm_count(int ordinal)
1798
+ {
1799
+ if (ordinal >= 0 && ordinal < int(g_devices.size()))
1800
+ return g_devices[ordinal].sm_count;
1801
+ return 0;
1802
+ }
1803
+
1789
1804
  void cuda_device_get_uuid(int ordinal, char uuid[16])
1790
1805
  {
1791
1806
  memcpy(uuid, g_devices[ordinal].uuid.bytes, sizeof(char)*16);
@@ -2034,6 +2049,9 @@ void cuda_context_destroy(void* context)
2034
2049
  if (info->stream)
2035
2050
  check_cu(cuStreamDestroy_f(info->stream));
2036
2051
 
2052
+ if (info->conditional_module)
2053
+ check_cu(cuModuleUnload_f(info->conditional_module));
2054
+
2037
2055
  g_contexts.erase(ctx);
2038
2056
  }
2039
2057
 
@@ -2739,22 +2757,10 @@ bool cuda_graph_end_capture(void* context, void* stream, void** graph_ret)
2739
2757
  if (external)
2740
2758
  return true;
2741
2759
 
2742
- cudaGraphExec_t graph_exec = NULL;
2743
-
2744
2760
  // end the capture
2745
2761
  if (!check_cuda(cudaStreamEndCapture(cuda_stream, &graph)))
2746
2762
  return false;
2747
2763
 
2748
- // enable to create debug GraphVis visualization of graph
2749
- // cudaGraphDebugDotPrint(graph, "graph.dot", cudaGraphDebugDotFlagsVerbose);
2750
-
2751
- // can use after CUDA 11.4 to permit graphs to capture cudaMallocAsync() operations
2752
- if (!check_cuda(cudaGraphInstantiateWithFlags(&graph_exec, graph, cudaGraphInstantiateFlagAutoFreeOnLaunch)))
2753
- return false;
2754
-
2755
- // free source graph
2756
- check_cuda(cudaGraphDestroy(graph));
2757
-
2758
2764
  // process deferred free list if no more captures are ongoing
2759
2765
  if (g_captures.empty())
2760
2766
  {
@@ -2763,11 +2769,503 @@ bool cuda_graph_end_capture(void* context, void* stream, void** graph_ret)
2763
2769
  }
2764
2770
 
2765
2771
  if (graph_ret)
2766
- *graph_ret = graph_exec;
2772
+ *graph_ret = graph;
2773
+
2774
+ return true;
2775
+ }
2776
+
2777
+ bool capture_debug_dot_print(void* graph, const char *path, uint32_t flags)
2778
+ {
2779
+ if (!check_cuda(cudaGraphDebugDotPrint((cudaGraph_t)graph, path, flags)))
2780
+ return false;
2781
+ return true;
2782
+ }
2783
+
2784
+ bool cuda_graph_create_exec(void* context, void* graph, void** graph_exec_ret)
2785
+ {
2786
+ ContextGuard guard(context);
2787
+
2788
+ cudaGraphExec_t graph_exec = NULL;
2789
+ if (!check_cuda(cudaGraphInstantiateWithFlags(&graph_exec, (cudaGraph_t)graph, cudaGraphInstantiateFlagAutoFreeOnLaunch)))
2790
+ return false;
2791
+
2792
+ if (graph_exec_ret)
2793
+ *graph_exec_ret = graph_exec;
2794
+
2795
+ return true;
2796
+ }
2797
+
2798
+ // Support for conditional graph nodes available with CUDA 12.4+.
2799
+ #if CUDA_VERSION >= 12040
2800
+
2801
+ // CUBIN data for compiled conditional modules, loaded on demand, keyed on device architecture
2802
+ static std::map<int, void*> g_conditional_cubins;
2803
+
2804
+ // Compile module with conditional helper kernels
2805
+ static void* compile_conditional_module(int arch)
2806
+ {
2807
+ static const char* kernel_source = R"(
2808
+ typedef __device_builtin__ unsigned long long cudaGraphConditionalHandle;
2809
+ extern "C" __device__ __cudart_builtin__ void cudaGraphSetConditional(cudaGraphConditionalHandle handle, unsigned int value);
2810
+
2811
+ extern "C" __global__ void set_conditional_if_handle_kernel(cudaGraphConditionalHandle handle, int* value)
2812
+ {
2813
+ if (threadIdx.x + blockIdx.x * blockDim.x == 0)
2814
+ cudaGraphSetConditional(handle, *value);
2815
+ }
2816
+
2817
+ extern "C" __global__ void set_conditional_else_handle_kernel(cudaGraphConditionalHandle handle, int* value)
2818
+ {
2819
+ if (threadIdx.x + blockIdx.x * blockDim.x == 0)
2820
+ cudaGraphSetConditional(handle, !*value);
2821
+ }
2822
+
2823
+ extern "C" __global__ void set_conditional_if_else_handles_kernel(cudaGraphConditionalHandle if_handle, cudaGraphConditionalHandle else_handle, int* value)
2824
+ {
2825
+ if (threadIdx.x + blockIdx.x * blockDim.x == 0)
2826
+ {
2827
+ cudaGraphSetConditional(if_handle, *value);
2828
+ cudaGraphSetConditional(else_handle, !*value);
2829
+ }
2830
+ }
2831
+ )";
2832
+
2833
+ // avoid recompilation
2834
+ auto it = g_conditional_cubins.find(arch);
2835
+ if (it != g_conditional_cubins.end())
2836
+ return it->second;
2837
+
2838
+ nvrtcProgram prog;
2839
+ if (!check_nvrtc(nvrtcCreateProgram(&prog, kernel_source, "conditional_kernels", 0, NULL, NULL)))
2840
+ return NULL;
2841
+
2842
+ char arch_opt[128];
2843
+ snprintf(arch_opt, sizeof(arch_opt), "--gpu-architecture=sm_%d", arch);
2844
+
2845
+ std::vector<const char*> opts;
2846
+ opts.push_back(arch_opt);
2847
+
2848
+ if (!check_nvrtc(nvrtcCompileProgram(prog, int(opts.size()), opts.data())))
2849
+ {
2850
+ size_t log_size;
2851
+ if (check_nvrtc(nvrtcGetProgramLogSize(prog, &log_size)))
2852
+ {
2853
+ std::vector<char> log(log_size);
2854
+ if (check_nvrtc(nvrtcGetProgramLog(prog, log.data())))
2855
+ fprintf(stderr, "%s", log.data());
2856
+ }
2857
+ nvrtcDestroyProgram(&prog);
2858
+ return NULL;
2859
+ }
2860
+
2861
+ // get output
2862
+ char* output = NULL;
2863
+ size_t output_size = 0;
2864
+ check_nvrtc(nvrtcGetCUBINSize(prog, &output_size));
2865
+ if (output_size > 0)
2866
+ {
2867
+ output = new char[output_size];
2868
+ if (check_nvrtc(nvrtcGetCUBIN(prog, output)))
2869
+ g_conditional_cubins[arch] = output;
2870
+ }
2871
+
2872
+ nvrtcDestroyProgram(&prog);
2873
+
2874
+ // return CUBIN data
2875
+ return output;
2876
+ }
2877
+
2878
+
2879
+ // Load module with conditional helper kernels
2880
+ static CUmodule load_conditional_module(void* context)
2881
+ {
2882
+ ContextInfo* context_info = get_context_info(context);
2883
+ if (!context_info)
2884
+ return NULL;
2885
+
2886
+ // check if already loaded
2887
+ if (context_info->conditional_module)
2888
+ return context_info->conditional_module;
2889
+
2890
+ int arch = context_info->device_info->arch;
2891
+
2892
+ // compile if needed
2893
+ void* compiled_module = compile_conditional_module(arch);
2894
+ if (!compiled_module)
2895
+ {
2896
+ fprintf(stderr, "Warp error: Failed to compile conditional kernels\n");
2897
+ return NULL;
2898
+ }
2899
+
2900
+ // load module
2901
+ CUmodule module = NULL;
2902
+ if (!check_cu(cuModuleLoadDataEx_f(&module, compiled_module, 0, NULL, NULL)))
2903
+ {
2904
+ fprintf(stderr, "Warp error: Failed to load conditional kernels module\n");
2905
+ return NULL;
2906
+ }
2907
+
2908
+ context_info->conditional_module = module;
2909
+
2910
+ return module;
2911
+ }
2912
+
2913
+ static CUfunction get_conditional_kernel(void* context, const char* name)
2914
+ {
2915
+ // load module if needed
2916
+ CUmodule module = load_conditional_module(context);
2917
+ if (!module)
2918
+ return NULL;
2919
+
2920
+ CUfunction kernel;
2921
+ if (!check_cu(cuModuleGetFunction_f(&kernel, module, name)))
2922
+ {
2923
+ fprintf(stderr, "Warp error: Failed to get kernel %s\n", name);
2924
+ return NULL;
2925
+ }
2926
+
2927
+ return kernel;
2928
+ }
2929
+
2930
+ bool cuda_graph_pause_capture(void* context, void* stream, void** graph_ret)
2931
+ {
2932
+ ContextGuard guard(context);
2933
+
2934
+ CUstream cuda_stream = static_cast<CUstream>(stream);
2935
+ if (!check_cuda(cudaStreamEndCapture(cuda_stream, (cudaGraph_t*)graph_ret)))
2936
+ return false;
2937
+ return true;
2938
+ }
2939
+
2940
+ bool cuda_graph_resume_capture(void* context, void* stream, void* graph)
2941
+ {
2942
+ ContextGuard guard(context);
2943
+
2944
+ CUstream cuda_stream = static_cast<CUstream>(stream);
2945
+ cudaGraph_t cuda_graph = static_cast<cudaGraph_t>(graph);
2946
+
2947
+ std::vector<cudaGraphNode_t> leaf_nodes;
2948
+ if (!get_graph_leaf_nodes(cuda_graph, leaf_nodes))
2949
+ return false;
2950
+
2951
+ if (!check_cuda(cudaStreamBeginCaptureToGraph(cuda_stream,
2952
+ cuda_graph,
2953
+ leaf_nodes.data(),
2954
+ nullptr,
2955
+ leaf_nodes.size(),
2956
+ cudaStreamCaptureModeGlobal)))
2957
+ return false;
2958
+
2959
+ return true;
2960
+ }
2961
+
2962
+ // https://developer.nvidia.com/blog/constructing-cuda-graphs-with-dynamic-parameters/#combined_approach
2963
+ // https://developer.nvidia.com/blog/dynamic-control-flow-in-cuda-graphs-with-conditional-nodes/
2964
+ // condition is a gpu pointer
2965
+ // if_graph_ret and else_graph_ret should be NULL if not needed
2966
+ bool cuda_graph_insert_if_else(void* context, void* stream, int* condition, void** if_graph_ret, void** else_graph_ret)
2967
+ {
2968
+ bool has_if = if_graph_ret != NULL;
2969
+ bool has_else = else_graph_ret != NULL;
2970
+ int num_branches = int(has_if) + int(has_else);
2971
+
2972
+ // if neither the IF nor ELSE branches are required, it's a no-op
2973
+ if (num_branches == 0)
2974
+ return true;
2975
+
2976
+ ContextGuard guard(context);
2977
+
2978
+ CUstream cuda_stream = static_cast<CUstream>(stream);
2979
+
2980
+ // Get the current stream capturing graph
2981
+ cudaStreamCaptureStatus capture_status = cudaStreamCaptureStatusNone;
2982
+ cudaGraph_t cuda_graph = NULL;
2983
+ const cudaGraphNode_t* capture_deps = NULL;
2984
+ size_t dep_count = 0;
2985
+ if (!check_cuda(cudaStreamGetCaptureInfo(cuda_stream, &capture_status, nullptr, &cuda_graph, &capture_deps, &dep_count)))
2986
+ return false;
2987
+
2988
+ // abort if not capturing
2989
+ if (!cuda_graph || capture_status != cudaStreamCaptureStatusActive)
2990
+ {
2991
+ wp::set_error_string("Stream is not capturing");
2992
+ return false;
2993
+ }
2994
+
2995
+ //int driver_version = cuda_driver_version();
2996
+
2997
+ // IF-ELSE nodes are only supported with CUDA 12.8+
2998
+ // Somehow child graphs produce wrong results when an else branch is used
2999
+ // Seems to be a bug in the CUDA driver: https://nvbugs/5241330
3000
+ if (num_branches == 1 /*|| driver_version >= 12080*/)
3001
+ {
3002
+ cudaGraphConditionalHandle handle;
3003
+ cudaGraphConditionalHandleCreate(&handle, cuda_graph);
3004
+
3005
+ // run a kernel to set the condition handle from the condition pointer
3006
+ // (need to negate the condition if only the else branch is used)
3007
+ CUfunction kernel;
3008
+ if (has_if)
3009
+ kernel = get_conditional_kernel(context, "set_conditional_if_handle_kernel");
3010
+ else
3011
+ kernel = get_conditional_kernel(context, "set_conditional_else_handle_kernel");
3012
+
3013
+ if (!kernel)
3014
+ {
3015
+ wp::set_error_string("Failed to get built-in conditional kernel");
3016
+ return false;
3017
+ }
3018
+
3019
+ void* kernel_args[2];
3020
+ kernel_args[0] = &handle;
3021
+ kernel_args[1] = &condition;
3022
+
3023
+ if (!check_cuda(cuLaunchKernel_f(kernel, 1, 1, 1, 1, 1, 1, 0, cuda_stream, kernel_args, NULL)))
3024
+ return false;
3025
+
3026
+ if (!check_cuda(cudaStreamGetCaptureInfo(cuda_stream, &capture_status, nullptr, &cuda_graph, &capture_deps, &dep_count)))
3027
+ return false;
3028
+
3029
+ // create conditional node
3030
+ cudaGraphNode_t condition_node;
3031
+ cudaGraphNodeParams condition_params = { cudaGraphNodeTypeConditional };
3032
+ condition_params.conditional.handle = handle;
3033
+ condition_params.conditional.type = cudaGraphCondTypeIf;
3034
+ condition_params.conditional.size = num_branches;
3035
+ if (!check_cuda(cudaGraphAddNode(&condition_node, cuda_graph, capture_deps, dep_count, &condition_params)))
3036
+ return false;
3037
+
3038
+ if (!check_cuda(cudaStreamUpdateCaptureDependencies(cuda_stream, &condition_node, 1, cudaStreamSetCaptureDependencies)))
3039
+ return false;
3040
+
3041
+ if (num_branches == 1)
3042
+ {
3043
+ if (has_if)
3044
+ *if_graph_ret = condition_params.conditional.phGraph_out[0];
3045
+ else
3046
+ *else_graph_ret = condition_params.conditional.phGraph_out[0];
3047
+ }
3048
+ else
3049
+ {
3050
+ *if_graph_ret = condition_params.conditional.phGraph_out[0];
3051
+ *else_graph_ret = condition_params.conditional.phGraph_out[1];
3052
+ }
3053
+ }
3054
+ else
3055
+ {
3056
+ // Create IF node followed by an additional IF node with negated condition
3057
+ cudaGraphConditionalHandle if_handle, else_handle;
3058
+ cudaGraphConditionalHandleCreate(&if_handle, cuda_graph);
3059
+ cudaGraphConditionalHandleCreate(&else_handle, cuda_graph);
3060
+
3061
+ CUfunction kernel = get_conditional_kernel(context, "set_conditional_if_else_handles_kernel");
3062
+ if (!kernel)
3063
+ {
3064
+ wp::set_error_string("Failed to get built-in conditional kernel");
3065
+ return false;
3066
+ }
3067
+
3068
+ void* kernel_args[3];
3069
+ kernel_args[0] = &if_handle;
3070
+ kernel_args[1] = &else_handle;
3071
+ kernel_args[2] = &condition;
3072
+
3073
+ if (!check_cu(cuLaunchKernel_f(kernel, 1, 1, 1, 1, 1, 1, 0, cuda_stream, kernel_args, NULL)))
3074
+ return false;
3075
+
3076
+ if (!check_cuda(cudaStreamGetCaptureInfo(cuda_stream, &capture_status, nullptr, &cuda_graph, &capture_deps, &dep_count)))
3077
+ return false;
3078
+
3079
+ cudaGraphNode_t if_node;
3080
+ cudaGraphNodeParams if_params = { cudaGraphNodeTypeConditional };
3081
+ if_params.conditional.handle = if_handle;
3082
+ if_params.conditional.type = cudaGraphCondTypeIf;
3083
+ if_params.conditional.size = 1;
3084
+ if (!check_cuda(cudaGraphAddNode(&if_node, cuda_graph, capture_deps, dep_count, &if_params)))
3085
+ return false;
3086
+
3087
+ cudaGraphNode_t else_node;
3088
+ cudaGraphNodeParams else_params = { cudaGraphNodeTypeConditional };
3089
+ else_params.conditional.handle = else_handle;
3090
+ else_params.conditional.type = cudaGraphCondTypeIf;
3091
+ else_params.conditional.size = 1;
3092
+ if (!check_cuda(cudaGraphAddNode(&else_node, cuda_graph, &if_node, 1, &else_params)))
3093
+ return false;
3094
+
3095
+ if (!check_cuda(cudaStreamUpdateCaptureDependencies(cuda_stream, &else_node, 1, cudaStreamSetCaptureDependencies)))
3096
+ return false;
3097
+
3098
+ *if_graph_ret = if_params.conditional.phGraph_out[0];
3099
+ *else_graph_ret = else_params.conditional.phGraph_out[0];
3100
+ }
3101
+
3102
+ return true;
3103
+ }
3104
+
3105
+ bool cuda_graph_insert_child_graph(void* context, void* stream, void* child_graph)
3106
+ {
3107
+ ContextGuard guard(context);
3108
+
3109
+ CUstream cuda_stream = static_cast<CUstream>(stream);
3110
+
3111
+ // Get the current stream capturing graph
3112
+ cudaStreamCaptureStatus capture_status = cudaStreamCaptureStatusNone;
3113
+ void* cuda_graph = NULL;
3114
+ const cudaGraphNode_t* capture_deps = NULL;
3115
+ size_t dep_count = 0;
3116
+ if (!check_cuda(cudaStreamGetCaptureInfo(cuda_stream, &capture_status, nullptr, (cudaGraph_t*)&cuda_graph, &capture_deps, &dep_count)))
3117
+ return false;
3118
+
3119
+ if (!cuda_graph_pause_capture(context, cuda_stream, &cuda_graph))
3120
+ return false;
3121
+
3122
+ cudaGraphNode_t body_node;
3123
+ if (!check_cuda(cudaGraphAddChildGraphNode(&body_node,
3124
+ static_cast<cudaGraph_t>(cuda_graph),
3125
+ capture_deps, dep_count,
3126
+ static_cast<cudaGraph_t>(child_graph))))
3127
+ return false;
3128
+
3129
+ if (!cuda_graph_resume_capture(context, cuda_stream, cuda_graph))
3130
+ return false;
3131
+
3132
+ if (!check_cuda(cudaStreamUpdateCaptureDependencies(cuda_stream, &body_node, 1, cudaStreamSetCaptureDependencies)))
3133
+ return false;
2767
3134
 
2768
3135
  return true;
2769
3136
  }
2770
3137
 
3138
+ bool cuda_graph_insert_while(void* context, void* stream, int* condition, void** body_graph_ret, uint64_t* handle_ret)
3139
+ {
3140
+ // if there's no body, it's a no-op
3141
+ if (!body_graph_ret)
3142
+ return true;
3143
+
3144
+ ContextGuard guard(context);
3145
+
3146
+ CUstream cuda_stream = static_cast<CUstream>(stream);
3147
+
3148
+ // Get the current stream capturing graph
3149
+ cudaStreamCaptureStatus capture_status = cudaStreamCaptureStatusNone;
3150
+ cudaGraph_t cuda_graph = NULL;
3151
+ const cudaGraphNode_t* capture_deps = NULL;
3152
+ size_t dep_count = 0;
3153
+ if (!check_cuda(cudaStreamGetCaptureInfo(cuda_stream, &capture_status, nullptr, &cuda_graph, &capture_deps, &dep_count)))
3154
+ return false;
3155
+
3156
+ // abort if not capturing
3157
+ if (!cuda_graph || capture_status != cudaStreamCaptureStatusActive)
3158
+ {
3159
+ wp::set_error_string("Stream is not capturing");
3160
+ return false;
3161
+ }
3162
+
3163
+ cudaGraphConditionalHandle handle;
3164
+ if (!check_cuda(cudaGraphConditionalHandleCreate(&handle, cuda_graph)))
3165
+ return false;
3166
+
3167
+ // launch a kernel to set the condition handle from condition pointer
3168
+ CUfunction kernel = get_conditional_kernel(context, "set_conditional_if_handle_kernel");
3169
+ if (!kernel)
3170
+ {
3171
+ wp::set_error_string("Failed to get built-in conditional kernel");
3172
+ return false;
3173
+ }
3174
+
3175
+ void* kernel_args[2];
3176
+ kernel_args[0] = &handle;
3177
+ kernel_args[1] = &condition;
3178
+
3179
+ if (!check_cu(cuLaunchKernel_f(kernel, 1, 1, 1, 1, 1, 1, 0, cuda_stream, kernel_args, NULL)))
3180
+ return false;
3181
+
3182
+ if (!check_cuda(cudaStreamGetCaptureInfo(cuda_stream, &capture_status, nullptr, &cuda_graph, &capture_deps, &dep_count)))
3183
+ return false;
3184
+
3185
+ // insert conditional graph node
3186
+ cudaGraphNode_t while_node;
3187
+ cudaGraphNodeParams while_params = { cudaGraphNodeTypeConditional };
3188
+ while_params.conditional.handle = handle;
3189
+ while_params.conditional.type = cudaGraphCondTypeWhile;
3190
+ while_params.conditional.size = 1;
3191
+ if (!check_cuda(cudaGraphAddNode(&while_node, cuda_graph, capture_deps, dep_count, &while_params)))
3192
+ return false;
3193
+
3194
+ if (!check_cuda(cudaStreamUpdateCaptureDependencies(cuda_stream, &while_node, 1, cudaStreamSetCaptureDependencies)))
3195
+ return false;
3196
+
3197
+ *body_graph_ret = while_params.conditional.phGraph_out[0];
3198
+ *handle_ret = handle;
3199
+
3200
+ return true;
3201
+ }
3202
+
3203
+ bool cuda_graph_set_condition(void* context, void* stream, int* condition, uint64_t handle)
3204
+ {
3205
+ ContextGuard guard(context);
3206
+
3207
+ CUstream cuda_stream = static_cast<CUstream>(stream);
3208
+
3209
+ // launch a kernel to set the condition handle from condition pointer
3210
+ CUfunction kernel = get_conditional_kernel(context, "set_conditional_if_handle_kernel");
3211
+ if (!kernel)
3212
+ {
3213
+ wp::set_error_string("Failed to get built-in conditional kernel");
3214
+ return false;
3215
+ }
3216
+
3217
+ void* kernel_args[2];
3218
+ kernel_args[0] = &handle;
3219
+ kernel_args[1] = &condition;
3220
+
3221
+ if (!check_cu(cuLaunchKernel_f(kernel, 1, 1, 1, 1, 1, 1, 0, cuda_stream, kernel_args, NULL)))
3222
+ return false;
3223
+
3224
+ return true;
3225
+ }
3226
+
3227
+ #else
3228
+ // stubs for conditional graph node API if CUDA toolkit is too old.
3229
+
3230
+ bool cuda_graph_pause_capture(void* context, void* stream, void** graph_ret)
3231
+ {
3232
+ wp::set_error_string("Warp error: Warp must be built with CUDA Toolkit 12.4+ to enable conditional graph nodes");
3233
+ return false;
3234
+ }
3235
+
3236
+ bool cuda_graph_resume_capture(void* context, void* stream, void* graph)
3237
+ {
3238
+ wp::set_error_string("Warp error: Warp must be built with CUDA Toolkit 12.4+ to enable conditional graph nodes");
3239
+ return false;
3240
+ }
3241
+
3242
+ bool cuda_graph_insert_if_else(void* context, void* stream, int* condition, void** if_graph_ret, void** else_graph_ret)
3243
+ {
3244
+ wp::set_error_string("Warp error: Warp must be built with CUDA Toolkit 12.4+ to enable conditional graph nodes");
3245
+ return false;
3246
+ }
3247
+
3248
+ bool cuda_graph_insert_while(void* context, void* stream, int* condition, void** body_graph_ret, uint64_t* handle_ret)
3249
+ {
3250
+ wp::set_error_string("Warp error: Warp must be built with CUDA Toolkit 12.4+ to enable conditional graph nodes");
3251
+ return false;
3252
+ }
3253
+
3254
+ bool cuda_graph_set_condition(void* context, void* stream, int* condition, uint64_t handle)
3255
+ {
3256
+ wp::set_error_string("Warp error: Warp must be built with CUDA Toolkit 12.4+ to enable conditional graph nodes");
3257
+ return false;
3258
+ }
3259
+
3260
+ bool cuda_graph_insert_child_graph(void* context, void* stream, void* child_graph)
3261
+ {
3262
+ wp::set_error_string("Warp error: Warp must be built with CUDA Toolkit 12.4+ to enable conditional graph nodes");
3263
+ return false;
3264
+ }
3265
+
3266
+ #endif // support for conditional graph nodes
3267
+
3268
+
2771
3269
  bool cuda_graph_launch(void* graph_exec, void* stream)
2772
3270
  {
2773
3271
  // TODO: allow naming graphs?
@@ -2780,7 +3278,14 @@ bool cuda_graph_launch(void* graph_exec, void* stream)
2780
3278
  return result;
2781
3279
  }
2782
3280
 
2783
- bool cuda_graph_destroy(void* context, void* graph_exec)
3281
+ bool cuda_graph_destroy(void* context, void* graph)
3282
+ {
3283
+ ContextGuard guard(context);
3284
+
3285
+ return check_cuda(cudaGraphDestroy((cudaGraph_t)graph));
3286
+ }
3287
+
3288
+ bool cuda_graph_exec_destroy(void* context, void* graph_exec)
2784
3289
  {
2785
3290
  ContextGuard guard(context);
2786
3291
 
@@ -2832,7 +3337,7 @@ bool write_file(const char* data, size_t size, std::string filename, const char*
2832
3337
  }
2833
3338
  #endif
2834
3339
 
2835
- size_t cuda_compile_program(const char* cuda_src, const char* program_name, int arch, const char* include_dir, int num_cuda_include_dirs, const char** cuda_include_dirs, bool debug, bool verbose, bool verify_fp, bool fast_math, bool fuse_fp, bool lineinfo, const char* output_path, size_t num_ltoirs, char** ltoirs, size_t* ltoir_sizes, int* ltoir_input_types)
3340
+ size_t cuda_compile_program(const char* cuda_src, const char* program_name, int arch, const char* include_dir, int num_cuda_include_dirs, const char** cuda_include_dirs, bool debug, bool verbose, bool verify_fp, bool fast_math, bool fuse_fp, bool lineinfo, bool compile_time_trace, const char* output_path, size_t num_ltoirs, char** ltoirs, size_t* ltoir_sizes, int* ltoir_input_types)
2836
3341
  {
2837
3342
  // use file extension to determine whether to output PTX or CUBIN
2838
3343
  const char* output_ext = strrchr(output_path, '.');
@@ -2919,11 +3424,11 @@ size_t cuda_compile_program(const char* cuda_src, const char* program_name, int
2919
3424
  else
2920
3425
  opts.push_back("--fmad=false");
2921
3426
 
2922
- std::vector<std::string> cuda_include_opt;
3427
+ std::vector<std::string> stored_options;
2923
3428
  for(int i = 0; i < num_cuda_include_dirs; i++)
2924
3429
  {
2925
- cuda_include_opt.push_back(std::string("--include-path=") + cuda_include_dirs[i]);
2926
- opts.push_back(cuda_include_opt.back().c_str());
3430
+ stored_options.push_back(std::string("--include-path=") + cuda_include_dirs[i]);
3431
+ opts.push_back(stored_options.back().c_str());
2927
3432
  }
2928
3433
 
2929
3434
  opts.push_back("--device-as-default-execution-space");
@@ -2936,6 +3441,16 @@ size_t cuda_compile_program(const char* cuda_src, const char* program_name, int
2936
3441
  opts.push_back("--relocatable-device-code=true");
2937
3442
  }
2938
3443
 
3444
+ if (compile_time_trace)
3445
+ {
3446
+ #if CUDA_VERSION >= 12080
3447
+ stored_options.push_back(std::string("--fdevice-time-trace=") + std::string(output_path).append("_compile-time-trace.json"));
3448
+ opts.push_back(stored_options.back().c_str());
3449
+ #else
3450
+ fprintf(stderr, "Warp warning: CUDA version is less than 12.8, compile_time_trace is not supported\n");
3451
+ #endif
3452
+ }
3453
+
2939
3454
  nvrtcProgram prog;
2940
3455
  nvrtcResult res;
2941
3456
 
@@ -3162,11 +3677,11 @@ size_t cuda_compile_program(const char* cuda_src, const char* program_name, int
3162
3677
  CHECK_ANY(num_include_dirs == 0);
3163
3678
 
3164
3679
  bool res = true;
3165
- cufftdxHandle h;
3166
- CHECK_CUFFTDX(cufftdxCreate(&h));
3680
+ cufftdxDescriptor h;
3681
+ CHECK_CUFFTDX(cufftdxCreateDescriptor(&h));
3167
3682
 
3168
- // CUFFTDX_API_BLOCK_LMEM means each thread starts with a subset of the data
3169
- CHECK_CUFFTDX(cufftdxSetOperatorInt64(h, cufftdxOperatorType::CUFFTDX_OPERATOR_API, cufftdxApi::CUFFTDX_API_BLOCK_LMEM));
3683
+ // CUFFTDX_API_LMEM means each thread starts with a subset of the data
3684
+ CHECK_CUFFTDX(cufftdxSetOperatorInt64(h, cufftdxOperatorType::CUFFTDX_OPERATOR_API, cufftdxApi::CUFFTDX_API_LMEM));
3170
3685
  CHECK_CUFFTDX(cufftdxSetOperatorInt64(h, cufftdxOperatorType::CUFFTDX_OPERATOR_EXECUTION, commondxExecution::COMMONDX_EXECUTION_BLOCK));
3171
3686
  CHECK_CUFFTDX(cufftdxSetOperatorInt64(h, cufftdxOperatorType::CUFFTDX_OPERATOR_SIZE, (long long)size));
3172
3687
  CHECK_CUFFTDX(cufftdxSetOperatorInt64(h, cufftdxOperatorType::CUFFTDX_OPERATOR_DIRECTION, (cufftdxDirection)direction));
@@ -3191,7 +3706,7 @@ size_t cuda_compile_program(const char* cuda_src, const char* program_name, int
3191
3706
  res = false;
3192
3707
  }
3193
3708
 
3194
- CHECK_CUFFTDX(cufftdxDestroy(h));
3709
+ CHECK_CUFFTDX(cufftdxDestroyDescriptor(h));
3195
3710
 
3196
3711
  return res;
3197
3712
  }
@@ -3207,22 +3722,22 @@ size_t cuda_compile_program(const char* cuda_src, const char* program_name, int
3207
3722
  CHECK_ANY(num_include_dirs == 0);
3208
3723
 
3209
3724
  bool res = true;
3210
- cublasdxHandle h;
3211
- CHECK_CUBLASDX(cublasdxCreate(&h));
3725
+ cublasdxDescriptor h;
3726
+ CHECK_CUBLASDX(cublasdxCreateDescriptor(&h));
3212
3727
 
3213
3728
  CHECK_CUBLASDX(cublasdxSetOperatorInt64(h, cublasdxOperatorType::CUBLASDX_OPERATOR_FUNCTION, cublasdxFunction::CUBLASDX_FUNCTION_MM));
3214
3729
  CHECK_CUBLASDX(cublasdxSetOperatorInt64(h, cublasdxOperatorType::CUBLASDX_OPERATOR_EXECUTION, commondxExecution::COMMONDX_EXECUTION_BLOCK));
3215
- CHECK_CUBLASDX(cublasdxSetOperatorInt64(h, cublasdxOperatorType::CUBLASDX_OPERATOR_API, cublasdxApi::CUBLASDX_API_BLOCK_SMEM));
3730
+ CHECK_CUBLASDX(cublasdxSetOperatorInt64(h, cublasdxOperatorType::CUBLASDX_OPERATOR_API, cublasdxApi::CUBLASDX_API_SMEM));
3216
3731
  std::array<long long int, 3> precisions = {precision_A, precision_B, precision_C};
3217
- CHECK_CUBLASDX(cublasdxSetOperatorInt64Array(h, cublasdxOperatorType::CUBLASDX_OPERATOR_PRECISION, 3, precisions.data()));
3732
+ CHECK_CUBLASDX(cublasdxSetOperatorInt64s(h, cublasdxOperatorType::CUBLASDX_OPERATOR_PRECISION, 3, precisions.data()));
3218
3733
  CHECK_CUBLASDX(cublasdxSetOperatorInt64(h, cublasdxOperatorType::CUBLASDX_OPERATOR_SM, (long long)(arch * 10)));
3219
3734
  CHECK_CUBLASDX(cublasdxSetOperatorInt64(h, cublasdxOperatorType::CUBLASDX_OPERATOR_TYPE, (cublasdxType)type));
3220
3735
  std::array<long long int, 3> block_dim = {num_threads, 1, 1};
3221
- CHECK_CUBLASDX(cublasdxSetOperatorInt64Array(h, cublasdxOperatorType::CUBLASDX_OPERATOR_BLOCK_DIM, block_dim.size(), block_dim.data()));
3736
+ CHECK_CUBLASDX(cublasdxSetOperatorInt64s(h, cublasdxOperatorType::CUBLASDX_OPERATOR_BLOCK_DIM, block_dim.size(), block_dim.data()));
3222
3737
  std::array<long long int, 3> size = {M, N, K};
3223
- CHECK_CUBLASDX(cublasdxSetOperatorInt64Array(h, cublasdxOperatorType::CUBLASDX_OPERATOR_SIZE, size.size(), size.data()));
3738
+ CHECK_CUBLASDX(cublasdxSetOperatorInt64s(h, cublasdxOperatorType::CUBLASDX_OPERATOR_SIZE, size.size(), size.data()));
3224
3739
  std::array<long long int, 3> arrangement = {arrangement_A, arrangement_B, arrangement_C};
3225
- CHECK_CUBLASDX(cublasdxSetOperatorInt64Array(h, cublasdxOperatorType::CUBLASDX_OPERATOR_ARRANGEMENT, arrangement.size(), arrangement.data()));
3740
+ CHECK_CUBLASDX(cublasdxSetOperatorInt64s(h, cublasdxOperatorType::CUBLASDX_OPERATOR_ARRANGEMENT, arrangement.size(), arrangement.data()));
3226
3741
 
3227
3742
  CHECK_CUBLASDX(cublasdxSetOptionStr(h, commondxOption::COMMONDX_OPTION_SYMBOL_NAME, symbol_name));
3228
3743
 
@@ -3236,12 +3751,12 @@ size_t cuda_compile_program(const char* cuda_src, const char* program_name, int
3236
3751
  res = false;
3237
3752
  }
3238
3753
 
3239
- CHECK_CUBLASDX(cublasdxDestroy(h));
3754
+ CHECK_CUBLASDX(cublasdxDestroyDescriptor(h));
3240
3755
 
3241
3756
  return res;
3242
3757
  }
3243
3758
 
3244
- bool cuda_compile_solver(const char* fatbin_output_path, const char* ltoir_output_path, const char* symbol_name, int num_include_dirs, const char** include_dirs, const char* mathdx_include_dir, int arch, int M, int N, int function, int precision, int fill_mode, int num_threads)
3759
+ bool cuda_compile_solver(const char* fatbin_output_path, const char* ltoir_output_path, const char* symbol_name, int num_include_dirs, const char** include_dirs, const char* mathdx_include_dir, int arch, int M, int N, int NRHS, int function, int side, int diag, int precision, int arrangement_A, int arrangement_B, int fill_mode, int num_threads)
3245
3760
  {
3246
3761
 
3247
3762
  CHECK_ANY(ltoir_output_path != nullptr);
@@ -3252,34 +3767,42 @@ size_t cuda_compile_program(const char* cuda_src, const char* program_name, int
3252
3767
 
3253
3768
  bool res = true;
3254
3769
 
3255
- cusolverHandle h { 0 };
3256
- CHECK_CUSOLVER(cusolverCreate(&h));
3257
- long long int size[2] = {M, N};
3258
- long long int block_dim[3] = {num_threads, 1, 1};
3259
- CHECK_CUSOLVER(cusolverSetOperatorInt64Array(h, cusolverOperatorType::CUSOLVER_OPERATOR_SIZE, 2, size));
3260
- CHECK_CUSOLVER(cusolverSetOperatorInt64Array(h, cusolverOperatorType::CUSOLVER_OPERATOR_BLOCK_DIM, 3, block_dim));
3261
- CHECK_CUSOLVER(cusolverSetOperatorInt64(h, cusolverOperatorType::CUSOLVER_OPERATOR_TYPE, cusolverType::CUSOLVER_TYPE_REAL));
3262
- CHECK_CUSOLVER(cusolverSetOperatorInt64(h, cusolverOperatorType::CUSOLVER_OPERATOR_API, cusolverApi::CUSOLVER_API_BLOCK_SMEM));
3263
- CHECK_CUSOLVER(cusolverSetOperatorInt64(h, cusolverOperatorType::CUSOLVER_OPERATOR_FUNCTION, (cusolverFunction)function));
3264
- CHECK_CUSOLVER(cusolverSetOperatorInt64(h, cusolverOperatorType::CUSOLVER_OPERATOR_EXECUTION, commondxExecution::COMMONDX_EXECUTION_BLOCK));
3265
- CHECK_CUSOLVER(cusolverSetOperatorInt64(h, cusolverOperatorType::CUSOLVER_OPERATOR_PRECISION, (commondxPrecision)precision));
3266
- CHECK_CUSOLVER(cusolverSetOperatorInt64(h, cusolverOperatorType::CUSOLVER_OPERATOR_FILL_MODE, (cusolverFillMode)fill_mode));
3267
- CHECK_CUSOLVER(cusolverSetOperatorInt64(h, cusolverOperatorType::CUSOLVER_OPERATOR_SM, (long long)(arch * 10)));
3770
+ cusolverdxDescriptor h { 0 };
3771
+ CHECK_CUSOLVER(cusolverdxCreateDescriptor(&h));
3772
+ std::array<long long int, 3> size = {M, N, NRHS};
3773
+ CHECK_CUSOLVER(cusolverdxSetOperatorInt64s(h, cusolverdxOperatorType::CUSOLVERDX_OPERATOR_SIZE, size.size(), size.data()));
3774
+ std::array<long long int, 3> block_dim = {num_threads, 1, 1};
3775
+ CHECK_CUSOLVER(cusolverdxSetOperatorInt64s(h, cusolverdxOperatorType::CUSOLVERDX_OPERATOR_BLOCK_DIM, block_dim.size(), block_dim.data()));
3776
+ CHECK_CUSOLVER(cusolverdxSetOperatorInt64(h, cusolverdxOperatorType::CUSOLVERDX_OPERATOR_TYPE, cusolverdxType::CUSOLVERDX_TYPE_REAL));
3777
+ CHECK_CUSOLVER(cusolverdxSetOperatorInt64(h, cusolverdxOperatorType::CUSOLVERDX_OPERATOR_API, cusolverdxApi::CUSOLVERDX_API_SMEM));
3778
+ CHECK_CUSOLVER(cusolverdxSetOperatorInt64(h, cusolverdxOperatorType::CUSOLVERDX_OPERATOR_FUNCTION, (cusolverdxFunction)function));
3779
+ if (side >= 0) {
3780
+ CHECK_CUSOLVER(cusolverdxSetOperatorInt64(h, cusolverdxOperatorType::CUSOLVERDX_OPERATOR_SIDE, (cusolverdxSide)side));
3781
+ }
3782
+ if (diag >= 0) {
3783
+ CHECK_CUSOLVER(cusolverdxSetOperatorInt64(h, cusolverdxOperatorType::CUSOLVERDX_OPERATOR_DIAG, (cusolverdxDiag)diag));
3784
+ }
3785
+ CHECK_CUSOLVER(cusolverdxSetOperatorInt64(h, cusolverdxOperatorType::CUSOLVERDX_OPERATOR_EXECUTION, commondxExecution::COMMONDX_EXECUTION_BLOCK));
3786
+ CHECK_CUSOLVER(cusolverdxSetOperatorInt64(h, cusolverdxOperatorType::CUSOLVERDX_OPERATOR_PRECISION, (commondxPrecision)precision));
3787
+ std::array<long long int, 2> arrangement = {arrangement_A, arrangement_B};
3788
+ CHECK_CUSOLVER(cusolverdxSetOperatorInt64s(h, cusolverdxOperatorType::CUSOLVERDX_OPERATOR_ARRANGEMENT, arrangement.size(), arrangement.data()));
3789
+ CHECK_CUSOLVER(cusolverdxSetOperatorInt64(h, cusolverdxOperatorType::CUSOLVERDX_OPERATOR_FILL_MODE, (cusolverdxFillMode)fill_mode));
3790
+ CHECK_CUSOLVER(cusolverdxSetOperatorInt64(h, cusolverdxOperatorType::CUSOLVERDX_OPERATOR_SM, (long long)(arch * 10)));
3268
3791
 
3269
- CHECK_CUSOLVER(cusolverSetOptionStr(h, commondxOption::COMMONDX_OPTION_SYMBOL_NAME, symbol_name));
3792
+ CHECK_CUSOLVER(cusolverdxSetOptionStr(h, commondxOption::COMMONDX_OPTION_SYMBOL_NAME, symbol_name));
3270
3793
 
3271
3794
  size_t lto_size = 0;
3272
- CHECK_CUSOLVER(cusolverGetLTOIRSize(h, &lto_size));
3795
+ CHECK_CUSOLVER(cusolverdxGetLTOIRSize(h, &lto_size));
3273
3796
 
3274
3797
  std::vector<char> lto(lto_size);
3275
- CHECK_CUSOLVER(cusolverGetLTOIR(h, lto.size(), lto.data()));
3798
+ CHECK_CUSOLVER(cusolverdxGetLTOIR(h, lto.size(), lto.data()));
3276
3799
 
3277
3800
  // This fatbin is universal, ie it is the same for any instantiations of a cusolver device function
3278
3801
  size_t fatbin_size = 0;
3279
- CHECK_CUSOLVER(cusolverGetUniversalFATBINSize(h, &fatbin_size));
3802
+ CHECK_CUSOLVER(cusolverdxGetUniversalFATBINSize(h, &fatbin_size));
3280
3803
 
3281
3804
  std::vector<char> fatbin(fatbin_size);
3282
- CHECK_CUSOLVER(cusolverGetUniversalFATBIN(h, fatbin.size(), fatbin.data()));
3805
+ CHECK_CUSOLVER(cusolverdxGetUniversalFATBIN(h, fatbin.size(), fatbin.data()));
3283
3806
 
3284
3807
  if(!write_file(lto.data(), lto.size(), ltoir_output_path, "wb")) {
3285
3808
  res = false;
@@ -3289,7 +3812,7 @@ size_t cuda_compile_program(const char* cuda_src, const char* program_name, int
3289
3812
  res = false;
3290
3813
  }
3291
3814
 
3292
- CHECK_CUSOLVER(cusolverDestroy(h));
3815
+ CHECK_CUSOLVER(cusolverdxDestroyDescriptor(h));
3293
3816
 
3294
3817
  return res;
3295
3818
  }