warp-lang 1.8.0__py3-none-manylinux_2_34_aarch64.whl → 1.8.1__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (59) hide show
  1. warp/bin/warp-clang.so +0 -0
  2. warp/bin/warp.so +0 -0
  3. warp/build_dll.py +5 -0
  4. warp/codegen.py +15 -3
  5. warp/config.py +1 -1
  6. warp/context.py +122 -24
  7. warp/examples/interop/example_jax_callable.py +34 -4
  8. warp/examples/interop/example_jax_kernel.py +27 -1
  9. warp/fem/field/virtual.py +2 -0
  10. warp/fem/integrate.py +78 -47
  11. warp/jax_experimental/ffi.py +201 -53
  12. warp/native/array.h +4 -4
  13. warp/native/builtin.h +8 -4
  14. warp/native/coloring.cpp +5 -1
  15. warp/native/cuda_util.cpp +1 -1
  16. warp/native/intersect.h +2 -2
  17. warp/native/mat.h +3 -3
  18. warp/native/mesh.h +1 -1
  19. warp/native/quat.h +6 -2
  20. warp/native/rand.h +7 -7
  21. warp/native/sparse.cu +1 -1
  22. warp/native/svd.h +23 -8
  23. warp/native/tile.h +20 -1
  24. warp/native/tile_radix_sort.h +5 -1
  25. warp/native/tile_reduce.h +16 -25
  26. warp/native/tuple.h +2 -2
  27. warp/native/vec.h +4 -4
  28. warp/native/warp.cpp +1 -1
  29. warp/native/warp.cu +15 -2
  30. warp/native/warp.h +1 -1
  31. warp/render/render_opengl.py +52 -51
  32. warp/render/render_usd.py +0 -1
  33. warp/sim/collide.py +1 -2
  34. warp/sim/integrator_vbd.py +10 -2
  35. warp/sparse.py +1 -1
  36. warp/tape.py +2 -0
  37. warp/tests/sim/test_cloth.py +89 -6
  38. warp/tests/sim/test_coloring.py +76 -1
  39. warp/tests/test_assert.py +53 -0
  40. warp/tests/test_atomic_cas.py +127 -114
  41. warp/tests/test_mat.py +22 -0
  42. warp/tests/test_quat.py +22 -0
  43. warp/tests/test_sparse.py +32 -0
  44. warp/tests/test_static.py +48 -0
  45. warp/tests/test_tape.py +38 -0
  46. warp/tests/test_vec.py +38 -408
  47. warp/tests/test_vec_constructors.py +325 -0
  48. warp/tests/tile/test_tile.py +31 -143
  49. warp/tests/tile/test_tile_mathdx.py +2 -2
  50. warp/tests/tile/test_tile_matmul.py +179 -0
  51. warp/tests/tile/test_tile_reduce.py +100 -11
  52. warp/tests/tile/test_tile_shared_memory.py +12 -12
  53. warp/tests/tile/test_tile_sort.py +59 -55
  54. warp/tests/unittest_suites.py +10 -0
  55. {warp_lang-1.8.0.dist-info → warp_lang-1.8.1.dist-info}/METADATA +4 -4
  56. {warp_lang-1.8.0.dist-info → warp_lang-1.8.1.dist-info}/RECORD +59 -57
  57. {warp_lang-1.8.0.dist-info → warp_lang-1.8.1.dist-info}/WHEEL +0 -0
  58. {warp_lang-1.8.0.dist-info → warp_lang-1.8.1.dist-info}/licenses/LICENSE.md +0 -0
  59. {warp_lang-1.8.0.dist-info → warp_lang-1.8.1.dist-info}/top_level.txt +0 -0
@@ -16,7 +16,8 @@
16
16
  import ctypes
17
17
  import threading
18
18
  import traceback
19
- from typing import Callable
19
+ from enum import IntEnum
20
+ from typing import Callable, Optional
20
21
 
21
22
  import jax
22
23
 
@@ -28,10 +29,17 @@ from warp.types import array_t, launch_bounds_t, strides_from_shape, type_to_war
28
29
  from .xla_ffi import *
29
30
 
30
31
 
32
+ class GraphMode(IntEnum):
33
+ NONE = 0 # don't capture a graph
34
+ JAX = 1 # let JAX capture a graph
35
+ WARP = 2 # let Warp capture a graph
36
+
37
+
31
38
  class FfiArg:
32
- def __init__(self, name, type):
39
+ def __init__(self, name, type, in_out=False):
33
40
  self.name = name
34
41
  self.type = type
42
+ self.in_out = in_out
35
43
  self.is_array = isinstance(type, wp.array)
36
44
 
37
45
  if self.is_array:
@@ -65,7 +73,7 @@ class FfiLaunchDesc:
65
73
 
66
74
 
67
75
  class FfiKernel:
68
- def __init__(self, kernel, num_outputs, vmap_method, launch_dims, output_dims):
76
+ def __init__(self, kernel, num_outputs, vmap_method, launch_dims, output_dims, in_out_argnames):
69
77
  self.kernel = kernel
70
78
  self.name = generate_unique_name(kernel.func)
