warp-lang 1.7.2rc1__py3-none-macosx_10_13_universal2.whl → 1.8.0__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (180) hide show
  1. warp/__init__.py +3 -1
  2. warp/__init__.pyi +3489 -1
  3. warp/autograd.py +45 -122
  4. warp/bin/libwarp.dylib +0 -0
  5. warp/build.py +241 -252
  6. warp/build_dll.py +125 -26
  7. warp/builtins.py +1907 -384
  8. warp/codegen.py +257 -101
  9. warp/config.py +12 -1
  10. warp/constants.py +1 -1
  11. warp/context.py +657 -223
  12. warp/dlpack.py +1 -1
  13. warp/examples/benchmarks/benchmark_cloth.py +2 -2
  14. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  15. warp/examples/core/example_sample_mesh.py +1 -1
  16. warp/examples/core/example_spin_lock.py +93 -0
  17. warp/examples/core/example_work_queue.py +118 -0
  18. warp/examples/fem/example_adaptive_grid.py +5 -5
  19. warp/examples/fem/example_apic_fluid.py +1 -1
  20. warp/examples/fem/example_burgers.py +1 -1
  21. warp/examples/fem/example_convection_diffusion.py +9 -6
  22. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  23. warp/examples/fem/example_deformed_geometry.py +1 -1
  24. warp/examples/fem/example_diffusion.py +2 -2
  25. warp/examples/fem/example_diffusion_3d.py +1 -1
  26. warp/examples/fem/example_distortion_energy.py +1 -1
  27. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  28. warp/examples/fem/example_magnetostatics.py +5 -3
  29. warp/examples/fem/example_mixed_elasticity.py +5 -3
  30. warp/examples/fem/example_navier_stokes.py +11 -9
  31. warp/examples/fem/example_nonconforming_contact.py +5 -3
  32. warp/examples/fem/example_streamlines.py +8 -3
  33. warp/examples/fem/utils.py +9 -8
  34. warp/examples/interop/example_jax_ffi_callback.py +2 -2
  35. warp/examples/optim/example_drone.py +1 -1
  36. warp/examples/sim/example_cloth.py +1 -1
  37. warp/examples/sim/example_cloth_self_contact.py +48 -54
  38. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  39. warp/examples/tile/example_tile_cholesky.py +2 -1
  40. warp/examples/tile/example_tile_convolution.py +1 -1
  41. warp/examples/tile/example_tile_filtering.py +1 -1
  42. warp/examples/tile/example_tile_matmul.py +1 -1
  43. warp/examples/tile/example_tile_mlp.py +2 -0
  44. warp/fabric.py +7 -7
  45. warp/fem/__init__.py +5 -0
  46. warp/fem/adaptivity.py +1 -1
  47. warp/fem/cache.py +152 -63
  48. warp/fem/dirichlet.py +2 -2
  49. warp/fem/domain.py +136 -6
  50. warp/fem/field/field.py +141 -99
  51. warp/fem/field/nodal_field.py +85 -39
  52. warp/fem/field/virtual.py +97 -52
  53. warp/fem/geometry/adaptive_nanogrid.py +91 -86
  54. warp/fem/geometry/closest_point.py +13 -0
  55. warp/fem/geometry/deformed_geometry.py +102 -40
  56. warp/fem/geometry/element.py +56 -2
  57. warp/fem/geometry/geometry.py +323 -22
  58. warp/fem/geometry/grid_2d.py +157 -62
  59. warp/fem/geometry/grid_3d.py +116 -20
  60. warp/fem/geometry/hexmesh.py +86 -20
  61. warp/fem/geometry/nanogrid.py +166 -86
  62. warp/fem/geometry/partition.py +59 -25
  63. warp/fem/geometry/quadmesh.py +86 -135
  64. warp/fem/geometry/tetmesh.py +47 -119
  65. warp/fem/geometry/trimesh.py +77 -270
  66. warp/fem/integrate.py +107 -52
  67. warp/fem/linalg.py +25 -58
  68. warp/fem/operator.py +124 -27
  69. warp/fem/quadrature/pic_quadrature.py +36 -14
  70. warp/fem/quadrature/quadrature.py +40 -16
  71. warp/fem/space/__init__.py +1 -1
  72. warp/fem/space/basis_function_space.py +66 -46
  73. warp/fem/space/basis_space.py +17 -4
  74. warp/fem/space/dof_mapper.py +1 -1
  75. warp/fem/space/function_space.py +2 -2
  76. warp/fem/space/grid_2d_function_space.py +4 -1
  77. warp/fem/space/hexmesh_function_space.py +4 -2
  78. warp/fem/space/nanogrid_function_space.py +3 -1
  79. warp/fem/space/partition.py +11 -2
  80. warp/fem/space/quadmesh_function_space.py +4 -1
  81. warp/fem/space/restriction.py +5 -2
  82. warp/fem/space/shape/__init__.py +10 -8
  83. warp/fem/space/tetmesh_function_space.py +4 -1
  84. warp/fem/space/topology.py +52 -21
  85. warp/fem/space/trimesh_function_space.py +4 -1
  86. warp/fem/utils.py +53 -8
  87. warp/jax.py +1 -2
  88. warp/jax_experimental/ffi.py +12 -17
  89. warp/jax_experimental/xla_ffi.py +37 -24
  90. warp/math.py +171 -1
  91. warp/native/array.h +99 -0
  92. warp/native/builtin.h +174 -31
  93. warp/native/coloring.cpp +1 -1
  94. warp/native/exports.h +118 -63
  95. warp/native/intersect.h +3 -3
  96. warp/native/mat.h +5 -10
  97. warp/native/mathdx.cpp +11 -5
  98. warp/native/matnn.h +1 -123
  99. warp/native/quat.h +28 -4
  100. warp/native/sparse.cpp +121 -258
  101. warp/native/sparse.cu +181 -274
  102. warp/native/spatial.h +305 -17
  103. warp/native/tile.h +583 -72
  104. warp/native/tile_radix_sort.h +1108 -0
  105. warp/native/tile_reduce.h +237 -2
  106. warp/native/tile_scan.h +240 -0
  107. warp/native/tuple.h +189 -0
  108. warp/native/vec.h +6 -16
  109. warp/native/warp.cpp +36 -4
  110. warp/native/warp.cu +574 -51
  111. warp/native/warp.h +47 -74
  112. warp/optim/linear.py +5 -1
  113. warp/paddle.py +7 -8
  114. warp/py.typed +0 -0
  115. warp/render/render_opengl.py +58 -29
  116. warp/render/render_usd.py +124 -61
  117. warp/sim/__init__.py +9 -0
  118. warp/sim/collide.py +252 -78
  119. warp/sim/graph_coloring.py +8 -1
  120. warp/sim/import_mjcf.py +4 -3
  121. warp/sim/import_usd.py +11 -7
  122. warp/sim/integrator.py +5 -2
  123. warp/sim/integrator_euler.py +1 -1
  124. warp/sim/integrator_featherstone.py +1 -1
  125. warp/sim/integrator_vbd.py +751 -320
  126. warp/sim/integrator_xpbd.py +1 -1
  127. warp/sim/model.py +265 -260
  128. warp/sim/utils.py +10 -7
  129. warp/sparse.py +303 -166
  130. warp/tape.py +52 -51
  131. warp/tests/cuda/test_conditional_captures.py +1046 -0
  132. warp/tests/cuda/test_streams.py +1 -1
  133. warp/tests/geometry/test_volume.py +2 -2
  134. warp/tests/interop/test_dlpack.py +9 -9
  135. warp/tests/interop/test_jax.py +0 -1
  136. warp/tests/run_coverage_serial.py +1 -1
  137. warp/tests/sim/disabled_kinematics.py +2 -2
  138. warp/tests/sim/{test_vbd.py → test_cloth.py} +296 -113
  139. warp/tests/sim/test_collision.py +159 -51
  140. warp/tests/sim/test_coloring.py +15 -1
  141. warp/tests/test_array.py +254 -2
  142. warp/tests/test_array_reduce.py +2 -2
  143. warp/tests/test_atomic_cas.py +299 -0
  144. warp/tests/test_codegen.py +142 -19
  145. warp/tests/test_conditional.py +47 -1
  146. warp/tests/test_ctypes.py +0 -20
  147. warp/tests/test_devices.py +8 -0
  148. warp/tests/test_fabricarray.py +4 -2
  149. warp/tests/test_fem.py +58 -25
  150. warp/tests/test_func.py +42 -1
  151. warp/tests/test_grad.py +1 -1
  152. warp/tests/test_lerp.py +1 -3
  153. warp/tests/test_map.py +481 -0
  154. warp/tests/test_mat.py +1 -24
  155. warp/tests/test_quat.py +6 -15
  156. warp/tests/test_rounding.py +10 -38
  157. warp/tests/test_runlength_encode.py +7 -7
  158. warp/tests/test_smoothstep.py +1 -1
  159. warp/tests/test_sparse.py +51 -2
  160. warp/tests/test_spatial.py +507 -1
  161. warp/tests/test_struct.py +2 -2
  162. warp/tests/test_tuple.py +265 -0
  163. warp/tests/test_types.py +2 -2
  164. warp/tests/test_utils.py +24 -18
  165. warp/tests/tile/test_tile.py +420 -1
  166. warp/tests/tile/test_tile_mathdx.py +518 -14
  167. warp/tests/tile/test_tile_reduce.py +213 -0
  168. warp/tests/tile/test_tile_shared_memory.py +130 -1
  169. warp/tests/tile/test_tile_sort.py +117 -0
  170. warp/tests/unittest_suites.py +4 -6
  171. warp/types.py +462 -308
  172. warp/utils.py +647 -86
  173. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.0.dist-info}/METADATA +20 -6
  174. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.0.dist-info}/RECORD +177 -165
  175. warp/stubs.py +0 -3381
  176. warp/tests/sim/test_xpbd.py +0 -399
  177. warp/tests/test_mlp.py +0 -282
  178. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.0.dist-info}/WHEEL +0 -0
  179. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.0.dist-info}/licenses/LICENSE.md +0 -0
  180. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.0.dist-info}/top_level.txt +0 -0
warp/context.py CHANGED
@@ -32,22 +32,7 @@ import typing
32
32
  import weakref
33
33
  from copy import copy as shallowcopy
34
34
  from pathlib import Path
35
- from typing import (
36
- Any,
37
- Callable,
38
- Dict,
39
- List,
40
- Literal,
41
- Mapping,
42
- Optional,
43
- Sequence,
44
- Set,
45
- Tuple,
46
- TypeVar,
47
- Union,
48
- get_args,
49
- get_origin,
50
- )
35
+ from typing import Any, Callable, Dict, List, Literal, Mapping, Sequence, Tuple, TypeVar, Union, get_args, get_origin
51
36
 
52
37
  import numpy as np
53
38
 
@@ -84,7 +69,7 @@ def get_function_args(func):
84
69
  complex_type_hints = (Any, Callable, Tuple)
85
70
  sequence_types = (list, tuple)
86
71
 
87
- function_key_counts: Dict[str, int] = {}
72
+ function_key_counts: dict[str, int] = {}
88
73
 
89
74
 
90
75
  def generate_unique_function_identifier(key: str) -> str:
@@ -120,40 +105,41 @@ def generate_unique_function_identifier(key: str) -> str:
120
105
  class Function:
