warp-lang 1.6.1__py3-none-macosx_10_13_universal2.whl → 1.7.0__py3-none-macosx_10_13_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +21 -7
- warp/autograd.py +14 -6
- warp/bin/libwarp-clang.dylib +0 -0
- warp/bin/libwarp.dylib +0 -0
- warp/build.py +424 -6
- warp/build_dll.py +20 -20
- warp/builtins.py +467 -368
- warp/codegen.py +193 -125
- warp/config.py +56 -12
- warp/constants.py +14 -6
- warp/context.py +524 -277
- warp/dlpack.py +22 -12
- warp/examples/__init__.py +14 -6
- warp/examples/assets/nonuniform.usd +0 -0
- warp/examples/assets/nvidia_logo.png +0 -0
- warp/examples/benchmarks/benchmark_api.py +14 -6
- warp/examples/benchmarks/benchmark_cloth.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_cupy.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_jax.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_numba.py +15 -0
- warp/examples/benchmarks/benchmark_cloth_numpy.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_paddle.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_pytorch.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_taichi.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_warp.py +14 -6
- warp/examples/benchmarks/benchmark_gemm.py +82 -48
- warp/examples/benchmarks/benchmark_interop_paddle.py +14 -6
- warp/examples/benchmarks/benchmark_interop_torch.py +14 -6
- warp/examples/benchmarks/benchmark_launches.py +14 -6
- warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
- warp/examples/browse.py +14 -6
- warp/examples/core/example_cupy.py +14 -6
- warp/examples/core/example_dem.py +14 -6
- warp/examples/core/example_fluid.py +14 -6
- warp/examples/core/example_graph_capture.py +14 -6
- warp/examples/core/example_marching_cubes.py +14 -6
- warp/examples/core/example_mesh.py +14 -6
- warp/examples/core/example_mesh_intersect.py +14 -6
- warp/examples/core/example_nvdb.py +14 -6
- warp/examples/core/example_raycast.py +14 -6
- warp/examples/core/example_raymarch.py +14 -6
- warp/examples/core/example_render_opengl.py +14 -6
- warp/examples/core/example_sample_mesh.py +300 -0
- warp/examples/core/example_sph.py +14 -6
- warp/examples/core/example_torch.py +14 -6
- warp/examples/core/example_wave.py +14 -6
- warp/examples/fem/example_adaptive_grid.py +14 -6
- warp/examples/fem/example_apic_fluid.py +15 -7
- warp/examples/fem/example_burgers.py +16 -8
- warp/examples/fem/example_convection_diffusion.py +14 -6
- warp/examples/fem/example_convection_diffusion_dg.py +14 -6
- warp/examples/fem/example_deformed_geometry.py +15 -7
- warp/examples/fem/example_diffusion.py +14 -6
- warp/examples/fem/example_diffusion_3d.py +14 -6
- warp/examples/fem/example_diffusion_mgpu.py +14 -6
- warp/examples/fem/example_distortion_energy.py +15 -7
- warp/examples/fem/example_magnetostatics.py +20 -12
- warp/examples/fem/example_mixed_elasticity.py +14 -6
- warp/examples/fem/example_navier_stokes.py +14 -6
- warp/examples/fem/example_nonconforming_contact.py +14 -6
- warp/examples/fem/example_stokes.py +14 -6
- warp/examples/fem/example_stokes_transfer.py +14 -6
- warp/examples/fem/example_streamlines.py +14 -6
- warp/examples/fem/utils.py +24 -3
- warp/examples/interop/example_jax_callable.py +116 -0
- warp/examples/interop/example_jax_ffi_callback.py +132 -0
- warp/examples/interop/example_jax_kernel.py +205 -0
- warp/examples/optim/example_bounce.py +14 -6
- warp/examples/optim/example_cloth_throw.py +14 -6
- warp/examples/optim/example_diffray.py +14 -6
- warp/examples/optim/example_drone.py +14 -6
- warp/examples/optim/example_fluid_checkpoint.py +497 -0
- warp/examples/optim/example_inverse_kinematics.py +14 -6
- warp/examples/optim/example_inverse_kinematics_torch.py +14 -6
- warp/examples/optim/example_softbody_properties.py +14 -6
- warp/examples/optim/example_spring_cage.py +14 -6
- warp/examples/optim/example_trajectory.py +14 -6
- warp/examples/sim/example_cartpole.py +14 -6
- warp/examples/sim/example_cloth.py +14 -6
- warp/examples/sim/example_cloth_self_contact.py +14 -6
- warp/examples/sim/example_granular.py +14 -6
- warp/examples/sim/example_granular_collision_sdf.py +14 -6
- warp/examples/sim/example_jacobian_ik.py +14 -6
- warp/examples/sim/example_particle_chain.py +14 -6
- warp/examples/sim/example_quadruped.py +14 -6
- warp/examples/sim/example_rigid_chain.py +14 -6
- warp/examples/sim/example_rigid_contact.py +14 -6
- warp/examples/sim/example_rigid_force.py +14 -6
- warp/examples/sim/example_rigid_gyroscopic.py +14 -6
- warp/examples/sim/example_rigid_soft_contact.py +14 -6
- warp/examples/sim/example_soft_body.py +14 -6
- warp/examples/tile/example_tile_cholesky.py +14 -6
- warp/examples/tile/example_tile_convolution.py +14 -6
- warp/examples/tile/example_tile_fft.py +14 -6
- warp/examples/tile/example_tile_filtering.py +14 -6
- warp/examples/tile/example_tile_matmul.py +16 -10
- warp/examples/tile/example_tile_mlp.py +14 -6
- warp/examples/tile/example_tile_nbody.py +14 -6
- warp/examples/tile/example_tile_walker.py +14 -6
- warp/fabric.py +15 -0
- warp/fem/__init__.py +26 -1
- warp/fem/adaptivity.py +19 -4
- warp/fem/cache.py +15 -0
- warp/fem/dirichlet.py +15 -0
- warp/fem/domain.py +15 -0
- warp/fem/field/__init__.py +15 -0
- warp/fem/field/field.py +15 -0
- warp/fem/field/nodal_field.py +37 -68
- warp/fem/field/restriction.py +15 -0
- warp/fem/field/virtual.py +77 -23
- warp/fem/geometry/__init__.py +15 -0
- warp/fem/geometry/adaptive_nanogrid.py +24 -10
- warp/fem/geometry/closest_point.py +16 -1
- warp/fem/geometry/deformed_geometry.py +20 -2
- warp/fem/geometry/element.py +15 -0
- warp/fem/geometry/geometry.py +20 -0
- warp/fem/geometry/grid_2d.py +27 -12
- warp/fem/geometry/grid_3d.py +27 -15
- warp/fem/geometry/hexmesh.py +20 -7
- warp/fem/geometry/nanogrid.py +24 -11
- warp/fem/geometry/partition.py +15 -0
- warp/fem/geometry/quadmesh.py +28 -13
- warp/fem/geometry/tetmesh.py +18 -4
- warp/fem/geometry/trimesh.py +18 -8
- warp/fem/integrate.py +277 -93
- warp/fem/linalg.py +20 -5
- warp/fem/operator.py +15 -0
- warp/fem/polynomial.py +15 -0
- warp/fem/quadrature/__init__.py +15 -0
- warp/fem/quadrature/pic_quadrature.py +52 -22
- warp/fem/quadrature/quadrature.py +209 -25
- warp/fem/space/__init__.py +16 -1
- warp/fem/space/basis_function_space.py +19 -2
- warp/fem/space/basis_space.py +40 -18
- warp/fem/space/dof_mapper.py +15 -0
- warp/fem/space/function_space.py +15 -0
- warp/fem/space/grid_2d_function_space.py +15 -0
- warp/fem/space/grid_3d_function_space.py +15 -0
- warp/fem/space/hexmesh_function_space.py +17 -2
- warp/fem/space/nanogrid_function_space.py +15 -0
- warp/fem/space/partition.py +21 -2
- warp/fem/space/quadmesh_function_space.py +23 -8
- warp/fem/space/restriction.py +15 -0
- warp/fem/space/shape/__init__.py +15 -0
- warp/fem/space/shape/cube_shape_function.py +38 -23
- warp/fem/space/shape/shape_function.py +15 -0
- warp/fem/space/shape/square_shape_function.py +27 -12
- warp/fem/space/shape/tet_shape_function.py +15 -0
- warp/fem/space/shape/triangle_shape_function.py +16 -1
- warp/fem/space/tetmesh_function_space.py +18 -3
- warp/fem/space/topology.py +15 -0
- warp/fem/space/trimesh_function_space.py +17 -2
- warp/fem/types.py +15 -0
- warp/fem/utils.py +27 -6
- warp/jax.py +28 -7
- warp/jax_experimental/__init__.py +16 -0
- warp/{jax_experimental.py → jax_experimental/custom_call.py} +28 -33
- warp/jax_experimental/ffi.py +698 -0
- warp/jax_experimental/xla_ffi.py +602 -0
- warp/math.py +103 -6
- warp/native/array.h +28 -6
- warp/native/builtin.h +44 -9
- warp/native/bvh.cpp +18 -7
- warp/native/bvh.cu +57 -20
- warp/native/bvh.h +17 -7
- warp/native/clang/clang.cpp +45 -9
- warp/native/coloring.cpp +15 -6
- warp/native/crt.cpp +15 -6
- warp/native/crt.h +15 -6
- warp/native/cuda_crt.h +15 -6
- warp/native/cuda_util.cpp +29 -6
- warp/native/cuda_util.h +17 -6
- warp/native/error.cpp +15 -6
- warp/native/error.h +15 -6
- warp/native/exports.h +85 -63
- warp/native/fabric.h +15 -6
- warp/native/hashgrid.cpp +15 -6
- warp/native/hashgrid.cu +15 -6
- warp/native/hashgrid.h +15 -6
- warp/native/initializer_array.h +15 -6
- warp/native/intersect.h +41 -32
- warp/native/intersect_adj.h +48 -39
- warp/native/intersect_tri.h +17 -0
- warp/native/marching.cpp +16 -0
- warp/native/marching.cu +16 -7
- warp/native/marching.h +17 -0
- warp/native/mat.h +528 -15
- warp/native/mathdx.cpp +15 -6
- warp/native/matnn.h +15 -6
- warp/native/mesh.cpp +15 -6
- warp/native/mesh.cu +15 -6
- warp/native/mesh.h +25 -16
- warp/native/noise.h +15 -6
- warp/native/quat.h +114 -17
- warp/native/rand.h +21 -6
- warp/native/range.h +15 -6
- warp/native/reduce.cpp +15 -6
- warp/native/reduce.cu +15 -6
- warp/native/runlength_encode.cpp +15 -6
- warp/native/runlength_encode.cu +15 -6
- warp/native/scan.cpp +15 -6
- warp/native/scan.cu +15 -6
- warp/native/scan.h +15 -6
- warp/native/solid_angle.h +17 -0
- warp/native/sort.cpp +137 -65
- warp/native/sort.cu +167 -21
- warp/native/sort.h +23 -7
- warp/native/sparse.cpp +58 -28
- warp/native/sparse.cu +67 -23
- warp/native/spatial.h +15 -6
- warp/native/svd.h +131 -6
- warp/native/temp_buffer.h +15 -6
- warp/native/tile.h +316 -111
- warp/native/tile_reduce.h +61 -9
- warp/native/vec.h +83 -13
- warp/native/volume.cpp +100 -119
- warp/native/volume.cu +15 -6
- warp/native/volume.h +15 -6
- warp/native/volume_builder.cu +40 -16
- warp/native/volume_builder.h +21 -6
- warp/native/volume_impl.h +15 -6
- warp/native/warp.cpp +20 -12
- warp/native/warp.cu +114 -16
- warp/native/warp.h +34 -16
- warp/optim/__init__.py +14 -6
- warp/optim/adam.py +14 -6
- warp/optim/linear.py +25 -10
- warp/optim/sgd.py +14 -6
- warp/paddle.py +14 -6
- warp/render/__init__.py +14 -6
- warp/render/render_opengl.py +14 -6
- warp/render/render_usd.py +14 -6
- warp/render/utils.py +14 -6
- warp/sim/__init__.py +14 -7
- warp/sim/articulation.py +18 -10
- warp/sim/collide.py +35 -16
- warp/sim/graph_coloring.py +14 -6
- warp/sim/import_mjcf.py +463 -162
- warp/sim/import_snu.py +14 -7
- warp/sim/import_urdf.py +46 -18
- warp/sim/import_usd.py +14 -7
- warp/sim/inertia.py +14 -6
- warp/sim/integrator.py +14 -6
- warp/sim/integrator_euler.py +19 -11
- warp/sim/integrator_featherstone.py +17 -16
- warp/sim/integrator_vbd.py +222 -8
- warp/sim/integrator_xpbd.py +19 -11
- warp/sim/model.py +56 -19
- warp/sim/particles.py +14 -6
- warp/sim/render.py +14 -6
- warp/sim/utils.py +17 -2
- warp/sparse.py +657 -555
- warp/stubs.py +231 -19
- warp/tape.py +14 -6
- warp/tests/aux_test_class_kernel.py +14 -6
- warp/tests/aux_test_compile_consts_dummy.py +14 -6
- warp/tests/aux_test_conditional_unequal_types_kernels.py +14 -6
- warp/tests/aux_test_dependent.py +14 -6
- warp/tests/aux_test_grad_customs.py +14 -6
- warp/tests/aux_test_instancing_gc.py +14 -6
- warp/tests/aux_test_module_unload.py +14 -6
- warp/tests/aux_test_name_clash1.py +14 -6
- warp/tests/aux_test_name_clash2.py +14 -6
- warp/tests/aux_test_unresolved_func.py +14 -6
- warp/tests/aux_test_unresolved_symbol.py +14 -6
- warp/tests/cuda/__init__.py +0 -0
- warp/tests/{test_async.py → cuda/test_async.py} +14 -6
- warp/tests/{test_ipc.py → cuda/test_ipc.py} +14 -6
- warp/tests/{test_mempool.py → cuda/test_mempool.py} +53 -6
- warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +14 -6
- warp/tests/{test_peer.py → cuda/test_peer.py} +14 -6
- warp/tests/{test_pinned.py → cuda/test_pinned.py} +14 -6
- warp/tests/{test_streams.py → cuda/test_streams.py} +85 -6
- warp/tests/geometry/__init__.py +0 -0
- warp/tests/{test_bvh.py → geometry/test_bvh.py} +14 -6
- warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +14 -6
- warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +14 -6
- warp/tests/{test_mesh.py → geometry/test_mesh.py} +14 -6
- warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +14 -6
- warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +80 -69
- warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +15 -7
- warp/tests/{test_volume.py → geometry/test_volume.py} +55 -12
- warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +14 -6
- warp/tests/interop/__init__.py +0 -0
- warp/tests/{test_dlpack.py → interop/test_dlpack.py} +42 -11
- warp/tests/{test_jax.py → interop/test_jax.py} +14 -6
- warp/tests/{test_paddle.py → interop/test_paddle.py} +14 -6
- warp/tests/{test_torch.py → interop/test_torch.py} +14 -6
- warp/tests/run_coverage_serial.py +14 -6
- warp/tests/sim/__init__.py +0 -0
- warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +23 -16
- warp/tests/{flaky_test_sim_grad.py → sim/flaky_test_sim_grad.py} +14 -6
- warp/tests/{test_collision.py → sim/test_collision.py} +16 -8
- warp/tests/{test_coloring.py → sim/test_coloring.py} +14 -7
- warp/tests/{test_model.py → sim/test_model.py} +55 -7
- warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +14 -6
- warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +16 -7
- warp/tests/sim/test_vbd.py +597 -0
- warp/tests/test_adam.py +14 -6
- warp/tests/test_arithmetic.py +14 -6
- warp/tests/test_array.py +14 -6
- warp/tests/test_array_reduce.py +14 -6
- warp/tests/test_assert.py +14 -6
- warp/tests/test_atomic.py +14 -6
- warp/tests/test_bool.py +15 -7
- warp/tests/test_builtins_resolution.py +14 -6
- warp/tests/test_closest_point_edge_edge.py +14 -6
- warp/tests/test_codegen.py +14 -6
- warp/tests/test_codegen_instancing.py +14 -6
- warp/tests/test_compile_consts.py +14 -6
- warp/tests/test_conditional.py +14 -6
- warp/tests/test_context.py +14 -6
- warp/tests/test_copy.py +14 -6
- warp/tests/test_ctypes.py +14 -6
- warp/tests/test_dense.py +14 -6
- warp/tests/test_devices.py +14 -6
- warp/tests/test_examples.py +42 -42
- warp/tests/test_fabricarray.py +14 -6
- warp/tests/test_fast_math.py +14 -6
- warp/tests/test_fem.py +37 -10
- warp/tests/test_fp16.py +14 -6
- warp/tests/test_func.py +14 -6
- warp/tests/test_future_annotations.py +14 -6
- warp/tests/test_generics.py +14 -6
- warp/tests/test_grad.py +14 -6
- warp/tests/test_grad_customs.py +14 -6
- warp/tests/test_grad_debug.py +14 -6
- warp/tests/test_implicit_init.py +14 -6
- warp/tests/test_import.py +14 -6
- warp/tests/test_indexedarray.py +14 -6
- warp/tests/test_intersect.py +14 -6
- warp/tests/test_iter.py +14 -6
- warp/tests/test_large.py +14 -6
- warp/tests/test_launch.py +14 -6
- warp/tests/test_lerp.py +14 -6
- warp/tests/test_linear_solvers.py +15 -11
- warp/tests/test_lvalue.py +14 -6
- warp/tests/test_mat.py +247 -85
- warp/tests/test_mat_lite.py +14 -6
- warp/tests/test_mat_scalar_ops.py +18 -10
- warp/tests/test_math.py +14 -6
- warp/tests/test_mlp.py +14 -6
- warp/tests/test_module_hashing.py +14 -6
- warp/tests/test_modules_lite.py +14 -6
- warp/tests/test_noise.py +14 -6
- warp/tests/test_operators.py +14 -6
- warp/tests/test_options.py +14 -6
- warp/tests/test_overwrite.py +15 -60
- warp/tests/test_print.py +14 -6
- warp/tests/test_quat.py +81 -52
- warp/tests/test_rand.py +58 -43
- warp/tests/test_reload.py +14 -6
- warp/tests/test_rounding.py +14 -6
- warp/tests/test_runlength_encode.py +14 -6
- warp/tests/test_scalar_ops.py +14 -6
- warp/tests/test_smoothstep.py +14 -6
- warp/tests/test_snippet.py +15 -0
- warp/tests/test_sparse.py +61 -12
- warp/tests/test_spatial.py +89 -6
- warp/tests/test_special_values.py +14 -6
- warp/tests/test_static.py +15 -7
- warp/tests/test_struct.py +14 -6
- warp/tests/test_tape.py +14 -6
- warp/tests/test_transient_module.py +14 -6
- warp/tests/test_triangle_closest_point.py +14 -6
- warp/tests/test_types.py +14 -6
- warp/tests/test_utils.py +98 -10
- warp/tests/test_vec.py +60 -40
- warp/tests/test_vec_lite.py +14 -6
- warp/tests/test_vec_scalar_ops.py +14 -6
- warp/tests/test_verify_fp.py +14 -6
- warp/tests/tile/__init__.py +0 -0
- warp/tests/{test_tile.py → tile/test_tile.py} +150 -57
- warp/tests/{test_tile_load.py → tile/test_tile_load.py} +15 -7
- warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +23 -12
- warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +39 -20
- warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +74 -7
- warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +14 -6
- warp/tests/{test_tile_view.py → tile/test_tile_view.py} +15 -7
- warp/tests/unittest_serial.py +15 -6
- warp/tests/unittest_suites.py +59 -65
- warp/tests/unittest_utils.py +16 -7
- warp/tests/walkthrough_debug.py +14 -6
- warp/thirdparty/unittest_parallel.py +15 -8
- warp/torch.py +14 -6
- warp/types.py +124 -664
- warp/utils.py +151 -78
- {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/METADATA +39 -12
- warp_lang-1.7.0.dist-info/RECORD +429 -0
- {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/WHEEL +1 -1
- warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
- warp/examples/optim/example_walker.py +0 -309
- warp/native/cutlass_gemm.cpp +0 -34
- warp/native/cutlass_gemm.cu +0 -373
- warp/tests/test_matmul.py +0 -503
- warp/tests/test_matmul_lite.py +0 -403
- warp/tests/test_vbd.py +0 -378
- warp/tests/unused_test_misc.py +0 -69
- warp_lang-1.6.1.dist-info/LICENSE.md +0 -126
- warp_lang-1.6.1.dist-info/RECORD +0 -419
- {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/top_level.txt +0 -0
|
@@ -1,9 +1,17 @@
|
|
|
1
|
-
# Copyright (c) 2024 NVIDIA CORPORATION.
|
|
2
|
-
#
|
|
3
|
-
#
|
|
4
|
-
#
|
|
5
|
-
#
|
|
6
|
-
#
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
7
15
|
|
|
8
16
|
import unittest
|
|
9
17
|
|
|
@@ -12,8 +20,6 @@ import numpy as np
|
|
|
12
20
|
import warp as wp
|
|
13
21
|
from warp.tests.unittest_utils import *
|
|
14
22
|
|
|
15
|
-
wp.init() # For wp.context.runtime.core.is_mathdx_enabled()
|
|
16
|
-
|
|
17
23
|
TILE_M = wp.constant(8)
|
|
18
24
|
TILE_N = wp.constant(4)
|
|
19
25
|
TILE_K = wp.constant(8)
|
|
@@ -208,7 +214,6 @@ def test_tile_binary_map(test, device):
|
|
|
208
214
|
assert_np_equal(B_wp.grad.numpy(), B_grad)
|
|
209
215
|
|
|
210
216
|
|
|
211
|
-
@unittest.skipUnless(wp.context.runtime.core.is_mathdx_enabled(), "Warp was not built with MathDx support")
|
|
212
217
|
def test_tile_grouped_gemm(test, device):
|
|
213
218
|
@wp.kernel
|
|
214
219
|
def tile_grouped_gemm(A: wp.array3d(dtype=float), B: wp.array3d(dtype=float), C: wp.array3d(dtype=float)):
|
|
@@ -248,60 +253,62 @@ def test_tile_grouped_gemm(test, device):
|
|
|
248
253
|
assert_np_equal(C_wp.numpy(), C, 1e-6)
|
|
249
254
|
|
|
250
255
|
|
|
251
|
-
|
|
252
|
-
def
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
256
|
+
def test_tile_gemm(dtype):
|
|
257
|
+
def test(test, device):
|
|
258
|
+
@wp.kernel
|
|
259
|
+
def tile_gemm(A: wp.array2d(dtype=dtype), B: wp.array2d(dtype=dtype), C: wp.array2d(dtype=dtype)):
|
|
260
|
+
# output tile index
|
|
261
|
+
i, j = wp.tid()
|
|
257
262
|
|
|
258
|
-
|
|
263
|
+
sum = wp.tile_zeros(shape=(TILE_M, TILE_N), dtype=dtype)
|
|
259
264
|
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
265
|
+
M = A.shape[0]
|
|
266
|
+
N = B.shape[1]
|
|
267
|
+
K = A.shape[1]
|
|
263
268
|
|
|
264
|
-
|
|
269
|
+
count = int(K / TILE_K)
|
|
265
270
|
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
271
|
+
for k in range(0, count):
|
|
272
|
+
a = wp.tile_load(A, shape=(TILE_M, TILE_K), offset=(i * TILE_M, k * TILE_K))
|
|
273
|
+
b = wp.tile_load(B, shape=(TILE_K, TILE_N), offset=(k * TILE_K, j * TILE_N))
|
|
269
274
|
|
|
270
|
-
|
|
271
|
-
|
|
275
|
+
# sum += a*b
|
|
276
|
+
wp.tile_matmul(a, b, sum)
|
|
272
277
|
|
|
273
|
-
|
|
278
|
+
wp.tile_store(C, sum, offset=(i * TILE_M, j * TILE_N))
|
|
274
279
|
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
280
|
+
M = TILE_M * 7
|
|
281
|
+
K = TILE_K * 6
|
|
282
|
+
N = TILE_N * 5
|
|
278
283
|
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
284
|
+
rng = np.random.default_rng(42)
|
|
285
|
+
A = rng.random((M, K), dtype=float).astype(wp.dtype_to_numpy(dtype))
|
|
286
|
+
B = rng.random((K, N), dtype=float).astype(wp.dtype_to_numpy(dtype))
|
|
287
|
+
C = np.zeros((M, N), dtype=float).astype(wp.dtype_to_numpy(dtype))
|
|
283
288
|
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
289
|
+
A_wp = wp.array(A, requires_grad=True, device=device)
|
|
290
|
+
B_wp = wp.array(B, requires_grad=True, device=device)
|
|
291
|
+
C_wp = wp.array(C, requires_grad=True, device=device)
|
|
287
292
|
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
293
|
+
with wp.Tape() as tape:
|
|
294
|
+
wp.launch_tiled(
|
|
295
|
+
tile_gemm,
|
|
296
|
+
dim=(int(M / TILE_M), int(N / TILE_N)),
|
|
297
|
+
inputs=[A_wp, B_wp, C_wp],
|
|
298
|
+
block_dim=TILE_DIM,
|
|
299
|
+
device=device,
|
|
300
|
+
)
|
|
296
301
|
|
|
297
|
-
|
|
302
|
+
assert_np_equal(C_wp.numpy(), A @ B, tol=1.0e-1)
|
|
298
303
|
|
|
299
|
-
|
|
304
|
+
adj_C = np.ones_like(C)
|
|
300
305
|
|
|
301
|
-
|
|
306
|
+
tape.backward(grads={C_wp: wp.array(adj_C, device=device)})
|
|
302
307
|
|
|
303
|
-
|
|
304
|
-
|
|
308
|
+
assert_np_equal(A_wp.grad.numpy(), adj_C @ B.T, tol=1.0e-1)
|
|
309
|
+
assert_np_equal(B_wp.grad.numpy(), A.T @ adj_C, 1.0e-1)
|
|
310
|
+
|
|
311
|
+
return test
|
|
305
312
|
|
|
306
313
|
|
|
307
314
|
@wp.kernel
|
|
@@ -542,7 +549,6 @@ def test_tile_transpose(test, device):
|
|
|
542
549
|
assert_np_equal(output.numpy(), input.numpy().T)
|
|
543
550
|
|
|
544
551
|
|
|
545
|
-
@unittest.skipUnless(wp.context.runtime.core.is_mathdx_enabled(), "Warp was not built with MathDx support")
|
|
546
552
|
def test_tile_transpose_matmul(test, device):
|
|
547
553
|
@wp.kernel
|
|
548
554
|
def test_tile_transpose_matmul_kernel(input: wp.array2d(dtype=float), output: wp.array2d(dtype=float)):
|
|
@@ -564,9 +570,36 @@ def test_tile_transpose_matmul(test, device):
|
|
|
564
570
|
|
|
565
571
|
|
|
566
572
|
@wp.kernel
|
|
567
|
-
def
|
|
573
|
+
def test_tile_broadcast_add_1d_kernel(
|
|
574
|
+
input_a: wp.array(dtype=float), input_b: wp.array(dtype=float), output: wp.array(dtype=float)
|
|
575
|
+
):
|
|
576
|
+
a = wp.tile_load(input_a, shape=(10,))
|
|
577
|
+
b = wp.tile_load(input_b, shape=(1,))
|
|
578
|
+
|
|
579
|
+
c = wp.tile_broadcast(b, shape=(10,))
|
|
580
|
+
d = a + c
|
|
581
|
+
|
|
582
|
+
wp.tile_store(output, d)
|
|
583
|
+
|
|
584
|
+
|
|
585
|
+
def test_tile_broadcast_add_1d(test, device):
|
|
586
|
+
N = 10
|
|
587
|
+
|
|
588
|
+
# implicit 1-dim ([1], 1)
|
|
589
|
+
a = wp.array(np.arange(0, N, dtype=np.float32), device=device)
|
|
590
|
+
b = wp.array(np.ones(1, dtype=np.float32), device=device)
|
|
591
|
+
out = wp.zeros((N,), dtype=float, device=device)
|
|
592
|
+
|
|
593
|
+
wp.launch_tiled(test_tile_broadcast_add_1d_kernel, dim=[1], inputs=[a, b, out], block_dim=32, device=device)
|
|
594
|
+
|
|
595
|
+
assert_np_equal(out.numpy(), a.numpy() + b.numpy())
|
|
596
|
+
|
|
597
|
+
|
|
598
|
+
@wp.kernel
|
|
599
|
+
def test_tile_broadcast_add_2d_kernel(
|
|
568
600
|
input_a: wp.array2d(dtype=float), input_b: wp.array(dtype=float), output: wp.array2d(dtype=float)
|
|
569
601
|
):
|
|
602
|
+
# implicit 1-dim ([1], 10)
|
|
570
603
|
a = wp.tile_load(input_a, shape=(10, 10))
|
|
571
604
|
b = wp.tile_load(input_b, shape=10)
|
|
572
605
|
|
|
@@ -576,7 +609,7 @@ def test_tile_broadcast_add_kernel(
|
|
|
576
609
|
wp.tile_store(output, d)
|
|
577
610
|
|
|
578
611
|
|
|
579
|
-
def
|
|
612
|
+
def test_tile_broadcast_add_2d(test, device):
|
|
580
613
|
M = 10
|
|
581
614
|
N = 10
|
|
582
615
|
|
|
@@ -584,7 +617,62 @@ def test_tile_broadcast_add(test, device):
|
|
|
584
617
|
b = wp.array(np.arange(0, N, dtype=np.float32), device=device)
|
|
585
618
|
out = wp.zeros((M, N), dtype=float, device=device)
|
|
586
619
|
|
|
587
|
-
wp.launch_tiled(
|
|
620
|
+
wp.launch_tiled(test_tile_broadcast_add_2d_kernel, dim=[1], inputs=[a, b, out], block_dim=32, device=device)
|
|
621
|
+
|
|
622
|
+
assert_np_equal(out.numpy(), a.numpy() + b.numpy())
|
|
623
|
+
|
|
624
|
+
|
|
625
|
+
@wp.kernel
|
|
626
|
+
def test_tile_broadcast_add_3d_kernel(
|
|
627
|
+
input_a: wp.array3d(dtype=float), input_b: wp.array3d(dtype=float), output: wp.array3d(dtype=float)
|
|
628
|
+
):
|
|
629
|
+
a = wp.tile_load(input_a, shape=(4, 10, 12))
|
|
630
|
+
b = wp.tile_load(input_b, shape=(4, 10, 1))
|
|
631
|
+
|
|
632
|
+
c = wp.tile_broadcast(b, shape=(4, 10, 12))
|
|
633
|
+
d = a + c
|
|
634
|
+
|
|
635
|
+
wp.tile_store(output, d)
|
|
636
|
+
|
|
637
|
+
|
|
638
|
+
def test_tile_broadcast_add_3d(test, device):
|
|
639
|
+
M = 4
|
|
640
|
+
N = 10
|
|
641
|
+
O = 12
|
|
642
|
+
|
|
643
|
+
# explicit 1-dim (M, N, 1) to (M, N, O)
|
|
644
|
+
a = wp.array(np.ones((M, N, O), dtype=np.float32), device=device)
|
|
645
|
+
b = wp.array(np.arange(0, M * N, dtype=np.float32).reshape((M, N, 1)), device=device)
|
|
646
|
+
out = wp.zeros((M, N, O), dtype=float, device=device)
|
|
647
|
+
|
|
648
|
+
wp.launch_tiled(test_tile_broadcast_add_3d_kernel, dim=[1], inputs=[a, b, out], block_dim=32, device=device)
|
|
649
|
+
assert_np_equal(out.numpy(), a.numpy() + b.numpy())
|
|
650
|
+
|
|
651
|
+
|
|
652
|
+
@wp.kernel
|
|
653
|
+
def test_tile_broadcast_add_4d_kernel(
|
|
654
|
+
input_a: wp.array4d(dtype=float), input_b: wp.array4d(dtype=float), output: wp.array4d(dtype=float)
|
|
655
|
+
):
|
|
656
|
+
a = wp.tile_load(input_a, shape=(4, 10, 5, 6))
|
|
657
|
+
b = wp.tile_load(input_b, shape=(4, 1, 5, 1))
|
|
658
|
+
c = wp.tile_broadcast(b, shape=(4, 10, 5, 6))
|
|
659
|
+
d = a + c
|
|
660
|
+
|
|
661
|
+
wp.tile_store(output, d)
|
|
662
|
+
|
|
663
|
+
|
|
664
|
+
def test_tile_broadcast_add_4d(test, device):
|
|
665
|
+
M = 4
|
|
666
|
+
N = 10
|
|
667
|
+
O = 5
|
|
668
|
+
P = 6
|
|
669
|
+
|
|
670
|
+
# explicit 1-dims (M, 1, O, 1) to (M, N, O, P)
|
|
671
|
+
a = wp.array(np.ones((M, N, O, P), dtype=np.float32), device=device)
|
|
672
|
+
b = wp.array(np.arange(0, M * O, dtype=np.float32).reshape((M, 1, O, 1)), device=device)
|
|
673
|
+
out = wp.zeros((M, N, O, P), dtype=float, device=device)
|
|
674
|
+
|
|
675
|
+
wp.launch_tiled(test_tile_broadcast_add_4d_kernel, dim=[1], inputs=[a, b, out], block_dim=32, device=device)
|
|
588
676
|
|
|
589
677
|
assert_np_equal(out.numpy(), a.numpy() + b.numpy())
|
|
590
678
|
|
|
@@ -657,7 +745,7 @@ def test_tile_print(test, device):
|
|
|
657
745
|
wp.synchronize()
|
|
658
746
|
|
|
659
747
|
|
|
660
|
-
devices =
|
|
748
|
+
devices = get_test_devices()
|
|
661
749
|
|
|
662
750
|
|
|
663
751
|
class TestTile(unittest.TestCase):
|
|
@@ -669,15 +757,20 @@ add_function_test(TestTile, "test_tile_copy_2d", test_tile_copy_2d, devices=devi
|
|
|
669
757
|
add_function_test(TestTile, "test_tile_unary_map", test_tile_unary_map, devices=devices)
|
|
670
758
|
add_function_test(TestTile, "test_tile_binary_map", test_tile_binary_map, devices=devices)
|
|
671
759
|
add_function_test(TestTile, "test_tile_grouped_gemm", test_tile_grouped_gemm, devices=devices)
|
|
672
|
-
add_function_test(TestTile, "
|
|
760
|
+
add_function_test(TestTile, "test_tile_gemm_fp16", test_tile_gemm(wp.float16), devices=devices)
|
|
761
|
+
add_function_test(TestTile, "test_tile_gemm_fp32", test_tile_gemm(wp.float32), devices=devices)
|
|
762
|
+
add_function_test(TestTile, "test_tile_gemm_fp64", test_tile_gemm(wp.float64), devices=devices)
|
|
673
763
|
add_function_test(TestTile, "test_tile_transpose", test_tile_transpose, devices=devices)
|
|
674
764
|
add_function_test(TestTile, "test_tile_transpose_matmul", test_tile_transpose_matmul, devices=devices)
|
|
675
765
|
add_function_test(TestTile, "test_tile_operators", test_tile_operators, devices=devices)
|
|
676
|
-
add_function_test(TestTile, "test_tile_sum", test_tile_sum, devices=devices)
|
|
766
|
+
add_function_test(TestTile, "test_tile_sum", test_tile_sum, devices=devices, check_output=False)
|
|
677
767
|
add_function_test(TestTile, "test_tile_sum_launch", test_tile_sum_launch, devices=devices)
|
|
678
768
|
add_function_test(TestTile, "test_tile_extract", test_tile_extract, devices=devices)
|
|
679
769
|
add_function_test(TestTile, "test_tile_extract_repeated", test_tile_extract_repeated, devices=devices)
|
|
680
|
-
add_function_test(TestTile, "
|
|
770
|
+
add_function_test(TestTile, "test_tile_broadcast_add_1d", test_tile_broadcast_add_1d, devices=devices)
|
|
771
|
+
add_function_test(TestTile, "test_tile_broadcast_add_2d", test_tile_broadcast_add_2d, devices=devices)
|
|
772
|
+
add_function_test(TestTile, "test_tile_broadcast_add_3d", test_tile_broadcast_add_3d, devices=devices)
|
|
773
|
+
add_function_test(TestTile, "test_tile_broadcast_add_4d", test_tile_broadcast_add_4d, devices=devices)
|
|
681
774
|
add_function_test(TestTile, "test_tile_broadcast_grad", test_tile_broadcast_grad, devices=devices)
|
|
682
775
|
add_function_test(TestTile, "test_tile_len", test_tile_len, devices=devices)
|
|
683
776
|
add_function_test(TestTile, "test_tile_print", test_tile_print, devices=devices, check_output=False)
|
|
@@ -1,9 +1,17 @@
|
|
|
1
|
-
# Copyright (c) 2024 NVIDIA CORPORATION.
|
|
2
|
-
#
|
|
3
|
-
#
|
|
4
|
-
#
|
|
5
|
-
#
|
|
6
|
-
#
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
7
15
|
|
|
8
16
|
import unittest
|
|
9
17
|
|
|
@@ -368,7 +376,7 @@ def test_tile_load_fortran(test, device):
|
|
|
368
376
|
assert_array_equal(B_wp.grad, A_wp.grad)
|
|
369
377
|
|
|
370
378
|
|
|
371
|
-
devices =
|
|
379
|
+
devices = get_test_devices()
|
|
372
380
|
|
|
373
381
|
|
|
374
382
|
class TestTileLoad(unittest.TestCase):
|
|
@@ -1,9 +1,17 @@
|
|
|
1
|
-
# Copyright (c) 2024 NVIDIA CORPORATION.
|
|
2
|
-
#
|
|
3
|
-
#
|
|
4
|
-
#
|
|
5
|
-
#
|
|
6
|
-
#
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
7
15
|
|
|
8
16
|
import functools
|
|
9
17
|
import unittest
|
|
@@ -84,6 +92,7 @@ def tile_math_fft_kernel_vec2d(gx: wp.array2d(dtype=wp.vec2d), gy: wp.array2d(dt
|
|
|
84
92
|
wp.tile_store(gy, xy)
|
|
85
93
|
|
|
86
94
|
|
|
95
|
+
@unittest.skipUnless(wp.context.runtime.core.is_mathdx_enabled(), "Warp was not built with MathDx support")
|
|
87
96
|
def test_tile_math_fft(test, device, wp_dtype):
|
|
88
97
|
np_real_dtype = {wp.vec2f: np.float32, wp.vec2d: np.float64}[wp_dtype]
|
|
89
98
|
np_cplx_dtype = {wp.vec2f: np.complex64, wp.vec2d: np.complex128}[wp_dtype]
|
|
@@ -164,31 +173,33 @@ def test_tile_math_cholesky(test, device):
|
|
|
164
173
|
# TODO: implement and test backward pass
|
|
165
174
|
|
|
166
175
|
|
|
167
|
-
|
|
176
|
+
all_devices = get_test_devices()
|
|
177
|
+
cuda_devices = get_cuda_test_devices()
|
|
168
178
|
|
|
169
179
|
|
|
170
|
-
@unittest.skipUnless(wp.context.runtime.core.is_mathdx_enabled(), "Warp was not built with MathDx support")
|
|
171
180
|
class TestTileMathDx(unittest.TestCase):
|
|
172
181
|
pass
|
|
173
182
|
|
|
174
183
|
|
|
175
184
|
# check_output=False so we can enable libmathdx's logging without failing the tests
|
|
176
|
-
add_function_test(TestTileMathDx, "test_tile_math_matmul", test_tile_math_matmul, devices=devices, check_output=False)
|
|
177
185
|
add_function_test(
|
|
178
|
-
TestTileMathDx, "
|
|
186
|
+
TestTileMathDx, "test_tile_math_matmul", test_tile_math_matmul, devices=all_devices, check_output=False
|
|
187
|
+
)
|
|
188
|
+
add_function_test(
|
|
189
|
+
TestTileMathDx, "test_tile_math_cholesky", test_tile_math_cholesky, devices=all_devices, check_output=False
|
|
179
190
|
)
|
|
180
191
|
add_function_test(
|
|
181
192
|
TestTileMathDx,
|
|
182
193
|
"test_tile_math_fft_vec2f",
|
|
183
194
|
functools.partial(test_tile_math_fft, wp_dtype=wp.vec2f),
|
|
184
|
-
devices=
|
|
195
|
+
devices=cuda_devices,
|
|
185
196
|
check_output=False,
|
|
186
197
|
)
|
|
187
198
|
add_function_test(
|
|
188
199
|
TestTileMathDx,
|
|
189
200
|
"test_tile_math_fft_vec2d",
|
|
190
201
|
functools.partial(test_tile_math_fft, wp_dtype=wp.vec2d),
|
|
191
|
-
devices=
|
|
202
|
+
devices=cuda_devices,
|
|
192
203
|
check_output=False,
|
|
193
204
|
)
|
|
194
205
|
|
|
@@ -1,9 +1,17 @@
|
|
|
1
|
-
# Copyright (c) 2024 NVIDIA CORPORATION.
|
|
2
|
-
#
|
|
3
|
-
#
|
|
4
|
-
#
|
|
5
|
-
#
|
|
6
|
-
#
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
7
15
|
|
|
8
16
|
import os
|
|
9
17
|
|
|
@@ -14,11 +22,6 @@ import warp.examples
|
|
|
14
22
|
import warp.optim
|
|
15
23
|
from warp.tests.unittest_utils import *
|
|
16
24
|
|
|
17
|
-
wp.init()
|
|
18
|
-
|
|
19
|
-
# needs to be constant for the whole module
|
|
20
|
-
NUM_THREADS = 32
|
|
21
|
-
|
|
22
25
|
|
|
23
26
|
def create_layer(rng, dim_in, dim_hid, dtype=float):
|
|
24
27
|
w = rng.uniform(-1.0 / np.sqrt(dim_in), 1.0 / np.sqrt(dim_in), (dim_hid, dim_in))
|
|
@@ -37,10 +40,12 @@ def create_array(rng, dim_in, dim_hid, dtype=float):
|
|
|
37
40
|
return a
|
|
38
41
|
|
|
39
42
|
|
|
40
|
-
@unittest.skipUnless(wp.context.runtime.core.is_mathdx_enabled(), "Warp was not built with MathDx support")
|
|
41
43
|
def test_multi_layer_nn(test, device):
|
|
42
44
|
import torch as tc
|
|
43
45
|
|
|
46
|
+
if device.is_cuda and not wp.context.runtime.core.is_mathdx_enabled():
|
|
47
|
+
test.skipTest("Skipping test on CUDA device without MathDx (tolerance)")
|
|
48
|
+
|
|
44
49
|
NUM_FREQ = wp.constant(8)
|
|
45
50
|
|
|
46
51
|
DIM_IN = wp.constant(4 * NUM_FREQ) # sin,cos for both x,y at each frequency
|
|
@@ -52,7 +57,13 @@ def test_multi_layer_nn(test, device):
|
|
|
52
57
|
|
|
53
58
|
BATCH_SIZE = min(512, int((IMG_WIDTH * IMG_HEIGHT) / 8))
|
|
54
59
|
|
|
60
|
+
if device.is_cpu:
|
|
61
|
+
NUM_THREADS = 1
|
|
62
|
+
else:
|
|
63
|
+
NUM_THREADS = 32
|
|
64
|
+
|
|
55
65
|
dtype = wp.float16
|
|
66
|
+
npdtype = wp.types.warp_type_to_np_dtype[dtype]
|
|
56
67
|
|
|
57
68
|
@wp.func
|
|
58
69
|
def relu(x: dtype):
|
|
@@ -66,7 +77,7 @@ def test_multi_layer_nn(test, device):
|
|
|
66
77
|
def zero(loss: wp.array(dtype=float)):
|
|
67
78
|
loss[0] = 0.0
|
|
68
79
|
|
|
69
|
-
@wp.kernel
|
|
80
|
+
@wp.kernel(module="unique")
|
|
70
81
|
def compute(
|
|
71
82
|
batches: wp.array(dtype=int),
|
|
72
83
|
input: wp.array2d(dtype=dtype),
|
|
@@ -162,7 +173,9 @@ def test_multi_layer_nn(test, device):
|
|
|
162
173
|
input = create_array(rng, IMG_WIDTH * IMG_HEIGHT, DIM_IN, dtype=dtype)
|
|
163
174
|
output = create_array(rng, IMG_WIDTH * IMG_HEIGHT, DIM_OUT)
|
|
164
175
|
|
|
165
|
-
reference_np =
|
|
176
|
+
reference_np = (
|
|
177
|
+
np.load(os.path.join(os.path.dirname(__file__), "..", "assets", "pixel.npy"), allow_pickle=True) / 255.0
|
|
178
|
+
)
|
|
166
179
|
reference = wp.array(reference_np, dtype=float)
|
|
167
180
|
|
|
168
181
|
assert reference.shape[1] == IMG_WIDTH * IMG_HEIGHT
|
|
@@ -224,7 +237,7 @@ def test_multi_layer_nn(test, device):
|
|
|
224
237
|
z_np = np.maximum(weights_3.numpy() @ z_np + bias_3.numpy(), 0.0)
|
|
225
238
|
|
|
226
239
|
# test numpy forward
|
|
227
|
-
assert_np_equal(output.numpy()[:, indices], z_np, tol=1.0e-2)
|
|
240
|
+
assert_np_equal(output.numpy()[:, indices].astype(npdtype), z_np, tol=1.0e-2)
|
|
228
241
|
|
|
229
242
|
# torch
|
|
230
243
|
input_tc = tc.tensor(input.numpy()[:, indices], requires_grad=True, device=torch_device)
|
|
@@ -252,7 +265,9 @@ def test_multi_layer_nn(test, device):
|
|
|
252
265
|
l_tc.backward()
|
|
253
266
|
|
|
254
267
|
# test torch
|
|
255
|
-
assert_np_equal(
|
|
268
|
+
assert_np_equal(
|
|
269
|
+
z_tc.cpu().detach().numpy(), output.numpy()[:, indices].astype(npdtype), tol=1.0e-2
|
|
270
|
+
)
|
|
256
271
|
assert_np_equal(weights_0.grad.numpy(), weights_0_tc.grad.cpu().detach().numpy(), tol=1.0e-2)
|
|
257
272
|
assert_np_equal(bias_0.grad.numpy(), bias_0_tc.grad.cpu().detach().numpy(), tol=1.0e-2)
|
|
258
273
|
assert_np_equal(weights_1.grad.numpy(), weights_1_tc.grad.cpu().detach().numpy(), tol=1.0e-2)
|
|
@@ -269,7 +284,6 @@ def test_multi_layer_nn(test, device):
|
|
|
269
284
|
test.assertLess(loss.numpy()[0], 0.002)
|
|
270
285
|
|
|
271
286
|
|
|
272
|
-
@unittest.skipUnless(wp.context.runtime.core.is_mathdx_enabled(), "Warp was not built with MathDx support")
|
|
273
287
|
def test_single_layer_nn(test, device):
|
|
274
288
|
import torch as tc
|
|
275
289
|
|
|
@@ -279,11 +293,16 @@ def test_single_layer_nn(test, device):
|
|
|
279
293
|
|
|
280
294
|
NUM_BLOCKS = 56
|
|
281
295
|
|
|
296
|
+
if device.is_cpu:
|
|
297
|
+
NUM_THREADS = 1
|
|
298
|
+
else:
|
|
299
|
+
NUM_THREADS = 32
|
|
300
|
+
|
|
282
301
|
@wp.func
|
|
283
302
|
def relu(x: float):
|
|
284
303
|
return wp.max(x, 0.0)
|
|
285
304
|
|
|
286
|
-
@wp.kernel
|
|
305
|
+
@wp.kernel(module="unique")
|
|
287
306
|
def compute(
|
|
288
307
|
input: wp.array2d(dtype=float),
|
|
289
308
|
weights: wp.array2d(dtype=float),
|
|
@@ -345,7 +364,6 @@ try:
|
|
|
345
364
|
import torch
|
|
346
365
|
|
|
347
366
|
# check which Warp devices work with Torch
|
|
348
|
-
# CUDA devices may fail if Torch was not compiled with CUDA support
|
|
349
367
|
torch_compatible_devices = []
|
|
350
368
|
torch_compatible_cuda_devices = []
|
|
351
369
|
|
|
@@ -364,7 +382,7 @@ try:
|
|
|
364
382
|
"test_single_layer_nn",
|
|
365
383
|
test_single_layer_nn,
|
|
366
384
|
check_output=False,
|
|
367
|
-
devices=
|
|
385
|
+
devices=torch_compatible_devices,
|
|
368
386
|
)
|
|
369
387
|
add_function_test(
|
|
370
388
|
TestTileMLP,
|
|
@@ -380,4 +398,5 @@ except Exception as e:
|
|
|
380
398
|
|
|
381
399
|
if __name__ == "__main__":
|
|
382
400
|
wp.clear_kernel_cache()
|
|
401
|
+
wp.clear_lto_cache()
|
|
383
402
|
unittest.main(verbosity=2, failfast=True)
|
|
@@ -1,9 +1,17 @@
|
|
|
1
|
-
# Copyright (c) 2024 NVIDIA CORPORATION.
|
|
2
|
-
#
|
|
3
|
-
#
|
|
4
|
-
#
|
|
5
|
-
#
|
|
6
|
-
#
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
7
15
|
|
|
8
16
|
import unittest
|
|
9
17
|
|
|
@@ -168,6 +176,64 @@ def test_tile_reduce_custom(test, device):
|
|
|
168
176
|
test.assertAlmostEqual(prod_wp[i], prod_np, places=4)
|
|
169
177
|
|
|
170
178
|
|
|
179
|
+
@wp.struct
|
|
180
|
+
class KeyValue:
|
|
181
|
+
key: wp.int32
|
|
182
|
+
value: wp.float32
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
@wp.func
|
|
186
|
+
def kv_max(a: KeyValue, b: KeyValue) -> KeyValue:
|
|
187
|
+
return wp.where(a.value < b.value, b, a)
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
@wp.kernel
|
|
191
|
+
def initialize_key_value(values: wp.array2d(dtype=wp.float32), keyvalues: wp.array2d(dtype=KeyValue)):
|
|
192
|
+
batch, idx = wp.tid()
|
|
193
|
+
keyvalues[batch, idx] = KeyValue(idx, values[batch, idx])
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
@wp.kernel(enable_backward=False)
|
|
197
|
+
def tile_reduce_custom_struct_kernel(values: wp.array2d(dtype=KeyValue), res: wp.array(dtype=KeyValue)):
|
|
198
|
+
# output tile index
|
|
199
|
+
i = wp.tid()
|
|
200
|
+
|
|
201
|
+
t = wp.tile_load(values, shape=(1, TILE_DIM), offset=(i, 0))
|
|
202
|
+
|
|
203
|
+
max_el = wp.tile_reduce(kv_max, t)
|
|
204
|
+
wp.tile_store(res, max_el, offset=i)
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
def test_tile_reduce_custom_struct(test, device):
|
|
208
|
+
batch_count = 56
|
|
209
|
+
|
|
210
|
+
N = TILE_DIM
|
|
211
|
+
|
|
212
|
+
rng = np.random.default_rng(42)
|
|
213
|
+
input = rng.random((batch_count, N), dtype=np.float32)
|
|
214
|
+
|
|
215
|
+
input_wp = wp.array(input, dtype=wp.float32, device=device)
|
|
216
|
+
keyvalues_wp = wp.empty(input_wp.shape, dtype=KeyValue, device=device)
|
|
217
|
+
|
|
218
|
+
wp.launch(initialize_key_value, dim=[batch_count, N], inputs=[input_wp], outputs=[keyvalues_wp], device=device)
|
|
219
|
+
|
|
220
|
+
output_wp = wp.empty(batch_count, dtype=KeyValue, device=device)
|
|
221
|
+
|
|
222
|
+
wp.launch_tiled(
|
|
223
|
+
tile_reduce_custom_struct_kernel,
|
|
224
|
+
dim=[batch_count],
|
|
225
|
+
inputs=[keyvalues_wp],
|
|
226
|
+
outputs=[output_wp],
|
|
227
|
+
block_dim=TILE_DIM,
|
|
228
|
+
device=device,
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
prod_wp = np.array([k for k, v in output_wp.numpy()])
|
|
232
|
+
expected = np.argmax(input, axis=1)
|
|
233
|
+
|
|
234
|
+
assert_np_equal(prod_wp, expected)
|
|
235
|
+
|
|
236
|
+
|
|
171
237
|
@wp.kernel
|
|
172
238
|
def tile_grouped_sum_kernel(input: wp.array3d(dtype=float), output: wp.array(dtype=float)):
|
|
173
239
|
# output tile index
|
|
@@ -357,7 +423,7 @@ def test_tile_arange(test, device):
|
|
|
357
423
|
assert_np_equal(output.numpy()[4], np.arange(17, 0, -1))
|
|
358
424
|
|
|
359
425
|
|
|
360
|
-
devices =
|
|
426
|
+
devices = get_test_devices()
|
|
361
427
|
|
|
362
428
|
|
|
363
429
|
class TestTileReduce(unittest.TestCase):
|
|
@@ -368,6 +434,7 @@ add_function_test(TestTileReduce, "test_tile_reduce_sum", test_tile_reduce_sum,
|
|
|
368
434
|
add_function_test(TestTileReduce, "test_tile_reduce_min", test_tile_reduce_min, devices=devices)
|
|
369
435
|
add_function_test(TestTileReduce, "test_tile_reduce_max", test_tile_reduce_max, devices=devices)
|
|
370
436
|
add_function_test(TestTileReduce, "test_tile_reduce_custom", test_tile_reduce_custom, devices=devices)
|
|
437
|
+
add_function_test(TestTileReduce, "test_tile_reduce_custom_struct", test_tile_reduce_custom_struct, devices=devices)
|
|
371
438
|
add_function_test(TestTileReduce, "test_tile_reduce_grouped_sum", test_tile_reduce_sum, devices=devices)
|
|
372
439
|
add_function_test(TestTileReduce, "test_tile_reduce_simt", test_tile_reduce_simt, devices=devices)
|
|
373
440
|
add_function_test(TestTileReduce, "test_tile_ones", test_tile_ones, devices=devices)
|