warp-lang 1.6.2__py3-none-macosx_10_13_universal2.whl → 1.7.1__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (191) hide show
  1. warp/__init__.py +7 -1
  2. warp/autograd.py +12 -2
  3. warp/bin/libwarp-clang.dylib +0 -0
  4. warp/bin/libwarp.dylib +0 -0
  5. warp/build.py +410 -0
  6. warp/build_dll.py +6 -14
  7. warp/builtins.py +463 -372
  8. warp/codegen.py +196 -124
  9. warp/config.py +42 -6
  10. warp/context.py +496 -271
  11. warp/dlpack.py +8 -6
  12. warp/examples/assets/nonuniform.usd +0 -0
  13. warp/examples/assets/nvidia_logo.png +0 -0
  14. warp/examples/benchmarks/benchmark_cloth.py +1 -1
  15. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  16. warp/examples/core/example_sample_mesh.py +300 -0
  17. warp/examples/distributed/example_jacobi_mpi.py +507 -0
  18. warp/examples/fem/example_apic_fluid.py +1 -1
  19. warp/examples/fem/example_burgers.py +2 -2
  20. warp/examples/fem/example_deformed_geometry.py +1 -1
  21. warp/examples/fem/example_distortion_energy.py +1 -1
  22. warp/examples/fem/example_magnetostatics.py +6 -6
  23. warp/examples/fem/utils.py +9 -3
  24. warp/examples/interop/example_jax_callable.py +116 -0
  25. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  26. warp/examples/interop/example_jax_kernel.py +205 -0
  27. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  28. warp/examples/tile/example_tile_matmul.py +2 -4
  29. warp/fem/__init__.py +11 -1
  30. warp/fem/adaptivity.py +4 -4
  31. warp/fem/field/field.py +11 -1
  32. warp/fem/field/nodal_field.py +56 -88
  33. warp/fem/field/virtual.py +62 -23
  34. warp/fem/geometry/adaptive_nanogrid.py +16 -13
  35. warp/fem/geometry/closest_point.py +1 -1
  36. warp/fem/geometry/deformed_geometry.py +5 -2
  37. warp/fem/geometry/geometry.py +5 -0
  38. warp/fem/geometry/grid_2d.py +12 -12
  39. warp/fem/geometry/grid_3d.py +12 -15
  40. warp/fem/geometry/hexmesh.py +5 -7
  41. warp/fem/geometry/nanogrid.py +9 -11
  42. warp/fem/geometry/quadmesh.py +13 -13
  43. warp/fem/geometry/tetmesh.py +3 -4
  44. warp/fem/geometry/trimesh.py +7 -20
  45. warp/fem/integrate.py +262 -93
  46. warp/fem/linalg.py +5 -5
  47. warp/fem/quadrature/pic_quadrature.py +37 -22
  48. warp/fem/quadrature/quadrature.py +194 -25
  49. warp/fem/space/__init__.py +1 -1
  50. warp/fem/space/basis_function_space.py +4 -2
  51. warp/fem/space/basis_space.py +25 -18
  52. warp/fem/space/hexmesh_function_space.py +2 -2
  53. warp/fem/space/partition.py +6 -2
  54. warp/fem/space/quadmesh_function_space.py +8 -8
  55. warp/fem/space/shape/cube_shape_function.py +23 -23
  56. warp/fem/space/shape/square_shape_function.py +12 -12
  57. warp/fem/space/shape/triangle_shape_function.py +1 -1
  58. warp/fem/space/tetmesh_function_space.py +3 -3
  59. warp/fem/space/trimesh_function_space.py +2 -2
  60. warp/fem/utils.py +12 -6
  61. warp/jax.py +14 -1
  62. warp/jax_experimental/__init__.py +16 -0
  63. warp/{jax_experimental.py → jax_experimental/custom_call.py} +28 -29
  64. warp/jax_experimental/ffi.py +702 -0
  65. warp/jax_experimental/xla_ffi.py +602 -0
  66. warp/math.py +89 -0
  67. warp/native/array.h +13 -0
  68. warp/native/builtin.h +29 -3
  69. warp/native/bvh.cpp +3 -1
  70. warp/native/bvh.cu +42 -14
  71. warp/native/bvh.h +2 -1
  72. warp/native/clang/clang.cpp +30 -3
  73. warp/native/cuda_util.cpp +14 -0
  74. warp/native/cuda_util.h +2 -0
  75. warp/native/exports.h +68 -63
  76. warp/native/intersect.h +26 -26
  77. warp/native/intersect_adj.h +33 -33
  78. warp/native/marching.cu +1 -1
  79. warp/native/mat.h +513 -9
  80. warp/native/mesh.h +10 -10
  81. warp/native/quat.h +99 -11
  82. warp/native/rand.h +6 -0
  83. warp/native/sort.cpp +122 -59
  84. warp/native/sort.cu +152 -15
  85. warp/native/sort.h +8 -1
  86. warp/native/sparse.cpp +43 -22
  87. warp/native/sparse.cu +52 -17
  88. warp/native/svd.h +116 -0
  89. warp/native/tile.h +312 -116
  90. warp/native/tile_reduce.h +46 -3
  91. warp/native/vec.h +68 -7
  92. warp/native/volume.cpp +85 -113
  93. warp/native/volume_builder.cu +25 -10
  94. warp/native/volume_builder.h +6 -0
  95. warp/native/warp.cpp +5 -6
  96. warp/native/warp.cu +100 -11
  97. warp/native/warp.h +19 -10
  98. warp/optim/linear.py +10 -10
  99. warp/render/render_opengl.py +19 -17
  100. warp/render/render_usd.py +93 -3
  101. warp/sim/articulation.py +4 -4
  102. warp/sim/collide.py +32 -19
  103. warp/sim/import_mjcf.py +449 -155
  104. warp/sim/import_urdf.py +32 -12
  105. warp/sim/inertia.py +189 -156
  106. warp/sim/integrator_euler.py +8 -5
  107. warp/sim/integrator_featherstone.py +3 -10
  108. warp/sim/integrator_vbd.py +207 -2
  109. warp/sim/integrator_xpbd.py +8 -5
  110. warp/sim/model.py +71 -25
  111. warp/sim/render.py +4 -0
  112. warp/sim/utils.py +2 -2
  113. warp/sparse.py +642 -555
  114. warp/stubs.py +217 -20
  115. warp/tests/__main__.py +0 -15
  116. warp/tests/assets/torus.usda +1 -1
  117. warp/tests/cuda/__init__.py +0 -0
  118. warp/tests/{test_mempool.py → cuda/test_mempool.py} +39 -0
  119. warp/tests/{test_streams.py → cuda/test_streams.py} +71 -0
  120. warp/tests/geometry/__init__.py +0 -0
  121. warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +66 -63
  122. warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +1 -1
  123. warp/tests/{test_volume.py → geometry/test_volume.py} +41 -6
  124. warp/tests/interop/__init__.py +0 -0
  125. warp/tests/{test_dlpack.py → interop/test_dlpack.py} +28 -5
  126. warp/tests/sim/__init__.py +0 -0
  127. warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +9 -10
  128. warp/tests/{test_collision.py → sim/test_collision.py} +236 -205
  129. warp/tests/sim/test_inertia.py +161 -0
  130. warp/tests/{test_model.py → sim/test_model.py} +40 -0
  131. warp/tests/{flaky_test_sim_grad.py → sim/test_sim_grad.py} +4 -0
  132. warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +2 -1
  133. warp/tests/sim/test_vbd.py +597 -0
  134. warp/tests/sim/test_xpbd.py +399 -0
  135. warp/tests/test_bool.py +1 -1
  136. warp/tests/test_codegen.py +24 -3
  137. warp/tests/test_examples.py +40 -38
  138. warp/tests/test_fem.py +98 -14
  139. warp/tests/test_linear_solvers.py +0 -11
  140. warp/tests/test_mat.py +577 -156
  141. warp/tests/test_mat_scalar_ops.py +4 -4
  142. warp/tests/test_overwrite.py +0 -60
  143. warp/tests/test_quat.py +356 -151
  144. warp/tests/test_rand.py +44 -37
  145. warp/tests/test_sparse.py +47 -6
  146. warp/tests/test_spatial.py +75 -0
  147. warp/tests/test_static.py +1 -1
  148. warp/tests/test_utils.py +84 -4
  149. warp/tests/test_vec.py +336 -178
  150. warp/tests/tile/__init__.py +0 -0
  151. warp/tests/{test_tile.py → tile/test_tile.py} +136 -51
  152. warp/tests/{test_tile_load.py → tile/test_tile_load.py} +98 -1
  153. warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +9 -6
  154. warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +25 -14
  155. warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +60 -1
  156. warp/tests/{test_tile_view.py → tile/test_tile_view.py} +1 -1
  157. warp/tests/unittest_serial.py +1 -0
  158. warp/tests/unittest_suites.py +45 -62
  159. warp/tests/unittest_utils.py +2 -1
  160. warp/thirdparty/unittest_parallel.py +3 -1
  161. warp/types.py +175 -666
  162. warp/utils.py +137 -72
  163. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/METADATA +46 -12
  164. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/RECORD +184 -171
  165. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/WHEEL +1 -1
  166. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info/licenses}/LICENSE.md +0 -26
  167. warp/examples/optim/example_walker.py +0 -317
  168. warp/native/cutlass_gemm.cpp +0 -43
  169. warp/native/cutlass_gemm.cu +0 -382
  170. warp/tests/test_matmul.py +0 -511
  171. warp/tests/test_matmul_lite.py +0 -411
  172. warp/tests/test_vbd.py +0 -386
  173. warp/tests/unused_test_misc.py +0 -77
  174. /warp/tests/{test_async.py → cuda/test_async.py} +0 -0
  175. /warp/tests/{test_ipc.py → cuda/test_ipc.py} +0 -0
  176. /warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +0 -0
  177. /warp/tests/{test_peer.py → cuda/test_peer.py} +0 -0
  178. /warp/tests/{test_pinned.py → cuda/test_pinned.py} +0 -0
  179. /warp/tests/{test_bvh.py → geometry/test_bvh.py} +0 -0
  180. /warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +0 -0
  181. /warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +0 -0
  182. /warp/tests/{test_mesh.py → geometry/test_mesh.py} +0 -0
  183. /warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +0 -0
  184. /warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +0 -0
  185. /warp/tests/{test_jax.py → interop/test_jax.py} +0 -0
  186. /warp/tests/{test_paddle.py → interop/test_paddle.py} +0 -0
  187. /warp/tests/{test_torch.py → interop/test_torch.py} +0 -0
  188. /warp/tests/{test_coloring.py → sim/test_coloring.py} +0 -0
  189. /warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +0 -0
  190. /warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +0 -0
  191. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/top_level.txt +0 -0
