warp-lang 1.8.0__py3-none-manylinux_2_34_aarch64.whl → 1.9.0__py3-none-manylinux_2_34_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +282 -103
- warp/__init__.pyi +482 -110
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/build.py +93 -30
- warp/build_dll.py +48 -63
- warp/builtins.py +955 -137
- warp/codegen.py +327 -209
- warp/config.py +1 -1
- warp/context.py +1363 -800
- warp/examples/core/example_marching_cubes.py +1 -0
- warp/examples/core/example_render_opengl.py +100 -3
- warp/examples/fem/example_apic_fluid.py +98 -52
- warp/examples/fem/example_convection_diffusion_dg.py +25 -4
- warp/examples/fem/example_diffusion_mgpu.py +8 -3
- warp/examples/fem/utils.py +68 -22
- warp/examples/interop/example_jax_callable.py +34 -4
- warp/examples/interop/example_jax_kernel.py +27 -1
- warp/fabric.py +1 -1
- warp/fem/cache.py +27 -19
- warp/fem/domain.py +2 -2
- warp/fem/field/nodal_field.py +2 -2
- warp/fem/field/virtual.py +266 -166
- warp/fem/geometry/geometry.py +5 -5
- warp/fem/integrate.py +200 -91
- warp/fem/space/restriction.py +4 -0
- warp/fem/space/shape/tet_shape_function.py +3 -10
- warp/jax_experimental/custom_call.py +1 -1
- warp/jax_experimental/ffi.py +203 -54
- warp/marching_cubes.py +708 -0
- warp/native/array.h +103 -8
- warp/native/builtin.h +90 -9
- warp/native/bvh.cpp +64 -28
- warp/native/bvh.cu +58 -58
- warp/native/bvh.h +2 -2
- warp/native/clang/clang.cpp +7 -7
- warp/native/coloring.cpp +13 -3
- warp/native/crt.cpp +2 -2
- warp/native/crt.h +3 -5
- warp/native/cuda_util.cpp +42 -11
- warp/native/cuda_util.h +10 -4
- warp/native/exports.h +1842 -1908
- warp/native/fabric.h +2 -1
- warp/native/hashgrid.cpp +37 -37
- warp/native/hashgrid.cu +2 -2
- warp/native/initializer_array.h +1 -1
- warp/native/intersect.h +4 -4
- warp/native/mat.h +1913 -119
- warp/native/mathdx.cpp +43 -43
- warp/native/mesh.cpp +24 -24
- warp/native/mesh.cu +26 -26
- warp/native/mesh.h +5 -3
- warp/native/nanovdb/GridHandle.h +179 -12
- warp/native/nanovdb/HostBuffer.h +8 -7
- warp/native/nanovdb/NanoVDB.h +517 -895
- warp/native/nanovdb/NodeManager.h +323 -0
- warp/native/nanovdb/PNanoVDB.h +2 -2
- warp/native/quat.h +337 -16
- warp/native/rand.h +7 -7
- warp/native/range.h +7 -1
- warp/native/reduce.cpp +10 -10
- warp/native/reduce.cu +13 -14
- warp/native/runlength_encode.cpp +2 -2
- warp/native/runlength_encode.cu +5 -5
- warp/native/scan.cpp +3 -3
- warp/native/scan.cu +4 -4
- warp/native/sort.cpp +10 -10
- warp/native/sort.cu +22 -22
- warp/native/sparse.cpp +8 -8
- warp/native/sparse.cu +14 -14
- warp/native/spatial.h +366 -17
- warp/native/svd.h +23 -8
- warp/native/temp_buffer.h +2 -2
- warp/native/tile.h +303 -70
- warp/native/tile_radix_sort.h +5 -1
- warp/native/tile_reduce.h +16 -25
- warp/native/tuple.h +2 -2
- warp/native/vec.h +385 -18
- warp/native/volume.cpp +54 -54
- warp/native/volume.cu +1 -1
- warp/native/volume.h +2 -1
- warp/native/volume_builder.cu +30 -37
- warp/native/warp.cpp +150 -149
- warp/native/warp.cu +337 -193
- warp/native/warp.h +227 -226
- warp/optim/linear.py +736 -271
- warp/render/imgui_manager.py +289 -0
- warp/render/render_opengl.py +137 -57
- warp/render/render_usd.py +0 -1
- warp/sim/collide.py +1 -2
- warp/sim/graph_coloring.py +2 -2
- warp/sim/integrator_vbd.py +10 -2
- warp/sparse.py +559 -176
- warp/tape.py +2 -0
- warp/tests/aux_test_module_aot.py +7 -0
- warp/tests/cuda/test_async.py +3 -3
- warp/tests/cuda/test_conditional_captures.py +101 -0
- warp/tests/geometry/test_marching_cubes.py +233 -12
- warp/tests/sim/test_cloth.py +89 -6
- warp/tests/sim/test_coloring.py +82 -7
- warp/tests/test_array.py +56 -5
- warp/tests/test_assert.py +53 -0
- warp/tests/test_atomic_cas.py +127 -114
- warp/tests/test_codegen.py +3 -2
- warp/tests/test_context.py +8 -15
- warp/tests/test_enum.py +136 -0
- warp/tests/test_examples.py +2 -2
- warp/tests/test_fem.py +45 -2
- warp/tests/test_fixedarray.py +229 -0
- warp/tests/test_func.py +18 -15
- warp/tests/test_future_annotations.py +7 -5
- warp/tests/test_linear_solvers.py +30 -0
- warp/tests/test_map.py +1 -1
- warp/tests/test_mat.py +1540 -378
- warp/tests/test_mat_assign_copy.py +178 -0
- warp/tests/test_mat_constructors.py +574 -0
- warp/tests/test_module_aot.py +287 -0
- warp/tests/test_print.py +69 -0
- warp/tests/test_quat.py +162 -34
- warp/tests/test_quat_assign_copy.py +145 -0
- warp/tests/test_reload.py +2 -1
- warp/tests/test_sparse.py +103 -0
- warp/tests/test_spatial.py +140 -34
- warp/tests/test_spatial_assign_copy.py +160 -0
- warp/tests/test_static.py +48 -0
- warp/tests/test_struct.py +43 -3
- warp/tests/test_tape.py +38 -0
- warp/tests/test_types.py +0 -20
- warp/tests/test_vec.py +216 -441
- warp/tests/test_vec_assign_copy.py +143 -0
- warp/tests/test_vec_constructors.py +325 -0
- warp/tests/tile/test_tile.py +206 -152
- warp/tests/tile/test_tile_cholesky.py +605 -0
- warp/tests/tile/test_tile_load.py +169 -0
- warp/tests/tile/test_tile_mathdx.py +2 -558
- warp/tests/tile/test_tile_matmul.py +179 -0
- warp/tests/tile/test_tile_mlp.py +1 -1
- warp/tests/tile/test_tile_reduce.py +100 -11
- warp/tests/tile/test_tile_shared_memory.py +16 -16
- warp/tests/tile/test_tile_sort.py +59 -55
- warp/tests/unittest_suites.py +16 -0
- warp/tests/walkthrough_debug.py +1 -1
- warp/thirdparty/unittest_parallel.py +108 -9
- warp/types.py +554 -264
- warp/utils.py +68 -86
- {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/METADATA +28 -65
- {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/RECORD +150 -138
- warp/native/marching.cpp +0 -19
- warp/native/marching.cu +0 -514
- warp/native/marching.h +0 -19
- {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/WHEEL +0 -0
- {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/licenses/LICENSE.md +0 -0
- {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/top_level.txt +0 -0
warp/jax_experimental/ffi.py
CHANGED
|
@@ -16,7 +16,8 @@
|
|
|
16
16
|
import ctypes
|
|
17
17
|
import threading
|
|
18
18
|
import traceback
|
|
19
|
-
from
|
|
19
|
+
from enum import IntEnum
|
|
20
|
+
from typing import Callable, Optional
|
|
20
21
|
|
|
21
22
|
import jax
|
|
22
23
|
|
|
@@ -28,10 +29,17 @@ from warp.types import array_t, launch_bounds_t, strides_from_shape, type_to_war
|
|
|
28
29
|
from .xla_ffi import *
|
|
29
30
|
|
|
30
31
|
|
|
32
|
+
class GraphMode(IntEnum):
|
|
33
|
+
NONE = 0 # don't capture a graph
|
|
34
|
+
JAX = 1 # let JAX capture a graph
|
|
35
|
+
WARP = 2 # let Warp capture a graph
|
|
36
|
+
|
|
37
|
+
|
|
31
38
|
class FfiArg:
|
|
32
|
-
def __init__(self, name, type):
|
|
39
|
+
def __init__(self, name, type, in_out=False):
|
|
33
40
|
self.name = name
|
|
34
41
|
self.type = type
|
|
42
|
+
self.in_out = in_out
|
|
35
43
|
self.is_array = isinstance(type, wp.array)
|
|
36
44
|
|
|
37
45
|
if self.is_array:
|
|
@@ -65,7 +73,7 @@ class FfiLaunchDesc:
|
|
|
65
73
|
|
|
66
74
|
|
|
67
75
|
class FfiKernel:
|
|
68
|
-
def __init__(self, kernel, num_outputs, vmap_method, launch_dims, output_dims):
|
|
76
|
+
def __init__(self, kernel, num_outputs, vmap_method, launch_dims, output_dims, in_out_argnames):
|
|
69
77
|
self.kernel = kernel
|
|
70
78
|
self.name = generate_unique_name(kernel.func)
|
|
71
79
|
self.num_outputs = num_outputs
|
|
@@ -76,17 +84,28 @@ class FfiKernel:
|
|
|
76
84
|
self.launch_id = 0
|
|
77
85
|
self.launch_descriptors = {}
|
|
78
86
|
|
|
87
|
+
in_out_argnames_list = in_out_argnames or []
|
|
88
|
+
in_out_argnames = set(in_out_argnames_list)
|
|
89
|
+
if len(in_out_argnames_list) != len(in_out_argnames):
|
|
90
|
+
raise AssertionError("in_out_argnames must not contain duplicate names")
|
|
91
|
+
|
|
79
92
|
self.num_kernel_args = len(kernel.adj.args)
|
|
80
|
-
self.
|
|
93
|
+
self.num_in_out = len(in_out_argnames)
|
|
94
|
+
self.num_inputs = self.num_kernel_args - num_outputs + self.num_in_out
|
|
81
95
|
if self.num_outputs < 1:
|
|
82
96
|
raise ValueError("At least one output is required")
|
|
83
97
|
if self.num_outputs > self.num_kernel_args:
|
|
84
98
|
raise ValueError("Number of outputs cannot be greater than the number of kernel arguments")
|
|
99
|
+
if self.num_outputs < self.num_in_out:
|
|
100
|
+
raise ValueError("Number of outputs cannot be smaller than the number of in_out_argnames")
|
|
85
101
|
|
|
86
102
|
# process input args
|
|
87
103
|
self.input_args = []
|
|
88
104
|
for i in range(self.num_inputs):
|
|
89
|
-
|
|
105
|
+
arg_name = kernel.adj.args[i].label
|
|
106
|
+
arg = FfiArg(arg_name, kernel.adj.args[i].type, arg_name in in_out_argnames)
|
|
107
|
+
if arg_name in in_out_argnames:
|
|
108
|
+
in_out_argnames.remove(arg_name)
|
|
90
109
|
if arg.is_array:
|
|
91
110
|
# keep track of the first input array argument
|
|
92
111
|
if self.first_array_arg is None:
|
|
@@ -96,11 +115,30 @@ class FfiKernel:
|
|
|
96
115
|
# process output args
|
|
97
116
|
self.output_args = []
|
|
98
117
|
for i in range(self.num_inputs, self.num_kernel_args):
|
|
99
|
-
|
|
118
|
+
arg_name = kernel.adj.args[i].label
|
|
119
|
+
if arg_name in in_out_argnames:
|
|
120
|
+
raise AssertionError(
|
|
121
|
+
f"Expected an output-only argument for argument {arg_name}."
|
|
122
|
+
" in_out arguments should be placed before output-only arguments."
|
|
123
|
+
)
|
|
124
|
+
arg = FfiArg(arg_name, kernel.adj.args[i].type, False)
|
|
100
125
|
if not arg.is_array:
|
|
101
126
|
raise TypeError("All output arguments must be arrays")
|
|
102
127
|
self.output_args.append(arg)
|
|
103
128
|
|
|
129
|
+
if in_out_argnames:
|
|
130
|
+
raise ValueError(f"in_out_argnames: '{in_out_argnames}' did not match any function argument names.")
|
|
131
|
+
|
|
132
|
+
# Build input output aliases.
|
|
133
|
+
out_id = 0
|
|
134
|
+
input_output_aliases = {}
|
|
135
|
+
for in_id, arg in enumerate(self.input_args):
|
|
136
|
+
if not arg.in_out:
|
|
137
|
+
continue
|
|
138
|
+
input_output_aliases[in_id] = out_id
|
|
139
|
+
out_id += 1
|
|
140
|
+
self.input_output_aliases = input_output_aliases
|
|
141
|
+
|
|
104
142
|
# register the callback
|
|
105
143
|
FFI_CCALLFUNC = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_CallFrame))
|
|
106
144
|
self.callback_func = FFI_CCALLFUNC(lambda call_frame: self.ffi_callback(call_frame))
|
|
@@ -121,6 +159,9 @@ class FfiKernel:
|
|
|
121
159
|
if vmap_method is None:
|
|
122
160
|
vmap_method = self.vmap_method
|
|
123
161
|
|
|
162
|
+
# output types
|
|
163
|
+
out_types = []
|
|
164
|
+
|
|
124
165
|
# process inputs
|
|
125
166
|
static_inputs = {}
|
|
126
167
|
for i in range(num_inputs):
|
|
@@ -150,6 +191,10 @@ class FfiKernel:
|
|
|
150
191
|
# stash the value to be retrieved by callback
|
|
151
192
|
static_inputs[input_arg.name] = input_arg.type(input_value)
|
|
152
193
|
|
|
194
|
+
# append in-out arg to output types
|
|
195
|
+
if input_arg.in_out:
|
|
196
|
+
out_types.append(get_jax_output_type(input_arg, input_value.shape))
|
|
197
|
+
|
|
153
198
|
# launch dimensions
|
|
154
199
|
if launch_dims is None:
|
|
155
200
|
# use the shape of the first input array
|
|
@@ -162,8 +207,7 @@ class FfiKernel:
|
|
|
162
207
|
else:
|
|
163
208
|
launch_dims = tuple(launch_dims)
|
|
164
209
|
|
|
165
|
-
# output
|
|
166
|
-
out_types = []
|
|
210
|
+
# output shapes
|
|
167
211
|
if isinstance(output_dims, dict):
|
|
168
212
|
# assume a dictionary of shapes keyed on argument name
|
|
169
213
|
for output_arg in self.output_args:
|
|
@@ -185,6 +229,7 @@ class FfiKernel:
|
|
|
185
229
|
self.name,
|
|
186
230
|
out_types,
|
|
187
231
|
vmap_method=vmap_method,
|
|
232
|
+
input_output_aliases=self.input_output_aliases,
|
|
188
233
|
)
|
|
189
234
|
|
|
190
235
|
# ensure the kernel module is loaded before the callback, otherwise graph capture may fail
|
|
@@ -238,9 +283,8 @@ class FfiKernel:
|
|
|
238
283
|
|
|
239
284
|
arg_refs = []
|
|
240
285
|
|
|
241
|
-
#
|
|
242
|
-
for i in
|
|
243
|
-
input_arg = self.input_args[i]
|
|
286
|
+
# input and in-out args
|
|
287
|
+
for i, input_arg in enumerate(self.input_args):
|
|
244
288
|
if input_arg.is_array:
|
|
245
289
|
buffer = inputs[i].contents
|
|
246
290
|
shape = buffer.dims[: input_arg.type.ndim]
|
|
@@ -255,10 +299,9 @@ class FfiKernel:
|
|
|
255
299
|
kernel_params[i + 1] = ctypes.addressof(arg)
|
|
256
300
|
arg_refs.append(arg) # keep a reference
|
|
257
301
|
|
|
258
|
-
#
|
|
259
|
-
for i in
|
|
260
|
-
|
|
261
|
-
buffer = outputs[i].contents
|
|
302
|
+
# pure output args (skip in-out FFI buffers)
|
|
303
|
+
for i, output_arg in enumerate(self.output_args):
|
|
304
|
+
buffer = outputs[i + self.num_in_out].contents
|
|
262
305
|
shape = buffer.dims[: output_arg.type.ndim]
|
|
263
306
|
strides = strides_from_shape(shape, output_arg.type.dtype)
|
|
264
307
|
arg = array_t(buffer.data, 0, output_arg.type.ndim, shape, strides)
|
|
@@ -274,7 +317,7 @@ class FfiKernel:
|
|
|
274
317
|
assert hooks.forward, "Failed to find kernel entry point"
|
|
275
318
|
|
|
276
319
|
# launch the kernel
|
|
277
|
-
wp.context.runtime.core.
|
|
320
|
+
wp.context.runtime.core.wp_cuda_launch_kernel(
|
|
278
321
|
device.context,
|
|
279
322
|
hooks.forward,
|
|
280
323
|
launch_bounds.size,
|
|
@@ -295,29 +338,38 @@ class FfiKernel:
|
|
|
295
338
|
class FfiCallDesc:
|
|
296
339
|
def __init__(self, static_inputs):
|
|
297
340
|
self.static_inputs = static_inputs
|
|
341
|
+
self.captures = {}
|
|
298
342
|
|
|
299
343
|
|
|
300
344
|
class FfiCallable:
|
|
301
|
-
def __init__(self, func, num_outputs,
|
|
345
|
+
def __init__(self, func, num_outputs, graph_mode, vmap_method, output_dims, in_out_argnames):
|
|
302
346
|
self.func = func
|
|
303
347
|
self.name = generate_unique_name(func)
|
|
304
348
|
self.num_outputs = num_outputs
|
|
305
349
|
self.vmap_method = vmap_method
|
|
306
|
-
self.
|
|
350
|
+
self.graph_mode = graph_mode
|
|
307
351
|
self.output_dims = output_dims
|
|
308
352
|
self.first_array_arg = None
|
|
309
353
|
self.call_id = 0
|
|
310
354
|
self.call_descriptors = {}
|
|
311
355
|
|
|
356
|
+
in_out_argnames_list = in_out_argnames or []
|
|
357
|
+
in_out_argnames = set(in_out_argnames_list)
|
|
358
|
+
if len(in_out_argnames_list) != len(in_out_argnames):
|
|
359
|
+
raise AssertionError("in_out_argnames must not contain duplicate names")
|
|
360
|
+
|
|
312
361
|
# get arguments and annotations
|
|
313
362
|
argspec = get_full_arg_spec(func)
|
|
314
363
|
|
|
315
364
|
num_args = len(argspec.args)
|
|
316
|
-
self.
|
|
365
|
+
self.num_in_out = len(in_out_argnames)
|
|
366
|
+
self.num_inputs = num_args - num_outputs + self.num_in_out
|
|
317
367
|
if self.num_outputs < 1:
|
|
318
368
|
raise ValueError("At least one output is required")
|
|
319
369
|
if self.num_outputs > num_args:
|
|
320
370
|
raise ValueError("Number of outputs cannot be greater than the number of kernel arguments")
|
|
371
|
+
if self.num_outputs < self.num_in_out:
|
|
372
|
+
raise ValueError("Number of outputs cannot be smaller than the number of in_out_argnames")
|
|
321
373
|
|
|
322
374
|
if len(argspec.annotations) < num_args:
|
|
323
375
|
raise RuntimeError(f"Incomplete argument annotations on function {self.name}")
|
|
@@ -329,16 +381,45 @@ class FfiCallable:
|
|
|
329
381
|
if arg_name == "return":
|
|
330
382
|
if arg_type is not None:
|
|
331
383
|
raise TypeError("Function must not return a value")
|
|
384
|
+
continue
|
|
332
385
|
else:
|
|
333
|
-
arg = FfiArg(arg_name, arg_type)
|
|
386
|
+
arg = FfiArg(arg_name, arg_type, arg_name in in_out_argnames)
|
|
387
|
+
if arg_name in in_out_argnames:
|
|
388
|
+
in_out_argnames.remove(arg_name)
|
|
334
389
|
if arg.is_array:
|
|
335
390
|
if arg_idx < self.num_inputs and self.first_array_arg is None:
|
|
336
391
|
self.first_array_arg = arg_idx
|
|
337
392
|
self.args.append(arg)
|
|
393
|
+
|
|
394
|
+
if arg.in_out and arg_idx >= self.num_inputs:
|
|
395
|
+
raise AssertionError(
|
|
396
|
+
f"Expected an output-only argument for argument {arg_name}."
|
|
397
|
+
" in_out arguments should be placed before output-only arguments."
|
|
398
|
+
)
|
|
399
|
+
|
|
338
400
|
arg_idx += 1
|
|
339
401
|
|
|
340
|
-
|
|
341
|
-
|
|
402
|
+
if in_out_argnames:
|
|
403
|
+
raise ValueError(f"in_out_argnames: '{in_out_argnames}' did not match any function argument names.")
|
|
404
|
+
|
|
405
|
+
self.input_args = self.args[: self.num_inputs] # includes in-out args
|
|
406
|
+
self.output_args = self.args[self.num_inputs :] # pure output args
|
|
407
|
+
|
|
408
|
+
# Buffer indices for array arguments in callback.
|
|
409
|
+
# In-out buffers are the same pointers in the XLA call frame,
|
|
410
|
+
# so we only include them for inputs and skip them for outputs.
|
|
411
|
+
self.array_input_indices = [i for i, arg in enumerate(self.input_args) if arg.is_array]
|
|
412
|
+
self.array_output_indices = list(range(self.num_in_out, self.num_outputs))
|
|
413
|
+
|
|
414
|
+
# Build input output aliases.
|
|
415
|
+
out_id = 0
|
|
416
|
+
input_output_aliases = {}
|
|
417
|
+
for in_id, arg in enumerate(self.input_args):
|
|
418
|
+
if not arg.in_out:
|
|
419
|
+
continue
|
|
420
|
+
input_output_aliases[in_id] = out_id
|
|
421
|
+
out_id += 1
|
|
422
|
+
self.input_output_aliases = input_output_aliases
|
|
342
423
|
|
|
343
424
|
# register the callback
|
|
344
425
|
FFI_CCALLFUNC = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_CallFrame))
|
|
@@ -350,7 +431,9 @@ class FfiCallable:
|
|
|
350
431
|
def __call__(self, *args, output_dims=None, vmap_method=None):
|
|
351
432
|
num_inputs = len(args)
|
|
352
433
|
if num_inputs != self.num_inputs:
|
|
353
|
-
|
|
434
|
+
input_names = ", ".join(arg.name for arg in self.input_args)
|
|
435
|
+
s = "" if self.num_inputs == 1 else "s"
|
|
436
|
+
raise ValueError(f"Expected {self.num_inputs} input{s} ({input_names}), but got {num_inputs}")
|
|
354
437
|
|
|
355
438
|
# default argument fallback
|
|
356
439
|
if vmap_method is None:
|
|
@@ -358,6 +441,9 @@ class FfiCallable:
|
|
|
358
441
|
if output_dims is None:
|
|
359
442
|
output_dims = self.output_dims
|
|
360
443
|
|
|
444
|
+
# output types
|
|
445
|
+
out_types = []
|
|
446
|
+
|
|
361
447
|
# process inputs
|
|
362
448
|
static_inputs = {}
|
|
363
449
|
for i in range(num_inputs):
|
|
@@ -387,12 +473,11 @@ class FfiCallable:
|
|
|
387
473
|
# stash the value to be retrieved by callback
|
|
388
474
|
static_inputs[input_arg.name] = input_arg.type(input_value)
|
|
389
475
|
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
476
|
+
# append in-out arg to output types
|
|
477
|
+
if input_arg.in_out:
|
|
478
|
+
out_types.append(get_jax_output_type(input_arg, input_value.shape))
|
|
393
479
|
|
|
394
|
-
# output
|
|
395
|
-
out_types = []
|
|
480
|
+
# output shapes
|
|
396
481
|
if isinstance(output_dims, dict):
|
|
397
482
|
# assume a dictionary of shapes keyed on argument name
|
|
398
483
|
for output_arg in self.output_args:
|
|
@@ -402,7 +487,9 @@ class FfiCallable:
|
|
|
402
487
|
out_types.append(get_jax_output_type(output_arg, dims))
|
|
403
488
|
else:
|
|
404
489
|
if output_dims is None:
|
|
405
|
-
|
|
490
|
+
if self.first_array_arg is None:
|
|
491
|
+
raise ValueError("Unable to determine output dimensions")
|
|
492
|
+
output_dims = get_warp_shape(self.input_args[self.first_array_arg], args[self.first_array_arg].shape)
|
|
406
493
|
elif isinstance(output_dims, int):
|
|
407
494
|
output_dims = (output_dims,)
|
|
408
495
|
# assume same dimensions for all outputs
|
|
@@ -413,6 +500,7 @@ class FfiCallable:
|
|
|
413
500
|
self.name,
|
|
414
501
|
out_types,
|
|
415
502
|
vmap_method=vmap_method,
|
|
503
|
+
input_output_aliases=self.input_output_aliases,
|
|
416
504
|
# has_side_effect=True, # force this function to execute even if outputs aren't used
|
|
417
505
|
)
|
|
418
506
|
|
|
@@ -430,11 +518,10 @@ class FfiCallable:
|
|
|
430
518
|
|
|
431
519
|
def ffi_callback(self, call_frame):
|
|
432
520
|
try:
|
|
433
|
-
# TODO Try-catch around the body and return XLA_FFI_Error on error.
|
|
434
|
-
extension = call_frame.contents.extension_start
|
|
435
521
|
# On the first call, XLA runtime will query the API version and traits
|
|
436
522
|
# metadata using the |extension| field. Let us respond to that query
|
|
437
523
|
# if the metadata extension is present.
|
|
524
|
+
extension = call_frame.contents.extension_start
|
|
438
525
|
if extension:
|
|
439
526
|
# Try to set the version metadata.
|
|
440
527
|
if extension.contents.type == XLA_FFI_Extension_Type.Metadata:
|
|
@@ -442,15 +529,19 @@ class FfiCallable:
|
|
|
442
529
|
metadata_ext.contents.metadata.contents.api_version.major_version = 0
|
|
443
530
|
metadata_ext.contents.metadata.contents.api_version.minor_version = 1
|
|
444
531
|
# Turn on CUDA graphs for this handler.
|
|
445
|
-
if self.
|
|
532
|
+
if self.graph_mode is GraphMode.JAX:
|
|
446
533
|
metadata_ext.contents.metadata.contents.traits = (
|
|
447
534
|
XLA_FFI_Handler_TraitsBits.COMMAND_BUFFER_COMPATIBLE
|
|
448
535
|
)
|
|
449
536
|
return None
|
|
450
537
|
|
|
451
538
|
# retrieve call info
|
|
452
|
-
|
|
453
|
-
|
|
539
|
+
# NOTE: this assumes that there's only one attribute - call_id (int64).
|
|
540
|
+
# A more general but slower approach is this:
|
|
541
|
+
# attrs = decode_attrs(call_frame.contents.attrs)
|
|
542
|
+
# call_id = int(attrs["call_id"])
|
|
543
|
+
attr = ctypes.cast(call_frame.contents.attrs.attrs[0], ctypes.POINTER(XLA_FFI_Scalar)).contents
|
|
544
|
+
call_id = ctypes.cast(attr.value, ctypes.POINTER(ctypes.c_int64)).contents.value
|
|
454
545
|
call_desc = self.call_descriptors[call_id]
|
|
455
546
|
|
|
456
547
|
num_inputs = call_frame.contents.args.size
|
|
@@ -462,16 +553,42 @@ class FfiCallable:
|
|
|
462
553
|
assert num_inputs == self.num_inputs
|
|
463
554
|
assert num_outputs == self.num_outputs
|
|
464
555
|
|
|
465
|
-
device = wp.device_from_jax(get_jax_device())
|
|
466
556
|
cuda_stream = get_stream_from_callframe(call_frame.contents)
|
|
557
|
+
|
|
558
|
+
if self.graph_mode == GraphMode.WARP:
|
|
559
|
+
# check if we already captured an identical call
|
|
560
|
+
ip = [inputs[i].contents.data for i in self.array_input_indices]
|
|
561
|
+
op = [outputs[i].contents.data for i in self.array_output_indices]
|
|
562
|
+
buffer_hash = hash((*ip, *op))
|
|
563
|
+
capture = call_desc.captures.get(buffer_hash)
|
|
564
|
+
|
|
565
|
+
# launch existing graph
|
|
566
|
+
if capture is not None:
|
|
567
|
+
# NOTE: We use the native graph API to avoid overhead with obtaining Stream and Device objects in Python.
|
|
568
|
+
# This code should match wp.capture_launch().
|
|
569
|
+
graph = capture.graph
|
|
570
|
+
if graph.graph_exec is None:
|
|
571
|
+
g = ctypes.c_void_p()
|
|
572
|
+
if not wp.context.runtime.core.wp_cuda_graph_create_exec(
|
|
573
|
+
graph.device.context, cuda_stream, graph.graph, ctypes.byref(g)
|
|
574
|
+
):
|
|
575
|
+
raise RuntimeError(f"Graph creation error: {wp.context.runtime.get_error_string()}")
|
|
576
|
+
graph.graph_exec = g
|
|
577
|
+
|
|
578
|
+
if not wp.context.runtime.core.wp_cuda_graph_launch(graph.graph_exec, cuda_stream):
|
|
579
|
+
raise RuntimeError(f"Graph launch error: {wp.context.runtime.get_error_string()}")
|
|
580
|
+
|
|
581
|
+
# early out
|
|
582
|
+
return
|
|
583
|
+
|
|
584
|
+
device = wp.device_from_jax(get_jax_device())
|
|
467
585
|
stream = wp.Stream(device, cuda_stream=cuda_stream)
|
|
468
586
|
|
|
469
587
|
# reconstruct the argument list
|
|
470
588
|
arg_list = []
|
|
471
589
|
|
|
472
|
-
#
|
|
473
|
-
for i in
|
|
474
|
-
arg = self.input_args[i]
|
|
590
|
+
# input and in-out args
|
|
591
|
+
for i, arg in enumerate(self.input_args):
|
|
475
592
|
if arg.is_array:
|
|
476
593
|
buffer = inputs[i].contents
|
|
477
594
|
shape = buffer.dims[: buffer.rank - arg.dtype_ndim]
|
|
@@ -482,10 +599,9 @@ class FfiCallable:
|
|
|
482
599
|
value = call_desc.static_inputs[arg.name]
|
|
483
600
|
arg_list.append(value)
|
|
484
601
|
|
|
485
|
-
#
|
|
486
|
-
for i in
|
|
487
|
-
|
|
488
|
-
buffer = outputs[i].contents
|
|
602
|
+
# pure output args (skip in-out FFI buffers)
|
|
603
|
+
for i, arg in enumerate(self.output_args):
|
|
604
|
+
buffer = outputs[i + self.num_in_out].contents
|
|
489
605
|
shape = buffer.dims[: buffer.rank - arg.dtype_ndim]
|
|
490
606
|
arr = wp.array(ptr=buffer.data, dtype=arg.type.dtype, shape=shape, device=device)
|
|
491
607
|
arg_list.append(arr)
|
|
@@ -493,11 +609,20 @@ class FfiCallable:
|
|
|
493
609
|
# call the Python function with reconstructed arguments
|
|
494
610
|
with wp.ScopedStream(stream, sync_enter=False):
|
|
495
611
|
if stream.is_capturing:
|
|
496
|
-
|
|
612
|
+
# capturing with JAX
|
|
613
|
+
with wp.ScopedCapture(external=True) as capture:
|
|
497
614
|
self.func(*arg_list)
|
|
498
615
|
# keep a reference to the capture object to prevent required modules getting unloaded
|
|
499
616
|
call_desc.capture = capture
|
|
617
|
+
elif self.graph_mode == GraphMode.WARP:
|
|
618
|
+
# capturing with WARP
|
|
619
|
+
with wp.ScopedCapture() as capture:
|
|
620
|
+
self.func(*arg_list)
|
|
621
|
+
wp.capture_launch(capture.graph)
|
|
622
|
+
# keep a reference to the capture object and reuse it with same buffers
|
|
623
|
+
call_desc.captures[buffer_hash] = capture
|
|
500
624
|
else:
|
|
625
|
+
# not capturing
|
|
501
626
|
self.func(*arg_list)
|
|
502
627
|
|
|
503
628
|
except Exception as e:
|
|
@@ -515,7 +640,9 @@ _FFI_KERNEL_REGISTRY: dict[str, FfiKernel] = {}
|
|
|
515
640
|
_FFI_REGISTRY_LOCK = threading.Lock()
|
|
516
641
|
|
|
517
642
|
|
|
518
|
-
def jax_kernel(
|
|
643
|
+
def jax_kernel(
|
|
644
|
+
kernel, num_outputs=1, vmap_method="broadcast_all", launch_dims=None, output_dims=None, in_out_argnames=None
|
|
645
|
+
):
|
|
519
646
|
"""Create a JAX callback from a Warp kernel.
|
|
520
647
|
|
|
521
648
|
NOTE: This is an experimental feature under development.
|
|
@@ -523,6 +650,7 @@ def jax_kernel(kernel, num_outputs=1, vmap_method="broadcast_all", launch_dims=N
|
|
|
523
650
|
Args:
|
|
524
651
|
kernel: The Warp kernel to launch.
|
|
525
652
|
num_outputs: Optional. Specify the number of output arguments if greater than 1.
|
|
653
|
+
This must include the number of ``in_out_arguments``.
|
|
526
654
|
vmap_method: Optional. String specifying how the callback transforms under ``vmap()``.
|
|
527
655
|
This argument can also be specified for individual calls.
|
|
528
656
|
launch_dims: Optional. Specify the default kernel launch dimensions. If None, launch
|
|
@@ -531,12 +659,13 @@ def jax_kernel(kernel, num_outputs=1, vmap_method="broadcast_all", launch_dims=N
|
|
|
531
659
|
output_dims: Optional. Specify the default dimensions of output arrays. If None, output
|
|
532
660
|
dimensions are inferred from the launch dimensions.
|
|
533
661
|
This argument can also be specified for individual calls.
|
|
662
|
+
in_out_argnames: Optional. Names of input-output arguments.
|
|
534
663
|
|
|
535
664
|
Limitations:
|
|
536
665
|
- All kernel arguments must be contiguous arrays or scalars.
|
|
537
666
|
- Scalars must be static arguments in JAX.
|
|
538
|
-
- Input arguments
|
|
539
|
-
- There must be at least one output argument.
|
|
667
|
+
- Input and input-output arguments must precede the output arguments in the ``kernel`` definition.
|
|
668
|
+
- There must be at least one output or input-output argument.
|
|
540
669
|
- Only the CUDA backend is supported.
|
|
541
670
|
"""
|
|
542
671
|
key = (
|
|
@@ -549,7 +678,7 @@ def jax_kernel(kernel, num_outputs=1, vmap_method="broadcast_all", launch_dims=N
|
|
|
549
678
|
|
|
550
679
|
with _FFI_REGISTRY_LOCK:
|
|
551
680
|
if key not in _FFI_KERNEL_REGISTRY:
|
|
552
|
-
new_kernel = FfiKernel(kernel, num_outputs, vmap_method, launch_dims, output_dims)
|
|
681
|
+
new_kernel = FfiKernel(kernel, num_outputs, vmap_method, launch_dims, output_dims, in_out_argnames)
|
|
553
682
|
_FFI_KERNEL_REGISTRY[key] = new_kernel
|
|
554
683
|
|
|
555
684
|
return _FFI_KERNEL_REGISTRY[key]
|
|
@@ -558,9 +687,11 @@ def jax_kernel(kernel, num_outputs=1, vmap_method="broadcast_all", launch_dims=N
|
|
|
558
687
|
def jax_callable(
|
|
559
688
|
func: Callable,
|
|
560
689
|
num_outputs: int = 1,
|
|
561
|
-
graph_compatible: bool =
|
|
562
|
-
|
|
690
|
+
graph_compatible: Optional[bool] = None, # deprecated
|
|
691
|
+
graph_mode: GraphMode = GraphMode.JAX,
|
|
692
|
+
vmap_method: Optional[str] = "broadcast_all",
|
|
563
693
|
output_dims=None,
|
|
694
|
+
in_out_argnames=None,
|
|
564
695
|
):
|
|
565
696
|
"""Create a JAX callback from an annotated Python function.
|
|
566
697
|
|
|
@@ -571,31 +702,50 @@ def jax_callable(
|
|
|
571
702
|
Args:
|
|
572
703
|
func: The Python function to call.
|
|
573
704
|
num_outputs: Optional. Specify the number of output arguments if greater than 1.
|
|
705
|
+
This must include the number of ``in_out_arguments``.
|
|
574
706
|
graph_compatible: Optional. Whether the function can be called during CUDA graph capture.
|
|
707
|
+
This argument is deprecated, use ``graph_mode`` instead.
|
|
708
|
+
graph_mode: Optional. CUDA graph capture mode.
|
|
709
|
+
``GraphMode.JAX`` (default): Let JAX capture the graph, which may be used as a subgraph in an enclosing capture.
|
|
710
|
+
``GraphMode.WARP``: Let Warp capture the graph. Use this mode when the callable cannot be used as a subraph,
|
|
711
|
+
such as when the callable uses conditional graph nodes.
|
|
712
|
+
``GraphMode.NONE``: Disable graph capture. Use when the callable performs operations that are not legal in a graph,
|
|
713
|
+
such as host synchronization.
|
|
575
714
|
vmap_method: Optional. String specifying how the callback transforms under ``vmap()``.
|
|
576
715
|
This argument can also be specified for individual calls.
|
|
577
716
|
output_dims: Optional. Specify the default dimensions of output arrays.
|
|
578
717
|
If ``None``, output dimensions are inferred from the launch dimensions.
|
|
579
718
|
This argument can also be specified for individual calls.
|
|
719
|
+
in_out_argnames: Optional. Names of input-output arguments.
|
|
580
720
|
|
|
581
721
|
Limitations:
|
|
582
722
|
- All kernel arguments must be contiguous arrays or scalars.
|
|
583
723
|
- Scalars must be static arguments in JAX.
|
|
584
|
-
- Input arguments
|
|
585
|
-
- There must be at least one output argument.
|
|
724
|
+
- Input and input-output arguments must precede the output arguments in the ``func`` definition.
|
|
725
|
+
- There must be at least one output or input-output argument.
|
|
586
726
|
- Only the CUDA backend is supported.
|
|
587
727
|
"""
|
|
728
|
+
|
|
729
|
+
if graph_compatible is not None:
|
|
730
|
+
wp.utils.warn(
|
|
731
|
+
"The `graph_compatible` argument is deprecated, use `graph_mode` instead.",
|
|
732
|
+
DeprecationWarning,
|
|
733
|
+
stacklevel=3,
|
|
734
|
+
)
|
|
735
|
+
if graph_compatible is False:
|
|
736
|
+
graph_mode = GraphMode.NONE
|
|
737
|
+
|
|
588
738
|
key = (
|
|
589
739
|
func,
|
|
590
740
|
num_outputs,
|
|
591
|
-
|
|
741
|
+
graph_mode,
|
|
592
742
|
vmap_method,
|
|
593
743
|
tuple(sorted(output_dims.items())) if output_dims else output_dims,
|
|
594
744
|
)
|
|
595
745
|
|
|
596
746
|
with _FFI_REGISTRY_LOCK:
|
|
597
747
|
if key not in _FFI_CALLABLE_REGISTRY:
|
|
598
|
-
new_callable = FfiCallable(func, num_outputs,
|
|
748
|
+
new_callable = FfiCallable(func, num_outputs, graph_mode, vmap_method, output_dims, in_out_argnames)
|
|
599
749
|
_FFI_CALLABLE_REGISTRY[key] = new_callable
|
|
600
750
|
|
|
601
751
|
return _FFI_CALLABLE_REGISTRY[key]
|
|
@@ -626,7 +776,6 @@ def register_ffi_callback(name: str, func: Callable, graph_compatible: bool = Tr
|
|
|
626
776
|
|
|
627
777
|
def ffi_callback(call_frame):
|
|
628
778
|
try:
|
|
629
|
-
# TODO Try-catch around the body and return XLA_FFI_Error on error.
|
|
630
779
|
extension = call_frame.contents.extension_start
|
|
631
780
|
# On the first call, XLA runtime will query the API version and traits
|
|
632
781
|
# metadata using the |extension| field. Let us respond to that query
|