71
79
  self.num_outputs = num_outputs
@@ -76,17 +84,28 @@ class FfiKernel:
76
84
  self.launch_id = 0
77
85
  self.launch_descriptors = {}
78
86
 
87
+ in_out_argnames_list = in_out_argnames or []
88
+ in_out_argnames = set(in_out_argnames_list)
89
+ if len(in_out_argnames_list) != len(in_out_argnames):
90
+ raise AssertionError("in_out_argnames must not contain duplicate names")
91
+
79
92
  self.num_kernel_args = len(kernel.adj.args)
80
- self.num_inputs = self.num_kernel_args - num_outputs
93
+ self.num_in_out = len(in_out_argnames)
94
+ self.num_inputs = self.num_kernel_args - num_outputs + self.num_in_out
81
95
  if self.num_outputs < 1:
82
96
  raise ValueError("At least one output is required")
83
97
  if self.num_outputs > self.num_kernel_args:
84
98
  raise ValueError("Number of outputs cannot be greater than the number of kernel arguments")
99
+ if self.num_outputs < self.num_in_out:
100
+ raise ValueError("Number of outputs cannot be smaller than the number of in_out_argnames")
85
101
 
86
102
  # process input args
87
103
  self.input_args = []
88
104
  for i in range(self.num_inputs):
89
- arg = FfiArg(kernel.adj.args[i].label, kernel.adj.args[i].type)
105
+ arg_name = kernel.adj.args[i].label
106
+ arg = FfiArg(arg_name, kernel.adj.args[i].type, arg_name in in_out_argnames)
107
+ if arg_name in in_out_argnames:
108
+ in_out_argnames.remove(arg_name)
90
109
  if arg.is_array:
91
110
  # keep track of the first input array argument
92
111
  if self.first_array_arg is None:
@@ -96,11 +115,30 @@ class FfiKernel:
96
115
  # process output args
97
116
  self.output_args = []
98
117
  for i in range(self.num_inputs, self.num_kernel_args):
99
- arg = FfiArg(kernel.adj.args[i].label, kernel.adj.args[i].type)
118
+ arg_name = kernel.adj.args[i].label
119
+ if arg_name in in_out_argnames:
120
+ raise AssertionError(
121
+ f"Expected an output-only argument for argument {arg_name}."
122
+ " in_out arguments should be placed before output-only arguments."
123
+ )
124
+ arg = FfiArg(arg_name, kernel.adj.args[i].type, False)
100
125
  if not arg.is_array:
101
126
  raise TypeError("All output arguments must be arrays")
102
127
  self.output_args.append(arg)
103
128
 
129
+ if in_out_argnames:
130
+ raise ValueError(f"in_out_argnames: '{in_out_argnames}' did not match any function argument names.")
131
+
132
+ # Build input output aliases.
133
+ out_id = 0
134
+ input_output_aliases = {}
135
+ for in_id, arg in enumerate(self.input_args):
136
+ if not arg.in_out:
137
+ continue
138
+ input_output_aliases[in_id] = out_id
139
+ out_id += 1
140
+ self.input_output_aliases = input_output_aliases
141
+
104
142
  # register the callback
105
143
  FFI_CCALLFUNC = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_CallFrame))
106
144
  self.callback_func = FFI_CCALLFUNC(lambda call_frame: self.ffi_callback(call_frame))
@@ -121,6 +159,9 @@ class FfiKernel:
121
159
  if vmap_method is None:
122
160
  vmap_method = self.vmap_method
123
161
 
162
+ # output types
163
+ out_types = []
164
+
124
165
  # process inputs
125
166
  static_inputs = {}
126
167
  for i in range(num_inputs):
@@ -150,6 +191,10 @@ class FfiKernel:
150
191
  # stash the value to be retrieved by callback
151
192
  static_inputs[input_arg.name] = input_arg.type(input_value)
152
193
 
194
+ # append in-out arg to output types
195
+ if input_arg.in_out:
196
+ out_types.append(get_jax_output_type(input_arg, input_value.shape))
197
+
153
198
  # launch dimensions
154
199
  if launch_dims is None:
155
200
  # use the shape of the first input array
@@ -162,8 +207,7 @@ class FfiKernel:
162
207
  else:
163
208
  launch_dims = tuple(launch_dims)
164
209
 
165
- # output types
166
- out_types = []
210
+ # output shapes
167
211
  if isinstance(output_dims, dict):
168
212
  # assume a dictionary of shapes keyed on argument name
169
213
  for output_arg in self.output_args:
@@ -185,6 +229,7 @@ class FfiKernel:
185
229
  self.name,
186
230
  out_types,
187
231
  vmap_method=vmap_method,
232
+ input_output_aliases=self.input_output_aliases,
188
233
  )
189
234
 
190
235
  # ensure the kernel module is loaded before the callback, otherwise graph capture may fail
@@ -238,9 +283,8 @@ class FfiKernel:
238
283
 
239
284
  arg_refs = []
240
285
 
