warp-lang 1.7.2rc1__py3-none-macosx_10_13_universal2.whl → 1.8.1__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (192) hide show
  1. warp/__init__.py +3 -1
  2. warp/__init__.pyi +3489 -1
  3. warp/autograd.py +45 -122
  4. warp/bin/libwarp.dylib +0 -0
  5. warp/build.py +241 -252
  6. warp/build_dll.py +130 -26
  7. warp/builtins.py +1907 -384
  8. warp/codegen.py +272 -104
  9. warp/config.py +12 -1
  10. warp/constants.py +1 -1
  11. warp/context.py +770 -238
  12. warp/dlpack.py +1 -1
  13. warp/examples/benchmarks/benchmark_cloth.py +2 -2
  14. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  15. warp/examples/core/example_sample_mesh.py +1 -1
  16. warp/examples/core/example_spin_lock.py +93 -0
  17. warp/examples/core/example_work_queue.py +118 -0
  18. warp/examples/fem/example_adaptive_grid.py +5 -5
  19. warp/examples/fem/example_apic_fluid.py +1 -1
  20. warp/examples/fem/example_burgers.py +1 -1
  21. warp/examples/fem/example_convection_diffusion.py +9 -6
  22. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  23. warp/examples/fem/example_deformed_geometry.py +1 -1
  24. warp/examples/fem/example_diffusion.py +2 -2
  25. warp/examples/fem/example_diffusion_3d.py +1 -1
  26. warp/examples/fem/example_distortion_energy.py +1 -1
  27. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  28. warp/examples/fem/example_magnetostatics.py +5 -3
  29. warp/examples/fem/example_mixed_elasticity.py +5 -3
  30. warp/examples/fem/example_navier_stokes.py +11 -9
  31. warp/examples/fem/example_nonconforming_contact.py +5 -3
  32. warp/examples/fem/example_streamlines.py +8 -3
  33. warp/examples/fem/utils.py +9 -8
  34. warp/examples/interop/example_jax_callable.py +34 -4
  35. warp/examples/interop/example_jax_ffi_callback.py +2 -2
  36. warp/examples/interop/example_jax_kernel.py +27 -1
  37. warp/examples/optim/example_drone.py +1 -1
  38. warp/examples/sim/example_cloth.py +1 -1
  39. warp/examples/sim/example_cloth_self_contact.py +48 -54
  40. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  41. warp/examples/tile/example_tile_cholesky.py +2 -1
  42. warp/examples/tile/example_tile_convolution.py +1 -1
  43. warp/examples/tile/example_tile_filtering.py +1 -1
  44. warp/examples/tile/example_tile_matmul.py +1 -1
  45. warp/examples/tile/example_tile_mlp.py +2 -0
  46. warp/fabric.py +7 -7
  47. warp/fem/__init__.py +5 -0
  48. warp/fem/adaptivity.py +1 -1
  49. warp/fem/cache.py +152 -63
  50. warp/fem/dirichlet.py +2 -2
  51. warp/fem/domain.py +136 -6
  52. warp/fem/field/field.py +141 -99
  53. warp/fem/field/nodal_field.py +85 -39
  54. warp/fem/field/virtual.py +99 -52
  55. warp/fem/geometry/adaptive_nanogrid.py +91 -86
  56. warp/fem/geometry/closest_point.py +13 -0
  57. warp/fem/geometry/deformed_geometry.py +102 -40
  58. warp/fem/geometry/element.py +56 -2
  59. warp/fem/geometry/geometry.py +323 -22
  60. warp/fem/geometry/grid_2d.py +157 -62
  61. warp/fem/geometry/grid_3d.py +116 -20
  62. warp/fem/geometry/hexmesh.py +86 -20
  63. warp/fem/geometry/nanogrid.py +166 -86
  64. warp/fem/geometry/partition.py +59 -25
  65. warp/fem/geometry/quadmesh.py +86 -135
  66. warp/fem/geometry/tetmesh.py +47 -119
  67. warp/fem/geometry/trimesh.py +77 -270
  68. warp/fem/integrate.py +181 -95
  69. warp/fem/linalg.py +25 -58
  70. warp/fem/operator.py +124 -27
  71. warp/fem/quadrature/pic_quadrature.py +36 -14
  72. warp/fem/quadrature/quadrature.py +40 -16
  73. warp/fem/space/__init__.py +1 -1
  74. warp/fem/space/basis_function_space.py +66 -46
  75. warp/fem/space/basis_space.py +17 -4
  76. warp/fem/space/dof_mapper.py +1 -1
  77. warp/fem/space/function_space.py +2 -2
  78. warp/fem/space/grid_2d_function_space.py +4 -1
  79. warp/fem/space/hexmesh_function_space.py +4 -2
  80. warp/fem/space/nanogrid_function_space.py +3 -1
  81. warp/fem/space/partition.py +11 -2
  82. warp/fem/space/quadmesh_function_space.py +4 -1
  83. warp/fem/space/restriction.py +5 -2
  84. warp/fem/space/shape/__init__.py +10 -8
  85. warp/fem/space/tetmesh_function_space.py +4 -1
  86. warp/fem/space/topology.py +52 -21
  87. warp/fem/space/trimesh_function_space.py +4 -1
  88. warp/fem/utils.py +53 -8
  89. warp/jax.py +1 -2
  90. warp/jax_experimental/ffi.py +210 -67
  91. warp/jax_experimental/xla_ffi.py +37 -24
  92. warp/math.py +171 -1
  93. warp/native/array.h +103 -4
  94. warp/native/builtin.h +182 -35
  95. warp/native/coloring.cpp +6 -2
  96. warp/native/cuda_util.cpp +1 -1
  97. warp/native/exports.h +118 -63
  98. warp/native/intersect.h +5 -5
  99. warp/native/mat.h +8 -13
  100. warp/native/mathdx.cpp +11 -5
  101. warp/native/matnn.h +1 -123
  102. warp/native/mesh.h +1 -1
  103. warp/native/quat.h +34 -6
  104. warp/native/rand.h +7 -7
  105. warp/native/sparse.cpp +121 -258
  106. warp/native/sparse.cu +181 -274
  107. warp/native/spatial.h +305 -17
  108. warp/native/svd.h +23 -8
  109. warp/native/tile.h +603 -73
  110. warp/native/tile_radix_sort.h +1112 -0
  111. warp/native/tile_reduce.h +239 -13
  112. warp/native/tile_scan.h +240 -0
  113. warp/native/tuple.h +189 -0
  114. warp/native/vec.h +10 -20
  115. warp/native/warp.cpp +36 -4
  116. warp/native/warp.cu +588 -52
  117. warp/native/warp.h +47 -74
  118. warp/optim/linear.py +5 -1
  119. warp/paddle.py +7 -8
  120. warp/py.typed +0 -0
  121. warp/render/render_opengl.py +110 -80
  122. warp/render/render_usd.py +124 -62
  123. warp/sim/__init__.py +9 -0
  124. warp/sim/collide.py +253 -80
  125. warp/sim/graph_coloring.py +8 -1
  126. warp/sim/import_mjcf.py +4 -3
  127. warp/sim/import_usd.py +11 -7
  128. warp/sim/integrator.py +5 -2
  129. warp/sim/integrator_euler.py +1 -1
  130. warp/sim/integrator_featherstone.py +1 -1
  131. warp/sim/integrator_vbd.py +761 -322
  132. warp/sim/integrator_xpbd.py +1 -1
  133. warp/sim/model.py +265 -260
  134. warp/sim/utils.py +10 -7
  135. warp/sparse.py +303 -166
  136. warp/tape.py +54 -51
  137. warp/tests/cuda/test_conditional_captures.py +1046 -0
  138. warp/tests/cuda/test_streams.py +1 -1
  139. warp/tests/geometry/test_volume.py +2 -2
  140. warp/tests/interop/test_dlpack.py +9 -9
  141. warp/tests/interop/test_jax.py +0 -1
  142. warp/tests/run_coverage_serial.py +1 -1
  143. warp/tests/sim/disabled_kinematics.py +2 -2
  144. warp/tests/sim/{test_vbd.py → test_cloth.py} +378 -112
  145. warp/tests/sim/test_collision.py +159 -51
  146. warp/tests/sim/test_coloring.py +91 -2
  147. warp/tests/test_array.py +254 -2
  148. warp/tests/test_array_reduce.py +2 -2
  149. warp/tests/test_assert.py +53 -0
  150. warp/tests/test_atomic_cas.py +312 -0
  151. warp/tests/test_codegen.py +142 -19
  152. warp/tests/test_conditional.py +47 -1
  153. warp/tests/test_ctypes.py +0 -20
  154. warp/tests/test_devices.py +8 -0
  155. warp/tests/test_fabricarray.py +4 -2
  156. warp/tests/test_fem.py +58 -25
  157. warp/tests/test_func.py +42 -1
  158. warp/tests/test_grad.py +1 -1
  159. warp/tests/test_lerp.py +1 -3
  160. warp/tests/test_map.py +481 -0
  161. warp/tests/test_mat.py +23 -24
  162. warp/tests/test_quat.py +28 -15
  163. warp/tests/test_rounding.py +10 -38
  164. warp/tests/test_runlength_encode.py +7 -7
  165. warp/tests/test_smoothstep.py +1 -1
  166. warp/tests/test_sparse.py +83 -2
  167. warp/tests/test_spatial.py +507 -1
  168. warp/tests/test_static.py +48 -0
  169. warp/tests/test_struct.py +2 -2
  170. warp/tests/test_tape.py +38 -0
  171. warp/tests/test_tuple.py +265 -0
  172. warp/tests/test_types.py +2 -2
  173. warp/tests/test_utils.py +24 -18
  174. warp/tests/test_vec.py +38 -408
  175. warp/tests/test_vec_constructors.py +325 -0
  176. warp/tests/tile/test_tile.py +438 -131
  177. warp/tests/tile/test_tile_mathdx.py +518 -14
  178. warp/tests/tile/test_tile_matmul.py +179 -0
  179. warp/tests/tile/test_tile_reduce.py +307 -5
  180. warp/tests/tile/test_tile_shared_memory.py +136 -7
  181. warp/tests/tile/test_tile_sort.py +121 -0
  182. warp/tests/unittest_suites.py +14 -6
  183. warp/types.py +462 -308
  184. warp/utils.py +647 -86
  185. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/METADATA +20 -6
  186. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/RECORD +189 -175
  187. warp/stubs.py +0 -3381
  188. warp/tests/sim/test_xpbd.py +0 -399
  189. warp/tests/test_mlp.py +0 -282
  190. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/WHEEL +0 -0
  191. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/licenses/LICENSE.md +0 -0
  192. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/top_level.txt +0 -0
