warp-lang 1.8.0__py3-none-macosx_10_13_universal2.whl → 1.8.1__py3-none-macosx_10_13_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/bin/libwarp.dylib +0 -0
- warp/build_dll.py +5 -0
- warp/codegen.py +15 -3
- warp/config.py +1 -1
- warp/context.py +122 -24
- warp/examples/interop/example_jax_callable.py +34 -4
- warp/examples/interop/example_jax_kernel.py +27 -1
- warp/fem/field/virtual.py +2 -0
- warp/fem/integrate.py +78 -47
- warp/jax_experimental/ffi.py +201 -53
- warp/native/array.h +4 -4
- warp/native/builtin.h +8 -4
- warp/native/coloring.cpp +5 -1
- warp/native/cuda_util.cpp +1 -1
- warp/native/intersect.h +2 -2
- warp/native/mat.h +3 -3
- warp/native/mesh.h +1 -1
- warp/native/quat.h +6 -2
- warp/native/rand.h +7 -7
- warp/native/sparse.cu +1 -1
- warp/native/svd.h +23 -8
- warp/native/tile.h +20 -1
- warp/native/tile_radix_sort.h +5 -1
- warp/native/tile_reduce.h +16 -25
- warp/native/tuple.h +2 -2
- warp/native/vec.h +4 -4
- warp/native/warp.cpp +1 -1
- warp/native/warp.cu +15 -2
- warp/native/warp.h +1 -1
- warp/render/render_opengl.py +52 -51
- warp/render/render_usd.py +0 -1
- warp/sim/collide.py +1 -2
- warp/sim/integrator_vbd.py +10 -2
- warp/sparse.py +1 -1
- warp/tape.py +2 -0
- warp/tests/sim/test_cloth.py +89 -6
- warp/tests/sim/test_coloring.py +76 -1
- warp/tests/test_assert.py +53 -0
- warp/tests/test_atomic_cas.py +127 -114
- warp/tests/test_mat.py +22 -0
- warp/tests/test_quat.py +22 -0
- warp/tests/test_sparse.py +32 -0
- warp/tests/test_static.py +48 -0
- warp/tests/test_tape.py +38 -0
- warp/tests/test_vec.py +38 -408
- warp/tests/test_vec_constructors.py +325 -0
- warp/tests/tile/test_tile.py +31 -143
- warp/tests/tile/test_tile_mathdx.py +2 -2
- warp/tests/tile/test_tile_matmul.py +179 -0
- warp/tests/tile/test_tile_reduce.py +100 -11
- warp/tests/tile/test_tile_shared_memory.py +12 -12
- warp/tests/tile/test_tile_sort.py +59 -55
- warp/tests/unittest_suites.py +10 -0
- {warp_lang-1.8.0.dist-info → warp_lang-1.8.1.dist-info}/METADATA +4 -4
- {warp_lang-1.8.0.dist-info → warp_lang-1.8.1.dist-info}/RECORD +58 -56
- {warp_lang-1.8.0.dist-info → warp_lang-1.8.1.dist-info}/WHEEL +0 -0
- {warp_lang-1.8.0.dist-info → warp_lang-1.8.1.dist-info}/licenses/LICENSE.md +0 -0
- {warp_lang-1.8.0.dist-info → warp_lang-1.8.1.dist-info}/top_level.txt +0 -0
warp/jax_experimental/ffi.py
CHANGED
|
@@ -16,7 +16,8 @@
|
|
|
16
16
|
import ctypes
|
|
17
17
|
import threading
|
|
18
18
|
import traceback
|
|
19
|
-
from
|
|
19
|
+
from enum import IntEnum
|
|
20
|
+
from typing import Callable, Optional
|
|
20
21
|
|
|
21
22
|
import jax
|
|
22
23
|
|
|
@@ -28,10 +29,17 @@ from warp.types import array_t, launch_bounds_t, strides_from_shape, type_to_war
|
|
|
28
29
|
from .xla_ffi import *
|
|
29
30
|
|
|
30
31
|
|
|
32
|
+
class GraphMode(IntEnum):
|
|
33
|
+
NONE = 0 # don't capture a graph
|
|
34
|
+
JAX = 1 # let JAX capture a graph
|
|
35
|
+
WARP = 2 # let Warp capture a graph
|
|
36
|
+
|
|
37
|
+
|
|
31
38
|
class FfiArg:
|
|
32
|
-
def __init__(self, name, type):
|
|
39
|
+
def __init__(self, name, type, in_out=False):
|
|
33
40
|
self.name = name
|
|
34
41
|
self.type = type
|
|
42
|
+
self.in_out = in_out
|
|
35
43
|
self.is_array = isinstance(type, wp.array)
|
|
36
44
|
|
|
37
45
|
if self.is_array:
|
|
@@ -65,7 +73,7 @@ class FfiLaunchDesc:
|
|
|
65
73
|
|
|
66
74
|
|
|
67
75
|
class FfiKernel:
|
|
68
|
-
def __init__(self, kernel, num_outputs, vmap_method, launch_dims, output_dims):
|
|
76
|
+
def __init__(self, kernel, num_outputs, vmap_method, launch_dims, output_dims, in_out_argnames):
|
|
69
77
|
self.kernel = kernel
|
|
70
78
|
self.name = generate_unique_name(kernel.func)
|
|
71
79
|
self.num_outputs = num_outputs
|
|
@@ -76,17 +84,28 @@ class FfiKernel:
|
|
|
76
84
|
self.launch_id = 0
|
|
77
85
|
self.launch_descriptors = {}
|
|
78
86
|
|
|
87
|
+
in_out_argnames_list = in_out_argnames or []
|
|
88
|
+
in_out_argnames = set(in_out_argnames_list)
|
|
89
|
+
if len(in_out_argnames_list) != len(in_out_argnames):
|
|
90
|
+
raise AssertionError("in_out_argnames must not contain duplicate names")
|
|
91
|
+
|
|
79
92
|
self.num_kernel_args = len(kernel.adj.args)
|
|
80
|
-
self.
|
|
93
|
+
self.num_in_out = len(in_out_argnames)
|
|
94
|
+
self.num_inputs = self.num_kernel_args - num_outputs + self.num_in_out
|
|
81
95
|
if self.num_outputs < 1:
|
|
82
96
|
raise ValueError("At least one output is required")
|
|
83
97
|
if self.num_outputs > self.num_kernel_args:
|
|
84
98
|
raise ValueError("Number of outputs cannot be greater than the number of kernel arguments")
|
|
99
|
+
if self.num_outputs < self.num_in_out:
|
|
100
|
+
raise ValueError("Number of outputs cannot be smaller than the number of in_out_argnames")
|
|
85
101
|
|
|
86
102
|
# process input args
|
|
87
103
|
self.input_args = []
|
|
88
104
|
for i in range(self.num_inputs):
|
|
89
|
-
|
|
105
|
+
arg_name = kernel.adj.args[i].label
|
|
106
|
+
arg = FfiArg(arg_name, kernel.adj.args[i].type, arg_name in in_out_argnames)
|
|
107
|
+
if arg_name in in_out_argnames:
|
|
108
|
+
in_out_argnames.remove(arg_name)
|
|
90
109
|
if arg.is_array:
|
|
91
110
|
# keep track of the first input array argument
|
|
92
111
|
if self.first_array_arg is None:
|
|
@@ -96,11 +115,30 @@ class FfiKernel:
|
|
|
96
115
|
# process output args
|
|
97
116
|
self.output_args = []
|
|
98
117
|
for i in range(self.num_inputs, self.num_kernel_args):
|
|
99
|
-
|
|
118
|
+
arg_name = kernel.adj.args[i].label
|
|
119
|
+
if arg_name in in_out_argnames:
|
|
120
|
+
raise AssertionError(
|
|
121
|
+
f"Expected an output-only argument for argument {arg_name}."
|
|
122
|
+
" in_out arguments should be placed before output-only arguments."
|
|
123
|
+
)
|
|
124
|
+
arg = FfiArg(arg_name, kernel.adj.args[i].type, False)
|
|
100
125
|
if not arg.is_array:
|
|
101
126
|
raise TypeError("All output arguments must be arrays")
|
|
102
127
|
self.output_args.append(arg)
|
|
103
128
|
|
|
129
|
+
if in_out_argnames:
|
|
130
|
+
raise ValueError(f"in_out_argnames: '{in_out_argnames}' did not match any function argument names.")
|
|
131
|
+
|
|
132
|
+
# Build input output aliases.
|
|
133
|
+
out_id = 0
|
|
134
|
+
input_output_aliases = {}
|
|
135
|
+
for in_id, arg in enumerate(self.input_args):
|
|
136
|
+
if not arg.in_out:
|
|
137
|
+
continue
|
|
138
|
+
input_output_aliases[in_id] = out_id
|
|
139
|
+
out_id += 1
|
|
140
|
+
self.input_output_aliases = input_output_aliases
|
|
141
|
+
|
|
104
142
|
# register the callback
|
|
105
143
|
FFI_CCALLFUNC = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_CallFrame))
|
|
106
144
|
self.callback_func = FFI_CCALLFUNC(lambda call_frame: self.ffi_callback(call_frame))
|
|
@@ -121,6 +159,9 @@ class FfiKernel:
|
|
|
121
159
|
if vmap_method is None:
|
|
122
160
|
vmap_method = self.vmap_method
|
|
123
161
|
|
|
162
|
+
# output types
|
|
163
|
+
out_types = []
|
|
164
|
+
|
|
124
165
|
# process inputs
|
|
125
166
|
static_inputs = {}
|
|
126
167
|
for i in range(num_inputs):
|
|
@@ -150,6 +191,10 @@ class FfiKernel:
|
|
|
150
191
|
# stash the value to be retrieved by callback
|
|
151
192
|
static_inputs[input_arg.name] = input_arg.type(input_value)
|
|
152
193
|
|
|
194
|
+
# append in-out arg to output types
|
|
195
|
+
if input_arg.in_out:
|
|
196
|
+
out_types.append(get_jax_output_type(input_arg, input_value.shape))
|
|
197
|
+
|
|
153
198
|
# launch dimensions
|
|
154
199
|
if launch_dims is None:
|
|
155
200
|
# use the shape of the first input array
|
|
@@ -162,8 +207,7 @@ class FfiKernel:
|
|
|
162
207
|
else:
|
|
163
208
|
launch_dims = tuple(launch_dims)
|
|
164
209
|
|
|
165
|
-
# output
|
|
166
|
-
out_types = []
|
|
210
|
+
# output shapes
|
|
167
211
|
if isinstance(output_dims, dict):
|
|
168
212
|
# assume a dictionary of shapes keyed on argument name
|
|
169
213
|
for output_arg in self.output_args:
|
|
@@ -185,6 +229,7 @@ class FfiKernel:
|
|
|
185
229
|
self.name,
|
|
186
230
|
out_types,
|
|
187
231
|
vmap_method=vmap_method,
|
|
232
|
+
input_output_aliases=self.input_output_aliases,
|
|
188
233
|
)
|
|
189
234
|
|
|
190
235
|
# ensure the kernel module is loaded before the callback, otherwise graph capture may fail
|
|
@@ -238,9 +283,8 @@ class FfiKernel:
|
|
|
238
283
|
|
|
239
284
|
arg_refs = []
|
|
240
285
|
|
|
241
|
-
#
|
|
242
|
-
for i in
|
|
243
|
-
input_arg = self.input_args[i]
|
|
286
|
+
# input and in-out args
|
|
287
|
+
for i, input_arg in enumerate(self.input_args):
|
|
244
288
|
if input_arg.is_array:
|
|
245
289
|
buffer = inputs[i].contents
|
|
246
290
|
shape = buffer.dims[: input_arg.type.ndim]
|
|
@@ -255,10 +299,9 @@ class FfiKernel:
|
|
|
255
299
|
kernel_params[i + 1] = ctypes.addressof(arg)
|
|
256
300
|
arg_refs.append(arg) # keep a reference
|
|
257
301
|
|
|
258
|
-
#
|
|
259
|
-
for i in
|
|
260
|
-
|
|
261
|
-
buffer = outputs[i].contents
|
|
302
|
+
# pure output args (skip in-out FFI buffers)
|
|
303
|
+
for i, output_arg in enumerate(self.output_args):
|
|
304
|
+
buffer = outputs[i + self.num_in_out].contents
|
|
262
305
|
shape = buffer.dims[: output_arg.type.ndim]
|
|
263
306
|
strides = strides_from_shape(shape, output_arg.type.dtype)
|
|
264
307
|
arg = array_t(buffer.data, 0, output_arg.type.ndim, shape, strides)
|
|
@@ -295,29 +338,38 @@ class FfiKernel:
|
|
|
295
338
|
class FfiCallDesc:
|
|
296
339
|
def __init__(self, static_inputs):
|
|
297
340
|
self.static_inputs = static_inputs
|
|
341
|
+
self.captures = {}
|
|
298
342
|
|
|
299
343
|
|
|
300
344
|
class FfiCallable:
|
|
301
|
-
def __init__(self, func, num_outputs,
|
|
345
|
+
def __init__(self, func, num_outputs, graph_mode, vmap_method, output_dims, in_out_argnames):
|
|
302
346
|
self.func = func
|
|
303
347
|
self.name = generate_unique_name(func)
|
|
304
348
|
self.num_outputs = num_outputs
|
|
305
349
|
self.vmap_method = vmap_method
|
|
306
|
-
self.
|
|
350
|
+
self.graph_mode = graph_mode
|
|
307
351
|
self.output_dims = output_dims
|
|
308
352
|
self.first_array_arg = None
|
|
309
353
|
self.call_id = 0
|
|
310
354
|
self.call_descriptors = {}
|
|
311
355
|
|
|
356
|
+
in_out_argnames_list = in_out_argnames or []
|
|
357
|
+
in_out_argnames = set(in_out_argnames_list)
|
|
358
|
+
if len(in_out_argnames_list) != len(in_out_argnames):
|
|
359
|
+
raise AssertionError("in_out_argnames must not contain duplicate names")
|
|
360
|
+
|
|
312
361
|
# get arguments and annotations
|
|
313
362
|
argspec = get_full_arg_spec(func)
|
|
314
363
|
|
|
315
364
|
num_args = len(argspec.args)
|
|
316
|
-
self.
|
|
365
|
+
self.num_in_out = len(in_out_argnames)
|
|
366
|
+
self.num_inputs = num_args - num_outputs + self.num_in_out
|
|
317
367
|
if self.num_outputs < 1:
|
|
318
368
|
raise ValueError("At least one output is required")
|
|
319
369
|
if self.num_outputs > num_args:
|
|
320
370
|
raise ValueError("Number of outputs cannot be greater than the number of kernel arguments")
|
|
371
|
+
if self.num_outputs < self.num_in_out:
|
|
372
|
+
raise ValueError("Number of outputs cannot be smaller than the number of in_out_argnames")
|
|
321
373
|
|
|
322
374
|
if len(argspec.annotations) < num_args:
|
|
323
375
|
raise RuntimeError(f"Incomplete argument annotations on function {self.name}")
|
|
@@ -330,15 +382,43 @@ class FfiCallable:
|
|
|
330
382
|
if arg_type is not None:
|
|
331
383
|
raise TypeError("Function must not return a value")
|
|
332
384
|
else:
|
|
333
|
-
arg = FfiArg(arg_name, arg_type)
|
|
385
|
+
arg = FfiArg(arg_name, arg_type, arg_name in in_out_argnames)
|
|
386
|
+
if arg_name in in_out_argnames:
|
|
387
|
+
in_out_argnames.remove(arg_name)
|
|
334
388
|
if arg.is_array:
|
|
335
389
|
if arg_idx < self.num_inputs and self.first_array_arg is None:
|
|
336
390
|
self.first_array_arg = arg_idx
|
|
337
391
|
self.args.append(arg)
|
|
392
|
+
|
|
393
|
+
if arg.in_out and arg_idx >= self.num_inputs:
|
|
394
|
+
raise AssertionError(
|
|
395
|
+
f"Expected an output-only argument for argument {arg_name}."
|
|
396
|
+
" in_out arguments should be placed before output-only arguments."
|
|
397
|
+
)
|
|
398
|
+
|
|
338
399
|
arg_idx += 1
|
|
339
400
|
|
|
340
|
-
|
|
341
|
-
|
|
401
|
+
if in_out_argnames:
|
|
402
|
+
raise ValueError(f"in_out_argnames: '{in_out_argnames}' did not match any function argument names.")
|
|
403
|
+
|
|
404
|
+
self.input_args = self.args[: self.num_inputs] # includes in-out args
|
|
405
|
+
self.output_args = self.args[self.num_inputs :] # pure output args
|
|
406
|
+
|
|
407
|
+
# Buffer indices for array arguments in callback.
|
|
408
|
+
# In-out buffers are the same pointers in the XLA call frame,
|
|
409
|
+
# so we only include them for inputs and skip them for outputs.
|
|
410
|
+
self.array_input_indices = [i for i, arg in enumerate(self.input_args) if arg.is_array]
|
|
411
|
+
self.array_output_indices = list(range(self.num_in_out, self.num_outputs))
|
|
412
|
+
|
|
413
|
+
# Build input output aliases.
|
|
414
|
+
out_id = 0
|
|
415
|
+
input_output_aliases = {}
|
|
416
|
+
for in_id, arg in enumerate(self.input_args):
|
|
417
|
+
if not arg.in_out:
|
|
418
|
+
continue
|
|
419
|
+
input_output_aliases[in_id] = out_id
|
|
420
|
+
out_id += 1
|
|
421
|
+
self.input_output_aliases = input_output_aliases
|
|
342
422
|
|
|
343
423
|
# register the callback
|
|
344
424
|
FFI_CCALLFUNC = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_CallFrame))
|
|
@@ -350,7 +430,9 @@ class FfiCallable:
|
|
|
350
430
|
def __call__(self, *args, output_dims=None, vmap_method=None):
|
|
351
431
|
num_inputs = len(args)
|
|
352
432
|
if num_inputs != self.num_inputs:
|
|
353
|
-
|
|
433
|
+
input_names = ", ".join(arg.name for arg in self.input_args)
|
|
434
|
+
s = "" if self.num_inputs == 1 else "s"
|
|
435
|
+
raise ValueError(f"Expected {self.num_inputs} input{s} ({input_names}), but got {num_inputs}")
|
|
354
436
|
|
|
355
437
|
# default argument fallback
|
|
356
438
|
if vmap_method is None:
|
|
@@ -358,6 +440,9 @@ class FfiCallable:
|
|
|
358
440
|
if output_dims is None:
|
|
359
441
|
output_dims = self.output_dims
|
|
360
442
|
|
|
443
|
+
# output types
|
|
444
|
+
out_types = []
|
|
445
|
+
|
|
361
446
|
# process inputs
|
|
362
447
|
static_inputs = {}
|
|
363
448
|
for i in range(num_inputs):
|
|
@@ -387,12 +472,11 @@ class FfiCallable:
|
|
|
387
472
|
# stash the value to be retrieved by callback
|
|
388
473
|
static_inputs[input_arg.name] = input_arg.type(input_value)
|
|
389
474
|
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
475
|
+
# append in-out arg to output types
|
|
476
|
+
if input_arg.in_out:
|
|
477
|
+
out_types.append(get_jax_output_type(input_arg, input_value.shape))
|
|
393
478
|
|
|
394
|
-
# output
|
|
395
|
-
out_types = []
|
|
479
|
+
# output shapes
|
|
396
480
|
if isinstance(output_dims, dict):
|
|
397
481
|
# assume a dictionary of shapes keyed on argument name
|
|
398
482
|
for output_arg in self.output_args:
|
|
@@ -402,7 +486,9 @@ class FfiCallable:
|
|
|
402
486
|
out_types.append(get_jax_output_type(output_arg, dims))
|
|
403
487
|
else:
|
|
404
488
|
if output_dims is None:
|
|
405
|
-
|
|
489
|
+
if self.first_array_arg is None:
|
|
490
|
+
raise ValueError("Unable to determine output dimensions")
|
|
491
|
+
output_dims = get_warp_shape(self.input_args[self.first_array_arg], args[self.first_array_arg].shape)
|
|
406
492
|
elif isinstance(output_dims, int):
|
|
407
493
|
output_dims = (output_dims,)
|
|
408
494
|
# assume same dimensions for all outputs
|
|
@@ -413,6 +499,7 @@ class FfiCallable:
|
|
|
413
499
|
self.name,
|
|
414
500
|
out_types,
|
|
415
501
|
vmap_method=vmap_method,
|
|
502
|
+
input_output_aliases=self.input_output_aliases,
|
|
416
503
|
# has_side_effect=True, # force this function to execute even if outputs aren't used
|
|
417
504
|
)
|
|
418
505
|
|
|
@@ -430,11 +517,10 @@ class FfiCallable:
|
|
|
430
517
|
|
|
431
518
|
def ffi_callback(self, call_frame):
|
|
432
519
|
try:
|
|
433
|
-
# TODO Try-catch around the body and return XLA_FFI_Error on error.
|
|
434
|
-
extension = call_frame.contents.extension_start
|
|
435
520
|
# On the first call, XLA runtime will query the API version and traits
|
|
436
521
|
# metadata using the |extension| field. Let us respond to that query
|
|
437
522
|
# if the metadata extension is present.
|
|
523
|
+
extension = call_frame.contents.extension_start
|
|
438
524
|
if extension:
|
|
439
525
|
# Try to set the version metadata.
|
|
440
526
|
if extension.contents.type == XLA_FFI_Extension_Type.Metadata:
|
|
@@ -442,15 +528,19 @@ class FfiCallable:
|
|
|
442
528
|
metadata_ext.contents.metadata.contents.api_version.major_version = 0
|
|
443
529
|
metadata_ext.contents.metadata.contents.api_version.minor_version = 1
|
|
444
530
|
# Turn on CUDA graphs for this handler.
|
|
445
|
-
if self.
|
|
531
|
+
if self.graph_mode is GraphMode.JAX:
|
|
446
532
|
metadata_ext.contents.metadata.contents.traits = (
|
|
447
533
|
XLA_FFI_Handler_TraitsBits.COMMAND_BUFFER_COMPATIBLE
|
|
448
534
|
)
|
|
449
535
|
return None
|
|
450
536
|
|
|
451
537
|
# retrieve call info
|
|
452
|
-
|
|
453
|
-
|
|
538
|
+
# NOTE: this assumes that there's only one attribute - call_id (int64).
|
|
539
|
+
# A more general but slower approach is this:
|
|
540
|
+
# attrs = decode_attrs(call_frame.contents.attrs)
|
|
541
|
+
# call_id = int(attrs["call_id"])
|
|
542
|
+
attr = ctypes.cast(call_frame.contents.attrs.attrs[0], ctypes.POINTER(XLA_FFI_Scalar)).contents
|
|
543
|
+
call_id = ctypes.cast(attr.value, ctypes.POINTER(ctypes.c_int64)).contents.value
|
|
454
544
|
call_desc = self.call_descriptors[call_id]
|
|
455
545
|
|
|
456
546
|
num_inputs = call_frame.contents.args.size
|
|
@@ -462,16 +552,42 @@ class FfiCallable:
|
|
|
462
552
|
assert num_inputs == self.num_inputs
|
|
463
553
|
assert num_outputs == self.num_outputs
|
|
464
554
|
|
|
465
|
-
device = wp.device_from_jax(get_jax_device())
|
|
466
555
|
cuda_stream = get_stream_from_callframe(call_frame.contents)
|
|
556
|
+
|
|
557
|
+
if self.graph_mode == GraphMode.WARP:
|
|
558
|
+
# check if we already captured an identical call
|
|
559
|
+
ip = [inputs[i].contents.data for i in self.array_input_indices]
|
|
560
|
+
op = [outputs[i].contents.data for i in self.array_output_indices]
|
|
561
|
+
buffer_hash = hash((*ip, *op))
|
|
562
|
+
capture = call_desc.captures.get(buffer_hash)
|
|
563
|
+
|
|
564
|
+
# launch existing graph
|
|
565
|
+
if capture is not None:
|
|
566
|
+
# NOTE: We use the native graph API to avoid overhead with obtaining Stream and Device objects in Python.
|
|
567
|
+
# This code should match wp.capture_launch().
|
|
568
|
+
graph = capture.graph
|
|
569
|
+
if graph.graph_exec is None:
|
|
570
|
+
g = ctypes.c_void_p()
|
|
571
|
+
if not wp.context.runtime.core.wp_cuda_graph_create_exec(
|
|
572
|
+
graph.device.context, cuda_stream, graph.graph, ctypes.byref(g)
|
|
573
|
+
):
|
|
574
|
+
raise RuntimeError(f"Graph creation error: {wp.context.runtime.get_error_string()}")
|
|
575
|
+
graph.graph_exec = g
|
|
576
|
+
|
|
577
|
+
if not wp.context.runtime.core.wp_cuda_graph_launch(graph.graph_exec, cuda_stream):
|
|
578
|
+
raise RuntimeError(f"Graph launch error: {wp.context.runtime.get_error_string()}")
|
|
579
|
+
|
|
580
|
+
# early out
|
|
581
|
+
return
|
|
582
|
+
|
|
583
|
+
device = wp.device_from_jax(get_jax_device())
|
|
467
584
|
stream = wp.Stream(device, cuda_stream=cuda_stream)
|
|
468
585
|
|
|
469
586
|
# reconstruct the argument list
|
|
470
587
|
arg_list = []
|
|
471
588
|
|
|
472
|
-
#
|
|
473
|
-
for i in
|
|
474
|
-
arg = self.input_args[i]
|
|
589
|
+
# input and in-out args
|
|
590
|
+
for i, arg in enumerate(self.input_args):
|
|
475
591
|
if arg.is_array:
|
|
476
592
|
buffer = inputs[i].contents
|
|
477
593
|
shape = buffer.dims[: buffer.rank - arg.dtype_ndim]
|
|
@@ -482,10 +598,9 @@ class FfiCallable:
|
|
|
482
598
|
value = call_desc.static_inputs[arg.name]
|
|
483
599
|
arg_list.append(value)
|
|
484
600
|
|
|
485
|
-
#
|
|
486
|
-
for i in
|
|
487
|
-
|
|
488
|
-
buffer = outputs[i].contents
|
|
601
|
+
# pure output args (skip in-out FFI buffers)
|
|
602
|
+
for i, arg in enumerate(self.output_args):
|
|
603
|
+
buffer = outputs[i + self.num_in_out].contents
|
|
489
604
|
shape = buffer.dims[: buffer.rank - arg.dtype_ndim]
|
|
490
605
|
arr = wp.array(ptr=buffer.data, dtype=arg.type.dtype, shape=shape, device=device)
|
|
491
606
|
arg_list.append(arr)
|
|
@@ -493,11 +608,20 @@ class FfiCallable:
|
|
|
493
608
|
# call the Python function with reconstructed arguments
|
|
494
609
|
with wp.ScopedStream(stream, sync_enter=False):
|
|
495
610
|
if stream.is_capturing:
|
|
496
|
-
|
|
611
|
+
# capturing with JAX
|
|
612
|
+
with wp.ScopedCapture(external=True) as capture:
|
|
497
613
|
self.func(*arg_list)
|
|
498
614
|
# keep a reference to the capture object to prevent required modules getting unloaded
|
|
499
615
|
call_desc.capture = capture
|
|
616
|
+
elif self.graph_mode == GraphMode.WARP:
|
|
617
|
+
# capturing with WARP
|
|
618
|
+
with wp.ScopedCapture() as capture:
|
|
619
|
+
self.func(*arg_list)
|
|
620
|
+
wp.capture_launch(capture.graph)
|
|
621
|
+
# keep a reference to the capture object and reuse it with same buffers
|
|
622
|
+
call_desc.captures[buffer_hash] = capture
|
|
500
623
|
else:
|
|
624
|
+
# not capturing
|
|
501
625
|
self.func(*arg_list)
|
|
502
626
|
|
|
503
627
|
except Exception as e:
|
|
@@ -515,7 +639,9 @@ _FFI_KERNEL_REGISTRY: dict[str, FfiKernel] = {}
|
|
|
515
639
|
_FFI_REGISTRY_LOCK = threading.Lock()
|
|
516
640
|
|
|
517
641
|
|
|
518
|
-
def jax_kernel(
|
|
642
|
+
def jax_kernel(
|
|
643
|
+
kernel, num_outputs=1, vmap_method="broadcast_all", launch_dims=None, output_dims=None, in_out_argnames=None
|
|
644
|
+
):
|
|
519
645
|
"""Create a JAX callback from a Warp kernel.
|
|
520
646
|
|
|
521
647
|
NOTE: This is an experimental feature under development.
|
|
@@ -523,6 +649,7 @@ def jax_kernel(kernel, num_outputs=1, vmap_method="broadcast_all", launch_dims=N
|
|
|
523
649
|
Args:
|
|
524
650
|
kernel: The Warp kernel to launch.
|
|
525
651
|
num_outputs: Optional. Specify the number of output arguments if greater than 1.
|
|
652
|
+
This must include the number of ``in_out_arguments``.
|
|
526
653
|
vmap_method: Optional. String specifying how the callback transforms under ``vmap()``.
|
|
527
654
|
This argument can also be specified for individual calls.
|
|
528
655
|
launch_dims: Optional. Specify the default kernel launch dimensions. If None, launch
|
|
@@ -531,12 +658,13 @@ def jax_kernel(kernel, num_outputs=1, vmap_method="broadcast_all", launch_dims=N
|
|
|
531
658
|
output_dims: Optional. Specify the default dimensions of output arrays. If None, output
|
|
532
659
|
dimensions are inferred from the launch dimensions.
|
|
533
660
|
This argument can also be specified for individual calls.
|
|
661
|
+
in_out_argnames: Optional. Names of input-output arguments.
|
|
534
662
|
|
|
535
663
|
Limitations:
|
|
536
664
|
- All kernel arguments must be contiguous arrays or scalars.
|
|
537
665
|
- Scalars must be static arguments in JAX.
|
|
538
|
-
- Input arguments
|
|
539
|
-
- There must be at least one output argument.
|
|
666
|
+
- Input and input-output arguments must precede the output arguments in the ``kernel`` definition.
|
|
667
|
+
- There must be at least one output or input-output argument.
|
|
540
668
|
- Only the CUDA backend is supported.
|
|
541
669
|
"""
|
|
542
670
|
key = (
|
|
@@ -549,7 +677,7 @@ def jax_kernel(kernel, num_outputs=1, vmap_method="broadcast_all", launch_dims=N
|
|
|
549
677
|
|
|
550
678
|
with _FFI_REGISTRY_LOCK:
|
|
551
679
|
if key not in _FFI_KERNEL_REGISTRY:
|
|
552
|
-
new_kernel = FfiKernel(kernel, num_outputs, vmap_method, launch_dims, output_dims)
|
|
680
|
+
new_kernel = FfiKernel(kernel, num_outputs, vmap_method, launch_dims, output_dims, in_out_argnames)
|
|
553
681
|
_FFI_KERNEL_REGISTRY[key] = new_kernel
|
|
554
682
|
|
|
555
683
|
return _FFI_KERNEL_REGISTRY[key]
|
|
@@ -558,9 +686,11 @@ def jax_kernel(kernel, num_outputs=1, vmap_method="broadcast_all", launch_dims=N
|
|
|
558
686
|
def jax_callable(
|
|
559
687
|
func: Callable,
|
|
560
688
|
num_outputs: int = 1,
|
|
561
|
-
graph_compatible: bool =
|
|
562
|
-
|
|
689
|
+
graph_compatible: Optional[bool] = None, # deprecated
|
|
690
|
+
graph_mode: GraphMode = GraphMode.JAX,
|
|
691
|
+
vmap_method: Optional[str] = "broadcast_all",
|
|
563
692
|
output_dims=None,
|
|
693
|
+
in_out_argnames=None,
|
|
564
694
|
):
|
|
565
695
|
"""Create a JAX callback from an annotated Python function.
|
|
566
696
|
|
|
@@ -571,31 +701,50 @@ def jax_callable(
|
|
|
571
701
|
Args:
|
|
572
702
|
func: The Python function to call.
|
|
573
703
|
num_outputs: Optional. Specify the number of output arguments if greater than 1.
|
|
704
|
+
This must include the number of ``in_out_arguments``.
|
|
574
705
|
graph_compatible: Optional. Whether the function can be called during CUDA graph capture.
|
|
706
|
+
This argument is deprecated, use ``graph_mode`` instead.
|
|
707
|
+
graph_mode: Optional. CUDA graph capture mode.
|
|
708
|
+
``GraphMode.JAX`` (default): Let JAX capture the graph, which may be used as a subgraph in an enclosing capture.
|
|
709
|
+
``GraphMode.WARP``: Let Warp capture the graph. Use this mode when the callable cannot be used as a subraph,
|
|
710
|
+
such as when the callable uses conditional graph nodes.
|
|
711
|
+
``GraphMode.NONE``: Disable graph capture. Use when the callable performs operations that are not legal in a graph,
|
|
712
|
+
such as host synchronization.
|
|
575
713
|
vmap_method: Optional. String specifying how the callback transforms under ``vmap()``.
|
|
576
714
|
This argument can also be specified for individual calls.
|
|
577
715
|
output_dims: Optional. Specify the default dimensions of output arrays.
|
|
578
716
|
If ``None``, output dimensions are inferred from the launch dimensions.
|
|
579
717
|
This argument can also be specified for individual calls.
|
|
718
|
+
in_out_argnames: Optional. Names of input-output arguments.
|
|
580
719
|
|
|
581
720
|
Limitations:
|
|
582
721
|
- All kernel arguments must be contiguous arrays or scalars.
|
|
583
722
|
- Scalars must be static arguments in JAX.
|
|
584
|
-
- Input arguments
|
|
585
|
-
- There must be at least one output argument.
|
|
723
|
+
- Input and input-output arguments must precede the output arguments in the ``func`` definition.
|
|
724
|
+
- There must be at least one output or input-output argument.
|
|
586
725
|
- Only the CUDA backend is supported.
|
|
587
726
|
"""
|
|
727
|
+
|
|
728
|
+
if graph_compatible is not None:
|
|
729
|
+
wp.utils.warn(
|
|
730
|
+
"The `graph_compatible` argument is deprecated, use `graph_mode` instead.",
|
|
731
|
+
DeprecationWarning,
|
|
732
|
+
stacklevel=3,
|
|
733
|
+
)
|
|
734
|
+
if graph_compatible is False:
|
|
735
|
+
graph_mode = GraphMode.NONE
|
|
736
|
+
|
|
588
737
|
key = (
|
|
589
738
|
func,
|
|
590
739
|
num_outputs,
|
|
591
|
-
|
|
740
|
+
graph_mode,
|
|
592
741
|
vmap_method,
|
|
593
742
|
tuple(sorted(output_dims.items())) if output_dims else output_dims,
|
|
594
743
|
)
|
|
595
744
|
|
|
596
745
|
with _FFI_REGISTRY_LOCK:
|
|
597
746
|
if key not in _FFI_CALLABLE_REGISTRY:
|
|
598
|
-
new_callable = FfiCallable(func, num_outputs,
|
|
747
|
+
new_callable = FfiCallable(func, num_outputs, graph_mode, vmap_method, output_dims, in_out_argnames)
|
|
599
748
|
_FFI_CALLABLE_REGISTRY[key] = new_callable
|
|
600
749
|
|
|
601
750
|
return _FFI_CALLABLE_REGISTRY[key]
|
|
@@ -626,7 +775,6 @@ def register_ffi_callback(name: str, func: Callable, graph_compatible: bool = Tr
|
|
|
626
775
|
|
|
627
776
|
def ffi_callback(call_frame):
|
|
628
777
|
try:
|
|
629
|
-
# TODO Try-catch around the body and return XLA_FFI_Error on error.
|
|
630
778
|
extension = call_frame.contents.extension_start
|
|
631
779
|
# On the first call, XLA runtime will query the API version and traits
|
|
632
780
|
# metadata using the |extension| field. Let us respond to that query
|
warp/native/array.h
CHANGED
|
@@ -161,7 +161,7 @@ inline CUDA_CALLABLE void print(shape_t s)
|
|
|
161
161
|
// should probably store ndim with shape
|
|
162
162
|
printf("(%d, %d, %d, %d)\n", s.dims[0], s.dims[1], s.dims[2], s.dims[3]);
|
|
163
163
|
}
|
|
164
|
-
inline CUDA_CALLABLE void adj_print(shape_t s, shape_t&
|
|
164
|
+
inline CUDA_CALLABLE void adj_print(shape_t s, shape_t& adj_s) {}
|
|
165
165
|
|
|
166
166
|
|
|
167
167
|
template <typename T>
|
|
@@ -665,11 +665,11 @@ CUDA_CALLABLE inline indexedarray_t<T> view(indexedarray_t<T>& src, int i, int j
|
|
|
665
665
|
}
|
|
666
666
|
|
|
667
667
|
template<template<typename> class A1, template<typename> class A2, template<typename> class A3, typename T>
|
|
668
|
-
inline CUDA_CALLABLE void adj_view(A1<T>& src, int i, A2<T>& adj_src, int adj_i, A3<T
|
|
668
|
+
inline CUDA_CALLABLE void adj_view(A1<T>& src, int i, A2<T>& adj_src, int adj_i, A3<T>& adj_ret) {}
|
|
669
669
|
template<template<typename> class A1, template<typename> class A2, template<typename> class A3, typename T>
|
|
670
|
-
inline CUDA_CALLABLE void adj_view(A1<T>& src, int i, int j, A2<T>& adj_src, int adj_i, int adj_j, A3<T
|
|
670
|
+
inline CUDA_CALLABLE void adj_view(A1<T>& src, int i, int j, A2<T>& adj_src, int adj_i, int adj_j, A3<T>& adj_ret) {}
|
|
671
671
|
template<template<typename> class A1, template<typename> class A2, template<typename> class A3, typename T>
|
|
672
|
-
inline CUDA_CALLABLE void adj_view(A1<T>& src, int i, int j, int k, A2<T>& adj_src, int adj_i, int adj_j, int adj_k, A3<T
|
|
672
|
+
inline CUDA_CALLABLE void adj_view(A1<T>& src, int i, int j, int k, A2<T>& adj_src, int adj_i, int adj_j, int adj_k, A3<T>& adj_ret) {}
|
|
673
673
|
|
|
674
674
|
// TODO: lower_bound() for indexed arrays?
|
|
675
675
|
|
warp/native/builtin.h
CHANGED
|
@@ -268,16 +268,20 @@ inline CUDA_CALLABLE half operator / (half a,half b)
|
|
|
268
268
|
|
|
269
269
|
|
|
270
270
|
template <typename T>
|
|
271
|
-
CUDA_CALLABLE float cast_float(T x) { return (float)(x); }
|
|
271
|
+
CUDA_CALLABLE inline float cast_float(T x) { return (float)(x); }
|
|
272
272
|
|
|
273
273
|
template <typename T>
|
|
274
|
-
CUDA_CALLABLE int cast_int(T x) { return (int)(x); }
|
|
274
|
+
CUDA_CALLABLE inline int cast_int(T x) { return (int)(x); }
|
|
275
275
|
|
|
276
276
|
template <typename T>
|
|
277
|
-
CUDA_CALLABLE void adj_cast_float(T x, T& adj_x, float adj_ret) {
|
|
277
|
+
CUDA_CALLABLE inline void adj_cast_float(T x, T& adj_x, float adj_ret) {}
|
|
278
|
+
|
|
279
|
+
CUDA_CALLABLE inline void adj_cast_float(float16 x, float16& adj_x, float adj_ret) { adj_x += float16(adj_ret); }
|
|
280
|
+
CUDA_CALLABLE inline void adj_cast_float(float32 x, float32& adj_x, float adj_ret) { adj_x += float32(adj_ret); }
|
|
281
|
+
CUDA_CALLABLE inline void adj_cast_float(float64 x, float64& adj_x, float adj_ret) { adj_x += float64(adj_ret); }
|
|
278
282
|
|
|
279
283
|
template <typename T>
|
|
280
|
-
CUDA_CALLABLE void adj_cast_int(T x, T& adj_x, int adj_ret) {
|
|
284
|
+
CUDA_CALLABLE inline void adj_cast_int(T x, T& adj_x, int adj_ret) {}
|
|
281
285
|
|
|
282
286
|
template <typename T>
|
|
283
287
|
CUDA_CALLABLE inline void adj_int8(T, T&, int8) {}
|
warp/native/coloring.cpp
CHANGED
|
@@ -209,9 +209,13 @@ float balance_color_groups(float target_max_min_ratio,
|
|
|
209
209
|
do
|
|
210
210
|
{
|
|
211
211
|
int biggest_group = -1, smallest_group = -1;
|
|
212
|
-
|
|
212
|
+
float prev_max_min_ratio = max_min_ratio;
|
|
213
213
|
max_min_ratio = find_largest_smallest_groups(color_groups, biggest_group, smallest_group);
|
|
214
214
|
|
|
215
|
+
if (prev_max_min_ratio > 0 && prev_max_min_ratio < max_min_ratio) {
|
|
216
|
+
return max_min_ratio;
|
|
217
|
+
}
|
|
218
|
+
|
|
215
219
|
// graph is not optimizable anymore or target ratio reached
|
|
216
220
|
if (color_groups[biggest_group].size() - color_groups[smallest_group].size() <= 2
|
|
217
221
|
|| max_min_ratio < target_max_min_ratio)
|
warp/native/cuda_util.cpp
CHANGED
|
@@ -212,7 +212,7 @@ bool init_cuda_driver()
|
|
|
212
212
|
get_driver_entry_point("cuDeviceGetCount", 2000, &(void*&)pfn_cuDeviceGetCount);
|
|
213
213
|
get_driver_entry_point("cuDeviceGetName", 2000, &(void*&)pfn_cuDeviceGetName);
|
|
214
214
|
get_driver_entry_point("cuDeviceGetAttribute", 2000, &(void*&)pfn_cuDeviceGetAttribute);
|
|
215
|
-
get_driver_entry_point("cuDeviceGetUuid",
|
|
215
|
+
get_driver_entry_point("cuDeviceGetUuid", 11040, &(void*&)pfn_cuDeviceGetUuid);
|
|
216
216
|
get_driver_entry_point("cuDevicePrimaryCtxRetain", 7000, &(void*&)pfn_cuDevicePrimaryCtxRetain);
|
|
217
217
|
get_driver_entry_point("cuDevicePrimaryCtxRelease", 11000, &(void*&)pfn_cuDevicePrimaryCtxRelease);
|
|
218
218
|
get_driver_entry_point("cuDeviceCanAccessPeer", 4000, &(void*&)pfn_cuDeviceCanAccessPeer);
|
warp/native/intersect.h
CHANGED
|
@@ -316,7 +316,7 @@ CUDA_CALLABLE inline bool intersect_ray_tri_woop(const vec3& p, const vec3& dir,
|
|
|
316
316
|
|
|
317
317
|
if (dir[kz] < 0.0f)
|
|
318
318
|
{
|
|
319
|
-
|
|
319
|
+
int tmp = kx;
|
|
320
320
|
kx = ky;
|
|
321
321
|
ky = tmp;
|
|
322
322
|
}
|
|
@@ -410,7 +410,7 @@ CUDA_CALLABLE inline void adj_intersect_ray_tri_woop(
|
|
|
410
410
|
|
|
411
411
|
if (dir[kz] < 0.0f)
|
|
412
412
|
{
|
|
413
|
-
|
|
413
|
+
int tmp = kx;
|
|
414
414
|
kx = ky;
|
|
415
415
|
ky = tmp;
|
|
416
416
|
}
|
warp/native/mat.h
CHANGED
|
@@ -1533,13 +1533,13 @@ inline CUDA_CALLABLE void adj_div(const mat_t<Rows,Cols,Type>& a, Type s, mat_t<
|
|
|
1533
1533
|
template<unsigned Rows, unsigned Cols, typename Type>
|
|
1534
1534
|
inline CUDA_CALLABLE void adj_div(Type s, const mat_t<Rows,Cols,Type>& a, Type& adj_s, mat_t<Rows,Cols,Type>& adj_a, const mat_t<Rows,Cols,Type>& adj_ret)
|
|
1535
1535
|
{
|
|
1536
|
-
adj_s -= tensordot(a , adj_ret)/ (s * s); // - a / s^2
|
|
1537
|
-
|
|
1538
1536
|
for (unsigned i=0; i < Rows; ++i)
|
|
1539
1537
|
{
|
|
1540
1538
|
for (unsigned j=0; j < Cols; ++j)
|
|
1541
1539
|
{
|
|
1542
|
-
|
|
1540
|
+
Type inv = Type(1) / a.data[i][j];
|
|
1541
|
+
adj_a.data[i][j] -= s * adj_ret.data[i][j] * inv * inv;
|
|
1542
|
+
adj_s += adj_ret.data[i][j] * inv;
|
|
1543
1543
|
}
|
|
1544
1544
|
}
|
|
1545
1545
|
}
|
warp/native/mesh.h
CHANGED
|
@@ -1357,7 +1357,7 @@ CUDA_CALLABLE inline void adj_mesh_query_point_sign_normal(uint64_t id, const ve
|
|
|
1357
1357
|
uint64_t adj_id, vec3& adj_point, float& adj_max_dist, float& adj_epsilon, mesh_query_point_t& adj_ret)
|
|
1358
1358
|
{
|
|
1359
1359
|
adj_mesh_query_point_sign_normal(id, point, max_dist, ret.sign, ret.face, ret.u, ret.v, epsilon,
|
|
1360
|
-
adj_id, adj_point, adj_max_dist, adj_ret.sign, adj_ret.face, adj_ret.u, adj_ret.v,
|
|
1360
|
+
adj_id, adj_point, adj_max_dist, adj_ret.sign, adj_ret.face, adj_ret.u, adj_ret.v, adj_epsilon, adj_ret.result);
|
|
1361
1361
|
}
|
|
1362
1362
|
|
|
1363
1363
|
CUDA_CALLABLE inline void adj_mesh_query_point_sign_winding_number(uint64_t id, const vec3& point, float max_dist, float accuracy, float winding_number_threshold, const mesh_query_point_t& ret,
|