241
- # inputs
242
- for i in range(num_inputs):
243
- input_arg = self.input_args[i]
286
+ # input and in-out args
287
+ for i, input_arg in enumerate(self.input_args):
244
288
  if input_arg.is_array:
245
289
  buffer = inputs[i].contents
246
290
  shape = buffer.dims[: input_arg.type.ndim]
@@ -255,10 +299,9 @@ class FfiKernel:
255
299
  kernel_params[i + 1] = ctypes.addressof(arg)
256
300
  arg_refs.append(arg) # keep a reference
257
301
 
258
- # outputs
259
- for i in range(num_outputs):
260
- output_arg = self.output_args[i]
261
- buffer = outputs[i].contents
302
+ # pure output args (skip in-out FFI buffers)
303
+ for i, output_arg in enumerate(self.output_args):
304
+ buffer = outputs[i + self.num_in_out].contents
262
305
  shape = buffer.dims[: output_arg.type.ndim]
263
306
  strides = strides_from_shape(shape, output_arg.type.dtype)
264
307
  arg = array_t(buffer.data, 0, output_arg.type.ndim, shape, strides)
@@ -295,29 +338,38 @@ class FfiKernel:
295
338
  class FfiCallDesc:
296
339
  def __init__(self, static_inputs):
297
340
  self.static_inputs = static_inputs
341
+ self.captures = {}
298
342
 
299
343
 
300
344
  class FfiCallable:
301
- def __init__(self, func, num_outputs, graph_compatible, vmap_method, output_dims):
345
+ def __init__(self, func, num_outputs, graph_mode, vmap_method, output_dims, in_out_argnames):
302
346
  self.func = func
303
347
  self.name = generate_unique_name(func)
304
348
  self.num_outputs = num_outputs
305
349
  self.vmap_method = vmap_method
306
- self.graph_compatible = graph_compatible
350
+ self.graph_mode = graph_mode
307
351
  self.output_dims = output_dims
308
352
  self.first_array_arg = None
309
353
  self.call_id = 0
310
354
  self.call_descriptors = {}
311
355
 
356
+ in_out_argnames_list = in_out_argnames or []
357
+ in_out_argnames = set(in_out_argnames_list)
358
+ if len(in_out_argnames_list) != len(in_out_argnames):
359
+ raise AssertionError("in_out_argnames must not contain duplicate names")
360
+
312
361
  # get arguments and annotations
313
362
  argspec = get_full_arg_spec(func)
314
363
 
315
364
  num_args = len(argspec.args)
316
- self.num_inputs = num_args - num_outputs
365
+ self.num_in_out = len(in_out_argnames)
366
+ self.num_inputs = num_args - num_outputs + self.num_in_out
317
367
  if self.num_outputs < 1:
318
368
  raise ValueError("At least one output is required")
319
369
  if self.num_outputs > num_args:
320
370
  raise ValueError("Number of outputs cannot be greater than the number of kernel arguments")
371
+ if self.num_outputs < self.num_in_out:
372
+ raise ValueError("Number of outputs cannot be smaller than the number of in_out_argnames")
321
373
 
322
374
  if len(argspec.annotations) < num_args:
323
375
  raise RuntimeError(f"Incomplete argument annotations on function {self.name}")
@@ -330,15 +382,43 @@ class FfiCallable:
330
382
  if arg_type is not None:
331
383
  raise TypeError("Function must not return a value")
332
384
  else:
333
- arg = FfiArg(arg_name, arg_type)
385
+ arg = FfiArg(arg_name, arg_type, arg_name in in_out_argnames)
386
+ if arg_name in in_out_argnames:
387
+ in_out_argnames.remove(arg_name)
334
388
  if arg.is_array:
335
389
  if arg_idx < self.num_inputs and self.first_array_arg is None:
336
390
  self.first_array_arg = arg_idx
337
391
  self.args.append(arg)
392
+
393
+ if arg.in_out and arg_idx >= self.num_inputs:
394
+ raise AssertionError(
395
+ f"Expected an output-only argument for argument {arg_name}."
396
+ " in_out arguments should be placed before output-only arguments."
397
+ )
398
+
338
399
  arg_idx += 1
339
400
 
340
- self.input_args = self.args[: self.num_inputs]
341
- self.output_args = self.args[self.num_inputs :]
401
+ if in_out_argnames:
402
+ raise ValueError(f"in_out_argnames: '{in_out_argnames}' did not match any function argument names.")
403
+
404
+ self.input_args = self.args[: self.num_inputs] # includes in-out args
405
+ self.output_args = self.args[self.num_inputs :] # pure output args
406
+
407
+ # Buffer indices for array arguments in callback.
408
+ # In-out buffers are the same pointers in the XLA call frame,
409
+ # so we only include them for inputs and skip them for outputs.
410
+ self.array_input_indices = [i for i, arg in enumerate(self.input_args) if arg.is_array]
411
+ self.array_output_indices = list(range(self.num_in_out, self.num_outputs))
412
+
413
+ # Build input output aliases.
414
+ out_id = 0
415
+ input_output_aliases = {}
416
+ for in_id, arg in enumerate(self.input_args):
417
+ if not arg.in_out:
418
+ continue
419
+ input_output_aliases[in_id] = out_id
420
+ out_id += 1
421
+ self.input_output_aliases = input_output_aliases
342
422
 
