warp-lang 1.7.2rc1__py3-none-macosx_10_13_universal2.whl → 1.8.1__py3-none-macosx_10_13_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +3 -1
- warp/__init__.pyi +3489 -1
- warp/autograd.py +45 -122
- warp/bin/libwarp.dylib +0 -0
- warp/build.py +241 -252
- warp/build_dll.py +130 -26
- warp/builtins.py +1907 -384
- warp/codegen.py +272 -104
- warp/config.py +12 -1
- warp/constants.py +1 -1
- warp/context.py +770 -238
- warp/dlpack.py +1 -1
- warp/examples/benchmarks/benchmark_cloth.py +2 -2
- warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
- warp/examples/core/example_sample_mesh.py +1 -1
- warp/examples/core/example_spin_lock.py +93 -0
- warp/examples/core/example_work_queue.py +118 -0
- warp/examples/fem/example_adaptive_grid.py +5 -5
- warp/examples/fem/example_apic_fluid.py +1 -1
- warp/examples/fem/example_burgers.py +1 -1
- warp/examples/fem/example_convection_diffusion.py +9 -6
- warp/examples/fem/example_darcy_ls_optimization.py +489 -0
- warp/examples/fem/example_deformed_geometry.py +1 -1
- warp/examples/fem/example_diffusion.py +2 -2
- warp/examples/fem/example_diffusion_3d.py +1 -1
- warp/examples/fem/example_distortion_energy.py +1 -1
- warp/examples/fem/example_elastic_shape_optimization.py +387 -0
- warp/examples/fem/example_magnetostatics.py +5 -3
- warp/examples/fem/example_mixed_elasticity.py +5 -3
- warp/examples/fem/example_navier_stokes.py +11 -9
- warp/examples/fem/example_nonconforming_contact.py +5 -3
- warp/examples/fem/example_streamlines.py +8 -3
- warp/examples/fem/utils.py +9 -8
- warp/examples/interop/example_jax_callable.py +34 -4
- warp/examples/interop/example_jax_ffi_callback.py +2 -2
- warp/examples/interop/example_jax_kernel.py +27 -1
- warp/examples/optim/example_drone.py +1 -1
- warp/examples/sim/example_cloth.py +1 -1
- warp/examples/sim/example_cloth_self_contact.py +48 -54
- warp/examples/tile/example_tile_block_cholesky.py +502 -0
- warp/examples/tile/example_tile_cholesky.py +2 -1
- warp/examples/tile/example_tile_convolution.py +1 -1
- warp/examples/tile/example_tile_filtering.py +1 -1
- warp/examples/tile/example_tile_matmul.py +1 -1
- warp/examples/tile/example_tile_mlp.py +2 -0
- warp/fabric.py +7 -7
- warp/fem/__init__.py +5 -0
- warp/fem/adaptivity.py +1 -1
- warp/fem/cache.py +152 -63
- warp/fem/dirichlet.py +2 -2
- warp/fem/domain.py +136 -6
- warp/fem/field/field.py +141 -99
- warp/fem/field/nodal_field.py +85 -39
- warp/fem/field/virtual.py +99 -52
- warp/fem/geometry/adaptive_nanogrid.py +91 -86
- warp/fem/geometry/closest_point.py +13 -0
- warp/fem/geometry/deformed_geometry.py +102 -40
- warp/fem/geometry/element.py +56 -2
- warp/fem/geometry/geometry.py +323 -22
- warp/fem/geometry/grid_2d.py +157 -62
- warp/fem/geometry/grid_3d.py +116 -20
- warp/fem/geometry/hexmesh.py +86 -20
- warp/fem/geometry/nanogrid.py +166 -86
- warp/fem/geometry/partition.py +59 -25
- warp/fem/geometry/quadmesh.py +86 -135
- warp/fem/geometry/tetmesh.py +47 -119
- warp/fem/geometry/trimesh.py +77 -270
- warp/fem/integrate.py +181 -95
- warp/fem/linalg.py +25 -58
- warp/fem/operator.py +124 -27
- warp/fem/quadrature/pic_quadrature.py +36 -14
- warp/fem/quadrature/quadrature.py +40 -16
- warp/fem/space/__init__.py +1 -1
- warp/fem/space/basis_function_space.py +66 -46
- warp/fem/space/basis_space.py +17 -4
- warp/fem/space/dof_mapper.py +1 -1
- warp/fem/space/function_space.py +2 -2
- warp/fem/space/grid_2d_function_space.py +4 -1
- warp/fem/space/hexmesh_function_space.py +4 -2
- warp/fem/space/nanogrid_function_space.py +3 -1
- warp/fem/space/partition.py +11 -2
- warp/fem/space/quadmesh_function_space.py +4 -1
- warp/fem/space/restriction.py +5 -2
- warp/fem/space/shape/__init__.py +10 -8
- warp/fem/space/tetmesh_function_space.py +4 -1
- warp/fem/space/topology.py +52 -21
- warp/fem/space/trimesh_function_space.py +4 -1
- warp/fem/utils.py +53 -8
- warp/jax.py +1 -2
- warp/jax_experimental/ffi.py +210 -67
- warp/jax_experimental/xla_ffi.py +37 -24
- warp/math.py +171 -1
- warp/native/array.h +103 -4
- warp/native/builtin.h +182 -35
- warp/native/coloring.cpp +6 -2
- warp/native/cuda_util.cpp +1 -1
- warp/native/exports.h +118 -63
- warp/native/intersect.h +5 -5
- warp/native/mat.h +8 -13
- warp/native/mathdx.cpp +11 -5
- warp/native/matnn.h +1 -123
- warp/native/mesh.h +1 -1
- warp/native/quat.h +34 -6
- warp/native/rand.h +7 -7
- warp/native/sparse.cpp +121 -258
- warp/native/sparse.cu +181 -274
- warp/native/spatial.h +305 -17
- warp/native/svd.h +23 -8
- warp/native/tile.h +603 -73
- warp/native/tile_radix_sort.h +1112 -0
- warp/native/tile_reduce.h +239 -13
- warp/native/tile_scan.h +240 -0
- warp/native/tuple.h +189 -0
- warp/native/vec.h +10 -20
- warp/native/warp.cpp +36 -4
- warp/native/warp.cu +588 -52
- warp/native/warp.h +47 -74
- warp/optim/linear.py +5 -1
- warp/paddle.py +7 -8
- warp/py.typed +0 -0
- warp/render/render_opengl.py +110 -80
- warp/render/render_usd.py +124 -62
- warp/sim/__init__.py +9 -0
- warp/sim/collide.py +253 -80
- warp/sim/graph_coloring.py +8 -1
- warp/sim/import_mjcf.py +4 -3
- warp/sim/import_usd.py +11 -7
- warp/sim/integrator.py +5 -2
- warp/sim/integrator_euler.py +1 -1
- warp/sim/integrator_featherstone.py +1 -1
- warp/sim/integrator_vbd.py +761 -322
- warp/sim/integrator_xpbd.py +1 -1
- warp/sim/model.py +265 -260
- warp/sim/utils.py +10 -7
- warp/sparse.py +303 -166
- warp/tape.py +54 -51
- warp/tests/cuda/test_conditional_captures.py +1046 -0
- warp/tests/cuda/test_streams.py +1 -1
- warp/tests/geometry/test_volume.py +2 -2
- warp/tests/interop/test_dlpack.py +9 -9
- warp/tests/interop/test_jax.py +0 -1
- warp/tests/run_coverage_serial.py +1 -1
- warp/tests/sim/disabled_kinematics.py +2 -2
- warp/tests/sim/{test_vbd.py → test_cloth.py} +378 -112
- warp/tests/sim/test_collision.py +159 -51
- warp/tests/sim/test_coloring.py +91 -2
- warp/tests/test_array.py +254 -2
- warp/tests/test_array_reduce.py +2 -2
- warp/tests/test_assert.py +53 -0
- warp/tests/test_atomic_cas.py +312 -0
- warp/tests/test_codegen.py +142 -19
- warp/tests/test_conditional.py +47 -1
- warp/tests/test_ctypes.py +0 -20
- warp/tests/test_devices.py +8 -0
- warp/tests/test_fabricarray.py +4 -2
- warp/tests/test_fem.py +58 -25
- warp/tests/test_func.py +42 -1
- warp/tests/test_grad.py +1 -1
- warp/tests/test_lerp.py +1 -3
- warp/tests/test_map.py +481 -0
- warp/tests/test_mat.py +23 -24
- warp/tests/test_quat.py +28 -15
- warp/tests/test_rounding.py +10 -38
- warp/tests/test_runlength_encode.py +7 -7
- warp/tests/test_smoothstep.py +1 -1
- warp/tests/test_sparse.py +83 -2
- warp/tests/test_spatial.py +507 -1
- warp/tests/test_static.py +48 -0
- warp/tests/test_struct.py +2 -2
- warp/tests/test_tape.py +38 -0
- warp/tests/test_tuple.py +265 -0
- warp/tests/test_types.py +2 -2
- warp/tests/test_utils.py +24 -18
- warp/tests/test_vec.py +38 -408
- warp/tests/test_vec_constructors.py +325 -0
- warp/tests/tile/test_tile.py +438 -131
- warp/tests/tile/test_tile_mathdx.py +518 -14
- warp/tests/tile/test_tile_matmul.py +179 -0
- warp/tests/tile/test_tile_reduce.py +307 -5
- warp/tests/tile/test_tile_shared_memory.py +136 -7
- warp/tests/tile/test_tile_sort.py +121 -0
- warp/tests/unittest_suites.py +14 -6
- warp/types.py +462 -308
- warp/utils.py +647 -86
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/METADATA +20 -6
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/RECORD +189 -175
- warp/stubs.py +0 -3381
- warp/tests/sim/test_xpbd.py +0 -399
- warp/tests/test_mlp.py +0 -282
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/WHEEL +0 -0
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/licenses/LICENSE.md +0 -0
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/top_level.txt +0 -0
warp/jax_experimental/ffi.py
CHANGED
|
@@ -16,7 +16,8 @@
|
|
|
16
16
|
import ctypes
|
|
17
17
|
import threading
|
|
18
18
|
import traceback
|
|
19
|
-
from
|
|
19
|
+
from enum import IntEnum
|
|
20
|
+
from typing import Callable, Optional
|
|
20
21
|
|
|
21
22
|
import jax
|
|
22
23
|
|
|
@@ -28,10 +29,17 @@ from warp.types import array_t, launch_bounds_t, strides_from_shape, type_to_war
|
|
|
28
29
|
from .xla_ffi import *
|
|
29
30
|
|
|
30
31
|
|
|
32
|
+
class GraphMode(IntEnum):
|
|
33
|
+
NONE = 0 # don't capture a graph
|
|
34
|
+
JAX = 1 # let JAX capture a graph
|
|
35
|
+
WARP = 2 # let Warp capture a graph
|
|
36
|
+
|
|
37
|
+
|
|
31
38
|
class FfiArg:
|
|
32
|
-
def __init__(self, name, type):
|
|
39
|
+
def __init__(self, name, type, in_out=False):
|
|
33
40
|
self.name = name
|
|
34
41
|
self.type = type
|
|
42
|
+
self.in_out = in_out
|
|
35
43
|
self.is_array = isinstance(type, wp.array)
|
|
36
44
|
|
|
37
45
|
if self.is_array:
|
|
@@ -65,7 +73,7 @@ class FfiLaunchDesc:
|
|
|
65
73
|
|
|
66
74
|
|
|
67
75
|
class FfiKernel:
|
|
68
|
-
def __init__(self, kernel, num_outputs, vmap_method, launch_dims, output_dims):
|
|
76
|
+
def __init__(self, kernel, num_outputs, vmap_method, launch_dims, output_dims, in_out_argnames):
|
|
69
77
|
self.kernel = kernel
|
|
70
78
|
self.name = generate_unique_name(kernel.func)
|
|
71
79
|
self.num_outputs = num_outputs
|
|
@@ -76,17 +84,28 @@ class FfiKernel:
|
|
|
76
84
|
self.launch_id = 0
|
|
77
85
|
self.launch_descriptors = {}
|
|
78
86
|
|
|
87
|
+
in_out_argnames_list = in_out_argnames or []
|
|
88
|
+
in_out_argnames = set(in_out_argnames_list)
|
|
89
|
+
if len(in_out_argnames_list) != len(in_out_argnames):
|
|
90
|
+
raise AssertionError("in_out_argnames must not contain duplicate names")
|
|
91
|
+
|
|
79
92
|
self.num_kernel_args = len(kernel.adj.args)
|
|
80
|
-
self.
|
|
93
|
+
self.num_in_out = len(in_out_argnames)
|
|
94
|
+
self.num_inputs = self.num_kernel_args - num_outputs + self.num_in_out
|
|
81
95
|
if self.num_outputs < 1:
|
|
82
96
|
raise ValueError("At least one output is required")
|
|
83
97
|
if self.num_outputs > self.num_kernel_args:
|
|
84
98
|
raise ValueError("Number of outputs cannot be greater than the number of kernel arguments")
|
|
99
|
+
if self.num_outputs < self.num_in_out:
|
|
100
|
+
raise ValueError("Number of outputs cannot be smaller than the number of in_out_argnames")
|
|
85
101
|
|
|
86
102
|
# process input args
|
|
87
103
|
self.input_args = []
|
|
88
104
|
for i in range(self.num_inputs):
|
|
89
|
-
|
|
105
|
+
arg_name = kernel.adj.args[i].label
|
|
106
|
+
arg = FfiArg(arg_name, kernel.adj.args[i].type, arg_name in in_out_argnames)
|
|
107
|
+
if arg_name in in_out_argnames:
|
|
108
|
+
in_out_argnames.remove(arg_name)
|
|
90
109
|
if arg.is_array:
|
|
91
110
|
# keep track of the first input array argument
|
|
92
111
|
if self.first_array_arg is None:
|
|
@@ -96,11 +115,30 @@ class FfiKernel:
|
|
|
96
115
|
# process output args
|
|
97
116
|
self.output_args = []
|
|
98
117
|
for i in range(self.num_inputs, self.num_kernel_args):
|
|
99
|
-
|
|
118
|
+
arg_name = kernel.adj.args[i].label
|
|
119
|
+
if arg_name in in_out_argnames:
|
|
120
|
+
raise AssertionError(
|
|
121
|
+
f"Expected an output-only argument for argument {arg_name}."
|
|
122
|
+
" in_out arguments should be placed before output-only arguments."
|
|
123
|
+
)
|
|
124
|
+
arg = FfiArg(arg_name, kernel.adj.args[i].type, False)
|
|
100
125
|
if not arg.is_array:
|
|
101
126
|
raise TypeError("All output arguments must be arrays")
|
|
102
127
|
self.output_args.append(arg)
|
|
103
128
|
|
|
129
|
+
if in_out_argnames:
|
|
130
|
+
raise ValueError(f"in_out_argnames: '{in_out_argnames}' did not match any function argument names.")
|
|
131
|
+
|
|
132
|
+
# Build input output aliases.
|
|
133
|
+
out_id = 0
|
|
134
|
+
input_output_aliases = {}
|
|
135
|
+
for in_id, arg in enumerate(self.input_args):
|
|
136
|
+
if not arg.in_out:
|
|
137
|
+
continue
|
|
138
|
+
input_output_aliases[in_id] = out_id
|
|
139
|
+
out_id += 1
|
|
140
|
+
self.input_output_aliases = input_output_aliases
|
|
141
|
+
|
|
104
142
|
# register the callback
|
|
105
143
|
FFI_CCALLFUNC = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_CallFrame))
|
|
106
144
|
self.callback_func = FFI_CCALLFUNC(lambda call_frame: self.ffi_callback(call_frame))
|
|
@@ -121,6 +159,9 @@ class FfiKernel:
|
|
|
121
159
|
if vmap_method is None:
|
|
122
160
|
vmap_method = self.vmap_method
|
|
123
161
|
|
|
162
|
+
# output types
|
|
163
|
+
out_types = []
|
|
164
|
+
|
|
124
165
|
# process inputs
|
|
125
166
|
static_inputs = {}
|
|
126
167
|
for i in range(num_inputs):
|
|
@@ -150,6 +191,10 @@ class FfiKernel:
|
|
|
150
191
|
# stash the value to be retrieved by callback
|
|
151
192
|
static_inputs[input_arg.name] = input_arg.type(input_value)
|
|
152
193
|
|
|
194
|
+
# append in-out arg to output types
|
|
195
|
+
if input_arg.in_out:
|
|
196
|
+
out_types.append(get_jax_output_type(input_arg, input_value.shape))
|
|
197
|
+
|
|
153
198
|
# launch dimensions
|
|
154
199
|
if launch_dims is None:
|
|
155
200
|
# use the shape of the first input array
|
|
@@ -162,8 +207,7 @@ class FfiKernel:
|
|
|
162
207
|
else:
|
|
163
208
|
launch_dims = tuple(launch_dims)
|
|
164
209
|
|
|
165
|
-
# output
|
|
166
|
-
out_types = []
|
|
210
|
+
# output shapes
|
|
167
211
|
if isinstance(output_dims, dict):
|
|
168
212
|
# assume a dictionary of shapes keyed on argument name
|
|
169
213
|
for output_arg in self.output_args:
|
|
@@ -185,6 +229,7 @@ class FfiKernel:
|
|
|
185
229
|
self.name,
|
|
186
230
|
out_types,
|
|
187
231
|
vmap_method=vmap_method,
|
|
232
|
+
input_output_aliases=self.input_output_aliases,
|
|
188
233
|
)
|
|
189
234
|
|
|
190
235
|
# ensure the kernel module is loaded before the callback, otherwise graph capture may fail
|
|
@@ -238,9 +283,8 @@ class FfiKernel:
|
|
|
238
283
|
|
|
239
284
|
arg_refs = []
|
|
240
285
|
|
|
241
|
-
#
|
|
242
|
-
for i in
|
|
243
|
-
input_arg = self.input_args[i]
|
|
286
|
+
# input and in-out args
|
|
287
|
+
for i, input_arg in enumerate(self.input_args):
|
|
244
288
|
if input_arg.is_array:
|
|
245
289
|
buffer = inputs[i].contents
|
|
246
290
|
shape = buffer.dims[: input_arg.type.ndim]
|
|
@@ -255,10 +299,9 @@ class FfiKernel:
|
|
|
255
299
|
kernel_params[i + 1] = ctypes.addressof(arg)
|
|
256
300
|
arg_refs.append(arg) # keep a reference
|
|
257
301
|
|
|
258
|
-
#
|
|
259
|
-
for i in
|
|
260
|
-
|
|
261
|
-
buffer = outputs[i].contents
|
|
302
|
+
# pure output args (skip in-out FFI buffers)
|
|
303
|
+
for i, output_arg in enumerate(self.output_args):
|
|
304
|
+
buffer = outputs[i + self.num_in_out].contents
|
|
262
305
|
shape = buffer.dims[: output_arg.type.ndim]
|
|
263
306
|
strides = strides_from_shape(shape, output_arg.type.dtype)
|
|
264
307
|
arg = array_t(buffer.data, 0, output_arg.type.ndim, shape, strides)
|
|
@@ -295,30 +338,38 @@ class FfiKernel:
|
|
|
295
338
|
class FfiCallDesc:
|
|
296
339
|
def __init__(self, static_inputs):
|
|
297
340
|
self.static_inputs = static_inputs
|
|
341
|
+
self.captures = {}
|
|
298
342
|
|
|
299
343
|
|
|
300
344
|
class FfiCallable:
|
|
301
|
-
def __init__(self, func, num_outputs,
|
|
345
|
+
def __init__(self, func, num_outputs, graph_mode, vmap_method, output_dims, in_out_argnames):
|
|
302
346
|
self.func = func
|
|
303
347
|
self.name = generate_unique_name(func)
|
|
304
348
|
self.num_outputs = num_outputs
|
|
305
349
|
self.vmap_method = vmap_method
|
|
306
|
-
self.
|
|
350
|
+
self.graph_mode = graph_mode
|
|
307
351
|
self.output_dims = output_dims
|
|
308
352
|
self.first_array_arg = None
|
|
309
|
-
self.has_static_args = False
|
|
310
353
|
self.call_id = 0
|
|
311
354
|
self.call_descriptors = {}
|
|
312
355
|
|
|
356
|
+
in_out_argnames_list = in_out_argnames or []
|
|
357
|
+
in_out_argnames = set(in_out_argnames_list)
|
|
358
|
+
if len(in_out_argnames_list) != len(in_out_argnames):
|
|
359
|
+
raise AssertionError("in_out_argnames must not contain duplicate names")
|
|
360
|
+
|
|
313
361
|
# get arguments and annotations
|
|
314
362
|
argspec = get_full_arg_spec(func)
|
|
315
363
|
|
|
316
364
|
num_args = len(argspec.args)
|
|
317
|
-
self.
|
|
365
|
+
self.num_in_out = len(in_out_argnames)
|
|
366
|
+
self.num_inputs = num_args - num_outputs + self.num_in_out
|
|
318
367
|
if self.num_outputs < 1:
|
|
319
368
|
raise ValueError("At least one output is required")
|
|
320
369
|
if self.num_outputs > num_args:
|
|
321
370
|
raise ValueError("Number of outputs cannot be greater than the number of kernel arguments")
|
|
371
|
+
if self.num_outputs < self.num_in_out:
|
|
372
|
+
raise ValueError("Number of outputs cannot be smaller than the number of in_out_argnames")
|
|
322
373
|
|
|
323
374
|
if len(argspec.annotations) < num_args:
|
|
324
375
|
raise RuntimeError(f"Incomplete argument annotations on function {self.name}")
|
|
@@ -331,17 +382,43 @@ class FfiCallable:
|
|
|
331
382
|
if arg_type is not None:
|
|
332
383
|
raise TypeError("Function must not return a value")
|
|
333
384
|
else:
|
|
334
|
-
arg = FfiArg(arg_name, arg_type)
|
|
385
|
+
arg = FfiArg(arg_name, arg_type, arg_name in in_out_argnames)
|
|
386
|
+
if arg_name in in_out_argnames:
|
|
387
|
+
in_out_argnames.remove(arg_name)
|
|
335
388
|
if arg.is_array:
|
|
336
389
|
if arg_idx < self.num_inputs and self.first_array_arg is None:
|
|
337
390
|
self.first_array_arg = arg_idx
|
|
338
|
-
else:
|
|
339
|
-
self.has_static_args = True
|
|
340
391
|
self.args.append(arg)
|
|
392
|
+
|
|
393
|
+
if arg.in_out and arg_idx >= self.num_inputs:
|
|
394
|
+
raise AssertionError(
|
|
395
|
+
f"Expected an output-only argument for argument {arg_name}."
|
|
396
|
+
" in_out arguments should be placed before output-only arguments."
|
|
397
|
+
)
|
|
398
|
+
|
|
341
399
|
arg_idx += 1
|
|
342
400
|
|
|
343
|
-
|
|
344
|
-
|
|
401
|
+
if in_out_argnames:
|
|
402
|
+
raise ValueError(f"in_out_argnames: '{in_out_argnames}' did not match any function argument names.")
|
|
403
|
+
|
|
404
|
+
self.input_args = self.args[: self.num_inputs] # includes in-out args
|
|
405
|
+
self.output_args = self.args[self.num_inputs :] # pure output args
|
|
406
|
+
|
|
407
|
+
# Buffer indices for array arguments in callback.
|
|
408
|
+
# In-out buffers are the same pointers in the XLA call frame,
|
|
409
|
+
# so we only include them for inputs and skip them for outputs.
|
|
410
|
+
self.array_input_indices = [i for i, arg in enumerate(self.input_args) if arg.is_array]
|
|
411
|
+
self.array_output_indices = list(range(self.num_in_out, self.num_outputs))
|
|
412
|
+
|
|
413
|
+
# Build input output aliases.
|
|
414
|
+
out_id = 0
|
|
415
|
+
input_output_aliases = {}
|
|
416
|
+
for in_id, arg in enumerate(self.input_args):
|
|
417
|
+
if not arg.in_out:
|
|
418
|
+
continue
|
|
419
|
+
input_output_aliases[in_id] = out_id
|
|
420
|
+
out_id += 1
|
|
421
|
+
self.input_output_aliases = input_output_aliases
|
|
345
422
|
|
|
346
423
|
# register the callback
|
|
347
424
|
FFI_CCALLFUNC = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_CallFrame))
|
|
@@ -353,7 +430,9 @@ class FfiCallable:
|
|
|
353
430
|
def __call__(self, *args, output_dims=None, vmap_method=None):
|
|
354
431
|
num_inputs = len(args)
|
|
355
432
|
if num_inputs != self.num_inputs:
|
|
356
|
-
|
|
433
|
+
input_names = ", ".join(arg.name for arg in self.input_args)
|
|
434
|
+
s = "" if self.num_inputs == 1 else "s"
|
|
435
|
+
raise ValueError(f"Expected {self.num_inputs} input{s} ({input_names}), but got {num_inputs}")
|
|
357
436
|
|
|
358
437
|
# default argument fallback
|
|
359
438
|
if vmap_method is None:
|
|
@@ -361,6 +440,9 @@ class FfiCallable:
|
|
|
361
440
|
if output_dims is None:
|
|
362
441
|
output_dims = self.output_dims
|
|
363
442
|
|
|
443
|
+
# output types
|
|
444
|
+
out_types = []
|
|
445
|
+
|
|
364
446
|
# process inputs
|
|
365
447
|
static_inputs = {}
|
|
366
448
|
for i in range(num_inputs):
|
|
@@ -390,12 +472,11 @@ class FfiCallable:
|
|
|
390
472
|
# stash the value to be retrieved by callback
|
|
391
473
|
static_inputs[input_arg.name] = input_arg.type(input_value)
|
|
392
474
|
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
475
|
+
# append in-out arg to output types
|
|
476
|
+
if input_arg.in_out:
|
|
477
|
+
out_types.append(get_jax_output_type(input_arg, input_value.shape))
|
|
396
478
|
|
|
397
|
-
# output
|
|
398
|
-
out_types = []
|
|
479
|
+
# output shapes
|
|
399
480
|
if isinstance(output_dims, dict):
|
|
400
481
|
# assume a dictionary of shapes keyed on argument name
|
|
401
482
|
for output_arg in self.output_args:
|
|
@@ -405,7 +486,9 @@ class FfiCallable:
|
|
|
405
486
|
out_types.append(get_jax_output_type(output_arg, dims))
|
|
406
487
|
else:
|
|
407
488
|
if output_dims is None:
|
|
408
|
-
|
|
489
|
+
if self.first_array_arg is None:
|
|
490
|
+
raise ValueError("Unable to determine output dimensions")
|
|
491
|
+
output_dims = get_warp_shape(self.input_args[self.first_array_arg], args[self.first_array_arg].shape)
|
|
409
492
|
elif isinstance(output_dims, int):
|
|
410
493
|
output_dims = (output_dims,)
|
|
411
494
|
# assume same dimensions for all outputs
|
|
@@ -416,6 +499,7 @@ class FfiCallable:
|
|
|
416
499
|
self.name,
|
|
417
500
|
out_types,
|
|
418
501
|
vmap_method=vmap_method,
|
|
502
|
+
input_output_aliases=self.input_output_aliases,
|
|
419
503
|
# has_side_effect=True, # force this function to execute even if outputs aren't used
|
|
420
504
|
)
|
|
421
505
|
|
|
@@ -425,22 +509,18 @@ class FfiCallable:
|
|
|
425
509
|
module = wp.get_module(self.func.__module__)
|
|
426
510
|
module.load(device)
|
|
427
511
|
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
return call(*args, call_id=call_id)
|
|
434
|
-
else:
|
|
435
|
-
return call(*args)
|
|
512
|
+
# save call data to be retrieved by callback
|
|
513
|
+
call_id = self.call_id
|
|
514
|
+
self.call_descriptors[call_id] = FfiCallDesc(static_inputs)
|
|
515
|
+
self.call_id += 1
|
|
516
|
+
return call(*args, call_id=call_id)
|
|
436
517
|
|
|
437
518
|
def ffi_callback(self, call_frame):
|
|
438
519
|
try:
|
|
439
|
-
# TODO Try-catch around the body and return XLA_FFI_Error on error.
|
|
440
|
-
extension = call_frame.contents.extension_start
|
|
441
520
|
# On the first call, XLA runtime will query the API version and traits
|
|
442
521
|
# metadata using the |extension| field. Let us respond to that query
|
|
443
522
|
# if the metadata extension is present.
|
|
523
|
+
extension = call_frame.contents.extension_start
|
|
444
524
|
if extension:
|
|
445
525
|
# Try to set the version metadata.
|
|
446
526
|
if extension.contents.type == XLA_FFI_Extension_Type.Metadata:
|
|
@@ -448,17 +528,20 @@ class FfiCallable:
|
|
|
448
528
|
metadata_ext.contents.metadata.contents.api_version.major_version = 0
|
|
449
529
|
metadata_ext.contents.metadata.contents.api_version.minor_version = 1
|
|
450
530
|
# Turn on CUDA graphs for this handler.
|
|
451
|
-
if self.
|
|
531
|
+
if self.graph_mode is GraphMode.JAX:
|
|
452
532
|
metadata_ext.contents.metadata.contents.traits = (
|
|
453
533
|
XLA_FFI_Handler_TraitsBits.COMMAND_BUFFER_COMPATIBLE
|
|
454
534
|
)
|
|
455
535
|
return None
|
|
456
536
|
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
537
|
+
# retrieve call info
|
|
538
|
+
# NOTE: this assumes that there's only one attribute - call_id (int64).
|
|
539
|
+
# A more general but slower approach is this:
|
|
540
|
+
# attrs = decode_attrs(call_frame.contents.attrs)
|
|
541
|
+
# call_id = int(attrs["call_id"])
|
|
542
|
+
attr = ctypes.cast(call_frame.contents.attrs.attrs[0], ctypes.POINTER(XLA_FFI_Scalar)).contents
|
|
543
|
+
call_id = ctypes.cast(attr.value, ctypes.POINTER(ctypes.c_int64)).contents.value
|
|
544
|
+
call_desc = self.call_descriptors[call_id]
|
|
462
545
|
|
|
463
546
|
num_inputs = call_frame.contents.args.size
|
|
464
547
|
inputs = ctypes.cast(call_frame.contents.args.args, ctypes.POINTER(ctypes.POINTER(XLA_FFI_Buffer)))
|
|
@@ -469,16 +552,42 @@ class FfiCallable:
|
|
|
469
552
|
assert num_inputs == self.num_inputs
|
|
470
553
|
assert num_outputs == self.num_outputs
|
|
471
554
|
|
|
472
|
-
device = wp.device_from_jax(get_jax_device())
|
|
473
555
|
cuda_stream = get_stream_from_callframe(call_frame.contents)
|
|
556
|
+
|
|
557
|
+
if self.graph_mode == GraphMode.WARP:
|
|
558
|
+
# check if we already captured an identical call
|
|
559
|
+
ip = [inputs[i].contents.data for i in self.array_input_indices]
|
|
560
|
+
op = [outputs[i].contents.data for i in self.array_output_indices]
|
|
561
|
+
buffer_hash = hash((*ip, *op))
|
|
562
|
+
capture = call_desc.captures.get(buffer_hash)
|
|
563
|
+
|
|
564
|
+
# launch existing graph
|
|
565
|
+
if capture is not None:
|
|
566
|
+
# NOTE: We use the native graph API to avoid overhead with obtaining Stream and Device objects in Python.
|
|
567
|
+
# This code should match wp.capture_launch().
|
|
568
|
+
graph = capture.graph
|
|
569
|
+
if graph.graph_exec is None:
|
|
570
|
+
g = ctypes.c_void_p()
|
|
571
|
+
if not wp.context.runtime.core.wp_cuda_graph_create_exec(
|
|
572
|
+
graph.device.context, cuda_stream, graph.graph, ctypes.byref(g)
|
|
573
|
+
):
|
|
574
|
+
raise RuntimeError(f"Graph creation error: {wp.context.runtime.get_error_string()}")
|
|
575
|
+
graph.graph_exec = g
|
|
576
|
+
|
|
577
|
+
if not wp.context.runtime.core.wp_cuda_graph_launch(graph.graph_exec, cuda_stream):
|
|
578
|
+
raise RuntimeError(f"Graph launch error: {wp.context.runtime.get_error_string()}")
|
|
579
|
+
|
|
580
|
+
# early out
|
|
581
|
+
return
|
|
582
|
+
|
|
583
|
+
device = wp.device_from_jax(get_jax_device())
|
|
474
584
|
stream = wp.Stream(device, cuda_stream=cuda_stream)
|
|
475
585
|
|
|
476
586
|
# reconstruct the argument list
|
|
477
587
|
arg_list = []
|
|
478
588
|
|
|
479
|
-
#
|
|
480
|
-
for i in
|
|
481
|
-
arg = self.input_args[i]
|
|
589
|
+
# input and in-out args
|
|
590
|
+
for i, arg in enumerate(self.input_args):
|
|
482
591
|
if arg.is_array:
|
|
483
592
|
buffer = inputs[i].contents
|
|
484
593
|
shape = buffer.dims[: buffer.rank - arg.dtype_ndim]
|
|
@@ -489,10 +598,9 @@ class FfiCallable:
|
|
|
489
598
|
value = call_desc.static_inputs[arg.name]
|
|
490
599
|
arg_list.append(value)
|
|
491
600
|
|
|
492
|
-
#
|
|
493
|
-
for i in
|
|
494
|
-
|
|
495
|
-
buffer = outputs[i].contents
|
|
601
|
+
# pure output args (skip in-out FFI buffers)
|
|
602
|
+
for i, arg in enumerate(self.output_args):
|
|
603
|
+
buffer = outputs[i + self.num_in_out].contents
|
|
496
604
|
shape = buffer.dims[: buffer.rank - arg.dtype_ndim]
|
|
497
605
|
arr = wp.array(ptr=buffer.data, dtype=arg.type.dtype, shape=shape, device=device)
|
|
498
606
|
arg_list.append(arr)
|
|
@@ -500,9 +608,20 @@ class FfiCallable:
|
|
|
500
608
|
# call the Python function with reconstructed arguments
|
|
501
609
|
with wp.ScopedStream(stream, sync_enter=False):
|
|
502
610
|
if stream.is_capturing:
|
|
503
|
-
with
|
|
611
|
+
# capturing with JAX
|
|
612
|
+
with wp.ScopedCapture(external=True) as capture:
|
|
613
|
+
self.func(*arg_list)
|
|
614
|
+
# keep a reference to the capture object to prevent required modules getting unloaded
|
|
615
|
+
call_desc.capture = capture
|
|
616
|
+
elif self.graph_mode == GraphMode.WARP:
|
|
617
|
+
# capturing with WARP
|
|
618
|
+
with wp.ScopedCapture() as capture:
|
|
504
619
|
self.func(*arg_list)
|
|
620
|
+
wp.capture_launch(capture.graph)
|
|
621
|
+
# keep a reference to the capture object and reuse it with same buffers
|
|
622
|
+
call_desc.captures[buffer_hash] = capture
|
|
505
623
|
else:
|
|
624
|
+
# not capturing
|
|
506
625
|
self.func(*arg_list)
|
|
507
626
|
|
|
508
627
|
except Exception as e:
|
|
@@ -520,7 +639,9 @@ _FFI_KERNEL_REGISTRY: dict[str, FfiKernel] = {}
|
|
|
520
639
|
_FFI_REGISTRY_LOCK = threading.Lock()
|
|
521
640
|
|
|
522
641
|
|
|
523
|
-
def jax_kernel(
|
|
642
|
+
def jax_kernel(
|
|
643
|
+
kernel, num_outputs=1, vmap_method="broadcast_all", launch_dims=None, output_dims=None, in_out_argnames=None
|
|
644
|
+
):
|
|
524
645
|
"""Create a JAX callback from a Warp kernel.
|
|
525
646
|
|
|
526
647
|
NOTE: This is an experimental feature under development.
|
|
@@ -528,6 +649,7 @@ def jax_kernel(kernel, num_outputs=1, vmap_method="broadcast_all", launch_dims=N
|
|
|
528
649
|
Args:
|
|
529
650
|
kernel: The Warp kernel to launch.
|
|
530
651
|
num_outputs: Optional. Specify the number of output arguments if greater than 1.
|
|
652
|
+
This must include the number of ``in_out_arguments``.
|
|
531
653
|
vmap_method: Optional. String specifying how the callback transforms under ``vmap()``.
|
|
532
654
|
This argument can also be specified for individual calls.
|
|
533
655
|
launch_dims: Optional. Specify the default kernel launch dimensions. If None, launch
|
|
@@ -536,12 +658,13 @@ def jax_kernel(kernel, num_outputs=1, vmap_method="broadcast_all", launch_dims=N
|
|
|
536
658
|
output_dims: Optional. Specify the default dimensions of output arrays. If None, output
|
|
537
659
|
dimensions are inferred from the launch dimensions.
|
|
538
660
|
This argument can also be specified for individual calls.
|
|
661
|
+
in_out_argnames: Optional. Names of input-output arguments.
|
|
539
662
|
|
|
540
663
|
Limitations:
|
|
541
664
|
- All kernel arguments must be contiguous arrays or scalars.
|
|
542
665
|
- Scalars must be static arguments in JAX.
|
|
543
|
-
- Input arguments
|
|
544
|
-
- There must be at least one output argument.
|
|
666
|
+
- Input and input-output arguments must precede the output arguments in the ``kernel`` definition.
|
|
667
|
+
- There must be at least one output or input-output argument.
|
|
545
668
|
- Only the CUDA backend is supported.
|
|
546
669
|
"""
|
|
547
670
|
key = (
|
|
@@ -554,7 +677,7 @@ def jax_kernel(kernel, num_outputs=1, vmap_method="broadcast_all", launch_dims=N
|
|
|
554
677
|
|
|
555
678
|
with _FFI_REGISTRY_LOCK:
|
|
556
679
|
if key not in _FFI_KERNEL_REGISTRY:
|
|
557
|
-
new_kernel = FfiKernel(kernel, num_outputs, vmap_method, launch_dims, output_dims)
|
|
680
|
+
new_kernel = FfiKernel(kernel, num_outputs, vmap_method, launch_dims, output_dims, in_out_argnames)
|
|
558
681
|
_FFI_KERNEL_REGISTRY[key] = new_kernel
|
|
559
682
|
|
|
560
683
|
return _FFI_KERNEL_REGISTRY[key]
|
|
@@ -563,9 +686,11 @@ def jax_kernel(kernel, num_outputs=1, vmap_method="broadcast_all", launch_dims=N
|
|
|
563
686
|
def jax_callable(
|
|
564
687
|
func: Callable,
|
|
565
688
|
num_outputs: int = 1,
|
|
566
|
-
graph_compatible: bool =
|
|
567
|
-
|
|
689
|
+
graph_compatible: Optional[bool] = None, # deprecated
|
|
690
|
+
graph_mode: GraphMode = GraphMode.JAX,
|
|
691
|
+
vmap_method: Optional[str] = "broadcast_all",
|
|
568
692
|
output_dims=None,
|
|
693
|
+
in_out_argnames=None,
|
|
569
694
|
):
|
|
570
695
|
"""Create a JAX callback from an annotated Python function.
|
|
571
696
|
|
|
@@ -576,31 +701,50 @@ def jax_callable(
|
|
|
576
701
|
Args:
|
|
577
702
|
func: The Python function to call.
|
|
578
703
|
num_outputs: Optional. Specify the number of output arguments if greater than 1.
|
|
704
|
+
This must include the number of ``in_out_arguments``.
|
|
579
705
|
graph_compatible: Optional. Whether the function can be called during CUDA graph capture.
|
|
706
|
+
This argument is deprecated, use ``graph_mode`` instead.
|
|
707
|
+
graph_mode: Optional. CUDA graph capture mode.
|
|
708
|
+
``GraphMode.JAX`` (default): Let JAX capture the graph, which may be used as a subgraph in an enclosing capture.
|
|
709
|
+
``GraphMode.WARP``: Let Warp capture the graph. Use this mode when the callable cannot be used as a subraph,
|
|
710
|
+
such as when the callable uses conditional graph nodes.
|
|
711
|
+
``GraphMode.NONE``: Disable graph capture. Use when the callable performs operations that are not legal in a graph,
|
|
712
|
+
such as host synchronization.
|
|
580
713
|
vmap_method: Optional. String specifying how the callback transforms under ``vmap()``.
|
|
581
714
|
This argument can also be specified for individual calls.
|
|
582
715
|
output_dims: Optional. Specify the default dimensions of output arrays.
|
|
583
716
|
If ``None``, output dimensions are inferred from the launch dimensions.
|
|
584
717
|
This argument can also be specified for individual calls.
|
|
718
|
+
in_out_argnames: Optional. Names of input-output arguments.
|
|
585
719
|
|
|
586
720
|
Limitations:
|
|
587
721
|
- All kernel arguments must be contiguous arrays or scalars.
|
|
588
722
|
- Scalars must be static arguments in JAX.
|
|
589
|
-
- Input arguments
|
|
590
|
-
- There must be at least one output argument.
|
|
723
|
+
- Input and input-output arguments must precede the output arguments in the ``func`` definition.
|
|
724
|
+
- There must be at least one output or input-output argument.
|
|
591
725
|
- Only the CUDA backend is supported.
|
|
592
726
|
"""
|
|
727
|
+
|
|
728
|
+
if graph_compatible is not None:
|
|
729
|
+
wp.utils.warn(
|
|
730
|
+
"The `graph_compatible` argument is deprecated, use `graph_mode` instead.",
|
|
731
|
+
DeprecationWarning,
|
|
732
|
+
stacklevel=3,
|
|
733
|
+
)
|
|
734
|
+
if graph_compatible is False:
|
|
735
|
+
graph_mode = GraphMode.NONE
|
|
736
|
+
|
|
593
737
|
key = (
|
|
594
738
|
func,
|
|
595
739
|
num_outputs,
|
|
596
|
-
|
|
740
|
+
graph_mode,
|
|
597
741
|
vmap_method,
|
|
598
742
|
tuple(sorted(output_dims.items())) if output_dims else output_dims,
|
|
599
743
|
)
|
|
600
744
|
|
|
601
745
|
with _FFI_REGISTRY_LOCK:
|
|
602
746
|
if key not in _FFI_CALLABLE_REGISTRY:
|
|
603
|
-
new_callable = FfiCallable(func, num_outputs,
|
|
747
|
+
new_callable = FfiCallable(func, num_outputs, graph_mode, vmap_method, output_dims, in_out_argnames)
|
|
604
748
|
_FFI_CALLABLE_REGISTRY[key] = new_callable
|
|
605
749
|
|
|
606
750
|
return _FFI_CALLABLE_REGISTRY[key]
|
|
@@ -631,7 +775,6 @@ def register_ffi_callback(name: str, func: Callable, graph_compatible: bool = Tr
|
|
|
631
775
|
|
|
632
776
|
def ffi_callback(call_frame):
|
|
633
777
|
try:
|
|
634
|
-
# TODO Try-catch around the body and return XLA_FFI_Error on error.
|
|
635
778
|
extension = call_frame.contents.extension_start
|
|
636
779
|
# On the first call, XLA runtime will query the API version and traits
|
|
637
780
|
# metadata using the |extension| field. Let us respond to that query
|