warp-lang 1.10.0__py3-none-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +334 -0
- warp/__init__.pyi +5856 -0
- warp/_src/__init__.py +14 -0
- warp/_src/autograd.py +1077 -0
- warp/_src/build.py +620 -0
- warp/_src/build_dll.py +642 -0
- warp/_src/builtins.py +10555 -0
- warp/_src/codegen.py +4361 -0
- warp/_src/config.py +178 -0
- warp/_src/constants.py +59 -0
- warp/_src/context.py +8352 -0
- warp/_src/dlpack.py +464 -0
- warp/_src/fabric.py +362 -0
- warp/_src/fem/__init__.py +14 -0
- warp/_src/fem/adaptivity.py +510 -0
- warp/_src/fem/cache.py +689 -0
- warp/_src/fem/dirichlet.py +190 -0
- warp/_src/fem/domain.py +553 -0
- warp/_src/fem/field/__init__.py +131 -0
- warp/_src/fem/field/field.py +703 -0
- warp/_src/fem/field/nodal_field.py +403 -0
- warp/_src/fem/field/restriction.py +39 -0
- warp/_src/fem/field/virtual.py +1021 -0
- warp/_src/fem/geometry/__init__.py +32 -0
- warp/_src/fem/geometry/adaptive_nanogrid.py +782 -0
- warp/_src/fem/geometry/closest_point.py +99 -0
- warp/_src/fem/geometry/deformed_geometry.py +277 -0
- warp/_src/fem/geometry/element.py +854 -0
- warp/_src/fem/geometry/geometry.py +693 -0
- warp/_src/fem/geometry/grid_2d.py +478 -0
- warp/_src/fem/geometry/grid_3d.py +539 -0
- warp/_src/fem/geometry/hexmesh.py +956 -0
- warp/_src/fem/geometry/nanogrid.py +660 -0
- warp/_src/fem/geometry/partition.py +483 -0
- warp/_src/fem/geometry/quadmesh.py +597 -0
- warp/_src/fem/geometry/tetmesh.py +762 -0
- warp/_src/fem/geometry/trimesh.py +588 -0
- warp/_src/fem/integrate.py +2507 -0
- warp/_src/fem/linalg.py +385 -0
- warp/_src/fem/operator.py +398 -0
- warp/_src/fem/polynomial.py +231 -0
- warp/_src/fem/quadrature/__init__.py +17 -0
- warp/_src/fem/quadrature/pic_quadrature.py +318 -0
- warp/_src/fem/quadrature/quadrature.py +665 -0
- warp/_src/fem/space/__init__.py +248 -0
- warp/_src/fem/space/basis_function_space.py +499 -0
- warp/_src/fem/space/basis_space.py +681 -0
- warp/_src/fem/space/dof_mapper.py +253 -0
- warp/_src/fem/space/function_space.py +312 -0
- warp/_src/fem/space/grid_2d_function_space.py +179 -0
- warp/_src/fem/space/grid_3d_function_space.py +229 -0
- warp/_src/fem/space/hexmesh_function_space.py +255 -0
- warp/_src/fem/space/nanogrid_function_space.py +199 -0
- warp/_src/fem/space/partition.py +435 -0
- warp/_src/fem/space/quadmesh_function_space.py +222 -0
- warp/_src/fem/space/restriction.py +221 -0
- warp/_src/fem/space/shape/__init__.py +152 -0
- warp/_src/fem/space/shape/cube_shape_function.py +1107 -0
- warp/_src/fem/space/shape/shape_function.py +134 -0
- warp/_src/fem/space/shape/square_shape_function.py +928 -0
- warp/_src/fem/space/shape/tet_shape_function.py +829 -0
- warp/_src/fem/space/shape/triangle_shape_function.py +674 -0
- warp/_src/fem/space/tetmesh_function_space.py +270 -0
- warp/_src/fem/space/topology.py +461 -0
- warp/_src/fem/space/trimesh_function_space.py +193 -0
- warp/_src/fem/types.py +114 -0
- warp/_src/fem/utils.py +488 -0
- warp/_src/jax.py +188 -0
- warp/_src/jax_experimental/__init__.py +14 -0
- warp/_src/jax_experimental/custom_call.py +389 -0
- warp/_src/jax_experimental/ffi.py +1286 -0
- warp/_src/jax_experimental/xla_ffi.py +658 -0
- warp/_src/marching_cubes.py +710 -0
- warp/_src/math.py +416 -0
- warp/_src/optim/__init__.py +14 -0
- warp/_src/optim/adam.py +165 -0
- warp/_src/optim/linear.py +1608 -0
- warp/_src/optim/sgd.py +114 -0
- warp/_src/paddle.py +408 -0
- warp/_src/render/__init__.py +14 -0
- warp/_src/render/imgui_manager.py +291 -0
- warp/_src/render/render_opengl.py +3638 -0
- warp/_src/render/render_usd.py +939 -0
- warp/_src/render/utils.py +162 -0
- warp/_src/sparse.py +2718 -0
- warp/_src/tape.py +1208 -0
- warp/_src/thirdparty/__init__.py +0 -0
- warp/_src/thirdparty/appdirs.py +598 -0
- warp/_src/thirdparty/dlpack.py +145 -0
- warp/_src/thirdparty/unittest_parallel.py +676 -0
- warp/_src/torch.py +393 -0
- warp/_src/types.py +5888 -0
- warp/_src/utils.py +1695 -0
- warp/autograd.py +33 -0
- warp/bin/libwarp-clang.dylib +0 -0
- warp/bin/libwarp.dylib +0 -0
- warp/build.py +29 -0
- warp/build_dll.py +24 -0
- warp/codegen.py +24 -0
- warp/constants.py +24 -0
- warp/context.py +33 -0
- warp/dlpack.py +24 -0
- warp/examples/__init__.py +24 -0
- warp/examples/assets/bear.usd +0 -0
- warp/examples/assets/bunny.usd +0 -0
- warp/examples/assets/cube.usd +0 -0
- warp/examples/assets/nonuniform.usd +0 -0
- warp/examples/assets/nvidia_logo.png +0 -0
- warp/examples/assets/pixel.jpg +0 -0
- warp/examples/assets/rocks.nvdb +0 -0
- warp/examples/assets/rocks.usd +0 -0
- warp/examples/assets/sphere.usd +0 -0
- warp/examples/assets/square_cloth.usd +0 -0
- warp/examples/benchmarks/benchmark_api.py +389 -0
- warp/examples/benchmarks/benchmark_cloth.py +296 -0
- warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
- warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
- warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
- warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
- warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
- warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
- warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
- warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
- warp/examples/benchmarks/benchmark_gemm.py +164 -0
- warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
- warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
- warp/examples/benchmarks/benchmark_launches.py +301 -0
- warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
- warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
- warp/examples/browse.py +37 -0
- warp/examples/core/example_cupy.py +86 -0
- warp/examples/core/example_dem.py +241 -0
- warp/examples/core/example_fluid.py +299 -0
- warp/examples/core/example_graph_capture.py +150 -0
- warp/examples/core/example_marching_cubes.py +195 -0
- warp/examples/core/example_mesh.py +180 -0
- warp/examples/core/example_mesh_intersect.py +211 -0
- warp/examples/core/example_nvdb.py +182 -0
- warp/examples/core/example_raycast.py +111 -0
- warp/examples/core/example_raymarch.py +205 -0
- warp/examples/core/example_render_opengl.py +290 -0
- warp/examples/core/example_sample_mesh.py +300 -0
- warp/examples/core/example_sph.py +411 -0
- warp/examples/core/example_spin_lock.py +93 -0
- warp/examples/core/example_torch.py +211 -0
- warp/examples/core/example_wave.py +269 -0
- warp/examples/core/example_work_queue.py +118 -0
- warp/examples/distributed/example_jacobi_mpi.py +506 -0
- warp/examples/fem/example_adaptive_grid.py +286 -0
- warp/examples/fem/example_apic_fluid.py +469 -0
- warp/examples/fem/example_burgers.py +261 -0
- warp/examples/fem/example_convection_diffusion.py +181 -0
- warp/examples/fem/example_convection_diffusion_dg.py +225 -0
- warp/examples/fem/example_darcy_ls_optimization.py +489 -0
- warp/examples/fem/example_deformed_geometry.py +172 -0
- warp/examples/fem/example_diffusion.py +196 -0
- warp/examples/fem/example_diffusion_3d.py +225 -0
- warp/examples/fem/example_diffusion_mgpu.py +225 -0
- warp/examples/fem/example_distortion_energy.py +228 -0
- warp/examples/fem/example_elastic_shape_optimization.py +387 -0
- warp/examples/fem/example_magnetostatics.py +242 -0
- warp/examples/fem/example_mixed_elasticity.py +293 -0
- warp/examples/fem/example_navier_stokes.py +263 -0
- warp/examples/fem/example_nonconforming_contact.py +300 -0
- warp/examples/fem/example_stokes.py +213 -0
- warp/examples/fem/example_stokes_transfer.py +262 -0
- warp/examples/fem/example_streamlines.py +357 -0
- warp/examples/fem/utils.py +1047 -0
- warp/examples/interop/example_jax_callable.py +146 -0
- warp/examples/interop/example_jax_ffi_callback.py +132 -0
- warp/examples/interop/example_jax_kernel.py +232 -0
- warp/examples/optim/example_diffray.py +561 -0
- warp/examples/optim/example_fluid_checkpoint.py +497 -0
- warp/examples/tile/example_tile_block_cholesky.py +502 -0
- warp/examples/tile/example_tile_cholesky.py +88 -0
- warp/examples/tile/example_tile_convolution.py +66 -0
- warp/examples/tile/example_tile_fft.py +55 -0
- warp/examples/tile/example_tile_filtering.py +113 -0
- warp/examples/tile/example_tile_matmul.py +85 -0
- warp/examples/tile/example_tile_mcgp.py +191 -0
- warp/examples/tile/example_tile_mlp.py +385 -0
- warp/examples/tile/example_tile_nbody.py +199 -0
- warp/fabric.py +24 -0
- warp/fem/__init__.py +173 -0
- warp/fem/adaptivity.py +26 -0
- warp/fem/cache.py +30 -0
- warp/fem/dirichlet.py +24 -0
- warp/fem/field/__init__.py +24 -0
- warp/fem/field/field.py +26 -0
- warp/fem/geometry/__init__.py +21 -0
- warp/fem/geometry/closest_point.py +31 -0
- warp/fem/linalg.py +38 -0
- warp/fem/operator.py +32 -0
- warp/fem/polynomial.py +29 -0
- warp/fem/space/__init__.py +22 -0
- warp/fem/space/basis_space.py +24 -0
- warp/fem/space/shape/__init__.py +68 -0
- warp/fem/space/topology.py +24 -0
- warp/fem/types.py +24 -0
- warp/fem/utils.py +32 -0
- warp/jax.py +29 -0
- warp/jax_experimental/__init__.py +29 -0
- warp/jax_experimental/custom_call.py +29 -0
- warp/jax_experimental/ffi.py +39 -0
- warp/jax_experimental/xla_ffi.py +24 -0
- warp/marching_cubes.py +24 -0
- warp/math.py +37 -0
- warp/native/array.h +1687 -0
- warp/native/builtin.h +2327 -0
- warp/native/bvh.cpp +562 -0
- warp/native/bvh.cu +826 -0
- warp/native/bvh.h +555 -0
- warp/native/clang/clang.cpp +541 -0
- warp/native/coloring.cpp +622 -0
- warp/native/crt.cpp +51 -0
- warp/native/crt.h +568 -0
- warp/native/cuda_crt.h +1058 -0
- warp/native/cuda_util.cpp +677 -0
- warp/native/cuda_util.h +313 -0
- warp/native/error.cpp +77 -0
- warp/native/error.h +36 -0
- warp/native/exports.h +2023 -0
- warp/native/fabric.h +246 -0
- warp/native/hashgrid.cpp +311 -0
- warp/native/hashgrid.cu +89 -0
- warp/native/hashgrid.h +240 -0
- warp/native/initializer_array.h +41 -0
- warp/native/intersect.h +1253 -0
- warp/native/intersect_adj.h +375 -0
- warp/native/intersect_tri.h +348 -0
- warp/native/mat.h +5189 -0
- warp/native/mathdx.cpp +93 -0
- warp/native/matnn.h +221 -0
- warp/native/mesh.cpp +266 -0
- warp/native/mesh.cu +406 -0
- warp/native/mesh.h +2097 -0
- warp/native/nanovdb/GridHandle.h +533 -0
- warp/native/nanovdb/HostBuffer.h +591 -0
- warp/native/nanovdb/NanoVDB.h +6246 -0
- warp/native/nanovdb/NodeManager.h +323 -0
- warp/native/nanovdb/PNanoVDB.h +3390 -0
- warp/native/noise.h +859 -0
- warp/native/quat.h +1664 -0
- warp/native/rand.h +342 -0
- warp/native/range.h +145 -0
- warp/native/reduce.cpp +174 -0
- warp/native/reduce.cu +363 -0
- warp/native/runlength_encode.cpp +79 -0
- warp/native/runlength_encode.cu +61 -0
- warp/native/scan.cpp +47 -0
- warp/native/scan.cu +55 -0
- warp/native/scan.h +23 -0
- warp/native/solid_angle.h +466 -0
- warp/native/sort.cpp +251 -0
- warp/native/sort.cu +286 -0
- warp/native/sort.h +35 -0
- warp/native/sparse.cpp +241 -0
- warp/native/sparse.cu +435 -0
- warp/native/spatial.h +1306 -0
- warp/native/svd.h +727 -0
- warp/native/temp_buffer.h +46 -0
- warp/native/tile.h +4124 -0
- warp/native/tile_radix_sort.h +1112 -0
- warp/native/tile_reduce.h +838 -0
- warp/native/tile_scan.h +240 -0
- warp/native/tuple.h +189 -0
- warp/native/vec.h +2199 -0
- warp/native/version.h +23 -0
- warp/native/volume.cpp +501 -0
- warp/native/volume.cu +68 -0
- warp/native/volume.h +970 -0
- warp/native/volume_builder.cu +483 -0
- warp/native/volume_builder.h +52 -0
- warp/native/volume_impl.h +70 -0
- warp/native/warp.cpp +1143 -0
- warp/native/warp.cu +4604 -0
- warp/native/warp.h +358 -0
- warp/optim/__init__.py +20 -0
- warp/optim/adam.py +24 -0
- warp/optim/linear.py +35 -0
- warp/optim/sgd.py +24 -0
- warp/paddle.py +24 -0
- warp/py.typed +0 -0
- warp/render/__init__.py +22 -0
- warp/render/imgui_manager.py +29 -0
- warp/render/render_opengl.py +24 -0
- warp/render/render_usd.py +24 -0
- warp/render/utils.py +24 -0
- warp/sparse.py +51 -0
- warp/tape.py +24 -0
- warp/tests/__init__.py +1 -0
- warp/tests/__main__.py +4 -0
- warp/tests/assets/curlnoise_golden.npy +0 -0
- warp/tests/assets/mlp_golden.npy +0 -0
- warp/tests/assets/pixel.npy +0 -0
- warp/tests/assets/pnoise_golden.npy +0 -0
- warp/tests/assets/spiky.usd +0 -0
- warp/tests/assets/test_grid.nvdb +0 -0
- warp/tests/assets/test_index_grid.nvdb +0 -0
- warp/tests/assets/test_int32_grid.nvdb +0 -0
- warp/tests/assets/test_vec_grid.nvdb +0 -0
- warp/tests/assets/torus.nvdb +0 -0
- warp/tests/assets/torus.usda +105 -0
- warp/tests/aux_test_class_kernel.py +34 -0
- warp/tests/aux_test_compile_consts_dummy.py +18 -0
- warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
- warp/tests/aux_test_dependent.py +29 -0
- warp/tests/aux_test_grad_customs.py +29 -0
- warp/tests/aux_test_instancing_gc.py +26 -0
- warp/tests/aux_test_module_aot.py +7 -0
- warp/tests/aux_test_module_unload.py +23 -0
- warp/tests/aux_test_name_clash1.py +40 -0
- warp/tests/aux_test_name_clash2.py +40 -0
- warp/tests/aux_test_reference.py +9 -0
- warp/tests/aux_test_reference_reference.py +8 -0
- warp/tests/aux_test_square.py +16 -0
- warp/tests/aux_test_unresolved_func.py +22 -0
- warp/tests/aux_test_unresolved_symbol.py +22 -0
- warp/tests/cuda/__init__.py +0 -0
- warp/tests/cuda/test_async.py +676 -0
- warp/tests/cuda/test_conditional_captures.py +1147 -0
- warp/tests/cuda/test_ipc.py +124 -0
- warp/tests/cuda/test_mempool.py +233 -0
- warp/tests/cuda/test_multigpu.py +169 -0
- warp/tests/cuda/test_peer.py +139 -0
- warp/tests/cuda/test_pinned.py +84 -0
- warp/tests/cuda/test_streams.py +691 -0
- warp/tests/geometry/__init__.py +0 -0
- warp/tests/geometry/test_bvh.py +335 -0
- warp/tests/geometry/test_hash_grid.py +259 -0
- warp/tests/geometry/test_marching_cubes.py +294 -0
- warp/tests/geometry/test_mesh.py +318 -0
- warp/tests/geometry/test_mesh_query_aabb.py +392 -0
- warp/tests/geometry/test_mesh_query_point.py +935 -0
- warp/tests/geometry/test_mesh_query_ray.py +323 -0
- warp/tests/geometry/test_volume.py +1103 -0
- warp/tests/geometry/test_volume_write.py +346 -0
- warp/tests/interop/__init__.py +0 -0
- warp/tests/interop/test_dlpack.py +730 -0
- warp/tests/interop/test_jax.py +1673 -0
- warp/tests/interop/test_paddle.py +800 -0
- warp/tests/interop/test_torch.py +1001 -0
- warp/tests/run_coverage_serial.py +39 -0
- warp/tests/test_adam.py +162 -0
- warp/tests/test_arithmetic.py +1096 -0
- warp/tests/test_array.py +3756 -0
- warp/tests/test_array_reduce.py +156 -0
- warp/tests/test_assert.py +303 -0
- warp/tests/test_atomic.py +336 -0
- warp/tests/test_atomic_bitwise.py +209 -0
- warp/tests/test_atomic_cas.py +312 -0
- warp/tests/test_bool.py +220 -0
- warp/tests/test_builtins_resolution.py +732 -0
- warp/tests/test_closest_point_edge_edge.py +327 -0
- warp/tests/test_codegen.py +974 -0
- warp/tests/test_codegen_instancing.py +1495 -0
- warp/tests/test_compile_consts.py +215 -0
- warp/tests/test_conditional.py +298 -0
- warp/tests/test_context.py +35 -0
- warp/tests/test_copy.py +319 -0
- warp/tests/test_ctypes.py +618 -0
- warp/tests/test_dense.py +73 -0
- warp/tests/test_devices.py +127 -0
- warp/tests/test_enum.py +136 -0
- warp/tests/test_examples.py +424 -0
- warp/tests/test_fabricarray.py +998 -0
- warp/tests/test_fast_math.py +72 -0
- warp/tests/test_fem.py +2204 -0
- warp/tests/test_fixedarray.py +229 -0
- warp/tests/test_fp16.py +136 -0
- warp/tests/test_func.py +501 -0
- warp/tests/test_future_annotations.py +100 -0
- warp/tests/test_generics.py +656 -0
- warp/tests/test_grad.py +893 -0
- warp/tests/test_grad_customs.py +339 -0
- warp/tests/test_grad_debug.py +341 -0
- warp/tests/test_implicit_init.py +411 -0
- warp/tests/test_import.py +45 -0
- warp/tests/test_indexedarray.py +1140 -0
- warp/tests/test_intersect.py +103 -0
- warp/tests/test_iter.py +76 -0
- warp/tests/test_large.py +177 -0
- warp/tests/test_launch.py +411 -0
- warp/tests/test_lerp.py +151 -0
- warp/tests/test_linear_solvers.py +223 -0
- warp/tests/test_lvalue.py +427 -0
- warp/tests/test_map.py +526 -0
- warp/tests/test_mat.py +3515 -0
- warp/tests/test_mat_assign_copy.py +178 -0
- warp/tests/test_mat_constructors.py +573 -0
- warp/tests/test_mat_lite.py +122 -0
- warp/tests/test_mat_scalar_ops.py +2913 -0
- warp/tests/test_math.py +212 -0
- warp/tests/test_module_aot.py +287 -0
- warp/tests/test_module_hashing.py +258 -0
- warp/tests/test_modules_lite.py +70 -0
- warp/tests/test_noise.py +252 -0
- warp/tests/test_operators.py +299 -0
- warp/tests/test_options.py +129 -0
- warp/tests/test_overwrite.py +551 -0
- warp/tests/test_print.py +408 -0
- warp/tests/test_quat.py +2653 -0
- warp/tests/test_quat_assign_copy.py +145 -0
- warp/tests/test_rand.py +339 -0
- warp/tests/test_reload.py +303 -0
- warp/tests/test_rounding.py +157 -0
- warp/tests/test_runlength_encode.py +196 -0
- warp/tests/test_scalar_ops.py +133 -0
- warp/tests/test_smoothstep.py +108 -0
- warp/tests/test_snippet.py +318 -0
- warp/tests/test_sparse.py +845 -0
- warp/tests/test_spatial.py +2859 -0
- warp/tests/test_spatial_assign_copy.py +160 -0
- warp/tests/test_special_values.py +361 -0
- warp/tests/test_static.py +640 -0
- warp/tests/test_struct.py +901 -0
- warp/tests/test_tape.py +242 -0
- warp/tests/test_transient_module.py +93 -0
- warp/tests/test_triangle_closest_point.py +192 -0
- warp/tests/test_tuple.py +361 -0
- warp/tests/test_types.py +615 -0
- warp/tests/test_utils.py +594 -0
- warp/tests/test_vec.py +1408 -0
- warp/tests/test_vec_assign_copy.py +143 -0
- warp/tests/test_vec_constructors.py +325 -0
- warp/tests/test_vec_lite.py +80 -0
- warp/tests/test_vec_scalar_ops.py +2327 -0
- warp/tests/test_verify_fp.py +100 -0
- warp/tests/test_version.py +75 -0
- warp/tests/tile/__init__.py +0 -0
- warp/tests/tile/test_tile.py +1519 -0
- warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
- warp/tests/tile/test_tile_cholesky.py +608 -0
- warp/tests/tile/test_tile_load.py +724 -0
- warp/tests/tile/test_tile_mathdx.py +156 -0
- warp/tests/tile/test_tile_matmul.py +179 -0
- warp/tests/tile/test_tile_mlp.py +400 -0
- warp/tests/tile/test_tile_reduce.py +950 -0
- warp/tests/tile/test_tile_shared_memory.py +376 -0
- warp/tests/tile/test_tile_sort.py +121 -0
- warp/tests/tile/test_tile_view.py +173 -0
- warp/tests/unittest_serial.py +47 -0
- warp/tests/unittest_suites.py +430 -0
- warp/tests/unittest_utils.py +469 -0
- warp/tests/walkthrough_debug.py +95 -0
- warp/torch.py +24 -0
- warp/types.py +51 -0
- warp/utils.py +31 -0
- warp_lang-1.10.0.dist-info/METADATA +459 -0
- warp_lang-1.10.0.dist-info/RECORD +468 -0
- warp_lang-1.10.0.dist-info/WHEEL +5 -0
- warp_lang-1.10.0.dist-info/licenses/LICENSE.md +176 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
- warp_lang-1.10.0.dist-info/top_level.txt +1 -0
warp/native/warp.cu
ADDED
|
@@ -0,0 +1,4604 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
3
|
+
* SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
*
|
|
5
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
* you may not use this file except in compliance with the License.
|
|
7
|
+
* You may obtain a copy of the License at
|
|
8
|
+
*
|
|
9
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
*
|
|
11
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
* See the License for the specific language governing permissions and
|
|
15
|
+
* limitations under the License.
|
|
16
|
+
*/
|
|
17
|
+
|
|
18
|
+
#include "warp.h"
|
|
19
|
+
#include "scan.h"
|
|
20
|
+
#include "cuda_util.h"
|
|
21
|
+
#include "error.h"
|
|
22
|
+
#include "sort.h"
|
|
23
|
+
|
|
24
|
+
#include <cstdlib>
|
|
25
|
+
#include <fstream>
|
|
26
|
+
#include <nvrtc.h>
|
|
27
|
+
#include <nvPTXCompiler.h>
|
|
28
|
+
#if WP_ENABLE_MATHDX
|
|
29
|
+
#include <nvJitLink.h>
|
|
30
|
+
#include <libmathdx.h>
|
|
31
|
+
#include <libcublasdx.h>
|
|
32
|
+
#include <libcufftdx.h>
|
|
33
|
+
#include <libcusolverdx.h>
|
|
34
|
+
#endif
|
|
35
|
+
|
|
36
|
+
#include <array>
|
|
37
|
+
#include <algorithm>
|
|
38
|
+
#include <iterator>
|
|
39
|
+
#include <list>
|
|
40
|
+
#include <map>
|
|
41
|
+
#include <mutex>
|
|
42
|
+
#include <string>
|
|
43
|
+
#include <unordered_map>
|
|
44
|
+
#include <unordered_set>
|
|
45
|
+
#include <vector>
|
|
46
|
+
|
|
47
|
+
#define check_any(result) (check_generic(result, __FILE__, __LINE__))
|
|
48
|
+
#define check_nvrtc(code) (check_nvrtc_result(code, __FILE__, __LINE__))
|
|
49
|
+
#define check_nvptx(code) (check_nvptx_result(code, __FILE__, __LINE__))
|
|
50
|
+
#define check_nvjitlink(handle, code) (check_nvjitlink_result(handle, code, __FILE__, __LINE__))
|
|
51
|
+
#define check_cufftdx(code) (check_cufftdx_result(code, __FILE__, __LINE__))
|
|
52
|
+
#define check_cublasdx(code) (check_cublasdx_result(code, __FILE__, __LINE__))
|
|
53
|
+
#define check_cusolver(code) (check_cusolver_result(code, __FILE__, __LINE__))
|
|
54
|
+
#define CHECK_ANY(code) \
|
|
55
|
+
{ \
|
|
56
|
+
do { \
|
|
57
|
+
bool out = (check_any(code)); \
|
|
58
|
+
if(!out) { \
|
|
59
|
+
return out; \
|
|
60
|
+
} \
|
|
61
|
+
} while(0); \
|
|
62
|
+
}
|
|
63
|
+
#define CHECK_CUFFTDX(code) \
|
|
64
|
+
{ \
|
|
65
|
+
do { \
|
|
66
|
+
bool out = (check_cufftdx(code)); \
|
|
67
|
+
if(!out) { \
|
|
68
|
+
return out; \
|
|
69
|
+
} \
|
|
70
|
+
} while(0); \
|
|
71
|
+
}
|
|
72
|
+
#define CHECK_CUBLASDX(code) \
|
|
73
|
+
{ \
|
|
74
|
+
do { \
|
|
75
|
+
bool out = (check_cufftdx(code)); \
|
|
76
|
+
if(!out) { \
|
|
77
|
+
return out; \
|
|
78
|
+
} \
|
|
79
|
+
} while(0); \
|
|
80
|
+
}
|
|
81
|
+
#define CHECK_CUSOLVER(code) \
|
|
82
|
+
{ \
|
|
83
|
+
do { \
|
|
84
|
+
bool out = (check_cusolver(code)); \
|
|
85
|
+
if(!out) { \
|
|
86
|
+
return out; \
|
|
87
|
+
} \
|
|
88
|
+
} while(0); \
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
bool check_nvrtc_result(nvrtcResult result, const char* file, int line)
|
|
92
|
+
{
|
|
93
|
+
if (result == NVRTC_SUCCESS)
|
|
94
|
+
return true;
|
|
95
|
+
|
|
96
|
+
const char* error_string = nvrtcGetErrorString(result);
|
|
97
|
+
fprintf(stderr, "Warp NVRTC compilation error %u: %s (%s:%d)\n", unsigned(result), error_string, file, line);
|
|
98
|
+
return false;
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
bool check_nvptx_result(nvPTXCompileResult result, const char* file, int line)
|
|
102
|
+
{
|
|
103
|
+
if (result == NVPTXCOMPILE_SUCCESS)
|
|
104
|
+
return true;
|
|
105
|
+
|
|
106
|
+
const char* error_string;
|
|
107
|
+
switch (result)
|
|
108
|
+
{
|
|
109
|
+
case NVPTXCOMPILE_ERROR_INVALID_COMPILER_HANDLE:
|
|
110
|
+
error_string = "Invalid compiler handle";
|
|
111
|
+
break;
|
|
112
|
+
case NVPTXCOMPILE_ERROR_INVALID_INPUT:
|
|
113
|
+
error_string = "Invalid input";
|
|
114
|
+
break;
|
|
115
|
+
case NVPTXCOMPILE_ERROR_COMPILATION_FAILURE:
|
|
116
|
+
error_string = "Compilation failure";
|
|
117
|
+
break;
|
|
118
|
+
case NVPTXCOMPILE_ERROR_INTERNAL:
|
|
119
|
+
error_string = "Internal error";
|
|
120
|
+
break;
|
|
121
|
+
case NVPTXCOMPILE_ERROR_OUT_OF_MEMORY:
|
|
122
|
+
error_string = "Out of memory";
|
|
123
|
+
break;
|
|
124
|
+
case NVPTXCOMPILE_ERROR_COMPILER_INVOCATION_INCOMPLETE:
|
|
125
|
+
error_string = "Incomplete compiler invocation";
|
|
126
|
+
break;
|
|
127
|
+
case NVPTXCOMPILE_ERROR_UNSUPPORTED_PTX_VERSION:
|
|
128
|
+
error_string = "Unsupported PTX version";
|
|
129
|
+
break;
|
|
130
|
+
default:
|
|
131
|
+
error_string = "Unknown error";
|
|
132
|
+
break;
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
fprintf(stderr, "Warp PTX compilation error %u: %s (%s:%d)\n", unsigned(result), error_string, file, line);
|
|
136
|
+
return false;
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
bool check_generic(int result, const char* file, int line)
|
|
140
|
+
{
|
|
141
|
+
if (!result) {
|
|
142
|
+
fprintf(stderr, "Error %d on %s:%d\n", (int)result, file, line);
|
|
143
|
+
return false;
|
|
144
|
+
} else {
|
|
145
|
+
return true;
|
|
146
|
+
}
|
|
147
|
+
}
|
|
148
|
+
|
|
149
|
+
struct DeviceInfo
|
|
150
|
+
{
|
|
151
|
+
static constexpr int kNameLen = 128;
|
|
152
|
+
|
|
153
|
+
CUdevice device = -1;
|
|
154
|
+
CUuuid uuid = {0};
|
|
155
|
+
int ordinal = -1;
|
|
156
|
+
int pci_domain_id = -1;
|
|
157
|
+
int pci_bus_id = -1;
|
|
158
|
+
int pci_device_id = -1;
|
|
159
|
+
char name[kNameLen] = "";
|
|
160
|
+
int arch = 0;
|
|
161
|
+
int is_uva = 0;
|
|
162
|
+
int is_mempool_supported = 0;
|
|
163
|
+
int sm_count = 0;
|
|
164
|
+
int is_ipc_supported = -1;
|
|
165
|
+
int max_smem_bytes = 0;
|
|
166
|
+
CUcontext primary_context = NULL;
|
|
167
|
+
};
|
|
168
|
+
|
|
169
|
+
struct ContextInfo
|
|
170
|
+
{
|
|
171
|
+
DeviceInfo* device_info = NULL;
|
|
172
|
+
|
|
173
|
+
// the current stream, managed from Python (see wp_cuda_context_set_stream() and wp_cuda_context_get_stream())
|
|
174
|
+
CUstream stream = NULL;
|
|
175
|
+
|
|
176
|
+
// conditional graph node support, loaded on demand if the driver supports it (CUDA 12.4+)
|
|
177
|
+
CUmodule conditional_module = NULL;
|
|
178
|
+
};
|
|
179
|
+
|
|
180
|
+
// Information used for freeing allocations.
|
|
181
|
+
struct FreeInfo
|
|
182
|
+
{
|
|
183
|
+
void* context = NULL;
|
|
184
|
+
void* ptr = NULL;
|
|
185
|
+
bool is_async = false;
|
|
186
|
+
};
|
|
187
|
+
|
|
188
|
+
struct CaptureInfo
|
|
189
|
+
{
|
|
190
|
+
CUstream stream = NULL; // the main stream where capture begins and ends
|
|
191
|
+
uint64_t id = 0; // unique capture id from CUDA
|
|
192
|
+
bool external = false; // whether this is an external capture
|
|
193
|
+
std::vector<FreeInfo> tmp_allocs; // temporary allocations owned by the graph (e.g., staged array fill values)
|
|
194
|
+
};
|
|
195
|
+
|
|
196
|
+
struct StreamInfo
|
|
197
|
+
{
|
|
198
|
+
CUevent cached_event = NULL; // event used for stream synchronization (cached to avoid creating temporary events)
|
|
199
|
+
CaptureInfo* capture = NULL; // capture info (only if started on this stream)
|
|
200
|
+
};
|
|
201
|
+
|
|
202
|
+
// Extra resources tied to a graph, freed after the graph is released by CUDA.
|
|
203
|
+
// Used with the on_graph_destroy() callback.
|
|
204
|
+
struct GraphDestroyCallbackInfo
|
|
205
|
+
{
|
|
206
|
+
void* context = NULL; // graph CUDA context
|
|
207
|
+
std::vector<void*> unfreed_allocs; // graph allocations not freed by the graph
|
|
208
|
+
std::vector<FreeInfo> tmp_allocs; // temporary allocations owned by the graph (e.g., staged array fill values)
|
|
209
|
+
};
|
|
210
|
+
|
|
211
|
+
// Information for graph allocations that are not freed by the graph.
|
|
212
|
+
// These allocations have a shared ownership:
|
|
213
|
+
// - The graph instance allocates/maps the memory on each launch, even if the user reference is released.
|
|
214
|
+
// - The user reference must remain valid even if the graph is destroyed.
|
|
215
|
+
// The memory will be freed once the user reference is released and the graph is destroyed.
|
|
216
|
+
struct GraphAllocInfo
|
|
217
|
+
{
|
|
218
|
+
uint64_t capture_id = 0;
|
|
219
|
+
void* context = NULL;
|
|
220
|
+
bool ref_exists = false; // whether user reference still exists
|
|
221
|
+
bool graph_destroyed = false; // whether graph instance was destroyed
|
|
222
|
+
};
|
|
223
|
+
|
|
224
|
+
// Information used when deferring module unloading.
|
|
225
|
+
struct ModuleInfo
|
|
226
|
+
{
|
|
227
|
+
void* context = NULL;
|
|
228
|
+
void* module = NULL;
|
|
229
|
+
};
|
|
230
|
+
|
|
231
|
+
// Information used when deferring graph destruction.
|
|
232
|
+
struct GraphDestroyInfo
|
|
233
|
+
{
|
|
234
|
+
void* context = NULL;
|
|
235
|
+
void* graph = NULL;
|
|
236
|
+
void* graph_exec = NULL;
|
|
237
|
+
};
|
|
238
|
+
|
|
239
|
+
static std::unordered_map<CUfunction, std::string> g_kernel_names;
|
|
240
|
+
|
|
241
|
+
// cached info for all devices, indexed by ordinal
|
|
242
|
+
static std::vector<DeviceInfo> g_devices;
|
|
243
|
+
|
|
244
|
+
// maps CUdevice to DeviceInfo
|
|
245
|
+
static std::map<CUdevice, DeviceInfo*> g_device_map;
|
|
246
|
+
|
|
247
|
+
// cached info for all known contexts
|
|
248
|
+
static std::map<CUcontext, ContextInfo> g_contexts;
|
|
249
|
+
|
|
250
|
+
// cached info for all known streams (including registered external streams)
|
|
251
|
+
static std::unordered_map<CUstream, StreamInfo> g_streams;
|
|
252
|
+
|
|
253
|
+
// Ongoing graph captures registered using wp.capture_begin().
|
|
254
|
+
// This maps the capture id to the stream where capture was started.
|
|
255
|
+
// See wp_cuda_graph_begin_capture(), wp_cuda_graph_end_capture(), and wp_free_device_async().
|
|
256
|
+
static std::unordered_map<uint64_t, CaptureInfo*> g_captures;
|
|
257
|
+
|
|
258
|
+
// Memory allocated during graph capture requires special handling.
|
|
259
|
+
// See wp_alloc_device_async() and wp_free_device_async().
|
|
260
|
+
static std::unordered_map<void*, GraphAllocInfo> g_graph_allocs;
|
|
261
|
+
|
|
262
|
+
// Memory that cannot be freed immediately gets queued here.
|
|
263
|
+
// Call free_deferred_allocs() to release.
|
|
264
|
+
static std::vector<FreeInfo> g_deferred_free_list;
|
|
265
|
+
|
|
266
|
+
// Modules that cannot be unloaded immediately get queued here.
|
|
267
|
+
// Call unload_deferred_modules() to release.
|
|
268
|
+
static std::vector<ModuleInfo> g_deferred_module_list;
|
|
269
|
+
|
|
270
|
+
// Graphs that cannot be destroyed immediately get queued here.
|
|
271
|
+
// Call destroy_deferred_graphs() to release.
|
|
272
|
+
static std::vector<GraphDestroyInfo> g_deferred_graph_list;
|
|
273
|
+
|
|
274
|
+
// Data from on_graph_destroy() callbacks that run on a different thread.
|
|
275
|
+
static std::vector<GraphDestroyCallbackInfo*> g_deferred_graph_destroy_list;
|
|
276
|
+
static std::mutex g_graph_destroy_mutex;
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
void wp_cuda_set_context_restore_policy(bool always_restore)
|
|
280
|
+
{
|
|
281
|
+
ContextGuard::always_restore = always_restore;
|
|
282
|
+
}
|
|
283
|
+
|
|
284
|
+
int wp_cuda_get_context_restore_policy()
|
|
285
|
+
{
|
|
286
|
+
return int(ContextGuard::always_restore);
|
|
287
|
+
}
|
|
288
|
+
|
|
289
|
+
int cuda_init()
|
|
290
|
+
{
|
|
291
|
+
if (!init_cuda_driver())
|
|
292
|
+
return -1;
|
|
293
|
+
|
|
294
|
+
int device_count = 0;
|
|
295
|
+
if (check_cu(cuDeviceGetCount_f(&device_count)))
|
|
296
|
+
{
|
|
297
|
+
g_devices.resize(device_count);
|
|
298
|
+
|
|
299
|
+
for (int i = 0; i < device_count; i++)
|
|
300
|
+
{
|
|
301
|
+
CUdevice device;
|
|
302
|
+
if (check_cu(cuDeviceGet_f(&device, i)))
|
|
303
|
+
{
|
|
304
|
+
// query device info
|
|
305
|
+
g_devices[i].device = device;
|
|
306
|
+
g_devices[i].ordinal = i;
|
|
307
|
+
check_cu(cuDeviceGetName_f(g_devices[i].name, DeviceInfo::kNameLen, device));
|
|
308
|
+
check_cu(cuDeviceGetUuid_f(&g_devices[i].uuid, device));
|
|
309
|
+
check_cu(cuDeviceGetAttribute_f(&g_devices[i].pci_domain_id, CU_DEVICE_ATTRIBUTE_PCI_DOMAIN_ID, device));
|
|
310
|
+
check_cu(cuDeviceGetAttribute_f(&g_devices[i].pci_bus_id, CU_DEVICE_ATTRIBUTE_PCI_BUS_ID, device));
|
|
311
|
+
check_cu(cuDeviceGetAttribute_f(&g_devices[i].pci_device_id, CU_DEVICE_ATTRIBUTE_PCI_DEVICE_ID, device));
|
|
312
|
+
check_cu(cuDeviceGetAttribute_f(&g_devices[i].is_uva, CU_DEVICE_ATTRIBUTE_UNIFIED_ADDRESSING, device));
|
|
313
|
+
check_cu(cuDeviceGetAttribute_f(&g_devices[i].is_mempool_supported, CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED, device));
|
|
314
|
+
check_cu(cuDeviceGetAttribute_f(&g_devices[i].sm_count, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, device));
|
|
315
|
+
#ifdef CUDA_VERSION
|
|
316
|
+
#if CUDA_VERSION >= 12000
|
|
317
|
+
int device_attribute_integrated = 0;
|
|
318
|
+
check_cu(cuDeviceGetAttribute_f(&device_attribute_integrated, CU_DEVICE_ATTRIBUTE_INTEGRATED, device));
|
|
319
|
+
if (device_attribute_integrated == 0)
|
|
320
|
+
{
|
|
321
|
+
check_cu(cuDeviceGetAttribute_f(&g_devices[i].is_ipc_supported, CU_DEVICE_ATTRIBUTE_IPC_EVENT_SUPPORTED, device));
|
|
322
|
+
}
|
|
323
|
+
else
|
|
324
|
+
{
|
|
325
|
+
// integrated devices do not support CUDA IPC
|
|
326
|
+
g_devices[i].is_ipc_supported = 0;
|
|
327
|
+
}
|
|
328
|
+
#endif
|
|
329
|
+
#endif
|
|
330
|
+
check_cu(cuDeviceGetAttribute_f(&g_devices[i].max_smem_bytes, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, device));
|
|
331
|
+
int major = 0;
|
|
332
|
+
int minor = 0;
|
|
333
|
+
check_cu(cuDeviceGetAttribute_f(&major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device));
|
|
334
|
+
check_cu(cuDeviceGetAttribute_f(&minor, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, device));
|
|
335
|
+
g_devices[i].arch = 10 * major + minor;
|
|
336
|
+
#ifdef CUDA_VERSION
|
|
337
|
+
#if CUDA_VERSION < 13000
|
|
338
|
+
if (g_devices[i].arch == 110) {
|
|
339
|
+
g_devices[i].arch = 101; // Thor SM change
|
|
340
|
+
}
|
|
341
|
+
#endif
|
|
342
|
+
#endif
|
|
343
|
+
g_device_map[device] = &g_devices[i];
|
|
344
|
+
}
|
|
345
|
+
else
|
|
346
|
+
{
|
|
347
|
+
return -1;
|
|
348
|
+
}
|
|
349
|
+
}
|
|
350
|
+
}
|
|
351
|
+
else
|
|
352
|
+
{
|
|
353
|
+
return -1;
|
|
354
|
+
}
|
|
355
|
+
|
|
356
|
+
// initialize default timing state
|
|
357
|
+
static CudaTimingState default_timing_state(0, NULL);
|
|
358
|
+
g_cuda_timing_state = &default_timing_state;
|
|
359
|
+
|
|
360
|
+
return 0;
|
|
361
|
+
}
|
|
362
|
+
|
|
363
|
+
|
|
364
|
+
CUcontext get_current_context()
|
|
365
|
+
{
|
|
366
|
+
CUcontext ctx;
|
|
367
|
+
if (check_cu(cuCtxGetCurrent_f(&ctx)))
|
|
368
|
+
return ctx;
|
|
369
|
+
else
|
|
370
|
+
return NULL;
|
|
371
|
+
}
|
|
372
|
+
|
|
373
|
+
static inline CUstream get_current_stream(void* context=NULL)
|
|
374
|
+
{
|
|
375
|
+
return static_cast<CUstream>(wp_cuda_context_get_stream(context));
|
|
376
|
+
}
|
|
377
|
+
|
|
378
|
+
static ContextInfo* get_context_info(CUcontext ctx)
|
|
379
|
+
{
|
|
380
|
+
if (!ctx)
|
|
381
|
+
{
|
|
382
|
+
ctx = get_current_context();
|
|
383
|
+
if (!ctx)
|
|
384
|
+
return NULL;
|
|
385
|
+
}
|
|
386
|
+
|
|
387
|
+
auto it = g_contexts.find(ctx);
|
|
388
|
+
if (it != g_contexts.end())
|
|
389
|
+
{
|
|
390
|
+
return &it->second;
|
|
391
|
+
}
|
|
392
|
+
else
|
|
393
|
+
{
|
|
394
|
+
// previously unseen context, add the info
|
|
395
|
+
ContextGuard guard(ctx, true);
|
|
396
|
+
|
|
397
|
+
CUdevice device;
|
|
398
|
+
if (check_cu(cuCtxGetDevice_f(&device)))
|
|
399
|
+
{
|
|
400
|
+
DeviceInfo* device_info = g_device_map[device];
|
|
401
|
+
|
|
402
|
+
// workaround for https://nvbugspro.nvidia.com/bug/4456003
|
|
403
|
+
if (device_info->is_mempool_supported)
|
|
404
|
+
{
|
|
405
|
+
void* dummy = NULL;
|
|
406
|
+
check_cuda(cudaMallocAsync(&dummy, 1, NULL));
|
|
407
|
+
check_cuda(cudaFreeAsync(dummy, NULL));
|
|
408
|
+
}
|
|
409
|
+
|
|
410
|
+
ContextInfo context_info;
|
|
411
|
+
context_info.device_info = device_info;
|
|
412
|
+
auto result = g_contexts.insert(std::make_pair(ctx, context_info));
|
|
413
|
+
return &result.first->second;
|
|
414
|
+
}
|
|
415
|
+
}
|
|
416
|
+
|
|
417
|
+
return NULL;
|
|
418
|
+
}
|
|
419
|
+
|
|
420
|
+
static inline ContextInfo* get_context_info(void* context)
|
|
421
|
+
{
|
|
422
|
+
return get_context_info(static_cast<CUcontext>(context));
|
|
423
|
+
}
|
|
424
|
+
|
|
425
|
+
static inline StreamInfo* get_stream_info(CUstream stream)
|
|
426
|
+
{
|
|
427
|
+
auto it = g_streams.find(stream);
|
|
428
|
+
if (it != g_streams.end())
|
|
429
|
+
return &it->second;
|
|
430
|
+
else
|
|
431
|
+
return NULL;
|
|
432
|
+
}
|
|
433
|
+
|
|
434
|
+
static inline CaptureInfo* get_capture_info(CUstream stream)
|
|
435
|
+
{
|
|
436
|
+
if (!g_captures.empty() && wp_cuda_stream_is_capturing(stream))
|
|
437
|
+
{
|
|
438
|
+
uint64_t capture_id = get_capture_id(stream);
|
|
439
|
+
auto capture_iter = g_captures.find(capture_id);
|
|
440
|
+
if (capture_iter != g_captures.end())
|
|
441
|
+
return capture_iter->second;
|
|
442
|
+
}
|
|
443
|
+
return NULL;
|
|
444
|
+
}
|
|
445
|
+
|
|
446
|
+
// helper function to copy a value to device memory in a graph-friendly way
|
|
447
|
+
static bool capturable_tmp_alloc(void* context, const void* data, size_t size, void** devptr_ret, bool* free_devptr_ret)
|
|
448
|
+
{
|
|
449
|
+
ContextGuard guard(context);
|
|
450
|
+
|
|
451
|
+
CUstream stream = get_current_stream();
|
|
452
|
+
CaptureInfo* capture_info = get_capture_info(stream);
|
|
453
|
+
int device_ordinal = wp_cuda_context_get_device_ordinal(context);
|
|
454
|
+
void* devptr = NULL;
|
|
455
|
+
bool free_devptr = true;
|
|
456
|
+
|
|
457
|
+
if (capture_info)
|
|
458
|
+
{
|
|
459
|
+
// ongoing graph capture - need to stage the fill value so that it persists with the graph
|
|
460
|
+
if (CUDA_VERSION >= 12040 && wp_cuda_driver_version() >= 12040)
|
|
461
|
+
{
|
|
462
|
+
// pause the capture so that the alloc/memcpy won't be captured
|
|
463
|
+
void* graph = NULL;
|
|
464
|
+
if (!wp_cuda_graph_pause_capture(WP_CURRENT_CONTEXT, stream, &graph))
|
|
465
|
+
return false;
|
|
466
|
+
|
|
467
|
+
// copy value to device memory
|
|
468
|
+
devptr = wp_alloc_device(WP_CURRENT_CONTEXT, size);
|
|
469
|
+
if (!devptr)
|
|
470
|
+
{
|
|
471
|
+
fprintf(stderr, "Warp error: Failed to allocate %llu bytes on device 'cuda:%d' (in function %s)\n", (unsigned long long)size, device_ordinal, __FUNCTION__);
|
|
472
|
+
return false;
|
|
473
|
+
}
|
|
474
|
+
if (!check_cuda(cudaMemcpyAsync(devptr, data, size, cudaMemcpyHostToDevice, stream)))
|
|
475
|
+
return false;
|
|
476
|
+
|
|
477
|
+
// graph takes ownership of the value storage
|
|
478
|
+
FreeInfo free_info;
|
|
479
|
+
free_info.context = context ? context : get_current_context();
|
|
480
|
+
free_info.ptr = devptr;
|
|
481
|
+
free_info.is_async = wp_cuda_device_is_mempool_supported(device_ordinal);
|
|
482
|
+
|
|
483
|
+
// allocation will be freed when graph is destroyed
|
|
484
|
+
capture_info->tmp_allocs.push_back(free_info);
|
|
485
|
+
|
|
486
|
+
// resume the capture
|
|
487
|
+
if (!wp_cuda_graph_resume_capture(WP_CURRENT_CONTEXT, stream, graph))
|
|
488
|
+
return false;
|
|
489
|
+
|
|
490
|
+
free_devptr = false; // memory is owned by the graph, doesn't need to be freed
|
|
491
|
+
}
|
|
492
|
+
else
|
|
493
|
+
{
|
|
494
|
+
// older CUDA can't pause/resume the capture, so stage in CPU memory
|
|
495
|
+
void* hostptr = wp_alloc_host(size);
|
|
496
|
+
if (!hostptr)
|
|
497
|
+
{
|
|
498
|
+
fprintf(stderr, "Warp error: Failed to allocate %llu bytes on device 'cpu' (in function %s)\n", (unsigned long long)size, __FUNCTION__);
|
|
499
|
+
return false;
|
|
500
|
+
}
|
|
501
|
+
memcpy(hostptr, data, size);
|
|
502
|
+
|
|
503
|
+
// the device allocation and h2d copy will be captured in the graph
|
|
504
|
+
devptr = wp_alloc_device(WP_CURRENT_CONTEXT, size);
|
|
505
|
+
if (!devptr)
|
|
506
|
+
{
|
|
507
|
+
fprintf(stderr, "Warp error: Failed to allocate %llu bytes on device 'cuda:%d' (in function %s)\n", (unsigned long long)size, device_ordinal, __FUNCTION__);
|
|
508
|
+
return false;
|
|
509
|
+
}
|
|
510
|
+
if (!check_cuda(cudaMemcpyAsync(devptr, hostptr, size, cudaMemcpyHostToDevice, stream)))
|
|
511
|
+
return false;
|
|
512
|
+
|
|
513
|
+
// graph takes ownership of the value storage
|
|
514
|
+
FreeInfo free_info;
|
|
515
|
+
free_info.context = NULL;
|
|
516
|
+
free_info.ptr = hostptr;
|
|
517
|
+
free_info.is_async = false;
|
|
518
|
+
|
|
519
|
+
// allocation will be freed when graph is destroyed
|
|
520
|
+
capture_info->tmp_allocs.push_back(free_info);
|
|
521
|
+
}
|
|
522
|
+
}
|
|
523
|
+
else
|
|
524
|
+
{
|
|
525
|
+
// not capturing, copy the value to device memory
|
|
526
|
+
devptr = wp_alloc_device(WP_CURRENT_CONTEXT, size);
|
|
527
|
+
if (!devptr)
|
|
528
|
+
{
|
|
529
|
+
fprintf(stderr, "Warp error: Failed to allocate %llu bytes on device 'cuda:%d' (in function %s)\n", (unsigned long long)size, device_ordinal, __FUNCTION__);
|
|
530
|
+
return false;
|
|
531
|
+
}
|
|
532
|
+
if (!check_cuda(cudaMemcpyAsync(devptr, data, size, cudaMemcpyHostToDevice, stream)))
|
|
533
|
+
return false;
|
|
534
|
+
}
|
|
535
|
+
|
|
536
|
+
*devptr_ret = devptr;
|
|
537
|
+
*free_devptr_ret = free_devptr;
|
|
538
|
+
|
|
539
|
+
return true;
|
|
540
|
+
}
|
|
541
|
+
|
|
542
|
+
static void deferred_free(void* ptr, void* context, bool is_async)
|
|
543
|
+
{
|
|
544
|
+
FreeInfo free_info;
|
|
545
|
+
free_info.ptr = ptr;
|
|
546
|
+
free_info.context = context ? context : get_current_context();
|
|
547
|
+
free_info.is_async = is_async;
|
|
548
|
+
g_deferred_free_list.push_back(free_info);
|
|
549
|
+
}
|
|
550
|
+
|
|
551
|
+
static int free_deferred_allocs(void* context = NULL)
|
|
552
|
+
{
|
|
553
|
+
if (g_deferred_free_list.empty() || !g_captures.empty())
|
|
554
|
+
return 0;
|
|
555
|
+
|
|
556
|
+
int num_freed_allocs = 0;
|
|
557
|
+
for (auto it = g_deferred_free_list.begin(); it != g_deferred_free_list.end(); /*noop*/)
|
|
558
|
+
{
|
|
559
|
+
const FreeInfo& free_info = *it;
|
|
560
|
+
|
|
561
|
+
// free the pointer if it matches the given context or if the context is unspecified
|
|
562
|
+
if (free_info.context == context || !context)
|
|
563
|
+
{
|
|
564
|
+
ContextGuard guard(free_info.context);
|
|
565
|
+
|
|
566
|
+
if (free_info.is_async)
|
|
567
|
+
{
|
|
568
|
+
// this could be a regular stream-ordered allocation or a graph allocation
|
|
569
|
+
cudaError_t res = cudaFreeAsync(free_info.ptr, NULL);
|
|
570
|
+
if (res != cudaSuccess)
|
|
571
|
+
{
|
|
572
|
+
if (res == cudaErrorInvalidValue)
|
|
573
|
+
{
|
|
574
|
+
// This can happen if we try to release the pointer but the graph was
|
|
575
|
+
// never launched, so the memory isn't mapped.
|
|
576
|
+
// This is fine, so clear the error.
|
|
577
|
+
cudaGetLastError();
|
|
578
|
+
}
|
|
579
|
+
else
|
|
580
|
+
{
|
|
581
|
+
// something else went wrong, report error
|
|
582
|
+
check_cuda(res);
|
|
583
|
+
}
|
|
584
|
+
}
|
|
585
|
+
}
|
|
586
|
+
else
|
|
587
|
+
{
|
|
588
|
+
check_cuda(cudaFree(free_info.ptr));
|
|
589
|
+
}
|
|
590
|
+
|
|
591
|
+
++num_freed_allocs;
|
|
592
|
+
|
|
593
|
+
it = g_deferred_free_list.erase(it);
|
|
594
|
+
}
|
|
595
|
+
else
|
|
596
|
+
{
|
|
597
|
+
++it;
|
|
598
|
+
}
|
|
599
|
+
}
|
|
600
|
+
|
|
601
|
+
return num_freed_allocs;
|
|
602
|
+
}
|
|
603
|
+
|
|
604
|
+
static int unload_deferred_modules(void* context = NULL)
|
|
605
|
+
{
|
|
606
|
+
if (g_deferred_module_list.empty() || !g_captures.empty())
|
|
607
|
+
return 0;
|
|
608
|
+
|
|
609
|
+
int num_unloaded_modules = 0;
|
|
610
|
+
for (auto it = g_deferred_module_list.begin(); it != g_deferred_module_list.end(); /*noop*/)
|
|
611
|
+
{
|
|
612
|
+
// free the module if it matches the given context or if the context is unspecified
|
|
613
|
+
const ModuleInfo& module_info = *it;
|
|
614
|
+
if (module_info.context == context || !context)
|
|
615
|
+
{
|
|
616
|
+
wp_cuda_unload_module(module_info.context, module_info.module);
|
|
617
|
+
++num_unloaded_modules;
|
|
618
|
+
it = g_deferred_module_list.erase(it);
|
|
619
|
+
}
|
|
620
|
+
else
|
|
621
|
+
{
|
|
622
|
+
++it;
|
|
623
|
+
}
|
|
624
|
+
}
|
|
625
|
+
|
|
626
|
+
return num_unloaded_modules;
|
|
627
|
+
}
|
|
628
|
+
|
|
629
|
+
static int destroy_deferred_graphs(void* context = NULL)
|
|
630
|
+
{
|
|
631
|
+
if (g_deferred_graph_list.empty() || !g_captures.empty())
|
|
632
|
+
return 0;
|
|
633
|
+
|
|
634
|
+
int num_destroyed_graphs = 0;
|
|
635
|
+
for (auto it = g_deferred_graph_list.begin(); it != g_deferred_graph_list.end(); /*noop*/)
|
|
636
|
+
{
|
|
637
|
+
// destroy the graph if it matches the given context or if the context is unspecified
|
|
638
|
+
const GraphDestroyInfo& graph_info = *it;
|
|
639
|
+
if (graph_info.context == context || !context)
|
|
640
|
+
{
|
|
641
|
+
if (graph_info.graph)
|
|
642
|
+
{
|
|
643
|
+
check_cuda(cudaGraphDestroy((cudaGraph_t)graph_info.graph));
|
|
644
|
+
}
|
|
645
|
+
if (graph_info.graph_exec)
|
|
646
|
+
{
|
|
647
|
+
check_cuda(cudaGraphExecDestroy((cudaGraphExec_t)graph_info.graph_exec));
|
|
648
|
+
}
|
|
649
|
+
++num_destroyed_graphs;
|
|
650
|
+
it = g_deferred_graph_list.erase(it);
|
|
651
|
+
}
|
|
652
|
+
else
|
|
653
|
+
{
|
|
654
|
+
++it;
|
|
655
|
+
}
|
|
656
|
+
}
|
|
657
|
+
|
|
658
|
+
return num_destroyed_graphs;
|
|
659
|
+
}
|
|
660
|
+
|
|
661
|
+
static int process_deferred_graph_destroy_callbacks(void* context = NULL)
|
|
662
|
+
{
|
|
663
|
+
int num_freed = 0;
|
|
664
|
+
|
|
665
|
+
std::lock_guard<std::mutex> lock(g_graph_destroy_mutex);
|
|
666
|
+
|
|
667
|
+
for (auto it = g_deferred_graph_destroy_list.begin(); it != g_deferred_graph_destroy_list.end(); /*noop*/)
|
|
668
|
+
{
|
|
669
|
+
GraphDestroyCallbackInfo* graph_info = *it;
|
|
670
|
+
if (graph_info->context == context || !context)
|
|
671
|
+
{
|
|
672
|
+
// handle unfreed graph allocations (may have outstanding user references)
|
|
673
|
+
for (void* ptr : graph_info->unfreed_allocs)
|
|
674
|
+
{
|
|
675
|
+
auto alloc_iter = g_graph_allocs.find(ptr);
|
|
676
|
+
if (alloc_iter != g_graph_allocs.end())
|
|
677
|
+
{
|
|
678
|
+
GraphAllocInfo& alloc_info = alloc_iter->second;
|
|
679
|
+
if (alloc_info.ref_exists)
|
|
680
|
+
{
|
|
681
|
+
// unreference from graph so the pointer will be deallocated when the user reference goes away
|
|
682
|
+
alloc_info.graph_destroyed = true;
|
|
683
|
+
}
|
|
684
|
+
else
|
|
685
|
+
{
|
|
686
|
+
// the pointer can be freed, no references remain
|
|
687
|
+
wp_free_device_async(alloc_info.context, ptr);
|
|
688
|
+
g_graph_allocs.erase(alloc_iter);
|
|
689
|
+
}
|
|
690
|
+
}
|
|
691
|
+
}
|
|
692
|
+
|
|
693
|
+
// handle temporary allocations owned by the graph (no user references)
|
|
694
|
+
for (const FreeInfo& tmp_info : graph_info->tmp_allocs)
|
|
695
|
+
{
|
|
696
|
+
if (tmp_info.context)
|
|
697
|
+
{
|
|
698
|
+
// GPU alloc
|
|
699
|
+
if (tmp_info.is_async)
|
|
700
|
+
{
|
|
701
|
+
wp_free_device_async(tmp_info.context, tmp_info.ptr);
|
|
702
|
+
}
|
|
703
|
+
else
|
|
704
|
+
{
|
|
705
|
+
wp_free_device_default(tmp_info.context, tmp_info.ptr);
|
|
706
|
+
}
|
|
707
|
+
}
|
|
708
|
+
else
|
|
709
|
+
{
|
|
710
|
+
// CPU alloc
|
|
711
|
+
wp_free_host(tmp_info.ptr);
|
|
712
|
+
}
|
|
713
|
+
}
|
|
714
|
+
|
|
715
|
+
++num_freed;
|
|
716
|
+
delete graph_info;
|
|
717
|
+
it = g_deferred_graph_destroy_list.erase(it);
|
|
718
|
+
}
|
|
719
|
+
else
|
|
720
|
+
{
|
|
721
|
+
++it;
|
|
722
|
+
}
|
|
723
|
+
}
|
|
724
|
+
|
|
725
|
+
return num_freed;
|
|
726
|
+
}
|
|
727
|
+
|
|
728
|
+
static int run_deferred_actions(void* context = NULL)
|
|
729
|
+
{
|
|
730
|
+
int num_actions = 0;
|
|
731
|
+
num_actions += free_deferred_allocs(context);
|
|
732
|
+
num_actions += unload_deferred_modules(context);
|
|
733
|
+
num_actions += destroy_deferred_graphs(context);
|
|
734
|
+
num_actions += process_deferred_graph_destroy_callbacks(context);
|
|
735
|
+
return num_actions;
|
|
736
|
+
}
|
|
737
|
+
|
|
738
|
+
// Callback used when a graph is destroyed.
|
|
739
|
+
// NOTE: this runs on an internal CUDA thread and requires synchronization.
|
|
740
|
+
static void CUDART_CB on_graph_destroy(void* user_data)
|
|
741
|
+
{
|
|
742
|
+
if (user_data)
|
|
743
|
+
{
|
|
744
|
+
std::lock_guard<std::mutex> lock(g_graph_destroy_mutex);
|
|
745
|
+
g_deferred_graph_destroy_list.push_back(static_cast<GraphDestroyCallbackInfo*>(user_data));
|
|
746
|
+
}
|
|
747
|
+
}
|
|
748
|
+
|
|
749
|
+
static inline const char* get_cuda_kernel_name(void* kernel)
|
|
750
|
+
{
|
|
751
|
+
CUfunction cuda_func = static_cast<CUfunction>(kernel);
|
|
752
|
+
auto name_iter = g_kernel_names.find((CUfunction)cuda_func);
|
|
753
|
+
if (name_iter != g_kernel_names.end())
|
|
754
|
+
return name_iter->second.c_str();
|
|
755
|
+
else
|
|
756
|
+
return "unknown_kernel";
|
|
757
|
+
}
|
|
758
|
+
|
|
759
|
+
|
|
760
|
+
void* wp_alloc_pinned(size_t s)
|
|
761
|
+
{
|
|
762
|
+
void* ptr = NULL;
|
|
763
|
+
check_cuda(cudaMallocHost(&ptr, s));
|
|
764
|
+
return ptr;
|
|
765
|
+
}
|
|
766
|
+
|
|
767
|
+
void wp_free_pinned(void* ptr)
|
|
768
|
+
{
|
|
769
|
+
cudaFreeHost(ptr);
|
|
770
|
+
}
|
|
771
|
+
|
|
772
|
+
void* wp_alloc_device(void* context, size_t s)
|
|
773
|
+
{
|
|
774
|
+
int ordinal = wp_cuda_context_get_device_ordinal(context);
|
|
775
|
+
|
|
776
|
+
// use stream-ordered allocator if available
|
|
777
|
+
if (wp_cuda_device_is_mempool_supported(ordinal))
|
|
778
|
+
return wp_alloc_device_async(context, s);
|
|
779
|
+
else
|
|
780
|
+
return wp_alloc_device_default(context, s);
|
|
781
|
+
}
|
|
782
|
+
|
|
783
|
+
void wp_free_device(void* context, void* ptr)
|
|
784
|
+
{
|
|
785
|
+
int ordinal = wp_cuda_context_get_device_ordinal(context);
|
|
786
|
+
|
|
787
|
+
// use stream-ordered allocator if available
|
|
788
|
+
if (wp_cuda_device_is_mempool_supported(ordinal))
|
|
789
|
+
wp_free_device_async(context, ptr);
|
|
790
|
+
else
|
|
791
|
+
wp_free_device_default(context, ptr);
|
|
792
|
+
}
|
|
793
|
+
|
|
794
|
+
void* wp_alloc_device_default(void* context, size_t s)
|
|
795
|
+
{
|
|
796
|
+
ContextGuard guard(context);
|
|
797
|
+
|
|
798
|
+
void* ptr = NULL;
|
|
799
|
+
check_cuda(cudaMalloc(&ptr, s));
|
|
800
|
+
|
|
801
|
+
return ptr;
|
|
802
|
+
}
|
|
803
|
+
|
|
804
|
+
void wp_free_device_default(void* context, void* ptr)
|
|
805
|
+
{
|
|
806
|
+
ContextGuard guard(context);
|
|
807
|
+
|
|
808
|
+
// check if a capture is in progress
|
|
809
|
+
if (g_captures.empty())
|
|
810
|
+
{
|
|
811
|
+
check_cuda(cudaFree(ptr));
|
|
812
|
+
}
|
|
813
|
+
else
|
|
814
|
+
{
|
|
815
|
+
// we must defer the operation until graph captures complete
|
|
816
|
+
deferred_free(ptr, context, false);
|
|
817
|
+
}
|
|
818
|
+
}
|
|
819
|
+
|
|
820
|
+
void* wp_alloc_device_async(void* context, size_t s)
|
|
821
|
+
{
|
|
822
|
+
// stream-ordered allocations don't rely on the current context,
|
|
823
|
+
// but we set the context here for consistent behaviour
|
|
824
|
+
ContextGuard guard(context);
|
|
825
|
+
|
|
826
|
+
ContextInfo* context_info = get_context_info(context);
|
|
827
|
+
if (!context_info)
|
|
828
|
+
return NULL;
|
|
829
|
+
|
|
830
|
+
CUstream stream = context_info->stream;
|
|
831
|
+
|
|
832
|
+
void* ptr = NULL;
|
|
833
|
+
check_cuda(cudaMallocAsync(&ptr, s, stream));
|
|
834
|
+
|
|
835
|
+
if (ptr)
|
|
836
|
+
{
|
|
837
|
+
// if the stream is capturing, the allocation requires special handling
|
|
838
|
+
if (wp_cuda_stream_is_capturing(stream))
|
|
839
|
+
{
|
|
840
|
+
// check if this is a known capture
|
|
841
|
+
uint64_t capture_id = get_capture_id(stream);
|
|
842
|
+
auto capture_iter = g_captures.find(capture_id);
|
|
843
|
+
if (capture_iter != g_captures.end())
|
|
844
|
+
{
|
|
845
|
+
// remember graph allocation details
|
|
846
|
+
GraphAllocInfo alloc_info;
|
|
847
|
+
alloc_info.capture_id = capture_id;
|
|
848
|
+
alloc_info.context = context ? context : get_current_context();
|
|
849
|
+
alloc_info.ref_exists = true; // user reference created and returned here
|
|
850
|
+
alloc_info.graph_destroyed = false; // graph not destroyed yet
|
|
851
|
+
g_graph_allocs[ptr] = alloc_info;
|
|
852
|
+
}
|
|
853
|
+
}
|
|
854
|
+
}
|
|
855
|
+
|
|
856
|
+
return ptr;
|
|
857
|
+
}
|
|
858
|
+
|
|
859
|
+
void wp_free_device_async(void* context, void* ptr)
|
|
860
|
+
{
|
|
861
|
+
// stream-ordered allocators generally don't rely on the current context,
|
|
862
|
+
// but we set the context here for consistent behaviour
|
|
863
|
+
ContextGuard guard(context);
|
|
864
|
+
|
|
865
|
+
// NB: Stream-ordered deallocations are tricky, because the memory could still be used on another stream
|
|
866
|
+
// or even multiple streams. To avoid use-after-free errors, we need to ensure that all preceding work
|
|
867
|
+
// completes before releasing the memory. The strategy is different for regular stream-ordered allocations
|
|
868
|
+
// and allocations made during graph capture. See below for details.
|
|
869
|
+
|
|
870
|
+
// check if this allocation was made during graph capture
|
|
871
|
+
auto alloc_iter = g_graph_allocs.find(ptr);
|
|
872
|
+
if (alloc_iter == g_graph_allocs.end())
|
|
873
|
+
{
|
|
874
|
+
// Not a graph allocation.
|
|
875
|
+
// Check if graph capture is ongoing.
|
|
876
|
+
if (g_captures.empty())
|
|
877
|
+
{
|
|
878
|
+
// cudaFreeAsync on the null stream does not block or trigger synchronization, but it postpones
|
|
879
|
+
// the deallocation until a synchronization point is reached, so preceding work on this pointer
|
|
880
|
+
// should safely complete.
|
|
881
|
+
check_cuda(cudaFreeAsync(ptr, NULL));
|
|
882
|
+
}
|
|
883
|
+
else
|
|
884
|
+
{
|
|
885
|
+
// We must defer the free operation until graph capture completes.
|
|
886
|
+
deferred_free(ptr, context, true);
|
|
887
|
+
}
|
|
888
|
+
}
|
|
889
|
+
else
|
|
890
|
+
{
|
|
891
|
+
// get the graph allocation details
|
|
892
|
+
GraphAllocInfo& alloc_info = alloc_iter->second;
|
|
893
|
+
|
|
894
|
+
uint64_t capture_id = alloc_info.capture_id;
|
|
895
|
+
|
|
896
|
+
// check if the capture is still active
|
|
897
|
+
auto capture_iter = g_captures.find(capture_id);
|
|
898
|
+
if (capture_iter != g_captures.end())
|
|
899
|
+
{
|
|
900
|
+
// Add a mem free node. Use all current leaf nodes as dependencies to ensure that all prior
|
|
901
|
+
// work completes before deallocating. This works with both Warp-initiated and external captures
|
|
902
|
+
// and avoids the need to explicitly track all streams used during the capture.
|
|
903
|
+
CaptureInfo* capture = capture_iter->second;
|
|
904
|
+
cudaGraph_t graph = get_capture_graph(capture->stream);
|
|
905
|
+
std::vector<cudaGraphNode_t> leaf_nodes;
|
|
906
|
+
if (graph && get_graph_leaf_nodes(graph, leaf_nodes))
|
|
907
|
+
{
|
|
908
|
+
cudaGraphNode_t free_node;
|
|
909
|
+
check_cuda(cudaGraphAddMemFreeNode(&free_node, graph, leaf_nodes.data(), leaf_nodes.size(), ptr));
|
|
910
|
+
}
|
|
911
|
+
|
|
912
|
+
// we're done with this allocation, it's owned by the graph
|
|
913
|
+
g_graph_allocs.erase(alloc_iter);
|
|
914
|
+
}
|
|
915
|
+
else
|
|
916
|
+
{
|
|
917
|
+
// the capture has ended
|
|
918
|
+
// if the owning graph was already destroyed, we can free the pointer now
|
|
919
|
+
if (alloc_info.graph_destroyed)
|
|
920
|
+
{
|
|
921
|
+
if (g_captures.empty())
|
|
922
|
+
{
|
|
923
|
+
// try to free the pointer now
|
|
924
|
+
cudaError_t res = cudaFreeAsync(ptr, NULL);
|
|
925
|
+
if (res == cudaErrorInvalidValue)
|
|
926
|
+
{
|
|
927
|
+
// This can happen if we try to release the pointer but the graph was
|
|
928
|
+
// never launched, so the memory isn't mapped.
|
|
929
|
+
// This is fine, so clear the error.
|
|
930
|
+
cudaGetLastError();
|
|
931
|
+
}
|
|
932
|
+
else
|
|
933
|
+
{
|
|
934
|
+
// check for other errors
|
|
935
|
+
check_cuda(res);
|
|
936
|
+
}
|
|
937
|
+
}
|
|
938
|
+
else
|
|
939
|
+
{
|
|
940
|
+
// We must defer the operation until graph capture completes.
|
|
941
|
+
deferred_free(ptr, context, true);
|
|
942
|
+
}
|
|
943
|
+
|
|
944
|
+
// we're done with this allocation
|
|
945
|
+
g_graph_allocs.erase(alloc_iter);
|
|
946
|
+
}
|
|
947
|
+
else
|
|
948
|
+
{
|
|
949
|
+
// graph still exists
|
|
950
|
+
// unreference the pointer so it will be deallocated once the graph instance is destroyed
|
|
951
|
+
alloc_info.ref_exists = false;
|
|
952
|
+
}
|
|
953
|
+
}
|
|
954
|
+
}
|
|
955
|
+
}
|
|
956
|
+
|
|
957
|
+
bool wp_memcpy_h2d(void* context, void* dest, void* src, size_t n, void* stream)
|
|
958
|
+
{
|
|
959
|
+
ContextGuard guard(context);
|
|
960
|
+
|
|
961
|
+
CUstream cuda_stream;
|
|
962
|
+
if (stream != WP_CURRENT_STREAM)
|
|
963
|
+
cuda_stream = static_cast<CUstream>(stream);
|
|
964
|
+
else
|
|
965
|
+
cuda_stream = get_current_stream(context);
|
|
966
|
+
|
|
967
|
+
begin_cuda_range(WP_TIMING_MEMCPY, cuda_stream, context, "memcpy HtoD");
|
|
968
|
+
|
|
969
|
+
bool result = check_cuda(cudaMemcpyAsync(dest, src, n, cudaMemcpyHostToDevice, cuda_stream));
|
|
970
|
+
|
|
971
|
+
end_cuda_range(WP_TIMING_MEMCPY, cuda_stream);
|
|
972
|
+
|
|
973
|
+
return result;
|
|
974
|
+
}
|
|
975
|
+
|
|
976
|
+
bool wp_memcpy_d2h(void* context, void* dest, void* src, size_t n, void* stream)
|
|
977
|
+
{
|
|
978
|
+
ContextGuard guard(context);
|
|
979
|
+
|
|
980
|
+
CUstream cuda_stream;
|
|
981
|
+
if (stream != WP_CURRENT_STREAM)
|
|
982
|
+
cuda_stream = static_cast<CUstream>(stream);
|
|
983
|
+
else
|
|
984
|
+
cuda_stream = get_current_stream(context);
|
|
985
|
+
|
|
986
|
+
begin_cuda_range(WP_TIMING_MEMCPY, cuda_stream, context, "memcpy DtoH");
|
|
987
|
+
|
|
988
|
+
bool result = check_cuda(cudaMemcpyAsync(dest, src, n, cudaMemcpyDeviceToHost, cuda_stream));
|
|
989
|
+
|
|
990
|
+
end_cuda_range(WP_TIMING_MEMCPY, cuda_stream);
|
|
991
|
+
|
|
992
|
+
return result;
|
|
993
|
+
}
|
|
994
|
+
|
|
995
|
+
bool wp_memcpy_d2d(void* context, void* dest, void* src, size_t n, void* stream)
|
|
996
|
+
{
|
|
997
|
+
ContextGuard guard(context);
|
|
998
|
+
|
|
999
|
+
CUstream cuda_stream;
|
|
1000
|
+
if (stream != WP_CURRENT_STREAM)
|
|
1001
|
+
cuda_stream = static_cast<CUstream>(stream);
|
|
1002
|
+
else
|
|
1003
|
+
cuda_stream = get_current_stream(context);
|
|
1004
|
+
|
|
1005
|
+
begin_cuda_range(WP_TIMING_MEMCPY, cuda_stream, context, "memcpy DtoD");
|
|
1006
|
+
|
|
1007
|
+
bool result = check_cuda(cudaMemcpyAsync(dest, src, n, cudaMemcpyDeviceToDevice, cuda_stream));
|
|
1008
|
+
|
|
1009
|
+
end_cuda_range(WP_TIMING_MEMCPY, cuda_stream);
|
|
1010
|
+
|
|
1011
|
+
return result;
|
|
1012
|
+
}
|
|
1013
|
+
|
|
1014
|
+
bool wp_memcpy_p2p(void* dst_context, void* dst, void* src_context, void* src, size_t n, void* stream)
|
|
1015
|
+
{
|
|
1016
|
+
// ContextGuard guard(context);
|
|
1017
|
+
|
|
1018
|
+
CUstream cuda_stream;
|
|
1019
|
+
if (stream != WP_CURRENT_STREAM)
|
|
1020
|
+
cuda_stream = static_cast<CUstream>(stream);
|
|
1021
|
+
else
|
|
1022
|
+
cuda_stream = get_current_stream(dst_context);
|
|
1023
|
+
|
|
1024
|
+
// Notes:
|
|
1025
|
+
// - cuMemcpyPeerAsync() works fine with both regular and pooled allocations (cudaMalloc() and cudaMallocAsync(), respectively)
|
|
1026
|
+
// when not capturing a graph.
|
|
1027
|
+
// - cuMemcpyPeerAsync() is not supported during graph capture, so we must use cudaMemcpyAsync() with kind=cudaMemcpyDefault.
|
|
1028
|
+
// - cudaMemcpyAsync() works fine with regular allocations, but doesn't work with pooled allocations
|
|
1029
|
+
// unless mempool access has been enabled.
|
|
1030
|
+
// - There is no reliable way to check if mempool access is enabled during graph capture,
|
|
1031
|
+
// because cudaMemPoolGetAccess() cannot be called during graph capture.
|
|
1032
|
+
// - CUDA will report error 1 (invalid argument) if cudaMemcpyAsync() is called but mempool access is not enabled.
|
|
1033
|
+
|
|
1034
|
+
if (!wp_cuda_stream_is_capturing(stream))
|
|
1035
|
+
{
|
|
1036
|
+
begin_cuda_range(WP_TIMING_MEMCPY, cuda_stream, get_stream_context(stream), "memcpy PtoP");
|
|
1037
|
+
|
|
1038
|
+
bool result = check_cu(cuMemcpyPeerAsync_f(
|
|
1039
|
+
(CUdeviceptr)dst, (CUcontext)dst_context,
|
|
1040
|
+
(CUdeviceptr)src, (CUcontext)src_context,
|
|
1041
|
+
n, cuda_stream));
|
|
1042
|
+
|
|
1043
|
+
end_cuda_range(WP_TIMING_MEMCPY, cuda_stream);
|
|
1044
|
+
|
|
1045
|
+
return result;
|
|
1046
|
+
}
|
|
1047
|
+
else
|
|
1048
|
+
{
|
|
1049
|
+
cudaError_t result = cudaSuccess;
|
|
1050
|
+
|
|
1051
|
+
// cudaMemcpyAsync() is sensitive to the bound context to resolve pointer locations.
|
|
1052
|
+
// If fails with cudaErrorInvalidValue if it cannot resolve an argument.
|
|
1053
|
+
// We first try the copy in the destination context, then if it fails we retry in the source context.
|
|
1054
|
+
// The cudaErrorInvalidValue error doesn't cause graph capture to fail, so it's ok to retry.
|
|
1055
|
+
// Since this trial-and-error shenanigans only happens during capture, there
|
|
1056
|
+
// is no perf impact when the graph is launched.
|
|
1057
|
+
// For bonus points, this approach simplifies memory pool access requirements.
|
|
1058
|
+
// Access only needs to be enabled one way, either from the source device to the destination device
|
|
1059
|
+
// or vice versa. Sometimes, when it's really quiet, you can actually hear my genius.
|
|
1060
|
+
{
|
|
1061
|
+
// try doing the copy in the destination context
|
|
1062
|
+
ContextGuard guard(dst_context);
|
|
1063
|
+
result = cudaMemcpyAsync(dst, src, n, cudaMemcpyDefault, cuda_stream);
|
|
1064
|
+
|
|
1065
|
+
if (result != cudaSuccess)
|
|
1066
|
+
{
|
|
1067
|
+
// clear error in destination context
|
|
1068
|
+
cudaGetLastError();
|
|
1069
|
+
|
|
1070
|
+
// try doing the copy in the source context
|
|
1071
|
+
ContextGuard guard(src_context);
|
|
1072
|
+
result = cudaMemcpyAsync(dst, src, n, cudaMemcpyDefault, cuda_stream);
|
|
1073
|
+
|
|
1074
|
+
// clear error in source context
|
|
1075
|
+
cudaGetLastError();
|
|
1076
|
+
}
|
|
1077
|
+
}
|
|
1078
|
+
|
|
1079
|
+
// If the copy failed, try to detect if mempool allocations are involved to generate a helpful error message.
|
|
1080
|
+
if (!check_cuda(result))
|
|
1081
|
+
{
|
|
1082
|
+
if (result == cudaErrorInvalidValue && src != NULL && dst != NULL)
|
|
1083
|
+
{
|
|
1084
|
+
// check if either of the pointers was allocated from a mempool
|
|
1085
|
+
void* src_mempool = NULL;
|
|
1086
|
+
void* dst_mempool = NULL;
|
|
1087
|
+
cuPointerGetAttribute_f(&src_mempool, CU_POINTER_ATTRIBUTE_MEMPOOL_HANDLE, (CUdeviceptr)src);
|
|
1088
|
+
cuPointerGetAttribute_f(&dst_mempool, CU_POINTER_ATTRIBUTE_MEMPOOL_HANDLE, (CUdeviceptr)dst);
|
|
1089
|
+
cudaGetLastError(); // clear any errors
|
|
1090
|
+
// check if either of the pointers was allocated during graph capture
|
|
1091
|
+
auto src_alloc = g_graph_allocs.find(src);
|
|
1092
|
+
auto dst_alloc = g_graph_allocs.find(dst);
|
|
1093
|
+
if (src_mempool != NULL || src_alloc != g_graph_allocs.end() ||
|
|
1094
|
+
dst_mempool != NULL || dst_alloc != g_graph_allocs.end())
|
|
1095
|
+
{
|
|
1096
|
+
wp::append_error_string("*** CUDA mempool allocations were used in a peer-to-peer copy during graph capture.");
|
|
1097
|
+
wp::append_error_string("*** This operation fails if mempool access is not enabled between the peer devices.");
|
|
1098
|
+
wp::append_error_string("*** Either enable mempool access between the devices or use the default CUDA allocator");
|
|
1099
|
+
wp::append_error_string("*** to pre-allocate the arrays before graph capture begins.");
|
|
1100
|
+
}
|
|
1101
|
+
}
|
|
1102
|
+
|
|
1103
|
+
return false;
|
|
1104
|
+
}
|
|
1105
|
+
|
|
1106
|
+
return true;
|
|
1107
|
+
}
|
|
1108
|
+
}
|
|
1109
|
+
|
|
1110
|
+
|
|
1111
|
+
__global__ void memset_kernel(int* dest, int value, size_t n)
|
|
1112
|
+
{
|
|
1113
|
+
const size_t tid = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
|
|
1114
|
+
|
|
1115
|
+
if (tid < n)
|
|
1116
|
+
{
|
|
1117
|
+
dest[tid] = value;
|
|
1118
|
+
}
|
|
1119
|
+
}
|
|
1120
|
+
|
|
1121
|
+
void wp_memset_device(void* context, void* dest, int value, size_t n)
|
|
1122
|
+
{
|
|
1123
|
+
ContextGuard guard(context);
|
|
1124
|
+
|
|
1125
|
+
if (true)// ((n%4) > 0)
|
|
1126
|
+
{
|
|
1127
|
+
cudaStream_t stream = get_current_stream();
|
|
1128
|
+
|
|
1129
|
+
begin_cuda_range(WP_TIMING_MEMSET, stream, context, "memset");
|
|
1130
|
+
|
|
1131
|
+
// for unaligned lengths fallback to CUDA memset
|
|
1132
|
+
check_cuda(cudaMemsetAsync(dest, value, n, stream));
|
|
1133
|
+
|
|
1134
|
+
end_cuda_range(WP_TIMING_MEMSET, stream);
|
|
1135
|
+
}
|
|
1136
|
+
else
|
|
1137
|
+
{
|
|
1138
|
+
// custom kernel to support 4-byte values (and slightly lower host overhead)
|
|
1139
|
+
const size_t num_words = n/4;
|
|
1140
|
+
wp_launch_device(WP_CURRENT_CONTEXT, memset_kernel, num_words, ((int*)dest, value, num_words));
|
|
1141
|
+
}
|
|
1142
|
+
}
|
|
1143
|
+
|
|
1144
|
+
// fill memory buffer with a value: generic memtile kernel using memcpy for each element
|
|
1145
|
+
__global__ void memtile_kernel(void* dst, const void* src, size_t srcsize, size_t n)
|
|
1146
|
+
{
|
|
1147
|
+
size_t tid = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
|
|
1148
|
+
if (tid < n)
|
|
1149
|
+
{
|
|
1150
|
+
memcpy((int8_t*)dst + srcsize * tid, src, srcsize);
|
|
1151
|
+
}
|
|
1152
|
+
}
|
|
1153
|
+
|
|
1154
|
+
// this should be faster than memtile_kernel, but requires proper alignment of dst
|
|
1155
|
+
template <typename T>
|
|
1156
|
+
__global__ void memtile_value_kernel(T* dst, T value, size_t n)
|
|
1157
|
+
{
|
|
1158
|
+
size_t tid = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
|
|
1159
|
+
if (tid < n)
|
|
1160
|
+
{
|
|
1161
|
+
dst[tid] = value;
|
|
1162
|
+
}
|
|
1163
|
+
}
|
|
1164
|
+
|
|
1165
|
+
void wp_memtile_device(void* context, void* dst, const void* src, size_t srcsize, size_t n)
|
|
1166
|
+
{
|
|
1167
|
+
ContextGuard guard(context);
|
|
1168
|
+
|
|
1169
|
+
size_t dst_addr = reinterpret_cast<size_t>(dst);
|
|
1170
|
+
size_t src_addr = reinterpret_cast<size_t>(src);
|
|
1171
|
+
|
|
1172
|
+
// try memtile_value first because it should be faster, but we need to ensure proper alignment
|
|
1173
|
+
if (srcsize == 8 && (dst_addr & 7) == 0 && (src_addr & 7) == 0)
|
|
1174
|
+
{
|
|
1175
|
+
int64_t* p = reinterpret_cast<int64_t*>(dst);
|
|
1176
|
+
int64_t value = *reinterpret_cast<const int64_t*>(src);
|
|
1177
|
+
wp_launch_device(WP_CURRENT_CONTEXT, memtile_value_kernel, n, (p, value, n));
|
|
1178
|
+
}
|
|
1179
|
+
else if (srcsize == 4 && (dst_addr & 3) == 0 && (src_addr & 3) == 0)
|
|
1180
|
+
{
|
|
1181
|
+
int32_t* p = reinterpret_cast<int32_t*>(dst);
|
|
1182
|
+
int32_t value = *reinterpret_cast<const int32_t*>(src);
|
|
1183
|
+
wp_launch_device(WP_CURRENT_CONTEXT, memtile_value_kernel, n, (p, value, n));
|
|
1184
|
+
}
|
|
1185
|
+
else if (srcsize == 2 && (dst_addr & 1) == 0 && (src_addr & 1) == 0)
|
|
1186
|
+
{
|
|
1187
|
+
int16_t* p = reinterpret_cast<int16_t*>(dst);
|
|
1188
|
+
int16_t value = *reinterpret_cast<const int16_t*>(src);
|
|
1189
|
+
wp_launch_device(WP_CURRENT_CONTEXT, memtile_value_kernel, n, (p, value, n));
|
|
1190
|
+
}
|
|
1191
|
+
else if (srcsize == 1)
|
|
1192
|
+
{
|
|
1193
|
+
check_cuda(cudaMemset(dst, *reinterpret_cast<const int8_t*>(src), n));
|
|
1194
|
+
}
|
|
1195
|
+
else
|
|
1196
|
+
{
|
|
1197
|
+
// generic version
|
|
1198
|
+
void* value_devptr = NULL; // fill value in device memory
|
|
1199
|
+
bool free_devptr = true; // whether we need to free the memory
|
|
1200
|
+
|
|
1201
|
+
// prepare the fill value in a graph-friendly way
|
|
1202
|
+
if (!capturable_tmp_alloc(WP_CURRENT_CONTEXT, src, srcsize, &value_devptr, &free_devptr))
|
|
1203
|
+
{
|
|
1204
|
+
fprintf(stderr, "Warp fill error: failed to copy value to device memory\n");
|
|
1205
|
+
return;
|
|
1206
|
+
}
|
|
1207
|
+
|
|
1208
|
+
wp_launch_device(WP_CURRENT_CONTEXT, memtile_kernel, n, (dst, value_devptr, srcsize, n));
|
|
1209
|
+
|
|
1210
|
+
if (free_devptr)
|
|
1211
|
+
{
|
|
1212
|
+
wp_free_device(WP_CURRENT_CONTEXT, value_devptr);
|
|
1213
|
+
}
|
|
1214
|
+
}
|
|
1215
|
+
}
|
|
1216
|
+
|
|
1217
|
+
|
|
1218
|
+
static __global__ void array_copy_1d_kernel(void* dst, const void* src,
|
|
1219
|
+
size_t dst_stride, size_t src_stride,
|
|
1220
|
+
const int* dst_indices, const int* src_indices,
|
|
1221
|
+
size_t n, size_t elem_size)
|
|
1222
|
+
{
|
|
1223
|
+
size_t i = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1224
|
+
if (i < n)
|
|
1225
|
+
{
|
|
1226
|
+
size_t src_idx = src_indices ? src_indices[i] : i;
|
|
1227
|
+
size_t dst_idx = dst_indices ? dst_indices[i] : i;
|
|
1228
|
+
const char* p = (const char*)src + src_idx * src_stride;
|
|
1229
|
+
char* q = (char*)dst + dst_idx * dst_stride;
|
|
1230
|
+
memcpy(q, p, elem_size);
|
|
1231
|
+
}
|
|
1232
|
+
}
|
|
1233
|
+
|
|
1234
|
+
static __global__ void array_copy_2d_kernel(void* dst, const void* src,
|
|
1235
|
+
wp::vec_t<2, size_t> dst_strides, wp::vec_t<2, size_t> src_strides,
|
|
1236
|
+
wp::vec_t<2, const int*> dst_indices, wp::vec_t<2, const int*> src_indices,
|
|
1237
|
+
wp::vec_t<2, size_t> shape, size_t elem_size)
|
|
1238
|
+
{
|
|
1239
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1240
|
+
size_t n = shape[1];
|
|
1241
|
+
size_t i = tid / n;
|
|
1242
|
+
size_t j = tid % n;
|
|
1243
|
+
if (i < shape[0] /*&& j < shape[1]*/)
|
|
1244
|
+
{
|
|
1245
|
+
size_t src_idx0 = src_indices[0] ? src_indices[0][i] : i;
|
|
1246
|
+
size_t dst_idx0 = dst_indices[0] ? dst_indices[0][i] : i;
|
|
1247
|
+
size_t src_idx1 = src_indices[1] ? src_indices[1][j] : j;
|
|
1248
|
+
size_t dst_idx1 = dst_indices[1] ? dst_indices[1][j] : j;
|
|
1249
|
+
const char* p = (const char*)src + src_idx0 * src_strides[0] + src_idx1 * src_strides[1];
|
|
1250
|
+
char* q = (char*)dst + dst_idx0 * dst_strides[0] + dst_idx1 * dst_strides[1];
|
|
1251
|
+
memcpy(q, p, elem_size);
|
|
1252
|
+
}
|
|
1253
|
+
}
|
|
1254
|
+
|
|
1255
|
+
static __global__ void array_copy_3d_kernel(void* dst, const void* src,
|
|
1256
|
+
wp::vec_t<3, size_t> dst_strides, wp::vec_t<3, size_t> src_strides,
|
|
1257
|
+
wp::vec_t<3, const int*> dst_indices, wp::vec_t<3, const int*> src_indices,
|
|
1258
|
+
wp::vec_t<3, size_t> shape, size_t elem_size)
|
|
1259
|
+
{
|
|
1260
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1261
|
+
size_t n = shape[1];
|
|
1262
|
+
size_t o = shape[2];
|
|
1263
|
+
size_t i = tid / (n * o);
|
|
1264
|
+
size_t j = tid % (n * o) / o;
|
|
1265
|
+
size_t k = tid % o;
|
|
1266
|
+
if (i < shape[0] && j < shape[1] /*&& k < shape[2]*/)
|
|
1267
|
+
{
|
|
1268
|
+
size_t src_idx0 = src_indices[0] ? src_indices[0][i] : i;
|
|
1269
|
+
size_t dst_idx0 = dst_indices[0] ? dst_indices[0][i] : i;
|
|
1270
|
+
size_t src_idx1 = src_indices[1] ? src_indices[1][j] : j;
|
|
1271
|
+
size_t dst_idx1 = dst_indices[1] ? dst_indices[1][j] : j;
|
|
1272
|
+
size_t src_idx2 = src_indices[2] ? src_indices[2][k] : k;
|
|
1273
|
+
size_t dst_idx2 = dst_indices[2] ? dst_indices[2][k] : k;
|
|
1274
|
+
const char* p = (const char*)src + src_idx0 * src_strides[0]
|
|
1275
|
+
+ src_idx1 * src_strides[1]
|
|
1276
|
+
+ src_idx2 * src_strides[2];
|
|
1277
|
+
char* q = (char*)dst + dst_idx0 * dst_strides[0]
|
|
1278
|
+
+ dst_idx1 * dst_strides[1]
|
|
1279
|
+
+ dst_idx2 * dst_strides[2];
|
|
1280
|
+
memcpy(q, p, elem_size);
|
|
1281
|
+
}
|
|
1282
|
+
}
|
|
1283
|
+
|
|
1284
|
+
static __global__ void array_copy_4d_kernel(void* dst, const void* src,
|
|
1285
|
+
wp::vec_t<4, size_t> dst_strides, wp::vec_t<4, size_t> src_strides,
|
|
1286
|
+
wp::vec_t<4, const int*> dst_indices, wp::vec_t<4, const int*> src_indices,
|
|
1287
|
+
wp::vec_t<4, size_t> shape, size_t elem_size)
|
|
1288
|
+
{
|
|
1289
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1290
|
+
size_t n = shape[1];
|
|
1291
|
+
size_t o = shape[2];
|
|
1292
|
+
size_t p = shape[3];
|
|
1293
|
+
size_t i = tid / (n * o * p);
|
|
1294
|
+
size_t j = tid % (n * o * p) / (o * p);
|
|
1295
|
+
size_t k = tid % (o * p) / p;
|
|
1296
|
+
size_t l = tid % p;
|
|
1297
|
+
if (i < shape[0] && j < shape[1] && k < shape[2] /*&& l < shape[3]*/)
|
|
1298
|
+
{
|
|
1299
|
+
size_t src_idx0 = src_indices[0] ? src_indices[0][i] : i;
|
|
1300
|
+
size_t dst_idx0 = dst_indices[0] ? dst_indices[0][i] : i;
|
|
1301
|
+
size_t src_idx1 = src_indices[1] ? src_indices[1][j] : j;
|
|
1302
|
+
size_t dst_idx1 = dst_indices[1] ? dst_indices[1][j] : j;
|
|
1303
|
+
size_t src_idx2 = src_indices[2] ? src_indices[2][k] : k;
|
|
1304
|
+
size_t dst_idx2 = dst_indices[2] ? dst_indices[2][k] : k;
|
|
1305
|
+
size_t src_idx3 = src_indices[3] ? src_indices[3][l] : l;
|
|
1306
|
+
size_t dst_idx3 = dst_indices[3] ? dst_indices[3][l] : l;
|
|
1307
|
+
const char* p = (const char*)src + src_idx0 * src_strides[0]
|
|
1308
|
+
+ src_idx1 * src_strides[1]
|
|
1309
|
+
+ src_idx2 * src_strides[2]
|
|
1310
|
+
+ src_idx3 * src_strides[3];
|
|
1311
|
+
char* q = (char*)dst + dst_idx0 * dst_strides[0]
|
|
1312
|
+
+ dst_idx1 * dst_strides[1]
|
|
1313
|
+
+ dst_idx2 * dst_strides[2]
|
|
1314
|
+
+ dst_idx3 * dst_strides[3];
|
|
1315
|
+
memcpy(q, p, elem_size);
|
|
1316
|
+
}
|
|
1317
|
+
}
|
|
1318
|
+
|
|
1319
|
+
|
|
1320
|
+
static __global__ void array_copy_from_fabric_kernel(wp::fabricarray_t<void> src,
|
|
1321
|
+
void* dst_data, size_t dst_stride, const int* dst_indices,
|
|
1322
|
+
size_t elem_size)
|
|
1323
|
+
{
|
|
1324
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1325
|
+
|
|
1326
|
+
if (tid < src.size)
|
|
1327
|
+
{
|
|
1328
|
+
size_t dst_idx = dst_indices ? dst_indices[tid] : tid;
|
|
1329
|
+
void* dst_ptr = (char*)dst_data + dst_idx * dst_stride;
|
|
1330
|
+
const void* src_ptr = fabricarray_element_ptr(src, tid, elem_size);
|
|
1331
|
+
memcpy(dst_ptr, src_ptr, elem_size);
|
|
1332
|
+
}
|
|
1333
|
+
}
|
|
1334
|
+
|
|
1335
|
+
static __global__ void array_copy_from_fabric_indexed_kernel(wp::indexedfabricarray_t<void> src,
|
|
1336
|
+
void* dst_data, size_t dst_stride, const int* dst_indices,
|
|
1337
|
+
size_t elem_size)
|
|
1338
|
+
{
|
|
1339
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1340
|
+
|
|
1341
|
+
if (tid < src.size)
|
|
1342
|
+
{
|
|
1343
|
+
size_t src_index = src.indices[tid];
|
|
1344
|
+
size_t dst_idx = dst_indices ? dst_indices[tid] : tid;
|
|
1345
|
+
void* dst_ptr = (char*)dst_data + dst_idx * dst_stride;
|
|
1346
|
+
const void* src_ptr = fabricarray_element_ptr(src.fa, src_index, elem_size);
|
|
1347
|
+
memcpy(dst_ptr, src_ptr, elem_size);
|
|
1348
|
+
}
|
|
1349
|
+
}
|
|
1350
|
+
|
|
1351
|
+
static __global__ void array_copy_to_fabric_kernel(wp::fabricarray_t<void> dst,
|
|
1352
|
+
const void* src_data, size_t src_stride, const int* src_indices,
|
|
1353
|
+
size_t elem_size)
|
|
1354
|
+
{
|
|
1355
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1356
|
+
|
|
1357
|
+
if (tid < dst.size)
|
|
1358
|
+
{
|
|
1359
|
+
size_t src_idx = src_indices ? src_indices[tid] : tid;
|
|
1360
|
+
const void* src_ptr = (const char*)src_data + src_idx * src_stride;
|
|
1361
|
+
void* dst_ptr = fabricarray_element_ptr(dst, tid, elem_size);
|
|
1362
|
+
memcpy(dst_ptr, src_ptr, elem_size);
|
|
1363
|
+
}
|
|
1364
|
+
}
|
|
1365
|
+
|
|
1366
|
+
static __global__ void array_copy_to_fabric_indexed_kernel(wp::indexedfabricarray_t<void> dst,
|
|
1367
|
+
const void* src_data, size_t src_stride, const int* src_indices,
|
|
1368
|
+
size_t elem_size)
|
|
1369
|
+
{
|
|
1370
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1371
|
+
|
|
1372
|
+
if (tid < dst.size)
|
|
1373
|
+
{
|
|
1374
|
+
size_t src_idx = src_indices ? src_indices[tid] : tid;
|
|
1375
|
+
const void* src_ptr = (const char*)src_data + src_idx * src_stride;
|
|
1376
|
+
size_t dst_idx = dst.indices[tid];
|
|
1377
|
+
void* dst_ptr = fabricarray_element_ptr(dst.fa, dst_idx, elem_size);
|
|
1378
|
+
memcpy(dst_ptr, src_ptr, elem_size);
|
|
1379
|
+
}
|
|
1380
|
+
}
|
|
1381
|
+
|
|
1382
|
+
|
|
1383
|
+
static __global__ void array_copy_fabric_to_fabric_kernel(wp::fabricarray_t<void> dst, wp::fabricarray_t<void> src, size_t elem_size)
|
|
1384
|
+
{
|
|
1385
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1386
|
+
|
|
1387
|
+
if (tid < dst.size)
|
|
1388
|
+
{
|
|
1389
|
+
const void* src_ptr = fabricarray_element_ptr(src, tid, elem_size);
|
|
1390
|
+
void* dst_ptr = fabricarray_element_ptr(dst, tid, elem_size);
|
|
1391
|
+
memcpy(dst_ptr, src_ptr, elem_size);
|
|
1392
|
+
}
|
|
1393
|
+
}
|
|
1394
|
+
|
|
1395
|
+
|
|
1396
|
+
static __global__ void array_copy_fabric_to_fabric_indexed_kernel(wp::indexedfabricarray_t<void> dst, wp::fabricarray_t<void> src, size_t elem_size)
|
|
1397
|
+
{
|
|
1398
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1399
|
+
|
|
1400
|
+
if (tid < dst.size)
|
|
1401
|
+
{
|
|
1402
|
+
const void* src_ptr = fabricarray_element_ptr(src, tid, elem_size);
|
|
1403
|
+
size_t dst_index = dst.indices[tid];
|
|
1404
|
+
void* dst_ptr = fabricarray_element_ptr(dst.fa, dst_index, elem_size);
|
|
1405
|
+
memcpy(dst_ptr, src_ptr, elem_size);
|
|
1406
|
+
}
|
|
1407
|
+
}
|
|
1408
|
+
|
|
1409
|
+
|
|
1410
|
+
static __global__ void array_copy_fabric_indexed_to_fabric_kernel(wp::fabricarray_t<void> dst, wp::indexedfabricarray_t<void> src, size_t elem_size)
|
|
1411
|
+
{
|
|
1412
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1413
|
+
|
|
1414
|
+
if (tid < dst.size)
|
|
1415
|
+
{
|
|
1416
|
+
size_t src_index = src.indices[tid];
|
|
1417
|
+
const void* src_ptr = fabricarray_element_ptr(src.fa, src_index, elem_size);
|
|
1418
|
+
void* dst_ptr = fabricarray_element_ptr(dst, tid, elem_size);
|
|
1419
|
+
memcpy(dst_ptr, src_ptr, elem_size);
|
|
1420
|
+
}
|
|
1421
|
+
}
|
|
1422
|
+
|
|
1423
|
+
|
|
1424
|
+
static __global__ void array_copy_fabric_indexed_to_fabric_indexed_kernel(wp::indexedfabricarray_t<void> dst, wp::indexedfabricarray_t<void> src, size_t elem_size)
|
|
1425
|
+
{
|
|
1426
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1427
|
+
|
|
1428
|
+
if (tid < dst.size)
|
|
1429
|
+
{
|
|
1430
|
+
size_t src_index = src.indices[tid];
|
|
1431
|
+
size_t dst_index = dst.indices[tid];
|
|
1432
|
+
const void* src_ptr = fabricarray_element_ptr(src.fa, src_index, elem_size);
|
|
1433
|
+
void* dst_ptr = fabricarray_element_ptr(dst.fa, dst_index, elem_size);
|
|
1434
|
+
memcpy(dst_ptr, src_ptr, elem_size);
|
|
1435
|
+
}
|
|
1436
|
+
}
|
|
1437
|
+
|
|
1438
|
+
|
|
1439
|
+
WP_API bool wp_array_copy_device(void* context, void* dst, void* src, int dst_type, int src_type, int elem_size)
|
|
1440
|
+
{
|
|
1441
|
+
if (!src || !dst)
|
|
1442
|
+
return false;
|
|
1443
|
+
|
|
1444
|
+
const void* src_data = NULL;
|
|
1445
|
+
void* dst_data = NULL;
|
|
1446
|
+
int src_ndim = 0;
|
|
1447
|
+
int dst_ndim = 0;
|
|
1448
|
+
const int* src_shape = NULL;
|
|
1449
|
+
const int* dst_shape = NULL;
|
|
1450
|
+
const int* src_strides = NULL;
|
|
1451
|
+
const int* dst_strides = NULL;
|
|
1452
|
+
const int*const* src_indices = NULL;
|
|
1453
|
+
const int*const* dst_indices = NULL;
|
|
1454
|
+
|
|
1455
|
+
const wp::fabricarray_t<void>* src_fabricarray = NULL;
|
|
1456
|
+
wp::fabricarray_t<void>* dst_fabricarray = NULL;
|
|
1457
|
+
|
|
1458
|
+
const wp::indexedfabricarray_t<void>* src_indexedfabricarray = NULL;
|
|
1459
|
+
wp::indexedfabricarray_t<void>* dst_indexedfabricarray = NULL;
|
|
1460
|
+
|
|
1461
|
+
const int* null_indices[wp::ARRAY_MAX_DIMS] = { NULL };
|
|
1462
|
+
|
|
1463
|
+
if (src_type == wp::ARRAY_TYPE_REGULAR)
|
|
1464
|
+
{
|
|
1465
|
+
const wp::array_t<void>& src_arr = *static_cast<const wp::array_t<void>*>(src);
|
|
1466
|
+
src_data = src_arr.data;
|
|
1467
|
+
src_ndim = src_arr.ndim;
|
|
1468
|
+
src_shape = src_arr.shape.dims;
|
|
1469
|
+
src_strides = src_arr.strides;
|
|
1470
|
+
src_indices = null_indices;
|
|
1471
|
+
}
|
|
1472
|
+
else if (src_type == wp::ARRAY_TYPE_INDEXED)
|
|
1473
|
+
{
|
|
1474
|
+
const wp::indexedarray_t<void>& src_arr = *static_cast<const wp::indexedarray_t<void>*>(src);
|
|
1475
|
+
src_data = src_arr.arr.data;
|
|
1476
|
+
src_ndim = src_arr.arr.ndim;
|
|
1477
|
+
src_shape = src_arr.shape.dims;
|
|
1478
|
+
src_strides = src_arr.arr.strides;
|
|
1479
|
+
src_indices = src_arr.indices;
|
|
1480
|
+
}
|
|
1481
|
+
else if (src_type == wp::ARRAY_TYPE_FABRIC)
|
|
1482
|
+
{
|
|
1483
|
+
src_fabricarray = static_cast<const wp::fabricarray_t<void>*>(src);
|
|
1484
|
+
src_ndim = 1;
|
|
1485
|
+
}
|
|
1486
|
+
else if (src_type == wp::ARRAY_TYPE_FABRIC_INDEXED)
|
|
1487
|
+
{
|
|
1488
|
+
src_indexedfabricarray = static_cast<const wp::indexedfabricarray_t<void>*>(src);
|
|
1489
|
+
src_ndim = 1;
|
|
1490
|
+
}
|
|
1491
|
+
else
|
|
1492
|
+
{
|
|
1493
|
+
fprintf(stderr, "Warp copy error: Invalid array type (%d)\n", src_type);
|
|
1494
|
+
return false;
|
|
1495
|
+
}
|
|
1496
|
+
|
|
1497
|
+
if (dst_type == wp::ARRAY_TYPE_REGULAR)
|
|
1498
|
+
{
|
|
1499
|
+
const wp::array_t<void>& dst_arr = *static_cast<const wp::array_t<void>*>(dst);
|
|
1500
|
+
dst_data = dst_arr.data;
|
|
1501
|
+
dst_ndim = dst_arr.ndim;
|
|
1502
|
+
dst_shape = dst_arr.shape.dims;
|
|
1503
|
+
dst_strides = dst_arr.strides;
|
|
1504
|
+
dst_indices = null_indices;
|
|
1505
|
+
}
|
|
1506
|
+
else if (dst_type == wp::ARRAY_TYPE_INDEXED)
|
|
1507
|
+
{
|
|
1508
|
+
const wp::indexedarray_t<void>& dst_arr = *static_cast<const wp::indexedarray_t<void>*>(dst);
|
|
1509
|
+
dst_data = dst_arr.arr.data;
|
|
1510
|
+
dst_ndim = dst_arr.arr.ndim;
|
|
1511
|
+
dst_shape = dst_arr.shape.dims;
|
|
1512
|
+
dst_strides = dst_arr.arr.strides;
|
|
1513
|
+
dst_indices = dst_arr.indices;
|
|
1514
|
+
}
|
|
1515
|
+
else if (dst_type == wp::ARRAY_TYPE_FABRIC)
|
|
1516
|
+
{
|
|
1517
|
+
dst_fabricarray = static_cast<wp::fabricarray_t<void>*>(dst);
|
|
1518
|
+
dst_ndim = 1;
|
|
1519
|
+
}
|
|
1520
|
+
else if (dst_type == wp::ARRAY_TYPE_FABRIC_INDEXED)
|
|
1521
|
+
{
|
|
1522
|
+
dst_indexedfabricarray = static_cast<wp::indexedfabricarray_t<void>*>(dst);
|
|
1523
|
+
dst_ndim = 1;
|
|
1524
|
+
}
|
|
1525
|
+
else
|
|
1526
|
+
{
|
|
1527
|
+
fprintf(stderr, "Warp copy error: Invalid array type (%d)\n", dst_type);
|
|
1528
|
+
return false;
|
|
1529
|
+
}
|
|
1530
|
+
|
|
1531
|
+
if (src_ndim != dst_ndim)
|
|
1532
|
+
{
|
|
1533
|
+
fprintf(stderr, "Warp copy error: Incompatible array dimensionalities (%d and %d)\n", src_ndim, dst_ndim);
|
|
1534
|
+
return false;
|
|
1535
|
+
}
|
|
1536
|
+
|
|
1537
|
+
ContextGuard guard(context);
|
|
1538
|
+
|
|
1539
|
+
// handle fabric arrays
|
|
1540
|
+
if (dst_fabricarray)
|
|
1541
|
+
{
|
|
1542
|
+
size_t n = dst_fabricarray->size;
|
|
1543
|
+
if (src_fabricarray)
|
|
1544
|
+
{
|
|
1545
|
+
// copy from fabric to fabric
|
|
1546
|
+
if (src_fabricarray->size != n)
|
|
1547
|
+
{
|
|
1548
|
+
fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
|
|
1549
|
+
return false;
|
|
1550
|
+
}
|
|
1551
|
+
wp_launch_device(WP_CURRENT_CONTEXT, array_copy_fabric_to_fabric_kernel, n,
|
|
1552
|
+
(*dst_fabricarray, *src_fabricarray, elem_size));
|
|
1553
|
+
return true;
|
|
1554
|
+
}
|
|
1555
|
+
else if (src_indexedfabricarray)
|
|
1556
|
+
{
|
|
1557
|
+
// copy from fabric indexed to fabric
|
|
1558
|
+
if (src_indexedfabricarray->size != n)
|
|
1559
|
+
{
|
|
1560
|
+
fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
|
|
1561
|
+
return false;
|
|
1562
|
+
}
|
|
1563
|
+
wp_launch_device(WP_CURRENT_CONTEXT, array_copy_fabric_indexed_to_fabric_kernel, n,
|
|
1564
|
+
(*dst_fabricarray, *src_indexedfabricarray, elem_size));
|
|
1565
|
+
return true;
|
|
1566
|
+
}
|
|
1567
|
+
else
|
|
1568
|
+
{
|
|
1569
|
+
// copy to fabric
|
|
1570
|
+
if (size_t(src_shape[0]) != n)
|
|
1571
|
+
{
|
|
1572
|
+
fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
|
|
1573
|
+
return false;
|
|
1574
|
+
}
|
|
1575
|
+
wp_launch_device(WP_CURRENT_CONTEXT, array_copy_to_fabric_kernel, n,
|
|
1576
|
+
(*dst_fabricarray, src_data, src_strides[0], src_indices[0], elem_size));
|
|
1577
|
+
return true;
|
|
1578
|
+
}
|
|
1579
|
+
}
|
|
1580
|
+
if (dst_indexedfabricarray)
|
|
1581
|
+
{
|
|
1582
|
+
size_t n = dst_indexedfabricarray->size;
|
|
1583
|
+
if (src_fabricarray)
|
|
1584
|
+
{
|
|
1585
|
+
// copy from fabric to fabric indexed
|
|
1586
|
+
if (src_fabricarray->size != n)
|
|
1587
|
+
{
|
|
1588
|
+
fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
|
|
1589
|
+
return false;
|
|
1590
|
+
}
|
|
1591
|
+
wp_launch_device(WP_CURRENT_CONTEXT, array_copy_fabric_to_fabric_indexed_kernel, n,
|
|
1592
|
+
(*dst_indexedfabricarray, *src_fabricarray, elem_size));
|
|
1593
|
+
return true;
|
|
1594
|
+
}
|
|
1595
|
+
else if (src_indexedfabricarray)
|
|
1596
|
+
{
|
|
1597
|
+
// copy from fabric indexed to fabric indexed
|
|
1598
|
+
if (src_indexedfabricarray->size != n)
|
|
1599
|
+
{
|
|
1600
|
+
fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
|
|
1601
|
+
return false;
|
|
1602
|
+
}
|
|
1603
|
+
wp_launch_device(WP_CURRENT_CONTEXT, array_copy_fabric_indexed_to_fabric_indexed_kernel, n,
|
|
1604
|
+
(*dst_indexedfabricarray, *src_indexedfabricarray, elem_size));
|
|
1605
|
+
return true;
|
|
1606
|
+
}
|
|
1607
|
+
else
|
|
1608
|
+
{
|
|
1609
|
+
// copy to fabric indexed
|
|
1610
|
+
if (size_t(src_shape[0]) != n)
|
|
1611
|
+
{
|
|
1612
|
+
fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
|
|
1613
|
+
return false;
|
|
1614
|
+
}
|
|
1615
|
+
wp_launch_device(WP_CURRENT_CONTEXT, array_copy_to_fabric_indexed_kernel, n,
|
|
1616
|
+
(*dst_indexedfabricarray, src_data, src_strides[0], src_indices[0], elem_size));
|
|
1617
|
+
return true;
|
|
1618
|
+
}
|
|
1619
|
+
}
|
|
1620
|
+
else if (src_fabricarray)
|
|
1621
|
+
{
|
|
1622
|
+
// copy from fabric
|
|
1623
|
+
size_t n = src_fabricarray->size;
|
|
1624
|
+
if (size_t(dst_shape[0]) != n)
|
|
1625
|
+
{
|
|
1626
|
+
fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
|
|
1627
|
+
return false;
|
|
1628
|
+
}
|
|
1629
|
+
wp_launch_device(WP_CURRENT_CONTEXT, array_copy_from_fabric_kernel, n,
|
|
1630
|
+
(*src_fabricarray, dst_data, dst_strides[0], dst_indices[0], elem_size));
|
|
1631
|
+
return true;
|
|
1632
|
+
}
|
|
1633
|
+
else if (src_indexedfabricarray)
|
|
1634
|
+
{
|
|
1635
|
+
// copy from fabric indexed
|
|
1636
|
+
size_t n = src_indexedfabricarray->size;
|
|
1637
|
+
if (size_t(dst_shape[0]) != n)
|
|
1638
|
+
{
|
|
1639
|
+
fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
|
|
1640
|
+
return false;
|
|
1641
|
+
}
|
|
1642
|
+
wp_launch_device(WP_CURRENT_CONTEXT, array_copy_from_fabric_indexed_kernel, n,
|
|
1643
|
+
(*src_indexedfabricarray, dst_data, dst_strides[0], dst_indices[0], elem_size));
|
|
1644
|
+
return true;
|
|
1645
|
+
}
|
|
1646
|
+
|
|
1647
|
+
size_t n = 1;
|
|
1648
|
+
for (int i = 0; i < src_ndim; i++)
|
|
1649
|
+
{
|
|
1650
|
+
if (src_shape[i] != dst_shape[i])
|
|
1651
|
+
{
|
|
1652
|
+
fprintf(stderr, "Warp copy error: Incompatible array shapes\n");
|
|
1653
|
+
return false;
|
|
1654
|
+
}
|
|
1655
|
+
n *= src_shape[i];
|
|
1656
|
+
}
|
|
1657
|
+
|
|
1658
|
+
switch (src_ndim)
|
|
1659
|
+
{
|
|
1660
|
+
case 1:
|
|
1661
|
+
{
|
|
1662
|
+
wp_launch_device(WP_CURRENT_CONTEXT, array_copy_1d_kernel, n, (dst_data, src_data,
|
|
1663
|
+
dst_strides[0], src_strides[0],
|
|
1664
|
+
dst_indices[0], src_indices[0],
|
|
1665
|
+
src_shape[0], elem_size));
|
|
1666
|
+
break;
|
|
1667
|
+
}
|
|
1668
|
+
case 2:
|
|
1669
|
+
{
|
|
1670
|
+
wp::vec_t<2, size_t> shape_v(src_shape[0], src_shape[1]);
|
|
1671
|
+
wp::vec_t<2, size_t> src_strides_v(src_strides[0], src_strides[1]);
|
|
1672
|
+
wp::vec_t<2, size_t> dst_strides_v(dst_strides[0], dst_strides[1]);
|
|
1673
|
+
wp::vec_t<2, const int*> src_indices_v(src_indices[0], src_indices[1]);
|
|
1674
|
+
wp::vec_t<2, const int*> dst_indices_v(dst_indices[0], dst_indices[1]);
|
|
1675
|
+
|
|
1676
|
+
wp_launch_device(WP_CURRENT_CONTEXT, array_copy_2d_kernel, n, (dst_data, src_data,
|
|
1677
|
+
dst_strides_v, src_strides_v,
|
|
1678
|
+
dst_indices_v, src_indices_v,
|
|
1679
|
+
shape_v, elem_size));
|
|
1680
|
+
break;
|
|
1681
|
+
}
|
|
1682
|
+
case 3:
|
|
1683
|
+
{
|
|
1684
|
+
wp::vec_t<3, size_t> shape_v(src_shape[0], src_shape[1], src_shape[2]);
|
|
1685
|
+
wp::vec_t<3, size_t> src_strides_v(src_strides[0], src_strides[1], src_strides[2]);
|
|
1686
|
+
wp::vec_t<3, size_t> dst_strides_v(dst_strides[0], dst_strides[1], dst_strides[2]);
|
|
1687
|
+
wp::vec_t<3, const int*> src_indices_v(src_indices[0], src_indices[1], src_indices[2]);
|
|
1688
|
+
wp::vec_t<3, const int*> dst_indices_v(dst_indices[0], dst_indices[1], dst_indices[2]);
|
|
1689
|
+
|
|
1690
|
+
wp_launch_device(WP_CURRENT_CONTEXT, array_copy_3d_kernel, n, (dst_data, src_data,
|
|
1691
|
+
dst_strides_v, src_strides_v,
|
|
1692
|
+
dst_indices_v, src_indices_v,
|
|
1693
|
+
shape_v, elem_size));
|
|
1694
|
+
break;
|
|
1695
|
+
}
|
|
1696
|
+
case 4:
|
|
1697
|
+
{
|
|
1698
|
+
wp::vec_t<4, size_t> shape_v(src_shape[0], src_shape[1], src_shape[2], src_shape[3]);
|
|
1699
|
+
wp::vec_t<4, size_t> src_strides_v(src_strides[0], src_strides[1], src_strides[2], src_strides[3]);
|
|
1700
|
+
wp::vec_t<4, size_t> dst_strides_v(dst_strides[0], dst_strides[1], dst_strides[2], dst_strides[3]);
|
|
1701
|
+
wp::vec_t<4, const int*> src_indices_v(src_indices[0], src_indices[1], src_indices[2], src_indices[3]);
|
|
1702
|
+
wp::vec_t<4, const int*> dst_indices_v(dst_indices[0], dst_indices[1], dst_indices[2], dst_indices[3]);
|
|
1703
|
+
|
|
1704
|
+
wp_launch_device(WP_CURRENT_CONTEXT, array_copy_4d_kernel, n, (dst_data, src_data,
|
|
1705
|
+
dst_strides_v, src_strides_v,
|
|
1706
|
+
dst_indices_v, src_indices_v,
|
|
1707
|
+
shape_v, elem_size));
|
|
1708
|
+
break;
|
|
1709
|
+
}
|
|
1710
|
+
default:
|
|
1711
|
+
fprintf(stderr, "Warp copy error: invalid array dimensionality (%d)\n", src_ndim);
|
|
1712
|
+
return false;
|
|
1713
|
+
}
|
|
1714
|
+
|
|
1715
|
+
return check_cuda(cudaGetLastError());
|
|
1716
|
+
}
|
|
1717
|
+
|
|
1718
|
+
|
|
1719
|
+
static __global__ void array_fill_1d_kernel(void* data,
|
|
1720
|
+
size_t n,
|
|
1721
|
+
size_t stride,
|
|
1722
|
+
const int* indices,
|
|
1723
|
+
const void* value,
|
|
1724
|
+
size_t value_size)
|
|
1725
|
+
{
|
|
1726
|
+
size_t i = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1727
|
+
if (i < n)
|
|
1728
|
+
{
|
|
1729
|
+
size_t idx = indices ? indices[i] : i;
|
|
1730
|
+
char* p = (char*)data + idx * stride;
|
|
1731
|
+
memcpy(p, value, value_size);
|
|
1732
|
+
}
|
|
1733
|
+
}
|
|
1734
|
+
|
|
1735
|
+
static __global__ void array_fill_2d_kernel(void* data,
|
|
1736
|
+
wp::vec_t<2, size_t> shape,
|
|
1737
|
+
wp::vec_t<2, size_t> strides,
|
|
1738
|
+
wp::vec_t<2, const int*> indices,
|
|
1739
|
+
const void* value,
|
|
1740
|
+
size_t value_size)
|
|
1741
|
+
{
|
|
1742
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1743
|
+
size_t n = shape[1];
|
|
1744
|
+
size_t i = tid / n;
|
|
1745
|
+
size_t j = tid % n;
|
|
1746
|
+
if (i < shape[0] /*&& j < shape[1]*/)
|
|
1747
|
+
{
|
|
1748
|
+
size_t idx0 = indices[0] ? indices[0][i] : i;
|
|
1749
|
+
size_t idx1 = indices[1] ? indices[1][j] : j;
|
|
1750
|
+
char* p = (char*)data + idx0 * strides[0] + idx1 * strides[1];
|
|
1751
|
+
memcpy(p, value, value_size);
|
|
1752
|
+
}
|
|
1753
|
+
}
|
|
1754
|
+
|
|
1755
|
+
static __global__ void array_fill_3d_kernel(void* data,
|
|
1756
|
+
wp::vec_t<3, size_t> shape,
|
|
1757
|
+
wp::vec_t<3, size_t> strides,
|
|
1758
|
+
wp::vec_t<3, const int*> indices,
|
|
1759
|
+
const void* value,
|
|
1760
|
+
size_t value_size)
|
|
1761
|
+
{
|
|
1762
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1763
|
+
size_t n = shape[1];
|
|
1764
|
+
size_t o = shape[2];
|
|
1765
|
+
size_t i = tid / (n * o);
|
|
1766
|
+
size_t j = tid % (n * o) / o;
|
|
1767
|
+
size_t k = tid % o;
|
|
1768
|
+
if (i < shape[0] && j < shape[1] /*&& k < shape[2]*/)
|
|
1769
|
+
{
|
|
1770
|
+
size_t idx0 = indices[0] ? indices[0][i] : i;
|
|
1771
|
+
size_t idx1 = indices[1] ? indices[1][j] : j;
|
|
1772
|
+
size_t idx2 = indices[2] ? indices[2][k] : k;
|
|
1773
|
+
char* p = (char*)data + idx0 * strides[0] + idx1 * strides[1] + idx2 * strides[2];
|
|
1774
|
+
memcpy(p, value, value_size);
|
|
1775
|
+
}
|
|
1776
|
+
}
|
|
1777
|
+
|
|
1778
|
+
static __global__ void array_fill_4d_kernel(void* data,
|
|
1779
|
+
wp::vec_t<4, size_t> shape,
|
|
1780
|
+
wp::vec_t<4, size_t> strides,
|
|
1781
|
+
wp::vec_t<4, const int*> indices,
|
|
1782
|
+
const void* value,
|
|
1783
|
+
size_t value_size)
|
|
1784
|
+
{
|
|
1785
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1786
|
+
size_t n = shape[1];
|
|
1787
|
+
size_t o = shape[2];
|
|
1788
|
+
size_t p = shape[3];
|
|
1789
|
+
size_t i = tid / (n * o * p);
|
|
1790
|
+
size_t j = tid % (n * o * p) / (o * p);
|
|
1791
|
+
size_t k = tid % (o * p) / p;
|
|
1792
|
+
size_t l = tid % p;
|
|
1793
|
+
if (i < shape[0] && j < shape[1] && k < shape[2] /*&& l < shape[3]*/)
|
|
1794
|
+
{
|
|
1795
|
+
size_t idx0 = indices[0] ? indices[0][i] : i;
|
|
1796
|
+
size_t idx1 = indices[1] ? indices[1][j] : j;
|
|
1797
|
+
size_t idx2 = indices[2] ? indices[2][k] : k;
|
|
1798
|
+
size_t idx3 = indices[3] ? indices[3][l] : l;
|
|
1799
|
+
char* p = (char*)data + idx0 * strides[0] + idx1 * strides[1] + idx2 * strides[2] + idx3 * strides[3];
|
|
1800
|
+
memcpy(p, value, value_size);
|
|
1801
|
+
}
|
|
1802
|
+
}
|
|
1803
|
+
|
|
1804
|
+
|
|
1805
|
+
static __global__ void array_fill_fabric_kernel(wp::fabricarray_t<void> fa, const void* value, size_t value_size)
|
|
1806
|
+
{
|
|
1807
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1808
|
+
if (tid < fa.size)
|
|
1809
|
+
{
|
|
1810
|
+
void* dst_ptr = fabricarray_element_ptr(fa, tid, value_size);
|
|
1811
|
+
memcpy(dst_ptr, value, value_size);
|
|
1812
|
+
}
|
|
1813
|
+
}
|
|
1814
|
+
|
|
1815
|
+
|
|
1816
|
+
static __global__ void array_fill_fabric_indexed_kernel(wp::indexedfabricarray_t<void> ifa, const void* value, size_t value_size)
|
|
1817
|
+
{
|
|
1818
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1819
|
+
if (tid < ifa.size)
|
|
1820
|
+
{
|
|
1821
|
+
size_t idx = size_t(ifa.indices[tid]);
|
|
1822
|
+
if (idx < ifa.fa.size)
|
|
1823
|
+
{
|
|
1824
|
+
void* dst_ptr = fabricarray_element_ptr(ifa.fa, idx, value_size);
|
|
1825
|
+
memcpy(dst_ptr, value, value_size);
|
|
1826
|
+
}
|
|
1827
|
+
}
|
|
1828
|
+
}
|
|
1829
|
+
|
|
1830
|
+
|
|
1831
|
+
WP_API void wp_array_fill_device(void* context, void* arr_ptr, int arr_type, const void* value_ptr, int value_size)
|
|
1832
|
+
{
|
|
1833
|
+
if (!arr_ptr || !value_ptr)
|
|
1834
|
+
return;
|
|
1835
|
+
|
|
1836
|
+
void* data = NULL;
|
|
1837
|
+
int ndim = 0;
|
|
1838
|
+
const int* shape = NULL;
|
|
1839
|
+
const int* strides = NULL;
|
|
1840
|
+
const int*const* indices = NULL;
|
|
1841
|
+
|
|
1842
|
+
wp::fabricarray_t<void>* fa = NULL;
|
|
1843
|
+
wp::indexedfabricarray_t<void>* ifa = NULL;
|
|
1844
|
+
|
|
1845
|
+
const int* null_indices[wp::ARRAY_MAX_DIMS] = { NULL };
|
|
1846
|
+
|
|
1847
|
+
if (arr_type == wp::ARRAY_TYPE_REGULAR)
|
|
1848
|
+
{
|
|
1849
|
+
wp::array_t<void>& arr = *static_cast<wp::array_t<void>*>(arr_ptr);
|
|
1850
|
+
data = arr.data;
|
|
1851
|
+
ndim = arr.ndim;
|
|
1852
|
+
shape = arr.shape.dims;
|
|
1853
|
+
strides = arr.strides;
|
|
1854
|
+
indices = null_indices;
|
|
1855
|
+
}
|
|
1856
|
+
else if (arr_type == wp::ARRAY_TYPE_INDEXED)
|
|
1857
|
+
{
|
|
1858
|
+
wp::indexedarray_t<void>& ia = *static_cast<wp::indexedarray_t<void>*>(arr_ptr);
|
|
1859
|
+
data = ia.arr.data;
|
|
1860
|
+
ndim = ia.arr.ndim;
|
|
1861
|
+
shape = ia.shape.dims;
|
|
1862
|
+
strides = ia.arr.strides;
|
|
1863
|
+
indices = ia.indices;
|
|
1864
|
+
}
|
|
1865
|
+
else if (arr_type == wp::ARRAY_TYPE_FABRIC)
|
|
1866
|
+
{
|
|
1867
|
+
fa = static_cast<wp::fabricarray_t<void>*>(arr_ptr);
|
|
1868
|
+
}
|
|
1869
|
+
else if (arr_type == wp::ARRAY_TYPE_FABRIC_INDEXED)
|
|
1870
|
+
{
|
|
1871
|
+
ifa = static_cast<wp::indexedfabricarray_t<void>*>(arr_ptr);
|
|
1872
|
+
}
|
|
1873
|
+
else
|
|
1874
|
+
{
|
|
1875
|
+
fprintf(stderr, "Warp fill error: Invalid array type id %d\n", arr_type);
|
|
1876
|
+
return;
|
|
1877
|
+
}
|
|
1878
|
+
|
|
1879
|
+
size_t n = 1;
|
|
1880
|
+
for (int i = 0; i < ndim; i++)
|
|
1881
|
+
n *= shape[i];
|
|
1882
|
+
|
|
1883
|
+
ContextGuard guard(context);
|
|
1884
|
+
|
|
1885
|
+
void* value_devptr = NULL; // fill value in device memory
|
|
1886
|
+
bool free_devptr = true; // whether we need to free the memory
|
|
1887
|
+
|
|
1888
|
+
// prepare the fill value in a graph-friendly way
|
|
1889
|
+
if (!capturable_tmp_alloc(WP_CURRENT_CONTEXT, value_ptr, value_size, &value_devptr, &free_devptr))
|
|
1890
|
+
{
|
|
1891
|
+
fprintf(stderr, "Warp fill error: failed to copy value to device memory\n");
|
|
1892
|
+
return;
|
|
1893
|
+
}
|
|
1894
|
+
|
|
1895
|
+
if (fa)
|
|
1896
|
+
{
|
|
1897
|
+
// handle fabric arrays
|
|
1898
|
+
wp_launch_device(WP_CURRENT_CONTEXT, array_fill_fabric_kernel, n,
|
|
1899
|
+
(*fa, value_devptr, value_size));
|
|
1900
|
+
}
|
|
1901
|
+
else if (ifa)
|
|
1902
|
+
{
|
|
1903
|
+
// handle indexed fabric arrays
|
|
1904
|
+
wp_launch_device(WP_CURRENT_CONTEXT, array_fill_fabric_indexed_kernel, n,
|
|
1905
|
+
(*ifa, value_devptr, value_size));
|
|
1906
|
+
}
|
|
1907
|
+
else
|
|
1908
|
+
{
|
|
1909
|
+
// handle regular or indexed arrays
|
|
1910
|
+
switch (ndim)
|
|
1911
|
+
{
|
|
1912
|
+
case 1:
|
|
1913
|
+
{
|
|
1914
|
+
wp_launch_device(WP_CURRENT_CONTEXT, array_fill_1d_kernel, n,
|
|
1915
|
+
(data, shape[0], strides[0], indices[0], value_devptr, value_size));
|
|
1916
|
+
break;
|
|
1917
|
+
}
|
|
1918
|
+
case 2:
|
|
1919
|
+
{
|
|
1920
|
+
wp::vec_t<2, size_t> shape_v(shape[0], shape[1]);
|
|
1921
|
+
wp::vec_t<2, size_t> strides_v(strides[0], strides[1]);
|
|
1922
|
+
wp::vec_t<2, const int*> indices_v(indices[0], indices[1]);
|
|
1923
|
+
wp_launch_device(WP_CURRENT_CONTEXT, array_fill_2d_kernel, n,
|
|
1924
|
+
(data, shape_v, strides_v, indices_v, value_devptr, value_size));
|
|
1925
|
+
break;
|
|
1926
|
+
}
|
|
1927
|
+
case 3:
|
|
1928
|
+
{
|
|
1929
|
+
wp::vec_t<3, size_t> shape_v(shape[0], shape[1], shape[2]);
|
|
1930
|
+
wp::vec_t<3, size_t> strides_v(strides[0], strides[1], strides[2]);
|
|
1931
|
+
wp::vec_t<3, const int*> indices_v(indices[0], indices[1], indices[2]);
|
|
1932
|
+
wp_launch_device(WP_CURRENT_CONTEXT, array_fill_3d_kernel, n,
|
|
1933
|
+
(data, shape_v, strides_v, indices_v, value_devptr, value_size));
|
|
1934
|
+
break;
|
|
1935
|
+
}
|
|
1936
|
+
case 4:
|
|
1937
|
+
{
|
|
1938
|
+
wp::vec_t<4, size_t> shape_v(shape[0], shape[1], shape[2], shape[3]);
|
|
1939
|
+
wp::vec_t<4, size_t> strides_v(strides[0], strides[1], strides[2], strides[3]);
|
|
1940
|
+
wp::vec_t<4, const int*> indices_v(indices[0], indices[1], indices[2], indices[3]);
|
|
1941
|
+
wp_launch_device(WP_CURRENT_CONTEXT, array_fill_4d_kernel, n,
|
|
1942
|
+
(data, shape_v, strides_v, indices_v, value_devptr, value_size));
|
|
1943
|
+
break;
|
|
1944
|
+
}
|
|
1945
|
+
default:
|
|
1946
|
+
fprintf(stderr, "Warp fill error: invalid array dimensionality (%d)\n", ndim);
|
|
1947
|
+
break;
|
|
1948
|
+
}
|
|
1949
|
+
}
|
|
1950
|
+
|
|
1951
|
+
if (free_devptr)
|
|
1952
|
+
{
|
|
1953
|
+
wp_free_device(WP_CURRENT_CONTEXT, value_devptr);
|
|
1954
|
+
}
|
|
1955
|
+
}
|
|
1956
|
+
|
|
1957
|
+
void wp_array_scan_int_device(uint64_t in, uint64_t out, int len, bool inclusive)
|
|
1958
|
+
{
|
|
1959
|
+
scan_device((const int*)in, (int*)out, len, inclusive);
|
|
1960
|
+
}
|
|
1961
|
+
|
|
1962
|
+
void wp_array_scan_float_device(uint64_t in, uint64_t out, int len, bool inclusive)
|
|
1963
|
+
{
|
|
1964
|
+
scan_device((const float*)in, (float*)out, len, inclusive);
|
|
1965
|
+
}
|
|
1966
|
+
|
|
1967
|
+
int wp_cuda_driver_version()
|
|
1968
|
+
{
|
|
1969
|
+
int version;
|
|
1970
|
+
if (check_cu(cuDriverGetVersion_f(&version)))
|
|
1971
|
+
return version;
|
|
1972
|
+
else
|
|
1973
|
+
return 0;
|
|
1974
|
+
}
|
|
1975
|
+
|
|
1976
|
+
int wp_cuda_toolkit_version()
|
|
1977
|
+
{
|
|
1978
|
+
return CUDA_VERSION;
|
|
1979
|
+
}
|
|
1980
|
+
|
|
1981
|
+
bool wp_cuda_driver_is_initialized()
|
|
1982
|
+
{
|
|
1983
|
+
return is_cuda_driver_initialized();
|
|
1984
|
+
}
|
|
1985
|
+
|
|
1986
|
+
int wp_nvrtc_supported_arch_count()
|
|
1987
|
+
{
|
|
1988
|
+
int count;
|
|
1989
|
+
if (check_nvrtc(nvrtcGetNumSupportedArchs(&count)))
|
|
1990
|
+
return count;
|
|
1991
|
+
else
|
|
1992
|
+
return 0;
|
|
1993
|
+
}
|
|
1994
|
+
|
|
1995
|
+
void wp_nvrtc_supported_archs(int* archs)
|
|
1996
|
+
{
|
|
1997
|
+
if (archs)
|
|
1998
|
+
{
|
|
1999
|
+
check_nvrtc(nvrtcGetSupportedArchs(archs));
|
|
2000
|
+
}
|
|
2001
|
+
}
|
|
2002
|
+
|
|
2003
|
+
int wp_cuda_device_get_count()
|
|
2004
|
+
{
|
|
2005
|
+
int count = 0;
|
|
2006
|
+
check_cu(cuDeviceGetCount_f(&count));
|
|
2007
|
+
return count;
|
|
2008
|
+
}
|
|
2009
|
+
|
|
2010
|
+
void* wp_cuda_device_get_primary_context(int ordinal)
|
|
2011
|
+
{
|
|
2012
|
+
if (ordinal >= 0 && ordinal < int(g_devices.size()))
|
|
2013
|
+
{
|
|
2014
|
+
DeviceInfo& device_info = g_devices[ordinal];
|
|
2015
|
+
|
|
2016
|
+
// acquire the primary context if we haven't already
|
|
2017
|
+
if (!device_info.primary_context)
|
|
2018
|
+
check_cu(cuDevicePrimaryCtxRetain_f(&device_info.primary_context, device_info.device));
|
|
2019
|
+
|
|
2020
|
+
return device_info.primary_context;
|
|
2021
|
+
}
|
|
2022
|
+
|
|
2023
|
+
return NULL;
|
|
2024
|
+
}
|
|
2025
|
+
|
|
2026
|
+
const char* wp_cuda_device_get_name(int ordinal)
|
|
2027
|
+
{
|
|
2028
|
+
if (ordinal >= 0 && ordinal < int(g_devices.size()))
|
|
2029
|
+
return g_devices[ordinal].name;
|
|
2030
|
+
return NULL;
|
|
2031
|
+
}
|
|
2032
|
+
|
|
2033
|
+
int wp_cuda_device_get_arch(int ordinal)
|
|
2034
|
+
{
|
|
2035
|
+
if (ordinal >= 0 && ordinal < int(g_devices.size()))
|
|
2036
|
+
return g_devices[ordinal].arch;
|
|
2037
|
+
return 0;
|
|
2038
|
+
}
|
|
2039
|
+
|
|
2040
|
+
int wp_cuda_device_get_sm_count(int ordinal)
|
|
2041
|
+
{
|
|
2042
|
+
if (ordinal >= 0 && ordinal < int(g_devices.size()))
|
|
2043
|
+
return g_devices[ordinal].sm_count;
|
|
2044
|
+
return 0;
|
|
2045
|
+
}
|
|
2046
|
+
|
|
2047
|
+
void wp_cuda_device_get_uuid(int ordinal, char uuid[16])
|
|
2048
|
+
{
|
|
2049
|
+
memcpy(uuid, g_devices[ordinal].uuid.bytes, sizeof(char)*16);
|
|
2050
|
+
}
|
|
2051
|
+
|
|
2052
|
+
int wp_cuda_device_get_pci_domain_id(int ordinal)
|
|
2053
|
+
{
|
|
2054
|
+
if (ordinal >= 0 && ordinal < int(g_devices.size()))
|
|
2055
|
+
return g_devices[ordinal].pci_domain_id;
|
|
2056
|
+
return -1;
|
|
2057
|
+
}
|
|
2058
|
+
|
|
2059
|
+
int wp_cuda_device_get_pci_bus_id(int ordinal)
|
|
2060
|
+
{
|
|
2061
|
+
if (ordinal >= 0 && ordinal < int(g_devices.size()))
|
|
2062
|
+
return g_devices[ordinal].pci_bus_id;
|
|
2063
|
+
return -1;
|
|
2064
|
+
}
|
|
2065
|
+
|
|
2066
|
+
int wp_cuda_device_get_pci_device_id(int ordinal)
|
|
2067
|
+
{
|
|
2068
|
+
if (ordinal >= 0 && ordinal < int(g_devices.size()))
|
|
2069
|
+
return g_devices[ordinal].pci_device_id;
|
|
2070
|
+
return -1;
|
|
2071
|
+
}
|
|
2072
|
+
|
|
2073
|
+
int wp_cuda_device_is_uva(int ordinal)
|
|
2074
|
+
{
|
|
2075
|
+
if (ordinal >= 0 && ordinal < int(g_devices.size()))
|
|
2076
|
+
return g_devices[ordinal].is_uva;
|
|
2077
|
+
return 0;
|
|
2078
|
+
}
|
|
2079
|
+
|
|
2080
|
+
int wp_cuda_device_is_mempool_supported(int ordinal)
|
|
2081
|
+
{
|
|
2082
|
+
if (ordinal >= 0 && ordinal < int(g_devices.size()))
|
|
2083
|
+
return g_devices[ordinal].is_mempool_supported;
|
|
2084
|
+
return 0;
|
|
2085
|
+
}
|
|
2086
|
+
|
|
2087
|
+
int wp_cuda_device_is_ipc_supported(int ordinal)
|
|
2088
|
+
{
|
|
2089
|
+
if (ordinal >= 0 && ordinal < int(g_devices.size()))
|
|
2090
|
+
return g_devices[ordinal].is_ipc_supported;
|
|
2091
|
+
return 0;
|
|
2092
|
+
}
|
|
2093
|
+
|
|
2094
|
+
int wp_cuda_device_set_mempool_release_threshold(int ordinal, uint64_t threshold)
|
|
2095
|
+
{
|
|
2096
|
+
if (ordinal < 0 || ordinal > int(g_devices.size()))
|
|
2097
|
+
{
|
|
2098
|
+
fprintf(stderr, "Invalid device ordinal %d\n", ordinal);
|
|
2099
|
+
return 0;
|
|
2100
|
+
}
|
|
2101
|
+
|
|
2102
|
+
if (!g_devices[ordinal].is_mempool_supported)
|
|
2103
|
+
return 0;
|
|
2104
|
+
|
|
2105
|
+
cudaMemPool_t pool;
|
|
2106
|
+
if (!check_cuda(cudaDeviceGetDefaultMemPool(&pool, ordinal)))
|
|
2107
|
+
{
|
|
2108
|
+
fprintf(stderr, "Warp error: Failed to get memory pool on device %d\n", ordinal);
|
|
2109
|
+
return 0;
|
|
2110
|
+
}
|
|
2111
|
+
|
|
2112
|
+
if (!check_cuda(cudaMemPoolSetAttribute(pool, cudaMemPoolAttrReleaseThreshold, &threshold)))
|
|
2113
|
+
{
|
|
2114
|
+
fprintf(stderr, "Warp error: Failed to set memory pool attribute on device %d\n", ordinal);
|
|
2115
|
+
return 0;
|
|
2116
|
+
}
|
|
2117
|
+
|
|
2118
|
+
return 1; // success
|
|
2119
|
+
}
|
|
2120
|
+
|
|
2121
|
+
uint64_t wp_cuda_device_get_mempool_release_threshold(int ordinal)
|
|
2122
|
+
{
|
|
2123
|
+
if (ordinal < 0 || ordinal > int(g_devices.size()))
|
|
2124
|
+
{
|
|
2125
|
+
fprintf(stderr, "Invalid device ordinal %d\n", ordinal);
|
|
2126
|
+
return 0;
|
|
2127
|
+
}
|
|
2128
|
+
|
|
2129
|
+
if (!g_devices[ordinal].is_mempool_supported)
|
|
2130
|
+
return 0;
|
|
2131
|
+
|
|
2132
|
+
cudaMemPool_t pool;
|
|
2133
|
+
if (!check_cuda(cudaDeviceGetDefaultMemPool(&pool, ordinal)))
|
|
2134
|
+
{
|
|
2135
|
+
fprintf(stderr, "Warp error: Failed to get memory pool on device %d\n", ordinal);
|
|
2136
|
+
return 0;
|
|
2137
|
+
}
|
|
2138
|
+
|
|
2139
|
+
uint64_t threshold = 0;
|
|
2140
|
+
if (!check_cuda(cudaMemPoolGetAttribute(pool, cudaMemPoolAttrReleaseThreshold, &threshold)))
|
|
2141
|
+
{
|
|
2142
|
+
fprintf(stderr, "Warp error: Failed to get memory pool release threshold on device %d\n", ordinal);
|
|
2143
|
+
return 0;
|
|
2144
|
+
}
|
|
2145
|
+
|
|
2146
|
+
return threshold;
|
|
2147
|
+
}
|
|
2148
|
+
|
|
2149
|
+
uint64_t wp_cuda_device_get_mempool_used_mem_current(int ordinal)
|
|
2150
|
+
{
|
|
2151
|
+
if (ordinal < 0 || ordinal > int(g_devices.size()))
|
|
2152
|
+
{
|
|
2153
|
+
fprintf(stderr, "Invalid device ordinal %d\n", ordinal);
|
|
2154
|
+
return 0;
|
|
2155
|
+
}
|
|
2156
|
+
|
|
2157
|
+
if (!g_devices[ordinal].is_mempool_supported)
|
|
2158
|
+
return 0;
|
|
2159
|
+
|
|
2160
|
+
cudaMemPool_t pool;
|
|
2161
|
+
if (!check_cuda(cudaDeviceGetDefaultMemPool(&pool, ordinal)))
|
|
2162
|
+
{
|
|
2163
|
+
fprintf(stderr, "Warp error: Failed to get memory pool on device %d\n", ordinal);
|
|
2164
|
+
return 0;
|
|
2165
|
+
}
|
|
2166
|
+
|
|
2167
|
+
uint64_t mem_used = 0;
|
|
2168
|
+
if (!check_cuda(cudaMemPoolGetAttribute(pool, cudaMemPoolAttrUsedMemCurrent, &mem_used)))
|
|
2169
|
+
{
|
|
2170
|
+
fprintf(stderr, "Warp error: Failed to get amount of currently used memory from the memory pool on device %d\n", ordinal);
|
|
2171
|
+
return 0;
|
|
2172
|
+
}
|
|
2173
|
+
|
|
2174
|
+
return mem_used;
|
|
2175
|
+
}
|
|
2176
|
+
|
|
2177
|
+
uint64_t wp_cuda_device_get_mempool_used_mem_high(int ordinal)
|
|
2178
|
+
{
|
|
2179
|
+
if (ordinal < 0 || ordinal > int(g_devices.size()))
|
|
2180
|
+
{
|
|
2181
|
+
fprintf(stderr, "Invalid device ordinal %d\n", ordinal);
|
|
2182
|
+
return 0;
|
|
2183
|
+
}
|
|
2184
|
+
|
|
2185
|
+
if (!g_devices[ordinal].is_mempool_supported)
|
|
2186
|
+
return 0;
|
|
2187
|
+
|
|
2188
|
+
cudaMemPool_t pool;
|
|
2189
|
+
if (!check_cuda(cudaDeviceGetDefaultMemPool(&pool, ordinal)))
|
|
2190
|
+
{
|
|
2191
|
+
fprintf(stderr, "Warp error: Failed to get memory pool on device %d\n", ordinal);
|
|
2192
|
+
return 0;
|
|
2193
|
+
}
|
|
2194
|
+
|
|
2195
|
+
uint64_t mem_high_water_mark = 0;
|
|
2196
|
+
if (!check_cuda(cudaMemPoolGetAttribute(pool, cudaMemPoolAttrUsedMemHigh, &mem_high_water_mark)))
|
|
2197
|
+
{
|
|
2198
|
+
fprintf(stderr, "Warp error: Failed to get memory usage high water mark from the memory pool on device %d\n", ordinal);
|
|
2199
|
+
return 0;
|
|
2200
|
+
}
|
|
2201
|
+
|
|
2202
|
+
return mem_high_water_mark;
|
|
2203
|
+
}
|
|
2204
|
+
|
|
2205
|
+
void wp_cuda_device_get_memory_info(int ordinal, size_t* free_mem, size_t* total_mem)
|
|
2206
|
+
{
|
|
2207
|
+
// use temporary storage if user didn't specify pointers
|
|
2208
|
+
size_t tmp_free_mem, tmp_total_mem;
|
|
2209
|
+
|
|
2210
|
+
if (free_mem)
|
|
2211
|
+
*free_mem = 0;
|
|
2212
|
+
else
|
|
2213
|
+
free_mem = &tmp_free_mem;
|
|
2214
|
+
|
|
2215
|
+
if (total_mem)
|
|
2216
|
+
*total_mem = 0;
|
|
2217
|
+
else
|
|
2218
|
+
total_mem = &tmp_total_mem;
|
|
2219
|
+
|
|
2220
|
+
if (ordinal >= 0 && ordinal < int(g_devices.size()))
|
|
2221
|
+
{
|
|
2222
|
+
if (g_devices[ordinal].primary_context)
|
|
2223
|
+
{
|
|
2224
|
+
ContextGuard guard(g_devices[ordinal].primary_context, true);
|
|
2225
|
+
check_cu(cuMemGetInfo_f(free_mem, total_mem));
|
|
2226
|
+
}
|
|
2227
|
+
else
|
|
2228
|
+
{
|
|
2229
|
+
// if we haven't acquired the primary context yet, acquire it temporarily
|
|
2230
|
+
CUcontext primary_context = NULL;
|
|
2231
|
+
check_cu(cuDevicePrimaryCtxRetain_f(&primary_context, g_devices[ordinal].device));
|
|
2232
|
+
{
|
|
2233
|
+
ContextGuard guard(primary_context, true);
|
|
2234
|
+
check_cu(cuMemGetInfo_f(free_mem, total_mem));
|
|
2235
|
+
}
|
|
2236
|
+
check_cu(cuDevicePrimaryCtxRelease_f(g_devices[ordinal].device));
|
|
2237
|
+
}
|
|
2238
|
+
}
|
|
2239
|
+
}
|
|
2240
|
+
|
|
2241
|
+
|
|
2242
|
+
void* wp_cuda_context_get_current()
|
|
2243
|
+
{
|
|
2244
|
+
return get_current_context();
|
|
2245
|
+
}
|
|
2246
|
+
|
|
2247
|
+
void wp_cuda_context_set_current(void* context)
|
|
2248
|
+
{
|
|
2249
|
+
CUcontext ctx = static_cast<CUcontext>(context);
|
|
2250
|
+
CUcontext prev_ctx = NULL;
|
|
2251
|
+
check_cu(cuCtxGetCurrent_f(&prev_ctx));
|
|
2252
|
+
if (ctx != prev_ctx)
|
|
2253
|
+
{
|
|
2254
|
+
check_cu(cuCtxSetCurrent_f(ctx));
|
|
2255
|
+
}
|
|
2256
|
+
}
|
|
2257
|
+
|
|
2258
|
+
void wp_cuda_context_push_current(void* context)
|
|
2259
|
+
{
|
|
2260
|
+
check_cu(cuCtxPushCurrent_f(static_cast<CUcontext>(context)));
|
|
2261
|
+
}
|
|
2262
|
+
|
|
2263
|
+
void wp_cuda_context_pop_current()
|
|
2264
|
+
{
|
|
2265
|
+
CUcontext context;
|
|
2266
|
+
check_cu(cuCtxPopCurrent_f(&context));
|
|
2267
|
+
}
|
|
2268
|
+
|
|
2269
|
+
void* wp_cuda_context_create(int device_ordinal)
|
|
2270
|
+
{
|
|
2271
|
+
CUcontext ctx = NULL;
|
|
2272
|
+
CUdevice device;
|
|
2273
|
+
if (check_cu(cuDeviceGet_f(&device, device_ordinal)))
|
|
2274
|
+
check_cu(cuCtxCreate_f(&ctx, 0, device));
|
|
2275
|
+
return ctx;
|
|
2276
|
+
}
|
|
2277
|
+
|
|
2278
|
+
void wp_cuda_context_destroy(void* context)
|
|
2279
|
+
{
|
|
2280
|
+
if (context)
|
|
2281
|
+
{
|
|
2282
|
+
CUcontext ctx = static_cast<CUcontext>(context);
|
|
2283
|
+
|
|
2284
|
+
// ensure this is not the current context
|
|
2285
|
+
if (ctx == wp_cuda_context_get_current())
|
|
2286
|
+
wp_cuda_context_set_current(NULL);
|
|
2287
|
+
|
|
2288
|
+
// release the cached info about this context
|
|
2289
|
+
ContextInfo* info = get_context_info(ctx);
|
|
2290
|
+
if (info)
|
|
2291
|
+
{
|
|
2292
|
+
if (info->stream)
|
|
2293
|
+
check_cu(cuStreamDestroy_f(info->stream));
|
|
2294
|
+
|
|
2295
|
+
if (info->conditional_module)
|
|
2296
|
+
check_cu(cuModuleUnload_f(info->conditional_module));
|
|
2297
|
+
|
|
2298
|
+
g_contexts.erase(ctx);
|
|
2299
|
+
}
|
|
2300
|
+
|
|
2301
|
+
check_cu(cuCtxDestroy_f(ctx));
|
|
2302
|
+
}
|
|
2303
|
+
}
|
|
2304
|
+
|
|
2305
|
+
void wp_cuda_context_synchronize(void* context)
|
|
2306
|
+
{
|
|
2307
|
+
ContextGuard guard(context);
|
|
2308
|
+
|
|
2309
|
+
check_cu(cuCtxSynchronize_f());
|
|
2310
|
+
|
|
2311
|
+
if (!context)
|
|
2312
|
+
context = get_current_context();
|
|
2313
|
+
|
|
2314
|
+
if (run_deferred_actions(context) > 0)
|
|
2315
|
+
{
|
|
2316
|
+
// ensure deferred asynchronous operations complete
|
|
2317
|
+
check_cu(cuCtxSynchronize_f());
|
|
2318
|
+
}
|
|
2319
|
+
|
|
2320
|
+
// check_cuda(cudaDeviceGraphMemTrim(wp_cuda_context_get_device_ordinal(context)));
|
|
2321
|
+
}
|
|
2322
|
+
|
|
2323
|
+
uint64_t wp_cuda_context_check(void* context)
|
|
2324
|
+
{
|
|
2325
|
+
ContextGuard guard(context);
|
|
2326
|
+
|
|
2327
|
+
// check errors before syncing
|
|
2328
|
+
cudaError_t e = cudaGetLastError();
|
|
2329
|
+
check_cuda(e);
|
|
2330
|
+
|
|
2331
|
+
cudaStreamCaptureStatus status = cudaStreamCaptureStatusNone;
|
|
2332
|
+
check_cuda(cudaStreamIsCapturing(get_current_stream(), &status));
|
|
2333
|
+
|
|
2334
|
+
// synchronize if the stream is not capturing
|
|
2335
|
+
if (status == cudaStreamCaptureStatusNone)
|
|
2336
|
+
{
|
|
2337
|
+
check_cuda(cudaDeviceSynchronize());
|
|
2338
|
+
e = cudaGetLastError();
|
|
2339
|
+
}
|
|
2340
|
+
|
|
2341
|
+
return static_cast<uint64_t>(e);
|
|
2342
|
+
}
|
|
2343
|
+
|
|
2344
|
+
|
|
2345
|
+
int wp_cuda_context_get_device_ordinal(void* context)
|
|
2346
|
+
{
|
|
2347
|
+
ContextInfo* info = get_context_info(static_cast<CUcontext>(context));
|
|
2348
|
+
return info && info->device_info ? info->device_info->ordinal : -1;
|
|
2349
|
+
}
|
|
2350
|
+
|
|
2351
|
+
int wp_cuda_context_is_primary(void* context)
|
|
2352
|
+
{
|
|
2353
|
+
CUcontext ctx = static_cast<CUcontext>(context);
|
|
2354
|
+
ContextInfo* context_info = get_context_info(ctx);
|
|
2355
|
+
if (!context_info)
|
|
2356
|
+
{
|
|
2357
|
+
fprintf(stderr, "Warp error: Failed to get context info\n");
|
|
2358
|
+
return 0;
|
|
2359
|
+
}
|
|
2360
|
+
|
|
2361
|
+
// if the device primary context is known, check if it matches the given context
|
|
2362
|
+
DeviceInfo* device_info = context_info->device_info;
|
|
2363
|
+
if (device_info->primary_context)
|
|
2364
|
+
return int(ctx == device_info->primary_context);
|
|
2365
|
+
|
|
2366
|
+
// there is no CUDA API to check if a context is primary, but we can temporarily
|
|
2367
|
+
// acquire the device's primary context to check the pointer
|
|
2368
|
+
CUcontext primary_ctx;
|
|
2369
|
+
if (check_cu(cuDevicePrimaryCtxRetain_f(&primary_ctx, device_info->device)))
|
|
2370
|
+
{
|
|
2371
|
+
check_cu(cuDevicePrimaryCtxRelease_f(device_info->device));
|
|
2372
|
+
return int(ctx == primary_ctx);
|
|
2373
|
+
}
|
|
2374
|
+
|
|
2375
|
+
return 0;
|
|
2376
|
+
}
|
|
2377
|
+
|
|
2378
|
+
void* wp_cuda_context_get_stream(void* context)
|
|
2379
|
+
{
|
|
2380
|
+
ContextInfo* info = get_context_info(static_cast<CUcontext>(context));
|
|
2381
|
+
if (info)
|
|
2382
|
+
{
|
|
2383
|
+
return info->stream;
|
|
2384
|
+
}
|
|
2385
|
+
return NULL;
|
|
2386
|
+
}
|
|
2387
|
+
|
|
2388
|
+
void wp_cuda_context_set_stream(void* context, void* stream, int sync)
|
|
2389
|
+
{
|
|
2390
|
+
ContextInfo* context_info = get_context_info(static_cast<CUcontext>(context));
|
|
2391
|
+
if (context_info)
|
|
2392
|
+
{
|
|
2393
|
+
CUstream new_stream = static_cast<CUstream>(stream);
|
|
2394
|
+
|
|
2395
|
+
// check whether we should sync with the previous stream on this device
|
|
2396
|
+
if (sync)
|
|
2397
|
+
{
|
|
2398
|
+
CUstream old_stream = context_info->stream;
|
|
2399
|
+
StreamInfo* old_stream_info = get_stream_info(old_stream);
|
|
2400
|
+
if (old_stream_info)
|
|
2401
|
+
{
|
|
2402
|
+
CUevent cached_event = old_stream_info->cached_event;
|
|
2403
|
+
check_cu(cuEventRecord_f(cached_event, old_stream));
|
|
2404
|
+
check_cu(cuStreamWaitEvent_f(new_stream, cached_event, CU_EVENT_WAIT_DEFAULT));
|
|
2405
|
+
}
|
|
2406
|
+
}
|
|
2407
|
+
|
|
2408
|
+
context_info->stream = new_stream;
|
|
2409
|
+
}
|
|
2410
|
+
}
|
|
2411
|
+
|
|
2412
|
+
int wp_cuda_is_peer_access_supported(int target_ordinal, int peer_ordinal)
|
|
2413
|
+
{
|
|
2414
|
+
int num_devices = int(g_devices.size());
|
|
2415
|
+
|
|
2416
|
+
if (target_ordinal < 0 || target_ordinal > num_devices)
|
|
2417
|
+
{
|
|
2418
|
+
fprintf(stderr, "Warp error: Invalid target device ordinal %d\n", target_ordinal);
|
|
2419
|
+
return 0;
|
|
2420
|
+
}
|
|
2421
|
+
|
|
2422
|
+
if (peer_ordinal < 0 || peer_ordinal > num_devices)
|
|
2423
|
+
{
|
|
2424
|
+
fprintf(stderr, "Warp error: Invalid peer device ordinal %d\n", peer_ordinal);
|
|
2425
|
+
return 0;
|
|
2426
|
+
}
|
|
2427
|
+
|
|
2428
|
+
if (target_ordinal == peer_ordinal)
|
|
2429
|
+
return 1;
|
|
2430
|
+
|
|
2431
|
+
int can_access = 0;
|
|
2432
|
+
check_cuda(cudaDeviceCanAccessPeer(&can_access, peer_ordinal, target_ordinal));
|
|
2433
|
+
|
|
2434
|
+
return can_access;
|
|
2435
|
+
}
|
|
2436
|
+
|
|
2437
|
+
int wp_cuda_is_peer_access_enabled(void* target_context, void* peer_context)
|
|
2438
|
+
{
|
|
2439
|
+
if (!target_context || !peer_context)
|
|
2440
|
+
{
|
|
2441
|
+
fprintf(stderr, "Warp error: invalid CUDA context\n");
|
|
2442
|
+
return 0;
|
|
2443
|
+
}
|
|
2444
|
+
|
|
2445
|
+
if (target_context == peer_context)
|
|
2446
|
+
return 1;
|
|
2447
|
+
|
|
2448
|
+
int target_ordinal = wp_cuda_context_get_device_ordinal(target_context);
|
|
2449
|
+
int peer_ordinal = wp_cuda_context_get_device_ordinal(peer_context);
|
|
2450
|
+
|
|
2451
|
+
// check if peer access is supported
|
|
2452
|
+
int can_access = 0;
|
|
2453
|
+
check_cuda(cudaDeviceCanAccessPeer(&can_access, peer_ordinal, target_ordinal));
|
|
2454
|
+
if (!can_access)
|
|
2455
|
+
return 0;
|
|
2456
|
+
|
|
2457
|
+
// There is no CUDA API to query if peer access is enabled, but we can try to enable it and check the result.
|
|
2458
|
+
|
|
2459
|
+
ContextGuard guard(peer_context, true);
|
|
2460
|
+
|
|
2461
|
+
CUcontext target_ctx = static_cast<CUcontext>(target_context);
|
|
2462
|
+
|
|
2463
|
+
CUresult result = cuCtxEnablePeerAccess_f(target_ctx, 0);
|
|
2464
|
+
if (result == CUDA_ERROR_PEER_ACCESS_ALREADY_ENABLED)
|
|
2465
|
+
{
|
|
2466
|
+
return 1;
|
|
2467
|
+
}
|
|
2468
|
+
else if (result == CUDA_SUCCESS)
|
|
2469
|
+
{
|
|
2470
|
+
// undo enablement
|
|
2471
|
+
check_cu(cuCtxDisablePeerAccess_f(target_ctx));
|
|
2472
|
+
return 0;
|
|
2473
|
+
}
|
|
2474
|
+
else
|
|
2475
|
+
{
|
|
2476
|
+
// report error
|
|
2477
|
+
check_cu(result);
|
|
2478
|
+
return 0;
|
|
2479
|
+
}
|
|
2480
|
+
}
|
|
2481
|
+
|
|
2482
|
+
int wp_cuda_set_peer_access_enabled(void* target_context, void* peer_context, int enable)
|
|
2483
|
+
{
|
|
2484
|
+
if (!target_context || !peer_context)
|
|
2485
|
+
{
|
|
2486
|
+
fprintf(stderr, "Warp error: invalid CUDA context\n");
|
|
2487
|
+
return 0;
|
|
2488
|
+
}
|
|
2489
|
+
|
|
2490
|
+
if (target_context == peer_context)
|
|
2491
|
+
return 1; // no-op
|
|
2492
|
+
|
|
2493
|
+
int target_ordinal = wp_cuda_context_get_device_ordinal(target_context);
|
|
2494
|
+
int peer_ordinal = wp_cuda_context_get_device_ordinal(peer_context);
|
|
2495
|
+
|
|
2496
|
+
// check if peer access is supported
|
|
2497
|
+
int can_access = 0;
|
|
2498
|
+
check_cuda(cudaDeviceCanAccessPeer(&can_access, peer_ordinal, target_ordinal));
|
|
2499
|
+
if (!can_access)
|
|
2500
|
+
{
|
|
2501
|
+
// failure if enabling, success if disabling
|
|
2502
|
+
if (enable)
|
|
2503
|
+
{
|
|
2504
|
+
fprintf(stderr, "Warp error: device %d cannot access device %d\n", peer_ordinal, target_ordinal);
|
|
2505
|
+
return 0;
|
|
2506
|
+
}
|
|
2507
|
+
else
|
|
2508
|
+
return 1;
|
|
2509
|
+
}
|
|
2510
|
+
|
|
2511
|
+
ContextGuard guard(peer_context, true);
|
|
2512
|
+
|
|
2513
|
+
CUcontext target_ctx = static_cast<CUcontext>(target_context);
|
|
2514
|
+
|
|
2515
|
+
if (enable)
|
|
2516
|
+
{
|
|
2517
|
+
CUresult status = cuCtxEnablePeerAccess_f(target_ctx, 0);
|
|
2518
|
+
if (status != CUDA_SUCCESS && status != CUDA_ERROR_PEER_ACCESS_ALREADY_ENABLED)
|
|
2519
|
+
{
|
|
2520
|
+
check_cu(status);
|
|
2521
|
+
fprintf(stderr, "Warp error: failed to enable peer access from device %d to device %d\n", peer_ordinal, target_ordinal);
|
|
2522
|
+
return 0;
|
|
2523
|
+
}
|
|
2524
|
+
}
|
|
2525
|
+
else
|
|
2526
|
+
{
|
|
2527
|
+
CUresult status = cuCtxDisablePeerAccess_f(target_ctx);
|
|
2528
|
+
if (status != CUDA_SUCCESS && status != CUDA_ERROR_PEER_ACCESS_NOT_ENABLED)
|
|
2529
|
+
{
|
|
2530
|
+
check_cu(status);
|
|
2531
|
+
fprintf(stderr, "Warp error: failed to disable peer access from device %d to device %d\n", peer_ordinal, target_ordinal);
|
|
2532
|
+
return 0;
|
|
2533
|
+
}
|
|
2534
|
+
}
|
|
2535
|
+
|
|
2536
|
+
return 1; // success
|
|
2537
|
+
}
|
|
2538
|
+
|
|
2539
|
+
int wp_cuda_is_mempool_access_enabled(int target_ordinal, int peer_ordinal)
|
|
2540
|
+
{
|
|
2541
|
+
int num_devices = int(g_devices.size());
|
|
2542
|
+
|
|
2543
|
+
if (target_ordinal < 0 || target_ordinal > num_devices)
|
|
2544
|
+
{
|
|
2545
|
+
fprintf(stderr, "Warp error: Invalid device ordinal %d\n", target_ordinal);
|
|
2546
|
+
return 0;
|
|
2547
|
+
}
|
|
2548
|
+
|
|
2549
|
+
if (peer_ordinal < 0 || peer_ordinal > num_devices)
|
|
2550
|
+
{
|
|
2551
|
+
fprintf(stderr, "Warp error: Invalid peer device ordinal %d\n", peer_ordinal);
|
|
2552
|
+
return 0;
|
|
2553
|
+
}
|
|
2554
|
+
|
|
2555
|
+
if (target_ordinal == peer_ordinal)
|
|
2556
|
+
return 1;
|
|
2557
|
+
|
|
2558
|
+
cudaMemPool_t pool;
|
|
2559
|
+
if (!check_cuda(cudaDeviceGetDefaultMemPool(&pool, target_ordinal)))
|
|
2560
|
+
{
|
|
2561
|
+
fprintf(stderr, "Warp error: Failed to get memory pool of device %d\n", target_ordinal);
|
|
2562
|
+
return 0;
|
|
2563
|
+
}
|
|
2564
|
+
|
|
2565
|
+
cudaMemAccessFlags flags = cudaMemAccessFlagsProtNone;
|
|
2566
|
+
cudaMemLocation location;
|
|
2567
|
+
location.id = peer_ordinal;
|
|
2568
|
+
location.type = cudaMemLocationTypeDevice;
|
|
2569
|
+
if (check_cuda(cudaMemPoolGetAccess(&flags, pool, &location)))
|
|
2570
|
+
return int(flags != cudaMemAccessFlagsProtNone);
|
|
2571
|
+
|
|
2572
|
+
return 0;
|
|
2573
|
+
}
|
|
2574
|
+
|
|
2575
|
+
int wp_cuda_set_mempool_access_enabled(int target_ordinal, int peer_ordinal, int enable)
|
|
2576
|
+
{
|
|
2577
|
+
int num_devices = int(g_devices.size());
|
|
2578
|
+
|
|
2579
|
+
if (target_ordinal < 0 || target_ordinal > num_devices)
|
|
2580
|
+
{
|
|
2581
|
+
fprintf(stderr, "Warp error: Invalid device ordinal %d\n", target_ordinal);
|
|
2582
|
+
return 0;
|
|
2583
|
+
}
|
|
2584
|
+
|
|
2585
|
+
if (peer_ordinal < 0 || peer_ordinal > num_devices)
|
|
2586
|
+
{
|
|
2587
|
+
fprintf(stderr, "Warp error: Invalid peer device ordinal %d\n", peer_ordinal);
|
|
2588
|
+
return 0;
|
|
2589
|
+
}
|
|
2590
|
+
|
|
2591
|
+
if (target_ordinal == peer_ordinal)
|
|
2592
|
+
return 1; // no-op
|
|
2593
|
+
|
|
2594
|
+
// get the memory pool
|
|
2595
|
+
cudaMemPool_t pool;
|
|
2596
|
+
if (!check_cuda(cudaDeviceGetDefaultMemPool(&pool, target_ordinal)))
|
|
2597
|
+
{
|
|
2598
|
+
fprintf(stderr, "Warp error: Failed to get memory pool of device %d\n", target_ordinal);
|
|
2599
|
+
return 0;
|
|
2600
|
+
}
|
|
2601
|
+
|
|
2602
|
+
cudaMemAccessDesc desc;
|
|
2603
|
+
desc.location.type = cudaMemLocationTypeDevice;
|
|
2604
|
+
desc.location.id = peer_ordinal;
|
|
2605
|
+
|
|
2606
|
+
// only cudaMemAccessFlagsProtReadWrite and cudaMemAccessFlagsProtNone are supported
|
|
2607
|
+
if (enable)
|
|
2608
|
+
desc.flags = cudaMemAccessFlagsProtReadWrite;
|
|
2609
|
+
else
|
|
2610
|
+
desc.flags = cudaMemAccessFlagsProtNone;
|
|
2611
|
+
|
|
2612
|
+
if (!check_cuda(cudaMemPoolSetAccess(pool, &desc, 1)))
|
|
2613
|
+
{
|
|
2614
|
+
fprintf(stderr, "Warp error: Failed to set mempool access from device %d to device %d\n", peer_ordinal, target_ordinal);
|
|
2615
|
+
return 0;
|
|
2616
|
+
}
|
|
2617
|
+
|
|
2618
|
+
return 1; // success
|
|
2619
|
+
}
|
|
2620
|
+
|
|
2621
|
+
void wp_cuda_ipc_get_mem_handle(void* ptr, char* out_buffer) {
|
|
2622
|
+
CUipcMemHandle memHandle;
|
|
2623
|
+
check_cu(cuIpcGetMemHandle_f(&memHandle, (CUdeviceptr)ptr));
|
|
2624
|
+
memcpy(out_buffer, memHandle.reserved, CU_IPC_HANDLE_SIZE);
|
|
2625
|
+
}
|
|
2626
|
+
|
|
2627
|
+
void* wp_cuda_ipc_open_mem_handle(void* context, char* handle) {
|
|
2628
|
+
ContextGuard guard(context);
|
|
2629
|
+
|
|
2630
|
+
CUipcMemHandle memHandle;
|
|
2631
|
+
memcpy(memHandle.reserved, handle, CU_IPC_HANDLE_SIZE);
|
|
2632
|
+
|
|
2633
|
+
CUdeviceptr device_ptr;
|
|
2634
|
+
|
|
2635
|
+
// Strangely, the CU_IPC_MEM_LAZY_ENABLE_PEER_ACCESS flag is required
|
|
2636
|
+
if check_cu(cuIpcOpenMemHandle_f(&device_ptr, memHandle, CU_IPC_MEM_LAZY_ENABLE_PEER_ACCESS))
|
|
2637
|
+
return (void*) device_ptr;
|
|
2638
|
+
else
|
|
2639
|
+
return NULL;
|
|
2640
|
+
}
|
|
2641
|
+
|
|
2642
|
+
void wp_cuda_ipc_close_mem_handle(void* ptr) {
|
|
2643
|
+
check_cu(cuIpcCloseMemHandle_f((CUdeviceptr) ptr));
|
|
2644
|
+
}
|
|
2645
|
+
|
|
2646
|
+
void wp_cuda_ipc_get_event_handle(void* context, void* event, char* out_buffer) {
|
|
2647
|
+
ContextGuard guard(context);
|
|
2648
|
+
|
|
2649
|
+
CUipcEventHandle eventHandle;
|
|
2650
|
+
check_cu(cuIpcGetEventHandle_f(&eventHandle, static_cast<CUevent>(event)));
|
|
2651
|
+
memcpy(out_buffer, eventHandle.reserved, CU_IPC_HANDLE_SIZE);
|
|
2652
|
+
}
|
|
2653
|
+
|
|
2654
|
+
void* wp_cuda_ipc_open_event_handle(void* context, char* handle) {
|
|
2655
|
+
ContextGuard guard(context);
|
|
2656
|
+
|
|
2657
|
+
CUipcEventHandle eventHandle;
|
|
2658
|
+
memcpy(eventHandle.reserved, handle, CU_IPC_HANDLE_SIZE);
|
|
2659
|
+
|
|
2660
|
+
CUevent event;
|
|
2661
|
+
|
|
2662
|
+
if (check_cu(cuIpcOpenEventHandle_f(&event, eventHandle)))
|
|
2663
|
+
return event;
|
|
2664
|
+
else
|
|
2665
|
+
return NULL;
|
|
2666
|
+
}
|
|
2667
|
+
|
|
2668
|
+
void* wp_cuda_stream_create(void* context, int priority)
|
|
2669
|
+
{
|
|
2670
|
+
ContextGuard guard(context, true);
|
|
2671
|
+
|
|
2672
|
+
CUstream stream;
|
|
2673
|
+
if (check_cu(cuStreamCreateWithPriority_f(&stream, CU_STREAM_DEFAULT, priority)))
|
|
2674
|
+
{
|
|
2675
|
+
wp_cuda_stream_register(WP_CURRENT_CONTEXT, stream);
|
|
2676
|
+
return stream;
|
|
2677
|
+
}
|
|
2678
|
+
else
|
|
2679
|
+
return NULL;
|
|
2680
|
+
}
|
|
2681
|
+
|
|
2682
|
+
void wp_cuda_stream_destroy(void* context, void* stream)
|
|
2683
|
+
{
|
|
2684
|
+
if (!stream)
|
|
2685
|
+
return;
|
|
2686
|
+
|
|
2687
|
+
wp_cuda_stream_unregister(context, stream);
|
|
2688
|
+
|
|
2689
|
+
// release temporary radix sort buffer associated with this stream
|
|
2690
|
+
radix_sort_release(context, stream);
|
|
2691
|
+
|
|
2692
|
+
check_cu(cuStreamDestroy_f(static_cast<CUstream>(stream)));
|
|
2693
|
+
}
|
|
2694
|
+
|
|
2695
|
+
int wp_cuda_stream_query(void* stream)
|
|
2696
|
+
{
|
|
2697
|
+
CUresult res = cuStreamQuery_f(static_cast<CUstream>(stream));
|
|
2698
|
+
|
|
2699
|
+
if ((res != CUDA_SUCCESS) && (res != CUDA_ERROR_NOT_READY))
|
|
2700
|
+
{
|
|
2701
|
+
// Abnormal, print out error
|
|
2702
|
+
check_cu(res);
|
|
2703
|
+
}
|
|
2704
|
+
|
|
2705
|
+
return res;
|
|
2706
|
+
}
|
|
2707
|
+
|
|
2708
|
+
void wp_cuda_stream_register(void* context, void* stream)
|
|
2709
|
+
{
|
|
2710
|
+
if (!stream)
|
|
2711
|
+
return;
|
|
2712
|
+
|
|
2713
|
+
ContextGuard guard(context);
|
|
2714
|
+
|
|
2715
|
+
// populate stream info
|
|
2716
|
+
StreamInfo& stream_info = g_streams[static_cast<CUstream>(stream)];
|
|
2717
|
+
check_cu(cuEventCreate_f(&stream_info.cached_event, CU_EVENT_DISABLE_TIMING));
|
|
2718
|
+
}
|
|
2719
|
+
|
|
2720
|
+
void wp_cuda_stream_unregister(void* context, void* stream)
|
|
2721
|
+
{
|
|
2722
|
+
if (!stream)
|
|
2723
|
+
return;
|
|
2724
|
+
|
|
2725
|
+
CUstream cuda_stream = static_cast<CUstream>(stream);
|
|
2726
|
+
|
|
2727
|
+
StreamInfo* stream_info = get_stream_info(cuda_stream);
|
|
2728
|
+
if (stream_info)
|
|
2729
|
+
{
|
|
2730
|
+
// release stream info
|
|
2731
|
+
check_cu(cuEventDestroy_f(stream_info->cached_event));
|
|
2732
|
+
g_streams.erase(cuda_stream);
|
|
2733
|
+
}
|
|
2734
|
+
|
|
2735
|
+
// make sure we don't leave dangling references to this stream
|
|
2736
|
+
ContextInfo* context_info = get_context_info(context);
|
|
2737
|
+
if (context_info)
|
|
2738
|
+
{
|
|
2739
|
+
if (cuda_stream == context_info->stream)
|
|
2740
|
+
context_info->stream = NULL;
|
|
2741
|
+
}
|
|
2742
|
+
}
|
|
2743
|
+
|
|
2744
|
+
void* wp_cuda_stream_get_current()
|
|
2745
|
+
{
|
|
2746
|
+
return get_current_stream();
|
|
2747
|
+
}
|
|
2748
|
+
|
|
2749
|
+
void wp_cuda_stream_synchronize(void* stream)
|
|
2750
|
+
{
|
|
2751
|
+
check_cu(cuStreamSynchronize_f(static_cast<CUstream>(stream)));
|
|
2752
|
+
}
|
|
2753
|
+
|
|
2754
|
+
void wp_cuda_stream_wait_event(void* stream, void* event, bool external)
|
|
2755
|
+
{
|
|
2756
|
+
// the external flag can only be used during graph capture
|
|
2757
|
+
if (external && !g_captures.empty() && wp_cuda_stream_is_capturing(stream))
|
|
2758
|
+
{
|
|
2759
|
+
// wait for an external event during graph capture
|
|
2760
|
+
check_cu(cuStreamWaitEvent_f(static_cast<CUstream>(stream), static_cast<CUevent>(event), CU_EVENT_WAIT_EXTERNAL));
|
|
2761
|
+
}
|
|
2762
|
+
else
|
|
2763
|
+
{
|
|
2764
|
+
check_cu(cuStreamWaitEvent_f(static_cast<CUstream>(stream), static_cast<CUevent>(event), CU_EVENT_WAIT_DEFAULT));
|
|
2765
|
+
}
|
|
2766
|
+
}
|
|
2767
|
+
|
|
2768
|
+
void wp_cuda_stream_wait_stream(void* stream, void* other_stream, void* event, bool external)
|
|
2769
|
+
{
|
|
2770
|
+
unsigned record_flags = CU_EVENT_RECORD_DEFAULT;
|
|
2771
|
+
unsigned wait_flags = CU_EVENT_WAIT_DEFAULT;
|
|
2772
|
+
|
|
2773
|
+
// the external flag can only be used during graph capture
|
|
2774
|
+
if (external && !g_captures.empty())
|
|
2775
|
+
{
|
|
2776
|
+
if (wp_cuda_stream_is_capturing(other_stream))
|
|
2777
|
+
record_flags = CU_EVENT_RECORD_EXTERNAL;
|
|
2778
|
+
if (wp_cuda_stream_is_capturing(stream))
|
|
2779
|
+
wait_flags = CU_EVENT_WAIT_EXTERNAL;
|
|
2780
|
+
}
|
|
2781
|
+
|
|
2782
|
+
check_cu(cuEventRecordWithFlags_f(static_cast<CUevent>(event), static_cast<CUstream>(other_stream), record_flags));
|
|
2783
|
+
check_cu(cuStreamWaitEvent_f(static_cast<CUstream>(stream), static_cast<CUevent>(event), wait_flags));
|
|
2784
|
+
}
|
|
2785
|
+
|
|
2786
|
+
int wp_cuda_stream_is_capturing(void* stream)
|
|
2787
|
+
{
|
|
2788
|
+
cudaStreamCaptureStatus status = cudaStreamCaptureStatusNone;
|
|
2789
|
+
check_cuda(cudaStreamIsCapturing(static_cast<cudaStream_t>(stream), &status));
|
|
2790
|
+
|
|
2791
|
+
return int(status != cudaStreamCaptureStatusNone);
|
|
2792
|
+
}
|
|
2793
|
+
|
|
2794
|
+
uint64_t wp_cuda_stream_get_capture_id(void* stream)
|
|
2795
|
+
{
|
|
2796
|
+
return get_capture_id(static_cast<CUstream>(stream));
|
|
2797
|
+
}
|
|
2798
|
+
|
|
2799
|
+
int wp_cuda_stream_get_priority(void* stream)
|
|
2800
|
+
{
|
|
2801
|
+
int priority = 0;
|
|
2802
|
+
check_cuda(cuStreamGetPriority_f(static_cast<CUstream>(stream), &priority));
|
|
2803
|
+
|
|
2804
|
+
return priority;
|
|
2805
|
+
}
|
|
2806
|
+
|
|
2807
|
+
void* wp_cuda_event_create(void* context, unsigned flags)
|
|
2808
|
+
{
|
|
2809
|
+
ContextGuard guard(context, true);
|
|
2810
|
+
|
|
2811
|
+
CUevent event;
|
|
2812
|
+
if (check_cu(cuEventCreate_f(&event, flags)))
|
|
2813
|
+
return event;
|
|
2814
|
+
else
|
|
2815
|
+
return NULL;
|
|
2816
|
+
}
|
|
2817
|
+
|
|
2818
|
+
void wp_cuda_event_destroy(void* event)
|
|
2819
|
+
{
|
|
2820
|
+
check_cu(cuEventDestroy_f(static_cast<CUevent>(event)));
|
|
2821
|
+
}
|
|
2822
|
+
|
|
2823
|
+
int wp_cuda_event_query(void* event)
|
|
2824
|
+
{
|
|
2825
|
+
CUresult res = cuEventQuery_f(static_cast<CUevent>(event));
|
|
2826
|
+
|
|
2827
|
+
if ((res != CUDA_SUCCESS) && (res != CUDA_ERROR_NOT_READY))
|
|
2828
|
+
{
|
|
2829
|
+
// Abnormal, print out error
|
|
2830
|
+
check_cu(res);
|
|
2831
|
+
}
|
|
2832
|
+
|
|
2833
|
+
return res;
|
|
2834
|
+
}
|
|
2835
|
+
|
|
2836
|
+
void wp_cuda_event_record(void* event, void* stream, bool external)
|
|
2837
|
+
{
|
|
2838
|
+
// the external flag can only be used during graph capture
|
|
2839
|
+
if (external && !g_captures.empty() && wp_cuda_stream_is_capturing(stream))
|
|
2840
|
+
{
|
|
2841
|
+
// record external event during graph capture (e.g., for timing or when explicitly specified by the user)
|
|
2842
|
+
check_cu(cuEventRecordWithFlags_f(static_cast<CUevent>(event), static_cast<CUstream>(stream), CU_EVENT_RECORD_EXTERNAL));
|
|
2843
|
+
}
|
|
2844
|
+
else
|
|
2845
|
+
{
|
|
2846
|
+
check_cu(cuEventRecord_f(static_cast<CUevent>(event), static_cast<CUstream>(stream)));
|
|
2847
|
+
}
|
|
2848
|
+
}
|
|
2849
|
+
|
|
2850
|
+
void wp_cuda_event_synchronize(void* event)
|
|
2851
|
+
{
|
|
2852
|
+
check_cu(cuEventSynchronize_f(static_cast<CUevent>(event)));
|
|
2853
|
+
}
|
|
2854
|
+
|
|
2855
|
+
float wp_cuda_event_elapsed_time(void* start_event, void* end_event)
|
|
2856
|
+
{
|
|
2857
|
+
float elapsed = 0.0f;
|
|
2858
|
+
cudaEvent_t start = static_cast<cudaEvent_t>(start_event);
|
|
2859
|
+
cudaEvent_t end = static_cast<cudaEvent_t>(end_event);
|
|
2860
|
+
check_cuda(cudaEventElapsedTime(&elapsed, start, end));
|
|
2861
|
+
return elapsed;
|
|
2862
|
+
}
|
|
2863
|
+
|
|
2864
|
+
bool wp_cuda_graph_begin_capture(void* context, void* stream, int external)
|
|
2865
|
+
{
|
|
2866
|
+
ContextGuard guard(context);
|
|
2867
|
+
|
|
2868
|
+
CUstream cuda_stream = static_cast<CUstream>(stream);
|
|
2869
|
+
StreamInfo* stream_info = get_stream_info(cuda_stream);
|
|
2870
|
+
if (!stream_info)
|
|
2871
|
+
{
|
|
2872
|
+
wp::set_error_string("Warp error: unknown stream");
|
|
2873
|
+
return false;
|
|
2874
|
+
}
|
|
2875
|
+
|
|
2876
|
+
if (external)
|
|
2877
|
+
{
|
|
2878
|
+
// if it's an external capture, make sure it's already active so we can get the capture id
|
|
2879
|
+
cudaStreamCaptureStatus status = cudaStreamCaptureStatusNone;
|
|
2880
|
+
if (!check_cuda(cudaStreamIsCapturing(cuda_stream, &status)))
|
|
2881
|
+
return false;
|
|
2882
|
+
if (status != cudaStreamCaptureStatusActive)
|
|
2883
|
+
{
|
|
2884
|
+
wp::set_error_string("Warp error: stream is not capturing");
|
|
2885
|
+
return false;
|
|
2886
|
+
}
|
|
2887
|
+
}
|
|
2888
|
+
else
|
|
2889
|
+
{
|
|
2890
|
+
// start the capture
|
|
2891
|
+
if (!check_cuda(cudaStreamBeginCapture(cuda_stream, cudaStreamCaptureModeThreadLocal)))
|
|
2892
|
+
return false;
|
|
2893
|
+
}
|
|
2894
|
+
|
|
2895
|
+
uint64_t capture_id = get_capture_id(cuda_stream);
|
|
2896
|
+
|
|
2897
|
+
CaptureInfo* capture = new CaptureInfo();
|
|
2898
|
+
capture->stream = cuda_stream;
|
|
2899
|
+
capture->id = capture_id;
|
|
2900
|
+
capture->external = bool(external);
|
|
2901
|
+
|
|
2902
|
+
// update stream info
|
|
2903
|
+
stream_info->capture = capture;
|
|
2904
|
+
|
|
2905
|
+
// add to known captures
|
|
2906
|
+
g_captures[capture_id] = capture;
|
|
2907
|
+
|
|
2908
|
+
return true;
|
|
2909
|
+
}
|
|
2910
|
+
|
|
2911
|
+
bool wp_cuda_graph_end_capture(void* context, void* stream, void** graph_ret)
|
|
2912
|
+
{
|
|
2913
|
+
ContextGuard guard(context);
|
|
2914
|
+
|
|
2915
|
+
// check if this is a known stream
|
|
2916
|
+
CUstream cuda_stream = static_cast<CUstream>(stream);
|
|
2917
|
+
StreamInfo* stream_info = get_stream_info(cuda_stream);
|
|
2918
|
+
if (!stream_info)
|
|
2919
|
+
{
|
|
2920
|
+
wp::set_error_string("Warp error: unknown capture stream");
|
|
2921
|
+
return false;
|
|
2922
|
+
}
|
|
2923
|
+
|
|
2924
|
+
// check if this stream was used to start a capture
|
|
2925
|
+
CaptureInfo* capture = stream_info->capture;
|
|
2926
|
+
if (!capture)
|
|
2927
|
+
{
|
|
2928
|
+
wp::set_error_string("Warp error: stream has no capture started");
|
|
2929
|
+
return false;
|
|
2930
|
+
}
|
|
2931
|
+
|
|
2932
|
+
// get capture info
|
|
2933
|
+
bool external = capture->external;
|
|
2934
|
+
uint64_t capture_id = capture->id;
|
|
2935
|
+
std::vector<FreeInfo> tmp_allocs = capture->tmp_allocs;
|
|
2936
|
+
|
|
2937
|
+
// clear capture info
|
|
2938
|
+
stream_info->capture = NULL;
|
|
2939
|
+
g_captures.erase(capture_id);
|
|
2940
|
+
delete capture;
|
|
2941
|
+
|
|
2942
|
+
// a lambda to clean up on exit in case of error
|
|
2943
|
+
auto clean_up = [cuda_stream, capture_id, external]()
|
|
2944
|
+
{
|
|
2945
|
+
// unreference outstanding graph allocs so that they will be released with the user reference
|
|
2946
|
+
for (auto it = g_graph_allocs.begin(); it != g_graph_allocs.end(); ++it)
|
|
2947
|
+
{
|
|
2948
|
+
GraphAllocInfo& alloc_info = it->second;
|
|
2949
|
+
if (alloc_info.capture_id == capture_id)
|
|
2950
|
+
alloc_info.graph_destroyed = true;
|
|
2951
|
+
}
|
|
2952
|
+
|
|
2953
|
+
// make sure we terminate the capture
|
|
2954
|
+
if (!external)
|
|
2955
|
+
{
|
|
2956
|
+
cudaGraph_t graph = NULL;
|
|
2957
|
+
cudaStreamEndCapture(cuda_stream, &graph);
|
|
2958
|
+
cudaGetLastError();
|
|
2959
|
+
}
|
|
2960
|
+
};
|
|
2961
|
+
|
|
2962
|
+
// get captured graph without ending the capture in case it is external
|
|
2963
|
+
cudaGraph_t graph = get_capture_graph(cuda_stream);
|
|
2964
|
+
if (!graph)
|
|
2965
|
+
{
|
|
2966
|
+
clean_up();
|
|
2967
|
+
return false;
|
|
2968
|
+
}
|
|
2969
|
+
|
|
2970
|
+
// ensure that all forked streams are joined to the main capture stream by manually
|
|
2971
|
+
// adding outstanding capture dependencies gathered from the graph leaf nodes
|
|
2972
|
+
std::vector<cudaGraphNode_t> stream_dependencies;
|
|
2973
|
+
std::vector<cudaGraphNode_t> leaf_nodes;
|
|
2974
|
+
if (get_capture_dependencies(cuda_stream, stream_dependencies) && get_graph_leaf_nodes(graph, leaf_nodes))
|
|
2975
|
+
{
|
|
2976
|
+
// compute set difference to get unjoined dependencies
|
|
2977
|
+
std::vector<cudaGraphNode_t> unjoined_dependencies;
|
|
2978
|
+
std::sort(stream_dependencies.begin(), stream_dependencies.end());
|
|
2979
|
+
std::sort(leaf_nodes.begin(), leaf_nodes.end());
|
|
2980
|
+
std::set_difference(leaf_nodes.begin(), leaf_nodes.end(),
|
|
2981
|
+
stream_dependencies.begin(), stream_dependencies.end(),
|
|
2982
|
+
std::back_inserter(unjoined_dependencies));
|
|
2983
|
+
if (!unjoined_dependencies.empty())
|
|
2984
|
+
{
|
|
2985
|
+
check_cu(cuStreamUpdateCaptureDependencies_f(cuda_stream, unjoined_dependencies.data(), unjoined_dependencies.size(),
|
|
2986
|
+
CU_STREAM_ADD_CAPTURE_DEPENDENCIES));
|
|
2987
|
+
// ensure graph is still valid
|
|
2988
|
+
if (get_capture_graph(cuda_stream) != graph)
|
|
2989
|
+
{
|
|
2990
|
+
clean_up();
|
|
2991
|
+
return false;
|
|
2992
|
+
}
|
|
2993
|
+
}
|
|
2994
|
+
}
|
|
2995
|
+
|
|
2996
|
+
// check if this graph has unfreed allocations, which require special handling
|
|
2997
|
+
std::vector<void*> unfreed_allocs;
|
|
2998
|
+
for (auto it = g_graph_allocs.begin(); it != g_graph_allocs.end(); ++it)
|
|
2999
|
+
{
|
|
3000
|
+
GraphAllocInfo& alloc_info = it->second;
|
|
3001
|
+
if (alloc_info.capture_id == capture_id)
|
|
3002
|
+
unfreed_allocs.push_back(it->first);
|
|
3003
|
+
}
|
|
3004
|
+
|
|
3005
|
+
if (!unfreed_allocs.empty() || !tmp_allocs.empty())
|
|
3006
|
+
{
|
|
3007
|
+
// Create a user object that will notify us when the instantiated graph is destroyed.
|
|
3008
|
+
// This works for external captures also, since we wouldn't otherwise know when
|
|
3009
|
+
// the externally-created graph instance gets deleted.
|
|
3010
|
+
// This callback is guaranteed to arrive after the graph has finished executing on the device,
|
|
3011
|
+
// not necessarily when cudaGraphExecDestroy() is called.
|
|
3012
|
+
GraphDestroyCallbackInfo* graph_info = new GraphDestroyCallbackInfo;
|
|
3013
|
+
graph_info->context = context ? context : get_current_context();
|
|
3014
|
+
graph_info->unfreed_allocs = unfreed_allocs;
|
|
3015
|
+
graph_info->tmp_allocs = tmp_allocs;
|
|
3016
|
+
cudaUserObject_t user_object;
|
|
3017
|
+
check_cuda(cudaUserObjectCreate(&user_object, graph_info, on_graph_destroy, 1, cudaUserObjectNoDestructorSync));
|
|
3018
|
+
check_cuda(cudaGraphRetainUserObject(graph, user_object, 1, cudaGraphUserObjectMove));
|
|
3019
|
+
|
|
3020
|
+
// ensure graph is still valid
|
|
3021
|
+
if (get_capture_graph(cuda_stream) != graph)
|
|
3022
|
+
{
|
|
3023
|
+
clean_up();
|
|
3024
|
+
return false;
|
|
3025
|
+
}
|
|
3026
|
+
}
|
|
3027
|
+
|
|
3028
|
+
// for external captures, we don't instantiate the graph ourselves, so we're done
|
|
3029
|
+
if (external)
|
|
3030
|
+
return true;
|
|
3031
|
+
|
|
3032
|
+
// end the capture
|
|
3033
|
+
if (!check_cuda(cudaStreamEndCapture(cuda_stream, &graph)))
|
|
3034
|
+
return false;
|
|
3035
|
+
|
|
3036
|
+
// process deferred free list if no more captures are ongoing
|
|
3037
|
+
if (g_captures.empty())
|
|
3038
|
+
{
|
|
3039
|
+
run_deferred_actions();
|
|
3040
|
+
}
|
|
3041
|
+
|
|
3042
|
+
if (graph_ret)
|
|
3043
|
+
*graph_ret = graph;
|
|
3044
|
+
|
|
3045
|
+
return true;
|
|
3046
|
+
}
|
|
3047
|
+
|
|
3048
|
+
bool wp_capture_debug_dot_print(void* graph, const char *path, uint32_t flags)
|
|
3049
|
+
{
|
|
3050
|
+
if (!check_cuda(cudaGraphDebugDotPrint((cudaGraph_t)graph, path, flags)))
|
|
3051
|
+
return false;
|
|
3052
|
+
return true;
|
|
3053
|
+
}
|
|
3054
|
+
|
|
3055
|
+
bool wp_cuda_graph_create_exec(void* context, void* stream, void* graph, void** graph_exec_ret)
|
|
3056
|
+
{
|
|
3057
|
+
ContextGuard guard(context);
|
|
3058
|
+
|
|
3059
|
+
cudaGraphExec_t graph_exec = NULL;
|
|
3060
|
+
if (!check_cuda(cudaGraphInstantiateWithFlags(&graph_exec, (cudaGraph_t)graph, cudaGraphInstantiateFlagAutoFreeOnLaunch)))
|
|
3061
|
+
return false;
|
|
3062
|
+
|
|
3063
|
+
// Usually uploading the graph explicitly is optional, but when updating graph nodes (e.g., indirect dispatch)
|
|
3064
|
+
// then the upload is required because otherwise the graph nodes that get updated might not yet be uploaded, which
|
|
3065
|
+
// results in undefined behavior.
|
|
3066
|
+
CUstream cuda_stream = static_cast<CUstream>(stream);
|
|
3067
|
+
if (!check_cuda(cudaGraphUpload(graph_exec, cuda_stream)))
|
|
3068
|
+
return false;
|
|
3069
|
+
|
|
3070
|
+
if (graph_exec_ret)
|
|
3071
|
+
*graph_exec_ret = graph_exec;
|
|
3072
|
+
|
|
3073
|
+
return true;
|
|
3074
|
+
}
|
|
3075
|
+
|
|
3076
|
+
// Support for conditional graph nodes available with CUDA 12.4+.
|
|
3077
|
+
#if CUDA_VERSION >= 12040
|
|
3078
|
+
|
|
3079
|
+
// CUBIN or PTX data for compiled conditional modules, loaded on demand, keyed on device architecture
|
|
3080
|
+
using ModuleKey = std::pair<int, bool>; // <arch, use_ptx>
|
|
3081
|
+
static std::map<ModuleKey, void*> g_conditional_modules;
|
|
3082
|
+
|
|
3083
|
+
// Compile module with conditional helper kernels
|
|
3084
|
+
static void* compile_conditional_module(int arch, bool use_ptx)
|
|
3085
|
+
{
|
|
3086
|
+
static const char* kernel_source = R"(
|
|
3087
|
+
typedef __device_builtin__ unsigned long long cudaGraphConditionalHandle;
|
|
3088
|
+
extern "C" __device__ __cudart_builtin__ void cudaGraphSetConditional(cudaGraphConditionalHandle handle, unsigned int value);
|
|
3089
|
+
|
|
3090
|
+
extern "C" __global__ void set_conditional_if_handle_kernel(cudaGraphConditionalHandle handle, int* value)
|
|
3091
|
+
{
|
|
3092
|
+
if (threadIdx.x + blockIdx.x * blockDim.x == 0)
|
|
3093
|
+
cudaGraphSetConditional(handle, *value);
|
|
3094
|
+
}
|
|
3095
|
+
|
|
3096
|
+
extern "C" __global__ void set_conditional_else_handle_kernel(cudaGraphConditionalHandle handle, int* value)
|
|
3097
|
+
{
|
|
3098
|
+
if (threadIdx.x + blockIdx.x * blockDim.x == 0)
|
|
3099
|
+
cudaGraphSetConditional(handle, !*value);
|
|
3100
|
+
}
|
|
3101
|
+
|
|
3102
|
+
extern "C" __global__ void set_conditional_if_else_handles_kernel(cudaGraphConditionalHandle if_handle, cudaGraphConditionalHandle else_handle, int* value)
|
|
3103
|
+
{
|
|
3104
|
+
if (threadIdx.x + blockIdx.x * blockDim.x == 0)
|
|
3105
|
+
{
|
|
3106
|
+
cudaGraphSetConditional(if_handle, *value);
|
|
3107
|
+
cudaGraphSetConditional(else_handle, !*value);
|
|
3108
|
+
}
|
|
3109
|
+
}
|
|
3110
|
+
)";
|
|
3111
|
+
|
|
3112
|
+
// avoid recompilation
|
|
3113
|
+
ModuleKey key = {arch, use_ptx};
|
|
3114
|
+
auto it = g_conditional_modules.find(key);
|
|
3115
|
+
if (it != g_conditional_modules.end())
|
|
3116
|
+
return it->second;
|
|
3117
|
+
|
|
3118
|
+
nvrtcProgram prog;
|
|
3119
|
+
if (!check_nvrtc(nvrtcCreateProgram(&prog, kernel_source, "conditional_kernels", 0, NULL, NULL)))
|
|
3120
|
+
return NULL;
|
|
3121
|
+
|
|
3122
|
+
char arch_opt[128];
|
|
3123
|
+
if (use_ptx)
|
|
3124
|
+
snprintf(arch_opt, sizeof(arch_opt), "--gpu-architecture=compute_%d", arch);
|
|
3125
|
+
else
|
|
3126
|
+
snprintf(arch_opt, sizeof(arch_opt), "--gpu-architecture=sm_%d", arch);
|
|
3127
|
+
|
|
3128
|
+
std::vector<const char*> opts;
|
|
3129
|
+
opts.push_back(arch_opt);
|
|
3130
|
+
|
|
3131
|
+
const bool print_debug = (std::getenv("WARP_DEBUG") != nullptr);
|
|
3132
|
+
if (print_debug)
|
|
3133
|
+
{
|
|
3134
|
+
printf("NVRTC options (conditional module, arch=%d, use_ptx=%s):\n", arch, use_ptx ? "true" : "false");
|
|
3135
|
+
for(auto o: opts) {
|
|
3136
|
+
printf("%s\n", o);
|
|
3137
|
+
}
|
|
3138
|
+
}
|
|
3139
|
+
|
|
3140
|
+
if (!check_nvrtc(nvrtcCompileProgram(prog, int(opts.size()), opts.data())))
|
|
3141
|
+
{
|
|
3142
|
+
size_t log_size;
|
|
3143
|
+
if (check_nvrtc(nvrtcGetProgramLogSize(prog, &log_size)))
|
|
3144
|
+
{
|
|
3145
|
+
std::vector<char> log(log_size);
|
|
3146
|
+
if (check_nvrtc(nvrtcGetProgramLog(prog, log.data())))
|
|
3147
|
+
fprintf(stderr, "%s", log.data());
|
|
3148
|
+
}
|
|
3149
|
+
nvrtcDestroyProgram(&prog);
|
|
3150
|
+
return NULL;
|
|
3151
|
+
}
|
|
3152
|
+
|
|
3153
|
+
// get output
|
|
3154
|
+
char* output = NULL;
|
|
3155
|
+
size_t output_size = 0;
|
|
3156
|
+
|
|
3157
|
+
if (use_ptx)
|
|
3158
|
+
{
|
|
3159
|
+
check_nvrtc(nvrtcGetPTXSize(prog, &output_size));
|
|
3160
|
+
if (output_size > 0)
|
|
3161
|
+
{
|
|
3162
|
+
output = new char[output_size];
|
|
3163
|
+
if (check_nvrtc(nvrtcGetPTX(prog, output)))
|
|
3164
|
+
g_conditional_modules[key] = output;
|
|
3165
|
+
}
|
|
3166
|
+
}
|
|
3167
|
+
else
|
|
3168
|
+
{
|
|
3169
|
+
check_nvrtc(nvrtcGetCUBINSize(prog, &output_size));
|
|
3170
|
+
if (output_size > 0)
|
|
3171
|
+
{
|
|
3172
|
+
output = new char[output_size];
|
|
3173
|
+
if (check_nvrtc(nvrtcGetCUBIN(prog, output)))
|
|
3174
|
+
g_conditional_modules[key] = output;
|
|
3175
|
+
}
|
|
3176
|
+
}
|
|
3177
|
+
|
|
3178
|
+
nvrtcDestroyProgram(&prog);
|
|
3179
|
+
|
|
3180
|
+
// return CUBIN or PTX data
|
|
3181
|
+
return output;
|
|
3182
|
+
}
|
|
3183
|
+
|
|
3184
|
+
|
|
3185
|
+
// Load module with conditional helper kernels
|
|
3186
|
+
static CUmodule load_conditional_module(void* context, int arch, bool use_ptx)
|
|
3187
|
+
{
|
|
3188
|
+
ContextInfo* context_info = get_context_info(context);
|
|
3189
|
+
if (!context_info)
|
|
3190
|
+
return NULL;
|
|
3191
|
+
|
|
3192
|
+
// check if already loaded
|
|
3193
|
+
if (context_info->conditional_module)
|
|
3194
|
+
return context_info->conditional_module;
|
|
3195
|
+
|
|
3196
|
+
// compile if needed
|
|
3197
|
+
void* compiled_module = compile_conditional_module(arch, use_ptx);
|
|
3198
|
+
if (!compiled_module)
|
|
3199
|
+
{
|
|
3200
|
+
fprintf(stderr, "Warp error: Failed to compile conditional kernels\n");
|
|
3201
|
+
return NULL;
|
|
3202
|
+
}
|
|
3203
|
+
|
|
3204
|
+
// load module (handles both PTX and CUBIN data automatically)
|
|
3205
|
+
CUmodule module = NULL;
|
|
3206
|
+
if (!check_cu(cuModuleLoadDataEx_f(&module, compiled_module, 0, NULL, NULL)))
|
|
3207
|
+
{
|
|
3208
|
+
fprintf(stderr, "Warp error: Failed to load conditional kernels module\n");
|
|
3209
|
+
return NULL;
|
|
3210
|
+
}
|
|
3211
|
+
|
|
3212
|
+
context_info->conditional_module = module;
|
|
3213
|
+
|
|
3214
|
+
return module;
|
|
3215
|
+
}
|
|
3216
|
+
|
|
3217
|
+
static CUfunction get_conditional_kernel(void* context, int arch, bool use_ptx, const char* name)
|
|
3218
|
+
{
|
|
3219
|
+
// load module if needed
|
|
3220
|
+
CUmodule module = load_conditional_module(context, arch, use_ptx);
|
|
3221
|
+
if (!module)
|
|
3222
|
+
return NULL;
|
|
3223
|
+
|
|
3224
|
+
CUfunction kernel;
|
|
3225
|
+
if (!check_cu(cuModuleGetFunction_f(&kernel, module, name)))
|
|
3226
|
+
{
|
|
3227
|
+
fprintf(stderr, "Warp error: Failed to get kernel %s\n", name);
|
|
3228
|
+
return NULL;
|
|
3229
|
+
}
|
|
3230
|
+
|
|
3231
|
+
return kernel;
|
|
3232
|
+
}
|
|
3233
|
+
|
|
3234
|
+
bool wp_cuda_graph_pause_capture(void* context, void* stream, void** graph_ret)
|
|
3235
|
+
{
|
|
3236
|
+
ContextGuard guard(context);
|
|
3237
|
+
|
|
3238
|
+
CUstream cuda_stream = static_cast<CUstream>(stream);
|
|
3239
|
+
if (!check_cuda(cudaStreamEndCapture(cuda_stream, (cudaGraph_t*)graph_ret)))
|
|
3240
|
+
return false;
|
|
3241
|
+
return true;
|
|
3242
|
+
}
|
|
3243
|
+
|
|
3244
|
+
bool wp_cuda_graph_resume_capture(void* context, void* stream, void* graph)
|
|
3245
|
+
{
|
|
3246
|
+
ContextGuard guard(context);
|
|
3247
|
+
|
|
3248
|
+
CUstream cuda_stream = static_cast<CUstream>(stream);
|
|
3249
|
+
cudaGraph_t cuda_graph = static_cast<cudaGraph_t>(graph);
|
|
3250
|
+
|
|
3251
|
+
std::vector<cudaGraphNode_t> leaf_nodes;
|
|
3252
|
+
if (!get_graph_leaf_nodes(cuda_graph, leaf_nodes))
|
|
3253
|
+
return false;
|
|
3254
|
+
|
|
3255
|
+
if (!check_cuda(cudaStreamBeginCaptureToGraph(cuda_stream,
|
|
3256
|
+
cuda_graph,
|
|
3257
|
+
leaf_nodes.data(),
|
|
3258
|
+
nullptr,
|
|
3259
|
+
leaf_nodes.size(),
|
|
3260
|
+
cudaStreamCaptureModeThreadLocal)))
|
|
3261
|
+
return false;
|
|
3262
|
+
|
|
3263
|
+
return true;
|
|
3264
|
+
}
|
|
3265
|
+
|
|
3266
|
+
// https://developer.nvidia.com/blog/constructing-cuda-graphs-with-dynamic-parameters/#combined_approach
|
|
3267
|
+
// https://developer.nvidia.com/blog/dynamic-control-flow-in-cuda-graphs-with-conditional-nodes/
|
|
3268
|
+
// condition is a gpu pointer
|
|
3269
|
+
// if_graph_ret and else_graph_ret should be NULL if not needed
|
|
3270
|
+
bool wp_cuda_graph_insert_if_else(void* context, void* stream, int arch, bool use_ptx, int* condition, void** if_graph_ret, void** else_graph_ret)
|
|
3271
|
+
{
|
|
3272
|
+
bool has_if = if_graph_ret != NULL;
|
|
3273
|
+
bool has_else = else_graph_ret != NULL;
|
|
3274
|
+
int num_branches = int(has_if) + int(has_else);
|
|
3275
|
+
|
|
3276
|
+
// if neither the IF nor ELSE branches are required, it's a no-op
|
|
3277
|
+
if (num_branches == 0)
|
|
3278
|
+
return true;
|
|
3279
|
+
|
|
3280
|
+
ContextGuard guard(context);
|
|
3281
|
+
|
|
3282
|
+
CUstream cuda_stream = static_cast<CUstream>(stream);
|
|
3283
|
+
|
|
3284
|
+
// Get the current stream capturing graph
|
|
3285
|
+
CUstreamCaptureStatus capture_status = CU_STREAM_CAPTURE_STATUS_NONE;
|
|
3286
|
+
cudaGraph_t cuda_graph = NULL;
|
|
3287
|
+
const cudaGraphNode_t* capture_deps = NULL;
|
|
3288
|
+
size_t dep_count = 0;
|
|
3289
|
+
if (!check_cu(cuStreamGetCaptureInfo_f(cuda_stream, &capture_status, nullptr, &cuda_graph, &capture_deps, &dep_count)))
|
|
3290
|
+
return false;
|
|
3291
|
+
|
|
3292
|
+
// abort if not capturing
|
|
3293
|
+
if (!cuda_graph || capture_status != CU_STREAM_CAPTURE_STATUS_ACTIVE)
|
|
3294
|
+
{
|
|
3295
|
+
wp::set_error_string("Stream is not capturing");
|
|
3296
|
+
return false;
|
|
3297
|
+
}
|
|
3298
|
+
|
|
3299
|
+
//int driver_version = wp_cuda_driver_version();
|
|
3300
|
+
|
|
3301
|
+
// IF-ELSE nodes are only supported with CUDA 12.8+
|
|
3302
|
+
// Somehow child graphs produce wrong results when an else branch is used
|
|
3303
|
+
// Seems to be a bug in the CUDA driver: https://nvbugs/5241330
|
|
3304
|
+
if (num_branches == 1 /*|| driver_version >= 12080*/)
|
|
3305
|
+
{
|
|
3306
|
+
cudaGraphConditionalHandle handle;
|
|
3307
|
+
check_cuda(cudaGraphConditionalHandleCreate(&handle, cuda_graph));
|
|
3308
|
+
|
|
3309
|
+
// run a kernel to set the condition handle from the condition pointer
|
|
3310
|
+
// (need to negate the condition if only the else branch is used)
|
|
3311
|
+
CUfunction kernel;
|
|
3312
|
+
if (has_if)
|
|
3313
|
+
kernel = get_conditional_kernel(context, arch, use_ptx, "set_conditional_if_handle_kernel");
|
|
3314
|
+
else
|
|
3315
|
+
kernel = get_conditional_kernel(context, arch, use_ptx, "set_conditional_else_handle_kernel");
|
|
3316
|
+
|
|
3317
|
+
if (!kernel)
|
|
3318
|
+
{
|
|
3319
|
+
wp::set_error_string("Failed to get built-in conditional kernel");
|
|
3320
|
+
return false;
|
|
3321
|
+
}
|
|
3322
|
+
|
|
3323
|
+
void* kernel_args[2];
|
|
3324
|
+
kernel_args[0] = &handle;
|
|
3325
|
+
kernel_args[1] = &condition;
|
|
3326
|
+
|
|
3327
|
+
if (!check_cu(cuLaunchKernel_f(kernel, 1, 1, 1, 1, 1, 1, 0, cuda_stream, kernel_args, NULL)))
|
|
3328
|
+
return false;
|
|
3329
|
+
|
|
3330
|
+
if (!check_cu(cuStreamGetCaptureInfo_f(cuda_stream, &capture_status, nullptr, &cuda_graph, &capture_deps, &dep_count)))
|
|
3331
|
+
return false;
|
|
3332
|
+
|
|
3333
|
+
// create conditional node
|
|
3334
|
+
CUgraphNode condition_node;
|
|
3335
|
+
CUgraphNodeParams condition_params = { CU_GRAPH_NODE_TYPE_CONDITIONAL };
|
|
3336
|
+
condition_params.conditional.handle = handle;
|
|
3337
|
+
condition_params.conditional.type = CU_GRAPH_COND_TYPE_IF;
|
|
3338
|
+
condition_params.conditional.size = num_branches;
|
|
3339
|
+
condition_params.conditional.ctx = get_current_context();
|
|
3340
|
+
if (!check_cu(cuGraphAddNode_f(&condition_node, cuda_graph, capture_deps, NULL, dep_count, &condition_params)))
|
|
3341
|
+
return false;
|
|
3342
|
+
|
|
3343
|
+
if (!check_cu(cuStreamUpdateCaptureDependencies_f(cuda_stream, &condition_node, 1, cudaStreamSetCaptureDependencies)))
|
|
3344
|
+
return false;
|
|
3345
|
+
|
|
3346
|
+
if (num_branches == 1)
|
|
3347
|
+
{
|
|
3348
|
+
if (has_if)
|
|
3349
|
+
*if_graph_ret = condition_params.conditional.phGraph_out[0];
|
|
3350
|
+
else
|
|
3351
|
+
*else_graph_ret = condition_params.conditional.phGraph_out[0];
|
|
3352
|
+
}
|
|
3353
|
+
else
|
|
3354
|
+
{
|
|
3355
|
+
*if_graph_ret = condition_params.conditional.phGraph_out[0];
|
|
3356
|
+
*else_graph_ret = condition_params.conditional.phGraph_out[1];
|
|
3357
|
+
}
|
|
3358
|
+
}
|
|
3359
|
+
else
|
|
3360
|
+
{
|
|
3361
|
+
// Create IF node followed by an additional IF node with negated condition
|
|
3362
|
+
cudaGraphConditionalHandle if_handle, else_handle;
|
|
3363
|
+
check_cuda(cudaGraphConditionalHandleCreate(&if_handle, cuda_graph));
|
|
3364
|
+
check_cuda(cudaGraphConditionalHandleCreate(&else_handle, cuda_graph));
|
|
3365
|
+
|
|
3366
|
+
CUfunction kernel = get_conditional_kernel(context, arch, use_ptx, "set_conditional_if_else_handles_kernel");
|
|
3367
|
+
if (!kernel)
|
|
3368
|
+
{
|
|
3369
|
+
wp::set_error_string("Failed to get built-in conditional kernel");
|
|
3370
|
+
return false;
|
|
3371
|
+
}
|
|
3372
|
+
|
|
3373
|
+
void* kernel_args[3];
|
|
3374
|
+
kernel_args[0] = &if_handle;
|
|
3375
|
+
kernel_args[1] = &else_handle;
|
|
3376
|
+
kernel_args[2] = &condition;
|
|
3377
|
+
|
|
3378
|
+
if (!check_cu(cuLaunchKernel_f(kernel, 1, 1, 1, 1, 1, 1, 0, cuda_stream, kernel_args, NULL)))
|
|
3379
|
+
return false;
|
|
3380
|
+
|
|
3381
|
+
if (!check_cu(cuStreamGetCaptureInfo_f(cuda_stream, &capture_status, nullptr, &cuda_graph, &capture_deps, &dep_count)))
|
|
3382
|
+
return false;
|
|
3383
|
+
|
|
3384
|
+
CUgraphNode if_node;
|
|
3385
|
+
CUgraphNodeParams if_params = { CU_GRAPH_NODE_TYPE_CONDITIONAL };
|
|
3386
|
+
if_params.conditional.handle = if_handle;
|
|
3387
|
+
if_params.conditional.type = CU_GRAPH_COND_TYPE_IF;
|
|
3388
|
+
if_params.conditional.size = 1;
|
|
3389
|
+
if_params.conditional.ctx = get_current_context();
|
|
3390
|
+
if (!check_cu(cuGraphAddNode_f(&if_node, cuda_graph, capture_deps, NULL, dep_count, &if_params)))
|
|
3391
|
+
return false;
|
|
3392
|
+
|
|
3393
|
+
CUgraphNode else_node;
|
|
3394
|
+
CUgraphNodeParams else_params = { CU_GRAPH_NODE_TYPE_CONDITIONAL };
|
|
3395
|
+
else_params.conditional.handle = else_handle;
|
|
3396
|
+
else_params.conditional.type = CU_GRAPH_COND_TYPE_IF;
|
|
3397
|
+
else_params.conditional.size = 1;
|
|
3398
|
+
else_params.conditional.ctx = get_current_context();
|
|
3399
|
+
if (!check_cu(cuGraphAddNode_f(&else_node, cuda_graph, &if_node, NULL, 1, &else_params)))
|
|
3400
|
+
return false;
|
|
3401
|
+
|
|
3402
|
+
if (!check_cu(cuStreamUpdateCaptureDependencies_f(cuda_stream, &else_node, 1, cudaStreamSetCaptureDependencies)))
|
|
3403
|
+
return false;
|
|
3404
|
+
|
|
3405
|
+
*if_graph_ret = if_params.conditional.phGraph_out[0];
|
|
3406
|
+
*else_graph_ret = else_params.conditional.phGraph_out[0];
|
|
3407
|
+
}
|
|
3408
|
+
|
|
3409
|
+
return true;
|
|
3410
|
+
}
|
|
3411
|
+
|
|
3412
|
+
// graph node type names for intelligible error reporting
|
|
3413
|
+
static const char* get_graph_node_type_name(CUgraphNodeType type)
|
|
3414
|
+
{
|
|
3415
|
+
static const std::unordered_map<CUgraphNodeType, const char*> names
|
|
3416
|
+
{
|
|
3417
|
+
{CU_GRAPH_NODE_TYPE_KERNEL, "kernel launch"},
|
|
3418
|
+
{CU_GRAPH_NODE_TYPE_MEMCPY, "memcpy"},
|
|
3419
|
+
{CU_GRAPH_NODE_TYPE_MEMSET, "memset"},
|
|
3420
|
+
{CU_GRAPH_NODE_TYPE_HOST, "host execution"},
|
|
3421
|
+
{CU_GRAPH_NODE_TYPE_GRAPH, "graph launch"},
|
|
3422
|
+
{CU_GRAPH_NODE_TYPE_EMPTY, "empty node"},
|
|
3423
|
+
{CU_GRAPH_NODE_TYPE_WAIT_EVENT, "event wait"},
|
|
3424
|
+
{CU_GRAPH_NODE_TYPE_EVENT_RECORD, "event record"},
|
|
3425
|
+
{CU_GRAPH_NODE_TYPE_EXT_SEMAS_SIGNAL, "semaphore signal"},
|
|
3426
|
+
{CU_GRAPH_NODE_TYPE_EXT_SEMAS_WAIT, "semaphore wait"},
|
|
3427
|
+
{CU_GRAPH_NODE_TYPE_MEM_ALLOC, "memory allocation"},
|
|
3428
|
+
{CU_GRAPH_NODE_TYPE_MEM_FREE, "memory deallocation"},
|
|
3429
|
+
{CU_GRAPH_NODE_TYPE_BATCH_MEM_OP, "batched mem op"},
|
|
3430
|
+
{CU_GRAPH_NODE_TYPE_CONDITIONAL, "conditional node"},
|
|
3431
|
+
};
|
|
3432
|
+
|
|
3433
|
+
auto it = names.find(type);
|
|
3434
|
+
if (it != names.end())
|
|
3435
|
+
return it->second;
|
|
3436
|
+
else
|
|
3437
|
+
return "unknown node";
|
|
3438
|
+
}
|
|
3439
|
+
|
|
3440
|
+
// check if a graph can be launched as a child graph
|
|
3441
|
+
static bool is_valid_child_graph(void* child_graph)
|
|
3442
|
+
{
|
|
3443
|
+
// disallowed child graph nodes according to the documentation of cuGraphAddChildGraphNode()
|
|
3444
|
+
static const std::unordered_set<CUgraphNodeType> disallowed_nodes
|
|
3445
|
+
{
|
|
3446
|
+
CU_GRAPH_NODE_TYPE_MEM_ALLOC,
|
|
3447
|
+
CU_GRAPH_NODE_TYPE_MEM_FREE,
|
|
3448
|
+
CU_GRAPH_NODE_TYPE_CONDITIONAL,
|
|
3449
|
+
};
|
|
3450
|
+
|
|
3451
|
+
if (!child_graph)
|
|
3452
|
+
{
|
|
3453
|
+
wp::set_error_string("Child graph is null");
|
|
3454
|
+
return false;
|
|
3455
|
+
}
|
|
3456
|
+
|
|
3457
|
+
size_t num_nodes = 0;
|
|
3458
|
+
if (!check_cuda(cudaGraphGetNodes((cudaGraph_t)child_graph, NULL, &num_nodes)))
|
|
3459
|
+
return false;
|
|
3460
|
+
std::vector<cudaGraphNode_t> nodes(num_nodes);
|
|
3461
|
+
if (!check_cuda(cudaGraphGetNodes((cudaGraph_t)child_graph, nodes.data(), &num_nodes)))
|
|
3462
|
+
return false;
|
|
3463
|
+
|
|
3464
|
+
for (size_t i = 0; i < num_nodes; i++)
|
|
3465
|
+
{
|
|
3466
|
+
// note: we use the driver API to get the node type, otherwise some nodes are not recognized correctly
|
|
3467
|
+
CUgraphNodeType node_type;
|
|
3468
|
+
check_cu(cuGraphNodeGetType_f(nodes[i], &node_type));
|
|
3469
|
+
auto it = disallowed_nodes.find(node_type);
|
|
3470
|
+
if (it != disallowed_nodes.end())
|
|
3471
|
+
{
|
|
3472
|
+
wp::set_error_string("Child graph contains an unsupported operation (%s)", get_graph_node_type_name(node_type));
|
|
3473
|
+
return false;
|
|
3474
|
+
}
|
|
3475
|
+
}
|
|
3476
|
+
|
|
3477
|
+
return true;
|
|
3478
|
+
}
|
|
3479
|
+
|
|
3480
|
+
// check if a graph can be used as a conditional body graph
|
|
3481
|
+
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#condtional-node-body-graph-requirements
|
|
3482
|
+
bool wp_cuda_graph_check_conditional_body(void* body_graph)
|
|
3483
|
+
{
|
|
3484
|
+
static const std::unordered_set<CUgraphNodeType> allowed_nodes
|
|
3485
|
+
{
|
|
3486
|
+
CU_GRAPH_NODE_TYPE_MEMCPY,
|
|
3487
|
+
CU_GRAPH_NODE_TYPE_MEMSET,
|
|
3488
|
+
CU_GRAPH_NODE_TYPE_KERNEL,
|
|
3489
|
+
CU_GRAPH_NODE_TYPE_GRAPH,
|
|
3490
|
+
CU_GRAPH_NODE_TYPE_EMPTY,
|
|
3491
|
+
CU_GRAPH_NODE_TYPE_CONDITIONAL,
|
|
3492
|
+
};
|
|
3493
|
+
|
|
3494
|
+
if (!body_graph)
|
|
3495
|
+
{
|
|
3496
|
+
wp::set_error_string("Conditional body graph is null");
|
|
3497
|
+
return false;
|
|
3498
|
+
}
|
|
3499
|
+
|
|
3500
|
+
size_t num_nodes = 0;
|
|
3501
|
+
if (!check_cuda(cudaGraphGetNodes((cudaGraph_t)body_graph, NULL, &num_nodes)))
|
|
3502
|
+
return false;
|
|
3503
|
+
std::vector<cudaGraphNode_t> nodes(num_nodes);
|
|
3504
|
+
if (!check_cuda(cudaGraphGetNodes((cudaGraph_t)body_graph, nodes.data(), &num_nodes)))
|
|
3505
|
+
return false;
|
|
3506
|
+
|
|
3507
|
+
for (size_t i = 0; i < num_nodes; i++)
|
|
3508
|
+
{
|
|
3509
|
+
// note: we use the driver API to get the node type, otherwise some nodes are not recognized correctly
|
|
3510
|
+
CUgraphNodeType node_type;
|
|
3511
|
+
check_cu(cuGraphNodeGetType_f(nodes[i], &node_type));
|
|
3512
|
+
if (allowed_nodes.find(node_type) == allowed_nodes.end())
|
|
3513
|
+
{
|
|
3514
|
+
wp::set_error_string("Conditional body graph contains an unsupported operation (%s)", get_graph_node_type_name(node_type));
|
|
3515
|
+
return false;
|
|
3516
|
+
}
|
|
3517
|
+
else if (node_type == CU_GRAPH_NODE_TYPE_GRAPH)
|
|
3518
|
+
{
|
|
3519
|
+
// check nested child graphs recursively
|
|
3520
|
+
cudaGraph_t child_graph = NULL;
|
|
3521
|
+
if (!check_cuda(cudaGraphChildGraphNodeGetGraph(nodes[i], &child_graph)))
|
|
3522
|
+
return false;
|
|
3523
|
+
if (!wp_cuda_graph_check_conditional_body(child_graph))
|
|
3524
|
+
return false;
|
|
3525
|
+
}
|
|
3526
|
+
}
|
|
3527
|
+
|
|
3528
|
+
return true;
|
|
3529
|
+
}
|
|
3530
|
+
|
|
3531
|
+
bool wp_cuda_graph_insert_child_graph(void* context, void* stream, void* child_graph)
|
|
3532
|
+
{
|
|
3533
|
+
if (!is_valid_child_graph(child_graph))
|
|
3534
|
+
return false;
|
|
3535
|
+
|
|
3536
|
+
ContextGuard guard(context);
|
|
3537
|
+
|
|
3538
|
+
CUstream cuda_stream = static_cast<CUstream>(stream);
|
|
3539
|
+
|
|
3540
|
+
// Get the current stream capturing graph
|
|
3541
|
+
CUstreamCaptureStatus capture_status = CU_STREAM_CAPTURE_STATUS_NONE;
|
|
3542
|
+
void* cuda_graph = NULL;
|
|
3543
|
+
const CUgraphNode* capture_deps = NULL;
|
|
3544
|
+
size_t dep_count = 0;
|
|
3545
|
+
if (!check_cu(cuStreamGetCaptureInfo_f(cuda_stream, &capture_status, nullptr, (cudaGraph_t*)&cuda_graph, &capture_deps, &dep_count)))
|
|
3546
|
+
return false;
|
|
3547
|
+
|
|
3548
|
+
if (!wp_cuda_graph_pause_capture(context, cuda_stream, &cuda_graph))
|
|
3549
|
+
return false;
|
|
3550
|
+
|
|
3551
|
+
cudaGraphNode_t body_node;
|
|
3552
|
+
if (!check_cuda(cudaGraphAddChildGraphNode(&body_node,
|
|
3553
|
+
static_cast<cudaGraph_t>(cuda_graph),
|
|
3554
|
+
capture_deps, dep_count,
|
|
3555
|
+
static_cast<cudaGraph_t>(child_graph))))
|
|
3556
|
+
return false;
|
|
3557
|
+
|
|
3558
|
+
if (!wp_cuda_graph_resume_capture(context, cuda_stream, cuda_graph))
|
|
3559
|
+
return false;
|
|
3560
|
+
|
|
3561
|
+
if (!check_cu(cuStreamUpdateCaptureDependencies_f(cuda_stream, &body_node, 1, cudaStreamSetCaptureDependencies)))
|
|
3562
|
+
return false;
|
|
3563
|
+
|
|
3564
|
+
return true;
|
|
3565
|
+
}
|
|
3566
|
+
|
|
3567
|
+
bool wp_cuda_graph_insert_while(void* context, void* stream, int arch, bool use_ptx, int* condition, void** body_graph_ret, uint64_t* handle_ret)
|
|
3568
|
+
{
|
|
3569
|
+
// if there's no body, it's a no-op
|
|
3570
|
+
if (!body_graph_ret)
|
|
3571
|
+
return true;
|
|
3572
|
+
|
|
3573
|
+
ContextGuard guard(context);
|
|
3574
|
+
|
|
3575
|
+
CUstream cuda_stream = static_cast<CUstream>(stream);
|
|
3576
|
+
|
|
3577
|
+
// Get the current stream capturing graph
|
|
3578
|
+
CUstreamCaptureStatus capture_status = CU_STREAM_CAPTURE_STATUS_NONE;
|
|
3579
|
+
cudaGraph_t cuda_graph = NULL;
|
|
3580
|
+
const cudaGraphNode_t* capture_deps = NULL;
|
|
3581
|
+
size_t dep_count = 0;
|
|
3582
|
+
if (!check_cu(cuStreamGetCaptureInfo_f(cuda_stream, &capture_status, nullptr, &cuda_graph, &capture_deps, &dep_count)))
|
|
3583
|
+
return false;
|
|
3584
|
+
|
|
3585
|
+
// abort if not capturing
|
|
3586
|
+
if (!cuda_graph || capture_status != CU_STREAM_CAPTURE_STATUS_ACTIVE)
|
|
3587
|
+
{
|
|
3588
|
+
wp::set_error_string("Stream is not capturing");
|
|
3589
|
+
return false;
|
|
3590
|
+
}
|
|
3591
|
+
|
|
3592
|
+
cudaGraphConditionalHandle handle;
|
|
3593
|
+
if (!check_cuda(cudaGraphConditionalHandleCreate(&handle, cuda_graph)))
|
|
3594
|
+
return false;
|
|
3595
|
+
|
|
3596
|
+
// launch a kernel to set the condition handle from condition pointer
|
|
3597
|
+
CUfunction kernel = get_conditional_kernel(context, arch, use_ptx, "set_conditional_if_handle_kernel");
|
|
3598
|
+
if (!kernel)
|
|
3599
|
+
{
|
|
3600
|
+
wp::set_error_string("Failed to get built-in conditional kernel");
|
|
3601
|
+
return false;
|
|
3602
|
+
}
|
|
3603
|
+
|
|
3604
|
+
void* kernel_args[2];
|
|
3605
|
+
kernel_args[0] = &handle;
|
|
3606
|
+
kernel_args[1] = &condition;
|
|
3607
|
+
|
|
3608
|
+
if (!check_cu(cuLaunchKernel_f(kernel, 1, 1, 1, 1, 1, 1, 0, cuda_stream, kernel_args, NULL)))
|
|
3609
|
+
return false;
|
|
3610
|
+
|
|
3611
|
+
if (!check_cu(cuStreamGetCaptureInfo_f(cuda_stream, &capture_status, nullptr, &cuda_graph, &capture_deps, &dep_count)))
|
|
3612
|
+
return false;
|
|
3613
|
+
|
|
3614
|
+
// insert conditional graph node
|
|
3615
|
+
CUgraphNode while_node;
|
|
3616
|
+
CUgraphNodeParams while_params = { CU_GRAPH_NODE_TYPE_CONDITIONAL };
|
|
3617
|
+
while_params.conditional.handle = handle;
|
|
3618
|
+
while_params.conditional.type = CU_GRAPH_COND_TYPE_WHILE;
|
|
3619
|
+
while_params.conditional.size = 1;
|
|
3620
|
+
while_params.conditional.ctx = get_current_context();
|
|
3621
|
+
if (!check_cu(cuGraphAddNode_f(&while_node, cuda_graph, capture_deps, NULL, dep_count, &while_params)))
|
|
3622
|
+
return false;
|
|
3623
|
+
|
|
3624
|
+
if (!check_cu(cuStreamUpdateCaptureDependencies_f(cuda_stream, &while_node, 1, cudaStreamSetCaptureDependencies)))
|
|
3625
|
+
return false;
|
|
3626
|
+
|
|
3627
|
+
*body_graph_ret = while_params.conditional.phGraph_out[0];
|
|
3628
|
+
*handle_ret = handle;
|
|
3629
|
+
|
|
3630
|
+
return true;
|
|
3631
|
+
}
|
|
3632
|
+
|
|
3633
|
+
bool wp_cuda_graph_set_condition(void* context, void* stream, int arch, bool use_ptx, int* condition, uint64_t handle)
|
|
3634
|
+
{
|
|
3635
|
+
ContextGuard guard(context);
|
|
3636
|
+
|
|
3637
|
+
CUstream cuda_stream = static_cast<CUstream>(stream);
|
|
3638
|
+
|
|
3639
|
+
// launch a kernel to set the condition handle from condition pointer
|
|
3640
|
+
CUfunction kernel = get_conditional_kernel(context, arch, use_ptx, "set_conditional_if_handle_kernel");
|
|
3641
|
+
if (!kernel)
|
|
3642
|
+
{
|
|
3643
|
+
wp::set_error_string("Failed to get built-in conditional kernel");
|
|
3644
|
+
return false;
|
|
3645
|
+
}
|
|
3646
|
+
|
|
3647
|
+
void* kernel_args[2];
|
|
3648
|
+
kernel_args[0] = &handle;
|
|
3649
|
+
kernel_args[1] = &condition;
|
|
3650
|
+
|
|
3651
|
+
if (!check_cu(cuLaunchKernel_f(kernel, 1, 1, 1, 1, 1, 1, 0, cuda_stream, kernel_args, NULL)))
|
|
3652
|
+
return false;
|
|
3653
|
+
|
|
3654
|
+
return true;
|
|
3655
|
+
}
|
|
3656
|
+
|
|
3657
|
+
#else
|
|
3658
|
+
// stubs for conditional graph node API if CUDA toolkit is too old.
|
|
3659
|
+
|
|
3660
|
+
bool wp_cuda_graph_pause_capture(void* context, void* stream, void** graph_ret)
|
|
3661
|
+
{
|
|
3662
|
+
wp::set_error_string("Warp error: Warp must be built with CUDA Toolkit 12.4+ to enable conditional graph nodes");
|
|
3663
|
+
return false;
|
|
3664
|
+
}
|
|
3665
|
+
|
|
3666
|
+
bool wp_cuda_graph_resume_capture(void* context, void* stream, void* graph)
|
|
3667
|
+
{
|
|
3668
|
+
wp::set_error_string("Warp error: Warp must be built with CUDA Toolkit 12.4+ to enable conditional graph nodes");
|
|
3669
|
+
return false;
|
|
3670
|
+
}
|
|
3671
|
+
|
|
3672
|
+
bool wp_cuda_graph_insert_if_else(void* context, void* stream, int arch, bool use_ptx, int* condition, void** if_graph_ret, void** else_graph_ret)
|
|
3673
|
+
{
|
|
3674
|
+
wp::set_error_string("Warp error: Warp must be built with CUDA Toolkit 12.4+ to enable conditional graph nodes");
|
|
3675
|
+
return false;
|
|
3676
|
+
}
|
|
3677
|
+
|
|
3678
|
+
bool wp_cuda_graph_insert_while(void* context, void* stream, int arch, bool use_ptx, int* condition, void** body_graph_ret, uint64_t* handle_ret)
|
|
3679
|
+
{
|
|
3680
|
+
wp::set_error_string("Warp error: Warp must be built with CUDA Toolkit 12.4+ to enable conditional graph nodes");
|
|
3681
|
+
return false;
|
|
3682
|
+
}
|
|
3683
|
+
|
|
3684
|
+
bool wp_cuda_graph_set_condition(void* context, void* stream, int arch, bool use_ptx, int* condition, uint64_t handle)
|
|
3685
|
+
{
|
|
3686
|
+
wp::set_error_string("Warp error: Warp must be built with CUDA Toolkit 12.4+ to enable conditional graph nodes");
|
|
3687
|
+
return false;
|
|
3688
|
+
}
|
|
3689
|
+
|
|
3690
|
+
bool wp_cuda_graph_insert_child_graph(void* context, void* stream, void* child_graph)
|
|
3691
|
+
{
|
|
3692
|
+
wp::set_error_string("Warp error: Warp must be built with CUDA Toolkit 12.4+ to enable conditional graph nodes");
|
|
3693
|
+
return false;
|
|
3694
|
+
}
|
|
3695
|
+
|
|
3696
|
+
bool wp_cuda_graph_check_conditional_body(void* body_graph)
|
|
3697
|
+
{
|
|
3698
|
+
wp::set_error_string("Warp error: Warp must be built with CUDA Toolkit 12.4+ to enable conditional graph nodes");
|
|
3699
|
+
return false;
|
|
3700
|
+
}
|
|
3701
|
+
|
|
3702
|
+
#endif // support for conditional graph nodes
|
|
3703
|
+
|
|
3704
|
+
|
|
3705
|
+
bool wp_cuda_graph_launch(void* graph_exec, void* stream)
|
|
3706
|
+
{
|
|
3707
|
+
// TODO: allow naming graphs?
|
|
3708
|
+
begin_cuda_range(WP_TIMING_GRAPH, stream, get_stream_context(stream), "graph");
|
|
3709
|
+
|
|
3710
|
+
bool result = check_cuda(cudaGraphLaunch((cudaGraphExec_t)graph_exec, (cudaStream_t)stream));
|
|
3711
|
+
|
|
3712
|
+
end_cuda_range(WP_TIMING_GRAPH, stream);
|
|
3713
|
+
|
|
3714
|
+
return result;
|
|
3715
|
+
}
|
|
3716
|
+
|
|
3717
|
+
bool wp_cuda_graph_destroy(void* context, void* graph)
|
|
3718
|
+
{
|
|
3719
|
+
// ensure there are no graph captures in progress
|
|
3720
|
+
if (g_captures.empty())
|
|
3721
|
+
{
|
|
3722
|
+
ContextGuard guard(context);
|
|
3723
|
+
return check_cuda(cudaGraphDestroy((cudaGraph_t)graph));
|
|
3724
|
+
}
|
|
3725
|
+
else
|
|
3726
|
+
{
|
|
3727
|
+
GraphDestroyInfo info;
|
|
3728
|
+
info.context = context ? context : get_current_context();
|
|
3729
|
+
info.graph = graph;
|
|
3730
|
+
g_deferred_graph_list.push_back(info);
|
|
3731
|
+
return true;
|
|
3732
|
+
}
|
|
3733
|
+
}
|
|
3734
|
+
|
|
3735
|
+
bool wp_cuda_graph_exec_destroy(void* context, void* graph_exec)
|
|
3736
|
+
{
|
|
3737
|
+
// ensure there are no graph captures in progress
|
|
3738
|
+
if (g_captures.empty())
|
|
3739
|
+
{
|
|
3740
|
+
ContextGuard guard(context);
|
|
3741
|
+
return check_cuda(cudaGraphExecDestroy((cudaGraphExec_t)graph_exec));
|
|
3742
|
+
}
|
|
3743
|
+
else
|
|
3744
|
+
{
|
|
3745
|
+
GraphDestroyInfo info;
|
|
3746
|
+
info.context = context ? context : get_current_context();
|
|
3747
|
+
info.graph_exec = graph_exec;
|
|
3748
|
+
g_deferred_graph_list.push_back(info);
|
|
3749
|
+
return true;
|
|
3750
|
+
}
|
|
3751
|
+
}
|
|
3752
|
+
|
|
3753
|
+
bool write_file(const char* data, size_t size, std::string filename, const char* mode)
|
|
3754
|
+
{
|
|
3755
|
+
const bool print_debug = (std::getenv("WARP_DEBUG") != nullptr);
|
|
3756
|
+
if (print_debug)
|
|
3757
|
+
{
|
|
3758
|
+
printf("Writing %zu B to %s (%s)\n", size, filename.c_str(), mode);
|
|
3759
|
+
}
|
|
3760
|
+
FILE* file = fopen(filename.c_str(), mode);
|
|
3761
|
+
if (file)
|
|
3762
|
+
{
|
|
3763
|
+
if (fwrite(data, 1, size, file) != size) {
|
|
3764
|
+
fprintf(stderr, "Warp error: Failed to write to output file '%s'\n", filename.c_str());
|
|
3765
|
+
return false;
|
|
3766
|
+
}
|
|
3767
|
+
fclose(file);
|
|
3768
|
+
return true;
|
|
3769
|
+
}
|
|
3770
|
+
else
|
|
3771
|
+
{
|
|
3772
|
+
fprintf(stderr, "Warp error: Failed to open file '%s'\n", filename.c_str());
|
|
3773
|
+
return false;
|
|
3774
|
+
}
|
|
3775
|
+
}
|
|
3776
|
+
|
|
3777
|
+
#if WP_ENABLE_MATHDX
|
|
3778
|
+
bool check_nvjitlink_result(nvJitLinkHandle handle, nvJitLinkResult result, const char* file, int line)
|
|
3779
|
+
{
|
|
3780
|
+
if (result != NVJITLINK_SUCCESS) {
|
|
3781
|
+
fprintf(stderr, "nvJitLink error: %d on %s:%d\n", (int)result, file, line);
|
|
3782
|
+
size_t lsize;
|
|
3783
|
+
result = nvJitLinkGetErrorLogSize(handle, &lsize);
|
|
3784
|
+
if (result == NVJITLINK_SUCCESS && lsize > 0) {
|
|
3785
|
+
std::vector<char> log(lsize);
|
|
3786
|
+
result = nvJitLinkGetErrorLog(handle, log.data());
|
|
3787
|
+
if (result == NVJITLINK_SUCCESS) {
|
|
3788
|
+
fprintf(stderr, "%s\n", log.data());
|
|
3789
|
+
}
|
|
3790
|
+
}
|
|
3791
|
+
return false;
|
|
3792
|
+
} else {
|
|
3793
|
+
return true;
|
|
3794
|
+
}
|
|
3795
|
+
}
|
|
3796
|
+
#endif
|
|
3797
|
+
|
|
3798
|
+
size_t wp_cuda_compile_program(const char* cuda_src, const char* program_name, int arch, const char* include_dir, int num_cuda_include_dirs, const char** cuda_include_dirs, bool debug, bool verbose, bool verify_fp, bool fast_math, bool fuse_fp, bool lineinfo, bool compile_time_trace, const char* output_path, size_t num_ltoirs, char** ltoirs, size_t* ltoir_sizes, int* ltoir_input_types)
|
|
3799
|
+
{
|
|
3800
|
+
// use file extension to determine whether to output PTX or CUBIN
|
|
3801
|
+
const char* output_ext = strrchr(output_path, '.');
|
|
3802
|
+
bool use_ptx = output_ext && strcmp(output_ext + 1, "ptx") == 0;
|
|
3803
|
+
const bool print_debug = (std::getenv("WARP_DEBUG") != nullptr);
|
|
3804
|
+
|
|
3805
|
+
// check include dir path len (path + option)
|
|
3806
|
+
const int max_path = 4096 + 16;
|
|
3807
|
+
if (strlen(include_dir) > max_path)
|
|
3808
|
+
{
|
|
3809
|
+
fprintf(stderr, "Warp error: Include path too long\n");
|
|
3810
|
+
return size_t(-1);
|
|
3811
|
+
}
|
|
3812
|
+
|
|
3813
|
+
if (print_debug)
|
|
3814
|
+
{
|
|
3815
|
+
// Not available in all nvJitLink versions
|
|
3816
|
+
// unsigned major = 0;
|
|
3817
|
+
// unsigned minor = 0;
|
|
3818
|
+
// nvJitLinkVersion(&major, &minor);
|
|
3819
|
+
// printf("nvJitLink version %d.%d\n", major, minor);
|
|
3820
|
+
int major = 0;
|
|
3821
|
+
int minor = 0;
|
|
3822
|
+
nvrtcVersion(&major, &minor);
|
|
3823
|
+
printf("NVRTC version %d.%d\n", major, minor);
|
|
3824
|
+
}
|
|
3825
|
+
|
|
3826
|
+
char include_opt[max_path];
|
|
3827
|
+
strcpy(include_opt, "--include-path=");
|
|
3828
|
+
strcat(include_opt, include_dir);
|
|
3829
|
+
|
|
3830
|
+
const int max_arch = 128;
|
|
3831
|
+
char arch_opt[max_arch];
|
|
3832
|
+
char arch_opt_lto[max_arch];
|
|
3833
|
+
|
|
3834
|
+
if (use_ptx)
|
|
3835
|
+
{
|
|
3836
|
+
snprintf(arch_opt, max_arch, "--gpu-architecture=compute_%d", arch);
|
|
3837
|
+
snprintf(arch_opt_lto, max_arch, "-arch=compute_%d", arch);
|
|
3838
|
+
}
|
|
3839
|
+
else
|
|
3840
|
+
{
|
|
3841
|
+
snprintf(arch_opt, max_arch, "--gpu-architecture=sm_%d", arch);
|
|
3842
|
+
snprintf(arch_opt_lto, max_arch, "-arch=sm_%d", arch);
|
|
3843
|
+
}
|
|
3844
|
+
|
|
3845
|
+
std::vector<const char*> opts;
|
|
3846
|
+
opts.push_back(arch_opt);
|
|
3847
|
+
opts.push_back(include_opt);
|
|
3848
|
+
opts.push_back("--std=c++17");
|
|
3849
|
+
|
|
3850
|
+
if (debug)
|
|
3851
|
+
{
|
|
3852
|
+
opts.push_back("--define-macro=_DEBUG");
|
|
3853
|
+
opts.push_back("--generate-line-info");
|
|
3854
|
+
#ifndef _WIN32
|
|
3855
|
+
opts.push_back("--device-debug"); // -G
|
|
3856
|
+
#endif
|
|
3857
|
+
}
|
|
3858
|
+
else
|
|
3859
|
+
{
|
|
3860
|
+
opts.push_back("--define-macro=NDEBUG");
|
|
3861
|
+
|
|
3862
|
+
if (lineinfo)
|
|
3863
|
+
opts.push_back("--generate-line-info");
|
|
3864
|
+
}
|
|
3865
|
+
|
|
3866
|
+
if (verify_fp)
|
|
3867
|
+
opts.push_back("--define-macro=WP_VERIFY_FP");
|
|
3868
|
+
else
|
|
3869
|
+
opts.push_back("--undefine-macro=WP_VERIFY_FP");
|
|
3870
|
+
|
|
3871
|
+
#if WP_ENABLE_MATHDX
|
|
3872
|
+
opts.push_back("--define-macro=WP_ENABLE_MATHDX=1");
|
|
3873
|
+
#else
|
|
3874
|
+
opts.push_back("--define-macro=WP_ENABLE_MATHDX=0");
|
|
3875
|
+
#endif
|
|
3876
|
+
|
|
3877
|
+
if (fast_math)
|
|
3878
|
+
opts.push_back("--use_fast_math");
|
|
3879
|
+
|
|
3880
|
+
if (fuse_fp)
|
|
3881
|
+
opts.push_back("--fmad=true");
|
|
3882
|
+
else
|
|
3883
|
+
opts.push_back("--fmad=false");
|
|
3884
|
+
|
|
3885
|
+
std::vector<std::string> stored_options;
|
|
3886
|
+
for(int i = 0; i < num_cuda_include_dirs; i++)
|
|
3887
|
+
{
|
|
3888
|
+
stored_options.push_back(std::string("--include-path=") + cuda_include_dirs[i]);
|
|
3889
|
+
opts.push_back(stored_options.back().c_str());
|
|
3890
|
+
}
|
|
3891
|
+
|
|
3892
|
+
opts.push_back("--device-as-default-execution-space");
|
|
3893
|
+
opts.push_back("--extra-device-vectorization");
|
|
3894
|
+
opts.push_back("--restrict");
|
|
3895
|
+
|
|
3896
|
+
if (num_ltoirs > 0)
|
|
3897
|
+
{
|
|
3898
|
+
opts.push_back("-dlto");
|
|
3899
|
+
opts.push_back("--relocatable-device-code=true");
|
|
3900
|
+
}
|
|
3901
|
+
|
|
3902
|
+
if (compile_time_trace)
|
|
3903
|
+
{
|
|
3904
|
+
#if CUDA_VERSION >= 12080
|
|
3905
|
+
stored_options.push_back(std::string("--fdevice-time-trace=") + std::string(output_path).append("_compile-time-trace.json"));
|
|
3906
|
+
opts.push_back(stored_options.back().c_str());
|
|
3907
|
+
#else
|
|
3908
|
+
fprintf(stderr, "Warp warning: CUDA version is less than 12.8, compile_time_trace is not supported\n");
|
|
3909
|
+
#endif
|
|
3910
|
+
}
|
|
3911
|
+
|
|
3912
|
+
nvrtcProgram prog;
|
|
3913
|
+
nvrtcResult res;
|
|
3914
|
+
|
|
3915
|
+
res = nvrtcCreateProgram(
|
|
3916
|
+
&prog, // prog
|
|
3917
|
+
cuda_src, // buffer
|
|
3918
|
+
program_name, // name
|
|
3919
|
+
0, // numHeaders
|
|
3920
|
+
NULL, // headers
|
|
3921
|
+
NULL); // includeNames
|
|
3922
|
+
|
|
3923
|
+
if (!check_nvrtc(res))
|
|
3924
|
+
return size_t(res);
|
|
3925
|
+
|
|
3926
|
+
if (print_debug)
|
|
3927
|
+
{
|
|
3928
|
+
printf("NVRTC options:\n");
|
|
3929
|
+
for(auto o: opts) {
|
|
3930
|
+
printf("%s\n", o);
|
|
3931
|
+
}
|
|
3932
|
+
}
|
|
3933
|
+
res = nvrtcCompileProgram(prog, int(opts.size()), opts.data());
|
|
3934
|
+
|
|
3935
|
+
if (!check_nvrtc(res) || verbose)
|
|
3936
|
+
{
|
|
3937
|
+
// get program log
|
|
3938
|
+
size_t log_size;
|
|
3939
|
+
if (check_nvrtc(nvrtcGetProgramLogSize(prog, &log_size)))
|
|
3940
|
+
{
|
|
3941
|
+
std::vector<char> log(log_size);
|
|
3942
|
+
if (check_nvrtc(nvrtcGetProgramLog(prog, log.data())))
|
|
3943
|
+
{
|
|
3944
|
+
// todo: figure out better way to return this to python
|
|
3945
|
+
if (res != NVRTC_SUCCESS)
|
|
3946
|
+
fprintf(stderr, "%s", log.data());
|
|
3947
|
+
else
|
|
3948
|
+
fprintf(stdout, "%s", log.data());
|
|
3949
|
+
}
|
|
3950
|
+
}
|
|
3951
|
+
|
|
3952
|
+
if (res != NVRTC_SUCCESS)
|
|
3953
|
+
{
|
|
3954
|
+
nvrtcDestroyProgram(&prog);
|
|
3955
|
+
return size_t(res);
|
|
3956
|
+
}
|
|
3957
|
+
}
|
|
3958
|
+
|
|
3959
|
+
nvrtcResult (*get_output_size)(nvrtcProgram, size_t*);
|
|
3960
|
+
nvrtcResult (*get_output_data)(nvrtcProgram, char*);
|
|
3961
|
+
const char* output_mode;
|
|
3962
|
+
if(num_ltoirs > 0) {
|
|
3963
|
+
#if WP_ENABLE_MATHDX
|
|
3964
|
+
get_output_size = nvrtcGetLTOIRSize;
|
|
3965
|
+
get_output_data = nvrtcGetLTOIR;
|
|
3966
|
+
output_mode = "wb";
|
|
3967
|
+
#else
|
|
3968
|
+
fprintf(stderr, "Warp error: num_ltoirs > 0 but Warp was not built with MathDx support\n");
|
|
3969
|
+
return size_t(-1);
|
|
3970
|
+
#endif
|
|
3971
|
+
}
|
|
3972
|
+
else if (use_ptx)
|
|
3973
|
+
{
|
|
3974
|
+
get_output_size = nvrtcGetPTXSize;
|
|
3975
|
+
get_output_data = nvrtcGetPTX;
|
|
3976
|
+
output_mode = "wt";
|
|
3977
|
+
}
|
|
3978
|
+
else
|
|
3979
|
+
{
|
|
3980
|
+
get_output_size = nvrtcGetCUBINSize;
|
|
3981
|
+
get_output_data = nvrtcGetCUBIN;
|
|
3982
|
+
output_mode = "wb";
|
|
3983
|
+
}
|
|
3984
|
+
|
|
3985
|
+
// save output
|
|
3986
|
+
size_t output_size;
|
|
3987
|
+
res = get_output_size(prog, &output_size);
|
|
3988
|
+
if (check_nvrtc(res))
|
|
3989
|
+
{
|
|
3990
|
+
std::vector<char> output(output_size);
|
|
3991
|
+
res = get_output_data(prog, output.data());
|
|
3992
|
+
if (check_nvrtc(res))
|
|
3993
|
+
{
|
|
3994
|
+
|
|
3995
|
+
// LTOIR case - need an extra step
|
|
3996
|
+
if (num_ltoirs > 0)
|
|
3997
|
+
{
|
|
3998
|
+
#if WP_ENABLE_MATHDX
|
|
3999
|
+
if(ltoir_input_types == nullptr || ltoirs == nullptr || ltoir_sizes == nullptr) {
|
|
4000
|
+
fprintf(stderr, "Warp error: num_ltoirs > 0 but ltoir_input_types, ltoirs or ltoir_sizes are NULL\n");
|
|
4001
|
+
return size_t(-1);
|
|
4002
|
+
}
|
|
4003
|
+
nvJitLinkHandle handle = nullptr;
|
|
4004
|
+
std::vector<const char *> lopts = {"-dlto", arch_opt_lto};
|
|
4005
|
+
if (use_ptx) {
|
|
4006
|
+
lopts.push_back("-ptx");
|
|
4007
|
+
}
|
|
4008
|
+
if (print_debug)
|
|
4009
|
+
{
|
|
4010
|
+
printf("nvJitLink options:\n");
|
|
4011
|
+
for(auto o: lopts) {
|
|
4012
|
+
printf("%s\n", o);
|
|
4013
|
+
}
|
|
4014
|
+
}
|
|
4015
|
+
if(!check_nvjitlink(handle, nvJitLinkCreate(&handle, lopts.size(), lopts.data())))
|
|
4016
|
+
{
|
|
4017
|
+
res = nvrtcResult(-1);
|
|
4018
|
+
}
|
|
4019
|
+
// Links
|
|
4020
|
+
if(std::getenv("WARP_DUMP_LTOIR"))
|
|
4021
|
+
{
|
|
4022
|
+
write_file(output.data(), output.size(), "nvrtc_output.ltoir", "wb");
|
|
4023
|
+
}
|
|
4024
|
+
if(!check_nvjitlink(handle, nvJitLinkAddData(handle, NVJITLINK_INPUT_LTOIR, output.data(), output.size(), "nvrtc_output"))) // NVRTC business
|
|
4025
|
+
{
|
|
4026
|
+
res = nvrtcResult(-1);
|
|
4027
|
+
}
|
|
4028
|
+
for(size_t ltoidx = 0; ltoidx < num_ltoirs; ltoidx++)
|
|
4029
|
+
{
|
|
4030
|
+
nvJitLinkInputType input_type = static_cast<nvJitLinkInputType>(ltoir_input_types[ltoidx]);
|
|
4031
|
+
const char* ext = ".unknown";
|
|
4032
|
+
switch(input_type) {
|
|
4033
|
+
case NVJITLINK_INPUT_CUBIN:
|
|
4034
|
+
ext = ".cubin";
|
|
4035
|
+
break;
|
|
4036
|
+
case NVJITLINK_INPUT_LTOIR:
|
|
4037
|
+
ext = ".ltoir";
|
|
4038
|
+
break;
|
|
4039
|
+
case NVJITLINK_INPUT_FATBIN:
|
|
4040
|
+
ext = ".fatbin";
|
|
4041
|
+
break;
|
|
4042
|
+
default:
|
|
4043
|
+
break;
|
|
4044
|
+
}
|
|
4045
|
+
if(std::getenv("WARP_DUMP_LTOIR"))
|
|
4046
|
+
{
|
|
4047
|
+
write_file(ltoirs[ltoidx], ltoir_sizes[ltoidx], std::string("lto_online_") + std::to_string(ltoidx) + ext, "wb");
|
|
4048
|
+
}
|
|
4049
|
+
if(!check_nvjitlink(handle, nvJitLinkAddData(handle, input_type, ltoirs[ltoidx], ltoir_sizes[ltoidx], "lto_online"))) // External LTOIR
|
|
4050
|
+
{
|
|
4051
|
+
res = nvrtcResult(-1);
|
|
4052
|
+
}
|
|
4053
|
+
}
|
|
4054
|
+
if(!check_nvjitlink(handle, nvJitLinkComplete(handle)))
|
|
4055
|
+
{
|
|
4056
|
+
res = nvrtcResult(-1);
|
|
4057
|
+
}
|
|
4058
|
+
else
|
|
4059
|
+
{
|
|
4060
|
+
if(use_ptx)
|
|
4061
|
+
{
|
|
4062
|
+
size_t ptx_size = 0;
|
|
4063
|
+
check_nvjitlink(handle, nvJitLinkGetLinkedPtxSize(handle, &ptx_size));
|
|
4064
|
+
std::vector<char> ptx(ptx_size);
|
|
4065
|
+
check_nvjitlink(handle, nvJitLinkGetLinkedPtx(handle, ptx.data()));
|
|
4066
|
+
output = ptx;
|
|
4067
|
+
}
|
|
4068
|
+
else
|
|
4069
|
+
{
|
|
4070
|
+
size_t cubin_size = 0;
|
|
4071
|
+
check_nvjitlink(handle, nvJitLinkGetLinkedCubinSize(handle, &cubin_size));
|
|
4072
|
+
std::vector<char> cubin(cubin_size);
|
|
4073
|
+
check_nvjitlink(handle, nvJitLinkGetLinkedCubin(handle, cubin.data()));
|
|
4074
|
+
output = cubin;
|
|
4075
|
+
}
|
|
4076
|
+
}
|
|
4077
|
+
check_nvjitlink(handle, nvJitLinkDestroy(&handle));
|
|
4078
|
+
#else
|
|
4079
|
+
fprintf(stderr, "Warp error: num_ltoirs > 0 but Warp was not built with MathDx support\n");
|
|
4080
|
+
return size_t(-1);
|
|
4081
|
+
#endif
|
|
4082
|
+
}
|
|
4083
|
+
|
|
4084
|
+
if(!write_file(output.data(), output.size(), output_path, output_mode)) {
|
|
4085
|
+
res = nvrtcResult(-1);
|
|
4086
|
+
}
|
|
4087
|
+
}
|
|
4088
|
+
}
|
|
4089
|
+
|
|
4090
|
+
check_nvrtc(nvrtcDestroyProgram(&prog));
|
|
4091
|
+
|
|
4092
|
+
return res;
|
|
4093
|
+
}
|
|
4094
|
+
|
|
4095
|
+
#if WP_ENABLE_MATHDX
|
|
4096
|
+
bool check_cufftdx_result(commondxStatusType result, const char* file, int line)
|
|
4097
|
+
{
|
|
4098
|
+
if (result != commondxStatusType::COMMONDX_SUCCESS) {
|
|
4099
|
+
fprintf(stderr, "libmathdx cuFFTDx error: %d on %s:%d\n", (int)result, file, line);
|
|
4100
|
+
return false;
|
|
4101
|
+
} else {
|
|
4102
|
+
return true;
|
|
4103
|
+
}
|
|
4104
|
+
}
|
|
4105
|
+
|
|
4106
|
+
bool check_cublasdx_result(commondxStatusType result, const char* file, int line)
|
|
4107
|
+
{
|
|
4108
|
+
if (result != commondxStatusType::COMMONDX_SUCCESS) {
|
|
4109
|
+
fprintf(stderr, "libmathdx cuBLASDx error: %d on %s:%d\n", (int)result, file, line);
|
|
4110
|
+
return false;
|
|
4111
|
+
} else {
|
|
4112
|
+
return true;
|
|
4113
|
+
}
|
|
4114
|
+
}
|
|
4115
|
+
|
|
4116
|
+
bool check_cusolver_result(commondxStatusType result, const char* file, int line)
|
|
4117
|
+
{
|
|
4118
|
+
if (result != commondxStatusType::COMMONDX_SUCCESS) {
|
|
4119
|
+
fprintf(stderr, "libmathdx cuSOLVER error: %d on %s:%d\n", (int)result, file, line);
|
|
4120
|
+
return false;
|
|
4121
|
+
} else {
|
|
4122
|
+
return true;
|
|
4123
|
+
}
|
|
4124
|
+
}
|
|
4125
|
+
|
|
4126
|
+
bool wp_cuda_compile_fft(const char* ltoir_output_path, const char* symbol_name, int num_include_dirs, const char** include_dirs, const char* mathdx_include_dir, int arch, int size, int elements_per_thread, int direction, int precision, int* shared_memory_size)
|
|
4127
|
+
{
|
|
4128
|
+
|
|
4129
|
+
CHECK_ANY(ltoir_output_path != nullptr);
|
|
4130
|
+
CHECK_ANY(symbol_name != nullptr);
|
|
4131
|
+
CHECK_ANY(shared_memory_size != nullptr);
|
|
4132
|
+
// Includes currently unused
|
|
4133
|
+
CHECK_ANY(include_dirs == nullptr);
|
|
4134
|
+
CHECK_ANY(mathdx_include_dir == nullptr);
|
|
4135
|
+
CHECK_ANY(num_include_dirs == 0);
|
|
4136
|
+
|
|
4137
|
+
bool res = true;
|
|
4138
|
+
cufftdxDescriptor h;
|
|
4139
|
+
CHECK_CUFFTDX(cufftdxCreateDescriptor(&h));
|
|
4140
|
+
|
|
4141
|
+
// CUFFTDX_API_LMEM means each thread starts with a subset of the data
|
|
4142
|
+
CHECK_CUFFTDX(cufftdxSetOperatorInt64(h, cufftdxOperatorType::CUFFTDX_OPERATOR_API, cufftdxApi::CUFFTDX_API_LMEM));
|
|
4143
|
+
CHECK_CUFFTDX(cufftdxSetOperatorInt64(h, cufftdxOperatorType::CUFFTDX_OPERATOR_EXECUTION, commondxExecution::COMMONDX_EXECUTION_BLOCK));
|
|
4144
|
+
CHECK_CUFFTDX(cufftdxSetOperatorInt64(h, cufftdxOperatorType::CUFFTDX_OPERATOR_SIZE, (long long)size));
|
|
4145
|
+
CHECK_CUFFTDX(cufftdxSetOperatorInt64(h, cufftdxOperatorType::CUFFTDX_OPERATOR_DIRECTION, (cufftdxDirection)direction));
|
|
4146
|
+
CHECK_CUFFTDX(cufftdxSetOperatorInt64(h, cufftdxOperatorType::CUFFTDX_OPERATOR_PRECISION, (commondxPrecision)precision));
|
|
4147
|
+
CHECK_CUFFTDX(cufftdxSetOperatorInt64(h, cufftdxOperatorType::CUFFTDX_OPERATOR_SM, (long long)(arch * 10)));
|
|
4148
|
+
CHECK_CUFFTDX(cufftdxSetOperatorInt64(h, cufftdxOperatorType::CUFFTDX_OPERATOR_ELEMENTS_PER_THREAD, (long long)(elements_per_thread)));
|
|
4149
|
+
CHECK_CUFFTDX(cufftdxSetOperatorInt64(h, cufftdxOperatorType::CUFFTDX_OPERATOR_FFTS_PER_BLOCK, 1));
|
|
4150
|
+
|
|
4151
|
+
CHECK_CUFFTDX(cufftdxSetOptionStr(h, commondxOption::COMMONDX_OPTION_SYMBOL_NAME, symbol_name));
|
|
4152
|
+
|
|
4153
|
+
size_t lto_size = 0;
|
|
4154
|
+
CHECK_CUFFTDX(cufftdxGetLTOIRSize(h, <o_size));
|
|
4155
|
+
|
|
4156
|
+
std::vector<char> lto(lto_size);
|
|
4157
|
+
CHECK_CUFFTDX(cufftdxGetLTOIR(h, lto.size(), lto.data()));
|
|
4158
|
+
|
|
4159
|
+
long long int smem = 0;
|
|
4160
|
+
CHECK_CUFFTDX(cufftdxGetTraitInt64(h, cufftdxTraitType::CUFFTDX_TRAIT_SHARED_MEMORY_SIZE, &smem));
|
|
4161
|
+
*shared_memory_size = (int)smem;
|
|
4162
|
+
|
|
4163
|
+
if(!write_file(lto.data(), lto.size(), ltoir_output_path, "wb")) {
|
|
4164
|
+
res = false;
|
|
4165
|
+
}
|
|
4166
|
+
|
|
4167
|
+
CHECK_CUFFTDX(cufftdxDestroyDescriptor(h));
|
|
4168
|
+
|
|
4169
|
+
return res;
|
|
4170
|
+
}
|
|
4171
|
+
|
|
4172
|
+
bool wp_cuda_compile_dot(const char* ltoir_output_path, const char* symbol_name, int num_include_dirs, const char** include_dirs, const char* mathdx_include_dir, int arch, int M, int N, int K, int precision_A, int precision_B, int precision_C, int type, int arrangement_A, int arrangement_B, int arrangement_C, int num_threads)
|
|
4173
|
+
{
|
|
4174
|
+
|
|
4175
|
+
CHECK_ANY(ltoir_output_path != nullptr);
|
|
4176
|
+
CHECK_ANY(symbol_name != nullptr);
|
|
4177
|
+
// Includes currently unused
|
|
4178
|
+
CHECK_ANY(include_dirs == nullptr);
|
|
4179
|
+
CHECK_ANY(mathdx_include_dir == nullptr);
|
|
4180
|
+
CHECK_ANY(num_include_dirs == 0);
|
|
4181
|
+
|
|
4182
|
+
bool res = true;
|
|
4183
|
+
cublasdxDescriptor h;
|
|
4184
|
+
CHECK_CUBLASDX(cublasdxCreateDescriptor(&h));
|
|
4185
|
+
|
|
4186
|
+
CHECK_CUBLASDX(cublasdxSetOperatorInt64(h, cublasdxOperatorType::CUBLASDX_OPERATOR_FUNCTION, cublasdxFunction::CUBLASDX_FUNCTION_MM));
|
|
4187
|
+
CHECK_CUBLASDX(cublasdxSetOperatorInt64(h, cublasdxOperatorType::CUBLASDX_OPERATOR_EXECUTION, commondxExecution::COMMONDX_EXECUTION_BLOCK));
|
|
4188
|
+
CHECK_CUBLASDX(cublasdxSetOperatorInt64(h, cublasdxOperatorType::CUBLASDX_OPERATOR_API, cublasdxApi::CUBLASDX_API_SMEM));
|
|
4189
|
+
std::array<long long int, 3> precisions = {precision_A, precision_B, precision_C};
|
|
4190
|
+
CHECK_CUBLASDX(cublasdxSetOperatorInt64s(h, cublasdxOperatorType::CUBLASDX_OPERATOR_PRECISION, 3, precisions.data()));
|
|
4191
|
+
CHECK_CUBLASDX(cublasdxSetOperatorInt64(h, cublasdxOperatorType::CUBLASDX_OPERATOR_SM, (long long)(arch * 10)));
|
|
4192
|
+
CHECK_CUBLASDX(cublasdxSetOperatorInt64(h, cublasdxOperatorType::CUBLASDX_OPERATOR_TYPE, (cublasdxType)type));
|
|
4193
|
+
std::array<long long int, 3> block_dim = {num_threads, 1, 1};
|
|
4194
|
+
CHECK_CUBLASDX(cublasdxSetOperatorInt64s(h, cublasdxOperatorType::CUBLASDX_OPERATOR_BLOCK_DIM, block_dim.size(), block_dim.data()));
|
|
4195
|
+
std::array<long long int, 3> size = {M, N, K};
|
|
4196
|
+
CHECK_CUBLASDX(cublasdxSetOperatorInt64s(h, cublasdxOperatorType::CUBLASDX_OPERATOR_SIZE, size.size(), size.data()));
|
|
4197
|
+
std::array<long long int, 3> arrangement = {arrangement_A, arrangement_B, arrangement_C};
|
|
4198
|
+
CHECK_CUBLASDX(cublasdxSetOperatorInt64s(h, cublasdxOperatorType::CUBLASDX_OPERATOR_ARRANGEMENT, arrangement.size(), arrangement.data()));
|
|
4199
|
+
|
|
4200
|
+
CHECK_CUBLASDX(cublasdxSetOptionStr(h, commondxOption::COMMONDX_OPTION_SYMBOL_NAME, symbol_name));
|
|
4201
|
+
|
|
4202
|
+
size_t lto_size = 0;
|
|
4203
|
+
CHECK_CUBLASDX(cublasdxGetLTOIRSize(h, <o_size));
|
|
4204
|
+
|
|
4205
|
+
std::vector<char> lto(lto_size);
|
|
4206
|
+
CHECK_CUBLASDX(cublasdxGetLTOIR(h, lto.size(), lto.data()));
|
|
4207
|
+
|
|
4208
|
+
if(!write_file(lto.data(), lto.size(), ltoir_output_path, "wb")) {
|
|
4209
|
+
res = false;
|
|
4210
|
+
}
|
|
4211
|
+
|
|
4212
|
+
CHECK_CUBLASDX(cublasdxDestroyDescriptor(h));
|
|
4213
|
+
|
|
4214
|
+
return res;
|
|
4215
|
+
}
|
|
4216
|
+
|
|
4217
|
+
bool wp_cuda_compile_solver(const char* fatbin_output_path, const char* ltoir_output_path, const char* symbol_name, int num_include_dirs, const char** include_dirs, const char* mathdx_include_dir, int arch, int M, int N, int NRHS, int function, int side, int diag, int precision, int arrangement_A, int arrangement_B, int fill_mode, int num_threads)
|
|
4218
|
+
{
|
|
4219
|
+
|
|
4220
|
+
CHECK_ANY(ltoir_output_path != nullptr);
|
|
4221
|
+
CHECK_ANY(symbol_name != nullptr);
|
|
4222
|
+
CHECK_ANY(mathdx_include_dir == nullptr);
|
|
4223
|
+
CHECK_ANY(num_include_dirs == 0);
|
|
4224
|
+
CHECK_ANY(include_dirs == nullptr);
|
|
4225
|
+
|
|
4226
|
+
bool res = true;
|
|
4227
|
+
|
|
4228
|
+
cusolverdxDescriptor h { 0 };
|
|
4229
|
+
CHECK_CUSOLVER(cusolverdxCreateDescriptor(&h));
|
|
4230
|
+
std::array<long long int, 3> size = {M, N, NRHS};
|
|
4231
|
+
CHECK_CUSOLVER(cusolverdxSetOperatorInt64s(h, cusolverdxOperatorType::CUSOLVERDX_OPERATOR_SIZE, size.size(), size.data()));
|
|
4232
|
+
std::array<long long int, 3> block_dim = {num_threads, 1, 1};
|
|
4233
|
+
CHECK_CUSOLVER(cusolverdxSetOperatorInt64s(h, cusolverdxOperatorType::CUSOLVERDX_OPERATOR_BLOCK_DIM, block_dim.size(), block_dim.data()));
|
|
4234
|
+
CHECK_CUSOLVER(cusolverdxSetOperatorInt64(h, cusolverdxOperatorType::CUSOLVERDX_OPERATOR_TYPE, cusolverdxType::CUSOLVERDX_TYPE_REAL));
|
|
4235
|
+
CHECK_CUSOLVER(cusolverdxSetOperatorInt64(h, cusolverdxOperatorType::CUSOLVERDX_OPERATOR_API, cusolverdxApi::CUSOLVERDX_API_SMEM));
|
|
4236
|
+
CHECK_CUSOLVER(cusolverdxSetOperatorInt64(h, cusolverdxOperatorType::CUSOLVERDX_OPERATOR_FUNCTION, (cusolverdxFunction)function));
|
|
4237
|
+
if (side >= 0) {
|
|
4238
|
+
CHECK_CUSOLVER(cusolverdxSetOperatorInt64(h, cusolverdxOperatorType::CUSOLVERDX_OPERATOR_SIDE, (cusolverdxSide)side));
|
|
4239
|
+
}
|
|
4240
|
+
if (diag >= 0) {
|
|
4241
|
+
CHECK_CUSOLVER(cusolverdxSetOperatorInt64(h, cusolverdxOperatorType::CUSOLVERDX_OPERATOR_DIAG, (cusolverdxDiag)diag));
|
|
4242
|
+
}
|
|
4243
|
+
CHECK_CUSOLVER(cusolverdxSetOperatorInt64(h, cusolverdxOperatorType::CUSOLVERDX_OPERATOR_EXECUTION, commondxExecution::COMMONDX_EXECUTION_BLOCK));
|
|
4244
|
+
CHECK_CUSOLVER(cusolverdxSetOperatorInt64(h, cusolverdxOperatorType::CUSOLVERDX_OPERATOR_PRECISION, (commondxPrecision)precision));
|
|
4245
|
+
std::array<long long int, 2> arrangement = {arrangement_A, arrangement_B};
|
|
4246
|
+
CHECK_CUSOLVER(cusolverdxSetOperatorInt64s(h, cusolverdxOperatorType::CUSOLVERDX_OPERATOR_ARRANGEMENT, arrangement.size(), arrangement.data()));
|
|
4247
|
+
CHECK_CUSOLVER(cusolverdxSetOperatorInt64(h, cusolverdxOperatorType::CUSOLVERDX_OPERATOR_FILL_MODE, (cusolverdxFillMode)fill_mode));
|
|
4248
|
+
CHECK_CUSOLVER(cusolverdxSetOperatorInt64(h, cusolverdxOperatorType::CUSOLVERDX_OPERATOR_SM, (long long)(arch * 10)));
|
|
4249
|
+
|
|
4250
|
+
CHECK_CUSOLVER(cusolverdxSetOptionStr(h, commondxOption::COMMONDX_OPTION_SYMBOL_NAME, symbol_name));
|
|
4251
|
+
|
|
4252
|
+
size_t lto_size = 0;
|
|
4253
|
+
CHECK_CUSOLVER(cusolverdxGetLTOIRSize(h, <o_size));
|
|
4254
|
+
|
|
4255
|
+
std::vector<char> lto(lto_size);
|
|
4256
|
+
CHECK_CUSOLVER(cusolverdxGetLTOIR(h, lto.size(), lto.data()));
|
|
4257
|
+
|
|
4258
|
+
// This fatbin is universal, ie it is the same for any instantiations of a cusolver device function
|
|
4259
|
+
size_t fatbin_size = 0;
|
|
4260
|
+
CHECK_CUSOLVER(cusolverdxGetUniversalFATBINSize(h, &fatbin_size));
|
|
4261
|
+
|
|
4262
|
+
std::vector<char> fatbin(fatbin_size);
|
|
4263
|
+
CHECK_CUSOLVER(cusolverdxGetUniversalFATBIN(h, fatbin.size(), fatbin.data()));
|
|
4264
|
+
|
|
4265
|
+
if(!write_file(lto.data(), lto.size(), ltoir_output_path, "wb")) {
|
|
4266
|
+
res = false;
|
|
4267
|
+
}
|
|
4268
|
+
|
|
4269
|
+
if(!write_file(fatbin.data(), fatbin.size(), fatbin_output_path, "wb")) {
|
|
4270
|
+
res = false;
|
|
4271
|
+
}
|
|
4272
|
+
|
|
4273
|
+
CHECK_CUSOLVER(cusolverdxDestroyDescriptor(h));
|
|
4274
|
+
|
|
4275
|
+
return res;
|
|
4276
|
+
}
|
|
4277
|
+
|
|
4278
|
+
#endif
|
|
4279
|
+
|
|
4280
|
+
void* wp_cuda_load_module(void* context, const char* path)
|
|
4281
|
+
{
|
|
4282
|
+
ContextGuard guard(context);
|
|
4283
|
+
|
|
4284
|
+
// use file extension to determine whether to load PTX or CUBIN
|
|
4285
|
+
const char* input_ext = strrchr(path, '.');
|
|
4286
|
+
bool load_ptx = input_ext && strcmp(input_ext + 1, "ptx") == 0;
|
|
4287
|
+
|
|
4288
|
+
std::vector<char> input;
|
|
4289
|
+
|
|
4290
|
+
FILE* file = fopen(path, "rb");
|
|
4291
|
+
if (file)
|
|
4292
|
+
{
|
|
4293
|
+
fseek(file, 0, SEEK_END);
|
|
4294
|
+
size_t length = ftell(file);
|
|
4295
|
+
fseek(file, 0, SEEK_SET);
|
|
4296
|
+
|
|
4297
|
+
input.resize(length + 1);
|
|
4298
|
+
if (fread(input.data(), 1, length, file) != length)
|
|
4299
|
+
{
|
|
4300
|
+
fprintf(stderr, "Warp error: Failed to read input file '%s'\n", path);
|
|
4301
|
+
fclose(file);
|
|
4302
|
+
return NULL;
|
|
4303
|
+
}
|
|
4304
|
+
fclose(file);
|
|
4305
|
+
|
|
4306
|
+
input[length] = '\0';
|
|
4307
|
+
}
|
|
4308
|
+
else
|
|
4309
|
+
{
|
|
4310
|
+
fprintf(stderr, "Warp error: Failed to open input file '%s'\n", path);
|
|
4311
|
+
return NULL;
|
|
4312
|
+
}
|
|
4313
|
+
|
|
4314
|
+
int driver_cuda_version = 0;
|
|
4315
|
+
CUmodule module = NULL;
|
|
4316
|
+
|
|
4317
|
+
if (load_ptx)
|
|
4318
|
+
{
|
|
4319
|
+
if (check_cu(cuDriverGetVersion_f(&driver_cuda_version)) && driver_cuda_version >= CUDA_VERSION)
|
|
4320
|
+
{
|
|
4321
|
+
// let the driver compile the PTX
|
|
4322
|
+
|
|
4323
|
+
CUjit_option options[2];
|
|
4324
|
+
void *option_vals[2];
|
|
4325
|
+
char error_log[8192] = "";
|
|
4326
|
+
unsigned int log_size = 8192;
|
|
4327
|
+
// Set up loader options
|
|
4328
|
+
// Pass a buffer for error message
|
|
4329
|
+
options[0] = CU_JIT_ERROR_LOG_BUFFER;
|
|
4330
|
+
option_vals[0] = (void*)error_log;
|
|
4331
|
+
// Pass the size of the error buffer
|
|
4332
|
+
options[1] = CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES;
|
|
4333
|
+
option_vals[1] = (void*)(size_t)log_size;
|
|
4334
|
+
|
|
4335
|
+
if (!check_cu(cuModuleLoadDataEx_f(&module, input.data(), 2, options, option_vals)))
|
|
4336
|
+
{
|
|
4337
|
+
fprintf(stderr, "Warp error: Loading PTX module failed\n");
|
|
4338
|
+
// print error log if not empty
|
|
4339
|
+
if (*error_log)
|
|
4340
|
+
fprintf(stderr, "PTX loader error:\n%s\n", error_log);
|
|
4341
|
+
return NULL;
|
|
4342
|
+
}
|
|
4343
|
+
}
|
|
4344
|
+
else
|
|
4345
|
+
{
|
|
4346
|
+
// manually compile the PTX and load as CUBIN
|
|
4347
|
+
|
|
4348
|
+
ContextInfo* context_info = get_context_info(static_cast<CUcontext>(context));
|
|
4349
|
+
if (!context_info || !context_info->device_info)
|
|
4350
|
+
{
|
|
4351
|
+
fprintf(stderr, "Warp error: Failed to determine target architecture\n");
|
|
4352
|
+
return NULL;
|
|
4353
|
+
}
|
|
4354
|
+
|
|
4355
|
+
int arch = context_info->device_info->arch;
|
|
4356
|
+
|
|
4357
|
+
char arch_opt[128];
|
|
4358
|
+
sprintf(arch_opt, "--gpu-name=sm_%d", arch);
|
|
4359
|
+
|
|
4360
|
+
const char* compiler_options[] = { arch_opt };
|
|
4361
|
+
|
|
4362
|
+
nvPTXCompilerHandle compiler = NULL;
|
|
4363
|
+
if (!check_nvptx(nvPTXCompilerCreate(&compiler, input.size(), input.data())))
|
|
4364
|
+
return NULL;
|
|
4365
|
+
|
|
4366
|
+
if (!check_nvptx(nvPTXCompilerCompile(compiler, sizeof(compiler_options) / sizeof(*compiler_options), compiler_options)))
|
|
4367
|
+
return NULL;
|
|
4368
|
+
|
|
4369
|
+
size_t cubin_size = 0;
|
|
4370
|
+
if (!check_nvptx(nvPTXCompilerGetCompiledProgramSize(compiler, &cubin_size)))
|
|
4371
|
+
return NULL;
|
|
4372
|
+
|
|
4373
|
+
std::vector<char> cubin(cubin_size);
|
|
4374
|
+
if (!check_nvptx(nvPTXCompilerGetCompiledProgram(compiler, cubin.data())))
|
|
4375
|
+
return NULL;
|
|
4376
|
+
|
|
4377
|
+
check_nvptx(nvPTXCompilerDestroy(&compiler));
|
|
4378
|
+
|
|
4379
|
+
if (!check_cu(cuModuleLoadDataEx_f(&module, cubin.data(), 0, NULL, NULL)))
|
|
4380
|
+
{
|
|
4381
|
+
fprintf(stderr, "Warp CUDA error: Loading module failed\n");
|
|
4382
|
+
return NULL;
|
|
4383
|
+
}
|
|
4384
|
+
}
|
|
4385
|
+
}
|
|
4386
|
+
else
|
|
4387
|
+
{
|
|
4388
|
+
// load CUBIN
|
|
4389
|
+
if (!check_cu(cuModuleLoadDataEx_f(&module, input.data(), 0, NULL, NULL)))
|
|
4390
|
+
{
|
|
4391
|
+
fprintf(stderr, "Warp CUDA error: Loading module failed\n");
|
|
4392
|
+
return NULL;
|
|
4393
|
+
}
|
|
4394
|
+
}
|
|
4395
|
+
|
|
4396
|
+
return module;
|
|
4397
|
+
}
|
|
4398
|
+
|
|
4399
|
+
void wp_cuda_unload_module(void* context, void* module)
|
|
4400
|
+
{
|
|
4401
|
+
// ensure there are no graph captures in progress
|
|
4402
|
+
if (g_captures.empty())
|
|
4403
|
+
{
|
|
4404
|
+
ContextGuard guard(context);
|
|
4405
|
+
check_cu(cuModuleUnload_f((CUmodule)module));
|
|
4406
|
+
}
|
|
4407
|
+
else
|
|
4408
|
+
{
|
|
4409
|
+
// defer until graph capture completes
|
|
4410
|
+
ModuleInfo module_info;
|
|
4411
|
+
module_info.context = context ? context : get_current_context();
|
|
4412
|
+
module_info.module = module;
|
|
4413
|
+
g_deferred_module_list.push_back(module_info);
|
|
4414
|
+
}
|
|
4415
|
+
}
|
|
4416
|
+
|
|
4417
|
+
|
|
4418
|
+
int wp_cuda_get_max_shared_memory(void* context)
|
|
4419
|
+
{
|
|
4420
|
+
ContextInfo* info = get_context_info(context);
|
|
4421
|
+
if (!info)
|
|
4422
|
+
return -1;
|
|
4423
|
+
|
|
4424
|
+
int max_smem_bytes = info->device_info->max_smem_bytes;
|
|
4425
|
+
return max_smem_bytes;
|
|
4426
|
+
}
|
|
4427
|
+
|
|
4428
|
+
bool wp_cuda_configure_kernel_shared_memory(void* kernel, int size)
|
|
4429
|
+
{
|
|
4430
|
+
int requested_smem_bytes = size;
|
|
4431
|
+
|
|
4432
|
+
// configure shared memory
|
|
4433
|
+
CUresult res = cuFuncSetAttribute_f((CUfunction)kernel, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, requested_smem_bytes);
|
|
4434
|
+
if (res != CUDA_SUCCESS)
|
|
4435
|
+
return false;
|
|
4436
|
+
|
|
4437
|
+
return true;
|
|
4438
|
+
}
|
|
4439
|
+
|
|
4440
|
+
void* wp_cuda_get_kernel(void* context, void* module, const char* name)
|
|
4441
|
+
{
|
|
4442
|
+
ContextGuard guard(context);
|
|
4443
|
+
|
|
4444
|
+
CUfunction kernel = NULL;
|
|
4445
|
+
if (!check_cu(cuModuleGetFunction_f(&kernel, (CUmodule)module, name)))
|
|
4446
|
+
{
|
|
4447
|
+
fprintf(stderr, "Warp CUDA error: Failed to lookup kernel function %s in module\n", name);
|
|
4448
|
+
return NULL;
|
|
4449
|
+
}
|
|
4450
|
+
|
|
4451
|
+
g_kernel_names[kernel] = name;
|
|
4452
|
+
return kernel;
|
|
4453
|
+
}
|
|
4454
|
+
|
|
4455
|
+
size_t wp_cuda_launch_kernel(void* context, void* kernel, size_t dim, int max_blocks, int block_dim, int shared_memory_bytes, void** args, void* stream)
|
|
4456
|
+
{
|
|
4457
|
+
ContextGuard guard(context);
|
|
4458
|
+
|
|
4459
|
+
if (block_dim <= 0)
|
|
4460
|
+
{
|
|
4461
|
+
#if defined(_DEBUG)
|
|
4462
|
+
fprintf(stderr, "Warp warning: Launch got block_dim %d. Setting to 256.\n", block_dim);
|
|
4463
|
+
#endif
|
|
4464
|
+
block_dim = 256;
|
|
4465
|
+
}
|
|
4466
|
+
|
|
4467
|
+
// CUDA specs up to compute capability 9.0 says the max x-dim grid is 2**31-1, so
|
|
4468
|
+
// grid_dim is fine as an int for the near future
|
|
4469
|
+
int grid_dim = (dim + block_dim - 1)/block_dim;
|
|
4470
|
+
|
|
4471
|
+
if (max_blocks <= 0) {
|
|
4472
|
+
max_blocks = 2147483647;
|
|
4473
|
+
}
|
|
4474
|
+
|
|
4475
|
+
if (grid_dim < 0)
|
|
4476
|
+
{
|
|
4477
|
+
#if defined(_DEBUG)
|
|
4478
|
+
fprintf(stderr, "Warp warning: Overflow in grid dimensions detected for %zu total elements and 256 threads "
|
|
4479
|
+
"per block.\n Setting block count to %d.\n", dim, max_blocks);
|
|
4480
|
+
#endif
|
|
4481
|
+
grid_dim = max_blocks;
|
|
4482
|
+
}
|
|
4483
|
+
else
|
|
4484
|
+
{
|
|
4485
|
+
if (grid_dim > max_blocks)
|
|
4486
|
+
{
|
|
4487
|
+
grid_dim = max_blocks;
|
|
4488
|
+
}
|
|
4489
|
+
}
|
|
4490
|
+
|
|
4491
|
+
begin_cuda_range(WP_TIMING_KERNEL, stream, context, get_cuda_kernel_name(kernel));
|
|
4492
|
+
|
|
4493
|
+
CUresult res = cuLaunchKernel_f(
|
|
4494
|
+
(CUfunction)kernel,
|
|
4495
|
+
grid_dim, 1, 1,
|
|
4496
|
+
block_dim, 1, 1,
|
|
4497
|
+
shared_memory_bytes,
|
|
4498
|
+
static_cast<CUstream>(stream),
|
|
4499
|
+
args,
|
|
4500
|
+
0);
|
|
4501
|
+
|
|
4502
|
+
check_cu(res);
|
|
4503
|
+
|
|
4504
|
+
end_cuda_range(WP_TIMING_KERNEL, stream);
|
|
4505
|
+
|
|
4506
|
+
return res;
|
|
4507
|
+
}
|
|
4508
|
+
|
|
4509
|
+
void wp_cuda_graphics_map(void* context, void* resource)
|
|
4510
|
+
{
|
|
4511
|
+
ContextGuard guard(context);
|
|
4512
|
+
|
|
4513
|
+
check_cu(cuGraphicsMapResources_f(1, (CUgraphicsResource*)resource, get_current_stream()));
|
|
4514
|
+
}
|
|
4515
|
+
|
|
4516
|
+
void wp_cuda_graphics_unmap(void* context, void* resource)
|
|
4517
|
+
{
|
|
4518
|
+
ContextGuard guard(context);
|
|
4519
|
+
|
|
4520
|
+
check_cu(cuGraphicsUnmapResources_f(1, (CUgraphicsResource*)resource, get_current_stream()));
|
|
4521
|
+
}
|
|
4522
|
+
|
|
4523
|
+
void wp_cuda_graphics_device_ptr_and_size(void* context, void* resource, uint64_t* ptr, size_t* size)
|
|
4524
|
+
{
|
|
4525
|
+
ContextGuard guard(context);
|
|
4526
|
+
|
|
4527
|
+
CUdeviceptr device_ptr;
|
|
4528
|
+
size_t bytes;
|
|
4529
|
+
check_cu(cuGraphicsResourceGetMappedPointer_f(&device_ptr, &bytes, *(CUgraphicsResource*)resource));
|
|
4530
|
+
|
|
4531
|
+
*ptr = device_ptr;
|
|
4532
|
+
*size = bytes;
|
|
4533
|
+
}
|
|
4534
|
+
|
|
4535
|
+
void* wp_cuda_graphics_register_gl_buffer(void* context, uint32_t gl_buffer, unsigned int flags)
|
|
4536
|
+
{
|
|
4537
|
+
ContextGuard guard(context);
|
|
4538
|
+
|
|
4539
|
+
CUgraphicsResource *resource = new CUgraphicsResource;
|
|
4540
|
+
bool success = check_cu(cuGraphicsGLRegisterBuffer_f(resource, gl_buffer, flags));
|
|
4541
|
+
if (!success)
|
|
4542
|
+
{
|
|
4543
|
+
delete resource;
|
|
4544
|
+
return NULL;
|
|
4545
|
+
}
|
|
4546
|
+
|
|
4547
|
+
return resource;
|
|
4548
|
+
}
|
|
4549
|
+
|
|
4550
|
+
void wp_cuda_graphics_unregister_resource(void* context, void* resource)
|
|
4551
|
+
{
|
|
4552
|
+
ContextGuard guard(context);
|
|
4553
|
+
|
|
4554
|
+
CUgraphicsResource *res = (CUgraphicsResource*)resource;
|
|
4555
|
+
check_cu(cuGraphicsUnregisterResource_f(*res));
|
|
4556
|
+
delete res;
|
|
4557
|
+
}
|
|
4558
|
+
|
|
4559
|
+
void wp_cuda_timing_begin(int flags)
|
|
4560
|
+
{
|
|
4561
|
+
g_cuda_timing_state = new CudaTimingState(flags, g_cuda_timing_state);
|
|
4562
|
+
}
|
|
4563
|
+
|
|
4564
|
+
int wp_cuda_timing_get_result_count()
|
|
4565
|
+
{
|
|
4566
|
+
if (g_cuda_timing_state)
|
|
4567
|
+
return int(g_cuda_timing_state->ranges.size());
|
|
4568
|
+
return 0;
|
|
4569
|
+
}
|
|
4570
|
+
|
|
4571
|
+
void wp_cuda_timing_end(timing_result_t* results, int size)
|
|
4572
|
+
{
|
|
4573
|
+
if (!g_cuda_timing_state)
|
|
4574
|
+
return;
|
|
4575
|
+
|
|
4576
|
+
// number of results to write to the user buffer
|
|
4577
|
+
int count = std::min(wp_cuda_timing_get_result_count(), size);
|
|
4578
|
+
|
|
4579
|
+
// compute timings and write results
|
|
4580
|
+
for (int i = 0; i < count; i++)
|
|
4581
|
+
{
|
|
4582
|
+
const CudaTimingRange& range = g_cuda_timing_state->ranges[i];
|
|
4583
|
+
timing_result_t& result = results[i];
|
|
4584
|
+
result.context = range.context;
|
|
4585
|
+
result.name = range.name;
|
|
4586
|
+
result.flag = range.flag;
|
|
4587
|
+
check_cuda(cudaEventElapsedTime(&result.elapsed, range.start, range.end));
|
|
4588
|
+
}
|
|
4589
|
+
|
|
4590
|
+
// release events
|
|
4591
|
+
for (CudaTimingRange& range : g_cuda_timing_state->ranges)
|
|
4592
|
+
{
|
|
4593
|
+
check_cu(cuEventDestroy_f(range.start));
|
|
4594
|
+
check_cu(cuEventDestroy_f(range.end));
|
|
4595
|
+
}
|
|
4596
|
+
|
|
4597
|
+
// restore previous state
|
|
4598
|
+
CudaTimingState* parent_state = g_cuda_timing_state->parent;
|
|
4599
|
+
delete g_cuda_timing_state;
|
|
4600
|
+
g_cuda_timing_state = parent_state;
|
|
4601
|
+
}
|
|
4602
|
+
|
|
4603
|
+
//#include "spline.inl"
|
|
4604
|
+
//#include "volume.inl"
|