warp-lang 1.7.2__py3-none-macosx_10_13_universal2.whl → 1.8.0__py3-none-macosx_10_13_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +3 -1
- warp/__init__.pyi +3489 -1
- warp/autograd.py +45 -122
- warp/bin/libwarp-clang.dylib +0 -0
- warp/bin/libwarp.dylib +0 -0
- warp/build.py +241 -252
- warp/build_dll.py +125 -26
- warp/builtins.py +1907 -384
- warp/codegen.py +257 -101
- warp/config.py +12 -1
- warp/constants.py +1 -1
- warp/context.py +657 -223
- warp/dlpack.py +1 -1
- warp/examples/benchmarks/benchmark_cloth.py +2 -2
- warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
- warp/examples/core/example_sample_mesh.py +1 -1
- warp/examples/core/example_spin_lock.py +93 -0
- warp/examples/core/example_work_queue.py +118 -0
- warp/examples/fem/example_adaptive_grid.py +5 -5
- warp/examples/fem/example_apic_fluid.py +1 -1
- warp/examples/fem/example_burgers.py +1 -1
- warp/examples/fem/example_convection_diffusion.py +9 -6
- warp/examples/fem/example_darcy_ls_optimization.py +489 -0
- warp/examples/fem/example_deformed_geometry.py +1 -1
- warp/examples/fem/example_diffusion.py +2 -2
- warp/examples/fem/example_diffusion_3d.py +1 -1
- warp/examples/fem/example_distortion_energy.py +1 -1
- warp/examples/fem/example_elastic_shape_optimization.py +387 -0
- warp/examples/fem/example_magnetostatics.py +5 -3
- warp/examples/fem/example_mixed_elasticity.py +5 -3
- warp/examples/fem/example_navier_stokes.py +11 -9
- warp/examples/fem/example_nonconforming_contact.py +5 -3
- warp/examples/fem/example_streamlines.py +8 -3
- warp/examples/fem/utils.py +9 -8
- warp/examples/interop/example_jax_ffi_callback.py +2 -2
- warp/examples/optim/example_drone.py +1 -1
- warp/examples/sim/example_cloth.py +1 -1
- warp/examples/sim/example_cloth_self_contact.py +48 -54
- warp/examples/tile/example_tile_block_cholesky.py +502 -0
- warp/examples/tile/example_tile_cholesky.py +2 -1
- warp/examples/tile/example_tile_convolution.py +1 -1
- warp/examples/tile/example_tile_filtering.py +1 -1
- warp/examples/tile/example_tile_matmul.py +1 -1
- warp/examples/tile/example_tile_mlp.py +2 -0
- warp/fabric.py +7 -7
- warp/fem/__init__.py +5 -0
- warp/fem/adaptivity.py +1 -1
- warp/fem/cache.py +152 -63
- warp/fem/dirichlet.py +2 -2
- warp/fem/domain.py +136 -6
- warp/fem/field/field.py +141 -99
- warp/fem/field/nodal_field.py +85 -39
- warp/fem/field/virtual.py +97 -52
- warp/fem/geometry/adaptive_nanogrid.py +91 -86
- warp/fem/geometry/closest_point.py +13 -0
- warp/fem/geometry/deformed_geometry.py +102 -40
- warp/fem/geometry/element.py +56 -2
- warp/fem/geometry/geometry.py +323 -22
- warp/fem/geometry/grid_2d.py +157 -62
- warp/fem/geometry/grid_3d.py +116 -20
- warp/fem/geometry/hexmesh.py +86 -20
- warp/fem/geometry/nanogrid.py +166 -86
- warp/fem/geometry/partition.py +59 -25
- warp/fem/geometry/quadmesh.py +86 -135
- warp/fem/geometry/tetmesh.py +47 -119
- warp/fem/geometry/trimesh.py +77 -270
- warp/fem/integrate.py +107 -52
- warp/fem/linalg.py +25 -58
- warp/fem/operator.py +124 -27
- warp/fem/quadrature/pic_quadrature.py +36 -14
- warp/fem/quadrature/quadrature.py +40 -16
- warp/fem/space/__init__.py +1 -1
- warp/fem/space/basis_function_space.py +66 -46
- warp/fem/space/basis_space.py +17 -4
- warp/fem/space/dof_mapper.py +1 -1
- warp/fem/space/function_space.py +2 -2
- warp/fem/space/grid_2d_function_space.py +4 -1
- warp/fem/space/hexmesh_function_space.py +4 -2
- warp/fem/space/nanogrid_function_space.py +3 -1
- warp/fem/space/partition.py +11 -2
- warp/fem/space/quadmesh_function_space.py +4 -1
- warp/fem/space/restriction.py +5 -2
- warp/fem/space/shape/__init__.py +10 -8
- warp/fem/space/tetmesh_function_space.py +4 -1
- warp/fem/space/topology.py +52 -21
- warp/fem/space/trimesh_function_space.py +4 -1
- warp/fem/utils.py +53 -8
- warp/jax.py +1 -2
- warp/jax_experimental/ffi.py +12 -17
- warp/jax_experimental/xla_ffi.py +37 -24
- warp/math.py +171 -1
- warp/native/array.h +99 -0
- warp/native/builtin.h +174 -31
- warp/native/coloring.cpp +1 -1
- warp/native/exports.h +118 -63
- warp/native/intersect.h +3 -3
- warp/native/mat.h +5 -10
- warp/native/mathdx.cpp +11 -5
- warp/native/matnn.h +1 -123
- warp/native/quat.h +28 -4
- warp/native/sparse.cpp +121 -258
- warp/native/sparse.cu +181 -274
- warp/native/spatial.h +305 -17
- warp/native/tile.h +583 -72
- warp/native/tile_radix_sort.h +1108 -0
- warp/native/tile_reduce.h +237 -2
- warp/native/tile_scan.h +240 -0
- warp/native/tuple.h +189 -0
- warp/native/vec.h +6 -16
- warp/native/warp.cpp +36 -4
- warp/native/warp.cu +574 -51
- warp/native/warp.h +47 -74
- warp/optim/linear.py +5 -1
- warp/paddle.py +7 -8
- warp/py.typed +0 -0
- warp/render/render_opengl.py +58 -29
- warp/render/render_usd.py +124 -61
- warp/sim/__init__.py +9 -0
- warp/sim/collide.py +252 -78
- warp/sim/graph_coloring.py +8 -1
- warp/sim/import_mjcf.py +4 -3
- warp/sim/import_usd.py +11 -7
- warp/sim/integrator.py +5 -2
- warp/sim/integrator_euler.py +1 -1
- warp/sim/integrator_featherstone.py +1 -1
- warp/sim/integrator_vbd.py +751 -320
- warp/sim/integrator_xpbd.py +1 -1
- warp/sim/model.py +265 -260
- warp/sim/utils.py +10 -7
- warp/sparse.py +303 -166
- warp/tape.py +52 -51
- warp/tests/cuda/test_conditional_captures.py +1046 -0
- warp/tests/cuda/test_streams.py +1 -1
- warp/tests/geometry/test_volume.py +2 -2
- warp/tests/interop/test_dlpack.py +9 -9
- warp/tests/interop/test_jax.py +0 -1
- warp/tests/run_coverage_serial.py +1 -1
- warp/tests/sim/disabled_kinematics.py +2 -2
- warp/tests/sim/{test_vbd.py → test_cloth.py} +296 -113
- warp/tests/sim/test_collision.py +159 -51
- warp/tests/sim/test_coloring.py +15 -1
- warp/tests/test_array.py +254 -2
- warp/tests/test_array_reduce.py +2 -2
- warp/tests/test_atomic_cas.py +299 -0
- warp/tests/test_codegen.py +142 -19
- warp/tests/test_conditional.py +47 -1
- warp/tests/test_ctypes.py +0 -20
- warp/tests/test_devices.py +8 -0
- warp/tests/test_fabricarray.py +4 -2
- warp/tests/test_fem.py +58 -25
- warp/tests/test_func.py +42 -1
- warp/tests/test_grad.py +1 -1
- warp/tests/test_lerp.py +1 -3
- warp/tests/test_map.py +481 -0
- warp/tests/test_mat.py +1 -24
- warp/tests/test_quat.py +6 -15
- warp/tests/test_rounding.py +10 -38
- warp/tests/test_runlength_encode.py +7 -7
- warp/tests/test_smoothstep.py +1 -1
- warp/tests/test_sparse.py +51 -2
- warp/tests/test_spatial.py +507 -1
- warp/tests/test_struct.py +2 -2
- warp/tests/test_tuple.py +265 -0
- warp/tests/test_types.py +2 -2
- warp/tests/test_utils.py +24 -18
- warp/tests/tile/test_tile.py +420 -1
- warp/tests/tile/test_tile_mathdx.py +518 -14
- warp/tests/tile/test_tile_reduce.py +213 -0
- warp/tests/tile/test_tile_shared_memory.py +130 -1
- warp/tests/tile/test_tile_sort.py +117 -0
- warp/tests/unittest_suites.py +4 -6
- warp/types.py +462 -308
- warp/utils.py +647 -86
- {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/METADATA +20 -6
- {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/RECORD +178 -166
- warp/stubs.py +0 -3381
- warp/tests/sim/test_xpbd.py +0 -399
- warp/tests/test_mlp.py +0 -282
- {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/WHEEL +0 -0
- {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/licenses/LICENSE.md +0 -0
- {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/top_level.txt +0 -0
warp/native/warp.cu
CHANGED
|
@@ -27,6 +27,9 @@
|
|
|
27
27
|
#if WP_ENABLE_MATHDX
|
|
28
28
|
#include <nvJitLink.h>
|
|
29
29
|
#include <libmathdx.h>
|
|
30
|
+
#include <libcublasdx.h>
|
|
31
|
+
#include <libcufftdx.h>
|
|
32
|
+
#include <libcusolverdx.h>
|
|
30
33
|
#endif
|
|
31
34
|
|
|
32
35
|
#include <array>
|
|
@@ -155,6 +158,7 @@ struct DeviceInfo
|
|
|
155
158
|
int arch = 0;
|
|
156
159
|
int is_uva = 0;
|
|
157
160
|
int is_mempool_supported = 0;
|
|
161
|
+
int sm_count = 0;
|
|
158
162
|
int is_ipc_supported = -1;
|
|
159
163
|
int max_smem_bytes = 0;
|
|
160
164
|
CUcontext primary_context = NULL;
|
|
@@ -166,6 +170,9 @@ struct ContextInfo
|
|
|
166
170
|
|
|
167
171
|
// the current stream, managed from Python (see cuda_context_set_stream() and cuda_context_get_stream())
|
|
168
172
|
CUstream stream = NULL;
|
|
173
|
+
|
|
174
|
+
// conditional graph node support, loaded on demand if the driver supports it (CUDA 12.4+)
|
|
175
|
+
CUmodule conditional_module = NULL;
|
|
169
176
|
};
|
|
170
177
|
|
|
171
178
|
struct CaptureInfo
|
|
@@ -280,6 +287,7 @@ int cuda_init()
|
|
|
280
287
|
check_cu(cuDeviceGetAttribute_f(&g_devices[i].pci_device_id, CU_DEVICE_ATTRIBUTE_PCI_DEVICE_ID, device));
|
|
281
288
|
check_cu(cuDeviceGetAttribute_f(&g_devices[i].is_uva, CU_DEVICE_ATTRIBUTE_UNIFIED_ADDRESSING, device));
|
|
282
289
|
check_cu(cuDeviceGetAttribute_f(&g_devices[i].is_mempool_supported, CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED, device));
|
|
290
|
+
check_cu(cuDeviceGetAttribute_f(&g_devices[i].sm_count, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, device));
|
|
283
291
|
#ifdef CUDA_VERSION
|
|
284
292
|
#if CUDA_VERSION >= 12000
|
|
285
293
|
int device_attribute_integrated = 0;
|
|
@@ -1786,6 +1794,13 @@ int cuda_device_get_arch(int ordinal)
|
|
|
1786
1794
|
return 0;
|
|
1787
1795
|
}
|
|
1788
1796
|
|
|
1797
|
+
int cuda_device_get_sm_count(int ordinal)
|
|
1798
|
+
{
|
|
1799
|
+
if (ordinal >= 0 && ordinal < int(g_devices.size()))
|
|
1800
|
+
return g_devices[ordinal].sm_count;
|
|
1801
|
+
return 0;
|
|
1802
|
+
}
|
|
1803
|
+
|
|
1789
1804
|
void cuda_device_get_uuid(int ordinal, char uuid[16])
|
|
1790
1805
|
{
|
|
1791
1806
|
memcpy(uuid, g_devices[ordinal].uuid.bytes, sizeof(char)*16);
|
|
@@ -2034,6 +2049,9 @@ void cuda_context_destroy(void* context)
|
|
|
2034
2049
|
if (info->stream)
|
|
2035
2050
|
check_cu(cuStreamDestroy_f(info->stream));
|
|
2036
2051
|
|
|
2052
|
+
if (info->conditional_module)
|
|
2053
|
+
check_cu(cuModuleUnload_f(info->conditional_module));
|
|
2054
|
+
|
|
2037
2055
|
g_contexts.erase(ctx);
|
|
2038
2056
|
}
|
|
2039
2057
|
|
|
@@ -2739,22 +2757,10 @@ bool cuda_graph_end_capture(void* context, void* stream, void** graph_ret)
|
|
|
2739
2757
|
if (external)
|
|
2740
2758
|
return true;
|
|
2741
2759
|
|
|
2742
|
-
cudaGraphExec_t graph_exec = NULL;
|
|
2743
|
-
|
|
2744
2760
|
// end the capture
|
|
2745
2761
|
if (!check_cuda(cudaStreamEndCapture(cuda_stream, &graph)))
|
|
2746
2762
|
return false;
|
|
2747
2763
|
|
|
2748
|
-
// enable to create debug GraphVis visualization of graph
|
|
2749
|
-
// cudaGraphDebugDotPrint(graph, "graph.dot", cudaGraphDebugDotFlagsVerbose);
|
|
2750
|
-
|
|
2751
|
-
// can use after CUDA 11.4 to permit graphs to capture cudaMallocAsync() operations
|
|
2752
|
-
if (!check_cuda(cudaGraphInstantiateWithFlags(&graph_exec, graph, cudaGraphInstantiateFlagAutoFreeOnLaunch)))
|
|
2753
|
-
return false;
|
|
2754
|
-
|
|
2755
|
-
// free source graph
|
|
2756
|
-
check_cuda(cudaGraphDestroy(graph));
|
|
2757
|
-
|
|
2758
2764
|
// process deferred free list if no more captures are ongoing
|
|
2759
2765
|
if (g_captures.empty())
|
|
2760
2766
|
{
|
|
@@ -2763,11 +2769,503 @@ bool cuda_graph_end_capture(void* context, void* stream, void** graph_ret)
|
|
|
2763
2769
|
}
|
|
2764
2770
|
|
|
2765
2771
|
if (graph_ret)
|
|
2766
|
-
*graph_ret =
|
|
2772
|
+
*graph_ret = graph;
|
|
2773
|
+
|
|
2774
|
+
return true;
|
|
2775
|
+
}
|
|
2776
|
+
|
|
2777
|
+
bool capture_debug_dot_print(void* graph, const char *path, uint32_t flags)
|
|
2778
|
+
{
|
|
2779
|
+
if (!check_cuda(cudaGraphDebugDotPrint((cudaGraph_t)graph, path, flags)))
|
|
2780
|
+
return false;
|
|
2781
|
+
return true;
|
|
2782
|
+
}
|
|
2783
|
+
|
|
2784
|
+
bool cuda_graph_create_exec(void* context, void* graph, void** graph_exec_ret)
|
|
2785
|
+
{
|
|
2786
|
+
ContextGuard guard(context);
|
|
2787
|
+
|
|
2788
|
+
cudaGraphExec_t graph_exec = NULL;
|
|
2789
|
+
if (!check_cuda(cudaGraphInstantiateWithFlags(&graph_exec, (cudaGraph_t)graph, cudaGraphInstantiateFlagAutoFreeOnLaunch)))
|
|
2790
|
+
return false;
|
|
2791
|
+
|
|
2792
|
+
if (graph_exec_ret)
|
|
2793
|
+
*graph_exec_ret = graph_exec;
|
|
2794
|
+
|
|
2795
|
+
return true;
|
|
2796
|
+
}
|
|
2797
|
+
|
|
2798
|
+
// Support for conditional graph nodes available with CUDA 12.4+.
|
|
2799
|
+
#if CUDA_VERSION >= 12040
|
|
2800
|
+
|
|
2801
|
+
// CUBIN data for compiled conditional modules, loaded on demand, keyed on device architecture
|
|
2802
|
+
static std::map<int, void*> g_conditional_cubins;
|
|
2803
|
+
|
|
2804
|
+
// Compile module with conditional helper kernels
|
|
2805
|
+
static void* compile_conditional_module(int arch)
|
|
2806
|
+
{
|
|
2807
|
+
static const char* kernel_source = R"(
|
|
2808
|
+
typedef __device_builtin__ unsigned long long cudaGraphConditionalHandle;
|
|
2809
|
+
extern "C" __device__ __cudart_builtin__ void cudaGraphSetConditional(cudaGraphConditionalHandle handle, unsigned int value);
|
|
2810
|
+
|
|
2811
|
+
extern "C" __global__ void set_conditional_if_handle_kernel(cudaGraphConditionalHandle handle, int* value)
|
|
2812
|
+
{
|
|
2813
|
+
if (threadIdx.x + blockIdx.x * blockDim.x == 0)
|
|
2814
|
+
cudaGraphSetConditional(handle, *value);
|
|
2815
|
+
}
|
|
2816
|
+
|
|
2817
|
+
extern "C" __global__ void set_conditional_else_handle_kernel(cudaGraphConditionalHandle handle, int* value)
|
|
2818
|
+
{
|
|
2819
|
+
if (threadIdx.x + blockIdx.x * blockDim.x == 0)
|
|
2820
|
+
cudaGraphSetConditional(handle, !*value);
|
|
2821
|
+
}
|
|
2822
|
+
|
|
2823
|
+
extern "C" __global__ void set_conditional_if_else_handles_kernel(cudaGraphConditionalHandle if_handle, cudaGraphConditionalHandle else_handle, int* value)
|
|
2824
|
+
{
|
|
2825
|
+
if (threadIdx.x + blockIdx.x * blockDim.x == 0)
|
|
2826
|
+
{
|
|
2827
|
+
cudaGraphSetConditional(if_handle, *value);
|
|
2828
|
+
cudaGraphSetConditional(else_handle, !*value);
|
|
2829
|
+
}
|
|
2830
|
+
}
|
|
2831
|
+
)";
|
|
2832
|
+
|
|
2833
|
+
// avoid recompilation
|
|
2834
|
+
auto it = g_conditional_cubins.find(arch);
|
|
2835
|
+
if (it != g_conditional_cubins.end())
|
|
2836
|
+
return it->second;
|
|
2837
|
+
|
|
2838
|
+
nvrtcProgram prog;
|
|
2839
|
+
if (!check_nvrtc(nvrtcCreateProgram(&prog, kernel_source, "conditional_kernels", 0, NULL, NULL)))
|
|
2840
|
+
return NULL;
|
|
2841
|
+
|
|
2842
|
+
char arch_opt[128];
|
|
2843
|
+
snprintf(arch_opt, sizeof(arch_opt), "--gpu-architecture=sm_%d", arch);
|
|
2844
|
+
|
|
2845
|
+
std::vector<const char*> opts;
|
|
2846
|
+
opts.push_back(arch_opt);
|
|
2847
|
+
|
|
2848
|
+
if (!check_nvrtc(nvrtcCompileProgram(prog, int(opts.size()), opts.data())))
|
|
2849
|
+
{
|
|
2850
|
+
size_t log_size;
|
|
2851
|
+
if (check_nvrtc(nvrtcGetProgramLogSize(prog, &log_size)))
|
|
2852
|
+
{
|
|
2853
|
+
std::vector<char> log(log_size);
|
|
2854
|
+
if (check_nvrtc(nvrtcGetProgramLog(prog, log.data())))
|
|
2855
|
+
fprintf(stderr, "%s", log.data());
|
|
2856
|
+
}
|
|
2857
|
+
nvrtcDestroyProgram(&prog);
|
|
2858
|
+
return NULL;
|
|
2859
|
+
}
|
|
2860
|
+
|
|
2861
|
+
// get output
|
|
2862
|
+
char* output = NULL;
|
|
2863
|
+
size_t output_size = 0;
|
|
2864
|
+
check_nvrtc(nvrtcGetCUBINSize(prog, &output_size));
|
|
2865
|
+
if (output_size > 0)
|
|
2866
|
+
{
|
|
2867
|
+
output = new char[output_size];
|
|
2868
|
+
if (check_nvrtc(nvrtcGetCUBIN(prog, output)))
|
|
2869
|
+
g_conditional_cubins[arch] = output;
|
|
2870
|
+
}
|
|
2871
|
+
|
|
2872
|
+
nvrtcDestroyProgram(&prog);
|
|
2873
|
+
|
|
2874
|
+
// return CUBIN data
|
|
2875
|
+
return output;
|
|
2876
|
+
}
|
|
2877
|
+
|
|
2878
|
+
|
|
2879
|
+
// Load module with conditional helper kernels
|
|
2880
|
+
static CUmodule load_conditional_module(void* context)
|
|
2881
|
+
{
|
|
2882
|
+
ContextInfo* context_info = get_context_info(context);
|
|
2883
|
+
if (!context_info)
|
|
2884
|
+
return NULL;
|
|
2885
|
+
|
|
2886
|
+
// check if already loaded
|
|
2887
|
+
if (context_info->conditional_module)
|
|
2888
|
+
return context_info->conditional_module;
|
|
2889
|
+
|
|
2890
|
+
int arch = context_info->device_info->arch;
|
|
2891
|
+
|
|
2892
|
+
// compile if needed
|
|
2893
|
+
void* compiled_module = compile_conditional_module(arch);
|
|
2894
|
+
if (!compiled_module)
|
|
2895
|
+
{
|
|
2896
|
+
fprintf(stderr, "Warp error: Failed to compile conditional kernels\n");
|
|
2897
|
+
return NULL;
|
|
2898
|
+
}
|
|
2899
|
+
|
|
2900
|
+
// load module
|
|
2901
|
+
CUmodule module = NULL;
|
|
2902
|
+
if (!check_cu(cuModuleLoadDataEx_f(&module, compiled_module, 0, NULL, NULL)))
|
|
2903
|
+
{
|
|
2904
|
+
fprintf(stderr, "Warp error: Failed to load conditional kernels module\n");
|
|
2905
|
+
return NULL;
|
|
2906
|
+
}
|
|
2907
|
+
|
|
2908
|
+
context_info->conditional_module = module;
|
|
2909
|
+
|
|
2910
|
+
return module;
|
|
2911
|
+
}
|
|
2912
|
+
|
|
2913
|
+
static CUfunction get_conditional_kernel(void* context, const char* name)
|
|
2914
|
+
{
|
|
2915
|
+
// load module if needed
|
|
2916
|
+
CUmodule module = load_conditional_module(context);
|
|
2917
|
+
if (!module)
|
|
2918
|
+
return NULL;
|
|
2919
|
+
|
|
2920
|
+
CUfunction kernel;
|
|
2921
|
+
if (!check_cu(cuModuleGetFunction_f(&kernel, module, name)))
|
|
2922
|
+
{
|
|
2923
|
+
fprintf(stderr, "Warp error: Failed to get kernel %s\n", name);
|
|
2924
|
+
return NULL;
|
|
2925
|
+
}
|
|
2926
|
+
|
|
2927
|
+
return kernel;
|
|
2928
|
+
}
|
|
2929
|
+
|
|
2930
|
+
bool cuda_graph_pause_capture(void* context, void* stream, void** graph_ret)
|
|
2931
|
+
{
|
|
2932
|
+
ContextGuard guard(context);
|
|
2933
|
+
|
|
2934
|
+
CUstream cuda_stream = static_cast<CUstream>(stream);
|
|
2935
|
+
if (!check_cuda(cudaStreamEndCapture(cuda_stream, (cudaGraph_t*)graph_ret)))
|
|
2936
|
+
return false;
|
|
2937
|
+
return true;
|
|
2938
|
+
}
|
|
2939
|
+
|
|
2940
|
+
bool cuda_graph_resume_capture(void* context, void* stream, void* graph)
|
|
2941
|
+
{
|
|
2942
|
+
ContextGuard guard(context);
|
|
2943
|
+
|
|
2944
|
+
CUstream cuda_stream = static_cast<CUstream>(stream);
|
|
2945
|
+
cudaGraph_t cuda_graph = static_cast<cudaGraph_t>(graph);
|
|
2946
|
+
|
|
2947
|
+
std::vector<cudaGraphNode_t> leaf_nodes;
|
|
2948
|
+
if (!get_graph_leaf_nodes(cuda_graph, leaf_nodes))
|
|
2949
|
+
return false;
|
|
2950
|
+
|
|
2951
|
+
if (!check_cuda(cudaStreamBeginCaptureToGraph(cuda_stream,
|
|
2952
|
+
cuda_graph,
|
|
2953
|
+
leaf_nodes.data(),
|
|
2954
|
+
nullptr,
|
|
2955
|
+
leaf_nodes.size(),
|
|
2956
|
+
cudaStreamCaptureModeGlobal)))
|
|
2957
|
+
return false;
|
|
2958
|
+
|
|
2959
|
+
return true;
|
|
2960
|
+
}
|
|
2961
|
+
|
|
2962
|
+
// https://developer.nvidia.com/blog/constructing-cuda-graphs-with-dynamic-parameters/#combined_approach
|
|
2963
|
+
// https://developer.nvidia.com/blog/dynamic-control-flow-in-cuda-graphs-with-conditional-nodes/
|
|
2964
|
+
// condition is a gpu pointer
|
|
2965
|
+
// if_graph_ret and else_graph_ret should be NULL if not needed
|
|
2966
|
+
bool cuda_graph_insert_if_else(void* context, void* stream, int* condition, void** if_graph_ret, void** else_graph_ret)
|
|
2967
|
+
{
|
|
2968
|
+
bool has_if = if_graph_ret != NULL;
|
|
2969
|
+
bool has_else = else_graph_ret != NULL;
|
|
2970
|
+
int num_branches = int(has_if) + int(has_else);
|
|
2971
|
+
|
|
2972
|
+
// if neither the IF nor ELSE branches are required, it's a no-op
|
|
2973
|
+
if (num_branches == 0)
|
|
2974
|
+
return true;
|
|
2975
|
+
|
|
2976
|
+
ContextGuard guard(context);
|
|
2977
|
+
|
|
2978
|
+
CUstream cuda_stream = static_cast<CUstream>(stream);
|
|
2979
|
+
|
|
2980
|
+
// Get the current stream capturing graph
|
|
2981
|
+
cudaStreamCaptureStatus capture_status = cudaStreamCaptureStatusNone;
|
|
2982
|
+
cudaGraph_t cuda_graph = NULL;
|
|
2983
|
+
const cudaGraphNode_t* capture_deps = NULL;
|
|
2984
|
+
size_t dep_count = 0;
|
|
2985
|
+
if (!check_cuda(cudaStreamGetCaptureInfo(cuda_stream, &capture_status, nullptr, &cuda_graph, &capture_deps, &dep_count)))
|
|
2986
|
+
return false;
|
|
2987
|
+
|
|
2988
|
+
// abort if not capturing
|
|
2989
|
+
if (!cuda_graph || capture_status != cudaStreamCaptureStatusActive)
|
|
2990
|
+
{
|
|
2991
|
+
wp::set_error_string("Stream is not capturing");
|
|
2992
|
+
return false;
|
|
2993
|
+
}
|
|
2994
|
+
|
|
2995
|
+
//int driver_version = cuda_driver_version();
|
|
2996
|
+
|
|
2997
|
+
// IF-ELSE nodes are only supported with CUDA 12.8+
|
|
2998
|
+
// Somehow child graphs produce wrong results when an else branch is used
|
|
2999
|
+
// Seems to be a bug in the CUDA driver: https://nvbugs/5241330
|
|
3000
|
+
if (num_branches == 1 /*|| driver_version >= 12080*/)
|
|
3001
|
+
{
|
|
3002
|
+
cudaGraphConditionalHandle handle;
|
|
3003
|
+
cudaGraphConditionalHandleCreate(&handle, cuda_graph);
|
|
3004
|
+
|
|
3005
|
+
// run a kernel to set the condition handle from the condition pointer
|
|
3006
|
+
// (need to negate the condition if only the else branch is used)
|
|
3007
|
+
CUfunction kernel;
|
|
3008
|
+
if (has_if)
|
|
3009
|
+
kernel = get_conditional_kernel(context, "set_conditional_if_handle_kernel");
|
|
3010
|
+
else
|
|
3011
|
+
kernel = get_conditional_kernel(context, "set_conditional_else_handle_kernel");
|
|
3012
|
+
|
|
3013
|
+
if (!kernel)
|
|
3014
|
+
{
|
|
3015
|
+
wp::set_error_string("Failed to get built-in conditional kernel");
|
|
3016
|
+
return false;
|
|
3017
|
+
}
|
|
3018
|
+
|
|
3019
|
+
void* kernel_args[2];
|
|
3020
|
+
kernel_args[0] = &handle;
|
|
3021
|
+
kernel_args[1] = &condition;
|
|
3022
|
+
|
|
3023
|
+
if (!check_cuda(cuLaunchKernel_f(kernel, 1, 1, 1, 1, 1, 1, 0, cuda_stream, kernel_args, NULL)))
|
|
3024
|
+
return false;
|
|
3025
|
+
|
|
3026
|
+
if (!check_cuda(cudaStreamGetCaptureInfo(cuda_stream, &capture_status, nullptr, &cuda_graph, &capture_deps, &dep_count)))
|
|
3027
|
+
return false;
|
|
3028
|
+
|
|
3029
|
+
// create conditional node
|
|
3030
|
+
cudaGraphNode_t condition_node;
|
|
3031
|
+
cudaGraphNodeParams condition_params = { cudaGraphNodeTypeConditional };
|
|
3032
|
+
condition_params.conditional.handle = handle;
|
|
3033
|
+
condition_params.conditional.type = cudaGraphCondTypeIf;
|
|
3034
|
+
condition_params.conditional.size = num_branches;
|
|
3035
|
+
if (!check_cuda(cudaGraphAddNode(&condition_node, cuda_graph, capture_deps, dep_count, &condition_params)))
|
|
3036
|
+
return false;
|
|
3037
|
+
|
|
3038
|
+
if (!check_cuda(cudaStreamUpdateCaptureDependencies(cuda_stream, &condition_node, 1, cudaStreamSetCaptureDependencies)))
|
|
3039
|
+
return false;
|
|
3040
|
+
|
|
3041
|
+
if (num_branches == 1)
|
|
3042
|
+
{
|
|
3043
|
+
if (has_if)
|
|
3044
|
+
*if_graph_ret = condition_params.conditional.phGraph_out[0];
|
|
3045
|
+
else
|
|
3046
|
+
*else_graph_ret = condition_params.conditional.phGraph_out[0];
|
|
3047
|
+
}
|
|
3048
|
+
else
|
|
3049
|
+
{
|
|
3050
|
+
*if_graph_ret = condition_params.conditional.phGraph_out[0];
|
|
3051
|
+
*else_graph_ret = condition_params.conditional.phGraph_out[1];
|
|
3052
|
+
}
|
|
3053
|
+
}
|
|
3054
|
+
else
|
|
3055
|
+
{
|
|
3056
|
+
// Create IF node followed by an additional IF node with negated condition
|
|
3057
|
+
cudaGraphConditionalHandle if_handle, else_handle;
|
|
3058
|
+
cudaGraphConditionalHandleCreate(&if_handle, cuda_graph);
|
|
3059
|
+
cudaGraphConditionalHandleCreate(&else_handle, cuda_graph);
|
|
3060
|
+
|
|
3061
|
+
CUfunction kernel = get_conditional_kernel(context, "set_conditional_if_else_handles_kernel");
|
|
3062
|
+
if (!kernel)
|
|
3063
|
+
{
|
|
3064
|
+
wp::set_error_string("Failed to get built-in conditional kernel");
|
|
3065
|
+
return false;
|
|
3066
|
+
}
|
|
3067
|
+
|
|
3068
|
+
void* kernel_args[3];
|
|
3069
|
+
kernel_args[0] = &if_handle;
|
|
3070
|
+
kernel_args[1] = &else_handle;
|
|
3071
|
+
kernel_args[2] = &condition;
|
|
3072
|
+
|
|
3073
|
+
if (!check_cu(cuLaunchKernel_f(kernel, 1, 1, 1, 1, 1, 1, 0, cuda_stream, kernel_args, NULL)))
|
|
3074
|
+
return false;
|
|
3075
|
+
|
|
3076
|
+
if (!check_cuda(cudaStreamGetCaptureInfo(cuda_stream, &capture_status, nullptr, &cuda_graph, &capture_deps, &dep_count)))
|
|
3077
|
+
return false;
|
|
3078
|
+
|
|
3079
|
+
cudaGraphNode_t if_node;
|
|
3080
|
+
cudaGraphNodeParams if_params = { cudaGraphNodeTypeConditional };
|
|
3081
|
+
if_params.conditional.handle = if_handle;
|
|
3082
|
+
if_params.conditional.type = cudaGraphCondTypeIf;
|
|
3083
|
+
if_params.conditional.size = 1;
|
|
3084
|
+
if (!check_cuda(cudaGraphAddNode(&if_node, cuda_graph, capture_deps, dep_count, &if_params)))
|
|
3085
|
+
return false;
|
|
3086
|
+
|
|
3087
|
+
cudaGraphNode_t else_node;
|
|
3088
|
+
cudaGraphNodeParams else_params = { cudaGraphNodeTypeConditional };
|
|
3089
|
+
else_params.conditional.handle = else_handle;
|
|
3090
|
+
else_params.conditional.type = cudaGraphCondTypeIf;
|
|
3091
|
+
else_params.conditional.size = 1;
|
|
3092
|
+
if (!check_cuda(cudaGraphAddNode(&else_node, cuda_graph, &if_node, 1, &else_params)))
|
|
3093
|
+
return false;
|
|
3094
|
+
|
|
3095
|
+
if (!check_cuda(cudaStreamUpdateCaptureDependencies(cuda_stream, &else_node, 1, cudaStreamSetCaptureDependencies)))
|
|
3096
|
+
return false;
|
|
3097
|
+
|
|
3098
|
+
*if_graph_ret = if_params.conditional.phGraph_out[0];
|
|
3099
|
+
*else_graph_ret = else_params.conditional.phGraph_out[0];
|
|
3100
|
+
}
|
|
3101
|
+
|
|
3102
|
+
return true;
|
|
3103
|
+
}
|
|
3104
|
+
|
|
3105
|
+
bool cuda_graph_insert_child_graph(void* context, void* stream, void* child_graph)
|
|
3106
|
+
{
|
|
3107
|
+
ContextGuard guard(context);
|
|
3108
|
+
|
|
3109
|
+
CUstream cuda_stream = static_cast<CUstream>(stream);
|
|
3110
|
+
|
|
3111
|
+
// Get the current stream capturing graph
|
|
3112
|
+
cudaStreamCaptureStatus capture_status = cudaStreamCaptureStatusNone;
|
|
3113
|
+
void* cuda_graph = NULL;
|
|
3114
|
+
const cudaGraphNode_t* capture_deps = NULL;
|
|
3115
|
+
size_t dep_count = 0;
|
|
3116
|
+
if (!check_cuda(cudaStreamGetCaptureInfo(cuda_stream, &capture_status, nullptr, (cudaGraph_t*)&cuda_graph, &capture_deps, &dep_count)))
|
|
3117
|
+
return false;
|
|
3118
|
+
|
|
3119
|
+
if (!cuda_graph_pause_capture(context, cuda_stream, &cuda_graph))
|
|
3120
|
+
return false;
|
|
3121
|
+
|
|
3122
|
+
cudaGraphNode_t body_node;
|
|
3123
|
+
if (!check_cuda(cudaGraphAddChildGraphNode(&body_node,
|
|
3124
|
+
static_cast<cudaGraph_t>(cuda_graph),
|
|
3125
|
+
capture_deps, dep_count,
|
|
3126
|
+
static_cast<cudaGraph_t>(child_graph))))
|
|
3127
|
+
return false;
|
|
3128
|
+
|
|
3129
|
+
if (!cuda_graph_resume_capture(context, cuda_stream, cuda_graph))
|
|
3130
|
+
return false;
|
|
3131
|
+
|
|
3132
|
+
if (!check_cuda(cudaStreamUpdateCaptureDependencies(cuda_stream, &body_node, 1, cudaStreamSetCaptureDependencies)))
|
|
3133
|
+
return false;
|
|
2767
3134
|
|
|
2768
3135
|
return true;
|
|
2769
3136
|
}
|
|
2770
3137
|
|
|
3138
|
+
bool cuda_graph_insert_while(void* context, void* stream, int* condition, void** body_graph_ret, uint64_t* handle_ret)
|
|
3139
|
+
{
|
|
3140
|
+
// if there's no body, it's a no-op
|
|
3141
|
+
if (!body_graph_ret)
|
|
3142
|
+
return true;
|
|
3143
|
+
|
|
3144
|
+
ContextGuard guard(context);
|
|
3145
|
+
|
|
3146
|
+
CUstream cuda_stream = static_cast<CUstream>(stream);
|
|
3147
|
+
|
|
3148
|
+
// Get the current stream capturing graph
|
|
3149
|
+
cudaStreamCaptureStatus capture_status = cudaStreamCaptureStatusNone;
|
|
3150
|
+
cudaGraph_t cuda_graph = NULL;
|
|
3151
|
+
const cudaGraphNode_t* capture_deps = NULL;
|
|
3152
|
+
size_t dep_count = 0;
|
|
3153
|
+
if (!check_cuda(cudaStreamGetCaptureInfo(cuda_stream, &capture_status, nullptr, &cuda_graph, &capture_deps, &dep_count)))
|
|
3154
|
+
return false;
|
|
3155
|
+
|
|
3156
|
+
// abort if not capturing
|
|
3157
|
+
if (!cuda_graph || capture_status != cudaStreamCaptureStatusActive)
|
|
3158
|
+
{
|
|
3159
|
+
wp::set_error_string("Stream is not capturing");
|
|
3160
|
+
return false;
|
|
3161
|
+
}
|
|
3162
|
+
|
|
3163
|
+
cudaGraphConditionalHandle handle;
|
|
3164
|
+
if (!check_cuda(cudaGraphConditionalHandleCreate(&handle, cuda_graph)))
|
|
3165
|
+
return false;
|
|
3166
|
+
|
|
3167
|
+
// launch a kernel to set the condition handle from condition pointer
|
|
3168
|
+
CUfunction kernel = get_conditional_kernel(context, "set_conditional_if_handle_kernel");
|
|
3169
|
+
if (!kernel)
|
|
3170
|
+
{
|
|
3171
|
+
wp::set_error_string("Failed to get built-in conditional kernel");
|
|
3172
|
+
return false;
|
|
3173
|
+
}
|
|
3174
|
+
|
|
3175
|
+
void* kernel_args[2];
|
|
3176
|
+
kernel_args[0] = &handle;
|
|
3177
|
+
kernel_args[1] = &condition;
|
|
3178
|
+
|
|
3179
|
+
if (!check_cu(cuLaunchKernel_f(kernel, 1, 1, 1, 1, 1, 1, 0, cuda_stream, kernel_args, NULL)))
|
|
3180
|
+
return false;
|
|
3181
|
+
|
|
3182
|
+
if (!check_cuda(cudaStreamGetCaptureInfo(cuda_stream, &capture_status, nullptr, &cuda_graph, &capture_deps, &dep_count)))
|
|
3183
|
+
return false;
|
|
3184
|
+
|
|
3185
|
+
// insert conditional graph node
|
|
3186
|
+
cudaGraphNode_t while_node;
|
|
3187
|
+
cudaGraphNodeParams while_params = { cudaGraphNodeTypeConditional };
|
|
3188
|
+
while_params.conditional.handle = handle;
|
|
3189
|
+
while_params.conditional.type = cudaGraphCondTypeWhile;
|
|
3190
|
+
while_params.conditional.size = 1;
|
|
3191
|
+
if (!check_cuda(cudaGraphAddNode(&while_node, cuda_graph, capture_deps, dep_count, &while_params)))
|
|
3192
|
+
return false;
|
|
3193
|
+
|
|
3194
|
+
if (!check_cuda(cudaStreamUpdateCaptureDependencies(cuda_stream, &while_node, 1, cudaStreamSetCaptureDependencies)))
|
|
3195
|
+
return false;
|
|
3196
|
+
|
|
3197
|
+
*body_graph_ret = while_params.conditional.phGraph_out[0];
|
|
3198
|
+
*handle_ret = handle;
|
|
3199
|
+
|
|
3200
|
+
return true;
|
|
3201
|
+
}
|
|
3202
|
+
|
|
3203
|
+
bool cuda_graph_set_condition(void* context, void* stream, int* condition, uint64_t handle)
|
|
3204
|
+
{
|
|
3205
|
+
ContextGuard guard(context);
|
|
3206
|
+
|
|
3207
|
+
CUstream cuda_stream = static_cast<CUstream>(stream);
|
|
3208
|
+
|
|
3209
|
+
// launch a kernel to set the condition handle from condition pointer
|
|
3210
|
+
CUfunction kernel = get_conditional_kernel(context, "set_conditional_if_handle_kernel");
|
|
3211
|
+
if (!kernel)
|
|
3212
|
+
{
|
|
3213
|
+
wp::set_error_string("Failed to get built-in conditional kernel");
|
|
3214
|
+
return false;
|
|
3215
|
+
}
|
|
3216
|
+
|
|
3217
|
+
void* kernel_args[2];
|
|
3218
|
+
kernel_args[0] = &handle;
|
|
3219
|
+
kernel_args[1] = &condition;
|
|
3220
|
+
|
|
3221
|
+
if (!check_cu(cuLaunchKernel_f(kernel, 1, 1, 1, 1, 1, 1, 0, cuda_stream, kernel_args, NULL)))
|
|
3222
|
+
return false;
|
|
3223
|
+
|
|
3224
|
+
return true;
|
|
3225
|
+
}
|
|
3226
|
+
|
|
3227
|
+
#else
|
|
3228
|
+
// stubs for conditional graph node API if CUDA toolkit is too old.
|
|
3229
|
+
|
|
3230
|
+
bool cuda_graph_pause_capture(void* context, void* stream, void** graph_ret)
|
|
3231
|
+
{
|
|
3232
|
+
wp::set_error_string("Warp error: Warp must be built with CUDA Toolkit 12.4+ to enable conditional graph nodes");
|
|
3233
|
+
return false;
|
|
3234
|
+
}
|
|
3235
|
+
|
|
3236
|
+
bool cuda_graph_resume_capture(void* context, void* stream, void* graph)
|
|
3237
|
+
{
|
|
3238
|
+
wp::set_error_string("Warp error: Warp must be built with CUDA Toolkit 12.4+ to enable conditional graph nodes");
|
|
3239
|
+
return false;
|
|
3240
|
+
}
|
|
3241
|
+
|
|
3242
|
+
bool cuda_graph_insert_if_else(void* context, void* stream, int* condition, void** if_graph_ret, void** else_graph_ret)
|
|
3243
|
+
{
|
|
3244
|
+
wp::set_error_string("Warp error: Warp must be built with CUDA Toolkit 12.4+ to enable conditional graph nodes");
|
|
3245
|
+
return false;
|
|
3246
|
+
}
|
|
3247
|
+
|
|
3248
|
+
bool cuda_graph_insert_while(void* context, void* stream, int* condition, void** body_graph_ret, uint64_t* handle_ret)
|
|
3249
|
+
{
|
|
3250
|
+
wp::set_error_string("Warp error: Warp must be built with CUDA Toolkit 12.4+ to enable conditional graph nodes");
|
|
3251
|
+
return false;
|
|
3252
|
+
}
|
|
3253
|
+
|
|
3254
|
+
bool cuda_graph_set_condition(void* context, void* stream, int* condition, uint64_t handle)
|
|
3255
|
+
{
|
|
3256
|
+
wp::set_error_string("Warp error: Warp must be built with CUDA Toolkit 12.4+ to enable conditional graph nodes");
|
|
3257
|
+
return false;
|
|
3258
|
+
}
|
|
3259
|
+
|
|
3260
|
+
bool cuda_graph_insert_child_graph(void* context, void* stream, void* child_graph)
|
|
3261
|
+
{
|
|
3262
|
+
wp::set_error_string("Warp error: Warp must be built with CUDA Toolkit 12.4+ to enable conditional graph nodes");
|
|
3263
|
+
return false;
|
|
3264
|
+
}
|
|
3265
|
+
|
|
3266
|
+
#endif // support for conditional graph nodes
|
|
3267
|
+
|
|
3268
|
+
|
|
2771
3269
|
bool cuda_graph_launch(void* graph_exec, void* stream)
|
|
2772
3270
|
{
|
|
2773
3271
|
// TODO: allow naming graphs?
|
|
@@ -2780,7 +3278,14 @@ bool cuda_graph_launch(void* graph_exec, void* stream)
|
|
|
2780
3278
|
return result;
|
|
2781
3279
|
}
|
|
2782
3280
|
|
|
2783
|
-
bool cuda_graph_destroy(void* context, void*
|
|
3281
|
+
bool cuda_graph_destroy(void* context, void* graph)
|
|
3282
|
+
{
|
|
3283
|
+
ContextGuard guard(context);
|
|
3284
|
+
|
|
3285
|
+
return check_cuda(cudaGraphDestroy((cudaGraph_t)graph));
|
|
3286
|
+
}
|
|
3287
|
+
|
|
3288
|
+
bool cuda_graph_exec_destroy(void* context, void* graph_exec)
|
|
2784
3289
|
{
|
|
2785
3290
|
ContextGuard guard(context);
|
|
2786
3291
|
|
|
@@ -2832,7 +3337,7 @@ bool write_file(const char* data, size_t size, std::string filename, const char*
|
|
|
2832
3337
|
}
|
|
2833
3338
|
#endif
|
|
2834
3339
|
|
|
2835
|
-
size_t cuda_compile_program(const char* cuda_src, const char* program_name, int arch, const char* include_dir, int num_cuda_include_dirs, const char** cuda_include_dirs, bool debug, bool verbose, bool verify_fp, bool fast_math, bool fuse_fp, bool lineinfo, const char* output_path, size_t num_ltoirs, char** ltoirs, size_t* ltoir_sizes, int* ltoir_input_types)
|
|
3340
|
+
size_t cuda_compile_program(const char* cuda_src, const char* program_name, int arch, const char* include_dir, int num_cuda_include_dirs, const char** cuda_include_dirs, bool debug, bool verbose, bool verify_fp, bool fast_math, bool fuse_fp, bool lineinfo, bool compile_time_trace, const char* output_path, size_t num_ltoirs, char** ltoirs, size_t* ltoir_sizes, int* ltoir_input_types)
|
|
2836
3341
|
{
|
|
2837
3342
|
// use file extension to determine whether to output PTX or CUBIN
|
|
2838
3343
|
const char* output_ext = strrchr(output_path, '.');
|
|
@@ -2919,11 +3424,11 @@ size_t cuda_compile_program(const char* cuda_src, const char* program_name, int
|
|
|
2919
3424
|
else
|
|
2920
3425
|
opts.push_back("--fmad=false");
|
|
2921
3426
|
|
|
2922
|
-
std::vector<std::string>
|
|
3427
|
+
std::vector<std::string> stored_options;
|
|
2923
3428
|
for(int i = 0; i < num_cuda_include_dirs; i++)
|
|
2924
3429
|
{
|
|
2925
|
-
|
|
2926
|
-
opts.push_back(
|
|
3430
|
+
stored_options.push_back(std::string("--include-path=") + cuda_include_dirs[i]);
|
|
3431
|
+
opts.push_back(stored_options.back().c_str());
|
|
2927
3432
|
}
|
|
2928
3433
|
|
|
2929
3434
|
opts.push_back("--device-as-default-execution-space");
|
|
@@ -2936,6 +3441,16 @@ size_t cuda_compile_program(const char* cuda_src, const char* program_name, int
|
|
|
2936
3441
|
opts.push_back("--relocatable-device-code=true");
|
|
2937
3442
|
}
|
|
2938
3443
|
|
|
3444
|
+
if (compile_time_trace)
|
|
3445
|
+
{
|
|
3446
|
+
#if CUDA_VERSION >= 12080
|
|
3447
|
+
stored_options.push_back(std::string("--fdevice-time-trace=") + std::string(output_path).append("_compile-time-trace.json"));
|
|
3448
|
+
opts.push_back(stored_options.back().c_str());
|
|
3449
|
+
#else
|
|
3450
|
+
fprintf(stderr, "Warp warning: CUDA version is less than 12.8, compile_time_trace is not supported\n");
|
|
3451
|
+
#endif
|
|
3452
|
+
}
|
|
3453
|
+
|
|
2939
3454
|
nvrtcProgram prog;
|
|
2940
3455
|
nvrtcResult res;
|
|
2941
3456
|
|
|
@@ -3162,11 +3677,11 @@ size_t cuda_compile_program(const char* cuda_src, const char* program_name, int
|
|
|
3162
3677
|
CHECK_ANY(num_include_dirs == 0);
|
|
3163
3678
|
|
|
3164
3679
|
bool res = true;
|
|
3165
|
-
|
|
3166
|
-
CHECK_CUFFTDX(
|
|
3680
|
+
cufftdxDescriptor h;
|
|
3681
|
+
CHECK_CUFFTDX(cufftdxCreateDescriptor(&h));
|
|
3167
3682
|
|
|
3168
|
-
//
|
|
3169
|
-
CHECK_CUFFTDX(cufftdxSetOperatorInt64(h, cufftdxOperatorType::CUFFTDX_OPERATOR_API, cufftdxApi::
|
|
3683
|
+
// CUFFTDX_API_LMEM means each thread starts with a subset of the data
|
|
3684
|
+
CHECK_CUFFTDX(cufftdxSetOperatorInt64(h, cufftdxOperatorType::CUFFTDX_OPERATOR_API, cufftdxApi::CUFFTDX_API_LMEM));
|
|
3170
3685
|
CHECK_CUFFTDX(cufftdxSetOperatorInt64(h, cufftdxOperatorType::CUFFTDX_OPERATOR_EXECUTION, commondxExecution::COMMONDX_EXECUTION_BLOCK));
|
|
3171
3686
|
CHECK_CUFFTDX(cufftdxSetOperatorInt64(h, cufftdxOperatorType::CUFFTDX_OPERATOR_SIZE, (long long)size));
|
|
3172
3687
|
CHECK_CUFFTDX(cufftdxSetOperatorInt64(h, cufftdxOperatorType::CUFFTDX_OPERATOR_DIRECTION, (cufftdxDirection)direction));
|
|
@@ -3191,7 +3706,7 @@ size_t cuda_compile_program(const char* cuda_src, const char* program_name, int
|
|
|
3191
3706
|
res = false;
|
|
3192
3707
|
}
|
|
3193
3708
|
|
|
3194
|
-
CHECK_CUFFTDX(
|
|
3709
|
+
CHECK_CUFFTDX(cufftdxDestroyDescriptor(h));
|
|
3195
3710
|
|
|
3196
3711
|
return res;
|
|
3197
3712
|
}
|
|
@@ -3207,22 +3722,22 @@ size_t cuda_compile_program(const char* cuda_src, const char* program_name, int
|
|
|
3207
3722
|
CHECK_ANY(num_include_dirs == 0);
|
|
3208
3723
|
|
|
3209
3724
|
bool res = true;
|
|
3210
|
-
|
|
3211
|
-
CHECK_CUBLASDX(
|
|
3725
|
+
cublasdxDescriptor h;
|
|
3726
|
+
CHECK_CUBLASDX(cublasdxCreateDescriptor(&h));
|
|
3212
3727
|
|
|
3213
3728
|
CHECK_CUBLASDX(cublasdxSetOperatorInt64(h, cublasdxOperatorType::CUBLASDX_OPERATOR_FUNCTION, cublasdxFunction::CUBLASDX_FUNCTION_MM));
|
|
3214
3729
|
CHECK_CUBLASDX(cublasdxSetOperatorInt64(h, cublasdxOperatorType::CUBLASDX_OPERATOR_EXECUTION, commondxExecution::COMMONDX_EXECUTION_BLOCK));
|
|
3215
|
-
CHECK_CUBLASDX(cublasdxSetOperatorInt64(h, cublasdxOperatorType::CUBLASDX_OPERATOR_API, cublasdxApi::
|
|
3730
|
+
CHECK_CUBLASDX(cublasdxSetOperatorInt64(h, cublasdxOperatorType::CUBLASDX_OPERATOR_API, cublasdxApi::CUBLASDX_API_SMEM));
|
|
3216
3731
|
std::array<long long int, 3> precisions = {precision_A, precision_B, precision_C};
|
|
3217
|
-
CHECK_CUBLASDX(
|
|
3732
|
+
CHECK_CUBLASDX(cublasdxSetOperatorInt64s(h, cublasdxOperatorType::CUBLASDX_OPERATOR_PRECISION, 3, precisions.data()));
|
|
3218
3733
|
CHECK_CUBLASDX(cublasdxSetOperatorInt64(h, cublasdxOperatorType::CUBLASDX_OPERATOR_SM, (long long)(arch * 10)));
|
|
3219
3734
|
CHECK_CUBLASDX(cublasdxSetOperatorInt64(h, cublasdxOperatorType::CUBLASDX_OPERATOR_TYPE, (cublasdxType)type));
|
|
3220
3735
|
std::array<long long int, 3> block_dim = {num_threads, 1, 1};
|
|
3221
|
-
CHECK_CUBLASDX(
|
|
3736
|
+
CHECK_CUBLASDX(cublasdxSetOperatorInt64s(h, cublasdxOperatorType::CUBLASDX_OPERATOR_BLOCK_DIM, block_dim.size(), block_dim.data()));
|
|
3222
3737
|
std::array<long long int, 3> size = {M, N, K};
|
|
3223
|
-
CHECK_CUBLASDX(
|
|
3738
|
+
CHECK_CUBLASDX(cublasdxSetOperatorInt64s(h, cublasdxOperatorType::CUBLASDX_OPERATOR_SIZE, size.size(), size.data()));
|
|
3224
3739
|
std::array<long long int, 3> arrangement = {arrangement_A, arrangement_B, arrangement_C};
|
|
3225
|
-
CHECK_CUBLASDX(
|
|
3740
|
+
CHECK_CUBLASDX(cublasdxSetOperatorInt64s(h, cublasdxOperatorType::CUBLASDX_OPERATOR_ARRANGEMENT, arrangement.size(), arrangement.data()));
|
|
3226
3741
|
|
|
3227
3742
|
CHECK_CUBLASDX(cublasdxSetOptionStr(h, commondxOption::COMMONDX_OPTION_SYMBOL_NAME, symbol_name));
|
|
3228
3743
|
|
|
@@ -3236,12 +3751,12 @@ size_t cuda_compile_program(const char* cuda_src, const char* program_name, int
|
|
|
3236
3751
|
res = false;
|
|
3237
3752
|
}
|
|
3238
3753
|
|
|
3239
|
-
CHECK_CUBLASDX(
|
|
3754
|
+
CHECK_CUBLASDX(cublasdxDestroyDescriptor(h));
|
|
3240
3755
|
|
|
3241
3756
|
return res;
|
|
3242
3757
|
}
|
|
3243
3758
|
|
|
3244
|
-
bool cuda_compile_solver(const char* fatbin_output_path, const char* ltoir_output_path, const char* symbol_name, int num_include_dirs, const char** include_dirs, const char* mathdx_include_dir, int arch, int M, int N, int function, int precision, int fill_mode, int num_threads)
|
|
3759
|
+
bool cuda_compile_solver(const char* fatbin_output_path, const char* ltoir_output_path, const char* symbol_name, int num_include_dirs, const char** include_dirs, const char* mathdx_include_dir, int arch, int M, int N, int NRHS, int function, int side, int diag, int precision, int arrangement_A, int arrangement_B, int fill_mode, int num_threads)
|
|
3245
3760
|
{
|
|
3246
3761
|
|
|
3247
3762
|
CHECK_ANY(ltoir_output_path != nullptr);
|
|
@@ -3252,34 +3767,42 @@ size_t cuda_compile_program(const char* cuda_src, const char* program_name, int
|
|
|
3252
3767
|
|
|
3253
3768
|
bool res = true;
|
|
3254
3769
|
|
|
3255
|
-
|
|
3256
|
-
CHECK_CUSOLVER(
|
|
3257
|
-
long long int size
|
|
3258
|
-
|
|
3259
|
-
|
|
3260
|
-
CHECK_CUSOLVER(
|
|
3261
|
-
CHECK_CUSOLVER(
|
|
3262
|
-
CHECK_CUSOLVER(
|
|
3263
|
-
CHECK_CUSOLVER(
|
|
3264
|
-
|
|
3265
|
-
|
|
3266
|
-
|
|
3267
|
-
|
|
3770
|
+
cusolverdxDescriptor h { 0 };
|
|
3771
|
+
CHECK_CUSOLVER(cusolverdxCreateDescriptor(&h));
|
|
3772
|
+
std::array<long long int, 3> size = {M, N, NRHS};
|
|
3773
|
+
CHECK_CUSOLVER(cusolverdxSetOperatorInt64s(h, cusolverdxOperatorType::CUSOLVERDX_OPERATOR_SIZE, size.size(), size.data()));
|
|
3774
|
+
std::array<long long int, 3> block_dim = {num_threads, 1, 1};
|
|
3775
|
+
CHECK_CUSOLVER(cusolverdxSetOperatorInt64s(h, cusolverdxOperatorType::CUSOLVERDX_OPERATOR_BLOCK_DIM, block_dim.size(), block_dim.data()));
|
|
3776
|
+
CHECK_CUSOLVER(cusolverdxSetOperatorInt64(h, cusolverdxOperatorType::CUSOLVERDX_OPERATOR_TYPE, cusolverdxType::CUSOLVERDX_TYPE_REAL));
|
|
3777
|
+
CHECK_CUSOLVER(cusolverdxSetOperatorInt64(h, cusolverdxOperatorType::CUSOLVERDX_OPERATOR_API, cusolverdxApi::CUSOLVERDX_API_SMEM));
|
|
3778
|
+
CHECK_CUSOLVER(cusolverdxSetOperatorInt64(h, cusolverdxOperatorType::CUSOLVERDX_OPERATOR_FUNCTION, (cusolverdxFunction)function));
|
|
3779
|
+
if (side >= 0) {
|
|
3780
|
+
CHECK_CUSOLVER(cusolverdxSetOperatorInt64(h, cusolverdxOperatorType::CUSOLVERDX_OPERATOR_SIDE, (cusolverdxSide)side));
|
|
3781
|
+
}
|
|
3782
|
+
if (diag >= 0) {
|
|
3783
|
+
CHECK_CUSOLVER(cusolverdxSetOperatorInt64(h, cusolverdxOperatorType::CUSOLVERDX_OPERATOR_DIAG, (cusolverdxDiag)diag));
|
|
3784
|
+
}
|
|
3785
|
+
CHECK_CUSOLVER(cusolverdxSetOperatorInt64(h, cusolverdxOperatorType::CUSOLVERDX_OPERATOR_EXECUTION, commondxExecution::COMMONDX_EXECUTION_BLOCK));
|
|
3786
|
+
CHECK_CUSOLVER(cusolverdxSetOperatorInt64(h, cusolverdxOperatorType::CUSOLVERDX_OPERATOR_PRECISION, (commondxPrecision)precision));
|
|
3787
|
+
std::array<long long int, 2> arrangement = {arrangement_A, arrangement_B};
|
|
3788
|
+
CHECK_CUSOLVER(cusolverdxSetOperatorInt64s(h, cusolverdxOperatorType::CUSOLVERDX_OPERATOR_ARRANGEMENT, arrangement.size(), arrangement.data()));
|
|
3789
|
+
CHECK_CUSOLVER(cusolverdxSetOperatorInt64(h, cusolverdxOperatorType::CUSOLVERDX_OPERATOR_FILL_MODE, (cusolverdxFillMode)fill_mode));
|
|
3790
|
+
CHECK_CUSOLVER(cusolverdxSetOperatorInt64(h, cusolverdxOperatorType::CUSOLVERDX_OPERATOR_SM, (long long)(arch * 10)));
|
|
3268
3791
|
|
|
3269
|
-
CHECK_CUSOLVER(
|
|
3792
|
+
CHECK_CUSOLVER(cusolverdxSetOptionStr(h, commondxOption::COMMONDX_OPTION_SYMBOL_NAME, symbol_name));
|
|
3270
3793
|
|
|
3271
3794
|
size_t lto_size = 0;
|
|
3272
|
-
CHECK_CUSOLVER(
|
|
3795
|
+
CHECK_CUSOLVER(cusolverdxGetLTOIRSize(h, <o_size));
|
|
3273
3796
|
|
|
3274
3797
|
std::vector<char> lto(lto_size);
|
|
3275
|
-
CHECK_CUSOLVER(
|
|
3798
|
+
CHECK_CUSOLVER(cusolverdxGetLTOIR(h, lto.size(), lto.data()));
|
|
3276
3799
|
|
|
3277
3800
|
// This fatbin is universal, ie it is the same for any instantiations of a cusolver device function
|
|
3278
3801
|
size_t fatbin_size = 0;
|
|
3279
|
-
CHECK_CUSOLVER(
|
|
3802
|
+
CHECK_CUSOLVER(cusolverdxGetUniversalFATBINSize(h, &fatbin_size));
|
|
3280
3803
|
|
|
3281
3804
|
std::vector<char> fatbin(fatbin_size);
|
|
3282
|
-
CHECK_CUSOLVER(
|
|
3805
|
+
CHECK_CUSOLVER(cusolverdxGetUniversalFATBIN(h, fatbin.size(), fatbin.data()));
|
|
3283
3806
|
|
|
3284
3807
|
if(!write_file(lto.data(), lto.size(), ltoir_output_path, "wb")) {
|
|
3285
3808
|
res = false;
|
|
@@ -3289,7 +3812,7 @@ size_t cuda_compile_program(const char* cuda_src, const char* program_name, int
|
|
|
3289
3812
|
res = false;
|
|
3290
3813
|
}
|
|
3291
3814
|
|
|
3292
|
-
CHECK_CUSOLVER(
|
|
3815
|
+
CHECK_CUSOLVER(cusolverdxDestroyDescriptor(h));
|
|
3293
3816
|
|
|
3294
3817
|
return res;
|
|
3295
3818
|
}
|