warp/context.py CHANGED
@@ -17,7 +17,6 @@ from __future__ import annotations
17
17
 
18
18
  import ast
19
19
  import ctypes
20
- import errno
21
20
  import functools
22
21
  import hashlib
23
22
  import inspect
@@ -28,13 +27,27 @@ import operator
28
27
  import os
29
28
  import platform
30
29
  import sys
31
- import time
32
30
  import types
33
31
  import typing
34
32
  import weakref
35
33
  from copy import copy as shallowcopy
36
34
  from pathlib import Path
37
- from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
35
+ from typing import (
36
+ Any,
37
+ Callable,
38
+ Dict,
39
+ List,
40
+ Literal,
41
+ Mapping,
42
+ Optional,
43
+ Sequence,
44
+ Set,
45
+ Tuple,
46
+ TypeVar,
47
+ Union,
48
+ get_args,
49
+ get_origin,
50
+ )
38
51
 
39
52
  import numpy as np
40
53
 
@@ -42,7 +55,7 @@ import warp
42
55
  import warp.build
43
56
  import warp.codegen
44
57
  import warp.config
45
- from warp.types import launch_bounds_t
58
+ from warp.types import Array, launch_bounds_t
46
59
 
47
60
  # represents either a built-in or user-defined function
48
61
 
@@ -71,10 +84,10 @@ def get_function_args(func):
71
84
  complex_type_hints = (Any, Callable, Tuple)
72
85
  sequence_types = (list, tuple)
73
86
 
74
- function_key_counts = {}
87
+ function_key_counts: Dict[str, int] = {}
75
88
 
76
89
 
77
- def generate_unique_function_identifier(key):
90
+ def generate_unique_function_identifier(key: str) -> str:
78
91
  # Generate unique identifiers for user-defined functions in native code.
79
92
  # - Prevents conflicts when a function is redefined and old versions are still in use.
80
93
  # - Prevents conflicts between multiple closures returned from the same function.
@@ -107,40 +120,40 @@ def generate_unique_function_identifier(key):
107
120
  class Function:
108
121
  def __init__(
109
122
  self,
110
- func,
111
- key,
112
- namespace,
113
- input_types=None,
114
- value_type=None,
115
- value_func=None,
116
- export_func=None,
117
- dispatch_func=None,
118
- lto_dispatch_func=None,
119
- module=None,
120
- variadic=False,
121
- initializer_list_func=None,
122
- export=False,
123
- doc="",
124
- group="",
125
- hidden=False,
126
- skip_replay=False,
127
- missing_grad=False,
128
- generic=False,
129
- native_func=None,
130
- defaults=None,
131
- custom_replay_func=None,
132
- native_snippet=None,
133
- adj_native_snippet=None,
134
- replay_snippet=None,
135
- skip_forward_codegen=False,
136
- skip_reverse_codegen=False,
137
- custom_reverse_num_input_args=-1,
138
- custom_reverse_mode=False,
139
- overloaded_annotations=None,
140
- code_transformers=None,
141
- skip_adding_overload=False,
142
- require_original_output_arg=False,
143
- scope_locals=None, # the locals() where the function is defined, used for overload management
123
+ func: Optional[Callable],
124
+ key: str,
125
+ namespace: str,
126
+ input_types: Optional[Dict[str, Union[type, TypeVar]]] = None,
127
+ value_type: Optional[type] = None,
128
+ value_func: Optional[Callable[[Mapping[str, type], Mapping[str, Any]], type]] = None,
129
+ export_func: Optional[Callable[[Dict[str, type]], Dict[str, type]]] = None,
130
+ dispatch_func: Optional[Callable] = None,
131
+ lto_dispatch_func: Optional[Callable] = None,
132
+ module: Optional[Module] = None,
133
+ variadic: bool = False,
134
+ initializer_list_func: Optional[Callable[[Dict[str, Any], type], bool]] = None,
135
+ export: bool = False,
136
+ doc: str = "",
137
+ group: str = "",
138
+ hidden: bool = False,
139
+ skip_replay: bool = False,
140
+ missing_grad: bool = False,
141
+ generic: bool = False,
142
+ native_func: Optional[str] = None,
143
+ defaults: Optional[Dict[str, Any]] = None,
144
+ custom_replay_func: Optional[Function] = None,
145
+ native_snippet: Optional[str] = None,
146
+ adj_native_snippet: Optional[str] = None,
147
+ replay_snippet: Optional[str] = None,
148
+ skip_forward_codegen: bool = False,
149
+ skip_reverse_codegen: bool = False,
150
+ custom_reverse_num_input_args: int = -1,
151
+ custom_reverse_mode: bool = False,
152
+ overloaded_annotations: Optional[Dict[str, type]] = None,
153
+ code_transformers: Optional[List[ast.NodeTransformer]] = None,
154
+ skip_adding_overload: bool = False,
155
+ require_original_output_arg: bool = False,
156
+ scope_locals: Optional[Dict[str, Any]] = None,
144
157
  ):
145
158
  if code_transformers is None:
146
159
  code_transformers = []
@@ -165,7 +178,7 @@ class Function:
165
178
  self.native_snippet = native_snippet
166
179
  self.adj_native_snippet = adj_native_snippet
167
180
  self.replay_snippet = replay_snippet
168
- self.custom_grad_func = None
181
+ self.custom_grad_func: Optional[Function] = None
169
182
  self.require_original_output_arg = require_original_output_arg
170
183
  self.generic_parent = None # generic function that was used to instantiate this overload
171
184
 
@@ -181,6 +194,7 @@ class Function:
181
194
  )
182
195
  self.missing_grad = missing_grad # whether builtin is missing a corresponding adjoint
183
196
  self.generic = generic
197
+ self.mangled_name: Optional[str] = None
184
198
 
185
199
  # allow registering functions with a different name in Python and native code
186
200
  if native_func is None:
@@ -197,8 +211,8 @@ class Function:
197
211
  # user-defined function
198
212
 
199
213
  # generic and concrete overload lookups by type signature
200
- self.user_templates = {}
201
- self.user_overloads = {}
214
+ self.user_templates: Dict[str, Function] = {}
215
+ self.user_overloads: Dict[str, Function] = {}
202
216
 
203
217
  # user defined (Python) function
