warp-lang 1.3.3__py3-none-manylinux2014_aarch64.whl → 1.4.1__py3-none-manylinux2014_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (110) hide show
  1. warp/__init__.py +6 -0
  2. warp/autograd.py +59 -6
  3. warp/bin/warp.so +0 -0
  4. warp/build_dll.py +8 -10
  5. warp/builtins.py +103 -3
  6. warp/codegen.py +447 -53
  7. warp/config.py +1 -1
  8. warp/context.py +682 -405
  9. warp/dlpack.py +2 -0
  10. warp/examples/benchmarks/benchmark_cloth.py +10 -0
  11. warp/examples/core/example_render_opengl.py +12 -10
  12. warp/examples/fem/example_adaptive_grid.py +251 -0
  13. warp/examples/fem/example_apic_fluid.py +1 -1
  14. warp/examples/fem/example_diffusion_3d.py +2 -2
  15. warp/examples/fem/example_magnetostatics.py +1 -1
  16. warp/examples/fem/example_streamlines.py +1 -0
  17. warp/examples/fem/utils.py +25 -5
  18. warp/examples/sim/example_cloth.py +50 -6
  19. warp/fem/__init__.py +2 -0
  20. warp/fem/adaptivity.py +493 -0
  21. warp/fem/field/field.py +2 -1
  22. warp/fem/field/nodal_field.py +18 -26
  23. warp/fem/field/test.py +4 -4
  24. warp/fem/field/trial.py +4 -4
  25. warp/fem/geometry/__init__.py +1 -0
  26. warp/fem/geometry/adaptive_nanogrid.py +843 -0
  27. warp/fem/geometry/nanogrid.py +55 -28
  28. warp/fem/space/__init__.py +1 -1
  29. warp/fem/space/nanogrid_function_space.py +69 -35
  30. warp/fem/utils.py +118 -107
  31. warp/jax_experimental.py +28 -15
  32. warp/native/array.h +0 -1
  33. warp/native/builtin.h +103 -6
  34. warp/native/bvh.cu +4 -2
  35. warp/native/cuda_util.cpp +14 -0
  36. warp/native/cuda_util.h +2 -0
  37. warp/native/error.cpp +4 -2
  38. warp/native/exports.h +99 -0
  39. warp/native/mat.h +97 -0
  40. warp/native/mesh.cpp +36 -0
  41. warp/native/mesh.cu +52 -1
  42. warp/native/mesh.h +1 -0
  43. warp/native/quat.h +43 -0
  44. warp/native/range.h +11 -2
  45. warp/native/spatial.h +6 -0
  46. warp/native/vec.h +74 -0
  47. warp/native/warp.cpp +2 -1
  48. warp/native/warp.cu +10 -3
  49. warp/native/warp.h +8 -1
  50. warp/paddle.py +382 -0
  51. warp/sim/__init__.py +1 -0
  52. warp/sim/collide.py +519 -0
  53. warp/sim/integrator_euler.py +18 -5
  54. warp/sim/integrator_featherstone.py +5 -5
  55. warp/sim/integrator_vbd.py +1026 -0
  56. warp/sim/integrator_xpbd.py +2 -6
  57. warp/sim/model.py +50 -25
  58. warp/sparse.py +9 -7
  59. warp/stubs.py +459 -0
  60. warp/tape.py +2 -0
  61. warp/tests/aux_test_dependent.py +1 -0
  62. warp/tests/aux_test_name_clash1.py +32 -0
  63. warp/tests/aux_test_name_clash2.py +32 -0
  64. warp/tests/aux_test_square.py +1 -0
  65. warp/tests/test_array.py +188 -0
  66. warp/tests/test_async.py +3 -3
  67. warp/tests/test_atomic.py +6 -0
  68. warp/tests/test_closest_point_edge_edge.py +93 -1
  69. warp/tests/test_codegen.py +93 -15
  70. warp/tests/test_codegen_instancing.py +1457 -0
  71. warp/tests/test_collision.py +486 -0
  72. warp/tests/test_compile_consts.py +3 -28
  73. warp/tests/test_dlpack.py +170 -0
  74. warp/tests/test_examples.py +22 -8
  75. warp/tests/test_fast_math.py +10 -4
  76. warp/tests/test_fem.py +81 -1
  77. warp/tests/test_func.py +46 -0
  78. warp/tests/test_implicit_init.py +49 -0
  79. warp/tests/test_jax.py +58 -0
  80. warp/tests/test_mat.py +84 -0
  81. warp/tests/test_mesh_query_point.py +188 -0
  82. warp/tests/test_model.py +13 -0
  83. warp/tests/test_module_hashing.py +40 -0
  84. warp/tests/test_multigpu.py +3 -3
  85. warp/tests/test_overwrite.py +8 -0
  86. warp/tests/test_paddle.py +852 -0
  87. warp/tests/test_print.py +89 -0
  88. warp/tests/test_quat.py +111 -0
  89. warp/tests/test_reload.py +31 -1
  90. warp/tests/test_scalar_ops.py +2 -0
  91. warp/tests/test_static.py +568 -0
  92. warp/tests/test_streams.py +64 -3
  93. warp/tests/test_struct.py +4 -4
  94. warp/tests/test_torch.py +24 -0
  95. warp/tests/test_triangle_closest_point.py +137 -0
  96. warp/tests/test_types.py +1 -1
  97. warp/tests/test_vbd.py +386 -0
  98. warp/tests/test_vec.py +143 -0
  99. warp/tests/test_vec_scalar_ops.py +139 -0
  100. warp/tests/unittest_suites.py +12 -0
  101. warp/tests/unittest_utils.py +9 -5
  102. warp/thirdparty/dlpack.py +3 -1
  103. warp/types.py +167 -36
  104. warp/utils.py +37 -14
  105. {warp_lang-1.3.3.dist-info → warp_lang-1.4.1.dist-info}/METADATA +10 -8
  106. {warp_lang-1.3.3.dist-info → warp_lang-1.4.1.dist-info}/RECORD +109 -97
  107. warp/tests/test_point_triangle_closest_point.py +0 -143
  108. {warp_lang-1.3.3.dist-info → warp_lang-1.4.1.dist-info}/LICENSE.md +0 -0
  109. {warp_lang-1.3.3.dist-info → warp_lang-1.4.1.dist-info}/WHEEL +0 -0
  110. {warp_lang-1.3.3.dist-info → warp_lang-1.4.1.dist-info}/top_level.txt +0 -0
warp/context.py CHANGED
@@ -6,7 +6,6 @@
6
6
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
7
 
8
8
  import ast
9
- import builtins
10
9
  import ctypes
11
10
  import functools
12
11
  import hashlib
@@ -19,9 +18,9 @@ import platform
19
18
  import sys
20
19
  import types
21
20
  import typing
21
+ import weakref
22
22
  from copy import copy as shallowcopy
23
23
  from pathlib import Path
24
- from struct import pack as struct_pack
25
24
  from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union
26
25
 
27
26
  import numpy as np
@@ -36,6 +35,10 @@ import warp.config
36
35
 
37
36
  def create_value_func(type):
38
37
  def value_func(arg_types, arg_values):
38
+ hint_origin = getattr(type, "__origin__", None)
39
+ if hint_origin is not None and issubclass(hint_origin, typing.Tuple):
40
+ return type.__args__
41
+
39
42
  return type
40
43
 
41
44
  return value_func
@@ -54,6 +57,38 @@ def get_function_args(func):
54
57
  complex_type_hints = (Any, Callable, Tuple)
55
58
  sequence_types = (list, tuple)
56
59
 
60
+ function_key_counts = {}
61
+
62
+
63
+ def generate_unique_function_identifier(key):
64
+ # Generate unique identifiers for user-defined functions in native code.
65
+ # - Prevents conflicts when a function is redefined and old versions are still in use.
66
+ # - Prevents conflicts between multiple closures returned from the same function.
67
+ # - Prevents conflicts between identically named functions from different modules.
68
+ #
69
+ # Currently, we generate a unique id when a new Function is created, which produces
70
+ # globally unique identifiers.
71
+ #
72
+ # NOTE:
73
+ # We could move this to the Module class for generating unique identifiers at module scope,
74
+ # but then we need another solution for preventing conflicts across modules (e.g., different namespaces).
75
+ # That would requires more Python code, generate more native code, and would be slightly slower
76
+ # with no clear advantages over globally-unique identifiers (non-global shared state is still shared state).
77
+ #
78
+ # TODO:
79
+ # Kernels and structs use unique identifiers based on their hash. Using hash-based identifiers
80
+ # for functions would allow filtering out duplicate identical functions during codegen,
81
+ # like we do with kernels and structs. This is worth investigating further, but might require
82
+ # additional refactoring. For example, the code that deals with custom gradient and replay functions
83
+ # requires matching function names, but these special functions get created before the hash
84
+ # for the parent function can be computed. In addition to these complications, computing hashes
85
+ # for all function instances would increase the cost of module hashing when generic functions
86
+ # are involved (currently we only hash the generic templates, which is sufficient).
87
+
88
+ unique_id = function_key_counts.get(key, 0)
89
+ function_key_counts[key] = unique_id + 1
90
+ return f"{key}_{unique_id}"
91
+
57
92
 
58
93
  class Function:
59
94
  def __init__(
@@ -90,6 +125,7 @@ class Function:
90
125
  code_transformers=None,
91
126
  skip_adding_overload=False,
92
127
  require_original_output_arg=False,
128
+ scope_locals=None, # the locals() where the function is defined, used for overload management
93
129
  ):
94
130
  if code_transformers is None:
95
131
  code_transformers = []
@@ -115,6 +151,7 @@ class Function:
115
151
  self.replay_snippet = replay_snippet
116
152
  self.custom_grad_func = None
117
153
  self.require_original_output_arg = require_original_output_arg
154
+ self.generic_parent = None # generic function that was used to instantiate this overload
118
155
 
119
156
  if initializer_list_func is None:
120
157
  self.initializer_list_func = lambda x, y: False
@@ -124,14 +161,19 @@ class Function:
124
161
  )
125
162
  self.hidden = hidden # function will not be listed in docs
126
163
  self.skip_replay = (
127
- skip_replay # whether or not operation will be performed during the forward replay in the backward pass
164
+ skip_replay # whether operation will be performed during the forward replay in the backward pass
128
165
  )
129
- self.missing_grad = missing_grad # whether or not builtin is missing a corresponding adjoint
166
+ self.missing_grad = missing_grad # whether builtin is missing a corresponding adjoint
130
167
  self.generic = generic
131
168
 
132
- # allow registering builtin functions with a different name in Python from the native code
169
+ # allow registering functions with a different name in Python and native code
133
170
  if native_func is None:
134
- self.native_func = key
171
+ if func is None:
172
+ # builtin function
173
+ self.native_func = key
174
+ else:
175
+ # user functions need unique identifiers to avoid conflicts
176
+ self.native_func = generate_unique_function_identifier(key)
135
177
  else:
136
178
  self.native_func = native_func
137
179
 
@@ -162,6 +204,11 @@ class Function:
162
204
  else:
163
205
  self.input_types[name] = type
164
206
 
207
+ # Record any default parameter values.
208
+ if not self.defaults:
209
+ signature = inspect.signature(func)
210
+ self.defaults = {k: v.default for k, v in signature.parameters.items() if v.default is not v.empty}
211
+
165
212
  else:
