warp-lang 1.7.2rc1__py3-none-win_amd64.whl → 1.8.1__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +3 -1
- warp/__init__.pyi +3489 -1
- warp/autograd.py +45 -122
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +241 -252
- warp/build_dll.py +130 -26
- warp/builtins.py +1907 -384
- warp/codegen.py +272 -104
- warp/config.py +12 -1
- warp/constants.py +1 -1
- warp/context.py +770 -238
- warp/dlpack.py +1 -1
- warp/examples/benchmarks/benchmark_cloth.py +2 -2
- warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
- warp/examples/core/example_sample_mesh.py +1 -1
- warp/examples/core/example_spin_lock.py +93 -0
- warp/examples/core/example_work_queue.py +118 -0
- warp/examples/fem/example_adaptive_grid.py +5 -5
- warp/examples/fem/example_apic_fluid.py +1 -1
- warp/examples/fem/example_burgers.py +1 -1
- warp/examples/fem/example_convection_diffusion.py +9 -6
- warp/examples/fem/example_darcy_ls_optimization.py +489 -0
- warp/examples/fem/example_deformed_geometry.py +1 -1
- warp/examples/fem/example_diffusion.py +2 -2
- warp/examples/fem/example_diffusion_3d.py +1 -1
- warp/examples/fem/example_distortion_energy.py +1 -1
- warp/examples/fem/example_elastic_shape_optimization.py +387 -0
- warp/examples/fem/example_magnetostatics.py +5 -3
- warp/examples/fem/example_mixed_elasticity.py +5 -3
- warp/examples/fem/example_navier_stokes.py +11 -9
- warp/examples/fem/example_nonconforming_contact.py +5 -3
- warp/examples/fem/example_streamlines.py +8 -3
- warp/examples/fem/utils.py +9 -8
- warp/examples/interop/example_jax_callable.py +34 -4
- warp/examples/interop/example_jax_ffi_callback.py +2 -2
- warp/examples/interop/example_jax_kernel.py +27 -1
- warp/examples/optim/example_drone.py +1 -1
- warp/examples/sim/example_cloth.py +1 -1
- warp/examples/sim/example_cloth_self_contact.py +48 -54
- warp/examples/tile/example_tile_block_cholesky.py +502 -0
- warp/examples/tile/example_tile_cholesky.py +2 -1
- warp/examples/tile/example_tile_convolution.py +1 -1
- warp/examples/tile/example_tile_filtering.py +1 -1
- warp/examples/tile/example_tile_matmul.py +1 -1
- warp/examples/tile/example_tile_mlp.py +2 -0
- warp/fabric.py +7 -7
- warp/fem/__init__.py +5 -0
- warp/fem/adaptivity.py +1 -1
- warp/fem/cache.py +152 -63
- warp/fem/dirichlet.py +2 -2
- warp/fem/domain.py +136 -6
- warp/fem/field/field.py +141 -99
- warp/fem/field/nodal_field.py +85 -39
- warp/fem/field/virtual.py +99 -52
- warp/fem/geometry/adaptive_nanogrid.py +91 -86
- warp/fem/geometry/closest_point.py +13 -0
- warp/fem/geometry/deformed_geometry.py +102 -40
- warp/fem/geometry/element.py +56 -2
- warp/fem/geometry/geometry.py +323 -22
- warp/fem/geometry/grid_2d.py +157 -62
- warp/fem/geometry/grid_3d.py +116 -20
- warp/fem/geometry/hexmesh.py +86 -20
- warp/fem/geometry/nanogrid.py +166 -86
- warp/fem/geometry/partition.py +59 -25
- warp/fem/geometry/quadmesh.py +86 -135
- warp/fem/geometry/tetmesh.py +47 -119
- warp/fem/geometry/trimesh.py +77 -270
- warp/fem/integrate.py +181 -95
- warp/fem/linalg.py +25 -58
- warp/fem/operator.py +124 -27
- warp/fem/quadrature/pic_quadrature.py +36 -14
- warp/fem/quadrature/quadrature.py +40 -16
- warp/fem/space/__init__.py +1 -1
- warp/fem/space/basis_function_space.py +66 -46
- warp/fem/space/basis_space.py +17 -4
- warp/fem/space/dof_mapper.py +1 -1
- warp/fem/space/function_space.py +2 -2
- warp/fem/space/grid_2d_function_space.py +4 -1
- warp/fem/space/hexmesh_function_space.py +4 -2
- warp/fem/space/nanogrid_function_space.py +3 -1
- warp/fem/space/partition.py +11 -2
- warp/fem/space/quadmesh_function_space.py +4 -1
- warp/fem/space/restriction.py +5 -2
- warp/fem/space/shape/__init__.py +10 -8
- warp/fem/space/tetmesh_function_space.py +4 -1
- warp/fem/space/topology.py +52 -21
- warp/fem/space/trimesh_function_space.py +4 -1
- warp/fem/utils.py +53 -8
- warp/jax.py +1 -2
- warp/jax_experimental/ffi.py +210 -67
- warp/jax_experimental/xla_ffi.py +37 -24
- warp/math.py +171 -1
- warp/native/array.h +103 -4
- warp/native/builtin.h +182 -35
- warp/native/coloring.cpp +6 -2
- warp/native/cuda_util.cpp +1 -1
- warp/native/exports.h +118 -63
- warp/native/intersect.h +5 -5
- warp/native/mat.h +8 -13
- warp/native/mathdx.cpp +11 -5
- warp/native/matnn.h +1 -123
- warp/native/mesh.h +1 -1
- warp/native/quat.h +34 -6
- warp/native/rand.h +7 -7
- warp/native/sparse.cpp +121 -258
- warp/native/sparse.cu +181 -274
- warp/native/spatial.h +305 -17
- warp/native/svd.h +23 -8
- warp/native/tile.h +603 -73
- warp/native/tile_radix_sort.h +1112 -0
- warp/native/tile_reduce.h +239 -13
- warp/native/tile_scan.h +240 -0
- warp/native/tuple.h +189 -0
- warp/native/vec.h +10 -20
- warp/native/warp.cpp +36 -4
- warp/native/warp.cu +588 -52
- warp/native/warp.h +47 -74
- warp/optim/linear.py +5 -1
- warp/paddle.py +7 -8
- warp/py.typed +0 -0
- warp/render/render_opengl.py +110 -80
- warp/render/render_usd.py +124 -62
- warp/sim/__init__.py +9 -0
- warp/sim/collide.py +253 -80
- warp/sim/graph_coloring.py +8 -1
- warp/sim/import_mjcf.py +4 -3
- warp/sim/import_usd.py +11 -7
- warp/sim/integrator.py +5 -2
- warp/sim/integrator_euler.py +1 -1
- warp/sim/integrator_featherstone.py +1 -1
- warp/sim/integrator_vbd.py +761 -322
- warp/sim/integrator_xpbd.py +1 -1
- warp/sim/model.py +265 -260
- warp/sim/utils.py +10 -7
- warp/sparse.py +303 -166
- warp/tape.py +54 -51
- warp/tests/cuda/test_conditional_captures.py +1046 -0
- warp/tests/cuda/test_streams.py +1 -1
- warp/tests/geometry/test_volume.py +2 -2
- warp/tests/interop/test_dlpack.py +9 -9
- warp/tests/interop/test_jax.py +0 -1
- warp/tests/run_coverage_serial.py +1 -1
- warp/tests/sim/disabled_kinematics.py +2 -2
- warp/tests/sim/{test_vbd.py → test_cloth.py} +378 -112
- warp/tests/sim/test_collision.py +159 -51
- warp/tests/sim/test_coloring.py +91 -2
- warp/tests/test_array.py +254 -2
- warp/tests/test_array_reduce.py +2 -2
- warp/tests/test_assert.py +53 -0
- warp/tests/test_atomic_cas.py +312 -0
- warp/tests/test_codegen.py +142 -19
- warp/tests/test_conditional.py +47 -1
- warp/tests/test_ctypes.py +0 -20
- warp/tests/test_devices.py +8 -0
- warp/tests/test_fabricarray.py +4 -2
- warp/tests/test_fem.py +58 -25
- warp/tests/test_func.py +42 -1
- warp/tests/test_grad.py +1 -1
- warp/tests/test_lerp.py +1 -3
- warp/tests/test_map.py +481 -0
- warp/tests/test_mat.py +23 -24
- warp/tests/test_quat.py +28 -15
- warp/tests/test_rounding.py +10 -38
- warp/tests/test_runlength_encode.py +7 -7
- warp/tests/test_smoothstep.py +1 -1
- warp/tests/test_sparse.py +83 -2
- warp/tests/test_spatial.py +507 -1
- warp/tests/test_static.py +48 -0
- warp/tests/test_struct.py +2 -2
- warp/tests/test_tape.py +38 -0
- warp/tests/test_tuple.py +265 -0
- warp/tests/test_types.py +2 -2
- warp/tests/test_utils.py +24 -18
- warp/tests/test_vec.py +38 -408
- warp/tests/test_vec_constructors.py +325 -0
- warp/tests/tile/test_tile.py +438 -131
- warp/tests/tile/test_tile_mathdx.py +518 -14
- warp/tests/tile/test_tile_matmul.py +179 -0
- warp/tests/tile/test_tile_reduce.py +307 -5
- warp/tests/tile/test_tile_shared_memory.py +136 -7
- warp/tests/tile/test_tile_sort.py +121 -0
- warp/tests/unittest_suites.py +14 -6
- warp/types.py +462 -308
- warp/utils.py +647 -86
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/METADATA +20 -6
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/RECORD +190 -176
- warp/stubs.py +0 -3381
- warp/tests/sim/test_xpbd.py +0 -399
- warp/tests/test_mlp.py +0 -282
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/WHEEL +0 -0
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/licenses/LICENSE.md +0 -0
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/top_level.txt +0 -0
warp/native/warp.cu
CHANGED
|
@@ -27,6 +27,9 @@
|
|
|
27
27
|
#if WP_ENABLE_MATHDX
|
|
28
28
|
#include <nvJitLink.h>
|
|
29
29
|
#include <libmathdx.h>
|
|
30
|
+
#include <libcublasdx.h>
|
|
31
|
+
#include <libcufftdx.h>
|
|
32
|
+
#include <libcusolverdx.h>
|
|
30
33
|
#endif
|
|
31
34
|
|
|
32
35
|
#include <array>
|
|
@@ -155,6 +158,7 @@ struct DeviceInfo
|
|
|
155
158
|
int arch = 0;
|
|
156
159
|
int is_uva = 0;
|
|
157
160
|
int is_mempool_supported = 0;
|
|
161
|
+
int sm_count = 0;
|
|
158
162
|
int is_ipc_supported = -1;
|
|
159
163
|
int max_smem_bytes = 0;
|
|
160
164
|
CUcontext primary_context = NULL;
|
|
@@ -166,6 +170,9 @@ struct ContextInfo
|
|
|
166
170
|
|
|
167
171
|
// the current stream, managed from Python (see cuda_context_set_stream() and cuda_context_get_stream())
|
|
168
172
|
CUstream stream = NULL;
|
|
173
|
+
|
|
174
|
+
// conditional graph node support, loaded on demand if the driver supports it (CUDA 12.4+)
|
|
175
|
+
CUmodule conditional_module = NULL;
|
|
169
176
|
};
|
|
170
177
|
|
|
171
178
|
struct CaptureInfo
|
|
@@ -280,6 +287,7 @@ int cuda_init()
|
|
|
280
287
|
check_cu(cuDeviceGetAttribute_f(&g_devices[i].pci_device_id, CU_DEVICE_ATTRIBUTE_PCI_DEVICE_ID, device));
|
|
281
288
|
check_cu(cuDeviceGetAttribute_f(&g_devices[i].is_uva, CU_DEVICE_ATTRIBUTE_UNIFIED_ADDRESSING, device));
|
|
282
289
|
check_cu(cuDeviceGetAttribute_f(&g_devices[i].is_mempool_supported, CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED, device));
|
|
290
|
+
check_cu(cuDeviceGetAttribute_f(&g_devices[i].sm_count, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, device));
|
|
283
291
|
#ifdef CUDA_VERSION
|
|
284
292
|
#if CUDA_VERSION >= 12000
|
|
285
293
|
int device_attribute_integrated = 0;
|
|
@@ -301,7 +309,13 @@ int cuda_init()
|
|
|
301
309
|
check_cu(cuDeviceGetAttribute_f(&major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device));
|
|
302
310
|
check_cu(cuDeviceGetAttribute_f(&minor, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, device));
|
|
303
311
|
g_devices[i].arch = 10 * major + minor;
|
|
304
|
-
|
|
312
|
+
#ifdef CUDA_VERSION
|
|
313
|
+
#if CUDA_VERSION < 13000
|
|
314
|
+
if (g_devices[i].arch == 110) {
|
|
315
|
+
g_devices[i].arch = 101; // Thor SM change
|
|
316
|
+
}
|
|
317
|
+
#endif
|
|
318
|
+
#endif
|
|
305
319
|
g_device_map[device] = &g_devices[i];
|
|
306
320
|
}
|
|
307
321
|
else
|
|
@@ -1786,6 +1800,13 @@ int cuda_device_get_arch(int ordinal)
|
|
|
1786
1800
|
return 0;
|
|
1787
1801
|
}
|
|
1788
1802
|
|
|
1803
|
+
int cuda_device_get_sm_count(int ordinal)
|
|
1804
|
+
{
|
|
1805
|
+
if (ordinal >= 0 && ordinal < int(g_devices.size()))
|
|
1806
|
+
return g_devices[ordinal].sm_count;
|
|
1807
|
+
return 0;
|
|
1808
|
+
}
|
|
1809
|
+
|
|
1789
1810
|
void cuda_device_get_uuid(int ordinal, char uuid[16])
|
|
1790
1811
|
{
|
|
1791
1812
|
memcpy(uuid, g_devices[ordinal].uuid.bytes, sizeof(char)*16);
|
|
@@ -2034,6 +2055,9 @@ void cuda_context_destroy(void* context)
|
|
|
2034
2055
|
if (info->stream)
|
|
2035
2056
|
check_cu(cuStreamDestroy_f(info->stream));
|
|
2036
2057
|
|
|
2058
|
+
if (info->conditional_module)
|
|
2059
|
+
check_cu(cuModuleUnload_f(info->conditional_module));
|
|
2060
|
+
|
|
2037
2061
|
g_contexts.erase(ctx);
|
|
2038
2062
|
}
|
|
2039
2063
|
|
|
@@ -2739,22 +2763,10 @@ bool cuda_graph_end_capture(void* context, void* stream, void** graph_ret)
|
|
|
2739
2763
|
if (external)
|
|
2740
2764
|
return true;
|
|
2741
2765
|
|
|
2742
|
-
cudaGraphExec_t graph_exec = NULL;
|
|
2743
|
-
|
|
2744
2766
|
// end the capture
|
|
2745
2767
|
if (!check_cuda(cudaStreamEndCapture(cuda_stream, &graph)))
|
|
2746
2768
|
return false;
|
|
2747
2769
|
|
|
2748
|
-
// enable to create debug GraphVis visualization of graph
|
|
2749
|
-
// cudaGraphDebugDotPrint(graph, "graph.dot", cudaGraphDebugDotFlagsVerbose);
|
|
2750
|
-
|
|
2751
|
-
// can use after CUDA 11.4 to permit graphs to capture cudaMallocAsync() operations
|
|
2752
|
-
if (!check_cuda(cudaGraphInstantiateWithFlags(&graph_exec, graph, cudaGraphInstantiateFlagAutoFreeOnLaunch)))
|
|
2753
|
-
return false;
|
|
2754
|
-
|
|
2755
|
-
// free source graph
|
|
2756
|
-
check_cuda(cudaGraphDestroy(graph));
|
|
2757
|
-
|
|
2758
2770
|
// process deferred free list if no more captures are ongoing
|
|
2759
2771
|
if (g_captures.empty())
|
|
2760
2772
|
{
|
|
@@ -2763,11 +2775,510 @@ bool cuda_graph_end_capture(void* context, void* stream, void** graph_ret)
|
|
|
2763
2775
|
}
|
|
2764
2776
|
|
|
2765
2777
|
if (graph_ret)
|
|
2766
|
-
*graph_ret =
|
|
2778
|
+
*graph_ret = graph;
|
|
2767
2779
|
|
|
2768
2780
|
return true;
|
|
2769
2781
|
}
|
|
2770
2782
|
|
|
2783
|
+
bool capture_debug_dot_print(void* graph, const char *path, uint32_t flags)
|
|
2784
|
+
{
|
|
2785
|
+
if (!check_cuda(cudaGraphDebugDotPrint((cudaGraph_t)graph, path, flags)))
|
|
2786
|
+
return false;
|
|
2787
|
+
return true;
|
|
2788
|
+
}
|
|
2789
|
+
|
|
2790
|
+
bool cuda_graph_create_exec(void* context, void* stream, void* graph, void** graph_exec_ret)
|
|
2791
|
+
{
|
|
2792
|
+
ContextGuard guard(context);
|
|
2793
|
+
|
|
2794
|
+
cudaGraphExec_t graph_exec = NULL;
|
|
2795
|
+
if (!check_cuda(cudaGraphInstantiateWithFlags(&graph_exec, (cudaGraph_t)graph, cudaGraphInstantiateFlagAutoFreeOnLaunch)))
|
|
2796
|
+
return false;
|
|
2797
|
+
|
|
2798
|
+
// Usually uploading the graph explicitly is optional, but when updating graph nodes (e.g., indirect dispatch)
|
|
2799
|
+
// then the upload is required because otherwise the graph nodes that get updated might not yet be uploaded, which
|
|
2800
|
+
// results in undefined behavior.
|
|
2801
|
+
CUstream cuda_stream = static_cast<CUstream>(stream);
|
|
2802
|
+
if (!check_cuda(cudaGraphUpload(graph_exec, cuda_stream)))
|
|
2803
|
+
return false;
|
|
2804
|
+
|
|
2805
|
+
if (graph_exec_ret)
|
|
2806
|
+
*graph_exec_ret = graph_exec;
|
|
2807
|
+
|
|
2808
|
+
return true;
|
|
2809
|
+
}
|
|
2810
|
+
|
|
2811
|
+
// Support for conditional graph nodes available with CUDA 12.4+.
|
|
2812
|
+
#if CUDA_VERSION >= 12040
|
|
2813
|
+
|
|
2814
|
+
// CUBIN data for compiled conditional modules, loaded on demand, keyed on device architecture
|
|
2815
|
+
static std::map<int, void*> g_conditional_cubins;
|
|
2816
|
+
|
|
2817
|
+
// Compile module with conditional helper kernels
|
|
2818
|
+
static void* compile_conditional_module(int arch)
|
|
2819
|
+
{
|
|
2820
|
+
static const char* kernel_source = R"(
|
|
2821
|
+
typedef __device_builtin__ unsigned long long cudaGraphConditionalHandle;
|
|
2822
|
+
extern "C" __device__ __cudart_builtin__ void cudaGraphSetConditional(cudaGraphConditionalHandle handle, unsigned int value);
|
|
2823
|
+
|
|
2824
|
+
extern "C" __global__ void set_conditional_if_handle_kernel(cudaGraphConditionalHandle handle, int* value)
|
|
2825
|
+
{
|
|
2826
|
+
if (threadIdx.x + blockIdx.x * blockDim.x == 0)
|
|
2827
|
+
cudaGraphSetConditional(handle, *value);
|
|
2828
|
+
}
|
|
2829
|
+
|
|
2830
|
+
extern "C" __global__ void set_conditional_else_handle_kernel(cudaGraphConditionalHandle handle, int* value)
|
|
2831
|
+
{
|
|
2832
|
+
if (threadIdx.x + blockIdx.x * blockDim.x == 0)
|
|
2833
|
+
cudaGraphSetConditional(handle, !*value);
|
|
2834
|
+
}
|
|
2835
|
+
|
|
2836
|
+
extern "C" __global__ void set_conditional_if_else_handles_kernel(cudaGraphConditionalHandle if_handle, cudaGraphConditionalHandle else_handle, int* value)
|
|
2837
|
+
{
|
|
2838
|
+
if (threadIdx.x + blockIdx.x * blockDim.x == 0)
|
|
2839
|
+
{
|
|
2840
|
+
cudaGraphSetConditional(if_handle, *value);
|
|
2841
|
+
cudaGraphSetConditional(else_handle, !*value);
|
|
2842
|
+
}
|
|
2843
|
+
}
|
|
2844
|
+
)";
|
|
2845
|
+
|
|
2846
|
+
// avoid recompilation
|
|
2847
|
+
auto it = g_conditional_cubins.find(arch);
|
|
2848
|
+
if (it != g_conditional_cubins.end())
|
|
2849
|
+
return it->second;
|
|
2850
|
+
|
|
2851
|
+
nvrtcProgram prog;
|
|
2852
|
+
if (!check_nvrtc(nvrtcCreateProgram(&prog, kernel_source, "conditional_kernels", 0, NULL, NULL)))
|
|
2853
|
+
return NULL;
|
|
2854
|
+
|
|
2855
|
+
char arch_opt[128];
|
|
2856
|
+
snprintf(arch_opt, sizeof(arch_opt), "--gpu-architecture=sm_%d", arch);
|
|
2857
|
+
|
|
2858
|
+
std::vector<const char*> opts;
|
|
2859
|
+
opts.push_back(arch_opt);
|
|
2860
|
+
|
|
2861
|
+
if (!check_nvrtc(nvrtcCompileProgram(prog, int(opts.size()), opts.data())))
|
|
2862
|
+
{
|
|
2863
|
+
size_t log_size;
|
|
2864
|
+
if (check_nvrtc(nvrtcGetProgramLogSize(prog, &log_size)))
|
|
2865
|
+
{
|
|
2866
|
+
std::vector<char> log(log_size);
|
|
2867
|
+
if (check_nvrtc(nvrtcGetProgramLog(prog, log.data())))
|
|
2868
|
+
fprintf(stderr, "%s", log.data());
|
|
2869
|
+
}
|
|
2870
|
+
nvrtcDestroyProgram(&prog);
|
|
2871
|
+
return NULL;
|
|
2872
|
+
}
|
|
2873
|
+
|
|
2874
|
+
// get output
|
|
2875
|
+
char* output = NULL;
|
|
2876
|
+
size_t output_size = 0;
|
|
2877
|
+
check_nvrtc(nvrtcGetCUBINSize(prog, &output_size));
|
|
2878
|
+
if (output_size > 0)
|
|
2879
|
+
{
|
|
2880
|
+
output = new char[output_size];
|
|
2881
|
+
if (check_nvrtc(nvrtcGetCUBIN(prog, output)))
|
|
2882
|
+
g_conditional_cubins[arch] = output;
|
|
2883
|
+
}
|
|
2884
|
+
|
|
2885
|
+
nvrtcDestroyProgram(&prog);
|
|
2886
|
+
|
|
2887
|
+
// return CUBIN data
|
|
2888
|
+
return output;
|
|
2889
|
+
}
|
|
2890
|
+
|
|
2891
|
+
|
|
2892
|
+
// Load module with conditional helper kernels
|
|
2893
|
+
static CUmodule load_conditional_module(void* context)
|
|
2894
|
+
{
|
|
2895
|
+
ContextInfo* context_info = get_context_info(context);
|
|
2896
|
+
if (!context_info)
|
|
2897
|
+
return NULL;
|
|
2898
|
+
|
|
2899
|
+
// check if already loaded
|
|
2900
|
+
if (context_info->conditional_module)
|
|
2901
|
+
return context_info->conditional_module;
|
|
2902
|
+
|
|
2903
|
+
int arch = context_info->device_info->arch;
|
|
2904
|
+
|
|
2905
|
+
// compile if needed
|
|
2906
|
+
void* compiled_module = compile_conditional_module(arch);
|
|
2907
|
+
if (!compiled_module)
|
|
2908
|
+
{
|
|
2909
|
+
fprintf(stderr, "Warp error: Failed to compile conditional kernels\n");
|
|
2910
|
+
return NULL;
|
|
2911
|
+
}
|
|
2912
|
+
|
|
2913
|
+
// load module
|
|
2914
|
+
CUmodule module = NULL;
|
|
2915
|
+
if (!check_cu(cuModuleLoadDataEx_f(&module, compiled_module, 0, NULL, NULL)))
|
|
2916
|
+
{
|
|
2917
|
+
fprintf(stderr, "Warp error: Failed to load conditional kernels module\n");
|
|
2918
|
+
return NULL;
|
|
2919
|
+
}
|
|
2920
|
+
|
|
2921
|
+
context_info->conditional_module = module;
|
|
2922
|
+
|
|
2923
|
+
return module;
|
|
2924
|
+
}
|
|
2925
|
+
|
|
2926
|
+
static CUfunction get_conditional_kernel(void* context, const char* name)
|
|
2927
|
+
{
|
|
2928
|
+
// load module if needed
|
|
2929
|
+
CUmodule module = load_conditional_module(context);
|
|
2930
|
+
if (!module)
|
|
2931
|
+
return NULL;
|
|
2932
|
+
|
|
2933
|
+
CUfunction kernel;
|
|
2934
|
+
if (!check_cu(cuModuleGetFunction_f(&kernel, module, name)))
|
|
2935
|
+
{
|
|
2936
|
+
fprintf(stderr, "Warp error: Failed to get kernel %s\n", name);
|
|
2937
|
+
return NULL;
|
|
2938
|
+
}
|
|
2939
|
+
|
|
2940
|
+
return kernel;
|
|
2941
|
+
}
|
|
2942
|
+
|
|
2943
|
+
bool cuda_graph_pause_capture(void* context, void* stream, void** graph_ret)
|
|
2944
|
+
{
|
|
2945
|
+
ContextGuard guard(context);
|
|
2946
|
+
|
|
2947
|
+
CUstream cuda_stream = static_cast<CUstream>(stream);
|
|
2948
|
+
if (!check_cuda(cudaStreamEndCapture(cuda_stream, (cudaGraph_t*)graph_ret)))
|
|
2949
|
+
return false;
|
|
2950
|
+
return true;
|
|
2951
|
+
}
|
|
2952
|
+
|
|
2953
|
+
bool cuda_graph_resume_capture(void* context, void* stream, void* graph)
|
|
2954
|
+
{
|
|
2955
|
+
ContextGuard guard(context);
|
|
2956
|
+
|
|
2957
|
+
CUstream cuda_stream = static_cast<CUstream>(stream);
|
|
2958
|
+
cudaGraph_t cuda_graph = static_cast<cudaGraph_t>(graph);
|
|
2959
|
+
|
|
2960
|
+
std::vector<cudaGraphNode_t> leaf_nodes;
|
|
2961
|
+
if (!get_graph_leaf_nodes(cuda_graph, leaf_nodes))
|
|
2962
|
+
return false;
|
|
2963
|
+
|
|
2964
|
+
if (!check_cuda(cudaStreamBeginCaptureToGraph(cuda_stream,
|
|
2965
|
+
cuda_graph,
|
|
2966
|
+
leaf_nodes.data(),
|
|
2967
|
+
nullptr,
|
|
2968
|
+
leaf_nodes.size(),
|
|
2969
|
+
cudaStreamCaptureModeGlobal)))
|
|
2970
|
+
return false;
|
|
2971
|
+
|
|
2972
|
+
return true;
|
|
2973
|
+
}
|
|
2974
|
+
|
|
2975
|
+
// https://developer.nvidia.com/blog/constructing-cuda-graphs-with-dynamic-parameters/#combined_approach
|
|
2976
|
+
// https://developer.nvidia.com/blog/dynamic-control-flow-in-cuda-graphs-with-conditional-nodes/
|
|
2977
|
+
// condition is a gpu pointer
|
|
2978
|
+
// if_graph_ret and else_graph_ret should be NULL if not needed
|
|
2979
|
+
bool cuda_graph_insert_if_else(void* context, void* stream, int* condition, void** if_graph_ret, void** else_graph_ret)
|
|
2980
|
+
{
|
|
2981
|
+
bool has_if = if_graph_ret != NULL;
|
|
2982
|
+
bool has_else = else_graph_ret != NULL;
|
|
2983
|
+
int num_branches = int(has_if) + int(has_else);
|
|
2984
|
+
|
|
2985
|
+
// if neither the IF nor ELSE branches are required, it's a no-op
|
|
2986
|
+
if (num_branches == 0)
|
|
2987
|
+
return true;
|
|
2988
|
+
|
|
2989
|
+
ContextGuard guard(context);
|
|
2990
|
+
|
|
2991
|
+
CUstream cuda_stream = static_cast<CUstream>(stream);
|
|
2992
|
+
|
|
2993
|
+
// Get the current stream capturing graph
|
|
2994
|
+
cudaStreamCaptureStatus capture_status = cudaStreamCaptureStatusNone;
|
|
2995
|
+
cudaGraph_t cuda_graph = NULL;
|
|
2996
|
+
const cudaGraphNode_t* capture_deps = NULL;
|
|
2997
|
+
size_t dep_count = 0;
|
|
2998
|
+
if (!check_cuda(cudaStreamGetCaptureInfo(cuda_stream, &capture_status, nullptr, &cuda_graph, &capture_deps, &dep_count)))
|
|
2999
|
+
return false;
|
|
3000
|
+
|
|
3001
|
+
// abort if not capturing
|
|
3002
|
+
if (!cuda_graph || capture_status != cudaStreamCaptureStatusActive)
|
|
3003
|
+
{
|
|
3004
|
+
wp::set_error_string("Stream is not capturing");
|
|
3005
|
+
return false;
|
|
3006
|
+
}
|
|
3007
|
+
|
|
3008
|
+
//int driver_version = cuda_driver_version();
|
|
3009
|
+
|
|
3010
|
+
// IF-ELSE nodes are only supported with CUDA 12.8+
|
|
3011
|
+
// Somehow child graphs produce wrong results when an else branch is used
|
|
3012
|
+
// Seems to be a bug in the CUDA driver: https://nvbugs/5241330
|
|
3013
|
+
if (num_branches == 1 /*|| driver_version >= 12080*/)
|
|
3014
|
+
{
|
|
3015
|
+
cudaGraphConditionalHandle handle;
|
|
3016
|
+
cudaGraphConditionalHandleCreate(&handle, cuda_graph);
|
|
3017
|
+
|
|
3018
|
+
// run a kernel to set the condition handle from the condition pointer
|
|
3019
|
+
// (need to negate the condition if only the else branch is used)
|
|
3020
|
+
CUfunction kernel;
|
|
3021
|
+
if (has_if)
|
|
3022
|
+
kernel = get_conditional_kernel(context, "set_conditional_if_handle_kernel");
|
|
3023
|
+
else
|
|
3024
|
+
kernel = get_conditional_kernel(context, "set_conditional_else_handle_kernel");
|
|
3025
|
+
|
|
3026
|
+
if (!kernel)
|
|
3027
|
+
{
|
|
3028
|
+
wp::set_error_string("Failed to get built-in conditional kernel");
|
|
3029
|
+
return false;
|
|
3030
|
+
}
|
|
3031
|
+
|
|
3032
|
+
void* kernel_args[2];
|
|
3033
|
+
kernel_args[0] = &handle;
|
|
3034
|
+
kernel_args[1] = &condition;
|
|
3035
|
+
|
|
3036
|
+
if (!check_cuda(cuLaunchKernel_f(kernel, 1, 1, 1, 1, 1, 1, 0, cuda_stream, kernel_args, NULL)))
|
|
3037
|
+
return false;
|
|
3038
|
+
|
|
3039
|
+
if (!check_cuda(cudaStreamGetCaptureInfo(cuda_stream, &capture_status, nullptr, &cuda_graph, &capture_deps, &dep_count)))
|
|
3040
|
+
return false;
|
|
3041
|
+
|
|
3042
|
+
// create conditional node
|
|
3043
|
+
cudaGraphNode_t condition_node;
|
|
3044
|
+
cudaGraphNodeParams condition_params = { cudaGraphNodeTypeConditional };
|
|
3045
|
+
condition_params.conditional.handle = handle;
|
|
3046
|
+
condition_params.conditional.type = cudaGraphCondTypeIf;
|
|
3047
|
+
condition_params.conditional.size = num_branches;
|
|
3048
|
+
if (!check_cuda(cudaGraphAddNode(&condition_node, cuda_graph, capture_deps, dep_count, &condition_params)))
|
|
3049
|
+
return false;
|
|
3050
|
+
|
|
3051
|
+
if (!check_cuda(cudaStreamUpdateCaptureDependencies(cuda_stream, &condition_node, 1, cudaStreamSetCaptureDependencies)))
|
|
3052
|
+
return false;
|
|
3053
|
+
|
|
3054
|
+
if (num_branches == 1)
|
|
3055
|
+
{
|
|
3056
|
+
if (has_if)
|
|
3057
|
+
*if_graph_ret = condition_params.conditional.phGraph_out[0];
|
|
3058
|
+
else
|
|
3059
|
+
*else_graph_ret = condition_params.conditional.phGraph_out[0];
|
|
3060
|
+
}
|
|
3061
|
+
else
|
|
3062
|
+
{
|
|
3063
|
+
*if_graph_ret = condition_params.conditional.phGraph_out[0];
|
|
3064
|
+
*else_graph_ret = condition_params.conditional.phGraph_out[1];
|
|
3065
|
+
}
|
|
3066
|
+
}
|
|
3067
|
+
else
|
|
3068
|
+
{
|
|
3069
|
+
// Create IF node followed by an additional IF node with negated condition
|
|
3070
|
+
cudaGraphConditionalHandle if_handle, else_handle;
|
|
3071
|
+
cudaGraphConditionalHandleCreate(&if_handle, cuda_graph);
|
|
3072
|
+
cudaGraphConditionalHandleCreate(&else_handle, cuda_graph);
|
|
3073
|
+
|
|
3074
|
+
CUfunction kernel = get_conditional_kernel(context, "set_conditional_if_else_handles_kernel");
|
|
3075
|
+
if (!kernel)
|
|
3076
|
+
{
|
|
3077
|
+
wp::set_error_string("Failed to get built-in conditional kernel");
|
|
3078
|
+
return false;
|
|
3079
|
+
}
|
|
3080
|
+
|
|
3081
|
+
void* kernel_args[3];
|
|
3082
|
+
kernel_args[0] = &if_handle;
|
|
3083
|
+
kernel_args[1] = &else_handle;
|
|
3084
|
+
kernel_args[2] = &condition;
|
|
3085
|
+
|
|
3086
|
+
if (!check_cu(cuLaunchKernel_f(kernel, 1, 1, 1, 1, 1, 1, 0, cuda_stream, kernel_args, NULL)))
|
|
3087
|
+
return false;
|
|
3088
|
+
|
|
3089
|
+
if (!check_cuda(cudaStreamGetCaptureInfo(cuda_stream, &capture_status, nullptr, &cuda_graph, &capture_deps, &dep_count)))
|
|
3090
|
+
return false;
|
|
3091
|
+
|
|
3092
|
+
cudaGraphNode_t if_node;
|
|
3093
|
+
cudaGraphNodeParams if_params = { cudaGraphNodeTypeConditional };
|
|
3094
|
+
if_params.conditional.handle = if_handle;
|
|
3095
|
+
if_params.conditional.type = cudaGraphCondTypeIf;
|
|
3096
|
+
if_params.conditional.size = 1;
|
|
3097
|
+
if (!check_cuda(cudaGraphAddNode(&if_node, cuda_graph, capture_deps, dep_count, &if_params)))
|
|
3098
|
+
return false;
|
|
3099
|
+
|
|
3100
|
+
cudaGraphNode_t else_node;
|
|
3101
|
+
cudaGraphNodeParams else_params = { cudaGraphNodeTypeConditional };
|
|
3102
|
+
else_params.conditional.handle = else_handle;
|
|
3103
|
+
else_params.conditional.type = cudaGraphCondTypeIf;
|
|
3104
|
+
else_params.conditional.size = 1;
|
|
3105
|
+
if (!check_cuda(cudaGraphAddNode(&else_node, cuda_graph, &if_node, 1, &else_params)))
|
|
3106
|
+
return false;
|
|
3107
|
+
|
|
3108
|
+
if (!check_cuda(cudaStreamUpdateCaptureDependencies(cuda_stream, &else_node, 1, cudaStreamSetCaptureDependencies)))
|
|
3109
|
+
return false;
|
|
3110
|
+
|
|
3111
|
+
*if_graph_ret = if_params.conditional.phGraph_out[0];
|
|
3112
|
+
*else_graph_ret = else_params.conditional.phGraph_out[0];
|
|
3113
|
+
}
|
|
3114
|
+
|
|
3115
|
+
return true;
|
|
3116
|
+
}
|
|
3117
|
+
|
|
3118
|
+
bool cuda_graph_insert_child_graph(void* context, void* stream, void* child_graph)
|
|
3119
|
+
{
|
|
3120
|
+
ContextGuard guard(context);
|
|
3121
|
+
|
|
3122
|
+
CUstream cuda_stream = static_cast<CUstream>(stream);
|
|
3123
|
+
|
|
3124
|
+
// Get the current stream capturing graph
|
|
3125
|
+
cudaStreamCaptureStatus capture_status = cudaStreamCaptureStatusNone;
|
|
3126
|
+
void* cuda_graph = NULL;
|
|
3127
|
+
const cudaGraphNode_t* capture_deps = NULL;
|
|
3128
|
+
size_t dep_count = 0;
|
|
3129
|
+
if (!check_cuda(cudaStreamGetCaptureInfo(cuda_stream, &capture_status, nullptr, (cudaGraph_t*)&cuda_graph, &capture_deps, &dep_count)))
|
|
3130
|
+
return false;
|
|
3131
|
+
|
|
3132
|
+
if (!cuda_graph_pause_capture(context, cuda_stream, &cuda_graph))
|
|
3133
|
+
return false;
|
|
3134
|
+
|
|
3135
|
+
cudaGraphNode_t body_node;
|
|
3136
|
+
if (!check_cuda(cudaGraphAddChildGraphNode(&body_node,
|
|
3137
|
+
static_cast<cudaGraph_t>(cuda_graph),
|
|
3138
|
+
capture_deps, dep_count,
|
|
3139
|
+
static_cast<cudaGraph_t>(child_graph))))
|
|
3140
|
+
return false;
|
|
3141
|
+
|
|
3142
|
+
if (!cuda_graph_resume_capture(context, cuda_stream, cuda_graph))
|
|
3143
|
+
return false;
|
|
3144
|
+
|
|
3145
|
+
if (!check_cuda(cudaStreamUpdateCaptureDependencies(cuda_stream, &body_node, 1, cudaStreamSetCaptureDependencies)))
|
|
3146
|
+
return false;
|
|
3147
|
+
|
|
3148
|
+
return true;
|
|
3149
|
+
}
|
|
3150
|
+
|
|
3151
|
+
bool cuda_graph_insert_while(void* context, void* stream, int* condition, void** body_graph_ret, uint64_t* handle_ret)
|
|
3152
|
+
{
|
|
3153
|
+
// if there's no body, it's a no-op
|
|
3154
|
+
if (!body_graph_ret)
|
|
3155
|
+
return true;
|
|
3156
|
+
|
|
3157
|
+
ContextGuard guard(context);
|
|
3158
|
+
|
|
3159
|
+
CUstream cuda_stream = static_cast<CUstream>(stream);
|
|
3160
|
+
|
|
3161
|
+
// Get the current stream capturing graph
|
|
3162
|
+
cudaStreamCaptureStatus capture_status = cudaStreamCaptureStatusNone;
|
|
3163
|
+
cudaGraph_t cuda_graph = NULL;
|
|
3164
|
+
const cudaGraphNode_t* capture_deps = NULL;
|
|
3165
|
+
size_t dep_count = 0;
|
|
3166
|
+
if (!check_cuda(cudaStreamGetCaptureInfo(cuda_stream, &capture_status, nullptr, &cuda_graph, &capture_deps, &dep_count)))
|
|
3167
|
+
return false;
|
|
3168
|
+
|
|
3169
|
+
// abort if not capturing
|
|
3170
|
+
if (!cuda_graph || capture_status != cudaStreamCaptureStatusActive)
|
|
3171
|
+
{
|
|
3172
|
+
wp::set_error_string("Stream is not capturing");
|
|
3173
|
+
return false;
|
|
3174
|
+
}
|
|
3175
|
+
|
|
3176
|
+
cudaGraphConditionalHandle handle;
|
|
3177
|
+
if (!check_cuda(cudaGraphConditionalHandleCreate(&handle, cuda_graph)))
|
|
3178
|
+
return false;
|
|
3179
|
+
|
|
3180
|
+
// launch a kernel to set the condition handle from condition pointer
|
|
3181
|
+
CUfunction kernel = get_conditional_kernel(context, "set_conditional_if_handle_kernel");
|
|
3182
|
+
if (!kernel)
|
|
3183
|
+
{
|
|
3184
|
+
wp::set_error_string("Failed to get built-in conditional kernel");
|
|
3185
|
+
return false;
|
|
3186
|
+
}
|
|
3187
|
+
|
|
3188
|
+
void* kernel_args[2];
|
|
3189
|
+
kernel_args[0] = &handle;
|
|
3190
|
+
kernel_args[1] = &condition;
|
|
3191
|
+
|
|
3192
|
+
if (!check_cu(cuLaunchKernel_f(kernel, 1, 1, 1, 1, 1, 1, 0, cuda_stream, kernel_args, NULL)))
|
|
3193
|
+
return false;
|
|
3194
|
+
|
|
3195
|
+
if (!check_cuda(cudaStreamGetCaptureInfo(cuda_stream, &capture_status, nullptr, &cuda_graph, &capture_deps, &dep_count)))
|
|
3196
|
+
return false;
|
|
3197
|
+
|
|
3198
|
+
// insert conditional graph node
|
|
3199
|
+
cudaGraphNode_t while_node;
|
|
3200
|
+
cudaGraphNodeParams while_params = { cudaGraphNodeTypeConditional };
|
|
3201
|
+
while_params.conditional.handle = handle;
|
|
3202
|
+
while_params.conditional.type = cudaGraphCondTypeWhile;
|
|
3203
|
+
while_params.conditional.size = 1;
|
|
3204
|
+
if (!check_cuda(cudaGraphAddNode(&while_node, cuda_graph, capture_deps, dep_count, &while_params)))
|
|
3205
|
+
return false;
|
|
3206
|
+
|
|
3207
|
+
if (!check_cuda(cudaStreamUpdateCaptureDependencies(cuda_stream, &while_node, 1, cudaStreamSetCaptureDependencies)))
|
|
3208
|
+
return false;
|
|
3209
|
+
|
|
3210
|
+
*body_graph_ret = while_params.conditional.phGraph_out[0];
|
|
3211
|
+
*handle_ret = handle;
|
|
3212
|
+
|
|
3213
|
+
return true;
|
|
3214
|
+
}
|
|
3215
|
+
|
|
3216
|
+
bool cuda_graph_set_condition(void* context, void* stream, int* condition, uint64_t handle)
|
|
3217
|
+
{
|
|
3218
|
+
ContextGuard guard(context);
|
|
3219
|
+
|
|
3220
|
+
CUstream cuda_stream = static_cast<CUstream>(stream);
|
|
3221
|
+
|
|
3222
|
+
// launch a kernel to set the condition handle from condition pointer
|
|
3223
|
+
CUfunction kernel = get_conditional_kernel(context, "set_conditional_if_handle_kernel");
|
|
3224
|
+
if (!kernel)
|
|
3225
|
+
{
|
|
3226
|
+
wp::set_error_string("Failed to get built-in conditional kernel");
|
|
3227
|
+
return false;
|
|
3228
|
+
}
|
|
3229
|
+
|
|
3230
|
+
void* kernel_args[2];
|
|
3231
|
+
kernel_args[0] = &handle;
|
|
3232
|
+
kernel_args[1] = &condition;
|
|
3233
|
+
|
|
3234
|
+
if (!check_cu(cuLaunchKernel_f(kernel, 1, 1, 1, 1, 1, 1, 0, cuda_stream, kernel_args, NULL)))
|
|
3235
|
+
return false;
|
|
3236
|
+
|
|
3237
|
+
return true;
|
|
3238
|
+
}
|
|
3239
|
+
|
|
3240
|
+
#else
|
|
3241
|
+
// stubs for conditional graph node API if CUDA toolkit is too old.
|
|
3242
|
+
|
|
3243
|
+
bool cuda_graph_pause_capture(void* context, void* stream, void** graph_ret)
|
|
3244
|
+
{
|
|
3245
|
+
wp::set_error_string("Warp error: Warp must be built with CUDA Toolkit 12.4+ to enable conditional graph nodes");
|
|
3246
|
+
return false;
|
|
3247
|
+
}
|
|
3248
|
+
|
|
3249
|
+
bool cuda_graph_resume_capture(void* context, void* stream, void* graph)
|
|
3250
|
+
{
|
|
3251
|
+
wp::set_error_string("Warp error: Warp must be built with CUDA Toolkit 12.4+ to enable conditional graph nodes");
|
|
3252
|
+
return false;
|
|
3253
|
+
}
|
|
3254
|
+
|
|
3255
|
+
bool cuda_graph_insert_if_else(void* context, void* stream, int* condition, void** if_graph_ret, void** else_graph_ret)
|
|
3256
|
+
{
|
|
3257
|
+
wp::set_error_string("Warp error: Warp must be built with CUDA Toolkit 12.4+ to enable conditional graph nodes");
|
|
3258
|
+
return false;
|
|
3259
|
+
}
|
|
3260
|
+
|
|
3261
|
+
bool cuda_graph_insert_while(void* context, void* stream, int* condition, void** body_graph_ret, uint64_t* handle_ret)
|
|
3262
|
+
{
|
|
3263
|
+
wp::set_error_string("Warp error: Warp must be built with CUDA Toolkit 12.4+ to enable conditional graph nodes");
|
|
3264
|
+
return false;
|
|
3265
|
+
}
|
|
3266
|
+
|
|
3267
|
+
bool cuda_graph_set_condition(void* context, void* stream, int* condition, uint64_t handle)
|
|
3268
|
+
{
|
|
3269
|
+
wp::set_error_string("Warp error: Warp must be built with CUDA Toolkit 12.4+ to enable conditional graph nodes");
|
|
3270
|
+
return false;
|
|
3271
|
+
}
|
|
3272
|
+
|
|
3273
|
+
bool cuda_graph_insert_child_graph(void* context, void* stream, void* child_graph)
|
|
3274
|
+
{
|
|
3275
|
+
wp::set_error_string("Warp error: Warp must be built with CUDA Toolkit 12.4+ to enable conditional graph nodes");
|
|
3276
|
+
return false;
|
|
3277
|
+
}
|
|
3278
|
+
|
|
3279
|
+
#endif // support for conditional graph nodes
|
|
3280
|
+
|
|
3281
|
+
|
|
2771
3282
|
bool cuda_graph_launch(void* graph_exec, void* stream)
|
|
2772
3283
|
{
|
|
2773
3284
|
// TODO: allow naming graphs?
|
|
@@ -2780,7 +3291,14 @@ bool cuda_graph_launch(void* graph_exec, void* stream)
|
|
|
2780
3291
|
return result;
|
|
2781
3292
|
}
|
|
2782
3293
|
|
|
2783
|
-
bool cuda_graph_destroy(void* context, void*
|
|
3294
|
+
bool cuda_graph_destroy(void* context, void* graph)
|
|
3295
|
+
{
|
|
3296
|
+
ContextGuard guard(context);
|
|
3297
|
+
|
|
3298
|
+
return check_cuda(cudaGraphDestroy((cudaGraph_t)graph));
|
|
3299
|
+
}
|
|
3300
|
+
|
|
3301
|
+
bool cuda_graph_exec_destroy(void* context, void* graph_exec)
|
|
2784
3302
|
{
|
|
2785
3303
|
ContextGuard guard(context);
|
|
2786
3304
|
|
|
@@ -2832,7 +3350,7 @@ bool write_file(const char* data, size_t size, std::string filename, const char*
|
|
|
2832
3350
|
}
|
|
2833
3351
|
#endif
|
|
2834
3352
|
|
|
2835
|
-
size_t cuda_compile_program(const char* cuda_src, const char* program_name, int arch, const char* include_dir, int num_cuda_include_dirs, const char** cuda_include_dirs, bool debug, bool verbose, bool verify_fp, bool fast_math, bool fuse_fp, bool lineinfo, const char* output_path, size_t num_ltoirs, char** ltoirs, size_t* ltoir_sizes, int* ltoir_input_types)
|
|
3353
|
+
size_t cuda_compile_program(const char* cuda_src, const char* program_name, int arch, const char* include_dir, int num_cuda_include_dirs, const char** cuda_include_dirs, bool debug, bool verbose, bool verify_fp, bool fast_math, bool fuse_fp, bool lineinfo, bool compile_time_trace, const char* output_path, size_t num_ltoirs, char** ltoirs, size_t* ltoir_sizes, int* ltoir_input_types)
|
|
2836
3354
|
{
|
|
2837
3355
|
// use file extension to determine whether to output PTX or CUBIN
|
|
2838
3356
|
const char* output_ext = strrchr(output_path, '.');
|
|
@@ -2919,11 +3437,11 @@ size_t cuda_compile_program(const char* cuda_src, const char* program_name, int
|
|
|
2919
3437
|
else
|
|
2920
3438
|
opts.push_back("--fmad=false");
|
|
2921
3439
|
|
|
2922
|
-
std::vector<std::string>
|
|
3440
|
+
std::vector<std::string> stored_options;
|
|
2923
3441
|
for(int i = 0; i < num_cuda_include_dirs; i++)
|
|
2924
3442
|
{
|
|
2925
|
-
|
|
2926
|
-
opts.push_back(
|
|
3443
|
+
stored_options.push_back(std::string("--include-path=") + cuda_include_dirs[i]);
|
|
3444
|
+
opts.push_back(stored_options.back().c_str());
|
|
2927
3445
|
}
|
|
2928
3446
|
|
|
2929
3447
|
opts.push_back("--device-as-default-execution-space");
|
|
@@ -2936,6 +3454,16 @@ size_t cuda_compile_program(const char* cuda_src, const char* program_name, int
|
|
|
2936
3454
|
opts.push_back("--relocatable-device-code=true");
|
|
2937
3455
|
}
|
|
2938
3456
|
|
|
3457
|
+
if (compile_time_trace)
|
|
3458
|
+
{
|
|
3459
|
+
#if CUDA_VERSION >= 12080
|
|
3460
|
+
stored_options.push_back(std::string("--fdevice-time-trace=") + std::string(output_path).append("_compile-time-trace.json"));
|
|
3461
|
+
opts.push_back(stored_options.back().c_str());
|
|
3462
|
+
#else
|
|
3463
|
+
fprintf(stderr, "Warp warning: CUDA version is less than 12.8, compile_time_trace is not supported\n");
|
|
3464
|
+
#endif
|
|
3465
|
+
}
|
|
3466
|
+
|
|
2939
3467
|
nvrtcProgram prog;
|
|
2940
3468
|
nvrtcResult res;
|
|
2941
3469
|
|
|
@@ -3162,11 +3690,11 @@ size_t cuda_compile_program(const char* cuda_src, const char* program_name, int
|
|
|
3162
3690
|
CHECK_ANY(num_include_dirs == 0);
|
|
3163
3691
|
|
|
3164
3692
|
bool res = true;
|
|
3165
|
-
|
|
3166
|
-
CHECK_CUFFTDX(
|
|
3693
|
+
cufftdxDescriptor h;
|
|
3694
|
+
CHECK_CUFFTDX(cufftdxCreateDescriptor(&h));
|
|
3167
3695
|
|
|
3168
|
-
//
|
|
3169
|
-
CHECK_CUFFTDX(cufftdxSetOperatorInt64(h, cufftdxOperatorType::CUFFTDX_OPERATOR_API, cufftdxApi::
|
|
3696
|
+
// CUFFTDX_API_LMEM means each thread starts with a subset of the data
|
|
3697
|
+
CHECK_CUFFTDX(cufftdxSetOperatorInt64(h, cufftdxOperatorType::CUFFTDX_OPERATOR_API, cufftdxApi::CUFFTDX_API_LMEM));
|
|
3170
3698
|
CHECK_CUFFTDX(cufftdxSetOperatorInt64(h, cufftdxOperatorType::CUFFTDX_OPERATOR_EXECUTION, commondxExecution::COMMONDX_EXECUTION_BLOCK));
|
|
3171
3699
|
CHECK_CUFFTDX(cufftdxSetOperatorInt64(h, cufftdxOperatorType::CUFFTDX_OPERATOR_SIZE, (long long)size));
|
|
3172
3700
|
CHECK_CUFFTDX(cufftdxSetOperatorInt64(h, cufftdxOperatorType::CUFFTDX_OPERATOR_DIRECTION, (cufftdxDirection)direction));
|
|
@@ -3191,7 +3719,7 @@ size_t cuda_compile_program(const char* cuda_src, const char* program_name, int
|
|
|
3191
3719
|
res = false;
|
|
3192
3720
|
}
|
|
3193
3721
|
|
|
3194
|
-
CHECK_CUFFTDX(
|
|
3722
|
+
CHECK_CUFFTDX(cufftdxDestroyDescriptor(h));
|
|
3195
3723
|
|
|
3196
3724
|
return res;
|
|
3197
3725
|
}
|
|
@@ -3207,22 +3735,22 @@ size_t cuda_compile_program(const char* cuda_src, const char* program_name, int
|
|
|
3207
3735
|
CHECK_ANY(num_include_dirs == 0);
|
|
3208
3736
|
|
|
3209
3737
|
bool res = true;
|
|
3210
|
-
|
|
3211
|
-
CHECK_CUBLASDX(
|
|
3738
|
+
cublasdxDescriptor h;
|
|
3739
|
+
CHECK_CUBLASDX(cublasdxCreateDescriptor(&h));
|
|
3212
3740
|
|
|
3213
3741
|
CHECK_CUBLASDX(cublasdxSetOperatorInt64(h, cublasdxOperatorType::CUBLASDX_OPERATOR_FUNCTION, cublasdxFunction::CUBLASDX_FUNCTION_MM));
|
|
3214
3742
|
CHECK_CUBLASDX(cublasdxSetOperatorInt64(h, cublasdxOperatorType::CUBLASDX_OPERATOR_EXECUTION, commondxExecution::COMMONDX_EXECUTION_BLOCK));
|
|
3215
|
-
CHECK_CUBLASDX(cublasdxSetOperatorInt64(h, cublasdxOperatorType::CUBLASDX_OPERATOR_API, cublasdxApi::
|
|
3743
|
+
CHECK_CUBLASDX(cublasdxSetOperatorInt64(h, cublasdxOperatorType::CUBLASDX_OPERATOR_API, cublasdxApi::CUBLASDX_API_SMEM));
|
|
3216
3744
|
std::array<long long int, 3> precisions = {precision_A, precision_B, precision_C};
|
|
3217
|
-
CHECK_CUBLASDX(
|
|
3745
|
+
CHECK_CUBLASDX(cublasdxSetOperatorInt64s(h, cublasdxOperatorType::CUBLASDX_OPERATOR_PRECISION, 3, precisions.data()));
|
|
3218
3746
|
CHECK_CUBLASDX(cublasdxSetOperatorInt64(h, cublasdxOperatorType::CUBLASDX_OPERATOR_SM, (long long)(arch * 10)));
|
|
3219
3747
|
CHECK_CUBLASDX(cublasdxSetOperatorInt64(h, cublasdxOperatorType::CUBLASDX_OPERATOR_TYPE, (cublasdxType)type));
|
|
3220
3748
|
std::array<long long int, 3> block_dim = {num_threads, 1, 1};
|
|
3221
|
-
CHECK_CUBLASDX(
|
|
3749
|
+
CHECK_CUBLASDX(cublasdxSetOperatorInt64s(h, cublasdxOperatorType::CUBLASDX_OPERATOR_BLOCK_DIM, block_dim.size(), block_dim.data()));
|
|
3222
3750
|
std::array<long long int, 3> size = {M, N, K};
|
|
3223
|
-
CHECK_CUBLASDX(
|
|
3751
|
+
CHECK_CUBLASDX(cublasdxSetOperatorInt64s(h, cublasdxOperatorType::CUBLASDX_OPERATOR_SIZE, size.size(), size.data()));
|
|
3224
3752
|
std::array<long long int, 3> arrangement = {arrangement_A, arrangement_B, arrangement_C};
|
|
3225
|
-
CHECK_CUBLASDX(
|
|
3753
|
+
CHECK_CUBLASDX(cublasdxSetOperatorInt64s(h, cublasdxOperatorType::CUBLASDX_OPERATOR_ARRANGEMENT, arrangement.size(), arrangement.data()));
|
|
3226
3754
|
|
|
3227
3755
|
CHECK_CUBLASDX(cublasdxSetOptionStr(h, commondxOption::COMMONDX_OPTION_SYMBOL_NAME, symbol_name));
|
|
3228
3756
|
|
|
@@ -3236,12 +3764,12 @@ size_t cuda_compile_program(const char* cuda_src, const char* program_name, int
|
|
|
3236
3764
|
res = false;
|
|
3237
3765
|
}
|
|
3238
3766
|
|
|
3239
|
-
CHECK_CUBLASDX(
|
|
3767
|
+
CHECK_CUBLASDX(cublasdxDestroyDescriptor(h));
|
|
3240
3768
|
|
|
3241
3769
|
return res;
|
|
3242
3770
|
}
|
|
3243
3771
|
|
|
3244
|
-
bool cuda_compile_solver(const char* fatbin_output_path, const char* ltoir_output_path, const char* symbol_name, int num_include_dirs, const char** include_dirs, const char* mathdx_include_dir, int arch, int M, int N, int function, int precision, int fill_mode, int num_threads)
|
|
3772
|
+
bool cuda_compile_solver(const char* fatbin_output_path, const char* ltoir_output_path, const char* symbol_name, int num_include_dirs, const char** include_dirs, const char* mathdx_include_dir, int arch, int M, int N, int NRHS, int function, int side, int diag, int precision, int arrangement_A, int arrangement_B, int fill_mode, int num_threads)
|
|
3245
3773
|
{
|
|
3246
3774
|
|
|
3247
3775
|
CHECK_ANY(ltoir_output_path != nullptr);
|
|
@@ -3252,34 +3780,42 @@ size_t cuda_compile_program(const char* cuda_src, const char* program_name, int
|
|
|
3252
3780
|
|
|
3253
3781
|
bool res = true;
|
|
3254
3782
|
|
|
3255
|
-
|
|
3256
|
-
CHECK_CUSOLVER(
|
|
3257
|
-
long long int size
|
|
3258
|
-
|
|
3259
|
-
|
|
3260
|
-
CHECK_CUSOLVER(
|
|
3261
|
-
CHECK_CUSOLVER(
|
|
3262
|
-
CHECK_CUSOLVER(
|
|
3263
|
-
CHECK_CUSOLVER(
|
|
3264
|
-
|
|
3265
|
-
|
|
3266
|
-
|
|
3267
|
-
|
|
3783
|
+
cusolverdxDescriptor h { 0 };
|
|
3784
|
+
CHECK_CUSOLVER(cusolverdxCreateDescriptor(&h));
|
|
3785
|
+
std::array<long long int, 3> size = {M, N, NRHS};
|
|
3786
|
+
CHECK_CUSOLVER(cusolverdxSetOperatorInt64s(h, cusolverdxOperatorType::CUSOLVERDX_OPERATOR_SIZE, size.size(), size.data()));
|
|
3787
|
+
std::array<long long int, 3> block_dim = {num_threads, 1, 1};
|
|
3788
|
+
CHECK_CUSOLVER(cusolverdxSetOperatorInt64s(h, cusolverdxOperatorType::CUSOLVERDX_OPERATOR_BLOCK_DIM, block_dim.size(), block_dim.data()));
|
|
3789
|
+
CHECK_CUSOLVER(cusolverdxSetOperatorInt64(h, cusolverdxOperatorType::CUSOLVERDX_OPERATOR_TYPE, cusolverdxType::CUSOLVERDX_TYPE_REAL));
|
|
3790
|
+
CHECK_CUSOLVER(cusolverdxSetOperatorInt64(h, cusolverdxOperatorType::CUSOLVERDX_OPERATOR_API, cusolverdxApi::CUSOLVERDX_API_SMEM));
|
|
3791
|
+
CHECK_CUSOLVER(cusolverdxSetOperatorInt64(h, cusolverdxOperatorType::CUSOLVERDX_OPERATOR_FUNCTION, (cusolverdxFunction)function));
|
|
3792
|
+
if (side >= 0) {
|
|
3793
|
+
CHECK_CUSOLVER(cusolverdxSetOperatorInt64(h, cusolverdxOperatorType::CUSOLVERDX_OPERATOR_SIDE, (cusolverdxSide)side));
|
|
3794
|
+
}
|
|
3795
|
+
if (diag >= 0) {
|
|
3796
|
+
CHECK_CUSOLVER(cusolverdxSetOperatorInt64(h, cusolverdxOperatorType::CUSOLVERDX_OPERATOR_DIAG, (cusolverdxDiag)diag));
|
|
3797
|
+
}
|
|
3798
|
+
CHECK_CUSOLVER(cusolverdxSetOperatorInt64(h, cusolverdxOperatorType::CUSOLVERDX_OPERATOR_EXECUTION, commondxExecution::COMMONDX_EXECUTION_BLOCK));
|
|
3799
|
+
CHECK_CUSOLVER(cusolverdxSetOperatorInt64(h, cusolverdxOperatorType::CUSOLVERDX_OPERATOR_PRECISION, (commondxPrecision)precision));
|
|
3800
|
+
std::array<long long int, 2> arrangement = {arrangement_A, arrangement_B};
|
|
3801
|
+
CHECK_CUSOLVER(cusolverdxSetOperatorInt64s(h, cusolverdxOperatorType::CUSOLVERDX_OPERATOR_ARRANGEMENT, arrangement.size(), arrangement.data()));
|
|
3802
|
+
CHECK_CUSOLVER(cusolverdxSetOperatorInt64(h, cusolverdxOperatorType::CUSOLVERDX_OPERATOR_FILL_MODE, (cusolverdxFillMode)fill_mode));
|
|
3803
|
+
CHECK_CUSOLVER(cusolverdxSetOperatorInt64(h, cusolverdxOperatorType::CUSOLVERDX_OPERATOR_SM, (long long)(arch * 10)));
|
|
3268
3804
|
|
|
3269
|
-
CHECK_CUSOLVER(
|
|
3805
|
+
CHECK_CUSOLVER(cusolverdxSetOptionStr(h, commondxOption::COMMONDX_OPTION_SYMBOL_NAME, symbol_name));
|
|
3270
3806
|
|
|
3271
3807
|
size_t lto_size = 0;
|
|
3272
|
-
CHECK_CUSOLVER(
|
|
3808
|
+
CHECK_CUSOLVER(cusolverdxGetLTOIRSize(h, <o_size));
|
|
3273
3809
|
|
|
3274
3810
|
std::vector<char> lto(lto_size);
|
|
3275
|
-
CHECK_CUSOLVER(
|
|
3811
|
+
CHECK_CUSOLVER(cusolverdxGetLTOIR(h, lto.size(), lto.data()));
|
|
3276
3812
|
|
|
3277
3813
|
// This fatbin is universal, ie it is the same for any instantiations of a cusolver device function
|
|
3278
3814
|
size_t fatbin_size = 0;
|
|
3279
|
-
CHECK_CUSOLVER(
|
|
3815
|
+
CHECK_CUSOLVER(cusolverdxGetUniversalFATBINSize(h, &fatbin_size));
|
|
3280
3816
|
|
|
3281
3817
|
std::vector<char> fatbin(fatbin_size);
|
|
3282
|
-
CHECK_CUSOLVER(
|
|
3818
|
+
CHECK_CUSOLVER(cusolverdxGetUniversalFATBIN(h, fatbin.size(), fatbin.data()));
|
|
3283
3819
|
|
|
3284
3820
|
if(!write_file(lto.data(), lto.size(), ltoir_output_path, "wb")) {
|
|
3285
3821
|
res = false;
|
|
@@ -3289,7 +3825,7 @@ size_t cuda_compile_program(const char* cuda_src, const char* program_name, int
|
|
|
3289
3825
|
res = false;
|
|
3290
3826
|
}
|
|
3291
3827
|
|
|
3292
|
-
CHECK_CUSOLVER(
|
|
3828
|
+
CHECK_CUSOLVER(cusolverdxDestroyDescriptor(h));
|
|
3293
3829
|
|
|
3294
3830
|
return res;
|
|
3295
3831
|
}
|