@@ -16,7 +16,8 @@
16
16
  import ctypes
17
17
  import threading
18
18
  import traceback
19
- from typing import Callable
19
+ from enum import IntEnum
20
+ from typing import Callable, Optional
20
21
 
21
22
  import jax
22
23
 
@@ -28,10 +29,17 @@ from warp.types import array_t, launch_bounds_t, strides_from_shape, type_to_war
28
29
  from .xla_ffi import *
29
30
 
30
31
 
32
+ class GraphMode(IntEnum):
33
+ NONE = 0 # don't capture a graph
34
+ JAX = 1 # let JAX capture a graph
35
+ WARP = 2 # let Warp capture a graph
36
+
37
+
31
38
  class FfiArg:
32
- def __init__(self, name, type):
39
+ def __init__(self, name, type, in_out=False):
33
40
  self.name = name
34
41
  self.type = type
42
+ self.in_out = in_out
35
43
  self.is_array = isinstance(type, wp.array)
36
44
 
37
45
  if self.is_array:
@@ -65,7 +73,7 @@ class FfiLaunchDesc:
65
73
 
66
74
 
67
75
  class FfiKernel:
68
- def __init__(self, kernel, num_outputs, vmap_method, launch_dims, output_dims):
76
+ def __init__(self, kernel, num_outputs, vmap_method, launch_dims, output_dims, in_out_argnames):
69
77
  self.kernel = kernel
