warp-lang 1.3.3__py3-none-manylinux2014_x86_64.whl → 1.4.0__py3-none-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (106) hide show
  1. warp/__init__.py +6 -0
  2. warp/autograd.py +59 -6
  3. warp/bin/warp.so +0 -0
  4. warp/build_dll.py +8 -10
  5. warp/builtins.py +126 -4
  6. warp/codegen.py +435 -53
  7. warp/config.py +1 -1
  8. warp/context.py +678 -403
  9. warp/dlpack.py +2 -0
  10. warp/examples/benchmarks/benchmark_cloth.py +10 -0
  11. warp/examples/core/example_render_opengl.py +12 -10
  12. warp/examples/fem/example_adaptive_grid.py +251 -0
  13. warp/examples/fem/example_apic_fluid.py +1 -1
  14. warp/examples/fem/example_diffusion_3d.py +2 -2
  15. warp/examples/fem/example_magnetostatics.py +1 -1
  16. warp/examples/fem/example_streamlines.py +1 -0
  17. warp/examples/fem/utils.py +23 -4
  18. warp/examples/sim/example_cloth.py +50 -6
  19. warp/fem/__init__.py +2 -0
  20. warp/fem/adaptivity.py +493 -0
  21. warp/fem/field/field.py +2 -1
  22. warp/fem/field/nodal_field.py +18 -26
  23. warp/fem/field/test.py +4 -4
  24. warp/fem/field/trial.py +4 -4
  25. warp/fem/geometry/__init__.py +1 -0
  26. warp/fem/geometry/adaptive_nanogrid.py +843 -0
  27. warp/fem/geometry/nanogrid.py +55 -28
  28. warp/fem/space/__init__.py +1 -1
  29. warp/fem/space/nanogrid_function_space.py +69 -35
  30. warp/fem/utils.py +113 -107
  31. warp/jax_experimental.py +28 -15
  32. warp/native/array.h +0 -1
  33. warp/native/builtin.h +103 -6
  34. warp/native/bvh.cu +2 -0
  35. warp/native/cuda_util.cpp +14 -0
  36. warp/native/cuda_util.h +2 -0
  37. warp/native/error.cpp +4 -2
  38. warp/native/exports.h +99 -17
  39. warp/native/mat.h +97 -0
  40. warp/native/mesh.cpp +36 -0
  41. warp/native/mesh.cu +51 -0
  42. warp/native/mesh.h +1 -0
  43. warp/native/quat.h +43 -0
  44. warp/native/spatial.h +6 -0
  45. warp/native/vec.h +74 -0
  46. warp/native/warp.cpp +2 -1
  47. warp/native/warp.cu +10 -3
  48. warp/native/warp.h +8 -1
  49. warp/paddle.py +382 -0
  50. warp/sim/__init__.py +1 -0
  51. warp/sim/collide.py +519 -0
  52. warp/sim/integrator_euler.py +18 -5
  53. warp/sim/integrator_featherstone.py +5 -5
  54. warp/sim/integrator_vbd.py +1026 -0
  55. warp/sim/model.py +49 -23
  56. warp/stubs.py +459 -0
  57. warp/tape.py +2 -0
  58. warp/tests/aux_test_dependent.py +1 -0
  59. warp/tests/aux_test_name_clash1.py +32 -0
  60. warp/tests/aux_test_name_clash2.py +32 -0
  61. warp/tests/aux_test_square.py +1 -0
  62. warp/tests/test_array.py +188 -0
  63. warp/tests/test_async.py +3 -3
  64. warp/tests/test_atomic.py +6 -0
  65. warp/tests/test_closest_point_edge_edge.py +93 -1
  66. warp/tests/test_codegen.py +62 -15
  67. warp/tests/test_codegen_instancing.py +1457 -0
  68. warp/tests/test_collision.py +486 -0
  69. warp/tests/test_compile_consts.py +3 -28
  70. warp/tests/test_dlpack.py +170 -0
  71. warp/tests/test_examples.py +22 -8
  72. warp/tests/test_fast_math.py +10 -4
  73. warp/tests/test_fem.py +64 -0
  74. warp/tests/test_func.py +46 -0
  75. warp/tests/test_implicit_init.py +49 -0
  76. warp/tests/test_jax.py +58 -0
  77. warp/tests/test_mat.py +84 -0
  78. warp/tests/test_mesh_query_point.py +188 -0
  79. warp/tests/test_module_hashing.py +40 -0
  80. warp/tests/test_multigpu.py +3 -3
  81. warp/tests/test_overwrite.py +8 -0
  82. warp/tests/test_paddle.py +852 -0
  83. warp/tests/test_print.py +89 -0
  84. warp/tests/test_quat.py +111 -0
  85. warp/tests/test_reload.py +31 -1
  86. warp/tests/test_scalar_ops.py +2 -0
  87. warp/tests/test_static.py +412 -0
  88. warp/tests/test_streams.py +64 -3
  89. warp/tests/test_struct.py +4 -4
  90. warp/tests/test_torch.py +24 -0
  91. warp/tests/test_triangle_closest_point.py +137 -0
  92. warp/tests/test_types.py +1 -1
  93. warp/tests/test_vbd.py +386 -0
  94. warp/tests/test_vec.py +143 -0
  95. warp/tests/test_vec_scalar_ops.py +139 -0
  96. warp/tests/unittest_suites.py +12 -0
  97. warp/tests/unittest_utils.py +9 -5
  98. warp/thirdparty/dlpack.py +3 -1
  99. warp/types.py +150 -28
  100. warp/utils.py +37 -14
  101. {warp_lang-1.3.3.dist-info → warp_lang-1.4.0.dist-info}/METADATA +10 -8
  102. {warp_lang-1.3.3.dist-info → warp_lang-1.4.0.dist-info}/RECORD +105 -93
  103. warp/tests/test_point_triangle_closest_point.py +0 -143
  104. {warp_lang-1.3.3.dist-info → warp_lang-1.4.0.dist-info}/LICENSE.md +0 -0
  105. {warp_lang-1.3.3.dist-info → warp_lang-1.4.0.dist-info}/WHEEL +0 -0
  106. {warp_lang-1.3.3.dist-info → warp_lang-1.4.0.dist-info}/top_level.txt +0 -0
warp/context.py CHANGED
@@ -19,6 +19,7 @@ import platform
19
19
  import sys
20
20
  import types
21
21
  import typing
22
+ import weakref
22
23
  from copy import copy as shallowcopy
23
24
  from pathlib import Path
24
25
  from struct import pack as struct_pack
@@ -36,6 +37,10 @@ import warp.config
36
37
 
37
38
  def create_value_func(type):
38
39
  def value_func(arg_types, arg_values):
40
+ hint_origin = getattr(type, "__origin__", None)
41
+ if hint_origin is not None and issubclass(hint_origin, typing.Tuple):
42
+ return type.__args__
43
+
39
44
  return type
40
45
 
41
46
  return value_func
@@ -54,6 +59,38 @@ def get_function_args(func):
54
59
  complex_type_hints = (Any, Callable, Tuple)
55
60
  sequence_types = (list, tuple)
56
61
 
62
+ function_key_counts = {}
63
+
64
+
65
+ def generate_unique_function_identifier(key):
66
+ # Generate unique identifiers for user-defined functions in native code.
67
+ # - Prevents conflicts when a function is redefined and old versions are still in use.
68
+ # - Prevents conflicts between multiple closures returned from the same function.
69
+ # - Prevents conflicts between identically named functions from different modules.
70
+ #
71
+ # Currently, we generate a unique id when a new Function is created, which produces
72
+ # globally unique identifiers.
73
+ #
74
+ # NOTE:
75
+ # We could move this to the Module class for generating unique identifiers at module scope,
76
+ # but then we need another solution for preventing conflicts across modules (e.g., different namespaces).
77
+ # That would requires more Python code, generate more native code, and would be slightly slower
78
+ # with no clear advantages over globally-unique identifiers (non-global shared state is still shared state).
79
+ #
80
+ # TODO:
81
+ # Kernels and structs use unique identifiers based on their hash. Using hash-based identifiers
82
+ # for functions would allow filtering out duplicate identical functions during codegen,
83
+ # like we do with kernels and structs. This is worth investigating further, but might require
84
+ # additional refactoring. For example, the code that deals with custom gradient and replay functions
85
+ # requires matching function names, but these special functions get created before the hash
86
+ # for the parent function can be computed. In addition to these complications, computing hashes
87
+ # for all function instances would increase the cost of module hashing when generic functions
88
+ # are involved (currently we only hash the generic templates, which is sufficient).
89
+
90
+ unique_id = function_key_counts.get(key, 0)
91
+ function_key_counts[key] = unique_id + 1
92
+ return f"{key}_{unique_id}"
93
+
57
94
 
58
95
  class Function:
59
96
  def __init__(
@@ -90,6 +127,7 @@ class Function:
90
127
  code_transformers=None,
91
128
  skip_adding_overload=False,
92
129
  require_original_output_arg=False,
130
+ scope_locals=None, # the locals() where the function is defined, used for overload management
93
131
  ):
94
132
  if code_transformers is None:
95
133
  code_transformers = []
@@ -115,6 +153,7 @@ class Function:
115
153
  self.replay_snippet = replay_snippet
116
154
  self.custom_grad_func = None
117
155
  self.require_original_output_arg = require_original_output_arg
156
+ self.generic_parent = None # generic function that was used to instantiate this overload
118
157
 
119
158
  if initializer_list_func is None:
120
159
  self.initializer_list_func = lambda x, y: False
@@ -124,14 +163,19 @@ class Function:
124
163
  )
125
164
  self.hidden = hidden # function will not be listed in docs
126
165
  self.skip_replay = (
127
- skip_replay # whether or not operation will be performed during the forward replay in the backward pass
166
+ skip_replay # whether operation will be performed during the forward replay in the backward pass
128
167
  )
129
- self.missing_grad = missing_grad # whether or not builtin is missing a corresponding adjoint
168
+ self.missing_grad = missing_grad # whether builtin is missing a corresponding adjoint
130
169
  self.generic = generic
131
170
 
132
- # allow registering builtin functions with a different name in Python from the native code
171
+ # allow registering functions with a different name in Python and native code
133
172
  if native_func is None:
134
- self.native_func = key
173
+ if func is None:
174
+ # builtin function
175
+ self.native_func = key
176
+ else:
177
+ # user functions need unique identifiers to avoid conflicts
178
+ self.native_func = generate_unique_function_identifier(key)
135
179
  else:
136
180
  self.native_func = native_func
137
181
 
@@ -162,6 +206,11 @@ class Function:
162
206
  else:
163
207
  self.input_types[name] = type
164
208
 
209
+ # Record any default parameter values.
210
+ if not self.defaults:
211
+ signature = inspect.signature(func)
212
+ self.defaults = {k: v.default for k, v in signature.parameters.items() if v.default is not v.empty}
213
+
165
214
  else:
166
215
  # builtin function
167
216
 
@@ -210,9 +259,13 @@ class Function:
210
259
  signature_params.append(param)
211
260
  self.signature = inspect.Signature(signature_params)
212
261
 
262
+ # scope for resolving overloads
263
+ if scope_locals is None:
264
+ scope_locals = inspect.currentframe().f_back.f_locals
265
+
213
266
  # add to current module