121
106
  def __init__(
122
107
  self,
123
- func: Optional[Callable],
108
+ func: Callable | None,
124
109
  key: str,
125
110
  namespace: str,
126
- input_types: Optional[Dict[str, Union[type, TypeVar]]] = None,
127
- value_type: Optional[type] = None,
128
- value_func: Optional[Callable[[Mapping[str, type], Mapping[str, Any]], type]] = None,
129
- export_func: Optional[Callable[[Dict[str, type]], Dict[str, type]]] = None,
130
- dispatch_func: Optional[Callable] = None,
131
- lto_dispatch_func: Optional[Callable] = None,
132
- module: Optional[Module] = None,
111
+ input_types: dict[str, type | TypeVar] | None = None,
112
+ value_type: type | None = None,
113
+ value_func: Callable[[Mapping[str, type], Mapping[str, Any]], type] | None = None,
114
+ export_func: Callable[[dict[str, type]], dict[str, type]] | None = None,
115
+ dispatch_func: Callable | None = None,
116
+ lto_dispatch_func: Callable | None = None,
117
+ module: Module | None = None,
133
118
  variadic: bool = False,
134
- initializer_list_func: Optional[Callable[[Dict[str, Any], type], bool]] = None,
119
+ initializer_list_func: Callable[[dict[str, Any], type], bool] | None = None,
135
120
  export: bool = False,
121
+ source: str | None = None,
136
122
  doc: str = "",
137
123
  group: str = "",
138
124
  hidden: bool = False,
139
125
  skip_replay: bool = False,
140
126
  missing_grad: bool = False,
141
127
  generic: bool = False,
142
- native_func: Optional[str] = None,
143
- defaults: Optional[Dict[str, Any]] = None,
144
- custom_replay_func: Optional[Function] = None,
145
- native_snippet: Optional[str] = None,
146
- adj_native_snippet: Optional[str] = None,
147
- replay_snippet: Optional[str] = None,
128
+ native_func: str | None = None,
129
+ defaults: dict[str, Any] | None = None,
130
+ custom_replay_func: Function | None = None,
131
+ native_snippet: str | None = None,
132
+ adj_native_snippet: str | None = None,
133
+ replay_snippet: str | None = None,
148
134
  skip_forward_codegen: bool = False,
149
135
  skip_reverse_codegen: bool = False,
150
136
  custom_reverse_num_input_args: int = -1,
151
137
  custom_reverse_mode: bool = False,
152
- overloaded_annotations: Optional[Dict[str, type]] = None,
153
- code_transformers: Optional[List[ast.NodeTransformer]] = None,
138
+ overloaded_annotations: dict[str, type] | None = None,
139
+ code_transformers: list[ast.NodeTransformer] | None = None,
154
140
  skip_adding_overload: bool = False,
155
141
  require_original_output_arg: bool = False,
156
- scope_locals: Optional[Dict[str, Any]] = None,
142
+ scope_locals: dict[str, Any] | None = None,
157
143
  ):
158
144
  if code_transformers is None:
159
145
  code_transformers = []
@@ -178,7 +164,7 @@ class Function:
178
164
  self.native_snippet = native_snippet
179
165
  self.adj_native_snippet = adj_native_snippet
180
166
  self.replay_snippet = replay_snippet
181
- self.custom_grad_func: Optional[Function] = None
167
+ self.custom_grad_func: Function | None = None
182
168
  self.require_original_output_arg = require_original_output_arg
183
169
  self.generic_parent = None # generic function that was used to instantiate this overload
184
170
 
@@ -194,7 +180,7 @@ class Function:
194
180
  )
195
181
  self.missing_grad = missing_grad # whether builtin is missing a corresponding adjoint
196
182
  self.generic = generic
197
- self.mangled_name: Optional[str] = None
183
+ self.mangled_name: str | None = None
198
184
 
199
185
  # allow registering functions with a different name in Python and native code
200
186
  if native_func is None:
@@ -211,12 +197,13 @@ class Function:
211
197
  # user-defined function
212
198
 
213
199
  # generic and concrete overload lookups by type signature
214
- self.user_templates: Dict[str, Function] = {}
215
- self.user_overloads: Dict[str, Function] = {}
200
+ self.user_templates: dict[str, Function] = {}
201
+ self.user_overloads: dict[str, Function] = {}
216
202
 
217
203
  # user defined (Python) function
218
204
  self.adj = warp.codegen.Adjoint(
219
205
  func,
206
+ source=source,
220
207
  is_user_function=True,
221
208
  skip_forward_codegen=skip_forward_codegen,
222
209
  skip_reverse_codegen=skip_reverse_codegen,
@@ -244,7 +231,7 @@ class Function:
244
231
 
245
232
  # embedded linked list of all overloads
246
233
  # the builtin_functions dictionary holds the list head for a given key (func name)
247
- self.overloads: List[Function] = []
234
+ self.overloads: list[Function] = []
248
235
 
249
236
  # builtin (native) function, canonicalize argument types
250
237
  if input_types is not None:
@@ -293,10 +280,11 @@ class Function:
293
280
  module.register_function(self, scope_locals, skip_adding_overload)
294
281
 
295
282
  def __call__(self, *args, **kwargs):
296
- # handles calling a builtin (native) function
297
- # as if it was a Python function, i.e.: from
298
- # within the CPython interpreter rather than
299
- # from within a kernel (experimental).
283
+ """Call this function from the CPython interpreter.
284
+
285
+ This is used to call built-in or user functions from the CPython
286
+ interpreter, rather than from within a kernel.
287
+ """
300
288
 
301
289
  if self.is_builtin() and self.mangled_name:
302
290
  # For each of this function's existing overloads, we attempt to pack
@@ -306,7 +294,23 @@ class Function:
306
294
  if overload.generic:
307
295
  continue
308
296
 
309
- success, return_value = call_builtin(overload, *args)
297
+ try:
298
+ # Try to bind the given arguments to the function's signature.
299
+ # This is not checking whether the argument types are matching,
300
+ # rather it's just assigning each argument to the corresponding
301
+ # function parameter.
302
+ bound_args = self.signature.bind(*args, **kwargs)
303
+ except TypeError:
304
+ continue
305
+
306
+ if self.defaults:
307
+ # Populate the bound arguments with any default values.
308
+ default_args = {k: v for k, v in self.defaults.items() if k not in bound_args.arguments}
309
+ warp.codegen.apply_defaults(bound_args, default_args)
310
+
311
+ bound_args = tuple(bound_args.arguments.values())
312
+
313
+ success, return_value = call_builtin(overload, bound_args)
310
314
  if success:
311
315
  return return_value
312
316
 
@@ -324,6 +328,9 @@ class Function:
324
328
 
325
329
  arguments = tuple(bound_args.arguments.values())
326
330
 
331
+ # Store the last runtime error we encountered from a function execution
332
+ last_execution_error = None
333
+
327
334
  # try and find a matching overload
328
335
  for overload in self.user_overloads.values():
329
336
  if len(overload.input_types) != len(arguments):
@@ -334,10 +341,25 @@ class Function:
334
341
  # attempt to unify argument types with function template types
335
342
  warp.types.infer_argument_types(arguments, template_types, arg_names)
336
343
  return overload.func(*arguments)
337
- except Exception:
344
+ except Exception as e:
345
+ # The function was callable but threw an error during its execution.
346
+ # This might be the intended overload, but it failed, or it might be the wrong overload.
347
+ # We save this specific error and continue, just in case another overload later in the
348
+ # list is a better match and doesn't fail.
349
+ last_execution_error = e
338
350
  continue
339
351
 
340
- raise RuntimeError(f"Error calling function '{self.key}', no overload found for arguments {args}")
352
+ if last_execution_error:
353
+ # Raise a new, more contextual RuntimeError, but link it to the
354
+ # original error that was caught. This preserves the original
355
+ # traceback and error type for easier debugging.
356
+ raise RuntimeError(
357
+ f"Error calling function '{self.key}'. No version succeeded. "
358
+ f"See above for the error from the last version that was tried."
359
+ ) from last_execution_error
360
+ else:
361
+ # We got here without ever calling an overload.func
362
+ raise RuntimeError(f"Error calling function '{self.key}', no overload found for arguments {args}")
341
363
 
342
364
  # user-defined function with no overloads
343
365
  if self.func is None:
@@ -358,9 +380,6 @@ class Function:
358
380
  if warp.types.is_array(v) or v in complex_type_hints:
359
381
  return False
360
382
 
361
- if type(self.value_type) in sequence_types:
362
- return False
363
-
364
383
  return True
365
384
 
366
385
  def mangle(self) -> str:
@@ -404,8 +423,12 @@ class Function:
404
423
  else:
405
424
  self.user_overloads[sig] = f
406
425
 
407
- def get_overload(self, arg_types: List[type], kwarg_types: Mapping[str, type]) -> Optional[Function]:
408
- assert not self.is_builtin()
426
+ def get_overload(self, arg_types: list[type], kwarg_types: Mapping[str, type]) -> Function | None:
427
+ if self.is_builtin():
428
+ for f in self.overloads:
429
+ if warp.codegen.func_match_args(f, arg_types, kwarg_types):
430
+ return f
431
+ return None
409
432
 
410
433
  for f in self.user_overloads.values():
411
434
  if warp.codegen.func_match_args(f, arg_types, kwarg_types):
@@ -439,7 +462,7 @@ class Function:
439
462
  overload_annotations[k] = warp.codegen.strip_reference(warp.codegen.get_arg_type(d))
440
463
 
441
464
  ovl = shallowcopy(f)
442
- ovl.adj = warp.codegen.Adjoint(f.func, overload_annotations)
465
+ ovl.adj = warp.codegen.Adjoint(f.func, overload_annotations, source=f.adj.source)
443
466
  ovl.input_types = overload_annotations
444
467
  ovl.value_func = None
445
468
  ovl.generic_parent = f
@@ -475,11 +498,25 @@ def get_builtin_type(return_type: type) -> type:
475
498
  return return_type
476
499
 
477
500
 
478
- def call_builtin(func: Function, *params: Any) -> Tuple[bool, Any]:
501
+ def extract_return_value(value_type: type, value_ctype: type, ret: Any) -> Any:
502
+ if issubclass(value_ctype, ctypes.Array) or issubclass(value_ctype, ctypes.Structure):
503
+ # return vector types as ctypes
504
+ return ret
505
+
506
+ if value_type is warp.types.float16:
507
+ return warp.types.half_bits_to_float(ret.value)
508
+
509
+ return ret.value
510
+
511
+
512
+ def call_builtin(func: Function, params: tuple) -> tuple[bool, Any]:
479
513
  uses_non_warp_array_type = False
480
514
 
481
515
  init()
482
516
 
517
+ if func.mangled_name is None:
518
+ return (False, None)
519
+
483
520
  # Retrieve the built-in function from Warp's dll.
484
521
  c_func = getattr(warp.context.runtime.core, func.mangled_name)
485
522
 
@@ -489,6 +526,8 @@ def call_builtin(func: Function, *params: Any) -> Tuple[bool, Any]:
489
526
  else:
490
527
  func_args = func.input_types
491
528
 
529
+ value_type = func.value_func(None, None)
530
+
492
531
  # Try gathering the parameters that the function expects and pack them
493
532
  # into their corresponding C types.
494
533
  c_params = []
@@ -604,9 +643,9 @@ def call_builtin(func: Function, *params: Any) -> Tuple[bool, Any]:
604
643
 
605
644
  if not (
606
645
  isinstance(param, arg_type)
607
- or (type(param) is float and arg_type is warp.types.float32) # noqa: E721
608
- or (type(param) is int and arg_type is warp.types.int32) # noqa: E721
609
- or (type(param) is bool and arg_type is warp.types.bool) # noqa: E721
646
+ or (type(param) is float and arg_type is warp.types.float32)
647
+ or (type(param) is int and arg_type is warp.types.int32)
648
+ or (type(param) is bool and arg_type is warp.types.bool)
610
649
  or warp.types.np_dtype_to_warp_type.get(getattr(param, "dtype", None)) is arg_type
611
650
  ):
612
651
  return (False, None)
@@ -620,25 +659,18 @@ def call_builtin(func: Function, *params: Any) -> Tuple[bool, Any]:
620
659
  else:
621
660
  c_params.append(arg_type._type_(param))
622
661
 
623
- # returns the corresponding ctype for a scalar or vector warp type
624
- value_type = func.value_func(None, None)
662
+ # Retrieve the return type.
663
+ value_type = func.value_func(func_args, None)
625
664
 
626
- if value_type == float:
627
- value_ctype = ctypes.c_float
628
- elif value_type == int:
629
- value_ctype = ctypes.c_int32
630
- elif value_type == bool:
631
- value_ctype = ctypes.c_bool
632
- elif issubclass(value_type, (ctypes.Array, ctypes.Structure)):
633
- value_ctype = value_type
634
- else:
635
- # scalar type
636
- value_ctype = value_type._type_
665
+ if value_type is not None:
666
+ if not isinstance(value_type, Sequence):
667
+ value_type = (value_type,)
668
+
669
+ value_ctype = tuple(warp.types.type_ctype(x) for x in value_type)
670
+ ret = tuple(x() for x in value_ctype)
671
+ ret_addr = tuple(ctypes.c_void_p(ctypes.addressof(x)) for x in ret)
637
672
 
638
- # construct return value (passed by address)
639
- ret = value_ctype()
640
- ret_addr = ctypes.c_void_p(ctypes.addressof(ret))
641
- c_params.append(ret_addr)
673
+ c_params.extend(ret_addr)
642
674
 
643
675
  # Call the built-in function from Warp's dll.
644
676
  c_func(*c_params)
@@ -653,17 +685,14 @@ def call_builtin(func: Function, *params: Any) -> Tuple[bool, Any]:
653
685
  stacklevel=3,
654
686
  )