343
423
  # register the callback
344
424
  FFI_CCALLFUNC = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_CallFrame))
@@ -350,7 +430,9 @@ class FfiCallable:
350
430
  def __call__(self, *args, output_dims=None, vmap_method=None):
351
431
  num_inputs = len(args)
352
432
  if num_inputs != self.num_inputs:
353
- raise ValueError(f"Expected {self.num_inputs} inputs, but got {num_inputs}")
433
+ input_names = ", ".join(arg.name for arg in self.input_args)
434
+ s = "" if self.num_inputs == 1 else "s"
435
+ raise ValueError(f"Expected {self.num_inputs} input{s} ({input_names}), but got {num_inputs}")
354
436
 
355
437
  # default argument fallback
356
438
  if vmap_method is None:
@@ -358,6 +440,9 @@ class FfiCallable:
358
440
  if output_dims is None:
359
441
  output_dims = self.output_dims
360
442
 
443
+ # output types
444
+ out_types = []
445
+
361
446
  # process inputs
362
447
  static_inputs = {}
363
448
  for i in range(num_inputs):
@@ -387,12 +472,11 @@ class FfiCallable:
387
472
  # stash the value to be retrieved by callback
388
473
  static_inputs[input_arg.name] = input_arg.type(input_value)
389
474
 
390
- if output_dims is None and self.first_array_arg is not None:
391
- # use the shape of the first input array
392
- output_dims = get_warp_shape(self.input_args[self.first_array_arg], args[self.first_array_arg].shape)
475
+ # append in-out arg to output types
476
+ if input_arg.in_out:
477
+ out_types.append(get_jax_output_type(input_arg, input_value.shape))
393
478
 
394
- # output types
395
- out_types = []
479
+ # output shapes
396
480
  if isinstance(output_dims, dict):
397
481
  # assume a dictionary of shapes keyed on argument name
398
482
  for output_arg in self.output_args:
@@ -402,7 +486,9 @@ class FfiCallable:
402
486
  out_types.append(get_jax_output_type(output_arg, dims))
403
487
  else:
404
488
  if output_dims is None:
405
- raise ValueError("Unable to determine output dimensions")
489
+ if self.first_array_arg is None:
490
+ raise ValueError("Unable to determine output dimensions")
491
+ output_dims = get_warp_shape(self.input_args[self.first_array_arg], args[self.first_array_arg].shape)
406
492
  elif isinstance(output_dims, int):
407
493
  output_dims = (output_dims,)
408
494
  # assume same dimensions for all outputs
@@ -413,6 +499,7 @@ class FfiCallable:
413
499
  self.name,
414
500
  out_types,
415
501
  vmap_method=vmap_method,
502
+ input_output_aliases=self.input_output_aliases,
416
503
  # has_side_effect=True, # force this function to execute even if outputs aren't used
417
504
  )
418
505
 
@@ -430,11 +517,10 @@ class FfiCallable:
430
517
 
431
518
  def ffi_callback(self, call_frame):
432
519
  try:
433
- # TODO Try-catch around the body and return XLA_FFI_Error on error.
434
- extension = call_frame.contents.extension_start
435
520
  # On the first call, XLA runtime will query the API version and traits
436
521
  # metadata using the |extension| field. Let us respond to that query
437
522
  # if the metadata extension is present.
523
+ extension = call_frame.contents.extension_start
438
524
  if extension:
439
525
  # Try to set the version metadata.
440
526
  if extension.contents.type == XLA_FFI_Extension_Type.Metadata:
@@ -442,15 +528,19 @@ class FfiCallable:
442
528
  metadata_ext.contents.metadata.contents.api_version.major_version = 0
443
529
  metadata_ext.contents.metadata.contents.api_version.minor_version = 1
444
530
  # Turn on CUDA graphs for this handler.
445
- if self.graph_compatible:
531
+ if self.graph_mode is GraphMode.JAX:
446
532
  metadata_ext.contents.metadata.contents.traits = (
447
533
  XLA_FFI_Handler_TraitsBits.COMMAND_BUFFER_COMPATIBLE
448
534
  )
449
535
  return None
450
536
 
451
537
  # retrieve call info
452
- attrs = decode_attrs(call_frame.contents.attrs)
453
- call_id = int(attrs["call_id"])
538
+ # NOTE: this assumes that there's only one attribute - call_id (int64).
539
+ # A more general but slower approach is this:
540
+ # attrs = decode_attrs(call_frame.contents.attrs)
541
+ # call_id = int(attrs["call_id"])
542
+ attr = ctypes.cast(call_frame.contents.attrs.attrs[0], ctypes.POINTER(XLA_FFI_Scalar)).contents
543
+ call_id = ctypes.cast(attr.value, ctypes.POINTER(ctypes.c_int64)).contents.value
454
544
  call_desc = self.call_descriptors[call_id]
455
545
 
456
546
  num_inputs = call_frame.contents.args.size
