warp-lang 1.9.0__py3-none-win_amd64.whl → 1.10.0rc2__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +301 -287
- warp/__init__.pyi +2220 -313
- warp/_src/__init__.py +14 -0
- warp/_src/autograd.py +1075 -0
- warp/_src/build.py +618 -0
- warp/_src/build_dll.py +640 -0
- warp/{builtins.py → _src/builtins.py} +1497 -226
- warp/_src/codegen.py +4359 -0
- warp/{config.py → _src/config.py} +178 -169
- warp/_src/constants.py +57 -0
- warp/_src/context.py +8294 -0
- warp/_src/dlpack.py +462 -0
- warp/_src/fabric.py +355 -0
- warp/_src/fem/__init__.py +14 -0
- warp/_src/fem/adaptivity.py +508 -0
- warp/_src/fem/cache.py +687 -0
- warp/_src/fem/dirichlet.py +188 -0
- warp/{fem → _src/fem}/domain.py +40 -30
- warp/_src/fem/field/__init__.py +131 -0
- warp/_src/fem/field/field.py +701 -0
- warp/{fem → _src/fem}/field/nodal_field.py +30 -15
- warp/{fem → _src/fem}/field/restriction.py +1 -1
- warp/{fem → _src/fem}/field/virtual.py +53 -27
- warp/_src/fem/geometry/__init__.py +32 -0
- warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +77 -163
- warp/_src/fem/geometry/closest_point.py +97 -0
- warp/{fem → _src/fem}/geometry/deformed_geometry.py +14 -22
- warp/{fem → _src/fem}/geometry/element.py +32 -10
- warp/{fem → _src/fem}/geometry/geometry.py +48 -20
- warp/{fem → _src/fem}/geometry/grid_2d.py +12 -23
- warp/{fem → _src/fem}/geometry/grid_3d.py +12 -23
- warp/{fem → _src/fem}/geometry/hexmesh.py +40 -63
- warp/{fem → _src/fem}/geometry/nanogrid.py +255 -248
- warp/{fem → _src/fem}/geometry/partition.py +121 -63
- warp/{fem → _src/fem}/geometry/quadmesh.py +26 -45
- warp/{fem → _src/fem}/geometry/tetmesh.py +40 -63
- warp/{fem → _src/fem}/geometry/trimesh.py +26 -45
- warp/{fem → _src/fem}/integrate.py +164 -158
- warp/_src/fem/linalg.py +383 -0
- warp/_src/fem/operator.py +396 -0
- warp/_src/fem/polynomial.py +229 -0
- warp/{fem → _src/fem}/quadrature/pic_quadrature.py +15 -20
- warp/{fem → _src/fem}/quadrature/quadrature.py +95 -47
- warp/_src/fem/space/__init__.py +248 -0
- warp/{fem → _src/fem}/space/basis_function_space.py +20 -11
- warp/_src/fem/space/basis_space.py +679 -0
- warp/{fem → _src/fem}/space/dof_mapper.py +3 -3
- warp/{fem → _src/fem}/space/function_space.py +14 -13
- warp/{fem → _src/fem}/space/grid_2d_function_space.py +4 -7
- warp/{fem → _src/fem}/space/grid_3d_function_space.py +4 -4
- warp/{fem → _src/fem}/space/hexmesh_function_space.py +4 -10
- warp/{fem → _src/fem}/space/nanogrid_function_space.py +3 -9
- warp/{fem → _src/fem}/space/partition.py +117 -60
- warp/{fem → _src/fem}/space/quadmesh_function_space.py +4 -10
- warp/{fem → _src/fem}/space/restriction.py +66 -33
- warp/_src/fem/space/shape/__init__.py +152 -0
- warp/{fem → _src/fem}/space/shape/cube_shape_function.py +9 -9
- warp/{fem → _src/fem}/space/shape/shape_function.py +8 -9
- warp/{fem → _src/fem}/space/shape/square_shape_function.py +6 -6
- warp/{fem → _src/fem}/space/shape/tet_shape_function.py +3 -3
- warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +3 -3
- warp/{fem → _src/fem}/space/tetmesh_function_space.py +3 -9
- warp/_src/fem/space/topology.py +459 -0
- warp/{fem → _src/fem}/space/trimesh_function_space.py +3 -9
- warp/_src/fem/types.py +112 -0
- warp/_src/fem/utils.py +486 -0
- warp/_src/jax.py +186 -0
- warp/_src/jax_experimental/__init__.py +14 -0
- warp/_src/jax_experimental/custom_call.py +387 -0
- warp/_src/jax_experimental/ffi.py +1284 -0
- warp/_src/jax_experimental/xla_ffi.py +656 -0
- warp/_src/marching_cubes.py +708 -0
- warp/_src/math.py +414 -0
- warp/_src/optim/__init__.py +14 -0
- warp/_src/optim/adam.py +163 -0
- warp/_src/optim/linear.py +1606 -0
- warp/_src/optim/sgd.py +112 -0
- warp/_src/paddle.py +406 -0
- warp/_src/render/__init__.py +14 -0
- warp/_src/render/imgui_manager.py +289 -0
- warp/_src/render/render_opengl.py +3636 -0
- warp/_src/render/render_usd.py +937 -0
- warp/_src/render/utils.py +160 -0
- warp/_src/sparse.py +2716 -0
- warp/_src/tape.py +1206 -0
- warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
- warp/_src/torch.py +391 -0
- warp/_src/types.py +5870 -0
- warp/_src/utils.py +1693 -0
- warp/autograd.py +12 -1054
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +8 -588
- warp/build_dll.py +6 -471
- warp/codegen.py +6 -4246
- warp/constants.py +6 -39
- warp/context.py +12 -7851
- warp/dlpack.py +6 -444
- warp/examples/distributed/example_jacobi_mpi.py +4 -5
- warp/examples/fem/example_adaptive_grid.py +1 -1
- warp/examples/fem/example_apic_fluid.py +1 -1
- warp/examples/fem/example_burgers.py +8 -8
- warp/examples/fem/example_diffusion.py +1 -1
- warp/examples/fem/example_distortion_energy.py +1 -1
- warp/examples/fem/example_mixed_elasticity.py +2 -2
- warp/examples/fem/example_navier_stokes.py +1 -1
- warp/examples/fem/example_nonconforming_contact.py +7 -7
- warp/examples/fem/example_stokes.py +1 -1
- warp/examples/fem/example_stokes_transfer.py +1 -1
- warp/examples/fem/utils.py +2 -2
- warp/examples/interop/example_jax_callable.py +1 -1
- warp/examples/interop/example_jax_ffi_callback.py +1 -1
- warp/examples/interop/example_jax_kernel.py +3 -2
- warp/examples/tile/example_tile_mcgp.py +191 -0
- warp/fabric.py +6 -337
- warp/fem/__init__.py +159 -97
- warp/fem/adaptivity.py +7 -489
- warp/fem/cache.py +9 -648
- warp/fem/dirichlet.py +6 -184
- warp/fem/field/__init__.py +8 -109
- warp/fem/field/field.py +7 -652
- warp/fem/geometry/__init__.py +7 -18
- warp/fem/geometry/closest_point.py +11 -77
- warp/fem/linalg.py +18 -366
- warp/fem/operator.py +11 -369
- warp/fem/polynomial.py +9 -209
- warp/fem/space/__init__.py +5 -211
- warp/fem/space/basis_space.py +6 -662
- warp/fem/space/shape/__init__.py +41 -118
- warp/fem/space/topology.py +6 -437
- warp/fem/types.py +6 -81
- warp/fem/utils.py +11 -444
- warp/jax.py +8 -165
- warp/jax_experimental/__init__.py +14 -1
- warp/jax_experimental/custom_call.py +8 -342
- warp/jax_experimental/ffi.py +17 -853
- warp/jax_experimental/xla_ffi.py +5 -596
- warp/marching_cubes.py +5 -689
- warp/math.py +16 -393
- warp/native/array.h +385 -37
- warp/native/builtin.h +316 -39
- warp/native/bvh.cpp +43 -9
- warp/native/bvh.cu +62 -27
- warp/native/bvh.h +310 -309
- warp/native/clang/clang.cpp +102 -97
- warp/native/coloring.cpp +0 -1
- warp/native/crt.h +208 -0
- warp/native/exports.h +156 -0
- warp/native/hashgrid.cu +2 -0
- warp/native/intersect.h +24 -1
- warp/native/intersect_tri.h +44 -35
- warp/native/mat.h +1456 -276
- warp/native/mesh.cpp +4 -4
- warp/native/mesh.cu +4 -2
- warp/native/mesh.h +176 -61
- warp/native/quat.h +0 -52
- warp/native/scan.cu +2 -0
- warp/native/sort.cu +22 -13
- warp/native/sort.h +2 -0
- warp/native/sparse.cu +7 -3
- warp/native/spatial.h +12 -0
- warp/native/tile.h +837 -70
- warp/native/tile_radix_sort.h +1 -1
- warp/native/tile_reduce.h +394 -46
- warp/native/tile_scan.h +4 -4
- warp/native/vec.h +469 -53
- warp/native/version.h +23 -0
- warp/native/volume.cpp +1 -1
- warp/native/volume.cu +1 -0
- warp/native/volume.h +1 -1
- warp/native/volume_builder.cu +2 -0
- warp/native/warp.cpp +60 -32
- warp/native/warp.cu +313 -201
- warp/native/warp.h +14 -11
- warp/optim/__init__.py +6 -3
- warp/optim/adam.py +6 -145
- warp/optim/linear.py +14 -1585
- warp/optim/sgd.py +6 -94
- warp/paddle.py +6 -388
- warp/render/__init__.py +8 -4
- warp/render/imgui_manager.py +7 -267
- warp/render/render_opengl.py +6 -3616
- warp/render/render_usd.py +6 -918
- warp/render/utils.py +6 -142
- warp/sparse.py +37 -2563
- warp/tape.py +6 -1188
- warp/tests/__main__.py +1 -1
- warp/tests/cuda/test_async.py +4 -4
- warp/tests/cuda/test_conditional_captures.py +1 -1
- warp/tests/cuda/test_multigpu.py +1 -1
- warp/tests/cuda/test_streams.py +58 -1
- warp/tests/geometry/test_bvh.py +157 -22
- warp/tests/geometry/test_hash_grid.py +38 -0
- warp/tests/geometry/test_marching_cubes.py +0 -1
- warp/tests/geometry/test_mesh.py +5 -3
- warp/tests/geometry/test_mesh_query_aabb.py +5 -12
- warp/tests/geometry/test_mesh_query_point.py +5 -2
- warp/tests/geometry/test_mesh_query_ray.py +15 -3
- warp/tests/geometry/test_volume_write.py +5 -5
- warp/tests/interop/test_dlpack.py +14 -14
- warp/tests/interop/test_jax.py +1382 -79
- warp/tests/interop/test_paddle.py +1 -1
- warp/tests/test_adam.py +0 -1
- warp/tests/test_arithmetic.py +9 -9
- warp/tests/test_array.py +529 -100
- warp/tests/test_array_reduce.py +3 -3
- warp/tests/test_atomic.py +12 -8
- warp/tests/test_atomic_bitwise.py +209 -0
- warp/tests/test_atomic_cas.py +4 -4
- warp/tests/test_bool.py +2 -2
- warp/tests/test_builtins_resolution.py +5 -571
- warp/tests/test_codegen.py +34 -15
- warp/tests/test_conditional.py +1 -1
- warp/tests/test_context.py +6 -6
- warp/tests/test_copy.py +242 -161
- warp/tests/test_ctypes.py +3 -3
- warp/tests/test_devices.py +24 -2
- warp/tests/test_examples.py +16 -84
- warp/tests/test_fabricarray.py +35 -35
- warp/tests/test_fast_math.py +0 -2
- warp/tests/test_fem.py +60 -14
- warp/tests/test_fixedarray.py +3 -3
- warp/tests/test_func.py +8 -5
- warp/tests/test_generics.py +1 -1
- warp/tests/test_indexedarray.py +24 -24
- warp/tests/test_intersect.py +39 -9
- warp/tests/test_large.py +1 -1
- warp/tests/test_lerp.py +3 -1
- warp/tests/test_linear_solvers.py +1 -1
- warp/tests/test_map.py +49 -4
- warp/tests/test_mat.py +52 -62
- warp/tests/test_mat_constructors.py +4 -5
- warp/tests/test_mat_lite.py +1 -1
- warp/tests/test_mat_scalar_ops.py +121 -121
- warp/tests/test_math.py +34 -0
- warp/tests/test_module_aot.py +4 -4
- warp/tests/test_modules_lite.py +28 -2
- warp/tests/test_print.py +11 -11
- warp/tests/test_quat.py +93 -58
- warp/tests/test_runlength_encode.py +1 -1
- warp/tests/test_scalar_ops.py +38 -10
- warp/tests/test_smoothstep.py +1 -1
- warp/tests/test_sparse.py +126 -15
- warp/tests/test_spatial.py +105 -87
- warp/tests/test_special_values.py +6 -6
- warp/tests/test_static.py +7 -7
- warp/tests/test_struct.py +13 -2
- warp/tests/test_triangle_closest_point.py +48 -1
- warp/tests/test_tuple.py +96 -0
- warp/tests/test_types.py +82 -9
- warp/tests/test_utils.py +52 -52
- warp/tests/test_vec.py +29 -29
- warp/tests/test_vec_constructors.py +5 -5
- warp/tests/test_vec_scalar_ops.py +97 -97
- warp/tests/test_version.py +75 -0
- warp/tests/tile/test_tile.py +239 -0
- warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
- warp/tests/tile/test_tile_cholesky.py +7 -4
- warp/tests/tile/test_tile_load.py +26 -2
- warp/tests/tile/test_tile_mathdx.py +3 -3
- warp/tests/tile/test_tile_matmul.py +1 -1
- warp/tests/tile/test_tile_mlp.py +2 -4
- warp/tests/tile/test_tile_reduce.py +214 -13
- warp/tests/unittest_suites.py +6 -14
- warp/tests/unittest_utils.py +10 -9
- warp/tests/walkthrough_debug.py +3 -1
- warp/torch.py +6 -373
- warp/types.py +29 -5750
- warp/utils.py +10 -1659
- {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/METADATA +47 -103
- warp_lang-1.10.0rc2.dist-info/RECORD +468 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
- warp/examples/assets/cartpole.urdf +0 -110
- warp/examples/assets/crazyflie.usd +0 -0
- warp/examples/assets/nv_ant.xml +0 -92
- warp/examples/assets/nv_humanoid.xml +0 -183
- warp/examples/assets/quadruped.urdf +0 -268
- warp/examples/optim/example_bounce.py +0 -266
- warp/examples/optim/example_cloth_throw.py +0 -228
- warp/examples/optim/example_drone.py +0 -870
- warp/examples/optim/example_inverse_kinematics.py +0 -182
- warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
- warp/examples/optim/example_softbody_properties.py +0 -400
- warp/examples/optim/example_spring_cage.py +0 -245
- warp/examples/optim/example_trajectory.py +0 -227
- warp/examples/sim/example_cartpole.py +0 -143
- warp/examples/sim/example_cloth.py +0 -225
- warp/examples/sim/example_cloth_self_contact.py +0 -316
- warp/examples/sim/example_granular.py +0 -130
- warp/examples/sim/example_granular_collision_sdf.py +0 -202
- warp/examples/sim/example_jacobian_ik.py +0 -244
- warp/examples/sim/example_particle_chain.py +0 -124
- warp/examples/sim/example_quadruped.py +0 -203
- warp/examples/sim/example_rigid_chain.py +0 -203
- warp/examples/sim/example_rigid_contact.py +0 -195
- warp/examples/sim/example_rigid_force.py +0 -133
- warp/examples/sim/example_rigid_gyroscopic.py +0 -115
- warp/examples/sim/example_rigid_soft_contact.py +0 -140
- warp/examples/sim/example_soft_body.py +0 -196
- warp/examples/tile/example_tile_walker.py +0 -327
- warp/sim/__init__.py +0 -74
- warp/sim/articulation.py +0 -793
- warp/sim/collide.py +0 -2570
- warp/sim/graph_coloring.py +0 -307
- warp/sim/import_mjcf.py +0 -791
- warp/sim/import_snu.py +0 -227
- warp/sim/import_urdf.py +0 -579
- warp/sim/import_usd.py +0 -898
- warp/sim/inertia.py +0 -357
- warp/sim/integrator.py +0 -245
- warp/sim/integrator_euler.py +0 -2000
- warp/sim/integrator_featherstone.py +0 -2101
- warp/sim/integrator_vbd.py +0 -2487
- warp/sim/integrator_xpbd.py +0 -3295
- warp/sim/model.py +0 -4821
- warp/sim/particles.py +0 -121
- warp/sim/render.py +0 -431
- warp/sim/utils.py +0 -431
- warp/tests/sim/disabled_kinematics.py +0 -244
- warp/tests/sim/test_cloth.py +0 -863
- warp/tests/sim/test_collision.py +0 -743
- warp/tests/sim/test_coloring.py +0 -347
- warp/tests/sim/test_inertia.py +0 -161
- warp/tests/sim/test_model.py +0 -226
- warp/tests/sim/test_sim_grad.py +0 -287
- warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
- warp/tests/sim/test_sim_kinematics.py +0 -98
- warp/thirdparty/__init__.py +0 -0
- warp_lang-1.9.0.dist-info/RECORD +0 -456
- /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
- /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
- /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
- /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
- {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/WHEEL +0 -0
- {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/licenses/LICENSE.md +0 -0
- {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/top_level.txt +0 -0
warp/tests/interop/test_jax.py
CHANGED
|
@@ -15,13 +15,18 @@
|
|
|
15
15
|
|
|
16
16
|
import os
|
|
17
17
|
import unittest
|
|
18
|
+
from functools import partial
|
|
18
19
|
from typing import Any
|
|
19
20
|
|
|
20
21
|
import numpy as np
|
|
21
22
|
|
|
22
23
|
import warp as wp
|
|
24
|
+
from warp._src.jax import get_jax_device
|
|
23
25
|
from warp.tests.unittest_utils import *
|
|
24
26
|
|
|
27
|
+
# default array size for tests
|
|
28
|
+
ARRAY_SIZE = 1024 * 1024
|
|
29
|
+
|
|
25
30
|
|
|
26
31
|
# basic kernel with one input and output
|
|
27
32
|
@wp.kernel
|
|
@@ -44,6 +49,18 @@ def triple_kernel_vecmat(input: wp.array(dtype=Any), output: wp.array(dtype=Any)
|
|
|
44
49
|
output[tid] = input.dtype.dtype(3) * input[tid]
|
|
45
50
|
|
|
46
51
|
|
|
52
|
+
@wp.kernel
|
|
53
|
+
def inc_1d_kernel(x: wp.array(dtype=float), y: wp.array(dtype=float)):
|
|
54
|
+
tid = wp.tid()
|
|
55
|
+
y[tid] = x[tid] + 1.0
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@wp.kernel
|
|
59
|
+
def inc_2d_kernel(x: wp.array2d(dtype=float), y: wp.array2d(dtype=float)):
|
|
60
|
+
i, j = wp.tid()
|
|
61
|
+
y[i, j] = x[i, j] + 1.0
|
|
62
|
+
|
|
63
|
+
|
|
47
64
|
# kernel with multiple inputs and outputs
|
|
48
65
|
@wp.kernel
|
|
49
66
|
def multiarg_kernel(
|
|
@@ -61,7 +78,7 @@ def multiarg_kernel(
|
|
|
61
78
|
|
|
62
79
|
|
|
63
80
|
# various types for testing
|
|
64
|
-
scalar_types = wp.types.scalar_types
|
|
81
|
+
scalar_types = wp._src.types.scalar_types
|
|
65
82
|
vector_types = []
|
|
66
83
|
matrix_types = []
|
|
67
84
|
for dim in [2, 3, 4]:
|
|
@@ -132,15 +149,19 @@ def test_device_conversion(test, device):
|
|
|
132
149
|
test.assertEqual(warp_device, device)
|
|
133
150
|
|
|
134
151
|
|
|
135
|
-
|
|
136
|
-
def test_jax_kernel_basic(test, device):
|
|
152
|
+
def test_jax_kernel_basic(test, device, use_ffi=False):
|
|
137
153
|
import jax.numpy as jp
|
|
138
154
|
|
|
139
|
-
|
|
155
|
+
if use_ffi:
|
|
156
|
+
from warp.jax_experimental.ffi import jax_kernel
|
|
140
157
|
|
|
141
|
-
|
|
158
|
+
jax_triple = jax_kernel(triple_kernel)
|
|
159
|
+
else:
|
|
160
|
+
from warp.jax_experimental.custom_call import jax_kernel
|
|
161
|
+
|
|
162
|
+
jax_triple = jax_kernel(triple_kernel, quiet=True) # suppress deprecation warnings
|
|
142
163
|
|
|
143
|
-
|
|
164
|
+
n = ARRAY_SIZE
|
|
144
165
|
|
|
145
166
|
@jax.jit
|
|
146
167
|
def f():
|
|
@@ -151,18 +172,27 @@ def test_jax_kernel_basic(test, device):
|
|
|
151
172
|
with jax.default_device(wp.device_to_jax(device)):
|
|
152
173
|
y = f()
|
|
153
174
|
|
|
175
|
+
wp.synchronize_device(device)
|
|
176
|
+
|
|
154
177
|
result = np.asarray(y).reshape((n,))
|
|
155
178
|
expected = 3 * np.arange(n, dtype=np.float32)
|
|
156
179
|
|
|
157
180
|
assert_np_equal(result, expected)
|
|
158
181
|
|
|
159
182
|
|
|
160
|
-
|
|
161
|
-
def test_jax_kernel_scalar(test, device):
|
|
183
|
+
def test_jax_kernel_scalar(test, device, use_ffi=False):
|
|
162
184
|
import jax.numpy as jp
|
|
163
185
|
|
|
164
|
-
|
|
186
|
+
if use_ffi:
|
|
187
|
+
from warp.jax_experimental.ffi import jax_kernel
|
|
188
|
+
|
|
189
|
+
kwargs = {}
|
|
190
|
+
else:
|
|
191
|
+
from warp.jax_experimental.custom_call import jax_kernel
|
|
165
192
|
|
|
193
|
+
kwargs = {"quiet": True}
|
|
194
|
+
|
|
195
|
+
# use a smallish size to ensure arange * 3 doesn't overflow
|
|
166
196
|
n = 64
|
|
167
197
|
|
|
168
198
|
for T in scalar_types:
|
|
@@ -173,7 +203,7 @@ def test_jax_kernel_scalar(test, device):
|
|
|
173
203
|
# get the concrete overload
|
|
174
204
|
kernel_instance = triple_kernel_scalar.add_overload([wp.array(dtype=T), wp.array(dtype=T)])
|
|
175
205
|
|
|
176
|
-
jax_triple = jax_kernel(kernel_instance)
|
|
206
|
+
jax_triple = jax_kernel(kernel_instance, **kwargs)
|
|
177
207
|
|
|
178
208
|
@jax.jit
|
|
179
209
|
def f(jax_triple=jax_triple, jp_dtype=jp_dtype):
|
|
@@ -184,22 +214,31 @@ def test_jax_kernel_scalar(test, device):
|
|
|
184
214
|
with jax.default_device(wp.device_to_jax(device)):
|
|
185
215
|
y = f()
|
|
186
216
|
|
|
217
|
+
wp.synchronize_device(device)
|
|
218
|
+
|
|
187
219
|
result = np.asarray(y).reshape((n,))
|
|
188
220
|
expected = 3 * np.arange(n, dtype=np_dtype)
|
|
189
221
|
|
|
190
222
|
assert_np_equal(result, expected)
|
|
191
223
|
|
|
192
224
|
|
|
193
|
-
|
|
194
|
-
def test_jax_kernel_vecmat(test, device):
|
|
225
|
+
def test_jax_kernel_vecmat(test, device, use_ffi=False):
|
|
195
226
|
import jax.numpy as jp
|
|
196
227
|
|
|
197
|
-
|
|
228
|
+
if use_ffi:
|
|
229
|
+
from warp.jax_experimental.ffi import jax_kernel
|
|
230
|
+
|
|
231
|
+
kwargs = {}
|
|
232
|
+
else:
|
|
233
|
+
from warp.jax_experimental.custom_call import jax_kernel
|
|
234
|
+
|
|
235
|
+
kwargs = {"quiet": True}
|
|
198
236
|
|
|
199
237
|
for T in [*vector_types, *matrix_types]:
|
|
200
238
|
jp_dtype = wp.dtype_to_jax(T._wp_scalar_type_)
|
|
201
239
|
np_dtype = wp.dtype_to_numpy(T._wp_scalar_type_)
|
|
202
240
|
|
|
241
|
+
# use a smallish size to ensure arange * 3 doesn't overflow
|
|
203
242
|
n = 64 // T._length_
|
|
204
243
|
scalar_shape = (n, *T._shape_)
|
|
205
244
|
scalar_len = n * T._length_
|
|
@@ -208,7 +247,7 @@ def test_jax_kernel_vecmat(test, device):
|
|
|
208
247
|
# get the concrete overload
|
|
209
248
|
kernel_instance = triple_kernel_vecmat.add_overload([wp.array(dtype=T), wp.array(dtype=T)])
|
|
210
249
|
|
|
211
|
-
jax_triple = jax_kernel(kernel_instance)
|
|
250
|
+
jax_triple = jax_kernel(kernel_instance, **kwargs)
|
|
212
251
|
|
|
213
252
|
@jax.jit
|
|
214
253
|
def f(jax_triple=jax_triple, jp_dtype=jp_dtype, scalar_len=scalar_len, scalar_shape=scalar_shape):
|
|
@@ -219,21 +258,27 @@ def test_jax_kernel_vecmat(test, device):
|
|
|
219
258
|
with jax.default_device(wp.device_to_jax(device)):
|
|
220
259
|
y = f()
|
|
221
260
|
|
|
261
|
+
wp.synchronize_device(device)
|
|
262
|
+
|
|
222
263
|
result = np.asarray(y).reshape(scalar_shape)
|
|
223
264
|
expected = 3 * np.arange(scalar_len, dtype=np_dtype).reshape(scalar_shape)
|
|
224
265
|
|
|
225
266
|
assert_np_equal(result, expected)
|
|
226
267
|
|
|
227
268
|
|
|
228
|
-
|
|
229
|
-
def test_jax_kernel_multiarg(test, device):
|
|
269
|
+
def test_jax_kernel_multiarg(test, device, use_ffi=False):
|
|
230
270
|
import jax.numpy as jp
|
|
231
271
|
|
|
232
|
-
|
|
272
|
+
if use_ffi:
|
|
273
|
+
from warp.jax_experimental.ffi import jax_kernel
|
|
233
274
|
|
|
234
|
-
|
|
275
|
+
jax_multiarg = jax_kernel(multiarg_kernel, num_outputs=2)
|
|
276
|
+
else:
|
|
277
|
+
from warp.jax_experimental.custom_call import jax_kernel
|
|
235
278
|
|
|
236
|
-
|
|
279
|
+
jax_multiarg = jax_kernel(multiarg_kernel, quiet=True)
|
|
280
|
+
|
|
281
|
+
n = ARRAY_SIZE
|
|
237
282
|
|
|
238
283
|
@jax.jit
|
|
239
284
|
def f():
|
|
@@ -246,6 +291,8 @@ def test_jax_kernel_multiarg(test, device):
|
|
|
246
291
|
with jax.default_device(wp.device_to_jax(device)):
|
|
247
292
|
x, y = f()
|
|
248
293
|
|
|
294
|
+
wp.synchronize_device(device)
|
|
295
|
+
|
|
249
296
|
result_x, result_y = np.asarray(x), np.asarray(y)
|
|
250
297
|
expected_x = np.full(n, 3, dtype=np.float32)
|
|
251
298
|
expected_y = np.full(n, 5, dtype=np.float32)
|
|
@@ -254,50 +301,48 @@ def test_jax_kernel_multiarg(test, device):
|
|
|
254
301
|
assert_np_equal(result_y, expected_y)
|
|
255
302
|
|
|
256
303
|
|
|
257
|
-
|
|
258
|
-
def test_jax_kernel_launch_dims(test, device):
|
|
304
|
+
def test_jax_kernel_launch_dims(test, device, use_ffi=False):
|
|
259
305
|
import jax.numpy as jp
|
|
260
306
|
|
|
261
|
-
|
|
307
|
+
if use_ffi:
|
|
308
|
+
from warp.jax_experimental.ffi import jax_kernel
|
|
309
|
+
|
|
310
|
+
kwargs = {}
|
|
311
|
+
else:
|
|
312
|
+
from warp.jax_experimental.custom_call import jax_kernel
|
|
313
|
+
|
|
314
|
+
kwargs = {"quiet": True}
|
|
262
315
|
|
|
263
316
|
n = 64
|
|
264
317
|
m = 32
|
|
265
318
|
|
|
266
319
|
# Test with 1D launch dims
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
tid = wp.tid()
|
|
270
|
-
y[tid] = x[tid] + 1.0
|
|
271
|
-
|
|
272
|
-
jax_add_one = jax_kernel(
|
|
273
|
-
add_one_kernel, launch_dims=(n - 2,)
|
|
320
|
+
jax_inc_1d = jax_kernel(
|
|
321
|
+
inc_1d_kernel, launch_dims=(n - 2,), **kwargs
|
|
274
322
|
) # Intentionally not the same as the first dimension of the input
|
|
275
323
|
|
|
276
324
|
@jax.jit
|
|
277
325
|
def f_1d():
|
|
278
326
|
x = jp.arange(n, dtype=jp.float32)
|
|
279
|
-
return
|
|
327
|
+
return jax_inc_1d(x)
|
|
280
328
|
|
|
281
329
|
# Test with 2D launch dims
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
i, j = wp.tid()
|
|
285
|
-
y[i, j] = x[i, j] + 1.0
|
|
286
|
-
|
|
287
|
-
jax_add_one_2d = jax_kernel(
|
|
288
|
-
add_one_2d_kernel, launch_dims=(n - 2, m - 2)
|
|
330
|
+
jax_inc_2d = jax_kernel(
|
|
331
|
+
inc_2d_kernel, launch_dims=(n - 2, m - 2), **kwargs
|
|
289
332
|
) # Intentionally not the same as the first dimension of the input
|
|
290
333
|
|
|
291
334
|
@jax.jit
|
|
292
335
|
def f_2d():
|
|
293
336
|
x = jp.zeros((n, m), dtype=jp.float32) + 3.0
|
|
294
|
-
return
|
|
337
|
+
return jax_inc_2d(x)
|
|
295
338
|
|
|
296
339
|
# run on the given device
|
|
297
340
|
with jax.default_device(wp.device_to_jax(device)):
|
|
298
341
|
y_1d = f_1d()
|
|
299
342
|
y_2d = f_2d()
|
|
300
343
|
|
|
344
|
+
wp.synchronize_device(device)
|
|
345
|
+
|
|
301
346
|
result_1d = np.asarray(y_1d).reshape((n - 2,))
|
|
302
347
|
expected_1d = np.arange(n - 2, dtype=np.float32) + 1.0
|
|
303
348
|
|
|
@@ -308,57 +353,1315 @@ def test_jax_kernel_launch_dims(test, device):
|
|
|
308
353
|
assert_np_equal(result_2d, expected_2d)
|
|
309
354
|
|
|
310
355
|
|
|
311
|
-
|
|
312
|
-
|
|
356
|
+
# =========================================================================================================
|
|
357
|
+
# JAX FFI
|
|
358
|
+
# =========================================================================================================
|
|
313
359
|
|
|
314
360
|
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
|
|
361
|
+
@wp.kernel
|
|
362
|
+
def add_kernel(a: wp.array(dtype=float), b: wp.array(dtype=float), output: wp.array(dtype=float)):
|
|
363
|
+
tid = wp.tid()
|
|
364
|
+
output[tid] = a[tid] + b[tid]
|
|
320
365
|
|
|
321
|
-
import jax
|
|
322
366
|
|
|
323
|
-
|
|
324
|
-
|
|
367
|
+
@wp.kernel
|
|
368
|
+
def axpy_kernel(x: wp.array(dtype=float), y: wp.array(dtype=float), alpha: float, out: wp.array(dtype=float)):
|
|
369
|
+
tid = wp.tid()
|
|
370
|
+
out[tid] = alpha * x[tid] + y[tid]
|
|
325
371
|
|
|
326
|
-
# check which Warp devices work with Jax
|
|
327
|
-
# CUDA devices may fail if Jax cannot find a CUDA Toolkit
|
|
328
|
-
test_devices = get_test_devices()
|
|
329
|
-
jax_compatible_devices = []
|
|
330
|
-
jax_compatible_cuda_devices = []
|
|
331
|
-
for d in test_devices:
|
|
332
|
-
try:
|
|
333
|
-
with jax.default_device(wp.device_to_jax(d)):
|
|
334
|
-
j = jax.numpy.arange(10, dtype=jax.numpy.float32)
|
|
335
|
-
j += 1
|
|
336
|
-
jax_compatible_devices.append(d)
|
|
337
|
-
if d.is_cuda:
|
|
338
|
-
jax_compatible_cuda_devices.append(d)
|
|
339
|
-
except Exception as e:
|
|
340
|
-
print(f"Skipping Jax DLPack tests on device '{d}' due to exception: {e}")
|
|
341
372
|
|
|
342
|
-
|
|
343
|
-
|
|
373
|
+
@wp.kernel
|
|
374
|
+
def sincos_kernel(angle: wp.array(dtype=float), sin_out: wp.array(dtype=float), cos_out: wp.array(dtype=float)):
|
|
375
|
+
tid = wp.tid()
|
|
376
|
+
sin_out[tid] = wp.sin(angle[tid])
|
|
377
|
+
cos_out[tid] = wp.cos(angle[tid])
|
|
344
378
|
|
|
345
|
-
if jax_compatible_devices:
|
|
346
|
-
add_function_test(TestJax, "test_device_conversion", test_device_conversion, devices=jax_compatible_devices)
|
|
347
379
|
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
380
|
+
@wp.kernel
|
|
381
|
+
def diagonal_kernel(output: wp.array(dtype=wp.mat33)):
|
|
382
|
+
tid = wp.tid()
|
|
383
|
+
d = float(tid + 1)
|
|
384
|
+
output[tid] = wp.mat33(d, 0.0, 0.0, 0.0, d * 2.0, 0.0, 0.0, 0.0, d * 3.0)
|
|
385
|
+
|
|
386
|
+
|
|
387
|
+
@wp.kernel
|
|
388
|
+
def scale_kernel(a: wp.array(dtype=float), s: float, output: wp.array(dtype=float)):
|
|
389
|
+
tid = wp.tid()
|
|
390
|
+
output[tid] = a[tid] * s
|
|
391
|
+
|
|
392
|
+
|
|
393
|
+
@wp.kernel
|
|
394
|
+
def scale_vec_kernel(a: wp.array(dtype=wp.vec2), s: float, output: wp.array(dtype=wp.vec2)):
|
|
395
|
+
tid = wp.tid()
|
|
396
|
+
output[tid] = a[tid] * s
|
|
397
|
+
|
|
398
|
+
|
|
399
|
+
@wp.kernel
|
|
400
|
+
def accum_kernel(a: wp.array(dtype=float), b: wp.array(dtype=float)):
|
|
401
|
+
tid = wp.tid()
|
|
402
|
+
b[tid] += a[tid]
|
|
403
|
+
|
|
404
|
+
|
|
405
|
+
@wp.kernel
|
|
406
|
+
def matmul_kernel(
|
|
407
|
+
a: wp.array2d(dtype=float), # NxK
|
|
408
|
+
b: wp.array2d(dtype=float), # KxM
|
|
409
|
+
c: wp.array2d(dtype=float), # NxM
|
|
410
|
+
):
|
|
411
|
+
# launch dims should be (N, M)
|
|
412
|
+
i, j = wp.tid()
|
|
413
|
+
N = a.shape[0]
|
|
414
|
+
K = a.shape[1]
|
|
415
|
+
M = b.shape[1]
|
|
416
|
+
if i < N and j < M:
|
|
417
|
+
s = wp.float32(0)
|
|
418
|
+
for k in range(K):
|
|
419
|
+
s += a[i, k] * b[k, j]
|
|
420
|
+
c[i, j] = s
|
|
421
|
+
|
|
422
|
+
|
|
423
|
+
@wp.kernel
|
|
424
|
+
def in_out_kernel(
|
|
425
|
+
a: wp.array(dtype=float), # input only
|
|
426
|
+
b: wp.array(dtype=float), # input and output
|
|
427
|
+
c: wp.array(dtype=float), # output only
|
|
428
|
+
):
|
|
429
|
+
tid = wp.tid()
|
|
430
|
+
b[tid] += a[tid]
|
|
431
|
+
c[tid] = 2.0 * a[tid]
|
|
432
|
+
|
|
433
|
+
|
|
434
|
+
@wp.kernel
|
|
435
|
+
def multi_out_kernel(
|
|
436
|
+
a: wp.array(dtype=float), b: wp.array(dtype=float), s: float, c: wp.array(dtype=float), d: wp.array(dtype=float)
|
|
437
|
+
):
|
|
438
|
+
tid = wp.tid()
|
|
439
|
+
c[tid] = a[tid] + b[tid]
|
|
440
|
+
d[tid] = s * a[tid]
|
|
441
|
+
|
|
442
|
+
|
|
443
|
+
@wp.kernel
|
|
444
|
+
def multi_out_kernel_v2(
|
|
445
|
+
a: wp.array(dtype=float), b: wp.array(dtype=float), s: float, c: wp.array(dtype=float), d: wp.array(dtype=float)
|
|
446
|
+
):
|
|
447
|
+
tid = wp.tid()
|
|
448
|
+
c[tid] = a[tid] * a[tid]
|
|
449
|
+
d[tid] = a[tid] * b[tid] * s
|
|
450
|
+
|
|
451
|
+
|
|
452
|
+
@wp.kernel
|
|
453
|
+
def multi_out_kernel_v3(
|
|
454
|
+
a: wp.array(dtype=float), b: wp.array(dtype=float), s: float, c: wp.array(dtype=float), d: wp.array(dtype=float)
|
|
455
|
+
):
|
|
456
|
+
tid = wp.tid()
|
|
457
|
+
c[tid] = a[tid] ** 2.0
|
|
458
|
+
d[tid] = a[tid] * b[tid] * s
|
|
459
|
+
|
|
460
|
+
|
|
461
|
+
@wp.kernel
|
|
462
|
+
def scale_sum_square_kernel(a: wp.array(dtype=float), b: wp.array(dtype=float), s: float, c: wp.array(dtype=float)):
|
|
463
|
+
tid = wp.tid()
|
|
464
|
+
c[tid] = (a[tid] * s + b[tid]) ** 2.0
|
|
465
|
+
|
|
466
|
+
|
|
467
|
+
# The Python function to call.
|
|
468
|
+
# Note the argument annotations, just like Warp kernels.
|
|
469
|
+
def scale_func(
|
|
470
|
+
# inputs
|
|
471
|
+
a: wp.array(dtype=float),
|
|
472
|
+
b: wp.array(dtype=wp.vec2),
|
|
473
|
+
s: float,
|
|
474
|
+
# outputs
|
|
475
|
+
c: wp.array(dtype=float),
|
|
476
|
+
d: wp.array(dtype=wp.vec2),
|
|
477
|
+
):
|
|
478
|
+
wp.launch(scale_kernel, dim=a.shape, inputs=[a, s], outputs=[c])
|
|
479
|
+
wp.launch(scale_vec_kernel, dim=b.shape, inputs=[b, s], outputs=[d])
|
|
480
|
+
|
|
481
|
+
|
|
482
|
+
def in_out_func(
|
|
483
|
+
a: wp.array(dtype=float), # input only
|
|
484
|
+
b: wp.array(dtype=float), # input and output
|
|
485
|
+
c: wp.array(dtype=float), # output only
|
|
486
|
+
):
|
|
487
|
+
wp.launch(scale_kernel, dim=a.size, inputs=[a, 2.0], outputs=[c])
|
|
488
|
+
wp.launch(accum_kernel, dim=a.size, inputs=[a, b]) # modifies `b`
|
|
489
|
+
|
|
490
|
+
|
|
491
|
+
def double_func(
|
|
492
|
+
# inputs
|
|
493
|
+
a: wp.array(dtype=float),
|
|
494
|
+
# outputs
|
|
495
|
+
b: wp.array(dtype=float),
|
|
496
|
+
):
|
|
497
|
+
wp.launch(scale_kernel, dim=a.shape, inputs=[a, 2.0], outputs=[b])
|
|
498
|
+
|
|
499
|
+
|
|
500
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
501
|
+
def test_ffi_jax_kernel_add(test, device):
|
|
502
|
+
# two inputs and one output
|
|
503
|
+
import jax.numpy as jp
|
|
504
|
+
|
|
505
|
+
from warp.jax_experimental.ffi import jax_kernel
|
|
506
|
+
|
|
507
|
+
jax_add = jax_kernel(add_kernel)
|
|
508
|
+
|
|
509
|
+
@jax.jit
|
|
510
|
+
def f():
|
|
511
|
+
n = ARRAY_SIZE
|
|
512
|
+
a = jp.arange(n, dtype=jp.float32)
|
|
513
|
+
b = jp.ones(n, dtype=jp.float32)
|
|
514
|
+
return jax_add(a, b)
|
|
515
|
+
|
|
516
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
517
|
+
(y,) = f()
|
|
518
|
+
|
|
519
|
+
wp.synchronize_device(device)
|
|
520
|
+
|
|
521
|
+
result = np.asarray(y)
|
|
522
|
+
expected = np.arange(1, ARRAY_SIZE + 1, dtype=np.float32)
|
|
523
|
+
|
|
524
|
+
assert_np_equal(result, expected)
|
|
525
|
+
|
|
526
|
+
|
|
527
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
528
|
+
def test_ffi_jax_kernel_sincos(test, device):
|
|
529
|
+
# one input and two outputs
|
|
530
|
+
import jax.numpy as jp
|
|
531
|
+
|
|
532
|
+
from warp.jax_experimental.ffi import jax_kernel
|
|
533
|
+
|
|
534
|
+
jax_sincos = jax_kernel(sincos_kernel, num_outputs=2)
|
|
535
|
+
|
|
536
|
+
n = ARRAY_SIZE
|
|
537
|
+
|
|
538
|
+
@jax.jit
|
|
539
|
+
def f():
|
|
540
|
+
a = jp.linspace(0, 2 * jp.pi, n, dtype=jp.float32)
|
|
541
|
+
return jax_sincos(a)
|
|
542
|
+
|
|
543
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
544
|
+
s, c = f()
|
|
545
|
+
|
|
546
|
+
wp.synchronize_device(device)
|
|
547
|
+
|
|
548
|
+
result_s = np.asarray(s)
|
|
549
|
+
result_c = np.asarray(c)
|
|
550
|
+
|
|
551
|
+
a = np.linspace(0, 2 * np.pi, n, dtype=np.float32)
|
|
552
|
+
expected_s = np.sin(a)
|
|
553
|
+
expected_c = np.cos(a)
|
|
554
|
+
|
|
555
|
+
assert_np_equal(result_s, expected_s, tol=1e-4)
|
|
556
|
+
assert_np_equal(result_c, expected_c, tol=1e-4)
|
|
557
|
+
|
|
558
|
+
|
|
559
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
560
|
+
def test_ffi_jax_kernel_diagonal(test, device):
|
|
561
|
+
# no inputs and one output
|
|
562
|
+
from warp.jax_experimental.ffi import jax_kernel
|
|
563
|
+
|
|
564
|
+
jax_diagonal = jax_kernel(diagonal_kernel)
|
|
565
|
+
|
|
566
|
+
@jax.jit
|
|
567
|
+
def f():
|
|
568
|
+
# launch dimensions determine output size
|
|
569
|
+
return jax_diagonal(launch_dims=4)
|
|
570
|
+
|
|
571
|
+
wp.synchronize_device(device)
|
|
572
|
+
|
|
573
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
574
|
+
(d,) = f()
|
|
575
|
+
|
|
576
|
+
result = np.asarray(d)
|
|
577
|
+
expected = np.array(
|
|
578
|
+
[
|
|
579
|
+
[[1.0, 0.0, 0.0], [0.0, 2.0, 0.0], [0.0, 0.0, 3.0]],
|
|
580
|
+
[[2.0, 0.0, 0.0], [0.0, 4.0, 0.0], [0.0, 0.0, 6.0]],
|
|
581
|
+
[[3.0, 0.0, 0.0], [0.0, 6.0, 0.0], [0.0, 0.0, 9.0]],
|
|
582
|
+
[[4.0, 0.0, 0.0], [0.0, 8.0, 0.0], [0.0, 0.0, 12.0]],
|
|
583
|
+
],
|
|
584
|
+
dtype=np.float32,
|
|
585
|
+
)
|
|
586
|
+
|
|
587
|
+
assert_np_equal(result, expected)
|
|
588
|
+
|
|
589
|
+
|
|
590
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
591
|
+
def test_ffi_jax_kernel_in_out(test, device):
|
|
592
|
+
# in-out args
|
|
593
|
+
import jax.numpy as jp
|
|
594
|
+
|
|
595
|
+
from warp.jax_experimental.ffi import jax_kernel
|
|
596
|
+
|
|
597
|
+
jax_func = jax_kernel(in_out_kernel, num_outputs=2, in_out_argnames=["b"])
|
|
598
|
+
|
|
599
|
+
f = jax.jit(jax_func)
|
|
600
|
+
|
|
601
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
602
|
+
a = jp.ones(ARRAY_SIZE, dtype=jp.float32)
|
|
603
|
+
b = jp.arange(ARRAY_SIZE, dtype=jp.float32)
|
|
604
|
+
b, c = f(a, b)
|
|
605
|
+
|
|
606
|
+
wp.synchronize_device(device)
|
|
607
|
+
|
|
608
|
+
assert_np_equal(b, np.arange(1, ARRAY_SIZE + 1, dtype=np.float32))
|
|
609
|
+
assert_np_equal(c, np.full(ARRAY_SIZE, 2, dtype=np.float32))
|
|
610
|
+
|
|
611
|
+
|
|
612
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
613
|
+
def test_ffi_jax_kernel_scale_vec_constant(test, device):
|
|
614
|
+
# multiply vectors by scalar (constant)
|
|
615
|
+
import jax.numpy as jp
|
|
616
|
+
|
|
617
|
+
from warp.jax_experimental.ffi import jax_kernel
|
|
618
|
+
|
|
619
|
+
jax_scale_vec = jax_kernel(scale_vec_kernel)
|
|
620
|
+
|
|
621
|
+
@jax.jit
|
|
622
|
+
def f():
|
|
623
|
+
a = jp.arange(ARRAY_SIZE, dtype=jp.float32).reshape((ARRAY_SIZE // 2, 2)) # array of vec2
|
|
624
|
+
s = 2.0
|
|
625
|
+
return jax_scale_vec(a, s)
|
|
626
|
+
|
|
627
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
628
|
+
(b,) = f()
|
|
629
|
+
|
|
630
|
+
wp.synchronize_device(device)
|
|
631
|
+
|
|
632
|
+
expected = 2 * np.arange(ARRAY_SIZE, dtype=np.float32).reshape((ARRAY_SIZE // 2, 2))
|
|
633
|
+
|
|
634
|
+
assert_np_equal(b, expected)
|
|
635
|
+
|
|
636
|
+
|
|
637
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
638
|
+
def test_ffi_jax_kernel_scale_vec_static(test, device):
|
|
639
|
+
# multiply vectors by scalar (static arg)
|
|
640
|
+
import jax.numpy as jp
|
|
641
|
+
|
|
642
|
+
from warp.jax_experimental.ffi import jax_kernel
|
|
643
|
+
|
|
644
|
+
jax_scale_vec = jax_kernel(scale_vec_kernel)
|
|
645
|
+
|
|
646
|
+
# NOTE: scalar arguments must be static compile-time constants
|
|
647
|
+
@partial(jax.jit, static_argnames=["s"])
|
|
648
|
+
def f(a, s):
|
|
649
|
+
return jax_scale_vec(a, s)
|
|
650
|
+
|
|
651
|
+
a = jp.arange(ARRAY_SIZE, dtype=jp.float32).reshape((ARRAY_SIZE // 2, 2)) # array of vec2
|
|
652
|
+
s = 3.0
|
|
653
|
+
|
|
654
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
655
|
+
(b,) = f(a, s)
|
|
656
|
+
|
|
657
|
+
wp.synchronize_device(device)
|
|
658
|
+
|
|
659
|
+
expected = 3 * np.arange(ARRAY_SIZE, dtype=np.float32).reshape((ARRAY_SIZE // 2, 2))
|
|
660
|
+
|
|
661
|
+
assert_np_equal(b, expected)
|
|
662
|
+
|
|
663
|
+
|
|
664
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
665
|
+
def test_ffi_jax_kernel_launch_dims_default(test, device):
|
|
666
|
+
# specify default launch dims
|
|
667
|
+
import jax.numpy as jp
|
|
668
|
+
|
|
669
|
+
from warp.jax_experimental.ffi import jax_kernel
|
|
670
|
+
|
|
671
|
+
N, M, K = 3, 4, 2
|
|
672
|
+
|
|
673
|
+
jax_matmul = jax_kernel(matmul_kernel, launch_dims=(N, M))
|
|
674
|
+
|
|
675
|
+
@jax.jit
|
|
676
|
+
def f():
|
|
677
|
+
a = jp.full((N, K), 2, dtype=jp.float32)
|
|
678
|
+
b = jp.full((K, M), 3, dtype=jp.float32)
|
|
679
|
+
|
|
680
|
+
# use default launch dims
|
|
681
|
+
return jax_matmul(a, b)
|
|
682
|
+
|
|
683
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
684
|
+
(result,) = f()
|
|
685
|
+
|
|
686
|
+
wp.synchronize_device(device)
|
|
687
|
+
|
|
688
|
+
expected = np.full((3, 4), 12, dtype=np.float32)
|
|
689
|
+
|
|
690
|
+
test.assertEqual(result.shape, expected.shape)
|
|
691
|
+
assert_np_equal(result, expected)
|
|
692
|
+
|
|
693
|
+
|
|
694
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
695
|
+
def test_ffi_jax_kernel_launch_dims_custom(test, device):
|
|
696
|
+
# specify custom launch dims per call
|
|
697
|
+
import jax.numpy as jp
|
|
698
|
+
|
|
699
|
+
from warp.jax_experimental.ffi import jax_kernel
|
|
700
|
+
|
|
701
|
+
jax_matmul = jax_kernel(matmul_kernel)
|
|
702
|
+
|
|
703
|
+
@jax.jit
|
|
704
|
+
def f():
|
|
705
|
+
N1, M1, K1 = 3, 4, 2
|
|
706
|
+
a1 = jp.full((N1, K1), 2, dtype=jp.float32)
|
|
707
|
+
b1 = jp.full((K1, M1), 3, dtype=jp.float32)
|
|
708
|
+
|
|
709
|
+
# use custom launch dims
|
|
710
|
+
result1 = jax_matmul(a1, b1, launch_dims=(N1, M1))
|
|
711
|
+
|
|
712
|
+
N2, M2, K2 = 4, 3, 2
|
|
713
|
+
a2 = jp.full((N2, K2), 2, dtype=jp.float32)
|
|
714
|
+
b2 = jp.full((K2, M2), 3, dtype=jp.float32)
|
|
715
|
+
|
|
716
|
+
# use different custom launch dims
|
|
717
|
+
result2 = jax_matmul(a2, b2, launch_dims=(N2, M2))
|
|
718
|
+
|
|
719
|
+
return result1[0], result2[0]
|
|
720
|
+
|
|
721
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
722
|
+
result1, result2 = f()
|
|
723
|
+
|
|
724
|
+
wp.synchronize_device(device)
|
|
725
|
+
|
|
726
|
+
expected1 = np.full((3, 4), 12, dtype=np.float32)
|
|
727
|
+
expected2 = np.full((4, 3), 12, dtype=np.float32)
|
|
728
|
+
|
|
729
|
+
test.assertEqual(result1.shape, expected1.shape)
|
|
730
|
+
test.assertEqual(result2.shape, expected2.shape)
|
|
731
|
+
assert_np_equal(result1, expected1)
|
|
732
|
+
assert_np_equal(result2, expected2)
|
|
733
|
+
|
|
734
|
+
|
|
735
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
736
|
+
def test_ffi_jax_callable_scale_constant(test, device):
|
|
737
|
+
# scale two arrays using a constant
|
|
738
|
+
import jax.numpy as jp
|
|
739
|
+
|
|
740
|
+
from warp.jax_experimental.ffi import jax_callable
|
|
741
|
+
|
|
742
|
+
jax_func = jax_callable(scale_func, num_outputs=2)
|
|
743
|
+
|
|
744
|
+
@jax.jit
|
|
745
|
+
def f():
|
|
746
|
+
# inputs
|
|
747
|
+
a = jp.arange(ARRAY_SIZE, dtype=jp.float32)
|
|
748
|
+
b = jp.arange(ARRAY_SIZE, dtype=jp.float32).reshape((ARRAY_SIZE // 2, 2)) # wp.vec2
|
|
749
|
+
s = 2.0
|
|
750
|
+
|
|
751
|
+
# output shapes
|
|
752
|
+
output_dims = {"c": a.shape, "d": b.shape}
|
|
753
|
+
|
|
754
|
+
c, d = jax_func(a, b, s, output_dims=output_dims)
|
|
755
|
+
|
|
756
|
+
return c, d
|
|
757
|
+
|
|
758
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
759
|
+
result1, result2 = f()
|
|
760
|
+
|
|
761
|
+
wp.synchronize_device(device)
|
|
762
|
+
|
|
763
|
+
expected1 = 2 * np.arange(ARRAY_SIZE, dtype=np.float32)
|
|
764
|
+
expected2 = 2 * np.arange(ARRAY_SIZE, dtype=np.float32).reshape((ARRAY_SIZE // 2, 2))
|
|
765
|
+
|
|
766
|
+
assert_np_equal(result1, expected1)
|
|
767
|
+
assert_np_equal(result2, expected2)
|
|
768
|
+
|
|
769
|
+
|
|
770
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
771
|
+
def test_ffi_jax_callable_scale_static(test, device):
|
|
772
|
+
# scale two arrays using a static arg
|
|
773
|
+
import jax.numpy as jp
|
|
359
774
|
|
|
775
|
+
from warp.jax_experimental.ffi import jax_callable
|
|
776
|
+
|
|
777
|
+
jax_func = jax_callable(scale_func, num_outputs=2)
|
|
778
|
+
|
|
779
|
+
# NOTE: scalar arguments must be static compile-time constants
|
|
780
|
+
@partial(jax.jit, static_argnames=["s"])
|
|
781
|
+
def f(a, b, s):
|
|
782
|
+
# output shapes
|
|
783
|
+
output_dims = {"c": a.shape, "d": b.shape}
|
|
784
|
+
|
|
785
|
+
c, d = jax_func(a, b, s, output_dims=output_dims)
|
|
786
|
+
|
|
787
|
+
return c, d
|
|
788
|
+
|
|
789
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
790
|
+
# inputs
|
|
791
|
+
a = jp.arange(ARRAY_SIZE, dtype=jp.float32)
|
|
792
|
+
b = jp.arange(ARRAY_SIZE, dtype=jp.float32).reshape((ARRAY_SIZE // 2, 2)) # wp.vec2
|
|
793
|
+
s = 3.0
|
|
794
|
+
result1, result2 = f(a, b, s)
|
|
795
|
+
|
|
796
|
+
wp.synchronize_device(device)
|
|
797
|
+
|
|
798
|
+
expected1 = 3 * np.arange(ARRAY_SIZE, dtype=np.float32)
|
|
799
|
+
expected2 = 3 * np.arange(ARRAY_SIZE, dtype=np.float32).reshape((ARRAY_SIZE // 2, 2))
|
|
800
|
+
|
|
801
|
+
assert_np_equal(result1, expected1)
|
|
802
|
+
assert_np_equal(result2, expected2)
|
|
803
|
+
|
|
804
|
+
|
|
805
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
806
|
+
def test_ffi_jax_callable_in_out(test, device):
|
|
807
|
+
# in-out arguments
|
|
808
|
+
import jax.numpy as jp
|
|
809
|
+
|
|
810
|
+
from warp.jax_experimental.ffi import jax_callable
|
|
811
|
+
|
|
812
|
+
jax_func = jax_callable(in_out_func, num_outputs=2, in_out_argnames=["b"])
|
|
813
|
+
|
|
814
|
+
f = jax.jit(jax_func)
|
|
815
|
+
|
|
816
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
817
|
+
a = jp.ones(ARRAY_SIZE, dtype=jp.float32)
|
|
818
|
+
b = jp.arange(ARRAY_SIZE, dtype=jp.float32)
|
|
819
|
+
b, c = f(a, b)
|
|
820
|
+
|
|
821
|
+
wp.synchronize_device(device)
|
|
822
|
+
|
|
823
|
+
assert_np_equal(b, np.arange(1, ARRAY_SIZE + 1, dtype=np.float32))
|
|
824
|
+
assert_np_equal(c, np.full(ARRAY_SIZE, 2, dtype=np.float32))
|
|
825
|
+
|
|
826
|
+
|
|
827
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
828
|
+
def test_ffi_jax_callable_graph_cache(test, device):
|
|
829
|
+
# test graph caching limits
|
|
830
|
+
import jax
|
|
831
|
+
import jax.numpy as jp
|
|
832
|
+
|
|
833
|
+
from warp.jax_experimental.ffi import (
|
|
834
|
+
GraphMode,
|
|
835
|
+
clear_jax_callable_graph_cache,
|
|
836
|
+
get_jax_callable_default_graph_cache_max,
|
|
837
|
+
jax_callable,
|
|
838
|
+
set_jax_callable_default_graph_cache_max,
|
|
839
|
+
)
|
|
840
|
+
|
|
841
|
+
# --- test with default cache settings ---
|
|
842
|
+
|
|
843
|
+
jax_double = jax_callable(double_func, graph_mode=GraphMode.WARP)
|
|
844
|
+
f = jax.jit(jax_double)
|
|
845
|
+
arrays = []
|
|
846
|
+
|
|
847
|
+
test.assertEqual(jax_double.graph_cache_max, get_jax_callable_default_graph_cache_max())
|
|
848
|
+
|
|
849
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
850
|
+
for i in range(10):
|
|
851
|
+
n = 10 + i
|
|
852
|
+
a = jp.arange(n, dtype=jp.float32)
|
|
853
|
+
(b,) = f(a)
|
|
854
|
+
|
|
855
|
+
assert_np_equal(b, 2 * np.arange(n, dtype=np.float32))
|
|
856
|
+
|
|
857
|
+
# ensure graph cache is always growing
|
|
858
|
+
test.assertEqual(jax_double.graph_cache_size, i + 1)
|
|
859
|
+
|
|
860
|
+
# keep JAX array alive to prevent the memory from being reused, thus forcing a new graph capture each time
|
|
861
|
+
arrays.append(a)
|
|
862
|
+
|
|
863
|
+
# --- test clearing one callable's cache ---
|
|
864
|
+
|
|
865
|
+
clear_jax_callable_graph_cache(jax_double)
|
|
866
|
+
|
|
867
|
+
test.assertEqual(jax_double.graph_cache_size, 0)
|
|
868
|
+
|
|
869
|
+
# --- test with a custom cache limit ---
|
|
870
|
+
|
|
871
|
+
graph_cache_max = 5
|
|
872
|
+
jax_double = jax_callable(double_func, graph_mode=GraphMode.WARP, graph_cache_max=graph_cache_max)
|
|
873
|
+
f = jax.jit(jax_double)
|
|
874
|
+
arrays = []
|
|
875
|
+
|
|
876
|
+
test.assertEqual(jax_double.graph_cache_max, graph_cache_max)
|
|
877
|
+
|
|
878
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
879
|
+
for i in range(10):
|
|
880
|
+
n = 10 + i
|
|
881
|
+
a = jp.arange(n, dtype=jp.float32)
|
|
882
|
+
(b,) = f(a)
|
|
883
|
+
|
|
884
|
+
assert_np_equal(b, 2 * np.arange(n, dtype=np.float32))
|
|
885
|
+
|
|
886
|
+
# ensure graph cache size is capped
|
|
887
|
+
test.assertEqual(jax_double.graph_cache_size, min(i + 1, graph_cache_max))
|
|
888
|
+
|
|
889
|
+
# keep JAX array alive to prevent the memory from being reused, thus forcing a new graph capture
|
|
890
|
+
arrays.append(a)
|
|
891
|
+
|
|
892
|
+
# --- test clearing all callables' caches ---
|
|
893
|
+
|
|
894
|
+
clear_jax_callable_graph_cache()
|
|
895
|
+
|
|
896
|
+
with wp.jax_experimental.ffi._FFI_REGISTRY_LOCK:
|
|
897
|
+
for c in wp.jax_experimental.ffi._FFI_CALLABLE_REGISTRY.values():
|
|
898
|
+
test.assertEqual(c.graph_cache_size, 0)
|
|
899
|
+
|
|
900
|
+
# --- test with a custom default cache limit ---
|
|
901
|
+
|
|
902
|
+
saved_max = get_jax_callable_default_graph_cache_max()
|
|
903
|
+
try:
|
|
904
|
+
set_jax_callable_default_graph_cache_max(5)
|
|
905
|
+
jax_double = jax_callable(double_func, graph_mode=GraphMode.WARP)
|
|
906
|
+
f = jax.jit(jax_double)
|
|
907
|
+
arrays = []
|
|
908
|
+
|
|
909
|
+
test.assertEqual(jax_double.graph_cache_max, get_jax_callable_default_graph_cache_max())
|
|
910
|
+
|
|
911
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
912
|
+
for i in range(10):
|
|
913
|
+
n = 10 + i
|
|
914
|
+
a = jp.arange(n, dtype=jp.float32)
|
|
915
|
+
(b,) = f(a)
|
|
916
|
+
|
|
917
|
+
assert_np_equal(b, 2 * np.arange(n, dtype=np.float32))
|
|
918
|
+
|
|
919
|
+
# ensure graph cache size is capped
|
|
920
|
+
test.assertEqual(
|
|
921
|
+
jax_double.graph_cache_size,
|
|
922
|
+
min(i + 1, get_jax_callable_default_graph_cache_max()),
|
|
923
|
+
)
|
|
924
|
+
|
|
925
|
+
# keep JAX array alive to prevent the memory from being reused, thus forcing a new graph capture
|
|
926
|
+
arrays.append(a)
|
|
927
|
+
|
|
928
|
+
clear_jax_callable_graph_cache()
|
|
929
|
+
|
|
930
|
+
finally:
|
|
931
|
+
set_jax_callable_default_graph_cache_max(saved_max)
|
|
932
|
+
|
|
933
|
+
|
|
934
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
935
|
+
def test_ffi_jax_callable_pmap_mul(test, device):
|
|
936
|
+
import jax
|
|
937
|
+
import jax.numpy as jp
|
|
938
|
+
|
|
939
|
+
from warp.jax_experimental.ffi import jax_callable
|
|
940
|
+
|
|
941
|
+
j = jax_callable(double_func, num_outputs=1)
|
|
942
|
+
|
|
943
|
+
ndev = jax.local_device_count()
|
|
944
|
+
per_device = max(ARRAY_SIZE // ndev, 64)
|
|
945
|
+
x = jp.arange(ndev * per_device, dtype=jp.float32).reshape((ndev, per_device))
|
|
946
|
+
|
|
947
|
+
def per_device_func(v):
|
|
948
|
+
(y,) = j(v)
|
|
949
|
+
return y
|
|
950
|
+
|
|
951
|
+
y = jax.pmap(per_device_func)(x)
|
|
952
|
+
|
|
953
|
+
wp.synchronize()
|
|
954
|
+
|
|
955
|
+
assert_np_equal(np.asarray(y), 2 * np.asarray(x))
|
|
956
|
+
|
|
957
|
+
|
|
958
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
959
|
+
def test_ffi_jax_callable_pmap_multi_output(test, device):
|
|
960
|
+
import jax
|
|
961
|
+
import jax.numpy as jp
|
|
962
|
+
|
|
963
|
+
from warp.jax_experimental.ffi import jax_callable
|
|
964
|
+
|
|
965
|
+
def multi_out_py(
|
|
966
|
+
a: wp.array(dtype=float),
|
|
967
|
+
b: wp.array(dtype=float),
|
|
968
|
+
s: float,
|
|
969
|
+
c: wp.array(dtype=float),
|
|
970
|
+
d: wp.array(dtype=float),
|
|
971
|
+
):
|
|
972
|
+
wp.launch(multi_out_kernel, dim=a.shape, inputs=[a, b, s], outputs=[c, d])
|
|
973
|
+
|
|
974
|
+
j = jax_callable(multi_out_py, num_outputs=2)
|
|
975
|
+
|
|
976
|
+
ndev = jax.local_device_count()
|
|
977
|
+
per_device = max(ARRAY_SIZE // ndev, 64)
|
|
978
|
+
a = jp.arange(ndev * per_device, dtype=jp.float32).reshape((ndev, per_device))
|
|
979
|
+
b = jp.ones((ndev, per_device), dtype=jp.float32)
|
|
980
|
+
s = 3.0
|
|
981
|
+
|
|
982
|
+
def per_device_func(aa, bb):
|
|
983
|
+
c, d = j(aa, bb, s)
|
|
984
|
+
return c + d # simple combine to exercise both outputs
|
|
985
|
+
|
|
986
|
+
out = jax.pmap(per_device_func)(a, b)
|
|
987
|
+
|
|
988
|
+
wp.synchronize()
|
|
989
|
+
|
|
990
|
+
a_np = np.arange(ndev * per_device, dtype=np.float32).reshape((ndev, per_device))
|
|
991
|
+
b_np = np.ones((ndev, per_device), dtype=np.float32)
|
|
992
|
+
ref = (a_np + b_np) + s * a_np
|
|
993
|
+
assert_np_equal(np.asarray(out), ref)
|
|
994
|
+
|
|
995
|
+
|
|
996
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
997
|
+
def test_ffi_jax_callable_pmap_multi_stage(test, device):
|
|
998
|
+
import jax
|
|
999
|
+
import jax.numpy as jp
|
|
1000
|
+
|
|
1001
|
+
from warp.jax_experimental.ffi import jax_callable
|
|
1002
|
+
|
|
1003
|
+
def multi_stage_py(
|
|
1004
|
+
a: wp.array(dtype=float),
|
|
1005
|
+
b: wp.array(dtype=float),
|
|
1006
|
+
alpha: float,
|
|
1007
|
+
tmp: wp.array(dtype=float),
|
|
1008
|
+
out: wp.array(dtype=float),
|
|
1009
|
+
):
|
|
1010
|
+
wp.launch(add_kernel, dim=a.shape, inputs=[a, b], outputs=[tmp])
|
|
1011
|
+
wp.launch(axpy_kernel, dim=a.shape, inputs=[tmp, b, alpha], outputs=[out])
|
|
1012
|
+
|
|
1013
|
+
j = jax_callable(multi_stage_py, num_outputs=2)
|
|
1014
|
+
|
|
1015
|
+
ndev = jax.local_device_count()
|
|
1016
|
+
per_device = max(ARRAY_SIZE // ndev, 64)
|
|
1017
|
+
a = jp.arange(ndev * per_device, dtype=jp.float32).reshape((ndev, per_device))
|
|
1018
|
+
b = jp.ones((ndev, per_device), dtype=jp.float32)
|
|
1019
|
+
alpha = 2.5
|
|
1020
|
+
|
|
1021
|
+
def per_device_func(aa, bb):
|
|
1022
|
+
tmp, out = j(aa, bb, alpha)
|
|
1023
|
+
return tmp + out
|
|
1024
|
+
|
|
1025
|
+
combined = jax.pmap(per_device_func)(a, b)
|
|
1026
|
+
|
|
1027
|
+
wp.synchronize()
|
|
1028
|
+
|
|
1029
|
+
a_np = np.arange(ndev * per_device, dtype=np.float32).reshape((ndev, per_device))
|
|
1030
|
+
b_np = np.ones((ndev, per_device), dtype=np.float32)
|
|
1031
|
+
tmp_ref = a_np + b_np
|
|
1032
|
+
out_ref = alpha * (a_np + b_np) + b_np
|
|
1033
|
+
ref = tmp_ref + out_ref
|
|
1034
|
+
assert_np_equal(np.asarray(combined), ref)
|
|
1035
|
+
|
|
1036
|
+
|
|
1037
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
1038
|
+
def test_ffi_callback(test, device):
|
|
1039
|
+
# in-out arguments
|
|
1040
|
+
import jax.numpy as jp
|
|
1041
|
+
|
|
1042
|
+
from warp.jax_experimental.ffi import register_ffi_callback
|
|
1043
|
+
|
|
1044
|
+
# the Python function to call
|
|
1045
|
+
def warp_func(inputs, outputs, attrs, ctx):
|
|
1046
|
+
# input arrays
|
|
1047
|
+
a = inputs[0]
|
|
1048
|
+
b = inputs[1]
|
|
1049
|
+
|
|
1050
|
+
# scalar attributes
|
|
1051
|
+
s = attrs["scale"]
|
|
1052
|
+
|
|
1053
|
+
# output arrays
|
|
1054
|
+
c = outputs[0]
|
|
1055
|
+
d = outputs[1]
|
|
1056
|
+
|
|
1057
|
+
device = wp.device_from_jax(get_jax_device())
|
|
1058
|
+
stream = wp.Stream(device, cuda_stream=ctx.stream)
|
|
1059
|
+
|
|
1060
|
+
with wp.ScopedStream(stream):
|
|
1061
|
+
# launch with arrays of scalars
|
|
1062
|
+
wp.launch(scale_kernel, dim=a.shape, inputs=[a, s], outputs=[c])
|
|
1063
|
+
|
|
1064
|
+
# launch with arrays of vec2
|
|
1065
|
+
# NOTE: the input shapes are from JAX arrays, we need to strip the inner dimension for vec2 arrays
|
|
1066
|
+
wp.launch(scale_vec_kernel, dim=b.shape[0], inputs=[b, s], outputs=[d])
|
|
1067
|
+
|
|
1068
|
+
# register callback
|
|
1069
|
+
register_ffi_callback("warp_func", warp_func)
|
|
1070
|
+
|
|
1071
|
+
n = ARRAY_SIZE
|
|
1072
|
+
|
|
1073
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
1074
|
+
# inputs
|
|
1075
|
+
a = jp.arange(n, dtype=jp.float32)
|
|
1076
|
+
b = jp.arange(n, dtype=jp.float32).reshape((n // 2, 2)) # array of wp.vec2
|
|
1077
|
+
s = 2.0
|
|
1078
|
+
|
|
1079
|
+
# set up call
|
|
1080
|
+
out_types = [
|
|
1081
|
+
jax.ShapeDtypeStruct(a.shape, jp.float32),
|
|
1082
|
+
jax.ShapeDtypeStruct(b.shape, jp.float32), # array of wp.vec2
|
|
1083
|
+
]
|
|
1084
|
+
call = jax.ffi.ffi_call("warp_func", out_types)
|
|
1085
|
+
|
|
1086
|
+
# call it
|
|
1087
|
+
c, d = call(a, b, scale=s)
|
|
1088
|
+
|
|
1089
|
+
wp.synchronize_device(device)
|
|
1090
|
+
|
|
1091
|
+
assert_np_equal(c, 2 * np.arange(ARRAY_SIZE, dtype=np.float32))
|
|
1092
|
+
assert_np_equal(d, 2 * np.arange(ARRAY_SIZE, dtype=np.float32).reshape((ARRAY_SIZE // 2, 2)))
|
|
1093
|
+
|
|
1094
|
+
|
|
1095
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
1096
|
+
def test_ffi_jax_kernel_autodiff_simple(test, device):
|
|
1097
|
+
import jax
|
|
1098
|
+
import jax.numpy as jp
|
|
1099
|
+
|
|
1100
|
+
from warp.jax_experimental.ffi import jax_kernel
|
|
1101
|
+
|
|
1102
|
+
jax_func = jax_kernel(
|
|
1103
|
+
scale_sum_square_kernel,
|
|
1104
|
+
num_outputs=1,
|
|
1105
|
+
enable_backward=True,
|
|
1106
|
+
)
|
|
1107
|
+
|
|
1108
|
+
from functools import partial
|
|
1109
|
+
|
|
1110
|
+
@partial(jax.jit, static_argnames=["s"])
|
|
1111
|
+
def loss(a, b, s):
|
|
1112
|
+
out = jax_func(a, b, s)[0]
|
|
1113
|
+
return jp.sum(out)
|
|
1114
|
+
|
|
1115
|
+
n = ARRAY_SIZE
|
|
1116
|
+
a = jp.arange(n, dtype=jp.float32)
|
|
1117
|
+
b = jp.ones(n, dtype=jp.float32)
|
|
1118
|
+
s = 2.0
|
|
1119
|
+
|
|
1120
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
1121
|
+
da, db = jax.grad(loss, argnums=(0, 1))(a, b, s)
|
|
1122
|
+
|
|
1123
|
+
wp.synchronize_device(device)
|
|
1124
|
+
|
|
1125
|
+
# reference gradients
|
|
1126
|
+
# d/da sum((a*s + b)^2) = sum(2*(a*s + b) * s)
|
|
1127
|
+
# d/db sum((a*s + b)^2) = sum(2*(a*s + b))
|
|
1128
|
+
a_np = np.arange(n, dtype=np.float32)
|
|
1129
|
+
b_np = np.ones(n, dtype=np.float32)
|
|
1130
|
+
ref_da = 2.0 * (a_np * s + b_np) * s
|
|
1131
|
+
ref_db = 2.0 * (a_np * s + b_np)
|
|
1132
|
+
|
|
1133
|
+
assert_np_equal(np.asarray(da), ref_da)
|
|
1134
|
+
assert_np_equal(np.asarray(db), ref_db)
|
|
1135
|
+
|
|
1136
|
+
|
|
1137
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
1138
|
+
def test_ffi_jax_kernel_autodiff_jit_of_grad_simple(test, device):
|
|
1139
|
+
import jax
|
|
1140
|
+
import jax.numpy as jp
|
|
1141
|
+
|
|
1142
|
+
from warp.jax_experimental.ffi import jax_kernel
|
|
1143
|
+
|
|
1144
|
+
jax_func = jax_kernel(scale_sum_square_kernel, num_outputs=1, enable_backward=True)
|
|
1145
|
+
|
|
1146
|
+
def loss(a, b, s):
|
|
1147
|
+
out = jax_func(a, b, s)[0]
|
|
1148
|
+
return jp.sum(out)
|
|
1149
|
+
|
|
1150
|
+
grad_fn = jax.grad(loss, argnums=(0, 1))
|
|
1151
|
+
|
|
1152
|
+
# more typical: jit(grad(...)) with static scalar
|
|
1153
|
+
jitted_grad = jax.jit(lambda a, b, s: grad_fn(a, b, s), static_argnames=("s",))
|
|
1154
|
+
|
|
1155
|
+
n = ARRAY_SIZE
|
|
1156
|
+
a = jp.arange(n, dtype=jp.float32)
|
|
1157
|
+
b = jp.ones(n, dtype=jp.float32)
|
|
1158
|
+
s = 2.0
|
|
1159
|
+
|
|
1160
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
1161
|
+
da, db = jitted_grad(a, b, s)
|
|
1162
|
+
|
|
1163
|
+
wp.synchronize_device(device)
|
|
1164
|
+
|
|
1165
|
+
a_np = np.arange(n, dtype=np.float32)
|
|
1166
|
+
b_np = np.ones(n, dtype=np.float32)
|
|
1167
|
+
ref_da = 2.0 * (a_np * s + b_np) * s
|
|
1168
|
+
ref_db = 2.0 * (a_np * s + b_np)
|
|
1169
|
+
|
|
1170
|
+
assert_np_equal(np.asarray(da), ref_da)
|
|
1171
|
+
assert_np_equal(np.asarray(db), ref_db)
|
|
1172
|
+
|
|
1173
|
+
|
|
1174
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
1175
|
+
def test_ffi_jax_kernel_autodiff_multi_output(test, device):
|
|
1176
|
+
import jax
|
|
1177
|
+
import jax.numpy as jp
|
|
1178
|
+
|
|
1179
|
+
from warp.jax_experimental.ffi import jax_kernel
|
|
1180
|
+
|
|
1181
|
+
jax_func = jax_kernel(multi_out_kernel_v3, num_outputs=2, enable_backward=True)
|
|
1182
|
+
|
|
1183
|
+
def caller(fn, a, b, s):
|
|
1184
|
+
c, d = fn(a, b, s)
|
|
1185
|
+
return jp.sum(c + d)
|
|
1186
|
+
|
|
1187
|
+
@jax.jit
|
|
1188
|
+
def grads(a, b, s):
|
|
1189
|
+
# mark s as static in the inner call via partial to avoid hashing
|
|
1190
|
+
def _inner(a, b, s):
|
|
1191
|
+
return caller(jax_func, a, b, s)
|
|
1192
|
+
|
|
1193
|
+
return jax.grad(lambda a, b: _inner(a, b, 2.0), argnums=(0, 1))(a, b)
|
|
1194
|
+
|
|
1195
|
+
n = ARRAY_SIZE
|
|
1196
|
+
a = jp.arange(n, dtype=jp.float32)
|
|
1197
|
+
b = jp.ones(n, dtype=jp.float32)
|
|
1198
|
+
s = 2.0
|
|
1199
|
+
|
|
1200
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
1201
|
+
da, db = grads(a, b, s)
|
|
1202
|
+
|
|
1203
|
+
wp.synchronize_device(device)
|
|
1204
|
+
|
|
1205
|
+
a_np = np.arange(n, dtype=np.float32)
|
|
1206
|
+
b_np = np.ones(n, dtype=np.float32)
|
|
1207
|
+
# d/da sum(c+d) = 2*a + b*s
|
|
1208
|
+
ref_da = 2.0 * a_np + b_np * s
|
|
1209
|
+
# d/db sum(c+d) = a*s
|
|
1210
|
+
ref_db = a_np * s
|
|
1211
|
+
|
|
1212
|
+
assert_np_equal(np.asarray(da), ref_da)
|
|
1213
|
+
assert_np_equal(np.asarray(db), ref_db)
|
|
1214
|
+
|
|
1215
|
+
|
|
1216
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
1217
|
+
def test_ffi_jax_kernel_autodiff_jit_of_grad_multi_output(test, device):
|
|
1218
|
+
import jax
|
|
1219
|
+
import jax.numpy as jp
|
|
1220
|
+
|
|
1221
|
+
from warp.jax_experimental.ffi import jax_kernel
|
|
1222
|
+
|
|
1223
|
+
jax_func = jax_kernel(multi_out_kernel_v3, num_outputs=2, enable_backward=True)
|
|
1224
|
+
|
|
1225
|
+
def loss(a, b, s):
|
|
1226
|
+
c, d = jax_func(a, b, s)
|
|
1227
|
+
return jp.sum(c + d)
|
|
1228
|
+
|
|
1229
|
+
grad_fn = jax.grad(loss, argnums=(0, 1))
|
|
1230
|
+
jitted_grad = jax.jit(lambda a, b, s: grad_fn(a, b, s), static_argnames=("s",))
|
|
1231
|
+
|
|
1232
|
+
n = ARRAY_SIZE
|
|
1233
|
+
a = jp.arange(n, dtype=jp.float32)
|
|
1234
|
+
b = jp.ones(n, dtype=jp.float32)
|
|
1235
|
+
s = 2.0
|
|
1236
|
+
|
|
1237
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
1238
|
+
da, db = jitted_grad(a, b, s)
|
|
1239
|
+
|
|
1240
|
+
wp.synchronize_device(device)
|
|
1241
|
+
|
|
1242
|
+
a_np = np.arange(n, dtype=np.float32)
|
|
1243
|
+
b_np = np.ones(n, dtype=np.float32)
|
|
1244
|
+
ref_da = 2.0 * a_np + b_np * s
|
|
1245
|
+
ref_db = a_np * s
|
|
1246
|
+
|
|
1247
|
+
assert_np_equal(np.asarray(da), ref_da)
|
|
1248
|
+
assert_np_equal(np.asarray(db), ref_db)
|
|
1249
|
+
|
|
1250
|
+
|
|
1251
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
1252
|
+
def test_ffi_jax_kernel_autodiff_2d(test, device):
|
|
1253
|
+
import jax
|
|
1254
|
+
import jax.numpy as jp
|
|
1255
|
+
|
|
1256
|
+
from warp.jax_experimental.ffi import jax_kernel
|
|
1257
|
+
|
|
1258
|
+
jax_func = jax_kernel(inc_2d_kernel, num_outputs=1, enable_backward=True)
|
|
1259
|
+
|
|
1260
|
+
@jax.jit
|
|
1261
|
+
def loss(a):
|
|
1262
|
+
out = jax_func(a)[0]
|
|
1263
|
+
return jp.sum(out)
|
|
1264
|
+
|
|
1265
|
+
n, m = 8, 6
|
|
1266
|
+
a = jp.arange(n * m, dtype=jp.float32).reshape((n, m))
|
|
1267
|
+
|
|
1268
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
1269
|
+
(da,) = jax.grad(loss, argnums=(0,))(a)
|
|
1270
|
+
|
|
1271
|
+
wp.synchronize_device(device)
|
|
1272
|
+
|
|
1273
|
+
ref = np.ones((n, m), dtype=np.float32)
|
|
1274
|
+
assert_np_equal(np.asarray(da), ref)
|
|
1275
|
+
|
|
1276
|
+
|
|
1277
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
1278
|
+
def test_ffi_jax_kernel_autodiff_vec2(test, device):
|
|
1279
|
+
import jax
|
|
1280
|
+
import jax.numpy as jp
|
|
1281
|
+
|
|
1282
|
+
from warp.jax_experimental.ffi import jax_kernel
|
|
1283
|
+
|
|
1284
|
+
jax_func = jax_kernel(scale_vec_kernel, num_outputs=1, enable_backward=True)
|
|
1285
|
+
|
|
1286
|
+
from functools import partial
|
|
1287
|
+
|
|
1288
|
+
@partial(jax.jit, static_argnames=("s",))
|
|
1289
|
+
def loss(a, s):
|
|
1290
|
+
out = jax_func(a, s)[0]
|
|
1291
|
+
return jp.sum(out)
|
|
1292
|
+
|
|
1293
|
+
n = ARRAY_SIZE
|
|
1294
|
+
a = jp.arange(n, dtype=jp.float32).reshape((n // 2, 2))
|
|
1295
|
+
s = 3.0
|
|
1296
|
+
|
|
1297
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
1298
|
+
(da,) = jax.grad(loss, argnums=(0,))(a, s)
|
|
1299
|
+
|
|
1300
|
+
wp.synchronize_device(device)
|
|
1301
|
+
|
|
1302
|
+
# d/da sum(a*s) = s
|
|
1303
|
+
ref = np.full_like(np.asarray(a), s)
|
|
1304
|
+
assert_np_equal(np.asarray(da), ref)
|
|
1305
|
+
|
|
1306
|
+
|
|
1307
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
1308
|
+
def test_ffi_jax_kernel_autodiff_mat22(test, device):
|
|
1309
|
+
import jax
|
|
1310
|
+
import jax.numpy as jp
|
|
1311
|
+
|
|
1312
|
+
from warp.jax_experimental.ffi import jax_kernel
|
|
1313
|
+
|
|
1314
|
+
@wp.kernel
|
|
1315
|
+
def scale_mat_kernel(a: wp.array(dtype=wp.mat22), s: float, out: wp.array(dtype=wp.mat22)):
|
|
1316
|
+
tid = wp.tid()
|
|
1317
|
+
out[tid] = a[tid] * s
|
|
1318
|
+
|
|
1319
|
+
jax_func = jax_kernel(scale_mat_kernel, num_outputs=1, enable_backward=True)
|
|
1320
|
+
|
|
1321
|
+
from functools import partial
|
|
1322
|
+
|
|
1323
|
+
@partial(jax.jit, static_argnames=("s",))
|
|
1324
|
+
def loss(a, s):
|
|
1325
|
+
out = jax_func(a, s)[0]
|
|
1326
|
+
return jp.sum(out)
|
|
1327
|
+
|
|
1328
|
+
n = 12 # must be divisible by 4 for 2x2 matrices
|
|
1329
|
+
a = jp.arange(n, dtype=jp.float32).reshape((n // 4, 2, 2))
|
|
1330
|
+
s = 2.5
|
|
1331
|
+
|
|
1332
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
1333
|
+
(da,) = jax.grad(loss, argnums=(0,))(a, s)
|
|
1334
|
+
|
|
1335
|
+
wp.synchronize_device(device)
|
|
1336
|
+
|
|
1337
|
+
ref = np.full((n // 4, 2, 2), s, dtype=np.float32)
|
|
1338
|
+
assert_np_equal(np.asarray(da), ref)
|
|
1339
|
+
|
|
1340
|
+
|
|
1341
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
1342
|
+
def test_ffi_jax_kernel_autodiff_static_required(test, device):
|
|
1343
|
+
import jax
|
|
1344
|
+
import jax.numpy as jp
|
|
1345
|
+
|
|
1346
|
+
from warp.jax_experimental.ffi import jax_kernel
|
|
1347
|
+
|
|
1348
|
+
# Require explicit static_argnames for scalar s
|
|
1349
|
+
jax_func = jax_kernel(scale_sum_square_kernel, num_outputs=1, enable_backward=True)
|
|
1350
|
+
|
|
1351
|
+
def loss(a, b, s):
|
|
1352
|
+
out = jax_func(a, b, s)[0]
|
|
1353
|
+
return jp.sum(out)
|
|
1354
|
+
|
|
1355
|
+
n = ARRAY_SIZE
|
|
1356
|
+
a = jp.arange(n, dtype=jp.float32)
|
|
1357
|
+
b = jp.ones(n, dtype=jp.float32)
|
|
1358
|
+
s = 1.5
|
|
1359
|
+
|
|
1360
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
1361
|
+
da, db = jax.grad(loss, argnums=(0, 1))(a, b, s)
|
|
1362
|
+
|
|
1363
|
+
wp.synchronize_device(device)
|
|
1364
|
+
|
|
1365
|
+
a_np = np.arange(n, dtype=np.float32)
|
|
1366
|
+
b_np = np.ones(n, dtype=np.float32)
|
|
1367
|
+
ref_da = 2.0 * (a_np * s + b_np) * s
|
|
1368
|
+
ref_db = 2.0 * (a_np * s + b_np)
|
|
1369
|
+
|
|
1370
|
+
assert_np_equal(np.asarray(da), ref_da)
|
|
1371
|
+
assert_np_equal(np.asarray(db), ref_db)
|
|
1372
|
+
|
|
1373
|
+
|
|
1374
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
1375
|
+
def test_ffi_jax_kernel_autodiff_pmap_triple(test, device):
|
|
1376
|
+
import jax
|
|
1377
|
+
import jax.numpy as jp
|
|
1378
|
+
|
|
1379
|
+
from warp.jax_experimental.ffi import jax_kernel
|
|
1380
|
+
|
|
1381
|
+
jax_mul = jax_kernel(triple_kernel, num_outputs=1, enable_backward=True)
|
|
1382
|
+
|
|
1383
|
+
ndev = jax.local_device_count()
|
|
1384
|
+
per_device = ARRAY_SIZE // ndev
|
|
1385
|
+
x = jp.arange(ndev * per_device, dtype=jp.float32).reshape((ndev, per_device))
|
|
1386
|
+
|
|
1387
|
+
def per_device_loss(x):
|
|
1388
|
+
y = jax_mul(x)[0]
|
|
1389
|
+
return jp.sum(y)
|
|
1390
|
+
|
|
1391
|
+
grads = jax.pmap(jax.grad(per_device_loss))(x)
|
|
1392
|
+
|
|
1393
|
+
wp.synchronize()
|
|
1394
|
+
|
|
1395
|
+
assert_np_equal(np.asarray(grads), np.full((ndev, per_device), 3.0, dtype=np.float32))
|
|
1396
|
+
|
|
1397
|
+
|
|
1398
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
1399
|
+
def test_ffi_jax_kernel_autodiff_pmap_multi_output(test, device):
|
|
1400
|
+
import jax
|
|
1401
|
+
import jax.numpy as jp
|
|
1402
|
+
|
|
1403
|
+
from warp.jax_experimental.ffi import jax_kernel
|
|
1404
|
+
|
|
1405
|
+
jax_mo = jax_kernel(multi_out_kernel_v2, num_outputs=2, enable_backward=True)
|
|
1406
|
+
|
|
1407
|
+
ndev = jax.local_device_count()
|
|
1408
|
+
per_device = ARRAY_SIZE // ndev
|
|
1409
|
+
a = jp.arange(ndev * per_device, dtype=jp.float32).reshape((ndev, per_device))
|
|
1410
|
+
b = jp.arange(ndev * per_device, dtype=jp.float32).reshape((ndev, per_device))
|
|
1411
|
+
s = 2.0
|
|
1412
|
+
|
|
1413
|
+
def per_dev_loss(aa, bb):
|
|
1414
|
+
c, d = jax_mo(aa, bb, s)
|
|
1415
|
+
return jp.sum(c + d)
|
|
1416
|
+
|
|
1417
|
+
da, db = jax.pmap(jax.grad(per_dev_loss, argnums=(0, 1)))(a, b)
|
|
1418
|
+
|
|
1419
|
+
wp.synchronize()
|
|
1420
|
+
|
|
1421
|
+
a_np = np.arange(ndev * per_device, dtype=np.float32).reshape((ndev, per_device))
|
|
1422
|
+
b_np = np.arange(ndev * per_device, dtype=np.float32).reshape((ndev, per_device))
|
|
1423
|
+
ref_da = 2.0 * a_np + b_np * s
|
|
1424
|
+
ref_db = a_np * s
|
|
1425
|
+
assert_np_equal(np.asarray(da), ref_da)
|
|
1426
|
+
assert_np_equal(np.asarray(db), ref_db)
|
|
1427
|
+
|
|
1428
|
+
|
|
1429
|
+
class TestJax(unittest.TestCase):
|
|
1430
|
+
pass
|
|
1431
|
+
|
|
1432
|
+
|
|
1433
|
+
# try adding Jax tests if Jax is installed correctly
|
|
1434
|
+
try:
|
|
1435
|
+
# prevent Jax from gobbling up GPU memory
|
|
1436
|
+
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
|
|
1437
|
+
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
|
|
1438
|
+
|
|
1439
|
+
import jax
|
|
1440
|
+
|
|
1441
|
+
# NOTE: we must enable 64-bit types in Jax to test the full gamut of types
|
|
1442
|
+
jax.config.update("jax_enable_x64", True)
|
|
1443
|
+
|
|
1444
|
+
# check which Warp devices work with Jax
|
|
1445
|
+
# CUDA devices may fail if Jax cannot find a CUDA Toolkit
|
|
1446
|
+
test_devices = get_test_devices()
|
|
1447
|
+
jax_compatible_devices = []
|
|
1448
|
+
jax_compatible_cuda_devices = []
|
|
1449
|
+
for d in test_devices:
|
|
1450
|
+
try:
|
|
1451
|
+
with jax.default_device(wp.device_to_jax(d)):
|
|
1452
|
+
j = jax.numpy.arange(10, dtype=jax.numpy.float32)
|
|
1453
|
+
j += 1
|
|
1454
|
+
jax_compatible_devices.append(d)
|
|
1455
|
+
if d.is_cuda:
|
|
1456
|
+
jax_compatible_cuda_devices.append(d)
|
|
1457
|
+
except Exception as e:
|
|
1458
|
+
print(f"Skipping Jax DLPack tests on device '{d}' due to exception: {e}")
|
|
1459
|
+
|
|
1460
|
+
add_function_test(TestJax, "test_dtype_from_jax", test_dtype_from_jax, devices=None)
|
|
1461
|
+
add_function_test(TestJax, "test_dtype_to_jax", test_dtype_to_jax, devices=None)
|
|
1462
|
+
|
|
1463
|
+
if jax_compatible_devices:
|
|
1464
|
+
add_function_test(TestJax, "test_device_conversion", test_device_conversion, devices=jax_compatible_devices)
|
|
1465
|
+
|
|
1466
|
+
if jax_compatible_cuda_devices:
|
|
1467
|
+
# tests for both custom_call and ffi variants of jax_kernel(), selected by installed JAX version
|
|
1468
|
+
if jax.__version_info__ < (0, 4, 25):
|
|
1469
|
+
# no interop supported
|
|
1470
|
+
ffi_opts = []
|
|
1471
|
+
elif jax.__version_info__ < (0, 5, 0):
|
|
1472
|
+
# only custom_call supported
|
|
1473
|
+
ffi_opts = [False]
|
|
1474
|
+
elif jax.__version_info__ < (0, 8, 0):
|
|
1475
|
+
# both custom_call and ffi supported
|
|
1476
|
+
ffi_opts = [False, True]
|
|
1477
|
+
else:
|
|
1478
|
+
# only ffi supported
|
|
1479
|
+
ffi_opts = [True]
|
|
1480
|
+
|
|
1481
|
+
for use_ffi in ffi_opts:
|
|
1482
|
+
suffix = "ffi" if use_ffi else "cc"
|
|
1483
|
+
add_function_test(
|
|
1484
|
+
TestJax,
|
|
1485
|
+
f"test_jax_kernel_basic_{suffix}",
|
|
1486
|
+
test_jax_kernel_basic,
|
|
1487
|
+
devices=jax_compatible_cuda_devices,
|
|
1488
|
+
use_ffi=use_ffi,
|
|
1489
|
+
)
|
|
1490
|
+
add_function_test(
|
|
1491
|
+
TestJax,
|
|
1492
|
+
f"test_jax_kernel_scalar_{suffix}",
|
|
1493
|
+
test_jax_kernel_scalar,
|
|
1494
|
+
devices=jax_compatible_cuda_devices,
|
|
1495
|
+
use_ffi=use_ffi,
|
|
1496
|
+
)
|
|
1497
|
+
add_function_test(
|
|
1498
|
+
TestJax,
|
|
1499
|
+
f"test_jax_kernel_vecmat_{suffix}",
|
|
1500
|
+
test_jax_kernel_vecmat,
|
|
1501
|
+
devices=jax_compatible_cuda_devices,
|
|
1502
|
+
use_ffi=use_ffi,
|
|
1503
|
+
)
|
|
1504
|
+
add_function_test(
|
|
1505
|
+
TestJax,
|
|
1506
|
+
f"test_jax_kernel_multiarg_{suffix}",
|
|
1507
|
+
test_jax_kernel_multiarg,
|
|
1508
|
+
devices=jax_compatible_cuda_devices,
|
|
1509
|
+
use_ffi=use_ffi,
|
|
1510
|
+
)
|
|
1511
|
+
add_function_test(
|
|
1512
|
+
TestJax,
|
|
1513
|
+
f"test_jax_kernel_launch_dims_{suffix}",
|
|
1514
|
+
test_jax_kernel_launch_dims,
|
|
1515
|
+
devices=jax_compatible_cuda_devices,
|
|
1516
|
+
use_ffi=use_ffi,
|
|
1517
|
+
)
|
|
1518
|
+
|
|
1519
|
+
# ffi.jax_kernel() tests
|
|
1520
|
+
add_function_test(
|
|
1521
|
+
TestJax, "test_ffi_jax_kernel_add", test_ffi_jax_kernel_add, devices=jax_compatible_cuda_devices
|
|
1522
|
+
)
|
|
1523
|
+
add_function_test(
|
|
1524
|
+
TestJax, "test_ffi_jax_kernel_sincos", test_ffi_jax_kernel_sincos, devices=jax_compatible_cuda_devices
|
|
1525
|
+
)
|
|
1526
|
+
add_function_test(
|
|
1527
|
+
TestJax, "test_ffi_jax_kernel_diagonal", test_ffi_jax_kernel_diagonal, devices=jax_compatible_cuda_devices
|
|
1528
|
+
)
|
|
1529
|
+
add_function_test(
|
|
1530
|
+
TestJax, "test_ffi_jax_kernel_in_out", test_ffi_jax_kernel_in_out, devices=jax_compatible_cuda_devices
|
|
1531
|
+
)
|
|
1532
|
+
add_function_test(
|
|
1533
|
+
TestJax,
|
|
1534
|
+
"test_ffi_jax_kernel_scale_vec_constant",
|
|
1535
|
+
test_ffi_jax_kernel_scale_vec_constant,
|
|
1536
|
+
devices=jax_compatible_cuda_devices,
|
|
1537
|
+
)
|
|
1538
|
+
add_function_test(
|
|
1539
|
+
TestJax,
|
|
1540
|
+
"test_ffi_jax_kernel_scale_vec_static",
|
|
1541
|
+
test_ffi_jax_kernel_scale_vec_static,
|
|
1542
|
+
devices=jax_compatible_cuda_devices,
|
|
1543
|
+
)
|
|
1544
|
+
add_function_test(
|
|
1545
|
+
TestJax,
|
|
1546
|
+
"test_ffi_jax_kernel_launch_dims_default",
|
|
1547
|
+
test_ffi_jax_kernel_launch_dims_default,
|
|
1548
|
+
devices=jax_compatible_cuda_devices,
|
|
1549
|
+
)
|
|
1550
|
+
add_function_test(
|
|
1551
|
+
TestJax,
|
|
1552
|
+
"test_ffi_jax_kernel_launch_dims_custom",
|
|
1553
|
+
test_ffi_jax_kernel_launch_dims_custom,
|
|
1554
|
+
devices=jax_compatible_cuda_devices,
|
|
1555
|
+
)
|
|
1556
|
+
|
|
1557
|
+
# ffi.jax_callable() tests
|
|
1558
|
+
add_function_test(
|
|
1559
|
+
TestJax,
|
|
1560
|
+
"test_ffi_jax_callable_scale_constant",
|
|
1561
|
+
test_ffi_jax_callable_scale_constant,
|
|
1562
|
+
devices=jax_compatible_cuda_devices,
|
|
1563
|
+
)
|
|
1564
|
+
add_function_test(
|
|
1565
|
+
TestJax,
|
|
1566
|
+
"test_ffi_jax_callable_scale_static",
|
|
1567
|
+
test_ffi_jax_callable_scale_static,
|
|
1568
|
+
devices=jax_compatible_cuda_devices,
|
|
1569
|
+
)
|
|
1570
|
+
add_function_test(
|
|
1571
|
+
TestJax, "test_ffi_jax_callable_in_out", test_ffi_jax_callable_in_out, devices=jax_compatible_cuda_devices
|
|
1572
|
+
)
|
|
1573
|
+
add_function_test(
|
|
1574
|
+
TestJax,
|
|
1575
|
+
"test_ffi_jax_callable_graph_cache",
|
|
1576
|
+
test_ffi_jax_callable_graph_cache,
|
|
1577
|
+
devices=jax_compatible_cuda_devices,
|
|
1578
|
+
)
|
|
1579
|
+
|
|
1580
|
+
# pmap tests
|
|
1581
|
+
add_function_test(
|
|
1582
|
+
TestJax,
|
|
1583
|
+
"test_ffi_jax_callable_pmap_multi_output",
|
|
1584
|
+
test_ffi_jax_callable_pmap_multi_output,
|
|
1585
|
+
devices=None,
|
|
1586
|
+
)
|
|
1587
|
+
add_function_test(
|
|
1588
|
+
TestJax,
|
|
1589
|
+
"test_ffi_jax_callable_pmap_mul",
|
|
1590
|
+
test_ffi_jax_callable_pmap_mul,
|
|
1591
|
+
devices=None,
|
|
1592
|
+
)
|
|
1593
|
+
add_function_test(
|
|
1594
|
+
TestJax,
|
|
1595
|
+
"test_ffi_jax_callable_pmap_multi_stage",
|
|
1596
|
+
test_ffi_jax_callable_pmap_multi_stage,
|
|
1597
|
+
devices=None,
|
|
1598
|
+
)
|
|
1599
|
+
|
|
1600
|
+
# ffi callback tests
|
|
1601
|
+
add_function_test(TestJax, "test_ffi_callback", test_ffi_callback, devices=jax_compatible_cuda_devices)
|
|
1602
|
+
|
|
1603
|
+
# autodiff tests
|
|
1604
|
+
add_function_test(
|
|
1605
|
+
TestJax,
|
|
1606
|
+
"test_ffi_jax_kernel_autodiff_simple",
|
|
1607
|
+
test_ffi_jax_kernel_autodiff_simple,
|
|
1608
|
+
devices=jax_compatible_cuda_devices,
|
|
1609
|
+
)
|
|
1610
|
+
add_function_test(
|
|
1611
|
+
TestJax,
|
|
1612
|
+
"test_ffi_jax_kernel_autodiff_jit_of_grad_simple",
|
|
1613
|
+
test_ffi_jax_kernel_autodiff_jit_of_grad_simple,
|
|
1614
|
+
devices=jax_compatible_cuda_devices,
|
|
1615
|
+
)
|
|
1616
|
+
add_function_test(
|
|
1617
|
+
TestJax,
|
|
1618
|
+
"test_ffi_jax_kernel_autodiff_multi_output",
|
|
1619
|
+
test_ffi_jax_kernel_autodiff_multi_output,
|
|
1620
|
+
devices=jax_compatible_cuda_devices,
|
|
1621
|
+
)
|
|
1622
|
+
add_function_test(
|
|
1623
|
+
TestJax,
|
|
1624
|
+
"test_ffi_jax_kernel_autodiff_jit_of_grad_multi_output",
|
|
1625
|
+
test_ffi_jax_kernel_autodiff_jit_of_grad_multi_output,
|
|
1626
|
+
devices=jax_compatible_cuda_devices,
|
|
1627
|
+
)
|
|
1628
|
+
add_function_test(
|
|
1629
|
+
TestJax,
|
|
1630
|
+
"test_ffi_jax_kernel_autodiff_2d",
|
|
1631
|
+
test_ffi_jax_kernel_autodiff_2d,
|
|
1632
|
+
devices=jax_compatible_cuda_devices,
|
|
1633
|
+
)
|
|
1634
|
+
add_function_test(
|
|
1635
|
+
TestJax,
|
|
1636
|
+
"test_ffi_jax_kernel_autodiff_vec2",
|
|
1637
|
+
test_ffi_jax_kernel_autodiff_vec2,
|
|
1638
|
+
devices=jax_compatible_cuda_devices,
|
|
1639
|
+
)
|
|
1640
|
+
add_function_test(
|
|
1641
|
+
TestJax,
|
|
1642
|
+
"test_ffi_jax_kernel_autodiff_mat22",
|
|
1643
|
+
test_ffi_jax_kernel_autodiff_mat22,
|
|
1644
|
+
devices=jax_compatible_cuda_devices,
|
|
1645
|
+
)
|
|
1646
|
+
add_function_test(
|
|
1647
|
+
TestJax,
|
|
1648
|
+
"test_ffi_jax_kernel_autodiff_static_required",
|
|
1649
|
+
test_ffi_jax_kernel_autodiff_static_required,
|
|
1650
|
+
devices=jax_compatible_cuda_devices,
|
|
1651
|
+
)
|
|
1652
|
+
|
|
1653
|
+
# autodiff with pmap tests
|
|
1654
|
+
add_function_test(
|
|
1655
|
+
TestJax,
|
|
1656
|
+
"test_ffi_jax_kernel_autodiff_pmap_triple",
|
|
1657
|
+
test_ffi_jax_kernel_autodiff_pmap_triple,
|
|
1658
|
+
devices=None,
|
|
1659
|
+
)
|
|
360
1660
|
add_function_test(
|
|
361
|
-
TestJax,
|
|
1661
|
+
TestJax,
|
|
1662
|
+
"test_ffi_jax_kernel_autodiff_pmap_multi_output",
|
|
1663
|
+
test_ffi_jax_kernel_autodiff_pmap_multi_output,
|
|
1664
|
+
devices=None,
|
|
362
1665
|
)
|
|
363
1666
|
|
|
364
1667
|
except Exception as e:
|