warp-lang 1.9.1__py3-none-manylinux_2_34_aarch64.whl → 1.10.0rc2__py3-none-manylinux_2_34_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +301 -287
- warp/__init__.pyi +794 -305
- warp/_src/__init__.py +14 -0
- warp/_src/autograd.py +1075 -0
- warp/_src/build.py +618 -0
- warp/_src/build_dll.py +640 -0
- warp/{builtins.py → _src/builtins.py} +1382 -377
- warp/_src/codegen.py +4359 -0
- warp/{config.py → _src/config.py} +178 -169
- warp/_src/constants.py +57 -0
- warp/_src/context.py +8294 -0
- warp/_src/dlpack.py +462 -0
- warp/_src/fabric.py +355 -0
- warp/_src/fem/__init__.py +14 -0
- warp/_src/fem/adaptivity.py +508 -0
- warp/_src/fem/cache.py +687 -0
- warp/_src/fem/dirichlet.py +188 -0
- warp/{fem → _src/fem}/domain.py +40 -30
- warp/_src/fem/field/__init__.py +131 -0
- warp/_src/fem/field/field.py +701 -0
- warp/{fem → _src/fem}/field/nodal_field.py +30 -15
- warp/{fem → _src/fem}/field/restriction.py +1 -1
- warp/{fem → _src/fem}/field/virtual.py +53 -27
- warp/_src/fem/geometry/__init__.py +32 -0
- warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +77 -163
- warp/_src/fem/geometry/closest_point.py +97 -0
- warp/{fem → _src/fem}/geometry/deformed_geometry.py +14 -22
- warp/{fem → _src/fem}/geometry/element.py +32 -10
- warp/{fem → _src/fem}/geometry/geometry.py +48 -20
- warp/{fem → _src/fem}/geometry/grid_2d.py +12 -23
- warp/{fem → _src/fem}/geometry/grid_3d.py +12 -23
- warp/{fem → _src/fem}/geometry/hexmesh.py +40 -63
- warp/{fem → _src/fem}/geometry/nanogrid.py +255 -248
- warp/{fem → _src/fem}/geometry/partition.py +121 -63
- warp/{fem → _src/fem}/geometry/quadmesh.py +26 -45
- warp/{fem → _src/fem}/geometry/tetmesh.py +40 -63
- warp/{fem → _src/fem}/geometry/trimesh.py +26 -45
- warp/{fem → _src/fem}/integrate.py +164 -158
- warp/_src/fem/linalg.py +383 -0
- warp/_src/fem/operator.py +396 -0
- warp/_src/fem/polynomial.py +229 -0
- warp/{fem → _src/fem}/quadrature/pic_quadrature.py +15 -20
- warp/{fem → _src/fem}/quadrature/quadrature.py +95 -47
- warp/_src/fem/space/__init__.py +248 -0
- warp/{fem → _src/fem}/space/basis_function_space.py +20 -11
- warp/_src/fem/space/basis_space.py +679 -0
- warp/{fem → _src/fem}/space/dof_mapper.py +3 -3
- warp/{fem → _src/fem}/space/function_space.py +14 -13
- warp/{fem → _src/fem}/space/grid_2d_function_space.py +4 -7
- warp/{fem → _src/fem}/space/grid_3d_function_space.py +4 -4
- warp/{fem → _src/fem}/space/hexmesh_function_space.py +4 -10
- warp/{fem → _src/fem}/space/nanogrid_function_space.py +3 -9
- warp/{fem → _src/fem}/space/partition.py +117 -60
- warp/{fem → _src/fem}/space/quadmesh_function_space.py +4 -10
- warp/{fem → _src/fem}/space/restriction.py +66 -33
- warp/_src/fem/space/shape/__init__.py +152 -0
- warp/{fem → _src/fem}/space/shape/cube_shape_function.py +9 -9
- warp/{fem → _src/fem}/space/shape/shape_function.py +8 -9
- warp/{fem → _src/fem}/space/shape/square_shape_function.py +6 -6
- warp/{fem → _src/fem}/space/shape/tet_shape_function.py +3 -3
- warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +3 -3
- warp/{fem → _src/fem}/space/tetmesh_function_space.py +3 -9
- warp/_src/fem/space/topology.py +459 -0
- warp/{fem → _src/fem}/space/trimesh_function_space.py +3 -9
- warp/_src/fem/types.py +112 -0
- warp/_src/fem/utils.py +486 -0
- warp/_src/jax.py +186 -0
- warp/_src/jax_experimental/__init__.py +14 -0
- warp/_src/jax_experimental/custom_call.py +387 -0
- warp/_src/jax_experimental/ffi.py +1284 -0
- warp/_src/jax_experimental/xla_ffi.py +656 -0
- warp/_src/marching_cubes.py +708 -0
- warp/_src/math.py +414 -0
- warp/_src/optim/__init__.py +14 -0
- warp/_src/optim/adam.py +163 -0
- warp/_src/optim/linear.py +1606 -0
- warp/_src/optim/sgd.py +112 -0
- warp/_src/paddle.py +406 -0
- warp/_src/render/__init__.py +14 -0
- warp/_src/render/imgui_manager.py +289 -0
- warp/_src/render/render_opengl.py +3636 -0
- warp/_src/render/render_usd.py +937 -0
- warp/_src/render/utils.py +160 -0
- warp/_src/sparse.py +2716 -0
- warp/_src/tape.py +1206 -0
- warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
- warp/_src/torch.py +391 -0
- warp/_src/types.py +5870 -0
- warp/_src/utils.py +1693 -0
- warp/autograd.py +12 -1054
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/build.py +8 -588
- warp/build_dll.py +6 -721
- warp/codegen.py +6 -4251
- warp/constants.py +6 -39
- warp/context.py +12 -8062
- warp/dlpack.py +6 -444
- warp/examples/distributed/example_jacobi_mpi.py +4 -5
- warp/examples/fem/example_adaptive_grid.py +1 -1
- warp/examples/fem/example_apic_fluid.py +1 -1
- warp/examples/fem/example_burgers.py +8 -8
- warp/examples/fem/example_diffusion.py +1 -1
- warp/examples/fem/example_distortion_energy.py +1 -1
- warp/examples/fem/example_mixed_elasticity.py +2 -2
- warp/examples/fem/example_navier_stokes.py +1 -1
- warp/examples/fem/example_nonconforming_contact.py +7 -7
- warp/examples/fem/example_stokes.py +1 -1
- warp/examples/fem/example_stokes_transfer.py +1 -1
- warp/examples/fem/utils.py +2 -2
- warp/examples/interop/example_jax_callable.py +1 -1
- warp/examples/interop/example_jax_ffi_callback.py +1 -1
- warp/examples/interop/example_jax_kernel.py +1 -1
- warp/examples/tile/example_tile_mcgp.py +191 -0
- warp/fabric.py +6 -337
- warp/fem/__init__.py +159 -97
- warp/fem/adaptivity.py +7 -489
- warp/fem/cache.py +9 -648
- warp/fem/dirichlet.py +6 -184
- warp/fem/field/__init__.py +8 -109
- warp/fem/field/field.py +7 -652
- warp/fem/geometry/__init__.py +7 -18
- warp/fem/geometry/closest_point.py +11 -77
- warp/fem/linalg.py +18 -366
- warp/fem/operator.py +11 -369
- warp/fem/polynomial.py +9 -209
- warp/fem/space/__init__.py +5 -211
- warp/fem/space/basis_space.py +6 -662
- warp/fem/space/shape/__init__.py +41 -118
- warp/fem/space/topology.py +6 -437
- warp/fem/types.py +6 -81
- warp/fem/utils.py +11 -444
- warp/jax.py +8 -165
- warp/jax_experimental/__init__.py +14 -1
- warp/jax_experimental/custom_call.py +8 -365
- warp/jax_experimental/ffi.py +17 -873
- warp/jax_experimental/xla_ffi.py +5 -605
- warp/marching_cubes.py +5 -689
- warp/math.py +16 -393
- warp/native/array.h +385 -37
- warp/native/builtin.h +314 -37
- warp/native/bvh.cpp +43 -9
- warp/native/bvh.cu +62 -27
- warp/native/bvh.h +310 -309
- warp/native/clang/clang.cpp +102 -97
- warp/native/coloring.cpp +0 -1
- warp/native/crt.h +208 -0
- warp/native/exports.h +156 -0
- warp/native/hashgrid.cu +2 -0
- warp/native/intersect.h +24 -1
- warp/native/intersect_tri.h +44 -35
- warp/native/mat.h +1456 -276
- warp/native/mesh.cpp +4 -4
- warp/native/mesh.cu +4 -2
- warp/native/mesh.h +176 -61
- warp/native/quat.h +0 -52
- warp/native/scan.cu +2 -0
- warp/native/sparse.cu +7 -3
- warp/native/spatial.h +12 -0
- warp/native/tile.h +681 -89
- warp/native/tile_radix_sort.h +1 -1
- warp/native/tile_reduce.h +394 -46
- warp/native/tile_scan.h +4 -4
- warp/native/vec.h +469 -0
- warp/native/version.h +23 -0
- warp/native/volume.cpp +1 -1
- warp/native/volume.cu +1 -0
- warp/native/volume.h +1 -1
- warp/native/volume_builder.cu +2 -0
- warp/native/warp.cpp +57 -29
- warp/native/warp.cu +253 -171
- warp/native/warp.h +11 -8
- warp/optim/__init__.py +6 -3
- warp/optim/adam.py +6 -145
- warp/optim/linear.py +14 -1585
- warp/optim/sgd.py +6 -94
- warp/paddle.py +6 -388
- warp/render/__init__.py +8 -4
- warp/render/imgui_manager.py +7 -267
- warp/render/render_opengl.py +6 -3618
- warp/render/render_usd.py +6 -919
- warp/render/utils.py +6 -142
- warp/sparse.py +37 -2563
- warp/tape.py +6 -1188
- warp/tests/__main__.py +1 -1
- warp/tests/cuda/test_async.py +4 -4
- warp/tests/cuda/test_conditional_captures.py +1 -1
- warp/tests/cuda/test_multigpu.py +1 -1
- warp/tests/cuda/test_streams.py +58 -1
- warp/tests/geometry/test_bvh.py +157 -22
- warp/tests/geometry/test_marching_cubes.py +0 -1
- warp/tests/geometry/test_mesh.py +5 -3
- warp/tests/geometry/test_mesh_query_aabb.py +5 -12
- warp/tests/geometry/test_mesh_query_point.py +5 -2
- warp/tests/geometry/test_mesh_query_ray.py +15 -3
- warp/tests/geometry/test_volume_write.py +5 -5
- warp/tests/interop/test_dlpack.py +14 -14
- warp/tests/interop/test_jax.py +772 -49
- warp/tests/interop/test_paddle.py +1 -1
- warp/tests/test_adam.py +0 -1
- warp/tests/test_arithmetic.py +9 -9
- warp/tests/test_array.py +527 -100
- warp/tests/test_array_reduce.py +3 -3
- warp/tests/test_atomic.py +12 -8
- warp/tests/test_atomic_bitwise.py +209 -0
- warp/tests/test_atomic_cas.py +4 -4
- warp/tests/test_bool.py +2 -2
- warp/tests/test_builtins_resolution.py +5 -571
- warp/tests/test_codegen.py +33 -14
- warp/tests/test_conditional.py +1 -1
- warp/tests/test_context.py +6 -6
- warp/tests/test_copy.py +242 -161
- warp/tests/test_ctypes.py +3 -3
- warp/tests/test_devices.py +24 -2
- warp/tests/test_examples.py +16 -84
- warp/tests/test_fabricarray.py +35 -35
- warp/tests/test_fast_math.py +0 -2
- warp/tests/test_fem.py +56 -10
- warp/tests/test_fixedarray.py +3 -3
- warp/tests/test_func.py +8 -5
- warp/tests/test_generics.py +1 -1
- warp/tests/test_indexedarray.py +24 -24
- warp/tests/test_intersect.py +39 -9
- warp/tests/test_large.py +1 -1
- warp/tests/test_lerp.py +3 -1
- warp/tests/test_linear_solvers.py +1 -1
- warp/tests/test_map.py +35 -4
- warp/tests/test_mat.py +52 -62
- warp/tests/test_mat_constructors.py +4 -5
- warp/tests/test_mat_lite.py +1 -1
- warp/tests/test_mat_scalar_ops.py +121 -121
- warp/tests/test_math.py +34 -0
- warp/tests/test_module_aot.py +4 -4
- warp/tests/test_modules_lite.py +28 -2
- warp/tests/test_print.py +11 -11
- warp/tests/test_quat.py +93 -58
- warp/tests/test_runlength_encode.py +1 -1
- warp/tests/test_scalar_ops.py +38 -10
- warp/tests/test_smoothstep.py +1 -1
- warp/tests/test_sparse.py +126 -15
- warp/tests/test_spatial.py +105 -87
- warp/tests/test_special_values.py +6 -6
- warp/tests/test_static.py +7 -7
- warp/tests/test_struct.py +13 -2
- warp/tests/test_triangle_closest_point.py +48 -1
- warp/tests/test_types.py +27 -15
- warp/tests/test_utils.py +52 -52
- warp/tests/test_vec.py +29 -29
- warp/tests/test_vec_constructors.py +5 -5
- warp/tests/test_vec_scalar_ops.py +97 -97
- warp/tests/test_version.py +75 -0
- warp/tests/tile/test_tile.py +178 -0
- warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
- warp/tests/tile/test_tile_cholesky.py +7 -4
- warp/tests/tile/test_tile_load.py +26 -2
- warp/tests/tile/test_tile_mathdx.py +3 -3
- warp/tests/tile/test_tile_matmul.py +1 -1
- warp/tests/tile/test_tile_mlp.py +2 -4
- warp/tests/tile/test_tile_reduce.py +214 -13
- warp/tests/unittest_suites.py +6 -14
- warp/tests/unittest_utils.py +10 -9
- warp/tests/walkthrough_debug.py +3 -1
- warp/torch.py +6 -373
- warp/types.py +29 -5764
- warp/utils.py +10 -1659
- {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/METADATA +46 -99
- warp_lang-1.10.0rc2.dist-info/RECORD +468 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
- warp/examples/assets/cartpole.urdf +0 -110
- warp/examples/assets/crazyflie.usd +0 -0
- warp/examples/assets/nv_ant.xml +0 -92
- warp/examples/assets/nv_humanoid.xml +0 -183
- warp/examples/assets/quadruped.urdf +0 -268
- warp/examples/optim/example_bounce.py +0 -266
- warp/examples/optim/example_cloth_throw.py +0 -228
- warp/examples/optim/example_drone.py +0 -870
- warp/examples/optim/example_inverse_kinematics.py +0 -182
- warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
- warp/examples/optim/example_softbody_properties.py +0 -400
- warp/examples/optim/example_spring_cage.py +0 -245
- warp/examples/optim/example_trajectory.py +0 -227
- warp/examples/sim/example_cartpole.py +0 -143
- warp/examples/sim/example_cloth.py +0 -225
- warp/examples/sim/example_cloth_self_contact.py +0 -316
- warp/examples/sim/example_granular.py +0 -130
- warp/examples/sim/example_granular_collision_sdf.py +0 -202
- warp/examples/sim/example_jacobian_ik.py +0 -244
- warp/examples/sim/example_particle_chain.py +0 -124
- warp/examples/sim/example_quadruped.py +0 -203
- warp/examples/sim/example_rigid_chain.py +0 -203
- warp/examples/sim/example_rigid_contact.py +0 -195
- warp/examples/sim/example_rigid_force.py +0 -133
- warp/examples/sim/example_rigid_gyroscopic.py +0 -115
- warp/examples/sim/example_rigid_soft_contact.py +0 -140
- warp/examples/sim/example_soft_body.py +0 -196
- warp/examples/tile/example_tile_walker.py +0 -327
- warp/sim/__init__.py +0 -74
- warp/sim/articulation.py +0 -793
- warp/sim/collide.py +0 -2570
- warp/sim/graph_coloring.py +0 -307
- warp/sim/import_mjcf.py +0 -791
- warp/sim/import_snu.py +0 -227
- warp/sim/import_urdf.py +0 -579
- warp/sim/import_usd.py +0 -898
- warp/sim/inertia.py +0 -357
- warp/sim/integrator.py +0 -245
- warp/sim/integrator_euler.py +0 -2000
- warp/sim/integrator_featherstone.py +0 -2101
- warp/sim/integrator_vbd.py +0 -2487
- warp/sim/integrator_xpbd.py +0 -3295
- warp/sim/model.py +0 -4821
- warp/sim/particles.py +0 -121
- warp/sim/render.py +0 -431
- warp/sim/utils.py +0 -431
- warp/tests/sim/disabled_kinematics.py +0 -244
- warp/tests/sim/test_cloth.py +0 -863
- warp/tests/sim/test_collision.py +0 -743
- warp/tests/sim/test_coloring.py +0 -347
- warp/tests/sim/test_inertia.py +0 -161
- warp/tests/sim/test_model.py +0 -226
- warp/tests/sim/test_sim_grad.py +0 -287
- warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
- warp/tests/sim/test_sim_kinematics.py +0 -98
- warp/thirdparty/__init__.py +0 -0
- warp_lang-1.9.1.dist-info/RECORD +0 -456
- /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
- /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
- /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
- /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
- {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/WHEEL +0 -0
- {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/licenses/LICENSE.md +0 -0
- {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/top_level.txt +0 -0
warp/tests/interop/test_jax.py
CHANGED
|
@@ -21,9 +21,12 @@ from typing import Any
|
|
|
21
21
|
import numpy as np
|
|
22
22
|
|
|
23
23
|
import warp as wp
|
|
24
|
-
from warp.jax import get_jax_device
|
|
24
|
+
from warp._src.jax import get_jax_device
|
|
25
25
|
from warp.tests.unittest_utils import *
|
|
26
26
|
|
|
27
|
+
# default array size for tests
|
|
28
|
+
ARRAY_SIZE = 1024 * 1024
|
|
29
|
+
|
|
27
30
|
|
|
28
31
|
# basic kernel with one input and output
|
|
29
32
|
@wp.kernel
|
|
@@ -46,6 +49,18 @@ def triple_kernel_vecmat(input: wp.array(dtype=Any), output: wp.array(dtype=Any)
|
|
|
46
49
|
output[tid] = input.dtype.dtype(3) * input[tid]
|
|
47
50
|
|
|
48
51
|
|
|
52
|
+
@wp.kernel
|
|
53
|
+
def inc_1d_kernel(x: wp.array(dtype=float), y: wp.array(dtype=float)):
|
|
54
|
+
tid = wp.tid()
|
|
55
|
+
y[tid] = x[tid] + 1.0
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@wp.kernel
|
|
59
|
+
def inc_2d_kernel(x: wp.array2d(dtype=float), y: wp.array2d(dtype=float)):
|
|
60
|
+
i, j = wp.tid()
|
|
61
|
+
y[i, j] = x[i, j] + 1.0
|
|
62
|
+
|
|
63
|
+
|
|
49
64
|
# kernel with multiple inputs and outputs
|
|
50
65
|
@wp.kernel
|
|
51
66
|
def multiarg_kernel(
|
|
@@ -63,7 +78,7 @@ def multiarg_kernel(
|
|
|
63
78
|
|
|
64
79
|
|
|
65
80
|
# various types for testing
|
|
66
|
-
scalar_types = wp.types.scalar_types
|
|
81
|
+
scalar_types = wp._src.types.scalar_types
|
|
67
82
|
vector_types = []
|
|
68
83
|
matrix_types = []
|
|
69
84
|
for dim in [2, 3, 4]:
|
|
@@ -146,7 +161,7 @@ def test_jax_kernel_basic(test, device, use_ffi=False):
|
|
|
146
161
|
|
|
147
162
|
jax_triple = jax_kernel(triple_kernel, quiet=True) # suppress deprecation warnings
|
|
148
163
|
|
|
149
|
-
n =
|
|
164
|
+
n = ARRAY_SIZE
|
|
150
165
|
|
|
151
166
|
@jax.jit
|
|
152
167
|
def f():
|
|
@@ -157,6 +172,8 @@ def test_jax_kernel_basic(test, device, use_ffi=False):
|
|
|
157
172
|
with jax.default_device(wp.device_to_jax(device)):
|
|
158
173
|
y = f()
|
|
159
174
|
|
|
175
|
+
wp.synchronize_device(device)
|
|
176
|
+
|
|
160
177
|
result = np.asarray(y).reshape((n,))
|
|
161
178
|
expected = 3 * np.arange(n, dtype=np.float32)
|
|
162
179
|
|
|
@@ -175,6 +192,7 @@ def test_jax_kernel_scalar(test, device, use_ffi=False):
|
|
|
175
192
|
|
|
176
193
|
kwargs = {"quiet": True}
|
|
177
194
|
|
|
195
|
+
# use a smallish size to ensure arange * 3 doesn't overflow
|
|
178
196
|
n = 64
|
|
179
197
|
|
|
180
198
|
for T in scalar_types:
|
|
@@ -196,6 +214,8 @@ def test_jax_kernel_scalar(test, device, use_ffi=False):
|
|
|
196
214
|
with jax.default_device(wp.device_to_jax(device)):
|
|
197
215
|
y = f()
|
|
198
216
|
|
|
217
|
+
wp.synchronize_device(device)
|
|
218
|
+
|
|
199
219
|
result = np.asarray(y).reshape((n,))
|
|
200
220
|
expected = 3 * np.arange(n, dtype=np_dtype)
|
|
201
221
|
|
|
@@ -218,6 +238,7 @@ def test_jax_kernel_vecmat(test, device, use_ffi=False):
|
|
|
218
238
|
jp_dtype = wp.dtype_to_jax(T._wp_scalar_type_)
|
|
219
239
|
np_dtype = wp.dtype_to_numpy(T._wp_scalar_type_)
|
|
220
240
|
|
|
241
|
+
# use a smallish size to ensure arange * 3 doesn't overflow
|
|
221
242
|
n = 64 // T._length_
|
|
222
243
|
scalar_shape = (n, *T._shape_)
|
|
223
244
|
scalar_len = n * T._length_
|
|
@@ -237,6 +258,8 @@ def test_jax_kernel_vecmat(test, device, use_ffi=False):
|
|
|
237
258
|
with jax.default_device(wp.device_to_jax(device)):
|
|
238
259
|
y = f()
|
|
239
260
|
|
|
261
|
+
wp.synchronize_device(device)
|
|
262
|
+
|
|
240
263
|
result = np.asarray(y).reshape(scalar_shape)
|
|
241
264
|
expected = 3 * np.arange(scalar_len, dtype=np_dtype).reshape(scalar_shape)
|
|
242
265
|
|
|
@@ -255,7 +278,7 @@ def test_jax_kernel_multiarg(test, device, use_ffi=False):
|
|
|
255
278
|
|
|
256
279
|
jax_multiarg = jax_kernel(multiarg_kernel, quiet=True)
|
|
257
280
|
|
|
258
|
-
n =
|
|
281
|
+
n = ARRAY_SIZE
|
|
259
282
|
|
|
260
283
|
@jax.jit
|
|
261
284
|
def f():
|
|
@@ -268,6 +291,8 @@ def test_jax_kernel_multiarg(test, device, use_ffi=False):
|
|
|
268
291
|
with jax.default_device(wp.device_to_jax(device)):
|
|
269
292
|
x, y = f()
|
|
270
293
|
|
|
294
|
+
wp.synchronize_device(device)
|
|
295
|
+
|
|
271
296
|
result_x, result_y = np.asarray(x), np.asarray(y)
|
|
272
297
|
expected_x = np.full(n, 3, dtype=np.float32)
|
|
273
298
|
expected_y = np.full(n, 5, dtype=np.float32)
|
|
@@ -292,40 +317,32 @@ def test_jax_kernel_launch_dims(test, device, use_ffi=False):
|
|
|
292
317
|
m = 32
|
|
293
318
|
|
|
294
319
|
# Test with 1D launch dims
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
tid = wp.tid()
|
|
298
|
-
y[tid] = x[tid] + 1.0
|
|
299
|
-
|
|
300
|
-
jax_add_one = jax_kernel(
|
|
301
|
-
add_one_kernel, launch_dims=(n - 2,), **kwargs
|
|
320
|
+
jax_inc_1d = jax_kernel(
|
|
321
|
+
inc_1d_kernel, launch_dims=(n - 2,), **kwargs
|
|
302
322
|
) # Intentionally not the same as the first dimension of the input
|
|
303
323
|
|
|
304
324
|
@jax.jit
|
|
305
325
|
def f_1d():
|
|
306
326
|
x = jp.arange(n, dtype=jp.float32)
|
|
307
|
-
return
|
|
327
|
+
return jax_inc_1d(x)
|
|
308
328
|
|
|
309
329
|
# Test with 2D launch dims
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
i, j = wp.tid()
|
|
313
|
-
y[i, j] = x[i, j] + 1.0
|
|
314
|
-
|
|
315
|
-
jax_add_one_2d = jax_kernel(
|
|
316
|
-
add_one_2d_kernel, launch_dims=(n - 2, m - 2), **kwargs
|
|
330
|
+
jax_inc_2d = jax_kernel(
|
|
331
|
+
inc_2d_kernel, launch_dims=(n - 2, m - 2), **kwargs
|
|
317
332
|
) # Intentionally not the same as the first dimension of the input
|
|
318
333
|
|
|
319
334
|
@jax.jit
|
|
320
335
|
def f_2d():
|
|
321
336
|
x = jp.zeros((n, m), dtype=jp.float32) + 3.0
|
|
322
|
-
return
|
|
337
|
+
return jax_inc_2d(x)
|
|
323
338
|
|
|
324
339
|
# run on the given device
|
|
325
340
|
with jax.default_device(wp.device_to_jax(device)):
|
|
326
341
|
y_1d = f_1d()
|
|
327
342
|
y_2d = f_2d()
|
|
328
343
|
|
|
344
|
+
wp.synchronize_device(device)
|
|
345
|
+
|
|
329
346
|
result_1d = np.asarray(y_1d).reshape((n - 2,))
|
|
330
347
|
expected_1d = np.arange(n - 2, dtype=np.float32) + 1.0
|
|
331
348
|
|
|
@@ -342,11 +359,17 @@ def test_jax_kernel_launch_dims(test, device, use_ffi=False):
|
|
|
342
359
|
|
|
343
360
|
|
|
344
361
|
@wp.kernel
|
|
345
|
-
def add_kernel(a: wp.array(dtype=
|
|
362
|
+
def add_kernel(a: wp.array(dtype=float), b: wp.array(dtype=float), output: wp.array(dtype=float)):
|
|
346
363
|
tid = wp.tid()
|
|
347
364
|
output[tid] = a[tid] + b[tid]
|
|
348
365
|
|
|
349
366
|
|
|
367
|
+
@wp.kernel
|
|
368
|
+
def axpy_kernel(x: wp.array(dtype=float), y: wp.array(dtype=float), alpha: float, out: wp.array(dtype=float)):
|
|
369
|
+
tid = wp.tid()
|
|
370
|
+
out[tid] = alpha * x[tid] + y[tid]
|
|
371
|
+
|
|
372
|
+
|
|
350
373
|
@wp.kernel
|
|
351
374
|
def sincos_kernel(angle: wp.array(dtype=float), sin_out: wp.array(dtype=float), cos_out: wp.array(dtype=float)):
|
|
352
375
|
tid = wp.tid()
|
|
@@ -408,6 +431,39 @@ def in_out_kernel(
|
|
|
408
431
|
c[tid] = 2.0 * a[tid]
|
|
409
432
|
|
|
410
433
|
|
|
434
|
+
@wp.kernel
|
|
435
|
+
def multi_out_kernel(
|
|
436
|
+
a: wp.array(dtype=float), b: wp.array(dtype=float), s: float, c: wp.array(dtype=float), d: wp.array(dtype=float)
|
|
437
|
+
):
|
|
438
|
+
tid = wp.tid()
|
|
439
|
+
c[tid] = a[tid] + b[tid]
|
|
440
|
+
d[tid] = s * a[tid]
|
|
441
|
+
|
|
442
|
+
|
|
443
|
+
@wp.kernel
|
|
444
|
+
def multi_out_kernel_v2(
|
|
445
|
+
a: wp.array(dtype=float), b: wp.array(dtype=float), s: float, c: wp.array(dtype=float), d: wp.array(dtype=float)
|
|
446
|
+
):
|
|
447
|
+
tid = wp.tid()
|
|
448
|
+
c[tid] = a[tid] * a[tid]
|
|
449
|
+
d[tid] = a[tid] * b[tid] * s
|
|
450
|
+
|
|
451
|
+
|
|
452
|
+
@wp.kernel
|
|
453
|
+
def multi_out_kernel_v3(
|
|
454
|
+
a: wp.array(dtype=float), b: wp.array(dtype=float), s: float, c: wp.array(dtype=float), d: wp.array(dtype=float)
|
|
455
|
+
):
|
|
456
|
+
tid = wp.tid()
|
|
457
|
+
c[tid] = a[tid] ** 2.0
|
|
458
|
+
d[tid] = a[tid] * b[tid] * s
|
|
459
|
+
|
|
460
|
+
|
|
461
|
+
@wp.kernel
|
|
462
|
+
def scale_sum_square_kernel(a: wp.array(dtype=float), b: wp.array(dtype=float), s: float, c: wp.array(dtype=float)):
|
|
463
|
+
tid = wp.tid()
|
|
464
|
+
c[tid] = (a[tid] * s + b[tid]) ** 2.0
|
|
465
|
+
|
|
466
|
+
|
|
411
467
|
# The Python function to call.
|
|
412
468
|
# Note the argument annotations, just like Warp kernels.
|
|
413
469
|
def scale_func(
|
|
@@ -432,6 +488,15 @@ def in_out_func(
|
|
|
432
488
|
wp.launch(accum_kernel, dim=a.size, inputs=[a, b]) # modifies `b`
|
|
433
489
|
|
|
434
490
|
|
|
491
|
+
def double_func(
|
|
492
|
+
# inputs
|
|
493
|
+
a: wp.array(dtype=float),
|
|
494
|
+
# outputs
|
|
495
|
+
b: wp.array(dtype=float),
|
|
496
|
+
):
|
|
497
|
+
wp.launch(scale_kernel, dim=a.shape, inputs=[a, 2.0], outputs=[b])
|
|
498
|
+
|
|
499
|
+
|
|
435
500
|
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
436
501
|
def test_ffi_jax_kernel_add(test, device):
|
|
437
502
|
# two inputs and one output
|
|
@@ -443,16 +508,18 @@ def test_ffi_jax_kernel_add(test, device):
|
|
|
443
508
|
|
|
444
509
|
@jax.jit
|
|
445
510
|
def f():
|
|
446
|
-
n =
|
|
447
|
-
a = jp.arange(n, dtype=jp.
|
|
448
|
-
b = jp.ones(n, dtype=jp.
|
|
511
|
+
n = ARRAY_SIZE
|
|
512
|
+
a = jp.arange(n, dtype=jp.float32)
|
|
513
|
+
b = jp.ones(n, dtype=jp.float32)
|
|
449
514
|
return jax_add(a, b)
|
|
450
515
|
|
|
451
516
|
with jax.default_device(wp.device_to_jax(device)):
|
|
452
517
|
(y,) = f()
|
|
453
518
|
|
|
519
|
+
wp.synchronize_device(device)
|
|
520
|
+
|
|
454
521
|
result = np.asarray(y)
|
|
455
|
-
expected = np.arange(1,
|
|
522
|
+
expected = np.arange(1, ARRAY_SIZE + 1, dtype=np.float32)
|
|
456
523
|
|
|
457
524
|
assert_np_equal(result, expected)
|
|
458
525
|
|
|
@@ -465,7 +532,8 @@ def test_ffi_jax_kernel_sincos(test, device):
|
|
|
465
532
|
from warp.jax_experimental.ffi import jax_kernel
|
|
466
533
|
|
|
467
534
|
jax_sincos = jax_kernel(sincos_kernel, num_outputs=2)
|
|
468
|
-
|
|
535
|
+
|
|
536
|
+
n = ARRAY_SIZE
|
|
469
537
|
|
|
470
538
|
@jax.jit
|
|
471
539
|
def f():
|
|
@@ -475,6 +543,8 @@ def test_ffi_jax_kernel_sincos(test, device):
|
|
|
475
543
|
with jax.default_device(wp.device_to_jax(device)):
|
|
476
544
|
s, c = f()
|
|
477
545
|
|
|
546
|
+
wp.synchronize_device(device)
|
|
547
|
+
|
|
478
548
|
result_s = np.asarray(s)
|
|
479
549
|
result_c = np.asarray(c)
|
|
480
550
|
|
|
@@ -498,6 +568,8 @@ def test_ffi_jax_kernel_diagonal(test, device):
|
|
|
498
568
|
# launch dimensions determine output size
|
|
499
569
|
return jax_diagonal(launch_dims=4)
|
|
500
570
|
|
|
571
|
+
wp.synchronize_device(device)
|
|
572
|
+
|
|
501
573
|
with jax.default_device(wp.device_to_jax(device)):
|
|
502
574
|
(d,) = f()
|
|
503
575
|
|
|
@@ -527,12 +599,14 @@ def test_ffi_jax_kernel_in_out(test, device):
|
|
|
527
599
|
f = jax.jit(jax_func)
|
|
528
600
|
|
|
529
601
|
with jax.default_device(wp.device_to_jax(device)):
|
|
530
|
-
a = jp.ones(
|
|
531
|
-
b = jp.arange(
|
|
602
|
+
a = jp.ones(ARRAY_SIZE, dtype=jp.float32)
|
|
603
|
+
b = jp.arange(ARRAY_SIZE, dtype=jp.float32)
|
|
532
604
|
b, c = f(a, b)
|
|
533
605
|
|
|
534
|
-
|
|
535
|
-
|
|
606
|
+
wp.synchronize_device(device)
|
|
607
|
+
|
|
608
|
+
assert_np_equal(b, np.arange(1, ARRAY_SIZE + 1, dtype=np.float32))
|
|
609
|
+
assert_np_equal(c, np.full(ARRAY_SIZE, 2, dtype=np.float32))
|
|
536
610
|
|
|
537
611
|
|
|
538
612
|
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
@@ -546,14 +620,16 @@ def test_ffi_jax_kernel_scale_vec_constant(test, device):
|
|
|
546
620
|
|
|
547
621
|
@jax.jit
|
|
548
622
|
def f():
|
|
549
|
-
a = jp.arange(
|
|
623
|
+
a = jp.arange(ARRAY_SIZE, dtype=jp.float32).reshape((ARRAY_SIZE // 2, 2)) # array of vec2
|
|
550
624
|
s = 2.0
|
|
551
625
|
return jax_scale_vec(a, s)
|
|
552
626
|
|
|
553
627
|
with jax.default_device(wp.device_to_jax(device)):
|
|
554
628
|
(b,) = f()
|
|
555
629
|
|
|
556
|
-
|
|
630
|
+
wp.synchronize_device(device)
|
|
631
|
+
|
|
632
|
+
expected = 2 * np.arange(ARRAY_SIZE, dtype=np.float32).reshape((ARRAY_SIZE // 2, 2))
|
|
557
633
|
|
|
558
634
|
assert_np_equal(b, expected)
|
|
559
635
|
|
|
@@ -572,13 +648,15 @@ def test_ffi_jax_kernel_scale_vec_static(test, device):
|
|
|
572
648
|
def f(a, s):
|
|
573
649
|
return jax_scale_vec(a, s)
|
|
574
650
|
|
|
575
|
-
a = jp.arange(
|
|
651
|
+
a = jp.arange(ARRAY_SIZE, dtype=jp.float32).reshape((ARRAY_SIZE // 2, 2)) # array of vec2
|
|
576
652
|
s = 3.0
|
|
577
653
|
|
|
578
654
|
with jax.default_device(wp.device_to_jax(device)):
|
|
579
655
|
(b,) = f(a, s)
|
|
580
656
|
|
|
581
|
-
|
|
657
|
+
wp.synchronize_device(device)
|
|
658
|
+
|
|
659
|
+
expected = 3 * np.arange(ARRAY_SIZE, dtype=np.float32).reshape((ARRAY_SIZE // 2, 2))
|
|
582
660
|
|
|
583
661
|
assert_np_equal(b, expected)
|
|
584
662
|
|
|
@@ -605,6 +683,8 @@ def test_ffi_jax_kernel_launch_dims_default(test, device):
|
|
|
605
683
|
with jax.default_device(wp.device_to_jax(device)):
|
|
606
684
|
(result,) = f()
|
|
607
685
|
|
|
686
|
+
wp.synchronize_device(device)
|
|
687
|
+
|
|
608
688
|
expected = np.full((3, 4), 12, dtype=np.float32)
|
|
609
689
|
|
|
610
690
|
test.assertEqual(result.shape, expected.shape)
|
|
@@ -641,6 +721,8 @@ def test_ffi_jax_kernel_launch_dims_custom(test, device):
|
|
|
641
721
|
with jax.default_device(wp.device_to_jax(device)):
|
|
642
722
|
result1, result2 = f()
|
|
643
723
|
|
|
724
|
+
wp.synchronize_device(device)
|
|
725
|
+
|
|
644
726
|
expected1 = np.full((3, 4), 12, dtype=np.float32)
|
|
645
727
|
expected2 = np.full((4, 3), 12, dtype=np.float32)
|
|
646
728
|
|
|
@@ -662,8 +744,8 @@ def test_ffi_jax_callable_scale_constant(test, device):
|
|
|
662
744
|
@jax.jit
|
|
663
745
|
def f():
|
|
664
746
|
# inputs
|
|
665
|
-
a = jp.arange(
|
|
666
|
-
b = jp.arange(
|
|
747
|
+
a = jp.arange(ARRAY_SIZE, dtype=jp.float32)
|
|
748
|
+
b = jp.arange(ARRAY_SIZE, dtype=jp.float32).reshape((ARRAY_SIZE // 2, 2)) # wp.vec2
|
|
667
749
|
s = 2.0
|
|
668
750
|
|
|
669
751
|
# output shapes
|
|
@@ -676,8 +758,10 @@ def test_ffi_jax_callable_scale_constant(test, device):
|
|
|
676
758
|
with jax.default_device(wp.device_to_jax(device)):
|
|
677
759
|
result1, result2 = f()
|
|
678
760
|
|
|
679
|
-
|
|
680
|
-
|
|
761
|
+
wp.synchronize_device(device)
|
|
762
|
+
|
|
763
|
+
expected1 = 2 * np.arange(ARRAY_SIZE, dtype=np.float32)
|
|
764
|
+
expected2 = 2 * np.arange(ARRAY_SIZE, dtype=np.float32).reshape((ARRAY_SIZE // 2, 2))
|
|
681
765
|
|
|
682
766
|
assert_np_equal(result1, expected1)
|
|
683
767
|
assert_np_equal(result2, expected2)
|
|
@@ -704,13 +788,15 @@ def test_ffi_jax_callable_scale_static(test, device):
|
|
|
704
788
|
|
|
705
789
|
with jax.default_device(wp.device_to_jax(device)):
|
|
706
790
|
# inputs
|
|
707
|
-
a = jp.arange(
|
|
708
|
-
b = jp.arange(
|
|
791
|
+
a = jp.arange(ARRAY_SIZE, dtype=jp.float32)
|
|
792
|
+
b = jp.arange(ARRAY_SIZE, dtype=jp.float32).reshape((ARRAY_SIZE // 2, 2)) # wp.vec2
|
|
709
793
|
s = 3.0
|
|
710
794
|
result1, result2 = f(a, b, s)
|
|
711
795
|
|
|
712
|
-
|
|
713
|
-
|
|
796
|
+
wp.synchronize_device(device)
|
|
797
|
+
|
|
798
|
+
expected1 = 3 * np.arange(ARRAY_SIZE, dtype=np.float32)
|
|
799
|
+
expected2 = 3 * np.arange(ARRAY_SIZE, dtype=np.float32).reshape((ARRAY_SIZE // 2, 2))
|
|
714
800
|
|
|
715
801
|
assert_np_equal(result1, expected1)
|
|
716
802
|
assert_np_equal(result2, expected2)
|
|
@@ -728,12 +814,224 @@ def test_ffi_jax_callable_in_out(test, device):
|
|
|
728
814
|
f = jax.jit(jax_func)
|
|
729
815
|
|
|
730
816
|
with jax.default_device(wp.device_to_jax(device)):
|
|
731
|
-
a = jp.ones(
|
|
732
|
-
b = jp.arange(
|
|
817
|
+
a = jp.ones(ARRAY_SIZE, dtype=jp.float32)
|
|
818
|
+
b = jp.arange(ARRAY_SIZE, dtype=jp.float32)
|
|
733
819
|
b, c = f(a, b)
|
|
734
820
|
|
|
735
|
-
|
|
736
|
-
|
|
821
|
+
wp.synchronize_device(device)
|
|
822
|
+
|
|
823
|
+
assert_np_equal(b, np.arange(1, ARRAY_SIZE + 1, dtype=np.float32))
|
|
824
|
+
assert_np_equal(c, np.full(ARRAY_SIZE, 2, dtype=np.float32))
|
|
825
|
+
|
|
826
|
+
|
|
827
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
828
|
+
def test_ffi_jax_callable_graph_cache(test, device):
|
|
829
|
+
# test graph caching limits
|
|
830
|
+
import jax
|
|
831
|
+
import jax.numpy as jp
|
|
832
|
+
|
|
833
|
+
from warp.jax_experimental.ffi import (
|
|
834
|
+
GraphMode,
|
|
835
|
+
clear_jax_callable_graph_cache,
|
|
836
|
+
get_jax_callable_default_graph_cache_max,
|
|
837
|
+
jax_callable,
|
|
838
|
+
set_jax_callable_default_graph_cache_max,
|
|
839
|
+
)
|
|
840
|
+
|
|
841
|
+
# --- test with default cache settings ---
|
|
842
|
+
|
|
843
|
+
jax_double = jax_callable(double_func, graph_mode=GraphMode.WARP)
|
|
844
|
+
f = jax.jit(jax_double)
|
|
845
|
+
arrays = []
|
|
846
|
+
|
|
847
|
+
test.assertEqual(jax_double.graph_cache_max, get_jax_callable_default_graph_cache_max())
|
|
848
|
+
|
|
849
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
850
|
+
for i in range(10):
|
|
851
|
+
n = 10 + i
|
|
852
|
+
a = jp.arange(n, dtype=jp.float32)
|
|
853
|
+
(b,) = f(a)
|
|
854
|
+
|
|
855
|
+
assert_np_equal(b, 2 * np.arange(n, dtype=np.float32))
|
|
856
|
+
|
|
857
|
+
# ensure graph cache is always growing
|
|
858
|
+
test.assertEqual(jax_double.graph_cache_size, i + 1)
|
|
859
|
+
|
|
860
|
+
# keep JAX array alive to prevent the memory from being reused, thus forcing a new graph capture each time
|
|
861
|
+
arrays.append(a)
|
|
862
|
+
|
|
863
|
+
# --- test clearing one callable's cache ---
|
|
864
|
+
|
|
865
|
+
clear_jax_callable_graph_cache(jax_double)
|
|
866
|
+
|
|
867
|
+
test.assertEqual(jax_double.graph_cache_size, 0)
|
|
868
|
+
|
|
869
|
+
# --- test with a custom cache limit ---
|
|
870
|
+
|
|
871
|
+
graph_cache_max = 5
|
|
872
|
+
jax_double = jax_callable(double_func, graph_mode=GraphMode.WARP, graph_cache_max=graph_cache_max)
|
|
873
|
+
f = jax.jit(jax_double)
|
|
874
|
+
arrays = []
|
|
875
|
+
|
|
876
|
+
test.assertEqual(jax_double.graph_cache_max, graph_cache_max)
|
|
877
|
+
|
|
878
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
879
|
+
for i in range(10):
|
|
880
|
+
n = 10 + i
|
|
881
|
+
a = jp.arange(n, dtype=jp.float32)
|
|
882
|
+
(b,) = f(a)
|
|
883
|
+
|
|
884
|
+
assert_np_equal(b, 2 * np.arange(n, dtype=np.float32))
|
|
885
|
+
|
|
886
|
+
# ensure graph cache size is capped
|
|
887
|
+
test.assertEqual(jax_double.graph_cache_size, min(i + 1, graph_cache_max))
|
|
888
|
+
|
|
889
|
+
# keep JAX array alive to prevent the memory from being reused, thus forcing a new graph capture
|
|
890
|
+
arrays.append(a)
|
|
891
|
+
|
|
892
|
+
# --- test clearing all callables' caches ---
|
|
893
|
+
|
|
894
|
+
clear_jax_callable_graph_cache()
|
|
895
|
+
|
|
896
|
+
with wp.jax_experimental.ffi._FFI_REGISTRY_LOCK:
|
|
897
|
+
for c in wp.jax_experimental.ffi._FFI_CALLABLE_REGISTRY.values():
|
|
898
|
+
test.assertEqual(c.graph_cache_size, 0)
|
|
899
|
+
|
|
900
|
+
# --- test with a custom default cache limit ---
|
|
901
|
+
|
|
902
|
+
saved_max = get_jax_callable_default_graph_cache_max()
|
|
903
|
+
try:
|
|
904
|
+
set_jax_callable_default_graph_cache_max(5)
|
|
905
|
+
jax_double = jax_callable(double_func, graph_mode=GraphMode.WARP)
|
|
906
|
+
f = jax.jit(jax_double)
|
|
907
|
+
arrays = []
|
|
908
|
+
|
|
909
|
+
test.assertEqual(jax_double.graph_cache_max, get_jax_callable_default_graph_cache_max())
|
|
910
|
+
|
|
911
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
912
|
+
for i in range(10):
|
|
913
|
+
n = 10 + i
|
|
914
|
+
a = jp.arange(n, dtype=jp.float32)
|
|
915
|
+
(b,) = f(a)
|
|
916
|
+
|
|
917
|
+
assert_np_equal(b, 2 * np.arange(n, dtype=np.float32))
|
|
918
|
+
|
|
919
|
+
# ensure graph cache size is capped
|
|
920
|
+
test.assertEqual(
|
|
921
|
+
jax_double.graph_cache_size,
|
|
922
|
+
min(i + 1, get_jax_callable_default_graph_cache_max()),
|
|
923
|
+
)
|
|
924
|
+
|
|
925
|
+
# keep JAX array alive to prevent the memory from being reused, thus forcing a new graph capture
|
|
926
|
+
arrays.append(a)
|
|
927
|
+
|
|
928
|
+
clear_jax_callable_graph_cache()
|
|
929
|
+
|
|
930
|
+
finally:
|
|
931
|
+
set_jax_callable_default_graph_cache_max(saved_max)
|
|
932
|
+
|
|
933
|
+
|
|
934
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
935
|
+
def test_ffi_jax_callable_pmap_mul(test, device):
|
|
936
|
+
import jax
|
|
937
|
+
import jax.numpy as jp
|
|
938
|
+
|
|
939
|
+
from warp.jax_experimental.ffi import jax_callable
|
|
940
|
+
|
|
941
|
+
j = jax_callable(double_func, num_outputs=1)
|
|
942
|
+
|
|
943
|
+
ndev = jax.local_device_count()
|
|
944
|
+
per_device = max(ARRAY_SIZE // ndev, 64)
|
|
945
|
+
x = jp.arange(ndev * per_device, dtype=jp.float32).reshape((ndev, per_device))
|
|
946
|
+
|
|
947
|
+
def per_device_func(v):
|
|
948
|
+
(y,) = j(v)
|
|
949
|
+
return y
|
|
950
|
+
|
|
951
|
+
y = jax.pmap(per_device_func)(x)
|
|
952
|
+
|
|
953
|
+
wp.synchronize()
|
|
954
|
+
|
|
955
|
+
assert_np_equal(np.asarray(y), 2 * np.asarray(x))
|
|
956
|
+
|
|
957
|
+
|
|
958
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
959
|
+
def test_ffi_jax_callable_pmap_multi_output(test, device):
|
|
960
|
+
import jax
|
|
961
|
+
import jax.numpy as jp
|
|
962
|
+
|
|
963
|
+
from warp.jax_experimental.ffi import jax_callable
|
|
964
|
+
|
|
965
|
+
def multi_out_py(
|
|
966
|
+
a: wp.array(dtype=float),
|
|
967
|
+
b: wp.array(dtype=float),
|
|
968
|
+
s: float,
|
|
969
|
+
c: wp.array(dtype=float),
|
|
970
|
+
d: wp.array(dtype=float),
|
|
971
|
+
):
|
|
972
|
+
wp.launch(multi_out_kernel, dim=a.shape, inputs=[a, b, s], outputs=[c, d])
|
|
973
|
+
|
|
974
|
+
j = jax_callable(multi_out_py, num_outputs=2)
|
|
975
|
+
|
|
976
|
+
ndev = jax.local_device_count()
|
|
977
|
+
per_device = max(ARRAY_SIZE // ndev, 64)
|
|
978
|
+
a = jp.arange(ndev * per_device, dtype=jp.float32).reshape((ndev, per_device))
|
|
979
|
+
b = jp.ones((ndev, per_device), dtype=jp.float32)
|
|
980
|
+
s = 3.0
|
|
981
|
+
|
|
982
|
+
def per_device_func(aa, bb):
|
|
983
|
+
c, d = j(aa, bb, s)
|
|
984
|
+
return c + d # simple combine to exercise both outputs
|
|
985
|
+
|
|
986
|
+
out = jax.pmap(per_device_func)(a, b)
|
|
987
|
+
|
|
988
|
+
wp.synchronize()
|
|
989
|
+
|
|
990
|
+
a_np = np.arange(ndev * per_device, dtype=np.float32).reshape((ndev, per_device))
|
|
991
|
+
b_np = np.ones((ndev, per_device), dtype=np.float32)
|
|
992
|
+
ref = (a_np + b_np) + s * a_np
|
|
993
|
+
assert_np_equal(np.asarray(out), ref)
|
|
994
|
+
|
|
995
|
+
|
|
996
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
997
|
+
def test_ffi_jax_callable_pmap_multi_stage(test, device):
|
|
998
|
+
import jax
|
|
999
|
+
import jax.numpy as jp
|
|
1000
|
+
|
|
1001
|
+
from warp.jax_experimental.ffi import jax_callable
|
|
1002
|
+
|
|
1003
|
+
def multi_stage_py(
|
|
1004
|
+
a: wp.array(dtype=float),
|
|
1005
|
+
b: wp.array(dtype=float),
|
|
1006
|
+
alpha: float,
|
|
1007
|
+
tmp: wp.array(dtype=float),
|
|
1008
|
+
out: wp.array(dtype=float),
|
|
1009
|
+
):
|
|
1010
|
+
wp.launch(add_kernel, dim=a.shape, inputs=[a, b], outputs=[tmp])
|
|
1011
|
+
wp.launch(axpy_kernel, dim=a.shape, inputs=[tmp, b, alpha], outputs=[out])
|
|
1012
|
+
|
|
1013
|
+
j = jax_callable(multi_stage_py, num_outputs=2)
|
|
1014
|
+
|
|
1015
|
+
ndev = jax.local_device_count()
|
|
1016
|
+
per_device = max(ARRAY_SIZE // ndev, 64)
|
|
1017
|
+
a = jp.arange(ndev * per_device, dtype=jp.float32).reshape((ndev, per_device))
|
|
1018
|
+
b = jp.ones((ndev, per_device), dtype=jp.float32)
|
|
1019
|
+
alpha = 2.5
|
|
1020
|
+
|
|
1021
|
+
def per_device_func(aa, bb):
|
|
1022
|
+
tmp, out = j(aa, bb, alpha)
|
|
1023
|
+
return tmp + out
|
|
1024
|
+
|
|
1025
|
+
combined = jax.pmap(per_device_func)(a, b)
|
|
1026
|
+
|
|
1027
|
+
wp.synchronize()
|
|
1028
|
+
|
|
1029
|
+
a_np = np.arange(ndev * per_device, dtype=np.float32).reshape((ndev, per_device))
|
|
1030
|
+
b_np = np.ones((ndev, per_device), dtype=np.float32)
|
|
1031
|
+
tmp_ref = a_np + b_np
|
|
1032
|
+
out_ref = alpha * (a_np + b_np) + b_np
|
|
1033
|
+
ref = tmp_ref + out_ref
|
|
1034
|
+
assert_np_equal(np.asarray(combined), ref)
|
|
737
1035
|
|
|
738
1036
|
|
|
739
1037
|
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
@@ -770,7 +1068,7 @@ def test_ffi_callback(test, device):
|
|
|
770
1068
|
# register callback
|
|
771
1069
|
register_ffi_callback("warp_func", warp_func)
|
|
772
1070
|
|
|
773
|
-
n =
|
|
1071
|
+
n = ARRAY_SIZE
|
|
774
1072
|
|
|
775
1073
|
with jax.default_device(wp.device_to_jax(device)):
|
|
776
1074
|
# inputs
|
|
@@ -788,8 +1086,344 @@ def test_ffi_callback(test, device):
|
|
|
788
1086
|
# call it
|
|
789
1087
|
c, d = call(a, b, scale=s)
|
|
790
1088
|
|
|
791
|
-
|
|
792
|
-
|
|
1089
|
+
wp.synchronize_device(device)
|
|
1090
|
+
|
|
1091
|
+
assert_np_equal(c, 2 * np.arange(ARRAY_SIZE, dtype=np.float32))
|
|
1092
|
+
assert_np_equal(d, 2 * np.arange(ARRAY_SIZE, dtype=np.float32).reshape((ARRAY_SIZE // 2, 2)))
|
|
1093
|
+
|
|
1094
|
+
|
|
1095
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
1096
|
+
def test_ffi_jax_kernel_autodiff_simple(test, device):
|
|
1097
|
+
import jax
|
|
1098
|
+
import jax.numpy as jp
|
|
1099
|
+
|
|
1100
|
+
from warp.jax_experimental.ffi import jax_kernel
|
|
1101
|
+
|
|
1102
|
+
jax_func = jax_kernel(
|
|
1103
|
+
scale_sum_square_kernel,
|
|
1104
|
+
num_outputs=1,
|
|
1105
|
+
enable_backward=True,
|
|
1106
|
+
)
|
|
1107
|
+
|
|
1108
|
+
from functools import partial
|
|
1109
|
+
|
|
1110
|
+
@partial(jax.jit, static_argnames=["s"])
|
|
1111
|
+
def loss(a, b, s):
|
|
1112
|
+
out = jax_func(a, b, s)[0]
|
|
1113
|
+
return jp.sum(out)
|
|
1114
|
+
|
|
1115
|
+
n = ARRAY_SIZE
|
|
1116
|
+
a = jp.arange(n, dtype=jp.float32)
|
|
1117
|
+
b = jp.ones(n, dtype=jp.float32)
|
|
1118
|
+
s = 2.0
|
|
1119
|
+
|
|
1120
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
1121
|
+
da, db = jax.grad(loss, argnums=(0, 1))(a, b, s)
|
|
1122
|
+
|
|
1123
|
+
wp.synchronize_device(device)
|
|
1124
|
+
|
|
1125
|
+
# reference gradients
|
|
1126
|
+
# d/da sum((a*s + b)^2) = sum(2*(a*s + b) * s)
|
|
1127
|
+
# d/db sum((a*s + b)^2) = sum(2*(a*s + b))
|
|
1128
|
+
a_np = np.arange(n, dtype=np.float32)
|
|
1129
|
+
b_np = np.ones(n, dtype=np.float32)
|
|
1130
|
+
ref_da = 2.0 * (a_np * s + b_np) * s
|
|
1131
|
+
ref_db = 2.0 * (a_np * s + b_np)
|
|
1132
|
+
|
|
1133
|
+
assert_np_equal(np.asarray(da), ref_da)
|
|
1134
|
+
assert_np_equal(np.asarray(db), ref_db)
|
|
1135
|
+
|
|
1136
|
+
|
|
1137
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
1138
|
+
def test_ffi_jax_kernel_autodiff_jit_of_grad_simple(test, device):
|
|
1139
|
+
import jax
|
|
1140
|
+
import jax.numpy as jp
|
|
1141
|
+
|
|
1142
|
+
from warp.jax_experimental.ffi import jax_kernel
|
|
1143
|
+
|
|
1144
|
+
jax_func = jax_kernel(scale_sum_square_kernel, num_outputs=1, enable_backward=True)
|
|
1145
|
+
|
|
1146
|
+
def loss(a, b, s):
|
|
1147
|
+
out = jax_func(a, b, s)[0]
|
|
1148
|
+
return jp.sum(out)
|
|
1149
|
+
|
|
1150
|
+
grad_fn = jax.grad(loss, argnums=(0, 1))
|
|
1151
|
+
|
|
1152
|
+
# more typical: jit(grad(...)) with static scalar
|
|
1153
|
+
jitted_grad = jax.jit(lambda a, b, s: grad_fn(a, b, s), static_argnames=("s",))
|
|
1154
|
+
|
|
1155
|
+
n = ARRAY_SIZE
|
|
1156
|
+
a = jp.arange(n, dtype=jp.float32)
|
|
1157
|
+
b = jp.ones(n, dtype=jp.float32)
|
|
1158
|
+
s = 2.0
|
|
1159
|
+
|
|
1160
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
1161
|
+
da, db = jitted_grad(a, b, s)
|
|
1162
|
+
|
|
1163
|
+
wp.synchronize_device(device)
|
|
1164
|
+
|
|
1165
|
+
a_np = np.arange(n, dtype=np.float32)
|
|
1166
|
+
b_np = np.ones(n, dtype=np.float32)
|
|
1167
|
+
ref_da = 2.0 * (a_np * s + b_np) * s
|
|
1168
|
+
ref_db = 2.0 * (a_np * s + b_np)
|
|
1169
|
+
|
|
1170
|
+
assert_np_equal(np.asarray(da), ref_da)
|
|
1171
|
+
assert_np_equal(np.asarray(db), ref_db)
|
|
1172
|
+
|
|
1173
|
+
|
|
1174
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
1175
|
+
def test_ffi_jax_kernel_autodiff_multi_output(test, device):
|
|
1176
|
+
import jax
|
|
1177
|
+
import jax.numpy as jp
|
|
1178
|
+
|
|
1179
|
+
from warp.jax_experimental.ffi import jax_kernel
|
|
1180
|
+
|
|
1181
|
+
jax_func = jax_kernel(multi_out_kernel_v3, num_outputs=2, enable_backward=True)
|
|
1182
|
+
|
|
1183
|
+
def caller(fn, a, b, s):
|
|
1184
|
+
c, d = fn(a, b, s)
|
|
1185
|
+
return jp.sum(c + d)
|
|
1186
|
+
|
|
1187
|
+
@jax.jit
|
|
1188
|
+
def grads(a, b, s):
|
|
1189
|
+
# mark s as static in the inner call via partial to avoid hashing
|
|
1190
|
+
def _inner(a, b, s):
|
|
1191
|
+
return caller(jax_func, a, b, s)
|
|
1192
|
+
|
|
1193
|
+
return jax.grad(lambda a, b: _inner(a, b, 2.0), argnums=(0, 1))(a, b)
|
|
1194
|
+
|
|
1195
|
+
n = ARRAY_SIZE
|
|
1196
|
+
a = jp.arange(n, dtype=jp.float32)
|
|
1197
|
+
b = jp.ones(n, dtype=jp.float32)
|
|
1198
|
+
s = 2.0
|
|
1199
|
+
|
|
1200
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
1201
|
+
da, db = grads(a, b, s)
|
|
1202
|
+
|
|
1203
|
+
wp.synchronize_device(device)
|
|
1204
|
+
|
|
1205
|
+
a_np = np.arange(n, dtype=np.float32)
|
|
1206
|
+
b_np = np.ones(n, dtype=np.float32)
|
|
1207
|
+
# d/da sum(c+d) = 2*a + b*s
|
|
1208
|
+
ref_da = 2.0 * a_np + b_np * s
|
|
1209
|
+
# d/db sum(c+d) = a*s
|
|
1210
|
+
ref_db = a_np * s
|
|
1211
|
+
|
|
1212
|
+
assert_np_equal(np.asarray(da), ref_da)
|
|
1213
|
+
assert_np_equal(np.asarray(db), ref_db)
|
|
1214
|
+
|
|
1215
|
+
|
|
1216
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
1217
|
+
def test_ffi_jax_kernel_autodiff_jit_of_grad_multi_output(test, device):
|
|
1218
|
+
import jax
|
|
1219
|
+
import jax.numpy as jp
|
|
1220
|
+
|
|
1221
|
+
from warp.jax_experimental.ffi import jax_kernel
|
|
1222
|
+
|
|
1223
|
+
jax_func = jax_kernel(multi_out_kernel_v3, num_outputs=2, enable_backward=True)
|
|
1224
|
+
|
|
1225
|
+
def loss(a, b, s):
|
|
1226
|
+
c, d = jax_func(a, b, s)
|
|
1227
|
+
return jp.sum(c + d)
|
|
1228
|
+
|
|
1229
|
+
grad_fn = jax.grad(loss, argnums=(0, 1))
|
|
1230
|
+
jitted_grad = jax.jit(lambda a, b, s: grad_fn(a, b, s), static_argnames=("s",))
|
|
1231
|
+
|
|
1232
|
+
n = ARRAY_SIZE
|
|
1233
|
+
a = jp.arange(n, dtype=jp.float32)
|
|
1234
|
+
b = jp.ones(n, dtype=jp.float32)
|
|
1235
|
+
s = 2.0
|
|
1236
|
+
|
|
1237
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
1238
|
+
da, db = jitted_grad(a, b, s)
|
|
1239
|
+
|
|
1240
|
+
wp.synchronize_device(device)
|
|
1241
|
+
|
|
1242
|
+
a_np = np.arange(n, dtype=np.float32)
|
|
1243
|
+
b_np = np.ones(n, dtype=np.float32)
|
|
1244
|
+
ref_da = 2.0 * a_np + b_np * s
|
|
1245
|
+
ref_db = a_np * s
|
|
1246
|
+
|
|
1247
|
+
assert_np_equal(np.asarray(da), ref_da)
|
|
1248
|
+
assert_np_equal(np.asarray(db), ref_db)
|
|
1249
|
+
|
|
1250
|
+
|
|
1251
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
1252
|
+
def test_ffi_jax_kernel_autodiff_2d(test, device):
|
|
1253
|
+
import jax
|
|
1254
|
+
import jax.numpy as jp
|
|
1255
|
+
|
|
1256
|
+
from warp.jax_experimental.ffi import jax_kernel
|
|
1257
|
+
|
|
1258
|
+
jax_func = jax_kernel(inc_2d_kernel, num_outputs=1, enable_backward=True)
|
|
1259
|
+
|
|
1260
|
+
@jax.jit
|
|
1261
|
+
def loss(a):
|
|
1262
|
+
out = jax_func(a)[0]
|
|
1263
|
+
return jp.sum(out)
|
|
1264
|
+
|
|
1265
|
+
n, m = 8, 6
|
|
1266
|
+
a = jp.arange(n * m, dtype=jp.float32).reshape((n, m))
|
|
1267
|
+
|
|
1268
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
1269
|
+
(da,) = jax.grad(loss, argnums=(0,))(a)
|
|
1270
|
+
|
|
1271
|
+
wp.synchronize_device(device)
|
|
1272
|
+
|
|
1273
|
+
ref = np.ones((n, m), dtype=np.float32)
|
|
1274
|
+
assert_np_equal(np.asarray(da), ref)
|
|
1275
|
+
|
|
1276
|
+
|
|
1277
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
1278
|
+
def test_ffi_jax_kernel_autodiff_vec2(test, device):
|
|
1279
|
+
import jax
|
|
1280
|
+
import jax.numpy as jp
|
|
1281
|
+
|
|
1282
|
+
from warp.jax_experimental.ffi import jax_kernel
|
|
1283
|
+
|
|
1284
|
+
jax_func = jax_kernel(scale_vec_kernel, num_outputs=1, enable_backward=True)
|
|
1285
|
+
|
|
1286
|
+
from functools import partial
|
|
1287
|
+
|
|
1288
|
+
@partial(jax.jit, static_argnames=("s",))
|
|
1289
|
+
def loss(a, s):
|
|
1290
|
+
out = jax_func(a, s)[0]
|
|
1291
|
+
return jp.sum(out)
|
|
1292
|
+
|
|
1293
|
+
n = ARRAY_SIZE
|
|
1294
|
+
a = jp.arange(n, dtype=jp.float32).reshape((n // 2, 2))
|
|
1295
|
+
s = 3.0
|
|
1296
|
+
|
|
1297
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
1298
|
+
(da,) = jax.grad(loss, argnums=(0,))(a, s)
|
|
1299
|
+
|
|
1300
|
+
wp.synchronize_device(device)
|
|
1301
|
+
|
|
1302
|
+
# d/da sum(a*s) = s
|
|
1303
|
+
ref = np.full_like(np.asarray(a), s)
|
|
1304
|
+
assert_np_equal(np.asarray(da), ref)
|
|
1305
|
+
|
|
1306
|
+
|
|
1307
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
1308
|
+
def test_ffi_jax_kernel_autodiff_mat22(test, device):
|
|
1309
|
+
import jax
|
|
1310
|
+
import jax.numpy as jp
|
|
1311
|
+
|
|
1312
|
+
from warp.jax_experimental.ffi import jax_kernel
|
|
1313
|
+
|
|
1314
|
+
@wp.kernel
|
|
1315
|
+
def scale_mat_kernel(a: wp.array(dtype=wp.mat22), s: float, out: wp.array(dtype=wp.mat22)):
|
|
1316
|
+
tid = wp.tid()
|
|
1317
|
+
out[tid] = a[tid] * s
|
|
1318
|
+
|
|
1319
|
+
jax_func = jax_kernel(scale_mat_kernel, num_outputs=1, enable_backward=True)
|
|
1320
|
+
|
|
1321
|
+
from functools import partial
|
|
1322
|
+
|
|
1323
|
+
@partial(jax.jit, static_argnames=("s",))
|
|
1324
|
+
def loss(a, s):
|
|
1325
|
+
out = jax_func(a, s)[0]
|
|
1326
|
+
return jp.sum(out)
|
|
1327
|
+
|
|
1328
|
+
n = 12 # must be divisible by 4 for 2x2 matrices
|
|
1329
|
+
a = jp.arange(n, dtype=jp.float32).reshape((n // 4, 2, 2))
|
|
1330
|
+
s = 2.5
|
|
1331
|
+
|
|
1332
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
1333
|
+
(da,) = jax.grad(loss, argnums=(0,))(a, s)
|
|
1334
|
+
|
|
1335
|
+
wp.synchronize_device(device)
|
|
1336
|
+
|
|
1337
|
+
ref = np.full((n // 4, 2, 2), s, dtype=np.float32)
|
|
1338
|
+
assert_np_equal(np.asarray(da), ref)
|
|
1339
|
+
|
|
1340
|
+
|
|
1341
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
1342
|
+
def test_ffi_jax_kernel_autodiff_static_required(test, device):
|
|
1343
|
+
import jax
|
|
1344
|
+
import jax.numpy as jp
|
|
1345
|
+
|
|
1346
|
+
from warp.jax_experimental.ffi import jax_kernel
|
|
1347
|
+
|
|
1348
|
+
# Require explicit static_argnames for scalar s
|
|
1349
|
+
jax_func = jax_kernel(scale_sum_square_kernel, num_outputs=1, enable_backward=True)
|
|
1350
|
+
|
|
1351
|
+
def loss(a, b, s):
|
|
1352
|
+
out = jax_func(a, b, s)[0]
|
|
1353
|
+
return jp.sum(out)
|
|
1354
|
+
|
|
1355
|
+
n = ARRAY_SIZE
|
|
1356
|
+
a = jp.arange(n, dtype=jp.float32)
|
|
1357
|
+
b = jp.ones(n, dtype=jp.float32)
|
|
1358
|
+
s = 1.5
|
|
1359
|
+
|
|
1360
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
1361
|
+
da, db = jax.grad(loss, argnums=(0, 1))(a, b, s)
|
|
1362
|
+
|
|
1363
|
+
wp.synchronize_device(device)
|
|
1364
|
+
|
|
1365
|
+
a_np = np.arange(n, dtype=np.float32)
|
|
1366
|
+
b_np = np.ones(n, dtype=np.float32)
|
|
1367
|
+
ref_da = 2.0 * (a_np * s + b_np) * s
|
|
1368
|
+
ref_db = 2.0 * (a_np * s + b_np)
|
|
1369
|
+
|
|
1370
|
+
assert_np_equal(np.asarray(da), ref_da)
|
|
1371
|
+
assert_np_equal(np.asarray(db), ref_db)
|
|
1372
|
+
|
|
1373
|
+
|
|
1374
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
1375
|
+
def test_ffi_jax_kernel_autodiff_pmap_triple(test, device):
|
|
1376
|
+
import jax
|
|
1377
|
+
import jax.numpy as jp
|
|
1378
|
+
|
|
1379
|
+
from warp.jax_experimental.ffi import jax_kernel
|
|
1380
|
+
|
|
1381
|
+
jax_mul = jax_kernel(triple_kernel, num_outputs=1, enable_backward=True)
|
|
1382
|
+
|
|
1383
|
+
ndev = jax.local_device_count()
|
|
1384
|
+
per_device = ARRAY_SIZE // ndev
|
|
1385
|
+
x = jp.arange(ndev * per_device, dtype=jp.float32).reshape((ndev, per_device))
|
|
1386
|
+
|
|
1387
|
+
def per_device_loss(x):
|
|
1388
|
+
y = jax_mul(x)[0]
|
|
1389
|
+
return jp.sum(y)
|
|
1390
|
+
|
|
1391
|
+
grads = jax.pmap(jax.grad(per_device_loss))(x)
|
|
1392
|
+
|
|
1393
|
+
wp.synchronize()
|
|
1394
|
+
|
|
1395
|
+
assert_np_equal(np.asarray(grads), np.full((ndev, per_device), 3.0, dtype=np.float32))
|
|
1396
|
+
|
|
1397
|
+
|
|
1398
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
1399
|
+
def test_ffi_jax_kernel_autodiff_pmap_multi_output(test, device):
|
|
1400
|
+
import jax
|
|
1401
|
+
import jax.numpy as jp
|
|
1402
|
+
|
|
1403
|
+
from warp.jax_experimental.ffi import jax_kernel
|
|
1404
|
+
|
|
1405
|
+
jax_mo = jax_kernel(multi_out_kernel_v2, num_outputs=2, enable_backward=True)
|
|
1406
|
+
|
|
1407
|
+
ndev = jax.local_device_count()
|
|
1408
|
+
per_device = ARRAY_SIZE // ndev
|
|
1409
|
+
a = jp.arange(ndev * per_device, dtype=jp.float32).reshape((ndev, per_device))
|
|
1410
|
+
b = jp.arange(ndev * per_device, dtype=jp.float32).reshape((ndev, per_device))
|
|
1411
|
+
s = 2.0
|
|
1412
|
+
|
|
1413
|
+
def per_dev_loss(aa, bb):
|
|
1414
|
+
c, d = jax_mo(aa, bb, s)
|
|
1415
|
+
return jp.sum(c + d)
|
|
1416
|
+
|
|
1417
|
+
da, db = jax.pmap(jax.grad(per_dev_loss, argnums=(0, 1)))(a, b)
|
|
1418
|
+
|
|
1419
|
+
wp.synchronize()
|
|
1420
|
+
|
|
1421
|
+
a_np = np.arange(ndev * per_device, dtype=np.float32).reshape((ndev, per_device))
|
|
1422
|
+
b_np = np.arange(ndev * per_device, dtype=np.float32).reshape((ndev, per_device))
|
|
1423
|
+
ref_da = 2.0 * a_np + b_np * s
|
|
1424
|
+
ref_db = a_np * s
|
|
1425
|
+
assert_np_equal(np.asarray(da), ref_da)
|
|
1426
|
+
assert_np_equal(np.asarray(db), ref_db)
|
|
793
1427
|
|
|
794
1428
|
|
|
795
1429
|
class TestJax(unittest.TestCase):
|
|
@@ -936,10 +1570,99 @@ try:
|
|
|
936
1570
|
add_function_test(
|
|
937
1571
|
TestJax, "test_ffi_jax_callable_in_out", test_ffi_jax_callable_in_out, devices=jax_compatible_cuda_devices
|
|
938
1572
|
)
|
|
1573
|
+
add_function_test(
|
|
1574
|
+
TestJax,
|
|
1575
|
+
"test_ffi_jax_callable_graph_cache",
|
|
1576
|
+
test_ffi_jax_callable_graph_cache,
|
|
1577
|
+
devices=jax_compatible_cuda_devices,
|
|
1578
|
+
)
|
|
1579
|
+
|
|
1580
|
+
# pmap tests
|
|
1581
|
+
add_function_test(
|
|
1582
|
+
TestJax,
|
|
1583
|
+
"test_ffi_jax_callable_pmap_multi_output",
|
|
1584
|
+
test_ffi_jax_callable_pmap_multi_output,
|
|
1585
|
+
devices=None,
|
|
1586
|
+
)
|
|
1587
|
+
add_function_test(
|
|
1588
|
+
TestJax,
|
|
1589
|
+
"test_ffi_jax_callable_pmap_mul",
|
|
1590
|
+
test_ffi_jax_callable_pmap_mul,
|
|
1591
|
+
devices=None,
|
|
1592
|
+
)
|
|
1593
|
+
add_function_test(
|
|
1594
|
+
TestJax,
|
|
1595
|
+
"test_ffi_jax_callable_pmap_multi_stage",
|
|
1596
|
+
test_ffi_jax_callable_pmap_multi_stage,
|
|
1597
|
+
devices=None,
|
|
1598
|
+
)
|
|
939
1599
|
|
|
940
1600
|
# ffi callback tests
|
|
941
1601
|
add_function_test(TestJax, "test_ffi_callback", test_ffi_callback, devices=jax_compatible_cuda_devices)
|
|
942
1602
|
|
|
1603
|
+
# autodiff tests
|
|
1604
|
+
add_function_test(
|
|
1605
|
+
TestJax,
|
|
1606
|
+
"test_ffi_jax_kernel_autodiff_simple",
|
|
1607
|
+
test_ffi_jax_kernel_autodiff_simple,
|
|
1608
|
+
devices=jax_compatible_cuda_devices,
|
|
1609
|
+
)
|
|
1610
|
+
add_function_test(
|
|
1611
|
+
TestJax,
|
|
1612
|
+
"test_ffi_jax_kernel_autodiff_jit_of_grad_simple",
|
|
1613
|
+
test_ffi_jax_kernel_autodiff_jit_of_grad_simple,
|
|
1614
|
+
devices=jax_compatible_cuda_devices,
|
|
1615
|
+
)
|
|
1616
|
+
add_function_test(
|
|
1617
|
+
TestJax,
|
|
1618
|
+
"test_ffi_jax_kernel_autodiff_multi_output",
|
|
1619
|
+
test_ffi_jax_kernel_autodiff_multi_output,
|
|
1620
|
+
devices=jax_compatible_cuda_devices,
|
|
1621
|
+
)
|
|
1622
|
+
add_function_test(
|
|
1623
|
+
TestJax,
|
|
1624
|
+
"test_ffi_jax_kernel_autodiff_jit_of_grad_multi_output",
|
|
1625
|
+
test_ffi_jax_kernel_autodiff_jit_of_grad_multi_output,
|
|
1626
|
+
devices=jax_compatible_cuda_devices,
|
|
1627
|
+
)
|
|
1628
|
+
add_function_test(
|
|
1629
|
+
TestJax,
|
|
1630
|
+
"test_ffi_jax_kernel_autodiff_2d",
|
|
1631
|
+
test_ffi_jax_kernel_autodiff_2d,
|
|
1632
|
+
devices=jax_compatible_cuda_devices,
|
|
1633
|
+
)
|
|
1634
|
+
add_function_test(
|
|
1635
|
+
TestJax,
|
|
1636
|
+
"test_ffi_jax_kernel_autodiff_vec2",
|
|
1637
|
+
test_ffi_jax_kernel_autodiff_vec2,
|
|
1638
|
+
devices=jax_compatible_cuda_devices,
|
|
1639
|
+
)
|
|
1640
|
+
add_function_test(
|
|
1641
|
+
TestJax,
|
|
1642
|
+
"test_ffi_jax_kernel_autodiff_mat22",
|
|
1643
|
+
test_ffi_jax_kernel_autodiff_mat22,
|
|
1644
|
+
devices=jax_compatible_cuda_devices,
|
|
1645
|
+
)
|
|
1646
|
+
add_function_test(
|
|
1647
|
+
TestJax,
|
|
1648
|
+
"test_ffi_jax_kernel_autodiff_static_required",
|
|
1649
|
+
test_ffi_jax_kernel_autodiff_static_required,
|
|
1650
|
+
devices=jax_compatible_cuda_devices,
|
|
1651
|
+
)
|
|
1652
|
+
|
|
1653
|
+
# autodiff with pmap tests
|
|
1654
|
+
add_function_test(
|
|
1655
|
+
TestJax,
|
|
1656
|
+
"test_ffi_jax_kernel_autodiff_pmap_triple",
|
|
1657
|
+
test_ffi_jax_kernel_autodiff_pmap_triple,
|
|
1658
|
+
devices=None,
|
|
1659
|
+
)
|
|
1660
|
+
add_function_test(
|
|
1661
|
+
TestJax,
|
|
1662
|
+
"test_ffi_jax_kernel_autodiff_pmap_multi_output",
|
|
1663
|
+
test_ffi_jax_kernel_autodiff_pmap_multi_output,
|
|
1664
|
+
devices=None,
|
|
1665
|
+
)
|
|
943
1666
|
|
|
944
1667
|
except Exception as e:
|
|
945
1668
|
print(f"Skipping Jax tests due to exception: {e}")
|