214
267
  if module:
215
- module.register_function(self, skip_adding_overload)
268
+ module.register_function(self, scope_locals, skip_adding_overload)
216
269
 
217
270
  def __call__(self, *args, **kwargs):
218
271
  # handles calling a builtin (native) function
@@ -323,57 +376,52 @@ class Function:
323
376
 
324
377
  # check if generic
325
378
  if warp.types.is_generic_signature(sig):
326
- if sig in self.user_templates:
327
- raise RuntimeError(
328
- f"Duplicate generic function overload {self.key} with arguments {f.input_types.values()}"
329
- )
330
379
  self.user_templates[sig] = f
331
380
  else:
332
- if sig in self.user_overloads:
333
- raise RuntimeError(
334
- f"Duplicate function overload {self.key} with arguments {f.input_types.values()}"
335
- )
336
381
  self.user_overloads[sig] = f
337
382
 
338
383
  def get_overload(self, arg_types, kwarg_types):
339
384
  assert not self.is_builtin()
340
385
 
341
- sig = warp.types.get_signature(arg_types, func_name=self.key)
386
+ for f in self.user_overloads.values():
387
+ if warp.codegen.func_match_args(f, arg_types, kwarg_types):
388
+ return f
342
389
 
343
- f = self.user_overloads.get(sig)
344
- if f is not None:
345
- return f
346
- else:
347
- for f in self.user_templates.values():
348
- if len(f.input_types) != len(arg_types):
349
- continue
390
+ for f in self.user_templates.values():
391
+ if not warp.codegen.func_match_args(f, arg_types, kwarg_types):
392
+ continue
393
+
394
+ if len(f.input_types) != len(arg_types):
395
+ continue
350
396
 
351
- # try to match the given types to the function template types
352
- template_types = list(f.input_types.values())
353
- args_matched = True
397
+ # try to match the given types to the function template types
398
+ template_types = list(f.input_types.values())
399
+ args_matched = True
354
400
 
355
- for i in range(len(arg_types)):
356
- if not warp.types.type_matches_template(arg_types[i], template_types[i]):
357
- args_matched = False
358
- break
401
+ for i in range(len(arg_types)):
402
+ if not warp.types.type_matches_template(arg_types[i], template_types[i]):
403
+ args_matched = False
404
+ break
359
405
 
360
- if args_matched:
361
- # instantiate this function with the specified argument types
406
+ if args_matched:
407
+ # instantiate this function with the specified argument types
362
408
 
363
- arg_names = f.input_types.keys()
364
- overload_annotations = dict(zip(arg_names, arg_types))
409
+ arg_names = f.input_types.keys()
410
+ overload_annotations = dict(zip(arg_names, arg_types))
365
411
 
366
- ovl = shallowcopy(f)
367
- ovl.adj = warp.codegen.Adjoint(f.func, overload_annotations)
368
- ovl.input_types = overload_annotations
369
- ovl.value_func = None
412
+ ovl = shallowcopy(f)
413
+ ovl.adj = warp.codegen.Adjoint(f.func, overload_annotations)
414
+ ovl.input_types = overload_annotations
415
+ ovl.value_func = None
416
+ ovl.generic_parent = f
370
417
 
371
- self.user_overloads[sig] = ovl
418
+ sig = warp.types.get_signature(arg_types, func_name=self.key)
419
+ self.user_overloads[sig] = ovl
372
420
 
373
- return ovl
421
+ return ovl
374
422
 
375
- # failed to find overload
376
- return None
423
+ # failed to find overload
424
+ return None
377
425
 
378
426
  def __repr__(self):
379
427
  inputs_str = ", ".join([f"{k}: {warp.types.type_repr(v)}" for k, v in self.input_types.items()])
@@ -589,8 +637,7 @@ class Kernel:
589
637
  self.module = module
590
638
 
591
639
  if key is None:
592
- unique_key = self.module.generate_unique_kernel_key(func.__name__)
593
- self.key = unique_key
640
+ self.key = warp.codegen.make_full_qualified_name(func)
594
641
  else:
595
642
  self.key = key
596
643
 
@@ -614,9 +661,15 @@ class Kernel:
614
661
  # known overloads for generic kernels, indexed by type signature
615
662
  self.overloads = {}
616
663
 
664
+ # generic kernel that was used to instantiate this overload
665
+ self.generic_parent = None
666
+
617
667
  # argument indices by name
618
668
  self.arg_indices = {a.label: i for i, a in enumerate(self.adj.args)}
619
669
 
670
+ # hash will be computed when the module is built
671
+ self.hash = None
672
+
620
673
  if self.module:
621
674
  self.module.register_kernel(self)
622
675
 
@@ -664,10 +717,11 @@ class Kernel:
664
717
  ovl.is_generic = False
665
718
  ovl.overloads = {}
666
719
  ovl.sig = sig
720
+ ovl.generic_parent = self
667
721
 
668
722
  self.overloads[sig] = ovl
669
723
 
670
- self.module.unload()
724
+ self.module.mark_modified()
671
725
 
672
726
  return ovl
673
727
 
@@ -676,10 +730,13 @@ class Kernel:
676
730
  return self.overloads.get(sig)
677
731
 
678
732
  def get_mangled_name(self):
679
- if self.sig:
680
- return f"{self.key}_{self.sig}"
681
- else:
682
- return self.key
733
+ if self.hash is None:
734
+ raise RuntimeError(f"Missing hash for kernel {self.key} in module {self.module.name}")
735
+
736
+ # TODO: allow customizing the number of hash characters used
737
+ hash_suffix = self.hash.hex()[:8]
738
+
739
+ return f"{self.key}_{hash_suffix}"
683
740
 
684
741
 
685
742
  # ----------------------
@@ -689,9 +746,11 @@ class Kernel:
689
746
  def func(f):
690
747
  name = warp.codegen.make_full_qualified_name(f)
691
748
 
749
+ scope_locals = inspect.currentframe().f_back.f_locals
750
+
692
751
  m = get_module(f.__module__)
693
752
  Function(
694
- func=f, key=name, namespace="", module=m, value_func=None
753
+ func=f, key=name, namespace="", module=m, value_func=None, scope_locals=scope_locals
695
754
  ) # value_type not known yet, will be inferred during Adjoint.build()
696
755
 
697
756
  # use the top of the list of overloads for this key
@@ -705,6 +764,8 @@ def func_native(snippet, adj_snippet=None, replay_snippet=None):
705
764
  Decorator to register native code snippet, @func_native