70
78
  self.name = generate_unique_name(kernel.func)
71
79
  self.num_outputs = num_outputs
@@ -76,17 +84,28 @@ class FfiKernel:
76
84
  self.launch_id = 0
77
85
  self.launch_descriptors = {}
78
86
 
87
+ in_out_argnames_list = in_out_argnames or []
88
+ in_out_argnames = set(in_out_argnames_list)
89
+ if len(in_out_argnames_list) != len(in_out_argnames):
90
+ raise AssertionError("in_out_argnames must not contain duplicate names")
91
+
79
92
  self.num_kernel_args = len(kernel.adj.args)
80
- self.num_inputs = self.num_kernel_args - num_outputs
93
+ self.num_in_out = len(in_out_argnames)
94
+ self.num_inputs = self.num_kernel_args - num_outputs + self.num_in_out
81
95
  if self.num_outputs < 1:
82
96
  raise ValueError("At least one output is required")
83
97
  if self.num_outputs > self.num_kernel_args:
84
98
  raise ValueError("Number of outputs cannot be greater than the number of kernel arguments")
99
+ if self.num_outputs < self.num_in_out:
100
+ raise ValueError("Number of outputs cannot be smaller than the number of in_out_argnames")
85
101
 
86
102
  # process input args
87
103
  self.input_args = []
88
104
  for i in range(self.num_inputs):
89
- arg = FfiArg(kernel.adj.args[i].label, kernel.adj.args[i].type)
105
+ arg_name = kernel.adj.args[i].label
106
+ arg = FfiArg(arg_name, kernel.adj.args[i].type, arg_name in in_out_argnames)
107
+ if arg_name in in_out_argnames:
108
+ in_out_argnames.remove(arg_name)
90
109
  if arg.is_array:
91
110
  # keep track of the first input array argument
92
111
  if self.first_array_arg is None:
@@ -96,11 +115,30 @@ class FfiKernel:
96
115
  # process output args
97
116
  self.output_args = []
98
117
  for i in range(self.num_inputs, self.num_kernel_args):
99
- arg = FfiArg(kernel.adj.args[i].label, kernel.adj.args[i].type)
118
+ arg_name = kernel.adj.args[i].label
119
+ if arg_name in in_out_argnames:
120
+ raise AssertionError(
121
+ f"Expected an output-only argument for argument {arg_name}."
122
+ " in_out arguments should be placed before output-only arguments."
123
+ )
124
+ arg = FfiArg(arg_name, kernel.adj.args[i].type, False)
100
125
  if not arg.is_array:
101
126
  raise TypeError("All output arguments must be arrays")
