warp-lang 1.7.2rc1__py3-none-macosx_10_13_universal2.whl → 1.8.1__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (192) hide show
  1. warp/__init__.py +3 -1
  2. warp/__init__.pyi +3489 -1
  3. warp/autograd.py +45 -122
  4. warp/bin/libwarp.dylib +0 -0
  5. warp/build.py +241 -252
  6. warp/build_dll.py +130 -26
  7. warp/builtins.py +1907 -384
  8. warp/codegen.py +272 -104
  9. warp/config.py +12 -1
  10. warp/constants.py +1 -1
  11. warp/context.py +770 -238
  12. warp/dlpack.py +1 -1
  13. warp/examples/benchmarks/benchmark_cloth.py +2 -2
  14. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  15. warp/examples/core/example_sample_mesh.py +1 -1
  16. warp/examples/core/example_spin_lock.py +93 -0
  17. warp/examples/core/example_work_queue.py +118 -0
  18. warp/examples/fem/example_adaptive_grid.py +5 -5
  19. warp/examples/fem/example_apic_fluid.py +1 -1
  20. warp/examples/fem/example_burgers.py +1 -1
  21. warp/examples/fem/example_convection_diffusion.py +9 -6
  22. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  23. warp/examples/fem/example_deformed_geometry.py +1 -1
  24. warp/examples/fem/example_diffusion.py +2 -2
  25. warp/examples/fem/example_diffusion_3d.py +1 -1
  26. warp/examples/fem/example_distortion_energy.py +1 -1
  27. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  28. warp/examples/fem/example_magnetostatics.py +5 -3
  29. warp/examples/fem/example_mixed_elasticity.py +5 -3
  30. warp/examples/fem/example_navier_stokes.py +11 -9
  31. warp/examples/fem/example_nonconforming_contact.py +5 -3
  32. warp/examples/fem/example_streamlines.py +8 -3
  33. warp/examples/fem/utils.py +9 -8
  34. warp/examples/interop/example_jax_callable.py +34 -4
  35. warp/examples/interop/example_jax_ffi_callback.py +2 -2
  36. warp/examples/interop/example_jax_kernel.py +27 -1
  37. warp/examples/optim/example_drone.py +1 -1
  38. warp/examples/sim/example_cloth.py +1 -1
  39. warp/examples/sim/example_cloth_self_contact.py +48 -54
  40. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  41. warp/examples/tile/example_tile_cholesky.py +2 -1
  42. warp/examples/tile/example_tile_convolution.py +1 -1
  43. warp/examples/tile/example_tile_filtering.py +1 -1
  44. warp/examples/tile/example_tile_matmul.py +1 -1
  45. warp/examples/tile/example_tile_mlp.py +2 -0
  46. warp/fabric.py +7 -7
  47. warp/fem/__init__.py +5 -0
  48. warp/fem/adaptivity.py +1 -1
  49. warp/fem/cache.py +152 -63
  50. warp/fem/dirichlet.py +2 -2
  51. warp/fem/domain.py +136 -6
  52. warp/fem/field/field.py +141 -99
  53. warp/fem/field/nodal_field.py +85 -39
  54. warp/fem/field/virtual.py +99 -52
  55. warp/fem/geometry/adaptive_nanogrid.py +91 -86
  56. warp/fem/geometry/closest_point.py +13 -0
  57. warp/fem/geometry/deformed_geometry.py +102 -40
  58. warp/fem/geometry/element.py +56 -2
  59. warp/fem/geometry/geometry.py +323 -22
  60. warp/fem/geometry/grid_2d.py +157 -62
  61. warp/fem/geometry/grid_3d.py +116 -20
  62. warp/fem/geometry/hexmesh.py +86 -20
  63. warp/fem/geometry/nanogrid.py +166 -86
  64. warp/fem/geometry/partition.py +59 -25
  65. warp/fem/geometry/quadmesh.py +86 -135
  66. warp/fem/geometry/tetmesh.py +47 -119
  67. warp/fem/geometry/trimesh.py +77 -270
  68. warp/fem/integrate.py +181 -95
  69. warp/fem/linalg.py +25 -58
  70. warp/fem/operator.py +124 -27
  71. warp/fem/quadrature/pic_quadrature.py +36 -14
  72. warp/fem/quadrature/quadrature.py +40 -16
  73. warp/fem/space/__init__.py +1 -1
  74. warp/fem/space/basis_function_space.py +66 -46
  75. warp/fem/space/basis_space.py +17 -4
  76. warp/fem/space/dof_mapper.py +1 -1
  77. warp/fem/space/function_space.py +2 -2
  78. warp/fem/space/grid_2d_function_space.py +4 -1
  79. warp/fem/space/hexmesh_function_space.py +4 -2
  80. warp/fem/space/nanogrid_function_space.py +3 -1
  81. warp/fem/space/partition.py +11 -2
  82. warp/fem/space/quadmesh_function_space.py +4 -1
  83. warp/fem/space/restriction.py +5 -2
  84. warp/fem/space/shape/__init__.py +10 -8
  85. warp/fem/space/tetmesh_function_space.py +4 -1
  86. warp/fem/space/topology.py +52 -21
  87. warp/fem/space/trimesh_function_space.py +4 -1
  88. warp/fem/utils.py +53 -8
  89. warp/jax.py +1 -2
  90. warp/jax_experimental/ffi.py +210 -67
  91. warp/jax_experimental/xla_ffi.py +37 -24
  92. warp/math.py +171 -1
  93. warp/native/array.h +103 -4
  94. warp/native/builtin.h +182 -35
  95. warp/native/coloring.cpp +6 -2
  96. warp/native/cuda_util.cpp +1 -1
  97. warp/native/exports.h +118 -63
  98. warp/native/intersect.h +5 -5
  99. warp/native/mat.h +8 -13
  100. warp/native/mathdx.cpp +11 -5
  101. warp/native/matnn.h +1 -123
  102. warp/native/mesh.h +1 -1
  103. warp/native/quat.h +34 -6
  104. warp/native/rand.h +7 -7
  105. warp/native/sparse.cpp +121 -258
  106. warp/native/sparse.cu +181 -274
  107. warp/native/spatial.h +305 -17
  108. warp/native/svd.h +23 -8
  109. warp/native/tile.h +603 -73
  110. warp/native/tile_radix_sort.h +1112 -0
  111. warp/native/tile_reduce.h +239 -13
  112. warp/native/tile_scan.h +240 -0
  113. warp/native/tuple.h +189 -0
  114. warp/native/vec.h +10 -20
  115. warp/native/warp.cpp +36 -4
  116. warp/native/warp.cu +588 -52
  117. warp/native/warp.h +47 -74
  118. warp/optim/linear.py +5 -1
  119. warp/paddle.py +7 -8
  120. warp/py.typed +0 -0
  121. warp/render/render_opengl.py +110 -80
  122. warp/render/render_usd.py +124 -62
  123. warp/sim/__init__.py +9 -0
  124. warp/sim/collide.py +253 -80
  125. warp/sim/graph_coloring.py +8 -1
  126. warp/sim/import_mjcf.py +4 -3
  127. warp/sim/import_usd.py +11 -7
  128. warp/sim/integrator.py +5 -2
  129. warp/sim/integrator_euler.py +1 -1
  130. warp/sim/integrator_featherstone.py +1 -1
  131. warp/sim/integrator_vbd.py +761 -322
  132. warp/sim/integrator_xpbd.py +1 -1
  133. warp/sim/model.py +265 -260
  134. warp/sim/utils.py +10 -7
  135. warp/sparse.py +303 -166
  136. warp/tape.py +54 -51
  137. warp/tests/cuda/test_conditional_captures.py +1046 -0
  138. warp/tests/cuda/test_streams.py +1 -1
  139. warp/tests/geometry/test_volume.py +2 -2
  140. warp/tests/interop/test_dlpack.py +9 -9
  141. warp/tests/interop/test_jax.py +0 -1
  142. warp/tests/run_coverage_serial.py +1 -1
  143. warp/tests/sim/disabled_kinematics.py +2 -2
  144. warp/tests/sim/{test_vbd.py → test_cloth.py} +378 -112
  145. warp/tests/sim/test_collision.py +159 -51
  146. warp/tests/sim/test_coloring.py +91 -2
  147. warp/tests/test_array.py +254 -2
  148. warp/tests/test_array_reduce.py +2 -2
  149. warp/tests/test_assert.py +53 -0
  150. warp/tests/test_atomic_cas.py +312 -0
  151. warp/tests/test_codegen.py +142 -19
  152. warp/tests/test_conditional.py +47 -1
  153. warp/tests/test_ctypes.py +0 -20
  154. warp/tests/test_devices.py +8 -0
  155. warp/tests/test_fabricarray.py +4 -2
  156. warp/tests/test_fem.py +58 -25
  157. warp/tests/test_func.py +42 -1
  158. warp/tests/test_grad.py +1 -1
  159. warp/tests/test_lerp.py +1 -3
  160. warp/tests/test_map.py +481 -0
  161. warp/tests/test_mat.py +23 -24
  162. warp/tests/test_quat.py +28 -15
  163. warp/tests/test_rounding.py +10 -38
  164. warp/tests/test_runlength_encode.py +7 -7
  165. warp/tests/test_smoothstep.py +1 -1
  166. warp/tests/test_sparse.py +83 -2
  167. warp/tests/test_spatial.py +507 -1
  168. warp/tests/test_static.py +48 -0
  169. warp/tests/test_struct.py +2 -2
  170. warp/tests/test_tape.py +38 -0
  171. warp/tests/test_tuple.py +265 -0
  172. warp/tests/test_types.py +2 -2
  173. warp/tests/test_utils.py +24 -18
  174. warp/tests/test_vec.py +38 -408
  175. warp/tests/test_vec_constructors.py +325 -0
  176. warp/tests/tile/test_tile.py +438 -131
  177. warp/tests/tile/test_tile_mathdx.py +518 -14
  178. warp/tests/tile/test_tile_matmul.py +179 -0
  179. warp/tests/tile/test_tile_reduce.py +307 -5
  180. warp/tests/tile/test_tile_shared_memory.py +136 -7
  181. warp/tests/tile/test_tile_sort.py +121 -0
  182. warp/tests/unittest_suites.py +14 -6
  183. warp/types.py +462 -308
  184. warp/utils.py +647 -86
  185. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/METADATA +20 -6
  186. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/RECORD +189 -175
  187. warp/stubs.py +0 -3381
  188. warp/tests/sim/test_xpbd.py +0 -399
  189. warp/tests/test_mlp.py +0 -282
  190. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/WHEEL +0 -0
  191. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/licenses/LICENSE.md +0 -0
  192. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/top_level.txt +0 -0
warp/context.py CHANGED
@@ -32,22 +32,7 @@ import typing
32
32
  import weakref
33
33
  from copy import copy as shallowcopy
34
34
  from pathlib import Path
35
- from typing import (
36
- Any,
37
- Callable,
38
- Dict,
39
- List,
40
- Literal,
41
- Mapping,
42
- Optional,
43
- Sequence,
44
- Set,
45
- Tuple,
46
- TypeVar,
47
- Union,
48
- get_args,
49
- get_origin,
50
- )
35
+ from typing import Any, Callable, Dict, List, Literal, Mapping, Sequence, Tuple, TypeVar, Union, get_args, get_origin
51
36
 
52
37
  import numpy as np
53
38
 
@@ -84,7 +69,7 @@ def get_function_args(func):
84
69
  complex_type_hints = (Any, Callable, Tuple)
85
70
  sequence_types = (list, tuple)
86
71
 
87
- function_key_counts: Dict[str, int] = {}
72
+ function_key_counts: dict[str, int] = {}
88
73
 
89
74
 
90
75
  def generate_unique_function_identifier(key: str) -> str:
@@ -120,40 +105,41 @@ def generate_unique_function_identifier(key: str) -> str:
120
105
  class Function:
