warp-lang 1.6.1__py3-none-macosx_10_13_universal2.whl → 1.7.0__py3-none-macosx_10_13_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +21 -7
- warp/autograd.py +14 -6
- warp/bin/libwarp-clang.dylib +0 -0
- warp/bin/libwarp.dylib +0 -0
- warp/build.py +424 -6
- warp/build_dll.py +20 -20
- warp/builtins.py +467 -368
- warp/codegen.py +193 -125
- warp/config.py +56 -12
- warp/constants.py +14 -6
- warp/context.py +524 -277
- warp/dlpack.py +22 -12
- warp/examples/__init__.py +14 -6
- warp/examples/assets/nonuniform.usd +0 -0
- warp/examples/assets/nvidia_logo.png +0 -0
- warp/examples/benchmarks/benchmark_api.py +14 -6
- warp/examples/benchmarks/benchmark_cloth.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_cupy.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_jax.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_numba.py +15 -0
- warp/examples/benchmarks/benchmark_cloth_numpy.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_paddle.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_pytorch.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_taichi.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_warp.py +14 -6
- warp/examples/benchmarks/benchmark_gemm.py +82 -48
- warp/examples/benchmarks/benchmark_interop_paddle.py +14 -6
- warp/examples/benchmarks/benchmark_interop_torch.py +14 -6
- warp/examples/benchmarks/benchmark_launches.py +14 -6
- warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
- warp/examples/browse.py +14 -6
- warp/examples/core/example_cupy.py +14 -6
- warp/examples/core/example_dem.py +14 -6
- warp/examples/core/example_fluid.py +14 -6
- warp/examples/core/example_graph_capture.py +14 -6
- warp/examples/core/example_marching_cubes.py +14 -6
- warp/examples/core/example_mesh.py +14 -6
- warp/examples/core/example_mesh_intersect.py +14 -6
- warp/examples/core/example_nvdb.py +14 -6
- warp/examples/core/example_raycast.py +14 -6
- warp/examples/core/example_raymarch.py +14 -6
- warp/examples/core/example_render_opengl.py +14 -6
- warp/examples/core/example_sample_mesh.py +300 -0
- warp/examples/core/example_sph.py +14 -6
- warp/examples/core/example_torch.py +14 -6
- warp/examples/core/example_wave.py +14 -6
- warp/examples/fem/example_adaptive_grid.py +14 -6
- warp/examples/fem/example_apic_fluid.py +15 -7
- warp/examples/fem/example_burgers.py +16 -8
- warp/examples/fem/example_convection_diffusion.py +14 -6
- warp/examples/fem/example_convection_diffusion_dg.py +14 -6
- warp/examples/fem/example_deformed_geometry.py +15 -7
- warp/examples/fem/example_diffusion.py +14 -6
- warp/examples/fem/example_diffusion_3d.py +14 -6
- warp/examples/fem/example_diffusion_mgpu.py +14 -6
- warp/examples/fem/example_distortion_energy.py +15 -7
- warp/examples/fem/example_magnetostatics.py +20 -12
- warp/examples/fem/example_mixed_elasticity.py +14 -6
- warp/examples/fem/example_navier_stokes.py +14 -6
- warp/examples/fem/example_nonconforming_contact.py +14 -6
- warp/examples/fem/example_stokes.py +14 -6
- warp/examples/fem/example_stokes_transfer.py +14 -6
- warp/examples/fem/example_streamlines.py +14 -6
- warp/examples/fem/utils.py +24 -3
- warp/examples/interop/example_jax_callable.py +116 -0
- warp/examples/interop/example_jax_ffi_callback.py +132 -0
- warp/examples/interop/example_jax_kernel.py +205 -0
- warp/examples/optim/example_bounce.py +14 -6
- warp/examples/optim/example_cloth_throw.py +14 -6
- warp/examples/optim/example_diffray.py +14 -6
- warp/examples/optim/example_drone.py +14 -6
- warp/examples/optim/example_fluid_checkpoint.py +497 -0
- warp/examples/optim/example_inverse_kinematics.py +14 -6
- warp/examples/optim/example_inverse_kinematics_torch.py +14 -6
- warp/examples/optim/example_softbody_properties.py +14 -6
- warp/examples/optim/example_spring_cage.py +14 -6
- warp/examples/optim/example_trajectory.py +14 -6
- warp/examples/sim/example_cartpole.py +14 -6
- warp/examples/sim/example_cloth.py +14 -6
- warp/examples/sim/example_cloth_self_contact.py +14 -6
- warp/examples/sim/example_granular.py +14 -6
- warp/examples/sim/example_granular_collision_sdf.py +14 -6
- warp/examples/sim/example_jacobian_ik.py +14 -6
- warp/examples/sim/example_particle_chain.py +14 -6
- warp/examples/sim/example_quadruped.py +14 -6
- warp/examples/sim/example_rigid_chain.py +14 -6
- warp/examples/sim/example_rigid_contact.py +14 -6
- warp/examples/sim/example_rigid_force.py +14 -6
- warp/examples/sim/example_rigid_gyroscopic.py +14 -6
- warp/examples/sim/example_rigid_soft_contact.py +14 -6
- warp/examples/sim/example_soft_body.py +14 -6
- warp/examples/tile/example_tile_cholesky.py +14 -6
- warp/examples/tile/example_tile_convolution.py +14 -6
- warp/examples/tile/example_tile_fft.py +14 -6
- warp/examples/tile/example_tile_filtering.py +14 -6
- warp/examples/tile/example_tile_matmul.py +16 -10
- warp/examples/tile/example_tile_mlp.py +14 -6
- warp/examples/tile/example_tile_nbody.py +14 -6
- warp/examples/tile/example_tile_walker.py +14 -6
- warp/fabric.py +15 -0
- warp/fem/__init__.py +26 -1
- warp/fem/adaptivity.py +19 -4
- warp/fem/cache.py +15 -0
- warp/fem/dirichlet.py +15 -0
- warp/fem/domain.py +15 -0
- warp/fem/field/__init__.py +15 -0
- warp/fem/field/field.py +15 -0
- warp/fem/field/nodal_field.py +37 -68
- warp/fem/field/restriction.py +15 -0
- warp/fem/field/virtual.py +77 -23
- warp/fem/geometry/__init__.py +15 -0
- warp/fem/geometry/adaptive_nanogrid.py +24 -10
- warp/fem/geometry/closest_point.py +16 -1
- warp/fem/geometry/deformed_geometry.py +20 -2
- warp/fem/geometry/element.py +15 -0
- warp/fem/geometry/geometry.py +20 -0
- warp/fem/geometry/grid_2d.py +27 -12
- warp/fem/geometry/grid_3d.py +27 -15
- warp/fem/geometry/hexmesh.py +20 -7
- warp/fem/geometry/nanogrid.py +24 -11
- warp/fem/geometry/partition.py +15 -0
- warp/fem/geometry/quadmesh.py +28 -13
- warp/fem/geometry/tetmesh.py +18 -4
- warp/fem/geometry/trimesh.py +18 -8
- warp/fem/integrate.py +277 -93
- warp/fem/linalg.py +20 -5
- warp/fem/operator.py +15 -0
- warp/fem/polynomial.py +15 -0
- warp/fem/quadrature/__init__.py +15 -0
- warp/fem/quadrature/pic_quadrature.py +52 -22
- warp/fem/quadrature/quadrature.py +209 -25
- warp/fem/space/__init__.py +16 -1
- warp/fem/space/basis_function_space.py +19 -2
- warp/fem/space/basis_space.py +40 -18
- warp/fem/space/dof_mapper.py +15 -0
- warp/fem/space/function_space.py +15 -0
- warp/fem/space/grid_2d_function_space.py +15 -0
- warp/fem/space/grid_3d_function_space.py +15 -0
- warp/fem/space/hexmesh_function_space.py +17 -2
- warp/fem/space/nanogrid_function_space.py +15 -0
- warp/fem/space/partition.py +21 -2
- warp/fem/space/quadmesh_function_space.py +23 -8
- warp/fem/space/restriction.py +15 -0
- warp/fem/space/shape/__init__.py +15 -0
- warp/fem/space/shape/cube_shape_function.py +38 -23
- warp/fem/space/shape/shape_function.py +15 -0
- warp/fem/space/shape/square_shape_function.py +27 -12
- warp/fem/space/shape/tet_shape_function.py +15 -0
- warp/fem/space/shape/triangle_shape_function.py +16 -1
- warp/fem/space/tetmesh_function_space.py +18 -3
- warp/fem/space/topology.py +15 -0
- warp/fem/space/trimesh_function_space.py +17 -2
- warp/fem/types.py +15 -0
- warp/fem/utils.py +27 -6
- warp/jax.py +28 -7
- warp/jax_experimental/__init__.py +16 -0
- warp/{jax_experimental.py → jax_experimental/custom_call.py} +28 -33
- warp/jax_experimental/ffi.py +698 -0
- warp/jax_experimental/xla_ffi.py +602 -0
- warp/math.py +103 -6
- warp/native/array.h +28 -6
- warp/native/builtin.h +44 -9
- warp/native/bvh.cpp +18 -7
- warp/native/bvh.cu +57 -20
- warp/native/bvh.h +17 -7
- warp/native/clang/clang.cpp +45 -9
- warp/native/coloring.cpp +15 -6
- warp/native/crt.cpp +15 -6
- warp/native/crt.h +15 -6
- warp/native/cuda_crt.h +15 -6
- warp/native/cuda_util.cpp +29 -6
- warp/native/cuda_util.h +17 -6
- warp/native/error.cpp +15 -6
- warp/native/error.h +15 -6
- warp/native/exports.h +85 -63
- warp/native/fabric.h +15 -6
- warp/native/hashgrid.cpp +15 -6
- warp/native/hashgrid.cu +15 -6
- warp/native/hashgrid.h +15 -6
- warp/native/initializer_array.h +15 -6
- warp/native/intersect.h +41 -32
- warp/native/intersect_adj.h +48 -39
- warp/native/intersect_tri.h +17 -0
- warp/native/marching.cpp +16 -0
- warp/native/marching.cu +16 -7
- warp/native/marching.h +17 -0
- warp/native/mat.h +528 -15
- warp/native/mathdx.cpp +15 -6
- warp/native/matnn.h +15 -6
- warp/native/mesh.cpp +15 -6
- warp/native/mesh.cu +15 -6
- warp/native/mesh.h +25 -16
- warp/native/noise.h +15 -6
- warp/native/quat.h +114 -17
- warp/native/rand.h +21 -6
- warp/native/range.h +15 -6
- warp/native/reduce.cpp +15 -6
- warp/native/reduce.cu +15 -6
- warp/native/runlength_encode.cpp +15 -6
- warp/native/runlength_encode.cu +15 -6
- warp/native/scan.cpp +15 -6
- warp/native/scan.cu +15 -6
- warp/native/scan.h +15 -6
- warp/native/solid_angle.h +17 -0
- warp/native/sort.cpp +137 -65
- warp/native/sort.cu +167 -21
- warp/native/sort.h +23 -7
- warp/native/sparse.cpp +58 -28
- warp/native/sparse.cu +67 -23
- warp/native/spatial.h +15 -6
- warp/native/svd.h +131 -6
- warp/native/temp_buffer.h +15 -6
- warp/native/tile.h +316 -111
- warp/native/tile_reduce.h +61 -9
- warp/native/vec.h +83 -13
- warp/native/volume.cpp +100 -119
- warp/native/volume.cu +15 -6
- warp/native/volume.h +15 -6
- warp/native/volume_builder.cu +40 -16
- warp/native/volume_builder.h +21 -6
- warp/native/volume_impl.h +15 -6
- warp/native/warp.cpp +20 -12
- warp/native/warp.cu +114 -16
- warp/native/warp.h +34 -16
- warp/optim/__init__.py +14 -6
- warp/optim/adam.py +14 -6
- warp/optim/linear.py +25 -10
- warp/optim/sgd.py +14 -6
- warp/paddle.py +14 -6
- warp/render/__init__.py +14 -6
- warp/render/render_opengl.py +14 -6
- warp/render/render_usd.py +14 -6
- warp/render/utils.py +14 -6
- warp/sim/__init__.py +14 -7
- warp/sim/articulation.py +18 -10
- warp/sim/collide.py +35 -16
- warp/sim/graph_coloring.py +14 -6
- warp/sim/import_mjcf.py +463 -162
- warp/sim/import_snu.py +14 -7
- warp/sim/import_urdf.py +46 -18
- warp/sim/import_usd.py +14 -7
- warp/sim/inertia.py +14 -6
- warp/sim/integrator.py +14 -6
- warp/sim/integrator_euler.py +19 -11
- warp/sim/integrator_featherstone.py +17 -16
- warp/sim/integrator_vbd.py +222 -8
- warp/sim/integrator_xpbd.py +19 -11
- warp/sim/model.py +56 -19
- warp/sim/particles.py +14 -6
- warp/sim/render.py +14 -6
- warp/sim/utils.py +17 -2
- warp/sparse.py +657 -555
- warp/stubs.py +231 -19
- warp/tape.py +14 -6
- warp/tests/aux_test_class_kernel.py +14 -6
- warp/tests/aux_test_compile_consts_dummy.py +14 -6
- warp/tests/aux_test_conditional_unequal_types_kernels.py +14 -6
- warp/tests/aux_test_dependent.py +14 -6
- warp/tests/aux_test_grad_customs.py +14 -6
- warp/tests/aux_test_instancing_gc.py +14 -6
- warp/tests/aux_test_module_unload.py +14 -6
- warp/tests/aux_test_name_clash1.py +14 -6
- warp/tests/aux_test_name_clash2.py +14 -6
- warp/tests/aux_test_unresolved_func.py +14 -6
- warp/tests/aux_test_unresolved_symbol.py +14 -6
- warp/tests/cuda/__init__.py +0 -0
- warp/tests/{test_async.py → cuda/test_async.py} +14 -6
- warp/tests/{test_ipc.py → cuda/test_ipc.py} +14 -6
- warp/tests/{test_mempool.py → cuda/test_mempool.py} +53 -6
- warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +14 -6
- warp/tests/{test_peer.py → cuda/test_peer.py} +14 -6
- warp/tests/{test_pinned.py → cuda/test_pinned.py} +14 -6
- warp/tests/{test_streams.py → cuda/test_streams.py} +85 -6
- warp/tests/geometry/__init__.py +0 -0
- warp/tests/{test_bvh.py → geometry/test_bvh.py} +14 -6
- warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +14 -6
- warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +14 -6
- warp/tests/{test_mesh.py → geometry/test_mesh.py} +14 -6
- warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +14 -6
- warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +80 -69
- warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +15 -7
- warp/tests/{test_volume.py → geometry/test_volume.py} +55 -12
- warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +14 -6
- warp/tests/interop/__init__.py +0 -0
- warp/tests/{test_dlpack.py → interop/test_dlpack.py} +42 -11
- warp/tests/{test_jax.py → interop/test_jax.py} +14 -6
- warp/tests/{test_paddle.py → interop/test_paddle.py} +14 -6
- warp/tests/{test_torch.py → interop/test_torch.py} +14 -6
- warp/tests/run_coverage_serial.py +14 -6
- warp/tests/sim/__init__.py +0 -0
- warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +23 -16
- warp/tests/{flaky_test_sim_grad.py → sim/flaky_test_sim_grad.py} +14 -6
- warp/tests/{test_collision.py → sim/test_collision.py} +16 -8
- warp/tests/{test_coloring.py → sim/test_coloring.py} +14 -7
- warp/tests/{test_model.py → sim/test_model.py} +55 -7
- warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +14 -6
- warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +16 -7
- warp/tests/sim/test_vbd.py +597 -0
- warp/tests/test_adam.py +14 -6
- warp/tests/test_arithmetic.py +14 -6
- warp/tests/test_array.py +14 -6
- warp/tests/test_array_reduce.py +14 -6
- warp/tests/test_assert.py +14 -6
- warp/tests/test_atomic.py +14 -6
- warp/tests/test_bool.py +15 -7
- warp/tests/test_builtins_resolution.py +14 -6
- warp/tests/test_closest_point_edge_edge.py +14 -6
- warp/tests/test_codegen.py +14 -6
- warp/tests/test_codegen_instancing.py +14 -6
- warp/tests/test_compile_consts.py +14 -6
- warp/tests/test_conditional.py +14 -6
- warp/tests/test_context.py +14 -6
- warp/tests/test_copy.py +14 -6
- warp/tests/test_ctypes.py +14 -6
- warp/tests/test_dense.py +14 -6
- warp/tests/test_devices.py +14 -6
- warp/tests/test_examples.py +42 -42
- warp/tests/test_fabricarray.py +14 -6
- warp/tests/test_fast_math.py +14 -6
- warp/tests/test_fem.py +37 -10
- warp/tests/test_fp16.py +14 -6
- warp/tests/test_func.py +14 -6
- warp/tests/test_future_annotations.py +14 -6
- warp/tests/test_generics.py +14 -6
- warp/tests/test_grad.py +14 -6
- warp/tests/test_grad_customs.py +14 -6
- warp/tests/test_grad_debug.py +14 -6
- warp/tests/test_implicit_init.py +14 -6
- warp/tests/test_import.py +14 -6
- warp/tests/test_indexedarray.py +14 -6
- warp/tests/test_intersect.py +14 -6
- warp/tests/test_iter.py +14 -6
- warp/tests/test_large.py +14 -6
- warp/tests/test_launch.py +14 -6
- warp/tests/test_lerp.py +14 -6
- warp/tests/test_linear_solvers.py +15 -11
- warp/tests/test_lvalue.py +14 -6
- warp/tests/test_mat.py +247 -85
- warp/tests/test_mat_lite.py +14 -6
- warp/tests/test_mat_scalar_ops.py +18 -10
- warp/tests/test_math.py +14 -6
- warp/tests/test_mlp.py +14 -6
- warp/tests/test_module_hashing.py +14 -6
- warp/tests/test_modules_lite.py +14 -6
- warp/tests/test_noise.py +14 -6
- warp/tests/test_operators.py +14 -6
- warp/tests/test_options.py +14 -6
- warp/tests/test_overwrite.py +15 -60
- warp/tests/test_print.py +14 -6
- warp/tests/test_quat.py +81 -52
- warp/tests/test_rand.py +58 -43
- warp/tests/test_reload.py +14 -6
- warp/tests/test_rounding.py +14 -6
- warp/tests/test_runlength_encode.py +14 -6
- warp/tests/test_scalar_ops.py +14 -6
- warp/tests/test_smoothstep.py +14 -6
- warp/tests/test_snippet.py +15 -0
- warp/tests/test_sparse.py +61 -12
- warp/tests/test_spatial.py +89 -6
- warp/tests/test_special_values.py +14 -6
- warp/tests/test_static.py +15 -7
- warp/tests/test_struct.py +14 -6
- warp/tests/test_tape.py +14 -6
- warp/tests/test_transient_module.py +14 -6
- warp/tests/test_triangle_closest_point.py +14 -6
- warp/tests/test_types.py +14 -6
- warp/tests/test_utils.py +98 -10
- warp/tests/test_vec.py +60 -40
- warp/tests/test_vec_lite.py +14 -6
- warp/tests/test_vec_scalar_ops.py +14 -6
- warp/tests/test_verify_fp.py +14 -6
- warp/tests/tile/__init__.py +0 -0
- warp/tests/{test_tile.py → tile/test_tile.py} +150 -57
- warp/tests/{test_tile_load.py → tile/test_tile_load.py} +15 -7
- warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +23 -12
- warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +39 -20
- warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +74 -7
- warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +14 -6
- warp/tests/{test_tile_view.py → tile/test_tile_view.py} +15 -7
- warp/tests/unittest_serial.py +15 -6
- warp/tests/unittest_suites.py +59 -65
- warp/tests/unittest_utils.py +16 -7
- warp/tests/walkthrough_debug.py +14 -6
- warp/thirdparty/unittest_parallel.py +15 -8
- warp/torch.py +14 -6
- warp/types.py +124 -664
- warp/utils.py +151 -78
- {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/METADATA +39 -12
- warp_lang-1.7.0.dist-info/RECORD +429 -0
- {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/WHEEL +1 -1
- warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
- warp/examples/optim/example_walker.py +0 -309
- warp/native/cutlass_gemm.cpp +0 -34
- warp/native/cutlass_gemm.cu +0 -373
- warp/tests/test_matmul.py +0 -503
- warp/tests/test_matmul_lite.py +0 -403
- warp/tests/test_vbd.py +0 -378
- warp/tests/unused_test_misc.py +0 -69
- warp_lang-1.6.1.dist-info/LICENSE.md +0 -126
- warp_lang-1.6.1.dist-info/RECORD +0 -419
- {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/top_level.txt +0 -0
warp/codegen.py
CHANGED
|
@@ -1,9 +1,17 @@
|
|
|
1
|
-
# Copyright (c) 2022 NVIDIA CORPORATION.
|
|
2
|
-
#
|
|
3
|
-
#
|
|
4
|
-
#
|
|
5
|
-
#
|
|
6
|
-
#
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
7
15
|
|
|
8
16
|
from __future__ import annotations
|
|
9
17
|
|
|
@@ -18,7 +26,7 @@ import re
|
|
|
18
26
|
import sys
|
|
19
27
|
import textwrap
|
|
20
28
|
import types
|
|
21
|
-
from typing import Any, Callable, Dict, Mapping, Optional, Sequence
|
|
29
|
+
from typing import Any, Callable, Dict, Mapping, Optional, Sequence, get_args, get_origin
|
|
22
30
|
|
|
23
31
|
import warp.config
|
|
24
32
|
from warp.types import *
|
|
@@ -49,7 +57,7 @@ class WarpCodegenKeyError(KeyError):
|
|
|
49
57
|
|
|
50
58
|
|
|
51
59
|
# map operator to function name
|
|
52
|
-
builtin_operators = {}
|
|
60
|
+
builtin_operators: Dict[type[ast.AST], str] = {}
|
|
53
61
|
|
|
54
62
|
# see https://www.ics.uci.edu/~pattis/ICS-31/lectures/opexp.pdf for a
|
|
55
63
|
# nice overview of python operators
|
|
@@ -114,16 +122,6 @@ def get_closure_cell_contents(obj):
|
|
|
114
122
|
return None
|
|
115
123
|
|
|
116
124
|
|
|
117
|
-
def get_type_origin(tp):
|
|
118
|
-
# Compatible version of `typing.get_origin()` for Python 3.7 and older.
|
|
119
|
-
return getattr(tp, "__origin__", None)
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
def get_type_args(tp):
|
|
123
|
-
# Compatible version of `typing.get_args()` for Python 3.7 and older.
|
|
124
|
-
return getattr(tp, "__args__", ())
|
|
125
|
-
|
|
126
|
-
|
|
127
125
|
def eval_annotations(annotations: Mapping[str, Any], obj: Any) -> Mapping[str, Any]:
|
|
128
126
|
"""Un-stringize annotations caused by `from __future__ import annotations` of PEP 563."""
|
|
129
127
|
# Implementation backported from `inspect.get_annotations()` for Python 3.9 and older.
|
|
@@ -407,12 +405,14 @@ class StructInstance:
|
|
|
407
405
|
|
|
408
406
|
|
|
409
407
|
class Struct:
|
|
410
|
-
|
|
408
|
+
hash: bytes
|
|
409
|
+
|
|
410
|
+
def __init__(self, cls: type, key: str, module: warp.context.Module):
|
|
411
411
|
self.cls = cls
|
|
412
412
|
self.module = module
|
|
413
413
|
self.key = key
|
|
414
|
+
self.vars: Dict[str, Var] = {}
|
|
414
415
|
|
|
415
|
-
self.vars = {}
|
|
416
416
|
annotations = get_annotations(self.cls)
|
|
417
417
|
for label, type in annotations.items():
|
|
418
418
|
self.vars[label] = Var(label, type)
|
|
@@ -583,11 +583,11 @@ class Reference:
|
|
|
583
583
|
self.value_type = value_type
|
|
584
584
|
|
|
585
585
|
|
|
586
|
-
def is_reference(type):
|
|
586
|
+
def is_reference(type: Any) -> builtins.bool:
|
|
587
587
|
return isinstance(type, Reference)
|
|
588
588
|
|
|
589
589
|
|
|
590
|
-
def strip_reference(arg):
|
|
590
|
+
def strip_reference(arg: Any) -> Any:
|
|
591
591
|
if is_reference(arg):
|
|
592
592
|
return arg.value_type
|
|
593
593
|
else:
|
|
@@ -615,7 +615,15 @@ def compute_type_str(base_name, template_params):
|
|
|
615
615
|
|
|
616
616
|
|
|
617
617
|
class Var:
|
|
618
|
-
def __init__(
|
|
618
|
+
def __init__(
|
|
619
|
+
self,
|
|
620
|
+
label: str,
|
|
621
|
+
type: type,
|
|
622
|
+
requires_grad: builtins.bool = False,
|
|
623
|
+
constant: Optional[builtins.bool] = None,
|
|
624
|
+
prefix: builtins.bool = True,
|
|
625
|
+
relative_lineno: Optional[int] = None,
|
|
626
|
+
):
|
|
619
627
|
# convert built-in types to wp types
|
|
620
628
|
if type == float:
|
|
621
629
|
type = float32
|
|
@@ -638,11 +646,14 @@ class Var:
|
|
|
638
646
|
# used to associate a view array Var with its parent array Var
|
|
639
647
|
self.parent = None
|
|
640
648
|
|
|
649
|
+
# Used to associate the variable with the Python statement that resulted in it being created.
|
|
650
|
+
self.relative_lineno = relative_lineno
|
|
651
|
+
|
|
641
652
|
def __str__(self):
|
|
642
653
|
return self.label
|
|
643
654
|
|
|
644
655
|
@staticmethod
|
|
645
|
-
def type_to_ctype(t, value_type=False):
|
|
656
|
+
def type_to_ctype(t: type, value_type: builtins.bool = False) -> str:
|
|
646
657
|
if is_array(t):
|
|
647
658
|
if hasattr(t.dtype, "_wp_generic_type_str_"):
|
|
648
659
|
dtypestr = compute_type_str(f"wp::{t.dtype._wp_generic_type_str_}", t.dtype._wp_type_params_)
|
|
@@ -673,7 +684,7 @@ class Var:
|
|
|
673
684
|
else:
|
|
674
685
|
return f"wp::{t.__name__}"
|
|
675
686
|
|
|
676
|
-
def ctype(self, value_type=False):
|
|
687
|
+
def ctype(self, value_type: builtins.bool = False) -> str:
|
|
677
688
|
return Var.type_to_ctype(self.type, value_type)
|
|
678
689
|
|
|
679
690
|
def emit(self, prefix: str = "var"):
|
|
@@ -795,7 +806,7 @@ def func_match_args(func, arg_types, kwarg_types):
|
|
|
795
806
|
return True
|
|
796
807
|
|
|
797
808
|
|
|
798
|
-
def get_arg_type(arg: Union[Var, Any]):
|
|
809
|
+
def get_arg_type(arg: Union[Var, Any]) -> type:
|
|
799
810
|
if isinstance(arg, str):
|
|
800
811
|
return str
|
|
801
812
|
|
|
@@ -811,7 +822,7 @@ def get_arg_type(arg: Union[Var, Any]):
|
|
|
811
822
|
return type(arg)
|
|
812
823
|
|
|
813
824
|
|
|
814
|
-
def get_arg_value(arg:
|
|
825
|
+
def get_arg_value(arg: Any) -> Any:
|
|
815
826
|
if isinstance(arg, Sequence):
|
|
816
827
|
return tuple(get_arg_value(x) for x in arg)
|
|
817
828
|
|
|
@@ -859,6 +870,9 @@ class Adjoint:
|
|
|
859
870
|
"please save it on a file and use `importlib` if needed."
|
|
860
871
|
) from e
|
|
861
872
|
|
|
873
|
+
# Indicates where the function definition starts (excludes decorators)
|
|
874
|
+
adj.fun_def_lineno = None
|
|
875
|
+
|
|
862
876
|
# get function source code
|
|
863
877
|
adj.source = inspect.getsource(func)
|
|
864
878
|
# ensures that indented class methods can be parsed as kernels
|
|
@@ -933,9 +947,6 @@ class Adjoint:
|
|
|
933
947
|
# for unit testing errors being spit out from kernels.
|
|
934
948
|
adj.skip_build = False
|
|
935
949
|
|
|
936
|
-
# Collect the LTOIR required at link-time
|
|
937
|
-
adj.ltoirs = []
|
|
938
|
-
|
|
939
950
|
# allocate extra space for a function call that requires its
|
|
940
951
|
# own shared memory space, we treat shared memory as a stack
|
|
941
952
|
# where each function pushes and pops space off, the extra
|
|
@@ -1125,7 +1136,7 @@ class Adjoint:
|
|
|
1125
1136
|
name = str(index)
|
|
1126
1137
|
|
|
1127
1138
|
# allocate new variable
|
|
1128
|
-
v = Var(name, type=type, constant=constant)
|
|
1139
|
+
v = Var(name, type=type, constant=constant, relative_lineno=adj.lineno)
|
|
1129
1140
|
|
|
1130
1141
|
adj.variables.append(v)
|
|
1131
1142
|
|
|
@@ -1150,11 +1161,44 @@ class Adjoint:
|
|
|
1150
1161
|
|
|
1151
1162
|
return var
|
|
1152
1163
|
|
|
1153
|
-
|
|
1154
|
-
|
|
1164
|
+
def get_line_directive(adj, statement: str, relative_lineno: Optional[int] = None) -> Optional[str]:
|
|
1165
|
+
"""Get a line directive for the given statement.
|
|
1166
|
+
|
|
1167
|
+
Args:
|
|
1168
|
+
statement: The statement to get the line directive for.
|
|
1169
|
+
relative_lineno: The line number of the statement relative to the function.
|
|
1170
|
+
|
|
1171
|
+
Returns:
|
|
1172
|
+
A line directive for the given statement, or None if no line directive is needed.
|
|
1173
|
+
"""
|
|
1174
|
+
|
|
1175
|
+
# lineinfo is enabled by default in debug mode regardless of the builder option, don't want to unnecessarily
|
|
1176
|
+
# emit line directives in generated code if it's not being compiled with line information
|
|
1177
|
+
lineinfo_enabled = (
|
|
1178
|
+
adj.builder_options.get("lineinfo", False) or adj.builder_options.get("mode", "release") == "debug"
|
|
1179
|
+
)
|
|
1180
|
+
|
|
1181
|
+
if relative_lineno is not None and lineinfo_enabled and warp.config.line_directives:
|
|
1182
|
+
is_comment = statement.strip().startswith("//")
|
|
1183
|
+
if not is_comment:
|
|
1184
|
+
line = relative_lineno + adj.fun_lineno
|
|
1185
|
+
# Convert backslashes to forward slashes for CUDA compatibility
|
|
1186
|
+
normalized_path = adj.filename.replace("\\", "/")
|
|
1187
|
+
return f'#line {line} "{normalized_path}"'
|
|
1188
|
+
return None
|
|
1189
|
+
|
|
1190
|
+
def add_forward(adj, statement: str, replay: Optional[str] = None, skip_replay: builtins.bool = False) -> None:
|
|
1191
|
+
"""Append a statement to the forward pass."""
|
|
1192
|
+
|
|
1193
|
+
if line_directive := adj.get_line_directive(statement, adj.lineno):
|
|
1194
|
+
adj.blocks[-1].body_forward.append(line_directive)
|
|
1195
|
+
|
|
1155
1196
|
adj.blocks[-1].body_forward.append(adj.indentation + statement)
|
|
1156
1197
|
|
|
1157
1198
|
if not skip_replay:
|
|
1199
|
+
if line_directive:
|
|
1200
|
+
adj.blocks[-1].body_replay.append(line_directive)
|
|
1201
|
+
|
|
1158
1202
|
if replay:
|
|
1159
1203
|
# if custom replay specified then output it
|
|
1160
1204
|
adj.blocks[-1].body_replay.append(adj.indentation + replay)
|
|
@@ -1163,9 +1207,14 @@ class Adjoint:
|
|
|
1163
1207
|
adj.blocks[-1].body_replay.append(adj.indentation + statement)
|
|
1164
1208
|
|
|
1165
1209
|
# append a statement to the reverse pass
|
|
1166
|
-
def add_reverse(adj, statement):
|
|
1210
|
+
def add_reverse(adj, statement: str) -> None:
|
|
1211
|
+
"""Append a statement to the reverse pass."""
|
|
1212
|
+
|
|
1167
1213
|
adj.blocks[-1].body_reverse.append(adj.indentation + statement)
|
|
1168
1214
|
|
|
1215
|
+
if line_directive := adj.get_line_directive(statement, adj.lineno):
|
|
1216
|
+
adj.blocks[-1].body_reverse.append(line_directive)
|
|
1217
|
+
|
|
1169
1218
|
def add_constant(adj, n):
|
|
1170
1219
|
output = adj.add_var(type=type(n), constant=n)
|
|
1171
1220
|
return output
|
|
@@ -1273,7 +1322,7 @@ class Adjoint:
|
|
|
1273
1322
|
|
|
1274
1323
|
# Bind the positional and keyword arguments to the function's signature
|
|
1275
1324
|
# in order to process them as Python does it.
|
|
1276
|
-
bound_args = func.signature.bind(*args, **kwargs)
|
|
1325
|
+
bound_args: inspect.BoundArguments = func.signature.bind(*args, **kwargs)
|
|
1277
1326
|
|
|
1278
1327
|
# Type args are the “compile time” argument values we get from codegen.
|
|
1279
1328
|
# For example, when calling `wp.vec3f(...)` from within a kernel,
|
|
@@ -1616,6 +1665,8 @@ class Adjoint:
|
|
|
1616
1665
|
adj.blocks[-1].body_reverse.extend(reversed(reverse))
|
|
1617
1666
|
|
|
1618
1667
|
def emit_FunctionDef(adj, node):
|
|
1668
|
+
adj.fun_def_lineno = node.lineno
|
|
1669
|
+
|
|
1619
1670
|
for f in node.body:
|
|
1620
1671
|
# Skip variable creation for standalone constants, including docstrings
|
|
1621
1672
|
if isinstance(f, ast.Expr) and isinstance(f.value, ast.Constant):
|
|
@@ -1680,7 +1731,7 @@ class Adjoint:
|
|
|
1680
1731
|
|
|
1681
1732
|
if var1 != var2:
|
|
1682
1733
|
# insert a phi function that selects var1, var2 based on cond
|
|
1683
|
-
out = adj.add_builtin_call("
|
|
1734
|
+
out = adj.add_builtin_call("where", [cond, var2, var1])
|
|
1684
1735
|
adj.symbols[sym] = out
|
|
1685
1736
|
|
|
1686
1737
|
symbols_prev = adj.symbols.copy()
|
|
@@ -1704,7 +1755,7 @@ class Adjoint:
|
|
|
1704
1755
|
if var1 != var2:
|
|
1705
1756
|
# insert a phi function that selects var1, var2 based on cond
|
|
1706
1757
|
# note the reversed order of vars since we want to use !cond as our select
|
|
1707
|
-
out = adj.add_builtin_call("
|
|
1758
|
+
out = adj.add_builtin_call("where", [cond, var1, var2])
|
|
1708
1759
|
adj.symbols[sym] = out
|
|
1709
1760
|
|
|
1710
1761
|
def emit_Compare(adj, node):
|
|
@@ -1848,25 +1899,6 @@ class Adjoint:
|
|
|
1848
1899
|
) from e
|
|
1849
1900
|
raise WarpCodegenAttributeError(f"Error, `{node.attr}` is not an attribute of '{aggregate}'") from e
|
|
1850
1901
|
|
|
1851
|
-
def emit_String(adj, node):
|
|
1852
|
-
# string constant
|
|
1853
|
-
return adj.add_constant(node.s)
|
|
1854
|
-
|
|
1855
|
-
def emit_Num(adj, node):
|
|
1856
|
-
# lookup constant, if it has already been assigned then return existing var
|
|
1857
|
-
key = (node.n, type(node.n))
|
|
1858
|
-
|
|
1859
|
-
if key in adj.symbols:
|
|
1860
|
-
return adj.symbols[key]
|
|
1861
|
-
else:
|
|
1862
|
-
out = adj.add_constant(node.n)
|
|
1863
|
-
adj.symbols[key] = out
|
|
1864
|
-
return out
|
|
1865
|
-
|
|
1866
|
-
def emit_Ellipsis(adj, node):
|
|
1867
|
-
# stubbed @wp.native_func
|
|
1868
|
-
return
|
|
1869
|
-
|
|
1870
1902
|
def emit_Assert(adj, node):
|
|
1871
1903
|
# eval condition
|
|
1872
1904
|
cond = adj.eval(node.test)
|
|
@@ -1878,24 +1910,11 @@ class Adjoint:
|
|
|
1878
1910
|
|
|
1879
1911
|
adj.add_forward(f'assert(("{escaped_segment}",{cond.emit()}));')
|
|
1880
1912
|
|
|
1881
|
-
def emit_NameConstant(adj, node):
|
|
1882
|
-
if node.value:
|
|
1883
|
-
return adj.add_constant(node.value)
|
|
1884
|
-
elif node.value is None:
|
|
1885
|
-
raise WarpCodegenTypeError("None type unsupported")
|
|
1886
|
-
else:
|
|
1887
|
-
return adj.add_constant(False)
|
|
1888
|
-
|
|
1889
1913
|
def emit_Constant(adj, node):
|
|
1890
|
-
if
|
|
1891
|
-
|
|
1892
|
-
elif isinstance(node, ast.Num):
|
|
1893
|
-
return adj.emit_Num(node)
|
|
1894
|
-
elif isinstance(node, ast.Ellipsis):
|
|
1895
|
-
return adj.emit_Ellipsis(node)
|
|
1914
|
+
if node.value is None:
|
|
1915
|
+
raise WarpCodegenTypeError("None type unsupported")
|
|
1896
1916
|
else:
|
|
1897
|
-
|
|
1898
|
-
return adj.emit_NameConstant(node)
|
|
1917
|
+
return adj.add_constant(node.value)
|
|
1899
1918
|
|
|
1900
1919
|
def emit_BinOp(adj, node):
|
|
1901
1920
|
# evaluate binary operator arguments
|
|
@@ -1989,10 +2008,11 @@ class Adjoint:
|
|
|
1989
2008
|
adj.end_while()
|
|
1990
2009
|
|
|
1991
2010
|
def eval_num(adj, a):
|
|
1992
|
-
if isinstance(a, ast.
|
|
1993
|
-
return True, a.
|
|
1994
|
-
if isinstance(a, ast.UnaryOp) and isinstance(a.op, ast.USub) and isinstance(a.operand, ast.
|
|
1995
|
-
|
|
2011
|
+
if isinstance(a, ast.Constant):
|
|
2012
|
+
return True, a.value
|
|
2013
|
+
if isinstance(a, ast.UnaryOp) and isinstance(a.op, ast.USub) and isinstance(a.operand, ast.Constant):
|
|
2014
|
+
# Negative constant
|
|
2015
|
+
return True, -a.operand.value
|
|
1996
2016
|
|
|
1997
2017
|
# try and resolve the expression to an object
|
|
1998
2018
|
# e.g.: wp.constant in the globals scope
|
|
@@ -2522,8 +2542,8 @@ class Adjoint:
|
|
|
2522
2542
|
f"Warning: mutating {node_source} in function {adj.fun_name} at {adj.filename}:{lineno}: this is a non-differentiable operation.\n{line}\n"
|
|
2523
2543
|
)
|
|
2524
2544
|
else:
|
|
2525
|
-
if
|
|
2526
|
-
out = adj.add_builtin_call("
|
|
2545
|
+
if warp.config.enable_vector_component_overwrites:
|
|
2546
|
+
out = adj.add_builtin_call("assign_copy", [target, *indices, rhs])
|
|
2527
2547
|
|
|
2528
2548
|
# re-point target symbol to out var
|
|
2529
2549
|
for id in adj.symbols:
|
|
@@ -2531,8 +2551,7 @@ class Adjoint:
|
|
|
2531
2551
|
adj.symbols[id] = out
|
|
2532
2552
|
break
|
|
2533
2553
|
else:
|
|
2534
|
-
|
|
2535
|
-
adj.add_builtin_call("store", [attr, rhs])
|
|
2554
|
+
adj.add_builtin_call("assign_inplace", [target, *indices, rhs])
|
|
2536
2555
|
|
|
2537
2556
|
else:
|
|
2538
2557
|
raise WarpCodegenError(
|
|
@@ -2575,8 +2594,8 @@ class Adjoint:
|
|
|
2575
2594
|
attr = adj.add_builtin_call("indexref", [aggregate, index])
|
|
2576
2595
|
adj.add_builtin_call("store", [attr, rhs])
|
|
2577
2596
|
else:
|
|
2578
|
-
if
|
|
2579
|
-
out = adj.add_builtin_call("
|
|
2597
|
+
if warp.config.enable_vector_component_overwrites:
|
|
2598
|
+
out = adj.add_builtin_call("assign_copy", [aggregate, index, rhs])
|
|
2580
2599
|
|
|
2581
2600
|
# re-point target symbol to out var
|
|
2582
2601
|
for id in adj.symbols:
|
|
@@ -2584,8 +2603,7 @@ class Adjoint:
|
|
|
2584
2603
|
adj.symbols[id] = out
|
|
2585
2604
|
break
|
|
2586
2605
|
else:
|
|
2587
|
-
|
|
2588
|
-
adj.add_builtin_call("store", [attr, rhs])
|
|
2606
|
+
adj.add_builtin_call("assign_inplace", [aggregate, index, rhs])
|
|
2589
2607
|
|
|
2590
2608
|
else:
|
|
2591
2609
|
attr = adj.emit_Attribute(lhs)
|
|
@@ -2691,10 +2709,12 @@ class Adjoint:
|
|
|
2691
2709
|
|
|
2692
2710
|
elif type_is_vector(target_type) or type_is_quaternion(target_type) or type_is_matrix(target_type):
|
|
2693
2711
|
if isinstance(node.op, ast.Add):
|
|
2694
|
-
adj.add_builtin_call("
|
|
2712
|
+
adj.add_builtin_call("add_inplace", [target, *indices, rhs])
|
|
2695
2713
|
elif isinstance(node.op, ast.Sub):
|
|
2696
|
-
adj.add_builtin_call("
|
|
2714
|
+
adj.add_builtin_call("sub_inplace", [target, *indices, rhs])
|
|
2697
2715
|
else:
|
|
2716
|
+
if warp.config.verbose:
|
|
2717
|
+
print(f"Warning: in-place op {node.op} is not differentiable")
|
|
2698
2718
|
make_new_assign_statement()
|
|
2699
2719
|
return
|
|
2700
2720
|
|
|
@@ -2724,9 +2744,6 @@ class Adjoint:
|
|
|
2724
2744
|
ast.BoolOp: emit_BoolOp,
|
|
2725
2745
|
ast.Name: emit_Name,
|
|
2726
2746
|
ast.Attribute: emit_Attribute,
|
|
2727
|
-
ast.Str: emit_String, # Deprecated in 3.8; use Constant
|
|
2728
|
-
ast.Num: emit_Num, # Deprecated in 3.8; use Constant
|
|
2729
|
-
ast.NameConstant: emit_NameConstant, # Deprecated in 3.8; use Constant
|
|
2730
2747
|
ast.Constant: emit_Constant,
|
|
2731
2748
|
ast.BinOp: emit_BinOp,
|
|
2732
2749
|
ast.UnaryOp: emit_UnaryOp,
|
|
@@ -2736,14 +2753,13 @@ class Adjoint:
|
|
|
2736
2753
|
ast.Continue: emit_Continue,
|
|
2737
2754
|
ast.Expr: emit_Expr,
|
|
2738
2755
|
ast.Call: emit_Call,
|
|
2739
|
-
ast.Index: emit_Index, # Deprecated in 3.
|
|
2756
|
+
ast.Index: emit_Index, # Deprecated in 3.9
|
|
2740
2757
|
ast.Subscript: emit_Subscript,
|
|
2741
2758
|
ast.Assign: emit_Assign,
|
|
2742
2759
|
ast.Return: emit_Return,
|
|
2743
2760
|
ast.AugAssign: emit_AugAssign,
|
|
2744
2761
|
ast.Tuple: emit_Tuple,
|
|
2745
2762
|
ast.Pass: emit_Pass,
|
|
2746
|
-
ast.Ellipsis: emit_Ellipsis,
|
|
2747
2763
|
ast.Assert: emit_Assert,
|
|
2748
2764
|
}
|
|
2749
2765
|
|
|
@@ -2939,12 +2955,16 @@ class Adjoint:
|
|
|
2939
2955
|
|
|
2940
2956
|
# We want to replace the expression code in-place,
|
|
2941
2957
|
# so reparse it to get the correct column info.
|
|
2942
|
-
len_value_locs = []
|
|
2958
|
+
len_value_locs: List[Tuple[int, int, int]] = []
|
|
2943
2959
|
expr_tree = ast.parse(static_code)
|
|
2944
2960
|
assert len(expr_tree.body) == 1 and isinstance(expr_tree.body[0], ast.Expr)
|
|
2945
2961
|
expr_root = expr_tree.body[0].value
|
|
2946
2962
|
for expr_node in ast.walk(expr_root):
|
|
2947
|
-
if
|
|
2963
|
+
if (
|
|
2964
|
+
isinstance(expr_node, ast.Call)
|
|
2965
|
+
and getattr(expr_node.func, "id", None) == "len"
|
|
2966
|
+
and len(expr_node.args) == 1
|
|
2967
|
+
):
|
|
2948
2968
|
len_expr = static_code[expr_node.col_offset : expr_node.end_col_offset]
|
|
2949
2969
|
try:
|
|
2950
2970
|
len_value = eval(len_expr, len_expr_ctx)
|
|
@@ -3102,9 +3122,9 @@ class Adjoint:
|
|
|
3102
3122
|
|
|
3103
3123
|
local_variables = set() # Track local variables appearing on the LHS so we know when variables are shadowed
|
|
3104
3124
|
|
|
3105
|
-
constants = {}
|
|
3106
|
-
types = {}
|
|
3107
|
-
functions = {}
|
|
3125
|
+
constants: Dict[str, Any] = {}
|
|
3126
|
+
types: Dict[Union[Struct, type], Any] = {}
|
|
3127
|
+
functions: Dict[warp.context.Function, Any] = {}
|
|
3108
3128
|
|
|
3109
3129
|
for node in ast.walk(adj.tree):
|
|
3110
3130
|
if isinstance(node, ast.Name) and node.id not in local_variables:
|
|
@@ -3147,7 +3167,7 @@ class Adjoint:
|
|
|
3147
3167
|
# code generation
|
|
3148
3168
|
|
|
3149
3169
|
cpu_module_header = """
|
|
3150
|
-
#define WP_TILE_BLOCK_DIM {
|
|
3170
|
+
#define WP_TILE_BLOCK_DIM {block_dim}
|
|
3151
3171
|
#define WP_NO_CRT
|
|
3152
3172
|
#include "builtin.h"
|
|
3153
3173
|
|
|
@@ -3166,7 +3186,7 @@ cpu_module_header = """
|
|
|
3166
3186
|
"""
|
|
3167
3187
|
|
|
3168
3188
|
cuda_module_header = """
|
|
3169
|
-
#define WP_TILE_BLOCK_DIM {
|
|
3189
|
+
#define WP_TILE_BLOCK_DIM {block_dim}
|
|
3170
3190
|
#define WP_NO_CRT
|
|
3171
3191
|
#include "builtin.h"
|
|
3172
3192
|
|
|
@@ -3189,6 +3209,7 @@ struct {name}
|
|
|
3189
3209
|
{{
|
|
3190
3210
|
{struct_body}
|
|
3191
3211
|
|
|
3212
|
+
{defaulted_constructor_def}
|
|
3192
3213
|
CUDA_CALLABLE {name}({forward_args})
|
|
3193
3214
|
{forward_initializers}
|
|
3194
3215
|
{{
|
|
@@ -3231,53 +3252,53 @@ static void adj_{name}(
|
|
|
3231
3252
|
|
|
3232
3253
|
cuda_forward_function_template = """
|
|
3233
3254
|
// {filename}:{lineno}
|
|
3234
|
-
static CUDA_CALLABLE {return_type} {name}(
|
|
3255
|
+
{line_directive}static CUDA_CALLABLE {return_type} {name}(
|
|
3235
3256
|
{forward_args})
|
|
3236
3257
|
{{
|
|
3237
|
-
{forward_body}}}
|
|
3258
|
+
{forward_body}{line_directive}}}
|
|
3238
3259
|
|
|
3239
3260
|
"""
|
|
3240
3261
|
|
|
3241
3262
|
cuda_reverse_function_template = """
|
|
3242
3263
|
// {filename}:{lineno}
|
|
3243
|
-
static CUDA_CALLABLE void adj_{name}(
|
|
3264
|
+
{line_directive}static CUDA_CALLABLE void adj_{name}(
|
|
3244
3265
|
{reverse_args})
|
|
3245
3266
|
{{
|
|
3246
|
-
{reverse_body}}}
|
|
3267
|
+
{reverse_body}{line_directive}}}
|
|
3247
3268
|
|
|
3248
3269
|
"""
|
|
3249
3270
|
|
|
3250
3271
|
cuda_kernel_template_forward = """
|
|
3251
3272
|
|
|
3252
|
-
extern "C" __global__ void {name}_cuda_kernel_forward(
|
|
3273
|
+
{line_directive}extern "C" __global__ void {name}_cuda_kernel_forward(
|
|
3253
3274
|
{forward_args})
|
|
3254
3275
|
{{
|
|
3255
|
-
for (size_t _idx = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
|
|
3256
|
-
_idx < dim.size;
|
|
3257
|
-
_idx += static_cast<size_t>(blockDim.x) * static_cast<size_t>(gridDim.x))
|
|
3276
|
+
{line_directive} for (size_t _idx = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
|
|
3277
|
+
{line_directive} _idx < dim.size;
|
|
3278
|
+
{line_directive} _idx += static_cast<size_t>(blockDim.x) * static_cast<size_t>(gridDim.x))
|
|
3258
3279
|
{{
|
|
3259
3280
|
// reset shared memory allocator
|
|
3260
|
-
wp::tile_alloc_shared(0, true);
|
|
3281
|
+
{line_directive} wp::tile_alloc_shared(0, true);
|
|
3261
3282
|
|
|
3262
|
-
{forward_body} }}
|
|
3263
|
-
}}
|
|
3283
|
+
{forward_body}{line_directive} }}
|
|
3284
|
+
{line_directive}}}
|
|
3264
3285
|
|
|
3265
3286
|
"""
|
|
3266
3287
|
|
|
3267
3288
|
cuda_kernel_template_backward = """
|
|
3268
3289
|
|
|
3269
|
-
extern "C" __global__ void {name}_cuda_kernel_backward(
|
|
3290
|
+
{line_directive}extern "C" __global__ void {name}_cuda_kernel_backward(
|
|
3270
3291
|
{reverse_args})
|
|
3271
3292
|
{{
|
|
3272
|
-
for (size_t _idx = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
|
|
3273
|
-
_idx < dim.size;
|
|
3274
|
-
_idx += static_cast<size_t>(blockDim.x) * static_cast<size_t>(gridDim.x))
|
|
3293
|
+
{line_directive} for (size_t _idx = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
|
|
3294
|
+
{line_directive} _idx < dim.size;
|
|
3295
|
+
{line_directive} _idx += static_cast<size_t>(blockDim.x) * static_cast<size_t>(gridDim.x))
|
|
3275
3296
|
{{
|
|
3276
3297
|
// reset shared memory allocator
|
|
3277
|
-
wp::tile_alloc_shared(0, true);
|
|
3298
|
+
{line_directive} wp::tile_alloc_shared(0, true);
|
|
3278
3299
|
|
|
3279
|
-
{reverse_body} }}
|
|
3280
|
-
}}
|
|
3300
|
+
{reverse_body}{line_directive} }}
|
|
3301
|
+
{line_directive}}}
|
|
3281
3302
|
|
|
3282
3303
|
"""
|
|
3283
3304
|
|
|
@@ -3307,10 +3328,17 @@ extern "C" {{
|
|
|
3307
3328
|
WP_API void {name}_cpu_forward(
|
|
3308
3329
|
{forward_args})
|
|
3309
3330
|
{{
|
|
3310
|
-
|
|
3331
|
+
for (size_t task_index = 0; task_index < dim.size; ++task_index)
|
|
3311
3332
|
{{
|
|
3333
|
+
// init shared memory allocator
|
|
3334
|
+
wp::tile_alloc_shared(0, true);
|
|
3335
|
+
|
|
3312
3336
|
{name}_cpu_kernel_forward(
|
|
3313
3337
|
{forward_params});
|
|
3338
|
+
|
|
3339
|
+
// check shared memory allocator
|
|
3340
|
+
wp::tile_alloc_shared(0, false, true);
|
|
3341
|
+
|
|
3314
3342
|
}}
|
|
3315
3343
|
}}
|
|
3316
3344
|
|
|
@@ -3327,8 +3355,14 @@ WP_API void {name}_cpu_backward(
|
|
|
3327
3355
|
{{
|
|
3328
3356
|
for (size_t task_index = 0; task_index < dim.size; ++task_index)
|
|
3329
3357
|
{{
|
|
3358
|
+
// initialize shared memory allocator
|
|
3359
|
+
wp::tile_alloc_shared(0, true);
|
|
3360
|
+
|
|
3330
3361
|
{name}_cpu_kernel_backward(
|
|
3331
3362
|
{reverse_params});
|
|
3363
|
+
|
|
3364
|
+
// check shared memory allocator
|
|
3365
|
+
wp::tile_alloc_shared(0, false, true);
|
|
3332
3366
|
}}
|
|
3333
3367
|
}}
|
|
3334
3368
|
|
|
@@ -3410,7 +3444,7 @@ def indent(args, stops=1):
|
|
|
3410
3444
|
|
|
3411
3445
|
|
|
3412
3446
|
# generates a C function name based on the python function name
|
|
3413
|
-
def make_full_qualified_name(func):
|
|
3447
|
+
def make_full_qualified_name(func: Union[str, Callable]) -> str:
|
|
3414
3448
|
if not isinstance(func, str):
|
|
3415
3449
|
func = func.__qualname__
|
|
3416
3450
|
return re.sub("[^0-9a-zA-Z_]+", "", func.replace(".", "__"))
|
|
@@ -3440,7 +3474,8 @@ def codegen_struct(struct, device="cpu", indent_size=4):
|
|
|
3440
3474
|
# forward args
|
|
3441
3475
|
for label, var in struct.vars.items():
|
|
3442
3476
|
var_ctype = var.ctype()
|
|
3443
|
-
|
|
3477
|
+
default_arg_def = " = {}" if forward_args else ""
|
|
3478
|
+
forward_args.append(f"{var_ctype} const& {label}{default_arg_def}")
|
|
3444
3479
|
reverse_args.append(f"{var_ctype} const&")
|
|
3445
3480
|
|
|
3446
3481
|
namespace = "wp::" if var_ctype.startswith("wp::") or var_ctype == "bool" else ""
|
|
@@ -3464,6 +3499,9 @@ def codegen_struct(struct, device="cpu", indent_size=4):
|
|
|
3464
3499
|
|
|
3465
3500
|
reverse_args.append(name + " & adj_ret")
|
|
3466
3501
|
|
|
3502
|
+
# explicitly defaulted default constructor if no default constructor has been defined
|
|
3503
|
+
defaulted_constructor_def = f"{name}() = default;" if forward_args else ""
|
|
3504
|
+
|
|
3467
3505
|
return struct_template.format(
|
|
3468
3506
|
name=name,
|
|
3469
3507
|
struct_body="".join([indent_block + l for l in body]),
|
|
@@ -3473,6 +3511,7 @@ def codegen_struct(struct, device="cpu", indent_size=4):
|
|
|
3473
3511
|
reverse_body="".join(reverse_body),
|
|
3474
3512
|
prefix_add_body="".join(prefix_add_body),
|
|
3475
3513
|
atomic_add_body="".join(atomic_add_body),
|
|
3514
|
+
defaulted_constructor_def=defaulted_constructor_def,
|
|
3476
3515
|
)
|
|
3477
3516
|
|
|
3478
3517
|
|
|
@@ -3502,6 +3541,9 @@ def codegen_func_forward(adj, func_type="kernel", device="cpu"):
|
|
|
3502
3541
|
else:
|
|
3503
3542
|
lines += [f"const {var.ctype()} {var.emit()} = {constant_str(var.constant)};\n"]
|
|
3504
3543
|
|
|
3544
|
+
if line_directive := adj.get_line_directive(lines[-1], var.relative_lineno):
|
|
3545
|
+
lines.insert(-1, f"{line_directive}\n")
|
|
3546
|
+
|
|
3505
3547
|
# forward pass
|
|
3506
3548
|
lines += ["//---------\n"]
|
|
3507
3549
|
lines += ["// forward\n"]
|
|
@@ -3509,7 +3551,7 @@ def codegen_func_forward(adj, func_type="kernel", device="cpu"):
|
|
|
3509
3551
|
for f in adj.blocks[0].body_forward:
|
|
3510
3552
|
lines += [f + "\n"]
|
|
3511
3553
|
|
|
3512
|
-
return "".join(
|
|
3554
|
+
return "".join(l.lstrip() if l.lstrip().startswith("#line") else indent_block + l for l in lines)
|
|
3513
3555
|
|
|
3514
3556
|
|
|
3515
3557
|
def codegen_func_reverse(adj, func_type="kernel", device="cpu"):
|
|
@@ -3539,6 +3581,9 @@ def codegen_func_reverse(adj, func_type="kernel", device="cpu"):
|
|
|
3539
3581
|
else:
|
|
3540
3582
|
lines += [f"const {var.ctype()} {var.emit()} = {constant_str(var.constant)};\n"]
|
|
3541
3583
|
|
|
3584
|
+
if line_directive := adj.get_line_directive(lines[-1], var.relative_lineno):
|
|
3585
|
+
lines.insert(-1, f"{line_directive}\n")
|
|
3586
|
+
|
|
3542
3587
|
# dual vars
|
|
3543
3588
|
lines += ["//---------\n"]
|
|
3544
3589
|
lines += ["// dual vars\n"]
|
|
@@ -3559,6 +3604,9 @@ def codegen_func_reverse(adj, func_type="kernel", device="cpu"):
|
|
|
3559
3604
|
else:
|
|
3560
3605
|
lines += [f"{ctype} {name} = {{}};\n"]
|
|
3561
3606
|
|
|
3607
|
+
if line_directive := adj.get_line_directive(lines[-1], var.relative_lineno):
|
|
3608
|
+
lines.insert(-1, f"{line_directive}\n")
|
|
3609
|
+
|
|
3562
3610
|
# forward pass
|
|
3563
3611
|
lines += ["//---------\n"]
|
|
3564
3612
|
lines += ["// forward\n"]
|
|
@@ -3579,7 +3627,7 @@ def codegen_func_reverse(adj, func_type="kernel", device="cpu"):
|
|
|
3579
3627
|
else:
|
|
3580
3628
|
lines += ["return;\n"]
|
|
3581
3629
|
|
|
3582
|
-
return "".join(
|
|
3630
|
+
return "".join(l.lstrip() if l.lstrip().startswith("#line") else indent_block + l for l in lines)
|
|
3583
3631
|
|
|
3584
3632
|
|
|
3585
3633
|
def codegen_func(adj, c_func_name: str, device="cpu", options=None):
|
|
@@ -3587,11 +3635,11 @@ def codegen_func(adj, c_func_name: str, device="cpu", options=None):
|
|
|
3587
3635
|
options = {}
|
|
3588
3636
|
|
|
3589
3637
|
if adj.return_var is not None and "return" in adj.arg_types:
|
|
3590
|
-
if
|
|
3591
|
-
if len(
|
|
3638
|
+
if get_origin(adj.arg_types["return"]) is tuple:
|
|
3639
|
+
if len(get_args(adj.arg_types["return"])) != len(adj.return_var):
|
|
3592
3640
|
raise WarpCodegenError(
|
|
3593
3641
|
f"The function `{adj.fun_name}` has its return type "
|
|
3594
|
-
f"annotated as a tuple of {len(
|
|
3642
|
+
f"annotated as a tuple of {len(get_args(adj.arg_types['return']))} elements "
|
|
3595
3643
|
f"but the code returns {len(adj.return_var)} values."
|
|
3596
3644
|
)
|
|
3597
3645
|
elif not types_equal(adj.arg_types["return"], tuple(x.type for x in adj.return_var)):
|
|
@@ -3600,7 +3648,7 @@ def codegen_func(adj, c_func_name: str, device="cpu", options=None):
|
|
|
3600
3648
|
f"annotated as `{warp.context.type_str(adj.arg_types['return'])}` "
|
|
3601
3649
|
f"but the code returns a tuple with types `({', '.join(warp.context.type_str(x.type) for x in adj.return_var)})`."
|
|
3602
3650
|
)
|
|
3603
|
-
elif len(adj.return_var) > 1 and
|
|
3651
|
+
elif len(adj.return_var) > 1 and get_origin(adj.arg_types["return"]) is not tuple:
|
|
3604
3652
|
raise WarpCodegenError(
|
|
3605
3653
|
f"The function `{adj.fun_name}` has its return type "
|
|
3606
3654
|
f"annotated as `{warp.context.type_str(adj.arg_types['return'])}` "
|
|
@@ -3613,6 +3661,13 @@ def codegen_func(adj, c_func_name: str, device="cpu", options=None):
|
|
|
3613
3661
|
f"but the code returns a value of type `{warp.context.type_str(adj.return_var[0].type)}`."
|
|
3614
3662
|
)
|
|
3615
3663
|
|
|
3664
|
+
# Build line directive for function definition (subtract 1 to account for 1-indexing of AST line numbers)
|
|
3665
|
+
# This is used as a catch-all C-to-Python source line mapping for any code that does not have
|
|
3666
|
+
# a direct mapping to a Python source line.
|
|
3667
|
+
func_line_directive = ""
|
|
3668
|
+
if line_directive := adj.get_line_directive("", adj.fun_def_lineno - 1):
|
|
3669
|
+
func_line_directive = f"{line_directive}\n"
|
|
3670
|
+
|
|
3616
3671
|
# forward header
|
|
3617
3672
|
if adj.return_var is not None and len(adj.return_var) == 1:
|
|
3618
3673
|
return_type = adj.return_var[0].ctype()
|
|
@@ -3676,6 +3731,7 @@ def codegen_func(adj, c_func_name: str, device="cpu", options=None):
|
|
|
3676
3731
|
forward_body=forward_body,
|
|
3677
3732
|
filename=adj.filename,
|
|
3678
3733
|
lineno=adj.fun_lineno,
|
|
3734
|
+
line_directive=func_line_directive,
|
|
3679
3735
|
)
|
|
3680
3736
|
|
|
3681
3737
|
if not adj.skip_reverse_codegen:
|
|
@@ -3694,6 +3750,7 @@ def codegen_func(adj, c_func_name: str, device="cpu", options=None):
|
|
|
3694
3750
|
reverse_body=reverse_body,
|
|
3695
3751
|
filename=adj.filename,
|
|
3696
3752
|
lineno=adj.fun_lineno,
|
|
3753
|
+
line_directive=func_line_directive,
|
|
3697
3754
|
)
|
|
3698
3755
|
|
|
3699
3756
|
return s
|
|
@@ -3736,6 +3793,7 @@ def codegen_snippet(adj, name, snippet, adj_snippet, replay_snippet):
|
|
|
3736
3793
|
forward_body=snippet,
|
|
3737
3794
|
filename=adj.filename,
|
|
3738
3795
|
lineno=adj.fun_lineno,
|
|
3796
|
+
line_directive="",
|
|
3739
3797
|
)
|
|
3740
3798
|
|
|
3741
3799
|
if replay_snippet is not None:
|
|
@@ -3746,6 +3804,7 @@ def codegen_snippet(adj, name, snippet, adj_snippet, replay_snippet):
|
|
|
3746
3804
|
forward_body=replay_snippet,
|
|
3747
3805
|
filename=adj.filename,
|
|
3748
3806
|
lineno=adj.fun_lineno,
|
|
3807
|
+
line_directive="",
|
|
3749
3808
|
)
|
|
3750
3809
|
|
|
3751
3810
|
if adj_snippet:
|
|
@@ -3761,6 +3820,7 @@ def codegen_snippet(adj, name, snippet, adj_snippet, replay_snippet):
|
|
|
3761
3820
|
reverse_body=reverse_body,
|
|
3762
3821
|
filename=adj.filename,
|
|
3763
3822
|
lineno=adj.fun_lineno,
|
|
3823
|
+
line_directive="",
|
|
3764
3824
|
)
|
|
3765
3825
|
|
|
3766
3826
|
return s
|
|
@@ -3773,6 +3833,13 @@ def codegen_kernel(kernel, device, options):
|
|
|
3773
3833
|
|
|
3774
3834
|
adj = kernel.adj
|
|
3775
3835
|
|
|
3836
|
+
# Build line directive for function definition (subtract 1 to account for 1-indexing of AST line numbers)
|
|
3837
|
+
# This is used as a catch-all C-to-Python source line mapping for any code that does not have
|
|
3838
|
+
# a direct mapping to a Python source line.
|
|
3839
|
+
func_line_directive = ""
|
|
3840
|
+
if line_directive := adj.get_line_directive("", adj.fun_def_lineno - 1):
|
|
3841
|
+
func_line_directive = f"{line_directive}\n"
|
|
3842
|
+
|
|
3776
3843
|
if device == "cpu":
|
|
3777
3844
|
template_forward = cpu_kernel_template_forward
|
|
3778
3845
|
template_backward = cpu_kernel_template_backward
|
|
@@ -3800,6 +3867,7 @@ def codegen_kernel(kernel, device, options):
|
|
|
3800
3867
|
{
|
|
3801
3868
|
"forward_args": indent(forward_args),
|
|
3802
3869
|
"forward_body": forward_body,
|
|
3870
|
+
"line_directive": func_line_directive,
|
|
3803
3871
|
}
|
|
3804
3872
|
)
|
|
3805
3873
|
template += template_forward
|