@@ -462,16 +552,42 @@ class FfiCallable:
462
552
  assert num_inputs == self.num_inputs
463
553
  assert num_outputs == self.num_outputs
464
554
 
465
- device = wp.device_from_jax(get_jax_device())
466
555
  cuda_stream = get_stream_from_callframe(call_frame.contents)
556
+
557
+ if self.graph_mode == GraphMode.WARP:
558
+ # check if we already captured an identical call
559
+ ip = [inputs[i].contents.data for i in self.array_input_indices]
560
+ op = [outputs[i].contents.data for i in self.array_output_indices]
561
+ buffer_hash = hash((*ip, *op))
562
+ capture = call_desc.captures.get(buffer_hash)
563
+
564
+ # launch existing graph
565
+ if capture is not None:
566
+ # NOTE: We use the native graph API to avoid overhead with obtaining Stream and Device objects in Python.
567
+ # This code should match wp.capture_launch().
568
+ graph = capture.graph
569
+ if graph.graph_exec is None:
570
+ g = ctypes.c_void_p()
571
+ if not wp.context.runtime.core.wp_cuda_graph_create_exec(
572
+ graph.device.context, cuda_stream, graph.graph, ctypes.byref(g)
573
+ ):
574
+ raise RuntimeError(f"Graph creation error: {wp.context.runtime.get_error_string()}")
575
+ graph.graph_exec = g
576
+
577
+ if not wp.context.runtime.core.wp_cuda_graph_launch(graph.graph_exec, cuda_stream):
578
+ raise RuntimeError(f"Graph launch error: {wp.context.runtime.get_error_string()}")
579
+
580
+ # early out
581
+ return
582
+
583
+ device = wp.device_from_jax(get_jax_device())
467
584
  stream = wp.Stream(device, cuda_stream=cuda_stream)
468
585
 
469
586
  # reconstruct the argument list
470
587
  arg_list = []
471
588
 
472
- # inputs
473
- for i in range(num_inputs):
474
- arg = self.input_args[i]
589
+ # input and in-out args
590
+ for i, arg in enumerate(self.input_args):
475
591
  if arg.is_array:
476
592
  buffer = inputs[i].contents
477
593
  shape = buffer.dims[: buffer.rank - arg.dtype_ndim]
@@ -482,10 +598,9 @@ class FfiCallable:
482
598
  value = call_desc.static_inputs[arg.name]
483
599
  arg_list.append(value)
484
600
 
485
- # outputs
486
- for i in range(num_outputs):
487
- arg = self.output_args[i]
488
- buffer = outputs[i].contents
601
+ # pure output args (skip in-out FFI buffers)
602
+ for i, arg in enumerate(self.output_args):
603
+ buffer = outputs[i + self.num_in_out].contents
489
604
  shape = buffer.dims[: buffer.rank - arg.dtype_ndim]
490
605
  arr = wp.array(ptr=buffer.data, dtype=arg.type.dtype, shape=shape, device=device)
491
606
  arg_list.append(arr)
@@ -493,11 +608,20 @@ class FfiCallable:
493
608
  # call the Python function with reconstructed arguments
494
609
  with wp.ScopedStream(stream, sync_enter=False):
495
610
  if stream.is_capturing:
496
- with wp.ScopedCapture(stream=stream, external=True) as capture:
611
+ # capturing with JAX
612
+ with wp.ScopedCapture(external=True) as capture:
497
613
  self.func(*arg_list)
498
614
  # keep a reference to the capture object to prevent required modules getting unloaded
499
615
  call_desc.capture = capture
616
+ elif self.graph_mode == GraphMode.WARP:
617
+ # capturing with WARP
618
+ with wp.ScopedCapture() as capture:
619
+ self.func(*arg_list)
620
+ wp.capture_launch(capture.graph)
621
+ # keep a reference to the capture object and reuse it with same buffers
622
+ call_desc.captures[buffer_hash] = capture
500
623
  else:
624
+ # not capturing
501
625
  self.func(*arg_list)
502
626
 
503
627
  except Exception as e:
@@ -515,7 +639,9 @@ _FFI_KERNEL_REGISTRY: dict[str, FfiKernel] = {}
515
639
  _FFI_REGISTRY_LOCK = threading.Lock()
516
640
 
517
641
 