121
106
  def __init__(
122
107
  self,
123
- func: Optional[Callable],
108
+ func: Callable | None,
124
109
  key: str,
125
110
  namespace: str,
126
- input_types: Optional[Dict[str, Union[type, TypeVar]]] = None,
127
- value_type: Optional[type] = None,
128
- value_func: Optional[Callable[[Mapping[str, type], Mapping[str, Any]], type]] = None,
129
- export_func: Optional[Callable[[Dict[str, type]], Dict[str, type]]] = None,
130
- dispatch_func: Optional[Callable] = None,
131
- lto_dispatch_func: Optional[Callable] = None,
132
- module: Optional[Module] = None,
111
+ input_types: dict[str, type | TypeVar] | None = None,
112
+ value_type: type | None = None,
113
+ value_func: Callable[[Mapping[str, type], Mapping[str, Any]], type] | None = None,
114
+ export_func: Callable[[dict[str, type]], dict[str, type]] | None = None,
115
+ dispatch_func: Callable | None = None,
116
+ lto_dispatch_func: Callable | None = None,
117
+ module: Module | None = None,
133
118
  variadic: bool = False,
134
- initializer_list_func: Optional[Callable[[Dict[str, Any], type], bool]] = None,
119
+ initializer_list_func: Callable[[dict[str, Any], type], bool] | None = None,
135
120
  export: bool = False,
121
+ source: str | None = None,
136
122
  doc: str = "",
137
123
  group: str = "",
138
124
  hidden: bool = False,
139
125
  skip_replay: bool = False,
140
126
  missing_grad: bool = False,
141
127
  generic: bool = False,
142
- native_func: Optional[str] = None,
143
- defaults: Optional[Dict[str, Any]] = None,
144
- custom_replay_func: Optional[Function] = None,
145
- native_snippet: Optional[str] = None,
146
- adj_native_snippet: Optional[str] = None,
147
- replay_snippet: Optional[str] = None,
128
+ native_func: str | None = None,
129
+ defaults: dict[str, Any] | None = None,
130
+ custom_replay_func: Function | None = None,
131
+ native_snippet: str | None = None,
132
+ adj_native_snippet: str | None = None,
133
+ replay_snippet: str | None = None,
148
134
  skip_forward_codegen: bool = False,
149
135
  skip_reverse_codegen: bool = False,
150
136
  custom_reverse_num_input_args: int = -1,
151
137
  custom_reverse_mode: bool = False,
152
- overloaded_annotations: Optional[Dict[str, type]] = None,
153
- code_transformers: Optional[List[ast.NodeTransformer]] = None,
138
+ overloaded_annotations: dict[str, type] | None = None,
139
+ code_transformers: list[ast.NodeTransformer] | None = None,
154
140
  skip_adding_overload: bool = False,
155
141
  require_original_output_arg: bool = False,
156
- scope_locals: Optional[Dict[str, Any]] = None,
142
+ scope_locals: dict[str, Any] | None = None,
157
143
  ):
158
144
  if code_transformers is None:
159
145
  code_transformers = []
@@ -178,7 +164,7 @@ class Function:
178
164
  self.native_snippet = native_snippet
179
165
  self.adj_native_snippet = adj_native_snippet
180
166
  self.replay_snippet = replay_snippet
181
- self.custom_grad_func: Optional[Function] = None
167
+ self.custom_grad_func: Function | None = None
182
168
  self.require_original_output_arg = require_original_output_arg
183
169
  self.generic_parent = None # generic function that was used to instantiate this overload
184
170
 
@@ -194,7 +180,7 @@ class Function:
194
180
  )
195
181
  self.missing_grad = missing_grad # whether builtin is missing a corresponding adjoint
196
182
  self.generic = generic
197
- self.mangled_name: Optional[str] = None
183
+ self.mangled_name: str | None = None
198
184
 
199
185
  # allow registering functions with a different name in Python and native code
200
186
  if native_func is None:
@@ -211,12 +197,13 @@ class Function:
211
197
  # user-defined function
212
198
 
213
199
  # generic and concrete overload lookups by type signature
214
- self.user_templates: Dict[str, Function] = {}
215
- self.user_overloads: Dict[str, Function] = {}
200
+ self.user_templates: dict[str, Function] = {}
201
+ self.user_overloads: dict[str, Function] = {}
216
202
 
217
203
  # user defined (Python) function
218
204
  self.adj = warp.codegen.Adjoint(
219
205
  func,
206
+ source=source,
220
207
  is_user_function=True,
221
208
  skip_forward_codegen=skip_forward_codegen,
222
209
  skip_reverse_codegen=skip_reverse_codegen,
@@ -244,7 +231,7 @@ class Function:
244
231
 
245
232
  # embedded linked list of all overloads
246
233
  # the builtin_functions dictionary holds the list head for a given key (func name)
247
- self.overloads: List[Function] = []
234
+ self.overloads: list[Function] = []
248
235
 
249
236
  # builtin (native) function, canonicalize argument types
250
237
  if input_types is not None:
@@ -293,10 +280,11 @@ class Function:
293
280
  module.register_function(self, scope_locals, skip_adding_overload)
294
281
 
295
282
  def __call__(self, *args, **kwargs):
296
- # handles calling a builtin (native) function
297
- # as if it was a Python function, i.e.: from
298
- # within the CPython interpreter rather than
299
- # from within a kernel (experimental).
283
+ """Call this function from the CPython interpreter.
284
+
285
+ This is used to call built-in or user functions from the CPython
286
+ interpreter, rather than from within a kernel.
287
+ """
300
288
 
301
289
  if self.is_builtin() and self.mangled_name:
302
290
  # For each of this function's existing overloads, we attempt to pack
@@ -306,7 +294,23 @@ class Function:
306
294
  if overload.generic:
307
295
  continue
308
296
 
309
- success, return_value = call_builtin(overload, *args)
297
+ try:
298
+ # Try to bind the given arguments to the function's signature.
299
+ # This is not checking whether the argument types are matching,
300
+ # rather it's just assigning each argument to the corresponding
301
+ # function parameter.
302
+ bound_args = self.signature.bind(*args, **kwargs)
303
+ except TypeError:
304
+ continue
305
+
306
+ if self.defaults:
307
+ # Populate the bound arguments with any default values.
308
+ default_args = {k: v for k, v in self.defaults.items() if k not in bound_args.arguments}
309
+ warp.codegen.apply_defaults(bound_args, default_args)
310
+
311
+ bound_args = tuple(bound_args.arguments.values())
312
+
313
+ success, return_value = call_builtin(overload, bound_args)
310
314
  if success:
311
315
  return return_value
312
316
 
@@ -324,6 +328,9 @@ class Function:
324
328
 
325
329
  arguments = tuple(bound_args.arguments.values())
326
330
 
331
+ # Store the last runtime error we encountered from a function execution
332
+ last_execution_error = None
333
+
327
334
  # try and find a matching overload
328
335
  for overload in self.user_overloads.values():
329
336
  if len(overload.input_types) != len(arguments):
@@ -334,10 +341,25 @@ class Function:
334
341
  # attempt to unify argument types with function template types
335
342
  warp.types.infer_argument_types(arguments, template_types, arg_names)
336
343
  return overload.func(*arguments)
337
- except Exception:
344
+ except Exception as e:
345
+ # The function was callable but threw an error during its execution.
346
+ # This might be the intended overload, but it failed, or it might be the wrong overload.
347
+ # We save this specific error and continue, just in case another overload later in the
348
+ # list is a better match and doesn't fail.
349
+ last_execution_error = e
338
350
  continue
339
351
 
340
- raise RuntimeError(f"Error calling function '{self.key}', no overload found for arguments {args}")
352
+ if last_execution_error:
353
+ # Raise a new, more contextual RuntimeError, but link it to the
354
+ # original error that was caught. This preserves the original
355
+ # traceback and error type for easier debugging.
356
+ raise RuntimeError(
357
+ f"Error calling function '{self.key}'. No version succeeded. "
358
+ f"See above for the error from the last version that was tried."
359
+ ) from last_execution_error
360
+ else:
361
+ # We got here without ever calling an overload.func
362
+ raise RuntimeError(f"Error calling function '{self.key}', no overload found for arguments {args}")
341
363
 
342
364
  # user-defined function with no overloads
343
365
  if self.func is None:
@@ -358,9 +380,6 @@ class Function:
358
380
  if warp.types.is_array(v) or v in complex_type_hints:
359
381
  return False
360
382
 
361
- if type(self.value_type) in sequence_types:
362
- return False
363
-
364
383
  return True
365
384
 
366
385
  def mangle(self) -> str:
@@ -404,8 +423,12 @@ class Function:
404
423
  else:
405
424
  self.user_overloads[sig] = f
406
425
 
407
- def get_overload(self, arg_types: List[type], kwarg_types: Mapping[str, type]) -> Optional[Function]:
408
- assert not self.is_builtin()
426
+ def get_overload(self, arg_types: list[type], kwarg_types: Mapping[str, type]) -> Function | None:
427
+ if self.is_builtin():
428
+ for f in self.overloads:
429
+ if warp.codegen.func_match_args(f, arg_types, kwarg_types):
430
+ return f
431
+ return None
409
432
 
410
433
  for f in self.user_overloads.values():
411
434
  if warp.codegen.func_match_args(f, arg_types, kwarg_types):
@@ -439,7 +462,7 @@ class Function:
439
462
  overload_annotations[k] = warp.codegen.strip_reference(warp.codegen.get_arg_type(d))
440
463
 
441
464
  ovl = shallowcopy(f)
442
- ovl.adj = warp.codegen.Adjoint(f.func, overload_annotations)
465
+ ovl.adj = warp.codegen.Adjoint(f.func, overload_annotations, source=f.adj.source)
443
466
  ovl.input_types = overload_annotations
444
467
  ovl.value_func = None
445
468
  ovl.generic_parent = f
@@ -475,11 +498,25 @@ def get_builtin_type(return_type: type) -> type:
475
498
  return return_type
476
499
 
477
500
 
478
- def call_builtin(func: Function, *params: Any) -> Tuple[bool, Any]:
501
+ def extract_return_value(value_type: type, value_ctype: type, ret: Any) -> Any:
502
+ if issubclass(value_ctype, ctypes.Array) or issubclass(value_ctype, ctypes.Structure):
503
+ # return vector types as ctypes
504
+ return ret
505
+
506
+ if value_type is warp.types.float16:
507
+ return warp.types.half_bits_to_float(ret.value)
508
+
509
+ return ret.value
510
+
511
+
512
+ def call_builtin(func: Function, params: tuple) -> tuple[bool, Any]:
479
513
  uses_non_warp_array_type = False
480
514
 
481
515
  init()
482
516
 
517
+ if func.mangled_name is None:
518
+ return (False, None)
519
+
483
520
  # Retrieve the built-in function from Warp's dll.
484
521
  c_func = getattr(warp.context.runtime.core, func.mangled_name)
485
522
 
@@ -489,6 +526,8 @@ def call_builtin(func: Function, *params: Any) -> Tuple[bool, Any]:
489
526
  else:
490
527
  func_args = func.input_types
491
528
 
529
+ value_type = func.value_func(None, None)
530
+
492
531
  # Try gathering the parameters that the function expects and pack them
493
532
  # into their corresponding C types.
494
533
  c_params = []
@@ -604,9 +643,9 @@ def call_builtin(func: Function, *params: Any) -> Tuple[bool, Any]:
604
643
 
605
644
  if not (
606
645
  isinstance(param, arg_type)
607
- or (type(param) is float and arg_type is warp.types.float32) # noqa: E721
608
- or (type(param) is int and arg_type is warp.types.int32) # noqa: E721
609
- or (type(param) is bool and arg_type is warp.types.bool) # noqa: E721
646
+ or (type(param) is float and arg_type is warp.types.float32)
647
+ or (type(param) is int and arg_type is warp.types.int32)
648
+ or (type(param) is bool and arg_type is warp.types.bool)
610
649
  or warp.types.np_dtype_to_warp_type.get(getattr(param, "dtype", None)) is arg_type
611
650
  ):
612
651
  return (False, None)
@@ -620,25 +659,18 @@ def call_builtin(func: Function, *params: Any) -> Tuple[bool, Any]:
620
659
  else:
621
660
  c_params.append(arg_type._type_(param))
622
661
 
623
- # returns the corresponding ctype for a scalar or vector warp type
624
- value_type = func.value_func(None, None)
662
+ # Retrieve the return type.
663
+ value_type = func.value_func(func_args, None)
625
664
 
626
- if value_type == float:
627
- value_ctype = ctypes.c_float
628
- elif value_type == int:
629
- value_ctype = ctypes.c_int32
630
- elif value_type == bool:
631
- value_ctype = ctypes.c_bool
632
- elif issubclass(value_type, (ctypes.Array, ctypes.Structure)):
633
- value_ctype = value_type
634
- else:
635
- # scalar type
636
- value_ctype = value_type._type_
665
+ if value_type is not None:
666
+ if not isinstance(value_type, Sequence):
667
+ value_type = (value_type,)
668
+
669
+ value_ctype = tuple(warp.types.type_ctype(x) for x in value_type)
670
+ ret = tuple(x() for x in value_ctype)
671
+ ret_addr = tuple(ctypes.c_void_p(ctypes.addressof(x)) for x in ret)
637
672
 
638
- # construct return value (passed by address)
639
- ret = value_ctype()
640
- ret_addr = ctypes.c_void_p(ctypes.addressof(ret))
641
- c_params.append(ret_addr)
673
+ c_params.extend(ret_addr)
642
674
 
643
675
  # Call the built-in function from Warp's dll.
644
676
  c_func(*c_params)
@@ -653,17 +685,14 @@ def call_builtin(func: Function, *params: Any) -> Tuple[bool, Any]:
653
685
  stacklevel=3,
654
686
  )
655
687
 
656
- if issubclass(value_ctype, ctypes.Array) or issubclass(value_ctype, ctypes.Structure):
657
- # return vector types as ctypes
658
- return (True, ret)
688
+ if value_type is None:
689
+ return (True, None)
659
690
 
