warp-lang 1.6.1__py3-none-win_amd64.whl → 1.7.0__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +21 -7
- warp/autograd.py +14 -6
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +424 -6
- warp/build_dll.py +20 -20
- warp/builtins.py +467 -368
- warp/codegen.py +193 -125
- warp/config.py +56 -12
- warp/constants.py +14 -6
- warp/context.py +524 -277
- warp/dlpack.py +22 -12
- warp/examples/__init__.py +14 -6
- warp/examples/assets/nonuniform.usd +0 -0
- warp/examples/assets/nvidia_logo.png +0 -0
- warp/examples/benchmarks/benchmark_api.py +14 -6
- warp/examples/benchmarks/benchmark_cloth.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_cupy.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_jax.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_numba.py +15 -0
- warp/examples/benchmarks/benchmark_cloth_numpy.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_paddle.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_pytorch.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_taichi.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_warp.py +14 -6
- warp/examples/benchmarks/benchmark_gemm.py +82 -48
- warp/examples/benchmarks/benchmark_interop_paddle.py +14 -6
- warp/examples/benchmarks/benchmark_interop_torch.py +14 -6
- warp/examples/benchmarks/benchmark_launches.py +14 -6
- warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
- warp/examples/browse.py +14 -6
- warp/examples/core/example_cupy.py +14 -6
- warp/examples/core/example_dem.py +14 -6
- warp/examples/core/example_fluid.py +14 -6
- warp/examples/core/example_graph_capture.py +14 -6
- warp/examples/core/example_marching_cubes.py +14 -6
- warp/examples/core/example_mesh.py +14 -6
- warp/examples/core/example_mesh_intersect.py +14 -6
- warp/examples/core/example_nvdb.py +14 -6
- warp/examples/core/example_raycast.py +14 -6
- warp/examples/core/example_raymarch.py +14 -6
- warp/examples/core/example_render_opengl.py +14 -6
- warp/examples/core/example_sample_mesh.py +300 -0
- warp/examples/core/example_sph.py +14 -6
- warp/examples/core/example_torch.py +14 -6
- warp/examples/core/example_wave.py +14 -6
- warp/examples/fem/example_adaptive_grid.py +14 -6
- warp/examples/fem/example_apic_fluid.py +15 -7
- warp/examples/fem/example_burgers.py +16 -8
- warp/examples/fem/example_convection_diffusion.py +14 -6
- warp/examples/fem/example_convection_diffusion_dg.py +14 -6
- warp/examples/fem/example_deformed_geometry.py +15 -7
- warp/examples/fem/example_diffusion.py +14 -6
- warp/examples/fem/example_diffusion_3d.py +14 -6
- warp/examples/fem/example_diffusion_mgpu.py +14 -6
- warp/examples/fem/example_distortion_energy.py +15 -7
- warp/examples/fem/example_magnetostatics.py +20 -12
- warp/examples/fem/example_mixed_elasticity.py +14 -6
- warp/examples/fem/example_navier_stokes.py +14 -6
- warp/examples/fem/example_nonconforming_contact.py +14 -6
- warp/examples/fem/example_stokes.py +14 -6
- warp/examples/fem/example_stokes_transfer.py +14 -6
- warp/examples/fem/example_streamlines.py +14 -6
- warp/examples/fem/utils.py +24 -3
- warp/examples/interop/example_jax_callable.py +116 -0
- warp/examples/interop/example_jax_ffi_callback.py +132 -0
- warp/examples/interop/example_jax_kernel.py +205 -0
- warp/examples/optim/example_bounce.py +14 -6
- warp/examples/optim/example_cloth_throw.py +14 -6
- warp/examples/optim/example_diffray.py +14 -6
- warp/examples/optim/example_drone.py +14 -6
- warp/examples/optim/example_fluid_checkpoint.py +497 -0
- warp/examples/optim/example_inverse_kinematics.py +14 -6
- warp/examples/optim/example_inverse_kinematics_torch.py +14 -6
- warp/examples/optim/example_softbody_properties.py +14 -6
- warp/examples/optim/example_spring_cage.py +14 -6
- warp/examples/optim/example_trajectory.py +14 -6
- warp/examples/sim/example_cartpole.py +14 -6
- warp/examples/sim/example_cloth.py +14 -6
- warp/examples/sim/example_cloth_self_contact.py +14 -6
- warp/examples/sim/example_granular.py +14 -6
- warp/examples/sim/example_granular_collision_sdf.py +14 -6
- warp/examples/sim/example_jacobian_ik.py +14 -6
- warp/examples/sim/example_particle_chain.py +14 -6
- warp/examples/sim/example_quadruped.py +14 -6
- warp/examples/sim/example_rigid_chain.py +14 -6
- warp/examples/sim/example_rigid_contact.py +14 -6
- warp/examples/sim/example_rigid_force.py +14 -6
- warp/examples/sim/example_rigid_gyroscopic.py +14 -6
- warp/examples/sim/example_rigid_soft_contact.py +14 -6
- warp/examples/sim/example_soft_body.py +14 -6
- warp/examples/tile/example_tile_cholesky.py +14 -6
- warp/examples/tile/example_tile_convolution.py +14 -6
- warp/examples/tile/example_tile_fft.py +14 -6
- warp/examples/tile/example_tile_filtering.py +14 -6
- warp/examples/tile/example_tile_matmul.py +16 -10
- warp/examples/tile/example_tile_mlp.py +14 -6
- warp/examples/tile/example_tile_nbody.py +14 -6
- warp/examples/tile/example_tile_walker.py +14 -6
- warp/fabric.py +15 -0
- warp/fem/__init__.py +26 -1
- warp/fem/adaptivity.py +19 -4
- warp/fem/cache.py +15 -0
- warp/fem/dirichlet.py +15 -0
- warp/fem/domain.py +15 -0
- warp/fem/field/__init__.py +15 -0
- warp/fem/field/field.py +15 -0
- warp/fem/field/nodal_field.py +37 -68
- warp/fem/field/restriction.py +15 -0
- warp/fem/field/virtual.py +77 -23
- warp/fem/geometry/__init__.py +15 -0
- warp/fem/geometry/adaptive_nanogrid.py +24 -10
- warp/fem/geometry/closest_point.py +16 -1
- warp/fem/geometry/deformed_geometry.py +20 -2
- warp/fem/geometry/element.py +15 -0
- warp/fem/geometry/geometry.py +20 -0
- warp/fem/geometry/grid_2d.py +27 -12
- warp/fem/geometry/grid_3d.py +27 -15
- warp/fem/geometry/hexmesh.py +20 -7
- warp/fem/geometry/nanogrid.py +24 -11
- warp/fem/geometry/partition.py +15 -0
- warp/fem/geometry/quadmesh.py +28 -13
- warp/fem/geometry/tetmesh.py +18 -4
- warp/fem/geometry/trimesh.py +18 -8
- warp/fem/integrate.py +277 -93
- warp/fem/linalg.py +20 -5
- warp/fem/operator.py +15 -0
- warp/fem/polynomial.py +15 -0
- warp/fem/quadrature/__init__.py +15 -0
- warp/fem/quadrature/pic_quadrature.py +52 -22
- warp/fem/quadrature/quadrature.py +209 -25
- warp/fem/space/__init__.py +16 -1
- warp/fem/space/basis_function_space.py +19 -2
- warp/fem/space/basis_space.py +40 -18
- warp/fem/space/dof_mapper.py +15 -0
- warp/fem/space/function_space.py +15 -0
- warp/fem/space/grid_2d_function_space.py +15 -0
- warp/fem/space/grid_3d_function_space.py +15 -0
- warp/fem/space/hexmesh_function_space.py +17 -2
- warp/fem/space/nanogrid_function_space.py +15 -0
- warp/fem/space/partition.py +21 -2
- warp/fem/space/quadmesh_function_space.py +23 -8
- warp/fem/space/restriction.py +15 -0
- warp/fem/space/shape/__init__.py +15 -0
- warp/fem/space/shape/cube_shape_function.py +38 -23
- warp/fem/space/shape/shape_function.py +15 -0
- warp/fem/space/shape/square_shape_function.py +27 -12
- warp/fem/space/shape/tet_shape_function.py +15 -0
- warp/fem/space/shape/triangle_shape_function.py +16 -1
- warp/fem/space/tetmesh_function_space.py +18 -3
- warp/fem/space/topology.py +15 -0
- warp/fem/space/trimesh_function_space.py +17 -2
- warp/fem/types.py +15 -0
- warp/fem/utils.py +27 -6
- warp/jax.py +28 -7
- warp/jax_experimental/__init__.py +16 -0
- warp/{jax_experimental.py → jax_experimental/custom_call.py} +28 -33
- warp/jax_experimental/ffi.py +698 -0
- warp/jax_experimental/xla_ffi.py +602 -0
- warp/math.py +103 -6
- warp/native/array.h +28 -6
- warp/native/builtin.h +44 -9
- warp/native/bvh.cpp +18 -7
- warp/native/bvh.cu +57 -20
- warp/native/bvh.h +17 -7
- warp/native/clang/clang.cpp +45 -9
- warp/native/coloring.cpp +15 -6
- warp/native/crt.cpp +15 -6
- warp/native/crt.h +15 -6
- warp/native/cuda_crt.h +15 -6
- warp/native/cuda_util.cpp +29 -6
- warp/native/cuda_util.h +17 -6
- warp/native/error.cpp +15 -6
- warp/native/error.h +15 -6
- warp/native/exports.h +85 -63
- warp/native/fabric.h +15 -6
- warp/native/hashgrid.cpp +15 -6
- warp/native/hashgrid.cu +15 -6
- warp/native/hashgrid.h +15 -6
- warp/native/initializer_array.h +15 -6
- warp/native/intersect.h +41 -32
- warp/native/intersect_adj.h +48 -39
- warp/native/intersect_tri.h +17 -0
- warp/native/marching.cpp +16 -0
- warp/native/marching.cu +16 -7
- warp/native/marching.h +17 -0
- warp/native/mat.h +528 -15
- warp/native/mathdx.cpp +15 -6
- warp/native/matnn.h +15 -6
- warp/native/mesh.cpp +15 -6
- warp/native/mesh.cu +15 -6
- warp/native/mesh.h +25 -16
- warp/native/noise.h +15 -6
- warp/native/quat.h +114 -17
- warp/native/rand.h +21 -6
- warp/native/range.h +15 -6
- warp/native/reduce.cpp +15 -6
- warp/native/reduce.cu +15 -6
- warp/native/runlength_encode.cpp +15 -6
- warp/native/runlength_encode.cu +15 -6
- warp/native/scan.cpp +15 -6
- warp/native/scan.cu +15 -6
- warp/native/scan.h +15 -6
- warp/native/solid_angle.h +17 -0
- warp/native/sort.cpp +137 -65
- warp/native/sort.cu +167 -21
- warp/native/sort.h +23 -7
- warp/native/sparse.cpp +58 -28
- warp/native/sparse.cu +67 -23
- warp/native/spatial.h +15 -6
- warp/native/svd.h +131 -6
- warp/native/temp_buffer.h +15 -6
- warp/native/tile.h +316 -111
- warp/native/tile_reduce.h +61 -9
- warp/native/vec.h +83 -13
- warp/native/volume.cpp +100 -119
- warp/native/volume.cu +15 -6
- warp/native/volume.h +15 -6
- warp/native/volume_builder.cu +40 -16
- warp/native/volume_builder.h +21 -6
- warp/native/volume_impl.h +15 -6
- warp/native/warp.cpp +20 -12
- warp/native/warp.cu +114 -16
- warp/native/warp.h +34 -16
- warp/optim/__init__.py +14 -6
- warp/optim/adam.py +14 -6
- warp/optim/linear.py +25 -10
- warp/optim/sgd.py +14 -6
- warp/paddle.py +14 -6
- warp/render/__init__.py +14 -6
- warp/render/render_opengl.py +14 -6
- warp/render/render_usd.py +14 -6
- warp/render/utils.py +14 -6
- warp/sim/__init__.py +14 -7
- warp/sim/articulation.py +18 -10
- warp/sim/collide.py +35 -16
- warp/sim/graph_coloring.py +14 -6
- warp/sim/import_mjcf.py +463 -162
- warp/sim/import_snu.py +14 -7
- warp/sim/import_urdf.py +46 -18
- warp/sim/import_usd.py +14 -7
- warp/sim/inertia.py +14 -6
- warp/sim/integrator.py +14 -6
- warp/sim/integrator_euler.py +19 -11
- warp/sim/integrator_featherstone.py +17 -16
- warp/sim/integrator_vbd.py +222 -8
- warp/sim/integrator_xpbd.py +19 -11
- warp/sim/model.py +56 -19
- warp/sim/particles.py +14 -6
- warp/sim/render.py +14 -6
- warp/sim/utils.py +17 -2
- warp/sparse.py +657 -555
- warp/stubs.py +231 -19
- warp/tape.py +14 -6
- warp/tests/aux_test_class_kernel.py +14 -6
- warp/tests/aux_test_compile_consts_dummy.py +14 -6
- warp/tests/aux_test_conditional_unequal_types_kernels.py +14 -6
- warp/tests/aux_test_dependent.py +14 -6
- warp/tests/aux_test_grad_customs.py +14 -6
- warp/tests/aux_test_instancing_gc.py +14 -6
- warp/tests/aux_test_module_unload.py +14 -6
- warp/tests/aux_test_name_clash1.py +14 -6
- warp/tests/aux_test_name_clash2.py +14 -6
- warp/tests/aux_test_unresolved_func.py +14 -6
- warp/tests/aux_test_unresolved_symbol.py +14 -6
- warp/tests/cuda/__init__.py +0 -0
- warp/tests/{test_async.py → cuda/test_async.py} +14 -6
- warp/tests/{test_ipc.py → cuda/test_ipc.py} +14 -6
- warp/tests/{test_mempool.py → cuda/test_mempool.py} +53 -6
- warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +14 -6
- warp/tests/{test_peer.py → cuda/test_peer.py} +14 -6
- warp/tests/{test_pinned.py → cuda/test_pinned.py} +14 -6
- warp/tests/{test_streams.py → cuda/test_streams.py} +85 -6
- warp/tests/geometry/__init__.py +0 -0
- warp/tests/{test_bvh.py → geometry/test_bvh.py} +14 -6
- warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +14 -6
- warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +14 -6
- warp/tests/{test_mesh.py → geometry/test_mesh.py} +14 -6
- warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +14 -6
- warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +80 -69
- warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +15 -7
- warp/tests/{test_volume.py → geometry/test_volume.py} +55 -12
- warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +14 -6
- warp/tests/interop/__init__.py +0 -0
- warp/tests/{test_dlpack.py → interop/test_dlpack.py} +42 -11
- warp/tests/{test_jax.py → interop/test_jax.py} +14 -6
- warp/tests/{test_paddle.py → interop/test_paddle.py} +14 -6
- warp/tests/{test_torch.py → interop/test_torch.py} +14 -6
- warp/tests/run_coverage_serial.py +14 -6
- warp/tests/sim/__init__.py +0 -0
- warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +23 -16
- warp/tests/{flaky_test_sim_grad.py → sim/flaky_test_sim_grad.py} +14 -6
- warp/tests/{test_collision.py → sim/test_collision.py} +16 -8
- warp/tests/{test_coloring.py → sim/test_coloring.py} +14 -7
- warp/tests/{test_model.py → sim/test_model.py} +55 -7
- warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +14 -6
- warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +16 -7
- warp/tests/sim/test_vbd.py +597 -0
- warp/tests/test_adam.py +14 -6
- warp/tests/test_arithmetic.py +14 -6
- warp/tests/test_array.py +14 -6
- warp/tests/test_array_reduce.py +14 -6
- warp/tests/test_assert.py +14 -6
- warp/tests/test_atomic.py +14 -6
- warp/tests/test_bool.py +15 -7
- warp/tests/test_builtins_resolution.py +14 -6
- warp/tests/test_closest_point_edge_edge.py +14 -6
- warp/tests/test_codegen.py +14 -6
- warp/tests/test_codegen_instancing.py +14 -6
- warp/tests/test_compile_consts.py +14 -6
- warp/tests/test_conditional.py +14 -6
- warp/tests/test_context.py +14 -6
- warp/tests/test_copy.py +14 -6
- warp/tests/test_ctypes.py +14 -6
- warp/tests/test_dense.py +14 -6
- warp/tests/test_devices.py +14 -6
- warp/tests/test_examples.py +42 -42
- warp/tests/test_fabricarray.py +14 -6
- warp/tests/test_fast_math.py +14 -6
- warp/tests/test_fem.py +37 -10
- warp/tests/test_fp16.py +14 -6
- warp/tests/test_func.py +14 -6
- warp/tests/test_future_annotations.py +14 -6
- warp/tests/test_generics.py +14 -6
- warp/tests/test_grad.py +14 -6
- warp/tests/test_grad_customs.py +14 -6
- warp/tests/test_grad_debug.py +14 -6
- warp/tests/test_implicit_init.py +14 -6
- warp/tests/test_import.py +14 -6
- warp/tests/test_indexedarray.py +14 -6
- warp/tests/test_intersect.py +14 -6
- warp/tests/test_iter.py +14 -6
- warp/tests/test_large.py +14 -6
- warp/tests/test_launch.py +14 -6
- warp/tests/test_lerp.py +14 -6
- warp/tests/test_linear_solvers.py +15 -11
- warp/tests/test_lvalue.py +14 -6
- warp/tests/test_mat.py +247 -85
- warp/tests/test_mat_lite.py +14 -6
- warp/tests/test_mat_scalar_ops.py +18 -10
- warp/tests/test_math.py +14 -6
- warp/tests/test_mlp.py +14 -6
- warp/tests/test_module_hashing.py +14 -6
- warp/tests/test_modules_lite.py +14 -6
- warp/tests/test_noise.py +14 -6
- warp/tests/test_operators.py +14 -6
- warp/tests/test_options.py +14 -6
- warp/tests/test_overwrite.py +15 -60
- warp/tests/test_print.py +14 -6
- warp/tests/test_quat.py +81 -52
- warp/tests/test_rand.py +58 -43
- warp/tests/test_reload.py +14 -6
- warp/tests/test_rounding.py +14 -6
- warp/tests/test_runlength_encode.py +14 -6
- warp/tests/test_scalar_ops.py +14 -6
- warp/tests/test_smoothstep.py +14 -6
- warp/tests/test_snippet.py +15 -0
- warp/tests/test_sparse.py +61 -12
- warp/tests/test_spatial.py +89 -6
- warp/tests/test_special_values.py +14 -6
- warp/tests/test_static.py +15 -7
- warp/tests/test_struct.py +14 -6
- warp/tests/test_tape.py +14 -6
- warp/tests/test_transient_module.py +14 -6
- warp/tests/test_triangle_closest_point.py +14 -6
- warp/tests/test_types.py +14 -6
- warp/tests/test_utils.py +98 -10
- warp/tests/test_vec.py +60 -40
- warp/tests/test_vec_lite.py +14 -6
- warp/tests/test_vec_scalar_ops.py +14 -6
- warp/tests/test_verify_fp.py +14 -6
- warp/tests/tile/__init__.py +0 -0
- warp/tests/{test_tile.py → tile/test_tile.py} +150 -57
- warp/tests/{test_tile_load.py → tile/test_tile_load.py} +15 -7
- warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +23 -12
- warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +39 -20
- warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +74 -7
- warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +14 -6
- warp/tests/{test_tile_view.py → tile/test_tile_view.py} +15 -7
- warp/tests/unittest_serial.py +15 -6
- warp/tests/unittest_suites.py +59 -65
- warp/tests/unittest_utils.py +16 -7
- warp/tests/walkthrough_debug.py +14 -6
- warp/thirdparty/unittest_parallel.py +15 -8
- warp/torch.py +14 -6
- warp/types.py +124 -664
- warp/utils.py +151 -78
- {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/METADATA +39 -12
- warp_lang-1.7.0.dist-info/RECORD +429 -0
- {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/WHEEL +1 -1
- warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
- warp/examples/optim/example_walker.py +0 -309
- warp/native/cutlass_gemm.cpp +0 -34
- warp/native/cutlass_gemm.cu +0 -373
- warp/tests/test_matmul.py +0 -503
- warp/tests/test_matmul_lite.py +0 -403
- warp/tests/test_vbd.py +0 -378
- warp/tests/unused_test_misc.py +0 -69
- warp_lang-1.6.1.dist-info/LICENSE.md +0 -126
- warp_lang-1.6.1.dist-info/RECORD +0 -419
- {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/top_level.txt +0 -0
warp/context.py
CHANGED
|
@@ -1,15 +1,22 @@
|
|
|
1
|
-
# Copyright (c) 2022 NVIDIA CORPORATION.
|
|
2
|
-
#
|
|
3
|
-
#
|
|
4
|
-
#
|
|
5
|
-
#
|
|
6
|
-
#
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
7
15
|
|
|
8
16
|
from __future__ import annotations
|
|
9
17
|
|
|
10
18
|
import ast
|
|
11
19
|
import ctypes
|
|
12
|
-
import errno
|
|
13
20
|
import functools
|
|
14
21
|
import hashlib
|
|
15
22
|
import inspect
|
|
@@ -20,13 +27,27 @@ import operator
|
|
|
20
27
|
import os
|
|
21
28
|
import platform
|
|
22
29
|
import sys
|
|
23
|
-
import time
|
|
24
30
|
import types
|
|
25
31
|
import typing
|
|
26
32
|
import weakref
|
|
27
33
|
from copy import copy as shallowcopy
|
|
28
34
|
from pathlib import Path
|
|
29
|
-
from typing import
|
|
35
|
+
from typing import (
|
|
36
|
+
Any,
|
|
37
|
+
Callable,
|
|
38
|
+
Dict,
|
|
39
|
+
List,
|
|
40
|
+
Literal,
|
|
41
|
+
Mapping,
|
|
42
|
+
Optional,
|
|
43
|
+
Sequence,
|
|
44
|
+
Set,
|
|
45
|
+
Tuple,
|
|
46
|
+
TypeVar,
|
|
47
|
+
Union,
|
|
48
|
+
get_args,
|
|
49
|
+
get_origin,
|
|
50
|
+
)
|
|
30
51
|
|
|
31
52
|
import numpy as np
|
|
32
53
|
|
|
@@ -34,7 +55,7 @@ import warp
|
|
|
34
55
|
import warp.build
|
|
35
56
|
import warp.codegen
|
|
36
57
|
import warp.config
|
|
37
|
-
from warp.types import launch_bounds_t
|
|
58
|
+
from warp.types import Array, launch_bounds_t
|
|
38
59
|
|
|
39
60
|
# represents either a built-in or user-defined function
|
|
40
61
|
|
|
@@ -63,10 +84,10 @@ def get_function_args(func):
|
|
|
63
84
|
complex_type_hints = (Any, Callable, Tuple)
|
|
64
85
|
sequence_types = (list, tuple)
|
|
65
86
|
|
|
66
|
-
function_key_counts = {}
|
|
87
|
+
function_key_counts: Dict[str, int] = {}
|
|
67
88
|
|
|
68
89
|
|
|
69
|
-
def generate_unique_function_identifier(key):
|
|
90
|
+
def generate_unique_function_identifier(key: str) -> str:
|
|
70
91
|
# Generate unique identifiers for user-defined functions in native code.
|
|
71
92
|
# - Prevents conflicts when a function is redefined and old versions are still in use.
|
|
72
93
|
# - Prevents conflicts between multiple closures returned from the same function.
|
|
@@ -99,40 +120,40 @@ def generate_unique_function_identifier(key):
|
|
|
99
120
|
class Function:
|
|
100
121
|
def __init__(
|
|
101
122
|
self,
|
|
102
|
-
func,
|
|
103
|
-
key,
|
|
104
|
-
namespace,
|
|
105
|
-
input_types=None,
|
|
106
|
-
value_type=None,
|
|
107
|
-
value_func=None,
|
|
108
|
-
export_func=None,
|
|
109
|
-
dispatch_func=None,
|
|
110
|
-
lto_dispatch_func=None,
|
|
111
|
-
module=None,
|
|
112
|
-
variadic=False,
|
|
113
|
-
initializer_list_func=None,
|
|
114
|
-
export=False,
|
|
115
|
-
doc="",
|
|
116
|
-
group="",
|
|
117
|
-
hidden=False,
|
|
118
|
-
skip_replay=False,
|
|
119
|
-
missing_grad=False,
|
|
120
|
-
generic=False,
|
|
121
|
-
native_func=None,
|
|
122
|
-
defaults=None,
|
|
123
|
-
custom_replay_func=None,
|
|
124
|
-
native_snippet=None,
|
|
125
|
-
adj_native_snippet=None,
|
|
126
|
-
replay_snippet=None,
|
|
127
|
-
skip_forward_codegen=False,
|
|
128
|
-
skip_reverse_codegen=False,
|
|
129
|
-
custom_reverse_num_input_args
|
|
130
|
-
custom_reverse_mode=False,
|
|
131
|
-
overloaded_annotations=None,
|
|
132
|
-
code_transformers=None,
|
|
133
|
-
skip_adding_overload=False,
|
|
134
|
-
require_original_output_arg=False,
|
|
135
|
-
scope_locals
|
|
123
|
+
func: Optional[Callable],
|
|
124
|
+
key: str,
|
|
125
|
+
namespace: str,
|
|
126
|
+
input_types: Optional[Dict[str, Union[type, TypeVar]]] = None,
|
|
127
|
+
value_type: Optional[type] = None,
|
|
128
|
+
value_func: Optional[Callable[[Mapping[str, type], Mapping[str, Any]], type]] = None,
|
|
129
|
+
export_func: Optional[Callable[[Dict[str, type]], Dict[str, type]]] = None,
|
|
130
|
+
dispatch_func: Optional[Callable] = None,
|
|
131
|
+
lto_dispatch_func: Optional[Callable] = None,
|
|
132
|
+
module: Optional[Module] = None,
|
|
133
|
+
variadic: bool = False,
|
|
134
|
+
initializer_list_func: Optional[Callable[[Dict[str, Any], type], bool]] = None,
|
|
135
|
+
export: bool = False,
|
|
136
|
+
doc: str = "",
|
|
137
|
+
group: str = "",
|
|
138
|
+
hidden: bool = False,
|
|
139
|
+
skip_replay: bool = False,
|
|
140
|
+
missing_grad: bool = False,
|
|
141
|
+
generic: bool = False,
|
|
142
|
+
native_func: Optional[str] = None,
|
|
143
|
+
defaults: Optional[Dict[str, Any]] = None,
|
|
144
|
+
custom_replay_func: Optional[Function] = None,
|
|
145
|
+
native_snippet: Optional[str] = None,
|
|
146
|
+
adj_native_snippet: Optional[str] = None,
|
|
147
|
+
replay_snippet: Optional[str] = None,
|
|
148
|
+
skip_forward_codegen: bool = False,
|
|
149
|
+
skip_reverse_codegen: bool = False,
|
|
150
|
+
custom_reverse_num_input_args: int = -1,
|
|
151
|
+
custom_reverse_mode: bool = False,
|
|
152
|
+
overloaded_annotations: Optional[Dict[str, type]] = None,
|
|
153
|
+
code_transformers: Optional[List[ast.NodeTransformer]] = None,
|
|
154
|
+
skip_adding_overload: bool = False,
|
|
155
|
+
require_original_output_arg: bool = False,
|
|
156
|
+
scope_locals: Optional[Dict[str, Any]] = None,
|
|
136
157
|
):
|
|
137
158
|
if code_transformers is None:
|
|
138
159
|
code_transformers = []
|
|
@@ -157,7 +178,7 @@ class Function:
|
|
|
157
178
|
self.native_snippet = native_snippet
|
|
158
179
|
self.adj_native_snippet = adj_native_snippet
|
|
159
180
|
self.replay_snippet = replay_snippet
|
|
160
|
-
self.custom_grad_func = None
|
|
181
|
+
self.custom_grad_func: Optional[Function] = None
|
|
161
182
|
self.require_original_output_arg = require_original_output_arg
|
|
162
183
|
self.generic_parent = None # generic function that was used to instantiate this overload
|
|
163
184
|
|
|
@@ -173,6 +194,7 @@ class Function:
|
|
|
173
194
|
)
|
|
174
195
|
self.missing_grad = missing_grad # whether builtin is missing a corresponding adjoint
|
|
175
196
|
self.generic = generic
|
|
197
|
+
self.mangled_name: Optional[str] = None
|
|
176
198
|
|
|
177
199
|
# allow registering functions with a different name in Python and native code
|
|
178
200
|
if native_func is None:
|
|
@@ -189,8 +211,8 @@ class Function:
|
|
|
189
211
|
# user-defined function
|
|
190
212
|
|
|
191
213
|
# generic and concrete overload lookups by type signature
|
|
192
|
-
self.user_templates = {}
|
|
193
|
-
self.user_overloads = {}
|
|
214
|
+
self.user_templates: Dict[str, Function] = {}
|
|
215
|
+
self.user_overloads: Dict[str, Function] = {}
|
|
194
216
|
|
|
195
217
|
# user defined (Python) function
|
|
196
218
|
self.adj = warp.codegen.Adjoint(
|
|
@@ -221,19 +243,17 @@ class Function:
|
|
|
221
243
|
# builtin function
|
|
222
244
|
|
|
223
245
|
# embedded linked list of all overloads
|
|
224
|
-
# the builtin_functions dictionary holds
|
|
225
|
-
|
|
226
|
-
self.overloads = []
|
|
246
|
+
# the builtin_functions dictionary holds the list head for a given key (func name)
|
|
247
|
+
self.overloads: List[Function] = []
|
|
227
248
|
|
|
228
249
|
# builtin (native) function, canonicalize argument types
|
|
229
|
-
|
|
230
|
-
|
|
250
|
+
if input_types is not None:
|
|
251
|
+
for k, v in input_types.items():
|
|
252
|
+
self.input_types[k] = warp.types.type_to_warp(v)
|
|
231
253
|
|
|
232
254
|
# cache mangled name
|
|
233
255
|
if self.export and self.is_simple():
|
|
234
256
|
self.mangled_name = self.mangle()
|
|
235
|
-
else:
|
|
236
|
-
self.mangled_name = None
|
|
237
257
|
|
|
238
258
|
if not skip_adding_overload:
|
|
239
259
|
self.add_overload(self)
|
|
@@ -264,7 +284,7 @@ class Function:
|
|
|
264
284
|
signature_params.append(param)
|
|
265
285
|
self.signature = inspect.Signature(signature_params)
|
|
266
286
|
|
|
267
|
-
# scope for resolving overloads
|
|
287
|
+
# scope for resolving overloads, the locals() where the function is defined
|
|
268
288
|
if scope_locals is None:
|
|
269
289
|
scope_locals = inspect.currentframe().f_back.f_locals
|
|
270
290
|
|
|
@@ -326,10 +346,10 @@ class Function:
|
|
|
326
346
|
# this function has no overloads, call it like a plain Python function
|
|
327
347
|
return self.func(*args, **kwargs)
|
|
328
348
|
|
|
329
|
-
def is_builtin(self):
|
|
349
|
+
def is_builtin(self) -> bool:
|
|
330
350
|
return self.func is None
|
|
331
351
|
|
|
332
|
-
def is_simple(self):
|
|
352
|
+
def is_simple(self) -> bool:
|
|
333
353
|
if self.variadic:
|
|
334
354
|
return False
|
|
335
355
|
|
|
@@ -343,9 +363,8 @@ class Function:
|
|
|
343
363
|
|
|
344
364
|
return True
|
|
345
365
|
|
|
346
|
-
def mangle(self):
|
|
347
|
-
|
|
348
|
-
# function, e.g.: builtin_normalize_vec3()
|
|
366
|
+
def mangle(self) -> str:
|
|
367
|
+
"""Build a mangled name for the C-exported function, e.g.: `builtin_normalize_vec3()`."""
|
|
349
368
|
|
|
350
369
|
name = "builtin_" + self.key
|
|
351
370
|
|
|
@@ -361,7 +380,7 @@ class Function:
|
|
|
361
380
|
|
|
362
381
|
return "_".join([name, *types])
|
|
363
382
|
|
|
364
|
-
def add_overload(self, f):
|
|
383
|
+
def add_overload(self, f: Function) -> None:
|
|
365
384
|
if self.is_builtin():
|
|
366
385
|
# todo: note that it is an error to add two functions
|
|
367
386
|
# with the exact same signature as this would cause compile
|
|
@@ -376,7 +395,7 @@ class Function:
|
|
|
376
395
|
else:
|
|
377
396
|
# get function signature based on the input types
|
|
378
397
|
sig = warp.types.get_signature(
|
|
379
|
-
f.input_types.values(), func_name=f.key, arg_names=list(f.input_types.keys())
|
|
398
|
+
list(f.input_types.values()), func_name=f.key, arg_names=list(f.input_types.keys())
|
|
380
399
|
)
|
|
381
400
|
|
|
382
401
|
# check if generic
|
|
@@ -385,7 +404,7 @@ class Function:
|
|
|
385
404
|
else:
|
|
386
405
|
self.user_overloads[sig] = f
|
|
387
406
|
|
|
388
|
-
def get_overload(self, arg_types, kwarg_types):
|
|
407
|
+
def get_overload(self, arg_types: List[type], kwarg_types: Mapping[str, type]) -> Optional[Function]:
|
|
389
408
|
assert not self.is_builtin()
|
|
390
409
|
|
|
391
410
|
for f in self.user_overloads.values():
|
|
@@ -438,7 +457,7 @@ class Function:
|
|
|
438
457
|
return f"<Function {self.key}({inputs_str})>"
|
|
439
458
|
|
|
440
459
|
|
|
441
|
-
def call_builtin(func: Function, *params) -> Tuple[bool, Any]:
|
|
460
|
+
def call_builtin(func: Function, *params: Any) -> Tuple[bool, Any]:
|
|
442
461
|
uses_non_warp_array_type = False
|
|
443
462
|
|
|
444
463
|
init()
|
|
@@ -755,37 +774,51 @@ class Kernel:
|
|
|
755
774
|
|
|
756
775
|
|
|
757
776
|
# decorator to register function, @func
|
|
758
|
-
def func(f):
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
|
|
777
|
-
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
|
|
777
|
+
def func(f: Optional[Callable] = None, *, name: Optional[str] = None):
|
|
778
|
+
def wrapper(f, *args, **kwargs):
|
|
779
|
+
if name is None:
|
|
780
|
+
key = warp.codegen.make_full_qualified_name(f)
|
|
781
|
+
else:
|
|
782
|
+
key = name
|
|
783
|
+
|
|
784
|
+
scope_locals = inspect.currentframe().f_back.f_back.f_locals
|
|
785
|
+
|
|
786
|
+
m = get_module(f.__module__)
|
|
787
|
+
doc = getattr(f, "__doc__", "") or ""
|
|
788
|
+
Function(
|
|
789
|
+
func=f,
|
|
790
|
+
key=key,
|
|
791
|
+
namespace="",
|
|
792
|
+
module=m,
|
|
793
|
+
value_func=None,
|
|
794
|
+
scope_locals=scope_locals,
|
|
795
|
+
doc=doc.strip(),
|
|
796
|
+
) # value_type not known yet, will be inferred during Adjoint.build()
|
|
797
|
+
|
|
798
|
+
# use the top of the list of overloads for this key
|
|
799
|
+
g = m.functions[key]
|
|
800
|
+
# copy over the function attributes, including docstring
|
|
801
|
+
return functools.update_wrapper(g, f)
|
|
802
|
+
|
|
803
|
+
if f is None:
|
|
804
|
+
# Arguments were passed to the decorator.
|
|
805
|
+
return wrapper
|
|
806
|
+
|
|
807
|
+
return wrapper(f)
|
|
808
|
+
|
|
809
|
+
|
|
810
|
+
def func_native(snippet: str, adj_snippet: Optional[str] = None, replay_snippet: Optional[str] = None):
|
|
782
811
|
"""
|
|
783
812
|
Decorator to register native code snippet, @func_native
|
|
784
813
|
"""
|
|
785
814
|
|
|
786
|
-
|
|
815
|
+
frame = inspect.currentframe()
|
|
816
|
+
if frame is None or frame.f_back is None:
|
|
817
|
+
scope_locals = {}
|
|
818
|
+
else:
|
|
819
|
+
scope_locals = frame.f_back.f_locals
|
|
787
820
|
|
|
788
|
-
def snippet_func(f):
|
|
821
|
+
def snippet_func(f: Callable) -> Callable:
|
|
789
822
|
name = warp.codegen.make_full_qualified_name(f)
|
|
790
823
|
|
|
791
824
|
m = get_module(f.__module__)
|
|
@@ -957,22 +990,71 @@ def func_replay(forward_fn):
|
|
|
957
990
|
return wrapper
|
|
958
991
|
|
|
959
992
|
|
|
960
|
-
|
|
961
|
-
|
|
962
|
-
|
|
993
|
+
def kernel(
|
|
994
|
+
f: Optional[Callable] = None,
|
|
995
|
+
*,
|
|
996
|
+
enable_backward: Optional[bool] = None,
|
|
997
|
+
module: Optional[Union[Module, Literal["unique"]]] = None,
|
|
998
|
+
):
|
|
999
|
+
"""
|
|
1000
|
+
Decorator to register a Warp kernel from a Python function.
|
|
1001
|
+
The function must be defined with type annotations for all arguments.
|
|
1002
|
+
The function must not return anything.
|
|
1003
|
+
|
|
1004
|
+
Example::
|
|
1005
|
+
|
|
1006
|
+
@wp.kernel
|
|
1007
|
+
def my_kernel(a: wp.array(dtype=float), b: wp.array(dtype=float)):
|
|
1008
|
+
tid = wp.tid()
|
|
1009
|
+
b[tid] = a[tid] + 1.0
|
|
1010
|
+
|
|
1011
|
+
|
|
1012
|
+
@wp.kernel(enable_backward=False)
|
|
1013
|
+
def my_kernel_no_backward(a: wp.array(dtype=float, ndim=2), x: float):
|
|
1014
|
+
# the backward pass will not be generated
|
|
1015
|
+
i, j = wp.tid()
|
|
1016
|
+
a[i, j] = x
|
|
1017
|
+
|
|
1018
|
+
|
|
1019
|
+
@wp.kernel(module="unique")
|
|
1020
|
+
def my_kernel_unique_module(a: wp.array(dtype=float), b: wp.array(dtype=float)):
|
|
1021
|
+
# the kernel will be registered in new unique module created just for this
|
|
1022
|
+
# kernel and its dependent functions and structs
|
|
1023
|
+
tid = wp.tid()
|
|
1024
|
+
b[tid] = a[tid] + 1.0
|
|
1025
|
+
|
|
1026
|
+
Args:
|
|
1027
|
+
f: The function to be registered as a kernel.
|
|
1028
|
+
enable_backward: If False, the backward pass will not be generated.
|
|
1029
|
+
module: The :class:`warp.context.Module` to which the kernel belongs. Alternatively, if a string `"unique"` is provided, the kernel is assigned to a new module named after the kernel name and hash. If None, the module is inferred from the function's module.
|
|
1030
|
+
|
|
1031
|
+
Returns:
|
|
1032
|
+
The registered kernel.
|
|
1033
|
+
"""
|
|
1034
|
+
|
|
963
1035
|
def wrapper(f, *args, **kwargs):
|
|
964
1036
|
options = {}
|
|
965
1037
|
|
|
966
1038
|
if enable_backward is not None:
|
|
967
1039
|
options["enable_backward"] = enable_backward
|
|
968
1040
|
|
|
969
|
-
|
|
1041
|
+
if module is None:
|
|
1042
|
+
m = get_module(f.__module__)
|
|
1043
|
+
elif module == "unique":
|
|
1044
|
+
m = Module(f.__name__, None)
|
|
1045
|
+
else:
|
|
1046
|
+
m = module
|
|
970
1047
|
k = Kernel(
|
|
971
1048
|
func=f,
|
|
972
1049
|
key=warp.codegen.make_full_qualified_name(f),
|
|
973
1050
|
module=m,
|
|
974
1051
|
options=options,
|
|
975
1052
|
)
|
|
1053
|
+
if module == "unique":
|
|
1054
|
+
# add the hash to the module name
|
|
1055
|
+
hasher = warp.context.ModuleHasher(m)
|
|
1056
|
+
k.module.name = f"{k.key}_{hasher.module_hash.hex()[:8]}"
|
|
1057
|
+
|
|
976
1058
|
k = functools.update_wrapper(k, f)
|
|
977
1059
|
return k
|
|
978
1060
|
|
|
@@ -984,7 +1066,7 @@ def kernel(f=None, *, enable_backward=None):
|
|
|
984
1066
|
|
|
985
1067
|
|
|
986
1068
|
# decorator to register struct, @struct
|
|
987
|
-
def struct(c):
|
|
1069
|
+
def struct(c: type):
|
|
988
1070
|
m = get_module(c.__module__)
|
|
989
1071
|
s = warp.codegen.Struct(cls=c, key=warp.codegen.make_full_qualified_name(c), module=m)
|
|
990
1072
|
s = functools.update_wrapper(s, c)
|
|
@@ -1097,47 +1179,47 @@ scalar_types.update({x: x._wp_scalar_type_ for x in warp.types.vector_types})
|
|
|
1097
1179
|
|
|
1098
1180
|
|
|
1099
1181
|
def add_builtin(
|
|
1100
|
-
key,
|
|
1101
|
-
input_types=None,
|
|
1102
|
-
constraint=None,
|
|
1103
|
-
value_type=None,
|
|
1104
|
-
value_func=None,
|
|
1105
|
-
export_func=None,
|
|
1106
|
-
dispatch_func=None,
|
|
1107
|
-
lto_dispatch_func=None,
|
|
1108
|
-
doc="",
|
|
1109
|
-
namespace="wp::",
|
|
1110
|
-
variadic=False,
|
|
1182
|
+
key: str,
|
|
1183
|
+
input_types: Optional[Dict[str, Union[type, TypeVar]]] = None,
|
|
1184
|
+
constraint: Optional[Callable[[Mapping[str, type]], bool]] = None,
|
|
1185
|
+
value_type: Optional[type] = None,
|
|
1186
|
+
value_func: Optional[Callable] = None,
|
|
1187
|
+
export_func: Optional[Callable] = None,
|
|
1188
|
+
dispatch_func: Optional[Callable] = None,
|
|
1189
|
+
lto_dispatch_func: Optional[Callable] = None,
|
|
1190
|
+
doc: str = "",
|
|
1191
|
+
namespace: str = "wp::",
|
|
1192
|
+
variadic: bool = False,
|
|
1111
1193
|
initializer_list_func=None,
|
|
1112
|
-
export=True,
|
|
1113
|
-
group="Other",
|
|
1114
|
-
hidden=False,
|
|
1115
|
-
skip_replay=False,
|
|
1116
|
-
missing_grad=False,
|
|
1117
|
-
native_func=None,
|
|
1118
|
-
defaults=None,
|
|
1119
|
-
require_original_output_arg=False,
|
|
1194
|
+
export: bool = True,
|
|
1195
|
+
group: str = "Other",
|
|
1196
|
+
hidden: bool = False,
|
|
1197
|
+
skip_replay: bool = False,
|
|
1198
|
+
missing_grad: bool = False,
|
|
1199
|
+
native_func: Optional[str] = None,
|
|
1200
|
+
defaults: Optional[Dict[str, Any]] = None,
|
|
1201
|
+
require_original_output_arg: bool = False,
|
|
1120
1202
|
):
|
|
1121
1203
|
"""Main entry point to register a new built-in function.
|
|
1122
1204
|
|
|
1123
1205
|
Args:
|
|
1124
|
-
key
|
|
1206
|
+
key: Function name. Multiple overloaded functions can be registered
|
|
1125
1207
|
under the same name as long as their signature differ.
|
|
1126
|
-
input_types
|
|
1208
|
+
input_types: Signature of the user-facing function.
|
|
1127
1209
|
Variadic arguments are supported by prefixing the parameter names
|
|
1128
1210
|
with asterisks as in `*args` and `**kwargs`. Generic arguments are
|
|
1129
1211
|
supported with types such as `Any`, `Float`, `Scalar`, etc.
|
|
1130
|
-
constraint
|
|
1212
|
+
constraint: For functions that define generic arguments and
|
|
1131
1213
|
are to be exported, this callback is used to specify whether some
|
|
1132
1214
|
combination of inferred arguments are valid or not.
|
|
1133
|
-
value_type
|
|
1134
|
-
value_func
|
|
1215
|
+
value_type: Type returned by the function.
|
|
1216
|
+
value_func: Callback used to specify the return type when
|
|
1135
1217
|
`value_type` isn't enough.
|
|
1136
|
-
export_func
|
|
1218
|
+
export_func: Callback used during the context stage to specify
|
|
1137
1219
|
the signature of the underlying C++ function, not accounting for
|
|
1138
1220
|
the template parameters.
|
|
1139
1221
|
If not provided, `input_types` is used.
|
|
1140
|
-
dispatch_func
|
|
1222
|
+
dispatch_func: Callback used during the codegen stage to specify
|
|
1141
1223
|
the runtime and template arguments to be passed to the underlying C++
|
|
1142
1224
|
function. In other words, this allows defining a mapping between
|
|
1143
1225
|
the signatures of the user-facing and the C++ functions, and even to
|
|
@@ -1145,27 +1227,26 @@ def add_builtin(
|
|
|
1145
1227
|
The arguments returned must be of type `codegen.Var`.
|
|
1146
1228
|
If not provided, all arguments passed by the users when calling
|
|
1147
1229
|
the built-in are passed as-is as runtime arguments to the C++ function.
|
|
1148
|
-
lto_dispatch_func
|
|
1230
|
+
lto_dispatch_func: Same as dispatch_func, but takes an 'option' dict
|
|
1149
1231
|
as extra argument (indicating tile_size and target architecture) and returns
|
|
1150
1232
|
an LTO-IR buffer as extra return value
|
|
1151
|
-
doc
|
|
1233
|
+
doc: Used to generate the Python's docstring and the HTML documentation.
|
|
1152
1234
|
namespace: Namespace for the underlying C++ function.
|
|
1153
|
-
variadic
|
|
1154
|
-
initializer_list_func
|
|
1155
|
-
when passing the arguments to the underlying
|
|
1156
|
-
|
|
1235
|
+
variadic: Whether the function declares variadic arguments.
|
|
1236
|
+
initializer_list_func: Callback to determine whether to use the
|
|
1237
|
+
initializer list syntax when passing the arguments to the underlying
|
|
1238
|
+
C++ function.
|
|
1239
|
+
export: Whether the function is to be exposed to the Python
|
|
1157
1240
|
interpreter so that it becomes available from within the `warp`
|
|
1158
1241
|
module.
|
|
1159
|
-
group
|
|
1160
|
-
hidden
|
|
1161
|
-
skip_replay
|
|
1242
|
+
group: Classification used for the documentation.
|
|
1243
|
+
hidden: Whether to add that function into the documentation.
|
|
1244
|
+
skip_replay: Whether operation will be performed during
|
|
1162
1245
|
the forward replay in the backward pass.
|
|
1163
|
-
missing_grad
|
|
1164
|
-
|
|
1165
|
-
|
|
1166
|
-
|
|
1167
|
-
in `input_types`.
|
|
1168
|
-
require_original_output_arg (bool): Used during the codegen stage to
|
|
1246
|
+
missing_grad: Whether the function is missing a corresponding adjoint.
|
|
1247
|
+
native_func: Name of the underlying C++ function.
|
|
1248
|
+
defaults: Default values for the parameters defined in `input_types`.
|
|
1249
|
+
require_original_output_arg: Used during the codegen stage to
|
|
1169
1250
|
specify whether an adjoint parameter corresponding to the return
|
|
1170
1251
|
value should be included in the signature of the backward function.
|
|
1171
1252
|
"""
|
|
@@ -1347,19 +1428,14 @@ def add_builtin(
|
|
|
1347
1428
|
def register_api_function(
|
|
1348
1429
|
function: Function,
|
|
1349
1430
|
group: str = "Other",
|
|
1350
|
-
hidden=False,
|
|
1431
|
+
hidden: bool = False,
|
|
1351
1432
|
):
|
|
1352
1433
|
"""Main entry point to register a Warp Python function to be part of the Warp API and appear in the documentation.
|
|
1353
1434
|
|
|
1354
1435
|
Args:
|
|
1355
|
-
function
|
|
1356
|
-
group
|
|
1357
|
-
|
|
1358
|
-
Variadic arguments are supported by prefixing the parameter names
|
|
1359
|
-
with asterisks as in `*args` and `**kwargs`. Generic arguments are
|
|
1360
|
-
supported with types such as `Any`, `Float`, `Scalar`, etc.
|
|
1361
|
-
value_type (Any): Type returned by the function.
|
|
1362
|
-
hidden (bool): Whether to add that function into the documentation.
|
|
1436
|
+
function: Warp function to be registered.
|
|
1437
|
+
group: Classification used for the documentation.
|
|
1438
|
+
hidden: Whether to add that function into the documentation.
|
|
1363
1439
|
"""
|
|
1364
1440
|
function.group = group
|
|
1365
1441
|
function.hidden = hidden
|
|
@@ -1367,10 +1443,10 @@ def register_api_function(
|
|
|
1367
1443
|
|
|
1368
1444
|
|
|
1369
1445
|
# global dictionary of modules
|
|
1370
|
-
user_modules = {}
|
|
1446
|
+
user_modules: Dict[str, Module] = {}
|
|
1371
1447
|
|
|
1372
1448
|
|
|
1373
|
-
def get_module(name):
|
|
1449
|
+
def get_module(name: str) -> Module:
|
|
1374
1450
|
# some modules might be manually imported using `importlib` without being
|
|
1375
1451
|
# registered into `sys.modules`
|
|
1376
1452
|
parent = sys.modules.get(name, None)
|
|
@@ -1452,13 +1528,16 @@ class ModuleHasher:
|
|
|
1452
1528
|
if warp.config.verify_fp:
|
|
1453
1529
|
ch.update(bytes("verify_fp", "utf-8"))
|
|
1454
1530
|
|
|
1531
|
+
# line directives, e.g. for Nsight Compute
|
|
1532
|
+
ch.update(bytes(ctypes.c_int(warp.config.line_directives)))
|
|
1533
|
+
|
|
1455
1534
|
# build config
|
|
1456
1535
|
ch.update(bytes(warp.config.mode, "utf-8"))
|
|
1457
1536
|
|
|
1458
1537
|
# save the module hash
|
|
1459
1538
|
self.module_hash = ch.digest()
|
|
1460
1539
|
|
|
1461
|
-
def hash_kernel(self, kernel):
|
|
1540
|
+
def hash_kernel(self, kernel: Kernel) -> bytes:
|
|
1462
1541
|
# NOTE: We only hash non-generic kernels, so we don't traverse kernel overloads here.
|
|
1463
1542
|
|
|
1464
1543
|
ch = hashlib.sha256()
|
|
@@ -1472,7 +1551,7 @@ class ModuleHasher:
|
|
|
1472
1551
|
|
|
1473
1552
|
return h
|
|
1474
1553
|
|
|
1475
|
-
def hash_function(self, func):
|
|
1554
|
+
def hash_function(self, func: Function) -> bytes:
|
|
1476
1555
|
# NOTE: This method hashes all possible overloads that a function call could resolve to.
|
|
1477
1556
|
# The exact overload will be resolved at build time, when the argument types are known.
|
|
1478
1557
|
|
|
@@ -1487,7 +1566,7 @@ class ModuleHasher:
|
|
|
1487
1566
|
ch.update(bytes(func.key, "utf-8"))
|
|
1488
1567
|
|
|
1489
1568
|
# include all concrete and generic overloads
|
|
1490
|
-
overloads = {**func.user_overloads, **func.user_templates}
|
|
1569
|
+
overloads: Dict[str, Function] = {**func.user_overloads, **func.user_templates}
|
|
1491
1570
|
for sig in sorted(overloads.keys()):
|
|
1492
1571
|
ovl = overloads[sig]
|
|
1493
1572
|
|
|
@@ -1518,7 +1597,7 @@ class ModuleHasher:
|
|
|
1518
1597
|
|
|
1519
1598
|
return h
|
|
1520
1599
|
|
|
1521
|
-
def hash_adjoint(self, adj):
|
|
1600
|
+
def hash_adjoint(self, adj: warp.codegen.Adjoint) -> bytes:
|
|
1522
1601
|
# NOTE: We don't cache adjoint hashes, because adjoints are always unique.
|
|
1523
1602
|
# Even instances of generic kernels and functions have unique adjoints with
|
|
1524
1603
|
# different argument types.
|
|
@@ -1567,7 +1646,7 @@ class ModuleHasher:
|
|
|
1567
1646
|
|
|
1568
1647
|
return ch.digest()
|
|
1569
1648
|
|
|
1570
|
-
def get_constant_bytes(self, value):
|
|
1649
|
+
def get_constant_bytes(self, value) -> bytes:
|
|
1571
1650
|
if isinstance(value, int):
|
|
1572
1651
|
# this also handles builtins.bool
|
|
1573
1652
|
return bytes(ctypes.c_int(value))
|
|
@@ -1585,7 +1664,7 @@ class ModuleHasher:
|
|
|
1585
1664
|
else:
|
|
1586
1665
|
raise TypeError(f"Invalid constant type: {type(value)}")
|
|
1587
1666
|
|
|
1588
|
-
def get_module_hash(self):
|
|
1667
|
+
def get_module_hash(self) -> bytes:
|
|
1589
1668
|
return self.module_hash
|
|
1590
1669
|
|
|
1591
1670
|
def get_unique_kernels(self):
|
|
@@ -1602,6 +1681,7 @@ class ModuleBuilder:
|
|
|
1602
1681
|
self.fatbins = {} # map from <some identifier> to fatbins, to add at link time
|
|
1603
1682
|
self.ltoirs = {} # map from lto symbol to lto binary
|
|
1604
1683
|
self.ltoirs_decl = {} # map from lto symbol to lto forward declaration
|
|
1684
|
+
self.shared_memory_bytes = {} # map from lto symbol to shared memory requirements
|
|
1605
1685
|
|
|
1606
1686
|
if hasher is None:
|
|
1607
1687
|
hasher = ModuleHasher(module)
|
|
@@ -1718,9 +1798,9 @@ class ModuleBuilder:
|
|
|
1718
1798
|
|
|
1719
1799
|
# add headers
|
|
1720
1800
|
if device == "cpu":
|
|
1721
|
-
source = warp.codegen.cpu_module_header.format(
|
|
1801
|
+
source = warp.codegen.cpu_module_header.format(block_dim=self.options["block_dim"]) + source
|
|
1722
1802
|
else:
|
|
1723
|
-
source = warp.codegen.cuda_module_header.format(
|
|
1803
|
+
source = warp.codegen.cuda_module_header.format(block_dim=self.options["block_dim"]) + source
|
|
1724
1804
|
|
|
1725
1805
|
return source
|
|
1726
1806
|
|
|
@@ -1757,7 +1837,7 @@ class ModuleExec:
|
|
|
1757
1837
|
runtime.llvm.unload_obj(self.handle.encode("utf-8"))
|
|
1758
1838
|
|
|
1759
1839
|
# lookup and cache kernel entry points
|
|
1760
|
-
def get_kernel_hooks(self, kernel):
|
|
1840
|
+
def get_kernel_hooks(self, kernel) -> KernelHooks:
|
|
1761
1841
|
# Use kernel.adj as a unique key for cache lookups instead of the kernel itself.
|
|
1762
1842
|
# This avoids holding a reference to the kernel and is faster than using
|
|
1763
1843
|
# a WeakKeyDictionary with kernels as keys.
|
|
@@ -1830,7 +1910,7 @@ class ModuleExec:
|
|
|
1830
1910
|
# creates a hash of the function to use for checking
|
|
1831
1911
|
# build cache
|
|
1832
1912
|
class Module:
|
|
1833
|
-
def __init__(self, name, loader):
|
|
1913
|
+
def __init__(self, name: Optional[str], loader=None):
|
|
1834
1914
|
self.name = name if name is not None else "None"
|
|
1835
1915
|
|
|
1836
1916
|
self.loader = loader
|
|
@@ -1870,7 +1950,7 @@ class Module:
|
|
|
1870
1950
|
"enable_backward": warp.config.enable_backward,
|
|
1871
1951
|
"fast_math": False,
|
|
1872
1952
|
"fuse_fp": True,
|
|
1873
|
-
"lineinfo":
|
|
1953
|
+
"lineinfo": warp.config.lineinfo,
|
|
1874
1954
|
"cuda_output": None, # supported values: "ptx", "cubin", or None (automatic)
|
|
1875
1955
|
"mode": warp.config.mode,
|
|
1876
1956
|
"block_dim": 256,
|
|
@@ -2073,7 +2153,11 @@ class Module:
|
|
|
2073
2153
|
use_ptx = True
|
|
2074
2154
|
|
|
2075
2155
|
if use_ptx:
|
|
2076
|
-
|
|
2156
|
+
# use the default PTX arch if the device supports it
|
|
2157
|
+
if warp.config.ptx_target_arch is not None:
|
|
2158
|
+
output_arch = min(device.arch, warp.config.ptx_target_arch)
|
|
2159
|
+
else:
|
|
2160
|
+
output_arch = min(device.arch, runtime.default_ptx_arch)
|
|
2077
2161
|
output_name = f"{module_name_short}.sm{output_arch}.ptx"
|
|
2078
2162
|
else:
|
|
2079
2163
|
output_arch = device.arch
|
|
@@ -2186,34 +2270,8 @@ class Module:
|
|
|
2186
2270
|
# -----------------------------------------------------------
|
|
2187
2271
|
# update cache
|
|
2188
2272
|
|
|
2189
|
-
def safe_rename(src, dst, attempts=5, delay=0.1):
|
|
2190
|
-
for i in range(attempts):
|
|
2191
|
-
try:
|
|
2192
|
-
os.rename(src, dst)
|
|
2193
|
-
return
|
|
2194
|
-
except FileExistsError:
|
|
2195
|
-
return
|
|
2196
|
-
except OSError as e:
|
|
2197
|
-
if e.errno == errno.ENOTEMPTY:
|
|
2198
|
-
# if directory exists we assume another process
|
|
2199
|
-
# got there first, in which case we will copy
|
|
2200
|
-
# our output to the directory manually in second step
|
|
2201
|
-
return
|
|
2202
|
-
else:
|
|
2203
|
-
# otherwise assume directory creation failed e.g.: access denied
|
|
2204
|
-
# on Windows we see occasional failures to rename directories due to
|
|
2205
|
-
# some process holding a lock on a file to be moved to workaround
|
|
2206
|
-
# this we make multiple attempts to rename with some delay
|
|
2207
|
-
if i < attempts - 1:
|
|
2208
|
-
time.sleep(delay)
|
|
2209
|
-
else:
|
|
2210
|
-
print(
|
|
2211
|
-
f"Could not update Warp cache with module binaries, trying to rename {build_dir} to {module_dir}, error {e}"
|
|
2212
|
-
)
|
|
2213
|
-
raise e
|
|
2214
|
-
|
|
2215
2273
|
# try to move process outputs to cache
|
|
2216
|
-
safe_rename(build_dir, module_dir)
|
|
2274
|
+
warp.build.safe_rename(build_dir, module_dir)
|
|
2217
2275
|
|
|
2218
2276
|
if os.path.exists(module_dir):
|
|
2219
2277
|
if not os.path.exists(binary_path):
|
|
@@ -2286,7 +2344,7 @@ class Module:
|
|
|
2286
2344
|
self.failed_builds = set()
|
|
2287
2345
|
|
|
2288
2346
|
# lookup kernel entry points based on name, called after compilation / module load
|
|
2289
|
-
def get_kernel_hooks(self, kernel, device):
|
|
2347
|
+
def get_kernel_hooks(self, kernel, device: Device) -> KernelHooks:
|
|
2290
2348
|
module_exec = self.execs.get((device.context, self.options["block_dim"]))
|
|
2291
2349
|
if module_exec is not None:
|
|
2292
2350
|
return module_exec.get_kernel_hooks(kernel)
|
|
@@ -2441,6 +2499,7 @@ class Event:
|
|
|
2441
2499
|
raise RuntimeError(f"Device {device} is not a CUDA device")
|
|
2442
2500
|
|
|
2443
2501
|
self.device = device
|
|
2502
|
+
self.enable_timing = enable_timing
|
|
2444
2503
|
|
|
2445
2504
|
if cuda_event is not None:
|
|
2446
2505
|
self.cuda_event = cuda_event
|
|
@@ -2490,6 +2549,17 @@ class Event:
|
|
|
2490
2549
|
else:
|
|
2491
2550
|
raise RuntimeError(f"Device {self.device} does not support IPC.")
|
|
2492
2551
|
|
|
2552
|
+
@property
|
|
2553
|
+
def is_complete(self) -> bool:
|
|
2554
|
+
"""A boolean indicating whether all work on the stream when the event was recorded has completed.
|
|
2555
|
+
|
|
2556
|
+
This property may not be accessed during a graph capture on any stream.
|
|
2557
|
+
"""
|
|
2558
|
+
|
|
2559
|
+
result_code = runtime.core.cuda_event_query(self.cuda_event)
|
|
2560
|
+
|
|
2561
|
+
return result_code == 0
|
|
2562
|
+
|
|
2493
2563
|
def __del__(self):
|
|
2494
2564
|
if not self.owner:
|
|
2495
2565
|
return
|
|
@@ -2504,7 +2574,7 @@ class Stream:
|
|
|
2504
2574
|
instance.owner = False
|
|
2505
2575
|
return instance
|
|
2506
2576
|
|
|
2507
|
-
def __init__(self, device:
|
|
2577
|
+
def __init__(self, device: Union["Device", str, None] = None, priority: int = 0, **kwargs):
|
|
2508
2578
|
"""Initialize the stream on a device with an optional specified priority.
|
|
2509
2579
|
|
|
2510
2580
|
Args:
|
|
@@ -2520,7 +2590,7 @@ class Stream:
|
|
|
2520
2590
|
Raises:
|
|
2521
2591
|
RuntimeError: If function is called before Warp has completed
|
|
2522
2592
|
initialization with a ``device`` that is not an instance of
|
|
2523
|
-
:class:`Device
|
|
2593
|
+
:class:`Device <warp.context.Device>`.
|
|
2524
2594
|
RuntimeError: ``device`` is not a CUDA Device.
|
|
2525
2595
|
RuntimeError: The stream could not be created on the device.
|
|
2526
2596
|
TypeError: The requested stream priority is not an integer.
|
|
@@ -2588,7 +2658,7 @@ class Stream:
|
|
|
2588
2658
|
f"Event from device {event.device} cannot be recorded on stream from device {self.device}"
|
|
2589
2659
|
)
|
|
2590
2660
|
|
|
2591
|
-
runtime.core.cuda_event_record(event.cuda_event, self.cuda_stream)
|
|
2661
|
+
runtime.core.cuda_event_record(event.cuda_event, self.cuda_stream, event.enable_timing)
|
|
2592
2662
|
|
|
2593
2663
|
return event
|
|
2594
2664
|
|
|
@@ -2622,6 +2692,17 @@ class Stream:
|
|
|
2622
2692
|
|
|
2623
2693
|
runtime.core.cuda_stream_wait_stream(self.cuda_stream, other_stream.cuda_stream, event.cuda_event)
|
|
2624
2694
|
|
|
2695
|
+
@property
|
|
2696
|
+
def is_complete(self) -> bool:
|
|
2697
|
+
"""A boolean indicating whether all work on the stream has completed.
|
|
2698
|
+
|
|
2699
|
+
This property may not be accessed during a graph capture on any stream.
|
|
2700
|
+
"""
|
|
2701
|
+
|
|
2702
|
+
result_code = runtime.core.cuda_stream_query(self.cuda_stream)
|
|
2703
|
+
|
|
2704
|
+
return result_code == 0
|
|
2705
|
+
|
|
2625
2706
|
@property
|
|
2626
2707
|
def is_capturing(self) -> bool:
|
|
2627
2708
|
"""A boolean indicating whether a graph capture is currently ongoing on this stream."""
|
|
@@ -2944,18 +3025,14 @@ Devicelike = Union[Device, str, None]
|
|
|
2944
3025
|
|
|
2945
3026
|
|
|
2946
3027
|
class Graph:
|
|
2947
|
-
def __new__(cls, *args, **kwargs):
|
|
2948
|
-
instance = super(Graph, cls).__new__(cls)
|
|
2949
|
-
instance.graph_exec = None
|
|
2950
|
-
return instance
|
|
2951
|
-
|
|
2952
3028
|
def __init__(self, device: Device, capture_id: int):
|
|
2953
3029
|
self.device = device
|
|
2954
3030
|
self.capture_id = capture_id
|
|
2955
|
-
self.module_execs = set()
|
|
3031
|
+
self.module_execs: Set[ModuleExec] = set()
|
|
3032
|
+
self.graph_exec: Optional[ctypes.c_void_p] = None
|
|
2956
3033
|
|
|
2957
3034
|
def __del__(self):
|
|
2958
|
-
if not self.graph_exec:
|
|
3035
|
+
if not hasattr(self, "graph_exec") or not hasattr(self, "device") or not self.graph_exec:
|
|
2959
3036
|
return
|
|
2960
3037
|
|
|
2961
3038
|
# use CUDA context guard to avoid side effects during garbage collection
|
|
@@ -3197,6 +3274,43 @@ class Runtime:
|
|
|
3197
3274
|
self.core.radix_sort_pairs_float_host.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int]
|
|
3198
3275
|
self.core.radix_sort_pairs_float_device.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int]
|
|
3199
3276
|
|
|
3277
|
+
self.core.radix_sort_pairs_int64_host.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int]
|
|
3278
|
+
self.core.radix_sort_pairs_int64_device.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int]
|
|
3279
|
+
|
|
3280
|
+
self.core.segmented_sort_pairs_int_host.argtypes = [
|
|
3281
|
+
ctypes.c_uint64,
|
|
3282
|
+
ctypes.c_uint64,
|
|
3283
|
+
ctypes.c_int,
|
|
3284
|
+
ctypes.c_uint64,
|
|
3285
|
+
ctypes.c_uint64,
|
|
3286
|
+
ctypes.c_int,
|
|
3287
|
+
]
|
|
3288
|
+
self.core.segmented_sort_pairs_int_device.argtypes = [
|
|
3289
|
+
ctypes.c_uint64,
|
|
3290
|
+
ctypes.c_uint64,
|
|
3291
|
+
ctypes.c_int,
|
|
3292
|
+
ctypes.c_uint64,
|
|
3293
|
+
ctypes.c_uint64,
|
|
3294
|
+
ctypes.c_int,
|
|
3295
|
+
]
|
|
3296
|
+
|
|
3297
|
+
self.core.segmented_sort_pairs_float_host.argtypes = [
|
|
3298
|
+
ctypes.c_uint64,
|
|
3299
|
+
ctypes.c_uint64,
|
|
3300
|
+
ctypes.c_int,
|
|
3301
|
+
ctypes.c_uint64,
|
|
3302
|
+
ctypes.c_uint64,
|
|
3303
|
+
ctypes.c_int,
|
|
3304
|
+
]
|
|
3305
|
+
self.core.segmented_sort_pairs_float_device.argtypes = [
|
|
3306
|
+
ctypes.c_uint64,
|
|
3307
|
+
ctypes.c_uint64,
|
|
3308
|
+
ctypes.c_int,
|
|
3309
|
+
ctypes.c_uint64,
|
|
3310
|
+
ctypes.c_uint64,
|
|
3311
|
+
ctypes.c_int,
|
|
3312
|
+
]
|
|
3313
|
+
|
|
3200
3314
|
self.core.runlength_encode_int_host.argtypes = [
|
|
3201
3315
|
ctypes.c_uint64,
|
|
3202
3316
|
ctypes.c_uint64,
|
|
@@ -3277,26 +3391,6 @@ class Runtime:
|
|
|
3277
3391
|
self.core.hash_grid_update_device.argtypes = [ctypes.c_uint64, ctypes.c_float, ctypes.c_void_p]
|
|
3278
3392
|
self.core.hash_grid_reserve_device.argtypes = [ctypes.c_uint64, ctypes.c_int]
|
|
3279
3393
|
|
|
3280
|
-
self.core.cutlass_gemm.argtypes = [
|
|
3281
|
-
ctypes.c_void_p,
|
|
3282
|
-
ctypes.c_int,
|
|
3283
|
-
ctypes.c_int,
|
|
3284
|
-
ctypes.c_int,
|
|
3285
|
-
ctypes.c_int,
|
|
3286
|
-
ctypes.c_char_p,
|
|
3287
|
-
ctypes.c_void_p,
|
|
3288
|
-
ctypes.c_void_p,
|
|
3289
|
-
ctypes.c_void_p,
|
|
3290
|
-
ctypes.c_void_p,
|
|
3291
|
-
ctypes.c_float,
|
|
3292
|
-
ctypes.c_float,
|
|
3293
|
-
ctypes.c_bool,
|
|
3294
|
-
ctypes.c_bool,
|
|
3295
|
-
ctypes.c_bool,
|
|
3296
|
-
ctypes.c_int,
|
|
3297
|
-
]
|
|
3298
|
-
self.core.cutlass_gemm.restype = ctypes.c_bool
|
|
3299
|
-
|
|
3300
3394
|
self.core.volume_create_host.argtypes = [ctypes.c_void_p, ctypes.c_uint64, ctypes.c_bool, ctypes.c_bool]
|
|
3301
3395
|
self.core.volume_create_host.restype = ctypes.c_uint64
|
|
3302
3396
|
self.core.volume_get_tiles_host.argtypes = [
|
|
@@ -3327,36 +3421,18 @@ class Runtime:
|
|
|
3327
3421
|
]
|
|
3328
3422
|
self.core.volume_destroy_device.argtypes = [ctypes.c_uint64]
|
|
3329
3423
|
|
|
3330
|
-
self.core.
|
|
3424
|
+
self.core.volume_from_tiles_device.argtypes = [
|
|
3331
3425
|
ctypes.c_void_p,
|
|
3332
3426
|
ctypes.c_void_p,
|
|
3333
3427
|
ctypes.c_int,
|
|
3334
3428
|
ctypes.c_float * 9,
|
|
3335
3429
|
ctypes.c_float * 3,
|
|
3336
3430
|
ctypes.c_bool,
|
|
3337
|
-
ctypes.c_float,
|
|
3338
|
-
]
|
|
3339
|
-
self.core.volume_f_from_tiles_device.restype = ctypes.c_uint64
|
|
3340
|
-
self.core.volume_v_from_tiles_device.argtypes = [
|
|
3341
3431
|
ctypes.c_void_p,
|
|
3342
|
-
ctypes.
|
|
3343
|
-
ctypes.
|
|
3344
|
-
ctypes.c_float * 9,
|
|
3345
|
-
ctypes.c_float * 3,
|
|
3346
|
-
ctypes.c_bool,
|
|
3347
|
-
ctypes.c_float * 3,
|
|
3348
|
-
]
|
|
3349
|
-
self.core.volume_v_from_tiles_device.restype = ctypes.c_uint64
|
|
3350
|
-
self.core.volume_i_from_tiles_device.argtypes = [
|
|
3351
|
-
ctypes.c_void_p,
|
|
3352
|
-
ctypes.c_void_p,
|
|
3353
|
-
ctypes.c_int,
|
|
3354
|
-
ctypes.c_float * 9,
|
|
3355
|
-
ctypes.c_float * 3,
|
|
3356
|
-
ctypes.c_bool,
|
|
3357
|
-
ctypes.c_int,
|
|
3432
|
+
ctypes.c_uint32,
|
|
3433
|
+
ctypes.c_char_p,
|
|
3358
3434
|
]
|
|
3359
|
-
self.core.
|
|
3435
|
+
self.core.volume_from_tiles_device.restype = ctypes.c_uint64
|
|
3360
3436
|
self.core.volume_index_from_tiles_device.argtypes = [
|
|
3361
3437
|
ctypes.c_void_p,
|
|
3362
3438
|
ctypes.c_void_p,
|
|
@@ -3425,6 +3501,7 @@ class Runtime:
|
|
|
3425
3501
|
ctypes.POINTER(ctypes.c_int), # tpl_cols
|
|
3426
3502
|
ctypes.c_void_p, # tpl_values
|
|
3427
3503
|
ctypes.c_bool, # prune_numerical_zeros
|
|
3504
|
+
ctypes.c_bool, # masked
|
|
3428
3505
|
ctypes.POINTER(ctypes.c_int), # bsr_offsets
|
|
3429
3506
|
ctypes.POINTER(ctypes.c_int), # bsr_columns
|
|
3430
3507
|
ctypes.c_void_p, # bsr_values
|
|
@@ -3459,8 +3536,6 @@ class Runtime:
|
|
|
3459
3536
|
self.core.is_cuda_enabled.restype = ctypes.c_int
|
|
3460
3537
|
self.core.is_cuda_compatibility_enabled.argtypes = None
|
|
3461
3538
|
self.core.is_cuda_compatibility_enabled.restype = ctypes.c_int
|
|
3462
|
-
self.core.is_cutlass_enabled.argtypes = None
|
|
3463
|
-
self.core.is_cutlass_enabled.restype = ctypes.c_int
|
|
3464
3539
|
self.core.is_mathdx_enabled.argtypes = None
|
|
3465
3540
|
self.core.is_mathdx_enabled.restype = ctypes.c_int
|
|
3466
3541
|
|
|
@@ -3494,6 +3569,10 @@ class Runtime:
|
|
|
3494
3569
|
self.core.cuda_device_set_mempool_release_threshold.restype = ctypes.c_int
|
|
3495
3570
|
self.core.cuda_device_get_mempool_release_threshold.argtypes = [ctypes.c_int]
|
|
3496
3571
|
self.core.cuda_device_get_mempool_release_threshold.restype = ctypes.c_uint64
|
|
3572
|
+
self.core.cuda_device_get_mempool_used_mem_current.argtypes = [ctypes.c_int]
|
|
3573
|
+
self.core.cuda_device_get_mempool_used_mem_current.restype = ctypes.c_uint64
|
|
3574
|
+
self.core.cuda_device_get_mempool_used_mem_high.argtypes = [ctypes.c_int]
|
|
3575
|
+
self.core.cuda_device_get_mempool_used_mem_high.restype = ctypes.c_uint64
|
|
3497
3576
|
self.core.cuda_device_get_memory_info.argtypes = [ctypes.c_int, ctypes.c_void_p, ctypes.c_void_p]
|
|
3498
3577
|
self.core.cuda_device_get_memory_info.restype = None
|
|
3499
3578
|
self.core.cuda_device_get_uuid.argtypes = [ctypes.c_int, ctypes.c_char * 16]
|
|
@@ -3563,6 +3642,8 @@ class Runtime:
|
|
|
3563
3642
|
self.core.cuda_stream_create.restype = ctypes.c_void_p
|
|
3564
3643
|
self.core.cuda_stream_destroy.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
|
|
3565
3644
|
self.core.cuda_stream_destroy.restype = None
|
|
3645
|
+
self.core.cuda_stream_query.argtypes = [ctypes.c_void_p]
|
|
3646
|
+
self.core.cuda_stream_query.restype = ctypes.c_int
|
|
3566
3647
|
self.core.cuda_stream_register.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
|
|
3567
3648
|
self.core.cuda_stream_register.restype = None
|
|
3568
3649
|
self.core.cuda_stream_unregister.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
|
|
@@ -3584,7 +3665,9 @@ class Runtime:
|
|
|
3584
3665
|
self.core.cuda_event_create.restype = ctypes.c_void_p
|
|
3585
3666
|
self.core.cuda_event_destroy.argtypes = [ctypes.c_void_p]
|
|
3586
3667
|
self.core.cuda_event_destroy.restype = None
|
|
3587
|
-
self.core.
|
|
3668
|
+
self.core.cuda_event_query.argtypes = [ctypes.c_void_p]
|
|
3669
|
+
self.core.cuda_event_query.restype = ctypes.c_int
|
|
3670
|
+
self.core.cuda_event_record.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_bool]
|
|
3588
3671
|
self.core.cuda_event_record.restype = None
|
|
3589
3672
|
self.core.cuda_event_synchronize.argtypes = [ctypes.c_void_p]
|
|
3590
3673
|
self.core.cuda_event_synchronize.restype = None
|
|
@@ -3833,9 +3916,20 @@ class Runtime:
|
|
|
3833
3916
|
cuda_device_count = len(self.cuda_devices)
|
|
3834
3917
|
else:
|
|
3835
3918
|
self.set_default_device("cuda:0")
|
|
3919
|
+
|
|
3920
|
+
# the minimum PTX architecture that supports all of Warp's features
|
|
3921
|
+
self.default_ptx_arch = 75
|
|
3922
|
+
|
|
3923
|
+
# Update the default PTX architecture based on devices present in the system.
|
|
3924
|
+
# Use the lowest architecture among devices that meet the minimum architecture requirement.
|
|
3925
|
+
# Devices below the required minimum will use the highest architecture they support.
|
|
3926
|
+
eligible_archs = [d.arch for d in self.cuda_devices if d.arch >= self.default_ptx_arch]
|
|
3927
|
+
if eligible_archs:
|
|
3928
|
+
self.default_ptx_arch = min(eligible_archs)
|
|
3836
3929
|
else:
|
|
3837
3930
|
# CUDA not available
|
|
3838
3931
|
self.set_default_device("cpu")
|
|
3932
|
+
self.default_ptx_arch = None
|
|
3839
3933
|
|
|
3840
3934
|
# initialize kernel cache
|
|
3841
3935
|
warp.build.init_kernel_cache(warp.config.kernel_cache_dir)
|
|
@@ -3848,6 +3942,11 @@ class Runtime:
|
|
|
3848
3942
|
greeting = []
|
|
3849
3943
|
|
|
3850
3944
|
greeting.append(f"Warp {warp.config.version} initialized:")
|
|
3945
|
+
|
|
3946
|
+
# Add git commit hash to greeting if available
|
|
3947
|
+
if warp.config._git_commit_hash is not None:
|
|
3948
|
+
greeting.append(f" Git commit: {warp.config._git_commit_hash}")
|
|
3949
|
+
|
|
3851
3950
|
if cuda_device_count > 0:
|
|
3852
3951
|
# print CUDA version info
|
|
3853
3952
|
greeting.append(
|
|
@@ -4200,7 +4299,7 @@ def set_device(ident: Devicelike) -> None:
|
|
|
4200
4299
|
device.make_current()
|
|
4201
4300
|
|
|
4202
4301
|
|
|
4203
|
-
def map_cuda_device(alias: str, context: ctypes.c_void_p = None) -> Device:
|
|
4302
|
+
def map_cuda_device(alias: str, context: Optional[ctypes.c_void_p] = None) -> Device:
|
|
4204
4303
|
"""Assign a device alias to a CUDA context.
|
|
4205
4304
|
|
|
4206
4305
|
This function can be used to create a wp.Device for an external CUDA context.
|
|
@@ -4228,7 +4327,13 @@ def unmap_cuda_device(alias: str) -> None:
|
|
|
4228
4327
|
|
|
4229
4328
|
|
|
4230
4329
|
def is_mempool_supported(device: Devicelike) -> bool:
|
|
4231
|
-
"""Check if CUDA memory pool allocators are available on the device.
|
|
4330
|
+
"""Check if CUDA memory pool allocators are available on the device.
|
|
4331
|
+
|
|
4332
|
+
Parameters:
|
|
4333
|
+
device: The :class:`Device <warp.context.Device>` or device identifier
|
|
4334
|
+
for which the query is to be performed.
|
|
4335
|
+
If ``None``, the default device will be used.
|
|
4336
|
+
"""
|
|
4232
4337
|
|
|
4233
4338
|
init()
|
|
4234
4339
|
|
|
@@ -4238,7 +4343,13 @@ def is_mempool_supported(device: Devicelike) -> bool:
|
|
|
4238
4343
|
|
|
4239
4344
|
|
|
4240
4345
|
def is_mempool_enabled(device: Devicelike) -> bool:
|
|
4241
|
-
"""Check if CUDA memory pool allocators are enabled on the device.
|
|
4346
|
+
"""Check if CUDA memory pool allocators are enabled on the device.
|
|
4347
|
+
|
|
4348
|
+
Parameters:
|
|
4349
|
+
device: The :class:`Device <warp.context.Device>` or device identifier
|
|
4350
|
+
for which the query is to be performed.
|
|
4351
|
+
If ``None``, the default device will be used.
|
|
4352
|
+
"""
|
|
4242
4353
|
|
|
4243
4354
|
init()
|
|
4244
4355
|
|
|
@@ -4258,6 +4369,11 @@ def set_mempool_enabled(device: Devicelike, enable: bool) -> None:
|
|
|
4258
4369
|
to Warp. The preferred solution is to enable memory pool access using :func:`set_mempool_access_enabled`.
|
|
4259
4370
|
If peer access is not supported, then the default CUDA allocators must be used to pre-allocate the memory
|
|
4260
4371
|
prior to graph capture.
|
|
4372
|
+
|
|
4373
|
+
Parameters:
|
|
4374
|
+
device: The :class:`Device <warp.context.Device>` or device identifier
|
|
4375
|
+
for which the operation is to be performed.
|
|
4376
|
+
If ``None``, the default device will be used.
|
|
4261
4377
|
"""
|
|
4262
4378
|
|
|
4263
4379
|
init()
|
|
@@ -4288,6 +4404,18 @@ def set_mempool_release_threshold(device: Devicelike, threshold: Union[int, floa
|
|
|
4288
4404
|
Values between 0 and 1 are interpreted as fractions of available memory. For example, 0.5 means
|
|
4289
4405
|
half of the device's physical memory. Greater values are interpreted as an absolute number of bytes.
|
|
4290
4406
|
For example, 1024**3 means one GiB of memory.
|
|
4407
|
+
|
|
4408
|
+
Parameters:
|
|
4409
|
+
device: The :class:`Device <warp.context.Device>` or device identifier
|
|
4410
|
+
for which the operation is to be performed.
|
|
4411
|
+
If ``None``, the default device will be used.
|
|
4412
|
+
threshold: An integer representing a number of bytes, or a ``float`` between 0 and 1,
|
|
4413
|
+
specifying the desired release threshold.
|
|
4414
|
+
|
|
4415
|
+
Raises:
|
|
4416
|
+
ValueError: If ``device`` is not a CUDA device.
|
|
4417
|
+
RuntimeError: If ``device`` is a CUDA device, but does not support memory pools.
|
|
4418
|
+
RuntimeError: Failed to set the memory pool release threshold.
|
|
4291
4419
|
"""
|
|
4292
4420
|
|
|
4293
4421
|
init()
|
|
@@ -4309,8 +4437,21 @@ def set_mempool_release_threshold(device: Devicelike, threshold: Union[int, floa
|
|
|
4309
4437
|
raise RuntimeError(f"Failed to set memory pool release threshold for device {device}")
|
|
4310
4438
|
|
|
4311
4439
|
|
|
4312
|
-
def get_mempool_release_threshold(device: Devicelike) -> int:
|
|
4313
|
-
"""Get the CUDA memory pool release threshold on the device
|
|
4440
|
+
def get_mempool_release_threshold(device: Devicelike = None) -> int:
|
|
4441
|
+
"""Get the CUDA memory pool release threshold on the device.
|
|
4442
|
+
|
|
4443
|
+
Parameters:
|
|
4444
|
+
device: The :class:`Device <warp.context.Device>` or device identifier
|
|
4445
|
+
for which the query is to be performed.
|
|
4446
|
+
If ``None``, the default device will be used.
|
|
4447
|
+
|
|
4448
|
+
Returns:
|
|
4449
|
+
The memory pool release threshold in bytes.
|
|
4450
|
+
|
|
4451
|
+
Raises:
|
|
4452
|
+
ValueError: If ``device`` is not a CUDA device.
|
|
4453
|
+
RuntimeError: If ``device`` is a CUDA device, but does not support memory pools.
|
|
4454
|
+
"""
|
|
4314
4455
|
|
|
4315
4456
|
init()
|
|
4316
4457
|
|
|
@@ -4325,6 +4466,64 @@ def get_mempool_release_threshold(device: Devicelike) -> int:
|
|
|
4325
4466
|
return runtime.core.cuda_device_get_mempool_release_threshold(device.ordinal)
|
|
4326
4467
|
|
|
4327
4468
|
|
|
4469
|
+
def get_mempool_used_mem_current(device: Devicelike = None) -> int:
|
|
4470
|
+
"""Get the amount of memory from the device's memory pool that is currently in use by the application.
|
|
4471
|
+
|
|
4472
|
+
Parameters:
|
|
4473
|
+
device: The :class:`Device <warp.context.Device>` or device identifier
|
|
4474
|
+
for which the query is to be performed.
|
|
4475
|
+
If ``None``, the default device will be used.
|
|
4476
|
+
|
|
4477
|
+
Returns:
|
|
4478
|
+
The amount of memory used in bytes.
|
|
4479
|
+
|
|
4480
|
+
Raises:
|
|
4481
|
+
ValueError: If ``device`` is not a CUDA device.
|
|
4482
|
+
RuntimeError: If ``device`` is a CUDA device, but does not support memory pools.
|
|
4483
|
+
"""
|
|
4484
|
+
|
|
4485
|
+
init()
|
|
4486
|
+
|
|
4487
|
+
device = runtime.get_device(device)
|
|
4488
|
+
|
|
4489
|
+
if not device.is_cuda:
|
|
4490
|
+
raise ValueError("Memory pools are only supported on CUDA devices")
|
|
4491
|
+
|
|
4492
|
+
if not device.is_mempool_supported:
|
|
4493
|
+
raise RuntimeError(f"Device {device} does not support memory pools")
|
|
4494
|
+
|
|
4495
|
+
return runtime.core.cuda_device_get_mempool_used_mem_current(device.ordinal)
|
|
4496
|
+
|
|
4497
|
+
|
|
4498
|
+
def get_mempool_used_mem_high(device: Devicelike = None) -> int:
|
|
4499
|
+
"""Get the application's memory usage high-water mark from the device's CUDA memory pool.
|
|
4500
|
+
|
|
4501
|
+
Parameters:
|
|
4502
|
+
device: The :class:`Device <warp.context.Device>` or device identifier
|
|
4503
|
+
for which the query is to be performed.
|
|
4504
|
+
If ``None``, the default device will be used.
|
|
4505
|
+
|
|
4506
|
+
Returns:
|
|
4507
|
+
The high-water mark of memory used from the memory pool in bytes.
|
|
4508
|
+
|
|
4509
|
+
Raises:
|
|
4510
|
+
ValueError: If ``device`` is not a CUDA device.
|
|
4511
|
+
RuntimeError: If ``device`` is a CUDA device, but does not support memory pools.
|
|
4512
|
+
"""
|
|
4513
|
+
|
|
4514
|
+
init()
|
|
4515
|
+
|
|
4516
|
+
device = runtime.get_device(device)
|
|
4517
|
+
|
|
4518
|
+
if not device.is_cuda:
|
|
4519
|
+
raise ValueError("Memory pools are only supported on CUDA devices")
|
|
4520
|
+
|
|
4521
|
+
if not device.is_mempool_supported:
|
|
4522
|
+
raise RuntimeError(f"Device {device} does not support memory pools")
|
|
4523
|
+
|
|
4524
|
+
return runtime.core.cuda_device_get_mempool_used_mem_high(device.ordinal)
|
|
4525
|
+
|
|
4526
|
+
|
|
4328
4527
|
def is_peer_access_supported(target_device: Devicelike, peer_device: Devicelike) -> bool:
|
|
4329
4528
|
"""Check if `peer_device` can directly access the memory of `target_device` on this system.
|
|
4330
4529
|
|
|
@@ -4527,7 +4726,7 @@ def wait_event(event: Event):
|
|
|
4527
4726
|
get_stream().wait_event(event)
|
|
4528
4727
|
|
|
4529
4728
|
|
|
4530
|
-
def get_event_elapsed_time(start_event: Event, end_event: Event, synchronize:
|
|
4729
|
+
def get_event_elapsed_time(start_event: Event, end_event: Event, synchronize: bool = True):
|
|
4531
4730
|
"""Get the elapsed time between two recorded events.
|
|
4532
4731
|
|
|
4533
4732
|
Both events must have been previously recorded with
|
|
@@ -4552,7 +4751,7 @@ def get_event_elapsed_time(start_event: Event, end_event: Event, synchronize: Op
|
|
|
4552
4751
|
return runtime.core.cuda_event_elapsed_time(start_event.cuda_event, end_event.cuda_event)
|
|
4553
4752
|
|
|
4554
4753
|
|
|
4555
|
-
def wait_stream(other_stream: Stream, event: Event = None):
|
|
4754
|
+
def wait_stream(other_stream: Stream, event: Optional[Event] = None):
|
|
4556
4755
|
"""Convenience function for calling :meth:`Stream.wait_stream` on the current stream.
|
|
4557
4756
|
|
|
4558
4757
|
Args:
|
|
@@ -4719,7 +4918,7 @@ class RegisteredGLBuffer:
|
|
|
4719
4918
|
|
|
4720
4919
|
|
|
4721
4920
|
def zeros(
|
|
4722
|
-
shape: Tuple = None,
|
|
4921
|
+
shape: Union[int, Tuple[int, ...], List[int], None] = None,
|
|
4723
4922
|
dtype=float,
|
|
4724
4923
|
device: Devicelike = None,
|
|
4725
4924
|
requires_grad: bool = False,
|
|
@@ -4747,7 +4946,7 @@ def zeros(
|
|
|
4747
4946
|
|
|
4748
4947
|
|
|
4749
4948
|
def zeros_like(
|
|
4750
|
-
src:
|
|
4949
|
+
src: Array, device: Devicelike = None, requires_grad: Optional[bool] = None, pinned: Optional[bool] = None
|
|
4751
4950
|
) -> warp.array:
|
|
4752
4951
|
"""Return a zero-initialized array with the same type and dimension of another array
|
|
4753
4952
|
|
|
@@ -4769,7 +4968,7 @@ def zeros_like(
|
|
|
4769
4968
|
|
|
4770
4969
|
|
|
4771
4970
|
def ones(
|
|
4772
|
-
shape: Tuple = None,
|
|
4971
|
+
shape: Union[int, Tuple[int, ...], List[int], None] = None,
|
|
4773
4972
|
dtype=float,
|
|
4774
4973
|
device: Devicelike = None,
|
|
4775
4974
|
requires_grad: bool = False,
|
|
@@ -4793,7 +4992,7 @@ def ones(
|
|
|
4793
4992
|
|
|
4794
4993
|
|
|
4795
4994
|
def ones_like(
|
|
4796
|
-
src:
|
|
4995
|
+
src: Array, device: Devicelike = None, requires_grad: Optional[bool] = None, pinned: Optional[bool] = None
|
|
4797
4996
|
) -> warp.array:
|
|
4798
4997
|
"""Return a one-initialized array with the same type and dimension of another array
|
|
4799
4998
|
|
|
@@ -4811,7 +5010,7 @@ def ones_like(
|
|
|
4811
5010
|
|
|
4812
5011
|
|
|
4813
5012
|
def full(
|
|
4814
|
-
shape: Tuple = None,
|
|
5013
|
+
shape: Union[int, Tuple[int, ...], List[int], None] = None,
|
|
4815
5014
|
value=0,
|
|
4816
5015
|
dtype=Any,
|
|
4817
5016
|
device: Devicelike = None,
|
|
@@ -4877,7 +5076,11 @@ def full(
|
|
|
4877
5076
|
|
|
4878
5077
|
|
|
4879
5078
|
def full_like(
|
|
4880
|
-
src:
|
|
5079
|
+
src: Array,
|
|
5080
|
+
value: Any,
|
|
5081
|
+
device: Devicelike = None,
|
|
5082
|
+
requires_grad: Optional[bool] = None,
|
|
5083
|
+
pinned: Optional[bool] = None,
|
|
4881
5084
|
) -> warp.array:
|
|
4882
5085
|
"""Return an array with all elements initialized to the given value with the same type and dimension of another array
|
|
4883
5086
|
|
|
@@ -4899,7 +5102,9 @@ def full_like(
|
|
|
4899
5102
|
return arr
|
|
4900
5103
|
|
|
4901
5104
|
|
|
4902
|
-
def clone(
|
|
5105
|
+
def clone(
|
|
5106
|
+
src: warp.array, device: Devicelike = None, requires_grad: Optional[bool] = None, pinned: Optional[bool] = None
|
|
5107
|
+
) -> warp.array:
|
|
4903
5108
|
"""Clone an existing array, allocates a copy of the src memory
|
|
4904
5109
|
|
|
4905
5110
|
Args:
|
|
@@ -4920,7 +5125,7 @@ def clone(src: warp.array, device: Devicelike = None, requires_grad: bool = None
|
|
|
4920
5125
|
|
|
4921
5126
|
|
|
4922
5127
|
def empty(
|
|
4923
|
-
shape: Tuple = None,
|
|
5128
|
+
shape: Union[int, Tuple[int, ...], List[int], None] = None,
|
|
4924
5129
|
dtype=float,
|
|
4925
5130
|
device: Devicelike = None,
|
|
4926
5131
|
requires_grad: bool = False,
|
|
@@ -4953,7 +5158,7 @@ def empty(
|
|
|
4953
5158
|
|
|
4954
5159
|
|
|
4955
5160
|
def empty_like(
|
|
4956
|
-
src:
|
|
5161
|
+
src: Array, device: Devicelike = None, requires_grad: Optional[bool] = None, pinned: Optional[bool] = None
|
|
4957
5162
|
) -> warp.array:
|
|
4958
5163
|
"""Return an uninitialized array with the same type and dimension of another array
|
|
4959
5164
|
|
|
@@ -5185,8 +5390,6 @@ def pack_arg(kernel, arg_type, arg_name, value, device, adjoint=False):
|
|
|
5185
5390
|
) from e
|
|
5186
5391
|
|
|
5187
5392
|
|
|
5188
|
-
# represents all data required for a kernel launch
|
|
5189
|
-
# so that launches can be replayed quickly, use `wp.launch(..., record_cmd=True)`
|
|
5190
5393
|
class Launch:
|
|
5191
5394
|
"""Represents all data required for a kernel launch so that launches can be replayed quickly.
|
|
5192
5395
|
|
|
@@ -5457,7 +5660,7 @@ def launch(
|
|
|
5457
5660
|
max_blocks: The maximum number of CUDA thread blocks to use.
|
|
5458
5661
|
Only has an effect for CUDA kernel launches.
|
|
5459
5662
|
If negative or zero, the maximum hardware value will be used.
|
|
5460
|
-
block_dim: The number of threads per block.
|
|
5663
|
+
block_dim: The number of threads per block (always 1 for "cpu" devices).
|
|
5461
5664
|
"""
|
|
5462
5665
|
|
|
5463
5666
|
init()
|
|
@@ -5468,6 +5671,9 @@ def launch(
|
|
|
5468
5671
|
else:
|
|
5469
5672
|
device = runtime.get_device(device)
|
|
5470
5673
|
|
|
5674
|
+
if device == "cpu":
|
|
5675
|
+
block_dim = 1
|
|
5676
|
+
|
|
5471
5677
|
# check function is a Kernel
|
|
5472
5678
|
if not isinstance(kernel, Kernel):
|
|
5473
5679
|
raise RuntimeError("Error launching kernel, can only launch functions decorated with @wp.kernel.")
|
|
@@ -5700,6 +5906,18 @@ def launch_tiled(*args, **kwargs):
|
|
|
5700
5906
|
"Launch block dimension 'block_dim' argument should be passed via. keyword args for wp.launch_tiled()"
|
|
5701
5907
|
)
|
|
5702
5908
|
|
|
5909
|
+
if "device" in kwargs:
|
|
5910
|
+
device = kwargs["device"]
|
|
5911
|
+
else:
|
|
5912
|
+
# todo: this doesn't consider the case where device
|
|
5913
|
+
# is passed through positional args
|
|
5914
|
+
device = None
|
|
5915
|
+
|
|
5916
|
+
# force the block_dim to 1 if running on "cpu"
|
|
5917
|
+
device = runtime.get_device(device)
|
|
5918
|
+
if device.is_cpu:
|
|
5919
|
+
kwargs["block_dim"] = 1
|
|
5920
|
+
|
|
5703
5921
|
dim = kwargs["dim"]
|
|
5704
5922
|
if not isinstance(dim, list):
|
|
5705
5923
|
dim = list(dim) if isinstance(dim, tuple) else [dim]
|
|
@@ -5868,6 +6086,7 @@ def set_module_options(options: Dict[str, Any], module: Optional[Any] = None):
|
|
|
5868
6086
|
|
|
5869
6087
|
* **mode**: The compilation mode to use, can be "debug", or "release", defaults to the value of ``warp.config.mode``.
|
|
5870
6088
|
* **max_unroll**: The maximum fixed-size loop to unroll, defaults to the value of ``warp.config.max_unroll``.
|
|
6089
|
+
* **block_dim**: The default number of threads to assign to each block
|
|
5871
6090
|
|
|
5872
6091
|
Args:
|
|
5873
6092
|
|
|
@@ -5893,7 +6112,12 @@ def get_module_options(module: Optional[Any] = None) -> Dict[str, Any]:
|
|
|
5893
6112
|
return get_module(m.__name__).options
|
|
5894
6113
|
|
|
5895
6114
|
|
|
5896
|
-
def capture_begin(
|
|
6115
|
+
def capture_begin(
|
|
6116
|
+
device: Devicelike = None,
|
|
6117
|
+
stream: Optional[Stream] = None,
|
|
6118
|
+
force_module_load: Optional[bool] = None,
|
|
6119
|
+
external: bool = False,
|
|
6120
|
+
):
|
|
5897
6121
|
"""Begin capture of a CUDA graph
|
|
5898
6122
|
|
|
5899
6123
|
Captures all subsequent kernel launches and memory operations on CUDA devices.
|
|
@@ -5960,16 +6184,15 @@ def capture_begin(device: Devicelike = None, stream=None, force_module_load=None
|
|
|
5960
6184
|
runtime.captures[capture_id] = graph
|
|
5961
6185
|
|
|
5962
6186
|
|
|
5963
|
-
def capture_end(device: Devicelike = None, stream: Stream = None) -> Graph:
|
|
5964
|
-
"""
|
|
6187
|
+
def capture_end(device: Devicelike = None, stream: Optional[Stream] = None) -> Graph:
|
|
6188
|
+
"""End the capture of a CUDA graph.
|
|
5965
6189
|
|
|
5966
6190
|
Args:
|
|
5967
|
-
|
|
5968
6191
|
device: The CUDA device where capture began
|
|
5969
6192
|
stream: The CUDA stream where capture began
|
|
5970
6193
|
|
|
5971
6194
|
Returns:
|
|
5972
|
-
A Graph object that can be launched with :func:`~warp.capture_launch()`
|
|
6195
|
+
A :class:`Graph` object that can be launched with :func:`~warp.capture_launch()`
|
|
5973
6196
|
"""
|
|
5974
6197
|
|
|
5975
6198
|
if stream is not None:
|
|
@@ -6003,12 +6226,12 @@ def capture_end(device: Devicelike = None, stream: Stream = None) -> Graph:
|
|
|
6003
6226
|
return graph
|
|
6004
6227
|
|
|
6005
6228
|
|
|
6006
|
-
def capture_launch(graph: Graph, stream: Stream = None):
|
|
6229
|
+
def capture_launch(graph: Graph, stream: Optional[Stream] = None):
|
|
6007
6230
|
"""Launch a previously captured CUDA graph
|
|
6008
6231
|
|
|
6009
6232
|
Args:
|
|
6010
|
-
graph: A Graph as returned by :func:`~warp.capture_end()`
|
|
6011
|
-
stream: A Stream to launch the graph on
|
|
6233
|
+
graph: A :class:`Graph` as returned by :func:`~warp.capture_end()`
|
|
6234
|
+
stream: A :class:`Stream` to launch the graph on
|
|
6012
6235
|
"""
|
|
6013
6236
|
|
|
6014
6237
|
if stream is not None:
|
|
@@ -6024,24 +6247,28 @@ def capture_launch(graph: Graph, stream: Stream = None):
|
|
|
6024
6247
|
|
|
6025
6248
|
|
|
6026
6249
|
def copy(
|
|
6027
|
-
dest: warp.array,
|
|
6250
|
+
dest: warp.array,
|
|
6251
|
+
src: warp.array,
|
|
6252
|
+
dest_offset: int = 0,
|
|
6253
|
+
src_offset: int = 0,
|
|
6254
|
+
count: int = 0,
|
|
6255
|
+
stream: Optional[Stream] = None,
|
|
6028
6256
|
):
|
|
6029
6257
|
"""Copy array contents from `src` to `dest`.
|
|
6030
6258
|
|
|
6031
6259
|
Args:
|
|
6032
|
-
dest: Destination array, must be at least as
|
|
6260
|
+
dest: Destination array, must be at least as large as source buffer
|
|
6033
6261
|
src: Source array
|
|
6034
6262
|
dest_offset: Element offset in the destination array
|
|
6035
6263
|
src_offset: Element offset in the source array
|
|
6036
6264
|
count: Number of array elements to copy (will copy all elements if set to 0)
|
|
6037
|
-
stream: The stream on which to perform the copy
|
|
6265
|
+
stream: The stream on which to perform the copy
|
|
6038
6266
|
|
|
6039
6267
|
The stream, if specified, can be from any device. If the stream is omitted, then Warp selects a stream based on the following rules:
|
|
6040
6268
|
(1) If the destination array is on a CUDA device, use the current stream on the destination device.
|
|
6041
6269
|
(2) Otherwise, if the source array is on a CUDA device, use the current stream on the source device.
|
|
6042
6270
|
|
|
6043
6271
|
If neither source nor destination are on a CUDA device, no stream is used for the copy.
|
|
6044
|
-
|
|
6045
6272
|
"""
|
|
6046
6273
|
|
|
6047
6274
|
from warp.context import runtime
|
|
@@ -6266,8 +6493,8 @@ def type_str(t):
|
|
|
6266
6493
|
return f"Transformation[{type_str(t._wp_scalar_type_)}]"
|
|
6267
6494
|
|
|
6268
6495
|
raise TypeError("Invalid vector or matrix dimensions")
|
|
6269
|
-
elif
|
|
6270
|
-
args_repr = ", ".join(type_str(x) for x in
|
|
6496
|
+
elif get_origin(t) in (list, tuple):
|
|
6497
|
+
args_repr = ", ".join(type_str(x) for x in get_args(t))
|
|
6271
6498
|
return f"{t._name}[{args_repr}]"
|
|
6272
6499
|
elif t is Ellipsis:
|
|
6273
6500
|
return "..."
|
|
@@ -6423,6 +6650,26 @@ def export_functions_rst(file): # pragma: no cover
|
|
|
6423
6650
|
def export_stubs(file): # pragma: no cover
|
|
6424
6651
|
"""Generates stub file for auto-complete of builtin functions"""
|
|
6425
6652
|
|
|
6653
|
+
# Add copyright notice
|
|
6654
|
+
print(
|
|
6655
|
+
"""# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
6656
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
6657
|
+
#
|
|
6658
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6659
|
+
# you may not use this file except in compliance with the License.
|
|
6660
|
+
# You may obtain a copy of the License at
|
|
6661
|
+
#
|
|
6662
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
6663
|
+
#
|
|
6664
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
6665
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
6666
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
6667
|
+
# See the License for the specific language governing permissions and
|
|
6668
|
+
# limitations under the License.
|
|
6669
|
+
""",
|
|
6670
|
+
file=file,
|
|
6671
|
+
)
|
|
6672
|
+
|
|
6426
6673
|
print(
|
|
6427
6674
|
"# Autogenerated file, do not edit, this file provides stubs for builtins autocomplete in VSCode, PyCharm, etc",
|
|
6428
6675
|
file=file,
|