102
127
  self.output_args.append(arg)
103
128
 
129
+ if in_out_argnames:
130
+ raise ValueError(f"in_out_argnames: '{in_out_argnames}' did not match any function argument names.")
131
+
132
+ # Build input output aliases.
133
+ out_id = 0
134
+ input_output_aliases = {}
135
+ for in_id, arg in enumerate(self.input_args):
136
+ if not arg.in_out:
137
+ continue
138
+ input_output_aliases[in_id] = out_id
139
+ out_id += 1
140
+ self.input_output_aliases = input_output_aliases
141
+
104
142
  # register the callback
105
143
  FFI_CCALLFUNC = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_CallFrame))
106
144
  self.callback_func = FFI_CCALLFUNC(lambda call_frame: self.ffi_callback(call_frame))
@@ -121,6 +159,9 @@ class FfiKernel:
121
159
  if vmap_method is None:
122
160
  vmap_method = self.vmap_method
123
161
 
162
+ # output types
163
+ out_types = []
164
+
124
165
  # process inputs
125
166
  static_inputs = {}
126
167
  for i in range(num_inputs):
@@ -150,6 +191,10 @@ class FfiKernel:
150
191
  # stash the value to be retrieved by callback
151
192
  static_inputs[input_arg.name] = input_arg.type(input_value)
152
193
 
194
+ # append in-out arg to output types
195
+ if input_arg.in_out:
196
+ out_types.append(get_jax_output_type(input_arg, input_value.shape))
197
+
153
198
  # launch dimensions
154
199
  if launch_dims is None:
155
200
  # use the shape of the first input array
@@ -162,8 +207,7 @@ class FfiKernel:
162
207
  else:
163
208
  launch_dims = tuple(launch_dims)
164
209
 
165
- # output types
166
- out_types = []
210
+ # output shapes
167
211
  if isinstance(output_dims, dict):
168
212
  # assume a dictionary of shapes keyed on argument name
169
213
  for output_arg in self.output_args:
@@ -185,6 +229,7 @@ class FfiKernel:
185
229
  self.name,
186
230
  out_types,
187
231
  vmap_method=vmap_method,
232
+ input_output_aliases=self.input_output_aliases,
188
233
  )
189
234
 
190
235
  # ensure the kernel module is loaded before the callback, otherwise graph capture may fail
@@ -238,9 +283,8 @@ class FfiKernel:
238
283
 
239
284
  arg_refs = []
240
285
 
241
- # inputs
242
- for i in range(num_inputs):
243
- input_arg = self.input_args[i]
286
+ # input and in-out args
287
+ for i, input_arg in enumerate(self.input_args):
244
288
  if input_arg.is_array:
245
289
  buffer = inputs[i].contents
246
290
  shape = buffer.dims[: input_arg.type.ndim]
@@ -255,10 +299,9 @@ class FfiKernel:
255
299
  kernel_params[i + 1] = ctypes.addressof(arg)
256
300
  arg_refs.append(arg) # keep a reference
257
301
 
258
- # outputs
259
- for i in range(num_outputs):
260
- output_arg = self.output_args[i]
261
- buffer = outputs[i].contents
302
+ # pure output args (skip in-out FFI buffers)
303
+ for i, output_arg in enumerate(self.output_args):
304
+ buffer = outputs[i + self.num_in_out].contents
262
305
  shape = buffer.dims[: output_arg.type.ndim]
263
306
  strides = strides_from_shape(shape, output_arg.type.dtype)
264
307
  arg = array_t(buffer.data, 0, output_arg.type.ndim, shape, strides)
@@ -295,30 +338,38 @@ class FfiKernel:
295
338
  class FfiCallDesc:
296
339
  def __init__(self, static_inputs):
297
340
  self.static_inputs = static_inputs
341
+ self.captures = {}
298
342
 
299
343
 
300
344
  class FfiCallable:
301
- def __init__(self, func, num_outputs, graph_compatible, vmap_method, output_dims):
345
+ def __init__(self, func, num_outputs, graph_mode, vmap_method, output_dims, in_out_argnames):
302
346
  self.func = func
303
347
  self.name = generate_unique_name(func)
304
348
  self.num_outputs = num_outputs
305
349
  self.vmap_method = vmap_method
306
- self.graph_compatible = graph_compatible
350
+ self.graph_mode = graph_mode
307
351
  self.output_dims = output_dims
308
352
  self.first_array_arg = None
309
- self.has_static_args = False
310
353
  self.call_id = 0
311
354
  self.call_descriptors = {}
312
355
 
356
+ in_out_argnames_list = in_out_argnames or []
357
+ in_out_argnames = set(in_out_argnames_list)
358
+ if len(in_out_argnames_list) != len(in_out_argnames):
359
+ raise AssertionError("in_out_argnames must not contain duplicate names")
360
+
313
361
  # get arguments and annotations
314
362
  argspec = get_full_arg_spec(func)
315
363
 