655
687
 
656
- if issubclass(value_ctype, ctypes.Array) or issubclass(value_ctype, ctypes.Structure):
657
- # return vector types as ctypes
658
- return (True, ret)
688
+ if value_type is None:
689
+ return (True, None)
659
690
 
660
- if value_type == warp.types.float16:
661
- value = warp.types.half_bits_to_float(ret.value)
662
- else:
663
- value = ret.value
691
+ return_value = tuple(extract_return_value(x, y, z) for x, y, z in zip(value_type, value_ctype, ret))
692
+ if len(return_value) == 1:
693
+ return_value = return_value[0]
664
694
 
665
- # return scalar types as int/float
666
- return (True, value)
695
+ return (True, return_value)
667
696
 
668
697
 
669
698
  class KernelHooks:
@@ -677,7 +706,7 @@ class KernelHooks:
677
706
 
678
707
  # caches source and compiled entry points for a kernel (will be populated after module loads)
679
708
  class Kernel:
680
- def __init__(self, func, key=None, module=None, options=None, code_transformers=None):
709
+ def __init__(self, func, key=None, module=None, options=None, code_transformers=None, source=None):
681
710
  self.func = func
682
711
 
683
712
  if module is None:
@@ -695,7 +724,7 @@ class Kernel:
695
724
  if code_transformers is None:
696
725
  code_transformers = []
697
726
 
698
- self.adj = warp.codegen.Adjoint(func, transformers=code_transformers)
727
+ self.adj = warp.codegen.Adjoint(func, transformers=code_transformers, source=source)
699
728
 
700
729
  # check if generic
701
730
  self.is_generic = False
@@ -762,7 +791,7 @@ class Kernel:
762
791
 
763
792
  # instantiate this kernel with the given argument types
764
793
  ovl = shallowcopy(self)
765
- ovl.adj = warp.codegen.Adjoint(self.func, overload_annotations)
794
+ ovl.adj = warp.codegen.Adjoint(self.func, overload_annotations, source=self.adj.source)
766
795
  ovl.is_generic = False
767
796
  ovl.overloads = {}
768
797
  ovl.sig = sig
@@ -798,7 +827,7 @@ class Kernel:
798
827
 
799
828
 
800
829
  # decorator to register function, @func
801
- def func(f: Optional[Callable] = None, *, name: Optional[str] = None):
830
+ def func(f: Callable | None = None, *, name: str | None = None):
802
831
  def wrapper(f, *args, **kwargs):
803
832
  if name is None:
804
833
  key = warp.codegen.make_full_qualified_name(f)
@@ -831,7 +860,7 @@ def func(f: Optional[Callable] = None, *, name: Optional[str] = None):
831
860
  return wrapper(f)
832
861
 
833
862
 
834
- def func_native(snippet: str, adj_snippet: Optional[str] = None, replay_snippet: Optional[str] = None):
863
+ def func_native(snippet: str, adj_snippet: str | None = None, replay_snippet: str | None = None):
835
864
  """
836
865
  Decorator to register native code snippet, @func_native
837
866
  """
@@ -1015,10 +1044,10 @@ def func_replay(forward_fn):
1015
1044
 
1016
1045
 
1017
1046
  def kernel(
1018
- f: Optional[Callable] = None,
1047
+ f: Callable | None = None,
1019
1048
  *,
1020
- enable_backward: Optional[bool] = None,
1021
- module: Optional[Union[Module, Literal["unique"]]] = None,
1049
+ enable_backward: bool | None = None,
1050
+ module: Module | Literal["unique"] | None = None,
1022
1051
  ):
1023
1052
  """
1024
1053
  Decorator to register a Warp kernel from a Python function.
@@ -1181,7 +1210,7 @@ def overload(kernel, arg_types=Union[None, Dict[str, Any], List[Any]]):
1181
1210
 
1182
1211
 
1183
1212
  # native functions that are part of the Warp API
1184
- builtin_functions: Dict[str, Function] = {}
1213
+ builtin_functions: dict[str, Function] = {}
1185
1214
 
1186
1215
 
1187
1216
  def get_generic_vtypes():
@@ -1204,13 +1233,13 @@ scalar_types.update({x: x._wp_scalar_type_ for x in warp.types.vector_types})
1204
1233
 
1205
1234
  def add_builtin(
1206
1235
  key: str,
1207
- input_types: Optional[Dict[str, Union[type, TypeVar]]] = None,
1208
- constraint: Optional[Callable[[Mapping[str, type]], bool]] = None,
1209
- value_type: Optional[type] = None,
1210
- value_func: Optional[Callable] = None,
1211
- export_func: Optional[Callable] = None,
1212
- dispatch_func: Optional[Callable] = None,
1213
- lto_dispatch_func: Optional[Callable] = None,
1236
+ input_types: dict[str, type | TypeVar] | None = None,
1237
+ constraint: Callable[[Mapping[str, type]], bool] | None = None,
1238
+ value_type: type | None = None,
1239
+ value_func: Callable | None = None,
1240
+ export_func: Callable | None = None,
1241
+ dispatch_func: Callable | None = None,
1242
+ lto_dispatch_func: Callable | None = None,
1214
1243
  doc: str = "",
1215
1244
  namespace: str = "wp::",
1216
1245
  variadic: bool = False,
@@ -1220,8 +1249,8 @@ def add_builtin(
1220
1249
  hidden: bool = False,
1221
1250
  skip_replay: bool = False,
1222
1251
  missing_grad: bool = False,
1223
- native_func: Optional[str] = None,
1224
- defaults: Optional[Dict[str, Any]] = None,
1252
+ native_func: str | None = None,
1253
+ defaults: dict[str, Any] | None = None,
1225
1254
  require_original_output_arg: bool = False,
1226
1255
  ):
1227
1256
  """Main entry point to register a new built-in function.
@@ -1371,18 +1400,13 @@ def add_builtin(
1371
1400
 
1372
1401
  return_type = value_func(concrete_arg_types, None)
1373
1402
 
1374
- # The return_type might just be vector_t(length=3,dtype=wp.float32), so we've got to match that
1375
- # in the list of hard coded types so it knows it's returning one of them:
1376
- if hasattr(return_type, "_wp_generic_type_hint_"):
1377
- return_type_match = tuple(
1378
- x
1379
- for x in generic_vtypes
1380
- if x._wp_generic_type_hint_ == return_type._wp_generic_type_hint_
1381
- and x._wp_type_params_ == return_type._wp_type_params_
1382
- )
1383
- if not return_type_match:
1384
- continue
1385
- return_type = return_type_match[0]
1403
+ try:
1404
+ if isinstance(return_type, Sequence):
1405
+ return_type = tuple(get_builtin_type(x) for x in return_type)
1406
+ else:
1407
+ return_type = get_builtin_type(return_type)
1408
+ except RuntimeError:
1409
+ continue
1386
1410
 
1387
1411
  # finally we can generate a function call for these concrete types:
1388
1412
  add_builtin(
@@ -1485,7 +1509,7 @@ def register_api_function(
1485
1509
 
1486
1510
 
1487
1511
  # global dictionary of modules
1488
- user_modules: Dict[str, Module] = {}
1512
+ user_modules: dict[str, Module] = {}
1489
1513
 
1490
1514
 
1491
1515
  def get_module(name: str) -> Module:
@@ -1608,7 +1632,7 @@ class ModuleHasher:
1608
1632
  ch.update(bytes(func.key, "utf-8"))
1609
1633
 
1610
1634
  # include all concrete and generic overloads
1611
- overloads: Dict[str, Function] = {**func.user_overloads, **func.user_templates}
1635
+ overloads: dict[str, Function] = {**func.user_overloads, **func.user_templates}
1612
1636
  for sig in sorted(overloads.keys()):
1613
1637
  ovl = overloads[sig]
1614
1638
 
@@ -1857,7 +1881,7 @@ class ModuleBuilder:
1857
1881
  # the original Modules get reloaded.
1858
1882
  class ModuleExec:
1859
1883
  def __new__(cls, *args, **kwargs):
1860
- instance = super(ModuleExec, cls).__new__(cls)
1884
+ instance = super().__new__(cls)
1861
1885
  instance.handle = None
1862
1886
  return instance
1863
1887
 
@@ -1952,7 +1976,7 @@ class ModuleExec:
1952
1976
  # creates a hash of the function to use for checking
1953
1977
  # build cache
1954
1978
  class Module:
1955
- def __init__(self, name: Optional[str], loader=None):
1979
+ def __init__(self, name: str | None, loader=None):
1956
1980
  self.name = name if name is not None else "None"
1957
1981
 
1958
1982
  self.loader = loader
@@ -1996,6 +2020,7 @@ class Module:
1996
2020
  "cuda_output": None, # supported values: "ptx", "cubin", or None (automatic)
1997
2021
  "mode": warp.config.mode,
1998
2022
  "block_dim": 256,
2023
+ "compile_time_trace": warp.config.compile_time_trace,
1999
2024
  }
2000
2025
 
2001
2026
  # Module dependencies are determined by scanning each function
@@ -2222,7 +2247,7 @@ class Module:
2222
2247
  ):
2223
2248
  builder_options = {
2224
2249
  **self.options,
2225
- # Some of the Tile codegen, such as cuFFTDx and cuBLASDx, requires knowledge of the target arch
2250
+ # Some of the tile codegen, such as cuFFTDx and cuBLASDx, requires knowledge of the target arch
2226
2251
  "output_arch": output_arch,
2227
2252
  }
2228
2253
  builder = ModuleBuilder(self, builder_options, hasher=self.hashers[active_block_dim])
@@ -2291,6 +2316,7 @@ class Module:
2291
2316
  fast_math=self.options["fast_math"],
2292
2317
  fuse_fp=self.options["fuse_fp"],
2293
2318
  lineinfo=self.options["lineinfo"],
2319
+ compile_time_trace=self.options["compile_time_trace"],
2294
2320
  ltoirs=builder.ltoirs.values(),
2295
2321
  fatbins=builder.fatbins.values(),
2296
2322
  )
@@ -2343,7 +2369,7 @@ class Module:
2343
2369
  # Load CPU or CUDA binary
2344
2370
 
