warp-lang 1.3.2__py3-none-win_amd64.whl → 1.4.0__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +6 -0
- warp/autograd.py +59 -6
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build_dll.py +8 -10
- warp/builtins.py +126 -4
- warp/codegen.py +435 -53
- warp/config.py +1 -1
- warp/context.py +678 -403
- warp/dlpack.py +2 -0
- warp/examples/benchmarks/benchmark_cloth.py +10 -0
- warp/examples/core/example_render_opengl.py +12 -10
- warp/examples/fem/example_adaptive_grid.py +251 -0
- warp/examples/fem/example_apic_fluid.py +1 -1
- warp/examples/fem/example_diffusion_3d.py +2 -2
- warp/examples/fem/example_magnetostatics.py +1 -1
- warp/examples/fem/example_streamlines.py +1 -0
- warp/examples/fem/utils.py +23 -4
- warp/examples/sim/example_cloth.py +50 -6
- warp/fem/__init__.py +2 -0
- warp/fem/adaptivity.py +493 -0
- warp/fem/field/field.py +2 -1
- warp/fem/field/nodal_field.py +18 -26
- warp/fem/field/test.py +4 -4
- warp/fem/field/trial.py +4 -4
- warp/fem/geometry/__init__.py +1 -0
- warp/fem/geometry/adaptive_nanogrid.py +843 -0
- warp/fem/geometry/nanogrid.py +55 -28
- warp/fem/space/__init__.py +1 -1
- warp/fem/space/nanogrid_function_space.py +69 -35
- warp/fem/utils.py +113 -107
- warp/jax_experimental.py +28 -15
- warp/native/array.h +0 -1
- warp/native/builtin.h +103 -6
- warp/native/bvh.cu +2 -0
- warp/native/cuda_util.cpp +14 -0
- warp/native/cuda_util.h +2 -0
- warp/native/error.cpp +4 -2
- warp/native/exports.h +99 -17
- warp/native/mat.h +97 -0
- warp/native/mesh.cpp +36 -0
- warp/native/mesh.cu +51 -0
- warp/native/mesh.h +1 -0
- warp/native/quat.h +43 -0
- warp/native/spatial.h +6 -0
- warp/native/vec.h +74 -0
- warp/native/warp.cpp +2 -1
- warp/native/warp.cu +10 -3
- warp/native/warp.h +8 -1
- warp/paddle.py +382 -0
- warp/sim/__init__.py +1 -0
- warp/sim/collide.py +519 -0
- warp/sim/integrator_euler.py +18 -5
- warp/sim/integrator_featherstone.py +5 -5
- warp/sim/integrator_vbd.py +1026 -0
- warp/sim/model.py +49 -23
- warp/stubs.py +459 -0
- warp/tape.py +2 -0
- warp/tests/aux_test_dependent.py +1 -0
- warp/tests/aux_test_name_clash1.py +32 -0
- warp/tests/aux_test_name_clash2.py +32 -0
- warp/tests/aux_test_square.py +1 -0
- warp/tests/test_array.py +222 -0
- warp/tests/test_async.py +3 -3
- warp/tests/test_atomic.py +6 -0
- warp/tests/test_closest_point_edge_edge.py +93 -1
- warp/tests/test_codegen.py +62 -15
- warp/tests/test_codegen_instancing.py +1457 -0
- warp/tests/test_collision.py +486 -0
- warp/tests/test_compile_consts.py +3 -28
- warp/tests/test_dlpack.py +170 -0
- warp/tests/test_examples.py +22 -8
- warp/tests/test_fast_math.py +10 -4
- warp/tests/test_fem.py +64 -0
- warp/tests/test_func.py +46 -0
- warp/tests/test_implicit_init.py +49 -0
- warp/tests/test_jax.py +58 -0
- warp/tests/test_mat.py +84 -0
- warp/tests/test_mesh_query_point.py +188 -0
- warp/tests/test_module_hashing.py +40 -0
- warp/tests/test_multigpu.py +3 -3
- warp/tests/test_overwrite.py +8 -0
- warp/tests/test_paddle.py +852 -0
- warp/tests/test_print.py +89 -0
- warp/tests/test_quat.py +111 -0
- warp/tests/test_reload.py +31 -1
- warp/tests/test_scalar_ops.py +2 -0
- warp/tests/test_static.py +412 -0
- warp/tests/test_streams.py +64 -3
- warp/tests/test_struct.py +4 -4
- warp/tests/test_torch.py +24 -0
- warp/tests/test_triangle_closest_point.py +137 -0
- warp/tests/test_types.py +1 -1
- warp/tests/test_vbd.py +386 -0
- warp/tests/test_vec.py +143 -0
- warp/tests/test_vec_scalar_ops.py +139 -0
- warp/tests/test_volume.py +30 -0
- warp/tests/unittest_suites.py +12 -0
- warp/tests/unittest_utils.py +9 -5
- warp/thirdparty/dlpack.py +3 -1
- warp/types.py +157 -34
- warp/utils.py +37 -14
- {warp_lang-1.3.2.dist-info → warp_lang-1.4.0.dist-info}/METADATA +10 -8
- {warp_lang-1.3.2.dist-info → warp_lang-1.4.0.dist-info}/RECORD +107 -95
- warp/tests/test_point_triangle_closest_point.py +0 -143
- {warp_lang-1.3.2.dist-info → warp_lang-1.4.0.dist-info}/LICENSE.md +0 -0
- {warp_lang-1.3.2.dist-info → warp_lang-1.4.0.dist-info}/WHEEL +0 -0
- {warp_lang-1.3.2.dist-info → warp_lang-1.4.0.dist-info}/top_level.txt +0 -0
warp/context.py
CHANGED
|
@@ -19,6 +19,7 @@ import platform
|
|
|
19
19
|
import sys
|
|
20
20
|
import types
|
|
21
21
|
import typing
|
|
22
|
+
import weakref
|
|
22
23
|
from copy import copy as shallowcopy
|
|
23
24
|
from pathlib import Path
|
|
24
25
|
from struct import pack as struct_pack
|
|
@@ -36,6 +37,10 @@ import warp.config
|
|
|
36
37
|
|
|
37
38
|
def create_value_func(type):
|
|
38
39
|
def value_func(arg_types, arg_values):
|
|
40
|
+
hint_origin = getattr(type, "__origin__", None)
|
|
41
|
+
if hint_origin is not None and issubclass(hint_origin, typing.Tuple):
|
|
42
|
+
return type.__args__
|
|
43
|
+
|
|
39
44
|
return type
|
|
40
45
|
|
|
41
46
|
return value_func
|
|
@@ -54,6 +59,38 @@ def get_function_args(func):
|
|
|
54
59
|
complex_type_hints = (Any, Callable, Tuple)
|
|
55
60
|
sequence_types = (list, tuple)
|
|
56
61
|
|
|
62
|
+
function_key_counts = {}
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def generate_unique_function_identifier(key):
|
|
66
|
+
# Generate unique identifiers for user-defined functions in native code.
|
|
67
|
+
# - Prevents conflicts when a function is redefined and old versions are still in use.
|
|
68
|
+
# - Prevents conflicts between multiple closures returned from the same function.
|
|
69
|
+
# - Prevents conflicts between identically named functions from different modules.
|
|
70
|
+
#
|
|
71
|
+
# Currently, we generate a unique id when a new Function is created, which produces
|
|
72
|
+
# globally unique identifiers.
|
|
73
|
+
#
|
|
74
|
+
# NOTE:
|
|
75
|
+
# We could move this to the Module class for generating unique identifiers at module scope,
|
|
76
|
+
# but then we need another solution for preventing conflicts across modules (e.g., different namespaces).
|
|
77
|
+
# That would requires more Python code, generate more native code, and would be slightly slower
|
|
78
|
+
# with no clear advantages over globally-unique identifiers (non-global shared state is still shared state).
|
|
79
|
+
#
|
|
80
|
+
# TODO:
|
|
81
|
+
# Kernels and structs use unique identifiers based on their hash. Using hash-based identifiers
|
|
82
|
+
# for functions would allow filtering out duplicate identical functions during codegen,
|
|
83
|
+
# like we do with kernels and structs. This is worth investigating further, but might require
|
|
84
|
+
# additional refactoring. For example, the code that deals with custom gradient and replay functions
|
|
85
|
+
# requires matching function names, but these special functions get created before the hash
|
|
86
|
+
# for the parent function can be computed. In addition to these complications, computing hashes
|
|
87
|
+
# for all function instances would increase the cost of module hashing when generic functions
|
|
88
|
+
# are involved (currently we only hash the generic templates, which is sufficient).
|
|
89
|
+
|
|
90
|
+
unique_id = function_key_counts.get(key, 0)
|
|
91
|
+
function_key_counts[key] = unique_id + 1
|
|
92
|
+
return f"{key}_{unique_id}"
|
|
93
|
+
|
|
57
94
|
|
|
58
95
|
class Function:
|
|
59
96
|
def __init__(
|
|
@@ -90,6 +127,7 @@ class Function:
|
|
|
90
127
|
code_transformers=None,
|
|
91
128
|
skip_adding_overload=False,
|
|
92
129
|
require_original_output_arg=False,
|
|
130
|
+
scope_locals=None, # the locals() where the function is defined, used for overload management
|
|
93
131
|
):
|
|
94
132
|
if code_transformers is None:
|
|
95
133
|
code_transformers = []
|
|
@@ -115,6 +153,7 @@ class Function:
|
|
|
115
153
|
self.replay_snippet = replay_snippet
|
|
116
154
|
self.custom_grad_func = None
|
|
117
155
|
self.require_original_output_arg = require_original_output_arg
|
|
156
|
+
self.generic_parent = None # generic function that was used to instantiate this overload
|
|
118
157
|
|
|
119
158
|
if initializer_list_func is None:
|
|
120
159
|
self.initializer_list_func = lambda x, y: False
|
|
@@ -124,14 +163,19 @@ class Function:
|
|
|
124
163
|
)
|
|
125
164
|
self.hidden = hidden # function will not be listed in docs
|
|
126
165
|
self.skip_replay = (
|
|
127
|
-
skip_replay # whether
|
|
166
|
+
skip_replay # whether operation will be performed during the forward replay in the backward pass
|
|
128
167
|
)
|
|
129
|
-
self.missing_grad = missing_grad # whether
|
|
168
|
+
self.missing_grad = missing_grad # whether builtin is missing a corresponding adjoint
|
|
130
169
|
self.generic = generic
|
|
131
170
|
|
|
132
|
-
# allow registering
|
|
171
|
+
# allow registering functions with a different name in Python and native code
|
|
133
172
|
if native_func is None:
|
|
134
|
-
|
|
173
|
+
if func is None:
|
|
174
|
+
# builtin function
|
|
175
|
+
self.native_func = key
|
|
176
|
+
else:
|
|
177
|
+
# user functions need unique identifiers to avoid conflicts
|
|
178
|
+
self.native_func = generate_unique_function_identifier(key)
|
|
135
179
|
else:
|
|
136
180
|
self.native_func = native_func
|
|
137
181
|
|
|
@@ -162,6 +206,11 @@ class Function:
|
|
|
162
206
|
else:
|
|
163
207
|
self.input_types[name] = type
|
|
164
208
|
|
|
209
|
+
# Record any default parameter values.
|
|
210
|
+
if not self.defaults:
|
|
211
|
+
signature = inspect.signature(func)
|
|
212
|
+
self.defaults = {k: v.default for k, v in signature.parameters.items() if v.default is not v.empty}
|
|
213
|
+
|
|
165
214
|
else:
|
|
166
215
|
# builtin function
|
|
167
216
|
|
|
@@ -210,9 +259,13 @@ class Function:
|
|
|
210
259
|
signature_params.append(param)
|
|
211
260
|
self.signature = inspect.Signature(signature_params)
|
|
212
261
|
|
|
262
|
+
# scope for resolving overloads
|
|
263
|
+
if scope_locals is None:
|
|
264
|
+
scope_locals = inspect.currentframe().f_back.f_locals
|
|
265
|
+
|
|
213
266
|
# add to current module
|
|
214
267
|
if module:
|
|
215
|
-
module.register_function(self, skip_adding_overload)
|
|
268
|
+
module.register_function(self, scope_locals, skip_adding_overload)
|
|
216
269
|
|
|
217
270
|
def __call__(self, *args, **kwargs):
|
|
218
271
|
# handles calling a builtin (native) function
|
|
@@ -323,57 +376,52 @@ class Function:
|
|
|
323
376
|
|
|
324
377
|
# check if generic
|
|
325
378
|
if warp.types.is_generic_signature(sig):
|
|
326
|
-
if sig in self.user_templates:
|
|
327
|
-
raise RuntimeError(
|
|
328
|
-
f"Duplicate generic function overload {self.key} with arguments {f.input_types.values()}"
|
|
329
|
-
)
|
|
330
379
|
self.user_templates[sig] = f
|
|
331
380
|
else:
|
|
332
|
-
if sig in self.user_overloads:
|
|
333
|
-
raise RuntimeError(
|
|
334
|
-
f"Duplicate function overload {self.key} with arguments {f.input_types.values()}"
|
|
335
|
-
)
|
|
336
381
|
self.user_overloads[sig] = f
|
|
337
382
|
|
|
338
383
|
def get_overload(self, arg_types, kwarg_types):
|
|
339
384
|
assert not self.is_builtin()
|
|
340
385
|
|
|
341
|
-
|
|
386
|
+
for f in self.user_overloads.values():
|
|
387
|
+
if warp.codegen.func_match_args(f, arg_types, kwarg_types):
|
|
388
|
+
return f
|
|
342
389
|
|
|
343
|
-
f
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
continue
|
|
390
|
+
for f in self.user_templates.values():
|
|
391
|
+
if not warp.codegen.func_match_args(f, arg_types, kwarg_types):
|
|
392
|
+
continue
|
|
393
|
+
|
|
394
|
+
if len(f.input_types) != len(arg_types):
|
|
395
|
+
continue
|
|
350
396
|
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
397
|
+
# try to match the given types to the function template types
|
|
398
|
+
template_types = list(f.input_types.values())
|
|
399
|
+
args_matched = True
|
|
354
400
|
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
401
|
+
for i in range(len(arg_types)):
|
|
402
|
+
if not warp.types.type_matches_template(arg_types[i], template_types[i]):
|
|
403
|
+
args_matched = False
|
|
404
|
+
break
|
|
359
405
|
|
|
360
|
-
|
|
361
|
-
|
|
406
|
+
if args_matched:
|
|
407
|
+
# instantiate this function with the specified argument types
|
|
362
408
|
|
|
363
|
-
|
|
364
|
-
|
|
409
|
+
arg_names = f.input_types.keys()
|
|
410
|
+
overload_annotations = dict(zip(arg_names, arg_types))
|
|
365
411
|
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
412
|
+
ovl = shallowcopy(f)
|
|
413
|
+
ovl.adj = warp.codegen.Adjoint(f.func, overload_annotations)
|
|
414
|
+
ovl.input_types = overload_annotations
|
|
415
|
+
ovl.value_func = None
|
|
416
|
+
ovl.generic_parent = f
|
|
370
417
|
|
|
371
|
-
|
|
418
|
+
sig = warp.types.get_signature(arg_types, func_name=self.key)
|
|
419
|
+
self.user_overloads[sig] = ovl
|
|
372
420
|
|
|
373
|
-
|
|
421
|
+
return ovl
|
|
374
422
|
|
|
375
|
-
|
|
376
|
-
|
|
423
|
+
# failed to find overload
|
|
424
|
+
return None
|
|
377
425
|
|
|
378
426
|
def __repr__(self):
|
|
379
427
|
inputs_str = ", ".join([f"{k}: {warp.types.type_repr(v)}" for k, v in self.input_types.items()])
|
|
@@ -589,8 +637,7 @@ class Kernel:
|
|
|
589
637
|
self.module = module
|
|
590
638
|
|
|
591
639
|
if key is None:
|
|
592
|
-
|
|
593
|
-
self.key = unique_key
|
|
640
|
+
self.key = warp.codegen.make_full_qualified_name(func)
|
|
594
641
|
else:
|
|
595
642
|
self.key = key
|
|
596
643
|
|
|
@@ -614,9 +661,15 @@ class Kernel:
|
|
|
614
661
|
# known overloads for generic kernels, indexed by type signature
|
|
615
662
|
self.overloads = {}
|
|
616
663
|
|
|
664
|
+
# generic kernel that was used to instantiate this overload
|
|
665
|
+
self.generic_parent = None
|
|
666
|
+
|
|
617
667
|
# argument indices by name
|
|
618
668
|
self.arg_indices = {a.label: i for i, a in enumerate(self.adj.args)}
|
|
619
669
|
|
|
670
|
+
# hash will be computed when the module is built
|
|
671
|
+
self.hash = None
|
|
672
|
+
|
|
620
673
|
if self.module:
|
|
621
674
|
self.module.register_kernel(self)
|
|
622
675
|
|
|
@@ -664,10 +717,11 @@ class Kernel:
|
|
|
664
717
|
ovl.is_generic = False
|
|
665
718
|
ovl.overloads = {}
|
|
666
719
|
ovl.sig = sig
|
|
720
|
+
ovl.generic_parent = self
|
|
667
721
|
|
|
668
722
|
self.overloads[sig] = ovl
|
|
669
723
|
|
|
670
|
-
self.module.
|
|
724
|
+
self.module.mark_modified()
|
|
671
725
|
|
|
672
726
|
return ovl
|
|
673
727
|
|
|
@@ -676,10 +730,13 @@ class Kernel:
|
|
|
676
730
|
return self.overloads.get(sig)
|
|
677
731
|
|
|
678
732
|
def get_mangled_name(self):
|
|
679
|
-
if self.
|
|
680
|
-
|
|
681
|
-
|
|
682
|
-
|
|
733
|
+
if self.hash is None:
|
|
734
|
+
raise RuntimeError(f"Missing hash for kernel {self.key} in module {self.module.name}")
|
|
735
|
+
|
|
736
|
+
# TODO: allow customizing the number of hash characters used
|
|
737
|
+
hash_suffix = self.hash.hex()[:8]
|
|
738
|
+
|
|
739
|
+
return f"{self.key}_{hash_suffix}"
|
|
683
740
|
|
|
684
741
|
|
|
685
742
|
# ----------------------
|
|
@@ -689,9 +746,11 @@ class Kernel:
|
|
|
689
746
|
def func(f):
|
|
690
747
|
name = warp.codegen.make_full_qualified_name(f)
|
|
691
748
|
|
|
749
|
+
scope_locals = inspect.currentframe().f_back.f_locals
|
|
750
|
+
|
|
692
751
|
m = get_module(f.__module__)
|
|
693
752
|
Function(
|
|
694
|
-
func=f, key=name, namespace="", module=m, value_func=None
|
|
753
|
+
func=f, key=name, namespace="", module=m, value_func=None, scope_locals=scope_locals
|
|
695
754
|
) # value_type not known yet, will be inferred during Adjoint.build()
|
|
696
755
|
|
|
697
756
|
# use the top of the list of overloads for this key
|
|
@@ -705,6 +764,8 @@ def func_native(snippet, adj_snippet=None, replay_snippet=None):
|
|
|
705
764
|
Decorator to register native code snippet, @func_native
|
|
706
765
|
"""
|
|
707
766
|
|
|
767
|
+
scope_locals = inspect.currentframe().f_back.f_locals
|
|
768
|
+
|
|
708
769
|
def snippet_func(f):
|
|
709
770
|
name = warp.codegen.make_full_qualified_name(f)
|
|
710
771
|
|
|
@@ -717,6 +778,7 @@ def func_native(snippet, adj_snippet=None, replay_snippet=None):
|
|
|
717
778
|
native_snippet=snippet,
|
|
718
779
|
adj_native_snippet=adj_snippet,
|
|
719
780
|
replay_snippet=replay_snippet,
|
|
781
|
+
scope_locals=scope_locals,
|
|
720
782
|
) # value_type not known yet, will be inferred during Adjoint.build()
|
|
721
783
|
g = m.functions[name]
|
|
722
784
|
# copy over the function attributes, including docstring
|
|
@@ -783,6 +845,7 @@ def func_grad(forward_fn):
|
|
|
783
845
|
f.custom_grad_func = Function(
|
|
784
846
|
grad_fn,
|
|
785
847
|
key=f.key,
|
|
848
|
+
native_func=f.native_func,
|
|
786
849
|
namespace=f.namespace,
|
|
787
850
|
input_types=reverse_args,
|
|
788
851
|
value_func=None,
|
|
@@ -941,7 +1004,7 @@ def overload(kernel, arg_types=None):
|
|
|
941
1004
|
# ensure this function name corresponds to a kernel
|
|
942
1005
|
fn = kernel
|
|
943
1006
|
module = get_module(fn.__module__)
|
|
944
|
-
kernel = module.
|
|
1007
|
+
kernel = module.find_kernel(fn)
|
|
945
1008
|
if kernel is None:
|
|
946
1009
|
raise RuntimeError(f"Failed to find a kernel named '{fn.__name__}' in module {fn.__module__}")
|
|
947
1010
|
|
|
@@ -1277,7 +1340,6 @@ def get_module(name):
|
|
|
1277
1340
|
# clear out old kernels, funcs, struct definitions
|
|
1278
1341
|
old_module.kernels = {}
|
|
1279
1342
|
old_module.functions = {}
|
|
1280
|
-
old_module.constants = {}
|
|
1281
1343
|
old_module.structs = {}
|
|
1282
1344
|
old_module.loader = parent_loader
|
|
1283
1345
|
|
|
@@ -1289,30 +1351,202 @@ def get_module(name):
|
|
|
1289
1351
|
return user_modules[name]
|
|
1290
1352
|
|
|
1291
1353
|
|
|
1354
|
+
# ModuleHasher computes the module hash based on all the kernels, module options,
|
|
1355
|
+
# and build configuration. For each kernel, it computes a deep hash by recursively
|
|
1356
|
+
# hashing all referenced functions, structs, and constants, even those defined in
|
|
1357
|
+
# other modules. The module hash is computed in the constructor and can be retrieved
|
|
1358
|
+
# using get_module_hash(). In addition, the ModuleHasher takes care of filtering out
|
|
1359
|
+
# duplicate kernels for codegen (see get_unique_kernels()).
|
|
1360
|
+
class ModuleHasher:
|
|
1361
|
+
def __init__(self, module):
|
|
1362
|
+
# cache function hashes to avoid hashing multiple times
|
|
1363
|
+
self.function_hashes = {} # (function: hash)
|
|
1364
|
+
|
|
1365
|
+
# avoid recursive spiral of doom (e.g., function calling an overload of itself)
|
|
1366
|
+
self.functions_in_progress = set()
|
|
1367
|
+
|
|
1368
|
+
# all unique kernels for codegen, filtered by hash
|
|
1369
|
+
self.unique_kernels = {} # (hash: kernel)
|
|
1370
|
+
|
|
1371
|
+
# start hashing the module
|
|
1372
|
+
ch = hashlib.sha256()
|
|
1373
|
+
|
|
1374
|
+
# hash all non-generic kernels
|
|
1375
|
+
for kernel in module.live_kernels:
|
|
1376
|
+
if kernel.is_generic:
|
|
1377
|
+
for ovl in kernel.overloads.values():
|
|
1378
|
+
if not ovl.adj.skip_build:
|
|
1379
|
+
ovl.hash = self.hash_kernel(ovl)
|
|
1380
|
+
else:
|
|
1381
|
+
if not kernel.adj.skip_build:
|
|
1382
|
+
kernel.hash = self.hash_kernel(kernel)
|
|
1383
|
+
|
|
1384
|
+
# include all unique kernels in the module hash
|
|
1385
|
+
for kernel_hash in sorted(self.unique_kernels.keys()):
|
|
1386
|
+
ch.update(kernel_hash)
|
|
1387
|
+
|
|
1388
|
+
# configuration parameters
|
|
1389
|
+
for opt in sorted(module.options.keys()):
|
|
1390
|
+
s = f"{opt}:{module.options[opt]}"
|
|
1391
|
+
ch.update(bytes(s, "utf-8"))
|
|
1392
|
+
|
|
1393
|
+
# ensure to trigger recompilation if flags affecting kernel compilation are changed
|
|
1394
|
+
if warp.config.verify_fp:
|
|
1395
|
+
ch.update(bytes("verify_fp", "utf-8"))
|
|
1396
|
+
|
|
1397
|
+
# build config
|
|
1398
|
+
ch.update(bytes(warp.config.mode, "utf-8"))
|
|
1399
|
+
|
|
1400
|
+
# save the module hash
|
|
1401
|
+
self.module_hash = ch.digest()
|
|
1402
|
+
|
|
1403
|
+
def hash_kernel(self, kernel):
|
|
1404
|
+
# NOTE: We only hash non-generic kernels, so we don't traverse kernel overloads here.
|
|
1405
|
+
|
|
1406
|
+
ch = hashlib.sha256()
|
|
1407
|
+
|
|
1408
|
+
ch.update(bytes(kernel.key, "utf-8"))
|
|
1409
|
+
ch.update(self.hash_adjoint(kernel.adj))
|
|
1410
|
+
|
|
1411
|
+
h = ch.digest()
|
|
1412
|
+
|
|
1413
|
+
self.unique_kernels[h] = kernel
|
|
1414
|
+
|
|
1415
|
+
return h
|
|
1416
|
+
|
|
1417
|
+
def hash_function(self, func):
|
|
1418
|
+
# NOTE: This method hashes all possible overloads that a function call could resolve to.
|
|
1419
|
+
# The exact overload will be resolved at build time, when the argument types are known.
|
|
1420
|
+
|
|
1421
|
+
h = self.function_hashes.get(func)
|
|
1422
|
+
if h is not None:
|
|
1423
|
+
return h
|
|
1424
|
+
|
|
1425
|
+
self.functions_in_progress.add(func)
|
|
1426
|
+
|
|
1427
|
+
ch = hashlib.sha256()
|
|
1428
|
+
|
|
1429
|
+
ch.update(bytes(func.key, "utf-8"))
|
|
1430
|
+
|
|
1431
|
+
# include all concrete and generic overloads
|
|
1432
|
+
overloads = {**func.user_overloads, **func.user_templates}
|
|
1433
|
+
for sig in sorted(overloads.keys()):
|
|
1434
|
+
ovl = overloads[sig]
|
|
1435
|
+
|
|
1436
|
+
# skip instantiations of generic functions
|
|
1437
|
+
if ovl.generic_parent is not None:
|
|
1438
|
+
continue
|
|
1439
|
+
|
|
1440
|
+
# adjoint
|
|
1441
|
+
ch.update(self.hash_adjoint(ovl.adj))
|
|
1442
|
+
|
|
1443
|
+
# custom bits
|
|
1444
|
+
if ovl.custom_grad_func:
|
|
1445
|
+
ch.update(self.hash_adjoint(ovl.custom_grad_func.adj))
|
|
1446
|
+
if ovl.custom_replay_func:
|
|
1447
|
+
ch.update(self.hash_adjoint(ovl.custom_replay_func.adj))
|
|
1448
|
+
if ovl.replay_snippet:
|
|
1449
|
+
ch.update(bytes(ovl.replay_snippet, "utf-8"))
|
|
1450
|
+
if ovl.native_snippet:
|
|
1451
|
+
ch.update(bytes(ovl.native_snippet, "utf-8"))
|
|
1452
|
+
if ovl.adj_native_snippet:
|
|
1453
|
+
ch.update(bytes(ovl.adj_native_snippet, "utf-8"))
|
|
1454
|
+
|
|
1455
|
+
h = ch.digest()
|
|
1456
|
+
|
|
1457
|
+
self.function_hashes[func] = h
|
|
1458
|
+
|
|
1459
|
+
self.functions_in_progress.remove(func)
|
|
1460
|
+
|
|
1461
|
+
return h
|
|
1462
|
+
|
|
1463
|
+
def hash_adjoint(self, adj):
|
|
1464
|
+
# NOTE: We don't cache adjoint hashes, because adjoints are always unique.
|
|
1465
|
+
# Even instances of generic kernels and functions have unique adjoints with
|
|
1466
|
+
# different argument types.
|
|
1467
|
+
|
|
1468
|
+
ch = hashlib.sha256()
|
|
1469
|
+
|
|
1470
|
+
# source
|
|
1471
|
+
ch.update(bytes(adj.source, "utf-8"))
|
|
1472
|
+
|
|
1473
|
+
# args
|
|
1474
|
+
for arg, arg_type in adj.arg_types.items():
|
|
1475
|
+
s = f"{arg}:{warp.types.get_type_code(arg_type)}"
|
|
1476
|
+
ch.update(bytes(s, "utf-8"))
|
|
1477
|
+
|
|
1478
|
+
# hash struct types
|
|
1479
|
+
if isinstance(arg_type, warp.codegen.Struct):
|
|
1480
|
+
ch.update(arg_type.hash)
|
|
1481
|
+
elif warp.types.is_array(arg_type) and isinstance(arg_type.dtype, warp.codegen.Struct):
|
|
1482
|
+
ch.update(arg_type.dtype.hash)
|
|
1483
|
+
|
|
1484
|
+
# find referenced constants, types, and functions
|
|
1485
|
+
constants, types, functions = adj.get_references()
|
|
1486
|
+
|
|
1487
|
+
# hash referenced constants
|
|
1488
|
+
for name, value in constants.items():
|
|
1489
|
+
ch.update(bytes(name, "utf-8"))
|
|
1490
|
+
# hash the referenced object
|
|
1491
|
+
if isinstance(value, builtins.bool):
|
|
1492
|
+
# This needs to come before the check for `int` since all boolean
|
|
1493
|
+
# values are also instances of `int`.
|
|
1494
|
+
ch.update(struct_pack("?", value))
|
|
1495
|
+
elif isinstance(value, int):
|
|
1496
|
+
ch.update(struct_pack("<q", value))
|
|
1497
|
+
elif isinstance(value, float):
|
|
1498
|
+
ch.update(struct_pack("<d", value))
|
|
1499
|
+
elif isinstance(value, warp.types.float16):
|
|
1500
|
+
# float16 is a special case
|
|
1501
|
+
p = ctypes.pointer(ctypes.c_float(value.value))
|
|
1502
|
+
ch.update(p.contents)
|
|
1503
|
+
elif isinstance(value, tuple(warp.types.scalar_types)):
|
|
1504
|
+
p = ctypes.pointer(value._type_(value.value))
|
|
1505
|
+
ch.update(p.contents)
|
|
1506
|
+
elif isinstance(value, ctypes.Array):
|
|
1507
|
+
ch.update(bytes(value))
|
|
1508
|
+
else:
|
|
1509
|
+
raise RuntimeError(f"Invalid constant type: {type(value)}")
|
|
1510
|
+
|
|
1511
|
+
# hash wp.static() expressions that were evaluated at declaration time
|
|
1512
|
+
for k, v in adj.static_expressions.items():
|
|
1513
|
+
ch.update(bytes(f"{k} = {v}", "utf-8"))
|
|
1514
|
+
|
|
1515
|
+
# hash referenced types
|
|
1516
|
+
for t in types.keys():
|
|
1517
|
+
ch.update(bytes(warp.types.get_type_code(t), "utf-8"))
|
|
1518
|
+
|
|
1519
|
+
# hash referenced functions
|
|
1520
|
+
for f in functions.keys():
|
|
1521
|
+
if f not in self.functions_in_progress:
|
|
1522
|
+
ch.update(self.hash_function(f))
|
|
1523
|
+
|
|
1524
|
+
return ch.digest()
|
|
1525
|
+
|
|
1526
|
+
def get_module_hash(self):
|
|
1527
|
+
return self.module_hash
|
|
1528
|
+
|
|
1529
|
+
def get_unique_kernels(self):
|
|
1530
|
+
return self.unique_kernels.values()
|
|
1531
|
+
|
|
1532
|
+
|
|
1292
1533
|
class ModuleBuilder:
|
|
1293
|
-
def __init__(self, module, options):
|
|
1534
|
+
def __init__(self, module, options, hasher=None):
|
|
1294
1535
|
self.functions = {}
|
|
1295
1536
|
self.structs = {}
|
|
1296
1537
|
self.options = options
|
|
1297
1538
|
self.module = module
|
|
1298
1539
|
self.deferred_functions = []
|
|
1299
1540
|
|
|
1300
|
-
|
|
1301
|
-
|
|
1302
|
-
for f in func.user_overloads.values():
|
|
1303
|
-
self.build_function(f)
|
|
1304
|
-
if f.custom_replay_func is not None:
|
|
1305
|
-
self.build_function(f.custom_replay_func)
|
|
1306
|
-
|
|
1307
|
-
# build all kernel entry points
|
|
1308
|
-
for kernel in module.kernels.values():
|
|
1309
|
-
if not kernel.is_generic:
|
|
1310
|
-
self.build_kernel(kernel)
|
|
1311
|
-
else:
|
|
1312
|
-
for k in kernel.overloads.values():
|
|
1313
|
-
self.build_kernel(k)
|
|
1541
|
+
if hasher is None:
|
|
1542
|
+
hasher = ModuleHasher(module)
|
|
1314
1543
|
|
|
1315
|
-
# build all
|
|
1544
|
+
# build all unique kernels
|
|
1545
|
+
self.kernels = hasher.get_unique_kernels()
|
|
1546
|
+
for kernel in self.kernels:
|
|
1547
|
+
self.build_kernel(kernel)
|
|
1548
|
+
|
|
1549
|
+
# build deferred functions
|
|
1316
1550
|
for func in self.deferred_functions:
|
|
1317
1551
|
self.build_function(func)
|
|
1318
1552
|
|
|
@@ -1328,7 +1562,7 @@ class ModuleBuilder:
|
|
|
1328
1562
|
for var in s.vars.values():
|
|
1329
1563
|
if isinstance(var.type, warp.codegen.Struct):
|
|
1330
1564
|
stack.append(var.type)
|
|
1331
|
-
elif
|
|
1565
|
+
elif warp.types.is_array(var.type) and isinstance(var.type.dtype, warp.codegen.Struct):
|
|
1332
1566
|
stack.append(var.type.dtype)
|
|
1333
1567
|
|
|
1334
1568
|
# Build them in reverse to generate a correct dependency order.
|
|
@@ -1374,8 +1608,12 @@ class ModuleBuilder:
|
|
|
1374
1608
|
source = ""
|
|
1375
1609
|
|
|
1376
1610
|
# code-gen structs
|
|
1611
|
+
visited_structs = set()
|
|
1377
1612
|
for struct in self.structs.keys():
|
|
1378
|
-
|
|
1613
|
+
# avoid emitting duplicates
|
|
1614
|
+
if struct.hash not in visited_structs:
|
|
1615
|
+
source += warp.codegen.codegen_struct(struct)
|
|
1616
|
+
visited_structs.add(struct.hash)
|
|
1379
1617
|
|
|
1380
1618
|
# code-gen all imported functions
|
|
1381
1619
|
for func in self.functions.keys():
|
|
@@ -1386,21 +1624,15 @@ class ModuleBuilder:
|
|
|
1386
1624
|
else:
|
|
1387
1625
|
source += warp.codegen.codegen_snippet(
|
|
1388
1626
|
func.adj,
|
|
1389
|
-
name=func.
|
|
1627
|
+
name=func.native_func,
|
|
1390
1628
|
snippet=func.native_snippet,
|
|
1391
1629
|
adj_snippet=func.adj_native_snippet,
|
|
1392
1630
|
replay_snippet=func.replay_snippet,
|
|
1393
1631
|
)
|
|
1394
1632
|
|
|
1395
|
-
for kernel in self.
|
|
1396
|
-
|
|
1397
|
-
|
|
1398
|
-
source += warp.codegen.codegen_kernel(kernel, device=device, options=self.options)
|
|
1399
|
-
source += warp.codegen.codegen_module(kernel, device=device)
|
|
1400
|
-
else:
|
|
1401
|
-
for k in kernel.overloads.values():
|
|
1402
|
-
source += warp.codegen.codegen_kernel(k, device=device, options=self.options)
|
|
1403
|
-
source += warp.codegen.codegen_module(k, device=device)
|
|
1633
|
+
for kernel in self.kernels:
|
|
1634
|
+
source += warp.codegen.codegen_kernel(kernel, device=device, options=self.options)
|
|
1635
|
+
source += warp.codegen.codegen_module(kernel, device=device)
|
|
1404
1636
|
|
|
1405
1637
|
# add headers
|
|
1406
1638
|
if device == "cpu":
|
|
@@ -1425,8 +1657,9 @@ class ModuleExec:
|
|
|
1425
1657
|
instance.handle = None
|
|
1426
1658
|
return instance
|
|
1427
1659
|
|
|
1428
|
-
def __init__(self, handle, device):
|
|
1660
|
+
def __init__(self, handle, module_hash, device):
|
|
1429
1661
|
self.handle = handle
|
|
1662
|
+
self.module_hash = module_hash
|
|
1430
1663
|
self.device = device
|
|
1431
1664
|
self.kernel_hooks = {}
|
|
1432
1665
|
|
|
@@ -1457,8 +1690,12 @@ class ModuleExec:
|
|
|
1457
1690
|
)
|
|
1458
1691
|
else:
|
|
1459
1692
|
func = ctypes.CFUNCTYPE(None)
|
|
1460
|
-
forward =
|
|
1461
|
-
|
|
1693
|
+
forward = (
|
|
1694
|
+
func(runtime.llvm.lookup(self.handle.encode("utf-8"), (name + "_cpu_forward").encode("utf-8"))) or None
|
|
1695
|
+
)
|
|
1696
|
+
backward = (
|
|
1697
|
+
func(runtime.llvm.lookup(self.handle.encode("utf-8"), (name + "_cpu_backward").encode("utf-8"))) or None
|
|
1698
|
+
)
|
|
1462
1699
|
|
|
1463
1700
|
hooks = KernelHooks(forward, backward)
|
|
1464
1701
|
self.kernel_hooks[kernel] = hooks
|
|
@@ -1475,16 +1712,33 @@ class Module:
|
|
|
1475
1712
|
self.name = name
|
|
1476
1713
|
self.loader = loader
|
|
1477
1714
|
|
|
1478
|
-
|
|
1479
|
-
self.
|
|
1480
|
-
self.
|
|
1481
|
-
self.structs = {}
|
|
1715
|
+
# lookup the latest versions of kernels, functions, and structs by key
|
|
1716
|
+
self.kernels = {} # (key: kernel)
|
|
1717
|
+
self.functions = {} # (key: function)
|
|
1718
|
+
self.structs = {} # (key: struct)
|
|
1719
|
+
|
|
1720
|
+
# Set of all "live" kernels in this module.
|
|
1721
|
+
# The difference between `live_kernels` and `kernels` is that `live_kernels` may contain
|
|
1722
|
+
# multiple kernels with the same key (which is essential to support closures), while `kernels`
|
|
1723
|
+
# only holds the latest kernel for each key. When the module is built, we compute the hash
|
|
1724
|
+
# of each kernel in `live_kernels` and filter out duplicates for codegen.
|
|
1725
|
+
self.live_kernels = weakref.WeakSet()
|
|
1482
1726
|
|
|
1483
|
-
|
|
1484
|
-
self.
|
|
1727
|
+
# executable modules currently loaded
|
|
1728
|
+
self.execs = {} # (device.context: ModuleExec)
|
|
1485
1729
|
|
|
1486
|
-
|
|
1487
|
-
self.
|
|
1730
|
+
# set of device contexts where the build has failed
|
|
1731
|
+
self.failed_builds = set()
|
|
1732
|
+
|
|
1733
|
+
# hash data, including the module hash
|
|
1734
|
+
self.hasher = None
|
|
1735
|
+
|
|
1736
|
+
# LLVM executable modules are identified using strings. Since it's possible for multiple
|
|
1737
|
+
# executable versions to be loaded at the same time, we need a way to ensure uniqueness.
|
|
1738
|
+
# A unique handle is created from the module name and this auto-incremented integer id.
|
|
1739
|
+
# NOTE: The module hash is not sufficient for uniqueness in rare cases where a module
|
|
1740
|
+
# is retained and later reloaded with the same hash.
|
|
1741
|
+
self.cpu_exec_id = 0
|
|
1488
1742
|
|
|
1489
1743
|
self.options = {
|
|
1490
1744
|
"max_unroll": warp.config.max_unroll,
|
|
@@ -1497,11 +1751,6 @@ class Module:
|
|
|
1497
1751
|
# Module dependencies are determined by scanning each function
|
|
1498
1752
|
# and kernel for references to external functions and structs.
|
|
1499
1753
|
#
|
|
1500
|
-
# When a referenced module is modified, all of its dependents need to be reloaded
|
|
1501
|
-
# on the next launch. To detect this, a module's hash recursively includes
|
|
1502
|
-
# all of its references.
|
|
1503
|
-
# -> See ``Module.hash_module()``
|
|
1504
|
-
#
|
|
1505
1754
|
# The dependency mechanism works for both static and dynamic (runtime) modifications.
|
|
1506
1755
|
# When a module is reloaded at runtime, we recursively unload all of its
|
|
1507
1756
|
# dependents, so that they will be re-hashed and reloaded on the next launch.
|
|
@@ -1510,40 +1759,39 @@ class Module:
|
|
|
1510
1759
|
self.references = set() # modules whose content we depend on
|
|
1511
1760
|
self.dependents = set() # modules that depend on our content
|
|
1512
1761
|
|
|
1513
|
-
# Since module hashing is recursive, we improve performance by caching the hash of the
|
|
1514
|
-
# module contents (kernel source, function source, and struct source).
|
|
1515
|
-
# After all kernels, functions, and structs are added to the module (usually at import time),
|
|
1516
|
-
# the content hash doesn't change.
|
|
1517
|
-
# -> See ``Module.hash_module_recursive()``
|
|
1518
|
-
|
|
1519
|
-
self.content_hash = None
|
|
1520
|
-
|
|
1521
|
-
# number of times module auto-generates kernel key for user
|
|
1522
|
-
# used to ensure unique kernel keys
|
|
1523
|
-
self.count = 0
|
|
1524
|
-
|
|
1525
1762
|
def register_struct(self, struct):
|
|
1526
1763
|
self.structs[struct.key] = struct
|
|
1527
1764
|
|
|
1528
1765
|
# for a reload of module on next launch
|
|
1529
|
-
self.
|
|
1766
|
+
self.mark_modified()
|
|
1530
1767
|
|
|
1531
1768
|
def register_kernel(self, kernel):
|
|
1769
|
+
# keep a reference to the latest version
|
|
1532
1770
|
self.kernels[kernel.key] = kernel
|
|
1533
1771
|
|
|
1772
|
+
# track all kernel objects, even if they are duplicates
|
|
1773
|
+
self.live_kernels.add(kernel)
|
|
1774
|
+
|
|
1534
1775
|
self.find_references(kernel.adj)
|
|
1535
1776
|
|
|
1536
1777
|
# for a reload of module on next launch
|
|
1537
|
-
self.
|
|
1778
|
+
self.mark_modified()
|
|
1538
1779
|
|
|
1539
|
-
def register_function(self, func, skip_adding_overload=False):
|
|
1540
|
-
|
|
1541
|
-
|
|
1780
|
+
def register_function(self, func, scope_locals, skip_adding_overload=False):
|
|
1781
|
+
# check for another Function with the same name in the same scope
|
|
1782
|
+
obj = scope_locals.get(func.func.__name__)
|
|
1783
|
+
if isinstance(obj, Function):
|
|
1784
|
+
func_existing = obj
|
|
1542
1785
|
else:
|
|
1786
|
+
func_existing = None
|
|
1787
|
+
|
|
1788
|
+
# keep a reference to the latest version
|
|
1789
|
+
self.functions[func.key] = func_existing or func
|
|
1790
|
+
|
|
1791
|
+
if func_existing:
|
|
1543
1792
|
# Check whether the new function's signature match any that has
|
|
1544
1793
|
# already been registered. If so, then we simply override it, as
|
|
1545
1794
|
# Python would do it, otherwise we register it as a new overload.
|
|
1546
|
-
func_existing = self.functions[func.key]
|
|
1547
1795
|
sig = warp.types.get_signature(
|
|
1548
1796
|
func.input_types.values(),
|
|
1549
1797
|
func_name=func.key,
|
|
@@ -1555,19 +1803,43 @@ class Module:
|
|
|
1555
1803
|
arg_names=list(func_existing.input_types.keys()),
|
|
1556
1804
|
)
|
|
1557
1805
|
if sig == sig_existing:
|
|
1806
|
+
# replace the top-level function, but keep existing overloads
|
|
1807
|
+
|
|
1808
|
+
# copy generic overloads
|
|
1809
|
+
func.user_templates = func_existing.user_templates.copy()
|
|
1810
|
+
|
|
1811
|
+
# copy concrete overloads
|
|
1812
|
+
if warp.types.is_generic_signature(sig):
|
|
1813
|
+
# skip overloads that were instantiated from the function being replaced
|
|
1814
|
+
for k, v in func_existing.user_overloads.items():
|
|
1815
|
+
if v.generic_parent != func_existing:
|
|
1816
|
+
func.user_overloads[k] = v
|
|
1817
|
+
func.user_templates[sig] = func
|
|
1818
|
+
else:
|
|
1819
|
+
func.user_overloads = func_existing.user_overloads.copy()
|
|
1820
|
+
func.user_overloads[sig] = func
|
|
1821
|
+
|
|
1558
1822
|
self.functions[func.key] = func
|
|
1559
1823
|
elif not skip_adding_overload:
|
|
1824
|
+
# check if this is a generic overload that replaces an existing one
|
|
1825
|
+
if warp.types.is_generic_signature(sig):
|
|
1826
|
+
old_generic = func_existing.user_templates.get(sig)
|
|
1827
|
+
if old_generic is not None:
|
|
1828
|
+
# purge any concrete overloads that were instantiated from the old one
|
|
1829
|
+
for k, v in list(func_existing.user_overloads.items()):
|
|
1830
|
+
if v.generic_parent == old_generic:
|
|
1831
|
+
del func_existing.user_overloads[k]
|
|
1560
1832
|
func_existing.add_overload(func)
|
|
1561
1833
|
|
|
1562
1834
|
self.find_references(func.adj)
|
|
1563
1835
|
|
|
1564
1836
|
# for a reload of module on next launch
|
|
1565
|
-
self.
|
|
1837
|
+
self.mark_modified()
|
|
1566
1838
|
|
|
1567
|
-
|
|
1568
|
-
|
|
1569
|
-
|
|
1570
|
-
return
|
|
1839
|
+
# find kernel corresponding to a Python function
|
|
1840
|
+
def find_kernel(self, func):
|
|
1841
|
+
qualname = warp.codegen.make_full_qualified_name(func)
|
|
1842
|
+
return self.kernels.get(qualname)
|
|
1571
1843
|
|
|
1572
1844
|
# collect all referenced functions / structs
|
|
1573
1845
|
# given the AST of a function or kernel
|
|
@@ -1599,165 +1871,30 @@ class Module:
|
|
|
1599
1871
|
if isinstance(arg.type, warp.codegen.Struct) and arg.type.module is not None:
|
|
1600
1872
|
add_ref(arg.type.module)
|
|
1601
1873
|
|
|
1602
|
-
def hash_module(self
|
|
1603
|
-
|
|
1604
|
-
|
|
1605
|
-
|
|
1606
|
-
computed ``content_hash`` will be used.
|
|
1607
|
-
"""
|
|
1608
|
-
|
|
1609
|
-
def get_type_name(type_hint) -> str:
|
|
1610
|
-
if isinstance(type_hint, warp.codegen.Struct):
|
|
1611
|
-
return get_type_name(type_hint.cls)
|
|
1612
|
-
elif isinstance(type_hint, warp.array) and isinstance(type_hint.dtype, warp.codegen.Struct):
|
|
1613
|
-
return f"array{get_type_name(type_hint.dtype)}"
|
|
1614
|
-
|
|
1615
|
-
return str(type_hint)
|
|
1616
|
-
|
|
1617
|
-
def hash_recursive(module, visited):
|
|
1618
|
-
# Hash this module, including all referenced modules recursively.
|
|
1619
|
-
# The visited set tracks modules already visited to avoid circular references.
|
|
1620
|
-
|
|
1621
|
-
# check if we need to update the content hash
|
|
1622
|
-
if not module.content_hash or recompute_content_hash:
|
|
1623
|
-
# recompute content hash
|
|
1624
|
-
ch = hashlib.sha256()
|
|
1625
|
-
|
|
1626
|
-
# Start with an empty constants dictionary in case any have been removed
|
|
1627
|
-
module.constants = {}
|
|
1628
|
-
|
|
1629
|
-
# struct source
|
|
1630
|
-
for struct in module.structs.values():
|
|
1631
|
-
s = ",".join(
|
|
1632
|
-
"{}: {}".format(name, get_type_name(type_hint))
|
|
1633
|
-
for name, type_hint in warp.codegen.get_annotations(struct.cls).items()
|
|
1634
|
-
)
|
|
1635
|
-
ch.update(bytes(s, "utf-8"))
|
|
1636
|
-
|
|
1637
|
-
# functions source
|
|
1638
|
-
for function in module.functions.values():
|
|
1639
|
-
# include all concrete and generic overloads
|
|
1640
|
-
overloads = itertools.chain(function.user_overloads.items(), function.user_templates.items())
|
|
1641
|
-
for sig, func in overloads:
|
|
1642
|
-
# signature
|
|
1643
|
-
ch.update(bytes(sig, "utf-8"))
|
|
1644
|
-
|
|
1645
|
-
# source
|
|
1646
|
-
ch.update(bytes(func.adj.source, "utf-8"))
|
|
1647
|
-
|
|
1648
|
-
if func.custom_grad_func:
|
|
1649
|
-
ch.update(bytes(func.custom_grad_func.adj.source, "utf-8"))
|
|
1650
|
-
if func.custom_replay_func:
|
|
1651
|
-
ch.update(bytes(func.custom_replay_func.adj.source, "utf-8"))
|
|
1652
|
-
if func.replay_snippet:
|
|
1653
|
-
ch.update(bytes(func.replay_snippet, "utf-8"))
|
|
1654
|
-
if func.native_snippet:
|
|
1655
|
-
ch.update(bytes(func.native_snippet, "utf-8"))
|
|
1656
|
-
if func.adj_native_snippet:
|
|
1657
|
-
ch.update(bytes(func.adj_native_snippet, "utf-8"))
|
|
1658
|
-
|
|
1659
|
-
# Populate constants referenced in this function
|
|
1660
|
-
if func.adj:
|
|
1661
|
-
module.constants.update(func.adj.get_constant_references())
|
|
1662
|
-
|
|
1663
|
-
# kernel source
|
|
1664
|
-
for kernel in module.kernels.values():
|
|
1665
|
-
ch.update(bytes(kernel.key, "utf-8"))
|
|
1666
|
-
ch.update(bytes(kernel.adj.source, "utf-8"))
|
|
1667
|
-
# cache kernel arg types
|
|
1668
|
-
for arg, arg_type in kernel.adj.arg_types.items():
|
|
1669
|
-
s = f"{arg}: {get_type_name(arg_type)}"
|
|
1670
|
-
ch.update(bytes(s, "utf-8"))
|
|
1671
|
-
# for generic kernels the Python source is always the same,
|
|
1672
|
-
# but we hash the type signatures of all the overloads
|
|
1673
|
-
if kernel.is_generic:
|
|
1674
|
-
for sig in sorted(kernel.overloads.keys()):
|
|
1675
|
-
ch.update(bytes(sig, "utf-8"))
|
|
1676
|
-
|
|
1677
|
-
# Populate constants referenced in this kernel
|
|
1678
|
-
module.constants.update(kernel.adj.get_constant_references())
|
|
1679
|
-
|
|
1680
|
-
# constants referenced in this module
|
|
1681
|
-
for constant_name, constant_value in module.constants.items():
|
|
1682
|
-
ch.update(bytes(constant_name, "utf-8"))
|
|
1683
|
-
|
|
1684
|
-
# hash the constant value
|
|
1685
|
-
if isinstance(constant_value, builtins.bool):
|
|
1686
|
-
# This needs to come before the check for `int` since all boolean
|
|
1687
|
-
# values are also instances of `int`.
|
|
1688
|
-
ch.update(struct_pack("?", constant_value))
|
|
1689
|
-
elif isinstance(constant_value, int):
|
|
1690
|
-
ch.update(struct_pack("<q", constant_value))
|
|
1691
|
-
elif isinstance(constant_value, float):
|
|
1692
|
-
ch.update(struct_pack("<d", constant_value))
|
|
1693
|
-
elif isinstance(constant_value, warp.types.float16):
|
|
1694
|
-
# float16 is a special case
|
|
1695
|
-
p = ctypes.pointer(ctypes.c_float(constant_value.value))
|
|
1696
|
-
ch.update(p.contents)
|
|
1697
|
-
elif isinstance(constant_value, tuple(warp.types.scalar_types)):
|
|
1698
|
-
p = ctypes.pointer(constant_value._type_(constant_value.value))
|
|
1699
|
-
ch.update(p.contents)
|
|
1700
|
-
elif isinstance(constant_value, ctypes.Array):
|
|
1701
|
-
ch.update(bytes(constant_value))
|
|
1702
|
-
else:
|
|
1703
|
-
raise RuntimeError(f"Invalid constant type: {type(constant_value)}")
|
|
1704
|
-
|
|
1705
|
-
module.content_hash = ch.digest()
|
|
1706
|
-
|
|
1707
|
-
h = hashlib.sha256()
|
|
1708
|
-
|
|
1709
|
-
# content hash
|
|
1710
|
-
h.update(module.content_hash)
|
|
1711
|
-
|
|
1712
|
-
# configuration parameters
|
|
1713
|
-
for k in sorted(module.options.keys()):
|
|
1714
|
-
s = f"{k}={module.options[k]}"
|
|
1715
|
-
h.update(bytes(s, "utf-8"))
|
|
1716
|
-
|
|
1717
|
-
# ensure to trigger recompilation if flags affecting kernel compilation are changed
|
|
1718
|
-
if warp.config.verify_fp:
|
|
1719
|
-
h.update(bytes("verify_fp", "utf-8"))
|
|
1720
|
-
|
|
1721
|
-
h.update(bytes(warp.config.mode, "utf-8"))
|
|
1722
|
-
|
|
1723
|
-
# recurse on references
|
|
1724
|
-
visited.add(module)
|
|
1725
|
-
|
|
1726
|
-
sorted_deps = sorted(module.references, key=lambda m: m.name)
|
|
1727
|
-
for dep in sorted_deps:
|
|
1728
|
-
if dep not in visited:
|
|
1729
|
-
dep_hash = hash_recursive(dep, visited)
|
|
1730
|
-
h.update(dep_hash)
|
|
1731
|
-
|
|
1732
|
-
return h.digest()
|
|
1733
|
-
|
|
1734
|
-
return hash_recursive(self, visited=set())
|
|
1874
|
+
def hash_module(self):
|
|
1875
|
+
# compute latest hash
|
|
1876
|
+
self.hasher = ModuleHasher(self)
|
|
1877
|
+
return self.hasher.get_module_hash()
|
|
1735
1878
|
|
|
1736
1879
|
def load(self, device) -> ModuleExec:
|
|
1737
1880
|
device = runtime.get_device(device)
|
|
1738
1881
|
|
|
1739
|
-
if
|
|
1740
|
-
|
|
1741
|
-
|
|
1742
|
-
|
|
1743
|
-
|
|
1744
|
-
|
|
1745
|
-
|
|
1746
|
-
if
|
|
1747
|
-
|
|
1748
|
-
|
|
1749
|
-
|
|
1750
|
-
|
|
1751
|
-
|
|
1752
|
-
return cuda_exec
|
|
1753
|
-
# avoid repeated build attempts
|
|
1754
|
-
if self.cuda_build_failed:
|
|
1755
|
-
return None
|
|
1756
|
-
if not warp.is_cuda_available():
|
|
1757
|
-
raise RuntimeError("Failed to build CUDA module because CUDA is not available")
|
|
1882
|
+
# compute the hash if needed
|
|
1883
|
+
if self.hasher is None:
|
|
1884
|
+
self.hasher = ModuleHasher(self)
|
|
1885
|
+
|
|
1886
|
+
# check if executable module is already loaded and not stale
|
|
1887
|
+
exec = self.execs.get(device.context)
|
|
1888
|
+
if exec is not None:
|
|
1889
|
+
if exec.module_hash == self.hasher.module_hash:
|
|
1890
|
+
return exec
|
|
1891
|
+
|
|
1892
|
+
# quietly avoid repeated build attempts to reduce error spew
|
|
1893
|
+
if device.context in self.failed_builds:
|
|
1894
|
+
return None
|
|
1758
1895
|
|
|
1759
1896
|
module_name = "wp_" + self.name
|
|
1760
|
-
module_hash = self.
|
|
1897
|
+
module_hash = self.hasher.module_hash
|
|
1761
1898
|
|
|
1762
1899
|
# use a unique module path using the module short hash
|
|
1763
1900
|
module_dir = os.path.join(warp.config.kernel_cache_dir, f"{module_name}_{module_hash.hex()[:7]}")
|
|
@@ -1807,7 +1944,7 @@ class Module:
|
|
|
1807
1944
|
or not warp.config.cache_kernels
|
|
1808
1945
|
or warp.config.verify_autograd_array_access
|
|
1809
1946
|
):
|
|
1810
|
-
builder = ModuleBuilder(self, self.options)
|
|
1947
|
+
builder = ModuleBuilder(self, self.options, hasher=self.hasher)
|
|
1811
1948
|
|
|
1812
1949
|
# create a temporary (process unique) dir for build outputs before moving to the binary dir
|
|
1813
1950
|
build_dir = os.path.join(
|
|
@@ -1844,7 +1981,7 @@ class Module:
|
|
|
1844
1981
|
)
|
|
1845
1982
|
|
|
1846
1983
|
except Exception as e:
|
|
1847
|
-
self.
|
|
1984
|
+
self.failed_builds.add(None)
|
|
1848
1985
|
module_load_timer.extra_msg = " (error)"
|
|
1849
1986
|
raise (e)
|
|
1850
1987
|
|
|
@@ -1873,7 +2010,7 @@ class Module:
|
|
|
1873
2010
|
)
|
|
1874
2011
|
|
|
1875
2012
|
except Exception as e:
|
|
1876
|
-
self.
|
|
2013
|
+
self.failed_builds.add(device.context)
|
|
1877
2014
|
module_load_timer.extra_msg = " (error)"
|
|
1878
2015
|
raise (e)
|
|
1879
2016
|
|
|
@@ -1914,15 +2051,18 @@ class Module:
|
|
|
1914
2051
|
# -----------------------------------------------------------
|
|
1915
2052
|
# Load CPU or CUDA binary
|
|
1916
2053
|
if device.is_cpu:
|
|
1917
|
-
|
|
1918
|
-
|
|
1919
|
-
self.
|
|
2054
|
+
# LLVM modules are identified using strings, so we need to ensure uniqueness
|
|
2055
|
+
module_handle = f"{module_name}_{self.cpu_exec_id}"
|
|
2056
|
+
self.cpu_exec_id += 1
|
|
2057
|
+
runtime.llvm.load_obj(binary_path.encode("utf-8"), module_handle.encode("utf-8"))
|
|
2058
|
+
module_exec = ModuleExec(module_handle, module_hash, device)
|
|
2059
|
+
self.execs[None] = module_exec
|
|
1920
2060
|
|
|
1921
2061
|
elif device.is_cuda:
|
|
1922
2062
|
cuda_module = warp.build.load_cuda(binary_path, device)
|
|
1923
2063
|
if cuda_module is not None:
|
|
1924
|
-
module_exec = ModuleExec(cuda_module, device)
|
|
1925
|
-
self.
|
|
2064
|
+
module_exec = ModuleExec(cuda_module, module_hash, device)
|
|
2065
|
+
self.execs[device.context] = module_exec
|
|
1926
2066
|
else:
|
|
1927
2067
|
module_load_timer.extra_msg = " (error)"
|
|
1928
2068
|
raise Exception(f"Failed to load CUDA module '{self.name}'")
|
|
@@ -1936,20 +2076,22 @@ class Module:
|
|
|
1936
2076
|
return module_exec
|
|
1937
2077
|
|
|
1938
2078
|
def unload(self):
|
|
2079
|
+
# force rehashing on next load
|
|
2080
|
+
self.mark_modified()
|
|
2081
|
+
|
|
1939
2082
|
# clear loaded modules
|
|
1940
|
-
self.
|
|
1941
|
-
|
|
2083
|
+
self.execs = {}
|
|
2084
|
+
|
|
2085
|
+
def mark_modified(self):
|
|
2086
|
+
# clear hash data
|
|
2087
|
+
self.hasher = None
|
|
1942
2088
|
|
|
1943
|
-
# clear
|
|
1944
|
-
self.
|
|
2089
|
+
# clear build failures
|
|
2090
|
+
self.failed_builds = set()
|
|
1945
2091
|
|
|
1946
2092
|
# lookup kernel entry points based on name, called after compilation / module load
|
|
1947
2093
|
def get_kernel_hooks(self, kernel, device):
|
|
1948
|
-
|
|
1949
|
-
module_exec = self.cuda_execs.get(device.context)
|
|
1950
|
-
else:
|
|
1951
|
-
module_exec = self.cpu_exec
|
|
1952
|
-
|
|
2094
|
+
module_exec = self.execs.get(device.context)
|
|
1953
2095
|
if module_exec is not None:
|
|
1954
2096
|
return module_exec.get_kernel_hooks(kernel)
|
|
1955
2097
|
else:
|
|
@@ -2056,6 +2198,63 @@ class ContextGuard:
|
|
|
2056
2198
|
runtime.core.cuda_context_set_current(self.saved_context)
|
|
2057
2199
|
|
|
2058
2200
|
|
|
2201
|
+
class Event:
|
|
2202
|
+
"""A CUDA event that can be recorded onto a stream.
|
|
2203
|
+
|
|
2204
|
+
Events can be used for device-side synchronization, which do not block
|
|
2205
|
+
the host thread.
|
|
2206
|
+
"""
|
|
2207
|
+
|
|
2208
|
+
# event creation flags
|
|
2209
|
+
class Flags:
|
|
2210
|
+
DEFAULT = 0x0
|
|
2211
|
+
BLOCKING_SYNC = 0x1
|
|
2212
|
+
DISABLE_TIMING = 0x2
|
|
2213
|
+
|
|
2214
|
+
def __new__(cls, *args, **kwargs):
|
|
2215
|
+
"""Creates a new event instance."""
|
|
2216
|
+
instance = super(Event, cls).__new__(cls)
|
|
2217
|
+
instance.owner = False
|
|
2218
|
+
return instance
|
|
2219
|
+
|
|
2220
|
+
def __init__(self, device: "Devicelike" = None, cuda_event=None, enable_timing: bool = False):
|
|
2221
|
+
"""Initializes the event on a CUDA device.
|
|
2222
|
+
|
|
2223
|
+
Args:
|
|
2224
|
+
device: The CUDA device whose streams this event may be recorded onto.
|
|
2225
|
+
If ``None``, then the current default device will be used.
|
|
2226
|
+
cuda_event: A pointer to a previously allocated CUDA event. If
|
|
2227
|
+
`None`, then a new event will be allocated on the associated device.
|
|
2228
|
+
enable_timing: If ``True`` this event will record timing data.
|
|
2229
|
+
:func:`~warp.get_event_elapsed_time` can be used to measure the
|
|
2230
|
+
time between two events created with ``enable_timing=True`` and
|
|
2231
|
+
recorded onto streams.
|
|
2232
|
+
"""
|
|
2233
|
+
|
|
2234
|
+
device = get_device(device)
|
|
2235
|
+
if not device.is_cuda:
|
|
2236
|
+
raise RuntimeError(f"Device {device} is not a CUDA device")
|
|
2237
|
+
|
|
2238
|
+
self.device = device
|
|
2239
|
+
|
|
2240
|
+
if cuda_event is not None:
|
|
2241
|
+
self.cuda_event = cuda_event
|
|
2242
|
+
else:
|
|
2243
|
+
flags = Event.Flags.DEFAULT
|
|
2244
|
+
if not enable_timing:
|
|
2245
|
+
flags |= Event.Flags.DISABLE_TIMING
|
|
2246
|
+
self.cuda_event = runtime.core.cuda_event_create(device.context, flags)
|
|
2247
|
+
if not self.cuda_event:
|
|
2248
|
+
raise RuntimeError(f"Failed to create event on device {device}")
|
|
2249
|
+
self.owner = True
|
|
2250
|
+
|
|
2251
|
+
def __del__(self):
|
|
2252
|
+
if not self.owner:
|
|
2253
|
+
return
|
|
2254
|
+
|
|
2255
|
+
runtime.core.cuda_event_destroy(self.cuda_event)
|
|
2256
|
+
|
|
2257
|
+
|
|
2059
2258
|
class Stream:
|
|
2060
2259
|
def __new__(cls, *args, **kwargs):
|
|
2061
2260
|
instance = super(Stream, cls).__new__(cls)
|
|
@@ -2063,7 +2262,27 @@ class Stream:
|
|
|
2063
2262
|
instance.owner = False
|
|
2064
2263
|
return instance
|
|
2065
2264
|
|
|
2066
|
-
def __init__(self, device=None, **kwargs):
|
|
2265
|
+
def __init__(self, device: Optional[Union["Device", str]] = None, priority: int = 0, **kwargs):
|
|
2266
|
+
"""Initialize the stream on a device with an optional specified priority.
|
|
2267
|
+
|
|
2268
|
+
Args:
|
|
2269
|
+
device: The CUDA device on which this stream will be created.
|
|
2270
|
+
priority: An optional integer specifying the requested stream priority.
|
|
2271
|
+
Can be -1 (high priority) or 0 (low/default priority).
|
|
2272
|
+
Values outside this range will be clamped.
|
|
2273
|
+
cuda_stream (int): A optional external stream handle passed as an
|
|
2274
|
+
integer. The caller is responsible for ensuring that the external
|
|
2275
|
+
stream does not get destroyed while it is referenced by this
|
|
2276
|
+
object.
|
|
2277
|
+
|
|
2278
|
+
Raises:
|
|
2279
|
+
RuntimeError: If function is called before Warp has completed
|
|
2280
|
+
initialization with a ``device`` that is not an instance of
|
|
2281
|
+
:class:`Device``.
|
|
2282
|
+
RuntimeError: ``device`` is not a CUDA Device.
|
|
2283
|
+
RuntimeError: The stream could not be created on the device.
|
|
2284
|
+
TypeError: The requested stream priority is not an integer.
|
|
2285
|
+
"""
|
|
2067
2286
|
# event used internally for synchronization (cached to avoid creating temporary events)
|
|
2068
2287
|
self._cached_event = None
|
|
2069
2288
|
|
|
@@ -2072,7 +2291,7 @@ class Stream:
|
|
|
2072
2291
|
device = runtime.get_device(device)
|
|
2073
2292
|
elif not isinstance(device, Device):
|
|
2074
2293
|
raise RuntimeError(
|
|
2075
|
-
"A
|
|
2294
|
+
"A Device object is required when creating a stream before or during Warp initialization"
|
|
2076
2295
|
)
|
|
2077
2296
|
|
|
2078
2297
|
if not device.is_cuda:
|
|
@@ -2085,7 +2304,11 @@ class Stream:
|
|
|
2085
2304
|
self.cuda_stream = kwargs["cuda_stream"]
|
|
2086
2305
|
device.runtime.core.cuda_stream_register(device.context, self.cuda_stream)
|
|
2087
2306
|
else:
|
|
2088
|
-
|
|
2307
|
+
if not isinstance(priority, int):
|
|
2308
|
+
raise TypeError("Stream priority must be an integer.")
|
|
2309
|
+
clamped_priority = max(-1, min(priority, 0)) # Only support two priority levels
|
|
2310
|
+
self.cuda_stream = device.runtime.core.cuda_stream_create(device.context, clamped_priority)
|
|
2311
|
+
|
|
2089
2312
|
if not self.cuda_stream:
|
|
2090
2313
|
raise RuntimeError(f"Failed to create stream on device {device}")
|
|
2091
2314
|
self.owner = True
|
|
@@ -2100,12 +2323,22 @@ class Stream:
|
|
|
2100
2323
|
runtime.core.cuda_stream_unregister(self.device.context, self.cuda_stream)
|
|
2101
2324
|
|
|
2102
2325
|
@property
|
|
2103
|
-
def cached_event(self):
|
|
2326
|
+
def cached_event(self) -> Event:
|
|
2104
2327
|
if self._cached_event is None:
|
|
2105
2328
|
self._cached_event = Event(self.device)
|
|
2106
2329
|
return self._cached_event
|
|
2107
2330
|
|
|
2108
|
-
def record_event(self, event=None):
|
|
2331
|
+
def record_event(self, event: Optional[Event] = None) -> Event:
|
|
2332
|
+
"""Record an event onto the stream.
|
|
2333
|
+
|
|
2334
|
+
Args:
|
|
2335
|
+
event: A warp.Event instance to be recorded onto the stream. If not
|
|
2336
|
+
provided, an :class:`~warp.Event` on the same device will be created.
|
|
2337
|
+
|
|
2338
|
+
Raises:
|
|
2339
|
+
RuntimeError: The provided :class:`~warp.Event` is from a different device than
|
|
2340
|
+
the recording stream.
|
|
2341
|
+
"""
|
|
2109
2342
|
if event is None:
|
|
2110
2343
|
event = Event(self.device)
|
|
2111
2344
|
elif event.device != self.device:
|
|
@@ -2117,56 +2350,45 @@ class Stream:
|
|
|
2117
2350
|
|
|
2118
2351
|
return event
|
|
2119
2352
|
|
|
2120
|
-
def wait_event(self, event):
|
|
2353
|
+
def wait_event(self, event: Event):
|
|
2354
|
+
"""Makes all future work in this stream wait until `event` has completed.
|
|
2355
|
+
|
|
2356
|
+
This function does not block the host thread.
|
|
2357
|
+
"""
|
|
2121
2358
|
runtime.core.cuda_stream_wait_event(self.cuda_stream, event.cuda_event)
|
|
2122
2359
|
|
|
2123
|
-
def wait_stream(self, other_stream, event=None):
|
|
2360
|
+
def wait_stream(self, other_stream: "Stream", event: Optional[Event] = None):
|
|
2361
|
+
"""Records an event on `other_stream` and makes this stream wait on it.
|
|
2362
|
+
|
|
2363
|
+
All work added to this stream after this function has been called will
|
|
2364
|
+
delay their execution until all preceding commands in `other_stream`
|
|
2365
|
+
have completed.
|
|
2366
|
+
|
|
2367
|
+
This function does not block the host thread.
|
|
2368
|
+
|
|
2369
|
+
Args:
|
|
2370
|
+
other_stream: The stream on which the calling stream will wait for
|
|
2371
|
+
previously issued commands to complete before executing subsequent
|
|
2372
|
+
commands.
|
|
2373
|
+
event: An optional :class:`Event` instance that will be used to
|
|
2374
|
+
record an event onto ``other_stream``. If ``None``, an internally
|
|
2375
|
+
managed :class:`Event` instance will be used.
|
|
2376
|
+
"""
|
|
2377
|
+
|
|
2124
2378
|
if event is None:
|
|
2125
2379
|
event = other_stream.cached_event
|
|
2126
2380
|
|
|
2127
2381
|
runtime.core.cuda_stream_wait_stream(self.cuda_stream, other_stream.cuda_stream, event.cuda_event)
|
|
2128
2382
|
|
|
2129
|
-
# whether a graph capture is currently ongoing on this stream
|
|
2130
2383
|
@property
|
|
2131
|
-
def is_capturing(self):
|
|
2384
|
+
def is_capturing(self) -> bool:
|
|
2385
|
+
"""A boolean indicating whether a graph capture is currently ongoing on this stream."""
|
|
2132
2386
|
return bool(runtime.core.cuda_stream_is_capturing(self.cuda_stream))
|
|
2133
2387
|
|
|
2134
|
-
|
|
2135
|
-
|
|
2136
|
-
|
|
2137
|
-
|
|
2138
|
-
DEFAULT = 0x0
|
|
2139
|
-
BLOCKING_SYNC = 0x1
|
|
2140
|
-
DISABLE_TIMING = 0x2
|
|
2141
|
-
|
|
2142
|
-
def __new__(cls, *args, **kwargs):
|
|
2143
|
-
instance = super(Event, cls).__new__(cls)
|
|
2144
|
-
instance.owner = False
|
|
2145
|
-
return instance
|
|
2146
|
-
|
|
2147
|
-
def __init__(self, device=None, cuda_event=None, enable_timing=False):
|
|
2148
|
-
device = get_device(device)
|
|
2149
|
-
if not device.is_cuda:
|
|
2150
|
-
raise RuntimeError(f"Device {device} is not a CUDA device")
|
|
2151
|
-
|
|
2152
|
-
self.device = device
|
|
2153
|
-
|
|
2154
|
-
if cuda_event is not None:
|
|
2155
|
-
self.cuda_event = cuda_event
|
|
2156
|
-
else:
|
|
2157
|
-
flags = Event.Flags.DEFAULT
|
|
2158
|
-
if not enable_timing:
|
|
2159
|
-
flags |= Event.Flags.DISABLE_TIMING
|
|
2160
|
-
self.cuda_event = runtime.core.cuda_event_create(device.context, flags)
|
|
2161
|
-
if not self.cuda_event:
|
|
2162
|
-
raise RuntimeError(f"Failed to create event on device {device}")
|
|
2163
|
-
self.owner = True
|
|
2164
|
-
|
|
2165
|
-
def __del__(self):
|
|
2166
|
-
if not self.owner:
|
|
2167
|
-
return
|
|
2168
|
-
|
|
2169
|
-
runtime.core.cuda_event_destroy(self.cuda_event)
|
|
2388
|
+
@property
|
|
2389
|
+
def priority(self) -> int:
|
|
2390
|
+
"""An integer representing the priority of the stream."""
|
|
2391
|
+
return runtime.core.cuda_stream_get_priority(self.cuda_stream)
|
|
2170
2392
|
|
|
2171
2393
|
|
|
2172
2394
|
class Device:
|
|
@@ -2178,14 +2400,14 @@ class Device:
|
|
|
2178
2400
|
or ``"CPU"`` if the processor name cannot be determined.
|
|
2179
2401
|
arch: An integer representing the compute capability version number calculated as
|
|
2180
2402
|
``10 * major + minor``. ``0`` for CPU devices.
|
|
2181
|
-
is_uva: A boolean indicating whether
|
|
2403
|
+
is_uva: A boolean indicating whether the device supports unified addressing.
|
|
2182
2404
|
``False`` for CPU devices.
|
|
2183
|
-
is_cubin_supported: A boolean indicating whether
|
|
2405
|
+
is_cubin_supported: A boolean indicating whether Warp's version of NVRTC can directly
|
|
2184
2406
|
generate CUDA binary files (cubin) for this device's architecture. ``False`` for CPU devices.
|
|
2185
|
-
is_mempool_supported: A boolean indicating whether
|
|
2407
|
+
is_mempool_supported: A boolean indicating whether the device supports using the
|
|
2186
2408
|
``cuMemAllocAsync`` and ``cuMemPool`` family of APIs for stream-ordered memory allocations. ``False`` for
|
|
2187
2409
|
CPU devices.
|
|
2188
|
-
is_primary: A boolean indicating whether
|
|
2410
|
+
is_primary: A boolean indicating whether this device's CUDA context is also the
|
|
2189
2411
|
device's primary context.
|
|
2190
2412
|
uuid: A string representing the UUID of the CUDA device. The UUID is in the same format used by
|
|
2191
2413
|
``nvidia-smi -L``. ``None`` for CPU devices.
|
|
@@ -2274,7 +2496,7 @@ class Device:
|
|
|
2274
2496
|
|
|
2275
2497
|
# initialize streams unless context acquisition is postponed
|
|
2276
2498
|
if self._context is not None:
|
|
2277
|
-
self.
|
|
2499
|
+
self._init_streams()
|
|
2278
2500
|
|
|
2279
2501
|
# TODO: add more device-specific dispatch functions
|
|
2280
2502
|
self.memset = lambda ptr, value, size: runtime.core.memset_device(self.context, ptr, value, size)
|
|
@@ -2285,7 +2507,13 @@ class Device:
|
|
|
2285
2507
|
else:
|
|
2286
2508
|
raise RuntimeError(f"Invalid device ordinal ({ordinal})'")
|
|
2287
2509
|
|
|
2288
|
-
def get_allocator(self, pinned=False):
|
|
2510
|
+
def get_allocator(self, pinned: bool = False):
|
|
2511
|
+
"""Get the memory allocator for this device.
|
|
2512
|
+
|
|
2513
|
+
Args:
|
|
2514
|
+
pinned: If ``True``, an allocator for pinned memory will be
|
|
2515
|
+
returned. Only applicable when this device is a CPU device.
|
|
2516
|
+
"""
|
|
2289
2517
|
if self.is_cuda:
|
|
2290
2518
|
return self.current_allocator
|
|
2291
2519
|
else:
|
|
@@ -2294,7 +2522,8 @@ class Device:
|
|
|
2294
2522
|
else:
|
|
2295
2523
|
return self.default_allocator
|
|
2296
2524
|
|
|
2297
|
-
def
|
|
2525
|
+
def _init_streams(self):
|
|
2526
|
+
"""Initializes the device's current stream and the device's null stream."""
|
|
2298
2527
|
# create a stream for asynchronous work
|
|
2299
2528
|
self.set_stream(Stream(self))
|
|
2300
2529
|
|
|
@@ -2302,17 +2531,18 @@ class Device:
|
|
|
2302
2531
|
self.null_stream = Stream(self, cuda_stream=None)
|
|
2303
2532
|
|
|
2304
2533
|
@property
|
|
2305
|
-
def is_cpu(self):
|
|
2306
|
-
"""A boolean indicating whether
|
|
2534
|
+
def is_cpu(self) -> bool:
|
|
2535
|
+
"""A boolean indicating whether the device is a CPU device."""
|
|
2307
2536
|
return self.ordinal < 0
|
|
2308
2537
|
|
|
2309
2538
|
@property
|
|
2310
|
-
def is_cuda(self):
|
|
2311
|
-
"""A boolean indicating whether
|
|
2539
|
+
def is_cuda(self) -> bool:
|
|
2540
|
+
"""A boolean indicating whether the device is a CUDA device."""
|
|
2312
2541
|
return self.ordinal >= 0
|
|
2313
2542
|
|
|
2314
2543
|
@property
|
|
2315
|
-
def is_capturing(self):
|
|
2544
|
+
def is_capturing(self) -> bool:
|
|
2545
|
+
"""A boolean indicating whether this device's default stream is currently capturing a graph."""
|
|
2316
2546
|
if self.is_cuda and self.stream is not None:
|
|
2317
2547
|
# There is no CUDA API to check if graph capture was started on a device, so we
|
|
2318
2548
|
# can't tell if a capture was started by external code on a different stream.
|
|
@@ -2336,17 +2566,17 @@ class Device:
|
|
|
2336
2566
|
raise RuntimeError(f"Failed to acquire primary context for device {self}")
|
|
2337
2567
|
self.runtime.context_map[self._context] = self
|
|
2338
2568
|
# initialize streams
|
|
2339
|
-
self.
|
|
2569
|
+
self._init_streams()
|
|
2340
2570
|
runtime.core.cuda_context_set_current(prev_context)
|
|
2341
2571
|
return self._context
|
|
2342
2572
|
|
|
2343
2573
|
@property
|
|
2344
|
-
def has_context(self):
|
|
2345
|
-
"""A boolean indicating whether
|
|
2574
|
+
def has_context(self) -> bool:
|
|
2575
|
+
"""A boolean indicating whether the device has a CUDA context associated with it."""
|
|
2346
2576
|
return self._context is not None
|
|
2347
2577
|
|
|
2348
2578
|
@property
|
|
2349
|
-
def stream(self):
|
|
2579
|
+
def stream(self) -> Stream:
|
|
2350
2580
|
"""The stream associated with a CUDA device.
|
|
2351
2581
|
|
|
2352
2582
|
Raises:
|
|
@@ -2361,7 +2591,22 @@ class Device:
|
|
|
2361
2591
|
def stream(self, stream):
|
|
2362
2592
|
self.set_stream(stream)
|
|
2363
2593
|
|
|
2364
|
-
def set_stream(self, stream, sync=True):
|
|
2594
|
+
def set_stream(self, stream: Stream, sync: bool = True) -> None:
|
|
2595
|
+
"""Set the current stream for this CUDA device.
|
|
2596
|
+
|
|
2597
|
+
The current stream will be used by default for all kernel launches and
|
|
2598
|
+
memory operations on this device.
|
|
2599
|
+
|
|
2600
|
+
If this is an external stream, the caller is responsible for
|
|
2601
|
+
guaranteeing the lifetime of the stream.
|
|
2602
|
+
|
|
2603
|
+
Consider using :class:`warp.ScopedStream` instead.
|
|
2604
|
+
|
|
2605
|
+
Args:
|
|
2606
|
+
stream: The stream to set as this device's current stream.
|
|
2607
|
+
sync: If ``True``, then ``stream`` will perform a device-side
|
|
2608
|
+
synchronization with the device's previous current stream.
|
|
2609
|
+
"""
|
|
2365
2610
|
if self.is_cuda:
|
|
2366
2611
|
if stream.device != self:
|
|
2367
2612
|
raise RuntimeError(f"Stream from device {stream.device} cannot be used on device {self}")
|
|
@@ -2372,12 +2617,12 @@ class Device:
|
|
|
2372
2617
|
raise RuntimeError(f"Device {self} is not a CUDA device")
|
|
2373
2618
|
|
|
2374
2619
|
@property
|
|
2375
|
-
def has_stream(self):
|
|
2376
|
-
"""A boolean indicating whether
|
|
2620
|
+
def has_stream(self) -> bool:
|
|
2621
|
+
"""A boolean indicating whether the device has a stream associated with it."""
|
|
2377
2622
|
return self._stream is not None
|
|
2378
2623
|
|
|
2379
2624
|
@property
|
|
2380
|
-
def total_memory(self):
|
|
2625
|
+
def total_memory(self) -> int:
|
|
2381
2626
|
"""The total amount of device memory available in bytes.
|
|
2382
2627
|
|
|
2383
2628
|
This function is currently only implemented for CUDA devices. 0 will be returned if called on a CPU device.
|
|
@@ -2391,7 +2636,7 @@ class Device:
|
|
|
2391
2636
|
return 0
|
|
2392
2637
|
|
|
2393
2638
|
@property
|
|
2394
|
-
def free_memory(self):
|
|
2639
|
+
def free_memory(self) -> int:
|
|
2395
2640
|
"""The amount of memory on the device that is free according to the OS in bytes.
|
|
2396
2641
|
|
|
2397
2642
|
This function is currently only implemented for CUDA devices. 0 will be returned if called on a CPU device.
|
|
@@ -2755,6 +3000,12 @@ class Runtime:
|
|
|
2755
3000
|
self.core.mesh_refit_host.argtypes = [ctypes.c_uint64]
|
|
2756
3001
|
self.core.mesh_refit_device.argtypes = [ctypes.c_uint64]
|
|
2757
3002
|
|
|
3003
|
+
self.core.mesh_set_points_host.argtypes = [ctypes.c_uint64, warp.types.array_t]
|
|
3004
|
+
self.core.mesh_set_points_device.argtypes = [ctypes.c_uint64, warp.types.array_t]
|
|
3005
|
+
|
|
3006
|
+
self.core.mesh_set_velocities_host.argtypes = [ctypes.c_uint64, warp.types.array_t]
|
|
3007
|
+
self.core.mesh_set_velocities_device.argtypes = [ctypes.c_uint64, warp.types.array_t]
|
|
3008
|
+
|
|
2758
3009
|
self.core.hash_grid_create_host.argtypes = [ctypes.c_int, ctypes.c_int, ctypes.c_int]
|
|
2759
3010
|
self.core.hash_grid_create_host.restype = ctypes.c_uint64
|
|
2760
3011
|
self.core.hash_grid_destroy_host.argtypes = [ctypes.c_uint64]
|
|
@@ -3029,7 +3280,7 @@ class Runtime:
|
|
|
3029
3280
|
self.core.cuda_set_mempool_access_enabled.argtypes = [ctypes.c_int, ctypes.c_int, ctypes.c_int]
|
|
3030
3281
|
self.core.cuda_set_mempool_access_enabled.restype = ctypes.c_int
|
|
3031
3282
|
|
|
3032
|
-
self.core.cuda_stream_create.argtypes = [ctypes.c_void_p]
|
|
3283
|
+
self.core.cuda_stream_create.argtypes = [ctypes.c_void_p, ctypes.c_int]
|
|
3033
3284
|
self.core.cuda_stream_create.restype = ctypes.c_void_p
|
|
3034
3285
|
self.core.cuda_stream_destroy.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
|
|
3035
3286
|
self.core.cuda_stream_destroy.restype = None
|
|
@@ -3047,6 +3298,8 @@ class Runtime:
|
|
|
3047
3298
|
self.core.cuda_stream_is_capturing.restype = ctypes.c_int
|
|
3048
3299
|
self.core.cuda_stream_get_capture_id.argtypes = [ctypes.c_void_p]
|
|
3049
3300
|
self.core.cuda_stream_get_capture_id.restype = ctypes.c_uint64
|
|
3301
|
+
self.core.cuda_stream_get_priority.argtypes = [ctypes.c_void_p]
|
|
3302
|
+
self.core.cuda_stream_get_priority.restype = ctypes.c_int
|
|
3050
3303
|
|
|
3051
3304
|
self.core.cuda_event_create.argtypes = [ctypes.c_void_p, ctypes.c_uint]
|
|
3052
3305
|
self.core.cuda_event_create.restype = ctypes.c_void_p
|
|
@@ -3382,10 +3635,10 @@ class Runtime:
|
|
|
3382
3635
|
|
|
3383
3636
|
raise ValueError(f"Invalid device identifier: {ident}")
|
|
3384
3637
|
|
|
3385
|
-
def set_default_device(self, ident: Devicelike):
|
|
3638
|
+
def set_default_device(self, ident: Devicelike) -> None:
|
|
3386
3639
|
self.default_device = self.get_device(ident)
|
|
3387
3640
|
|
|
3388
|
-
def get_current_cuda_device(self):
|
|
3641
|
+
def get_current_cuda_device(self) -> Device:
|
|
3389
3642
|
current_context = self.core.cuda_context_get_current()
|
|
3390
3643
|
if current_context is not None:
|
|
3391
3644
|
current_device = self.context_map.get(current_context)
|
|
@@ -3415,7 +3668,7 @@ class Runtime:
|
|
|
3415
3668
|
else:
|
|
3416
3669
|
raise RuntimeError('"cuda" device requested but CUDA is not supported by the hardware or driver')
|
|
3417
3670
|
|
|
3418
|
-
def rename_device(self, device, alias):
|
|
3671
|
+
def rename_device(self, device, alias) -> Device:
|
|
3419
3672
|
del self.device_map[device.alias]
|
|
3420
3673
|
device.alias = alias
|
|
3421
3674
|
self.device_map[alias] = device
|
|
@@ -3462,7 +3715,7 @@ class Runtime:
|
|
|
3462
3715
|
|
|
3463
3716
|
return device
|
|
3464
3717
|
|
|
3465
|
-
def unmap_cuda_device(self, alias):
|
|
3718
|
+
def unmap_cuda_device(self, alias) -> None:
|
|
3466
3719
|
device = self.device_map.get(alias)
|
|
3467
3720
|
|
|
3468
3721
|
# make sure the alias refers to a CUDA device
|
|
@@ -3473,7 +3726,7 @@ class Runtime:
|
|
|
3473
3726
|
del self.context_map[device.context]
|
|
3474
3727
|
self.cuda_devices.remove(device)
|
|
3475
3728
|
|
|
3476
|
-
def verify_cuda_device(self, device: Devicelike = None):
|
|
3729
|
+
def verify_cuda_device(self, device: Devicelike = None) -> None:
|
|
3477
3730
|
if warp.config.verify_cuda:
|
|
3478
3731
|
device = runtime.get_device(device)
|
|
3479
3732
|
if not device.is_cuda:
|
|
@@ -3485,13 +3738,13 @@ class Runtime:
|
|
|
3485
3738
|
|
|
3486
3739
|
|
|
3487
3740
|
# global entry points
|
|
3488
|
-
def is_cpu_available():
|
|
3741
|
+
def is_cpu_available() -> bool:
|
|
3489
3742
|
init()
|
|
3490
3743
|
|
|
3491
|
-
return runtime.llvm
|
|
3744
|
+
return runtime.llvm is not None
|
|
3492
3745
|
|
|
3493
3746
|
|
|
3494
|
-
def is_cuda_available():
|
|
3747
|
+
def is_cuda_available() -> bool:
|
|
3495
3748
|
return get_cuda_device_count() > 0
|
|
3496
3749
|
|
|
3497
3750
|
|
|
@@ -3575,8 +3828,8 @@ def get_device(ident: Devicelike = None) -> Device:
|
|
|
3575
3828
|
return runtime.get_device(ident)
|
|
3576
3829
|
|
|
3577
3830
|
|
|
3578
|
-
def set_device(ident: Devicelike):
|
|
3579
|
-
"""Sets the
|
|
3831
|
+
def set_device(ident: Devicelike) -> None:
|
|
3832
|
+
"""Sets the default device identified by the argument."""
|
|
3580
3833
|
|
|
3581
3834
|
init()
|
|
3582
3835
|
|
|
@@ -3604,7 +3857,7 @@ def map_cuda_device(alias: str, context: ctypes.c_void_p = None) -> Device:
|
|
|
3604
3857
|
return runtime.map_cuda_device(alias, context)
|
|
3605
3858
|
|
|
3606
3859
|
|
|
3607
|
-
def unmap_cuda_device(alias: str):
|
|
3860
|
+
def unmap_cuda_device(alias: str) -> None:
|
|
3608
3861
|
"""Remove a CUDA device with the given alias."""
|
|
3609
3862
|
|
|
3610
3863
|
init()
|
|
@@ -3612,7 +3865,7 @@ def unmap_cuda_device(alias: str):
|
|
|
3612
3865
|
runtime.unmap_cuda_device(alias)
|
|
3613
3866
|
|
|
3614
3867
|
|
|
3615
|
-
def is_mempool_supported(device: Devicelike):
|
|
3868
|
+
def is_mempool_supported(device: Devicelike) -> bool:
|
|
3616
3869
|
"""Check if CUDA memory pool allocators are available on the device."""
|
|
3617
3870
|
|
|
3618
3871
|
init()
|
|
@@ -3622,7 +3875,7 @@ def is_mempool_supported(device: Devicelike):
|
|
|
3622
3875
|
return device.is_mempool_supported
|
|
3623
3876
|
|
|
3624
3877
|
|
|
3625
|
-
def is_mempool_enabled(device: Devicelike):
|
|
3878
|
+
def is_mempool_enabled(device: Devicelike) -> bool:
|
|
3626
3879
|
"""Check if CUDA memory pool allocators are enabled on the device."""
|
|
3627
3880
|
|
|
3628
3881
|
init()
|
|
@@ -3632,7 +3885,7 @@ def is_mempool_enabled(device: Devicelike):
|
|
|
3632
3885
|
return device.is_mempool_enabled
|
|
3633
3886
|
|
|
3634
3887
|
|
|
3635
|
-
def set_mempool_enabled(device: Devicelike, enable: bool):
|
|
3888
|
+
def set_mempool_enabled(device: Devicelike, enable: bool) -> None:
|
|
3636
3889
|
"""Enable or disable CUDA memory pool allocators on the device.
|
|
3637
3890
|
|
|
3638
3891
|
Pooled allocators are typically faster and allow allocating memory during graph capture.
|
|
@@ -3663,7 +3916,7 @@ def set_mempool_enabled(device: Devicelike, enable: bool):
|
|
|
3663
3916
|
raise ValueError("Memory pools are only supported on CUDA devices")
|
|
3664
3917
|
|
|
3665
3918
|
|
|
3666
|
-
def set_mempool_release_threshold(device: Devicelike, threshold: Union[int, float]):
|
|
3919
|
+
def set_mempool_release_threshold(device: Devicelike, threshold: Union[int, float]) -> None:
|
|
3667
3920
|
"""Set the CUDA memory pool release threshold on the device.
|
|
3668
3921
|
|
|
3669
3922
|
This is the amount of reserved memory to hold onto before trying to release memory back to the OS.
|
|
@@ -3694,7 +3947,7 @@ def set_mempool_release_threshold(device: Devicelike, threshold: Union[int, floa
|
|
|
3694
3947
|
raise RuntimeError(f"Failed to set memory pool release threshold for device {device}")
|
|
3695
3948
|
|
|
3696
3949
|
|
|
3697
|
-
def get_mempool_release_threshold(device: Devicelike):
|
|
3950
|
+
def get_mempool_release_threshold(device: Devicelike) -> int:
|
|
3698
3951
|
"""Get the CUDA memory pool release threshold on the device."""
|
|
3699
3952
|
|
|
3700
3953
|
init()
|
|
@@ -3710,7 +3963,7 @@ def get_mempool_release_threshold(device: Devicelike):
|
|
|
3710
3963
|
return runtime.core.cuda_device_get_mempool_release_threshold(device.ordinal)
|
|
3711
3964
|
|
|
3712
3965
|
|
|
3713
|
-
def is_peer_access_supported(target_device: Devicelike, peer_device: Devicelike):
|
|
3966
|
+
def is_peer_access_supported(target_device: Devicelike, peer_device: Devicelike) -> bool:
|
|
3714
3967
|
"""Check if `peer_device` can directly access the memory of `target_device` on this system.
|
|
3715
3968
|
|
|
3716
3969
|
This applies to memory allocated using default CUDA allocators. For memory allocated using
|
|
@@ -3731,7 +3984,7 @@ def is_peer_access_supported(target_device: Devicelike, peer_device: Devicelike)
|
|
|
3731
3984
|
return bool(runtime.core.cuda_is_peer_access_supported(target_device.ordinal, peer_device.ordinal))
|
|
3732
3985
|
|
|
3733
3986
|
|
|
3734
|
-
def is_peer_access_enabled(target_device: Devicelike, peer_device: Devicelike):
|
|
3987
|
+
def is_peer_access_enabled(target_device: Devicelike, peer_device: Devicelike) -> bool:
|
|
3735
3988
|
"""Check if `peer_device` can currently access the memory of `target_device`.
|
|
3736
3989
|
|
|
3737
3990
|
This applies to memory allocated using default CUDA allocators. For memory allocated using
|
|
@@ -3752,7 +4005,7 @@ def is_peer_access_enabled(target_device: Devicelike, peer_device: Devicelike):
|
|
|
3752
4005
|
return bool(runtime.core.cuda_is_peer_access_enabled(target_device.context, peer_device.context))
|
|
3753
4006
|
|
|
3754
4007
|
|
|
3755
|
-
def set_peer_access_enabled(target_device: Devicelike, peer_device: Devicelike, enable: bool):
|
|
4008
|
+
def set_peer_access_enabled(target_device: Devicelike, peer_device: Devicelike, enable: bool) -> None:
|
|
3756
4009
|
"""Enable or disable direct access from `peer_device` to the memory of `target_device`.
|
|
3757
4010
|
|
|
3758
4011
|
Enabling peer access can improve the speed of peer-to-peer memory transfers, but can have
|
|
@@ -3784,7 +4037,7 @@ def set_peer_access_enabled(target_device: Devicelike, peer_device: Devicelike,
|
|
|
3784
4037
|
raise RuntimeError(f"Failed to {action} peer access from device {peer_device} to device {target_device}")
|
|
3785
4038
|
|
|
3786
4039
|
|
|
3787
|
-
def is_mempool_access_supported(target_device: Devicelike, peer_device: Devicelike):
|
|
4040
|
+
def is_mempool_access_supported(target_device: Devicelike, peer_device: Devicelike) -> bool:
|
|
3788
4041
|
"""Check if `peer_device` can directly access the memory pool of `target_device`.
|
|
3789
4042
|
|
|
3790
4043
|
If mempool access is possible, it can be managed using `set_mempool_access_enabled()` and `is_mempool_access_enabled()`.
|
|
@@ -3801,7 +4054,7 @@ def is_mempool_access_supported(target_device: Devicelike, peer_device: Deviceli
|
|
|
3801
4054
|
return target_device.is_mempool_supported and is_peer_access_supported(target_device, peer_device)
|
|
3802
4055
|
|
|
3803
4056
|
|
|
3804
|
-
def is_mempool_access_enabled(target_device: Devicelike, peer_device: Devicelike):
|
|
4057
|
+
def is_mempool_access_enabled(target_device: Devicelike, peer_device: Devicelike) -> bool:
|
|
3805
4058
|
"""Check if `peer_device` can currently access the memory pool of `target_device`.
|
|
3806
4059
|
|
|
3807
4060
|
This applies to memory allocated using CUDA pooled allocators. For memory allocated using
|
|
@@ -3822,7 +4075,7 @@ def is_mempool_access_enabled(target_device: Devicelike, peer_device: Devicelike
|
|
|
3822
4075
|
return bool(runtime.core.cuda_is_mempool_access_enabled(target_device.ordinal, peer_device.ordinal))
|
|
3823
4076
|
|
|
3824
4077
|
|
|
3825
|
-
def set_mempool_access_enabled(target_device: Devicelike, peer_device: Devicelike, enable: bool):
|
|
4078
|
+
def set_mempool_access_enabled(target_device: Devicelike, peer_device: Devicelike, enable: bool) -> None:
|
|
3826
4079
|
"""Enable or disable access from `peer_device` to the memory pool of `target_device`.
|
|
3827
4080
|
|
|
3828
4081
|
This applies to memory allocated using CUDA pooled allocators. For memory allocated using
|
|
@@ -3858,26 +4111,41 @@ def set_mempool_access_enabled(target_device: Devicelike, peer_device: Devicelik
|
|
|
3858
4111
|
|
|
3859
4112
|
|
|
3860
4113
|
def get_stream(device: Devicelike = None) -> Stream:
|
|
3861
|
-
"""Return the stream currently used by the given device
|
|
4114
|
+
"""Return the stream currently used by the given device.
|
|
4115
|
+
|
|
4116
|
+
Args:
|
|
4117
|
+
device: An optional :class:`Device` instance or device alias
|
|
4118
|
+
(e.g. "cuda:0") for which the current stream will be returned.
|
|
4119
|
+
If ``None``, the default device will be used.
|
|
4120
|
+
|
|
4121
|
+
Raises:
|
|
4122
|
+
RuntimeError: The device is not a CUDA device.
|
|
4123
|
+
"""
|
|
3862
4124
|
|
|
3863
4125
|
return get_device(device).stream
|
|
3864
4126
|
|
|
3865
4127
|
|
|
3866
|
-
def set_stream(stream, device: Devicelike = None, sync: bool = False):
|
|
3867
|
-
"""
|
|
4128
|
+
def set_stream(stream: Stream, device: Devicelike = None, sync: bool = False) -> None:
|
|
4129
|
+
"""Convenience function for calling :meth:`Device.set_stream` on the given ``device``.
|
|
3868
4130
|
|
|
3869
|
-
|
|
3870
|
-
|
|
4131
|
+
Args:
|
|
4132
|
+
device: An optional :class:`Device` instance or device alias
|
|
4133
|
+
(e.g. "cuda:0") for which the current stream is to be replaced with
|
|
4134
|
+
``stream``. If ``None``, the default device will be used.
|
|
4135
|
+
stream: The stream to set as this device's current stream.
|
|
4136
|
+
sync: If ``True``, then ``stream`` will perform a device-side
|
|
4137
|
+
synchronization with the device's previous current stream.
|
|
3871
4138
|
"""
|
|
3872
4139
|
|
|
3873
4140
|
get_device(device).set_stream(stream, sync=sync)
|
|
3874
4141
|
|
|
3875
4142
|
|
|
3876
|
-
def record_event(event: Event = None):
|
|
3877
|
-
"""
|
|
4143
|
+
def record_event(event: Optional[Event] = None):
|
|
4144
|
+
"""Convenience function for calling :meth:`Stream.record_event` on the current stream.
|
|
3878
4145
|
|
|
3879
4146
|
Args:
|
|
3880
|
-
event: Event to record. If None
|
|
4147
|
+
event: :class:`Event` instance to record. If ``None``, a new :class:`Event`
|
|
4148
|
+
instance will be created.
|
|
3881
4149
|
|
|
3882
4150
|
Returns:
|
|
3883
4151
|
The recorded event.
|
|
@@ -3887,29 +4155,31 @@ def record_event(event: Event = None):
|
|
|
3887
4155
|
|
|
3888
4156
|
|
|
3889
4157
|
def wait_event(event: Event):
|
|
3890
|
-
"""
|
|
4158
|
+
"""Convenience function for calling :meth:`Stream.wait_event` on the current stream.
|
|
3891
4159
|
|
|
3892
4160
|
Args:
|
|
3893
|
-
event: Event to wait for.
|
|
4161
|
+
event: :class:`Event` instance to wait for.
|
|
3894
4162
|
"""
|
|
3895
4163
|
|
|
3896
4164
|
get_stream().wait_event(event)
|
|
3897
4165
|
|
|
3898
4166
|
|
|
3899
|
-
def get_event_elapsed_time(start_event: Event, end_event: Event, synchronize: bool = True):
|
|
4167
|
+
def get_event_elapsed_time(start_event: Event, end_event: Event, synchronize: Optional[bool] = True):
|
|
3900
4168
|
"""Get the elapsed time between two recorded events.
|
|
3901
4169
|
|
|
3902
|
-
|
|
3903
|
-
|
|
3904
|
-
Both events must have been previously recorded with ``wp.record_event()`` or ``wp.Stream.record_event()``.
|
|
4170
|
+
Both events must have been previously recorded with
|
|
4171
|
+
:func:`~warp.record_event()` or :meth:`warp.Stream.record_event()`.
|
|
3905
4172
|
|
|
3906
4173
|
If ``synchronize`` is False, the caller must ensure that device execution has reached ``end_event``
|
|
3907
4174
|
prior to calling ``get_event_elapsed_time()``.
|
|
3908
4175
|
|
|
3909
4176
|
Args:
|
|
3910
|
-
start_event
|
|
3911
|
-
end_event
|
|
3912
|
-
synchronize
|
|
4177
|
+
start_event: The start event.
|
|
4178
|
+
end_event: The end event.
|
|
4179
|
+
synchronize: Whether Warp should synchronize on the ``end_event``.
|
|
4180
|
+
|
|
4181
|
+
Returns:
|
|
4182
|
+
The elapsed time in milliseconds with a resolution about 0.5 ms.
|
|
3913
4183
|
"""
|
|
3914
4184
|
|
|
3915
4185
|
# ensure the end_event is reached
|
|
@@ -3919,14 +4189,19 @@ def get_event_elapsed_time(start_event: Event, end_event: Event, synchronize: bo
|
|
|
3919
4189
|
return runtime.core.cuda_event_elapsed_time(start_event.cuda_event, end_event.cuda_event)
|
|
3920
4190
|
|
|
3921
4191
|
|
|
3922
|
-
def wait_stream(
|
|
3923
|
-
"""
|
|
4192
|
+
def wait_stream(other_stream: Stream, event: Event = None):
|
|
4193
|
+
"""Convenience function for calling :meth:`Stream.wait_stream` on the current stream.
|
|
3924
4194
|
|
|
3925
4195
|
Args:
|
|
3926
|
-
|
|
4196
|
+
other_stream: The stream on which the calling stream will wait for
|
|
4197
|
+
previously issued commands to complete before executing subsequent
|
|
4198
|
+
commands.
|
|
4199
|
+
event: An optional :class:`Event` instance that will be used to
|
|
4200
|
+
record an event onto ``other_stream``. If ``None``, an internally
|
|
4201
|
+
managed :class:`Event` instance will be used.
|
|
3927
4202
|
"""
|
|
3928
4203
|
|
|
3929
|
-
get_stream().wait_stream(
|
|
4204
|
+
get_stream().wait_stream(other_stream, event=event)
|
|
3930
4205
|
|
|
3931
4206
|
|
|
3932
4207
|
class RegisteredGLBuffer:
|
|
@@ -4362,7 +4637,7 @@ def from_numpy(
|
|
|
4362
4637
|
dtype: The data type of the new Warp array. If this is not provided, the data type will be inferred.
|
|
4363
4638
|
shape: The shape of the Warp array.
|
|
4364
4639
|
device: The device on which the Warp array will be constructed.
|
|
4365
|
-
requires_grad: Whether
|
|
4640
|
+
requires_grad: Whether gradients will be tracked for this array.
|
|
4366
4641
|
|
|
4367
4642
|
Raises:
|
|
4368
4643
|
RuntimeError: The data type of the NumPy array is not supported.
|
|
@@ -4398,7 +4673,7 @@ def pack_arg(kernel, arg_type, arg_name, value, device, adjoint=False):
|
|
|
4398
4673
|
return arg_type.__ctype__()
|
|
4399
4674
|
|
|
4400
4675
|
elif isinstance(value, warp.types.array_t):
|
|
4401
|
-
# accept array descriptors
|
|
4676
|
+
# accept array descriptors verbatim
|
|
4402
4677
|
return value
|
|
4403
4678
|
|
|
4404
4679
|
else:
|
|
@@ -4865,7 +5140,7 @@ def synchronize_device(device: Devicelike = None):
|
|
|
4865
5140
|
runtime.core.cuda_context_synchronize(device.context)
|
|
4866
5141
|
|
|
4867
5142
|
|
|
4868
|
-
def synchronize_stream(stream_or_device=None):
|
|
5143
|
+
def synchronize_stream(stream_or_device: Union[Stream, Devicelike, None] = None):
|
|
4869
5144
|
"""Synchronize the calling CPU thread with any outstanding CUDA work on the specified stream.
|
|
4870
5145
|
|
|
4871
5146
|
This function allows the host application code to ensure that all kernel launches
|
|
@@ -4989,7 +5264,7 @@ def set_module_options(options: Dict[str, Any], module: Optional[Any] = None):
|
|
|
4989
5264
|
m = module
|
|
4990
5265
|
|
|
4991
5266
|
get_module(m.__name__).options.update(options)
|
|
4992
|
-
get_module(m.__name__).
|
|
5267
|
+
get_module(m.__name__).mark_modified()
|
|
4993
5268
|
|
|
4994
5269
|
|
|
4995
5270
|
def get_module_options(module: Optional[Any] = None) -> Dict[str, Any]:
|
|
@@ -5016,7 +5291,7 @@ def capture_begin(device: Devicelike = None, stream=None, force_module_load=None
|
|
|
5016
5291
|
Args:
|
|
5017
5292
|
device: The CUDA device to capture on
|
|
5018
5293
|
stream: The CUDA stream to capture on
|
|
5019
|
-
force_module_load: Whether
|
|
5294
|
+
force_module_load: Whether to force loading of all kernels before capture.
|
|
5020
5295
|
In general it is better to use :func:`~warp.load_module()` to selectively load kernels.
|
|
5021
5296
|
When running with CUDA drivers that support CUDA 12.3 or newer, this option is not recommended to be set to
|
|
5022
5297
|
``True`` because kernels can be loaded during graph capture on more recent drivers. If this argument is
|
|
@@ -5574,7 +5849,7 @@ def export_stubs(file): # pragma: no cover
|
|
|
5574
5849
|
|
|
5575
5850
|
return_str = ""
|
|
5576
5851
|
|
|
5577
|
-
if
|
|
5852
|
+
if f.hidden: # or f.generic:
|
|
5578
5853
|
continue
|
|
5579
5854
|
|
|
5580
5855
|
return_type = f.value_func(None, None)
|