warp-lang 1.7.0__py3-none-manylinux_2_34_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +139 -0
- warp/__init__.pyi +1 -0
- warp/autograd.py +1142 -0
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/build.py +557 -0
- warp/build_dll.py +405 -0
- warp/builtins.py +6855 -0
- warp/codegen.py +3969 -0
- warp/config.py +158 -0
- warp/constants.py +57 -0
- warp/context.py +6812 -0
- warp/dlpack.py +462 -0
- warp/examples/__init__.py +24 -0
- warp/examples/assets/bear.usd +0 -0
- warp/examples/assets/bunny.usd +0 -0
- warp/examples/assets/cartpole.urdf +110 -0
- warp/examples/assets/crazyflie.usd +0 -0
- warp/examples/assets/cube.usd +0 -0
- warp/examples/assets/nonuniform.usd +0 -0
- warp/examples/assets/nv_ant.xml +92 -0
- warp/examples/assets/nv_humanoid.xml +183 -0
- warp/examples/assets/nvidia_logo.png +0 -0
- warp/examples/assets/pixel.jpg +0 -0
- warp/examples/assets/quadruped.urdf +268 -0
- warp/examples/assets/rocks.nvdb +0 -0
- warp/examples/assets/rocks.usd +0 -0
- warp/examples/assets/sphere.usd +0 -0
- warp/examples/assets/square_cloth.usd +0 -0
- warp/examples/benchmarks/benchmark_api.py +389 -0
- warp/examples/benchmarks/benchmark_cloth.py +296 -0
- warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
- warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
- warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
- warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
- warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
- warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
- warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
- warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
- warp/examples/benchmarks/benchmark_gemm.py +164 -0
- warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
- warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
- warp/examples/benchmarks/benchmark_launches.py +301 -0
- warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
- warp/examples/browse.py +37 -0
- warp/examples/core/example_cupy.py +86 -0
- warp/examples/core/example_dem.py +241 -0
- warp/examples/core/example_fluid.py +299 -0
- warp/examples/core/example_graph_capture.py +150 -0
- warp/examples/core/example_marching_cubes.py +194 -0
- warp/examples/core/example_mesh.py +180 -0
- warp/examples/core/example_mesh_intersect.py +211 -0
- warp/examples/core/example_nvdb.py +182 -0
- warp/examples/core/example_raycast.py +111 -0
- warp/examples/core/example_raymarch.py +205 -0
- warp/examples/core/example_render_opengl.py +193 -0
- warp/examples/core/example_sample_mesh.py +300 -0
- warp/examples/core/example_sph.py +411 -0
- warp/examples/core/example_torch.py +211 -0
- warp/examples/core/example_wave.py +269 -0
- warp/examples/fem/example_adaptive_grid.py +286 -0
- warp/examples/fem/example_apic_fluid.py +423 -0
- warp/examples/fem/example_burgers.py +261 -0
- warp/examples/fem/example_convection_diffusion.py +178 -0
- warp/examples/fem/example_convection_diffusion_dg.py +204 -0
- warp/examples/fem/example_deformed_geometry.py +172 -0
- warp/examples/fem/example_diffusion.py +196 -0
- warp/examples/fem/example_diffusion_3d.py +225 -0
- warp/examples/fem/example_diffusion_mgpu.py +220 -0
- warp/examples/fem/example_distortion_energy.py +228 -0
- warp/examples/fem/example_magnetostatics.py +240 -0
- warp/examples/fem/example_mixed_elasticity.py +291 -0
- warp/examples/fem/example_navier_stokes.py +261 -0
- warp/examples/fem/example_nonconforming_contact.py +298 -0
- warp/examples/fem/example_stokes.py +213 -0
- warp/examples/fem/example_stokes_transfer.py +262 -0
- warp/examples/fem/example_streamlines.py +352 -0
- warp/examples/fem/utils.py +1000 -0
- warp/examples/interop/example_jax_callable.py +116 -0
- warp/examples/interop/example_jax_ffi_callback.py +132 -0
- warp/examples/interop/example_jax_kernel.py +205 -0
- warp/examples/optim/example_bounce.py +266 -0
- warp/examples/optim/example_cloth_throw.py +228 -0
- warp/examples/optim/example_diffray.py +561 -0
- warp/examples/optim/example_drone.py +870 -0
- warp/examples/optim/example_fluid_checkpoint.py +497 -0
- warp/examples/optim/example_inverse_kinematics.py +182 -0
- warp/examples/optim/example_inverse_kinematics_torch.py +191 -0
- warp/examples/optim/example_softbody_properties.py +400 -0
- warp/examples/optim/example_spring_cage.py +245 -0
- warp/examples/optim/example_trajectory.py +227 -0
- warp/examples/sim/example_cartpole.py +143 -0
- warp/examples/sim/example_cloth.py +225 -0
- warp/examples/sim/example_cloth_self_contact.py +322 -0
- warp/examples/sim/example_granular.py +130 -0
- warp/examples/sim/example_granular_collision_sdf.py +202 -0
- warp/examples/sim/example_jacobian_ik.py +244 -0
- warp/examples/sim/example_particle_chain.py +124 -0
- warp/examples/sim/example_quadruped.py +203 -0
- warp/examples/sim/example_rigid_chain.py +203 -0
- warp/examples/sim/example_rigid_contact.py +195 -0
- warp/examples/sim/example_rigid_force.py +133 -0
- warp/examples/sim/example_rigid_gyroscopic.py +115 -0
- warp/examples/sim/example_rigid_soft_contact.py +140 -0
- warp/examples/sim/example_soft_body.py +196 -0
- warp/examples/tile/example_tile_cholesky.py +87 -0
- warp/examples/tile/example_tile_convolution.py +66 -0
- warp/examples/tile/example_tile_fft.py +55 -0
- warp/examples/tile/example_tile_filtering.py +113 -0
- warp/examples/tile/example_tile_matmul.py +85 -0
- warp/examples/tile/example_tile_mlp.py +383 -0
- warp/examples/tile/example_tile_nbody.py +199 -0
- warp/examples/tile/example_tile_walker.py +327 -0
- warp/fabric.py +355 -0
- warp/fem/__init__.py +106 -0
- warp/fem/adaptivity.py +508 -0
- warp/fem/cache.py +572 -0
- warp/fem/dirichlet.py +202 -0
- warp/fem/domain.py +411 -0
- warp/fem/field/__init__.py +125 -0
- warp/fem/field/field.py +619 -0
- warp/fem/field/nodal_field.py +326 -0
- warp/fem/field/restriction.py +37 -0
- warp/fem/field/virtual.py +848 -0
- warp/fem/geometry/__init__.py +32 -0
- warp/fem/geometry/adaptive_nanogrid.py +857 -0
- warp/fem/geometry/closest_point.py +84 -0
- warp/fem/geometry/deformed_geometry.py +221 -0
- warp/fem/geometry/element.py +776 -0
- warp/fem/geometry/geometry.py +362 -0
- warp/fem/geometry/grid_2d.py +392 -0
- warp/fem/geometry/grid_3d.py +452 -0
- warp/fem/geometry/hexmesh.py +911 -0
- warp/fem/geometry/nanogrid.py +571 -0
- warp/fem/geometry/partition.py +389 -0
- warp/fem/geometry/quadmesh.py +663 -0
- warp/fem/geometry/tetmesh.py +855 -0
- warp/fem/geometry/trimesh.py +806 -0
- warp/fem/integrate.py +2335 -0
- warp/fem/linalg.py +419 -0
- warp/fem/operator.py +293 -0
- warp/fem/polynomial.py +229 -0
- warp/fem/quadrature/__init__.py +17 -0
- warp/fem/quadrature/pic_quadrature.py +299 -0
- warp/fem/quadrature/quadrature.py +591 -0
- warp/fem/space/__init__.py +228 -0
- warp/fem/space/basis_function_space.py +468 -0
- warp/fem/space/basis_space.py +667 -0
- warp/fem/space/dof_mapper.py +251 -0
- warp/fem/space/function_space.py +309 -0
- warp/fem/space/grid_2d_function_space.py +177 -0
- warp/fem/space/grid_3d_function_space.py +227 -0
- warp/fem/space/hexmesh_function_space.py +257 -0
- warp/fem/space/nanogrid_function_space.py +201 -0
- warp/fem/space/partition.py +367 -0
- warp/fem/space/quadmesh_function_space.py +223 -0
- warp/fem/space/restriction.py +179 -0
- warp/fem/space/shape/__init__.py +143 -0
- warp/fem/space/shape/cube_shape_function.py +1105 -0
- warp/fem/space/shape/shape_function.py +133 -0
- warp/fem/space/shape/square_shape_function.py +926 -0
- warp/fem/space/shape/tet_shape_function.py +834 -0
- warp/fem/space/shape/triangle_shape_function.py +672 -0
- warp/fem/space/tetmesh_function_space.py +271 -0
- warp/fem/space/topology.py +424 -0
- warp/fem/space/trimesh_function_space.py +194 -0
- warp/fem/types.py +99 -0
- warp/fem/utils.py +420 -0
- warp/jax.py +187 -0
- warp/jax_experimental/__init__.py +16 -0
- warp/jax_experimental/custom_call.py +351 -0
- warp/jax_experimental/ffi.py +698 -0
- warp/jax_experimental/xla_ffi.py +602 -0
- warp/math.py +244 -0
- warp/native/array.h +1145 -0
- warp/native/builtin.h +1800 -0
- warp/native/bvh.cpp +492 -0
- warp/native/bvh.cu +791 -0
- warp/native/bvh.h +554 -0
- warp/native/clang/clang.cpp +536 -0
- warp/native/coloring.cpp +613 -0
- warp/native/crt.cpp +51 -0
- warp/native/crt.h +362 -0
- warp/native/cuda_crt.h +1058 -0
- warp/native/cuda_util.cpp +646 -0
- warp/native/cuda_util.h +307 -0
- warp/native/error.cpp +77 -0
- warp/native/error.h +36 -0
- warp/native/exports.h +1878 -0
- warp/native/fabric.h +245 -0
- warp/native/hashgrid.cpp +311 -0
- warp/native/hashgrid.cu +87 -0
- warp/native/hashgrid.h +240 -0
- warp/native/initializer_array.h +41 -0
- warp/native/intersect.h +1230 -0
- warp/native/intersect_adj.h +375 -0
- warp/native/intersect_tri.h +339 -0
- warp/native/marching.cpp +19 -0
- warp/native/marching.cu +514 -0
- warp/native/marching.h +19 -0
- warp/native/mat.h +2220 -0
- warp/native/mathdx.cpp +87 -0
- warp/native/matnn.h +343 -0
- warp/native/mesh.cpp +266 -0
- warp/native/mesh.cu +404 -0
- warp/native/mesh.h +1980 -0
- warp/native/nanovdb/GridHandle.h +366 -0
- warp/native/nanovdb/HostBuffer.h +590 -0
- warp/native/nanovdb/NanoVDB.h +6624 -0
- warp/native/nanovdb/PNanoVDB.h +3390 -0
- warp/native/noise.h +859 -0
- warp/native/quat.h +1371 -0
- warp/native/rand.h +342 -0
- warp/native/range.h +139 -0
- warp/native/reduce.cpp +174 -0
- warp/native/reduce.cu +364 -0
- warp/native/runlength_encode.cpp +79 -0
- warp/native/runlength_encode.cu +61 -0
- warp/native/scan.cpp +47 -0
- warp/native/scan.cu +53 -0
- warp/native/scan.h +23 -0
- warp/native/solid_angle.h +466 -0
- warp/native/sort.cpp +251 -0
- warp/native/sort.cu +277 -0
- warp/native/sort.h +33 -0
- warp/native/sparse.cpp +378 -0
- warp/native/sparse.cu +524 -0
- warp/native/spatial.h +657 -0
- warp/native/svd.h +702 -0
- warp/native/temp_buffer.h +46 -0
- warp/native/tile.h +2584 -0
- warp/native/tile_reduce.h +264 -0
- warp/native/vec.h +1426 -0
- warp/native/volume.cpp +501 -0
- warp/native/volume.cu +67 -0
- warp/native/volume.h +969 -0
- warp/native/volume_builder.cu +477 -0
- warp/native/volume_builder.h +52 -0
- warp/native/volume_impl.h +70 -0
- warp/native/warp.cpp +1082 -0
- warp/native/warp.cu +3636 -0
- warp/native/warp.h +381 -0
- warp/optim/__init__.py +17 -0
- warp/optim/adam.py +163 -0
- warp/optim/linear.py +1137 -0
- warp/optim/sgd.py +112 -0
- warp/paddle.py +407 -0
- warp/render/__init__.py +18 -0
- warp/render/render_opengl.py +3518 -0
- warp/render/render_usd.py +784 -0
- warp/render/utils.py +160 -0
- warp/sim/__init__.py +65 -0
- warp/sim/articulation.py +793 -0
- warp/sim/collide.py +2395 -0
- warp/sim/graph_coloring.py +300 -0
- warp/sim/import_mjcf.py +790 -0
- warp/sim/import_snu.py +227 -0
- warp/sim/import_urdf.py +579 -0
- warp/sim/import_usd.py +894 -0
- warp/sim/inertia.py +324 -0
- warp/sim/integrator.py +242 -0
- warp/sim/integrator_euler.py +1997 -0
- warp/sim/integrator_featherstone.py +2101 -0
- warp/sim/integrator_vbd.py +2048 -0
- warp/sim/integrator_xpbd.py +3292 -0
- warp/sim/model.py +4791 -0
- warp/sim/particles.py +121 -0
- warp/sim/render.py +427 -0
- warp/sim/utils.py +428 -0
- warp/sparse.py +2057 -0
- warp/stubs.py +3333 -0
- warp/tape.py +1203 -0
- warp/tests/__init__.py +1 -0
- warp/tests/__main__.py +4 -0
- warp/tests/assets/curlnoise_golden.npy +0 -0
- warp/tests/assets/mlp_golden.npy +0 -0
- warp/tests/assets/pixel.npy +0 -0
- warp/tests/assets/pnoise_golden.npy +0 -0
- warp/tests/assets/spiky.usd +0 -0
- warp/tests/assets/test_grid.nvdb +0 -0
- warp/tests/assets/test_index_grid.nvdb +0 -0
- warp/tests/assets/test_int32_grid.nvdb +0 -0
- warp/tests/assets/test_vec_grid.nvdb +0 -0
- warp/tests/assets/torus.nvdb +0 -0
- warp/tests/assets/torus.usda +105 -0
- warp/tests/aux_test_class_kernel.py +34 -0
- warp/tests/aux_test_compile_consts_dummy.py +18 -0
- warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
- warp/tests/aux_test_dependent.py +29 -0
- warp/tests/aux_test_grad_customs.py +29 -0
- warp/tests/aux_test_instancing_gc.py +26 -0
- warp/tests/aux_test_module_unload.py +23 -0
- warp/tests/aux_test_name_clash1.py +40 -0
- warp/tests/aux_test_name_clash2.py +40 -0
- warp/tests/aux_test_reference.py +9 -0
- warp/tests/aux_test_reference_reference.py +8 -0
- warp/tests/aux_test_square.py +16 -0
- warp/tests/aux_test_unresolved_func.py +22 -0
- warp/tests/aux_test_unresolved_symbol.py +22 -0
- warp/tests/cuda/__init__.py +0 -0
- warp/tests/cuda/test_async.py +676 -0
- warp/tests/cuda/test_ipc.py +124 -0
- warp/tests/cuda/test_mempool.py +233 -0
- warp/tests/cuda/test_multigpu.py +169 -0
- warp/tests/cuda/test_peer.py +139 -0
- warp/tests/cuda/test_pinned.py +84 -0
- warp/tests/cuda/test_streams.py +634 -0
- warp/tests/geometry/__init__.py +0 -0
- warp/tests/geometry/test_bvh.py +200 -0
- warp/tests/geometry/test_hash_grid.py +221 -0
- warp/tests/geometry/test_marching_cubes.py +74 -0
- warp/tests/geometry/test_mesh.py +316 -0
- warp/tests/geometry/test_mesh_query_aabb.py +399 -0
- warp/tests/geometry/test_mesh_query_point.py +932 -0
- warp/tests/geometry/test_mesh_query_ray.py +311 -0
- warp/tests/geometry/test_volume.py +1103 -0
- warp/tests/geometry/test_volume_write.py +346 -0
- warp/tests/interop/__init__.py +0 -0
- warp/tests/interop/test_dlpack.py +729 -0
- warp/tests/interop/test_jax.py +371 -0
- warp/tests/interop/test_paddle.py +800 -0
- warp/tests/interop/test_torch.py +1001 -0
- warp/tests/run_coverage_serial.py +39 -0
- warp/tests/sim/__init__.py +0 -0
- warp/tests/sim/disabled_kinematics.py +244 -0
- warp/tests/sim/flaky_test_sim_grad.py +290 -0
- warp/tests/sim/test_collision.py +604 -0
- warp/tests/sim/test_coloring.py +258 -0
- warp/tests/sim/test_model.py +224 -0
- warp/tests/sim/test_sim_grad_bounce_linear.py +212 -0
- warp/tests/sim/test_sim_kinematics.py +98 -0
- warp/tests/sim/test_vbd.py +597 -0
- warp/tests/test_adam.py +163 -0
- warp/tests/test_arithmetic.py +1096 -0
- warp/tests/test_array.py +2972 -0
- warp/tests/test_array_reduce.py +156 -0
- warp/tests/test_assert.py +250 -0
- warp/tests/test_atomic.py +153 -0
- warp/tests/test_bool.py +220 -0
- warp/tests/test_builtins_resolution.py +1298 -0
- warp/tests/test_closest_point_edge_edge.py +327 -0
- warp/tests/test_codegen.py +810 -0
- warp/tests/test_codegen_instancing.py +1495 -0
- warp/tests/test_compile_consts.py +215 -0
- warp/tests/test_conditional.py +252 -0
- warp/tests/test_context.py +42 -0
- warp/tests/test_copy.py +238 -0
- warp/tests/test_ctypes.py +638 -0
- warp/tests/test_dense.py +73 -0
- warp/tests/test_devices.py +97 -0
- warp/tests/test_examples.py +482 -0
- warp/tests/test_fabricarray.py +996 -0
- warp/tests/test_fast_math.py +74 -0
- warp/tests/test_fem.py +2003 -0
- warp/tests/test_fp16.py +136 -0
- warp/tests/test_func.py +454 -0
- warp/tests/test_future_annotations.py +98 -0
- warp/tests/test_generics.py +656 -0
- warp/tests/test_grad.py +893 -0
- warp/tests/test_grad_customs.py +339 -0
- warp/tests/test_grad_debug.py +341 -0
- warp/tests/test_implicit_init.py +411 -0
- warp/tests/test_import.py +45 -0
- warp/tests/test_indexedarray.py +1140 -0
- warp/tests/test_intersect.py +73 -0
- warp/tests/test_iter.py +76 -0
- warp/tests/test_large.py +177 -0
- warp/tests/test_launch.py +411 -0
- warp/tests/test_lerp.py +151 -0
- warp/tests/test_linear_solvers.py +193 -0
- warp/tests/test_lvalue.py +427 -0
- warp/tests/test_mat.py +2089 -0
- warp/tests/test_mat_lite.py +122 -0
- warp/tests/test_mat_scalar_ops.py +2913 -0
- warp/tests/test_math.py +178 -0
- warp/tests/test_mlp.py +282 -0
- warp/tests/test_module_hashing.py +258 -0
- warp/tests/test_modules_lite.py +44 -0
- warp/tests/test_noise.py +252 -0
- warp/tests/test_operators.py +299 -0
- warp/tests/test_options.py +129 -0
- warp/tests/test_overwrite.py +551 -0
- warp/tests/test_print.py +339 -0
- warp/tests/test_quat.py +2315 -0
- warp/tests/test_rand.py +339 -0
- warp/tests/test_reload.py +302 -0
- warp/tests/test_rounding.py +185 -0
- warp/tests/test_runlength_encode.py +196 -0
- warp/tests/test_scalar_ops.py +105 -0
- warp/tests/test_smoothstep.py +108 -0
- warp/tests/test_snippet.py +318 -0
- warp/tests/test_sparse.py +582 -0
- warp/tests/test_spatial.py +2229 -0
- warp/tests/test_special_values.py +361 -0
- warp/tests/test_static.py +592 -0
- warp/tests/test_struct.py +734 -0
- warp/tests/test_tape.py +204 -0
- warp/tests/test_transient_module.py +93 -0
- warp/tests/test_triangle_closest_point.py +145 -0
- warp/tests/test_types.py +562 -0
- warp/tests/test_utils.py +588 -0
- warp/tests/test_vec.py +1487 -0
- warp/tests/test_vec_lite.py +80 -0
- warp/tests/test_vec_scalar_ops.py +2327 -0
- warp/tests/test_verify_fp.py +100 -0
- warp/tests/tile/__init__.py +0 -0
- warp/tests/tile/test_tile.py +780 -0
- warp/tests/tile/test_tile_load.py +407 -0
- warp/tests/tile/test_tile_mathdx.py +208 -0
- warp/tests/tile/test_tile_mlp.py +402 -0
- warp/tests/tile/test_tile_reduce.py +447 -0
- warp/tests/tile/test_tile_shared_memory.py +247 -0
- warp/tests/tile/test_tile_view.py +173 -0
- warp/tests/unittest_serial.py +47 -0
- warp/tests/unittest_suites.py +427 -0
- warp/tests/unittest_utils.py +468 -0
- warp/tests/walkthrough_debug.py +93 -0
- warp/thirdparty/__init__.py +0 -0
- warp/thirdparty/appdirs.py +598 -0
- warp/thirdparty/dlpack.py +145 -0
- warp/thirdparty/unittest_parallel.py +570 -0
- warp/torch.py +391 -0
- warp/types.py +5230 -0
- warp/utils.py +1137 -0
- warp_lang-1.7.0.dist-info/METADATA +516 -0
- warp_lang-1.7.0.dist-info/RECORD +429 -0
- warp_lang-1.7.0.dist-info/WHEEL +5 -0
- warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
- warp_lang-1.7.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,646 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
3
|
+
* SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
*
|
|
5
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
* you may not use this file except in compliance with the License.
|
|
7
|
+
* You may obtain a copy of the License at
|
|
8
|
+
*
|
|
9
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
*
|
|
11
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
* See the License for the specific language governing permissions and
|
|
15
|
+
* limitations under the License.
|
|
16
|
+
*/
|
|
17
|
+
|
|
18
|
+
#if WP_ENABLE_CUDA
|
|
19
|
+
|
|
20
|
+
#include "cuda_util.h"
|
|
21
|
+
#include "error.h"
|
|
22
|
+
|
|
23
|
+
#if defined(_WIN32)
|
|
24
|
+
#define WIN32_LEAN_AND_MEAN
|
|
25
|
+
#define NOMINMAX
|
|
26
|
+
#include <windows.h>
|
|
27
|
+
#include <wingdi.h> // needed for OpenGL includes
|
|
28
|
+
#elif defined(__linux__)
|
|
29
|
+
#include <dlfcn.h>
|
|
30
|
+
#endif
|
|
31
|
+
|
|
32
|
+
#include <set>
|
|
33
|
+
#include <stack>
|
|
34
|
+
|
|
35
|
+
// the minimum CUDA version required from the driver
|
|
36
|
+
#define WP_CUDA_DRIVER_VERSION 11040
|
|
37
|
+
|
|
38
|
+
// the minimum CUDA Toolkit version required to build Warp
|
|
39
|
+
#define WP_CUDA_TOOLKIT_VERSION 11050
|
|
40
|
+
|
|
41
|
+
// check if the CUDA Toolkit is too old
|
|
42
|
+
#if CUDA_VERSION < WP_CUDA_TOOLKIT_VERSION
|
|
43
|
+
#error Building Warp requires CUDA Toolkit version 11.5 or higher
|
|
44
|
+
#endif
|
|
45
|
+
|
|
46
|
+
// Avoid including <cudaGLTypedefs.h>, which requires OpenGL headers to be installed.
|
|
47
|
+
// We define our own GL types, based on the spec here: https://www.khronos.org/opengl/wiki/OpenGL_Type
|
|
48
|
+
namespace wp
|
|
49
|
+
{
|
|
50
|
+
typedef uint32_t GLuint;
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
// function prototypes adapted from <cudaGLTypedefs.h>
|
|
54
|
+
typedef CUresult (CUDAAPI *PFN_cuGraphicsGLRegisterBuffer_v3000)(CUgraphicsResource *pCudaResource, wp::GLuint buffer, unsigned int Flags);
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
// function pointers to driver API entry points
|
|
58
|
+
// these are explicitly versioned according to cudaTypedefs.h from CUDA Toolkit WP_CUDA_TOOLKIT_VERSION
|
|
59
|
+
#if CUDA_VERSION < 12000
|
|
60
|
+
static PFN_cuGetProcAddress_v11030 pfn_cuGetProcAddress;
|
|
61
|
+
#else
|
|
62
|
+
static PFN_cuGetProcAddress_v12000 pfn_cuGetProcAddress;
|
|
63
|
+
#endif
|
|
64
|
+
static PFN_cuDriverGetVersion_v2020 pfn_cuDriverGetVersion;
|
|
65
|
+
static PFN_cuGetErrorName_v6000 pfn_cuGetErrorName;
|
|
66
|
+
static PFN_cuGetErrorString_v6000 pfn_cuGetErrorString;
|
|
67
|
+
static PFN_cuInit_v2000 pfn_cuInit;
|
|
68
|
+
static PFN_cuDeviceGet_v2000 pfn_cuDeviceGet;
|
|
69
|
+
static PFN_cuDeviceGetCount_v2000 pfn_cuDeviceGetCount;
|
|
70
|
+
static PFN_cuDeviceGetName_v2000 pfn_cuDeviceGetName;
|
|
71
|
+
static PFN_cuDeviceGetAttribute_v2000 pfn_cuDeviceGetAttribute;
|
|
72
|
+
static PFN_cuDeviceGetUuid_v11040 pfn_cuDeviceGetUuid;
|
|
73
|
+
static PFN_cuDevicePrimaryCtxRetain_v7000 pfn_cuDevicePrimaryCtxRetain;
|
|
74
|
+
static PFN_cuDevicePrimaryCtxRelease_v11000 pfn_cuDevicePrimaryCtxRelease;
|
|
75
|
+
static PFN_cuDeviceCanAccessPeer_v4000 pfn_cuDeviceCanAccessPeer;
|
|
76
|
+
static PFN_cuMemGetInfo_v3020 pfn_cuMemGetInfo;
|
|
77
|
+
static PFN_cuCtxGetCurrent_v4000 pfn_cuCtxGetCurrent;
|
|
78
|
+
static PFN_cuCtxSetCurrent_v4000 pfn_cuCtxSetCurrent;
|
|
79
|
+
static PFN_cuCtxPushCurrent_v4000 pfn_cuCtxPushCurrent;
|
|
80
|
+
static PFN_cuCtxPopCurrent_v4000 pfn_cuCtxPopCurrent;
|
|
81
|
+
static PFN_cuCtxSynchronize_v2000 pfn_cuCtxSynchronize;
|
|
82
|
+
static PFN_cuCtxGetDevice_v2000 pfn_cuCtxGetDevice;
|
|
83
|
+
static PFN_cuCtxCreate_v3020 pfn_cuCtxCreate;
|
|
84
|
+
static PFN_cuCtxDestroy_v4000 pfn_cuCtxDestroy;
|
|
85
|
+
static PFN_cuCtxEnablePeerAccess_v4000 pfn_cuCtxEnablePeerAccess;
|
|
86
|
+
static PFN_cuCtxDisablePeerAccess_v4000 pfn_cuCtxDisablePeerAccess;
|
|
87
|
+
static PFN_cuStreamCreate_v2000 pfn_cuStreamCreate;
|
|
88
|
+
static PFN_cuStreamDestroy_v4000 pfn_cuStreamDestroy;
|
|
89
|
+
static PFN_cuStreamQuery_v2000 pfn_cuStreamQuery;
|
|
90
|
+
static PFN_cuStreamSynchronize_v2000 pfn_cuStreamSynchronize;
|
|
91
|
+
static PFN_cuStreamWaitEvent_v3020 pfn_cuStreamWaitEvent;
|
|
92
|
+
static PFN_cuStreamGetCtx_v9020 pfn_cuStreamGetCtx;
|
|
93
|
+
static PFN_cuStreamGetCaptureInfo_v11030 pfn_cuStreamGetCaptureInfo;
|
|
94
|
+
static PFN_cuStreamUpdateCaptureDependencies_v11030 pfn_cuStreamUpdateCaptureDependencies;
|
|
95
|
+
static PFN_cuStreamCreateWithPriority_v5050 pfn_cuStreamCreateWithPriority;
|
|
96
|
+
static PFN_cuStreamGetPriority_v5050 pfn_cuStreamGetPriority;
|
|
97
|
+
static PFN_cuEventCreate_v2000 pfn_cuEventCreate;
|
|
98
|
+
static PFN_cuEventDestroy_v4000 pfn_cuEventDestroy;
|
|
99
|
+
static PFN_cuEventQuery_v2000 pfn_cuEventQuery;
|
|
100
|
+
static PFN_cuEventRecord_v2000 pfn_cuEventRecord;
|
|
101
|
+
static PFN_cuEventRecordWithFlags_v11010 pfn_cuEventRecordWithFlags;
|
|
102
|
+
static PFN_cuEventSynchronize_v2000 pfn_cuEventSynchronize;
|
|
103
|
+
static PFN_cuModuleLoadDataEx_v2010 pfn_cuModuleLoadDataEx;
|
|
104
|
+
static PFN_cuModuleUnload_v2000 pfn_cuModuleUnload;
|
|
105
|
+
static PFN_cuModuleGetFunction_v2000 pfn_cuModuleGetFunction;
|
|
106
|
+
static PFN_cuLaunchKernel_v4000 pfn_cuLaunchKernel;
|
|
107
|
+
static PFN_cuMemcpyPeerAsync_v4000 pfn_cuMemcpyPeerAsync;
|
|
108
|
+
static PFN_cuPointerGetAttribute_v4000 pfn_cuPointerGetAttribute;
|
|
109
|
+
static PFN_cuGraphicsMapResources_v3000 pfn_cuGraphicsMapResources;
|
|
110
|
+
static PFN_cuGraphicsUnmapResources_v3000 pfn_cuGraphicsUnmapResources;
|
|
111
|
+
static PFN_cuGraphicsResourceGetMappedPointer_v3020 pfn_cuGraphicsResourceGetMappedPointer;
|
|
112
|
+
static PFN_cuGraphicsGLRegisterBuffer_v3000 pfn_cuGraphicsGLRegisterBuffer;
|
|
113
|
+
static PFN_cuGraphicsUnregisterResource_v3000 pfn_cuGraphicsUnregisterResource;
|
|
114
|
+
static PFN_cuModuleGetGlobal_v3020 pfn_cuModuleGetGlobal;
|
|
115
|
+
static PFN_cuFuncSetAttribute_v9000 pfn_cuFuncSetAttribute;
|
|
116
|
+
static PFN_cuIpcGetEventHandle_v4010 pfn_cuIpcGetEventHandle;
|
|
117
|
+
static PFN_cuIpcOpenEventHandle_v4010 pfn_cuIpcOpenEventHandle;
|
|
118
|
+
static PFN_cuIpcGetMemHandle_v4010 pfn_cuIpcGetMemHandle;
|
|
119
|
+
static PFN_cuIpcOpenMemHandle_v11000 pfn_cuIpcOpenMemHandle;
|
|
120
|
+
static PFN_cuIpcCloseMemHandle_v4010 pfn_cuIpcCloseMemHandle;
|
|
121
|
+
|
|
122
|
+
static bool cuda_driver_initialized = false;
|
|
123
|
+
|
|
124
|
+
bool ContextGuard::always_restore = false;
|
|
125
|
+
|
|
126
|
+
CudaTimingState* g_cuda_timing_state = NULL;
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
static inline int get_major(int version)
|
|
130
|
+
{
|
|
131
|
+
return version / 1000;
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
static inline int get_minor(int version)
|
|
135
|
+
{
|
|
136
|
+
return (version % 1000) / 10;
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
// Get versioned driver entry point. The version argument should match the function pointer type.
|
|
140
|
+
// For example, to initialize PFN_cuCtxCreate_v3020 use version 3020.
|
|
141
|
+
static bool get_driver_entry_point(const char* name, int version, void** pfn)
|
|
142
|
+
{
|
|
143
|
+
if (!pfn_cuGetProcAddress || !name || !pfn)
|
|
144
|
+
return false;
|
|
145
|
+
|
|
146
|
+
#if CUDA_VERSION < 12000
|
|
147
|
+
CUresult r = pfn_cuGetProcAddress(name, pfn, version, CU_GET_PROC_ADDRESS_DEFAULT);
|
|
148
|
+
#else
|
|
149
|
+
CUresult r = pfn_cuGetProcAddress(name, pfn, version, CU_GET_PROC_ADDRESS_DEFAULT, NULL);
|
|
150
|
+
#endif
|
|
151
|
+
|
|
152
|
+
if (r != CUDA_SUCCESS)
|
|
153
|
+
{
|
|
154
|
+
fprintf(stderr, "Warp CUDA error: Failed to get driver entry point '%s' (CUDA error %u)\n", name, unsigned(r));
|
|
155
|
+
return false;
|
|
156
|
+
}
|
|
157
|
+
|
|
158
|
+
return true;
|
|
159
|
+
}
|
|
160
|
+
|
|
161
|
+
bool init_cuda_driver()
|
|
162
|
+
{
|
|
163
|
+
#if defined(_WIN32)
|
|
164
|
+
static HMODULE hCudaDriver = LoadLibraryA("nvcuda.dll");
|
|
165
|
+
if (hCudaDriver == NULL) {
|
|
166
|
+
fprintf(stderr, "Warp CUDA error: Could not open nvcuda.dll.\n");
|
|
167
|
+
return false;
|
|
168
|
+
}
|
|
169
|
+
pfn_cuGetProcAddress = (PFN_cuGetProcAddress)GetProcAddress(hCudaDriver, "cuGetProcAddress");
|
|
170
|
+
#elif defined(__linux__)
|
|
171
|
+
static void* hCudaDriver = dlopen("libcuda.so", RTLD_NOW);
|
|
172
|
+
if (hCudaDriver == NULL) {
|
|
173
|
+
// WSL and possibly other systems might require the .1 suffix
|
|
174
|
+
hCudaDriver = dlopen("libcuda.so.1", RTLD_NOW);
|
|
175
|
+
if (hCudaDriver == NULL) {
|
|
176
|
+
fprintf(stderr, "Warp CUDA error: Could not open libcuda.so.\n");
|
|
177
|
+
return false;
|
|
178
|
+
}
|
|
179
|
+
}
|
|
180
|
+
pfn_cuGetProcAddress = (PFN_cuGetProcAddress)dlsym(hCudaDriver, "cuGetProcAddress");
|
|
181
|
+
#endif
|
|
182
|
+
|
|
183
|
+
if (!pfn_cuGetProcAddress)
|
|
184
|
+
{
|
|
185
|
+
fprintf(stderr, "Warp CUDA error: Failed to get function cuGetProcAddress\n");
|
|
186
|
+
return false;
|
|
187
|
+
}
|
|
188
|
+
|
|
189
|
+
// check the CUDA driver version and report an error if it's too low
|
|
190
|
+
int driver_version = 0;
|
|
191
|
+
if (get_driver_entry_point("cuDriverGetVersion", 2020, &(void*&)pfn_cuDriverGetVersion) &&
|
|
192
|
+
check_cu(pfn_cuDriverGetVersion(&driver_version)))
|
|
193
|
+
{
|
|
194
|
+
if (driver_version < WP_CUDA_DRIVER_VERSION)
|
|
195
|
+
{
|
|
196
|
+
fprintf(stderr, "Warp CUDA error: Warp requires CUDA driver %d.%d or higher, but the current driver only supports CUDA %d.%d\n",
|
|
197
|
+
get_major(WP_CUDA_DRIVER_VERSION), get_minor(WP_CUDA_DRIVER_VERSION),
|
|
198
|
+
get_major(driver_version), get_minor(driver_version));
|
|
199
|
+
return false;
|
|
200
|
+
}
|
|
201
|
+
}
|
|
202
|
+
else
|
|
203
|
+
{
|
|
204
|
+
fprintf(stderr, "Warp CUDA warning: Unable to determine CUDA driver version\n");
|
|
205
|
+
}
|
|
206
|
+
|
|
207
|
+
// initialize driver entry points
|
|
208
|
+
get_driver_entry_point("cuGetErrorString", 6000, &(void*&)pfn_cuGetErrorString);
|
|
209
|
+
get_driver_entry_point("cuGetErrorName", 6000, &(void*&)pfn_cuGetErrorName);
|
|
210
|
+
get_driver_entry_point("cuInit", 2000, &(void*&)pfn_cuInit);
|
|
211
|
+
get_driver_entry_point("cuDeviceGet", 2000, &(void*&)pfn_cuDeviceGet);
|
|
212
|
+
get_driver_entry_point("cuDeviceGetCount", 2000, &(void*&)pfn_cuDeviceGetCount);
|
|
213
|
+
get_driver_entry_point("cuDeviceGetName", 2000, &(void*&)pfn_cuDeviceGetName);
|
|
214
|
+
get_driver_entry_point("cuDeviceGetAttribute", 2000, &(void*&)pfn_cuDeviceGetAttribute);
|
|
215
|
+
get_driver_entry_point("cuDeviceGetUuid", 110400, &(void*&)pfn_cuDeviceGetUuid);
|
|
216
|
+
get_driver_entry_point("cuDevicePrimaryCtxRetain", 7000, &(void*&)pfn_cuDevicePrimaryCtxRetain);
|
|
217
|
+
get_driver_entry_point("cuDevicePrimaryCtxRelease", 11000, &(void*&)pfn_cuDevicePrimaryCtxRelease);
|
|
218
|
+
get_driver_entry_point("cuDeviceCanAccessPeer", 4000, &(void*&)pfn_cuDeviceCanAccessPeer);
|
|
219
|
+
get_driver_entry_point("cuMemGetInfo", 3020, &(void*&)pfn_cuMemGetInfo);
|
|
220
|
+
get_driver_entry_point("cuCtxSetCurrent", 4000, &(void*&)pfn_cuCtxSetCurrent);
|
|
221
|
+
get_driver_entry_point("cuCtxGetCurrent", 4000, &(void*&)pfn_cuCtxGetCurrent);
|
|
222
|
+
get_driver_entry_point("cuCtxPushCurrent", 4000, &(void*&)pfn_cuCtxPushCurrent);
|
|
223
|
+
get_driver_entry_point("cuCtxPopCurrent", 4000, &(void*&)pfn_cuCtxPopCurrent);
|
|
224
|
+
get_driver_entry_point("cuCtxSynchronize", 2000, &(void*&)pfn_cuCtxSynchronize);
|
|
225
|
+
get_driver_entry_point("cuCtxGetDevice", 2000, &(void*&)pfn_cuCtxGetDevice);
|
|
226
|
+
get_driver_entry_point("cuCtxCreate", 3020, &(void*&)pfn_cuCtxCreate);
|
|
227
|
+
get_driver_entry_point("cuCtxDestroy", 4000, &(void*&)pfn_cuCtxDestroy);
|
|
228
|
+
get_driver_entry_point("cuCtxEnablePeerAccess", 4000, &(void*&)pfn_cuCtxEnablePeerAccess);
|
|
229
|
+
get_driver_entry_point("cuCtxDisablePeerAccess", 4000, &(void*&)pfn_cuCtxDisablePeerAccess);
|
|
230
|
+
get_driver_entry_point("cuStreamCreate", 2000, &(void*&)pfn_cuStreamCreate);
|
|
231
|
+
get_driver_entry_point("cuStreamDestroy", 4000, &(void*&)pfn_cuStreamDestroy);
|
|
232
|
+
get_driver_entry_point("cuStreamQuery", 2000, &(void*&)pfn_cuStreamQuery);
|
|
233
|
+
get_driver_entry_point("cuStreamSynchronize", 2000, &(void*&)pfn_cuStreamSynchronize);
|
|
234
|
+
get_driver_entry_point("cuStreamWaitEvent", 3020, &(void*&)pfn_cuStreamWaitEvent);
|
|
235
|
+
get_driver_entry_point("cuStreamGetCtx", 9020, &(void*&)pfn_cuStreamGetCtx);
|
|
236
|
+
get_driver_entry_point("cuStreamGetCaptureInfo", 11030, &(void*&)pfn_cuStreamGetCaptureInfo);
|
|
237
|
+
get_driver_entry_point("cuStreamUpdateCaptureDependencies", 11030, &(void*&)pfn_cuStreamUpdateCaptureDependencies);
|
|
238
|
+
get_driver_entry_point("cuStreamCreateWithPriority", 5050, &(void*&)pfn_cuStreamCreateWithPriority);
|
|
239
|
+
get_driver_entry_point("cuStreamGetPriority", 5050, &(void*&)pfn_cuStreamGetPriority);
|
|
240
|
+
get_driver_entry_point("cuEventCreate", 2000, &(void*&)pfn_cuEventCreate);
|
|
241
|
+
get_driver_entry_point("cuEventDestroy", 4000, &(void*&)pfn_cuEventDestroy);
|
|
242
|
+
get_driver_entry_point("cuEventQuery", 2000, &(void*&)pfn_cuEventQuery);
|
|
243
|
+
get_driver_entry_point("cuEventRecord", 2000, &(void*&)pfn_cuEventRecord);
|
|
244
|
+
get_driver_entry_point("cuEventRecordWithFlags", 11010, &(void*&)pfn_cuEventRecordWithFlags);
|
|
245
|
+
get_driver_entry_point("cuEventSynchronize", 2000, &(void*&)pfn_cuEventSynchronize);
|
|
246
|
+
get_driver_entry_point("cuModuleLoadDataEx", 2010, &(void*&)pfn_cuModuleLoadDataEx);
|
|
247
|
+
get_driver_entry_point("cuModuleUnload", 2000, &(void*&)pfn_cuModuleUnload);
|
|
248
|
+
get_driver_entry_point("cuModuleGetFunction", 2000, &(void*&)pfn_cuModuleGetFunction);
|
|
249
|
+
get_driver_entry_point("cuLaunchKernel", 4000, &(void*&)pfn_cuLaunchKernel);
|
|
250
|
+
get_driver_entry_point("cuMemcpyPeerAsync", 4000, &(void*&)pfn_cuMemcpyPeerAsync);
|
|
251
|
+
get_driver_entry_point("cuPointerGetAttribute", 4000, &(void*&)pfn_cuPointerGetAttribute);
|
|
252
|
+
get_driver_entry_point("cuGraphicsMapResources", 3000, &(void*&)pfn_cuGraphicsMapResources);
|
|
253
|
+
get_driver_entry_point("cuGraphicsUnmapResources", 3000, &(void*&)pfn_cuGraphicsUnmapResources);
|
|
254
|
+
get_driver_entry_point("cuGraphicsResourceGetMappedPointer", 3020, &(void*&)pfn_cuGraphicsResourceGetMappedPointer);
|
|
255
|
+
get_driver_entry_point("cuGraphicsGLRegisterBuffer", 3000, &(void*&)pfn_cuGraphicsGLRegisterBuffer);
|
|
256
|
+
get_driver_entry_point("cuGraphicsUnregisterResource", 3000, &(void*&)pfn_cuGraphicsUnregisterResource);
|
|
257
|
+
get_driver_entry_point("cuModuleGetGlobal", 3020, &(void*&)pfn_cuModuleGetGlobal);
|
|
258
|
+
get_driver_entry_point("cuFuncSetAttribute", 9000, &(void*&)pfn_cuFuncSetAttribute);
|
|
259
|
+
get_driver_entry_point("cuIpcGetEventHandle", 4010, &(void*&)pfn_cuIpcGetEventHandle);
|
|
260
|
+
get_driver_entry_point("cuIpcOpenEventHandle", 4010, &(void*&)pfn_cuIpcOpenEventHandle);
|
|
261
|
+
get_driver_entry_point("cuIpcGetMemHandle", 4010, &(void*&)pfn_cuIpcGetMemHandle);
|
|
262
|
+
get_driver_entry_point("cuIpcOpenMemHandle", 11000, &(void*&)pfn_cuIpcOpenMemHandle);
|
|
263
|
+
get_driver_entry_point("cuIpcCloseMemHandle", 4010, &(void*&)pfn_cuIpcCloseMemHandle);
|
|
264
|
+
|
|
265
|
+
if (pfn_cuInit)
|
|
266
|
+
cuda_driver_initialized = check_cu(pfn_cuInit(0));
|
|
267
|
+
|
|
268
|
+
return cuda_driver_initialized;
|
|
269
|
+
}
|
|
270
|
+
|
|
271
|
+
bool is_cuda_driver_initialized()
|
|
272
|
+
{
|
|
273
|
+
return cuda_driver_initialized;
|
|
274
|
+
}
|
|
275
|
+
|
|
276
|
+
bool check_cuda_result(cudaError_t code, const char* func, const char* file, int line)
|
|
277
|
+
{
|
|
278
|
+
if (code == cudaSuccess)
|
|
279
|
+
return true;
|
|
280
|
+
|
|
281
|
+
wp::set_error_string("Warp CUDA error %u: %s (in function %s, %s:%d)", unsigned(code), cudaGetErrorString(code), func, file, line);
|
|
282
|
+
return false;
|
|
283
|
+
}
|
|
284
|
+
|
|
285
|
+
bool check_cu_result(CUresult result, const char* func, const char* file, int line)
|
|
286
|
+
{
|
|
287
|
+
if (result == CUDA_SUCCESS)
|
|
288
|
+
return true;
|
|
289
|
+
|
|
290
|
+
const char* errString = NULL;
|
|
291
|
+
if (pfn_cuGetErrorString)
|
|
292
|
+
pfn_cuGetErrorString(result, &errString);
|
|
293
|
+
|
|
294
|
+
if (errString)
|
|
295
|
+
wp::set_error_string("Warp CUDA error %u: %s (in function %s, %s:%d)", unsigned(result), errString, func, file, line);
|
|
296
|
+
else
|
|
297
|
+
wp::set_error_string("Warp CUDA error %u (in function %s, %s:%d)", unsigned(result), func, file, line);
|
|
298
|
+
|
|
299
|
+
return false;
|
|
300
|
+
}
|
|
301
|
+
|
|
302
|
+
bool get_capture_dependencies(CUstream stream, std::vector<CUgraphNode>& dependencies_ret)
|
|
303
|
+
{
|
|
304
|
+
CUstreamCaptureStatus status;
|
|
305
|
+
size_t num_dependencies = 0;
|
|
306
|
+
const CUgraphNode* dependencies = NULL;
|
|
307
|
+
dependencies_ret.clear();
|
|
308
|
+
if (check_cu(cuStreamGetCaptureInfo_f(stream, &status, NULL, NULL, &dependencies, &num_dependencies)))
|
|
309
|
+
{
|
|
310
|
+
if (dependencies && num_dependencies > 0)
|
|
311
|
+
dependencies_ret.insert(dependencies_ret.begin(), dependencies, dependencies + num_dependencies);
|
|
312
|
+
return true;
|
|
313
|
+
}
|
|
314
|
+
return false;
|
|
315
|
+
}
|
|
316
|
+
|
|
317
|
+
bool get_graph_leaf_nodes(cudaGraph_t graph, std::vector<cudaGraphNode_t>& leaf_nodes_ret)
|
|
318
|
+
{
|
|
319
|
+
if (!graph)
|
|
320
|
+
return false;
|
|
321
|
+
|
|
322
|
+
size_t node_count = 0;
|
|
323
|
+
if (!check_cuda(cudaGraphGetNodes(graph, NULL, &node_count)))
|
|
324
|
+
return false;
|
|
325
|
+
|
|
326
|
+
std::vector<cudaGraphNode_t> nodes(node_count);
|
|
327
|
+
if (!check_cuda(cudaGraphGetNodes(graph, nodes.data(), &node_count)))
|
|
328
|
+
return false;
|
|
329
|
+
|
|
330
|
+
leaf_nodes_ret.clear();
|
|
331
|
+
|
|
332
|
+
for (cudaGraphNode_t node : nodes)
|
|
333
|
+
{
|
|
334
|
+
size_t dependent_count;
|
|
335
|
+
if (!check_cuda(cudaGraphNodeGetDependentNodes(node, NULL, &dependent_count)))
|
|
336
|
+
return false;
|
|
337
|
+
|
|
338
|
+
if (dependent_count == 0)
|
|
339
|
+
leaf_nodes_ret.push_back(node);
|
|
340
|
+
}
|
|
341
|
+
|
|
342
|
+
return true;
|
|
343
|
+
}
|
|
344
|
+
|
|
345
|
+
|
|
346
|
+
#define DRIVER_ENTRY_POINT_ERROR driver_entry_point_error(__FUNCTION__)
|
|
347
|
+
|
|
348
|
+
static CUresult driver_entry_point_error(const char* function)
|
|
349
|
+
{
|
|
350
|
+
fprintf(stderr, "Warp CUDA error: Function %s: a suitable driver entry point was not found\n", function);
|
|
351
|
+
return (CUresult)cudaErrorCallRequiresNewerDriver; // this matches what cudart would do
|
|
352
|
+
}
|
|
353
|
+
|
|
354
|
+
CUresult cuDriverGetVersion_f(int* version)
|
|
355
|
+
{
|
|
356
|
+
return pfn_cuDriverGetVersion ? pfn_cuDriverGetVersion(version) : DRIVER_ENTRY_POINT_ERROR;
|
|
357
|
+
}
|
|
358
|
+
|
|
359
|
+
CUresult cuGetErrorName_f(CUresult result, const char** pstr)
|
|
360
|
+
{
|
|
361
|
+
return pfn_cuGetErrorName ? pfn_cuGetErrorName(result, pstr) : DRIVER_ENTRY_POINT_ERROR;
|
|
362
|
+
}
|
|
363
|
+
|
|
364
|
+
CUresult cuGetErrorString_f(CUresult result, const char** pstr)
|
|
365
|
+
{
|
|
366
|
+
return pfn_cuGetErrorString ? pfn_cuGetErrorString(result, pstr) : DRIVER_ENTRY_POINT_ERROR;
|
|
367
|
+
}
|
|
368
|
+
|
|
369
|
+
CUresult cuInit_f(unsigned int flags)
|
|
370
|
+
{
|
|
371
|
+
return pfn_cuInit ? pfn_cuInit(flags) : DRIVER_ENTRY_POINT_ERROR;
|
|
372
|
+
}
|
|
373
|
+
|
|
374
|
+
CUresult cuDeviceGet_f(CUdevice *dev, int ordinal)
|
|
375
|
+
{
|
|
376
|
+
return pfn_cuDeviceGet ? pfn_cuDeviceGet(dev, ordinal) : DRIVER_ENTRY_POINT_ERROR;
|
|
377
|
+
}
|
|
378
|
+
|
|
379
|
+
CUresult cuDeviceGetCount_f(int* count)
|
|
380
|
+
{
|
|
381
|
+
if (pfn_cuDeviceGetCount)
|
|
382
|
+
return pfn_cuDeviceGetCount(count);
|
|
383
|
+
|
|
384
|
+
// allow calling this function even if CUDA is not available
|
|
385
|
+
if (count)
|
|
386
|
+
*count = 0;
|
|
387
|
+
|
|
388
|
+
return CUDA_SUCCESS;
|
|
389
|
+
}
|
|
390
|
+
|
|
391
|
+
CUresult cuDeviceGetName_f(char* name, int len, CUdevice dev)
|
|
392
|
+
{
|
|
393
|
+
return pfn_cuDeviceGetName ? pfn_cuDeviceGetName(name, len, dev) : DRIVER_ENTRY_POINT_ERROR;
|
|
394
|
+
}
|
|
395
|
+
|
|
396
|
+
CUresult cuDeviceGetAttribute_f(int* value, CUdevice_attribute attrib, CUdevice dev)
|
|
397
|
+
{
|
|
398
|
+
return pfn_cuDeviceGetAttribute ? pfn_cuDeviceGetAttribute(value, attrib, dev) : DRIVER_ENTRY_POINT_ERROR;
|
|
399
|
+
}
|
|
400
|
+
|
|
401
|
+
CUresult cuDeviceGetUuid_f(CUuuid* uuid, CUdevice dev)
|
|
402
|
+
{
|
|
403
|
+
return pfn_cuDeviceGetUuid ? pfn_cuDeviceGetUuid(uuid, dev) : DRIVER_ENTRY_POINT_ERROR;
|
|
404
|
+
}
|
|
405
|
+
|
|
406
|
+
CUresult cuDevicePrimaryCtxRetain_f(CUcontext* ctx, CUdevice dev)
|
|
407
|
+
{
|
|
408
|
+
return pfn_cuDevicePrimaryCtxRetain ? pfn_cuDevicePrimaryCtxRetain(ctx, dev) : DRIVER_ENTRY_POINT_ERROR;
|
|
409
|
+
}
|
|
410
|
+
|
|
411
|
+
CUresult cuDevicePrimaryCtxRelease_f(CUdevice dev)
|
|
412
|
+
{
|
|
413
|
+
return pfn_cuDevicePrimaryCtxRelease ? pfn_cuDevicePrimaryCtxRelease(dev) : DRIVER_ENTRY_POINT_ERROR;
|
|
414
|
+
}
|
|
415
|
+
|
|
416
|
+
CUresult cuDeviceCanAccessPeer_f(int* can_access, CUdevice dev, CUdevice peer_dev)
|
|
417
|
+
{
|
|
418
|
+
return pfn_cuDeviceCanAccessPeer ? pfn_cuDeviceCanAccessPeer(can_access, dev, peer_dev) : DRIVER_ENTRY_POINT_ERROR;
|
|
419
|
+
}
|
|
420
|
+
|
|
421
|
+
CUresult cuMemGetInfo_f(size_t* free, size_t* total)
|
|
422
|
+
{
|
|
423
|
+
return pfn_cuMemGetInfo ? pfn_cuMemGetInfo(free, total) : DRIVER_ENTRY_POINT_ERROR;
|
|
424
|
+
}
|
|
425
|
+
|
|
426
|
+
CUresult cuCtxGetCurrent_f(CUcontext* ctx)
|
|
427
|
+
{
|
|
428
|
+
return pfn_cuCtxGetCurrent ? pfn_cuCtxGetCurrent(ctx) : DRIVER_ENTRY_POINT_ERROR;
|
|
429
|
+
}
|
|
430
|
+
|
|
431
|
+
CUresult cuCtxSetCurrent_f(CUcontext ctx)
|
|
432
|
+
{
|
|
433
|
+
return pfn_cuCtxSetCurrent ? pfn_cuCtxSetCurrent(ctx) : DRIVER_ENTRY_POINT_ERROR;
|
|
434
|
+
}
|
|
435
|
+
|
|
436
|
+
CUresult cuCtxPushCurrent_f(CUcontext ctx)
|
|
437
|
+
{
|
|
438
|
+
return pfn_cuCtxPushCurrent ? pfn_cuCtxPushCurrent(ctx) : DRIVER_ENTRY_POINT_ERROR;
|
|
439
|
+
}
|
|
440
|
+
|
|
441
|
+
CUresult cuCtxPopCurrent_f(CUcontext* ctx)
|
|
442
|
+
{
|
|
443
|
+
return pfn_cuCtxPopCurrent ? pfn_cuCtxPopCurrent(ctx) : DRIVER_ENTRY_POINT_ERROR;
|
|
444
|
+
}
|
|
445
|
+
|
|
446
|
+
CUresult cuCtxSynchronize_f()
|
|
447
|
+
{
|
|
448
|
+
return pfn_cuCtxSynchronize ? pfn_cuCtxSynchronize() : DRIVER_ENTRY_POINT_ERROR;
|
|
449
|
+
}
|
|
450
|
+
|
|
451
|
+
CUresult cuCtxGetDevice_f(CUdevice* dev)
|
|
452
|
+
{
|
|
453
|
+
return pfn_cuCtxGetDevice ? pfn_cuCtxGetDevice(dev) : DRIVER_ENTRY_POINT_ERROR;
|
|
454
|
+
}
|
|
455
|
+
|
|
456
|
+
CUresult cuCtxCreate_f(CUcontext* ctx, unsigned int flags, CUdevice dev)
|
|
457
|
+
{
|
|
458
|
+
return pfn_cuCtxCreate ? pfn_cuCtxCreate(ctx, flags, dev) : DRIVER_ENTRY_POINT_ERROR;
|
|
459
|
+
}
|
|
460
|
+
|
|
461
|
+
CUresult cuCtxDestroy_f(CUcontext ctx)
|
|
462
|
+
{
|
|
463
|
+
return pfn_cuCtxDestroy ? pfn_cuCtxDestroy(ctx) : DRIVER_ENTRY_POINT_ERROR;
|
|
464
|
+
}
|
|
465
|
+
|
|
466
|
+
CUresult cuCtxEnablePeerAccess_f(CUcontext peer_ctx, unsigned int flags)
|
|
467
|
+
{
|
|
468
|
+
return pfn_cuCtxEnablePeerAccess ? pfn_cuCtxEnablePeerAccess(peer_ctx, flags) : DRIVER_ENTRY_POINT_ERROR;
|
|
469
|
+
}
|
|
470
|
+
|
|
471
|
+
CUresult cuCtxDisablePeerAccess_f(CUcontext peer_ctx)
|
|
472
|
+
{
|
|
473
|
+
return pfn_cuCtxDisablePeerAccess ? pfn_cuCtxDisablePeerAccess(peer_ctx) : DRIVER_ENTRY_POINT_ERROR;
|
|
474
|
+
}
|
|
475
|
+
|
|
476
|
+
CUresult cuStreamCreate_f(CUstream* stream, unsigned int flags)
|
|
477
|
+
{
|
|
478
|
+
return pfn_cuStreamCreate ? pfn_cuStreamCreate(stream, flags) : DRIVER_ENTRY_POINT_ERROR;
|
|
479
|
+
}
|
|
480
|
+
|
|
481
|
+
CUresult cuStreamDestroy_f(CUstream stream)
|
|
482
|
+
{
|
|
483
|
+
return pfn_cuStreamDestroy ? pfn_cuStreamDestroy(stream) : DRIVER_ENTRY_POINT_ERROR;
|
|
484
|
+
}
|
|
485
|
+
|
|
486
|
+
CUresult cuStreamQuery_f(CUstream stream)
|
|
487
|
+
{
|
|
488
|
+
return pfn_cuStreamQuery ? pfn_cuStreamQuery(stream) : DRIVER_ENTRY_POINT_ERROR;
|
|
489
|
+
}
|
|
490
|
+
|
|
491
|
+
CUresult cuStreamSynchronize_f(CUstream stream)
|
|
492
|
+
{
|
|
493
|
+
return pfn_cuStreamSynchronize ? pfn_cuStreamSynchronize(stream) : DRIVER_ENTRY_POINT_ERROR;
|
|
494
|
+
}
|
|
495
|
+
|
|
496
|
+
CUresult cuStreamWaitEvent_f(CUstream stream, CUevent event, unsigned int flags)
|
|
497
|
+
{
|
|
498
|
+
return pfn_cuStreamWaitEvent ? pfn_cuStreamWaitEvent(stream, event, flags) : DRIVER_ENTRY_POINT_ERROR;
|
|
499
|
+
}
|
|
500
|
+
|
|
501
|
+
CUresult cuStreamGetCtx_f(CUstream stream, CUcontext* pctx)
|
|
502
|
+
{
|
|
503
|
+
return pfn_cuStreamGetCtx ? pfn_cuStreamGetCtx(stream, pctx) : DRIVER_ENTRY_POINT_ERROR;
|
|
504
|
+
}
|
|
505
|
+
|
|
506
|
+
CUresult cuStreamGetCaptureInfo_f(CUstream stream, CUstreamCaptureStatus *captureStatus_out, cuuint64_t *id_out, CUgraph *graph_out, const CUgraphNode **dependencies_out, size_t *numDependencies_out)
|
|
507
|
+
{
|
|
508
|
+
return pfn_cuStreamGetCaptureInfo ? pfn_cuStreamGetCaptureInfo(stream, captureStatus_out, id_out, graph_out, dependencies_out, numDependencies_out) : DRIVER_ENTRY_POINT_ERROR;
|
|
509
|
+
}
|
|
510
|
+
|
|
511
|
+
CUresult cuStreamUpdateCaptureDependencies_f(CUstream stream, CUgraphNode *dependencies, size_t numDependencies, unsigned int flags)
|
|
512
|
+
{
|
|
513
|
+
return pfn_cuStreamUpdateCaptureDependencies ? pfn_cuStreamUpdateCaptureDependencies(stream, dependencies, numDependencies, flags) : DRIVER_ENTRY_POINT_ERROR;
|
|
514
|
+
}
|
|
515
|
+
|
|
516
|
+
CUresult cuStreamCreateWithPriority_f(CUstream* phStream, unsigned int flags, int priority)
|
|
517
|
+
{
|
|
518
|
+
return pfn_cuStreamCreateWithPriority ? pfn_cuStreamCreateWithPriority(phStream, flags, priority) : DRIVER_ENTRY_POINT_ERROR;
|
|
519
|
+
}
|
|
520
|
+
|
|
521
|
+
CUresult cuStreamGetPriority_f(CUstream hStream, int* priority)
|
|
522
|
+
{
|
|
523
|
+
return pfn_cuStreamGetPriority ? pfn_cuStreamGetPriority(hStream, priority) : DRIVER_ENTRY_POINT_ERROR;
|
|
524
|
+
}
|
|
525
|
+
|
|
526
|
+
CUresult cuEventCreate_f(CUevent* event, unsigned int flags)
|
|
527
|
+
{
|
|
528
|
+
return pfn_cuEventCreate ? pfn_cuEventCreate(event, flags) : DRIVER_ENTRY_POINT_ERROR;
|
|
529
|
+
}
|
|
530
|
+
|
|
531
|
+
CUresult cuEventDestroy_f(CUevent event)
|
|
532
|
+
{
|
|
533
|
+
return pfn_cuEventDestroy ? pfn_cuEventDestroy(event) : DRIVER_ENTRY_POINT_ERROR;
|
|
534
|
+
}
|
|
535
|
+
|
|
536
|
+
CUresult cuEventQuery_f(CUevent event)
|
|
537
|
+
{
|
|
538
|
+
return pfn_cuEventQuery ? pfn_cuEventQuery(event) : DRIVER_ENTRY_POINT_ERROR;
|
|
539
|
+
}
|
|
540
|
+
|
|
541
|
+
CUresult cuEventRecord_f(CUevent event, CUstream stream)
|
|
542
|
+
{
|
|
543
|
+
return pfn_cuEventRecord ? pfn_cuEventRecord(event, stream) : DRIVER_ENTRY_POINT_ERROR;
|
|
544
|
+
}
|
|
545
|
+
|
|
546
|
+
CUresult cuEventRecordWithFlags_f(CUevent event, CUstream stream, unsigned int flags)
|
|
547
|
+
{
|
|
548
|
+
return pfn_cuEventRecordWithFlags ? pfn_cuEventRecordWithFlags(event, stream, flags) : DRIVER_ENTRY_POINT_ERROR;
|
|
549
|
+
}
|
|
550
|
+
|
|
551
|
+
CUresult cuEventSynchronize_f(CUevent event)
|
|
552
|
+
{
|
|
553
|
+
return pfn_cuEventSynchronize ? pfn_cuEventSynchronize(event) : DRIVER_ENTRY_POINT_ERROR;
|
|
554
|
+
}
|
|
555
|
+
|
|
556
|
+
CUresult cuModuleLoadDataEx_f(CUmodule *module, const void *image, unsigned int numOptions, CUjit_option *options, void **optionValues)
|
|
557
|
+
{
|
|
558
|
+
return pfn_cuModuleLoadDataEx ? pfn_cuModuleLoadDataEx(module, image, numOptions, options, optionValues) : DRIVER_ENTRY_POINT_ERROR;
|
|
559
|
+
}
|
|
560
|
+
|
|
561
|
+
CUresult cuModuleUnload_f(CUmodule hmod)
|
|
562
|
+
{
|
|
563
|
+
return pfn_cuModuleUnload ? pfn_cuModuleUnload(hmod) : DRIVER_ENTRY_POINT_ERROR;
|
|
564
|
+
}
|
|
565
|
+
|
|
566
|
+
CUresult cuModuleGetFunction_f(CUfunction *hfunc, CUmodule hmod, const char *name)
|
|
567
|
+
{
|
|
568
|
+
return pfn_cuModuleGetFunction ? pfn_cuModuleGetFunction(hfunc, hmod, name) : DRIVER_ENTRY_POINT_ERROR;
|
|
569
|
+
}
|
|
570
|
+
|
|
571
|
+
CUresult cuLaunchKernel_f(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes, CUstream hStream, void **kernelParams, void **extra)
|
|
572
|
+
{
|
|
573
|
+
return pfn_cuLaunchKernel ? pfn_cuLaunchKernel(f, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ, sharedMemBytes, hStream, kernelParams, extra) : DRIVER_ENTRY_POINT_ERROR;
|
|
574
|
+
}
|
|
575
|
+
|
|
576
|
+
CUresult cuMemcpyPeerAsync_f(CUdeviceptr dst_ptr, CUcontext dst_ctx, CUdeviceptr src_ptr, CUcontext src_ctx, size_t n, CUstream stream)
|
|
577
|
+
{
|
|
578
|
+
return pfn_cuMemcpyPeerAsync ? pfn_cuMemcpyPeerAsync(dst_ptr, dst_ctx, src_ptr, src_ctx, n, stream) : DRIVER_ENTRY_POINT_ERROR;
|
|
579
|
+
}
|
|
580
|
+
|
|
581
|
+
CUresult cuPointerGetAttribute_f(void* data, CUpointer_attribute attribute, CUdeviceptr ptr)
|
|
582
|
+
{
|
|
583
|
+
return pfn_cuPointerGetAttribute ? pfn_cuPointerGetAttribute(data, attribute, ptr) : DRIVER_ENTRY_POINT_ERROR;
|
|
584
|
+
}
|
|
585
|
+
|
|
586
|
+
CUresult cuGraphicsMapResources_f(unsigned int count, CUgraphicsResource* resources, CUstream stream)
|
|
587
|
+
{
|
|
588
|
+
return pfn_cuGraphicsMapResources ? pfn_cuGraphicsMapResources(count, resources, stream) : DRIVER_ENTRY_POINT_ERROR;
|
|
589
|
+
}
|
|
590
|
+
|
|
591
|
+
CUresult cuGraphicsUnmapResources_f(unsigned int count, CUgraphicsResource* resources, CUstream hStream)
|
|
592
|
+
{
|
|
593
|
+
return pfn_cuGraphicsUnmapResources ? pfn_cuGraphicsUnmapResources(count, resources, hStream) : DRIVER_ENTRY_POINT_ERROR;
|
|
594
|
+
}
|
|
595
|
+
|
|
596
|
+
CUresult cuGraphicsResourceGetMappedPointer_f(CUdeviceptr* pDevPtr, size_t* pSize, CUgraphicsResource resource)
|
|
597
|
+
{
|
|
598
|
+
return pfn_cuGraphicsResourceGetMappedPointer ? pfn_cuGraphicsResourceGetMappedPointer(pDevPtr, pSize, resource) : DRIVER_ENTRY_POINT_ERROR;
|
|
599
|
+
}
|
|
600
|
+
|
|
601
|
+
CUresult cuGraphicsGLRegisterBuffer_f(CUgraphicsResource *pCudaResource, unsigned int buffer, unsigned int flags)
|
|
602
|
+
{
|
|
603
|
+
return pfn_cuGraphicsGLRegisterBuffer ? pfn_cuGraphicsGLRegisterBuffer(pCudaResource, (wp::GLuint) buffer, flags) : DRIVER_ENTRY_POINT_ERROR;
|
|
604
|
+
}
|
|
605
|
+
|
|
606
|
+
CUresult cuGraphicsUnregisterResource_f(CUgraphicsResource resource)
|
|
607
|
+
{
|
|
608
|
+
return pfn_cuGraphicsUnregisterResource ? pfn_cuGraphicsUnregisterResource(resource) : DRIVER_ENTRY_POINT_ERROR;
|
|
609
|
+
}
|
|
610
|
+
|
|
611
|
+
CUresult cuModuleGetGlobal_f(CUdeviceptr* dptr, size_t* bytes, CUmodule hmod, const char* name )
|
|
612
|
+
{
|
|
613
|
+
return pfn_cuModuleGetGlobal ? pfn_cuModuleGetGlobal(dptr, bytes, hmod, name) : DRIVER_ENTRY_POINT_ERROR;
|
|
614
|
+
}
|
|
615
|
+
|
|
616
|
+
CUresult cuFuncSetAttribute_f(CUfunction hfunc, CUfunction_attribute attrib, int value)
|
|
617
|
+
{
|
|
618
|
+
return pfn_cuFuncSetAttribute ? pfn_cuFuncSetAttribute(hfunc, attrib, value) : DRIVER_ENTRY_POINT_ERROR;
|
|
619
|
+
}
|
|
620
|
+
|
|
621
|
+
CUresult cuIpcGetEventHandle_f(CUipcEventHandle *pHandle, CUevent event)
|
|
622
|
+
{
|
|
623
|
+
return pfn_cuIpcGetEventHandle ? pfn_cuIpcGetEventHandle(pHandle, event) : DRIVER_ENTRY_POINT_ERROR;
|
|
624
|
+
}
|
|
625
|
+
|
|
626
|
+
CUresult cuIpcOpenEventHandle_f(CUevent *phEvent, CUipcEventHandle handle)
|
|
627
|
+
{
|
|
628
|
+
return pfn_cuIpcOpenEventHandle ? pfn_cuIpcOpenEventHandle(phEvent, handle) : DRIVER_ENTRY_POINT_ERROR;
|
|
629
|
+
}
|
|
630
|
+
|
|
631
|
+
CUresult cuIpcGetMemHandle_f(CUipcMemHandle *pHandle, CUdeviceptr dptr)
|
|
632
|
+
{
|
|
633
|
+
return pfn_cuIpcGetMemHandle ? pfn_cuIpcGetMemHandle(pHandle, dptr) : DRIVER_ENTRY_POINT_ERROR;
|
|
634
|
+
}
|
|
635
|
+
|
|
636
|
+
CUresult cuIpcOpenMemHandle_f(CUdeviceptr *pdptr, CUipcMemHandle handle, unsigned int flags)
|
|
637
|
+
{
|
|
638
|
+
return pfn_cuIpcOpenMemHandle ? pfn_cuIpcOpenMemHandle(pdptr, handle, flags) : DRIVER_ENTRY_POINT_ERROR;
|
|
639
|
+
}
|
|
640
|
+
|
|
641
|
+
CUresult cuIpcCloseMemHandle_f(CUdeviceptr dptr)
|
|
642
|
+
{
|
|
643
|
+
return pfn_cuIpcCloseMemHandle ? pfn_cuIpcCloseMemHandle(dptr) : DRIVER_ENTRY_POINT_ERROR;
|
|
644
|
+
}
|
|
645
|
+
|
|
646
|
+
#endif // WP_ENABLE_CUDA
|