warp-lang 1.7.0__py3-none-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +139 -0
- warp/__init__.pyi +1 -0
- warp/autograd.py +1142 -0
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/build.py +557 -0
- warp/build_dll.py +405 -0
- warp/builtins.py +6855 -0
- warp/codegen.py +3969 -0
- warp/config.py +158 -0
- warp/constants.py +57 -0
- warp/context.py +6812 -0
- warp/dlpack.py +462 -0
- warp/examples/__init__.py +24 -0
- warp/examples/assets/bear.usd +0 -0
- warp/examples/assets/bunny.usd +0 -0
- warp/examples/assets/cartpole.urdf +110 -0
- warp/examples/assets/crazyflie.usd +0 -0
- warp/examples/assets/cube.usd +0 -0
- warp/examples/assets/nonuniform.usd +0 -0
- warp/examples/assets/nv_ant.xml +92 -0
- warp/examples/assets/nv_humanoid.xml +183 -0
- warp/examples/assets/nvidia_logo.png +0 -0
- warp/examples/assets/pixel.jpg +0 -0
- warp/examples/assets/quadruped.urdf +268 -0
- warp/examples/assets/rocks.nvdb +0 -0
- warp/examples/assets/rocks.usd +0 -0
- warp/examples/assets/sphere.usd +0 -0
- warp/examples/assets/square_cloth.usd +0 -0
- warp/examples/benchmarks/benchmark_api.py +389 -0
- warp/examples/benchmarks/benchmark_cloth.py +296 -0
- warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
- warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
- warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
- warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
- warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
- warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
- warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
- warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
- warp/examples/benchmarks/benchmark_gemm.py +164 -0
- warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
- warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
- warp/examples/benchmarks/benchmark_launches.py +301 -0
- warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
- warp/examples/browse.py +37 -0
- warp/examples/core/example_cupy.py +86 -0
- warp/examples/core/example_dem.py +241 -0
- warp/examples/core/example_fluid.py +299 -0
- warp/examples/core/example_graph_capture.py +150 -0
- warp/examples/core/example_marching_cubes.py +194 -0
- warp/examples/core/example_mesh.py +180 -0
- warp/examples/core/example_mesh_intersect.py +211 -0
- warp/examples/core/example_nvdb.py +182 -0
- warp/examples/core/example_raycast.py +111 -0
- warp/examples/core/example_raymarch.py +205 -0
- warp/examples/core/example_render_opengl.py +193 -0
- warp/examples/core/example_sample_mesh.py +300 -0
- warp/examples/core/example_sph.py +411 -0
- warp/examples/core/example_torch.py +211 -0
- warp/examples/core/example_wave.py +269 -0
- warp/examples/fem/example_adaptive_grid.py +286 -0
- warp/examples/fem/example_apic_fluid.py +423 -0
- warp/examples/fem/example_burgers.py +261 -0
- warp/examples/fem/example_convection_diffusion.py +178 -0
- warp/examples/fem/example_convection_diffusion_dg.py +204 -0
- warp/examples/fem/example_deformed_geometry.py +172 -0
- warp/examples/fem/example_diffusion.py +196 -0
- warp/examples/fem/example_diffusion_3d.py +225 -0
- warp/examples/fem/example_diffusion_mgpu.py +220 -0
- warp/examples/fem/example_distortion_energy.py +228 -0
- warp/examples/fem/example_magnetostatics.py +240 -0
- warp/examples/fem/example_mixed_elasticity.py +291 -0
- warp/examples/fem/example_navier_stokes.py +261 -0
- warp/examples/fem/example_nonconforming_contact.py +298 -0
- warp/examples/fem/example_stokes.py +213 -0
- warp/examples/fem/example_stokes_transfer.py +262 -0
- warp/examples/fem/example_streamlines.py +352 -0
- warp/examples/fem/utils.py +1000 -0
- warp/examples/interop/example_jax_callable.py +116 -0
- warp/examples/interop/example_jax_ffi_callback.py +132 -0
- warp/examples/interop/example_jax_kernel.py +205 -0
- warp/examples/optim/example_bounce.py +266 -0
- warp/examples/optim/example_cloth_throw.py +228 -0
- warp/examples/optim/example_diffray.py +561 -0
- warp/examples/optim/example_drone.py +870 -0
- warp/examples/optim/example_fluid_checkpoint.py +497 -0
- warp/examples/optim/example_inverse_kinematics.py +182 -0
- warp/examples/optim/example_inverse_kinematics_torch.py +191 -0
- warp/examples/optim/example_softbody_properties.py +400 -0
- warp/examples/optim/example_spring_cage.py +245 -0
- warp/examples/optim/example_trajectory.py +227 -0
- warp/examples/sim/example_cartpole.py +143 -0
- warp/examples/sim/example_cloth.py +225 -0
- warp/examples/sim/example_cloth_self_contact.py +322 -0
- warp/examples/sim/example_granular.py +130 -0
- warp/examples/sim/example_granular_collision_sdf.py +202 -0
- warp/examples/sim/example_jacobian_ik.py +244 -0
- warp/examples/sim/example_particle_chain.py +124 -0
- warp/examples/sim/example_quadruped.py +203 -0
- warp/examples/sim/example_rigid_chain.py +203 -0
- warp/examples/sim/example_rigid_contact.py +195 -0
- warp/examples/sim/example_rigid_force.py +133 -0
- warp/examples/sim/example_rigid_gyroscopic.py +115 -0
- warp/examples/sim/example_rigid_soft_contact.py +140 -0
- warp/examples/sim/example_soft_body.py +196 -0
- warp/examples/tile/example_tile_cholesky.py +87 -0
- warp/examples/tile/example_tile_convolution.py +66 -0
- warp/examples/tile/example_tile_fft.py +55 -0
- warp/examples/tile/example_tile_filtering.py +113 -0
- warp/examples/tile/example_tile_matmul.py +85 -0
- warp/examples/tile/example_tile_mlp.py +383 -0
- warp/examples/tile/example_tile_nbody.py +199 -0
- warp/examples/tile/example_tile_walker.py +327 -0
- warp/fabric.py +355 -0
- warp/fem/__init__.py +106 -0
- warp/fem/adaptivity.py +508 -0
- warp/fem/cache.py +572 -0
- warp/fem/dirichlet.py +202 -0
- warp/fem/domain.py +411 -0
- warp/fem/field/__init__.py +125 -0
- warp/fem/field/field.py +619 -0
- warp/fem/field/nodal_field.py +326 -0
- warp/fem/field/restriction.py +37 -0
- warp/fem/field/virtual.py +848 -0
- warp/fem/geometry/__init__.py +32 -0
- warp/fem/geometry/adaptive_nanogrid.py +857 -0
- warp/fem/geometry/closest_point.py +84 -0
- warp/fem/geometry/deformed_geometry.py +221 -0
- warp/fem/geometry/element.py +776 -0
- warp/fem/geometry/geometry.py +362 -0
- warp/fem/geometry/grid_2d.py +392 -0
- warp/fem/geometry/grid_3d.py +452 -0
- warp/fem/geometry/hexmesh.py +911 -0
- warp/fem/geometry/nanogrid.py +571 -0
- warp/fem/geometry/partition.py +389 -0
- warp/fem/geometry/quadmesh.py +663 -0
- warp/fem/geometry/tetmesh.py +855 -0
- warp/fem/geometry/trimesh.py +806 -0
- warp/fem/integrate.py +2335 -0
- warp/fem/linalg.py +419 -0
- warp/fem/operator.py +293 -0
- warp/fem/polynomial.py +229 -0
- warp/fem/quadrature/__init__.py +17 -0
- warp/fem/quadrature/pic_quadrature.py +299 -0
- warp/fem/quadrature/quadrature.py +591 -0
- warp/fem/space/__init__.py +228 -0
- warp/fem/space/basis_function_space.py +468 -0
- warp/fem/space/basis_space.py +667 -0
- warp/fem/space/dof_mapper.py +251 -0
- warp/fem/space/function_space.py +309 -0
- warp/fem/space/grid_2d_function_space.py +177 -0
- warp/fem/space/grid_3d_function_space.py +227 -0
- warp/fem/space/hexmesh_function_space.py +257 -0
- warp/fem/space/nanogrid_function_space.py +201 -0
- warp/fem/space/partition.py +367 -0
- warp/fem/space/quadmesh_function_space.py +223 -0
- warp/fem/space/restriction.py +179 -0
- warp/fem/space/shape/__init__.py +143 -0
- warp/fem/space/shape/cube_shape_function.py +1105 -0
- warp/fem/space/shape/shape_function.py +133 -0
- warp/fem/space/shape/square_shape_function.py +926 -0
- warp/fem/space/shape/tet_shape_function.py +834 -0
- warp/fem/space/shape/triangle_shape_function.py +672 -0
- warp/fem/space/tetmesh_function_space.py +271 -0
- warp/fem/space/topology.py +424 -0
- warp/fem/space/trimesh_function_space.py +194 -0
- warp/fem/types.py +99 -0
- warp/fem/utils.py +420 -0
- warp/jax.py +187 -0
- warp/jax_experimental/__init__.py +16 -0
- warp/jax_experimental/custom_call.py +351 -0
- warp/jax_experimental/ffi.py +698 -0
- warp/jax_experimental/xla_ffi.py +602 -0
- warp/math.py +244 -0
- warp/native/array.h +1145 -0
- warp/native/builtin.h +1800 -0
- warp/native/bvh.cpp +492 -0
- warp/native/bvh.cu +791 -0
- warp/native/bvh.h +554 -0
- warp/native/clang/clang.cpp +536 -0
- warp/native/coloring.cpp +613 -0
- warp/native/crt.cpp +51 -0
- warp/native/crt.h +362 -0
- warp/native/cuda_crt.h +1058 -0
- warp/native/cuda_util.cpp +646 -0
- warp/native/cuda_util.h +307 -0
- warp/native/error.cpp +77 -0
- warp/native/error.h +36 -0
- warp/native/exports.h +1878 -0
- warp/native/fabric.h +245 -0
- warp/native/hashgrid.cpp +311 -0
- warp/native/hashgrid.cu +87 -0
- warp/native/hashgrid.h +240 -0
- warp/native/initializer_array.h +41 -0
- warp/native/intersect.h +1230 -0
- warp/native/intersect_adj.h +375 -0
- warp/native/intersect_tri.h +339 -0
- warp/native/marching.cpp +19 -0
- warp/native/marching.cu +514 -0
- warp/native/marching.h +19 -0
- warp/native/mat.h +2220 -0
- warp/native/mathdx.cpp +87 -0
- warp/native/matnn.h +343 -0
- warp/native/mesh.cpp +266 -0
- warp/native/mesh.cu +404 -0
- warp/native/mesh.h +1980 -0
- warp/native/nanovdb/GridHandle.h +366 -0
- warp/native/nanovdb/HostBuffer.h +590 -0
- warp/native/nanovdb/NanoVDB.h +6624 -0
- warp/native/nanovdb/PNanoVDB.h +3390 -0
- warp/native/noise.h +859 -0
- warp/native/quat.h +1371 -0
- warp/native/rand.h +342 -0
- warp/native/range.h +139 -0
- warp/native/reduce.cpp +174 -0
- warp/native/reduce.cu +364 -0
- warp/native/runlength_encode.cpp +79 -0
- warp/native/runlength_encode.cu +61 -0
- warp/native/scan.cpp +47 -0
- warp/native/scan.cu +53 -0
- warp/native/scan.h +23 -0
- warp/native/solid_angle.h +466 -0
- warp/native/sort.cpp +251 -0
- warp/native/sort.cu +277 -0
- warp/native/sort.h +33 -0
- warp/native/sparse.cpp +378 -0
- warp/native/sparse.cu +524 -0
- warp/native/spatial.h +657 -0
- warp/native/svd.h +702 -0
- warp/native/temp_buffer.h +46 -0
- warp/native/tile.h +2584 -0
- warp/native/tile_reduce.h +264 -0
- warp/native/vec.h +1426 -0
- warp/native/volume.cpp +501 -0
- warp/native/volume.cu +67 -0
- warp/native/volume.h +969 -0
- warp/native/volume_builder.cu +477 -0
- warp/native/volume_builder.h +52 -0
- warp/native/volume_impl.h +70 -0
- warp/native/warp.cpp +1082 -0
- warp/native/warp.cu +3636 -0
- warp/native/warp.h +381 -0
- warp/optim/__init__.py +17 -0
- warp/optim/adam.py +163 -0
- warp/optim/linear.py +1137 -0
- warp/optim/sgd.py +112 -0
- warp/paddle.py +407 -0
- warp/render/__init__.py +18 -0
- warp/render/render_opengl.py +3518 -0
- warp/render/render_usd.py +784 -0
- warp/render/utils.py +160 -0
- warp/sim/__init__.py +65 -0
- warp/sim/articulation.py +793 -0
- warp/sim/collide.py +2395 -0
- warp/sim/graph_coloring.py +300 -0
- warp/sim/import_mjcf.py +790 -0
- warp/sim/import_snu.py +227 -0
- warp/sim/import_urdf.py +579 -0
- warp/sim/import_usd.py +894 -0
- warp/sim/inertia.py +324 -0
- warp/sim/integrator.py +242 -0
- warp/sim/integrator_euler.py +1997 -0
- warp/sim/integrator_featherstone.py +2101 -0
- warp/sim/integrator_vbd.py +2048 -0
- warp/sim/integrator_xpbd.py +3292 -0
- warp/sim/model.py +4791 -0
- warp/sim/particles.py +121 -0
- warp/sim/render.py +427 -0
- warp/sim/utils.py +428 -0
- warp/sparse.py +2057 -0
- warp/stubs.py +3333 -0
- warp/tape.py +1203 -0
- warp/tests/__init__.py +1 -0
- warp/tests/__main__.py +4 -0
- warp/tests/assets/curlnoise_golden.npy +0 -0
- warp/tests/assets/mlp_golden.npy +0 -0
- warp/tests/assets/pixel.npy +0 -0
- warp/tests/assets/pnoise_golden.npy +0 -0
- warp/tests/assets/spiky.usd +0 -0
- warp/tests/assets/test_grid.nvdb +0 -0
- warp/tests/assets/test_index_grid.nvdb +0 -0
- warp/tests/assets/test_int32_grid.nvdb +0 -0
- warp/tests/assets/test_vec_grid.nvdb +0 -0
- warp/tests/assets/torus.nvdb +0 -0
- warp/tests/assets/torus.usda +105 -0
- warp/tests/aux_test_class_kernel.py +34 -0
- warp/tests/aux_test_compile_consts_dummy.py +18 -0
- warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
- warp/tests/aux_test_dependent.py +29 -0
- warp/tests/aux_test_grad_customs.py +29 -0
- warp/tests/aux_test_instancing_gc.py +26 -0
- warp/tests/aux_test_module_unload.py +23 -0
- warp/tests/aux_test_name_clash1.py +40 -0
- warp/tests/aux_test_name_clash2.py +40 -0
- warp/tests/aux_test_reference.py +9 -0
- warp/tests/aux_test_reference_reference.py +8 -0
- warp/tests/aux_test_square.py +16 -0
- warp/tests/aux_test_unresolved_func.py +22 -0
- warp/tests/aux_test_unresolved_symbol.py +22 -0
- warp/tests/cuda/__init__.py +0 -0
- warp/tests/cuda/test_async.py +676 -0
- warp/tests/cuda/test_ipc.py +124 -0
- warp/tests/cuda/test_mempool.py +233 -0
- warp/tests/cuda/test_multigpu.py +169 -0
- warp/tests/cuda/test_peer.py +139 -0
- warp/tests/cuda/test_pinned.py +84 -0
- warp/tests/cuda/test_streams.py +634 -0
- warp/tests/geometry/__init__.py +0 -0
- warp/tests/geometry/test_bvh.py +200 -0
- warp/tests/geometry/test_hash_grid.py +221 -0
- warp/tests/geometry/test_marching_cubes.py +74 -0
- warp/tests/geometry/test_mesh.py +316 -0
- warp/tests/geometry/test_mesh_query_aabb.py +399 -0
- warp/tests/geometry/test_mesh_query_point.py +932 -0
- warp/tests/geometry/test_mesh_query_ray.py +311 -0
- warp/tests/geometry/test_volume.py +1103 -0
- warp/tests/geometry/test_volume_write.py +346 -0
- warp/tests/interop/__init__.py +0 -0
- warp/tests/interop/test_dlpack.py +729 -0
- warp/tests/interop/test_jax.py +371 -0
- warp/tests/interop/test_paddle.py +800 -0
- warp/tests/interop/test_torch.py +1001 -0
- warp/tests/run_coverage_serial.py +39 -0
- warp/tests/sim/__init__.py +0 -0
- warp/tests/sim/disabled_kinematics.py +244 -0
- warp/tests/sim/flaky_test_sim_grad.py +290 -0
- warp/tests/sim/test_collision.py +604 -0
- warp/tests/sim/test_coloring.py +258 -0
- warp/tests/sim/test_model.py +224 -0
- warp/tests/sim/test_sim_grad_bounce_linear.py +212 -0
- warp/tests/sim/test_sim_kinematics.py +98 -0
- warp/tests/sim/test_vbd.py +597 -0
- warp/tests/test_adam.py +163 -0
- warp/tests/test_arithmetic.py +1096 -0
- warp/tests/test_array.py +2972 -0
- warp/tests/test_array_reduce.py +156 -0
- warp/tests/test_assert.py +250 -0
- warp/tests/test_atomic.py +153 -0
- warp/tests/test_bool.py +220 -0
- warp/tests/test_builtins_resolution.py +1298 -0
- warp/tests/test_closest_point_edge_edge.py +327 -0
- warp/tests/test_codegen.py +810 -0
- warp/tests/test_codegen_instancing.py +1495 -0
- warp/tests/test_compile_consts.py +215 -0
- warp/tests/test_conditional.py +252 -0
- warp/tests/test_context.py +42 -0
- warp/tests/test_copy.py +238 -0
- warp/tests/test_ctypes.py +638 -0
- warp/tests/test_dense.py +73 -0
- warp/tests/test_devices.py +97 -0
- warp/tests/test_examples.py +482 -0
- warp/tests/test_fabricarray.py +996 -0
- warp/tests/test_fast_math.py +74 -0
- warp/tests/test_fem.py +2003 -0
- warp/tests/test_fp16.py +136 -0
- warp/tests/test_func.py +454 -0
- warp/tests/test_future_annotations.py +98 -0
- warp/tests/test_generics.py +656 -0
- warp/tests/test_grad.py +893 -0
- warp/tests/test_grad_customs.py +339 -0
- warp/tests/test_grad_debug.py +341 -0
- warp/tests/test_implicit_init.py +411 -0
- warp/tests/test_import.py +45 -0
- warp/tests/test_indexedarray.py +1140 -0
- warp/tests/test_intersect.py +73 -0
- warp/tests/test_iter.py +76 -0
- warp/tests/test_large.py +177 -0
- warp/tests/test_launch.py +411 -0
- warp/tests/test_lerp.py +151 -0
- warp/tests/test_linear_solvers.py +193 -0
- warp/tests/test_lvalue.py +427 -0
- warp/tests/test_mat.py +2089 -0
- warp/tests/test_mat_lite.py +122 -0
- warp/tests/test_mat_scalar_ops.py +2913 -0
- warp/tests/test_math.py +178 -0
- warp/tests/test_mlp.py +282 -0
- warp/tests/test_module_hashing.py +258 -0
- warp/tests/test_modules_lite.py +44 -0
- warp/tests/test_noise.py +252 -0
- warp/tests/test_operators.py +299 -0
- warp/tests/test_options.py +129 -0
- warp/tests/test_overwrite.py +551 -0
- warp/tests/test_print.py +339 -0
- warp/tests/test_quat.py +2315 -0
- warp/tests/test_rand.py +339 -0
- warp/tests/test_reload.py +302 -0
- warp/tests/test_rounding.py +185 -0
- warp/tests/test_runlength_encode.py +196 -0
- warp/tests/test_scalar_ops.py +105 -0
- warp/tests/test_smoothstep.py +108 -0
- warp/tests/test_snippet.py +318 -0
- warp/tests/test_sparse.py +582 -0
- warp/tests/test_spatial.py +2229 -0
- warp/tests/test_special_values.py +361 -0
- warp/tests/test_static.py +592 -0
- warp/tests/test_struct.py +734 -0
- warp/tests/test_tape.py +204 -0
- warp/tests/test_transient_module.py +93 -0
- warp/tests/test_triangle_closest_point.py +145 -0
- warp/tests/test_types.py +562 -0
- warp/tests/test_utils.py +588 -0
- warp/tests/test_vec.py +1487 -0
- warp/tests/test_vec_lite.py +80 -0
- warp/tests/test_vec_scalar_ops.py +2327 -0
- warp/tests/test_verify_fp.py +100 -0
- warp/tests/tile/__init__.py +0 -0
- warp/tests/tile/test_tile.py +780 -0
- warp/tests/tile/test_tile_load.py +407 -0
- warp/tests/tile/test_tile_mathdx.py +208 -0
- warp/tests/tile/test_tile_mlp.py +402 -0
- warp/tests/tile/test_tile_reduce.py +447 -0
- warp/tests/tile/test_tile_shared_memory.py +247 -0
- warp/tests/tile/test_tile_view.py +173 -0
- warp/tests/unittest_serial.py +47 -0
- warp/tests/unittest_suites.py +427 -0
- warp/tests/unittest_utils.py +468 -0
- warp/tests/walkthrough_debug.py +93 -0
- warp/thirdparty/__init__.py +0 -0
- warp/thirdparty/appdirs.py +598 -0
- warp/thirdparty/dlpack.py +145 -0
- warp/thirdparty/unittest_parallel.py +570 -0
- warp/torch.py +391 -0
- warp/types.py +5230 -0
- warp/utils.py +1137 -0
- warp_lang-1.7.0.dist-info/METADATA +516 -0
- warp_lang-1.7.0.dist-info/RECORD +429 -0
- warp_lang-1.7.0.dist-info/WHEEL +5 -0
- warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
- warp_lang-1.7.0.dist-info/top_level.txt +1 -0
warp/native/warp.cu
ADDED
|
@@ -0,0 +1,3636 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
3
|
+
* SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
*
|
|
5
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
* you may not use this file except in compliance with the License.
|
|
7
|
+
* You may obtain a copy of the License at
|
|
8
|
+
*
|
|
9
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
*
|
|
11
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
* See the License for the specific language governing permissions and
|
|
15
|
+
* limitations under the License.
|
|
16
|
+
*/
|
|
17
|
+
|
|
18
|
+
#include "warp.h"
|
|
19
|
+
#include "scan.h"
|
|
20
|
+
#include "cuda_util.h"
|
|
21
|
+
#include "error.h"
|
|
22
|
+
|
|
23
|
+
#include <cstdlib>
|
|
24
|
+
#include <fstream>
|
|
25
|
+
#include <nvrtc.h>
|
|
26
|
+
#include <nvPTXCompiler.h>
|
|
27
|
+
#if WP_ENABLE_MATHDX
|
|
28
|
+
#include <nvJitLink.h>
|
|
29
|
+
#include <libmathdx.h>
|
|
30
|
+
#endif
|
|
31
|
+
|
|
32
|
+
#include <array>
|
|
33
|
+
#include <algorithm>
|
|
34
|
+
#include <iterator>
|
|
35
|
+
#include <list>
|
|
36
|
+
#include <map>
|
|
37
|
+
#include <string>
|
|
38
|
+
#include <unordered_map>
|
|
39
|
+
#include <unordered_set>
|
|
40
|
+
#include <vector>
|
|
41
|
+
|
|
42
|
+
#define check_any(result) (check_generic(result, __FILE__, __LINE__))
|
|
43
|
+
#define check_nvrtc(code) (check_nvrtc_result(code, __FILE__, __LINE__))
|
|
44
|
+
#define check_nvptx(code) (check_nvptx_result(code, __FILE__, __LINE__))
|
|
45
|
+
#define check_nvjitlink(handle, code) (check_nvjitlink_result(handle, code, __FILE__, __LINE__))
|
|
46
|
+
#define check_cufftdx(code) (check_cufftdx_result(code, __FILE__, __LINE__))
|
|
47
|
+
#define check_cublasdx(code) (check_cublasdx_result(code, __FILE__, __LINE__))
|
|
48
|
+
#define check_cusolver(code) (check_cusolver_result(code, __FILE__, __LINE__))
|
|
49
|
+
#define CHECK_ANY(code) \
|
|
50
|
+
{ \
|
|
51
|
+
do { \
|
|
52
|
+
bool out = (check_any(code)); \
|
|
53
|
+
if(!out) { \
|
|
54
|
+
return out; \
|
|
55
|
+
} \
|
|
56
|
+
} while(0); \
|
|
57
|
+
}
|
|
58
|
+
#define CHECK_CUFFTDX(code) \
|
|
59
|
+
{ \
|
|
60
|
+
do { \
|
|
61
|
+
bool out = (check_cufftdx(code)); \
|
|
62
|
+
if(!out) { \
|
|
63
|
+
return out; \
|
|
64
|
+
} \
|
|
65
|
+
} while(0); \
|
|
66
|
+
}
|
|
67
|
+
#define CHECK_CUBLASDX(code) \
|
|
68
|
+
{ \
|
|
69
|
+
do { \
|
|
70
|
+
bool out = (check_cufftdx(code)); \
|
|
71
|
+
if(!out) { \
|
|
72
|
+
return out; \
|
|
73
|
+
} \
|
|
74
|
+
} while(0); \
|
|
75
|
+
}
|
|
76
|
+
#define CHECK_CUSOLVER(code) \
|
|
77
|
+
{ \
|
|
78
|
+
do { \
|
|
79
|
+
bool out = (check_cusolver(code)); \
|
|
80
|
+
if(!out) { \
|
|
81
|
+
return out; \
|
|
82
|
+
} \
|
|
83
|
+
} while(0); \
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
bool check_nvrtc_result(nvrtcResult result, const char* file, int line)
|
|
87
|
+
{
|
|
88
|
+
if (result == NVRTC_SUCCESS)
|
|
89
|
+
return true;
|
|
90
|
+
|
|
91
|
+
const char* error_string = nvrtcGetErrorString(result);
|
|
92
|
+
fprintf(stderr, "Warp NVRTC compilation error %u: %s (%s:%d)\n", unsigned(result), error_string, file, line);
|
|
93
|
+
return false;
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
bool check_nvptx_result(nvPTXCompileResult result, const char* file, int line)
|
|
97
|
+
{
|
|
98
|
+
if (result == NVPTXCOMPILE_SUCCESS)
|
|
99
|
+
return true;
|
|
100
|
+
|
|
101
|
+
const char* error_string;
|
|
102
|
+
switch (result)
|
|
103
|
+
{
|
|
104
|
+
case NVPTXCOMPILE_ERROR_INVALID_COMPILER_HANDLE:
|
|
105
|
+
error_string = "Invalid compiler handle";
|
|
106
|
+
break;
|
|
107
|
+
case NVPTXCOMPILE_ERROR_INVALID_INPUT:
|
|
108
|
+
error_string = "Invalid input";
|
|
109
|
+
break;
|
|
110
|
+
case NVPTXCOMPILE_ERROR_COMPILATION_FAILURE:
|
|
111
|
+
error_string = "Compilation failure";
|
|
112
|
+
break;
|
|
113
|
+
case NVPTXCOMPILE_ERROR_INTERNAL:
|
|
114
|
+
error_string = "Internal error";
|
|
115
|
+
break;
|
|
116
|
+
case NVPTXCOMPILE_ERROR_OUT_OF_MEMORY:
|
|
117
|
+
error_string = "Out of memory";
|
|
118
|
+
break;
|
|
119
|
+
case NVPTXCOMPILE_ERROR_COMPILER_INVOCATION_INCOMPLETE:
|
|
120
|
+
error_string = "Incomplete compiler invocation";
|
|
121
|
+
break;
|
|
122
|
+
case NVPTXCOMPILE_ERROR_UNSUPPORTED_PTX_VERSION:
|
|
123
|
+
error_string = "Unsupported PTX version";
|
|
124
|
+
break;
|
|
125
|
+
default:
|
|
126
|
+
error_string = "Unknown error";
|
|
127
|
+
break;
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
fprintf(stderr, "Warp PTX compilation error %u: %s (%s:%d)\n", unsigned(result), error_string, file, line);
|
|
131
|
+
return false;
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
bool check_generic(int result, const char* file, int line)
|
|
135
|
+
{
|
|
136
|
+
if (!result) {
|
|
137
|
+
fprintf(stderr, "Error %d on %s:%d\n", (int)result, file, line);
|
|
138
|
+
return false;
|
|
139
|
+
} else {
|
|
140
|
+
return true;
|
|
141
|
+
}
|
|
142
|
+
}
|
|
143
|
+
|
|
144
|
+
struct DeviceInfo
|
|
145
|
+
{
|
|
146
|
+
static constexpr int kNameLen = 128;
|
|
147
|
+
|
|
148
|
+
CUdevice device = -1;
|
|
149
|
+
CUuuid uuid = {0};
|
|
150
|
+
int ordinal = -1;
|
|
151
|
+
int pci_domain_id = -1;
|
|
152
|
+
int pci_bus_id = -1;
|
|
153
|
+
int pci_device_id = -1;
|
|
154
|
+
char name[kNameLen] = "";
|
|
155
|
+
int arch = 0;
|
|
156
|
+
int is_uva = 0;
|
|
157
|
+
int is_mempool_supported = 0;
|
|
158
|
+
int is_ipc_supported = -1;
|
|
159
|
+
int max_smem_bytes = 0;
|
|
160
|
+
CUcontext primary_context = NULL;
|
|
161
|
+
};
|
|
162
|
+
|
|
163
|
+
struct ContextInfo
|
|
164
|
+
{
|
|
165
|
+
DeviceInfo* device_info = NULL;
|
|
166
|
+
|
|
167
|
+
// the current stream, managed from Python (see cuda_context_set_stream() and cuda_context_get_stream())
|
|
168
|
+
CUstream stream = NULL;
|
|
169
|
+
};
|
|
170
|
+
|
|
171
|
+
struct CaptureInfo
|
|
172
|
+
{
|
|
173
|
+
CUstream stream = NULL; // the main stream where capture begins and ends
|
|
174
|
+
uint64_t id = 0; // unique capture id from CUDA
|
|
175
|
+
bool external = false; // whether this is an external capture
|
|
176
|
+
};
|
|
177
|
+
|
|
178
|
+
struct StreamInfo
|
|
179
|
+
{
|
|
180
|
+
CUevent cached_event = NULL; // event used for stream synchronization (cached to avoid creating temporary events)
|
|
181
|
+
CaptureInfo* capture = NULL; // capture info (only if started on this stream)
|
|
182
|
+
};
|
|
183
|
+
|
|
184
|
+
struct GraphInfo
|
|
185
|
+
{
|
|
186
|
+
std::vector<void*> unfreed_allocs;
|
|
187
|
+
};
|
|
188
|
+
|
|
189
|
+
// Information for graph allocations that are not freed by the graph.
|
|
190
|
+
// These allocations have a shared ownership:
|
|
191
|
+
// - The graph instance allocates/maps the memory on each launch, even if the user reference is released.
|
|
192
|
+
// - The user reference must remain valid even if the graph is destroyed.
|
|
193
|
+
// The memory will be freed once the user reference is released and the graph is destroyed.
|
|
194
|
+
struct GraphAllocInfo
|
|
195
|
+
{
|
|
196
|
+
uint64_t capture_id = 0;
|
|
197
|
+
void* context = NULL;
|
|
198
|
+
bool ref_exists = false; // whether user reference still exists
|
|
199
|
+
bool graph_destroyed = false; // whether graph instance was destroyed
|
|
200
|
+
};
|
|
201
|
+
|
|
202
|
+
// Information used when deferring deallocations.
|
|
203
|
+
struct FreeInfo
|
|
204
|
+
{
|
|
205
|
+
void* context = NULL;
|
|
206
|
+
void* ptr = NULL;
|
|
207
|
+
bool is_async = false;
|
|
208
|
+
};
|
|
209
|
+
|
|
210
|
+
// Information used when deferring module unloading.
|
|
211
|
+
struct ModuleInfo
|
|
212
|
+
{
|
|
213
|
+
void* context = NULL;
|
|
214
|
+
void* module = NULL;
|
|
215
|
+
};
|
|
216
|
+
|
|
217
|
+
static std::unordered_map<CUfunction, std::string> g_kernel_names;
|
|
218
|
+
|
|
219
|
+
// cached info for all devices, indexed by ordinal
|
|
220
|
+
static std::vector<DeviceInfo> g_devices;
|
|
221
|
+
|
|
222
|
+
// maps CUdevice to DeviceInfo
|
|
223
|
+
static std::map<CUdevice, DeviceInfo*> g_device_map;
|
|
224
|
+
|
|
225
|
+
// cached info for all known contexts
|
|
226
|
+
static std::map<CUcontext, ContextInfo> g_contexts;
|
|
227
|
+
|
|
228
|
+
// cached info for all known streams (including registered external streams)
|
|
229
|
+
static std::unordered_map<CUstream, StreamInfo> g_streams;
|
|
230
|
+
|
|
231
|
+
// Ongoing graph captures registered using wp.capture_begin().
|
|
232
|
+
// This maps the capture id to the stream where capture was started.
|
|
233
|
+
// See cuda_graph_begin_capture(), cuda_graph_end_capture(), and free_device_async().
|
|
234
|
+
static std::unordered_map<uint64_t, CaptureInfo*> g_captures;
|
|
235
|
+
|
|
236
|
+
// Memory allocated during graph capture requires special handling.
|
|
237
|
+
// See alloc_device_async() and free_device_async().
|
|
238
|
+
static std::unordered_map<void*, GraphAllocInfo> g_graph_allocs;
|
|
239
|
+
|
|
240
|
+
// Memory that cannot be freed immediately gets queued here.
|
|
241
|
+
// Call free_deferred_allocs() to release.
|
|
242
|
+
static std::vector<FreeInfo> g_deferred_free_list;
|
|
243
|
+
|
|
244
|
+
// Modules that cannot be unloaded immediately get queued here.
|
|
245
|
+
// Call unload_deferred_modules() to release.
|
|
246
|
+
static std::vector<ModuleInfo> g_deferred_module_list;
|
|
247
|
+
|
|
248
|
+
void cuda_set_context_restore_policy(bool always_restore)
|
|
249
|
+
{
|
|
250
|
+
ContextGuard::always_restore = always_restore;
|
|
251
|
+
}
|
|
252
|
+
|
|
253
|
+
int cuda_get_context_restore_policy()
|
|
254
|
+
{
|
|
255
|
+
return int(ContextGuard::always_restore);
|
|
256
|
+
}
|
|
257
|
+
|
|
258
|
+
int cuda_init()
|
|
259
|
+
{
|
|
260
|
+
if (!init_cuda_driver())
|
|
261
|
+
return -1;
|
|
262
|
+
|
|
263
|
+
int device_count = 0;
|
|
264
|
+
if (check_cu(cuDeviceGetCount_f(&device_count)))
|
|
265
|
+
{
|
|
266
|
+
g_devices.resize(device_count);
|
|
267
|
+
|
|
268
|
+
for (int i = 0; i < device_count; i++)
|
|
269
|
+
{
|
|
270
|
+
CUdevice device;
|
|
271
|
+
if (check_cu(cuDeviceGet_f(&device, i)))
|
|
272
|
+
{
|
|
273
|
+
// query device info
|
|
274
|
+
g_devices[i].device = device;
|
|
275
|
+
g_devices[i].ordinal = i;
|
|
276
|
+
check_cu(cuDeviceGetName_f(g_devices[i].name, DeviceInfo::kNameLen, device));
|
|
277
|
+
check_cu(cuDeviceGetUuid_f(&g_devices[i].uuid, device));
|
|
278
|
+
check_cu(cuDeviceGetAttribute_f(&g_devices[i].pci_domain_id, CU_DEVICE_ATTRIBUTE_PCI_DOMAIN_ID, device));
|
|
279
|
+
check_cu(cuDeviceGetAttribute_f(&g_devices[i].pci_bus_id, CU_DEVICE_ATTRIBUTE_PCI_BUS_ID, device));
|
|
280
|
+
check_cu(cuDeviceGetAttribute_f(&g_devices[i].pci_device_id, CU_DEVICE_ATTRIBUTE_PCI_DEVICE_ID, device));
|
|
281
|
+
check_cu(cuDeviceGetAttribute_f(&g_devices[i].is_uva, CU_DEVICE_ATTRIBUTE_UNIFIED_ADDRESSING, device));
|
|
282
|
+
check_cu(cuDeviceGetAttribute_f(&g_devices[i].is_mempool_supported, CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED, device));
|
|
283
|
+
#ifdef CUDA_VERSION
|
|
284
|
+
#if CUDA_VERSION >= 12000
|
|
285
|
+
int device_attribute_integrated = 0;
|
|
286
|
+
check_cu(cuDeviceGetAttribute_f(&device_attribute_integrated, CU_DEVICE_ATTRIBUTE_INTEGRATED, device));
|
|
287
|
+
if (device_attribute_integrated == 0)
|
|
288
|
+
{
|
|
289
|
+
check_cu(cuDeviceGetAttribute_f(&g_devices[i].is_ipc_supported, CU_DEVICE_ATTRIBUTE_IPC_EVENT_SUPPORTED, device));
|
|
290
|
+
}
|
|
291
|
+
else
|
|
292
|
+
{
|
|
293
|
+
// integrated devices do not support CUDA IPC
|
|
294
|
+
g_devices[i].is_ipc_supported = 0;
|
|
295
|
+
}
|
|
296
|
+
#endif
|
|
297
|
+
#endif
|
|
298
|
+
check_cu(cuDeviceGetAttribute_f(&g_devices[i].max_smem_bytes, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, device));
|
|
299
|
+
int major = 0;
|
|
300
|
+
int minor = 0;
|
|
301
|
+
check_cu(cuDeviceGetAttribute_f(&major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device));
|
|
302
|
+
check_cu(cuDeviceGetAttribute_f(&minor, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, device));
|
|
303
|
+
g_devices[i].arch = 10 * major + minor;
|
|
304
|
+
|
|
305
|
+
g_device_map[device] = &g_devices[i];
|
|
306
|
+
}
|
|
307
|
+
else
|
|
308
|
+
{
|
|
309
|
+
return -1;
|
|
310
|
+
}
|
|
311
|
+
}
|
|
312
|
+
}
|
|
313
|
+
else
|
|
314
|
+
{
|
|
315
|
+
return -1;
|
|
316
|
+
}
|
|
317
|
+
|
|
318
|
+
// initialize default timing state
|
|
319
|
+
static CudaTimingState default_timing_state(0, NULL);
|
|
320
|
+
g_cuda_timing_state = &default_timing_state;
|
|
321
|
+
|
|
322
|
+
return 0;
|
|
323
|
+
}
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
static inline CUcontext get_current_context()
|
|
327
|
+
{
|
|
328
|
+
CUcontext ctx;
|
|
329
|
+
if (check_cu(cuCtxGetCurrent_f(&ctx)))
|
|
330
|
+
return ctx;
|
|
331
|
+
else
|
|
332
|
+
return NULL;
|
|
333
|
+
}
|
|
334
|
+
|
|
335
|
+
static inline CUstream get_current_stream(void* context=NULL)
|
|
336
|
+
{
|
|
337
|
+
return static_cast<CUstream>(cuda_context_get_stream(context));
|
|
338
|
+
}
|
|
339
|
+
|
|
340
|
+
static ContextInfo* get_context_info(CUcontext ctx)
|
|
341
|
+
{
|
|
342
|
+
if (!ctx)
|
|
343
|
+
{
|
|
344
|
+
ctx = get_current_context();
|
|
345
|
+
if (!ctx)
|
|
346
|
+
return NULL;
|
|
347
|
+
}
|
|
348
|
+
|
|
349
|
+
auto it = g_contexts.find(ctx);
|
|
350
|
+
if (it != g_contexts.end())
|
|
351
|
+
{
|
|
352
|
+
return &it->second;
|
|
353
|
+
}
|
|
354
|
+
else
|
|
355
|
+
{
|
|
356
|
+
// previously unseen context, add the info
|
|
357
|
+
ContextGuard guard(ctx, true);
|
|
358
|
+
|
|
359
|
+
CUdevice device;
|
|
360
|
+
if (check_cu(cuCtxGetDevice_f(&device)))
|
|
361
|
+
{
|
|
362
|
+
DeviceInfo* device_info = g_device_map[device];
|
|
363
|
+
|
|
364
|
+
// workaround for https://nvbugspro.nvidia.com/bug/4456003
|
|
365
|
+
if (device_info->is_mempool_supported)
|
|
366
|
+
{
|
|
367
|
+
void* dummy = NULL;
|
|
368
|
+
check_cuda(cudaMallocAsync(&dummy, 1, NULL));
|
|
369
|
+
check_cuda(cudaFreeAsync(dummy, NULL));
|
|
370
|
+
}
|
|
371
|
+
|
|
372
|
+
ContextInfo context_info;
|
|
373
|
+
context_info.device_info = device_info;
|
|
374
|
+
auto result = g_contexts.insert(std::make_pair(ctx, context_info));
|
|
375
|
+
return &result.first->second;
|
|
376
|
+
}
|
|
377
|
+
}
|
|
378
|
+
|
|
379
|
+
return NULL;
|
|
380
|
+
}
|
|
381
|
+
|
|
382
|
+
static inline ContextInfo* get_context_info(void* context)
|
|
383
|
+
{
|
|
384
|
+
return get_context_info(static_cast<CUcontext>(context));
|
|
385
|
+
}
|
|
386
|
+
|
|
387
|
+
static inline StreamInfo* get_stream_info(CUstream stream)
|
|
388
|
+
{
|
|
389
|
+
auto it = g_streams.find(stream);
|
|
390
|
+
if (it != g_streams.end())
|
|
391
|
+
return &it->second;
|
|
392
|
+
else
|
|
393
|
+
return NULL;
|
|
394
|
+
}
|
|
395
|
+
|
|
396
|
+
static void deferred_free(void* ptr, void* context, bool is_async)
|
|
397
|
+
{
|
|
398
|
+
FreeInfo free_info;
|
|
399
|
+
free_info.ptr = ptr;
|
|
400
|
+
free_info.context = context ? context : get_current_context();
|
|
401
|
+
free_info.is_async = is_async;
|
|
402
|
+
g_deferred_free_list.push_back(free_info);
|
|
403
|
+
}
|
|
404
|
+
|
|
405
|
+
static int free_deferred_allocs(void* context = NULL)
|
|
406
|
+
{
|
|
407
|
+
if (g_deferred_free_list.empty() || !g_captures.empty())
|
|
408
|
+
return 0;
|
|
409
|
+
|
|
410
|
+
int num_freed_allocs = 0;
|
|
411
|
+
for (auto it = g_deferred_free_list.begin(); it != g_deferred_free_list.end(); /*noop*/)
|
|
412
|
+
{
|
|
413
|
+
const FreeInfo& free_info = *it;
|
|
414
|
+
|
|
415
|
+
// free the pointer if it matches the given context or if the context is unspecified
|
|
416
|
+
if (free_info.context == context || !context)
|
|
417
|
+
{
|
|
418
|
+
ContextGuard guard(free_info.context);
|
|
419
|
+
|
|
420
|
+
if (free_info.is_async)
|
|
421
|
+
{
|
|
422
|
+
// this could be a regular stream-ordered allocation or a graph allocation
|
|
423
|
+
cudaError_t res = cudaFreeAsync(free_info.ptr, NULL);
|
|
424
|
+
if (res != cudaSuccess)
|
|
425
|
+
{
|
|
426
|
+
if (res == cudaErrorInvalidValue)
|
|
427
|
+
{
|
|
428
|
+
// This can happen if we try to release the pointer but the graph was
|
|
429
|
+
// never launched, so the memory isn't mapped.
|
|
430
|
+
// This is fine, so clear the error.
|
|
431
|
+
cudaGetLastError();
|
|
432
|
+
}
|
|
433
|
+
else
|
|
434
|
+
{
|
|
435
|
+
// something else went wrong, report error
|
|
436
|
+
check_cuda(res);
|
|
437
|
+
}
|
|
438
|
+
}
|
|
439
|
+
}
|
|
440
|
+
else
|
|
441
|
+
{
|
|
442
|
+
check_cuda(cudaFree(free_info.ptr));
|
|
443
|
+
}
|
|
444
|
+
|
|
445
|
+
++num_freed_allocs;
|
|
446
|
+
|
|
447
|
+
it = g_deferred_free_list.erase(it);
|
|
448
|
+
}
|
|
449
|
+
else
|
|
450
|
+
{
|
|
451
|
+
++it;
|
|
452
|
+
}
|
|
453
|
+
}
|
|
454
|
+
|
|
455
|
+
return num_freed_allocs;
|
|
456
|
+
}
|
|
457
|
+
|
|
458
|
+
static int unload_deferred_modules(void* context = NULL)
|
|
459
|
+
{
|
|
460
|
+
if (g_deferred_module_list.empty() || !g_captures.empty())
|
|
461
|
+
return 0;
|
|
462
|
+
|
|
463
|
+
int num_unloaded_modules = 0;
|
|
464
|
+
for (auto it = g_deferred_module_list.begin(); it != g_deferred_module_list.end(); /*noop*/)
|
|
465
|
+
{
|
|
466
|
+
// free the module if it matches the given context or if the context is unspecified
|
|
467
|
+
const ModuleInfo& module_info = *it;
|
|
468
|
+
if (module_info.context == context || !context)
|
|
469
|
+
{
|
|
470
|
+
cuda_unload_module(module_info.context, module_info.module);
|
|
471
|
+
++num_unloaded_modules;
|
|
472
|
+
it = g_deferred_module_list.erase(it);
|
|
473
|
+
}
|
|
474
|
+
else
|
|
475
|
+
{
|
|
476
|
+
++it;
|
|
477
|
+
}
|
|
478
|
+
}
|
|
479
|
+
|
|
480
|
+
return num_unloaded_modules;
|
|
481
|
+
}
|
|
482
|
+
|
|
483
|
+
static void CUDART_CB on_graph_destroy(void* user_data)
|
|
484
|
+
{
|
|
485
|
+
if (!user_data)
|
|
486
|
+
return;
|
|
487
|
+
|
|
488
|
+
GraphInfo* graph_info = static_cast<GraphInfo*>(user_data);
|
|
489
|
+
|
|
490
|
+
for (void* ptr : graph_info->unfreed_allocs)
|
|
491
|
+
{
|
|
492
|
+
auto alloc_iter = g_graph_allocs.find(ptr);
|
|
493
|
+
if (alloc_iter != g_graph_allocs.end())
|
|
494
|
+
{
|
|
495
|
+
GraphAllocInfo& alloc_info = alloc_iter->second;
|
|
496
|
+
if (alloc_info.ref_exists)
|
|
497
|
+
{
|
|
498
|
+
// unreference from graph so the pointer will be deallocated when the user reference goes away
|
|
499
|
+
alloc_info.graph_destroyed = true;
|
|
500
|
+
}
|
|
501
|
+
else
|
|
502
|
+
{
|
|
503
|
+
// the pointer can be freed, but we can't call CUDA functions in this callback, so defer it
|
|
504
|
+
deferred_free(ptr, alloc_info.context, true);
|
|
505
|
+
g_graph_allocs.erase(alloc_iter);
|
|
506
|
+
}
|
|
507
|
+
}
|
|
508
|
+
}
|
|
509
|
+
|
|
510
|
+
delete graph_info;
|
|
511
|
+
}
|
|
512
|
+
|
|
513
|
+
static inline const char* get_cuda_kernel_name(void* kernel)
|
|
514
|
+
{
|
|
515
|
+
CUfunction cuda_func = static_cast<CUfunction>(kernel);
|
|
516
|
+
auto name_iter = g_kernel_names.find((CUfunction)cuda_func);
|
|
517
|
+
if (name_iter != g_kernel_names.end())
|
|
518
|
+
return name_iter->second.c_str();
|
|
519
|
+
else
|
|
520
|
+
return "unknown_kernel";
|
|
521
|
+
}
|
|
522
|
+
|
|
523
|
+
|
|
524
|
+
void* alloc_pinned(size_t s)
|
|
525
|
+
{
|
|
526
|
+
void* ptr = NULL;
|
|
527
|
+
check_cuda(cudaMallocHost(&ptr, s));
|
|
528
|
+
return ptr;
|
|
529
|
+
}
|
|
530
|
+
|
|
531
|
+
void free_pinned(void* ptr)
|
|
532
|
+
{
|
|
533
|
+
cudaFreeHost(ptr);
|
|
534
|
+
}
|
|
535
|
+
|
|
536
|
+
void* alloc_device(void* context, size_t s)
|
|
537
|
+
{
|
|
538
|
+
int ordinal = cuda_context_get_device_ordinal(context);
|
|
539
|
+
|
|
540
|
+
// use stream-ordered allocator if available
|
|
541
|
+
if (cuda_device_is_mempool_supported(ordinal))
|
|
542
|
+
return alloc_device_async(context, s);
|
|
543
|
+
else
|
|
544
|
+
return alloc_device_default(context, s);
|
|
545
|
+
}
|
|
546
|
+
|
|
547
|
+
void free_device(void* context, void* ptr)
|
|
548
|
+
{
|
|
549
|
+
int ordinal = cuda_context_get_device_ordinal(context);
|
|
550
|
+
|
|
551
|
+
// use stream-ordered allocator if available
|
|
552
|
+
if (cuda_device_is_mempool_supported(ordinal))
|
|
553
|
+
free_device_async(context, ptr);
|
|
554
|
+
else
|
|
555
|
+
free_device_default(context, ptr);
|
|
556
|
+
}
|
|
557
|
+
|
|
558
|
+
void* alloc_device_default(void* context, size_t s)
|
|
559
|
+
{
|
|
560
|
+
ContextGuard guard(context);
|
|
561
|
+
|
|
562
|
+
void* ptr = NULL;
|
|
563
|
+
check_cuda(cudaMalloc(&ptr, s));
|
|
564
|
+
|
|
565
|
+
return ptr;
|
|
566
|
+
}
|
|
567
|
+
|
|
568
|
+
void free_device_default(void* context, void* ptr)
|
|
569
|
+
{
|
|
570
|
+
ContextGuard guard(context);
|
|
571
|
+
|
|
572
|
+
// check if a capture is in progress
|
|
573
|
+
if (g_captures.empty())
|
|
574
|
+
{
|
|
575
|
+
check_cuda(cudaFree(ptr));
|
|
576
|
+
}
|
|
577
|
+
else
|
|
578
|
+
{
|
|
579
|
+
// we must defer the operation until graph captures complete
|
|
580
|
+
deferred_free(ptr, context, false);
|
|
581
|
+
}
|
|
582
|
+
}
|
|
583
|
+
|
|
584
|
+
void* alloc_device_async(void* context, size_t s)
|
|
585
|
+
{
|
|
586
|
+
// stream-ordered allocations don't rely on the current context,
|
|
587
|
+
// but we set the context here for consistent behaviour
|
|
588
|
+
ContextGuard guard(context);
|
|
589
|
+
|
|
590
|
+
ContextInfo* context_info = get_context_info(context);
|
|
591
|
+
if (!context_info)
|
|
592
|
+
return NULL;
|
|
593
|
+
|
|
594
|
+
CUstream stream = context_info->stream;
|
|
595
|
+
|
|
596
|
+
void* ptr = NULL;
|
|
597
|
+
check_cuda(cudaMallocAsync(&ptr, s, stream));
|
|
598
|
+
|
|
599
|
+
if (ptr)
|
|
600
|
+
{
|
|
601
|
+
// if the stream is capturing, the allocation requires special handling
|
|
602
|
+
if (cuda_stream_is_capturing(stream))
|
|
603
|
+
{
|
|
604
|
+
// check if this is a known capture
|
|
605
|
+
uint64_t capture_id = get_capture_id(stream);
|
|
606
|
+
auto capture_iter = g_captures.find(capture_id);
|
|
607
|
+
if (capture_iter != g_captures.end())
|
|
608
|
+
{
|
|
609
|
+
// remember graph allocation details
|
|
610
|
+
GraphAllocInfo alloc_info;
|
|
611
|
+
alloc_info.capture_id = capture_id;
|
|
612
|
+
alloc_info.context = context ? context : get_current_context();
|
|
613
|
+
alloc_info.ref_exists = true; // user reference created and returned here
|
|
614
|
+
alloc_info.graph_destroyed = false; // graph not destroyed yet
|
|
615
|
+
g_graph_allocs[ptr] = alloc_info;
|
|
616
|
+
}
|
|
617
|
+
}
|
|
618
|
+
}
|
|
619
|
+
|
|
620
|
+
return ptr;
|
|
621
|
+
}
|
|
622
|
+
|
|
623
|
+
void free_device_async(void* context, void* ptr)
|
|
624
|
+
{
|
|
625
|
+
// stream-ordered allocators generally don't rely on the current context,
|
|
626
|
+
// but we set the context here for consistent behaviour
|
|
627
|
+
ContextGuard guard(context);
|
|
628
|
+
|
|
629
|
+
// NB: Stream-ordered deallocations are tricky, because the memory could still be used on another stream
|
|
630
|
+
// or even multiple streams. To avoid use-after-free errors, we need to ensure that all preceding work
|
|
631
|
+
// completes before releasing the memory. The strategy is different for regular stream-ordered allocations
|
|
632
|
+
// and allocations made during graph capture. See below for details.
|
|
633
|
+
|
|
634
|
+
// check if this allocation was made during graph capture
|
|
635
|
+
auto alloc_iter = g_graph_allocs.find(ptr);
|
|
636
|
+
if (alloc_iter == g_graph_allocs.end())
|
|
637
|
+
{
|
|
638
|
+
// Not a graph allocation.
|
|
639
|
+
// Check if graph capture is ongoing.
|
|
640
|
+
if (g_captures.empty())
|
|
641
|
+
{
|
|
642
|
+
// cudaFreeAsync on the null stream does not block or trigger synchronization, but it postpones
|
|
643
|
+
// the deallocation until a synchronization point is reached, so preceding work on this pointer
|
|
644
|
+
// should safely complete.
|
|
645
|
+
check_cuda(cudaFreeAsync(ptr, NULL));
|
|
646
|
+
}
|
|
647
|
+
else
|
|
648
|
+
{
|
|
649
|
+
// We must defer the free operation until graph capture completes.
|
|
650
|
+
deferred_free(ptr, context, true);
|
|
651
|
+
}
|
|
652
|
+
}
|
|
653
|
+
else
|
|
654
|
+
{
|
|
655
|
+
// get the graph allocation details
|
|
656
|
+
GraphAllocInfo& alloc_info = alloc_iter->second;
|
|
657
|
+
|
|
658
|
+
uint64_t capture_id = alloc_info.capture_id;
|
|
659
|
+
|
|
660
|
+
// check if the capture is still active
|
|
661
|
+
auto capture_iter = g_captures.find(capture_id);
|
|
662
|
+
if (capture_iter != g_captures.end())
|
|
663
|
+
{
|
|
664
|
+
// Add a mem free node. Use all current leaf nodes as dependencies to ensure that all prior
|
|
665
|
+
// work completes before deallocating. This works with both Warp-initiated and external captures
|
|
666
|
+
// and avoids the need to explicitly track all streams used during the capture.
|
|
667
|
+
CaptureInfo* capture = capture_iter->second;
|
|
668
|
+
cudaGraph_t graph = get_capture_graph(capture->stream);
|
|
669
|
+
std::vector<cudaGraphNode_t> leaf_nodes;
|
|
670
|
+
if (graph && get_graph_leaf_nodes(graph, leaf_nodes))
|
|
671
|
+
{
|
|
672
|
+
cudaGraphNode_t free_node;
|
|
673
|
+
check_cuda(cudaGraphAddMemFreeNode(&free_node, graph, leaf_nodes.data(), leaf_nodes.size(), ptr));
|
|
674
|
+
}
|
|
675
|
+
|
|
676
|
+
// we're done with this allocation, it's owned by the graph
|
|
677
|
+
g_graph_allocs.erase(alloc_iter);
|
|
678
|
+
}
|
|
679
|
+
else
|
|
680
|
+
{
|
|
681
|
+
// the capture has ended
|
|
682
|
+
// if the owning graph was already destroyed, we can free the pointer now
|
|
683
|
+
if (alloc_info.graph_destroyed)
|
|
684
|
+
{
|
|
685
|
+
if (g_captures.empty())
|
|
686
|
+
{
|
|
687
|
+
// try to free the pointer now
|
|
688
|
+
cudaError_t res = cudaFreeAsync(ptr, NULL);
|
|
689
|
+
if (res == cudaErrorInvalidValue)
|
|
690
|
+
{
|
|
691
|
+
// This can happen if we try to release the pointer but the graph was
|
|
692
|
+
// never launched, so the memory isn't mapped.
|
|
693
|
+
// This is fine, so clear the error.
|
|
694
|
+
cudaGetLastError();
|
|
695
|
+
}
|
|
696
|
+
else
|
|
697
|
+
{
|
|
698
|
+
// check for other errors
|
|
699
|
+
check_cuda(res);
|
|
700
|
+
}
|
|
701
|
+
}
|
|
702
|
+
else
|
|
703
|
+
{
|
|
704
|
+
// We must defer the operation until graph capture completes.
|
|
705
|
+
deferred_free(ptr, context, true);
|
|
706
|
+
}
|
|
707
|
+
|
|
708
|
+
// we're done with this allocation
|
|
709
|
+
g_graph_allocs.erase(alloc_iter);
|
|
710
|
+
}
|
|
711
|
+
else
|
|
712
|
+
{
|
|
713
|
+
// graph still exists
|
|
714
|
+
// unreference the pointer so it will be deallocated once the graph instance is destroyed
|
|
715
|
+
alloc_info.ref_exists = false;
|
|
716
|
+
}
|
|
717
|
+
}
|
|
718
|
+
}
|
|
719
|
+
}
|
|
720
|
+
|
|
721
|
+
bool memcpy_h2d(void* context, void* dest, void* src, size_t n, void* stream)
|
|
722
|
+
{
|
|
723
|
+
ContextGuard guard(context);
|
|
724
|
+
|
|
725
|
+
CUstream cuda_stream;
|
|
726
|
+
if (stream != WP_CURRENT_STREAM)
|
|
727
|
+
cuda_stream = static_cast<CUstream>(stream);
|
|
728
|
+
else
|
|
729
|
+
cuda_stream = get_current_stream(context);
|
|
730
|
+
|
|
731
|
+
begin_cuda_range(WP_TIMING_MEMCPY, cuda_stream, context, "memcpy HtoD");
|
|
732
|
+
|
|
733
|
+
bool result = check_cuda(cudaMemcpyAsync(dest, src, n, cudaMemcpyHostToDevice, cuda_stream));
|
|
734
|
+
|
|
735
|
+
end_cuda_range(WP_TIMING_MEMCPY, cuda_stream);
|
|
736
|
+
|
|
737
|
+
return result;
|
|
738
|
+
}
|
|
739
|
+
|
|
740
|
+
bool memcpy_d2h(void* context, void* dest, void* src, size_t n, void* stream)
|
|
741
|
+
{
|
|
742
|
+
ContextGuard guard(context);
|
|
743
|
+
|
|
744
|
+
CUstream cuda_stream;
|
|
745
|
+
if (stream != WP_CURRENT_STREAM)
|
|
746
|
+
cuda_stream = static_cast<CUstream>(stream);
|
|
747
|
+
else
|
|
748
|
+
cuda_stream = get_current_stream(context);
|
|
749
|
+
|
|
750
|
+
begin_cuda_range(WP_TIMING_MEMCPY, cuda_stream, context, "memcpy DtoH");
|
|
751
|
+
|
|
752
|
+
bool result = check_cuda(cudaMemcpyAsync(dest, src, n, cudaMemcpyDeviceToHost, cuda_stream));
|
|
753
|
+
|
|
754
|
+
end_cuda_range(WP_TIMING_MEMCPY, cuda_stream);
|
|
755
|
+
|
|
756
|
+
return result;
|
|
757
|
+
}
|
|
758
|
+
|
|
759
|
+
bool memcpy_d2d(void* context, void* dest, void* src, size_t n, void* stream)
|
|
760
|
+
{
|
|
761
|
+
ContextGuard guard(context);
|
|
762
|
+
|
|
763
|
+
CUstream cuda_stream;
|
|
764
|
+
if (stream != WP_CURRENT_STREAM)
|
|
765
|
+
cuda_stream = static_cast<CUstream>(stream);
|
|
766
|
+
else
|
|
767
|
+
cuda_stream = get_current_stream(context);
|
|
768
|
+
|
|
769
|
+
begin_cuda_range(WP_TIMING_MEMCPY, cuda_stream, context, "memcpy DtoD");
|
|
770
|
+
|
|
771
|
+
bool result = check_cuda(cudaMemcpyAsync(dest, src, n, cudaMemcpyDeviceToDevice, cuda_stream));
|
|
772
|
+
|
|
773
|
+
end_cuda_range(WP_TIMING_MEMCPY, cuda_stream);
|
|
774
|
+
|
|
775
|
+
return result;
|
|
776
|
+
}
|
|
777
|
+
|
|
778
|
+
bool memcpy_p2p(void* dst_context, void* dst, void* src_context, void* src, size_t n, void* stream)
|
|
779
|
+
{
|
|
780
|
+
// ContextGuard guard(context);
|
|
781
|
+
|
|
782
|
+
CUstream cuda_stream;
|
|
783
|
+
if (stream != WP_CURRENT_STREAM)
|
|
784
|
+
cuda_stream = static_cast<CUstream>(stream);
|
|
785
|
+
else
|
|
786
|
+
cuda_stream = get_current_stream(dst_context);
|
|
787
|
+
|
|
788
|
+
// Notes:
|
|
789
|
+
// - cuMemcpyPeerAsync() works fine with both regular and pooled allocations (cudaMalloc() and cudaMallocAsync(), respectively)
|
|
790
|
+
// when not capturing a graph.
|
|
791
|
+
// - cuMemcpyPeerAsync() is not supported during graph capture, so we must use cudaMemcpyAsync() with kind=cudaMemcpyDefault.
|
|
792
|
+
// - cudaMemcpyAsync() works fine with regular allocations, but doesn't work with pooled allocations
|
|
793
|
+
// unless mempool access has been enabled.
|
|
794
|
+
// - There is no reliable way to check if mempool access is enabled during graph capture,
|
|
795
|
+
// because cudaMemPoolGetAccess() cannot be called during graph capture.
|
|
796
|
+
// - CUDA will report error 1 (invalid argument) if cudaMemcpyAsync() is called but mempool access is not enabled.
|
|
797
|
+
|
|
798
|
+
if (!cuda_stream_is_capturing(stream))
|
|
799
|
+
{
|
|
800
|
+
begin_cuda_range(WP_TIMING_MEMCPY, cuda_stream, get_stream_context(stream), "memcpy PtoP");
|
|
801
|
+
|
|
802
|
+
bool result = check_cu(cuMemcpyPeerAsync_f(
|
|
803
|
+
(CUdeviceptr)dst, (CUcontext)dst_context,
|
|
804
|
+
(CUdeviceptr)src, (CUcontext)src_context,
|
|
805
|
+
n, cuda_stream));
|
|
806
|
+
|
|
807
|
+
end_cuda_range(WP_TIMING_MEMCPY, cuda_stream);
|
|
808
|
+
|
|
809
|
+
return result;
|
|
810
|
+
}
|
|
811
|
+
else
|
|
812
|
+
{
|
|
813
|
+
cudaError_t result = cudaSuccess;
|
|
814
|
+
|
|
815
|
+
// cudaMemcpyAsync() is sensitive to the bound context to resolve pointer locations.
|
|
816
|
+
// If fails with cudaErrorInvalidValue if it cannot resolve an argument.
|
|
817
|
+
// We first try the copy in the destination context, then if it fails we retry in the source context.
|
|
818
|
+
// The cudaErrorInvalidValue error doesn't cause graph capture to fail, so it's ok to retry.
|
|
819
|
+
// Since this trial-and-error shenanigans only happens during capture, there
|
|
820
|
+
// is no perf impact when the graph is launched.
|
|
821
|
+
// For bonus points, this approach simplifies memory pool access requirements.
|
|
822
|
+
// Access only needs to be enabled one way, either from the source device to the destination device
|
|
823
|
+
// or vice versa. Sometimes, when it's really quiet, you can actually hear my genius.
|
|
824
|
+
{
|
|
825
|
+
// try doing the copy in the destination context
|
|
826
|
+
ContextGuard guard(dst_context);
|
|
827
|
+
result = cudaMemcpyAsync(dst, src, n, cudaMemcpyDefault, cuda_stream);
|
|
828
|
+
|
|
829
|
+
if (result != cudaSuccess)
|
|
830
|
+
{
|
|
831
|
+
// clear error in destination context
|
|
832
|
+
cudaGetLastError();
|
|
833
|
+
|
|
834
|
+
// try doing the copy in the source context
|
|
835
|
+
ContextGuard guard(src_context);
|
|
836
|
+
result = cudaMemcpyAsync(dst, src, n, cudaMemcpyDefault, cuda_stream);
|
|
837
|
+
|
|
838
|
+
// clear error in source context
|
|
839
|
+
cudaGetLastError();
|
|
840
|
+
}
|
|
841
|
+
}
|
|
842
|
+
|
|
843
|
+
// If the copy failed, try to detect if mempool allocations are involved to generate a helpful error message.
|
|
844
|
+
if (!check_cuda(result))
|
|
845
|
+
{
|
|
846
|
+
if (result == cudaErrorInvalidValue && src != NULL && dst != NULL)
|
|
847
|
+
{
|
|
848
|
+
// check if either of the pointers was allocated from a mempool
|
|
849
|
+
void* src_mempool = NULL;
|
|
850
|
+
void* dst_mempool = NULL;
|
|
851
|
+
cuPointerGetAttribute_f(&src_mempool, CU_POINTER_ATTRIBUTE_MEMPOOL_HANDLE, (CUdeviceptr)src);
|
|
852
|
+
cuPointerGetAttribute_f(&dst_mempool, CU_POINTER_ATTRIBUTE_MEMPOOL_HANDLE, (CUdeviceptr)dst);
|
|
853
|
+
cudaGetLastError(); // clear any errors
|
|
854
|
+
// check if either of the pointers was allocated during graph capture
|
|
855
|
+
auto src_alloc = g_graph_allocs.find(src);
|
|
856
|
+
auto dst_alloc = g_graph_allocs.find(dst);
|
|
857
|
+
if (src_mempool != NULL || src_alloc != g_graph_allocs.end() ||
|
|
858
|
+
dst_mempool != NULL || dst_alloc != g_graph_allocs.end())
|
|
859
|
+
{
|
|
860
|
+
wp::append_error_string("*** CUDA mempool allocations were used in a peer-to-peer copy during graph capture.");
|
|
861
|
+
wp::append_error_string("*** This operation fails if mempool access is not enabled between the peer devices.");
|
|
862
|
+
wp::append_error_string("*** Either enable mempool access between the devices or use the default CUDA allocator");
|
|
863
|
+
wp::append_error_string("*** to pre-allocate the arrays before graph capture begins.");
|
|
864
|
+
}
|
|
865
|
+
}
|
|
866
|
+
|
|
867
|
+
return false;
|
|
868
|
+
}
|
|
869
|
+
|
|
870
|
+
return true;
|
|
871
|
+
}
|
|
872
|
+
}
|
|
873
|
+
|
|
874
|
+
|
|
875
|
+
__global__ void memset_kernel(int* dest, int value, size_t n)
|
|
876
|
+
{
|
|
877
|
+
const size_t tid = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
|
|
878
|
+
|
|
879
|
+
if (tid < n)
|
|
880
|
+
{
|
|
881
|
+
dest[tid] = value;
|
|
882
|
+
}
|
|
883
|
+
}
|
|
884
|
+
|
|
885
|
+
void memset_device(void* context, void* dest, int value, size_t n)
|
|
886
|
+
{
|
|
887
|
+
ContextGuard guard(context);
|
|
888
|
+
|
|
889
|
+
if (true)// ((n%4) > 0)
|
|
890
|
+
{
|
|
891
|
+
cudaStream_t stream = get_current_stream();
|
|
892
|
+
|
|
893
|
+
begin_cuda_range(WP_TIMING_MEMSET, stream, context, "memset");
|
|
894
|
+
|
|
895
|
+
// for unaligned lengths fallback to CUDA memset
|
|
896
|
+
check_cuda(cudaMemsetAsync(dest, value, n, stream));
|
|
897
|
+
|
|
898
|
+
end_cuda_range(WP_TIMING_MEMSET, stream);
|
|
899
|
+
}
|
|
900
|
+
else
|
|
901
|
+
{
|
|
902
|
+
// custom kernel to support 4-byte values (and slightly lower host overhead)
|
|
903
|
+
const size_t num_words = n/4;
|
|
904
|
+
wp_launch_device(WP_CURRENT_CONTEXT, memset_kernel, num_words, ((int*)dest, value, num_words));
|
|
905
|
+
}
|
|
906
|
+
}
|
|
907
|
+
|
|
908
|
+
// fill memory buffer with a value: generic memtile kernel using memcpy for each element
|
|
909
|
+
__global__ void memtile_kernel(void* dst, const void* src, size_t srcsize, size_t n)
|
|
910
|
+
{
|
|
911
|
+
size_t tid = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
|
|
912
|
+
if (tid < n)
|
|
913
|
+
{
|
|
914
|
+
memcpy((int8_t*)dst + srcsize * tid, src, srcsize);
|
|
915
|
+
}
|
|
916
|
+
}
|
|
917
|
+
|
|
918
|
+
// this should be faster than memtile_kernel, but requires proper alignment of dst
|
|
919
|
+
template <typename T>
|
|
920
|
+
__global__ void memtile_value_kernel(T* dst, T value, size_t n)
|
|
921
|
+
{
|
|
922
|
+
size_t tid = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
|
|
923
|
+
if (tid < n)
|
|
924
|
+
{
|
|
925
|
+
dst[tid] = value;
|
|
926
|
+
}
|
|
927
|
+
}
|
|
928
|
+
|
|
929
|
+
void memtile_device(void* context, void* dst, const void* src, size_t srcsize, size_t n)
|
|
930
|
+
{
|
|
931
|
+
ContextGuard guard(context);
|
|
932
|
+
|
|
933
|
+
size_t dst_addr = reinterpret_cast<size_t>(dst);
|
|
934
|
+
size_t src_addr = reinterpret_cast<size_t>(src);
|
|
935
|
+
|
|
936
|
+
// try memtile_value first because it should be faster, but we need to ensure proper alignment
|
|
937
|
+
if (srcsize == 8 && (dst_addr & 7) == 0 && (src_addr & 7) == 0)
|
|
938
|
+
{
|
|
939
|
+
int64_t* p = reinterpret_cast<int64_t*>(dst);
|
|
940
|
+
int64_t value = *reinterpret_cast<const int64_t*>(src);
|
|
941
|
+
wp_launch_device(WP_CURRENT_CONTEXT, memtile_value_kernel, n, (p, value, n));
|
|
942
|
+
}
|
|
943
|
+
else if (srcsize == 4 && (dst_addr & 3) == 0 && (src_addr & 3) == 0)
|
|
944
|
+
{
|
|
945
|
+
int32_t* p = reinterpret_cast<int32_t*>(dst);
|
|
946
|
+
int32_t value = *reinterpret_cast<const int32_t*>(src);
|
|
947
|
+
wp_launch_device(WP_CURRENT_CONTEXT, memtile_value_kernel, n, (p, value, n));
|
|
948
|
+
}
|
|
949
|
+
else if (srcsize == 2 && (dst_addr & 1) == 0 && (src_addr & 1) == 0)
|
|
950
|
+
{
|
|
951
|
+
int16_t* p = reinterpret_cast<int16_t*>(dst);
|
|
952
|
+
int16_t value = *reinterpret_cast<const int16_t*>(src);
|
|
953
|
+
wp_launch_device(WP_CURRENT_CONTEXT, memtile_value_kernel, n, (p, value, n));
|
|
954
|
+
}
|
|
955
|
+
else if (srcsize == 1)
|
|
956
|
+
{
|
|
957
|
+
check_cuda(cudaMemset(dst, *reinterpret_cast<const int8_t*>(src), n));
|
|
958
|
+
}
|
|
959
|
+
else
|
|
960
|
+
{
|
|
961
|
+
// generic version
|
|
962
|
+
|
|
963
|
+
// copy value to device memory
|
|
964
|
+
// TODO: use a persistent stream-local staging buffer to avoid allocs?
|
|
965
|
+
void* src_devptr = alloc_device(WP_CURRENT_CONTEXT, srcsize);
|
|
966
|
+
check_cuda(cudaMemcpyAsync(src_devptr, src, srcsize, cudaMemcpyHostToDevice, get_current_stream()));
|
|
967
|
+
|
|
968
|
+
wp_launch_device(WP_CURRENT_CONTEXT, memtile_kernel, n, (dst, src_devptr, srcsize, n));
|
|
969
|
+
|
|
970
|
+
free_device(WP_CURRENT_CONTEXT, src_devptr);
|
|
971
|
+
|
|
972
|
+
}
|
|
973
|
+
}
|
|
974
|
+
|
|
975
|
+
|
|
976
|
+
static __global__ void array_copy_1d_kernel(void* dst, const void* src,
|
|
977
|
+
int dst_stride, int src_stride,
|
|
978
|
+
const int* dst_indices, const int* src_indices,
|
|
979
|
+
int n, int elem_size)
|
|
980
|
+
{
|
|
981
|
+
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
|
982
|
+
if (i < n)
|
|
983
|
+
{
|
|
984
|
+
int src_idx = src_indices ? src_indices[i] : i;
|
|
985
|
+
int dst_idx = dst_indices ? dst_indices[i] : i;
|
|
986
|
+
const char* p = (const char*)src + src_idx * src_stride;
|
|
987
|
+
char* q = (char*)dst + dst_idx * dst_stride;
|
|
988
|
+
memcpy(q, p, elem_size);
|
|
989
|
+
}
|
|
990
|
+
}
|
|
991
|
+
|
|
992
|
+
static __global__ void array_copy_2d_kernel(void* dst, const void* src,
|
|
993
|
+
wp::vec_t<2, int> dst_strides, wp::vec_t<2, int> src_strides,
|
|
994
|
+
wp::vec_t<2, const int*> dst_indices, wp::vec_t<2, const int*> src_indices,
|
|
995
|
+
wp::vec_t<2, int> shape, int elem_size)
|
|
996
|
+
{
|
|
997
|
+
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
|
998
|
+
int n = shape[1];
|
|
999
|
+
int i = tid / n;
|
|
1000
|
+
int j = tid % n;
|
|
1001
|
+
if (i < shape[0] /*&& j < shape[1]*/)
|
|
1002
|
+
{
|
|
1003
|
+
int src_idx0 = src_indices[0] ? src_indices[0][i] : i;
|
|
1004
|
+
int dst_idx0 = dst_indices[0] ? dst_indices[0][i] : i;
|
|
1005
|
+
int src_idx1 = src_indices[1] ? src_indices[1][j] : j;
|
|
1006
|
+
int dst_idx1 = dst_indices[1] ? dst_indices[1][j] : j;
|
|
1007
|
+
const char* p = (const char*)src + src_idx0 * src_strides[0] + src_idx1 * src_strides[1];
|
|
1008
|
+
char* q = (char*)dst + dst_idx0 * dst_strides[0] + dst_idx1 * dst_strides[1];
|
|
1009
|
+
memcpy(q, p, elem_size);
|
|
1010
|
+
}
|
|
1011
|
+
}
|
|
1012
|
+
|
|
1013
|
+
static __global__ void array_copy_3d_kernel(void* dst, const void* src,
|
|
1014
|
+
wp::vec_t<3, int> dst_strides, wp::vec_t<3, int> src_strides,
|
|
1015
|
+
wp::vec_t<3, const int*> dst_indices, wp::vec_t<3, const int*> src_indices,
|
|
1016
|
+
wp::vec_t<3, int> shape, int elem_size)
|
|
1017
|
+
{
|
|
1018
|
+
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
|
1019
|
+
int n = shape[1];
|
|
1020
|
+
int o = shape[2];
|
|
1021
|
+
int i = tid / (n * o);
|
|
1022
|
+
int j = tid % (n * o) / o;
|
|
1023
|
+
int k = tid % o;
|
|
1024
|
+
if (i < shape[0] && j < shape[1] /*&& k < shape[2]*/)
|
|
1025
|
+
{
|
|
1026
|
+
int src_idx0 = src_indices[0] ? src_indices[0][i] : i;
|
|
1027
|
+
int dst_idx0 = dst_indices[0] ? dst_indices[0][i] : i;
|
|
1028
|
+
int src_idx1 = src_indices[1] ? src_indices[1][j] : j;
|
|
1029
|
+
int dst_idx1 = dst_indices[1] ? dst_indices[1][j] : j;
|
|
1030
|
+
int src_idx2 = src_indices[2] ? src_indices[2][k] : k;
|
|
1031
|
+
int dst_idx2 = dst_indices[2] ? dst_indices[2][k] : k;
|
|
1032
|
+
const char* p = (const char*)src + src_idx0 * src_strides[0]
|
|
1033
|
+
+ src_idx1 * src_strides[1]
|
|
1034
|
+
+ src_idx2 * src_strides[2];
|
|
1035
|
+
char* q = (char*)dst + dst_idx0 * dst_strides[0]
|
|
1036
|
+
+ dst_idx1 * dst_strides[1]
|
|
1037
|
+
+ dst_idx2 * dst_strides[2];
|
|
1038
|
+
memcpy(q, p, elem_size);
|
|
1039
|
+
}
|
|
1040
|
+
}
|
|
1041
|
+
|
|
1042
|
+
static __global__ void array_copy_4d_kernel(void* dst, const void* src,
|
|
1043
|
+
wp::vec_t<4, int> dst_strides, wp::vec_t<4, int> src_strides,
|
|
1044
|
+
wp::vec_t<4, const int*> dst_indices, wp::vec_t<4, const int*> src_indices,
|
|
1045
|
+
wp::vec_t<4, int> shape, int elem_size)
|
|
1046
|
+
{
|
|
1047
|
+
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
|
1048
|
+
int n = shape[1];
|
|
1049
|
+
int o = shape[2];
|
|
1050
|
+
int p = shape[3];
|
|
1051
|
+
int i = tid / (n * o * p);
|
|
1052
|
+
int j = tid % (n * o * p) / (o * p);
|
|
1053
|
+
int k = tid % (o * p) / p;
|
|
1054
|
+
int l = tid % p;
|
|
1055
|
+
if (i < shape[0] && j < shape[1] && k < shape[2] /*&& l < shape[3]*/)
|
|
1056
|
+
{
|
|
1057
|
+
int src_idx0 = src_indices[0] ? src_indices[0][i] : i;
|
|
1058
|
+
int dst_idx0 = dst_indices[0] ? dst_indices[0][i] : i;
|
|
1059
|
+
int src_idx1 = src_indices[1] ? src_indices[1][j] : j;
|
|
1060
|
+
int dst_idx1 = dst_indices[1] ? dst_indices[1][j] : j;
|
|
1061
|
+
int src_idx2 = src_indices[2] ? src_indices[2][k] : k;
|
|
1062
|
+
int dst_idx2 = dst_indices[2] ? dst_indices[2][k] : k;
|
|
1063
|
+
int src_idx3 = src_indices[3] ? src_indices[3][l] : l;
|
|
1064
|
+
int dst_idx3 = dst_indices[3] ? dst_indices[3][l] : l;
|
|
1065
|
+
const char* p = (const char*)src + src_idx0 * src_strides[0]
|
|
1066
|
+
+ src_idx1 * src_strides[1]
|
|
1067
|
+
+ src_idx2 * src_strides[2]
|
|
1068
|
+
+ src_idx3 * src_strides[3];
|
|
1069
|
+
char* q = (char*)dst + dst_idx0 * dst_strides[0]
|
|
1070
|
+
+ dst_idx1 * dst_strides[1]
|
|
1071
|
+
+ dst_idx2 * dst_strides[2]
|
|
1072
|
+
+ dst_idx3 * dst_strides[3];
|
|
1073
|
+
memcpy(q, p, elem_size);
|
|
1074
|
+
}
|
|
1075
|
+
}
|
|
1076
|
+
|
|
1077
|
+
|
|
1078
|
+
static __global__ void array_copy_from_fabric_kernel(wp::fabricarray_t<void> src,
|
|
1079
|
+
void* dst_data, int dst_stride, const int* dst_indices,
|
|
1080
|
+
int elem_size)
|
|
1081
|
+
{
|
|
1082
|
+
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
|
1083
|
+
|
|
1084
|
+
if (tid < src.size)
|
|
1085
|
+
{
|
|
1086
|
+
int dst_idx = dst_indices ? dst_indices[tid] : tid;
|
|
1087
|
+
void* dst_ptr = (char*)dst_data + dst_idx * dst_stride;
|
|
1088
|
+
const void* src_ptr = fabricarray_element_ptr(src, tid, elem_size);
|
|
1089
|
+
memcpy(dst_ptr, src_ptr, elem_size);
|
|
1090
|
+
}
|
|
1091
|
+
}
|
|
1092
|
+
|
|
1093
|
+
static __global__ void array_copy_from_fabric_indexed_kernel(wp::indexedfabricarray_t<void> src,
|
|
1094
|
+
void* dst_data, int dst_stride, const int* dst_indices,
|
|
1095
|
+
int elem_size)
|
|
1096
|
+
{
|
|
1097
|
+
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
|
1098
|
+
|
|
1099
|
+
if (tid < src.size)
|
|
1100
|
+
{
|
|
1101
|
+
int src_index = src.indices[tid];
|
|
1102
|
+
int dst_idx = dst_indices ? dst_indices[tid] : tid;
|
|
1103
|
+
void* dst_ptr = (char*)dst_data + dst_idx * dst_stride;
|
|
1104
|
+
const void* src_ptr = fabricarray_element_ptr(src.fa, src_index, elem_size);
|
|
1105
|
+
memcpy(dst_ptr, src_ptr, elem_size);
|
|
1106
|
+
}
|
|
1107
|
+
}
|
|
1108
|
+
|
|
1109
|
+
static __global__ void array_copy_to_fabric_kernel(wp::fabricarray_t<void> dst,
|
|
1110
|
+
const void* src_data, int src_stride, const int* src_indices,
|
|
1111
|
+
int elem_size)
|
|
1112
|
+
{
|
|
1113
|
+
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
|
1114
|
+
|
|
1115
|
+
if (tid < dst.size)
|
|
1116
|
+
{
|
|
1117
|
+
int src_idx = src_indices ? src_indices[tid] : tid;
|
|
1118
|
+
const void* src_ptr = (const char*)src_data + src_idx * src_stride;
|
|
1119
|
+
void* dst_ptr = fabricarray_element_ptr(dst, tid, elem_size);
|
|
1120
|
+
memcpy(dst_ptr, src_ptr, elem_size);
|
|
1121
|
+
}
|
|
1122
|
+
}
|
|
1123
|
+
|
|
1124
|
+
static __global__ void array_copy_to_fabric_indexed_kernel(wp::indexedfabricarray_t<void> dst,
|
|
1125
|
+
const void* src_data, int src_stride, const int* src_indices,
|
|
1126
|
+
int elem_size)
|
|
1127
|
+
{
|
|
1128
|
+
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
|
1129
|
+
|
|
1130
|
+
if (tid < dst.size)
|
|
1131
|
+
{
|
|
1132
|
+
int src_idx = src_indices ? src_indices[tid] : tid;
|
|
1133
|
+
const void* src_ptr = (const char*)src_data + src_idx * src_stride;
|
|
1134
|
+
int dst_idx = dst.indices[tid];
|
|
1135
|
+
void* dst_ptr = fabricarray_element_ptr(dst.fa, dst_idx, elem_size);
|
|
1136
|
+
memcpy(dst_ptr, src_ptr, elem_size);
|
|
1137
|
+
}
|
|
1138
|
+
}
|
|
1139
|
+
|
|
1140
|
+
|
|
1141
|
+
static __global__ void array_copy_fabric_to_fabric_kernel(wp::fabricarray_t<void> dst, wp::fabricarray_t<void> src, int elem_size)
|
|
1142
|
+
{
|
|
1143
|
+
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
|
1144
|
+
|
|
1145
|
+
if (tid < dst.size)
|
|
1146
|
+
{
|
|
1147
|
+
const void* src_ptr = fabricarray_element_ptr(src, tid, elem_size);
|
|
1148
|
+
void* dst_ptr = fabricarray_element_ptr(dst, tid, elem_size);
|
|
1149
|
+
memcpy(dst_ptr, src_ptr, elem_size);
|
|
1150
|
+
}
|
|
1151
|
+
}
|
|
1152
|
+
|
|
1153
|
+
|
|
1154
|
+
static __global__ void array_copy_fabric_to_fabric_indexed_kernel(wp::indexedfabricarray_t<void> dst, wp::fabricarray_t<void> src, int elem_size)
|
|
1155
|
+
{
|
|
1156
|
+
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
|
1157
|
+
|
|
1158
|
+
if (tid < dst.size)
|
|
1159
|
+
{
|
|
1160
|
+
const void* src_ptr = fabricarray_element_ptr(src, tid, elem_size);
|
|
1161
|
+
int dst_index = dst.indices[tid];
|
|
1162
|
+
void* dst_ptr = fabricarray_element_ptr(dst.fa, dst_index, elem_size);
|
|
1163
|
+
memcpy(dst_ptr, src_ptr, elem_size);
|
|
1164
|
+
}
|
|
1165
|
+
}
|
|
1166
|
+
|
|
1167
|
+
|
|
1168
|
+
static __global__ void array_copy_fabric_indexed_to_fabric_kernel(wp::fabricarray_t<void> dst, wp::indexedfabricarray_t<void> src, int elem_size)
|
|
1169
|
+
{
|
|
1170
|
+
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
|
1171
|
+
|
|
1172
|
+
if (tid < dst.size)
|
|
1173
|
+
{
|
|
1174
|
+
int src_index = src.indices[tid];
|
|
1175
|
+
const void* src_ptr = fabricarray_element_ptr(src.fa, src_index, elem_size);
|
|
1176
|
+
void* dst_ptr = fabricarray_element_ptr(dst, tid, elem_size);
|
|
1177
|
+
memcpy(dst_ptr, src_ptr, elem_size);
|
|
1178
|
+
}
|
|
1179
|
+
}
|
|
1180
|
+
|
|
1181
|
+
|
|
1182
|
+
static __global__ void array_copy_fabric_indexed_to_fabric_indexed_kernel(wp::indexedfabricarray_t<void> dst, wp::indexedfabricarray_t<void> src, int elem_size)
|
|
1183
|
+
{
|
|
1184
|
+
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
|
1185
|
+
|
|
1186
|
+
if (tid < dst.size)
|
|
1187
|
+
{
|
|
1188
|
+
int src_index = src.indices[tid];
|
|
1189
|
+
int dst_index = dst.indices[tid];
|
|
1190
|
+
const void* src_ptr = fabricarray_element_ptr(src.fa, src_index, elem_size);
|
|
1191
|
+
void* dst_ptr = fabricarray_element_ptr(dst.fa, dst_index, elem_size);
|
|
1192
|
+
memcpy(dst_ptr, src_ptr, elem_size);
|
|
1193
|
+
}
|
|
1194
|
+
}
|
|
1195
|
+
|
|
1196
|
+
|
|
1197
|
+
WP_API bool array_copy_device(void* context, void* dst, void* src, int dst_type, int src_type, int elem_size)
|
|
1198
|
+
{
|
|
1199
|
+
if (!src || !dst)
|
|
1200
|
+
return false;
|
|
1201
|
+
|
|
1202
|
+
const void* src_data = NULL;
|
|
1203
|
+
void* dst_data = NULL;
|
|
1204
|
+
int src_ndim = 0;
|
|
1205
|
+
int dst_ndim = 0;
|
|
1206
|
+
const int* src_shape = NULL;
|
|
1207
|
+
const int* dst_shape = NULL;
|
|
1208
|
+
const int* src_strides = NULL;
|
|
1209
|
+
const int* dst_strides = NULL;
|
|
1210
|
+
const int*const* src_indices = NULL;
|
|
1211
|
+
const int*const* dst_indices = NULL;
|
|
1212
|
+
|
|
1213
|
+
const wp::fabricarray_t<void>* src_fabricarray = NULL;
|
|
1214
|
+
wp::fabricarray_t<void>* dst_fabricarray = NULL;
|
|
1215
|
+
|
|
1216
|
+
const wp::indexedfabricarray_t<void>* src_indexedfabricarray = NULL;
|
|
1217
|
+
wp::indexedfabricarray_t<void>* dst_indexedfabricarray = NULL;
|
|
1218
|
+
|
|
1219
|
+
const int* null_indices[wp::ARRAY_MAX_DIMS] = { NULL };
|
|
1220
|
+
|
|
1221
|
+
if (src_type == wp::ARRAY_TYPE_REGULAR)
|
|
1222
|
+
{
|
|
1223
|
+
const wp::array_t<void>& src_arr = *static_cast<const wp::array_t<void>*>(src);
|
|
1224
|
+
src_data = src_arr.data;
|
|
1225
|
+
src_ndim = src_arr.ndim;
|
|
1226
|
+
src_shape = src_arr.shape.dims;
|
|
1227
|
+
src_strides = src_arr.strides;
|
|
1228
|
+
src_indices = null_indices;
|
|
1229
|
+
}
|
|
1230
|
+
else if (src_type == wp::ARRAY_TYPE_INDEXED)
|
|
1231
|
+
{
|
|
1232
|
+
const wp::indexedarray_t<void>& src_arr = *static_cast<const wp::indexedarray_t<void>*>(src);
|
|
1233
|
+
src_data = src_arr.arr.data;
|
|
1234
|
+
src_ndim = src_arr.arr.ndim;
|
|
1235
|
+
src_shape = src_arr.shape.dims;
|
|
1236
|
+
src_strides = src_arr.arr.strides;
|
|
1237
|
+
src_indices = src_arr.indices;
|
|
1238
|
+
}
|
|
1239
|
+
else if (src_type == wp::ARRAY_TYPE_FABRIC)
|
|
1240
|
+
{
|
|
1241
|
+
src_fabricarray = static_cast<const wp::fabricarray_t<void>*>(src);
|
|
1242
|
+
src_ndim = 1;
|
|
1243
|
+
}
|
|
1244
|
+
else if (src_type == wp::ARRAY_TYPE_FABRIC_INDEXED)
|
|
1245
|
+
{
|
|
1246
|
+
src_indexedfabricarray = static_cast<const wp::indexedfabricarray_t<void>*>(src);
|
|
1247
|
+
src_ndim = 1;
|
|
1248
|
+
}
|
|
1249
|
+
else
|
|
1250
|
+
{
|
|
1251
|
+
fprintf(stderr, "Warp copy error: Invalid array type (%d)\n", src_type);
|
|
1252
|
+
return false;
|
|
1253
|
+
}
|
|
1254
|
+
|
|
1255
|
+
if (dst_type == wp::ARRAY_TYPE_REGULAR)
|
|
1256
|
+
{
|
|
1257
|
+
const wp::array_t<void>& dst_arr = *static_cast<const wp::array_t<void>*>(dst);
|
|
1258
|
+
dst_data = dst_arr.data;
|
|
1259
|
+
dst_ndim = dst_arr.ndim;
|
|
1260
|
+
dst_shape = dst_arr.shape.dims;
|
|
1261
|
+
dst_strides = dst_arr.strides;
|
|
1262
|
+
dst_indices = null_indices;
|
|
1263
|
+
}
|
|
1264
|
+
else if (dst_type == wp::ARRAY_TYPE_INDEXED)
|
|
1265
|
+
{
|
|
1266
|
+
const wp::indexedarray_t<void>& dst_arr = *static_cast<const wp::indexedarray_t<void>*>(dst);
|
|
1267
|
+
dst_data = dst_arr.arr.data;
|
|
1268
|
+
dst_ndim = dst_arr.arr.ndim;
|
|
1269
|
+
dst_shape = dst_arr.shape.dims;
|
|
1270
|
+
dst_strides = dst_arr.arr.strides;
|
|
1271
|
+
dst_indices = dst_arr.indices;
|
|
1272
|
+
}
|
|
1273
|
+
else if (dst_type == wp::ARRAY_TYPE_FABRIC)
|
|
1274
|
+
{
|
|
1275
|
+
dst_fabricarray = static_cast<wp::fabricarray_t<void>*>(dst);
|
|
1276
|
+
dst_ndim = 1;
|
|
1277
|
+
}
|
|
1278
|
+
else if (dst_type == wp::ARRAY_TYPE_FABRIC_INDEXED)
|
|
1279
|
+
{
|
|
1280
|
+
dst_indexedfabricarray = static_cast<wp::indexedfabricarray_t<void>*>(dst);
|
|
1281
|
+
dst_ndim = 1;
|
|
1282
|
+
}
|
|
1283
|
+
else
|
|
1284
|
+
{
|
|
1285
|
+
fprintf(stderr, "Warp copy error: Invalid array type (%d)\n", dst_type);
|
|
1286
|
+
return false;
|
|
1287
|
+
}
|
|
1288
|
+
|
|
1289
|
+
if (src_ndim != dst_ndim)
|
|
1290
|
+
{
|
|
1291
|
+
fprintf(stderr, "Warp copy error: Incompatible array dimensionalities (%d and %d)\n", src_ndim, dst_ndim);
|
|
1292
|
+
return false;
|
|
1293
|
+
}
|
|
1294
|
+
|
|
1295
|
+
ContextGuard guard(context);
|
|
1296
|
+
|
|
1297
|
+
// handle fabric arrays
|
|
1298
|
+
if (dst_fabricarray)
|
|
1299
|
+
{
|
|
1300
|
+
size_t n = dst_fabricarray->size;
|
|
1301
|
+
if (src_fabricarray)
|
|
1302
|
+
{
|
|
1303
|
+
// copy from fabric to fabric
|
|
1304
|
+
if (src_fabricarray->size != n)
|
|
1305
|
+
{
|
|
1306
|
+
fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
|
|
1307
|
+
return false;
|
|
1308
|
+
}
|
|
1309
|
+
wp_launch_device(WP_CURRENT_CONTEXT, array_copy_fabric_to_fabric_kernel, n,
|
|
1310
|
+
(*dst_fabricarray, *src_fabricarray, elem_size));
|
|
1311
|
+
return true;
|
|
1312
|
+
}
|
|
1313
|
+
else if (src_indexedfabricarray)
|
|
1314
|
+
{
|
|
1315
|
+
// copy from fabric indexed to fabric
|
|
1316
|
+
if (src_indexedfabricarray->size != n)
|
|
1317
|
+
{
|
|
1318
|
+
fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
|
|
1319
|
+
return false;
|
|
1320
|
+
}
|
|
1321
|
+
wp_launch_device(WP_CURRENT_CONTEXT, array_copy_fabric_indexed_to_fabric_kernel, n,
|
|
1322
|
+
(*dst_fabricarray, *src_indexedfabricarray, elem_size));
|
|
1323
|
+
return true;
|
|
1324
|
+
}
|
|
1325
|
+
else
|
|
1326
|
+
{
|
|
1327
|
+
// copy to fabric
|
|
1328
|
+
if (size_t(src_shape[0]) != n)
|
|
1329
|
+
{
|
|
1330
|
+
fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
|
|
1331
|
+
return false;
|
|
1332
|
+
}
|
|
1333
|
+
wp_launch_device(WP_CURRENT_CONTEXT, array_copy_to_fabric_kernel, n,
|
|
1334
|
+
(*dst_fabricarray, src_data, src_strides[0], src_indices[0], elem_size));
|
|
1335
|
+
return true;
|
|
1336
|
+
}
|
|
1337
|
+
}
|
|
1338
|
+
if (dst_indexedfabricarray)
|
|
1339
|
+
{
|
|
1340
|
+
size_t n = dst_indexedfabricarray->size;
|
|
1341
|
+
if (src_fabricarray)
|
|
1342
|
+
{
|
|
1343
|
+
// copy from fabric to fabric indexed
|
|
1344
|
+
if (src_fabricarray->size != n)
|
|
1345
|
+
{
|
|
1346
|
+
fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
|
|
1347
|
+
return false;
|
|
1348
|
+
}
|
|
1349
|
+
wp_launch_device(WP_CURRENT_CONTEXT, array_copy_fabric_to_fabric_indexed_kernel, n,
|
|
1350
|
+
(*dst_indexedfabricarray, *src_fabricarray, elem_size));
|
|
1351
|
+
return true;
|
|
1352
|
+
}
|
|
1353
|
+
else if (src_indexedfabricarray)
|
|
1354
|
+
{
|
|
1355
|
+
// copy from fabric indexed to fabric indexed
|
|
1356
|
+
if (src_indexedfabricarray->size != n)
|
|
1357
|
+
{
|
|
1358
|
+
fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
|
|
1359
|
+
return false;
|
|
1360
|
+
}
|
|
1361
|
+
wp_launch_device(WP_CURRENT_CONTEXT, array_copy_fabric_indexed_to_fabric_indexed_kernel, n,
|
|
1362
|
+
(*dst_indexedfabricarray, *src_indexedfabricarray, elem_size));
|
|
1363
|
+
return true;
|
|
1364
|
+
}
|
|
1365
|
+
else
|
|
1366
|
+
{
|
|
1367
|
+
// copy to fabric indexed
|
|
1368
|
+
if (size_t(src_shape[0]) != n)
|
|
1369
|
+
{
|
|
1370
|
+
fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
|
|
1371
|
+
return false;
|
|
1372
|
+
}
|
|
1373
|
+
wp_launch_device(WP_CURRENT_CONTEXT, array_copy_to_fabric_indexed_kernel, n,
|
|
1374
|
+
(*dst_indexedfabricarray, src_data, src_strides[0], src_indices[0], elem_size));
|
|
1375
|
+
return true;
|
|
1376
|
+
}
|
|
1377
|
+
}
|
|
1378
|
+
else if (src_fabricarray)
|
|
1379
|
+
{
|
|
1380
|
+
// copy from fabric
|
|
1381
|
+
size_t n = src_fabricarray->size;
|
|
1382
|
+
if (size_t(dst_shape[0]) != n)
|
|
1383
|
+
{
|
|
1384
|
+
fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
|
|
1385
|
+
return false;
|
|
1386
|
+
}
|
|
1387
|
+
wp_launch_device(WP_CURRENT_CONTEXT, array_copy_from_fabric_kernel, n,
|
|
1388
|
+
(*src_fabricarray, dst_data, dst_strides[0], dst_indices[0], elem_size));
|
|
1389
|
+
return true;
|
|
1390
|
+
}
|
|
1391
|
+
else if (src_indexedfabricarray)
|
|
1392
|
+
{
|
|
1393
|
+
// copy from fabric indexed
|
|
1394
|
+
size_t n = src_indexedfabricarray->size;
|
|
1395
|
+
if (size_t(dst_shape[0]) != n)
|
|
1396
|
+
{
|
|
1397
|
+
fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
|
|
1398
|
+
return false;
|
|
1399
|
+
}
|
|
1400
|
+
wp_launch_device(WP_CURRENT_CONTEXT, array_copy_from_fabric_indexed_kernel, n,
|
|
1401
|
+
(*src_indexedfabricarray, dst_data, dst_strides[0], dst_indices[0], elem_size));
|
|
1402
|
+
return true;
|
|
1403
|
+
}
|
|
1404
|
+
|
|
1405
|
+
size_t n = 1;
|
|
1406
|
+
for (int i = 0; i < src_ndim; i++)
|
|
1407
|
+
{
|
|
1408
|
+
if (src_shape[i] != dst_shape[i])
|
|
1409
|
+
{
|
|
1410
|
+
fprintf(stderr, "Warp copy error: Incompatible array shapes\n");
|
|
1411
|
+
return false;
|
|
1412
|
+
}
|
|
1413
|
+
n *= src_shape[i];
|
|
1414
|
+
}
|
|
1415
|
+
|
|
1416
|
+
switch (src_ndim)
|
|
1417
|
+
{
|
|
1418
|
+
case 1:
|
|
1419
|
+
{
|
|
1420
|
+
wp_launch_device(WP_CURRENT_CONTEXT, array_copy_1d_kernel, n, (dst_data, src_data,
|
|
1421
|
+
dst_strides[0], src_strides[0],
|
|
1422
|
+
dst_indices[0], src_indices[0],
|
|
1423
|
+
src_shape[0], elem_size));
|
|
1424
|
+
break;
|
|
1425
|
+
}
|
|
1426
|
+
case 2:
|
|
1427
|
+
{
|
|
1428
|
+
wp::vec_t<2, int> shape_v(src_shape[0], src_shape[1]);
|
|
1429
|
+
wp::vec_t<2, int> src_strides_v(src_strides[0], src_strides[1]);
|
|
1430
|
+
wp::vec_t<2, int> dst_strides_v(dst_strides[0], dst_strides[1]);
|
|
1431
|
+
wp::vec_t<2, const int*> src_indices_v(src_indices[0], src_indices[1]);
|
|
1432
|
+
wp::vec_t<2, const int*> dst_indices_v(dst_indices[0], dst_indices[1]);
|
|
1433
|
+
|
|
1434
|
+
wp_launch_device(WP_CURRENT_CONTEXT, array_copy_2d_kernel, n, (dst_data, src_data,
|
|
1435
|
+
dst_strides_v, src_strides_v,
|
|
1436
|
+
dst_indices_v, src_indices_v,
|
|
1437
|
+
shape_v, elem_size));
|
|
1438
|
+
break;
|
|
1439
|
+
}
|
|
1440
|
+
case 3:
|
|
1441
|
+
{
|
|
1442
|
+
wp::vec_t<3, int> shape_v(src_shape[0], src_shape[1], src_shape[2]);
|
|
1443
|
+
wp::vec_t<3, int> src_strides_v(src_strides[0], src_strides[1], src_strides[2]);
|
|
1444
|
+
wp::vec_t<3, int> dst_strides_v(dst_strides[0], dst_strides[1], dst_strides[2]);
|
|
1445
|
+
wp::vec_t<3, const int*> src_indices_v(src_indices[0], src_indices[1], src_indices[2]);
|
|
1446
|
+
wp::vec_t<3, const int*> dst_indices_v(dst_indices[0], dst_indices[1], dst_indices[2]);
|
|
1447
|
+
|
|
1448
|
+
wp_launch_device(WP_CURRENT_CONTEXT, array_copy_3d_kernel, n, (dst_data, src_data,
|
|
1449
|
+
dst_strides_v, src_strides_v,
|
|
1450
|
+
dst_indices_v, src_indices_v,
|
|
1451
|
+
shape_v, elem_size));
|
|
1452
|
+
break;
|
|
1453
|
+
}
|
|
1454
|
+
case 4:
|
|
1455
|
+
{
|
|
1456
|
+
wp::vec_t<4, int> shape_v(src_shape[0], src_shape[1], src_shape[2], src_shape[3]);
|
|
1457
|
+
wp::vec_t<4, int> src_strides_v(src_strides[0], src_strides[1], src_strides[2], src_strides[3]);
|
|
1458
|
+
wp::vec_t<4, int> dst_strides_v(dst_strides[0], dst_strides[1], dst_strides[2], dst_strides[3]);
|
|
1459
|
+
wp::vec_t<4, const int*> src_indices_v(src_indices[0], src_indices[1], src_indices[2], src_indices[3]);
|
|
1460
|
+
wp::vec_t<4, const int*> dst_indices_v(dst_indices[0], dst_indices[1], dst_indices[2], dst_indices[3]);
|
|
1461
|
+
|
|
1462
|
+
wp_launch_device(WP_CURRENT_CONTEXT, array_copy_4d_kernel, n, (dst_data, src_data,
|
|
1463
|
+
dst_strides_v, src_strides_v,
|
|
1464
|
+
dst_indices_v, src_indices_v,
|
|
1465
|
+
shape_v, elem_size));
|
|
1466
|
+
break;
|
|
1467
|
+
}
|
|
1468
|
+
default:
|
|
1469
|
+
fprintf(stderr, "Warp copy error: invalid array dimensionality (%d)\n", src_ndim);
|
|
1470
|
+
return false;
|
|
1471
|
+
}
|
|
1472
|
+
|
|
1473
|
+
return check_cuda(cudaGetLastError());
|
|
1474
|
+
}
|
|
1475
|
+
|
|
1476
|
+
|
|
1477
|
+
static __global__ void array_fill_1d_kernel(void* data,
|
|
1478
|
+
int n,
|
|
1479
|
+
int stride,
|
|
1480
|
+
const int* indices,
|
|
1481
|
+
const void* value,
|
|
1482
|
+
int value_size)
|
|
1483
|
+
{
|
|
1484
|
+
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
|
1485
|
+
if (i < n)
|
|
1486
|
+
{
|
|
1487
|
+
int idx = indices ? indices[i] : i;
|
|
1488
|
+
char* p = (char*)data + idx * stride;
|
|
1489
|
+
memcpy(p, value, value_size);
|
|
1490
|
+
}
|
|
1491
|
+
}
|
|
1492
|
+
|
|
1493
|
+
static __global__ void array_fill_2d_kernel(void* data,
|
|
1494
|
+
wp::vec_t<2, int> shape,
|
|
1495
|
+
wp::vec_t<2, int> strides,
|
|
1496
|
+
wp::vec_t<2, const int*> indices,
|
|
1497
|
+
const void* value,
|
|
1498
|
+
int value_size)
|
|
1499
|
+
{
|
|
1500
|
+
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
|
1501
|
+
int n = shape[1];
|
|
1502
|
+
int i = tid / n;
|
|
1503
|
+
int j = tid % n;
|
|
1504
|
+
if (i < shape[0] /*&& j < shape[1]*/)
|
|
1505
|
+
{
|
|
1506
|
+
int idx0 = indices[0] ? indices[0][i] : i;
|
|
1507
|
+
int idx1 = indices[1] ? indices[1][j] : j;
|
|
1508
|
+
char* p = (char*)data + idx0 * strides[0] + idx1 * strides[1];
|
|
1509
|
+
memcpy(p, value, value_size);
|
|
1510
|
+
}
|
|
1511
|
+
}
|
|
1512
|
+
|
|
1513
|
+
static __global__ void array_fill_3d_kernel(void* data,
|
|
1514
|
+
wp::vec_t<3, int> shape,
|
|
1515
|
+
wp::vec_t<3, int> strides,
|
|
1516
|
+
wp::vec_t<3, const int*> indices,
|
|
1517
|
+
const void* value,
|
|
1518
|
+
int value_size)
|
|
1519
|
+
{
|
|
1520
|
+
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
|
1521
|
+
int n = shape[1];
|
|
1522
|
+
int o = shape[2];
|
|
1523
|
+
int i = tid / (n * o);
|
|
1524
|
+
int j = tid % (n * o) / o;
|
|
1525
|
+
int k = tid % o;
|
|
1526
|
+
if (i < shape[0] && j < shape[1] /*&& k < shape[2]*/)
|
|
1527
|
+
{
|
|
1528
|
+
int idx0 = indices[0] ? indices[0][i] : i;
|
|
1529
|
+
int idx1 = indices[1] ? indices[1][j] : j;
|
|
1530
|
+
int idx2 = indices[2] ? indices[2][k] : k;
|
|
1531
|
+
char* p = (char*)data + idx0 * strides[0] + idx1 * strides[1] + idx2 * strides[2];
|
|
1532
|
+
memcpy(p, value, value_size);
|
|
1533
|
+
}
|
|
1534
|
+
}
|
|
1535
|
+
|
|
1536
|
+
static __global__ void array_fill_4d_kernel(void* data,
|
|
1537
|
+
wp::vec_t<4, int> shape,
|
|
1538
|
+
wp::vec_t<4, int> strides,
|
|
1539
|
+
wp::vec_t<4, const int*> indices,
|
|
1540
|
+
const void* value,
|
|
1541
|
+
int value_size)
|
|
1542
|
+
{
|
|
1543
|
+
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
|
1544
|
+
int n = shape[1];
|
|
1545
|
+
int o = shape[2];
|
|
1546
|
+
int p = shape[3];
|
|
1547
|
+
int i = tid / (n * o * p);
|
|
1548
|
+
int j = tid % (n * o * p) / (o * p);
|
|
1549
|
+
int k = tid % (o * p) / p;
|
|
1550
|
+
int l = tid % p;
|
|
1551
|
+
if (i < shape[0] && j < shape[1] && k < shape[2] /*&& l < shape[3]*/)
|
|
1552
|
+
{
|
|
1553
|
+
int idx0 = indices[0] ? indices[0][i] : i;
|
|
1554
|
+
int idx1 = indices[1] ? indices[1][j] : j;
|
|
1555
|
+
int idx2 = indices[2] ? indices[2][k] : k;
|
|
1556
|
+
int idx3 = indices[3] ? indices[3][l] : l;
|
|
1557
|
+
char* p = (char*)data + idx0 * strides[0] + idx1 * strides[1] + idx2 * strides[2] + idx3 * strides[3];
|
|
1558
|
+
memcpy(p, value, value_size);
|
|
1559
|
+
}
|
|
1560
|
+
}
|
|
1561
|
+
|
|
1562
|
+
|
|
1563
|
+
static __global__ void array_fill_fabric_kernel(wp::fabricarray_t<void> fa, const void* value, int value_size)
|
|
1564
|
+
{
|
|
1565
|
+
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
|
1566
|
+
if (tid < fa.size)
|
|
1567
|
+
{
|
|
1568
|
+
void* dst_ptr = fabricarray_element_ptr(fa, tid, value_size);
|
|
1569
|
+
memcpy(dst_ptr, value, value_size);
|
|
1570
|
+
}
|
|
1571
|
+
}
|
|
1572
|
+
|
|
1573
|
+
|
|
1574
|
+
static __global__ void array_fill_fabric_indexed_kernel(wp::indexedfabricarray_t<void> ifa, const void* value, int value_size)
|
|
1575
|
+
{
|
|
1576
|
+
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
|
1577
|
+
if (tid < ifa.size)
|
|
1578
|
+
{
|
|
1579
|
+
size_t idx = size_t(ifa.indices[tid]);
|
|
1580
|
+
if (idx < ifa.fa.size)
|
|
1581
|
+
{
|
|
1582
|
+
void* dst_ptr = fabricarray_element_ptr(ifa.fa, idx, value_size);
|
|
1583
|
+
memcpy(dst_ptr, value, value_size);
|
|
1584
|
+
}
|
|
1585
|
+
}
|
|
1586
|
+
}
|
|
1587
|
+
|
|
1588
|
+
|
|
1589
|
+
WP_API void array_fill_device(void* context, void* arr_ptr, int arr_type, const void* value_ptr, int value_size)
|
|
1590
|
+
{
|
|
1591
|
+
if (!arr_ptr || !value_ptr)
|
|
1592
|
+
return;
|
|
1593
|
+
|
|
1594
|
+
void* data = NULL;
|
|
1595
|
+
int ndim = 0;
|
|
1596
|
+
const int* shape = NULL;
|
|
1597
|
+
const int* strides = NULL;
|
|
1598
|
+
const int*const* indices = NULL;
|
|
1599
|
+
|
|
1600
|
+
wp::fabricarray_t<void>* fa = NULL;
|
|
1601
|
+
wp::indexedfabricarray_t<void>* ifa = NULL;
|
|
1602
|
+
|
|
1603
|
+
const int* null_indices[wp::ARRAY_MAX_DIMS] = { NULL };
|
|
1604
|
+
|
|
1605
|
+
if (arr_type == wp::ARRAY_TYPE_REGULAR)
|
|
1606
|
+
{
|
|
1607
|
+
wp::array_t<void>& arr = *static_cast<wp::array_t<void>*>(arr_ptr);
|
|
1608
|
+
data = arr.data;
|
|
1609
|
+
ndim = arr.ndim;
|
|
1610
|
+
shape = arr.shape.dims;
|
|
1611
|
+
strides = arr.strides;
|
|
1612
|
+
indices = null_indices;
|
|
1613
|
+
}
|
|
1614
|
+
else if (arr_type == wp::ARRAY_TYPE_INDEXED)
|
|
1615
|
+
{
|
|
1616
|
+
wp::indexedarray_t<void>& ia = *static_cast<wp::indexedarray_t<void>*>(arr_ptr);
|
|
1617
|
+
data = ia.arr.data;
|
|
1618
|
+
ndim = ia.arr.ndim;
|
|
1619
|
+
shape = ia.shape.dims;
|
|
1620
|
+
strides = ia.arr.strides;
|
|
1621
|
+
indices = ia.indices;
|
|
1622
|
+
}
|
|
1623
|
+
else if (arr_type == wp::ARRAY_TYPE_FABRIC)
|
|
1624
|
+
{
|
|
1625
|
+
fa = static_cast<wp::fabricarray_t<void>*>(arr_ptr);
|
|
1626
|
+
}
|
|
1627
|
+
else if (arr_type == wp::ARRAY_TYPE_FABRIC_INDEXED)
|
|
1628
|
+
{
|
|
1629
|
+
ifa = static_cast<wp::indexedfabricarray_t<void>*>(arr_ptr);
|
|
1630
|
+
}
|
|
1631
|
+
else
|
|
1632
|
+
{
|
|
1633
|
+
fprintf(stderr, "Warp fill error: Invalid array type id %d\n", arr_type);
|
|
1634
|
+
return;
|
|
1635
|
+
}
|
|
1636
|
+
|
|
1637
|
+
size_t n = 1;
|
|
1638
|
+
for (int i = 0; i < ndim; i++)
|
|
1639
|
+
n *= shape[i];
|
|
1640
|
+
|
|
1641
|
+
ContextGuard guard(context);
|
|
1642
|
+
|
|
1643
|
+
// copy value to device memory
|
|
1644
|
+
// TODO: use a persistent stream-local staging buffer to avoid allocs?
|
|
1645
|
+
void* value_devptr = alloc_device(WP_CURRENT_CONTEXT, value_size);
|
|
1646
|
+
check_cuda(cudaMemcpyAsync(value_devptr, value_ptr, value_size, cudaMemcpyHostToDevice, get_current_stream()));
|
|
1647
|
+
|
|
1648
|
+
// handle fabric arrays
|
|
1649
|
+
if (fa)
|
|
1650
|
+
{
|
|
1651
|
+
wp_launch_device(WP_CURRENT_CONTEXT, array_fill_fabric_kernel, n,
|
|
1652
|
+
(*fa, value_devptr, value_size));
|
|
1653
|
+
return;
|
|
1654
|
+
}
|
|
1655
|
+
else if (ifa)
|
|
1656
|
+
{
|
|
1657
|
+
wp_launch_device(WP_CURRENT_CONTEXT, array_fill_fabric_indexed_kernel, n,
|
|
1658
|
+
(*ifa, value_devptr, value_size));
|
|
1659
|
+
return;
|
|
1660
|
+
}
|
|
1661
|
+
|
|
1662
|
+
// handle regular or indexed arrays
|
|
1663
|
+
switch (ndim)
|
|
1664
|
+
{
|
|
1665
|
+
case 1:
|
|
1666
|
+
{
|
|
1667
|
+
wp_launch_device(WP_CURRENT_CONTEXT, array_fill_1d_kernel, n,
|
|
1668
|
+
(data, shape[0], strides[0], indices[0], value_devptr, value_size));
|
|
1669
|
+
break;
|
|
1670
|
+
}
|
|
1671
|
+
case 2:
|
|
1672
|
+
{
|
|
1673
|
+
wp::vec_t<2, int> shape_v(shape[0], shape[1]);
|
|
1674
|
+
wp::vec_t<2, int> strides_v(strides[0], strides[1]);
|
|
1675
|
+
wp::vec_t<2, const int*> indices_v(indices[0], indices[1]);
|
|
1676
|
+
wp_launch_device(WP_CURRENT_CONTEXT, array_fill_2d_kernel, n,
|
|
1677
|
+
(data, shape_v, strides_v, indices_v, value_devptr, value_size));
|
|
1678
|
+
break;
|
|
1679
|
+
}
|
|
1680
|
+
case 3:
|
|
1681
|
+
{
|
|
1682
|
+
wp::vec_t<3, int> shape_v(shape[0], shape[1], shape[2]);
|
|
1683
|
+
wp::vec_t<3, int> strides_v(strides[0], strides[1], strides[2]);
|
|
1684
|
+
wp::vec_t<3, const int*> indices_v(indices[0], indices[1], indices[2]);
|
|
1685
|
+
wp_launch_device(WP_CURRENT_CONTEXT, array_fill_3d_kernel, n,
|
|
1686
|
+
(data, shape_v, strides_v, indices_v, value_devptr, value_size));
|
|
1687
|
+
break;
|
|
1688
|
+
}
|
|
1689
|
+
case 4:
|
|
1690
|
+
{
|
|
1691
|
+
wp::vec_t<4, int> shape_v(shape[0], shape[1], shape[2], shape[3]);
|
|
1692
|
+
wp::vec_t<4, int> strides_v(strides[0], strides[1], strides[2], strides[3]);
|
|
1693
|
+
wp::vec_t<4, const int*> indices_v(indices[0], indices[1], indices[2], indices[3]);
|
|
1694
|
+
wp_launch_device(WP_CURRENT_CONTEXT, array_fill_4d_kernel, n,
|
|
1695
|
+
(data, shape_v, strides_v, indices_v, value_devptr, value_size));
|
|
1696
|
+
break;
|
|
1697
|
+
}
|
|
1698
|
+
default:
|
|
1699
|
+
fprintf(stderr, "Warp fill error: invalid array dimensionality (%d)\n", ndim);
|
|
1700
|
+
return;
|
|
1701
|
+
}
|
|
1702
|
+
|
|
1703
|
+
free_device(WP_CURRENT_CONTEXT, value_devptr);
|
|
1704
|
+
}
|
|
1705
|
+
|
|
1706
|
+
void array_scan_int_device(uint64_t in, uint64_t out, int len, bool inclusive)
|
|
1707
|
+
{
|
|
1708
|
+
scan_device((const int*)in, (int*)out, len, inclusive);
|
|
1709
|
+
}
|
|
1710
|
+
|
|
1711
|
+
void array_scan_float_device(uint64_t in, uint64_t out, int len, bool inclusive)
|
|
1712
|
+
{
|
|
1713
|
+
scan_device((const float*)in, (float*)out, len, inclusive);
|
|
1714
|
+
}
|
|
1715
|
+
|
|
1716
|
+
int cuda_driver_version()
|
|
1717
|
+
{
|
|
1718
|
+
int version;
|
|
1719
|
+
if (check_cu(cuDriverGetVersion_f(&version)))
|
|
1720
|
+
return version;
|
|
1721
|
+
else
|
|
1722
|
+
return 0;
|
|
1723
|
+
}
|
|
1724
|
+
|
|
1725
|
+
int cuda_toolkit_version()
|
|
1726
|
+
{
|
|
1727
|
+
return CUDA_VERSION;
|
|
1728
|
+
}
|
|
1729
|
+
|
|
1730
|
+
bool cuda_driver_is_initialized()
|
|
1731
|
+
{
|
|
1732
|
+
return is_cuda_driver_initialized();
|
|
1733
|
+
}
|
|
1734
|
+
|
|
1735
|
+
int nvrtc_supported_arch_count()
|
|
1736
|
+
{
|
|
1737
|
+
int count;
|
|
1738
|
+
if (check_nvrtc(nvrtcGetNumSupportedArchs(&count)))
|
|
1739
|
+
return count;
|
|
1740
|
+
else
|
|
1741
|
+
return 0;
|
|
1742
|
+
}
|
|
1743
|
+
|
|
1744
|
+
void nvrtc_supported_archs(int* archs)
|
|
1745
|
+
{
|
|
1746
|
+
if (archs)
|
|
1747
|
+
{
|
|
1748
|
+
check_nvrtc(nvrtcGetSupportedArchs(archs));
|
|
1749
|
+
}
|
|
1750
|
+
}
|
|
1751
|
+
|
|
1752
|
+
int cuda_device_get_count()
|
|
1753
|
+
{
|
|
1754
|
+
int count = 0;
|
|
1755
|
+
check_cu(cuDeviceGetCount_f(&count));
|
|
1756
|
+
return count;
|
|
1757
|
+
}
|
|
1758
|
+
|
|
1759
|
+
void* cuda_device_get_primary_context(int ordinal)
|
|
1760
|
+
{
|
|
1761
|
+
if (ordinal >= 0 && ordinal < int(g_devices.size()))
|
|
1762
|
+
{
|
|
1763
|
+
DeviceInfo& device_info = g_devices[ordinal];
|
|
1764
|
+
|
|
1765
|
+
// acquire the primary context if we haven't already
|
|
1766
|
+
if (!device_info.primary_context)
|
|
1767
|
+
check_cu(cuDevicePrimaryCtxRetain_f(&device_info.primary_context, device_info.device));
|
|
1768
|
+
|
|
1769
|
+
return device_info.primary_context;
|
|
1770
|
+
}
|
|
1771
|
+
|
|
1772
|
+
return NULL;
|
|
1773
|
+
}
|
|
1774
|
+
|
|
1775
|
+
const char* cuda_device_get_name(int ordinal)
|
|
1776
|
+
{
|
|
1777
|
+
if (ordinal >= 0 && ordinal < int(g_devices.size()))
|
|
1778
|
+
return g_devices[ordinal].name;
|
|
1779
|
+
return NULL;
|
|
1780
|
+
}
|
|
1781
|
+
|
|
1782
|
+
int cuda_device_get_arch(int ordinal)
|
|
1783
|
+
{
|
|
1784
|
+
if (ordinal >= 0 && ordinal < int(g_devices.size()))
|
|
1785
|
+
return g_devices[ordinal].arch;
|
|
1786
|
+
return 0;
|
|
1787
|
+
}
|
|
1788
|
+
|
|
1789
|
+
void cuda_device_get_uuid(int ordinal, char uuid[16])
|
|
1790
|
+
{
|
|
1791
|
+
memcpy(uuid, g_devices[ordinal].uuid.bytes, sizeof(char)*16);
|
|
1792
|
+
}
|
|
1793
|
+
|
|
1794
|
+
int cuda_device_get_pci_domain_id(int ordinal)
|
|
1795
|
+
{
|
|
1796
|
+
if (ordinal >= 0 && ordinal < int(g_devices.size()))
|
|
1797
|
+
return g_devices[ordinal].pci_domain_id;
|
|
1798
|
+
return -1;
|
|
1799
|
+
}
|
|
1800
|
+
|
|
1801
|
+
int cuda_device_get_pci_bus_id(int ordinal)
|
|
1802
|
+
{
|
|
1803
|
+
if (ordinal >= 0 && ordinal < int(g_devices.size()))
|
|
1804
|
+
return g_devices[ordinal].pci_bus_id;
|
|
1805
|
+
return -1;
|
|
1806
|
+
}
|
|
1807
|
+
|
|
1808
|
+
int cuda_device_get_pci_device_id(int ordinal)
|
|
1809
|
+
{
|
|
1810
|
+
if (ordinal >= 0 && ordinal < int(g_devices.size()))
|
|
1811
|
+
return g_devices[ordinal].pci_device_id;
|
|
1812
|
+
return -1;
|
|
1813
|
+
}
|
|
1814
|
+
|
|
1815
|
+
int cuda_device_is_uva(int ordinal)
|
|
1816
|
+
{
|
|
1817
|
+
if (ordinal >= 0 && ordinal < int(g_devices.size()))
|
|
1818
|
+
return g_devices[ordinal].is_uva;
|
|
1819
|
+
return 0;
|
|
1820
|
+
}
|
|
1821
|
+
|
|
1822
|
+
int cuda_device_is_mempool_supported(int ordinal)
|
|
1823
|
+
{
|
|
1824
|
+
if (ordinal >= 0 && ordinal < int(g_devices.size()))
|
|
1825
|
+
return g_devices[ordinal].is_mempool_supported;
|
|
1826
|
+
return 0;
|
|
1827
|
+
}
|
|
1828
|
+
|
|
1829
|
+
int cuda_device_is_ipc_supported(int ordinal)
|
|
1830
|
+
{
|
|
1831
|
+
if (ordinal >= 0 && ordinal < int(g_devices.size()))
|
|
1832
|
+
return g_devices[ordinal].is_ipc_supported;
|
|
1833
|
+
return 0;
|
|
1834
|
+
}
|
|
1835
|
+
|
|
1836
|
+
int cuda_device_set_mempool_release_threshold(int ordinal, uint64_t threshold)
|
|
1837
|
+
{
|
|
1838
|
+
if (ordinal < 0 || ordinal > int(g_devices.size()))
|
|
1839
|
+
{
|
|
1840
|
+
fprintf(stderr, "Invalid device ordinal %d\n", ordinal);
|
|
1841
|
+
return 0;
|
|
1842
|
+
}
|
|
1843
|
+
|
|
1844
|
+
if (!g_devices[ordinal].is_mempool_supported)
|
|
1845
|
+
return 0;
|
|
1846
|
+
|
|
1847
|
+
cudaMemPool_t pool;
|
|
1848
|
+
if (!check_cuda(cudaDeviceGetDefaultMemPool(&pool, ordinal)))
|
|
1849
|
+
{
|
|
1850
|
+
fprintf(stderr, "Warp error: Failed to get memory pool on device %d\n", ordinal);
|
|
1851
|
+
return 0;
|
|
1852
|
+
}
|
|
1853
|
+
|
|
1854
|
+
if (!check_cuda(cudaMemPoolSetAttribute(pool, cudaMemPoolAttrReleaseThreshold, &threshold)))
|
|
1855
|
+
{
|
|
1856
|
+
fprintf(stderr, "Warp error: Failed to set memory pool attribute on device %d\n", ordinal);
|
|
1857
|
+
return 0;
|
|
1858
|
+
}
|
|
1859
|
+
|
|
1860
|
+
return 1; // success
|
|
1861
|
+
}
|
|
1862
|
+
|
|
1863
|
+
uint64_t cuda_device_get_mempool_release_threshold(int ordinal)
|
|
1864
|
+
{
|
|
1865
|
+
if (ordinal < 0 || ordinal > int(g_devices.size()))
|
|
1866
|
+
{
|
|
1867
|
+
fprintf(stderr, "Invalid device ordinal %d\n", ordinal);
|
|
1868
|
+
return 0;
|
|
1869
|
+
}
|
|
1870
|
+
|
|
1871
|
+
if (!g_devices[ordinal].is_mempool_supported)
|
|
1872
|
+
return 0;
|
|
1873
|
+
|
|
1874
|
+
cudaMemPool_t pool;
|
|
1875
|
+
if (!check_cuda(cudaDeviceGetDefaultMemPool(&pool, ordinal)))
|
|
1876
|
+
{
|
|
1877
|
+
fprintf(stderr, "Warp error: Failed to get memory pool on device %d\n", ordinal);
|
|
1878
|
+
return 0;
|
|
1879
|
+
}
|
|
1880
|
+
|
|
1881
|
+
uint64_t threshold = 0;
|
|
1882
|
+
if (!check_cuda(cudaMemPoolGetAttribute(pool, cudaMemPoolAttrReleaseThreshold, &threshold)))
|
|
1883
|
+
{
|
|
1884
|
+
fprintf(stderr, "Warp error: Failed to get memory pool release threshold on device %d\n", ordinal);
|
|
1885
|
+
return 0;
|
|
1886
|
+
}
|
|
1887
|
+
|
|
1888
|
+
return threshold;
|
|
1889
|
+
}
|
|
1890
|
+
|
|
1891
|
+
uint64_t cuda_device_get_mempool_used_mem_current(int ordinal)
|
|
1892
|
+
{
|
|
1893
|
+
if (ordinal < 0 || ordinal > int(g_devices.size()))
|
|
1894
|
+
{
|
|
1895
|
+
fprintf(stderr, "Invalid device ordinal %d\n", ordinal);
|
|
1896
|
+
return 0;
|
|
1897
|
+
}
|
|
1898
|
+
|
|
1899
|
+
if (!g_devices[ordinal].is_mempool_supported)
|
|
1900
|
+
return 0;
|
|
1901
|
+
|
|
1902
|
+
cudaMemPool_t pool;
|
|
1903
|
+
if (!check_cuda(cudaDeviceGetDefaultMemPool(&pool, ordinal)))
|
|
1904
|
+
{
|
|
1905
|
+
fprintf(stderr, "Warp error: Failed to get memory pool on device %d\n", ordinal);
|
|
1906
|
+
return 0;
|
|
1907
|
+
}
|
|
1908
|
+
|
|
1909
|
+
uint64_t mem_used = 0;
|
|
1910
|
+
if (!check_cuda(cudaMemPoolGetAttribute(pool, cudaMemPoolAttrUsedMemCurrent, &mem_used)))
|
|
1911
|
+
{
|
|
1912
|
+
fprintf(stderr, "Warp error: Failed to get amount of currently used memory from the memory pool on device %d\n", ordinal);
|
|
1913
|
+
return 0;
|
|
1914
|
+
}
|
|
1915
|
+
|
|
1916
|
+
return mem_used;
|
|
1917
|
+
}
|
|
1918
|
+
|
|
1919
|
+
uint64_t cuda_device_get_mempool_used_mem_high(int ordinal)
|
|
1920
|
+
{
|
|
1921
|
+
if (ordinal < 0 || ordinal > int(g_devices.size()))
|
|
1922
|
+
{
|
|
1923
|
+
fprintf(stderr, "Invalid device ordinal %d\n", ordinal);
|
|
1924
|
+
return 0;
|
|
1925
|
+
}
|
|
1926
|
+
|
|
1927
|
+
if (!g_devices[ordinal].is_mempool_supported)
|
|
1928
|
+
return 0;
|
|
1929
|
+
|
|
1930
|
+
cudaMemPool_t pool;
|
|
1931
|
+
if (!check_cuda(cudaDeviceGetDefaultMemPool(&pool, ordinal)))
|
|
1932
|
+
{
|
|
1933
|
+
fprintf(stderr, "Warp error: Failed to get memory pool on device %d\n", ordinal);
|
|
1934
|
+
return 0;
|
|
1935
|
+
}
|
|
1936
|
+
|
|
1937
|
+
uint64_t mem_high_water_mark = 0;
|
|
1938
|
+
if (!check_cuda(cudaMemPoolGetAttribute(pool, cudaMemPoolAttrUsedMemHigh, &mem_high_water_mark)))
|
|
1939
|
+
{
|
|
1940
|
+
fprintf(stderr, "Warp error: Failed to get memory usage high water mark from the memory pool on device %d\n", ordinal);
|
|
1941
|
+
return 0;
|
|
1942
|
+
}
|
|
1943
|
+
|
|
1944
|
+
return mem_high_water_mark;
|
|
1945
|
+
}
|
|
1946
|
+
|
|
1947
|
+
void cuda_device_get_memory_info(int ordinal, size_t* free_mem, size_t* total_mem)
|
|
1948
|
+
{
|
|
1949
|
+
// use temporary storage if user didn't specify pointers
|
|
1950
|
+
size_t tmp_free_mem, tmp_total_mem;
|
|
1951
|
+
|
|
1952
|
+
if (free_mem)
|
|
1953
|
+
*free_mem = 0;
|
|
1954
|
+
else
|
|
1955
|
+
free_mem = &tmp_free_mem;
|
|
1956
|
+
|
|
1957
|
+
if (total_mem)
|
|
1958
|
+
*total_mem = 0;
|
|
1959
|
+
else
|
|
1960
|
+
total_mem = &tmp_total_mem;
|
|
1961
|
+
|
|
1962
|
+
if (ordinal >= 0 && ordinal < int(g_devices.size()))
|
|
1963
|
+
{
|
|
1964
|
+
if (g_devices[ordinal].primary_context)
|
|
1965
|
+
{
|
|
1966
|
+
ContextGuard guard(g_devices[ordinal].primary_context, true);
|
|
1967
|
+
check_cu(cuMemGetInfo_f(free_mem, total_mem));
|
|
1968
|
+
}
|
|
1969
|
+
else
|
|
1970
|
+
{
|
|
1971
|
+
// if we haven't acquired the primary context yet, acquire it temporarily
|
|
1972
|
+
CUcontext primary_context = NULL;
|
|
1973
|
+
check_cu(cuDevicePrimaryCtxRetain_f(&primary_context, g_devices[ordinal].device));
|
|
1974
|
+
{
|
|
1975
|
+
ContextGuard guard(primary_context, true);
|
|
1976
|
+
check_cu(cuMemGetInfo_f(free_mem, total_mem));
|
|
1977
|
+
}
|
|
1978
|
+
check_cu(cuDevicePrimaryCtxRelease_f(g_devices[ordinal].device));
|
|
1979
|
+
}
|
|
1980
|
+
}
|
|
1981
|
+
}
|
|
1982
|
+
|
|
1983
|
+
|
|
1984
|
+
void* cuda_context_get_current()
|
|
1985
|
+
{
|
|
1986
|
+
return get_current_context();
|
|
1987
|
+
}
|
|
1988
|
+
|
|
1989
|
+
void cuda_context_set_current(void* context)
|
|
1990
|
+
{
|
|
1991
|
+
CUcontext ctx = static_cast<CUcontext>(context);
|
|
1992
|
+
CUcontext prev_ctx = NULL;
|
|
1993
|
+
check_cu(cuCtxGetCurrent_f(&prev_ctx));
|
|
1994
|
+
if (ctx != prev_ctx)
|
|
1995
|
+
{
|
|
1996
|
+
check_cu(cuCtxSetCurrent_f(ctx));
|
|
1997
|
+
}
|
|
1998
|
+
}
|
|
1999
|
+
|
|
2000
|
+
void cuda_context_push_current(void* context)
|
|
2001
|
+
{
|
|
2002
|
+
check_cu(cuCtxPushCurrent_f(static_cast<CUcontext>(context)));
|
|
2003
|
+
}
|
|
2004
|
+
|
|
2005
|
+
void cuda_context_pop_current()
|
|
2006
|
+
{
|
|
2007
|
+
CUcontext context;
|
|
2008
|
+
check_cu(cuCtxPopCurrent_f(&context));
|
|
2009
|
+
}
|
|
2010
|
+
|
|
2011
|
+
void* cuda_context_create(int device_ordinal)
|
|
2012
|
+
{
|
|
2013
|
+
CUcontext ctx = NULL;
|
|
2014
|
+
CUdevice device;
|
|
2015
|
+
if (check_cu(cuDeviceGet_f(&device, device_ordinal)))
|
|
2016
|
+
check_cu(cuCtxCreate_f(&ctx, 0, device));
|
|
2017
|
+
return ctx;
|
|
2018
|
+
}
|
|
2019
|
+
|
|
2020
|
+
void cuda_context_destroy(void* context)
|
|
2021
|
+
{
|
|
2022
|
+
if (context)
|
|
2023
|
+
{
|
|
2024
|
+
CUcontext ctx = static_cast<CUcontext>(context);
|
|
2025
|
+
|
|
2026
|
+
// ensure this is not the current context
|
|
2027
|
+
if (ctx == cuda_context_get_current())
|
|
2028
|
+
cuda_context_set_current(NULL);
|
|
2029
|
+
|
|
2030
|
+
// release the cached info about this context
|
|
2031
|
+
ContextInfo* info = get_context_info(ctx);
|
|
2032
|
+
if (info)
|
|
2033
|
+
{
|
|
2034
|
+
if (info->stream)
|
|
2035
|
+
check_cu(cuStreamDestroy_f(info->stream));
|
|
2036
|
+
|
|
2037
|
+
g_contexts.erase(ctx);
|
|
2038
|
+
}
|
|
2039
|
+
|
|
2040
|
+
check_cu(cuCtxDestroy_f(ctx));
|
|
2041
|
+
}
|
|
2042
|
+
}
|
|
2043
|
+
|
|
2044
|
+
void cuda_context_synchronize(void* context)
|
|
2045
|
+
{
|
|
2046
|
+
ContextGuard guard(context);
|
|
2047
|
+
|
|
2048
|
+
check_cu(cuCtxSynchronize_f());
|
|
2049
|
+
|
|
2050
|
+
if (free_deferred_allocs(context ? context : get_current_context()) > 0)
|
|
2051
|
+
{
|
|
2052
|
+
// ensure deferred asynchronous deallocations complete
|
|
2053
|
+
check_cu(cuCtxSynchronize_f());
|
|
2054
|
+
}
|
|
2055
|
+
|
|
2056
|
+
unload_deferred_modules(context);
|
|
2057
|
+
|
|
2058
|
+
// check_cuda(cudaDeviceGraphMemTrim(cuda_context_get_device_ordinal(context)));
|
|
2059
|
+
}
|
|
2060
|
+
|
|
2061
|
+
uint64_t cuda_context_check(void* context)
|
|
2062
|
+
{
|
|
2063
|
+
ContextGuard guard(context);
|
|
2064
|
+
|
|
2065
|
+
// check errors before syncing
|
|
2066
|
+
cudaError_t e = cudaGetLastError();
|
|
2067
|
+
check_cuda(e);
|
|
2068
|
+
|
|
2069
|
+
cudaStreamCaptureStatus status = cudaStreamCaptureStatusNone;
|
|
2070
|
+
check_cuda(cudaStreamIsCapturing(get_current_stream(), &status));
|
|
2071
|
+
|
|
2072
|
+
// synchronize if the stream is not capturing
|
|
2073
|
+
if (status == cudaStreamCaptureStatusNone)
|
|
2074
|
+
{
|
|
2075
|
+
check_cuda(cudaDeviceSynchronize());
|
|
2076
|
+
e = cudaGetLastError();
|
|
2077
|
+
}
|
|
2078
|
+
|
|
2079
|
+
return static_cast<uint64_t>(e);
|
|
2080
|
+
}
|
|
2081
|
+
|
|
2082
|
+
|
|
2083
|
+
int cuda_context_get_device_ordinal(void* context)
|
|
2084
|
+
{
|
|
2085
|
+
ContextInfo* info = get_context_info(static_cast<CUcontext>(context));
|
|
2086
|
+
return info && info->device_info ? info->device_info->ordinal : -1;
|
|
2087
|
+
}
|
|
2088
|
+
|
|
2089
|
+
int cuda_context_is_primary(void* context)
|
|
2090
|
+
{
|
|
2091
|
+
CUcontext ctx = static_cast<CUcontext>(context);
|
|
2092
|
+
ContextInfo* context_info = get_context_info(ctx);
|
|
2093
|
+
if (!context_info)
|
|
2094
|
+
{
|
|
2095
|
+
fprintf(stderr, "Warp error: Failed to get context info\n");
|
|
2096
|
+
return 0;
|
|
2097
|
+
}
|
|
2098
|
+
|
|
2099
|
+
// if the device primary context is known, check if it matches the given context
|
|
2100
|
+
DeviceInfo* device_info = context_info->device_info;
|
|
2101
|
+
if (device_info->primary_context)
|
|
2102
|
+
return int(ctx == device_info->primary_context);
|
|
2103
|
+
|
|
2104
|
+
// there is no CUDA API to check if a context is primary, but we can temporarily
|
|
2105
|
+
// acquire the device's primary context to check the pointer
|
|
2106
|
+
CUcontext primary_ctx;
|
|
2107
|
+
if (check_cu(cuDevicePrimaryCtxRetain_f(&primary_ctx, device_info->device)))
|
|
2108
|
+
{
|
|
2109
|
+
check_cu(cuDevicePrimaryCtxRelease_f(device_info->device));
|
|
2110
|
+
return int(ctx == primary_ctx);
|
|
2111
|
+
}
|
|
2112
|
+
|
|
2113
|
+
return 0;
|
|
2114
|
+
}
|
|
2115
|
+
|
|
2116
|
+
void* cuda_context_get_stream(void* context)
|
|
2117
|
+
{
|
|
2118
|
+
ContextInfo* info = get_context_info(static_cast<CUcontext>(context));
|
|
2119
|
+
if (info)
|
|
2120
|
+
{
|
|
2121
|
+
return info->stream;
|
|
2122
|
+
}
|
|
2123
|
+
return NULL;
|
|
2124
|
+
}
|
|
2125
|
+
|
|
2126
|
+
void cuda_context_set_stream(void* context, void* stream, int sync)
|
|
2127
|
+
{
|
|
2128
|
+
ContextInfo* context_info = get_context_info(static_cast<CUcontext>(context));
|
|
2129
|
+
if (context_info)
|
|
2130
|
+
{
|
|
2131
|
+
CUstream new_stream = static_cast<CUstream>(stream);
|
|
2132
|
+
|
|
2133
|
+
// check whether we should sync with the previous stream on this device
|
|
2134
|
+
if (sync)
|
|
2135
|
+
{
|
|
2136
|
+
CUstream old_stream = context_info->stream;
|
|
2137
|
+
StreamInfo* old_stream_info = get_stream_info(old_stream);
|
|
2138
|
+
if (old_stream_info)
|
|
2139
|
+
{
|
|
2140
|
+
CUevent cached_event = old_stream_info->cached_event;
|
|
2141
|
+
check_cu(cuEventRecord_f(cached_event, old_stream));
|
|
2142
|
+
check_cu(cuStreamWaitEvent_f(new_stream, cached_event, CU_EVENT_WAIT_DEFAULT));
|
|
2143
|
+
}
|
|
2144
|
+
}
|
|
2145
|
+
|
|
2146
|
+
context_info->stream = new_stream;
|
|
2147
|
+
}
|
|
2148
|
+
}
|
|
2149
|
+
|
|
2150
|
+
int cuda_is_peer_access_supported(int target_ordinal, int peer_ordinal)
|
|
2151
|
+
{
|
|
2152
|
+
int num_devices = int(g_devices.size());
|
|
2153
|
+
|
|
2154
|
+
if (target_ordinal < 0 || target_ordinal > num_devices)
|
|
2155
|
+
{
|
|
2156
|
+
fprintf(stderr, "Warp error: Invalid target device ordinal %d\n", target_ordinal);
|
|
2157
|
+
return 0;
|
|
2158
|
+
}
|
|
2159
|
+
|
|
2160
|
+
if (peer_ordinal < 0 || peer_ordinal > num_devices)
|
|
2161
|
+
{
|
|
2162
|
+
fprintf(stderr, "Warp error: Invalid peer device ordinal %d\n", peer_ordinal);
|
|
2163
|
+
return 0;
|
|
2164
|
+
}
|
|
2165
|
+
|
|
2166
|
+
if (target_ordinal == peer_ordinal)
|
|
2167
|
+
return 1;
|
|
2168
|
+
|
|
2169
|
+
int can_access = 0;
|
|
2170
|
+
check_cuda(cudaDeviceCanAccessPeer(&can_access, peer_ordinal, target_ordinal));
|
|
2171
|
+
|
|
2172
|
+
return can_access;
|
|
2173
|
+
}
|
|
2174
|
+
|
|
2175
|
+
int cuda_is_peer_access_enabled(void* target_context, void* peer_context)
|
|
2176
|
+
{
|
|
2177
|
+
if (!target_context || !peer_context)
|
|
2178
|
+
{
|
|
2179
|
+
fprintf(stderr, "Warp error: invalid CUDA context\n");
|
|
2180
|
+
return 0;
|
|
2181
|
+
}
|
|
2182
|
+
|
|
2183
|
+
if (target_context == peer_context)
|
|
2184
|
+
return 1;
|
|
2185
|
+
|
|
2186
|
+
int target_ordinal = cuda_context_get_device_ordinal(target_context);
|
|
2187
|
+
int peer_ordinal = cuda_context_get_device_ordinal(peer_context);
|
|
2188
|
+
|
|
2189
|
+
// check if peer access is supported
|
|
2190
|
+
int can_access = 0;
|
|
2191
|
+
check_cuda(cudaDeviceCanAccessPeer(&can_access, peer_ordinal, target_ordinal));
|
|
2192
|
+
if (!can_access)
|
|
2193
|
+
return 0;
|
|
2194
|
+
|
|
2195
|
+
// There is no CUDA API to query if peer access is enabled, but we can try to enable it and check the result.
|
|
2196
|
+
|
|
2197
|
+
ContextGuard guard(peer_context, true);
|
|
2198
|
+
|
|
2199
|
+
CUcontext target_ctx = static_cast<CUcontext>(target_context);
|
|
2200
|
+
|
|
2201
|
+
CUresult result = cuCtxEnablePeerAccess_f(target_ctx, 0);
|
|
2202
|
+
if (result == CUDA_ERROR_PEER_ACCESS_ALREADY_ENABLED)
|
|
2203
|
+
{
|
|
2204
|
+
return 1;
|
|
2205
|
+
}
|
|
2206
|
+
else if (result == CUDA_SUCCESS)
|
|
2207
|
+
{
|
|
2208
|
+
// undo enablement
|
|
2209
|
+
check_cu(cuCtxDisablePeerAccess_f(target_ctx));
|
|
2210
|
+
return 0;
|
|
2211
|
+
}
|
|
2212
|
+
else
|
|
2213
|
+
{
|
|
2214
|
+
// report error
|
|
2215
|
+
check_cu(result);
|
|
2216
|
+
return 0;
|
|
2217
|
+
}
|
|
2218
|
+
}
|
|
2219
|
+
|
|
2220
|
+
int cuda_set_peer_access_enabled(void* target_context, void* peer_context, int enable)
|
|
2221
|
+
{
|
|
2222
|
+
if (!target_context || !peer_context)
|
|
2223
|
+
{
|
|
2224
|
+
fprintf(stderr, "Warp error: invalid CUDA context\n");
|
|
2225
|
+
return 0;
|
|
2226
|
+
}
|
|
2227
|
+
|
|
2228
|
+
if (target_context == peer_context)
|
|
2229
|
+
return 1; // no-op
|
|
2230
|
+
|
|
2231
|
+
int target_ordinal = cuda_context_get_device_ordinal(target_context);
|
|
2232
|
+
int peer_ordinal = cuda_context_get_device_ordinal(peer_context);
|
|
2233
|
+
|
|
2234
|
+
// check if peer access is supported
|
|
2235
|
+
int can_access = 0;
|
|
2236
|
+
check_cuda(cudaDeviceCanAccessPeer(&can_access, peer_ordinal, target_ordinal));
|
|
2237
|
+
if (!can_access)
|
|
2238
|
+
{
|
|
2239
|
+
// failure if enabling, success if disabling
|
|
2240
|
+
if (enable)
|
|
2241
|
+
{
|
|
2242
|
+
fprintf(stderr, "Warp error: device %d cannot access device %d\n", peer_ordinal, target_ordinal);
|
|
2243
|
+
return 0;
|
|
2244
|
+
}
|
|
2245
|
+
else
|
|
2246
|
+
return 1;
|
|
2247
|
+
}
|
|
2248
|
+
|
|
2249
|
+
ContextGuard guard(peer_context, true);
|
|
2250
|
+
|
|
2251
|
+
CUcontext target_ctx = static_cast<CUcontext>(target_context);
|
|
2252
|
+
|
|
2253
|
+
if (enable)
|
|
2254
|
+
{
|
|
2255
|
+
CUresult status = cuCtxEnablePeerAccess_f(target_ctx, 0);
|
|
2256
|
+
if (status != CUDA_SUCCESS && status != CUDA_ERROR_PEER_ACCESS_ALREADY_ENABLED)
|
|
2257
|
+
{
|
|
2258
|
+
check_cu(status);
|
|
2259
|
+
fprintf(stderr, "Warp error: failed to enable peer access from device %d to device %d\n", peer_ordinal, target_ordinal);
|
|
2260
|
+
return 0;
|
|
2261
|
+
}
|
|
2262
|
+
}
|
|
2263
|
+
else
|
|
2264
|
+
{
|
|
2265
|
+
CUresult status = cuCtxDisablePeerAccess_f(target_ctx);
|
|
2266
|
+
if (status != CUDA_SUCCESS && status != CUDA_ERROR_PEER_ACCESS_NOT_ENABLED)
|
|
2267
|
+
{
|
|
2268
|
+
check_cu(status);
|
|
2269
|
+
fprintf(stderr, "Warp error: failed to disable peer access from device %d to device %d\n", peer_ordinal, target_ordinal);
|
|
2270
|
+
return 0;
|
|
2271
|
+
}
|
|
2272
|
+
}
|
|
2273
|
+
|
|
2274
|
+
return 1; // success
|
|
2275
|
+
}
|
|
2276
|
+
|
|
2277
|
+
int cuda_is_mempool_access_enabled(int target_ordinal, int peer_ordinal)
|
|
2278
|
+
{
|
|
2279
|
+
int num_devices = int(g_devices.size());
|
|
2280
|
+
|
|
2281
|
+
if (target_ordinal < 0 || target_ordinal > num_devices)
|
|
2282
|
+
{
|
|
2283
|
+
fprintf(stderr, "Warp error: Invalid device ordinal %d\n", target_ordinal);
|
|
2284
|
+
return 0;
|
|
2285
|
+
}
|
|
2286
|
+
|
|
2287
|
+
if (peer_ordinal < 0 || peer_ordinal > num_devices)
|
|
2288
|
+
{
|
|
2289
|
+
fprintf(stderr, "Warp error: Invalid peer device ordinal %d\n", peer_ordinal);
|
|
2290
|
+
return 0;
|
|
2291
|
+
}
|
|
2292
|
+
|
|
2293
|
+
if (target_ordinal == peer_ordinal)
|
|
2294
|
+
return 1;
|
|
2295
|
+
|
|
2296
|
+
cudaMemPool_t pool;
|
|
2297
|
+
if (!check_cuda(cudaDeviceGetDefaultMemPool(&pool, target_ordinal)))
|
|
2298
|
+
{
|
|
2299
|
+
fprintf(stderr, "Warp error: Failed to get memory pool of device %d\n", target_ordinal);
|
|
2300
|
+
return 0;
|
|
2301
|
+
}
|
|
2302
|
+
|
|
2303
|
+
cudaMemAccessFlags flags = cudaMemAccessFlagsProtNone;
|
|
2304
|
+
cudaMemLocation location;
|
|
2305
|
+
location.id = peer_ordinal;
|
|
2306
|
+
location.type = cudaMemLocationTypeDevice;
|
|
2307
|
+
if (check_cuda(cudaMemPoolGetAccess(&flags, pool, &location)))
|
|
2308
|
+
return int(flags != cudaMemAccessFlagsProtNone);
|
|
2309
|
+
|
|
2310
|
+
return 0;
|
|
2311
|
+
}
|
|
2312
|
+
|
|
2313
|
+
int cuda_set_mempool_access_enabled(int target_ordinal, int peer_ordinal, int enable)
|
|
2314
|
+
{
|
|
2315
|
+
int num_devices = int(g_devices.size());
|
|
2316
|
+
|
|
2317
|
+
if (target_ordinal < 0 || target_ordinal > num_devices)
|
|
2318
|
+
{
|
|
2319
|
+
fprintf(stderr, "Warp error: Invalid device ordinal %d\n", target_ordinal);
|
|
2320
|
+
return 0;
|
|
2321
|
+
}
|
|
2322
|
+
|
|
2323
|
+
if (peer_ordinal < 0 || peer_ordinal > num_devices)
|
|
2324
|
+
{
|
|
2325
|
+
fprintf(stderr, "Warp error: Invalid peer device ordinal %d\n", peer_ordinal);
|
|
2326
|
+
return 0;
|
|
2327
|
+
}
|
|
2328
|
+
|
|
2329
|
+
if (target_ordinal == peer_ordinal)
|
|
2330
|
+
return 1; // no-op
|
|
2331
|
+
|
|
2332
|
+
// get the memory pool
|
|
2333
|
+
cudaMemPool_t pool;
|
|
2334
|
+
if (!check_cuda(cudaDeviceGetDefaultMemPool(&pool, target_ordinal)))
|
|
2335
|
+
{
|
|
2336
|
+
fprintf(stderr, "Warp error: Failed to get memory pool of device %d\n", target_ordinal);
|
|
2337
|
+
return 0;
|
|
2338
|
+
}
|
|
2339
|
+
|
|
2340
|
+
cudaMemAccessDesc desc;
|
|
2341
|
+
desc.location.type = cudaMemLocationTypeDevice;
|
|
2342
|
+
desc.location.id = peer_ordinal;
|
|
2343
|
+
|
|
2344
|
+
// only cudaMemAccessFlagsProtReadWrite and cudaMemAccessFlagsProtNone are supported
|
|
2345
|
+
if (enable)
|
|
2346
|
+
desc.flags = cudaMemAccessFlagsProtReadWrite;
|
|
2347
|
+
else
|
|
2348
|
+
desc.flags = cudaMemAccessFlagsProtNone;
|
|
2349
|
+
|
|
2350
|
+
if (!check_cuda(cudaMemPoolSetAccess(pool, &desc, 1)))
|
|
2351
|
+
{
|
|
2352
|
+
fprintf(stderr, "Warp error: Failed to set mempool access from device %d to device %d\n", peer_ordinal, target_ordinal);
|
|
2353
|
+
return 0;
|
|
2354
|
+
}
|
|
2355
|
+
|
|
2356
|
+
return 1; // success
|
|
2357
|
+
}
|
|
2358
|
+
|
|
2359
|
+
void cuda_ipc_get_mem_handle(void* ptr, char* out_buffer) {
|
|
2360
|
+
CUipcMemHandle memHandle;
|
|
2361
|
+
check_cu(cuIpcGetMemHandle_f(&memHandle, (CUdeviceptr)ptr));
|
|
2362
|
+
memcpy(out_buffer, memHandle.reserved, CU_IPC_HANDLE_SIZE);
|
|
2363
|
+
}
|
|
2364
|
+
|
|
2365
|
+
void* cuda_ipc_open_mem_handle(void* context, char* handle) {
|
|
2366
|
+
ContextGuard guard(context);
|
|
2367
|
+
|
|
2368
|
+
CUipcMemHandle memHandle;
|
|
2369
|
+
memcpy(memHandle.reserved, handle, CU_IPC_HANDLE_SIZE);
|
|
2370
|
+
|
|
2371
|
+
CUdeviceptr device_ptr;
|
|
2372
|
+
|
|
2373
|
+
// Strangely, the CU_IPC_MEM_LAZY_ENABLE_PEER_ACCESS flag is required
|
|
2374
|
+
if check_cu(cuIpcOpenMemHandle_f(&device_ptr, memHandle, CU_IPC_MEM_LAZY_ENABLE_PEER_ACCESS))
|
|
2375
|
+
return (void*) device_ptr;
|
|
2376
|
+
else
|
|
2377
|
+
return NULL;
|
|
2378
|
+
}
|
|
2379
|
+
|
|
2380
|
+
void cuda_ipc_close_mem_handle(void* ptr) {
|
|
2381
|
+
check_cu(cuIpcCloseMemHandle_f((CUdeviceptr) ptr));
|
|
2382
|
+
}
|
|
2383
|
+
|
|
2384
|
+
void cuda_ipc_get_event_handle(void* context, void* event, char* out_buffer) {
|
|
2385
|
+
ContextGuard guard(context);
|
|
2386
|
+
|
|
2387
|
+
CUipcEventHandle eventHandle;
|
|
2388
|
+
check_cu(cuIpcGetEventHandle_f(&eventHandle, static_cast<CUevent>(event)));
|
|
2389
|
+
memcpy(out_buffer, eventHandle.reserved, CU_IPC_HANDLE_SIZE);
|
|
2390
|
+
}
|
|
2391
|
+
|
|
2392
|
+
void* cuda_ipc_open_event_handle(void* context, char* handle) {
|
|
2393
|
+
ContextGuard guard(context);
|
|
2394
|
+
|
|
2395
|
+
CUipcEventHandle eventHandle;
|
|
2396
|
+
memcpy(eventHandle.reserved, handle, CU_IPC_HANDLE_SIZE);
|
|
2397
|
+
|
|
2398
|
+
CUevent event;
|
|
2399
|
+
|
|
2400
|
+
if (check_cu(cuIpcOpenEventHandle_f(&event, eventHandle)))
|
|
2401
|
+
return event;
|
|
2402
|
+
else
|
|
2403
|
+
return NULL;
|
|
2404
|
+
}
|
|
2405
|
+
|
|
2406
|
+
void* cuda_stream_create(void* context, int priority)
|
|
2407
|
+
{
|
|
2408
|
+
ContextGuard guard(context, true);
|
|
2409
|
+
|
|
2410
|
+
CUstream stream;
|
|
2411
|
+
if (check_cu(cuStreamCreateWithPriority_f(&stream, CU_STREAM_DEFAULT, priority)))
|
|
2412
|
+
{
|
|
2413
|
+
cuda_stream_register(WP_CURRENT_CONTEXT, stream);
|
|
2414
|
+
return stream;
|
|
2415
|
+
}
|
|
2416
|
+
else
|
|
2417
|
+
return NULL;
|
|
2418
|
+
}
|
|
2419
|
+
|
|
2420
|
+
void cuda_stream_destroy(void* context, void* stream)
|
|
2421
|
+
{
|
|
2422
|
+
if (!stream)
|
|
2423
|
+
return;
|
|
2424
|
+
|
|
2425
|
+
cuda_stream_unregister(context, stream);
|
|
2426
|
+
|
|
2427
|
+
check_cu(cuStreamDestroy_f(static_cast<CUstream>(stream)));
|
|
2428
|
+
}
|
|
2429
|
+
|
|
2430
|
+
int cuda_stream_query(void* stream)
|
|
2431
|
+
{
|
|
2432
|
+
CUresult res = cuStreamQuery_f(static_cast<CUstream>(stream));
|
|
2433
|
+
|
|
2434
|
+
if ((res != CUDA_SUCCESS) && (res != CUDA_ERROR_NOT_READY))
|
|
2435
|
+
{
|
|
2436
|
+
// Abnormal, print out error
|
|
2437
|
+
check_cu(res);
|
|
2438
|
+
}
|
|
2439
|
+
|
|
2440
|
+
return res;
|
|
2441
|
+
}
|
|
2442
|
+
|
|
2443
|
+
void cuda_stream_register(void* context, void* stream)
|
|
2444
|
+
{
|
|
2445
|
+
if (!stream)
|
|
2446
|
+
return;
|
|
2447
|
+
|
|
2448
|
+
ContextGuard guard(context);
|
|
2449
|
+
|
|
2450
|
+
// populate stream info
|
|
2451
|
+
StreamInfo& stream_info = g_streams[static_cast<CUstream>(stream)];
|
|
2452
|
+
check_cu(cuEventCreate_f(&stream_info.cached_event, CU_EVENT_DISABLE_TIMING));
|
|
2453
|
+
}
|
|
2454
|
+
|
|
2455
|
+
void cuda_stream_unregister(void* context, void* stream)
|
|
2456
|
+
{
|
|
2457
|
+
if (!stream)
|
|
2458
|
+
return;
|
|
2459
|
+
|
|
2460
|
+
CUstream cuda_stream = static_cast<CUstream>(stream);
|
|
2461
|
+
|
|
2462
|
+
StreamInfo* stream_info = get_stream_info(cuda_stream);
|
|
2463
|
+
if (stream_info)
|
|
2464
|
+
{
|
|
2465
|
+
// release stream info
|
|
2466
|
+
check_cu(cuEventDestroy_f(stream_info->cached_event));
|
|
2467
|
+
g_streams.erase(cuda_stream);
|
|
2468
|
+
}
|
|
2469
|
+
|
|
2470
|
+
// make sure we don't leave dangling references to this stream
|
|
2471
|
+
ContextInfo* context_info = get_context_info(context);
|
|
2472
|
+
if (context_info)
|
|
2473
|
+
{
|
|
2474
|
+
if (cuda_stream == context_info->stream)
|
|
2475
|
+
context_info->stream = NULL;
|
|
2476
|
+
}
|
|
2477
|
+
}
|
|
2478
|
+
|
|
2479
|
+
void* cuda_stream_get_current()
|
|
2480
|
+
{
|
|
2481
|
+
return get_current_stream();
|
|
2482
|
+
}
|
|
2483
|
+
|
|
2484
|
+
void cuda_stream_synchronize(void* stream)
|
|
2485
|
+
{
|
|
2486
|
+
check_cu(cuStreamSynchronize_f(static_cast<CUstream>(stream)));
|
|
2487
|
+
}
|
|
2488
|
+
|
|
2489
|
+
void cuda_stream_wait_event(void* stream, void* event)
|
|
2490
|
+
{
|
|
2491
|
+
check_cu(cuStreamWaitEvent_f(static_cast<CUstream>(stream), static_cast<CUevent>(event), 0));
|
|
2492
|
+
}
|
|
2493
|
+
|
|
2494
|
+
void cuda_stream_wait_stream(void* stream, void* other_stream, void* event)
|
|
2495
|
+
{
|
|
2496
|
+
check_cu(cuEventRecord_f(static_cast<CUevent>(event), static_cast<CUstream>(other_stream)));
|
|
2497
|
+
check_cu(cuStreamWaitEvent_f(static_cast<CUstream>(stream), static_cast<CUevent>(event), 0));
|
|
2498
|
+
}
|
|
2499
|
+
|
|
2500
|
+
int cuda_stream_is_capturing(void* stream)
|
|
2501
|
+
{
|
|
2502
|
+
cudaStreamCaptureStatus status = cudaStreamCaptureStatusNone;
|
|
2503
|
+
check_cuda(cudaStreamIsCapturing(static_cast<cudaStream_t>(stream), &status));
|
|
2504
|
+
|
|
2505
|
+
return int(status != cudaStreamCaptureStatusNone);
|
|
2506
|
+
}
|
|
2507
|
+
|
|
2508
|
+
uint64_t cuda_stream_get_capture_id(void* stream)
|
|
2509
|
+
{
|
|
2510
|
+
return get_capture_id(static_cast<CUstream>(stream));
|
|
2511
|
+
}
|
|
2512
|
+
|
|
2513
|
+
int cuda_stream_get_priority(void* stream)
|
|
2514
|
+
{
|
|
2515
|
+
int priority = 0;
|
|
2516
|
+
check_cuda(cuStreamGetPriority_f(static_cast<CUstream>(stream), &priority));
|
|
2517
|
+
|
|
2518
|
+
return priority;
|
|
2519
|
+
}
|
|
2520
|
+
|
|
2521
|
+
void* cuda_event_create(void* context, unsigned flags)
|
|
2522
|
+
{
|
|
2523
|
+
ContextGuard guard(context, true);
|
|
2524
|
+
|
|
2525
|
+
CUevent event;
|
|
2526
|
+
if (check_cu(cuEventCreate_f(&event, flags)))
|
|
2527
|
+
return event;
|
|
2528
|
+
else
|
|
2529
|
+
return NULL;
|
|
2530
|
+
}
|
|
2531
|
+
|
|
2532
|
+
void cuda_event_destroy(void* event)
|
|
2533
|
+
{
|
|
2534
|
+
check_cu(cuEventDestroy_f(static_cast<CUevent>(event)));
|
|
2535
|
+
}
|
|
2536
|
+
|
|
2537
|
+
int cuda_event_query(void* event)
|
|
2538
|
+
{
|
|
2539
|
+
CUresult res = cuEventQuery_f(static_cast<CUevent>(event));
|
|
2540
|
+
|
|
2541
|
+
if ((res != CUDA_SUCCESS) && (res != CUDA_ERROR_NOT_READY))
|
|
2542
|
+
{
|
|
2543
|
+
// Abnormal, print out error
|
|
2544
|
+
check_cu(res);
|
|
2545
|
+
}
|
|
2546
|
+
|
|
2547
|
+
return res;
|
|
2548
|
+
}
|
|
2549
|
+
|
|
2550
|
+
void cuda_event_record(void* event, void* stream, bool timing)
|
|
2551
|
+
{
|
|
2552
|
+
if (timing && !g_captures.empty() && cuda_stream_is_capturing(stream))
|
|
2553
|
+
{
|
|
2554
|
+
// record timing event during graph capture
|
|
2555
|
+
check_cu(cuEventRecordWithFlags_f(static_cast<CUevent>(event), static_cast<CUstream>(stream), CU_EVENT_RECORD_EXTERNAL));
|
|
2556
|
+
}
|
|
2557
|
+
else
|
|
2558
|
+
{
|
|
2559
|
+
check_cu(cuEventRecord_f(static_cast<CUevent>(event), static_cast<CUstream>(stream)));
|
|
2560
|
+
}
|
|
2561
|
+
}
|
|
2562
|
+
|
|
2563
|
+
void cuda_event_synchronize(void* event)
|
|
2564
|
+
{
|
|
2565
|
+
check_cu(cuEventSynchronize_f(static_cast<CUevent>(event)));
|
|
2566
|
+
}
|
|
2567
|
+
|
|
2568
|
+
float cuda_event_elapsed_time(void* start_event, void* end_event)
|
|
2569
|
+
{
|
|
2570
|
+
float elapsed = 0.0f;
|
|
2571
|
+
cudaEvent_t start = static_cast<cudaEvent_t>(start_event);
|
|
2572
|
+
cudaEvent_t end = static_cast<cudaEvent_t>(end_event);
|
|
2573
|
+
check_cuda(cudaEventElapsedTime(&elapsed, start, end));
|
|
2574
|
+
return elapsed;
|
|
2575
|
+
}
|
|
2576
|
+
|
|
2577
|
+
bool cuda_graph_begin_capture(void* context, void* stream, int external)
|
|
2578
|
+
{
|
|
2579
|
+
ContextGuard guard(context);
|
|
2580
|
+
|
|
2581
|
+
CUstream cuda_stream = static_cast<CUstream>(stream);
|
|
2582
|
+
StreamInfo* stream_info = get_stream_info(cuda_stream);
|
|
2583
|
+
if (!stream_info)
|
|
2584
|
+
{
|
|
2585
|
+
wp::set_error_string("Warp error: unknown stream");
|
|
2586
|
+
return false;
|
|
2587
|
+
}
|
|
2588
|
+
|
|
2589
|
+
if (external)
|
|
2590
|
+
{
|
|
2591
|
+
// if it's an external capture, make sure it's already active so we can get the capture id
|
|
2592
|
+
cudaStreamCaptureStatus status = cudaStreamCaptureStatusNone;
|
|
2593
|
+
if (!check_cuda(cudaStreamIsCapturing(cuda_stream, &status)))
|
|
2594
|
+
return false;
|
|
2595
|
+
if (status != cudaStreamCaptureStatusActive)
|
|
2596
|
+
{
|
|
2597
|
+
wp::set_error_string("Warp error: stream is not capturing");
|
|
2598
|
+
return false;
|
|
2599
|
+
}
|
|
2600
|
+
}
|
|
2601
|
+
else
|
|
2602
|
+
{
|
|
2603
|
+
// start the capture
|
|
2604
|
+
if (!check_cuda(cudaStreamBeginCapture(cuda_stream, cudaStreamCaptureModeGlobal)))
|
|
2605
|
+
return false;
|
|
2606
|
+
}
|
|
2607
|
+
|
|
2608
|
+
uint64_t capture_id = get_capture_id(cuda_stream);
|
|
2609
|
+
|
|
2610
|
+
CaptureInfo* capture = new CaptureInfo();
|
|
2611
|
+
capture->stream = cuda_stream;
|
|
2612
|
+
capture->id = capture_id;
|
|
2613
|
+
capture->external = bool(external);
|
|
2614
|
+
|
|
2615
|
+
// update stream info
|
|
2616
|
+
stream_info->capture = capture;
|
|
2617
|
+
|
|
2618
|
+
// add to known captures
|
|
2619
|
+
g_captures[capture_id] = capture;
|
|
2620
|
+
|
|
2621
|
+
return true;
|
|
2622
|
+
}
|
|
2623
|
+
|
|
2624
|
+
bool cuda_graph_end_capture(void* context, void* stream, void** graph_ret)
|
|
2625
|
+
{
|
|
2626
|
+
ContextGuard guard(context);
|
|
2627
|
+
|
|
2628
|
+
// check if this is a known stream
|
|
2629
|
+
CUstream cuda_stream = static_cast<CUstream>(stream);
|
|
2630
|
+
StreamInfo* stream_info = get_stream_info(cuda_stream);
|
|
2631
|
+
if (!stream_info)
|
|
2632
|
+
{
|
|
2633
|
+
wp::set_error_string("Warp error: unknown capture stream");
|
|
2634
|
+
return false;
|
|
2635
|
+
}
|
|
2636
|
+
|
|
2637
|
+
// check if this stream was used to start a capture
|
|
2638
|
+
CaptureInfo* capture = stream_info->capture;
|
|
2639
|
+
if (!capture)
|
|
2640
|
+
{
|
|
2641
|
+
wp::set_error_string("Warp error: stream has no capture started");
|
|
2642
|
+
return false;
|
|
2643
|
+
}
|
|
2644
|
+
|
|
2645
|
+
// get capture info
|
|
2646
|
+
bool external = capture->external;
|
|
2647
|
+
uint64_t capture_id = capture->id;
|
|
2648
|
+
|
|
2649
|
+
// clear capture info
|
|
2650
|
+
stream_info->capture = NULL;
|
|
2651
|
+
g_captures.erase(capture_id);
|
|
2652
|
+
delete capture;
|
|
2653
|
+
|
|
2654
|
+
// a lambda to clean up on exit in case of error
|
|
2655
|
+
auto clean_up = [cuda_stream, capture_id, external]()
|
|
2656
|
+
{
|
|
2657
|
+
// unreference outstanding graph allocs so that they will be released with the user reference
|
|
2658
|
+
for (auto it = g_graph_allocs.begin(); it != g_graph_allocs.end(); ++it)
|
|
2659
|
+
{
|
|
2660
|
+
GraphAllocInfo& alloc_info = it->second;
|
|
2661
|
+
if (alloc_info.capture_id == capture_id)
|
|
2662
|
+
alloc_info.graph_destroyed = true;
|
|
2663
|
+
}
|
|
2664
|
+
|
|
2665
|
+
// make sure we terminate the capture
|
|
2666
|
+
if (!external)
|
|
2667
|
+
{
|
|
2668
|
+
cudaGraph_t graph = NULL;
|
|
2669
|
+
cudaStreamEndCapture(cuda_stream, &graph);
|
|
2670
|
+
cudaGetLastError();
|
|
2671
|
+
}
|
|
2672
|
+
};
|
|
2673
|
+
|
|
2674
|
+
// get captured graph without ending the capture in case it is external
|
|
2675
|
+
cudaGraph_t graph = get_capture_graph(cuda_stream);
|
|
2676
|
+
if (!graph)
|
|
2677
|
+
{
|
|
2678
|
+
clean_up();
|
|
2679
|
+
return false;
|
|
2680
|
+
}
|
|
2681
|
+
|
|
2682
|
+
// ensure that all forked streams are joined to the main capture stream by manually
|
|
2683
|
+
// adding outstanding capture dependencies gathered from the graph leaf nodes
|
|
2684
|
+
std::vector<cudaGraphNode_t> stream_dependencies;
|
|
2685
|
+
std::vector<cudaGraphNode_t> leaf_nodes;
|
|
2686
|
+
if (get_capture_dependencies(cuda_stream, stream_dependencies) && get_graph_leaf_nodes(graph, leaf_nodes))
|
|
2687
|
+
{
|
|
2688
|
+
// compute set difference to get unjoined dependencies
|
|
2689
|
+
std::vector<cudaGraphNode_t> unjoined_dependencies;
|
|
2690
|
+
std::sort(stream_dependencies.begin(), stream_dependencies.end());
|
|
2691
|
+
std::sort(leaf_nodes.begin(), leaf_nodes.end());
|
|
2692
|
+
std::set_difference(leaf_nodes.begin(), leaf_nodes.end(),
|
|
2693
|
+
stream_dependencies.begin(), stream_dependencies.end(),
|
|
2694
|
+
std::back_inserter(unjoined_dependencies));
|
|
2695
|
+
if (!unjoined_dependencies.empty())
|
|
2696
|
+
{
|
|
2697
|
+
check_cu(cuStreamUpdateCaptureDependencies_f(cuda_stream, unjoined_dependencies.data(), unjoined_dependencies.size(),
|
|
2698
|
+
CU_STREAM_ADD_CAPTURE_DEPENDENCIES));
|
|
2699
|
+
// ensure graph is still valid
|
|
2700
|
+
if (get_capture_graph(cuda_stream) != graph)
|
|
2701
|
+
{
|
|
2702
|
+
clean_up();
|
|
2703
|
+
return false;
|
|
2704
|
+
}
|
|
2705
|
+
}
|
|
2706
|
+
}
|
|
2707
|
+
|
|
2708
|
+
// check if this graph has unfreed allocations, which require special handling
|
|
2709
|
+
std::vector<void*> unfreed_allocs;
|
|
2710
|
+
for (auto it = g_graph_allocs.begin(); it != g_graph_allocs.end(); ++it)
|
|
2711
|
+
{
|
|
2712
|
+
GraphAllocInfo& alloc_info = it->second;
|
|
2713
|
+
if (alloc_info.capture_id == capture_id)
|
|
2714
|
+
unfreed_allocs.push_back(it->first);
|
|
2715
|
+
}
|
|
2716
|
+
|
|
2717
|
+
if (!unfreed_allocs.empty())
|
|
2718
|
+
{
|
|
2719
|
+
// Create a user object that will notify us when the instantiated graph is destroyed.
|
|
2720
|
+
// This works for external captures also, since we wouldn't otherwise know when
|
|
2721
|
+
// the externally-created graph instance gets deleted.
|
|
2722
|
+
// This callback is guaranteed to arrive after the graph has finished executing on the device,
|
|
2723
|
+
// not necessarily when cudaGraphExecDestroy() is called.
|
|
2724
|
+
GraphInfo* graph_info = new GraphInfo;
|
|
2725
|
+
graph_info->unfreed_allocs = unfreed_allocs;
|
|
2726
|
+
cudaUserObject_t user_object;
|
|
2727
|
+
check_cuda(cudaUserObjectCreate(&user_object, graph_info, on_graph_destroy, 1, cudaUserObjectNoDestructorSync));
|
|
2728
|
+
check_cuda(cudaGraphRetainUserObject(graph, user_object, 1, cudaGraphUserObjectMove));
|
|
2729
|
+
|
|
2730
|
+
// ensure graph is still valid
|
|
2731
|
+
if (get_capture_graph(cuda_stream) != graph)
|
|
2732
|
+
{
|
|
2733
|
+
clean_up();
|
|
2734
|
+
return false;
|
|
2735
|
+
}
|
|
2736
|
+
}
|
|
2737
|
+
|
|
2738
|
+
// for external captures, we don't instantiate the graph ourselves, so we're done
|
|
2739
|
+
if (external)
|
|
2740
|
+
return true;
|
|
2741
|
+
|
|
2742
|
+
cudaGraphExec_t graph_exec = NULL;
|
|
2743
|
+
|
|
2744
|
+
// end the capture
|
|
2745
|
+
if (!check_cuda(cudaStreamEndCapture(cuda_stream, &graph)))
|
|
2746
|
+
return false;
|
|
2747
|
+
|
|
2748
|
+
// enable to create debug GraphVis visualization of graph
|
|
2749
|
+
// cudaGraphDebugDotPrint(graph, "graph.dot", cudaGraphDebugDotFlagsVerbose);
|
|
2750
|
+
|
|
2751
|
+
// can use after CUDA 11.4 to permit graphs to capture cudaMallocAsync() operations
|
|
2752
|
+
if (!check_cuda(cudaGraphInstantiateWithFlags(&graph_exec, graph, cudaGraphInstantiateFlagAutoFreeOnLaunch)))
|
|
2753
|
+
return false;
|
|
2754
|
+
|
|
2755
|
+
// free source graph
|
|
2756
|
+
check_cuda(cudaGraphDestroy(graph));
|
|
2757
|
+
|
|
2758
|
+
// process deferred free list if no more captures are ongoing
|
|
2759
|
+
if (g_captures.empty())
|
|
2760
|
+
{
|
|
2761
|
+
free_deferred_allocs();
|
|
2762
|
+
unload_deferred_modules();
|
|
2763
|
+
}
|
|
2764
|
+
|
|
2765
|
+
if (graph_ret)
|
|
2766
|
+
*graph_ret = graph_exec;
|
|
2767
|
+
|
|
2768
|
+
return true;
|
|
2769
|
+
}
|
|
2770
|
+
|
|
2771
|
+
bool cuda_graph_launch(void* graph_exec, void* stream)
|
|
2772
|
+
{
|
|
2773
|
+
// TODO: allow naming graphs?
|
|
2774
|
+
begin_cuda_range(WP_TIMING_GRAPH, stream, get_stream_context(stream), "graph");
|
|
2775
|
+
|
|
2776
|
+
bool result = check_cuda(cudaGraphLaunch((cudaGraphExec_t)graph_exec, (cudaStream_t)stream));
|
|
2777
|
+
|
|
2778
|
+
end_cuda_range(WP_TIMING_GRAPH, stream);
|
|
2779
|
+
|
|
2780
|
+
return result;
|
|
2781
|
+
}
|
|
2782
|
+
|
|
2783
|
+
bool cuda_graph_destroy(void* context, void* graph_exec)
|
|
2784
|
+
{
|
|
2785
|
+
ContextGuard guard(context);
|
|
2786
|
+
|
|
2787
|
+
return check_cuda(cudaGraphExecDestroy((cudaGraphExec_t)graph_exec));
|
|
2788
|
+
}
|
|
2789
|
+
|
|
2790
|
+
bool write_file(const char* data, size_t size, std::string filename, const char* mode)
|
|
2791
|
+
{
|
|
2792
|
+
const bool print_debug = (std::getenv("WARP_DEBUG") != nullptr);
|
|
2793
|
+
if (print_debug)
|
|
2794
|
+
{
|
|
2795
|
+
printf("Writing %zu B to %s (%s)\n", size, filename.c_str(), mode);
|
|
2796
|
+
}
|
|
2797
|
+
FILE* file = fopen(filename.c_str(), mode);
|
|
2798
|
+
if (file)
|
|
2799
|
+
{
|
|
2800
|
+
if (fwrite(data, 1, size, file) != size) {
|
|
2801
|
+
fprintf(stderr, "Warp error: Failed to write to output file '%s'\n", filename.c_str());
|
|
2802
|
+
return false;
|
|
2803
|
+
}
|
|
2804
|
+
fclose(file);
|
|
2805
|
+
return true;
|
|
2806
|
+
}
|
|
2807
|
+
else
|
|
2808
|
+
{
|
|
2809
|
+
fprintf(stderr, "Warp error: Failed to open file '%s'\n", filename.c_str());
|
|
2810
|
+
return false;
|
|
2811
|
+
}
|
|
2812
|
+
}
|
|
2813
|
+
|
|
2814
|
+
#if WP_ENABLE_MATHDX
|
|
2815
|
+
bool check_nvjitlink_result(nvJitLinkHandle handle, nvJitLinkResult result, const char* file, int line)
|
|
2816
|
+
{
|
|
2817
|
+
if (result != NVJITLINK_SUCCESS) {
|
|
2818
|
+
fprintf(stderr, "nvJitLink error: %d on %s:%d\n", (int)result, file, line);
|
|
2819
|
+
size_t lsize;
|
|
2820
|
+
result = nvJitLinkGetErrorLogSize(handle, &lsize);
|
|
2821
|
+
if (result == NVJITLINK_SUCCESS && lsize > 0) {
|
|
2822
|
+
std::vector<char> log(lsize);
|
|
2823
|
+
result = nvJitLinkGetErrorLog(handle, log.data());
|
|
2824
|
+
if (result == NVJITLINK_SUCCESS) {
|
|
2825
|
+
fprintf(stderr, "%s\n", log.data());
|
|
2826
|
+
}
|
|
2827
|
+
}
|
|
2828
|
+
return false;
|
|
2829
|
+
} else {
|
|
2830
|
+
return true;
|
|
2831
|
+
}
|
|
2832
|
+
}
|
|
2833
|
+
#endif
|
|
2834
|
+
|
|
2835
|
+
size_t cuda_compile_program(const char* cuda_src, const char* program_name, int arch, const char* include_dir, int num_cuda_include_dirs, const char** cuda_include_dirs, bool debug, bool verbose, bool verify_fp, bool fast_math, bool fuse_fp, bool lineinfo, const char* output_path, size_t num_ltoirs, char** ltoirs, size_t* ltoir_sizes, int* ltoir_input_types)
|
|
2836
|
+
{
|
|
2837
|
+
// use file extension to determine whether to output PTX or CUBIN
|
|
2838
|
+
const char* output_ext = strrchr(output_path, '.');
|
|
2839
|
+
bool use_ptx = output_ext && strcmp(output_ext + 1, "ptx") == 0;
|
|
2840
|
+
const bool print_debug = (std::getenv("WARP_DEBUG") != nullptr);
|
|
2841
|
+
|
|
2842
|
+
// check include dir path len (path + option)
|
|
2843
|
+
const int max_path = 4096 + 16;
|
|
2844
|
+
if (strlen(include_dir) > max_path)
|
|
2845
|
+
{
|
|
2846
|
+
fprintf(stderr, "Warp error: Include path too long\n");
|
|
2847
|
+
return size_t(-1);
|
|
2848
|
+
}
|
|
2849
|
+
|
|
2850
|
+
if (print_debug)
|
|
2851
|
+
{
|
|
2852
|
+
// Not available in all nvJitLink versions
|
|
2853
|
+
// unsigned major = 0;
|
|
2854
|
+
// unsigned minor = 0;
|
|
2855
|
+
// nvJitLinkVersion(&major, &minor);
|
|
2856
|
+
// printf("nvJitLink version %d.%d\n", major, minor);
|
|
2857
|
+
int major = 0;
|
|
2858
|
+
int minor = 0;
|
|
2859
|
+
nvrtcVersion(&major, &minor);
|
|
2860
|
+
printf("NVRTC version %d.%d\n", major, minor);
|
|
2861
|
+
}
|
|
2862
|
+
|
|
2863
|
+
char include_opt[max_path];
|
|
2864
|
+
strcpy(include_opt, "--include-path=");
|
|
2865
|
+
strcat(include_opt, include_dir);
|
|
2866
|
+
|
|
2867
|
+
const int max_arch = 128;
|
|
2868
|
+
char arch_opt[max_arch];
|
|
2869
|
+
char arch_opt_lto[max_arch];
|
|
2870
|
+
|
|
2871
|
+
if (use_ptx)
|
|
2872
|
+
{
|
|
2873
|
+
snprintf(arch_opt, max_arch, "--gpu-architecture=compute_%d", arch);
|
|
2874
|
+
snprintf(arch_opt_lto, max_arch, "-arch=compute_%d", arch);
|
|
2875
|
+
}
|
|
2876
|
+
else
|
|
2877
|
+
{
|
|
2878
|
+
snprintf(arch_opt, max_arch, "--gpu-architecture=sm_%d", arch);
|
|
2879
|
+
snprintf(arch_opt_lto, max_arch, "-arch=sm_%d", arch);
|
|
2880
|
+
}
|
|
2881
|
+
|
|
2882
|
+
std::vector<const char*> opts;
|
|
2883
|
+
opts.push_back(arch_opt);
|
|
2884
|
+
opts.push_back(include_opt);
|
|
2885
|
+
opts.push_back("--std=c++17");
|
|
2886
|
+
|
|
2887
|
+
if (debug)
|
|
2888
|
+
{
|
|
2889
|
+
opts.push_back("--define-macro=_DEBUG");
|
|
2890
|
+
opts.push_back("--generate-line-info");
|
|
2891
|
+
|
|
2892
|
+
// disabling since it causes issues with `Unresolved extern function 'cudaGetParameterBufferV2'
|
|
2893
|
+
//opts.push_back("--device-debug");
|
|
2894
|
+
}
|
|
2895
|
+
else
|
|
2896
|
+
{
|
|
2897
|
+
opts.push_back("--define-macro=NDEBUG");
|
|
2898
|
+
|
|
2899
|
+
if (lineinfo)
|
|
2900
|
+
opts.push_back("--generate-line-info");
|
|
2901
|
+
}
|
|
2902
|
+
|
|
2903
|
+
if (verify_fp)
|
|
2904
|
+
opts.push_back("--define-macro=WP_VERIFY_FP");
|
|
2905
|
+
else
|
|
2906
|
+
opts.push_back("--undefine-macro=WP_VERIFY_FP");
|
|
2907
|
+
|
|
2908
|
+
#if WP_ENABLE_MATHDX
|
|
2909
|
+
opts.push_back("--define-macro=WP_ENABLE_MATHDX=1");
|
|
2910
|
+
#else
|
|
2911
|
+
opts.push_back("--define-macro=WP_ENABLE_MATHDX=0");
|
|
2912
|
+
#endif
|
|
2913
|
+
|
|
2914
|
+
if (fast_math)
|
|
2915
|
+
opts.push_back("--use_fast_math");
|
|
2916
|
+
|
|
2917
|
+
if (fuse_fp)
|
|
2918
|
+
opts.push_back("--fmad=true");
|
|
2919
|
+
else
|
|
2920
|
+
opts.push_back("--fmad=false");
|
|
2921
|
+
|
|
2922
|
+
std::vector<std::string> cuda_include_opt;
|
|
2923
|
+
for(int i = 0; i < num_cuda_include_dirs; i++)
|
|
2924
|
+
{
|
|
2925
|
+
cuda_include_opt.push_back(std::string("--include-path=") + cuda_include_dirs[i]);
|
|
2926
|
+
opts.push_back(cuda_include_opt.back().c_str());
|
|
2927
|
+
}
|
|
2928
|
+
|
|
2929
|
+
opts.push_back("--device-as-default-execution-space");
|
|
2930
|
+
opts.push_back("--extra-device-vectorization");
|
|
2931
|
+
opts.push_back("--restrict");
|
|
2932
|
+
|
|
2933
|
+
if (num_ltoirs > 0)
|
|
2934
|
+
{
|
|
2935
|
+
opts.push_back("-dlto");
|
|
2936
|
+
opts.push_back("--relocatable-device-code=true");
|
|
2937
|
+
}
|
|
2938
|
+
|
|
2939
|
+
nvrtcProgram prog;
|
|
2940
|
+
nvrtcResult res;
|
|
2941
|
+
|
|
2942
|
+
res = nvrtcCreateProgram(
|
|
2943
|
+
&prog, // prog
|
|
2944
|
+
cuda_src, // buffer
|
|
2945
|
+
program_name, // name
|
|
2946
|
+
0, // numHeaders
|
|
2947
|
+
NULL, // headers
|
|
2948
|
+
NULL); // includeNames
|
|
2949
|
+
|
|
2950
|
+
if (!check_nvrtc(res))
|
|
2951
|
+
return size_t(res);
|
|
2952
|
+
|
|
2953
|
+
if (print_debug)
|
|
2954
|
+
{
|
|
2955
|
+
printf("NVRTC options:\n");
|
|
2956
|
+
for(auto o: opts) {
|
|
2957
|
+
printf("%s\n", o);
|
|
2958
|
+
}
|
|
2959
|
+
}
|
|
2960
|
+
res = nvrtcCompileProgram(prog, int(opts.size()), opts.data());
|
|
2961
|
+
|
|
2962
|
+
if (!check_nvrtc(res) || verbose)
|
|
2963
|
+
{
|
|
2964
|
+
// get program log
|
|
2965
|
+
size_t log_size;
|
|
2966
|
+
if (check_nvrtc(nvrtcGetProgramLogSize(prog, &log_size)))
|
|
2967
|
+
{
|
|
2968
|
+
std::vector<char> log(log_size);
|
|
2969
|
+
if (check_nvrtc(nvrtcGetProgramLog(prog, log.data())))
|
|
2970
|
+
{
|
|
2971
|
+
// todo: figure out better way to return this to python
|
|
2972
|
+
if (res != NVRTC_SUCCESS)
|
|
2973
|
+
fprintf(stderr, "%s", log.data());
|
|
2974
|
+
else
|
|
2975
|
+
fprintf(stdout, "%s", log.data());
|
|
2976
|
+
}
|
|
2977
|
+
}
|
|
2978
|
+
|
|
2979
|
+
if (res != NVRTC_SUCCESS)
|
|
2980
|
+
{
|
|
2981
|
+
nvrtcDestroyProgram(&prog);
|
|
2982
|
+
return size_t(res);
|
|
2983
|
+
}
|
|
2984
|
+
}
|
|
2985
|
+
|
|
2986
|
+
nvrtcResult (*get_output_size)(nvrtcProgram, size_t*);
|
|
2987
|
+
nvrtcResult (*get_output_data)(nvrtcProgram, char*);
|
|
2988
|
+
const char* output_mode;
|
|
2989
|
+
if(num_ltoirs > 0) {
|
|
2990
|
+
#if WP_ENABLE_MATHDX
|
|
2991
|
+
get_output_size = nvrtcGetLTOIRSize;
|
|
2992
|
+
get_output_data = nvrtcGetLTOIR;
|
|
2993
|
+
output_mode = "wb";
|
|
2994
|
+
#else
|
|
2995
|
+
fprintf(stderr, "Warp error: num_ltoirs > 0 but Warp was not built with MathDx support\n");
|
|
2996
|
+
return size_t(-1);
|
|
2997
|
+
#endif
|
|
2998
|
+
}
|
|
2999
|
+
else if (use_ptx)
|
|
3000
|
+
{
|
|
3001
|
+
get_output_size = nvrtcGetPTXSize;
|
|
3002
|
+
get_output_data = nvrtcGetPTX;
|
|
3003
|
+
output_mode = "wt";
|
|
3004
|
+
}
|
|
3005
|
+
else
|
|
3006
|
+
{
|
|
3007
|
+
get_output_size = nvrtcGetCUBINSize;
|
|
3008
|
+
get_output_data = nvrtcGetCUBIN;
|
|
3009
|
+
output_mode = "wb";
|
|
3010
|
+
}
|
|
3011
|
+
|
|
3012
|
+
// save output
|
|
3013
|
+
size_t output_size;
|
|
3014
|
+
res = get_output_size(prog, &output_size);
|
|
3015
|
+
if (check_nvrtc(res))
|
|
3016
|
+
{
|
|
3017
|
+
std::vector<char> output(output_size);
|
|
3018
|
+
res = get_output_data(prog, output.data());
|
|
3019
|
+
if (check_nvrtc(res))
|
|
3020
|
+
{
|
|
3021
|
+
|
|
3022
|
+
// LTOIR case - need an extra step
|
|
3023
|
+
if (num_ltoirs > 0)
|
|
3024
|
+
{
|
|
3025
|
+
#if WP_ENABLE_MATHDX
|
|
3026
|
+
if(ltoir_input_types == nullptr || ltoirs == nullptr || ltoir_sizes == nullptr) {
|
|
3027
|
+
fprintf(stderr, "Warp error: num_ltoirs > 0 but ltoir_input_types, ltoirs or ltoir_sizes are NULL\n");
|
|
3028
|
+
return size_t(-1);
|
|
3029
|
+
}
|
|
3030
|
+
nvJitLinkHandle handle;
|
|
3031
|
+
std::vector<const char *> lopts = {"-dlto", arch_opt_lto};
|
|
3032
|
+
if (use_ptx) {
|
|
3033
|
+
lopts.push_back("-ptx");
|
|
3034
|
+
}
|
|
3035
|
+
if (print_debug)
|
|
3036
|
+
{
|
|
3037
|
+
printf("nvJitLink options:\n");
|
|
3038
|
+
for(auto o: lopts) {
|
|
3039
|
+
printf("%s\n", o);
|
|
3040
|
+
}
|
|
3041
|
+
}
|
|
3042
|
+
if(!check_nvjitlink(handle, nvJitLinkCreate(&handle, lopts.size(), lopts.data())))
|
|
3043
|
+
{
|
|
3044
|
+
res = nvrtcResult(-1);
|
|
3045
|
+
}
|
|
3046
|
+
// Links
|
|
3047
|
+
if(std::getenv("WARP_DUMP_LTOIR"))
|
|
3048
|
+
{
|
|
3049
|
+
write_file(output.data(), output.size(), "nvrtc_output.ltoir", "wb");
|
|
3050
|
+
}
|
|
3051
|
+
if(!check_nvjitlink(handle, nvJitLinkAddData(handle, NVJITLINK_INPUT_LTOIR, output.data(), output.size(), "nvrtc_output"))) // NVRTC business
|
|
3052
|
+
{
|
|
3053
|
+
res = nvrtcResult(-1);
|
|
3054
|
+
}
|
|
3055
|
+
for(size_t ltoidx = 0; ltoidx < num_ltoirs; ltoidx++)
|
|
3056
|
+
{
|
|
3057
|
+
nvJitLinkInputType input_type = static_cast<nvJitLinkInputType>(ltoir_input_types[ltoidx]);
|
|
3058
|
+
const char* ext = ".unknown";
|
|
3059
|
+
switch(input_type) {
|
|
3060
|
+
case NVJITLINK_INPUT_CUBIN:
|
|
3061
|
+
ext = ".cubin";
|
|
3062
|
+
break;
|
|
3063
|
+
case NVJITLINK_INPUT_LTOIR:
|
|
3064
|
+
ext = ".ltoir";
|
|
3065
|
+
break;
|
|
3066
|
+
case NVJITLINK_INPUT_FATBIN:
|
|
3067
|
+
ext = ".fatbin";
|
|
3068
|
+
break;
|
|
3069
|
+
default:
|
|
3070
|
+
break;
|
|
3071
|
+
}
|
|
3072
|
+
if(std::getenv("WARP_DUMP_LTOIR"))
|
|
3073
|
+
{
|
|
3074
|
+
write_file(ltoirs[ltoidx], ltoir_sizes[ltoidx], std::string("lto_online_") + std::to_string(ltoidx) + ext, "wb");
|
|
3075
|
+
}
|
|
3076
|
+
if(!check_nvjitlink(handle, nvJitLinkAddData(handle, input_type, ltoirs[ltoidx], ltoir_sizes[ltoidx], "lto_online"))) // External LTOIR
|
|
3077
|
+
{
|
|
3078
|
+
res = nvrtcResult(-1);
|
|
3079
|
+
}
|
|
3080
|
+
}
|
|
3081
|
+
if(!check_nvjitlink(handle, nvJitLinkComplete(handle)))
|
|
3082
|
+
{
|
|
3083
|
+
res = nvrtcResult(-1);
|
|
3084
|
+
}
|
|
3085
|
+
else
|
|
3086
|
+
{
|
|
3087
|
+
if(use_ptx)
|
|
3088
|
+
{
|
|
3089
|
+
size_t ptx_size = 0;
|
|
3090
|
+
check_nvjitlink(handle, nvJitLinkGetLinkedPtxSize(handle, &ptx_size));
|
|
3091
|
+
std::vector<char> ptx(ptx_size);
|
|
3092
|
+
check_nvjitlink(handle, nvJitLinkGetLinkedPtx(handle, ptx.data()));
|
|
3093
|
+
output = ptx;
|
|
3094
|
+
}
|
|
3095
|
+
else
|
|
3096
|
+
{
|
|
3097
|
+
size_t cubin_size = 0;
|
|
3098
|
+
check_nvjitlink(handle, nvJitLinkGetLinkedCubinSize(handle, &cubin_size));
|
|
3099
|
+
std::vector<char> cubin(cubin_size);
|
|
3100
|
+
check_nvjitlink(handle, nvJitLinkGetLinkedCubin(handle, cubin.data()));
|
|
3101
|
+
output = cubin;
|
|
3102
|
+
}
|
|
3103
|
+
}
|
|
3104
|
+
check_nvjitlink(handle, nvJitLinkDestroy(&handle));
|
|
3105
|
+
#else
|
|
3106
|
+
fprintf(stderr, "Warp error: num_ltoirs > 0 but Warp was not built with MathDx support\n");
|
|
3107
|
+
return size_t(-1);
|
|
3108
|
+
#endif
|
|
3109
|
+
}
|
|
3110
|
+
|
|
3111
|
+
if(!write_file(output.data(), output.size(), output_path, output_mode)) {
|
|
3112
|
+
res = nvrtcResult(-1);
|
|
3113
|
+
}
|
|
3114
|
+
}
|
|
3115
|
+
}
|
|
3116
|
+
|
|
3117
|
+
check_nvrtc(nvrtcDestroyProgram(&prog));
|
|
3118
|
+
|
|
3119
|
+
return res;
|
|
3120
|
+
}
|
|
3121
|
+
|
|
3122
|
+
#if WP_ENABLE_MATHDX
|
|
3123
|
+
bool check_cufftdx_result(commondxStatusType result, const char* file, int line)
|
|
3124
|
+
{
|
|
3125
|
+
if (result != commondxStatusType::COMMONDX_SUCCESS) {
|
|
3126
|
+
fprintf(stderr, "libmathdx cuFFTDx error: %d on %s:%d\n", (int)result, file, line);
|
|
3127
|
+
return false;
|
|
3128
|
+
} else {
|
|
3129
|
+
return true;
|
|
3130
|
+
}
|
|
3131
|
+
}
|
|
3132
|
+
|
|
3133
|
+
bool check_cublasdx_result(commondxStatusType result, const char* file, int line)
|
|
3134
|
+
{
|
|
3135
|
+
if (result != commondxStatusType::COMMONDX_SUCCESS) {
|
|
3136
|
+
fprintf(stderr, "libmathdx cuBLASDx error: %d on %s:%d\n", (int)result, file, line);
|
|
3137
|
+
return false;
|
|
3138
|
+
} else {
|
|
3139
|
+
return true;
|
|
3140
|
+
}
|
|
3141
|
+
}
|
|
3142
|
+
|
|
3143
|
+
bool check_cusolver_result(commondxStatusType result, const char* file, int line)
|
|
3144
|
+
{
|
|
3145
|
+
if (result != commondxStatusType::COMMONDX_SUCCESS) {
|
|
3146
|
+
fprintf(stderr, "libmathdx cuSOLVER error: %d on %s:%d\n", (int)result, file, line);
|
|
3147
|
+
return false;
|
|
3148
|
+
} else {
|
|
3149
|
+
return true;
|
|
3150
|
+
}
|
|
3151
|
+
}
|
|
3152
|
+
|
|
3153
|
+
bool cuda_compile_fft(const char* ltoir_output_path, const char* symbol_name, int num_include_dirs, const char** include_dirs, const char* mathdx_include_dir, int arch, int size, int elements_per_thread, int direction, int precision, int* shared_memory_size)
|
|
3154
|
+
{
|
|
3155
|
+
|
|
3156
|
+
CHECK_ANY(ltoir_output_path != nullptr);
|
|
3157
|
+
CHECK_ANY(symbol_name != nullptr);
|
|
3158
|
+
CHECK_ANY(shared_memory_size != nullptr);
|
|
3159
|
+
// Includes currently unused
|
|
3160
|
+
CHECK_ANY(include_dirs == nullptr);
|
|
3161
|
+
CHECK_ANY(mathdx_include_dir == nullptr);
|
|
3162
|
+
CHECK_ANY(num_include_dirs == 0);
|
|
3163
|
+
|
|
3164
|
+
bool res = true;
|
|
3165
|
+
cufftdxHandle h;
|
|
3166
|
+
CHECK_CUFFTDX(cufftdxCreate(&h));
|
|
3167
|
+
|
|
3168
|
+
// CUFFTDX_API_BLOCK_LMEM means each thread starts with a subset of the data
|
|
3169
|
+
CHECK_CUFFTDX(cufftdxSetOperatorInt64(h, cufftdxOperatorType::CUFFTDX_OPERATOR_API, cufftdxApi::CUFFTDX_API_BLOCK_LMEM));
|
|
3170
|
+
CHECK_CUFFTDX(cufftdxSetOperatorInt64(h, cufftdxOperatorType::CUFFTDX_OPERATOR_EXECUTION, commondxExecution::COMMONDX_EXECUTION_BLOCK));
|
|
3171
|
+
CHECK_CUFFTDX(cufftdxSetOperatorInt64(h, cufftdxOperatorType::CUFFTDX_OPERATOR_SIZE, (long long)size));
|
|
3172
|
+
CHECK_CUFFTDX(cufftdxSetOperatorInt64(h, cufftdxOperatorType::CUFFTDX_OPERATOR_DIRECTION, (cufftdxDirection)direction));
|
|
3173
|
+
CHECK_CUFFTDX(cufftdxSetOperatorInt64(h, cufftdxOperatorType::CUFFTDX_OPERATOR_PRECISION, (commondxPrecision)precision));
|
|
3174
|
+
CHECK_CUFFTDX(cufftdxSetOperatorInt64(h, cufftdxOperatorType::CUFFTDX_OPERATOR_SM, (long long)(arch * 10)));
|
|
3175
|
+
CHECK_CUFFTDX(cufftdxSetOperatorInt64(h, cufftdxOperatorType::CUFFTDX_OPERATOR_ELEMENTS_PER_THREAD, (long long)(elements_per_thread)));
|
|
3176
|
+
CHECK_CUFFTDX(cufftdxSetOperatorInt64(h, cufftdxOperatorType::CUFFTDX_OPERATOR_FFTS_PER_BLOCK, 1));
|
|
3177
|
+
|
|
3178
|
+
CHECK_CUFFTDX(cufftdxSetOptionStr(h, commondxOption::COMMONDX_OPTION_SYMBOL_NAME, symbol_name));
|
|
3179
|
+
|
|
3180
|
+
size_t lto_size = 0;
|
|
3181
|
+
CHECK_CUFFTDX(cufftdxGetLTOIRSize(h, <o_size));
|
|
3182
|
+
|
|
3183
|
+
std::vector<char> lto(lto_size);
|
|
3184
|
+
CHECK_CUFFTDX(cufftdxGetLTOIR(h, lto.size(), lto.data()));
|
|
3185
|
+
|
|
3186
|
+
long long int smem = 0;
|
|
3187
|
+
CHECK_CUFFTDX(cufftdxGetTraitInt64(h, cufftdxTraitType::CUFFTDX_TRAIT_SHARED_MEMORY_SIZE, &smem));
|
|
3188
|
+
*shared_memory_size = (int)smem;
|
|
3189
|
+
|
|
3190
|
+
if(!write_file(lto.data(), lto.size(), ltoir_output_path, "wb")) {
|
|
3191
|
+
res = false;
|
|
3192
|
+
}
|
|
3193
|
+
|
|
3194
|
+
CHECK_CUFFTDX(cufftdxDestroy(h));
|
|
3195
|
+
|
|
3196
|
+
return res;
|
|
3197
|
+
}
|
|
3198
|
+
|
|
3199
|
+
bool cuda_compile_dot(const char* ltoir_output_path, const char* symbol_name, int num_include_dirs, const char** include_dirs, const char* mathdx_include_dir, int arch, int M, int N, int K, int precision_A, int precision_B, int precision_C, int type, int arrangement_A, int arrangement_B, int arrangement_C, int num_threads)
|
|
3200
|
+
{
|
|
3201
|
+
|
|
3202
|
+
CHECK_ANY(ltoir_output_path != nullptr);
|
|
3203
|
+
CHECK_ANY(symbol_name != nullptr);
|
|
3204
|
+
// Includes currently unused
|
|
3205
|
+
CHECK_ANY(include_dirs == nullptr);
|
|
3206
|
+
CHECK_ANY(mathdx_include_dir == nullptr);
|
|
3207
|
+
CHECK_ANY(num_include_dirs == 0);
|
|
3208
|
+
|
|
3209
|
+
bool res = true;
|
|
3210
|
+
cublasdxHandle h;
|
|
3211
|
+
CHECK_CUBLASDX(cublasdxCreate(&h));
|
|
3212
|
+
|
|
3213
|
+
CHECK_CUBLASDX(cublasdxSetOperatorInt64(h, cublasdxOperatorType::CUBLASDX_OPERATOR_FUNCTION, cublasdxFunction::CUBLASDX_FUNCTION_MM));
|
|
3214
|
+
CHECK_CUBLASDX(cublasdxSetOperatorInt64(h, cublasdxOperatorType::CUBLASDX_OPERATOR_EXECUTION, commondxExecution::COMMONDX_EXECUTION_BLOCK));
|
|
3215
|
+
CHECK_CUBLASDX(cublasdxSetOperatorInt64(h, cublasdxOperatorType::CUBLASDX_OPERATOR_API, cublasdxApi::CUBLASDX_API_BLOCK_SMEM));
|
|
3216
|
+
std::array<long long int, 3> precisions = {precision_A, precision_B, precision_C};
|
|
3217
|
+
CHECK_CUBLASDX(cublasdxSetOperatorInt64Array(h, cublasdxOperatorType::CUBLASDX_OPERATOR_PRECISION, 3, precisions.data()));
|
|
3218
|
+
CHECK_CUBLASDX(cublasdxSetOperatorInt64(h, cublasdxOperatorType::CUBLASDX_OPERATOR_SM, (long long)(arch * 10)));
|
|
3219
|
+
CHECK_CUBLASDX(cublasdxSetOperatorInt64(h, cublasdxOperatorType::CUBLASDX_OPERATOR_TYPE, (cublasdxType)type));
|
|
3220
|
+
std::array<long long int, 3> block_dim = {num_threads, 1, 1};
|
|
3221
|
+
CHECK_CUBLASDX(cublasdxSetOperatorInt64Array(h, cublasdxOperatorType::CUBLASDX_OPERATOR_BLOCK_DIM, block_dim.size(), block_dim.data()));
|
|
3222
|
+
std::array<long long int, 3> size = {M, N, K};
|
|
3223
|
+
CHECK_CUBLASDX(cublasdxSetOperatorInt64Array(h, cublasdxOperatorType::CUBLASDX_OPERATOR_SIZE, size.size(), size.data()));
|
|
3224
|
+
std::array<long long int, 3> arrangement = {arrangement_A, arrangement_B, arrangement_C};
|
|
3225
|
+
CHECK_CUBLASDX(cublasdxSetOperatorInt64Array(h, cublasdxOperatorType::CUBLASDX_OPERATOR_ARRANGEMENT, arrangement.size(), arrangement.data()));
|
|
3226
|
+
|
|
3227
|
+
CHECK_CUBLASDX(cublasdxSetOptionStr(h, commondxOption::COMMONDX_OPTION_SYMBOL_NAME, symbol_name));
|
|
3228
|
+
|
|
3229
|
+
size_t lto_size = 0;
|
|
3230
|
+
CHECK_CUBLASDX(cublasdxGetLTOIRSize(h, <o_size));
|
|
3231
|
+
|
|
3232
|
+
std::vector<char> lto(lto_size);
|
|
3233
|
+
CHECK_CUBLASDX(cublasdxGetLTOIR(h, lto.size(), lto.data()));
|
|
3234
|
+
|
|
3235
|
+
if(!write_file(lto.data(), lto.size(), ltoir_output_path, "wb")) {
|
|
3236
|
+
res = false;
|
|
3237
|
+
}
|
|
3238
|
+
|
|
3239
|
+
CHECK_CUBLASDX(cublasdxDestroy(h));
|
|
3240
|
+
|
|
3241
|
+
return res;
|
|
3242
|
+
}
|
|
3243
|
+
|
|
3244
|
+
bool cuda_compile_solver(const char* fatbin_output_path, const char* ltoir_output_path, const char* symbol_name, int num_include_dirs, const char** include_dirs, const char* mathdx_include_dir, int arch, int M, int N, int function, int precision, int fill_mode, int num_threads)
|
|
3245
|
+
{
|
|
3246
|
+
|
|
3247
|
+
CHECK_ANY(ltoir_output_path != nullptr);
|
|
3248
|
+
CHECK_ANY(symbol_name != nullptr);
|
|
3249
|
+
CHECK_ANY(mathdx_include_dir == nullptr);
|
|
3250
|
+
CHECK_ANY(num_include_dirs == 0);
|
|
3251
|
+
CHECK_ANY(include_dirs == nullptr);
|
|
3252
|
+
|
|
3253
|
+
bool res = true;
|
|
3254
|
+
|
|
3255
|
+
cusolverHandle h { 0 };
|
|
3256
|
+
CHECK_CUSOLVER(cusolverCreate(&h));
|
|
3257
|
+
long long int size[2] = {M, N};
|
|
3258
|
+
long long int block_dim[3] = {num_threads, 1, 1};
|
|
3259
|
+
CHECK_CUSOLVER(cusolverSetOperatorInt64Array(h, cusolverOperatorType::CUSOLVER_OPERATOR_SIZE, 2, size));
|
|
3260
|
+
CHECK_CUSOLVER(cusolverSetOperatorInt64Array(h, cusolverOperatorType::CUSOLVER_OPERATOR_BLOCK_DIM, 3, block_dim));
|
|
3261
|
+
CHECK_CUSOLVER(cusolverSetOperatorInt64(h, cusolverOperatorType::CUSOLVER_OPERATOR_TYPE, cusolverType::CUSOLVER_TYPE_REAL));
|
|
3262
|
+
CHECK_CUSOLVER(cusolverSetOperatorInt64(h, cusolverOperatorType::CUSOLVER_OPERATOR_API, cusolverApi::CUSOLVER_API_BLOCK_SMEM));
|
|
3263
|
+
CHECK_CUSOLVER(cusolverSetOperatorInt64(h, cusolverOperatorType::CUSOLVER_OPERATOR_FUNCTION, (cusolverFunction)function));
|
|
3264
|
+
CHECK_CUSOLVER(cusolverSetOperatorInt64(h, cusolverOperatorType::CUSOLVER_OPERATOR_EXECUTION, commondxExecution::COMMONDX_EXECUTION_BLOCK));
|
|
3265
|
+
CHECK_CUSOLVER(cusolverSetOperatorInt64(h, cusolverOperatorType::CUSOLVER_OPERATOR_PRECISION, (commondxPrecision)precision));
|
|
3266
|
+
CHECK_CUSOLVER(cusolverSetOperatorInt64(h, cusolverOperatorType::CUSOLVER_OPERATOR_FILL_MODE, (cusolverFillMode)fill_mode));
|
|
3267
|
+
CHECK_CUSOLVER(cusolverSetOperatorInt64(h, cusolverOperatorType::CUSOLVER_OPERATOR_SM, (long long)(arch * 10)));
|
|
3268
|
+
|
|
3269
|
+
CHECK_CUSOLVER(cusolverSetOptionStr(h, commondxOption::COMMONDX_OPTION_SYMBOL_NAME, symbol_name));
|
|
3270
|
+
|
|
3271
|
+
size_t lto_size = 0;
|
|
3272
|
+
CHECK_CUSOLVER(cusolverGetLTOIRSize(h, <o_size));
|
|
3273
|
+
|
|
3274
|
+
std::vector<char> lto(lto_size);
|
|
3275
|
+
CHECK_CUSOLVER(cusolverGetLTOIR(h, lto.size(), lto.data()));
|
|
3276
|
+
|
|
3277
|
+
// This fatbin is universal, ie it is the same for any instantiations of a cusolver device function
|
|
3278
|
+
size_t fatbin_size = 0;
|
|
3279
|
+
CHECK_CUSOLVER(cusolverGetUniversalFATBINSize(h, &fatbin_size));
|
|
3280
|
+
|
|
3281
|
+
std::vector<char> fatbin(fatbin_size);
|
|
3282
|
+
CHECK_CUSOLVER(cusolverGetUniversalFATBIN(h, fatbin.size(), fatbin.data()));
|
|
3283
|
+
|
|
3284
|
+
if(!write_file(lto.data(), lto.size(), ltoir_output_path, "wb")) {
|
|
3285
|
+
res = false;
|
|
3286
|
+
}
|
|
3287
|
+
|
|
3288
|
+
if(!write_file(fatbin.data(), fatbin.size(), fatbin_output_path, "wb")) {
|
|
3289
|
+
res = false;
|
|
3290
|
+
}
|
|
3291
|
+
|
|
3292
|
+
CHECK_CUSOLVER(cusolverDestroy(h));
|
|
3293
|
+
|
|
3294
|
+
return res;
|
|
3295
|
+
}
|
|
3296
|
+
|
|
3297
|
+
#endif
|
|
3298
|
+
|
|
3299
|
+
void* cuda_load_module(void* context, const char* path)
|
|
3300
|
+
{
|
|
3301
|
+
ContextGuard guard(context);
|
|
3302
|
+
|
|
3303
|
+
// use file extension to determine whether to load PTX or CUBIN
|
|
3304
|
+
const char* input_ext = strrchr(path, '.');
|
|
3305
|
+
bool load_ptx = input_ext && strcmp(input_ext + 1, "ptx") == 0;
|
|
3306
|
+
|
|
3307
|
+
std::vector<char> input;
|
|
3308
|
+
|
|
3309
|
+
FILE* file = fopen(path, "rb");
|
|
3310
|
+
if (file)
|
|
3311
|
+
{
|
|
3312
|
+
fseek(file, 0, SEEK_END);
|
|
3313
|
+
size_t length = ftell(file);
|
|
3314
|
+
fseek(file, 0, SEEK_SET);
|
|
3315
|
+
|
|
3316
|
+
input.resize(length + 1);
|
|
3317
|
+
if (fread(input.data(), 1, length, file) != length)
|
|
3318
|
+
{
|
|
3319
|
+
fprintf(stderr, "Warp error: Failed to read input file '%s'\n", path);
|
|
3320
|
+
fclose(file);
|
|
3321
|
+
return NULL;
|
|
3322
|
+
}
|
|
3323
|
+
fclose(file);
|
|
3324
|
+
|
|
3325
|
+
input[length] = '\0';
|
|
3326
|
+
}
|
|
3327
|
+
else
|
|
3328
|
+
{
|
|
3329
|
+
fprintf(stderr, "Warp error: Failed to open input file '%s'\n", path);
|
|
3330
|
+
return NULL;
|
|
3331
|
+
}
|
|
3332
|
+
|
|
3333
|
+
int driver_cuda_version = 0;
|
|
3334
|
+
CUmodule module = NULL;
|
|
3335
|
+
|
|
3336
|
+
if (load_ptx)
|
|
3337
|
+
{
|
|
3338
|
+
if (check_cu(cuDriverGetVersion_f(&driver_cuda_version)) && driver_cuda_version >= CUDA_VERSION)
|
|
3339
|
+
{
|
|
3340
|
+
// let the driver compile the PTX
|
|
3341
|
+
|
|
3342
|
+
CUjit_option options[2];
|
|
3343
|
+
void *option_vals[2];
|
|
3344
|
+
char error_log[8192] = "";
|
|
3345
|
+
unsigned int log_size = 8192;
|
|
3346
|
+
// Set up loader options
|
|
3347
|
+
// Pass a buffer for error message
|
|
3348
|
+
options[0] = CU_JIT_ERROR_LOG_BUFFER;
|
|
3349
|
+
option_vals[0] = (void*)error_log;
|
|
3350
|
+
// Pass the size of the error buffer
|
|
3351
|
+
options[1] = CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES;
|
|
3352
|
+
option_vals[1] = (void*)(size_t)log_size;
|
|
3353
|
+
|
|
3354
|
+
if (!check_cu(cuModuleLoadDataEx_f(&module, input.data(), 2, options, option_vals)))
|
|
3355
|
+
{
|
|
3356
|
+
fprintf(stderr, "Warp error: Loading PTX module failed\n");
|
|
3357
|
+
// print error log if not empty
|
|
3358
|
+
if (*error_log)
|
|
3359
|
+
fprintf(stderr, "PTX loader error:\n%s\n", error_log);
|
|
3360
|
+
return NULL;
|
|
3361
|
+
}
|
|
3362
|
+
}
|
|
3363
|
+
else
|
|
3364
|
+
{
|
|
3365
|
+
// manually compile the PTX and load as CUBIN
|
|
3366
|
+
|
|
3367
|
+
ContextInfo* context_info = get_context_info(static_cast<CUcontext>(context));
|
|
3368
|
+
if (!context_info || !context_info->device_info)
|
|
3369
|
+
{
|
|
3370
|
+
fprintf(stderr, "Warp error: Failed to determine target architecture\n");
|
|
3371
|
+
return NULL;
|
|
3372
|
+
}
|
|
3373
|
+
|
|
3374
|
+
int arch = context_info->device_info->arch;
|
|
3375
|
+
|
|
3376
|
+
char arch_opt[128];
|
|
3377
|
+
sprintf(arch_opt, "--gpu-name=sm_%d", arch);
|
|
3378
|
+
|
|
3379
|
+
const char* compiler_options[] = { arch_opt };
|
|
3380
|
+
|
|
3381
|
+
nvPTXCompilerHandle compiler = NULL;
|
|
3382
|
+
if (!check_nvptx(nvPTXCompilerCreate(&compiler, input.size(), input.data())))
|
|
3383
|
+
return NULL;
|
|
3384
|
+
|
|
3385
|
+
if (!check_nvptx(nvPTXCompilerCompile(compiler, sizeof(compiler_options) / sizeof(*compiler_options), compiler_options)))
|
|
3386
|
+
return NULL;
|
|
3387
|
+
|
|
3388
|
+
size_t cubin_size = 0;
|
|
3389
|
+
if (!check_nvptx(nvPTXCompilerGetCompiledProgramSize(compiler, &cubin_size)))
|
|
3390
|
+
return NULL;
|
|
3391
|
+
|
|
3392
|
+
std::vector<char> cubin(cubin_size);
|
|
3393
|
+
if (!check_nvptx(nvPTXCompilerGetCompiledProgram(compiler, cubin.data())))
|
|
3394
|
+
return NULL;
|
|
3395
|
+
|
|
3396
|
+
check_nvptx(nvPTXCompilerDestroy(&compiler));
|
|
3397
|
+
|
|
3398
|
+
if (!check_cu(cuModuleLoadDataEx_f(&module, cubin.data(), 0, NULL, NULL)))
|
|
3399
|
+
{
|
|
3400
|
+
fprintf(stderr, "Warp CUDA error: Loading module failed\n");
|
|
3401
|
+
return NULL;
|
|
3402
|
+
}
|
|
3403
|
+
}
|
|
3404
|
+
}
|
|
3405
|
+
else
|
|
3406
|
+
{
|
|
3407
|
+
// load CUBIN
|
|
3408
|
+
if (!check_cu(cuModuleLoadDataEx_f(&module, input.data(), 0, NULL, NULL)))
|
|
3409
|
+
{
|
|
3410
|
+
fprintf(stderr, "Warp CUDA error: Loading module failed\n");
|
|
3411
|
+
return NULL;
|
|
3412
|
+
}
|
|
3413
|
+
}
|
|
3414
|
+
|
|
3415
|
+
return module;
|
|
3416
|
+
}
|
|
3417
|
+
|
|
3418
|
+
void cuda_unload_module(void* context, void* module)
|
|
3419
|
+
{
|
|
3420
|
+
// ensure there are no graph captures in progress
|
|
3421
|
+
if (g_captures.empty())
|
|
3422
|
+
{
|
|
3423
|
+
ContextGuard guard(context);
|
|
3424
|
+
check_cu(cuModuleUnload_f((CUmodule)module));
|
|
3425
|
+
}
|
|
3426
|
+
else
|
|
3427
|
+
{
|
|
3428
|
+
// defer until graph capture completes
|
|
3429
|
+
ModuleInfo module_info;
|
|
3430
|
+
module_info.context = context ? context : get_current_context();
|
|
3431
|
+
module_info.module = module;
|
|
3432
|
+
g_deferred_module_list.push_back(module_info);
|
|
3433
|
+
}
|
|
3434
|
+
}
|
|
3435
|
+
|
|
3436
|
+
|
|
3437
|
+
int cuda_get_max_shared_memory(void* context)
|
|
3438
|
+
{
|
|
3439
|
+
ContextInfo* info = get_context_info(context);
|
|
3440
|
+
if (!info)
|
|
3441
|
+
return -1;
|
|
3442
|
+
|
|
3443
|
+
int max_smem_bytes = info->device_info->max_smem_bytes;
|
|
3444
|
+
return max_smem_bytes;
|
|
3445
|
+
}
|
|
3446
|
+
|
|
3447
|
+
bool cuda_configure_kernel_shared_memory(void* kernel, int size)
|
|
3448
|
+
{
|
|
3449
|
+
int requested_smem_bytes = size;
|
|
3450
|
+
|
|
3451
|
+
// configure shared memory
|
|
3452
|
+
CUresult res = cuFuncSetAttribute_f((CUfunction)kernel, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, requested_smem_bytes);
|
|
3453
|
+
if (res != CUDA_SUCCESS)
|
|
3454
|
+
return false;
|
|
3455
|
+
|
|
3456
|
+
return true;
|
|
3457
|
+
}
|
|
3458
|
+
|
|
3459
|
+
void* cuda_get_kernel(void* context, void* module, const char* name)
|
|
3460
|
+
{
|
|
3461
|
+
ContextGuard guard(context);
|
|
3462
|
+
|
|
3463
|
+
CUfunction kernel = NULL;
|
|
3464
|
+
if (!check_cu(cuModuleGetFunction_f(&kernel, (CUmodule)module, name)))
|
|
3465
|
+
{
|
|
3466
|
+
fprintf(stderr, "Warp CUDA error: Failed to lookup kernel function %s in module\n", name);
|
|
3467
|
+
return NULL;
|
|
3468
|
+
}
|
|
3469
|
+
|
|
3470
|
+
g_kernel_names[kernel] = name;
|
|
3471
|
+
return kernel;
|
|
3472
|
+
}
|
|
3473
|
+
|
|
3474
|
+
size_t cuda_launch_kernel(void* context, void* kernel, size_t dim, int max_blocks, int block_dim, int shared_memory_bytes, void** args, void* stream)
|
|
3475
|
+
{
|
|
3476
|
+
ContextGuard guard(context);
|
|
3477
|
+
|
|
3478
|
+
if (block_dim <= 0)
|
|
3479
|
+
{
|
|
3480
|
+
#if defined(_DEBUG)
|
|
3481
|
+
fprintf(stderr, "Warp warning: Launch got block_dim %d. Setting to 256.\n", block_dim);
|
|
3482
|
+
#endif
|
|
3483
|
+
block_dim = 256;
|
|
3484
|
+
}
|
|
3485
|
+
|
|
3486
|
+
// CUDA specs up to compute capability 9.0 says the max x-dim grid is 2**31-1, so
|
|
3487
|
+
// grid_dim is fine as an int for the near future
|
|
3488
|
+
int grid_dim = (dim + block_dim - 1)/block_dim;
|
|
3489
|
+
|
|
3490
|
+
if (max_blocks <= 0) {
|
|
3491
|
+
max_blocks = 2147483647;
|
|
3492
|
+
}
|
|
3493
|
+
|
|
3494
|
+
if (grid_dim < 0)
|
|
3495
|
+
{
|
|
3496
|
+
#if defined(_DEBUG)
|
|
3497
|
+
fprintf(stderr, "Warp warning: Overflow in grid dimensions detected for %zu total elements and 256 threads "
|
|
3498
|
+
"per block.\n Setting block count to %d.\n", dim, max_blocks);
|
|
3499
|
+
#endif
|
|
3500
|
+
grid_dim = max_blocks;
|
|
3501
|
+
}
|
|
3502
|
+
else
|
|
3503
|
+
{
|
|
3504
|
+
if (grid_dim > max_blocks)
|
|
3505
|
+
{
|
|
3506
|
+
grid_dim = max_blocks;
|
|
3507
|
+
}
|
|
3508
|
+
}
|
|
3509
|
+
|
|
3510
|
+
begin_cuda_range(WP_TIMING_KERNEL, stream, context, get_cuda_kernel_name(kernel));
|
|
3511
|
+
|
|
3512
|
+
CUresult res = cuLaunchKernel_f(
|
|
3513
|
+
(CUfunction)kernel,
|
|
3514
|
+
grid_dim, 1, 1,
|
|
3515
|
+
block_dim, 1, 1,
|
|
3516
|
+
shared_memory_bytes,
|
|
3517
|
+
static_cast<CUstream>(stream),
|
|
3518
|
+
args,
|
|
3519
|
+
0);
|
|
3520
|
+
|
|
3521
|
+
check_cu(res);
|
|
3522
|
+
|
|
3523
|
+
end_cuda_range(WP_TIMING_KERNEL, stream);
|
|
3524
|
+
|
|
3525
|
+
return res;
|
|
3526
|
+
}
|
|
3527
|
+
|
|
3528
|
+
void cuda_graphics_map(void* context, void* resource)
|
|
3529
|
+
{
|
|
3530
|
+
ContextGuard guard(context);
|
|
3531
|
+
|
|
3532
|
+
check_cu(cuGraphicsMapResources_f(1, (CUgraphicsResource*)resource, get_current_stream()));
|
|
3533
|
+
}
|
|
3534
|
+
|
|
3535
|
+
void cuda_graphics_unmap(void* context, void* resource)
|
|
3536
|
+
{
|
|
3537
|
+
ContextGuard guard(context);
|
|
3538
|
+
|
|
3539
|
+
check_cu(cuGraphicsUnmapResources_f(1, (CUgraphicsResource*)resource, get_current_stream()));
|
|
3540
|
+
}
|
|
3541
|
+
|
|
3542
|
+
void cuda_graphics_device_ptr_and_size(void* context, void* resource, uint64_t* ptr, size_t* size)
|
|
3543
|
+
{
|
|
3544
|
+
ContextGuard guard(context);
|
|
3545
|
+
|
|
3546
|
+
CUdeviceptr device_ptr;
|
|
3547
|
+
size_t bytes;
|
|
3548
|
+
check_cu(cuGraphicsResourceGetMappedPointer_f(&device_ptr, &bytes, *(CUgraphicsResource*)resource));
|
|
3549
|
+
|
|
3550
|
+
*ptr = device_ptr;
|
|
3551
|
+
*size = bytes;
|
|
3552
|
+
}
|
|
3553
|
+
|
|
3554
|
+
void* cuda_graphics_register_gl_buffer(void* context, uint32_t gl_buffer, unsigned int flags)
|
|
3555
|
+
{
|
|
3556
|
+
ContextGuard guard(context);
|
|
3557
|
+
|
|
3558
|
+
CUgraphicsResource *resource = new CUgraphicsResource;
|
|
3559
|
+
bool success = check_cu(cuGraphicsGLRegisterBuffer_f(resource, gl_buffer, flags));
|
|
3560
|
+
if (!success)
|
|
3561
|
+
{
|
|
3562
|
+
delete resource;
|
|
3563
|
+
return NULL;
|
|
3564
|
+
}
|
|
3565
|
+
|
|
3566
|
+
return resource;
|
|
3567
|
+
}
|
|
3568
|
+
|
|
3569
|
+
void cuda_graphics_unregister_resource(void* context, void* resource)
|
|
3570
|
+
{
|
|
3571
|
+
ContextGuard guard(context);
|
|
3572
|
+
|
|
3573
|
+
CUgraphicsResource *res = (CUgraphicsResource*)resource;
|
|
3574
|
+
check_cu(cuGraphicsUnregisterResource_f(*res));
|
|
3575
|
+
delete res;
|
|
3576
|
+
}
|
|
3577
|
+
|
|
3578
|
+
void cuda_timing_begin(int flags)
|
|
3579
|
+
{
|
|
3580
|
+
g_cuda_timing_state = new CudaTimingState(flags, g_cuda_timing_state);
|
|
3581
|
+
}
|
|
3582
|
+
|
|
3583
|
+
int cuda_timing_get_result_count()
|
|
3584
|
+
{
|
|
3585
|
+
if (g_cuda_timing_state)
|
|
3586
|
+
return int(g_cuda_timing_state->ranges.size());
|
|
3587
|
+
return 0;
|
|
3588
|
+
}
|
|
3589
|
+
|
|
3590
|
+
void cuda_timing_end(timing_result_t* results, int size)
|
|
3591
|
+
{
|
|
3592
|
+
if (!g_cuda_timing_state)
|
|
3593
|
+
return;
|
|
3594
|
+
|
|
3595
|
+
// number of results to write to the user buffer
|
|
3596
|
+
int count = std::min(cuda_timing_get_result_count(), size);
|
|
3597
|
+
|
|
3598
|
+
// compute timings and write results
|
|
3599
|
+
for (int i = 0; i < count; i++)
|
|
3600
|
+
{
|
|
3601
|
+
const CudaTimingRange& range = g_cuda_timing_state->ranges[i];
|
|
3602
|
+
timing_result_t& result = results[i];
|
|
3603
|
+
result.context = range.context;
|
|
3604
|
+
result.name = range.name;
|
|
3605
|
+
result.flag = range.flag;
|
|
3606
|
+
check_cuda(cudaEventElapsedTime(&result.elapsed, range.start, range.end));
|
|
3607
|
+
}
|
|
3608
|
+
|
|
3609
|
+
// release events
|
|
3610
|
+
for (CudaTimingRange& range : g_cuda_timing_state->ranges)
|
|
3611
|
+
{
|
|
3612
|
+
check_cu(cuEventDestroy_f(range.start));
|
|
3613
|
+
check_cu(cuEventDestroy_f(range.end));
|
|
3614
|
+
}
|
|
3615
|
+
|
|
3616
|
+
// restore previous state
|
|
3617
|
+
CudaTimingState* parent_state = g_cuda_timing_state->parent;
|
|
3618
|
+
delete g_cuda_timing_state;
|
|
3619
|
+
g_cuda_timing_state = parent_state;
|
|
3620
|
+
}
|
|
3621
|
+
|
|
3622
|
+
// impl. files
|
|
3623
|
+
#include "bvh.cu"
|
|
3624
|
+
#include "mesh.cu"
|
|
3625
|
+
#include "sort.cu"
|
|
3626
|
+
#include "hashgrid.cu"
|
|
3627
|
+
#include "reduce.cu"
|
|
3628
|
+
#include "runlength_encode.cu"
|
|
3629
|
+
#include "scan.cu"
|
|
3630
|
+
#include "marching.cu"
|
|
3631
|
+
#include "sparse.cu"
|
|
3632
|
+
#include "volume.cu"
|
|
3633
|
+
#include "volume_builder.cu"
|
|
3634
|
+
|
|
3635
|
+
//#include "spline.inl"
|
|
3636
|
+
//#include "volume.inl"
|