warp-lang 1.7.0__py3-none-manylinux_2_34_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +139 -0
- warp/__init__.pyi +1 -0
- warp/autograd.py +1142 -0
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/build.py +557 -0
- warp/build_dll.py +405 -0
- warp/builtins.py +6855 -0
- warp/codegen.py +3969 -0
- warp/config.py +158 -0
- warp/constants.py +57 -0
- warp/context.py +6812 -0
- warp/dlpack.py +462 -0
- warp/examples/__init__.py +24 -0
- warp/examples/assets/bear.usd +0 -0
- warp/examples/assets/bunny.usd +0 -0
- warp/examples/assets/cartpole.urdf +110 -0
- warp/examples/assets/crazyflie.usd +0 -0
- warp/examples/assets/cube.usd +0 -0
- warp/examples/assets/nonuniform.usd +0 -0
- warp/examples/assets/nv_ant.xml +92 -0
- warp/examples/assets/nv_humanoid.xml +183 -0
- warp/examples/assets/nvidia_logo.png +0 -0
- warp/examples/assets/pixel.jpg +0 -0
- warp/examples/assets/quadruped.urdf +268 -0
- warp/examples/assets/rocks.nvdb +0 -0
- warp/examples/assets/rocks.usd +0 -0
- warp/examples/assets/sphere.usd +0 -0
- warp/examples/assets/square_cloth.usd +0 -0
- warp/examples/benchmarks/benchmark_api.py +389 -0
- warp/examples/benchmarks/benchmark_cloth.py +296 -0
- warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
- warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
- warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
- warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
- warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
- warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
- warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
- warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
- warp/examples/benchmarks/benchmark_gemm.py +164 -0
- warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
- warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
- warp/examples/benchmarks/benchmark_launches.py +301 -0
- warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
- warp/examples/browse.py +37 -0
- warp/examples/core/example_cupy.py +86 -0
- warp/examples/core/example_dem.py +241 -0
- warp/examples/core/example_fluid.py +299 -0
- warp/examples/core/example_graph_capture.py +150 -0
- warp/examples/core/example_marching_cubes.py +194 -0
- warp/examples/core/example_mesh.py +180 -0
- warp/examples/core/example_mesh_intersect.py +211 -0
- warp/examples/core/example_nvdb.py +182 -0
- warp/examples/core/example_raycast.py +111 -0
- warp/examples/core/example_raymarch.py +205 -0
- warp/examples/core/example_render_opengl.py +193 -0
- warp/examples/core/example_sample_mesh.py +300 -0
- warp/examples/core/example_sph.py +411 -0
- warp/examples/core/example_torch.py +211 -0
- warp/examples/core/example_wave.py +269 -0
- warp/examples/fem/example_adaptive_grid.py +286 -0
- warp/examples/fem/example_apic_fluid.py +423 -0
- warp/examples/fem/example_burgers.py +261 -0
- warp/examples/fem/example_convection_diffusion.py +178 -0
- warp/examples/fem/example_convection_diffusion_dg.py +204 -0
- warp/examples/fem/example_deformed_geometry.py +172 -0
- warp/examples/fem/example_diffusion.py +196 -0
- warp/examples/fem/example_diffusion_3d.py +225 -0
- warp/examples/fem/example_diffusion_mgpu.py +220 -0
- warp/examples/fem/example_distortion_energy.py +228 -0
- warp/examples/fem/example_magnetostatics.py +240 -0
- warp/examples/fem/example_mixed_elasticity.py +291 -0
- warp/examples/fem/example_navier_stokes.py +261 -0
- warp/examples/fem/example_nonconforming_contact.py +298 -0
- warp/examples/fem/example_stokes.py +213 -0
- warp/examples/fem/example_stokes_transfer.py +262 -0
- warp/examples/fem/example_streamlines.py +352 -0
- warp/examples/fem/utils.py +1000 -0
- warp/examples/interop/example_jax_callable.py +116 -0
- warp/examples/interop/example_jax_ffi_callback.py +132 -0
- warp/examples/interop/example_jax_kernel.py +205 -0
- warp/examples/optim/example_bounce.py +266 -0
- warp/examples/optim/example_cloth_throw.py +228 -0
- warp/examples/optim/example_diffray.py +561 -0
- warp/examples/optim/example_drone.py +870 -0
- warp/examples/optim/example_fluid_checkpoint.py +497 -0
- warp/examples/optim/example_inverse_kinematics.py +182 -0
- warp/examples/optim/example_inverse_kinematics_torch.py +191 -0
- warp/examples/optim/example_softbody_properties.py +400 -0
- warp/examples/optim/example_spring_cage.py +245 -0
- warp/examples/optim/example_trajectory.py +227 -0
- warp/examples/sim/example_cartpole.py +143 -0
- warp/examples/sim/example_cloth.py +225 -0
- warp/examples/sim/example_cloth_self_contact.py +322 -0
- warp/examples/sim/example_granular.py +130 -0
- warp/examples/sim/example_granular_collision_sdf.py +202 -0
- warp/examples/sim/example_jacobian_ik.py +244 -0
- warp/examples/sim/example_particle_chain.py +124 -0
- warp/examples/sim/example_quadruped.py +203 -0
- warp/examples/sim/example_rigid_chain.py +203 -0
- warp/examples/sim/example_rigid_contact.py +195 -0
- warp/examples/sim/example_rigid_force.py +133 -0
- warp/examples/sim/example_rigid_gyroscopic.py +115 -0
- warp/examples/sim/example_rigid_soft_contact.py +140 -0
- warp/examples/sim/example_soft_body.py +196 -0
- warp/examples/tile/example_tile_cholesky.py +87 -0
- warp/examples/tile/example_tile_convolution.py +66 -0
- warp/examples/tile/example_tile_fft.py +55 -0
- warp/examples/tile/example_tile_filtering.py +113 -0
- warp/examples/tile/example_tile_matmul.py +85 -0
- warp/examples/tile/example_tile_mlp.py +383 -0
- warp/examples/tile/example_tile_nbody.py +199 -0
- warp/examples/tile/example_tile_walker.py +327 -0
- warp/fabric.py +355 -0
- warp/fem/__init__.py +106 -0
- warp/fem/adaptivity.py +508 -0
- warp/fem/cache.py +572 -0
- warp/fem/dirichlet.py +202 -0
- warp/fem/domain.py +411 -0
- warp/fem/field/__init__.py +125 -0
- warp/fem/field/field.py +619 -0
- warp/fem/field/nodal_field.py +326 -0
- warp/fem/field/restriction.py +37 -0
- warp/fem/field/virtual.py +848 -0
- warp/fem/geometry/__init__.py +32 -0
- warp/fem/geometry/adaptive_nanogrid.py +857 -0
- warp/fem/geometry/closest_point.py +84 -0
- warp/fem/geometry/deformed_geometry.py +221 -0
- warp/fem/geometry/element.py +776 -0
- warp/fem/geometry/geometry.py +362 -0
- warp/fem/geometry/grid_2d.py +392 -0
- warp/fem/geometry/grid_3d.py +452 -0
- warp/fem/geometry/hexmesh.py +911 -0
- warp/fem/geometry/nanogrid.py +571 -0
- warp/fem/geometry/partition.py +389 -0
- warp/fem/geometry/quadmesh.py +663 -0
- warp/fem/geometry/tetmesh.py +855 -0
- warp/fem/geometry/trimesh.py +806 -0
- warp/fem/integrate.py +2335 -0
- warp/fem/linalg.py +419 -0
- warp/fem/operator.py +293 -0
- warp/fem/polynomial.py +229 -0
- warp/fem/quadrature/__init__.py +17 -0
- warp/fem/quadrature/pic_quadrature.py +299 -0
- warp/fem/quadrature/quadrature.py +591 -0
- warp/fem/space/__init__.py +228 -0
- warp/fem/space/basis_function_space.py +468 -0
- warp/fem/space/basis_space.py +667 -0
- warp/fem/space/dof_mapper.py +251 -0
- warp/fem/space/function_space.py +309 -0
- warp/fem/space/grid_2d_function_space.py +177 -0
- warp/fem/space/grid_3d_function_space.py +227 -0
- warp/fem/space/hexmesh_function_space.py +257 -0
- warp/fem/space/nanogrid_function_space.py +201 -0
- warp/fem/space/partition.py +367 -0
- warp/fem/space/quadmesh_function_space.py +223 -0
- warp/fem/space/restriction.py +179 -0
- warp/fem/space/shape/__init__.py +143 -0
- warp/fem/space/shape/cube_shape_function.py +1105 -0
- warp/fem/space/shape/shape_function.py +133 -0
- warp/fem/space/shape/square_shape_function.py +926 -0
- warp/fem/space/shape/tet_shape_function.py +834 -0
- warp/fem/space/shape/triangle_shape_function.py +672 -0
- warp/fem/space/tetmesh_function_space.py +271 -0
- warp/fem/space/topology.py +424 -0
- warp/fem/space/trimesh_function_space.py +194 -0
- warp/fem/types.py +99 -0
- warp/fem/utils.py +420 -0
- warp/jax.py +187 -0
- warp/jax_experimental/__init__.py +16 -0
- warp/jax_experimental/custom_call.py +351 -0
- warp/jax_experimental/ffi.py +698 -0
- warp/jax_experimental/xla_ffi.py +602 -0
- warp/math.py +244 -0
- warp/native/array.h +1145 -0
- warp/native/builtin.h +1800 -0
- warp/native/bvh.cpp +492 -0
- warp/native/bvh.cu +791 -0
- warp/native/bvh.h +554 -0
- warp/native/clang/clang.cpp +536 -0
- warp/native/coloring.cpp +613 -0
- warp/native/crt.cpp +51 -0
- warp/native/crt.h +362 -0
- warp/native/cuda_crt.h +1058 -0
- warp/native/cuda_util.cpp +646 -0
- warp/native/cuda_util.h +307 -0
- warp/native/error.cpp +77 -0
- warp/native/error.h +36 -0
- warp/native/exports.h +1878 -0
- warp/native/fabric.h +245 -0
- warp/native/hashgrid.cpp +311 -0
- warp/native/hashgrid.cu +87 -0
- warp/native/hashgrid.h +240 -0
- warp/native/initializer_array.h +41 -0
- warp/native/intersect.h +1230 -0
- warp/native/intersect_adj.h +375 -0
- warp/native/intersect_tri.h +339 -0
- warp/native/marching.cpp +19 -0
- warp/native/marching.cu +514 -0
- warp/native/marching.h +19 -0
- warp/native/mat.h +2220 -0
- warp/native/mathdx.cpp +87 -0
- warp/native/matnn.h +343 -0
- warp/native/mesh.cpp +266 -0
- warp/native/mesh.cu +404 -0
- warp/native/mesh.h +1980 -0
- warp/native/nanovdb/GridHandle.h +366 -0
- warp/native/nanovdb/HostBuffer.h +590 -0
- warp/native/nanovdb/NanoVDB.h +6624 -0
- warp/native/nanovdb/PNanoVDB.h +3390 -0
- warp/native/noise.h +859 -0
- warp/native/quat.h +1371 -0
- warp/native/rand.h +342 -0
- warp/native/range.h +139 -0
- warp/native/reduce.cpp +174 -0
- warp/native/reduce.cu +364 -0
- warp/native/runlength_encode.cpp +79 -0
- warp/native/runlength_encode.cu +61 -0
- warp/native/scan.cpp +47 -0
- warp/native/scan.cu +53 -0
- warp/native/scan.h +23 -0
- warp/native/solid_angle.h +466 -0
- warp/native/sort.cpp +251 -0
- warp/native/sort.cu +277 -0
- warp/native/sort.h +33 -0
- warp/native/sparse.cpp +378 -0
- warp/native/sparse.cu +524 -0
- warp/native/spatial.h +657 -0
- warp/native/svd.h +702 -0
- warp/native/temp_buffer.h +46 -0
- warp/native/tile.h +2584 -0
- warp/native/tile_reduce.h +264 -0
- warp/native/vec.h +1426 -0
- warp/native/volume.cpp +501 -0
- warp/native/volume.cu +67 -0
- warp/native/volume.h +969 -0
- warp/native/volume_builder.cu +477 -0
- warp/native/volume_builder.h +52 -0
- warp/native/volume_impl.h +70 -0
- warp/native/warp.cpp +1082 -0
- warp/native/warp.cu +3636 -0
- warp/native/warp.h +381 -0
- warp/optim/__init__.py +17 -0
- warp/optim/adam.py +163 -0
- warp/optim/linear.py +1137 -0
- warp/optim/sgd.py +112 -0
- warp/paddle.py +407 -0
- warp/render/__init__.py +18 -0
- warp/render/render_opengl.py +3518 -0
- warp/render/render_usd.py +784 -0
- warp/render/utils.py +160 -0
- warp/sim/__init__.py +65 -0
- warp/sim/articulation.py +793 -0
- warp/sim/collide.py +2395 -0
- warp/sim/graph_coloring.py +300 -0
- warp/sim/import_mjcf.py +790 -0
- warp/sim/import_snu.py +227 -0
- warp/sim/import_urdf.py +579 -0
- warp/sim/import_usd.py +894 -0
- warp/sim/inertia.py +324 -0
- warp/sim/integrator.py +242 -0
- warp/sim/integrator_euler.py +1997 -0
- warp/sim/integrator_featherstone.py +2101 -0
- warp/sim/integrator_vbd.py +2048 -0
- warp/sim/integrator_xpbd.py +3292 -0
- warp/sim/model.py +4791 -0
- warp/sim/particles.py +121 -0
- warp/sim/render.py +427 -0
- warp/sim/utils.py +428 -0
- warp/sparse.py +2057 -0
- warp/stubs.py +3333 -0
- warp/tape.py +1203 -0
- warp/tests/__init__.py +1 -0
- warp/tests/__main__.py +4 -0
- warp/tests/assets/curlnoise_golden.npy +0 -0
- warp/tests/assets/mlp_golden.npy +0 -0
- warp/tests/assets/pixel.npy +0 -0
- warp/tests/assets/pnoise_golden.npy +0 -0
- warp/tests/assets/spiky.usd +0 -0
- warp/tests/assets/test_grid.nvdb +0 -0
- warp/tests/assets/test_index_grid.nvdb +0 -0
- warp/tests/assets/test_int32_grid.nvdb +0 -0
- warp/tests/assets/test_vec_grid.nvdb +0 -0
- warp/tests/assets/torus.nvdb +0 -0
- warp/tests/assets/torus.usda +105 -0
- warp/tests/aux_test_class_kernel.py +34 -0
- warp/tests/aux_test_compile_consts_dummy.py +18 -0
- warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
- warp/tests/aux_test_dependent.py +29 -0
- warp/tests/aux_test_grad_customs.py +29 -0
- warp/tests/aux_test_instancing_gc.py +26 -0
- warp/tests/aux_test_module_unload.py +23 -0
- warp/tests/aux_test_name_clash1.py +40 -0
- warp/tests/aux_test_name_clash2.py +40 -0
- warp/tests/aux_test_reference.py +9 -0
- warp/tests/aux_test_reference_reference.py +8 -0
- warp/tests/aux_test_square.py +16 -0
- warp/tests/aux_test_unresolved_func.py +22 -0
- warp/tests/aux_test_unresolved_symbol.py +22 -0
- warp/tests/cuda/__init__.py +0 -0
- warp/tests/cuda/test_async.py +676 -0
- warp/tests/cuda/test_ipc.py +124 -0
- warp/tests/cuda/test_mempool.py +233 -0
- warp/tests/cuda/test_multigpu.py +169 -0
- warp/tests/cuda/test_peer.py +139 -0
- warp/tests/cuda/test_pinned.py +84 -0
- warp/tests/cuda/test_streams.py +634 -0
- warp/tests/geometry/__init__.py +0 -0
- warp/tests/geometry/test_bvh.py +200 -0
- warp/tests/geometry/test_hash_grid.py +221 -0
- warp/tests/geometry/test_marching_cubes.py +74 -0
- warp/tests/geometry/test_mesh.py +316 -0
- warp/tests/geometry/test_mesh_query_aabb.py +399 -0
- warp/tests/geometry/test_mesh_query_point.py +932 -0
- warp/tests/geometry/test_mesh_query_ray.py +311 -0
- warp/tests/geometry/test_volume.py +1103 -0
- warp/tests/geometry/test_volume_write.py +346 -0
- warp/tests/interop/__init__.py +0 -0
- warp/tests/interop/test_dlpack.py +729 -0
- warp/tests/interop/test_jax.py +371 -0
- warp/tests/interop/test_paddle.py +800 -0
- warp/tests/interop/test_torch.py +1001 -0
- warp/tests/run_coverage_serial.py +39 -0
- warp/tests/sim/__init__.py +0 -0
- warp/tests/sim/disabled_kinematics.py +244 -0
- warp/tests/sim/flaky_test_sim_grad.py +290 -0
- warp/tests/sim/test_collision.py +604 -0
- warp/tests/sim/test_coloring.py +258 -0
- warp/tests/sim/test_model.py +224 -0
- warp/tests/sim/test_sim_grad_bounce_linear.py +212 -0
- warp/tests/sim/test_sim_kinematics.py +98 -0
- warp/tests/sim/test_vbd.py +597 -0
- warp/tests/test_adam.py +163 -0
- warp/tests/test_arithmetic.py +1096 -0
- warp/tests/test_array.py +2972 -0
- warp/tests/test_array_reduce.py +156 -0
- warp/tests/test_assert.py +250 -0
- warp/tests/test_atomic.py +153 -0
- warp/tests/test_bool.py +220 -0
- warp/tests/test_builtins_resolution.py +1298 -0
- warp/tests/test_closest_point_edge_edge.py +327 -0
- warp/tests/test_codegen.py +810 -0
- warp/tests/test_codegen_instancing.py +1495 -0
- warp/tests/test_compile_consts.py +215 -0
- warp/tests/test_conditional.py +252 -0
- warp/tests/test_context.py +42 -0
- warp/tests/test_copy.py +238 -0
- warp/tests/test_ctypes.py +638 -0
- warp/tests/test_dense.py +73 -0
- warp/tests/test_devices.py +97 -0
- warp/tests/test_examples.py +482 -0
- warp/tests/test_fabricarray.py +996 -0
- warp/tests/test_fast_math.py +74 -0
- warp/tests/test_fem.py +2003 -0
- warp/tests/test_fp16.py +136 -0
- warp/tests/test_func.py +454 -0
- warp/tests/test_future_annotations.py +98 -0
- warp/tests/test_generics.py +656 -0
- warp/tests/test_grad.py +893 -0
- warp/tests/test_grad_customs.py +339 -0
- warp/tests/test_grad_debug.py +341 -0
- warp/tests/test_implicit_init.py +411 -0
- warp/tests/test_import.py +45 -0
- warp/tests/test_indexedarray.py +1140 -0
- warp/tests/test_intersect.py +73 -0
- warp/tests/test_iter.py +76 -0
- warp/tests/test_large.py +177 -0
- warp/tests/test_launch.py +411 -0
- warp/tests/test_lerp.py +151 -0
- warp/tests/test_linear_solvers.py +193 -0
- warp/tests/test_lvalue.py +427 -0
- warp/tests/test_mat.py +2089 -0
- warp/tests/test_mat_lite.py +122 -0
- warp/tests/test_mat_scalar_ops.py +2913 -0
- warp/tests/test_math.py +178 -0
- warp/tests/test_mlp.py +282 -0
- warp/tests/test_module_hashing.py +258 -0
- warp/tests/test_modules_lite.py +44 -0
- warp/tests/test_noise.py +252 -0
- warp/tests/test_operators.py +299 -0
- warp/tests/test_options.py +129 -0
- warp/tests/test_overwrite.py +551 -0
- warp/tests/test_print.py +339 -0
- warp/tests/test_quat.py +2315 -0
- warp/tests/test_rand.py +339 -0
- warp/tests/test_reload.py +302 -0
- warp/tests/test_rounding.py +185 -0
- warp/tests/test_runlength_encode.py +196 -0
- warp/tests/test_scalar_ops.py +105 -0
- warp/tests/test_smoothstep.py +108 -0
- warp/tests/test_snippet.py +318 -0
- warp/tests/test_sparse.py +582 -0
- warp/tests/test_spatial.py +2229 -0
- warp/tests/test_special_values.py +361 -0
- warp/tests/test_static.py +592 -0
- warp/tests/test_struct.py +734 -0
- warp/tests/test_tape.py +204 -0
- warp/tests/test_transient_module.py +93 -0
- warp/tests/test_triangle_closest_point.py +145 -0
- warp/tests/test_types.py +562 -0
- warp/tests/test_utils.py +588 -0
- warp/tests/test_vec.py +1487 -0
- warp/tests/test_vec_lite.py +80 -0
- warp/tests/test_vec_scalar_ops.py +2327 -0
- warp/tests/test_verify_fp.py +100 -0
- warp/tests/tile/__init__.py +0 -0
- warp/tests/tile/test_tile.py +780 -0
- warp/tests/tile/test_tile_load.py +407 -0
- warp/tests/tile/test_tile_mathdx.py +208 -0
- warp/tests/tile/test_tile_mlp.py +402 -0
- warp/tests/tile/test_tile_reduce.py +447 -0
- warp/tests/tile/test_tile_shared_memory.py +247 -0
- warp/tests/tile/test_tile_view.py +173 -0
- warp/tests/unittest_serial.py +47 -0
- warp/tests/unittest_suites.py +427 -0
- warp/tests/unittest_utils.py +468 -0
- warp/tests/walkthrough_debug.py +93 -0
- warp/thirdparty/__init__.py +0 -0
- warp/thirdparty/appdirs.py +598 -0
- warp/thirdparty/dlpack.py +145 -0
- warp/thirdparty/unittest_parallel.py +570 -0
- warp/torch.py +391 -0
- warp/types.py +5230 -0
- warp/utils.py +1137 -0
- warp_lang-1.7.0.dist-info/METADATA +516 -0
- warp_lang-1.7.0.dist-info/RECORD +429 -0
- warp_lang-1.7.0.dist-info/WHEEL +5 -0
- warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
- warp_lang-1.7.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import numpy as np
|
|
17
|
+
import taichi as ti
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@ti.func
|
|
21
|
+
def step(x):
|
|
22
|
+
ret = 0.0
|
|
23
|
+
if x < 0:
|
|
24
|
+
ret = 1
|
|
25
|
+
return ret
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@ti.data_oriented
|
|
29
|
+
class TiIntegrator:
|
|
30
|
+
@ti.kernel
|
|
31
|
+
def eval_springs(self):
|
|
32
|
+
for tid in range(self.cloth.num_springs):
|
|
33
|
+
i = self.spring_indices[2 * tid]
|
|
34
|
+
j = self.spring_indices[2 * tid + 1]
|
|
35
|
+
|
|
36
|
+
ke = self.spring_stiffness[tid]
|
|
37
|
+
kd = self.spring_damping[tid]
|
|
38
|
+
rest = self.spring_lengths[tid]
|
|
39
|
+
|
|
40
|
+
xi = self.positions[i]
|
|
41
|
+
xj = self.positions[j]
|
|
42
|
+
|
|
43
|
+
vi = self.velocities[i]
|
|
44
|
+
vj = self.velocities[j]
|
|
45
|
+
|
|
46
|
+
xij = xi - xj
|
|
47
|
+
vij = vi - vj
|
|
48
|
+
|
|
49
|
+
l = xij.norm()
|
|
50
|
+
dir = xij.normalized()
|
|
51
|
+
|
|
52
|
+
c = l - rest
|
|
53
|
+
dcdt = dir.dot(vij)
|
|
54
|
+
|
|
55
|
+
fs = dir * (ke * c + kd * dcdt)
|
|
56
|
+
|
|
57
|
+
self.forces[i] -= fs
|
|
58
|
+
self.forces[j] += fs
|
|
59
|
+
|
|
60
|
+
@ti.kernel
|
|
61
|
+
def integrate_particles(self, dt: ti.f32):
|
|
62
|
+
for tid in range(self.cloth.num_particles):
|
|
63
|
+
x0 = self.positions[tid]
|
|
64
|
+
v0 = self.velocities[tid]
|
|
65
|
+
f0 = self.forces[tid]
|
|
66
|
+
w = self.inv_mass[tid]
|
|
67
|
+
|
|
68
|
+
g = ti.Vector([0.0, 0.0, 0.0])
|
|
69
|
+
|
|
70
|
+
if w > 0.0:
|
|
71
|
+
g = ti.Vector([0.0, -9.81, 0.0])
|
|
72
|
+
|
|
73
|
+
v1 = v0 + (f0 * w + g) * dt
|
|
74
|
+
x1 = x0 + v1 * dt
|
|
75
|
+
|
|
76
|
+
self.positions[tid] = x1
|
|
77
|
+
self.velocities[tid] = v1
|
|
78
|
+
self.forces[tid] = ti.Vector([0.0, 0.0, 0.0])
|
|
79
|
+
|
|
80
|
+
def __init__(self, cloth, device):
|
|
81
|
+
if device == "cpu":
|
|
82
|
+
ti.init(arch=ti.cpu)
|
|
83
|
+
elif device == "cuda":
|
|
84
|
+
ti.init(arch=ti.gpu)
|
|
85
|
+
else:
|
|
86
|
+
raise RuntimeError("Unsupported Taichi device")
|
|
87
|
+
|
|
88
|
+
self.cloth = cloth
|
|
89
|
+
|
|
90
|
+
self.positions = ti.Vector.field(3, dtype=ti.f32, shape=self.cloth.num_particles)
|
|
91
|
+
self.velocities = ti.Vector.field(3, dtype=ti.f32, shape=self.cloth.num_particles)
|
|
92
|
+
self.inv_mass = ti.field(ti.f32, shape=self.cloth.num_particles)
|
|
93
|
+
|
|
94
|
+
self.spring_indices = ti.field(ti.i32, shape=self.cloth.num_springs * 2)
|
|
95
|
+
self.spring_lengths = ti.field(ti.f32, shape=self.cloth.num_springs)
|
|
96
|
+
self.spring_stiffness = ti.field(ti.f32, shape=self.cloth.num_springs)
|
|
97
|
+
self.spring_damping = ti.field(ti.f32, shape=self.cloth.num_springs)
|
|
98
|
+
|
|
99
|
+
self.forces = ti.Vector.field(3, dtype=ti.f32, shape=self.cloth.num_particles)
|
|
100
|
+
|
|
101
|
+
# upload data
|
|
102
|
+
self.positions.from_numpy(cloth.positions)
|
|
103
|
+
self.velocities.from_numpy(cloth.velocities)
|
|
104
|
+
self.inv_mass.from_numpy(cloth.inv_masses)
|
|
105
|
+
self.forces.from_numpy(np.zeros_like(self.cloth.velocities))
|
|
106
|
+
|
|
107
|
+
self.spring_indices.from_numpy(cloth.spring_indices)
|
|
108
|
+
self.spring_lengths.from_numpy(cloth.spring_lengths)
|
|
109
|
+
self.spring_stiffness.from_numpy(cloth.spring_stiffness)
|
|
110
|
+
self.spring_damping.from_numpy(cloth.spring_damping)
|
|
111
|
+
|
|
112
|
+
def simulate(self, dt, substeps):
|
|
113
|
+
sim_dt = dt / substeps
|
|
114
|
+
|
|
115
|
+
for _s in range(substeps):
|
|
116
|
+
self.eval_springs()
|
|
117
|
+
|
|
118
|
+
self.integrate_particles(sim_dt)
|
|
119
|
+
|
|
120
|
+
return self.positions.to_numpy()
|
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import warp as wp
|
|
17
|
+
|
|
18
|
+
wp.clear_kernel_cache()
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@wp.kernel
|
|
22
|
+
def eval_springs(
|
|
23
|
+
x: wp.array(dtype=wp.vec3),
|
|
24
|
+
v: wp.array(dtype=wp.vec3),
|
|
25
|
+
spring_indices: wp.array(dtype=int),
|
|
26
|
+
spring_rest_lengths: wp.array(dtype=float),
|
|
27
|
+
spring_stiffness: wp.array(dtype=float),
|
|
28
|
+
spring_damping: wp.array(dtype=float),
|
|
29
|
+
f: wp.array(dtype=wp.vec3),
|
|
30
|
+
):
|
|
31
|
+
tid = wp.tid()
|
|
32
|
+
|
|
33
|
+
i = spring_indices[tid * 2 + 0]
|
|
34
|
+
j = spring_indices[tid * 2 + 1]
|
|
35
|
+
|
|
36
|
+
ke = spring_stiffness[tid]
|
|
37
|
+
kd = spring_damping[tid]
|
|
38
|
+
rest = spring_rest_lengths[tid]
|
|
39
|
+
|
|
40
|
+
xi = x[i]
|
|
41
|
+
xj = x[j]
|
|
42
|
+
|
|
43
|
+
vi = v[i]
|
|
44
|
+
vj = v[j]
|
|
45
|
+
|
|
46
|
+
xij = xi - xj
|
|
47
|
+
vij = vi - vj
|
|
48
|
+
|
|
49
|
+
l = wp.length(xij)
|
|
50
|
+
l_inv = 1.0 / l
|
|
51
|
+
|
|
52
|
+
# normalized spring direction
|
|
53
|
+
dir = xij * l_inv
|
|
54
|
+
|
|
55
|
+
c = l - rest
|
|
56
|
+
dcdt = wp.dot(dir, vij)
|
|
57
|
+
|
|
58
|
+
# damping based on relative velocity.
|
|
59
|
+
fs = dir * (ke * c + kd * dcdt)
|
|
60
|
+
|
|
61
|
+
wp.atomic_sub(f, i, fs)
|
|
62
|
+
wp.atomic_add(f, j, fs)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
@wp.kernel
|
|
66
|
+
def integrate_particles(
|
|
67
|
+
x: wp.array(dtype=wp.vec3),
|
|
68
|
+
v: wp.array(dtype=wp.vec3),
|
|
69
|
+
f: wp.array(dtype=wp.vec3),
|
|
70
|
+
w: wp.array(dtype=float),
|
|
71
|
+
dt: float,
|
|
72
|
+
):
|
|
73
|
+
tid = wp.tid()
|
|
74
|
+
|
|
75
|
+
x0 = x[tid]
|
|
76
|
+
v0 = v[tid]
|
|
77
|
+
f0 = f[tid]
|
|
78
|
+
inv_mass = w[tid]
|
|
79
|
+
|
|
80
|
+
g = wp.vec3()
|
|
81
|
+
|
|
82
|
+
# treat particles with inv_mass == 0 as kinematic
|
|
83
|
+
if inv_mass > 0.0:
|
|
84
|
+
g = wp.vec3(0.0, 0.0 - 9.81, 0.0)
|
|
85
|
+
|
|
86
|
+
# simple semi-implicit Euler. v1 = v0 + a dt, x1 = x0 + v1 dt
|
|
87
|
+
v1 = v0 + (f0 * inv_mass + g) * dt
|
|
88
|
+
x1 = x0 + v1 * dt
|
|
89
|
+
|
|
90
|
+
x[tid] = x1
|
|
91
|
+
v[tid] = v1
|
|
92
|
+
|
|
93
|
+
# clear forces
|
|
94
|
+
f[tid] = wp.vec3()
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class WpIntegrator:
|
|
98
|
+
def __init__(self, cloth, device):
|
|
99
|
+
self.device = wp.get_device(device)
|
|
100
|
+
|
|
101
|
+
with wp.ScopedDevice(self.device):
|
|
102
|
+
self.positions = wp.from_numpy(cloth.positions, dtype=wp.vec3)
|
|
103
|
+
self.positions_host = wp.from_numpy(cloth.positions, dtype=wp.vec3, device="cpu")
|
|
104
|
+
self.invmass = wp.from_numpy(cloth.inv_masses, dtype=float)
|
|
105
|
+
|
|
106
|
+
self.velocities = wp.zeros(cloth.num_particles, dtype=wp.vec3)
|
|
107
|
+
self.forces = wp.zeros(cloth.num_particles, dtype=wp.vec3)
|
|
108
|
+
|
|
109
|
+
self.spring_indices = wp.from_numpy(cloth.spring_indices, dtype=int)
|
|
110
|
+
self.spring_lengths = wp.from_numpy(cloth.spring_lengths, dtype=float)
|
|
111
|
+
self.spring_stiffness = wp.from_numpy(cloth.spring_stiffness, dtype=float)
|
|
112
|
+
self.spring_damping = wp.from_numpy(cloth.spring_damping, dtype=float)
|
|
113
|
+
|
|
114
|
+
self.cloth = cloth
|
|
115
|
+
|
|
116
|
+
def simulate(self, dt, substeps):
|
|
117
|
+
sim_dt = dt / substeps
|
|
118
|
+
|
|
119
|
+
for _s in range(substeps):
|
|
120
|
+
wp.launch(
|
|
121
|
+
kernel=eval_springs,
|
|
122
|
+
dim=self.cloth.num_springs,
|
|
123
|
+
inputs=[
|
|
124
|
+
self.positions,
|
|
125
|
+
self.velocities,
|
|
126
|
+
self.spring_indices,
|
|
127
|
+
self.spring_lengths,
|
|
128
|
+
self.spring_stiffness,
|
|
129
|
+
self.spring_damping,
|
|
130
|
+
self.forces,
|
|
131
|
+
],
|
|
132
|
+
outputs=[],
|
|
133
|
+
device=self.device,
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
# integrate
|
|
137
|
+
wp.launch(
|
|
138
|
+
kernel=integrate_particles,
|
|
139
|
+
dim=self.cloth.num_particles,
|
|
140
|
+
inputs=[self.positions, self.velocities, self.forces, self.invmass, sim_dt],
|
|
141
|
+
outputs=[],
|
|
142
|
+
device=self.device,
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
# copy data back to host
|
|
146
|
+
if self.device.is_cuda:
|
|
147
|
+
wp.copy(self.positions_host, self.positions)
|
|
148
|
+
wp.synchronize()
|
|
149
|
+
|
|
150
|
+
return self.positions_host.numpy()
|
|
151
|
+
|
|
152
|
+
else:
|
|
153
|
+
return self.positions.numpy()
|
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
"""Compare GEMM performance between Torch and Warp (Tiled).
|
|
17
|
+
|
|
18
|
+
This script can be used to identify optimal tile parameters for a fixed-size
|
|
19
|
+
matrix multiplication.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
from itertools import product
|
|
23
|
+
from statistics import mean, stdev
|
|
24
|
+
from typing import List
|
|
25
|
+
|
|
26
|
+
import numpy as np
|
|
27
|
+
import torch
|
|
28
|
+
|
|
29
|
+
import warp as wp
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
# returns a kernel to compute a GEMM given m,n,k tile sizes
|
|
33
|
+
def create_gemm_kernel(m, n, k):
|
|
34
|
+
TILE_M = m
|
|
35
|
+
TILE_N = n
|
|
36
|
+
TILE_K = k
|
|
37
|
+
|
|
38
|
+
@wp.kernel
|
|
39
|
+
def gemm(A: wp.array2d(dtype=float), B: wp.array2d(dtype=float), output: wp.array2d(dtype=float)):
|
|
40
|
+
i, j = wp.tid()
|
|
41
|
+
sum = wp.tile_zeros(shape=(TILE_M, TILE_N), dtype=wp.float32)
|
|
42
|
+
|
|
43
|
+
count = A.shape[1] // TILE_K
|
|
44
|
+
|
|
45
|
+
for k in range(count):
|
|
46
|
+
a = wp.tile_load(A, shape=(TILE_M, TILE_K), offset=(i * TILE_M, k * TILE_K))
|
|
47
|
+
b = wp.tile_load(B, shape=(TILE_K, TILE_N), offset=(k * TILE_K, j * TILE_N))
|
|
48
|
+
|
|
49
|
+
wp.tile_matmul(a, b, sum)
|
|
50
|
+
|
|
51
|
+
wp.tile_store(output, sum, offset=(i * TILE_M, j * TILE_N))
|
|
52
|
+
|
|
53
|
+
return gemm
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def benchmark_torch(A: torch.Tensor, B: torch.Tensor, warm_up: int, iterations: int):
|
|
57
|
+
# warm-up
|
|
58
|
+
for _ in range(warm_up):
|
|
59
|
+
torch.matmul(A, B)
|
|
60
|
+
|
|
61
|
+
torch.cuda.synchronize()
|
|
62
|
+
|
|
63
|
+
start_event = torch.cuda.Event(enable_timing=True)
|
|
64
|
+
end_event = torch.cuda.Event(enable_timing=True)
|
|
65
|
+
|
|
66
|
+
timing_results = []
|
|
67
|
+
|
|
68
|
+
for _i in range(iterations):
|
|
69
|
+
start_event.record()
|
|
70
|
+
torch.matmul(A, B)
|
|
71
|
+
end_event.record()
|
|
72
|
+
|
|
73
|
+
torch.cuda.synchronize()
|
|
74
|
+
timing_results.append(start_event.elapsed_time(end_event))
|
|
75
|
+
|
|
76
|
+
return mean(timing_results), stdev(timing_results)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def benchmark_warp(A: wp.array, B: wp.array, config: List[int], warm_up: int, iterations: int):
|
|
80
|
+
TILE_M = config[0]
|
|
81
|
+
TILE_N = config[1]
|
|
82
|
+
TILE_K = config[2]
|
|
83
|
+
BLOCK_DIM = config[3]
|
|
84
|
+
|
|
85
|
+
mlp = create_gemm_kernel(TILE_M, TILE_N, TILE_K)
|
|
86
|
+
|
|
87
|
+
M = A.shape[0]
|
|
88
|
+
N = B.shape[1]
|
|
89
|
+
|
|
90
|
+
output = wp.zeros((M, N), dtype=float)
|
|
91
|
+
|
|
92
|
+
# create launch command
|
|
93
|
+
cmd = wp.launch_tiled(
|
|
94
|
+
kernel=mlp,
|
|
95
|
+
dim=[M // TILE_M, N // TILE_N],
|
|
96
|
+
inputs=[A, B, output],
|
|
97
|
+
block_dim=BLOCK_DIM,
|
|
98
|
+
record_cmd=True,
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
# warm-up
|
|
102
|
+
for _ in range(warm_up):
|
|
103
|
+
cmd.launch()
|
|
104
|
+
|
|
105
|
+
# check output
|
|
106
|
+
if warm_up > 0:
|
|
107
|
+
try:
|
|
108
|
+
np.testing.assert_allclose(output.numpy(), A.numpy() @ B.numpy(), atol=1e-3, rtol=1e-3)
|
|
109
|
+
except AssertionError as e:
|
|
110
|
+
print(f"Failed with {TILE_M=}, {TILE_N=}, {TILE_K=}, {BLOCK_DIM=}")
|
|
111
|
+
raise e
|
|
112
|
+
|
|
113
|
+
# benchmark
|
|
114
|
+
with wp.ScopedTimer("warp", print=False, synchronize=True, cuda_filter=wp.TIMING_KERNEL) as timer:
|
|
115
|
+
for _ in range(iterations):
|
|
116
|
+
cmd.launch()
|
|
117
|
+
|
|
118
|
+
timing_results = [result.elapsed for result in timer.timing_results]
|
|
119
|
+
|
|
120
|
+
return mean(timing_results), stdev(timing_results)
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
if __name__ == "__main__":
|
|
124
|
+
torch.backends.cuda.matmul.allow_tf32 = False # Disable TF32 for matrix multiplications
|
|
125
|
+
torch.backends.cudnn.allow_tf32 = False # Disable TF32 for cuDNN operations
|
|
126
|
+
|
|
127
|
+
wp.init()
|
|
128
|
+
wp.clear_kernel_cache()
|
|
129
|
+
wp.set_module_options({"fast_math": True, "enable_backward": False})
|
|
130
|
+
|
|
131
|
+
tile_m = [8, 16, 32, 64]
|
|
132
|
+
tile_n = [8, 16, 32, 64]
|
|
133
|
+
tile_k = [8, 16, 64]
|
|
134
|
+
block = [32, 64, 128]
|
|
135
|
+
|
|
136
|
+
M = 1024
|
|
137
|
+
N = 1024
|
|
138
|
+
K = 1024
|
|
139
|
+
print(f"{M=}, {N=}, {K=}")
|
|
140
|
+
|
|
141
|
+
A = torch.randn(M, K).cuda()
|
|
142
|
+
B = torch.randn(K, N).cuda()
|
|
143
|
+
|
|
144
|
+
iterations = 100
|
|
145
|
+
warm_up = 5
|
|
146
|
+
|
|
147
|
+
time_torch_mean, time_torch_std = benchmark_torch(A, B, warm_up, iterations)
|
|
148
|
+
print(f"Torch: {time_torch_mean:.6g}±{time_torch_std:.2g} ms")
|
|
149
|
+
|
|
150
|
+
configs = list(product(tile_m, tile_n, tile_k, block))
|
|
151
|
+
|
|
152
|
+
wp.config.quiet = True
|
|
153
|
+
|
|
154
|
+
# header
|
|
155
|
+
print(
|
|
156
|
+
f"{'TILE_M':<8s} {'TILE_N':<8s} {'TILE_K':<8s} {'BLOCK':<8s} {'Time (ms)':<10s} {'Std dev (ms)':<14s} {'Warp/Torch':<12s}"
|
|
157
|
+
)
|
|
158
|
+
print("-" * 79)
|
|
159
|
+
|
|
160
|
+
for c in configs:
|
|
161
|
+
time_mean, time_std = benchmark_warp(wp.from_torch(A), wp.from_torch(B), c, warm_up, iterations)
|
|
162
|
+
print(
|
|
163
|
+
f"{c[0]:<8d} {c[1]:<8d} {c[2]:<8d} {c[3]:<8d} {time_mean:<10.6g} {time_std:<#14.2g} {time_mean / time_torch_mean:<12.6g}"
|
|
164
|
+
)
|
|
@@ -0,0 +1,166 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import time
|
|
17
|
+
|
|
18
|
+
import paddle
|
|
19
|
+
|
|
20
|
+
import warp as wp
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def create_simple_kernel(dtype):
|
|
24
|
+
def simple_kernel(
|
|
25
|
+
a: wp.array(dtype=dtype),
|
|
26
|
+
b: wp.array(dtype=dtype),
|
|
27
|
+
c: wp.array(dtype=dtype),
|
|
28
|
+
d: wp.array(dtype=dtype),
|
|
29
|
+
e: wp.array(dtype=dtype),
|
|
30
|
+
):
|
|
31
|
+
pass
|
|
32
|
+
|
|
33
|
+
return wp.Kernel(simple_kernel)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def test_from_paddle(kernel, num_iters, array_size, device, warp_dtype=None):
|
|
37
|
+
warp_device = wp.get_device(device)
|
|
38
|
+
paddle_device = wp.device_to_paddle(warp_device)
|
|
39
|
+
|
|
40
|
+
if hasattr(warp_dtype, "_shape_"):
|
|
41
|
+
paddle_shape = (array_size, *warp_dtype._shape_)
|
|
42
|
+
paddle_dtype = wp.dtype_to_paddle(warp_dtype._wp_scalar_type_)
|
|
43
|
+
else:
|
|
44
|
+
paddle_shape = (array_size,)
|
|
45
|
+
paddle_dtype = paddle.float32 if warp_dtype is None else wp.dtype_to_paddle(warp_dtype)
|
|
46
|
+
|
|
47
|
+
_a = paddle.zeros(paddle_shape, dtype=paddle_dtype).to(device=paddle_device)
|
|
48
|
+
_b = paddle.zeros(paddle_shape, dtype=paddle_dtype).to(device=paddle_device)
|
|
49
|
+
_c = paddle.zeros(paddle_shape, dtype=paddle_dtype).to(device=paddle_device)
|
|
50
|
+
_d = paddle.zeros(paddle_shape, dtype=paddle_dtype).to(device=paddle_device)
|
|
51
|
+
_e = paddle.zeros(paddle_shape, dtype=paddle_dtype).to(device=paddle_device)
|
|
52
|
+
|
|
53
|
+
wp.synchronize()
|
|
54
|
+
|
|
55
|
+
# profiler = Profiler(interval=0.000001)
|
|
56
|
+
# profiler.start()
|
|
57
|
+
|
|
58
|
+
t1 = time.time_ns()
|
|
59
|
+
|
|
60
|
+
for _ in range(num_iters):
|
|
61
|
+
a = wp.from_paddle(_a, dtype=warp_dtype)
|
|
62
|
+
b = wp.from_paddle(_b, dtype=warp_dtype)
|
|
63
|
+
c = wp.from_paddle(_c, dtype=warp_dtype)
|
|
64
|
+
d = wp.from_paddle(_d, dtype=warp_dtype)
|
|
65
|
+
e = wp.from_paddle(_e, dtype=warp_dtype)
|
|
66
|
+
wp.launch(kernel, dim=array_size, inputs=[a, b, c, d, e])
|
|
67
|
+
|
|
68
|
+
t2 = time.time_ns()
|
|
69
|
+
print(f"{(t2 - t1) / 1_000_000:8.0f} ms from_paddle(...)")
|
|
70
|
+
|
|
71
|
+
# profiler.stop()
|
|
72
|
+
# profiler.print()
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def test_array_ctype_from_paddle(kernel, num_iters, array_size, device, warp_dtype=None):
|
|
76
|
+
warp_device = wp.get_device(device)
|
|
77
|
+
paddle_device = wp.device_to_paddle(warp_device)
|
|
78
|
+
|
|
79
|
+
if hasattr(warp_dtype, "_shape_"):
|
|
80
|
+
paddle_shape = (array_size, *warp_dtype._shape_)
|
|
81
|
+
paddle_dtype = wp.dtype_to_paddle(warp_dtype._wp_scalar_type_)
|
|
82
|
+
else:
|
|
83
|
+
paddle_shape = (array_size,)
|
|
84
|
+
paddle_dtype = paddle.float32 if warp_dtype is None else wp.dtype_to_paddle(warp_dtype)
|
|
85
|
+
|
|
86
|
+
_a = paddle.zeros(paddle_shape, dtype=paddle_dtype).to(device=paddle_device)
|
|
87
|
+
_b = paddle.zeros(paddle_shape, dtype=paddle_dtype).to(device=paddle_device)
|
|
88
|
+
_c = paddle.zeros(paddle_shape, dtype=paddle_dtype).to(device=paddle_device)
|
|
89
|
+
_d = paddle.zeros(paddle_shape, dtype=paddle_dtype).to(device=paddle_device)
|
|
90
|
+
_e = paddle.zeros(paddle_shape, dtype=paddle_dtype).to(device=paddle_device)
|
|
91
|
+
|
|
92
|
+
wp.synchronize()
|
|
93
|
+
|
|
94
|
+
# profiler = Profiler(interval=0.000001)
|
|
95
|
+
# profiler.start()
|
|
96
|
+
|
|
97
|
+
t1 = time.time_ns()
|
|
98
|
+
|
|
99
|
+
for _ in range(num_iters):
|
|
100
|
+
a = wp.from_paddle(_a, dtype=warp_dtype, return_ctype=True)
|
|
101
|
+
b = wp.from_paddle(_b, dtype=warp_dtype, return_ctype=True)
|
|
102
|
+
c = wp.from_paddle(_c, dtype=warp_dtype, return_ctype=True)
|
|
103
|
+
d = wp.from_paddle(_d, dtype=warp_dtype, return_ctype=True)
|
|
104
|
+
e = wp.from_paddle(_e, dtype=warp_dtype, return_ctype=True)
|
|
105
|
+
wp.launch(kernel, dim=array_size, inputs=[a, b, c, d, e])
|
|
106
|
+
|
|
107
|
+
t2 = time.time_ns()
|
|
108
|
+
print(f"{(t2 - t1) / 1_000_000:8.0f} ms from_paddle(..., return_ctype=True)")
|
|
109
|
+
|
|
110
|
+
# profiler.stop()
|
|
111
|
+
# profiler.print()
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def test_direct_from_paddle(kernel, num_iters, array_size, device, warp_dtype=None):
|
|
115
|
+
warp_device = wp.get_device(device)
|
|
116
|
+
paddle_device = wp.device_to_paddle(warp_device)
|
|
117
|
+
|
|
118
|
+
if hasattr(warp_dtype, "_shape_"):
|
|
119
|
+
paddle_shape = (array_size, *warp_dtype._shape_)
|
|
120
|
+
paddle_dtype = wp.dtype_to_paddle(warp_dtype._wp_scalar_type_)
|
|
121
|
+
else:
|
|
122
|
+
paddle_shape = (array_size,)
|
|
123
|
+
paddle_dtype = paddle.float32 if warp_dtype is None else wp.dtype_to_paddle(warp_dtype)
|
|
124
|
+
|
|
125
|
+
_a = paddle.zeros(paddle_shape, dtype=paddle_dtype).to(device=paddle_device)
|
|
126
|
+
_b = paddle.zeros(paddle_shape, dtype=paddle_dtype).to(device=paddle_device)
|
|
127
|
+
_c = paddle.zeros(paddle_shape, dtype=paddle_dtype).to(device=paddle_device)
|
|
128
|
+
_d = paddle.zeros(paddle_shape, dtype=paddle_dtype).to(device=paddle_device)
|
|
129
|
+
_e = paddle.zeros(paddle_shape, dtype=paddle_dtype).to(device=paddle_device)
|
|
130
|
+
|
|
131
|
+
wp.synchronize()
|
|
132
|
+
|
|
133
|
+
# profiler = Profiler(interval=0.000001)
|
|
134
|
+
# profiler.start()
|
|
135
|
+
|
|
136
|
+
t1 = time.time_ns()
|
|
137
|
+
|
|
138
|
+
for _ in range(num_iters):
|
|
139
|
+
wp.launch(kernel, dim=array_size, inputs=[_a, _b, _c, _d, _e])
|
|
140
|
+
|
|
141
|
+
t2 = time.time_ns()
|
|
142
|
+
print(f"{(t2 - t1) / 1_000_000:8.0f} ms direct from paddle")
|
|
143
|
+
|
|
144
|
+
# profiler.stop()
|
|
145
|
+
# profiler.print()
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
wp.init()
|
|
149
|
+
|
|
150
|
+
params = [
|
|
151
|
+
# (warp_dtype arg, kernel)
|
|
152
|
+
(None, create_simple_kernel(wp.float32)),
|
|
153
|
+
(wp.float32, create_simple_kernel(wp.float32)),
|
|
154
|
+
(wp.vec3f, create_simple_kernel(wp.vec3f)),
|
|
155
|
+
(wp.mat22f, create_simple_kernel(wp.mat22f)),
|
|
156
|
+
]
|
|
157
|
+
|
|
158
|
+
wp.load_module()
|
|
159
|
+
|
|
160
|
+
num_iters = 100000
|
|
161
|
+
|
|
162
|
+
for warp_dtype, kernel in params:
|
|
163
|
+
print(f"\ndtype={wp.context.type_str(warp_dtype)}")
|
|
164
|
+
test_from_paddle(kernel, num_iters, 10, "cuda:0", warp_dtype=warp_dtype)
|
|
165
|
+
test_array_ctype_from_paddle(kernel, num_iters, 10, "cuda:0", warp_dtype=warp_dtype)
|
|
166
|
+
test_direct_from_paddle(kernel, num_iters, 10, "cuda:0", warp_dtype=warp_dtype)
|