316
364
  num_args = len(argspec.args)
317
- self.num_inputs = num_args - num_outputs
365
+ self.num_in_out = len(in_out_argnames)
366
+ self.num_inputs = num_args - num_outputs + self.num_in_out
318
367
  if self.num_outputs < 1:
319
368
  raise ValueError("At least one output is required")
320
369
  if self.num_outputs > num_args:
321
370
  raise ValueError("Number of outputs cannot be greater than the number of kernel arguments")
371
+ if self.num_outputs < self.num_in_out:
372
+ raise ValueError("Number of outputs cannot be smaller than the number of in_out_argnames")
322
373
 
323
374
  if len(argspec.annotations) < num_args:
324
375
  raise RuntimeError(f"Incomplete argument annotations on function {self.name}")
@@ -331,17 +382,43 @@ class FfiCallable:
331
382
  if arg_type is not None:
332
383
  raise TypeError("Function must not return a value")
333
384
  else:
334
- arg = FfiArg(arg_name, arg_type)
385
+ arg = FfiArg(arg_name, arg_type, arg_name in in_out_argnames)
386
+ if arg_name in in_out_argnames:
387
+ in_out_argnames.remove(arg_name)
335
388
  if arg.is_array:
336
389
  if arg_idx < self.num_inputs and self.first_array_arg is None:
337
390
  self.first_array_arg = arg_idx
338
- else:
339
- self.has_static_args = True
340
391
  self.args.append(arg)
392
+
393
+ if arg.in_out and arg_idx >= self.num_inputs:
394
+ raise AssertionError(
395
+ f"Expected an output-only argument for argument {arg_name}."
396
+ " in_out arguments should be placed before output-only arguments."
397
+ )
398
+
341
399
  arg_idx += 1
342
400
 
343
- self.input_args = self.args[: self.num_inputs]
344
- self.output_args = self.args[self.num_inputs :]
401
+ if in_out_argnames:
402
+ raise ValueError(f"in_out_argnames: '{in_out_argnames}' did not match any function argument names.")
403
+
404
+ self.input_args = self.args[: self.num_inputs] # includes in-out args
405
+ self.output_args = self.args[self.num_inputs :] # pure output args
406
+
407
+ # Buffer indices for array arguments in callback.
408
+ # In-out buffers are the same pointers in the XLA call frame,
409
+ # so we only include them for inputs and skip them for outputs.
410
+ self.array_input_indices = [i for i, arg in enumerate(self.input_args) if arg.is_array]
411
+ self.array_output_indices = list(range(self.num_in_out, self.num_outputs))
412
+
413
+ # Build input output aliases.
414
+ out_id = 0
415
+ input_output_aliases = {}
416
+ for in_id, arg in enumerate(self.input_args):
417
+ if not arg.in_out:
418
+ continue
419
+ input_output_aliases[in_id] = out_id
420
+ out_id += 1
421
+ self.input_output_aliases = input_output_aliases
345
422
 
346
423
  # register the callback
347
424
  FFI_CCALLFUNC = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_CallFrame))
@@ -353,7 +430,9 @@ class FfiCallable:
353
430
  def __call__(self, *args, output_dims=None, vmap_method=None):
354
431
  num_inputs = len(args)
355
432
  if num_inputs != self.num_inputs:
356
- raise ValueError(f"Expected {self.num_inputs} inputs, but got {num_inputs}")
433
+ input_names = ", ".join(arg.name for arg in self.input_args)
434
+ s = "" if self.num_inputs == 1 else "s"
435
+ raise ValueError(f"Expected {self.num_inputs} input{s} ({input_names}), but got {num_inputs}")
357
436
 
358
437
  # default argument fallback
359
438
  if vmap_method is None:
@@ -361,6 +440,9 @@ class FfiCallable:
361
440
  if output_dims is None:
362
441
  output_dims = self.output_dims
363
442
 
443
+ # output types
444
+ out_types = []
445
+
364
446
  # process inputs
365
447
  static_inputs = {}
366
448
  for i in range(num_inputs):
@@ -390,12 +472,11 @@ class FfiCallable:
390
472
  # stash the value to be retrieved by callback
391
473
  static_inputs[input_arg.name] = input_arg.type(input_value)
392
474
 
393
- if output_dims is None and self.first_array_arg is not None:
394
- # use the shape of the first input array
395
- output_dims = get_warp_shape(self.input_args[self.first_array_arg], args[self.first_array_arg].shape)
475
+ # append in-out arg to output types
476
+ if input_arg.in_out:
477
+ out_types.append(get_jax_output_type(input_arg, input_value.shape))
396
478
 
397
- # output types
398
- out_types = []
479
+ # output shapes
399
480
  if isinstance(output_dims, dict):
400
481
  # assume a dictionary of shapes keyed on argument name
401
482
  for output_arg in self.output_args:
@@ -405,7 +486,9 @@ class FfiCallable:
405
486
  out_types.append(get_jax_output_type(output_arg, dims))
