warp-lang 1.9.1__py3-none-manylinux_2_34_aarch64.whl → 1.10.0rc2__py3-none-manylinux_2_34_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +301 -287
- warp/__init__.pyi +794 -305
- warp/_src/__init__.py +14 -0
- warp/_src/autograd.py +1075 -0
- warp/_src/build.py +618 -0
- warp/_src/build_dll.py +640 -0
- warp/{builtins.py → _src/builtins.py} +1382 -377
- warp/_src/codegen.py +4359 -0
- warp/{config.py → _src/config.py} +178 -169
- warp/_src/constants.py +57 -0
- warp/_src/context.py +8294 -0
- warp/_src/dlpack.py +462 -0
- warp/_src/fabric.py +355 -0
- warp/_src/fem/__init__.py +14 -0
- warp/_src/fem/adaptivity.py +508 -0
- warp/_src/fem/cache.py +687 -0
- warp/_src/fem/dirichlet.py +188 -0
- warp/{fem → _src/fem}/domain.py +40 -30
- warp/_src/fem/field/__init__.py +131 -0
- warp/_src/fem/field/field.py +701 -0
- warp/{fem → _src/fem}/field/nodal_field.py +30 -15
- warp/{fem → _src/fem}/field/restriction.py +1 -1
- warp/{fem → _src/fem}/field/virtual.py +53 -27
- warp/_src/fem/geometry/__init__.py +32 -0
- warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +77 -163
- warp/_src/fem/geometry/closest_point.py +97 -0
- warp/{fem → _src/fem}/geometry/deformed_geometry.py +14 -22
- warp/{fem → _src/fem}/geometry/element.py +32 -10
- warp/{fem → _src/fem}/geometry/geometry.py +48 -20
- warp/{fem → _src/fem}/geometry/grid_2d.py +12 -23
- warp/{fem → _src/fem}/geometry/grid_3d.py +12 -23
- warp/{fem → _src/fem}/geometry/hexmesh.py +40 -63
- warp/{fem → _src/fem}/geometry/nanogrid.py +255 -248
- warp/{fem → _src/fem}/geometry/partition.py +121 -63
- warp/{fem → _src/fem}/geometry/quadmesh.py +26 -45
- warp/{fem → _src/fem}/geometry/tetmesh.py +40 -63
- warp/{fem → _src/fem}/geometry/trimesh.py +26 -45
- warp/{fem → _src/fem}/integrate.py +164 -158
- warp/_src/fem/linalg.py +383 -0
- warp/_src/fem/operator.py +396 -0
- warp/_src/fem/polynomial.py +229 -0
- warp/{fem → _src/fem}/quadrature/pic_quadrature.py +15 -20
- warp/{fem → _src/fem}/quadrature/quadrature.py +95 -47
- warp/_src/fem/space/__init__.py +248 -0
- warp/{fem → _src/fem}/space/basis_function_space.py +20 -11
- warp/_src/fem/space/basis_space.py +679 -0
- warp/{fem → _src/fem}/space/dof_mapper.py +3 -3
- warp/{fem → _src/fem}/space/function_space.py +14 -13
- warp/{fem → _src/fem}/space/grid_2d_function_space.py +4 -7
- warp/{fem → _src/fem}/space/grid_3d_function_space.py +4 -4
- warp/{fem → _src/fem}/space/hexmesh_function_space.py +4 -10
- warp/{fem → _src/fem}/space/nanogrid_function_space.py +3 -9
- warp/{fem → _src/fem}/space/partition.py +117 -60
- warp/{fem → _src/fem}/space/quadmesh_function_space.py +4 -10
- warp/{fem → _src/fem}/space/restriction.py +66 -33
- warp/_src/fem/space/shape/__init__.py +152 -0
- warp/{fem → _src/fem}/space/shape/cube_shape_function.py +9 -9
- warp/{fem → _src/fem}/space/shape/shape_function.py +8 -9
- warp/{fem → _src/fem}/space/shape/square_shape_function.py +6 -6
- warp/{fem → _src/fem}/space/shape/tet_shape_function.py +3 -3
- warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +3 -3
- warp/{fem → _src/fem}/space/tetmesh_function_space.py +3 -9
- warp/_src/fem/space/topology.py +459 -0
- warp/{fem → _src/fem}/space/trimesh_function_space.py +3 -9
- warp/_src/fem/types.py +112 -0
- warp/_src/fem/utils.py +486 -0
- warp/_src/jax.py +186 -0
- warp/_src/jax_experimental/__init__.py +14 -0
- warp/_src/jax_experimental/custom_call.py +387 -0
- warp/_src/jax_experimental/ffi.py +1284 -0
- warp/_src/jax_experimental/xla_ffi.py +656 -0
- warp/_src/marching_cubes.py +708 -0
- warp/_src/math.py +414 -0
- warp/_src/optim/__init__.py +14 -0
- warp/_src/optim/adam.py +163 -0
- warp/_src/optim/linear.py +1606 -0
- warp/_src/optim/sgd.py +112 -0
- warp/_src/paddle.py +406 -0
- warp/_src/render/__init__.py +14 -0
- warp/_src/render/imgui_manager.py +289 -0
- warp/_src/render/render_opengl.py +3636 -0
- warp/_src/render/render_usd.py +937 -0
- warp/_src/render/utils.py +160 -0
- warp/_src/sparse.py +2716 -0
- warp/_src/tape.py +1206 -0
- warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
- warp/_src/torch.py +391 -0
- warp/_src/types.py +5870 -0
- warp/_src/utils.py +1693 -0
- warp/autograd.py +12 -1054
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/build.py +8 -588
- warp/build_dll.py +6 -721
- warp/codegen.py +6 -4251
- warp/constants.py +6 -39
- warp/context.py +12 -8062
- warp/dlpack.py +6 -444
- warp/examples/distributed/example_jacobi_mpi.py +4 -5
- warp/examples/fem/example_adaptive_grid.py +1 -1
- warp/examples/fem/example_apic_fluid.py +1 -1
- warp/examples/fem/example_burgers.py +8 -8
- warp/examples/fem/example_diffusion.py +1 -1
- warp/examples/fem/example_distortion_energy.py +1 -1
- warp/examples/fem/example_mixed_elasticity.py +2 -2
- warp/examples/fem/example_navier_stokes.py +1 -1
- warp/examples/fem/example_nonconforming_contact.py +7 -7
- warp/examples/fem/example_stokes.py +1 -1
- warp/examples/fem/example_stokes_transfer.py +1 -1
- warp/examples/fem/utils.py +2 -2
- warp/examples/interop/example_jax_callable.py +1 -1
- warp/examples/interop/example_jax_ffi_callback.py +1 -1
- warp/examples/interop/example_jax_kernel.py +1 -1
- warp/examples/tile/example_tile_mcgp.py +191 -0
- warp/fabric.py +6 -337
- warp/fem/__init__.py +159 -97
- warp/fem/adaptivity.py +7 -489
- warp/fem/cache.py +9 -648
- warp/fem/dirichlet.py +6 -184
- warp/fem/field/__init__.py +8 -109
- warp/fem/field/field.py +7 -652
- warp/fem/geometry/__init__.py +7 -18
- warp/fem/geometry/closest_point.py +11 -77
- warp/fem/linalg.py +18 -366
- warp/fem/operator.py +11 -369
- warp/fem/polynomial.py +9 -209
- warp/fem/space/__init__.py +5 -211
- warp/fem/space/basis_space.py +6 -662
- warp/fem/space/shape/__init__.py +41 -118
- warp/fem/space/topology.py +6 -437
- warp/fem/types.py +6 -81
- warp/fem/utils.py +11 -444
- warp/jax.py +8 -165
- warp/jax_experimental/__init__.py +14 -1
- warp/jax_experimental/custom_call.py +8 -365
- warp/jax_experimental/ffi.py +17 -873
- warp/jax_experimental/xla_ffi.py +5 -605
- warp/marching_cubes.py +5 -689
- warp/math.py +16 -393
- warp/native/array.h +385 -37
- warp/native/builtin.h +314 -37
- warp/native/bvh.cpp +43 -9
- warp/native/bvh.cu +62 -27
- warp/native/bvh.h +310 -309
- warp/native/clang/clang.cpp +102 -97
- warp/native/coloring.cpp +0 -1
- warp/native/crt.h +208 -0
- warp/native/exports.h +156 -0
- warp/native/hashgrid.cu +2 -0
- warp/native/intersect.h +24 -1
- warp/native/intersect_tri.h +44 -35
- warp/native/mat.h +1456 -276
- warp/native/mesh.cpp +4 -4
- warp/native/mesh.cu +4 -2
- warp/native/mesh.h +176 -61
- warp/native/quat.h +0 -52
- warp/native/scan.cu +2 -0
- warp/native/sparse.cu +7 -3
- warp/native/spatial.h +12 -0
- warp/native/tile.h +681 -89
- warp/native/tile_radix_sort.h +1 -1
- warp/native/tile_reduce.h +394 -46
- warp/native/tile_scan.h +4 -4
- warp/native/vec.h +469 -0
- warp/native/version.h +23 -0
- warp/native/volume.cpp +1 -1
- warp/native/volume.cu +1 -0
- warp/native/volume.h +1 -1
- warp/native/volume_builder.cu +2 -0
- warp/native/warp.cpp +57 -29
- warp/native/warp.cu +253 -171
- warp/native/warp.h +11 -8
- warp/optim/__init__.py +6 -3
- warp/optim/adam.py +6 -145
- warp/optim/linear.py +14 -1585
- warp/optim/sgd.py +6 -94
- warp/paddle.py +6 -388
- warp/render/__init__.py +8 -4
- warp/render/imgui_manager.py +7 -267
- warp/render/render_opengl.py +6 -3618
- warp/render/render_usd.py +6 -919
- warp/render/utils.py +6 -142
- warp/sparse.py +37 -2563
- warp/tape.py +6 -1188
- warp/tests/__main__.py +1 -1
- warp/tests/cuda/test_async.py +4 -4
- warp/tests/cuda/test_conditional_captures.py +1 -1
- warp/tests/cuda/test_multigpu.py +1 -1
- warp/tests/cuda/test_streams.py +58 -1
- warp/tests/geometry/test_bvh.py +157 -22
- warp/tests/geometry/test_marching_cubes.py +0 -1
- warp/tests/geometry/test_mesh.py +5 -3
- warp/tests/geometry/test_mesh_query_aabb.py +5 -12
- warp/tests/geometry/test_mesh_query_point.py +5 -2
- warp/tests/geometry/test_mesh_query_ray.py +15 -3
- warp/tests/geometry/test_volume_write.py +5 -5
- warp/tests/interop/test_dlpack.py +14 -14
- warp/tests/interop/test_jax.py +772 -49
- warp/tests/interop/test_paddle.py +1 -1
- warp/tests/test_adam.py +0 -1
- warp/tests/test_arithmetic.py +9 -9
- warp/tests/test_array.py +527 -100
- warp/tests/test_array_reduce.py +3 -3
- warp/tests/test_atomic.py +12 -8
- warp/tests/test_atomic_bitwise.py +209 -0
- warp/tests/test_atomic_cas.py +4 -4
- warp/tests/test_bool.py +2 -2
- warp/tests/test_builtins_resolution.py +5 -571
- warp/tests/test_codegen.py +33 -14
- warp/tests/test_conditional.py +1 -1
- warp/tests/test_context.py +6 -6
- warp/tests/test_copy.py +242 -161
- warp/tests/test_ctypes.py +3 -3
- warp/tests/test_devices.py +24 -2
- warp/tests/test_examples.py +16 -84
- warp/tests/test_fabricarray.py +35 -35
- warp/tests/test_fast_math.py +0 -2
- warp/tests/test_fem.py +56 -10
- warp/tests/test_fixedarray.py +3 -3
- warp/tests/test_func.py +8 -5
- warp/tests/test_generics.py +1 -1
- warp/tests/test_indexedarray.py +24 -24
- warp/tests/test_intersect.py +39 -9
- warp/tests/test_large.py +1 -1
- warp/tests/test_lerp.py +3 -1
- warp/tests/test_linear_solvers.py +1 -1
- warp/tests/test_map.py +35 -4
- warp/tests/test_mat.py +52 -62
- warp/tests/test_mat_constructors.py +4 -5
- warp/tests/test_mat_lite.py +1 -1
- warp/tests/test_mat_scalar_ops.py +121 -121
- warp/tests/test_math.py +34 -0
- warp/tests/test_module_aot.py +4 -4
- warp/tests/test_modules_lite.py +28 -2
- warp/tests/test_print.py +11 -11
- warp/tests/test_quat.py +93 -58
- warp/tests/test_runlength_encode.py +1 -1
- warp/tests/test_scalar_ops.py +38 -10
- warp/tests/test_smoothstep.py +1 -1
- warp/tests/test_sparse.py +126 -15
- warp/tests/test_spatial.py +105 -87
- warp/tests/test_special_values.py +6 -6
- warp/tests/test_static.py +7 -7
- warp/tests/test_struct.py +13 -2
- warp/tests/test_triangle_closest_point.py +48 -1
- warp/tests/test_types.py +27 -15
- warp/tests/test_utils.py +52 -52
- warp/tests/test_vec.py +29 -29
- warp/tests/test_vec_constructors.py +5 -5
- warp/tests/test_vec_scalar_ops.py +97 -97
- warp/tests/test_version.py +75 -0
- warp/tests/tile/test_tile.py +178 -0
- warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
- warp/tests/tile/test_tile_cholesky.py +7 -4
- warp/tests/tile/test_tile_load.py +26 -2
- warp/tests/tile/test_tile_mathdx.py +3 -3
- warp/tests/tile/test_tile_matmul.py +1 -1
- warp/tests/tile/test_tile_mlp.py +2 -4
- warp/tests/tile/test_tile_reduce.py +214 -13
- warp/tests/unittest_suites.py +6 -14
- warp/tests/unittest_utils.py +10 -9
- warp/tests/walkthrough_debug.py +3 -1
- warp/torch.py +6 -373
- warp/types.py +29 -5764
- warp/utils.py +10 -1659
- {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/METADATA +46 -99
- warp_lang-1.10.0rc2.dist-info/RECORD +468 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
- warp/examples/assets/cartpole.urdf +0 -110
- warp/examples/assets/crazyflie.usd +0 -0
- warp/examples/assets/nv_ant.xml +0 -92
- warp/examples/assets/nv_humanoid.xml +0 -183
- warp/examples/assets/quadruped.urdf +0 -268
- warp/examples/optim/example_bounce.py +0 -266
- warp/examples/optim/example_cloth_throw.py +0 -228
- warp/examples/optim/example_drone.py +0 -870
- warp/examples/optim/example_inverse_kinematics.py +0 -182
- warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
- warp/examples/optim/example_softbody_properties.py +0 -400
- warp/examples/optim/example_spring_cage.py +0 -245
- warp/examples/optim/example_trajectory.py +0 -227
- warp/examples/sim/example_cartpole.py +0 -143
- warp/examples/sim/example_cloth.py +0 -225
- warp/examples/sim/example_cloth_self_contact.py +0 -316
- warp/examples/sim/example_granular.py +0 -130
- warp/examples/sim/example_granular_collision_sdf.py +0 -202
- warp/examples/sim/example_jacobian_ik.py +0 -244
- warp/examples/sim/example_particle_chain.py +0 -124
- warp/examples/sim/example_quadruped.py +0 -203
- warp/examples/sim/example_rigid_chain.py +0 -203
- warp/examples/sim/example_rigid_contact.py +0 -195
- warp/examples/sim/example_rigid_force.py +0 -133
- warp/examples/sim/example_rigid_gyroscopic.py +0 -115
- warp/examples/sim/example_rigid_soft_contact.py +0 -140
- warp/examples/sim/example_soft_body.py +0 -196
- warp/examples/tile/example_tile_walker.py +0 -327
- warp/sim/__init__.py +0 -74
- warp/sim/articulation.py +0 -793
- warp/sim/collide.py +0 -2570
- warp/sim/graph_coloring.py +0 -307
- warp/sim/import_mjcf.py +0 -791
- warp/sim/import_snu.py +0 -227
- warp/sim/import_urdf.py +0 -579
- warp/sim/import_usd.py +0 -898
- warp/sim/inertia.py +0 -357
- warp/sim/integrator.py +0 -245
- warp/sim/integrator_euler.py +0 -2000
- warp/sim/integrator_featherstone.py +0 -2101
- warp/sim/integrator_vbd.py +0 -2487
- warp/sim/integrator_xpbd.py +0 -3295
- warp/sim/model.py +0 -4821
- warp/sim/particles.py +0 -121
- warp/sim/render.py +0 -431
- warp/sim/utils.py +0 -431
- warp/tests/sim/disabled_kinematics.py +0 -244
- warp/tests/sim/test_cloth.py +0 -863
- warp/tests/sim/test_collision.py +0 -743
- warp/tests/sim/test_coloring.py +0 -347
- warp/tests/sim/test_inertia.py +0 -161
- warp/tests/sim/test_model.py +0 -226
- warp/tests/sim/test_sim_grad.py +0 -287
- warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
- warp/tests/sim/test_sim_kinematics.py +0 -98
- warp/thirdparty/__init__.py +0 -0
- warp_lang-1.9.1.dist-info/RECORD +0 -456
- /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
- /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
- /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
- /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
- {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/WHEEL +0 -0
- {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/licenses/LICENSE.md +0 -0
- {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,1284 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import collections
|
|
17
|
+
import ctypes
|
|
18
|
+
import inspect
|
|
19
|
+
import threading
|
|
20
|
+
import traceback
|
|
21
|
+
from enum import IntEnum
|
|
22
|
+
from typing import Callable, Optional
|
|
23
|
+
|
|
24
|
+
import jax
|
|
25
|
+
|
|
26
|
+
import warp as wp
|
|
27
|
+
from warp._src.codegen import get_full_arg_spec, make_full_qualified_name
|
|
28
|
+
from warp._src.jax import get_jax_device
|
|
29
|
+
from warp._src.types import array_t, launch_bounds_t, strides_from_shape, type_to_warp
|
|
30
|
+
|
|
31
|
+
from .xla_ffi import *
|
|
32
|
+
|
|
33
|
+
# Type alias for differentiable kernel cache key
|
|
34
|
+
DiffKernelCacheKey = tuple[Callable, tuple, int, str, tuple[str, ...]]
|
|
35
|
+
|
|
36
|
+
# Holders for the custom callbacks to keep them alive.
|
|
37
|
+
_FFI_KERNEL_REGISTRY: dict[str, "FfiKernel"] = {}
|
|
38
|
+
_FFI_DIFF_KERNEL_REGISTRY: dict[DiffKernelCacheKey, Callable] = {}
|
|
39
|
+
_FFI_CALLABLE_REGISTRY: dict[str, "FfiCallable"] = {}
|
|
40
|
+
_FFI_CALLBACK_REGISTRY: dict[str, ctypes.CFUNCTYPE] = {}
|
|
41
|
+
_FFI_REGISTRY_LOCK = threading.Lock()
|
|
42
|
+
|
|
43
|
+
# Lock when XLA invokes callbacks from multiple threads.
|
|
44
|
+
_FFI_CALLBACK_LOCK = threading.Lock()
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def check_jax_version():
|
|
48
|
+
# check if JAX version supports this
|
|
49
|
+
if jax.__version_info__ < (0, 5, 0):
|
|
50
|
+
msg = (
|
|
51
|
+
"This version of jax_kernel() requires JAX version 0.5.0 or higher, "
|
|
52
|
+
f"but installed JAX version is {jax.__version_info__}."
|
|
53
|
+
)
|
|
54
|
+
if jax.__version_info__ >= (0, 4, 25):
|
|
55
|
+
msg += " Please use warp.jax_experimental.custom_call.jax_kernel instead."
|
|
56
|
+
raise RuntimeError(msg)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class GraphMode(IntEnum):
|
|
60
|
+
NONE = 0 # don't capture a graph
|
|
61
|
+
JAX = 1 # let JAX capture a graph
|
|
62
|
+
WARP = 2 # let Warp capture a graph
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class ModulePreloadMode(IntEnum):
|
|
66
|
+
NONE = 0 # don't preload modules
|
|
67
|
+
CURRENT_DEVICE = 1 # preload on currently active device
|
|
68
|
+
ALL_DEVICES = 2 # preload on all supported devices
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class FfiArg:
|
|
72
|
+
def __init__(self, name, type, in_out=False):
|
|
73
|
+
self.name = name
|
|
74
|
+
self.type = type
|
|
75
|
+
self.in_out = in_out
|
|
76
|
+
self.is_array = isinstance(type, wp.array)
|
|
77
|
+
|
|
78
|
+
if self.is_array:
|
|
79
|
+
if hasattr(type.dtype, "_wp_scalar_type_"):
|
|
80
|
+
self.dtype_shape = type.dtype._shape_
|
|
81
|
+
self.dtype_ndim = len(self.dtype_shape)
|
|
82
|
+
self.jax_scalar_type = wp.dtype_to_jax(type.dtype._wp_scalar_type_)
|
|
83
|
+
self.jax_ndim = type.ndim + self.dtype_ndim
|
|
84
|
+
elif type.dtype in wp._src.types.value_types:
|
|
85
|
+
self.dtype_ndim = 0
|
|
86
|
+
self.dtype_shape = ()
|
|
87
|
+
self.jax_scalar_type = wp.dtype_to_jax(type.dtype)
|
|
88
|
+
self.jax_ndim = type.ndim
|
|
89
|
+
else:
|
|
90
|
+
raise TypeError(f"Invalid data type for array argument '{name}', expected scalar, vector, or matrix")
|
|
91
|
+
self.warp_ndim = type.ndim
|
|
92
|
+
elif type in wp._src.types.value_types:
|
|
93
|
+
self.dtype_ndim = 0
|
|
94
|
+
self.dtype_shape = ()
|
|
95
|
+
self.jax_scalar_type = wp.dtype_to_jax(type_to_warp(type))
|
|
96
|
+
self.jax_ndim = 0
|
|
97
|
+
self.warp_ndim = 0
|
|
98
|
+
else:
|
|
99
|
+
raise TypeError(f"Invalid type for argument '{name}', expected array or scalar, got {type}")
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class FfiLaunchDesc:
|
|
103
|
+
def __init__(self, static_inputs, launch_dims):
|
|
104
|
+
self.static_inputs = static_inputs
|
|
105
|
+
self.launch_dims = launch_dims
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
class FfiKernel:
|
|
109
|
+
def __init__(
|
|
110
|
+
self, kernel, num_outputs, vmap_method, launch_dims, output_dims, in_out_argnames, module_preload_mode
|
|
111
|
+
):
|
|
112
|
+
self.kernel = kernel
|
|
113
|
+
self.name = generate_unique_name(kernel.func)
|
|
114
|
+
self.num_outputs = num_outputs
|
|
115
|
+
self.vmap_method = vmap_method
|
|
116
|
+
self.launch_dims = launch_dims
|
|
117
|
+
self.output_dims = output_dims
|
|
118
|
+
self.module_preload_mode = module_preload_mode
|
|
119
|
+
self.first_array_arg = None
|
|
120
|
+
self.launch_id = 0
|
|
121
|
+
self.launch_descriptors = {}
|
|
122
|
+
|
|
123
|
+
in_out_argnames_list = in_out_argnames or []
|
|
124
|
+
in_out_argnames = set(in_out_argnames_list)
|
|
125
|
+
if len(in_out_argnames_list) != len(in_out_argnames):
|
|
126
|
+
raise AssertionError("in_out_argnames must not contain duplicate names")
|
|
127
|
+
|
|
128
|
+
self.num_kernel_args = len(kernel.adj.args)
|
|
129
|
+
self.num_in_out = len(in_out_argnames)
|
|
130
|
+
self.num_inputs = self.num_kernel_args - num_outputs + self.num_in_out
|
|
131
|
+
if self.num_outputs < 1:
|
|
132
|
+
raise ValueError("At least one output is required")
|
|
133
|
+
if self.num_outputs > self.num_kernel_args:
|
|
134
|
+
raise ValueError("Number of outputs cannot be greater than the number of kernel arguments")
|
|
135
|
+
if self.num_outputs < self.num_in_out:
|
|
136
|
+
raise ValueError("Number of outputs cannot be smaller than the number of in_out_argnames")
|
|
137
|
+
|
|
138
|
+
# process input args
|
|
139
|
+
self.input_args = []
|
|
140
|
+
for i in range(self.num_inputs):
|
|
141
|
+
arg_name = kernel.adj.args[i].label
|
|
142
|
+
arg = FfiArg(arg_name, kernel.adj.args[i].type, arg_name in in_out_argnames)
|
|
143
|
+
if arg_name in in_out_argnames:
|
|
144
|
+
in_out_argnames.remove(arg_name)
|
|
145
|
+
if arg.is_array:
|
|
146
|
+
# keep track of the first input array argument
|
|
147
|
+
if self.first_array_arg is None:
|
|
148
|
+
self.first_array_arg = i
|
|
149
|
+
self.input_args.append(arg)
|
|
150
|
+
|
|
151
|
+
# process output args
|
|
152
|
+
self.output_args = []
|
|
153
|
+
for i in range(self.num_inputs, self.num_kernel_args):
|
|
154
|
+
arg_name = kernel.adj.args[i].label
|
|
155
|
+
if arg_name in in_out_argnames:
|
|
156
|
+
raise AssertionError(
|
|
157
|
+
f"Expected an output-only argument for argument {arg_name}."
|
|
158
|
+
" in_out arguments should be placed before output-only arguments."
|
|
159
|
+
)
|
|
160
|
+
arg = FfiArg(arg_name, kernel.adj.args[i].type, False)
|
|
161
|
+
if not arg.is_array:
|
|
162
|
+
raise TypeError("All output arguments must be arrays")
|
|
163
|
+
self.output_args.append(arg)
|
|
164
|
+
|
|
165
|
+
if in_out_argnames:
|
|
166
|
+
raise ValueError(f"in_out_argnames: '{in_out_argnames}' did not match any function argument names.")
|
|
167
|
+
|
|
168
|
+
# Build input output aliases.
|
|
169
|
+
out_id = 0
|
|
170
|
+
input_output_aliases = {}
|
|
171
|
+
for in_id, arg in enumerate(self.input_args):
|
|
172
|
+
if not arg.in_out:
|
|
173
|
+
continue
|
|
174
|
+
input_output_aliases[in_id] = out_id
|
|
175
|
+
out_id += 1
|
|
176
|
+
self.input_output_aliases = input_output_aliases
|
|
177
|
+
|
|
178
|
+
# register the callback
|
|
179
|
+
FFI_CCALLFUNC = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_CallFrame))
|
|
180
|
+
self.callback_func = FFI_CCALLFUNC(lambda call_frame: self.ffi_callback(call_frame))
|
|
181
|
+
ffi_ccall_address = ctypes.cast(self.callback_func, ctypes.c_void_p)
|
|
182
|
+
ffi_capsule = jax.ffi.pycapsule(ffi_ccall_address.value)
|
|
183
|
+
jax.ffi.register_ffi_target(self.name, ffi_capsule, platform="CUDA")
|
|
184
|
+
|
|
185
|
+
def __call__(self, *args, output_dims=None, launch_dims=None, vmap_method=None):
|
|
186
|
+
num_inputs = len(args)
|
|
187
|
+
if num_inputs != self.num_inputs:
|
|
188
|
+
raise ValueError(f"Expected {self.num_inputs} inputs, but got {num_inputs}")
|
|
189
|
+
|
|
190
|
+
# default argument fallback
|
|
191
|
+
if launch_dims is None:
|
|
192
|
+
launch_dims = self.launch_dims
|
|
193
|
+
if output_dims is None:
|
|
194
|
+
output_dims = self.output_dims
|
|
195
|
+
if vmap_method is None:
|
|
196
|
+
vmap_method = self.vmap_method
|
|
197
|
+
|
|
198
|
+
# output types
|
|
199
|
+
out_types = []
|
|
200
|
+
|
|
201
|
+
# process inputs
|
|
202
|
+
static_inputs = {}
|
|
203
|
+
for i in range(num_inputs):
|
|
204
|
+
input_arg = self.input_args[i]
|
|
205
|
+
input_value = args[i]
|
|
206
|
+
if input_arg.is_array:
|
|
207
|
+
# check dtype
|
|
208
|
+
if input_value.dtype != input_arg.jax_scalar_type:
|
|
209
|
+
raise TypeError(
|
|
210
|
+
f"Invalid data type for array argument '{input_arg.name}', expected {input_arg.jax_scalar_type}, got {input_value.dtype}"
|
|
211
|
+
)
|
|
212
|
+
# check ndim
|
|
213
|
+
if input_value.ndim != input_arg.jax_ndim:
|
|
214
|
+
raise TypeError(
|
|
215
|
+
f"Invalid dimensionality for array argument '{input_arg.name}', expected {input_arg.jax_ndim} dimensions, got {input_value.ndim}"
|
|
216
|
+
)
|
|
217
|
+
# check inner dims
|
|
218
|
+
for d in range(input_arg.dtype_ndim):
|
|
219
|
+
if input_value.shape[input_arg.type.ndim + d] != input_arg.dtype_shape[d]:
|
|
220
|
+
raise TypeError(
|
|
221
|
+
f"Invalid inner dimensions for array argument '{input_arg.name}', expected {input_arg.dtype_shape}, got {input_value.shape[-input_arg.dtype_ndim :]}"
|
|
222
|
+
)
|
|
223
|
+
else:
|
|
224
|
+
# make sure scalar is not a traced variable, should be static
|
|
225
|
+
if isinstance(input_value, jax.core.Tracer):
|
|
226
|
+
raise ValueError(f"Argument '{input_arg.name}' must be a static value")
|
|
227
|
+
# stash the value to be retrieved by callback
|
|
228
|
+
static_inputs[input_arg.name] = input_arg.type(input_value)
|
|
229
|
+
|
|
230
|
+
# append in-out arg to output types
|
|
231
|
+
if input_arg.in_out:
|
|
232
|
+
out_types.append(get_jax_output_type(input_arg, input_value.shape))
|
|
233
|
+
|
|
234
|
+
# launch dimensions
|
|
235
|
+
if launch_dims is None:
|
|
236
|
+
# use the shape of the first input array
|
|
237
|
+
if self.first_array_arg is not None:
|
|
238
|
+
launch_dims = get_warp_shape(self.input_args[self.first_array_arg], args[self.first_array_arg].shape)
|
|
239
|
+
else:
|
|
240
|
+
raise RuntimeError("Failed to determine launch dimensions")
|
|
241
|
+
elif isinstance(launch_dims, int):
|
|
242
|
+
launch_dims = (launch_dims,)
|
|
243
|
+
else:
|
|
244
|
+
launch_dims = tuple(launch_dims)
|
|
245
|
+
|
|
246
|
+
# output shapes
|
|
247
|
+
if isinstance(output_dims, dict):
|
|
248
|
+
# assume a dictionary of shapes keyed on argument name
|
|
249
|
+
for output_arg in self.output_args:
|
|
250
|
+
dims = output_dims.get(output_arg.name)
|
|
251
|
+
if dims is None:
|
|
252
|
+
raise ValueError(f"Missing output dimensions for argument '{output_arg.name}'")
|
|
253
|
+
out_types.append(get_jax_output_type(output_arg, dims))
|
|
254
|
+
else:
|
|
255
|
+
if output_dims is None:
|
|
256
|
+
# use launch dimensions
|
|
257
|
+
output_dims = launch_dims
|
|
258
|
+
elif isinstance(output_dims, int):
|
|
259
|
+
output_dims = (output_dims,)
|
|
260
|
+
# assume same dimensions for all outputs
|
|
261
|
+
for output_arg in self.output_args:
|
|
262
|
+
out_types.append(get_jax_output_type(output_arg, output_dims))
|
|
263
|
+
|
|
264
|
+
call = jax.ffi.ffi_call(
|
|
265
|
+
self.name,
|
|
266
|
+
out_types,
|
|
267
|
+
vmap_method=vmap_method,
|
|
268
|
+
input_output_aliases=self.input_output_aliases,
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
# preload on the specified devices
|
|
272
|
+
if self.module_preload_mode == ModulePreloadMode.CURRENT_DEVICE:
|
|
273
|
+
device = wp.device_from_jax(get_jax_device())
|
|
274
|
+
self.kernel.module.load(device)
|
|
275
|
+
elif self.module_preload_mode == ModulePreloadMode.ALL_DEVICES:
|
|
276
|
+
for d in jax.local_devices():
|
|
277
|
+
try:
|
|
278
|
+
dev = wp.device_from_jax(d)
|
|
279
|
+
except Exception:
|
|
280
|
+
# ignore unsupported devices like TPUs
|
|
281
|
+
pass
|
|
282
|
+
# we only support CUDA devices for now
|
|
283
|
+
if dev.is_cuda:
|
|
284
|
+
self.kernel.module.load(dev)
|
|
285
|
+
|
|
286
|
+
# save launch data to be retrieved by callback
|
|
287
|
+
launch_id = self.launch_id
|
|
288
|
+
self.launch_descriptors[launch_id] = FfiLaunchDesc(static_inputs, launch_dims)
|
|
289
|
+
self.launch_id += 1
|
|
290
|
+
|
|
291
|
+
return call(*args, launch_id=launch_id)
|
|
292
|
+
|
|
293
|
+
def ffi_callback(self, call_frame):
|
|
294
|
+
try:
|
|
295
|
+
# On the first call, XLA runtime will query the API version and traits
|
|
296
|
+
# metadata using the |extension| field. Let us respond to that query
|
|
297
|
+
# if the metadata extension is present.
|
|
298
|
+
extension = call_frame.contents.extension_start
|
|
299
|
+
if extension:
|
|
300
|
+
# Try to set the version metadata.
|
|
301
|
+
if extension.contents.type == XLA_FFI_Extension_Type.Metadata:
|
|
302
|
+
metadata_ext = ctypes.cast(extension, ctypes.POINTER(XLA_FFI_Metadata_Extension))
|
|
303
|
+
metadata_ext.contents.metadata.contents.api_version.major_version = 0
|
|
304
|
+
metadata_ext.contents.metadata.contents.api_version.minor_version = 1
|
|
305
|
+
# Turn on CUDA graphs for this handler.
|
|
306
|
+
metadata_ext.contents.metadata.contents.traits = (
|
|
307
|
+
XLA_FFI_Handler_TraitsBits.COMMAND_BUFFER_COMPATIBLE
|
|
308
|
+
)
|
|
309
|
+
return None
|
|
310
|
+
|
|
311
|
+
# Lock is required to prevent race conditions when callback is invoked
|
|
312
|
+
# from multiple threads, like with pmap.
|
|
313
|
+
with _FFI_CALLBACK_LOCK:
|
|
314
|
+
# retrieve call info
|
|
315
|
+
attrs = decode_attrs(call_frame.contents.attrs)
|
|
316
|
+
launch_id = int(attrs["launch_id"])
|
|
317
|
+
launch_desc = self.launch_descriptors[launch_id]
|
|
318
|
+
|
|
319
|
+
num_inputs = call_frame.contents.args.size
|
|
320
|
+
inputs = ctypes.cast(call_frame.contents.args.args, ctypes.POINTER(ctypes.POINTER(XLA_FFI_Buffer)))
|
|
321
|
+
|
|
322
|
+
num_outputs = call_frame.contents.rets.size
|
|
323
|
+
outputs = ctypes.cast(call_frame.contents.rets.rets, ctypes.POINTER(ctypes.POINTER(XLA_FFI_Buffer)))
|
|
324
|
+
|
|
325
|
+
assert num_inputs == self.num_inputs
|
|
326
|
+
assert num_outputs == self.num_outputs
|
|
327
|
+
|
|
328
|
+
launch_bounds = launch_bounds_t(launch_desc.launch_dims)
|
|
329
|
+
|
|
330
|
+
# first kernel param is the launch bounds
|
|
331
|
+
kernel_params = (ctypes.c_void_p * (1 + self.num_kernel_args))()
|
|
332
|
+
kernel_params[0] = ctypes.addressof(launch_bounds)
|
|
333
|
+
|
|
334
|
+
arg_refs = []
|
|
335
|
+
|
|
336
|
+
# input and in-out args
|
|
337
|
+
for i, input_arg in enumerate(self.input_args):
|
|
338
|
+
if input_arg.is_array:
|
|
339
|
+
buffer = inputs[i].contents
|
|
340
|
+
shape = buffer.dims[: input_arg.type.ndim]
|
|
341
|
+
strides = strides_from_shape(shape, input_arg.type.dtype)
|
|
342
|
+
arg = array_t(buffer.data, 0, input_arg.type.ndim, shape, strides)
|
|
343
|
+
kernel_params[i + 1] = ctypes.addressof(arg)
|
|
344
|
+
arg_refs.append(arg) # keep a reference
|
|
345
|
+
else:
|
|
346
|
+
# scalar argument, get stashed value
|
|
347
|
+
value = launch_desc.static_inputs[input_arg.name]
|
|
348
|
+
arg = input_arg.type._type_(value)
|
|
349
|
+
kernel_params[i + 1] = ctypes.addressof(arg)
|
|
350
|
+
arg_refs.append(arg) # keep a reference
|
|
351
|
+
|
|
352
|
+
# pure output args (skip in-out FFI buffers)
|
|
353
|
+
for i, output_arg in enumerate(self.output_args):
|
|
354
|
+
buffer = outputs[i + self.num_in_out].contents
|
|
355
|
+
shape = buffer.dims[: output_arg.type.ndim]
|
|
356
|
+
strides = strides_from_shape(shape, output_arg.type.dtype)
|
|
357
|
+
arg = array_t(buffer.data, 0, output_arg.type.ndim, shape, strides)
|
|
358
|
+
kernel_params[num_inputs + i + 1] = ctypes.addressof(arg)
|
|
359
|
+
arg_refs.append(arg) # keep a reference
|
|
360
|
+
|
|
361
|
+
# get device and stream
|
|
362
|
+
device = wp.get_cuda_device(get_device_ordinal_from_callframe(call_frame.contents))
|
|
363
|
+
stream = get_stream_from_callframe(call_frame.contents)
|
|
364
|
+
|
|
365
|
+
# get kernel hooks
|
|
366
|
+
hooks = self.kernel.module.get_kernel_hooks(self.kernel, device)
|
|
367
|
+
assert hooks.forward, "Failed to find kernel entry point"
|
|
368
|
+
|
|
369
|
+
# launch the kernel
|
|
370
|
+
wp._src.context.runtime.core.wp_cuda_launch_kernel(
|
|
371
|
+
device.context,
|
|
372
|
+
hooks.forward,
|
|
373
|
+
launch_bounds.size,
|
|
374
|
+
0,
|
|
375
|
+
256,
|
|
376
|
+
hooks.forward_smem_bytes,
|
|
377
|
+
kernel_params,
|
|
378
|
+
stream,
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
except Exception as e:
|
|
382
|
+
print(traceback.format_exc())
|
|
383
|
+
return create_ffi_error(
|
|
384
|
+
call_frame.contents.api, XLA_FFI_Error_Code.UNKNOWN, f"FFI callback error: {type(e).__name__}: {e}"
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
|
|
388
|
+
class FfiCallDesc:
|
|
389
|
+
def __init__(self, static_inputs):
|
|
390
|
+
self.static_inputs = static_inputs
|
|
391
|
+
|
|
392
|
+
|
|
393
|
+
class FfiCallable:
|
|
394
|
+
default_graph_cache_max: int | None = 32
|
|
395
|
+
|
|
396
|
+
def __init__(
|
|
397
|
+
self,
|
|
398
|
+
func,
|
|
399
|
+
num_outputs,
|
|
400
|
+
graph_mode,
|
|
401
|
+
vmap_method,
|
|
402
|
+
output_dims,
|
|
403
|
+
in_out_argnames,
|
|
404
|
+
graph_cache_max,
|
|
405
|
+
module_preload_mode,
|
|
406
|
+
):
|
|
407
|
+
self.func = func
|
|
408
|
+
self.name = generate_unique_name(func)
|
|
409
|
+
self.num_outputs = num_outputs
|
|
410
|
+
self.vmap_method = vmap_method
|
|
411
|
+
self.graph_mode = graph_mode
|
|
412
|
+
self.output_dims = output_dims
|
|
413
|
+
self.module_preload_mode = module_preload_mode
|
|
414
|
+
self.first_array_arg = None
|
|
415
|
+
self.call_id = 0
|
|
416
|
+
self.call_descriptors = {}
|
|
417
|
+
|
|
418
|
+
# LRU cache of graphs captured by Warp
|
|
419
|
+
self._graph_cache_max = graph_cache_max
|
|
420
|
+
self.captures = collections.OrderedDict()
|
|
421
|
+
|
|
422
|
+
in_out_argnames_list = in_out_argnames or []
|
|
423
|
+
in_out_argnames = set(in_out_argnames_list)
|
|
424
|
+
if len(in_out_argnames_list) != len(in_out_argnames):
|
|
425
|
+
raise AssertionError("in_out_argnames must not contain duplicate names")
|
|
426
|
+
|
|
427
|
+
# get arguments and annotations
|
|
428
|
+
argspec = get_full_arg_spec(func)
|
|
429
|
+
|
|
430
|
+
num_args = len(argspec.args)
|
|
431
|
+
self.num_in_out = len(in_out_argnames)
|
|
432
|
+
self.num_inputs = num_args - num_outputs + self.num_in_out
|
|
433
|
+
if self.num_outputs < 1:
|
|
434
|
+
raise ValueError("At least one output is required")
|
|
435
|
+
if self.num_outputs > num_args:
|
|
436
|
+
raise ValueError("Number of outputs cannot be greater than the number of kernel arguments")
|
|
437
|
+
if self.num_outputs < self.num_in_out:
|
|
438
|
+
raise ValueError("Number of outputs cannot be smaller than the number of in_out_argnames")
|
|
439
|
+
|
|
440
|
+
if len(argspec.annotations) < num_args:
|
|
441
|
+
raise RuntimeError(f"Incomplete argument annotations on function {self.name}")
|
|
442
|
+
|
|
443
|
+
# parse type annotations
|
|
444
|
+
self.args = []
|
|
445
|
+
arg_idx = 0
|
|
446
|
+
for arg_name, arg_type in argspec.annotations.items():
|
|
447
|
+
if arg_name == "return":
|
|
448
|
+
if arg_type is not None:
|
|
449
|
+
raise TypeError("Function must not return a value")
|
|
450
|
+
continue
|
|
451
|
+
else:
|
|
452
|
+
arg = FfiArg(arg_name, arg_type, arg_name in in_out_argnames)
|
|
453
|
+
if arg_name in in_out_argnames:
|
|
454
|
+
in_out_argnames.remove(arg_name)
|
|
455
|
+
if arg.is_array:
|
|
456
|
+
if arg_idx < self.num_inputs and self.first_array_arg is None:
|
|
457
|
+
self.first_array_arg = arg_idx
|
|
458
|
+
self.args.append(arg)
|
|
459
|
+
|
|
460
|
+
if arg.in_out and arg_idx >= self.num_inputs:
|
|
461
|
+
raise AssertionError(
|
|
462
|
+
f"Expected an output-only argument for argument {arg_name}."
|
|
463
|
+
" in_out arguments should be placed before output-only arguments."
|
|
464
|
+
)
|
|
465
|
+
|
|
466
|
+
arg_idx += 1
|
|
467
|
+
|
|
468
|
+
if in_out_argnames:
|
|
469
|
+
raise ValueError(f"in_out_argnames: '{in_out_argnames}' did not match any function argument names.")
|
|
470
|
+
|
|
471
|
+
self.input_args = self.args[: self.num_inputs] # includes in-out args
|
|
472
|
+
self.output_args = self.args[self.num_inputs :] # pure output args
|
|
473
|
+
|
|
474
|
+
# Buffer indices for array arguments in callback.
|
|
475
|
+
# In-out buffers are the same pointers in the XLA call frame,
|
|
476
|
+
# so we only include them for inputs and skip them for outputs.
|
|
477
|
+
self.array_input_indices = [i for i, arg in enumerate(self.input_args) if arg.is_array]
|
|
478
|
+
self.array_output_indices = list(range(self.num_in_out, self.num_outputs))
|
|
479
|
+
|
|
480
|
+
# Build input output aliases.
|
|
481
|
+
out_id = 0
|
|
482
|
+
input_output_aliases = {}
|
|
483
|
+
for in_id, arg in enumerate(self.input_args):
|
|
484
|
+
if not arg.in_out:
|
|
485
|
+
continue
|
|
486
|
+
input_output_aliases[in_id] = out_id
|
|
487
|
+
out_id += 1
|
|
488
|
+
self.input_output_aliases = input_output_aliases
|
|
489
|
+
|
|
490
|
+
# register the callback
|
|
491
|
+
FFI_CCALLFUNC = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_CallFrame))
|
|
492
|
+
self.callback_func = FFI_CCALLFUNC(lambda call_frame: self.ffi_callback(call_frame))
|
|
493
|
+
ffi_ccall_address = ctypes.cast(self.callback_func, ctypes.c_void_p)
|
|
494
|
+
ffi_capsule = jax.ffi.pycapsule(ffi_ccall_address.value)
|
|
495
|
+
jax.ffi.register_ffi_target(self.name, ffi_capsule, platform="CUDA")
|
|
496
|
+
|
|
497
|
+
def __call__(self, *args, output_dims=None, vmap_method=None):
|
|
498
|
+
num_inputs = len(args)
|
|
499
|
+
if num_inputs != self.num_inputs:
|
|
500
|
+
input_names = ", ".join(arg.name for arg in self.input_args)
|
|
501
|
+
s = "" if self.num_inputs == 1 else "s"
|
|
502
|
+
raise ValueError(f"Expected {self.num_inputs} input{s} ({input_names}), but got {num_inputs}")
|
|
503
|
+
|
|
504
|
+
# default argument fallback
|
|
505
|
+
if vmap_method is None:
|
|
506
|
+
vmap_method = self.vmap_method
|
|
507
|
+
if output_dims is None:
|
|
508
|
+
output_dims = self.output_dims
|
|
509
|
+
|
|
510
|
+
# output types
|
|
511
|
+
out_types = []
|
|
512
|
+
|
|
513
|
+
# process inputs
|
|
514
|
+
static_inputs = {}
|
|
515
|
+
for i in range(num_inputs):
|
|
516
|
+
input_arg = self.input_args[i]
|
|
517
|
+
input_value = args[i]
|
|
518
|
+
if input_arg.is_array:
|
|
519
|
+
# check dtype
|
|
520
|
+
if input_value.dtype != input_arg.jax_scalar_type:
|
|
521
|
+
raise TypeError(
|
|
522
|
+
f"Invalid data type for array argument '{input_arg.name}', expected {input_arg.jax_scalar_type}, got {input_value.dtype}"
|
|
523
|
+
)
|
|
524
|
+
# check ndim
|
|
525
|
+
if input_value.ndim != input_arg.jax_ndim:
|
|
526
|
+
raise TypeError(
|
|
527
|
+
f"Invalid dimensionality for array argument '{input_arg.name}', expected {input_arg.jax_ndim} dimensions, got {input_value.ndim}"
|
|
528
|
+
)
|
|
529
|
+
# check inner dims
|
|
530
|
+
for d in range(input_arg.dtype_ndim):
|
|
531
|
+
if input_value.shape[input_arg.type.ndim + d] != input_arg.dtype_shape[d]:
|
|
532
|
+
raise TypeError(
|
|
533
|
+
f"Invalid inner dimensions for array argument '{input_arg.name}', expected {input_arg.dtype_shape}, got {input_value.shape[-input_arg.dtype_ndim :]}"
|
|
534
|
+
)
|
|
535
|
+
else:
|
|
536
|
+
# make sure scalar is not a traced variable, should be static
|
|
537
|
+
if isinstance(input_value, jax.core.Tracer):
|
|
538
|
+
raise ValueError(f"Argument '{input_arg.name}' must be a static value")
|
|
539
|
+
# stash the value to be retrieved by callback
|
|
540
|
+
static_inputs[input_arg.name] = input_arg.type(input_value)
|
|
541
|
+
|
|
542
|
+
# append in-out arg to output types
|
|
543
|
+
if input_arg.in_out:
|
|
544
|
+
out_types.append(get_jax_output_type(input_arg, input_value.shape))
|
|
545
|
+
|
|
546
|
+
# output shapes
|
|
547
|
+
if isinstance(output_dims, dict):
|
|
548
|
+
# assume a dictionary of shapes keyed on argument name
|
|
549
|
+
for output_arg in self.output_args:
|
|
550
|
+
dims = output_dims.get(output_arg.name)
|
|
551
|
+
if dims is None:
|
|
552
|
+
raise ValueError(f"Missing output dimensions for argument '{output_arg.name}'")
|
|
553
|
+
out_types.append(get_jax_output_type(output_arg, dims))
|
|
554
|
+
else:
|
|
555
|
+
if output_dims is None:
|
|
556
|
+
if self.first_array_arg is None:
|
|
557
|
+
raise ValueError("Unable to determine output dimensions")
|
|
558
|
+
output_dims = get_warp_shape(self.input_args[self.first_array_arg], args[self.first_array_arg].shape)
|
|
559
|
+
elif isinstance(output_dims, int):
|
|
560
|
+
output_dims = (output_dims,)
|
|
561
|
+
# assume same dimensions for all outputs
|
|
562
|
+
for output_arg in self.output_args:
|
|
563
|
+
out_types.append(get_jax_output_type(output_arg, output_dims))
|
|
564
|
+
|
|
565
|
+
call = jax.ffi.ffi_call(
|
|
566
|
+
self.name,
|
|
567
|
+
out_types,
|
|
568
|
+
vmap_method=vmap_method,
|
|
569
|
+
input_output_aliases=self.input_output_aliases,
|
|
570
|
+
# has_side_effect=True, # force this function to execute even if outputs aren't used
|
|
571
|
+
)
|
|
572
|
+
|
|
573
|
+
# preload on the specified devices
|
|
574
|
+
# NOTE: if the target function uses kernels from different modules, they will not be loaded here
|
|
575
|
+
module = wp.get_module(self.func.__module__)
|
|
576
|
+
if self.module_preload_mode == ModulePreloadMode.CURRENT_DEVICE:
|
|
577
|
+
device = wp.device_from_jax(get_jax_device())
|
|
578
|
+
module.load(device)
|
|
579
|
+
elif self.module_preload_mode == ModulePreloadMode.ALL_DEVICES:
|
|
580
|
+
for d in jax.local_devices():
|
|
581
|
+
try:
|
|
582
|
+
dev = wp.device_from_jax(d)
|
|
583
|
+
except Exception:
|
|
584
|
+
# ignore unsupported devices like TPUs
|
|
585
|
+
pass
|
|
586
|
+
# we only support CUDA devices for now
|
|
587
|
+
if dev.is_cuda:
|
|
588
|
+
module.load(dev)
|
|
589
|
+
|
|
590
|
+
# save call data to be retrieved by callback
|
|
591
|
+
call_id = self.call_id
|
|
592
|
+
self.call_descriptors[call_id] = FfiCallDesc(static_inputs)
|
|
593
|
+
self.call_id += 1
|
|
594
|
+
return call(*args, call_id=call_id)
|
|
595
|
+
|
|
596
|
+
def ffi_callback(self, call_frame):
|
|
597
|
+
try:
|
|
598
|
+
# On the first call, XLA runtime will query the API version and traits
|
|
599
|
+
# metadata using the |extension| field. Let us respond to that query
|
|
600
|
+
# if the metadata extension is present.
|
|
601
|
+
extension = call_frame.contents.extension_start
|
|
602
|
+
if extension:
|
|
603
|
+
# Try to set the version metadata.
|
|
604
|
+
if extension.contents.type == XLA_FFI_Extension_Type.Metadata:
|
|
605
|
+
metadata_ext = ctypes.cast(extension, ctypes.POINTER(XLA_FFI_Metadata_Extension))
|
|
606
|
+
metadata_ext.contents.metadata.contents.api_version.major_version = 0
|
|
607
|
+
metadata_ext.contents.metadata.contents.api_version.minor_version = 1
|
|
608
|
+
# Turn on CUDA graphs for this handler.
|
|
609
|
+
if self.graph_mode is GraphMode.JAX:
|
|
610
|
+
metadata_ext.contents.metadata.contents.traits = (
|
|
611
|
+
XLA_FFI_Handler_TraitsBits.COMMAND_BUFFER_COMPATIBLE
|
|
612
|
+
)
|
|
613
|
+
return None
|
|
614
|
+
|
|
615
|
+
# Lock is required to prevent race conditions when callback is invoked
|
|
616
|
+
# from multiple threads, like with pmap.
|
|
617
|
+
with _FFI_CALLBACK_LOCK:
|
|
618
|
+
# retrieve call info
|
|
619
|
+
# NOTE: this assumes that there's only one attribute - call_id (int64).
|
|
620
|
+
# A more general but slower approach is this:
|
|
621
|
+
# attrs = decode_attrs(call_frame.contents.attrs)
|
|
622
|
+
# call_id = int(attrs["call_id"])
|
|
623
|
+
attr = ctypes.cast(call_frame.contents.attrs.attrs[0], ctypes.POINTER(XLA_FFI_Scalar)).contents
|
|
624
|
+
call_id = ctypes.cast(attr.value, ctypes.POINTER(ctypes.c_int64)).contents.value
|
|
625
|
+
call_desc = self.call_descriptors[call_id]
|
|
626
|
+
|
|
627
|
+
num_inputs = call_frame.contents.args.size
|
|
628
|
+
inputs = ctypes.cast(call_frame.contents.args.args, ctypes.POINTER(ctypes.POINTER(XLA_FFI_Buffer)))
|
|
629
|
+
|
|
630
|
+
num_outputs = call_frame.contents.rets.size
|
|
631
|
+
outputs = ctypes.cast(call_frame.contents.rets.rets, ctypes.POINTER(ctypes.POINTER(XLA_FFI_Buffer)))
|
|
632
|
+
|
|
633
|
+
assert num_inputs == self.num_inputs
|
|
634
|
+
assert num_outputs == self.num_outputs
|
|
635
|
+
|
|
636
|
+
cuda_stream = get_stream_from_callframe(call_frame.contents)
|
|
637
|
+
|
|
638
|
+
if self.graph_mode == GraphMode.WARP:
|
|
639
|
+
# check if we already captured an identical call
|
|
640
|
+
ip = [inputs[i].contents.data for i in self.array_input_indices]
|
|
641
|
+
op = [outputs[i].contents.data for i in self.array_output_indices]
|
|
642
|
+
capture_key = hash((call_id, *ip, *op))
|
|
643
|
+
capture = self.captures.get(capture_key)
|
|
644
|
+
|
|
645
|
+
# launch existing graph
|
|
646
|
+
if capture is not None:
|
|
647
|
+
# NOTE: We use the native graph API to avoid overhead with obtaining Stream and Device objects in Python.
|
|
648
|
+
# This code should match wp.capture_launch().
|
|
649
|
+
graph = capture.graph
|
|
650
|
+
if graph.graph_exec is None:
|
|
651
|
+
g = ctypes.c_void_p()
|
|
652
|
+
if not wp._src.context.runtime.core.wp_cuda_graph_create_exec(
|
|
653
|
+
graph.device.context, cuda_stream, graph.graph, ctypes.byref(g)
|
|
654
|
+
):
|
|
655
|
+
raise RuntimeError(f"Graph creation error: {wp.context.runtime.get_error_string()}")
|
|
656
|
+
graph.graph_exec = g
|
|
657
|
+
|
|
658
|
+
if not wp._src.context.runtime.core.wp_cuda_graph_launch(graph.graph_exec, cuda_stream):
|
|
659
|
+
raise RuntimeError(f"Graph launch error: {wp.context.runtime.get_error_string()}")
|
|
660
|
+
|
|
661
|
+
# update the graph cache to keep recently used graphs alive
|
|
662
|
+
self.captures.move_to_end(capture_key)
|
|
663
|
+
|
|
664
|
+
# early out
|
|
665
|
+
return
|
|
666
|
+
|
|
667
|
+
device_ordinal = get_device_ordinal_from_callframe(call_frame.contents)
|
|
668
|
+
device = wp.get_cuda_device(device_ordinal)
|
|
669
|
+
stream = wp.Stream(device, cuda_stream=cuda_stream)
|
|
670
|
+
|
|
671
|
+
# reconstruct the argument list
|
|
672
|
+
arg_list = []
|
|
673
|
+
|
|
674
|
+
# input and in-out args
|
|
675
|
+
for i, arg in enumerate(self.input_args):
|
|
676
|
+
if arg.is_array:
|
|
677
|
+
buffer = inputs[i].contents
|
|
678
|
+
shape = buffer.dims[: buffer.rank - arg.dtype_ndim]
|
|
679
|
+
arr = wp.array(ptr=buffer.data, dtype=arg.type.dtype, shape=shape, device=device)
|
|
680
|
+
arg_list.append(arr)
|
|
681
|
+
else:
|
|
682
|
+
# scalar argument, get stashed value
|
|
683
|
+
value = call_desc.static_inputs[arg.name]
|
|
684
|
+
arg_list.append(value)
|
|
685
|
+
|
|
686
|
+
# pure output args (skip in-out FFI buffers)
|
|
687
|
+
for i, arg in enumerate(self.output_args):
|
|
688
|
+
buffer = outputs[i + self.num_in_out].contents
|
|
689
|
+
shape = buffer.dims[: buffer.rank - arg.dtype_ndim]
|
|
690
|
+
arr = wp.array(ptr=buffer.data, dtype=arg.type.dtype, shape=shape, device=device)
|
|
691
|
+
arg_list.append(arr)
|
|
692
|
+
|
|
693
|
+
# call the Python function with reconstructed arguments
|
|
694
|
+
with wp.ScopedStream(stream, sync_enter=True):
|
|
695
|
+
if stream.is_capturing:
|
|
696
|
+
# capturing with JAX
|
|
697
|
+
with wp.ScopedCapture(external=True) as capture:
|
|
698
|
+
self.func(*arg_list)
|
|
699
|
+
# keep a reference to the capture object to prevent required modules getting unloaded
|
|
700
|
+
call_desc.capture = capture
|
|
701
|
+
elif self.graph_mode == GraphMode.WARP:
|
|
702
|
+
# capturing with WARP
|
|
703
|
+
with wp.ScopedCapture() as capture:
|
|
704
|
+
self.func(*arg_list)
|
|
705
|
+
wp.capture_launch(capture.graph)
|
|
706
|
+
# keep a reference to the capture object and reuse it with same buffers
|
|
707
|
+
self.captures[capture_key] = capture
|
|
708
|
+
# respect the cache size limit if specified
|
|
709
|
+
if self._graph_cache_max is not None and len(self.captures) > self._graph_cache_max:
|
|
710
|
+
self.captures.popitem(last=False)
|
|
711
|
+
else:
|
|
712
|
+
# not capturing
|
|
713
|
+
self.func(*arg_list)
|
|
714
|
+
|
|
715
|
+
except Exception as e:
|
|
716
|
+
print(traceback.format_exc())
|
|
717
|
+
return create_ffi_error(
|
|
718
|
+
call_frame.contents.api, XLA_FFI_Error_Code.UNKNOWN, f"FFI callback error: {type(e).__name__}: {e}"
|
|
719
|
+
)
|
|
720
|
+
|
|
721
|
+
return None
|
|
722
|
+
|
|
723
|
+
@property
|
|
724
|
+
def graph_cache_max(self) -> int | None:
|
|
725
|
+
return self._graph_cache_max
|
|
726
|
+
|
|
727
|
+
@graph_cache_max.setter
|
|
728
|
+
def graph_cache_max(self, value: int | None):
|
|
729
|
+
if value != self._graph_cache_max:
|
|
730
|
+
if value is not None and (self._graph_cache_max is None or value < self._graph_cache_max):
|
|
731
|
+
# trim the cache if needed
|
|
732
|
+
while len(self.captures) > value:
|
|
733
|
+
self.captures.popitem(last=False)
|
|
734
|
+
self._graph_cache_max = value
|
|
735
|
+
|
|
736
|
+
@property
|
|
737
|
+
def graph_cache_size(self) -> int:
|
|
738
|
+
return len(self.captures)
|
|
739
|
+
|
|
740
|
+
|
|
741
|
+
def jax_kernel(
|
|
742
|
+
kernel,
|
|
743
|
+
num_outputs=1,
|
|
744
|
+
vmap_method="broadcast_all",
|
|
745
|
+
launch_dims=None,
|
|
746
|
+
output_dims=None,
|
|
747
|
+
in_out_argnames=None,
|
|
748
|
+
module_preload_mode=ModulePreloadMode.CURRENT_DEVICE,
|
|
749
|
+
enable_backward: bool = False,
|
|
750
|
+
):
|
|
751
|
+
"""Create a JAX callback from a Warp kernel.
|
|
752
|
+
|
|
753
|
+
NOTE: This is an experimental feature under development.
|
|
754
|
+
|
|
755
|
+
Args:
|
|
756
|
+
kernel: The Warp kernel to launch.
|
|
757
|
+
num_outputs: Specify the number of output arguments if greater than 1.
|
|
758
|
+
This must include the number of ``in_out_arguments``.
|
|
759
|
+
vmap_method: String specifying how the callback transforms under ``vmap()``.
|
|
760
|
+
This argument can also be specified for individual calls.
|
|
761
|
+
launch_dims: Specify the default kernel launch dimensions. If None, launch
|
|
762
|
+
dimensions are inferred from the shape of the first array argument.
|
|
763
|
+
This argument can also be specified for individual calls.
|
|
764
|
+
output_dims: Specify the default dimensions of output arrays. If None, output
|
|
765
|
+
dimensions are inferred from the launch dimensions.
|
|
766
|
+
This argument can also be specified for individual calls.
|
|
767
|
+
in_out_argnames: Names of arguments that are both inputs and outputs (aliased buffers).
|
|
768
|
+
These must be array arguments that appear before any pure output arguments in the
|
|
769
|
+
kernel signature. The number of in-out arguments is included in ``num_outputs``.
|
|
770
|
+
Not supported when ``enable_backward=True``.
|
|
771
|
+
module_preload_mode: Specify the devices where the module should be preloaded.
|
|
772
|
+
enable_backward: Enable automatic differentiation for this kernel.
|
|
773
|
+
|
|
774
|
+
Limitations:
|
|
775
|
+
- All kernel arguments must be contiguous arrays or scalars.
|
|
776
|
+
- Scalars must be static arguments in JAX.
|
|
777
|
+
- Input and input-output arguments must precede the output arguments in the ``kernel`` definition.
|
|
778
|
+
- There must be at least one output or input-output argument.
|
|
779
|
+
- Only the CUDA backend is supported.
|
|
780
|
+
"""
|
|
781
|
+
|
|
782
|
+
check_jax_version()
|
|
783
|
+
|
|
784
|
+
if not enable_backward:
|
|
785
|
+
key = (
|
|
786
|
+
kernel.func,
|
|
787
|
+
kernel.sig,
|
|
788
|
+
num_outputs,
|
|
789
|
+
vmap_method,
|
|
790
|
+
tuple(launch_dims) if launch_dims else launch_dims,
|
|
791
|
+
tuple(sorted(output_dims.items())) if output_dims else output_dims,
|
|
792
|
+
module_preload_mode,
|
|
793
|
+
)
|
|
794
|
+
|
|
795
|
+
with _FFI_REGISTRY_LOCK:
|
|
796
|
+
if key not in _FFI_KERNEL_REGISTRY:
|
|
797
|
+
new_kernel = FfiKernel(
|
|
798
|
+
kernel, num_outputs, vmap_method, launch_dims, output_dims, in_out_argnames, module_preload_mode
|
|
799
|
+
)
|
|
800
|
+
_FFI_KERNEL_REGISTRY[key] = new_kernel
|
|
801
|
+
|
|
802
|
+
return _FFI_KERNEL_REGISTRY[key]
|
|
803
|
+
|
|
804
|
+
# make sure the arguments are compatible with autodiff
|
|
805
|
+
if in_out_argnames:
|
|
806
|
+
raise NotImplementedError(
|
|
807
|
+
"jax_kernel(): Input-output arguments (in_out_argnames) are not supported when enable_backward=True."
|
|
808
|
+
)
|
|
809
|
+
|
|
810
|
+
# TODO: we should support passing these to the forward and backward callables
|
|
811
|
+
if launch_dims is not None or output_dims is not None:
|
|
812
|
+
raise NotImplementedError(
|
|
813
|
+
"jax_kernel(): Custom dimensions (launch_dims, output_dims) are not supported when enable_backward=True."
|
|
814
|
+
)
|
|
815
|
+
|
|
816
|
+
# Differentiable path: build a custom VJP wrapper inline.
|
|
817
|
+
# Infer the original kernel signature (names and annotations)
|
|
818
|
+
signature = inspect.signature(kernel.func)
|
|
819
|
+
|
|
820
|
+
parameters = [p for p in signature.parameters.values() if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD]
|
|
821
|
+
parameter_count = len(parameters)
|
|
822
|
+
num_inputs = parameter_count - num_outputs
|
|
823
|
+
|
|
824
|
+
# determine static argument indices
|
|
825
|
+
static_args = []
|
|
826
|
+
for i, p in enumerate(parameters[:num_inputs]):
|
|
827
|
+
param_type = p.annotation
|
|
828
|
+
if not isinstance(param_type, wp.array):
|
|
829
|
+
if param_type in wp._src.types.value_types:
|
|
830
|
+
static_args.append(i)
|
|
831
|
+
else:
|
|
832
|
+
raise TypeError(f"Invalid type for argument '{p.name}', expected array or scalar, got {type}")
|
|
833
|
+
|
|
834
|
+
def _resolve_launch_dims(call_args):
|
|
835
|
+
# determine launch dimensions from the shape of the first input array
|
|
836
|
+
for i, p in enumerate(parameters[:num_inputs]):
|
|
837
|
+
param_type = p.annotation
|
|
838
|
+
if isinstance(param_type, wp.array):
|
|
839
|
+
arg = call_args[i]
|
|
840
|
+
arg_shape = tuple(arg.shape)
|
|
841
|
+
if hasattr(param_type.dtype, "_wp_scalar_type_"):
|
|
842
|
+
# vector/matrix array, trim trailing dimensions of JAX input array
|
|
843
|
+
return arg_shape[: param_type.ndim]
|
|
844
|
+
else:
|
|
845
|
+
# scalar array
|
|
846
|
+
return arg_shape
|
|
847
|
+
raise RuntimeError("Unable to determine launch dimensions, at least one input array is required")
|
|
848
|
+
|
|
849
|
+
# Forward kernel wrapper: simply launches the kernel
|
|
850
|
+
def fwd_kernel_wrapper(*args):
|
|
851
|
+
wp.launch(kernel, dim=_resolve_launch_dims(args), inputs=args[:num_inputs], outputs=args[num_inputs:])
|
|
852
|
+
|
|
853
|
+
# update forward signature and annotations so jax_callable() sees a fully annotated function
|
|
854
|
+
fwd_kernel_wrapper.__signature__ = signature
|
|
855
|
+
fwd_kernel_wrapper.__annotations__ = {p.name: p.annotation for p in parameters}
|
|
856
|
+
fwd_kernel_wrapper.__annotations__["return"] = None
|
|
857
|
+
|
|
858
|
+
jax_fwd_kernel = jax_callable(fwd_kernel_wrapper, num_outputs=num_outputs, vmap_method=vmap_method)
|
|
859
|
+
|
|
860
|
+
# backward arguments only include static args once
|
|
861
|
+
bwd_arg_count = 2 * parameter_count - len(static_args)
|
|
862
|
+
|
|
863
|
+
# Backward wrapper: launches adjoint with provided output gradients
|
|
864
|
+
def bwd_kernel_wrapper(*args):
|
|
865
|
+
if len(args) != bwd_arg_count:
|
|
866
|
+
raise RuntimeError(f"Invalid backward argument count, expected {bwd_arg_count} but got {len(args)}")
|
|
867
|
+
|
|
868
|
+
inputs = list(args[:num_inputs])
|
|
869
|
+
outputs = list(args[num_inputs:parameter_count])
|
|
870
|
+
grad_out = list(args[parameter_count : parameter_count + num_outputs])
|
|
871
|
+
grad_in = list(args[parameter_count + num_outputs :])
|
|
872
|
+
|
|
873
|
+
for i in static_args:
|
|
874
|
+
grad_in.insert(i, inputs[i])
|
|
875
|
+
|
|
876
|
+
for gi in grad_in:
|
|
877
|
+
if isinstance(gi, wp.array):
|
|
878
|
+
try:
|
|
879
|
+
gi.zero_()
|
|
880
|
+
except Exception as e:
|
|
881
|
+
wp.utils.warn(f"Failed to zero gradient array: {e}", stacklevel=2)
|
|
882
|
+
raise e
|
|
883
|
+
|
|
884
|
+
# NOTE: We cannot use a passed launch_dims here, the backward rule doesn't receive it (and it could be wrong under pmap/vmap).
|
|
885
|
+
# We need to infer from the inputs.
|
|
886
|
+
wp.launch(
|
|
887
|
+
kernel,
|
|
888
|
+
dim=_resolve_launch_dims(inputs),
|
|
889
|
+
inputs=inputs,
|
|
890
|
+
outputs=outputs,
|
|
891
|
+
adj_inputs=grad_in,
|
|
892
|
+
adj_outputs=grad_out,
|
|
893
|
+
adjoint=True,
|
|
894
|
+
)
|
|
895
|
+
|
|
896
|
+
# Build the backward wrapper signature expected by jax_callable
|
|
897
|
+
bwd_input_params = parameters[:num_inputs]
|
|
898
|
+
bwd_output_params = parameters[num_inputs:parameter_count]
|
|
899
|
+
bwd_grad_output_params = [
|
|
900
|
+
inspect.Parameter(
|
|
901
|
+
f"adj_{p.name}",
|
|
902
|
+
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
|
903
|
+
default=p.default,
|
|
904
|
+
annotation=p.annotation,
|
|
905
|
+
)
|
|
906
|
+
for p in bwd_output_params
|
|
907
|
+
]
|
|
908
|
+
|
|
909
|
+
bwd_grad_input_params = [
|
|
910
|
+
inspect.Parameter(
|
|
911
|
+
f"adj_{p.name}",
|
|
912
|
+
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
|
913
|
+
default=p.default,
|
|
914
|
+
annotation=p.annotation,
|
|
915
|
+
)
|
|
916
|
+
for p in bwd_input_params
|
|
917
|
+
]
|
|
918
|
+
for i in reversed(static_args):
|
|
919
|
+
del bwd_grad_input_params[i]
|
|
920
|
+
|
|
921
|
+
# update backward signature and annotations so jax_callable() sees a fully annotated function
|
|
922
|
+
bwd_signature = bwd_input_params + bwd_output_params + bwd_grad_output_params + bwd_grad_input_params
|
|
923
|
+
bwd_kernel_wrapper.__signature__ = inspect.Signature(bwd_signature)
|
|
924
|
+
bwd_annotations = {}
|
|
925
|
+
for p in bwd_input_params:
|
|
926
|
+
bwd_annotations[p.name] = p.annotation
|
|
927
|
+
for p in bwd_output_params:
|
|
928
|
+
bwd_annotations[p.name] = p.annotation
|
|
929
|
+
for p in bwd_grad_output_params:
|
|
930
|
+
bwd_annotations[p.name] = p.annotation
|
|
931
|
+
for p in bwd_grad_input_params:
|
|
932
|
+
bwd_annotations[p.name] = p.annotation
|
|
933
|
+
bwd_annotations["return"] = None
|
|
934
|
+
bwd_kernel_wrapper.__annotations__ = bwd_annotations
|
|
935
|
+
|
|
936
|
+
jax_bwd_kernel = jax_callable(
|
|
937
|
+
bwd_kernel_wrapper,
|
|
938
|
+
num_outputs=len(bwd_input_params) - len(static_args),
|
|
939
|
+
vmap_method=vmap_method,
|
|
940
|
+
)
|
|
941
|
+
|
|
942
|
+
differentiable_input_indices = [i for i in range(num_inputs) if i not in static_args]
|
|
943
|
+
differentiable_input_names = [parameters[i].name for i in differentiable_input_indices]
|
|
944
|
+
|
|
945
|
+
def fwd_function(*args):
|
|
946
|
+
outputs = jax_fwd_kernel(*args)
|
|
947
|
+
non_static_inputs = list(args)
|
|
948
|
+
for i in reversed(static_args):
|
|
949
|
+
del non_static_inputs[i]
|
|
950
|
+
# Normalize to tuple for consistent handling
|
|
951
|
+
if num_outputs == 1:
|
|
952
|
+
outputs_tuple = (outputs,) if not isinstance(outputs, (list, tuple)) else (outputs[0],)
|
|
953
|
+
else:
|
|
954
|
+
outputs_tuple = outputs if isinstance(outputs, tuple) else tuple(outputs)
|
|
955
|
+
return outputs, (tuple(non_static_inputs), outputs_tuple)
|
|
956
|
+
|
|
957
|
+
def bwd_function(*bwd_args):
|
|
958
|
+
nondiff_vals = list(bwd_args[: len(static_args)])
|
|
959
|
+
residuals = bwd_args[len(static_args)]
|
|
960
|
+
grad_out_args = bwd_args[len(static_args) + 1 :]
|
|
961
|
+
|
|
962
|
+
non_static_inputs, output_vals_tuple = residuals
|
|
963
|
+
|
|
964
|
+
input_vals = list(non_static_inputs)
|
|
965
|
+
for i, v in zip(static_args, nondiff_vals):
|
|
966
|
+
input_vals.insert(i, v)
|
|
967
|
+
|
|
968
|
+
# Normalize grad outputs and handle nested containers (e.g., single tuple for multi-output)
|
|
969
|
+
if num_outputs == 1:
|
|
970
|
+
go = grad_out_args[0]
|
|
971
|
+
grad_out_tuple = tuple(go) if isinstance(go, (list, tuple)) else (go,)
|
|
972
|
+
else:
|
|
973
|
+
if len(grad_out_args) == 1 and isinstance(grad_out_args[0], (list, tuple)):
|
|
974
|
+
grad_out_tuple = tuple(grad_out_args[0])
|
|
975
|
+
else:
|
|
976
|
+
grad_out_tuple = tuple(grad_out_args)
|
|
977
|
+
bwd_call_args = list(input_vals) + list(output_vals_tuple) + list(grad_out_tuple)
|
|
978
|
+
|
|
979
|
+
out_dims_map = {}
|
|
980
|
+
param_ann = {p.name: p.annotation for p in parameters[:num_inputs]}
|
|
981
|
+
for name, val in zip(differentiable_input_names, non_static_inputs):
|
|
982
|
+
ann = param_ann.get(name)
|
|
983
|
+
if ann is None:
|
|
984
|
+
continue
|
|
985
|
+
# Check if annotation is a warp array type (annotation is an instance of wp.array)
|
|
986
|
+
is_array_ann = isinstance(ann, wp.array)
|
|
987
|
+
if not is_array_ann:
|
|
988
|
+
continue
|
|
989
|
+
dtype_ndim = 0
|
|
990
|
+
# Extract dtype ndim if it's a vector/matrix type
|
|
991
|
+
if hasattr(ann, "dtype") and hasattr(ann.dtype, "_wp_scalar_type_"):
|
|
992
|
+
dtype_ndim = len(ann.dtype._shape_)
|
|
993
|
+
warp_ndim = getattr(ann, "ndim", 0)
|
|
994
|
+
vshape = tuple(val.shape)
|
|
995
|
+
if warp_ndim == 0:
|
|
996
|
+
continue
|
|
997
|
+
if dtype_ndim > 0:
|
|
998
|
+
core_rank = max(0, len(vshape) - dtype_ndim)
|
|
999
|
+
warp_dims = vshape[max(0, core_rank - warp_ndim) : core_rank]
|
|
1000
|
+
else:
|
|
1001
|
+
warp_dims = vshape[-warp_ndim:]
|
|
1002
|
+
out_dims_map[f"adj_{name}"] = tuple(warp_dims)
|
|
1003
|
+
|
|
1004
|
+
non_static_input_grads = jax_bwd_kernel(*bwd_call_args, output_dims=out_dims_map)
|
|
1005
|
+
return tuple(non_static_input_grads)
|
|
1006
|
+
|
|
1007
|
+
jax_func = jax.custom_vjp(jax_fwd_kernel, nondiff_argnums=tuple(static_args))
|
|
1008
|
+
jax_func.defvjp(fwd_function, bwd_function)
|
|
1009
|
+
|
|
1010
|
+
if static_args:
|
|
1011
|
+
static_names = [parameters[i].name for i in static_args]
|
|
1012
|
+
|
|
1013
|
+
def _user_callable(*args):
|
|
1014
|
+
return jax_func(*args)
|
|
1015
|
+
|
|
1016
|
+
_user_callable.__signature__ = signature
|
|
1017
|
+
|
|
1018
|
+
# Cache differentiable wrapper
|
|
1019
|
+
key = (kernel.func, kernel.sig, num_outputs, vmap_method, tuple(sorted(static_names)))
|
|
1020
|
+
with _FFI_REGISTRY_LOCK:
|
|
1021
|
+
cached = _FFI_DIFF_KERNEL_REGISTRY.get(key)
|
|
1022
|
+
if cached is None:
|
|
1023
|
+
cached = jax.jit(_user_callable, static_argnames=tuple(static_names))
|
|
1024
|
+
_FFI_DIFF_KERNEL_REGISTRY[key] = cached
|
|
1025
|
+
return _FFI_DIFF_KERNEL_REGISTRY[key]
|
|
1026
|
+
|
|
1027
|
+
# Cache differentiable wrapper (no static args)
|
|
1028
|
+
key = (kernel.func, kernel.sig, num_outputs, vmap_method, ())
|
|
1029
|
+
with _FFI_REGISTRY_LOCK:
|
|
1030
|
+
cached = _FFI_DIFF_KERNEL_REGISTRY.get(key)
|
|
1031
|
+
if cached is None:
|
|
1032
|
+
_FFI_DIFF_KERNEL_REGISTRY[key] = jax_func
|
|
1033
|
+
cached = jax_func
|
|
1034
|
+
return cached
|
|
1035
|
+
|
|
1036
|
+
|
|
1037
|
+
def jax_callable(
|
|
1038
|
+
func: Callable,
|
|
1039
|
+
num_outputs: int = 1,
|
|
1040
|
+
graph_compatible: Optional[bool] = None, # deprecated
|
|
1041
|
+
graph_mode: GraphMode = GraphMode.JAX,
|
|
1042
|
+
vmap_method: Optional[str] = "broadcast_all",
|
|
1043
|
+
output_dims=None,
|
|
1044
|
+
in_out_argnames=None,
|
|
1045
|
+
graph_cache_max: int | None = None,
|
|
1046
|
+
module_preload_mode: ModulePreloadMode = ModulePreloadMode.CURRENT_DEVICE,
|
|
1047
|
+
):
|
|
1048
|
+
"""Create a JAX callback from an annotated Python function.
|
|
1049
|
+
|
|
1050
|
+
The Python function arguments must have type annotations like Warp kernels.
|
|
1051
|
+
|
|
1052
|
+
NOTE: This is an experimental feature under development.
|
|
1053
|
+
|
|
1054
|
+
Args:
|
|
1055
|
+
func: The Python function to call.
|
|
1056
|
+
num_outputs: Specify the number of output arguments if greater than 1.
|
|
1057
|
+
This must include the number of ``in_out_arguments``.
|
|
1058
|
+
graph_compatible: Whether the function can be called during CUDA graph capture.
|
|
1059
|
+
This argument is deprecated, use ``graph_mode`` instead.
|
|
1060
|
+
graph_mode: CUDA graph capture mode.
|
|
1061
|
+
``GraphMode.JAX`` (default): Let JAX capture the graph, which may be used as a subgraph in an enclosing JAX capture.
|
|
1062
|
+
``GraphMode.WARP``: Let Warp capture the graph. Use this mode when the callable cannot be used as a subgraph,
|
|
1063
|
+
such as when the callable uses conditional graph nodes.
|
|
1064
|
+
``GraphMode.NONE``: Disable graph capture. Use when the callable performs operations that are not legal in a graph,
|
|
1065
|
+
such as host synchronization.
|
|
1066
|
+
vmap_method: String specifying how the callback transforms under ``vmap()``.
|
|
1067
|
+
This argument can also be specified for individual calls.
|
|
1068
|
+
output_dims: Specify the default dimensions of output arrays.
|
|
1069
|
+
If ``None``, output dimensions are inferred from the launch dimensions.
|
|
1070
|
+
This argument can also be specified for individual calls.
|
|
1071
|
+
in_out_argnames: Names of arguments that are both inputs and outputs (aliased buffers).
|
|
1072
|
+
These must be array arguments that appear before any pure output arguments in the
|
|
1073
|
+
function signature. The number of in-out arguments is included in ``num_outputs``.
|
|
1074
|
+
graph_cache_max: Maximum number of cached graphs captured using ``GraphMode.WARP``.
|
|
1075
|
+
If ``None``, use ``warp.jax_experimental.get_jax_callable_default_graph_cache_max()``.
|
|
1076
|
+
module_preload_mode: Specify the devices where the module should be preloaded.
|
|
1077
|
+
|
|
1078
|
+
Limitations:
|
|
1079
|
+
- All kernel arguments must be contiguous arrays or scalars.
|
|
1080
|
+
- Scalars must be static arguments in JAX.
|
|
1081
|
+
- Input and input-output arguments must precede the output arguments in the ``func`` definition.
|
|
1082
|
+
- There must be at least one output or input-output argument.
|
|
1083
|
+
- Only the CUDA backend is supported.
|
|
1084
|
+
"""
|
|
1085
|
+
|
|
1086
|
+
check_jax_version()
|
|
1087
|
+
|
|
1088
|
+
if graph_compatible is not None:
|
|
1089
|
+
wp._src.utils.warn(
|
|
1090
|
+
"The `graph_compatible` argument is deprecated, use `graph_mode` instead.",
|
|
1091
|
+
DeprecationWarning,
|
|
1092
|
+
stacklevel=3,
|
|
1093
|
+
)
|
|
1094
|
+
if graph_compatible is False:
|
|
1095
|
+
graph_mode = GraphMode.NONE
|
|
1096
|
+
|
|
1097
|
+
if graph_cache_max is None:
|
|
1098
|
+
graph_cache_max = FfiCallable.default_graph_cache_max
|
|
1099
|
+
|
|
1100
|
+
# Note: we don't include graph_cache_max in the key, it is applied below.
|
|
1101
|
+
key = (
|
|
1102
|
+
func,
|
|
1103
|
+
num_outputs,
|
|
1104
|
+
graph_mode,
|
|
1105
|
+
vmap_method,
|
|
1106
|
+
tuple(sorted(output_dims.items())) if output_dims else output_dims,
|
|
1107
|
+
module_preload_mode,
|
|
1108
|
+
)
|
|
1109
|
+
|
|
1110
|
+
with _FFI_REGISTRY_LOCK:
|
|
1111
|
+
callable = _FFI_CALLABLE_REGISTRY.get(key)
|
|
1112
|
+
if callable is None:
|
|
1113
|
+
callable = FfiCallable(
|
|
1114
|
+
func,
|
|
1115
|
+
num_outputs,
|
|
1116
|
+
graph_mode,
|
|
1117
|
+
vmap_method,
|
|
1118
|
+
output_dims,
|
|
1119
|
+
in_out_argnames,
|
|
1120
|
+
graph_cache_max,
|
|
1121
|
+
module_preload_mode,
|
|
1122
|
+
)
|
|
1123
|
+
_FFI_CALLABLE_REGISTRY[key] = callable
|
|
1124
|
+
else:
|
|
1125
|
+
# make sure we're using the latest graph cache max
|
|
1126
|
+
callable.graph_cache_max = graph_cache_max
|
|
1127
|
+
|
|
1128
|
+
return callable
|
|
1129
|
+
|
|
1130
|
+
|
|
1131
|
+
def get_jax_callable_default_graph_cache_max():
|
|
1132
|
+
"""
|
|
1133
|
+
Get the maximum size of the graph cache for graphs captured using ``GraphMode.WARP``, unlimited if ``None``.
|
|
1134
|
+
"""
|
|
1135
|
+
return FfiCallable.default_graph_cache_max
|
|
1136
|
+
|
|
1137
|
+
|
|
1138
|
+
def set_jax_callable_default_graph_cache_max(cache_max: int | None):
|
|
1139
|
+
"""
|
|
1140
|
+
Set the maximum size of the graph cache for graphs captured using ``GraphMode.WARP``, unlimited if ``None``.
|
|
1141
|
+
"""
|
|
1142
|
+
FfiCallable.default_graph_cache_max = cache_max
|
|
1143
|
+
|
|
1144
|
+
|
|
1145
|
+
def clear_jax_callable_graph_cache(callable: FfiCallable | None = None):
|
|
1146
|
+
"""Clear the graph cache of the given callable or all callables if ``None``."""
|
|
1147
|
+
|
|
1148
|
+
if callable is not None:
|
|
1149
|
+
callable.captures.clear()
|
|
1150
|
+
else:
|
|
1151
|
+
# apply to all callables
|
|
1152
|
+
with _FFI_REGISTRY_LOCK:
|
|
1153
|
+
for callable in _FFI_CALLABLE_REGISTRY.values():
|
|
1154
|
+
callable.captures.clear()
|
|
1155
|
+
|
|
1156
|
+
|
|
1157
|
+
###############################################################################
|
|
1158
|
+
#
|
|
1159
|
+
# Generic FFI callbacks for Python functions of the form
|
|
1160
|
+
# func(inputs, outputs, attrs, ctx)
|
|
1161
|
+
#
|
|
1162
|
+
###############################################################################
|
|
1163
|
+
|
|
1164
|
+
|
|
1165
|
+
def register_ffi_callback(name: str, func: Callable, graph_compatible: bool = True) -> None:
|
|
1166
|
+
"""Create a JAX callback from a Python function.
|
|
1167
|
+
|
|
1168
|
+
The Python function must have the form ``func(inputs, outputs, attrs, ctx)``.
|
|
1169
|
+
|
|
1170
|
+
NOTE: This is an experimental feature under development.
|
|
1171
|
+
|
|
1172
|
+
Args:
|
|
1173
|
+
name: A unique FFI callback name.
|
|
1174
|
+
func: The Python function to call.
|
|
1175
|
+
graph_compatible: Whether the function can be called during CUDA graph capture.
|
|
1176
|
+
"""
|
|
1177
|
+
|
|
1178
|
+
check_jax_version()
|
|
1179
|
+
|
|
1180
|
+
# TODO check that the name is not already registered
|
|
1181
|
+
|
|
1182
|
+
def ffi_callback(call_frame):
|
|
1183
|
+
try:
|
|
1184
|
+
extension = call_frame.contents.extension_start
|
|
1185
|
+
# On the first call, XLA runtime will query the API version and traits
|
|
1186
|
+
# metadata using the |extension| field. Let us respond to that query
|
|
1187
|
+
# if the metadata extension is present.
|
|
1188
|
+
if extension:
|
|
1189
|
+
# Try to set the version metadata.
|
|
1190
|
+
if extension.contents.type == XLA_FFI_Extension_Type.Metadata:
|
|
1191
|
+
metadata_ext = ctypes.cast(extension, ctypes.POINTER(XLA_FFI_Metadata_Extension))
|
|
1192
|
+
metadata_ext.contents.metadata.contents.api_version.major_version = 0
|
|
1193
|
+
metadata_ext.contents.metadata.contents.api_version.minor_version = 1
|
|
1194
|
+
if graph_compatible:
|
|
1195
|
+
# Turn on CUDA graphs for this handler.
|
|
1196
|
+
metadata_ext.contents.metadata.contents.traits = (
|
|
1197
|
+
XLA_FFI_Handler_TraitsBits.COMMAND_BUFFER_COMPATIBLE
|
|
1198
|
+
)
|
|
1199
|
+
return None
|
|
1200
|
+
|
|
1201
|
+
# Lock is required to prevent race conditions when callback is invoked
|
|
1202
|
+
# from multiple threads, like with pmap.
|
|
1203
|
+
with _FFI_CALLBACK_LOCK:
|
|
1204
|
+
attrs = decode_attrs(call_frame.contents.attrs)
|
|
1205
|
+
|
|
1206
|
+
input_count = call_frame.contents.args.size
|
|
1207
|
+
inputs = ctypes.cast(call_frame.contents.args.args, ctypes.POINTER(ctypes.POINTER(XLA_FFI_Buffer)))
|
|
1208
|
+
inputs = [FfiBuffer(inputs[i].contents) for i in range(input_count)]
|
|
1209
|
+
|
|
1210
|
+
output_count = call_frame.contents.rets.size
|
|
1211
|
+
outputs = ctypes.cast(call_frame.contents.rets.rets, ctypes.POINTER(ctypes.POINTER(XLA_FFI_Buffer)))
|
|
1212
|
+
outputs = [FfiBuffer(outputs[i].contents) for i in range(output_count)]
|
|
1213
|
+
|
|
1214
|
+
ctx = ExecutionContext(call_frame.contents)
|
|
1215
|
+
|
|
1216
|
+
func(inputs, outputs, attrs, ctx)
|
|
1217
|
+
|
|
1218
|
+
except Exception as e:
|
|
1219
|
+
print(traceback.format_exc())
|
|
1220
|
+
return create_ffi_error(
|
|
1221
|
+
call_frame.contents.api, XLA_FFI_Error_Code.UNKNOWN, f"FFI callback error: {type(e).__name__}: {e}"
|
|
1222
|
+
)
|
|
1223
|
+
|
|
1224
|
+
return None
|
|
1225
|
+
|
|
1226
|
+
FFI_CCALLFUNC = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_CallFrame))
|
|
1227
|
+
callback_func = FFI_CCALLFUNC(ffi_callback)
|
|
1228
|
+
with _FFI_REGISTRY_LOCK:
|
|
1229
|
+
_FFI_CALLBACK_REGISTRY[name] = callback_func
|
|
1230
|
+
ffi_ccall_address = ctypes.cast(callback_func, ctypes.c_void_p)
|
|
1231
|
+
ffi_capsule = jax.ffi.pycapsule(ffi_ccall_address.value)
|
|
1232
|
+
jax.ffi.register_ffi_target(name, ffi_capsule, platform="CUDA")
|
|
1233
|
+
|
|
1234
|
+
|
|
1235
|
+
###############################################################################
|
|
1236
|
+
#
|
|
1237
|
+
# Utilities
|
|
1238
|
+
#
|
|
1239
|
+
###############################################################################
|
|
1240
|
+
|
|
1241
|
+
# ensure unique FFI callback names
|
|
1242
|
+
ffi_name_counts = {}
|
|
1243
|
+
|
|
1244
|
+
|
|
1245
|
+
def generate_unique_name(func) -> str:
|
|
1246
|
+
key = make_full_qualified_name(func)
|
|
1247
|
+
unique_id = ffi_name_counts.get(key, 0)
|
|
1248
|
+
ffi_name_counts[key] = unique_id + 1
|
|
1249
|
+
return f"{key}_{unique_id}"
|
|
1250
|
+
|
|
1251
|
+
|
|
1252
|
+
def get_warp_shape(arg, dims):
|
|
1253
|
+
if arg.dtype_ndim > 0:
|
|
1254
|
+
# vector/matrix array
|
|
1255
|
+
return dims[: arg.warp_ndim]
|
|
1256
|
+
else:
|
|
1257
|
+
# scalar array
|
|
1258
|
+
return dims
|
|
1259
|
+
|
|
1260
|
+
|
|
1261
|
+
def get_jax_output_type(arg, dims):
|
|
1262
|
+
if isinstance(dims, int):
|
|
1263
|
+
dims = (dims,)
|
|
1264
|
+
|
|
1265
|
+
ndim = len(dims)
|
|
1266
|
+
|
|
1267
|
+
if arg.dtype_ndim > 0:
|
|
1268
|
+
# vector/matrix array
|
|
1269
|
+
if ndim == arg.warp_ndim:
|
|
1270
|
+
return jax.ShapeDtypeStruct((*dims, *arg.dtype_shape), arg.jax_scalar_type)
|
|
1271
|
+
elif ndim == arg.jax_ndim:
|
|
1272
|
+
# make sure inner dimensions match
|
|
1273
|
+
inner_dims = dims[-arg.dtype_ndim :]
|
|
1274
|
+
for i in range(arg.dtype_ndim):
|
|
1275
|
+
if inner_dims[i] != arg.dtype_shape[i]:
|
|
1276
|
+
raise ValueError(f"Invalid output dimensions for argument '{arg.name}': {dims}")
|
|
1277
|
+
return jax.ShapeDtypeStruct(dims, arg.jax_scalar_type)
|
|
1278
|
+
else:
|
|
1279
|
+
raise ValueError(f"Invalid output dimensions for argument '{arg.name}': {dims}")
|
|
1280
|
+
else:
|
|
1281
|
+
# scalar array
|
|
1282
|
+
if ndim != arg.warp_ndim:
|
|
1283
|
+
raise ValueError(f"Invalid output dimensions for argument '{arg.name}': {dims}")
|
|
1284
|
+
return jax.ShapeDtypeStruct(dims, arg.jax_scalar_type)
|