warp-lang 1.4.2__py3-none-manylinux2014_aarch64.whl → 1.5.0__py3-none-manylinux2014_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +4 -0
- warp/autograd.py +43 -8
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/build.py +21 -2
- warp/build_dll.py +23 -6
- warp/builtins.py +1783 -2
- warp/codegen.py +177 -45
- warp/config.py +2 -2
- warp/context.py +321 -73
- warp/examples/assets/pixel.jpg +0 -0
- warp/examples/benchmarks/benchmark_cloth_paddle.py +86 -0
- warp/examples/benchmarks/benchmark_gemm.py +121 -0
- warp/examples/benchmarks/benchmark_interop_paddle.py +158 -0
- warp/examples/benchmarks/benchmark_tile.py +179 -0
- warp/examples/fem/example_adaptive_grid.py +37 -10
- warp/examples/fem/example_apic_fluid.py +3 -2
- warp/examples/fem/example_convection_diffusion_dg.py +4 -5
- warp/examples/fem/example_deformed_geometry.py +1 -1
- warp/examples/fem/example_diffusion_3d.py +47 -4
- warp/examples/fem/example_distortion_energy.py +220 -0
- warp/examples/fem/example_magnetostatics.py +127 -85
- warp/examples/fem/example_nonconforming_contact.py +5 -5
- warp/examples/fem/example_stokes.py +3 -1
- warp/examples/fem/example_streamlines.py +12 -19
- warp/examples/fem/utils.py +38 -15
- warp/examples/sim/example_cloth.py +2 -25
- warp/examples/sim/example_quadruped.py +2 -1
- warp/examples/tile/example_tile_convolution.py +58 -0
- warp/examples/tile/example_tile_fft.py +47 -0
- warp/examples/tile/example_tile_filtering.py +105 -0
- warp/examples/tile/example_tile_matmul.py +79 -0
- warp/examples/tile/example_tile_mlp.py +375 -0
- warp/fem/__init__.py +8 -0
- warp/fem/cache.py +16 -12
- warp/fem/dirichlet.py +1 -1
- warp/fem/domain.py +44 -1
- warp/fem/field/__init__.py +1 -2
- warp/fem/field/field.py +31 -19
- warp/fem/field/nodal_field.py +101 -49
- warp/fem/field/virtual.py +794 -0
- warp/fem/geometry/__init__.py +2 -2
- warp/fem/geometry/deformed_geometry.py +3 -105
- warp/fem/geometry/element.py +13 -0
- warp/fem/geometry/geometry.py +165 -5
- warp/fem/geometry/grid_2d.py +3 -6
- warp/fem/geometry/grid_3d.py +31 -28
- warp/fem/geometry/hexmesh.py +3 -46
- warp/fem/geometry/nanogrid.py +3 -2
- warp/fem/geometry/{quadmesh_2d.py → quadmesh.py} +280 -159
- warp/fem/geometry/tetmesh.py +2 -43
- warp/fem/geometry/{trimesh_2d.py → trimesh.py} +354 -186
- warp/fem/integrate.py +683 -261
- warp/fem/linalg.py +404 -0
- warp/fem/operator.py +101 -18
- warp/fem/polynomial.py +5 -5
- warp/fem/quadrature/quadrature.py +45 -21
- warp/fem/space/__init__.py +45 -11
- warp/fem/space/basis_function_space.py +451 -0
- warp/fem/space/basis_space.py +58 -11
- warp/fem/space/function_space.py +146 -5
- warp/fem/space/grid_2d_function_space.py +80 -66
- warp/fem/space/grid_3d_function_space.py +113 -68
- warp/fem/space/hexmesh_function_space.py +96 -108
- warp/fem/space/nanogrid_function_space.py +62 -110
- warp/fem/space/quadmesh_function_space.py +208 -0
- warp/fem/space/shape/__init__.py +45 -7
- warp/fem/space/shape/cube_shape_function.py +328 -54
- warp/fem/space/shape/shape_function.py +10 -1
- warp/fem/space/shape/square_shape_function.py +328 -60
- warp/fem/space/shape/tet_shape_function.py +269 -19
- warp/fem/space/shape/triangle_shape_function.py +238 -19
- warp/fem/space/tetmesh_function_space.py +69 -37
- warp/fem/space/topology.py +38 -0
- warp/fem/space/trimesh_function_space.py +179 -0
- warp/fem/utils.py +6 -331
- warp/jax_experimental.py +3 -1
- warp/native/array.h +15 -0
- warp/native/builtin.h +66 -26
- warp/native/bvh.h +4 -0
- warp/native/coloring.cpp +600 -0
- warp/native/cuda_util.cpp +14 -0
- warp/native/cuda_util.h +2 -1
- warp/native/fabric.h +8 -0
- warp/native/hashgrid.h +4 -0
- warp/native/marching.cu +8 -0
- warp/native/mat.h +14 -3
- warp/native/mathdx.cpp +59 -0
- warp/native/mesh.h +4 -0
- warp/native/range.h +13 -1
- warp/native/reduce.cpp +9 -1
- warp/native/reduce.cu +7 -0
- warp/native/runlength_encode.cpp +9 -1
- warp/native/runlength_encode.cu +7 -1
- warp/native/scan.cpp +8 -0
- warp/native/scan.cu +8 -0
- warp/native/scan.h +8 -1
- warp/native/sparse.cpp +8 -0
- warp/native/sparse.cu +8 -0
- warp/native/temp_buffer.h +7 -0
- warp/native/tile.h +1857 -0
- warp/native/tile_gemm.h +341 -0
- warp/native/tile_reduce.h +210 -0
- warp/native/volume_builder.cu +8 -0
- warp/native/volume_builder.h +8 -0
- warp/native/warp.cpp +10 -2
- warp/native/warp.cu +369 -15
- warp/native/warp.h +12 -2
- warp/optim/adam.py +39 -4
- warp/paddle.py +29 -12
- warp/render/render_opengl.py +137 -65
- warp/sim/graph_coloring.py +292 -0
- warp/sim/integrator_euler.py +4 -2
- warp/sim/integrator_featherstone.py +115 -44
- warp/sim/integrator_vbd.py +6 -0
- warp/sim/model.py +88 -15
- warp/stubs.py +569 -4
- warp/tape.py +12 -7
- warp/tests/assets/pixel.npy +0 -0
- warp/tests/aux_test_instancing_gc.py +18 -0
- warp/tests/test_array.py +39 -0
- warp/tests/test_codegen.py +81 -1
- warp/tests/test_codegen_instancing.py +30 -0
- warp/tests/test_collision.py +110 -0
- warp/tests/test_coloring.py +241 -0
- warp/tests/test_context.py +34 -0
- warp/tests/test_examples.py +18 -4
- warp/tests/test_fem.py +453 -113
- warp/tests/test_func.py +13 -0
- warp/tests/test_generics.py +52 -0
- warp/tests/test_iter.py +68 -0
- warp/tests/test_mat_scalar_ops.py +1 -1
- warp/tests/test_mesh_query_point.py +1 -1
- warp/tests/test_module_hashing.py +23 -0
- warp/tests/test_paddle.py +27 -87
- warp/tests/test_print.py +56 -1
- warp/tests/test_spatial.py +1 -1
- warp/tests/test_tile.py +700 -0
- warp/tests/test_tile_mathdx.py +144 -0
- warp/tests/test_tile_mlp.py +383 -0
- warp/tests/test_tile_reduce.py +374 -0
- warp/tests/test_tile_shared_memory.py +190 -0
- warp/tests/test_vbd.py +12 -20
- warp/tests/test_volume.py +43 -0
- warp/tests/unittest_suites.py +19 -2
- warp/tests/unittest_utils.py +4 -0
- warp/types.py +338 -72
- warp/utils.py +22 -1
- {warp_lang-1.4.2.dist-info → warp_lang-1.5.0.dist-info}/METADATA +33 -7
- {warp_lang-1.4.2.dist-info → warp_lang-1.5.0.dist-info}/RECORD +153 -126
- {warp_lang-1.4.2.dist-info → warp_lang-1.5.0.dist-info}/WHEEL +1 -1
- warp/fem/field/test.py +0 -180
- warp/fem/field/trial.py +0 -183
- warp/fem/space/collocated_function_space.py +0 -102
- warp/fem/space/quadmesh_2d_function_space.py +0 -261
- warp/fem/space/trimesh_2d_function_space.py +0 -153
- {warp_lang-1.4.2.dist-info → warp_lang-1.5.0.dist-info}/LICENSE.md +0 -0
- {warp_lang-1.4.2.dist-info → warp_lang-1.5.0.dist-info}/top_level.txt +0 -0
warp/context.py
CHANGED
|
@@ -12,6 +12,7 @@ import hashlib
|
|
|
12
12
|
import inspect
|
|
13
13
|
import io
|
|
14
14
|
import itertools
|
|
15
|
+
import json
|
|
15
16
|
import operator
|
|
16
17
|
import os
|
|
17
18
|
import platform
|
|
@@ -21,7 +22,7 @@ import typing
|
|
|
21
22
|
import weakref
|
|
22
23
|
from copy import copy as shallowcopy
|
|
23
24
|
from pathlib import Path
|
|
24
|
-
from typing import Any, Callable, Dict, List,
|
|
25
|
+
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
|
|
25
26
|
|
|
26
27
|
import numpy as np
|
|
27
28
|
|
|
@@ -101,6 +102,7 @@ class Function:
|
|
|
101
102
|
value_func=None,
|
|
102
103
|
export_func=None,
|
|
103
104
|
dispatch_func=None,
|
|
105
|
+
lto_dispatch_func=None,
|
|
104
106
|
module=None,
|
|
105
107
|
variadic=False,
|
|
106
108
|
initializer_list_func=None,
|
|
@@ -137,6 +139,7 @@ class Function:
|
|
|
137
139
|
self.value_func = value_func # a function that takes a list of args and a list of templates and returns the value type, e.g.: load(array, index) returns the type of value being loaded
|
|
138
140
|
self.export_func = export_func
|
|
139
141
|
self.dispatch_func = dispatch_func
|
|
142
|
+
self.lto_dispatch_func = lto_dispatch_func
|
|
140
143
|
self.input_types = {}
|
|
141
144
|
self.export = export
|
|
142
145
|
self.doc = doc
|
|
@@ -619,10 +622,13 @@ def call_builtin(func: Function, *params) -> Tuple[bool, Any]:
|
|
|
619
622
|
|
|
620
623
|
|
|
621
624
|
class KernelHooks:
|
|
622
|
-
def __init__(self, forward, backward):
|
|
625
|
+
def __init__(self, forward, backward, forward_smem_bytes=0, backward_smem_bytes=0):
|
|
623
626
|
self.forward = forward
|
|
624
627
|
self.backward = backward
|
|
625
628
|
|
|
629
|
+
self.forward_smem_bytes = forward_smem_bytes
|
|
630
|
+
self.backward_smem_bytes = backward_smem_bytes
|
|
631
|
+
|
|
626
632
|
|
|
627
633
|
# caches source and compiled entry points for a kernel (will be populated after module loads)
|
|
628
634
|
class Kernel:
|
|
@@ -970,8 +976,17 @@ def struct(c):
|
|
|
970
976
|
return s
|
|
971
977
|
|
|
972
978
|
|
|
973
|
-
|
|
974
|
-
|
|
979
|
+
def overload(kernel, arg_types=Union[None, Dict[str, Any], List[Any]]):
|
|
980
|
+
"""Overload a generic kernel with the given argument types.
|
|
981
|
+
|
|
982
|
+
Can be called directly or used as a function decorator.
|
|
983
|
+
|
|
984
|
+
Args:
|
|
985
|
+
kernel: The generic kernel to be instantiated with concrete types.
|
|
986
|
+
arg_types: A list of concrete argument types for the kernel or a
|
|
987
|
+
dictionary specifying generic argument names as keys and concrete
|
|
988
|
+
types as variables.
|
|
989
|
+
"""
|
|
975
990
|
if isinstance(kernel, Kernel):
|
|
976
991
|
# handle cases where user calls us directly, e.g. wp.overload(kernel, [args...])
|
|
977
992
|
|
|
@@ -1073,6 +1088,7 @@ def add_builtin(
|
|
|
1073
1088
|
value_func=None,
|
|
1074
1089
|
export_func=None,
|
|
1075
1090
|
dispatch_func=None,
|
|
1091
|
+
lto_dispatch_func=None,
|
|
1076
1092
|
doc="",
|
|
1077
1093
|
namespace="wp::",
|
|
1078
1094
|
variadic=False,
|
|
@@ -1113,6 +1129,9 @@ def add_builtin(
|
|
|
1113
1129
|
The arguments returned must be of type `codegen.Var`.
|
|
1114
1130
|
If not provided, all arguments passed by the users when calling
|
|
1115
1131
|
the built-in are passed as-is as runtime arguments to the C++ function.
|
|
1132
|
+
lto_dispatch_func (Callable): Same as dispatch_func, but takes an 'option' dict
|
|
1133
|
+
as extra argument (indicating tile_size and target architecture) and returns
|
|
1134
|
+
an LTO-IR buffer as extra return value
|
|
1116
1135
|
doc (str): Used to generate the Python's docstring and the HTML documentation.
|
|
1117
1136
|
namespace: Namespace for the underlying C++ function.
|
|
1118
1137
|
variadic (bool): Whether the function declares variadic arguments.
|
|
@@ -1252,6 +1271,7 @@ def add_builtin(
|
|
|
1252
1271
|
value_func=value_func if return_type is Any else None,
|
|
1253
1272
|
export_func=export_func,
|
|
1254
1273
|
dispatch_func=dispatch_func,
|
|
1274
|
+
lto_dispatch_func=lto_dispatch_func,
|
|
1255
1275
|
doc=doc,
|
|
1256
1276
|
namespace=namespace,
|
|
1257
1277
|
variadic=variadic,
|
|
@@ -1274,6 +1294,7 @@ def add_builtin(
|
|
|
1274
1294
|
value_func=value_func,
|
|
1275
1295
|
export_func=export_func,
|
|
1276
1296
|
dispatch_func=dispatch_func,
|
|
1297
|
+
lto_dispatch_func=lto_dispatch_func,
|
|
1277
1298
|
variadic=variadic,
|
|
1278
1299
|
initializer_list_func=initializer_list_func,
|
|
1279
1300
|
export=export,
|
|
@@ -1540,6 +1561,8 @@ class ModuleBuilder:
|
|
|
1540
1561
|
self.options = options
|
|
1541
1562
|
self.module = module
|
|
1542
1563
|
self.deferred_functions = []
|
|
1564
|
+
self.ltoirs = {} # map from lto symbol to lto binary
|
|
1565
|
+
self.ltoirs_decl = {} # map from lto symbol to lto forward declaration
|
|
1543
1566
|
|
|
1544
1567
|
if hasher is None:
|
|
1545
1568
|
hasher = ModuleHasher(module)
|
|
@@ -1607,9 +1630,26 @@ class ModuleBuilder:
|
|
|
1607
1630
|
# use dict to preserve import order
|
|
1608
1631
|
self.functions[func] = None
|
|
1609
1632
|
|
|
1633
|
+
def build_meta(self):
|
|
1634
|
+
meta = {}
|
|
1635
|
+
|
|
1636
|
+
for kernel in self.kernels:
|
|
1637
|
+
name = kernel.get_mangled_name()
|
|
1638
|
+
|
|
1639
|
+
meta[name + "_cuda_kernel_forward_smem_bytes"] = kernel.adj.get_total_required_shared()
|
|
1640
|
+
meta[name + "_cuda_kernel_backward_smem_bytes"] = kernel.adj.get_total_required_shared() * 2
|
|
1641
|
+
|
|
1642
|
+
return meta
|
|
1643
|
+
|
|
1610
1644
|
def codegen(self, device):
|
|
1611
1645
|
source = ""
|
|
1612
1646
|
|
|
1647
|
+
# code-gen LTO forward declarations
|
|
1648
|
+
source += 'extern "C" {\n'
|
|
1649
|
+
for fwd in self.ltoirs_decl.values():
|
|
1650
|
+
source += fwd + "\n"
|
|
1651
|
+
source += "}\n"
|
|
1652
|
+
|
|
1613
1653
|
# code-gen structs
|
|
1614
1654
|
visited_structs = set()
|
|
1615
1655
|
for struct in self.structs.keys():
|
|
@@ -1639,9 +1679,9 @@ class ModuleBuilder:
|
|
|
1639
1679
|
|
|
1640
1680
|
# add headers
|
|
1641
1681
|
if device == "cpu":
|
|
1642
|
-
source = warp.codegen.cpu_module_header + source
|
|
1682
|
+
source = warp.codegen.cpu_module_header.format(tile_size=self.options["block_dim"]) + source
|
|
1643
1683
|
else:
|
|
1644
|
-
source = warp.codegen.cuda_module_header + source
|
|
1684
|
+
source = warp.codegen.cuda_module_header.format(tile_size=self.options["block_dim"]) + source
|
|
1645
1685
|
|
|
1646
1686
|
return source
|
|
1647
1687
|
|
|
@@ -1660,11 +1700,12 @@ class ModuleExec:
|
|
|
1660
1700
|
instance.handle = None
|
|
1661
1701
|
return instance
|
|
1662
1702
|
|
|
1663
|
-
def __init__(self, handle, module_hash, device):
|
|
1703
|
+
def __init__(self, handle, module_hash, device, meta):
|
|
1664
1704
|
self.handle = handle
|
|
1665
1705
|
self.module_hash = module_hash
|
|
1666
1706
|
self.device = device
|
|
1667
1707
|
self.kernel_hooks = {}
|
|
1708
|
+
self.meta = meta
|
|
1668
1709
|
|
|
1669
1710
|
# release the loaded module
|
|
1670
1711
|
def __del__(self):
|
|
@@ -1678,19 +1719,50 @@ class ModuleExec:
|
|
|
1678
1719
|
|
|
1679
1720
|
# lookup and cache kernel entry points
|
|
1680
1721
|
def get_kernel_hooks(self, kernel):
|
|
1681
|
-
|
|
1722
|
+
# Use kernel.adj as a unique key for cache lookups instead of the kernel itself.
|
|
1723
|
+
# This avoids holding a reference to the kernel and is faster than using
|
|
1724
|
+
# a WeakKeyDictionary with kernels as keys.
|
|
1725
|
+
hooks = self.kernel_hooks.get(kernel.adj)
|
|
1682
1726
|
if hooks is not None:
|
|
1683
1727
|
return hooks
|
|
1684
1728
|
|
|
1685
1729
|
name = kernel.get_mangled_name()
|
|
1686
1730
|
|
|
1687
1731
|
if self.device.is_cuda:
|
|
1688
|
-
|
|
1689
|
-
|
|
1732
|
+
forward_name = name + "_cuda_kernel_forward"
|
|
1733
|
+
forward_kernel = runtime.core.cuda_get_kernel(
|
|
1734
|
+
self.device.context, self.handle, forward_name.encode("utf-8")
|
|
1690
1735
|
)
|
|
1691
|
-
|
|
1692
|
-
|
|
1736
|
+
|
|
1737
|
+
backward_name = name + "_cuda_kernel_backward"
|
|
1738
|
+
backward_kernel = runtime.core.cuda_get_kernel(
|
|
1739
|
+
self.device.context, self.handle, backward_name.encode("utf-8")
|
|
1693
1740
|
)
|
|
1741
|
+
|
|
1742
|
+
# look up the required shared memory size for each kernel from module metadata
|
|
1743
|
+
forward_smem_bytes = self.meta[forward_name + "_smem_bytes"]
|
|
1744
|
+
backward_smem_bytes = self.meta[backward_name + "_smem_bytes"]
|
|
1745
|
+
|
|
1746
|
+
# configure kernels maximum shared memory size
|
|
1747
|
+
max_smem_bytes = runtime.core.cuda_get_max_shared_memory(self.device.context)
|
|
1748
|
+
|
|
1749
|
+
if not runtime.core.cuda_configure_kernel_shared_memory(forward_kernel, forward_smem_bytes):
|
|
1750
|
+
print(
|
|
1751
|
+
f"Warning: Failed to configure kernel dynamic shared memory for this device, tried to configure {forward_name} kernel for {forward_smem_bytes} bytes, but maximum available is {max_smem_bytes}"
|
|
1752
|
+
)
|
|
1753
|
+
|
|
1754
|
+
options = dict(kernel.module.options)
|
|
1755
|
+
options.update(kernel.options)
|
|
1756
|
+
|
|
1757
|
+
if options["enable_backward"] and not runtime.core.cuda_configure_kernel_shared_memory(
|
|
1758
|
+
backward_kernel, backward_smem_bytes
|
|
1759
|
+
):
|
|
1760
|
+
print(
|
|
1761
|
+
f"Warning: Failed to configure kernel dynamic shared memory for this device, tried to configure {backward_name} kernel for {backward_smem_bytes} bytes, but maximum available is {max_smem_bytes}"
|
|
1762
|
+
)
|
|
1763
|
+
|
|
1764
|
+
hooks = KernelHooks(forward_kernel, backward_kernel, forward_smem_bytes, backward_smem_bytes)
|
|
1765
|
+
|
|
1694
1766
|
else:
|
|
1695
1767
|
func = ctypes.CFUNCTYPE(None)
|
|
1696
1768
|
forward = (
|
|
@@ -1700,9 +1772,9 @@ class ModuleExec:
|
|
|
1700
1772
|
func(runtime.llvm.lookup(self.handle.encode("utf-8"), (name + "_cpu_backward").encode("utf-8"))) or None
|
|
1701
1773
|
)
|
|
1702
1774
|
|
|
1703
|
-
|
|
1704
|
-
self.kernel_hooks[kernel] = hooks
|
|
1775
|
+
hooks = KernelHooks(forward, backward)
|
|
1705
1776
|
|
|
1777
|
+
self.kernel_hooks[kernel.adj] = hooks
|
|
1706
1778
|
return hooks
|
|
1707
1779
|
|
|
1708
1780
|
|
|
@@ -1712,7 +1784,8 @@ class ModuleExec:
|
|
|
1712
1784
|
# build cache
|
|
1713
1785
|
class Module:
|
|
1714
1786
|
def __init__(self, name, loader):
|
|
1715
|
-
self.name = name
|
|
1787
|
+
self.name = name if name is not None else "None"
|
|
1788
|
+
|
|
1716
1789
|
self.loader = loader
|
|
1717
1790
|
|
|
1718
1791
|
# lookup the latest versions of kernels, functions, and structs by key
|
|
@@ -1720,12 +1793,14 @@ class Module:
|
|
|
1720
1793
|
self.functions = {} # (key: function)
|
|
1721
1794
|
self.structs = {} # (key: struct)
|
|
1722
1795
|
|
|
1723
|
-
# Set of all "live" kernels in this module.
|
|
1796
|
+
# Set of all "live" kernels in this module, i.e., kernels that still have references.
|
|
1797
|
+
# We keep a weak reference to every kernel ever created in this module and rely on Python GC
|
|
1798
|
+
# to release kernels that no longer have any references (in user code or internal bookkeeping).
|
|
1724
1799
|
# The difference between `live_kernels` and `kernels` is that `live_kernels` may contain
|
|
1725
1800
|
# multiple kernels with the same key (which is essential to support closures), while `kernels`
|
|
1726
1801
|
# only holds the latest kernel for each key. When the module is built, we compute the hash
|
|
1727
1802
|
# of each kernel in `live_kernels` and filter out duplicates for codegen.
|
|
1728
|
-
self.
|
|
1803
|
+
self._live_kernels = weakref.WeakSet()
|
|
1729
1804
|
|
|
1730
1805
|
# executable modules currently loaded
|
|
1731
1806
|
self.execs = {} # (device.context: ModuleExec)
|
|
@@ -1749,6 +1824,7 @@ class Module:
|
|
|
1749
1824
|
"fast_math": False,
|
|
1750
1825
|
"cuda_output": None, # supported values: "ptx", "cubin", or None (automatic)
|
|
1751
1826
|
"mode": warp.config.mode,
|
|
1827
|
+
"block_dim": 256,
|
|
1752
1828
|
}
|
|
1753
1829
|
|
|
1754
1830
|
# Module dependencies are determined by scanning each function
|
|
@@ -1773,7 +1849,7 @@ class Module:
|
|
|
1773
1849
|
self.kernels[kernel.key] = kernel
|
|
1774
1850
|
|
|
1775
1851
|
# track all kernel objects, even if they are duplicates
|
|
1776
|
-
self.
|
|
1852
|
+
self._live_kernels.add(kernel)
|
|
1777
1853
|
|
|
1778
1854
|
self.find_references(kernel.adj)
|
|
1779
1855
|
|
|
@@ -1839,6 +1915,19 @@ class Module:
|
|
|
1839
1915
|
# for a reload of module on next launch
|
|
1840
1916
|
self.mark_modified()
|
|
1841
1917
|
|
|
1918
|
+
@property
|
|
1919
|
+
def live_kernels(self):
|
|
1920
|
+
# Return a list of kernels that still have references.
|
|
1921
|
+
# We return a regular list instead of the WeakSet to avoid undesirable issues
|
|
1922
|
+
# if kernels are garbage collected before the caller is done using this list.
|
|
1923
|
+
# Note that we should avoid retaining strong references to kernels unnecessarily
|
|
1924
|
+
# so that Python GC can release kernels that no longer have user references.
|
|
1925
|
+
# It is tempting to call gc.collect() here to force garbage collection,
|
|
1926
|
+
# but this can have undesirable consequences (e.g., GC during graph capture),
|
|
1927
|
+
# so we should avoid it as a general rule. Instead, we rely on Python's
|
|
1928
|
+
# reference counting GC to collect kernels that have gone out of scope.
|
|
1929
|
+
return list(self._live_kernels)
|
|
1930
|
+
|
|
1842
1931
|
# find kernel corresponding to a Python function
|
|
1843
1932
|
def find_kernel(self, func):
|
|
1844
1933
|
qualname = warp.codegen.make_full_qualified_name(func)
|
|
@@ -1879,9 +1968,17 @@ class Module:
|
|
|
1879
1968
|
self.hasher = ModuleHasher(self)
|
|
1880
1969
|
return self.hasher.get_module_hash()
|
|
1881
1970
|
|
|
1882
|
-
def load(self, device) -> ModuleExec:
|
|
1971
|
+
def load(self, device, block_dim=None) -> ModuleExec:
|
|
1883
1972
|
device = runtime.get_device(device)
|
|
1884
1973
|
|
|
1974
|
+
# re-compile module if tile size (blockdim) changes
|
|
1975
|
+
# todo: it would be better to have a method such as `module.get_kernel(block_dim=N)`
|
|
1976
|
+
# that can return a single kernel instance with a given block size
|
|
1977
|
+
if block_dim is not None:
|
|
1978
|
+
if self.options["block_dim"] != block_dim:
|
|
1979
|
+
self.unload()
|
|
1980
|
+
self.options["block_dim"] = block_dim
|
|
1981
|
+
|
|
1885
1982
|
# compute the hash if needed
|
|
1886
1983
|
if self.hasher is None:
|
|
1887
1984
|
self.hasher = ModuleHasher(self)
|
|
@@ -1909,6 +2006,7 @@ class Module:
|
|
|
1909
2006
|
# determine output paths
|
|
1910
2007
|
if device.is_cpu:
|
|
1911
2008
|
output_name = "module_codegen.o"
|
|
2009
|
+
output_arch = None
|
|
1912
2010
|
|
|
1913
2011
|
elif device.is_cuda:
|
|
1914
2012
|
# determine whether to use PTX or CUBIN
|
|
@@ -1947,7 +2045,12 @@ class Module:
|
|
|
1947
2045
|
or not warp.config.cache_kernels
|
|
1948
2046
|
or warp.config.verify_autograd_array_access
|
|
1949
2047
|
):
|
|
1950
|
-
|
|
2048
|
+
builder_options = {
|
|
2049
|
+
**self.options,
|
|
2050
|
+
# Some of the Tile codegen, such as cuFFTDx and cuBLASDx, requires knowledge of the target arch
|
|
2051
|
+
"output_arch": output_arch,
|
|
2052
|
+
}
|
|
2053
|
+
builder = ModuleBuilder(self, builder_options, hasher=self.hasher)
|
|
1951
2054
|
|
|
1952
2055
|
# create a temporary (process unique) dir for build outputs before moving to the binary dir
|
|
1953
2056
|
build_dir = os.path.join(
|
|
@@ -2010,6 +2113,7 @@ class Module:
|
|
|
2010
2113
|
config=self.options["mode"],
|
|
2011
2114
|
fast_math=self.options["fast_math"],
|
|
2012
2115
|
verify_fp=warp.config.verify_fp,
|
|
2116
|
+
ltoirs=builder.ltoirs.values(),
|
|
2013
2117
|
)
|
|
2014
2118
|
|
|
2015
2119
|
except Exception as e:
|
|
@@ -2017,6 +2121,15 @@ class Module:
|
|
|
2017
2121
|
module_load_timer.extra_msg = " (error)"
|
|
2018
2122
|
raise (e)
|
|
2019
2123
|
|
|
2124
|
+
# ------------------------------------------------------------
|
|
2125
|
+
# build meta data
|
|
2126
|
+
|
|
2127
|
+
meta = builder.build_meta()
|
|
2128
|
+
meta_path = os.path.join(build_dir, "module_codegen.meta")
|
|
2129
|
+
|
|
2130
|
+
with open(meta_path, "w") as meta_file:
|
|
2131
|
+
json.dump(meta, meta_file)
|
|
2132
|
+
|
|
2020
2133
|
# -----------------------------------------------------------
|
|
2021
2134
|
# update cache
|
|
2022
2135
|
|
|
@@ -2053,18 +2166,23 @@ class Module:
|
|
|
2053
2166
|
|
|
2054
2167
|
# -----------------------------------------------------------
|
|
2055
2168
|
# Load CPU or CUDA binary
|
|
2169
|
+
|
|
2170
|
+
meta_path = os.path.join(module_dir, "module_codegen.meta")
|
|
2171
|
+
with open(meta_path, "r") as meta_file:
|
|
2172
|
+
meta = json.load(meta_file)
|
|
2173
|
+
|
|
2056
2174
|
if device.is_cpu:
|
|
2057
2175
|
# LLVM modules are identified using strings, so we need to ensure uniqueness
|
|
2058
2176
|
module_handle = f"{module_name}_{self.cpu_exec_id}"
|
|
2059
2177
|
self.cpu_exec_id += 1
|
|
2060
2178
|
runtime.llvm.load_obj(binary_path.encode("utf-8"), module_handle.encode("utf-8"))
|
|
2061
|
-
module_exec = ModuleExec(module_handle, module_hash, device)
|
|
2179
|
+
module_exec = ModuleExec(module_handle, module_hash, device, meta)
|
|
2062
2180
|
self.execs[None] = module_exec
|
|
2063
2181
|
|
|
2064
2182
|
elif device.is_cuda:
|
|
2065
2183
|
cuda_module = warp.build.load_cuda(binary_path, device)
|
|
2066
2184
|
if cuda_module is not None:
|
|
2067
|
-
module_exec = ModuleExec(cuda_module, module_hash, device)
|
|
2185
|
+
module_exec = ModuleExec(cuda_module, module_hash, device, meta)
|
|
2068
2186
|
self.execs[device.context] = module_exec
|
|
2069
2187
|
else:
|
|
2070
2188
|
module_load_timer.extra_msg = " (error)"
|
|
@@ -2719,21 +2837,16 @@ class Graph:
|
|
|
2719
2837
|
|
|
2720
2838
|
class Runtime:
|
|
2721
2839
|
def __init__(self):
|
|
2722
|
-
if sys.version_info < (3,
|
|
2723
|
-
raise RuntimeError("Warp requires Python 3.
|
|
2840
|
+
if sys.version_info < (3, 8):
|
|
2841
|
+
raise RuntimeError("Warp requires Python 3.8 as a minimum")
|
|
2724
2842
|
if sys.version_info < (3, 9):
|
|
2725
2843
|
warp.utils.warn(f"Python 3.9 or newer is recommended for running Warp, detected {sys.version_info}")
|
|
2726
2844
|
|
|
2727
2845
|
bin_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "bin")
|
|
2728
2846
|
|
|
2729
2847
|
if os.name == "nt":
|
|
2730
|
-
|
|
2731
|
-
|
|
2732
|
-
os.add_dll_directory(bin_path)
|
|
2733
|
-
|
|
2734
|
-
else:
|
|
2735
|
-
# Python < 3.8 we add dll directory to path
|
|
2736
|
-
os.environ["PATH"] = bin_path + os.pathsep + os.environ["PATH"]
|
|
2848
|
+
# Python >= 3.8 this method to add dll search paths
|
|
2849
|
+
os.add_dll_directory(bin_path)
|
|
2737
2850
|
|
|
2738
2851
|
warp_lib = os.path.join(bin_path, "warp.dll")
|
|
2739
2852
|
llvm_lib = os.path.join(bin_path, "warp-clang.dll")
|
|
@@ -3205,6 +3318,8 @@ class Runtime:
|
|
|
3205
3318
|
self.core.is_cuda_compatibility_enabled.restype = ctypes.c_int
|
|
3206
3319
|
self.core.is_cutlass_enabled.argtypes = None
|
|
3207
3320
|
self.core.is_cutlass_enabled.restype = ctypes.c_int
|
|
3321
|
+
self.core.is_mathdx_enabled.argtypes = None
|
|
3322
|
+
self.core.is_mathdx_enabled.restype = ctypes.c_int
|
|
3208
3323
|
|
|
3209
3324
|
self.core.cuda_driver_version.argtypes = None
|
|
3210
3325
|
self.core.cuda_driver_version.restype = ctypes.c_int
|
|
@@ -3329,17 +3444,58 @@ class Runtime:
|
|
|
3329
3444
|
self.core.cuda_graph_destroy.restype = ctypes.c_bool
|
|
3330
3445
|
|
|
3331
3446
|
self.core.cuda_compile_program.argtypes = [
|
|
3332
|
-
ctypes.c_char_p,
|
|
3333
|
-
ctypes.c_int,
|
|
3334
|
-
ctypes.c_char_p,
|
|
3335
|
-
ctypes.
|
|
3336
|
-
ctypes.
|
|
3337
|
-
ctypes.c_bool,
|
|
3338
|
-
ctypes.c_bool,
|
|
3339
|
-
ctypes.
|
|
3447
|
+
ctypes.c_char_p, # cuda_src
|
|
3448
|
+
ctypes.c_int, # arch
|
|
3449
|
+
ctypes.c_char_p, # include_dir
|
|
3450
|
+
ctypes.c_int, # num_cuda_include_dirs
|
|
3451
|
+
ctypes.POINTER(ctypes.c_char_p), # cuda include dirs
|
|
3452
|
+
ctypes.c_bool, # debug
|
|
3453
|
+
ctypes.c_bool, # verbose
|
|
3454
|
+
ctypes.c_bool, # verify_fp
|
|
3455
|
+
ctypes.c_bool, # fast_math
|
|
3456
|
+
ctypes.c_char_p, # output_path
|
|
3457
|
+
ctypes.c_size_t, # num_ltoirs
|
|
3458
|
+
ctypes.POINTER(ctypes.c_char_p), # ltoirs
|
|
3459
|
+
ctypes.POINTER(ctypes.c_size_t), # ltoir_sizes
|
|
3340
3460
|
]
|
|
3341
3461
|
self.core.cuda_compile_program.restype = ctypes.c_size_t
|
|
3342
3462
|
|
|
3463
|
+
self.core.cuda_compile_fft.argtypes = [
|
|
3464
|
+
ctypes.c_char_p, # lto
|
|
3465
|
+
ctypes.c_char_p, # function name
|
|
3466
|
+
ctypes.c_int, # num include dirs
|
|
3467
|
+
ctypes.POINTER(ctypes.c_char_p), # include dirs
|
|
3468
|
+
ctypes.c_char_p, # mathdx include dir
|
|
3469
|
+
ctypes.c_int, # arch
|
|
3470
|
+
ctypes.c_int, # size
|
|
3471
|
+
ctypes.c_int, # ept
|
|
3472
|
+
ctypes.c_int, # direction
|
|
3473
|
+
ctypes.c_int, # precision
|
|
3474
|
+
ctypes.POINTER(ctypes.c_int), # smem (out)
|
|
3475
|
+
]
|
|
3476
|
+
self.core.cuda_compile_fft.restype = ctypes.c_bool
|
|
3477
|
+
|
|
3478
|
+
self.core.cuda_compile_dot.argtypes = [
|
|
3479
|
+
ctypes.c_char_p, # lto
|
|
3480
|
+
ctypes.c_char_p, # function name
|
|
3481
|
+
ctypes.c_int, # num include dirs
|
|
3482
|
+
ctypes.POINTER(ctypes.c_char_p), # include dirs
|
|
3483
|
+
ctypes.c_char_p, # mathdx include dir
|
|
3484
|
+
ctypes.c_int, # arch
|
|
3485
|
+
ctypes.c_int, # M
|
|
3486
|
+
ctypes.c_int, # N
|
|
3487
|
+
ctypes.c_int, # K
|
|
3488
|
+
ctypes.c_int, # a_precision
|
|
3489
|
+
ctypes.c_int, # b_precision
|
|
3490
|
+
ctypes.c_int, # c_precision
|
|
3491
|
+
ctypes.c_int, # type
|
|
3492
|
+
ctypes.c_int, # a_arrangement
|
|
3493
|
+
ctypes.c_int, # b_arrangement
|
|
3494
|
+
ctypes.c_int, # c_arrangement
|
|
3495
|
+
ctypes.c_int, # num threads
|
|
3496
|
+
]
|
|
3497
|
+
self.core.cuda_compile_dot.restype = ctypes.c_bool
|
|
3498
|
+
|
|
3343
3499
|
self.core.cuda_load_module.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
|
|
3344
3500
|
self.core.cuda_load_module.restype = ctypes.c_void_p
|
|
3345
3501
|
|
|
@@ -3349,11 +3505,19 @@ class Runtime:
|
|
|
3349
3505
|
self.core.cuda_get_kernel.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_char_p]
|
|
3350
3506
|
self.core.cuda_get_kernel.restype = ctypes.c_void_p
|
|
3351
3507
|
|
|
3508
|
+
self.core.cuda_get_max_shared_memory.argtypes = [ctypes.c_void_p]
|
|
3509
|
+
self.core.cuda_get_max_shared_memory.restype = ctypes.c_int
|
|
3510
|
+
|
|
3511
|
+
self.core.cuda_configure_kernel_shared_memory.argtypes = [ctypes.c_void_p, ctypes.c_int]
|
|
3512
|
+
self.core.cuda_configure_kernel_shared_memory.restype = ctypes.c_bool
|
|
3513
|
+
|
|
3352
3514
|
self.core.cuda_launch_kernel.argtypes = [
|
|
3353
3515
|
ctypes.c_void_p,
|
|
3354
3516
|
ctypes.c_void_p,
|
|
3355
3517
|
ctypes.c_size_t,
|
|
3356
3518
|
ctypes.c_int,
|
|
3519
|
+
ctypes.c_int,
|
|
3520
|
+
ctypes.c_int,
|
|
3357
3521
|
ctypes.POINTER(ctypes.c_void_p),
|
|
3358
3522
|
ctypes.c_void_p,
|
|
3359
3523
|
]
|
|
@@ -3382,6 +3546,23 @@ class Runtime:
|
|
|
3382
3546
|
self.core.cuda_timing_end.argtypes = []
|
|
3383
3547
|
self.core.cuda_timing_end.restype = None
|
|
3384
3548
|
|
|
3549
|
+
self.core.graph_coloring.argtypes = [
|
|
3550
|
+
ctypes.c_int,
|
|
3551
|
+
warp.types.array_t,
|
|
3552
|
+
ctypes.c_int,
|
|
3553
|
+
warp.types.array_t,
|
|
3554
|
+
]
|
|
3555
|
+
self.core.graph_coloring.restype = ctypes.c_int
|
|
3556
|
+
|
|
3557
|
+
self.core.balance_coloring.argtypes = [
|
|
3558
|
+
ctypes.c_int,
|
|
3559
|
+
warp.types.array_t,
|
|
3560
|
+
ctypes.c_int,
|
|
3561
|
+
ctypes.c_float,
|
|
3562
|
+
warp.types.array_t,
|
|
3563
|
+
]
|
|
3564
|
+
self.core.balance_coloring.restype = ctypes.c_float
|
|
3565
|
+
|
|
3385
3566
|
self.core.init.restype = ctypes.c_int
|
|
3386
3567
|
|
|
3387
3568
|
except AttributeError as e:
|
|
@@ -3607,10 +3788,7 @@ class Runtime:
|
|
|
3607
3788
|
|
|
3608
3789
|
def load_dll(self, dll_path):
|
|
3609
3790
|
try:
|
|
3610
|
-
|
|
3611
|
-
dll = ctypes.CDLL(dll_path, winmode=0)
|
|
3612
|
-
else:
|
|
3613
|
-
dll = ctypes.CDLL(dll_path)
|
|
3791
|
+
dll = ctypes.CDLL(dll_path, winmode=0)
|
|
3614
3792
|
except OSError as e:
|
|
3615
3793
|
if "GLIBCXX" in str(e):
|
|
3616
3794
|
raise RuntimeError(
|
|
@@ -3751,7 +3929,7 @@ def is_cuda_available() -> bool:
|
|
|
3751
3929
|
return get_cuda_device_count() > 0
|
|
3752
3930
|
|
|
3753
3931
|
|
|
3754
|
-
def is_device_available(device):
|
|
3932
|
+
def is_device_available(device: Device) -> bool:
|
|
3755
3933
|
return device in get_devices()
|
|
3756
3934
|
|
|
3757
3935
|
|
|
@@ -3811,7 +3989,7 @@ def get_cuda_devices() -> List[Device]:
|
|
|
3811
3989
|
|
|
3812
3990
|
|
|
3813
3991
|
def get_preferred_device() -> Device:
|
|
3814
|
-
"""Returns the preferred compute device,
|
|
3992
|
+
"""Returns the preferred compute device, ``cuda:0`` if available and ``cpu`` otherwise."""
|
|
3815
3993
|
|
|
3816
3994
|
init()
|
|
3817
3995
|
|
|
@@ -3951,7 +4129,7 @@ def set_mempool_release_threshold(device: Devicelike, threshold: Union[int, floa
|
|
|
3951
4129
|
|
|
3952
4130
|
|
|
3953
4131
|
def get_mempool_release_threshold(device: Devicelike) -> int:
|
|
3954
|
-
"""Get the CUDA memory pool release threshold on the device."""
|
|
4132
|
+
"""Get the CUDA memory pool release threshold on the device in bytes."""
|
|
3955
4133
|
|
|
3956
4134
|
init()
|
|
3957
4135
|
|
|
@@ -3970,7 +4148,7 @@ def is_peer_access_supported(target_device: Devicelike, peer_device: Devicelike)
|
|
|
3970
4148
|
"""Check if `peer_device` can directly access the memory of `target_device` on this system.
|
|
3971
4149
|
|
|
3972
4150
|
This applies to memory allocated using default CUDA allocators. For memory allocated using
|
|
3973
|
-
CUDA pooled allocators, use
|
|
4151
|
+
CUDA pooled allocators, use :func:`is_mempool_access_supported()`.
|
|
3974
4152
|
|
|
3975
4153
|
Returns:
|
|
3976
4154
|
A Boolean value indicating if this peer access is supported by the system.
|
|
@@ -3991,7 +4169,7 @@ def is_peer_access_enabled(target_device: Devicelike, peer_device: Devicelike) -
|
|
|
3991
4169
|
"""Check if `peer_device` can currently access the memory of `target_device`.
|
|
3992
4170
|
|
|
3993
4171
|
This applies to memory allocated using default CUDA allocators. For memory allocated using
|
|
3994
|
-
CUDA pooled allocators, use
|
|
4172
|
+
CUDA pooled allocators, use :func:`is_mempool_access_enabled()`.
|
|
3995
4173
|
|
|
3996
4174
|
Returns:
|
|
3997
4175
|
A Boolean value indicating if this peer access is currently enabled.
|
|
@@ -4015,7 +4193,7 @@ def set_peer_access_enabled(target_device: Devicelike, peer_device: Devicelike,
|
|
|
4015
4193
|
a negative impact on memory consumption and allocation performance.
|
|
4016
4194
|
|
|
4017
4195
|
This applies to memory allocated using default CUDA allocators. For memory allocated using
|
|
4018
|
-
CUDA pooled allocators, use
|
|
4196
|
+
CUDA pooled allocators, use :func:`set_mempool_access_enabled()`.
|
|
4019
4197
|
"""
|
|
4020
4198
|
|
|
4021
4199
|
init()
|
|
@@ -4043,7 +4221,8 @@ def set_peer_access_enabled(target_device: Devicelike, peer_device: Devicelike,
|
|
|
4043
4221
|
def is_mempool_access_supported(target_device: Devicelike, peer_device: Devicelike) -> bool:
|
|
4044
4222
|
"""Check if `peer_device` can directly access the memory pool of `target_device`.
|
|
4045
4223
|
|
|
4046
|
-
If mempool access is possible, it can be managed using
|
|
4224
|
+
If mempool access is possible, it can be managed using :func:`set_mempool_access_enabled()`
|
|
4225
|
+
and :func:`is_mempool_access_enabled()`.
|
|
4047
4226
|
|
|
4048
4227
|
Returns:
|
|
4049
4228
|
A Boolean value indicating if this memory pool access is supported by the system.
|
|
@@ -4061,7 +4240,7 @@ def is_mempool_access_enabled(target_device: Devicelike, peer_device: Devicelike
|
|
|
4061
4240
|
"""Check if `peer_device` can currently access the memory pool of `target_device`.
|
|
4062
4241
|
|
|
4063
4242
|
This applies to memory allocated using CUDA pooled allocators. For memory allocated using
|
|
4064
|
-
default CUDA allocators, use
|
|
4243
|
+
default CUDA allocators, use :func:`is_peer_access_enabled()`.
|
|
4065
4244
|
|
|
4066
4245
|
Returns:
|
|
4067
4246
|
A Boolean value indicating if this peer access is currently enabled.
|
|
@@ -4082,7 +4261,7 @@ def set_mempool_access_enabled(target_device: Devicelike, peer_device: Devicelik
|
|
|
4082
4261
|
"""Enable or disable access from `peer_device` to the memory pool of `target_device`.
|
|
4083
4262
|
|
|
4084
4263
|
This applies to memory allocated using CUDA pooled allocators. For memory allocated using
|
|
4085
|
-
default CUDA allocators, use
|
|
4264
|
+
default CUDA allocators, use :func:`set_peer_access_enabled()`.
|
|
4086
4265
|
"""
|
|
4087
4266
|
|
|
4088
4267
|
init()
|
|
@@ -4791,7 +4970,9 @@ def pack_arg(kernel, arg_type, arg_name, value, device, adjoint=False):
|
|
|
4791
4970
|
# represents all data required for a kernel launch
|
|
4792
4971
|
# so that launches can be replayed quickly, use `wp.launch(..., record_cmd=True)`
|
|
4793
4972
|
class Launch:
|
|
4794
|
-
def __init__(
|
|
4973
|
+
def __init__(
|
|
4974
|
+
self, kernel, device, hooks=None, params=None, params_addr=None, bounds=None, max_blocks=0, block_dim=256
|
|
4975
|
+
):
|
|
4795
4976
|
# retain the module executable so it doesn't get unloaded
|
|
4796
4977
|
self.module_exec = kernel.module.load(device)
|
|
4797
4978
|
if not self.module_exec:
|
|
@@ -4830,6 +5011,7 @@ class Launch:
|
|
|
4830
5011
|
self.device = device
|
|
4831
5012
|
self.bounds = bounds
|
|
4832
5013
|
self.max_blocks = max_blocks
|
|
5014
|
+
self.block_dim = block_dim
|
|
4833
5015
|
|
|
4834
5016
|
def set_dim(self, dim):
|
|
4835
5017
|
self.bounds = warp.types.launch_bounds_t(dim)
|
|
@@ -4911,6 +5093,8 @@ class Launch:
|
|
|
4911
5093
|
self.hooks.forward,
|
|
4912
5094
|
self.bounds.size,
|
|
4913
5095
|
self.max_blocks,
|
|
5096
|
+
self.block_dim,
|
|
5097
|
+
self.hooks.forward_smem_bytes,
|
|
4914
5098
|
self.params_addr,
|
|
4915
5099
|
stream.cuda_stream,
|
|
4916
5100
|
)
|
|
@@ -4929,6 +5113,7 @@ def launch(
|
|
|
4929
5113
|
record_tape=True,
|
|
4930
5114
|
record_cmd=False,
|
|
4931
5115
|
max_blocks=0,
|
|
5116
|
+
block_dim=256,
|
|
4932
5117
|
):
|
|
4933
5118
|
"""Launch a Warp kernel on the target device
|
|
4934
5119
|
|
|
@@ -4948,6 +5133,7 @@ def launch(
|
|
|
4948
5133
|
record_cmd: When True the launch will be returned as a ``Launch`` command object, the launch will not occur until the user calls ``cmd.launch()``
|
|
4949
5134
|
max_blocks: The maximum number of CUDA thread blocks to use. Only has an effect for CUDA kernel launches.
|
|
4950
5135
|
If negative or zero, the maximum hardware value will be used.
|
|
5136
|
+
block_dim: The number of threads per block.
|
|
4951
5137
|
"""
|
|
4952
5138
|
|
|
4953
5139
|
init()
|
|
@@ -5001,7 +5187,12 @@ def launch(
|
|
|
5001
5187
|
kernel = kernel.add_overload(fwd_types)
|
|
5002
5188
|
|
|
5003
5189
|
# delay load modules, including new overload if needed
|
|
5004
|
-
|
|
5190
|
+
try:
|
|
5191
|
+
module_exec = kernel.module.load(device, block_dim)
|
|
5192
|
+
except Exception:
|
|
5193
|
+
kernel.adj.skip_build = True
|
|
5194
|
+
raise
|
|
5195
|
+
|
|
5005
5196
|
if not module_exec:
|
|
5006
5197
|
return
|
|
5007
5198
|
|
|
@@ -5057,7 +5248,14 @@ def launch(
|
|
|
5057
5248
|
)
|
|
5058
5249
|
|
|
5059
5250
|
runtime.core.cuda_launch_kernel(
|
|
5060
|
-
device.context,
|
|
5251
|
+
device.context,
|
|
5252
|
+
hooks.backward,
|
|
5253
|
+
bounds.size,
|
|
5254
|
+
max_blocks,
|
|
5255
|
+
block_dim,
|
|
5256
|
+
hooks.backward_smem_bytes,
|
|
5257
|
+
kernel_params,
|
|
5258
|
+
stream.cuda_stream,
|
|
5061
5259
|
)
|
|
5062
5260
|
|
|
5063
5261
|
else:
|
|
@@ -5080,7 +5278,14 @@ def launch(
|
|
|
5080
5278
|
else:
|
|
5081
5279
|
# launch
|
|
5082
5280
|
runtime.core.cuda_launch_kernel(
|
|
5083
|
-
device.context,
|
|
5281
|
+
device.context,
|
|
5282
|
+
hooks.forward,
|
|
5283
|
+
bounds.size,
|
|
5284
|
+
max_blocks,
|
|
5285
|
+
block_dim,
|
|
5286
|
+
hooks.forward_smem_bytes,
|
|
5287
|
+
kernel_params,
|
|
5288
|
+
stream.cuda_stream,
|
|
5084
5289
|
)
|
|
5085
5290
|
|
|
5086
5291
|
try:
|
|
@@ -5094,13 +5299,65 @@ def launch(
|
|
|
5094
5299
|
# record file, lineno, func as metadata
|
|
5095
5300
|
frame = inspect.currentframe().f_back
|
|
5096
5301
|
caller = {"file": frame.f_code.co_filename, "lineno": frame.f_lineno, "func": frame.f_code.co_name}
|
|
5097
|
-
runtime.tape.record_launch(
|
|
5302
|
+
runtime.tape.record_launch(
|
|
5303
|
+
kernel, dim, max_blocks, inputs, outputs, device, block_dim, metadata={"caller": caller}
|
|
5304
|
+
)
|
|
5098
5305
|
|
|
5099
5306
|
# detect illegal inter-kernel read/write access patterns if verification flag is set
|
|
5100
5307
|
if warp.config.verify_autograd_array_access:
|
|
5101
5308
|
runtime.tape._check_kernel_array_access(kernel, fwd_args)
|
|
5102
5309
|
|
|
5103
5310
|
|
|
5311
|
+
def launch_tiled(*args, **kwargs):
|
|
5312
|
+
"""A helper method for launching a grid with an extra trailing dimension equal to the block size.
|
|
5313
|
+
|
|
5314
|
+
For example, to launch a 2D grid, where each element has 64 threads assigned you would use the following:
|
|
5315
|
+
|
|
5316
|
+
.. code-block:: python
|
|
5317
|
+
|
|
5318
|
+
wp.launch_tiled(kernel, [M, N], inputs=[...], block_dim=64)
|
|
5319
|
+
|
|
5320
|
+
Which is equivalent to the following:
|
|
5321
|
+
|
|
5322
|
+
.. code-block:: python
|
|
5323
|
+
|
|
5324
|
+
wp.launch(kernel, [M, N, 64], inputs=[...], block_dim=64)
|
|
5325
|
+
|
|
5326
|
+
Inside your kernel code you can retrieve the first two indices of the thread as usual, ignoring the implicit third dimension if desired:
|
|
5327
|
+
|
|
5328
|
+
.. code-block:: python
|
|
5329
|
+
|
|
5330
|
+
@wp.kernel
|
|
5331
|
+
def compute()
|
|
5332
|
+
|
|
5333
|
+
i, j = wp.tid()
|
|
5334
|
+
|
|
5335
|
+
...
|
|
5336
|
+
"""
|
|
5337
|
+
|
|
5338
|
+
# promote dim to a list in case it was passed as a scalar or tuple
|
|
5339
|
+
if "dim" not in kwargs:
|
|
5340
|
+
raise RuntimeError("Launch dimensions 'dim' argument should be passed via. keyword args for wp.launch_tiled()")
|
|
5341
|
+
|
|
5342
|
+
if "block_dim" not in kwargs:
|
|
5343
|
+
raise RuntimeError(
|
|
5344
|
+
"Launch block dimension 'block_dim' argument should be passed via. keyword args for wp.launch_tiled()"
|
|
5345
|
+
)
|
|
5346
|
+
|
|
5347
|
+
dim = kwargs["dim"]
|
|
5348
|
+
if not isinstance(dim, list):
|
|
5349
|
+
dim = list(dim) if isinstance(dim, tuple) else [dim]
|
|
5350
|
+
|
|
5351
|
+
if len(dim) > 3:
|
|
5352
|
+
raise RuntimeError("wp.launch_tiled() requires a grid with fewer than 4 dimensions")
|
|
5353
|
+
|
|
5354
|
+
# add trailing dimension
|
|
5355
|
+
kwargs["dim"] = dim + [kwargs["block_dim"]]
|
|
5356
|
+
|
|
5357
|
+
# forward to original launch method
|
|
5358
|
+
launch(*args, **kwargs)
|
|
5359
|
+
|
|
5360
|
+
|
|
5104
5361
|
def synchronize():
|
|
5105
5362
|
"""Manually synchronize the calling CPU thread with any outstanding CUDA work on all devices
|
|
5106
5363
|
|
|
@@ -5619,16 +5876,6 @@ def type_str(t):
|
|
|
5619
5876
|
return "Any"
|
|
5620
5877
|
elif t == Callable:
|
|
5621
5878
|
return "Callable"
|
|
5622
|
-
elif t == Tuple[int]:
|
|
5623
|
-
return "Tuple[int]"
|
|
5624
|
-
elif t == Tuple[int, int]:
|
|
5625
|
-
return "Tuple[int, int]"
|
|
5626
|
-
elif t == Tuple[int, int, int]:
|
|
5627
|
-
return "Tuple[int, int, int]"
|
|
5628
|
-
elif t == Tuple[int, int, int, int]:
|
|
5629
|
-
return "Tuple[int, int, int, int]"
|
|
5630
|
-
elif t == Tuple[int, ...]:
|
|
5631
|
-
return "Tuple[int, ...]"
|
|
5632
5879
|
elif isinstance(t, int):
|
|
5633
5880
|
return str(t)
|
|
5634
5881
|
elif isinstance(t, List):
|
|
@@ -5663,9 +5910,13 @@ def type_str(t):
|
|
|
5663
5910
|
return f"Transformation[{type_str(t._wp_scalar_type_)}]"
|
|
5664
5911
|
|
|
5665
5912
|
raise TypeError("Invalid vector or matrix dimensions")
|
|
5666
|
-
elif
|
|
5667
|
-
args_repr = ", ".join(type_str(x) for x in
|
|
5668
|
-
return f"{t.
|
|
5913
|
+
elif warp.codegen.get_type_origin(t) in (list, tuple):
|
|
5914
|
+
args_repr = ", ".join(type_str(x) for x in warp.codegen.get_type_args(t))
|
|
5915
|
+
return f"{t._name}[{args_repr}]"
|
|
5916
|
+
elif t is Ellipsis:
|
|
5917
|
+
return "..."
|
|
5918
|
+
elif warp.types.is_tile(t):
|
|
5919
|
+
return "Tile"
|
|
5669
5920
|
|
|
5670
5921
|
return t.__name__
|
|
5671
5922
|
|
|
@@ -5826,9 +6077,6 @@ def export_stubs(file): # pragma: no cover
|
|
|
5826
6077
|
print('Cols = TypeVar("Cols", bound=int)', file=file)
|
|
5827
6078
|
print('DType = TypeVar("DType")', file=file)
|
|
5828
6079
|
|
|
5829
|
-
print('Int = TypeVar("Int")', file=file)
|
|
5830
|
-
print('Float = TypeVar("Float")', file=file)
|
|
5831
|
-
print('Scalar = TypeVar("Scalar")', file=file)
|
|
5832
6080
|
print("Vector = Generic[Length, Scalar]", file=file)
|
|
5833
6081
|
print("Matrix = Generic[Rows, Cols, Scalar]", file=file)
|
|
5834
6082
|
print("Quaternion = Generic[Float]", file=file)
|