warp-lang 1.5.1__py3-none-manylinux2014_x86_64.whl → 1.6.0__py3-none-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +5 -0
- warp/autograd.py +414 -191
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/build.py +40 -12
- warp/build_dll.py +13 -6
- warp/builtins.py +1076 -480
- warp/codegen.py +240 -119
- warp/config.py +1 -1
- warp/context.py +298 -84
- warp/examples/assets/square_cloth.usd +0 -0
- warp/examples/benchmarks/benchmark_gemm.py +27 -18
- warp/examples/benchmarks/benchmark_interop_paddle.py +3 -3
- warp/examples/benchmarks/benchmark_interop_torch.py +3 -3
- warp/examples/core/example_torch.py +18 -34
- warp/examples/fem/example_apic_fluid.py +1 -0
- warp/examples/fem/example_mixed_elasticity.py +1 -1
- warp/examples/optim/example_bounce.py +1 -1
- warp/examples/optim/example_cloth_throw.py +1 -1
- warp/examples/optim/example_diffray.py +4 -15
- warp/examples/optim/example_drone.py +1 -1
- warp/examples/optim/example_softbody_properties.py +392 -0
- warp/examples/optim/example_trajectory.py +1 -3
- warp/examples/optim/example_walker.py +5 -0
- warp/examples/sim/example_cartpole.py +0 -2
- warp/examples/sim/example_cloth_self_contact.py +260 -0
- warp/examples/sim/example_granular_collision_sdf.py +4 -5
- warp/examples/sim/example_jacobian_ik.py +0 -2
- warp/examples/sim/example_quadruped.py +5 -2
- warp/examples/tile/example_tile_cholesky.py +79 -0
- warp/examples/tile/example_tile_convolution.py +2 -2
- warp/examples/tile/example_tile_fft.py +2 -2
- warp/examples/tile/example_tile_filtering.py +3 -3
- warp/examples/tile/example_tile_matmul.py +4 -4
- warp/examples/tile/example_tile_mlp.py +12 -12
- warp/examples/tile/example_tile_nbody.py +180 -0
- warp/examples/tile/example_tile_walker.py +319 -0
- warp/math.py +147 -0
- warp/native/array.h +12 -0
- warp/native/builtin.h +0 -1
- warp/native/bvh.cpp +149 -70
- warp/native/bvh.cu +287 -68
- warp/native/bvh.h +195 -85
- warp/native/clang/clang.cpp +5 -1
- warp/native/cuda_util.cpp +35 -0
- warp/native/cuda_util.h +5 -0
- warp/native/exports.h +40 -40
- warp/native/intersect.h +17 -0
- warp/native/mat.h +41 -0
- warp/native/mathdx.cpp +19 -0
- warp/native/mesh.cpp +25 -8
- warp/native/mesh.cu +153 -101
- warp/native/mesh.h +482 -403
- warp/native/quat.h +40 -0
- warp/native/solid_angle.h +7 -0
- warp/native/sort.cpp +85 -0
- warp/native/sort.cu +34 -0
- warp/native/sort.h +3 -1
- warp/native/spatial.h +11 -0
- warp/native/tile.h +1185 -664
- warp/native/tile_reduce.h +8 -6
- warp/native/vec.h +41 -0
- warp/native/warp.cpp +8 -1
- warp/native/warp.cu +263 -40
- warp/native/warp.h +19 -5
- warp/optim/linear.py +22 -4
- warp/render/render_opengl.py +124 -59
- warp/sim/__init__.py +6 -1
- warp/sim/collide.py +270 -26
- warp/sim/integrator_euler.py +25 -7
- warp/sim/integrator_featherstone.py +154 -35
- warp/sim/integrator_vbd.py +842 -40
- warp/sim/model.py +111 -53
- warp/stubs.py +248 -115
- warp/tape.py +28 -30
- warp/tests/aux_test_module_unload.py +15 -0
- warp/tests/{test_sim_grad.py → flaky_test_sim_grad.py} +104 -63
- warp/tests/test_array.py +74 -0
- warp/tests/test_assert.py +242 -0
- warp/tests/test_codegen.py +14 -61
- warp/tests/test_collision.py +2 -2
- warp/tests/test_examples.py +9 -0
- warp/tests/test_grad_debug.py +87 -2
- warp/tests/test_hash_grid.py +1 -1
- warp/tests/test_ipc.py +116 -0
- warp/tests/test_mat.py +138 -167
- warp/tests/test_math.py +47 -1
- warp/tests/test_matmul.py +11 -7
- warp/tests/test_matmul_lite.py +4 -4
- warp/tests/test_mesh.py +84 -60
- warp/tests/test_mesh_query_aabb.py +165 -0
- warp/tests/test_mesh_query_point.py +328 -286
- warp/tests/test_mesh_query_ray.py +134 -121
- warp/tests/test_mlp.py +2 -2
- warp/tests/test_operators.py +43 -0
- warp/tests/test_overwrite.py +2 -2
- warp/tests/test_quat.py +77 -0
- warp/tests/test_reload.py +29 -0
- warp/tests/test_sim_grad_bounce_linear.py +204 -0
- warp/tests/test_static.py +16 -0
- warp/tests/test_tape.py +25 -0
- warp/tests/test_tile.py +134 -191
- warp/tests/test_tile_load.py +356 -0
- warp/tests/test_tile_mathdx.py +61 -8
- warp/tests/test_tile_mlp.py +17 -17
- warp/tests/test_tile_reduce.py +24 -18
- warp/tests/test_tile_shared_memory.py +66 -17
- warp/tests/test_tile_view.py +165 -0
- warp/tests/test_torch.py +35 -0
- warp/tests/test_utils.py +36 -24
- warp/tests/test_vec.py +110 -0
- warp/tests/unittest_suites.py +29 -4
- warp/tests/unittest_utils.py +30 -11
- warp/thirdparty/unittest_parallel.py +2 -2
- warp/types.py +409 -99
- warp/utils.py +9 -5
- {warp_lang-1.5.1.dist-info → warp_lang-1.6.0.dist-info}/METADATA +68 -44
- {warp_lang-1.5.1.dist-info → warp_lang-1.6.0.dist-info}/RECORD +121 -110
- {warp_lang-1.5.1.dist-info → warp_lang-1.6.0.dist-info}/WHEEL +1 -1
- warp/examples/benchmarks/benchmark_tile.py +0 -179
- warp/native/tile_gemm.h +0 -341
- {warp_lang-1.5.1.dist-info → warp_lang-1.6.0.dist-info}/LICENSE.md +0 -0
- {warp_lang-1.5.1.dist-info → warp_lang-1.6.0.dist-info}/top_level.txt +0 -0
warp/context.py
CHANGED
|
@@ -5,6 +5,8 @@
|
|
|
5
5
|
# distribution of this software and related documentation without an express
|
|
6
6
|
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
|
7
7
|
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
8
10
|
import ast
|
|
9
11
|
import ctypes
|
|
10
12
|
import errno
|
|
@@ -393,7 +395,8 @@ class Function:
|
|
|
393
395
|
if not warp.codegen.func_match_args(f, arg_types, kwarg_types):
|
|
394
396
|
continue
|
|
395
397
|
|
|
396
|
-
|
|
398
|
+
acceptable_arg_num = len(f.input_types) - len(f.defaults) <= len(arg_types) <= len(f.input_types)
|
|
399
|
+
if not acceptable_arg_num:
|
|
397
400
|
continue
|
|
398
401
|
|
|
399
402
|
# try to match the given types to the function template types
|
|
@@ -410,6 +413,10 @@ class Function:
|
|
|
410
413
|
|
|
411
414
|
arg_names = f.input_types.keys()
|
|
412
415
|
overload_annotations = dict(zip(arg_names, arg_types))
|
|
416
|
+
# add defaults
|
|
417
|
+
for k, d in f.defaults.items():
|
|
418
|
+
if k not in overload_annotations:
|
|
419
|
+
overload_annotations[k] = warp.codegen.strip_reference(warp.codegen.get_arg_type(d))
|
|
413
420
|
|
|
414
421
|
ovl = shallowcopy(f)
|
|
415
422
|
ovl.adj = warp.codegen.Adjoint(f.func, overload_annotations)
|
|
@@ -753,8 +760,15 @@ def func(f):
|
|
|
753
760
|
scope_locals = inspect.currentframe().f_back.f_locals
|
|
754
761
|
|
|
755
762
|
m = get_module(f.__module__)
|
|
763
|
+
doc = getattr(f, "__doc__", "") or ""
|
|
756
764
|
Function(
|
|
757
|
-
func=f,
|
|
765
|
+
func=f,
|
|
766
|
+
key=name,
|
|
767
|
+
namespace="",
|
|
768
|
+
module=m,
|
|
769
|
+
value_func=None,
|
|
770
|
+
scope_locals=scope_locals,
|
|
771
|
+
doc=doc.strip(),
|
|
758
772
|
) # value_type not known yet, will be inferred during Adjoint.build()
|
|
759
773
|
|
|
760
774
|
# use the top of the list of overloads for this key
|
|
@@ -1059,7 +1073,8 @@ def overload(kernel, arg_types=Union[None, Dict[str, Any], List[Any]]):
|
|
|
1059
1073
|
raise RuntimeError("wp.overload() called with invalid argument!")
|
|
1060
1074
|
|
|
1061
1075
|
|
|
1062
|
-
|
|
1076
|
+
# native functions that are part of the Warp API
|
|
1077
|
+
builtin_functions: Dict[str, Function] = {}
|
|
1063
1078
|
|
|
1064
1079
|
|
|
1065
1080
|
def get_generic_vtypes():
|
|
@@ -1328,6 +1343,28 @@ def add_builtin(
|
|
|
1328
1343
|
setattr(warp, key, func)
|
|
1329
1344
|
|
|
1330
1345
|
|
|
1346
|
+
def register_api_function(
|
|
1347
|
+
function: Function,
|
|
1348
|
+
group: str = "Other",
|
|
1349
|
+
hidden=False,
|
|
1350
|
+
):
|
|
1351
|
+
"""Main entry point to register a Warp Python function to be part of the Warp API and appear in the documentation.
|
|
1352
|
+
|
|
1353
|
+
Args:
|
|
1354
|
+
function (Function): Warp function to be registered.
|
|
1355
|
+
group (str): Classification used for the documentation.
|
|
1356
|
+
input_types (Mapping[str, Any]): Signature of the user-facing function.
|
|
1357
|
+
Variadic arguments are supported by prefixing the parameter names
|
|
1358
|
+
with asterisks as in `*args` and `**kwargs`. Generic arguments are
|
|
1359
|
+
supported with types such as `Any`, `Float`, `Scalar`, etc.
|
|
1360
|
+
value_type (Any): Type returned by the function.
|
|
1361
|
+
hidden (bool): Whether to add that function into the documentation.
|
|
1362
|
+
"""
|
|
1363
|
+
function.group = group
|
|
1364
|
+
function.hidden = hidden
|
|
1365
|
+
builtin_functions[function.key] = function
|
|
1366
|
+
|
|
1367
|
+
|
|
1331
1368
|
# global dictionary of modules
|
|
1332
1369
|
user_modules = {}
|
|
1333
1370
|
|
|
@@ -1561,6 +1598,7 @@ class ModuleBuilder:
|
|
|
1561
1598
|
self.options = options
|
|
1562
1599
|
self.module = module
|
|
1563
1600
|
self.deferred_functions = []
|
|
1601
|
+
self.fatbins = {} # map from <some identifier> to fatbins, to add at link time
|
|
1564
1602
|
self.ltoirs = {} # map from lto symbol to lto binary
|
|
1565
1603
|
self.ltoirs_decl = {} # map from lto symbol to lto forward declaration
|
|
1566
1604
|
|
|
@@ -1675,7 +1713,7 @@ class ModuleBuilder:
|
|
|
1675
1713
|
|
|
1676
1714
|
for kernel in self.kernels:
|
|
1677
1715
|
source += warp.codegen.codegen_kernel(kernel, device=device, options=self.options)
|
|
1678
|
-
source += warp.codegen.codegen_module(kernel, device=device)
|
|
1716
|
+
source += warp.codegen.codegen_module(kernel, device=device, options=self.options)
|
|
1679
1717
|
|
|
1680
1718
|
# add headers
|
|
1681
1719
|
if device == "cpu":
|
|
@@ -1728,20 +1766,26 @@ class ModuleExec:
|
|
|
1728
1766
|
|
|
1729
1767
|
name = kernel.get_mangled_name()
|
|
1730
1768
|
|
|
1769
|
+
options = dict(kernel.module.options)
|
|
1770
|
+
options.update(kernel.options)
|
|
1771
|
+
|
|
1731
1772
|
if self.device.is_cuda:
|
|
1732
1773
|
forward_name = name + "_cuda_kernel_forward"
|
|
1733
1774
|
forward_kernel = runtime.core.cuda_get_kernel(
|
|
1734
1775
|
self.device.context, self.handle, forward_name.encode("utf-8")
|
|
1735
1776
|
)
|
|
1736
1777
|
|
|
1737
|
-
|
|
1738
|
-
|
|
1739
|
-
|
|
1740
|
-
|
|
1778
|
+
if options["enable_backward"]:
|
|
1779
|
+
backward_name = name + "_cuda_kernel_backward"
|
|
1780
|
+
backward_kernel = runtime.core.cuda_get_kernel(
|
|
1781
|
+
self.device.context, self.handle, backward_name.encode("utf-8")
|
|
1782
|
+
)
|
|
1783
|
+
else:
|
|
1784
|
+
backward_kernel = None
|
|
1741
1785
|
|
|
1742
1786
|
# look up the required shared memory size for each kernel from module metadata
|
|
1743
1787
|
forward_smem_bytes = self.meta[forward_name + "_smem_bytes"]
|
|
1744
|
-
backward_smem_bytes = self.meta[backward_name + "_smem_bytes"]
|
|
1788
|
+
backward_smem_bytes = self.meta[backward_name + "_smem_bytes"] if options["enable_backward"] else 0
|
|
1745
1789
|
|
|
1746
1790
|
# configure kernels maximum shared memory size
|
|
1747
1791
|
max_smem_bytes = runtime.core.cuda_get_max_shared_memory(self.device.context)
|
|
@@ -1751,9 +1795,6 @@ class ModuleExec:
|
|
|
1751
1795
|
f"Warning: Failed to configure kernel dynamic shared memory for this device, tried to configure {forward_name} kernel for {forward_smem_bytes} bytes, but maximum available is {max_smem_bytes}"
|
|
1752
1796
|
)
|
|
1753
1797
|
|
|
1754
|
-
options = dict(kernel.module.options)
|
|
1755
|
-
options.update(kernel.options)
|
|
1756
|
-
|
|
1757
1798
|
if options["enable_backward"] and not runtime.core.cuda_configure_kernel_shared_memory(
|
|
1758
1799
|
backward_kernel, backward_smem_bytes
|
|
1759
1800
|
):
|
|
@@ -1768,9 +1809,14 @@ class ModuleExec:
|
|
|
1768
1809
|
forward = (
|
|
1769
1810
|
func(runtime.llvm.lookup(self.handle.encode("utf-8"), (name + "_cpu_forward").encode("utf-8"))) or None
|
|
1770
1811
|
)
|
|
1771
|
-
|
|
1772
|
-
|
|
1773
|
-
|
|
1812
|
+
|
|
1813
|
+
if options["enable_backward"]:
|
|
1814
|
+
backward = (
|
|
1815
|
+
func(runtime.llvm.lookup(self.handle.encode("utf-8"), (name + "_cpu_backward").encode("utf-8")))
|
|
1816
|
+
or None
|
|
1817
|
+
)
|
|
1818
|
+
else:
|
|
1819
|
+
backward = None
|
|
1774
1820
|
|
|
1775
1821
|
hooks = KernelHooks(forward, backward)
|
|
1776
1822
|
|
|
@@ -1803,13 +1849,13 @@ class Module:
|
|
|
1803
1849
|
self._live_kernels = weakref.WeakSet()
|
|
1804
1850
|
|
|
1805
1851
|
# executable modules currently loaded
|
|
1806
|
-
self.execs = {} # (device.context: ModuleExec)
|
|
1852
|
+
self.execs = {} # ((device.context, blockdim): ModuleExec)
|
|
1807
1853
|
|
|
1808
1854
|
# set of device contexts where the build has failed
|
|
1809
1855
|
self.failed_builds = set()
|
|
1810
1856
|
|
|
1811
|
-
# hash data, including the module hash
|
|
1812
|
-
self.
|
|
1857
|
+
# hash data, including the module hash. Module may store multiple hashes (one per block_dim used)
|
|
1858
|
+
self.hashers = {}
|
|
1813
1859
|
|
|
1814
1860
|
# LLVM executable modules are identified using strings. Since it's possible for multiple
|
|
1815
1861
|
# executable versions to be loaded at the same time, we need a way to ensure uniqueness.
|
|
@@ -1822,6 +1868,8 @@ class Module:
|
|
|
1822
1868
|
"max_unroll": warp.config.max_unroll,
|
|
1823
1869
|
"enable_backward": warp.config.enable_backward,
|
|
1824
1870
|
"fast_math": False,
|
|
1871
|
+
"fuse_fp": True,
|
|
1872
|
+
"lineinfo": False,
|
|
1825
1873
|
"cuda_output": None, # supported values: "ptx", "cubin", or None (automatic)
|
|
1826
1874
|
"mode": warp.config.mode,
|
|
1827
1875
|
"block_dim": 256,
|
|
@@ -1965,28 +2013,27 @@ class Module:
|
|
|
1965
2013
|
|
|
1966
2014
|
def hash_module(self):
|
|
1967
2015
|
# compute latest hash
|
|
1968
|
-
|
|
1969
|
-
|
|
2016
|
+
block_dim = self.options["block_dim"]
|
|
2017
|
+
self.hashers[block_dim] = ModuleHasher(self)
|
|
2018
|
+
return self.hashers[block_dim].get_module_hash()
|
|
1970
2019
|
|
|
1971
2020
|
def load(self, device, block_dim=None) -> ModuleExec:
|
|
1972
2021
|
device = runtime.get_device(device)
|
|
1973
2022
|
|
|
1974
|
-
#
|
|
1975
|
-
# todo: it would be better to have a method such as `module.get_kernel(block_dim=N)`
|
|
1976
|
-
# that can return a single kernel instance with a given block size
|
|
2023
|
+
# update module options if launching with a new block dim
|
|
1977
2024
|
if block_dim is not None:
|
|
1978
|
-
if self.options["block_dim"] != block_dim:
|
|
1979
|
-
self.unload()
|
|
1980
2025
|
self.options["block_dim"] = block_dim
|
|
1981
2026
|
|
|
2027
|
+
active_block_dim = self.options["block_dim"]
|
|
2028
|
+
|
|
1982
2029
|
# compute the hash if needed
|
|
1983
|
-
if self.
|
|
1984
|
-
self.
|
|
2030
|
+
if active_block_dim not in self.hashers:
|
|
2031
|
+
self.hashers[active_block_dim] = ModuleHasher(self)
|
|
1985
2032
|
|
|
1986
2033
|
# check if executable module is already loaded and not stale
|
|
1987
|
-
exec = self.execs.get(device.context)
|
|
2034
|
+
exec = self.execs.get((device.context, active_block_dim))
|
|
1988
2035
|
if exec is not None:
|
|
1989
|
-
if exec.module_hash == self.
|
|
2036
|
+
if exec.module_hash == self.hashers[active_block_dim].get_module_hash():
|
|
1990
2037
|
return exec
|
|
1991
2038
|
|
|
1992
2039
|
# quietly avoid repeated build attempts to reduce error spew
|
|
@@ -1994,10 +2041,11 @@ class Module:
|
|
|
1994
2041
|
return None
|
|
1995
2042
|
|
|
1996
2043
|
module_name = "wp_" + self.name
|
|
1997
|
-
module_hash = self.
|
|
2044
|
+
module_hash = self.hashers[active_block_dim].get_module_hash()
|
|
1998
2045
|
|
|
1999
2046
|
# use a unique module path using the module short hash
|
|
2000
|
-
|
|
2047
|
+
module_name_short = f"{module_name}_{module_hash.hex()[:7]}"
|
|
2048
|
+
module_dir = os.path.join(warp.config.kernel_cache_dir, module_name_short)
|
|
2001
2049
|
|
|
2002
2050
|
with warp.ScopedTimer(
|
|
2003
2051
|
f"Module {self.name} {module_hash.hex()[:7]} load on device '{device}'", active=not warp.config.quiet
|
|
@@ -2005,7 +2053,7 @@ class Module:
|
|
|
2005
2053
|
# -----------------------------------------------------------
|
|
2006
2054
|
# determine output paths
|
|
2007
2055
|
if device.is_cpu:
|
|
2008
|
-
output_name = "
|
|
2056
|
+
output_name = f"{module_name_short}.o"
|
|
2009
2057
|
output_arch = None
|
|
2010
2058
|
|
|
2011
2059
|
elif device.is_cuda:
|
|
@@ -2025,10 +2073,10 @@ class Module:
|
|
|
2025
2073
|
|
|
2026
2074
|
if use_ptx:
|
|
2027
2075
|
output_arch = min(device.arch, warp.config.ptx_target_arch)
|
|
2028
|
-
output_name = f"
|
|
2076
|
+
output_name = f"{module_name_short}.sm{output_arch}.ptx"
|
|
2029
2077
|
else:
|
|
2030
2078
|
output_arch = device.arch
|
|
2031
|
-
output_name = f"
|
|
2079
|
+
output_name = f"{module_name_short}.sm{output_arch}.cubin"
|
|
2032
2080
|
|
|
2033
2081
|
# final object binary path
|
|
2034
2082
|
binary_path = os.path.join(module_dir, output_name)
|
|
@@ -2050,7 +2098,7 @@ class Module:
|
|
|
2050
2098
|
# Some of the Tile codegen, such as cuFFTDx and cuBLASDx, requires knowledge of the target arch
|
|
2051
2099
|
"output_arch": output_arch,
|
|
2052
2100
|
}
|
|
2053
|
-
builder = ModuleBuilder(self, builder_options, hasher=self.
|
|
2101
|
+
builder = ModuleBuilder(self, builder_options, hasher=self.hashers[active_block_dim])
|
|
2054
2102
|
|
|
2055
2103
|
# create a temporary (process unique) dir for build outputs before moving to the binary dir
|
|
2056
2104
|
build_dir = os.path.join(
|
|
@@ -2066,7 +2114,7 @@ class Module:
|
|
|
2066
2114
|
if device.is_cpu:
|
|
2067
2115
|
# build
|
|
2068
2116
|
try:
|
|
2069
|
-
source_code_path = os.path.join(build_dir, "
|
|
2117
|
+
source_code_path = os.path.join(build_dir, f"{module_name_short}.cpp")
|
|
2070
2118
|
|
|
2071
2119
|
# write cpp sources
|
|
2072
2120
|
cpp_source = builder.codegen("cpu")
|
|
@@ -2084,6 +2132,7 @@ class Module:
|
|
|
2084
2132
|
mode=self.options["mode"],
|
|
2085
2133
|
fast_math=self.options["fast_math"],
|
|
2086
2134
|
verify_fp=warp.config.verify_fp,
|
|
2135
|
+
fuse_fp=self.options["fuse_fp"],
|
|
2087
2136
|
)
|
|
2088
2137
|
|
|
2089
2138
|
except Exception as e:
|
|
@@ -2094,7 +2143,7 @@ class Module:
|
|
|
2094
2143
|
elif device.is_cuda:
|
|
2095
2144
|
# build
|
|
2096
2145
|
try:
|
|
2097
|
-
source_code_path = os.path.join(build_dir, "
|
|
2146
|
+
source_code_path = os.path.join(build_dir, f"{module_name_short}.cu")
|
|
2098
2147
|
|
|
2099
2148
|
# write cuda sources
|
|
2100
2149
|
cu_source = builder.codegen("cuda")
|
|
@@ -2111,9 +2160,12 @@ class Module:
|
|
|
2111
2160
|
output_arch,
|
|
2112
2161
|
output_path,
|
|
2113
2162
|
config=self.options["mode"],
|
|
2114
|
-
fast_math=self.options["fast_math"],
|
|
2115
2163
|
verify_fp=warp.config.verify_fp,
|
|
2164
|
+
fast_math=self.options["fast_math"],
|
|
2165
|
+
fuse_fp=self.options["fuse_fp"],
|
|
2166
|
+
lineinfo=self.options["lineinfo"],
|
|
2116
2167
|
ltoirs=builder.ltoirs.values(),
|
|
2168
|
+
fatbins=builder.fatbins.values(),
|
|
2117
2169
|
)
|
|
2118
2170
|
|
|
2119
2171
|
except Exception as e:
|
|
@@ -2125,7 +2177,7 @@ class Module:
|
|
|
2125
2177
|
# build meta data
|
|
2126
2178
|
|
|
2127
2179
|
meta = builder.build_meta()
|
|
2128
|
-
meta_path = os.path.join(build_dir, "
|
|
2180
|
+
meta_path = os.path.join(build_dir, f"{module_name_short}.meta")
|
|
2129
2181
|
|
|
2130
2182
|
with open(meta_path, "w") as meta_file:
|
|
2131
2183
|
json.dump(meta, meta_file)
|
|
@@ -2189,7 +2241,7 @@ class Module:
|
|
|
2189
2241
|
# -----------------------------------------------------------
|
|
2190
2242
|
# Load CPU or CUDA binary
|
|
2191
2243
|
|
|
2192
|
-
meta_path = os.path.join(module_dir, "
|
|
2244
|
+
meta_path = os.path.join(module_dir, f"{module_name_short}.meta")
|
|
2193
2245
|
with open(meta_path, "r") as meta_file:
|
|
2194
2246
|
meta = json.load(meta_file)
|
|
2195
2247
|
|
|
@@ -2199,13 +2251,13 @@ class Module:
|
|
|
2199
2251
|
self.cpu_exec_id += 1
|
|
2200
2252
|
runtime.llvm.load_obj(binary_path.encode("utf-8"), module_handle.encode("utf-8"))
|
|
2201
2253
|
module_exec = ModuleExec(module_handle, module_hash, device, meta)
|
|
2202
|
-
self.execs[None] = module_exec
|
|
2254
|
+
self.execs[(None, active_block_dim)] = module_exec
|
|
2203
2255
|
|
|
2204
2256
|
elif device.is_cuda:
|
|
2205
2257
|
cuda_module = warp.build.load_cuda(binary_path, device)
|
|
2206
2258
|
if cuda_module is not None:
|
|
2207
2259
|
module_exec = ModuleExec(cuda_module, module_hash, device, meta)
|
|
2208
|
-
self.execs[device.context] = module_exec
|
|
2260
|
+
self.execs[(device.context, active_block_dim)] = module_exec
|
|
2209
2261
|
else:
|
|
2210
2262
|
module_load_timer.extra_msg = " (error)"
|
|
2211
2263
|
raise Exception(f"Failed to load CUDA module '{self.name}'")
|
|
@@ -2227,14 +2279,14 @@ class Module:
|
|
|
2227
2279
|
|
|
2228
2280
|
def mark_modified(self):
|
|
2229
2281
|
# clear hash data
|
|
2230
|
-
self.
|
|
2282
|
+
self.hashers = {}
|
|
2231
2283
|
|
|
2232
2284
|
# clear build failures
|
|
2233
2285
|
self.failed_builds = set()
|
|
2234
2286
|
|
|
2235
2287
|
# lookup kernel entry points based on name, called after compilation / module load
|
|
2236
2288
|
def get_kernel_hooks(self, kernel, device):
|
|
2237
|
-
module_exec = self.execs.get(device.context)
|
|
2289
|
+
module_exec = self.execs.get((device.context, self.options["block_dim"]))
|
|
2238
2290
|
if module_exec is not None:
|
|
2239
2291
|
return module_exec.get_kernel_hooks(kernel)
|
|
2240
2292
|
else:
|
|
@@ -2353,6 +2405,7 @@ class Event:
|
|
|
2353
2405
|
DEFAULT = 0x0
|
|
2354
2406
|
BLOCKING_SYNC = 0x1
|
|
2355
2407
|
DISABLE_TIMING = 0x2
|
|
2408
|
+
INTERPROCESS = 0x4
|
|
2356
2409
|
|
|
2357
2410
|
def __new__(cls, *args, **kwargs):
|
|
2358
2411
|
"""Creates a new event instance."""
|
|
@@ -2360,7 +2413,9 @@ class Event:
|
|
|
2360
2413
|
instance.owner = False
|
|
2361
2414
|
return instance
|
|
2362
2415
|
|
|
2363
|
-
def __init__(
|
|
2416
|
+
def __init__(
|
|
2417
|
+
self, device: "Devicelike" = None, cuda_event=None, enable_timing: bool = False, interprocess: bool = False
|
|
2418
|
+
):
|
|
2364
2419
|
"""Initializes the event on a CUDA device.
|
|
2365
2420
|
|
|
2366
2421
|
Args:
|
|
@@ -2372,6 +2427,12 @@ class Event:
|
|
|
2372
2427
|
:func:`~warp.get_event_elapsed_time` can be used to measure the
|
|
2373
2428
|
time between two events created with ``enable_timing=True`` and
|
|
2374
2429
|
recorded onto streams.
|
|
2430
|
+
interprocess: If ``True`` this event may be used as an interprocess event.
|
|
2431
|
+
|
|
2432
|
+
Raises:
|
|
2433
|
+
RuntimeError: The event could not be created.
|
|
2434
|
+
ValueError: The combination of ``enable_timing=True`` and
|
|
2435
|
+
``interprocess=True`` is not allowed.
|
|
2375
2436
|
"""
|
|
2376
2437
|
|
|
2377
2438
|
device = get_device(device)
|
|
@@ -2386,11 +2447,48 @@ class Event:
|
|
|
2386
2447
|
flags = Event.Flags.DEFAULT
|
|
2387
2448
|
if not enable_timing:
|
|
2388
2449
|
flags |= Event.Flags.DISABLE_TIMING
|
|
2450
|
+
if interprocess:
|
|
2451
|
+
if enable_timing:
|
|
2452
|
+
raise ValueError("The combination of 'enable_timing=True' and 'interprocess=True' is not allowed.")
|
|
2453
|
+
flags |= Event.Flags.INTERPROCESS
|
|
2454
|
+
|
|
2389
2455
|
self.cuda_event = runtime.core.cuda_event_create(device.context, flags)
|
|
2390
2456
|
if not self.cuda_event:
|
|
2391
2457
|
raise RuntimeError(f"Failed to create event on device {device}")
|
|
2392
2458
|
self.owner = True
|
|
2393
2459
|
|
|
2460
|
+
def ipc_handle(self) -> bytes:
|
|
2461
|
+
"""Return a CUDA IPC handle of the event as a 64-byte ``bytes`` object.
|
|
2462
|
+
|
|
2463
|
+
The event must have been created with ``interprocess=True`` in order to
|
|
2464
|
+
obtain a valid interprocess handle.
|
|
2465
|
+
|
|
2466
|
+
IPC is currently only supported on Linux.
|
|
2467
|
+
|
|
2468
|
+
Example:
|
|
2469
|
+
Create an event and get its IPC handle::
|
|
2470
|
+
|
|
2471
|
+
e1 = wp.Event(interprocess=True)
|
|
2472
|
+
event_handle = e1.ipc_handle()
|
|
2473
|
+
|
|
2474
|
+
Raises:
|
|
2475
|
+
RuntimeError: Device does not support IPC.
|
|
2476
|
+
"""
|
|
2477
|
+
|
|
2478
|
+
if self.device.is_ipc_supported is not False:
|
|
2479
|
+
# Allocate a buffer for the data (64-element char array)
|
|
2480
|
+
ipc_handle_buffer = (ctypes.c_char * 64)()
|
|
2481
|
+
|
|
2482
|
+
warp.context.runtime.core.cuda_ipc_get_event_handle(self.device.context, self.cuda_event, ipc_handle_buffer)
|
|
2483
|
+
|
|
2484
|
+
if ipc_handle_buffer.raw == bytes(64):
|
|
2485
|
+
warp.utils.warn("IPC event handle appears to be invalid. Was interprocess=True used?")
|
|
2486
|
+
|
|
2487
|
+
return ipc_handle_buffer.raw
|
|
2488
|
+
|
|
2489
|
+
else:
|
|
2490
|
+
raise RuntimeError(f"Device {self.device} does not support IPC.")
|
|
2491
|
+
|
|
2394
2492
|
def __del__(self):
|
|
2395
2493
|
if not self.owner:
|
|
2396
2494
|
return
|
|
@@ -2538,23 +2636,27 @@ class Device:
|
|
|
2538
2636
|
"""A device to allocate Warp arrays and to launch kernels on.
|
|
2539
2637
|
|
|
2540
2638
|
Attributes:
|
|
2541
|
-
ordinal: A Warp-specific
|
|
2542
|
-
name: A
|
|
2639
|
+
ordinal (int): A Warp-specific label for the device. ``-1`` for CPU devices.
|
|
2640
|
+
name (str): A label for the device. By default, CPU devices will be named according to the processor name,
|
|
2543
2641
|
or ``"CPU"`` if the processor name cannot be determined.
|
|
2544
|
-
arch:
|
|
2545
|
-
``
|
|
2546
|
-
is_uva:
|
|
2642
|
+
arch (int): The compute capability version number calculated as ``10 * major + minor``.
|
|
2643
|
+
``0`` for CPU devices.
|
|
2644
|
+
is_uva (bool): Indicates whether the device supports unified addressing.
|
|
2547
2645
|
``False`` for CPU devices.
|
|
2548
|
-
is_cubin_supported:
|
|
2646
|
+
is_cubin_supported (bool): Indicates whether Warp's version of NVRTC can directly
|
|
2549
2647
|
generate CUDA binary files (cubin) for this device's architecture. ``False`` for CPU devices.
|
|
2550
|
-
is_mempool_supported:
|
|
2551
|
-
``
|
|
2552
|
-
|
|
2553
|
-
|
|
2554
|
-
|
|
2555
|
-
|
|
2556
|
-
|
|
2557
|
-
|
|
2648
|
+
is_mempool_supported (bool): Indicates whether the device supports using the ``cuMemAllocAsync`` and
|
|
2649
|
+
``cuMemPool`` family of APIs for stream-ordered memory allocations. ``False`` for CPU devices.
|
|
2650
|
+
is_ipc_supported (Optional[bool]): Indicates whether the device supports IPC.
|
|
2651
|
+
|
|
2652
|
+
- ``True`` if supported.
|
|
2653
|
+
- ``False`` if not supported.
|
|
2654
|
+
- ``None`` if IPC support could not be determined (e.g. CUDA 11).
|
|
2655
|
+
|
|
2656
|
+
is_primary (bool): Indicates whether this device's CUDA context is also the device's primary context.
|
|
2657
|
+
uuid (str): The UUID of the CUDA device. The UUID is in the same format used by ``nvidia-smi -L``.
|
|
2658
|
+
``None`` for CPU devices.
|
|
2659
|
+
pci_bus_id (str): An identifier for the CUDA device in the format ``[domain]:[bus]:[device]``, in which
|
|
2558
2660
|
``domain``, ``bus``, and ``device`` are all hexadecimal values. ``None`` for CPU devices.
|
|
2559
2661
|
"""
|
|
2560
2662
|
|
|
@@ -2587,6 +2689,7 @@ class Device:
|
|
|
2587
2689
|
self.is_uva = False
|
|
2588
2690
|
self.is_mempool_supported = False
|
|
2589
2691
|
self.is_mempool_enabled = False
|
|
2692
|
+
self.is_ipc_supported = False # TODO: Support IPC for CPU arrays
|
|
2590
2693
|
self.is_cubin_supported = False
|
|
2591
2694
|
self.uuid = None
|
|
2592
2695
|
self.pci_bus_id = None
|
|
@@ -2602,8 +2705,14 @@ class Device:
|
|
|
2602
2705
|
# CUDA device
|
|
2603
2706
|
self.name = runtime.core.cuda_device_get_name(ordinal).decode()
|
|
2604
2707
|
self.arch = runtime.core.cuda_device_get_arch(ordinal)
|
|
2605
|
-
self.is_uva = runtime.core.cuda_device_is_uva(ordinal)
|
|
2606
|
-
self.is_mempool_supported = runtime.core.cuda_device_is_mempool_supported(ordinal)
|
|
2708
|
+
self.is_uva = runtime.core.cuda_device_is_uva(ordinal) > 0
|
|
2709
|
+
self.is_mempool_supported = runtime.core.cuda_device_is_mempool_supported(ordinal) > 0
|
|
2710
|
+
if platform.system() == "Linux":
|
|
2711
|
+
# Use None when IPC support cannot be determined
|
|
2712
|
+
ipc_support_api_query = runtime.core.cuda_device_is_ipc_supported(ordinal)
|
|
2713
|
+
self.is_ipc_supported = bool(ipc_support_api_query) if ipc_support_api_query >= 0 else None
|
|
2714
|
+
else:
|
|
2715
|
+
self.is_ipc_supported = False
|
|
2607
2716
|
if warp.config.enable_mempools_at_init:
|
|
2608
2717
|
# enable if supported
|
|
2609
2718
|
self.is_mempool_enabled = self.is_mempool_supported
|
|
@@ -3084,6 +3193,9 @@ class Runtime:
|
|
|
3084
3193
|
self.core.radix_sort_pairs_int_host.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int]
|
|
3085
3194
|
self.core.radix_sort_pairs_int_device.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int]
|
|
3086
3195
|
|
|
3196
|
+
self.core.radix_sort_pairs_float_host.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int]
|
|
3197
|
+
self.core.radix_sort_pairs_float_device.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int]
|
|
3198
|
+
|
|
3087
3199
|
self.core.runlength_encode_int_host.argtypes = [
|
|
3088
3200
|
ctypes.c_uint64,
|
|
3089
3201
|
ctypes.c_uint64,
|
|
@@ -3100,10 +3212,16 @@ class Runtime:
|
|
|
3100
3212
|
]
|
|
3101
3213
|
|
|
3102
3214
|
self.core.bvh_create_host.restype = ctypes.c_uint64
|
|
3103
|
-
self.core.bvh_create_host.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int]
|
|
3215
|
+
self.core.bvh_create_host.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int, ctypes.c_int]
|
|
3104
3216
|
|
|
3105
3217
|
self.core.bvh_create_device.restype = ctypes.c_uint64
|
|
3106
|
-
self.core.bvh_create_device.argtypes = [
|
|
3218
|
+
self.core.bvh_create_device.argtypes = [
|
|
3219
|
+
ctypes.c_void_p,
|
|
3220
|
+
ctypes.c_void_p,
|
|
3221
|
+
ctypes.c_void_p,
|
|
3222
|
+
ctypes.c_int,
|
|
3223
|
+
ctypes.c_int,
|
|
3224
|
+
]
|
|
3107
3225
|
|
|
3108
3226
|
self.core.bvh_destroy_host.argtypes = [ctypes.c_uint64]
|
|
3109
3227
|
self.core.bvh_destroy_device.argtypes = [ctypes.c_uint64]
|
|
@@ -3119,6 +3237,7 @@ class Runtime:
|
|
|
3119
3237
|
ctypes.c_int,
|
|
3120
3238
|
ctypes.c_int,
|
|
3121
3239
|
ctypes.c_int,
|
|
3240
|
+
ctypes.c_int,
|
|
3122
3241
|
]
|
|
3123
3242
|
|
|
3124
3243
|
self.core.mesh_create_device.restype = ctypes.c_uint64
|
|
@@ -3130,6 +3249,7 @@ class Runtime:
|
|
|
3130
3249
|
ctypes.c_int,
|
|
3131
3250
|
ctypes.c_int,
|
|
3132
3251
|
ctypes.c_int,
|
|
3252
|
+
ctypes.c_int,
|
|
3133
3253
|
]
|
|
3134
3254
|
|
|
3135
3255
|
self.core.mesh_destroy_host.argtypes = [ctypes.c_uint64]
|
|
@@ -3367,6 +3487,8 @@ class Runtime:
|
|
|
3367
3487
|
self.core.cuda_device_is_uva.restype = ctypes.c_int
|
|
3368
3488
|
self.core.cuda_device_is_mempool_supported.argtypes = [ctypes.c_int]
|
|
3369
3489
|
self.core.cuda_device_is_mempool_supported.restype = ctypes.c_int
|
|
3490
|
+
self.core.cuda_device_is_ipc_supported.argtypes = [ctypes.c_int]
|
|
3491
|
+
self.core.cuda_device_is_ipc_supported.restype = ctypes.c_int
|
|
3370
3492
|
self.core.cuda_device_set_mempool_release_threshold.argtypes = [ctypes.c_int, ctypes.c_uint64]
|
|
3371
3493
|
self.core.cuda_device_set_mempool_release_threshold.restype = ctypes.c_int
|
|
3372
3494
|
self.core.cuda_device_get_mempool_release_threshold.argtypes = [ctypes.c_int]
|
|
@@ -3420,6 +3542,22 @@ class Runtime:
|
|
|
3420
3542
|
self.core.cuda_set_mempool_access_enabled.argtypes = [ctypes.c_int, ctypes.c_int, ctypes.c_int]
|
|
3421
3543
|
self.core.cuda_set_mempool_access_enabled.restype = ctypes.c_int
|
|
3422
3544
|
|
|
3545
|
+
# inter-process communication
|
|
3546
|
+
self.core.cuda_ipc_get_mem_handle.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_char)]
|
|
3547
|
+
self.core.cuda_ipc_get_mem_handle.restype = None
|
|
3548
|
+
self.core.cuda_ipc_open_mem_handle.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_char)]
|
|
3549
|
+
self.core.cuda_ipc_open_mem_handle.restype = ctypes.c_void_p
|
|
3550
|
+
self.core.cuda_ipc_close_mem_handle.argtypes = [ctypes.c_void_p]
|
|
3551
|
+
self.core.cuda_ipc_close_mem_handle.restype = None
|
|
3552
|
+
self.core.cuda_ipc_get_event_handle.argtypes = [
|
|
3553
|
+
ctypes.c_void_p,
|
|
3554
|
+
ctypes.c_void_p,
|
|
3555
|
+
ctypes.POINTER(ctypes.c_char),
|
|
3556
|
+
]
|
|
3557
|
+
self.core.cuda_ipc_get_event_handle.restype = None
|
|
3558
|
+
self.core.cuda_ipc_open_event_handle.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_char)]
|
|
3559
|
+
self.core.cuda_ipc_open_event_handle.restype = ctypes.c_void_p
|
|
3560
|
+
|
|
3423
3561
|
self.core.cuda_stream_create.argtypes = [ctypes.c_void_p, ctypes.c_int]
|
|
3424
3562
|
self.core.cuda_stream_create.restype = ctypes.c_void_p
|
|
3425
3563
|
self.core.cuda_stream_destroy.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
|
|
@@ -3467,6 +3605,7 @@ class Runtime:
|
|
|
3467
3605
|
|
|
3468
3606
|
self.core.cuda_compile_program.argtypes = [
|
|
3469
3607
|
ctypes.c_char_p, # cuda_src
|
|
3608
|
+
ctypes.c_char_p, # program name
|
|
3470
3609
|
ctypes.c_int, # arch
|
|
3471
3610
|
ctypes.c_char_p, # include_dir
|
|
3472
3611
|
ctypes.c_int, # num_cuda_include_dirs
|
|
@@ -3475,10 +3614,13 @@ class Runtime:
|
|
|
3475
3614
|
ctypes.c_bool, # verbose
|
|
3476
3615
|
ctypes.c_bool, # verify_fp
|
|
3477
3616
|
ctypes.c_bool, # fast_math
|
|
3617
|
+
ctypes.c_bool, # fuse_fp
|
|
3618
|
+
ctypes.c_bool, # lineinfo
|
|
3478
3619
|
ctypes.c_char_p, # output_path
|
|
3479
3620
|
ctypes.c_size_t, # num_ltoirs
|
|
3480
3621
|
ctypes.POINTER(ctypes.c_char_p), # ltoirs
|
|
3481
3622
|
ctypes.POINTER(ctypes.c_size_t), # ltoir_sizes
|
|
3623
|
+
ctypes.POINTER(ctypes.c_int), # ltoir_input_types, each of type nvJitLinkInputType
|
|
3482
3624
|
]
|
|
3483
3625
|
self.core.cuda_compile_program.restype = ctypes.c_size_t
|
|
3484
3626
|
|
|
@@ -3518,6 +3660,22 @@ class Runtime:
|
|
|
3518
3660
|
]
|
|
3519
3661
|
self.core.cuda_compile_dot.restype = ctypes.c_bool
|
|
3520
3662
|
|
|
3663
|
+
self.core.cuda_compile_solver.argtypes = [
|
|
3664
|
+
ctypes.c_char_p, # universal fatbin
|
|
3665
|
+
ctypes.c_char_p, # lto
|
|
3666
|
+
ctypes.c_char_p, # function name
|
|
3667
|
+
ctypes.c_int, # num include dirs
|
|
3668
|
+
ctypes.POINTER(ctypes.c_char_p), # include dirs
|
|
3669
|
+
ctypes.c_char_p, # mathdx include dir
|
|
3670
|
+
ctypes.c_int, # arch
|
|
3671
|
+
ctypes.c_int, # M
|
|
3672
|
+
ctypes.c_int, # N
|
|
3673
|
+
ctypes.c_int, # precision
|
|
3674
|
+
ctypes.c_int, # fill_mode
|
|
3675
|
+
ctypes.c_int, # num threads
|
|
3676
|
+
]
|
|
3677
|
+
self.core.cuda_compile_fft.restype = ctypes.c_bool
|
|
3678
|
+
|
|
3521
3679
|
self.core.cuda_load_module.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
|
|
3522
3680
|
self.core.cuda_load_module.restype = ctypes.c_void_p
|
|
3523
3681
|
|
|
@@ -4868,6 +5026,40 @@ def from_numpy(
|
|
|
4868
5026
|
)
|
|
4869
5027
|
|
|
4870
5028
|
|
|
5029
|
+
def event_from_ipc_handle(handle, device: "Devicelike" = None) -> Event:
|
|
5030
|
+
"""Create an event from an IPC handle.
|
|
5031
|
+
|
|
5032
|
+
Args:
|
|
5033
|
+
handle: The interprocess event handle for an existing CUDA event.
|
|
5034
|
+
device (Devicelike): Device to associate with the array.
|
|
5035
|
+
|
|
5036
|
+
Returns:
|
|
5037
|
+
An event created from the interprocess event handle ``handle``.
|
|
5038
|
+
|
|
5039
|
+
Raises:
|
|
5040
|
+
RuntimeError: IPC is not supported on ``device``.
|
|
5041
|
+
"""
|
|
5042
|
+
|
|
5043
|
+
try:
|
|
5044
|
+
# Performance note: try first, ask questions later
|
|
5045
|
+
device = warp.context.runtime.get_device(device)
|
|
5046
|
+
except Exception:
|
|
5047
|
+
# Fallback to using the public API for retrieving the device,
|
|
5048
|
+
# which takes take of initializing Warp if needed.
|
|
5049
|
+
device = warp.context.get_device(device)
|
|
5050
|
+
|
|
5051
|
+
if device.is_ipc_supported is False:
|
|
5052
|
+
raise RuntimeError(f"IPC is not supported on device {device}.")
|
|
5053
|
+
|
|
5054
|
+
event = Event(
|
|
5055
|
+
device=device, cuda_event=warp.context.runtime.core.cuda_ipc_open_event_handle(device.context, handle)
|
|
5056
|
+
)
|
|
5057
|
+
# Events created from IPC handles must be freed with cuEventDestroy
|
|
5058
|
+
event.owner = True
|
|
5059
|
+
|
|
5060
|
+
return event
|
|
5061
|
+
|
|
5062
|
+
|
|
4871
5063
|
# given a kernel destination argument type and a value convert
|
|
4872
5064
|
# to a c-type that can be passed to a kernel
|
|
4873
5065
|
def pack_arg(kernel, arg_type, arg_name, value, device, adjoint=False):
|
|
@@ -4949,6 +5141,9 @@ def pack_arg(kernel, arg_type, arg_name, value, device, adjoint=False):
|
|
|
4949
5141
|
|
|
4950
5142
|
# try to convert to a value type (vec3, mat33, etc)
|
|
4951
5143
|
elif issubclass(arg_type, ctypes.Array):
|
|
5144
|
+
# simple value types don't have gradient arrays, but native built-in signatures still expect a non-null adjoint value of the correct type
|
|
5145
|
+
if value is None and adjoint:
|
|
5146
|
+
return arg_type(0)
|
|
4952
5147
|
if warp.types.types_equal(type(value), arg_type):
|
|
4953
5148
|
return value
|
|
4954
5149
|
else:
|
|
@@ -4958,9 +5153,6 @@ def pack_arg(kernel, arg_type, arg_name, value, device, adjoint=False):
|
|
|
4958
5153
|
except Exception as e:
|
|
4959
5154
|
raise ValueError(f"Failed to convert argument for param {arg_name} to {type_str(arg_type)}") from e
|
|
4960
5155
|
|
|
4961
|
-
elif isinstance(value, bool):
|
|
4962
|
-
return ctypes.c_bool(value)
|
|
4963
|
-
|
|
4964
5156
|
elif isinstance(value, arg_type):
|
|
4965
5157
|
try:
|
|
4966
5158
|
# try to pack as a scalar type
|
|
@@ -4975,6 +5167,9 @@ def pack_arg(kernel, arg_type, arg_name, value, device, adjoint=False):
|
|
|
4975
5167
|
) from e
|
|
4976
5168
|
|
|
4977
5169
|
else:
|
|
5170
|
+
# scalar args don't have gradient arrays, but native built-in signatures still expect a non-null scalar adjoint
|
|
5171
|
+
if value is None and adjoint:
|
|
5172
|
+
return arg_type._type_(0)
|
|
4978
5173
|
try:
|
|
4979
5174
|
# try to pack as a scalar type
|
|
4980
5175
|
if arg_type is warp.types.float16:
|
|
@@ -6034,14 +6229,19 @@ def export_functions_rst(file): # pragma: no cover
|
|
|
6034
6229
|
# build dictionary of all functions by group
|
|
6035
6230
|
groups = {}
|
|
6036
6231
|
|
|
6037
|
-
|
|
6232
|
+
functions = list(builtin_functions.values())
|
|
6233
|
+
|
|
6234
|
+
for f in functions:
|
|
6038
6235
|
# build dict of groups
|
|
6039
6236
|
if f.group not in groups:
|
|
6040
6237
|
groups[f.group] = []
|
|
6041
6238
|
|
|
6042
|
-
|
|
6043
|
-
|
|
6044
|
-
|
|
6239
|
+
if hasattr(f, "overloads"):
|
|
6240
|
+
# append all overloads to the group
|
|
6241
|
+
for o in f.overloads:
|
|
6242
|
+
groups[f.group].append(o)
|
|
6243
|
+
else:
|
|
6244
|
+
groups[f.group].append(f)
|
|
6045
6245
|
|
|
6046
6246
|
# Keep track of what function and query types have been written
|
|
6047
6247
|
written_functions = set()
|
|
@@ -6061,6 +6261,10 @@ def export_functions_rst(file): # pragma: no cover
|
|
|
6061
6261
|
print("---------------", file=file)
|
|
6062
6262
|
|
|
6063
6263
|
for f in g:
|
|
6264
|
+
if f.func:
|
|
6265
|
+
# f is a Warp function written in Python, we can use autofunction
|
|
6266
|
+
print(f".. autofunction:: {f.func.__module__}.{f.key}", file=file)
|
|
6267
|
+
continue
|
|
6064
6268
|
for f_prefix, query_type in query_types:
|
|
6065
6269
|
if f.key.startswith(f_prefix) and query_type not in written_query_types:
|
|
6066
6270
|
print(f".. autoclass:: {query_type}", file=file)
|
|
@@ -6118,24 +6322,32 @@ def export_stubs(file): # pragma: no cover
|
|
|
6118
6322
|
print(header, file=file)
|
|
6119
6323
|
print(file=file)
|
|
6120
6324
|
|
|
6121
|
-
|
|
6122
|
-
|
|
6123
|
-
args = ", ".join(f"{k}: {type_str(v)}" for k, v in f.input_types.items())
|
|
6325
|
+
def add_stub(f):
|
|
6326
|
+
args = ", ".join(f"{k}: {type_str(v)}" for k, v in f.input_types.items())
|
|
6124
6327
|
|
|
6125
|
-
|
|
6328
|
+
return_str = ""
|
|
6126
6329
|
|
|
6127
|
-
|
|
6128
|
-
|
|
6330
|
+
if f.hidden: # or f.generic:
|
|
6331
|
+
return
|
|
6129
6332
|
|
|
6333
|
+
return_type = f.value_type
|
|
6334
|
+
if f.value_func:
|
|
6130
6335
|
return_type = f.value_func(None, None)
|
|
6131
|
-
|
|
6132
|
-
|
|
6133
|
-
|
|
6134
|
-
|
|
6135
|
-
|
|
6136
|
-
|
|
6137
|
-
|
|
6138
|
-
|
|
6336
|
+
if return_type:
|
|
6337
|
+
return_str = " -> " + type_str(return_type)
|
|
6338
|
+
|
|
6339
|
+
print("@over", file=file)
|
|
6340
|
+
print(f"def {f.key}({args}){return_str}:", file=file)
|
|
6341
|
+
print(f' """{f.doc}', file=file)
|
|
6342
|
+
print(' """', file=file)
|
|
6343
|
+
print(" ...\n\n", file=file)
|
|
6344
|
+
|
|
6345
|
+
for g in builtin_functions.values():
|
|
6346
|
+
if hasattr(g, "overloads"):
|
|
6347
|
+
for f in g.overloads:
|
|
6348
|
+
add_stub(f)
|
|
6349
|
+
else:
|
|
6350
|
+
add_stub(g)
|
|
6139
6351
|
|
|
6140
6352
|
|
|
6141
6353
|
def export_builtins(file: io.TextIOBase): # pragma: no cover
|
|
@@ -6161,6 +6373,8 @@ def export_builtins(file: io.TextIOBase): # pragma: no cover
|
|
|
6161
6373
|
file.write('extern "C" {\n\n')
|
|
6162
6374
|
|
|
6163
6375
|
for k, g in builtin_functions.items():
|
|
6376
|
+
if not hasattr(g, "overloads"):
|
|
6377
|
+
continue
|
|
6164
6378
|
for f in g.overloads:
|
|
6165
6379
|
if not f.export or f.generic:
|
|
6166
6380
|
continue
|