warp-lang 1.4.1__py3-none-manylinux2014_aarch64.whl → 1.5.0__py3-none-manylinux2014_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +4 -0
- warp/autograd.py +43 -8
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/build.py +21 -2
- warp/build_dll.py +23 -6
- warp/builtins.py +1920 -111
- warp/codegen.py +186 -62
- warp/config.py +2 -2
- warp/context.py +322 -73
- warp/examples/assets/pixel.jpg +0 -0
- warp/examples/benchmarks/benchmark_cloth_paddle.py +86 -0
- warp/examples/benchmarks/benchmark_gemm.py +121 -0
- warp/examples/benchmarks/benchmark_interop_paddle.py +158 -0
- warp/examples/benchmarks/benchmark_tile.py +179 -0
- warp/examples/core/example_dem.py +2 -1
- warp/examples/core/example_mesh_intersect.py +3 -3
- warp/examples/fem/example_adaptive_grid.py +37 -10
- warp/examples/fem/example_apic_fluid.py +3 -2
- warp/examples/fem/example_convection_diffusion_dg.py +4 -5
- warp/examples/fem/example_deformed_geometry.py +1 -1
- warp/examples/fem/example_diffusion_3d.py +47 -4
- warp/examples/fem/example_distortion_energy.py +220 -0
- warp/examples/fem/example_magnetostatics.py +127 -85
- warp/examples/fem/example_nonconforming_contact.py +5 -5
- warp/examples/fem/example_stokes.py +3 -1
- warp/examples/fem/example_streamlines.py +12 -19
- warp/examples/fem/utils.py +38 -15
- warp/examples/optim/example_walker.py +2 -2
- warp/examples/sim/example_cloth.py +2 -25
- warp/examples/sim/example_jacobian_ik.py +6 -2
- warp/examples/sim/example_quadruped.py +2 -1
- warp/examples/tile/example_tile_convolution.py +58 -0
- warp/examples/tile/example_tile_fft.py +47 -0
- warp/examples/tile/example_tile_filtering.py +105 -0
- warp/examples/tile/example_tile_matmul.py +79 -0
- warp/examples/tile/example_tile_mlp.py +375 -0
- warp/fem/__init__.py +8 -0
- warp/fem/cache.py +16 -12
- warp/fem/dirichlet.py +1 -1
- warp/fem/domain.py +44 -1
- warp/fem/field/__init__.py +1 -2
- warp/fem/field/field.py +31 -19
- warp/fem/field/nodal_field.py +101 -49
- warp/fem/field/virtual.py +794 -0
- warp/fem/geometry/__init__.py +2 -2
- warp/fem/geometry/deformed_geometry.py +3 -105
- warp/fem/geometry/element.py +13 -0
- warp/fem/geometry/geometry.py +165 -5
- warp/fem/geometry/grid_2d.py +3 -6
- warp/fem/geometry/grid_3d.py +31 -28
- warp/fem/geometry/hexmesh.py +3 -46
- warp/fem/geometry/nanogrid.py +3 -2
- warp/fem/geometry/{quadmesh_2d.py → quadmesh.py} +280 -159
- warp/fem/geometry/tetmesh.py +2 -43
- warp/fem/geometry/{trimesh_2d.py → trimesh.py} +354 -186
- warp/fem/integrate.py +683 -261
- warp/fem/linalg.py +404 -0
- warp/fem/operator.py +101 -18
- warp/fem/polynomial.py +5 -5
- warp/fem/quadrature/quadrature.py +45 -21
- warp/fem/space/__init__.py +45 -11
- warp/fem/space/basis_function_space.py +451 -0
- warp/fem/space/basis_space.py +58 -11
- warp/fem/space/function_space.py +146 -5
- warp/fem/space/grid_2d_function_space.py +80 -66
- warp/fem/space/grid_3d_function_space.py +113 -68
- warp/fem/space/hexmesh_function_space.py +96 -108
- warp/fem/space/nanogrid_function_space.py +62 -110
- warp/fem/space/quadmesh_function_space.py +208 -0
- warp/fem/space/shape/__init__.py +45 -7
- warp/fem/space/shape/cube_shape_function.py +328 -54
- warp/fem/space/shape/shape_function.py +10 -1
- warp/fem/space/shape/square_shape_function.py +328 -60
- warp/fem/space/shape/tet_shape_function.py +269 -19
- warp/fem/space/shape/triangle_shape_function.py +238 -19
- warp/fem/space/tetmesh_function_space.py +69 -37
- warp/fem/space/topology.py +38 -0
- warp/fem/space/trimesh_function_space.py +179 -0
- warp/fem/utils.py +6 -331
- warp/jax_experimental.py +3 -1
- warp/native/array.h +55 -40
- warp/native/builtin.h +124 -43
- warp/native/bvh.h +4 -0
- warp/native/coloring.cpp +600 -0
- warp/native/cuda_util.cpp +14 -0
- warp/native/cuda_util.h +2 -1
- warp/native/fabric.h +8 -0
- warp/native/hashgrid.h +4 -0
- warp/native/marching.cu +8 -0
- warp/native/mat.h +14 -3
- warp/native/mathdx.cpp +59 -0
- warp/native/mesh.h +4 -0
- warp/native/range.h +13 -1
- warp/native/reduce.cpp +9 -1
- warp/native/reduce.cu +7 -0
- warp/native/runlength_encode.cpp +9 -1
- warp/native/runlength_encode.cu +7 -1
- warp/native/scan.cpp +8 -0
- warp/native/scan.cu +8 -0
- warp/native/scan.h +8 -1
- warp/native/sparse.cpp +8 -0
- warp/native/sparse.cu +8 -0
- warp/native/temp_buffer.h +7 -0
- warp/native/tile.h +1857 -0
- warp/native/tile_gemm.h +341 -0
- warp/native/tile_reduce.h +210 -0
- warp/native/volume_builder.cu +8 -0
- warp/native/volume_builder.h +8 -0
- warp/native/warp.cpp +10 -2
- warp/native/warp.cu +369 -15
- warp/native/warp.h +12 -2
- warp/optim/adam.py +39 -4
- warp/paddle.py +29 -12
- warp/render/render_opengl.py +137 -65
- warp/sim/graph_coloring.py +292 -0
- warp/sim/integrator_euler.py +4 -2
- warp/sim/integrator_featherstone.py +115 -44
- warp/sim/integrator_vbd.py +6 -0
- warp/sim/model.py +90 -17
- warp/stubs.py +651 -85
- warp/tape.py +12 -7
- warp/tests/assets/pixel.npy +0 -0
- warp/tests/aux_test_instancing_gc.py +18 -0
- warp/tests/test_array.py +207 -48
- warp/tests/test_closest_point_edge_edge.py +8 -8
- warp/tests/test_codegen.py +120 -1
- warp/tests/test_codegen_instancing.py +30 -0
- warp/tests/test_collision.py +110 -0
- warp/tests/test_coloring.py +241 -0
- warp/tests/test_context.py +34 -0
- warp/tests/test_examples.py +18 -4
- warp/tests/test_fabricarray.py +33 -0
- warp/tests/test_fem.py +453 -113
- warp/tests/test_func.py +48 -1
- warp/tests/test_generics.py +52 -0
- warp/tests/test_iter.py +68 -0
- warp/tests/test_mat_scalar_ops.py +1 -1
- warp/tests/test_mesh_query_point.py +5 -4
- warp/tests/test_module_hashing.py +23 -0
- warp/tests/test_paddle.py +27 -87
- warp/tests/test_print.py +191 -1
- warp/tests/test_spatial.py +1 -1
- warp/tests/test_tile.py +700 -0
- warp/tests/test_tile_mathdx.py +144 -0
- warp/tests/test_tile_mlp.py +383 -0
- warp/tests/test_tile_reduce.py +374 -0
- warp/tests/test_tile_shared_memory.py +190 -0
- warp/tests/test_vbd.py +12 -20
- warp/tests/test_volume.py +43 -0
- warp/tests/unittest_suites.py +23 -2
- warp/tests/unittest_utils.py +4 -0
- warp/types.py +339 -73
- warp/utils.py +22 -1
- {warp_lang-1.4.1.dist-info → warp_lang-1.5.0.dist-info}/METADATA +33 -7
- {warp_lang-1.4.1.dist-info → warp_lang-1.5.0.dist-info}/RECORD +159 -132
- {warp_lang-1.4.1.dist-info → warp_lang-1.5.0.dist-info}/WHEEL +1 -1
- warp/fem/field/test.py +0 -180
- warp/fem/field/trial.py +0 -183
- warp/fem/space/collocated_function_space.py +0 -102
- warp/fem/space/quadmesh_2d_function_space.py +0 -261
- warp/fem/space/trimesh_2d_function_space.py +0 -153
- {warp_lang-1.4.1.dist-info → warp_lang-1.5.0.dist-info}/LICENSE.md +0 -0
- {warp_lang-1.4.1.dist-info → warp_lang-1.5.0.dist-info}/top_level.txt +0 -0
warp/context.py
CHANGED
|
@@ -12,6 +12,7 @@ import hashlib
|
|
|
12
12
|
import inspect
|
|
13
13
|
import io
|
|
14
14
|
import itertools
|
|
15
|
+
import json
|
|
15
16
|
import operator
|
|
16
17
|
import os
|
|
17
18
|
import platform
|
|
@@ -21,7 +22,7 @@ import typing
|
|
|
21
22
|
import weakref
|
|
22
23
|
from copy import copy as shallowcopy
|
|
23
24
|
from pathlib import Path
|
|
24
|
-
from typing import Any, Callable, Dict, List,
|
|
25
|
+
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
|
|
25
26
|
|
|
26
27
|
import numpy as np
|
|
27
28
|
|
|
@@ -101,6 +102,7 @@ class Function:
|
|
|
101
102
|
value_func=None,
|
|
102
103
|
export_func=None,
|
|
103
104
|
dispatch_func=None,
|
|
105
|
+
lto_dispatch_func=None,
|
|
104
106
|
module=None,
|
|
105
107
|
variadic=False,
|
|
106
108
|
initializer_list_func=None,
|
|
@@ -137,6 +139,7 @@ class Function:
|
|
|
137
139
|
self.value_func = value_func # a function that takes a list of args and a list of templates and returns the value type, e.g.: load(array, index) returns the type of value being loaded
|
|
138
140
|
self.export_func = export_func
|
|
139
141
|
self.dispatch_func = dispatch_func
|
|
142
|
+
self.lto_dispatch_func = lto_dispatch_func
|
|
140
143
|
self.input_types = {}
|
|
141
144
|
self.export = export
|
|
142
145
|
self.doc = doc
|
|
@@ -619,10 +622,13 @@ def call_builtin(func: Function, *params) -> Tuple[bool, Any]:
|
|
|
619
622
|
|
|
620
623
|
|
|
621
624
|
class KernelHooks:
|
|
622
|
-
def __init__(self, forward, backward):
|
|
625
|
+
def __init__(self, forward, backward, forward_smem_bytes=0, backward_smem_bytes=0):
|
|
623
626
|
self.forward = forward
|
|
624
627
|
self.backward = backward
|
|
625
628
|
|
|
629
|
+
self.forward_smem_bytes = forward_smem_bytes
|
|
630
|
+
self.backward_smem_bytes = backward_smem_bytes
|
|
631
|
+
|
|
626
632
|
|
|
627
633
|
# caches source and compiled entry points for a kernel (will be populated after module loads)
|
|
628
634
|
class Kernel:
|
|
@@ -970,8 +976,17 @@ def struct(c):
|
|
|
970
976
|
return s
|
|
971
977
|
|
|
972
978
|
|
|
973
|
-
|
|
974
|
-
|
|
979
|
+
def overload(kernel, arg_types=Union[None, Dict[str, Any], List[Any]]):
|
|
980
|
+
"""Overload a generic kernel with the given argument types.
|
|
981
|
+
|
|
982
|
+
Can be called directly or used as a function decorator.
|
|
983
|
+
|
|
984
|
+
Args:
|
|
985
|
+
kernel: The generic kernel to be instantiated with concrete types.
|
|
986
|
+
arg_types: A list of concrete argument types for the kernel or a
|
|
987
|
+
dictionary specifying generic argument names as keys and concrete
|
|
988
|
+
types as variables.
|
|
989
|
+
"""
|
|
975
990
|
if isinstance(kernel, Kernel):
|
|
976
991
|
# handle cases where user calls us directly, e.g. wp.overload(kernel, [args...])
|
|
977
992
|
|
|
@@ -1073,6 +1088,7 @@ def add_builtin(
|
|
|
1073
1088
|
value_func=None,
|
|
1074
1089
|
export_func=None,
|
|
1075
1090
|
dispatch_func=None,
|
|
1091
|
+
lto_dispatch_func=None,
|
|
1076
1092
|
doc="",
|
|
1077
1093
|
namespace="wp::",
|
|
1078
1094
|
variadic=False,
|
|
@@ -1113,6 +1129,9 @@ def add_builtin(
|
|
|
1113
1129
|
The arguments returned must be of type `codegen.Var`.
|
|
1114
1130
|
If not provided, all arguments passed by the users when calling
|
|
1115
1131
|
the built-in are passed as-is as runtime arguments to the C++ function.
|
|
1132
|
+
lto_dispatch_func (Callable): Same as dispatch_func, but takes an 'option' dict
|
|
1133
|
+
as extra argument (indicating tile_size and target architecture) and returns
|
|
1134
|
+
an LTO-IR buffer as extra return value
|
|
1116
1135
|
doc (str): Used to generate the Python's docstring and the HTML documentation.
|
|
1117
1136
|
namespace: Namespace for the underlying C++ function.
|
|
1118
1137
|
variadic (bool): Whether the function declares variadic arguments.
|
|
@@ -1249,8 +1268,10 @@ def add_builtin(
|
|
|
1249
1268
|
key,
|
|
1250
1269
|
input_types=arg_types,
|
|
1251
1270
|
value_type=return_type,
|
|
1271
|
+
value_func=value_func if return_type is Any else None,
|
|
1252
1272
|
export_func=export_func,
|
|
1253
1273
|
dispatch_func=dispatch_func,
|
|
1274
|
+
lto_dispatch_func=lto_dispatch_func,
|
|
1254
1275
|
doc=doc,
|
|
1255
1276
|
namespace=namespace,
|
|
1256
1277
|
variadic=variadic,
|
|
@@ -1273,6 +1294,7 @@ def add_builtin(
|
|
|
1273
1294
|
value_func=value_func,
|
|
1274
1295
|
export_func=export_func,
|
|
1275
1296
|
dispatch_func=dispatch_func,
|
|
1297
|
+
lto_dispatch_func=lto_dispatch_func,
|
|
1276
1298
|
variadic=variadic,
|
|
1277
1299
|
initializer_list_func=initializer_list_func,
|
|
1278
1300
|
export=export,
|
|
@@ -1539,6 +1561,8 @@ class ModuleBuilder:
|
|
|
1539
1561
|
self.options = options
|
|
1540
1562
|
self.module = module
|
|
1541
1563
|
self.deferred_functions = []
|
|
1564
|
+
self.ltoirs = {} # map from lto symbol to lto binary
|
|
1565
|
+
self.ltoirs_decl = {} # map from lto symbol to lto forward declaration
|
|
1542
1566
|
|
|
1543
1567
|
if hasher is None:
|
|
1544
1568
|
hasher = ModuleHasher(module)
|
|
@@ -1606,9 +1630,26 @@ class ModuleBuilder:
|
|
|
1606
1630
|
# use dict to preserve import order
|
|
1607
1631
|
self.functions[func] = None
|
|
1608
1632
|
|
|
1633
|
+
def build_meta(self):
|
|
1634
|
+
meta = {}
|
|
1635
|
+
|
|
1636
|
+
for kernel in self.kernels:
|
|
1637
|
+
name = kernel.get_mangled_name()
|
|
1638
|
+
|
|
1639
|
+
meta[name + "_cuda_kernel_forward_smem_bytes"] = kernel.adj.get_total_required_shared()
|
|
1640
|
+
meta[name + "_cuda_kernel_backward_smem_bytes"] = kernel.adj.get_total_required_shared() * 2
|
|
1641
|
+
|
|
1642
|
+
return meta
|
|
1643
|
+
|
|
1609
1644
|
def codegen(self, device):
|
|
1610
1645
|
source = ""
|
|
1611
1646
|
|
|
1647
|
+
# code-gen LTO forward declarations
|
|
1648
|
+
source += 'extern "C" {\n'
|
|
1649
|
+
for fwd in self.ltoirs_decl.values():
|
|
1650
|
+
source += fwd + "\n"
|
|
1651
|
+
source += "}\n"
|
|
1652
|
+
|
|
1612
1653
|
# code-gen structs
|
|
1613
1654
|
visited_structs = set()
|
|
1614
1655
|
for struct in self.structs.keys():
|
|
@@ -1638,9 +1679,9 @@ class ModuleBuilder:
|
|
|
1638
1679
|
|
|
1639
1680
|
# add headers
|
|
1640
1681
|
if device == "cpu":
|
|
1641
|
-
source = warp.codegen.cpu_module_header + source
|
|
1682
|
+
source = warp.codegen.cpu_module_header.format(tile_size=self.options["block_dim"]) + source
|
|
1642
1683
|
else:
|
|
1643
|
-
source = warp.codegen.cuda_module_header + source
|
|
1684
|
+
source = warp.codegen.cuda_module_header.format(tile_size=self.options["block_dim"]) + source
|
|
1644
1685
|
|
|
1645
1686
|
return source
|
|
1646
1687
|
|
|
@@ -1659,11 +1700,12 @@ class ModuleExec:
|
|
|
1659
1700
|
instance.handle = None
|
|
1660
1701
|
return instance
|
|
1661
1702
|
|
|
1662
|
-
def __init__(self, handle, module_hash, device):
|
|
1703
|
+
def __init__(self, handle, module_hash, device, meta):
|
|
1663
1704
|
self.handle = handle
|
|
1664
1705
|
self.module_hash = module_hash
|
|
1665
1706
|
self.device = device
|
|
1666
1707
|
self.kernel_hooks = {}
|
|
1708
|
+
self.meta = meta
|
|
1667
1709
|
|
|
1668
1710
|
# release the loaded module
|
|
1669
1711
|
def __del__(self):
|
|
@@ -1677,19 +1719,50 @@ class ModuleExec:
|
|
|
1677
1719
|
|
|
1678
1720
|
# lookup and cache kernel entry points
|
|
1679
1721
|
def get_kernel_hooks(self, kernel):
|
|
1680
|
-
|
|
1722
|
+
# Use kernel.adj as a unique key for cache lookups instead of the kernel itself.
|
|
1723
|
+
# This avoids holding a reference to the kernel and is faster than using
|
|
1724
|
+
# a WeakKeyDictionary with kernels as keys.
|
|
1725
|
+
hooks = self.kernel_hooks.get(kernel.adj)
|
|
1681
1726
|
if hooks is not None:
|
|
1682
1727
|
return hooks
|
|
1683
1728
|
|
|
1684
1729
|
name = kernel.get_mangled_name()
|
|
1685
1730
|
|
|
1686
1731
|
if self.device.is_cuda:
|
|
1687
|
-
|
|
1688
|
-
|
|
1732
|
+
forward_name = name + "_cuda_kernel_forward"
|
|
1733
|
+
forward_kernel = runtime.core.cuda_get_kernel(
|
|
1734
|
+
self.device.context, self.handle, forward_name.encode("utf-8")
|
|
1689
1735
|
)
|
|
1690
|
-
|
|
1691
|
-
|
|
1736
|
+
|
|
1737
|
+
backward_name = name + "_cuda_kernel_backward"
|
|
1738
|
+
backward_kernel = runtime.core.cuda_get_kernel(
|
|
1739
|
+
self.device.context, self.handle, backward_name.encode("utf-8")
|
|
1692
1740
|
)
|
|
1741
|
+
|
|
1742
|
+
# look up the required shared memory size for each kernel from module metadata
|
|
1743
|
+
forward_smem_bytes = self.meta[forward_name + "_smem_bytes"]
|
|
1744
|
+
backward_smem_bytes = self.meta[backward_name + "_smem_bytes"]
|
|
1745
|
+
|
|
1746
|
+
# configure kernels maximum shared memory size
|
|
1747
|
+
max_smem_bytes = runtime.core.cuda_get_max_shared_memory(self.device.context)
|
|
1748
|
+
|
|
1749
|
+
if not runtime.core.cuda_configure_kernel_shared_memory(forward_kernel, forward_smem_bytes):
|
|
1750
|
+
print(
|
|
1751
|
+
f"Warning: Failed to configure kernel dynamic shared memory for this device, tried to configure {forward_name} kernel for {forward_smem_bytes} bytes, but maximum available is {max_smem_bytes}"
|
|
1752
|
+
)
|
|
1753
|
+
|
|
1754
|
+
options = dict(kernel.module.options)
|
|
1755
|
+
options.update(kernel.options)
|
|
1756
|
+
|
|
1757
|
+
if options["enable_backward"] and not runtime.core.cuda_configure_kernel_shared_memory(
|
|
1758
|
+
backward_kernel, backward_smem_bytes
|
|
1759
|
+
):
|
|
1760
|
+
print(
|
|
1761
|
+
f"Warning: Failed to configure kernel dynamic shared memory for this device, tried to configure {backward_name} kernel for {backward_smem_bytes} bytes, but maximum available is {max_smem_bytes}"
|
|
1762
|
+
)
|
|
1763
|
+
|
|
1764
|
+
hooks = KernelHooks(forward_kernel, backward_kernel, forward_smem_bytes, backward_smem_bytes)
|
|
1765
|
+
|
|
1693
1766
|
else:
|
|
1694
1767
|
func = ctypes.CFUNCTYPE(None)
|
|
1695
1768
|
forward = (
|
|
@@ -1699,9 +1772,9 @@ class ModuleExec:
|
|
|
1699
1772
|
func(runtime.llvm.lookup(self.handle.encode("utf-8"), (name + "_cpu_backward").encode("utf-8"))) or None
|
|
1700
1773
|
)
|
|
1701
1774
|
|
|
1702
|
-
|
|
1703
|
-
self.kernel_hooks[kernel] = hooks
|
|
1775
|
+
hooks = KernelHooks(forward, backward)
|
|
1704
1776
|
|
|
1777
|
+
self.kernel_hooks[kernel.adj] = hooks
|
|
1705
1778
|
return hooks
|
|
1706
1779
|
|
|
1707
1780
|
|
|
@@ -1711,7 +1784,8 @@ class ModuleExec:
|
|
|
1711
1784
|
# build cache
|
|
1712
1785
|
class Module:
|
|
1713
1786
|
def __init__(self, name, loader):
|
|
1714
|
-
self.name = name
|
|
1787
|
+
self.name = name if name is not None else "None"
|
|
1788
|
+
|
|
1715
1789
|
self.loader = loader
|
|
1716
1790
|
|
|
1717
1791
|
# lookup the latest versions of kernels, functions, and structs by key
|
|
@@ -1719,12 +1793,14 @@ class Module:
|
|
|
1719
1793
|
self.functions = {} # (key: function)
|
|
1720
1794
|
self.structs = {} # (key: struct)
|
|
1721
1795
|
|
|
1722
|
-
# Set of all "live" kernels in this module.
|
|
1796
|
+
# Set of all "live" kernels in this module, i.e., kernels that still have references.
|
|
1797
|
+
# We keep a weak reference to every kernel ever created in this module and rely on Python GC
|
|
1798
|
+
# to release kernels that no longer have any references (in user code or internal bookkeeping).
|
|
1723
1799
|
# The difference between `live_kernels` and `kernels` is that `live_kernels` may contain
|
|
1724
1800
|
# multiple kernels with the same key (which is essential to support closures), while `kernels`
|
|
1725
1801
|
# only holds the latest kernel for each key. When the module is built, we compute the hash
|
|
1726
1802
|
# of each kernel in `live_kernels` and filter out duplicates for codegen.
|
|
1727
|
-
self.
|
|
1803
|
+
self._live_kernels = weakref.WeakSet()
|
|
1728
1804
|
|
|
1729
1805
|
# executable modules currently loaded
|
|
1730
1806
|
self.execs = {} # (device.context: ModuleExec)
|
|
@@ -1748,6 +1824,7 @@ class Module:
|
|
|
1748
1824
|
"fast_math": False,
|
|
1749
1825
|
"cuda_output": None, # supported values: "ptx", "cubin", or None (automatic)
|
|
1750
1826
|
"mode": warp.config.mode,
|
|
1827
|
+
"block_dim": 256,
|
|
1751
1828
|
}
|
|
1752
1829
|
|
|
1753
1830
|
# Module dependencies are determined by scanning each function
|
|
@@ -1772,7 +1849,7 @@ class Module:
|
|
|
1772
1849
|
self.kernels[kernel.key] = kernel
|
|
1773
1850
|
|
|
1774
1851
|
# track all kernel objects, even if they are duplicates
|
|
1775
|
-
self.
|
|
1852
|
+
self._live_kernels.add(kernel)
|
|
1776
1853
|
|
|
1777
1854
|
self.find_references(kernel.adj)
|
|
1778
1855
|
|
|
@@ -1838,6 +1915,19 @@ class Module:
|
|
|
1838
1915
|
# for a reload of module on next launch
|
|
1839
1916
|
self.mark_modified()
|
|
1840
1917
|
|
|
1918
|
+
@property
|
|
1919
|
+
def live_kernels(self):
|
|
1920
|
+
# Return a list of kernels that still have references.
|
|
1921
|
+
# We return a regular list instead of the WeakSet to avoid undesirable issues
|
|
1922
|
+
# if kernels are garbage collected before the caller is done using this list.
|
|
1923
|
+
# Note that we should avoid retaining strong references to kernels unnecessarily
|
|
1924
|
+
# so that Python GC can release kernels that no longer have user references.
|
|
1925
|
+
# It is tempting to call gc.collect() here to force garbage collection,
|
|
1926
|
+
# but this can have undesirable consequences (e.g., GC during graph capture),
|
|
1927
|
+
# so we should avoid it as a general rule. Instead, we rely on Python's
|
|
1928
|
+
# reference counting GC to collect kernels that have gone out of scope.
|
|
1929
|
+
return list(self._live_kernels)
|
|
1930
|
+
|
|
1841
1931
|
# find kernel corresponding to a Python function
|
|
1842
1932
|
def find_kernel(self, func):
|
|
1843
1933
|
qualname = warp.codegen.make_full_qualified_name(func)
|
|
@@ -1878,9 +1968,17 @@ class Module:
|
|
|
1878
1968
|
self.hasher = ModuleHasher(self)
|
|
1879
1969
|
return self.hasher.get_module_hash()
|
|
1880
1970
|
|
|
1881
|
-
def load(self, device) -> ModuleExec:
|
|
1971
|
+
def load(self, device, block_dim=None) -> ModuleExec:
|
|
1882
1972
|
device = runtime.get_device(device)
|
|
1883
1973
|
|
|
1974
|
+
# re-compile module if tile size (blockdim) changes
|
|
1975
|
+
# todo: it would be better to have a method such as `module.get_kernel(block_dim=N)`
|
|
1976
|
+
# that can return a single kernel instance with a given block size
|
|
1977
|
+
if block_dim is not None:
|
|
1978
|
+
if self.options["block_dim"] != block_dim:
|
|
1979
|
+
self.unload()
|
|
1980
|
+
self.options["block_dim"] = block_dim
|
|
1981
|
+
|
|
1884
1982
|
# compute the hash if needed
|
|
1885
1983
|
if self.hasher is None:
|
|
1886
1984
|
self.hasher = ModuleHasher(self)
|
|
@@ -1908,6 +2006,7 @@ class Module:
|
|
|
1908
2006
|
# determine output paths
|
|
1909
2007
|
if device.is_cpu:
|
|
1910
2008
|
output_name = "module_codegen.o"
|
|
2009
|
+
output_arch = None
|
|
1911
2010
|
|
|
1912
2011
|
elif device.is_cuda:
|
|
1913
2012
|
# determine whether to use PTX or CUBIN
|
|
@@ -1946,7 +2045,12 @@ class Module:
|
|
|
1946
2045
|
or not warp.config.cache_kernels
|
|
1947
2046
|
or warp.config.verify_autograd_array_access
|
|
1948
2047
|
):
|
|
1949
|
-
|
|
2048
|
+
builder_options = {
|
|
2049
|
+
**self.options,
|
|
2050
|
+
# Some of the Tile codegen, such as cuFFTDx and cuBLASDx, requires knowledge of the target arch
|
|
2051
|
+
"output_arch": output_arch,
|
|
2052
|
+
}
|
|
2053
|
+
builder = ModuleBuilder(self, builder_options, hasher=self.hasher)
|
|
1950
2054
|
|
|
1951
2055
|
# create a temporary (process unique) dir for build outputs before moving to the binary dir
|
|
1952
2056
|
build_dir = os.path.join(
|
|
@@ -2009,6 +2113,7 @@ class Module:
|
|
|
2009
2113
|
config=self.options["mode"],
|
|
2010
2114
|
fast_math=self.options["fast_math"],
|
|
2011
2115
|
verify_fp=warp.config.verify_fp,
|
|
2116
|
+
ltoirs=builder.ltoirs.values(),
|
|
2012
2117
|
)
|
|
2013
2118
|
|
|
2014
2119
|
except Exception as e:
|
|
@@ -2016,6 +2121,15 @@ class Module:
|
|
|
2016
2121
|
module_load_timer.extra_msg = " (error)"
|
|
2017
2122
|
raise (e)
|
|
2018
2123
|
|
|
2124
|
+
# ------------------------------------------------------------
|
|
2125
|
+
# build meta data
|
|
2126
|
+
|
|
2127
|
+
meta = builder.build_meta()
|
|
2128
|
+
meta_path = os.path.join(build_dir, "module_codegen.meta")
|
|
2129
|
+
|
|
2130
|
+
with open(meta_path, "w") as meta_file:
|
|
2131
|
+
json.dump(meta, meta_file)
|
|
2132
|
+
|
|
2019
2133
|
# -----------------------------------------------------------
|
|
2020
2134
|
# update cache
|
|
2021
2135
|
|
|
@@ -2052,18 +2166,23 @@ class Module:
|
|
|
2052
2166
|
|
|
2053
2167
|
# -----------------------------------------------------------
|
|
2054
2168
|
# Load CPU or CUDA binary
|
|
2169
|
+
|
|
2170
|
+
meta_path = os.path.join(module_dir, "module_codegen.meta")
|
|
2171
|
+
with open(meta_path, "r") as meta_file:
|
|
2172
|
+
meta = json.load(meta_file)
|
|
2173
|
+
|
|
2055
2174
|
if device.is_cpu:
|
|
2056
2175
|
# LLVM modules are identified using strings, so we need to ensure uniqueness
|
|
2057
2176
|
module_handle = f"{module_name}_{self.cpu_exec_id}"
|
|
2058
2177
|
self.cpu_exec_id += 1
|
|
2059
2178
|
runtime.llvm.load_obj(binary_path.encode("utf-8"), module_handle.encode("utf-8"))
|
|
2060
|
-
module_exec = ModuleExec(module_handle, module_hash, device)
|
|
2179
|
+
module_exec = ModuleExec(module_handle, module_hash, device, meta)
|
|
2061
2180
|
self.execs[None] = module_exec
|
|
2062
2181
|
|
|
2063
2182
|
elif device.is_cuda:
|
|
2064
2183
|
cuda_module = warp.build.load_cuda(binary_path, device)
|
|
2065
2184
|
if cuda_module is not None:
|
|
2066
|
-
module_exec = ModuleExec(cuda_module, module_hash, device)
|
|
2185
|
+
module_exec = ModuleExec(cuda_module, module_hash, device, meta)
|
|
2067
2186
|
self.execs[device.context] = module_exec
|
|
2068
2187
|
else:
|
|
2069
2188
|
module_load_timer.extra_msg = " (error)"
|
|
@@ -2718,21 +2837,16 @@ class Graph:
|
|
|
2718
2837
|
|
|
2719
2838
|
class Runtime:
|
|
2720
2839
|
def __init__(self):
|
|
2721
|
-
if sys.version_info < (3,
|
|
2722
|
-
raise RuntimeError("Warp requires Python 3.
|
|
2840
|
+
if sys.version_info < (3, 8):
|
|
2841
|
+
raise RuntimeError("Warp requires Python 3.8 as a minimum")
|
|
2723
2842
|
if sys.version_info < (3, 9):
|
|
2724
2843
|
warp.utils.warn(f"Python 3.9 or newer is recommended for running Warp, detected {sys.version_info}")
|
|
2725
2844
|
|
|
2726
2845
|
bin_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "bin")
|
|
2727
2846
|
|
|
2728
2847
|
if os.name == "nt":
|
|
2729
|
-
|
|
2730
|
-
|
|
2731
|
-
os.add_dll_directory(bin_path)
|
|
2732
|
-
|
|
2733
|
-
else:
|
|
2734
|
-
# Python < 3.8 we add dll directory to path
|
|
2735
|
-
os.environ["PATH"] = bin_path + os.pathsep + os.environ["PATH"]
|
|
2848
|
+
# Python >= 3.8 this method to add dll search paths
|
|
2849
|
+
os.add_dll_directory(bin_path)
|
|
2736
2850
|
|
|
2737
2851
|
warp_lib = os.path.join(bin_path, "warp.dll")
|
|
2738
2852
|
llvm_lib = os.path.join(bin_path, "warp-clang.dll")
|
|
@@ -3204,6 +3318,8 @@ class Runtime:
|
|
|
3204
3318
|
self.core.is_cuda_compatibility_enabled.restype = ctypes.c_int
|
|
3205
3319
|
self.core.is_cutlass_enabled.argtypes = None
|
|
3206
3320
|
self.core.is_cutlass_enabled.restype = ctypes.c_int
|
|
3321
|
+
self.core.is_mathdx_enabled.argtypes = None
|
|
3322
|
+
self.core.is_mathdx_enabled.restype = ctypes.c_int
|
|
3207
3323
|
|
|
3208
3324
|
self.core.cuda_driver_version.argtypes = None
|
|
3209
3325
|
self.core.cuda_driver_version.restype = ctypes.c_int
|
|
@@ -3328,17 +3444,58 @@ class Runtime:
|
|
|
3328
3444
|
self.core.cuda_graph_destroy.restype = ctypes.c_bool
|
|
3329
3445
|
|
|
3330
3446
|
self.core.cuda_compile_program.argtypes = [
|
|
3331
|
-
ctypes.c_char_p,
|
|
3332
|
-
ctypes.c_int,
|
|
3333
|
-
ctypes.c_char_p,
|
|
3334
|
-
ctypes.
|
|
3335
|
-
ctypes.
|
|
3336
|
-
ctypes.c_bool,
|
|
3337
|
-
ctypes.c_bool,
|
|
3338
|
-
ctypes.
|
|
3447
|
+
ctypes.c_char_p, # cuda_src
|
|
3448
|
+
ctypes.c_int, # arch
|
|
3449
|
+
ctypes.c_char_p, # include_dir
|
|
3450
|
+
ctypes.c_int, # num_cuda_include_dirs
|
|
3451
|
+
ctypes.POINTER(ctypes.c_char_p), # cuda include dirs
|
|
3452
|
+
ctypes.c_bool, # debug
|
|
3453
|
+
ctypes.c_bool, # verbose
|
|
3454
|
+
ctypes.c_bool, # verify_fp
|
|
3455
|
+
ctypes.c_bool, # fast_math
|
|
3456
|
+
ctypes.c_char_p, # output_path
|
|
3457
|
+
ctypes.c_size_t, # num_ltoirs
|
|
3458
|
+
ctypes.POINTER(ctypes.c_char_p), # ltoirs
|
|
3459
|
+
ctypes.POINTER(ctypes.c_size_t), # ltoir_sizes
|
|
3339
3460
|
]
|
|
3340
3461
|
self.core.cuda_compile_program.restype = ctypes.c_size_t
|
|
3341
3462
|
|
|
3463
|
+
self.core.cuda_compile_fft.argtypes = [
|
|
3464
|
+
ctypes.c_char_p, # lto
|
|
3465
|
+
ctypes.c_char_p, # function name
|
|
3466
|
+
ctypes.c_int, # num include dirs
|
|
3467
|
+
ctypes.POINTER(ctypes.c_char_p), # include dirs
|
|
3468
|
+
ctypes.c_char_p, # mathdx include dir
|
|
3469
|
+
ctypes.c_int, # arch
|
|
3470
|
+
ctypes.c_int, # size
|
|
3471
|
+
ctypes.c_int, # ept
|
|
3472
|
+
ctypes.c_int, # direction
|
|
3473
|
+
ctypes.c_int, # precision
|
|
3474
|
+
ctypes.POINTER(ctypes.c_int), # smem (out)
|
|
3475
|
+
]
|
|
3476
|
+
self.core.cuda_compile_fft.restype = ctypes.c_bool
|
|
3477
|
+
|
|
3478
|
+
self.core.cuda_compile_dot.argtypes = [
|
|
3479
|
+
ctypes.c_char_p, # lto
|
|
3480
|
+
ctypes.c_char_p, # function name
|
|
3481
|
+
ctypes.c_int, # num include dirs
|
|
3482
|
+
ctypes.POINTER(ctypes.c_char_p), # include dirs
|
|
3483
|
+
ctypes.c_char_p, # mathdx include dir
|
|
3484
|
+
ctypes.c_int, # arch
|
|
3485
|
+
ctypes.c_int, # M
|
|
3486
|
+
ctypes.c_int, # N
|
|
3487
|
+
ctypes.c_int, # K
|
|
3488
|
+
ctypes.c_int, # a_precision
|
|
3489
|
+
ctypes.c_int, # b_precision
|
|
3490
|
+
ctypes.c_int, # c_precision
|
|
3491
|
+
ctypes.c_int, # type
|
|
3492
|
+
ctypes.c_int, # a_arrangement
|
|
3493
|
+
ctypes.c_int, # b_arrangement
|
|
3494
|
+
ctypes.c_int, # c_arrangement
|
|
3495
|
+
ctypes.c_int, # num threads
|
|
3496
|
+
]
|
|
3497
|
+
self.core.cuda_compile_dot.restype = ctypes.c_bool
|
|
3498
|
+
|
|
3342
3499
|
self.core.cuda_load_module.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
|
|
3343
3500
|
self.core.cuda_load_module.restype = ctypes.c_void_p
|
|
3344
3501
|
|
|
@@ -3348,11 +3505,19 @@ class Runtime:
|
|
|
3348
3505
|
self.core.cuda_get_kernel.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_char_p]
|
|
3349
3506
|
self.core.cuda_get_kernel.restype = ctypes.c_void_p
|
|
3350
3507
|
|
|
3508
|
+
self.core.cuda_get_max_shared_memory.argtypes = [ctypes.c_void_p]
|
|
3509
|
+
self.core.cuda_get_max_shared_memory.restype = ctypes.c_int
|
|
3510
|
+
|
|
3511
|
+
self.core.cuda_configure_kernel_shared_memory.argtypes = [ctypes.c_void_p, ctypes.c_int]
|
|
3512
|
+
self.core.cuda_configure_kernel_shared_memory.restype = ctypes.c_bool
|
|
3513
|
+
|
|
3351
3514
|
self.core.cuda_launch_kernel.argtypes = [
|
|
3352
3515
|
ctypes.c_void_p,
|
|
3353
3516
|
ctypes.c_void_p,
|
|
3354
3517
|
ctypes.c_size_t,
|
|
3355
3518
|
ctypes.c_int,
|
|
3519
|
+
ctypes.c_int,
|
|
3520
|
+
ctypes.c_int,
|
|
3356
3521
|
ctypes.POINTER(ctypes.c_void_p),
|
|
3357
3522
|
ctypes.c_void_p,
|
|
3358
3523
|
]
|
|
@@ -3381,6 +3546,23 @@ class Runtime:
|
|
|
3381
3546
|
self.core.cuda_timing_end.argtypes = []
|
|
3382
3547
|
self.core.cuda_timing_end.restype = None
|
|
3383
3548
|
|
|
3549
|
+
self.core.graph_coloring.argtypes = [
|
|
3550
|
+
ctypes.c_int,
|
|
3551
|
+
warp.types.array_t,
|
|
3552
|
+
ctypes.c_int,
|
|
3553
|
+
warp.types.array_t,
|
|
3554
|
+
]
|
|
3555
|
+
self.core.graph_coloring.restype = ctypes.c_int
|
|
3556
|
+
|
|
3557
|
+
self.core.balance_coloring.argtypes = [
|
|
3558
|
+
ctypes.c_int,
|
|
3559
|
+
warp.types.array_t,
|
|
3560
|
+
ctypes.c_int,
|
|
3561
|
+
ctypes.c_float,
|
|
3562
|
+
warp.types.array_t,
|
|
3563
|
+
]
|
|
3564
|
+
self.core.balance_coloring.restype = ctypes.c_float
|
|
3565
|
+
|
|
3384
3566
|
self.core.init.restype = ctypes.c_int
|
|
3385
3567
|
|
|
3386
3568
|
except AttributeError as e:
|
|
@@ -3606,10 +3788,7 @@ class Runtime:
|
|
|
3606
3788
|
|
|
3607
3789
|
def load_dll(self, dll_path):
|
|
3608
3790
|
try:
|
|
3609
|
-
|
|
3610
|
-
dll = ctypes.CDLL(dll_path, winmode=0)
|
|
3611
|
-
else:
|
|
3612
|
-
dll = ctypes.CDLL(dll_path)
|
|
3791
|
+
dll = ctypes.CDLL(dll_path, winmode=0)
|
|
3613
3792
|
except OSError as e:
|
|
3614
3793
|
if "GLIBCXX" in str(e):
|
|
3615
3794
|
raise RuntimeError(
|
|
@@ -3750,7 +3929,7 @@ def is_cuda_available() -> bool:
|
|
|
3750
3929
|
return get_cuda_device_count() > 0
|
|
3751
3930
|
|
|
3752
3931
|
|
|
3753
|
-
def is_device_available(device):
|
|
3932
|
+
def is_device_available(device: Device) -> bool:
|
|
3754
3933
|
return device in get_devices()
|
|
3755
3934
|
|
|
3756
3935
|
|
|
@@ -3810,7 +3989,7 @@ def get_cuda_devices() -> List[Device]:
|
|
|
3810
3989
|
|
|
3811
3990
|
|
|
3812
3991
|
def get_preferred_device() -> Device:
|
|
3813
|
-
"""Returns the preferred compute device,
|
|
3992
|
+
"""Returns the preferred compute device, ``cuda:0`` if available and ``cpu`` otherwise."""
|
|
3814
3993
|
|
|
3815
3994
|
init()
|
|
3816
3995
|
|
|
@@ -3950,7 +4129,7 @@ def set_mempool_release_threshold(device: Devicelike, threshold: Union[int, floa
|
|
|
3950
4129
|
|
|
3951
4130
|
|
|
3952
4131
|
def get_mempool_release_threshold(device: Devicelike) -> int:
|
|
3953
|
-
"""Get the CUDA memory pool release threshold on the device."""
|
|
4132
|
+
"""Get the CUDA memory pool release threshold on the device in bytes."""
|
|
3954
4133
|
|
|
3955
4134
|
init()
|
|
3956
4135
|
|
|
@@ -3969,7 +4148,7 @@ def is_peer_access_supported(target_device: Devicelike, peer_device: Devicelike)
|
|
|
3969
4148
|
"""Check if `peer_device` can directly access the memory of `target_device` on this system.
|
|
3970
4149
|
|
|
3971
4150
|
This applies to memory allocated using default CUDA allocators. For memory allocated using
|
|
3972
|
-
CUDA pooled allocators, use
|
|
4151
|
+
CUDA pooled allocators, use :func:`is_mempool_access_supported()`.
|
|
3973
4152
|
|
|
3974
4153
|
Returns:
|
|
3975
4154
|
A Boolean value indicating if this peer access is supported by the system.
|
|
@@ -3990,7 +4169,7 @@ def is_peer_access_enabled(target_device: Devicelike, peer_device: Devicelike) -
|
|
|
3990
4169
|
"""Check if `peer_device` can currently access the memory of `target_device`.
|
|
3991
4170
|
|
|
3992
4171
|
This applies to memory allocated using default CUDA allocators. For memory allocated using
|
|
3993
|
-
CUDA pooled allocators, use
|
|
4172
|
+
CUDA pooled allocators, use :func:`is_mempool_access_enabled()`.
|
|
3994
4173
|
|
|
3995
4174
|
Returns:
|
|
3996
4175
|
A Boolean value indicating if this peer access is currently enabled.
|
|
@@ -4014,7 +4193,7 @@ def set_peer_access_enabled(target_device: Devicelike, peer_device: Devicelike,
|
|
|
4014
4193
|
a negative impact on memory consumption and allocation performance.
|
|
4015
4194
|
|
|
4016
4195
|
This applies to memory allocated using default CUDA allocators. For memory allocated using
|
|
4017
|
-
CUDA pooled allocators, use
|
|
4196
|
+
CUDA pooled allocators, use :func:`set_mempool_access_enabled()`.
|
|
4018
4197
|
"""
|
|
4019
4198
|
|
|
4020
4199
|
init()
|
|
@@ -4042,7 +4221,8 @@ def set_peer_access_enabled(target_device: Devicelike, peer_device: Devicelike,
|
|
|
4042
4221
|
def is_mempool_access_supported(target_device: Devicelike, peer_device: Devicelike) -> bool:
|
|
4043
4222
|
"""Check if `peer_device` can directly access the memory pool of `target_device`.
|
|
4044
4223
|
|
|
4045
|
-
If mempool access is possible, it can be managed using
|
|
4224
|
+
If mempool access is possible, it can be managed using :func:`set_mempool_access_enabled()`
|
|
4225
|
+
and :func:`is_mempool_access_enabled()`.
|
|
4046
4226
|
|
|
4047
4227
|
Returns:
|
|
4048
4228
|
A Boolean value indicating if this memory pool access is supported by the system.
|
|
@@ -4060,7 +4240,7 @@ def is_mempool_access_enabled(target_device: Devicelike, peer_device: Devicelike
|
|
|
4060
4240
|
"""Check if `peer_device` can currently access the memory pool of `target_device`.
|
|
4061
4241
|
|
|
4062
4242
|
This applies to memory allocated using CUDA pooled allocators. For memory allocated using
|
|
4063
|
-
default CUDA allocators, use
|
|
4243
|
+
default CUDA allocators, use :func:`is_peer_access_enabled()`.
|
|
4064
4244
|
|
|
4065
4245
|
Returns:
|
|
4066
4246
|
A Boolean value indicating if this peer access is currently enabled.
|
|
@@ -4081,7 +4261,7 @@ def set_mempool_access_enabled(target_device: Devicelike, peer_device: Devicelik
|
|
|
4081
4261
|
"""Enable or disable access from `peer_device` to the memory pool of `target_device`.
|
|
4082
4262
|
|
|
4083
4263
|
This applies to memory allocated using CUDA pooled allocators. For memory allocated using
|
|
4084
|
-
default CUDA allocators, use
|
|
4264
|
+
default CUDA allocators, use :func:`set_peer_access_enabled()`.
|
|
4085
4265
|
"""
|
|
4086
4266
|
|
|
4087
4267
|
init()
|
|
@@ -4790,7 +4970,9 @@ def pack_arg(kernel, arg_type, arg_name, value, device, adjoint=False):
|
|
|
4790
4970
|
# represents all data required for a kernel launch
|
|
4791
4971
|
# so that launches can be replayed quickly, use `wp.launch(..., record_cmd=True)`
|
|
4792
4972
|
class Launch:
|
|
4793
|
-
def __init__(
|
|
4973
|
+
def __init__(
|
|
4974
|
+
self, kernel, device, hooks=None, params=None, params_addr=None, bounds=None, max_blocks=0, block_dim=256
|
|
4975
|
+
):
|
|
4794
4976
|
# retain the module executable so it doesn't get unloaded
|
|
4795
4977
|
self.module_exec = kernel.module.load(device)
|
|
4796
4978
|
if not self.module_exec:
|
|
@@ -4829,6 +5011,7 @@ class Launch:
|
|
|
4829
5011
|
self.device = device
|
|
4830
5012
|
self.bounds = bounds
|
|
4831
5013
|
self.max_blocks = max_blocks
|
|
5014
|
+
self.block_dim = block_dim
|
|
4832
5015
|
|
|
4833
5016
|
def set_dim(self, dim):
|
|
4834
5017
|
self.bounds = warp.types.launch_bounds_t(dim)
|
|
@@ -4910,6 +5093,8 @@ class Launch:
|
|
|
4910
5093
|
self.hooks.forward,
|
|
4911
5094
|
self.bounds.size,
|
|
4912
5095
|
self.max_blocks,
|
|
5096
|
+
self.block_dim,
|
|
5097
|
+
self.hooks.forward_smem_bytes,
|
|
4913
5098
|
self.params_addr,
|
|
4914
5099
|
stream.cuda_stream,
|
|
4915
5100
|
)
|
|
@@ -4928,6 +5113,7 @@ def launch(
|
|
|
4928
5113
|
record_tape=True,
|
|
4929
5114
|
record_cmd=False,
|
|
4930
5115
|
max_blocks=0,
|
|
5116
|
+
block_dim=256,
|
|
4931
5117
|
):
|
|
4932
5118
|
"""Launch a Warp kernel on the target device
|
|
4933
5119
|
|
|
@@ -4947,6 +5133,7 @@ def launch(
|
|
|
4947
5133
|
record_cmd: When True the launch will be returned as a ``Launch`` command object, the launch will not occur until the user calls ``cmd.launch()``
|
|
4948
5134
|
max_blocks: The maximum number of CUDA thread blocks to use. Only has an effect for CUDA kernel launches.
|
|
4949
5135
|
If negative or zero, the maximum hardware value will be used.
|
|
5136
|
+
block_dim: The number of threads per block.
|
|
4950
5137
|
"""
|
|
4951
5138
|
|
|
4952
5139
|
init()
|
|
@@ -5000,7 +5187,12 @@ def launch(
|
|
|
5000
5187
|
kernel = kernel.add_overload(fwd_types)
|
|
5001
5188
|
|
|
5002
5189
|
# delay load modules, including new overload if needed
|
|
5003
|
-
|
|
5190
|
+
try:
|
|
5191
|
+
module_exec = kernel.module.load(device, block_dim)
|
|
5192
|
+
except Exception:
|
|
5193
|
+
kernel.adj.skip_build = True
|
|
5194
|
+
raise
|
|
5195
|
+
|
|
5004
5196
|
if not module_exec:
|
|
5005
5197
|
return
|
|
5006
5198
|
|
|
@@ -5056,7 +5248,14 @@ def launch(
|
|
|
5056
5248
|
)
|
|
5057
5249
|
|
|
5058
5250
|
runtime.core.cuda_launch_kernel(
|
|
5059
|
-
device.context,
|
|
5251
|
+
device.context,
|
|
5252
|
+
hooks.backward,
|
|
5253
|
+
bounds.size,
|
|
5254
|
+
max_blocks,
|
|
5255
|
+
block_dim,
|
|
5256
|
+
hooks.backward_smem_bytes,
|
|
5257
|
+
kernel_params,
|
|
5258
|
+
stream.cuda_stream,
|
|
5060
5259
|
)
|
|
5061
5260
|
|
|
5062
5261
|
else:
|
|
@@ -5079,7 +5278,14 @@ def launch(
|
|
|
5079
5278
|
else:
|
|
5080
5279
|
# launch
|
|
5081
5280
|
runtime.core.cuda_launch_kernel(
|
|
5082
|
-
device.context,
|
|
5281
|
+
device.context,
|
|
5282
|
+
hooks.forward,
|
|
5283
|
+
bounds.size,
|
|
5284
|
+
max_blocks,
|
|
5285
|
+
block_dim,
|
|
5286
|
+
hooks.forward_smem_bytes,
|
|
5287
|
+
kernel_params,
|
|
5288
|
+
stream.cuda_stream,
|
|
5083
5289
|
)
|
|
5084
5290
|
|
|
5085
5291
|
try:
|
|
@@ -5093,13 +5299,65 @@ def launch(
|
|
|
5093
5299
|
# record file, lineno, func as metadata
|
|
5094
5300
|
frame = inspect.currentframe().f_back
|
|
5095
5301
|
caller = {"file": frame.f_code.co_filename, "lineno": frame.f_lineno, "func": frame.f_code.co_name}
|
|
5096
|
-
runtime.tape.record_launch(
|
|
5302
|
+
runtime.tape.record_launch(
|
|
5303
|
+
kernel, dim, max_blocks, inputs, outputs, device, block_dim, metadata={"caller": caller}
|
|
5304
|
+
)
|
|
5097
5305
|
|
|
5098
5306
|
# detect illegal inter-kernel read/write access patterns if verification flag is set
|
|
5099
5307
|
if warp.config.verify_autograd_array_access:
|
|
5100
5308
|
runtime.tape._check_kernel_array_access(kernel, fwd_args)
|
|
5101
5309
|
|
|
5102
5310
|
|
|
5311
|
+
def launch_tiled(*args, **kwargs):
|
|
5312
|
+
"""A helper method for launching a grid with an extra trailing dimension equal to the block size.
|
|
5313
|
+
|
|
5314
|
+
For example, to launch a 2D grid, where each element has 64 threads assigned you would use the following:
|
|
5315
|
+
|
|
5316
|
+
.. code-block:: python
|
|
5317
|
+
|
|
5318
|
+
wp.launch_tiled(kernel, [M, N], inputs=[...], block_dim=64)
|
|
5319
|
+
|
|
5320
|
+
Which is equivalent to the following:
|
|
5321
|
+
|
|
5322
|
+
.. code-block:: python
|
|
5323
|
+
|
|
5324
|
+
wp.launch(kernel, [M, N, 64], inputs=[...], block_dim=64)
|
|
5325
|
+
|
|
5326
|
+
Inside your kernel code you can retrieve the first two indices of the thread as usual, ignoring the implicit third dimension if desired:
|
|
5327
|
+
|
|
5328
|
+
.. code-block:: python
|
|
5329
|
+
|
|
5330
|
+
@wp.kernel
|
|
5331
|
+
def compute()
|
|
5332
|
+
|
|
5333
|
+
i, j = wp.tid()
|
|
5334
|
+
|
|
5335
|
+
...
|
|
5336
|
+
"""
|
|
5337
|
+
|
|
5338
|
+
# promote dim to a list in case it was passed as a scalar or tuple
|
|
5339
|
+
if "dim" not in kwargs:
|
|
5340
|
+
raise RuntimeError("Launch dimensions 'dim' argument should be passed via. keyword args for wp.launch_tiled()")
|
|
5341
|
+
|
|
5342
|
+
if "block_dim" not in kwargs:
|
|
5343
|
+
raise RuntimeError(
|
|
5344
|
+
"Launch block dimension 'block_dim' argument should be passed via. keyword args for wp.launch_tiled()"
|
|
5345
|
+
)
|
|
5346
|
+
|
|
5347
|
+
dim = kwargs["dim"]
|
|
5348
|
+
if not isinstance(dim, list):
|
|
5349
|
+
dim = list(dim) if isinstance(dim, tuple) else [dim]
|
|
5350
|
+
|
|
5351
|
+
if len(dim) > 3:
|
|
5352
|
+
raise RuntimeError("wp.launch_tiled() requires a grid with fewer than 4 dimensions")
|
|
5353
|
+
|
|
5354
|
+
# add trailing dimension
|
|
5355
|
+
kwargs["dim"] = dim + [kwargs["block_dim"]]
|
|
5356
|
+
|
|
5357
|
+
# forward to original launch method
|
|
5358
|
+
launch(*args, **kwargs)
|
|
5359
|
+
|
|
5360
|
+
|
|
5103
5361
|
def synchronize():
|
|
5104
5362
|
"""Manually synchronize the calling CPU thread with any outstanding CUDA work on all devices
|
|
5105
5363
|
|
|
@@ -5618,16 +5876,6 @@ def type_str(t):
|
|
|
5618
5876
|
return "Any"
|
|
5619
5877
|
elif t == Callable:
|
|
5620
5878
|
return "Callable"
|
|
5621
|
-
elif t == Tuple[int]:
|
|
5622
|
-
return "Tuple[int]"
|
|
5623
|
-
elif t == Tuple[int, int]:
|
|
5624
|
-
return "Tuple[int, int]"
|
|
5625
|
-
elif t == Tuple[int, int, int]:
|
|
5626
|
-
return "Tuple[int, int, int]"
|
|
5627
|
-
elif t == Tuple[int, int, int, int]:
|
|
5628
|
-
return "Tuple[int, int, int, int]"
|
|
5629
|
-
elif t == Tuple[int, ...]:
|
|
5630
|
-
return "Tuple[int, ...]"
|
|
5631
5879
|
elif isinstance(t, int):
|
|
5632
5880
|
return str(t)
|
|
5633
5881
|
elif isinstance(t, List):
|
|
@@ -5662,9 +5910,13 @@ def type_str(t):
|
|
|
5662
5910
|
return f"Transformation[{type_str(t._wp_scalar_type_)}]"
|
|
5663
5911
|
|
|
5664
5912
|
raise TypeError("Invalid vector or matrix dimensions")
|
|
5665
|
-
elif
|
|
5666
|
-
args_repr = ", ".join(type_str(x) for x in
|
|
5667
|
-
return f"{t.
|
|
5913
|
+
elif warp.codegen.get_type_origin(t) in (list, tuple):
|
|
5914
|
+
args_repr = ", ".join(type_str(x) for x in warp.codegen.get_type_args(t))
|
|
5915
|
+
return f"{t._name}[{args_repr}]"
|
|
5916
|
+
elif t is Ellipsis:
|
|
5917
|
+
return "..."
|
|
5918
|
+
elif warp.types.is_tile(t):
|
|
5919
|
+
return "Tile"
|
|
5668
5920
|
|
|
5669
5921
|
return t.__name__
|
|
5670
5922
|
|
|
@@ -5825,9 +6077,6 @@ def export_stubs(file): # pragma: no cover
|
|
|
5825
6077
|
print('Cols = TypeVar("Cols", bound=int)', file=file)
|
|
5826
6078
|
print('DType = TypeVar("DType")', file=file)
|
|
5827
6079
|
|
|
5828
|
-
print('Int = TypeVar("Int")', file=file)
|
|
5829
|
-
print('Float = TypeVar("Float")', file=file)
|
|
5830
|
-
print('Scalar = TypeVar("Scalar")', file=file)
|
|
5831
6080
|
print("Vector = Generic[Length, Scalar]", file=file)
|
|
5832
6081
|
print("Matrix = Generic[Rows, Cols, Scalar]", file=file)
|
|
5833
6082
|
print("Quaternion = Generic[Float]", file=file)
|