warp-lang 1.9.0__py3-none-manylinux_2_34_aarch64.whl → 1.10.0__py3-none-manylinux_2_34_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +301 -287
- warp/__init__.pyi +2302 -307
- warp/_src/__init__.py +14 -0
- warp/_src/autograd.py +1077 -0
- warp/_src/build.py +620 -0
- warp/_src/build_dll.py +642 -0
- warp/{builtins.py → _src/builtins.py} +1546 -224
- warp/_src/codegen.py +4361 -0
- warp/{config.py → _src/config.py} +178 -169
- warp/_src/constants.py +59 -0
- warp/_src/context.py +8352 -0
- warp/_src/dlpack.py +464 -0
- warp/_src/fabric.py +362 -0
- warp/_src/fem/__init__.py +14 -0
- warp/_src/fem/adaptivity.py +510 -0
- warp/_src/fem/cache.py +689 -0
- warp/_src/fem/dirichlet.py +190 -0
- warp/{fem → _src/fem}/domain.py +42 -30
- warp/_src/fem/field/__init__.py +131 -0
- warp/_src/fem/field/field.py +703 -0
- warp/{fem → _src/fem}/field/nodal_field.py +32 -15
- warp/{fem → _src/fem}/field/restriction.py +3 -1
- warp/{fem → _src/fem}/field/virtual.py +55 -27
- warp/_src/fem/geometry/__init__.py +32 -0
- warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +79 -163
- warp/_src/fem/geometry/closest_point.py +99 -0
- warp/{fem → _src/fem}/geometry/deformed_geometry.py +16 -22
- warp/{fem → _src/fem}/geometry/element.py +34 -10
- warp/{fem → _src/fem}/geometry/geometry.py +50 -20
- warp/{fem → _src/fem}/geometry/grid_2d.py +14 -23
- warp/{fem → _src/fem}/geometry/grid_3d.py +14 -23
- warp/{fem → _src/fem}/geometry/hexmesh.py +42 -63
- warp/{fem → _src/fem}/geometry/nanogrid.py +256 -247
- warp/{fem → _src/fem}/geometry/partition.py +123 -63
- warp/{fem → _src/fem}/geometry/quadmesh.py +28 -45
- warp/{fem → _src/fem}/geometry/tetmesh.py +42 -63
- warp/{fem → _src/fem}/geometry/trimesh.py +28 -45
- warp/{fem → _src/fem}/integrate.py +166 -158
- warp/_src/fem/linalg.py +385 -0
- warp/_src/fem/operator.py +398 -0
- warp/_src/fem/polynomial.py +231 -0
- warp/{fem → _src/fem}/quadrature/pic_quadrature.py +17 -20
- warp/{fem → _src/fem}/quadrature/quadrature.py +97 -47
- warp/_src/fem/space/__init__.py +248 -0
- warp/{fem → _src/fem}/space/basis_function_space.py +22 -11
- warp/_src/fem/space/basis_space.py +681 -0
- warp/{fem → _src/fem}/space/dof_mapper.py +5 -3
- warp/{fem → _src/fem}/space/function_space.py +16 -13
- warp/{fem → _src/fem}/space/grid_2d_function_space.py +6 -7
- warp/{fem → _src/fem}/space/grid_3d_function_space.py +6 -4
- warp/{fem → _src/fem}/space/hexmesh_function_space.py +6 -10
- warp/{fem → _src/fem}/space/nanogrid_function_space.py +5 -9
- warp/{fem → _src/fem}/space/partition.py +119 -60
- warp/{fem → _src/fem}/space/quadmesh_function_space.py +6 -10
- warp/{fem → _src/fem}/space/restriction.py +68 -33
- warp/_src/fem/space/shape/__init__.py +152 -0
- warp/{fem → _src/fem}/space/shape/cube_shape_function.py +11 -9
- warp/{fem → _src/fem}/space/shape/shape_function.py +10 -9
- warp/{fem → _src/fem}/space/shape/square_shape_function.py +8 -6
- warp/{fem → _src/fem}/space/shape/tet_shape_function.py +5 -3
- warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +5 -3
- warp/{fem → _src/fem}/space/tetmesh_function_space.py +5 -9
- warp/_src/fem/space/topology.py +461 -0
- warp/{fem → _src/fem}/space/trimesh_function_space.py +5 -9
- warp/_src/fem/types.py +114 -0
- warp/_src/fem/utils.py +488 -0
- warp/_src/jax.py +188 -0
- warp/_src/jax_experimental/__init__.py +14 -0
- warp/_src/jax_experimental/custom_call.py +389 -0
- warp/_src/jax_experimental/ffi.py +1286 -0
- warp/_src/jax_experimental/xla_ffi.py +658 -0
- warp/_src/marching_cubes.py +710 -0
- warp/_src/math.py +416 -0
- warp/_src/optim/__init__.py +14 -0
- warp/_src/optim/adam.py +165 -0
- warp/_src/optim/linear.py +1608 -0
- warp/_src/optim/sgd.py +114 -0
- warp/_src/paddle.py +408 -0
- warp/_src/render/__init__.py +14 -0
- warp/_src/render/imgui_manager.py +291 -0
- warp/_src/render/render_opengl.py +3638 -0
- warp/_src/render/render_usd.py +939 -0
- warp/_src/render/utils.py +162 -0
- warp/_src/sparse.py +2718 -0
- warp/_src/tape.py +1208 -0
- warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
- warp/_src/torch.py +393 -0
- warp/_src/types.py +5888 -0
- warp/_src/utils.py +1695 -0
- warp/autograd.py +12 -1054
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/build.py +8 -588
- warp/build_dll.py +6 -471
- warp/codegen.py +6 -4246
- warp/constants.py +6 -39
- warp/context.py +12 -7851
- warp/dlpack.py +6 -444
- warp/examples/distributed/example_jacobi_mpi.py +4 -5
- warp/examples/fem/example_adaptive_grid.py +1 -1
- warp/examples/fem/example_apic_fluid.py +1 -1
- warp/examples/fem/example_burgers.py +8 -8
- warp/examples/fem/example_diffusion.py +1 -1
- warp/examples/fem/example_distortion_energy.py +1 -1
- warp/examples/fem/example_mixed_elasticity.py +2 -2
- warp/examples/fem/example_navier_stokes.py +1 -1
- warp/examples/fem/example_nonconforming_contact.py +7 -7
- warp/examples/fem/example_stokes.py +1 -1
- warp/examples/fem/example_stokes_transfer.py +1 -1
- warp/examples/fem/utils.py +2 -2
- warp/examples/interop/example_jax_callable.py +1 -1
- warp/examples/interop/example_jax_ffi_callback.py +1 -1
- warp/examples/interop/example_jax_kernel.py +3 -2
- warp/examples/tile/example_tile_mcgp.py +191 -0
- warp/fabric.py +6 -337
- warp/fem/__init__.py +159 -97
- warp/fem/adaptivity.py +7 -489
- warp/fem/cache.py +9 -648
- warp/fem/dirichlet.py +6 -184
- warp/fem/field/__init__.py +8 -109
- warp/fem/field/field.py +7 -652
- warp/fem/geometry/__init__.py +7 -18
- warp/fem/geometry/closest_point.py +11 -77
- warp/fem/linalg.py +18 -366
- warp/fem/operator.py +11 -369
- warp/fem/polynomial.py +9 -209
- warp/fem/space/__init__.py +5 -211
- warp/fem/space/basis_space.py +6 -662
- warp/fem/space/shape/__init__.py +41 -118
- warp/fem/space/topology.py +6 -437
- warp/fem/types.py +6 -81
- warp/fem/utils.py +11 -444
- warp/jax.py +8 -165
- warp/jax_experimental/__init__.py +14 -1
- warp/jax_experimental/custom_call.py +8 -342
- warp/jax_experimental/ffi.py +17 -853
- warp/jax_experimental/xla_ffi.py +5 -596
- warp/marching_cubes.py +5 -689
- warp/math.py +16 -393
- warp/native/array.h +385 -37
- warp/native/builtin.h +316 -39
- warp/native/bvh.cpp +43 -9
- warp/native/bvh.cu +62 -27
- warp/native/bvh.h +310 -309
- warp/native/clang/clang.cpp +102 -97
- warp/native/coloring.cpp +0 -1
- warp/native/crt.h +208 -0
- warp/native/exports.h +156 -0
- warp/native/hashgrid.cu +2 -0
- warp/native/intersect.h +24 -1
- warp/native/intersect_tri.h +44 -35
- warp/native/mat.h +1456 -276
- warp/native/mesh.cpp +4 -4
- warp/native/mesh.cu +4 -2
- warp/native/mesh.h +176 -61
- warp/native/quat.h +0 -52
- warp/native/scan.cu +2 -0
- warp/native/sort.cu +22 -13
- warp/native/sort.h +2 -0
- warp/native/sparse.cu +7 -3
- warp/native/spatial.h +12 -0
- warp/native/tile.h +837 -70
- warp/native/tile_radix_sort.h +3 -3
- warp/native/tile_reduce.h +394 -46
- warp/native/tile_scan.h +4 -4
- warp/native/vec.h +469 -53
- warp/native/version.h +23 -0
- warp/native/volume.cpp +1 -1
- warp/native/volume.cu +1 -0
- warp/native/volume.h +1 -1
- warp/native/volume_builder.cu +2 -0
- warp/native/warp.cpp +60 -32
- warp/native/warp.cu +581 -280
- warp/native/warp.h +14 -11
- warp/optim/__init__.py +6 -3
- warp/optim/adam.py +6 -145
- warp/optim/linear.py +14 -1585
- warp/optim/sgd.py +6 -94
- warp/paddle.py +6 -388
- warp/render/__init__.py +8 -4
- warp/render/imgui_manager.py +7 -267
- warp/render/render_opengl.py +6 -3616
- warp/render/render_usd.py +6 -918
- warp/render/utils.py +6 -142
- warp/sparse.py +37 -2563
- warp/tape.py +6 -1188
- warp/tests/__main__.py +1 -1
- warp/tests/cuda/test_async.py +4 -4
- warp/tests/cuda/test_conditional_captures.py +1 -1
- warp/tests/cuda/test_multigpu.py +1 -1
- warp/tests/cuda/test_streams.py +58 -1
- warp/tests/geometry/test_bvh.py +157 -22
- warp/tests/geometry/test_hash_grid.py +38 -0
- warp/tests/geometry/test_marching_cubes.py +0 -1
- warp/tests/geometry/test_mesh.py +5 -3
- warp/tests/geometry/test_mesh_query_aabb.py +5 -12
- warp/tests/geometry/test_mesh_query_point.py +5 -2
- warp/tests/geometry/test_mesh_query_ray.py +15 -3
- warp/tests/geometry/test_volume_write.py +5 -5
- warp/tests/interop/test_dlpack.py +18 -17
- warp/tests/interop/test_jax.py +1382 -79
- warp/tests/interop/test_paddle.py +1 -1
- warp/tests/test_adam.py +0 -1
- warp/tests/test_arithmetic.py +9 -9
- warp/tests/test_array.py +580 -100
- warp/tests/test_array_reduce.py +3 -3
- warp/tests/test_atomic.py +12 -8
- warp/tests/test_atomic_bitwise.py +209 -0
- warp/tests/test_atomic_cas.py +4 -4
- warp/tests/test_bool.py +2 -2
- warp/tests/test_builtins_resolution.py +5 -571
- warp/tests/test_codegen.py +34 -15
- warp/tests/test_conditional.py +1 -1
- warp/tests/test_context.py +6 -6
- warp/tests/test_copy.py +242 -161
- warp/tests/test_ctypes.py +3 -3
- warp/tests/test_devices.py +24 -2
- warp/tests/test_examples.py +16 -84
- warp/tests/test_fabricarray.py +35 -35
- warp/tests/test_fast_math.py +0 -2
- warp/tests/test_fem.py +60 -14
- warp/tests/test_fixedarray.py +3 -3
- warp/tests/test_func.py +8 -5
- warp/tests/test_generics.py +1 -1
- warp/tests/test_indexedarray.py +24 -24
- warp/tests/test_intersect.py +39 -9
- warp/tests/test_large.py +1 -1
- warp/tests/test_lerp.py +3 -1
- warp/tests/test_linear_solvers.py +1 -1
- warp/tests/test_map.py +49 -4
- warp/tests/test_mat.py +52 -62
- warp/tests/test_mat_constructors.py +4 -5
- warp/tests/test_mat_lite.py +1 -1
- warp/tests/test_mat_scalar_ops.py +121 -121
- warp/tests/test_math.py +34 -0
- warp/tests/test_module_aot.py +4 -4
- warp/tests/test_modules_lite.py +28 -2
- warp/tests/test_print.py +11 -11
- warp/tests/test_quat.py +93 -58
- warp/tests/test_runlength_encode.py +1 -1
- warp/tests/test_scalar_ops.py +38 -10
- warp/tests/test_smoothstep.py +1 -1
- warp/tests/test_sparse.py +126 -15
- warp/tests/test_spatial.py +105 -87
- warp/tests/test_special_values.py +6 -6
- warp/tests/test_static.py +7 -7
- warp/tests/test_struct.py +13 -2
- warp/tests/test_triangle_closest_point.py +48 -1
- warp/tests/test_tuple.py +96 -0
- warp/tests/test_types.py +82 -9
- warp/tests/test_utils.py +52 -52
- warp/tests/test_vec.py +29 -29
- warp/tests/test_vec_constructors.py +5 -5
- warp/tests/test_vec_scalar_ops.py +97 -97
- warp/tests/test_version.py +75 -0
- warp/tests/tile/test_tile.py +239 -0
- warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
- warp/tests/tile/test_tile_cholesky.py +7 -4
- warp/tests/tile/test_tile_load.py +26 -2
- warp/tests/tile/test_tile_mathdx.py +3 -3
- warp/tests/tile/test_tile_matmul.py +1 -1
- warp/tests/tile/test_tile_mlp.py +2 -4
- warp/tests/tile/test_tile_reduce.py +214 -13
- warp/tests/unittest_suites.py +6 -14
- warp/tests/unittest_utils.py +10 -9
- warp/tests/walkthrough_debug.py +3 -1
- warp/torch.py +6 -373
- warp/types.py +29 -5750
- warp/utils.py +10 -1659
- {warp_lang-1.9.0.dist-info → warp_lang-1.10.0.dist-info}/METADATA +47 -103
- warp_lang-1.10.0.dist-info/RECORD +468 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
- warp/examples/assets/cartpole.urdf +0 -110
- warp/examples/assets/crazyflie.usd +0 -0
- warp/examples/assets/nv_ant.xml +0 -92
- warp/examples/assets/nv_humanoid.xml +0 -183
- warp/examples/assets/quadruped.urdf +0 -268
- warp/examples/optim/example_bounce.py +0 -266
- warp/examples/optim/example_cloth_throw.py +0 -228
- warp/examples/optim/example_drone.py +0 -870
- warp/examples/optim/example_inverse_kinematics.py +0 -182
- warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
- warp/examples/optim/example_softbody_properties.py +0 -400
- warp/examples/optim/example_spring_cage.py +0 -245
- warp/examples/optim/example_trajectory.py +0 -227
- warp/examples/sim/example_cartpole.py +0 -143
- warp/examples/sim/example_cloth.py +0 -225
- warp/examples/sim/example_cloth_self_contact.py +0 -316
- warp/examples/sim/example_granular.py +0 -130
- warp/examples/sim/example_granular_collision_sdf.py +0 -202
- warp/examples/sim/example_jacobian_ik.py +0 -244
- warp/examples/sim/example_particle_chain.py +0 -124
- warp/examples/sim/example_quadruped.py +0 -203
- warp/examples/sim/example_rigid_chain.py +0 -203
- warp/examples/sim/example_rigid_contact.py +0 -195
- warp/examples/sim/example_rigid_force.py +0 -133
- warp/examples/sim/example_rigid_gyroscopic.py +0 -115
- warp/examples/sim/example_rigid_soft_contact.py +0 -140
- warp/examples/sim/example_soft_body.py +0 -196
- warp/examples/tile/example_tile_walker.py +0 -327
- warp/sim/__init__.py +0 -74
- warp/sim/articulation.py +0 -793
- warp/sim/collide.py +0 -2570
- warp/sim/graph_coloring.py +0 -307
- warp/sim/import_mjcf.py +0 -791
- warp/sim/import_snu.py +0 -227
- warp/sim/import_urdf.py +0 -579
- warp/sim/import_usd.py +0 -898
- warp/sim/inertia.py +0 -357
- warp/sim/integrator.py +0 -245
- warp/sim/integrator_euler.py +0 -2000
- warp/sim/integrator_featherstone.py +0 -2101
- warp/sim/integrator_vbd.py +0 -2487
- warp/sim/integrator_xpbd.py +0 -3295
- warp/sim/model.py +0 -4821
- warp/sim/particles.py +0 -121
- warp/sim/render.py +0 -431
- warp/sim/utils.py +0 -431
- warp/tests/sim/disabled_kinematics.py +0 -244
- warp/tests/sim/test_cloth.py +0 -863
- warp/tests/sim/test_collision.py +0 -743
- warp/tests/sim/test_coloring.py +0 -347
- warp/tests/sim/test_inertia.py +0 -161
- warp/tests/sim/test_model.py +0 -226
- warp/tests/sim/test_sim_grad.py +0 -287
- warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
- warp/tests/sim/test_sim_kinematics.py +0 -98
- warp/thirdparty/__init__.py +0 -0
- warp_lang-1.9.0.dist-info/RECORD +0 -456
- /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
- /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
- /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
- /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
- {warp_lang-1.9.0.dist-info → warp_lang-1.10.0.dist-info}/WHEEL +0 -0
- {warp_lang-1.9.0.dist-info → warp_lang-1.10.0.dist-info}/licenses/LICENSE.md +0 -0
- {warp_lang-1.9.0.dist-info → warp_lang-1.10.0.dist-info}/top_level.txt +0 -0
warp/native/warp.cu
CHANGED
|
@@ -19,6 +19,7 @@
|
|
|
19
19
|
#include "scan.h"
|
|
20
20
|
#include "cuda_util.h"
|
|
21
21
|
#include "error.h"
|
|
22
|
+
#include "sort.h"
|
|
22
23
|
|
|
23
24
|
#include <cstdlib>
|
|
24
25
|
#include <fstream>
|
|
@@ -37,6 +38,7 @@
|
|
|
37
38
|
#include <iterator>
|
|
38
39
|
#include <list>
|
|
39
40
|
#include <map>
|
|
41
|
+
#include <mutex>
|
|
40
42
|
#include <string>
|
|
41
43
|
#include <unordered_map>
|
|
42
44
|
#include <unordered_set>
|
|
@@ -175,11 +177,20 @@ struct ContextInfo
|
|
|
175
177
|
CUmodule conditional_module = NULL;
|
|
176
178
|
};
|
|
177
179
|
|
|
180
|
+
// Information used for freeing allocations.
|
|
181
|
+
struct FreeInfo
|
|
182
|
+
{
|
|
183
|
+
void* context = NULL;
|
|
184
|
+
void* ptr = NULL;
|
|
185
|
+
bool is_async = false;
|
|
186
|
+
};
|
|
187
|
+
|
|
178
188
|
struct CaptureInfo
|
|
179
189
|
{
|
|
180
190
|
CUstream stream = NULL; // the main stream where capture begins and ends
|
|
181
191
|
uint64_t id = 0; // unique capture id from CUDA
|
|
182
192
|
bool external = false; // whether this is an external capture
|
|
193
|
+
std::vector<FreeInfo> tmp_allocs; // temporary allocations owned by the graph (e.g., staged array fill values)
|
|
183
194
|
};
|
|
184
195
|
|
|
185
196
|
struct StreamInfo
|
|
@@ -188,9 +199,13 @@ struct StreamInfo
|
|
|
188
199
|
CaptureInfo* capture = NULL; // capture info (only if started on this stream)
|
|
189
200
|
};
|
|
190
201
|
|
|
191
|
-
|
|
202
|
+
// Extra resources tied to a graph, freed after the graph is released by CUDA.
|
|
203
|
+
// Used with the on_graph_destroy() callback.
|
|
204
|
+
struct GraphDestroyCallbackInfo
|
|
192
205
|
{
|
|
193
|
-
|
|
206
|
+
void* context = NULL; // graph CUDA context
|
|
207
|
+
std::vector<void*> unfreed_allocs; // graph allocations not freed by the graph
|
|
208
|
+
std::vector<FreeInfo> tmp_allocs; // temporary allocations owned by the graph (e.g., staged array fill values)
|
|
194
209
|
};
|
|
195
210
|
|
|
196
211
|
// Information for graph allocations that are not freed by the graph.
|
|
@@ -206,19 +221,19 @@ struct GraphAllocInfo
|
|
|
206
221
|
bool graph_destroyed = false; // whether graph instance was destroyed
|
|
207
222
|
};
|
|
208
223
|
|
|
209
|
-
// Information used when deferring
|
|
210
|
-
struct
|
|
224
|
+
// Information used when deferring module unloading.
|
|
225
|
+
struct ModuleInfo
|
|
211
226
|
{
|
|
212
227
|
void* context = NULL;
|
|
213
|
-
void*
|
|
214
|
-
bool is_async = false;
|
|
228
|
+
void* module = NULL;
|
|
215
229
|
};
|
|
216
230
|
|
|
217
|
-
// Information used when deferring
|
|
218
|
-
struct
|
|
231
|
+
// Information used when deferring graph destruction.
|
|
232
|
+
struct GraphDestroyInfo
|
|
219
233
|
{
|
|
220
234
|
void* context = NULL;
|
|
221
|
-
void*
|
|
235
|
+
void* graph = NULL;
|
|
236
|
+
void* graph_exec = NULL;
|
|
222
237
|
};
|
|
223
238
|
|
|
224
239
|
static std::unordered_map<CUfunction, std::string> g_kernel_names;
|
|
@@ -252,6 +267,15 @@ static std::vector<FreeInfo> g_deferred_free_list;
|
|
|
252
267
|
// Call unload_deferred_modules() to release.
|
|
253
268
|
static std::vector<ModuleInfo> g_deferred_module_list;
|
|
254
269
|
|
|
270
|
+
// Graphs that cannot be destroyed immediately get queued here.
|
|
271
|
+
// Call destroy_deferred_graphs() to release.
|
|
272
|
+
static std::vector<GraphDestroyInfo> g_deferred_graph_list;
|
|
273
|
+
|
|
274
|
+
// Data from on_graph_destroy() callbacks that run on a different thread.
|
|
275
|
+
static std::vector<GraphDestroyCallbackInfo*> g_deferred_graph_destroy_list;
|
|
276
|
+
static std::mutex g_graph_destroy_mutex;
|
|
277
|
+
|
|
278
|
+
|
|
255
279
|
void wp_cuda_set_context_restore_policy(bool always_restore)
|
|
256
280
|
{
|
|
257
281
|
ContextGuard::always_restore = always_restore;
|
|
@@ -337,7 +361,7 @@ int cuda_init()
|
|
|
337
361
|
}
|
|
338
362
|
|
|
339
363
|
|
|
340
|
-
|
|
364
|
+
CUcontext get_current_context()
|
|
341
365
|
{
|
|
342
366
|
CUcontext ctx;
|
|
343
367
|
if (check_cu(cuCtxGetCurrent_f(&ctx)))
|
|
@@ -407,6 +431,114 @@ static inline StreamInfo* get_stream_info(CUstream stream)
|
|
|
407
431
|
return NULL;
|
|
408
432
|
}
|
|
409
433
|
|
|
434
|
+
static inline CaptureInfo* get_capture_info(CUstream stream)
|
|
435
|
+
{
|
|
436
|
+
if (!g_captures.empty() && wp_cuda_stream_is_capturing(stream))
|
|
437
|
+
{
|
|
438
|
+
uint64_t capture_id = get_capture_id(stream);
|
|
439
|
+
auto capture_iter = g_captures.find(capture_id);
|
|
440
|
+
if (capture_iter != g_captures.end())
|
|
441
|
+
return capture_iter->second;
|
|
442
|
+
}
|
|
443
|
+
return NULL;
|
|
444
|
+
}
|
|
445
|
+
|
|
446
|
+
// helper function to copy a value to device memory in a graph-friendly way
|
|
447
|
+
static bool capturable_tmp_alloc(void* context, const void* data, size_t size, void** devptr_ret, bool* free_devptr_ret)
|
|
448
|
+
{
|
|
449
|
+
ContextGuard guard(context);
|
|
450
|
+
|
|
451
|
+
CUstream stream = get_current_stream();
|
|
452
|
+
CaptureInfo* capture_info = get_capture_info(stream);
|
|
453
|
+
int device_ordinal = wp_cuda_context_get_device_ordinal(context);
|
|
454
|
+
void* devptr = NULL;
|
|
455
|
+
bool free_devptr = true;
|
|
456
|
+
|
|
457
|
+
if (capture_info)
|
|
458
|
+
{
|
|
459
|
+
// ongoing graph capture - need to stage the fill value so that it persists with the graph
|
|
460
|
+
if (CUDA_VERSION >= 12040 && wp_cuda_driver_version() >= 12040)
|
|
461
|
+
{
|
|
462
|
+
// pause the capture so that the alloc/memcpy won't be captured
|
|
463
|
+
void* graph = NULL;
|
|
464
|
+
if (!wp_cuda_graph_pause_capture(WP_CURRENT_CONTEXT, stream, &graph))
|
|
465
|
+
return false;
|
|
466
|
+
|
|
467
|
+
// copy value to device memory
|
|
468
|
+
devptr = wp_alloc_device(WP_CURRENT_CONTEXT, size);
|
|
469
|
+
if (!devptr)
|
|
470
|
+
{
|
|
471
|
+
fprintf(stderr, "Warp error: Failed to allocate %llu bytes on device 'cuda:%d' (in function %s)\n", (unsigned long long)size, device_ordinal, __FUNCTION__);
|
|
472
|
+
return false;
|
|
473
|
+
}
|
|
474
|
+
if (!check_cuda(cudaMemcpyAsync(devptr, data, size, cudaMemcpyHostToDevice, stream)))
|
|
475
|
+
return false;
|
|
476
|
+
|
|
477
|
+
// graph takes ownership of the value storage
|
|
478
|
+
FreeInfo free_info;
|
|
479
|
+
free_info.context = context ? context : get_current_context();
|
|
480
|
+
free_info.ptr = devptr;
|
|
481
|
+
free_info.is_async = wp_cuda_device_is_mempool_supported(device_ordinal);
|
|
482
|
+
|
|
483
|
+
// allocation will be freed when graph is destroyed
|
|
484
|
+
capture_info->tmp_allocs.push_back(free_info);
|
|
485
|
+
|
|
486
|
+
// resume the capture
|
|
487
|
+
if (!wp_cuda_graph_resume_capture(WP_CURRENT_CONTEXT, stream, graph))
|
|
488
|
+
return false;
|
|
489
|
+
|
|
490
|
+
free_devptr = false; // memory is owned by the graph, doesn't need to be freed
|
|
491
|
+
}
|
|
492
|
+
else
|
|
493
|
+
{
|
|
494
|
+
// older CUDA can't pause/resume the capture, so stage in CPU memory
|
|
495
|
+
void* hostptr = wp_alloc_host(size);
|
|
496
|
+
if (!hostptr)
|
|
497
|
+
{
|
|
498
|
+
fprintf(stderr, "Warp error: Failed to allocate %llu bytes on device 'cpu' (in function %s)\n", (unsigned long long)size, __FUNCTION__);
|
|
499
|
+
return false;
|
|
500
|
+
}
|
|
501
|
+
memcpy(hostptr, data, size);
|
|
502
|
+
|
|
503
|
+
// the device allocation and h2d copy will be captured in the graph
|
|
504
|
+
devptr = wp_alloc_device(WP_CURRENT_CONTEXT, size);
|
|
505
|
+
if (!devptr)
|
|
506
|
+
{
|
|
507
|
+
fprintf(stderr, "Warp error: Failed to allocate %llu bytes on device 'cuda:%d' (in function %s)\n", (unsigned long long)size, device_ordinal, __FUNCTION__);
|
|
508
|
+
return false;
|
|
509
|
+
}
|
|
510
|
+
if (!check_cuda(cudaMemcpyAsync(devptr, hostptr, size, cudaMemcpyHostToDevice, stream)))
|
|
511
|
+
return false;
|
|
512
|
+
|
|
513
|
+
// graph takes ownership of the value storage
|
|
514
|
+
FreeInfo free_info;
|
|
515
|
+
free_info.context = NULL;
|
|
516
|
+
free_info.ptr = hostptr;
|
|
517
|
+
free_info.is_async = false;
|
|
518
|
+
|
|
519
|
+
// allocation will be freed when graph is destroyed
|
|
520
|
+
capture_info->tmp_allocs.push_back(free_info);
|
|
521
|
+
}
|
|
522
|
+
}
|
|
523
|
+
else
|
|
524
|
+
{
|
|
525
|
+
// not capturing, copy the value to device memory
|
|
526
|
+
devptr = wp_alloc_device(WP_CURRENT_CONTEXT, size);
|
|
527
|
+
if (!devptr)
|
|
528
|
+
{
|
|
529
|
+
fprintf(stderr, "Warp error: Failed to allocate %llu bytes on device 'cuda:%d' (in function %s)\n", (unsigned long long)size, device_ordinal, __FUNCTION__);
|
|
530
|
+
return false;
|
|
531
|
+
}
|
|
532
|
+
if (!check_cuda(cudaMemcpyAsync(devptr, data, size, cudaMemcpyHostToDevice, stream)))
|
|
533
|
+
return false;
|
|
534
|
+
}
|
|
535
|
+
|
|
536
|
+
*devptr_ret = devptr;
|
|
537
|
+
*free_devptr_ret = free_devptr;
|
|
538
|
+
|
|
539
|
+
return true;
|
|
540
|
+
}
|
|
541
|
+
|
|
410
542
|
static void deferred_free(void* ptr, void* context, bool is_async)
|
|
411
543
|
{
|
|
412
544
|
FreeInfo free_info;
|
|
@@ -494,34 +626,124 @@ static int unload_deferred_modules(void* context = NULL)
|
|
|
494
626
|
return num_unloaded_modules;
|
|
495
627
|
}
|
|
496
628
|
|
|
497
|
-
static
|
|
629
|
+
static int destroy_deferred_graphs(void* context = NULL)
|
|
498
630
|
{
|
|
499
|
-
if (!
|
|
500
|
-
return;
|
|
631
|
+
if (g_deferred_graph_list.empty() || !g_captures.empty())
|
|
632
|
+
return 0;
|
|
633
|
+
|
|
634
|
+
int num_destroyed_graphs = 0;
|
|
635
|
+
for (auto it = g_deferred_graph_list.begin(); it != g_deferred_graph_list.end(); /*noop*/)
|
|
636
|
+
{
|
|
637
|
+
// destroy the graph if it matches the given context or if the context is unspecified
|
|
638
|
+
const GraphDestroyInfo& graph_info = *it;
|
|
639
|
+
if (graph_info.context == context || !context)
|
|
640
|
+
{
|
|
641
|
+
if (graph_info.graph)
|
|
642
|
+
{
|
|
643
|
+
check_cuda(cudaGraphDestroy((cudaGraph_t)graph_info.graph));
|
|
644
|
+
}
|
|
645
|
+
if (graph_info.graph_exec)
|
|
646
|
+
{
|
|
647
|
+
check_cuda(cudaGraphExecDestroy((cudaGraphExec_t)graph_info.graph_exec));
|
|
648
|
+
}
|
|
649
|
+
++num_destroyed_graphs;
|
|
650
|
+
it = g_deferred_graph_list.erase(it);
|
|
651
|
+
}
|
|
652
|
+
else
|
|
653
|
+
{
|
|
654
|
+
++it;
|
|
655
|
+
}
|
|
656
|
+
}
|
|
657
|
+
|
|
658
|
+
return num_destroyed_graphs;
|
|
659
|
+
}
|
|
660
|
+
|
|
661
|
+
static int process_deferred_graph_destroy_callbacks(void* context = NULL)
|
|
662
|
+
{
|
|
663
|
+
int num_freed = 0;
|
|
501
664
|
|
|
502
|
-
|
|
665
|
+
std::lock_guard<std::mutex> lock(g_graph_destroy_mutex);
|
|
503
666
|
|
|
504
|
-
for (
|
|
667
|
+
for (auto it = g_deferred_graph_destroy_list.begin(); it != g_deferred_graph_destroy_list.end(); /*noop*/)
|
|
505
668
|
{
|
|
506
|
-
|
|
507
|
-
if (
|
|
669
|
+
GraphDestroyCallbackInfo* graph_info = *it;
|
|
670
|
+
if (graph_info->context == context || !context)
|
|
508
671
|
{
|
|
509
|
-
|
|
510
|
-
|
|
672
|
+
// handle unfreed graph allocations (may have outstanding user references)
|
|
673
|
+
for (void* ptr : graph_info->unfreed_allocs)
|
|
511
674
|
{
|
|
512
|
-
|
|
513
|
-
|
|
675
|
+
auto alloc_iter = g_graph_allocs.find(ptr);
|
|
676
|
+
if (alloc_iter != g_graph_allocs.end())
|
|
677
|
+
{
|
|
678
|
+
GraphAllocInfo& alloc_info = alloc_iter->second;
|
|
679
|
+
if (alloc_info.ref_exists)
|
|
680
|
+
{
|
|
681
|
+
// unreference from graph so the pointer will be deallocated when the user reference goes away
|
|
682
|
+
alloc_info.graph_destroyed = true;
|
|
683
|
+
}
|
|
684
|
+
else
|
|
685
|
+
{
|
|
686
|
+
// the pointer can be freed, no references remain
|
|
687
|
+
wp_free_device_async(alloc_info.context, ptr);
|
|
688
|
+
g_graph_allocs.erase(alloc_iter);
|
|
689
|
+
}
|
|
690
|
+
}
|
|
514
691
|
}
|
|
515
|
-
|
|
692
|
+
|
|
693
|
+
// handle temporary allocations owned by the graph (no user references)
|
|
694
|
+
for (const FreeInfo& tmp_info : graph_info->tmp_allocs)
|
|
516
695
|
{
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
696
|
+
if (tmp_info.context)
|
|
697
|
+
{
|
|
698
|
+
// GPU alloc
|
|
699
|
+
if (tmp_info.is_async)
|
|
700
|
+
{
|
|
701
|
+
wp_free_device_async(tmp_info.context, tmp_info.ptr);
|
|
702
|
+
}
|
|
703
|
+
else
|
|
704
|
+
{
|
|
705
|
+
wp_free_device_default(tmp_info.context, tmp_info.ptr);
|
|
706
|
+
}
|
|
707
|
+
}
|
|
708
|
+
else
|
|
709
|
+
{
|
|
710
|
+
// CPU alloc
|
|
711
|
+
wp_free_host(tmp_info.ptr);
|
|
712
|
+
}
|
|
520
713
|
}
|
|
714
|
+
|
|
715
|
+
++num_freed;
|
|
716
|
+
delete graph_info;
|
|
717
|
+
it = g_deferred_graph_destroy_list.erase(it);
|
|
718
|
+
}
|
|
719
|
+
else
|
|
720
|
+
{
|
|
721
|
+
++it;
|
|
521
722
|
}
|
|
522
723
|
}
|
|
523
724
|
|
|
524
|
-
|
|
725
|
+
return num_freed;
|
|
726
|
+
}
|
|
727
|
+
|
|
728
|
+
static int run_deferred_actions(void* context = NULL)
|
|
729
|
+
{
|
|
730
|
+
int num_actions = 0;
|
|
731
|
+
num_actions += free_deferred_allocs(context);
|
|
732
|
+
num_actions += unload_deferred_modules(context);
|
|
733
|
+
num_actions += destroy_deferred_graphs(context);
|
|
734
|
+
num_actions += process_deferred_graph_destroy_callbacks(context);
|
|
735
|
+
return num_actions;
|
|
736
|
+
}
|
|
737
|
+
|
|
738
|
+
// Callback used when a graph is destroyed.
|
|
739
|
+
// NOTE: this runs on an internal CUDA thread and requires synchronization.
|
|
740
|
+
static void CUDART_CB on_graph_destroy(void* user_data)
|
|
741
|
+
{
|
|
742
|
+
if (user_data)
|
|
743
|
+
{
|
|
744
|
+
std::lock_guard<std::mutex> lock(g_graph_destroy_mutex);
|
|
745
|
+
g_deferred_graph_destroy_list.push_back(static_cast<GraphDestroyCallbackInfo*>(user_data));
|
|
746
|
+
}
|
|
525
747
|
}
|
|
526
748
|
|
|
527
749
|
static inline const char* get_cuda_kernel_name(void* kernel)
|
|
@@ -973,30 +1195,36 @@ void wp_memtile_device(void* context, void* dst, const void* src, size_t srcsize
|
|
|
973
1195
|
else
|
|
974
1196
|
{
|
|
975
1197
|
// generic version
|
|
1198
|
+
void* value_devptr = NULL; // fill value in device memory
|
|
1199
|
+
bool free_devptr = true; // whether we need to free the memory
|
|
976
1200
|
|
|
977
|
-
//
|
|
978
|
-
|
|
979
|
-
|
|
980
|
-
|
|
981
|
-
|
|
982
|
-
|
|
1201
|
+
// prepare the fill value in a graph-friendly way
|
|
1202
|
+
if (!capturable_tmp_alloc(WP_CURRENT_CONTEXT, src, srcsize, &value_devptr, &free_devptr))
|
|
1203
|
+
{
|
|
1204
|
+
fprintf(stderr, "Warp fill error: failed to copy value to device memory\n");
|
|
1205
|
+
return;
|
|
1206
|
+
}
|
|
983
1207
|
|
|
984
|
-
|
|
1208
|
+
wp_launch_device(WP_CURRENT_CONTEXT, memtile_kernel, n, (dst, value_devptr, srcsize, n));
|
|
985
1209
|
|
|
1210
|
+
if (free_devptr)
|
|
1211
|
+
{
|
|
1212
|
+
wp_free_device(WP_CURRENT_CONTEXT, value_devptr);
|
|
1213
|
+
}
|
|
986
1214
|
}
|
|
987
1215
|
}
|
|
988
1216
|
|
|
989
1217
|
|
|
990
1218
|
static __global__ void array_copy_1d_kernel(void* dst, const void* src,
|
|
991
|
-
|
|
1219
|
+
size_t dst_stride, size_t src_stride,
|
|
992
1220
|
const int* dst_indices, const int* src_indices,
|
|
993
|
-
|
|
1221
|
+
size_t n, size_t elem_size)
|
|
994
1222
|
{
|
|
995
|
-
|
|
1223
|
+
size_t i = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
996
1224
|
if (i < n)
|
|
997
1225
|
{
|
|
998
|
-
|
|
999
|
-
|
|
1226
|
+
size_t src_idx = src_indices ? src_indices[i] : i;
|
|
1227
|
+
size_t dst_idx = dst_indices ? dst_indices[i] : i;
|
|
1000
1228
|
const char* p = (const char*)src + src_idx * src_stride;
|
|
1001
1229
|
char* q = (char*)dst + dst_idx * dst_stride;
|
|
1002
1230
|
memcpy(q, p, elem_size);
|
|
@@ -1004,20 +1232,20 @@ static __global__ void array_copy_1d_kernel(void* dst, const void* src,
|
|
|
1004
1232
|
}
|
|
1005
1233
|
|
|
1006
1234
|
static __global__ void array_copy_2d_kernel(void* dst, const void* src,
|
|
1007
|
-
wp::vec_t<2,
|
|
1235
|
+
wp::vec_t<2, size_t> dst_strides, wp::vec_t<2, size_t> src_strides,
|
|
1008
1236
|
wp::vec_t<2, const int*> dst_indices, wp::vec_t<2, const int*> src_indices,
|
|
1009
|
-
wp::vec_t<2,
|
|
1237
|
+
wp::vec_t<2, size_t> shape, size_t elem_size)
|
|
1010
1238
|
{
|
|
1011
|
-
|
|
1012
|
-
|
|
1013
|
-
|
|
1014
|
-
|
|
1239
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1240
|
+
size_t n = shape[1];
|
|
1241
|
+
size_t i = tid / n;
|
|
1242
|
+
size_t j = tid % n;
|
|
1015
1243
|
if (i < shape[0] /*&& j < shape[1]*/)
|
|
1016
1244
|
{
|
|
1017
|
-
|
|
1018
|
-
|
|
1019
|
-
|
|
1020
|
-
|
|
1245
|
+
size_t src_idx0 = src_indices[0] ? src_indices[0][i] : i;
|
|
1246
|
+
size_t dst_idx0 = dst_indices[0] ? dst_indices[0][i] : i;
|
|
1247
|
+
size_t src_idx1 = src_indices[1] ? src_indices[1][j] : j;
|
|
1248
|
+
size_t dst_idx1 = dst_indices[1] ? dst_indices[1][j] : j;
|
|
1021
1249
|
const char* p = (const char*)src + src_idx0 * src_strides[0] + src_idx1 * src_strides[1];
|
|
1022
1250
|
char* q = (char*)dst + dst_idx0 * dst_strides[0] + dst_idx1 * dst_strides[1];
|
|
1023
1251
|
memcpy(q, p, elem_size);
|
|
@@ -1025,24 +1253,24 @@ static __global__ void array_copy_2d_kernel(void* dst, const void* src,
|
|
|
1025
1253
|
}
|
|
1026
1254
|
|
|
1027
1255
|
static __global__ void array_copy_3d_kernel(void* dst, const void* src,
|
|
1028
|
-
wp::vec_t<3,
|
|
1256
|
+
wp::vec_t<3, size_t> dst_strides, wp::vec_t<3, size_t> src_strides,
|
|
1029
1257
|
wp::vec_t<3, const int*> dst_indices, wp::vec_t<3, const int*> src_indices,
|
|
1030
|
-
wp::vec_t<3,
|
|
1031
|
-
{
|
|
1032
|
-
|
|
1033
|
-
|
|
1034
|
-
|
|
1035
|
-
|
|
1036
|
-
|
|
1037
|
-
|
|
1258
|
+
wp::vec_t<3, size_t> shape, size_t elem_size)
|
|
1259
|
+
{
|
|
1260
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1261
|
+
size_t n = shape[1];
|
|
1262
|
+
size_t o = shape[2];
|
|
1263
|
+
size_t i = tid / (n * o);
|
|
1264
|
+
size_t j = tid % (n * o) / o;
|
|
1265
|
+
size_t k = tid % o;
|
|
1038
1266
|
if (i < shape[0] && j < shape[1] /*&& k < shape[2]*/)
|
|
1039
1267
|
{
|
|
1040
|
-
|
|
1041
|
-
|
|
1042
|
-
|
|
1043
|
-
|
|
1044
|
-
|
|
1045
|
-
|
|
1268
|
+
size_t src_idx0 = src_indices[0] ? src_indices[0][i] : i;
|
|
1269
|
+
size_t dst_idx0 = dst_indices[0] ? dst_indices[0][i] : i;
|
|
1270
|
+
size_t src_idx1 = src_indices[1] ? src_indices[1][j] : j;
|
|
1271
|
+
size_t dst_idx1 = dst_indices[1] ? dst_indices[1][j] : j;
|
|
1272
|
+
size_t src_idx2 = src_indices[2] ? src_indices[2][k] : k;
|
|
1273
|
+
size_t dst_idx2 = dst_indices[2] ? dst_indices[2][k] : k;
|
|
1046
1274
|
const char* p = (const char*)src + src_idx0 * src_strides[0]
|
|
1047
1275
|
+ src_idx1 * src_strides[1]
|
|
1048
1276
|
+ src_idx2 * src_strides[2];
|
|
@@ -1054,28 +1282,28 @@ static __global__ void array_copy_3d_kernel(void* dst, const void* src,
|
|
|
1054
1282
|
}
|
|
1055
1283
|
|
|
1056
1284
|
static __global__ void array_copy_4d_kernel(void* dst, const void* src,
|
|
1057
|
-
wp::vec_t<4,
|
|
1285
|
+
wp::vec_t<4, size_t> dst_strides, wp::vec_t<4, size_t> src_strides,
|
|
1058
1286
|
wp::vec_t<4, const int*> dst_indices, wp::vec_t<4, const int*> src_indices,
|
|
1059
|
-
wp::vec_t<4,
|
|
1060
|
-
{
|
|
1061
|
-
|
|
1062
|
-
|
|
1063
|
-
|
|
1064
|
-
|
|
1065
|
-
|
|
1066
|
-
|
|
1067
|
-
|
|
1068
|
-
|
|
1287
|
+
wp::vec_t<4, size_t> shape, size_t elem_size)
|
|
1288
|
+
{
|
|
1289
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1290
|
+
size_t n = shape[1];
|
|
1291
|
+
size_t o = shape[2];
|
|
1292
|
+
size_t p = shape[3];
|
|
1293
|
+
size_t i = tid / (n * o * p);
|
|
1294
|
+
size_t j = tid % (n * o * p) / (o * p);
|
|
1295
|
+
size_t k = tid % (o * p) / p;
|
|
1296
|
+
size_t l = tid % p;
|
|
1069
1297
|
if (i < shape[0] && j < shape[1] && k < shape[2] /*&& l < shape[3]*/)
|
|
1070
1298
|
{
|
|
1071
|
-
|
|
1072
|
-
|
|
1073
|
-
|
|
1074
|
-
|
|
1075
|
-
|
|
1076
|
-
|
|
1077
|
-
|
|
1078
|
-
|
|
1299
|
+
size_t src_idx0 = src_indices[0] ? src_indices[0][i] : i;
|
|
1300
|
+
size_t dst_idx0 = dst_indices[0] ? dst_indices[0][i] : i;
|
|
1301
|
+
size_t src_idx1 = src_indices[1] ? src_indices[1][j] : j;
|
|
1302
|
+
size_t dst_idx1 = dst_indices[1] ? dst_indices[1][j] : j;
|
|
1303
|
+
size_t src_idx2 = src_indices[2] ? src_indices[2][k] : k;
|
|
1304
|
+
size_t dst_idx2 = dst_indices[2] ? dst_indices[2][k] : k;
|
|
1305
|
+
size_t src_idx3 = src_indices[3] ? src_indices[3][l] : l;
|
|
1306
|
+
size_t dst_idx3 = dst_indices[3] ? dst_indices[3][l] : l;
|
|
1079
1307
|
const char* p = (const char*)src + src_idx0 * src_strides[0]
|
|
1080
1308
|
+ src_idx1 * src_strides[1]
|
|
1081
1309
|
+ src_idx2 * src_strides[2]
|
|
@@ -1090,14 +1318,14 @@ static __global__ void array_copy_4d_kernel(void* dst, const void* src,
|
|
|
1090
1318
|
|
|
1091
1319
|
|
|
1092
1320
|
static __global__ void array_copy_from_fabric_kernel(wp::fabricarray_t<void> src,
|
|
1093
|
-
void* dst_data,
|
|
1094
|
-
|
|
1321
|
+
void* dst_data, size_t dst_stride, const int* dst_indices,
|
|
1322
|
+
size_t elem_size)
|
|
1095
1323
|
{
|
|
1096
|
-
|
|
1324
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1097
1325
|
|
|
1098
1326
|
if (tid < src.size)
|
|
1099
1327
|
{
|
|
1100
|
-
|
|
1328
|
+
size_t dst_idx = dst_indices ? dst_indices[tid] : tid;
|
|
1101
1329
|
void* dst_ptr = (char*)dst_data + dst_idx * dst_stride;
|
|
1102
1330
|
const void* src_ptr = fabricarray_element_ptr(src, tid, elem_size);
|
|
1103
1331
|
memcpy(dst_ptr, src_ptr, elem_size);
|
|
@@ -1105,15 +1333,15 @@ static __global__ void array_copy_from_fabric_kernel(wp::fabricarray_t<void> src
|
|
|
1105
1333
|
}
|
|
1106
1334
|
|
|
1107
1335
|
static __global__ void array_copy_from_fabric_indexed_kernel(wp::indexedfabricarray_t<void> src,
|
|
1108
|
-
void* dst_data,
|
|
1109
|
-
|
|
1336
|
+
void* dst_data, size_t dst_stride, const int* dst_indices,
|
|
1337
|
+
size_t elem_size)
|
|
1110
1338
|
{
|
|
1111
|
-
|
|
1339
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1112
1340
|
|
|
1113
1341
|
if (tid < src.size)
|
|
1114
1342
|
{
|
|
1115
|
-
|
|
1116
|
-
|
|
1343
|
+
size_t src_index = src.indices[tid];
|
|
1344
|
+
size_t dst_idx = dst_indices ? dst_indices[tid] : tid;
|
|
1117
1345
|
void* dst_ptr = (char*)dst_data + dst_idx * dst_stride;
|
|
1118
1346
|
const void* src_ptr = fabricarray_element_ptr(src.fa, src_index, elem_size);
|
|
1119
1347
|
memcpy(dst_ptr, src_ptr, elem_size);
|
|
@@ -1121,14 +1349,14 @@ static __global__ void array_copy_from_fabric_indexed_kernel(wp::indexedfabricar
|
|
|
1121
1349
|
}
|
|
1122
1350
|
|
|
1123
1351
|
static __global__ void array_copy_to_fabric_kernel(wp::fabricarray_t<void> dst,
|
|
1124
|
-
const void* src_data,
|
|
1125
|
-
|
|
1352
|
+
const void* src_data, size_t src_stride, const int* src_indices,
|
|
1353
|
+
size_t elem_size)
|
|
1126
1354
|
{
|
|
1127
|
-
|
|
1355
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1128
1356
|
|
|
1129
1357
|
if (tid < dst.size)
|
|
1130
1358
|
{
|
|
1131
|
-
|
|
1359
|
+
size_t src_idx = src_indices ? src_indices[tid] : tid;
|
|
1132
1360
|
const void* src_ptr = (const char*)src_data + src_idx * src_stride;
|
|
1133
1361
|
void* dst_ptr = fabricarray_element_ptr(dst, tid, elem_size);
|
|
1134
1362
|
memcpy(dst_ptr, src_ptr, elem_size);
|
|
@@ -1136,25 +1364,25 @@ static __global__ void array_copy_to_fabric_kernel(wp::fabricarray_t<void> dst,
|
|
|
1136
1364
|
}
|
|
1137
1365
|
|
|
1138
1366
|
static __global__ void array_copy_to_fabric_indexed_kernel(wp::indexedfabricarray_t<void> dst,
|
|
1139
|
-
const void* src_data,
|
|
1140
|
-
|
|
1367
|
+
const void* src_data, size_t src_stride, const int* src_indices,
|
|
1368
|
+
size_t elem_size)
|
|
1141
1369
|
{
|
|
1142
|
-
|
|
1370
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1143
1371
|
|
|
1144
1372
|
if (tid < dst.size)
|
|
1145
1373
|
{
|
|
1146
|
-
|
|
1374
|
+
size_t src_idx = src_indices ? src_indices[tid] : tid;
|
|
1147
1375
|
const void* src_ptr = (const char*)src_data + src_idx * src_stride;
|
|
1148
|
-
|
|
1376
|
+
size_t dst_idx = dst.indices[tid];
|
|
1149
1377
|
void* dst_ptr = fabricarray_element_ptr(dst.fa, dst_idx, elem_size);
|
|
1150
1378
|
memcpy(dst_ptr, src_ptr, elem_size);
|
|
1151
1379
|
}
|
|
1152
1380
|
}
|
|
1153
1381
|
|
|
1154
1382
|
|
|
1155
|
-
static __global__ void array_copy_fabric_to_fabric_kernel(wp::fabricarray_t<void> dst, wp::fabricarray_t<void> src,
|
|
1383
|
+
static __global__ void array_copy_fabric_to_fabric_kernel(wp::fabricarray_t<void> dst, wp::fabricarray_t<void> src, size_t elem_size)
|
|
1156
1384
|
{
|
|
1157
|
-
|
|
1385
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1158
1386
|
|
|
1159
1387
|
if (tid < dst.size)
|
|
1160
1388
|
{
|
|
@@ -1165,27 +1393,27 @@ static __global__ void array_copy_fabric_to_fabric_kernel(wp::fabricarray_t<void
|
|
|
1165
1393
|
}
|
|
1166
1394
|
|
|
1167
1395
|
|
|
1168
|
-
static __global__ void array_copy_fabric_to_fabric_indexed_kernel(wp::indexedfabricarray_t<void> dst, wp::fabricarray_t<void> src,
|
|
1396
|
+
static __global__ void array_copy_fabric_to_fabric_indexed_kernel(wp::indexedfabricarray_t<void> dst, wp::fabricarray_t<void> src, size_t elem_size)
|
|
1169
1397
|
{
|
|
1170
|
-
|
|
1398
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1171
1399
|
|
|
1172
1400
|
if (tid < dst.size)
|
|
1173
1401
|
{
|
|
1174
1402
|
const void* src_ptr = fabricarray_element_ptr(src, tid, elem_size);
|
|
1175
|
-
|
|
1403
|
+
size_t dst_index = dst.indices[tid];
|
|
1176
1404
|
void* dst_ptr = fabricarray_element_ptr(dst.fa, dst_index, elem_size);
|
|
1177
1405
|
memcpy(dst_ptr, src_ptr, elem_size);
|
|
1178
1406
|
}
|
|
1179
1407
|
}
|
|
1180
1408
|
|
|
1181
1409
|
|
|
1182
|
-
static __global__ void array_copy_fabric_indexed_to_fabric_kernel(wp::fabricarray_t<void> dst, wp::indexedfabricarray_t<void> src,
|
|
1410
|
+
static __global__ void array_copy_fabric_indexed_to_fabric_kernel(wp::fabricarray_t<void> dst, wp::indexedfabricarray_t<void> src, size_t elem_size)
|
|
1183
1411
|
{
|
|
1184
|
-
|
|
1412
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1185
1413
|
|
|
1186
1414
|
if (tid < dst.size)
|
|
1187
1415
|
{
|
|
1188
|
-
|
|
1416
|
+
size_t src_index = src.indices[tid];
|
|
1189
1417
|
const void* src_ptr = fabricarray_element_ptr(src.fa, src_index, elem_size);
|
|
1190
1418
|
void* dst_ptr = fabricarray_element_ptr(dst, tid, elem_size);
|
|
1191
1419
|
memcpy(dst_ptr, src_ptr, elem_size);
|
|
@@ -1193,14 +1421,14 @@ static __global__ void array_copy_fabric_indexed_to_fabric_kernel(wp::fabricarra
|
|
|
1193
1421
|
}
|
|
1194
1422
|
|
|
1195
1423
|
|
|
1196
|
-
static __global__ void array_copy_fabric_indexed_to_fabric_indexed_kernel(wp::indexedfabricarray_t<void> dst, wp::indexedfabricarray_t<void> src,
|
|
1424
|
+
static __global__ void array_copy_fabric_indexed_to_fabric_indexed_kernel(wp::indexedfabricarray_t<void> dst, wp::indexedfabricarray_t<void> src, size_t elem_size)
|
|
1197
1425
|
{
|
|
1198
|
-
|
|
1426
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1199
1427
|
|
|
1200
1428
|
if (tid < dst.size)
|
|
1201
1429
|
{
|
|
1202
|
-
|
|
1203
|
-
|
|
1430
|
+
size_t src_index = src.indices[tid];
|
|
1431
|
+
size_t dst_index = dst.indices[tid];
|
|
1204
1432
|
const void* src_ptr = fabricarray_element_ptr(src.fa, src_index, elem_size);
|
|
1205
1433
|
void* dst_ptr = fabricarray_element_ptr(dst.fa, dst_index, elem_size);
|
|
1206
1434
|
memcpy(dst_ptr, src_ptr, elem_size);
|
|
@@ -1439,9 +1667,9 @@ WP_API bool wp_array_copy_device(void* context, void* dst, void* src, int dst_ty
|
|
|
1439
1667
|
}
|
|
1440
1668
|
case 2:
|
|
1441
1669
|
{
|
|
1442
|
-
wp::vec_t<2,
|
|
1443
|
-
wp::vec_t<2,
|
|
1444
|
-
wp::vec_t<2,
|
|
1670
|
+
wp::vec_t<2, size_t> shape_v(src_shape[0], src_shape[1]);
|
|
1671
|
+
wp::vec_t<2, size_t> src_strides_v(src_strides[0], src_strides[1]);
|
|
1672
|
+
wp::vec_t<2, size_t> dst_strides_v(dst_strides[0], dst_strides[1]);
|
|
1445
1673
|
wp::vec_t<2, const int*> src_indices_v(src_indices[0], src_indices[1]);
|
|
1446
1674
|
wp::vec_t<2, const int*> dst_indices_v(dst_indices[0], dst_indices[1]);
|
|
1447
1675
|
|
|
@@ -1453,9 +1681,9 @@ WP_API bool wp_array_copy_device(void* context, void* dst, void* src, int dst_ty
|
|
|
1453
1681
|
}
|
|
1454
1682
|
case 3:
|
|
1455
1683
|
{
|
|
1456
|
-
wp::vec_t<3,
|
|
1457
|
-
wp::vec_t<3,
|
|
1458
|
-
wp::vec_t<3,
|
|
1684
|
+
wp::vec_t<3, size_t> shape_v(src_shape[0], src_shape[1], src_shape[2]);
|
|
1685
|
+
wp::vec_t<3, size_t> src_strides_v(src_strides[0], src_strides[1], src_strides[2]);
|
|
1686
|
+
wp::vec_t<3, size_t> dst_strides_v(dst_strides[0], dst_strides[1], dst_strides[2]);
|
|
1459
1687
|
wp::vec_t<3, const int*> src_indices_v(src_indices[0], src_indices[1], src_indices[2]);
|
|
1460
1688
|
wp::vec_t<3, const int*> dst_indices_v(dst_indices[0], dst_indices[1], dst_indices[2]);
|
|
1461
1689
|
|
|
@@ -1467,9 +1695,9 @@ WP_API bool wp_array_copy_device(void* context, void* dst, void* src, int dst_ty
|
|
|
1467
1695
|
}
|
|
1468
1696
|
case 4:
|
|
1469
1697
|
{
|
|
1470
|
-
wp::vec_t<4,
|
|
1471
|
-
wp::vec_t<4,
|
|
1472
|
-
wp::vec_t<4,
|
|
1698
|
+
wp::vec_t<4, size_t> shape_v(src_shape[0], src_shape[1], src_shape[2], src_shape[3]);
|
|
1699
|
+
wp::vec_t<4, size_t> src_strides_v(src_strides[0], src_strides[1], src_strides[2], src_strides[3]);
|
|
1700
|
+
wp::vec_t<4, size_t> dst_strides_v(dst_strides[0], dst_strides[1], dst_strides[2], dst_strides[3]);
|
|
1473
1701
|
wp::vec_t<4, const int*> src_indices_v(src_indices[0], src_indices[1], src_indices[2], src_indices[3]);
|
|
1474
1702
|
wp::vec_t<4, const int*> dst_indices_v(dst_indices[0], dst_indices[1], dst_indices[2], dst_indices[3]);
|
|
1475
1703
|
|
|
@@ -1489,94 +1717,94 @@ WP_API bool wp_array_copy_device(void* context, void* dst, void* src, int dst_ty
|
|
|
1489
1717
|
|
|
1490
1718
|
|
|
1491
1719
|
static __global__ void array_fill_1d_kernel(void* data,
|
|
1492
|
-
|
|
1493
|
-
|
|
1720
|
+
size_t n,
|
|
1721
|
+
size_t stride,
|
|
1494
1722
|
const int* indices,
|
|
1495
1723
|
const void* value,
|
|
1496
|
-
|
|
1724
|
+
size_t value_size)
|
|
1497
1725
|
{
|
|
1498
|
-
|
|
1726
|
+
size_t i = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1499
1727
|
if (i < n)
|
|
1500
1728
|
{
|
|
1501
|
-
|
|
1729
|
+
size_t idx = indices ? indices[i] : i;
|
|
1502
1730
|
char* p = (char*)data + idx * stride;
|
|
1503
1731
|
memcpy(p, value, value_size);
|
|
1504
1732
|
}
|
|
1505
1733
|
}
|
|
1506
1734
|
|
|
1507
1735
|
static __global__ void array_fill_2d_kernel(void* data,
|
|
1508
|
-
wp::vec_t<2,
|
|
1509
|
-
wp::vec_t<2,
|
|
1736
|
+
wp::vec_t<2, size_t> shape,
|
|
1737
|
+
wp::vec_t<2, size_t> strides,
|
|
1510
1738
|
wp::vec_t<2, const int*> indices,
|
|
1511
1739
|
const void* value,
|
|
1512
|
-
|
|
1740
|
+
size_t value_size)
|
|
1513
1741
|
{
|
|
1514
|
-
|
|
1515
|
-
|
|
1516
|
-
|
|
1517
|
-
|
|
1742
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1743
|
+
size_t n = shape[1];
|
|
1744
|
+
size_t i = tid / n;
|
|
1745
|
+
size_t j = tid % n;
|
|
1518
1746
|
if (i < shape[0] /*&& j < shape[1]*/)
|
|
1519
1747
|
{
|
|
1520
|
-
|
|
1521
|
-
|
|
1748
|
+
size_t idx0 = indices[0] ? indices[0][i] : i;
|
|
1749
|
+
size_t idx1 = indices[1] ? indices[1][j] : j;
|
|
1522
1750
|
char* p = (char*)data + idx0 * strides[0] + idx1 * strides[1];
|
|
1523
1751
|
memcpy(p, value, value_size);
|
|
1524
1752
|
}
|
|
1525
1753
|
}
|
|
1526
1754
|
|
|
1527
1755
|
static __global__ void array_fill_3d_kernel(void* data,
|
|
1528
|
-
wp::vec_t<3,
|
|
1529
|
-
wp::vec_t<3,
|
|
1756
|
+
wp::vec_t<3, size_t> shape,
|
|
1757
|
+
wp::vec_t<3, size_t> strides,
|
|
1530
1758
|
wp::vec_t<3, const int*> indices,
|
|
1531
1759
|
const void* value,
|
|
1532
|
-
|
|
1533
|
-
{
|
|
1534
|
-
|
|
1535
|
-
|
|
1536
|
-
|
|
1537
|
-
|
|
1538
|
-
|
|
1539
|
-
|
|
1760
|
+
size_t value_size)
|
|
1761
|
+
{
|
|
1762
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1763
|
+
size_t n = shape[1];
|
|
1764
|
+
size_t o = shape[2];
|
|
1765
|
+
size_t i = tid / (n * o);
|
|
1766
|
+
size_t j = tid % (n * o) / o;
|
|
1767
|
+
size_t k = tid % o;
|
|
1540
1768
|
if (i < shape[0] && j < shape[1] /*&& k < shape[2]*/)
|
|
1541
1769
|
{
|
|
1542
|
-
|
|
1543
|
-
|
|
1544
|
-
|
|
1770
|
+
size_t idx0 = indices[0] ? indices[0][i] : i;
|
|
1771
|
+
size_t idx1 = indices[1] ? indices[1][j] : j;
|
|
1772
|
+
size_t idx2 = indices[2] ? indices[2][k] : k;
|
|
1545
1773
|
char* p = (char*)data + idx0 * strides[0] + idx1 * strides[1] + idx2 * strides[2];
|
|
1546
1774
|
memcpy(p, value, value_size);
|
|
1547
1775
|
}
|
|
1548
1776
|
}
|
|
1549
1777
|
|
|
1550
1778
|
static __global__ void array_fill_4d_kernel(void* data,
|
|
1551
|
-
wp::vec_t<4,
|
|
1552
|
-
wp::vec_t<4,
|
|
1779
|
+
wp::vec_t<4, size_t> shape,
|
|
1780
|
+
wp::vec_t<4, size_t> strides,
|
|
1553
1781
|
wp::vec_t<4, const int*> indices,
|
|
1554
1782
|
const void* value,
|
|
1555
|
-
|
|
1556
|
-
{
|
|
1557
|
-
|
|
1558
|
-
|
|
1559
|
-
|
|
1560
|
-
|
|
1561
|
-
|
|
1562
|
-
|
|
1563
|
-
|
|
1564
|
-
|
|
1783
|
+
size_t value_size)
|
|
1784
|
+
{
|
|
1785
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1786
|
+
size_t n = shape[1];
|
|
1787
|
+
size_t o = shape[2];
|
|
1788
|
+
size_t p = shape[3];
|
|
1789
|
+
size_t i = tid / (n * o * p);
|
|
1790
|
+
size_t j = tid % (n * o * p) / (o * p);
|
|
1791
|
+
size_t k = tid % (o * p) / p;
|
|
1792
|
+
size_t l = tid % p;
|
|
1565
1793
|
if (i < shape[0] && j < shape[1] && k < shape[2] /*&& l < shape[3]*/)
|
|
1566
1794
|
{
|
|
1567
|
-
|
|
1568
|
-
|
|
1569
|
-
|
|
1570
|
-
|
|
1795
|
+
size_t idx0 = indices[0] ? indices[0][i] : i;
|
|
1796
|
+
size_t idx1 = indices[1] ? indices[1][j] : j;
|
|
1797
|
+
size_t idx2 = indices[2] ? indices[2][k] : k;
|
|
1798
|
+
size_t idx3 = indices[3] ? indices[3][l] : l;
|
|
1571
1799
|
char* p = (char*)data + idx0 * strides[0] + idx1 * strides[1] + idx2 * strides[2] + idx3 * strides[3];
|
|
1572
1800
|
memcpy(p, value, value_size);
|
|
1573
1801
|
}
|
|
1574
1802
|
}
|
|
1575
1803
|
|
|
1576
1804
|
|
|
1577
|
-
static __global__ void array_fill_fabric_kernel(wp::fabricarray_t<void> fa, const void* value,
|
|
1805
|
+
static __global__ void array_fill_fabric_kernel(wp::fabricarray_t<void> fa, const void* value, size_t value_size)
|
|
1578
1806
|
{
|
|
1579
|
-
|
|
1807
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1580
1808
|
if (tid < fa.size)
|
|
1581
1809
|
{
|
|
1582
1810
|
void* dst_ptr = fabricarray_element_ptr(fa, tid, value_size);
|
|
@@ -1585,9 +1813,9 @@ static __global__ void array_fill_fabric_kernel(wp::fabricarray_t<void> fa, cons
|
|
|
1585
1813
|
}
|
|
1586
1814
|
|
|
1587
1815
|
|
|
1588
|
-
static __global__ void array_fill_fabric_indexed_kernel(wp::indexedfabricarray_t<void> ifa, const void* value,
|
|
1816
|
+
static __global__ void array_fill_fabric_indexed_kernel(wp::indexedfabricarray_t<void> ifa, const void* value, size_t value_size)
|
|
1589
1817
|
{
|
|
1590
|
-
|
|
1818
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1591
1819
|
if (tid < ifa.size)
|
|
1592
1820
|
{
|
|
1593
1821
|
size_t idx = size_t(ifa.indices[tid]);
|
|
@@ -1654,67 +1882,76 @@ WP_API void wp_array_fill_device(void* context, void* arr_ptr, int arr_type, con
|
|
|
1654
1882
|
|
|
1655
1883
|
ContextGuard guard(context);
|
|
1656
1884
|
|
|
1657
|
-
//
|
|
1658
|
-
|
|
1659
|
-
|
|
1660
|
-
|
|
1885
|
+
void* value_devptr = NULL; // fill value in device memory
|
|
1886
|
+
bool free_devptr = true; // whether we need to free the memory
|
|
1887
|
+
|
|
1888
|
+
// prepare the fill value in a graph-friendly way
|
|
1889
|
+
if (!capturable_tmp_alloc(WP_CURRENT_CONTEXT, value_ptr, value_size, &value_devptr, &free_devptr))
|
|
1890
|
+
{
|
|
1891
|
+
fprintf(stderr, "Warp fill error: failed to copy value to device memory\n");
|
|
1892
|
+
return;
|
|
1893
|
+
}
|
|
1661
1894
|
|
|
1662
|
-
// handle fabric arrays
|
|
1663
1895
|
if (fa)
|
|
1664
1896
|
{
|
|
1897
|
+
// handle fabric arrays
|
|
1665
1898
|
wp_launch_device(WP_CURRENT_CONTEXT, array_fill_fabric_kernel, n,
|
|
1666
1899
|
(*fa, value_devptr, value_size));
|
|
1667
|
-
return;
|
|
1668
1900
|
}
|
|
1669
1901
|
else if (ifa)
|
|
1670
1902
|
{
|
|
1903
|
+
// handle indexed fabric arrays
|
|
1671
1904
|
wp_launch_device(WP_CURRENT_CONTEXT, array_fill_fabric_indexed_kernel, n,
|
|
1672
1905
|
(*ifa, value_devptr, value_size));
|
|
1673
|
-
return;
|
|
1674
|
-
}
|
|
1675
|
-
|
|
1676
|
-
// handle regular or indexed arrays
|
|
1677
|
-
switch (ndim)
|
|
1678
|
-
{
|
|
1679
|
-
case 1:
|
|
1680
|
-
{
|
|
1681
|
-
wp_launch_device(WP_CURRENT_CONTEXT, array_fill_1d_kernel, n,
|
|
1682
|
-
(data, shape[0], strides[0], indices[0], value_devptr, value_size));
|
|
1683
|
-
break;
|
|
1684
|
-
}
|
|
1685
|
-
case 2:
|
|
1686
|
-
{
|
|
1687
|
-
wp::vec_t<2, int> shape_v(shape[0], shape[1]);
|
|
1688
|
-
wp::vec_t<2, int> strides_v(strides[0], strides[1]);
|
|
1689
|
-
wp::vec_t<2, const int*> indices_v(indices[0], indices[1]);
|
|
1690
|
-
wp_launch_device(WP_CURRENT_CONTEXT, array_fill_2d_kernel, n,
|
|
1691
|
-
(data, shape_v, strides_v, indices_v, value_devptr, value_size));
|
|
1692
|
-
break;
|
|
1693
1906
|
}
|
|
1694
|
-
|
|
1907
|
+
else
|
|
1695
1908
|
{
|
|
1696
|
-
|
|
1697
|
-
|
|
1698
|
-
|
|
1699
|
-
|
|
1700
|
-
|
|
1701
|
-
|
|
1909
|
+
// handle regular or indexed arrays
|
|
1910
|
+
switch (ndim)
|
|
1911
|
+
{
|
|
1912
|
+
case 1:
|
|
1913
|
+
{
|
|
1914
|
+
wp_launch_device(WP_CURRENT_CONTEXT, array_fill_1d_kernel, n,
|
|
1915
|
+
(data, shape[0], strides[0], indices[0], value_devptr, value_size));
|
|
1916
|
+
break;
|
|
1917
|
+
}
|
|
1918
|
+
case 2:
|
|
1919
|
+
{
|
|
1920
|
+
wp::vec_t<2, size_t> shape_v(shape[0], shape[1]);
|
|
1921
|
+
wp::vec_t<2, size_t> strides_v(strides[0], strides[1]);
|
|
1922
|
+
wp::vec_t<2, const int*> indices_v(indices[0], indices[1]);
|
|
1923
|
+
wp_launch_device(WP_CURRENT_CONTEXT, array_fill_2d_kernel, n,
|
|
1924
|
+
(data, shape_v, strides_v, indices_v, value_devptr, value_size));
|
|
1925
|
+
break;
|
|
1926
|
+
}
|
|
1927
|
+
case 3:
|
|
1928
|
+
{
|
|
1929
|
+
wp::vec_t<3, size_t> shape_v(shape[0], shape[1], shape[2]);
|
|
1930
|
+
wp::vec_t<3, size_t> strides_v(strides[0], strides[1], strides[2]);
|
|
1931
|
+
wp::vec_t<3, const int*> indices_v(indices[0], indices[1], indices[2]);
|
|
1932
|
+
wp_launch_device(WP_CURRENT_CONTEXT, array_fill_3d_kernel, n,
|
|
1933
|
+
(data, shape_v, strides_v, indices_v, value_devptr, value_size));
|
|
1934
|
+
break;
|
|
1935
|
+
}
|
|
1936
|
+
case 4:
|
|
1937
|
+
{
|
|
1938
|
+
wp::vec_t<4, size_t> shape_v(shape[0], shape[1], shape[2], shape[3]);
|
|
1939
|
+
wp::vec_t<4, size_t> strides_v(strides[0], strides[1], strides[2], strides[3]);
|
|
1940
|
+
wp::vec_t<4, const int*> indices_v(indices[0], indices[1], indices[2], indices[3]);
|
|
1941
|
+
wp_launch_device(WP_CURRENT_CONTEXT, array_fill_4d_kernel, n,
|
|
1942
|
+
(data, shape_v, strides_v, indices_v, value_devptr, value_size));
|
|
1943
|
+
break;
|
|
1944
|
+
}
|
|
1945
|
+
default:
|
|
1946
|
+
fprintf(stderr, "Warp fill error: invalid array dimensionality (%d)\n", ndim);
|
|
1947
|
+
break;
|
|
1948
|
+
}
|
|
1702
1949
|
}
|
|
1703
|
-
|
|
1950
|
+
|
|
1951
|
+
if (free_devptr)
|
|
1704
1952
|
{
|
|
1705
|
-
|
|
1706
|
-
wp::vec_t<4, int> strides_v(strides[0], strides[1], strides[2], strides[3]);
|
|
1707
|
-
wp::vec_t<4, const int*> indices_v(indices[0], indices[1], indices[2], indices[3]);
|
|
1708
|
-
wp_launch_device(WP_CURRENT_CONTEXT, array_fill_4d_kernel, n,
|
|
1709
|
-
(data, shape_v, strides_v, indices_v, value_devptr, value_size));
|
|
1710
|
-
break;
|
|
1953
|
+
wp_free_device(WP_CURRENT_CONTEXT, value_devptr);
|
|
1711
1954
|
}
|
|
1712
|
-
default:
|
|
1713
|
-
fprintf(stderr, "Warp fill error: invalid array dimensionality (%d)\n", ndim);
|
|
1714
|
-
return;
|
|
1715
|
-
}
|
|
1716
|
-
|
|
1717
|
-
wp_free_device(WP_CURRENT_CONTEXT, value_devptr);
|
|
1718
1955
|
}
|
|
1719
1956
|
|
|
1720
1957
|
void wp_array_scan_int_device(uint64_t in, uint64_t out, int len, bool inclusive)
|
|
@@ -2071,14 +2308,15 @@ void wp_cuda_context_synchronize(void* context)
|
|
|
2071
2308
|
|
|
2072
2309
|
check_cu(cuCtxSynchronize_f());
|
|
2073
2310
|
|
|
2074
|
-
if (
|
|
2311
|
+
if (!context)
|
|
2312
|
+
context = get_current_context();
|
|
2313
|
+
|
|
2314
|
+
if (run_deferred_actions(context) > 0)
|
|
2075
2315
|
{
|
|
2076
|
-
// ensure deferred asynchronous
|
|
2316
|
+
// ensure deferred asynchronous operations complete
|
|
2077
2317
|
check_cu(cuCtxSynchronize_f());
|
|
2078
2318
|
}
|
|
2079
2319
|
|
|
2080
|
-
unload_deferred_modules(context);
|
|
2081
|
-
|
|
2082
2320
|
// check_cuda(cudaDeviceGraphMemTrim(wp_cuda_context_get_device_ordinal(context)));
|
|
2083
2321
|
}
|
|
2084
2322
|
|
|
@@ -2448,6 +2686,9 @@ void wp_cuda_stream_destroy(void* context, void* stream)
|
|
|
2448
2686
|
|
|
2449
2687
|
wp_cuda_stream_unregister(context, stream);
|
|
2450
2688
|
|
|
2689
|
+
// release temporary radix sort buffer associated with this stream
|
|
2690
|
+
radix_sort_release(context, stream);
|
|
2691
|
+
|
|
2451
2692
|
check_cu(cuStreamDestroy_f(static_cast<CUstream>(stream)));
|
|
2452
2693
|
}
|
|
2453
2694
|
|
|
@@ -2510,15 +2751,36 @@ void wp_cuda_stream_synchronize(void* stream)
|
|
|
2510
2751
|
check_cu(cuStreamSynchronize_f(static_cast<CUstream>(stream)));
|
|
2511
2752
|
}
|
|
2512
2753
|
|
|
2513
|
-
void wp_cuda_stream_wait_event(void* stream, void* event)
|
|
2754
|
+
void wp_cuda_stream_wait_event(void* stream, void* event, bool external)
|
|
2514
2755
|
{
|
|
2515
|
-
|
|
2756
|
+
// the external flag can only be used during graph capture
|
|
2757
|
+
if (external && !g_captures.empty() && wp_cuda_stream_is_capturing(stream))
|
|
2758
|
+
{
|
|
2759
|
+
// wait for an external event during graph capture
|
|
2760
|
+
check_cu(cuStreamWaitEvent_f(static_cast<CUstream>(stream), static_cast<CUevent>(event), CU_EVENT_WAIT_EXTERNAL));
|
|
2761
|
+
}
|
|
2762
|
+
else
|
|
2763
|
+
{
|
|
2764
|
+
check_cu(cuStreamWaitEvent_f(static_cast<CUstream>(stream), static_cast<CUevent>(event), CU_EVENT_WAIT_DEFAULT));
|
|
2765
|
+
}
|
|
2516
2766
|
}
|
|
2517
2767
|
|
|
2518
|
-
void wp_cuda_stream_wait_stream(void* stream, void* other_stream, void* event)
|
|
2768
|
+
void wp_cuda_stream_wait_stream(void* stream, void* other_stream, void* event, bool external)
|
|
2519
2769
|
{
|
|
2520
|
-
|
|
2521
|
-
|
|
2770
|
+
unsigned record_flags = CU_EVENT_RECORD_DEFAULT;
|
|
2771
|
+
unsigned wait_flags = CU_EVENT_WAIT_DEFAULT;
|
|
2772
|
+
|
|
2773
|
+
// the external flag can only be used during graph capture
|
|
2774
|
+
if (external && !g_captures.empty())
|
|
2775
|
+
{
|
|
2776
|
+
if (wp_cuda_stream_is_capturing(other_stream))
|
|
2777
|
+
record_flags = CU_EVENT_RECORD_EXTERNAL;
|
|
2778
|
+
if (wp_cuda_stream_is_capturing(stream))
|
|
2779
|
+
wait_flags = CU_EVENT_WAIT_EXTERNAL;
|
|
2780
|
+
}
|
|
2781
|
+
|
|
2782
|
+
check_cu(cuEventRecordWithFlags_f(static_cast<CUevent>(event), static_cast<CUstream>(other_stream), record_flags));
|
|
2783
|
+
check_cu(cuStreamWaitEvent_f(static_cast<CUstream>(stream), static_cast<CUevent>(event), wait_flags));
|
|
2522
2784
|
}
|
|
2523
2785
|
|
|
2524
2786
|
int wp_cuda_stream_is_capturing(void* stream)
|
|
@@ -2571,11 +2833,12 @@ int wp_cuda_event_query(void* event)
|
|
|
2571
2833
|
return res;
|
|
2572
2834
|
}
|
|
2573
2835
|
|
|
2574
|
-
void wp_cuda_event_record(void* event, void* stream, bool
|
|
2836
|
+
void wp_cuda_event_record(void* event, void* stream, bool external)
|
|
2575
2837
|
{
|
|
2576
|
-
|
|
2838
|
+
// the external flag can only be used during graph capture
|
|
2839
|
+
if (external && !g_captures.empty() && wp_cuda_stream_is_capturing(stream))
|
|
2577
2840
|
{
|
|
2578
|
-
// record
|
|
2841
|
+
// record external event during graph capture (e.g., for timing or when explicitly specified by the user)
|
|
2579
2842
|
check_cu(cuEventRecordWithFlags_f(static_cast<CUevent>(event), static_cast<CUstream>(stream), CU_EVENT_RECORD_EXTERNAL));
|
|
2580
2843
|
}
|
|
2581
2844
|
else
|
|
@@ -2625,7 +2888,7 @@ bool wp_cuda_graph_begin_capture(void* context, void* stream, int external)
|
|
|
2625
2888
|
else
|
|
2626
2889
|
{
|
|
2627
2890
|
// start the capture
|
|
2628
|
-
if (!check_cuda(cudaStreamBeginCapture(cuda_stream,
|
|
2891
|
+
if (!check_cuda(cudaStreamBeginCapture(cuda_stream, cudaStreamCaptureModeThreadLocal)))
|
|
2629
2892
|
return false;
|
|
2630
2893
|
}
|
|
2631
2894
|
|
|
@@ -2669,6 +2932,7 @@ bool wp_cuda_graph_end_capture(void* context, void* stream, void** graph_ret)
|
|
|
2669
2932
|
// get capture info
|
|
2670
2933
|
bool external = capture->external;
|
|
2671
2934
|
uint64_t capture_id = capture->id;
|
|
2935
|
+
std::vector<FreeInfo> tmp_allocs = capture->tmp_allocs;
|
|
2672
2936
|
|
|
2673
2937
|
// clear capture info
|
|
2674
2938
|
stream_info->capture = NULL;
|
|
@@ -2738,15 +3002,17 @@ bool wp_cuda_graph_end_capture(void* context, void* stream, void** graph_ret)
|
|
|
2738
3002
|
unfreed_allocs.push_back(it->first);
|
|
2739
3003
|
}
|
|
2740
3004
|
|
|
2741
|
-
if (!unfreed_allocs.empty())
|
|
3005
|
+
if (!unfreed_allocs.empty() || !tmp_allocs.empty())
|
|
2742
3006
|
{
|
|
2743
3007
|
// Create a user object that will notify us when the instantiated graph is destroyed.
|
|
2744
3008
|
// This works for external captures also, since we wouldn't otherwise know when
|
|
2745
3009
|
// the externally-created graph instance gets deleted.
|
|
2746
3010
|
// This callback is guaranteed to arrive after the graph has finished executing on the device,
|
|
2747
3011
|
// not necessarily when cudaGraphExecDestroy() is called.
|
|
2748
|
-
|
|
3012
|
+
GraphDestroyCallbackInfo* graph_info = new GraphDestroyCallbackInfo;
|
|
3013
|
+
graph_info->context = context ? context : get_current_context();
|
|
2749
3014
|
graph_info->unfreed_allocs = unfreed_allocs;
|
|
3015
|
+
graph_info->tmp_allocs = tmp_allocs;
|
|
2750
3016
|
cudaUserObject_t user_object;
|
|
2751
3017
|
check_cuda(cudaUserObjectCreate(&user_object, graph_info, on_graph_destroy, 1, cudaUserObjectNoDestructorSync));
|
|
2752
3018
|
check_cuda(cudaGraphRetainUserObject(graph, user_object, 1, cudaGraphUserObjectMove));
|
|
@@ -2770,8 +3036,7 @@ bool wp_cuda_graph_end_capture(void* context, void* stream, void** graph_ret)
|
|
|
2770
3036
|
// process deferred free list if no more captures are ongoing
|
|
2771
3037
|
if (g_captures.empty())
|
|
2772
3038
|
{
|
|
2773
|
-
|
|
2774
|
-
unload_deferred_modules();
|
|
3039
|
+
run_deferred_actions();
|
|
2775
3040
|
}
|
|
2776
3041
|
|
|
2777
3042
|
if (graph_ret)
|
|
@@ -2811,11 +3076,12 @@ bool wp_cuda_graph_create_exec(void* context, void* stream, void* graph, void**
|
|
|
2811
3076
|
// Support for conditional graph nodes available with CUDA 12.4+.
|
|
2812
3077
|
#if CUDA_VERSION >= 12040
|
|
2813
3078
|
|
|
2814
|
-
// CUBIN data for compiled conditional modules, loaded on demand, keyed on device architecture
|
|
2815
|
-
|
|
3079
|
+
// CUBIN or PTX data for compiled conditional modules, loaded on demand, keyed on device architecture
|
|
3080
|
+
using ModuleKey = std::pair<int, bool>; // <arch, use_ptx>
|
|
3081
|
+
static std::map<ModuleKey, void*> g_conditional_modules;
|
|
2816
3082
|
|
|
2817
3083
|
// Compile module with conditional helper kernels
|
|
2818
|
-
static void* compile_conditional_module(int arch)
|
|
3084
|
+
static void* compile_conditional_module(int arch, bool use_ptx)
|
|
2819
3085
|
{
|
|
2820
3086
|
static const char* kernel_source = R"(
|
|
2821
3087
|
typedef __device_builtin__ unsigned long long cudaGraphConditionalHandle;
|
|
@@ -2844,8 +3110,9 @@ static void* compile_conditional_module(int arch)
|
|
|
2844
3110
|
)";
|
|
2845
3111
|
|
|
2846
3112
|
// avoid recompilation
|
|
2847
|
-
|
|
2848
|
-
|
|
3113
|
+
ModuleKey key = {arch, use_ptx};
|
|
3114
|
+
auto it = g_conditional_modules.find(key);
|
|
3115
|
+
if (it != g_conditional_modules.end())
|
|
2849
3116
|
return it->second;
|
|
2850
3117
|
|
|
2851
3118
|
nvrtcProgram prog;
|
|
@@ -2853,11 +3120,23 @@ static void* compile_conditional_module(int arch)
|
|
|
2853
3120
|
return NULL;
|
|
2854
3121
|
|
|
2855
3122
|
char arch_opt[128];
|
|
2856
|
-
|
|
3123
|
+
if (use_ptx)
|
|
3124
|
+
snprintf(arch_opt, sizeof(arch_opt), "--gpu-architecture=compute_%d", arch);
|
|
3125
|
+
else
|
|
3126
|
+
snprintf(arch_opt, sizeof(arch_opt), "--gpu-architecture=sm_%d", arch);
|
|
2857
3127
|
|
|
2858
3128
|
std::vector<const char*> opts;
|
|
2859
3129
|
opts.push_back(arch_opt);
|
|
2860
3130
|
|
|
3131
|
+
const bool print_debug = (std::getenv("WARP_DEBUG") != nullptr);
|
|
3132
|
+
if (print_debug)
|
|
3133
|
+
{
|
|
3134
|
+
printf("NVRTC options (conditional module, arch=%d, use_ptx=%s):\n", arch, use_ptx ? "true" : "false");
|
|
3135
|
+
for(auto o: opts) {
|
|
3136
|
+
printf("%s\n", o);
|
|
3137
|
+
}
|
|
3138
|
+
}
|
|
3139
|
+
|
|
2861
3140
|
if (!check_nvrtc(nvrtcCompileProgram(prog, int(opts.size()), opts.data())))
|
|
2862
3141
|
{
|
|
2863
3142
|
size_t log_size;
|
|
@@ -2874,23 +3153,37 @@ static void* compile_conditional_module(int arch)
|
|
|
2874
3153
|
// get output
|
|
2875
3154
|
char* output = NULL;
|
|
2876
3155
|
size_t output_size = 0;
|
|
2877
|
-
|
|
2878
|
-
if (
|
|
3156
|
+
|
|
3157
|
+
if (use_ptx)
|
|
3158
|
+
{
|
|
3159
|
+
check_nvrtc(nvrtcGetPTXSize(prog, &output_size));
|
|
3160
|
+
if (output_size > 0)
|
|
3161
|
+
{
|
|
3162
|
+
output = new char[output_size];
|
|
3163
|
+
if (check_nvrtc(nvrtcGetPTX(prog, output)))
|
|
3164
|
+
g_conditional_modules[key] = output;
|
|
3165
|
+
}
|
|
3166
|
+
}
|
|
3167
|
+
else
|
|
2879
3168
|
{
|
|
2880
|
-
|
|
2881
|
-
if (
|
|
2882
|
-
|
|
3169
|
+
check_nvrtc(nvrtcGetCUBINSize(prog, &output_size));
|
|
3170
|
+
if (output_size > 0)
|
|
3171
|
+
{
|
|
3172
|
+
output = new char[output_size];
|
|
3173
|
+
if (check_nvrtc(nvrtcGetCUBIN(prog, output)))
|
|
3174
|
+
g_conditional_modules[key] = output;
|
|
3175
|
+
}
|
|
2883
3176
|
}
|
|
2884
3177
|
|
|
2885
3178
|
nvrtcDestroyProgram(&prog);
|
|
2886
3179
|
|
|
2887
|
-
// return CUBIN data
|
|
3180
|
+
// return CUBIN or PTX data
|
|
2888
3181
|
return output;
|
|
2889
3182
|
}
|
|
2890
3183
|
|
|
2891
3184
|
|
|
2892
3185
|
// Load module with conditional helper kernels
|
|
2893
|
-
static CUmodule load_conditional_module(void* context)
|
|
3186
|
+
static CUmodule load_conditional_module(void* context, int arch, bool use_ptx)
|
|
2894
3187
|
{
|
|
2895
3188
|
ContextInfo* context_info = get_context_info(context);
|
|
2896
3189
|
if (!context_info)
|
|
@@ -2900,17 +3193,15 @@ static CUmodule load_conditional_module(void* context)
|
|
|
2900
3193
|
if (context_info->conditional_module)
|
|
2901
3194
|
return context_info->conditional_module;
|
|
2902
3195
|
|
|
2903
|
-
int arch = context_info->device_info->arch;
|
|
2904
|
-
|
|
2905
3196
|
// compile if needed
|
|
2906
|
-
void* compiled_module = compile_conditional_module(arch);
|
|
3197
|
+
void* compiled_module = compile_conditional_module(arch, use_ptx);
|
|
2907
3198
|
if (!compiled_module)
|
|
2908
3199
|
{
|
|
2909
3200
|
fprintf(stderr, "Warp error: Failed to compile conditional kernels\n");
|
|
2910
3201
|
return NULL;
|
|
2911
3202
|
}
|
|
2912
3203
|
|
|
2913
|
-
// load module
|
|
3204
|
+
// load module (handles both PTX and CUBIN data automatically)
|
|
2914
3205
|
CUmodule module = NULL;
|
|
2915
3206
|
if (!check_cu(cuModuleLoadDataEx_f(&module, compiled_module, 0, NULL, NULL)))
|
|
2916
3207
|
{
|
|
@@ -2923,10 +3214,10 @@ static CUmodule load_conditional_module(void* context)
|
|
|
2923
3214
|
return module;
|
|
2924
3215
|
}
|
|
2925
3216
|
|
|
2926
|
-
static CUfunction get_conditional_kernel(void* context, const char* name)
|
|
3217
|
+
static CUfunction get_conditional_kernel(void* context, int arch, bool use_ptx, const char* name)
|
|
2927
3218
|
{
|
|
2928
3219
|
// load module if needed
|
|
2929
|
-
CUmodule module = load_conditional_module(context);
|
|
3220
|
+
CUmodule module = load_conditional_module(context, arch, use_ptx);
|
|
2930
3221
|
if (!module)
|
|
2931
3222
|
return NULL;
|
|
2932
3223
|
|
|
@@ -2966,7 +3257,7 @@ bool wp_cuda_graph_resume_capture(void* context, void* stream, void* graph)
|
|
|
2966
3257
|
leaf_nodes.data(),
|
|
2967
3258
|
nullptr,
|
|
2968
3259
|
leaf_nodes.size(),
|
|
2969
|
-
|
|
3260
|
+
cudaStreamCaptureModeThreadLocal)))
|
|
2970
3261
|
return false;
|
|
2971
3262
|
|
|
2972
3263
|
return true;
|
|
@@ -2976,7 +3267,7 @@ bool wp_cuda_graph_resume_capture(void* context, void* stream, void* graph)
|
|
|
2976
3267
|
// https://developer.nvidia.com/blog/dynamic-control-flow-in-cuda-graphs-with-conditional-nodes/
|
|
2977
3268
|
// condition is a gpu pointer
|
|
2978
3269
|
// if_graph_ret and else_graph_ret should be NULL if not needed
|
|
2979
|
-
bool wp_cuda_graph_insert_if_else(void* context, void* stream, int* condition, void** if_graph_ret, void** else_graph_ret)
|
|
3270
|
+
bool wp_cuda_graph_insert_if_else(void* context, void* stream, int arch, bool use_ptx, int* condition, void** if_graph_ret, void** else_graph_ret)
|
|
2980
3271
|
{
|
|
2981
3272
|
bool has_if = if_graph_ret != NULL;
|
|
2982
3273
|
bool has_else = else_graph_ret != NULL;
|
|
@@ -3019,9 +3310,9 @@ bool wp_cuda_graph_insert_if_else(void* context, void* stream, int* condition, v
|
|
|
3019
3310
|
// (need to negate the condition if only the else branch is used)
|
|
3020
3311
|
CUfunction kernel;
|
|
3021
3312
|
if (has_if)
|
|
3022
|
-
kernel = get_conditional_kernel(context, "set_conditional_if_handle_kernel");
|
|
3313
|
+
kernel = get_conditional_kernel(context, arch, use_ptx, "set_conditional_if_handle_kernel");
|
|
3023
3314
|
else
|
|
3024
|
-
kernel = get_conditional_kernel(context, "set_conditional_else_handle_kernel");
|
|
3315
|
+
kernel = get_conditional_kernel(context, arch, use_ptx, "set_conditional_else_handle_kernel");
|
|
3025
3316
|
|
|
3026
3317
|
if (!kernel)
|
|
3027
3318
|
{
|
|
@@ -3072,7 +3363,7 @@ bool wp_cuda_graph_insert_if_else(void* context, void* stream, int* condition, v
|
|
|
3072
3363
|
check_cuda(cudaGraphConditionalHandleCreate(&if_handle, cuda_graph));
|
|
3073
3364
|
check_cuda(cudaGraphConditionalHandleCreate(&else_handle, cuda_graph));
|
|
3074
3365
|
|
|
3075
|
-
CUfunction kernel = get_conditional_kernel(context, "set_conditional_if_else_handles_kernel");
|
|
3366
|
+
CUfunction kernel = get_conditional_kernel(context, arch, use_ptx, "set_conditional_if_else_handles_kernel");
|
|
3076
3367
|
if (!kernel)
|
|
3077
3368
|
{
|
|
3078
3369
|
wp::set_error_string("Failed to get built-in conditional kernel");
|
|
@@ -3273,7 +3564,7 @@ bool wp_cuda_graph_insert_child_graph(void* context, void* stream, void* child_g
|
|
|
3273
3564
|
return true;
|
|
3274
3565
|
}
|
|
3275
3566
|
|
|
3276
|
-
bool wp_cuda_graph_insert_while(void* context, void* stream, int* condition, void** body_graph_ret, uint64_t* handle_ret)
|
|
3567
|
+
bool wp_cuda_graph_insert_while(void* context, void* stream, int arch, bool use_ptx, int* condition, void** body_graph_ret, uint64_t* handle_ret)
|
|
3277
3568
|
{
|
|
3278
3569
|
// if there's no body, it's a no-op
|
|
3279
3570
|
if (!body_graph_ret)
|
|
@@ -3303,7 +3594,7 @@ bool wp_cuda_graph_insert_while(void* context, void* stream, int* condition, voi
|
|
|
3303
3594
|
return false;
|
|
3304
3595
|
|
|
3305
3596
|
// launch a kernel to set the condition handle from condition pointer
|
|
3306
|
-
CUfunction kernel = get_conditional_kernel(context, "set_conditional_if_handle_kernel");
|
|
3597
|
+
CUfunction kernel = get_conditional_kernel(context, arch, use_ptx, "set_conditional_if_handle_kernel");
|
|
3307
3598
|
if (!kernel)
|
|
3308
3599
|
{
|
|
3309
3600
|
wp::set_error_string("Failed to get built-in conditional kernel");
|
|
@@ -3339,14 +3630,14 @@ bool wp_cuda_graph_insert_while(void* context, void* stream, int* condition, voi
|
|
|
3339
3630
|
return true;
|
|
3340
3631
|
}
|
|
3341
3632
|
|
|
3342
|
-
bool wp_cuda_graph_set_condition(void* context, void* stream, int* condition, uint64_t handle)
|
|
3633
|
+
bool wp_cuda_graph_set_condition(void* context, void* stream, int arch, bool use_ptx, int* condition, uint64_t handle)
|
|
3343
3634
|
{
|
|
3344
3635
|
ContextGuard guard(context);
|
|
3345
3636
|
|
|
3346
3637
|
CUstream cuda_stream = static_cast<CUstream>(stream);
|
|
3347
3638
|
|
|
3348
3639
|
// launch a kernel to set the condition handle from condition pointer
|
|
3349
|
-
CUfunction kernel = get_conditional_kernel(context, "set_conditional_if_handle_kernel");
|
|
3640
|
+
CUfunction kernel = get_conditional_kernel(context, arch, use_ptx, "set_conditional_if_handle_kernel");
|
|
3350
3641
|
if (!kernel)
|
|
3351
3642
|
{
|
|
3352
3643
|
wp::set_error_string("Failed to get built-in conditional kernel");
|
|
@@ -3378,19 +3669,19 @@ bool wp_cuda_graph_resume_capture(void* context, void* stream, void* graph)
|
|
|
3378
3669
|
return false;
|
|
3379
3670
|
}
|
|
3380
3671
|
|
|
3381
|
-
bool wp_cuda_graph_insert_if_else(void* context, void* stream, int* condition, void** if_graph_ret, void** else_graph_ret)
|
|
3672
|
+
bool wp_cuda_graph_insert_if_else(void* context, void* stream, int arch, bool use_ptx, int* condition, void** if_graph_ret, void** else_graph_ret)
|
|
3382
3673
|
{
|
|
3383
3674
|
wp::set_error_string("Warp error: Warp must be built with CUDA Toolkit 12.4+ to enable conditional graph nodes");
|
|
3384
3675
|
return false;
|
|
3385
3676
|
}
|
|
3386
3677
|
|
|
3387
|
-
bool wp_cuda_graph_insert_while(void* context, void* stream, int* condition, void** body_graph_ret, uint64_t* handle_ret)
|
|
3678
|
+
bool wp_cuda_graph_insert_while(void* context, void* stream, int arch, bool use_ptx, int* condition, void** body_graph_ret, uint64_t* handle_ret)
|
|
3388
3679
|
{
|
|
3389
3680
|
wp::set_error_string("Warp error: Warp must be built with CUDA Toolkit 12.4+ to enable conditional graph nodes");
|
|
3390
3681
|
return false;
|
|
3391
3682
|
}
|
|
3392
3683
|
|
|
3393
|
-
bool wp_cuda_graph_set_condition(void* context, void* stream, int* condition, uint64_t handle)
|
|
3684
|
+
bool wp_cuda_graph_set_condition(void* context, void* stream, int arch, bool use_ptx, int* condition, uint64_t handle)
|
|
3394
3685
|
{
|
|
3395
3686
|
wp::set_error_string("Warp error: Warp must be built with CUDA Toolkit 12.4+ to enable conditional graph nodes");
|
|
3396
3687
|
return false;
|
|
@@ -3425,16 +3716,38 @@ bool wp_cuda_graph_launch(void* graph_exec, void* stream)
|
|
|
3425
3716
|
|
|
3426
3717
|
bool wp_cuda_graph_destroy(void* context, void* graph)
|
|
3427
3718
|
{
|
|
3428
|
-
|
|
3429
|
-
|
|
3430
|
-
|
|
3719
|
+
// ensure there are no graph captures in progress
|
|
3720
|
+
if (g_captures.empty())
|
|
3721
|
+
{
|
|
3722
|
+
ContextGuard guard(context);
|
|
3723
|
+
return check_cuda(cudaGraphDestroy((cudaGraph_t)graph));
|
|
3724
|
+
}
|
|
3725
|
+
else
|
|
3726
|
+
{
|
|
3727
|
+
GraphDestroyInfo info;
|
|
3728
|
+
info.context = context ? context : get_current_context();
|
|
3729
|
+
info.graph = graph;
|
|
3730
|
+
g_deferred_graph_list.push_back(info);
|
|
3731
|
+
return true;
|
|
3732
|
+
}
|
|
3431
3733
|
}
|
|
3432
3734
|
|
|
3433
3735
|
bool wp_cuda_graph_exec_destroy(void* context, void* graph_exec)
|
|
3434
3736
|
{
|
|
3435
|
-
|
|
3436
|
-
|
|
3437
|
-
|
|
3737
|
+
// ensure there are no graph captures in progress
|
|
3738
|
+
if (g_captures.empty())
|
|
3739
|
+
{
|
|
3740
|
+
ContextGuard guard(context);
|
|
3741
|
+
return check_cuda(cudaGraphExecDestroy((cudaGraphExec_t)graph_exec));
|
|
3742
|
+
}
|
|
3743
|
+
else
|
|
3744
|
+
{
|
|
3745
|
+
GraphDestroyInfo info;
|
|
3746
|
+
info.context = context ? context : get_current_context();
|
|
3747
|
+
info.graph_exec = graph_exec;
|
|
3748
|
+
g_deferred_graph_list.push_back(info);
|
|
3749
|
+
return true;
|
|
3750
|
+
}
|
|
3438
3751
|
}
|
|
3439
3752
|
|
|
3440
3753
|
bool write_file(const char* data, size_t size, std::string filename, const char* mode)
|
|
@@ -4287,17 +4600,5 @@ void wp_cuda_timing_end(timing_result_t* results, int size)
|
|
|
4287
4600
|
g_cuda_timing_state = parent_state;
|
|
4288
4601
|
}
|
|
4289
4602
|
|
|
4290
|
-
// impl. files
|
|
4291
|
-
#include "bvh.cu"
|
|
4292
|
-
#include "mesh.cu"
|
|
4293
|
-
#include "sort.cu"
|
|
4294
|
-
#include "hashgrid.cu"
|
|
4295
|
-
#include "reduce.cu"
|
|
4296
|
-
#include "runlength_encode.cu"
|
|
4297
|
-
#include "scan.cu"
|
|
4298
|
-
#include "sparse.cu"
|
|
4299
|
-
#include "volume.cu"
|
|
4300
|
-
#include "volume_builder.cu"
|
|
4301
|
-
|
|
4302
4603
|
//#include "spline.inl"
|
|
4303
4604
|
//#include "volume.inl"
|