warp-lang 1.3.3__py3-none-manylinux2014_x86_64.whl → 1.4.1__py3-none-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +6 -0
- warp/autograd.py +59 -6
- warp/bin/warp.so +0 -0
- warp/build_dll.py +8 -10
- warp/builtins.py +103 -3
- warp/codegen.py +447 -53
- warp/config.py +1 -1
- warp/context.py +682 -405
- warp/dlpack.py +2 -0
- warp/examples/benchmarks/benchmark_cloth.py +10 -0
- warp/examples/core/example_render_opengl.py +12 -10
- warp/examples/fem/example_adaptive_grid.py +251 -0
- warp/examples/fem/example_apic_fluid.py +1 -1
- warp/examples/fem/example_diffusion_3d.py +2 -2
- warp/examples/fem/example_magnetostatics.py +1 -1
- warp/examples/fem/example_streamlines.py +1 -0
- warp/examples/fem/utils.py +25 -5
- warp/examples/sim/example_cloth.py +50 -6
- warp/fem/__init__.py +2 -0
- warp/fem/adaptivity.py +493 -0
- warp/fem/field/field.py +2 -1
- warp/fem/field/nodal_field.py +18 -26
- warp/fem/field/test.py +4 -4
- warp/fem/field/trial.py +4 -4
- warp/fem/geometry/__init__.py +1 -0
- warp/fem/geometry/adaptive_nanogrid.py +843 -0
- warp/fem/geometry/nanogrid.py +55 -28
- warp/fem/space/__init__.py +1 -1
- warp/fem/space/nanogrid_function_space.py +69 -35
- warp/fem/utils.py +118 -107
- warp/jax_experimental.py +28 -15
- warp/native/array.h +0 -1
- warp/native/builtin.h +103 -6
- warp/native/bvh.cu +4 -2
- warp/native/cuda_util.cpp +14 -0
- warp/native/cuda_util.h +2 -0
- warp/native/error.cpp +4 -2
- warp/native/exports.h +99 -0
- warp/native/mat.h +97 -0
- warp/native/mesh.cpp +36 -0
- warp/native/mesh.cu +52 -1
- warp/native/mesh.h +1 -0
- warp/native/quat.h +43 -0
- warp/native/range.h +11 -2
- warp/native/spatial.h +6 -0
- warp/native/vec.h +74 -0
- warp/native/warp.cpp +2 -1
- warp/native/warp.cu +10 -3
- warp/native/warp.h +8 -1
- warp/paddle.py +382 -0
- warp/sim/__init__.py +1 -0
- warp/sim/collide.py +519 -0
- warp/sim/integrator_euler.py +18 -5
- warp/sim/integrator_featherstone.py +5 -5
- warp/sim/integrator_vbd.py +1026 -0
- warp/sim/integrator_xpbd.py +2 -6
- warp/sim/model.py +50 -25
- warp/sparse.py +9 -7
- warp/stubs.py +459 -0
- warp/tape.py +2 -0
- warp/tests/aux_test_dependent.py +1 -0
- warp/tests/aux_test_name_clash1.py +32 -0
- warp/tests/aux_test_name_clash2.py +32 -0
- warp/tests/aux_test_square.py +1 -0
- warp/tests/test_array.py +188 -0
- warp/tests/test_async.py +3 -3
- warp/tests/test_atomic.py +6 -0
- warp/tests/test_closest_point_edge_edge.py +93 -1
- warp/tests/test_codegen.py +93 -15
- warp/tests/test_codegen_instancing.py +1457 -0
- warp/tests/test_collision.py +486 -0
- warp/tests/test_compile_consts.py +3 -28
- warp/tests/test_dlpack.py +170 -0
- warp/tests/test_examples.py +22 -8
- warp/tests/test_fast_math.py +10 -4
- warp/tests/test_fem.py +81 -1
- warp/tests/test_func.py +46 -0
- warp/tests/test_implicit_init.py +49 -0
- warp/tests/test_jax.py +58 -0
- warp/tests/test_mat.py +84 -0
- warp/tests/test_mesh_query_point.py +188 -0
- warp/tests/test_model.py +13 -0
- warp/tests/test_module_hashing.py +40 -0
- warp/tests/test_multigpu.py +3 -3
- warp/tests/test_overwrite.py +8 -0
- warp/tests/test_paddle.py +852 -0
- warp/tests/test_print.py +89 -0
- warp/tests/test_quat.py +111 -0
- warp/tests/test_reload.py +31 -1
- warp/tests/test_scalar_ops.py +2 -0
- warp/tests/test_static.py +568 -0
- warp/tests/test_streams.py +64 -3
- warp/tests/test_struct.py +4 -4
- warp/tests/test_torch.py +24 -0
- warp/tests/test_triangle_closest_point.py +137 -0
- warp/tests/test_types.py +1 -1
- warp/tests/test_vbd.py +386 -0
- warp/tests/test_vec.py +143 -0
- warp/tests/test_vec_scalar_ops.py +139 -0
- warp/tests/unittest_suites.py +12 -0
- warp/tests/unittest_utils.py +9 -5
- warp/thirdparty/dlpack.py +3 -1
- warp/types.py +167 -36
- warp/utils.py +37 -14
- {warp_lang-1.3.3.dist-info → warp_lang-1.4.1.dist-info}/METADATA +10 -8
- {warp_lang-1.3.3.dist-info → warp_lang-1.4.1.dist-info}/RECORD +109 -97
- warp/tests/test_point_triangle_closest_point.py +0 -143
- {warp_lang-1.3.3.dist-info → warp_lang-1.4.1.dist-info}/LICENSE.md +0 -0
- {warp_lang-1.3.3.dist-info → warp_lang-1.4.1.dist-info}/WHEEL +0 -0
- {warp_lang-1.3.3.dist-info → warp_lang-1.4.1.dist-info}/top_level.txt +0 -0
warp/context.py
CHANGED
|
@@ -6,7 +6,6 @@
|
|
|
6
6
|
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
|
7
7
|
|
|
8
8
|
import ast
|
|
9
|
-
import builtins
|
|
10
9
|
import ctypes
|
|
11
10
|
import functools
|
|
12
11
|
import hashlib
|
|
@@ -19,9 +18,9 @@ import platform
|
|
|
19
18
|
import sys
|
|
20
19
|
import types
|
|
21
20
|
import typing
|
|
21
|
+
import weakref
|
|
22
22
|
from copy import copy as shallowcopy
|
|
23
23
|
from pathlib import Path
|
|
24
|
-
from struct import pack as struct_pack
|
|
25
24
|
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union
|
|
26
25
|
|
|
27
26
|
import numpy as np
|
|
@@ -36,6 +35,10 @@ import warp.config
|
|
|
36
35
|
|
|
37
36
|
def create_value_func(type):
|
|
38
37
|
def value_func(arg_types, arg_values):
|
|
38
|
+
hint_origin = getattr(type, "__origin__", None)
|
|
39
|
+
if hint_origin is not None and issubclass(hint_origin, typing.Tuple):
|
|
40
|
+
return type.__args__
|
|
41
|
+
|
|
39
42
|
return type
|
|
40
43
|
|
|
41
44
|
return value_func
|
|
@@ -54,6 +57,38 @@ def get_function_args(func):
|
|
|
54
57
|
complex_type_hints = (Any, Callable, Tuple)
|
|
55
58
|
sequence_types = (list, tuple)
|
|
56
59
|
|
|
60
|
+
function_key_counts = {}
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def generate_unique_function_identifier(key):
|
|
64
|
+
# Generate unique identifiers for user-defined functions in native code.
|
|
65
|
+
# - Prevents conflicts when a function is redefined and old versions are still in use.
|
|
66
|
+
# - Prevents conflicts between multiple closures returned from the same function.
|
|
67
|
+
# - Prevents conflicts between identically named functions from different modules.
|
|
68
|
+
#
|
|
69
|
+
# Currently, we generate a unique id when a new Function is created, which produces
|
|
70
|
+
# globally unique identifiers.
|
|
71
|
+
#
|
|
72
|
+
# NOTE:
|
|
73
|
+
# We could move this to the Module class for generating unique identifiers at module scope,
|
|
74
|
+
# but then we need another solution for preventing conflicts across modules (e.g., different namespaces).
|
|
75
|
+
# That would requires more Python code, generate more native code, and would be slightly slower
|
|
76
|
+
# with no clear advantages over globally-unique identifiers (non-global shared state is still shared state).
|
|
77
|
+
#
|
|
78
|
+
# TODO:
|
|
79
|
+
# Kernels and structs use unique identifiers based on their hash. Using hash-based identifiers
|
|
80
|
+
# for functions would allow filtering out duplicate identical functions during codegen,
|
|
81
|
+
# like we do with kernels and structs. This is worth investigating further, but might require
|
|
82
|
+
# additional refactoring. For example, the code that deals with custom gradient and replay functions
|
|
83
|
+
# requires matching function names, but these special functions get created before the hash
|
|
84
|
+
# for the parent function can be computed. In addition to these complications, computing hashes
|
|
85
|
+
# for all function instances would increase the cost of module hashing when generic functions
|
|
86
|
+
# are involved (currently we only hash the generic templates, which is sufficient).
|
|
87
|
+
|
|
88
|
+
unique_id = function_key_counts.get(key, 0)
|
|
89
|
+
function_key_counts[key] = unique_id + 1
|
|
90
|
+
return f"{key}_{unique_id}"
|
|
91
|
+
|
|
57
92
|
|
|
58
93
|
class Function:
|
|
59
94
|
def __init__(
|
|
@@ -90,6 +125,7 @@ class Function:
|
|
|
90
125
|
code_transformers=None,
|
|
91
126
|
skip_adding_overload=False,
|
|
92
127
|
require_original_output_arg=False,
|
|
128
|
+
scope_locals=None, # the locals() where the function is defined, used for overload management
|
|
93
129
|
):
|
|
94
130
|
if code_transformers is None:
|
|
95
131
|
code_transformers = []
|
|
@@ -115,6 +151,7 @@ class Function:
|
|
|
115
151
|
self.replay_snippet = replay_snippet
|
|
116
152
|
self.custom_grad_func = None
|
|
117
153
|
self.require_original_output_arg = require_original_output_arg
|
|
154
|
+
self.generic_parent = None # generic function that was used to instantiate this overload
|
|
118
155
|
|
|
119
156
|
if initializer_list_func is None:
|
|
120
157
|
self.initializer_list_func = lambda x, y: False
|
|
@@ -124,14 +161,19 @@ class Function:
|
|
|
124
161
|
)
|
|
125
162
|
self.hidden = hidden # function will not be listed in docs
|
|
126
163
|
self.skip_replay = (
|
|
127
|
-
skip_replay # whether
|
|
164
|
+
skip_replay # whether operation will be performed during the forward replay in the backward pass
|
|
128
165
|
)
|
|
129
|
-
self.missing_grad = missing_grad # whether
|
|
166
|
+
self.missing_grad = missing_grad # whether builtin is missing a corresponding adjoint
|
|
130
167
|
self.generic = generic
|
|
131
168
|
|
|
132
|
-
# allow registering
|
|
169
|
+
# allow registering functions with a different name in Python and native code
|
|
133
170
|
if native_func is None:
|
|
134
|
-
|
|
171
|
+
if func is None:
|
|
172
|
+
# builtin function
|
|
173
|
+
self.native_func = key
|
|
174
|
+
else:
|
|
175
|
+
# user functions need unique identifiers to avoid conflicts
|
|
176
|
+
self.native_func = generate_unique_function_identifier(key)
|
|
135
177
|
else:
|
|
136
178
|
self.native_func = native_func
|
|
137
179
|
|
|
@@ -162,6 +204,11 @@ class Function:
|
|
|
162
204
|
else:
|
|
163
205
|
self.input_types[name] = type
|
|
164
206
|
|
|
207
|
+
# Record any default parameter values.
|
|
208
|
+
if not self.defaults:
|
|
209
|
+
signature = inspect.signature(func)
|
|
210
|
+
self.defaults = {k: v.default for k, v in signature.parameters.items() if v.default is not v.empty}
|
|
211
|
+
|
|
165
212
|
else:
|
|
166
213
|
# builtin function
|
|
167
214
|
|
|
@@ -210,9 +257,13 @@ class Function:
|
|
|
210
257
|
signature_params.append(param)
|
|
211
258
|
self.signature = inspect.Signature(signature_params)
|
|
212
259
|
|
|
260
|
+
# scope for resolving overloads
|
|
261
|
+
if scope_locals is None:
|
|
262
|
+
scope_locals = inspect.currentframe().f_back.f_locals
|
|
263
|
+
|
|
213
264
|
# add to current module
|
|
214
265
|
if module:
|
|
215
|
-
module.register_function(self, skip_adding_overload)
|
|
266
|
+
module.register_function(self, scope_locals, skip_adding_overload)
|
|
216
267
|
|
|
217
268
|
def __call__(self, *args, **kwargs):
|
|
218
269
|
# handles calling a builtin (native) function
|
|
@@ -323,57 +374,52 @@ class Function:
|
|
|
323
374
|
|
|
324
375
|
# check if generic
|
|
325
376
|
if warp.types.is_generic_signature(sig):
|
|
326
|
-
if sig in self.user_templates:
|
|
327
|
-
raise RuntimeError(
|
|
328
|
-
f"Duplicate generic function overload {self.key} with arguments {f.input_types.values()}"
|
|
329
|
-
)
|
|
330
377
|
self.user_templates[sig] = f
|
|
331
378
|
else:
|
|
332
|
-
if sig in self.user_overloads:
|
|
333
|
-
raise RuntimeError(
|
|
334
|
-
f"Duplicate function overload {self.key} with arguments {f.input_types.values()}"
|
|
335
|
-
)
|
|
336
379
|
self.user_overloads[sig] = f
|
|
337
380
|
|
|
338
381
|
def get_overload(self, arg_types, kwarg_types):
|
|
339
382
|
assert not self.is_builtin()
|
|
340
383
|
|
|
341
|
-
|
|
384
|
+
for f in self.user_overloads.values():
|
|
385
|
+
if warp.codegen.func_match_args(f, arg_types, kwarg_types):
|
|
386
|
+
return f
|
|
342
387
|
|
|
343
|
-
f
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
continue
|
|
388
|
+
for f in self.user_templates.values():
|
|
389
|
+
if not warp.codegen.func_match_args(f, arg_types, kwarg_types):
|
|
390
|
+
continue
|
|
391
|
+
|
|
392
|
+
if len(f.input_types) != len(arg_types):
|
|
393
|
+
continue
|
|
350
394
|
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
395
|
+
# try to match the given types to the function template types
|
|
396
|
+
template_types = list(f.input_types.values())
|
|
397
|
+
args_matched = True
|
|
354
398
|
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
399
|
+
for i in range(len(arg_types)):
|
|
400
|
+
if not warp.types.type_matches_template(arg_types[i], template_types[i]):
|
|
401
|
+
args_matched = False
|
|
402
|
+
break
|
|
359
403
|
|
|
360
|
-
|
|
361
|
-
|
|
404
|
+
if args_matched:
|
|
405
|
+
# instantiate this function with the specified argument types
|
|
362
406
|
|
|
363
|
-
|
|
364
|
-
|
|
407
|
+
arg_names = f.input_types.keys()
|
|
408
|
+
overload_annotations = dict(zip(arg_names, arg_types))
|
|
365
409
|
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
410
|
+
ovl = shallowcopy(f)
|
|
411
|
+
ovl.adj = warp.codegen.Adjoint(f.func, overload_annotations)
|
|
412
|
+
ovl.input_types = overload_annotations
|
|
413
|
+
ovl.value_func = None
|
|
414
|
+
ovl.generic_parent = f
|
|
370
415
|
|
|
371
|
-
|
|
416
|
+
sig = warp.types.get_signature(arg_types, func_name=self.key)
|
|
417
|
+
self.user_overloads[sig] = ovl
|
|
372
418
|
|
|
373
|
-
|
|
419
|
+
return ovl
|
|
374
420
|
|
|
375
|
-
|
|
376
|
-
|
|
421
|
+
# failed to find overload
|
|
422
|
+
return None
|
|
377
423
|
|
|
378
424
|
def __repr__(self):
|
|
379
425
|
inputs_str = ", ".join([f"{k}: {warp.types.type_repr(v)}" for k, v in self.input_types.items()])
|
|
@@ -589,8 +635,7 @@ class Kernel:
|
|
|
589
635
|
self.module = module
|
|
590
636
|
|
|
591
637
|
if key is None:
|
|
592
|
-
|
|
593
|
-
self.key = unique_key
|
|
638
|
+
self.key = warp.codegen.make_full_qualified_name(func)
|
|
594
639
|
else:
|
|
595
640
|
self.key = key
|
|
596
641
|
|
|
@@ -614,9 +659,15 @@ class Kernel:
|
|
|
614
659
|
# known overloads for generic kernels, indexed by type signature
|
|
615
660
|
self.overloads = {}
|
|
616
661
|
|
|
662
|
+
# generic kernel that was used to instantiate this overload
|
|
663
|
+
self.generic_parent = None
|
|
664
|
+
|
|
617
665
|
# argument indices by name
|
|
618
666
|
self.arg_indices = {a.label: i for i, a in enumerate(self.adj.args)}
|
|
619
667
|
|
|
668
|
+
# hash will be computed when the module is built
|
|
669
|
+
self.hash = None
|
|
670
|
+
|
|
620
671
|
if self.module:
|
|
621
672
|
self.module.register_kernel(self)
|
|
622
673
|
|
|
@@ -664,10 +715,11 @@ class Kernel:
|
|
|
664
715
|
ovl.is_generic = False
|
|
665
716
|
ovl.overloads = {}
|
|
666
717
|
ovl.sig = sig
|
|
718
|
+
ovl.generic_parent = self
|
|
667
719
|
|
|
668
720
|
self.overloads[sig] = ovl
|
|
669
721
|
|
|
670
|
-
self.module.
|
|
722
|
+
self.module.mark_modified()
|
|
671
723
|
|
|
672
724
|
return ovl
|
|
673
725
|
|
|
@@ -676,10 +728,13 @@ class Kernel:
|
|
|
676
728
|
return self.overloads.get(sig)
|
|
677
729
|
|
|
678
730
|
def get_mangled_name(self):
|
|
679
|
-
if self.
|
|
680
|
-
|
|
681
|
-
|
|
682
|
-
|
|
731
|
+
if self.hash is None:
|
|
732
|
+
raise RuntimeError(f"Missing hash for kernel {self.key} in module {self.module.name}")
|
|
733
|
+
|
|
734
|
+
# TODO: allow customizing the number of hash characters used
|
|
735
|
+
hash_suffix = self.hash.hex()[:8]
|
|
736
|
+
|
|
737
|
+
return f"{self.key}_{hash_suffix}"
|
|
683
738
|
|
|
684
739
|
|
|
685
740
|
# ----------------------
|
|
@@ -689,9 +744,11 @@ class Kernel:
|
|
|
689
744
|
def func(f):
|
|
690
745
|
name = warp.codegen.make_full_qualified_name(f)
|
|
691
746
|
|
|
747
|
+
scope_locals = inspect.currentframe().f_back.f_locals
|
|
748
|
+
|
|
692
749
|
m = get_module(f.__module__)
|
|
693
750
|
Function(
|
|
694
|
-
func=f, key=name, namespace="", module=m, value_func=None
|
|
751
|
+
func=f, key=name, namespace="", module=m, value_func=None, scope_locals=scope_locals
|
|
695
752
|
) # value_type not known yet, will be inferred during Adjoint.build()
|
|
696
753
|
|
|
697
754
|
# use the top of the list of overloads for this key
|
|
@@ -705,6 +762,8 @@ def func_native(snippet, adj_snippet=None, replay_snippet=None):
|
|
|
705
762
|
Decorator to register native code snippet, @func_native
|
|
706
763
|
"""
|
|
707
764
|
|
|
765
|
+
scope_locals = inspect.currentframe().f_back.f_locals
|
|
766
|
+
|
|
708
767
|
def snippet_func(f):
|
|
709
768
|
name = warp.codegen.make_full_qualified_name(f)
|
|
710
769
|
|
|
@@ -717,6 +776,7 @@ def func_native(snippet, adj_snippet=None, replay_snippet=None):
|
|
|
717
776
|
native_snippet=snippet,
|
|
718
777
|
adj_native_snippet=adj_snippet,
|
|
719
778
|
replay_snippet=replay_snippet,
|
|
779
|
+
scope_locals=scope_locals,
|
|
720
780
|
) # value_type not known yet, will be inferred during Adjoint.build()
|
|
721
781
|
g = m.functions[name]
|
|
722
782
|
# copy over the function attributes, including docstring
|
|
@@ -783,6 +843,7 @@ def func_grad(forward_fn):
|
|
|
783
843
|
f.custom_grad_func = Function(
|
|
784
844
|
grad_fn,
|
|
785
845
|
key=f.key,
|
|
846
|
+
native_func=f.native_func,
|
|
786
847
|
namespace=f.namespace,
|
|
787
848
|
input_types=reverse_args,
|
|
788
849
|
value_func=None,
|
|
@@ -941,7 +1002,7 @@ def overload(kernel, arg_types=None):
|
|
|
941
1002
|
# ensure this function name corresponds to a kernel
|
|
942
1003
|
fn = kernel
|
|
943
1004
|
module = get_module(fn.__module__)
|
|
944
|
-
kernel = module.
|
|
1005
|
+
kernel = module.find_kernel(fn)
|
|
945
1006
|
if kernel is None:
|
|
946
1007
|
raise RuntimeError(f"Failed to find a kernel named '{fn.__name__}' in module {fn.__module__}")
|
|
947
1008
|
|
|
@@ -1277,7 +1338,6 @@ def get_module(name):
|
|
|
1277
1338
|
# clear out old kernels, funcs, struct definitions
|
|
1278
1339
|
old_module.kernels = {}
|
|
1279
1340
|
old_module.functions = {}
|
|
1280
|
-
old_module.constants = {}
|
|
1281
1341
|
old_module.structs = {}
|
|
1282
1342
|
old_module.loader = parent_loader
|
|
1283
1343
|
|
|
@@ -1289,30 +1349,206 @@ def get_module(name):
|
|
|
1289
1349
|
return user_modules[name]
|
|
1290
1350
|
|
|
1291
1351
|
|
|
1352
|
+
# ModuleHasher computes the module hash based on all the kernels, module options,
|
|
1353
|
+
# and build configuration. For each kernel, it computes a deep hash by recursively
|
|
1354
|
+
# hashing all referenced functions, structs, and constants, even those defined in
|
|
1355
|
+
# other modules. The module hash is computed in the constructor and can be retrieved
|
|
1356
|
+
# using get_module_hash(). In addition, the ModuleHasher takes care of filtering out
|
|
1357
|
+
# duplicate kernels for codegen (see get_unique_kernels()).
|
|
1358
|
+
class ModuleHasher:
|
|
1359
|
+
def __init__(self, module):
|
|
1360
|
+
# cache function hashes to avoid hashing multiple times
|
|
1361
|
+
self.function_hashes = {} # (function: hash)
|
|
1362
|
+
|
|
1363
|
+
# avoid recursive spiral of doom (e.g., function calling an overload of itself)
|
|
1364
|
+
self.functions_in_progress = set()
|
|
1365
|
+
|
|
1366
|
+
# all unique kernels for codegen, filtered by hash
|
|
1367
|
+
self.unique_kernels = {} # (hash: kernel)
|
|
1368
|
+
|
|
1369
|
+
# start hashing the module
|
|
1370
|
+
ch = hashlib.sha256()
|
|
1371
|
+
|
|
1372
|
+
# hash all non-generic kernels
|
|
1373
|
+
for kernel in module.live_kernels:
|
|
1374
|
+
if kernel.is_generic:
|
|
1375
|
+
for ovl in kernel.overloads.values():
|
|
1376
|
+
if not ovl.adj.skip_build:
|
|
1377
|
+
ovl.hash = self.hash_kernel(ovl)
|
|
1378
|
+
else:
|
|
1379
|
+
if not kernel.adj.skip_build:
|
|
1380
|
+
kernel.hash = self.hash_kernel(kernel)
|
|
1381
|
+
|
|
1382
|
+
# include all unique kernels in the module hash
|
|
1383
|
+
for kernel_hash in sorted(self.unique_kernels.keys()):
|
|
1384
|
+
ch.update(kernel_hash)
|
|
1385
|
+
|
|
1386
|
+
# configuration parameters
|
|
1387
|
+
for opt in sorted(module.options.keys()):
|
|
1388
|
+
s = f"{opt}:{module.options[opt]}"
|
|
1389
|
+
ch.update(bytes(s, "utf-8"))
|
|
1390
|
+
|
|
1391
|
+
# ensure to trigger recompilation if flags affecting kernel compilation are changed
|
|
1392
|
+
if warp.config.verify_fp:
|
|
1393
|
+
ch.update(bytes("verify_fp", "utf-8"))
|
|
1394
|
+
|
|
1395
|
+
# build config
|
|
1396
|
+
ch.update(bytes(warp.config.mode, "utf-8"))
|
|
1397
|
+
|
|
1398
|
+
# save the module hash
|
|
1399
|
+
self.module_hash = ch.digest()
|
|
1400
|
+
|
|
1401
|
+
def hash_kernel(self, kernel):
|
|
1402
|
+
# NOTE: We only hash non-generic kernels, so we don't traverse kernel overloads here.
|
|
1403
|
+
|
|
1404
|
+
ch = hashlib.sha256()
|
|
1405
|
+
|
|
1406
|
+
ch.update(bytes(kernel.key, "utf-8"))
|
|
1407
|
+
ch.update(self.hash_adjoint(kernel.adj))
|
|
1408
|
+
|
|
1409
|
+
h = ch.digest()
|
|
1410
|
+
|
|
1411
|
+
self.unique_kernels[h] = kernel
|
|
1412
|
+
|
|
1413
|
+
return h
|
|
1414
|
+
|
|
1415
|
+
def hash_function(self, func):
|
|
1416
|
+
# NOTE: This method hashes all possible overloads that a function call could resolve to.
|
|
1417
|
+
# The exact overload will be resolved at build time, when the argument types are known.
|
|
1418
|
+
|
|
1419
|
+
h = self.function_hashes.get(func)
|
|
1420
|
+
if h is not None:
|
|
1421
|
+
return h
|
|
1422
|
+
|
|
1423
|
+
self.functions_in_progress.add(func)
|
|
1424
|
+
|
|
1425
|
+
ch = hashlib.sha256()
|
|
1426
|
+
|
|
1427
|
+
ch.update(bytes(func.key, "utf-8"))
|
|
1428
|
+
|
|
1429
|
+
# include all concrete and generic overloads
|
|
1430
|
+
overloads = {**func.user_overloads, **func.user_templates}
|
|
1431
|
+
for sig in sorted(overloads.keys()):
|
|
1432
|
+
ovl = overloads[sig]
|
|
1433
|
+
|
|
1434
|
+
# skip instantiations of generic functions
|
|
1435
|
+
if ovl.generic_parent is not None:
|
|
1436
|
+
continue
|
|
1437
|
+
|
|
1438
|
+
# adjoint
|
|
1439
|
+
ch.update(self.hash_adjoint(ovl.adj))
|
|
1440
|
+
|
|
1441
|
+
# custom bits
|
|
1442
|
+
if ovl.custom_grad_func:
|
|
1443
|
+
ch.update(self.hash_adjoint(ovl.custom_grad_func.adj))
|
|
1444
|
+
if ovl.custom_replay_func:
|
|
1445
|
+
ch.update(self.hash_adjoint(ovl.custom_replay_func.adj))
|
|
1446
|
+
if ovl.replay_snippet:
|
|
1447
|
+
ch.update(bytes(ovl.replay_snippet, "utf-8"))
|
|
1448
|
+
if ovl.native_snippet:
|
|
1449
|
+
ch.update(bytes(ovl.native_snippet, "utf-8"))
|
|
1450
|
+
if ovl.adj_native_snippet:
|
|
1451
|
+
ch.update(bytes(ovl.adj_native_snippet, "utf-8"))
|
|
1452
|
+
|
|
1453
|
+
h = ch.digest()
|
|
1454
|
+
|
|
1455
|
+
self.function_hashes[func] = h
|
|
1456
|
+
|
|
1457
|
+
self.functions_in_progress.remove(func)
|
|
1458
|
+
|
|
1459
|
+
return h
|
|
1460
|
+
|
|
1461
|
+
def hash_adjoint(self, adj):
|
|
1462
|
+
# NOTE: We don't cache adjoint hashes, because adjoints are always unique.
|
|
1463
|
+
# Even instances of generic kernels and functions have unique adjoints with
|
|
1464
|
+
# different argument types.
|
|
1465
|
+
|
|
1466
|
+
ch = hashlib.sha256()
|
|
1467
|
+
|
|
1468
|
+
# source
|
|
1469
|
+
ch.update(bytes(adj.source, "utf-8"))
|
|
1470
|
+
|
|
1471
|
+
# args
|
|
1472
|
+
for arg, arg_type in adj.arg_types.items():
|
|
1473
|
+
s = f"{arg}:{warp.types.get_type_code(arg_type)}"
|
|
1474
|
+
ch.update(bytes(s, "utf-8"))
|
|
1475
|
+
|
|
1476
|
+
# hash struct types
|
|
1477
|
+
if isinstance(arg_type, warp.codegen.Struct):
|
|
1478
|
+
ch.update(arg_type.hash)
|
|
1479
|
+
elif warp.types.is_array(arg_type) and isinstance(arg_type.dtype, warp.codegen.Struct):
|
|
1480
|
+
ch.update(arg_type.dtype.hash)
|
|
1481
|
+
|
|
1482
|
+
# find referenced constants, types, and functions
|
|
1483
|
+
constants, types, functions = adj.get_references()
|
|
1484
|
+
|
|
1485
|
+
# hash referenced constants
|
|
1486
|
+
for name, value in constants.items():
|
|
1487
|
+
ch.update(bytes(name, "utf-8"))
|
|
1488
|
+
ch.update(self.get_constant_bytes(value))
|
|
1489
|
+
|
|
1490
|
+
# hash wp.static() expressions that were evaluated at declaration time
|
|
1491
|
+
for k, v in adj.static_expressions.items():
|
|
1492
|
+
ch.update(bytes(k, "utf-8"))
|
|
1493
|
+
if isinstance(v, Function):
|
|
1494
|
+
if v not in self.functions_in_progress:
|
|
1495
|
+
ch.update(self.hash_function(v))
|
|
1496
|
+
else:
|
|
1497
|
+
ch.update(self.get_constant_bytes(v))
|
|
1498
|
+
|
|
1499
|
+
# hash referenced types
|
|
1500
|
+
for t in types.keys():
|
|
1501
|
+
ch.update(bytes(warp.types.get_type_code(t), "utf-8"))
|
|
1502
|
+
|
|
1503
|
+
# hash referenced functions
|
|
1504
|
+
for f in functions.keys():
|
|
1505
|
+
if f not in self.functions_in_progress:
|
|
1506
|
+
ch.update(self.hash_function(f))
|
|
1507
|
+
|
|
1508
|
+
return ch.digest()
|
|
1509
|
+
|
|
1510
|
+
def get_constant_bytes(self, value):
|
|
1511
|
+
if isinstance(value, int):
|
|
1512
|
+
# this also handles builtins.bool
|
|
1513
|
+
return bytes(ctypes.c_int(value))
|
|
1514
|
+
elif isinstance(value, float):
|
|
1515
|
+
return bytes(ctypes.c_float(value))
|
|
1516
|
+
elif isinstance(value, warp.types.float16):
|
|
1517
|
+
# float16 is a special case
|
|
1518
|
+
return bytes(ctypes.c_float(value.value))
|
|
1519
|
+
elif isinstance(value, tuple(warp.types.scalar_and_bool_types)):
|
|
1520
|
+
return bytes(value._type_(value.value))
|
|
1521
|
+
elif hasattr(value, "_wp_scalar_type_"):
|
|
1522
|
+
return bytes(value)
|
|
1523
|
+
elif isinstance(value, warp.codegen.StructInstance):
|
|
1524
|
+
return bytes(value._ctype)
|
|
1525
|
+
else:
|
|
1526
|
+
raise TypeError(f"Invalid constant type: {type(value)}")
|
|
1527
|
+
|
|
1528
|
+
def get_module_hash(self):
|
|
1529
|
+
return self.module_hash
|
|
1530
|
+
|
|
1531
|
+
def get_unique_kernels(self):
|
|
1532
|
+
return self.unique_kernels.values()
|
|
1533
|
+
|
|
1534
|
+
|
|
1292
1535
|
class ModuleBuilder:
|
|
1293
|
-
def __init__(self, module, options):
|
|
1536
|
+
def __init__(self, module, options, hasher=None):
|
|
1294
1537
|
self.functions = {}
|
|
1295
1538
|
self.structs = {}
|
|
1296
1539
|
self.options = options
|
|
1297
1540
|
self.module = module
|
|
1298
1541
|
self.deferred_functions = []
|
|
1299
1542
|
|
|
1300
|
-
|
|
1301
|
-
|
|
1302
|
-
for f in func.user_overloads.values():
|
|
1303
|
-
self.build_function(f)
|
|
1304
|
-
if f.custom_replay_func is not None:
|
|
1305
|
-
self.build_function(f.custom_replay_func)
|
|
1306
|
-
|
|
1307
|
-
# build all kernel entry points
|
|
1308
|
-
for kernel in module.kernels.values():
|
|
1309
|
-
if not kernel.is_generic:
|
|
1310
|
-
self.build_kernel(kernel)
|
|
1311
|
-
else:
|
|
1312
|
-
for k in kernel.overloads.values():
|
|
1313
|
-
self.build_kernel(k)
|
|
1543
|
+
if hasher is None:
|
|
1544
|
+
hasher = ModuleHasher(module)
|
|
1314
1545
|
|
|
1315
|
-
# build all
|
|
1546
|
+
# build all unique kernels
|
|
1547
|
+
self.kernels = hasher.get_unique_kernels()
|
|
1548
|
+
for kernel in self.kernels:
|
|
1549
|
+
self.build_kernel(kernel)
|
|
1550
|
+
|
|
1551
|
+
# build deferred functions
|
|
1316
1552
|
for func in self.deferred_functions:
|
|
1317
1553
|
self.build_function(func)
|
|
1318
1554
|
|
|
@@ -1328,7 +1564,7 @@ class ModuleBuilder:
|
|
|
1328
1564
|
for var in s.vars.values():
|
|
1329
1565
|
if isinstance(var.type, warp.codegen.Struct):
|
|
1330
1566
|
stack.append(var.type)
|
|
1331
|
-
elif
|
|
1567
|
+
elif warp.types.is_array(var.type) and isinstance(var.type.dtype, warp.codegen.Struct):
|
|
1332
1568
|
stack.append(var.type.dtype)
|
|
1333
1569
|
|
|
1334
1570
|
# Build them in reverse to generate a correct dependency order.
|
|
@@ -1374,8 +1610,12 @@ class ModuleBuilder:
|
|
|
1374
1610
|
source = ""
|
|
1375
1611
|
|
|
1376
1612
|
# code-gen structs
|
|
1613
|
+
visited_structs = set()
|
|
1377
1614
|
for struct in self.structs.keys():
|
|
1378
|
-
|
|
1615
|
+
# avoid emitting duplicates
|
|
1616
|
+
if struct.hash not in visited_structs:
|
|
1617
|
+
source += warp.codegen.codegen_struct(struct)
|
|
1618
|
+
visited_structs.add(struct.hash)
|
|
1379
1619
|
|
|
1380
1620
|
# code-gen all imported functions
|
|
1381
1621
|
for func in self.functions.keys():
|
|
@@ -1386,21 +1626,15 @@ class ModuleBuilder:
|
|
|
1386
1626
|
else:
|
|
1387
1627
|
source += warp.codegen.codegen_snippet(
|
|
1388
1628
|
func.adj,
|
|
1389
|
-
name=func.
|
|
1629
|
+
name=func.native_func,
|
|
1390
1630
|
snippet=func.native_snippet,
|
|
1391
1631
|
adj_snippet=func.adj_native_snippet,
|
|
1392
1632
|
replay_snippet=func.replay_snippet,
|
|
1393
1633
|
)
|
|
1394
1634
|
|
|
1395
|
-
for kernel in self.
|
|
1396
|
-
|
|
1397
|
-
|
|
1398
|
-
source += warp.codegen.codegen_kernel(kernel, device=device, options=self.options)
|
|
1399
|
-
source += warp.codegen.codegen_module(kernel, device=device)
|
|
1400
|
-
else:
|
|
1401
|
-
for k in kernel.overloads.values():
|
|
1402
|
-
source += warp.codegen.codegen_kernel(k, device=device, options=self.options)
|
|
1403
|
-
source += warp.codegen.codegen_module(k, device=device)
|
|
1635
|
+
for kernel in self.kernels:
|
|
1636
|
+
source += warp.codegen.codegen_kernel(kernel, device=device, options=self.options)
|
|
1637
|
+
source += warp.codegen.codegen_module(kernel, device=device)
|
|
1404
1638
|
|
|
1405
1639
|
# add headers
|
|
1406
1640
|
if device == "cpu":
|
|
@@ -1425,8 +1659,9 @@ class ModuleExec:
|
|
|
1425
1659
|
instance.handle = None
|
|
1426
1660
|
return instance
|
|
1427
1661
|
|
|
1428
|
-
def __init__(self, handle, device):
|
|
1662
|
+
def __init__(self, handle, module_hash, device):
|
|
1429
1663
|
self.handle = handle
|
|
1664
|
+
self.module_hash = module_hash
|
|
1430
1665
|
self.device = device
|
|
1431
1666
|
self.kernel_hooks = {}
|
|
1432
1667
|
|
|
@@ -1457,8 +1692,12 @@ class ModuleExec:
|
|
|
1457
1692
|
)
|
|
1458
1693
|
else:
|
|
1459
1694
|
func = ctypes.CFUNCTYPE(None)
|
|
1460
|
-
forward =
|
|
1461
|
-
|
|
1695
|
+
forward = (
|
|
1696
|
+
func(runtime.llvm.lookup(self.handle.encode("utf-8"), (name + "_cpu_forward").encode("utf-8"))) or None
|
|
1697
|
+
)
|
|
1698
|
+
backward = (
|
|
1699
|
+
func(runtime.llvm.lookup(self.handle.encode("utf-8"), (name + "_cpu_backward").encode("utf-8"))) or None
|
|
1700
|
+
)
|
|
1462
1701
|
|
|
1463
1702
|
hooks = KernelHooks(forward, backward)
|
|
1464
1703
|
self.kernel_hooks[kernel] = hooks
|
|
@@ -1475,16 +1714,33 @@ class Module:
|
|
|
1475
1714
|
self.name = name
|
|
1476
1715
|
self.loader = loader
|
|
1477
1716
|
|
|
1478
|
-
|
|
1479
|
-
self.
|
|
1480
|
-
self.
|
|
1481
|
-
self.structs = {}
|
|
1717
|
+
# lookup the latest versions of kernels, functions, and structs by key
|
|
1718
|
+
self.kernels = {} # (key: kernel)
|
|
1719
|
+
self.functions = {} # (key: function)
|
|
1720
|
+
self.structs = {} # (key: struct)
|
|
1721
|
+
|
|
1722
|
+
# Set of all "live" kernels in this module.
|
|
1723
|
+
# The difference between `live_kernels` and `kernels` is that `live_kernels` may contain
|
|
1724
|
+
# multiple kernels with the same key (which is essential to support closures), while `kernels`
|
|
1725
|
+
# only holds the latest kernel for each key. When the module is built, we compute the hash
|
|
1726
|
+
# of each kernel in `live_kernels` and filter out duplicates for codegen.
|
|
1727
|
+
self.live_kernels = weakref.WeakSet()
|
|
1482
1728
|
|
|
1483
|
-
|
|
1484
|
-
self.
|
|
1729
|
+
# executable modules currently loaded
|
|
1730
|
+
self.execs = {} # (device.context: ModuleExec)
|
|
1485
1731
|
|
|
1486
|
-
|
|
1487
|
-
self.
|
|
1732
|
+
# set of device contexts where the build has failed
|
|
1733
|
+
self.failed_builds = set()
|
|
1734
|
+
|
|
1735
|
+
# hash data, including the module hash
|
|
1736
|
+
self.hasher = None
|
|
1737
|
+
|
|
1738
|
+
# LLVM executable modules are identified using strings. Since it's possible for multiple
|
|
1739
|
+
# executable versions to be loaded at the same time, we need a way to ensure uniqueness.
|
|
1740
|
+
# A unique handle is created from the module name and this auto-incremented integer id.
|
|
1741
|
+
# NOTE: The module hash is not sufficient for uniqueness in rare cases where a module
|
|
1742
|
+
# is retained and later reloaded with the same hash.
|
|
1743
|
+
self.cpu_exec_id = 0
|
|
1488
1744
|
|
|
1489
1745
|
self.options = {
|
|
1490
1746
|
"max_unroll": warp.config.max_unroll,
|
|
@@ -1497,11 +1753,6 @@ class Module:
|
|
|
1497
1753
|
# Module dependencies are determined by scanning each function
|
|
1498
1754
|
# and kernel for references to external functions and structs.
|
|
1499
1755
|
#
|
|
1500
|
-
# When a referenced module is modified, all of its dependents need to be reloaded
|
|
1501
|
-
# on the next launch. To detect this, a module's hash recursively includes
|
|
1502
|
-
# all of its references.
|
|
1503
|
-
# -> See ``Module.hash_module()``
|
|
1504
|
-
#
|
|
1505
1756
|
# The dependency mechanism works for both static and dynamic (runtime) modifications.
|
|
1506
1757
|
# When a module is reloaded at runtime, we recursively unload all of its
|
|
1507
1758
|
# dependents, so that they will be re-hashed and reloaded on the next launch.
|
|
@@ -1510,40 +1761,39 @@ class Module:
|
|
|
1510
1761
|
self.references = set() # modules whose content we depend on
|
|
1511
1762
|
self.dependents = set() # modules that depend on our content
|
|
1512
1763
|
|
|
1513
|
-
# Since module hashing is recursive, we improve performance by caching the hash of the
|
|
1514
|
-
# module contents (kernel source, function source, and struct source).
|
|
1515
|
-
# After all kernels, functions, and structs are added to the module (usually at import time),
|
|
1516
|
-
# the content hash doesn't change.
|
|
1517
|
-
# -> See ``Module.hash_module_recursive()``
|
|
1518
|
-
|
|
1519
|
-
self.content_hash = None
|
|
1520
|
-
|
|
1521
|
-
# number of times module auto-generates kernel key for user
|
|
1522
|
-
# used to ensure unique kernel keys
|
|
1523
|
-
self.count = 0
|
|
1524
|
-
|
|
1525
1764
|
def register_struct(self, struct):
|
|
1526
1765
|
self.structs[struct.key] = struct
|
|
1527
1766
|
|
|
1528
1767
|
# for a reload of module on next launch
|
|
1529
|
-
self.
|
|
1768
|
+
self.mark_modified()
|
|
1530
1769
|
|
|
1531
1770
|
def register_kernel(self, kernel):
|
|
1771
|
+
# keep a reference to the latest version
|
|
1532
1772
|
self.kernels[kernel.key] = kernel
|
|
1533
1773
|
|
|
1774
|
+
# track all kernel objects, even if they are duplicates
|
|
1775
|
+
self.live_kernels.add(kernel)
|
|
1776
|
+
|
|
1534
1777
|
self.find_references(kernel.adj)
|
|
1535
1778
|
|
|
1536
1779
|
# for a reload of module on next launch
|
|
1537
|
-
self.
|
|
1780
|
+
self.mark_modified()
|
|
1538
1781
|
|
|
1539
|
-
def register_function(self, func, skip_adding_overload=False):
|
|
1540
|
-
|
|
1541
|
-
|
|
1782
|
+
def register_function(self, func, scope_locals, skip_adding_overload=False):
|
|
1783
|
+
# check for another Function with the same name in the same scope
|
|
1784
|
+
obj = scope_locals.get(func.func.__name__)
|
|
1785
|
+
if isinstance(obj, Function):
|
|
1786
|
+
func_existing = obj
|
|
1542
1787
|
else:
|
|
1788
|
+
func_existing = None
|
|
1789
|
+
|
|
1790
|
+
# keep a reference to the latest version
|
|
1791
|
+
self.functions[func.key] = func_existing or func
|
|
1792
|
+
|
|
1793
|
+
if func_existing:
|
|
1543
1794
|
# Check whether the new function's signature match any that has
|
|
1544
1795
|
# already been registered. If so, then we simply override it, as
|
|
1545
1796
|
# Python would do it, otherwise we register it as a new overload.
|
|
1546
|
-
func_existing = self.functions[func.key]
|
|
1547
1797
|
sig = warp.types.get_signature(
|
|
1548
1798
|
func.input_types.values(),
|
|
1549
1799
|
func_name=func.key,
|
|
@@ -1555,19 +1805,43 @@ class Module:
|
|
|
1555
1805
|
arg_names=list(func_existing.input_types.keys()),
|
|
1556
1806
|
)
|
|
1557
1807
|
if sig == sig_existing:
|
|
1808
|
+
# replace the top-level function, but keep existing overloads
|
|
1809
|
+
|
|
1810
|
+
# copy generic overloads
|
|
1811
|
+
func.user_templates = func_existing.user_templates.copy()
|
|
1812
|
+
|
|
1813
|
+
# copy concrete overloads
|
|
1814
|
+
if warp.types.is_generic_signature(sig):
|
|
1815
|
+
# skip overloads that were instantiated from the function being replaced
|
|
1816
|
+
for k, v in func_existing.user_overloads.items():
|
|
1817
|
+
if v.generic_parent != func_existing:
|
|
1818
|
+
func.user_overloads[k] = v
|
|
1819
|
+
func.user_templates[sig] = func
|
|
1820
|
+
else:
|
|
1821
|
+
func.user_overloads = func_existing.user_overloads.copy()
|
|
1822
|
+
func.user_overloads[sig] = func
|
|
1823
|
+
|
|
1558
1824
|
self.functions[func.key] = func
|
|
1559
1825
|
elif not skip_adding_overload:
|
|
1826
|
+
# check if this is a generic overload that replaces an existing one
|
|
1827
|
+
if warp.types.is_generic_signature(sig):
|
|
1828
|
+
old_generic = func_existing.user_templates.get(sig)
|
|
1829
|
+
if old_generic is not None:
|
|
1830
|
+
# purge any concrete overloads that were instantiated from the old one
|
|
1831
|
+
for k, v in list(func_existing.user_overloads.items()):
|
|
1832
|
+
if v.generic_parent == old_generic:
|
|
1833
|
+
del func_existing.user_overloads[k]
|
|
1560
1834
|
func_existing.add_overload(func)
|
|
1561
1835
|
|
|
1562
1836
|
self.find_references(func.adj)
|
|
1563
1837
|
|
|
1564
1838
|
# for a reload of module on next launch
|
|
1565
|
-
self.
|
|
1839
|
+
self.mark_modified()
|
|
1566
1840
|
|
|
1567
|
-
|
|
1568
|
-
|
|
1569
|
-
|
|
1570
|
-
return
|
|
1841
|
+
# find kernel corresponding to a Python function
|
|
1842
|
+
def find_kernel(self, func):
|
|
1843
|
+
qualname = warp.codegen.make_full_qualified_name(func)
|
|
1844
|
+
return self.kernels.get(qualname)
|
|
1571
1845
|
|
|
1572
1846
|
# collect all referenced functions / structs
|
|
1573
1847
|
# given the AST of a function or kernel
|
|
@@ -1599,165 +1873,30 @@ class Module:
|
|
|
1599
1873
|
if isinstance(arg.type, warp.codegen.Struct) and arg.type.module is not None:
|
|
1600
1874
|
add_ref(arg.type.module)
|
|
1601
1875
|
|
|
1602
|
-
def hash_module(self
|
|
1603
|
-
|
|
1604
|
-
|
|
1605
|
-
|
|
1606
|
-
computed ``content_hash`` will be used.
|
|
1607
|
-
"""
|
|
1608
|
-
|
|
1609
|
-
def get_type_name(type_hint) -> str:
|
|
1610
|
-
if isinstance(type_hint, warp.codegen.Struct):
|
|
1611
|
-
return get_type_name(type_hint.cls)
|
|
1612
|
-
elif isinstance(type_hint, warp.array) and isinstance(type_hint.dtype, warp.codegen.Struct):
|
|
1613
|
-
return f"array{get_type_name(type_hint.dtype)}"
|
|
1614
|
-
|
|
1615
|
-
return str(type_hint)
|
|
1616
|
-
|
|
1617
|
-
def hash_recursive(module, visited):
|
|
1618
|
-
# Hash this module, including all referenced modules recursively.
|
|
1619
|
-
# The visited set tracks modules already visited to avoid circular references.
|
|
1620
|
-
|
|
1621
|
-
# check if we need to update the content hash
|
|
1622
|
-
if not module.content_hash or recompute_content_hash:
|
|
1623
|
-
# recompute content hash
|
|
1624
|
-
ch = hashlib.sha256()
|
|
1625
|
-
|
|
1626
|
-
# Start with an empty constants dictionary in case any have been removed
|
|
1627
|
-
module.constants = {}
|
|
1628
|
-
|
|
1629
|
-
# struct source
|
|
1630
|
-
for struct in module.structs.values():
|
|
1631
|
-
s = ",".join(
|
|
1632
|
-
"{}: {}".format(name, get_type_name(type_hint))
|
|
1633
|
-
for name, type_hint in warp.codegen.get_annotations(struct.cls).items()
|
|
1634
|
-
)
|
|
1635
|
-
ch.update(bytes(s, "utf-8"))
|
|
1636
|
-
|
|
1637
|
-
# functions source
|
|
1638
|
-
for function in module.functions.values():
|
|
1639
|
-
# include all concrete and generic overloads
|
|
1640
|
-
overloads = itertools.chain(function.user_overloads.items(), function.user_templates.items())
|
|
1641
|
-
for sig, func in overloads:
|
|
1642
|
-
# signature
|
|
1643
|
-
ch.update(bytes(sig, "utf-8"))
|
|
1644
|
-
|
|
1645
|
-
# source
|
|
1646
|
-
ch.update(bytes(func.adj.source, "utf-8"))
|
|
1647
|
-
|
|
1648
|
-
if func.custom_grad_func:
|
|
1649
|
-
ch.update(bytes(func.custom_grad_func.adj.source, "utf-8"))
|
|
1650
|
-
if func.custom_replay_func:
|
|
1651
|
-
ch.update(bytes(func.custom_replay_func.adj.source, "utf-8"))
|
|
1652
|
-
if func.replay_snippet:
|
|
1653
|
-
ch.update(bytes(func.replay_snippet, "utf-8"))
|
|
1654
|
-
if func.native_snippet:
|
|
1655
|
-
ch.update(bytes(func.native_snippet, "utf-8"))
|
|
1656
|
-
if func.adj_native_snippet:
|
|
1657
|
-
ch.update(bytes(func.adj_native_snippet, "utf-8"))
|
|
1658
|
-
|
|
1659
|
-
# Populate constants referenced in this function
|
|
1660
|
-
if func.adj:
|
|
1661
|
-
module.constants.update(func.adj.get_constant_references())
|
|
1662
|
-
|
|
1663
|
-
# kernel source
|
|
1664
|
-
for kernel in module.kernels.values():
|
|
1665
|
-
ch.update(bytes(kernel.key, "utf-8"))
|
|
1666
|
-
ch.update(bytes(kernel.adj.source, "utf-8"))
|
|
1667
|
-
# cache kernel arg types
|
|
1668
|
-
for arg, arg_type in kernel.adj.arg_types.items():
|
|
1669
|
-
s = f"{arg}: {get_type_name(arg_type)}"
|
|
1670
|
-
ch.update(bytes(s, "utf-8"))
|
|
1671
|
-
# for generic kernels the Python source is always the same,
|
|
1672
|
-
# but we hash the type signatures of all the overloads
|
|
1673
|
-
if kernel.is_generic:
|
|
1674
|
-
for sig in sorted(kernel.overloads.keys()):
|
|
1675
|
-
ch.update(bytes(sig, "utf-8"))
|
|
1676
|
-
|
|
1677
|
-
# Populate constants referenced in this kernel
|
|
1678
|
-
module.constants.update(kernel.adj.get_constant_references())
|
|
1679
|
-
|
|
1680
|
-
# constants referenced in this module
|
|
1681
|
-
for constant_name, constant_value in module.constants.items():
|
|
1682
|
-
ch.update(bytes(constant_name, "utf-8"))
|
|
1683
|
-
|
|
1684
|
-
# hash the constant value
|
|
1685
|
-
if isinstance(constant_value, builtins.bool):
|
|
1686
|
-
# This needs to come before the check for `int` since all boolean
|
|
1687
|
-
# values are also instances of `int`.
|
|
1688
|
-
ch.update(struct_pack("?", constant_value))
|
|
1689
|
-
elif isinstance(constant_value, int):
|
|
1690
|
-
ch.update(struct_pack("<q", constant_value))
|
|
1691
|
-
elif isinstance(constant_value, float):
|
|
1692
|
-
ch.update(struct_pack("<d", constant_value))
|
|
1693
|
-
elif isinstance(constant_value, warp.types.float16):
|
|
1694
|
-
# float16 is a special case
|
|
1695
|
-
p = ctypes.pointer(ctypes.c_float(constant_value.value))
|
|
1696
|
-
ch.update(p.contents)
|
|
1697
|
-
elif isinstance(constant_value, tuple(warp.types.scalar_types)):
|
|
1698
|
-
p = ctypes.pointer(constant_value._type_(constant_value.value))
|
|
1699
|
-
ch.update(p.contents)
|
|
1700
|
-
elif isinstance(constant_value, ctypes.Array):
|
|
1701
|
-
ch.update(bytes(constant_value))
|
|
1702
|
-
else:
|
|
1703
|
-
raise RuntimeError(f"Invalid constant type: {type(constant_value)}")
|
|
1704
|
-
|
|
1705
|
-
module.content_hash = ch.digest()
|
|
1706
|
-
|
|
1707
|
-
h = hashlib.sha256()
|
|
1708
|
-
|
|
1709
|
-
# content hash
|
|
1710
|
-
h.update(module.content_hash)
|
|
1711
|
-
|
|
1712
|
-
# configuration parameters
|
|
1713
|
-
for k in sorted(module.options.keys()):
|
|
1714
|
-
s = f"{k}={module.options[k]}"
|
|
1715
|
-
h.update(bytes(s, "utf-8"))
|
|
1716
|
-
|
|
1717
|
-
# ensure to trigger recompilation if flags affecting kernel compilation are changed
|
|
1718
|
-
if warp.config.verify_fp:
|
|
1719
|
-
h.update(bytes("verify_fp", "utf-8"))
|
|
1720
|
-
|
|
1721
|
-
h.update(bytes(warp.config.mode, "utf-8"))
|
|
1722
|
-
|
|
1723
|
-
# recurse on references
|
|
1724
|
-
visited.add(module)
|
|
1725
|
-
|
|
1726
|
-
sorted_deps = sorted(module.references, key=lambda m: m.name)
|
|
1727
|
-
for dep in sorted_deps:
|
|
1728
|
-
if dep not in visited:
|
|
1729
|
-
dep_hash = hash_recursive(dep, visited)
|
|
1730
|
-
h.update(dep_hash)
|
|
1731
|
-
|
|
1732
|
-
return h.digest()
|
|
1733
|
-
|
|
1734
|
-
return hash_recursive(self, visited=set())
|
|
1876
|
+
def hash_module(self):
|
|
1877
|
+
# compute latest hash
|
|
1878
|
+
self.hasher = ModuleHasher(self)
|
|
1879
|
+
return self.hasher.get_module_hash()
|
|
1735
1880
|
|
|
1736
1881
|
def load(self, device) -> ModuleExec:
|
|
1737
1882
|
device = runtime.get_device(device)
|
|
1738
1883
|
|
|
1739
|
-
if
|
|
1740
|
-
|
|
1741
|
-
|
|
1742
|
-
|
|
1743
|
-
|
|
1744
|
-
|
|
1745
|
-
|
|
1746
|
-
if
|
|
1747
|
-
|
|
1748
|
-
|
|
1749
|
-
|
|
1750
|
-
|
|
1751
|
-
|
|
1752
|
-
return cuda_exec
|
|
1753
|
-
# avoid repeated build attempts
|
|
1754
|
-
if self.cuda_build_failed:
|
|
1755
|
-
return None
|
|
1756
|
-
if not warp.is_cuda_available():
|
|
1757
|
-
raise RuntimeError("Failed to build CUDA module because CUDA is not available")
|
|
1884
|
+
# compute the hash if needed
|
|
1885
|
+
if self.hasher is None:
|
|
1886
|
+
self.hasher = ModuleHasher(self)
|
|
1887
|
+
|
|
1888
|
+
# check if executable module is already loaded and not stale
|
|
1889
|
+
exec = self.execs.get(device.context)
|
|
1890
|
+
if exec is not None:
|
|
1891
|
+
if exec.module_hash == self.hasher.module_hash:
|
|
1892
|
+
return exec
|
|
1893
|
+
|
|
1894
|
+
# quietly avoid repeated build attempts to reduce error spew
|
|
1895
|
+
if device.context in self.failed_builds:
|
|
1896
|
+
return None
|
|
1758
1897
|
|
|
1759
1898
|
module_name = "wp_" + self.name
|
|
1760
|
-
module_hash = self.
|
|
1899
|
+
module_hash = self.hasher.module_hash
|
|
1761
1900
|
|
|
1762
1901
|
# use a unique module path using the module short hash
|
|
1763
1902
|
module_dir = os.path.join(warp.config.kernel_cache_dir, f"{module_name}_{module_hash.hex()[:7]}")
|
|
@@ -1807,7 +1946,7 @@ class Module:
|
|
|
1807
1946
|
or not warp.config.cache_kernels
|
|
1808
1947
|
or warp.config.verify_autograd_array_access
|
|
1809
1948
|
):
|
|
1810
|
-
builder = ModuleBuilder(self, self.options)
|
|
1949
|
+
builder = ModuleBuilder(self, self.options, hasher=self.hasher)
|
|
1811
1950
|
|
|
1812
1951
|
# create a temporary (process unique) dir for build outputs before moving to the binary dir
|
|
1813
1952
|
build_dir = os.path.join(
|
|
@@ -1844,7 +1983,7 @@ class Module:
|
|
|
1844
1983
|
)
|
|
1845
1984
|
|
|
1846
1985
|
except Exception as e:
|
|
1847
|
-
self.
|
|
1986
|
+
self.failed_builds.add(None)
|
|
1848
1987
|
module_load_timer.extra_msg = " (error)"
|
|
1849
1988
|
raise (e)
|
|
1850
1989
|
|
|
@@ -1873,7 +2012,7 @@ class Module:
|
|
|
1873
2012
|
)
|
|
1874
2013
|
|
|
1875
2014
|
except Exception as e:
|
|
1876
|
-
self.
|
|
2015
|
+
self.failed_builds.add(device.context)
|
|
1877
2016
|
module_load_timer.extra_msg = " (error)"
|
|
1878
2017
|
raise (e)
|
|
1879
2018
|
|
|
@@ -1914,15 +2053,18 @@ class Module:
|
|
|
1914
2053
|
# -----------------------------------------------------------
|
|
1915
2054
|
# Load CPU or CUDA binary
|
|
1916
2055
|
if device.is_cpu:
|
|
1917
|
-
|
|
1918
|
-
|
|
1919
|
-
self.
|
|
2056
|
+
# LLVM modules are identified using strings, so we need to ensure uniqueness
|
|
2057
|
+
module_handle = f"{module_name}_{self.cpu_exec_id}"
|
|
2058
|
+
self.cpu_exec_id += 1
|
|
2059
|
+
runtime.llvm.load_obj(binary_path.encode("utf-8"), module_handle.encode("utf-8"))
|
|
2060
|
+
module_exec = ModuleExec(module_handle, module_hash, device)
|
|
2061
|
+
self.execs[None] = module_exec
|
|
1920
2062
|
|
|
1921
2063
|
elif device.is_cuda:
|
|
1922
2064
|
cuda_module = warp.build.load_cuda(binary_path, device)
|
|
1923
2065
|
if cuda_module is not None:
|
|
1924
|
-
module_exec = ModuleExec(cuda_module, device)
|
|
1925
|
-
self.
|
|
2066
|
+
module_exec = ModuleExec(cuda_module, module_hash, device)
|
|
2067
|
+
self.execs[device.context] = module_exec
|
|
1926
2068
|
else:
|
|
1927
2069
|
module_load_timer.extra_msg = " (error)"
|
|
1928
2070
|
raise Exception(f"Failed to load CUDA module '{self.name}'")
|
|
@@ -1936,20 +2078,22 @@ class Module:
|
|
|
1936
2078
|
return module_exec
|
|
1937
2079
|
|
|
1938
2080
|
def unload(self):
|
|
2081
|
+
# force rehashing on next load
|
|
2082
|
+
self.mark_modified()
|
|
2083
|
+
|
|
1939
2084
|
# clear loaded modules
|
|
1940
|
-
self.
|
|
1941
|
-
self.cuda_execs = {}
|
|
2085
|
+
self.execs = {}
|
|
1942
2086
|
|
|
1943
|
-
|
|
1944
|
-
|
|
2087
|
+
def mark_modified(self):
|
|
2088
|
+
# clear hash data
|
|
2089
|
+
self.hasher = None
|
|
2090
|
+
|
|
2091
|
+
# clear build failures
|
|
2092
|
+
self.failed_builds = set()
|
|
1945
2093
|
|
|
1946
2094
|
# lookup kernel entry points based on name, called after compilation / module load
|
|
1947
2095
|
def get_kernel_hooks(self, kernel, device):
|
|
1948
|
-
|
|
1949
|
-
module_exec = self.cuda_execs.get(device.context)
|
|
1950
|
-
else:
|
|
1951
|
-
module_exec = self.cpu_exec
|
|
1952
|
-
|
|
2096
|
+
module_exec = self.execs.get(device.context)
|
|
1953
2097
|
if module_exec is not None:
|
|
1954
2098
|
return module_exec.get_kernel_hooks(kernel)
|
|
1955
2099
|
else:
|
|
@@ -2056,6 +2200,63 @@ class ContextGuard:
|
|
|
2056
2200
|
runtime.core.cuda_context_set_current(self.saved_context)
|
|
2057
2201
|
|
|
2058
2202
|
|
|
2203
|
+
class Event:
|
|
2204
|
+
"""A CUDA event that can be recorded onto a stream.
|
|
2205
|
+
|
|
2206
|
+
Events can be used for device-side synchronization, which do not block
|
|
2207
|
+
the host thread.
|
|
2208
|
+
"""
|
|
2209
|
+
|
|
2210
|
+
# event creation flags
|
|
2211
|
+
class Flags:
|
|
2212
|
+
DEFAULT = 0x0
|
|
2213
|
+
BLOCKING_SYNC = 0x1
|
|
2214
|
+
DISABLE_TIMING = 0x2
|
|
2215
|
+
|
|
2216
|
+
def __new__(cls, *args, **kwargs):
|
|
2217
|
+
"""Creates a new event instance."""
|
|
2218
|
+
instance = super(Event, cls).__new__(cls)
|
|
2219
|
+
instance.owner = False
|
|
2220
|
+
return instance
|
|
2221
|
+
|
|
2222
|
+
def __init__(self, device: "Devicelike" = None, cuda_event=None, enable_timing: bool = False):
|
|
2223
|
+
"""Initializes the event on a CUDA device.
|
|
2224
|
+
|
|
2225
|
+
Args:
|
|
2226
|
+
device: The CUDA device whose streams this event may be recorded onto.
|
|
2227
|
+
If ``None``, then the current default device will be used.
|
|
2228
|
+
cuda_event: A pointer to a previously allocated CUDA event. If
|
|
2229
|
+
`None`, then a new event will be allocated on the associated device.
|
|
2230
|
+
enable_timing: If ``True`` this event will record timing data.
|
|
2231
|
+
:func:`~warp.get_event_elapsed_time` can be used to measure the
|
|
2232
|
+
time between two events created with ``enable_timing=True`` and
|
|
2233
|
+
recorded onto streams.
|
|
2234
|
+
"""
|
|
2235
|
+
|
|
2236
|
+
device = get_device(device)
|
|
2237
|
+
if not device.is_cuda:
|
|
2238
|
+
raise RuntimeError(f"Device {device} is not a CUDA device")
|
|
2239
|
+
|
|
2240
|
+
self.device = device
|
|
2241
|
+
|
|
2242
|
+
if cuda_event is not None:
|
|
2243
|
+
self.cuda_event = cuda_event
|
|
2244
|
+
else:
|
|
2245
|
+
flags = Event.Flags.DEFAULT
|
|
2246
|
+
if not enable_timing:
|
|
2247
|
+
flags |= Event.Flags.DISABLE_TIMING
|
|
2248
|
+
self.cuda_event = runtime.core.cuda_event_create(device.context, flags)
|
|
2249
|
+
if not self.cuda_event:
|
|
2250
|
+
raise RuntimeError(f"Failed to create event on device {device}")
|
|
2251
|
+
self.owner = True
|
|
2252
|
+
|
|
2253
|
+
def __del__(self):
|
|
2254
|
+
if not self.owner:
|
|
2255
|
+
return
|
|
2256
|
+
|
|
2257
|
+
runtime.core.cuda_event_destroy(self.cuda_event)
|
|
2258
|
+
|
|
2259
|
+
|
|
2059
2260
|
class Stream:
|
|
2060
2261
|
def __new__(cls, *args, **kwargs):
|
|
2061
2262
|
instance = super(Stream, cls).__new__(cls)
|
|
@@ -2063,7 +2264,27 @@ class Stream:
|
|
|
2063
2264
|
instance.owner = False
|
|
2064
2265
|
return instance
|
|
2065
2266
|
|
|
2066
|
-
def __init__(self, device=None, **kwargs):
|
|
2267
|
+
def __init__(self, device: Optional[Union["Device", str]] = None, priority: int = 0, **kwargs):
|
|
2268
|
+
"""Initialize the stream on a device with an optional specified priority.
|
|
2269
|
+
|
|
2270
|
+
Args:
|
|
2271
|
+
device: The CUDA device on which this stream will be created.
|
|
2272
|
+
priority: An optional integer specifying the requested stream priority.
|
|
2273
|
+
Can be -1 (high priority) or 0 (low/default priority).
|
|
2274
|
+
Values outside this range will be clamped.
|
|
2275
|
+
cuda_stream (int): A optional external stream handle passed as an
|
|
2276
|
+
integer. The caller is responsible for ensuring that the external
|
|
2277
|
+
stream does not get destroyed while it is referenced by this
|
|
2278
|
+
object.
|
|
2279
|
+
|
|
2280
|
+
Raises:
|
|
2281
|
+
RuntimeError: If function is called before Warp has completed
|
|
2282
|
+
initialization with a ``device`` that is not an instance of
|
|
2283
|
+
:class:`Device``.
|
|
2284
|
+
RuntimeError: ``device`` is not a CUDA Device.
|
|
2285
|
+
RuntimeError: The stream could not be created on the device.
|
|
2286
|
+
TypeError: The requested stream priority is not an integer.
|
|
2287
|
+
"""
|
|
2067
2288
|
# event used internally for synchronization (cached to avoid creating temporary events)
|
|
2068
2289
|
self._cached_event = None
|
|
2069
2290
|
|
|
@@ -2072,7 +2293,7 @@ class Stream:
|
|
|
2072
2293
|
device = runtime.get_device(device)
|
|
2073
2294
|
elif not isinstance(device, Device):
|
|
2074
2295
|
raise RuntimeError(
|
|
2075
|
-
"A
|
|
2296
|
+
"A Device object is required when creating a stream before or during Warp initialization"
|
|
2076
2297
|
)
|
|
2077
2298
|
|
|
2078
2299
|
if not device.is_cuda:
|
|
@@ -2085,7 +2306,11 @@ class Stream:
|
|
|
2085
2306
|
self.cuda_stream = kwargs["cuda_stream"]
|
|
2086
2307
|
device.runtime.core.cuda_stream_register(device.context, self.cuda_stream)
|
|
2087
2308
|
else:
|
|
2088
|
-
|
|
2309
|
+
if not isinstance(priority, int):
|
|
2310
|
+
raise TypeError("Stream priority must be an integer.")
|
|
2311
|
+
clamped_priority = max(-1, min(priority, 0)) # Only support two priority levels
|
|
2312
|
+
self.cuda_stream = device.runtime.core.cuda_stream_create(device.context, clamped_priority)
|
|
2313
|
+
|
|
2089
2314
|
if not self.cuda_stream:
|
|
2090
2315
|
raise RuntimeError(f"Failed to create stream on device {device}")
|
|
2091
2316
|
self.owner = True
|
|
@@ -2100,12 +2325,22 @@ class Stream:
|
|
|
2100
2325
|
runtime.core.cuda_stream_unregister(self.device.context, self.cuda_stream)
|
|
2101
2326
|
|
|
2102
2327
|
@property
|
|
2103
|
-
def cached_event(self):
|
|
2328
|
+
def cached_event(self) -> Event:
|
|
2104
2329
|
if self._cached_event is None:
|
|
2105
2330
|
self._cached_event = Event(self.device)
|
|
2106
2331
|
return self._cached_event
|
|
2107
2332
|
|
|
2108
|
-
def record_event(self, event=None):
|
|
2333
|
+
def record_event(self, event: Optional[Event] = None) -> Event:
|
|
2334
|
+
"""Record an event onto the stream.
|
|
2335
|
+
|
|
2336
|
+
Args:
|
|
2337
|
+
event: A warp.Event instance to be recorded onto the stream. If not
|
|
2338
|
+
provided, an :class:`~warp.Event` on the same device will be created.
|
|
2339
|
+
|
|
2340
|
+
Raises:
|
|
2341
|
+
RuntimeError: The provided :class:`~warp.Event` is from a different device than
|
|
2342
|
+
the recording stream.
|
|
2343
|
+
"""
|
|
2109
2344
|
if event is None:
|
|
2110
2345
|
event = Event(self.device)
|
|
2111
2346
|
elif event.device != self.device:
|
|
@@ -2117,56 +2352,45 @@ class Stream:
|
|
|
2117
2352
|
|
|
2118
2353
|
return event
|
|
2119
2354
|
|
|
2120
|
-
def wait_event(self, event):
|
|
2355
|
+
def wait_event(self, event: Event):
|
|
2356
|
+
"""Makes all future work in this stream wait until `event` has completed.
|
|
2357
|
+
|
|
2358
|
+
This function does not block the host thread.
|
|
2359
|
+
"""
|
|
2121
2360
|
runtime.core.cuda_stream_wait_event(self.cuda_stream, event.cuda_event)
|
|
2122
2361
|
|
|
2123
|
-
def wait_stream(self, other_stream, event=None):
|
|
2362
|
+
def wait_stream(self, other_stream: "Stream", event: Optional[Event] = None):
|
|
2363
|
+
"""Records an event on `other_stream` and makes this stream wait on it.
|
|
2364
|
+
|
|
2365
|
+
All work added to this stream after this function has been called will
|
|
2366
|
+
delay their execution until all preceding commands in `other_stream`
|
|
2367
|
+
have completed.
|
|
2368
|
+
|
|
2369
|
+
This function does not block the host thread.
|
|
2370
|
+
|
|
2371
|
+
Args:
|
|
2372
|
+
other_stream: The stream on which the calling stream will wait for
|
|
2373
|
+
previously issued commands to complete before executing subsequent
|
|
2374
|
+
commands.
|
|
2375
|
+
event: An optional :class:`Event` instance that will be used to
|
|
2376
|
+
record an event onto ``other_stream``. If ``None``, an internally
|
|
2377
|
+
managed :class:`Event` instance will be used.
|
|
2378
|
+
"""
|
|
2379
|
+
|
|
2124
2380
|
if event is None:
|
|
2125
2381
|
event = other_stream.cached_event
|
|
2126
2382
|
|
|
2127
2383
|
runtime.core.cuda_stream_wait_stream(self.cuda_stream, other_stream.cuda_stream, event.cuda_event)
|
|
2128
2384
|
|
|
2129
|
-
# whether a graph capture is currently ongoing on this stream
|
|
2130
2385
|
@property
|
|
2131
|
-
def is_capturing(self):
|
|
2386
|
+
def is_capturing(self) -> bool:
|
|
2387
|
+
"""A boolean indicating whether a graph capture is currently ongoing on this stream."""
|
|
2132
2388
|
return bool(runtime.core.cuda_stream_is_capturing(self.cuda_stream))
|
|
2133
2389
|
|
|
2134
|
-
|
|
2135
|
-
|
|
2136
|
-
|
|
2137
|
-
|
|
2138
|
-
DEFAULT = 0x0
|
|
2139
|
-
BLOCKING_SYNC = 0x1
|
|
2140
|
-
DISABLE_TIMING = 0x2
|
|
2141
|
-
|
|
2142
|
-
def __new__(cls, *args, **kwargs):
|
|
2143
|
-
instance = super(Event, cls).__new__(cls)
|
|
2144
|
-
instance.owner = False
|
|
2145
|
-
return instance
|
|
2146
|
-
|
|
2147
|
-
def __init__(self, device=None, cuda_event=None, enable_timing=False):
|
|
2148
|
-
device = get_device(device)
|
|
2149
|
-
if not device.is_cuda:
|
|
2150
|
-
raise RuntimeError(f"Device {device} is not a CUDA device")
|
|
2151
|
-
|
|
2152
|
-
self.device = device
|
|
2153
|
-
|
|
2154
|
-
if cuda_event is not None:
|
|
2155
|
-
self.cuda_event = cuda_event
|
|
2156
|
-
else:
|
|
2157
|
-
flags = Event.Flags.DEFAULT
|
|
2158
|
-
if not enable_timing:
|
|
2159
|
-
flags |= Event.Flags.DISABLE_TIMING
|
|
2160
|
-
self.cuda_event = runtime.core.cuda_event_create(device.context, flags)
|
|
2161
|
-
if not self.cuda_event:
|
|
2162
|
-
raise RuntimeError(f"Failed to create event on device {device}")
|
|
2163
|
-
self.owner = True
|
|
2164
|
-
|
|
2165
|
-
def __del__(self):
|
|
2166
|
-
if not self.owner:
|
|
2167
|
-
return
|
|
2168
|
-
|
|
2169
|
-
runtime.core.cuda_event_destroy(self.cuda_event)
|
|
2390
|
+
@property
|
|
2391
|
+
def priority(self) -> int:
|
|
2392
|
+
"""An integer representing the priority of the stream."""
|
|
2393
|
+
return runtime.core.cuda_stream_get_priority(self.cuda_stream)
|
|
2170
2394
|
|
|
2171
2395
|
|
|
2172
2396
|
class Device:
|
|
@@ -2178,14 +2402,14 @@ class Device:
|
|
|
2178
2402
|
or ``"CPU"`` if the processor name cannot be determined.
|
|
2179
2403
|
arch: An integer representing the compute capability version number calculated as
|
|
2180
2404
|
``10 * major + minor``. ``0`` for CPU devices.
|
|
2181
|
-
is_uva: A boolean indicating whether
|
|
2405
|
+
is_uva: A boolean indicating whether the device supports unified addressing.
|
|
2182
2406
|
``False`` for CPU devices.
|
|
2183
|
-
is_cubin_supported: A boolean indicating whether
|
|
2407
|
+
is_cubin_supported: A boolean indicating whether Warp's version of NVRTC can directly
|
|
2184
2408
|
generate CUDA binary files (cubin) for this device's architecture. ``False`` for CPU devices.
|
|
2185
|
-
is_mempool_supported: A boolean indicating whether
|
|
2409
|
+
is_mempool_supported: A boolean indicating whether the device supports using the
|
|
2186
2410
|
``cuMemAllocAsync`` and ``cuMemPool`` family of APIs for stream-ordered memory allocations. ``False`` for
|
|
2187
2411
|
CPU devices.
|
|
2188
|
-
is_primary: A boolean indicating whether
|
|
2412
|
+
is_primary: A boolean indicating whether this device's CUDA context is also the
|
|
2189
2413
|
device's primary context.
|
|
2190
2414
|
uuid: A string representing the UUID of the CUDA device. The UUID is in the same format used by
|
|
2191
2415
|
``nvidia-smi -L``. ``None`` for CPU devices.
|
|
@@ -2274,7 +2498,7 @@ class Device:
|
|
|
2274
2498
|
|
|
2275
2499
|
# initialize streams unless context acquisition is postponed
|
|
2276
2500
|
if self._context is not None:
|
|
2277
|
-
self.
|
|
2501
|
+
self._init_streams()
|
|
2278
2502
|
|
|
2279
2503
|
# TODO: add more device-specific dispatch functions
|
|
2280
2504
|
self.memset = lambda ptr, value, size: runtime.core.memset_device(self.context, ptr, value, size)
|
|
@@ -2285,7 +2509,13 @@ class Device:
|
|
|
2285
2509
|
else:
|
|
2286
2510
|
raise RuntimeError(f"Invalid device ordinal ({ordinal})'")
|
|
2287
2511
|
|
|
2288
|
-
def get_allocator(self, pinned=False):
|
|
2512
|
+
def get_allocator(self, pinned: bool = False):
|
|
2513
|
+
"""Get the memory allocator for this device.
|
|
2514
|
+
|
|
2515
|
+
Args:
|
|
2516
|
+
pinned: If ``True``, an allocator for pinned memory will be
|
|
2517
|
+
returned. Only applicable when this device is a CPU device.
|
|
2518
|
+
"""
|
|
2289
2519
|
if self.is_cuda:
|
|
2290
2520
|
return self.current_allocator
|
|
2291
2521
|
else:
|
|
@@ -2294,7 +2524,8 @@ class Device:
|
|
|
2294
2524
|
else:
|
|
2295
2525
|
return self.default_allocator
|
|
2296
2526
|
|
|
2297
|
-
def
|
|
2527
|
+
def _init_streams(self):
|
|
2528
|
+
"""Initializes the device's current stream and the device's null stream."""
|
|
2298
2529
|
# create a stream for asynchronous work
|
|
2299
2530
|
self.set_stream(Stream(self))
|
|
2300
2531
|
|
|
@@ -2302,17 +2533,18 @@ class Device:
|
|
|
2302
2533
|
self.null_stream = Stream(self, cuda_stream=None)
|
|
2303
2534
|
|
|
2304
2535
|
@property
|
|
2305
|
-
def is_cpu(self):
|
|
2306
|
-
"""A boolean indicating whether
|
|
2536
|
+
def is_cpu(self) -> bool:
|
|
2537
|
+
"""A boolean indicating whether the device is a CPU device."""
|
|
2307
2538
|
return self.ordinal < 0
|
|
2308
2539
|
|
|
2309
2540
|
@property
|
|
2310
|
-
def is_cuda(self):
|
|
2311
|
-
"""A boolean indicating whether
|
|
2541
|
+
def is_cuda(self) -> bool:
|
|
2542
|
+
"""A boolean indicating whether the device is a CUDA device."""
|
|
2312
2543
|
return self.ordinal >= 0
|
|
2313
2544
|
|
|
2314
2545
|
@property
|
|
2315
|
-
def is_capturing(self):
|
|
2546
|
+
def is_capturing(self) -> bool:
|
|
2547
|
+
"""A boolean indicating whether this device's default stream is currently capturing a graph."""
|
|
2316
2548
|
if self.is_cuda and self.stream is not None:
|
|
2317
2549
|
# There is no CUDA API to check if graph capture was started on a device, so we
|
|
2318
2550
|
# can't tell if a capture was started by external code on a different stream.
|
|
@@ -2336,17 +2568,17 @@ class Device:
|
|
|
2336
2568
|
raise RuntimeError(f"Failed to acquire primary context for device {self}")
|
|
2337
2569
|
self.runtime.context_map[self._context] = self
|
|
2338
2570
|
# initialize streams
|
|
2339
|
-
self.
|
|
2571
|
+
self._init_streams()
|
|
2340
2572
|
runtime.core.cuda_context_set_current(prev_context)
|
|
2341
2573
|
return self._context
|
|
2342
2574
|
|
|
2343
2575
|
@property
|
|
2344
|
-
def has_context(self):
|
|
2345
|
-
"""A boolean indicating whether
|
|
2576
|
+
def has_context(self) -> bool:
|
|
2577
|
+
"""A boolean indicating whether the device has a CUDA context associated with it."""
|
|
2346
2578
|
return self._context is not None
|
|
2347
2579
|
|
|
2348
2580
|
@property
|
|
2349
|
-
def stream(self):
|
|
2581
|
+
def stream(self) -> Stream:
|
|
2350
2582
|
"""The stream associated with a CUDA device.
|
|
2351
2583
|
|
|
2352
2584
|
Raises:
|
|
@@ -2361,7 +2593,22 @@ class Device:
|
|
|
2361
2593
|
def stream(self, stream):
|
|
2362
2594
|
self.set_stream(stream)
|
|
2363
2595
|
|
|
2364
|
-
def set_stream(self, stream, sync=True):
|
|
2596
|
+
def set_stream(self, stream: Stream, sync: bool = True) -> None:
|
|
2597
|
+
"""Set the current stream for this CUDA device.
|
|
2598
|
+
|
|
2599
|
+
The current stream will be used by default for all kernel launches and
|
|
2600
|
+
memory operations on this device.
|
|
2601
|
+
|
|
2602
|
+
If this is an external stream, the caller is responsible for
|
|
2603
|
+
guaranteeing the lifetime of the stream.
|
|
2604
|
+
|
|
2605
|
+
Consider using :class:`warp.ScopedStream` instead.
|
|
2606
|
+
|
|
2607
|
+
Args:
|
|
2608
|
+
stream: The stream to set as this device's current stream.
|
|
2609
|
+
sync: If ``True``, then ``stream`` will perform a device-side
|
|
2610
|
+
synchronization with the device's previous current stream.
|
|
2611
|
+
"""
|
|
2365
2612
|
if self.is_cuda:
|
|
2366
2613
|
if stream.device != self:
|
|
2367
2614
|
raise RuntimeError(f"Stream from device {stream.device} cannot be used on device {self}")
|
|
@@ -2372,12 +2619,12 @@ class Device:
|
|
|
2372
2619
|
raise RuntimeError(f"Device {self} is not a CUDA device")
|
|
2373
2620
|
|
|
2374
2621
|
@property
|
|
2375
|
-
def has_stream(self):
|
|
2376
|
-
"""A boolean indicating whether
|
|
2622
|
+
def has_stream(self) -> bool:
|
|
2623
|
+
"""A boolean indicating whether the device has a stream associated with it."""
|
|
2377
2624
|
return self._stream is not None
|
|
2378
2625
|
|
|
2379
2626
|
@property
|
|
2380
|
-
def total_memory(self):
|
|
2627
|
+
def total_memory(self) -> int:
|
|
2381
2628
|
"""The total amount of device memory available in bytes.
|
|
2382
2629
|
|
|
2383
2630
|
This function is currently only implemented for CUDA devices. 0 will be returned if called on a CPU device.
|
|
@@ -2391,7 +2638,7 @@ class Device:
|
|
|
2391
2638
|
return 0
|
|
2392
2639
|
|
|
2393
2640
|
@property
|
|
2394
|
-
def free_memory(self):
|
|
2641
|
+
def free_memory(self) -> int:
|
|
2395
2642
|
"""The amount of memory on the device that is free according to the OS in bytes.
|
|
2396
2643
|
|
|
2397
2644
|
This function is currently only implemented for CUDA devices. 0 will be returned if called on a CPU device.
|
|
@@ -2755,6 +3002,12 @@ class Runtime:
|
|
|
2755
3002
|
self.core.mesh_refit_host.argtypes = [ctypes.c_uint64]
|
|
2756
3003
|
self.core.mesh_refit_device.argtypes = [ctypes.c_uint64]
|
|
2757
3004
|
|
|
3005
|
+
self.core.mesh_set_points_host.argtypes = [ctypes.c_uint64, warp.types.array_t]
|
|
3006
|
+
self.core.mesh_set_points_device.argtypes = [ctypes.c_uint64, warp.types.array_t]
|
|
3007
|
+
|
|
3008
|
+
self.core.mesh_set_velocities_host.argtypes = [ctypes.c_uint64, warp.types.array_t]
|
|
3009
|
+
self.core.mesh_set_velocities_device.argtypes = [ctypes.c_uint64, warp.types.array_t]
|
|
3010
|
+
|
|
2758
3011
|
self.core.hash_grid_create_host.argtypes = [ctypes.c_int, ctypes.c_int, ctypes.c_int]
|
|
2759
3012
|
self.core.hash_grid_create_host.restype = ctypes.c_uint64
|
|
2760
3013
|
self.core.hash_grid_destroy_host.argtypes = [ctypes.c_uint64]
|
|
@@ -3029,7 +3282,7 @@ class Runtime:
|
|
|
3029
3282
|
self.core.cuda_set_mempool_access_enabled.argtypes = [ctypes.c_int, ctypes.c_int, ctypes.c_int]
|
|
3030
3283
|
self.core.cuda_set_mempool_access_enabled.restype = ctypes.c_int
|
|
3031
3284
|
|
|
3032
|
-
self.core.cuda_stream_create.argtypes = [ctypes.c_void_p]
|
|
3285
|
+
self.core.cuda_stream_create.argtypes = [ctypes.c_void_p, ctypes.c_int]
|
|
3033
3286
|
self.core.cuda_stream_create.restype = ctypes.c_void_p
|
|
3034
3287
|
self.core.cuda_stream_destroy.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
|
|
3035
3288
|
self.core.cuda_stream_destroy.restype = None
|
|
@@ -3047,6 +3300,8 @@ class Runtime:
|
|
|
3047
3300
|
self.core.cuda_stream_is_capturing.restype = ctypes.c_int
|
|
3048
3301
|
self.core.cuda_stream_get_capture_id.argtypes = [ctypes.c_void_p]
|
|
3049
3302
|
self.core.cuda_stream_get_capture_id.restype = ctypes.c_uint64
|
|
3303
|
+
self.core.cuda_stream_get_priority.argtypes = [ctypes.c_void_p]
|
|
3304
|
+
self.core.cuda_stream_get_priority.restype = ctypes.c_int
|
|
3050
3305
|
|
|
3051
3306
|
self.core.cuda_event_create.argtypes = [ctypes.c_void_p, ctypes.c_uint]
|
|
3052
3307
|
self.core.cuda_event_create.restype = ctypes.c_void_p
|
|
@@ -3382,10 +3637,10 @@ class Runtime:
|
|
|
3382
3637
|
|
|
3383
3638
|
raise ValueError(f"Invalid device identifier: {ident}")
|
|
3384
3639
|
|
|
3385
|
-
def set_default_device(self, ident: Devicelike):
|
|
3640
|
+
def set_default_device(self, ident: Devicelike) -> None:
|
|
3386
3641
|
self.default_device = self.get_device(ident)
|
|
3387
3642
|
|
|
3388
|
-
def get_current_cuda_device(self):
|
|
3643
|
+
def get_current_cuda_device(self) -> Device:
|
|
3389
3644
|
current_context = self.core.cuda_context_get_current()
|
|
3390
3645
|
if current_context is not None:
|
|
3391
3646
|
current_device = self.context_map.get(current_context)
|
|
@@ -3415,7 +3670,7 @@ class Runtime:
|
|
|
3415
3670
|
else:
|
|
3416
3671
|
raise RuntimeError('"cuda" device requested but CUDA is not supported by the hardware or driver')
|
|
3417
3672
|
|
|
3418
|
-
def rename_device(self, device, alias):
|
|
3673
|
+
def rename_device(self, device, alias) -> Device:
|
|
3419
3674
|
del self.device_map[device.alias]
|
|
3420
3675
|
device.alias = alias
|
|
3421
3676
|
self.device_map[alias] = device
|
|
@@ -3462,7 +3717,7 @@ class Runtime:
|
|
|
3462
3717
|
|
|
3463
3718
|
return device
|
|
3464
3719
|
|
|
3465
|
-
def unmap_cuda_device(self, alias):
|
|
3720
|
+
def unmap_cuda_device(self, alias) -> None:
|
|
3466
3721
|
device = self.device_map.get(alias)
|
|
3467
3722
|
|
|
3468
3723
|
# make sure the alias refers to a CUDA device
|
|
@@ -3473,7 +3728,7 @@ class Runtime:
|
|
|
3473
3728
|
del self.context_map[device.context]
|
|
3474
3729
|
self.cuda_devices.remove(device)
|
|
3475
3730
|
|
|
3476
|
-
def verify_cuda_device(self, device: Devicelike = None):
|
|
3731
|
+
def verify_cuda_device(self, device: Devicelike = None) -> None:
|
|
3477
3732
|
if warp.config.verify_cuda:
|
|
3478
3733
|
device = runtime.get_device(device)
|
|
3479
3734
|
if not device.is_cuda:
|
|
@@ -3485,13 +3740,13 @@ class Runtime:
|
|
|
3485
3740
|
|
|
3486
3741
|
|
|
3487
3742
|
# global entry points
|
|
3488
|
-
def is_cpu_available():
|
|
3743
|
+
def is_cpu_available() -> bool:
|
|
3489
3744
|
init()
|
|
3490
3745
|
|
|
3491
|
-
return runtime.llvm
|
|
3746
|
+
return runtime.llvm is not None
|
|
3492
3747
|
|
|
3493
3748
|
|
|
3494
|
-
def is_cuda_available():
|
|
3749
|
+
def is_cuda_available() -> bool:
|
|
3495
3750
|
return get_cuda_device_count() > 0
|
|
3496
3751
|
|
|
3497
3752
|
|
|
@@ -3575,8 +3830,8 @@ def get_device(ident: Devicelike = None) -> Device:
|
|
|
3575
3830
|
return runtime.get_device(ident)
|
|
3576
3831
|
|
|
3577
3832
|
|
|
3578
|
-
def set_device(ident: Devicelike):
|
|
3579
|
-
"""Sets the
|
|
3833
|
+
def set_device(ident: Devicelike) -> None:
|
|
3834
|
+
"""Sets the default device identified by the argument."""
|
|
3580
3835
|
|
|
3581
3836
|
init()
|
|
3582
3837
|
|
|
@@ -3604,7 +3859,7 @@ def map_cuda_device(alias: str, context: ctypes.c_void_p = None) -> Device:
|
|
|
3604
3859
|
return runtime.map_cuda_device(alias, context)
|
|
3605
3860
|
|
|
3606
3861
|
|
|
3607
|
-
def unmap_cuda_device(alias: str):
|
|
3862
|
+
def unmap_cuda_device(alias: str) -> None:
|
|
3608
3863
|
"""Remove a CUDA device with the given alias."""
|
|
3609
3864
|
|
|
3610
3865
|
init()
|
|
@@ -3612,7 +3867,7 @@ def unmap_cuda_device(alias: str):
|
|
|
3612
3867
|
runtime.unmap_cuda_device(alias)
|
|
3613
3868
|
|
|
3614
3869
|
|
|
3615
|
-
def is_mempool_supported(device: Devicelike):
|
|
3870
|
+
def is_mempool_supported(device: Devicelike) -> bool:
|
|
3616
3871
|
"""Check if CUDA memory pool allocators are available on the device."""
|
|
3617
3872
|
|
|
3618
3873
|
init()
|
|
@@ -3622,7 +3877,7 @@ def is_mempool_supported(device: Devicelike):
|
|
|
3622
3877
|
return device.is_mempool_supported
|
|
3623
3878
|
|
|
3624
3879
|
|
|
3625
|
-
def is_mempool_enabled(device: Devicelike):
|
|
3880
|
+
def is_mempool_enabled(device: Devicelike) -> bool:
|
|
3626
3881
|
"""Check if CUDA memory pool allocators are enabled on the device."""
|
|
3627
3882
|
|
|
3628
3883
|
init()
|
|
@@ -3632,7 +3887,7 @@ def is_mempool_enabled(device: Devicelike):
|
|
|
3632
3887
|
return device.is_mempool_enabled
|
|
3633
3888
|
|
|
3634
3889
|
|
|
3635
|
-
def set_mempool_enabled(device: Devicelike, enable: bool):
|
|
3890
|
+
def set_mempool_enabled(device: Devicelike, enable: bool) -> None:
|
|
3636
3891
|
"""Enable or disable CUDA memory pool allocators on the device.
|
|
3637
3892
|
|
|
3638
3893
|
Pooled allocators are typically faster and allow allocating memory during graph capture.
|
|
@@ -3663,7 +3918,7 @@ def set_mempool_enabled(device: Devicelike, enable: bool):
|
|
|
3663
3918
|
raise ValueError("Memory pools are only supported on CUDA devices")
|
|
3664
3919
|
|
|
3665
3920
|
|
|
3666
|
-
def set_mempool_release_threshold(device: Devicelike, threshold: Union[int, float]):
|
|
3921
|
+
def set_mempool_release_threshold(device: Devicelike, threshold: Union[int, float]) -> None:
|
|
3667
3922
|
"""Set the CUDA memory pool release threshold on the device.
|
|
3668
3923
|
|
|
3669
3924
|
This is the amount of reserved memory to hold onto before trying to release memory back to the OS.
|
|
@@ -3694,7 +3949,7 @@ def set_mempool_release_threshold(device: Devicelike, threshold: Union[int, floa
|
|
|
3694
3949
|
raise RuntimeError(f"Failed to set memory pool release threshold for device {device}")
|
|
3695
3950
|
|
|
3696
3951
|
|
|
3697
|
-
def get_mempool_release_threshold(device: Devicelike):
|
|
3952
|
+
def get_mempool_release_threshold(device: Devicelike) -> int:
|
|
3698
3953
|
"""Get the CUDA memory pool release threshold on the device."""
|
|
3699
3954
|
|
|
3700
3955
|
init()
|
|
@@ -3710,7 +3965,7 @@ def get_mempool_release_threshold(device: Devicelike):
|
|
|
3710
3965
|
return runtime.core.cuda_device_get_mempool_release_threshold(device.ordinal)
|
|
3711
3966
|
|
|
3712
3967
|
|
|
3713
|
-
def is_peer_access_supported(target_device: Devicelike, peer_device: Devicelike):
|
|
3968
|
+
def is_peer_access_supported(target_device: Devicelike, peer_device: Devicelike) -> bool:
|
|
3714
3969
|
"""Check if `peer_device` can directly access the memory of `target_device` on this system.
|
|
3715
3970
|
|
|
3716
3971
|
This applies to memory allocated using default CUDA allocators. For memory allocated using
|
|
@@ -3731,7 +3986,7 @@ def is_peer_access_supported(target_device: Devicelike, peer_device: Devicelike)
|
|
|
3731
3986
|
return bool(runtime.core.cuda_is_peer_access_supported(target_device.ordinal, peer_device.ordinal))
|
|
3732
3987
|
|
|
3733
3988
|
|
|
3734
|
-
def is_peer_access_enabled(target_device: Devicelike, peer_device: Devicelike):
|
|
3989
|
+
def is_peer_access_enabled(target_device: Devicelike, peer_device: Devicelike) -> bool:
|
|
3735
3990
|
"""Check if `peer_device` can currently access the memory of `target_device`.
|
|
3736
3991
|
|
|
3737
3992
|
This applies to memory allocated using default CUDA allocators. For memory allocated using
|
|
@@ -3752,7 +4007,7 @@ def is_peer_access_enabled(target_device: Devicelike, peer_device: Devicelike):
|
|
|
3752
4007
|
return bool(runtime.core.cuda_is_peer_access_enabled(target_device.context, peer_device.context))
|
|
3753
4008
|
|
|
3754
4009
|
|
|
3755
|
-
def set_peer_access_enabled(target_device: Devicelike, peer_device: Devicelike, enable: bool):
|
|
4010
|
+
def set_peer_access_enabled(target_device: Devicelike, peer_device: Devicelike, enable: bool) -> None:
|
|
3756
4011
|
"""Enable or disable direct access from `peer_device` to the memory of `target_device`.
|
|
3757
4012
|
|
|
3758
4013
|
Enabling peer access can improve the speed of peer-to-peer memory transfers, but can have
|
|
@@ -3784,7 +4039,7 @@ def set_peer_access_enabled(target_device: Devicelike, peer_device: Devicelike,
|
|
|
3784
4039
|
raise RuntimeError(f"Failed to {action} peer access from device {peer_device} to device {target_device}")
|
|
3785
4040
|
|
|
3786
4041
|
|
|
3787
|
-
def is_mempool_access_supported(target_device: Devicelike, peer_device: Devicelike):
|
|
4042
|
+
def is_mempool_access_supported(target_device: Devicelike, peer_device: Devicelike) -> bool:
|
|
3788
4043
|
"""Check if `peer_device` can directly access the memory pool of `target_device`.
|
|
3789
4044
|
|
|
3790
4045
|
If mempool access is possible, it can be managed using `set_mempool_access_enabled()` and `is_mempool_access_enabled()`.
|
|
@@ -3801,7 +4056,7 @@ def is_mempool_access_supported(target_device: Devicelike, peer_device: Deviceli
|
|
|
3801
4056
|
return target_device.is_mempool_supported and is_peer_access_supported(target_device, peer_device)
|
|
3802
4057
|
|
|
3803
4058
|
|
|
3804
|
-
def is_mempool_access_enabled(target_device: Devicelike, peer_device: Devicelike):
|
|
4059
|
+
def is_mempool_access_enabled(target_device: Devicelike, peer_device: Devicelike) -> bool:
|
|
3805
4060
|
"""Check if `peer_device` can currently access the memory pool of `target_device`.
|
|
3806
4061
|
|
|
3807
4062
|
This applies to memory allocated using CUDA pooled allocators. For memory allocated using
|
|
@@ -3822,7 +4077,7 @@ def is_mempool_access_enabled(target_device: Devicelike, peer_device: Devicelike
|
|
|
3822
4077
|
return bool(runtime.core.cuda_is_mempool_access_enabled(target_device.ordinal, peer_device.ordinal))
|
|
3823
4078
|
|
|
3824
4079
|
|
|
3825
|
-
def set_mempool_access_enabled(target_device: Devicelike, peer_device: Devicelike, enable: bool):
|
|
4080
|
+
def set_mempool_access_enabled(target_device: Devicelike, peer_device: Devicelike, enable: bool) -> None:
|
|
3826
4081
|
"""Enable or disable access from `peer_device` to the memory pool of `target_device`.
|
|
3827
4082
|
|
|
3828
4083
|
This applies to memory allocated using CUDA pooled allocators. For memory allocated using
|
|
@@ -3858,26 +4113,41 @@ def set_mempool_access_enabled(target_device: Devicelike, peer_device: Devicelik
|
|
|
3858
4113
|
|
|
3859
4114
|
|
|
3860
4115
|
def get_stream(device: Devicelike = None) -> Stream:
|
|
3861
|
-
"""Return the stream currently used by the given device
|
|
4116
|
+
"""Return the stream currently used by the given device.
|
|
4117
|
+
|
|
4118
|
+
Args:
|
|
4119
|
+
device: An optional :class:`Device` instance or device alias
|
|
4120
|
+
(e.g. "cuda:0") for which the current stream will be returned.
|
|
4121
|
+
If ``None``, the default device will be used.
|
|
4122
|
+
|
|
4123
|
+
Raises:
|
|
4124
|
+
RuntimeError: The device is not a CUDA device.
|
|
4125
|
+
"""
|
|
3862
4126
|
|
|
3863
4127
|
return get_device(device).stream
|
|
3864
4128
|
|
|
3865
4129
|
|
|
3866
|
-
def set_stream(stream, device: Devicelike = None, sync: bool = False):
|
|
3867
|
-
"""
|
|
4130
|
+
def set_stream(stream: Stream, device: Devicelike = None, sync: bool = False) -> None:
|
|
4131
|
+
"""Convenience function for calling :meth:`Device.set_stream` on the given ``device``.
|
|
3868
4132
|
|
|
3869
|
-
|
|
3870
|
-
|
|
4133
|
+
Args:
|
|
4134
|
+
device: An optional :class:`Device` instance or device alias
|
|
4135
|
+
(e.g. "cuda:0") for which the current stream is to be replaced with
|
|
4136
|
+
``stream``. If ``None``, the default device will be used.
|
|
4137
|
+
stream: The stream to set as this device's current stream.
|
|
4138
|
+
sync: If ``True``, then ``stream`` will perform a device-side
|
|
4139
|
+
synchronization with the device's previous current stream.
|
|
3871
4140
|
"""
|
|
3872
4141
|
|
|
3873
4142
|
get_device(device).set_stream(stream, sync=sync)
|
|
3874
4143
|
|
|
3875
4144
|
|
|
3876
|
-
def record_event(event: Event = None):
|
|
3877
|
-
"""
|
|
4145
|
+
def record_event(event: Optional[Event] = None):
|
|
4146
|
+
"""Convenience function for calling :meth:`Stream.record_event` on the current stream.
|
|
3878
4147
|
|
|
3879
4148
|
Args:
|
|
3880
|
-
event: Event to record. If None
|
|
4149
|
+
event: :class:`Event` instance to record. If ``None``, a new :class:`Event`
|
|
4150
|
+
instance will be created.
|
|
3881
4151
|
|
|
3882
4152
|
Returns:
|
|
3883
4153
|
The recorded event.
|
|
@@ -3887,29 +4157,31 @@ def record_event(event: Event = None):
|
|
|
3887
4157
|
|
|
3888
4158
|
|
|
3889
4159
|
def wait_event(event: Event):
|
|
3890
|
-
"""
|
|
4160
|
+
"""Convenience function for calling :meth:`Stream.wait_event` on the current stream.
|
|
3891
4161
|
|
|
3892
4162
|
Args:
|
|
3893
|
-
event: Event to wait for.
|
|
4163
|
+
event: :class:`Event` instance to wait for.
|
|
3894
4164
|
"""
|
|
3895
4165
|
|
|
3896
4166
|
get_stream().wait_event(event)
|
|
3897
4167
|
|
|
3898
4168
|
|
|
3899
|
-
def get_event_elapsed_time(start_event: Event, end_event: Event, synchronize: bool = True):
|
|
4169
|
+
def get_event_elapsed_time(start_event: Event, end_event: Event, synchronize: Optional[bool] = True):
|
|
3900
4170
|
"""Get the elapsed time between two recorded events.
|
|
3901
4171
|
|
|
3902
|
-
|
|
3903
|
-
|
|
3904
|
-
Both events must have been previously recorded with ``wp.record_event()`` or ``wp.Stream.record_event()``.
|
|
4172
|
+
Both events must have been previously recorded with
|
|
4173
|
+
:func:`~warp.record_event()` or :meth:`warp.Stream.record_event()`.
|
|
3905
4174
|
|
|
3906
4175
|
If ``synchronize`` is False, the caller must ensure that device execution has reached ``end_event``
|
|
3907
4176
|
prior to calling ``get_event_elapsed_time()``.
|
|
3908
4177
|
|
|
3909
4178
|
Args:
|
|
3910
|
-
start_event
|
|
3911
|
-
end_event
|
|
3912
|
-
synchronize
|
|
4179
|
+
start_event: The start event.
|
|
4180
|
+
end_event: The end event.
|
|
4181
|
+
synchronize: Whether Warp should synchronize on the ``end_event``.
|
|
4182
|
+
|
|
4183
|
+
Returns:
|
|
4184
|
+
The elapsed time in milliseconds with a resolution about 0.5 ms.
|
|
3913
4185
|
"""
|
|
3914
4186
|
|
|
3915
4187
|
# ensure the end_event is reached
|
|
@@ -3919,14 +4191,19 @@ def get_event_elapsed_time(start_event: Event, end_event: Event, synchronize: bo
|
|
|
3919
4191
|
return runtime.core.cuda_event_elapsed_time(start_event.cuda_event, end_event.cuda_event)
|
|
3920
4192
|
|
|
3921
4193
|
|
|
3922
|
-
def wait_stream(
|
|
3923
|
-
"""
|
|
4194
|
+
def wait_stream(other_stream: Stream, event: Event = None):
|
|
4195
|
+
"""Convenience function for calling :meth:`Stream.wait_stream` on the current stream.
|
|
3924
4196
|
|
|
3925
4197
|
Args:
|
|
3926
|
-
|
|
4198
|
+
other_stream: The stream on which the calling stream will wait for
|
|
4199
|
+
previously issued commands to complete before executing subsequent
|
|
4200
|
+
commands.
|
|
4201
|
+
event: An optional :class:`Event` instance that will be used to
|
|
4202
|
+
record an event onto ``other_stream``. If ``None``, an internally
|
|
4203
|
+
managed :class:`Event` instance will be used.
|
|
3927
4204
|
"""
|
|
3928
4205
|
|
|
3929
|
-
get_stream().wait_stream(
|
|
4206
|
+
get_stream().wait_stream(other_stream, event=event)
|
|
3930
4207
|
|
|
3931
4208
|
|
|
3932
4209
|
class RegisteredGLBuffer:
|
|
@@ -4362,7 +4639,7 @@ def from_numpy(
|
|
|
4362
4639
|
dtype: The data type of the new Warp array. If this is not provided, the data type will be inferred.
|
|
4363
4640
|
shape: The shape of the Warp array.
|
|
4364
4641
|
device: The device on which the Warp array will be constructed.
|
|
4365
|
-
requires_grad: Whether
|
|
4642
|
+
requires_grad: Whether gradients will be tracked for this array.
|
|
4366
4643
|
|
|
4367
4644
|
Raises:
|
|
4368
4645
|
RuntimeError: The data type of the NumPy array is not supported.
|
|
@@ -4398,7 +4675,7 @@ def pack_arg(kernel, arg_type, arg_name, value, device, adjoint=False):
|
|
|
4398
4675
|
return arg_type.__ctype__()
|
|
4399
4676
|
|
|
4400
4677
|
elif isinstance(value, warp.types.array_t):
|
|
4401
|
-
# accept array descriptors
|
|
4678
|
+
# accept array descriptors verbatim
|
|
4402
4679
|
return value
|
|
4403
4680
|
|
|
4404
4681
|
else:
|
|
@@ -4865,7 +5142,7 @@ def synchronize_device(device: Devicelike = None):
|
|
|
4865
5142
|
runtime.core.cuda_context_synchronize(device.context)
|
|
4866
5143
|
|
|
4867
5144
|
|
|
4868
|
-
def synchronize_stream(stream_or_device=None):
|
|
5145
|
+
def synchronize_stream(stream_or_device: Union[Stream, Devicelike, None] = None):
|
|
4869
5146
|
"""Synchronize the calling CPU thread with any outstanding CUDA work on the specified stream.
|
|
4870
5147
|
|
|
4871
5148
|
This function allows the host application code to ensure that all kernel launches
|
|
@@ -4989,7 +5266,7 @@ def set_module_options(options: Dict[str, Any], module: Optional[Any] = None):
|
|
|
4989
5266
|
m = module
|
|
4990
5267
|
|
|
4991
5268
|
get_module(m.__name__).options.update(options)
|
|
4992
|
-
get_module(m.__name__).
|
|
5269
|
+
get_module(m.__name__).mark_modified()
|
|
4993
5270
|
|
|
4994
5271
|
|
|
4995
5272
|
def get_module_options(module: Optional[Any] = None) -> Dict[str, Any]:
|
|
@@ -5016,7 +5293,7 @@ def capture_begin(device: Devicelike = None, stream=None, force_module_load=None
|
|
|
5016
5293
|
Args:
|
|
5017
5294
|
device: The CUDA device to capture on
|
|
5018
5295
|
stream: The CUDA stream to capture on
|
|
5019
|
-
force_module_load: Whether
|
|
5296
|
+
force_module_load: Whether to force loading of all kernels before capture.
|
|
5020
5297
|
In general it is better to use :func:`~warp.load_module()` to selectively load kernels.
|
|
5021
5298
|
When running with CUDA drivers that support CUDA 12.3 or newer, this option is not recommended to be set to
|
|
5022
5299
|
``True`` because kernels can be loaded during graph capture on more recent drivers. If this argument is
|
|
@@ -5574,7 +5851,7 @@ def export_stubs(file): # pragma: no cover
|
|
|
5574
5851
|
|
|
5575
5852
|
return_str = ""
|
|
5576
5853
|
|
|
5577
|
-
if
|
|
5854
|
+
if f.hidden: # or f.generic:
|
|
5578
5855
|
continue
|
|
5579
5856
|
|
|
5580
5857
|
return_type = f.value_func(None, None)
|