660
- if value_type == warp.types.float16:
661
- value = warp.types.half_bits_to_float(ret.value)
662
- else:
663
- value = ret.value
691
+ return_value = tuple(extract_return_value(x, y, z) for x, y, z in zip(value_type, value_ctype, ret))
692
+ if len(return_value) == 1:
693
+ return_value = return_value[0]
664
694
 
665
- # return scalar types as int/float
666
- return (True, value)
695
+ return (True, return_value)
667
696
 
668
697
 
669
698
  class KernelHooks:
@@ -677,7 +706,7 @@ class KernelHooks:
677
706
 
678
707
  # caches source and compiled entry points for a kernel (will be populated after module loads)
679
708
  class Kernel:
680
- def __init__(self, func, key=None, module=None, options=None, code_transformers=None):
709
+ def __init__(self, func, key=None, module=None, options=None, code_transformers=None, source=None):
681
710
  self.func = func
682
711
 
683
712
  if module is None:
@@ -695,7 +724,7 @@ class Kernel:
695
724
  if code_transformers is None:
696
725
  code_transformers = []
697
726
 
698
- self.adj = warp.codegen.Adjoint(func, transformers=code_transformers)
727
+ self.adj = warp.codegen.Adjoint(func, transformers=code_transformers, source=source)
699
728
 
700
729
  # check if generic
701
730
  self.is_generic = False
@@ -762,7 +791,7 @@ class Kernel:
762
791
 
763
792
  # instantiate this kernel with the given argument types
764
793
  ovl = shallowcopy(self)
765
- ovl.adj = warp.codegen.Adjoint(self.func, overload_annotations)
794
+ ovl.adj = warp.codegen.Adjoint(self.func, overload_annotations, source=self.adj.source)
766
795
  ovl.is_generic = False
767
796
  ovl.overloads = {}
768
797
  ovl.sig = sig
@@ -798,7 +827,7 @@ class Kernel:
798
827
 
799
828
 
800
829
  # decorator to register function, @func
801
- def func(f: Optional[Callable] = None, *, name: Optional[str] = None):
830
+ def func(f: Callable | None = None, *, name: str | None = None):
802
831
  def wrapper(f, *args, **kwargs):
803
832
  if name is None:
804
833
  key = warp.codegen.make_full_qualified_name(f)
@@ -831,7 +860,7 @@ def func(f: Optional[Callable] = None, *, name: Optional[str] = None):
831
860
  return wrapper(f)
832
861
 
833
862
 
834
- def func_native(snippet: str, adj_snippet: Optional[str] = None, replay_snippet: Optional[str] = None):
863
+ def func_native(snippet: str, adj_snippet: str | None = None, replay_snippet: str | None = None):
835
864
  """
836
865
  Decorator to register native code snippet, @func_native
837
866
  """
@@ -1015,10 +1044,10 @@ def func_replay(forward_fn):
1015
1044
 
1016
1045
 
1017
1046
  def kernel(
1018
- f: Optional[Callable] = None,
1047
+ f: Callable | None = None,
1019
1048
  *,
1020
- enable_backward: Optional[bool] = None,
1021
- module: Optional[Union[Module, Literal["unique"]]] = None,
1049
+ enable_backward: bool | None = None,
1050
+ module: Module | Literal["unique"] | None = None,
1022
1051
  ):
1023
1052
  """
1024
1053
  Decorator to register a Warp kernel from a Python function.
@@ -1181,7 +1210,7 @@ def overload(kernel, arg_types=Union[None, Dict[str, Any], List[Any]]):
1181
1210
 
1182
1211
 
1183
1212
  # native functions that are part of the Warp API
1184
- builtin_functions: Dict[str, Function] = {}
1213
+ builtin_functions: dict[str, Function] = {}
1185
1214
 
1186
1215
 
1187
1216
  def get_generic_vtypes():
@@ -1204,13 +1233,13 @@ scalar_types.update({x: x._wp_scalar_type_ for x in warp.types.vector_types})
1204
1233
 
1205
1234
  def add_builtin(
1206
1235
  key: str,
1207
- input_types: Optional[Dict[str, Union[type, TypeVar]]] = None,
1208
- constraint: Optional[Callable[[Mapping[str, type]], bool]] = None,
1209
- value_type: Optional[type] = None,
1210
- value_func: Optional[Callable] = None,
1211
- export_func: Optional[Callable] = None,
1212
- dispatch_func: Optional[Callable] = None,
1213
- lto_dispatch_func: Optional[Callable] = None,
1236
+ input_types: dict[str, type | TypeVar] | None = None,
1237
+ constraint: Callable[[Mapping[str, type]], bool] | None = None,
1238
+ value_type: type | None = None,
1239
+ value_func: Callable | None = None,
1240
+ export_func: Callable | None = None,
1241
+ dispatch_func: Callable | None = None,
1242
+ lto_dispatch_func: Callable | None = None,
1214
1243
  doc: str = "",
1215
1244
  namespace: str = "wp::",
1216
1245
  variadic: bool = False,
@@ -1220,8 +1249,8 @@ def add_builtin(
1220
1249
  hidden: bool = False,
1221
1250
  skip_replay: bool = False,
1222
1251
  missing_grad: bool = False,
1223
- native_func: Optional[str] = None,
1224
- defaults: Optional[Dict[str, Any]] = None,
1252
+ native_func: str | None = None,
1253
+ defaults: dict[str, Any] | None = None,
1225
1254
  require_original_output_arg: bool = False,
1226
1255
  ):
1227
1256
  """Main entry point to register a new built-in function.