166
213
  # builtin function
167
214
 
@@ -210,9 +257,13 @@ class Function:
210
257
  signature_params.append(param)
211
258
  self.signature = inspect.Signature(signature_params)
212
259
 
260
+ # scope for resolving overloads
261
+ if scope_locals is None:
262
+ scope_locals = inspect.currentframe().f_back.f_locals
263
+
213
264
  # add to current module
214
265
  if module:
215
- module.register_function(self, skip_adding_overload)
266
+ module.register_function(self, scope_locals, skip_adding_overload)
216
267
 
217
268
  def __call__(self, *args, **kwargs):
218
269
  # handles calling a builtin (native) function
@@ -323,57 +374,52 @@ class Function:
323
374
 
324
375
  # check if generic
325
376
  if warp.types.is_generic_signature(sig):
326
- if sig in self.user_templates:
327
- raise RuntimeError(
328
- f"Duplicate generic function overload {self.key} with arguments {f.input_types.values()}"
329
- )
330
377
  self.user_templates[sig] = f
331
378
  else:
332
- if sig in self.user_overloads:
333
- raise RuntimeError(
334
- f"Duplicate function overload {self.key} with arguments {f.input_types.values()}"
335
- )
336
379
  self.user_overloads[sig] = f
337
380
 
338
381
  def get_overload(self, arg_types, kwarg_types):
339
382
  assert not self.is_builtin()
340
383
 
341
- sig = warp.types.get_signature(arg_types, func_name=self.key)
384
+ for f in self.user_overloads.values():
385
+ if warp.codegen.func_match_args(f, arg_types, kwarg_types):
386
+ return f
342
387
 
343
- f = self.user_overloads.get(sig)
344
- if f is not None:
345
- return f
346
- else:
347
- for f in self.user_templates.values():
348
- if len(f.input_types) != len(arg_types):
349
- continue
388
+ for f in self.user_templates.values():
389
+ if not warp.codegen.func_match_args(f, arg_types, kwarg_types):
390
+ continue
391
+
392
+ if len(f.input_types) != len(arg_types):
393
+ continue
350
394
 
351
- # try to match the given types to the function template types
352
- template_types = list(f.input_types.values())
353
- args_matched = True
395
+ # try to match the given types to the function template types
396
+ template_types = list(f.input_types.values())
397
+ args_matched = True
354
398
 
355
- for i in range(len(arg_types)):
356
- if not warp.types.type_matches_template(arg_types[i], template_types[i]):
357
- args_matched = False
358
- break
399
+ for i in range(len(arg_types)):
400
+ if not warp.types.type_matches_template(arg_types[i], template_types[i]):
401
+ args_matched = False
402
+ break
359
403
 
360
- if args_matched:
361
- # instantiate this function with the specified argument types
404
+ if args_matched:
405
+ # instantiate this function with the specified argument types
362
406
 
363
- arg_names = f.input_types.keys()
364
- overload_annotations = dict(zip(arg_names, arg_types))
407
+ arg_names = f.input_types.keys()
408
+ overload_annotations = dict(zip(arg_names, arg_types))
365
409
 
366
- ovl = shallowcopy(f)
367
- ovl.adj = warp.codegen.Adjoint(f.func, overload_annotations)
368
- ovl.input_types = overload_annotations
369
- ovl.value_func = None
410
+ ovl = shallowcopy(f)
411
+ ovl.adj = warp.codegen.Adjoint(f.func, overload_annotations)
412
+ ovl.input_types = overload_annotations
413
+ ovl.value_func = None
414
+ ovl.generic_parent = f
370
415
 
371
- self.user_overloads[sig] = ovl
416
+ sig = warp.types.get_signature(arg_types, func_name=self.key)
417
+ self.user_overloads[sig] = ovl
372
418
 
373
- return ovl
419
+ return ovl
374
420
 
375
- # failed to find overload
376
- return None
421
+ # failed to find overload
422
+ return None
377
423
 
378
424
  def __repr__(self):
379
425
  inputs_str = ", ".join([f"{k}: {warp.types.type_repr(v)}" for k, v in self.input_types.items()])
@@ -589,8 +635,7 @@ class Kernel:
589
635
  self.module = module
590
636
 
591
637
  if key is None:
592
- unique_key = self.module.generate_unique_kernel_key(func.__name__)
593
- self.key = unique_key
638
+ self.key = warp.codegen.make_full_qualified_name(func)
594
639
  else:
595
640
  self.key = key
596
641
 
@@ -614,9 +659,15 @@ class Kernel:
614
659
  # known overloads for generic kernels, indexed by type signature
615
660
  self.overloads = {}
616
661
 
662
+ # generic kernel that was used to instantiate this overload
663
+ self.generic_parent = None
664
+
617
665
  # argument indices by name
618
666
  self.arg_indices = {a.label: i for i, a in enumerate(self.adj.args)}
619
667
 
668
+ # hash will be computed when the module is built
669
+ self.hash = None
670
+
620
671
  if self.module:
621
672
  self.module.register_kernel(self)
622
673
 
@@ -664,10 +715,11 @@ class Kernel:
664
715
  ovl.is_generic = False
665
716
  ovl.overloads = {}
666
717
  ovl.sig = sig
718
+ ovl.generic_parent = self
667
719
 
668
720
  self.overloads[sig] = ovl
669
721
 
670
- self.module.unload()
722
+ self.module.mark_modified()
671
723
 
672
724
  return ovl
673
725
 
@@ -676,10 +728,13 @@ class Kernel:
676
728
  return self.overloads.get(sig)
677
729
 
678
730
  def get_mangled_name(self):
679
- if self.sig:
680
- return f"{self.key}_{self.sig}"
681
- else:
682
- return self.key
731
+ if self.hash is None:
732
+ raise RuntimeError(f"Missing hash for kernel {self.key} in module {self.module.name}")
733
+
734
+ # TODO: allow customizing the number of hash characters used
735
+ hash_suffix = self.hash.hex()[:8]
736
+
737
+ return f"{self.key}_{hash_suffix}"
683
738
 
684
739
 
685
740
  # ----------------------
@@ -689,9 +744,11 @@ class Kernel:
689
744
  def func(f):
690
745
  name = warp.codegen.make_full_qualified_name(f)
691
746
 
747
+ scope_locals = inspect.currentframe().f_back.f_locals
748
+
692
749
  m = get_module(f.__module__)
693
750
  Function(
694
- func=f, key=name, namespace="", module=m, value_func=None
751
+ func=f, key=name, namespace="", module=m, value_func=None, scope_locals=scope_locals
695
752
  ) # value_type not known yet, will be inferred during Adjoint.build()
696
753
 
697
754
  # use the top of the list of overloads for this key
@@ -705,6 +762,8 @@ def func_native(snippet, adj_snippet=None, replay_snippet=None):
705
762
  Decorator to register native code snippet, @func_native