2345
2371
  meta_path = os.path.join(module_dir, f"{module_name_short}.meta")
2346
- with open(meta_path, "r") as meta_file:
2372
+ with open(meta_path) as meta_file:
2347
2373
  meta = json.load(meta_file)
2348
2374
 
2349
2375
  if device.is_cpu:
@@ -2406,7 +2432,7 @@ class CpuDefaultAllocator:
2406
2432
  def alloc(self, size_in_bytes):
2407
2433
  ptr = runtime.core.alloc_host(size_in_bytes)
2408
2434
  if not ptr:
2409
- raise RuntimeError(f"Failed to allocate {size_in_bytes} bytes on device '{self.device}'")
2435
+ raise RuntimeError(f"Failed to allocate {size_in_bytes} bytes on device 'cpu'")
2410
2436
  return ptr
2411
2437
 
2412
2438
  def free(self, ptr, size_in_bytes):
@@ -2510,12 +2536,12 @@ class Event:
2510
2536
 
2511
2537
  def __new__(cls, *args, **kwargs):
2512
2538
  """Creates a new event instance."""
2513
- instance = super(Event, cls).__new__(cls)
2539
+ instance = super().__new__(cls)
2514
2540
  instance.owner = False
2515
2541
  return instance
2516
2542
 
2517
2543
  def __init__(
2518
- self, device: "Devicelike" = None, cuda_event=None, enable_timing: bool = False, interprocess: bool = False
2544
+ self, device: Devicelike = None, cuda_event=None, enable_timing: bool = False, interprocess: bool = False
2519
2545
  ):
2520
2546
  """Initializes the event on a CUDA device.
2521
2547
 
@@ -2611,12 +2637,12 @@ class Event:
2611
2637
 
2612
2638
  class Stream:
2613
2639
  def __new__(cls, *args, **kwargs):
2614
- instance = super(Stream, cls).__new__(cls)
2640
+ instance = super().__new__(cls)
2615
2641
  instance.cuda_stream = None
2616
2642
  instance.owner = False
2617
2643
  return instance
2618
2644
 
2619
- def __init__(self, device: Union["Device", str, None] = None, priority: int = 0, **kwargs):
2645
+ def __init__(self, device: Device | str | None = None, priority: int = 0, **kwargs):
2620
2646
  """Initialize the stream on a device with an optional specified priority.
2621
2647
 
2622
2648
  Args:
@@ -2682,7 +2708,7 @@ class Stream:
2682
2708
  self._cached_event = Event(self.device)
2683
2709
  return self._cached_event
2684
2710
 
2685
- def record_event(self, event: Optional[Event] = None) -> Event:
2711
+ def record_event(self, event: Event | None = None) -> Event:
2686
2712
  """Record an event onto the stream.
2687
2713
 
2688
2714
  Args:
@@ -2711,7 +2737,7 @@ class Stream:
2711
2737
  """
2712
2738
  runtime.core.cuda_stream_wait_event(self.cuda_stream, event.cuda_event)
2713
2739
 
2714
- def wait_stream(self, other_stream: "Stream", event: Optional[Event] = None):
2740
+ def wait_stream(self, other_stream: Stream, event: Event | None = None):
2715
2741
  """Records an event on `other_stream` and makes this stream wait on it.
2716
2742
 
2717
2743
  All work added to this stream after this function has been called will
@@ -2765,6 +2791,8 @@ class Device:
2765
2791
  or ``"CPU"`` if the processor name cannot be determined.
2766
2792
  arch (int): The compute capability version number calculated as ``10 * major + minor``.
2767
2793
  ``0`` for CPU devices.
2794
+ sm_count (int): The number of streaming multiprocessors on the CUDA device.
2795
+ ``0`` for CPU devices.
2768
2796
  is_uva (bool): Indicates whether the device supports unified addressing.
2769
2797
  ``False`` for CPU devices.
2770
2798
  is_cubin_supported (bool): Indicates whether Warp's version of NVRTC can directly
@@ -2810,6 +2838,7 @@ class Device:
2810
2838
  # CPU device
2811
2839
  self.name = platform.processor() or "CPU"
2812
2840
  self.arch = 0
2841
+ self.sm_count = 0
2813
2842
  self.is_uva = False
2814
2843
  self.is_mempool_supported = False
2815
2844
  self.is_mempool_enabled = False
@@ -2829,6 +2858,7 @@ class Device:
2829
2858
  # CUDA device
2830
2859
  self.name = runtime.core.cuda_device_get_name(ordinal).decode()
2831
2860
  self.arch = runtime.core.cuda_device_get_arch(ordinal)
2861
+ self.sm_count = runtime.core.cuda_device_get_sm_count(ordinal)
2832
2862
  self.is_uva = runtime.core.cuda_device_is_uva(ordinal) > 0
2833
2863
  self.is_mempool_supported = runtime.core.cuda_device_is_mempool_supported(ordinal) > 0
2834
2864
  if platform.system() == "Linux":
@@ -3070,16 +3100,23 @@ class Graph:
3070
3100
  def __init__(self, device: Device, capture_id: int):
3071
3101
  self.device = device
3072
3102
  self.capture_id = capture_id
3073
- self.module_execs: Set[ModuleExec] = set()
3074
- self.graph_exec: Optional[ctypes.c_void_p] = None
3103
+ self.module_execs: set[ModuleExec] = set()
3104
+ self.graph_exec: ctypes.c_void_p | None = None
3105
+
3106
+ self.graph: ctypes.c_void_p | None = None
3107
+ self.has_conditional = (
3108
+ False # Track if there are conditional nodes in the graph since they are not allowed in child graphs
3109
+ )
3075
3110
 
3076
3111
  def __del__(self):
3077
- if not hasattr(self, "graph_exec") or not hasattr(self, "device") or not self.graph_exec:
3112
+ if not hasattr(self, "graph") or not hasattr(self, "device") or not self.graph:
3078
3113
  return
3079
3114
 
3080
3115
  # use CUDA context guard to avoid side effects during garbage collection
3081
3116
  with self.device.context_guard:
3082
- runtime.core.cuda_graph_destroy(self.device.context, self.graph_exec)
3117
+ runtime.core.cuda_graph_destroy(self.device.context, self.graph)
3118
+ if hasattr(self, "graph_exec") and self.graph_exec is not None:
3119
+ runtime.core.cuda_graph_exec_destroy(self.device.context, self.graph_exec)
3083
3120
 
3084
3121
  # retain executable CUDA modules used by this graph, which prevents them from being unloaded
3085
3122
  def retain_module_exec(self, module_exec: ModuleExec):
@@ -3088,8 +3125,6 @@ class Graph:
3088
3125
 
3089
3126
  class Runtime:
3090
3127
  def __init__(self):
3091
- if sys.version_info < (3, 8):
3092
- raise RuntimeError("Warp requires Python 3.8 as a minimum")
3093
3128
  if sys.version_info < (3, 9):
3094
3129
  warp.utils.warn(f"Python 3.9 or newer is recommended for running Warp, detected {sys.version_info}")
3095
3130
 
@@ -3535,44 +3570,40 @@ class Runtime:
3535
3570
  self.core.volume_get_blind_data_info.restype = ctypes.c_char_p
3536
3571
 
3537
3572
  bsr_matrix_from_triplets_argtypes = [
3538
- ctypes.c_int, # rows_per_bock
3539
- ctypes.c_int, # cols_per_blocks
3573
+ ctypes.c_int, # block_size
3574
+ ctypes.c_int, # scalar size in bytes
3540
3575
  ctypes.c_int, # row_count
3541
- ctypes.c_int, # tpl_nnz
3576
+ ctypes.c_int, # col_count
3577
+ ctypes.c_int, # nnz_upper_bound
3578
+ ctypes.POINTER(ctypes.c_int), # tpl_nnz
3542
3579
  ctypes.POINTER(ctypes.c_int), # tpl_rows
3543
3580
  ctypes.POINTER(ctypes.c_int), # tpl_cols
3544
3581
  ctypes.c_void_p, # tpl_values
3545
- ctypes.c_bool, # prune_numerical_zeros
3582
+ ctypes.c_uint64, # zero_value_mask
3546
3583
  ctypes.c_bool, # masked
3547
3584
  ctypes.POINTER(ctypes.c_int), # bsr_offsets
3548
3585
  ctypes.POINTER(ctypes.c_int), # bsr_columns
3549
- ctypes.c_void_p, # bsr_values
3586
+ ctypes.POINTER(ctypes.c_int), # prefix sum of block count to sum for each bsr block
3587
+ ctypes.POINTER(ctypes.c_int), # indices to ptriplet blocks to sum for each bsr block
3550
3588
  ctypes.POINTER(ctypes.c_int), # bsr_nnz
3551
3589
  ctypes.c_void_p, # bsr_nnz_event
3552
3590
  ]
3553
3591
 
3554
- self.core.bsr_matrix_from_triplets_float_host.argtypes = bsr_matrix_from_triplets_argtypes
3555
- self.core.bsr_matrix_from_triplets_double_host.argtypes = bsr_matrix_from_triplets_argtypes
3556
- self.core.bsr_matrix_from_triplets_float_device.argtypes = bsr_matrix_from_triplets_argtypes
3557
- self.core.bsr_matrix_from_triplets_double_device.argtypes = bsr_matrix_from_triplets_argtypes
3592
+ self.core.bsr_matrix_from_triplets_host.argtypes = bsr_matrix_from_triplets_argtypes
3593
+ self.core.bsr_matrix_from_triplets_device.argtypes = bsr_matrix_from_triplets_argtypes
3558
3594
 
3559
3595
  bsr_transpose_argtypes = [
3560
- ctypes.c_int, # rows_per_bock
3561
- ctypes.c_int, # cols_per_blocks
3562
3596
  ctypes.c_int, # row_count
3563
3597
  ctypes.c_int, # col count
3564
3598
  ctypes.c_int, # nnz
3565
3599
  ctypes.POINTER(ctypes.c_int), # transposed_bsr_offsets
3566
3600
  ctypes.POINTER(ctypes.c_int), # transposed_bsr_columns
3567
- ctypes.c_void_p, # bsr_values
3568
3601
  ctypes.POINTER(ctypes.c_int), # transposed_bsr_offsets
3569
3602
  ctypes.POINTER(ctypes.c_int), # transposed_bsr_columns
3570
- ctypes.c_void_p, # transposed_bsr_values
3603
+ ctypes.POINTER(ctypes.c_int), # src to dest block map
3571
3604
  ]
3572
- self.core.bsr_transpose_float_host.argtypes = bsr_transpose_argtypes
3573
- self.core.bsr_transpose_double_host.argtypes = bsr_transpose_argtypes
3574
- self.core.bsr_transpose_float_device.argtypes = bsr_transpose_argtypes
3575
- self.core.bsr_transpose_double_device.argtypes = bsr_transpose_argtypes
3605
+ self.core.bsr_transpose_host.argtypes = bsr_transpose_argtypes
3606
+ self.core.bsr_transpose_device.argtypes = bsr_transpose_argtypes
3576
3607
 
3577
3608
  self.core.is_cuda_enabled.argtypes = None
3578
3609
  self.core.is_cuda_enabled.restype = ctypes.c_int
@@ -3601,6 +3632,8 @@ class Runtime:
3601
3632
  self.core.cuda_device_get_name.restype = ctypes.c_char_p
3602
3633
  self.core.cuda_device_get_arch.argtypes = [ctypes.c_int]
3603
3634
  self.core.cuda_device_get_arch.restype = ctypes.c_int
3635
+ self.core.cuda_device_get_sm_count.argtypes = [ctypes.c_int]
3636
+ self.core.cuda_device_get_sm_count.restype = ctypes.c_int
3604
3637
  self.core.cuda_device_is_uva.argtypes = [ctypes.c_int]
3605
3638
  self.core.cuda_device_is_uva.restype = ctypes.c_int
3606
3639
  self.core.cuda_device_is_mempool_supported.argtypes = [ctypes.c_int]
@@ -3724,11 +3757,72 @@ class Runtime:
3724
3757
  ctypes.POINTER(ctypes.c_void_p),
3725
3758
  ]
3726
3759
  self.core.cuda_graph_end_capture.restype = ctypes.c_bool
3760
+
3761
+ self.core.cuda_graph_create_exec.argtypes = [
3762
+ ctypes.c_void_p,
3763
+ ctypes.c_void_p,
3764
+ ctypes.POINTER(ctypes.c_void_p),
3765
+ ]
3766
+ self.core.cuda_graph_create_exec.restype = ctypes.c_bool
3767
+
3768
+ self.core.capture_debug_dot_print.argtypes = [ctypes.c_void_p, ctypes.c_char_p, ctypes.c_uint32]
3769
+ self.core.capture_debug_dot_print.restype = ctypes.c_bool
3770
+
3727
3771
  self.core.cuda_graph_launch.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
3728
3772
  self.core.cuda_graph_launch.restype = ctypes.c_bool
3773
+ self.core.cuda_graph_exec_destroy.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
3774
+ self.core.cuda_graph_exec_destroy.restype = ctypes.c_bool
3775
+
3729
3776
  self.core.cuda_graph_destroy.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
3730
3777
  self.core.cuda_graph_destroy.restype = ctypes.c_bool
3731
3778
 
3779
+ self.core.cuda_graph_insert_if_else.argtypes = [
3780
+ ctypes.c_void_p,
3781
+ ctypes.c_void_p,
3782
+ ctypes.POINTER(ctypes.c_int),
3783
+ ctypes.POINTER(ctypes.c_void_p),
3784
+ ctypes.POINTER(ctypes.c_void_p),
3785
+ ]
3786
+ self.core.cuda_graph_insert_if_else.restype = ctypes.c_bool
3787
+
3788
+ self.core.cuda_graph_insert_while.argtypes = [
3789
+ ctypes.c_void_p,
3790
+ ctypes.c_void_p,
3791
+ ctypes.POINTER(ctypes.c_int),
3792
+ ctypes.POINTER(ctypes.c_void_p),
3793
+ ctypes.POINTER(ctypes.c_uint64),
3794
+ ]
3795
+ self.core.cuda_graph_insert_while.restype = ctypes.c_bool
3796
+
3797
+ self.core.cuda_graph_set_condition.argtypes = [
3798
+ ctypes.c_void_p,
3799
+ ctypes.c_void_p,
3800
+ ctypes.POINTER(ctypes.c_int),
3801
+ ctypes.c_uint64,
3802
+ ]
3803
+ self.core.cuda_graph_set_condition.restype = ctypes.c_bool
3804
+
3805
+ self.core.cuda_graph_pause_capture.argtypes = [
3806
+ ctypes.c_void_p,
3807
+ ctypes.c_void_p,
3808
+ ctypes.POINTER(ctypes.c_void_p),
3809
+ ]
3810
+ self.core.cuda_graph_pause_capture.restype = ctypes.c_bool
3811
+
3812
+ self.core.cuda_graph_resume_capture.argtypes = [
3813
+ ctypes.c_void_p,
3814
+ ctypes.c_void_p,
3815
+ ctypes.c_void_p,
3816
+ ]
3817
+ self.core.cuda_graph_resume_capture.restype = ctypes.c_bool
3818
+
3819
+ self.core.cuda_graph_insert_child_graph.argtypes = [
3820
+ ctypes.c_void_p,
3821
+ ctypes.c_void_p,
3822
+ ctypes.c_void_p,
3823
+ ]
3824
+ self.core.cuda_graph_insert_child_graph.restype = ctypes.c_bool
3825
+
3732
3826
  self.core.cuda_compile_program.argtypes = [
3733
3827
  ctypes.c_char_p, # cuda_src
3734
3828
  ctypes.c_char_p, # program name
@@ -3742,6 +3836,7 @@ class Runtime:
3742
3836
  ctypes.c_bool, # fast_math
3743
3837
  ctypes.c_bool, # fuse_fp
3744
3838
  ctypes.c_bool, # lineinfo
3839
+ ctypes.c_bool, # compile_time_trace
3745
3840
  ctypes.c_char_p, # output_path
3746
3841
  ctypes.c_size_t, # num_ltoirs
3747
3842
  ctypes.POINTER(ctypes.c_char_p), # ltoirs
@@ -3796,11 +3891,17 @@ class Runtime:
3796
3891
  ctypes.c_int, # arch
3797
3892
  ctypes.c_int, # M
3798
3893
  ctypes.c_int, # N
3894
+ ctypes.c_int, # NRHS
3895
+ ctypes.c_int, # function
3896
+ ctypes.c_int, # side
3897
+ ctypes.c_int, # diag
3799
3898
  ctypes.c_int, # precision
3899
+ ctypes.c_int, # a_arrangement
3900
+ ctypes.c_int, # b_arrangement
3800
3901
  ctypes.c_int, # fill_mode
3801
3902
  ctypes.c_int, # num threads
3802
3903
  ]
3803
- self.core.cuda_compile_fft.restype = ctypes.c_bool
3904
+ self.core.cuda_compile_solver.restype = ctypes.c_bool
3804
3905
 
3805
3906
  self.core.cuda_load_module.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
3806
3907
  self.core.cuda_load_module.restype = ctypes.c_void_p
@@ -4270,7 +4371,7 @@ def is_cuda_driver_initialized() -> bool:
4270
4371
  return runtime.core.cuda_driver_is_initialized()
4271
4372
 
4272
4373
 
4273
- def get_devices() -> List[Device]:
4374
+ def get_devices() -> list[Device]:
4274
4375
  """Returns a list of devices supported in this environment."""
4275
4376
 
4276
4377
  init()
@@ -4291,7 +4392,7 @@ def get_cuda_device_count() -> int:
4291
4392
  return len(runtime.cuda_devices)
4292
4393
 
4293
4394
 
4294
- def get_cuda_device(ordinal: Union[int, None] = None) -> Device:
4395
+ def get_cuda_device(ordinal: int | None = None) -> Device:
4295
4396
  """Returns the CUDA device with the given ordinal or the current CUDA device if ordinal is None."""
4296
4397
 
4297
4398
  init()
@@ -4302,7 +4403,7 @@ def get_cuda_device(ordinal: Union[int, None] = None) -> Device:
4302
4403
  return runtime.cuda_devices[ordinal]
4303
4404
 
4304
4405
 
4305
- def get_cuda_devices() -> List[Device]:
4406
+ def get_cuda_devices() -> list[Device]:
4306
4407
  """Returns a list of CUDA devices supported in this environment."""
4307
4408
 
4308
4409
  init()
@@ -4341,7 +4442,7 @@ def set_device(ident: Devicelike) -> None:
4341
4442
  device.make_current()
4342
4443
 
4343
4444
 
4344
- def map_cuda_device(alias: str, context: Optional[ctypes.c_void_p] = None) -> Device:
4445
+ def map_cuda_device(alias: str, context: ctypes.c_void_p | None = None) -> Device:
4345
4446
  """Assign a device alias to a CUDA context.
