warp-lang 0.11.0__py3-none-manylinux2014_x86_64.whl → 1.0.0__py3-none-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +8 -0
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/build.py +7 -6
- warp/build_dll.py +70 -79
- warp/builtins.py +10 -6
- warp/codegen.py +51 -19
- warp/config.py +7 -8
- warp/constants.py +3 -0
- warp/context.py +948 -245
- warp/dlpack.py +198 -113
- warp/examples/assets/bunny.usd +0 -0
- warp/examples/assets/cartpole.urdf +110 -0
- warp/examples/assets/crazyflie.usd +0 -0
- warp/examples/assets/cube.usda +42 -0
- warp/examples/assets/nv_ant.xml +92 -0
- warp/examples/assets/nv_humanoid.xml +183 -0
- warp/examples/assets/quadruped.urdf +268 -0
- warp/examples/assets/rocks.nvdb +0 -0
- warp/examples/assets/rocks.usd +0 -0
- warp/examples/assets/sphere.usda +56 -0
- warp/examples/assets/torus.usda +105 -0
- warp/examples/benchmarks/benchmark_api.py +383 -0
- warp/examples/benchmarks/benchmark_cloth.py +279 -0
- warp/examples/benchmarks/benchmark_cloth_cupy.py +88 -0
- warp/examples/benchmarks/benchmark_cloth_jax.py +100 -0
- warp/examples/benchmarks/benchmark_cloth_numba.py +142 -0
- warp/examples/benchmarks/benchmark_cloth_numpy.py +77 -0
- warp/examples/benchmarks/benchmark_cloth_pytorch.py +86 -0
- warp/examples/benchmarks/benchmark_cloth_taichi.py +112 -0
- warp/examples/benchmarks/benchmark_cloth_warp.py +146 -0
- warp/examples/benchmarks/benchmark_launches.py +295 -0
- warp/examples/core/example_dem.py +221 -0
- warp/examples/core/example_fluid.py +267 -0
- warp/examples/core/example_graph_capture.py +129 -0
- warp/examples/core/example_marching_cubes.py +177 -0
- warp/examples/core/example_mesh.py +154 -0
- warp/examples/core/example_mesh_intersect.py +193 -0
- warp/examples/core/example_nvdb.py +169 -0
- warp/examples/core/example_raycast.py +89 -0
- warp/examples/core/example_raymarch.py +178 -0
- warp/examples/core/example_render_opengl.py +141 -0
- warp/examples/core/example_sph.py +389 -0
- warp/examples/core/example_torch.py +181 -0
- warp/examples/core/example_wave.py +249 -0
- warp/examples/fem/bsr_utils.py +380 -0
- warp/examples/fem/example_apic_fluid.py +391 -0
- warp/examples/fem/example_convection_diffusion.py +168 -0
- warp/examples/fem/example_convection_diffusion_dg.py +209 -0
- warp/examples/fem/example_convection_diffusion_dg0.py +194 -0
- warp/examples/fem/example_deformed_geometry.py +159 -0
- warp/examples/fem/example_diffusion.py +173 -0
- warp/examples/fem/example_diffusion_3d.py +152 -0
- warp/examples/fem/example_diffusion_mgpu.py +214 -0
- warp/examples/fem/example_mixed_elasticity.py +222 -0
- warp/examples/fem/example_navier_stokes.py +243 -0
- warp/examples/fem/example_stokes.py +192 -0
- warp/examples/fem/example_stokes_transfer.py +249 -0
- warp/examples/fem/mesh_utils.py +109 -0
- warp/examples/fem/plot_utils.py +287 -0
- warp/examples/optim/example_bounce.py +248 -0
- warp/examples/optim/example_cloth_throw.py +210 -0
- warp/examples/optim/example_diffray.py +535 -0
- warp/examples/optim/example_drone.py +850 -0
- warp/examples/optim/example_inverse_kinematics.py +169 -0
- warp/examples/optim/example_inverse_kinematics_torch.py +170 -0
- warp/examples/optim/example_spring_cage.py +234 -0
- warp/examples/optim/example_trajectory.py +201 -0
- warp/examples/sim/example_cartpole.py +128 -0
- warp/examples/sim/example_cloth.py +184 -0
- warp/examples/sim/example_granular.py +113 -0
- warp/examples/sim/example_granular_collision_sdf.py +185 -0
- warp/examples/sim/example_jacobian_ik.py +213 -0
- warp/examples/sim/example_particle_chain.py +106 -0
- warp/examples/sim/example_quadruped.py +179 -0
- warp/examples/sim/example_rigid_chain.py +191 -0
- warp/examples/sim/example_rigid_contact.py +176 -0
- warp/examples/sim/example_rigid_force.py +126 -0
- warp/examples/sim/example_rigid_gyroscopic.py +97 -0
- warp/examples/sim/example_rigid_soft_contact.py +124 -0
- warp/examples/sim/example_soft_body.py +178 -0
- warp/fabric.py +29 -20
- warp/fem/cache.py +0 -1
- warp/fem/dirichlet.py +0 -2
- warp/fem/integrate.py +0 -1
- warp/jax.py +45 -0
- warp/jax_experimental.py +339 -0
- warp/native/builtin.h +12 -0
- warp/native/bvh.cu +18 -18
- warp/native/clang/clang.cpp +8 -3
- warp/native/cuda_util.cpp +94 -5
- warp/native/cuda_util.h +35 -6
- warp/native/cutlass_gemm.cpp +1 -1
- warp/native/cutlass_gemm.cu +4 -1
- warp/native/error.cpp +66 -0
- warp/native/error.h +27 -0
- warp/native/mesh.cu +2 -2
- warp/native/reduce.cu +4 -4
- warp/native/runlength_encode.cu +2 -2
- warp/native/scan.cu +2 -2
- warp/native/sparse.cu +0 -1
- warp/native/temp_buffer.h +2 -2
- warp/native/warp.cpp +95 -60
- warp/native/warp.cu +1053 -218
- warp/native/warp.h +49 -32
- warp/optim/linear.py +33 -16
- warp/render/render_opengl.py +202 -101
- warp/render/render_usd.py +82 -40
- warp/sim/__init__.py +13 -4
- warp/sim/articulation.py +4 -5
- warp/sim/collide.py +320 -175
- warp/sim/import_mjcf.py +25 -30
- warp/sim/import_urdf.py +94 -63
- warp/sim/import_usd.py +51 -36
- warp/sim/inertia.py +3 -2
- warp/sim/integrator.py +233 -0
- warp/sim/integrator_euler.py +447 -469
- warp/sim/integrator_featherstone.py +1991 -0
- warp/sim/integrator_xpbd.py +1420 -640
- warp/sim/model.py +765 -487
- warp/sim/particles.py +2 -1
- warp/sim/render.py +35 -13
- warp/sim/utils.py +222 -11
- warp/stubs.py +8 -0
- warp/tape.py +16 -1
- warp/tests/aux_test_grad_customs.py +23 -0
- warp/tests/test_array.py +190 -1
- warp/tests/test_async.py +656 -0
- warp/tests/test_bool.py +50 -0
- warp/tests/test_dlpack.py +164 -11
- warp/tests/test_examples.py +166 -74
- warp/tests/test_fem.py +8 -1
- warp/tests/test_generics.py +15 -5
- warp/tests/test_grad.py +1 -1
- warp/tests/test_grad_customs.py +172 -12
- warp/tests/test_jax.py +254 -0
- warp/tests/test_large.py +29 -6
- warp/tests/test_launch.py +25 -0
- warp/tests/test_linear_solvers.py +20 -3
- warp/tests/test_matmul.py +61 -16
- warp/tests/test_matmul_lite.py +13 -13
- warp/tests/test_mempool.py +186 -0
- warp/tests/test_multigpu.py +3 -0
- warp/tests/test_options.py +16 -2
- warp/tests/test_peer.py +137 -0
- warp/tests/test_print.py +3 -1
- warp/tests/test_quat.py +23 -0
- warp/tests/test_sim_kinematics.py +97 -0
- warp/tests/test_snippet.py +126 -3
- warp/tests/test_streams.py +108 -79
- warp/tests/test_torch.py +16 -8
- warp/tests/test_utils.py +32 -27
- warp/tests/test_verify_fp.py +65 -0
- warp/tests/test_volume.py +1 -1
- warp/tests/unittest_serial.py +2 -0
- warp/tests/unittest_suites.py +12 -0
- warp/tests/unittest_utils.py +14 -7
- warp/thirdparty/unittest_parallel.py +15 -3
- warp/torch.py +10 -8
- warp/types.py +363 -246
- warp/utils.py +143 -19
- warp_lang-1.0.0.dist-info/LICENSE.md +126 -0
- warp_lang-1.0.0.dist-info/METADATA +394 -0
- {warp_lang-0.11.0.dist-info → warp_lang-1.0.0.dist-info}/RECORD +167 -86
- warp/sim/optimizer.py +0 -138
- warp_lang-0.11.0.dist-info/LICENSE.md +0 -36
- warp_lang-0.11.0.dist-info/METADATA +0 -238
- /warp/tests/{walkthough_debug.py → walkthrough_debug.py} +0 -0
- {warp_lang-0.11.0.dist-info → warp_lang-1.0.0.dist-info}/WHEEL +0 -0
- {warp_lang-0.11.0.dist-info → warp_lang-1.0.0.dist-info}/top_level.txt +0 -0
warp/native/warp.cu
CHANGED
|
@@ -9,11 +9,17 @@
|
|
|
9
9
|
#include "warp.h"
|
|
10
10
|
#include "scan.h"
|
|
11
11
|
#include "cuda_util.h"
|
|
12
|
+
#include "error.h"
|
|
12
13
|
|
|
13
14
|
#include <nvrtc.h>
|
|
14
15
|
#include <nvPTXCompiler.h>
|
|
15
16
|
|
|
17
|
+
#include <algorithm>
|
|
18
|
+
#include <iterator>
|
|
19
|
+
#include <list>
|
|
16
20
|
#include <map>
|
|
21
|
+
#include <unordered_map>
|
|
22
|
+
#include <unordered_set>
|
|
17
23
|
#include <vector>
|
|
18
24
|
|
|
19
25
|
#define check_nvrtc(code) (check_nvrtc_result(code, __FILE__, __LINE__))
|
|
@@ -81,14 +87,55 @@ struct DeviceInfo
|
|
|
81
87
|
char name[kNameLen] = "";
|
|
82
88
|
int arch = 0;
|
|
83
89
|
int is_uva = 0;
|
|
84
|
-
int
|
|
90
|
+
int is_mempool_supported = 0;
|
|
91
|
+
CUcontext primary_context = NULL;
|
|
85
92
|
};
|
|
86
93
|
|
|
87
94
|
struct ContextInfo
|
|
88
95
|
{
|
|
89
96
|
DeviceInfo* device_info = NULL;
|
|
90
97
|
|
|
91
|
-
|
|
98
|
+
// the current stream, managed from Python (see cuda_context_set_stream() and cuda_context_get_stream())
|
|
99
|
+
CUstream stream = NULL;
|
|
100
|
+
};
|
|
101
|
+
|
|
102
|
+
struct CaptureInfo
|
|
103
|
+
{
|
|
104
|
+
CUstream stream = NULL; // the main stream where capture begins and ends
|
|
105
|
+
uint64_t id = 0; // unique capture id from CUDA
|
|
106
|
+
bool external = false; // whether this is an external capture
|
|
107
|
+
};
|
|
108
|
+
|
|
109
|
+
struct StreamInfo
|
|
110
|
+
{
|
|
111
|
+
CUevent cached_event = NULL; // event used for stream synchronization (cached to avoid creating temporary events)
|
|
112
|
+
CaptureInfo* capture = NULL; // capture info (only if started on this stream)
|
|
113
|
+
};
|
|
114
|
+
|
|
115
|
+
struct GraphInfo
|
|
116
|
+
{
|
|
117
|
+
std::vector<void*> unfreed_allocs;
|
|
118
|
+
};
|
|
119
|
+
|
|
120
|
+
// Information for graph allocations that are not freed by the graph.
|
|
121
|
+
// These allocations have a shared ownership:
|
|
122
|
+
// - The graph instance allocates/maps the memory on each launch, even if the user reference is released.
|
|
123
|
+
// - The user reference must remain valid even if the graph is destroyed.
|
|
124
|
+
// The memory will be freed once the user reference is released and the graph is destroyed.
|
|
125
|
+
struct GraphAllocInfo
|
|
126
|
+
{
|
|
127
|
+
uint64_t capture_id = 0;
|
|
128
|
+
void* context = NULL;
|
|
129
|
+
bool ref_exists = false; // whether user reference still exists
|
|
130
|
+
bool graph_destroyed = false; // whether graph instance was destroyed
|
|
131
|
+
};
|
|
132
|
+
|
|
133
|
+
// Information used when deferring deallocations.
|
|
134
|
+
struct FreeInfo
|
|
135
|
+
{
|
|
136
|
+
void* context = NULL;
|
|
137
|
+
void* ptr = NULL;
|
|
138
|
+
bool is_async = false;
|
|
92
139
|
};
|
|
93
140
|
|
|
94
141
|
// cached info for all devices, indexed by ordinal
|
|
@@ -100,6 +147,22 @@ static std::map<CUdevice, DeviceInfo*> g_device_map;
|
|
|
100
147
|
// cached info for all known contexts
|
|
101
148
|
static std::map<CUcontext, ContextInfo> g_contexts;
|
|
102
149
|
|
|
150
|
+
// cached info for all known streams (including registered external streams)
|
|
151
|
+
static std::unordered_map<CUstream, StreamInfo> g_streams;
|
|
152
|
+
|
|
153
|
+
// Ongoing graph captures registered using wp.capture_begin().
|
|
154
|
+
// This maps the capture id to the stream where capture was started.
|
|
155
|
+
// See cuda_graph_begin_capture(), cuda_graph_end_capture(), and free_device_async().
|
|
156
|
+
static std::unordered_map<uint64_t, CaptureInfo*> g_captures;
|
|
157
|
+
|
|
158
|
+
// Memory allocated during graph capture requires special handling.
|
|
159
|
+
// See alloc_device_async() and free_device_async().
|
|
160
|
+
static std::unordered_map<void*, GraphAllocInfo> g_graph_allocs;
|
|
161
|
+
|
|
162
|
+
// Memory that cannot be freed immediately gets queued here.
|
|
163
|
+
// Call free_deferred_allocs() to release.
|
|
164
|
+
static std::vector<FreeInfo> g_deferred_free_list;
|
|
165
|
+
|
|
103
166
|
|
|
104
167
|
void cuda_set_context_restore_policy(bool always_restore)
|
|
105
168
|
{
|
|
@@ -116,12 +179,12 @@ int cuda_init()
|
|
|
116
179
|
if (!init_cuda_driver())
|
|
117
180
|
return -1;
|
|
118
181
|
|
|
119
|
-
int
|
|
120
|
-
if (check_cu(cuDeviceGetCount_f(&
|
|
182
|
+
int device_count = 0;
|
|
183
|
+
if (check_cu(cuDeviceGetCount_f(&device_count)))
|
|
121
184
|
{
|
|
122
|
-
g_devices.resize(
|
|
185
|
+
g_devices.resize(device_count);
|
|
123
186
|
|
|
124
|
-
for (int i = 0; i <
|
|
187
|
+
for (int i = 0; i < device_count; i++)
|
|
125
188
|
{
|
|
126
189
|
CUdevice device;
|
|
127
190
|
if (check_cu(cuDeviceGet_f(&device, i)))
|
|
@@ -135,7 +198,7 @@ int cuda_init()
|
|
|
135
198
|
check_cu(cuDeviceGetAttribute_f(&g_devices[i].pci_bus_id, CU_DEVICE_ATTRIBUTE_PCI_BUS_ID, device));
|
|
136
199
|
check_cu(cuDeviceGetAttribute_f(&g_devices[i].pci_device_id, CU_DEVICE_ATTRIBUTE_PCI_DEVICE_ID, device));
|
|
137
200
|
check_cu(cuDeviceGetAttribute_f(&g_devices[i].is_uva, CU_DEVICE_ATTRIBUTE_UNIFIED_ADDRESSING, device));
|
|
138
|
-
check_cu(cuDeviceGetAttribute_f(&g_devices[i].
|
|
201
|
+
check_cu(cuDeviceGetAttribute_f(&g_devices[i].is_mempool_supported, CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED, device));
|
|
139
202
|
int major = 0;
|
|
140
203
|
int minor = 0;
|
|
141
204
|
check_cu(cuDeviceGetAttribute_f(&major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device));
|
|
@@ -168,9 +231,9 @@ static inline CUcontext get_current_context()
|
|
|
168
231
|
return NULL;
|
|
169
232
|
}
|
|
170
233
|
|
|
171
|
-
static inline CUstream get_current_stream()
|
|
234
|
+
static inline CUstream get_current_stream(void* context=NULL)
|
|
172
235
|
{
|
|
173
|
-
return static_cast<CUstream>(cuda_context_get_stream(
|
|
236
|
+
return static_cast<CUstream>(cuda_context_get_stream(context));
|
|
174
237
|
}
|
|
175
238
|
|
|
176
239
|
static ContextInfo* get_context_info(CUcontext ctx)
|
|
@@ -191,11 +254,22 @@ static ContextInfo* get_context_info(CUcontext ctx)
|
|
|
191
254
|
{
|
|
192
255
|
// previously unseen context, add the info
|
|
193
256
|
ContextGuard guard(ctx, true);
|
|
194
|
-
|
|
257
|
+
|
|
195
258
|
CUdevice device;
|
|
196
259
|
if (check_cu(cuCtxGetDevice_f(&device)))
|
|
197
260
|
{
|
|
198
|
-
|
|
261
|
+
DeviceInfo* device_info = g_device_map[device];
|
|
262
|
+
|
|
263
|
+
// workaround for https://nvbugspro.nvidia.com/bug/4456003
|
|
264
|
+
if (device_info->is_mempool_supported)
|
|
265
|
+
{
|
|
266
|
+
void* dummy = NULL;
|
|
267
|
+
check_cuda(cudaMallocAsync(&dummy, 1, NULL));
|
|
268
|
+
check_cuda(cudaFreeAsync(dummy, NULL));
|
|
269
|
+
}
|
|
270
|
+
|
|
271
|
+
ContextInfo context_info;
|
|
272
|
+
context_info.device_info = device_info;
|
|
199
273
|
auto result = g_contexts.insert(std::make_pair(ctx, context_info));
|
|
200
274
|
return &result.first->second;
|
|
201
275
|
}
|
|
@@ -204,10 +278,116 @@ static ContextInfo* get_context_info(CUcontext ctx)
|
|
|
204
278
|
return NULL;
|
|
205
279
|
}
|
|
206
280
|
|
|
281
|
+
static inline ContextInfo* get_context_info(void* context)
|
|
282
|
+
{
|
|
283
|
+
return get_context_info(static_cast<CUcontext>(context));
|
|
284
|
+
}
|
|
285
|
+
|
|
286
|
+
static inline StreamInfo* get_stream_info(CUstream stream)
|
|
287
|
+
{
|
|
288
|
+
auto it = g_streams.find(stream);
|
|
289
|
+
if (it != g_streams.end())
|
|
290
|
+
return &it->second;
|
|
291
|
+
else
|
|
292
|
+
return NULL;
|
|
293
|
+
}
|
|
294
|
+
|
|
295
|
+
static void deferred_free(void* ptr, void* context, bool is_async)
|
|
296
|
+
{
|
|
297
|
+
FreeInfo free_info;
|
|
298
|
+
free_info.ptr = ptr;
|
|
299
|
+
free_info.context = context ? context : get_current_context();
|
|
300
|
+
free_info.is_async = is_async;
|
|
301
|
+
g_deferred_free_list.push_back(free_info);
|
|
302
|
+
}
|
|
303
|
+
|
|
304
|
+
static int free_deferred_allocs(void* context = NULL)
|
|
305
|
+
{
|
|
306
|
+
if (g_deferred_free_list.empty() || !g_captures.empty())
|
|
307
|
+
return 0;
|
|
308
|
+
|
|
309
|
+
int num_freed_allocs = 0;
|
|
310
|
+
for (auto it = g_deferred_free_list.begin(); it != g_deferred_free_list.end(); /*noop*/)
|
|
311
|
+
{
|
|
312
|
+
const FreeInfo& free_info = *it;
|
|
313
|
+
|
|
314
|
+
// free the pointer if it matches the given context or if the context is unspecified
|
|
315
|
+
if (free_info.context == context || !context)
|
|
316
|
+
{
|
|
317
|
+
ContextGuard guard(free_info.context);
|
|
318
|
+
|
|
319
|
+
if (free_info.is_async)
|
|
320
|
+
{
|
|
321
|
+
// this could be a regular stream-ordered allocation or a graph allocation
|
|
322
|
+
cudaError_t res = cudaFreeAsync(free_info.ptr, NULL);
|
|
323
|
+
if (res != cudaSuccess)
|
|
324
|
+
{
|
|
325
|
+
if (res == cudaErrorInvalidValue)
|
|
326
|
+
{
|
|
327
|
+
// This can happen if we try to release the pointer but the graph was
|
|
328
|
+
// never launched, so the memory isn't mapped.
|
|
329
|
+
// This is fine, so clear the error.
|
|
330
|
+
cudaGetLastError();
|
|
331
|
+
}
|
|
332
|
+
else
|
|
333
|
+
{
|
|
334
|
+
// something else went wrong, report error
|
|
335
|
+
check_cuda(res);
|
|
336
|
+
}
|
|
337
|
+
}
|
|
338
|
+
}
|
|
339
|
+
else
|
|
340
|
+
{
|
|
341
|
+
check_cuda(cudaFree(free_info.ptr));
|
|
342
|
+
}
|
|
343
|
+
|
|
344
|
+
++num_freed_allocs;
|
|
345
|
+
|
|
346
|
+
it = g_deferred_free_list.erase(it);
|
|
347
|
+
}
|
|
348
|
+
else
|
|
349
|
+
{
|
|
350
|
+
++it;
|
|
351
|
+
}
|
|
352
|
+
}
|
|
353
|
+
|
|
354
|
+
return num_freed_allocs;
|
|
355
|
+
}
|
|
356
|
+
|
|
357
|
+
static void CUDART_CB on_graph_destroy(void* user_data)
|
|
358
|
+
{
|
|
359
|
+
if (!user_data)
|
|
360
|
+
return;
|
|
361
|
+
|
|
362
|
+
GraphInfo* graph_info = static_cast<GraphInfo*>(user_data);
|
|
363
|
+
|
|
364
|
+
for (void* ptr : graph_info->unfreed_allocs)
|
|
365
|
+
{
|
|
366
|
+
auto alloc_iter = g_graph_allocs.find(ptr);
|
|
367
|
+
if (alloc_iter != g_graph_allocs.end())
|
|
368
|
+
{
|
|
369
|
+
GraphAllocInfo& alloc_info = alloc_iter->second;
|
|
370
|
+
if (alloc_info.ref_exists)
|
|
371
|
+
{
|
|
372
|
+
// unreference from graph so the pointer will be deallocated when the user reference goes away
|
|
373
|
+
alloc_info.graph_destroyed = true;
|
|
374
|
+
}
|
|
375
|
+
else
|
|
376
|
+
{
|
|
377
|
+
// the pointer can be freed, but we can't call CUDA functions in this callback, so defer it
|
|
378
|
+
deferred_free(ptr, alloc_info.context, true);
|
|
379
|
+
g_graph_allocs.erase(alloc_iter);
|
|
380
|
+
}
|
|
381
|
+
}
|
|
382
|
+
}
|
|
383
|
+
|
|
384
|
+
delete graph_info;
|
|
385
|
+
}
|
|
386
|
+
|
|
207
387
|
|
|
208
388
|
void* alloc_pinned(size_t s)
|
|
209
389
|
{
|
|
210
|
-
void* ptr;
|
|
390
|
+
void* ptr = NULL;
|
|
211
391
|
check_cuda(cudaMallocHost(&ptr, s));
|
|
212
392
|
return ptr;
|
|
213
393
|
}
|
|
@@ -218,84 +398,320 @@ void free_pinned(void* ptr)
|
|
|
218
398
|
}
|
|
219
399
|
|
|
220
400
|
void* alloc_device(void* context, size_t s)
|
|
401
|
+
{
|
|
402
|
+
int ordinal = cuda_context_get_device_ordinal(context);
|
|
403
|
+
|
|
404
|
+
// use stream-ordered allocator if available
|
|
405
|
+
if (cuda_device_is_mempool_supported(ordinal))
|
|
406
|
+
return alloc_device_async(context, s);
|
|
407
|
+
else
|
|
408
|
+
return alloc_device_default(context, s);
|
|
409
|
+
}
|
|
410
|
+
|
|
411
|
+
void free_device(void* context, void* ptr)
|
|
412
|
+
{
|
|
413
|
+
int ordinal = cuda_context_get_device_ordinal(context);
|
|
414
|
+
|
|
415
|
+
// use stream-ordered allocator if available
|
|
416
|
+
if (cuda_device_is_mempool_supported(ordinal))
|
|
417
|
+
free_device_async(context, ptr);
|
|
418
|
+
else
|
|
419
|
+
free_device_default(context, ptr);
|
|
420
|
+
}
|
|
421
|
+
|
|
422
|
+
void* alloc_device_default(void* context, size_t s)
|
|
221
423
|
{
|
|
222
424
|
ContextGuard guard(context);
|
|
223
425
|
|
|
224
|
-
void* ptr;
|
|
426
|
+
void* ptr = NULL;
|
|
225
427
|
check_cuda(cudaMalloc(&ptr, s));
|
|
428
|
+
|
|
226
429
|
return ptr;
|
|
227
430
|
}
|
|
228
431
|
|
|
229
|
-
void
|
|
432
|
+
void free_device_default(void* context, void* ptr)
|
|
230
433
|
{
|
|
231
|
-
// "cudaMallocAsync ignores the current device/context when determining where the allocation will reside. Instead,
|
|
232
|
-
// cudaMallocAsync determines the resident device based on the specified memory pool or the supplied stream."
|
|
233
434
|
ContextGuard guard(context);
|
|
234
435
|
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
if (cuda_context_is_memory_pool_supported(context))
|
|
436
|
+
// check if a capture is in progress
|
|
437
|
+
if (g_captures.empty())
|
|
238
438
|
{
|
|
239
|
-
check_cuda(
|
|
439
|
+
check_cuda(cudaFree(ptr));
|
|
240
440
|
}
|
|
241
441
|
else
|
|
242
442
|
{
|
|
243
|
-
|
|
443
|
+
// we must defer the operation until graph captures complete
|
|
444
|
+
deferred_free(ptr, context, false);
|
|
244
445
|
}
|
|
245
|
-
|
|
246
|
-
return ptr;
|
|
247
446
|
}
|
|
248
447
|
|
|
249
|
-
void
|
|
448
|
+
void* alloc_device_async(void* context, size_t s)
|
|
250
449
|
{
|
|
450
|
+
// stream-ordered allocations don't rely on the current context,
|
|
451
|
+
// but we set the context here for consistent behaviour
|
|
251
452
|
ContextGuard guard(context);
|
|
252
453
|
|
|
253
|
-
|
|
454
|
+
ContextInfo* context_info = get_context_info(context);
|
|
455
|
+
if (!context_info)
|
|
456
|
+
return NULL;
|
|
457
|
+
|
|
458
|
+
CUstream stream = context_info->stream;
|
|
459
|
+
|
|
460
|
+
void* ptr = NULL;
|
|
461
|
+
check_cuda(cudaMallocAsync(&ptr, s, stream));
|
|
462
|
+
|
|
463
|
+
if (ptr)
|
|
464
|
+
{
|
|
465
|
+
// if the stream is capturing, the allocation requires special handling
|
|
466
|
+
if (cuda_stream_is_capturing(stream))
|
|
467
|
+
{
|
|
468
|
+
// check if this is a known capture
|
|
469
|
+
uint64_t capture_id = get_capture_id(stream);
|
|
470
|
+
auto capture_iter = g_captures.find(capture_id);
|
|
471
|
+
if (capture_iter != g_captures.end())
|
|
472
|
+
{
|
|
473
|
+
// remember graph allocation details
|
|
474
|
+
GraphAllocInfo alloc_info;
|
|
475
|
+
alloc_info.capture_id = capture_id;
|
|
476
|
+
alloc_info.context = context ? context : get_current_context();
|
|
477
|
+
alloc_info.ref_exists = true; // user reference created and returned here
|
|
478
|
+
alloc_info.graph_destroyed = false; // graph not destroyed yet
|
|
479
|
+
g_graph_allocs[ptr] = alloc_info;
|
|
480
|
+
}
|
|
481
|
+
}
|
|
482
|
+
}
|
|
483
|
+
|
|
484
|
+
return ptr;
|
|
254
485
|
}
|
|
255
486
|
|
|
256
|
-
void
|
|
487
|
+
void free_device_async(void* context, void* ptr)
|
|
257
488
|
{
|
|
489
|
+
// stream-ordered allocators generally don't rely on the current context,
|
|
490
|
+
// but we set the context here for consistent behaviour
|
|
258
491
|
ContextGuard guard(context);
|
|
259
492
|
|
|
260
|
-
|
|
493
|
+
// NB: Stream-ordered deallocations are tricky, because the memory could still be used on another stream
|
|
494
|
+
// or even multiple streams. To avoid use-after-free errors, we need to ensure that all preceding work
|
|
495
|
+
// completes before releasing the memory. The strategy is different for regular stream-ordered allocations
|
|
496
|
+
// and allocations made during graph capture. See below for details.
|
|
497
|
+
|
|
498
|
+
// check if this allocation was made during graph capture
|
|
499
|
+
auto alloc_iter = g_graph_allocs.find(ptr);
|
|
500
|
+
if (alloc_iter == g_graph_allocs.end())
|
|
261
501
|
{
|
|
262
|
-
|
|
502
|
+
// Not a graph allocation.
|
|
503
|
+
// Check if graph capture is ongoing.
|
|
504
|
+
if (g_captures.empty())
|
|
505
|
+
{
|
|
506
|
+
// cudaFreeAsync on the null stream does not block or trigger synchronization, but it postpones
|
|
507
|
+
// the deallocation until a synchronization point is reached, so preceding work on this pointer
|
|
508
|
+
// should safely complete.
|
|
509
|
+
check_cuda(cudaFreeAsync(ptr, NULL));
|
|
510
|
+
}
|
|
511
|
+
else
|
|
512
|
+
{
|
|
513
|
+
// We must defer the free operation until graph capture completes.
|
|
514
|
+
deferred_free(ptr, context, true);
|
|
515
|
+
}
|
|
263
516
|
}
|
|
264
517
|
else
|
|
265
518
|
{
|
|
266
|
-
|
|
519
|
+
// get the graph allocation details
|
|
520
|
+
GraphAllocInfo& alloc_info = alloc_iter->second;
|
|
521
|
+
|
|
522
|
+
uint64_t capture_id = alloc_info.capture_id;
|
|
523
|
+
|
|
524
|
+
// check if the capture is still active
|
|
525
|
+
auto capture_iter = g_captures.find(capture_id);
|
|
526
|
+
if (capture_iter != g_captures.end())
|
|
527
|
+
{
|
|
528
|
+
// Add a mem free node. Use all current leaf nodes as dependencies to ensure that all prior
|
|
529
|
+
// work completes before deallocating. This works with both Warp-initiated and external captures
|
|
530
|
+
// and avoids the need to explicitly track all streams used during the capture.
|
|
531
|
+
CaptureInfo* capture = capture_iter->second;
|
|
532
|
+
cudaGraph_t graph = get_capture_graph(capture->stream);
|
|
533
|
+
std::vector<cudaGraphNode_t> leaf_nodes;
|
|
534
|
+
if (graph && get_graph_leaf_nodes(graph, leaf_nodes))
|
|
535
|
+
{
|
|
536
|
+
cudaGraphNode_t free_node;
|
|
537
|
+
check_cuda(cudaGraphAddMemFreeNode(&free_node, graph, leaf_nodes.data(), leaf_nodes.size(), ptr));
|
|
538
|
+
}
|
|
539
|
+
|
|
540
|
+
// we're done with this allocation, it's owned by the graph
|
|
541
|
+
g_graph_allocs.erase(alloc_iter);
|
|
542
|
+
}
|
|
543
|
+
else
|
|
544
|
+
{
|
|
545
|
+
// the capture has ended
|
|
546
|
+
// if the owning graph was already destroyed, we can free the pointer now
|
|
547
|
+
if (alloc_info.graph_destroyed)
|
|
548
|
+
{
|
|
549
|
+
if (g_captures.empty())
|
|
550
|
+
{
|
|
551
|
+
// try to free the pointer now
|
|
552
|
+
cudaError_t res = cudaFreeAsync(ptr, NULL);
|
|
553
|
+
if (res == cudaErrorInvalidValue)
|
|
554
|
+
{
|
|
555
|
+
// This can happen if we try to release the pointer but the graph was
|
|
556
|
+
// never launched, so the memory isn't mapped.
|
|
557
|
+
// This is fine, so clear the error.
|
|
558
|
+
cudaGetLastError();
|
|
559
|
+
}
|
|
560
|
+
else
|
|
561
|
+
{
|
|
562
|
+
// check for other errors
|
|
563
|
+
check_cuda(res);
|
|
564
|
+
}
|
|
565
|
+
}
|
|
566
|
+
else
|
|
567
|
+
{
|
|
568
|
+
// We must defer the operation until graph capture completes.
|
|
569
|
+
deferred_free(ptr, context, true);
|
|
570
|
+
}
|
|
571
|
+
|
|
572
|
+
// we're done with this allocation
|
|
573
|
+
g_graph_allocs.erase(alloc_iter);
|
|
574
|
+
}
|
|
575
|
+
else
|
|
576
|
+
{
|
|
577
|
+
// graph still exists
|
|
578
|
+
// unreference the pointer so it will be deallocated once the graph instance is destroyed
|
|
579
|
+
alloc_info.ref_exists = false;
|
|
580
|
+
}
|
|
581
|
+
}
|
|
267
582
|
}
|
|
268
583
|
}
|
|
269
584
|
|
|
270
|
-
|
|
585
|
+
bool memcpy_h2d(void* context, void* dest, void* src, size_t n, void* stream)
|
|
271
586
|
{
|
|
272
587
|
ContextGuard guard(context);
|
|
273
|
-
|
|
274
|
-
|
|
588
|
+
|
|
589
|
+
CUstream cuda_stream;
|
|
590
|
+
if (stream != WP_CURRENT_STREAM)
|
|
591
|
+
cuda_stream = static_cast<CUstream>(stream);
|
|
592
|
+
else
|
|
593
|
+
cuda_stream = get_current_stream(context);
|
|
594
|
+
|
|
595
|
+
return check_cuda(cudaMemcpyAsync(dest, src, n, cudaMemcpyHostToDevice, cuda_stream));
|
|
275
596
|
}
|
|
276
597
|
|
|
277
|
-
|
|
598
|
+
bool memcpy_d2h(void* context, void* dest, void* src, size_t n, void* stream)
|
|
278
599
|
{
|
|
279
600
|
ContextGuard guard(context);
|
|
280
601
|
|
|
281
|
-
|
|
602
|
+
CUstream cuda_stream;
|
|
603
|
+
if (stream != WP_CURRENT_STREAM)
|
|
604
|
+
cuda_stream = static_cast<CUstream>(stream);
|
|
605
|
+
else
|
|
606
|
+
cuda_stream = get_current_stream(context);
|
|
607
|
+
|
|
608
|
+
return check_cuda(cudaMemcpyAsync(dest, src, n, cudaMemcpyDeviceToHost, cuda_stream));
|
|
282
609
|
}
|
|
283
610
|
|
|
284
|
-
|
|
611
|
+
bool memcpy_d2d(void* context, void* dest, void* src, size_t n, void* stream)
|
|
285
612
|
{
|
|
286
613
|
ContextGuard guard(context);
|
|
287
614
|
|
|
288
|
-
|
|
615
|
+
CUstream cuda_stream;
|
|
616
|
+
if (stream != WP_CURRENT_STREAM)
|
|
617
|
+
cuda_stream = static_cast<CUstream>(stream);
|
|
618
|
+
else
|
|
619
|
+
cuda_stream = get_current_stream(context);
|
|
620
|
+
|
|
621
|
+
return check_cuda(cudaMemcpyAsync(dest, src, n, cudaMemcpyDeviceToDevice, cuda_stream));
|
|
289
622
|
}
|
|
290
623
|
|
|
291
|
-
|
|
624
|
+
bool memcpy_p2p(void* dst_context, void* dst, void* src_context, void* src, size_t n, void* stream)
|
|
292
625
|
{
|
|
293
|
-
ContextGuard guard(context);
|
|
626
|
+
// ContextGuard guard(context);
|
|
627
|
+
|
|
628
|
+
CUstream cuda_stream;
|
|
629
|
+
if (stream != WP_CURRENT_STREAM)
|
|
630
|
+
cuda_stream = static_cast<CUstream>(stream);
|
|
631
|
+
else
|
|
632
|
+
cuda_stream = get_current_stream(dst_context);
|
|
633
|
+
|
|
634
|
+
// Notes:
|
|
635
|
+
// - cuMemcpyPeerAsync() works fine with both regular and pooled allocations (cudaMalloc() and cudaMallocAsync(), respectively)
|
|
636
|
+
// when not capturing a graph.
|
|
637
|
+
// - cuMemcpyPeerAsync() is not supported during graph capture, so we must use cudaMemcpyAsync() with kind=cudaMemcpyDefault.
|
|
638
|
+
// - cudaMemcpyAsync() works fine with regular allocations, but doesn't work with pooled allocations
|
|
639
|
+
// unless mempool access has been enabled.
|
|
640
|
+
// - There is no reliable way to check if mempool access is enabled during graph capture,
|
|
641
|
+
// because cudaMemPoolGetAccess() cannot be called during graph capture.
|
|
642
|
+
// - CUDA will report error 1 (invalid argument) if cudaMemcpyAsync() is called but mempool access is not enabled.
|
|
643
|
+
|
|
644
|
+
if (!cuda_stream_is_capturing(stream))
|
|
645
|
+
{
|
|
646
|
+
return check_cu(cuMemcpyPeerAsync_f(
|
|
647
|
+
(CUdeviceptr)dst, (CUcontext)dst_context,
|
|
648
|
+
(CUdeviceptr)src, (CUcontext)src_context,
|
|
649
|
+
n, cuda_stream));
|
|
650
|
+
}
|
|
651
|
+
else
|
|
652
|
+
{
|
|
653
|
+
cudaError_t result = cudaSuccess;
|
|
654
|
+
|
|
655
|
+
// cudaMemcpyAsync() is sensitive to the bound context to resolve pointer locations.
|
|
656
|
+
// If fails with cudaErrorInvalidValue if it cannot resolve an argument.
|
|
657
|
+
// We first try the copy in the destination context, then if it fails we retry in the source context.
|
|
658
|
+
// The cudaErrorInvalidValue error doesn't cause graph capture to fail, so it's ok to retry.
|
|
659
|
+
// Since this trial-and-error shenanigans only happens during capture, there
|
|
660
|
+
// is no perf impact when the graph is launched.
|
|
661
|
+
// For bonus points, this approach simplifies memory pool access requirements.
|
|
662
|
+
// Access only needs to be enabled one way, either from the source device to the destination device
|
|
663
|
+
// or vice versa. Sometimes, when it's really quiet, you can actually hear my genius.
|
|
664
|
+
{
|
|
665
|
+
// try doing the copy in the destination context
|
|
666
|
+
ContextGuard guard(dst_context);
|
|
667
|
+
result = cudaMemcpyAsync(dst, src, n, cudaMemcpyDefault, cuda_stream);
|
|
294
668
|
|
|
295
|
-
|
|
296
|
-
|
|
669
|
+
if (result != cudaSuccess)
|
|
670
|
+
{
|
|
671
|
+
// clear error in destination context
|
|
672
|
+
cudaGetLastError();
|
|
673
|
+
|
|
674
|
+
// try doing the copy in the source context
|
|
675
|
+
ContextGuard guard(src_context);
|
|
676
|
+
result = cudaMemcpyAsync(dst, src, n, cudaMemcpyDefault, cuda_stream);
|
|
677
|
+
|
|
678
|
+
// clear error in source context
|
|
679
|
+
cudaGetLastError();
|
|
680
|
+
}
|
|
681
|
+
}
|
|
682
|
+
|
|
683
|
+
// If the copy failed, try to detect if mempool allocations are involved to generate a helpful error message.
|
|
684
|
+
if (!check_cuda(result))
|
|
685
|
+
{
|
|
686
|
+
if (result == cudaErrorInvalidValue && src != NULL && dst != NULL)
|
|
687
|
+
{
|
|
688
|
+
// check if either of the pointers was allocated from a mempool
|
|
689
|
+
void* src_mempool = NULL;
|
|
690
|
+
void* dst_mempool = NULL;
|
|
691
|
+
cuPointerGetAttribute_f(&src_mempool, CU_POINTER_ATTRIBUTE_MEMPOOL_HANDLE, (CUdeviceptr)src);
|
|
692
|
+
cuPointerGetAttribute_f(&dst_mempool, CU_POINTER_ATTRIBUTE_MEMPOOL_HANDLE, (CUdeviceptr)dst);
|
|
693
|
+
cudaGetLastError(); // clear any errors
|
|
694
|
+
// check if either of the pointers was allocated during graph capture
|
|
695
|
+
auto src_alloc = g_graph_allocs.find(src);
|
|
696
|
+
auto dst_alloc = g_graph_allocs.find(dst);
|
|
697
|
+
if (src_mempool != NULL || src_alloc != g_graph_allocs.end() ||
|
|
698
|
+
dst_mempool != NULL || dst_alloc != g_graph_allocs.end())
|
|
699
|
+
{
|
|
700
|
+
wp::append_error_string("*** CUDA mempool allocations were used in a peer-to-peer copy during graph capture.");
|
|
701
|
+
wp::append_error_string("*** This operation fails if mempool access is not enabled between the peer devices.");
|
|
702
|
+
wp::append_error_string("*** Either enable mempool access between the devices or use the default CUDA allocator");
|
|
703
|
+
wp::append_error_string("*** to pre-allocate the arrays before graph capture begins.");
|
|
704
|
+
}
|
|
705
|
+
}
|
|
706
|
+
|
|
707
|
+
return false;
|
|
708
|
+
}
|
|
709
|
+
|
|
710
|
+
return true;
|
|
711
|
+
}
|
|
297
712
|
}
|
|
298
713
|
|
|
714
|
+
|
|
299
715
|
__global__ void memset_kernel(int* dest, int value, size_t n)
|
|
300
716
|
{
|
|
301
717
|
const size_t tid = wp::grid_index();
|
|
@@ -378,14 +794,15 @@ void memtile_device(void* context, void* dst, const void* src, size_t srcsize, s
|
|
|
378
794
|
{
|
|
379
795
|
// generic version
|
|
380
796
|
|
|
797
|
+
// copy value to device memory
|
|
381
798
|
// TODO: use a persistent stream-local staging buffer to avoid allocs?
|
|
382
|
-
void*
|
|
383
|
-
check_cuda(
|
|
384
|
-
|
|
799
|
+
void* src_devptr = alloc_device(WP_CURRENT_CONTEXT, srcsize);
|
|
800
|
+
check_cuda(cudaMemcpyAsync(src_devptr, src, srcsize, cudaMemcpyHostToDevice, get_current_stream()));
|
|
801
|
+
|
|
802
|
+
wp_launch_device(WP_CURRENT_CONTEXT, memtile_kernel, n, (dst, src_devptr, srcsize, n));
|
|
385
803
|
|
|
386
|
-
|
|
804
|
+
free_device(WP_CURRENT_CONTEXT, src_devptr);
|
|
387
805
|
|
|
388
|
-
check_cuda(cudaFree(src_device));
|
|
389
806
|
}
|
|
390
807
|
}
|
|
391
808
|
|
|
@@ -611,15 +1028,13 @@ static __global__ void array_copy_fabric_indexed_to_fabric_indexed_kernel(wp::in
|
|
|
611
1028
|
}
|
|
612
1029
|
|
|
613
1030
|
|
|
614
|
-
WP_API
|
|
1031
|
+
WP_API bool array_copy_device(void* context, void* dst, void* src, int dst_type, int src_type, int elem_size)
|
|
615
1032
|
{
|
|
616
1033
|
if (!src || !dst)
|
|
617
|
-
return
|
|
1034
|
+
return false;
|
|
618
1035
|
|
|
619
1036
|
const void* src_data = NULL;
|
|
620
|
-
const void* src_grad = NULL;
|
|
621
1037
|
void* dst_data = NULL;
|
|
622
|
-
void* dst_grad = NULL;
|
|
623
1038
|
int src_ndim = 0;
|
|
624
1039
|
int dst_ndim = 0;
|
|
625
1040
|
const int* src_shape = NULL;
|
|
@@ -641,7 +1056,6 @@ WP_API size_t array_copy_device(void* context, void* dst, void* src, int dst_typ
|
|
|
641
1056
|
{
|
|
642
1057
|
const wp::array_t<void>& src_arr = *static_cast<const wp::array_t<void>*>(src);
|
|
643
1058
|
src_data = src_arr.data;
|
|
644
|
-
src_grad = src_arr.grad;
|
|
645
1059
|
src_ndim = src_arr.ndim;
|
|
646
1060
|
src_shape = src_arr.shape.dims;
|
|
647
1061
|
src_strides = src_arr.strides;
|
|
@@ -669,14 +1083,13 @@ WP_API size_t array_copy_device(void* context, void* dst, void* src, int dst_typ
|
|
|
669
1083
|
else
|
|
670
1084
|
{
|
|
671
1085
|
fprintf(stderr, "Warp copy error: Invalid array type (%d)\n", src_type);
|
|
672
|
-
return
|
|
1086
|
+
return false;
|
|
673
1087
|
}
|
|
674
1088
|
|
|
675
1089
|
if (dst_type == wp::ARRAY_TYPE_REGULAR)
|
|
676
1090
|
{
|
|
677
1091
|
const wp::array_t<void>& dst_arr = *static_cast<const wp::array_t<void>*>(dst);
|
|
678
1092
|
dst_data = dst_arr.data;
|
|
679
|
-
dst_grad = dst_arr.grad;
|
|
680
1093
|
dst_ndim = dst_arr.ndim;
|
|
681
1094
|
dst_shape = dst_arr.shape.dims;
|
|
682
1095
|
dst_strides = dst_arr.strides;
|
|
@@ -704,13 +1117,13 @@ WP_API size_t array_copy_device(void* context, void* dst, void* src, int dst_typ
|
|
|
704
1117
|
else
|
|
705
1118
|
{
|
|
706
1119
|
fprintf(stderr, "Warp copy error: Invalid array type (%d)\n", dst_type);
|
|
707
|
-
return
|
|
1120
|
+
return false;
|
|
708
1121
|
}
|
|
709
1122
|
|
|
710
1123
|
if (src_ndim != dst_ndim)
|
|
711
1124
|
{
|
|
712
1125
|
fprintf(stderr, "Warp copy error: Incompatible array dimensionalities (%d and %d)\n", src_ndim, dst_ndim);
|
|
713
|
-
return
|
|
1126
|
+
return false;
|
|
714
1127
|
}
|
|
715
1128
|
|
|
716
1129
|
ContextGuard guard(context);
|
|
@@ -725,11 +1138,11 @@ WP_API size_t array_copy_device(void* context, void* dst, void* src, int dst_typ
|
|
|
725
1138
|
if (src_fabricarray->size != n)
|
|
726
1139
|
{
|
|
727
1140
|
fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
|
|
728
|
-
return
|
|
1141
|
+
return false;
|
|
729
1142
|
}
|
|
730
1143
|
wp_launch_device(WP_CURRENT_CONTEXT, array_copy_fabric_to_fabric_kernel, n,
|
|
731
1144
|
(*dst_fabricarray, *src_fabricarray, elem_size));
|
|
732
|
-
return
|
|
1145
|
+
return true;
|
|
733
1146
|
}
|
|
734
1147
|
else if (src_indexedfabricarray)
|
|
735
1148
|
{
|
|
@@ -737,11 +1150,11 @@ WP_API size_t array_copy_device(void* context, void* dst, void* src, int dst_typ
|
|
|
737
1150
|
if (src_indexedfabricarray->size != n)
|
|
738
1151
|
{
|
|
739
1152
|
fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
|
|
740
|
-
return
|
|
1153
|
+
return false;
|
|
741
1154
|
}
|
|
742
1155
|
wp_launch_device(WP_CURRENT_CONTEXT, array_copy_fabric_indexed_to_fabric_kernel, n,
|
|
743
1156
|
(*dst_fabricarray, *src_indexedfabricarray, elem_size));
|
|
744
|
-
return
|
|
1157
|
+
return true;
|
|
745
1158
|
}
|
|
746
1159
|
else
|
|
747
1160
|
{
|
|
@@ -749,11 +1162,11 @@ WP_API size_t array_copy_device(void* context, void* dst, void* src, int dst_typ
|
|
|
749
1162
|
if (size_t(src_shape[0]) != n)
|
|
750
1163
|
{
|
|
751
1164
|
fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
|
|
752
|
-
return
|
|
1165
|
+
return false;
|
|
753
1166
|
}
|
|
754
1167
|
wp_launch_device(WP_CURRENT_CONTEXT, array_copy_to_fabric_kernel, n,
|
|
755
1168
|
(*dst_fabricarray, src_data, src_strides[0], src_indices[0], elem_size));
|
|
756
|
-
return
|
|
1169
|
+
return true;
|
|
757
1170
|
}
|
|
758
1171
|
}
|
|
759
1172
|
if (dst_indexedfabricarray)
|
|
@@ -765,11 +1178,11 @@ WP_API size_t array_copy_device(void* context, void* dst, void* src, int dst_typ
|
|
|
765
1178
|
if (src_fabricarray->size != n)
|
|
766
1179
|
{
|
|
767
1180
|
fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
|
|
768
|
-
return
|
|
1181
|
+
return false;
|
|
769
1182
|
}
|
|
770
1183
|
wp_launch_device(WP_CURRENT_CONTEXT, array_copy_fabric_to_fabric_indexed_kernel, n,
|
|
771
1184
|
(*dst_indexedfabricarray, *src_fabricarray, elem_size));
|
|
772
|
-
return
|
|
1185
|
+
return true;
|
|
773
1186
|
}
|
|
774
1187
|
else if (src_indexedfabricarray)
|
|
775
1188
|
{
|
|
@@ -777,11 +1190,11 @@ WP_API size_t array_copy_device(void* context, void* dst, void* src, int dst_typ
|
|
|
777
1190
|
if (src_indexedfabricarray->size != n)
|
|
778
1191
|
{
|
|
779
1192
|
fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
|
|
780
|
-
return
|
|
1193
|
+
return false;
|
|
781
1194
|
}
|
|
782
1195
|
wp_launch_device(WP_CURRENT_CONTEXT, array_copy_fabric_indexed_to_fabric_indexed_kernel, n,
|
|
783
1196
|
(*dst_indexedfabricarray, *src_indexedfabricarray, elem_size));
|
|
784
|
-
return
|
|
1197
|
+
return true;
|
|
785
1198
|
}
|
|
786
1199
|
else
|
|
787
1200
|
{
|
|
@@ -789,11 +1202,11 @@ WP_API size_t array_copy_device(void* context, void* dst, void* src, int dst_typ
|
|
|
789
1202
|
if (size_t(src_shape[0]) != n)
|
|
790
1203
|
{
|
|
791
1204
|
fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
|
|
792
|
-
return
|
|
1205
|
+
return false;
|
|
793
1206
|
}
|
|
794
1207
|
wp_launch_device(WP_CURRENT_CONTEXT, array_copy_to_fabric_indexed_kernel, n,
|
|
795
1208
|
(*dst_indexedfabricarray, src_data, src_strides[0], src_indices[0], elem_size));
|
|
796
|
-
return
|
|
1209
|
+
return true;
|
|
797
1210
|
}
|
|
798
1211
|
}
|
|
799
1212
|
else if (src_fabricarray)
|
|
@@ -803,11 +1216,11 @@ WP_API size_t array_copy_device(void* context, void* dst, void* src, int dst_typ
|
|
|
803
1216
|
if (size_t(dst_shape[0]) != n)
|
|
804
1217
|
{
|
|
805
1218
|
fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
|
|
806
|
-
return
|
|
1219
|
+
return false;
|
|
807
1220
|
}
|
|
808
1221
|
wp_launch_device(WP_CURRENT_CONTEXT, array_copy_from_fabric_kernel, n,
|
|
809
1222
|
(*src_fabricarray, dst_data, dst_strides[0], dst_indices[0], elem_size));
|
|
810
|
-
return
|
|
1223
|
+
return true;
|
|
811
1224
|
}
|
|
812
1225
|
else if (src_indexedfabricarray)
|
|
813
1226
|
{
|
|
@@ -816,11 +1229,11 @@ WP_API size_t array_copy_device(void* context, void* dst, void* src, int dst_typ
|
|
|
816
1229
|
if (size_t(dst_shape[0]) != n)
|
|
817
1230
|
{
|
|
818
1231
|
fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
|
|
819
|
-
return
|
|
1232
|
+
return false;
|
|
820
1233
|
}
|
|
821
1234
|
wp_launch_device(WP_CURRENT_CONTEXT, array_copy_from_fabric_indexed_kernel, n,
|
|
822
1235
|
(*src_indexedfabricarray, dst_data, dst_strides[0], dst_indices[0], elem_size));
|
|
823
|
-
return
|
|
1236
|
+
return true;
|
|
824
1237
|
}
|
|
825
1238
|
|
|
826
1239
|
size_t n = 1;
|
|
@@ -829,7 +1242,7 @@ WP_API size_t array_copy_device(void* context, void* dst, void* src, int dst_typ
|
|
|
829
1242
|
if (src_shape[i] != dst_shape[i])
|
|
830
1243
|
{
|
|
831
1244
|
fprintf(stderr, "Warp copy error: Incompatible array shapes\n");
|
|
832
|
-
return
|
|
1245
|
+
return false;
|
|
833
1246
|
}
|
|
834
1247
|
n *= src_shape[i];
|
|
835
1248
|
}
|
|
@@ -888,13 +1301,10 @@ WP_API size_t array_copy_device(void* context, void* dst, void* src, int dst_typ
|
|
|
888
1301
|
}
|
|
889
1302
|
default:
|
|
890
1303
|
fprintf(stderr, "Warp copy error: invalid array dimensionality (%d)\n", src_ndim);
|
|
891
|
-
return
|
|
1304
|
+
return false;
|
|
892
1305
|
}
|
|
893
1306
|
|
|
894
|
-
|
|
895
|
-
return n;
|
|
896
|
-
else
|
|
897
|
-
return 0;
|
|
1307
|
+
return check_cuda(cudaGetLastError());
|
|
898
1308
|
}
|
|
899
1309
|
|
|
900
1310
|
|
|
@@ -1065,8 +1475,8 @@ WP_API void array_fill_device(void* context, void* arr_ptr, int arr_type, const
|
|
|
1065
1475
|
ContextGuard guard(context);
|
|
1066
1476
|
|
|
1067
1477
|
// copy value to device memory
|
|
1068
|
-
|
|
1069
|
-
|
|
1478
|
+
// TODO: use a persistent stream-local staging buffer to avoid allocs?
|
|
1479
|
+
void* value_devptr = alloc_device(WP_CURRENT_CONTEXT, value_size);
|
|
1070
1480
|
check_cuda(cudaMemcpyAsync(value_devptr, value_ptr, value_size, cudaMemcpyHostToDevice, get_current_stream()));
|
|
1071
1481
|
|
|
1072
1482
|
// handle fabric arrays
|
|
@@ -1123,6 +1533,8 @@ WP_API void array_fill_device(void* context, void* arr_ptr, int arr_type, const
|
|
|
1123
1533
|
fprintf(stderr, "Warp fill error: invalid array dimensionality (%d)\n", ndim);
|
|
1124
1534
|
return;
|
|
1125
1535
|
}
|
|
1536
|
+
|
|
1537
|
+
free_device(WP_CURRENT_CONTEXT, value_devptr);
|
|
1126
1538
|
}
|
|
1127
1539
|
|
|
1128
1540
|
void array_scan_int_device(uint64_t in, uint64_t out, int len, bool inclusive)
|
|
@@ -1178,20 +1590,20 @@ int cuda_device_get_count()
|
|
|
1178
1590
|
return count;
|
|
1179
1591
|
}
|
|
1180
1592
|
|
|
1181
|
-
void*
|
|
1593
|
+
void* cuda_device_get_primary_context(int ordinal)
|
|
1182
1594
|
{
|
|
1183
|
-
|
|
1184
|
-
|
|
1185
|
-
|
|
1186
|
-
check_cu(cuDevicePrimaryCtxRetain_f(&context, device));
|
|
1187
|
-
return context;
|
|
1188
|
-
}
|
|
1595
|
+
if (ordinal >= 0 && ordinal < int(g_devices.size()))
|
|
1596
|
+
{
|
|
1597
|
+
DeviceInfo& device_info = g_devices[ordinal];
|
|
1189
1598
|
|
|
1190
|
-
|
|
1191
|
-
|
|
1192
|
-
|
|
1193
|
-
|
|
1194
|
-
|
|
1599
|
+
// acquire the primary context if we haven't already
|
|
1600
|
+
if (!device_info.primary_context)
|
|
1601
|
+
check_cu(cuDevicePrimaryCtxRetain_f(&device_info.primary_context, device_info.device));
|
|
1602
|
+
|
|
1603
|
+
return device_info.primary_context;
|
|
1604
|
+
}
|
|
1605
|
+
|
|
1606
|
+
return NULL;
|
|
1195
1607
|
}
|
|
1196
1608
|
|
|
1197
1609
|
const char* cuda_device_get_name(int ordinal)
|
|
@@ -1241,13 +1653,105 @@ int cuda_device_is_uva(int ordinal)
|
|
|
1241
1653
|
return 0;
|
|
1242
1654
|
}
|
|
1243
1655
|
|
|
1244
|
-
int
|
|
1656
|
+
int cuda_device_is_mempool_supported(int ordinal)
|
|
1245
1657
|
{
|
|
1246
1658
|
if (ordinal >= 0 && ordinal < int(g_devices.size()))
|
|
1247
|
-
return g_devices[ordinal].
|
|
1248
|
-
return
|
|
1659
|
+
return g_devices[ordinal].is_mempool_supported;
|
|
1660
|
+
return 0;
|
|
1661
|
+
}
|
|
1662
|
+
|
|
1663
|
+
int cuda_device_set_mempool_release_threshold(int ordinal, uint64_t threshold)
|
|
1664
|
+
{
|
|
1665
|
+
if (ordinal < 0 || ordinal > int(g_devices.size()))
|
|
1666
|
+
{
|
|
1667
|
+
fprintf(stderr, "Invalid device ordinal %d\n", ordinal);
|
|
1668
|
+
return 0;
|
|
1669
|
+
}
|
|
1670
|
+
|
|
1671
|
+
if (!g_devices[ordinal].is_mempool_supported)
|
|
1672
|
+
return 0;
|
|
1673
|
+
|
|
1674
|
+
cudaMemPool_t pool;
|
|
1675
|
+
if (!check_cuda(cudaDeviceGetDefaultMemPool(&pool, ordinal)))
|
|
1676
|
+
{
|
|
1677
|
+
fprintf(stderr, "Warp error: Failed to get memory pool on device %d\n", ordinal);
|
|
1678
|
+
return 0;
|
|
1679
|
+
}
|
|
1680
|
+
|
|
1681
|
+
if (!check_cuda(cudaMemPoolSetAttribute(pool, cudaMemPoolAttrReleaseThreshold, &threshold)))
|
|
1682
|
+
{
|
|
1683
|
+
fprintf(stderr, "Warp error: Failed to set memory pool attribute on device %d\n", ordinal);
|
|
1684
|
+
return 0;
|
|
1685
|
+
}
|
|
1686
|
+
|
|
1687
|
+
return 1; // success
|
|
1688
|
+
}
|
|
1689
|
+
|
|
1690
|
+
uint64_t cuda_device_get_mempool_release_threshold(int ordinal)
|
|
1691
|
+
{
|
|
1692
|
+
if (ordinal < 0 || ordinal > int(g_devices.size()))
|
|
1693
|
+
{
|
|
1694
|
+
fprintf(stderr, "Invalid device ordinal %d\n", ordinal);
|
|
1695
|
+
return 0;
|
|
1696
|
+
}
|
|
1697
|
+
|
|
1698
|
+
if (!g_devices[ordinal].is_mempool_supported)
|
|
1699
|
+
return 0;
|
|
1700
|
+
|
|
1701
|
+
cudaMemPool_t pool;
|
|
1702
|
+
if (!check_cuda(cudaDeviceGetDefaultMemPool(&pool, ordinal)))
|
|
1703
|
+
{
|
|
1704
|
+
fprintf(stderr, "Warp error: Failed to get memory pool on device %d\n", ordinal);
|
|
1705
|
+
return 0;
|
|
1706
|
+
}
|
|
1707
|
+
|
|
1708
|
+
uint64_t threshold = 0;
|
|
1709
|
+
if (!check_cuda(cudaMemPoolGetAttribute(pool, cudaMemPoolAttrReleaseThreshold, &threshold)))
|
|
1710
|
+
{
|
|
1711
|
+
fprintf(stderr, "Warp error: Failed to get memory pool release threshold on device %d\n", ordinal);
|
|
1712
|
+
return 0;
|
|
1713
|
+
}
|
|
1714
|
+
|
|
1715
|
+
return threshold;
|
|
1716
|
+
}
|
|
1717
|
+
|
|
1718
|
+
void cuda_device_get_memory_info(int ordinal, size_t* free_mem, size_t* total_mem)
|
|
1719
|
+
{
|
|
1720
|
+
// use temporary storage if user didn't specify pointers
|
|
1721
|
+
size_t tmp_free_mem, tmp_total_mem;
|
|
1722
|
+
|
|
1723
|
+
if (free_mem)
|
|
1724
|
+
*free_mem = 0;
|
|
1725
|
+
else
|
|
1726
|
+
free_mem = &tmp_free_mem;
|
|
1727
|
+
|
|
1728
|
+
if (total_mem)
|
|
1729
|
+
*total_mem = 0;
|
|
1730
|
+
else
|
|
1731
|
+
total_mem = &tmp_total_mem;
|
|
1732
|
+
|
|
1733
|
+
if (ordinal >= 0 && ordinal < int(g_devices.size()))
|
|
1734
|
+
{
|
|
1735
|
+
if (g_devices[ordinal].primary_context)
|
|
1736
|
+
{
|
|
1737
|
+
ContextGuard guard(g_devices[ordinal].primary_context, true);
|
|
1738
|
+
check_cu(cuMemGetInfo_f(free_mem, total_mem));
|
|
1739
|
+
}
|
|
1740
|
+
else
|
|
1741
|
+
{
|
|
1742
|
+
// if we haven't acquired the primary context yet, acquire it temporarily
|
|
1743
|
+
CUcontext primary_context = NULL;
|
|
1744
|
+
check_cu(cuDevicePrimaryCtxRetain_f(&primary_context, g_devices[ordinal].device));
|
|
1745
|
+
{
|
|
1746
|
+
ContextGuard guard(primary_context, true);
|
|
1747
|
+
check_cu(cuMemGetInfo_f(free_mem, total_mem));
|
|
1748
|
+
}
|
|
1749
|
+
check_cu(cuDevicePrimaryCtxRelease_f(g_devices[ordinal].device));
|
|
1750
|
+
}
|
|
1751
|
+
}
|
|
1249
1752
|
}
|
|
1250
1753
|
|
|
1754
|
+
|
|
1251
1755
|
void* cuda_context_get_current()
|
|
1252
1756
|
{
|
|
1253
1757
|
return get_current_context();
|
|
@@ -1313,26 +1817,35 @@ void cuda_context_synchronize(void* context)
|
|
|
1313
1817
|
ContextGuard guard(context);
|
|
1314
1818
|
|
|
1315
1819
|
check_cu(cuCtxSynchronize_f());
|
|
1820
|
+
|
|
1821
|
+
if (free_deferred_allocs(context ? context : get_current_context()) > 0)
|
|
1822
|
+
{
|
|
1823
|
+
// ensure deferred asynchronous deallocations complete
|
|
1824
|
+
check_cu(cuCtxSynchronize_f());
|
|
1825
|
+
}
|
|
1826
|
+
|
|
1827
|
+
// check_cuda(cudaDeviceGraphMemTrim(cuda_context_get_device_ordinal(context)));
|
|
1316
1828
|
}
|
|
1317
1829
|
|
|
1318
1830
|
uint64_t cuda_context_check(void* context)
|
|
1319
1831
|
{
|
|
1320
1832
|
ContextGuard guard(context);
|
|
1321
1833
|
|
|
1322
|
-
|
|
1323
|
-
|
|
1834
|
+
// check errors before syncing
|
|
1835
|
+
cudaError_t e = cudaGetLastError();
|
|
1836
|
+
check_cuda(e);
|
|
1837
|
+
|
|
1838
|
+
cudaStreamCaptureStatus status = cudaStreamCaptureStatusNone;
|
|
1839
|
+
check_cuda(cudaStreamIsCapturing(get_current_stream(), &status));
|
|
1324
1840
|
|
|
1325
|
-
//
|
|
1326
|
-
// since we cannot synchronize the device
|
|
1841
|
+
// synchronize if the stream is not capturing
|
|
1327
1842
|
if (status == cudaStreamCaptureStatusNone)
|
|
1328
1843
|
{
|
|
1329
|
-
cudaDeviceSynchronize();
|
|
1330
|
-
|
|
1331
|
-
}
|
|
1332
|
-
else
|
|
1333
|
-
{
|
|
1334
|
-
return 0;
|
|
1844
|
+
check_cuda(cudaDeviceSynchronize());
|
|
1845
|
+
e = cudaGetLastError();
|
|
1335
1846
|
}
|
|
1847
|
+
|
|
1848
|
+
return static_cast<uint64_t>(e);
|
|
1336
1849
|
}
|
|
1337
1850
|
|
|
1338
1851
|
|
|
@@ -1344,25 +1857,28 @@ int cuda_context_get_device_ordinal(void* context)
|
|
|
1344
1857
|
|
|
1345
1858
|
int cuda_context_is_primary(void* context)
|
|
1346
1859
|
{
|
|
1347
|
-
|
|
1348
|
-
|
|
1860
|
+
CUcontext ctx = static_cast<CUcontext>(context);
|
|
1861
|
+
ContextInfo* context_info = get_context_info(ctx);
|
|
1862
|
+
if (!context_info)
|
|
1349
1863
|
{
|
|
1350
|
-
|
|
1351
|
-
|
|
1352
|
-
void* device_primary_context = cuda_device_primary_context_retain(ordinal);
|
|
1353
|
-
cuda_device_primary_context_release(ordinal);
|
|
1354
|
-
return int(context == device_primary_context);
|
|
1864
|
+
fprintf(stderr, "Warp error: Failed to get context info\n");
|
|
1865
|
+
return 0;
|
|
1355
1866
|
}
|
|
1356
|
-
return 0;
|
|
1357
|
-
}
|
|
1358
1867
|
|
|
1359
|
-
|
|
1360
|
-
|
|
1361
|
-
|
|
1362
|
-
|
|
1868
|
+
// if the device primary context is known, check if it matches the given context
|
|
1869
|
+
DeviceInfo* device_info = context_info->device_info;
|
|
1870
|
+
if (device_info->primary_context)
|
|
1871
|
+
return int(ctx == device_info->primary_context);
|
|
1872
|
+
|
|
1873
|
+
// there is no CUDA API to check if a context is primary, but we can temporarily
|
|
1874
|
+
// acquire the device's primary context to check the pointer
|
|
1875
|
+
CUcontext primary_ctx;
|
|
1876
|
+
if (check_cu(cuDevicePrimaryCtxRetain_f(&primary_ctx, device_info->device)))
|
|
1363
1877
|
{
|
|
1364
|
-
|
|
1878
|
+
check_cu(cuDevicePrimaryCtxRelease_f(device_info->device));
|
|
1879
|
+
return int(ctx == primary_ctx);
|
|
1365
1880
|
}
|
|
1881
|
+
|
|
1366
1882
|
return 0;
|
|
1367
1883
|
}
|
|
1368
1884
|
|
|
@@ -1376,115 +1892,251 @@ void* cuda_context_get_stream(void* context)
|
|
|
1376
1892
|
return NULL;
|
|
1377
1893
|
}
|
|
1378
1894
|
|
|
1379
|
-
void cuda_context_set_stream(void* context, void* stream)
|
|
1895
|
+
void cuda_context_set_stream(void* context, void* stream, int sync)
|
|
1380
1896
|
{
|
|
1381
|
-
ContextInfo*
|
|
1382
|
-
if (
|
|
1897
|
+
ContextInfo* context_info = get_context_info(static_cast<CUcontext>(context));
|
|
1898
|
+
if (context_info)
|
|
1383
1899
|
{
|
|
1384
|
-
|
|
1900
|
+
CUstream new_stream = static_cast<CUstream>(stream);
|
|
1901
|
+
|
|
1902
|
+
// check whether we should sync with the previous stream on this device
|
|
1903
|
+
if (sync)
|
|
1904
|
+
{
|
|
1905
|
+
CUstream old_stream = context_info->stream;
|
|
1906
|
+
StreamInfo* old_stream_info = get_stream_info(old_stream);
|
|
1907
|
+
if (old_stream_info)
|
|
1908
|
+
{
|
|
1909
|
+
CUevent cached_event = old_stream_info->cached_event;
|
|
1910
|
+
check_cu(cuEventRecord_f(cached_event, old_stream));
|
|
1911
|
+
check_cu(cuStreamWaitEvent_f(new_stream, cached_event, CU_EVENT_WAIT_DEFAULT));
|
|
1912
|
+
}
|
|
1913
|
+
}
|
|
1914
|
+
|
|
1915
|
+
context_info->stream = new_stream;
|
|
1385
1916
|
}
|
|
1386
1917
|
}
|
|
1387
1918
|
|
|
1388
|
-
|
|
1919
|
+
|
|
1920
|
+
int cuda_is_peer_access_supported(int target_ordinal, int peer_ordinal)
|
|
1389
1921
|
{
|
|
1390
|
-
|
|
1922
|
+
int num_devices = int(g_devices.size());
|
|
1923
|
+
|
|
1924
|
+
if (target_ordinal < 0 || target_ordinal > num_devices)
|
|
1391
1925
|
{
|
|
1392
|
-
fprintf(stderr, "Warp error:
|
|
1926
|
+
fprintf(stderr, "Warp error: Invalid target device ordinal %d\n", target_ordinal);
|
|
1393
1927
|
return 0;
|
|
1394
1928
|
}
|
|
1395
1929
|
|
|
1396
|
-
if (
|
|
1397
|
-
|
|
1930
|
+
if (peer_ordinal < 0 || peer_ordinal > num_devices)
|
|
1931
|
+
{
|
|
1932
|
+
fprintf(stderr, "Warp error: Invalid peer device ordinal %d\n", peer_ordinal);
|
|
1933
|
+
return 0;
|
|
1934
|
+
}
|
|
1398
1935
|
|
|
1399
|
-
|
|
1400
|
-
|
|
1936
|
+
if (target_ordinal == peer_ordinal)
|
|
1937
|
+
return 1;
|
|
1938
|
+
|
|
1939
|
+
int can_access = 0;
|
|
1940
|
+
check_cuda(cudaDeviceCanAccessPeer(&can_access, peer_ordinal, target_ordinal));
|
|
1941
|
+
|
|
1942
|
+
return can_access;
|
|
1943
|
+
}
|
|
1401
1944
|
|
|
1402
|
-
|
|
1403
|
-
|
|
1404
|
-
if (!
|
|
1945
|
+
int cuda_is_peer_access_enabled(void* target_context, void* peer_context)
|
|
1946
|
+
{
|
|
1947
|
+
if (!target_context || !peer_context)
|
|
1405
1948
|
{
|
|
1406
|
-
fprintf(stderr, "Warp error:
|
|
1949
|
+
fprintf(stderr, "Warp error: invalid CUDA context\n");
|
|
1407
1950
|
return 0;
|
|
1408
1951
|
}
|
|
1409
1952
|
|
|
1410
|
-
|
|
1411
|
-
|
|
1953
|
+
if (target_context == peer_context)
|
|
1954
|
+
return 1;
|
|
1955
|
+
|
|
1956
|
+
int target_ordinal = cuda_context_get_device_ordinal(target_context);
|
|
1957
|
+
int peer_ordinal = cuda_context_get_device_ordinal(peer_context);
|
|
1958
|
+
|
|
1959
|
+
// check if peer access is supported
|
|
1960
|
+
int can_access = 0;
|
|
1961
|
+
check_cuda(cudaDeviceCanAccessPeer(&can_access, peer_ordinal, target_ordinal));
|
|
1962
|
+
if (!can_access)
|
|
1963
|
+
return 0;
|
|
1964
|
+
|
|
1965
|
+
// There is no CUDA API to query if peer access is enabled, but we can try to enable it and check the result.
|
|
1966
|
+
|
|
1967
|
+
ContextGuard guard(peer_context, true);
|
|
1968
|
+
|
|
1969
|
+
CUcontext target_ctx = static_cast<CUcontext>(target_context);
|
|
1970
|
+
|
|
1971
|
+
CUresult result = cuCtxEnablePeerAccess_f(target_ctx, 0);
|
|
1972
|
+
if (result == CUDA_ERROR_PEER_ACCESS_ALREADY_ENABLED)
|
|
1412
1973
|
{
|
|
1413
|
-
|
|
1974
|
+
return 1;
|
|
1975
|
+
}
|
|
1976
|
+
else if (result == CUDA_SUCCESS)
|
|
1977
|
+
{
|
|
1978
|
+
// undo enablement
|
|
1979
|
+
check_cu(cuCtxDisablePeerAccess_f(target_ctx));
|
|
1980
|
+
return 0;
|
|
1981
|
+
}
|
|
1982
|
+
else
|
|
1983
|
+
{
|
|
1984
|
+
// report error
|
|
1985
|
+
check_cu(result);
|
|
1986
|
+
return 0;
|
|
1987
|
+
}
|
|
1988
|
+
}
|
|
1989
|
+
|
|
1990
|
+
int cuda_set_peer_access_enabled(void* target_context, void* peer_context, int enable)
|
|
1991
|
+
{
|
|
1992
|
+
if (!target_context || !peer_context)
|
|
1993
|
+
{
|
|
1994
|
+
fprintf(stderr, "Warp error: invalid CUDA context\n");
|
|
1995
|
+
return 0;
|
|
1996
|
+
}
|
|
1997
|
+
|
|
1998
|
+
if (target_context == peer_context)
|
|
1999
|
+
return 1; // no-op
|
|
2000
|
+
|
|
2001
|
+
int target_ordinal = cuda_context_get_device_ordinal(target_context);
|
|
2002
|
+
int peer_ordinal = cuda_context_get_device_ordinal(peer_context);
|
|
2003
|
+
|
|
2004
|
+
// check if peer access is supported
|
|
2005
|
+
int can_access = 0;
|
|
2006
|
+
check_cuda(cudaDeviceCanAccessPeer(&can_access, peer_ordinal, target_ordinal));
|
|
2007
|
+
if (!can_access)
|
|
2008
|
+
{
|
|
2009
|
+
// failure if enabling, success if disabling
|
|
2010
|
+
if (enable)
|
|
1414
2011
|
{
|
|
1415
|
-
|
|
2012
|
+
fprintf(stderr, "Warp error: device %d cannot access device %d\n", peer_ordinal, target_ordinal);
|
|
2013
|
+
return 0;
|
|
1416
2014
|
}
|
|
1417
2015
|
else
|
|
2016
|
+
return 1;
|
|
2017
|
+
}
|
|
2018
|
+
|
|
2019
|
+
ContextGuard guard(peer_context, true);
|
|
2020
|
+
|
|
2021
|
+
CUcontext target_ctx = static_cast<CUcontext>(target_context);
|
|
2022
|
+
|
|
2023
|
+
if (enable)
|
|
2024
|
+
{
|
|
2025
|
+
CUresult status = cuCtxEnablePeerAccess_f(target_ctx, 0);
|
|
2026
|
+
if (status != CUDA_SUCCESS && status != CUDA_ERROR_PEER_ACCESS_ALREADY_ENABLED)
|
|
1418
2027
|
{
|
|
1419
|
-
|
|
2028
|
+
check_cu(status);
|
|
2029
|
+
fprintf(stderr, "Warp error: failed to enable peer access from device %d to device %d\n", peer_ordinal, target_ordinal);
|
|
1420
2030
|
return 0;
|
|
1421
2031
|
}
|
|
1422
2032
|
}
|
|
1423
2033
|
else
|
|
1424
2034
|
{
|
|
1425
|
-
|
|
1426
|
-
|
|
1427
|
-
CUresult result = cuCtxEnablePeerAccess_f(peer_ctx, 0);
|
|
1428
|
-
if (result == CUDA_SUCCESS || result == CUDA_ERROR_PEER_ACCESS_ALREADY_ENABLED)
|
|
1429
|
-
{
|
|
1430
|
-
return 1; // ok
|
|
1431
|
-
}
|
|
1432
|
-
else
|
|
2035
|
+
CUresult status = cuCtxDisablePeerAccess_f(target_ctx);
|
|
2036
|
+
if (status != CUDA_SUCCESS && status != CUDA_ERROR_PEER_ACCESS_NOT_ENABLED)
|
|
1433
2037
|
{
|
|
1434
|
-
check_cu(
|
|
2038
|
+
check_cu(status);
|
|
2039
|
+
fprintf(stderr, "Warp error: failed to disable peer access from device %d to device %d\n", peer_ordinal, target_ordinal);
|
|
1435
2040
|
return 0;
|
|
1436
2041
|
}
|
|
1437
2042
|
}
|
|
2043
|
+
|
|
2044
|
+
return 1; // success
|
|
1438
2045
|
}
|
|
1439
2046
|
|
|
1440
|
-
int
|
|
2047
|
+
int cuda_is_mempool_access_enabled(int target_ordinal, int peer_ordinal)
|
|
1441
2048
|
{
|
|
1442
|
-
|
|
2049
|
+
int num_devices = int(g_devices.size());
|
|
2050
|
+
|
|
2051
|
+
if (target_ordinal < 0 || target_ordinal > num_devices)
|
|
2052
|
+
{
|
|
2053
|
+
fprintf(stderr, "Warp error: Invalid device ordinal %d\n", target_ordinal);
|
|
2054
|
+
return 0;
|
|
2055
|
+
}
|
|
2056
|
+
|
|
2057
|
+
if (peer_ordinal < 0 || peer_ordinal > num_devices)
|
|
2058
|
+
{
|
|
2059
|
+
fprintf(stderr, "Warp error: Invalid peer device ordinal %d\n", peer_ordinal);
|
|
1443
2060
|
return 0;
|
|
2061
|
+
}
|
|
1444
2062
|
|
|
1445
|
-
if (
|
|
2063
|
+
if (target_ordinal == peer_ordinal)
|
|
1446
2064
|
return 1;
|
|
1447
2065
|
|
|
1448
|
-
|
|
1449
|
-
|
|
1450
|
-
|
|
1451
|
-
|
|
1452
|
-
ContextInfo* peer_info = get_context_info(peer_ctx);
|
|
1453
|
-
if (!info || !peer_info)
|
|
2066
|
+
cudaMemPool_t pool;
|
|
2067
|
+
if (!check_cuda(cudaDeviceGetDefaultMemPool(&pool, target_ordinal)))
|
|
2068
|
+
{
|
|
2069
|
+
fprintf(stderr, "Warp error: Failed to get memory pool of device %d\n", target_ordinal);
|
|
1454
2070
|
return 0;
|
|
2071
|
+
}
|
|
2072
|
+
|
|
2073
|
+
cudaMemAccessFlags flags = cudaMemAccessFlagsProtNone;
|
|
2074
|
+
cudaMemLocation location;
|
|
2075
|
+
location.id = peer_ordinal;
|
|
2076
|
+
location.type = cudaMemLocationTypeDevice;
|
|
2077
|
+
if (check_cuda(cudaMemPoolGetAccess(&flags, pool, &location)))
|
|
2078
|
+
return int(flags != cudaMemAccessFlagsProtNone);
|
|
2079
|
+
|
|
2080
|
+
return 0;
|
|
2081
|
+
}
|
|
2082
|
+
|
|
2083
|
+
int cuda_set_mempool_access_enabled(int target_ordinal, int peer_ordinal, int enable)
|
|
2084
|
+
{
|
|
2085
|
+
int num_devices = int(g_devices.size());
|
|
1455
2086
|
|
|
1456
|
-
|
|
1457
|
-
if (info->device_info == peer_info->device_info)
|
|
2087
|
+
if (target_ordinal < 0 || target_ordinal > num_devices)
|
|
1458
2088
|
{
|
|
1459
|
-
|
|
1460
|
-
|
|
1461
|
-
|
|
1462
|
-
|
|
2089
|
+
fprintf(stderr, "Warp error: Invalid device ordinal %d\n", target_ordinal);
|
|
2090
|
+
return 0;
|
|
2091
|
+
}
|
|
2092
|
+
|
|
2093
|
+
if (peer_ordinal < 0 || peer_ordinal > num_devices)
|
|
2094
|
+
{
|
|
2095
|
+
fprintf(stderr, "Warp error: Invalid peer device ordinal %d\n", peer_ordinal);
|
|
2096
|
+
return 0;
|
|
1463
2097
|
}
|
|
2098
|
+
|
|
2099
|
+
if (target_ordinal == peer_ordinal)
|
|
2100
|
+
return 1; // no-op
|
|
2101
|
+
|
|
2102
|
+
// get the memory pool
|
|
2103
|
+
cudaMemPool_t pool;
|
|
2104
|
+
if (!check_cuda(cudaDeviceGetDefaultMemPool(&pool, target_ordinal)))
|
|
2105
|
+
{
|
|
2106
|
+
fprintf(stderr, "Warp error: Failed to get memory pool of device %d\n", target_ordinal);
|
|
2107
|
+
return 0;
|
|
2108
|
+
}
|
|
2109
|
+
|
|
2110
|
+
cudaMemAccessDesc desc;
|
|
2111
|
+
desc.location.type = cudaMemLocationTypeDevice;
|
|
2112
|
+
desc.location.id = peer_ordinal;
|
|
2113
|
+
|
|
2114
|
+
// only cudaMemAccessFlagsProtReadWrite and cudaMemAccessFlagsProtNone are supported
|
|
2115
|
+
if (enable)
|
|
2116
|
+
desc.flags = cudaMemAccessFlagsProtReadWrite;
|
|
1464
2117
|
else
|
|
2118
|
+
desc.flags = cudaMemAccessFlagsProtNone;
|
|
2119
|
+
|
|
2120
|
+
if (!check_cuda(cudaMemPoolSetAccess(pool, &desc, 1)))
|
|
1465
2121
|
{
|
|
1466
|
-
|
|
1467
|
-
|
|
1468
|
-
ContextGuard guard(ctx, true);
|
|
1469
|
-
CUresult result = cuCtxEnablePeerAccess_f(peer_ctx, 0);
|
|
1470
|
-
if (result == CUDA_SUCCESS || result == CUDA_ERROR_PEER_ACCESS_ALREADY_ENABLED)
|
|
1471
|
-
return 1;
|
|
1472
|
-
else
|
|
1473
|
-
return 0;
|
|
2122
|
+
fprintf(stderr, "Warp error: Failed to set mempool access from device %d to device %d\n", peer_ordinal, target_ordinal);
|
|
2123
|
+
return 0;
|
|
1474
2124
|
}
|
|
2125
|
+
|
|
2126
|
+
return 1; // success
|
|
1475
2127
|
}
|
|
1476
2128
|
|
|
2129
|
+
|
|
1477
2130
|
void* cuda_stream_create(void* context)
|
|
1478
2131
|
{
|
|
1479
|
-
CUcontext ctx = context ? static_cast<CUcontext>(context) : get_current_context();
|
|
1480
|
-
if (!ctx)
|
|
1481
|
-
return NULL;
|
|
1482
|
-
|
|
1483
2132
|
ContextGuard guard(context, true);
|
|
1484
2133
|
|
|
1485
2134
|
CUstream stream;
|
|
1486
2135
|
if (check_cu(cuStreamCreate_f(&stream, CU_STREAM_DEFAULT)))
|
|
2136
|
+
{
|
|
2137
|
+
cuda_stream_register(WP_CURRENT_CONTEXT, stream);
|
|
1487
2138
|
return stream;
|
|
2139
|
+
}
|
|
1488
2140
|
else
|
|
1489
2141
|
return NULL;
|
|
1490
2142
|
}
|
|
@@ -1494,20 +2146,45 @@ void cuda_stream_destroy(void* context, void* stream)
|
|
|
1494
2146
|
if (!stream)
|
|
1495
2147
|
return;
|
|
1496
2148
|
|
|
1497
|
-
|
|
1498
|
-
if (!ctx)
|
|
1499
|
-
return;
|
|
1500
|
-
|
|
1501
|
-
ContextGuard guard(context, true);
|
|
2149
|
+
cuda_stream_unregister(context, stream);
|
|
1502
2150
|
|
|
1503
2151
|
check_cu(cuStreamDestroy_f(static_cast<CUstream>(stream)));
|
|
1504
2152
|
}
|
|
1505
2153
|
|
|
1506
|
-
void
|
|
2154
|
+
void cuda_stream_register(void* context, void* stream)
|
|
1507
2155
|
{
|
|
2156
|
+
if (!stream)
|
|
2157
|
+
return;
|
|
2158
|
+
|
|
1508
2159
|
ContextGuard guard(context);
|
|
1509
2160
|
|
|
1510
|
-
|
|
2161
|
+
// populate stream info
|
|
2162
|
+
StreamInfo& stream_info = g_streams[static_cast<CUstream>(stream)];
|
|
2163
|
+
check_cu(cuEventCreate_f(&stream_info.cached_event, CU_EVENT_DISABLE_TIMING));
|
|
2164
|
+
}
|
|
2165
|
+
|
|
2166
|
+
void cuda_stream_unregister(void* context, void* stream)
|
|
2167
|
+
{
|
|
2168
|
+
if (!stream)
|
|
2169
|
+
return;
|
|
2170
|
+
|
|
2171
|
+
CUstream cuda_stream = static_cast<CUstream>(stream);
|
|
2172
|
+
|
|
2173
|
+
StreamInfo* stream_info = get_stream_info(cuda_stream);
|
|
2174
|
+
if (stream_info)
|
|
2175
|
+
{
|
|
2176
|
+
// release stream info
|
|
2177
|
+
check_cu(cuEventDestroy_f(stream_info->cached_event));
|
|
2178
|
+
g_streams.erase(cuda_stream);
|
|
2179
|
+
}
|
|
2180
|
+
|
|
2181
|
+
// make sure we don't leave dangling references to this stream
|
|
2182
|
+
ContextInfo* context_info = get_context_info(context);
|
|
2183
|
+
if (context_info)
|
|
2184
|
+
{
|
|
2185
|
+
if (cuda_stream == context_info->stream)
|
|
2186
|
+
context_info->stream = NULL;
|
|
2187
|
+
}
|
|
1511
2188
|
}
|
|
1512
2189
|
|
|
1513
2190
|
void* cuda_stream_get_current()
|
|
@@ -1515,24 +2192,33 @@ void* cuda_stream_get_current()
|
|
|
1515
2192
|
return get_current_stream();
|
|
1516
2193
|
}
|
|
1517
2194
|
|
|
1518
|
-
void
|
|
2195
|
+
void cuda_stream_synchronize(void* stream)
|
|
1519
2196
|
{
|
|
1520
|
-
|
|
2197
|
+
check_cu(cuStreamSynchronize_f(static_cast<CUstream>(stream)));
|
|
2198
|
+
}
|
|
1521
2199
|
|
|
2200
|
+
void cuda_stream_wait_event(void* stream, void* event)
|
|
2201
|
+
{
|
|
1522
2202
|
check_cu(cuStreamWaitEvent_f(static_cast<CUstream>(stream), static_cast<CUevent>(event), 0));
|
|
1523
2203
|
}
|
|
1524
2204
|
|
|
1525
|
-
void cuda_stream_wait_stream(void*
|
|
2205
|
+
void cuda_stream_wait_stream(void* stream, void* other_stream, void* event)
|
|
1526
2206
|
{
|
|
1527
|
-
ContextGuard guard(context);
|
|
1528
|
-
|
|
1529
2207
|
check_cu(cuEventRecord_f(static_cast<CUevent>(event), static_cast<CUstream>(other_stream)));
|
|
1530
2208
|
check_cu(cuStreamWaitEvent_f(static_cast<CUstream>(stream), static_cast<CUevent>(event), 0));
|
|
1531
2209
|
}
|
|
1532
2210
|
|
|
2211
|
+
int cuda_stream_is_capturing(void* stream)
|
|
2212
|
+
{
|
|
2213
|
+
cudaStreamCaptureStatus status = cudaStreamCaptureStatusNone;
|
|
2214
|
+
check_cuda(cudaStreamIsCapturing(static_cast<cudaStream_t>(stream), &status));
|
|
2215
|
+
|
|
2216
|
+
return int(status != cudaStreamCaptureStatusNone);
|
|
2217
|
+
}
|
|
2218
|
+
|
|
1533
2219
|
void* cuda_event_create(void* context, unsigned flags)
|
|
1534
2220
|
{
|
|
1535
|
-
ContextGuard guard(context);
|
|
2221
|
+
ContextGuard guard(context, true);
|
|
1536
2222
|
|
|
1537
2223
|
CUevent event;
|
|
1538
2224
|
if (check_cu(cuEventCreate_f(&event, flags)))
|
|
@@ -1541,68 +2227,217 @@ void* cuda_event_create(void* context, unsigned flags)
|
|
|
1541
2227
|
return NULL;
|
|
1542
2228
|
}
|
|
1543
2229
|
|
|
1544
|
-
void cuda_event_destroy(void*
|
|
2230
|
+
void cuda_event_destroy(void* event)
|
|
1545
2231
|
{
|
|
1546
|
-
ContextGuard guard(context, true);
|
|
1547
|
-
|
|
1548
2232
|
check_cu(cuEventDestroy_f(static_cast<CUevent>(event)));
|
|
1549
2233
|
}
|
|
1550
2234
|
|
|
1551
|
-
void cuda_event_record(void*
|
|
2235
|
+
void cuda_event_record(void* event, void* stream)
|
|
1552
2236
|
{
|
|
1553
|
-
ContextGuard guard(context);
|
|
1554
|
-
|
|
1555
2237
|
check_cu(cuEventRecord_f(static_cast<CUevent>(event), static_cast<CUstream>(stream)));
|
|
1556
2238
|
}
|
|
1557
2239
|
|
|
1558
|
-
|
|
2240
|
+
bool cuda_graph_begin_capture(void* context, void* stream, int external)
|
|
1559
2241
|
{
|
|
1560
2242
|
ContextGuard guard(context);
|
|
1561
2243
|
|
|
1562
|
-
|
|
2244
|
+
CUstream cuda_stream = static_cast<CUstream>(stream);
|
|
2245
|
+
StreamInfo* stream_info = get_stream_info(cuda_stream);
|
|
2246
|
+
if (!stream_info)
|
|
2247
|
+
{
|
|
2248
|
+
wp::set_error_string("Warp error: unknown stream");
|
|
2249
|
+
return false;
|
|
2250
|
+
}
|
|
2251
|
+
|
|
2252
|
+
if (external)
|
|
2253
|
+
{
|
|
2254
|
+
// if it's an external capture, make sure it's already active so we can get the capture id
|
|
2255
|
+
cudaStreamCaptureStatus status = cudaStreamCaptureStatusNone;
|
|
2256
|
+
if (!check_cuda(cudaStreamIsCapturing(cuda_stream, &status)))
|
|
2257
|
+
return false;
|
|
2258
|
+
if (status != cudaStreamCaptureStatusActive)
|
|
2259
|
+
{
|
|
2260
|
+
wp::set_error_string("Warp error: stream is not capturing");
|
|
2261
|
+
return false;
|
|
2262
|
+
}
|
|
2263
|
+
}
|
|
2264
|
+
else
|
|
2265
|
+
{
|
|
2266
|
+
// start the capture
|
|
2267
|
+
if (!check_cuda(cudaStreamBeginCapture(cuda_stream, cudaStreamCaptureModeGlobal)))
|
|
2268
|
+
return false;
|
|
2269
|
+
}
|
|
2270
|
+
|
|
2271
|
+
uint64_t capture_id = get_capture_id(cuda_stream);
|
|
2272
|
+
|
|
2273
|
+
CaptureInfo* capture = new CaptureInfo();
|
|
2274
|
+
capture->stream = cuda_stream;
|
|
2275
|
+
capture->id = capture_id;
|
|
2276
|
+
capture->external = bool(external);
|
|
2277
|
+
|
|
2278
|
+
// update stream info
|
|
2279
|
+
stream_info->capture = capture;
|
|
2280
|
+
|
|
2281
|
+
// add to known captures
|
|
2282
|
+
g_captures[capture_id] = capture;
|
|
2283
|
+
|
|
2284
|
+
return true;
|
|
1563
2285
|
}
|
|
1564
2286
|
|
|
1565
|
-
|
|
2287
|
+
bool cuda_graph_end_capture(void* context, void* stream, void** graph_ret)
|
|
1566
2288
|
{
|
|
1567
2289
|
ContextGuard guard(context);
|
|
1568
2290
|
|
|
1569
|
-
|
|
1570
|
-
|
|
2291
|
+
// check if this is a known stream
|
|
2292
|
+
CUstream cuda_stream = static_cast<CUstream>(stream);
|
|
2293
|
+
StreamInfo* stream_info = get_stream_info(cuda_stream);
|
|
2294
|
+
if (!stream_info)
|
|
2295
|
+
{
|
|
2296
|
+
wp::set_error_string("Warp error: unknown capture stream");
|
|
2297
|
+
return false;
|
|
2298
|
+
}
|
|
1571
2299
|
|
|
1572
|
-
if
|
|
2300
|
+
// check if this stream was used to start a capture
|
|
2301
|
+
CaptureInfo* capture = stream_info->capture;
|
|
2302
|
+
if (!capture)
|
|
1573
2303
|
{
|
|
1574
|
-
|
|
1575
|
-
|
|
2304
|
+
wp::set_error_string("Warp error: stream has no capture started");
|
|
2305
|
+
return false;
|
|
2306
|
+
}
|
|
1576
2307
|
|
|
1577
|
-
|
|
1578
|
-
|
|
1579
|
-
|
|
1580
|
-
// can use after CUDA 11.4 to permit graphs to capture cudaMallocAsync() operations
|
|
1581
|
-
check_cuda(cudaGraphInstantiateWithFlags(&graph_exec, graph, cudaGraphInstantiateFlagAutoFreeOnLaunch));
|
|
2308
|
+
// get capture info
|
|
2309
|
+
bool external = capture->external;
|
|
2310
|
+
uint64_t capture_id = capture->id;
|
|
1582
2311
|
|
|
1583
|
-
|
|
1584
|
-
|
|
2312
|
+
// clear capture info
|
|
2313
|
+
stream_info->capture = NULL;
|
|
2314
|
+
g_captures.erase(capture_id);
|
|
2315
|
+
delete capture;
|
|
2316
|
+
|
|
2317
|
+
// a lambda to clean up on exit in case of error
|
|
2318
|
+
auto clean_up = [cuda_stream, capture_id, external]()
|
|
2319
|
+
{
|
|
2320
|
+
// unreference outstanding graph allocs so that they will be released with the user reference
|
|
2321
|
+
for (auto it = g_graph_allocs.begin(); it != g_graph_allocs.end(); ++it)
|
|
2322
|
+
{
|
|
2323
|
+
GraphAllocInfo& alloc_info = it->second;
|
|
2324
|
+
if (alloc_info.capture_id == capture_id)
|
|
2325
|
+
alloc_info.graph_destroyed = true;
|
|
2326
|
+
}
|
|
2327
|
+
|
|
2328
|
+
// make sure we terminate the capture
|
|
2329
|
+
if (!external)
|
|
2330
|
+
{
|
|
2331
|
+
cudaGraph_t graph = NULL;
|
|
2332
|
+
cudaStreamEndCapture(cuda_stream, &graph);
|
|
2333
|
+
cudaGetLastError();
|
|
2334
|
+
}
|
|
2335
|
+
};
|
|
1585
2336
|
|
|
1586
|
-
|
|
2337
|
+
// get captured graph without ending the capture in case it is external
|
|
2338
|
+
cudaGraph_t graph = get_capture_graph(cuda_stream);
|
|
2339
|
+
if (!graph)
|
|
2340
|
+
{
|
|
2341
|
+
clean_up();
|
|
2342
|
+
return false;
|
|
1587
2343
|
}
|
|
1588
|
-
|
|
2344
|
+
|
|
2345
|
+
// ensure that all forked streams are joined to the main capture stream by manually
|
|
2346
|
+
// adding outstanding capture dependencies gathered from the graph leaf nodes
|
|
2347
|
+
std::vector<cudaGraphNode_t> stream_dependencies;
|
|
2348
|
+
std::vector<cudaGraphNode_t> leaf_nodes;
|
|
2349
|
+
if (get_capture_dependencies(cuda_stream, stream_dependencies) && get_graph_leaf_nodes(graph, leaf_nodes))
|
|
2350
|
+
{
|
|
2351
|
+
// compute set difference to get unjoined dependencies
|
|
2352
|
+
std::vector<cudaGraphNode_t> unjoined_dependencies;
|
|
2353
|
+
std::sort(stream_dependencies.begin(), stream_dependencies.end());
|
|
2354
|
+
std::sort(leaf_nodes.begin(), leaf_nodes.end());
|
|
2355
|
+
std::set_difference(leaf_nodes.begin(), leaf_nodes.end(),
|
|
2356
|
+
stream_dependencies.begin(), stream_dependencies.end(),
|
|
2357
|
+
std::back_inserter(unjoined_dependencies));
|
|
2358
|
+
if (!unjoined_dependencies.empty())
|
|
2359
|
+
{
|
|
2360
|
+
check_cu(cuStreamUpdateCaptureDependencies_f(cuda_stream, unjoined_dependencies.data(), unjoined_dependencies.size(),
|
|
2361
|
+
CU_STREAM_ADD_CAPTURE_DEPENDENCIES));
|
|
2362
|
+
// ensure graph is still valid
|
|
2363
|
+
if (get_capture_graph(cuda_stream) != graph)
|
|
2364
|
+
{
|
|
2365
|
+
clean_up();
|
|
2366
|
+
return false;
|
|
2367
|
+
}
|
|
2368
|
+
}
|
|
2369
|
+
}
|
|
2370
|
+
|
|
2371
|
+
// check if this graph has unfreed allocations, which require special handling
|
|
2372
|
+
std::vector<void*> unfreed_allocs;
|
|
2373
|
+
for (auto it = g_graph_allocs.begin(); it != g_graph_allocs.end(); ++it)
|
|
1589
2374
|
{
|
|
1590
|
-
|
|
2375
|
+
GraphAllocInfo& alloc_info = it->second;
|
|
2376
|
+
if (alloc_info.capture_id == capture_id)
|
|
2377
|
+
unfreed_allocs.push_back(it->first);
|
|
2378
|
+
}
|
|
2379
|
+
|
|
2380
|
+
if (!unfreed_allocs.empty())
|
|
2381
|
+
{
|
|
2382
|
+
// Create a user object that will notify us when the instantiated graph is destroyed.
|
|
2383
|
+
// This works for external captures also, since we wouldn't otherwise know when
|
|
2384
|
+
// the externally-created graph instance gets deleted.
|
|
2385
|
+
// This callback is guaranteed to arrive after the graph has finished executing on the device,
|
|
2386
|
+
// not necessarily when cudaGraphExecDestroy() is called.
|
|
2387
|
+
GraphInfo* graph_info = new GraphInfo;
|
|
2388
|
+
graph_info->unfreed_allocs = unfreed_allocs;
|
|
2389
|
+
cudaUserObject_t user_object;
|
|
2390
|
+
check_cuda(cudaUserObjectCreate(&user_object, graph_info, on_graph_destroy, 1, cudaUserObjectNoDestructorSync));
|
|
2391
|
+
check_cuda(cudaGraphRetainUserObject(graph, user_object, 1, cudaGraphUserObjectMove));
|
|
2392
|
+
|
|
2393
|
+
// ensure graph is still valid
|
|
2394
|
+
if (get_capture_graph(cuda_stream) != graph)
|
|
2395
|
+
{
|
|
2396
|
+
clean_up();
|
|
2397
|
+
return false;
|
|
2398
|
+
}
|
|
1591
2399
|
}
|
|
2400
|
+
|
|
2401
|
+
// for external captures, we don't instantiate the graph ourselves, so we're done
|
|
2402
|
+
if (external)
|
|
2403
|
+
return true;
|
|
2404
|
+
|
|
2405
|
+
cudaGraphExec_t graph_exec = NULL;
|
|
2406
|
+
|
|
2407
|
+
// end the capture
|
|
2408
|
+
if (!check_cuda(cudaStreamEndCapture(cuda_stream, &graph)))
|
|
2409
|
+
return false;
|
|
2410
|
+
|
|
2411
|
+
// enable to create debug GraphVis visualization of graph
|
|
2412
|
+
// cudaGraphDebugDotPrint(graph, "graph.dot", cudaGraphDebugDotFlagsVerbose);
|
|
2413
|
+
|
|
2414
|
+
// can use after CUDA 11.4 to permit graphs to capture cudaMallocAsync() operations
|
|
2415
|
+
if (!check_cuda(cudaGraphInstantiateWithFlags(&graph_exec, graph, cudaGraphInstantiateFlagAutoFreeOnLaunch)))
|
|
2416
|
+
return false;
|
|
2417
|
+
|
|
2418
|
+
// free source graph
|
|
2419
|
+
check_cuda(cudaGraphDestroy(graph));
|
|
2420
|
+
|
|
2421
|
+
// process deferred free list if no more captures are ongoing
|
|
2422
|
+
if (g_captures.empty())
|
|
2423
|
+
free_deferred_allocs();
|
|
2424
|
+
|
|
2425
|
+
if (graph_ret)
|
|
2426
|
+
*graph_ret = graph_exec;
|
|
2427
|
+
|
|
2428
|
+
return true;
|
|
1592
2429
|
}
|
|
1593
2430
|
|
|
1594
|
-
|
|
2431
|
+
bool cuda_graph_launch(void* graph_exec, void* stream)
|
|
1595
2432
|
{
|
|
1596
|
-
|
|
1597
|
-
|
|
1598
|
-
check_cuda(cudaGraphLaunch((cudaGraphExec_t)graph_exec, get_current_stream()));
|
|
2433
|
+
return check_cuda(cudaGraphLaunch((cudaGraphExec_t)graph_exec, (cudaStream_t)stream));
|
|
1599
2434
|
}
|
|
1600
2435
|
|
|
1601
|
-
|
|
2436
|
+
bool cuda_graph_destroy(void* context, void* graph_exec)
|
|
1602
2437
|
{
|
|
1603
2438
|
ContextGuard guard(context);
|
|
1604
2439
|
|
|
1605
|
-
check_cuda(cudaGraphExecDestroy((cudaGraphExec_t)graph_exec));
|
|
2440
|
+
return check_cuda(cudaGraphExecDestroy((cudaGraphExec_t)graph_exec));
|
|
1606
2441
|
}
|
|
1607
2442
|
|
|
1608
2443
|
size_t cuda_compile_program(const char* cuda_src, int arch, const char* include_dir, bool debug, bool verbose, bool verify_fp, bool fast_math, const char* output_path)
|
|
@@ -1880,7 +2715,7 @@ void* cuda_get_kernel(void* context, void* module, const char* name)
|
|
|
1880
2715
|
return kernel;
|
|
1881
2716
|
}
|
|
1882
2717
|
|
|
1883
|
-
size_t cuda_launch_kernel(void* context, void* kernel, size_t dim, int max_blocks, void** args)
|
|
2718
|
+
size_t cuda_launch_kernel(void* context, void* kernel, size_t dim, int max_blocks, void** args, void* stream)
|
|
1884
2719
|
{
|
|
1885
2720
|
ContextGuard guard(context);
|
|
1886
2721
|
|
|
@@ -1913,7 +2748,7 @@ size_t cuda_launch_kernel(void* context, void* kernel, size_t dim, int max_block
|
|
|
1913
2748
|
(CUfunction)kernel,
|
|
1914
2749
|
grid_dim, 1, 1,
|
|
1915
2750
|
block_dim, 1, 1,
|
|
1916
|
-
0,
|
|
2751
|
+
0, static_cast<CUstream>(stream),
|
|
1917
2752
|
args,
|
|
1918
2753
|
0);
|
|
1919
2754
|
|