warp-lang 1.7.0__py3-none-manylinux_2_34_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +139 -0
- warp/__init__.pyi +1 -0
- warp/autograd.py +1142 -0
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/build.py +557 -0
- warp/build_dll.py +405 -0
- warp/builtins.py +6855 -0
- warp/codegen.py +3969 -0
- warp/config.py +158 -0
- warp/constants.py +57 -0
- warp/context.py +6812 -0
- warp/dlpack.py +462 -0
- warp/examples/__init__.py +24 -0
- warp/examples/assets/bear.usd +0 -0
- warp/examples/assets/bunny.usd +0 -0
- warp/examples/assets/cartpole.urdf +110 -0
- warp/examples/assets/crazyflie.usd +0 -0
- warp/examples/assets/cube.usd +0 -0
- warp/examples/assets/nonuniform.usd +0 -0
- warp/examples/assets/nv_ant.xml +92 -0
- warp/examples/assets/nv_humanoid.xml +183 -0
- warp/examples/assets/nvidia_logo.png +0 -0
- warp/examples/assets/pixel.jpg +0 -0
- warp/examples/assets/quadruped.urdf +268 -0
- warp/examples/assets/rocks.nvdb +0 -0
- warp/examples/assets/rocks.usd +0 -0
- warp/examples/assets/sphere.usd +0 -0
- warp/examples/assets/square_cloth.usd +0 -0
- warp/examples/benchmarks/benchmark_api.py +389 -0
- warp/examples/benchmarks/benchmark_cloth.py +296 -0
- warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
- warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
- warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
- warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
- warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
- warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
- warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
- warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
- warp/examples/benchmarks/benchmark_gemm.py +164 -0
- warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
- warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
- warp/examples/benchmarks/benchmark_launches.py +301 -0
- warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
- warp/examples/browse.py +37 -0
- warp/examples/core/example_cupy.py +86 -0
- warp/examples/core/example_dem.py +241 -0
- warp/examples/core/example_fluid.py +299 -0
- warp/examples/core/example_graph_capture.py +150 -0
- warp/examples/core/example_marching_cubes.py +194 -0
- warp/examples/core/example_mesh.py +180 -0
- warp/examples/core/example_mesh_intersect.py +211 -0
- warp/examples/core/example_nvdb.py +182 -0
- warp/examples/core/example_raycast.py +111 -0
- warp/examples/core/example_raymarch.py +205 -0
- warp/examples/core/example_render_opengl.py +193 -0
- warp/examples/core/example_sample_mesh.py +300 -0
- warp/examples/core/example_sph.py +411 -0
- warp/examples/core/example_torch.py +211 -0
- warp/examples/core/example_wave.py +269 -0
- warp/examples/fem/example_adaptive_grid.py +286 -0
- warp/examples/fem/example_apic_fluid.py +423 -0
- warp/examples/fem/example_burgers.py +261 -0
- warp/examples/fem/example_convection_diffusion.py +178 -0
- warp/examples/fem/example_convection_diffusion_dg.py +204 -0
- warp/examples/fem/example_deformed_geometry.py +172 -0
- warp/examples/fem/example_diffusion.py +196 -0
- warp/examples/fem/example_diffusion_3d.py +225 -0
- warp/examples/fem/example_diffusion_mgpu.py +220 -0
- warp/examples/fem/example_distortion_energy.py +228 -0
- warp/examples/fem/example_magnetostatics.py +240 -0
- warp/examples/fem/example_mixed_elasticity.py +291 -0
- warp/examples/fem/example_navier_stokes.py +261 -0
- warp/examples/fem/example_nonconforming_contact.py +298 -0
- warp/examples/fem/example_stokes.py +213 -0
- warp/examples/fem/example_stokes_transfer.py +262 -0
- warp/examples/fem/example_streamlines.py +352 -0
- warp/examples/fem/utils.py +1000 -0
- warp/examples/interop/example_jax_callable.py +116 -0
- warp/examples/interop/example_jax_ffi_callback.py +132 -0
- warp/examples/interop/example_jax_kernel.py +205 -0
- warp/examples/optim/example_bounce.py +266 -0
- warp/examples/optim/example_cloth_throw.py +228 -0
- warp/examples/optim/example_diffray.py +561 -0
- warp/examples/optim/example_drone.py +870 -0
- warp/examples/optim/example_fluid_checkpoint.py +497 -0
- warp/examples/optim/example_inverse_kinematics.py +182 -0
- warp/examples/optim/example_inverse_kinematics_torch.py +191 -0
- warp/examples/optim/example_softbody_properties.py +400 -0
- warp/examples/optim/example_spring_cage.py +245 -0
- warp/examples/optim/example_trajectory.py +227 -0
- warp/examples/sim/example_cartpole.py +143 -0
- warp/examples/sim/example_cloth.py +225 -0
- warp/examples/sim/example_cloth_self_contact.py +322 -0
- warp/examples/sim/example_granular.py +130 -0
- warp/examples/sim/example_granular_collision_sdf.py +202 -0
- warp/examples/sim/example_jacobian_ik.py +244 -0
- warp/examples/sim/example_particle_chain.py +124 -0
- warp/examples/sim/example_quadruped.py +203 -0
- warp/examples/sim/example_rigid_chain.py +203 -0
- warp/examples/sim/example_rigid_contact.py +195 -0
- warp/examples/sim/example_rigid_force.py +133 -0
- warp/examples/sim/example_rigid_gyroscopic.py +115 -0
- warp/examples/sim/example_rigid_soft_contact.py +140 -0
- warp/examples/sim/example_soft_body.py +196 -0
- warp/examples/tile/example_tile_cholesky.py +87 -0
- warp/examples/tile/example_tile_convolution.py +66 -0
- warp/examples/tile/example_tile_fft.py +55 -0
- warp/examples/tile/example_tile_filtering.py +113 -0
- warp/examples/tile/example_tile_matmul.py +85 -0
- warp/examples/tile/example_tile_mlp.py +383 -0
- warp/examples/tile/example_tile_nbody.py +199 -0
- warp/examples/tile/example_tile_walker.py +327 -0
- warp/fabric.py +355 -0
- warp/fem/__init__.py +106 -0
- warp/fem/adaptivity.py +508 -0
- warp/fem/cache.py +572 -0
- warp/fem/dirichlet.py +202 -0
- warp/fem/domain.py +411 -0
- warp/fem/field/__init__.py +125 -0
- warp/fem/field/field.py +619 -0
- warp/fem/field/nodal_field.py +326 -0
- warp/fem/field/restriction.py +37 -0
- warp/fem/field/virtual.py +848 -0
- warp/fem/geometry/__init__.py +32 -0
- warp/fem/geometry/adaptive_nanogrid.py +857 -0
- warp/fem/geometry/closest_point.py +84 -0
- warp/fem/geometry/deformed_geometry.py +221 -0
- warp/fem/geometry/element.py +776 -0
- warp/fem/geometry/geometry.py +362 -0
- warp/fem/geometry/grid_2d.py +392 -0
- warp/fem/geometry/grid_3d.py +452 -0
- warp/fem/geometry/hexmesh.py +911 -0
- warp/fem/geometry/nanogrid.py +571 -0
- warp/fem/geometry/partition.py +389 -0
- warp/fem/geometry/quadmesh.py +663 -0
- warp/fem/geometry/tetmesh.py +855 -0
- warp/fem/geometry/trimesh.py +806 -0
- warp/fem/integrate.py +2335 -0
- warp/fem/linalg.py +419 -0
- warp/fem/operator.py +293 -0
- warp/fem/polynomial.py +229 -0
- warp/fem/quadrature/__init__.py +17 -0
- warp/fem/quadrature/pic_quadrature.py +299 -0
- warp/fem/quadrature/quadrature.py +591 -0
- warp/fem/space/__init__.py +228 -0
- warp/fem/space/basis_function_space.py +468 -0
- warp/fem/space/basis_space.py +667 -0
- warp/fem/space/dof_mapper.py +251 -0
- warp/fem/space/function_space.py +309 -0
- warp/fem/space/grid_2d_function_space.py +177 -0
- warp/fem/space/grid_3d_function_space.py +227 -0
- warp/fem/space/hexmesh_function_space.py +257 -0
- warp/fem/space/nanogrid_function_space.py +201 -0
- warp/fem/space/partition.py +367 -0
- warp/fem/space/quadmesh_function_space.py +223 -0
- warp/fem/space/restriction.py +179 -0
- warp/fem/space/shape/__init__.py +143 -0
- warp/fem/space/shape/cube_shape_function.py +1105 -0
- warp/fem/space/shape/shape_function.py +133 -0
- warp/fem/space/shape/square_shape_function.py +926 -0
- warp/fem/space/shape/tet_shape_function.py +834 -0
- warp/fem/space/shape/triangle_shape_function.py +672 -0
- warp/fem/space/tetmesh_function_space.py +271 -0
- warp/fem/space/topology.py +424 -0
- warp/fem/space/trimesh_function_space.py +194 -0
- warp/fem/types.py +99 -0
- warp/fem/utils.py +420 -0
- warp/jax.py +187 -0
- warp/jax_experimental/__init__.py +16 -0
- warp/jax_experimental/custom_call.py +351 -0
- warp/jax_experimental/ffi.py +698 -0
- warp/jax_experimental/xla_ffi.py +602 -0
- warp/math.py +244 -0
- warp/native/array.h +1145 -0
- warp/native/builtin.h +1800 -0
- warp/native/bvh.cpp +492 -0
- warp/native/bvh.cu +791 -0
- warp/native/bvh.h +554 -0
- warp/native/clang/clang.cpp +536 -0
- warp/native/coloring.cpp +613 -0
- warp/native/crt.cpp +51 -0
- warp/native/crt.h +362 -0
- warp/native/cuda_crt.h +1058 -0
- warp/native/cuda_util.cpp +646 -0
- warp/native/cuda_util.h +307 -0
- warp/native/error.cpp +77 -0
- warp/native/error.h +36 -0
- warp/native/exports.h +1878 -0
- warp/native/fabric.h +245 -0
- warp/native/hashgrid.cpp +311 -0
- warp/native/hashgrid.cu +87 -0
- warp/native/hashgrid.h +240 -0
- warp/native/initializer_array.h +41 -0
- warp/native/intersect.h +1230 -0
- warp/native/intersect_adj.h +375 -0
- warp/native/intersect_tri.h +339 -0
- warp/native/marching.cpp +19 -0
- warp/native/marching.cu +514 -0
- warp/native/marching.h +19 -0
- warp/native/mat.h +2220 -0
- warp/native/mathdx.cpp +87 -0
- warp/native/matnn.h +343 -0
- warp/native/mesh.cpp +266 -0
- warp/native/mesh.cu +404 -0
- warp/native/mesh.h +1980 -0
- warp/native/nanovdb/GridHandle.h +366 -0
- warp/native/nanovdb/HostBuffer.h +590 -0
- warp/native/nanovdb/NanoVDB.h +6624 -0
- warp/native/nanovdb/PNanoVDB.h +3390 -0
- warp/native/noise.h +859 -0
- warp/native/quat.h +1371 -0
- warp/native/rand.h +342 -0
- warp/native/range.h +139 -0
- warp/native/reduce.cpp +174 -0
- warp/native/reduce.cu +364 -0
- warp/native/runlength_encode.cpp +79 -0
- warp/native/runlength_encode.cu +61 -0
- warp/native/scan.cpp +47 -0
- warp/native/scan.cu +53 -0
- warp/native/scan.h +23 -0
- warp/native/solid_angle.h +466 -0
- warp/native/sort.cpp +251 -0
- warp/native/sort.cu +277 -0
- warp/native/sort.h +33 -0
- warp/native/sparse.cpp +378 -0
- warp/native/sparse.cu +524 -0
- warp/native/spatial.h +657 -0
- warp/native/svd.h +702 -0
- warp/native/temp_buffer.h +46 -0
- warp/native/tile.h +2584 -0
- warp/native/tile_reduce.h +264 -0
- warp/native/vec.h +1426 -0
- warp/native/volume.cpp +501 -0
- warp/native/volume.cu +67 -0
- warp/native/volume.h +969 -0
- warp/native/volume_builder.cu +477 -0
- warp/native/volume_builder.h +52 -0
- warp/native/volume_impl.h +70 -0
- warp/native/warp.cpp +1082 -0
- warp/native/warp.cu +3636 -0
- warp/native/warp.h +381 -0
- warp/optim/__init__.py +17 -0
- warp/optim/adam.py +163 -0
- warp/optim/linear.py +1137 -0
- warp/optim/sgd.py +112 -0
- warp/paddle.py +407 -0
- warp/render/__init__.py +18 -0
- warp/render/render_opengl.py +3518 -0
- warp/render/render_usd.py +784 -0
- warp/render/utils.py +160 -0
- warp/sim/__init__.py +65 -0
- warp/sim/articulation.py +793 -0
- warp/sim/collide.py +2395 -0
- warp/sim/graph_coloring.py +300 -0
- warp/sim/import_mjcf.py +790 -0
- warp/sim/import_snu.py +227 -0
- warp/sim/import_urdf.py +579 -0
- warp/sim/import_usd.py +894 -0
- warp/sim/inertia.py +324 -0
- warp/sim/integrator.py +242 -0
- warp/sim/integrator_euler.py +1997 -0
- warp/sim/integrator_featherstone.py +2101 -0
- warp/sim/integrator_vbd.py +2048 -0
- warp/sim/integrator_xpbd.py +3292 -0
- warp/sim/model.py +4791 -0
- warp/sim/particles.py +121 -0
- warp/sim/render.py +427 -0
- warp/sim/utils.py +428 -0
- warp/sparse.py +2057 -0
- warp/stubs.py +3333 -0
- warp/tape.py +1203 -0
- warp/tests/__init__.py +1 -0
- warp/tests/__main__.py +4 -0
- warp/tests/assets/curlnoise_golden.npy +0 -0
- warp/tests/assets/mlp_golden.npy +0 -0
- warp/tests/assets/pixel.npy +0 -0
- warp/tests/assets/pnoise_golden.npy +0 -0
- warp/tests/assets/spiky.usd +0 -0
- warp/tests/assets/test_grid.nvdb +0 -0
- warp/tests/assets/test_index_grid.nvdb +0 -0
- warp/tests/assets/test_int32_grid.nvdb +0 -0
- warp/tests/assets/test_vec_grid.nvdb +0 -0
- warp/tests/assets/torus.nvdb +0 -0
- warp/tests/assets/torus.usda +105 -0
- warp/tests/aux_test_class_kernel.py +34 -0
- warp/tests/aux_test_compile_consts_dummy.py +18 -0
- warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
- warp/tests/aux_test_dependent.py +29 -0
- warp/tests/aux_test_grad_customs.py +29 -0
- warp/tests/aux_test_instancing_gc.py +26 -0
- warp/tests/aux_test_module_unload.py +23 -0
- warp/tests/aux_test_name_clash1.py +40 -0
- warp/tests/aux_test_name_clash2.py +40 -0
- warp/tests/aux_test_reference.py +9 -0
- warp/tests/aux_test_reference_reference.py +8 -0
- warp/tests/aux_test_square.py +16 -0
- warp/tests/aux_test_unresolved_func.py +22 -0
- warp/tests/aux_test_unresolved_symbol.py +22 -0
- warp/tests/cuda/__init__.py +0 -0
- warp/tests/cuda/test_async.py +676 -0
- warp/tests/cuda/test_ipc.py +124 -0
- warp/tests/cuda/test_mempool.py +233 -0
- warp/tests/cuda/test_multigpu.py +169 -0
- warp/tests/cuda/test_peer.py +139 -0
- warp/tests/cuda/test_pinned.py +84 -0
- warp/tests/cuda/test_streams.py +634 -0
- warp/tests/geometry/__init__.py +0 -0
- warp/tests/geometry/test_bvh.py +200 -0
- warp/tests/geometry/test_hash_grid.py +221 -0
- warp/tests/geometry/test_marching_cubes.py +74 -0
- warp/tests/geometry/test_mesh.py +316 -0
- warp/tests/geometry/test_mesh_query_aabb.py +399 -0
- warp/tests/geometry/test_mesh_query_point.py +932 -0
- warp/tests/geometry/test_mesh_query_ray.py +311 -0
- warp/tests/geometry/test_volume.py +1103 -0
- warp/tests/geometry/test_volume_write.py +346 -0
- warp/tests/interop/__init__.py +0 -0
- warp/tests/interop/test_dlpack.py +729 -0
- warp/tests/interop/test_jax.py +371 -0
- warp/tests/interop/test_paddle.py +800 -0
- warp/tests/interop/test_torch.py +1001 -0
- warp/tests/run_coverage_serial.py +39 -0
- warp/tests/sim/__init__.py +0 -0
- warp/tests/sim/disabled_kinematics.py +244 -0
- warp/tests/sim/flaky_test_sim_grad.py +290 -0
- warp/tests/sim/test_collision.py +604 -0
- warp/tests/sim/test_coloring.py +258 -0
- warp/tests/sim/test_model.py +224 -0
- warp/tests/sim/test_sim_grad_bounce_linear.py +212 -0
- warp/tests/sim/test_sim_kinematics.py +98 -0
- warp/tests/sim/test_vbd.py +597 -0
- warp/tests/test_adam.py +163 -0
- warp/tests/test_arithmetic.py +1096 -0
- warp/tests/test_array.py +2972 -0
- warp/tests/test_array_reduce.py +156 -0
- warp/tests/test_assert.py +250 -0
- warp/tests/test_atomic.py +153 -0
- warp/tests/test_bool.py +220 -0
- warp/tests/test_builtins_resolution.py +1298 -0
- warp/tests/test_closest_point_edge_edge.py +327 -0
- warp/tests/test_codegen.py +810 -0
- warp/tests/test_codegen_instancing.py +1495 -0
- warp/tests/test_compile_consts.py +215 -0
- warp/tests/test_conditional.py +252 -0
- warp/tests/test_context.py +42 -0
- warp/tests/test_copy.py +238 -0
- warp/tests/test_ctypes.py +638 -0
- warp/tests/test_dense.py +73 -0
- warp/tests/test_devices.py +97 -0
- warp/tests/test_examples.py +482 -0
- warp/tests/test_fabricarray.py +996 -0
- warp/tests/test_fast_math.py +74 -0
- warp/tests/test_fem.py +2003 -0
- warp/tests/test_fp16.py +136 -0
- warp/tests/test_func.py +454 -0
- warp/tests/test_future_annotations.py +98 -0
- warp/tests/test_generics.py +656 -0
- warp/tests/test_grad.py +893 -0
- warp/tests/test_grad_customs.py +339 -0
- warp/tests/test_grad_debug.py +341 -0
- warp/tests/test_implicit_init.py +411 -0
- warp/tests/test_import.py +45 -0
- warp/tests/test_indexedarray.py +1140 -0
- warp/tests/test_intersect.py +73 -0
- warp/tests/test_iter.py +76 -0
- warp/tests/test_large.py +177 -0
- warp/tests/test_launch.py +411 -0
- warp/tests/test_lerp.py +151 -0
- warp/tests/test_linear_solvers.py +193 -0
- warp/tests/test_lvalue.py +427 -0
- warp/tests/test_mat.py +2089 -0
- warp/tests/test_mat_lite.py +122 -0
- warp/tests/test_mat_scalar_ops.py +2913 -0
- warp/tests/test_math.py +178 -0
- warp/tests/test_mlp.py +282 -0
- warp/tests/test_module_hashing.py +258 -0
- warp/tests/test_modules_lite.py +44 -0
- warp/tests/test_noise.py +252 -0
- warp/tests/test_operators.py +299 -0
- warp/tests/test_options.py +129 -0
- warp/tests/test_overwrite.py +551 -0
- warp/tests/test_print.py +339 -0
- warp/tests/test_quat.py +2315 -0
- warp/tests/test_rand.py +339 -0
- warp/tests/test_reload.py +302 -0
- warp/tests/test_rounding.py +185 -0
- warp/tests/test_runlength_encode.py +196 -0
- warp/tests/test_scalar_ops.py +105 -0
- warp/tests/test_smoothstep.py +108 -0
- warp/tests/test_snippet.py +318 -0
- warp/tests/test_sparse.py +582 -0
- warp/tests/test_spatial.py +2229 -0
- warp/tests/test_special_values.py +361 -0
- warp/tests/test_static.py +592 -0
- warp/tests/test_struct.py +734 -0
- warp/tests/test_tape.py +204 -0
- warp/tests/test_transient_module.py +93 -0
- warp/tests/test_triangle_closest_point.py +145 -0
- warp/tests/test_types.py +562 -0
- warp/tests/test_utils.py +588 -0
- warp/tests/test_vec.py +1487 -0
- warp/tests/test_vec_lite.py +80 -0
- warp/tests/test_vec_scalar_ops.py +2327 -0
- warp/tests/test_verify_fp.py +100 -0
- warp/tests/tile/__init__.py +0 -0
- warp/tests/tile/test_tile.py +780 -0
- warp/tests/tile/test_tile_load.py +407 -0
- warp/tests/tile/test_tile_mathdx.py +208 -0
- warp/tests/tile/test_tile_mlp.py +402 -0
- warp/tests/tile/test_tile_reduce.py +447 -0
- warp/tests/tile/test_tile_shared_memory.py +247 -0
- warp/tests/tile/test_tile_view.py +173 -0
- warp/tests/unittest_serial.py +47 -0
- warp/tests/unittest_suites.py +427 -0
- warp/tests/unittest_utils.py +468 -0
- warp/tests/walkthrough_debug.py +93 -0
- warp/thirdparty/__init__.py +0 -0
- warp/thirdparty/appdirs.py +598 -0
- warp/thirdparty/dlpack.py +145 -0
- warp/thirdparty/unittest_parallel.py +570 -0
- warp/torch.py +391 -0
- warp/types.py +5230 -0
- warp/utils.py +1137 -0
- warp_lang-1.7.0.dist-info/METADATA +516 -0
- warp_lang-1.7.0.dist-info/RECORD +429 -0
- warp_lang-1.7.0.dist-info/WHEEL +5 -0
- warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
- warp_lang-1.7.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,185 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import unittest
|
|
17
|
+
|
|
18
|
+
import numpy as np
|
|
19
|
+
|
|
20
|
+
import warp as wp
|
|
21
|
+
from warp.tests.unittest_utils import *
|
|
22
|
+
|
|
23
|
+
compare_to_numpy = False
|
|
24
|
+
print_results = False
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@wp.kernel
|
|
28
|
+
def test_kernel(
|
|
29
|
+
x: wp.array(dtype=float),
|
|
30
|
+
x_round: wp.array(dtype=float),
|
|
31
|
+
x_rint: wp.array(dtype=float),
|
|
32
|
+
x_trunc: wp.array(dtype=float),
|
|
33
|
+
x_cast: wp.array(dtype=float),
|
|
34
|
+
x_floor: wp.array(dtype=float),
|
|
35
|
+
x_ceil: wp.array(dtype=float),
|
|
36
|
+
x_frac: wp.array(dtype=float),
|
|
37
|
+
):
|
|
38
|
+
tid = wp.tid()
|
|
39
|
+
|
|
40
|
+
x_round[tid] = wp.round(x[tid])
|
|
41
|
+
x_rint[tid] = wp.rint(x[tid])
|
|
42
|
+
x_trunc[tid] = wp.trunc(x[tid])
|
|
43
|
+
x_cast[tid] = float(int(x[tid]))
|
|
44
|
+
x_floor[tid] = wp.floor(x[tid])
|
|
45
|
+
x_ceil[tid] = wp.ceil(x[tid])
|
|
46
|
+
x_frac[tid] = wp.frac(x[tid])
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def test_rounding(test, device):
|
|
50
|
+
nx = np.array(
|
|
51
|
+
[
|
|
52
|
+
4.9,
|
|
53
|
+
4.5,
|
|
54
|
+
4.1,
|
|
55
|
+
3.9,
|
|
56
|
+
3.5,
|
|
57
|
+
3.1,
|
|
58
|
+
2.9,
|
|
59
|
+
2.5,
|
|
60
|
+
2.1,
|
|
61
|
+
1.9,
|
|
62
|
+
1.5,
|
|
63
|
+
1.1,
|
|
64
|
+
0.9,
|
|
65
|
+
0.5,
|
|
66
|
+
0.1,
|
|
67
|
+
-0.1,
|
|
68
|
+
-0.5,
|
|
69
|
+
-0.9,
|
|
70
|
+
-1.1,
|
|
71
|
+
-1.5,
|
|
72
|
+
-1.9,
|
|
73
|
+
-2.1,
|
|
74
|
+
-2.5,
|
|
75
|
+
-2.9,
|
|
76
|
+
-3.1,
|
|
77
|
+
-3.5,
|
|
78
|
+
-3.9,
|
|
79
|
+
-4.1,
|
|
80
|
+
-4.5,
|
|
81
|
+
-4.9,
|
|
82
|
+
],
|
|
83
|
+
dtype=np.float32,
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
x = wp.array(nx, device=device)
|
|
87
|
+
N = len(x)
|
|
88
|
+
|
|
89
|
+
x_round = wp.empty(N, dtype=float, device=device)
|
|
90
|
+
x_rint = wp.empty(N, dtype=float, device=device)
|
|
91
|
+
x_trunc = wp.empty(N, dtype=float, device=device)
|
|
92
|
+
x_cast = wp.empty(N, dtype=float, device=device)
|
|
93
|
+
x_floor = wp.empty(N, dtype=float, device=device)
|
|
94
|
+
x_ceil = wp.empty(N, dtype=float, device=device)
|
|
95
|
+
x_frac = wp.empty(N, dtype=float, device=device)
|
|
96
|
+
|
|
97
|
+
wp.launch(
|
|
98
|
+
kernel=test_kernel, dim=N, inputs=[x, x_round, x_rint, x_trunc, x_cast, x_floor, x_ceil, x_frac], device=device
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
wp.synchronize()
|
|
102
|
+
|
|
103
|
+
nx_round = x_round.numpy().reshape(N)
|
|
104
|
+
nx_rint = x_rint.numpy().reshape(N)
|
|
105
|
+
nx_trunc = x_trunc.numpy().reshape(N)
|
|
106
|
+
nx_cast = x_cast.numpy().reshape(N)
|
|
107
|
+
nx_floor = x_floor.numpy().reshape(N)
|
|
108
|
+
nx_ceil = x_ceil.numpy().reshape(N)
|
|
109
|
+
nx_frac = x_frac.numpy().reshape(N)
|
|
110
|
+
|
|
111
|
+
tab = np.stack([nx, nx_round, nx_rint, nx_trunc, nx_cast, nx_floor, nx_ceil, nx_frac], axis=1)
|
|
112
|
+
|
|
113
|
+
golden = np.array(
|
|
114
|
+
[
|
|
115
|
+
[4.9, 5.0, 5.0, 4.0, 4.0, 4.0, 5.0, 0.9],
|
|
116
|
+
[4.5, 5.0, 4.0, 4.0, 4.0, 4.0, 5.0, 0.5],
|
|
117
|
+
[4.1, 4.0, 4.0, 4.0, 4.0, 4.0, 5.0, 0.1],
|
|
118
|
+
[3.9, 4.0, 4.0, 3.0, 3.0, 3.0, 4.0, 0.9],
|
|
119
|
+
[3.5, 4.0, 4.0, 3.0, 3.0, 3.0, 4.0, 0.5],
|
|
120
|
+
[3.1, 3.0, 3.0, 3.0, 3.0, 3.0, 4.0, 0.1],
|
|
121
|
+
[2.9, 3.0, 3.0, 2.0, 2.0, 2.0, 3.0, 0.9],
|
|
122
|
+
[2.5, 3.0, 2.0, 2.0, 2.0, 2.0, 3.0, 0.5],
|
|
123
|
+
[2.1, 2.0, 2.0, 2.0, 2.0, 2.0, 3.0, 0.1],
|
|
124
|
+
[1.9, 2.0, 2.0, 1.0, 1.0, 1.0, 2.0, 0.9],
|
|
125
|
+
[1.5, 2.0, 2.0, 1.0, 1.0, 1.0, 2.0, 0.5],
|
|
126
|
+
[1.1, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 0.1],
|
|
127
|
+
[0.9, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.9],
|
|
128
|
+
[0.5, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.5],
|
|
129
|
+
[0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.1],
|
|
130
|
+
[-0.1, -0.0, -0.0, -0.0, 0.0, -1.0, -0.0, -0.1],
|
|
131
|
+
[-0.5, -1.0, -0.0, -0.0, 0.0, -1.0, -0.0, -0.5],
|
|
132
|
+
[-0.9, -1.0, -1.0, -0.0, 0.0, -1.0, -0.0, -0.9],
|
|
133
|
+
[-1.1, -1.0, -1.0, -1.0, -1.0, -2.0, -1.0, -0.1],
|
|
134
|
+
[-1.5, -2.0, -2.0, -1.0, -1.0, -2.0, -1.0, -0.5],
|
|
135
|
+
[-1.9, -2.0, -2.0, -1.0, -1.0, -2.0, -1.0, -0.9],
|
|
136
|
+
[-2.1, -2.0, -2.0, -2.0, -2.0, -3.0, -2.0, -0.1],
|
|
137
|
+
[-2.5, -3.0, -2.0, -2.0, -2.0, -3.0, -2.0, -0.5],
|
|
138
|
+
[-2.9, -3.0, -3.0, -2.0, -2.0, -3.0, -2.0, -0.9],
|
|
139
|
+
[-3.1, -3.0, -3.0, -3.0, -3.0, -4.0, -3.0, -0.1],
|
|
140
|
+
[-3.5, -4.0, -4.0, -3.0, -3.0, -4.0, -3.0, -0.5],
|
|
141
|
+
[-3.9, -4.0, -4.0, -3.0, -3.0, -4.0, -3.0, -0.9],
|
|
142
|
+
[-4.1, -4.0, -4.0, -4.0, -4.0, -5.0, -4.0, -0.1],
|
|
143
|
+
[-4.5, -5.0, -4.0, -4.0, -4.0, -5.0, -4.0, -0.5],
|
|
144
|
+
[-4.9, -5.0, -5.0, -4.0, -4.0, -5.0, -4.0, -0.9],
|
|
145
|
+
],
|
|
146
|
+
dtype=np.float32,
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
assert_np_equal(tab, golden, tol=1e-6)
|
|
150
|
+
|
|
151
|
+
if print_results:
|
|
152
|
+
np.set_printoptions(formatter={"float": lambda x: "{:6.1f}".format(x).replace(".0", ".")})
|
|
153
|
+
|
|
154
|
+
print("----------------------------------------------")
|
|
155
|
+
print(" %5s %5s %5s %5s %5s %5s %5s" % ("x ", "round", "rint", "trunc", "cast", "floor", "ceil"))
|
|
156
|
+
print(tab)
|
|
157
|
+
print("----------------------------------------------")
|
|
158
|
+
|
|
159
|
+
if compare_to_numpy:
|
|
160
|
+
nx_round = np.round(nx)
|
|
161
|
+
nx_rint = np.rint(nx)
|
|
162
|
+
nx_trunc = np.trunc(nx)
|
|
163
|
+
nx_fix = np.fix(nx)
|
|
164
|
+
nx_floor = np.floor(nx)
|
|
165
|
+
nx_ceil = np.ceil(nx)
|
|
166
|
+
nx_frac = np.modf(nx)[0]
|
|
167
|
+
|
|
168
|
+
tab = np.stack([nx, nx_round, nx_rint, nx_trunc, nx_fix, nx_floor, nx_ceil, nx_frac], axis=1)
|
|
169
|
+
print(" %5s %5s %5s %5s %5s %5s %5s" % ("x ", "round", "rint", "trunc", "fix", "floor", "ceil"))
|
|
170
|
+
print(tab)
|
|
171
|
+
print("----------------------------------------------")
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
class TestRounding(unittest.TestCase):
|
|
175
|
+
pass
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
devices = get_test_devices()
|
|
179
|
+
|
|
180
|
+
add_function_test(TestRounding, "test_rounding", test_rounding, devices=devices)
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
if __name__ == "__main__":
|
|
184
|
+
wp.clear_kernel_cache()
|
|
185
|
+
unittest.main(verbosity=2)
|
|
@@ -0,0 +1,196 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import unittest
|
|
17
|
+
from functools import partial
|
|
18
|
+
|
|
19
|
+
import numpy as np
|
|
20
|
+
|
|
21
|
+
import warp as wp
|
|
22
|
+
from warp.tests.unittest_utils import *
|
|
23
|
+
from warp.utils import runlength_encode
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def test_runlength_encode_int(test, device, n):
|
|
27
|
+
rng = np.random.default_rng(123)
|
|
28
|
+
|
|
29
|
+
values_np = np.sort(rng.integers(-10, high=10, size=n, dtype=int))
|
|
30
|
+
|
|
31
|
+
unique_values_np, unique_counts_np = np.unique(values_np, return_counts=True)
|
|
32
|
+
|
|
33
|
+
values = wp.array(values_np, device=device, dtype=int)
|
|
34
|
+
|
|
35
|
+
unique_values = wp.empty_like(values)
|
|
36
|
+
unique_counts = wp.empty_like(values)
|
|
37
|
+
|
|
38
|
+
run_count = runlength_encode(values, unique_values, unique_counts)
|
|
39
|
+
|
|
40
|
+
test.assertEqual(run_count, len(unique_values_np))
|
|
41
|
+
assert_np_equal(unique_values.numpy()[:run_count], unique_values_np[:run_count])
|
|
42
|
+
assert_np_equal(unique_counts.numpy()[:run_count], unique_counts_np[:run_count])
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def test_runlength_encode_error_insufficient_storage(test, device):
|
|
46
|
+
values = wp.zeros(123, dtype=int, device=device)
|
|
47
|
+
run_values = wp.empty(1, dtype=int, device=device)
|
|
48
|
+
run_lengths = wp.empty(123, dtype=int, device=device)
|
|
49
|
+
with test.assertRaisesRegex(
|
|
50
|
+
RuntimeError,
|
|
51
|
+
r"Output array storage sizes must be at least equal to value_count$",
|
|
52
|
+
):
|
|
53
|
+
runlength_encode(values, run_values, run_lengths)
|
|
54
|
+
|
|
55
|
+
values = wp.zeros(123, dtype=int, device="cpu")
|
|
56
|
+
run_values = wp.empty(123, dtype=int, device="cpu")
|
|
57
|
+
run_lengths = wp.empty(1, dtype=int, device="cpu")
|
|
58
|
+
with test.assertRaisesRegex(
|
|
59
|
+
RuntimeError,
|
|
60
|
+
r"Output array storage sizes must be at least equal to value_count$",
|
|
61
|
+
):
|
|
62
|
+
runlength_encode(values, run_values, run_lengths)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def test_runlength_encode_error_dtypes_mismatch(test, device):
|
|
66
|
+
values = wp.zeros(123, dtype=int, device=device)
|
|
67
|
+
run_values = wp.empty(123, dtype=float, device=device)
|
|
68
|
+
run_lengths = wp.empty_like(values, device=device)
|
|
69
|
+
with test.assertRaisesRegex(
|
|
70
|
+
RuntimeError,
|
|
71
|
+
r"values and run_values data types do not match$",
|
|
72
|
+
):
|
|
73
|
+
runlength_encode(values, run_values, run_lengths)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def test_runlength_encode_error_run_length_unsupported_dtype(test, device):
|
|
77
|
+
values = wp.zeros(123, dtype=int, device=device)
|
|
78
|
+
run_values = wp.empty(123, dtype=int, device=device)
|
|
79
|
+
run_lengths = wp.empty(123, dtype=float, device=device)
|
|
80
|
+
with test.assertRaisesRegex(
|
|
81
|
+
RuntimeError,
|
|
82
|
+
r"run_lengths array must be of type int32$",
|
|
83
|
+
):
|
|
84
|
+
runlength_encode(values, run_values, run_lengths)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def test_runlength_encode_error_run_count_unsupported_dtype(test, device):
|
|
88
|
+
values = wp.zeros(123, dtype=int, device=device)
|
|
89
|
+
run_values = wp.empty_like(values, device=device)
|
|
90
|
+
run_lengths = wp.empty_like(values, device=device)
|
|
91
|
+
run_count = wp.empty(shape=(1,), dtype=float, device=device)
|
|
92
|
+
with test.assertRaisesRegex(
|
|
93
|
+
RuntimeError,
|
|
94
|
+
r"run_count array must be of type int32$",
|
|
95
|
+
):
|
|
96
|
+
runlength_encode(values, run_values, run_lengths, run_count=run_count)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def test_runlength_encode_error_unsupported_dtype(test, device):
|
|
100
|
+
values = wp.zeros(123, dtype=float, device=device)
|
|
101
|
+
run_values = wp.empty(123, dtype=float, device=device)
|
|
102
|
+
run_lengths = wp.empty(123, dtype=int, device=device)
|
|
103
|
+
with test.assertRaisesRegex(
|
|
104
|
+
RuntimeError,
|
|
105
|
+
r"Unsupported data type$",
|
|
106
|
+
):
|
|
107
|
+
runlength_encode(values, run_values, run_lengths)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
devices = get_test_devices()
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
class TestRunlengthEncode(unittest.TestCase):
|
|
114
|
+
@unittest.skipUnless(wp.is_cuda_available(), "Requires CUDA")
|
|
115
|
+
def test_runlength_encode_error_devices_mismatch(self):
|
|
116
|
+
values = wp.zeros(123, dtype=int, device="cpu")
|
|
117
|
+
run_values = wp.empty_like(values, device="cuda:0")
|
|
118
|
+
run_lengths = wp.empty_like(values, device="cuda:0")
|
|
119
|
+
with self.assertRaisesRegex(
|
|
120
|
+
RuntimeError,
|
|
121
|
+
r"Array storage devices do not match$",
|
|
122
|
+
):
|
|
123
|
+
runlength_encode(values, run_values, run_lengths)
|
|
124
|
+
|
|
125
|
+
values = wp.zeros(123, dtype=int, device="cpu")
|
|
126
|
+
run_values = wp.empty_like(values, device="cpu")
|
|
127
|
+
run_lengths = wp.empty_like(values, device="cuda:0")
|
|
128
|
+
with self.assertRaisesRegex(
|
|
129
|
+
RuntimeError,
|
|
130
|
+
r"Array storage devices do not match$",
|
|
131
|
+
):
|
|
132
|
+
runlength_encode(values, run_values, run_lengths)
|
|
133
|
+
|
|
134
|
+
values = wp.zeros(123, dtype=int, device="cpu")
|
|
135
|
+
run_values = wp.empty_like(values, device="cuda:0")
|
|
136
|
+
run_lengths = wp.empty_like(values, device="cpu")
|
|
137
|
+
with self.assertRaisesRegex(
|
|
138
|
+
RuntimeError,
|
|
139
|
+
r"Array storage devices do not match$",
|
|
140
|
+
):
|
|
141
|
+
runlength_encode(values, run_values, run_lengths)
|
|
142
|
+
|
|
143
|
+
@unittest.skipUnless(wp.is_cuda_available(), "Requires CUDA")
|
|
144
|
+
def test_runlength_encode_error_run_count_device_mismatch(self):
|
|
145
|
+
values = wp.zeros(123, dtype=int, device="cpu")
|
|
146
|
+
run_values = wp.empty_like(values, device="cpu")
|
|
147
|
+
run_lengths = wp.empty_like(values, device="cpu")
|
|
148
|
+
run_count = wp.empty(shape=(1,), dtype=int, device="cuda:0")
|
|
149
|
+
with self.assertRaisesRegex(
|
|
150
|
+
RuntimeError,
|
|
151
|
+
r"run_count storage device does not match other arrays$",
|
|
152
|
+
):
|
|
153
|
+
runlength_encode(values, run_values, run_lengths, run_count=run_count)
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
add_function_test(
|
|
157
|
+
TestRunlengthEncode, "test_runlength_encode_int", partial(test_runlength_encode_int, n=100), devices=devices
|
|
158
|
+
)
|
|
159
|
+
add_function_test(
|
|
160
|
+
TestRunlengthEncode, "test_runlength_encode_empty", partial(test_runlength_encode_int, n=0), devices=devices
|
|
161
|
+
)
|
|
162
|
+
add_function_test(
|
|
163
|
+
TestRunlengthEncode,
|
|
164
|
+
"test_runlength_encode_error_insufficient_storage",
|
|
165
|
+
test_runlength_encode_error_insufficient_storage,
|
|
166
|
+
devices=devices,
|
|
167
|
+
)
|
|
168
|
+
add_function_test(
|
|
169
|
+
TestRunlengthEncode,
|
|
170
|
+
"test_runlength_encode_error_dtypes_mismatch",
|
|
171
|
+
test_runlength_encode_error_dtypes_mismatch,
|
|
172
|
+
devices=devices,
|
|
173
|
+
)
|
|
174
|
+
add_function_test(
|
|
175
|
+
TestRunlengthEncode,
|
|
176
|
+
"test_runlength_encode_error_run_length_unsupported_dtype",
|
|
177
|
+
test_runlength_encode_error_run_length_unsupported_dtype,
|
|
178
|
+
devices=devices,
|
|
179
|
+
)
|
|
180
|
+
add_function_test(
|
|
181
|
+
TestRunlengthEncode,
|
|
182
|
+
"test_runlength_encode_error_run_count_unsupported_dtype",
|
|
183
|
+
test_runlength_encode_error_run_count_unsupported_dtype,
|
|
184
|
+
devices=devices,
|
|
185
|
+
)
|
|
186
|
+
add_function_test(
|
|
187
|
+
TestRunlengthEncode,
|
|
188
|
+
"test_runlength_encode_error_unsupported_dtype",
|
|
189
|
+
test_runlength_encode_error_unsupported_dtype,
|
|
190
|
+
devices=devices,
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
if __name__ == "__main__":
|
|
195
|
+
wp.clear_kernel_cache()
|
|
196
|
+
unittest.main(verbosity=2)
|
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import unittest
|
|
17
|
+
|
|
18
|
+
import numpy as np
|
|
19
|
+
|
|
20
|
+
import warp as wp
|
|
21
|
+
from warp.tests.unittest_utils import *
|
|
22
|
+
|
|
23
|
+
np_signed_int_types = [
|
|
24
|
+
np.int8,
|
|
25
|
+
np.int16,
|
|
26
|
+
np.int32,
|
|
27
|
+
np.int64,
|
|
28
|
+
np.byte,
|
|
29
|
+
]
|
|
30
|
+
|
|
31
|
+
np_unsigned_int_types = [
|
|
32
|
+
np.uint8,
|
|
33
|
+
np.uint16,
|
|
34
|
+
np.uint32,
|
|
35
|
+
np.uint64,
|
|
36
|
+
np.ubyte,
|
|
37
|
+
]
|
|
38
|
+
|
|
39
|
+
np_int_types = np_signed_int_types + np_unsigned_int_types
|
|
40
|
+
|
|
41
|
+
np_float_types = [np.float16, np.float32, np.float64]
|
|
42
|
+
|
|
43
|
+
np_scalar_types = np_int_types + np_float_types
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def test_py_arithmetic_ops(test, device, dtype):
|
|
47
|
+
wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
|
|
48
|
+
|
|
49
|
+
def make_scalar(value):
|
|
50
|
+
if wptype in wp.types.int_types:
|
|
51
|
+
# Cast to the correct integer type to simulate wrapping.
|
|
52
|
+
return wptype._type_(value).value
|
|
53
|
+
|
|
54
|
+
return value
|
|
55
|
+
|
|
56
|
+
a = wptype(1)
|
|
57
|
+
test.assertAlmostEqual(+a, make_scalar(1))
|
|
58
|
+
test.assertAlmostEqual(-a, make_scalar(-1))
|
|
59
|
+
test.assertAlmostEqual(a + wptype(5), make_scalar(6))
|
|
60
|
+
test.assertAlmostEqual(a - wptype(5), make_scalar(-4))
|
|
61
|
+
test.assertAlmostEqual(a % wptype(2), make_scalar(1))
|
|
62
|
+
|
|
63
|
+
a = wptype(2)
|
|
64
|
+
test.assertAlmostEqual(a * wptype(2), make_scalar(4))
|
|
65
|
+
test.assertAlmostEqual(wptype(2) * a, make_scalar(4))
|
|
66
|
+
test.assertAlmostEqual(a / wptype(2), make_scalar(1))
|
|
67
|
+
test.assertAlmostEqual(wptype(24) / a, make_scalar(12))
|
|
68
|
+
test.assertAlmostEqual(a % wptype(2), make_scalar(0))
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def test_py_math_ops(test, device, dtype):
|
|
72
|
+
wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
|
|
73
|
+
|
|
74
|
+
def make_scalar(value):
|
|
75
|
+
if wptype in wp.types.int_types:
|
|
76
|
+
# Cast to the correct integer type to simulate wrapping.
|
|
77
|
+
return wptype._type_(value).value
|
|
78
|
+
|
|
79
|
+
return value
|
|
80
|
+
|
|
81
|
+
a = wptype(1)
|
|
82
|
+
test.assertAlmostEqual(wp.abs(a), 1)
|
|
83
|
+
|
|
84
|
+
if dtype in np_float_types:
|
|
85
|
+
test.assertAlmostEqual(wp.sin(a), 0.84147098480789650488, places=3)
|
|
86
|
+
test.assertAlmostEqual(wp.radians(a), 0.01745329251994329577, places=5)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
devices = get_test_devices()
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
class TestScalarOps(unittest.TestCase):
|
|
93
|
+
pass
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
for dtype in np_scalar_types:
|
|
97
|
+
add_function_test(
|
|
98
|
+
TestScalarOps, f"test_py_arithmetic_ops_{dtype.__name__}", test_py_arithmetic_ops, devices=None, dtype=dtype
|
|
99
|
+
)
|
|
100
|
+
add_function_test(TestScalarOps, f"test_py_math_ops_{dtype.__name__}", test_py_math_ops, devices=None, dtype=dtype)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
if __name__ == "__main__":
|
|
104
|
+
wp.clear_kernel_cache()
|
|
105
|
+
unittest.main(verbosity=2, failfast=True)
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import unittest
|
|
17
|
+
from dataclasses import dataclass
|
|
18
|
+
from typing import Any
|
|
19
|
+
|
|
20
|
+
import numpy as np
|
|
21
|
+
|
|
22
|
+
import warp as wp
|
|
23
|
+
from warp.tests.unittest_utils import *
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclass
|
|
27
|
+
class TestData:
|
|
28
|
+
a: Any
|
|
29
|
+
b: Any
|
|
30
|
+
t: float
|
|
31
|
+
expected: Any
|
|
32
|
+
expected_adj_a: Any = None
|
|
33
|
+
expected_adj_b: Any = None
|
|
34
|
+
expected_adj_t: float = None
|
|
35
|
+
|
|
36
|
+
def check_backwards(self):
|
|
37
|
+
return self.expected_adj_a is not None and self.expected_adj_b is not None and self.expected_adj_t is not None
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
TEST_DATA = {
|
|
41
|
+
wp.float32: (
|
|
42
|
+
TestData(a=1.0, b=2.0, t=1.5, expected=0.5, expected_adj_a=-0.75, expected_adj_b=-0.75, expected_adj_t=1.5),
|
|
43
|
+
TestData(
|
|
44
|
+
a=-1.0,
|
|
45
|
+
b=2.0,
|
|
46
|
+
t=-0.25,
|
|
47
|
+
expected=0.15625,
|
|
48
|
+
expected_adj_a=-0.28125,
|
|
49
|
+
expected_adj_b=-0.09375,
|
|
50
|
+
expected_adj_t=0.375,
|
|
51
|
+
),
|
|
52
|
+
TestData(a=0.0, b=1.0, t=9.9, expected=1.0, expected_adj_a=0.0, expected_adj_b=0.0, expected_adj_t=0.0),
|
|
53
|
+
TestData(a=0.0, b=1.0, t=-9.9, expected=0.0, expected_adj_a=0.0, expected_adj_b=0.0, expected_adj_t=0.0),
|
|
54
|
+
),
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def test_smoothstep(test, device):
|
|
59
|
+
def make_kernel_fn(data_type):
|
|
60
|
+
def fn(
|
|
61
|
+
a: wp.array(dtype=data_type),
|
|
62
|
+
b: wp.array(dtype=data_type),
|
|
63
|
+
t: wp.array(dtype=float),
|
|
64
|
+
out: wp.array(dtype=data_type),
|
|
65
|
+
):
|
|
66
|
+
out[0] = wp.smoothstep(a[0], b[0], t[0])
|
|
67
|
+
|
|
68
|
+
return fn
|
|
69
|
+
|
|
70
|
+
for data_type, test_data_set in TEST_DATA.items():
|
|
71
|
+
kernel_fn = make_kernel_fn(data_type)
|
|
72
|
+
kernel = wp.Kernel(
|
|
73
|
+
func=kernel_fn,
|
|
74
|
+
key=f"test_smoothstep{data_type.__name__}_kernel",
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
for test_data in test_data_set:
|
|
78
|
+
a = wp.array([test_data.a], dtype=data_type, device=device, requires_grad=True)
|
|
79
|
+
b = wp.array([test_data.b], dtype=data_type, device=device, requires_grad=True)
|
|
80
|
+
t = wp.array([test_data.t], dtype=float, device=device, requires_grad=True)
|
|
81
|
+
out = wp.array([0] * wp.types.type_length(data_type), dtype=data_type, device=device, requires_grad=True)
|
|
82
|
+
|
|
83
|
+
with wp.Tape() as tape:
|
|
84
|
+
wp.launch(kernel, dim=1, inputs=[a, b, t, out], device=device)
|
|
85
|
+
|
|
86
|
+
assert_np_equal(out.numpy(), np.array([test_data.expected]), tol=1e-6)
|
|
87
|
+
|
|
88
|
+
if test_data.check_backwards():
|
|
89
|
+
tape.backward(out)
|
|
90
|
+
|
|
91
|
+
assert_np_equal(tape.gradients[a].numpy(), np.array([test_data.expected_adj_a]), tol=1e-6)
|
|
92
|
+
assert_np_equal(tape.gradients[b].numpy(), np.array([test_data.expected_adj_b]), tol=1e-6)
|
|
93
|
+
assert_np_equal(tape.gradients[t].numpy(), np.array([test_data.expected_adj_t]), tol=1e-6)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
devices = get_test_devices()
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class TestSmoothstep(unittest.TestCase):
|
|
100
|
+
pass
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
add_function_test(TestSmoothstep, "test_smoothstep", test_smoothstep, devices=devices)
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
if __name__ == "__main__":
|
|
107
|
+
wp.clear_kernel_cache()
|
|
108
|
+
unittest.main(verbosity=2)
|