406
487
  else:
407
488
  if output_dims is None:
408
- raise ValueError("Unable to determine output dimensions")
489
+ if self.first_array_arg is None:
490
+ raise ValueError("Unable to determine output dimensions")
491
+ output_dims = get_warp_shape(self.input_args[self.first_array_arg], args[self.first_array_arg].shape)
409
492
  elif isinstance(output_dims, int):
410
493
  output_dims = (output_dims,)
411
494
  # assume same dimensions for all outputs
@@ -416,6 +499,7 @@ class FfiCallable:
416
499
  self.name,
417
500
  out_types,
418
501
  vmap_method=vmap_method,
502
+ input_output_aliases=self.input_output_aliases,
419
503
  # has_side_effect=True, # force this function to execute even if outputs aren't used
420
504
  )
421
505
 
@@ -425,22 +509,18 @@ class FfiCallable:
425
509
  module = wp.get_module(self.func.__module__)
426
510
  module.load(device)
427
511
 
428
- if self.has_static_args:
429
- # save call data to be retrieved by callback
430
- call_id = self.call_id
431
- self.call_descriptors[call_id] = FfiCallDesc(static_inputs)
432
- self.call_id += 1
433
- return call(*args, call_id=call_id)
434
- else:
435
- return call(*args)
512
+ # save call data to be retrieved by callback
513
+ call_id = self.call_id
514
+ self.call_descriptors[call_id] = FfiCallDesc(static_inputs)
515
+ self.call_id += 1
516
+ return call(*args, call_id=call_id)
436
517
 
437
518
  def ffi_callback(self, call_frame):
438
519
  try:
439
- # TODO Try-catch around the body and return XLA_FFI_Error on error.
440
- extension = call_frame.contents.extension_start
441
520
  # On the first call, XLA runtime will query the API version and traits
442
521
  # metadata using the |extension| field. Let us respond to that query
443
522
  # if the metadata extension is present.
523
+ extension = call_frame.contents.extension_start
444
524
  if extension:
445
525
  # Try to set the version metadata.
446
526
  if extension.contents.type == XLA_FFI_Extension_Type.Metadata:
@@ -448,17 +528,20 @@ class FfiCallable:
448
528
  metadata_ext.contents.metadata.contents.api_version.major_version = 0
449
529
  metadata_ext.contents.metadata.contents.api_version.minor_version = 1
450
530
  # Turn on CUDA graphs for this handler.
451
- if self.graph_compatible:
531
+ if self.graph_mode is GraphMode.JAX:
452
532
  metadata_ext.contents.metadata.contents.traits = (
453
533
  XLA_FFI_Handler_TraitsBits.COMMAND_BUFFER_COMPATIBLE
454
534
  )
455
535
  return None
456
536
 
457
- if self.has_static_args:
458
- # retrieve call info
459
- attrs = decode_attrs(call_frame.contents.attrs)
460
- call_id = int(attrs["call_id"])
461
- call_desc = self.call_descriptors[call_id]
537
+ # retrieve call info
538
+ # NOTE: this assumes that there's only one attribute - call_id (int64).
539
+ # A more general but slower approach is this:
540
+ # attrs = decode_attrs(call_frame.contents.attrs)
541
+ # call_id = int(attrs["call_id"])
542
+ attr = ctypes.cast(call_frame.contents.attrs.attrs[0], ctypes.POINTER(XLA_FFI_Scalar)).contents
543
+ call_id = ctypes.cast(attr.value, ctypes.POINTER(ctypes.c_int64)).contents.value
544
+ call_desc = self.call_descriptors[call_id]
462
545
 
463
546
  num_inputs = call_frame.contents.args.size
464
547
  inputs = ctypes.cast(call_frame.contents.args.args, ctypes.POINTER(ctypes.POINTER(XLA_FFI_Buffer)))
@@ -469,16 +552,42 @@ class FfiCallable:
469
552
  assert num_inputs == self.num_inputs
470
553
  assert num_outputs == self.num_outputs
471
554
 
472
- device = wp.device_from_jax(get_jax_device())
473
555
  cuda_stream = get_stream_from_callframe(call_frame.contents)
556
+
557
+ if self.graph_mode == GraphMode.WARP:
558
+ # check if we already captured an identical call
559
+ ip = [inputs[i].contents.data for i in self.array_input_indices]
560
+ op = [outputs[i].contents.data for i in self.array_output_indices]
561
+ buffer_hash = hash((*ip, *op))
562
+ capture = call_desc.captures.get(buffer_hash)
563
+
564
+ # launch existing graph
565
+ if capture is not None:
566
+ # NOTE: We use the native graph API to avoid overhead with obtaining Stream and Device objects in Python.
567
+ # This code should match wp.capture_launch().
568
+ graph = capture.graph
569
+ if graph.graph_exec is None:
570
+ g = ctypes.c_void_p()
571
+ if not wp.context.runtime.core.wp_cuda_graph_create_exec(
572
+ graph.device.context, cuda_stream, graph.graph, ctypes.byref(g)
573
+ ):
574
+ raise RuntimeError(f"Graph creation error: {wp.context.runtime.get_error_string()}")
575
+ graph.graph_exec = g
576
+
577
+ if not wp.context.runtime.core.wp_cuda_graph_launch(graph.graph_exec, cuda_stream):
578
+ raise RuntimeError(f"Graph launch error: {wp.context.runtime.get_error_string()}")
579
+
580
+ # early out
581
+ return
582
+
583
+ device = wp.device_from_jax(get_jax_device())
474
584
  stream = wp.Stream(device, cuda_stream=cuda_stream)