518
- def jax_kernel(kernel, num_outputs=1, vmap_method="broadcast_all", launch_dims=None, output_dims=None):
642
+ def jax_kernel(
643
+ kernel, num_outputs=1, vmap_method="broadcast_all", launch_dims=None, output_dims=None, in_out_argnames=None
644
+ ):
519
645
  """Create a JAX callback from a Warp kernel.
520
646
 
521
647
  NOTE: This is an experimental feature under development.
@@ -523,6 +649,7 @@ def jax_kernel(kernel, num_outputs=1, vmap_method="broadcast_all", launch_dims=N
523
649
  Args:
524
650
  kernel: The Warp kernel to launch.
525
651
  num_outputs: Optional. Specify the number of output arguments if greater than 1.
652
+ This must include the number of ``in_out_arguments``.
526
653
  vmap_method: Optional. String specifying how the callback transforms under ``vmap()``.
527
654
  This argument can also be specified for individual calls.
528
655
  launch_dims: Optional. Specify the default kernel launch dimensions. If None, launch
@@ -531,12 +658,13 @@ def jax_kernel(kernel, num_outputs=1, vmap_method="broadcast_all", launch_dims=N
531
658
  output_dims: Optional. Specify the default dimensions of output arrays. If None, output
532
659
  dimensions are inferred from the launch dimensions.
533
660
  This argument can also be specified for individual calls.
661
+ in_out_argnames: Optional. Names of input-output arguments.
534
662
 
535
663
  Limitations:
536
664
  - All kernel arguments must be contiguous arrays or scalars.
537
665
  - Scalars must be static arguments in JAX.
538
- - Input arguments are followed by output arguments in the Warp kernel definition.
539
- - There must be at least one output argument.
666
+ - Input and input-output arguments must precede the output arguments in the ``kernel`` definition.
667
+ - There must be at least one output or input-output argument.
540
668
  - Only the CUDA backend is supported.
541
669
  """
