warp-lang 1.9.1__py3-none-win_amd64.whl → 1.10.0rc2__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +301 -287
- warp/__init__.pyi +794 -305
- warp/_src/__init__.py +14 -0
- warp/_src/autograd.py +1075 -0
- warp/_src/build.py +618 -0
- warp/_src/build_dll.py +640 -0
- warp/{builtins.py → _src/builtins.py} +1382 -377
- warp/_src/codegen.py +4359 -0
- warp/{config.py → _src/config.py} +178 -169
- warp/_src/constants.py +57 -0
- warp/_src/context.py +8294 -0
- warp/_src/dlpack.py +462 -0
- warp/_src/fabric.py +355 -0
- warp/_src/fem/__init__.py +14 -0
- warp/_src/fem/adaptivity.py +508 -0
- warp/_src/fem/cache.py +687 -0
- warp/_src/fem/dirichlet.py +188 -0
- warp/{fem → _src/fem}/domain.py +40 -30
- warp/_src/fem/field/__init__.py +131 -0
- warp/_src/fem/field/field.py +701 -0
- warp/{fem → _src/fem}/field/nodal_field.py +30 -15
- warp/{fem → _src/fem}/field/restriction.py +1 -1
- warp/{fem → _src/fem}/field/virtual.py +53 -27
- warp/_src/fem/geometry/__init__.py +32 -0
- warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +77 -163
- warp/_src/fem/geometry/closest_point.py +97 -0
- warp/{fem → _src/fem}/geometry/deformed_geometry.py +14 -22
- warp/{fem → _src/fem}/geometry/element.py +32 -10
- warp/{fem → _src/fem}/geometry/geometry.py +48 -20
- warp/{fem → _src/fem}/geometry/grid_2d.py +12 -23
- warp/{fem → _src/fem}/geometry/grid_3d.py +12 -23
- warp/{fem → _src/fem}/geometry/hexmesh.py +40 -63
- warp/{fem → _src/fem}/geometry/nanogrid.py +255 -248
- warp/{fem → _src/fem}/geometry/partition.py +121 -63
- warp/{fem → _src/fem}/geometry/quadmesh.py +26 -45
- warp/{fem → _src/fem}/geometry/tetmesh.py +40 -63
- warp/{fem → _src/fem}/geometry/trimesh.py +26 -45
- warp/{fem → _src/fem}/integrate.py +164 -158
- warp/_src/fem/linalg.py +383 -0
- warp/_src/fem/operator.py +396 -0
- warp/_src/fem/polynomial.py +229 -0
- warp/{fem → _src/fem}/quadrature/pic_quadrature.py +15 -20
- warp/{fem → _src/fem}/quadrature/quadrature.py +95 -47
- warp/_src/fem/space/__init__.py +248 -0
- warp/{fem → _src/fem}/space/basis_function_space.py +20 -11
- warp/_src/fem/space/basis_space.py +679 -0
- warp/{fem → _src/fem}/space/dof_mapper.py +3 -3
- warp/{fem → _src/fem}/space/function_space.py +14 -13
- warp/{fem → _src/fem}/space/grid_2d_function_space.py +4 -7
- warp/{fem → _src/fem}/space/grid_3d_function_space.py +4 -4
- warp/{fem → _src/fem}/space/hexmesh_function_space.py +4 -10
- warp/{fem → _src/fem}/space/nanogrid_function_space.py +3 -9
- warp/{fem → _src/fem}/space/partition.py +117 -60
- warp/{fem → _src/fem}/space/quadmesh_function_space.py +4 -10
- warp/{fem → _src/fem}/space/restriction.py +66 -33
- warp/_src/fem/space/shape/__init__.py +152 -0
- warp/{fem → _src/fem}/space/shape/cube_shape_function.py +9 -9
- warp/{fem → _src/fem}/space/shape/shape_function.py +8 -9
- warp/{fem → _src/fem}/space/shape/square_shape_function.py +6 -6
- warp/{fem → _src/fem}/space/shape/tet_shape_function.py +3 -3
- warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +3 -3
- warp/{fem → _src/fem}/space/tetmesh_function_space.py +3 -9
- warp/_src/fem/space/topology.py +459 -0
- warp/{fem → _src/fem}/space/trimesh_function_space.py +3 -9
- warp/_src/fem/types.py +112 -0
- warp/_src/fem/utils.py +486 -0
- warp/_src/jax.py +186 -0
- warp/_src/jax_experimental/__init__.py +14 -0
- warp/_src/jax_experimental/custom_call.py +387 -0
- warp/_src/jax_experimental/ffi.py +1284 -0
- warp/_src/jax_experimental/xla_ffi.py +656 -0
- warp/_src/marching_cubes.py +708 -0
- warp/_src/math.py +414 -0
- warp/_src/optim/__init__.py +14 -0
- warp/_src/optim/adam.py +163 -0
- warp/_src/optim/linear.py +1606 -0
- warp/_src/optim/sgd.py +112 -0
- warp/_src/paddle.py +406 -0
- warp/_src/render/__init__.py +14 -0
- warp/_src/render/imgui_manager.py +289 -0
- warp/_src/render/render_opengl.py +3636 -0
- warp/_src/render/render_usd.py +937 -0
- warp/_src/render/utils.py +160 -0
- warp/_src/sparse.py +2716 -0
- warp/_src/tape.py +1206 -0
- warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
- warp/_src/torch.py +391 -0
- warp/_src/types.py +5870 -0
- warp/_src/utils.py +1693 -0
- warp/autograd.py +12 -1054
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +8 -588
- warp/build_dll.py +6 -721
- warp/codegen.py +6 -4251
- warp/constants.py +6 -39
- warp/context.py +12 -8062
- warp/dlpack.py +6 -444
- warp/examples/distributed/example_jacobi_mpi.py +4 -5
- warp/examples/fem/example_adaptive_grid.py +1 -1
- warp/examples/fem/example_apic_fluid.py +1 -1
- warp/examples/fem/example_burgers.py +8 -8
- warp/examples/fem/example_diffusion.py +1 -1
- warp/examples/fem/example_distortion_energy.py +1 -1
- warp/examples/fem/example_mixed_elasticity.py +2 -2
- warp/examples/fem/example_navier_stokes.py +1 -1
- warp/examples/fem/example_nonconforming_contact.py +7 -7
- warp/examples/fem/example_stokes.py +1 -1
- warp/examples/fem/example_stokes_transfer.py +1 -1
- warp/examples/fem/utils.py +2 -2
- warp/examples/interop/example_jax_callable.py +1 -1
- warp/examples/interop/example_jax_ffi_callback.py +1 -1
- warp/examples/interop/example_jax_kernel.py +1 -1
- warp/examples/tile/example_tile_mcgp.py +191 -0
- warp/fabric.py +6 -337
- warp/fem/__init__.py +159 -97
- warp/fem/adaptivity.py +7 -489
- warp/fem/cache.py +9 -648
- warp/fem/dirichlet.py +6 -184
- warp/fem/field/__init__.py +8 -109
- warp/fem/field/field.py +7 -652
- warp/fem/geometry/__init__.py +7 -18
- warp/fem/geometry/closest_point.py +11 -77
- warp/fem/linalg.py +18 -366
- warp/fem/operator.py +11 -369
- warp/fem/polynomial.py +9 -209
- warp/fem/space/__init__.py +5 -211
- warp/fem/space/basis_space.py +6 -662
- warp/fem/space/shape/__init__.py +41 -118
- warp/fem/space/topology.py +6 -437
- warp/fem/types.py +6 -81
- warp/fem/utils.py +11 -444
- warp/jax.py +8 -165
- warp/jax_experimental/__init__.py +14 -1
- warp/jax_experimental/custom_call.py +8 -365
- warp/jax_experimental/ffi.py +17 -873
- warp/jax_experimental/xla_ffi.py +5 -605
- warp/marching_cubes.py +5 -689
- warp/math.py +16 -393
- warp/native/array.h +385 -37
- warp/native/builtin.h +314 -37
- warp/native/bvh.cpp +43 -9
- warp/native/bvh.cu +62 -27
- warp/native/bvh.h +310 -309
- warp/native/clang/clang.cpp +102 -97
- warp/native/coloring.cpp +0 -1
- warp/native/crt.h +208 -0
- warp/native/exports.h +156 -0
- warp/native/hashgrid.cu +2 -0
- warp/native/intersect.h +24 -1
- warp/native/intersect_tri.h +44 -35
- warp/native/mat.h +1456 -276
- warp/native/mesh.cpp +4 -4
- warp/native/mesh.cu +4 -2
- warp/native/mesh.h +176 -61
- warp/native/quat.h +0 -52
- warp/native/scan.cu +2 -0
- warp/native/sparse.cu +7 -3
- warp/native/spatial.h +12 -0
- warp/native/tile.h +681 -89
- warp/native/tile_radix_sort.h +1 -1
- warp/native/tile_reduce.h +394 -46
- warp/native/tile_scan.h +4 -4
- warp/native/vec.h +469 -0
- warp/native/version.h +23 -0
- warp/native/volume.cpp +1 -1
- warp/native/volume.cu +1 -0
- warp/native/volume.h +1 -1
- warp/native/volume_builder.cu +2 -0
- warp/native/warp.cpp +57 -29
- warp/native/warp.cu +253 -171
- warp/native/warp.h +11 -8
- warp/optim/__init__.py +6 -3
- warp/optim/adam.py +6 -145
- warp/optim/linear.py +14 -1585
- warp/optim/sgd.py +6 -94
- warp/paddle.py +6 -388
- warp/render/__init__.py +8 -4
- warp/render/imgui_manager.py +7 -267
- warp/render/render_opengl.py +6 -3618
- warp/render/render_usd.py +6 -919
- warp/render/utils.py +6 -142
- warp/sparse.py +37 -2563
- warp/tape.py +6 -1188
- warp/tests/__main__.py +1 -1
- warp/tests/cuda/test_async.py +4 -4
- warp/tests/cuda/test_conditional_captures.py +1 -1
- warp/tests/cuda/test_multigpu.py +1 -1
- warp/tests/cuda/test_streams.py +58 -1
- warp/tests/geometry/test_bvh.py +157 -22
- warp/tests/geometry/test_marching_cubes.py +0 -1
- warp/tests/geometry/test_mesh.py +5 -3
- warp/tests/geometry/test_mesh_query_aabb.py +5 -12
- warp/tests/geometry/test_mesh_query_point.py +5 -2
- warp/tests/geometry/test_mesh_query_ray.py +15 -3
- warp/tests/geometry/test_volume_write.py +5 -5
- warp/tests/interop/test_dlpack.py +14 -14
- warp/tests/interop/test_jax.py +772 -49
- warp/tests/interop/test_paddle.py +1 -1
- warp/tests/test_adam.py +0 -1
- warp/tests/test_arithmetic.py +9 -9
- warp/tests/test_array.py +527 -100
- warp/tests/test_array_reduce.py +3 -3
- warp/tests/test_atomic.py +12 -8
- warp/tests/test_atomic_bitwise.py +209 -0
- warp/tests/test_atomic_cas.py +4 -4
- warp/tests/test_bool.py +2 -2
- warp/tests/test_builtins_resolution.py +5 -571
- warp/tests/test_codegen.py +33 -14
- warp/tests/test_conditional.py +1 -1
- warp/tests/test_context.py +6 -6
- warp/tests/test_copy.py +242 -161
- warp/tests/test_ctypes.py +3 -3
- warp/tests/test_devices.py +24 -2
- warp/tests/test_examples.py +16 -84
- warp/tests/test_fabricarray.py +35 -35
- warp/tests/test_fast_math.py +0 -2
- warp/tests/test_fem.py +56 -10
- warp/tests/test_fixedarray.py +3 -3
- warp/tests/test_func.py +8 -5
- warp/tests/test_generics.py +1 -1
- warp/tests/test_indexedarray.py +24 -24
- warp/tests/test_intersect.py +39 -9
- warp/tests/test_large.py +1 -1
- warp/tests/test_lerp.py +3 -1
- warp/tests/test_linear_solvers.py +1 -1
- warp/tests/test_map.py +35 -4
- warp/tests/test_mat.py +52 -62
- warp/tests/test_mat_constructors.py +4 -5
- warp/tests/test_mat_lite.py +1 -1
- warp/tests/test_mat_scalar_ops.py +121 -121
- warp/tests/test_math.py +34 -0
- warp/tests/test_module_aot.py +4 -4
- warp/tests/test_modules_lite.py +28 -2
- warp/tests/test_print.py +11 -11
- warp/tests/test_quat.py +93 -58
- warp/tests/test_runlength_encode.py +1 -1
- warp/tests/test_scalar_ops.py +38 -10
- warp/tests/test_smoothstep.py +1 -1
- warp/tests/test_sparse.py +126 -15
- warp/tests/test_spatial.py +105 -87
- warp/tests/test_special_values.py +6 -6
- warp/tests/test_static.py +7 -7
- warp/tests/test_struct.py +13 -2
- warp/tests/test_triangle_closest_point.py +48 -1
- warp/tests/test_types.py +27 -15
- warp/tests/test_utils.py +52 -52
- warp/tests/test_vec.py +29 -29
- warp/tests/test_vec_constructors.py +5 -5
- warp/tests/test_vec_scalar_ops.py +97 -97
- warp/tests/test_version.py +75 -0
- warp/tests/tile/test_tile.py +178 -0
- warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
- warp/tests/tile/test_tile_cholesky.py +7 -4
- warp/tests/tile/test_tile_load.py +26 -2
- warp/tests/tile/test_tile_mathdx.py +3 -3
- warp/tests/tile/test_tile_matmul.py +1 -1
- warp/tests/tile/test_tile_mlp.py +2 -4
- warp/tests/tile/test_tile_reduce.py +214 -13
- warp/tests/unittest_suites.py +6 -14
- warp/tests/unittest_utils.py +10 -9
- warp/tests/walkthrough_debug.py +3 -1
- warp/torch.py +6 -373
- warp/types.py +29 -5764
- warp/utils.py +10 -1659
- {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/METADATA +46 -99
- warp_lang-1.10.0rc2.dist-info/RECORD +468 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
- warp/examples/assets/cartpole.urdf +0 -110
- warp/examples/assets/crazyflie.usd +0 -0
- warp/examples/assets/nv_ant.xml +0 -92
- warp/examples/assets/nv_humanoid.xml +0 -183
- warp/examples/assets/quadruped.urdf +0 -268
- warp/examples/optim/example_bounce.py +0 -266
- warp/examples/optim/example_cloth_throw.py +0 -228
- warp/examples/optim/example_drone.py +0 -870
- warp/examples/optim/example_inverse_kinematics.py +0 -182
- warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
- warp/examples/optim/example_softbody_properties.py +0 -400
- warp/examples/optim/example_spring_cage.py +0 -245
- warp/examples/optim/example_trajectory.py +0 -227
- warp/examples/sim/example_cartpole.py +0 -143
- warp/examples/sim/example_cloth.py +0 -225
- warp/examples/sim/example_cloth_self_contact.py +0 -316
- warp/examples/sim/example_granular.py +0 -130
- warp/examples/sim/example_granular_collision_sdf.py +0 -202
- warp/examples/sim/example_jacobian_ik.py +0 -244
- warp/examples/sim/example_particle_chain.py +0 -124
- warp/examples/sim/example_quadruped.py +0 -203
- warp/examples/sim/example_rigid_chain.py +0 -203
- warp/examples/sim/example_rigid_contact.py +0 -195
- warp/examples/sim/example_rigid_force.py +0 -133
- warp/examples/sim/example_rigid_gyroscopic.py +0 -115
- warp/examples/sim/example_rigid_soft_contact.py +0 -140
- warp/examples/sim/example_soft_body.py +0 -196
- warp/examples/tile/example_tile_walker.py +0 -327
- warp/sim/__init__.py +0 -74
- warp/sim/articulation.py +0 -793
- warp/sim/collide.py +0 -2570
- warp/sim/graph_coloring.py +0 -307
- warp/sim/import_mjcf.py +0 -791
- warp/sim/import_snu.py +0 -227
- warp/sim/import_urdf.py +0 -579
- warp/sim/import_usd.py +0 -898
- warp/sim/inertia.py +0 -357
- warp/sim/integrator.py +0 -245
- warp/sim/integrator_euler.py +0 -2000
- warp/sim/integrator_featherstone.py +0 -2101
- warp/sim/integrator_vbd.py +0 -2487
- warp/sim/integrator_xpbd.py +0 -3295
- warp/sim/model.py +0 -4821
- warp/sim/particles.py +0 -121
- warp/sim/render.py +0 -431
- warp/sim/utils.py +0 -431
- warp/tests/sim/disabled_kinematics.py +0 -244
- warp/tests/sim/test_cloth.py +0 -863
- warp/tests/sim/test_collision.py +0 -743
- warp/tests/sim/test_coloring.py +0 -347
- warp/tests/sim/test_inertia.py +0 -161
- warp/tests/sim/test_model.py +0 -226
- warp/tests/sim/test_sim_grad.py +0 -287
- warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
- warp/tests/sim/test_sim_kinematics.py +0 -98
- warp/thirdparty/__init__.py +0 -0
- warp_lang-1.9.1.dist-info/RECORD +0 -456
- /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
- /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
- /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
- /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
- {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/WHEEL +0 -0
- {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/licenses/LICENSE.md +0 -0
- {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/top_level.txt +0 -0
warp/jax_experimental/ffi.py
CHANGED
|
@@ -13,883 +13,27 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
|
|
17
|
-
import threading
|
|
18
|
-
import traceback
|
|
19
|
-
from enum import IntEnum
|
|
20
|
-
from typing import Callable, Optional
|
|
16
|
+
# isort: skip_file
|
|
21
17
|
|
|
22
|
-
import
|
|
18
|
+
from warp._src.jax_experimental.ffi import GraphMode as GraphMode
|
|
19
|
+
from warp._src.jax_experimental.ffi import jax_kernel as jax_kernel
|
|
20
|
+
from warp._src.jax_experimental.ffi import jax_callable as jax_callable
|
|
21
|
+
from warp._src.jax_experimental.ffi import register_ffi_callback as register_ffi_callback
|
|
23
22
|
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
from warp.
|
|
23
|
+
from warp._src.jax_experimental.ffi import (
|
|
24
|
+
get_jax_callable_default_graph_cache_max as get_jax_callable_default_graph_cache_max,
|
|
25
|
+
)
|
|
26
|
+
from warp._src.jax_experimental.ffi import (
|
|
27
|
+
set_jax_callable_default_graph_cache_max as set_jax_callable_default_graph_cache_max,
|
|
28
|
+
)
|
|
29
|
+
from warp._src.jax_experimental.ffi import clear_jax_callable_graph_cache as clear_jax_callable_graph_cache
|
|
28
30
|
|
|
29
|
-
|
|
31
|
+
# TODO: Remove after cleaning up the public API.
|
|
30
32
|
|
|
33
|
+
from warp._src.jax_experimental import ffi as _ffi
|
|
31
34
|
|
|
32
|
-
def check_jax_version():
|
|
33
|
-
# check if JAX version supports this
|
|
34
|
-
if jax.__version_info__ < (0, 5, 0):
|
|
35
|
-
msg = (
|
|
36
|
-
"This version of jax_kernel() requires JAX version 0.5.0 or higher, "
|
|
37
|
-
f"but installed JAX version is {jax.__version_info__}."
|
|
38
|
-
)
|
|
39
|
-
if jax.__version_info__ >= (0, 4, 25):
|
|
40
|
-
msg += " Please use warp.jax_experimental.custom_call.jax_kernel instead."
|
|
41
|
-
raise RuntimeError(msg)
|
|
42
35
|
|
|
36
|
+
def __getattr__(name):
|
|
37
|
+
from warp._src.utils import get_deprecated_api
|
|
43
38
|
|
|
44
|
-
|
|
45
|
-
NONE = 0 # don't capture a graph
|
|
46
|
-
JAX = 1 # let JAX capture a graph
|
|
47
|
-
WARP = 2 # let Warp capture a graph
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
class FfiArg:
|
|
51
|
-
def __init__(self, name, type, in_out=False):
|
|
52
|
-
self.name = name
|
|
53
|
-
self.type = type
|
|
54
|
-
self.in_out = in_out
|
|
55
|
-
self.is_array = isinstance(type, wp.array)
|
|
56
|
-
|
|
57
|
-
if self.is_array:
|
|
58
|
-
if hasattr(type.dtype, "_wp_scalar_type_"):
|
|
59
|
-
self.dtype_shape = type.dtype._shape_
|
|
60
|
-
self.dtype_ndim = len(self.dtype_shape)
|
|
61
|
-
self.jax_scalar_type = wp.dtype_to_jax(type.dtype._wp_scalar_type_)
|
|
62
|
-
self.jax_ndim = type.ndim + self.dtype_ndim
|
|
63
|
-
elif type.dtype in wp.types.value_types:
|
|
64
|
-
self.dtype_ndim = 0
|
|
65
|
-
self.dtype_shape = ()
|
|
66
|
-
self.jax_scalar_type = wp.dtype_to_jax(type.dtype)
|
|
67
|
-
self.jax_ndim = type.ndim
|
|
68
|
-
else:
|
|
69
|
-
raise TypeError(f"Invalid data type for array argument '{name}', expected scalar, vector, or matrix")
|
|
70
|
-
self.warp_ndim = type.ndim
|
|
71
|
-
elif type in wp.types.value_types:
|
|
72
|
-
self.dtype_ndim = 0
|
|
73
|
-
self.dtype_shape = ()
|
|
74
|
-
self.jax_scalar_type = wp.dtype_to_jax(type_to_warp(type))
|
|
75
|
-
self.jax_ndim = 0
|
|
76
|
-
self.warp_ndim = 0
|
|
77
|
-
else:
|
|
78
|
-
raise TypeError(f"Invalid type for argument '{name}', expected array or scalar, got {type}")
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
class FfiLaunchDesc:
|
|
82
|
-
def __init__(self, static_inputs, launch_dims):
|
|
83
|
-
self.static_inputs = static_inputs
|
|
84
|
-
self.launch_dims = launch_dims
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
class FfiKernel:
|
|
88
|
-
def __init__(self, kernel, num_outputs, vmap_method, launch_dims, output_dims, in_out_argnames):
|
|
89
|
-
self.kernel = kernel
|
|
90
|
-
self.name = generate_unique_name(kernel.func)
|
|
91
|
-
self.num_outputs = num_outputs
|
|
92
|
-
self.vmap_method = vmap_method
|
|
93
|
-
self.launch_dims = launch_dims
|
|
94
|
-
self.output_dims = output_dims
|
|
95
|
-
self.first_array_arg = None
|
|
96
|
-
self.launch_id = 0
|
|
97
|
-
self.launch_descriptors = {}
|
|
98
|
-
|
|
99
|
-
in_out_argnames_list = in_out_argnames or []
|
|
100
|
-
in_out_argnames = set(in_out_argnames_list)
|
|
101
|
-
if len(in_out_argnames_list) != len(in_out_argnames):
|
|
102
|
-
raise AssertionError("in_out_argnames must not contain duplicate names")
|
|
103
|
-
|
|
104
|
-
self.num_kernel_args = len(kernel.adj.args)
|
|
105
|
-
self.num_in_out = len(in_out_argnames)
|
|
106
|
-
self.num_inputs = self.num_kernel_args - num_outputs + self.num_in_out
|
|
107
|
-
if self.num_outputs < 1:
|
|
108
|
-
raise ValueError("At least one output is required")
|
|
109
|
-
if self.num_outputs > self.num_kernel_args:
|
|
110
|
-
raise ValueError("Number of outputs cannot be greater than the number of kernel arguments")
|
|
111
|
-
if self.num_outputs < self.num_in_out:
|
|
112
|
-
raise ValueError("Number of outputs cannot be smaller than the number of in_out_argnames")
|
|
113
|
-
|
|
114
|
-
# process input args
|
|
115
|
-
self.input_args = []
|
|
116
|
-
for i in range(self.num_inputs):
|
|
117
|
-
arg_name = kernel.adj.args[i].label
|
|
118
|
-
arg = FfiArg(arg_name, kernel.adj.args[i].type, arg_name in in_out_argnames)
|
|
119
|
-
if arg_name in in_out_argnames:
|
|
120
|
-
in_out_argnames.remove(arg_name)
|
|
121
|
-
if arg.is_array:
|
|
122
|
-
# keep track of the first input array argument
|
|
123
|
-
if self.first_array_arg is None:
|
|
124
|
-
self.first_array_arg = i
|
|
125
|
-
self.input_args.append(arg)
|
|
126
|
-
|
|
127
|
-
# process output args
|
|
128
|
-
self.output_args = []
|
|
129
|
-
for i in range(self.num_inputs, self.num_kernel_args):
|
|
130
|
-
arg_name = kernel.adj.args[i].label
|
|
131
|
-
if arg_name in in_out_argnames:
|
|
132
|
-
raise AssertionError(
|
|
133
|
-
f"Expected an output-only argument for argument {arg_name}."
|
|
134
|
-
" in_out arguments should be placed before output-only arguments."
|
|
135
|
-
)
|
|
136
|
-
arg = FfiArg(arg_name, kernel.adj.args[i].type, False)
|
|
137
|
-
if not arg.is_array:
|
|
138
|
-
raise TypeError("All output arguments must be arrays")
|
|
139
|
-
self.output_args.append(arg)
|
|
140
|
-
|
|
141
|
-
if in_out_argnames:
|
|
142
|
-
raise ValueError(f"in_out_argnames: '{in_out_argnames}' did not match any function argument names.")
|
|
143
|
-
|
|
144
|
-
# Build input output aliases.
|
|
145
|
-
out_id = 0
|
|
146
|
-
input_output_aliases = {}
|
|
147
|
-
for in_id, arg in enumerate(self.input_args):
|
|
148
|
-
if not arg.in_out:
|
|
149
|
-
continue
|
|
150
|
-
input_output_aliases[in_id] = out_id
|
|
151
|
-
out_id += 1
|
|
152
|
-
self.input_output_aliases = input_output_aliases
|
|
153
|
-
|
|
154
|
-
# register the callback
|
|
155
|
-
FFI_CCALLFUNC = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_CallFrame))
|
|
156
|
-
self.callback_func = FFI_CCALLFUNC(lambda call_frame: self.ffi_callback(call_frame))
|
|
157
|
-
ffi_ccall_address = ctypes.cast(self.callback_func, ctypes.c_void_p)
|
|
158
|
-
ffi_capsule = jax.ffi.pycapsule(ffi_ccall_address.value)
|
|
159
|
-
jax.ffi.register_ffi_target(self.name, ffi_capsule, platform="CUDA")
|
|
160
|
-
|
|
161
|
-
def __call__(self, *args, output_dims=None, launch_dims=None, vmap_method=None):
|
|
162
|
-
num_inputs = len(args)
|
|
163
|
-
if num_inputs != self.num_inputs:
|
|
164
|
-
raise ValueError(f"Expected {self.num_inputs} inputs, but got {num_inputs}")
|
|
165
|
-
|
|
166
|
-
# default argument fallback
|
|
167
|
-
if launch_dims is None:
|
|
168
|
-
launch_dims = self.launch_dims
|
|
169
|
-
if output_dims is None:
|
|
170
|
-
output_dims = self.output_dims
|
|
171
|
-
if vmap_method is None:
|
|
172
|
-
vmap_method = self.vmap_method
|
|
173
|
-
|
|
174
|
-
# output types
|
|
175
|
-
out_types = []
|
|
176
|
-
|
|
177
|
-
# process inputs
|
|
178
|
-
static_inputs = {}
|
|
179
|
-
for i in range(num_inputs):
|
|
180
|
-
input_arg = self.input_args[i]
|
|
181
|
-
input_value = args[i]
|
|
182
|
-
if input_arg.is_array:
|
|
183
|
-
# check dtype
|
|
184
|
-
if input_value.dtype != input_arg.jax_scalar_type:
|
|
185
|
-
raise TypeError(
|
|
186
|
-
f"Invalid data type for array argument '{input_arg.name}', expected {input_arg.jax_scalar_type}, got {input_value.dtype}"
|
|
187
|
-
)
|
|
188
|
-
# check ndim
|
|
189
|
-
if input_value.ndim != input_arg.jax_ndim:
|
|
190
|
-
raise TypeError(
|
|
191
|
-
f"Invalid dimensionality for array argument '{input_arg.name}', expected {input_arg.jax_ndim} dimensions, got {input_value.ndim}"
|
|
192
|
-
)
|
|
193
|
-
# check inner dims
|
|
194
|
-
for d in range(input_arg.dtype_ndim):
|
|
195
|
-
if input_value.shape[input_arg.type.ndim + d] != input_arg.dtype_shape[d]:
|
|
196
|
-
raise TypeError(
|
|
197
|
-
f"Invalid inner dimensions for array argument '{input_arg.name}', expected {input_arg.dtype_shape}, got {input_value.shape[-input_arg.dtype_ndim :]}"
|
|
198
|
-
)
|
|
199
|
-
else:
|
|
200
|
-
# make sure scalar is not a traced variable, should be static
|
|
201
|
-
if isinstance(input_value, jax.core.Tracer):
|
|
202
|
-
raise ValueError(f"Argument '{input_arg.name}' must be a static value")
|
|
203
|
-
# stash the value to be retrieved by callback
|
|
204
|
-
static_inputs[input_arg.name] = input_arg.type(input_value)
|
|
205
|
-
|
|
206
|
-
# append in-out arg to output types
|
|
207
|
-
if input_arg.in_out:
|
|
208
|
-
out_types.append(get_jax_output_type(input_arg, input_value.shape))
|
|
209
|
-
|
|
210
|
-
# launch dimensions
|
|
211
|
-
if launch_dims is None:
|
|
212
|
-
# use the shape of the first input array
|
|
213
|
-
if self.first_array_arg is not None:
|
|
214
|
-
launch_dims = get_warp_shape(self.input_args[self.first_array_arg], args[self.first_array_arg].shape)
|
|
215
|
-
else:
|
|
216
|
-
raise RuntimeError("Failed to determine launch dimensions")
|
|
217
|
-
elif isinstance(launch_dims, int):
|
|
218
|
-
launch_dims = (launch_dims,)
|
|
219
|
-
else:
|
|
220
|
-
launch_dims = tuple(launch_dims)
|
|
221
|
-
|
|
222
|
-
# output shapes
|
|
223
|
-
if isinstance(output_dims, dict):
|
|
224
|
-
# assume a dictionary of shapes keyed on argument name
|
|
225
|
-
for output_arg in self.output_args:
|
|
226
|
-
dims = output_dims.get(output_arg.name)
|
|
227
|
-
if dims is None:
|
|
228
|
-
raise ValueError(f"Missing output dimensions for argument '{output_arg.name}'")
|
|
229
|
-
out_types.append(get_jax_output_type(output_arg, dims))
|
|
230
|
-
else:
|
|
231
|
-
if output_dims is None:
|
|
232
|
-
# use launch dimensions
|
|
233
|
-
output_dims = launch_dims
|
|
234
|
-
elif isinstance(output_dims, int):
|
|
235
|
-
output_dims = (output_dims,)
|
|
236
|
-
# assume same dimensions for all outputs
|
|
237
|
-
for output_arg in self.output_args:
|
|
238
|
-
out_types.append(get_jax_output_type(output_arg, output_dims))
|
|
239
|
-
|
|
240
|
-
call = jax.ffi.ffi_call(
|
|
241
|
-
self.name,
|
|
242
|
-
out_types,
|
|
243
|
-
vmap_method=vmap_method,
|
|
244
|
-
input_output_aliases=self.input_output_aliases,
|
|
245
|
-
)
|
|
246
|
-
|
|
247
|
-
# ensure the kernel module is loaded before the callback, otherwise graph capture may fail
|
|
248
|
-
device = wp.device_from_jax(get_jax_device())
|
|
249
|
-
self.kernel.module.load(device)
|
|
250
|
-
|
|
251
|
-
# save launch data to be retrieved by callback
|
|
252
|
-
launch_id = self.launch_id
|
|
253
|
-
self.launch_descriptors[launch_id] = FfiLaunchDesc(static_inputs, launch_dims)
|
|
254
|
-
self.launch_id += 1
|
|
255
|
-
|
|
256
|
-
return call(*args, launch_id=launch_id)
|
|
257
|
-
|
|
258
|
-
def ffi_callback(self, call_frame):
|
|
259
|
-
try:
|
|
260
|
-
# On the first call, XLA runtime will query the API version and traits
|
|
261
|
-
# metadata using the |extension| field. Let us respond to that query
|
|
262
|
-
# if the metadata extension is present.
|
|
263
|
-
extension = call_frame.contents.extension_start
|
|
264
|
-
if extension:
|
|
265
|
-
# Try to set the version metadata.
|
|
266
|
-
if extension.contents.type == XLA_FFI_Extension_Type.Metadata:
|
|
267
|
-
metadata_ext = ctypes.cast(extension, ctypes.POINTER(XLA_FFI_Metadata_Extension))
|
|
268
|
-
metadata_ext.contents.metadata.contents.api_version.major_version = 0
|
|
269
|
-
metadata_ext.contents.metadata.contents.api_version.minor_version = 1
|
|
270
|
-
# Turn on CUDA graphs for this handler.
|
|
271
|
-
metadata_ext.contents.metadata.contents.traits = (
|
|
272
|
-
XLA_FFI_Handler_TraitsBits.COMMAND_BUFFER_COMPATIBLE
|
|
273
|
-
)
|
|
274
|
-
return None
|
|
275
|
-
|
|
276
|
-
# retrieve call info
|
|
277
|
-
attrs = decode_attrs(call_frame.contents.attrs)
|
|
278
|
-
launch_id = int(attrs["launch_id"])
|
|
279
|
-
launch_desc = self.launch_descriptors[launch_id]
|
|
280
|
-
|
|
281
|
-
num_inputs = call_frame.contents.args.size
|
|
282
|
-
inputs = ctypes.cast(call_frame.contents.args.args, ctypes.POINTER(ctypes.POINTER(XLA_FFI_Buffer)))
|
|
283
|
-
|
|
284
|
-
num_outputs = call_frame.contents.rets.size
|
|
285
|
-
outputs = ctypes.cast(call_frame.contents.rets.rets, ctypes.POINTER(ctypes.POINTER(XLA_FFI_Buffer)))
|
|
286
|
-
|
|
287
|
-
assert num_inputs == self.num_inputs
|
|
288
|
-
assert num_outputs == self.num_outputs
|
|
289
|
-
|
|
290
|
-
launch_bounds = launch_bounds_t(launch_desc.launch_dims)
|
|
291
|
-
|
|
292
|
-
# first kernel param is the launch bounds
|
|
293
|
-
kernel_params = (ctypes.c_void_p * (1 + self.num_kernel_args))()
|
|
294
|
-
kernel_params[0] = ctypes.addressof(launch_bounds)
|
|
295
|
-
|
|
296
|
-
arg_refs = []
|
|
297
|
-
|
|
298
|
-
# input and in-out args
|
|
299
|
-
for i, input_arg in enumerate(self.input_args):
|
|
300
|
-
if input_arg.is_array:
|
|
301
|
-
buffer = inputs[i].contents
|
|
302
|
-
shape = buffer.dims[: input_arg.type.ndim]
|
|
303
|
-
strides = strides_from_shape(shape, input_arg.type.dtype)
|
|
304
|
-
arg = array_t(buffer.data, 0, input_arg.type.ndim, shape, strides)
|
|
305
|
-
kernel_params[i + 1] = ctypes.addressof(arg)
|
|
306
|
-
arg_refs.append(arg) # keep a reference
|
|
307
|
-
else:
|
|
308
|
-
# scalar argument, get stashed value
|
|
309
|
-
value = launch_desc.static_inputs[input_arg.name]
|
|
310
|
-
arg = input_arg.type._type_(value)
|
|
311
|
-
kernel_params[i + 1] = ctypes.addressof(arg)
|
|
312
|
-
arg_refs.append(arg) # keep a reference
|
|
313
|
-
|
|
314
|
-
# pure output args (skip in-out FFI buffers)
|
|
315
|
-
for i, output_arg in enumerate(self.output_args):
|
|
316
|
-
buffer = outputs[i + self.num_in_out].contents
|
|
317
|
-
shape = buffer.dims[: output_arg.type.ndim]
|
|
318
|
-
strides = strides_from_shape(shape, output_arg.type.dtype)
|
|
319
|
-
arg = array_t(buffer.data, 0, output_arg.type.ndim, shape, strides)
|
|
320
|
-
kernel_params[num_inputs + i + 1] = ctypes.addressof(arg)
|
|
321
|
-
arg_refs.append(arg) # keep a reference
|
|
322
|
-
|
|
323
|
-
# get device and stream
|
|
324
|
-
device = wp.device_from_jax(get_jax_device())
|
|
325
|
-
stream = get_stream_from_callframe(call_frame.contents)
|
|
326
|
-
|
|
327
|
-
# get kernel hooks
|
|
328
|
-
hooks = self.kernel.module.get_kernel_hooks(self.kernel, device)
|
|
329
|
-
assert hooks.forward, "Failed to find kernel entry point"
|
|
330
|
-
|
|
331
|
-
# launch the kernel
|
|
332
|
-
wp.context.runtime.core.wp_cuda_launch_kernel(
|
|
333
|
-
device.context,
|
|
334
|
-
hooks.forward,
|
|
335
|
-
launch_bounds.size,
|
|
336
|
-
0,
|
|
337
|
-
256,
|
|
338
|
-
hooks.forward_smem_bytes,
|
|
339
|
-
kernel_params,
|
|
340
|
-
stream,
|
|
341
|
-
)
|
|
342
|
-
|
|
343
|
-
except Exception as e:
|
|
344
|
-
print(traceback.format_exc())
|
|
345
|
-
return create_ffi_error(
|
|
346
|
-
call_frame.contents.api, XLA_FFI_Error_Code.UNKNOWN, f"FFI callback error: {type(e).__name__}: {e}"
|
|
347
|
-
)
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
class FfiCallDesc:
|
|
351
|
-
def __init__(self, static_inputs):
|
|
352
|
-
self.static_inputs = static_inputs
|
|
353
|
-
self.captures = {}
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
class FfiCallable:
|
|
357
|
-
def __init__(self, func, num_outputs, graph_mode, vmap_method, output_dims, in_out_argnames):
|
|
358
|
-
self.func = func
|
|
359
|
-
self.name = generate_unique_name(func)
|
|
360
|
-
self.num_outputs = num_outputs
|
|
361
|
-
self.vmap_method = vmap_method
|
|
362
|
-
self.graph_mode = graph_mode
|
|
363
|
-
self.output_dims = output_dims
|
|
364
|
-
self.first_array_arg = None
|
|
365
|
-
self.call_id = 0
|
|
366
|
-
self.call_descriptors = {}
|
|
367
|
-
|
|
368
|
-
in_out_argnames_list = in_out_argnames or []
|
|
369
|
-
in_out_argnames = set(in_out_argnames_list)
|
|
370
|
-
if len(in_out_argnames_list) != len(in_out_argnames):
|
|
371
|
-
raise AssertionError("in_out_argnames must not contain duplicate names")
|
|
372
|
-
|
|
373
|
-
# get arguments and annotations
|
|
374
|
-
argspec = get_full_arg_spec(func)
|
|
375
|
-
|
|
376
|
-
num_args = len(argspec.args)
|
|
377
|
-
self.num_in_out = len(in_out_argnames)
|
|
378
|
-
self.num_inputs = num_args - num_outputs + self.num_in_out
|
|
379
|
-
if self.num_outputs < 1:
|
|
380
|
-
raise ValueError("At least one output is required")
|
|
381
|
-
if self.num_outputs > num_args:
|
|
382
|
-
raise ValueError("Number of outputs cannot be greater than the number of kernel arguments")
|
|
383
|
-
if self.num_outputs < self.num_in_out:
|
|
384
|
-
raise ValueError("Number of outputs cannot be smaller than the number of in_out_argnames")
|
|
385
|
-
|
|
386
|
-
if len(argspec.annotations) < num_args:
|
|
387
|
-
raise RuntimeError(f"Incomplete argument annotations on function {self.name}")
|
|
388
|
-
|
|
389
|
-
# parse type annotations
|
|
390
|
-
self.args = []
|
|
391
|
-
arg_idx = 0
|
|
392
|
-
for arg_name, arg_type in argspec.annotations.items():
|
|
393
|
-
if arg_name == "return":
|
|
394
|
-
if arg_type is not None:
|
|
395
|
-
raise TypeError("Function must not return a value")
|
|
396
|
-
continue
|
|
397
|
-
else:
|
|
398
|
-
arg = FfiArg(arg_name, arg_type, arg_name in in_out_argnames)
|
|
399
|
-
if arg_name in in_out_argnames:
|
|
400
|
-
in_out_argnames.remove(arg_name)
|
|
401
|
-
if arg.is_array:
|
|
402
|
-
if arg_idx < self.num_inputs and self.first_array_arg is None:
|
|
403
|
-
self.first_array_arg = arg_idx
|
|
404
|
-
self.args.append(arg)
|
|
405
|
-
|
|
406
|
-
if arg.in_out and arg_idx >= self.num_inputs:
|
|
407
|
-
raise AssertionError(
|
|
408
|
-
f"Expected an output-only argument for argument {arg_name}."
|
|
409
|
-
" in_out arguments should be placed before output-only arguments."
|
|
410
|
-
)
|
|
411
|
-
|
|
412
|
-
arg_idx += 1
|
|
413
|
-
|
|
414
|
-
if in_out_argnames:
|
|
415
|
-
raise ValueError(f"in_out_argnames: '{in_out_argnames}' did not match any function argument names.")
|
|
416
|
-
|
|
417
|
-
self.input_args = self.args[: self.num_inputs] # includes in-out args
|
|
418
|
-
self.output_args = self.args[self.num_inputs :] # pure output args
|
|
419
|
-
|
|
420
|
-
# Buffer indices for array arguments in callback.
|
|
421
|
-
# In-out buffers are the same pointers in the XLA call frame,
|
|
422
|
-
# so we only include them for inputs and skip them for outputs.
|
|
423
|
-
self.array_input_indices = [i for i, arg in enumerate(self.input_args) if arg.is_array]
|
|
424
|
-
self.array_output_indices = list(range(self.num_in_out, self.num_outputs))
|
|
425
|
-
|
|
426
|
-
# Build input output aliases.
|
|
427
|
-
out_id = 0
|
|
428
|
-
input_output_aliases = {}
|
|
429
|
-
for in_id, arg in enumerate(self.input_args):
|
|
430
|
-
if not arg.in_out:
|
|
431
|
-
continue
|
|
432
|
-
input_output_aliases[in_id] = out_id
|
|
433
|
-
out_id += 1
|
|
434
|
-
self.input_output_aliases = input_output_aliases
|
|
435
|
-
|
|
436
|
-
# register the callback
|
|
437
|
-
FFI_CCALLFUNC = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_CallFrame))
|
|
438
|
-
self.callback_func = FFI_CCALLFUNC(lambda call_frame: self.ffi_callback(call_frame))
|
|
439
|
-
ffi_ccall_address = ctypes.cast(self.callback_func, ctypes.c_void_p)
|
|
440
|
-
ffi_capsule = jax.ffi.pycapsule(ffi_ccall_address.value)
|
|
441
|
-
jax.ffi.register_ffi_target(self.name, ffi_capsule, platform="CUDA")
|
|
442
|
-
|
|
443
|
-
def __call__(self, *args, output_dims=None, vmap_method=None):
|
|
444
|
-
num_inputs = len(args)
|
|
445
|
-
if num_inputs != self.num_inputs:
|
|
446
|
-
input_names = ", ".join(arg.name for arg in self.input_args)
|
|
447
|
-
s = "" if self.num_inputs == 1 else "s"
|
|
448
|
-
raise ValueError(f"Expected {self.num_inputs} input{s} ({input_names}), but got {num_inputs}")
|
|
449
|
-
|
|
450
|
-
# default argument fallback
|
|
451
|
-
if vmap_method is None:
|
|
452
|
-
vmap_method = self.vmap_method
|
|
453
|
-
if output_dims is None:
|
|
454
|
-
output_dims = self.output_dims
|
|
455
|
-
|
|
456
|
-
# output types
|
|
457
|
-
out_types = []
|
|
458
|
-
|
|
459
|
-
# process inputs
|
|
460
|
-
static_inputs = {}
|
|
461
|
-
for i in range(num_inputs):
|
|
462
|
-
input_arg = self.input_args[i]
|
|
463
|
-
input_value = args[i]
|
|
464
|
-
if input_arg.is_array:
|
|
465
|
-
# check dtype
|
|
466
|
-
if input_value.dtype != input_arg.jax_scalar_type:
|
|
467
|
-
raise TypeError(
|
|
468
|
-
f"Invalid data type for array argument '{input_arg.name}', expected {input_arg.jax_scalar_type}, got {input_value.dtype}"
|
|
469
|
-
)
|
|
470
|
-
# check ndim
|
|
471
|
-
if input_value.ndim != input_arg.jax_ndim:
|
|
472
|
-
raise TypeError(
|
|
473
|
-
f"Invalid dimensionality for array argument '{input_arg.name}', expected {input_arg.jax_ndim} dimensions, got {input_value.ndim}"
|
|
474
|
-
)
|
|
475
|
-
# check inner dims
|
|
476
|
-
for d in range(input_arg.dtype_ndim):
|
|
477
|
-
if input_value.shape[input_arg.type.ndim + d] != input_arg.dtype_shape[d]:
|
|
478
|
-
raise TypeError(
|
|
479
|
-
f"Invalid inner dimensions for array argument '{input_arg.name}', expected {input_arg.dtype_shape}, got {input_value.shape[-input_arg.dtype_ndim :]}"
|
|
480
|
-
)
|
|
481
|
-
else:
|
|
482
|
-
# make sure scalar is not a traced variable, should be static
|
|
483
|
-
if isinstance(input_value, jax.core.Tracer):
|
|
484
|
-
raise ValueError(f"Argument '{input_arg.name}' must be a static value")
|
|
485
|
-
# stash the value to be retrieved by callback
|
|
486
|
-
static_inputs[input_arg.name] = input_arg.type(input_value)
|
|
487
|
-
|
|
488
|
-
# append in-out arg to output types
|
|
489
|
-
if input_arg.in_out:
|
|
490
|
-
out_types.append(get_jax_output_type(input_arg, input_value.shape))
|
|
491
|
-
|
|
492
|
-
# output shapes
|
|
493
|
-
if isinstance(output_dims, dict):
|
|
494
|
-
# assume a dictionary of shapes keyed on argument name
|
|
495
|
-
for output_arg in self.output_args:
|
|
496
|
-
dims = output_dims.get(output_arg.name)
|
|
497
|
-
if dims is None:
|
|
498
|
-
raise ValueError(f"Missing output dimensions for argument '{output_arg.name}'")
|
|
499
|
-
out_types.append(get_jax_output_type(output_arg, dims))
|
|
500
|
-
else:
|
|
501
|
-
if output_dims is None:
|
|
502
|
-
if self.first_array_arg is None:
|
|
503
|
-
raise ValueError("Unable to determine output dimensions")
|
|
504
|
-
output_dims = get_warp_shape(self.input_args[self.first_array_arg], args[self.first_array_arg].shape)
|
|
505
|
-
elif isinstance(output_dims, int):
|
|
506
|
-
output_dims = (output_dims,)
|
|
507
|
-
# assume same dimensions for all outputs
|
|
508
|
-
for output_arg in self.output_args:
|
|
509
|
-
out_types.append(get_jax_output_type(output_arg, output_dims))
|
|
510
|
-
|
|
511
|
-
call = jax.ffi.ffi_call(
|
|
512
|
-
self.name,
|
|
513
|
-
out_types,
|
|
514
|
-
vmap_method=vmap_method,
|
|
515
|
-
input_output_aliases=self.input_output_aliases,
|
|
516
|
-
# has_side_effect=True, # force this function to execute even if outputs aren't used
|
|
517
|
-
)
|
|
518
|
-
|
|
519
|
-
# load the module
|
|
520
|
-
# NOTE: if the target function uses kernels from different modules, they will not be loaded here
|
|
521
|
-
device = wp.device_from_jax(get_jax_device())
|
|
522
|
-
module = wp.get_module(self.func.__module__)
|
|
523
|
-
module.load(device)
|
|
524
|
-
|
|
525
|
-
# save call data to be retrieved by callback
|
|
526
|
-
call_id = self.call_id
|
|
527
|
-
self.call_descriptors[call_id] = FfiCallDesc(static_inputs)
|
|
528
|
-
self.call_id += 1
|
|
529
|
-
return call(*args, call_id=call_id)
|
|
530
|
-
|
|
531
|
-
def ffi_callback(self, call_frame):
|
|
532
|
-
try:
|
|
533
|
-
# On the first call, XLA runtime will query the API version and traits
|
|
534
|
-
# metadata using the |extension| field. Let us respond to that query
|
|
535
|
-
# if the metadata extension is present.
|
|
536
|
-
extension = call_frame.contents.extension_start
|
|
537
|
-
if extension:
|
|
538
|
-
# Try to set the version metadata.
|
|
539
|
-
if extension.contents.type == XLA_FFI_Extension_Type.Metadata:
|
|
540
|
-
metadata_ext = ctypes.cast(extension, ctypes.POINTER(XLA_FFI_Metadata_Extension))
|
|
541
|
-
metadata_ext.contents.metadata.contents.api_version.major_version = 0
|
|
542
|
-
metadata_ext.contents.metadata.contents.api_version.minor_version = 1
|
|
543
|
-
# Turn on CUDA graphs for this handler.
|
|
544
|
-
if self.graph_mode is GraphMode.JAX:
|
|
545
|
-
metadata_ext.contents.metadata.contents.traits = (
|
|
546
|
-
XLA_FFI_Handler_TraitsBits.COMMAND_BUFFER_COMPATIBLE
|
|
547
|
-
)
|
|
548
|
-
return None
|
|
549
|
-
|
|
550
|
-
# retrieve call info
|
|
551
|
-
# NOTE: this assumes that there's only one attribute - call_id (int64).
|
|
552
|
-
# A more general but slower approach is this:
|
|
553
|
-
# attrs = decode_attrs(call_frame.contents.attrs)
|
|
554
|
-
# call_id = int(attrs["call_id"])
|
|
555
|
-
attr = ctypes.cast(call_frame.contents.attrs.attrs[0], ctypes.POINTER(XLA_FFI_Scalar)).contents
|
|
556
|
-
call_id = ctypes.cast(attr.value, ctypes.POINTER(ctypes.c_int64)).contents.value
|
|
557
|
-
call_desc = self.call_descriptors[call_id]
|
|
558
|
-
|
|
559
|
-
num_inputs = call_frame.contents.args.size
|
|
560
|
-
inputs = ctypes.cast(call_frame.contents.args.args, ctypes.POINTER(ctypes.POINTER(XLA_FFI_Buffer)))
|
|
561
|
-
|
|
562
|
-
num_outputs = call_frame.contents.rets.size
|
|
563
|
-
outputs = ctypes.cast(call_frame.contents.rets.rets, ctypes.POINTER(ctypes.POINTER(XLA_FFI_Buffer)))
|
|
564
|
-
|
|
565
|
-
assert num_inputs == self.num_inputs
|
|
566
|
-
assert num_outputs == self.num_outputs
|
|
567
|
-
|
|
568
|
-
cuda_stream = get_stream_from_callframe(call_frame.contents)
|
|
569
|
-
|
|
570
|
-
if self.graph_mode == GraphMode.WARP:
|
|
571
|
-
# check if we already captured an identical call
|
|
572
|
-
ip = [inputs[i].contents.data for i in self.array_input_indices]
|
|
573
|
-
op = [outputs[i].contents.data for i in self.array_output_indices]
|
|
574
|
-
buffer_hash = hash((*ip, *op))
|
|
575
|
-
capture = call_desc.captures.get(buffer_hash)
|
|
576
|
-
|
|
577
|
-
# launch existing graph
|
|
578
|
-
if capture is not None:
|
|
579
|
-
# NOTE: We use the native graph API to avoid overhead with obtaining Stream and Device objects in Python.
|
|
580
|
-
# This code should match wp.capture_launch().
|
|
581
|
-
graph = capture.graph
|
|
582
|
-
if graph.graph_exec is None:
|
|
583
|
-
g = ctypes.c_void_p()
|
|
584
|
-
if not wp.context.runtime.core.wp_cuda_graph_create_exec(
|
|
585
|
-
graph.device.context, cuda_stream, graph.graph, ctypes.byref(g)
|
|
586
|
-
):
|
|
587
|
-
raise RuntimeError(f"Graph creation error: {wp.context.runtime.get_error_string()}")
|
|
588
|
-
graph.graph_exec = g
|
|
589
|
-
|
|
590
|
-
if not wp.context.runtime.core.wp_cuda_graph_launch(graph.graph_exec, cuda_stream):
|
|
591
|
-
raise RuntimeError(f"Graph launch error: {wp.context.runtime.get_error_string()}")
|
|
592
|
-
|
|
593
|
-
# early out
|
|
594
|
-
return
|
|
595
|
-
|
|
596
|
-
device = wp.device_from_jax(get_jax_device())
|
|
597
|
-
stream = wp.Stream(device, cuda_stream=cuda_stream)
|
|
598
|
-
|
|
599
|
-
# reconstruct the argument list
|
|
600
|
-
arg_list = []
|
|
601
|
-
|
|
602
|
-
# input and in-out args
|
|
603
|
-
for i, arg in enumerate(self.input_args):
|
|
604
|
-
if arg.is_array:
|
|
605
|
-
buffer = inputs[i].contents
|
|
606
|
-
shape = buffer.dims[: buffer.rank - arg.dtype_ndim]
|
|
607
|
-
arr = wp.array(ptr=buffer.data, dtype=arg.type.dtype, shape=shape, device=device)
|
|
608
|
-
arg_list.append(arr)
|
|
609
|
-
else:
|
|
610
|
-
# scalar argument, get stashed value
|
|
611
|
-
value = call_desc.static_inputs[arg.name]
|
|
612
|
-
arg_list.append(value)
|
|
613
|
-
|
|
614
|
-
# pure output args (skip in-out FFI buffers)
|
|
615
|
-
for i, arg in enumerate(self.output_args):
|
|
616
|
-
buffer = outputs[i + self.num_in_out].contents
|
|
617
|
-
shape = buffer.dims[: buffer.rank - arg.dtype_ndim]
|
|
618
|
-
arr = wp.array(ptr=buffer.data, dtype=arg.type.dtype, shape=shape, device=device)
|
|
619
|
-
arg_list.append(arr)
|
|
620
|
-
|
|
621
|
-
# call the Python function with reconstructed arguments
|
|
622
|
-
with wp.ScopedStream(stream, sync_enter=False):
|
|
623
|
-
if stream.is_capturing:
|
|
624
|
-
# capturing with JAX
|
|
625
|
-
with wp.ScopedCapture(external=True) as capture:
|
|
626
|
-
self.func(*arg_list)
|
|
627
|
-
# keep a reference to the capture object to prevent required modules getting unloaded
|
|
628
|
-
call_desc.capture = capture
|
|
629
|
-
elif self.graph_mode == GraphMode.WARP:
|
|
630
|
-
# capturing with WARP
|
|
631
|
-
with wp.ScopedCapture() as capture:
|
|
632
|
-
self.func(*arg_list)
|
|
633
|
-
wp.capture_launch(capture.graph)
|
|
634
|
-
# keep a reference to the capture object and reuse it with same buffers
|
|
635
|
-
call_desc.captures[buffer_hash] = capture
|
|
636
|
-
else:
|
|
637
|
-
# not capturing
|
|
638
|
-
self.func(*arg_list)
|
|
639
|
-
|
|
640
|
-
except Exception as e:
|
|
641
|
-
print(traceback.format_exc())
|
|
642
|
-
return create_ffi_error(
|
|
643
|
-
call_frame.contents.api, XLA_FFI_Error_Code.UNKNOWN, f"FFI callback error: {type(e).__name__}: {e}"
|
|
644
|
-
)
|
|
645
|
-
|
|
646
|
-
return None
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
# Holders for the custom callbacks to keep them alive.
|
|
650
|
-
_FFI_CALLABLE_REGISTRY: dict[str, FfiCallable] = {}
|
|
651
|
-
_FFI_KERNEL_REGISTRY: dict[str, FfiKernel] = {}
|
|
652
|
-
_FFI_REGISTRY_LOCK = threading.Lock()
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
def jax_kernel(
|
|
656
|
-
kernel, num_outputs=1, vmap_method="broadcast_all", launch_dims=None, output_dims=None, in_out_argnames=None
|
|
657
|
-
):
|
|
658
|
-
"""Create a JAX callback from a Warp kernel.
|
|
659
|
-
|
|
660
|
-
NOTE: This is an experimental feature under development.
|
|
661
|
-
|
|
662
|
-
Args:
|
|
663
|
-
kernel: The Warp kernel to launch.
|
|
664
|
-
num_outputs: Optional. Specify the number of output arguments if greater than 1.
|
|
665
|
-
This must include the number of ``in_out_arguments``.
|
|
666
|
-
vmap_method: Optional. String specifying how the callback transforms under ``vmap()``.
|
|
667
|
-
This argument can also be specified for individual calls.
|
|
668
|
-
launch_dims: Optional. Specify the default kernel launch dimensions. If None, launch
|
|
669
|
-
dimensions are inferred from the shape of the first array argument.
|
|
670
|
-
This argument can also be specified for individual calls.
|
|
671
|
-
output_dims: Optional. Specify the default dimensions of output arrays. If None, output
|
|
672
|
-
dimensions are inferred from the launch dimensions.
|
|
673
|
-
This argument can also be specified for individual calls.
|
|
674
|
-
in_out_argnames: Optional. Names of input-output arguments.
|
|
675
|
-
|
|
676
|
-
Limitations:
|
|
677
|
-
- All kernel arguments must be contiguous arrays or scalars.
|
|
678
|
-
- Scalars must be static arguments in JAX.
|
|
679
|
-
- Input and input-output arguments must precede the output arguments in the ``kernel`` definition.
|
|
680
|
-
- There must be at least one output or input-output argument.
|
|
681
|
-
- Only the CUDA backend is supported.
|
|
682
|
-
"""
|
|
683
|
-
|
|
684
|
-
check_jax_version()
|
|
685
|
-
|
|
686
|
-
key = (
|
|
687
|
-
kernel.func,
|
|
688
|
-
kernel.sig,
|
|
689
|
-
num_outputs,
|
|
690
|
-
vmap_method,
|
|
691
|
-
tuple(launch_dims) if launch_dims else launch_dims,
|
|
692
|
-
tuple(sorted(output_dims.items())) if output_dims else output_dims,
|
|
693
|
-
)
|
|
694
|
-
|
|
695
|
-
with _FFI_REGISTRY_LOCK:
|
|
696
|
-
if key not in _FFI_KERNEL_REGISTRY:
|
|
697
|
-
new_kernel = FfiKernel(kernel, num_outputs, vmap_method, launch_dims, output_dims, in_out_argnames)
|
|
698
|
-
_FFI_KERNEL_REGISTRY[key] = new_kernel
|
|
699
|
-
|
|
700
|
-
return _FFI_KERNEL_REGISTRY[key]
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
def jax_callable(
|
|
704
|
-
func: Callable,
|
|
705
|
-
num_outputs: int = 1,
|
|
706
|
-
graph_compatible: Optional[bool] = None, # deprecated
|
|
707
|
-
graph_mode: GraphMode = GraphMode.JAX,
|
|
708
|
-
vmap_method: Optional[str] = "broadcast_all",
|
|
709
|
-
output_dims=None,
|
|
710
|
-
in_out_argnames=None,
|
|
711
|
-
):
|
|
712
|
-
"""Create a JAX callback from an annotated Python function.
|
|
713
|
-
|
|
714
|
-
The Python function arguments must have type annotations like Warp kernels.
|
|
715
|
-
|
|
716
|
-
NOTE: This is an experimental feature under development.
|
|
717
|
-
|
|
718
|
-
Args:
|
|
719
|
-
func: The Python function to call.
|
|
720
|
-
num_outputs: Optional. Specify the number of output arguments if greater than 1.
|
|
721
|
-
This must include the number of ``in_out_arguments``.
|
|
722
|
-
graph_compatible: Optional. Whether the function can be called during CUDA graph capture.
|
|
723
|
-
This argument is deprecated, use ``graph_mode`` instead.
|
|
724
|
-
graph_mode: Optional. CUDA graph capture mode.
|
|
725
|
-
``GraphMode.JAX`` (default): Let JAX capture the graph, which may be used as a subgraph in an enclosing capture.
|
|
726
|
-
``GraphMode.WARP``: Let Warp capture the graph. Use this mode when the callable cannot be used as a subraph,
|
|
727
|
-
such as when the callable uses conditional graph nodes.
|
|
728
|
-
``GraphMode.NONE``: Disable graph capture. Use when the callable performs operations that are not legal in a graph,
|
|
729
|
-
such as host synchronization.
|
|
730
|
-
vmap_method: Optional. String specifying how the callback transforms under ``vmap()``.
|
|
731
|
-
This argument can also be specified for individual calls.
|
|
732
|
-
output_dims: Optional. Specify the default dimensions of output arrays.
|
|
733
|
-
If ``None``, output dimensions are inferred from the launch dimensions.
|
|
734
|
-
This argument can also be specified for individual calls.
|
|
735
|
-
in_out_argnames: Optional. Names of input-output arguments.
|
|
736
|
-
|
|
737
|
-
Limitations:
|
|
738
|
-
- All kernel arguments must be contiguous arrays or scalars.
|
|
739
|
-
- Scalars must be static arguments in JAX.
|
|
740
|
-
- Input and input-output arguments must precede the output arguments in the ``func`` definition.
|
|
741
|
-
- There must be at least one output or input-output argument.
|
|
742
|
-
- Only the CUDA backend is supported.
|
|
743
|
-
"""
|
|
744
|
-
|
|
745
|
-
check_jax_version()
|
|
746
|
-
|
|
747
|
-
if graph_compatible is not None:
|
|
748
|
-
wp.utils.warn(
|
|
749
|
-
"The `graph_compatible` argument is deprecated, use `graph_mode` instead.",
|
|
750
|
-
DeprecationWarning,
|
|
751
|
-
stacklevel=3,
|
|
752
|
-
)
|
|
753
|
-
if graph_compatible is False:
|
|
754
|
-
graph_mode = GraphMode.NONE
|
|
755
|
-
|
|
756
|
-
key = (
|
|
757
|
-
func,
|
|
758
|
-
num_outputs,
|
|
759
|
-
graph_mode,
|
|
760
|
-
vmap_method,
|
|
761
|
-
tuple(sorted(output_dims.items())) if output_dims else output_dims,
|
|
762
|
-
)
|
|
763
|
-
|
|
764
|
-
with _FFI_REGISTRY_LOCK:
|
|
765
|
-
if key not in _FFI_CALLABLE_REGISTRY:
|
|
766
|
-
new_callable = FfiCallable(func, num_outputs, graph_mode, vmap_method, output_dims, in_out_argnames)
|
|
767
|
-
_FFI_CALLABLE_REGISTRY[key] = new_callable
|
|
768
|
-
|
|
769
|
-
return _FFI_CALLABLE_REGISTRY[key]
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
###############################################################################
|
|
773
|
-
#
|
|
774
|
-
# Generic FFI callbacks for Python functions of the form
|
|
775
|
-
# func(inputs, outputs, attrs, ctx)
|
|
776
|
-
#
|
|
777
|
-
###############################################################################
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
def register_ffi_callback(name: str, func: Callable, graph_compatible: bool = True) -> None:
|
|
781
|
-
"""Create a JAX callback from a Python function.
|
|
782
|
-
|
|
783
|
-
The Python function must have the form ``func(inputs, outputs, attrs, ctx)``.
|
|
784
|
-
|
|
785
|
-
NOTE: This is an experimental feature under development.
|
|
786
|
-
|
|
787
|
-
Args:
|
|
788
|
-
name: A unique FFI callback name.
|
|
789
|
-
func: The Python function to call.
|
|
790
|
-
graph_compatible: Optional. Whether the function can be called during CUDA graph capture.
|
|
791
|
-
"""
|
|
792
|
-
|
|
793
|
-
check_jax_version()
|
|
794
|
-
|
|
795
|
-
# TODO check that the name is not already registered
|
|
796
|
-
|
|
797
|
-
def ffi_callback(call_frame):
|
|
798
|
-
try:
|
|
799
|
-
extension = call_frame.contents.extension_start
|
|
800
|
-
# On the first call, XLA runtime will query the API version and traits
|
|
801
|
-
# metadata using the |extension| field. Let us respond to that query
|
|
802
|
-
# if the metadata extension is present.
|
|
803
|
-
if extension:
|
|
804
|
-
# Try to set the version metadata.
|
|
805
|
-
if extension.contents.type == XLA_FFI_Extension_Type.Metadata:
|
|
806
|
-
metadata_ext = ctypes.cast(extension, ctypes.POINTER(XLA_FFI_Metadata_Extension))
|
|
807
|
-
metadata_ext.contents.metadata.contents.api_version.major_version = 0
|
|
808
|
-
metadata_ext.contents.metadata.contents.api_version.minor_version = 1
|
|
809
|
-
if graph_compatible:
|
|
810
|
-
# Turn on CUDA graphs for this handler.
|
|
811
|
-
metadata_ext.contents.metadata.contents.traits = (
|
|
812
|
-
XLA_FFI_Handler_TraitsBits.COMMAND_BUFFER_COMPATIBLE
|
|
813
|
-
)
|
|
814
|
-
return None
|
|
815
|
-
|
|
816
|
-
attrs = decode_attrs(call_frame.contents.attrs)
|
|
817
|
-
|
|
818
|
-
input_count = call_frame.contents.args.size
|
|
819
|
-
inputs = ctypes.cast(call_frame.contents.args.args, ctypes.POINTER(ctypes.POINTER(XLA_FFI_Buffer)))
|
|
820
|
-
inputs = [FfiBuffer(inputs[i].contents) for i in range(input_count)]
|
|
821
|
-
|
|
822
|
-
output_count = call_frame.contents.rets.size
|
|
823
|
-
outputs = ctypes.cast(call_frame.contents.rets.rets, ctypes.POINTER(ctypes.POINTER(XLA_FFI_Buffer)))
|
|
824
|
-
outputs = [FfiBuffer(outputs[i].contents) for i in range(output_count)]
|
|
825
|
-
|
|
826
|
-
ctx = ExecutionContext(call_frame.contents)
|
|
827
|
-
|
|
828
|
-
func(inputs, outputs, attrs, ctx)
|
|
829
|
-
except Exception as e:
|
|
830
|
-
print(traceback.format_exc())
|
|
831
|
-
return create_ffi_error(
|
|
832
|
-
call_frame.contents.api, XLA_FFI_Error_Code.UNKNOWN, f"FFI callback error: {type(e).__name__}: {e}"
|
|
833
|
-
)
|
|
834
|
-
|
|
835
|
-
return None
|
|
836
|
-
|
|
837
|
-
FFI_CCALLFUNC = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_CallFrame))
|
|
838
|
-
callback_func = FFI_CCALLFUNC(ffi_callback)
|
|
839
|
-
with _FFI_REGISTRY_LOCK:
|
|
840
|
-
_FFI_CALLABLE_REGISTRY[name] = callback_func
|
|
841
|
-
ffi_ccall_address = ctypes.cast(callback_func, ctypes.c_void_p)
|
|
842
|
-
ffi_capsule = jax.ffi.pycapsule(ffi_ccall_address.value)
|
|
843
|
-
jax.ffi.register_ffi_target(name, ffi_capsule, platform="CUDA")
|
|
844
|
-
|
|
845
|
-
|
|
846
|
-
###############################################################################
|
|
847
|
-
#
|
|
848
|
-
# Utilities
|
|
849
|
-
#
|
|
850
|
-
###############################################################################
|
|
851
|
-
|
|
852
|
-
# ensure unique FFI callback names
|
|
853
|
-
ffi_name_counts = {}
|
|
854
|
-
|
|
855
|
-
|
|
856
|
-
def generate_unique_name(func) -> str:
|
|
857
|
-
key = make_full_qualified_name(func)
|
|
858
|
-
unique_id = ffi_name_counts.get(key, 0)
|
|
859
|
-
ffi_name_counts[key] = unique_id + 1
|
|
860
|
-
return f"{key}_{unique_id}"
|
|
861
|
-
|
|
862
|
-
|
|
863
|
-
def get_warp_shape(arg, dims):
|
|
864
|
-
if arg.dtype_ndim > 0:
|
|
865
|
-
# vector/matrix array
|
|
866
|
-
return dims[: arg.warp_ndim]
|
|
867
|
-
else:
|
|
868
|
-
# scalar array
|
|
869
|
-
return dims
|
|
870
|
-
|
|
871
|
-
|
|
872
|
-
def get_jax_output_type(arg, dims):
|
|
873
|
-
if isinstance(dims, int):
|
|
874
|
-
dims = (dims,)
|
|
875
|
-
|
|
876
|
-
ndim = len(dims)
|
|
877
|
-
|
|
878
|
-
if arg.dtype_ndim > 0:
|
|
879
|
-
# vector/matrix array
|
|
880
|
-
if ndim == arg.warp_ndim:
|
|
881
|
-
return jax.ShapeDtypeStruct((*dims, *arg.dtype_shape), arg.jax_scalar_type)
|
|
882
|
-
elif ndim == arg.jax_ndim:
|
|
883
|
-
# make sure inner dimensions match
|
|
884
|
-
inner_dims = dims[-arg.dtype_ndim :]
|
|
885
|
-
for i in range(arg.dtype_ndim):
|
|
886
|
-
if inner_dims[i] != arg.dtype_shape[i]:
|
|
887
|
-
raise ValueError(f"Invalid output dimensions for argument '{arg.name}': {dims}")
|
|
888
|
-
return jax.ShapeDtypeStruct(dims, arg.jax_scalar_type)
|
|
889
|
-
else:
|
|
890
|
-
raise ValueError(f"Invalid output dimensions for argument '{arg.name}': {dims}")
|
|
891
|
-
else:
|
|
892
|
-
# scalar array
|
|
893
|
-
if ndim != arg.warp_ndim:
|
|
894
|
-
raise ValueError(f"Invalid output dimensions for argument '{arg.name}': {dims}")
|
|
895
|
-
return jax.ShapeDtypeStruct(dims, arg.jax_scalar_type)
|
|
39
|
+
return get_deprecated_api(_ffi, "wp.jax_experimental", name)
|