475
585
 
476
586
  # reconstruct the argument list
477
587
  arg_list = []
478
588
 
479
- # inputs
480
- for i in range(num_inputs):
481
- arg = self.input_args[i]
589
+ # input and in-out args
590
+ for i, arg in enumerate(self.input_args):
482
591
  if arg.is_array:
483
592
  buffer = inputs[i].contents
484
593
  shape = buffer.dims[: buffer.rank - arg.dtype_ndim]
@@ -489,10 +598,9 @@ class FfiCallable:
489
598
  value = call_desc.static_inputs[arg.name]
490
599
  arg_list.append(value)
491
600
 
492
- # outputs
493
- for i in range(num_outputs):
494
- arg = self.output_args[i]
495
- buffer = outputs[i].contents
601
+ # pure output args (skip in-out FFI buffers)
602
+ for i, arg in enumerate(self.output_args):
603
+ buffer = outputs[i + self.num_in_out].contents
496
604
  shape = buffer.dims[: buffer.rank - arg.dtype_ndim]
497
605
  arr = wp.array(ptr=buffer.data, dtype=arg.type.dtype, shape=shape, device=device)
498
606
  arg_list.append(arr)
@@ -500,9 +608,20 @@ class FfiCallable:
500
608
  # call the Python function with reconstructed arguments
501
609
  with wp.ScopedStream(stream, sync_enter=False):
502
610
  if stream.is_capturing:
503
- with wp.ScopedCapture(stream=stream, external=True):
611
+ # capturing with JAX
612
+ with wp.ScopedCapture(external=True) as capture:
613
+ self.func(*arg_list)
614
+ # keep a reference to the capture object to prevent required modules getting unloaded
615
+ call_desc.capture = capture
616
+ elif self.graph_mode == GraphMode.WARP:
617
+ # capturing with WARP
618
+ with wp.ScopedCapture() as capture:
504
619
  self.func(*arg_list)
620
+ wp.capture_launch(capture.graph)
621
+ # keep a reference to the capture object and reuse it with same buffers
622
+ call_desc.captures[buffer_hash] = capture
505
623
  else:
624
+ # not capturing
506
625
  self.func(*arg_list)
507
626
 
508
627
  except Exception as e:
@@ -520,7 +639,9 @@ _FFI_KERNEL_REGISTRY: dict[str, FfiKernel] = {}
520
639
  _FFI_REGISTRY_LOCK = threading.Lock()
521
640
 
522
641
 
523
- def jax_kernel(kernel, num_outputs=1, vmap_method="broadcast_all", launch_dims=None, output_dims=None):
642
+ def jax_kernel(
643
+ kernel, num_outputs=1, vmap_method="broadcast_all", launch_dims=None, output_dims=None, in_out_argnames=None
644
+ ):
524
645
  """Create a JAX callback from a Warp kernel.
525
646
 
526
647
  NOTE: This is an experimental feature under development.
@@ -528,6 +649,7 @@ def jax_kernel(kernel, num_outputs=1, vmap_method="broadcast_all", launch_dims=N
528
649
  Args:
529
650
  kernel: The Warp kernel to launch.
530
651
  num_outputs: Optional. Specify the number of output arguments if greater than 1.
652
+ This must include the number of ``in_out_arguments``.
531
653
  vmap_method: Optional. String specifying how the callback transforms under ``vmap()``.
532
654
  This argument can also be specified for individual calls.
533
655
  launch_dims: Optional. Specify the default kernel launch dimensions. If None, launch
@@ -536,12 +658,13 @@ def jax_kernel(kernel, num_outputs=1, vmap_method="broadcast_all", launch_dims=N
536
658
  output_dims: Optional. Specify the default dimensions of output arrays. If None, output
537
659
  dimensions are inferred from the launch dimensions.
538
660
  This argument can also be specified for individual calls.
661
+ in_out_argnames: Optional. Names of input-output arguments.
539
662
 
540
663
  Limitations:
541
664
  - All kernel arguments must be contiguous arrays or scalars.
542
665
  - Scalars must be static arguments in JAX.
543
- - Input arguments are followed by output arguments in the Warp kernel definition.
544
- - There must be at least one output argument.
666
+ - Input and input-output arguments must precede the output arguments in the ``kernel`` definition.
667
+ - There must be at least one output or input-output argument.
545
668
  - Only the CUDA backend is supported.
546
669
  """