542
670
  key = (
@@ -549,7 +677,7 @@ def jax_kernel(kernel, num_outputs=1, vmap_method="broadcast_all", launch_dims=N
549
677
 
550
678
  with _FFI_REGISTRY_LOCK:
551
679
  if key not in _FFI_KERNEL_REGISTRY:
552
- new_kernel = FfiKernel(kernel, num_outputs, vmap_method, launch_dims, output_dims)
680
+ new_kernel = FfiKernel(kernel, num_outputs, vmap_method, launch_dims, output_dims, in_out_argnames)
553
681
  _FFI_KERNEL_REGISTRY[key] = new_kernel
554
682
 
555
683
  return _FFI_KERNEL_REGISTRY[key]
@@ -558,9 +686,11 @@ def jax_kernel(kernel, num_outputs=1, vmap_method="broadcast_all", launch_dims=N
558
686
  def jax_callable(
559
687
  func: Callable,
560
688
  num_outputs: int = 1,
561
- graph_compatible: bool = True,
562
- vmap_method: str = "broadcast_all",
689
+ graph_compatible: Optional[bool] = None, # deprecated
690
+ graph_mode: GraphMode = GraphMode.JAX,
691
+ vmap_method: Optional[str] = "broadcast_all",
563
692
  output_dims=None,
693
+ in_out_argnames=None,
564
694
  ):
565
695
  """Create a JAX callback from an annotated Python function.
566
696
 
@@ -571,31 +701,50 @@ def jax_callable(
571
701
  Args:
572
702
  func: The Python function to call.
573
703
  num_outputs: Optional. Specify the number of output arguments if greater than 1.
704
+ This must include the number of ``in_out_arguments``.
574
705
  graph_compatible: Optional. Whether the function can be called during CUDA graph capture.
706
+ This argument is deprecated, use ``graph_mode`` instead.
707
+ graph_mode: Optional. CUDA graph capture mode.
708
+ ``GraphMode.JAX`` (default): Let JAX capture the graph, which may be used as a subgraph in an enclosing capture.
709
+ ``GraphMode.WARP``: Let Warp capture the graph. Use this mode when the callable cannot be used as a subraph,
710
+ such as when the callable uses conditional graph nodes.
711
+ ``GraphMode.NONE``: Disable graph capture. Use when the callable performs operations that are not legal in a graph,
712
+ such as host synchronization.
575
713
  vmap_method: Optional. String specifying how the callback transforms under ``vmap()``.
576
714
  This argument can also be specified for individual calls.
577
715
  output_dims: Optional. Specify the default dimensions of output arrays.
578
716
  If ``None``, output dimensions are inferred from the launch dimensions.
579
717
  This argument can also be specified for individual calls.
718
+ in_out_argnames: Optional. Names of input-output arguments.
580
719
 
581
720
  Limitations:
582
721
  - All kernel arguments must be contiguous arrays or scalars.
583
722
  - Scalars must be static arguments in JAX.
584
- - Input arguments are followed by output arguments in the Warp kernel definition.
585
- - There must be at least one output argument.
723
+ - Input and input-output arguments must precede the output arguments in the ``func`` definition.
724
+ - There must be at least one output or input-output argument.
586
725
  - Only the CUDA backend is supported.
587
726
  """
727
+
728
+ if graph_compatible is not None:
729
+ wp.utils.warn(
730
+ "The `graph_compatible` argument is deprecated, use `graph_mode` instead.",
731
+ DeprecationWarning,
732
+ stacklevel=3,
733
+ )
734
+ if graph_compatible is False:
735
+ graph_mode = GraphMode.NONE
736
+
588
737
  key = (
589
738
  func,
590
739
  num_outputs,
591
- graph_compatible,
740
+ graph_mode,
592
741
  vmap_method,
593
742
  tuple(sorted(output_dims.items())) if output_dims else output_dims,
594
743
  )
595
744
 
596
745
  with _FFI_REGISTRY_LOCK:
597
746
  if key not in _FFI_CALLABLE_REGISTRY:
598
- new_callable = FfiCallable(func, num_outputs, graph_compatible, vmap_method, output_dims)
747
+ new_callable = FfiCallable(func, num_outputs, graph_mode, vmap_method, output_dims, in_out_argnames)
599
748
  _FFI_CALLABLE_REGISTRY[key] = new_callable
600
749
 
601
750
  return _FFI_CALLABLE_REGISTRY[key]
@@ -626,7 +775,6 @@ def register_ffi_callback(name: str, func: Callable, graph_compatible: bool = Tr
626
775
 
627
776
  def ffi_callback(call_frame):
628
777
  try:
629
- # TODO Try-catch around the body and return XLA_FFI_Error on error.
630
778
  extension = call_frame.contents.extension_start
631
779
  # On the first call, XLA runtime will query the API version and traits
632
780
  # metadata using the |extension| field. Let us respond to that query
warp/native/array.h CHANGED
@@ -161,7 +161,7 @@ inline CUDA_CALLABLE void print(shape_t s)
161
161
  // should probably store ndim with shape
162
162
  printf("(%d, %d, %d, %d)\n", s.dims[0], s.dims[1], s.dims[2], s.dims[3]);
163
163
  }
164
- inline CUDA_CALLABLE void adj_print(shape_t s, shape_t& shape_t) {}
164
+ inline CUDA_CALLABLE void adj_print(shape_t s, shape_t& adj_s) {}
165
165
 
166
166
 
167
167
  template <typename T>
@@ -665,11 +665,11 @@ CUDA_CALLABLE inline indexedarray_t<T> view(indexedarray_t<T>& src, int i, int j
665
665
  }
666
666
 
667
667
  template<template<typename> class A1, template<typename> class A2, template<typename> class A3, typename T>
668
- inline CUDA_CALLABLE void adj_view(A1<T>& src, int i, A2<T>& adj_src, int adj_i, A3<T> adj_ret) {}
668
+ inline CUDA_CALLABLE void adj_view(A1<T>& src, int i, A2<T>& adj_src, int adj_i, A3<T>& adj_ret) {}
669
669
  template<template<typename> class A1, template<typename> class A2, template<typename> class A3, typename T>
670
- inline CUDA_CALLABLE void adj_view(A1<T>& src, int i, int j, A2<T>& adj_src, int adj_i, int adj_j, A3<T> adj_ret) {}
670
+ inline CUDA_CALLABLE void adj_view(A1<T>& src, int i, int j, A2<T>& adj_src, int adj_i, int adj_j, A3<T>& adj_ret) {}
671
671
  template<template<typename> class A1, template<typename> class A2, template<typename> class A3, typename T>
672
- inline CUDA_CALLABLE void adj_view(A1<T>& src, int i, int j, int k, A2<T>& adj_src, int adj_i, int adj_j, int adj_k, A3<T> adj_ret) {}
672
+ inline CUDA_CALLABLE void adj_view(A1<T>& src, int i, int j, int k, A2<T>& adj_src, int adj_i, int adj_j, int adj_k, A3<T>& adj_ret) {}
673
673
 
674
674
  // TODO: lower_bound() for indexed arrays?
675
675
 
warp/native/builtin.h CHANGED
@@ -268,16 +268,20 @@ inline CUDA_CALLABLE half operator / (half a,half b)
268
268
 
269
269
 
270
270
  template <typename T>
271
- CUDA_CALLABLE float cast_float(T x) { return (float)(x); }
271
+ CUDA_CALLABLE inline float cast_float(T x) { return (float)(x); }
272
272
 
273
273
  template <typename T>
274
- CUDA_CALLABLE int cast_int(T x) { return (int)(x); }
274
+ CUDA_CALLABLE inline int cast_int(T x) { return (int)(x); }
275
275
 
276
276
  template <typename T>
277
- CUDA_CALLABLE void adj_cast_float(T x, T& adj_x, float adj_ret) { adj_x += T(adj_ret); }
277
+ CUDA_CALLABLE inline void adj_cast_float(T x, T& adj_x, float adj_ret) {}
278
+
279
+ CUDA_CALLABLE inline void adj_cast_float(float16 x, float16& adj_x, float adj_ret) { adj_x += float16(adj_ret); }
280
+ CUDA_CALLABLE inline void adj_cast_float(float32 x, float32& adj_x, float adj_ret) { adj_x += float32(adj_ret); }
281
+ CUDA_CALLABLE inline void adj_cast_float(float64 x, float64& adj_x, float adj_ret) { adj_x += float64(adj_ret); }
278
282
 
279
283
  template <typename T>
280
- CUDA_CALLABLE void adj_cast_int(T x, T& adj_x, int adj_ret) { adj_x += adj_ret; }
284
+ CUDA_CALLABLE inline void adj_cast_int(T x, T& adj_x, int adj_ret) {}
281
285
 
282
286
  template <typename T>
283
287
  CUDA_CALLABLE inline void adj_int8(T, T&, int8) {}
warp/native/coloring.cpp CHANGED
@@ -209,9 +209,13 @@ float balance_color_groups(float target_max_min_ratio,
209
209
  do
210
210
  {
211
211
  int biggest_group = -1, smallest_group = -1;
212
-
212
+ float prev_max_min_ratio = max_min_ratio;
213
213
  max_min_ratio = find_largest_smallest_groups(color_groups, biggest_group, smallest_group);
214
214
 
215
+ if (prev_max_min_ratio > 0 && prev_max_min_ratio < max_min_ratio) {
216
+ return max_min_ratio;
217
+ }
218
+
215
219
  // graph is not optimizable anymore or target ratio reached
216
220
  if (color_groups[biggest_group].size() - color_groups[smallest_group].size() <= 2
217
221
  || max_min_ratio < target_max_min_ratio)
warp/native/cuda_util.cpp CHANGED
@@ -212,7 +212,7 @@ bool init_cuda_driver()
212
212
  get_driver_entry_point("cuDeviceGetCount", 2000, &(void*&)pfn_cuDeviceGetCount);
213
213
  get_driver_entry_point("cuDeviceGetName", 2000, &(void*&)pfn_cuDeviceGetName);
214
214
  get_driver_entry_point("cuDeviceGetAttribute", 2000, &(void*&)pfn_cuDeviceGetAttribute);
215
- get_driver_entry_point("cuDeviceGetUuid", 110400, &(void*&)pfn_cuDeviceGetUuid);
215
+ get_driver_entry_point("cuDeviceGetUuid", 11040, &(void*&)pfn_cuDeviceGetUuid);
216
216
  get_driver_entry_point("cuDevicePrimaryCtxRetain", 7000, &(void*&)pfn_cuDevicePrimaryCtxRetain);
217
217
  get_driver_entry_point("cuDevicePrimaryCtxRelease", 11000, &(void*&)pfn_cuDevicePrimaryCtxRelease);
218
218
  get_driver_entry_point("cuDeviceCanAccessPeer", 4000, &(void*&)pfn_cuDeviceCanAccessPeer);
warp/native/intersect.h CHANGED
@@ -316,7 +316,7 @@ CUDA_CALLABLE inline bool intersect_ray_tri_woop(const vec3& p, const vec3& dir,
316
316
 
317
317
  if (dir[kz] < 0.0f)
318
318
  {
319
- float tmp = kx;
319
+ int tmp = kx;
320
320
  kx = ky;
321
321
  ky = tmp;
322
322
  }
@@ -410,7 +410,7 @@ CUDA_CALLABLE inline void adj_intersect_ray_tri_woop(
410
410
 
411
411
  if (dir[kz] < 0.0f)
412
412
  {
413
- float tmp = kx;
413
+ int tmp = kx;
414
414
  kx = ky;
415
415
  ky = tmp;
416
416
  }
warp/native/mat.h CHANGED
@@ -1533,13 +1533,13 @@ inline CUDA_CALLABLE void adj_div(const mat_t<Rows,Cols,Type>& a, Type s, mat_t<
1533
1533
  template<unsigned Rows, unsigned Cols, typename Type>
1534
1534
  inline CUDA_CALLABLE void adj_div(Type s, const mat_t<Rows,Cols,Type>& a, Type& adj_s, mat_t<Rows,Cols,Type>& adj_a, const mat_t<Rows,Cols,Type>& adj_ret)
1535
1535
  {
1536
- adj_s -= tensordot(a , adj_ret)/ (s * s); // - a / s^2
1537
-
1538
1536
  for (unsigned i=0; i < Rows; ++i)
1539
1537
  {
1540
1538
  for (unsigned j=0; j < Cols; ++j)
1541
1539
  {
1542
- adj_a.data[i][j] += s / adj_ret.data[i][j];
1540
+ Type inv = Type(1) / a.data[i][j];
1541
+ adj_a.data[i][j] -= s * adj_ret.data[i][j] * inv * inv;
1542
+ adj_s += adj_ret.data[i][j] * inv;
1543
1543
  }
1544
1544
  }
1545
1545
  }
warp/native/mesh.h CHANGED
@@ -1357,7 +1357,7 @@ CUDA_CALLABLE inline void adj_mesh_query_point_sign_normal(uint64_t id, const ve
1357
1357
  uint64_t adj_id, vec3& adj_point, float& adj_max_dist, float& adj_epsilon, mesh_query_point_t& adj_ret)
1358
1358
  {
1359
1359
  adj_mesh_query_point_sign_normal(id, point, max_dist, ret.sign, ret.face, ret.u, ret.v, epsilon,
1360
- adj_id, adj_point, adj_max_dist, adj_ret.sign, adj_ret.face, adj_ret.u, adj_ret.v, epsilon, adj_ret.result);
1360
+ adj_id, adj_point, adj_max_dist, adj_ret.sign, adj_ret.face, adj_ret.u, adj_ret.v, adj_epsilon, adj_ret.result);
1361
1361
  }
1362
1362
 
1363
1363
  CUDA_CALLABLE inline void adj_mesh_query_point_sign_winding_number(uint64_t id, const vec3& point, float max_dist, float accuracy, float winding_number_threshold, const mesh_query_point_t& ret,