warp-lang 1.7.0__py3-none-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +139 -0
- warp/__init__.pyi +1 -0
- warp/autograd.py +1142 -0
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/build.py +557 -0
- warp/build_dll.py +405 -0
- warp/builtins.py +6855 -0
- warp/codegen.py +3969 -0
- warp/config.py +158 -0
- warp/constants.py +57 -0
- warp/context.py +6812 -0
- warp/dlpack.py +462 -0
- warp/examples/__init__.py +24 -0
- warp/examples/assets/bear.usd +0 -0
- warp/examples/assets/bunny.usd +0 -0
- warp/examples/assets/cartpole.urdf +110 -0
- warp/examples/assets/crazyflie.usd +0 -0
- warp/examples/assets/cube.usd +0 -0
- warp/examples/assets/nonuniform.usd +0 -0
- warp/examples/assets/nv_ant.xml +92 -0
- warp/examples/assets/nv_humanoid.xml +183 -0
- warp/examples/assets/nvidia_logo.png +0 -0
- warp/examples/assets/pixel.jpg +0 -0
- warp/examples/assets/quadruped.urdf +268 -0
- warp/examples/assets/rocks.nvdb +0 -0
- warp/examples/assets/rocks.usd +0 -0
- warp/examples/assets/sphere.usd +0 -0
- warp/examples/assets/square_cloth.usd +0 -0
- warp/examples/benchmarks/benchmark_api.py +389 -0
- warp/examples/benchmarks/benchmark_cloth.py +296 -0
- warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
- warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
- warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
- warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
- warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
- warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
- warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
- warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
- warp/examples/benchmarks/benchmark_gemm.py +164 -0
- warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
- warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
- warp/examples/benchmarks/benchmark_launches.py +301 -0
- warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
- warp/examples/browse.py +37 -0
- warp/examples/core/example_cupy.py +86 -0
- warp/examples/core/example_dem.py +241 -0
- warp/examples/core/example_fluid.py +299 -0
- warp/examples/core/example_graph_capture.py +150 -0
- warp/examples/core/example_marching_cubes.py +194 -0
- warp/examples/core/example_mesh.py +180 -0
- warp/examples/core/example_mesh_intersect.py +211 -0
- warp/examples/core/example_nvdb.py +182 -0
- warp/examples/core/example_raycast.py +111 -0
- warp/examples/core/example_raymarch.py +205 -0
- warp/examples/core/example_render_opengl.py +193 -0
- warp/examples/core/example_sample_mesh.py +300 -0
- warp/examples/core/example_sph.py +411 -0
- warp/examples/core/example_torch.py +211 -0
- warp/examples/core/example_wave.py +269 -0
- warp/examples/fem/example_adaptive_grid.py +286 -0
- warp/examples/fem/example_apic_fluid.py +423 -0
- warp/examples/fem/example_burgers.py +261 -0
- warp/examples/fem/example_convection_diffusion.py +178 -0
- warp/examples/fem/example_convection_diffusion_dg.py +204 -0
- warp/examples/fem/example_deformed_geometry.py +172 -0
- warp/examples/fem/example_diffusion.py +196 -0
- warp/examples/fem/example_diffusion_3d.py +225 -0
- warp/examples/fem/example_diffusion_mgpu.py +220 -0
- warp/examples/fem/example_distortion_energy.py +228 -0
- warp/examples/fem/example_magnetostatics.py +240 -0
- warp/examples/fem/example_mixed_elasticity.py +291 -0
- warp/examples/fem/example_navier_stokes.py +261 -0
- warp/examples/fem/example_nonconforming_contact.py +298 -0
- warp/examples/fem/example_stokes.py +213 -0
- warp/examples/fem/example_stokes_transfer.py +262 -0
- warp/examples/fem/example_streamlines.py +352 -0
- warp/examples/fem/utils.py +1000 -0
- warp/examples/interop/example_jax_callable.py +116 -0
- warp/examples/interop/example_jax_ffi_callback.py +132 -0
- warp/examples/interop/example_jax_kernel.py +205 -0
- warp/examples/optim/example_bounce.py +266 -0
- warp/examples/optim/example_cloth_throw.py +228 -0
- warp/examples/optim/example_diffray.py +561 -0
- warp/examples/optim/example_drone.py +870 -0
- warp/examples/optim/example_fluid_checkpoint.py +497 -0
- warp/examples/optim/example_inverse_kinematics.py +182 -0
- warp/examples/optim/example_inverse_kinematics_torch.py +191 -0
- warp/examples/optim/example_softbody_properties.py +400 -0
- warp/examples/optim/example_spring_cage.py +245 -0
- warp/examples/optim/example_trajectory.py +227 -0
- warp/examples/sim/example_cartpole.py +143 -0
- warp/examples/sim/example_cloth.py +225 -0
- warp/examples/sim/example_cloth_self_contact.py +322 -0
- warp/examples/sim/example_granular.py +130 -0
- warp/examples/sim/example_granular_collision_sdf.py +202 -0
- warp/examples/sim/example_jacobian_ik.py +244 -0
- warp/examples/sim/example_particle_chain.py +124 -0
- warp/examples/sim/example_quadruped.py +203 -0
- warp/examples/sim/example_rigid_chain.py +203 -0
- warp/examples/sim/example_rigid_contact.py +195 -0
- warp/examples/sim/example_rigid_force.py +133 -0
- warp/examples/sim/example_rigid_gyroscopic.py +115 -0
- warp/examples/sim/example_rigid_soft_contact.py +140 -0
- warp/examples/sim/example_soft_body.py +196 -0
- warp/examples/tile/example_tile_cholesky.py +87 -0
- warp/examples/tile/example_tile_convolution.py +66 -0
- warp/examples/tile/example_tile_fft.py +55 -0
- warp/examples/tile/example_tile_filtering.py +113 -0
- warp/examples/tile/example_tile_matmul.py +85 -0
- warp/examples/tile/example_tile_mlp.py +383 -0
- warp/examples/tile/example_tile_nbody.py +199 -0
- warp/examples/tile/example_tile_walker.py +327 -0
- warp/fabric.py +355 -0
- warp/fem/__init__.py +106 -0
- warp/fem/adaptivity.py +508 -0
- warp/fem/cache.py +572 -0
- warp/fem/dirichlet.py +202 -0
- warp/fem/domain.py +411 -0
- warp/fem/field/__init__.py +125 -0
- warp/fem/field/field.py +619 -0
- warp/fem/field/nodal_field.py +326 -0
- warp/fem/field/restriction.py +37 -0
- warp/fem/field/virtual.py +848 -0
- warp/fem/geometry/__init__.py +32 -0
- warp/fem/geometry/adaptive_nanogrid.py +857 -0
- warp/fem/geometry/closest_point.py +84 -0
- warp/fem/geometry/deformed_geometry.py +221 -0
- warp/fem/geometry/element.py +776 -0
- warp/fem/geometry/geometry.py +362 -0
- warp/fem/geometry/grid_2d.py +392 -0
- warp/fem/geometry/grid_3d.py +452 -0
- warp/fem/geometry/hexmesh.py +911 -0
- warp/fem/geometry/nanogrid.py +571 -0
- warp/fem/geometry/partition.py +389 -0
- warp/fem/geometry/quadmesh.py +663 -0
- warp/fem/geometry/tetmesh.py +855 -0
- warp/fem/geometry/trimesh.py +806 -0
- warp/fem/integrate.py +2335 -0
- warp/fem/linalg.py +419 -0
- warp/fem/operator.py +293 -0
- warp/fem/polynomial.py +229 -0
- warp/fem/quadrature/__init__.py +17 -0
- warp/fem/quadrature/pic_quadrature.py +299 -0
- warp/fem/quadrature/quadrature.py +591 -0
- warp/fem/space/__init__.py +228 -0
- warp/fem/space/basis_function_space.py +468 -0
- warp/fem/space/basis_space.py +667 -0
- warp/fem/space/dof_mapper.py +251 -0
- warp/fem/space/function_space.py +309 -0
- warp/fem/space/grid_2d_function_space.py +177 -0
- warp/fem/space/grid_3d_function_space.py +227 -0
- warp/fem/space/hexmesh_function_space.py +257 -0
- warp/fem/space/nanogrid_function_space.py +201 -0
- warp/fem/space/partition.py +367 -0
- warp/fem/space/quadmesh_function_space.py +223 -0
- warp/fem/space/restriction.py +179 -0
- warp/fem/space/shape/__init__.py +143 -0
- warp/fem/space/shape/cube_shape_function.py +1105 -0
- warp/fem/space/shape/shape_function.py +133 -0
- warp/fem/space/shape/square_shape_function.py +926 -0
- warp/fem/space/shape/tet_shape_function.py +834 -0
- warp/fem/space/shape/triangle_shape_function.py +672 -0
- warp/fem/space/tetmesh_function_space.py +271 -0
- warp/fem/space/topology.py +424 -0
- warp/fem/space/trimesh_function_space.py +194 -0
- warp/fem/types.py +99 -0
- warp/fem/utils.py +420 -0
- warp/jax.py +187 -0
- warp/jax_experimental/__init__.py +16 -0
- warp/jax_experimental/custom_call.py +351 -0
- warp/jax_experimental/ffi.py +698 -0
- warp/jax_experimental/xla_ffi.py +602 -0
- warp/math.py +244 -0
- warp/native/array.h +1145 -0
- warp/native/builtin.h +1800 -0
- warp/native/bvh.cpp +492 -0
- warp/native/bvh.cu +791 -0
- warp/native/bvh.h +554 -0
- warp/native/clang/clang.cpp +536 -0
- warp/native/coloring.cpp +613 -0
- warp/native/crt.cpp +51 -0
- warp/native/crt.h +362 -0
- warp/native/cuda_crt.h +1058 -0
- warp/native/cuda_util.cpp +646 -0
- warp/native/cuda_util.h +307 -0
- warp/native/error.cpp +77 -0
- warp/native/error.h +36 -0
- warp/native/exports.h +1878 -0
- warp/native/fabric.h +245 -0
- warp/native/hashgrid.cpp +311 -0
- warp/native/hashgrid.cu +87 -0
- warp/native/hashgrid.h +240 -0
- warp/native/initializer_array.h +41 -0
- warp/native/intersect.h +1230 -0
- warp/native/intersect_adj.h +375 -0
- warp/native/intersect_tri.h +339 -0
- warp/native/marching.cpp +19 -0
- warp/native/marching.cu +514 -0
- warp/native/marching.h +19 -0
- warp/native/mat.h +2220 -0
- warp/native/mathdx.cpp +87 -0
- warp/native/matnn.h +343 -0
- warp/native/mesh.cpp +266 -0
- warp/native/mesh.cu +404 -0
- warp/native/mesh.h +1980 -0
- warp/native/nanovdb/GridHandle.h +366 -0
- warp/native/nanovdb/HostBuffer.h +590 -0
- warp/native/nanovdb/NanoVDB.h +6624 -0
- warp/native/nanovdb/PNanoVDB.h +3390 -0
- warp/native/noise.h +859 -0
- warp/native/quat.h +1371 -0
- warp/native/rand.h +342 -0
- warp/native/range.h +139 -0
- warp/native/reduce.cpp +174 -0
- warp/native/reduce.cu +364 -0
- warp/native/runlength_encode.cpp +79 -0
- warp/native/runlength_encode.cu +61 -0
- warp/native/scan.cpp +47 -0
- warp/native/scan.cu +53 -0
- warp/native/scan.h +23 -0
- warp/native/solid_angle.h +466 -0
- warp/native/sort.cpp +251 -0
- warp/native/sort.cu +277 -0
- warp/native/sort.h +33 -0
- warp/native/sparse.cpp +378 -0
- warp/native/sparse.cu +524 -0
- warp/native/spatial.h +657 -0
- warp/native/svd.h +702 -0
- warp/native/temp_buffer.h +46 -0
- warp/native/tile.h +2584 -0
- warp/native/tile_reduce.h +264 -0
- warp/native/vec.h +1426 -0
- warp/native/volume.cpp +501 -0
- warp/native/volume.cu +67 -0
- warp/native/volume.h +969 -0
- warp/native/volume_builder.cu +477 -0
- warp/native/volume_builder.h +52 -0
- warp/native/volume_impl.h +70 -0
- warp/native/warp.cpp +1082 -0
- warp/native/warp.cu +3636 -0
- warp/native/warp.h +381 -0
- warp/optim/__init__.py +17 -0
- warp/optim/adam.py +163 -0
- warp/optim/linear.py +1137 -0
- warp/optim/sgd.py +112 -0
- warp/paddle.py +407 -0
- warp/render/__init__.py +18 -0
- warp/render/render_opengl.py +3518 -0
- warp/render/render_usd.py +784 -0
- warp/render/utils.py +160 -0
- warp/sim/__init__.py +65 -0
- warp/sim/articulation.py +793 -0
- warp/sim/collide.py +2395 -0
- warp/sim/graph_coloring.py +300 -0
- warp/sim/import_mjcf.py +790 -0
- warp/sim/import_snu.py +227 -0
- warp/sim/import_urdf.py +579 -0
- warp/sim/import_usd.py +894 -0
- warp/sim/inertia.py +324 -0
- warp/sim/integrator.py +242 -0
- warp/sim/integrator_euler.py +1997 -0
- warp/sim/integrator_featherstone.py +2101 -0
- warp/sim/integrator_vbd.py +2048 -0
- warp/sim/integrator_xpbd.py +3292 -0
- warp/sim/model.py +4791 -0
- warp/sim/particles.py +121 -0
- warp/sim/render.py +427 -0
- warp/sim/utils.py +428 -0
- warp/sparse.py +2057 -0
- warp/stubs.py +3333 -0
- warp/tape.py +1203 -0
- warp/tests/__init__.py +1 -0
- warp/tests/__main__.py +4 -0
- warp/tests/assets/curlnoise_golden.npy +0 -0
- warp/tests/assets/mlp_golden.npy +0 -0
- warp/tests/assets/pixel.npy +0 -0
- warp/tests/assets/pnoise_golden.npy +0 -0
- warp/tests/assets/spiky.usd +0 -0
- warp/tests/assets/test_grid.nvdb +0 -0
- warp/tests/assets/test_index_grid.nvdb +0 -0
- warp/tests/assets/test_int32_grid.nvdb +0 -0
- warp/tests/assets/test_vec_grid.nvdb +0 -0
- warp/tests/assets/torus.nvdb +0 -0
- warp/tests/assets/torus.usda +105 -0
- warp/tests/aux_test_class_kernel.py +34 -0
- warp/tests/aux_test_compile_consts_dummy.py +18 -0
- warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
- warp/tests/aux_test_dependent.py +29 -0
- warp/tests/aux_test_grad_customs.py +29 -0
- warp/tests/aux_test_instancing_gc.py +26 -0
- warp/tests/aux_test_module_unload.py +23 -0
- warp/tests/aux_test_name_clash1.py +40 -0
- warp/tests/aux_test_name_clash2.py +40 -0
- warp/tests/aux_test_reference.py +9 -0
- warp/tests/aux_test_reference_reference.py +8 -0
- warp/tests/aux_test_square.py +16 -0
- warp/tests/aux_test_unresolved_func.py +22 -0
- warp/tests/aux_test_unresolved_symbol.py +22 -0
- warp/tests/cuda/__init__.py +0 -0
- warp/tests/cuda/test_async.py +676 -0
- warp/tests/cuda/test_ipc.py +124 -0
- warp/tests/cuda/test_mempool.py +233 -0
- warp/tests/cuda/test_multigpu.py +169 -0
- warp/tests/cuda/test_peer.py +139 -0
- warp/tests/cuda/test_pinned.py +84 -0
- warp/tests/cuda/test_streams.py +634 -0
- warp/tests/geometry/__init__.py +0 -0
- warp/tests/geometry/test_bvh.py +200 -0
- warp/tests/geometry/test_hash_grid.py +221 -0
- warp/tests/geometry/test_marching_cubes.py +74 -0
- warp/tests/geometry/test_mesh.py +316 -0
- warp/tests/geometry/test_mesh_query_aabb.py +399 -0
- warp/tests/geometry/test_mesh_query_point.py +932 -0
- warp/tests/geometry/test_mesh_query_ray.py +311 -0
- warp/tests/geometry/test_volume.py +1103 -0
- warp/tests/geometry/test_volume_write.py +346 -0
- warp/tests/interop/__init__.py +0 -0
- warp/tests/interop/test_dlpack.py +729 -0
- warp/tests/interop/test_jax.py +371 -0
- warp/tests/interop/test_paddle.py +800 -0
- warp/tests/interop/test_torch.py +1001 -0
- warp/tests/run_coverage_serial.py +39 -0
- warp/tests/sim/__init__.py +0 -0
- warp/tests/sim/disabled_kinematics.py +244 -0
- warp/tests/sim/flaky_test_sim_grad.py +290 -0
- warp/tests/sim/test_collision.py +604 -0
- warp/tests/sim/test_coloring.py +258 -0
- warp/tests/sim/test_model.py +224 -0
- warp/tests/sim/test_sim_grad_bounce_linear.py +212 -0
- warp/tests/sim/test_sim_kinematics.py +98 -0
- warp/tests/sim/test_vbd.py +597 -0
- warp/tests/test_adam.py +163 -0
- warp/tests/test_arithmetic.py +1096 -0
- warp/tests/test_array.py +2972 -0
- warp/tests/test_array_reduce.py +156 -0
- warp/tests/test_assert.py +250 -0
- warp/tests/test_atomic.py +153 -0
- warp/tests/test_bool.py +220 -0
- warp/tests/test_builtins_resolution.py +1298 -0
- warp/tests/test_closest_point_edge_edge.py +327 -0
- warp/tests/test_codegen.py +810 -0
- warp/tests/test_codegen_instancing.py +1495 -0
- warp/tests/test_compile_consts.py +215 -0
- warp/tests/test_conditional.py +252 -0
- warp/tests/test_context.py +42 -0
- warp/tests/test_copy.py +238 -0
- warp/tests/test_ctypes.py +638 -0
- warp/tests/test_dense.py +73 -0
- warp/tests/test_devices.py +97 -0
- warp/tests/test_examples.py +482 -0
- warp/tests/test_fabricarray.py +996 -0
- warp/tests/test_fast_math.py +74 -0
- warp/tests/test_fem.py +2003 -0
- warp/tests/test_fp16.py +136 -0
- warp/tests/test_func.py +454 -0
- warp/tests/test_future_annotations.py +98 -0
- warp/tests/test_generics.py +656 -0
- warp/tests/test_grad.py +893 -0
- warp/tests/test_grad_customs.py +339 -0
- warp/tests/test_grad_debug.py +341 -0
- warp/tests/test_implicit_init.py +411 -0
- warp/tests/test_import.py +45 -0
- warp/tests/test_indexedarray.py +1140 -0
- warp/tests/test_intersect.py +73 -0
- warp/tests/test_iter.py +76 -0
- warp/tests/test_large.py +177 -0
- warp/tests/test_launch.py +411 -0
- warp/tests/test_lerp.py +151 -0
- warp/tests/test_linear_solvers.py +193 -0
- warp/tests/test_lvalue.py +427 -0
- warp/tests/test_mat.py +2089 -0
- warp/tests/test_mat_lite.py +122 -0
- warp/tests/test_mat_scalar_ops.py +2913 -0
- warp/tests/test_math.py +178 -0
- warp/tests/test_mlp.py +282 -0
- warp/tests/test_module_hashing.py +258 -0
- warp/tests/test_modules_lite.py +44 -0
- warp/tests/test_noise.py +252 -0
- warp/tests/test_operators.py +299 -0
- warp/tests/test_options.py +129 -0
- warp/tests/test_overwrite.py +551 -0
- warp/tests/test_print.py +339 -0
- warp/tests/test_quat.py +2315 -0
- warp/tests/test_rand.py +339 -0
- warp/tests/test_reload.py +302 -0
- warp/tests/test_rounding.py +185 -0
- warp/tests/test_runlength_encode.py +196 -0
- warp/tests/test_scalar_ops.py +105 -0
- warp/tests/test_smoothstep.py +108 -0
- warp/tests/test_snippet.py +318 -0
- warp/tests/test_sparse.py +582 -0
- warp/tests/test_spatial.py +2229 -0
- warp/tests/test_special_values.py +361 -0
- warp/tests/test_static.py +592 -0
- warp/tests/test_struct.py +734 -0
- warp/tests/test_tape.py +204 -0
- warp/tests/test_transient_module.py +93 -0
- warp/tests/test_triangle_closest_point.py +145 -0
- warp/tests/test_types.py +562 -0
- warp/tests/test_utils.py +588 -0
- warp/tests/test_vec.py +1487 -0
- warp/tests/test_vec_lite.py +80 -0
- warp/tests/test_vec_scalar_ops.py +2327 -0
- warp/tests/test_verify_fp.py +100 -0
- warp/tests/tile/__init__.py +0 -0
- warp/tests/tile/test_tile.py +780 -0
- warp/tests/tile/test_tile_load.py +407 -0
- warp/tests/tile/test_tile_mathdx.py +208 -0
- warp/tests/tile/test_tile_mlp.py +402 -0
- warp/tests/tile/test_tile_reduce.py +447 -0
- warp/tests/tile/test_tile_shared_memory.py +247 -0
- warp/tests/tile/test_tile_view.py +173 -0
- warp/tests/unittest_serial.py +47 -0
- warp/tests/unittest_suites.py +427 -0
- warp/tests/unittest_utils.py +468 -0
- warp/tests/walkthrough_debug.py +93 -0
- warp/thirdparty/__init__.py +0 -0
- warp/thirdparty/appdirs.py +598 -0
- warp/thirdparty/dlpack.py +145 -0
- warp/thirdparty/unittest_parallel.py +570 -0
- warp/torch.py +391 -0
- warp/types.py +5230 -0
- warp/utils.py +1137 -0
- warp_lang-1.7.0.dist-info/METADATA +516 -0
- warp_lang-1.7.0.dist-info/RECORD +429 -0
- warp_lang-1.7.0.dist-info/WHEEL +5 -0
- warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
- warp_lang-1.7.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,729 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import ctypes
|
|
17
|
+
import os
|
|
18
|
+
import unittest
|
|
19
|
+
|
|
20
|
+
import numpy as np
|
|
21
|
+
|
|
22
|
+
import warp as wp
|
|
23
|
+
from warp.tests.unittest_utils import *
|
|
24
|
+
|
|
25
|
+
N = 1024 * 1024
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _jax_version():
|
|
29
|
+
try:
|
|
30
|
+
import jax
|
|
31
|
+
|
|
32
|
+
return jax.__version_info__
|
|
33
|
+
except (ImportError, AttributeError):
|
|
34
|
+
return (0, 0, 0)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@wp.kernel
|
|
38
|
+
def inc(a: wp.array(dtype=float)):
|
|
39
|
+
tid = wp.tid()
|
|
40
|
+
a[tid] = a[tid] + 1.0
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def test_dlpack_warp_to_warp(test, device):
|
|
44
|
+
a1 = wp.array(data=np.arange(N, dtype=np.float32), device=device)
|
|
45
|
+
|
|
46
|
+
a2 = wp.from_dlpack(wp.to_dlpack(a1))
|
|
47
|
+
|
|
48
|
+
test.assertEqual(a1.ptr, a2.ptr)
|
|
49
|
+
test.assertEqual(a1.device, a2.device)
|
|
50
|
+
test.assertEqual(a1.dtype, a2.dtype)
|
|
51
|
+
test.assertEqual(a1.shape, a2.shape)
|
|
52
|
+
test.assertEqual(a1.strides, a2.strides)
|
|
53
|
+
|
|
54
|
+
assert_np_equal(a1.numpy(), a2.numpy())
|
|
55
|
+
|
|
56
|
+
wp.launch(inc, dim=a2.size, inputs=[a2], device=device)
|
|
57
|
+
|
|
58
|
+
assert_np_equal(a1.numpy(), a2.numpy())
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def test_dlpack_dtypes_and_shapes(test, device):
|
|
62
|
+
# automatically determine scalar dtype
|
|
63
|
+
def wrap_scalar_tensor_implicit(dtype):
|
|
64
|
+
a1 = wp.zeros(N, dtype=dtype, device=device)
|
|
65
|
+
a2 = wp.from_dlpack(wp.to_dlpack(a1))
|
|
66
|
+
|
|
67
|
+
test.assertEqual(a1.ptr, a2.ptr)
|
|
68
|
+
test.assertEqual(a1.device, a2.device)
|
|
69
|
+
test.assertEqual(a1.dtype, a2.dtype)
|
|
70
|
+
test.assertEqual(a1.shape, a2.shape)
|
|
71
|
+
test.assertEqual(a1.strides, a2.strides)
|
|
72
|
+
|
|
73
|
+
# explicitly specify scalar dtype
|
|
74
|
+
def wrap_scalar_tensor_explicit(dtype, target_dtype):
|
|
75
|
+
a1 = wp.zeros(N, dtype=dtype, device=device)
|
|
76
|
+
a2 = wp.from_dlpack(wp.to_dlpack(a1), dtype=target_dtype)
|
|
77
|
+
|
|
78
|
+
test.assertEqual(a1.ptr, a2.ptr)
|
|
79
|
+
test.assertEqual(a1.device, a2.device)
|
|
80
|
+
test.assertEqual(a1.dtype, dtype)
|
|
81
|
+
test.assertEqual(a2.dtype, target_dtype)
|
|
82
|
+
test.assertEqual(a1.shape, a2.shape)
|
|
83
|
+
test.assertEqual(a1.strides, a2.strides)
|
|
84
|
+
|
|
85
|
+
# convert vector arrays to scalar arrays
|
|
86
|
+
def wrap_vector_to_scalar_tensor(vec_dtype):
|
|
87
|
+
scalar_type = vec_dtype._wp_scalar_type_
|
|
88
|
+
scalar_size = ctypes.sizeof(vec_dtype._type_)
|
|
89
|
+
|
|
90
|
+
a1 = wp.zeros(N, dtype=vec_dtype, device=device)
|
|
91
|
+
a2 = wp.from_dlpack(wp.to_dlpack(a1), dtype=scalar_type)
|
|
92
|
+
|
|
93
|
+
test.assertEqual(a1.ptr, a2.ptr)
|
|
94
|
+
test.assertEqual(a1.device, a2.device)
|
|
95
|
+
test.assertEqual(a2.ndim, a1.ndim + 1)
|
|
96
|
+
test.assertEqual(a1.dtype, vec_dtype)
|
|
97
|
+
test.assertEqual(a2.dtype, scalar_type)
|
|
98
|
+
test.assertEqual(a2.shape, (*a1.shape, vec_dtype._length_))
|
|
99
|
+
test.assertEqual(a2.strides, (*a1.strides, scalar_size))
|
|
100
|
+
|
|
101
|
+
# convert scalar arrays to vector arrays
|
|
102
|
+
def wrap_scalar_to_vector_tensor(vec_dtype):
|
|
103
|
+
scalar_type = vec_dtype._wp_scalar_type_
|
|
104
|
+
scalar_size = ctypes.sizeof(vec_dtype._type_)
|
|
105
|
+
|
|
106
|
+
a1 = wp.zeros((N, vec_dtype._length_), dtype=scalar_type, device=device)
|
|
107
|
+
a2 = wp.from_dlpack(wp.to_dlpack(a1), dtype=vec_dtype)
|
|
108
|
+
|
|
109
|
+
test.assertEqual(a1.ptr, a2.ptr)
|
|
110
|
+
test.assertEqual(a1.device, a2.device)
|
|
111
|
+
test.assertEqual(a2.ndim, a1.ndim - 1)
|
|
112
|
+
test.assertEqual(a1.dtype, scalar_type)
|
|
113
|
+
test.assertEqual(a2.dtype, vec_dtype)
|
|
114
|
+
test.assertEqual(a1.shape, (*a2.shape, vec_dtype._length_))
|
|
115
|
+
test.assertEqual(a1.strides, (*a2.strides, scalar_size))
|
|
116
|
+
|
|
117
|
+
# convert matrix arrays to scalar arrays
|
|
118
|
+
def wrap_matrix_to_scalar_tensor(mat_dtype):
|
|
119
|
+
scalar_type = mat_dtype._wp_scalar_type_
|
|
120
|
+
scalar_size = ctypes.sizeof(mat_dtype._type_)
|
|
121
|
+
|
|
122
|
+
a1 = wp.zeros(N, dtype=mat_dtype, device=device)
|
|
123
|
+
a2 = wp.from_dlpack(wp.to_dlpack(a1), dtype=scalar_type)
|
|
124
|
+
|
|
125
|
+
test.assertEqual(a1.ptr, a2.ptr)
|
|
126
|
+
test.assertEqual(a1.device, a2.device)
|
|
127
|
+
test.assertEqual(a2.ndim, a1.ndim + 2)
|
|
128
|
+
test.assertEqual(a1.dtype, mat_dtype)
|
|
129
|
+
test.assertEqual(a2.dtype, scalar_type)
|
|
130
|
+
test.assertEqual(a2.shape, (*a1.shape, *mat_dtype._shape_))
|
|
131
|
+
test.assertEqual(a2.strides, (*a1.strides, scalar_size * mat_dtype._shape_[1], scalar_size))
|
|
132
|
+
|
|
133
|
+
# convert scalar arrays to matrix arrays
|
|
134
|
+
def wrap_scalar_to_matrix_tensor(mat_dtype):
|
|
135
|
+
scalar_type = mat_dtype._wp_scalar_type_
|
|
136
|
+
scalar_size = ctypes.sizeof(mat_dtype._type_)
|
|
137
|
+
|
|
138
|
+
a1 = wp.zeros((N, *mat_dtype._shape_), dtype=scalar_type, device=device)
|
|
139
|
+
a2 = wp.from_dlpack(wp.to_dlpack(a1), dtype=mat_dtype)
|
|
140
|
+
|
|
141
|
+
test.assertEqual(a1.ptr, a2.ptr)
|
|
142
|
+
test.assertEqual(a1.device, a2.device)
|
|
143
|
+
test.assertEqual(a2.ndim, a1.ndim - 2)
|
|
144
|
+
test.assertEqual(a1.dtype, scalar_type)
|
|
145
|
+
test.assertEqual(a2.dtype, mat_dtype)
|
|
146
|
+
test.assertEqual(a1.shape, (*a2.shape, *mat_dtype._shape_))
|
|
147
|
+
test.assertEqual(a1.strides, (*a2.strides, scalar_size * mat_dtype._shape_[1], scalar_size))
|
|
148
|
+
|
|
149
|
+
for t in wp.types.scalar_types:
|
|
150
|
+
wrap_scalar_tensor_implicit(t)
|
|
151
|
+
|
|
152
|
+
for t in wp.types.scalar_types:
|
|
153
|
+
wrap_scalar_tensor_explicit(t, t)
|
|
154
|
+
|
|
155
|
+
# test signed/unsigned conversions
|
|
156
|
+
wrap_scalar_tensor_explicit(wp.int8, wp.uint8)
|
|
157
|
+
wrap_scalar_tensor_explicit(wp.uint8, wp.int8)
|
|
158
|
+
wrap_scalar_tensor_explicit(wp.int16, wp.uint16)
|
|
159
|
+
wrap_scalar_tensor_explicit(wp.uint16, wp.int16)
|
|
160
|
+
wrap_scalar_tensor_explicit(wp.int32, wp.uint32)
|
|
161
|
+
wrap_scalar_tensor_explicit(wp.uint32, wp.int32)
|
|
162
|
+
wrap_scalar_tensor_explicit(wp.int64, wp.uint64)
|
|
163
|
+
wrap_scalar_tensor_explicit(wp.uint64, wp.int64)
|
|
164
|
+
|
|
165
|
+
vec_types = []
|
|
166
|
+
for t in wp.types.scalar_types:
|
|
167
|
+
for vec_len in [2, 3, 4, 5]:
|
|
168
|
+
vec_types.append(wp.types.vector(vec_len, t))
|
|
169
|
+
|
|
170
|
+
vec_types.append(wp.quath)
|
|
171
|
+
vec_types.append(wp.quatf)
|
|
172
|
+
vec_types.append(wp.quatd)
|
|
173
|
+
vec_types.append(wp.transformh)
|
|
174
|
+
vec_types.append(wp.transformf)
|
|
175
|
+
vec_types.append(wp.transformd)
|
|
176
|
+
vec_types.append(wp.spatial_vectorh)
|
|
177
|
+
vec_types.append(wp.spatial_vectorf)
|
|
178
|
+
vec_types.append(wp.spatial_vectord)
|
|
179
|
+
|
|
180
|
+
for vec_type in vec_types:
|
|
181
|
+
wrap_vector_to_scalar_tensor(vec_type)
|
|
182
|
+
wrap_scalar_to_vector_tensor(vec_type)
|
|
183
|
+
|
|
184
|
+
mat_shapes = [(2, 2), (3, 3), (4, 4), (5, 5), (2, 3), (3, 2), (3, 4), (4, 3)]
|
|
185
|
+
mat_types = []
|
|
186
|
+
for t in wp.types.scalar_types:
|
|
187
|
+
for mat_shape in mat_shapes:
|
|
188
|
+
mat_types.append(wp.types.matrix(mat_shape, t))
|
|
189
|
+
|
|
190
|
+
mat_types.append(wp.spatial_matrixh)
|
|
191
|
+
mat_types.append(wp.spatial_matrixf)
|
|
192
|
+
mat_types.append(wp.spatial_matrixd)
|
|
193
|
+
|
|
194
|
+
for mat_type in mat_types:
|
|
195
|
+
wrap_matrix_to_scalar_tensor(mat_type)
|
|
196
|
+
wrap_scalar_to_matrix_tensor(mat_type)
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def test_dlpack_stream_arg(test, device):
|
|
200
|
+
# test valid range for the stream argument to array.__dlpack__()
|
|
201
|
+
|
|
202
|
+
data = np.arange(10)
|
|
203
|
+
|
|
204
|
+
def check_result(capsule):
|
|
205
|
+
result = wp.dlpack._from_dlpack(capsule)
|
|
206
|
+
assert_np_equal(result.numpy(), data)
|
|
207
|
+
|
|
208
|
+
with wp.ScopedDevice(device):
|
|
209
|
+
a = wp.array(data=data)
|
|
210
|
+
|
|
211
|
+
# stream arguments supported for all devices
|
|
212
|
+
check_result(a.__dlpack__())
|
|
213
|
+
check_result(a.__dlpack__(stream=None))
|
|
214
|
+
check_result(a.__dlpack__(stream=-1))
|
|
215
|
+
|
|
216
|
+
# device-specific stream arguments
|
|
217
|
+
if device.is_cuda:
|
|
218
|
+
check_result(a.__dlpack__(stream=0)) # default stream
|
|
219
|
+
check_result(a.__dlpack__(stream=1)) # legacy default stream
|
|
220
|
+
check_result(a.__dlpack__(stream=2)) # per thread default stream
|
|
221
|
+
|
|
222
|
+
# custom stream
|
|
223
|
+
stream = wp.Stream(device)
|
|
224
|
+
check_result(a.__dlpack__(stream=stream.cuda_stream))
|
|
225
|
+
|
|
226
|
+
# unsupported stream arguments
|
|
227
|
+
expected_error = r"DLPack stream must None or an integer >= -1"
|
|
228
|
+
with test.assertRaisesRegex(TypeError, expected_error):
|
|
229
|
+
check_result(a.__dlpack__(stream=-2))
|
|
230
|
+
with test.assertRaisesRegex(TypeError, expected_error):
|
|
231
|
+
check_result(a.__dlpack__(stream="nope"))
|
|
232
|
+
else:
|
|
233
|
+
expected_error = r"DLPack stream must be None or -1 for CPU device"
|
|
234
|
+
|
|
235
|
+
with test.assertRaisesRegex(TypeError, expected_error):
|
|
236
|
+
check_result(a.__dlpack__(stream=0))
|
|
237
|
+
with test.assertRaisesRegex(TypeError, expected_error):
|
|
238
|
+
check_result(a.__dlpack__(stream=1))
|
|
239
|
+
with test.assertRaisesRegex(TypeError, expected_error):
|
|
240
|
+
check_result(a.__dlpack__(stream=2))
|
|
241
|
+
with test.assertRaisesRegex(TypeError, expected_error):
|
|
242
|
+
check_result(a.__dlpack__(stream=1742))
|
|
243
|
+
|
|
244
|
+
with test.assertRaisesRegex(TypeError, expected_error):
|
|
245
|
+
check_result(a.__dlpack__(stream=-2))
|
|
246
|
+
with test.assertRaisesRegex(TypeError, expected_error):
|
|
247
|
+
check_result(a.__dlpack__(stream="nope"))
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
def test_dlpack_warp_to_torch(test, device):
|
|
251
|
+
import torch.utils.dlpack
|
|
252
|
+
|
|
253
|
+
a = wp.array(data=np.arange(N, dtype=np.float32), device=device)
|
|
254
|
+
|
|
255
|
+
t = torch.utils.dlpack.from_dlpack(wp.to_dlpack(a))
|
|
256
|
+
|
|
257
|
+
item_size = wp.types.type_size_in_bytes(a.dtype)
|
|
258
|
+
|
|
259
|
+
test.assertEqual(a.ptr, t.data_ptr())
|
|
260
|
+
test.assertEqual(a.device, wp.device_from_torch(t.device))
|
|
261
|
+
test.assertEqual(a.dtype, wp.dtype_from_torch(t.dtype))
|
|
262
|
+
test.assertEqual(a.shape, tuple(t.shape))
|
|
263
|
+
test.assertEqual(a.strides, tuple(s * item_size for s in t.stride()))
|
|
264
|
+
|
|
265
|
+
assert_np_equal(a.numpy(), t.cpu().numpy())
|
|
266
|
+
|
|
267
|
+
wp.launch(inc, dim=a.size, inputs=[a], device=device)
|
|
268
|
+
|
|
269
|
+
assert_np_equal(a.numpy(), t.cpu().numpy())
|
|
270
|
+
|
|
271
|
+
t += 1
|
|
272
|
+
|
|
273
|
+
assert_np_equal(a.numpy(), t.cpu().numpy())
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
def test_dlpack_warp_to_torch_v2(test, device):
|
|
277
|
+
# same as original test, but uses newer __dlpack__() method
|
|
278
|
+
|
|
279
|
+
import torch.utils.dlpack
|
|
280
|
+
|
|
281
|
+
a = wp.array(data=np.arange(N, dtype=np.float32), device=device)
|
|
282
|
+
|
|
283
|
+
# pass the array directly
|
|
284
|
+
t = torch.utils.dlpack.from_dlpack(a)
|
|
285
|
+
|
|
286
|
+
item_size = wp.types.type_size_in_bytes(a.dtype)
|
|
287
|
+
|
|
288
|
+
test.assertEqual(a.ptr, t.data_ptr())
|
|
289
|
+
test.assertEqual(a.device, wp.device_from_torch(t.device))
|
|
290
|
+
test.assertEqual(a.dtype, wp.dtype_from_torch(t.dtype))
|
|
291
|
+
test.assertEqual(a.shape, tuple(t.shape))
|
|
292
|
+
test.assertEqual(a.strides, tuple(s * item_size for s in t.stride()))
|
|
293
|
+
|
|
294
|
+
assert_np_equal(a.numpy(), t.cpu().numpy())
|
|
295
|
+
|
|
296
|
+
wp.launch(inc, dim=a.size, inputs=[a], device=device)
|
|
297
|
+
|
|
298
|
+
assert_np_equal(a.numpy(), t.cpu().numpy())
|
|
299
|
+
|
|
300
|
+
t += 1
|
|
301
|
+
|
|
302
|
+
assert_np_equal(a.numpy(), t.cpu().numpy())
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
def test_dlpack_torch_to_warp(test, device):
|
|
306
|
+
import torch
|
|
307
|
+
import torch.utils.dlpack
|
|
308
|
+
|
|
309
|
+
t = torch.arange(N, dtype=torch.float32, device=wp.device_to_torch(device))
|
|
310
|
+
|
|
311
|
+
a = wp.from_dlpack(torch.utils.dlpack.to_dlpack(t))
|
|
312
|
+
|
|
313
|
+
item_size = wp.types.type_size_in_bytes(a.dtype)
|
|
314
|
+
|
|
315
|
+
test.assertEqual(a.ptr, t.data_ptr())
|
|
316
|
+
test.assertEqual(a.device, wp.device_from_torch(t.device))
|
|
317
|
+
test.assertEqual(a.dtype, wp.dtype_from_torch(t.dtype))
|
|
318
|
+
test.assertEqual(a.shape, tuple(t.shape))
|
|
319
|
+
test.assertEqual(a.strides, tuple(s * item_size for s in t.stride()))
|
|
320
|
+
|
|
321
|
+
assert_np_equal(a.numpy(), t.cpu().numpy())
|
|
322
|
+
|
|
323
|
+
wp.launch(inc, dim=a.size, inputs=[a], device=device)
|
|
324
|
+
|
|
325
|
+
assert_np_equal(a.numpy(), t.cpu().numpy())
|
|
326
|
+
|
|
327
|
+
t += 1
|
|
328
|
+
|
|
329
|
+
assert_np_equal(a.numpy(), t.cpu().numpy())
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
def test_dlpack_torch_to_warp_v2(test, device):
|
|
333
|
+
# same as original test, but uses newer __dlpack__() method
|
|
334
|
+
|
|
335
|
+
import torch
|
|
336
|
+
|
|
337
|
+
t = torch.arange(N, dtype=torch.float32, device=wp.device_to_torch(device))
|
|
338
|
+
|
|
339
|
+
# pass tensor directly
|
|
340
|
+
a = wp.from_dlpack(t)
|
|
341
|
+
|
|
342
|
+
item_size = wp.types.type_size_in_bytes(a.dtype)
|
|
343
|
+
|
|
344
|
+
test.assertEqual(a.ptr, t.data_ptr())
|
|
345
|
+
test.assertEqual(a.device, wp.device_from_torch(t.device))
|
|
346
|
+
test.assertEqual(a.dtype, wp.dtype_from_torch(t.dtype))
|
|
347
|
+
test.assertEqual(a.shape, tuple(t.shape))
|
|
348
|
+
test.assertEqual(a.strides, tuple(s * item_size for s in t.stride()))
|
|
349
|
+
|
|
350
|
+
assert_np_equal(a.numpy(), t.cpu().numpy())
|
|
351
|
+
|
|
352
|
+
wp.launch(inc, dim=a.size, inputs=[a], device=device)
|
|
353
|
+
|
|
354
|
+
assert_np_equal(a.numpy(), t.cpu().numpy())
|
|
355
|
+
|
|
356
|
+
t += 1
|
|
357
|
+
|
|
358
|
+
assert_np_equal(a.numpy(), t.cpu().numpy())
|
|
359
|
+
|
|
360
|
+
|
|
361
|
+
def test_dlpack_paddle_to_warp(test, device):
|
|
362
|
+
import paddle
|
|
363
|
+
import paddle.utils.dlpack
|
|
364
|
+
|
|
365
|
+
t = paddle.arange(N, dtype=paddle.float32).to(device=wp.device_to_paddle(device))
|
|
366
|
+
|
|
367
|
+
# paddle do not implement __dlpack__ yet, so only test to_dlpack here
|
|
368
|
+
a = wp.from_dlpack(paddle.utils.dlpack.to_dlpack(t))
|
|
369
|
+
|
|
370
|
+
item_size = wp.types.type_size_in_bytes(a.dtype)
|
|
371
|
+
|
|
372
|
+
test.assertEqual(a.ptr, t.data_ptr())
|
|
373
|
+
test.assertEqual(a.device, wp.device_from_paddle(t.place))
|
|
374
|
+
test.assertEqual(a.dtype, wp.dtype_from_paddle(t.dtype))
|
|
375
|
+
test.assertEqual(a.shape, tuple(t.shape))
|
|
376
|
+
test.assertEqual(a.strides, tuple(s * item_size for s in t.strides))
|
|
377
|
+
|
|
378
|
+
assert_np_equal(a.numpy(), t.numpy())
|
|
379
|
+
|
|
380
|
+
wp.launch(inc, dim=a.size, inputs=[a], device=device)
|
|
381
|
+
|
|
382
|
+
assert_np_equal(a.numpy(), t.numpy())
|
|
383
|
+
|
|
384
|
+
paddle.assign(t + 1, t)
|
|
385
|
+
|
|
386
|
+
assert_np_equal(a.numpy(), t.numpy())
|
|
387
|
+
|
|
388
|
+
|
|
389
|
+
def test_dlpack_warp_to_jax(test, device):
|
|
390
|
+
import jax
|
|
391
|
+
import jax.dlpack
|
|
392
|
+
import jax.numpy as jnp
|
|
393
|
+
|
|
394
|
+
cpu_device = jax.devices("cpu")[0]
|
|
395
|
+
|
|
396
|
+
# Create a numpy array from a JAX array to respect XLA alignment needs
|
|
397
|
+
with jax.default_device(cpu_device):
|
|
398
|
+
x_jax = jnp.arange(N, dtype=jnp.float32)
|
|
399
|
+
x_numpy = np.asarray(x_jax)
|
|
400
|
+
test.assertEqual(x_jax.unsafe_buffer_pointer(), np.lib.array_utils.byte_bounds(x_numpy)[0])
|
|
401
|
+
|
|
402
|
+
a = wp.array(x_numpy, device=device, dtype=wp.float32, copy=False)
|
|
403
|
+
|
|
404
|
+
if device.is_cpu:
|
|
405
|
+
test.assertEqual(a.ptr, np.lib.array_utils.byte_bounds(x_numpy)[0])
|
|
406
|
+
|
|
407
|
+
# use generic dlpack conversion
|
|
408
|
+
j1 = jax.dlpack.from_dlpack(a, copy=False)
|
|
409
|
+
|
|
410
|
+
# use jax wrapper
|
|
411
|
+
j2 = wp.to_jax(a)
|
|
412
|
+
|
|
413
|
+
test.assertEqual(a.ptr, j1.unsafe_buffer_pointer())
|
|
414
|
+
test.assertEqual(a.ptr, j2.unsafe_buffer_pointer())
|
|
415
|
+
test.assertEqual(a.device, wp.device_from_jax(list(j1.devices())[0]))
|
|
416
|
+
test.assertEqual(a.device, wp.device_from_jax(list(j2.devices())[0]))
|
|
417
|
+
test.assertEqual(a.shape, j1.shape)
|
|
418
|
+
test.assertEqual(a.shape, j2.shape)
|
|
419
|
+
|
|
420
|
+
assert_np_equal(a.numpy(), np.asarray(j1))
|
|
421
|
+
assert_np_equal(a.numpy(), np.asarray(j2))
|
|
422
|
+
|
|
423
|
+
wp.launch(inc, dim=a.size, inputs=[a], device=device)
|
|
424
|
+
wp.synchronize_device(device)
|
|
425
|
+
|
|
426
|
+
# HACK? Run a no-op operation so that Jax flags the arrays as dirty
|
|
427
|
+
# and gets the latest values, which were modified by Warp.
|
|
428
|
+
j1 += 0
|
|
429
|
+
j2 += 0
|
|
430
|
+
|
|
431
|
+
assert_np_equal(a.numpy(), np.asarray(j1))
|
|
432
|
+
assert_np_equal(a.numpy(), np.asarray(j2))
|
|
433
|
+
|
|
434
|
+
|
|
435
|
+
@unittest.skipUnless(_jax_version() >= (0, 4, 15), "Jax version too old")
|
|
436
|
+
def test_dlpack_warp_to_jax_v2(test, device):
|
|
437
|
+
# same as original test, but uses newer __dlpack__() method
|
|
438
|
+
import jax
|
|
439
|
+
import jax.dlpack
|
|
440
|
+
import jax.numpy as jnp
|
|
441
|
+
|
|
442
|
+
cpu_device = jax.devices("cpu")[0]
|
|
443
|
+
|
|
444
|
+
# Create a numpy array from a JAX array to respect XLA alignment needs
|
|
445
|
+
with jax.default_device(cpu_device):
|
|
446
|
+
x_jax = jnp.arange(N, dtype=jnp.float32)
|
|
447
|
+
x_numpy = np.asarray(x_jax)
|
|
448
|
+
test.assertEqual(x_jax.unsafe_buffer_pointer(), np.lib.array_utils.byte_bounds(x_numpy)[0])
|
|
449
|
+
|
|
450
|
+
a = wp.array(x_numpy, device=device, dtype=wp.float32, copy=False)
|
|
451
|
+
|
|
452
|
+
if device.is_cpu:
|
|
453
|
+
test.assertEqual(a.ptr, np.lib.array_utils.byte_bounds(x_numpy)[0])
|
|
454
|
+
|
|
455
|
+
# pass warp array directly
|
|
456
|
+
j1 = jax.dlpack.from_dlpack(a, copy=False)
|
|
457
|
+
|
|
458
|
+
# use jax wrapper
|
|
459
|
+
j2 = wp.to_jax(a)
|
|
460
|
+
|
|
461
|
+
test.assertEqual(a.ptr, j1.unsafe_buffer_pointer())
|
|
462
|
+
test.assertEqual(a.ptr, j2.unsafe_buffer_pointer())
|
|
463
|
+
test.assertEqual(a.device, wp.device_from_jax(list(j1.devices())[0]))
|
|
464
|
+
test.assertEqual(a.device, wp.device_from_jax(list(j2.devices())[0]))
|
|
465
|
+
test.assertEqual(a.shape, j1.shape)
|
|
466
|
+
test.assertEqual(a.shape, j2.shape)
|
|
467
|
+
|
|
468
|
+
assert_np_equal(a.numpy(), np.asarray(j1))
|
|
469
|
+
assert_np_equal(a.numpy(), np.asarray(j2))
|
|
470
|
+
|
|
471
|
+
wp.launch(inc, dim=a.size, inputs=[a], device=device)
|
|
472
|
+
wp.synchronize_device(device)
|
|
473
|
+
|
|
474
|
+
# HACK? Run a no-op operation so that Jax flags the arrays as dirty
|
|
475
|
+
# and gets the latest values, which were modified by Warp.
|
|
476
|
+
j1 += 0
|
|
477
|
+
j2 += 0
|
|
478
|
+
|
|
479
|
+
assert_np_equal(a.numpy(), np.asarray(j1))
|
|
480
|
+
assert_np_equal(a.numpy(), np.asarray(j2))
|
|
481
|
+
|
|
482
|
+
|
|
483
|
+
def test_dlpack_warp_to_paddle(test, device):
|
|
484
|
+
import paddle.utils.dlpack
|
|
485
|
+
|
|
486
|
+
a = wp.array(data=np.arange(N, dtype=np.float32), device=device)
|
|
487
|
+
|
|
488
|
+
t = paddle.utils.dlpack.from_dlpack(wp.to_dlpack(a))
|
|
489
|
+
|
|
490
|
+
item_size = wp.types.type_size_in_bytes(a.dtype)
|
|
491
|
+
|
|
492
|
+
test.assertEqual(a.ptr, t.data_ptr())
|
|
493
|
+
test.assertEqual(a.device, wp.device_from_paddle(t.place))
|
|
494
|
+
test.assertEqual(a.dtype, wp.dtype_from_paddle(t.dtype))
|
|
495
|
+
test.assertEqual(a.shape, tuple(t.shape))
|
|
496
|
+
test.assertEqual(a.strides, tuple(s * item_size for s in t.strides))
|
|
497
|
+
|
|
498
|
+
assert_np_equal(a.numpy(), t.cpu().numpy())
|
|
499
|
+
|
|
500
|
+
wp.launch(inc, dim=a.size, inputs=[a], device=device)
|
|
501
|
+
|
|
502
|
+
assert_np_equal(a.numpy(), t.cpu().numpy())
|
|
503
|
+
|
|
504
|
+
paddle.assign(t + 1, t)
|
|
505
|
+
|
|
506
|
+
assert_np_equal(a.numpy(), t.cpu().numpy())
|
|
507
|
+
|
|
508
|
+
|
|
509
|
+
def test_dlpack_warp_to_paddle_v2(test, device):
|
|
510
|
+
# same as original test, but uses newer __dlpack__() method
|
|
511
|
+
|
|
512
|
+
import paddle.utils.dlpack
|
|
513
|
+
|
|
514
|
+
a = wp.array(data=np.arange(N, dtype=np.float32), device=device)
|
|
515
|
+
|
|
516
|
+
# pass the array directly
|
|
517
|
+
t = paddle.utils.dlpack.from_dlpack(a)
|
|
518
|
+
|
|
519
|
+
item_size = wp.types.type_size_in_bytes(a.dtype)
|
|
520
|
+
|
|
521
|
+
test.assertEqual(a.ptr, t.data_ptr())
|
|
522
|
+
test.assertEqual(a.device, wp.device_from_paddle(t.place))
|
|
523
|
+
test.assertEqual(a.dtype, wp.dtype_from_paddle(t.dtype))
|
|
524
|
+
test.assertEqual(a.shape, tuple(t.shape))
|
|
525
|
+
test.assertEqual(a.strides, tuple(s * item_size for s in t.strides))
|
|
526
|
+
|
|
527
|
+
assert_np_equal(a.numpy(), t.numpy())
|
|
528
|
+
|
|
529
|
+
wp.launch(inc, dim=a.size, inputs=[a], device=device)
|
|
530
|
+
|
|
531
|
+
assert_np_equal(a.numpy(), t.numpy())
|
|
532
|
+
|
|
533
|
+
paddle.assign(t + 1, t)
|
|
534
|
+
|
|
535
|
+
assert_np_equal(a.numpy(), t.numpy())
|
|
536
|
+
|
|
537
|
+
|
|
538
|
+
def test_dlpack_jax_to_warp(test, device):
|
|
539
|
+
import jax
|
|
540
|
+
import jax.dlpack
|
|
541
|
+
|
|
542
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
543
|
+
j = jax.numpy.arange(N, dtype=jax.numpy.float32)
|
|
544
|
+
|
|
545
|
+
# use generic dlpack conversion
|
|
546
|
+
a1 = wp.from_dlpack(jax.dlpack.to_dlpack(j))
|
|
547
|
+
|
|
548
|
+
# use jax wrapper
|
|
549
|
+
a2 = wp.from_jax(j)
|
|
550
|
+
|
|
551
|
+
test.assertEqual(a1.ptr, j.unsafe_buffer_pointer())
|
|
552
|
+
test.assertEqual(a2.ptr, j.unsafe_buffer_pointer())
|
|
553
|
+
test.assertEqual(a1.device, wp.device_from_jax(list(j.devices())[0]))
|
|
554
|
+
test.assertEqual(a2.device, wp.device_from_jax(list(j.devices())[0]))
|
|
555
|
+
test.assertEqual(a1.shape, j.shape)
|
|
556
|
+
test.assertEqual(a2.shape, j.shape)
|
|
557
|
+
|
|
558
|
+
assert_np_equal(a1.numpy(), np.asarray(j))
|
|
559
|
+
assert_np_equal(a2.numpy(), np.asarray(j))
|
|
560
|
+
|
|
561
|
+
wp.launch(inc, dim=a1.size, inputs=[a1], device=device)
|
|
562
|
+
wp.synchronize_device(device)
|
|
563
|
+
|
|
564
|
+
# HACK? Run a no-op operation so that Jax flags the array as dirty
|
|
565
|
+
# and gets the latest values, which were modified by Warp.
|
|
566
|
+
j += 0
|
|
567
|
+
|
|
568
|
+
assert_np_equal(a1.numpy(), np.asarray(j))
|
|
569
|
+
assert_np_equal(a2.numpy(), np.asarray(j))
|
|
570
|
+
|
|
571
|
+
|
|
572
|
+
@unittest.skipUnless(_jax_version() >= (0, 4, 15), "Jax version too old")
|
|
573
|
+
def test_dlpack_jax_to_warp_v2(test, device):
|
|
574
|
+
# same as original test, but uses newer __dlpack__() method
|
|
575
|
+
|
|
576
|
+
import jax
|
|
577
|
+
|
|
578
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
579
|
+
j = jax.numpy.arange(N, dtype=jax.numpy.float32)
|
|
580
|
+
|
|
581
|
+
# pass jax array directly
|
|
582
|
+
a1 = wp.from_dlpack(j)
|
|
583
|
+
|
|
584
|
+
# use jax wrapper
|
|
585
|
+
a2 = wp.from_jax(j)
|
|
586
|
+
|
|
587
|
+
test.assertEqual(a1.ptr, j.unsafe_buffer_pointer())
|
|
588
|
+
test.assertEqual(a2.ptr, j.unsafe_buffer_pointer())
|
|
589
|
+
test.assertEqual(a1.device, wp.device_from_jax(list(j.devices())[0]))
|
|
590
|
+
test.assertEqual(a2.device, wp.device_from_jax(list(j.devices())[0]))
|
|
591
|
+
test.assertEqual(a1.shape, j.shape)
|
|
592
|
+
test.assertEqual(a2.shape, j.shape)
|
|
593
|
+
|
|
594
|
+
assert_np_equal(a1.numpy(), np.asarray(j))
|
|
595
|
+
assert_np_equal(a2.numpy(), np.asarray(j))
|
|
596
|
+
|
|
597
|
+
wp.launch(inc, dim=a1.size, inputs=[a1], device=device)
|
|
598
|
+
wp.synchronize_device(device)
|
|
599
|
+
|
|
600
|
+
# HACK? Run a no-op operation so that Jax flags the array as dirty
|
|
601
|
+
# and gets the latest values, which were modified by Warp.
|
|
602
|
+
j += 0
|
|
603
|
+
|
|
604
|
+
assert_np_equal(a1.numpy(), np.asarray(j))
|
|
605
|
+
assert_np_equal(a2.numpy(), np.asarray(j))
|
|
606
|
+
|
|
607
|
+
|
|
608
|
+
class TestDLPack(unittest.TestCase):
|
|
609
|
+
pass
|
|
610
|
+
|
|
611
|
+
|
|
612
|
+
devices = get_test_devices()
|
|
613
|
+
|
|
614
|
+
add_function_test(TestDLPack, "test_dlpack_warp_to_warp", test_dlpack_warp_to_warp, devices=devices)
|
|
615
|
+
add_function_test(TestDLPack, "test_dlpack_dtypes_and_shapes", test_dlpack_dtypes_and_shapes, devices=devices)
|
|
616
|
+
add_function_test(TestDLPack, "test_dlpack_stream_arg", test_dlpack_stream_arg, devices=devices)
|
|
617
|
+
|
|
618
|
+
# torch interop via dlpack
|
|
619
|
+
try:
|
|
620
|
+
import torch
|
|
621
|
+
import torch.utils.dlpack
|
|
622
|
+
|
|
623
|
+
# check which Warp devices work with Torch
|
|
624
|
+
# CUDA devices may fail if Torch was not compiled with CUDA support
|
|
625
|
+
test_devices = get_test_devices()
|
|
626
|
+
torch_compatible_devices = []
|
|
627
|
+
for d in test_devices:
|
|
628
|
+
try:
|
|
629
|
+
t = torch.arange(10, device=wp.device_to_torch(d))
|
|
630
|
+
t += 1
|
|
631
|
+
torch_compatible_devices.append(d)
|
|
632
|
+
except Exception as e:
|
|
633
|
+
print(f"Skipping Torch DLPack tests on device '{d}' due to exception: {e}")
|
|
634
|
+
|
|
635
|
+
if torch_compatible_devices:
|
|
636
|
+
add_function_test(
|
|
637
|
+
TestDLPack, "test_dlpack_warp_to_torch", test_dlpack_warp_to_torch, devices=torch_compatible_devices
|
|
638
|
+
)
|
|
639
|
+
add_function_test(
|
|
640
|
+
TestDLPack, "test_dlpack_warp_to_torch_v2", test_dlpack_warp_to_torch_v2, devices=torch_compatible_devices
|
|
641
|
+
)
|
|
642
|
+
add_function_test(
|
|
643
|
+
TestDLPack, "test_dlpack_torch_to_warp", test_dlpack_torch_to_warp, devices=torch_compatible_devices
|
|
644
|
+
)
|
|
645
|
+
add_function_test(
|
|
646
|
+
TestDLPack, "test_dlpack_torch_to_warp_v2", test_dlpack_torch_to_warp_v2, devices=torch_compatible_devices
|
|
647
|
+
)
|
|
648
|
+
|
|
649
|
+
except Exception as e:
|
|
650
|
+
print(f"Skipping Torch DLPack tests due to exception: {e}")
|
|
651
|
+
|
|
652
|
+
# jax interop via dlpack
|
|
653
|
+
try:
|
|
654
|
+
# prevent Jax from gobbling up GPU memory
|
|
655
|
+
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
|
|
656
|
+
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
|
|
657
|
+
|
|
658
|
+
import jax
|
|
659
|
+
import jax.dlpack
|
|
660
|
+
|
|
661
|
+
# check which Warp devices work with Jax
|
|
662
|
+
# CUDA devices may fail if Jax cannot find a CUDA Toolkit
|
|
663
|
+
test_devices = get_test_devices()
|
|
664
|
+
jax_compatible_devices = []
|
|
665
|
+
for d in test_devices:
|
|
666
|
+
try:
|
|
667
|
+
with jax.default_device(wp.device_to_jax(d)):
|
|
668
|
+
j = jax.numpy.arange(10, dtype=jax.numpy.float32)
|
|
669
|
+
j += 1
|
|
670
|
+
jax_compatible_devices.append(d)
|
|
671
|
+
except Exception as e:
|
|
672
|
+
print(f"Skipping Jax DLPack tests on device '{d}' due to exception: {e}")
|
|
673
|
+
|
|
674
|
+
if jax_compatible_devices:
|
|
675
|
+
add_function_test(
|
|
676
|
+
TestDLPack, "test_dlpack_warp_to_jax", test_dlpack_warp_to_jax, devices=jax_compatible_devices
|
|
677
|
+
)
|
|
678
|
+
add_function_test(
|
|
679
|
+
TestDLPack, "test_dlpack_warp_to_jax_v2", test_dlpack_warp_to_jax_v2, devices=jax_compatible_devices
|
|
680
|
+
)
|
|
681
|
+
add_function_test(
|
|
682
|
+
TestDLPack, "test_dlpack_jax_to_warp", test_dlpack_jax_to_warp, devices=jax_compatible_devices
|
|
683
|
+
)
|
|
684
|
+
add_function_test(
|
|
685
|
+
TestDLPack, "test_dlpack_jax_to_warp_v2", test_dlpack_jax_to_warp_v2, devices=jax_compatible_devices
|
|
686
|
+
)
|
|
687
|
+
|
|
688
|
+
except Exception as e:
|
|
689
|
+
print(f"Skipping Jax DLPack tests due to exception: {e}")
|
|
690
|
+
|
|
691
|
+
|
|
692
|
+
# paddle interop via dlpack
|
|
693
|
+
try:
|
|
694
|
+
import paddle
|
|
695
|
+
import paddle.utils.dlpack
|
|
696
|
+
|
|
697
|
+
# check which Warp devices work with paddle
|
|
698
|
+
# CUDA devices may fail if paddle was not compiled with CUDA support
|
|
699
|
+
test_devices = get_test_devices()
|
|
700
|
+
paddle_compatible_devices = []
|
|
701
|
+
for d in test_devices:
|
|
702
|
+
try:
|
|
703
|
+
t = paddle.arange(10).to(device=wp.device_to_paddle(d))
|
|
704
|
+
paddle.assign(t + 1, t)
|
|
705
|
+
paddle_compatible_devices.append(d)
|
|
706
|
+
except Exception as e:
|
|
707
|
+
print(f"Skipping paddle DLPack tests on device '{d}' due to exception: {e}")
|
|
708
|
+
|
|
709
|
+
if paddle_compatible_devices:
|
|
710
|
+
add_function_test(
|
|
711
|
+
TestDLPack, "test_dlpack_warp_to_paddle", test_dlpack_warp_to_paddle, devices=paddle_compatible_devices
|
|
712
|
+
)
|
|
713
|
+
add_function_test(
|
|
714
|
+
TestDLPack,
|
|
715
|
+
"test_dlpack_warp_to_paddle_v2",
|
|
716
|
+
test_dlpack_warp_to_paddle_v2,
|
|
717
|
+
devices=paddle_compatible_devices,
|
|
718
|
+
)
|
|
719
|
+
add_function_test(
|
|
720
|
+
TestDLPack, "test_dlpack_paddle_to_warp", test_dlpack_paddle_to_warp, devices=paddle_compatible_devices
|
|
721
|
+
)
|
|
722
|
+
|
|
723
|
+
except Exception as e:
|
|
724
|
+
print(f"Skipping Paddle DLPack tests due to exception: {e}")
|
|
725
|
+
|
|
726
|
+
|
|
727
|
+
if __name__ == "__main__":
|
|
728
|
+
wp.clear_kernel_cache()
|
|
729
|
+
unittest.main(verbosity=2)
|