204
218
  self.adj = warp.codegen.Adjoint(
@@ -229,19 +243,17 @@ class Function:
229
243
  # builtin function
230
244
 
231
245
  # embedded linked list of all overloads
232
- # the builtin_functions dictionary holds
233
- # the list head for a given key (func name)
234
- self.overloads = []
246
+ # the builtin_functions dictionary holds the list head for a given key (func name)
247
+ self.overloads: List[Function] = []
235
248
 
236
249
  # builtin (native) function, canonicalize argument types
237
- for k, v in input_types.items():
238
- self.input_types[k] = warp.types.type_to_warp(v)
250
+ if input_types is not None:
251
+ for k, v in input_types.items():
252
+ self.input_types[k] = warp.types.type_to_warp(v)
239
253
 
240
254
  # cache mangled name
241
255
  if self.export and self.is_simple():
242
256
  self.mangled_name = self.mangle()
243
- else:
244
- self.mangled_name = None
245
257
 
246
258
  if not skip_adding_overload:
247
259
  self.add_overload(self)
@@ -272,7 +284,7 @@ class Function:
272
284
  signature_params.append(param)
273
285
  self.signature = inspect.Signature(signature_params)
274
286
 
275
- # scope for resolving overloads
287
+ # scope for resolving overloads, the locals() where the function is defined
276
288
  if scope_locals is None:
277
289
  scope_locals = inspect.currentframe().f_back.f_locals
278
290
 
@@ -334,10 +346,10 @@ class Function:
334
346
  # this function has no overloads, call it like a plain Python function
335
347
  return self.func(*args, **kwargs)
336
348
 
337
- def is_builtin(self):
349
+ def is_builtin(self) -> bool:
338
350
  return self.func is None
339
351
 
340
- def is_simple(self):
352
+ def is_simple(self) -> bool:
341
353
  if self.variadic:
342
354
  return False
343
355
 
@@ -351,9 +363,8 @@ class Function:
351
363
 
352
364
  return True
353
365
 
354
- def mangle(self):
355
- # builds a mangled name for the C-exported
356
- # function, e.g.: builtin_normalize_vec3()
366
+ def mangle(self) -> str:
367
+ """Build a mangled name for the C-exported function, e.g.: `builtin_normalize_vec3()`."""
357
368
 
358
369
  name = "builtin_" + self.key
359
370
 
@@ -369,7 +380,7 @@ class Function:
369
380
 
370
381
  return "_".join([name, *types])
371
382
 
372
- def add_overload(self, f):
383
+ def add_overload(self, f: Function) -> None:
373
384
  if self.is_builtin():
374
385
  # todo: note that it is an error to add two functions
375
386
  # with the exact same signature as this would cause compile
@@ -384,7 +395,7 @@ class Function:
384
395
  else:
385
396
  # get function signature based on the input types
386
397
  sig = warp.types.get_signature(
387
- f.input_types.values(), func_name=f.key, arg_names=list(f.input_types.keys())
398
+ list(f.input_types.values()), func_name=f.key, arg_names=list(f.input_types.keys())
388
399
  )
389
400
 
390
401
  # check if generic
@@ -393,7 +404,7 @@ class Function:
393
404
  else:
394
405
  self.user_overloads[sig] = f
395
406
 
396
- def get_overload(self, arg_types, kwarg_types):
407
+ def get_overload(self, arg_types: List[type], kwarg_types: Mapping[str, type]) -> Optional[Function]:
397
408
  assert not self.is_builtin()
398
409
 
399
410
  for f in self.user_overloads.values():
@@ -446,7 +457,7 @@ class Function:
446
457
  return f"<Function {self.key}({inputs_str})>"
447
458
 
448
459
 
449
- def call_builtin(func: Function, *params) -> Tuple[bool, Any]:
460
+ def call_builtin(func: Function, *params: Any) -> Tuple[bool, Any]:
450
461
  uses_non_warp_array_type = False
451
462
 
452
463
  init()
@@ -758,42 +769,62 @@ class Kernel:
758
769
 
759
770
  return f"{self.key}_{hash_suffix}"
760
771
 
772
+ def __call__(self, *args, **kwargs):
773
+ # we implement this function only to ensure Kernel is a callable object
774
+ # so that we can document Warp kernels in the same way as Python functions
775
+ # annotated by @wp.kernel (see functools.update_wrapper())
776
+ raise NotImplementedError("Kernel.__call__() is not implemented, please use wp.launch() instead")
777
+
761
778
 
762
779
  # ----------------------
763
780
 
764
781
 
765
782
  # decorator to register function, @func
766
- def func(f):
767
- name = warp.codegen.make_full_qualified_name(f)
768
-
769
- scope_locals = inspect.currentframe().f_back.f_locals
770
-
771
- m = get_module(f.__module__)
772
- doc = getattr(f, "__doc__", "") or ""
773
- Function(
774
- func=f,
775
- key=name,
776
- namespace="",
777
- module=m,
778
- value_func=None,
779
- scope_locals=scope_locals,
780
- doc=doc.strip(),
781
- ) # value_type not known yet, will be inferred during Adjoint.build()
782
-
783
- # use the top of the list of overloads for this key
784
- g = m.functions[name]
785
- # copy over the function attributes, including docstring
786
- return functools.update_wrapper(g, f)
787
-
788
-
789
- def func_native(snippet, adj_snippet=None, replay_snippet=None):
783
+ def func(f: Optional[Callable] = None, *, name: Optional[str] = None):
784
+ def wrapper(f, *args, **kwargs):
785
+ if name is None:
786
+ key = warp.codegen.make_full_qualified_name(f)
787
+ else:
788
+ key = name
789
+
790
+ scope_locals = inspect.currentframe().f_back.f_back.f_locals
791
+
792
+ m = get_module(f.__module__)
793
+ doc = getattr(f, "__doc__", "") or ""
794
+ Function(
795
+ func=f,
796
+ key=key,
797
+ namespace="",
798
+ module=m,
799
+ value_func=None,
800
+ scope_locals=scope_locals,
801
+ doc=doc.strip(),
802
+ ) # value_type not known yet, will be inferred during Adjoint.build()
803
+
804
+ # use the top of the list of overloads for this key
805
+ g = m.functions[key]
806
+ # copy over the function attributes, including docstring
807
+ return functools.update_wrapper(g, f)
808
+
809
+ if f is None:
810
+ # Arguments were passed to the decorator.
811
+ return wrapper
812
+
813
+ return wrapper(f)
814
+
815
+
816
+ def func_native(snippet: str, adj_snippet: Optional[str] = None, replay_snippet: Optional[str] = None):
790
817
  """
791
818
  Decorator to register native code snippet, @func_native
792
819
  """
793
820
 
794
- scope_locals = inspect.currentframe().f_back.f_locals
821
+ frame = inspect.currentframe()
822
+ if frame is None or frame.f_back is None:
823
+ scope_locals = {}
824
+ else:
825
+ scope_locals = frame.f_back.f_locals
795
826
 
796
- def snippet_func(f):
827
+ def snippet_func(f: Callable) -> Callable:
797
828
  name = warp.codegen.make_full_qualified_name(f)
798
829
 
799
830
  m = get_module(f.__module__)
@@ -965,22 +996,71 @@ def func_replay(forward_fn):
965
996
  return wrapper
966
997
 
967
998
 
968
- # decorator to register kernel, @kernel, custom_name may be a string
969
- # that creates a kernel with a different name from the actual function
970
- def kernel(f=None, *, enable_backward=None):
999
+ def kernel(
1000
+ f: Optional[Callable] = None,
1001
+ *,
1002
+ enable_backward: Optional[bool] = None,
1003
+ module: Optional[Union[Module, Literal["unique"]]] = None,
1004
+ ):
1005
+ """
1006
+ Decorator to register a Warp kernel from a Python function.
1007
+ The function must be defined with type annotations for all arguments.
1008
+ The function must not return anything.
1009
+
1010
+ Example::
1011
+
1012
+ @wp.kernel
1013
+ def my_kernel(a: wp.array(dtype=float), b: wp.array(dtype=float)):
1014
+ tid = wp.tid()
1015
+ b[tid] = a[tid] + 1.0
1016
+
1017
+
1018
+ @wp.kernel(enable_backward=False)
1019
+ def my_kernel_no_backward(a: wp.array(dtype=float, ndim=2), x: float):
1020
+ # the backward pass will not be generated
1021
+ i, j = wp.tid()
1022
+ a[i, j] = x
1023
+
1024
+
1025
+ @wp.kernel(module="unique")
1026
+ def my_kernel_unique_module(a: wp.array(dtype=float), b: wp.array(dtype=float)):
1027
+ # the kernel will be registered in new unique module created just for this
1028
+ # kernel and its dependent functions and structs
1029
+ tid = wp.tid()
1030
+ b[tid] = a[tid] + 1.0
1031
+
1032
+ Args:
1033
+ f: The function to be registered as a kernel.
1034
+ enable_backward: If False, the backward pass will not be generated.
1035
+ module: The :class:`warp.context.Module` to which the kernel belongs. Alternatively, if a string `"unique"` is provided, the kernel is assigned to a new module named after the kernel name and hash. If None, the module is inferred from the function's module.
1036
+
1037
+ Returns:
1038
+ The registered kernel.
1039
+ """
1040
+
971
1041
  def wrapper(f, *args, **kwargs):
972
1042
  options = {}
973
1043
 
974
1044
  if enable_backward is not None:
975
1045
  options["enable_backward"] = enable_backward
976
1046
 
977
- m = get_module(f.__module__)
1047
+ if module is None:
1048
+ m = get_module(f.__module__)
1049
+ elif module == "unique":
1050
+ m = Module(f.__name__, None)
1051
+ else:
1052
+ m = module
978
1053
  k = Kernel(
979
1054
  func=f,
980
1055
  key=warp.codegen.make_full_qualified_name(f),
981
1056
  module=m,
982
1057
  options=options,
983
1058
  )
1059
+ if module == "unique":
1060
+ # add the hash to the module name
1061
+ hasher = warp.context.ModuleHasher(m)
1062
+ k.module.name = f"{k.key}_{hasher.module_hash.hex()[:8]}"
1063
+
984
1064
  k = functools.update_wrapper(k, f)
985
1065
  return k
986
1066
 
@@ -992,7 +1072,7 @@ def kernel(f=None, *, enable_backward=None):
992
1072
 
993
1073
 
994
1074
  # decorator to register struct, @struct
995
- def struct(c):
1075
+ def struct(c: type):
996
1076
  m = get_module(c.__module__)
997
1077
  s = warp.codegen.Struct(cls=c, key=warp.codegen.make_full_qualified_name(c), module=m)
998
1078
  s = functools.update_wrapper(s, c)
@@ -1105,47 +1185,47 @@ scalar_types.update({x: x._wp_scalar_type_ for x in warp.types.vector_types})
1105
1185
 
1106
1186
 
1107
1187
  def add_builtin(
1108
- key,
1109
- input_types=None,
1110
- constraint=None,
1111
- value_type=None,
1112
- value_func=None,
1113
- export_func=None,
1114
- dispatch_func=None,
1115
- lto_dispatch_func=None,
1116
- doc="",
1117
- namespace="wp::",
1118
- variadic=False,
1188
+ key: str,
1189
+ input_types: Optional[Dict[str, Union[type, TypeVar]]] = None,
1190
+ constraint: Optional[Callable[[Mapping[str, type]], bool]] = None,
1191
+ value_type: Optional[type] = None,
1192
+ value_func: Optional[Callable] = None,
1193
+ export_func: Optional[Callable] = None,
1194
+ dispatch_func: Optional[Callable] = None,
1195
+ lto_dispatch_func: Optional[Callable] = None,
1196
+ doc: str = "",
1197
+ namespace: str = "wp::",
1198
+ variadic: bool = False,
1119
1199
  initializer_list_func=None,
1120
- export=True,
1121
- group="Other",
1122
- hidden=False,
1123
- skip_replay=False,
1124
- missing_grad=False,
1125
- native_func=None,
1126
- defaults=None,
1127
- require_original_output_arg=False,
1200
+ export: bool = True,
1201
+ group: str = "Other",
1202
+ hidden: bool = False,
1203
+ skip_replay: bool = False,
1204
+ missing_grad: bool = False,
1205
+ native_func: Optional[str] = None,
1206
+ defaults: Optional[Dict[str, Any]] = None,
1207
+ require_original_output_arg: bool = False,
1128
1208
  ):
1129
1209
  """Main entry point to register a new built-in function.
1130
1210
 
1131
1211
  Args:
1132
- key (str): Function name. Multiple overloaded functions can be registered
1212
+ key: Function name. Multiple overloaded functions can be registered
1133
1213
  under the same name as long as their signature differ.
1134
- input_types (Mapping[str, Any]): Signature of the user-facing function.
1214
+ input_types: Signature of the user-facing function.
1135
1215
  Variadic arguments are supported by prefixing the parameter names
1136
1216
  with asterisks as in `*args` and `**kwargs`. Generic arguments are
1137
1217
  supported with types such as `Any`, `Float`, `Scalar`, etc.
1138
- constraint (Callable): For functions that define generic arguments and
1218
+ constraint: For functions that define generic arguments and
1139
1219
  are to be exported, this callback is used to specify whether some
1140
1220
  combination of inferred arguments are valid or not.
1141
- value_type (Any): Type returned by the function.
1142
- value_func (Callable): Callback used to specify the return type when
1221
+ value_type: Type returned by the function.
1222
+ value_func: Callback used to specify the return type when
1143
1223
  `value_type` isn't enough.
1144
- export_func (Callable): Callback used during the context stage to specify
1224
+ export_func: Callback used during the context stage to specify
1145
1225
  the signature of the underlying C++ function, not accounting for
1146
1226
  the template parameters.
1147
1227
  If not provided, `input_types` is used.
1148
- dispatch_func (Callable): Callback used during the codegen stage to specify
1228
+ dispatch_func: Callback used during the codegen stage to specify
1149
1229
  the runtime and template arguments to be passed to the underlying C++
1150
1230
  function. In other words, this allows defining a mapping between
1151
1231
  the signatures of the user-facing and the C++ functions, and even to
@@ -1153,27 +1233,26 @@ def add_builtin(
1153
1233
  The arguments returned must be of type `codegen.Var`.
1154
1234
  If not provided, all arguments passed by the users when calling
1155
1235
  the built-in are passed as-is as runtime arguments to the C++ function.
1156
- lto_dispatch_func (Callable): Same as dispatch_func, but takes an 'option' dict
1236
+ lto_dispatch_func: Same as dispatch_func, but takes an 'option' dict
1157
1237
  as extra argument (indicating tile_size and target architecture) and returns
1158
1238
  an LTO-IR buffer as extra return value
1159
- doc (str): Used to generate the Python's docstring and the HTML documentation.
1239
+ doc: Used to generate the Python's docstring and the HTML documentation.
1160
1240
  namespace: Namespace for the underlying C++ function.
1161
- variadic (bool): Whether the function declares variadic arguments.
1162
- initializer_list_func (bool): Whether to use the initializer list syntax
1163
- when passing the arguments to the underlying C++ function.
1164
- export (bool): Whether the function is to be exposed to the Python
1241
+ variadic: Whether the function declares variadic arguments.
1242
+ initializer_list_func: Callback to determine whether to use the
1243
+ initializer list syntax when passing the arguments to the underlying
1244
+ C++ function.
1245
+ export: Whether the function is to be exposed to the Python
1165
1246
  interpreter so that it becomes available from within the `warp`
1166
1247
  module.
1167
- group (str): Classification used for the documentation.
1168
- hidden (bool): Whether to add that function into the documentation.
1169
- skip_replay (bool): Whether operation will be performed during
1248
+ group: Classification used for the documentation.
1249
+ hidden: Whether to add that function into the documentation.
1250
+ skip_replay: Whether operation will be performed during
1170
1251
  the forward replay in the backward pass.
1171
- missing_grad (bool): Whether the function is missing a corresponding
1172
- adjoint.
1173
- native_func (str): Name of the underlying C++ function.
1174
- defaults (Mapping[str, Any]): Default values for the parameters defined
1175
- in `input_types`.
1176
- require_original_output_arg (bool): Used during the codegen stage to
1252
+ missing_grad: Whether the function is missing a corresponding adjoint.
1253
+ native_func: Name of the underlying C++ function.
1254
+ defaults: Default values for the parameters defined in `input_types`.
1255
+ require_original_output_arg: Used during the codegen stage to
1177
1256
  specify whether an adjoint parameter corresponding to the return
1178
1257
  value should be included in the signature of the backward function.
1179
1258
  """
@@ -1355,19 +1434,14 @@ def add_builtin(
1355
1434
  def register_api_function(
1356
1435
  function: Function,
1357
1436
  group: str = "Other",
1358
- hidden=False,
1437
+ hidden: bool = False,
1359
1438
  ):
1360
1439
  """Main entry point to register a Warp Python function to be part of the Warp API and appear in the documentation.
1361
1440
 
1362
1441
  Args:
1363
- function (Function): Warp function to be registered.
1364
- group (str): Classification used for the documentation.
1365
- input_types (Mapping[str, Any]): Signature of the user-facing function.
1366
- Variadic arguments are supported by prefixing the parameter names
1367
- with asterisks as in `*args` and `**kwargs`. Generic arguments are
1368
- supported with types such as `Any`, `Float`, `Scalar`, etc.
1369
- value_type (Any): Type returned by the function.
1370
- hidden (bool): Whether to add that function into the documentation.
1442
+ function: Warp function to be registered.
1443
+ group: Classification used for the documentation.
1444
+ hidden: Whether to add that function into the documentation.
1371
1445
  """
1372
1446
  function.group = group
1373
1447
  function.hidden = hidden
@@ -1375,10 +1449,10 @@ def register_api_function(
1375
1449
 
1376
1450
 
1377
1451
  # global dictionary of modules
1378
- user_modules = {}
1452
+ user_modules: Dict[str, Module] = {}
1379
1453
 
1380
1454
 
1381
- def get_module(name):
1455
+ def get_module(name: str) -> Module:
1382
1456
  # some modules might be manually imported using `importlib` without being
1383
1457
  # registered into `sys.modules`
1384
1458
  parent = sys.modules.get(name, None)
@@ -1460,13 +1534,16 @@ class ModuleHasher:
1460
1534
  if warp.config.verify_fp:
1461
1535
  ch.update(bytes("verify_fp", "utf-8"))
1462
1536
 
1537
+ # line directives, e.g. for Nsight Compute
1538
+ ch.update(bytes(ctypes.c_int(warp.config.line_directives)))
1539
+
1463
1540
  # build config
1464
1541
  ch.update(bytes(warp.config.mode, "utf-8"))
1465
1542
 
1466
1543
  # save the module hash
1467
1544
  self.module_hash = ch.digest()
1468
1545
 
1469
- def hash_kernel(self, kernel):
1546
+ def hash_kernel(self, kernel: Kernel) -> bytes:
1470
1547
  # NOTE: We only hash non-generic kernels, so we don't traverse kernel overloads here.
1471
1548
 
1472
1549
  ch = hashlib.sha256()
@@ -1480,7 +1557,7 @@ class ModuleHasher:
1480
1557
 
1481
1558
  return h
1482
1559
 
1483
- def hash_function(self, func):
1560
+ def hash_function(self, func: Function) -> bytes:
1484
1561
  # NOTE: This method hashes all possible overloads that a function call could resolve to.
1485
1562
  # The exact overload will be resolved at build time, when the argument types are known.
1486
1563
 
@@ -1495,7 +1572,7 @@ class ModuleHasher:
1495
1572
  ch.update(bytes(func.key, "utf-8"))
1496
1573
 
1497
1574
  # include all concrete and generic overloads
1498
- overloads = {**func.user_overloads, **func.user_templates}
1575
+ overloads: Dict[str, Function] = {**func.user_overloads, **func.user_templates}
1499
1576
  for sig in sorted(overloads.keys()):
1500
1577
  ovl = overloads[sig]
1501
1578
 
@@ -1526,7 +1603,7 @@ class ModuleHasher:
1526
1603
 
1527
1604
  return h
1528
1605
 
1529
- def hash_adjoint(self, adj):
1606
+ def hash_adjoint(self, adj: warp.codegen.Adjoint) -> bytes:
1530
1607
  # NOTE: We don't cache adjoint hashes, because adjoints are always unique.
1531
1608
  # Even instances of generic kernels and functions have unique adjoints with
1532
1609
  # different argument types.
@@ -1575,7 +1652,7 @@ class ModuleHasher:
1575
1652
 
1576
1653
  return ch.digest()
1577
1654
 
1578
- def get_constant_bytes(self, value):
1655
+ def get_constant_bytes(self, value) -> bytes:
1579
1656
  if isinstance(value, int):
1580
1657
  # this also handles builtins.bool
1581
1658
  return bytes(ctypes.c_int(value))
@@ -1593,7 +1670,7 @@ class ModuleHasher:
1593
1670
  else:
1594
1671
  raise TypeError(f"Invalid constant type: {type(value)}")
1595
1672
 
1596
- def get_module_hash(self):
1673
+ def get_module_hash(self) -> bytes:
1597
1674
  return self.module_hash
1598
1675
 
1599
1676
  def get_unique_kernels(self):
@@ -1610,6 +1687,7 @@ class ModuleBuilder:
1610
1687
  self.fatbins = {} # map from <some identifier> to fatbins, to add at link time
1611
1688
  self.ltoirs = {} # map from lto symbol to lto binary
1612
1689
  self.ltoirs_decl = {} # map from lto symbol to lto forward declaration
1690
+ self.shared_memory_bytes = {} # map from lto symbol to shared memory requirements
1613
1691
 
1614
1692
  if hasher is None:
1615
1693
  hasher = ModuleHasher(module)
@@ -1726,9 +1804,9 @@ class ModuleBuilder:
1726
1804
 
1727
1805
  # add headers
1728
1806
  if device == "cpu":
1729
- source = warp.codegen.cpu_module_header.format(tile_size=self.options["block_dim"]) + source
1807
+ source = warp.codegen.cpu_module_header.format(block_dim=self.options["block_dim"]) + source
1730
1808
  else:
1731
- source = warp.codegen.cuda_module_header.format(tile_size=self.options["block_dim"]) + source
1809
+ source = warp.codegen.cuda_module_header.format(block_dim=self.options["block_dim"]) + source
1732
1810
 
1733
1811
  return source
1734
1812
 
@@ -1765,7 +1843,7 @@ class ModuleExec:
1765
1843
  runtime.llvm.unload_obj(self.handle.encode("utf-8"))
1766
1844
 
1767
1845
  # lookup and cache kernel entry points
1768
- def get_kernel_hooks(self, kernel):
1846
+ def get_kernel_hooks(self, kernel) -> KernelHooks:
1769
1847
  # Use kernel.adj as a unique key for cache lookups instead of the kernel itself.
1770
1848
  # This avoids holding a reference to the kernel and is faster than using
1771
1849
  # a WeakKeyDictionary with kernels as keys.
@@ -1838,7 +1916,7 @@ class ModuleExec:
1838
1916
  # creates a hash of the function to use for checking
1839
1917
  # build cache
1840
1918
  class Module:
1841
- def __init__(self, name, loader):
1919
+ def __init__(self, name: Optional[str], loader=None):
1842
1920
  self.name = name if name is not None else "None"
1843
1921
 
1844
1922
  self.loader = loader
@@ -1878,7 +1956,7 @@ class Module:
1878
1956
  "enable_backward": warp.config.enable_backward,
1879
1957
  "fast_math": False,
1880
1958
  "fuse_fp": True,
1881
- "lineinfo": False,
1959
+ "lineinfo": warp.config.lineinfo,
1882
1960
  "cuda_output": None, # supported values: "ptx", "cubin", or None (automatic)
1883
1961
  "mode": warp.config.mode,
1884
1962
  "block_dim": 256,
@@ -2081,7 +2159,11 @@ class Module:
2081
2159
  use_ptx = True
2082
2160
 
2083
2161
  if use_ptx:
2084
- output_arch = min(device.arch, warp.config.ptx_target_arch)
2162
+ # use the default PTX arch if the device supports it
2163
+ if warp.config.ptx_target_arch is not None:
2164
+ output_arch = min(device.arch, warp.config.ptx_target_arch)
2165
+ else:
2166
+ output_arch = min(device.arch, runtime.default_ptx_arch)
2085
2167
  output_name = f"{module_name_short}.sm{output_arch}.ptx"
2086
2168
  else:
2087
2169
  output_arch = device.arch
@@ -2194,34 +2276,8 @@ class Module:
2194
2276
  # -----------------------------------------------------------
2195
2277
  # update cache
2196
2278
 
2197
- def safe_rename(src, dst, attempts=5, delay=0.1):
2198
- for i in range(attempts):
2199
- try:
2200
- os.rename(src, dst)
2201
- return
2202
- except FileExistsError:
2203
- return
2204
- except OSError as e:
2205
- if e.errno == errno.ENOTEMPTY:
2206
- # if directory exists we assume another process
2207
- # got there first, in which case we will copy
2208
- # our output to the directory manually in second step
2209
- return
2210
- else:
2211
- # otherwise assume directory creation failed e.g.: access denied
2212
- # on Windows we see occasional failures to rename directories due to
2213
- # some process holding a lock on a file to be moved to workaround
2214
- # this we make multiple attempts to rename with some delay
2215
- if i < attempts - 1:
2216
- time.sleep(delay)
2217
- else:
2218
- print(
2219
- f"Could not update Warp cache with module binaries, trying to rename {build_dir} to {module_dir}, error {e}"
2220
- )
2221
- raise e
2222
-
2223
2279
  # try to move process outputs to cache
2224
- safe_rename(build_dir, module_dir)
2280
+ warp.build.safe_rename(build_dir, module_dir)
2225
2281
 
2226
2282
  if os.path.exists(module_dir):
2227
2283
  if not os.path.exists(binary_path):
@@ -2294,7 +2350,7 @@ class Module:
2294
2350
  self.failed_builds = set()
2295
2351
 
2296
2352
  # lookup kernel entry points based on name, called after compilation / module load
2297
- def get_kernel_hooks(self, kernel, device):
2353
+ def get_kernel_hooks(self, kernel, device: Device) -> KernelHooks:
2298
2354
  module_exec = self.execs.get((device.context, self.options["block_dim"]))
2299
2355
  if module_exec is not None:
2300
2356
  return module_exec.get_kernel_hooks(kernel)
@@ -2449,6 +2505,7 @@ class Event:
2449
2505
  raise RuntimeError(f"Device {device} is not a CUDA device")
2450
2506
 
2451
2507
  self.device = device
2508
+ self.enable_timing = enable_timing
2452
2509
 
2453
2510
  if cuda_event is not None:
2454
2511
  self.cuda_event = cuda_event
@@ -2498,6 +2555,17 @@ class Event:
2498
2555
  else:
2499
2556
  raise RuntimeError(f"Device {self.device} does not support IPC.")
2500
2557
 
2558
+ @property
2559
+ def is_complete(self) -> bool:
2560
+ """A boolean indicating whether all work on the stream when the event was recorded has completed.
2561
+
2562
+ This property may not be accessed during a graph capture on any stream.
2563
+ """
2564
+
2565
+ result_code = runtime.core.cuda_event_query(self.cuda_event)
2566
+
2567
+ return result_code == 0
2568
+
2501
2569
  def __del__(self):
2502
2570
  if not self.owner:
2503
2571
  return
@@ -2512,7 +2580,7 @@ class Stream:
2512
2580
  instance.owner = False
2513
2581
  return instance
2514
2582
 
2515
- def __init__(self, device: Optional[Union["Device", str]] = None, priority: int = 0, **kwargs):
2583
+ def __init__(self, device: Union["Device", str, None] = None, priority: int = 0, **kwargs):
2516
2584
  """Initialize the stream on a device with an optional specified priority.
2517
2585
 
2518
2586
  Args:
@@ -2528,7 +2596,7 @@ class Stream:
2528
2596
  Raises:
2529
2597
  RuntimeError: If function is called before Warp has completed
2530
2598
  initialization with a ``device`` that is not an instance of
2531
- :class:`Device``.
2599
+ :class:`Device <warp.context.Device>`.
2532
2600
  RuntimeError: ``device`` is not a CUDA Device.
2533
2601
  RuntimeError: The stream could not be created on the device.
2534
2602
  TypeError: The requested stream priority is not an integer.
@@ -2596,7 +2664,7 @@ class Stream:
2596
2664
  f"Event from device {event.device} cannot be recorded on stream from device {self.device}"
2597
2665
  )
2598
2666
 
2599
- runtime.core.cuda_event_record(event.cuda_event, self.cuda_stream)
2667
+ runtime.core.cuda_event_record(event.cuda_event, self.cuda_stream, event.enable_timing)
2600
2668
 
2601
2669
  return event
2602
2670
 
@@ -2630,6 +2698,17 @@ class Stream:
2630
2698
 
2631
2699
  runtime.core.cuda_stream_wait_stream(self.cuda_stream, other_stream.cuda_stream, event.cuda_event)
2632
2700
 
2701
+ @property
2702
+ def is_complete(self) -> bool:
2703
+ """A boolean indicating whether all work on the stream has completed.
2704
+
2705
+ This property may not be accessed during a graph capture on any stream.
2706
+ """
2707
+
2708
+ result_code = runtime.core.cuda_stream_query(self.cuda_stream)
2709
+
2710
+ return result_code == 0
2711
+
2633
2712
  @property
2634
2713
  def is_capturing(self) -> bool:
2635
2714
  """A boolean indicating whether a graph capture is currently ongoing on this stream."""
@@ -2952,18 +3031,14 @@ Devicelike = Union[Device, str, None]
2952
3031
 
2953
3032
 
2954
3033
  class Graph:
2955
- def __new__(cls, *args, **kwargs):
2956
- instance = super(Graph, cls).__new__(cls)
2957
- instance.graph_exec = None
2958
- return instance
2959
-
2960
3034
  def __init__(self, device: Device, capture_id: int):
2961
3035
  self.device = device
2962
3036
  self.capture_id = capture_id
2963
- self.module_execs = set()
3037
+ self.module_execs: Set[ModuleExec] = set()
3038
+ self.graph_exec: Optional[ctypes.c_void_p] = None
2964
3039
 
2965
3040
  def __del__(self):
2966
- if not self.graph_exec:
3041
+ if not hasattr(self, "graph_exec") or not hasattr(self, "device") or not self.graph_exec:
2967
3042
  return
2968
3043
 
2969
3044
  # use CUDA context guard to avoid side effects during garbage collection
@@ -3205,6 +3280,43 @@ class Runtime:
3205
3280
  self.core.radix_sort_pairs_float_host.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int]
3206
3281
  self.core.radix_sort_pairs_float_device.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int]
3207
3282
 
3283
+ self.core.radix_sort_pairs_int64_host.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int]
3284
+ self.core.radix_sort_pairs_int64_device.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int]
3285
+
3286
+ self.core.segmented_sort_pairs_int_host.argtypes = [
3287
+ ctypes.c_uint64,
3288
+ ctypes.c_uint64,
3289
+ ctypes.c_int,
3290
+ ctypes.c_uint64,
3291
+ ctypes.c_uint64,
3292
+ ctypes.c_int,
3293
+ ]
3294
+ self.core.segmented_sort_pairs_int_device.argtypes = [
3295
+ ctypes.c_uint64,
3296
+ ctypes.c_uint64,
3297
+ ctypes.c_int,
3298
+ ctypes.c_uint64,
3299
+ ctypes.c_uint64,
3300
+ ctypes.c_int,
3301
+ ]
3302
+
3303
+ self.core.segmented_sort_pairs_float_host.argtypes = [
3304
+ ctypes.c_uint64,
3305
+ ctypes.c_uint64,
3306
+ ctypes.c_int,
3307
+ ctypes.c_uint64,
3308
+ ctypes.c_uint64,
3309
+ ctypes.c_int,
3310
+ ]
3311
+ self.core.segmented_sort_pairs_float_device.argtypes = [
3312
+ ctypes.c_uint64,
3313
+ ctypes.c_uint64,
3314
+ ctypes.c_int,
3315
+ ctypes.c_uint64,
3316
+ ctypes.c_uint64,
3317
+ ctypes.c_int,
3318
+ ]
3319
+
3208
3320
  self.core.runlength_encode_int_host.argtypes = [
3209
3321
  ctypes.c_uint64,
3210
3322
  ctypes.c_uint64,
@@ -3285,26 +3397,6 @@ class Runtime:
3285
3397
  self.core.hash_grid_update_device.argtypes = [ctypes.c_uint64, ctypes.c_float, ctypes.c_void_p]
3286
3398
  self.core.hash_grid_reserve_device.argtypes = [ctypes.c_uint64, ctypes.c_int]
3287
3399
 
3288
- self.core.cutlass_gemm.argtypes = [
3289
- ctypes.c_void_p,
3290
- ctypes.c_int,
3291
- ctypes.c_int,
3292
- ctypes.c_int,
3293
- ctypes.c_int,
3294
- ctypes.c_char_p,
3295
- ctypes.c_void_p,
3296
- ctypes.c_void_p,
3297
- ctypes.c_void_p,
3298
- ctypes.c_void_p,
3299
- ctypes.c_float,
3300
- ctypes.c_float,
3301
- ctypes.c_bool,
3302
- ctypes.c_bool,
3303
- ctypes.c_bool,
3304
- ctypes.c_int,
3305
- ]
3306
- self.core.cutlass_gemm.restype = ctypes.c_bool
3307
-
3308
3400
  self.core.volume_create_host.argtypes = [ctypes.c_void_p, ctypes.c_uint64, ctypes.c_bool, ctypes.c_bool]
3309
3401
  self.core.volume_create_host.restype = ctypes.c_uint64
3310
3402
  self.core.volume_get_tiles_host.argtypes = [
@@ -3335,36 +3427,18 @@ class Runtime:
3335
3427
  ]
3336
3428
  self.core.volume_destroy_device.argtypes = [ctypes.c_uint64]
3337
3429
 
3338
- self.core.volume_f_from_tiles_device.argtypes = [
3430
+ self.core.volume_from_tiles_device.argtypes = [
3339
3431
  ctypes.c_void_p,
3340
3432
  ctypes.c_void_p,
3341
3433
  ctypes.c_int,
3342
3434
  ctypes.c_float * 9,
3343
3435
  ctypes.c_float * 3,
3344
3436
  ctypes.c_bool,
3345
- ctypes.c_float,
3346
- ]
3347
- self.core.volume_f_from_tiles_device.restype = ctypes.c_uint64
3348
- self.core.volume_v_from_tiles_device.argtypes = [
3349
3437
  ctypes.c_void_p,
3350
- ctypes.c_void_p,
3351
- ctypes.c_int,
3352
- ctypes.c_float * 9,
3353
- ctypes.c_float * 3,
3354
- ctypes.c_bool,
3355
- ctypes.c_float * 3,
3356
- ]
3357
- self.core.volume_v_from_tiles_device.restype = ctypes.c_uint64
3358
- self.core.volume_i_from_tiles_device.argtypes = [
3359
- ctypes.c_void_p,
3360
- ctypes.c_void_p,
3361
- ctypes.c_int,
3362
- ctypes.c_float * 9,
3363
- ctypes.c_float * 3,
3364
- ctypes.c_bool,
3365
- ctypes.c_int,
3438
+ ctypes.c_uint32,
3439
+ ctypes.c_char_p,
3366
3440
  ]
3367
- self.core.volume_i_from_tiles_device.restype = ctypes.c_uint64
3441
+ self.core.volume_from_tiles_device.restype = ctypes.c_uint64
3368
3442
  self.core.volume_index_from_tiles_device.argtypes = [
3369
3443
  ctypes.c_void_p,
3370
3444
  ctypes.c_void_p,
@@ -3433,6 +3507,7 @@ class Runtime:
3433
3507
  ctypes.POINTER(ctypes.c_int), # tpl_cols
3434
3508
  ctypes.c_void_p, # tpl_values
3435
3509
  ctypes.c_bool, # prune_numerical_zeros
3510
+ ctypes.c_bool, # masked
3436
3511
  ctypes.POINTER(ctypes.c_int), # bsr_offsets
3437
3512
  ctypes.POINTER(ctypes.c_int), # bsr_columns
3438
3513
  ctypes.c_void_p, # bsr_values
@@ -3467,8 +3542,6 @@ class Runtime:
3467
3542
  self.core.is_cuda_enabled.restype = ctypes.c_int
3468
3543
  self.core.is_cuda_compatibility_enabled.argtypes = None
3469
3544
  self.core.is_cuda_compatibility_enabled.restype = ctypes.c_int
3470
- self.core.is_cutlass_enabled.argtypes = None
3471
- self.core.is_cutlass_enabled.restype = ctypes.c_int
3472
3545
  self.core.is_mathdx_enabled.argtypes = None
3473
3546
  self.core.is_mathdx_enabled.restype = ctypes.c_int
3474
3547
 
@@ -3502,6 +3575,10 @@ class Runtime:
3502
3575
  self.core.cuda_device_set_mempool_release_threshold.restype = ctypes.c_int
3503
3576
  self.core.cuda_device_get_mempool_release_threshold.argtypes = [ctypes.c_int]
3504
3577
  self.core.cuda_device_get_mempool_release_threshold.restype = ctypes.c_uint64
3578
+ self.core.cuda_device_get_mempool_used_mem_current.argtypes = [ctypes.c_int]
3579
+ self.core.cuda_device_get_mempool_used_mem_current.restype = ctypes.c_uint64
3580
+ self.core.cuda_device_get_mempool_used_mem_high.argtypes = [ctypes.c_int]
3581
+ self.core.cuda_device_get_mempool_used_mem_high.restype = ctypes.c_uint64
3505
3582
  self.core.cuda_device_get_memory_info.argtypes = [ctypes.c_int, ctypes.c_void_p, ctypes.c_void_p]
3506
3583
  self.core.cuda_device_get_memory_info.restype = None
3507
3584
  self.core.cuda_device_get_uuid.argtypes = [ctypes.c_int, ctypes.c_char * 16]
@@ -3571,6 +3648,8 @@ class Runtime:
3571
3648
  self.core.cuda_stream_create.restype = ctypes.c_void_p
3572
3649
  self.core.cuda_stream_destroy.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
3573
3650
  self.core.cuda_stream_destroy.restype = None
3651
+ self.core.cuda_stream_query.argtypes = [ctypes.c_void_p]
3652
+ self.core.cuda_stream_query.restype = ctypes.c_int
3574
3653
  self.core.cuda_stream_register.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
3575
3654
  self.core.cuda_stream_register.restype = None
3576
3655
  self.core.cuda_stream_unregister.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
@@ -3592,7 +3671,9 @@ class Runtime:
3592
3671
  self.core.cuda_event_create.restype = ctypes.c_void_p
3593
3672
  self.core.cuda_event_destroy.argtypes = [ctypes.c_void_p]
3594
3673
  self.core.cuda_event_destroy.restype = None
3595
- self.core.cuda_event_record.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
3674
+ self.core.cuda_event_query.argtypes = [ctypes.c_void_p]
3675
+ self.core.cuda_event_query.restype = ctypes.c_int
3676
+ self.core.cuda_event_record.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_bool]
3596
3677
  self.core.cuda_event_record.restype = None
3597
3678
  self.core.cuda_event_synchronize.argtypes = [ctypes.c_void_p]
3598
3679
  self.core.cuda_event_synchronize.restype = None
@@ -3841,9 +3922,20 @@ class Runtime:
3841
3922
  cuda_device_count = len(self.cuda_devices)
3842
3923
  else:
3843
3924
  self.set_default_device("cuda:0")
3925
+
3926
+ # the minimum PTX architecture that supports all of Warp's features
3927
+ self.default_ptx_arch = 75
3928
+
3929
+ # Update the default PTX architecture based on devices present in the system.
3930
+ # Use the lowest architecture among devices that meet the minimum architecture requirement.
3931
+ # Devices below the required minimum will use the highest architecture they support.
3932
+ eligible_archs = [d.arch for d in self.cuda_devices if d.arch >= self.default_ptx_arch]
3933
+ if eligible_archs:
3934
+ self.default_ptx_arch = min(eligible_archs)
3844
3935
  else:
3845
3936
  # CUDA not available
3846
3937
  self.set_default_device("cpu")
3938
+ self.default_ptx_arch = None
3847
3939
 
3848
3940
  # initialize kernel cache
3849
3941
  warp.build.init_kernel_cache(warp.config.kernel_cache_dir)
@@ -3856,6 +3948,11 @@ class Runtime:
3856
3948
  greeting = []
3857
3949
 
3858
3950
  greeting.append(f"Warp {warp.config.version} initialized:")
3951
+
3952
+ # Add git commit hash to greeting if available
3953
+ if warp.config._git_commit_hash is not None:
3954
+ greeting.append(f" Git commit: {warp.config._git_commit_hash}")
3955
+
3859
3956
  if cuda_device_count > 0:
3860
3957
  # print CUDA version info
3861
3958
  greeting.append(
@@ -4208,7 +4305,7 @@ def set_device(ident: Devicelike) -> None:
4208
4305
  device.make_current()
4209
4306
 
4210
4307
 
4211
- def map_cuda_device(alias: str, context: ctypes.c_void_p = None) -> Device:
4308
+ def map_cuda_device(alias: str, context: Optional[ctypes.c_void_p] = None) -> Device:
4212
4309
  """Assign a device alias to a CUDA context.
4213
4310
 
4214
4311
  This function can be used to create a wp.Device for an external CUDA context.
@@ -4236,7 +4333,13 @@ def unmap_cuda_device(alias: str) -> None:
4236
4333
 
4237
4334
 
4238
4335
  def is_mempool_supported(device: Devicelike) -> bool:
4239
- """Check if CUDA memory pool allocators are available on the device."""
4336
+ """Check if CUDA memory pool allocators are available on the device.
4337
+
4338
+ Parameters:
4339
+ device: The :class:`Device <warp.context.Device>` or device identifier
4340
+ for which the query is to be performed.
4341
+ If ``None``, the default device will be used.
4342
+ """
4240
4343
 
4241
4344
  init()
4242
4345
 
@@ -4246,7 +4349,13 @@ def is_mempool_supported(device: Devicelike) -> bool:
4246
4349
 
4247
4350
 
4248
4351
  def is_mempool_enabled(device: Devicelike) -> bool:
4249
- """Check if CUDA memory pool allocators are enabled on the device."""
4352
+ """Check if CUDA memory pool allocators are enabled on the device.
4353
+
4354
+ Parameters:
4355
+ device: The :class:`Device <warp.context.Device>` or device identifier
4356
+ for which the query is to be performed.
4357
+ If ``None``, the default device will be used.
4358
+ """
4250
4359
 
4251
4360
  init()
4252
4361
 
@@ -4266,6 +4375,11 @@ def set_mempool_enabled(device: Devicelike, enable: bool) -> None:
4266
4375
  to Warp. The preferred solution is to enable memory pool access using :func:`set_mempool_access_enabled`.
4267
4376
  If peer access is not supported, then the default CUDA allocators must be used to pre-allocate the memory
4268
4377
  prior to graph capture.
4378
+
4379
+ Parameters:
4380
+ device: The :class:`Device <warp.context.Device>` or device identifier
4381
+ for which the operation is to be performed.
4382
+ If ``None``, the default device will be used.
4269
4383
  """
4270
4384
 
4271
4385
  init()
@@ -4296,6 +4410,18 @@ def set_mempool_release_threshold(device: Devicelike, threshold: Union[int, floa
4296
4410
  Values between 0 and 1 are interpreted as fractions of available memory. For example, 0.5 means
4297
4411
  half of the device's physical memory. Greater values are interpreted as an absolute number of bytes.
4298
4412
  For example, 1024**3 means one GiB of memory.
4413
+
4414
+ Parameters:
4415
+ device: The :class:`Device <warp.context.Device>` or device identifier
4416
+ for which the operation is to be performed.
4417
+ If ``None``, the default device will be used.
4418
+ threshold: An integer representing a number of bytes, or a ``float`` between 0 and 1,
4419
+ specifying the desired release threshold.
4420
+
4421
+ Raises:
4422
+ ValueError: If ``device`` is not a CUDA device.
4423
+ RuntimeError: If ``device`` is a CUDA device, but does not support memory pools.
4424
+ RuntimeError: Failed to set the memory pool release threshold.
4299
4425
  """
4300
4426
 
4301
4427
  init()
@@ -4317,8 +4443,21 @@ def set_mempool_release_threshold(device: Devicelike, threshold: Union[int, floa
4317
4443
  raise RuntimeError(f"Failed to set memory pool release threshold for device {device}")
4318
4444
 
4319
4445
 
4320
- def get_mempool_release_threshold(device: Devicelike) -> int:
4321
- """Get the CUDA memory pool release threshold on the device in bytes."""
4446
+ def get_mempool_release_threshold(device: Devicelike = None) -> int:
4447
+ """Get the CUDA memory pool release threshold on the device.
4448
+
4449
+ Parameters:
4450
+ device: The :class:`Device <warp.context.Device>` or device identifier
4451
+ for which the query is to be performed.
4452
+ If ``None``, the default device will be used.
4453
+
4454
+ Returns:
4455
+ The memory pool release threshold in bytes.
4456
+
4457
+ Raises:
4458
+ ValueError: If ``device`` is not a CUDA device.
4459
+ RuntimeError: If ``device`` is a CUDA device, but does not support memory pools.
4460
+ """
4322
4461
 
4323
4462
  init()
4324
4463
 
@@ -4333,6 +4472,64 @@ def get_mempool_release_threshold(device: Devicelike) -> int:
4333
4472
  return runtime.core.cuda_device_get_mempool_release_threshold(device.ordinal)
4334
4473
 
4335
4474
 
4475
+ def get_mempool_used_mem_current(device: Devicelike = None) -> int:
4476
+ """Get the amount of memory from the device's memory pool that is currently in use by the application.
4477
+
4478
+ Parameters:
4479
+ device: The :class:`Device <warp.context.Device>` or device identifier
4480
+ for which the query is to be performed.
4481
+ If ``None``, the default device will be used.
4482
+
4483
+ Returns:
4484
+ The amount of memory used in bytes.
4485
+
4486
+ Raises:
4487
+ ValueError: If ``device`` is not a CUDA device.
4488
+ RuntimeError: If ``device`` is a CUDA device, but does not support memory pools.
4489
+ """
4490
+
4491
+ init()
4492
+
4493
+ device = runtime.get_device(device)
4494
+
4495
+ if not device.is_cuda:
4496
+ raise ValueError("Memory pools are only supported on CUDA devices")
4497
+
4498
+ if not device.is_mempool_supported:
4499
+ raise RuntimeError(f"Device {device} does not support memory pools")
4500
+
4501
+ return runtime.core.cuda_device_get_mempool_used_mem_current(device.ordinal)
4502
+
4503
+
4504
+ def get_mempool_used_mem_high(device: Devicelike = None) -> int:
4505
+ """Get the application's memory usage high-water mark from the device's CUDA memory pool.
4506
+
4507
+ Parameters:
4508
+ device: The :class:`Device <warp.context.Device>` or device identifier
4509
+ for which the query is to be performed.
4510
+ If ``None``, the default device will be used.
4511
+
4512
+ Returns:
4513
+ The high-water mark of memory used from the memory pool in bytes.
4514
+
4515
+ Raises:
4516
+ ValueError: If ``device`` is not a CUDA device.
4517
+ RuntimeError: If ``device`` is a CUDA device, but does not support memory pools.
4518
+ """
4519
+
4520
+ init()
4521
+
4522
+ device = runtime.get_device(device)
4523
+
4524
+ if not device.is_cuda:
4525
+ raise ValueError("Memory pools are only supported on CUDA devices")
4526
+
4527
+ if not device.is_mempool_supported:
4528
+ raise RuntimeError(f"Device {device} does not support memory pools")
4529
+
4530
+ return runtime.core.cuda_device_get_mempool_used_mem_high(device.ordinal)
4531
+
4532
+
4336
4533
  def is_peer_access_supported(target_device: Devicelike, peer_device: Devicelike) -> bool:
4337
4534
  """Check if `peer_device` can directly access the memory of `target_device` on this system.
4338
4535
 
@@ -4535,7 +4732,7 @@ def wait_event(event: Event):
4535
4732
  get_stream().wait_event(event)
4536
4733
 
4537
4734
 
4538
- def get_event_elapsed_time(start_event: Event, end_event: Event, synchronize: Optional[bool] = True):
4735
+ def get_event_elapsed_time(start_event: Event, end_event: Event, synchronize: bool = True):
4539
4736
  """Get the elapsed time between two recorded events.
4540
4737
 
4541
4738
  Both events must have been previously recorded with
@@ -4560,7 +4757,7 @@ def get_event_elapsed_time(start_event: Event, end_event: Event, synchronize: Op
4560
4757
  return runtime.core.cuda_event_elapsed_time(start_event.cuda_event, end_event.cuda_event)
4561
4758
 
4562
4759
 
4563
- def wait_stream(other_stream: Stream, event: Event = None):
4760
+ def wait_stream(other_stream: Stream, event: Optional[Event] = None):
4564
4761
  """Convenience function for calling :meth:`Stream.wait_stream` on the current stream.
4565
4762
 
4566
4763
  Args:
@@ -4727,7 +4924,7 @@ class RegisteredGLBuffer:
4727
4924
 
4728
4925
 
4729
4926
  def zeros(
4730
- shape: Tuple = None,
4927
+ shape: Union[int, Tuple[int, ...], List[int], None] = None,
4731
4928
  dtype=float,
4732
4929
  device: Devicelike = None,
4733
4930
  requires_grad: bool = False,
@@ -4755,7 +4952,7 @@ def zeros(
4755
4952
 
4756
4953
 
4757
4954
  def zeros_like(
4758
- src: warp.array, device: Devicelike = None, requires_grad: bool = None, pinned: bool = None
4955
+ src: Array, device: Devicelike = None, requires_grad: Optional[bool] = None, pinned: Optional[bool] = None
4759
4956
  ) -> warp.array:
4760
4957
  """Return a zero-initialized array with the same type and dimension of another array
4761
4958
 
@@ -4777,7 +4974,7 @@ def zeros_like(
4777
4974
 
4778
4975
 
4779
4976
  def ones(
4780
- shape: Tuple = None,
4977
+ shape: Union[int, Tuple[int, ...], List[int], None] = None,
4781
4978
  dtype=float,
4782
4979
  device: Devicelike = None,
4783
4980
  requires_grad: bool = False,
@@ -4801,7 +4998,7 @@ def ones(
4801
4998
 
4802
4999
 
4803
5000
  def ones_like(
4804
- src: warp.array, device: Devicelike = None, requires_grad: bool = None, pinned: bool = None
5001
+ src: Array, device: Devicelike = None, requires_grad: Optional[bool] = None, pinned: Optional[bool] = None
4805
5002
  ) -> warp.array:
4806
5003
  """Return a one-initialized array with the same type and dimension of another array
4807
5004
 
@@ -4819,7 +5016,7 @@ def ones_like(
4819
5016
 
4820
5017
 
4821
5018
  def full(
4822
- shape: Tuple = None,
5019
+ shape: Union[int, Tuple[int, ...], List[int], None] = None,
4823
5020
  value=0,
4824
5021
  dtype=Any,
4825
5022
  device: Devicelike = None,
@@ -4885,7 +5082,11 @@ def full(
4885
5082
 
4886
5083
 
4887
5084
  def full_like(
4888
- src: warp.array, value: Any, device: Devicelike = None, requires_grad: bool = None, pinned: bool = None
5085
+ src: Array,
5086
+ value: Any,
5087
+ device: Devicelike = None,
5088
+ requires_grad: Optional[bool] = None,
5089
+ pinned: Optional[bool] = None,
4889
5090
  ) -> warp.array:
4890
5091
  """Return an array with all elements initialized to the given value with the same type and dimension of another array
4891
5092
 
@@ -4907,7 +5108,9 @@ def full_like(
4907
5108
  return arr
4908
5109
 
4909
5110
 
4910
- def clone(src: warp.array, device: Devicelike = None, requires_grad: bool = None, pinned: bool = None) -> warp.array:
5111
+ def clone(
5112
+ src: warp.array, device: Devicelike = None, requires_grad: Optional[bool] = None, pinned: Optional[bool] = None
5113
+ ) -> warp.array:
4911
5114
  """Clone an existing array, allocates a copy of the src memory
4912
5115
 
4913
5116
  Args:
@@ -4928,7 +5131,7 @@ def clone(src: warp.array, device: Devicelike = None, requires_grad: bool = None
4928
5131
 
4929
5132
 
4930
5133
  def empty(
4931
- shape: Tuple = None,
5134
+ shape: Union[int, Tuple[int, ...], List[int], None] = None,
4932
5135
  dtype=float,
4933
5136
  device: Devicelike = None,
4934
5137
  requires_grad: bool = False,
@@ -4961,7 +5164,7 @@ def empty(
4961
5164
 
4962
5165
 
4963
5166
  def empty_like(
4964
- src: warp.array, device: Devicelike = None, requires_grad: bool = None, pinned: bool = None
5167
+ src: Array, device: Devicelike = None, requires_grad: Optional[bool] = None, pinned: Optional[bool] = None
4965
5168
  ) -> warp.array:
4966
5169
  """Return an uninitialized array with the same type and dimension of another array
4967
5170
 
@@ -5193,8 +5396,6 @@ def pack_arg(kernel, arg_type, arg_name, value, device, adjoint=False):
5193
5396
  ) from e
5194
5397
 
5195
5398
 
5196
- # represents all data required for a kernel launch
5197
- # so that launches can be replayed quickly, use `wp.launch(..., record_cmd=True)`
5198
5399
  class Launch:
5199
5400
  """Represents all data required for a kernel launch so that launches can be replayed quickly.
5200
5401
 
@@ -5465,7 +5666,7 @@ def launch(
5465
5666
  max_blocks: The maximum number of CUDA thread blocks to use.
5466
5667
  Only has an effect for CUDA kernel launches.
5467
5668
  If negative or zero, the maximum hardware value will be used.
5468
- block_dim: The number of threads per block.
5669
+ block_dim: The number of threads per block (always 1 for "cpu" devices).
5469
5670
  """
5470
5671
 
5471
5672
  init()
@@ -5476,6 +5677,9 @@ def launch(
5476
5677
  else:
5477
5678
  device = runtime.get_device(device)
5478
5679
 
5680
+ if device == "cpu":
5681
+ block_dim = 1
5682
+
5479
5683
  # check function is a Kernel
5480
5684
  if not isinstance(kernel, Kernel):
5481
5685
  raise RuntimeError("Error launching kernel, can only launch functions decorated with @wp.kernel.")
@@ -5708,6 +5912,18 @@ def launch_tiled(*args, **kwargs):
5708
5912
  "Launch block dimension 'block_dim' argument should be passed via. keyword args for wp.launch_tiled()"
5709
5913
  )
5710
5914
 
5915
+ if "device" in kwargs:
5916
+ device = kwargs["device"]
5917
+ else:
5918
+ # todo: this doesn't consider the case where device
5919
+ # is passed through positional args
5920
+ device = None
5921
+
5922
+ # force the block_dim to 1 if running on "cpu"
5923
+ device = runtime.get_device(device)
5924
+ if device.is_cpu:
5925
+ kwargs["block_dim"] = 1
5926
+
5711
5927
  dim = kwargs["dim"]
5712
5928
  if not isinstance(dim, list):
5713
5929
  dim = list(dim) if isinstance(dim, tuple) else [dim]
@@ -5876,6 +6092,7 @@ def set_module_options(options: Dict[str, Any], module: Optional[Any] = None):
5876
6092
 
5877
6093
  * **mode**: The compilation mode to use, can be "debug", or "release", defaults to the value of ``warp.config.mode``.
5878
6094
  * **max_unroll**: The maximum fixed-size loop to unroll, defaults to the value of ``warp.config.max_unroll``.
6095
+ * **block_dim**: The default number of threads to assign to each block
5879
6096
 
5880
6097
  Args:
5881
6098
 
@@ -5901,7 +6118,12 @@ def get_module_options(module: Optional[Any] = None) -> Dict[str, Any]:
5901
6118
  return get_module(m.__name__).options
5902
6119
 
5903
6120
 
5904
- def capture_begin(device: Devicelike = None, stream=None, force_module_load=None, external=False):
6121
+ def capture_begin(
6122
+ device: Devicelike = None,
6123
+ stream: Optional[Stream] = None,
6124
+ force_module_load: Optional[bool] = None,
6125
+ external: bool = False,
6126
+ ):
5905
6127
  """Begin capture of a CUDA graph
5906
6128
 
5907
6129
  Captures all subsequent kernel launches and memory operations on CUDA devices.
@@ -5968,16 +6190,15 @@ def capture_begin(device: Devicelike = None, stream=None, force_module_load=None
5968
6190
  runtime.captures[capture_id] = graph
5969
6191
 
5970
6192
 
5971
- def capture_end(device: Devicelike = None, stream: Stream = None) -> Graph:
5972
- """Ends the capture of a CUDA graph
6193
+ def capture_end(device: Devicelike = None, stream: Optional[Stream] = None) -> Graph:
6194
+ """End the capture of a CUDA graph.
5973
6195
 
5974
6196
  Args:
5975
-
5976
6197
  device: The CUDA device where capture began
5977
6198
  stream: The CUDA stream where capture began
5978
6199
 
5979
6200
  Returns:
5980
- A Graph object that can be launched with :func:`~warp.capture_launch()`
6201
+ A :class:`Graph` object that can be launched with :func:`~warp.capture_launch()`
5981
6202
  """
5982
6203
 
5983
6204
  if stream is not None:
@@ -6011,12 +6232,12 @@ def capture_end(device: Devicelike = None, stream: Stream = None) -> Graph:
6011
6232
  return graph
6012
6233
 
6013
6234
 
6014
- def capture_launch(graph: Graph, stream: Stream = None):
6235
+ def capture_launch(graph: Graph, stream: Optional[Stream] = None):
6015
6236
  """Launch a previously captured CUDA graph
6016
6237
 
6017
6238
  Args:
6018
- graph: A Graph as returned by :func:`~warp.capture_end()`
6019
- stream: A Stream to launch the graph on (optional)
6239
+ graph: A :class:`Graph` as returned by :func:`~warp.capture_end()`
6240
+ stream: A :class:`Stream` to launch the graph on
6020
6241
  """
6021
6242
 
6022
6243
  if stream is not None:
@@ -6032,24 +6253,28 @@ def capture_launch(graph: Graph, stream: Stream = None):
6032
6253
 
6033
6254
 
6034
6255
  def copy(
6035
- dest: warp.array, src: warp.array, dest_offset: int = 0, src_offset: int = 0, count: int = 0, stream: Stream = None
6256
+ dest: warp.array,
6257
+ src: warp.array,
6258
+ dest_offset: int = 0,
6259
+ src_offset: int = 0,
6260
+ count: int = 0,
6261
+ stream: Optional[Stream] = None,
6036
6262
  ):
6037
6263
  """Copy array contents from `src` to `dest`.
6038
6264
 
6039
6265
  Args:
6040
- dest: Destination array, must be at least as big as source buffer
6266
+ dest: Destination array, must be at least as large as source buffer
6041
6267
  src: Source array
6042
6268
  dest_offset: Element offset in the destination array
6043
6269
  src_offset: Element offset in the source array
6044
6270
  count: Number of array elements to copy (will copy all elements if set to 0)
6045
- stream: The stream on which to perform the copy (optional)
6271
+ stream: The stream on which to perform the copy
6046
6272
 
6047
6273
  The stream, if specified, can be from any device. If the stream is omitted, then Warp selects a stream based on the following rules:
6048
6274
  (1) If the destination array is on a CUDA device, use the current stream on the destination device.
6049
6275
  (2) Otherwise, if the source array is on a CUDA device, use the current stream on the source device.
6050
6276
 
6051
6277
  If neither source nor destination are on a CUDA device, no stream is used for the copy.
6052
-
6053
6278
  """
6054
6279
 
6055
6280
  from warp.context import runtime
@@ -6274,8 +6499,8 @@ def type_str(t):
6274
6499
  return f"Transformation[{type_str(t._wp_scalar_type_)}]"
6275
6500
 
6276
6501
  raise TypeError("Invalid vector or matrix dimensions")
6277
- elif warp.codegen.get_type_origin(t) in (list, tuple):
6278
- args_repr = ", ".join(type_str(x) for x in warp.codegen.get_type_args(t))
6502
+ elif get_origin(t) in (list, tuple):
6503
+ args_repr = ", ".join(type_str(x) for x in get_args(t))
6279
6504
  return f"{t._name}[{args_repr}]"
6280
6505
  elif t is Ellipsis:
6281
6506
  return "..."