4346
4447
 
4347
4448
  This function can be used to create a wp.Device for an external CUDA context.
@@ -4436,7 +4537,7 @@ def set_mempool_enabled(device: Devicelike, enable: bool) -> None:
4436
4537
  raise ValueError("Memory pools are only supported on CUDA devices")
4437
4538
 
4438
4539
 
4439
- def set_mempool_release_threshold(device: Devicelike, threshold: Union[int, float]) -> None:
4540
+ def set_mempool_release_threshold(device: Devicelike, threshold: int | float) -> None:
4440
4541
  """Set the CUDA memory pool release threshold on the device.
4441
4542
 
4442
4543
  This is the amount of reserved memory to hold onto before trying to release memory back to the OS.
@@ -4744,7 +4845,7 @@ def set_stream(stream: Stream, device: Devicelike = None, sync: bool = False) ->
4744
4845
  get_device(device).set_stream(stream, sync=sync)
4745
4846
 
4746
4847
 
4747
- def record_event(event: Optional[Event] = None):
4848
+ def record_event(event: Event | None = None):
4748
4849
  """Convenience function for calling :meth:`Stream.record_event` on the current stream.
4749
4850
 
4750
4851
  Args:
@@ -4793,7 +4894,7 @@ def get_event_elapsed_time(start_event: Event, end_event: Event, synchronize: bo
4793
4894
  return runtime.core.cuda_event_elapsed_time(start_event.cuda_event, end_event.cuda_event)
4794
4895
 
4795
4896
 
4796
- def wait_stream(other_stream: Stream, event: Optional[Event] = None):
4897
+ def wait_stream(other_stream: Stream, event: Event | None = None):
4797
4898
  """Convenience function for calling :meth:`Stream.wait_stream` on the current stream.
4798
4899
 
4799
4900
  Args:
@@ -4863,7 +4964,7 @@ class RegisteredGLBuffer:
4863
4964
  __fallback_warning_shown = False
4864
4965
 
4865
4966
  def __new__(cls, *args, **kwargs):
4866
- instance = super(RegisteredGLBuffer, cls).__new__(cls)
4967
+ instance = super().__new__(cls)
4867
4968
  instance.resource = None
4868
4969
  return instance
4869
4970
 
@@ -4960,8 +5061,8 @@ class RegisteredGLBuffer:
4960
5061
 
4961
5062
 
4962
5063
  def zeros(
4963
- shape: Union[int, Tuple[int, ...], List[int], None] = None,
4964
- dtype=float,
5064
+ shape: int | tuple[int, ...] | list[int] | None = None,
5065
+ dtype: type = float,
4965
5066
  device: Devicelike = None,
4966
5067
  requires_grad: bool = False,
4967
5068
  pinned: bool = False,
@@ -4988,7 +5089,7 @@ def zeros(
4988
5089
 
4989
5090
 
4990
5091
  def zeros_like(
4991
- src: Array, device: Devicelike = None, requires_grad: Optional[bool] = None, pinned: Optional[bool] = None
5092
+ src: Array, device: Devicelike = None, requires_grad: bool | None = None, pinned: bool | None = None
4992
5093
  ) -> warp.array:
4993
5094
  """Return a zero-initialized array with the same type and dimension of another array
4994
5095
 
@@ -5010,8 +5111,8 @@ def zeros_like(
5010
5111
 
5011
5112
 
5012
5113
  def ones(
5013
- shape: Union[int, Tuple[int, ...], List[int], None] = None,
5014
- dtype=float,
5114
+ shape: int | tuple[int, ...] | list[int] | None = None,
5115
+ dtype: type = float,
5015
5116
  device: Devicelike = None,
5016
5117
  requires_grad: bool = False,
5017
5118
  pinned: bool = False,
@@ -5034,7 +5135,7 @@ def ones(
5034
5135
 
5035
5136
 
5036
5137
  def ones_like(
5037
- src: Array, device: Devicelike = None, requires_grad: Optional[bool] = None, pinned: Optional[bool] = None
5138
+ src: Array, device: Devicelike = None, requires_grad: bool | None = None, pinned: bool | None = None
5038
5139
  ) -> warp.array:
5039
5140
  """Return a one-initialized array with the same type and dimension of another array
5040
5141
 
@@ -5052,7 +5153,7 @@ def ones_like(
5052
5153
 
5053
5154
 
5054
5155
  def full(
5055
- shape: Union[int, Tuple[int, ...], List[int], None] = None,
5156
+ shape: int | tuple[int, ...] | list[int] | None = None,
5056
5157
  value=0,
5057
5158
  dtype=Any,
5058
5159
  device: Devicelike = None,
@@ -5121,8 +5222,8 @@ def full_like(
5121
5222
  src: Array,
5122
5223
  value: Any,
5123
5224
  device: Devicelike = None,
5124
- requires_grad: Optional[bool] = None,
5125
- pinned: Optional[bool] = None,
5225
+ requires_grad: bool | None = None,
5226
+ pinned: bool | None = None,
5126
5227
  ) -> warp.array:
5127
5228
  """Return an array with all elements initialized to the given value with the same type and dimension of another array
5128
5229
 
@@ -5145,7 +5246,7 @@ def full_like(
5145
5246
 
5146
5247
 
5147
5248
  def clone(
5148
- src: warp.array, device: Devicelike = None, requires_grad: Optional[bool] = None, pinned: Optional[bool] = None
5249
+ src: warp.array, device: Devicelike = None, requires_grad: bool | None = None, pinned: bool | None = None
5149
5250
  ) -> warp.array:
5150
5251
  """Clone an existing array, allocates a copy of the src memory
5151
5252
 
@@ -5167,7 +5268,7 @@ def clone(
5167
5268
 
5168
5269
 
5169
5270
  def empty(
5170
- shape: Union[int, Tuple[int, ...], List[int], None] = None,
5271
+ shape: int | tuple[int, ...] | list[int] | None = None,
5171
5272
  dtype=float,
5172
5273
  device: Devicelike = None,
5173
5274
  requires_grad: bool = False,
@@ -5200,7 +5301,7 @@ def empty(
5200
5301
 
5201
5302
 
5202
5303
  def empty_like(
5203
- src: Array, device: Devicelike = None, requires_grad: Optional[bool] = None, pinned: Optional[bool] = None
5304
+ src: Array, device: Devicelike = None, requires_grad: bool | None = None, pinned: bool | None = None
5204
5305
  ) -> warp.array:
5205
5306
  """Return an uninitialized array with the same type and dimension of another array
5206
5307
 
@@ -5235,9 +5336,9 @@ def empty_like(
5235
5336
 
5236
5337
  def from_numpy(
5237
5338
  arr: np.ndarray,
5238
- dtype: Optional[type] = None,
5239
- shape: Optional[Sequence[int]] = None,
5240
- device: Optional[Devicelike] = None,
5339
+ dtype: type | None = None,
5340
+ shape: Sequence[int] | None = None,
5341
+ device: Devicelike | None = None,
5241
5342
  requires_grad: bool = False,
5242
5343
  ) -> warp.array:
5243
5344
  """Returns a Warp array created from a NumPy array.
@@ -5255,7 +5356,7 @@ def from_numpy(
5255
5356
  if dtype is None:
5256
5357
  base_type = warp.types.np_dtype_to_warp_type.get(arr.dtype)
5257
5358
  if base_type is None:
5258
- raise RuntimeError("Unsupported NumPy data type '{}'.".format(arr.dtype))
5359
+ raise RuntimeError(f"Unsupported NumPy data type '{arr.dtype}'.")
5259
5360
 
5260
5361
  dim_count = len(arr.shape)
5261
5362
  if dim_count == 2:
@@ -5274,7 +5375,7 @@ def from_numpy(
5274
5375
  )
5275
5376
 
5276
5377
 
5277
- def event_from_ipc_handle(handle, device: "Devicelike" = None) -> Event:
5378
+ def event_from_ipc_handle(handle, device: Devicelike = None) -> Event:
5278
5379
  """Create an event from an IPC handle.
5279
5380
 
5280
5381
  Args:
@@ -5443,10 +5544,10 @@ class Launch:
5443
5544
  self,
5444
5545
  kernel,
5445
5546
  device: Device,
5446
- hooks: Optional[KernelHooks] = None,
5447
- params: Optional[Sequence[Any]] = None,
5448
- params_addr: Optional[Sequence[ctypes.c_void_p]] = None,
5449
- bounds: Optional[launch_bounds_t] = None,
5547
+ hooks: KernelHooks | None = None,
5548
+ params: Sequence[Any] | None = None,
5549
+ params_addr: Sequence[ctypes.c_void_p] | None = None,
5550
+ bounds: launch_bounds_t | None = None,
5450
5551
  max_blocks: int = 0,
5451
5552
  block_dim: int = 256,
5452
5553
  adjoint: bool = False,
@@ -5516,7 +5617,7 @@ class Launch:
5516
5617
  self.adjoint: bool = adjoint
5517
5618
  """Whether to run the adjoint kernel instead of the forward kernel."""
5518
5619
 
5519
- def set_dim(self, dim: Union[int, List[int], Tuple[int, ...]]):
5620
+ def set_dim(self, dim: int | list[int] | tuple[int, ...]):
5520
5621
  """Set the launch dimensions.
5521
5622
 
5522
5623
  Args:
@@ -5554,7 +5655,7 @@ class Launch:
5554
5655
  if self.params_addr:
5555
5656
  self.params_addr[params_index] = ctypes.c_void_p(ctypes.addressof(carg))
5556
5657
 
5557
- def set_param_at_index_from_ctype(self, index: int, value: Union[ctypes.Structure, int, float]):
5658
+ def set_param_at_index_from_ctype(self, index: int, value: ctypes.Structure | int | float):
5558
5659
  """Set a kernel parameter at an index without any type conversion.
5559
5660
 
5560
5661
  Args:
@@ -5617,7 +5718,7 @@ class Launch:
5617
5718
  for i, v in enumerate(values):
5618
5719
  self.set_param_at_index_from_ctype(i, v)
5619
5720
 
5620
- def launch(self, stream: Optional[Stream] = None) -> None:
5721
+ def launch(self, stream: Stream | None = None) -> None:
5621
5722
  """Launch the kernel.
5622
5723
 
5623
5724
  Args:
@@ -5634,7 +5735,7 @@ class Launch:
5634
5735
 
5635
5736
  # If the stream is capturing, we retain the CUDA module so that it doesn't get unloaded
5636
5737
  # before the captured graph is released.
5637
- if runtime.core.cuda_stream_is_capturing(stream.cuda_stream):
5738
+ if len(runtime.captures) > 0 and runtime.core.cuda_stream_is_capturing(stream.cuda_stream):
5638
5739
  capture_id = runtime.core.cuda_stream_get_capture_id(stream.cuda_stream)
5639
5740
  graph = runtime.captures.get(capture_id)
5640
5741
  if graph is not None:
@@ -5666,13 +5767,13 @@ class Launch:
5666
5767
 
5667
5768
  def launch(
5668
5769
  kernel,
5669
- dim: Union[int, Sequence[int]],
5770
+ dim: int | Sequence[int],
5670
5771
  inputs: Sequence = [],
5671
5772
  outputs: Sequence = [],
5672
5773
  adj_inputs: Sequence = [],
5673
5774
  adj_outputs: Sequence = [],
5674
5775
  device: Devicelike = None,
5675
- stream: Optional[Stream] = None,
5776
+ stream: Stream | None = None,
5676
5777
  adjoint: bool = False,
5677
5778
  record_tape: bool = True,
5678
5779
  record_cmd: bool = False,
@@ -5824,7 +5925,7 @@ def launch(
5824
5925
 
5825
5926
  # If the stream is capturing, we retain the CUDA module so that it doesn't get unloaded
5826
5927
  # before the captured graph is released.
5827
- if runtime.core.cuda_stream_is_capturing(stream.cuda_stream):
5928
+ if len(runtime.captures) > 0 and runtime.core.cuda_stream_is_capturing(stream.cuda_stream):
5828
5929
  capture_id = runtime.core.cuda_stream_get_capture_id(stream.cuda_stream)
5829
5930
  graph = runtime.captures.get(capture_id)
5830
5931
  if graph is not None:
@@ -5968,7 +6069,7 @@ def launch_tiled(*args, **kwargs):
5968
6069
  raise RuntimeError("wp.launch_tiled() requires a grid with fewer than 4 dimensions")
5969
6070
 
5970
6071
  # add trailing dimension
5971
- kwargs["dim"] = dim + [kwargs["block_dim"]]
6072
+ kwargs["dim"] = [*dim, kwargs["block_dim"]]
5972
6073
 
5973
6074
  # forward to original launch method
5974
6075
  return launch(*args, **kwargs)
@@ -6016,7 +6117,7 @@ def synchronize_device(device: Devicelike = None):
6016
6117
  runtime.core.cuda_context_synchronize(device.context)
6017
6118
 
6018
6119
 
6019
- def synchronize_stream(stream_or_device: Union[Stream, Devicelike, None] = None):
6120
+ def synchronize_stream(stream_or_device: Stream | Devicelike | None = None):
6020
6121
  """Synchronize the calling CPU thread with any outstanding CUDA work on the specified stream.
6021
6122
 
6022
6123
  This function allows the host application code to ensure that all kernel launches
@@ -6046,7 +6147,7 @@ def synchronize_event(event: Event):
6046
6147
  runtime.core.cuda_event_synchronize(event.cuda_event)
6047
6148
 
6048
6149
 
6049
- def force_load(device: Union[Device, str, List[Device], List[str]] = None, modules: List[Module] = None):
6150
+ def force_load(device: Device | str | list[Device] | list[str] | None = None, modules: list[Module] | None = None):
6050
6151
  """Force user-defined kernels to be compiled and loaded
6051
6152
 
6052
6153
  Args:
@@ -6078,7 +6179,7 @@ def force_load(device: Union[Device, str, List[Device], List[str]] = None, modul
6078
6179
 
6079
6180
 
6080
6181
  def load_module(
6081
- module: Union[Module, types.ModuleType, str] = None, device: Union[Device, str] = None, recursive: bool = False
6182
+ module: Module | types.ModuleType | str | None = None, device: Device | str | None = None, recursive: bool = False
6082
6183
  ):
6083
6184
  """Force user-defined module to be compiled and loaded
6084
6185
 
@@ -6120,7 +6221,7 @@ def load_module(
6120
6221
  force_load(device=device, modules=modules)
6121
6222
 
6122
6223
 
6123
- def set_module_options(options: Dict[str, Any], module: Optional[Any] = None):
6224
+ def set_module_options(options: dict[str, Any], module: Any = None):
6124
6225
  """Set options for the current module.
6125
6226
 
6126
6227
  Options can be used to control runtime compilation and code-generation
@@ -6144,7 +6245,7 @@ def set_module_options(options: Dict[str, Any], module: Optional[Any] = None):
6144
6245
  get_module(m.__name__).mark_modified()
6145
6246
 
6146
6247
 
6147
- def get_module_options(module: Optional[Any] = None) -> Dict[str, Any]:
6248
+ def get_module_options(module: Any = None) -> dict[str, Any]:
6148
6249
  """Returns a list of options for the current module."""
6149
6250
  if module is None:
6150
6251
  m = inspect.getmodule(inspect.stack()[1][0])
@@ -6156,8 +6257,8 @@ def get_module_options(module: Optional[Any] = None) -> Dict[str, Any]:
6156
6257
 
6157
6258
  def capture_begin(
6158
6259
  device: Devicelike = None,
6159
- stream: Optional[Stream] = None,
6160
- force_module_load: Optional[bool] = None,
6260
+ stream: Stream | None = None,
6261
+ force_module_load: bool | None = None,
6161
6262
  external: bool = False,
6162
6263
  ):
6163
6264
  """Begin capture of a CUDA graph
@@ -6226,7 +6327,7 @@ def capture_begin(
6226
6327
  runtime.captures[capture_id] = graph
6227
6328
 
6228
6329
 
6229
- def capture_end(device: Devicelike = None, stream: Optional[Stream] = None) -> Graph:
6330
+ def capture_end(device: Devicelike = None, stream: Stream | None = None) -> Graph:
6230
6331
  """End the capture of a CUDA graph.
6231
6332
 
6232
6333
  Args:
@@ -6255,20 +6356,324 @@ def capture_end(device: Devicelike = None, stream: Optional[Stream] = None) -> G
6255
6356
  del runtime.captures[graph.capture_id]
6256
6357
 
6257
6358
  # get the graph executable
6258
- graph_exec = ctypes.c_void_p()
6259
- result = runtime.core.cuda_graph_end_capture(device.context, stream.cuda_stream, ctypes.byref(graph_exec))
6359
+ g = ctypes.c_void_p()
6360
+ result = runtime.core.cuda_graph_end_capture(device.context, stream.cuda_stream, ctypes.byref(g))
6260
6361
 
6261
6362
  if not result:
6262
6363
  # A concrete error should've already been reported, so we don't need to go into details here
6263
6364
  raise RuntimeError(f"CUDA graph capture failed. {runtime.get_error_string()}")
6264
6365
 
6265
6366
  # set the graph executable
6266
- graph.graph_exec = graph_exec
6367
+ graph.graph = g
6368
+ graph.graph_exec = None # Lazy initialization
6267
6369
 
6268
6370
  return graph
6269
6371
 
6270
6372
 
6271
- def capture_launch(graph: Graph, stream: Optional[Stream] = None):
6373
+ def capture_debug_dot_print(graph: Graph, path: str, verbose: bool = False):
6374
+ """Export a CUDA graph to a DOT file for visualization
6375
+
6376
+ Args:
6377
+ graph: A :class:`Graph` as returned by :func:`~warp.capture_end()`
6378
+ path: Path to save the DOT file
6379
+ verbose: Whether to include additional debug information in the output
6380
+ """
6381
+ if not runtime.core.capture_debug_dot_print(graph.graph, path.encode(), 0 if verbose else 1):
6382
+ raise RuntimeError(f"Graph debug dot print error: {runtime.get_error_string()}")
6383
+
6384
+
6385
+ def assert_conditional_graph_support():
6386
+ if runtime is None:
6387
+ init()
6388
+
6389
+ if runtime.toolkit_version < (12, 4):
6390
+ raise RuntimeError("Warp must be built with CUDA Toolkit 12.4+ to enable conditional graph nodes")
6391
+
6392
+ if runtime.driver_version < (12, 4):
6393
+ raise RuntimeError("Conditional graph nodes require CUDA driver 12.4+")
6394
+
6395
+
6396
+ def capture_pause(device: Devicelike = None, stream: Stream | None = None) -> ctypes.c_void_p:
6397
+ if stream is not None:
6398
+ device = stream.device
6399
+ else:
6400
+ device = runtime.get_device(device)
6401
+ if not device.is_cuda:
6402
+ raise RuntimeError("Must be a CUDA device")
6403
+ stream = device.stream
6404
+
6405
+ graph = ctypes.c_void_p()
6406
+ if not runtime.core.cuda_graph_pause_capture(device.context, stream.cuda_stream, ctypes.byref(graph)):
6407
+ raise RuntimeError(runtime.get_error_string())
6408
+
6409
+ return graph
6410
+
6411
+
6412
+ def capture_resume(graph: ctypes.c_void_p, device: Devicelike = None, stream: Stream | None = None):
6413
+ if stream is not None:
6414
+ device = stream.device
6415
+ else:
6416
+ device = runtime.get_device(device)
6417
+ if not device.is_cuda:
6418
+ raise RuntimeError("Must be a CUDA device")
6419
+ stream = device.stream
6420
+
6421
+ if not runtime.core.cuda_graph_resume_capture(device.context, stream.cuda_stream, graph):
6422
+ raise RuntimeError(runtime.get_error_string())
6423
+
6424
+
6425
+ # reusable pinned readback buffer for conditions
6426
+ condition_host = None
6427
+
6428
+
6429
+ def capture_if(
6430
+ condition: warp.array(dtype=int),
6431
+ on_true: Callable | Graph | None = None,
6432
+ on_false: Callable | Graph | None = None,
6433
+ stream: Stream = None,
6434
+ **kwargs,
6435
+ ):
6436
+ """Create a dynamic branch based on a condition.
6437
+
6438
+ The condition value is retrieved from the first element of the ``condition`` array.
6439
+
6440
+ This function is particularly useful with CUDA graphs, but can be used without graph capture as well.
6441
+ CUDA 12.4+ is required to take advantage of conditional graph nodes for dynamic control flow.
6442
+
6443
+ Args:
6444
+ condition: Warp array holding the condition value.
6445
+ on_true: A callback function or :class:`Graph` to execute if the condition is True.
6446
+ on_false: A callback function or :class:`Graph` to execute if the condition is False.
6447
+ stream: The CUDA stream where the condition was written. If None, use the current stream on the device where ``condition`` resides.
6448
+
6449
+ Any additional keyword arguments are forwarded to the callback functions.
6450
+ """
6451
+
6452
+ # if neither the IF branch nor the ELSE branch is specified, it's a no-op
6453
+ if on_true is None and on_false is None:
6454
+ return
6455
+
6456
+ # check condition data type
6457
+ if not isinstance(condition, warp.array) or condition.dtype is not warp.int32:
6458
+ raise TypeError("Condition must be a Warp array of int32 with a single element")
6459
+
6460
+ device = condition.device
6461
+
6462
+ # determine the stream and whether a graph capture is active
6463
+ if device.is_cuda:
6464
+ if stream is None:
6465
+ stream = device.stream
6466
+ graph = device.captures.get(stream)
6467
+ else:
6468
+ graph = None
6469
+
6470
+ if graph is None:
6471
+ # if no graph is active, just execute the correct branch directly
6472
+ if device.is_cuda:
6473
+ # use a pinned buffer for condition readback to host
6474
+ global condition_host
6475
+ if condition_host is None:
6476
+ condition_host = warp.empty(1, dtype=int, device="cpu", pinned=True)
6477
+ warp.copy(condition_host, condition, stream=stream)
6478
+ warp.synchronize_stream(stream)
6479
+ condition_value = bool(ctypes.cast(condition_host.ptr, ctypes.POINTER(ctypes.c_int32)).contents)
6480
+ else:
6481
+ condition_value = bool(ctypes.cast(condition.ptr, ctypes.POINTER(ctypes.c_int32)).contents)
6482
+
6483
+ if condition_value:
6484
+ if on_true is not None:
6485
+ if isinstance(on_true, Callable):
6486
+ on_true(**kwargs)
6487
+ elif isinstance(on_true, Graph):
6488
+ capture_launch(on_true, stream=stream)
6489
+ else:
6490
+ raise TypeError("on_true must be a Callable or a Graph")
6491
+ else:
6492
+ if on_false is not None:
6493
+ if isinstance(on_false, Callable):
6494
+ on_false(**kwargs)
6495
+ elif isinstance(on_false, Graph):
6496
+ capture_launch(on_false, stream=stream)
6497
+ else:
6498
+ raise TypeError("on_false must be a Callable or a Graph")
6499
+
6500
+ return
6501
+
6502
+ graph.has_conditional = True
6503
+
6504
+ # ensure conditional graph nodes are supported
6505
+ assert_conditional_graph_support()
6506
+
6507
+ # insert conditional node
6508
+ graph_on_true = ctypes.c_void_p()
6509
+ graph_on_false = ctypes.c_void_p()
6510
+ if not runtime.core.cuda_graph_insert_if_else(
6511
+ device.context,
6512
+ stream.cuda_stream,
6513
+ ctypes.cast(condition.ptr, ctypes.POINTER(ctypes.c_int32)),
6514
+ None if on_true is None else ctypes.byref(graph_on_true),
6515
+ None if on_false is None else ctypes.byref(graph_on_false),
6516
+ ):
6517
+ raise RuntimeError(runtime.get_error_string())
6518
+
6519
+ # pause capturing parent graph
6520
+ main_graph = capture_pause(stream=stream)
6521
+
6522
+ # capture if-graph
6523
+ if on_true is not None:
6524
+ capture_resume(graph_on_true, stream=stream)
6525
+ if isinstance(on_true, Callable):
6526
+ on_true(**kwargs)
6527
+ elif isinstance(on_true, Graph):
6528
+ if on_true.has_conditional:
6529
+ raise RuntimeError(
6530
+ "The on_true graph contains conditional nodes, which are not allowed in child graphs"
6531
+ )
6532
+ if not runtime.core.cuda_graph_insert_child_graph(
6533
+ device.context,
6534
+ stream.cuda_stream,
6535
+ on_true.graph,
6536
+ ):
6537
+ raise RuntimeError(runtime.get_error_string())
6538
+ else:
6539
+ raise TypeError("on_true must be a Callable or a Graph")
6540
+ capture_pause(stream=stream)
6541
+
6542
+ # capture else-graph
6543
+ if on_false is not None:
6544
+ capture_resume(graph_on_false, stream=stream)
6545
+ if isinstance(on_false, Callable):
6546
+ on_false(**kwargs)
6547
+ elif isinstance(on_false, Graph):
6548
+ if on_false.has_conditional:
6549
+ raise RuntimeError(
6550
+ "The on_false graph contains conditional nodes, which are not allowed in child graphs"
6551
+ )
6552
+ if not runtime.core.cuda_graph_insert_child_graph(
6553
+ device.context,
6554
+ stream.cuda_stream,
6555
+ on_false.graph,
6556
+ ):
6557
+ raise RuntimeError(runtime.get_error_string())
6558
+ else:
6559
+ raise TypeError("on_false must be a Callable or a Graph")
6560
+ capture_pause(stream=stream)
6561
+
6562
+ # resume capturing parent graph
6563
+ capture_resume(main_graph, stream=stream)
6564
+
6565
+
6566
+ def capture_while(condition: warp.array(dtype=int), while_body: Callable | Graph, stream: Stream = None, **kwargs):
6567
+ """Create a dynamic loop based on a condition.
6568
+
6569
+ The condition value is retrieved from the first element of the ``condition`` array.
6570
+
6571
+ The ``while_body`` callback is responsible for updating the condition value so the loop can terminate.
6572
+
6573
+ This function is particularly useful with CUDA graphs, but can be used without graph capture as well.
6574
+ CUDA 12.4+ is required to take advantage of conditional graph nodes for dynamic control flow.
6575
+
6576
+ Args:
6577
+ condition: Warp array holding the condition value.
6578
+ while_body: A callback function or :class:`Graph` to execute while the loop condition is True.
6579
+ stream: The CUDA stream where the condition was written. If None, use the current stream on the device where ``condition`` resides.
6580
+
6581
+ Any additional keyword arguments are forwarded to the callback function.
6582
+ """
6583
+
6584
+ # check condition data type
6585
+ if not isinstance(condition, warp.array) or condition.dtype is not warp.int32:
6586
+ raise TypeError("Condition must be a Warp array of int32 with a single element")
6587
+
6588
+ device = condition.device
6589
+
6590
+ # determine the stream and whether a graph capture is active
6591
+ if device.is_cuda:
6592
+ if stream is None:
6593
+ stream = device.stream
6594
+ graph = device.captures.get(stream)
6595
+ else:
6596
+ graph = None
6597
+
6598
+ if graph is None:
6599
+ # since no graph is active, just execute the kernels directly
6600
+ while True:
6601
+ if device.is_cuda:
6602
+ # use a pinned buffer for condition readback to host
6603
+ global condition_host
6604
+ if condition_host is None:
6605
+ condition_host = warp.empty(1, dtype=int, device="cpu", pinned=True)
6606
+ warp.copy(condition_host, condition, stream=stream)
6607
+ warp.synchronize_stream(stream)
6608
+ condition_value = bool(ctypes.cast(condition_host.ptr, ctypes.POINTER(ctypes.c_int32)).contents)
6609
+ else:
6610
+ condition_value = bool(ctypes.cast(condition.ptr, ctypes.POINTER(ctypes.c_int32)).contents)
6611
+
6612
+ if condition_value:
6613
+ if isinstance(while_body, Callable):
6614
+ while_body(**kwargs)
6615
+ elif isinstance(while_body, Graph):
6616
+ capture_launch(while_body, stream=stream)
6617
+ else:
6618
+ raise TypeError("while_body must be a callable or a graph")
6619
+
6620
+ else:
6621
+ break
6622
+
6623
+ return
6624
+
6625
+ graph.has_conditional = True
6626
+
6627
+ # ensure conditional graph nodes are supported
6628
+ assert_conditional_graph_support()
6629
+
6630
+ # insert conditional while-node
6631
+ body_graph = ctypes.c_void_p()
6632
+ cond_handle = ctypes.c_uint64()
6633
+ if not runtime.core.cuda_graph_insert_while(
6634
+ device.context,
6635
+ stream.cuda_stream,
6636
+ ctypes.cast(condition.ptr, ctypes.POINTER(ctypes.c_int32)),
6637
+ ctypes.byref(body_graph),
6638
+ ctypes.byref(cond_handle),
6639
+ ):
6640
+ raise RuntimeError(runtime.get_error_string())
6641
+
6642
+ # pause capturing parent graph and start capturing child graph
6643
+ main_graph = capture_pause(stream=stream)
6644
+ capture_resume(body_graph, stream=stream)
6645
+
6646
+ # capture while-body
6647
+ if isinstance(while_body, Callable):
6648
+ while_body(**kwargs)
6649
+ elif isinstance(while_body, Graph):
6650
+ if while_body.has_conditional:
6651
+ raise RuntimeError("The body graph contains conditional nodes, which are not allowed in child graphs")
6652
+
6653
+ if not runtime.core.cuda_graph_insert_child_graph(
6654
+ device.context,
6655
+ stream.cuda_stream,
6656
+ while_body.graph,
6657
+ ):
6658
+ raise RuntimeError(runtime.get_error_string())
6659
+ else:
6660
+ raise RuntimeError(runtime.get_error_string())
6661
+
6662
+ # update condition
6663
+ if not runtime.core.cuda_graph_set_condition(
6664
+ device.context,
6665
+ stream.cuda_stream,
6666
+ ctypes.cast(condition.ptr, ctypes.POINTER(ctypes.c_int32)),
6667
+ cond_handle,
6668
+ ):
6669
+ raise RuntimeError(runtime.get_error_string())
6670
+
6671
+ # stop capturing child graph and resume capturing parent graph
6672
+ capture_pause(stream=stream)
6673
+ capture_resume(main_graph, stream=stream)
6674
+
6675
+
6676
+ def capture_launch(graph: Graph, stream: Stream | None = None):
6272
6677
  """Launch a previously captured CUDA graph
6273
6678
 
6274
6679
  Args:
@@ -6284,6 +6689,13 @@ def capture_launch(graph: Graph, stream: Optional[Stream] = None):
6284
6689
  device = graph.device
6285
6690
  stream = device.stream
6286
6691
 
6692
+ if graph.graph_exec is None:
6693
+ g = ctypes.c_void_p()
6694
+ result = runtime.core.cuda_graph_create_exec(graph.device.context, graph.graph, ctypes.byref(g))
6695
+ if not result:
6696
+ raise RuntimeError(f"Graph creation error: {runtime.get_error_string()}")
6697
+ graph.graph_exec = g
6698
+
6287
6699
  if not runtime.core.cuda_graph_launch(graph.graph_exec, stream.cuda_stream):
6288
6700
  raise RuntimeError(f"Graph launch error: {runtime.get_error_string()}")
6289
6701
 
@@ -6294,7 +6706,7 @@ def copy(
6294
6706
  dest_offset: int = 0,
6295
6707
  src_offset: int = 0,
6296
6708
  count: int = 0,
6297
- stream: Optional[Stream] = None,
6709
+ stream: Stream | None = None,
6298
6710
  ):
6299
6711
  """Copy array contents from `src` to `dest`.
6300
6712
 
@@ -6431,11 +6843,8 @@ def copy(
6431
6843
 
6432
6844
  # can't copy to/from fabric arrays of arrays, because they are jagged arrays of arbitrary lengths
6433
6845
  # TODO?
6434
- if (
6435
- isinstance(src, (warp.fabricarray, warp.indexedfabricarray))
6436
- and src.ndim > 1
6437
- or isinstance(dest, (warp.fabricarray, warp.indexedfabricarray))
6438
- and dest.ndim > 1
6846
+ if (isinstance(src, (warp.fabricarray, warp.indexedfabricarray)) and src.ndim > 1) or (
6847
+ isinstance(dest, (warp.fabricarray, warp.indexedfabricarray)) and dest.ndim > 1
6439
6848
  ):
6440
6849
  raise RuntimeError("Copying to/from Fabric arrays of arrays is not supported")
6441
6850
 
@@ -6503,7 +6912,7 @@ def type_str(t):
6503
6912
  return "Callable"
6504
6913
  elif isinstance(t, int):
6505
6914
  return str(t)
6506
- elif isinstance(t, List):
6915
+ elif isinstance(t, (List, tuple)):
6507
6916
  return "Tuple[" + ", ".join(map(type_str, t)) + "]"
6508
6917
  elif isinstance(t, warp.array):
6509
6918
  return f"Array[{type_str(t.dtype)}]"
@@ -6536,12 +6945,16 @@ def type_str(t):
6536
6945
 
6537
6946
  raise TypeError("Invalid vector or matrix dimensions")
6538
6947
  elif get_origin(t) in (list, tuple):
6539
- args_repr = ", ".join(type_str(x) for x in get_args(t))
6540
- return f"{t._name}[{args_repr}]"
6948
+ args = get_args(t)
6949
+ if args:
6950
+ args_repr = ", ".join(type_str(x) for x in get_args(t))
6951
+ return f"{t._name}[{args_repr}]"
6952
+ else:
6953
+ return f"{t._name}"
6541
6954
  elif t is Ellipsis:
6542
6955
  return "..."
6543
6956
  elif warp.types.is_tile(t):
6544
- return "Tile"
6957
+ return f"Tile[{type_str(t.dtype)},{type_str(t.shape)}]"
6545
6958
 
6546
6959
  return t.__name__
6547
6960
 
@@ -6568,14 +6981,14 @@ def resolve_exported_function_sig(f):
6568
6981
  # so we can generate the return type for overloaded functions
6569
6982
  return_type = f.value_func(func_args, None)
6570
6983
 
6984
+ if return_type is None or (isinstance(return_type, tuple) and len(return_type) > 1):
6985
+ return (func_args, return_type)
6986
+
6571
6987
  try:
6572
- return_type_str = ctype_ret_str(return_type)
6988
+ ctype_ret_str(return_type)
6573
6989
  except Exception:
6574
6990
  return None
6575
6991
 
6576
- if return_type_str.startswith("Tuple"):
6577
- return None
6578
-
6579
6992
  return (func_args, return_type)
6580
6993
 
6581
6994
 
@@ -6716,13 +7129,18 @@ def export_functions_rst(file): # pragma: no cover
6716
7129
  print("---------------", file=file)
6717
7130
 
6718
7131
  for f, is_exported in g:
7132
+ if not isinstance(f, Function) and callable(f):
7133
+ # f is a plain Python function
7134
+ print(f".. autofunction:: {f.__module__}.{f.__name__}", file=file)
7135
+ continue
6719
7136
  if f.func:
6720
7137
  # f is a Warp function written in Python, we can use autofunction
6721
7138
  print(f".. autofunction:: {f.func.__module__}.{f.key}", file=file)
6722
7139
  continue
6723
7140
  for f_prefix, query_type in query_types:
6724
7141
  if f.key.startswith(f_prefix) and query_type not in written_query_types:
6725
- print(f".. autoclass:: {query_type}", file=file)
7142
+ print(f".. autoclass:: warp.{query_type}", file=file)
7143
+ print(" :exclude-members: Var, vars", file=file)
6726
7144
  written_query_types.add(query_type)
6727
7145
  break
6728
7146
 
@@ -6775,6 +7193,7 @@ def export_stubs(file): # pragma: no cover
6775
7193
  print('Rows = TypeVar("Rows", bound=int)', file=file)
6776
7194
  print('Cols = TypeVar("Cols", bound=int)', file=file)
6777
7195
  print('DType = TypeVar("DType")', file=file)
7196
+ print('Shape = TypeVar("Shape")', file=file)
6778
7197
 
6779
7198
  print("Vector = Generic[Length, Scalar]", file=file)
6780
7199
  print("Matrix = Generic[Rows, Cols, Scalar]", file=file)
@@ -6783,6 +7202,7 @@ def export_stubs(file): # pragma: no cover
6783
7202
  print("Array = Generic[DType]", file=file)
6784
7203
  print("FabricArray = Generic[DType]", file=file)
6785
7204
  print("IndexedFabricArray = Generic[DType]", file=file)
7205
+ print("Tile = Generic[DType, Shape]", file=file)
6786
7206
 
6787
7207
  # prepend __init__.py
6788
7208
  with open(os.path.join(os.path.dirname(file.name), "__init__.py")) as header_file:
@@ -6817,7 +7237,7 @@ def export_stubs(file): # pragma: no cover
6817
7237
  if hasattr(g, "overloads"):
6818
7238
  for f in g.overloads:
6819
7239
  add_stub(f)
6820
- else:
7240
+ elif isinstance(g, Function):
6821
7241
  add_stub(g)
6822
7242
 
6823
7243
 
@@ -6848,16 +7268,30 @@ def export_builtins(file: io.TextIOBase): # pragma: no cover
6848
7268
  args = ", ".join(f"{ctype_arg_str(v)} {k}" for k, v in func_args.items())
6849
7269
  params = ", ".join(func_args.keys())
6850
7270
 
6851
- return_str = ctype_ret_str(return_type)
6852
-
6853
- if args == "":
6854
- file.write(f"WP_API void {f.mangled_name}({return_str}* ret) {{ *ret = wp::{f.key}({params}); }}\n")
6855
- elif return_type is None:
7271
+ if return_type is None:
7272
+ # void function
6856
7273
  file.write(f"WP_API void {f.mangled_name}({args}) {{ wp::{f.key}({params}); }}\n")
7274
+ elif isinstance(return_type, tuple) and len(return_type) > 1:
7275
+ # multiple return value function using output parameters
7276
+ outputs = tuple(f"{ctype_ret_str(x)}& ret_{i}" for i, x in enumerate(return_type))
7277
+ output_params = ", ".join(f"ret_{i}" for i in range(len(outputs)))
7278
+ if args:
7279
+ file.write(
7280
+ f"WP_API void {f.mangled_name}({args}, {', '.join(outputs)}) {{ wp::{f.key}({params}, {output_params}); }}\n"
7281
+ )
7282
+ else:
7283
+ file.write(
7284
+ f"WP_API void {f.mangled_name}({', '.join(outputs)}) {{ wp::{f.key}({params}, {output_params}); }}\n"
7285
+ )
6857
7286
  else:
6858
- file.write(
6859
- f"WP_API void {f.mangled_name}({args}, {return_str}* ret) {{ *ret = wp::{f.key}({params}); }}\n"
6860
- )
7287
+ # single return value function
7288
+ return_str = ctype_ret_str(return_type)
7289
+ if args:
7290
+ file.write(
7291
+ f"WP_API void {f.mangled_name}({args}, {return_str}* ret) {{ *ret = wp::{f.key}({params}); }}\n"
7292
+ )
7293
+ else:
7294
+ file.write(f"WP_API void {f.mangled_name}({return_str}* ret) {{ *ret = wp::{f.key}({params}); }}\n")
6861
7295
 
6862
7296
  file.write('\n} // extern "C"\n\n')
6863
7297
  file.write("} // namespace wp\n")