warp-lang 1.7.0__py3-none-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +139 -0
- warp/__init__.pyi +1 -0
- warp/autograd.py +1142 -0
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/build.py +557 -0
- warp/build_dll.py +405 -0
- warp/builtins.py +6855 -0
- warp/codegen.py +3969 -0
- warp/config.py +158 -0
- warp/constants.py +57 -0
- warp/context.py +6812 -0
- warp/dlpack.py +462 -0
- warp/examples/__init__.py +24 -0
- warp/examples/assets/bear.usd +0 -0
- warp/examples/assets/bunny.usd +0 -0
- warp/examples/assets/cartpole.urdf +110 -0
- warp/examples/assets/crazyflie.usd +0 -0
- warp/examples/assets/cube.usd +0 -0
- warp/examples/assets/nonuniform.usd +0 -0
- warp/examples/assets/nv_ant.xml +92 -0
- warp/examples/assets/nv_humanoid.xml +183 -0
- warp/examples/assets/nvidia_logo.png +0 -0
- warp/examples/assets/pixel.jpg +0 -0
- warp/examples/assets/quadruped.urdf +268 -0
- warp/examples/assets/rocks.nvdb +0 -0
- warp/examples/assets/rocks.usd +0 -0
- warp/examples/assets/sphere.usd +0 -0
- warp/examples/assets/square_cloth.usd +0 -0
- warp/examples/benchmarks/benchmark_api.py +389 -0
- warp/examples/benchmarks/benchmark_cloth.py +296 -0
- warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
- warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
- warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
- warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
- warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
- warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
- warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
- warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
- warp/examples/benchmarks/benchmark_gemm.py +164 -0
- warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
- warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
- warp/examples/benchmarks/benchmark_launches.py +301 -0
- warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
- warp/examples/browse.py +37 -0
- warp/examples/core/example_cupy.py +86 -0
- warp/examples/core/example_dem.py +241 -0
- warp/examples/core/example_fluid.py +299 -0
- warp/examples/core/example_graph_capture.py +150 -0
- warp/examples/core/example_marching_cubes.py +194 -0
- warp/examples/core/example_mesh.py +180 -0
- warp/examples/core/example_mesh_intersect.py +211 -0
- warp/examples/core/example_nvdb.py +182 -0
- warp/examples/core/example_raycast.py +111 -0
- warp/examples/core/example_raymarch.py +205 -0
- warp/examples/core/example_render_opengl.py +193 -0
- warp/examples/core/example_sample_mesh.py +300 -0
- warp/examples/core/example_sph.py +411 -0
- warp/examples/core/example_torch.py +211 -0
- warp/examples/core/example_wave.py +269 -0
- warp/examples/fem/example_adaptive_grid.py +286 -0
- warp/examples/fem/example_apic_fluid.py +423 -0
- warp/examples/fem/example_burgers.py +261 -0
- warp/examples/fem/example_convection_diffusion.py +178 -0
- warp/examples/fem/example_convection_diffusion_dg.py +204 -0
- warp/examples/fem/example_deformed_geometry.py +172 -0
- warp/examples/fem/example_diffusion.py +196 -0
- warp/examples/fem/example_diffusion_3d.py +225 -0
- warp/examples/fem/example_diffusion_mgpu.py +220 -0
- warp/examples/fem/example_distortion_energy.py +228 -0
- warp/examples/fem/example_magnetostatics.py +240 -0
- warp/examples/fem/example_mixed_elasticity.py +291 -0
- warp/examples/fem/example_navier_stokes.py +261 -0
- warp/examples/fem/example_nonconforming_contact.py +298 -0
- warp/examples/fem/example_stokes.py +213 -0
- warp/examples/fem/example_stokes_transfer.py +262 -0
- warp/examples/fem/example_streamlines.py +352 -0
- warp/examples/fem/utils.py +1000 -0
- warp/examples/interop/example_jax_callable.py +116 -0
- warp/examples/interop/example_jax_ffi_callback.py +132 -0
- warp/examples/interop/example_jax_kernel.py +205 -0
- warp/examples/optim/example_bounce.py +266 -0
- warp/examples/optim/example_cloth_throw.py +228 -0
- warp/examples/optim/example_diffray.py +561 -0
- warp/examples/optim/example_drone.py +870 -0
- warp/examples/optim/example_fluid_checkpoint.py +497 -0
- warp/examples/optim/example_inverse_kinematics.py +182 -0
- warp/examples/optim/example_inverse_kinematics_torch.py +191 -0
- warp/examples/optim/example_softbody_properties.py +400 -0
- warp/examples/optim/example_spring_cage.py +245 -0
- warp/examples/optim/example_trajectory.py +227 -0
- warp/examples/sim/example_cartpole.py +143 -0
- warp/examples/sim/example_cloth.py +225 -0
- warp/examples/sim/example_cloth_self_contact.py +322 -0
- warp/examples/sim/example_granular.py +130 -0
- warp/examples/sim/example_granular_collision_sdf.py +202 -0
- warp/examples/sim/example_jacobian_ik.py +244 -0
- warp/examples/sim/example_particle_chain.py +124 -0
- warp/examples/sim/example_quadruped.py +203 -0
- warp/examples/sim/example_rigid_chain.py +203 -0
- warp/examples/sim/example_rigid_contact.py +195 -0
- warp/examples/sim/example_rigid_force.py +133 -0
- warp/examples/sim/example_rigid_gyroscopic.py +115 -0
- warp/examples/sim/example_rigid_soft_contact.py +140 -0
- warp/examples/sim/example_soft_body.py +196 -0
- warp/examples/tile/example_tile_cholesky.py +87 -0
- warp/examples/tile/example_tile_convolution.py +66 -0
- warp/examples/tile/example_tile_fft.py +55 -0
- warp/examples/tile/example_tile_filtering.py +113 -0
- warp/examples/tile/example_tile_matmul.py +85 -0
- warp/examples/tile/example_tile_mlp.py +383 -0
- warp/examples/tile/example_tile_nbody.py +199 -0
- warp/examples/tile/example_tile_walker.py +327 -0
- warp/fabric.py +355 -0
- warp/fem/__init__.py +106 -0
- warp/fem/adaptivity.py +508 -0
- warp/fem/cache.py +572 -0
- warp/fem/dirichlet.py +202 -0
- warp/fem/domain.py +411 -0
- warp/fem/field/__init__.py +125 -0
- warp/fem/field/field.py +619 -0
- warp/fem/field/nodal_field.py +326 -0
- warp/fem/field/restriction.py +37 -0
- warp/fem/field/virtual.py +848 -0
- warp/fem/geometry/__init__.py +32 -0
- warp/fem/geometry/adaptive_nanogrid.py +857 -0
- warp/fem/geometry/closest_point.py +84 -0
- warp/fem/geometry/deformed_geometry.py +221 -0
- warp/fem/geometry/element.py +776 -0
- warp/fem/geometry/geometry.py +362 -0
- warp/fem/geometry/grid_2d.py +392 -0
- warp/fem/geometry/grid_3d.py +452 -0
- warp/fem/geometry/hexmesh.py +911 -0
- warp/fem/geometry/nanogrid.py +571 -0
- warp/fem/geometry/partition.py +389 -0
- warp/fem/geometry/quadmesh.py +663 -0
- warp/fem/geometry/tetmesh.py +855 -0
- warp/fem/geometry/trimesh.py +806 -0
- warp/fem/integrate.py +2335 -0
- warp/fem/linalg.py +419 -0
- warp/fem/operator.py +293 -0
- warp/fem/polynomial.py +229 -0
- warp/fem/quadrature/__init__.py +17 -0
- warp/fem/quadrature/pic_quadrature.py +299 -0
- warp/fem/quadrature/quadrature.py +591 -0
- warp/fem/space/__init__.py +228 -0
- warp/fem/space/basis_function_space.py +468 -0
- warp/fem/space/basis_space.py +667 -0
- warp/fem/space/dof_mapper.py +251 -0
- warp/fem/space/function_space.py +309 -0
- warp/fem/space/grid_2d_function_space.py +177 -0
- warp/fem/space/grid_3d_function_space.py +227 -0
- warp/fem/space/hexmesh_function_space.py +257 -0
- warp/fem/space/nanogrid_function_space.py +201 -0
- warp/fem/space/partition.py +367 -0
- warp/fem/space/quadmesh_function_space.py +223 -0
- warp/fem/space/restriction.py +179 -0
- warp/fem/space/shape/__init__.py +143 -0
- warp/fem/space/shape/cube_shape_function.py +1105 -0
- warp/fem/space/shape/shape_function.py +133 -0
- warp/fem/space/shape/square_shape_function.py +926 -0
- warp/fem/space/shape/tet_shape_function.py +834 -0
- warp/fem/space/shape/triangle_shape_function.py +672 -0
- warp/fem/space/tetmesh_function_space.py +271 -0
- warp/fem/space/topology.py +424 -0
- warp/fem/space/trimesh_function_space.py +194 -0
- warp/fem/types.py +99 -0
- warp/fem/utils.py +420 -0
- warp/jax.py +187 -0
- warp/jax_experimental/__init__.py +16 -0
- warp/jax_experimental/custom_call.py +351 -0
- warp/jax_experimental/ffi.py +698 -0
- warp/jax_experimental/xla_ffi.py +602 -0
- warp/math.py +244 -0
- warp/native/array.h +1145 -0
- warp/native/builtin.h +1800 -0
- warp/native/bvh.cpp +492 -0
- warp/native/bvh.cu +791 -0
- warp/native/bvh.h +554 -0
- warp/native/clang/clang.cpp +536 -0
- warp/native/coloring.cpp +613 -0
- warp/native/crt.cpp +51 -0
- warp/native/crt.h +362 -0
- warp/native/cuda_crt.h +1058 -0
- warp/native/cuda_util.cpp +646 -0
- warp/native/cuda_util.h +307 -0
- warp/native/error.cpp +77 -0
- warp/native/error.h +36 -0
- warp/native/exports.h +1878 -0
- warp/native/fabric.h +245 -0
- warp/native/hashgrid.cpp +311 -0
- warp/native/hashgrid.cu +87 -0
- warp/native/hashgrid.h +240 -0
- warp/native/initializer_array.h +41 -0
- warp/native/intersect.h +1230 -0
- warp/native/intersect_adj.h +375 -0
- warp/native/intersect_tri.h +339 -0
- warp/native/marching.cpp +19 -0
- warp/native/marching.cu +514 -0
- warp/native/marching.h +19 -0
- warp/native/mat.h +2220 -0
- warp/native/mathdx.cpp +87 -0
- warp/native/matnn.h +343 -0
- warp/native/mesh.cpp +266 -0
- warp/native/mesh.cu +404 -0
- warp/native/mesh.h +1980 -0
- warp/native/nanovdb/GridHandle.h +366 -0
- warp/native/nanovdb/HostBuffer.h +590 -0
- warp/native/nanovdb/NanoVDB.h +6624 -0
- warp/native/nanovdb/PNanoVDB.h +3390 -0
- warp/native/noise.h +859 -0
- warp/native/quat.h +1371 -0
- warp/native/rand.h +342 -0
- warp/native/range.h +139 -0
- warp/native/reduce.cpp +174 -0
- warp/native/reduce.cu +364 -0
- warp/native/runlength_encode.cpp +79 -0
- warp/native/runlength_encode.cu +61 -0
- warp/native/scan.cpp +47 -0
- warp/native/scan.cu +53 -0
- warp/native/scan.h +23 -0
- warp/native/solid_angle.h +466 -0
- warp/native/sort.cpp +251 -0
- warp/native/sort.cu +277 -0
- warp/native/sort.h +33 -0
- warp/native/sparse.cpp +378 -0
- warp/native/sparse.cu +524 -0
- warp/native/spatial.h +657 -0
- warp/native/svd.h +702 -0
- warp/native/temp_buffer.h +46 -0
- warp/native/tile.h +2584 -0
- warp/native/tile_reduce.h +264 -0
- warp/native/vec.h +1426 -0
- warp/native/volume.cpp +501 -0
- warp/native/volume.cu +67 -0
- warp/native/volume.h +969 -0
- warp/native/volume_builder.cu +477 -0
- warp/native/volume_builder.h +52 -0
- warp/native/volume_impl.h +70 -0
- warp/native/warp.cpp +1082 -0
- warp/native/warp.cu +3636 -0
- warp/native/warp.h +381 -0
- warp/optim/__init__.py +17 -0
- warp/optim/adam.py +163 -0
- warp/optim/linear.py +1137 -0
- warp/optim/sgd.py +112 -0
- warp/paddle.py +407 -0
- warp/render/__init__.py +18 -0
- warp/render/render_opengl.py +3518 -0
- warp/render/render_usd.py +784 -0
- warp/render/utils.py +160 -0
- warp/sim/__init__.py +65 -0
- warp/sim/articulation.py +793 -0
- warp/sim/collide.py +2395 -0
- warp/sim/graph_coloring.py +300 -0
- warp/sim/import_mjcf.py +790 -0
- warp/sim/import_snu.py +227 -0
- warp/sim/import_urdf.py +579 -0
- warp/sim/import_usd.py +894 -0
- warp/sim/inertia.py +324 -0
- warp/sim/integrator.py +242 -0
- warp/sim/integrator_euler.py +1997 -0
- warp/sim/integrator_featherstone.py +2101 -0
- warp/sim/integrator_vbd.py +2048 -0
- warp/sim/integrator_xpbd.py +3292 -0
- warp/sim/model.py +4791 -0
- warp/sim/particles.py +121 -0
- warp/sim/render.py +427 -0
- warp/sim/utils.py +428 -0
- warp/sparse.py +2057 -0
- warp/stubs.py +3333 -0
- warp/tape.py +1203 -0
- warp/tests/__init__.py +1 -0
- warp/tests/__main__.py +4 -0
- warp/tests/assets/curlnoise_golden.npy +0 -0
- warp/tests/assets/mlp_golden.npy +0 -0
- warp/tests/assets/pixel.npy +0 -0
- warp/tests/assets/pnoise_golden.npy +0 -0
- warp/tests/assets/spiky.usd +0 -0
- warp/tests/assets/test_grid.nvdb +0 -0
- warp/tests/assets/test_index_grid.nvdb +0 -0
- warp/tests/assets/test_int32_grid.nvdb +0 -0
- warp/tests/assets/test_vec_grid.nvdb +0 -0
- warp/tests/assets/torus.nvdb +0 -0
- warp/tests/assets/torus.usda +105 -0
- warp/tests/aux_test_class_kernel.py +34 -0
- warp/tests/aux_test_compile_consts_dummy.py +18 -0
- warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
- warp/tests/aux_test_dependent.py +29 -0
- warp/tests/aux_test_grad_customs.py +29 -0
- warp/tests/aux_test_instancing_gc.py +26 -0
- warp/tests/aux_test_module_unload.py +23 -0
- warp/tests/aux_test_name_clash1.py +40 -0
- warp/tests/aux_test_name_clash2.py +40 -0
- warp/tests/aux_test_reference.py +9 -0
- warp/tests/aux_test_reference_reference.py +8 -0
- warp/tests/aux_test_square.py +16 -0
- warp/tests/aux_test_unresolved_func.py +22 -0
- warp/tests/aux_test_unresolved_symbol.py +22 -0
- warp/tests/cuda/__init__.py +0 -0
- warp/tests/cuda/test_async.py +676 -0
- warp/tests/cuda/test_ipc.py +124 -0
- warp/tests/cuda/test_mempool.py +233 -0
- warp/tests/cuda/test_multigpu.py +169 -0
- warp/tests/cuda/test_peer.py +139 -0
- warp/tests/cuda/test_pinned.py +84 -0
- warp/tests/cuda/test_streams.py +634 -0
- warp/tests/geometry/__init__.py +0 -0
- warp/tests/geometry/test_bvh.py +200 -0
- warp/tests/geometry/test_hash_grid.py +221 -0
- warp/tests/geometry/test_marching_cubes.py +74 -0
- warp/tests/geometry/test_mesh.py +316 -0
- warp/tests/geometry/test_mesh_query_aabb.py +399 -0
- warp/tests/geometry/test_mesh_query_point.py +932 -0
- warp/tests/geometry/test_mesh_query_ray.py +311 -0
- warp/tests/geometry/test_volume.py +1103 -0
- warp/tests/geometry/test_volume_write.py +346 -0
- warp/tests/interop/__init__.py +0 -0
- warp/tests/interop/test_dlpack.py +729 -0
- warp/tests/interop/test_jax.py +371 -0
- warp/tests/interop/test_paddle.py +800 -0
- warp/tests/interop/test_torch.py +1001 -0
- warp/tests/run_coverage_serial.py +39 -0
- warp/tests/sim/__init__.py +0 -0
- warp/tests/sim/disabled_kinematics.py +244 -0
- warp/tests/sim/flaky_test_sim_grad.py +290 -0
- warp/tests/sim/test_collision.py +604 -0
- warp/tests/sim/test_coloring.py +258 -0
- warp/tests/sim/test_model.py +224 -0
- warp/tests/sim/test_sim_grad_bounce_linear.py +212 -0
- warp/tests/sim/test_sim_kinematics.py +98 -0
- warp/tests/sim/test_vbd.py +597 -0
- warp/tests/test_adam.py +163 -0
- warp/tests/test_arithmetic.py +1096 -0
- warp/tests/test_array.py +2972 -0
- warp/tests/test_array_reduce.py +156 -0
- warp/tests/test_assert.py +250 -0
- warp/tests/test_atomic.py +153 -0
- warp/tests/test_bool.py +220 -0
- warp/tests/test_builtins_resolution.py +1298 -0
- warp/tests/test_closest_point_edge_edge.py +327 -0
- warp/tests/test_codegen.py +810 -0
- warp/tests/test_codegen_instancing.py +1495 -0
- warp/tests/test_compile_consts.py +215 -0
- warp/tests/test_conditional.py +252 -0
- warp/tests/test_context.py +42 -0
- warp/tests/test_copy.py +238 -0
- warp/tests/test_ctypes.py +638 -0
- warp/tests/test_dense.py +73 -0
- warp/tests/test_devices.py +97 -0
- warp/tests/test_examples.py +482 -0
- warp/tests/test_fabricarray.py +996 -0
- warp/tests/test_fast_math.py +74 -0
- warp/tests/test_fem.py +2003 -0
- warp/tests/test_fp16.py +136 -0
- warp/tests/test_func.py +454 -0
- warp/tests/test_future_annotations.py +98 -0
- warp/tests/test_generics.py +656 -0
- warp/tests/test_grad.py +893 -0
- warp/tests/test_grad_customs.py +339 -0
- warp/tests/test_grad_debug.py +341 -0
- warp/tests/test_implicit_init.py +411 -0
- warp/tests/test_import.py +45 -0
- warp/tests/test_indexedarray.py +1140 -0
- warp/tests/test_intersect.py +73 -0
- warp/tests/test_iter.py +76 -0
- warp/tests/test_large.py +177 -0
- warp/tests/test_launch.py +411 -0
- warp/tests/test_lerp.py +151 -0
- warp/tests/test_linear_solvers.py +193 -0
- warp/tests/test_lvalue.py +427 -0
- warp/tests/test_mat.py +2089 -0
- warp/tests/test_mat_lite.py +122 -0
- warp/tests/test_mat_scalar_ops.py +2913 -0
- warp/tests/test_math.py +178 -0
- warp/tests/test_mlp.py +282 -0
- warp/tests/test_module_hashing.py +258 -0
- warp/tests/test_modules_lite.py +44 -0
- warp/tests/test_noise.py +252 -0
- warp/tests/test_operators.py +299 -0
- warp/tests/test_options.py +129 -0
- warp/tests/test_overwrite.py +551 -0
- warp/tests/test_print.py +339 -0
- warp/tests/test_quat.py +2315 -0
- warp/tests/test_rand.py +339 -0
- warp/tests/test_reload.py +302 -0
- warp/tests/test_rounding.py +185 -0
- warp/tests/test_runlength_encode.py +196 -0
- warp/tests/test_scalar_ops.py +105 -0
- warp/tests/test_smoothstep.py +108 -0
- warp/tests/test_snippet.py +318 -0
- warp/tests/test_sparse.py +582 -0
- warp/tests/test_spatial.py +2229 -0
- warp/tests/test_special_values.py +361 -0
- warp/tests/test_static.py +592 -0
- warp/tests/test_struct.py +734 -0
- warp/tests/test_tape.py +204 -0
- warp/tests/test_transient_module.py +93 -0
- warp/tests/test_triangle_closest_point.py +145 -0
- warp/tests/test_types.py +562 -0
- warp/tests/test_utils.py +588 -0
- warp/tests/test_vec.py +1487 -0
- warp/tests/test_vec_lite.py +80 -0
- warp/tests/test_vec_scalar_ops.py +2327 -0
- warp/tests/test_verify_fp.py +100 -0
- warp/tests/tile/__init__.py +0 -0
- warp/tests/tile/test_tile.py +780 -0
- warp/tests/tile/test_tile_load.py +407 -0
- warp/tests/tile/test_tile_mathdx.py +208 -0
- warp/tests/tile/test_tile_mlp.py +402 -0
- warp/tests/tile/test_tile_reduce.py +447 -0
- warp/tests/tile/test_tile_shared_memory.py +247 -0
- warp/tests/tile/test_tile_view.py +173 -0
- warp/tests/unittest_serial.py +47 -0
- warp/tests/unittest_suites.py +427 -0
- warp/tests/unittest_utils.py +468 -0
- warp/tests/walkthrough_debug.py +93 -0
- warp/thirdparty/__init__.py +0 -0
- warp/thirdparty/appdirs.py +598 -0
- warp/thirdparty/dlpack.py +145 -0
- warp/thirdparty/unittest_parallel.py +570 -0
- warp/torch.py +391 -0
- warp/types.py +5230 -0
- warp/utils.py +1137 -0
- warp_lang-1.7.0.dist-info/METADATA +516 -0
- warp_lang-1.7.0.dist-info/RECORD +429 -0
- warp_lang-1.7.0.dist-info/WHEEL +5 -0
- warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
- warp_lang-1.7.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,199 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
###########################################################################
|
|
17
|
+
# Example N-Body
|
|
18
|
+
#
|
|
19
|
+
# Shows how to simulate an N-Body gravitational problem using an all-pairs
|
|
20
|
+
# approach with Warp tile primitives.
|
|
21
|
+
#
|
|
22
|
+
# References:
|
|
23
|
+
# L. Nyland, M. Harris, and J. Prins. "Fast N-Body Simulation with
|
|
24
|
+
# CUDA" in GPU Gems 3. H. Nguyen, Addison-Wesley Professional, 2007.
|
|
25
|
+
# https://developer.nvidia.com/gpugems/gpugems3/part-v-physics-simulation/chapter-31-fast-n-body-simulation-cuda
|
|
26
|
+
#
|
|
27
|
+
###########################################################################
|
|
28
|
+
|
|
29
|
+
import argparse
|
|
30
|
+
|
|
31
|
+
import numpy as np
|
|
32
|
+
|
|
33
|
+
import warp as wp
|
|
34
|
+
|
|
35
|
+
wp.init()
|
|
36
|
+
|
|
37
|
+
DT = wp.constant(0.016)
|
|
38
|
+
SOFTENING_SQ = wp.constant(0.1**2) # Softening factor for numerical stability
|
|
39
|
+
TILE_SIZE = wp.constant(64)
|
|
40
|
+
PARTICLE_MASS = wp.constant(1.0)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@wp.func
|
|
44
|
+
def body_body_interaction(p0: wp.vec3, pi: wp.vec3):
|
|
45
|
+
"""Return the acceleration of the particle at position `p0` due to the
|
|
46
|
+
particle at position `pi`."""
|
|
47
|
+
r = pi - p0
|
|
48
|
+
|
|
49
|
+
dist_sq = wp.length_sq(r) + SOFTENING_SQ
|
|
50
|
+
|
|
51
|
+
inv_dist = 1.0 / wp.sqrt(dist_sq)
|
|
52
|
+
inv_dist_cubed = inv_dist * inv_dist * inv_dist
|
|
53
|
+
|
|
54
|
+
acc = PARTICLE_MASS * inv_dist_cubed * r
|
|
55
|
+
|
|
56
|
+
return acc
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@wp.kernel
|
|
60
|
+
def integrate_bodies_tiled(
|
|
61
|
+
old_position: wp.array(dtype=wp.vec3),
|
|
62
|
+
velocity: wp.array(dtype=wp.vec3),
|
|
63
|
+
new_position: wp.array(dtype=wp.vec3),
|
|
64
|
+
num_bodies: int,
|
|
65
|
+
):
|
|
66
|
+
i = wp.tid()
|
|
67
|
+
|
|
68
|
+
p0 = old_position[i]
|
|
69
|
+
|
|
70
|
+
accel = wp.vec3(0.0, 0.0, 0.0)
|
|
71
|
+
|
|
72
|
+
for k in range(num_bodies / TILE_SIZE):
|
|
73
|
+
k_tile = wp.tile_load(old_position, shape=TILE_SIZE, offset=k * TILE_SIZE)
|
|
74
|
+
for idx in range(TILE_SIZE):
|
|
75
|
+
pi = k_tile[idx]
|
|
76
|
+
accel += body_body_interaction(p0, pi)
|
|
77
|
+
|
|
78
|
+
# Advance the velocity one timestep (in-place)
|
|
79
|
+
velocity[i] = velocity[i] + accel * DT
|
|
80
|
+
|
|
81
|
+
# Advance the positions (using a second array)
|
|
82
|
+
new_position[i] = old_position[i] + DT * velocity[i]
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class Example:
|
|
86
|
+
def __init__(self, headless=False, num_bodies=16384):
|
|
87
|
+
self.num_bodies = num_bodies
|
|
88
|
+
|
|
89
|
+
rng = np.random.default_rng(42)
|
|
90
|
+
|
|
91
|
+
# Sample the surface of a sphere
|
|
92
|
+
phi = np.arccos(1.0 - 2.0 * rng.uniform(low=0.0, high=1.0, size=self.num_bodies))
|
|
93
|
+
theta = rng.uniform(low=0.0, high=2.0 * np.pi, size=self.num_bodies)
|
|
94
|
+
x = np.cos(theta) * np.sin(phi)
|
|
95
|
+
y = np.sin(theta) * np.sin(phi)
|
|
96
|
+
z = np.cos(phi)
|
|
97
|
+
init_pos_np = np.stack((x, y, z), axis=1)
|
|
98
|
+
|
|
99
|
+
scale = (num_bodies / 1024) ** (1 / 2) # Scale factor to maintain a constant density
|
|
100
|
+
inner = 0.9625 * scale
|
|
101
|
+
outer = 1.54 * scale
|
|
102
|
+
radii = inner + (outer - inner) * rng.uniform(size=(self.num_bodies, 1))
|
|
103
|
+
init_pos_np = init_pos_np * radii
|
|
104
|
+
|
|
105
|
+
axis = np.array([0.0, 0.0, 1.0])
|
|
106
|
+
v_scale = scale * 3.08
|
|
107
|
+
init_vel_np = v_scale * np.cross(init_pos_np, axis)
|
|
108
|
+
|
|
109
|
+
self.graph_scale = np.max(radii) * 5.0
|
|
110
|
+
self.pos_array_0 = wp.array(init_pos_np, dtype=wp.vec3)
|
|
111
|
+
self.pos_array_1 = wp.empty_like(self.pos_array_0)
|
|
112
|
+
self.vel_array = wp.array(init_vel_np, dtype=wp.vec3)
|
|
113
|
+
|
|
114
|
+
if headless:
|
|
115
|
+
self.scatter_plot = None
|
|
116
|
+
else:
|
|
117
|
+
self.scatter_plot = self.create_plot()
|
|
118
|
+
|
|
119
|
+
def create_plot(self):
|
|
120
|
+
import matplotlib.pyplot as plt
|
|
121
|
+
|
|
122
|
+
# Create a figure and a 3D axis for the plot
|
|
123
|
+
self.fig = plt.figure()
|
|
124
|
+
ax = self.fig.add_subplot(111, projection="3d")
|
|
125
|
+
|
|
126
|
+
# Scatter plot of initial positions
|
|
127
|
+
point_size = 0.05 * self.graph_scale
|
|
128
|
+
init_pos_np = self.pos_array_0.numpy()
|
|
129
|
+
scatter_plot = ax.scatter(
|
|
130
|
+
init_pos_np[:, 0], init_pos_np[:, 1], init_pos_np[:, 2], s=point_size, c="#76b900", alpha=0.5
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
# Set axis limits
|
|
134
|
+
ax.set_xlim(-self.graph_scale, self.graph_scale)
|
|
135
|
+
ax.set_ylim(-self.graph_scale, self.graph_scale)
|
|
136
|
+
ax.set_zlim(-self.graph_scale, self.graph_scale)
|
|
137
|
+
|
|
138
|
+
return scatter_plot
|
|
139
|
+
|
|
140
|
+
def step(self):
|
|
141
|
+
wp.launch(
|
|
142
|
+
integrate_bodies_tiled,
|
|
143
|
+
dim=self.num_bodies,
|
|
144
|
+
inputs=[self.pos_array_0, self.vel_array, self.pos_array_1, self.num_bodies],
|
|
145
|
+
block_dim=TILE_SIZE,
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
# Swap arrays
|
|
149
|
+
(self.pos_array_0, self.pos_array_1) = (self.pos_array_1, self.pos_array_0)
|
|
150
|
+
|
|
151
|
+
def render(self):
|
|
152
|
+
positions_cpu = self.pos_array_0.numpy()
|
|
153
|
+
|
|
154
|
+
# Update scatter plot positions
|
|
155
|
+
self.scatter_plot._offsets3d = (
|
|
156
|
+
positions_cpu[:, 0],
|
|
157
|
+
positions_cpu[:, 1],
|
|
158
|
+
positions_cpu[:, 2],
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
# Function to update the scatter plot
|
|
162
|
+
def step_and_render(self, frame):
|
|
163
|
+
self.step()
|
|
164
|
+
self.render()
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
if __name__ == "__main__":
|
|
168
|
+
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
|
169
|
+
parser.add_argument("--device", type=str, default=None, help="Override the default Warp device.")
|
|
170
|
+
parser.add_argument("--num_frames", type=int, default=1000, help="Total number of frames.")
|
|
171
|
+
parser.add_argument("-N", help="Number of bodies. Should be a multiple of 64.", type=int, default=16384)
|
|
172
|
+
parser.add_argument(
|
|
173
|
+
"--headless",
|
|
174
|
+
action="store_true",
|
|
175
|
+
help="Run in headless mode, suppressing the opening of any graphical windows.",
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
args = parser.parse_known_args()[0]
|
|
179
|
+
|
|
180
|
+
if args.device == "cpu":
|
|
181
|
+
print("This example only runs on CUDA devices.")
|
|
182
|
+
exit()
|
|
183
|
+
|
|
184
|
+
with wp.ScopedDevice(args.device):
|
|
185
|
+
example = Example(headless=args.headless, num_bodies=args.N)
|
|
186
|
+
|
|
187
|
+
if not args.headless:
|
|
188
|
+
import matplotlib.pyplot as plt
|
|
189
|
+
from matplotlib.animation import FuncAnimation
|
|
190
|
+
|
|
191
|
+
# Create the animation
|
|
192
|
+
ani = FuncAnimation(example.fig, example.step_and_render, frames=args.num_frames, interval=50, repeat=False)
|
|
193
|
+
|
|
194
|
+
# Display the animation
|
|
195
|
+
plt.show()
|
|
196
|
+
|
|
197
|
+
else:
|
|
198
|
+
for _ in range(args.num_frames):
|
|
199
|
+
example.step()
|
|
@@ -0,0 +1,327 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
###########################################################################
|
|
17
|
+
# Example Tile Walker
|
|
18
|
+
#
|
|
19
|
+
# Trains a tetrahedral mesh quadruped to run. Feeds 8 time-varying input
|
|
20
|
+
# phases as inputs into a single layer fully connected network with a tanh
|
|
21
|
+
# activation function. Interprets the output of the network as tet
|
|
22
|
+
# activations, which are fed into the wp.sim soft mesh model. This is
|
|
23
|
+
# simulated forward in time and then evaluated based on the center of mass
|
|
24
|
+
# momentum of the mesh.
|
|
25
|
+
#
|
|
26
|
+
# This example uses the Warp tile API, which as of Warp 1.6 is the
|
|
27
|
+
# recommended way to handle matrix multiplication. example_walker.py in
|
|
28
|
+
# examples/optim demonstrates the old way of doing matrix multiplication,
|
|
29
|
+
# wp.matmul(), which will be deprecated in a future version.
|
|
30
|
+
#
|
|
31
|
+
###########################################################################
|
|
32
|
+
|
|
33
|
+
import math
|
|
34
|
+
import os
|
|
35
|
+
|
|
36
|
+
import numpy as np
|
|
37
|
+
from pxr import Gf, Usd, UsdGeom
|
|
38
|
+
|
|
39
|
+
import warp as wp
|
|
40
|
+
import warp.examples
|
|
41
|
+
import warp.optim
|
|
42
|
+
import warp.sim
|
|
43
|
+
import warp.sim.render
|
|
44
|
+
|
|
45
|
+
PHASE_COUNT = 8
|
|
46
|
+
PHASE_STEP = wp.constant((2.0 * math.pi) / PHASE_COUNT)
|
|
47
|
+
PHASE_FREQ = wp.constant(5.0)
|
|
48
|
+
ACTIVATION_STRENGTH = wp.constant(0.3)
|
|
49
|
+
|
|
50
|
+
TILE_TETS = wp.constant(8)
|
|
51
|
+
TILE_THREADS = 64
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@wp.kernel
|
|
55
|
+
def loss_kernel(com: wp.array(dtype=wp.vec3), loss: wp.array(dtype=float)):
|
|
56
|
+
tid = wp.tid()
|
|
57
|
+
vx = com[tid][0]
|
|
58
|
+
vy = com[tid][1]
|
|
59
|
+
vz = com[tid][2]
|
|
60
|
+
delta = wp.sqrt(vx * vx) + wp.sqrt(vy * vy) - vz
|
|
61
|
+
|
|
62
|
+
wp.atomic_add(loss, 0, delta)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
@wp.kernel
|
|
66
|
+
def com_kernel(velocities: wp.array(dtype=wp.vec3), n: int, com: wp.array(dtype=wp.vec3)):
|
|
67
|
+
tid = wp.tid()
|
|
68
|
+
v = velocities[tid]
|
|
69
|
+
a = v / wp.float32(n)
|
|
70
|
+
wp.atomic_add(com, 0, a)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
@wp.kernel
|
|
74
|
+
def compute_phases(phases: wp.array(dtype=float), sim_time: float):
|
|
75
|
+
tid = wp.tid()
|
|
76
|
+
phases[tid] = wp.sin(PHASE_FREQ * sim_time + wp.float32(tid) * PHASE_STEP)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
@wp.func
|
|
80
|
+
def tanh(x: float):
|
|
81
|
+
return wp.tanh(x) * ACTIVATION_STRENGTH
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@wp.kernel
|
|
85
|
+
def network(
|
|
86
|
+
phases: wp.array2d(dtype=float), weights: wp.array2d(dtype=float), tet_activations: wp.array2d(dtype=float)
|
|
87
|
+
):
|
|
88
|
+
# output tile index
|
|
89
|
+
i = wp.tid()
|
|
90
|
+
|
|
91
|
+
# GEMM
|
|
92
|
+
p = wp.tile_load(phases, shape=(PHASE_COUNT, 1))
|
|
93
|
+
w = wp.tile_load(weights, shape=(TILE_TETS, PHASE_COUNT), offset=(i * TILE_TETS, 0))
|
|
94
|
+
out = wp.tile_matmul(w, p)
|
|
95
|
+
|
|
96
|
+
# activation
|
|
97
|
+
activations = wp.tile_map(tanh, out)
|
|
98
|
+
wp.tile_store(tet_activations, activations, offset=(i * TILE_TETS, 0))
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class Example:
|
|
102
|
+
def __init__(self, stage_path="example_tile_walker.usd", verbose=False, num_frames=300):
|
|
103
|
+
self.verbose = verbose
|
|
104
|
+
|
|
105
|
+
fps = 60
|
|
106
|
+
self.frame_dt = 1.0 / fps
|
|
107
|
+
self.num_frames = num_frames
|
|
108
|
+
|
|
109
|
+
self.sim_substeps = 80
|
|
110
|
+
self.sim_dt = self.frame_dt / self.sim_substeps
|
|
111
|
+
self.sim_time = 0.0
|
|
112
|
+
|
|
113
|
+
self.iter = 0
|
|
114
|
+
self.train_rate = 0.025
|
|
115
|
+
|
|
116
|
+
self.phase_count = PHASE_COUNT
|
|
117
|
+
|
|
118
|
+
self.render_time = 0.0
|
|
119
|
+
|
|
120
|
+
# bear
|
|
121
|
+
asset_stage = Usd.Stage.Open(os.path.join(warp.examples.get_asset_directory(), "bear.usd"))
|
|
122
|
+
|
|
123
|
+
geom = UsdGeom.Mesh(asset_stage.GetPrimAtPath("/root/bear"))
|
|
124
|
+
points = geom.GetPointsAttr().Get()
|
|
125
|
+
|
|
126
|
+
xform = Gf.Matrix4f(geom.ComputeLocalToWorldTransform(0.0))
|
|
127
|
+
for i in range(len(points)):
|
|
128
|
+
points[i] = xform.Transform(points[i])
|
|
129
|
+
|
|
130
|
+
self.points = [wp.vec3(point) for point in points]
|
|
131
|
+
self.tet_indices = geom.GetPrim().GetAttribute("tetraIndices").Get()
|
|
132
|
+
|
|
133
|
+
# sim model
|
|
134
|
+
builder = wp.sim.ModelBuilder()
|
|
135
|
+
builder.add_soft_mesh(
|
|
136
|
+
pos=wp.vec3(0.0, 0.5, 0.0),
|
|
137
|
+
rot=wp.quat_identity(),
|
|
138
|
+
scale=1.0,
|
|
139
|
+
vel=wp.vec3(0.0, 0.0, 0.0),
|
|
140
|
+
vertices=self.points,
|
|
141
|
+
indices=self.tet_indices,
|
|
142
|
+
density=1.0,
|
|
143
|
+
k_mu=2000.0,
|
|
144
|
+
k_lambda=2000.0,
|
|
145
|
+
k_damp=2.0,
|
|
146
|
+
tri_ke=0.0,
|
|
147
|
+
tri_ka=1e-8,
|
|
148
|
+
tri_kd=0.0,
|
|
149
|
+
tri_drag=0.0,
|
|
150
|
+
tri_lift=0.0,
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
# finalize model
|
|
154
|
+
self.model = builder.finalize(requires_grad=True)
|
|
155
|
+
self.control = self.model.control()
|
|
156
|
+
|
|
157
|
+
self.model.soft_contact_ke = 2.0e3
|
|
158
|
+
self.model.soft_contact_kd = 0.1
|
|
159
|
+
self.model.soft_contact_kf = 10.0
|
|
160
|
+
self.model.soft_contact_mu = 0.7
|
|
161
|
+
|
|
162
|
+
radii = wp.zeros(self.model.particle_count, dtype=float)
|
|
163
|
+
radii.fill_(0.05)
|
|
164
|
+
self.model.particle_radius = radii
|
|
165
|
+
self.model.ground = True
|
|
166
|
+
|
|
167
|
+
# allocate sim states
|
|
168
|
+
self.states = []
|
|
169
|
+
for _i in range(self.num_frames * self.sim_substeps + 1):
|
|
170
|
+
self.states.append(self.model.state(requires_grad=True))
|
|
171
|
+
|
|
172
|
+
# initialize the integrator.
|
|
173
|
+
self.integrator = wp.sim.SemiImplicitIntegrator()
|
|
174
|
+
|
|
175
|
+
# model input
|
|
176
|
+
self.phases = []
|
|
177
|
+
for _i in range(self.num_frames):
|
|
178
|
+
self.phases.append(wp.zeros(self.phase_count, dtype=float, requires_grad=True))
|
|
179
|
+
|
|
180
|
+
# weights matrix for linear network
|
|
181
|
+
rng = np.random.default_rng(42)
|
|
182
|
+
k = 1.0 / self.phase_count
|
|
183
|
+
weights = rng.uniform(-np.sqrt(k), np.sqrt(k), (self.model.tet_count, self.phase_count))
|
|
184
|
+
self.weights = wp.array(weights, dtype=float, requires_grad=True)
|
|
185
|
+
|
|
186
|
+
# tanh activation layer array
|
|
187
|
+
self.tet_activations = []
|
|
188
|
+
for _i in range(self.num_frames):
|
|
189
|
+
self.tet_activations.append(wp.zeros(self.model.tet_count, dtype=float, requires_grad=True))
|
|
190
|
+
|
|
191
|
+
# optimization
|
|
192
|
+
self.loss = wp.zeros(1, dtype=float, requires_grad=True)
|
|
193
|
+
self.coms = []
|
|
194
|
+
for _i in range(self.num_frames):
|
|
195
|
+
self.coms.append(wp.zeros(1, dtype=wp.vec3, requires_grad=True))
|
|
196
|
+
self.optimizer = warp.optim.Adam([self.weights.flatten()], lr=self.train_rate)
|
|
197
|
+
|
|
198
|
+
# rendering
|
|
199
|
+
if stage_path:
|
|
200
|
+
self.renderer = wp.sim.render.SimRenderer(self.model, stage_path)
|
|
201
|
+
else:
|
|
202
|
+
self.renderer = None
|
|
203
|
+
|
|
204
|
+
# capture forward/backward passes
|
|
205
|
+
self.use_cuda_graph = wp.get_device().is_cuda
|
|
206
|
+
if self.use_cuda_graph:
|
|
207
|
+
with wp.ScopedCapture() as capture:
|
|
208
|
+
self.tape = wp.Tape()
|
|
209
|
+
with self.tape:
|
|
210
|
+
for i in range(self.num_frames):
|
|
211
|
+
self.forward(i)
|
|
212
|
+
self.tape.backward(self.loss)
|
|
213
|
+
self.graph = capture.graph
|
|
214
|
+
|
|
215
|
+
def forward(self, frame):
|
|
216
|
+
with wp.ScopedTimer("network", active=self.verbose):
|
|
217
|
+
# build sinusoidal input phases
|
|
218
|
+
wp.launch(kernel=compute_phases, dim=self.phase_count, inputs=[self.phases[frame], self.sim_time])
|
|
219
|
+
|
|
220
|
+
# apply linear network with tanh activation
|
|
221
|
+
wp.launch_tiled(
|
|
222
|
+
kernel=network,
|
|
223
|
+
dim=math.ceil(self.model.tet_count / TILE_TETS),
|
|
224
|
+
inputs=[self.phases[frame].reshape((self.phase_count, 1)), self.weights],
|
|
225
|
+
outputs=[self.tet_activations[frame].reshape((self.model.tet_count, 1))],
|
|
226
|
+
block_dim=TILE_THREADS,
|
|
227
|
+
)
|
|
228
|
+
self.control.tet_activations = self.tet_activations[frame]
|
|
229
|
+
|
|
230
|
+
with wp.ScopedTimer("simulate", active=self.verbose):
|
|
231
|
+
# run simulation loop
|
|
232
|
+
for i in range(self.sim_substeps):
|
|
233
|
+
self.states[frame * self.sim_substeps + i].clear_forces()
|
|
234
|
+
self.integrator.simulate(
|
|
235
|
+
self.model,
|
|
236
|
+
self.states[frame * self.sim_substeps + i],
|
|
237
|
+
self.states[frame * self.sim_substeps + i + 1],
|
|
238
|
+
self.sim_dt,
|
|
239
|
+
self.control,
|
|
240
|
+
)
|
|
241
|
+
self.sim_time += self.sim_dt
|
|
242
|
+
|
|
243
|
+
with wp.ScopedTimer("loss", active=self.verbose):
|
|
244
|
+
# compute center of mass velocity
|
|
245
|
+
wp.launch(
|
|
246
|
+
com_kernel,
|
|
247
|
+
dim=self.model.particle_count,
|
|
248
|
+
inputs=[
|
|
249
|
+
self.states[(frame + 1) * self.sim_substeps].particle_qd,
|
|
250
|
+
self.model.particle_count,
|
|
251
|
+
self.coms[frame],
|
|
252
|
+
],
|
|
253
|
+
outputs=[],
|
|
254
|
+
)
|
|
255
|
+
# compute loss
|
|
256
|
+
wp.launch(loss_kernel, dim=1, inputs=[self.coms[frame], self.loss], outputs=[])
|
|
257
|
+
|
|
258
|
+
def step(self):
|
|
259
|
+
with wp.ScopedTimer("step"):
|
|
260
|
+
if self.use_cuda_graph:
|
|
261
|
+
wp.capture_launch(self.graph)
|
|
262
|
+
else:
|
|
263
|
+
self.tape = wp.Tape()
|
|
264
|
+
with self.tape:
|
|
265
|
+
for i in range(self.num_frames):
|
|
266
|
+
self.forward(i)
|
|
267
|
+
self.tape.backward(self.loss)
|
|
268
|
+
|
|
269
|
+
# optimization
|
|
270
|
+
x = self.weights.grad.flatten()
|
|
271
|
+
self.optimizer.step([x])
|
|
272
|
+
|
|
273
|
+
loss = self.loss.numpy()
|
|
274
|
+
if self.verbose:
|
|
275
|
+
print(f"Iteration {self.iter}: {loss}")
|
|
276
|
+
|
|
277
|
+
# reset sim
|
|
278
|
+
self.sim_time = 0.0
|
|
279
|
+
self.states[0] = self.model.state(requires_grad=True)
|
|
280
|
+
|
|
281
|
+
# clear grads and zero arrays for next iteration
|
|
282
|
+
self.tape.zero()
|
|
283
|
+
self.loss.zero_()
|
|
284
|
+
for i in range(self.num_frames):
|
|
285
|
+
self.coms[i].zero_()
|
|
286
|
+
|
|
287
|
+
self.iter += 1
|
|
288
|
+
|
|
289
|
+
def render(self):
|
|
290
|
+
if self.renderer is None:
|
|
291
|
+
return
|
|
292
|
+
|
|
293
|
+
with wp.ScopedTimer("render"):
|
|
294
|
+
for i in range(self.num_frames + 1):
|
|
295
|
+
self.renderer.begin_frame(self.render_time)
|
|
296
|
+
self.renderer.render(self.states[i * self.sim_substeps])
|
|
297
|
+
self.renderer.end_frame()
|
|
298
|
+
|
|
299
|
+
self.render_time += self.frame_dt
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
if __name__ == "__main__":
|
|
303
|
+
import argparse
|
|
304
|
+
|
|
305
|
+
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
|
306
|
+
parser.add_argument("--device", type=str, default=None, help="Override the default Warp device.")
|
|
307
|
+
parser.add_argument(
|
|
308
|
+
"--stage_path",
|
|
309
|
+
type=lambda x: None if x == "None" else str(x),
|
|
310
|
+
default="example_tile_walker.usd",
|
|
311
|
+
help="Path to the output USD file.",
|
|
312
|
+
)
|
|
313
|
+
parser.add_argument("--num_frames", type=int, default=300, help="Total number of frames per training iteration.")
|
|
314
|
+
parser.add_argument("--train_iters", type=int, default=30, help="Total number of training iterations.")
|
|
315
|
+
parser.add_argument("--verbose", action="store_true", help="Print out additional status messages during execution.")
|
|
316
|
+
|
|
317
|
+
args = parser.parse_known_args()[0]
|
|
318
|
+
|
|
319
|
+
with wp.ScopedDevice(args.device):
|
|
320
|
+
example = Example(stage_path=args.stage_path, verbose=args.verbose, num_frames=args.num_frames)
|
|
321
|
+
|
|
322
|
+
for _ in range(args.train_iters):
|
|
323
|
+
example.step()
|
|
324
|
+
example.render()
|
|
325
|
+
|
|
326
|
+
if example.renderer:
|
|
327
|
+
example.renderer.save()
|