@@ -1371,18 +1400,13 @@ def add_builtin(
1371
1400
 
1372
1401
  return_type = value_func(concrete_arg_types, None)
1373
1402
 
1374
- # The return_type might just be vector_t(length=3,dtype=wp.float32), so we've got to match that
1375
- # in the list of hard coded types so it knows it's returning one of them:
1376
- if hasattr(return_type, "_wp_generic_type_hint_"):
1377
- return_type_match = tuple(
1378
- x
1379
- for x in generic_vtypes
1380
- if x._wp_generic_type_hint_ == return_type._wp_generic_type_hint_
1381
- and x._wp_type_params_ == return_type._wp_type_params_
1382
- )
1383
- if not return_type_match:
1384
- continue
1385
- return_type = return_type_match[0]
1403
+ try:
1404
+ if isinstance(return_type, Sequence):
1405
+ return_type = tuple(get_builtin_type(x) for x in return_type)
1406
+ else:
1407
+ return_type = get_builtin_type(return_type)
1408
+ except RuntimeError:
1409
+ continue
1386
1410
 
1387
1411
  # finally we can generate a function call for these concrete types:
1388
1412
  add_builtin(
@@ -1485,7 +1509,7 @@ def register_api_function(
1485
1509
 
1486
1510
 
1487
1511
  # global dictionary of modules
1488
- user_modules: Dict[str, Module] = {}
1512
+ user_modules: dict[str, Module] = {}
1489
1513
 
1490
1514
 
1491
1515
  def get_module(name: str) -> Module:
@@ -1608,7 +1632,7 @@ class ModuleHasher:
1608
1632
  ch.update(bytes(func.key, "utf-8"))
1609
1633
 
1610
1634
  # include all concrete and generic overloads
1611
- overloads: Dict[str, Function] = {**func.user_overloads, **func.user_templates}
1635
+ overloads: dict[str, Function] = {**func.user_overloads, **func.user_templates}
1612
1636
  for sig in sorted(overloads.keys()):
1613
1637
  ovl = overloads[sig]
1614
1638
 
@@ -1668,7 +1692,7 @@ class ModuleHasher:
1668
1692
  ch.update(bytes(name, "utf-8"))
1669
1693
  ch.update(self.get_constant_bytes(value))
1670
1694
 
1671
- # hash wp.static() expressions that were evaluated at declaration time
1695
+ # hash wp.static() expressions
1672
1696
  for k, v in adj.static_expressions.items():
1673
1697
  ch.update(bytes(k, "utf-8"))
1674
1698
  if isinstance(v, Function):
@@ -1857,7 +1881,7 @@ class ModuleBuilder:
1857
1881
  # the original Modules get reloaded.
1858
1882
  class ModuleExec:
1859
1883
  def __new__(cls, *args, **kwargs):
1860
- instance = super(ModuleExec, cls).__new__(cls)
1884
+ instance = super().__new__(cls)
1861
1885
  instance.handle = None
1862
1886
  return instance
1863
1887
 
@@ -1952,7 +1976,7 @@ class ModuleExec:
1952
1976
  # creates a hash of the function to use for checking
1953
1977
  # build cache
1954
1978
  class Module:
1955
- def __init__(self, name: Optional[str], loader=None):
1979
+ def __init__(self, name: str | None, loader=None):
1956
1980
  self.name = name if name is not None else "None"
1957
1981
 
1958
1982
  self.loader = loader
@@ -1987,6 +2011,9 @@ class Module:
1987
2011
  # is retained and later reloaded with the same hash.
1988
2012
  self.cpu_exec_id = 0
1989
2013
 
2014
+ # Indicates whether the module has functions or kernels with unresolved static expressions.
2015
+ self.has_unresolved_static_expressions = False
2016
+
1990
2017
  self.options = {
1991
2018
  "max_unroll": warp.config.max_unroll,
1992
2019
  "enable_backward": warp.config.enable_backward,
@@ -1994,8 +2021,9 @@ class Module:
1994
2021
  "fuse_fp": True,
1995
2022
  "lineinfo": warp.config.lineinfo,
1996
2023
  "cuda_output": None, # supported values: "ptx", "cubin", or None (automatic)
1997
- "mode": warp.config.mode,
2024
+ "mode": None,
1998
2025
  "block_dim": 256,
2026
+ "compile_time_trace": warp.config.compile_time_trace,
1999
2027
  }
2000
2028
 
2001
2029
  # Module dependencies are determined by scanning each function
@@ -2022,6 +2050,10 @@ class Module:
2022
2050
  # track all kernel objects, even if they are duplicates
2023
2051
  self._live_kernels.add(kernel)
2024
2052
 
2053
+ # Check for unresolved static expressions in the kernel.
2054
+ if kernel.adj.has_unresolved_static_expressions:
2055
+ self.has_unresolved_static_expressions = True
2056
+
2025
2057
  self.find_references(kernel.adj)
2026
2058
 
2027
2059
  # for a reload of module on next launch
@@ -2081,6 +2113,10 @@ class Module:
2081
2113
  del func_existing.user_overloads[k]
2082
2114
  func_existing.add_overload(func)
2083
2115
 
2116
+ # Check for unresolved static expressions in the function.
2117
+ if func.adj.has_unresolved_static_expressions:
2118
+ self.has_unresolved_static_expressions = True
2119
+
2084
2120
  self.find_references(func.adj)
2085
2121
 
2086
2122
  # for a reload of module on next launch
@@ -2140,7 +2176,7 @@ class Module:
2140
2176
  self.hashers[block_dim] = ModuleHasher(self)
2141
2177
  return self.hashers[block_dim].get_module_hash()
2142
2178
 
2143
- def load(self, device, block_dim=None) -> ModuleExec:
2179
+ def load(self, device, block_dim=None) -> ModuleExec | None:
2144
2180
  device = runtime.get_device(device)
2145
2181
 
2146
2182
  # update module options if launching with a new block dim
@@ -2149,6 +2185,20 @@ class Module:
2149
2185
 
2150
2186
  active_block_dim = self.options["block_dim"]
2151
2187
 
2188
+ if self.has_unresolved_static_expressions:
2189
+ # The module hash currently does not account for unresolved static expressions
2190
+ # (only static expressions evaluated at declaration time so far).
2191
+ # We need to generate the code for the functions and kernels that have
2192
+ # unresolved static expressions and then compute the module hash again.
2193
+ builder_options = {
2194
+ **self.options,
2195
+ "output_arch": None,
2196
+ }
2197
+ # build functions, kernels to resolve static expressions
2198
+ _ = ModuleBuilder(self, builder_options)
2199
+
2200
+ self.has_unresolved_static_expressions = False
2201
+
2152
2202
  # compute the hash if needed
2153
2203
  if active_block_dim not in self.hashers:
2154
2204
  self.hashers[active_block_dim] = ModuleHasher(self)
@@ -2222,7 +2272,7 @@ class Module:
2222
2272
  ):
2223
2273
  builder_options = {
2224
2274
  **self.options,
2225
- # Some of the Tile codegen, such as cuFFTDx and cuBLASDx, requires knowledge of the target arch
2275
+ # Some of the tile codegen, such as cuFFTDx and cuBLASDx, requires knowledge of the target arch
2226
2276
  "output_arch": output_arch,
2227
2277
  }
2228
2278
  builder = ModuleBuilder(self, builder_options, hasher=self.hashers[active_block_dim])
@@ -2237,6 +2287,8 @@ class Module:
2237
2287
 
2238
2288
  module_load_timer.extra_msg = " (compiled)" # For wp.ScopedTimer informational purposes
2239
2289
 
2290
+ mode = self.options["mode"] if self.options["mode"] is not None else warp.config.mode
2291
+
2240
2292
  # build CPU
2241
2293
  if device.is_cpu:
2242
2294
  # build
@@ -2256,7 +2308,7 @@ class Module:
2256
2308
  warp.build.build_cpu(
2257
2309
  output_path,
2258
2310
  source_code_path,
2259
- mode=self.options["mode"],
2311
+ mode=mode,
2260
2312
  fast_math=self.options["fast_math"],
2261
2313
  verify_fp=warp.config.verify_fp,
2262
2314
  fuse_fp=self.options["fuse_fp"],
@@ -2286,11 +2338,12 @@ class Module:
2286
2338
  source_code_path,
2287
2339
  output_arch,
2288
2340
  output_path,
2289
- config=self.options["mode"],
2341
+ config=mode,
2290
2342
  verify_fp=warp.config.verify_fp,
2291
2343
  fast_math=self.options["fast_math"],
2292
2344
  fuse_fp=self.options["fuse_fp"],
2293
2345
  lineinfo=self.options["lineinfo"],
2346
+ compile_time_trace=self.options["compile_time_trace"],
2294
2347
  ltoirs=builder.ltoirs.values(),
2295
2348
  fatbins=builder.fatbins.values(),
2296
2349
  )
@@ -2343,7 +2396,7 @@ class Module:
2343
2396
  # Load CPU or CUDA binary
2344
2397
 
2345
2398
  meta_path = os.path.join(module_dir, f"{module_name_short}.meta")
2346
- with open(meta_path, "r") as meta_file:
2399
+ with open(meta_path) as meta_file:
2347
2400
  meta = json.load(meta_file)
2348
2401
 
2349
2402
  if device.is_cpu:
@@ -2406,7 +2459,7 @@ class CpuDefaultAllocator:
2406
2459
  def alloc(self, size_in_bytes):
2407
2460
  ptr = runtime.core.alloc_host(size_in_bytes)
2408
2461
  if not ptr:
2409
- raise RuntimeError(f"Failed to allocate {size_in_bytes} bytes on device '{self.device}'")
2462
+ raise RuntimeError(f"Failed to allocate {size_in_bytes} bytes on device 'cpu'")
2410
2463
  return ptr
2411
2464
 
2412
2465
  def free(self, ptr, size_in_bytes):
@@ -2510,12 +2563,12 @@ class Event:
2510
2563
 
2511
2564
  def __new__(cls, *args, **kwargs):
2512
2565
  """Creates a new event instance."""
2513
- instance = super(Event, cls).__new__(cls)
2566
+ instance = super().__new__(cls)
2514
2567
  instance.owner = False
2515
2568
  return instance
2516
2569
 
2517
2570
  def __init__(
2518
- self, device: "Devicelike" = None, cuda_event=None, enable_timing: bool = False, interprocess: bool = False
2571
+ self, device: Devicelike = None, cuda_event=None, enable_timing: bool = False, interprocess: bool = False
2519
2572
  ):
2520
2573
  """Initializes the event on a CUDA device.
2521
2574
 
@@ -2611,12 +2664,12 @@ class Event:
2611
2664
 
2612
2665
  class Stream:
2613
2666
  def __new__(cls, *args, **kwargs):
2614
- instance = super(Stream, cls).__new__(cls)
2667
+ instance = super().__new__(cls)
2615
2668
  instance.cuda_stream = None
2616
2669
  instance.owner = False
2617
2670
  return instance
2618
2671
 
2619
- def __init__(self, device: Union["Device", str, None] = None, priority: int = 0, **kwargs):
2672
+ def __init__(self, device: Device | str | None = None, priority: int = 0, **kwargs):
2620
2673
  """Initialize the stream on a device with an optional specified priority.
2621
2674
 
2622
2675
  Args:
@@ -2682,7 +2735,7 @@ class Stream:
2682
2735
  self._cached_event = Event(self.device)
2683
2736
  return self._cached_event
2684
2737
 
2685
- def record_event(self, event: Optional[Event] = None) -> Event:
2738
+ def record_event(self, event: Event | None = None) -> Event:
2686
2739
  """Record an event onto the stream.
2687
2740
 
2688
2741
  Args:
@@ -2711,7 +2764,7 @@ class Stream:
2711
2764
  """
2712
2765
  runtime.core.cuda_stream_wait_event(self.cuda_stream, event.cuda_event)
2713
2766
 
2714
- def wait_stream(self, other_stream: "Stream", event: Optional[Event] = None):
2767
+ def wait_stream(self, other_stream: Stream, event: Event | None = None):
2715
2768
  """Records an event on `other_stream` and makes this stream wait on it.
2716
2769
 
2717
2770
  All work added to this stream after this function has been called will
@@ -2765,6 +2818,8 @@ class Device:
2765
2818
  or ``"CPU"`` if the processor name cannot be determined.
2766
2819
  arch (int): The compute capability version number calculated as ``10 * major + minor``.
2767
2820
  ``0`` for CPU devices.
2821
+ sm_count (int): The number of streaming multiprocessors on the CUDA device.
2822
+ ``0`` for CPU devices.
2768
2823
  is_uva (bool): Indicates whether the device supports unified addressing.
2769
2824
  ``False`` for CPU devices.
2770
2825
  is_cubin_supported (bool): Indicates whether Warp's version of NVRTC can directly
@@ -2810,6 +2865,7 @@ class Device:
2810
2865
  # CPU device
2811
2866
  self.name = platform.processor() or "CPU"
2812
2867
  self.arch = 0
2868
+ self.sm_count = 0
2813
2869
  self.is_uva = False
2814
2870
  self.is_mempool_supported = False
2815
2871
  self.is_mempool_enabled = False
@@ -2829,6 +2885,7 @@ class Device:
2829
2885
  # CUDA device
2830
2886
  self.name = runtime.core.cuda_device_get_name(ordinal).decode()
2831
2887
  self.arch = runtime.core.cuda_device_get_arch(ordinal)
2888
+ self.sm_count = runtime.core.cuda_device_get_sm_count(ordinal)
2832
2889
  self.is_uva = runtime.core.cuda_device_is_uva(ordinal) > 0
2833
2890
  self.is_mempool_supported = runtime.core.cuda_device_is_mempool_supported(ordinal) > 0
2834
2891
  if platform.system() == "Linux":
@@ -3070,16 +3127,23 @@ class Graph:
3070
3127
  def __init__(self, device: Device, capture_id: int):
3071
3128
  self.device = device
3072
3129
  self.capture_id = capture_id
3073
- self.module_execs: Set[ModuleExec] = set()
3074
- self.graph_exec: Optional[ctypes.c_void_p] = None
3130
+ self.module_execs: set[ModuleExec] = set()
3131
+ self.graph_exec: ctypes.c_void_p | None = None
3132
+
3133
+ self.graph: ctypes.c_void_p | None = None
3134
+ self.has_conditional = (
3135
+ False # Track if there are conditional nodes in the graph since they are not allowed in child graphs
3136
+ )
3075
3137
 
3076
3138
  def __del__(self):
3077
- if not hasattr(self, "graph_exec") or not hasattr(self, "device") or not self.graph_exec:
3139
+ if not hasattr(self, "graph") or not hasattr(self, "device") or not self.graph:
3078
3140
  return
3079
3141
 
3080
3142
  # use CUDA context guard to avoid side effects during garbage collection
3081
3143
  with self.device.context_guard:
3082
- runtime.core.cuda_graph_destroy(self.device.context, self.graph_exec)
3144
+ runtime.core.cuda_graph_destroy(self.device.context, self.graph)
3145
+ if hasattr(self, "graph_exec") and self.graph_exec is not None:
3146
+ runtime.core.cuda_graph_exec_destroy(self.device.context, self.graph_exec)
3083
3147
 
3084
3148
  # retain executable CUDA modules used by this graph, which prevents them from being unloaded
3085
3149
  def retain_module_exec(self, module_exec: ModuleExec):
@@ -3088,8 +3152,6 @@ class Graph:
3088
3152
 
3089
3153
  class Runtime:
3090
3154
  def __init__(self):
3091
- if sys.version_info < (3, 8):
3092
- raise RuntimeError("Warp requires Python 3.8 as a minimum")
3093
3155
  if sys.version_info < (3, 9):
3094
3156
  warp.utils.warn(f"Python 3.9 or newer is recommended for running Warp, detected {sys.version_info}")
3095
3157
 
@@ -3535,44 +3597,40 @@ class Runtime:
3535
3597
  self.core.volume_get_blind_data_info.restype = ctypes.c_char_p
3536
3598
 
3537
3599
  bsr_matrix_from_triplets_argtypes = [
3538
- ctypes.c_int, # rows_per_bock
3539
- ctypes.c_int, # cols_per_blocks
3600
+ ctypes.c_int, # block_size
3601
+ ctypes.c_int, # scalar size in bytes
3540
3602
  ctypes.c_int, # row_count
3541
- ctypes.c_int, # tpl_nnz
3603
+ ctypes.c_int, # col_count
3604
+ ctypes.c_int, # nnz_upper_bound
3605
+ ctypes.POINTER(ctypes.c_int), # tpl_nnz
3542
3606
  ctypes.POINTER(ctypes.c_int), # tpl_rows
3543
3607
  ctypes.POINTER(ctypes.c_int), # tpl_cols
3544
3608
  ctypes.c_void_p, # tpl_values
3545
- ctypes.c_bool, # prune_numerical_zeros
3609
+ ctypes.c_uint64, # zero_value_mask
3546
3610
  ctypes.c_bool, # masked
3547
3611
  ctypes.POINTER(ctypes.c_int), # bsr_offsets
3548
3612
  ctypes.POINTER(ctypes.c_int), # bsr_columns
3549
- ctypes.c_void_p, # bsr_values
3613
+ ctypes.POINTER(ctypes.c_int), # prefix sum of block count to sum for each bsr block
3614
+ ctypes.POINTER(ctypes.c_int), # indices to ptriplet blocks to sum for each bsr block
3550
3615
  ctypes.POINTER(ctypes.c_int), # bsr_nnz
3551
3616
  ctypes.c_void_p, # bsr_nnz_event
3552
3617
  ]
3553
3618
 
3554
- self.core.bsr_matrix_from_triplets_float_host.argtypes = bsr_matrix_from_triplets_argtypes
3555
- self.core.bsr_matrix_from_triplets_double_host.argtypes = bsr_matrix_from_triplets_argtypes
3556
- self.core.bsr_matrix_from_triplets_float_device.argtypes = bsr_matrix_from_triplets_argtypes
3557
- self.core.bsr_matrix_from_triplets_double_device.argtypes = bsr_matrix_from_triplets_argtypes
3619
+ self.core.bsr_matrix_from_triplets_host.argtypes = bsr_matrix_from_triplets_argtypes
3620
+ self.core.bsr_matrix_from_triplets_device.argtypes = bsr_matrix_from_triplets_argtypes
3558
3621
 
3559
3622
  bsr_transpose_argtypes = [
3560
- ctypes.c_int, # rows_per_bock
3561
- ctypes.c_int, # cols_per_blocks
3562
3623
  ctypes.c_int, # row_count
3563
3624
  ctypes.c_int, # col count
3564
3625
  ctypes.c_int, # nnz
3565
3626
  ctypes.POINTER(ctypes.c_int), # transposed_bsr_offsets
3566
3627
  ctypes.POINTER(ctypes.c_int), # transposed_bsr_columns
3567
- ctypes.c_void_p, # bsr_values
3568
3628
  ctypes.POINTER(ctypes.c_int), # transposed_bsr_offsets
3569
3629
  ctypes.POINTER(ctypes.c_int), # transposed_bsr_columns
3570
- ctypes.c_void_p, # transposed_bsr_values
3630
+ ctypes.POINTER(ctypes.c_int), # src to dest block map
3571
3631
  ]
3572
- self.core.bsr_transpose_float_host.argtypes = bsr_transpose_argtypes
3573
- self.core.bsr_transpose_double_host.argtypes = bsr_transpose_argtypes
3574
- self.core.bsr_transpose_float_device.argtypes = bsr_transpose_argtypes
3575
- self.core.bsr_transpose_double_device.argtypes = bsr_transpose_argtypes
3632
+ self.core.bsr_transpose_host.argtypes = bsr_transpose_argtypes
3633
+ self.core.bsr_transpose_device.argtypes = bsr_transpose_argtypes
3576
3634
 
3577
3635
  self.core.is_cuda_enabled.argtypes = None
3578
3636
  self.core.is_cuda_enabled.restype = ctypes.c_int
@@ -3601,6 +3659,8 @@ class Runtime:
3601
3659
  self.core.cuda_device_get_name.restype = ctypes.c_char_p
3602
3660
  self.core.cuda_device_get_arch.argtypes = [ctypes.c_int]
3603
3661
  self.core.cuda_device_get_arch.restype = ctypes.c_int
3662
+ self.core.cuda_device_get_sm_count.argtypes = [ctypes.c_int]
3663
+ self.core.cuda_device_get_sm_count.restype = ctypes.c_int
3604
3664
  self.core.cuda_device_is_uva.argtypes = [ctypes.c_int]
3605
3665
  self.core.cuda_device_is_uva.restype = ctypes.c_int
3606
3666
  self.core.cuda_device_is_mempool_supported.argtypes = [ctypes.c_int]
@@ -3724,11 +3784,73 @@ class Runtime:
3724
3784
  ctypes.POINTER(ctypes.c_void_p),
3725
3785
  ]
3726
3786
  self.core.cuda_graph_end_capture.restype = ctypes.c_bool
3787
+
3788
+ self.core.cuda_graph_create_exec.argtypes = [
3789
+ ctypes.c_void_p,
3790
+ ctypes.c_void_p,
3791
+ ctypes.c_void_p,
3792
+ ctypes.POINTER(ctypes.c_void_p),
3793
+ ]
3794
+ self.core.cuda_graph_create_exec.restype = ctypes.c_bool
3795
+
3796
+ self.core.capture_debug_dot_print.argtypes = [ctypes.c_void_p, ctypes.c_char_p, ctypes.c_uint32]
3797
+ self.core.capture_debug_dot_print.restype = ctypes.c_bool
3798
+
3727
3799
  self.core.cuda_graph_launch.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
3728
3800
  self.core.cuda_graph_launch.restype = ctypes.c_bool
3801
+ self.core.cuda_graph_exec_destroy.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
3802
+ self.core.cuda_graph_exec_destroy.restype = ctypes.c_bool
3803
+
3729
3804
  self.core.cuda_graph_destroy.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
3730
3805
  self.core.cuda_graph_destroy.restype = ctypes.c_bool
3731
3806
 
3807
+ self.core.cuda_graph_insert_if_else.argtypes = [
3808
+ ctypes.c_void_p,
3809
+ ctypes.c_void_p,
3810
+ ctypes.POINTER(ctypes.c_int),
3811
+ ctypes.POINTER(ctypes.c_void_p),
3812
+ ctypes.POINTER(ctypes.c_void_p),
3813
+ ]
3814
+ self.core.cuda_graph_insert_if_else.restype = ctypes.c_bool
3815
+
3816
+ self.core.cuda_graph_insert_while.argtypes = [
3817
+ ctypes.c_void_p,
3818
+ ctypes.c_void_p,
3819
+ ctypes.POINTER(ctypes.c_int),
3820
+ ctypes.POINTER(ctypes.c_void_p),
3821
+ ctypes.POINTER(ctypes.c_uint64),
3822
+ ]
3823
+ self.core.cuda_graph_insert_while.restype = ctypes.c_bool
3824
+
3825
+ self.core.cuda_graph_set_condition.argtypes = [
3826
+ ctypes.c_void_p,
3827
+ ctypes.c_void_p,
3828
+ ctypes.POINTER(ctypes.c_int),
3829
+ ctypes.c_uint64,
3830
+ ]
3831
+ self.core.cuda_graph_set_condition.restype = ctypes.c_bool
3832
+
3833
+ self.core.cuda_graph_pause_capture.argtypes = [
3834
+ ctypes.c_void_p,
3835
+ ctypes.c_void_p,
3836
+ ctypes.POINTER(ctypes.c_void_p),
3837
+ ]
3838
+ self.core.cuda_graph_pause_capture.restype = ctypes.c_bool
3839
+
3840
+ self.core.cuda_graph_resume_capture.argtypes = [
3841
+ ctypes.c_void_p,
3842
+ ctypes.c_void_p,
3843
+ ctypes.c_void_p,
3844
+ ]
3845
+ self.core.cuda_graph_resume_capture.restype = ctypes.c_bool
3846
+
3847
+ self.core.cuda_graph_insert_child_graph.argtypes = [
3848
+ ctypes.c_void_p,
3849
+ ctypes.c_void_p,
3850
+ ctypes.c_void_p,
3851
+ ]
3852
+ self.core.cuda_graph_insert_child_graph.restype = ctypes.c_bool
3853
+
3732
3854
  self.core.cuda_compile_program.argtypes = [
3733
3855
  ctypes.c_char_p, # cuda_src
3734
3856
  ctypes.c_char_p, # program name
@@ -3742,6 +3864,7 @@ class Runtime:
3742
3864
  ctypes.c_bool, # fast_math
3743
3865
  ctypes.c_bool, # fuse_fp
3744
3866
  ctypes.c_bool, # lineinfo
3867
+ ctypes.c_bool, # compile_time_trace
3745
3868
  ctypes.c_char_p, # output_path
3746
3869
  ctypes.c_size_t, # num_ltoirs
3747
3870
  ctypes.POINTER(ctypes.c_char_p), # ltoirs
@@ -3796,11 +3919,17 @@ class Runtime:
3796
3919
  ctypes.c_int, # arch
3797
3920
  ctypes.c_int, # M
3798
3921
  ctypes.c_int, # N
3922
+ ctypes.c_int, # NRHS
3923
+ ctypes.c_int, # function
3924
+ ctypes.c_int, # side
3925
+ ctypes.c_int, # diag
3799
3926
  ctypes.c_int, # precision
3927
+ ctypes.c_int, # a_arrangement
3928
+ ctypes.c_int, # b_arrangement
3800
3929
  ctypes.c_int, # fill_mode
3801
3930
  ctypes.c_int, # num threads
3802
3931
  ]
3803
- self.core.cuda_compile_fft.restype = ctypes.c_bool
3932
+ self.core.cuda_compile_solver.restype = ctypes.c_bool
3804
3933
 
3805
3934
  self.core.cuda_load_module.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
3806
3935
  self.core.cuda_load_module.restype = ctypes.c_void_p
@@ -3965,9 +4094,14 @@ class Runtime:
3965
4094
  # Update the default PTX architecture based on devices present in the system.
3966
4095
  # Use the lowest architecture among devices that meet the minimum architecture requirement.
3967
4096
  # Devices below the required minimum will use the highest architecture they support.
3968
- eligible_archs = [d.arch for d in self.cuda_devices if d.arch >= self.default_ptx_arch]
3969
- if eligible_archs:
3970
- self.default_ptx_arch = min(eligible_archs)
4097
+ try:
4098
+ self.default_ptx_arch = min(
4099
+ d.arch
4100
+ for d in self.cuda_devices
4101
+ if d.arch >= self.default_ptx_arch and d.arch in self.nvrtc_supported_archs
4102
+ )
4103
+ except ValueError:
4104
+ pass # no eligible NVRTC-supported arch ≥ default, retain existing
3971
4105
  else:
3972
4106
  # CUDA not available
3973
4107
  self.set_default_device("cpu")
@@ -4270,7 +4404,7 @@ def is_cuda_driver_initialized() -> bool:
4270
4404
  return runtime.core.cuda_driver_is_initialized()
4271
4405
 
4272
4406
 
4273
- def get_devices() -> List[Device]:
4407
+ def get_devices() -> list[Device]:
4274
4408
  """Returns a list of devices supported in this environment."""
4275
4409
 
4276
4410
  init()
@@ -4291,7 +4425,7 @@ def get_cuda_device_count() -> int:
4291
4425
  return len(runtime.cuda_devices)
4292
4426
 
4293
4427
 
4294
- def get_cuda_device(ordinal: Union[int, None] = None) -> Device:
4428
+ def get_cuda_device(ordinal: int | None = None) -> Device:
4295
4429
  """Returns the CUDA device with the given ordinal or the current CUDA device if ordinal is None."""
4296
4430
 
4297
4431
  init()
@@ -4302,7 +4436,7 @@ def get_cuda_device(ordinal: Union[int, None] = None) -> Device:
4302
4436
  return runtime.cuda_devices[ordinal]
4303
4437
 
4304
4438
 
4305
- def get_cuda_devices() -> List[Device]:
4439
+ def get_cuda_devices() -> list[Device]:
4306
4440
  """Returns a list of CUDA devices supported in this environment."""
4307
4441
 
4308
4442
  init()
@@ -4341,7 +4475,7 @@ def set_device(ident: Devicelike) -> None:
4341
4475
  device.make_current()
4342
4476
 
4343
4477
 
4344
- def map_cuda_device(alias: str, context: Optional[ctypes.c_void_p] = None) -> Device:
4478
+ def map_cuda_device(alias: str, context: ctypes.c_void_p | None = None) -> Device:
4345
4479
  """Assign a device alias to a CUDA context.
4346
4480
 
4347
4481
  This function can be used to create a wp.Device for an external CUDA context.
@@ -4436,7 +4570,7 @@ def set_mempool_enabled(device: Devicelike, enable: bool) -> None:
4436
4570
  raise ValueError("Memory pools are only supported on CUDA devices")
4437
4571
 
4438
4572
 
4439
- def set_mempool_release_threshold(device: Devicelike, threshold: Union[int, float]) -> None:
4573
+ def set_mempool_release_threshold(device: Devicelike, threshold: int | float) -> None:
4440
4574
  """Set the CUDA memory pool release threshold on the device.
4441
4575
 
4442
4576
  This is the amount of reserved memory to hold onto before trying to release memory back to the OS.
@@ -4744,7 +4878,7 @@ def set_stream(stream: Stream, device: Devicelike = None, sync: bool = False) ->
4744
4878
  get_device(device).set_stream(stream, sync=sync)
4745
4879
 
4746
4880
 
4747
- def record_event(event: Optional[Event] = None):
4881
+ def record_event(event: Event | None = None):
4748
4882
  """Convenience function for calling :meth:`Stream.record_event` on the current stream.
4749
4883
 
4750
4884
  Args:
@@ -4793,7 +4927,7 @@ def get_event_elapsed_time(start_event: Event, end_event: Event, synchronize: bo
4793
4927
  return runtime.core.cuda_event_elapsed_time(start_event.cuda_event, end_event.cuda_event)
4794
4928
 
4795
4929
 
4796
- def wait_stream(other_stream: Stream, event: Optional[Event] = None):
4930
+ def wait_stream(other_stream: Stream, event: Event | None = None):
4797
4931
  """Convenience function for calling :meth:`Stream.wait_stream` on the current stream.
4798
4932
 
4799
4933
  Args:
@@ -4863,7 +4997,7 @@ class RegisteredGLBuffer:
4863
4997
  __fallback_warning_shown = False
4864
4998
 
4865
4999
  def __new__(cls, *args, **kwargs):
4866
- instance = super(RegisteredGLBuffer, cls).__new__(cls)
5000
+ instance = super().__new__(cls)
4867
5001
  instance.resource = None
4868
5002
  return instance
4869
5003
 
@@ -4960,8 +5094,8 @@ class RegisteredGLBuffer:
4960
5094
 
4961
5095
 
4962
5096
  def zeros(
4963
- shape: Union[int, Tuple[int, ...], List[int], None] = None,
4964
- dtype=float,
5097
+ shape: int | tuple[int, ...] | list[int] | None = None,
5098
+ dtype: type = float,
4965
5099
  device: Devicelike = None,
4966
5100
  requires_grad: bool = False,
4967
5101
  pinned: bool = False,
@@ -4988,7 +5122,7 @@ def zeros(
4988
5122
 
4989
5123
 
4990
5124
  def zeros_like(
4991
- src: Array, device: Devicelike = None, requires_grad: Optional[bool] = None, pinned: Optional[bool] = None
5125
+ src: Array, device: Devicelike = None, requires_grad: bool | None = None, pinned: bool | None = None
4992
5126
  ) -> warp.array:
4993
5127
  """Return a zero-initialized array with the same type and dimension of another array
4994
5128
 
@@ -5010,8 +5144,8 @@ def zeros_like(
5010
5144
 
5011
5145
 
5012
5146
  def ones(
5013
- shape: Union[int, Tuple[int, ...], List[int], None] = None,
5014
- dtype=float,
5147
+ shape: int | tuple[int, ...] | list[int] | None = None,
5148
+ dtype: type = float,
5015
5149
  device: Devicelike = None,
5016
5150
  requires_grad: bool = False,
5017
5151
  pinned: bool = False,
@@ -5034,7 +5168,7 @@ def ones(
5034
5168
 
5035
5169
 
5036
5170
  def ones_like(
5037
- src: Array, device: Devicelike = None, requires_grad: Optional[bool] = None, pinned: Optional[bool] = None
5171
+ src: Array, device: Devicelike = None, requires_grad: bool | None = None, pinned: bool | None = None
5038
5172
  ) -> warp.array:
5039
5173
  """Return a one-initialized array with the same type and dimension of another array
5040
5174
 
@@ -5052,7 +5186,7 @@ def ones_like(
5052
5186
 
5053
5187
 
5054
5188
  def full(
5055
- shape: Union[int, Tuple[int, ...], List[int], None] = None,
5189
+ shape: int | tuple[int, ...] | list[int] | None = None,
5056
5190
  value=0,
5057
5191
  dtype=Any,
5058
5192
  device: Devicelike = None,
@@ -5121,8 +5255,8 @@ def full_like(
5121
5255
  src: Array,
5122
5256
  value: Any,
5123
5257
  device: Devicelike = None,
5124
- requires_grad: Optional[bool] = None,
5125
- pinned: Optional[bool] = None,
5258
+ requires_grad: bool | None = None,
5259
+ pinned: bool | None = None,
5126
5260
  ) -> warp.array:
5127
5261
  """Return an array with all elements initialized to the given value with the same type and dimension of another array
5128
5262
 
@@ -5145,7 +5279,7 @@ def full_like(
5145
5279
 
5146
5280
 
5147
5281
  def clone(
5148
- src: warp.array, device: Devicelike = None, requires_grad: Optional[bool] = None, pinned: Optional[bool] = None
5282
+ src: warp.array, device: Devicelike = None, requires_grad: bool | None = None, pinned: bool | None = None
5149
5283
  ) -> warp.array:
5150
5284
  """Clone an existing array, allocates a copy of the src memory
5151
5285
 
@@ -5167,7 +5301,7 @@ def clone(
5167
5301
 
5168
5302
 
5169
5303
  def empty(
5170
- shape: Union[int, Tuple[int, ...], List[int], None] = None,
5304
+ shape: int | tuple[int, ...] | list[int] | None = None,
5171
5305
  dtype=float,
5172
5306
  device: Devicelike = None,
5173
5307
  requires_grad: bool = False,
@@ -5200,7 +5334,7 @@ def empty(
5200
5334
 
5201
5335
 
5202
5336
  def empty_like(
5203
- src: Array, device: Devicelike = None, requires_grad: Optional[bool] = None, pinned: Optional[bool] = None
5337
+ src: Array, device: Devicelike = None, requires_grad: bool | None = None, pinned: bool | None = None
5204
5338
  ) -> warp.array:
5205
5339
  """Return an uninitialized array with the same type and dimension of another array
5206
5340
 
@@ -5235,9 +5369,9 @@ def empty_like(
5235
5369
 
5236
5370
  def from_numpy(
5237
5371
  arr: np.ndarray,
5238
- dtype: Optional[type] = None,
5239
- shape: Optional[Sequence[int]] = None,
5240
- device: Optional[Devicelike] = None,
5372
+ dtype: type | None = None,
5373
+ shape: Sequence[int] | None = None,
5374
+ device: Devicelike | None = None,
5241
5375
  requires_grad: bool = False,
5242
5376
  ) -> warp.array:
5243
5377
  """Returns a Warp array created from a NumPy array.
@@ -5255,7 +5389,7 @@ def from_numpy(
5255
5389
  if dtype is None:
5256
5390
  base_type = warp.types.np_dtype_to_warp_type.get(arr.dtype)
5257
5391
  if base_type is None:
5258
- raise RuntimeError("Unsupported NumPy data type '{}'.".format(arr.dtype))
5392
+ raise RuntimeError(f"Unsupported NumPy data type '{arr.dtype}'.")
5259
5393
 
5260
5394
  dim_count = len(arr.shape)
5261
5395
  if dim_count == 2:
@@ -5274,7 +5408,7 @@ def from_numpy(
5274
5408
  )
5275
5409
 
5276
5410
 
5277
- def event_from_ipc_handle(handle, device: "Devicelike" = None) -> Event:
5411
+ def event_from_ipc_handle(handle, device: Devicelike = None) -> Event:
5278
5412
  """Create an event from an IPC handle.
5279
5413
 
5280
5414
  Args:
@@ -5443,10 +5577,10 @@ class Launch:
5443
5577
  self,
5444
5578
  kernel,
5445
5579
  device: Device,
5446
- hooks: Optional[KernelHooks] = None,
5447
- params: Optional[Sequence[Any]] = None,
5448
- params_addr: Optional[Sequence[ctypes.c_void_p]] = None,
5449
- bounds: Optional[launch_bounds_t] = None,
5580
+ hooks: KernelHooks | None = None,
5581
+ params: Sequence[Any] | None = None,
5582
+ params_addr: Sequence[ctypes.c_void_p] | None = None,
5583
+ bounds: launch_bounds_t | None = None,
5450
5584
  max_blocks: int = 0,
5451
5585
  block_dim: int = 256,
5452
5586
  adjoint: bool = False,
@@ -5516,7 +5650,7 @@ class Launch:
5516
5650
  self.adjoint: bool = adjoint
5517
5651
  """Whether to run the adjoint kernel instead of the forward kernel."""
5518
5652
 
5519
- def set_dim(self, dim: Union[int, List[int], Tuple[int, ...]]):
5653
+ def set_dim(self, dim: int | list[int] | tuple[int, ...]):
5520
5654
  """Set the launch dimensions.
5521
5655
 
5522
5656
  Args:
@@ -5554,7 +5688,7 @@ class Launch:
5554
5688
  if self.params_addr:
5555
5689
  self.params_addr[params_index] = ctypes.c_void_p(ctypes.addressof(carg))
5556
5690
 
5557
- def set_param_at_index_from_ctype(self, index: int, value: Union[ctypes.Structure, int, float]):
5691
+ def set_param_at_index_from_ctype(self, index: int, value: ctypes.Structure | int | float):
5558
5692
  """Set a kernel parameter at an index without any type conversion.
5559
5693
 
5560
5694
  Args:
@@ -5617,7 +5751,7 @@ class Launch:
5617
5751
  for i, v in enumerate(values):
5618
5752
  self.set_param_at_index_from_ctype(i, v)
5619
5753
 
5620
- def launch(self, stream: Optional[Stream] = None) -> None:
5754
+ def launch(self, stream: Stream | None = None) -> None:
5621
5755
  """Launch the kernel.
5622
5756
 
5623
5757
  Args:
@@ -5634,7 +5768,7 @@ class Launch:
5634
5768
 
5635
5769
  # If the stream is capturing, we retain the CUDA module so that it doesn't get unloaded
5636
5770
  # before the captured graph is released.
5637
- if runtime.core.cuda_stream_is_capturing(stream.cuda_stream):
5771
+ if len(runtime.captures) > 0 and runtime.core.cuda_stream_is_capturing(stream.cuda_stream):
5638
5772
  capture_id = runtime.core.cuda_stream_get_capture_id(stream.cuda_stream)
5639
5773
  graph = runtime.captures.get(capture_id)
5640
5774
  if graph is not None:
@@ -5666,13 +5800,13 @@ class Launch:
5666
5800
 
5667
5801
  def launch(
5668
5802
  kernel,
5669
- dim: Union[int, Sequence[int]],
5803
+ dim: int | Sequence[int],
5670
5804
  inputs: Sequence = [],
5671
5805
  outputs: Sequence = [],
5672
5806
  adj_inputs: Sequence = [],
5673
5807
  adj_outputs: Sequence = [],
5674
5808
  device: Devicelike = None,
5675
- stream: Optional[Stream] = None,
5809
+ stream: Stream | None = None,
5676
5810
  adjoint: bool = False,
5677
5811
  record_tape: bool = True,
5678
5812
  record_cmd: bool = False,
@@ -5824,7 +5958,7 @@ def launch(
5824
5958
 
5825
5959
  # If the stream is capturing, we retain the CUDA module so that it doesn't get unloaded
5826
5960
  # before the captured graph is released.
5827
- if runtime.core.cuda_stream_is_capturing(stream.cuda_stream):
5961
+ if len(runtime.captures) > 0 and runtime.core.cuda_stream_is_capturing(stream.cuda_stream):
5828
5962
  capture_id = runtime.core.cuda_stream_get_capture_id(stream.cuda_stream)
5829
5963
  graph = runtime.captures.get(capture_id)
5830
5964
  if graph is not None:
@@ -5968,7 +6102,7 @@ def launch_tiled(*args, **kwargs):
5968
6102
  raise RuntimeError("wp.launch_tiled() requires a grid with fewer than 4 dimensions")
5969
6103
 
5970
6104
  # add trailing dimension
5971
- kwargs["dim"] = dim + [kwargs["block_dim"]]
6105
+ kwargs["dim"] = [*dim, kwargs["block_dim"]]
5972
6106
 
5973
6107
  # forward to original launch method
5974
6108
  return launch(*args, **kwargs)
@@ -6016,7 +6150,7 @@ def synchronize_device(device: Devicelike = None):
6016
6150
  runtime.core.cuda_context_synchronize(device.context)
6017
6151
 
6018
6152
 
6019
- def synchronize_stream(stream_or_device: Union[Stream, Devicelike, None] = None):
6153
+ def synchronize_stream(stream_or_device: Stream | Devicelike | None = None):
6020
6154
  """Synchronize the calling CPU thread with any outstanding CUDA work on the specified stream.
6021
6155
 
6022
6156
  This function allows the host application code to ensure that all kernel launches
@@ -6046,7 +6180,7 @@ def synchronize_event(event: Event):
6046
6180
  runtime.core.cuda_event_synchronize(event.cuda_event)
6047
6181
 
6048
6182
 
6049
- def force_load(device: Union[Device, str, List[Device], List[str]] = None, modules: List[Module] = None):
6183
+ def force_load(device: Device | str | list[Device] | list[str] | None = None, modules: list[Module] | None = None):
6050
6184
  """Force user-defined kernels to be compiled and loaded
6051
6185
 
6052
6186
  Args:
@@ -6078,7 +6212,7 @@ def force_load(device: Union[Device, str, List[Device], List[str]] = None, modul
6078
6212
 
6079
6213
 
6080
6214
  def load_module(
6081
- module: Union[Module, types.ModuleType, str] = None, device: Union[Device, str] = None, recursive: bool = False
6215
+ module: Module | types.ModuleType | str | None = None, device: Device | str | None = None, recursive: bool = False
6082
6216
  ):
6083
6217
  """Force user-defined module to be compiled and loaded
6084
6218
 
@@ -6120,7 +6254,7 @@ def load_module(
6120
6254
  force_load(device=device, modules=modules)
6121
6255
 
6122
6256
 
6123
- def set_module_options(options: Dict[str, Any], module: Optional[Any] = None):
6257
+ def set_module_options(options: dict[str, Any], module: Any = None):
6124
6258
  """Set options for the current module.
6125
6259
 
6126
6260
  Options can be used to control runtime compilation and code-generation
@@ -6144,7 +6278,7 @@ def set_module_options(options: Dict[str, Any], module: Optional[Any] = None):
6144
6278
  get_module(m.__name__).mark_modified()
6145
6279
 
6146
6280
 
6147
- def get_module_options(module: Optional[Any] = None) -> Dict[str, Any]:
6281
+ def get_module_options(module: Any = None) -> dict[str, Any]:
6148
6282
  """Returns a list of options for the current module."""
6149
6283
  if module is None:
6150
6284
  m = inspect.getmodule(inspect.stack()[1][0])
@@ -6154,10 +6288,44 @@ def get_module_options(module: Optional[Any] = None) -> Dict[str, Any]:
6154
6288
  return get_module(m.__name__).options
6155
6289
 
6156
6290
 
6291
+ def _unregister_capture(device: Device, stream: Stream, graph: Graph):
6292
+ """Unregister a graph capture from the device and runtime.
6293
+
6294
+ This should be called when a graph capture is no longer active, either because it completed or was paused.
6295
+ The graph should only be registered while it is actively capturing.
6296
+
6297
+ Args:
6298
+ device: The CUDA device the graph was being captured on
6299
+ stream: The CUDA stream the graph was being captured on
6300
+ graph: The Graph object that was being captured
6301
+ """
6302
+ del device.captures[stream]
6303
+ del runtime.captures[graph.capture_id]
6304
+
6305
+
6306
+ def _register_capture(device: Device, stream: Stream, graph: Graph, capture_id: int):
6307
+ """Register a graph capture with the device and runtime.
6308
+
6309
+ Makes the graph discoverable through its capture_id so that retain_module_exec() can be called
6310
+ when launching kernels during graph capture. This ensures modules are retained until graph execution completes.
6311
+
6312
+ Args:
6313
+ device: The CUDA device the graph is being captured on
6314
+ stream: The CUDA stream the graph is being captured on
6315
+ graph: The Graph object being captured
6316
+ capture_id: Unique identifier for this graph capture
6317
+ """
6318
+ # add to ongoing captures on the device
6319
+ device.captures[stream] = graph
6320
+
6321
+ # add to lookup table by globally unique capture id
6322
+ runtime.captures[capture_id] = graph
6323
+
6324
+
6157
6325
  def capture_begin(
6158
6326
  device: Devicelike = None,
6159
- stream: Optional[Stream] = None,
6160
- force_module_load: Optional[bool] = None,
6327
+ stream: Stream | None = None,
6328
+ force_module_load: bool | None = None,
6161
6329
  external: bool = False,
6162
6330
  ):
6163
6331
  """Begin capture of a CUDA graph
@@ -6219,14 +6387,10 @@ def capture_begin(
6219
6387
  capture_id = runtime.core.cuda_stream_get_capture_id(stream.cuda_stream)
6220
6388
  graph = Graph(device, capture_id)
6221
6389
 
6222
- # add to ongoing captures on the device
6223
- device.captures[stream] = graph
6224
-
6225
- # add to lookup table by globally unique capture id
6226
- runtime.captures[capture_id] = graph
6390
+ _register_capture(device, stream, graph, capture_id)
6227
6391
 
6228
6392
 
6229
- def capture_end(device: Devicelike = None, stream: Optional[Stream] = None) -> Graph:
6393
+ def capture_end(device: Devicelike = None, stream: Stream | None = None) -> Graph:
6230
6394
  """End the capture of a CUDA graph.
6231
6395
 
6232
6396
  Args:
@@ -6251,24 +6415,361 @@ def capture_end(device: Devicelike = None, stream: Optional[Stream] = None) -> G
6251
6415
  if graph is None:
6252
6416
  raise RuntimeError("Graph capture is not active on this stream")
6253
6417
 
6254
- del device.captures[stream]
6255
- del runtime.captures[graph.capture_id]
6418
+ _unregister_capture(device, stream, graph)
6256
6419
 
6257
6420
  # get the graph executable
6258
- graph_exec = ctypes.c_void_p()
6259
- result = runtime.core.cuda_graph_end_capture(device.context, stream.cuda_stream, ctypes.byref(graph_exec))
6421
+ g = ctypes.c_void_p()
6422
+ result = runtime.core.cuda_graph_end_capture(device.context, stream.cuda_stream, ctypes.byref(g))
6260
6423
 
6261
6424
  if not result:
6262
6425
  # A concrete error should've already been reported, so we don't need to go into details here
6263
6426
  raise RuntimeError(f"CUDA graph capture failed. {runtime.get_error_string()}")
6264
6427
 
6265
6428
  # set the graph executable
6266
- graph.graph_exec = graph_exec
6429
+ graph.graph = g
6430
+ graph.graph_exec = None # Lazy initialization
6431
+
6432
+ return graph
6433
+
6434
+
6435
+ def capture_debug_dot_print(graph: Graph, path: str, verbose: bool = False):
6436
+ """Export a CUDA graph to a DOT file for visualization
6437
+
6438
+ Args:
6439
+ graph: A :class:`Graph` as returned by :func:`~warp.capture_end()`
6440
+ path: Path to save the DOT file
6441
+ verbose: Whether to include additional debug information in the output
6442
+ """
6443
+ if not runtime.core.capture_debug_dot_print(graph.graph, path.encode(), 0 if verbose else 1):
6444
+ raise RuntimeError(f"Graph debug dot print error: {runtime.get_error_string()}")
6445
+
6446
+
6447
+ def assert_conditional_graph_support():
6448
+ if runtime is None:
6449
+ init()
6450
+
6451
+ if runtime.toolkit_version < (12, 4):
6452
+ raise RuntimeError("Warp must be built with CUDA Toolkit 12.4+ to enable conditional graph nodes")
6453
+
6454
+ if runtime.driver_version < (12, 4):
6455
+ raise RuntimeError("Conditional graph nodes require CUDA driver 12.4+")
6456
+
6457
+
6458
+ def capture_pause(device: Devicelike = None, stream: Stream | None = None) -> Graph:
6459
+ if stream is not None:
6460
+ device = stream.device
6461
+ else:
6462
+ device = runtime.get_device(device)
6463
+ if not device.is_cuda:
6464
+ raise RuntimeError("Must be a CUDA device")
6465
+ stream = device.stream
6466
+
6467
+ # get the graph being captured
6468
+ graph = device.captures.get(stream)
6469
+
6470
+ if graph is None:
6471
+ raise RuntimeError("Graph capture is not active on this stream")
6472
+
6473
+ _unregister_capture(device, stream, graph)
6474
+
6475
+ g = ctypes.c_void_p()
6476
+ if not runtime.core.cuda_graph_pause_capture(device.context, stream.cuda_stream, ctypes.byref(g)):
6477
+ raise RuntimeError(runtime.get_error_string())
6478
+
6479
+ graph.graph = g
6267
6480
 
6268
6481
  return graph
6269
6482
 
6270
6483
 
6271
- def capture_launch(graph: Graph, stream: Optional[Stream] = None):
6484
+ def capture_resume(graph: Graph, device: Devicelike = None, stream: Stream | None = None):
6485
+ if stream is not None:
6486
+ device = stream.device
6487
+ else:
6488
+ device = runtime.get_device(device)
6489
+ if not device.is_cuda:
6490
+ raise RuntimeError("Must be a CUDA device")
6491
+ stream = device.stream
6492
+
6493
+ if not runtime.core.cuda_graph_resume_capture(device.context, stream.cuda_stream, graph.graph):
6494
+ raise RuntimeError(runtime.get_error_string())
6495
+
6496
+ capture_id = runtime.core.cuda_stream_get_capture_id(stream.cuda_stream)
6497
+ graph.capture_id = capture_id
6498
+
6499
+ _register_capture(device, stream, graph, capture_id)
6500
+
6501
+
6502
+ # reusable pinned readback buffer for conditions
6503
+ condition_host = None
6504
+
6505
+
6506
+ def capture_if(
6507
+ condition: warp.array(dtype=int),
6508
+ on_true: Callable | Graph | None = None,
6509
+ on_false: Callable | Graph | None = None,
6510
+ stream: Stream = None,
6511
+ **kwargs,
6512
+ ):
6513
+ """Create a dynamic branch based on a condition.
6514
+
6515
+ The condition value is retrieved from the first element of the ``condition`` array.
6516
+
6517
+ This function is particularly useful with CUDA graphs, but can be used without graph capture as well.
6518
+ CUDA 12.4+ is required to take advantage of conditional graph nodes for dynamic control flow.
6519
+
6520
+ Args:
6521
+ condition: Warp array holding the condition value.
6522
+ on_true: A callback function or :class:`Graph` to execute if the condition is True.
6523
+ on_false: A callback function or :class:`Graph` to execute if the condition is False.
6524
+ stream: The CUDA stream where the condition was written. If None, use the current stream on the device where ``condition`` resides.
6525
+
6526
+ Any additional keyword arguments are forwarded to the callback functions.
6527
+ """
6528
+
6529
+ # if neither the IF branch nor the ELSE branch is specified, it's a no-op
6530
+ if on_true is None and on_false is None:
6531
+ return
6532
+
6533
+ # check condition data type
6534
+ if not isinstance(condition, warp.array) or condition.dtype is not warp.int32:
6535
+ raise TypeError("Condition must be a Warp array of int32 with a single element")
6536
+
6537
+ device = condition.device
6538
+
6539
+ # determine the stream and whether a graph capture is active
6540
+ if device.is_cuda:
6541
+ if stream is None:
6542
+ stream = device.stream
6543
+ graph = device.captures.get(stream)
6544
+ else:
6545
+ graph = None
6546
+
6547
+ if graph is None:
6548
+ # if no graph is active, just execute the correct branch directly
6549
+ if device.is_cuda:
6550
+ # use a pinned buffer for condition readback to host
6551
+ global condition_host
6552
+ if condition_host is None:
6553
+ condition_host = warp.empty(1, dtype=int, device="cpu", pinned=True)
6554
+ warp.copy(condition_host, condition, stream=stream)
6555
+ warp.synchronize_stream(stream)
6556
+ condition_value = bool(ctypes.cast(condition_host.ptr, ctypes.POINTER(ctypes.c_int32)).contents)
6557
+ else:
6558
+ condition_value = bool(ctypes.cast(condition.ptr, ctypes.POINTER(ctypes.c_int32)).contents)
6559
+
6560
+ if condition_value:
6561
+ if on_true is not None:
6562
+ if isinstance(on_true, Callable):
6563
+ on_true(**kwargs)
6564
+ elif isinstance(on_true, Graph):
6565
+ capture_launch(on_true, stream=stream)
6566
+ else:
6567
+ raise TypeError("on_true must be a Callable or a Graph")
6568
+ else:
6569
+ if on_false is not None:
6570
+ if isinstance(on_false, Callable):
6571
+ on_false(**kwargs)
6572
+ elif isinstance(on_false, Graph):
6573
+ capture_launch(on_false, stream=stream)
6574
+ else:
6575
+ raise TypeError("on_false must be a Callable or a Graph")
6576
+
6577
+ return
6578
+
6579
+ graph.has_conditional = True
6580
+
6581
+ # ensure conditional graph nodes are supported
6582
+ assert_conditional_graph_support()
6583
+
6584
+ # insert conditional node
6585
+ graph_on_true = ctypes.c_void_p()
6586
+ graph_on_false = ctypes.c_void_p()
6587
+ if not runtime.core.cuda_graph_insert_if_else(
6588
+ device.context,
6589
+ stream.cuda_stream,
6590
+ ctypes.cast(condition.ptr, ctypes.POINTER(ctypes.c_int32)),
6591
+ None if on_true is None else ctypes.byref(graph_on_true),
6592
+ None if on_false is None else ctypes.byref(graph_on_false),
6593
+ ):
6594
+ raise RuntimeError(runtime.get_error_string())
6595
+
6596
+ # pause capturing parent graph
6597
+ main_graph = capture_pause(stream=stream)
6598
+ # store the pointer to the cuda graph to restore it later
6599
+ main_graph_ptr = main_graph.graph
6600
+
6601
+ # capture if-graph
6602
+ if on_true is not None:
6603
+ # temporarily repurpose the main_graph python object such that all dependencies
6604
+ # added through retain_module_exec() end up in the correct python graph object
6605
+ main_graph.graph = graph_on_true
6606
+ capture_resume(main_graph, stream=stream)
6607
+ if isinstance(on_true, Callable):
6608
+ on_true(**kwargs)
6609
+ elif isinstance(on_true, Graph):
6610
+ if on_true.has_conditional:
6611
+ raise RuntimeError(
6612
+ "The on_true graph contains conditional nodes, which are not allowed in child graphs"
6613
+ )
6614
+ if not runtime.core.cuda_graph_insert_child_graph(
6615
+ device.context,
6616
+ stream.cuda_stream,
6617
+ on_true.graph,
6618
+ ):
6619
+ raise RuntimeError(runtime.get_error_string())
6620
+ else:
6621
+ raise TypeError("on_true must be a Callable or a Graph")
6622
+ capture_pause(stream=stream)
6623
+
6624
+ # capture else-graph
6625
+ if on_false is not None:
6626
+ # temporarily repurpose the main_graph python object such that all dependencies
6627
+ # added through retain_module_exec() end up in the correct python graph object
6628
+ main_graph.graph = graph_on_false
6629
+ capture_resume(main_graph, stream=stream)
6630
+ if isinstance(on_false, Callable):
6631
+ on_false(**kwargs)
6632
+ elif isinstance(on_false, Graph):
6633
+ if on_false.has_conditional:
6634
+ raise RuntimeError(
6635
+ "The on_false graph contains conditional nodes, which are not allowed in child graphs"
6636
+ )
6637
+ if not runtime.core.cuda_graph_insert_child_graph(
6638
+ device.context,
6639
+ stream.cuda_stream,
6640
+ on_false.graph,
6641
+ ):
6642
+ raise RuntimeError(runtime.get_error_string())
6643
+ else:
6644
+ raise TypeError("on_false must be a Callable or a Graph")
6645
+ capture_pause(stream=stream)
6646
+
6647
+ # restore the main graph to its original state
6648
+ main_graph.graph = main_graph_ptr
6649
+
6650
+ # resume capturing parent graph
6651
+ capture_resume(main_graph, stream=stream)
6652
+
6653
+
6654
+ def capture_while(condition: warp.array(dtype=int), while_body: Callable | Graph, stream: Stream = None, **kwargs):
6655
+ """Create a dynamic loop based on a condition.
6656
+
6657
+ The condition value is retrieved from the first element of the ``condition`` array.
6658
+
6659
+ The ``while_body`` callback is responsible for updating the condition value so the loop can terminate.
6660
+
6661
+ This function is particularly useful with CUDA graphs, but can be used without graph capture as well.
6662
+ CUDA 12.4+ is required to take advantage of conditional graph nodes for dynamic control flow.
6663
+
6664
+ Args:
6665
+ condition: Warp array holding the condition value.
6666
+ while_body: A callback function or :class:`Graph` to execute while the loop condition is True.
6667
+ stream: The CUDA stream where the condition was written. If None, use the current stream on the device where ``condition`` resides.
6668
+
6669
+ Any additional keyword arguments are forwarded to the callback function.
6670
+ """
6671
+
6672
+ # check condition data type
6673
+ if not isinstance(condition, warp.array) or condition.dtype is not warp.int32:
6674
+ raise TypeError("Condition must be a Warp array of int32 with a single element")
6675
+
6676
+ device = condition.device
6677
+
6678
+ # determine the stream and whether a graph capture is active
6679
+ if device.is_cuda:
6680
+ if stream is None:
6681
+ stream = device.stream
6682
+ graph = device.captures.get(stream)
6683
+ else:
6684
+ graph = None
6685
+
6686
+ if graph is None:
6687
+ # since no graph is active, just execute the kernels directly
6688
+ while True:
6689
+ if device.is_cuda:
6690
+ # use a pinned buffer for condition readback to host
6691
+ global condition_host
6692
+ if condition_host is None:
6693
+ condition_host = warp.empty(1, dtype=int, device="cpu", pinned=True)
6694
+ warp.copy(condition_host, condition, stream=stream)
6695
+ warp.synchronize_stream(stream)
6696
+ condition_value = bool(ctypes.cast(condition_host.ptr, ctypes.POINTER(ctypes.c_int32)).contents)
6697
+ else:
6698
+ condition_value = bool(ctypes.cast(condition.ptr, ctypes.POINTER(ctypes.c_int32)).contents)
6699
+
6700
+ if condition_value:
6701
+ if isinstance(while_body, Callable):
6702
+ while_body(**kwargs)
6703
+ elif isinstance(while_body, Graph):
6704
+ capture_launch(while_body, stream=stream)
6705
+ else:
6706
+ raise TypeError("while_body must be a callable or a graph")
6707
+
6708
+ else:
6709
+ break
6710
+
6711
+ return
6712
+
6713
+ graph.has_conditional = True
6714
+
6715
+ # ensure conditional graph nodes are supported
6716
+ assert_conditional_graph_support()
6717
+
6718
+ # insert conditional while-node
6719
+ body_graph = ctypes.c_void_p()
6720
+ cond_handle = ctypes.c_uint64()
6721
+ if not runtime.core.cuda_graph_insert_while(
6722
+ device.context,
6723
+ stream.cuda_stream,
6724
+ ctypes.cast(condition.ptr, ctypes.POINTER(ctypes.c_int32)),
6725
+ ctypes.byref(body_graph),
6726
+ ctypes.byref(cond_handle),
6727
+ ):
6728
+ raise RuntimeError(runtime.get_error_string())
6729
+
6730
+ # pause capturing parent graph and start capturing child graph
6731
+ main_graph = capture_pause(stream=stream)
6732
+ # store the pointer to the cuda graph to restore it later
6733
+ main_graph_ptr = main_graph.graph
6734
+
6735
+ # temporarily repurpose the main_graph python object such that all dependencies
6736
+ # added through retain_module_exec() end up in the correct python graph object
6737
+ main_graph.graph = body_graph
6738
+ capture_resume(main_graph, stream=stream)
6739
+
6740
+ # capture while-body
6741
+ if isinstance(while_body, Callable):
6742
+ while_body(**kwargs)
6743
+ elif isinstance(while_body, Graph):
6744
+ if while_body.has_conditional:
6745
+ raise RuntimeError("The body graph contains conditional nodes, which are not allowed in child graphs")
6746
+
6747
+ if not runtime.core.cuda_graph_insert_child_graph(
6748
+ device.context,
6749
+ stream.cuda_stream,
6750
+ while_body.graph,
6751
+ ):
6752
+ raise RuntimeError(runtime.get_error_string())
6753
+ else:
6754
+ raise RuntimeError(runtime.get_error_string())
6755
+
6756
+ # update condition
6757
+ if not runtime.core.cuda_graph_set_condition(
6758
+ device.context,
6759
+ stream.cuda_stream,
6760
+ ctypes.cast(condition.ptr, ctypes.POINTER(ctypes.c_int32)),
6761
+ cond_handle,
6762
+ ):
6763
+ raise RuntimeError(runtime.get_error_string())
6764
+
6765
+ # stop capturing child graph and resume capturing parent graph
6766
+ capture_pause(stream=stream)
6767
+ # restore the main graph to its original state
6768
+ main_graph.graph = main_graph_ptr
6769
+ capture_resume(main_graph, stream=stream)
6770
+
6771
+
6772
+ def capture_launch(graph: Graph, stream: Stream | None = None):
6272
6773
  """Launch a previously captured CUDA graph
6273
6774
 
6274
6775
  Args:
@@ -6284,6 +6785,15 @@ def capture_launch(graph: Graph, stream: Optional[Stream] = None):
6284
6785
  device = graph.device
6285
6786
  stream = device.stream
6286
6787
 
6788
+ if graph.graph_exec is None:
6789
+ g = ctypes.c_void_p()
6790
+ result = runtime.core.cuda_graph_create_exec(
6791
+ graph.device.context, stream.cuda_stream, graph.graph, ctypes.byref(g)
6792
+ )
6793
+ if not result:
6794
+ raise RuntimeError(f"Graph creation error: {runtime.get_error_string()}")
6795
+ graph.graph_exec = g
6796
+
6287
6797
  if not runtime.core.cuda_graph_launch(graph.graph_exec, stream.cuda_stream):
6288
6798
  raise RuntimeError(f"Graph launch error: {runtime.get_error_string()}")
6289
6799
 
@@ -6294,7 +6804,7 @@ def copy(
6294
6804
  dest_offset: int = 0,
6295
6805
  src_offset: int = 0,
6296
6806
  count: int = 0,
6297
- stream: Optional[Stream] = None,
6807
+ stream: Stream | None = None,
6298
6808
  ):
6299
6809
  """Copy array contents from `src` to `dest`.
6300
6810
 
@@ -6431,11 +6941,8 @@ def copy(
6431
6941
 
6432
6942
  # can't copy to/from fabric arrays of arrays, because they are jagged arrays of arbitrary lengths
6433
6943
  # TODO?
6434
- if (
6435
- isinstance(src, (warp.fabricarray, warp.indexedfabricarray))
6436
- and src.ndim > 1
6437
- or isinstance(dest, (warp.fabricarray, warp.indexedfabricarray))
6438
- and dest.ndim > 1
6944
+ if (isinstance(src, (warp.fabricarray, warp.indexedfabricarray)) and src.ndim > 1) or (
6945
+ isinstance(dest, (warp.fabricarray, warp.indexedfabricarray)) and dest.ndim > 1
6439
6946
  ):
6440
6947
  raise RuntimeError("Copying to/from Fabric arrays of arrays is not supported")
6441
6948
 
@@ -6503,7 +7010,7 @@ def type_str(t):
6503
7010
  return "Callable"
6504
7011
  elif isinstance(t, int):
6505
7012
  return str(t)
6506
- elif isinstance(t, List):
7013
+ elif isinstance(t, (List, tuple)):
6507
7014
  return "Tuple[" + ", ".join(map(type_str, t)) + "]"
6508
7015
  elif isinstance(t, warp.array):
6509
7016
  return f"Array[{type_str(t.dtype)}]"
@@ -6536,12 +7043,16 @@ def type_str(t):
6536
7043
 
6537
7044
  raise TypeError("Invalid vector or matrix dimensions")
6538
7045
  elif get_origin(t) in (list, tuple):
6539
- args_repr = ", ".join(type_str(x) for x in get_args(t))
6540
- return f"{t._name}[{args_repr}]"
7046
+ args = get_args(t)
7047
+ if args:
7048
+ args_repr = ", ".join(type_str(x) for x in get_args(t))
7049
+ return f"{t._name}[{args_repr}]"
7050
+ else:
7051
+ return f"{t._name}"
6541
7052
  elif t is Ellipsis:
6542
7053
  return "..."
6543
7054
  elif warp.types.is_tile(t):
6544
- return "Tile"
7055
+ return f"Tile[{type_str(t.dtype)},{type_str(t.shape)}]"
6545
7056
 
6546
7057
  return t.__name__
6547
7058
 
@@ -6568,14 +7079,14 @@ def resolve_exported_function_sig(f):
6568
7079
  # so we can generate the return type for overloaded functions
6569
7080
  return_type = f.value_func(func_args, None)
6570
7081
 
7082
+ if return_type is None or (isinstance(return_type, tuple) and len(return_type) > 1):
7083
+ return (func_args, return_type)
7084
+
6571
7085
  try:
6572
- return_type_str = ctype_ret_str(return_type)
7086
+ ctype_ret_str(return_type)
6573
7087
  except Exception:
6574
7088
  return None
6575
7089
 
6576
- if return_type_str.startswith("Tuple"):
6577
- return None
6578
-
6579
7090
  return (func_args, return_type)
6580
7091
 
6581
7092
 
@@ -6716,13 +7227,18 @@ def export_functions_rst(file): # pragma: no cover
6716
7227
  print("---------------", file=file)
6717
7228
 
6718
7229
  for f, is_exported in g:
7230
+ if not isinstance(f, Function) and callable(f):
7231
+ # f is a plain Python function
7232
+ print(f".. autofunction:: {f.__module__}.{f.__name__}", file=file)
7233
+ continue
6719
7234
  if f.func:
6720
7235
  # f is a Warp function written in Python, we can use autofunction
6721
7236
  print(f".. autofunction:: {f.func.__module__}.{f.key}", file=file)
6722
7237
  continue
6723
7238
  for f_prefix, query_type in query_types:
6724
7239
  if f.key.startswith(f_prefix) and query_type not in written_query_types:
6725
- print(f".. autoclass:: {query_type}", file=file)
7240
+ print(f".. autoclass:: warp.{query_type}", file=file)
7241
+ print(" :exclude-members: Var, vars", file=file)
6726
7242
  written_query_types.add(query_type)
6727
7243
  break
6728
7244
 
@@ -6775,6 +7291,7 @@ def export_stubs(file): # pragma: no cover
6775
7291
  print('Rows = TypeVar("Rows", bound=int)', file=file)
6776
7292
  print('Cols = TypeVar("Cols", bound=int)', file=file)
6777
7293
  print('DType = TypeVar("DType")', file=file)
7294
+ print('Shape = TypeVar("Shape")', file=file)
6778
7295
 
6779
7296
  print("Vector = Generic[Length, Scalar]", file=file)
6780
7297
  print("Matrix = Generic[Rows, Cols, Scalar]", file=file)
@@ -6783,6 +7300,7 @@ def export_stubs(file): # pragma: no cover
6783
7300
  print("Array = Generic[DType]", file=file)
6784
7301
  print("FabricArray = Generic[DType]", file=file)
6785
7302
  print("IndexedFabricArray = Generic[DType]", file=file)
7303
+ print("Tile = Generic[DType, Shape]", file=file)
6786
7304
 
6787
7305
  # prepend __init__.py
6788
7306
  with open(os.path.join(os.path.dirname(file.name), "__init__.py")) as header_file:
@@ -6817,7 +7335,7 @@ def export_stubs(file): # pragma: no cover
6817
7335
  if hasattr(g, "overloads"):
6818
7336
  for f in g.overloads:
6819
7337
  add_stub(f)
6820
- else:
7338
+ elif isinstance(g, Function):
6821
7339
  add_stub(g)
6822
7340
 
6823
7341
 
@@ -6848,16 +7366,30 @@ def export_builtins(file: io.TextIOBase): # pragma: no cover
6848
7366
  args = ", ".join(f"{ctype_arg_str(v)} {k}" for k, v in func_args.items())
6849
7367
  params = ", ".join(func_args.keys())
6850
7368
 
6851
- return_str = ctype_ret_str(return_type)
6852
-
6853
- if args == "":
6854
- file.write(f"WP_API void {f.mangled_name}({return_str}* ret) {{ *ret = wp::{f.key}({params}); }}\n")
6855
- elif return_type is None:
7369
+ if return_type is None:
7370
+ # void function
6856
7371
  file.write(f"WP_API void {f.mangled_name}({args}) {{ wp::{f.key}({params}); }}\n")
7372
+ elif isinstance(return_type, tuple) and len(return_type) > 1:
7373
+ # multiple return value function using output parameters
7374
+ outputs = tuple(f"{ctype_ret_str(x)}& ret_{i}" for i, x in enumerate(return_type))
7375
+ output_params = ", ".join(f"ret_{i}" for i in range(len(outputs)))
7376
+ if args:
7377
+ file.write(
7378
+ f"WP_API void {f.mangled_name}({args}, {', '.join(outputs)}) {{ wp::{f.key}({params}, {output_params}); }}\n"
7379
+ )
7380
+ else:
7381
+ file.write(
7382
+ f"WP_API void {f.mangled_name}({', '.join(outputs)}) {{ wp::{f.key}({params}, {output_params}); }}\n"
7383
+ )
6857
7384
  else:
6858
- file.write(
6859
- f"WP_API void {f.mangled_name}({args}, {return_str}* ret) {{ *ret = wp::{f.key}({params}); }}\n"
6860
- )
7385
+ # single return value function
7386
+ return_str = ctype_ret_str(return_type)
7387
+ if args:
7388
+ file.write(
7389
+ f"WP_API void {f.mangled_name}({args}, {return_str}* ret) {{ *ret = wp::{f.key}({params}); }}\n"
7390
+ )
7391
+ else:
7392
+ file.write(f"WP_API void {f.mangled_name}({return_str}* ret) {{ *ret = wp::{f.key}({params}); }}\n")
6861
7393
 
6862
7394
  file.write('\n} // extern "C"\n\n')
6863
7395
  file.write("} // namespace wp\n")