warp-lang 1.8.0__py3-none-win_amd64.whl → 1.9.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (153) hide show
  1. warp/__init__.py +282 -103
  2. warp/__init__.pyi +482 -110
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +93 -30
  6. warp/build_dll.py +48 -63
  7. warp/builtins.py +955 -137
  8. warp/codegen.py +327 -209
  9. warp/config.py +1 -1
  10. warp/context.py +1363 -800
  11. warp/examples/core/example_marching_cubes.py +1 -0
  12. warp/examples/core/example_render_opengl.py +100 -3
  13. warp/examples/fem/example_apic_fluid.py +98 -52
  14. warp/examples/fem/example_convection_diffusion_dg.py +25 -4
  15. warp/examples/fem/example_diffusion_mgpu.py +8 -3
  16. warp/examples/fem/utils.py +68 -22
  17. warp/examples/interop/example_jax_callable.py +34 -4
  18. warp/examples/interop/example_jax_kernel.py +27 -1
  19. warp/fabric.py +1 -1
  20. warp/fem/cache.py +27 -19
  21. warp/fem/domain.py +2 -2
  22. warp/fem/field/nodal_field.py +2 -2
  23. warp/fem/field/virtual.py +266 -166
  24. warp/fem/geometry/geometry.py +5 -5
  25. warp/fem/integrate.py +200 -91
  26. warp/fem/space/restriction.py +4 -0
  27. warp/fem/space/shape/tet_shape_function.py +3 -10
  28. warp/jax_experimental/custom_call.py +1 -1
  29. warp/jax_experimental/ffi.py +203 -54
  30. warp/marching_cubes.py +708 -0
  31. warp/native/array.h +103 -8
  32. warp/native/builtin.h +90 -9
  33. warp/native/bvh.cpp +64 -28
  34. warp/native/bvh.cu +58 -58
  35. warp/native/bvh.h +2 -2
  36. warp/native/clang/clang.cpp +7 -7
  37. warp/native/coloring.cpp +13 -3
  38. warp/native/crt.cpp +2 -2
  39. warp/native/crt.h +3 -5
  40. warp/native/cuda_util.cpp +42 -11
  41. warp/native/cuda_util.h +10 -4
  42. warp/native/exports.h +1842 -1908
  43. warp/native/fabric.h +2 -1
  44. warp/native/hashgrid.cpp +37 -37
  45. warp/native/hashgrid.cu +2 -2
  46. warp/native/initializer_array.h +1 -1
  47. warp/native/intersect.h +4 -4
  48. warp/native/mat.h +1913 -119
  49. warp/native/mathdx.cpp +43 -43
  50. warp/native/mesh.cpp +24 -24
  51. warp/native/mesh.cu +26 -26
  52. warp/native/mesh.h +5 -3
  53. warp/native/nanovdb/GridHandle.h +179 -12
  54. warp/native/nanovdb/HostBuffer.h +8 -7
  55. warp/native/nanovdb/NanoVDB.h +517 -895
  56. warp/native/nanovdb/NodeManager.h +323 -0
  57. warp/native/nanovdb/PNanoVDB.h +2 -2
  58. warp/native/quat.h +337 -16
  59. warp/native/rand.h +7 -7
  60. warp/native/range.h +7 -1
  61. warp/native/reduce.cpp +10 -10
  62. warp/native/reduce.cu +13 -14
  63. warp/native/runlength_encode.cpp +2 -2
  64. warp/native/runlength_encode.cu +5 -5
  65. warp/native/scan.cpp +3 -3
  66. warp/native/scan.cu +4 -4
  67. warp/native/sort.cpp +10 -10
  68. warp/native/sort.cu +22 -22
  69. warp/native/sparse.cpp +8 -8
  70. warp/native/sparse.cu +14 -14
  71. warp/native/spatial.h +366 -17
  72. warp/native/svd.h +23 -8
  73. warp/native/temp_buffer.h +2 -2
  74. warp/native/tile.h +303 -70
  75. warp/native/tile_radix_sort.h +5 -1
  76. warp/native/tile_reduce.h +16 -25
  77. warp/native/tuple.h +2 -2
  78. warp/native/vec.h +385 -18
  79. warp/native/volume.cpp +54 -54
  80. warp/native/volume.cu +1 -1
  81. warp/native/volume.h +2 -1
  82. warp/native/volume_builder.cu +30 -37
  83. warp/native/warp.cpp +150 -149
  84. warp/native/warp.cu +337 -193
  85. warp/native/warp.h +227 -226
  86. warp/optim/linear.py +736 -271
  87. warp/render/imgui_manager.py +289 -0
  88. warp/render/render_opengl.py +137 -57
  89. warp/render/render_usd.py +0 -1
  90. warp/sim/collide.py +1 -2
  91. warp/sim/graph_coloring.py +2 -2
  92. warp/sim/integrator_vbd.py +10 -2
  93. warp/sparse.py +559 -176
  94. warp/tape.py +2 -0
  95. warp/tests/aux_test_module_aot.py +7 -0
  96. warp/tests/cuda/test_async.py +3 -3
  97. warp/tests/cuda/test_conditional_captures.py +101 -0
  98. warp/tests/geometry/test_marching_cubes.py +233 -12
  99. warp/tests/sim/test_cloth.py +89 -6
  100. warp/tests/sim/test_coloring.py +82 -7
  101. warp/tests/test_array.py +56 -5
  102. warp/tests/test_assert.py +53 -0
  103. warp/tests/test_atomic_cas.py +127 -114
  104. warp/tests/test_codegen.py +3 -2
  105. warp/tests/test_context.py +8 -15
  106. warp/tests/test_enum.py +136 -0
  107. warp/tests/test_examples.py +2 -2
  108. warp/tests/test_fem.py +45 -2
  109. warp/tests/test_fixedarray.py +229 -0
  110. warp/tests/test_func.py +18 -15
  111. warp/tests/test_future_annotations.py +7 -5
  112. warp/tests/test_linear_solvers.py +30 -0
  113. warp/tests/test_map.py +1 -1
  114. warp/tests/test_mat.py +1540 -378
  115. warp/tests/test_mat_assign_copy.py +178 -0
  116. warp/tests/test_mat_constructors.py +574 -0
  117. warp/tests/test_module_aot.py +287 -0
  118. warp/tests/test_print.py +69 -0
  119. warp/tests/test_quat.py +162 -34
  120. warp/tests/test_quat_assign_copy.py +145 -0
  121. warp/tests/test_reload.py +2 -1
  122. warp/tests/test_sparse.py +103 -0
  123. warp/tests/test_spatial.py +140 -34
  124. warp/tests/test_spatial_assign_copy.py +160 -0
  125. warp/tests/test_static.py +48 -0
  126. warp/tests/test_struct.py +43 -3
  127. warp/tests/test_tape.py +38 -0
  128. warp/tests/test_types.py +0 -20
  129. warp/tests/test_vec.py +216 -441
  130. warp/tests/test_vec_assign_copy.py +143 -0
  131. warp/tests/test_vec_constructors.py +325 -0
  132. warp/tests/tile/test_tile.py +206 -152
  133. warp/tests/tile/test_tile_cholesky.py +605 -0
  134. warp/tests/tile/test_tile_load.py +169 -0
  135. warp/tests/tile/test_tile_mathdx.py +2 -558
  136. warp/tests/tile/test_tile_matmul.py +179 -0
  137. warp/tests/tile/test_tile_mlp.py +1 -1
  138. warp/tests/tile/test_tile_reduce.py +100 -11
  139. warp/tests/tile/test_tile_shared_memory.py +16 -16
  140. warp/tests/tile/test_tile_sort.py +59 -55
  141. warp/tests/unittest_suites.py +16 -0
  142. warp/tests/walkthrough_debug.py +1 -1
  143. warp/thirdparty/unittest_parallel.py +108 -9
  144. warp/types.py +554 -264
  145. warp/utils.py +68 -86
  146. {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/METADATA +28 -65
  147. {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/RECORD +150 -138
  148. warp/native/marching.cpp +0 -19
  149. warp/native/marching.cu +0 -514
  150. warp/native/marching.h +0 -19
  151. {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/WHEEL +0 -0
  152. {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/licenses/LICENSE.md +0 -0
  153. {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/top_level.txt +0 -0
@@ -16,7 +16,8 @@
16
16
  import ctypes
17
17
  import threading
18
18
  import traceback
19
- from typing import Callable
19
+ from enum import IntEnum
20
+ from typing import Callable, Optional
20
21
 
21
22
  import jax
22
23
 
@@ -28,10 +29,17 @@ from warp.types import array_t, launch_bounds_t, strides_from_shape, type_to_war
28
29
  from .xla_ffi import *
29
30
 
30
31
 
32
+ class GraphMode(IntEnum):
33
+ NONE = 0 # don't capture a graph
34
+ JAX = 1 # let JAX capture a graph
35
+ WARP = 2 # let Warp capture a graph
36
+
37
+
31
38
  class FfiArg:
32
- def __init__(self, name, type):
39
+ def __init__(self, name, type, in_out=False):
33
40
  self.name = name
34
41
  self.type = type
42
+ self.in_out = in_out
35
43
  self.is_array = isinstance(type, wp.array)
36
44
 
37
45
  if self.is_array:
@@ -65,7 +73,7 @@ class FfiLaunchDesc:
65
73
 
66
74
 
67
75
  class FfiKernel:
68
- def __init__(self, kernel, num_outputs, vmap_method, launch_dims, output_dims):
76
+ def __init__(self, kernel, num_outputs, vmap_method, launch_dims, output_dims, in_out_argnames):
69
77
  self.kernel = kernel
70
78
  self.name = generate_unique_name(kernel.func)
71
79
  self.num_outputs = num_outputs
@@ -76,17 +84,28 @@ class FfiKernel:
76
84
  self.launch_id = 0
77
85
  self.launch_descriptors = {}
78
86
 
87
+ in_out_argnames_list = in_out_argnames or []
88
+ in_out_argnames = set(in_out_argnames_list)
89
+ if len(in_out_argnames_list) != len(in_out_argnames):
90
+ raise AssertionError("in_out_argnames must not contain duplicate names")
91
+
79
92
  self.num_kernel_args = len(kernel.adj.args)
80
- self.num_inputs = self.num_kernel_args - num_outputs
93
+ self.num_in_out = len(in_out_argnames)
94
+ self.num_inputs = self.num_kernel_args - num_outputs + self.num_in_out
81
95
  if self.num_outputs < 1:
82
96
  raise ValueError("At least one output is required")
83
97
  if self.num_outputs > self.num_kernel_args:
84
98
  raise ValueError("Number of outputs cannot be greater than the number of kernel arguments")
99
+ if self.num_outputs < self.num_in_out:
100
+ raise ValueError("Number of outputs cannot be smaller than the number of in_out_argnames")
85
101
 
86
102
  # process input args
87
103
  self.input_args = []
88
104
  for i in range(self.num_inputs):
89
- arg = FfiArg(kernel.adj.args[i].label, kernel.adj.args[i].type)
105
+ arg_name = kernel.adj.args[i].label
106
+ arg = FfiArg(arg_name, kernel.adj.args[i].type, arg_name in in_out_argnames)
107
+ if arg_name in in_out_argnames:
108
+ in_out_argnames.remove(arg_name)
90
109
  if arg.is_array:
91
110
  # keep track of the first input array argument
92
111
  if self.first_array_arg is None:
@@ -96,11 +115,30 @@ class FfiKernel:
96
115
  # process output args
97
116
  self.output_args = []
98
117
  for i in range(self.num_inputs, self.num_kernel_args):
99
- arg = FfiArg(kernel.adj.args[i].label, kernel.adj.args[i].type)
118
+ arg_name = kernel.adj.args[i].label
119
+ if arg_name in in_out_argnames:
120
+ raise AssertionError(
121
+ f"Expected an output-only argument for argument {arg_name}."
122
+ " in_out arguments should be placed before output-only arguments."
123
+ )
124
+ arg = FfiArg(arg_name, kernel.adj.args[i].type, False)
100
125
  if not arg.is_array:
101
126
  raise TypeError("All output arguments must be arrays")
102
127
  self.output_args.append(arg)
103
128
 
129
+ if in_out_argnames:
130
+ raise ValueError(f"in_out_argnames: '{in_out_argnames}' did not match any function argument names.")
131
+
132
+ # Build input output aliases.
133
+ out_id = 0
134
+ input_output_aliases = {}
135
+ for in_id, arg in enumerate(self.input_args):
136
+ if not arg.in_out:
137
+ continue
138
+ input_output_aliases[in_id] = out_id
139
+ out_id += 1
140
+ self.input_output_aliases = input_output_aliases
141
+
104
142
  # register the callback
105
143
  FFI_CCALLFUNC = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_CallFrame))
106
144
  self.callback_func = FFI_CCALLFUNC(lambda call_frame: self.ffi_callback(call_frame))
@@ -121,6 +159,9 @@ class FfiKernel:
121
159
  if vmap_method is None:
122
160
  vmap_method = self.vmap_method
123
161
 
162
+ # output types
163
+ out_types = []
164
+
124
165
  # process inputs
125
166
  static_inputs = {}
126
167
  for i in range(num_inputs):
@@ -150,6 +191,10 @@ class FfiKernel:
150
191
  # stash the value to be retrieved by callback
151
192
  static_inputs[input_arg.name] = input_arg.type(input_value)
152
193
 
194
+ # append in-out arg to output types
195
+ if input_arg.in_out:
196
+ out_types.append(get_jax_output_type(input_arg, input_value.shape))
197
+
153
198
  # launch dimensions
154
199
  if launch_dims is None:
155
200
  # use the shape of the first input array
@@ -162,8 +207,7 @@ class FfiKernel:
162
207
  else:
163
208
  launch_dims = tuple(launch_dims)
164
209
 
165
- # output types
166
- out_types = []
210
+ # output shapes
167
211
  if isinstance(output_dims, dict):
168
212
  # assume a dictionary of shapes keyed on argument name
169
213
  for output_arg in self.output_args:
@@ -185,6 +229,7 @@ class FfiKernel:
185
229
  self.name,
186
230
  out_types,
187
231
  vmap_method=vmap_method,
232
+ input_output_aliases=self.input_output_aliases,
188
233
  )
189
234
 
190
235
  # ensure the kernel module is loaded before the callback, otherwise graph capture may fail
@@ -238,9 +283,8 @@ class FfiKernel:
238
283
 
239
284
  arg_refs = []
240
285
 
241
- # inputs
242
- for i in range(num_inputs):
243
- input_arg = self.input_args[i]
286
+ # input and in-out args
287
+ for i, input_arg in enumerate(self.input_args):
244
288
  if input_arg.is_array:
245
289
  buffer = inputs[i].contents
246
290
  shape = buffer.dims[: input_arg.type.ndim]
@@ -255,10 +299,9 @@ class FfiKernel:
255
299
  kernel_params[i + 1] = ctypes.addressof(arg)
256
300
  arg_refs.append(arg) # keep a reference
257
301
 
258
- # outputs
259
- for i in range(num_outputs):
260
- output_arg = self.output_args[i]
261
- buffer = outputs[i].contents
302
+ # pure output args (skip in-out FFI buffers)
303
+ for i, output_arg in enumerate(self.output_args):
304
+ buffer = outputs[i + self.num_in_out].contents
262
305
  shape = buffer.dims[: output_arg.type.ndim]
263
306
  strides = strides_from_shape(shape, output_arg.type.dtype)
264
307
  arg = array_t(buffer.data, 0, output_arg.type.ndim, shape, strides)
@@ -274,7 +317,7 @@ class FfiKernel:
274
317
  assert hooks.forward, "Failed to find kernel entry point"
275
318
 
276
319
  # launch the kernel
277
- wp.context.runtime.core.cuda_launch_kernel(
320
+ wp.context.runtime.core.wp_cuda_launch_kernel(
278
321
  device.context,
279
322
  hooks.forward,
280
323
  launch_bounds.size,
@@ -295,29 +338,38 @@ class FfiKernel:
295
338
  class FfiCallDesc:
296
339
  def __init__(self, static_inputs):
297
340
  self.static_inputs = static_inputs
341
+ self.captures = {}
298
342
 
299
343
 
300
344
  class FfiCallable:
301
- def __init__(self, func, num_outputs, graph_compatible, vmap_method, output_dims):
345
+ def __init__(self, func, num_outputs, graph_mode, vmap_method, output_dims, in_out_argnames):
302
346
  self.func = func
303
347
  self.name = generate_unique_name(func)
304
348
  self.num_outputs = num_outputs
305
349
  self.vmap_method = vmap_method
306
- self.graph_compatible = graph_compatible
350
+ self.graph_mode = graph_mode
307
351
  self.output_dims = output_dims
308
352
  self.first_array_arg = None
309
353
  self.call_id = 0
310
354
  self.call_descriptors = {}
311
355
 
356
+ in_out_argnames_list = in_out_argnames or []
357
+ in_out_argnames = set(in_out_argnames_list)
358
+ if len(in_out_argnames_list) != len(in_out_argnames):
359
+ raise AssertionError("in_out_argnames must not contain duplicate names")
360
+
312
361
  # get arguments and annotations
313
362
  argspec = get_full_arg_spec(func)
314
363
 
315
364
  num_args = len(argspec.args)
316
- self.num_inputs = num_args - num_outputs
365
+ self.num_in_out = len(in_out_argnames)
366
+ self.num_inputs = num_args - num_outputs + self.num_in_out
317
367
  if self.num_outputs < 1:
318
368
  raise ValueError("At least one output is required")
319
369
  if self.num_outputs > num_args:
320
370
  raise ValueError("Number of outputs cannot be greater than the number of kernel arguments")
371
+ if self.num_outputs < self.num_in_out:
372
+ raise ValueError("Number of outputs cannot be smaller than the number of in_out_argnames")
321
373
 
322
374
  if len(argspec.annotations) < num_args:
323
375
  raise RuntimeError(f"Incomplete argument annotations on function {self.name}")
@@ -329,16 +381,45 @@ class FfiCallable:
329
381
  if arg_name == "return":
330
382
  if arg_type is not None:
331
383
  raise TypeError("Function must not return a value")
384
+ continue
332
385
  else:
333
- arg = FfiArg(arg_name, arg_type)
386
+ arg = FfiArg(arg_name, arg_type, arg_name in in_out_argnames)
387
+ if arg_name in in_out_argnames:
388
+ in_out_argnames.remove(arg_name)
334
389
  if arg.is_array:
335
390
  if arg_idx < self.num_inputs and self.first_array_arg is None:
336
391
  self.first_array_arg = arg_idx
337
392
  self.args.append(arg)
393
+
394
+ if arg.in_out and arg_idx >= self.num_inputs:
395
+ raise AssertionError(
396
+ f"Expected an output-only argument for argument {arg_name}."
397
+ " in_out arguments should be placed before output-only arguments."
398
+ )
399
+
338
400
  arg_idx += 1
339
401
 
340
- self.input_args = self.args[: self.num_inputs]
341
- self.output_args = self.args[self.num_inputs :]
402
+ if in_out_argnames:
403
+ raise ValueError(f"in_out_argnames: '{in_out_argnames}' did not match any function argument names.")
404
+
405
+ self.input_args = self.args[: self.num_inputs] # includes in-out args
406
+ self.output_args = self.args[self.num_inputs :] # pure output args
407
+
408
+ # Buffer indices for array arguments in callback.
409
+ # In-out buffers are the same pointers in the XLA call frame,
410
+ # so we only include them for inputs and skip them for outputs.
411
+ self.array_input_indices = [i for i, arg in enumerate(self.input_args) if arg.is_array]
412
+ self.array_output_indices = list(range(self.num_in_out, self.num_outputs))
413
+
414
+ # Build input output aliases.
415
+ out_id = 0
416
+ input_output_aliases = {}
417
+ for in_id, arg in enumerate(self.input_args):
418
+ if not arg.in_out:
419
+ continue
420
+ input_output_aliases[in_id] = out_id
421
+ out_id += 1
422
+ self.input_output_aliases = input_output_aliases
342
423
 
343
424
  # register the callback
344
425
  FFI_CCALLFUNC = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_CallFrame))
@@ -350,7 +431,9 @@ class FfiCallable:
350
431
  def __call__(self, *args, output_dims=None, vmap_method=None):
351
432
  num_inputs = len(args)
352
433
  if num_inputs != self.num_inputs:
353
- raise ValueError(f"Expected {self.num_inputs} inputs, but got {num_inputs}")
434
+ input_names = ", ".join(arg.name for arg in self.input_args)
435
+ s = "" if self.num_inputs == 1 else "s"
436
+ raise ValueError(f"Expected {self.num_inputs} input{s} ({input_names}), but got {num_inputs}")
354
437
 
355
438
  # default argument fallback
356
439
  if vmap_method is None:
@@ -358,6 +441,9 @@ class FfiCallable:
358
441
  if output_dims is None:
359
442
  output_dims = self.output_dims
360
443
 
444
+ # output types
445
+ out_types = []
446
+
361
447
  # process inputs
362
448
  static_inputs = {}
363
449
  for i in range(num_inputs):
@@ -387,12 +473,11 @@ class FfiCallable:
387
473
  # stash the value to be retrieved by callback
388
474
  static_inputs[input_arg.name] = input_arg.type(input_value)
389
475
 
390
- if output_dims is None and self.first_array_arg is not None:
391
- # use the shape of the first input array
392
- output_dims = get_warp_shape(self.input_args[self.first_array_arg], args[self.first_array_arg].shape)
476
+ # append in-out arg to output types
477
+ if input_arg.in_out:
478
+ out_types.append(get_jax_output_type(input_arg, input_value.shape))
393
479
 
394
- # output types
395
- out_types = []
480
+ # output shapes
396
481
  if isinstance(output_dims, dict):
397
482
  # assume a dictionary of shapes keyed on argument name
398
483
  for output_arg in self.output_args:
@@ -402,7 +487,9 @@ class FfiCallable:
402
487
  out_types.append(get_jax_output_type(output_arg, dims))
403
488
  else:
404
489
  if output_dims is None:
405
- raise ValueError("Unable to determine output dimensions")
490
+ if self.first_array_arg is None:
491
+ raise ValueError("Unable to determine output dimensions")
492
+ output_dims = get_warp_shape(self.input_args[self.first_array_arg], args[self.first_array_arg].shape)
406
493
  elif isinstance(output_dims, int):
407
494
  output_dims = (output_dims,)
408
495
  # assume same dimensions for all outputs
@@ -413,6 +500,7 @@ class FfiCallable:
413
500
  self.name,
414
501
  out_types,
415
502
  vmap_method=vmap_method,
503
+ input_output_aliases=self.input_output_aliases,
416
504
  # has_side_effect=True, # force this function to execute even if outputs aren't used
417
505
  )
418
506
 
@@ -430,11 +518,10 @@ class FfiCallable:
430
518
 
431
519
  def ffi_callback(self, call_frame):
432
520
  try:
433
- # TODO Try-catch around the body and return XLA_FFI_Error on error.
434
- extension = call_frame.contents.extension_start
435
521
  # On the first call, XLA runtime will query the API version and traits
436
522
  # metadata using the |extension| field. Let us respond to that query
437
523
  # if the metadata extension is present.
524
+ extension = call_frame.contents.extension_start
438
525
  if extension:
439
526
  # Try to set the version metadata.
440
527
  if extension.contents.type == XLA_FFI_Extension_Type.Metadata:
@@ -442,15 +529,19 @@ class FfiCallable:
442
529
  metadata_ext.contents.metadata.contents.api_version.major_version = 0
443
530
  metadata_ext.contents.metadata.contents.api_version.minor_version = 1
444
531
  # Turn on CUDA graphs for this handler.
445
- if self.graph_compatible:
532
+ if self.graph_mode is GraphMode.JAX:
446
533
  metadata_ext.contents.metadata.contents.traits = (
447
534
  XLA_FFI_Handler_TraitsBits.COMMAND_BUFFER_COMPATIBLE
448
535
  )
449
536
  return None
450
537
 
451
538
  # retrieve call info
452
- attrs = decode_attrs(call_frame.contents.attrs)
453
- call_id = int(attrs["call_id"])
539
+ # NOTE: this assumes that there's only one attribute - call_id (int64).
540
+ # A more general but slower approach is this:
541
+ # attrs = decode_attrs(call_frame.contents.attrs)
542
+ # call_id = int(attrs["call_id"])
543
+ attr = ctypes.cast(call_frame.contents.attrs.attrs[0], ctypes.POINTER(XLA_FFI_Scalar)).contents
544
+ call_id = ctypes.cast(attr.value, ctypes.POINTER(ctypes.c_int64)).contents.value
454
545
  call_desc = self.call_descriptors[call_id]
455
546
 
456
547
  num_inputs = call_frame.contents.args.size
@@ -462,16 +553,42 @@ class FfiCallable:
462
553
  assert num_inputs == self.num_inputs
463
554
  assert num_outputs == self.num_outputs
464
555
 
465
- device = wp.device_from_jax(get_jax_device())
466
556
  cuda_stream = get_stream_from_callframe(call_frame.contents)
557
+
558
+ if self.graph_mode == GraphMode.WARP:
559
+ # check if we already captured an identical call
560
+ ip = [inputs[i].contents.data for i in self.array_input_indices]
561
+ op = [outputs[i].contents.data for i in self.array_output_indices]
562
+ buffer_hash = hash((*ip, *op))
563
+ capture = call_desc.captures.get(buffer_hash)
564
+
565
+ # launch existing graph
566
+ if capture is not None:
567
+ # NOTE: We use the native graph API to avoid overhead with obtaining Stream and Device objects in Python.
568
+ # This code should match wp.capture_launch().
569
+ graph = capture.graph
570
+ if graph.graph_exec is None:
571
+ g = ctypes.c_void_p()
572
+ if not wp.context.runtime.core.wp_cuda_graph_create_exec(
573
+ graph.device.context, cuda_stream, graph.graph, ctypes.byref(g)
574
+ ):
575
+ raise RuntimeError(f"Graph creation error: {wp.context.runtime.get_error_string()}")
576
+ graph.graph_exec = g
577
+
578
+ if not wp.context.runtime.core.wp_cuda_graph_launch(graph.graph_exec, cuda_stream):
579
+ raise RuntimeError(f"Graph launch error: {wp.context.runtime.get_error_string()}")
580
+
581
+ # early out
582
+ return
583
+
584
+ device = wp.device_from_jax(get_jax_device())
467
585
  stream = wp.Stream(device, cuda_stream=cuda_stream)
468
586
 
469
587
  # reconstruct the argument list
470
588
  arg_list = []
471
589
 
472
- # inputs
473
- for i in range(num_inputs):
474
- arg = self.input_args[i]
590
+ # input and in-out args
591
+ for i, arg in enumerate(self.input_args):
475
592
  if arg.is_array:
476
593
  buffer = inputs[i].contents
477
594
  shape = buffer.dims[: buffer.rank - arg.dtype_ndim]
@@ -482,10 +599,9 @@ class FfiCallable:
482
599
  value = call_desc.static_inputs[arg.name]
483
600
  arg_list.append(value)
484
601
 
485
- # outputs
486
- for i in range(num_outputs):
487
- arg = self.output_args[i]
488
- buffer = outputs[i].contents
602
+ # pure output args (skip in-out FFI buffers)
603
+ for i, arg in enumerate(self.output_args):
604
+ buffer = outputs[i + self.num_in_out].contents
489
605
  shape = buffer.dims[: buffer.rank - arg.dtype_ndim]
490
606
  arr = wp.array(ptr=buffer.data, dtype=arg.type.dtype, shape=shape, device=device)
491
607
  arg_list.append(arr)
@@ -493,11 +609,20 @@ class FfiCallable:
493
609
  # call the Python function with reconstructed arguments
494
610
  with wp.ScopedStream(stream, sync_enter=False):
495
611
  if stream.is_capturing:
496
- with wp.ScopedCapture(stream=stream, external=True) as capture:
612
+ # capturing with JAX
613
+ with wp.ScopedCapture(external=True) as capture:
497
614
  self.func(*arg_list)
498
615
  # keep a reference to the capture object to prevent required modules getting unloaded
499
616
  call_desc.capture = capture
617
+ elif self.graph_mode == GraphMode.WARP:
618
+ # capturing with WARP
619
+ with wp.ScopedCapture() as capture:
620
+ self.func(*arg_list)
621
+ wp.capture_launch(capture.graph)
622
+ # keep a reference to the capture object and reuse it with same buffers
623
+ call_desc.captures[buffer_hash] = capture
500
624
  else:
625
+ # not capturing
501
626
  self.func(*arg_list)
502
627
 
503
628
  except Exception as e:
@@ -515,7 +640,9 @@ _FFI_KERNEL_REGISTRY: dict[str, FfiKernel] = {}
515
640
  _FFI_REGISTRY_LOCK = threading.Lock()
516
641
 
517
642
 
518
- def jax_kernel(kernel, num_outputs=1, vmap_method="broadcast_all", launch_dims=None, output_dims=None):
643
+ def jax_kernel(
644
+ kernel, num_outputs=1, vmap_method="broadcast_all", launch_dims=None, output_dims=None, in_out_argnames=None
645
+ ):
519
646
  """Create a JAX callback from a Warp kernel.
520
647
 
521
648
  NOTE: This is an experimental feature under development.
@@ -523,6 +650,7 @@ def jax_kernel(kernel, num_outputs=1, vmap_method="broadcast_all", launch_dims=N
523
650
  Args:
524
651
  kernel: The Warp kernel to launch.
525
652
  num_outputs: Optional. Specify the number of output arguments if greater than 1.
653
+ This must include the number of ``in_out_arguments``.
526
654
  vmap_method: Optional. String specifying how the callback transforms under ``vmap()``.
527
655
  This argument can also be specified for individual calls.
528
656
  launch_dims: Optional. Specify the default kernel launch dimensions. If None, launch
@@ -531,12 +659,13 @@ def jax_kernel(kernel, num_outputs=1, vmap_method="broadcast_all", launch_dims=N
531
659
  output_dims: Optional. Specify the default dimensions of output arrays. If None, output
532
660
  dimensions are inferred from the launch dimensions.
533
661
  This argument can also be specified for individual calls.
662
+ in_out_argnames: Optional. Names of input-output arguments.
534
663
 
535
664
  Limitations:
536
665
  - All kernel arguments must be contiguous arrays or scalars.
537
666
  - Scalars must be static arguments in JAX.
538
- - Input arguments are followed by output arguments in the Warp kernel definition.
539
- - There must be at least one output argument.
667
+ - Input and input-output arguments must precede the output arguments in the ``kernel`` definition.
668
+ - There must be at least one output or input-output argument.
540
669
  - Only the CUDA backend is supported.
541
670
  """
542
671
  key = (
@@ -549,7 +678,7 @@ def jax_kernel(kernel, num_outputs=1, vmap_method="broadcast_all", launch_dims=N
549
678
 
550
679
  with _FFI_REGISTRY_LOCK:
551
680
  if key not in _FFI_KERNEL_REGISTRY:
552
- new_kernel = FfiKernel(kernel, num_outputs, vmap_method, launch_dims, output_dims)
681
+ new_kernel = FfiKernel(kernel, num_outputs, vmap_method, launch_dims, output_dims, in_out_argnames)
553
682
  _FFI_KERNEL_REGISTRY[key] = new_kernel
554
683
 
555
684
  return _FFI_KERNEL_REGISTRY[key]
@@ -558,9 +687,11 @@ def jax_kernel(kernel, num_outputs=1, vmap_method="broadcast_all", launch_dims=N
558
687
  def jax_callable(
559
688
  func: Callable,
560
689
  num_outputs: int = 1,
561
- graph_compatible: bool = True,
562
- vmap_method: str = "broadcast_all",
690
+ graph_compatible: Optional[bool] = None, # deprecated
691
+ graph_mode: GraphMode = GraphMode.JAX,
692
+ vmap_method: Optional[str] = "broadcast_all",
563
693
  output_dims=None,
694
+ in_out_argnames=None,
564
695
  ):
565
696
  """Create a JAX callback from an annotated Python function.
566
697
 
@@ -571,31 +702,50 @@ def jax_callable(
571
702
  Args:
572
703
  func: The Python function to call.
573
704
  num_outputs: Optional. Specify the number of output arguments if greater than 1.
705
+ This must include the number of ``in_out_arguments``.
574
706
  graph_compatible: Optional. Whether the function can be called during CUDA graph capture.
707
+ This argument is deprecated, use ``graph_mode`` instead.
708
+ graph_mode: Optional. CUDA graph capture mode.
709
+ ``GraphMode.JAX`` (default): Let JAX capture the graph, which may be used as a subgraph in an enclosing capture.
710
+ ``GraphMode.WARP``: Let Warp capture the graph. Use this mode when the callable cannot be used as a subraph,
711
+ such as when the callable uses conditional graph nodes.
712
+ ``GraphMode.NONE``: Disable graph capture. Use when the callable performs operations that are not legal in a graph,
713
+ such as host synchronization.
575
714
  vmap_method: Optional. String specifying how the callback transforms under ``vmap()``.
576
715
  This argument can also be specified for individual calls.
577
716
  output_dims: Optional. Specify the default dimensions of output arrays.
578
717
  If ``None``, output dimensions are inferred from the launch dimensions.
579
718
  This argument can also be specified for individual calls.
719
+ in_out_argnames: Optional. Names of input-output arguments.
580
720
 
581
721
  Limitations:
582
722
  - All kernel arguments must be contiguous arrays or scalars.
583
723
  - Scalars must be static arguments in JAX.
584
- - Input arguments are followed by output arguments in the Warp kernel definition.
585
- - There must be at least one output argument.
724
+ - Input and input-output arguments must precede the output arguments in the ``func`` definition.
725
+ - There must be at least one output or input-output argument.
586
726
  - Only the CUDA backend is supported.
587
727
  """
728
+
729
+ if graph_compatible is not None:
730
+ wp.utils.warn(
731
+ "The `graph_compatible` argument is deprecated, use `graph_mode` instead.",
732
+ DeprecationWarning,
733
+ stacklevel=3,
734
+ )
735
+ if graph_compatible is False:
736
+ graph_mode = GraphMode.NONE
737
+
588
738
  key = (
589
739
  func,
590
740
  num_outputs,
591
- graph_compatible,
741
+ graph_mode,
592
742
  vmap_method,
593
743
  tuple(sorted(output_dims.items())) if output_dims else output_dims,
594
744
  )
595
745
 
596
746
  with _FFI_REGISTRY_LOCK:
597
747
  if key not in _FFI_CALLABLE_REGISTRY:
598
- new_callable = FfiCallable(func, num_outputs, graph_compatible, vmap_method, output_dims)
748
+ new_callable = FfiCallable(func, num_outputs, graph_mode, vmap_method, output_dims, in_out_argnames)
599
749
  _FFI_CALLABLE_REGISTRY[key] = new_callable
600
750
 
601
751
  return _FFI_CALLABLE_REGISTRY[key]
@@ -626,7 +776,6 @@ def register_ffi_callback(name: str, func: Callable, graph_compatible: bool = Tr
626
776
 
627
777
  def ffi_callback(call_frame):
628
778
  try:
629
- # TODO Try-catch around the body and return XLA_FFI_Error on error.
630
779
  extension = call_frame.contents.extension_start
631
780
  # On the first call, XLA runtime will query the API version and traits
632
781
  # metadata using the |extension| field. Let us respond to that query