warp-lang 1.10.0__py3-none-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +334 -0
- warp/__init__.pyi +5856 -0
- warp/_src/__init__.py +14 -0
- warp/_src/autograd.py +1077 -0
- warp/_src/build.py +620 -0
- warp/_src/build_dll.py +642 -0
- warp/_src/builtins.py +10555 -0
- warp/_src/codegen.py +4361 -0
- warp/_src/config.py +178 -0
- warp/_src/constants.py +59 -0
- warp/_src/context.py +8352 -0
- warp/_src/dlpack.py +464 -0
- warp/_src/fabric.py +362 -0
- warp/_src/fem/__init__.py +14 -0
- warp/_src/fem/adaptivity.py +510 -0
- warp/_src/fem/cache.py +689 -0
- warp/_src/fem/dirichlet.py +190 -0
- warp/_src/fem/domain.py +553 -0
- warp/_src/fem/field/__init__.py +131 -0
- warp/_src/fem/field/field.py +703 -0
- warp/_src/fem/field/nodal_field.py +403 -0
- warp/_src/fem/field/restriction.py +39 -0
- warp/_src/fem/field/virtual.py +1021 -0
- warp/_src/fem/geometry/__init__.py +32 -0
- warp/_src/fem/geometry/adaptive_nanogrid.py +782 -0
- warp/_src/fem/geometry/closest_point.py +99 -0
- warp/_src/fem/geometry/deformed_geometry.py +277 -0
- warp/_src/fem/geometry/element.py +854 -0
- warp/_src/fem/geometry/geometry.py +693 -0
- warp/_src/fem/geometry/grid_2d.py +478 -0
- warp/_src/fem/geometry/grid_3d.py +539 -0
- warp/_src/fem/geometry/hexmesh.py +956 -0
- warp/_src/fem/geometry/nanogrid.py +660 -0
- warp/_src/fem/geometry/partition.py +483 -0
- warp/_src/fem/geometry/quadmesh.py +597 -0
- warp/_src/fem/geometry/tetmesh.py +762 -0
- warp/_src/fem/geometry/trimesh.py +588 -0
- warp/_src/fem/integrate.py +2507 -0
- warp/_src/fem/linalg.py +385 -0
- warp/_src/fem/operator.py +398 -0
- warp/_src/fem/polynomial.py +231 -0
- warp/_src/fem/quadrature/__init__.py +17 -0
- warp/_src/fem/quadrature/pic_quadrature.py +318 -0
- warp/_src/fem/quadrature/quadrature.py +665 -0
- warp/_src/fem/space/__init__.py +248 -0
- warp/_src/fem/space/basis_function_space.py +499 -0
- warp/_src/fem/space/basis_space.py +681 -0
- warp/_src/fem/space/dof_mapper.py +253 -0
- warp/_src/fem/space/function_space.py +312 -0
- warp/_src/fem/space/grid_2d_function_space.py +179 -0
- warp/_src/fem/space/grid_3d_function_space.py +229 -0
- warp/_src/fem/space/hexmesh_function_space.py +255 -0
- warp/_src/fem/space/nanogrid_function_space.py +199 -0
- warp/_src/fem/space/partition.py +435 -0
- warp/_src/fem/space/quadmesh_function_space.py +222 -0
- warp/_src/fem/space/restriction.py +221 -0
- warp/_src/fem/space/shape/__init__.py +152 -0
- warp/_src/fem/space/shape/cube_shape_function.py +1107 -0
- warp/_src/fem/space/shape/shape_function.py +134 -0
- warp/_src/fem/space/shape/square_shape_function.py +928 -0
- warp/_src/fem/space/shape/tet_shape_function.py +829 -0
- warp/_src/fem/space/shape/triangle_shape_function.py +674 -0
- warp/_src/fem/space/tetmesh_function_space.py +270 -0
- warp/_src/fem/space/topology.py +461 -0
- warp/_src/fem/space/trimesh_function_space.py +193 -0
- warp/_src/fem/types.py +114 -0
- warp/_src/fem/utils.py +488 -0
- warp/_src/jax.py +188 -0
- warp/_src/jax_experimental/__init__.py +14 -0
- warp/_src/jax_experimental/custom_call.py +389 -0
- warp/_src/jax_experimental/ffi.py +1286 -0
- warp/_src/jax_experimental/xla_ffi.py +658 -0
- warp/_src/marching_cubes.py +710 -0
- warp/_src/math.py +416 -0
- warp/_src/optim/__init__.py +14 -0
- warp/_src/optim/adam.py +165 -0
- warp/_src/optim/linear.py +1608 -0
- warp/_src/optim/sgd.py +114 -0
- warp/_src/paddle.py +408 -0
- warp/_src/render/__init__.py +14 -0
- warp/_src/render/imgui_manager.py +291 -0
- warp/_src/render/render_opengl.py +3638 -0
- warp/_src/render/render_usd.py +939 -0
- warp/_src/render/utils.py +162 -0
- warp/_src/sparse.py +2718 -0
- warp/_src/tape.py +1208 -0
- warp/_src/thirdparty/__init__.py +0 -0
- warp/_src/thirdparty/appdirs.py +598 -0
- warp/_src/thirdparty/dlpack.py +145 -0
- warp/_src/thirdparty/unittest_parallel.py +676 -0
- warp/_src/torch.py +393 -0
- warp/_src/types.py +5888 -0
- warp/_src/utils.py +1695 -0
- warp/autograd.py +33 -0
- warp/bin/libwarp-clang.dylib +0 -0
- warp/bin/libwarp.dylib +0 -0
- warp/build.py +29 -0
- warp/build_dll.py +24 -0
- warp/codegen.py +24 -0
- warp/constants.py +24 -0
- warp/context.py +33 -0
- warp/dlpack.py +24 -0
- warp/examples/__init__.py +24 -0
- warp/examples/assets/bear.usd +0 -0
- warp/examples/assets/bunny.usd +0 -0
- warp/examples/assets/cube.usd +0 -0
- warp/examples/assets/nonuniform.usd +0 -0
- warp/examples/assets/nvidia_logo.png +0 -0
- warp/examples/assets/pixel.jpg +0 -0
- warp/examples/assets/rocks.nvdb +0 -0
- warp/examples/assets/rocks.usd +0 -0
- warp/examples/assets/sphere.usd +0 -0
- warp/examples/assets/square_cloth.usd +0 -0
- warp/examples/benchmarks/benchmark_api.py +389 -0
- warp/examples/benchmarks/benchmark_cloth.py +296 -0
- warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
- warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
- warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
- warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
- warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
- warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
- warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
- warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
- warp/examples/benchmarks/benchmark_gemm.py +164 -0
- warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
- warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
- warp/examples/benchmarks/benchmark_launches.py +301 -0
- warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
- warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
- warp/examples/browse.py +37 -0
- warp/examples/core/example_cupy.py +86 -0
- warp/examples/core/example_dem.py +241 -0
- warp/examples/core/example_fluid.py +299 -0
- warp/examples/core/example_graph_capture.py +150 -0
- warp/examples/core/example_marching_cubes.py +195 -0
- warp/examples/core/example_mesh.py +180 -0
- warp/examples/core/example_mesh_intersect.py +211 -0
- warp/examples/core/example_nvdb.py +182 -0
- warp/examples/core/example_raycast.py +111 -0
- warp/examples/core/example_raymarch.py +205 -0
- warp/examples/core/example_render_opengl.py +290 -0
- warp/examples/core/example_sample_mesh.py +300 -0
- warp/examples/core/example_sph.py +411 -0
- warp/examples/core/example_spin_lock.py +93 -0
- warp/examples/core/example_torch.py +211 -0
- warp/examples/core/example_wave.py +269 -0
- warp/examples/core/example_work_queue.py +118 -0
- warp/examples/distributed/example_jacobi_mpi.py +506 -0
- warp/examples/fem/example_adaptive_grid.py +286 -0
- warp/examples/fem/example_apic_fluid.py +469 -0
- warp/examples/fem/example_burgers.py +261 -0
- warp/examples/fem/example_convection_diffusion.py +181 -0
- warp/examples/fem/example_convection_diffusion_dg.py +225 -0
- warp/examples/fem/example_darcy_ls_optimization.py +489 -0
- warp/examples/fem/example_deformed_geometry.py +172 -0
- warp/examples/fem/example_diffusion.py +196 -0
- warp/examples/fem/example_diffusion_3d.py +225 -0
- warp/examples/fem/example_diffusion_mgpu.py +225 -0
- warp/examples/fem/example_distortion_energy.py +228 -0
- warp/examples/fem/example_elastic_shape_optimization.py +387 -0
- warp/examples/fem/example_magnetostatics.py +242 -0
- warp/examples/fem/example_mixed_elasticity.py +293 -0
- warp/examples/fem/example_navier_stokes.py +263 -0
- warp/examples/fem/example_nonconforming_contact.py +300 -0
- warp/examples/fem/example_stokes.py +213 -0
- warp/examples/fem/example_stokes_transfer.py +262 -0
- warp/examples/fem/example_streamlines.py +357 -0
- warp/examples/fem/utils.py +1047 -0
- warp/examples/interop/example_jax_callable.py +146 -0
- warp/examples/interop/example_jax_ffi_callback.py +132 -0
- warp/examples/interop/example_jax_kernel.py +232 -0
- warp/examples/optim/example_diffray.py +561 -0
- warp/examples/optim/example_fluid_checkpoint.py +497 -0
- warp/examples/tile/example_tile_block_cholesky.py +502 -0
- warp/examples/tile/example_tile_cholesky.py +88 -0
- warp/examples/tile/example_tile_convolution.py +66 -0
- warp/examples/tile/example_tile_fft.py +55 -0
- warp/examples/tile/example_tile_filtering.py +113 -0
- warp/examples/tile/example_tile_matmul.py +85 -0
- warp/examples/tile/example_tile_mcgp.py +191 -0
- warp/examples/tile/example_tile_mlp.py +385 -0
- warp/examples/tile/example_tile_nbody.py +199 -0
- warp/fabric.py +24 -0
- warp/fem/__init__.py +173 -0
- warp/fem/adaptivity.py +26 -0
- warp/fem/cache.py +30 -0
- warp/fem/dirichlet.py +24 -0
- warp/fem/field/__init__.py +24 -0
- warp/fem/field/field.py +26 -0
- warp/fem/geometry/__init__.py +21 -0
- warp/fem/geometry/closest_point.py +31 -0
- warp/fem/linalg.py +38 -0
- warp/fem/operator.py +32 -0
- warp/fem/polynomial.py +29 -0
- warp/fem/space/__init__.py +22 -0
- warp/fem/space/basis_space.py +24 -0
- warp/fem/space/shape/__init__.py +68 -0
- warp/fem/space/topology.py +24 -0
- warp/fem/types.py +24 -0
- warp/fem/utils.py +32 -0
- warp/jax.py +29 -0
- warp/jax_experimental/__init__.py +29 -0
- warp/jax_experimental/custom_call.py +29 -0
- warp/jax_experimental/ffi.py +39 -0
- warp/jax_experimental/xla_ffi.py +24 -0
- warp/marching_cubes.py +24 -0
- warp/math.py +37 -0
- warp/native/array.h +1687 -0
- warp/native/builtin.h +2327 -0
- warp/native/bvh.cpp +562 -0
- warp/native/bvh.cu +826 -0
- warp/native/bvh.h +555 -0
- warp/native/clang/clang.cpp +541 -0
- warp/native/coloring.cpp +622 -0
- warp/native/crt.cpp +51 -0
- warp/native/crt.h +568 -0
- warp/native/cuda_crt.h +1058 -0
- warp/native/cuda_util.cpp +677 -0
- warp/native/cuda_util.h +313 -0
- warp/native/error.cpp +77 -0
- warp/native/error.h +36 -0
- warp/native/exports.h +2023 -0
- warp/native/fabric.h +246 -0
- warp/native/hashgrid.cpp +311 -0
- warp/native/hashgrid.cu +89 -0
- warp/native/hashgrid.h +240 -0
- warp/native/initializer_array.h +41 -0
- warp/native/intersect.h +1253 -0
- warp/native/intersect_adj.h +375 -0
- warp/native/intersect_tri.h +348 -0
- warp/native/mat.h +5189 -0
- warp/native/mathdx.cpp +93 -0
- warp/native/matnn.h +221 -0
- warp/native/mesh.cpp +266 -0
- warp/native/mesh.cu +406 -0
- warp/native/mesh.h +2097 -0
- warp/native/nanovdb/GridHandle.h +533 -0
- warp/native/nanovdb/HostBuffer.h +591 -0
- warp/native/nanovdb/NanoVDB.h +6246 -0
- warp/native/nanovdb/NodeManager.h +323 -0
- warp/native/nanovdb/PNanoVDB.h +3390 -0
- warp/native/noise.h +859 -0
- warp/native/quat.h +1664 -0
- warp/native/rand.h +342 -0
- warp/native/range.h +145 -0
- warp/native/reduce.cpp +174 -0
- warp/native/reduce.cu +363 -0
- warp/native/runlength_encode.cpp +79 -0
- warp/native/runlength_encode.cu +61 -0
- warp/native/scan.cpp +47 -0
- warp/native/scan.cu +55 -0
- warp/native/scan.h +23 -0
- warp/native/solid_angle.h +466 -0
- warp/native/sort.cpp +251 -0
- warp/native/sort.cu +286 -0
- warp/native/sort.h +35 -0
- warp/native/sparse.cpp +241 -0
- warp/native/sparse.cu +435 -0
- warp/native/spatial.h +1306 -0
- warp/native/svd.h +727 -0
- warp/native/temp_buffer.h +46 -0
- warp/native/tile.h +4124 -0
- warp/native/tile_radix_sort.h +1112 -0
- warp/native/tile_reduce.h +838 -0
- warp/native/tile_scan.h +240 -0
- warp/native/tuple.h +189 -0
- warp/native/vec.h +2199 -0
- warp/native/version.h +23 -0
- warp/native/volume.cpp +501 -0
- warp/native/volume.cu +68 -0
- warp/native/volume.h +970 -0
- warp/native/volume_builder.cu +483 -0
- warp/native/volume_builder.h +52 -0
- warp/native/volume_impl.h +70 -0
- warp/native/warp.cpp +1143 -0
- warp/native/warp.cu +4604 -0
- warp/native/warp.h +358 -0
- warp/optim/__init__.py +20 -0
- warp/optim/adam.py +24 -0
- warp/optim/linear.py +35 -0
- warp/optim/sgd.py +24 -0
- warp/paddle.py +24 -0
- warp/py.typed +0 -0
- warp/render/__init__.py +22 -0
- warp/render/imgui_manager.py +29 -0
- warp/render/render_opengl.py +24 -0
- warp/render/render_usd.py +24 -0
- warp/render/utils.py +24 -0
- warp/sparse.py +51 -0
- warp/tape.py +24 -0
- warp/tests/__init__.py +1 -0
- warp/tests/__main__.py +4 -0
- warp/tests/assets/curlnoise_golden.npy +0 -0
- warp/tests/assets/mlp_golden.npy +0 -0
- warp/tests/assets/pixel.npy +0 -0
- warp/tests/assets/pnoise_golden.npy +0 -0
- warp/tests/assets/spiky.usd +0 -0
- warp/tests/assets/test_grid.nvdb +0 -0
- warp/tests/assets/test_index_grid.nvdb +0 -0
- warp/tests/assets/test_int32_grid.nvdb +0 -0
- warp/tests/assets/test_vec_grid.nvdb +0 -0
- warp/tests/assets/torus.nvdb +0 -0
- warp/tests/assets/torus.usda +105 -0
- warp/tests/aux_test_class_kernel.py +34 -0
- warp/tests/aux_test_compile_consts_dummy.py +18 -0
- warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
- warp/tests/aux_test_dependent.py +29 -0
- warp/tests/aux_test_grad_customs.py +29 -0
- warp/tests/aux_test_instancing_gc.py +26 -0
- warp/tests/aux_test_module_aot.py +7 -0
- warp/tests/aux_test_module_unload.py +23 -0
- warp/tests/aux_test_name_clash1.py +40 -0
- warp/tests/aux_test_name_clash2.py +40 -0
- warp/tests/aux_test_reference.py +9 -0
- warp/tests/aux_test_reference_reference.py +8 -0
- warp/tests/aux_test_square.py +16 -0
- warp/tests/aux_test_unresolved_func.py +22 -0
- warp/tests/aux_test_unresolved_symbol.py +22 -0
- warp/tests/cuda/__init__.py +0 -0
- warp/tests/cuda/test_async.py +676 -0
- warp/tests/cuda/test_conditional_captures.py +1147 -0
- warp/tests/cuda/test_ipc.py +124 -0
- warp/tests/cuda/test_mempool.py +233 -0
- warp/tests/cuda/test_multigpu.py +169 -0
- warp/tests/cuda/test_peer.py +139 -0
- warp/tests/cuda/test_pinned.py +84 -0
- warp/tests/cuda/test_streams.py +691 -0
- warp/tests/geometry/__init__.py +0 -0
- warp/tests/geometry/test_bvh.py +335 -0
- warp/tests/geometry/test_hash_grid.py +259 -0
- warp/tests/geometry/test_marching_cubes.py +294 -0
- warp/tests/geometry/test_mesh.py +318 -0
- warp/tests/geometry/test_mesh_query_aabb.py +392 -0
- warp/tests/geometry/test_mesh_query_point.py +935 -0
- warp/tests/geometry/test_mesh_query_ray.py +323 -0
- warp/tests/geometry/test_volume.py +1103 -0
- warp/tests/geometry/test_volume_write.py +346 -0
- warp/tests/interop/__init__.py +0 -0
- warp/tests/interop/test_dlpack.py +730 -0
- warp/tests/interop/test_jax.py +1673 -0
- warp/tests/interop/test_paddle.py +800 -0
- warp/tests/interop/test_torch.py +1001 -0
- warp/tests/run_coverage_serial.py +39 -0
- warp/tests/test_adam.py +162 -0
- warp/tests/test_arithmetic.py +1096 -0
- warp/tests/test_array.py +3756 -0
- warp/tests/test_array_reduce.py +156 -0
- warp/tests/test_assert.py +303 -0
- warp/tests/test_atomic.py +336 -0
- warp/tests/test_atomic_bitwise.py +209 -0
- warp/tests/test_atomic_cas.py +312 -0
- warp/tests/test_bool.py +220 -0
- warp/tests/test_builtins_resolution.py +732 -0
- warp/tests/test_closest_point_edge_edge.py +327 -0
- warp/tests/test_codegen.py +974 -0
- warp/tests/test_codegen_instancing.py +1495 -0
- warp/tests/test_compile_consts.py +215 -0
- warp/tests/test_conditional.py +298 -0
- warp/tests/test_context.py +35 -0
- warp/tests/test_copy.py +319 -0
- warp/tests/test_ctypes.py +618 -0
- warp/tests/test_dense.py +73 -0
- warp/tests/test_devices.py +127 -0
- warp/tests/test_enum.py +136 -0
- warp/tests/test_examples.py +424 -0
- warp/tests/test_fabricarray.py +998 -0
- warp/tests/test_fast_math.py +72 -0
- warp/tests/test_fem.py +2204 -0
- warp/tests/test_fixedarray.py +229 -0
- warp/tests/test_fp16.py +136 -0
- warp/tests/test_func.py +501 -0
- warp/tests/test_future_annotations.py +100 -0
- warp/tests/test_generics.py +656 -0
- warp/tests/test_grad.py +893 -0
- warp/tests/test_grad_customs.py +339 -0
- warp/tests/test_grad_debug.py +341 -0
- warp/tests/test_implicit_init.py +411 -0
- warp/tests/test_import.py +45 -0
- warp/tests/test_indexedarray.py +1140 -0
- warp/tests/test_intersect.py +103 -0
- warp/tests/test_iter.py +76 -0
- warp/tests/test_large.py +177 -0
- warp/tests/test_launch.py +411 -0
- warp/tests/test_lerp.py +151 -0
- warp/tests/test_linear_solvers.py +223 -0
- warp/tests/test_lvalue.py +427 -0
- warp/tests/test_map.py +526 -0
- warp/tests/test_mat.py +3515 -0
- warp/tests/test_mat_assign_copy.py +178 -0
- warp/tests/test_mat_constructors.py +573 -0
- warp/tests/test_mat_lite.py +122 -0
- warp/tests/test_mat_scalar_ops.py +2913 -0
- warp/tests/test_math.py +212 -0
- warp/tests/test_module_aot.py +287 -0
- warp/tests/test_module_hashing.py +258 -0
- warp/tests/test_modules_lite.py +70 -0
- warp/tests/test_noise.py +252 -0
- warp/tests/test_operators.py +299 -0
- warp/tests/test_options.py +129 -0
- warp/tests/test_overwrite.py +551 -0
- warp/tests/test_print.py +408 -0
- warp/tests/test_quat.py +2653 -0
- warp/tests/test_quat_assign_copy.py +145 -0
- warp/tests/test_rand.py +339 -0
- warp/tests/test_reload.py +303 -0
- warp/tests/test_rounding.py +157 -0
- warp/tests/test_runlength_encode.py +196 -0
- warp/tests/test_scalar_ops.py +133 -0
- warp/tests/test_smoothstep.py +108 -0
- warp/tests/test_snippet.py +318 -0
- warp/tests/test_sparse.py +845 -0
- warp/tests/test_spatial.py +2859 -0
- warp/tests/test_spatial_assign_copy.py +160 -0
- warp/tests/test_special_values.py +361 -0
- warp/tests/test_static.py +640 -0
- warp/tests/test_struct.py +901 -0
- warp/tests/test_tape.py +242 -0
- warp/tests/test_transient_module.py +93 -0
- warp/tests/test_triangle_closest_point.py +192 -0
- warp/tests/test_tuple.py +361 -0
- warp/tests/test_types.py +615 -0
- warp/tests/test_utils.py +594 -0
- warp/tests/test_vec.py +1408 -0
- warp/tests/test_vec_assign_copy.py +143 -0
- warp/tests/test_vec_constructors.py +325 -0
- warp/tests/test_vec_lite.py +80 -0
- warp/tests/test_vec_scalar_ops.py +2327 -0
- warp/tests/test_verify_fp.py +100 -0
- warp/tests/test_version.py +75 -0
- warp/tests/tile/__init__.py +0 -0
- warp/tests/tile/test_tile.py +1519 -0
- warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
- warp/tests/tile/test_tile_cholesky.py +608 -0
- warp/tests/tile/test_tile_load.py +724 -0
- warp/tests/tile/test_tile_mathdx.py +156 -0
- warp/tests/tile/test_tile_matmul.py +179 -0
- warp/tests/tile/test_tile_mlp.py +400 -0
- warp/tests/tile/test_tile_reduce.py +950 -0
- warp/tests/tile/test_tile_shared_memory.py +376 -0
- warp/tests/tile/test_tile_sort.py +121 -0
- warp/tests/tile/test_tile_view.py +173 -0
- warp/tests/unittest_serial.py +47 -0
- warp/tests/unittest_suites.py +430 -0
- warp/tests/unittest_utils.py +469 -0
- warp/tests/walkthrough_debug.py +95 -0
- warp/torch.py +24 -0
- warp/types.py +51 -0
- warp/utils.py +31 -0
- warp_lang-1.10.0.dist-info/METADATA +459 -0
- warp_lang-1.10.0.dist-info/RECORD +468 -0
- warp_lang-1.10.0.dist-info/WHEEL +5 -0
- warp_lang-1.10.0.dist-info/licenses/LICENSE.md +176 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
- warp_lang-1.10.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1673 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
import unittest
|
|
18
|
+
from functools import partial
|
|
19
|
+
from typing import Any
|
|
20
|
+
|
|
21
|
+
import numpy as np
|
|
22
|
+
|
|
23
|
+
import warp as wp
|
|
24
|
+
from warp._src.jax import get_jax_device
|
|
25
|
+
from warp.tests.unittest_utils import *
|
|
26
|
+
|
|
27
|
+
# default array size for tests
|
|
28
|
+
ARRAY_SIZE = 1024 * 1024
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
# basic kernel with one input and output
|
|
32
|
+
@wp.kernel
|
|
33
|
+
def triple_kernel(input: wp.array(dtype=float), output: wp.array(dtype=float)):
|
|
34
|
+
tid = wp.tid()
|
|
35
|
+
output[tid] = 3.0 * input[tid]
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
# generic kernel with one scalar input and output
|
|
39
|
+
@wp.kernel
|
|
40
|
+
def triple_kernel_scalar(input: wp.array(dtype=Any), output: wp.array(dtype=Any)):
|
|
41
|
+
tid = wp.tid()
|
|
42
|
+
output[tid] = input.dtype(3) * input[tid]
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
# generic kernel with one vector/matrix input and output
|
|
46
|
+
@wp.kernel
|
|
47
|
+
def triple_kernel_vecmat(input: wp.array(dtype=Any), output: wp.array(dtype=Any)):
|
|
48
|
+
tid = wp.tid()
|
|
49
|
+
output[tid] = input.dtype.dtype(3) * input[tid]
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@wp.kernel
|
|
53
|
+
def inc_1d_kernel(x: wp.array(dtype=float), y: wp.array(dtype=float)):
|
|
54
|
+
tid = wp.tid()
|
|
55
|
+
y[tid] = x[tid] + 1.0
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@wp.kernel
|
|
59
|
+
def inc_2d_kernel(x: wp.array2d(dtype=float), y: wp.array2d(dtype=float)):
|
|
60
|
+
i, j = wp.tid()
|
|
61
|
+
y[i, j] = x[i, j] + 1.0
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
# kernel with multiple inputs and outputs
|
|
65
|
+
@wp.kernel
|
|
66
|
+
def multiarg_kernel(
|
|
67
|
+
# inputs
|
|
68
|
+
a: wp.array(dtype=float),
|
|
69
|
+
b: wp.array(dtype=float),
|
|
70
|
+
c: wp.array(dtype=float),
|
|
71
|
+
# outputs
|
|
72
|
+
ab: wp.array(dtype=float),
|
|
73
|
+
bc: wp.array(dtype=float),
|
|
74
|
+
):
|
|
75
|
+
tid = wp.tid()
|
|
76
|
+
ab[tid] = a[tid] + b[tid]
|
|
77
|
+
bc[tid] = b[tid] + c[tid]
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
# various types for testing
|
|
81
|
+
scalar_types = wp._src.types.scalar_types
|
|
82
|
+
vector_types = []
|
|
83
|
+
matrix_types = []
|
|
84
|
+
for dim in [2, 3, 4]:
|
|
85
|
+
for T in scalar_types:
|
|
86
|
+
vector_types.append(wp.vec(dim, T))
|
|
87
|
+
matrix_types.append(wp.mat((dim, dim), T))
|
|
88
|
+
|
|
89
|
+
# explicitly overload generic kernels to avoid module reloading during tests
|
|
90
|
+
for T in scalar_types:
|
|
91
|
+
wp.overload(triple_kernel_scalar, [wp.array(dtype=T), wp.array(dtype=T)])
|
|
92
|
+
for T in [*vector_types, *matrix_types]:
|
|
93
|
+
wp.overload(triple_kernel_vecmat, [wp.array(dtype=T), wp.array(dtype=T)])
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def _jax_version():
|
|
97
|
+
try:
|
|
98
|
+
import jax
|
|
99
|
+
|
|
100
|
+
return jax.__version_info__
|
|
101
|
+
except ImportError:
|
|
102
|
+
return (0, 0, 0)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def test_dtype_from_jax(test, device):
|
|
106
|
+
import jax.numpy as jp
|
|
107
|
+
|
|
108
|
+
def test_conversions(jax_type, warp_type):
|
|
109
|
+
test.assertEqual(wp.dtype_from_jax(jax_type), warp_type)
|
|
110
|
+
test.assertEqual(wp.dtype_from_jax(jp.dtype(jax_type)), warp_type)
|
|
111
|
+
|
|
112
|
+
test_conversions(jp.float16, wp.float16)
|
|
113
|
+
test_conversions(jp.float32, wp.float32)
|
|
114
|
+
test_conversions(jp.float64, wp.float64)
|
|
115
|
+
test_conversions(jp.int8, wp.int8)
|
|
116
|
+
test_conversions(jp.int16, wp.int16)
|
|
117
|
+
test_conversions(jp.int32, wp.int32)
|
|
118
|
+
test_conversions(jp.int64, wp.int64)
|
|
119
|
+
test_conversions(jp.uint8, wp.uint8)
|
|
120
|
+
test_conversions(jp.uint16, wp.uint16)
|
|
121
|
+
test_conversions(jp.uint32, wp.uint32)
|
|
122
|
+
test_conversions(jp.uint64, wp.uint64)
|
|
123
|
+
test_conversions(jp.bool_, wp.bool)
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def test_dtype_to_jax(test, device):
|
|
127
|
+
import jax.numpy as jp
|
|
128
|
+
|
|
129
|
+
def test_conversions(warp_type, jax_type):
|
|
130
|
+
test.assertEqual(wp.dtype_to_jax(warp_type), jax_type)
|
|
131
|
+
|
|
132
|
+
test_conversions(wp.float16, jp.float16)
|
|
133
|
+
test_conversions(wp.float32, jp.float32)
|
|
134
|
+
test_conversions(wp.float64, jp.float64)
|
|
135
|
+
test_conversions(wp.int8, jp.int8)
|
|
136
|
+
test_conversions(wp.int16, jp.int16)
|
|
137
|
+
test_conversions(wp.int32, jp.int32)
|
|
138
|
+
test_conversions(wp.int64, jp.int64)
|
|
139
|
+
test_conversions(wp.uint8, jp.uint8)
|
|
140
|
+
test_conversions(wp.uint16, jp.uint16)
|
|
141
|
+
test_conversions(wp.uint32, jp.uint32)
|
|
142
|
+
test_conversions(wp.uint64, jp.uint64)
|
|
143
|
+
test_conversions(wp.bool, jp.bool_)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def test_device_conversion(test, device):
|
|
147
|
+
jax_device = wp.device_to_jax(device)
|
|
148
|
+
warp_device = wp.device_from_jax(jax_device)
|
|
149
|
+
test.assertEqual(warp_device, device)
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def test_jax_kernel_basic(test, device, use_ffi=False):
|
|
153
|
+
import jax.numpy as jp
|
|
154
|
+
|
|
155
|
+
if use_ffi:
|
|
156
|
+
from warp.jax_experimental.ffi import jax_kernel
|
|
157
|
+
|
|
158
|
+
jax_triple = jax_kernel(triple_kernel)
|
|
159
|
+
else:
|
|
160
|
+
from warp.jax_experimental.custom_call import jax_kernel
|
|
161
|
+
|
|
162
|
+
jax_triple = jax_kernel(triple_kernel, quiet=True) # suppress deprecation warnings
|
|
163
|
+
|
|
164
|
+
n = ARRAY_SIZE
|
|
165
|
+
|
|
166
|
+
@jax.jit
|
|
167
|
+
def f():
|
|
168
|
+
x = jp.arange(n, dtype=jp.float32)
|
|
169
|
+
return jax_triple(x)
|
|
170
|
+
|
|
171
|
+
# run on the given device
|
|
172
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
173
|
+
y = f()
|
|
174
|
+
|
|
175
|
+
wp.synchronize_device(device)
|
|
176
|
+
|
|
177
|
+
result = np.asarray(y).reshape((n,))
|
|
178
|
+
expected = 3 * np.arange(n, dtype=np.float32)
|
|
179
|
+
|
|
180
|
+
assert_np_equal(result, expected)
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def test_jax_kernel_scalar(test, device, use_ffi=False):
|
|
184
|
+
import jax.numpy as jp
|
|
185
|
+
|
|
186
|
+
if use_ffi:
|
|
187
|
+
from warp.jax_experimental.ffi import jax_kernel
|
|
188
|
+
|
|
189
|
+
kwargs = {}
|
|
190
|
+
else:
|
|
191
|
+
from warp.jax_experimental.custom_call import jax_kernel
|
|
192
|
+
|
|
193
|
+
kwargs = {"quiet": True}
|
|
194
|
+
|
|
195
|
+
# use a smallish size to ensure arange * 3 doesn't overflow
|
|
196
|
+
n = 64
|
|
197
|
+
|
|
198
|
+
for T in scalar_types:
|
|
199
|
+
jp_dtype = wp.dtype_to_jax(T)
|
|
200
|
+
np_dtype = wp.dtype_to_numpy(T)
|
|
201
|
+
|
|
202
|
+
with test.subTest(msg=T.__name__):
|
|
203
|
+
# get the concrete overload
|
|
204
|
+
kernel_instance = triple_kernel_scalar.add_overload([wp.array(dtype=T), wp.array(dtype=T)])
|
|
205
|
+
|
|
206
|
+
jax_triple = jax_kernel(kernel_instance, **kwargs)
|
|
207
|
+
|
|
208
|
+
@jax.jit
|
|
209
|
+
def f(jax_triple=jax_triple, jp_dtype=jp_dtype):
|
|
210
|
+
x = jp.arange(n, dtype=jp_dtype)
|
|
211
|
+
return jax_triple(x)
|
|
212
|
+
|
|
213
|
+
# run on the given device
|
|
214
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
215
|
+
y = f()
|
|
216
|
+
|
|
217
|
+
wp.synchronize_device(device)
|
|
218
|
+
|
|
219
|
+
result = np.asarray(y).reshape((n,))
|
|
220
|
+
expected = 3 * np.arange(n, dtype=np_dtype)
|
|
221
|
+
|
|
222
|
+
assert_np_equal(result, expected)
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
def test_jax_kernel_vecmat(test, device, use_ffi=False):
|
|
226
|
+
import jax.numpy as jp
|
|
227
|
+
|
|
228
|
+
if use_ffi:
|
|
229
|
+
from warp.jax_experimental.ffi import jax_kernel
|
|
230
|
+
|
|
231
|
+
kwargs = {}
|
|
232
|
+
else:
|
|
233
|
+
from warp.jax_experimental.custom_call import jax_kernel
|
|
234
|
+
|
|
235
|
+
kwargs = {"quiet": True}
|
|
236
|
+
|
|
237
|
+
for T in [*vector_types, *matrix_types]:
|
|
238
|
+
jp_dtype = wp.dtype_to_jax(T._wp_scalar_type_)
|
|
239
|
+
np_dtype = wp.dtype_to_numpy(T._wp_scalar_type_)
|
|
240
|
+
|
|
241
|
+
# use a smallish size to ensure arange * 3 doesn't overflow
|
|
242
|
+
n = 64 // T._length_
|
|
243
|
+
scalar_shape = (n, *T._shape_)
|
|
244
|
+
scalar_len = n * T._length_
|
|
245
|
+
|
|
246
|
+
with test.subTest(msg=T.__name__):
|
|
247
|
+
# get the concrete overload
|
|
248
|
+
kernel_instance = triple_kernel_vecmat.add_overload([wp.array(dtype=T), wp.array(dtype=T)])
|
|
249
|
+
|
|
250
|
+
jax_triple = jax_kernel(kernel_instance, **kwargs)
|
|
251
|
+
|
|
252
|
+
@jax.jit
|
|
253
|
+
def f(jax_triple=jax_triple, jp_dtype=jp_dtype, scalar_len=scalar_len, scalar_shape=scalar_shape):
|
|
254
|
+
x = jp.arange(scalar_len, dtype=jp_dtype).reshape(scalar_shape)
|
|
255
|
+
return jax_triple(x)
|
|
256
|
+
|
|
257
|
+
# run on the given device
|
|
258
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
259
|
+
y = f()
|
|
260
|
+
|
|
261
|
+
wp.synchronize_device(device)
|
|
262
|
+
|
|
263
|
+
result = np.asarray(y).reshape(scalar_shape)
|
|
264
|
+
expected = 3 * np.arange(scalar_len, dtype=np_dtype).reshape(scalar_shape)
|
|
265
|
+
|
|
266
|
+
assert_np_equal(result, expected)
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
def test_jax_kernel_multiarg(test, device, use_ffi=False):
|
|
270
|
+
import jax.numpy as jp
|
|
271
|
+
|
|
272
|
+
if use_ffi:
|
|
273
|
+
from warp.jax_experimental.ffi import jax_kernel
|
|
274
|
+
|
|
275
|
+
jax_multiarg = jax_kernel(multiarg_kernel, num_outputs=2)
|
|
276
|
+
else:
|
|
277
|
+
from warp.jax_experimental.custom_call import jax_kernel
|
|
278
|
+
|
|
279
|
+
jax_multiarg = jax_kernel(multiarg_kernel, quiet=True)
|
|
280
|
+
|
|
281
|
+
n = ARRAY_SIZE
|
|
282
|
+
|
|
283
|
+
@jax.jit
|
|
284
|
+
def f():
|
|
285
|
+
a = jp.full(n, 1, dtype=jp.float32)
|
|
286
|
+
b = jp.full(n, 2, dtype=jp.float32)
|
|
287
|
+
c = jp.full(n, 3, dtype=jp.float32)
|
|
288
|
+
return jax_multiarg(a, b, c)
|
|
289
|
+
|
|
290
|
+
# run on the given device
|
|
291
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
292
|
+
x, y = f()
|
|
293
|
+
|
|
294
|
+
wp.synchronize_device(device)
|
|
295
|
+
|
|
296
|
+
result_x, result_y = np.asarray(x), np.asarray(y)
|
|
297
|
+
expected_x = np.full(n, 3, dtype=np.float32)
|
|
298
|
+
expected_y = np.full(n, 5, dtype=np.float32)
|
|
299
|
+
|
|
300
|
+
assert_np_equal(result_x, expected_x)
|
|
301
|
+
assert_np_equal(result_y, expected_y)
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
def test_jax_kernel_launch_dims(test, device, use_ffi=False):
|
|
305
|
+
import jax.numpy as jp
|
|
306
|
+
|
|
307
|
+
if use_ffi:
|
|
308
|
+
from warp.jax_experimental.ffi import jax_kernel
|
|
309
|
+
|
|
310
|
+
kwargs = {}
|
|
311
|
+
else:
|
|
312
|
+
from warp.jax_experimental.custom_call import jax_kernel
|
|
313
|
+
|
|
314
|
+
kwargs = {"quiet": True}
|
|
315
|
+
|
|
316
|
+
n = 64
|
|
317
|
+
m = 32
|
|
318
|
+
|
|
319
|
+
# Test with 1D launch dims
|
|
320
|
+
jax_inc_1d = jax_kernel(
|
|
321
|
+
inc_1d_kernel, launch_dims=(n - 2,), **kwargs
|
|
322
|
+
) # Intentionally not the same as the first dimension of the input
|
|
323
|
+
|
|
324
|
+
@jax.jit
|
|
325
|
+
def f_1d():
|
|
326
|
+
x = jp.arange(n, dtype=jp.float32)
|
|
327
|
+
return jax_inc_1d(x)
|
|
328
|
+
|
|
329
|
+
# Test with 2D launch dims
|
|
330
|
+
jax_inc_2d = jax_kernel(
|
|
331
|
+
inc_2d_kernel, launch_dims=(n - 2, m - 2), **kwargs
|
|
332
|
+
) # Intentionally not the same as the first dimension of the input
|
|
333
|
+
|
|
334
|
+
@jax.jit
|
|
335
|
+
def f_2d():
|
|
336
|
+
x = jp.zeros((n, m), dtype=jp.float32) + 3.0
|
|
337
|
+
return jax_inc_2d(x)
|
|
338
|
+
|
|
339
|
+
# run on the given device
|
|
340
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
341
|
+
y_1d = f_1d()
|
|
342
|
+
y_2d = f_2d()
|
|
343
|
+
|
|
344
|
+
wp.synchronize_device(device)
|
|
345
|
+
|
|
346
|
+
result_1d = np.asarray(y_1d).reshape((n - 2,))
|
|
347
|
+
expected_1d = np.arange(n - 2, dtype=np.float32) + 1.0
|
|
348
|
+
|
|
349
|
+
result_2d = np.asarray(y_2d).reshape((n - 2, m - 2))
|
|
350
|
+
expected_2d = np.full((n - 2, m - 2), 4.0, dtype=np.float32)
|
|
351
|
+
|
|
352
|
+
assert_np_equal(result_1d, expected_1d)
|
|
353
|
+
assert_np_equal(result_2d, expected_2d)
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
# =========================================================================================================
|
|
357
|
+
# JAX FFI
|
|
358
|
+
# =========================================================================================================
|
|
359
|
+
|
|
360
|
+
|
|
361
|
+
@wp.kernel
|
|
362
|
+
def add_kernel(a: wp.array(dtype=float), b: wp.array(dtype=float), output: wp.array(dtype=float)):
|
|
363
|
+
tid = wp.tid()
|
|
364
|
+
output[tid] = a[tid] + b[tid]
|
|
365
|
+
|
|
366
|
+
|
|
367
|
+
@wp.kernel
|
|
368
|
+
def axpy_kernel(x: wp.array(dtype=float), y: wp.array(dtype=float), alpha: float, out: wp.array(dtype=float)):
|
|
369
|
+
tid = wp.tid()
|
|
370
|
+
out[tid] = alpha * x[tid] + y[tid]
|
|
371
|
+
|
|
372
|
+
|
|
373
|
+
@wp.kernel
|
|
374
|
+
def sincos_kernel(angle: wp.array(dtype=float), sin_out: wp.array(dtype=float), cos_out: wp.array(dtype=float)):
|
|
375
|
+
tid = wp.tid()
|
|
376
|
+
sin_out[tid] = wp.sin(angle[tid])
|
|
377
|
+
cos_out[tid] = wp.cos(angle[tid])
|
|
378
|
+
|
|
379
|
+
|
|
380
|
+
@wp.kernel
|
|
381
|
+
def diagonal_kernel(output: wp.array(dtype=wp.mat33)):
|
|
382
|
+
tid = wp.tid()
|
|
383
|
+
d = float(tid + 1)
|
|
384
|
+
output[tid] = wp.mat33(d, 0.0, 0.0, 0.0, d * 2.0, 0.0, 0.0, 0.0, d * 3.0)
|
|
385
|
+
|
|
386
|
+
|
|
387
|
+
@wp.kernel
|
|
388
|
+
def scale_kernel(a: wp.array(dtype=float), s: float, output: wp.array(dtype=float)):
|
|
389
|
+
tid = wp.tid()
|
|
390
|
+
output[tid] = a[tid] * s
|
|
391
|
+
|
|
392
|
+
|
|
393
|
+
@wp.kernel
|
|
394
|
+
def scale_vec_kernel(a: wp.array(dtype=wp.vec2), s: float, output: wp.array(dtype=wp.vec2)):
|
|
395
|
+
tid = wp.tid()
|
|
396
|
+
output[tid] = a[tid] * s
|
|
397
|
+
|
|
398
|
+
|
|
399
|
+
@wp.kernel
|
|
400
|
+
def accum_kernel(a: wp.array(dtype=float), b: wp.array(dtype=float)):
|
|
401
|
+
tid = wp.tid()
|
|
402
|
+
b[tid] += a[tid]
|
|
403
|
+
|
|
404
|
+
|
|
405
|
+
@wp.kernel
|
|
406
|
+
def matmul_kernel(
|
|
407
|
+
a: wp.array2d(dtype=float), # NxK
|
|
408
|
+
b: wp.array2d(dtype=float), # KxM
|
|
409
|
+
c: wp.array2d(dtype=float), # NxM
|
|
410
|
+
):
|
|
411
|
+
# launch dims should be (N, M)
|
|
412
|
+
i, j = wp.tid()
|
|
413
|
+
N = a.shape[0]
|
|
414
|
+
K = a.shape[1]
|
|
415
|
+
M = b.shape[1]
|
|
416
|
+
if i < N and j < M:
|
|
417
|
+
s = wp.float32(0)
|
|
418
|
+
for k in range(K):
|
|
419
|
+
s += a[i, k] * b[k, j]
|
|
420
|
+
c[i, j] = s
|
|
421
|
+
|
|
422
|
+
|
|
423
|
+
@wp.kernel
|
|
424
|
+
def in_out_kernel(
|
|
425
|
+
a: wp.array(dtype=float), # input only
|
|
426
|
+
b: wp.array(dtype=float), # input and output
|
|
427
|
+
c: wp.array(dtype=float), # output only
|
|
428
|
+
):
|
|
429
|
+
tid = wp.tid()
|
|
430
|
+
b[tid] += a[tid]
|
|
431
|
+
c[tid] = 2.0 * a[tid]
|
|
432
|
+
|
|
433
|
+
|
|
434
|
+
@wp.kernel
|
|
435
|
+
def multi_out_kernel(
|
|
436
|
+
a: wp.array(dtype=float), b: wp.array(dtype=float), s: float, c: wp.array(dtype=float), d: wp.array(dtype=float)
|
|
437
|
+
):
|
|
438
|
+
tid = wp.tid()
|
|
439
|
+
c[tid] = a[tid] + b[tid]
|
|
440
|
+
d[tid] = s * a[tid]
|
|
441
|
+
|
|
442
|
+
|
|
443
|
+
@wp.kernel
|
|
444
|
+
def multi_out_kernel_v2(
|
|
445
|
+
a: wp.array(dtype=float), b: wp.array(dtype=float), s: float, c: wp.array(dtype=float), d: wp.array(dtype=float)
|
|
446
|
+
):
|
|
447
|
+
tid = wp.tid()
|
|
448
|
+
c[tid] = a[tid] * a[tid]
|
|
449
|
+
d[tid] = a[tid] * b[tid] * s
|
|
450
|
+
|
|
451
|
+
|
|
452
|
+
@wp.kernel
|
|
453
|
+
def multi_out_kernel_v3(
|
|
454
|
+
a: wp.array(dtype=float), b: wp.array(dtype=float), s: float, c: wp.array(dtype=float), d: wp.array(dtype=float)
|
|
455
|
+
):
|
|
456
|
+
tid = wp.tid()
|
|
457
|
+
c[tid] = a[tid] ** 2.0
|
|
458
|
+
d[tid] = a[tid] * b[tid] * s
|
|
459
|
+
|
|
460
|
+
|
|
461
|
+
@wp.kernel
|
|
462
|
+
def scale_sum_square_kernel(a: wp.array(dtype=float), b: wp.array(dtype=float), s: float, c: wp.array(dtype=float)):
|
|
463
|
+
tid = wp.tid()
|
|
464
|
+
c[tid] = (a[tid] * s + b[tid]) ** 2.0
|
|
465
|
+
|
|
466
|
+
|
|
467
|
+
# The Python function to call.
|
|
468
|
+
# Note the argument annotations, just like Warp kernels.
|
|
469
|
+
def scale_func(
|
|
470
|
+
# inputs
|
|
471
|
+
a: wp.array(dtype=float),
|
|
472
|
+
b: wp.array(dtype=wp.vec2),
|
|
473
|
+
s: float,
|
|
474
|
+
# outputs
|
|
475
|
+
c: wp.array(dtype=float),
|
|
476
|
+
d: wp.array(dtype=wp.vec2),
|
|
477
|
+
):
|
|
478
|
+
wp.launch(scale_kernel, dim=a.shape, inputs=[a, s], outputs=[c])
|
|
479
|
+
wp.launch(scale_vec_kernel, dim=b.shape, inputs=[b, s], outputs=[d])
|
|
480
|
+
|
|
481
|
+
|
|
482
|
+
def in_out_func(
|
|
483
|
+
a: wp.array(dtype=float), # input only
|
|
484
|
+
b: wp.array(dtype=float), # input and output
|
|
485
|
+
c: wp.array(dtype=float), # output only
|
|
486
|
+
):
|
|
487
|
+
wp.launch(scale_kernel, dim=a.size, inputs=[a, 2.0], outputs=[c])
|
|
488
|
+
wp.launch(accum_kernel, dim=a.size, inputs=[a, b]) # modifies `b`
|
|
489
|
+
|
|
490
|
+
|
|
491
|
+
def double_func(
|
|
492
|
+
# inputs
|
|
493
|
+
a: wp.array(dtype=float),
|
|
494
|
+
# outputs
|
|
495
|
+
b: wp.array(dtype=float),
|
|
496
|
+
):
|
|
497
|
+
wp.launch(scale_kernel, dim=a.shape, inputs=[a, 2.0], outputs=[b])
|
|
498
|
+
|
|
499
|
+
|
|
500
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
501
|
+
def test_ffi_jax_kernel_add(test, device):
|
|
502
|
+
# two inputs and one output
|
|
503
|
+
import jax.numpy as jp
|
|
504
|
+
|
|
505
|
+
from warp.jax_experimental.ffi import jax_kernel
|
|
506
|
+
|
|
507
|
+
jax_add = jax_kernel(add_kernel)
|
|
508
|
+
|
|
509
|
+
@jax.jit
|
|
510
|
+
def f():
|
|
511
|
+
n = ARRAY_SIZE
|
|
512
|
+
a = jp.arange(n, dtype=jp.float32)
|
|
513
|
+
b = jp.ones(n, dtype=jp.float32)
|
|
514
|
+
return jax_add(a, b)
|
|
515
|
+
|
|
516
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
517
|
+
(y,) = f()
|
|
518
|
+
|
|
519
|
+
wp.synchronize_device(device)
|
|
520
|
+
|
|
521
|
+
result = np.asarray(y)
|
|
522
|
+
expected = np.arange(1, ARRAY_SIZE + 1, dtype=np.float32)
|
|
523
|
+
|
|
524
|
+
assert_np_equal(result, expected)
|
|
525
|
+
|
|
526
|
+
|
|
527
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
528
|
+
def test_ffi_jax_kernel_sincos(test, device):
|
|
529
|
+
# one input and two outputs
|
|
530
|
+
import jax.numpy as jp
|
|
531
|
+
|
|
532
|
+
from warp.jax_experimental.ffi import jax_kernel
|
|
533
|
+
|
|
534
|
+
jax_sincos = jax_kernel(sincos_kernel, num_outputs=2)
|
|
535
|
+
|
|
536
|
+
n = ARRAY_SIZE
|
|
537
|
+
|
|
538
|
+
@jax.jit
|
|
539
|
+
def f():
|
|
540
|
+
a = jp.linspace(0, 2 * jp.pi, n, dtype=jp.float32)
|
|
541
|
+
return jax_sincos(a)
|
|
542
|
+
|
|
543
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
544
|
+
s, c = f()
|
|
545
|
+
|
|
546
|
+
wp.synchronize_device(device)
|
|
547
|
+
|
|
548
|
+
result_s = np.asarray(s)
|
|
549
|
+
result_c = np.asarray(c)
|
|
550
|
+
|
|
551
|
+
a = np.linspace(0, 2 * np.pi, n, dtype=np.float32)
|
|
552
|
+
expected_s = np.sin(a)
|
|
553
|
+
expected_c = np.cos(a)
|
|
554
|
+
|
|
555
|
+
assert_np_equal(result_s, expected_s, tol=1e-4)
|
|
556
|
+
assert_np_equal(result_c, expected_c, tol=1e-4)
|
|
557
|
+
|
|
558
|
+
|
|
559
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
560
|
+
def test_ffi_jax_kernel_diagonal(test, device):
|
|
561
|
+
# no inputs and one output
|
|
562
|
+
from warp.jax_experimental.ffi import jax_kernel
|
|
563
|
+
|
|
564
|
+
jax_diagonal = jax_kernel(diagonal_kernel)
|
|
565
|
+
|
|
566
|
+
@jax.jit
|
|
567
|
+
def f():
|
|
568
|
+
# launch dimensions determine output size
|
|
569
|
+
return jax_diagonal(launch_dims=4)
|
|
570
|
+
|
|
571
|
+
wp.synchronize_device(device)
|
|
572
|
+
|
|
573
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
574
|
+
(d,) = f()
|
|
575
|
+
|
|
576
|
+
result = np.asarray(d)
|
|
577
|
+
expected = np.array(
|
|
578
|
+
[
|
|
579
|
+
[[1.0, 0.0, 0.0], [0.0, 2.0, 0.0], [0.0, 0.0, 3.0]],
|
|
580
|
+
[[2.0, 0.0, 0.0], [0.0, 4.0, 0.0], [0.0, 0.0, 6.0]],
|
|
581
|
+
[[3.0, 0.0, 0.0], [0.0, 6.0, 0.0], [0.0, 0.0, 9.0]],
|
|
582
|
+
[[4.0, 0.0, 0.0], [0.0, 8.0, 0.0], [0.0, 0.0, 12.0]],
|
|
583
|
+
],
|
|
584
|
+
dtype=np.float32,
|
|
585
|
+
)
|
|
586
|
+
|
|
587
|
+
assert_np_equal(result, expected)
|
|
588
|
+
|
|
589
|
+
|
|
590
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
591
|
+
def test_ffi_jax_kernel_in_out(test, device):
|
|
592
|
+
# in-out args
|
|
593
|
+
import jax.numpy as jp
|
|
594
|
+
|
|
595
|
+
from warp.jax_experimental.ffi import jax_kernel
|
|
596
|
+
|
|
597
|
+
jax_func = jax_kernel(in_out_kernel, num_outputs=2, in_out_argnames=["b"])
|
|
598
|
+
|
|
599
|
+
f = jax.jit(jax_func)
|
|
600
|
+
|
|
601
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
602
|
+
a = jp.ones(ARRAY_SIZE, dtype=jp.float32)
|
|
603
|
+
b = jp.arange(ARRAY_SIZE, dtype=jp.float32)
|
|
604
|
+
b, c = f(a, b)
|
|
605
|
+
|
|
606
|
+
wp.synchronize_device(device)
|
|
607
|
+
|
|
608
|
+
assert_np_equal(b, np.arange(1, ARRAY_SIZE + 1, dtype=np.float32))
|
|
609
|
+
assert_np_equal(c, np.full(ARRAY_SIZE, 2, dtype=np.float32))
|
|
610
|
+
|
|
611
|
+
|
|
612
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
613
|
+
def test_ffi_jax_kernel_scale_vec_constant(test, device):
|
|
614
|
+
# multiply vectors by scalar (constant)
|
|
615
|
+
import jax.numpy as jp
|
|
616
|
+
|
|
617
|
+
from warp.jax_experimental.ffi import jax_kernel
|
|
618
|
+
|
|
619
|
+
jax_scale_vec = jax_kernel(scale_vec_kernel)
|
|
620
|
+
|
|
621
|
+
@jax.jit
|
|
622
|
+
def f():
|
|
623
|
+
a = jp.arange(ARRAY_SIZE, dtype=jp.float32).reshape((ARRAY_SIZE // 2, 2)) # array of vec2
|
|
624
|
+
s = 2.0
|
|
625
|
+
return jax_scale_vec(a, s)
|
|
626
|
+
|
|
627
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
628
|
+
(b,) = f()
|
|
629
|
+
|
|
630
|
+
wp.synchronize_device(device)
|
|
631
|
+
|
|
632
|
+
expected = 2 * np.arange(ARRAY_SIZE, dtype=np.float32).reshape((ARRAY_SIZE // 2, 2))
|
|
633
|
+
|
|
634
|
+
assert_np_equal(b, expected)
|
|
635
|
+
|
|
636
|
+
|
|
637
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
638
|
+
def test_ffi_jax_kernel_scale_vec_static(test, device):
|
|
639
|
+
# multiply vectors by scalar (static arg)
|
|
640
|
+
import jax.numpy as jp
|
|
641
|
+
|
|
642
|
+
from warp.jax_experimental.ffi import jax_kernel
|
|
643
|
+
|
|
644
|
+
jax_scale_vec = jax_kernel(scale_vec_kernel)
|
|
645
|
+
|
|
646
|
+
# NOTE: scalar arguments must be static compile-time constants
|
|
647
|
+
@partial(jax.jit, static_argnames=["s"])
|
|
648
|
+
def f(a, s):
|
|
649
|
+
return jax_scale_vec(a, s)
|
|
650
|
+
|
|
651
|
+
a = jp.arange(ARRAY_SIZE, dtype=jp.float32).reshape((ARRAY_SIZE // 2, 2)) # array of vec2
|
|
652
|
+
s = 3.0
|
|
653
|
+
|
|
654
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
655
|
+
(b,) = f(a, s)
|
|
656
|
+
|
|
657
|
+
wp.synchronize_device(device)
|
|
658
|
+
|
|
659
|
+
expected = 3 * np.arange(ARRAY_SIZE, dtype=np.float32).reshape((ARRAY_SIZE // 2, 2))
|
|
660
|
+
|
|
661
|
+
assert_np_equal(b, expected)
|
|
662
|
+
|
|
663
|
+
|
|
664
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
665
|
+
def test_ffi_jax_kernel_launch_dims_default(test, device):
|
|
666
|
+
# specify default launch dims
|
|
667
|
+
import jax.numpy as jp
|
|
668
|
+
|
|
669
|
+
from warp.jax_experimental.ffi import jax_kernel
|
|
670
|
+
|
|
671
|
+
N, M, K = 3, 4, 2
|
|
672
|
+
|
|
673
|
+
jax_matmul = jax_kernel(matmul_kernel, launch_dims=(N, M))
|
|
674
|
+
|
|
675
|
+
@jax.jit
|
|
676
|
+
def f():
|
|
677
|
+
a = jp.full((N, K), 2, dtype=jp.float32)
|
|
678
|
+
b = jp.full((K, M), 3, dtype=jp.float32)
|
|
679
|
+
|
|
680
|
+
# use default launch dims
|
|
681
|
+
return jax_matmul(a, b)
|
|
682
|
+
|
|
683
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
684
|
+
(result,) = f()
|
|
685
|
+
|
|
686
|
+
wp.synchronize_device(device)
|
|
687
|
+
|
|
688
|
+
expected = np.full((3, 4), 12, dtype=np.float32)
|
|
689
|
+
|
|
690
|
+
test.assertEqual(result.shape, expected.shape)
|
|
691
|
+
assert_np_equal(result, expected)
|
|
692
|
+
|
|
693
|
+
|
|
694
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
695
|
+
def test_ffi_jax_kernel_launch_dims_custom(test, device):
|
|
696
|
+
# specify custom launch dims per call
|
|
697
|
+
import jax.numpy as jp
|
|
698
|
+
|
|
699
|
+
from warp.jax_experimental.ffi import jax_kernel
|
|
700
|
+
|
|
701
|
+
jax_matmul = jax_kernel(matmul_kernel)
|
|
702
|
+
|
|
703
|
+
@jax.jit
|
|
704
|
+
def f():
|
|
705
|
+
N1, M1, K1 = 3, 4, 2
|
|
706
|
+
a1 = jp.full((N1, K1), 2, dtype=jp.float32)
|
|
707
|
+
b1 = jp.full((K1, M1), 3, dtype=jp.float32)
|
|
708
|
+
|
|
709
|
+
# use custom launch dims
|
|
710
|
+
result1 = jax_matmul(a1, b1, launch_dims=(N1, M1))
|
|
711
|
+
|
|
712
|
+
N2, M2, K2 = 4, 3, 2
|
|
713
|
+
a2 = jp.full((N2, K2), 2, dtype=jp.float32)
|
|
714
|
+
b2 = jp.full((K2, M2), 3, dtype=jp.float32)
|
|
715
|
+
|
|
716
|
+
# use different custom launch dims
|
|
717
|
+
result2 = jax_matmul(a2, b2, launch_dims=(N2, M2))
|
|
718
|
+
|
|
719
|
+
return result1[0], result2[0]
|
|
720
|
+
|
|
721
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
722
|
+
result1, result2 = f()
|
|
723
|
+
|
|
724
|
+
wp.synchronize_device(device)
|
|
725
|
+
|
|
726
|
+
expected1 = np.full((3, 4), 12, dtype=np.float32)
|
|
727
|
+
expected2 = np.full((4, 3), 12, dtype=np.float32)
|
|
728
|
+
|
|
729
|
+
test.assertEqual(result1.shape, expected1.shape)
|
|
730
|
+
test.assertEqual(result2.shape, expected2.shape)
|
|
731
|
+
assert_np_equal(result1, expected1)
|
|
732
|
+
assert_np_equal(result2, expected2)
|
|
733
|
+
|
|
734
|
+
|
|
735
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
736
|
+
def test_ffi_jax_callable_scale_constant(test, device):
|
|
737
|
+
# scale two arrays using a constant
|
|
738
|
+
import jax.numpy as jp
|
|
739
|
+
|
|
740
|
+
from warp.jax_experimental.ffi import jax_callable
|
|
741
|
+
|
|
742
|
+
jax_func = jax_callable(scale_func, num_outputs=2)
|
|
743
|
+
|
|
744
|
+
@jax.jit
|
|
745
|
+
def f():
|
|
746
|
+
# inputs
|
|
747
|
+
a = jp.arange(ARRAY_SIZE, dtype=jp.float32)
|
|
748
|
+
b = jp.arange(ARRAY_SIZE, dtype=jp.float32).reshape((ARRAY_SIZE // 2, 2)) # wp.vec2
|
|
749
|
+
s = 2.0
|
|
750
|
+
|
|
751
|
+
# output shapes
|
|
752
|
+
output_dims = {"c": a.shape, "d": b.shape}
|
|
753
|
+
|
|
754
|
+
c, d = jax_func(a, b, s, output_dims=output_dims)
|
|
755
|
+
|
|
756
|
+
return c, d
|
|
757
|
+
|
|
758
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
759
|
+
result1, result2 = f()
|
|
760
|
+
|
|
761
|
+
wp.synchronize_device(device)
|
|
762
|
+
|
|
763
|
+
expected1 = 2 * np.arange(ARRAY_SIZE, dtype=np.float32)
|
|
764
|
+
expected2 = 2 * np.arange(ARRAY_SIZE, dtype=np.float32).reshape((ARRAY_SIZE // 2, 2))
|
|
765
|
+
|
|
766
|
+
assert_np_equal(result1, expected1)
|
|
767
|
+
assert_np_equal(result2, expected2)
|
|
768
|
+
|
|
769
|
+
|
|
770
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
771
|
+
def test_ffi_jax_callable_scale_static(test, device):
|
|
772
|
+
# scale two arrays using a static arg
|
|
773
|
+
import jax.numpy as jp
|
|
774
|
+
|
|
775
|
+
from warp.jax_experimental.ffi import jax_callable
|
|
776
|
+
|
|
777
|
+
jax_func = jax_callable(scale_func, num_outputs=2)
|
|
778
|
+
|
|
779
|
+
# NOTE: scalar arguments must be static compile-time constants
|
|
780
|
+
@partial(jax.jit, static_argnames=["s"])
|
|
781
|
+
def f(a, b, s):
|
|
782
|
+
# output shapes
|
|
783
|
+
output_dims = {"c": a.shape, "d": b.shape}
|
|
784
|
+
|
|
785
|
+
c, d = jax_func(a, b, s, output_dims=output_dims)
|
|
786
|
+
|
|
787
|
+
return c, d
|
|
788
|
+
|
|
789
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
790
|
+
# inputs
|
|
791
|
+
a = jp.arange(ARRAY_SIZE, dtype=jp.float32)
|
|
792
|
+
b = jp.arange(ARRAY_SIZE, dtype=jp.float32).reshape((ARRAY_SIZE // 2, 2)) # wp.vec2
|
|
793
|
+
s = 3.0
|
|
794
|
+
result1, result2 = f(a, b, s)
|
|
795
|
+
|
|
796
|
+
wp.synchronize_device(device)
|
|
797
|
+
|
|
798
|
+
expected1 = 3 * np.arange(ARRAY_SIZE, dtype=np.float32)
|
|
799
|
+
expected2 = 3 * np.arange(ARRAY_SIZE, dtype=np.float32).reshape((ARRAY_SIZE // 2, 2))
|
|
800
|
+
|
|
801
|
+
assert_np_equal(result1, expected1)
|
|
802
|
+
assert_np_equal(result2, expected2)
|
|
803
|
+
|
|
804
|
+
|
|
805
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
806
|
+
def test_ffi_jax_callable_in_out(test, device):
|
|
807
|
+
# in-out arguments
|
|
808
|
+
import jax.numpy as jp
|
|
809
|
+
|
|
810
|
+
from warp.jax_experimental.ffi import jax_callable
|
|
811
|
+
|
|
812
|
+
jax_func = jax_callable(in_out_func, num_outputs=2, in_out_argnames=["b"])
|
|
813
|
+
|
|
814
|
+
f = jax.jit(jax_func)
|
|
815
|
+
|
|
816
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
817
|
+
a = jp.ones(ARRAY_SIZE, dtype=jp.float32)
|
|
818
|
+
b = jp.arange(ARRAY_SIZE, dtype=jp.float32)
|
|
819
|
+
b, c = f(a, b)
|
|
820
|
+
|
|
821
|
+
wp.synchronize_device(device)
|
|
822
|
+
|
|
823
|
+
assert_np_equal(b, np.arange(1, ARRAY_SIZE + 1, dtype=np.float32))
|
|
824
|
+
assert_np_equal(c, np.full(ARRAY_SIZE, 2, dtype=np.float32))
|
|
825
|
+
|
|
826
|
+
|
|
827
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
828
|
+
def test_ffi_jax_callable_graph_cache(test, device):
|
|
829
|
+
# test graph caching limits
|
|
830
|
+
import jax
|
|
831
|
+
import jax.numpy as jp
|
|
832
|
+
|
|
833
|
+
from warp.jax_experimental.ffi import (
|
|
834
|
+
GraphMode,
|
|
835
|
+
clear_jax_callable_graph_cache,
|
|
836
|
+
get_jax_callable_default_graph_cache_max,
|
|
837
|
+
jax_callable,
|
|
838
|
+
set_jax_callable_default_graph_cache_max,
|
|
839
|
+
)
|
|
840
|
+
|
|
841
|
+
# --- test with default cache settings ---
|
|
842
|
+
|
|
843
|
+
jax_double = jax_callable(double_func, graph_mode=GraphMode.WARP)
|
|
844
|
+
f = jax.jit(jax_double)
|
|
845
|
+
arrays = []
|
|
846
|
+
|
|
847
|
+
test.assertEqual(jax_double.graph_cache_max, get_jax_callable_default_graph_cache_max())
|
|
848
|
+
|
|
849
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
850
|
+
for i in range(10):
|
|
851
|
+
n = 10 + i
|
|
852
|
+
a = jp.arange(n, dtype=jp.float32)
|
|
853
|
+
(b,) = f(a)
|
|
854
|
+
|
|
855
|
+
assert_np_equal(b, 2 * np.arange(n, dtype=np.float32))
|
|
856
|
+
|
|
857
|
+
# ensure graph cache is always growing
|
|
858
|
+
test.assertEqual(jax_double.graph_cache_size, i + 1)
|
|
859
|
+
|
|
860
|
+
# keep JAX array alive to prevent the memory from being reused, thus forcing a new graph capture each time
|
|
861
|
+
arrays.append(a)
|
|
862
|
+
|
|
863
|
+
# --- test clearing one callable's cache ---
|
|
864
|
+
|
|
865
|
+
clear_jax_callable_graph_cache(jax_double)
|
|
866
|
+
|
|
867
|
+
test.assertEqual(jax_double.graph_cache_size, 0)
|
|
868
|
+
|
|
869
|
+
# --- test with a custom cache limit ---
|
|
870
|
+
|
|
871
|
+
graph_cache_max = 5
|
|
872
|
+
jax_double = jax_callable(double_func, graph_mode=GraphMode.WARP, graph_cache_max=graph_cache_max)
|
|
873
|
+
f = jax.jit(jax_double)
|
|
874
|
+
arrays = []
|
|
875
|
+
|
|
876
|
+
test.assertEqual(jax_double.graph_cache_max, graph_cache_max)
|
|
877
|
+
|
|
878
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
879
|
+
for i in range(10):
|
|
880
|
+
n = 10 + i
|
|
881
|
+
a = jp.arange(n, dtype=jp.float32)
|
|
882
|
+
(b,) = f(a)
|
|
883
|
+
|
|
884
|
+
assert_np_equal(b, 2 * np.arange(n, dtype=np.float32))
|
|
885
|
+
|
|
886
|
+
# ensure graph cache size is capped
|
|
887
|
+
test.assertEqual(jax_double.graph_cache_size, min(i + 1, graph_cache_max))
|
|
888
|
+
|
|
889
|
+
# keep JAX array alive to prevent the memory from being reused, thus forcing a new graph capture
|
|
890
|
+
arrays.append(a)
|
|
891
|
+
|
|
892
|
+
# --- test clearing all callables' caches ---
|
|
893
|
+
|
|
894
|
+
clear_jax_callable_graph_cache()
|
|
895
|
+
|
|
896
|
+
with wp.jax_experimental.ffi._FFI_REGISTRY_LOCK:
|
|
897
|
+
for c in wp.jax_experimental.ffi._FFI_CALLABLE_REGISTRY.values():
|
|
898
|
+
test.assertEqual(c.graph_cache_size, 0)
|
|
899
|
+
|
|
900
|
+
# --- test with a custom default cache limit ---
|
|
901
|
+
|
|
902
|
+
saved_max = get_jax_callable_default_graph_cache_max()
|
|
903
|
+
try:
|
|
904
|
+
set_jax_callable_default_graph_cache_max(5)
|
|
905
|
+
jax_double = jax_callable(double_func, graph_mode=GraphMode.WARP)
|
|
906
|
+
f = jax.jit(jax_double)
|
|
907
|
+
arrays = []
|
|
908
|
+
|
|
909
|
+
test.assertEqual(jax_double.graph_cache_max, get_jax_callable_default_graph_cache_max())
|
|
910
|
+
|
|
911
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
912
|
+
for i in range(10):
|
|
913
|
+
n = 10 + i
|
|
914
|
+
a = jp.arange(n, dtype=jp.float32)
|
|
915
|
+
(b,) = f(a)
|
|
916
|
+
|
|
917
|
+
assert_np_equal(b, 2 * np.arange(n, dtype=np.float32))
|
|
918
|
+
|
|
919
|
+
# ensure graph cache size is capped
|
|
920
|
+
test.assertEqual(
|
|
921
|
+
jax_double.graph_cache_size,
|
|
922
|
+
min(i + 1, get_jax_callable_default_graph_cache_max()),
|
|
923
|
+
)
|
|
924
|
+
|
|
925
|
+
# keep JAX array alive to prevent the memory from being reused, thus forcing a new graph capture
|
|
926
|
+
arrays.append(a)
|
|
927
|
+
|
|
928
|
+
clear_jax_callable_graph_cache()
|
|
929
|
+
|
|
930
|
+
finally:
|
|
931
|
+
set_jax_callable_default_graph_cache_max(saved_max)
|
|
932
|
+
|
|
933
|
+
|
|
934
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
935
|
+
def test_ffi_jax_callable_pmap_mul(test, device):
|
|
936
|
+
import jax
|
|
937
|
+
import jax.numpy as jp
|
|
938
|
+
|
|
939
|
+
from warp.jax_experimental.ffi import jax_callable
|
|
940
|
+
|
|
941
|
+
j = jax_callable(double_func, num_outputs=1)
|
|
942
|
+
|
|
943
|
+
ndev = jax.local_device_count()
|
|
944
|
+
per_device = max(ARRAY_SIZE // ndev, 64)
|
|
945
|
+
x = jp.arange(ndev * per_device, dtype=jp.float32).reshape((ndev, per_device))
|
|
946
|
+
|
|
947
|
+
def per_device_func(v):
|
|
948
|
+
(y,) = j(v)
|
|
949
|
+
return y
|
|
950
|
+
|
|
951
|
+
y = jax.pmap(per_device_func)(x)
|
|
952
|
+
|
|
953
|
+
wp.synchronize()
|
|
954
|
+
|
|
955
|
+
assert_np_equal(np.asarray(y), 2 * np.asarray(x))
|
|
956
|
+
|
|
957
|
+
|
|
958
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
959
|
+
def test_ffi_jax_callable_pmap_multi_output(test, device):
|
|
960
|
+
import jax
|
|
961
|
+
import jax.numpy as jp
|
|
962
|
+
|
|
963
|
+
from warp.jax_experimental.ffi import jax_callable
|
|
964
|
+
|
|
965
|
+
def multi_out_py(
|
|
966
|
+
a: wp.array(dtype=float),
|
|
967
|
+
b: wp.array(dtype=float),
|
|
968
|
+
s: float,
|
|
969
|
+
c: wp.array(dtype=float),
|
|
970
|
+
d: wp.array(dtype=float),
|
|
971
|
+
):
|
|
972
|
+
wp.launch(multi_out_kernel, dim=a.shape, inputs=[a, b, s], outputs=[c, d])
|
|
973
|
+
|
|
974
|
+
j = jax_callable(multi_out_py, num_outputs=2)
|
|
975
|
+
|
|
976
|
+
ndev = jax.local_device_count()
|
|
977
|
+
per_device = max(ARRAY_SIZE // ndev, 64)
|
|
978
|
+
a = jp.arange(ndev * per_device, dtype=jp.float32).reshape((ndev, per_device))
|
|
979
|
+
b = jp.ones((ndev, per_device), dtype=jp.float32)
|
|
980
|
+
s = 3.0
|
|
981
|
+
|
|
982
|
+
def per_device_func(aa, bb):
|
|
983
|
+
c, d = j(aa, bb, s)
|
|
984
|
+
return c + d # simple combine to exercise both outputs
|
|
985
|
+
|
|
986
|
+
out = jax.pmap(per_device_func)(a, b)
|
|
987
|
+
|
|
988
|
+
wp.synchronize()
|
|
989
|
+
|
|
990
|
+
a_np = np.arange(ndev * per_device, dtype=np.float32).reshape((ndev, per_device))
|
|
991
|
+
b_np = np.ones((ndev, per_device), dtype=np.float32)
|
|
992
|
+
ref = (a_np + b_np) + s * a_np
|
|
993
|
+
assert_np_equal(np.asarray(out), ref)
|
|
994
|
+
|
|
995
|
+
|
|
996
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
997
|
+
def test_ffi_jax_callable_pmap_multi_stage(test, device):
|
|
998
|
+
import jax
|
|
999
|
+
import jax.numpy as jp
|
|
1000
|
+
|
|
1001
|
+
from warp.jax_experimental.ffi import jax_callable
|
|
1002
|
+
|
|
1003
|
+
def multi_stage_py(
|
|
1004
|
+
a: wp.array(dtype=float),
|
|
1005
|
+
b: wp.array(dtype=float),
|
|
1006
|
+
alpha: float,
|
|
1007
|
+
tmp: wp.array(dtype=float),
|
|
1008
|
+
out: wp.array(dtype=float),
|
|
1009
|
+
):
|
|
1010
|
+
wp.launch(add_kernel, dim=a.shape, inputs=[a, b], outputs=[tmp])
|
|
1011
|
+
wp.launch(axpy_kernel, dim=a.shape, inputs=[tmp, b, alpha], outputs=[out])
|
|
1012
|
+
|
|
1013
|
+
j = jax_callable(multi_stage_py, num_outputs=2)
|
|
1014
|
+
|
|
1015
|
+
ndev = jax.local_device_count()
|
|
1016
|
+
per_device = max(ARRAY_SIZE // ndev, 64)
|
|
1017
|
+
a = jp.arange(ndev * per_device, dtype=jp.float32).reshape((ndev, per_device))
|
|
1018
|
+
b = jp.ones((ndev, per_device), dtype=jp.float32)
|
|
1019
|
+
alpha = 2.5
|
|
1020
|
+
|
|
1021
|
+
def per_device_func(aa, bb):
|
|
1022
|
+
tmp, out = j(aa, bb, alpha)
|
|
1023
|
+
return tmp + out
|
|
1024
|
+
|
|
1025
|
+
combined = jax.pmap(per_device_func)(a, b)
|
|
1026
|
+
|
|
1027
|
+
wp.synchronize()
|
|
1028
|
+
|
|
1029
|
+
a_np = np.arange(ndev * per_device, dtype=np.float32).reshape((ndev, per_device))
|
|
1030
|
+
b_np = np.ones((ndev, per_device), dtype=np.float32)
|
|
1031
|
+
tmp_ref = a_np + b_np
|
|
1032
|
+
out_ref = alpha * (a_np + b_np) + b_np
|
|
1033
|
+
ref = tmp_ref + out_ref
|
|
1034
|
+
assert_np_equal(np.asarray(combined), ref)
|
|
1035
|
+
|
|
1036
|
+
|
|
1037
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
1038
|
+
def test_ffi_callback(test, device):
|
|
1039
|
+
# in-out arguments
|
|
1040
|
+
import jax.numpy as jp
|
|
1041
|
+
|
|
1042
|
+
from warp.jax_experimental.ffi import register_ffi_callback
|
|
1043
|
+
|
|
1044
|
+
# the Python function to call
|
|
1045
|
+
def warp_func(inputs, outputs, attrs, ctx):
|
|
1046
|
+
# input arrays
|
|
1047
|
+
a = inputs[0]
|
|
1048
|
+
b = inputs[1]
|
|
1049
|
+
|
|
1050
|
+
# scalar attributes
|
|
1051
|
+
s = attrs["scale"]
|
|
1052
|
+
|
|
1053
|
+
# output arrays
|
|
1054
|
+
c = outputs[0]
|
|
1055
|
+
d = outputs[1]
|
|
1056
|
+
|
|
1057
|
+
device = wp.device_from_jax(get_jax_device())
|
|
1058
|
+
stream = wp.Stream(device, cuda_stream=ctx.stream)
|
|
1059
|
+
|
|
1060
|
+
with wp.ScopedStream(stream):
|
|
1061
|
+
# launch with arrays of scalars
|
|
1062
|
+
wp.launch(scale_kernel, dim=a.shape, inputs=[a, s], outputs=[c])
|
|
1063
|
+
|
|
1064
|
+
# launch with arrays of vec2
|
|
1065
|
+
# NOTE: the input shapes are from JAX arrays, we need to strip the inner dimension for vec2 arrays
|
|
1066
|
+
wp.launch(scale_vec_kernel, dim=b.shape[0], inputs=[b, s], outputs=[d])
|
|
1067
|
+
|
|
1068
|
+
# register callback
|
|
1069
|
+
register_ffi_callback("warp_func", warp_func)
|
|
1070
|
+
|
|
1071
|
+
n = ARRAY_SIZE
|
|
1072
|
+
|
|
1073
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
1074
|
+
# inputs
|
|
1075
|
+
a = jp.arange(n, dtype=jp.float32)
|
|
1076
|
+
b = jp.arange(n, dtype=jp.float32).reshape((n // 2, 2)) # array of wp.vec2
|
|
1077
|
+
s = 2.0
|
|
1078
|
+
|
|
1079
|
+
# set up call
|
|
1080
|
+
out_types = [
|
|
1081
|
+
jax.ShapeDtypeStruct(a.shape, jp.float32),
|
|
1082
|
+
jax.ShapeDtypeStruct(b.shape, jp.float32), # array of wp.vec2
|
|
1083
|
+
]
|
|
1084
|
+
call = jax.ffi.ffi_call("warp_func", out_types)
|
|
1085
|
+
|
|
1086
|
+
# call it
|
|
1087
|
+
c, d = call(a, b, scale=s)
|
|
1088
|
+
|
|
1089
|
+
wp.synchronize_device(device)
|
|
1090
|
+
|
|
1091
|
+
assert_np_equal(c, 2 * np.arange(ARRAY_SIZE, dtype=np.float32))
|
|
1092
|
+
assert_np_equal(d, 2 * np.arange(ARRAY_SIZE, dtype=np.float32).reshape((ARRAY_SIZE // 2, 2)))
|
|
1093
|
+
|
|
1094
|
+
|
|
1095
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
1096
|
+
def test_ffi_jax_kernel_autodiff_simple(test, device):
|
|
1097
|
+
import jax
|
|
1098
|
+
import jax.numpy as jp
|
|
1099
|
+
|
|
1100
|
+
from warp.jax_experimental.ffi import jax_kernel
|
|
1101
|
+
|
|
1102
|
+
jax_func = jax_kernel(
|
|
1103
|
+
scale_sum_square_kernel,
|
|
1104
|
+
num_outputs=1,
|
|
1105
|
+
enable_backward=True,
|
|
1106
|
+
)
|
|
1107
|
+
|
|
1108
|
+
from functools import partial
|
|
1109
|
+
|
|
1110
|
+
@partial(jax.jit, static_argnames=["s"])
|
|
1111
|
+
def loss(a, b, s):
|
|
1112
|
+
out = jax_func(a, b, s)[0]
|
|
1113
|
+
return jp.sum(out)
|
|
1114
|
+
|
|
1115
|
+
n = ARRAY_SIZE
|
|
1116
|
+
a = jp.arange(n, dtype=jp.float32)
|
|
1117
|
+
b = jp.ones(n, dtype=jp.float32)
|
|
1118
|
+
s = 2.0
|
|
1119
|
+
|
|
1120
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
1121
|
+
da, db = jax.grad(loss, argnums=(0, 1))(a, b, s)
|
|
1122
|
+
|
|
1123
|
+
wp.synchronize_device(device)
|
|
1124
|
+
|
|
1125
|
+
# reference gradients
|
|
1126
|
+
# d/da sum((a*s + b)^2) = sum(2*(a*s + b) * s)
|
|
1127
|
+
# d/db sum((a*s + b)^2) = sum(2*(a*s + b))
|
|
1128
|
+
a_np = np.arange(n, dtype=np.float32)
|
|
1129
|
+
b_np = np.ones(n, dtype=np.float32)
|
|
1130
|
+
ref_da = 2.0 * (a_np * s + b_np) * s
|
|
1131
|
+
ref_db = 2.0 * (a_np * s + b_np)
|
|
1132
|
+
|
|
1133
|
+
assert_np_equal(np.asarray(da), ref_da)
|
|
1134
|
+
assert_np_equal(np.asarray(db), ref_db)
|
|
1135
|
+
|
|
1136
|
+
|
|
1137
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
1138
|
+
def test_ffi_jax_kernel_autodiff_jit_of_grad_simple(test, device):
|
|
1139
|
+
import jax
|
|
1140
|
+
import jax.numpy as jp
|
|
1141
|
+
|
|
1142
|
+
from warp.jax_experimental.ffi import jax_kernel
|
|
1143
|
+
|
|
1144
|
+
jax_func = jax_kernel(scale_sum_square_kernel, num_outputs=1, enable_backward=True)
|
|
1145
|
+
|
|
1146
|
+
def loss(a, b, s):
|
|
1147
|
+
out = jax_func(a, b, s)[0]
|
|
1148
|
+
return jp.sum(out)
|
|
1149
|
+
|
|
1150
|
+
grad_fn = jax.grad(loss, argnums=(0, 1))
|
|
1151
|
+
|
|
1152
|
+
# more typical: jit(grad(...)) with static scalar
|
|
1153
|
+
jitted_grad = jax.jit(lambda a, b, s: grad_fn(a, b, s), static_argnames=("s",))
|
|
1154
|
+
|
|
1155
|
+
n = ARRAY_SIZE
|
|
1156
|
+
a = jp.arange(n, dtype=jp.float32)
|
|
1157
|
+
b = jp.ones(n, dtype=jp.float32)
|
|
1158
|
+
s = 2.0
|
|
1159
|
+
|
|
1160
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
1161
|
+
da, db = jitted_grad(a, b, s)
|
|
1162
|
+
|
|
1163
|
+
wp.synchronize_device(device)
|
|
1164
|
+
|
|
1165
|
+
a_np = np.arange(n, dtype=np.float32)
|
|
1166
|
+
b_np = np.ones(n, dtype=np.float32)
|
|
1167
|
+
ref_da = 2.0 * (a_np * s + b_np) * s
|
|
1168
|
+
ref_db = 2.0 * (a_np * s + b_np)
|
|
1169
|
+
|
|
1170
|
+
assert_np_equal(np.asarray(da), ref_da)
|
|
1171
|
+
assert_np_equal(np.asarray(db), ref_db)
|
|
1172
|
+
|
|
1173
|
+
|
|
1174
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
1175
|
+
def test_ffi_jax_kernel_autodiff_multi_output(test, device):
|
|
1176
|
+
import jax
|
|
1177
|
+
import jax.numpy as jp
|
|
1178
|
+
|
|
1179
|
+
from warp.jax_experimental.ffi import jax_kernel
|
|
1180
|
+
|
|
1181
|
+
jax_func = jax_kernel(multi_out_kernel_v3, num_outputs=2, enable_backward=True)
|
|
1182
|
+
|
|
1183
|
+
def caller(fn, a, b, s):
|
|
1184
|
+
c, d = fn(a, b, s)
|
|
1185
|
+
return jp.sum(c + d)
|
|
1186
|
+
|
|
1187
|
+
@jax.jit
|
|
1188
|
+
def grads(a, b, s):
|
|
1189
|
+
# mark s as static in the inner call via partial to avoid hashing
|
|
1190
|
+
def _inner(a, b, s):
|
|
1191
|
+
return caller(jax_func, a, b, s)
|
|
1192
|
+
|
|
1193
|
+
return jax.grad(lambda a, b: _inner(a, b, 2.0), argnums=(0, 1))(a, b)
|
|
1194
|
+
|
|
1195
|
+
n = ARRAY_SIZE
|
|
1196
|
+
a = jp.arange(n, dtype=jp.float32)
|
|
1197
|
+
b = jp.ones(n, dtype=jp.float32)
|
|
1198
|
+
s = 2.0
|
|
1199
|
+
|
|
1200
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
1201
|
+
da, db = grads(a, b, s)
|
|
1202
|
+
|
|
1203
|
+
wp.synchronize_device(device)
|
|
1204
|
+
|
|
1205
|
+
a_np = np.arange(n, dtype=np.float32)
|
|
1206
|
+
b_np = np.ones(n, dtype=np.float32)
|
|
1207
|
+
# d/da sum(c+d) = 2*a + b*s
|
|
1208
|
+
ref_da = 2.0 * a_np + b_np * s
|
|
1209
|
+
# d/db sum(c+d) = a*s
|
|
1210
|
+
ref_db = a_np * s
|
|
1211
|
+
|
|
1212
|
+
assert_np_equal(np.asarray(da), ref_da)
|
|
1213
|
+
assert_np_equal(np.asarray(db), ref_db)
|
|
1214
|
+
|
|
1215
|
+
|
|
1216
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
1217
|
+
def test_ffi_jax_kernel_autodiff_jit_of_grad_multi_output(test, device):
|
|
1218
|
+
import jax
|
|
1219
|
+
import jax.numpy as jp
|
|
1220
|
+
|
|
1221
|
+
from warp.jax_experimental.ffi import jax_kernel
|
|
1222
|
+
|
|
1223
|
+
jax_func = jax_kernel(multi_out_kernel_v3, num_outputs=2, enable_backward=True)
|
|
1224
|
+
|
|
1225
|
+
def loss(a, b, s):
|
|
1226
|
+
c, d = jax_func(a, b, s)
|
|
1227
|
+
return jp.sum(c + d)
|
|
1228
|
+
|
|
1229
|
+
grad_fn = jax.grad(loss, argnums=(0, 1))
|
|
1230
|
+
jitted_grad = jax.jit(lambda a, b, s: grad_fn(a, b, s), static_argnames=("s",))
|
|
1231
|
+
|
|
1232
|
+
n = ARRAY_SIZE
|
|
1233
|
+
a = jp.arange(n, dtype=jp.float32)
|
|
1234
|
+
b = jp.ones(n, dtype=jp.float32)
|
|
1235
|
+
s = 2.0
|
|
1236
|
+
|
|
1237
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
1238
|
+
da, db = jitted_grad(a, b, s)
|
|
1239
|
+
|
|
1240
|
+
wp.synchronize_device(device)
|
|
1241
|
+
|
|
1242
|
+
a_np = np.arange(n, dtype=np.float32)
|
|
1243
|
+
b_np = np.ones(n, dtype=np.float32)
|
|
1244
|
+
ref_da = 2.0 * a_np + b_np * s
|
|
1245
|
+
ref_db = a_np * s
|
|
1246
|
+
|
|
1247
|
+
assert_np_equal(np.asarray(da), ref_da)
|
|
1248
|
+
assert_np_equal(np.asarray(db), ref_db)
|
|
1249
|
+
|
|
1250
|
+
|
|
1251
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
1252
|
+
def test_ffi_jax_kernel_autodiff_2d(test, device):
|
|
1253
|
+
import jax
|
|
1254
|
+
import jax.numpy as jp
|
|
1255
|
+
|
|
1256
|
+
from warp.jax_experimental.ffi import jax_kernel
|
|
1257
|
+
|
|
1258
|
+
jax_func = jax_kernel(inc_2d_kernel, num_outputs=1, enable_backward=True)
|
|
1259
|
+
|
|
1260
|
+
@jax.jit
|
|
1261
|
+
def loss(a):
|
|
1262
|
+
out = jax_func(a)[0]
|
|
1263
|
+
return jp.sum(out)
|
|
1264
|
+
|
|
1265
|
+
n, m = 8, 6
|
|
1266
|
+
a = jp.arange(n * m, dtype=jp.float32).reshape((n, m))
|
|
1267
|
+
|
|
1268
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
1269
|
+
(da,) = jax.grad(loss, argnums=(0,))(a)
|
|
1270
|
+
|
|
1271
|
+
wp.synchronize_device(device)
|
|
1272
|
+
|
|
1273
|
+
ref = np.ones((n, m), dtype=np.float32)
|
|
1274
|
+
assert_np_equal(np.asarray(da), ref)
|
|
1275
|
+
|
|
1276
|
+
|
|
1277
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
1278
|
+
def test_ffi_jax_kernel_autodiff_vec2(test, device):
|
|
1279
|
+
import jax
|
|
1280
|
+
import jax.numpy as jp
|
|
1281
|
+
|
|
1282
|
+
from warp.jax_experimental.ffi import jax_kernel
|
|
1283
|
+
|
|
1284
|
+
jax_func = jax_kernel(scale_vec_kernel, num_outputs=1, enable_backward=True)
|
|
1285
|
+
|
|
1286
|
+
from functools import partial
|
|
1287
|
+
|
|
1288
|
+
@partial(jax.jit, static_argnames=("s",))
|
|
1289
|
+
def loss(a, s):
|
|
1290
|
+
out = jax_func(a, s)[0]
|
|
1291
|
+
return jp.sum(out)
|
|
1292
|
+
|
|
1293
|
+
n = ARRAY_SIZE
|
|
1294
|
+
a = jp.arange(n, dtype=jp.float32).reshape((n // 2, 2))
|
|
1295
|
+
s = 3.0
|
|
1296
|
+
|
|
1297
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
1298
|
+
(da,) = jax.grad(loss, argnums=(0,))(a, s)
|
|
1299
|
+
|
|
1300
|
+
wp.synchronize_device(device)
|
|
1301
|
+
|
|
1302
|
+
# d/da sum(a*s) = s
|
|
1303
|
+
ref = np.full_like(np.asarray(a), s)
|
|
1304
|
+
assert_np_equal(np.asarray(da), ref)
|
|
1305
|
+
|
|
1306
|
+
|
|
1307
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
1308
|
+
def test_ffi_jax_kernel_autodiff_mat22(test, device):
|
|
1309
|
+
import jax
|
|
1310
|
+
import jax.numpy as jp
|
|
1311
|
+
|
|
1312
|
+
from warp.jax_experimental.ffi import jax_kernel
|
|
1313
|
+
|
|
1314
|
+
@wp.kernel
|
|
1315
|
+
def scale_mat_kernel(a: wp.array(dtype=wp.mat22), s: float, out: wp.array(dtype=wp.mat22)):
|
|
1316
|
+
tid = wp.tid()
|
|
1317
|
+
out[tid] = a[tid] * s
|
|
1318
|
+
|
|
1319
|
+
jax_func = jax_kernel(scale_mat_kernel, num_outputs=1, enable_backward=True)
|
|
1320
|
+
|
|
1321
|
+
from functools import partial
|
|
1322
|
+
|
|
1323
|
+
@partial(jax.jit, static_argnames=("s",))
|
|
1324
|
+
def loss(a, s):
|
|
1325
|
+
out = jax_func(a, s)[0]
|
|
1326
|
+
return jp.sum(out)
|
|
1327
|
+
|
|
1328
|
+
n = 12 # must be divisible by 4 for 2x2 matrices
|
|
1329
|
+
a = jp.arange(n, dtype=jp.float32).reshape((n // 4, 2, 2))
|
|
1330
|
+
s = 2.5
|
|
1331
|
+
|
|
1332
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
1333
|
+
(da,) = jax.grad(loss, argnums=(0,))(a, s)
|
|
1334
|
+
|
|
1335
|
+
wp.synchronize_device(device)
|
|
1336
|
+
|
|
1337
|
+
ref = np.full((n // 4, 2, 2), s, dtype=np.float32)
|
|
1338
|
+
assert_np_equal(np.asarray(da), ref)
|
|
1339
|
+
|
|
1340
|
+
|
|
1341
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
1342
|
+
def test_ffi_jax_kernel_autodiff_static_required(test, device):
|
|
1343
|
+
import jax
|
|
1344
|
+
import jax.numpy as jp
|
|
1345
|
+
|
|
1346
|
+
from warp.jax_experimental.ffi import jax_kernel
|
|
1347
|
+
|
|
1348
|
+
# Require explicit static_argnames for scalar s
|
|
1349
|
+
jax_func = jax_kernel(scale_sum_square_kernel, num_outputs=1, enable_backward=True)
|
|
1350
|
+
|
|
1351
|
+
def loss(a, b, s):
|
|
1352
|
+
out = jax_func(a, b, s)[0]
|
|
1353
|
+
return jp.sum(out)
|
|
1354
|
+
|
|
1355
|
+
n = ARRAY_SIZE
|
|
1356
|
+
a = jp.arange(n, dtype=jp.float32)
|
|
1357
|
+
b = jp.ones(n, dtype=jp.float32)
|
|
1358
|
+
s = 1.5
|
|
1359
|
+
|
|
1360
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
1361
|
+
da, db = jax.grad(loss, argnums=(0, 1))(a, b, s)
|
|
1362
|
+
|
|
1363
|
+
wp.synchronize_device(device)
|
|
1364
|
+
|
|
1365
|
+
a_np = np.arange(n, dtype=np.float32)
|
|
1366
|
+
b_np = np.ones(n, dtype=np.float32)
|
|
1367
|
+
ref_da = 2.0 * (a_np * s + b_np) * s
|
|
1368
|
+
ref_db = 2.0 * (a_np * s + b_np)
|
|
1369
|
+
|
|
1370
|
+
assert_np_equal(np.asarray(da), ref_da)
|
|
1371
|
+
assert_np_equal(np.asarray(db), ref_db)
|
|
1372
|
+
|
|
1373
|
+
|
|
1374
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
1375
|
+
def test_ffi_jax_kernel_autodiff_pmap_triple(test, device):
|
|
1376
|
+
import jax
|
|
1377
|
+
import jax.numpy as jp
|
|
1378
|
+
|
|
1379
|
+
from warp.jax_experimental.ffi import jax_kernel
|
|
1380
|
+
|
|
1381
|
+
jax_mul = jax_kernel(triple_kernel, num_outputs=1, enable_backward=True)
|
|
1382
|
+
|
|
1383
|
+
ndev = jax.local_device_count()
|
|
1384
|
+
per_device = ARRAY_SIZE // ndev
|
|
1385
|
+
x = jp.arange(ndev * per_device, dtype=jp.float32).reshape((ndev, per_device))
|
|
1386
|
+
|
|
1387
|
+
def per_device_loss(x):
|
|
1388
|
+
y = jax_mul(x)[0]
|
|
1389
|
+
return jp.sum(y)
|
|
1390
|
+
|
|
1391
|
+
grads = jax.pmap(jax.grad(per_device_loss))(x)
|
|
1392
|
+
|
|
1393
|
+
wp.synchronize()
|
|
1394
|
+
|
|
1395
|
+
assert_np_equal(np.asarray(grads), np.full((ndev, per_device), 3.0, dtype=np.float32))
|
|
1396
|
+
|
|
1397
|
+
|
|
1398
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
1399
|
+
def test_ffi_jax_kernel_autodiff_pmap_multi_output(test, device):
|
|
1400
|
+
import jax
|
|
1401
|
+
import jax.numpy as jp
|
|
1402
|
+
|
|
1403
|
+
from warp.jax_experimental.ffi import jax_kernel
|
|
1404
|
+
|
|
1405
|
+
jax_mo = jax_kernel(multi_out_kernel_v2, num_outputs=2, enable_backward=True)
|
|
1406
|
+
|
|
1407
|
+
ndev = jax.local_device_count()
|
|
1408
|
+
per_device = ARRAY_SIZE // ndev
|
|
1409
|
+
a = jp.arange(ndev * per_device, dtype=jp.float32).reshape((ndev, per_device))
|
|
1410
|
+
b = jp.arange(ndev * per_device, dtype=jp.float32).reshape((ndev, per_device))
|
|
1411
|
+
s = 2.0
|
|
1412
|
+
|
|
1413
|
+
def per_dev_loss(aa, bb):
|
|
1414
|
+
c, d = jax_mo(aa, bb, s)
|
|
1415
|
+
return jp.sum(c + d)
|
|
1416
|
+
|
|
1417
|
+
da, db = jax.pmap(jax.grad(per_dev_loss, argnums=(0, 1)))(a, b)
|
|
1418
|
+
|
|
1419
|
+
wp.synchronize()
|
|
1420
|
+
|
|
1421
|
+
a_np = np.arange(ndev * per_device, dtype=np.float32).reshape((ndev, per_device))
|
|
1422
|
+
b_np = np.arange(ndev * per_device, dtype=np.float32).reshape((ndev, per_device))
|
|
1423
|
+
ref_da = 2.0 * a_np + b_np * s
|
|
1424
|
+
ref_db = a_np * s
|
|
1425
|
+
assert_np_equal(np.asarray(da), ref_da)
|
|
1426
|
+
assert_np_equal(np.asarray(db), ref_db)
|
|
1427
|
+
|
|
1428
|
+
|
|
1429
|
+
class TestJax(unittest.TestCase):
|
|
1430
|
+
pass
|
|
1431
|
+
|
|
1432
|
+
|
|
1433
|
+
# try adding Jax tests if Jax is installed correctly
|
|
1434
|
+
try:
|
|
1435
|
+
# prevent Jax from gobbling up GPU memory
|
|
1436
|
+
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
|
|
1437
|
+
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
|
|
1438
|
+
|
|
1439
|
+
import jax
|
|
1440
|
+
|
|
1441
|
+
# NOTE: we must enable 64-bit types in Jax to test the full gamut of types
|
|
1442
|
+
jax.config.update("jax_enable_x64", True)
|
|
1443
|
+
|
|
1444
|
+
# check which Warp devices work with Jax
|
|
1445
|
+
# CUDA devices may fail if Jax cannot find a CUDA Toolkit
|
|
1446
|
+
test_devices = get_test_devices()
|
|
1447
|
+
jax_compatible_devices = []
|
|
1448
|
+
jax_compatible_cuda_devices = []
|
|
1449
|
+
for d in test_devices:
|
|
1450
|
+
try:
|
|
1451
|
+
with jax.default_device(wp.device_to_jax(d)):
|
|
1452
|
+
j = jax.numpy.arange(10, dtype=jax.numpy.float32)
|
|
1453
|
+
j += 1
|
|
1454
|
+
jax_compatible_devices.append(d)
|
|
1455
|
+
if d.is_cuda:
|
|
1456
|
+
jax_compatible_cuda_devices.append(d)
|
|
1457
|
+
except Exception as e:
|
|
1458
|
+
print(f"Skipping Jax DLPack tests on device '{d}' due to exception: {e}")
|
|
1459
|
+
|
|
1460
|
+
add_function_test(TestJax, "test_dtype_from_jax", test_dtype_from_jax, devices=None)
|
|
1461
|
+
add_function_test(TestJax, "test_dtype_to_jax", test_dtype_to_jax, devices=None)
|
|
1462
|
+
|
|
1463
|
+
if jax_compatible_devices:
|
|
1464
|
+
add_function_test(TestJax, "test_device_conversion", test_device_conversion, devices=jax_compatible_devices)
|
|
1465
|
+
|
|
1466
|
+
if jax_compatible_cuda_devices:
|
|
1467
|
+
# tests for both custom_call and ffi variants of jax_kernel(), selected by installed JAX version
|
|
1468
|
+
if jax.__version_info__ < (0, 4, 25):
|
|
1469
|
+
# no interop supported
|
|
1470
|
+
ffi_opts = []
|
|
1471
|
+
elif jax.__version_info__ < (0, 5, 0):
|
|
1472
|
+
# only custom_call supported
|
|
1473
|
+
ffi_opts = [False]
|
|
1474
|
+
elif jax.__version_info__ < (0, 8, 0):
|
|
1475
|
+
# both custom_call and ffi supported
|
|
1476
|
+
ffi_opts = [False, True]
|
|
1477
|
+
else:
|
|
1478
|
+
# only ffi supported
|
|
1479
|
+
ffi_opts = [True]
|
|
1480
|
+
|
|
1481
|
+
for use_ffi in ffi_opts:
|
|
1482
|
+
suffix = "ffi" if use_ffi else "cc"
|
|
1483
|
+
add_function_test(
|
|
1484
|
+
TestJax,
|
|
1485
|
+
f"test_jax_kernel_basic_{suffix}",
|
|
1486
|
+
test_jax_kernel_basic,
|
|
1487
|
+
devices=jax_compatible_cuda_devices,
|
|
1488
|
+
use_ffi=use_ffi,
|
|
1489
|
+
)
|
|
1490
|
+
add_function_test(
|
|
1491
|
+
TestJax,
|
|
1492
|
+
f"test_jax_kernel_scalar_{suffix}",
|
|
1493
|
+
test_jax_kernel_scalar,
|
|
1494
|
+
devices=jax_compatible_cuda_devices,
|
|
1495
|
+
use_ffi=use_ffi,
|
|
1496
|
+
)
|
|
1497
|
+
add_function_test(
|
|
1498
|
+
TestJax,
|
|
1499
|
+
f"test_jax_kernel_vecmat_{suffix}",
|
|
1500
|
+
test_jax_kernel_vecmat,
|
|
1501
|
+
devices=jax_compatible_cuda_devices,
|
|
1502
|
+
use_ffi=use_ffi,
|
|
1503
|
+
)
|
|
1504
|
+
add_function_test(
|
|
1505
|
+
TestJax,
|
|
1506
|
+
f"test_jax_kernel_multiarg_{suffix}",
|
|
1507
|
+
test_jax_kernel_multiarg,
|
|
1508
|
+
devices=jax_compatible_cuda_devices,
|
|
1509
|
+
use_ffi=use_ffi,
|
|
1510
|
+
)
|
|
1511
|
+
add_function_test(
|
|
1512
|
+
TestJax,
|
|
1513
|
+
f"test_jax_kernel_launch_dims_{suffix}",
|
|
1514
|
+
test_jax_kernel_launch_dims,
|
|
1515
|
+
devices=jax_compatible_cuda_devices,
|
|
1516
|
+
use_ffi=use_ffi,
|
|
1517
|
+
)
|
|
1518
|
+
|
|
1519
|
+
# ffi.jax_kernel() tests
|
|
1520
|
+
add_function_test(
|
|
1521
|
+
TestJax, "test_ffi_jax_kernel_add", test_ffi_jax_kernel_add, devices=jax_compatible_cuda_devices
|
|
1522
|
+
)
|
|
1523
|
+
add_function_test(
|
|
1524
|
+
TestJax, "test_ffi_jax_kernel_sincos", test_ffi_jax_kernel_sincos, devices=jax_compatible_cuda_devices
|
|
1525
|
+
)
|
|
1526
|
+
add_function_test(
|
|
1527
|
+
TestJax, "test_ffi_jax_kernel_diagonal", test_ffi_jax_kernel_diagonal, devices=jax_compatible_cuda_devices
|
|
1528
|
+
)
|
|
1529
|
+
add_function_test(
|
|
1530
|
+
TestJax, "test_ffi_jax_kernel_in_out", test_ffi_jax_kernel_in_out, devices=jax_compatible_cuda_devices
|
|
1531
|
+
)
|
|
1532
|
+
add_function_test(
|
|
1533
|
+
TestJax,
|
|
1534
|
+
"test_ffi_jax_kernel_scale_vec_constant",
|
|
1535
|
+
test_ffi_jax_kernel_scale_vec_constant,
|
|
1536
|
+
devices=jax_compatible_cuda_devices,
|
|
1537
|
+
)
|
|
1538
|
+
add_function_test(
|
|
1539
|
+
TestJax,
|
|
1540
|
+
"test_ffi_jax_kernel_scale_vec_static",
|
|
1541
|
+
test_ffi_jax_kernel_scale_vec_static,
|
|
1542
|
+
devices=jax_compatible_cuda_devices,
|
|
1543
|
+
)
|
|
1544
|
+
add_function_test(
|
|
1545
|
+
TestJax,
|
|
1546
|
+
"test_ffi_jax_kernel_launch_dims_default",
|
|
1547
|
+
test_ffi_jax_kernel_launch_dims_default,
|
|
1548
|
+
devices=jax_compatible_cuda_devices,
|
|
1549
|
+
)
|
|
1550
|
+
add_function_test(
|
|
1551
|
+
TestJax,
|
|
1552
|
+
"test_ffi_jax_kernel_launch_dims_custom",
|
|
1553
|
+
test_ffi_jax_kernel_launch_dims_custom,
|
|
1554
|
+
devices=jax_compatible_cuda_devices,
|
|
1555
|
+
)
|
|
1556
|
+
|
|
1557
|
+
# ffi.jax_callable() tests
|
|
1558
|
+
add_function_test(
|
|
1559
|
+
TestJax,
|
|
1560
|
+
"test_ffi_jax_callable_scale_constant",
|
|
1561
|
+
test_ffi_jax_callable_scale_constant,
|
|
1562
|
+
devices=jax_compatible_cuda_devices,
|
|
1563
|
+
)
|
|
1564
|
+
add_function_test(
|
|
1565
|
+
TestJax,
|
|
1566
|
+
"test_ffi_jax_callable_scale_static",
|
|
1567
|
+
test_ffi_jax_callable_scale_static,
|
|
1568
|
+
devices=jax_compatible_cuda_devices,
|
|
1569
|
+
)
|
|
1570
|
+
add_function_test(
|
|
1571
|
+
TestJax, "test_ffi_jax_callable_in_out", test_ffi_jax_callable_in_out, devices=jax_compatible_cuda_devices
|
|
1572
|
+
)
|
|
1573
|
+
add_function_test(
|
|
1574
|
+
TestJax,
|
|
1575
|
+
"test_ffi_jax_callable_graph_cache",
|
|
1576
|
+
test_ffi_jax_callable_graph_cache,
|
|
1577
|
+
devices=jax_compatible_cuda_devices,
|
|
1578
|
+
)
|
|
1579
|
+
|
|
1580
|
+
# pmap tests
|
|
1581
|
+
add_function_test(
|
|
1582
|
+
TestJax,
|
|
1583
|
+
"test_ffi_jax_callable_pmap_multi_output",
|
|
1584
|
+
test_ffi_jax_callable_pmap_multi_output,
|
|
1585
|
+
devices=None,
|
|
1586
|
+
)
|
|
1587
|
+
add_function_test(
|
|
1588
|
+
TestJax,
|
|
1589
|
+
"test_ffi_jax_callable_pmap_mul",
|
|
1590
|
+
test_ffi_jax_callable_pmap_mul,
|
|
1591
|
+
devices=None,
|
|
1592
|
+
)
|
|
1593
|
+
add_function_test(
|
|
1594
|
+
TestJax,
|
|
1595
|
+
"test_ffi_jax_callable_pmap_multi_stage",
|
|
1596
|
+
test_ffi_jax_callable_pmap_multi_stage,
|
|
1597
|
+
devices=None,
|
|
1598
|
+
)
|
|
1599
|
+
|
|
1600
|
+
# ffi callback tests
|
|
1601
|
+
add_function_test(TestJax, "test_ffi_callback", test_ffi_callback, devices=jax_compatible_cuda_devices)
|
|
1602
|
+
|
|
1603
|
+
# autodiff tests
|
|
1604
|
+
add_function_test(
|
|
1605
|
+
TestJax,
|
|
1606
|
+
"test_ffi_jax_kernel_autodiff_simple",
|
|
1607
|
+
test_ffi_jax_kernel_autodiff_simple,
|
|
1608
|
+
devices=jax_compatible_cuda_devices,
|
|
1609
|
+
)
|
|
1610
|
+
add_function_test(
|
|
1611
|
+
TestJax,
|
|
1612
|
+
"test_ffi_jax_kernel_autodiff_jit_of_grad_simple",
|
|
1613
|
+
test_ffi_jax_kernel_autodiff_jit_of_grad_simple,
|
|
1614
|
+
devices=jax_compatible_cuda_devices,
|
|
1615
|
+
)
|
|
1616
|
+
add_function_test(
|
|
1617
|
+
TestJax,
|
|
1618
|
+
"test_ffi_jax_kernel_autodiff_multi_output",
|
|
1619
|
+
test_ffi_jax_kernel_autodiff_multi_output,
|
|
1620
|
+
devices=jax_compatible_cuda_devices,
|
|
1621
|
+
)
|
|
1622
|
+
add_function_test(
|
|
1623
|
+
TestJax,
|
|
1624
|
+
"test_ffi_jax_kernel_autodiff_jit_of_grad_multi_output",
|
|
1625
|
+
test_ffi_jax_kernel_autodiff_jit_of_grad_multi_output,
|
|
1626
|
+
devices=jax_compatible_cuda_devices,
|
|
1627
|
+
)
|
|
1628
|
+
add_function_test(
|
|
1629
|
+
TestJax,
|
|
1630
|
+
"test_ffi_jax_kernel_autodiff_2d",
|
|
1631
|
+
test_ffi_jax_kernel_autodiff_2d,
|
|
1632
|
+
devices=jax_compatible_cuda_devices,
|
|
1633
|
+
)
|
|
1634
|
+
add_function_test(
|
|
1635
|
+
TestJax,
|
|
1636
|
+
"test_ffi_jax_kernel_autodiff_vec2",
|
|
1637
|
+
test_ffi_jax_kernel_autodiff_vec2,
|
|
1638
|
+
devices=jax_compatible_cuda_devices,
|
|
1639
|
+
)
|
|
1640
|
+
add_function_test(
|
|
1641
|
+
TestJax,
|
|
1642
|
+
"test_ffi_jax_kernel_autodiff_mat22",
|
|
1643
|
+
test_ffi_jax_kernel_autodiff_mat22,
|
|
1644
|
+
devices=jax_compatible_cuda_devices,
|
|
1645
|
+
)
|
|
1646
|
+
add_function_test(
|
|
1647
|
+
TestJax,
|
|
1648
|
+
"test_ffi_jax_kernel_autodiff_static_required",
|
|
1649
|
+
test_ffi_jax_kernel_autodiff_static_required,
|
|
1650
|
+
devices=jax_compatible_cuda_devices,
|
|
1651
|
+
)
|
|
1652
|
+
|
|
1653
|
+
# autodiff with pmap tests
|
|
1654
|
+
add_function_test(
|
|
1655
|
+
TestJax,
|
|
1656
|
+
"test_ffi_jax_kernel_autodiff_pmap_triple",
|
|
1657
|
+
test_ffi_jax_kernel_autodiff_pmap_triple,
|
|
1658
|
+
devices=None,
|
|
1659
|
+
)
|
|
1660
|
+
add_function_test(
|
|
1661
|
+
TestJax,
|
|
1662
|
+
"test_ffi_jax_kernel_autodiff_pmap_multi_output",
|
|
1663
|
+
test_ffi_jax_kernel_autodiff_pmap_multi_output,
|
|
1664
|
+
devices=None,
|
|
1665
|
+
)
|
|
1666
|
+
|
|
1667
|
+
except Exception as e:
|
|
1668
|
+
print(f"Skipping Jax tests due to exception: {e}")
|
|
1669
|
+
|
|
1670
|
+
|
|
1671
|
+
if __name__ == "__main__":
|
|
1672
|
+
wp.clear_kernel_cache()
|
|
1673
|
+
unittest.main(verbosity=2)
|