706
765
  """
707
766
 
767
+ scope_locals = inspect.currentframe().f_back.f_locals
768
+
708
769
  def snippet_func(f):
709
770
  name = warp.codegen.make_full_qualified_name(f)
710
771
 
@@ -717,6 +778,7 @@ def func_native(snippet, adj_snippet=None, replay_snippet=None):
717
778
  native_snippet=snippet,
718
779
  adj_native_snippet=adj_snippet,
719
780
  replay_snippet=replay_snippet,
781
+ scope_locals=scope_locals,
720
782
  ) # value_type not known yet, will be inferred during Adjoint.build()
721
783
  g = m.functions[name]
722
784
  # copy over the function attributes, including docstring
@@ -783,6 +845,7 @@ def func_grad(forward_fn):
783
845
  f.custom_grad_func = Function(
784
846
  grad_fn,
785
847
  key=f.key,
848
+ native_func=f.native_func,
786
849
  namespace=f.namespace,
787
850
  input_types=reverse_args,
788
851
  value_func=None,
@@ -941,7 +1004,7 @@ def overload(kernel, arg_types=None):
941
1004
  # ensure this function name corresponds to a kernel
942
1005
  fn = kernel
943
1006
  module = get_module(fn.__module__)
944
- kernel = module.kernels.get(fn.__name__)
1007
+ kernel = module.find_kernel(fn)
945
1008
  if kernel is None:
946
1009
  raise RuntimeError(f"Failed to find a kernel named '{fn.__name__}' in module {fn.__module__}")
947
1010
 
@@ -1277,7 +1340,6 @@ def get_module(name):
1277
1340
  # clear out old kernels, funcs, struct definitions
1278
1341
  old_module.kernels = {}
1279
1342
  old_module.functions = {}
1280
- old_module.constants = {}
1281
1343
  old_module.structs = {}
1282
1344
  old_module.loader = parent_loader
1283
1345
 
@@ -1289,30 +1351,202 @@ def get_module(name):
1289
1351
  return user_modules[name]
1290
1352
 
1291
1353
 
1354
+ # ModuleHasher computes the module hash based on all the kernels, module options,
1355
+ # and build configuration. For each kernel, it computes a deep hash by recursively
1356
+ # hashing all referenced functions, structs, and constants, even those defined in
1357
+ # other modules. The module hash is computed in the constructor and can be retrieved
1358
+ # using get_module_hash(). In addition, the ModuleHasher takes care of filtering out
1359
+ # duplicate kernels for codegen (see get_unique_kernels()).
1360
+ class ModuleHasher:
1361
+ def __init__(self, module):
1362
+ # cache function hashes to avoid hashing multiple times
1363
+ self.function_hashes = {} # (function: hash)
1364
+
1365
+ # avoid recursive spiral of doom (e.g., function calling an overload of itself)
1366
+ self.functions_in_progress = set()
1367
+
1368
+ # all unique kernels for codegen, filtered by hash
1369
+ self.unique_kernels = {} # (hash: kernel)
1370
+
1371
+ # start hashing the module
1372
+ ch = hashlib.sha256()
1373
+
1374
+ # hash all non-generic kernels
1375
+ for kernel in module.live_kernels:
1376
+ if kernel.is_generic:
1377
+ for ovl in kernel.overloads.values():
1378
+ if not ovl.adj.skip_build:
1379
+ ovl.hash = self.hash_kernel(ovl)
1380
+ else:
1381
+ if not kernel.adj.skip_build:
1382
+ kernel.hash = self.hash_kernel(kernel)
1383
+
1384
+ # include all unique kernels in the module hash
1385
+ for kernel_hash in sorted(self.unique_kernels.keys()):
1386
+ ch.update(kernel_hash)
1387
+
1388
+ # configuration parameters
1389
+ for opt in sorted(module.options.keys()):
1390
+ s = f"{opt}:{module.options[opt]}"
1391
+ ch.update(bytes(s, "utf-8"))
1392
+
1393
+ # ensure to trigger recompilation if flags affecting kernel compilation are changed
1394
+ if warp.config.verify_fp:
1395
+ ch.update(bytes("verify_fp", "utf-8"))
1396
+
1397
+ # build config
1398
+ ch.update(bytes(warp.config.mode, "utf-8"))
1399
+
1400
+ # save the module hash
1401
+ self.module_hash = ch.digest()
1402
+
1403
+ def hash_kernel(self, kernel):
1404
+ # NOTE: We only hash non-generic kernels, so we don't traverse kernel overloads here.
1405
+
1406
+ ch = hashlib.sha256()
1407
+
1408
+ ch.update(bytes(kernel.key, "utf-8"))
1409
+ ch.update(self.hash_adjoint(kernel.adj))
1410
+
1411
+ h = ch.digest()
1412
+
1413
+ self.unique_kernels[h] = kernel
1414
+
1415
+ return h
1416
+
1417
+ def hash_function(self, func):
1418
+ # NOTE: This method hashes all possible overloads that a function call could resolve to.
1419
+ # The exact overload will be resolved at build time, when the argument types are known.
1420
+
1421
+ h = self.function_hashes.get(func)
1422
+ if h is not None:
1423
+ return h
1424
+
1425
+ self.functions_in_progress.add(func)
1426
+
1427
+ ch = hashlib.sha256()
1428
+
1429
+ ch.update(bytes(func.key, "utf-8"))
1430
+
1431
+ # include all concrete and generic overloads
1432
+ overloads = {**func.user_overloads, **func.user_templates}
1433
+ for sig in sorted(overloads.keys()):
1434
+ ovl = overloads[sig]
1435
+
1436
+ # skip instantiations of generic functions
1437
+ if ovl.generic_parent is not None:
1438
+ continue
1439
+
1440
+ # adjoint
1441
+ ch.update(self.hash_adjoint(ovl.adj))
1442
+
1443
+ # custom bits
1444
+ if ovl.custom_grad_func:
1445
+ ch.update(self.hash_adjoint(ovl.custom_grad_func.adj))
1446
+ if ovl.custom_replay_func:
1447
+ ch.update(self.hash_adjoint(ovl.custom_replay_func.adj))
1448
+ if ovl.replay_snippet:
1449
+ ch.update(bytes(ovl.replay_snippet, "utf-8"))
1450
+ if ovl.native_snippet:
1451
+ ch.update(bytes(ovl.native_snippet, "utf-8"))
1452
+ if ovl.adj_native_snippet:
1453
+ ch.update(bytes(ovl.adj_native_snippet, "utf-8"))
1454
+
1455
+ h = ch.digest()
1456
+
1457
+ self.function_hashes[func] = h
1458
+
1459
+ self.functions_in_progress.remove(func)
1460
+
1461
+ return h
1462
+
1463
+ def hash_adjoint(self, adj):
1464
+ # NOTE: We don't cache adjoint hashes, because adjoints are always unique.
1465
+ # Even instances of generic kernels and functions have unique adjoints with
1466
+ # different argument types.
1467
+
1468
+ ch = hashlib.sha256()
1469
+
1470
+ # source
1471
+ ch.update(bytes(adj.source, "utf-8"))
1472
+
1473
+ # args
1474
+ for arg, arg_type in adj.arg_types.items():
1475
+ s = f"{arg}:{warp.types.get_type_code(arg_type)}"
1476
+ ch.update(bytes(s, "utf-8"))
1477
+
1478
+ # hash struct types
1479
+ if isinstance(arg_type, warp.codegen.Struct):
1480
+ ch.update(arg_type.hash)
1481
+ elif warp.types.is_array(arg_type) and isinstance(arg_type.dtype, warp.codegen.Struct):
1482
+ ch.update(arg_type.dtype.hash)
1483
+
1484
+ # find referenced constants, types, and functions
1485
+ constants, types, functions = adj.get_references()
1486
+
1487
+ # hash referenced constants
1488
+ for name, value in constants.items():
1489
+ ch.update(bytes(name, "utf-8"))
1490
+ # hash the referenced object
1491
+ if isinstance(value, builtins.bool):
1492
+ # This needs to come before the check for `int` since all boolean
1493
+ # values are also instances of `int`.
1494
+ ch.update(struct_pack("?", value))
1495
+ elif isinstance(value, int):
1496
+ ch.update(struct_pack("<q", value))
1497
+ elif isinstance(value, float):
1498
+ ch.update(struct_pack("<d", value))
1499
+ elif isinstance(value, warp.types.float16):
1500
+ # float16 is a special case
1501
+ p = ctypes.pointer(ctypes.c_float(value.value))
1502
+ ch.update(p.contents)
1503
+ elif isinstance(value, tuple(warp.types.scalar_types)):
1504
+ p = ctypes.pointer(value._type_(value.value))
1505
+ ch.update(p.contents)
1506
+ elif isinstance(value, ctypes.Array):
1507
+ ch.update(bytes(value))
1508
+ else:
1509
+ raise RuntimeError(f"Invalid constant type: {type(value)}")
1510
+
1511
+ # hash wp.static() expressions that were evaluated at declaration time
1512
+ for k, v in adj.static_expressions.items():
1513
+ ch.update(bytes(f"{k} = {v}", "utf-8"))
1514
+
1515
+ # hash referenced types
1516
+ for t in types.keys():
1517
+ ch.update(bytes(warp.types.get_type_code(t), "utf-8"))
1518
+
1519
+ # hash referenced functions
1520
+ for f in functions.keys():
1521
+ if f not in self.functions_in_progress:
1522
+ ch.update(self.hash_function(f))
1523
+
1524
+ return ch.digest()
1525
+
1526
+ def get_module_hash(self):
1527
+ return self.module_hash
1528
+
1529
+ def get_unique_kernels(self):
1530
+ return self.unique_kernels.values()
1531
+
1532
+
1292
1533
  class ModuleBuilder:
1293
- def __init__(self, module, options):
1534
+ def __init__(self, module, options, hasher=None):
1294
1535
  self.functions = {}
1295
1536
  self.structs = {}
1296
1537
  self.options = options
1297
1538
  self.module = module
1298
1539
  self.deferred_functions = []
1299
1540
 
1300
- # build all functions declared in the module
1301
- for func in module.functions.values():
1302
- for f in func.user_overloads.values():
1303
- self.build_function(f)
1304
- if f.custom_replay_func is not None:
1305
- self.build_function(f.custom_replay_func)
1306
-
1307
- # build all kernel entry points
1308
- for kernel in module.kernels.values():
1309
- if not kernel.is_generic:
1310
- self.build_kernel(kernel)
1311
- else:
1312
- for k in kernel.overloads.values():
1313
- self.build_kernel(k)
1541
+ if hasher is None:
1542
+ hasher = ModuleHasher(module)
1314
1543
 
1315
- # build all functions outside this module which are called from functions or kernels in this module
1544
+ # build all unique kernels
1545
+ self.kernels = hasher.get_unique_kernels()
1546
+ for kernel in self.kernels:
1547
+ self.build_kernel(kernel)
1548
+
1549
+ # build deferred functions
1316
1550
  for func in self.deferred_functions:
1317
1551
  self.build_function(func)
1318
1552
 
@@ -1328,7 +1562,7 @@ class ModuleBuilder:
1328
1562
  for var in s.vars.values():
1329
1563
  if isinstance(var.type, warp.codegen.Struct):
1330
1564
  stack.append(var.type)
1331
- elif isinstance(var.type, warp.types.array) and isinstance(var.type.dtype, warp.codegen.Struct):
1565
+ elif warp.types.is_array(var.type) and isinstance(var.type.dtype, warp.codegen.Struct):
1332
1566
  stack.append(var.type.dtype)
1333
1567
 
1334
1568
  # Build them in reverse to generate a correct dependency order.
@@ -1374,8 +1608,12 @@ class ModuleBuilder:
1374
1608
  source = ""
1375
1609
 
1376
1610
  # code-gen structs
1611
+ visited_structs = set()
1377
1612
  for struct in self.structs.keys():
1378
- source += warp.codegen.codegen_struct(struct)
1613
+ # avoid emitting duplicates
1614
+ if struct.hash not in visited_structs:
1615
+ source += warp.codegen.codegen_struct(struct)
1616
+ visited_structs.add(struct.hash)
1379
1617
 
1380
1618
  # code-gen all imported functions
1381
1619
  for func in self.functions.keys():
@@ -1386,21 +1624,15 @@ class ModuleBuilder:
1386
1624
  else:
1387
1625
  source += warp.codegen.codegen_snippet(
1388
1626
  func.adj,
1389
- name=func.key,
1627
+ name=func.native_func,
1390
1628
  snippet=func.native_snippet,
1391
1629
  adj_snippet=func.adj_native_snippet,
1392
1630
  replay_snippet=func.replay_snippet,
1393
1631
  )
1394
1632
 
1395
- for kernel in self.module.kernels.values():
1396
- # each kernel gets an entry point in the module
1397
- if not kernel.is_generic:
1398
- source += warp.codegen.codegen_kernel(kernel, device=device, options=self.options)
1399
- source += warp.codegen.codegen_module(kernel, device=device)
1400
- else:
1401
- for k in kernel.overloads.values():
1402
- source += warp.codegen.codegen_kernel(k, device=device, options=self.options)
1403
- source += warp.codegen.codegen_module(k, device=device)
1633
+ for kernel in self.kernels:
1634
+ source += warp.codegen.codegen_kernel(kernel, device=device, options=self.options)
1635
+ source += warp.codegen.codegen_module(kernel, device=device)
1404
1636
 
1405
1637
  # add headers
1406
1638
  if device == "cpu":
@@ -1425,8 +1657,9 @@ class ModuleExec:
1425
1657
  instance.handle = None
1426
1658
  return instance
1427
1659
 
1428
- def __init__(self, handle, device):
1660
+ def __init__(self, handle, module_hash, device):
1429
1661
  self.handle = handle
1662
+ self.module_hash = module_hash
1430
1663
  self.device = device
1431
1664
  self.kernel_hooks = {}
1432
1665
 
@@ -1457,8 +1690,12 @@ class ModuleExec:
1457
1690
  )
1458
1691
  else:
1459
1692
  func = ctypes.CFUNCTYPE(None)
1460
- forward = func(runtime.llvm.lookup(self.handle.encode("utf-8"), (name + "_cpu_forward").encode("utf-8")))
1461
- backward = func(runtime.llvm.lookup(self.handle.encode("utf-8"), (name + "_cpu_backward").encode("utf-8")))
1693
+ forward = (
1694
+ func(runtime.llvm.lookup(self.handle.encode("utf-8"), (name + "_cpu_forward").encode("utf-8"))) or None
1695
+ )
1696
+ backward = (
1697
+ func(runtime.llvm.lookup(self.handle.encode("utf-8"), (name + "_cpu_backward").encode("utf-8"))) or None
1698
+ )
1462
1699
 
1463
1700
  hooks = KernelHooks(forward, backward)
1464
1701
  self.kernel_hooks[kernel] = hooks
@@ -1475,16 +1712,33 @@ class Module:
1475
1712
  self.name = name
1476
1713
  self.loader = loader
1477
1714
 
1478
- self.kernels = {}
1479
- self.functions = {}
1480
- self.constants = {} # Any constants referenced in this module including those defined in other modules
1481
- self.structs = {}
1715
+ # lookup the latest versions of kernels, functions, and structs by key
1716
+ self.kernels = {} # (key: kernel)
1717
+ self.functions = {} # (key: function)
1718
+ self.structs = {} # (key: struct)
1719
+
1720
+ # Set of all "live" kernels in this module.
1721
+ # The difference between `live_kernels` and `kernels` is that `live_kernels` may contain
1722
+ # multiple kernels with the same key (which is essential to support closures), while `kernels`
1723
+ # only holds the latest kernel for each key. When the module is built, we compute the hash
1724
+ # of each kernel in `live_kernels` and filter out duplicates for codegen.
1725
+ self.live_kernels = weakref.WeakSet()
1482
1726
 
1483
- self.cpu_exec = None # executable CPU module
1484
- self.cuda_execs = {} # executable CUDA module lookup by CUDA context
1727
+ # executable modules currently loaded
1728
+ self.execs = {} # (device.context: ModuleExec)
1485
1729
 
1486
- self.cpu_build_failed = False
1487
- self.cuda_build_failed = False
1730
+ # set of device contexts where the build has failed
1731
+ self.failed_builds = set()
1732
+
1733
+ # hash data, including the module hash
1734
+ self.hasher = None
1735
+
1736
+ # LLVM executable modules are identified using strings. Since it's possible for multiple
1737
+ # executable versions to be loaded at the same time, we need a way to ensure uniqueness.
1738
+ # A unique handle is created from the module name and this auto-incremented integer id.
1739
+ # NOTE: The module hash is not sufficient for uniqueness in rare cases where a module
1740
+ # is retained and later reloaded with the same hash.
1741
+ self.cpu_exec_id = 0
1488
1742
 
1489
1743
  self.options = {
1490
1744
  "max_unroll": warp.config.max_unroll,
@@ -1497,11 +1751,6 @@ class Module:
1497
1751
  # Module dependencies are determined by scanning each function
1498
1752
  # and kernel for references to external functions and structs.
1499
1753
  #
1500
- # When a referenced module is modified, all of its dependents need to be reloaded
1501
- # on the next launch. To detect this, a module's hash recursively includes
1502
- # all of its references.
1503
- # -> See ``Module.hash_module()``
1504
- #
1505
1754
  # The dependency mechanism works for both static and dynamic (runtime) modifications.
1506
1755
  # When a module is reloaded at runtime, we recursively unload all of its
1507
1756
  # dependents, so that they will be re-hashed and reloaded on the next launch.
@@ -1510,40 +1759,39 @@ class Module:
1510
1759
  self.references = set() # modules whose content we depend on
1511
1760
  self.dependents = set() # modules that depend on our content
1512
1761
 
1513
- # Since module hashing is recursive, we improve performance by caching the hash of the
1514
- # module contents (kernel source, function source, and struct source).
1515
- # After all kernels, functions, and structs are added to the module (usually at import time),
1516
- # the content hash doesn't change.
1517
- # -> See ``Module.hash_module_recursive()``
1518
-
1519
- self.content_hash = None
1520
-
1521
- # number of times module auto-generates kernel key for user
1522
- # used to ensure unique kernel keys
1523
- self.count = 0
1524
-
1525
1762
  def register_struct(self, struct):
1526
1763
  self.structs[struct.key] = struct
1527
1764
 
1528
1765
  # for a reload of module on next launch
1529
- self.unload()
1766
+ self.mark_modified()
1530
1767
 
1531
1768
  def register_kernel(self, kernel):
1769
+ # keep a reference to the latest version
1532
1770
  self.kernels[kernel.key] = kernel
1533
1771
 
1772
+ # track all kernel objects, even if they are duplicates
1773
+ self.live_kernels.add(kernel)
1774
+
1534
1775
  self.find_references(kernel.adj)
1535
1776
 
1536
1777
  # for a reload of module on next launch
1537
- self.unload()
1778
+ self.mark_modified()
1538
1779
 
1539
- def register_function(self, func, skip_adding_overload=False):
1540
- if func.key not in self.functions:
1541
- self.functions[func.key] = func
1780
+ def register_function(self, func, scope_locals, skip_adding_overload=False):
1781
+ # check for another Function with the same name in the same scope
1782
+ obj = scope_locals.get(func.func.__name__)
1783
+ if isinstance(obj, Function):
1784
+ func_existing = obj
1542
1785
  else:
1786
+ func_existing = None
1787
+
1788
+ # keep a reference to the latest version
1789
+ self.functions[func.key] = func_existing or func
1790
+
1791
+ if func_existing:
1543
1792
  # Check whether the new function's signature match any that has
1544
1793
  # already been registered. If so, then we simply override it, as
1545
1794
  # Python would do it, otherwise we register it as a new overload.
1546
- func_existing = self.functions[func.key]
1547
1795
  sig = warp.types.get_signature(
1548
1796
  func.input_types.values(),
1549
1797
  func_name=func.key,
@@ -1555,19 +1803,43 @@ class Module:
1555
1803
  arg_names=list(func_existing.input_types.keys()),
1556
1804
  )
1557
1805
  if sig == sig_existing:
1806
+ # replace the top-level function, but keep existing overloads
1807
+
1808
+ # copy generic overloads
1809
+ func.user_templates = func_existing.user_templates.copy()
1810
+
1811
+ # copy concrete overloads
1812
+ if warp.types.is_generic_signature(sig):
1813
+ # skip overloads that were instantiated from the function being replaced
1814
+ for k, v in func_existing.user_overloads.items():
1815
+ if v.generic_parent != func_existing:
1816
+ func.user_overloads[k] = v
1817
+ func.user_templates[sig] = func
1818
+ else:
1819
+ func.user_overloads = func_existing.user_overloads.copy()
1820
+ func.user_overloads[sig] = func
1821
+
1558
1822
  self.functions[func.key] = func
1559
1823
  elif not skip_adding_overload:
1824
+ # check if this is a generic overload that replaces an existing one
1825
+ if warp.types.is_generic_signature(sig):
1826
+ old_generic = func_existing.user_templates.get(sig)
1827
+ if old_generic is not None:
1828
+ # purge any concrete overloads that were instantiated from the old one
1829
+ for k, v in list(func_existing.user_overloads.items()):
1830
+ if v.generic_parent == old_generic:
1831
+ del func_existing.user_overloads[k]
1560
1832
  func_existing.add_overload(func)
1561
1833
 
1562
1834
  self.find_references(func.adj)
1563
1835
 
1564
1836
  # for a reload of module on next launch
1565
- self.unload()
1837
+ self.mark_modified()
1566
1838
 
1567
- def generate_unique_kernel_key(self, key):
1568
- unique_key = f"{key}_{self.count}"
1569
- self.count += 1
1570
- return unique_key
1839
+ # find kernel corresponding to a Python function
1840
+ def find_kernel(self, func):
1841
+ qualname = warp.codegen.make_full_qualified_name(func)
1842
+ return self.kernels.get(qualname)
1571
1843
 
1572
1844
  # collect all referenced functions / structs
1573
1845
  # given the AST of a function or kernel
@@ -1599,165 +1871,30 @@ class Module:
1599
1871
  if isinstance(arg.type, warp.codegen.Struct) and arg.type.module is not None:
1600
1872
  add_ref(arg.type.module)
1601
1873
 
1602
- def hash_module(self, recompute_content_hash=False):
1603
- """Recursively compute and return a hash for the module.
1604
-
1605
- If ``recompute_content_hash`` is False, each module's previously
1606
- computed ``content_hash`` will be used.
1607
- """
1608
-
1609
- def get_type_name(type_hint) -> str:
1610
- if isinstance(type_hint, warp.codegen.Struct):
1611
- return get_type_name(type_hint.cls)
1612
- elif isinstance(type_hint, warp.array) and isinstance(type_hint.dtype, warp.codegen.Struct):
1613
- return f"array{get_type_name(type_hint.dtype)}"
1614
-
1615
- return str(type_hint)
1616
-
1617
- def hash_recursive(module, visited):
1618
- # Hash this module, including all referenced modules recursively.
1619
- # The visited set tracks modules already visited to avoid circular references.
1620
-
1621
- # check if we need to update the content hash
1622
- if not module.content_hash or recompute_content_hash:
1623
- # recompute content hash
1624
- ch = hashlib.sha256()
1625
-
1626
- # Start with an empty constants dictionary in case any have been removed
1627
- module.constants = {}
1628
-
1629
- # struct source
1630
- for struct in module.structs.values():
1631
- s = ",".join(
1632
- "{}: {}".format(name, get_type_name(type_hint))
1633
- for name, type_hint in warp.codegen.get_annotations(struct.cls).items()
1634
- )
1635
- ch.update(bytes(s, "utf-8"))
1636
-
1637
- # functions source
1638
- for function in module.functions.values():
1639
- # include all concrete and generic overloads
1640
- overloads = itertools.chain(function.user_overloads.items(), function.user_templates.items())
1641
- for sig, func in overloads:
1642
- # signature
1643
- ch.update(bytes(sig, "utf-8"))
1644
-
1645
- # source
1646
- ch.update(bytes(func.adj.source, "utf-8"))
1647
-
1648
- if func.custom_grad_func:
1649
- ch.update(bytes(func.custom_grad_func.adj.source, "utf-8"))
1650
- if func.custom_replay_func:
1651
- ch.update(bytes(func.custom_replay_func.adj.source, "utf-8"))
1652
- if func.replay_snippet:
1653
- ch.update(bytes(func.replay_snippet, "utf-8"))
1654
- if func.native_snippet:
1655
- ch.update(bytes(func.native_snippet, "utf-8"))
1656
- if func.adj_native_snippet:
1657
- ch.update(bytes(func.adj_native_snippet, "utf-8"))
1658
-
1659
- # Populate constants referenced in this function
1660
- if func.adj:
1661
- module.constants.update(func.adj.get_constant_references())
1662
-
1663
- # kernel source
1664
- for kernel in module.kernels.values():
1665
- ch.update(bytes(kernel.key, "utf-8"))
1666
- ch.update(bytes(kernel.adj.source, "utf-8"))
1667
- # cache kernel arg types
1668
- for arg, arg_type in kernel.adj.arg_types.items():
1669
- s = f"{arg}: {get_type_name(arg_type)}"
1670
- ch.update(bytes(s, "utf-8"))
1671
- # for generic kernels the Python source is always the same,
1672
- # but we hash the type signatures of all the overloads
1673
- if kernel.is_generic:
1674
- for sig in sorted(kernel.overloads.keys()):
1675
- ch.update(bytes(sig, "utf-8"))
1676
-
1677
- # Populate constants referenced in this kernel
1678
- module.constants.update(kernel.adj.get_constant_references())
1679
-
1680
- # constants referenced in this module
1681
- for constant_name, constant_value in module.constants.items():
1682
- ch.update(bytes(constant_name, "utf-8"))
1683
-
1684
- # hash the constant value
1685
- if isinstance(constant_value, builtins.bool):
1686
- # This needs to come before the check for `int` since all boolean
1687
- # values are also instances of `int`.
1688
- ch.update(struct_pack("?", constant_value))
1689
- elif isinstance(constant_value, int):
1690
- ch.update(struct_pack("<q", constant_value))
1691
- elif isinstance(constant_value, float):
1692
- ch.update(struct_pack("<d", constant_value))
1693
- elif isinstance(constant_value, warp.types.float16):
1694
- # float16 is a special case
1695
- p = ctypes.pointer(ctypes.c_float(constant_value.value))
1696
- ch.update(p.contents)
1697
- elif isinstance(constant_value, tuple(warp.types.scalar_types)):
1698
- p = ctypes.pointer(constant_value._type_(constant_value.value))
1699
- ch.update(p.contents)
1700
- elif isinstance(constant_value, ctypes.Array):
1701
- ch.update(bytes(constant_value))
1702
- else:
1703
- raise RuntimeError(f"Invalid constant type: {type(constant_value)}")
1704
-
1705
- module.content_hash = ch.digest()
1706
-
1707
- h = hashlib.sha256()
1708
-
1709
- # content hash
1710
- h.update(module.content_hash)
1711
-
1712
- # configuration parameters
1713
- for k in sorted(module.options.keys()):
1714
- s = f"{k}={module.options[k]}"
1715
- h.update(bytes(s, "utf-8"))
1716
-
1717
- # ensure to trigger recompilation if flags affecting kernel compilation are changed
1718
- if warp.config.verify_fp:
1719
- h.update(bytes("verify_fp", "utf-8"))
1720
-
1721
- h.update(bytes(warp.config.mode, "utf-8"))
1722
-
1723
- # recurse on references
1724
- visited.add(module)
1725
-
1726
- sorted_deps = sorted(module.references, key=lambda m: m.name)
1727
- for dep in sorted_deps:
1728
- if dep not in visited:
1729
- dep_hash = hash_recursive(dep, visited)
1730
- h.update(dep_hash)
1731
-
1732
- return h.digest()
1733
-
1734
- return hash_recursive(self, visited=set())
1874
+ def hash_module(self):
1875
+ # compute latest hash
1876
+ self.hasher = ModuleHasher(self)
1877
+ return self.hasher.get_module_hash()
1735
1878
 
1736
1879
  def load(self, device) -> ModuleExec:
1737
1880
  device = runtime.get_device(device)
1738
1881
 
1739
- if device.is_cpu:
1740
- # check if already loaded
1741
- if self.cpu_exec:
1742
- return self.cpu_exec
1743
- # avoid repeated build attempts
1744
- if self.cpu_build_failed:
1745
- return None
1746
- if not warp.is_cpu_available():
1747
- raise RuntimeError("Failed to build CPU module because no CPU buildchain was found")
1748
- else:
1749
- # check if already loaded
1750
- cuda_exec = self.cuda_execs.get(device.context)
1751
- if cuda_exec is not None:
1752
- return cuda_exec
1753
- # avoid repeated build attempts
1754
- if self.cuda_build_failed:
1755
- return None
1756
- if not warp.is_cuda_available():
1757
- raise RuntimeError("Failed to build CUDA module because CUDA is not available")
1882
+ # compute the hash if needed
1883
+ if self.hasher is None:
1884
+ self.hasher = ModuleHasher(self)
1885
+
1886
+ # check if executable module is already loaded and not stale
1887
+ exec = self.execs.get(device.context)
1888
+ if exec is not None:
1889
+ if exec.module_hash == self.hasher.module_hash:
1890
+ return exec
1891
+
1892
+ # quietly avoid repeated build attempts to reduce error spew
1893
+ if device.context in self.failed_builds:
1894
+ return None
1758
1895
 
1759
1896
  module_name = "wp_" + self.name
1760
- module_hash = self.hash_module()
1897
+ module_hash = self.hasher.module_hash
1761
1898
 
1762
1899
  # use a unique module path using the module short hash
1763
1900
  module_dir = os.path.join(warp.config.kernel_cache_dir, f"{module_name}_{module_hash.hex()[:7]}")
@@ -1807,7 +1944,7 @@ class Module:
1807
1944
  or not warp.config.cache_kernels
1808
1945
  or warp.config.verify_autograd_array_access
1809
1946
  ):
1810
- builder = ModuleBuilder(self, self.options)
1947
+ builder = ModuleBuilder(self, self.options, hasher=self.hasher)
1811
1948
 
1812
1949
  # create a temporary (process unique) dir for build outputs before moving to the binary dir
1813
1950
  build_dir = os.path.join(
@@ -1844,7 +1981,7 @@ class Module:
1844
1981
  )
1845
1982
 
1846
1983
  except Exception as e:
1847
- self.cpu_build_failed = True
1984
+ self.failed_builds.add(None)
1848
1985
  module_load_timer.extra_msg = " (error)"
1849
1986
  raise (e)
1850
1987
 
@@ -1873,7 +2010,7 @@ class Module:
1873
2010
  )
1874
2011
 
1875
2012
  except Exception as e:
1876
- self.cuda_build_failed = True
2013
+ self.failed_builds.add(device.context)
1877
2014
  module_load_timer.extra_msg = " (error)"
1878
2015
  raise (e)
1879
2016
 
@@ -1914,15 +2051,18 @@ class Module:
1914
2051
  # -----------------------------------------------------------
1915
2052
  # Load CPU or CUDA binary
1916
2053
  if device.is_cpu:
1917
- runtime.llvm.load_obj(binary_path.encode("utf-8"), module_name.encode("utf-8"))
1918
- module_exec = ModuleExec(module_name, device)
1919
- self.cpu_exec = module_exec
2054
+ # LLVM modules are identified using strings, so we need to ensure uniqueness
2055
+ module_handle = f"{module_name}_{self.cpu_exec_id}"
2056
+ self.cpu_exec_id += 1
2057
+ runtime.llvm.load_obj(binary_path.encode("utf-8"), module_handle.encode("utf-8"))
2058
+ module_exec = ModuleExec(module_handle, module_hash, device)
2059
+ self.execs[None] = module_exec
1920
2060
 
1921
2061
  elif device.is_cuda:
1922
2062
  cuda_module = warp.build.load_cuda(binary_path, device)
1923
2063
  if cuda_module is not None:
1924
- module_exec = ModuleExec(cuda_module, device)
1925
- self.cuda_execs[device.context] = module_exec
2064
+ module_exec = ModuleExec(cuda_module, module_hash, device)
2065
+ self.execs[device.context] = module_exec
1926
2066
  else:
1927
2067
  module_load_timer.extra_msg = " (error)"
1928
2068
  raise Exception(f"Failed to load CUDA module '{self.name}'")
@@ -1936,20 +2076,22 @@ class Module:
1936
2076
  return module_exec
1937
2077
 
1938
2078
  def unload(self):
2079
+ # force rehashing on next load
2080
+ self.mark_modified()
2081
+
1939
2082
  # clear loaded modules
1940
- self.cpu_exec = None
1941
- self.cuda_execs = {}
2083
+ self.execs = {}
2084
+
2085
+ def mark_modified(self):
2086
+ # clear hash data
2087
+ self.hasher = None
1942
2088
 
1943
- # clear content hash
1944
- self.content_hash = None
2089
+ # clear build failures
2090
+ self.failed_builds = set()
1945
2091
 
1946
2092
  # lookup kernel entry points based on name, called after compilation / module load
1947
2093
  def get_kernel_hooks(self, kernel, device):
1948
- if device.is_cuda:
1949
- module_exec = self.cuda_execs.get(device.context)
1950
- else:
1951
- module_exec = self.cpu_exec
1952
-
2094
+ module_exec = self.execs.get(device.context)
1953
2095
  if module_exec is not None:
1954
2096
  return module_exec.get_kernel_hooks(kernel)
1955
2097
  else:
@@ -2056,6 +2198,63 @@ class ContextGuard:
2056
2198
  runtime.core.cuda_context_set_current(self.saved_context)
2057
2199
 
2058
2200
 
2201
+ class Event:
2202
+ """A CUDA event that can be recorded onto a stream.
2203
+
2204
+ Events can be used for device-side synchronization, which do not block
2205
+ the host thread.
2206
+ """
2207
+
2208
+ # event creation flags
2209
+ class Flags:
2210
+ DEFAULT = 0x0
2211
+ BLOCKING_SYNC = 0x1
2212
+ DISABLE_TIMING = 0x2
2213
+
2214
+ def __new__(cls, *args, **kwargs):
2215
+ """Creates a new event instance."""
2216
+ instance = super(Event, cls).__new__(cls)
2217
+ instance.owner = False
2218
+ return instance
2219
+
2220
+ def __init__(self, device: "Devicelike" = None, cuda_event=None, enable_timing: bool = False):
2221
+ """Initializes the event on a CUDA device.
2222
+
2223
+ Args:
2224
+ device: The CUDA device whose streams this event may be recorded onto.
2225
+ If ``None``, then the current default device will be used.
2226
+ cuda_event: A pointer to a previously allocated CUDA event. If
2227
+ `None`, then a new event will be allocated on the associated device.
2228
+ enable_timing: If ``True`` this event will record timing data.
2229
+ :func:`~warp.get_event_elapsed_time` can be used to measure the
2230
+ time between two events created with ``enable_timing=True`` and
2231
+ recorded onto streams.
2232
+ """
2233
+
2234
+ device = get_device(device)
2235
+ if not device.is_cuda:
2236
+ raise RuntimeError(f"Device {device} is not a CUDA device")
2237
+
2238
+ self.device = device
2239
+
2240
+ if cuda_event is not None:
2241
+ self.cuda_event = cuda_event
2242
+ else:
2243
+ flags = Event.Flags.DEFAULT
2244
+ if not enable_timing:
2245
+ flags |= Event.Flags.DISABLE_TIMING
2246
+ self.cuda_event = runtime.core.cuda_event_create(device.context, flags)
2247
+ if not self.cuda_event:
2248
+ raise RuntimeError(f"Failed to create event on device {device}")
2249
+ self.owner = True
2250
+
2251
+ def __del__(self):
2252
+ if not self.owner:
2253
+ return
2254
+
2255
+ runtime.core.cuda_event_destroy(self.cuda_event)
2256
+
2257
+
2059
2258
  class Stream:
2060
2259
  def __new__(cls, *args, **kwargs):
2061
2260
  instance = super(Stream, cls).__new__(cls)
@@ -2063,7 +2262,27 @@ class Stream:
2063
2262
  instance.owner = False
2064
2263
  return instance
2065
2264
 
2066
- def __init__(self, device=None, **kwargs):
2265
+ def __init__(self, device: Optional[Union["Device", str]] = None, priority: int = 0, **kwargs):
2266
+ """Initialize the stream on a device with an optional specified priority.
2267
+
2268
+ Args:
2269
+ device: The CUDA device on which this stream will be created.
2270
+ priority: An optional integer specifying the requested stream priority.
2271
+ Can be -1 (high priority) or 0 (low/default priority).
2272
+ Values outside this range will be clamped.
2273
+ cuda_stream (int): A optional external stream handle passed as an
2274
+ integer. The caller is responsible for ensuring that the external
2275
+ stream does not get destroyed while it is referenced by this
2276
+ object.
2277
+
2278
+ Raises:
2279
+ RuntimeError: If function is called before Warp has completed
2280
+ initialization with a ``device`` that is not an instance of
2281
+ :class:`Device``.
2282
+ RuntimeError: ``device`` is not a CUDA Device.
2283
+ RuntimeError: The stream could not be created on the device.
2284
+ TypeError: The requested stream priority is not an integer.
2285
+ """
2067
2286
  # event used internally for synchronization (cached to avoid creating temporary events)
2068
2287
  self._cached_event = None
2069
2288
 
@@ -2072,7 +2291,7 @@ class Stream:
2072
2291
  device = runtime.get_device(device)
2073
2292
  elif not isinstance(device, Device):
2074
2293
  raise RuntimeError(
2075
- "A device object is required when creating a stream before or during Warp initialization"
2294
+ "A Device object is required when creating a stream before or during Warp initialization"
2076
2295
  )
2077
2296
 
2078
2297
  if not device.is_cuda:
@@ -2085,7 +2304,11 @@ class Stream:
2085
2304
  self.cuda_stream = kwargs["cuda_stream"]
2086
2305
  device.runtime.core.cuda_stream_register(device.context, self.cuda_stream)
2087
2306
  else:
2088
- self.cuda_stream = device.runtime.core.cuda_stream_create(device.context)
2307
+ if not isinstance(priority, int):
2308
+ raise TypeError("Stream priority must be an integer.")
2309
+ clamped_priority = max(-1, min(priority, 0)) # Only support two priority levels
2310
+ self.cuda_stream = device.runtime.core.cuda_stream_create(device.context, clamped_priority)
2311
+
2089
2312
  if not self.cuda_stream:
2090
2313
  raise RuntimeError(f"Failed to create stream on device {device}")
2091
2314
  self.owner = True
@@ -2100,12 +2323,22 @@ class Stream:
2100
2323
  runtime.core.cuda_stream_unregister(self.device.context, self.cuda_stream)
2101
2324
 
2102
2325
  @property
2103
- def cached_event(self):
2326
+ def cached_event(self) -> Event:
2104
2327
  if self._cached_event is None:
2105
2328
  self._cached_event = Event(self.device)
2106
2329
  return self._cached_event
2107
2330
 
2108
- def record_event(self, event=None):
2331
+ def record_event(self, event: Optional[Event] = None) -> Event:
2332
+ """Record an event onto the stream.
2333
+
2334
+ Args:
2335
+ event: A warp.Event instance to be recorded onto the stream. If not
2336
+ provided, an :class:`~warp.Event` on the same device will be created.
2337
+
2338
+ Raises:
2339
+ RuntimeError: The provided :class:`~warp.Event` is from a different device than
2340
+ the recording stream.
2341
+ """
2109
2342
  if event is None:
2110
2343
  event = Event(self.device)
2111
2344
  elif event.device != self.device:
@@ -2117,56 +2350,45 @@ class Stream:
2117
2350
 
2118
2351
  return event
2119
2352
 
2120
- def wait_event(self, event):
2353
+ def wait_event(self, event: Event):
2354
+ """Makes all future work in this stream wait until `event` has completed.
2355
+
2356
+ This function does not block the host thread.
2357
+ """
2121
2358
  runtime.core.cuda_stream_wait_event(self.cuda_stream, event.cuda_event)
2122
2359
 
2123
- def wait_stream(self, other_stream, event=None):
2360
+ def wait_stream(self, other_stream: "Stream", event: Optional[Event] = None):
2361
+ """Records an event on `other_stream` and makes this stream wait on it.
2362
+
2363
+ All work added to this stream after this function has been called will
2364
+ delay their execution until all preceding commands in `other_stream`
2365
+ have completed.
2366
+
2367
+ This function does not block the host thread.
2368
+
2369
+ Args:
2370
+ other_stream: The stream on which the calling stream will wait for
2371
+ previously issued commands to complete before executing subsequent
2372
+ commands.
2373
+ event: An optional :class:`Event` instance that will be used to
2374
+ record an event onto ``other_stream``. If ``None``, an internally
2375
+ managed :class:`Event` instance will be used.
2376
+ """
2377
+
2124
2378
  if event is None:
2125
2379
  event = other_stream.cached_event
2126
2380
 
2127
2381
  runtime.core.cuda_stream_wait_stream(self.cuda_stream, other_stream.cuda_stream, event.cuda_event)
2128
2382
 
2129
- # whether a graph capture is currently ongoing on this stream
2130
2383
  @property
2131
- def is_capturing(self):
2384
+ def is_capturing(self) -> bool:
2385
+ """A boolean indicating whether a graph capture is currently ongoing on this stream."""
2132
2386
  return bool(runtime.core.cuda_stream_is_capturing(self.cuda_stream))
2133
2387
 
2134
-
2135
- class Event:
2136
- # event creation flags
2137
- class Flags:
2138
- DEFAULT = 0x0
2139
- BLOCKING_SYNC = 0x1
2140
- DISABLE_TIMING = 0x2
2141
-
2142
- def __new__(cls, *args, **kwargs):
2143
- instance = super(Event, cls).__new__(cls)
2144
- instance.owner = False
2145
- return instance
2146
-
2147
- def __init__(self, device=None, cuda_event=None, enable_timing=False):
2148
- device = get_device(device)
2149
- if not device.is_cuda:
2150
- raise RuntimeError(f"Device {device} is not a CUDA device")
2151
-
2152
- self.device = device
2153
-
2154
- if cuda_event is not None:
2155
- self.cuda_event = cuda_event
2156
- else:
2157
- flags = Event.Flags.DEFAULT
2158
- if not enable_timing:
2159
- flags |= Event.Flags.DISABLE_TIMING
2160
- self.cuda_event = runtime.core.cuda_event_create(device.context, flags)
2161
- if not self.cuda_event:
2162
- raise RuntimeError(f"Failed to create event on device {device}")
2163
- self.owner = True
2164
-
2165
- def __del__(self):
2166
- if not self.owner:
2167
- return
2168
-
2169
- runtime.core.cuda_event_destroy(self.cuda_event)
2388
+ @property
2389
+ def priority(self) -> int:
2390
+ """An integer representing the priority of the stream."""
2391
+ return runtime.core.cuda_stream_get_priority(self.cuda_stream)
2170
2392
 
2171
2393
 
2172
2394
  class Device:
@@ -2178,14 +2400,14 @@ class Device:
2178
2400
  or ``"CPU"`` if the processor name cannot be determined.
2179
2401
  arch: An integer representing the compute capability version number calculated as
2180
2402
  ``10 * major + minor``. ``0`` for CPU devices.
2181
- is_uva: A boolean indicating whether or not the device supports unified addressing.
2403
+ is_uva: A boolean indicating whether the device supports unified addressing.
2182
2404
  ``False`` for CPU devices.
2183
- is_cubin_supported: A boolean indicating whether or not Warp's version of NVRTC can directly
2405
+ is_cubin_supported: A boolean indicating whether Warp's version of NVRTC can directly
2184
2406
  generate CUDA binary files (cubin) for this device's architecture. ``False`` for CPU devices.
2185
- is_mempool_supported: A boolean indicating whether or not the device supports using the
2407
+ is_mempool_supported: A boolean indicating whether the device supports using the
2186
2408
  ``cuMemAllocAsync`` and ``cuMemPool`` family of APIs for stream-ordered memory allocations. ``False`` for
2187
2409
  CPU devices.
2188
- is_primary: A boolean indicating whether or not this device's CUDA context is also the
2410
+ is_primary: A boolean indicating whether this device's CUDA context is also the
2189
2411
  device's primary context.
2190
2412
  uuid: A string representing the UUID of the CUDA device. The UUID is in the same format used by
2191
2413
  ``nvidia-smi -L``. ``None`` for CPU devices.
@@ -2274,7 +2496,7 @@ class Device:
2274
2496
 
2275
2497
  # initialize streams unless context acquisition is postponed
2276
2498
  if self._context is not None:
2277
- self.init_streams()
2499
+ self._init_streams()
2278
2500
 
2279
2501
  # TODO: add more device-specific dispatch functions
2280
2502
  self.memset = lambda ptr, value, size: runtime.core.memset_device(self.context, ptr, value, size)
@@ -2285,7 +2507,13 @@ class Device:
2285
2507
  else:
2286
2508
  raise RuntimeError(f"Invalid device ordinal ({ordinal})'")
2287
2509
 
2288
- def get_allocator(self, pinned=False):
2510
+ def get_allocator(self, pinned: bool = False):
2511
+ """Get the memory allocator for this device.
2512
+
2513
+ Args:
2514
+ pinned: If ``True``, an allocator for pinned memory will be
2515
+ returned. Only applicable when this device is a CPU device.
2516
+ """
2289
2517
  if self.is_cuda:
2290
2518
  return self.current_allocator
2291
2519
  else:
@@ -2294,7 +2522,8 @@ class Device:
2294
2522
  else:
2295
2523
  return self.default_allocator
2296
2524
 
2297
- def init_streams(self):
2525
+ def _init_streams(self):
2526
+ """Initializes the device's current stream and the device's null stream."""
2298
2527
  # create a stream for asynchronous work
2299
2528
  self.set_stream(Stream(self))
2300
2529
 
@@ -2302,17 +2531,18 @@ class Device:
2302
2531
  self.null_stream = Stream(self, cuda_stream=None)
2303
2532
 
2304
2533
  @property
2305
- def is_cpu(self):
2306
- """A boolean indicating whether or not the device is a CPU device."""
2534
+ def is_cpu(self) -> bool:
2535
+ """A boolean indicating whether the device is a CPU device."""
2307
2536
  return self.ordinal < 0
2308
2537
 
2309
2538
  @property
2310
- def is_cuda(self):
2311
- """A boolean indicating whether or not the device is a CUDA device."""
2539
+ def is_cuda(self) -> bool:
2540
+ """A boolean indicating whether the device is a CUDA device."""
2312
2541
  return self.ordinal >= 0
2313
2542
 
2314
2543
  @property
2315
- def is_capturing(self):
2544
+ def is_capturing(self) -> bool:
2545
+ """A boolean indicating whether this device's default stream is currently capturing a graph."""
2316
2546
  if self.is_cuda and self.stream is not None:
2317
2547
  # There is no CUDA API to check if graph capture was started on a device, so we
2318
2548
  # can't tell if a capture was started by external code on a different stream.
@@ -2336,17 +2566,17 @@ class Device:
2336
2566
  raise RuntimeError(f"Failed to acquire primary context for device {self}")
2337
2567
  self.runtime.context_map[self._context] = self
2338
2568
  # initialize streams
2339
- self.init_streams()
2569
+ self._init_streams()
2340
2570
  runtime.core.cuda_context_set_current(prev_context)
2341
2571
  return self._context
2342
2572
 
2343
2573
  @property
2344
- def has_context(self):
2345
- """A boolean indicating whether or not the device has a CUDA context associated with it."""
2574
+ def has_context(self) -> bool:
2575
+ """A boolean indicating whether the device has a CUDA context associated with it."""
2346
2576
  return self._context is not None
2347
2577
 
2348
2578
  @property
2349
- def stream(self):
2579
+ def stream(self) -> Stream:
2350
2580
  """The stream associated with a CUDA device.
2351
2581
 
2352
2582
  Raises:
@@ -2361,7 +2591,22 @@ class Device:
2361
2591
  def stream(self, stream):
2362
2592
  self.set_stream(stream)
2363
2593
 
2364
- def set_stream(self, stream, sync=True):
2594
+ def set_stream(self, stream: Stream, sync: bool = True) -> None:
2595
+ """Set the current stream for this CUDA device.
2596
+
2597
+ The current stream will be used by default for all kernel launches and
2598
+ memory operations on this device.
2599
+
2600
+ If this is an external stream, the caller is responsible for
2601
+ guaranteeing the lifetime of the stream.
2602
+
2603
+ Consider using :class:`warp.ScopedStream` instead.
2604
+
2605
+ Args:
2606
+ stream: The stream to set as this device's current stream.
2607
+ sync: If ``True``, then ``stream`` will perform a device-side
2608
+ synchronization with the device's previous current stream.
2609
+ """
2365
2610
  if self.is_cuda:
2366
2611
  if stream.device != self:
2367
2612
  raise RuntimeError(f"Stream from device {stream.device} cannot be used on device {self}")
@@ -2372,12 +2617,12 @@ class Device:
2372
2617
  raise RuntimeError(f"Device {self} is not a CUDA device")
2373
2618
 
2374
2619
  @property
2375
- def has_stream(self):
2376
- """A boolean indicating whether or not the device has a stream associated with it."""
2620
+ def has_stream(self) -> bool:
2621
+ """A boolean indicating whether the device has a stream associated with it."""
2377
2622
  return self._stream is not None
2378
2623
 
2379
2624
  @property
2380
- def total_memory(self):
2625
+ def total_memory(self) -> int:
2381
2626
  """The total amount of device memory available in bytes.
2382
2627
 
2383
2628
  This function is currently only implemented for CUDA devices. 0 will be returned if called on a CPU device.
@@ -2391,7 +2636,7 @@ class Device:
2391
2636
  return 0
2392
2637
 
2393
2638
  @property
2394
- def free_memory(self):
2639
+ def free_memory(self) -> int:
2395
2640
  """The amount of memory on the device that is free according to the OS in bytes.
2396
2641
 
2397
2642
  This function is currently only implemented for CUDA devices. 0 will be returned if called on a CPU device.
@@ -2755,6 +3000,12 @@ class Runtime:
2755
3000
  self.core.mesh_refit_host.argtypes = [ctypes.c_uint64]
2756
3001
  self.core.mesh_refit_device.argtypes = [ctypes.c_uint64]
2757
3002
 
3003
+ self.core.mesh_set_points_host.argtypes = [ctypes.c_uint64, warp.types.array_t]
3004
+ self.core.mesh_set_points_device.argtypes = [ctypes.c_uint64, warp.types.array_t]
3005
+
3006
+ self.core.mesh_set_velocities_host.argtypes = [ctypes.c_uint64, warp.types.array_t]
3007
+ self.core.mesh_set_velocities_device.argtypes = [ctypes.c_uint64, warp.types.array_t]
3008
+
2758
3009
  self.core.hash_grid_create_host.argtypes = [ctypes.c_int, ctypes.c_int, ctypes.c_int]
2759
3010
  self.core.hash_grid_create_host.restype = ctypes.c_uint64
2760
3011
  self.core.hash_grid_destroy_host.argtypes = [ctypes.c_uint64]
@@ -3029,7 +3280,7 @@ class Runtime:
3029
3280
  self.core.cuda_set_mempool_access_enabled.argtypes = [ctypes.c_int, ctypes.c_int, ctypes.c_int]
3030
3281
  self.core.cuda_set_mempool_access_enabled.restype = ctypes.c_int
3031
3282
 
3032
- self.core.cuda_stream_create.argtypes = [ctypes.c_void_p]
3283
+ self.core.cuda_stream_create.argtypes = [ctypes.c_void_p, ctypes.c_int]
3033
3284
  self.core.cuda_stream_create.restype = ctypes.c_void_p
3034
3285
  self.core.cuda_stream_destroy.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
3035
3286
  self.core.cuda_stream_destroy.restype = None
@@ -3047,6 +3298,8 @@ class Runtime:
3047
3298
  self.core.cuda_stream_is_capturing.restype = ctypes.c_int
3048
3299
  self.core.cuda_stream_get_capture_id.argtypes = [ctypes.c_void_p]
3049
3300
  self.core.cuda_stream_get_capture_id.restype = ctypes.c_uint64
3301
+ self.core.cuda_stream_get_priority.argtypes = [ctypes.c_void_p]
3302
+ self.core.cuda_stream_get_priority.restype = ctypes.c_int
3050
3303
 
3051
3304
  self.core.cuda_event_create.argtypes = [ctypes.c_void_p, ctypes.c_uint]
3052
3305
  self.core.cuda_event_create.restype = ctypes.c_void_p
@@ -3382,10 +3635,10 @@ class Runtime:
3382
3635
 
3383
3636
  raise ValueError(f"Invalid device identifier: {ident}")
3384
3637
 
3385
- def set_default_device(self, ident: Devicelike):
3638
+ def set_default_device(self, ident: Devicelike) -> None:
3386
3639
  self.default_device = self.get_device(ident)
3387
3640
 
3388
- def get_current_cuda_device(self):
3641
+ def get_current_cuda_device(self) -> Device:
3389
3642
  current_context = self.core.cuda_context_get_current()
3390
3643
  if current_context is not None:
3391
3644
  current_device = self.context_map.get(current_context)
@@ -3415,7 +3668,7 @@ class Runtime:
3415
3668
  else:
3416
3669
  raise RuntimeError('"cuda" device requested but CUDA is not supported by the hardware or driver')
3417
3670
 
3418
- def rename_device(self, device, alias):
3671
+ def rename_device(self, device, alias) -> Device:
3419
3672
  del self.device_map[device.alias]
3420
3673
  device.alias = alias
3421
3674
  self.device_map[alias] = device
@@ -3462,7 +3715,7 @@ class Runtime:
3462
3715
 
3463
3716
  return device
3464
3717
 
3465
- def unmap_cuda_device(self, alias):
3718
+ def unmap_cuda_device(self, alias) -> None:
3466
3719
  device = self.device_map.get(alias)
3467
3720
 
3468
3721
  # make sure the alias refers to a CUDA device
@@ -3473,7 +3726,7 @@ class Runtime:
3473
3726
  del self.context_map[device.context]
3474
3727
  self.cuda_devices.remove(device)
3475
3728
 
3476
- def verify_cuda_device(self, device: Devicelike = None):
3729
+ def verify_cuda_device(self, device: Devicelike = None) -> None:
3477
3730
  if warp.config.verify_cuda:
3478
3731
  device = runtime.get_device(device)
3479
3732
  if not device.is_cuda:
@@ -3485,13 +3738,13 @@ class Runtime:
3485
3738
 
3486
3739
 
3487
3740
  # global entry points
3488
- def is_cpu_available():
3741
+ def is_cpu_available() -> bool:
3489
3742
  init()
3490
3743
 
3491
- return runtime.llvm
3744
+ return runtime.llvm is not None
3492
3745
 
3493
3746
 
3494
- def is_cuda_available():
3747
+ def is_cuda_available() -> bool:
3495
3748
  return get_cuda_device_count() > 0
3496
3749
 
3497
3750
 
@@ -3575,8 +3828,8 @@ def get_device(ident: Devicelike = None) -> Device:
3575
3828
  return runtime.get_device(ident)
3576
3829
 
3577
3830
 
3578
- def set_device(ident: Devicelike):
3579
- """Sets the target device identified by the argument."""
3831
+ def set_device(ident: Devicelike) -> None:
3832
+ """Sets the default device identified by the argument."""
3580
3833
 
3581
3834
  init()
3582
3835
 
@@ -3604,7 +3857,7 @@ def map_cuda_device(alias: str, context: ctypes.c_void_p = None) -> Device:
3604
3857
  return runtime.map_cuda_device(alias, context)
3605
3858
 
3606
3859
 
3607
- def unmap_cuda_device(alias: str):
3860
+ def unmap_cuda_device(alias: str) -> None:
3608
3861
  """Remove a CUDA device with the given alias."""
3609
3862
 
3610
3863
  init()
@@ -3612,7 +3865,7 @@ def unmap_cuda_device(alias: str):
3612
3865
  runtime.unmap_cuda_device(alias)
3613
3866
 
3614
3867
 
3615
- def is_mempool_supported(device: Devicelike):
3868
+ def is_mempool_supported(device: Devicelike) -> bool:
3616
3869
  """Check if CUDA memory pool allocators are available on the device."""
3617
3870
 
3618
3871
  init()
@@ -3622,7 +3875,7 @@ def is_mempool_supported(device: Devicelike):
3622
3875
  return device.is_mempool_supported
3623
3876
 
3624
3877
 
3625
- def is_mempool_enabled(device: Devicelike):
3878
+ def is_mempool_enabled(device: Devicelike) -> bool:
3626
3879
  """Check if CUDA memory pool allocators are enabled on the device."""
3627
3880
 
3628
3881
  init()
@@ -3632,7 +3885,7 @@ def is_mempool_enabled(device: Devicelike):
3632
3885
  return device.is_mempool_enabled
3633
3886
 
3634
3887
 
3635
- def set_mempool_enabled(device: Devicelike, enable: bool):
3888
+ def set_mempool_enabled(device: Devicelike, enable: bool) -> None:
3636
3889
  """Enable or disable CUDA memory pool allocators on the device.
3637
3890
 
3638
3891
  Pooled allocators are typically faster and allow allocating memory during graph capture.
@@ -3663,7 +3916,7 @@ def set_mempool_enabled(device: Devicelike, enable: bool):
3663
3916
  raise ValueError("Memory pools are only supported on CUDA devices")
3664
3917
 
3665
3918
 
3666
- def set_mempool_release_threshold(device: Devicelike, threshold: Union[int, float]):
3919
+ def set_mempool_release_threshold(device: Devicelike, threshold: Union[int, float]) -> None:
3667
3920
  """Set the CUDA memory pool release threshold on the device.
3668
3921
 
3669
3922
  This is the amount of reserved memory to hold onto before trying to release memory back to the OS.
@@ -3694,7 +3947,7 @@ def set_mempool_release_threshold(device: Devicelike, threshold: Union[int, floa
3694
3947
  raise RuntimeError(f"Failed to set memory pool release threshold for device {device}")
3695
3948
 
3696
3949
 
3697
- def get_mempool_release_threshold(device: Devicelike):
3950
+ def get_mempool_release_threshold(device: Devicelike) -> int:
3698
3951
  """Get the CUDA memory pool release threshold on the device."""
3699
3952
 
3700
3953
  init()
@@ -3710,7 +3963,7 @@ def get_mempool_release_threshold(device: Devicelike):
3710
3963
  return runtime.core.cuda_device_get_mempool_release_threshold(device.ordinal)
3711
3964
 
3712
3965
 
3713
- def is_peer_access_supported(target_device: Devicelike, peer_device: Devicelike):
3966
+ def is_peer_access_supported(target_device: Devicelike, peer_device: Devicelike) -> bool:
3714
3967
  """Check if `peer_device` can directly access the memory of `target_device` on this system.
3715
3968
 
3716
3969
  This applies to memory allocated using default CUDA allocators. For memory allocated using
@@ -3731,7 +3984,7 @@ def is_peer_access_supported(target_device: Devicelike, peer_device: Devicelike)
3731
3984
  return bool(runtime.core.cuda_is_peer_access_supported(target_device.ordinal, peer_device.ordinal))
3732
3985
 
3733
3986
 
3734
- def is_peer_access_enabled(target_device: Devicelike, peer_device: Devicelike):
3987
+ def is_peer_access_enabled(target_device: Devicelike, peer_device: Devicelike) -> bool:
3735
3988
  """Check if `peer_device` can currently access the memory of `target_device`.
3736
3989
 
3737
3990
  This applies to memory allocated using default CUDA allocators. For memory allocated using
@@ -3752,7 +4005,7 @@ def is_peer_access_enabled(target_device: Devicelike, peer_device: Devicelike):
3752
4005
  return bool(runtime.core.cuda_is_peer_access_enabled(target_device.context, peer_device.context))
3753
4006
 
3754
4007
 
3755
- def set_peer_access_enabled(target_device: Devicelike, peer_device: Devicelike, enable: bool):
4008
+ def set_peer_access_enabled(target_device: Devicelike, peer_device: Devicelike, enable: bool) -> None:
3756
4009
  """Enable or disable direct access from `peer_device` to the memory of `target_device`.
3757
4010
 
3758
4011
  Enabling peer access can improve the speed of peer-to-peer memory transfers, but can have
@@ -3784,7 +4037,7 @@ def set_peer_access_enabled(target_device: Devicelike, peer_device: Devicelike,
3784
4037
  raise RuntimeError(f"Failed to {action} peer access from device {peer_device} to device {target_device}")
3785
4038
 
3786
4039
 
3787
- def is_mempool_access_supported(target_device: Devicelike, peer_device: Devicelike):
4040
+ def is_mempool_access_supported(target_device: Devicelike, peer_device: Devicelike) -> bool:
3788
4041
  """Check if `peer_device` can directly access the memory pool of `target_device`.
3789
4042
 
3790
4043
  If mempool access is possible, it can be managed using `set_mempool_access_enabled()` and `is_mempool_access_enabled()`.
@@ -3801,7 +4054,7 @@ def is_mempool_access_supported(target_device: Devicelike, peer_device: Deviceli
3801
4054
  return target_device.is_mempool_supported and is_peer_access_supported(target_device, peer_device)
3802
4055
 
3803
4056
 
3804
- def is_mempool_access_enabled(target_device: Devicelike, peer_device: Devicelike):
4057
+ def is_mempool_access_enabled(target_device: Devicelike, peer_device: Devicelike) -> bool:
3805
4058
  """Check if `peer_device` can currently access the memory pool of `target_device`.
3806
4059
 
3807
4060
  This applies to memory allocated using CUDA pooled allocators. For memory allocated using
@@ -3822,7 +4075,7 @@ def is_mempool_access_enabled(target_device: Devicelike, peer_device: Devicelike
3822
4075
  return bool(runtime.core.cuda_is_mempool_access_enabled(target_device.ordinal, peer_device.ordinal))
3823
4076
 
3824
4077
 
3825
- def set_mempool_access_enabled(target_device: Devicelike, peer_device: Devicelike, enable: bool):
4078
+ def set_mempool_access_enabled(target_device: Devicelike, peer_device: Devicelike, enable: bool) -> None:
3826
4079
  """Enable or disable access from `peer_device` to the memory pool of `target_device`.
3827
4080
 
3828
4081
  This applies to memory allocated using CUDA pooled allocators. For memory allocated using
@@ -3858,26 +4111,41 @@ def set_mempool_access_enabled(target_device: Devicelike, peer_device: Devicelik
3858
4111
 
3859
4112
 
3860
4113
  def get_stream(device: Devicelike = None) -> Stream:
3861
- """Return the stream currently used by the given device"""
4114
+ """Return the stream currently used by the given device.
4115
+
4116
+ Args:
4117
+ device: An optional :class:`Device` instance or device alias
4118
+ (e.g. "cuda:0") for which the current stream will be returned.
4119
+ If ``None``, the default device will be used.
4120
+
4121
+ Raises:
4122
+ RuntimeError: The device is not a CUDA device.
4123
+ """
3862
4124
 
3863
4125
  return get_device(device).stream
3864
4126
 
3865
4127
 
3866
- def set_stream(stream, device: Devicelike = None, sync: bool = False):
3867
- """Set the stream to be used by the given device.
4128
+ def set_stream(stream: Stream, device: Devicelike = None, sync: bool = False) -> None:
4129
+ """Convenience function for calling :meth:`Device.set_stream` on the given ``device``.
3868
4130
 
3869
- If this is an external stream, caller is responsible for guaranteeing the lifetime of the stream.
3870
- Consider using wp.ScopedStream instead.
4131
+ Args:
4132
+ device: An optional :class:`Device` instance or device alias
4133
+ (e.g. "cuda:0") for which the current stream is to be replaced with
4134
+ ``stream``. If ``None``, the default device will be used.
4135
+ stream: The stream to set as this device's current stream.
4136
+ sync: If ``True``, then ``stream`` will perform a device-side
4137
+ synchronization with the device's previous current stream.
3871
4138
  """
3872
4139
 
3873
4140
  get_device(device).set_stream(stream, sync=sync)
3874
4141
 
3875
4142
 
3876
- def record_event(event: Event = None):
3877
- """Record a CUDA event on the current stream.
4143
+ def record_event(event: Optional[Event] = None):
4144
+ """Convenience function for calling :meth:`Stream.record_event` on the current stream.
3878
4145
 
3879
4146
  Args:
3880
- event: Event to record. If None, a new Event will be created.
4147
+ event: :class:`Event` instance to record. If ``None``, a new :class:`Event`
4148
+ instance will be created.
3881
4149
 
3882
4150
  Returns:
3883
4151
  The recorded event.
@@ -3887,29 +4155,31 @@ def record_event(event: Event = None):
3887
4155
 
3888
4156
 
3889
4157
  def wait_event(event: Event):
3890
- """Make the current stream wait for a CUDA event.
4158
+ """Convenience function for calling :meth:`Stream.wait_event` on the current stream.
3891
4159
 
3892
4160
  Args:
3893
- event: Event to wait for.
4161
+ event: :class:`Event` instance to wait for.
3894
4162
  """
3895
4163
 
3896
4164
  get_stream().wait_event(event)
3897
4165
 
3898
4166
 
3899
- def get_event_elapsed_time(start_event: Event, end_event: Event, synchronize: bool = True):
4167
+ def get_event_elapsed_time(start_event: Event, end_event: Event, synchronize: Optional[bool] = True):
3900
4168
  """Get the elapsed time between two recorded events.
3901
4169
 
3902
- The result is in milliseconds with a resolution of about 0.5 microsecond.
3903
-
3904
- Both events must have been previously recorded with ``wp.record_event()`` or ``wp.Stream.record_event()``.
4170
+ Both events must have been previously recorded with
4171
+ :func:`~warp.record_event()` or :meth:`warp.Stream.record_event()`.
3905
4172
 
3906
4173
  If ``synchronize`` is False, the caller must ensure that device execution has reached ``end_event``
3907
4174
  prior to calling ``get_event_elapsed_time()``.
3908
4175
 
3909
4176
  Args:
3910
- start_event (Event): The start event.
3911
- end_event (Event): The end event.
3912
- synchronize (bool, optional): Whether Warp should synchronize on the ``end_event``.
4177
+ start_event: The start event.
4178
+ end_event: The end event.
4179
+ synchronize: Whether Warp should synchronize on the ``end_event``.
4180
+
4181
+ Returns:
4182
+ The elapsed time in milliseconds with a resolution about 0.5 ms.
3913
4183
  """
3914
4184
 
3915
4185
  # ensure the end_event is reached
@@ -3919,14 +4189,19 @@ def get_event_elapsed_time(start_event: Event, end_event: Event, synchronize: bo
3919
4189
  return runtime.core.cuda_event_elapsed_time(start_event.cuda_event, end_event.cuda_event)
3920
4190
 
3921
4191
 
3922
- def wait_stream(stream: Stream, event: Event = None):
3923
- """Make the current stream wait for another CUDA stream to complete its work.
4192
+ def wait_stream(other_stream: Stream, event: Event = None):
4193
+ """Convenience function for calling :meth:`Stream.wait_stream` on the current stream.
3924
4194
 
3925
4195
  Args:
3926
- event: Event to be used. If None, a new Event will be created.
4196
+ other_stream: The stream on which the calling stream will wait for
4197
+ previously issued commands to complete before executing subsequent
4198
+ commands.
4199
+ event: An optional :class:`Event` instance that will be used to
4200
+ record an event onto ``other_stream``. If ``None``, an internally
4201
+ managed :class:`Event` instance will be used.
3927
4202
  """
3928
4203
 
3929
- get_stream().wait_stream(stream, event=event)
4204
+ get_stream().wait_stream(other_stream, event=event)
3930
4205
 
3931
4206
 
3932
4207
  class RegisteredGLBuffer:
@@ -4362,7 +4637,7 @@ def from_numpy(
4362
4637
  dtype: The data type of the new Warp array. If this is not provided, the data type will be inferred.
4363
4638
  shape: The shape of the Warp array.
4364
4639
  device: The device on which the Warp array will be constructed.
4365
- requires_grad: Whether or not gradients will be tracked for this array.
4640
+ requires_grad: Whether gradients will be tracked for this array.
4366
4641
 
4367
4642
  Raises:
4368
4643
  RuntimeError: The data type of the NumPy array is not supported.
@@ -4398,7 +4673,7 @@ def pack_arg(kernel, arg_type, arg_name, value, device, adjoint=False):
4398
4673
  return arg_type.__ctype__()
4399
4674
 
4400
4675
  elif isinstance(value, warp.types.array_t):
4401
- # accept array descriptors verbatum
4676
+ # accept array descriptors verbatim
4402
4677
  return value
4403
4678
 
4404
4679
  else:
@@ -4865,7 +5140,7 @@ def synchronize_device(device: Devicelike = None):
4865
5140
  runtime.core.cuda_context_synchronize(device.context)
4866
5141
 
4867
5142
 
4868
- def synchronize_stream(stream_or_device=None):
5143
+ def synchronize_stream(stream_or_device: Union[Stream, Devicelike, None] = None):
4869
5144
  """Synchronize the calling CPU thread with any outstanding CUDA work on the specified stream.
4870
5145
 
4871
5146
  This function allows the host application code to ensure that all kernel launches
@@ -4989,7 +5264,7 @@ def set_module_options(options: Dict[str, Any], module: Optional[Any] = None):
4989
5264
  m = module
4990
5265
 
4991
5266
  get_module(m.__name__).options.update(options)
4992
- get_module(m.__name__).unload()
5267
+ get_module(m.__name__).mark_modified()
4993
5268
 
4994
5269
 
4995
5270
  def get_module_options(module: Optional[Any] = None) -> Dict[str, Any]:
@@ -5016,7 +5291,7 @@ def capture_begin(device: Devicelike = None, stream=None, force_module_load=None
5016
5291
  Args:
5017
5292
  device: The CUDA device to capture on
5018
5293
  stream: The CUDA stream to capture on
5019
- force_module_load: Whether or not to force loading of all kernels before capture.
5294
+ force_module_load: Whether to force loading of all kernels before capture.
5020
5295
  In general it is better to use :func:`~warp.load_module()` to selectively load kernels.
5021
5296
  When running with CUDA drivers that support CUDA 12.3 or newer, this option is not recommended to be set to
5022
5297
  ``True`` because kernels can be loaded during graph capture on more recent drivers. If this argument is
@@ -5574,7 +5849,7 @@ def export_stubs(file): # pragma: no cover
5574
5849
 
5575
5850
  return_str = ""
5576
5851
 
5577
- if not f.export or f.hidden: # or f.generic:
5852
+ if f.hidden: # or f.generic:
5578
5853
  continue
5579
5854
 
5580
5855
  return_type = f.value_func(None, None)