706
763
  """
707
764
 
765
+ scope_locals = inspect.currentframe().f_back.f_locals
766
+
708
767
  def snippet_func(f):
709
768
  name = warp.codegen.make_full_qualified_name(f)
710
769
 
@@ -717,6 +776,7 @@ def func_native(snippet, adj_snippet=None, replay_snippet=None):
717
776
  native_snippet=snippet,
718
777
  adj_native_snippet=adj_snippet,
719
778
  replay_snippet=replay_snippet,
779
+ scope_locals=scope_locals,
720
780
  ) # value_type not known yet, will be inferred during Adjoint.build()
721
781
  g = m.functions[name]
722
782
  # copy over the function attributes, including docstring
@@ -783,6 +843,7 @@ def func_grad(forward_fn):
783
843
  f.custom_grad_func = Function(
784
844
  grad_fn,
785
845
  key=f.key,
846
+ native_func=f.native_func,
786
847
  namespace=f.namespace,
787
848
  input_types=reverse_args,
788
849
  value_func=None,
@@ -941,7 +1002,7 @@ def overload(kernel, arg_types=None):
941
1002
  # ensure this function name corresponds to a kernel
942
1003
  fn = kernel
943
1004
  module = get_module(fn.__module__)
944
- kernel = module.kernels.get(fn.__name__)
1005
+ kernel = module.find_kernel(fn)
945
1006
  if kernel is None:
946
1007
  raise RuntimeError(f"Failed to find a kernel named '{fn.__name__}' in module {fn.__module__}")
947
1008
 
@@ -1277,7 +1338,6 @@ def get_module(name):
1277
1338
  # clear out old kernels, funcs, struct definitions
1278
1339
  old_module.kernels = {}
1279
1340
  old_module.functions = {}
1280
- old_module.constants = {}
1281
1341
  old_module.structs = {}
1282
1342
  old_module.loader = parent_loader
1283
1343
 
@@ -1289,30 +1349,206 @@ def get_module(name):
1289
1349
  return user_modules[name]
1290
1350
 
1291
1351
 
1352
+ # ModuleHasher computes the module hash based on all the kernels, module options,
1353
+ # and build configuration. For each kernel, it computes a deep hash by recursively
1354
+ # hashing all referenced functions, structs, and constants, even those defined in
1355
+ # other modules. The module hash is computed in the constructor and can be retrieved
1356
+ # using get_module_hash(). In addition, the ModuleHasher takes care of filtering out
1357
+ # duplicate kernels for codegen (see get_unique_kernels()).
1358
+ class ModuleHasher:
1359
+ def __init__(self, module):
1360
+ # cache function hashes to avoid hashing multiple times
1361
+ self.function_hashes = {} # (function: hash)
1362
+
1363
+ # avoid recursive spiral of doom (e.g., function calling an overload of itself)
1364
+ self.functions_in_progress = set()
1365
+
1366
+ # all unique kernels for codegen, filtered by hash
1367
+ self.unique_kernels = {} # (hash: kernel)
1368
+
1369
+ # start hashing the module
1370
+ ch = hashlib.sha256()
1371
+
1372
+ # hash all non-generic kernels
1373
+ for kernel in module.live_kernels:
1374
+ if kernel.is_generic:
1375
+ for ovl in kernel.overloads.values():
1376
+ if not ovl.adj.skip_build:
1377
+ ovl.hash = self.hash_kernel(ovl)
1378
+ else:
1379
+ if not kernel.adj.skip_build:
1380
+ kernel.hash = self.hash_kernel(kernel)
1381
+
1382
+ # include all unique kernels in the module hash
1383
+ for kernel_hash in sorted(self.unique_kernels.keys()):
1384
+ ch.update(kernel_hash)
1385
+
1386
+ # configuration parameters
1387
+ for opt in sorted(module.options.keys()):
1388
+ s = f"{opt}:{module.options[opt]}"
1389
+ ch.update(bytes(s, "utf-8"))
1390
+
1391
+ # ensure to trigger recompilation if flags affecting kernel compilation are changed
1392
+ if warp.config.verify_fp:
1393
+ ch.update(bytes("verify_fp", "utf-8"))
1394
+
1395
+ # build config
1396
+ ch.update(bytes(warp.config.mode, "utf-8"))
1397
+
1398
+ # save the module hash
1399
+ self.module_hash = ch.digest()
1400
+
1401
+ def hash_kernel(self, kernel):
1402
+ # NOTE: We only hash non-generic kernels, so we don't traverse kernel overloads here.
1403
+
1404
+ ch = hashlib.sha256()
1405
+
1406
+ ch.update(bytes(kernel.key, "utf-8"))
1407
+ ch.update(self.hash_adjoint(kernel.adj))
1408
+
1409
+ h = ch.digest()
1410
+
1411
+ self.unique_kernels[h] = kernel
1412
+
1413
+ return h
1414
+
1415
+ def hash_function(self, func):
1416
+ # NOTE: This method hashes all possible overloads that a function call could resolve to.
1417
+ # The exact overload will be resolved at build time, when the argument types are known.
1418
+
1419
+ h = self.function_hashes.get(func)
1420
+ if h is not None:
1421
+ return h
1422
+
1423
+ self.functions_in_progress.add(func)
1424
+
1425
+ ch = hashlib.sha256()
1426
+
1427
+ ch.update(bytes(func.key, "utf-8"))
1428
+
1429
+ # include all concrete and generic overloads
1430
+ overloads = {**func.user_overloads, **func.user_templates}
1431
+ for sig in sorted(overloads.keys()):
1432
+ ovl = overloads[sig]
1433
+
1434
+ # skip instantiations of generic functions
1435
+ if ovl.generic_parent is not None:
1436
+ continue
1437
+
1438
+ # adjoint
1439
+ ch.update(self.hash_adjoint(ovl.adj))
1440
+
1441
+ # custom bits
1442
+ if ovl.custom_grad_func:
1443
+ ch.update(self.hash_adjoint(ovl.custom_grad_func.adj))
1444
+ if ovl.custom_replay_func:
1445
+ ch.update(self.hash_adjoint(ovl.custom_replay_func.adj))
1446
+ if ovl.replay_snippet:
1447
+ ch.update(bytes(ovl.replay_snippet, "utf-8"))
1448
+ if ovl.native_snippet:
1449
+ ch.update(bytes(ovl.native_snippet, "utf-8"))
1450
+ if ovl.adj_native_snippet:
1451
+ ch.update(bytes(ovl.adj_native_snippet, "utf-8"))
1452
+
1453
+ h = ch.digest()
1454
+
1455
+ self.function_hashes[func] = h
1456
+
1457
+ self.functions_in_progress.remove(func)
1458
+
1459
+ return h
1460
+
1461
+ def hash_adjoint(self, adj):
1462
+ # NOTE: We don't cache adjoint hashes, because adjoints are always unique.
1463
+ # Even instances of generic kernels and functions have unique adjoints with
1464
+ # different argument types.
1465
+
1466
+ ch = hashlib.sha256()
1467
+
1468
+ # source
1469
+ ch.update(bytes(adj.source, "utf-8"))
1470
+
1471
+ # args
1472
+ for arg, arg_type in adj.arg_types.items():
1473
+ s = f"{arg}:{warp.types.get_type_code(arg_type)}"
1474
+ ch.update(bytes(s, "utf-8"))
1475
+
1476
+ # hash struct types
1477
+ if isinstance(arg_type, warp.codegen.Struct):
1478
+ ch.update(arg_type.hash)
1479
+ elif warp.types.is_array(arg_type) and isinstance(arg_type.dtype, warp.codegen.Struct):
1480
+ ch.update(arg_type.dtype.hash)
1481
+
1482
+ # find referenced constants, types, and functions
1483
+ constants, types, functions = adj.get_references()
1484
+
1485
+ # hash referenced constants
1486
+ for name, value in constants.items():
1487
+ ch.update(bytes(name, "utf-8"))
1488
+ ch.update(self.get_constant_bytes(value))
1489
+
1490
+ # hash wp.static() expressions that were evaluated at declaration time
1491
+ for k, v in adj.static_expressions.items():
1492
+ ch.update(bytes(k, "utf-8"))
1493
+ if isinstance(v, Function):
1494
+ if v not in self.functions_in_progress:
1495
+ ch.update(self.hash_function(v))
1496
+ else:
1497
+ ch.update(self.get_constant_bytes(v))
1498
+
1499
+ # hash referenced types
1500
+ for t in types.keys():
1501
+ ch.update(bytes(warp.types.get_type_code(t), "utf-8"))
1502
+
1503
+ # hash referenced functions
1504
+ for f in functions.keys():
1505
+ if f not in self.functions_in_progress:
1506
+ ch.update(self.hash_function(f))
1507
+
1508
+ return ch.digest()
1509
+
1510
+ def get_constant_bytes(self, value):
1511
+ if isinstance(value, int):
1512
+ # this also handles builtins.bool
1513
+ return bytes(ctypes.c_int(value))
1514
+ elif isinstance(value, float):
1515
+ return bytes(ctypes.c_float(value))
1516
+ elif isinstance(value, warp.types.float16):
1517
+ # float16 is a special case
1518
+ return bytes(ctypes.c_float(value.value))
1519
+ elif isinstance(value, tuple(warp.types.scalar_and_bool_types)):
1520
+ return bytes(value._type_(value.value))
1521
+ elif hasattr(value, "_wp_scalar_type_"):
1522
+ return bytes(value)
1523
+ elif isinstance(value, warp.codegen.StructInstance):
1524
+ return bytes(value._ctype)
1525
+ else:
1526
+ raise TypeError(f"Invalid constant type: {type(value)}")
1527
+
1528
+ def get_module_hash(self):
1529
+ return self.module_hash
1530
+
1531
+ def get_unique_kernels(self):
1532
+ return self.unique_kernels.values()
1533
+
1534
+
1292
1535
  class ModuleBuilder:
1293
- def __init__(self, module, options):
1536
+ def __init__(self, module, options, hasher=None):
1294
1537
  self.functions = {}
1295
1538
  self.structs = {}
1296
1539
  self.options = options
1297
1540
  self.module = module
1298
1541
  self.deferred_functions = []
1299
1542
 
1300
- # build all functions declared in the module
1301
- for func in module.functions.values():
1302
- for f in func.user_overloads.values():
1303
- self.build_function(f)
1304
- if f.custom_replay_func is not None:
1305
- self.build_function(f.custom_replay_func)
1306
-
1307
- # build all kernel entry points
1308
- for kernel in module.kernels.values():
1309
- if not kernel.is_generic:
1310
- self.build_kernel(kernel)
1311
- else:
1312
- for k in kernel.overloads.values():
1313
- self.build_kernel(k)
1543
+ if hasher is None:
1544
+ hasher = ModuleHasher(module)
1314
1545
 
1315
- # build all functions outside this module which are called from functions or kernels in this module
1546
+ # build all unique kernels
1547
+ self.kernels = hasher.get_unique_kernels()
1548
+ for kernel in self.kernels:
1549
+ self.build_kernel(kernel)
1550
+
1551
+ # build deferred functions
1316
1552
  for func in self.deferred_functions:
1317
1553
  self.build_function(func)
1318
1554
 
@@ -1328,7 +1564,7 @@ class ModuleBuilder:
1328
1564
  for var in s.vars.values():
1329
1565
  if isinstance(var.type, warp.codegen.Struct):
1330
1566
  stack.append(var.type)
1331
- elif isinstance(var.type, warp.types.array) and isinstance(var.type.dtype, warp.codegen.Struct):
1567
+ elif warp.types.is_array(var.type) and isinstance(var.type.dtype, warp.codegen.Struct):
1332
1568
  stack.append(var.type.dtype)
1333
1569
 
1334
1570
  # Build them in reverse to generate a correct dependency order.
@@ -1374,8 +1610,12 @@ class ModuleBuilder:
1374
1610
  source = ""
1375
1611
 
1376
1612
  # code-gen structs
1613
+ visited_structs = set()
1377
1614
  for struct in self.structs.keys():
1378
- source += warp.codegen.codegen_struct(struct)
1615
+ # avoid emitting duplicates
1616
+ if struct.hash not in visited_structs:
1617
+ source += warp.codegen.codegen_struct(struct)
1618
+ visited_structs.add(struct.hash)
1379
1619
 
1380
1620
  # code-gen all imported functions
1381
1621
  for func in self.functions.keys():
@@ -1386,21 +1626,15 @@ class ModuleBuilder:
1386
1626
  else:
1387
1627
  source += warp.codegen.codegen_snippet(
1388
1628
  func.adj,
1389
- name=func.key,
1629
+ name=func.native_func,
1390
1630
  snippet=func.native_snippet,
1391
1631
  adj_snippet=func.adj_native_snippet,
1392
1632
  replay_snippet=func.replay_snippet,
1393
1633
  )
1394
1634
 
1395
- for kernel in self.module.kernels.values():
1396
- # each kernel gets an entry point in the module
1397
- if not kernel.is_generic:
1398
- source += warp.codegen.codegen_kernel(kernel, device=device, options=self.options)
1399
- source += warp.codegen.codegen_module(kernel, device=device)
1400
- else:
1401
- for k in kernel.overloads.values():
1402
- source += warp.codegen.codegen_kernel(k, device=device, options=self.options)
1403
- source += warp.codegen.codegen_module(k, device=device)
1635
+ for kernel in self.kernels:
1636
+ source += warp.codegen.codegen_kernel(kernel, device=device, options=self.options)
1637
+ source += warp.codegen.codegen_module(kernel, device=device)
1404
1638
 
1405
1639
  # add headers
1406
1640
  if device == "cpu":
@@ -1425,8 +1659,9 @@ class ModuleExec:
1425
1659
  instance.handle = None
1426
1660
  return instance
1427
1661
 
1428
- def __init__(self, handle, device):
1662
+ def __init__(self, handle, module_hash, device):
1429
1663
  self.handle = handle
1664
+ self.module_hash = module_hash
1430
1665
  self.device = device
1431
1666
  self.kernel_hooks = {}
1432
1667
 
@@ -1457,8 +1692,12 @@ class ModuleExec:
1457
1692
  )
1458
1693
  else:
1459
1694
  func = ctypes.CFUNCTYPE(None)
1460
- forward = func(runtime.llvm.lookup(self.handle.encode("utf-8"), (name + "_cpu_forward").encode("utf-8")))
1461
- backward = func(runtime.llvm.lookup(self.handle.encode("utf-8"), (name + "_cpu_backward").encode("utf-8")))
1695
+ forward = (
1696
+ func(runtime.llvm.lookup(self.handle.encode("utf-8"), (name + "_cpu_forward").encode("utf-8"))) or None
1697
+ )
1698
+ backward = (
1699
+ func(runtime.llvm.lookup(self.handle.encode("utf-8"), (name + "_cpu_backward").encode("utf-8"))) or None
1700
+ )
1462
1701
 
1463
1702
  hooks = KernelHooks(forward, backward)
1464
1703
  self.kernel_hooks[kernel] = hooks
@@ -1475,16 +1714,33 @@ class Module:
1475
1714
  self.name = name
1476
1715
  self.loader = loader
1477
1716
 
1478
- self.kernels = {}
1479
- self.functions = {}
1480
- self.constants = {} # Any constants referenced in this module including those defined in other modules
1481
- self.structs = {}
1717
+ # lookup the latest versions of kernels, functions, and structs by key
1718
+ self.kernels = {} # (key: kernel)
1719
+ self.functions = {} # (key: function)
1720
+ self.structs = {} # (key: struct)
1721
+
1722
+ # Set of all "live" kernels in this module.
1723
+ # The difference between `live_kernels` and `kernels` is that `live_kernels` may contain
1724
+ # multiple kernels with the same key (which is essential to support closures), while `kernels`
1725
+ # only holds the latest kernel for each key. When the module is built, we compute the hash
1726
+ # of each kernel in `live_kernels` and filter out duplicates for codegen.
1727
+ self.live_kernels = weakref.WeakSet()
1482
1728
 
1483
- self.cpu_exec = None # executable CPU module
1484
- self.cuda_execs = {} # executable CUDA module lookup by CUDA context
1729
+ # executable modules currently loaded
1730
+ self.execs = {} # (device.context: ModuleExec)
1485
1731
 
1486
- self.cpu_build_failed = False
1487
- self.cuda_build_failed = False
1732
+ # set of device contexts where the build has failed
1733
+ self.failed_builds = set()
1734
+
1735
+ # hash data, including the module hash
1736
+ self.hasher = None
1737
+
1738
+ # LLVM executable modules are identified using strings. Since it's possible for multiple
1739
+ # executable versions to be loaded at the same time, we need a way to ensure uniqueness.
1740
+ # A unique handle is created from the module name and this auto-incremented integer id.
1741
+ # NOTE: The module hash is not sufficient for uniqueness in rare cases where a module
1742
+ # is retained and later reloaded with the same hash.
1743
+ self.cpu_exec_id = 0
1488
1744
 
1489
1745
  self.options = {
1490
1746
  "max_unroll": warp.config.max_unroll,
@@ -1497,11 +1753,6 @@ class Module:
1497
1753
  # Module dependencies are determined by scanning each function
1498
1754
  # and kernel for references to external functions and structs.
1499
1755
  #
1500
- # When a referenced module is modified, all of its dependents need to be reloaded
1501
- # on the next launch. To detect this, a module's hash recursively includes
1502
- # all of its references.
1503
- # -> See ``Module.hash_module()``
1504
- #
1505
1756
  # The dependency mechanism works for both static and dynamic (runtime) modifications.
1506
1757
  # When a module is reloaded at runtime, we recursively unload all of its
1507
1758
  # dependents, so that they will be re-hashed and reloaded on the next launch.
@@ -1510,40 +1761,39 @@ class Module:
1510
1761
  self.references = set() # modules whose content we depend on
1511
1762
  self.dependents = set() # modules that depend on our content
1512
1763
 
1513
- # Since module hashing is recursive, we improve performance by caching the hash of the
1514
- # module contents (kernel source, function source, and struct source).
1515
- # After all kernels, functions, and structs are added to the module (usually at import time),
1516
- # the content hash doesn't change.
1517
- # -> See ``Module.hash_module_recursive()``
1518
-
1519
- self.content_hash = None
1520
-
1521
- # number of times module auto-generates kernel key for user
1522
- # used to ensure unique kernel keys
1523
- self.count = 0
1524
-
1525
1764
  def register_struct(self, struct):
1526
1765
  self.structs[struct.key] = struct
1527
1766
 
1528
1767
  # for a reload of module on next launch
1529
- self.unload()
1768
+ self.mark_modified()
1530
1769
 
1531
1770
  def register_kernel(self, kernel):
1771
+ # keep a reference to the latest version
1532
1772
  self.kernels[kernel.key] = kernel
1533
1773
 
1774
+ # track all kernel objects, even if they are duplicates
1775
+ self.live_kernels.add(kernel)
1776
+
1534
1777
  self.find_references(kernel.adj)
1535
1778
 
1536
1779
  # for a reload of module on next launch
1537
- self.unload()
1780
+ self.mark_modified()
1538
1781
 
1539
- def register_function(self, func, skip_adding_overload=False):
1540
- if func.key not in self.functions:
1541
- self.functions[func.key] = func
1782
+ def register_function(self, func, scope_locals, skip_adding_overload=False):
1783
+ # check for another Function with the same name in the same scope
1784
+ obj = scope_locals.get(func.func.__name__)
1785
+ if isinstance(obj, Function):
1786
+ func_existing = obj
1542
1787
  else:
1788
+ func_existing = None
1789
+
1790
+ # keep a reference to the latest version
1791
+ self.functions[func.key] = func_existing or func
1792
+
1793
+ if func_existing:
1543
1794
  # Check whether the new function's signature match any that has
1544
1795
  # already been registered. If so, then we simply override it, as
1545
1796
  # Python would do it, otherwise we register it as a new overload.
1546
- func_existing = self.functions[func.key]
1547
1797
  sig = warp.types.get_signature(
1548
1798
  func.input_types.values(),
1549
1799
  func_name=func.key,
@@ -1555,19 +1805,43 @@ class Module:
1555
1805
  arg_names=list(func_existing.input_types.keys()),
1556
1806
  )
1557
1807
  if sig == sig_existing:
1808
+ # replace the top-level function, but keep existing overloads
1809
+
1810
+ # copy generic overloads
1811
+ func.user_templates = func_existing.user_templates.copy()
1812
+
1813
+ # copy concrete overloads
1814
+ if warp.types.is_generic_signature(sig):
1815
+ # skip overloads that were instantiated from the function being replaced
1816
+ for k, v in func_existing.user_overloads.items():
1817
+ if v.generic_parent != func_existing:
1818
+ func.user_overloads[k] = v
1819
+ func.user_templates[sig] = func
1820
+ else:
1821
+ func.user_overloads = func_existing.user_overloads.copy()
1822
+ func.user_overloads[sig] = func
1823
+
1558
1824
  self.functions[func.key] = func
1559
1825
  elif not skip_adding_overload:
1826
+ # check if this is a generic overload that replaces an existing one
1827
+ if warp.types.is_generic_signature(sig):
1828
+ old_generic = func_existing.user_templates.get(sig)
1829
+ if old_generic is not None:
1830
+ # purge any concrete overloads that were instantiated from the old one
1831
+ for k, v in list(func_existing.user_overloads.items()):
1832
+ if v.generic_parent == old_generic:
1833
+ del func_existing.user_overloads[k]
1560
1834
  func_existing.add_overload(func)
1561
1835
 
1562
1836
  self.find_references(func.adj)
1563
1837
 
1564
1838
  # for a reload of module on next launch
1565
- self.unload()
1839
+ self.mark_modified()
1566
1840
 
1567
- def generate_unique_kernel_key(self, key):
1568
- unique_key = f"{key}_{self.count}"
1569
- self.count += 1
1570
- return unique_key
1841
+ # find kernel corresponding to a Python function
1842
+ def find_kernel(self, func):
1843
+ qualname = warp.codegen.make_full_qualified_name(func)
1844
+ return self.kernels.get(qualname)
1571
1845
 
1572
1846
  # collect all referenced functions / structs
1573
1847
  # given the AST of a function or kernel
@@ -1599,165 +1873,30 @@ class Module:
1599
1873
  if isinstance(arg.type, warp.codegen.Struct) and arg.type.module is not None:
1600
1874
  add_ref(arg.type.module)
1601
1875
 
1602
- def hash_module(self, recompute_content_hash=False):
1603
- """Recursively compute and return a hash for the module.
1604
-
1605
- If ``recompute_content_hash`` is False, each module's previously
1606
- computed ``content_hash`` will be used.
1607
- """
1608
-
1609
- def get_type_name(type_hint) -> str:
1610
- if isinstance(type_hint, warp.codegen.Struct):
1611
- return get_type_name(type_hint.cls)
1612
- elif isinstance(type_hint, warp.array) and isinstance(type_hint.dtype, warp.codegen.Struct):
1613
- return f"array{get_type_name(type_hint.dtype)}"
1614
-
1615
- return str(type_hint)
1616
-
1617
- def hash_recursive(module, visited):
1618
- # Hash this module, including all referenced modules recursively.
1619
- # The visited set tracks modules already visited to avoid circular references.
1620
-
1621
- # check if we need to update the content hash
1622
- if not module.content_hash or recompute_content_hash:
1623
- # recompute content hash
1624
- ch = hashlib.sha256()
1625
-
1626
- # Start with an empty constants dictionary in case any have been removed
1627
- module.constants = {}
1628
-
1629
- # struct source
1630
- for struct in module.structs.values():
1631
- s = ",".join(
1632
- "{}: {}".format(name, get_type_name(type_hint))
1633
- for name, type_hint in warp.codegen.get_annotations(struct.cls).items()
1634
- )
1635
- ch.update(bytes(s, "utf-8"))
1636
-
1637
- # functions source
1638
- for function in module.functions.values():
1639
- # include all concrete and generic overloads
1640
- overloads = itertools.chain(function.user_overloads.items(), function.user_templates.items())
1641
- for sig, func in overloads:
1642
- # signature
1643
- ch.update(bytes(sig, "utf-8"))
1644
-
1645
- # source
1646
- ch.update(bytes(func.adj.source, "utf-8"))
1647
-
1648
- if func.custom_grad_func:
1649
- ch.update(bytes(func.custom_grad_func.adj.source, "utf-8"))
1650
- if func.custom_replay_func:
1651
- ch.update(bytes(func.custom_replay_func.adj.source, "utf-8"))
1652
- if func.replay_snippet:
1653
- ch.update(bytes(func.replay_snippet, "utf-8"))
1654
- if func.native_snippet:
1655
- ch.update(bytes(func.native_snippet, "utf-8"))
1656
- if func.adj_native_snippet:
1657
- ch.update(bytes(func.adj_native_snippet, "utf-8"))
1658
-
1659
- # Populate constants referenced in this function
1660
- if func.adj:
1661
- module.constants.update(func.adj.get_constant_references())
1662
-
1663
- # kernel source
1664
- for kernel in module.kernels.values():
1665
- ch.update(bytes(kernel.key, "utf-8"))
1666
- ch.update(bytes(kernel.adj.source, "utf-8"))
1667
- # cache kernel arg types
1668
- for arg, arg_type in kernel.adj.arg_types.items():
1669
- s = f"{arg}: {get_type_name(arg_type)}"
1670
- ch.update(bytes(s, "utf-8"))
1671
- # for generic kernels the Python source is always the same,
1672
- # but we hash the type signatures of all the overloads
1673
- if kernel.is_generic:
1674
- for sig in sorted(kernel.overloads.keys()):
1675
- ch.update(bytes(sig, "utf-8"))
1676
-
1677
- # Populate constants referenced in this kernel
1678
- module.constants.update(kernel.adj.get_constant_references())
1679
-
1680
- # constants referenced in this module
1681
- for constant_name, constant_value in module.constants.items():
1682
- ch.update(bytes(constant_name, "utf-8"))
1683
-
1684
- # hash the constant value
1685
- if isinstance(constant_value, builtins.bool):
1686
- # This needs to come before the check for `int` since all boolean
1687
- # values are also instances of `int`.
1688
- ch.update(struct_pack("?", constant_value))
1689
- elif isinstance(constant_value, int):
1690
- ch.update(struct_pack("<q", constant_value))
1691
- elif isinstance(constant_value, float):
1692
- ch.update(struct_pack("<d", constant_value))
1693
- elif isinstance(constant_value, warp.types.float16):
1694
- # float16 is a special case
1695
- p = ctypes.pointer(ctypes.c_float(constant_value.value))
1696
- ch.update(p.contents)
1697
- elif isinstance(constant_value, tuple(warp.types.scalar_types)):
1698
- p = ctypes.pointer(constant_value._type_(constant_value.value))
1699
- ch.update(p.contents)
1700
- elif isinstance(constant_value, ctypes.Array):
1701
- ch.update(bytes(constant_value))
1702
- else:
1703
- raise RuntimeError(f"Invalid constant type: {type(constant_value)}")
1704
-
1705
- module.content_hash = ch.digest()
1706
-
1707
- h = hashlib.sha256()
1708
-
1709
- # content hash
1710
- h.update(module.content_hash)
1711
-
1712
- # configuration parameters
1713
- for k in sorted(module.options.keys()):
1714
- s = f"{k}={module.options[k]}"
1715
- h.update(bytes(s, "utf-8"))
1716
-
1717
- # ensure to trigger recompilation if flags affecting kernel compilation are changed
1718
- if warp.config.verify_fp:
1719
- h.update(bytes("verify_fp", "utf-8"))
1720
-
1721
- h.update(bytes(warp.config.mode, "utf-8"))
1722
-
1723
- # recurse on references
1724
- visited.add(module)
1725
-
1726
- sorted_deps = sorted(module.references, key=lambda m: m.name)
1727
- for dep in sorted_deps:
1728
- if dep not in visited:
1729
- dep_hash = hash_recursive(dep, visited)
1730
- h.update(dep_hash)
1731
-
1732
- return h.digest()
1733
-
1734
- return hash_recursive(self, visited=set())
1876
+ def hash_module(self):
1877
+ # compute latest hash
1878
+ self.hasher = ModuleHasher(self)
1879
+ return self.hasher.get_module_hash()
1735
1880
 
1736
1881
  def load(self, device) -> ModuleExec:
1737
1882
  device = runtime.get_device(device)
1738
1883
 
1739
- if device.is_cpu:
1740
- # check if already loaded
1741
- if self.cpu_exec:
1742
- return self.cpu_exec
1743
- # avoid repeated build attempts
1744
- if self.cpu_build_failed:
1745
- return None
1746
- if not warp.is_cpu_available():
1747
- raise RuntimeError("Failed to build CPU module because no CPU buildchain was found")
1748
- else:
1749
- # check if already loaded
1750
- cuda_exec = self.cuda_execs.get(device.context)
1751
- if cuda_exec is not None:
1752
- return cuda_exec
1753
- # avoid repeated build attempts
1754
- if self.cuda_build_failed:
1755
- return None
1756
- if not warp.is_cuda_available():
1757
- raise RuntimeError("Failed to build CUDA module because CUDA is not available")
1884
+ # compute the hash if needed
1885
+ if self.hasher is None:
1886
+ self.hasher = ModuleHasher(self)
1887
+
1888
+ # check if executable module is already loaded and not stale
1889
+ exec = self.execs.get(device.context)
1890
+ if exec is not None:
1891
+ if exec.module_hash == self.hasher.module_hash:
1892
+ return exec
1893
+
1894
+ # quietly avoid repeated build attempts to reduce error spew
1895
+ if device.context in self.failed_builds:
1896
+ return None
1758
1897
 
1759
1898
  module_name = "wp_" + self.name
1760
- module_hash = self.hash_module()
1899
+ module_hash = self.hasher.module_hash
1761
1900
 
1762
1901
  # use a unique module path using the module short hash
1763
1902
  module_dir = os.path.join(warp.config.kernel_cache_dir, f"{module_name}_{module_hash.hex()[:7]}")
@@ -1807,7 +1946,7 @@ class Module:
1807
1946
  or not warp.config.cache_kernels
1808
1947
  or warp.config.verify_autograd_array_access
1809
1948
  ):
1810
- builder = ModuleBuilder(self, self.options)
1949
+ builder = ModuleBuilder(self, self.options, hasher=self.hasher)
1811
1950
 
1812
1951
  # create a temporary (process unique) dir for build outputs before moving to the binary dir
1813
1952
  build_dir = os.path.join(
@@ -1844,7 +1983,7 @@ class Module:
1844
1983
  )
1845
1984
 
1846
1985
  except Exception as e:
1847
- self.cpu_build_failed = True
1986
+ self.failed_builds.add(None)
1848
1987
  module_load_timer.extra_msg = " (error)"
1849
1988
  raise (e)
1850
1989
 
@@ -1873,7 +2012,7 @@ class Module:
1873
2012
  )
1874
2013
 
1875
2014
  except Exception as e:
1876
- self.cuda_build_failed = True
2015
+ self.failed_builds.add(device.context)
1877
2016
  module_load_timer.extra_msg = " (error)"
1878
2017
  raise (e)
1879
2018
 
@@ -1914,15 +2053,18 @@ class Module:
1914
2053
  # -----------------------------------------------------------
1915
2054
  # Load CPU or CUDA binary
1916
2055
  if device.is_cpu:
1917
- runtime.llvm.load_obj(binary_path.encode("utf-8"), module_name.encode("utf-8"))
1918
- module_exec = ModuleExec(module_name, device)
1919
- self.cpu_exec = module_exec
2056
+ # LLVM modules are identified using strings, so we need to ensure uniqueness
2057
+ module_handle = f"{module_name}_{self.cpu_exec_id}"
2058
+ self.cpu_exec_id += 1
2059
+ runtime.llvm.load_obj(binary_path.encode("utf-8"), module_handle.encode("utf-8"))
2060
+ module_exec = ModuleExec(module_handle, module_hash, device)
2061
+ self.execs[None] = module_exec
1920
2062
 
1921
2063
  elif device.is_cuda:
1922
2064
  cuda_module = warp.build.load_cuda(binary_path, device)
1923
2065
  if cuda_module is not None:
1924
- module_exec = ModuleExec(cuda_module, device)
1925
- self.cuda_execs[device.context] = module_exec
2066
+ module_exec = ModuleExec(cuda_module, module_hash, device)
2067
+ self.execs[device.context] = module_exec
1926
2068
  else:
1927
2069
  module_load_timer.extra_msg = " (error)"
1928
2070
  raise Exception(f"Failed to load CUDA module '{self.name}'")
@@ -1936,20 +2078,22 @@ class Module:
1936
2078
  return module_exec
1937
2079
 
1938
2080
  def unload(self):
2081
+ # force rehashing on next load
2082
+ self.mark_modified()
2083
+
1939
2084
  # clear loaded modules
1940
- self.cpu_exec = None
1941
- self.cuda_execs = {}
2085
+ self.execs = {}
1942
2086
 
1943
- # clear content hash
1944
- self.content_hash = None
2087
+ def mark_modified(self):
2088
+ # clear hash data
2089
+ self.hasher = None
2090
+
2091
+ # clear build failures
2092
+ self.failed_builds = set()
1945
2093
 
1946
2094
  # lookup kernel entry points based on name, called after compilation / module load
1947
2095
  def get_kernel_hooks(self, kernel, device):
1948
- if device.is_cuda:
1949
- module_exec = self.cuda_execs.get(device.context)
1950
- else:
1951
- module_exec = self.cpu_exec
1952
-
2096
+ module_exec = self.execs.get(device.context)
1953
2097
  if module_exec is not None:
1954
2098
  return module_exec.get_kernel_hooks(kernel)
1955
2099
  else:
@@ -2056,6 +2200,63 @@ class ContextGuard:
2056
2200
  runtime.core.cuda_context_set_current(self.saved_context)
2057
2201
 
2058
2202
 
2203
+ class Event:
2204
+ """A CUDA event that can be recorded onto a stream.
2205
+
2206
+ Events can be used for device-side synchronization, which do not block
2207
+ the host thread.
2208
+ """
2209
+
2210
+ # event creation flags
2211
+ class Flags:
2212
+ DEFAULT = 0x0
2213
+ BLOCKING_SYNC = 0x1
2214
+ DISABLE_TIMING = 0x2
2215
+
2216
+ def __new__(cls, *args, **kwargs):
2217
+ """Creates a new event instance."""
2218
+ instance = super(Event, cls).__new__(cls)
2219
+ instance.owner = False
2220
+ return instance
2221
+
2222
+ def __init__(self, device: "Devicelike" = None, cuda_event=None, enable_timing: bool = False):
2223
+ """Initializes the event on a CUDA device.
2224
+
2225
+ Args:
2226
+ device: The CUDA device whose streams this event may be recorded onto.
2227
+ If ``None``, then the current default device will be used.
2228
+ cuda_event: A pointer to a previously allocated CUDA event. If
2229
+ `None`, then a new event will be allocated on the associated device.
2230
+ enable_timing: If ``True`` this event will record timing data.
2231
+ :func:`~warp.get_event_elapsed_time` can be used to measure the
2232
+ time between two events created with ``enable_timing=True`` and
2233
+ recorded onto streams.
2234
+ """
2235
+
2236
+ device = get_device(device)
2237
+ if not device.is_cuda:
2238
+ raise RuntimeError(f"Device {device} is not a CUDA device")
2239
+
2240
+ self.device = device
2241
+
2242
+ if cuda_event is not None:
2243
+ self.cuda_event = cuda_event
2244
+ else:
2245
+ flags = Event.Flags.DEFAULT
2246
+ if not enable_timing:
2247
+ flags |= Event.Flags.DISABLE_TIMING
2248
+ self.cuda_event = runtime.core.cuda_event_create(device.context, flags)
2249
+ if not self.cuda_event:
2250
+ raise RuntimeError(f"Failed to create event on device {device}")
2251
+ self.owner = True
2252
+
2253
+ def __del__(self):
2254
+ if not self.owner:
2255
+ return
2256
+
2257
+ runtime.core.cuda_event_destroy(self.cuda_event)
2258
+
2259
+
2059
2260
  class Stream:
2060
2261
  def __new__(cls, *args, **kwargs):
2061
2262
  instance = super(Stream, cls).__new__(cls)
@@ -2063,7 +2264,27 @@ class Stream:
2063
2264
  instance.owner = False
2064
2265
  return instance
2065
2266
 
2066
- def __init__(self, device=None, **kwargs):
2267
+ def __init__(self, device: Optional[Union["Device", str]] = None, priority: int = 0, **kwargs):
2268
+ """Initialize the stream on a device with an optional specified priority.
2269
+
2270
+ Args:
2271
+ device: The CUDA device on which this stream will be created.
2272
+ priority: An optional integer specifying the requested stream priority.
2273
+ Can be -1 (high priority) or 0 (low/default priority).
2274
+ Values outside this range will be clamped.
2275
+ cuda_stream (int): A optional external stream handle passed as an
2276
+ integer. The caller is responsible for ensuring that the external
2277
+ stream does not get destroyed while it is referenced by this
2278
+ object.
2279
+
2280
+ Raises:
2281
+ RuntimeError: If function is called before Warp has completed
2282
+ initialization with a ``device`` that is not an instance of
2283
+ :class:`Device``.
2284
+ RuntimeError: ``device`` is not a CUDA Device.
2285
+ RuntimeError: The stream could not be created on the device.
2286
+ TypeError: The requested stream priority is not an integer.
2287
+ """
2067
2288
  # event used internally for synchronization (cached to avoid creating temporary events)
2068
2289
  self._cached_event = None
2069
2290
 
@@ -2072,7 +2293,7 @@ class Stream:
2072
2293
  device = runtime.get_device(device)
2073
2294
  elif not isinstance(device, Device):
2074
2295
  raise RuntimeError(
2075
- "A device object is required when creating a stream before or during Warp initialization"
2296
+ "A Device object is required when creating a stream before or during Warp initialization"
2076
2297
  )
2077
2298
 
2078
2299
  if not device.is_cuda:
@@ -2085,7 +2306,11 @@ class Stream:
2085
2306
  self.cuda_stream = kwargs["cuda_stream"]
2086
2307
  device.runtime.core.cuda_stream_register(device.context, self.cuda_stream)
2087
2308
  else:
2088
- self.cuda_stream = device.runtime.core.cuda_stream_create(device.context)
2309
+ if not isinstance(priority, int):
2310
+ raise TypeError("Stream priority must be an integer.")
2311
+ clamped_priority = max(-1, min(priority, 0)) # Only support two priority levels
2312
+ self.cuda_stream = device.runtime.core.cuda_stream_create(device.context, clamped_priority)
2313
+
2089
2314
  if not self.cuda_stream:
2090
2315
  raise RuntimeError(f"Failed to create stream on device {device}")
2091
2316
  self.owner = True
@@ -2100,12 +2325,22 @@ class Stream:
2100
2325
  runtime.core.cuda_stream_unregister(self.device.context, self.cuda_stream)
2101
2326
 
2102
2327
  @property
2103
- def cached_event(self):
2328
+ def cached_event(self) -> Event:
2104
2329
  if self._cached_event is None:
2105
2330
  self._cached_event = Event(self.device)
2106
2331
  return self._cached_event
2107
2332
 
2108
- def record_event(self, event=None):
2333
+ def record_event(self, event: Optional[Event] = None) -> Event:
2334
+ """Record an event onto the stream.
2335
+
2336
+ Args:
2337
+ event: A warp.Event instance to be recorded onto the stream. If not
2338
+ provided, an :class:`~warp.Event` on the same device will be created.
2339
+
2340
+ Raises:
2341
+ RuntimeError: The provided :class:`~warp.Event` is from a different device than
2342
+ the recording stream.
2343
+ """
2109
2344
  if event is None:
2110
2345
  event = Event(self.device)
2111
2346
  elif event.device != self.device:
@@ -2117,56 +2352,45 @@ class Stream:
2117
2352
 
2118
2353
  return event
2119
2354
 
2120
- def wait_event(self, event):
2355
+ def wait_event(self, event: Event):
2356
+ """Makes all future work in this stream wait until `event` has completed.
2357
+
2358
+ This function does not block the host thread.
2359
+ """
2121
2360
  runtime.core.cuda_stream_wait_event(self.cuda_stream, event.cuda_event)
2122
2361
 
2123
- def wait_stream(self, other_stream, event=None):
2362
+ def wait_stream(self, other_stream: "Stream", event: Optional[Event] = None):
2363
+ """Records an event on `other_stream` and makes this stream wait on it.
2364
+
2365
+ All work added to this stream after this function has been called will
2366
+ delay their execution until all preceding commands in `other_stream`
2367
+ have completed.
2368
+
2369
+ This function does not block the host thread.
2370
+
2371
+ Args:
2372
+ other_stream: The stream on which the calling stream will wait for
2373
+ previously issued commands to complete before executing subsequent
2374
+ commands.
2375
+ event: An optional :class:`Event` instance that will be used to
2376
+ record an event onto ``other_stream``. If ``None``, an internally
2377
+ managed :class:`Event` instance will be used.
2378
+ """
2379
+
2124
2380
  if event is None:
2125
2381
  event = other_stream.cached_event
2126
2382
 
2127
2383
  runtime.core.cuda_stream_wait_stream(self.cuda_stream, other_stream.cuda_stream, event.cuda_event)
2128
2384
 
2129
- # whether a graph capture is currently ongoing on this stream
2130
2385
  @property
2131
- def is_capturing(self):
2386
+ def is_capturing(self) -> bool:
2387
+ """A boolean indicating whether a graph capture is currently ongoing on this stream."""
2132
2388
  return bool(runtime.core.cuda_stream_is_capturing(self.cuda_stream))
2133
2389
 
2134
-
2135
- class Event:
2136
- # event creation flags
2137
- class Flags:
2138
- DEFAULT = 0x0
2139
- BLOCKING_SYNC = 0x1
2140
- DISABLE_TIMING = 0x2
2141
-
2142
- def __new__(cls, *args, **kwargs):
2143
- instance = super(Event, cls).__new__(cls)
2144
- instance.owner = False
2145
- return instance
2146
-
2147
- def __init__(self, device=None, cuda_event=None, enable_timing=False):
2148
- device = get_device(device)
2149
- if not device.is_cuda:
2150
- raise RuntimeError(f"Device {device} is not a CUDA device")
2151
-
2152
- self.device = device
2153
-
2154
- if cuda_event is not None:
2155
- self.cuda_event = cuda_event
2156
- else:
2157
- flags = Event.Flags.DEFAULT
2158
- if not enable_timing:
2159
- flags |= Event.Flags.DISABLE_TIMING
2160
- self.cuda_event = runtime.core.cuda_event_create(device.context, flags)
2161
- if not self.cuda_event:
2162
- raise RuntimeError(f"Failed to create event on device {device}")
2163
- self.owner = True
2164
-
2165
- def __del__(self):
2166
- if not self.owner:
2167
- return
2168
-
2169
- runtime.core.cuda_event_destroy(self.cuda_event)
2390
+ @property
2391
+ def priority(self) -> int:
2392
+ """An integer representing the priority of the stream."""
2393
+ return runtime.core.cuda_stream_get_priority(self.cuda_stream)
2170
2394
 
2171
2395
 
2172
2396
  class Device:
@@ -2178,14 +2402,14 @@ class Device:
2178
2402
  or ``"CPU"`` if the processor name cannot be determined.
2179
2403
  arch: An integer representing the compute capability version number calculated as
2180
2404
  ``10 * major + minor``. ``0`` for CPU devices.
2181
- is_uva: A boolean indicating whether or not the device supports unified addressing.
2405
+ is_uva: A boolean indicating whether the device supports unified addressing.
2182
2406
  ``False`` for CPU devices.
2183
- is_cubin_supported: A boolean indicating whether or not Warp's version of NVRTC can directly
2407
+ is_cubin_supported: A boolean indicating whether Warp's version of NVRTC can directly
2184
2408
  generate CUDA binary files (cubin) for this device's architecture. ``False`` for CPU devices.
2185
- is_mempool_supported: A boolean indicating whether or not the device supports using the
2409
+ is_mempool_supported: A boolean indicating whether the device supports using the
2186
2410
  ``cuMemAllocAsync`` and ``cuMemPool`` family of APIs for stream-ordered memory allocations. ``False`` for
2187
2411
  CPU devices.
2188
- is_primary: A boolean indicating whether or not this device's CUDA context is also the
2412
+ is_primary: A boolean indicating whether this device's CUDA context is also the
2189
2413
  device's primary context.
2190
2414
  uuid: A string representing the UUID of the CUDA device. The UUID is in the same format used by
2191
2415
  ``nvidia-smi -L``. ``None`` for CPU devices.
@@ -2274,7 +2498,7 @@ class Device:
2274
2498
 
2275
2499
  # initialize streams unless context acquisition is postponed
2276
2500
  if self._context is not None:
2277
- self.init_streams()
2501
+ self._init_streams()
2278
2502
 
2279
2503
  # TODO: add more device-specific dispatch functions
2280
2504
  self.memset = lambda ptr, value, size: runtime.core.memset_device(self.context, ptr, value, size)
@@ -2285,7 +2509,13 @@ class Device:
2285
2509
  else:
2286
2510
  raise RuntimeError(f"Invalid device ordinal ({ordinal})'")
2287
2511
 
2288
- def get_allocator(self, pinned=False):
2512
+ def get_allocator(self, pinned: bool = False):
2513
+ """Get the memory allocator for this device.
2514
+
2515
+ Args:
2516
+ pinned: If ``True``, an allocator for pinned memory will be
2517
+ returned. Only applicable when this device is a CPU device.
2518
+ """
2289
2519
  if self.is_cuda:
2290
2520
  return self.current_allocator
2291
2521
  else:
@@ -2294,7 +2524,8 @@ class Device:
2294
2524
  else:
2295
2525
  return self.default_allocator
2296
2526
 
2297
- def init_streams(self):
2527
+ def _init_streams(self):
2528
+ """Initializes the device's current stream and the device's null stream."""
2298
2529
  # create a stream for asynchronous work
2299
2530
  self.set_stream(Stream(self))
2300
2531
 
@@ -2302,17 +2533,18 @@ class Device:
2302
2533
  self.null_stream = Stream(self, cuda_stream=None)
2303
2534
 
2304
2535
  @property
2305
- def is_cpu(self):
2306
- """A boolean indicating whether or not the device is a CPU device."""
2536
+ def is_cpu(self) -> bool:
2537
+ """A boolean indicating whether the device is a CPU device."""
2307
2538
  return self.ordinal < 0
2308
2539
 
2309
2540
  @property
2310
- def is_cuda(self):
2311
- """A boolean indicating whether or not the device is a CUDA device."""
2541
+ def is_cuda(self) -> bool:
2542
+ """A boolean indicating whether the device is a CUDA device."""
2312
2543
  return self.ordinal >= 0
2313
2544
 
2314
2545
  @property
2315
- def is_capturing(self):
2546
+ def is_capturing(self) -> bool:
2547
+ """A boolean indicating whether this device's default stream is currently capturing a graph."""
2316
2548
  if self.is_cuda and self.stream is not None:
2317
2549
  # There is no CUDA API to check if graph capture was started on a device, so we
2318
2550
  # can't tell if a capture was started by external code on a different stream.
@@ -2336,17 +2568,17 @@ class Device:
2336
2568
  raise RuntimeError(f"Failed to acquire primary context for device {self}")
2337
2569
  self.runtime.context_map[self._context] = self
2338
2570
  # initialize streams
2339
- self.init_streams()
2571
+ self._init_streams()
2340
2572
  runtime.core.cuda_context_set_current(prev_context)
2341
2573
  return self._context
2342
2574
 
2343
2575
  @property
2344
- def has_context(self):
2345
- """A boolean indicating whether or not the device has a CUDA context associated with it."""
2576
+ def has_context(self) -> bool:
2577
+ """A boolean indicating whether the device has a CUDA context associated with it."""
2346
2578
  return self._context is not None
2347
2579
 
2348
2580
  @property
2349
- def stream(self):
2581
+ def stream(self) -> Stream:
2350
2582
  """The stream associated with a CUDA device.
2351
2583
 
2352
2584
  Raises:
@@ -2361,7 +2593,22 @@ class Device:
2361
2593
  def stream(self, stream):
2362
2594
  self.set_stream(stream)
2363
2595
 
2364
- def set_stream(self, stream, sync=True):
2596
+ def set_stream(self, stream: Stream, sync: bool = True) -> None:
2597
+ """Set the current stream for this CUDA device.
2598
+
2599
+ The current stream will be used by default for all kernel launches and
2600
+ memory operations on this device.
2601
+
2602
+ If this is an external stream, the caller is responsible for
2603
+ guaranteeing the lifetime of the stream.
2604
+
2605
+ Consider using :class:`warp.ScopedStream` instead.
2606
+
2607
+ Args:
2608
+ stream: The stream to set as this device's current stream.
2609
+ sync: If ``True``, then ``stream`` will perform a device-side
2610
+ synchronization with the device's previous current stream.
2611
+ """
2365
2612
  if self.is_cuda:
2366
2613
  if stream.device != self:
2367
2614
  raise RuntimeError(f"Stream from device {stream.device} cannot be used on device {self}")
@@ -2372,12 +2619,12 @@ class Device:
2372
2619
  raise RuntimeError(f"Device {self} is not a CUDA device")
2373
2620
 
2374
2621
  @property
2375
- def has_stream(self):
2376
- """A boolean indicating whether or not the device has a stream associated with it."""
2622
+ def has_stream(self) -> bool:
2623
+ """A boolean indicating whether the device has a stream associated with it."""
2377
2624
  return self._stream is not None
2378
2625
 
2379
2626
  @property
2380
- def total_memory(self):
2627
+ def total_memory(self) -> int:
2381
2628
  """The total amount of device memory available in bytes.
2382
2629
 
2383
2630
  This function is currently only implemented for CUDA devices. 0 will be returned if called on a CPU device.
@@ -2391,7 +2638,7 @@ class Device:
2391
2638
  return 0
2392
2639
 
2393
2640
  @property
2394
- def free_memory(self):
2641
+ def free_memory(self) -> int:
2395
2642
  """The amount of memory on the device that is free according to the OS in bytes.
2396
2643
 
2397
2644
  This function is currently only implemented for CUDA devices. 0 will be returned if called on a CPU device.
@@ -2755,6 +3002,12 @@ class Runtime:
2755
3002
  self.core.mesh_refit_host.argtypes = [ctypes.c_uint64]
2756
3003
  self.core.mesh_refit_device.argtypes = [ctypes.c_uint64]
2757
3004
 
3005
+ self.core.mesh_set_points_host.argtypes = [ctypes.c_uint64, warp.types.array_t]
3006
+ self.core.mesh_set_points_device.argtypes = [ctypes.c_uint64, warp.types.array_t]
3007
+
3008
+ self.core.mesh_set_velocities_host.argtypes = [ctypes.c_uint64, warp.types.array_t]
3009
+ self.core.mesh_set_velocities_device.argtypes = [ctypes.c_uint64, warp.types.array_t]
3010
+
2758
3011
  self.core.hash_grid_create_host.argtypes = [ctypes.c_int, ctypes.c_int, ctypes.c_int]
2759
3012
  self.core.hash_grid_create_host.restype = ctypes.c_uint64
2760
3013
  self.core.hash_grid_destroy_host.argtypes = [ctypes.c_uint64]
@@ -3029,7 +3282,7 @@ class Runtime:
3029
3282
  self.core.cuda_set_mempool_access_enabled.argtypes = [ctypes.c_int, ctypes.c_int, ctypes.c_int]
3030
3283
  self.core.cuda_set_mempool_access_enabled.restype = ctypes.c_int
3031
3284
 
3032
- self.core.cuda_stream_create.argtypes = [ctypes.c_void_p]
3285
+ self.core.cuda_stream_create.argtypes = [ctypes.c_void_p, ctypes.c_int]
3033
3286
  self.core.cuda_stream_create.restype = ctypes.c_void_p
3034
3287
  self.core.cuda_stream_destroy.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
3035
3288
  self.core.cuda_stream_destroy.restype = None
@@ -3047,6 +3300,8 @@ class Runtime:
3047
3300
  self.core.cuda_stream_is_capturing.restype = ctypes.c_int
3048
3301
  self.core.cuda_stream_get_capture_id.argtypes = [ctypes.c_void_p]
3049
3302
  self.core.cuda_stream_get_capture_id.restype = ctypes.c_uint64
3303
+ self.core.cuda_stream_get_priority.argtypes = [ctypes.c_void_p]
3304
+ self.core.cuda_stream_get_priority.restype = ctypes.c_int
3050
3305
 
3051
3306
  self.core.cuda_event_create.argtypes = [ctypes.c_void_p, ctypes.c_uint]
3052
3307
  self.core.cuda_event_create.restype = ctypes.c_void_p
@@ -3382,10 +3637,10 @@ class Runtime:
3382
3637
 
3383
3638
  raise ValueError(f"Invalid device identifier: {ident}")
3384
3639
 
3385
- def set_default_device(self, ident: Devicelike):
3640
+ def set_default_device(self, ident: Devicelike) -> None:
3386
3641
  self.default_device = self.get_device(ident)
3387
3642
 
3388
- def get_current_cuda_device(self):
3643
+ def get_current_cuda_device(self) -> Device:
3389
3644
  current_context = self.core.cuda_context_get_current()
3390
3645
  if current_context is not None:
3391
3646
  current_device = self.context_map.get(current_context)
@@ -3415,7 +3670,7 @@ class Runtime:
3415
3670
  else:
3416
3671
  raise RuntimeError('"cuda" device requested but CUDA is not supported by the hardware or driver')
3417
3672
 
3418
- def rename_device(self, device, alias):
3673
+ def rename_device(self, device, alias) -> Device:
3419
3674
  del self.device_map[device.alias]
3420
3675
  device.alias = alias
3421
3676
  self.device_map[alias] = device
@@ -3462,7 +3717,7 @@ class Runtime:
3462
3717
 
3463
3718
  return device
3464
3719
 
3465
- def unmap_cuda_device(self, alias):
3720
+ def unmap_cuda_device(self, alias) -> None:
3466
3721
  device = self.device_map.get(alias)
3467
3722
 
3468
3723
  # make sure the alias refers to a CUDA device
@@ -3473,7 +3728,7 @@ class Runtime:
3473
3728
  del self.context_map[device.context]
3474
3729
  self.cuda_devices.remove(device)
3475
3730
 
3476
- def verify_cuda_device(self, device: Devicelike = None):
3731
+ def verify_cuda_device(self, device: Devicelike = None) -> None:
3477
3732
  if warp.config.verify_cuda:
3478
3733
  device = runtime.get_device(device)
3479
3734
  if not device.is_cuda:
@@ -3485,13 +3740,13 @@ class Runtime:
3485
3740
 
3486
3741
 
3487
3742
  # global entry points
3488
- def is_cpu_available():
3743
+ def is_cpu_available() -> bool:
3489
3744
  init()
3490
3745
 
3491
- return runtime.llvm
3746
+ return runtime.llvm is not None
3492
3747
 
3493
3748
 
3494
- def is_cuda_available():
3749
+ def is_cuda_available() -> bool:
3495
3750
  return get_cuda_device_count() > 0
3496
3751
 
3497
3752
 
@@ -3575,8 +3830,8 @@ def get_device(ident: Devicelike = None) -> Device:
3575
3830
  return runtime.get_device(ident)
3576
3831
 
3577
3832
 
3578
- def set_device(ident: Devicelike):
3579
- """Sets the target device identified by the argument."""
3833
+ def set_device(ident: Devicelike) -> None:
3834
+ """Sets the default device identified by the argument."""
3580
3835
 
3581
3836
  init()
3582
3837
 
@@ -3604,7 +3859,7 @@ def map_cuda_device(alias: str, context: ctypes.c_void_p = None) -> Device:
3604
3859
  return runtime.map_cuda_device(alias, context)
3605
3860
 
3606
3861
 
3607
- def unmap_cuda_device(alias: str):
3862
+ def unmap_cuda_device(alias: str) -> None:
3608
3863
  """Remove a CUDA device with the given alias."""
3609
3864
 
3610
3865
  init()
@@ -3612,7 +3867,7 @@ def unmap_cuda_device(alias: str):
3612
3867
  runtime.unmap_cuda_device(alias)
3613
3868
 
3614
3869
 
3615
- def is_mempool_supported(device: Devicelike):
3870
+ def is_mempool_supported(device: Devicelike) -> bool:
3616
3871
  """Check if CUDA memory pool allocators are available on the device."""
3617
3872
 
3618
3873
  init()
@@ -3622,7 +3877,7 @@ def is_mempool_supported(device: Devicelike):
3622
3877
  return device.is_mempool_supported
3623
3878
 
3624
3879
 
3625
- def is_mempool_enabled(device: Devicelike):
3880
+ def is_mempool_enabled(device: Devicelike) -> bool:
3626
3881
  """Check if CUDA memory pool allocators are enabled on the device."""
3627
3882
 
3628
3883
  init()
@@ -3632,7 +3887,7 @@ def is_mempool_enabled(device: Devicelike):
3632
3887
  return device.is_mempool_enabled
3633
3888
 
3634
3889
 
3635
- def set_mempool_enabled(device: Devicelike, enable: bool):
3890
+ def set_mempool_enabled(device: Devicelike, enable: bool) -> None:
3636
3891
  """Enable or disable CUDA memory pool allocators on the device.
3637
3892
 
3638
3893
  Pooled allocators are typically faster and allow allocating memory during graph capture.
@@ -3663,7 +3918,7 @@ def set_mempool_enabled(device: Devicelike, enable: bool):
3663
3918
  raise ValueError("Memory pools are only supported on CUDA devices")
3664
3919
 
3665
3920
 
3666
- def set_mempool_release_threshold(device: Devicelike, threshold: Union[int, float]):
3921
+ def set_mempool_release_threshold(device: Devicelike, threshold: Union[int, float]) -> None:
3667
3922
  """Set the CUDA memory pool release threshold on the device.
3668
3923
 
3669
3924
  This is the amount of reserved memory to hold onto before trying to release memory back to the OS.
@@ -3694,7 +3949,7 @@ def set_mempool_release_threshold(device: Devicelike, threshold: Union[int, floa
3694
3949
  raise RuntimeError(f"Failed to set memory pool release threshold for device {device}")
3695
3950
 
3696
3951
 
3697
- def get_mempool_release_threshold(device: Devicelike):
3952
+ def get_mempool_release_threshold(device: Devicelike) -> int:
3698
3953
  """Get the CUDA memory pool release threshold on the device."""
3699
3954
 
3700
3955
  init()
@@ -3710,7 +3965,7 @@ def get_mempool_release_threshold(device: Devicelike):
3710
3965
  return runtime.core.cuda_device_get_mempool_release_threshold(device.ordinal)
3711
3966
 
3712
3967
 
3713
- def is_peer_access_supported(target_device: Devicelike, peer_device: Devicelike):
3968
+ def is_peer_access_supported(target_device: Devicelike, peer_device: Devicelike) -> bool:
3714
3969
  """Check if `peer_device` can directly access the memory of `target_device` on this system.
3715
3970
 
3716
3971
  This applies to memory allocated using default CUDA allocators. For memory allocated using
@@ -3731,7 +3986,7 @@ def is_peer_access_supported(target_device: Devicelike, peer_device: Devicelike)
3731
3986
  return bool(runtime.core.cuda_is_peer_access_supported(target_device.ordinal, peer_device.ordinal))
3732
3987
 
3733
3988
 
3734
- def is_peer_access_enabled(target_device: Devicelike, peer_device: Devicelike):
3989
+ def is_peer_access_enabled(target_device: Devicelike, peer_device: Devicelike) -> bool:
3735
3990
  """Check if `peer_device` can currently access the memory of `target_device`.
3736
3991
 
3737
3992
  This applies to memory allocated using default CUDA allocators. For memory allocated using
@@ -3752,7 +4007,7 @@ def is_peer_access_enabled(target_device: Devicelike, peer_device: Devicelike):
3752
4007
  return bool(runtime.core.cuda_is_peer_access_enabled(target_device.context, peer_device.context))
3753
4008
 
3754
4009
 
3755
- def set_peer_access_enabled(target_device: Devicelike, peer_device: Devicelike, enable: bool):
4010
+ def set_peer_access_enabled(target_device: Devicelike, peer_device: Devicelike, enable: bool) -> None:
3756
4011
  """Enable or disable direct access from `peer_device` to the memory of `target_device`.
3757
4012
 
3758
4013
  Enabling peer access can improve the speed of peer-to-peer memory transfers, but can have
@@ -3784,7 +4039,7 @@ def set_peer_access_enabled(target_device: Devicelike, peer_device: Devicelike,
3784
4039
  raise RuntimeError(f"Failed to {action} peer access from device {peer_device} to device {target_device}")
3785
4040
 
3786
4041
 
3787
- def is_mempool_access_supported(target_device: Devicelike, peer_device: Devicelike):
4042
+ def is_mempool_access_supported(target_device: Devicelike, peer_device: Devicelike) -> bool:
3788
4043
  """Check if `peer_device` can directly access the memory pool of `target_device`.
3789
4044
 
3790
4045
  If mempool access is possible, it can be managed using `set_mempool_access_enabled()` and `is_mempool_access_enabled()`.
@@ -3801,7 +4056,7 @@ def is_mempool_access_supported(target_device: Devicelike, peer_device: Deviceli
3801
4056
  return target_device.is_mempool_supported and is_peer_access_supported(target_device, peer_device)
3802
4057
 
3803
4058
 
3804
- def is_mempool_access_enabled(target_device: Devicelike, peer_device: Devicelike):
4059
+ def is_mempool_access_enabled(target_device: Devicelike, peer_device: Devicelike) -> bool:
3805
4060
  """Check if `peer_device` can currently access the memory pool of `target_device`.
3806
4061
 
3807
4062
  This applies to memory allocated using CUDA pooled allocators. For memory allocated using
@@ -3822,7 +4077,7 @@ def is_mempool_access_enabled(target_device: Devicelike, peer_device: Devicelike
3822
4077
  return bool(runtime.core.cuda_is_mempool_access_enabled(target_device.ordinal, peer_device.ordinal))
3823
4078
 
3824
4079
 
3825
- def set_mempool_access_enabled(target_device: Devicelike, peer_device: Devicelike, enable: bool):
4080
+ def set_mempool_access_enabled(target_device: Devicelike, peer_device: Devicelike, enable: bool) -> None:
3826
4081
  """Enable or disable access from `peer_device` to the memory pool of `target_device`.
3827
4082
 
3828
4083
  This applies to memory allocated using CUDA pooled allocators. For memory allocated using
@@ -3858,26 +4113,41 @@ def set_mempool_access_enabled(target_device: Devicelike, peer_device: Devicelik
3858
4113
 
3859
4114
 
3860
4115
  def get_stream(device: Devicelike = None) -> Stream:
3861
- """Return the stream currently used by the given device"""
4116
+ """Return the stream currently used by the given device.
4117
+
4118
+ Args:
4119
+ device: An optional :class:`Device` instance or device alias
4120
+ (e.g. "cuda:0") for which the current stream will be returned.
4121
+ If ``None``, the default device will be used.
4122
+
4123
+ Raises:
4124
+ RuntimeError: The device is not a CUDA device.
4125
+ """
3862
4126
 
3863
4127
  return get_device(device).stream
3864
4128
 
3865
4129
 
3866
- def set_stream(stream, device: Devicelike = None, sync: bool = False):
3867
- """Set the stream to be used by the given device.
4130
+ def set_stream(stream: Stream, device: Devicelike = None, sync: bool = False) -> None:
4131
+ """Convenience function for calling :meth:`Device.set_stream` on the given ``device``.
3868
4132
 
3869
- If this is an external stream, caller is responsible for guaranteeing the lifetime of the stream.
3870
- Consider using wp.ScopedStream instead.
4133
+ Args:
4134
+ device: An optional :class:`Device` instance or device alias
4135
+ (e.g. "cuda:0") for which the current stream is to be replaced with
4136
+ ``stream``. If ``None``, the default device will be used.
4137
+ stream: The stream to set as this device's current stream.
4138
+ sync: If ``True``, then ``stream`` will perform a device-side
4139
+ synchronization with the device's previous current stream.
3871
4140
  """
3872
4141
 
3873
4142
  get_device(device).set_stream(stream, sync=sync)
3874
4143
 
3875
4144
 
3876
- def record_event(event: Event = None):
3877
- """Record a CUDA event on the current stream.
4145
+ def record_event(event: Optional[Event] = None):
4146
+ """Convenience function for calling :meth:`Stream.record_event` on the current stream.
3878
4147
 
3879
4148
  Args:
3880
- event: Event to record. If None, a new Event will be created.
4149
+ event: :class:`Event` instance to record. If ``None``, a new :class:`Event`
4150
+ instance will be created.
3881
4151
 
3882
4152
  Returns:
3883
4153
  The recorded event.
@@ -3887,29 +4157,31 @@ def record_event(event: Event = None):
3887
4157
 
3888
4158
 
3889
4159
  def wait_event(event: Event):
3890
- """Make the current stream wait for a CUDA event.
4160
+ """Convenience function for calling :meth:`Stream.wait_event` on the current stream.
3891
4161
 
3892
4162
  Args:
3893
- event: Event to wait for.
4163
+ event: :class:`Event` instance to wait for.
3894
4164
  """
3895
4165
 
3896
4166
  get_stream().wait_event(event)
3897
4167
 
3898
4168
 
3899
- def get_event_elapsed_time(start_event: Event, end_event: Event, synchronize: bool = True):
4169
+ def get_event_elapsed_time(start_event: Event, end_event: Event, synchronize: Optional[bool] = True):
3900
4170
  """Get the elapsed time between two recorded events.
3901
4171
 
3902
- The result is in milliseconds with a resolution of about 0.5 microsecond.
3903
-
3904
- Both events must have been previously recorded with ``wp.record_event()`` or ``wp.Stream.record_event()``.
4172
+ Both events must have been previously recorded with
4173
+ :func:`~warp.record_event()` or :meth:`warp.Stream.record_event()`.
3905
4174
 
3906
4175
  If ``synchronize`` is False, the caller must ensure that device execution has reached ``end_event``
3907
4176
  prior to calling ``get_event_elapsed_time()``.
3908
4177
 
3909
4178
  Args:
3910
- start_event (Event): The start event.
3911
- end_event (Event): The end event.
3912
- synchronize (bool, optional): Whether Warp should synchronize on the ``end_event``.
4179
+ start_event: The start event.
4180
+ end_event: The end event.
4181
+ synchronize: Whether Warp should synchronize on the ``end_event``.
4182
+
4183
+ Returns:
4184
+ The elapsed time in milliseconds with a resolution about 0.5 ms.
3913
4185
  """
3914
4186
 
3915
4187
  # ensure the end_event is reached
@@ -3919,14 +4191,19 @@ def get_event_elapsed_time(start_event: Event, end_event: Event, synchronize: bo
3919
4191
  return runtime.core.cuda_event_elapsed_time(start_event.cuda_event, end_event.cuda_event)
3920
4192
 
3921
4193
 
3922
- def wait_stream(stream: Stream, event: Event = None):
3923
- """Make the current stream wait for another CUDA stream to complete its work.
4194
+ def wait_stream(other_stream: Stream, event: Event = None):
4195
+ """Convenience function for calling :meth:`Stream.wait_stream` on the current stream.
3924
4196
 
3925
4197
  Args:
3926
- event: Event to be used. If None, a new Event will be created.
4198
+ other_stream: The stream on which the calling stream will wait for
4199
+ previously issued commands to complete before executing subsequent
4200
+ commands.
4201
+ event: An optional :class:`Event` instance that will be used to
4202
+ record an event onto ``other_stream``. If ``None``, an internally
4203
+ managed :class:`Event` instance will be used.
3927
4204
  """
3928
4205
 
3929
- get_stream().wait_stream(stream, event=event)
4206
+ get_stream().wait_stream(other_stream, event=event)
3930
4207
 
3931
4208
 
3932
4209
  class RegisteredGLBuffer:
@@ -4362,7 +4639,7 @@ def from_numpy(
4362
4639
  dtype: The data type of the new Warp array. If this is not provided, the data type will be inferred.
4363
4640
  shape: The shape of the Warp array.
4364
4641
  device: The device on which the Warp array will be constructed.
4365
- requires_grad: Whether or not gradients will be tracked for this array.
4642
+ requires_grad: Whether gradients will be tracked for this array.
4366
4643
 
4367
4644
  Raises:
4368
4645
  RuntimeError: The data type of the NumPy array is not supported.
@@ -4398,7 +4675,7 @@ def pack_arg(kernel, arg_type, arg_name, value, device, adjoint=False):
4398
4675
  return arg_type.__ctype__()
4399
4676
 
4400
4677
  elif isinstance(value, warp.types.array_t):
4401
- # accept array descriptors verbatum
4678
+ # accept array descriptors verbatim
4402
4679
  return value
4403
4680
 
4404
4681
  else:
@@ -4865,7 +5142,7 @@ def synchronize_device(device: Devicelike = None):
4865
5142
  runtime.core.cuda_context_synchronize(device.context)
4866
5143
 
4867
5144
 
4868
- def synchronize_stream(stream_or_device=None):
5145
+ def synchronize_stream(stream_or_device: Union[Stream, Devicelike, None] = None):
4869
5146
  """Synchronize the calling CPU thread with any outstanding CUDA work on the specified stream.
4870
5147
 
4871
5148
  This function allows the host application code to ensure that all kernel launches
@@ -4989,7 +5266,7 @@ def set_module_options(options: Dict[str, Any], module: Optional[Any] = None):
4989
5266
  m = module
4990
5267
 
4991
5268
  get_module(m.__name__).options.update(options)
4992
- get_module(m.__name__).unload()
5269
+ get_module(m.__name__).mark_modified()
4993
5270
 
4994
5271
 
4995
5272
  def get_module_options(module: Optional[Any] = None) -> Dict[str, Any]:
@@ -5016,7 +5293,7 @@ def capture_begin(device: Devicelike = None, stream=None, force_module_load=None
5016
5293
  Args:
5017
5294
  device: The CUDA device to capture on
5018
5295
  stream: The CUDA stream to capture on
5019
- force_module_load: Whether or not to force loading of all kernels before capture.
5296
+ force_module_load: Whether to force loading of all kernels before capture.
5020
5297
  In general it is better to use :func:`~warp.load_module()` to selectively load kernels.
5021
5298
  When running with CUDA drivers that support CUDA 12.3 or newer, this option is not recommended to be set to
5022
5299
  ``True`` because kernels can be loaded during graph capture on more recent drivers. If this argument is
@@ -5574,7 +5851,7 @@ def export_stubs(file): # pragma: no cover
5574
5851
 
5575
5852
  return_str = ""
5576
5853
 
5577
- if not f.export or f.hidden: # or f.generic:
5854
+ if f.hidden: # or f.generic:
5578
5855
  continue
5579
5856
 
5580
5857
  return_type = f.value_func(None, None)