warp-lang 1.2.2__py3-none-manylinux2014_aarch64.whl → 1.3.1__py3-none-manylinux2014_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +8 -6
- warp/autograd.py +823 -0
- warp/bin/warp.so +0 -0
- warp/build.py +6 -2
- warp/builtins.py +1412 -888
- warp/codegen.py +503 -166
- warp/config.py +48 -18
- warp/context.py +400 -198
- warp/dlpack.py +8 -0
- warp/examples/assets/bunny.usd +0 -0
- warp/examples/benchmarks/benchmark_cloth_warp.py +1 -1
- warp/examples/benchmarks/benchmark_interop_torch.py +158 -0
- warp/examples/benchmarks/benchmark_launches.py +1 -1
- warp/examples/core/example_cupy.py +78 -0
- warp/examples/fem/example_apic_fluid.py +17 -36
- warp/examples/fem/example_burgers.py +9 -18
- warp/examples/fem/example_convection_diffusion.py +7 -17
- warp/examples/fem/example_convection_diffusion_dg.py +27 -47
- warp/examples/fem/example_deformed_geometry.py +11 -22
- warp/examples/fem/example_diffusion.py +7 -18
- warp/examples/fem/example_diffusion_3d.py +24 -28
- warp/examples/fem/example_diffusion_mgpu.py +7 -14
- warp/examples/fem/example_magnetostatics.py +190 -0
- warp/examples/fem/example_mixed_elasticity.py +111 -80
- warp/examples/fem/example_navier_stokes.py +30 -34
- warp/examples/fem/example_nonconforming_contact.py +290 -0
- warp/examples/fem/example_stokes.py +17 -32
- warp/examples/fem/example_stokes_transfer.py +12 -21
- warp/examples/fem/example_streamlines.py +350 -0
- warp/examples/fem/utils.py +936 -0
- warp/fabric.py +5 -2
- warp/fem/__init__.py +13 -3
- warp/fem/cache.py +161 -11
- warp/fem/dirichlet.py +37 -28
- warp/fem/domain.py +105 -14
- warp/fem/field/__init__.py +14 -3
- warp/fem/field/field.py +454 -11
- warp/fem/field/nodal_field.py +33 -18
- warp/fem/geometry/deformed_geometry.py +50 -15
- warp/fem/geometry/hexmesh.py +12 -24
- warp/fem/geometry/nanogrid.py +106 -31
- warp/fem/geometry/quadmesh_2d.py +6 -11
- warp/fem/geometry/tetmesh.py +103 -61
- warp/fem/geometry/trimesh_2d.py +98 -47
- warp/fem/integrate.py +231 -186
- warp/fem/operator.py +14 -9
- warp/fem/quadrature/pic_quadrature.py +35 -9
- warp/fem/quadrature/quadrature.py +119 -32
- warp/fem/space/basis_space.py +98 -22
- warp/fem/space/collocated_function_space.py +3 -1
- warp/fem/space/function_space.py +7 -2
- warp/fem/space/grid_2d_function_space.py +3 -3
- warp/fem/space/grid_3d_function_space.py +4 -4
- warp/fem/space/hexmesh_function_space.py +3 -2
- warp/fem/space/nanogrid_function_space.py +12 -14
- warp/fem/space/partition.py +45 -47
- warp/fem/space/restriction.py +19 -16
- warp/fem/space/shape/cube_shape_function.py +91 -3
- warp/fem/space/shape/shape_function.py +7 -0
- warp/fem/space/shape/square_shape_function.py +32 -0
- warp/fem/space/shape/tet_shape_function.py +11 -7
- warp/fem/space/shape/triangle_shape_function.py +10 -1
- warp/fem/space/topology.py +116 -42
- warp/fem/types.py +8 -1
- warp/fem/utils.py +301 -83
- warp/native/array.h +16 -0
- warp/native/builtin.h +0 -15
- warp/native/cuda_util.cpp +14 -6
- warp/native/exports.h +1348 -1308
- warp/native/quat.h +79 -0
- warp/native/rand.h +27 -4
- warp/native/sparse.cpp +83 -81
- warp/native/sparse.cu +381 -453
- warp/native/vec.h +64 -0
- warp/native/volume.cpp +40 -49
- warp/native/volume_builder.cu +2 -3
- warp/native/volume_builder.h +12 -17
- warp/native/warp.cu +3 -3
- warp/native/warp.h +69 -59
- warp/render/render_opengl.py +17 -9
- warp/sim/articulation.py +117 -17
- warp/sim/collide.py +35 -29
- warp/sim/model.py +123 -18
- warp/sim/render.py +3 -1
- warp/sparse.py +867 -203
- warp/stubs.py +312 -541
- warp/tape.py +29 -1
- warp/tests/disabled_kinematics.py +1 -1
- warp/tests/test_adam.py +1 -1
- warp/tests/test_arithmetic.py +1 -1
- warp/tests/test_array.py +58 -1
- warp/tests/test_array_reduce.py +1 -1
- warp/tests/test_async.py +1 -1
- warp/tests/test_atomic.py +1 -1
- warp/tests/test_bool.py +1 -1
- warp/tests/test_builtins_resolution.py +1 -1
- warp/tests/test_bvh.py +6 -1
- warp/tests/test_closest_point_edge_edge.py +1 -1
- warp/tests/test_codegen.py +91 -1
- warp/tests/test_compile_consts.py +1 -1
- warp/tests/test_conditional.py +1 -1
- warp/tests/test_copy.py +1 -1
- warp/tests/test_ctypes.py +1 -1
- warp/tests/test_dense.py +1 -1
- warp/tests/test_devices.py +1 -1
- warp/tests/test_dlpack.py +1 -1
- warp/tests/test_examples.py +33 -4
- warp/tests/test_fabricarray.py +5 -2
- warp/tests/test_fast_math.py +1 -1
- warp/tests/test_fem.py +213 -6
- warp/tests/test_fp16.py +1 -1
- warp/tests/test_func.py +1 -1
- warp/tests/test_future_annotations.py +90 -0
- warp/tests/test_generics.py +1 -1
- warp/tests/test_grad.py +1 -1
- warp/tests/test_grad_customs.py +1 -1
- warp/tests/test_grad_debug.py +247 -0
- warp/tests/test_hash_grid.py +6 -1
- warp/tests/test_implicit_init.py +354 -0
- warp/tests/test_import.py +1 -1
- warp/tests/test_indexedarray.py +1 -1
- warp/tests/test_intersect.py +1 -1
- warp/tests/test_jax.py +1 -1
- warp/tests/test_large.py +1 -1
- warp/tests/test_launch.py +1 -1
- warp/tests/test_lerp.py +1 -1
- warp/tests/test_linear_solvers.py +1 -1
- warp/tests/test_lvalue.py +1 -1
- warp/tests/test_marching_cubes.py +5 -2
- warp/tests/test_mat.py +34 -35
- warp/tests/test_mat_lite.py +2 -1
- warp/tests/test_mat_scalar_ops.py +1 -1
- warp/tests/test_math.py +1 -1
- warp/tests/test_matmul.py +20 -16
- warp/tests/test_matmul_lite.py +1 -1
- warp/tests/test_mempool.py +1 -1
- warp/tests/test_mesh.py +5 -2
- warp/tests/test_mesh_query_aabb.py +1 -1
- warp/tests/test_mesh_query_point.py +1 -1
- warp/tests/test_mesh_query_ray.py +1 -1
- warp/tests/test_mlp.py +1 -1
- warp/tests/test_model.py +1 -1
- warp/tests/test_module_hashing.py +77 -1
- warp/tests/test_modules_lite.py +1 -1
- warp/tests/test_multigpu.py +1 -1
- warp/tests/test_noise.py +1 -1
- warp/tests/test_operators.py +1 -1
- warp/tests/test_options.py +1 -1
- warp/tests/test_overwrite.py +542 -0
- warp/tests/test_peer.py +1 -1
- warp/tests/test_pinned.py +1 -1
- warp/tests/test_print.py +1 -1
- warp/tests/test_quat.py +15 -1
- warp/tests/test_rand.py +1 -1
- warp/tests/test_reload.py +1 -1
- warp/tests/test_rounding.py +1 -1
- warp/tests/test_runlength_encode.py +1 -1
- warp/tests/test_scalar_ops.py +95 -0
- warp/tests/test_sim_grad.py +1 -1
- warp/tests/test_sim_kinematics.py +1 -1
- warp/tests/test_smoothstep.py +1 -1
- warp/tests/test_sparse.py +82 -15
- warp/tests/test_spatial.py +1 -1
- warp/tests/test_special_values.py +2 -11
- warp/tests/test_streams.py +11 -1
- warp/tests/test_struct.py +1 -1
- warp/tests/test_tape.py +1 -1
- warp/tests/test_torch.py +194 -1
- warp/tests/test_transient_module.py +1 -1
- warp/tests/test_types.py +1 -1
- warp/tests/test_utils.py +1 -1
- warp/tests/test_vec.py +15 -63
- warp/tests/test_vec_lite.py +2 -1
- warp/tests/test_vec_scalar_ops.py +65 -1
- warp/tests/test_verify_fp.py +1 -1
- warp/tests/test_volume.py +28 -2
- warp/tests/test_volume_write.py +1 -1
- warp/tests/unittest_serial.py +1 -1
- warp/tests/unittest_suites.py +9 -1
- warp/tests/walkthrough_debug.py +1 -1
- warp/thirdparty/unittest_parallel.py +2 -5
- warp/torch.py +103 -41
- warp/types.py +341 -224
- warp/utils.py +11 -2
- {warp_lang-1.2.2.dist-info → warp_lang-1.3.1.dist-info}/METADATA +99 -46
- warp_lang-1.3.1.dist-info/RECORD +368 -0
- warp/examples/fem/bsr_utils.py +0 -378
- warp/examples/fem/mesh_utils.py +0 -133
- warp/examples/fem/plot_utils.py +0 -292
- warp_lang-1.2.2.dist-info/RECORD +0 -359
- {warp_lang-1.2.2.dist-info → warp_lang-1.3.1.dist-info}/LICENSE.md +0 -0
- {warp_lang-1.2.2.dist-info → warp_lang-1.3.1.dist-info}/WHEEL +0 -0
- {warp_lang-1.2.2.dist-info → warp_lang-1.3.1.dist-info}/top_level.txt +0 -0
warp/context.py
CHANGED
|
@@ -18,6 +18,7 @@ import os
|
|
|
18
18
|
import platform
|
|
19
19
|
import sys
|
|
20
20
|
import types
|
|
21
|
+
import typing
|
|
21
22
|
from copy import copy as shallowcopy
|
|
22
23
|
from pathlib import Path
|
|
23
24
|
from struct import pack as struct_pack
|
|
@@ -34,7 +35,7 @@ import warp.config
|
|
|
34
35
|
|
|
35
36
|
|
|
36
37
|
def create_value_func(type):
|
|
37
|
-
def value_func(
|
|
38
|
+
def value_func(arg_types, arg_values):
|
|
38
39
|
return type
|
|
39
40
|
|
|
40
41
|
return value_func
|
|
@@ -42,7 +43,7 @@ def create_value_func(type):
|
|
|
42
43
|
|
|
43
44
|
def get_function_args(func):
|
|
44
45
|
"""Ensures that all function arguments are annotated and returns a dictionary mapping from argument name to its type."""
|
|
45
|
-
argspec =
|
|
46
|
+
argspec = warp.codegen.get_full_arg_spec(func)
|
|
46
47
|
|
|
47
48
|
# use source-level argument annotations
|
|
48
49
|
if len(argspec.annotations) < len(argspec.args):
|
|
@@ -63,7 +64,8 @@ class Function:
|
|
|
63
64
|
input_types=None,
|
|
64
65
|
value_type=None,
|
|
65
66
|
value_func=None,
|
|
66
|
-
|
|
67
|
+
export_func=None,
|
|
68
|
+
dispatch_func=None,
|
|
67
69
|
module=None,
|
|
68
70
|
variadic=False,
|
|
69
71
|
initializer_list_func=None,
|
|
@@ -97,14 +99,15 @@ class Function:
|
|
|
97
99
|
self.namespace = namespace
|
|
98
100
|
self.value_type = value_type
|
|
99
101
|
self.value_func = value_func # a function that takes a list of args and a list of templates and returns the value type, e.g.: load(array, index) returns the type of value being loaded
|
|
100
|
-
self.
|
|
102
|
+
self.export_func = export_func
|
|
103
|
+
self.dispatch_func = dispatch_func
|
|
101
104
|
self.input_types = {}
|
|
102
105
|
self.export = export
|
|
103
106
|
self.doc = doc
|
|
104
107
|
self.group = group
|
|
105
108
|
self.module = module
|
|
106
109
|
self.variadic = variadic # function can take arbitrary number of inputs, e.g.: printf()
|
|
107
|
-
self.defaults = defaults
|
|
110
|
+
self.defaults = {} if defaults is None else defaults
|
|
108
111
|
# Function instance for a custom implementation of the replay pass
|
|
109
112
|
self.custom_replay_func = custom_replay_func
|
|
110
113
|
self.native_snippet = native_snippet
|
|
@@ -180,6 +183,33 @@ class Function:
|
|
|
180
183
|
if not skip_adding_overload:
|
|
181
184
|
self.add_overload(self)
|
|
182
185
|
|
|
186
|
+
# Store a description of the function's signature that can be used
|
|
187
|
+
# to resolve a bunch of positional/keyword/variadic arguments against,
|
|
188
|
+
# in a way that is compatible with Python's semantics.
|
|
189
|
+
signature_params = []
|
|
190
|
+
signature_default_param_kind = inspect.Parameter.POSITIONAL_OR_KEYWORD
|
|
191
|
+
for param_name in self.input_types.keys():
|
|
192
|
+
if param_name.startswith("**"):
|
|
193
|
+
param_name = param_name[2:]
|
|
194
|
+
param_kind = inspect.Parameter.VAR_KEYWORD
|
|
195
|
+
elif param_name.startswith("*"):
|
|
196
|
+
param_name = param_name[1:]
|
|
197
|
+
param_kind = inspect.Parameter.VAR_POSITIONAL
|
|
198
|
+
|
|
199
|
+
# Once a variadic argument like `*args` is found, any following
|
|
200
|
+
# arguments need to be passed using keywords.
|
|
201
|
+
signature_default_param_kind = inspect.Parameter.KEYWORD_ONLY
|
|
202
|
+
else:
|
|
203
|
+
param_kind = signature_default_param_kind
|
|
204
|
+
|
|
205
|
+
param = param = inspect.Parameter(
|
|
206
|
+
param_name,
|
|
207
|
+
param_kind,
|
|
208
|
+
default=self.defaults.get(param_name, inspect.Parameter.empty),
|
|
209
|
+
)
|
|
210
|
+
signature_params.append(param)
|
|
211
|
+
self.signature = inspect.Signature(signature_params)
|
|
212
|
+
|
|
183
213
|
# add to current module
|
|
184
214
|
if module:
|
|
185
215
|
module.register_function(self, skip_adding_overload)
|
|
@@ -247,7 +277,7 @@ class Function:
|
|
|
247
277
|
|
|
248
278
|
# only export simple types that don't use arrays
|
|
249
279
|
for v in self.input_types.values():
|
|
250
|
-
if
|
|
280
|
+
if warp.types.is_array(v) or v in complex_type_hints:
|
|
251
281
|
return False
|
|
252
282
|
|
|
253
283
|
if type(self.value_type) in sequence_types:
|
|
@@ -261,8 +291,14 @@ class Function:
|
|
|
261
291
|
|
|
262
292
|
name = "builtin_" + self.key
|
|
263
293
|
|
|
294
|
+
# Runtime arguments that are to be passed to the function, not its template signature.
|
|
295
|
+
if self.export_func is not None:
|
|
296
|
+
func_args = self.export_func(self.input_types)
|
|
297
|
+
else:
|
|
298
|
+
func_args = self.input_types
|
|
299
|
+
|
|
264
300
|
types = []
|
|
265
|
-
for t in
|
|
301
|
+
for t in func_args.values():
|
|
266
302
|
types.append(t.__name__)
|
|
267
303
|
|
|
268
304
|
return "_".join([name, *types])
|
|
@@ -299,7 +335,7 @@ class Function:
|
|
|
299
335
|
)
|
|
300
336
|
self.user_overloads[sig] = f
|
|
301
337
|
|
|
302
|
-
def get_overload(self, arg_types):
|
|
338
|
+
def get_overload(self, arg_types, kwarg_types):
|
|
303
339
|
assert not self.is_builtin()
|
|
304
340
|
|
|
305
341
|
sig = warp.types.get_signature(arg_types, func_name=self.key)
|
|
@@ -347,15 +383,21 @@ class Function:
|
|
|
347
383
|
def call_builtin(func: Function, *params) -> Tuple[bool, Any]:
|
|
348
384
|
uses_non_warp_array_type = False
|
|
349
385
|
|
|
350
|
-
|
|
386
|
+
init()
|
|
351
387
|
|
|
352
388
|
# Retrieve the built-in function from Warp's dll.
|
|
353
389
|
c_func = getattr(warp.context.runtime.core, func.mangled_name)
|
|
354
390
|
|
|
391
|
+
# Runtime arguments that are to be passed to the function, not its template signature.
|
|
392
|
+
if func.export_func is not None:
|
|
393
|
+
func_args = func.export_func(func.input_types)
|
|
394
|
+
else:
|
|
395
|
+
func_args = func.input_types
|
|
396
|
+
|
|
355
397
|
# Try gathering the parameters that the function expects and pack them
|
|
356
398
|
# into their corresponding C types.
|
|
357
399
|
c_params = []
|
|
358
|
-
for i, (_, arg_type) in enumerate(
|
|
400
|
+
for i, (_, arg_type) in enumerate(func_args.items()):
|
|
359
401
|
param = params[i]
|
|
360
402
|
|
|
361
403
|
try:
|
|
@@ -485,7 +527,8 @@ def call_builtin(func: Function, *params) -> Tuple[bool, Any]:
|
|
|
485
527
|
c_params.append(arg_type._type_(param))
|
|
486
528
|
|
|
487
529
|
# returns the corresponding ctype for a scalar or vector warp type
|
|
488
|
-
value_type = func.value_func(None, None
|
|
530
|
+
value_type = func.value_func(None, None)
|
|
531
|
+
|
|
489
532
|
if value_type == float:
|
|
490
533
|
value_ctype = ctypes.c_float
|
|
491
534
|
elif value_type == int:
|
|
@@ -521,10 +564,12 @@ def call_builtin(func: Function, *params) -> Tuple[bool, Any]:
|
|
|
521
564
|
return (True, ret)
|
|
522
565
|
|
|
523
566
|
if value_type == warp.types.float16:
|
|
524
|
-
|
|
567
|
+
value = warp.types.half_bits_to_float(ret.value)
|
|
568
|
+
else:
|
|
569
|
+
value = ret.value
|
|
525
570
|
|
|
526
571
|
# return scalar types as int/float
|
|
527
|
-
return (True,
|
|
572
|
+
return (True, value)
|
|
528
573
|
|
|
529
574
|
|
|
530
575
|
class KernelHooks:
|
|
@@ -742,7 +787,6 @@ def func_grad(forward_fn):
|
|
|
742
787
|
input_types=reverse_args,
|
|
743
788
|
value_func=None,
|
|
744
789
|
module=f.module,
|
|
745
|
-
template_func=f.template_func,
|
|
746
790
|
skip_forward_codegen=True,
|
|
747
791
|
custom_reverse_mode=True,
|
|
748
792
|
custom_reverse_num_input_args=len(f.input_types),
|
|
@@ -807,7 +851,7 @@ def func_replay(forward_fn):
|
|
|
807
851
|
f"Cannot define custom replay definition for {forward_fn.key} since the provided replay function has generic input arguments."
|
|
808
852
|
)
|
|
809
853
|
|
|
810
|
-
f = forward_fn.get_overload(arg_types)
|
|
854
|
+
f = forward_fn.get_overload(arg_types, {})
|
|
811
855
|
if f is None:
|
|
812
856
|
inputs_str = ", ".join([f"{k}: {v.__name__}" for k, v in args.items()])
|
|
813
857
|
raise RuntimeError(
|
|
@@ -819,8 +863,9 @@ def func_replay(forward_fn):
|
|
|
819
863
|
namespace=f.namespace,
|
|
820
864
|
input_types=f.input_types,
|
|
821
865
|
value_func=f.value_func,
|
|
866
|
+
export_func=f.export_func,
|
|
867
|
+
dispatch_func=f.dispatch_func,
|
|
822
868
|
module=f.module,
|
|
823
|
-
template_func=f.template_func,
|
|
824
869
|
skip_reverse_codegen=True,
|
|
825
870
|
skip_adding_overload=True,
|
|
826
871
|
code_transformers=f.adj.transformers,
|
|
@@ -920,7 +965,7 @@ def overload(kernel, arg_types=None):
|
|
|
920
965
|
)
|
|
921
966
|
|
|
922
967
|
# ensure all arguments are annotated
|
|
923
|
-
argspec =
|
|
968
|
+
argspec = warp.codegen.get_full_arg_spec(fn)
|
|
924
969
|
if len(argspec.annotations) < len(argspec.args):
|
|
925
970
|
raise RuntimeError(f"Incomplete argument annotations on kernel overload {fn.__name__}")
|
|
926
971
|
|
|
@@ -965,7 +1010,8 @@ def add_builtin(
|
|
|
965
1010
|
constraint=None,
|
|
966
1011
|
value_type=None,
|
|
967
1012
|
value_func=None,
|
|
968
|
-
|
|
1013
|
+
export_func=None,
|
|
1014
|
+
dispatch_func=None,
|
|
969
1015
|
doc="",
|
|
970
1016
|
namespace="wp::",
|
|
971
1017
|
variadic=False,
|
|
@@ -979,18 +1025,66 @@ def add_builtin(
|
|
|
979
1025
|
defaults=None,
|
|
980
1026
|
require_original_output_arg=False,
|
|
981
1027
|
):
|
|
1028
|
+
"""Main entry point to register a new built-in function.
|
|
1029
|
+
|
|
1030
|
+
Args:
|
|
1031
|
+
key (str): Function name. Multiple overloaded functions can be registered
|
|
1032
|
+
under the same name as long as their signature differ.
|
|
1033
|
+
input_types (Mapping[str, Any]): Signature of the user-facing function.
|
|
1034
|
+
Variadic arguments are supported by prefixing the parameter names
|
|
1035
|
+
with asterisks as in `*args` and `**kwargs`. Generic arguments are
|
|
1036
|
+
supported with types such as `Any`, `Float`, `Scalar`, etc.
|
|
1037
|
+
constraint (Callable): For functions that define generic arguments and
|
|
1038
|
+
are to be exported, this callback is used to specify whether some
|
|
1039
|
+
combination of inferred arguments are valid or not.
|
|
1040
|
+
value_type (Any): Type returned by the function.
|
|
1041
|
+
value_func (Callable): Callback used to specify the return type when
|
|
1042
|
+
`value_type` isn't enough.
|
|
1043
|
+
export_func (Callable): Callback used during the context stage to specify
|
|
1044
|
+
the signature of the underlying C++ function, not accounting for
|
|
1045
|
+
the template parameters.
|
|
1046
|
+
If not provided, `input_types` is used.
|
|
1047
|
+
dispatch_func (Callable): Callback used during the codegen stage to specify
|
|
1048
|
+
the runtime and template arguments to be passed to the underlying C++
|
|
1049
|
+
function. In other words, this allows defining a mapping between
|
|
1050
|
+
the signatures of the user-facing and the C++ functions, and even to
|
|
1051
|
+
dynamically create new arguments on the fly.
|
|
1052
|
+
The arguments returned must be of type `codegen.Var`.
|
|
1053
|
+
If not provided, all arguments passed by the users when calling
|
|
1054
|
+
the built-in are passed as-is as runtime arguments to the C++ function.
|
|
1055
|
+
doc (str): Used to generate the Python's docstring and the HTML documentation.
|
|
1056
|
+
namespace: Namespace for the underlying C++ function.
|
|
1057
|
+
variadic (bool): Whether the function declares variadic arguments.
|
|
1058
|
+
initializer_list_func (bool): Whether to use the initializer list syntax
|
|
1059
|
+
when passing the arguments to the underlying C++ function.
|
|
1060
|
+
export (bool): Whether the function is to be exposed to the Python
|
|
1061
|
+
interpreter so that it becomes available from within the `warp`
|
|
1062
|
+
module.
|
|
1063
|
+
group (str): Classification used for the documentation.
|
|
1064
|
+
hidden (bool): Whether to add that function into the documentation.
|
|
1065
|
+
skip_replay (bool): Whether operation will be performed during
|
|
1066
|
+
the forward replay in the backward pass.
|
|
1067
|
+
missing_grad (bool): Whether the function is missing a corresponding
|
|
1068
|
+
adjoint.
|
|
1069
|
+
native_func (str): Name of the underlying C++ function.
|
|
1070
|
+
defaults (Mapping[str, Any]): Default values for the parameters defined
|
|
1071
|
+
in `input_types`.
|
|
1072
|
+
require_original_output_arg (bool): Used during the codegen stage to
|
|
1073
|
+
specify whether an adjoint parameter corresponding to the return
|
|
1074
|
+
value should be included in the signature of the backward function.
|
|
1075
|
+
"""
|
|
982
1076
|
if input_types is None:
|
|
983
1077
|
input_types = {}
|
|
984
1078
|
|
|
985
1079
|
# wrap simple single-type functions with a value_func()
|
|
986
1080
|
if value_func is None:
|
|
987
1081
|
|
|
988
|
-
def value_func(
|
|
1082
|
+
def value_func(arg_types, arg_values):
|
|
989
1083
|
return value_type
|
|
990
1084
|
|
|
991
1085
|
if initializer_list_func is None:
|
|
992
1086
|
|
|
993
|
-
def initializer_list_func(args,
|
|
1087
|
+
def initializer_list_func(args, return_type):
|
|
994
1088
|
return False
|
|
995
1089
|
|
|
996
1090
|
if defaults is None:
|
|
@@ -998,8 +1092,13 @@ def add_builtin(
|
|
|
998
1092
|
|
|
999
1093
|
# Add specialized versions of this builtin if it's generic by matching arguments against
|
|
1000
1094
|
# hard coded types. We do this so you can use hard coded warp types outside kernels:
|
|
1095
|
+
if export_func is not None:
|
|
1096
|
+
func_arg_types = export_func(input_types)
|
|
1097
|
+
else:
|
|
1098
|
+
func_arg_types = input_types
|
|
1099
|
+
|
|
1001
1100
|
generic = False
|
|
1002
|
-
for x in
|
|
1101
|
+
for x in func_arg_types.values():
|
|
1003
1102
|
if warp.types.type_is_generic(x):
|
|
1004
1103
|
generic = True
|
|
1005
1104
|
break
|
|
@@ -1007,7 +1106,7 @@ def add_builtin(
|
|
|
1007
1106
|
if generic and export:
|
|
1008
1107
|
# collect the parent type names of all the generic arguments:
|
|
1009
1108
|
genericset = set()
|
|
1010
|
-
for t in
|
|
1109
|
+
for t in func_arg_types.values():
|
|
1011
1110
|
if hasattr(t, "_wp_generic_type_hint_"):
|
|
1012
1111
|
genericset.add(t._wp_generic_type_hint_)
|
|
1013
1112
|
elif warp.types.type_is_generic_scalar(t):
|
|
@@ -1059,15 +1158,17 @@ def add_builtin(
|
|
|
1059
1158
|
|
|
1060
1159
|
typelists.append(l)
|
|
1061
1160
|
|
|
1062
|
-
for
|
|
1161
|
+
for arg_types in itertools.product(*typelists):
|
|
1162
|
+
arg_types = dict(zip(input_types.keys(), arg_types))
|
|
1163
|
+
|
|
1063
1164
|
# Some of these argument lists won't work, eg if the function is mul(), we won't be
|
|
1064
1165
|
# able to do a matrix vector multiplication for a mat22 and a vec3. The `constraint`
|
|
1065
1166
|
# function determines which combinations are valid:
|
|
1066
1167
|
if constraint:
|
|
1067
|
-
if constraint(
|
|
1168
|
+
if constraint(arg_types) is False:
|
|
1068
1169
|
continue
|
|
1069
1170
|
|
|
1070
|
-
return_type = value_func(
|
|
1171
|
+
return_type = value_func(arg_types, None)
|
|
1071
1172
|
|
|
1072
1173
|
# The return_type might just be vector_t(length=3,dtype=wp.float32), so we've got to match that
|
|
1073
1174
|
# in the list of hard coded types so it knows it's returning one of them:
|
|
@@ -1085,8 +1186,10 @@ def add_builtin(
|
|
|
1085
1186
|
# finally we can generate a function call for these concrete types:
|
|
1086
1187
|
add_builtin(
|
|
1087
1188
|
key,
|
|
1088
|
-
input_types=
|
|
1189
|
+
input_types=arg_types,
|
|
1089
1190
|
value_type=return_type,
|
|
1191
|
+
export_func=export_func,
|
|
1192
|
+
dispatch_func=dispatch_func,
|
|
1090
1193
|
doc=doc,
|
|
1091
1194
|
namespace=namespace,
|
|
1092
1195
|
variadic=variadic,
|
|
@@ -1096,6 +1199,7 @@ def add_builtin(
|
|
|
1096
1199
|
hidden=True,
|
|
1097
1200
|
skip_replay=skip_replay,
|
|
1098
1201
|
missing_grad=missing_grad,
|
|
1202
|
+
defaults=defaults,
|
|
1099
1203
|
require_original_output_arg=require_original_output_arg,
|
|
1100
1204
|
)
|
|
1101
1205
|
|
|
@@ -1106,7 +1210,8 @@ def add_builtin(
|
|
|
1106
1210
|
input_types=input_types,
|
|
1107
1211
|
value_type=value_type,
|
|
1108
1212
|
value_func=value_func,
|
|
1109
|
-
|
|
1213
|
+
export_func=export_func,
|
|
1214
|
+
dispatch_func=dispatch_func,
|
|
1110
1215
|
variadic=variadic,
|
|
1111
1216
|
initializer_list_func=initializer_list_func,
|
|
1112
1217
|
export=export,
|
|
@@ -1250,7 +1355,7 @@ class ModuleBuilder:
|
|
|
1250
1355
|
if not func.value_func:
|
|
1251
1356
|
|
|
1252
1357
|
def wrap(adj):
|
|
1253
|
-
def value_type(arg_types,
|
|
1358
|
+
def value_type(arg_types, arg_values):
|
|
1254
1359
|
if adj.return_var is None or len(adj.return_var) == 0:
|
|
1255
1360
|
return None
|
|
1256
1361
|
if len(adj.return_var) == 1:
|
|
@@ -1453,14 +1558,6 @@ class Module:
|
|
|
1453
1558
|
computed ``content_hash`` will be used.
|
|
1454
1559
|
"""
|
|
1455
1560
|
|
|
1456
|
-
def get_annotations(obj: Any) -> Mapping[str, Any]:
|
|
1457
|
-
"""Alternative to `inspect.get_annotations()` for Python 3.9 and older."""
|
|
1458
|
-
# See https://docs.python.org/3/howto/annotations.html#accessing-the-annotations-dict-of-an-object-in-python-3-9-and-older
|
|
1459
|
-
if isinstance(obj, type):
|
|
1460
|
-
return obj.__dict__.get("__annotations__", {})
|
|
1461
|
-
|
|
1462
|
-
return getattr(obj, "__annotations__", {})
|
|
1463
|
-
|
|
1464
1561
|
def get_type_name(type_hint):
|
|
1465
1562
|
if isinstance(type_hint, warp.codegen.Struct):
|
|
1466
1563
|
return get_type_name(type_hint.cls)
|
|
@@ -1482,7 +1579,7 @@ class Module:
|
|
|
1482
1579
|
for struct in module.structs.values():
|
|
1483
1580
|
s = ",".join(
|
|
1484
1581
|
"{}: {}".format(name, get_type_name(type_hint))
|
|
1485
|
-
for name, type_hint in get_annotations(struct.cls).items()
|
|
1582
|
+
for name, type_hint in warp.codegen.get_annotations(struct.cls).items()
|
|
1486
1583
|
)
|
|
1487
1584
|
ch.update(bytes(s, "utf-8"))
|
|
1488
1585
|
|
|
@@ -1495,22 +1592,18 @@ class Module:
|
|
|
1495
1592
|
ch.update(bytes(sig, "utf-8"))
|
|
1496
1593
|
|
|
1497
1594
|
# source
|
|
1498
|
-
|
|
1499
|
-
ch.update(bytes(s, "utf-8"))
|
|
1595
|
+
ch.update(bytes(func.adj.source, "utf-8"))
|
|
1500
1596
|
|
|
1501
1597
|
if func.custom_grad_func:
|
|
1502
|
-
|
|
1503
|
-
ch.update(bytes(s, "utf-8"))
|
|
1598
|
+
ch.update(bytes(func.custom_grad_func.adj.source, "utf-8"))
|
|
1504
1599
|
if func.custom_replay_func:
|
|
1505
|
-
|
|
1600
|
+
ch.update(bytes(func.custom_replay_func.adj.source, "utf-8"))
|
|
1506
1601
|
if func.replay_snippet:
|
|
1507
|
-
|
|
1602
|
+
ch.update(bytes(func.replay_snippet, "utf-8"))
|
|
1508
1603
|
if func.native_snippet:
|
|
1509
|
-
|
|
1510
|
-
ch.update(bytes(s, "utf-8"))
|
|
1604
|
+
ch.update(bytes(func.native_snippet, "utf-8"))
|
|
1511
1605
|
if func.adj_native_snippet:
|
|
1512
|
-
|
|
1513
|
-
ch.update(bytes(s, "utf-8"))
|
|
1606
|
+
ch.update(bytes(func.adj_native_snippet, "utf-8"))
|
|
1514
1607
|
|
|
1515
1608
|
# Populate constants referenced in this function
|
|
1516
1609
|
if func.adj:
|
|
@@ -1621,7 +1714,7 @@ class Module:
|
|
|
1621
1714
|
|
|
1622
1715
|
with ScopedTimer(
|
|
1623
1716
|
f"Module {self.name} {module_hash.hex()[:7]} load on device '{device}'", active=not warp.config.quiet
|
|
1624
|
-
):
|
|
1717
|
+
) as module_load_timer:
|
|
1625
1718
|
# -----------------------------------------------------------
|
|
1626
1719
|
# determine output paths
|
|
1627
1720
|
if device.is_cpu:
|
|
@@ -1657,7 +1750,13 @@ class Module:
|
|
|
1657
1750
|
|
|
1658
1751
|
build_dir = None
|
|
1659
1752
|
|
|
1660
|
-
if
|
|
1753
|
+
# we always want to build if binary doesn't exist yet
|
|
1754
|
+
# and we want to rebuild if we are not caching kernels or if we are tracking array access
|
|
1755
|
+
if (
|
|
1756
|
+
not os.path.exists(binary_path)
|
|
1757
|
+
or not warp.config.cache_kernels
|
|
1758
|
+
or warp.config.verify_autograd_array_access
|
|
1759
|
+
):
|
|
1661
1760
|
builder = ModuleBuilder(self, self.options)
|
|
1662
1761
|
|
|
1663
1762
|
# create a temporary (process unique) dir for build outputs before moving to the binary dir
|
|
@@ -1668,6 +1767,8 @@ class Module:
|
|
|
1668
1767
|
# dir may exist from previous attempts / runs / archs
|
|
1669
1768
|
Path(build_dir).mkdir(parents=True, exist_ok=True)
|
|
1670
1769
|
|
|
1770
|
+
module_load_timer.extra_msg = " (compiled)" # For wp.ScopedTimer informational purposes
|
|
1771
|
+
|
|
1671
1772
|
# build CPU
|
|
1672
1773
|
if device.is_cpu:
|
|
1673
1774
|
# build
|
|
@@ -1694,6 +1795,7 @@ class Module:
|
|
|
1694
1795
|
|
|
1695
1796
|
except Exception as e:
|
|
1696
1797
|
self.cpu_build_failed = True
|
|
1798
|
+
module_load_timer.extra_msg = " (error)"
|
|
1697
1799
|
raise (e)
|
|
1698
1800
|
|
|
1699
1801
|
elif device.is_cuda:
|
|
@@ -1722,6 +1824,7 @@ class Module:
|
|
|
1722
1824
|
|
|
1723
1825
|
except Exception as e:
|
|
1724
1826
|
self.cuda_build_failed = True
|
|
1827
|
+
module_load_timer.extra_msg = " (error)"
|
|
1725
1828
|
raise (e)
|
|
1726
1829
|
|
|
1727
1830
|
# -----------------------------------------------------------
|
|
@@ -1755,6 +1858,8 @@ class Module:
|
|
|
1755
1858
|
except Exception as e:
|
|
1756
1859
|
# We don't need source_code_path to be copied successfully to proceed, so warn and keep running
|
|
1757
1860
|
warp.utils.warn(f"Exception when renaming {source_code_path}: {e}")
|
|
1861
|
+
else:
|
|
1862
|
+
module_load_timer.extra_msg = " (cached)" # For wp.ScopedTimer informational purposes
|
|
1758
1863
|
|
|
1759
1864
|
# -----------------------------------------------------------
|
|
1760
1865
|
# Load CPU or CUDA binary
|
|
@@ -1767,6 +1872,7 @@ class Module:
|
|
|
1767
1872
|
if cuda_module is not None:
|
|
1768
1873
|
self.cuda_modules[device.context] = cuda_module
|
|
1769
1874
|
else:
|
|
1875
|
+
module_load_timer.extra_msg = " (error)"
|
|
1770
1876
|
raise Exception(f"Failed to load CUDA module '{self.name}'")
|
|
1771
1877
|
|
|
1772
1878
|
if build_dir:
|
|
@@ -1937,10 +2043,13 @@ class ContextGuard:
|
|
|
1937
2043
|
|
|
1938
2044
|
|
|
1939
2045
|
class Stream:
|
|
1940
|
-
def
|
|
1941
|
-
|
|
1942
|
-
|
|
2046
|
+
def __new__(cls, *args, **kwargs):
|
|
2047
|
+
instance = super(Stream, cls).__new__(cls)
|
|
2048
|
+
instance.cuda_stream = None
|
|
2049
|
+
instance.owner = False
|
|
2050
|
+
return instance
|
|
1943
2051
|
|
|
2052
|
+
def __init__(self, device=None, **kwargs):
|
|
1944
2053
|
# event used internally for synchronization (cached to avoid creating temporary events)
|
|
1945
2054
|
self._cached_event = None
|
|
1946
2055
|
|
|
@@ -2016,9 +2125,12 @@ class Event:
|
|
|
2016
2125
|
BLOCKING_SYNC = 0x1
|
|
2017
2126
|
DISABLE_TIMING = 0x2
|
|
2018
2127
|
|
|
2019
|
-
def
|
|
2020
|
-
|
|
2128
|
+
def __new__(cls, *args, **kwargs):
|
|
2129
|
+
instance = super(Event, cls).__new__(cls)
|
|
2130
|
+
instance.owner = False
|
|
2131
|
+
return instance
|
|
2021
2132
|
|
|
2133
|
+
def __init__(self, device=None, cuda_event=None, enable_timing=False):
|
|
2022
2134
|
device = get_device(device)
|
|
2023
2135
|
if not device.is_cuda:
|
|
2024
2136
|
raise RuntimeError(f"Device {device} is not a CUDA device")
|
|
@@ -2320,6 +2432,11 @@ Devicelike = Union[Device, str, None]
|
|
|
2320
2432
|
|
|
2321
2433
|
|
|
2322
2434
|
class Graph:
|
|
2435
|
+
def __new__(cls, *args, **kwargs):
|
|
2436
|
+
instance = super(Graph, cls).__new__(cls)
|
|
2437
|
+
instance.exec = None
|
|
2438
|
+
return instance
|
|
2439
|
+
|
|
2323
2440
|
def __init__(self, device: Device, exec: ctypes.c_void_p):
|
|
2324
2441
|
self.device = device
|
|
2325
2442
|
self.exec = exec
|
|
@@ -2682,48 +2799,38 @@ class Runtime:
|
|
|
2682
2799
|
ctypes.c_void_p,
|
|
2683
2800
|
ctypes.c_void_p,
|
|
2684
2801
|
ctypes.c_int,
|
|
2685
|
-
ctypes.c_float,
|
|
2686
|
-
ctypes.c_float,
|
|
2687
|
-
ctypes.c_float,
|
|
2688
|
-
ctypes.c_float,
|
|
2689
|
-
ctypes.c_float,
|
|
2802
|
+
ctypes.c_float * 9,
|
|
2803
|
+
ctypes.c_float * 3,
|
|
2690
2804
|
ctypes.c_bool,
|
|
2805
|
+
ctypes.c_float,
|
|
2691
2806
|
]
|
|
2692
2807
|
self.core.volume_f_from_tiles_device.restype = ctypes.c_uint64
|
|
2693
2808
|
self.core.volume_v_from_tiles_device.argtypes = [
|
|
2694
2809
|
ctypes.c_void_p,
|
|
2695
2810
|
ctypes.c_void_p,
|
|
2696
2811
|
ctypes.c_int,
|
|
2697
|
-
ctypes.c_float,
|
|
2698
|
-
ctypes.c_float,
|
|
2699
|
-
ctypes.c_float,
|
|
2700
|
-
ctypes.c_float,
|
|
2701
|
-
ctypes.c_float,
|
|
2702
|
-
ctypes.c_float,
|
|
2703
|
-
ctypes.c_float,
|
|
2812
|
+
ctypes.c_float * 9,
|
|
2813
|
+
ctypes.c_float * 3,
|
|
2704
2814
|
ctypes.c_bool,
|
|
2815
|
+
ctypes.c_float * 3,
|
|
2705
2816
|
]
|
|
2706
2817
|
self.core.volume_v_from_tiles_device.restype = ctypes.c_uint64
|
|
2707
2818
|
self.core.volume_i_from_tiles_device.argtypes = [
|
|
2708
2819
|
ctypes.c_void_p,
|
|
2709
2820
|
ctypes.c_void_p,
|
|
2710
2821
|
ctypes.c_int,
|
|
2711
|
-
ctypes.c_float,
|
|
2712
|
-
ctypes.
|
|
2713
|
-
ctypes.c_float,
|
|
2714
|
-
ctypes.c_float,
|
|
2715
|
-
ctypes.c_float,
|
|
2822
|
+
ctypes.c_float * 9,
|
|
2823
|
+
ctypes.c_float * 3,
|
|
2716
2824
|
ctypes.c_bool,
|
|
2825
|
+
ctypes.c_int,
|
|
2717
2826
|
]
|
|
2718
2827
|
self.core.volume_i_from_tiles_device.restype = ctypes.c_uint64
|
|
2719
2828
|
self.core.volume_index_from_tiles_device.argtypes = [
|
|
2720
2829
|
ctypes.c_void_p,
|
|
2721
2830
|
ctypes.c_void_p,
|
|
2722
2831
|
ctypes.c_int,
|
|
2723
|
-
ctypes.c_float,
|
|
2724
|
-
ctypes.c_float,
|
|
2725
|
-
ctypes.c_float,
|
|
2726
|
-
ctypes.c_float,
|
|
2832
|
+
ctypes.c_float * 9,
|
|
2833
|
+
ctypes.c_float * 3,
|
|
2727
2834
|
ctypes.c_bool,
|
|
2728
2835
|
]
|
|
2729
2836
|
self.core.volume_index_from_tiles_device.restype = ctypes.c_uint64
|
|
@@ -2731,10 +2838,8 @@ class Runtime:
|
|
|
2731
2838
|
ctypes.c_void_p,
|
|
2732
2839
|
ctypes.c_void_p,
|
|
2733
2840
|
ctypes.c_int,
|
|
2734
|
-
ctypes.c_float,
|
|
2735
|
-
ctypes.c_float,
|
|
2736
|
-
ctypes.c_float,
|
|
2737
|
-
ctypes.c_float,
|
|
2841
|
+
ctypes.c_float * 9,
|
|
2842
|
+
ctypes.c_float * 3,
|
|
2738
2843
|
ctypes.c_bool,
|
|
2739
2844
|
]
|
|
2740
2845
|
self.core.volume_from_active_voxels_device.restype = ctypes.c_uint64
|
|
@@ -2780,39 +2885,38 @@ class Runtime:
|
|
|
2780
2885
|
self.core.volume_get_blind_data_info.restype = ctypes.c_char_p
|
|
2781
2886
|
|
|
2782
2887
|
bsr_matrix_from_triplets_argtypes = [
|
|
2783
|
-
ctypes.c_int,
|
|
2784
|
-
ctypes.c_int,
|
|
2785
|
-
ctypes.c_int,
|
|
2786
|
-
ctypes.c_int,
|
|
2787
|
-
ctypes.
|
|
2788
|
-
ctypes.
|
|
2789
|
-
ctypes.
|
|
2790
|
-
ctypes.
|
|
2791
|
-
ctypes.
|
|
2792
|
-
ctypes.
|
|
2888
|
+
ctypes.c_int, # rows_per_bock
|
|
2889
|
+
ctypes.c_int, # cols_per_blocks
|
|
2890
|
+
ctypes.c_int, # row_count
|
|
2891
|
+
ctypes.c_int, # tpl_nnz
|
|
2892
|
+
ctypes.POINTER(ctypes.c_int), # tpl_rows
|
|
2893
|
+
ctypes.POINTER(ctypes.c_int), # tpl_cols
|
|
2894
|
+
ctypes.c_void_p, # tpl_values
|
|
2895
|
+
ctypes.c_bool, # prune_numerical_zeros
|
|
2896
|
+
ctypes.POINTER(ctypes.c_int), # bsr_offsets
|
|
2897
|
+
ctypes.POINTER(ctypes.c_int), # bsr_columns
|
|
2898
|
+
ctypes.c_void_p, # bsr_values
|
|
2899
|
+
ctypes.POINTER(ctypes.c_int), # bsr_nnz
|
|
2900
|
+
ctypes.c_void_p, # bsr_nnz_event
|
|
2793
2901
|
]
|
|
2902
|
+
|
|
2794
2903
|
self.core.bsr_matrix_from_triplets_float_host.argtypes = bsr_matrix_from_triplets_argtypes
|
|
2795
2904
|
self.core.bsr_matrix_from_triplets_double_host.argtypes = bsr_matrix_from_triplets_argtypes
|
|
2796
2905
|
self.core.bsr_matrix_from_triplets_float_device.argtypes = bsr_matrix_from_triplets_argtypes
|
|
2797
2906
|
self.core.bsr_matrix_from_triplets_double_device.argtypes = bsr_matrix_from_triplets_argtypes
|
|
2798
2907
|
|
|
2799
|
-
self.core.bsr_matrix_from_triplets_float_host.restype = ctypes.c_int
|
|
2800
|
-
self.core.bsr_matrix_from_triplets_double_host.restype = ctypes.c_int
|
|
2801
|
-
self.core.bsr_matrix_from_triplets_float_device.restype = ctypes.c_int
|
|
2802
|
-
self.core.bsr_matrix_from_triplets_double_device.restype = ctypes.c_int
|
|
2803
|
-
|
|
2804
2908
|
bsr_transpose_argtypes = [
|
|
2805
|
-
ctypes.c_int,
|
|
2806
|
-
ctypes.c_int,
|
|
2807
|
-
ctypes.c_int,
|
|
2808
|
-
ctypes.c_int,
|
|
2809
|
-
ctypes.c_int,
|
|
2810
|
-
ctypes.
|
|
2811
|
-
ctypes.
|
|
2812
|
-
ctypes.
|
|
2813
|
-
ctypes.
|
|
2814
|
-
ctypes.
|
|
2815
|
-
ctypes.
|
|
2909
|
+
ctypes.c_int, # rows_per_bock
|
|
2910
|
+
ctypes.c_int, # cols_per_blocks
|
|
2911
|
+
ctypes.c_int, # row_count
|
|
2912
|
+
ctypes.c_int, # col count
|
|
2913
|
+
ctypes.c_int, # nnz
|
|
2914
|
+
ctypes.POINTER(ctypes.c_int), # transposed_bsr_offsets
|
|
2915
|
+
ctypes.POINTER(ctypes.c_int), # transposed_bsr_columns
|
|
2916
|
+
ctypes.c_void_p, # bsr_values
|
|
2917
|
+
ctypes.POINTER(ctypes.c_int), # transposed_bsr_offsets
|
|
2918
|
+
ctypes.POINTER(ctypes.c_int), # transposed_bsr_columns
|
|
2919
|
+
ctypes.c_void_p, # transposed_bsr_values
|
|
2816
2920
|
]
|
|
2817
2921
|
self.core.bsr_transpose_float_host.argtypes = bsr_transpose_argtypes
|
|
2818
2922
|
self.core.bsr_transpose_double_host.argtypes = bsr_transpose_argtypes
|
|
@@ -3019,35 +3123,63 @@ class Runtime:
|
|
|
3019
3123
|
self.device_map["cpu"] = self.cpu_device
|
|
3020
3124
|
self.context_map[None] = self.cpu_device
|
|
3021
3125
|
|
|
3022
|
-
|
|
3126
|
+
self.is_cuda_enabled = bool(self.core.is_cuda_enabled())
|
|
3127
|
+
self.is_cuda_compatibility_enabled = bool(self.core.is_cuda_compatibility_enabled())
|
|
3023
3128
|
|
|
3024
|
-
|
|
3129
|
+
self.toolkit_version = None # CTK version used to build the core lib
|
|
3130
|
+
self.driver_version = None # installed driver version
|
|
3131
|
+
self.min_driver_version = None # minimum required driver version
|
|
3132
|
+
|
|
3133
|
+
self.cuda_devices = []
|
|
3134
|
+
self.cuda_primary_devices = []
|
|
3135
|
+
|
|
3136
|
+
cuda_device_count = 0
|
|
3137
|
+
|
|
3138
|
+
if self.is_cuda_enabled:
|
|
3025
3139
|
# get CUDA Toolkit and driver versions
|
|
3026
|
-
|
|
3027
|
-
|
|
3028
|
-
|
|
3029
|
-
#
|
|
3030
|
-
|
|
3031
|
-
|
|
3032
|
-
|
|
3033
|
-
|
|
3034
|
-
|
|
3140
|
+
toolkit_version = self.core.cuda_toolkit_version()
|
|
3141
|
+
driver_version = self.core.cuda_driver_version()
|
|
3142
|
+
|
|
3143
|
+
# save versions as tuples, e.g., (12, 4)
|
|
3144
|
+
self.toolkit_version = (toolkit_version // 1000, (toolkit_version % 1000) // 10)
|
|
3145
|
+
self.driver_version = (driver_version // 1000, (driver_version % 1000) // 10)
|
|
3146
|
+
|
|
3147
|
+
# determine minimum required driver version
|
|
3148
|
+
if self.is_cuda_compatibility_enabled:
|
|
3149
|
+
# we can rely on minor version compatibility, but 11.4 is the absolute minimum required from the driver
|
|
3150
|
+
if self.toolkit_version[0] > 11:
|
|
3151
|
+
self.min_driver_version = (self.toolkit_version[0], 0)
|
|
3152
|
+
else:
|
|
3153
|
+
self.min_driver_version = (11, 4)
|
|
3035
3154
|
else:
|
|
3036
|
-
|
|
3155
|
+
# we can't rely on minor version compatibility, so the driver can't be older than the toolkit
|
|
3156
|
+
self.min_driver_version = self.toolkit_version
|
|
3157
|
+
|
|
3158
|
+
# determine if the installed driver is sufficient
|
|
3159
|
+
if self.driver_version >= self.min_driver_version:
|
|
3160
|
+
# get all architectures supported by NVRTC
|
|
3161
|
+
num_archs = self.core.nvrtc_supported_arch_count()
|
|
3162
|
+
if num_archs > 0:
|
|
3163
|
+
archs = (ctypes.c_int * num_archs)()
|
|
3164
|
+
self.core.nvrtc_supported_archs(archs)
|
|
3165
|
+
self.nvrtc_supported_archs = set(archs)
|
|
3166
|
+
else:
|
|
3167
|
+
self.nvrtc_supported_archs = set()
|
|
3037
3168
|
|
|
3038
|
-
|
|
3039
|
-
|
|
3040
|
-
self.cuda_custom_context_count = [0] * cuda_device_count
|
|
3169
|
+
# get CUDA device count
|
|
3170
|
+
cuda_device_count = self.core.cuda_device_get_count()
|
|
3041
3171
|
|
|
3042
|
-
|
|
3043
|
-
|
|
3044
|
-
|
|
3045
|
-
|
|
3046
|
-
|
|
3047
|
-
|
|
3048
|
-
|
|
3049
|
-
|
|
3050
|
-
|
|
3172
|
+
# register primary CUDA devices
|
|
3173
|
+
for i in range(cuda_device_count):
|
|
3174
|
+
alias = f"cuda:{i}"
|
|
3175
|
+
device = Device(self, alias, ordinal=i, is_primary=True)
|
|
3176
|
+
self.cuda_devices.append(device)
|
|
3177
|
+
self.cuda_primary_devices.append(device)
|
|
3178
|
+
self.device_map[alias] = device
|
|
3179
|
+
|
|
3180
|
+
# count known non-primary contexts on each physical device so we can
|
|
3181
|
+
# give them reasonable aliases (e.g., "cuda:0.0", "cuda:0.1")
|
|
3182
|
+
self.cuda_custom_context_count = [0] * cuda_device_count
|
|
3051
3183
|
|
|
3052
3184
|
# set default device
|
|
3053
3185
|
if cuda_device_count > 0:
|
|
@@ -3066,14 +3198,8 @@ class Runtime:
|
|
|
3066
3198
|
# initialize kernel cache
|
|
3067
3199
|
warp.build.init_kernel_cache(warp.config.kernel_cache_dir)
|
|
3068
3200
|
|
|
3069
|
-
|
|
3070
|
-
|
|
3071
|
-
for cuda_device in self.cuda_devices:
|
|
3072
|
-
if cuda_device.is_primary:
|
|
3073
|
-
if not cuda_device.is_uva:
|
|
3074
|
-
devices_without_uva.append(cuda_device)
|
|
3075
|
-
if not cuda_device.is_mempool_supported:
|
|
3076
|
-
devices_without_mempool.append(cuda_device)
|
|
3201
|
+
# global tape
|
|
3202
|
+
self.tape = None
|
|
3077
3203
|
|
|
3078
3204
|
# print device and version information
|
|
3079
3205
|
if not warp.config.quiet:
|
|
@@ -3081,18 +3207,24 @@ class Runtime:
|
|
|
3081
3207
|
|
|
3082
3208
|
greeting.append(f"Warp {warp.config.version} initialized:")
|
|
3083
3209
|
if cuda_device_count > 0:
|
|
3084
|
-
|
|
3085
|
-
driver_version = (self.driver_version // 1000, (self.driver_version % 1000) // 10)
|
|
3210
|
+
# print CUDA version info
|
|
3086
3211
|
greeting.append(
|
|
3087
|
-
f" CUDA Toolkit {toolkit_version[0]}.{toolkit_version[1]}, Driver {driver_version[0]}.{driver_version[1]}"
|
|
3212
|
+
f" CUDA Toolkit {self.toolkit_version[0]}.{self.toolkit_version[1]}, Driver {self.driver_version[0]}.{self.driver_version[1]}"
|
|
3088
3213
|
)
|
|
3089
3214
|
else:
|
|
3090
|
-
|
|
3091
|
-
|
|
3092
|
-
greeting.append(" CUDA devices not available")
|
|
3093
|
-
else:
|
|
3215
|
+
# briefly explain lack of CUDA devices
|
|
3216
|
+
if not self.is_cuda_enabled:
|
|
3094
3217
|
# Warp was compiled without CUDA support
|
|
3095
|
-
greeting.append(" CUDA
|
|
3218
|
+
greeting.append(" CUDA not enabled in this build")
|
|
3219
|
+
elif self.driver_version < self.min_driver_version:
|
|
3220
|
+
# insufficient CUDA driver version
|
|
3221
|
+
greeting.append(
|
|
3222
|
+
f" CUDA Toolkit {self.toolkit_version[0]}.{self.toolkit_version[1]}, Driver {self.driver_version[0]}.{self.driver_version[1]}"
|
|
3223
|
+
" (insufficient CUDA driver version!)"
|
|
3224
|
+
)
|
|
3225
|
+
else:
|
|
3226
|
+
# CUDA is supported, but no devices are available
|
|
3227
|
+
greeting.append(" CUDA devices not available")
|
|
3096
3228
|
greeting.append(" Devices:")
|
|
3097
3229
|
alias_str = f'"{self.cpu_device.alias}"'
|
|
3098
3230
|
name_str = f'"{self.cpu_device.name}"'
|
|
@@ -3151,41 +3283,44 @@ class Runtime:
|
|
|
3151
3283
|
print("\n".join(greeting))
|
|
3152
3284
|
|
|
3153
3285
|
if cuda_device_count > 0:
|
|
3154
|
-
#
|
|
3286
|
+
# ensure initialization did not change the initial context (e.g. querying available memory)
|
|
3287
|
+
self.core.cuda_context_set_current(initial_context)
|
|
3288
|
+
|
|
3289
|
+
# detect possible misconfiguration of the system
|
|
3290
|
+
devices_without_uva = []
|
|
3291
|
+
devices_without_mempool = []
|
|
3292
|
+
for cuda_device in self.cuda_primary_devices:
|
|
3293
|
+
if not cuda_device.is_uva:
|
|
3294
|
+
devices_without_uva.append(cuda_device)
|
|
3295
|
+
if not cuda_device.is_mempool_supported:
|
|
3296
|
+
devices_without_mempool.append(cuda_device)
|
|
3297
|
+
|
|
3155
3298
|
if devices_without_uva:
|
|
3156
3299
|
# This should not happen on any system officially supported by Warp. UVA is not available
|
|
3157
3300
|
# on 32-bit Windows, which we don't support. Nonetheless, we should check and report a
|
|
3158
3301
|
# warning out of abundance of caution. It may help with debugging a broken VM setup etc.
|
|
3159
3302
|
warp.utils.warn(
|
|
3160
|
-
f"Support for Unified Virtual Addressing (UVA) was not detected on devices {devices_without_uva}."
|
|
3303
|
+
f"\n Support for Unified Virtual Addressing (UVA) was not detected on devices {devices_without_uva}."
|
|
3161
3304
|
)
|
|
3162
3305
|
if devices_without_mempool:
|
|
3163
3306
|
warp.utils.warn(
|
|
3164
|
-
f"Support for CUDA memory pools was not detected on devices {devices_without_mempool}.
|
|
3165
|
-
"This prevents memory allocations in CUDA graphs and may result in poor performance.
|
|
3166
|
-
"Is the UVM driver enabled?"
|
|
3307
|
+
f"\n Support for CUDA memory pools was not detected on devices {devices_without_mempool}."
|
|
3308
|
+
"\n This prevents memory allocations in CUDA graphs and may result in poor performance."
|
|
3309
|
+
"\n Is the UVM driver enabled?"
|
|
3167
3310
|
)
|
|
3168
3311
|
|
|
3169
|
-
|
|
3170
|
-
#
|
|
3171
|
-
#
|
|
3172
|
-
if self.driver_version < self.
|
|
3173
|
-
|
|
3174
|
-
|
|
3175
|
-
|
|
3176
|
-
"
|
|
3177
|
-
"
|
|
3178
|
-
"* Warp is not fully supported by the current driver. *\n"
|
|
3179
|
-
"* Some CUDA functionality may not work correctly! *\n"
|
|
3180
|
-
"* Update the driver or rebuild Warp without the --quick flag. *\n"
|
|
3181
|
-
"******************************************************************\n"
|
|
3312
|
+
elif self.is_cuda_enabled:
|
|
3313
|
+
# Report a warning about insufficient driver version. The warning should appear even in quiet mode
|
|
3314
|
+
# when the greeting message is suppressed. Also try to provide guidance for resolving the situation.
|
|
3315
|
+
if self.driver_version < self.min_driver_version:
|
|
3316
|
+
msg = []
|
|
3317
|
+
msg.append("\n Insufficient CUDA driver version.")
|
|
3318
|
+
msg.append(
|
|
3319
|
+
f"The minimum required CUDA driver version is {self.min_driver_version[0]}.{self.min_driver_version[1]}, "
|
|
3320
|
+
f"but the installed CUDA driver version is {self.driver_version[0]}.{self.driver_version[1]}."
|
|
3182
3321
|
)
|
|
3183
|
-
|
|
3184
|
-
|
|
3185
|
-
self.core.cuda_context_set_current(initial_context)
|
|
3186
|
-
|
|
3187
|
-
# global tape
|
|
3188
|
-
self.tape = None
|
|
3322
|
+
msg.append("Visit https://github.com/NVIDIA/warp/blob/main/README.md#installing for guidance.")
|
|
3323
|
+
warp.utils.warn("\n ".join(msg))
|
|
3189
3324
|
|
|
3190
3325
|
def get_error_string(self):
|
|
3191
3326
|
return self.core.get_error_string().decode("utf-8")
|
|
@@ -3208,17 +3343,20 @@ class Runtime:
|
|
|
3208
3343
|
return dll
|
|
3209
3344
|
|
|
3210
3345
|
def get_device(self, ident: Devicelike = None) -> Device:
|
|
3211
|
-
|
|
3346
|
+
# special cases
|
|
3347
|
+
if type(ident) is Device:
|
|
3212
3348
|
return ident
|
|
3213
3349
|
elif ident is None:
|
|
3214
3350
|
return self.default_device
|
|
3215
|
-
|
|
3216
|
-
|
|
3217
|
-
|
|
3218
|
-
|
|
3219
|
-
|
|
3220
|
-
|
|
3221
|
-
|
|
3351
|
+
|
|
3352
|
+
# string lookup
|
|
3353
|
+
device = self.device_map.get(ident)
|
|
3354
|
+
if device is not None:
|
|
3355
|
+
return device
|
|
3356
|
+
elif ident == "cuda":
|
|
3357
|
+
return self.get_current_cuda_device()
|
|
3358
|
+
|
|
3359
|
+
raise ValueError(f"Invalid device identifier: {ident}")
|
|
3222
3360
|
|
|
3223
3361
|
def set_default_device(self, ident: Devicelike):
|
|
3224
3362
|
self.default_device = self.get_device(ident)
|
|
@@ -3248,7 +3386,7 @@ class Runtime:
|
|
|
3248
3386
|
return self.cuda_devices[0]
|
|
3249
3387
|
else:
|
|
3250
3388
|
# CUDA is not available
|
|
3251
|
-
if not self.
|
|
3389
|
+
if not self.is_cuda_enabled:
|
|
3252
3390
|
raise RuntimeError('"cuda" device requested but this build of Warp does not support CUDA')
|
|
3253
3391
|
else:
|
|
3254
3392
|
raise RuntimeError('"cuda" device requested but CUDA is not supported by the hardware or driver')
|
|
@@ -3821,6 +3959,11 @@ class RegisteredGLBuffer:
|
|
|
3821
3959
|
|
|
3822
3960
|
__fallback_warning_shown = False
|
|
3823
3961
|
|
|
3962
|
+
def __new__(cls, *args, **kwargs):
|
|
3963
|
+
instance = super(RegisteredGLBuffer, cls).__new__(cls)
|
|
3964
|
+
instance.resource = None
|
|
3965
|
+
return instance
|
|
3966
|
+
|
|
3824
3967
|
def __init__(self, gl_buffer_id: int, device: Devicelike = None, flags: int = NONE, fallback_to_copy: bool = True):
|
|
3825
3968
|
"""
|
|
3826
3969
|
Args:
|
|
@@ -4230,6 +4373,10 @@ def pack_arg(kernel, arg_type, arg_name, value, device, adjoint=False):
|
|
|
4230
4373
|
# allow for NULL arrays
|
|
4231
4374
|
return arg_type.__ctype__()
|
|
4232
4375
|
|
|
4376
|
+
elif isinstance(value, warp.types.array_t):
|
|
4377
|
+
# accept array descriptors verbatum
|
|
4378
|
+
return value
|
|
4379
|
+
|
|
4233
4380
|
else:
|
|
4234
4381
|
# check for array type
|
|
4235
4382
|
# - in forward passes, array types have to match
|
|
@@ -4240,6 +4387,32 @@ def pack_arg(kernel, arg_type, arg_name, value, device, adjoint=False):
|
|
|
4240
4387
|
array_matches = type(value) is type(arg_type)
|
|
4241
4388
|
|
|
4242
4389
|
if not array_matches:
|
|
4390
|
+
# if a regular Warp array is required, try converting from __cuda_array_interface__ or __array_interface__
|
|
4391
|
+
if isinstance(arg_type, warp.array):
|
|
4392
|
+
if device.is_cuda:
|
|
4393
|
+
# check for __cuda_array_interface__
|
|
4394
|
+
try:
|
|
4395
|
+
interface = value.__cuda_array_interface__
|
|
4396
|
+
except AttributeError:
|
|
4397
|
+
pass
|
|
4398
|
+
else:
|
|
4399
|
+
return warp.types.array_ctype_from_interface(interface, dtype=arg_type.dtype, owner=value)
|
|
4400
|
+
else:
|
|
4401
|
+
# check for __array_interface__
|
|
4402
|
+
try:
|
|
4403
|
+
interface = value.__array_interface__
|
|
4404
|
+
except AttributeError:
|
|
4405
|
+
pass
|
|
4406
|
+
else:
|
|
4407
|
+
return warp.types.array_ctype_from_interface(interface, dtype=arg_type.dtype, owner=value)
|
|
4408
|
+
# check for __array__() method, e.g. Torch CPU tensors
|
|
4409
|
+
try:
|
|
4410
|
+
interface = value.__array__().__array_interface__
|
|
4411
|
+
except AttributeError:
|
|
4412
|
+
pass
|
|
4413
|
+
else:
|
|
4414
|
+
return warp.types.array_ctype_from_interface(interface, dtype=arg_type.dtype, owner=value)
|
|
4415
|
+
|
|
4243
4416
|
adj = "adjoint " if adjoint else ""
|
|
4244
4417
|
raise RuntimeError(
|
|
4245
4418
|
f"Error launching kernel '{kernel.key}', {adj}argument '{arg_name}' expects an array of type {type(arg_type)}, but passed value has type {type(value)}."
|
|
@@ -4603,6 +4776,10 @@ def launch(
|
|
|
4603
4776
|
caller = {"file": frame.f_code.co_filename, "lineno": frame.f_lineno, "func": frame.f_code.co_name}
|
|
4604
4777
|
runtime.tape.record_launch(kernel, dim, max_blocks, inputs, outputs, device, metadata={"caller": caller})
|
|
4605
4778
|
|
|
4779
|
+
# detect illegal inter-kernel read/write access patterns if verification flag is set
|
|
4780
|
+
if warp.config.verify_autograd_array_access:
|
|
4781
|
+
runtime.tape._check_kernel_array_access(kernel, fwd_args)
|
|
4782
|
+
|
|
4606
4783
|
|
|
4607
4784
|
def synchronize():
|
|
4608
4785
|
"""Manually synchronize the calling CPU thread with any outstanding CUDA work on all devices
|
|
@@ -4808,7 +4985,7 @@ def capture_begin(device: Devicelike = None, stream=None, force_module_load=None
|
|
|
4808
4985
|
"""
|
|
4809
4986
|
|
|
4810
4987
|
if force_module_load is None:
|
|
4811
|
-
if runtime.driver_version >=
|
|
4988
|
+
if runtime.driver_version >= (12, 3):
|
|
4812
4989
|
# Driver versions 12.3 and can compile modules during graph capture
|
|
4813
4990
|
force_module_load = False
|
|
4814
4991
|
else:
|
|
@@ -5084,6 +5261,9 @@ def copy(
|
|
|
5084
5261
|
),
|
|
5085
5262
|
arrays=[dest, src],
|
|
5086
5263
|
)
|
|
5264
|
+
if warp.config.verify_autograd_array_access:
|
|
5265
|
+
dest.mark_write()
|
|
5266
|
+
src.mark_read()
|
|
5087
5267
|
|
|
5088
5268
|
|
|
5089
5269
|
def adj_copy(
|
|
@@ -5106,8 +5286,16 @@ def type_str(t):
|
|
|
5106
5286
|
return "Any"
|
|
5107
5287
|
elif t == Callable:
|
|
5108
5288
|
return "Callable"
|
|
5289
|
+
elif t == Tuple[int]:
|
|
5290
|
+
return "Tuple[int]"
|
|
5109
5291
|
elif t == Tuple[int, int]:
|
|
5110
5292
|
return "Tuple[int, int]"
|
|
5293
|
+
elif t == Tuple[int, int, int]:
|
|
5294
|
+
return "Tuple[int, int, int]"
|
|
5295
|
+
elif t == Tuple[int, int, int, int]:
|
|
5296
|
+
return "Tuple[int, int, int, int]"
|
|
5297
|
+
elif t == Tuple[int, ...]:
|
|
5298
|
+
return "Tuple[int, ...]"
|
|
5111
5299
|
elif isinstance(t, int):
|
|
5112
5300
|
return str(t)
|
|
5113
5301
|
elif isinstance(t, List):
|
|
@@ -5142,6 +5330,9 @@ def type_str(t):
|
|
|
5142
5330
|
return f"Transformation[{type_str(t._wp_scalar_type_)}]"
|
|
5143
5331
|
|
|
5144
5332
|
raise TypeError("Invalid vector or matrix dimensions")
|
|
5333
|
+
elif typing.get_origin(t) in (List, Mapping, Sequence, Union, Tuple):
|
|
5334
|
+
args_repr = ", ".join(type_str(x) for x in typing.get_args(t))
|
|
5335
|
+
return f"{t.__name__}[{args_repr}]"
|
|
5145
5336
|
|
|
5146
5337
|
return t.__name__
|
|
5147
5338
|
|
|
@@ -5169,7 +5360,7 @@ def print_function(f, file, noentry=False): # pragma: no cover
|
|
|
5169
5360
|
try:
|
|
5170
5361
|
# todo: construct a default value for each of the functions args
|
|
5171
5362
|
# so we can generate the return type for overloaded functions
|
|
5172
|
-
return_type = " -> " + type_str(f.value_func(None, None
|
|
5363
|
+
return_type = " -> " + type_str(f.value_func(None, None))
|
|
5173
5364
|
except Exception:
|
|
5174
5365
|
pass
|
|
5175
5366
|
|
|
@@ -5232,14 +5423,6 @@ def export_functions_rst(file): # pragma: no cover
|
|
|
5232
5423
|
print(".. class:: Transformation", file=file)
|
|
5233
5424
|
print(".. class:: Array", file=file)
|
|
5234
5425
|
|
|
5235
|
-
print("\nQuery Types", file=file)
|
|
5236
|
-
print("-------------", file=file)
|
|
5237
|
-
print(".. autoclass:: bvh_query_t", file=file)
|
|
5238
|
-
print(".. autoclass:: hash_grid_query_t", file=file)
|
|
5239
|
-
print(".. autoclass:: mesh_query_aabb_t", file=file)
|
|
5240
|
-
print(".. autoclass:: mesh_query_point_t", file=file)
|
|
5241
|
-
print(".. autoclass:: mesh_query_ray_t", file=file)
|
|
5242
|
-
|
|
5243
5426
|
# build dictionary of all functions by group
|
|
5244
5427
|
groups = {}
|
|
5245
5428
|
|
|
@@ -5252,8 +5435,17 @@ def export_functions_rst(file): # pragma: no cover
|
|
|
5252
5435
|
for o in f.overloads:
|
|
5253
5436
|
groups[f.group].append(o)
|
|
5254
5437
|
|
|
5255
|
-
# Keep track of what function
|
|
5256
|
-
written_functions =
|
|
5438
|
+
# Keep track of what function and query types have been written
|
|
5439
|
+
written_functions = set()
|
|
5440
|
+
written_query_types = set()
|
|
5441
|
+
|
|
5442
|
+
query_types = (
|
|
5443
|
+
("bvh_query", "BvhQuery"),
|
|
5444
|
+
("mesh_query_aabb", "MeshQueryAABB"),
|
|
5445
|
+
("mesh_query_point", "MeshQueryPoint"),
|
|
5446
|
+
("mesh_query_ray", "MeshQueryRay"),
|
|
5447
|
+
("hash_grid_query", "HashGridQuery"),
|
|
5448
|
+
)
|
|
5257
5449
|
|
|
5258
5450
|
for k, g in groups.items():
|
|
5259
5451
|
print("\n", file=file)
|
|
@@ -5261,12 +5453,18 @@ def export_functions_rst(file): # pragma: no cover
|
|
|
5261
5453
|
print("---------------", file=file)
|
|
5262
5454
|
|
|
5263
5455
|
for f in g:
|
|
5456
|
+
for f_prefix, query_type in query_types:
|
|
5457
|
+
if f.key.startswith(f_prefix) and query_type not in written_query_types:
|
|
5458
|
+
print(f".. autoclass:: {query_type}", file=file)
|
|
5459
|
+
written_query_types.add(query_type)
|
|
5460
|
+
break
|
|
5461
|
+
|
|
5264
5462
|
if f.key in written_functions:
|
|
5265
5463
|
# Add :noindex: + :nocontentsentry: since Sphinx gets confused
|
|
5266
5464
|
print_function(f, file=file, noentry=True)
|
|
5267
5465
|
else:
|
|
5268
5466
|
if print_function(f, file=file):
|
|
5269
|
-
written_functions
|
|
5467
|
+
written_functions.add(f.key)
|
|
5270
5468
|
|
|
5271
5469
|
# footnotes
|
|
5272
5470
|
print(".. rubric:: Footnotes", file=file)
|
|
@@ -5327,7 +5525,7 @@ def export_stubs(file): # pragma: no cover
|
|
|
5327
5525
|
try:
|
|
5328
5526
|
# todo: construct a default value for each of the functions args
|
|
5329
5527
|
# so we can generate the return type for overloaded functions
|
|
5330
|
-
return_type = f.value_func(None, None
|
|
5528
|
+
return_type = f.value_func(None, None)
|
|
5331
5529
|
if return_type:
|
|
5332
5530
|
return_str = " -> " + type_str(return_type)
|
|
5333
5531
|
|
|
@@ -5373,21 +5571,25 @@ def export_builtins(file: io.TextIOBase): # pragma: no cover
|
|
|
5373
5571
|
if not f.is_simple():
|
|
5374
5572
|
continue
|
|
5375
5573
|
|
|
5376
|
-
args = ", ".join(f"{ctype_arg_str(v)} {k}" for k, v in f.input_types.items())
|
|
5377
|
-
params = ", ".join(f.input_types.keys())
|
|
5378
|
-
|
|
5379
|
-
return_type = ""
|
|
5380
|
-
|
|
5381
5574
|
try:
|
|
5382
5575
|
# todo: construct a default value for each of the functions args
|
|
5383
5576
|
# so we can generate the return type for overloaded functions
|
|
5384
|
-
return_type = ctype_ret_str(f.value_func(None, None
|
|
5577
|
+
return_type = ctype_ret_str(f.value_func(None, None))
|
|
5385
5578
|
except Exception:
|
|
5386
5579
|
continue
|
|
5387
5580
|
|
|
5388
5581
|
if return_type.startswith("Tuple"):
|
|
5389
5582
|
continue
|
|
5390
5583
|
|
|
5584
|
+
# Runtime arguments that are to be passed to the function, not its template signature.
|
|
5585
|
+
if f.export_func is not None:
|
|
5586
|
+
func_args = f.export_func(f.input_types)
|
|
5587
|
+
else:
|
|
5588
|
+
func_args = f.input_types
|
|
5589
|
+
|
|
5590
|
+
args = ", ".join(f"{ctype_arg_str(v)} {k}" for k, v in func_args.items())
|
|
5591
|
+
params = ", ".join(func_args.keys())
|
|
5592
|
+
|
|
5391
5593
|
if args == "":
|
|
5392
5594
|
file.write(f"WP_API void {f.mangled_name}({return_type}* ret) {{ *ret = wp::{f.key}({params}); }}\n")
|
|
5393
5595
|
elif return_type == "None":
|