warp-lang 1.7.0__py3-none-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +139 -0
- warp/__init__.pyi +1 -0
- warp/autograd.py +1142 -0
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/build.py +557 -0
- warp/build_dll.py +405 -0
- warp/builtins.py +6855 -0
- warp/codegen.py +3969 -0
- warp/config.py +158 -0
- warp/constants.py +57 -0
- warp/context.py +6812 -0
- warp/dlpack.py +462 -0
- warp/examples/__init__.py +24 -0
- warp/examples/assets/bear.usd +0 -0
- warp/examples/assets/bunny.usd +0 -0
- warp/examples/assets/cartpole.urdf +110 -0
- warp/examples/assets/crazyflie.usd +0 -0
- warp/examples/assets/cube.usd +0 -0
- warp/examples/assets/nonuniform.usd +0 -0
- warp/examples/assets/nv_ant.xml +92 -0
- warp/examples/assets/nv_humanoid.xml +183 -0
- warp/examples/assets/nvidia_logo.png +0 -0
- warp/examples/assets/pixel.jpg +0 -0
- warp/examples/assets/quadruped.urdf +268 -0
- warp/examples/assets/rocks.nvdb +0 -0
- warp/examples/assets/rocks.usd +0 -0
- warp/examples/assets/sphere.usd +0 -0
- warp/examples/assets/square_cloth.usd +0 -0
- warp/examples/benchmarks/benchmark_api.py +389 -0
- warp/examples/benchmarks/benchmark_cloth.py +296 -0
- warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
- warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
- warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
- warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
- warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
- warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
- warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
- warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
- warp/examples/benchmarks/benchmark_gemm.py +164 -0
- warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
- warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
- warp/examples/benchmarks/benchmark_launches.py +301 -0
- warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
- warp/examples/browse.py +37 -0
- warp/examples/core/example_cupy.py +86 -0
- warp/examples/core/example_dem.py +241 -0
- warp/examples/core/example_fluid.py +299 -0
- warp/examples/core/example_graph_capture.py +150 -0
- warp/examples/core/example_marching_cubes.py +194 -0
- warp/examples/core/example_mesh.py +180 -0
- warp/examples/core/example_mesh_intersect.py +211 -0
- warp/examples/core/example_nvdb.py +182 -0
- warp/examples/core/example_raycast.py +111 -0
- warp/examples/core/example_raymarch.py +205 -0
- warp/examples/core/example_render_opengl.py +193 -0
- warp/examples/core/example_sample_mesh.py +300 -0
- warp/examples/core/example_sph.py +411 -0
- warp/examples/core/example_torch.py +211 -0
- warp/examples/core/example_wave.py +269 -0
- warp/examples/fem/example_adaptive_grid.py +286 -0
- warp/examples/fem/example_apic_fluid.py +423 -0
- warp/examples/fem/example_burgers.py +261 -0
- warp/examples/fem/example_convection_diffusion.py +178 -0
- warp/examples/fem/example_convection_diffusion_dg.py +204 -0
- warp/examples/fem/example_deformed_geometry.py +172 -0
- warp/examples/fem/example_diffusion.py +196 -0
- warp/examples/fem/example_diffusion_3d.py +225 -0
- warp/examples/fem/example_diffusion_mgpu.py +220 -0
- warp/examples/fem/example_distortion_energy.py +228 -0
- warp/examples/fem/example_magnetostatics.py +240 -0
- warp/examples/fem/example_mixed_elasticity.py +291 -0
- warp/examples/fem/example_navier_stokes.py +261 -0
- warp/examples/fem/example_nonconforming_contact.py +298 -0
- warp/examples/fem/example_stokes.py +213 -0
- warp/examples/fem/example_stokes_transfer.py +262 -0
- warp/examples/fem/example_streamlines.py +352 -0
- warp/examples/fem/utils.py +1000 -0
- warp/examples/interop/example_jax_callable.py +116 -0
- warp/examples/interop/example_jax_ffi_callback.py +132 -0
- warp/examples/interop/example_jax_kernel.py +205 -0
- warp/examples/optim/example_bounce.py +266 -0
- warp/examples/optim/example_cloth_throw.py +228 -0
- warp/examples/optim/example_diffray.py +561 -0
- warp/examples/optim/example_drone.py +870 -0
- warp/examples/optim/example_fluid_checkpoint.py +497 -0
- warp/examples/optim/example_inverse_kinematics.py +182 -0
- warp/examples/optim/example_inverse_kinematics_torch.py +191 -0
- warp/examples/optim/example_softbody_properties.py +400 -0
- warp/examples/optim/example_spring_cage.py +245 -0
- warp/examples/optim/example_trajectory.py +227 -0
- warp/examples/sim/example_cartpole.py +143 -0
- warp/examples/sim/example_cloth.py +225 -0
- warp/examples/sim/example_cloth_self_contact.py +322 -0
- warp/examples/sim/example_granular.py +130 -0
- warp/examples/sim/example_granular_collision_sdf.py +202 -0
- warp/examples/sim/example_jacobian_ik.py +244 -0
- warp/examples/sim/example_particle_chain.py +124 -0
- warp/examples/sim/example_quadruped.py +203 -0
- warp/examples/sim/example_rigid_chain.py +203 -0
- warp/examples/sim/example_rigid_contact.py +195 -0
- warp/examples/sim/example_rigid_force.py +133 -0
- warp/examples/sim/example_rigid_gyroscopic.py +115 -0
- warp/examples/sim/example_rigid_soft_contact.py +140 -0
- warp/examples/sim/example_soft_body.py +196 -0
- warp/examples/tile/example_tile_cholesky.py +87 -0
- warp/examples/tile/example_tile_convolution.py +66 -0
- warp/examples/tile/example_tile_fft.py +55 -0
- warp/examples/tile/example_tile_filtering.py +113 -0
- warp/examples/tile/example_tile_matmul.py +85 -0
- warp/examples/tile/example_tile_mlp.py +383 -0
- warp/examples/tile/example_tile_nbody.py +199 -0
- warp/examples/tile/example_tile_walker.py +327 -0
- warp/fabric.py +355 -0
- warp/fem/__init__.py +106 -0
- warp/fem/adaptivity.py +508 -0
- warp/fem/cache.py +572 -0
- warp/fem/dirichlet.py +202 -0
- warp/fem/domain.py +411 -0
- warp/fem/field/__init__.py +125 -0
- warp/fem/field/field.py +619 -0
- warp/fem/field/nodal_field.py +326 -0
- warp/fem/field/restriction.py +37 -0
- warp/fem/field/virtual.py +848 -0
- warp/fem/geometry/__init__.py +32 -0
- warp/fem/geometry/adaptive_nanogrid.py +857 -0
- warp/fem/geometry/closest_point.py +84 -0
- warp/fem/geometry/deformed_geometry.py +221 -0
- warp/fem/geometry/element.py +776 -0
- warp/fem/geometry/geometry.py +362 -0
- warp/fem/geometry/grid_2d.py +392 -0
- warp/fem/geometry/grid_3d.py +452 -0
- warp/fem/geometry/hexmesh.py +911 -0
- warp/fem/geometry/nanogrid.py +571 -0
- warp/fem/geometry/partition.py +389 -0
- warp/fem/geometry/quadmesh.py +663 -0
- warp/fem/geometry/tetmesh.py +855 -0
- warp/fem/geometry/trimesh.py +806 -0
- warp/fem/integrate.py +2335 -0
- warp/fem/linalg.py +419 -0
- warp/fem/operator.py +293 -0
- warp/fem/polynomial.py +229 -0
- warp/fem/quadrature/__init__.py +17 -0
- warp/fem/quadrature/pic_quadrature.py +299 -0
- warp/fem/quadrature/quadrature.py +591 -0
- warp/fem/space/__init__.py +228 -0
- warp/fem/space/basis_function_space.py +468 -0
- warp/fem/space/basis_space.py +667 -0
- warp/fem/space/dof_mapper.py +251 -0
- warp/fem/space/function_space.py +309 -0
- warp/fem/space/grid_2d_function_space.py +177 -0
- warp/fem/space/grid_3d_function_space.py +227 -0
- warp/fem/space/hexmesh_function_space.py +257 -0
- warp/fem/space/nanogrid_function_space.py +201 -0
- warp/fem/space/partition.py +367 -0
- warp/fem/space/quadmesh_function_space.py +223 -0
- warp/fem/space/restriction.py +179 -0
- warp/fem/space/shape/__init__.py +143 -0
- warp/fem/space/shape/cube_shape_function.py +1105 -0
- warp/fem/space/shape/shape_function.py +133 -0
- warp/fem/space/shape/square_shape_function.py +926 -0
- warp/fem/space/shape/tet_shape_function.py +834 -0
- warp/fem/space/shape/triangle_shape_function.py +672 -0
- warp/fem/space/tetmesh_function_space.py +271 -0
- warp/fem/space/topology.py +424 -0
- warp/fem/space/trimesh_function_space.py +194 -0
- warp/fem/types.py +99 -0
- warp/fem/utils.py +420 -0
- warp/jax.py +187 -0
- warp/jax_experimental/__init__.py +16 -0
- warp/jax_experimental/custom_call.py +351 -0
- warp/jax_experimental/ffi.py +698 -0
- warp/jax_experimental/xla_ffi.py +602 -0
- warp/math.py +244 -0
- warp/native/array.h +1145 -0
- warp/native/builtin.h +1800 -0
- warp/native/bvh.cpp +492 -0
- warp/native/bvh.cu +791 -0
- warp/native/bvh.h +554 -0
- warp/native/clang/clang.cpp +536 -0
- warp/native/coloring.cpp +613 -0
- warp/native/crt.cpp +51 -0
- warp/native/crt.h +362 -0
- warp/native/cuda_crt.h +1058 -0
- warp/native/cuda_util.cpp +646 -0
- warp/native/cuda_util.h +307 -0
- warp/native/error.cpp +77 -0
- warp/native/error.h +36 -0
- warp/native/exports.h +1878 -0
- warp/native/fabric.h +245 -0
- warp/native/hashgrid.cpp +311 -0
- warp/native/hashgrid.cu +87 -0
- warp/native/hashgrid.h +240 -0
- warp/native/initializer_array.h +41 -0
- warp/native/intersect.h +1230 -0
- warp/native/intersect_adj.h +375 -0
- warp/native/intersect_tri.h +339 -0
- warp/native/marching.cpp +19 -0
- warp/native/marching.cu +514 -0
- warp/native/marching.h +19 -0
- warp/native/mat.h +2220 -0
- warp/native/mathdx.cpp +87 -0
- warp/native/matnn.h +343 -0
- warp/native/mesh.cpp +266 -0
- warp/native/mesh.cu +404 -0
- warp/native/mesh.h +1980 -0
- warp/native/nanovdb/GridHandle.h +366 -0
- warp/native/nanovdb/HostBuffer.h +590 -0
- warp/native/nanovdb/NanoVDB.h +6624 -0
- warp/native/nanovdb/PNanoVDB.h +3390 -0
- warp/native/noise.h +859 -0
- warp/native/quat.h +1371 -0
- warp/native/rand.h +342 -0
- warp/native/range.h +139 -0
- warp/native/reduce.cpp +174 -0
- warp/native/reduce.cu +364 -0
- warp/native/runlength_encode.cpp +79 -0
- warp/native/runlength_encode.cu +61 -0
- warp/native/scan.cpp +47 -0
- warp/native/scan.cu +53 -0
- warp/native/scan.h +23 -0
- warp/native/solid_angle.h +466 -0
- warp/native/sort.cpp +251 -0
- warp/native/sort.cu +277 -0
- warp/native/sort.h +33 -0
- warp/native/sparse.cpp +378 -0
- warp/native/sparse.cu +524 -0
- warp/native/spatial.h +657 -0
- warp/native/svd.h +702 -0
- warp/native/temp_buffer.h +46 -0
- warp/native/tile.h +2584 -0
- warp/native/tile_reduce.h +264 -0
- warp/native/vec.h +1426 -0
- warp/native/volume.cpp +501 -0
- warp/native/volume.cu +67 -0
- warp/native/volume.h +969 -0
- warp/native/volume_builder.cu +477 -0
- warp/native/volume_builder.h +52 -0
- warp/native/volume_impl.h +70 -0
- warp/native/warp.cpp +1082 -0
- warp/native/warp.cu +3636 -0
- warp/native/warp.h +381 -0
- warp/optim/__init__.py +17 -0
- warp/optim/adam.py +163 -0
- warp/optim/linear.py +1137 -0
- warp/optim/sgd.py +112 -0
- warp/paddle.py +407 -0
- warp/render/__init__.py +18 -0
- warp/render/render_opengl.py +3518 -0
- warp/render/render_usd.py +784 -0
- warp/render/utils.py +160 -0
- warp/sim/__init__.py +65 -0
- warp/sim/articulation.py +793 -0
- warp/sim/collide.py +2395 -0
- warp/sim/graph_coloring.py +300 -0
- warp/sim/import_mjcf.py +790 -0
- warp/sim/import_snu.py +227 -0
- warp/sim/import_urdf.py +579 -0
- warp/sim/import_usd.py +894 -0
- warp/sim/inertia.py +324 -0
- warp/sim/integrator.py +242 -0
- warp/sim/integrator_euler.py +1997 -0
- warp/sim/integrator_featherstone.py +2101 -0
- warp/sim/integrator_vbd.py +2048 -0
- warp/sim/integrator_xpbd.py +3292 -0
- warp/sim/model.py +4791 -0
- warp/sim/particles.py +121 -0
- warp/sim/render.py +427 -0
- warp/sim/utils.py +428 -0
- warp/sparse.py +2057 -0
- warp/stubs.py +3333 -0
- warp/tape.py +1203 -0
- warp/tests/__init__.py +1 -0
- warp/tests/__main__.py +4 -0
- warp/tests/assets/curlnoise_golden.npy +0 -0
- warp/tests/assets/mlp_golden.npy +0 -0
- warp/tests/assets/pixel.npy +0 -0
- warp/tests/assets/pnoise_golden.npy +0 -0
- warp/tests/assets/spiky.usd +0 -0
- warp/tests/assets/test_grid.nvdb +0 -0
- warp/tests/assets/test_index_grid.nvdb +0 -0
- warp/tests/assets/test_int32_grid.nvdb +0 -0
- warp/tests/assets/test_vec_grid.nvdb +0 -0
- warp/tests/assets/torus.nvdb +0 -0
- warp/tests/assets/torus.usda +105 -0
- warp/tests/aux_test_class_kernel.py +34 -0
- warp/tests/aux_test_compile_consts_dummy.py +18 -0
- warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
- warp/tests/aux_test_dependent.py +29 -0
- warp/tests/aux_test_grad_customs.py +29 -0
- warp/tests/aux_test_instancing_gc.py +26 -0
- warp/tests/aux_test_module_unload.py +23 -0
- warp/tests/aux_test_name_clash1.py +40 -0
- warp/tests/aux_test_name_clash2.py +40 -0
- warp/tests/aux_test_reference.py +9 -0
- warp/tests/aux_test_reference_reference.py +8 -0
- warp/tests/aux_test_square.py +16 -0
- warp/tests/aux_test_unresolved_func.py +22 -0
- warp/tests/aux_test_unresolved_symbol.py +22 -0
- warp/tests/cuda/__init__.py +0 -0
- warp/tests/cuda/test_async.py +676 -0
- warp/tests/cuda/test_ipc.py +124 -0
- warp/tests/cuda/test_mempool.py +233 -0
- warp/tests/cuda/test_multigpu.py +169 -0
- warp/tests/cuda/test_peer.py +139 -0
- warp/tests/cuda/test_pinned.py +84 -0
- warp/tests/cuda/test_streams.py +634 -0
- warp/tests/geometry/__init__.py +0 -0
- warp/tests/geometry/test_bvh.py +200 -0
- warp/tests/geometry/test_hash_grid.py +221 -0
- warp/tests/geometry/test_marching_cubes.py +74 -0
- warp/tests/geometry/test_mesh.py +316 -0
- warp/tests/geometry/test_mesh_query_aabb.py +399 -0
- warp/tests/geometry/test_mesh_query_point.py +932 -0
- warp/tests/geometry/test_mesh_query_ray.py +311 -0
- warp/tests/geometry/test_volume.py +1103 -0
- warp/tests/geometry/test_volume_write.py +346 -0
- warp/tests/interop/__init__.py +0 -0
- warp/tests/interop/test_dlpack.py +729 -0
- warp/tests/interop/test_jax.py +371 -0
- warp/tests/interop/test_paddle.py +800 -0
- warp/tests/interop/test_torch.py +1001 -0
- warp/tests/run_coverage_serial.py +39 -0
- warp/tests/sim/__init__.py +0 -0
- warp/tests/sim/disabled_kinematics.py +244 -0
- warp/tests/sim/flaky_test_sim_grad.py +290 -0
- warp/tests/sim/test_collision.py +604 -0
- warp/tests/sim/test_coloring.py +258 -0
- warp/tests/sim/test_model.py +224 -0
- warp/tests/sim/test_sim_grad_bounce_linear.py +212 -0
- warp/tests/sim/test_sim_kinematics.py +98 -0
- warp/tests/sim/test_vbd.py +597 -0
- warp/tests/test_adam.py +163 -0
- warp/tests/test_arithmetic.py +1096 -0
- warp/tests/test_array.py +2972 -0
- warp/tests/test_array_reduce.py +156 -0
- warp/tests/test_assert.py +250 -0
- warp/tests/test_atomic.py +153 -0
- warp/tests/test_bool.py +220 -0
- warp/tests/test_builtins_resolution.py +1298 -0
- warp/tests/test_closest_point_edge_edge.py +327 -0
- warp/tests/test_codegen.py +810 -0
- warp/tests/test_codegen_instancing.py +1495 -0
- warp/tests/test_compile_consts.py +215 -0
- warp/tests/test_conditional.py +252 -0
- warp/tests/test_context.py +42 -0
- warp/tests/test_copy.py +238 -0
- warp/tests/test_ctypes.py +638 -0
- warp/tests/test_dense.py +73 -0
- warp/tests/test_devices.py +97 -0
- warp/tests/test_examples.py +482 -0
- warp/tests/test_fabricarray.py +996 -0
- warp/tests/test_fast_math.py +74 -0
- warp/tests/test_fem.py +2003 -0
- warp/tests/test_fp16.py +136 -0
- warp/tests/test_func.py +454 -0
- warp/tests/test_future_annotations.py +98 -0
- warp/tests/test_generics.py +656 -0
- warp/tests/test_grad.py +893 -0
- warp/tests/test_grad_customs.py +339 -0
- warp/tests/test_grad_debug.py +341 -0
- warp/tests/test_implicit_init.py +411 -0
- warp/tests/test_import.py +45 -0
- warp/tests/test_indexedarray.py +1140 -0
- warp/tests/test_intersect.py +73 -0
- warp/tests/test_iter.py +76 -0
- warp/tests/test_large.py +177 -0
- warp/tests/test_launch.py +411 -0
- warp/tests/test_lerp.py +151 -0
- warp/tests/test_linear_solvers.py +193 -0
- warp/tests/test_lvalue.py +427 -0
- warp/tests/test_mat.py +2089 -0
- warp/tests/test_mat_lite.py +122 -0
- warp/tests/test_mat_scalar_ops.py +2913 -0
- warp/tests/test_math.py +178 -0
- warp/tests/test_mlp.py +282 -0
- warp/tests/test_module_hashing.py +258 -0
- warp/tests/test_modules_lite.py +44 -0
- warp/tests/test_noise.py +252 -0
- warp/tests/test_operators.py +299 -0
- warp/tests/test_options.py +129 -0
- warp/tests/test_overwrite.py +551 -0
- warp/tests/test_print.py +339 -0
- warp/tests/test_quat.py +2315 -0
- warp/tests/test_rand.py +339 -0
- warp/tests/test_reload.py +302 -0
- warp/tests/test_rounding.py +185 -0
- warp/tests/test_runlength_encode.py +196 -0
- warp/tests/test_scalar_ops.py +105 -0
- warp/tests/test_smoothstep.py +108 -0
- warp/tests/test_snippet.py +318 -0
- warp/tests/test_sparse.py +582 -0
- warp/tests/test_spatial.py +2229 -0
- warp/tests/test_special_values.py +361 -0
- warp/tests/test_static.py +592 -0
- warp/tests/test_struct.py +734 -0
- warp/tests/test_tape.py +204 -0
- warp/tests/test_transient_module.py +93 -0
- warp/tests/test_triangle_closest_point.py +145 -0
- warp/tests/test_types.py +562 -0
- warp/tests/test_utils.py +588 -0
- warp/tests/test_vec.py +1487 -0
- warp/tests/test_vec_lite.py +80 -0
- warp/tests/test_vec_scalar_ops.py +2327 -0
- warp/tests/test_verify_fp.py +100 -0
- warp/tests/tile/__init__.py +0 -0
- warp/tests/tile/test_tile.py +780 -0
- warp/tests/tile/test_tile_load.py +407 -0
- warp/tests/tile/test_tile_mathdx.py +208 -0
- warp/tests/tile/test_tile_mlp.py +402 -0
- warp/tests/tile/test_tile_reduce.py +447 -0
- warp/tests/tile/test_tile_shared_memory.py +247 -0
- warp/tests/tile/test_tile_view.py +173 -0
- warp/tests/unittest_serial.py +47 -0
- warp/tests/unittest_suites.py +427 -0
- warp/tests/unittest_utils.py +468 -0
- warp/tests/walkthrough_debug.py +93 -0
- warp/thirdparty/__init__.py +0 -0
- warp/thirdparty/appdirs.py +598 -0
- warp/thirdparty/dlpack.py +145 -0
- warp/thirdparty/unittest_parallel.py +570 -0
- warp/torch.py +391 -0
- warp/types.py +5230 -0
- warp/utils.py +1137 -0
- warp_lang-1.7.0.dist-info/METADATA +516 -0
- warp_lang-1.7.0.dist-info/RECORD +429 -0
- warp_lang-1.7.0.dist-info/WHEEL +5 -0
- warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
- warp_lang-1.7.0.dist-info/top_level.txt +1 -0
warp/tests/test_fp16.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import unittest
|
|
17
|
+
|
|
18
|
+
import numpy as np
|
|
19
|
+
|
|
20
|
+
import warp as wp
|
|
21
|
+
from warp.tests.unittest_utils import *
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@wp.kernel
|
|
25
|
+
def load_store_half(f32: wp.array(dtype=wp.float32), f16: wp.array(dtype=wp.float16)):
|
|
26
|
+
tid = wp.tid()
|
|
27
|
+
|
|
28
|
+
# check conversion from f32->f16
|
|
29
|
+
a = wp.float16(f32[tid])
|
|
30
|
+
b = f16[tid]
|
|
31
|
+
|
|
32
|
+
wp.expect_eq(a, b)
|
|
33
|
+
|
|
34
|
+
# check stores
|
|
35
|
+
f16[tid] = a
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def test_fp16_conversion(test, device):
|
|
39
|
+
s = [1.0, 2.0, 3.0, -3.14159]
|
|
40
|
+
|
|
41
|
+
np_f32 = np.array(s, dtype=np.float32)
|
|
42
|
+
np_f16 = np.array(s, dtype=np.float16)
|
|
43
|
+
|
|
44
|
+
wp_f32 = wp.array(s, dtype=wp.float32, device=device)
|
|
45
|
+
wp_f16 = wp.array(s, dtype=wp.float16, device=device)
|
|
46
|
+
|
|
47
|
+
assert_np_equal(np_f32, wp_f32.numpy())
|
|
48
|
+
assert_np_equal(np_f16, wp_f16.numpy())
|
|
49
|
+
|
|
50
|
+
wp.launch(load_store_half, dim=len(s), inputs=[wp_f32, wp_f16], device=device)
|
|
51
|
+
|
|
52
|
+
# check that stores worked
|
|
53
|
+
assert_np_equal(np_f16, wp_f16.numpy())
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
@wp.kernel
|
|
57
|
+
def value_load_store_half(f16_value: wp.float16, f16_array: wp.array(dtype=wp.float16)):
|
|
58
|
+
wp.expect_eq(f16_value, f16_array[0])
|
|
59
|
+
|
|
60
|
+
# check stores
|
|
61
|
+
f16_array[0] = f16_value
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def test_fp16_kernel_parameter(test, device):
|
|
65
|
+
"""Test the ability to pass in fp16 into kernels as parameters"""
|
|
66
|
+
|
|
67
|
+
s = [1.0, 2.0, 3.0, -3.14159]
|
|
68
|
+
|
|
69
|
+
for test_val in s:
|
|
70
|
+
np_f16 = np.array([test_val], dtype=np.float16)
|
|
71
|
+
wp_f16 = wp.array([test_val], dtype=wp.float16, device=device)
|
|
72
|
+
|
|
73
|
+
wp.launch(value_load_store_half, (1,), inputs=[wp.float16(test_val), wp_f16], device=device)
|
|
74
|
+
|
|
75
|
+
# check that stores worked
|
|
76
|
+
assert_np_equal(np_f16, wp_f16.numpy())
|
|
77
|
+
|
|
78
|
+
# Do the same thing but pass in test_val as a Python float to test automatic conversion
|
|
79
|
+
wp_f16 = wp.array([test_val], dtype=wp.float16, device=device)
|
|
80
|
+
wp.launch(value_load_store_half, (1,), inputs=[test_val, wp_f16], device=device)
|
|
81
|
+
assert_np_equal(np_f16, wp_f16.numpy())
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@wp.kernel
|
|
85
|
+
def mul_half(input: wp.array(dtype=wp.float16), output: wp.array(dtype=wp.float16)):
|
|
86
|
+
tid = wp.tid()
|
|
87
|
+
|
|
88
|
+
# convert to compute type fp32
|
|
89
|
+
x = wp.float(input[tid]) * 2.0
|
|
90
|
+
|
|
91
|
+
# store back as fp16
|
|
92
|
+
output[tid] = wp.float16(x)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def test_fp16_grad(test, device):
|
|
96
|
+
rng = np.random.default_rng(123)
|
|
97
|
+
|
|
98
|
+
# checks that gradients are correctly propagated for
|
|
99
|
+
# fp16 arrays, even when intermediate calculations
|
|
100
|
+
# are performed in e.g.: fp32
|
|
101
|
+
|
|
102
|
+
s = rng.random(size=15).astype(np.float16)
|
|
103
|
+
|
|
104
|
+
input = wp.array(s, dtype=wp.float16, device=device, requires_grad=True)
|
|
105
|
+
output = wp.zeros_like(input)
|
|
106
|
+
|
|
107
|
+
tape = wp.Tape()
|
|
108
|
+
with tape:
|
|
109
|
+
wp.launch(mul_half, dim=len(s), inputs=[input, output], device=device)
|
|
110
|
+
|
|
111
|
+
ones = wp.array(np.ones(len(output)), dtype=wp.float16, device=device)
|
|
112
|
+
|
|
113
|
+
tape.backward(grads={output: ones})
|
|
114
|
+
|
|
115
|
+
assert_np_equal(input.grad.numpy(), np.ones(len(s)) * 2.0)
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
class TestFp16(unittest.TestCase):
|
|
119
|
+
pass
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
devices = []
|
|
123
|
+
if wp.is_cpu_available():
|
|
124
|
+
devices.append("cpu")
|
|
125
|
+
for cuda_device in get_selected_cuda_test_devices():
|
|
126
|
+
if cuda_device.arch >= 70:
|
|
127
|
+
devices.append(cuda_device)
|
|
128
|
+
|
|
129
|
+
add_function_test(TestFp16, "test_fp16_conversion", test_fp16_conversion, devices=devices)
|
|
130
|
+
add_function_test(TestFp16, "test_fp16_grad", test_fp16_grad, devices=devices)
|
|
131
|
+
add_function_test(TestFp16, "test_fp16_kernel_parameter", test_fp16_kernel_parameter, devices=devices)
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
if __name__ == "__main__":
|
|
135
|
+
wp.clear_kernel_cache()
|
|
136
|
+
unittest.main(verbosity=2)
|
warp/tests/test_func.py
ADDED
|
@@ -0,0 +1,454 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import math
|
|
17
|
+
import unittest
|
|
18
|
+
from typing import Any, Tuple
|
|
19
|
+
|
|
20
|
+
import numpy as np
|
|
21
|
+
|
|
22
|
+
import warp as wp
|
|
23
|
+
from warp.tests.unittest_utils import *
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@wp.func
|
|
27
|
+
def sqr(x: float):
|
|
28
|
+
return x * x
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
# test nested user function calls
|
|
32
|
+
# and explicit return type hints
|
|
33
|
+
@wp.func
|
|
34
|
+
def cube(x: float) -> float:
|
|
35
|
+
return sqr(x) * x
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@wp.func
|
|
39
|
+
def custom(x: int):
|
|
40
|
+
return x + 1
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@wp.func
|
|
44
|
+
def custom(x: float):
|
|
45
|
+
return x + 1.0
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@wp.func
|
|
49
|
+
def custom(x: wp.vec3):
|
|
50
|
+
return x + wp.vec3(1.0, 0.0, 0.0)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@wp.func
|
|
54
|
+
def noreturn(x: wp.vec3):
|
|
55
|
+
x = x + wp.vec3(0.0, 1.0, 0.0)
|
|
56
|
+
|
|
57
|
+
wp.expect_eq(x, wp.vec3(1.0, 1.0, 0.0))
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@wp.kernel
|
|
61
|
+
def test_overload_func():
|
|
62
|
+
# tests overloading a custom @wp.func
|
|
63
|
+
|
|
64
|
+
i = custom(1)
|
|
65
|
+
f = custom(1.0)
|
|
66
|
+
v = custom(wp.vec3(1.0, 0.0, 0.0))
|
|
67
|
+
|
|
68
|
+
wp.expect_eq(i, 2)
|
|
69
|
+
wp.expect_eq(f, 2.0)
|
|
70
|
+
wp.expect_eq(v, wp.vec3(2.0, 0.0, 0.0))
|
|
71
|
+
|
|
72
|
+
noreturn(wp.vec3(1.0, 0.0, 0.0))
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
@wp.func
|
|
76
|
+
def foo(x: int):
|
|
77
|
+
# This shouldn't be picked up.
|
|
78
|
+
return x * 2
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
@wp.func
|
|
82
|
+
def foo(x: int):
|
|
83
|
+
return x * 3
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
@wp.kernel
|
|
87
|
+
def test_override_func():
|
|
88
|
+
i = foo(1)
|
|
89
|
+
wp.expect_eq(i, 3)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def test_func_closure_capture(test, device):
|
|
93
|
+
def make_closure_kernel(func):
|
|
94
|
+
def closure_kernel_fn(data: wp.array(dtype=float), expected: float):
|
|
95
|
+
f = func(data[wp.tid()])
|
|
96
|
+
wp.expect_eq(f, expected)
|
|
97
|
+
|
|
98
|
+
return wp.Kernel(func=closure_kernel_fn)
|
|
99
|
+
|
|
100
|
+
sqr_closure = make_closure_kernel(sqr)
|
|
101
|
+
cube_closure = make_closure_kernel(cube)
|
|
102
|
+
|
|
103
|
+
data = wp.array([2.0], dtype=float, device=device)
|
|
104
|
+
expected_sqr = 4.0
|
|
105
|
+
expected_cube = 8.0
|
|
106
|
+
|
|
107
|
+
wp.launch(sqr_closure, dim=data.shape, inputs=[data, expected_sqr], device=device)
|
|
108
|
+
wp.launch(cube_closure, dim=data.shape, inputs=[data, expected_cube], device=device)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
@wp.func
|
|
112
|
+
def test_func(param1: wp.int32, param2: wp.int32, param3: wp.int32) -> wp.float32:
|
|
113
|
+
return 1.0
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
@wp.kernel
|
|
117
|
+
def test_return_kernel(test_data: wp.array(dtype=wp.float32)):
|
|
118
|
+
tid = wp.tid()
|
|
119
|
+
test_data[tid] = wp.lerp(test_func(0, 1, 2), test_func(0, 1, 2), 0.5)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def test_return_func(test, device):
|
|
123
|
+
test_data = wp.zeros(100, dtype=wp.float32, device=device)
|
|
124
|
+
wp.launch(kernel=test_return_kernel, dim=test_data.size, inputs=[test_data], device=device)
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
@wp.func
|
|
128
|
+
def multi_valued_func(a: wp.float32, b: wp.float32):
|
|
129
|
+
return a + b, a - b, a * b, a / b
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def test_multi_valued_func(test, device):
|
|
133
|
+
@wp.kernel
|
|
134
|
+
def test_multi_valued_kernel(test_data1: wp.array(dtype=wp.float32), test_data2: wp.array(dtype=wp.float32)):
|
|
135
|
+
tid = wp.tid()
|
|
136
|
+
d1, d2 = test_data1[tid], test_data2[tid]
|
|
137
|
+
a, b, c, d = multi_valued_func(d1, d2)
|
|
138
|
+
wp.expect_eq(a, d1 + d2)
|
|
139
|
+
wp.expect_eq(b, d1 - d2)
|
|
140
|
+
wp.expect_eq(c, d1 * d2)
|
|
141
|
+
wp.expect_eq(d, d1 / d2)
|
|
142
|
+
|
|
143
|
+
test_data1 = wp.array(np.arange(100), dtype=wp.float32, device=device)
|
|
144
|
+
test_data2 = wp.array(np.arange(100, 0, -1), dtype=wp.float32, device=device)
|
|
145
|
+
wp.launch(kernel=test_multi_valued_kernel, dim=test_data1.size, inputs=[test_data1, test_data2], device=device)
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
@wp.kernel
|
|
149
|
+
def test_func_defaults():
|
|
150
|
+
# test default as expected
|
|
151
|
+
wp.expect_near(1.0, 1.0 + 1.0e-6)
|
|
152
|
+
|
|
153
|
+
# test that changing tolerance still works
|
|
154
|
+
wp.expect_near(1.0, 1.1, 0.5)
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
@wp.func
|
|
158
|
+
def sign(x: float):
|
|
159
|
+
return 123.0
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
@wp.kernel
|
|
163
|
+
def test_builtin_shadowing():
|
|
164
|
+
wp.expect_eq(sign(1.23), 123.0)
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
@wp.func
|
|
168
|
+
def user_func_with_defaults(a: int = 123, b: int = 234) -> int:
|
|
169
|
+
return a + b
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
@wp.kernel
|
|
173
|
+
def user_func_with_defaults_kernel():
|
|
174
|
+
a = user_func_with_defaults()
|
|
175
|
+
wp.expect_eq(a, 357)
|
|
176
|
+
|
|
177
|
+
b = user_func_with_defaults(111)
|
|
178
|
+
wp.expect_eq(b, 345)
|
|
179
|
+
|
|
180
|
+
c = user_func_with_defaults(111, 222)
|
|
181
|
+
wp.expect_eq(c, 333)
|
|
182
|
+
|
|
183
|
+
d = user_func_with_defaults(a=111)
|
|
184
|
+
wp.expect_eq(d, 345)
|
|
185
|
+
|
|
186
|
+
e = user_func_with_defaults(b=111)
|
|
187
|
+
wp.expect_eq(e, 234)
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def test_user_func_with_defaults(test, device):
|
|
191
|
+
wp.launch(user_func_with_defaults_kernel, dim=1, device=device)
|
|
192
|
+
|
|
193
|
+
a = user_func_with_defaults()
|
|
194
|
+
assert a == 357
|
|
195
|
+
|
|
196
|
+
b = user_func_with_defaults(111)
|
|
197
|
+
assert b == 345
|
|
198
|
+
|
|
199
|
+
c = user_func_with_defaults(111, 222)
|
|
200
|
+
assert c == 333
|
|
201
|
+
|
|
202
|
+
d = user_func_with_defaults(a=111)
|
|
203
|
+
assert d == 345
|
|
204
|
+
|
|
205
|
+
e = user_func_with_defaults(b=111)
|
|
206
|
+
assert e == 234
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
@wp.func
|
|
210
|
+
def user_func_return_multiple_values(a: int, b: float) -> Tuple[int, float]:
|
|
211
|
+
return a + a, b * b
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
@wp.kernel
|
|
215
|
+
def test_user_func_return_multiple_values():
|
|
216
|
+
a, b = user_func_return_multiple_values(123, 234.0)
|
|
217
|
+
wp.expect_eq(a, 246)
|
|
218
|
+
wp.expect_eq(b, 54756.0)
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
@wp.func
|
|
222
|
+
def user_func_overload(
|
|
223
|
+
b: wp.array(dtype=Any),
|
|
224
|
+
i: int,
|
|
225
|
+
):
|
|
226
|
+
return b[i] * 2.0
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
@wp.kernel
|
|
230
|
+
def user_func_overload_resolution_kernel(
|
|
231
|
+
a: wp.array(dtype=Any),
|
|
232
|
+
b: wp.array(dtype=Any),
|
|
233
|
+
):
|
|
234
|
+
i = wp.tid()
|
|
235
|
+
a[i] = user_func_overload(b, i)
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
def test_user_func_overload_resolution(test, device):
|
|
239
|
+
a0 = wp.array((1, 2, 3), dtype=wp.vec3)
|
|
240
|
+
b0 = wp.array((2, 3, 4), dtype=wp.vec3)
|
|
241
|
+
|
|
242
|
+
a1 = wp.array((5,), dtype=float)
|
|
243
|
+
b1 = wp.array((6,), dtype=float)
|
|
244
|
+
|
|
245
|
+
wp.launch(user_func_overload_resolution_kernel, a0.shape, (a0, b0))
|
|
246
|
+
wp.launch(user_func_overload_resolution_kernel, a1.shape, (a1, b1))
|
|
247
|
+
|
|
248
|
+
assert_np_equal(a0.numpy()[0], (4, 6, 8))
|
|
249
|
+
assert a1.numpy()[0] == 12
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
@wp.func
|
|
253
|
+
def user_func_return_none() -> None:
|
|
254
|
+
pass
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
@wp.kernel
|
|
258
|
+
def test_return_annotation_none() -> None:
|
|
259
|
+
user_func_return_none()
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
devices = get_test_devices()
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
class TestFunc(unittest.TestCase):
|
|
266
|
+
def test_user_func_export(self):
|
|
267
|
+
# tests calling overloaded user-defined functions from Python
|
|
268
|
+
i = custom(1)
|
|
269
|
+
f = custom(1.0)
|
|
270
|
+
v = custom(wp.vec3(1.0, 0.0, 0.0))
|
|
271
|
+
|
|
272
|
+
self.assertEqual(i, 2)
|
|
273
|
+
self.assertEqual(f, 2.0)
|
|
274
|
+
assert_np_equal(np.array([*v]), np.array([2.0, 0.0, 0.0]))
|
|
275
|
+
|
|
276
|
+
def test_native_func_export(self):
|
|
277
|
+
# tests calling native functions from Python
|
|
278
|
+
|
|
279
|
+
q = wp.quat(0.0, 0.0, 0.0, 1.0)
|
|
280
|
+
assert_np_equal(np.array([*q]), np.array([0.0, 0.0, 0.0, 1.0]))
|
|
281
|
+
|
|
282
|
+
r = wp.quat_from_axis_angle(wp.vec3(1.0, 0.0, 0.0), 2.0)
|
|
283
|
+
assert_np_equal(np.array([*r]), np.array([0.8414709568023682, 0.0, 0.0, 0.5403022170066833]), tol=1.0e-3)
|
|
284
|
+
|
|
285
|
+
q = wp.quat(1.0, 2.0, 3.0, 4.0)
|
|
286
|
+
q = wp.normalize(q) * 2.0
|
|
287
|
+
assert_np_equal(
|
|
288
|
+
np.array([*q]),
|
|
289
|
+
np.array([0.18257418274879456, 0.3651483654975891, 0.547722578048706, 0.7302967309951782]) * 2.0,
|
|
290
|
+
tol=1.0e-3,
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
v2 = wp.vec2(1.0, 2.0)
|
|
294
|
+
v2 = wp.normalize(v2) * 2.0
|
|
295
|
+
assert_np_equal(np.array([*v2]), np.array([0.4472135901451111, 0.8944271802902222]) * 2.0, tol=1.0e-3)
|
|
296
|
+
|
|
297
|
+
v3 = wp.vec3(1.0, 2.0, 3.0)
|
|
298
|
+
v3 = wp.normalize(v3) * 2.0
|
|
299
|
+
assert_np_equal(
|
|
300
|
+
np.array([*v3]), np.array([0.26726123690605164, 0.5345224738121033, 0.8017836809158325]) * 2.0, tol=1.0e-3
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
v4 = wp.vec4(1.0, 2.0, 3.0, 4.0)
|
|
304
|
+
v4 = wp.normalize(v4) * 2.0
|
|
305
|
+
assert_np_equal(
|
|
306
|
+
np.array([*v4]),
|
|
307
|
+
np.array([0.18257418274879456, 0.3651483654975891, 0.547722578048706, 0.7302967309951782]) * 2.0,
|
|
308
|
+
tol=1.0e-3,
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
v = wp.vec2(0.0)
|
|
312
|
+
v += wp.vec2(1.0, 1.0)
|
|
313
|
+
assert v == wp.vec2(1.0, 1.0)
|
|
314
|
+
v -= wp.vec2(1.0, 1.0)
|
|
315
|
+
assert v == wp.vec2(0.0, 0.0)
|
|
316
|
+
v = wp.vec2(2.0, 2.0) - wp.vec2(1.0, 1.0)
|
|
317
|
+
assert v == wp.vec2(1.0, 1.0)
|
|
318
|
+
v *= 2.0
|
|
319
|
+
assert v == wp.vec2(2.0, 2.0)
|
|
320
|
+
v = v * 2.0
|
|
321
|
+
assert v == wp.vec2(4.0, 4.0)
|
|
322
|
+
v = v / 2.0
|
|
323
|
+
assert v == wp.vec2(2.0, 2.0)
|
|
324
|
+
v /= 2.0
|
|
325
|
+
assert v == wp.vec2(1.0, 1.0)
|
|
326
|
+
v = -v
|
|
327
|
+
assert v == wp.vec2(-1.0, -1.0)
|
|
328
|
+
v = +v
|
|
329
|
+
assert v == wp.vec2(-1.0, -1.0)
|
|
330
|
+
|
|
331
|
+
m22 = wp.mat22(1.0, 2.0, 3.0, 4.0)
|
|
332
|
+
m22 = m22 + m22
|
|
333
|
+
|
|
334
|
+
self.assertEqual(m22[1, 1], 8.0)
|
|
335
|
+
self.assertEqual(str(m22), "[[2.0, 4.0],\n [6.0, 8.0]]")
|
|
336
|
+
|
|
337
|
+
t = wp.transform(
|
|
338
|
+
wp.vec3(1.0, 2.0, 3.0),
|
|
339
|
+
wp.quat(4.0, 5.0, 6.0, 7.0),
|
|
340
|
+
)
|
|
341
|
+
self.assertSequenceEqual(t, (1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0))
|
|
342
|
+
self.assertSequenceEqual(
|
|
343
|
+
t * wp.transform(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0), (396.0, 432.0, 720.0, 56.0, 70.0, 84.0, -28.0)
|
|
344
|
+
)
|
|
345
|
+
self.assertSequenceEqual(
|
|
346
|
+
t * wp.transform((1.0, 2.0, 3.0), (4.0, 5.0, 6.0, 7.0)), (396.0, 432.0, 720.0, 56.0, 70.0, 84.0, -28.0)
|
|
347
|
+
)
|
|
348
|
+
|
|
349
|
+
t = wp.transform()
|
|
350
|
+
self.assertSequenceEqual(t, (0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0))
|
|
351
|
+
|
|
352
|
+
t = wp.transform(p=(1.0, 2.0, 3.0), q=(4.0, 5.0, 6.0, 7.0))
|
|
353
|
+
self.assertSequenceEqual(t, (1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0))
|
|
354
|
+
|
|
355
|
+
t = wp.transform(q=(4.0, 5.0, 6.0, 7.0), p=(1.0, 2.0, 3.0))
|
|
356
|
+
self.assertSequenceEqual(t, (1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0))
|
|
357
|
+
|
|
358
|
+
t = wp.transform((1.0, 2.0, 3.0), q=(4.0, 5.0, 6.0, 7.0))
|
|
359
|
+
self.assertSequenceEqual(t, (1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0))
|
|
360
|
+
|
|
361
|
+
t = wp.transform(p=(1.0, 2.0, 3.0))
|
|
362
|
+
self.assertSequenceEqual(t, (1.0, 2.0, 3.0, 0.0, 0.0, 0.0, 1.0))
|
|
363
|
+
|
|
364
|
+
t = wp.transform(q=(4.0, 5.0, 6.0, 7.0))
|
|
365
|
+
self.assertSequenceEqual(t, (0.0, 0.0, 0.0, 4.0, 5.0, 6.0, 7.0))
|
|
366
|
+
|
|
367
|
+
t = wp.transform((1.0, 2.0, 3.0), (4.0, 5.0, 6.0, 7.0))
|
|
368
|
+
self.assertSequenceEqual(t, (1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0))
|
|
369
|
+
|
|
370
|
+
t = wp.transform(p=wp.vec3(1.0, 2.0, 3.0), q=wp.quat(4.0, 5.0, 6.0, 7.0))
|
|
371
|
+
self.assertSequenceEqual(t, (1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0))
|
|
372
|
+
|
|
373
|
+
t = wp.transform(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0)
|
|
374
|
+
self.assertSequenceEqual(t, (1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0))
|
|
375
|
+
|
|
376
|
+
t = wp.transform(wp.transform(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0))
|
|
377
|
+
self.assertSequenceEqual(t, (1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0))
|
|
378
|
+
|
|
379
|
+
t = wp.transform(*wp.transform(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0))
|
|
380
|
+
self.assertSequenceEqual(t, (1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0))
|
|
381
|
+
|
|
382
|
+
transformf = wp.types.transformation(dtype=float)
|
|
383
|
+
|
|
384
|
+
t = wp.transformf((1.0, 2.0, 3.0), (4.0, 5.0, 6.0, 7.0))
|
|
385
|
+
self.assertSequenceEqual(
|
|
386
|
+
t + transformf((2.0, 3.0, 4.0), (5.0, 6.0, 7.0, 8.0)),
|
|
387
|
+
(3.0, 5.0, 7.0, 9.0, 11.0, 13.0, 15.0),
|
|
388
|
+
)
|
|
389
|
+
self.assertSequenceEqual(
|
|
390
|
+
t - transformf((2.0, 3.0, 4.0), (5.0, 6.0, 7.0, 8.0)),
|
|
391
|
+
(-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0),
|
|
392
|
+
)
|
|
393
|
+
|
|
394
|
+
f = wp.sin(math.pi * 0.5)
|
|
395
|
+
self.assertAlmostEqual(f, 1.0, places=3)
|
|
396
|
+
|
|
397
|
+
m = wp.mat22(0.0, 0.0, 0.0, 0.0)
|
|
398
|
+
m += wp.mat22(1.0, 1.0, 1.0, 1.0)
|
|
399
|
+
assert m == wp.mat22(1.0, 1.0, 1.0, 1.0)
|
|
400
|
+
m -= wp.mat22(1.0, 1.0, 1.0, 1.0)
|
|
401
|
+
assert m == wp.mat22(0.0, 0.0, 0.0, 0.0)
|
|
402
|
+
m = wp.mat22(2.0, 2.0, 2.0, 2.0) - wp.mat22(1.0, 1.0, 1.0, 1.0)
|
|
403
|
+
assert m == wp.mat22(1.0, 1.0, 1.0, 1.0)
|
|
404
|
+
m *= 2.0
|
|
405
|
+
assert m == wp.mat22(2.0, 2.0, 2.0, 2.0)
|
|
406
|
+
m = m * 2.0
|
|
407
|
+
assert m == wp.mat22(4.0, 4.0, 4.0, 4.0)
|
|
408
|
+
m = m / 2.0
|
|
409
|
+
assert m == wp.mat22(2.0, 2.0, 2.0, 2.0)
|
|
410
|
+
m /= 2.0
|
|
411
|
+
assert m == wp.mat22(1.0, 1.0, 1.0, 1.0)
|
|
412
|
+
m = -m
|
|
413
|
+
assert m == wp.mat22(-1.0, -1.0, -1.0, -1.0)
|
|
414
|
+
m = +m
|
|
415
|
+
assert m == wp.mat22(-1.0, -1.0, -1.0, -1.0)
|
|
416
|
+
m = m * m
|
|
417
|
+
assert m == wp.mat22(2.0, 2.0, 2.0, 2.0)
|
|
418
|
+
|
|
419
|
+
def test_native_function_error_resolution(self):
|
|
420
|
+
a = wp.mat22f(1.0, 2.0, 3.0, 4.0)
|
|
421
|
+
b = wp.mat22d(1.0, 2.0, 3.0, 4.0)
|
|
422
|
+
with self.assertRaisesRegex(
|
|
423
|
+
RuntimeError,
|
|
424
|
+
r"^Couldn't find a function 'mul' compatible with " r"the arguments 'mat22f, mat22d'$",
|
|
425
|
+
):
|
|
426
|
+
a * b
|
|
427
|
+
|
|
428
|
+
|
|
429
|
+
add_kernel_test(TestFunc, kernel=test_overload_func, name="test_overload_func", dim=1, devices=devices)
|
|
430
|
+
add_function_test(TestFunc, func=test_return_func, name="test_return_func", devices=devices)
|
|
431
|
+
add_kernel_test(TestFunc, kernel=test_override_func, name="test_override_func", dim=1, devices=devices)
|
|
432
|
+
add_function_test(TestFunc, func=test_func_closure_capture, name="test_func_closure_capture", devices=devices)
|
|
433
|
+
add_function_test(TestFunc, func=test_multi_valued_func, name="test_multi_valued_func", devices=devices)
|
|
434
|
+
add_kernel_test(TestFunc, kernel=test_func_defaults, name="test_func_defaults", dim=1, devices=devices)
|
|
435
|
+
add_kernel_test(TestFunc, kernel=test_builtin_shadowing, name="test_builtin_shadowing", dim=1, devices=devices)
|
|
436
|
+
add_function_test(TestFunc, func=test_user_func_with_defaults, name="test_user_func_with_defaults", devices=devices)
|
|
437
|
+
add_kernel_test(
|
|
438
|
+
TestFunc,
|
|
439
|
+
kernel=test_user_func_return_multiple_values,
|
|
440
|
+
name="test_user_func_return_multiple_values",
|
|
441
|
+
dim=1,
|
|
442
|
+
devices=devices,
|
|
443
|
+
)
|
|
444
|
+
add_function_test(
|
|
445
|
+
TestFunc, func=test_user_func_overload_resolution, name="test_user_func_overload_resolution", devices=devices
|
|
446
|
+
)
|
|
447
|
+
add_kernel_test(
|
|
448
|
+
TestFunc, kernel=test_return_annotation_none, name="test_return_annotation_none", dim=1, devices=devices
|
|
449
|
+
)
|
|
450
|
+
|
|
451
|
+
|
|
452
|
+
if __name__ == "__main__":
|
|
453
|
+
wp.clear_kernel_cache()
|
|
454
|
+
unittest.main(verbosity=2)
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
# This is what we are actually testing.
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
import unittest
|
|
20
|
+
|
|
21
|
+
import warp as wp
|
|
22
|
+
from warp.tests.unittest_utils import *
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@wp.struct
|
|
26
|
+
class FooData:
|
|
27
|
+
x: float
|
|
28
|
+
y: float
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class Foo:
|
|
32
|
+
Data = FooData
|
|
33
|
+
|
|
34
|
+
@wp.func
|
|
35
|
+
def compute():
|
|
36
|
+
pass
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@wp.kernel
|
|
40
|
+
def kernel_1(
|
|
41
|
+
out: wp.array(dtype=float),
|
|
42
|
+
):
|
|
43
|
+
tid = wp.tid()
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@wp.kernel
|
|
47
|
+
def kernel_2(
|
|
48
|
+
out: wp.array(dtype=float),
|
|
49
|
+
):
|
|
50
|
+
tid = wp.tid()
|
|
51
|
+
out[tid] = 1.23
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def create_kernel_3(foo: Foo):
|
|
55
|
+
def fn(
|
|
56
|
+
data: foo.Data,
|
|
57
|
+
out: wp.array(dtype=float),
|
|
58
|
+
):
|
|
59
|
+
tid = wp.tid()
|
|
60
|
+
|
|
61
|
+
# Referencing a variable in a type hint like `foo.Data` isn't officially
|
|
62
|
+
# accepted by Python but it's still being used in some places (e.g.: `warp.fem`)
|
|
63
|
+
# where it works only because the variable being referenced within the function,
|
|
64
|
+
# which causes it to be promoted to a closure variable. Without that,
|
|
65
|
+
# it wouldn't be possible to resolve `foo` and to evaluate the `foo.Data`
|
|
66
|
+
# string to its corresponding type.
|
|
67
|
+
foo.compute()
|
|
68
|
+
|
|
69
|
+
out[tid] = data.x + data.y
|
|
70
|
+
|
|
71
|
+
return wp.Kernel(func=fn)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def test_future_annotations(test, device):
|
|
75
|
+
foo = Foo()
|
|
76
|
+
foo_data = FooData()
|
|
77
|
+
foo_data.x = 1.23
|
|
78
|
+
foo_data.y = 2.34
|
|
79
|
+
|
|
80
|
+
out = wp.empty(1, dtype=float)
|
|
81
|
+
|
|
82
|
+
kernel_3 = create_kernel_3(foo)
|
|
83
|
+
|
|
84
|
+
wp.launch(kernel_1, dim=out.shape, outputs=(out,))
|
|
85
|
+
wp.launch(kernel_2, dim=out.shape, outputs=(out,))
|
|
86
|
+
wp.launch(kernel_3, dim=out.shape, inputs=(foo_data,), outputs=(out,))
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class TestFutureAnnotations(unittest.TestCase):
|
|
90
|
+
pass
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
add_function_test(TestFutureAnnotations, "test_future_annotations", test_future_annotations)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
if __name__ == "__main__":
|
|
97
|
+
wp.clear_kernel_cache()
|
|
98
|
+
unittest.main(verbosity=2)
|