warp-lang 1.6.2__py3-none-win_amd64.whl → 1.7.0__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +7 -1
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +410 -0
- warp/build_dll.py +6 -14
- warp/builtins.py +452 -362
- warp/codegen.py +179 -119
- warp/config.py +42 -6
- warp/context.py +490 -271
- warp/dlpack.py +8 -6
- warp/examples/assets/nonuniform.usd +0 -0
- warp/examples/assets/nvidia_logo.png +0 -0
- warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
- warp/examples/core/example_sample_mesh.py +300 -0
- warp/examples/fem/example_apic_fluid.py +1 -1
- warp/examples/fem/example_burgers.py +2 -2
- warp/examples/fem/example_deformed_geometry.py +1 -1
- warp/examples/fem/example_distortion_energy.py +1 -1
- warp/examples/fem/example_magnetostatics.py +6 -6
- warp/examples/fem/utils.py +9 -3
- warp/examples/interop/example_jax_callable.py +116 -0
- warp/examples/interop/example_jax_ffi_callback.py +132 -0
- warp/examples/interop/example_jax_kernel.py +205 -0
- warp/examples/optim/example_fluid_checkpoint.py +497 -0
- warp/examples/tile/example_tile_matmul.py +2 -4
- warp/fem/__init__.py +11 -1
- warp/fem/adaptivity.py +4 -4
- warp/fem/field/nodal_field.py +22 -68
- warp/fem/field/virtual.py +62 -23
- warp/fem/geometry/adaptive_nanogrid.py +9 -10
- warp/fem/geometry/closest_point.py +1 -1
- warp/fem/geometry/deformed_geometry.py +5 -2
- warp/fem/geometry/geometry.py +5 -0
- warp/fem/geometry/grid_2d.py +12 -12
- warp/fem/geometry/grid_3d.py +12 -15
- warp/fem/geometry/hexmesh.py +5 -7
- warp/fem/geometry/nanogrid.py +9 -11
- warp/fem/geometry/quadmesh.py +13 -13
- warp/fem/geometry/tetmesh.py +3 -4
- warp/fem/geometry/trimesh.py +3 -8
- warp/fem/integrate.py +262 -93
- warp/fem/linalg.py +5 -5
- warp/fem/quadrature/pic_quadrature.py +37 -22
- warp/fem/quadrature/quadrature.py +194 -25
- warp/fem/space/__init__.py +1 -1
- warp/fem/space/basis_function_space.py +4 -2
- warp/fem/space/basis_space.py +25 -18
- warp/fem/space/hexmesh_function_space.py +2 -2
- warp/fem/space/partition.py +6 -2
- warp/fem/space/quadmesh_function_space.py +8 -8
- warp/fem/space/shape/cube_shape_function.py +23 -23
- warp/fem/space/shape/square_shape_function.py +12 -12
- warp/fem/space/shape/triangle_shape_function.py +1 -1
- warp/fem/space/tetmesh_function_space.py +3 -3
- warp/fem/space/trimesh_function_space.py +2 -2
- warp/fem/utils.py +12 -6
- warp/jax.py +14 -1
- warp/jax_experimental/__init__.py +16 -0
- warp/{jax_experimental.py → jax_experimental/custom_call.py} +14 -27
- warp/jax_experimental/ffi.py +698 -0
- warp/jax_experimental/xla_ffi.py +602 -0
- warp/math.py +89 -0
- warp/native/array.h +13 -0
- warp/native/builtin.h +29 -3
- warp/native/bvh.cpp +3 -1
- warp/native/bvh.cu +42 -14
- warp/native/bvh.h +2 -1
- warp/native/clang/clang.cpp +30 -3
- warp/native/cuda_util.cpp +14 -0
- warp/native/cuda_util.h +2 -0
- warp/native/exports.h +68 -63
- warp/native/intersect.h +26 -26
- warp/native/intersect_adj.h +33 -33
- warp/native/marching.cu +1 -1
- warp/native/mat.h +513 -9
- warp/native/mesh.h +10 -10
- warp/native/quat.h +99 -11
- warp/native/rand.h +6 -0
- warp/native/sort.cpp +122 -59
- warp/native/sort.cu +152 -15
- warp/native/sort.h +8 -1
- warp/native/sparse.cpp +43 -22
- warp/native/sparse.cu +52 -17
- warp/native/svd.h +116 -0
- warp/native/tile.h +301 -105
- warp/native/tile_reduce.h +46 -3
- warp/native/vec.h +68 -7
- warp/native/volume.cpp +85 -113
- warp/native/volume_builder.cu +25 -10
- warp/native/volume_builder.h +6 -0
- warp/native/warp.cpp +5 -6
- warp/native/warp.cu +99 -10
- warp/native/warp.h +19 -10
- warp/optim/linear.py +10 -10
- warp/sim/articulation.py +4 -4
- warp/sim/collide.py +21 -10
- warp/sim/import_mjcf.py +449 -155
- warp/sim/import_urdf.py +32 -12
- warp/sim/integrator_euler.py +5 -5
- warp/sim/integrator_featherstone.py +3 -10
- warp/sim/integrator_vbd.py +207 -2
- warp/sim/integrator_xpbd.py +5 -5
- warp/sim/model.py +42 -13
- warp/sim/utils.py +2 -2
- warp/sparse.py +642 -555
- warp/stubs.py +216 -19
- warp/tests/__main__.py +0 -15
- warp/tests/cuda/__init__.py +0 -0
- warp/tests/{test_mempool.py → cuda/test_mempool.py} +39 -0
- warp/tests/{test_streams.py → cuda/test_streams.py} +71 -0
- warp/tests/geometry/__init__.py +0 -0
- warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +66 -63
- warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +1 -1
- warp/tests/{test_volume.py → geometry/test_volume.py} +41 -6
- warp/tests/interop/__init__.py +0 -0
- warp/tests/{test_dlpack.py → interop/test_dlpack.py} +28 -5
- warp/tests/sim/__init__.py +0 -0
- warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +9 -10
- warp/tests/{test_collision.py → sim/test_collision.py} +2 -2
- warp/tests/{test_model.py → sim/test_model.py} +40 -0
- warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +2 -1
- warp/tests/sim/test_vbd.py +597 -0
- warp/tests/test_bool.py +1 -1
- warp/tests/test_examples.py +28 -36
- warp/tests/test_fem.py +23 -4
- warp/tests/test_linear_solvers.py +0 -11
- warp/tests/test_mat.py +233 -79
- warp/tests/test_mat_scalar_ops.py +4 -4
- warp/tests/test_overwrite.py +0 -60
- warp/tests/test_quat.py +67 -46
- warp/tests/test_rand.py +44 -37
- warp/tests/test_sparse.py +47 -6
- warp/tests/test_spatial.py +75 -0
- warp/tests/test_static.py +1 -1
- warp/tests/test_utils.py +84 -4
- warp/tests/test_vec.py +46 -34
- warp/tests/tile/__init__.py +0 -0
- warp/tests/{test_tile.py → tile/test_tile.py} +136 -51
- warp/tests/{test_tile_load.py → tile/test_tile_load.py} +1 -1
- warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +9 -6
- warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +25 -14
- warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +60 -1
- warp/tests/{test_tile_view.py → tile/test_tile_view.py} +1 -1
- warp/tests/unittest_serial.py +1 -0
- warp/tests/unittest_suites.py +45 -59
- warp/tests/unittest_utils.py +2 -1
- warp/thirdparty/unittest_parallel.py +3 -1
- warp/types.py +110 -658
- warp/utils.py +137 -72
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/METADATA +29 -7
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/RECORD +172 -162
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/WHEEL +1 -1
- warp/examples/optim/example_walker.py +0 -317
- warp/native/cutlass_gemm.cpp +0 -43
- warp/native/cutlass_gemm.cu +0 -382
- warp/tests/test_matmul.py +0 -511
- warp/tests/test_matmul_lite.py +0 -411
- warp/tests/test_vbd.py +0 -386
- warp/tests/unused_test_misc.py +0 -77
- /warp/tests/{test_async.py → cuda/test_async.py} +0 -0
- /warp/tests/{test_ipc.py → cuda/test_ipc.py} +0 -0
- /warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +0 -0
- /warp/tests/{test_peer.py → cuda/test_peer.py} +0 -0
- /warp/tests/{test_pinned.py → cuda/test_pinned.py} +0 -0
- /warp/tests/{test_bvh.py → geometry/test_bvh.py} +0 -0
- /warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +0 -0
- /warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +0 -0
- /warp/tests/{test_mesh.py → geometry/test_mesh.py} +0 -0
- /warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +0 -0
- /warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +0 -0
- /warp/tests/{test_jax.py → interop/test_jax.py} +0 -0
- /warp/tests/{test_paddle.py → interop/test_paddle.py} +0 -0
- /warp/tests/{test_torch.py → interop/test_torch.py} +0 -0
- /warp/tests/{flaky_test_sim_grad.py → sim/flaky_test_sim_grad.py} +0 -0
- /warp/tests/{test_coloring.py → sim/test_coloring.py} +0 -0
- /warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +0 -0
- /warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +0 -0
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info/licenses}/LICENSE.md +0 -0
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,698 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import ctypes
|
|
17
|
+
import traceback
|
|
18
|
+
from typing import Callable
|
|
19
|
+
|
|
20
|
+
import jax
|
|
21
|
+
|
|
22
|
+
import warp as wp
|
|
23
|
+
from warp.codegen import get_full_arg_spec, make_full_qualified_name
|
|
24
|
+
from warp.jax import get_jax_device
|
|
25
|
+
from warp.types import array_t, launch_bounds_t, strides_from_shape, type_to_warp
|
|
26
|
+
|
|
27
|
+
from .xla_ffi import *
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def jax_kernel(kernel, num_outputs=1, vmap_method="broadcast_all", launch_dims=None, output_dims=None):
|
|
31
|
+
"""Create a JAX callback from a Warp kernel.
|
|
32
|
+
|
|
33
|
+
NOTE: This is an experimental feature under development.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
kernel: The Warp kernel to launch.
|
|
37
|
+
num_outputs: Optional. Specify the number of output arguments if greater than 1.
|
|
38
|
+
vmap_method: Optional. String specifying how the callback transforms under ``vmap()``.
|
|
39
|
+
This argument can also be specified for individual calls.
|
|
40
|
+
launch_dims: Optional. Specify the default kernel launch dimensions. If None, launch
|
|
41
|
+
dimensions are inferred from the shape of the first array argument.
|
|
42
|
+
This argument can also be specified for individual calls.
|
|
43
|
+
output_dims: Optional. Specify the default dimensions of output arrays. If None, output
|
|
44
|
+
dimensions are inferred from the launch dimensions.
|
|
45
|
+
This argument can also be specified for individual calls.
|
|
46
|
+
|
|
47
|
+
Limitations:
|
|
48
|
+
- All kernel arguments must be contiguous arrays or scalars.
|
|
49
|
+
- Scalars must be static arguments in JAX.
|
|
50
|
+
- Input arguments are followed by output arguments in the Warp kernel definition.
|
|
51
|
+
- There must be at least one output argument.
|
|
52
|
+
- Only the CUDA backend is supported.
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
return FfiKernel(kernel, num_outputs, vmap_method, launch_dims, output_dims)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def jax_callable(
|
|
59
|
+
func: Callable,
|
|
60
|
+
num_outputs: int = 1,
|
|
61
|
+
graph_compatible: bool = True,
|
|
62
|
+
vmap_method: str = "broadcast_all",
|
|
63
|
+
output_dims=None,
|
|
64
|
+
):
|
|
65
|
+
"""Create a JAX callback from an annotated Python function.
|
|
66
|
+
|
|
67
|
+
The Python function arguments must have type annotations like Warp kernels.
|
|
68
|
+
|
|
69
|
+
NOTE: This is an experimental feature under development.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
func: The Python function to call.
|
|
73
|
+
num_outputs: Optional. Specify the number of output arguments if greater than 1.
|
|
74
|
+
graph_compatible: Optional. Whether the function can be called during CUDA graph capture.
|
|
75
|
+
vmap_method: Optional. String specifying how the callback transforms under ``vmap()``.
|
|
76
|
+
This argument can also be specified for individual calls.
|
|
77
|
+
output_dims: Optional. Specify the default dimensions of output arrays.
|
|
78
|
+
If ``None``, output dimensions are inferred from the launch dimensions.
|
|
79
|
+
This argument can also be specified for individual calls.
|
|
80
|
+
|
|
81
|
+
Limitations:
|
|
82
|
+
- All kernel arguments must be contiguous arrays or scalars.
|
|
83
|
+
- Scalars must be static arguments in JAX.
|
|
84
|
+
- Input arguments are followed by output arguments in the Warp kernel definition.
|
|
85
|
+
- There must be at least one output argument.
|
|
86
|
+
- Only the CUDA backend is supported.
|
|
87
|
+
"""
|
|
88
|
+
|
|
89
|
+
return FfiCallable(func, num_outputs, graph_compatible, vmap_method, output_dims)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
class FfiArg:
|
|
93
|
+
def __init__(self, name, type):
|
|
94
|
+
self.name = name
|
|
95
|
+
self.type = type
|
|
96
|
+
self.is_array = isinstance(type, wp.array)
|
|
97
|
+
|
|
98
|
+
if self.is_array:
|
|
99
|
+
if hasattr(type.dtype, "_wp_scalar_type_"):
|
|
100
|
+
self.dtype_shape = type.dtype._shape_
|
|
101
|
+
self.dtype_ndim = len(self.dtype_shape)
|
|
102
|
+
self.jax_scalar_type = wp.dtype_to_jax(type.dtype._wp_scalar_type_)
|
|
103
|
+
self.jax_ndim = type.ndim + self.dtype_ndim
|
|
104
|
+
elif type.dtype in wp.types.value_types:
|
|
105
|
+
self.dtype_ndim = 0
|
|
106
|
+
self.dtype_shape = ()
|
|
107
|
+
self.jax_scalar_type = wp.dtype_to_jax(type.dtype)
|
|
108
|
+
self.jax_ndim = type.ndim
|
|
109
|
+
else:
|
|
110
|
+
raise TypeError(f"Invalid data type for array argument '{name}', expected scalar, vector, or matrix")
|
|
111
|
+
self.warp_ndim = type.ndim
|
|
112
|
+
elif type in wp.types.value_types:
|
|
113
|
+
self.dtype_ndim = 0
|
|
114
|
+
self.dtype_shape = ()
|
|
115
|
+
self.jax_scalar_type = wp.dtype_to_jax(type_to_warp(type))
|
|
116
|
+
self.jax_ndim = 0
|
|
117
|
+
self.warp_ndim = 0
|
|
118
|
+
else:
|
|
119
|
+
raise TypeError(f"Invalid type for argument '{name}', expected array or scalar, got {type}")
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
class FfiLaunchDesc:
|
|
123
|
+
def __init__(self, static_inputs, launch_dims):
|
|
124
|
+
self.static_inputs = static_inputs
|
|
125
|
+
self.launch_dims = launch_dims
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
class FfiKernel:
|
|
129
|
+
def __init__(self, kernel, num_outputs, vmap_method, launch_dims, output_dims):
|
|
130
|
+
self.kernel = kernel
|
|
131
|
+
self.name = generate_unique_name(kernel.func)
|
|
132
|
+
self.num_outputs = num_outputs
|
|
133
|
+
self.vmap_method = vmap_method
|
|
134
|
+
self.launch_dims = launch_dims
|
|
135
|
+
self.output_dims = output_dims
|
|
136
|
+
self.first_array_arg = None
|
|
137
|
+
self.launch_id = 0
|
|
138
|
+
self.launch_descriptors = {}
|
|
139
|
+
|
|
140
|
+
self.num_kernel_args = len(kernel.adj.args)
|
|
141
|
+
self.num_inputs = self.num_kernel_args - num_outputs
|
|
142
|
+
if self.num_outputs < 1:
|
|
143
|
+
raise ValueError("At least one output is required")
|
|
144
|
+
if self.num_outputs > self.num_kernel_args:
|
|
145
|
+
raise ValueError("Number of outputs cannot be greater than the number of kernel arguments")
|
|
146
|
+
|
|
147
|
+
# process input args
|
|
148
|
+
self.input_args = []
|
|
149
|
+
for i in range(self.num_inputs):
|
|
150
|
+
arg = FfiArg(kernel.adj.args[i].label, kernel.adj.args[i].type)
|
|
151
|
+
if arg.is_array:
|
|
152
|
+
# keep track of the first input array argument
|
|
153
|
+
if self.first_array_arg is None:
|
|
154
|
+
self.first_array_arg = i
|
|
155
|
+
self.input_args.append(arg)
|
|
156
|
+
|
|
157
|
+
# process output args
|
|
158
|
+
self.output_args = []
|
|
159
|
+
for i in range(self.num_inputs, self.num_kernel_args):
|
|
160
|
+
arg = FfiArg(kernel.adj.args[i].label, kernel.adj.args[i].type)
|
|
161
|
+
if not arg.is_array:
|
|
162
|
+
raise TypeError("All output arguments must be arrays")
|
|
163
|
+
self.output_args.append(arg)
|
|
164
|
+
|
|
165
|
+
# register the callback
|
|
166
|
+
FFI_CCALLFUNC = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_CallFrame))
|
|
167
|
+
self.callback_func = FFI_CCALLFUNC(lambda call_frame: self.ffi_callback(call_frame))
|
|
168
|
+
ffi_ccall_address = ctypes.cast(self.callback_func, ctypes.c_void_p)
|
|
169
|
+
ffi_capsule = jax.ffi.pycapsule(ffi_ccall_address.value)
|
|
170
|
+
jax.ffi.register_ffi_target(self.name, ffi_capsule, platform="CUDA")
|
|
171
|
+
|
|
172
|
+
def __call__(self, *args, output_dims=None, launch_dims=None, vmap_method=None):
|
|
173
|
+
num_inputs = len(args)
|
|
174
|
+
if num_inputs != self.num_inputs:
|
|
175
|
+
raise ValueError(f"Expected {self.num_inputs} inputs, but got {num_inputs}")
|
|
176
|
+
|
|
177
|
+
# default argument fallback
|
|
178
|
+
if launch_dims is None:
|
|
179
|
+
launch_dims = self.launch_dims
|
|
180
|
+
if output_dims is None:
|
|
181
|
+
output_dims = self.output_dims
|
|
182
|
+
if vmap_method is None:
|
|
183
|
+
vmap_method = self.vmap_method
|
|
184
|
+
|
|
185
|
+
# process inputs
|
|
186
|
+
static_inputs = {}
|
|
187
|
+
for i in range(num_inputs):
|
|
188
|
+
input_arg = self.input_args[i]
|
|
189
|
+
input_value = args[i]
|
|
190
|
+
if input_arg.is_array:
|
|
191
|
+
# check dtype
|
|
192
|
+
if input_value.dtype != input_arg.jax_scalar_type:
|
|
193
|
+
raise TypeError(
|
|
194
|
+
f"Invalid data type for array argument '{input_arg.name}', expected {input_arg.jax_scalar_type}, got {input_value.dtype}"
|
|
195
|
+
)
|
|
196
|
+
# check ndim
|
|
197
|
+
if input_value.ndim != input_arg.jax_ndim:
|
|
198
|
+
raise TypeError(
|
|
199
|
+
f"Invalid dimensionality for array argument '{input_arg.name}', expected {input_arg.jax_ndim} dimensions, got {input_value.ndim}"
|
|
200
|
+
)
|
|
201
|
+
# check inner dims
|
|
202
|
+
for d in range(input_arg.dtype_ndim):
|
|
203
|
+
if input_value.shape[input_arg.type.ndim + d] != input_arg.dtype_shape[d]:
|
|
204
|
+
raise TypeError(
|
|
205
|
+
f"Invalid inner dimensions for array argument '{input_arg.name}', expected {input_arg.dtype_shape}, got {input_value.shape[-input_arg.dtype_ndim :]}"
|
|
206
|
+
)
|
|
207
|
+
else:
|
|
208
|
+
# make sure scalar is not a traced variable, should be static
|
|
209
|
+
if isinstance(input_value, jax.core.Tracer):
|
|
210
|
+
raise ValueError(f"Argument '{input_arg.name}' must be a static value")
|
|
211
|
+
# stash the value to be retrieved by callback
|
|
212
|
+
static_inputs[input_arg.name] = input_arg.type(input_value)
|
|
213
|
+
|
|
214
|
+
# launch dimensions
|
|
215
|
+
if launch_dims is None:
|
|
216
|
+
# use the shape of the first input array
|
|
217
|
+
if self.first_array_arg is not None:
|
|
218
|
+
launch_dims = get_warp_shape(self.input_args[self.first_array_arg], args[self.first_array_arg].shape)
|
|
219
|
+
else:
|
|
220
|
+
raise RuntimeError("Failed to determine launch dimensions")
|
|
221
|
+
elif isinstance(launch_dims, int):
|
|
222
|
+
launch_dims = (launch_dims,)
|
|
223
|
+
else:
|
|
224
|
+
launch_dims = tuple(launch_dims)
|
|
225
|
+
|
|
226
|
+
# output types
|
|
227
|
+
out_types = []
|
|
228
|
+
if isinstance(output_dims, dict):
|
|
229
|
+
# assume a dictionary of shapes keyed on argument name
|
|
230
|
+
for output_arg in self.output_args:
|
|
231
|
+
dims = output_dims.get(output_arg.name)
|
|
232
|
+
if dims is None:
|
|
233
|
+
raise ValueError(f"Missing output dimensions for argument '{output_arg.name}'")
|
|
234
|
+
out_types.append(get_jax_output_type(output_arg, dims))
|
|
235
|
+
else:
|
|
236
|
+
if output_dims is None:
|
|
237
|
+
# use launch dimensions
|
|
238
|
+
output_dims = launch_dims
|
|
239
|
+
elif isinstance(output_dims, int):
|
|
240
|
+
output_dims = (output_dims,)
|
|
241
|
+
# assume same dimensions for all outputs
|
|
242
|
+
for output_arg in self.output_args:
|
|
243
|
+
out_types.append(get_jax_output_type(output_arg, output_dims))
|
|
244
|
+
|
|
245
|
+
call = jax.ffi.ffi_call(
|
|
246
|
+
self.name,
|
|
247
|
+
out_types,
|
|
248
|
+
vmap_method=vmap_method,
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
# ensure the kernel module is loaded before the callback, otherwise graph capture may fail
|
|
252
|
+
device = wp.device_from_jax(get_jax_device())
|
|
253
|
+
self.kernel.module.load(device)
|
|
254
|
+
|
|
255
|
+
# save launch data to be retrieved by callback
|
|
256
|
+
launch_id = self.launch_id
|
|
257
|
+
self.launch_descriptors[launch_id] = FfiLaunchDesc(static_inputs, launch_dims)
|
|
258
|
+
self.launch_id += 1
|
|
259
|
+
|
|
260
|
+
return call(*args, launch_id=launch_id)
|
|
261
|
+
|
|
262
|
+
def ffi_callback(self, call_frame):
|
|
263
|
+
try:
|
|
264
|
+
# On the first call, XLA runtime will query the API version and traits
|
|
265
|
+
# metadata using the |extension| field. Let us respond to that query
|
|
266
|
+
# if the metadata extension is present.
|
|
267
|
+
extension = call_frame.contents.extension_start
|
|
268
|
+
if extension:
|
|
269
|
+
# Try to set the version metadata.
|
|
270
|
+
if extension.contents.type == XLA_FFI_Extension_Type.Metadata:
|
|
271
|
+
metadata_ext = ctypes.cast(extension, ctypes.POINTER(XLA_FFI_Metadata_Extension))
|
|
272
|
+
metadata_ext.contents.metadata.contents.api_version.major_version = 0
|
|
273
|
+
metadata_ext.contents.metadata.contents.api_version.minor_version = 1
|
|
274
|
+
# Turn on CUDA graphs for this handler.
|
|
275
|
+
metadata_ext.contents.metadata.contents.traits = (
|
|
276
|
+
XLA_FFI_Handler_TraitsBits.COMMAND_BUFFER_COMPATIBLE
|
|
277
|
+
)
|
|
278
|
+
return None
|
|
279
|
+
|
|
280
|
+
# retrieve call info
|
|
281
|
+
attrs = decode_attrs(call_frame.contents.attrs)
|
|
282
|
+
launch_id = int(attrs["launch_id"])
|
|
283
|
+
launch_desc = self.launch_descriptors[launch_id]
|
|
284
|
+
|
|
285
|
+
num_inputs = call_frame.contents.args.size
|
|
286
|
+
inputs = ctypes.cast(call_frame.contents.args.args, ctypes.POINTER(ctypes.POINTER(XLA_FFI_Buffer)))
|
|
287
|
+
|
|
288
|
+
num_outputs = call_frame.contents.rets.size
|
|
289
|
+
outputs = ctypes.cast(call_frame.contents.rets.rets, ctypes.POINTER(ctypes.POINTER(XLA_FFI_Buffer)))
|
|
290
|
+
|
|
291
|
+
assert num_inputs == self.num_inputs
|
|
292
|
+
assert num_outputs == self.num_outputs
|
|
293
|
+
|
|
294
|
+
launch_bounds = launch_bounds_t(launch_desc.launch_dims)
|
|
295
|
+
|
|
296
|
+
# first kernel param is the launch bounds
|
|
297
|
+
kernel_params = (ctypes.c_void_p * (1 + self.num_kernel_args))()
|
|
298
|
+
kernel_params[0] = ctypes.addressof(launch_bounds)
|
|
299
|
+
|
|
300
|
+
arg_refs = []
|
|
301
|
+
|
|
302
|
+
# inputs
|
|
303
|
+
for i in range(num_inputs):
|
|
304
|
+
input_arg = self.input_args[i]
|
|
305
|
+
if input_arg.is_array:
|
|
306
|
+
buffer = inputs[i].contents
|
|
307
|
+
shape = buffer.dims[: input_arg.type.ndim]
|
|
308
|
+
strides = strides_from_shape(shape, input_arg.type.dtype)
|
|
309
|
+
arg = array_t(buffer.data, 0, input_arg.type.ndim, shape, strides)
|
|
310
|
+
kernel_params[i + 1] = ctypes.addressof(arg)
|
|
311
|
+
arg_refs.append(arg) # keep a reference
|
|
312
|
+
else:
|
|
313
|
+
# scalar argument, get stashed value
|
|
314
|
+
value = launch_desc.static_inputs[input_arg.name]
|
|
315
|
+
arg = input_arg.type._type_(value)
|
|
316
|
+
kernel_params[i + 1] = ctypes.addressof(arg)
|
|
317
|
+
arg_refs.append(arg) # keep a reference
|
|
318
|
+
|
|
319
|
+
# outputs
|
|
320
|
+
for i in range(num_outputs):
|
|
321
|
+
output_arg = self.output_args[i]
|
|
322
|
+
buffer = outputs[i].contents
|
|
323
|
+
shape = buffer.dims[: output_arg.type.ndim]
|
|
324
|
+
strides = strides_from_shape(shape, output_arg.type.dtype)
|
|
325
|
+
arg = array_t(buffer.data, 0, output_arg.type.ndim, shape, strides)
|
|
326
|
+
kernel_params[num_inputs + i + 1] = ctypes.addressof(arg)
|
|
327
|
+
arg_refs.append(arg) # keep a reference
|
|
328
|
+
|
|
329
|
+
# get device and stream
|
|
330
|
+
device = wp.device_from_jax(get_jax_device())
|
|
331
|
+
stream = get_stream_from_callframe(call_frame.contents)
|
|
332
|
+
|
|
333
|
+
# get kernel hooks
|
|
334
|
+
hooks = self.kernel.module.get_kernel_hooks(self.kernel, device)
|
|
335
|
+
assert hooks.forward, "Failed to find kernel entry point"
|
|
336
|
+
|
|
337
|
+
# launch the kernel
|
|
338
|
+
wp.context.runtime.core.cuda_launch_kernel(
|
|
339
|
+
device.context,
|
|
340
|
+
hooks.forward,
|
|
341
|
+
launch_bounds.size,
|
|
342
|
+
0,
|
|
343
|
+
256,
|
|
344
|
+
hooks.forward_smem_bytes,
|
|
345
|
+
kernel_params,
|
|
346
|
+
stream,
|
|
347
|
+
)
|
|
348
|
+
|
|
349
|
+
except Exception as e:
|
|
350
|
+
print(traceback.format_exc())
|
|
351
|
+
return create_ffi_error(
|
|
352
|
+
call_frame.contents.api, XLA_FFI_Error_Code.UNKNOWN, f"FFI callback error: {type(e).__name__}: {e}"
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
class FfiCallDesc:
|
|
357
|
+
def __init__(self, static_inputs):
|
|
358
|
+
self.static_inputs = static_inputs
|
|
359
|
+
|
|
360
|
+
|
|
361
|
+
class FfiCallable:
|
|
362
|
+
def __init__(self, func, num_outputs, graph_compatible, vmap_method, output_dims):
|
|
363
|
+
self.func = func
|
|
364
|
+
self.name = generate_unique_name(func)
|
|
365
|
+
self.num_outputs = num_outputs
|
|
366
|
+
self.vmap_method = vmap_method
|
|
367
|
+
self.graph_compatible = graph_compatible
|
|
368
|
+
self.output_dims = output_dims
|
|
369
|
+
self.first_array_arg = None
|
|
370
|
+
self.has_static_args = False
|
|
371
|
+
self.call_id = 0
|
|
372
|
+
self.call_descriptors = {}
|
|
373
|
+
|
|
374
|
+
# get arguments and annotations
|
|
375
|
+
argspec = get_full_arg_spec(func)
|
|
376
|
+
|
|
377
|
+
num_args = len(argspec.args)
|
|
378
|
+
self.num_inputs = num_args - num_outputs
|
|
379
|
+
if self.num_outputs < 1:
|
|
380
|
+
raise ValueError("At least one output is required")
|
|
381
|
+
if self.num_outputs > num_args:
|
|
382
|
+
raise ValueError("Number of outputs cannot be greater than the number of kernel arguments")
|
|
383
|
+
|
|
384
|
+
if len(argspec.annotations) < num_args:
|
|
385
|
+
raise RuntimeError(f"Incomplete argument annotations on function {self.name}")
|
|
386
|
+
|
|
387
|
+
# parse type annotations
|
|
388
|
+
self.args = []
|
|
389
|
+
arg_idx = 0
|
|
390
|
+
for arg_name, arg_type in argspec.annotations.items():
|
|
391
|
+
if arg_name == "return":
|
|
392
|
+
if arg_type is not None:
|
|
393
|
+
raise TypeError("Function must not return a value")
|
|
394
|
+
else:
|
|
395
|
+
arg = FfiArg(arg_name, arg_type)
|
|
396
|
+
if arg.is_array:
|
|
397
|
+
if arg_idx < self.num_inputs and self.first_array_arg is None:
|
|
398
|
+
self.first_array_arg = arg_idx
|
|
399
|
+
else:
|
|
400
|
+
self.has_static_args = True
|
|
401
|
+
self.args.append(arg)
|
|
402
|
+
arg_idx += 1
|
|
403
|
+
|
|
404
|
+
self.input_args = self.args[: self.num_inputs]
|
|
405
|
+
self.output_args = self.args[self.num_inputs :]
|
|
406
|
+
|
|
407
|
+
# register the callback
|
|
408
|
+
FFI_CCALLFUNC = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_CallFrame))
|
|
409
|
+
self.callback_func = FFI_CCALLFUNC(lambda call_frame: self.ffi_callback(call_frame))
|
|
410
|
+
ffi_ccall_address = ctypes.cast(self.callback_func, ctypes.c_void_p)
|
|
411
|
+
ffi_capsule = jax.ffi.pycapsule(ffi_ccall_address.value)
|
|
412
|
+
jax.ffi.register_ffi_target(self.name, ffi_capsule, platform="CUDA")
|
|
413
|
+
|
|
414
|
+
def __call__(self, *args, output_dims=None, vmap_method=None):
|
|
415
|
+
num_inputs = len(args)
|
|
416
|
+
if num_inputs != self.num_inputs:
|
|
417
|
+
raise ValueError(f"Expected {self.num_inputs} inputs, but got {num_inputs}")
|
|
418
|
+
|
|
419
|
+
# default argument fallback
|
|
420
|
+
if vmap_method is None:
|
|
421
|
+
vmap_method = self.vmap_method
|
|
422
|
+
if output_dims is None:
|
|
423
|
+
output_dims = self.output_dims
|
|
424
|
+
|
|
425
|
+
# process inputs
|
|
426
|
+
static_inputs = {}
|
|
427
|
+
for i in range(num_inputs):
|
|
428
|
+
input_arg = self.input_args[i]
|
|
429
|
+
input_value = args[i]
|
|
430
|
+
if input_arg.is_array:
|
|
431
|
+
# check dtype
|
|
432
|
+
if input_value.dtype != input_arg.jax_scalar_type:
|
|
433
|
+
raise TypeError(
|
|
434
|
+
f"Invalid data type for array argument '{input_arg.name}', expected {input_arg.jax_scalar_type}, got {input_value.dtype}"
|
|
435
|
+
)
|
|
436
|
+
# check ndim
|
|
437
|
+
if input_value.ndim != input_arg.jax_ndim:
|
|
438
|
+
raise TypeError(
|
|
439
|
+
f"Invalid dimensionality for array argument '{input_arg.name}', expected {input_arg.jax_ndim} dimensions, got {input_value.ndim}"
|
|
440
|
+
)
|
|
441
|
+
# check inner dims
|
|
442
|
+
for d in range(input_arg.dtype_ndim):
|
|
443
|
+
if input_value.shape[input_arg.type.ndim + d] != input_arg.dtype_shape[d]:
|
|
444
|
+
raise TypeError(
|
|
445
|
+
f"Invalid inner dimensions for array argument '{input_arg.name}', expected {input_arg.dtype_shape}, got {input_value.shape[-input_arg.dtype_ndim :]}"
|
|
446
|
+
)
|
|
447
|
+
else:
|
|
448
|
+
# make sure scalar is not a traced variable, should be static
|
|
449
|
+
if isinstance(input_value, jax.core.Tracer):
|
|
450
|
+
raise ValueError(f"Argument '{input_arg.name}' must be a static value")
|
|
451
|
+
# stash the value to be retrieved by callback
|
|
452
|
+
static_inputs[input_arg.name] = input_arg.type(input_value)
|
|
453
|
+
|
|
454
|
+
if output_dims is None and self.first_array_arg is not None:
|
|
455
|
+
# use the shape of the first input array
|
|
456
|
+
output_dims = get_warp_shape(self.input_args[self.first_array_arg], args[self.first_array_arg].shape)
|
|
457
|
+
|
|
458
|
+
# output types
|
|
459
|
+
out_types = []
|
|
460
|
+
if isinstance(output_dims, dict):
|
|
461
|
+
# assume a dictionary of shapes keyed on argument name
|
|
462
|
+
for output_arg in self.output_args:
|
|
463
|
+
dims = output_dims.get(output_arg.name)
|
|
464
|
+
if dims is None:
|
|
465
|
+
raise ValueError(f"Missing output dimensions for argument '{output_arg.name}'")
|
|
466
|
+
out_types.append(get_jax_output_type(output_arg, dims))
|
|
467
|
+
else:
|
|
468
|
+
if output_dims is None:
|
|
469
|
+
raise ValueError("Unable to determine output dimensions")
|
|
470
|
+
elif isinstance(output_dims, int):
|
|
471
|
+
output_dims = (output_dims,)
|
|
472
|
+
# assume same dimensions for all outputs
|
|
473
|
+
for output_arg in self.output_args:
|
|
474
|
+
out_types.append(get_jax_output_type(output_arg, output_dims))
|
|
475
|
+
|
|
476
|
+
call = jax.ffi.ffi_call(
|
|
477
|
+
self.name,
|
|
478
|
+
out_types,
|
|
479
|
+
vmap_method=vmap_method,
|
|
480
|
+
# has_side_effect=True, # force this function to execute even if outputs aren't used
|
|
481
|
+
)
|
|
482
|
+
|
|
483
|
+
# load the module
|
|
484
|
+
# NOTE: if the target function uses kernels from different modules, they will not be loaded here
|
|
485
|
+
device = wp.device_from_jax(get_jax_device())
|
|
486
|
+
module = wp.get_module(self.func.__module__)
|
|
487
|
+
module.load(device)
|
|
488
|
+
|
|
489
|
+
if self.has_static_args:
|
|
490
|
+
# save call data to be retrieved by callback
|
|
491
|
+
call_id = self.call_id
|
|
492
|
+
self.call_descriptors[call_id] = FfiCallDesc(static_inputs)
|
|
493
|
+
self.call_id += 1
|
|
494
|
+
return call(*args, call_id=call_id)
|
|
495
|
+
else:
|
|
496
|
+
return call(*args)
|
|
497
|
+
|
|
498
|
+
def ffi_callback(self, call_frame):
|
|
499
|
+
try:
|
|
500
|
+
# TODO Try-catch around the body and return XLA_FFI_Error on error.
|
|
501
|
+
extension = call_frame.contents.extension_start
|
|
502
|
+
# On the first call, XLA runtime will query the API version and traits
|
|
503
|
+
# metadata using the |extension| field. Let us respond to that query
|
|
504
|
+
# if the metadata extension is present.
|
|
505
|
+
if extension:
|
|
506
|
+
# Try to set the version metadata.
|
|
507
|
+
if extension.contents.type == XLA_FFI_Extension_Type.Metadata:
|
|
508
|
+
metadata_ext = ctypes.cast(extension, ctypes.POINTER(XLA_FFI_Metadata_Extension))
|
|
509
|
+
metadata_ext.contents.metadata.contents.api_version.major_version = 0
|
|
510
|
+
metadata_ext.contents.metadata.contents.api_version.minor_version = 1
|
|
511
|
+
# Turn on CUDA graphs for this handler.
|
|
512
|
+
if self.graph_compatible:
|
|
513
|
+
metadata_ext.contents.metadata.contents.traits = (
|
|
514
|
+
XLA_FFI_Handler_TraitsBits.COMMAND_BUFFER_COMPATIBLE
|
|
515
|
+
)
|
|
516
|
+
return None
|
|
517
|
+
|
|
518
|
+
if self.has_static_args:
|
|
519
|
+
# retrieve call info
|
|
520
|
+
attrs = decode_attrs(call_frame.contents.attrs)
|
|
521
|
+
call_id = int(attrs["call_id"])
|
|
522
|
+
call_desc = self.call_descriptors[call_id]
|
|
523
|
+
|
|
524
|
+
num_inputs = call_frame.contents.args.size
|
|
525
|
+
inputs = ctypes.cast(call_frame.contents.args.args, ctypes.POINTER(ctypes.POINTER(XLA_FFI_Buffer)))
|
|
526
|
+
|
|
527
|
+
num_outputs = call_frame.contents.rets.size
|
|
528
|
+
outputs = ctypes.cast(call_frame.contents.rets.rets, ctypes.POINTER(ctypes.POINTER(XLA_FFI_Buffer)))
|
|
529
|
+
|
|
530
|
+
assert num_inputs == self.num_inputs
|
|
531
|
+
assert num_outputs == self.num_outputs
|
|
532
|
+
|
|
533
|
+
device = wp.device_from_jax(get_jax_device())
|
|
534
|
+
cuda_stream = get_stream_from_callframe(call_frame.contents)
|
|
535
|
+
stream = wp.Stream(device, cuda_stream=cuda_stream)
|
|
536
|
+
|
|
537
|
+
# reconstruct the argument list
|
|
538
|
+
arg_list = []
|
|
539
|
+
|
|
540
|
+
# inputs
|
|
541
|
+
for i in range(num_inputs):
|
|
542
|
+
arg = self.input_args[i]
|
|
543
|
+
if arg.is_array:
|
|
544
|
+
buffer = inputs[i].contents
|
|
545
|
+
shape = buffer.dims[: buffer.rank - arg.dtype_ndim]
|
|
546
|
+
arr = wp.array(ptr=buffer.data, dtype=arg.type.dtype, shape=shape, device=device)
|
|
547
|
+
arg_list.append(arr)
|
|
548
|
+
else:
|
|
549
|
+
# scalar argument, get stashed value
|
|
550
|
+
value = call_desc.static_inputs[arg.name]
|
|
551
|
+
arg_list.append(value)
|
|
552
|
+
|
|
553
|
+
# outputs
|
|
554
|
+
for i in range(num_outputs):
|
|
555
|
+
arg = self.output_args[i]
|
|
556
|
+
buffer = outputs[i].contents
|
|
557
|
+
shape = buffer.dims[: buffer.rank - arg.dtype_ndim]
|
|
558
|
+
arr = wp.array(ptr=buffer.data, dtype=arg.type.dtype, shape=shape, device=device)
|
|
559
|
+
arg_list.append(arr)
|
|
560
|
+
|
|
561
|
+
# call the Python function with reconstructed arguments
|
|
562
|
+
with wp.ScopedStream(stream, sync_enter=False):
|
|
563
|
+
self.func(*arg_list)
|
|
564
|
+
|
|
565
|
+
except Exception as e:
|
|
566
|
+
print(traceback.format_exc())
|
|
567
|
+
return create_ffi_error(
|
|
568
|
+
call_frame.contents.api, XLA_FFI_Error_Code.UNKNOWN, f"FFI callback error: {type(e).__name__}: {e}"
|
|
569
|
+
)
|
|
570
|
+
|
|
571
|
+
return None
|
|
572
|
+
|
|
573
|
+
|
|
574
|
+
###############################################################################
|
|
575
|
+
#
|
|
576
|
+
# Generic FFI callbacks for Python functions of the form
|
|
577
|
+
# func(inputs, outputs, attrs, ctx)
|
|
578
|
+
#
|
|
579
|
+
###############################################################################
|
|
580
|
+
|
|
581
|
+
# Holder for the custom callbacks to keep them alive.
|
|
582
|
+
ffi_callbacks = {}
|
|
583
|
+
|
|
584
|
+
|
|
585
|
+
def register_ffi_callback(name: str, func: Callable, graph_compatible: bool = True) -> None:
|
|
586
|
+
"""Create a JAX callback from a Python function.
|
|
587
|
+
|
|
588
|
+
The Python function must have the form ``func(inputs, outputs, attrs, ctx)``.
|
|
589
|
+
|
|
590
|
+
NOTE: This is an experimental feature under development.
|
|
591
|
+
|
|
592
|
+
Args:
|
|
593
|
+
name: A unique FFI callback name.
|
|
594
|
+
func: The Python function to call.
|
|
595
|
+
graph_compatible: Optional. Whether the function can be called during CUDA graph capture.
|
|
596
|
+
"""
|
|
597
|
+
|
|
598
|
+
# TODO check that the name is not already registered
|
|
599
|
+
|
|
600
|
+
def ffi_callback(call_frame):
|
|
601
|
+
try:
|
|
602
|
+
# TODO Try-catch around the body and return XLA_FFI_Error on error.
|
|
603
|
+
extension = call_frame.contents.extension_start
|
|
604
|
+
# On the first call, XLA runtime will query the API version and traits
|
|
605
|
+
# metadata using the |extension| field. Let us respond to that query
|
|
606
|
+
# if the metadata extension is present.
|
|
607
|
+
if extension:
|
|
608
|
+
# Try to set the version metadata.
|
|
609
|
+
if extension.contents.type == XLA_FFI_Extension_Type.Metadata:
|
|
610
|
+
metadata_ext = ctypes.cast(extension, ctypes.POINTER(XLA_FFI_Metadata_Extension))
|
|
611
|
+
metadata_ext.contents.metadata.contents.api_version.major_version = 0
|
|
612
|
+
metadata_ext.contents.metadata.contents.api_version.minor_version = 1
|
|
613
|
+
if graph_compatible:
|
|
614
|
+
# Turn on CUDA graphs for this handler.
|
|
615
|
+
metadata_ext.contents.metadata.contents.traits = (
|
|
616
|
+
XLA_FFI_Handler_TraitsBits.COMMAND_BUFFER_COMPATIBLE
|
|
617
|
+
)
|
|
618
|
+
return None
|
|
619
|
+
|
|
620
|
+
attrs = decode_attrs(call_frame.contents.attrs)
|
|
621
|
+
|
|
622
|
+
input_count = call_frame.contents.args.size
|
|
623
|
+
inputs = ctypes.cast(call_frame.contents.args.args, ctypes.POINTER(ctypes.POINTER(XLA_FFI_Buffer)))
|
|
624
|
+
inputs = [FfiBuffer(inputs[i].contents) for i in range(input_count)]
|
|
625
|
+
|
|
626
|
+
output_count = call_frame.contents.rets.size
|
|
627
|
+
outputs = ctypes.cast(call_frame.contents.rets.rets, ctypes.POINTER(ctypes.POINTER(XLA_FFI_Buffer)))
|
|
628
|
+
outputs = [FfiBuffer(outputs[i].contents) for i in range(output_count)]
|
|
629
|
+
|
|
630
|
+
ctx = ExecutionContext(call_frame.contents)
|
|
631
|
+
|
|
632
|
+
func(inputs, outputs, attrs, ctx)
|
|
633
|
+
except Exception as e:
|
|
634
|
+
print(traceback.format_exc())
|
|
635
|
+
return create_ffi_error(
|
|
636
|
+
call_frame.contents.api, XLA_FFI_Error_Code.UNKNOWN, f"FFI callback error: {type(e).__name__}: {e}"
|
|
637
|
+
)
|
|
638
|
+
|
|
639
|
+
return None
|
|
640
|
+
|
|
641
|
+
FFI_CCALLFUNC = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_CallFrame))
|
|
642
|
+
callback_func = FFI_CCALLFUNC(ffi_callback)
|
|
643
|
+
ffi_callbacks[name] = callback_func
|
|
644
|
+
ffi_ccall_address = ctypes.cast(callback_func, ctypes.c_void_p)
|
|
645
|
+
ffi_capsule = jax.ffi.pycapsule(ffi_ccall_address.value)
|
|
646
|
+
jax.ffi.register_ffi_target(name, ffi_capsule, platform="CUDA")
|
|
647
|
+
|
|
648
|
+
|
|
649
|
+
###############################################################################
|
|
650
|
+
#
|
|
651
|
+
# Utilities
|
|
652
|
+
#
|
|
653
|
+
###############################################################################
|
|
654
|
+
|
|
655
|
+
# ensure unique FFI callback names
|
|
656
|
+
ffi_name_counts = {}
|
|
657
|
+
|
|
658
|
+
|
|
659
|
+
def generate_unique_name(func) -> str:
|
|
660
|
+
key = make_full_qualified_name(func)
|
|
661
|
+
unique_id = ffi_name_counts.get(key, 0)
|
|
662
|
+
ffi_name_counts[key] = unique_id + 1
|
|
663
|
+
return f"{key}_{unique_id}"
|
|
664
|
+
|
|
665
|
+
|
|
666
|
+
def get_warp_shape(arg, dims):
|
|
667
|
+
if arg.dtype_ndim > 0:
|
|
668
|
+
# vector/matrix array
|
|
669
|
+
return dims[: arg.warp_ndim]
|
|
670
|
+
else:
|
|
671
|
+
# scalar array
|
|
672
|
+
return dims
|
|
673
|
+
|
|
674
|
+
|
|
675
|
+
def get_jax_output_type(arg, dims):
|
|
676
|
+
if isinstance(dims, int):
|
|
677
|
+
dims = (dims,)
|
|
678
|
+
|
|
679
|
+
ndim = len(dims)
|
|
680
|
+
|
|
681
|
+
if arg.dtype_ndim > 0:
|
|
682
|
+
# vector/matrix array
|
|
683
|
+
if ndim == arg.warp_ndim:
|
|
684
|
+
return jax.ShapeDtypeStruct((*dims, *arg.dtype_shape), arg.jax_scalar_type)
|
|
685
|
+
elif ndim == arg.jax_ndim:
|
|
686
|
+
# make sure inner dimensions match
|
|
687
|
+
inner_dims = dims[-arg.dtype_ndim :]
|
|
688
|
+
for i in range(arg.dtype_ndim):
|
|
689
|
+
if inner_dims[i] != arg.dtype_shape[i]:
|
|
690
|
+
raise ValueError(f"Invalid output dimensions for argument '{arg.name}': {dims}")
|
|
691
|
+
return jax.ShapeDtypeStruct(dims, arg.jax_scalar_type)
|
|
692
|
+
else:
|
|
693
|
+
raise ValueError(f"Invalid output dimensions for argument '{arg.name}': {dims}")
|
|
694
|
+
else:
|
|
695
|
+
# scalar array
|
|
696
|
+
if ndim != arg.warp_ndim:
|
|
697
|
+
raise ValueError(f"Invalid output dimensions for argument '{arg.name}': {dims}")
|
|
698
|
+
return jax.ShapeDtypeStruct(dims, arg.jax_scalar_type)
|