547
670
  key = (
@@ -554,7 +677,7 @@ def jax_kernel(kernel, num_outputs=1, vmap_method="broadcast_all", launch_dims=N
554
677
 
555
678
  with _FFI_REGISTRY_LOCK:
556
679
  if key not in _FFI_KERNEL_REGISTRY:
557
- new_kernel = FfiKernel(kernel, num_outputs, vmap_method, launch_dims, output_dims)
680
+ new_kernel = FfiKernel(kernel, num_outputs, vmap_method, launch_dims, output_dims, in_out_argnames)
558
681
  _FFI_KERNEL_REGISTRY[key] = new_kernel
559
682
 
560
683
  return _FFI_KERNEL_REGISTRY[key]
@@ -563,9 +686,11 @@ def jax_kernel(kernel, num_outputs=1, vmap_method="broadcast_all", launch_dims=N
563
686
  def jax_callable(
564
687
  func: Callable,
565
688
  num_outputs: int = 1,
566
- graph_compatible: bool = True,
567
- vmap_method: str = "broadcast_all",
689
+ graph_compatible: Optional[bool] = None, # deprecated
690
+ graph_mode: GraphMode = GraphMode.JAX,
691
+ vmap_method: Optional[str] = "broadcast_all",
568
692
  output_dims=None,
693
+ in_out_argnames=None,
569
694
  ):
570
695
  """Create a JAX callback from an annotated Python function.
571
696
 
@@ -576,31 +701,50 @@ def jax_callable(
576
701
  Args:
577
702
  func: The Python function to call.
578
703
  num_outputs: Optional. Specify the number of output arguments if greater than 1.
704
+ This must include the number of ``in_out_arguments``.
579
705
  graph_compatible: Optional. Whether the function can be called during CUDA graph capture.
706
+ This argument is deprecated, use ``graph_mode`` instead.
707
+ graph_mode: Optional. CUDA graph capture mode.
708
+ ``GraphMode.JAX`` (default): Let JAX capture the graph, which may be used as a subgraph in an enclosing capture.
709
+ ``GraphMode.WARP``: Let Warp capture the graph. Use this mode when the callable cannot be used as a subraph,
710
+ such as when the callable uses conditional graph nodes.
711
+ ``GraphMode.NONE``: Disable graph capture. Use when the callable performs operations that are not legal in a graph,
712
+ such as host synchronization.
580
713
  vmap_method: Optional. String specifying how the callback transforms under ``vmap()``.
581
714
  This argument can also be specified for individual calls.
582
715
  output_dims: Optional. Specify the default dimensions of output arrays.
583
716
  If ``None``, output dimensions are inferred from the launch dimensions.
584
717
  This argument can also be specified for individual calls.
718
+ in_out_argnames: Optional. Names of input-output arguments.
585
719
 
586
720
  Limitations:
587
721
  - All kernel arguments must be contiguous arrays or scalars.
588
722
  - Scalars must be static arguments in JAX.
589
- - Input arguments are followed by output arguments in the Warp kernel definition.
590
- - There must be at least one output argument.
723
+ - Input and input-output arguments must precede the output arguments in the ``func`` definition.
724
+ - There must be at least one output or input-output argument.
591
725
  - Only the CUDA backend is supported.
592
726
  """
727
+
728
+ if graph_compatible is not None:
729
+ wp.utils.warn(
730
+ "The `graph_compatible` argument is deprecated, use `graph_mode` instead.",
731
+ DeprecationWarning,
732
+ stacklevel=3,
733
+ )
734
+ if graph_compatible is False:
735
+ graph_mode = GraphMode.NONE
736
+
593
737
  key = (
594
738
  func,
595
739
  num_outputs,
596
- graph_compatible,
740
+ graph_mode,
597
741
  vmap_method,
598
742
  tuple(sorted(output_dims.items())) if output_dims else output_dims,
599
743
  )
600
744
 
601
745
  with _FFI_REGISTRY_LOCK:
602
746
  if key not in _FFI_CALLABLE_REGISTRY:
603
- new_callable = FfiCallable(func, num_outputs, graph_compatible, vmap_method, output_dims)
747
+ new_callable = FfiCallable(func, num_outputs, graph_mode, vmap_method, output_dims, in_out_argnames)
604
748
  _FFI_CALLABLE_REGISTRY[key] = new_callable
605
749
 
606
750
  return _FFI_CALLABLE_REGISTRY[key]
@@ -631,7 +775,6 @@ def register_ffi_callback(name: str, func: Callable, graph_compatible: bool = Tr
631
775
 
632
776
  def ffi_callback(call_frame):
633
777
  try:
634
- # TODO Try-catch around the body and return XLA_FFI_Error on error.
635
778
  extension = call_frame.contents.extension_start
636
779
  # On the first call, XLA runtime will query the API version and traits
637
780
  # metadata using the |extension| field. Let us respond to that query