warp-lang 1.4.2__py3-none-manylinux2014_x86_64.whl → 1.5.1__py3-none-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +4 -0
- warp/autograd.py +43 -8
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/build.py +21 -2
- warp/build_dll.py +23 -6
- warp/builtins.py +1819 -7
- warp/codegen.py +197 -61
- warp/config.py +2 -2
- warp/context.py +379 -107
- warp/examples/assets/pixel.jpg +0 -0
- warp/examples/benchmarks/benchmark_cloth_paddle.py +86 -0
- warp/examples/benchmarks/benchmark_gemm.py +121 -0
- warp/examples/benchmarks/benchmark_interop_paddle.py +158 -0
- warp/examples/benchmarks/benchmark_tile.py +179 -0
- warp/examples/fem/example_adaptive_grid.py +37 -10
- warp/examples/fem/example_apic_fluid.py +3 -2
- warp/examples/fem/example_convection_diffusion_dg.py +4 -5
- warp/examples/fem/example_deformed_geometry.py +1 -1
- warp/examples/fem/example_diffusion_3d.py +47 -4
- warp/examples/fem/example_distortion_energy.py +220 -0
- warp/examples/fem/example_magnetostatics.py +127 -85
- warp/examples/fem/example_nonconforming_contact.py +5 -5
- warp/examples/fem/example_stokes.py +3 -1
- warp/examples/fem/example_streamlines.py +12 -19
- warp/examples/fem/utils.py +38 -15
- warp/examples/sim/example_cloth.py +4 -25
- warp/examples/sim/example_quadruped.py +2 -1
- warp/examples/tile/example_tile_convolution.py +58 -0
- warp/examples/tile/example_tile_fft.py +47 -0
- warp/examples/tile/example_tile_filtering.py +105 -0
- warp/examples/tile/example_tile_matmul.py +79 -0
- warp/examples/tile/example_tile_mlp.py +375 -0
- warp/fem/__init__.py +8 -0
- warp/fem/cache.py +16 -12
- warp/fem/dirichlet.py +1 -1
- warp/fem/domain.py +44 -1
- warp/fem/field/__init__.py +1 -2
- warp/fem/field/field.py +31 -19
- warp/fem/field/nodal_field.py +101 -49
- warp/fem/field/virtual.py +794 -0
- warp/fem/geometry/__init__.py +2 -2
- warp/fem/geometry/deformed_geometry.py +3 -105
- warp/fem/geometry/element.py +13 -0
- warp/fem/geometry/geometry.py +165 -7
- warp/fem/geometry/grid_2d.py +3 -6
- warp/fem/geometry/grid_3d.py +31 -28
- warp/fem/geometry/hexmesh.py +3 -46
- warp/fem/geometry/nanogrid.py +3 -2
- warp/fem/geometry/{quadmesh_2d.py → quadmesh.py} +280 -159
- warp/fem/geometry/tetmesh.py +2 -43
- warp/fem/geometry/{trimesh_2d.py → trimesh.py} +354 -186
- warp/fem/integrate.py +683 -261
- warp/fem/linalg.py +404 -0
- warp/fem/operator.py +101 -18
- warp/fem/polynomial.py +5 -5
- warp/fem/quadrature/quadrature.py +45 -21
- warp/fem/space/__init__.py +45 -11
- warp/fem/space/basis_function_space.py +451 -0
- warp/fem/space/basis_space.py +58 -11
- warp/fem/space/function_space.py +146 -5
- warp/fem/space/grid_2d_function_space.py +80 -66
- warp/fem/space/grid_3d_function_space.py +113 -68
- warp/fem/space/hexmesh_function_space.py +96 -108
- warp/fem/space/nanogrid_function_space.py +62 -110
- warp/fem/space/quadmesh_function_space.py +208 -0
- warp/fem/space/shape/__init__.py +45 -7
- warp/fem/space/shape/cube_shape_function.py +328 -54
- warp/fem/space/shape/shape_function.py +10 -1
- warp/fem/space/shape/square_shape_function.py +328 -60
- warp/fem/space/shape/tet_shape_function.py +269 -19
- warp/fem/space/shape/triangle_shape_function.py +238 -19
- warp/fem/space/tetmesh_function_space.py +69 -37
- warp/fem/space/topology.py +38 -0
- warp/fem/space/trimesh_function_space.py +179 -0
- warp/fem/utils.py +6 -331
- warp/jax_experimental.py +3 -1
- warp/native/array.h +15 -0
- warp/native/builtin.h +66 -26
- warp/native/bvh.h +4 -0
- warp/native/coloring.cpp +604 -0
- warp/native/cuda_util.cpp +68 -51
- warp/native/cuda_util.h +2 -1
- warp/native/fabric.h +8 -0
- warp/native/hashgrid.h +4 -0
- warp/native/marching.cu +8 -0
- warp/native/mat.h +14 -3
- warp/native/mathdx.cpp +59 -0
- warp/native/mesh.h +4 -0
- warp/native/range.h +13 -1
- warp/native/reduce.cpp +9 -1
- warp/native/reduce.cu +7 -0
- warp/native/runlength_encode.cpp +9 -1
- warp/native/runlength_encode.cu +7 -1
- warp/native/scan.cpp +8 -0
- warp/native/scan.cu +8 -0
- warp/native/scan.h +8 -1
- warp/native/sparse.cpp +8 -0
- warp/native/sparse.cu +8 -0
- warp/native/temp_buffer.h +7 -0
- warp/native/tile.h +1854 -0
- warp/native/tile_gemm.h +341 -0
- warp/native/tile_reduce.h +210 -0
- warp/native/volume_builder.cu +8 -0
- warp/native/volume_builder.h +8 -0
- warp/native/warp.cpp +10 -2
- warp/native/warp.cu +369 -15
- warp/native/warp.h +12 -2
- warp/optim/adam.py +39 -4
- warp/paddle.py +29 -12
- warp/render/render_opengl.py +140 -67
- warp/sim/graph_coloring.py +292 -0
- warp/sim/import_urdf.py +8 -8
- warp/sim/integrator_euler.py +4 -2
- warp/sim/integrator_featherstone.py +115 -44
- warp/sim/integrator_vbd.py +6 -0
- warp/sim/model.py +109 -32
- warp/sparse.py +1 -1
- warp/stubs.py +569 -4
- warp/tape.py +12 -7
- warp/tests/assets/pixel.npy +0 -0
- warp/tests/aux_test_instancing_gc.py +18 -0
- warp/tests/test_array.py +39 -0
- warp/tests/test_codegen.py +81 -1
- warp/tests/test_codegen_instancing.py +30 -0
- warp/tests/test_collision.py +110 -0
- warp/tests/test_coloring.py +251 -0
- warp/tests/test_context.py +34 -0
- warp/tests/test_examples.py +21 -5
- warp/tests/test_fem.py +453 -113
- warp/tests/test_func.py +34 -4
- warp/tests/test_generics.py +52 -0
- warp/tests/test_iter.py +68 -0
- warp/tests/test_lerp.py +13 -87
- warp/tests/test_mat_scalar_ops.py +1 -1
- warp/tests/test_matmul.py +6 -9
- warp/tests/test_matmul_lite.py +6 -11
- warp/tests/test_mesh_query_point.py +1 -1
- warp/tests/test_module_hashing.py +23 -0
- warp/tests/test_overwrite.py +45 -0
- warp/tests/test_paddle.py +27 -87
- warp/tests/test_print.py +56 -1
- warp/tests/test_smoothstep.py +17 -83
- warp/tests/test_spatial.py +1 -1
- warp/tests/test_static.py +3 -3
- warp/tests/test_tile.py +744 -0
- warp/tests/test_tile_mathdx.py +144 -0
- warp/tests/test_tile_mlp.py +383 -0
- warp/tests/test_tile_reduce.py +374 -0
- warp/tests/test_tile_shared_memory.py +190 -0
- warp/tests/test_vbd.py +12 -20
- warp/tests/test_volume.py +43 -0
- warp/tests/unittest_suites.py +19 -2
- warp/tests/unittest_utils.py +4 -2
- warp/types.py +340 -74
- warp/utils.py +23 -3
- {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/METADATA +32 -7
- {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/RECORD +161 -134
- {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/WHEEL +1 -1
- warp/fem/field/test.py +0 -180
- warp/fem/field/trial.py +0 -183
- warp/fem/space/collocated_function_space.py +0 -102
- warp/fem/space/quadmesh_2d_function_space.py +0 -261
- warp/fem/space/trimesh_2d_function_space.py +0 -153
- {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/LICENSE.md +0 -0
- {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/top_level.txt +0 -0
warp/context.py
CHANGED
|
@@ -7,21 +7,24 @@
|
|
|
7
7
|
|
|
8
8
|
import ast
|
|
9
9
|
import ctypes
|
|
10
|
+
import errno
|
|
10
11
|
import functools
|
|
11
12
|
import hashlib
|
|
12
13
|
import inspect
|
|
13
14
|
import io
|
|
14
15
|
import itertools
|
|
16
|
+
import json
|
|
15
17
|
import operator
|
|
16
18
|
import os
|
|
17
19
|
import platform
|
|
18
20
|
import sys
|
|
21
|
+
import time
|
|
19
22
|
import types
|
|
20
23
|
import typing
|
|
21
24
|
import weakref
|
|
22
25
|
from copy import copy as shallowcopy
|
|
23
26
|
from pathlib import Path
|
|
24
|
-
from typing import Any, Callable, Dict, List,
|
|
27
|
+
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
|
|
25
28
|
|
|
26
29
|
import numpy as np
|
|
27
30
|
|
|
@@ -101,6 +104,7 @@ class Function:
|
|
|
101
104
|
value_func=None,
|
|
102
105
|
export_func=None,
|
|
103
106
|
dispatch_func=None,
|
|
107
|
+
lto_dispatch_func=None,
|
|
104
108
|
module=None,
|
|
105
109
|
variadic=False,
|
|
106
110
|
initializer_list_func=None,
|
|
@@ -137,6 +141,7 @@ class Function:
|
|
|
137
141
|
self.value_func = value_func # a function that takes a list of args and a list of templates and returns the value type, e.g.: load(array, index) returns the type of value being loaded
|
|
138
142
|
self.export_func = export_func
|
|
139
143
|
self.dispatch_func = dispatch_func
|
|
144
|
+
self.lto_dispatch_func = lto_dispatch_func
|
|
140
145
|
self.input_types = {}
|
|
141
146
|
self.export = export
|
|
142
147
|
self.doc = doc
|
|
@@ -235,24 +240,23 @@ class Function:
|
|
|
235
240
|
# in a way that is compatible with Python's semantics.
|
|
236
241
|
signature_params = []
|
|
237
242
|
signature_default_param_kind = inspect.Parameter.POSITIONAL_OR_KEYWORD
|
|
238
|
-
for
|
|
239
|
-
if
|
|
240
|
-
param_name =
|
|
243
|
+
for raw_param_name in self.input_types.keys():
|
|
244
|
+
if raw_param_name.startswith("**"):
|
|
245
|
+
param_name = raw_param_name[2:]
|
|
241
246
|
param_kind = inspect.Parameter.VAR_KEYWORD
|
|
242
|
-
elif
|
|
243
|
-
param_name =
|
|
247
|
+
elif raw_param_name.startswith("*"):
|
|
248
|
+
param_name = raw_param_name[1:]
|
|
244
249
|
param_kind = inspect.Parameter.VAR_POSITIONAL
|
|
245
250
|
|
|
246
251
|
# Once a variadic argument like `*args` is found, any following
|
|
247
252
|
# arguments need to be passed using keywords.
|
|
248
253
|
signature_default_param_kind = inspect.Parameter.KEYWORD_ONLY
|
|
249
254
|
else:
|
|
255
|
+
param_name = raw_param_name
|
|
250
256
|
param_kind = signature_default_param_kind
|
|
251
257
|
|
|
252
|
-
param =
|
|
253
|
-
param_name,
|
|
254
|
-
param_kind,
|
|
255
|
-
default=self.defaults.get(param_name, inspect.Parameter.empty),
|
|
258
|
+
param = inspect.Parameter(
|
|
259
|
+
param_name, param_kind, default=self.defaults.get(param_name, inspect.Parameter.empty)
|
|
256
260
|
)
|
|
257
261
|
signature_params.append(param)
|
|
258
262
|
self.signature = inspect.Signature(signature_params)
|
|
@@ -291,22 +295,22 @@ class Function:
|
|
|
291
295
|
|
|
292
296
|
if hasattr(self, "user_overloads") and len(self.user_overloads):
|
|
293
297
|
# user-defined function with overloads
|
|
298
|
+
bound_args = self.signature.bind(*args, **kwargs)
|
|
299
|
+
if self.defaults:
|
|
300
|
+
warp.codegen.apply_defaults(bound_args, self.defaults)
|
|
294
301
|
|
|
295
|
-
|
|
296
|
-
raise RuntimeError(
|
|
297
|
-
f"Error calling function '{self.key}', keyword arguments are not supported for user-defined overloads."
|
|
298
|
-
)
|
|
302
|
+
arguments = tuple(bound_args.arguments.values())
|
|
299
303
|
|
|
300
304
|
# try and find a matching overload
|
|
301
305
|
for overload in self.user_overloads.values():
|
|
302
|
-
if len(overload.input_types) != len(
|
|
306
|
+
if len(overload.input_types) != len(arguments):
|
|
303
307
|
continue
|
|
304
308
|
template_types = list(overload.input_types.values())
|
|
305
309
|
arg_names = list(overload.input_types.keys())
|
|
306
310
|
try:
|
|
307
311
|
# attempt to unify argument types with function template types
|
|
308
|
-
warp.types.infer_argument_types(
|
|
309
|
-
return overload.func(*
|
|
312
|
+
warp.types.infer_argument_types(arguments, template_types, arg_names)
|
|
313
|
+
return overload.func(*arguments)
|
|
310
314
|
except Exception:
|
|
311
315
|
continue
|
|
312
316
|
|
|
@@ -506,11 +510,10 @@ def call_builtin(func: Function, *params) -> Tuple[bool, Any]:
|
|
|
506
510
|
if elem_count != arg_type._length_:
|
|
507
511
|
return (False, None)
|
|
508
512
|
|
|
509
|
-
# Retrieve the element type of the sequence while ensuring
|
|
510
|
-
# that it's homogeneous.
|
|
513
|
+
# Retrieve the element type of the sequence while ensuring that it's homogeneous.
|
|
511
514
|
elem_type = type(arr[0])
|
|
512
|
-
for
|
|
513
|
-
if type(arr[
|
|
515
|
+
for array_index in range(1, elem_count):
|
|
516
|
+
if type(arr[array_index]) is not elem_type:
|
|
514
517
|
raise ValueError("All array elements must share the same type.")
|
|
515
518
|
|
|
516
519
|
expected_elem_type = arg_type._wp_scalar_type_
|
|
@@ -540,10 +543,10 @@ def call_builtin(func: Function, *params) -> Tuple[bool, Any]:
|
|
|
540
543
|
c_param = arg_type()
|
|
541
544
|
if warp.types.type_is_matrix(arg_type):
|
|
542
545
|
rows, cols = arg_type._shape_
|
|
543
|
-
for
|
|
544
|
-
idx_start =
|
|
546
|
+
for row_index in range(rows):
|
|
547
|
+
idx_start = row_index * cols
|
|
545
548
|
idx_end = idx_start + cols
|
|
546
|
-
c_param[
|
|
549
|
+
c_param[row_index] = arr[idx_start:idx_end]
|
|
547
550
|
else:
|
|
548
551
|
c_param[:] = arr
|
|
549
552
|
|
|
@@ -619,10 +622,13 @@ def call_builtin(func: Function, *params) -> Tuple[bool, Any]:
|
|
|
619
622
|
|
|
620
623
|
|
|
621
624
|
class KernelHooks:
|
|
622
|
-
def __init__(self, forward, backward):
|
|
625
|
+
def __init__(self, forward, backward, forward_smem_bytes=0, backward_smem_bytes=0):
|
|
623
626
|
self.forward = forward
|
|
624
627
|
self.backward = backward
|
|
625
628
|
|
|
629
|
+
self.forward_smem_bytes = forward_smem_bytes
|
|
630
|
+
self.backward_smem_bytes = backward_smem_bytes
|
|
631
|
+
|
|
626
632
|
|
|
627
633
|
# caches source and compiled entry points for a kernel (will be populated after module loads)
|
|
628
634
|
class Kernel:
|
|
@@ -970,8 +976,17 @@ def struct(c):
|
|
|
970
976
|
return s
|
|
971
977
|
|
|
972
978
|
|
|
973
|
-
|
|
974
|
-
|
|
979
|
+
def overload(kernel, arg_types=Union[None, Dict[str, Any], List[Any]]):
|
|
980
|
+
"""Overload a generic kernel with the given argument types.
|
|
981
|
+
|
|
982
|
+
Can be called directly or used as a function decorator.
|
|
983
|
+
|
|
984
|
+
Args:
|
|
985
|
+
kernel: The generic kernel to be instantiated with concrete types.
|
|
986
|
+
arg_types: A list of concrete argument types for the kernel or a
|
|
987
|
+
dictionary specifying generic argument names as keys and concrete
|
|
988
|
+
types as variables.
|
|
989
|
+
"""
|
|
975
990
|
if isinstance(kernel, Kernel):
|
|
976
991
|
# handle cases where user calls us directly, e.g. wp.overload(kernel, [args...])
|
|
977
992
|
|
|
@@ -1073,6 +1088,7 @@ def add_builtin(
|
|
|
1073
1088
|
value_func=None,
|
|
1074
1089
|
export_func=None,
|
|
1075
1090
|
dispatch_func=None,
|
|
1091
|
+
lto_dispatch_func=None,
|
|
1076
1092
|
doc="",
|
|
1077
1093
|
namespace="wp::",
|
|
1078
1094
|
variadic=False,
|
|
@@ -1113,6 +1129,9 @@ def add_builtin(
|
|
|
1113
1129
|
The arguments returned must be of type `codegen.Var`.
|
|
1114
1130
|
If not provided, all arguments passed by the users when calling
|
|
1115
1131
|
the built-in are passed as-is as runtime arguments to the C++ function.
|
|
1132
|
+
lto_dispatch_func (Callable): Same as dispatch_func, but takes an 'option' dict
|
|
1133
|
+
as extra argument (indicating tile_size and target architecture) and returns
|
|
1134
|
+
an LTO-IR buffer as extra return value
|
|
1116
1135
|
doc (str): Used to generate the Python's docstring and the HTML documentation.
|
|
1117
1136
|
namespace: Namespace for the underlying C++ function.
|
|
1118
1137
|
variadic (bool): Whether the function declares variadic arguments.
|
|
@@ -1220,16 +1239,16 @@ def add_builtin(
|
|
|
1220
1239
|
typelists.append(l)
|
|
1221
1240
|
|
|
1222
1241
|
for arg_types in itertools.product(*typelists):
|
|
1223
|
-
|
|
1242
|
+
concrete_arg_types = dict(zip(input_types.keys(), arg_types))
|
|
1224
1243
|
|
|
1225
1244
|
# Some of these argument lists won't work, eg if the function is mul(), we won't be
|
|
1226
1245
|
# able to do a matrix vector multiplication for a mat22 and a vec3. The `constraint`
|
|
1227
1246
|
# function determines which combinations are valid:
|
|
1228
1247
|
if constraint:
|
|
1229
|
-
if constraint(
|
|
1248
|
+
if constraint(concrete_arg_types) is False:
|
|
1230
1249
|
continue
|
|
1231
1250
|
|
|
1232
|
-
return_type = value_func(
|
|
1251
|
+
return_type = value_func(concrete_arg_types, None)
|
|
1233
1252
|
|
|
1234
1253
|
# The return_type might just be vector_t(length=3,dtype=wp.float32), so we've got to match that
|
|
1235
1254
|
# in the list of hard coded types so it knows it's returning one of them:
|
|
@@ -1247,11 +1266,12 @@ def add_builtin(
|
|
|
1247
1266
|
# finally we can generate a function call for these concrete types:
|
|
1248
1267
|
add_builtin(
|
|
1249
1268
|
key,
|
|
1250
|
-
input_types=
|
|
1269
|
+
input_types=concrete_arg_types,
|
|
1251
1270
|
value_type=return_type,
|
|
1252
1271
|
value_func=value_func if return_type is Any else None,
|
|
1253
1272
|
export_func=export_func,
|
|
1254
1273
|
dispatch_func=dispatch_func,
|
|
1274
|
+
lto_dispatch_func=lto_dispatch_func,
|
|
1255
1275
|
doc=doc,
|
|
1256
1276
|
namespace=namespace,
|
|
1257
1277
|
variadic=variadic,
|
|
@@ -1274,6 +1294,7 @@ def add_builtin(
|
|
|
1274
1294
|
value_func=value_func,
|
|
1275
1295
|
export_func=export_func,
|
|
1276
1296
|
dispatch_func=dispatch_func,
|
|
1297
|
+
lto_dispatch_func=lto_dispatch_func,
|
|
1277
1298
|
variadic=variadic,
|
|
1278
1299
|
initializer_list_func=initializer_list_func,
|
|
1279
1300
|
export=export,
|
|
@@ -1540,6 +1561,8 @@ class ModuleBuilder:
|
|
|
1540
1561
|
self.options = options
|
|
1541
1562
|
self.module = module
|
|
1542
1563
|
self.deferred_functions = []
|
|
1564
|
+
self.ltoirs = {} # map from lto symbol to lto binary
|
|
1565
|
+
self.ltoirs_decl = {} # map from lto symbol to lto forward declaration
|
|
1543
1566
|
|
|
1544
1567
|
if hasher is None:
|
|
1545
1568
|
hasher = ModuleHasher(module)
|
|
@@ -1607,9 +1630,26 @@ class ModuleBuilder:
|
|
|
1607
1630
|
# use dict to preserve import order
|
|
1608
1631
|
self.functions[func] = None
|
|
1609
1632
|
|
|
1633
|
+
def build_meta(self):
|
|
1634
|
+
meta = {}
|
|
1635
|
+
|
|
1636
|
+
for kernel in self.kernels:
|
|
1637
|
+
name = kernel.get_mangled_name()
|
|
1638
|
+
|
|
1639
|
+
meta[name + "_cuda_kernel_forward_smem_bytes"] = kernel.adj.get_total_required_shared()
|
|
1640
|
+
meta[name + "_cuda_kernel_backward_smem_bytes"] = kernel.adj.get_total_required_shared() * 2
|
|
1641
|
+
|
|
1642
|
+
return meta
|
|
1643
|
+
|
|
1610
1644
|
def codegen(self, device):
|
|
1611
1645
|
source = ""
|
|
1612
1646
|
|
|
1647
|
+
# code-gen LTO forward declarations
|
|
1648
|
+
source += 'extern "C" {\n'
|
|
1649
|
+
for fwd in self.ltoirs_decl.values():
|
|
1650
|
+
source += fwd + "\n"
|
|
1651
|
+
source += "}\n"
|
|
1652
|
+
|
|
1613
1653
|
# code-gen structs
|
|
1614
1654
|
visited_structs = set()
|
|
1615
1655
|
for struct in self.structs.keys():
|
|
@@ -1639,9 +1679,9 @@ class ModuleBuilder:
|
|
|
1639
1679
|
|
|
1640
1680
|
# add headers
|
|
1641
1681
|
if device == "cpu":
|
|
1642
|
-
source = warp.codegen.cpu_module_header + source
|
|
1682
|
+
source = warp.codegen.cpu_module_header.format(tile_size=self.options["block_dim"]) + source
|
|
1643
1683
|
else:
|
|
1644
|
-
source = warp.codegen.cuda_module_header + source
|
|
1684
|
+
source = warp.codegen.cuda_module_header.format(tile_size=self.options["block_dim"]) + source
|
|
1645
1685
|
|
|
1646
1686
|
return source
|
|
1647
1687
|
|
|
@@ -1660,11 +1700,12 @@ class ModuleExec:
|
|
|
1660
1700
|
instance.handle = None
|
|
1661
1701
|
return instance
|
|
1662
1702
|
|
|
1663
|
-
def __init__(self, handle, module_hash, device):
|
|
1703
|
+
def __init__(self, handle, module_hash, device, meta):
|
|
1664
1704
|
self.handle = handle
|
|
1665
1705
|
self.module_hash = module_hash
|
|
1666
1706
|
self.device = device
|
|
1667
1707
|
self.kernel_hooks = {}
|
|
1708
|
+
self.meta = meta
|
|
1668
1709
|
|
|
1669
1710
|
# release the loaded module
|
|
1670
1711
|
def __del__(self):
|
|
@@ -1678,19 +1719,50 @@ class ModuleExec:
|
|
|
1678
1719
|
|
|
1679
1720
|
# lookup and cache kernel entry points
|
|
1680
1721
|
def get_kernel_hooks(self, kernel):
|
|
1681
|
-
|
|
1722
|
+
# Use kernel.adj as a unique key for cache lookups instead of the kernel itself.
|
|
1723
|
+
# This avoids holding a reference to the kernel and is faster than using
|
|
1724
|
+
# a WeakKeyDictionary with kernels as keys.
|
|
1725
|
+
hooks = self.kernel_hooks.get(kernel.adj)
|
|
1682
1726
|
if hooks is not None:
|
|
1683
1727
|
return hooks
|
|
1684
1728
|
|
|
1685
1729
|
name = kernel.get_mangled_name()
|
|
1686
1730
|
|
|
1687
1731
|
if self.device.is_cuda:
|
|
1688
|
-
|
|
1689
|
-
|
|
1732
|
+
forward_name = name + "_cuda_kernel_forward"
|
|
1733
|
+
forward_kernel = runtime.core.cuda_get_kernel(
|
|
1734
|
+
self.device.context, self.handle, forward_name.encode("utf-8")
|
|
1690
1735
|
)
|
|
1691
|
-
|
|
1692
|
-
|
|
1736
|
+
|
|
1737
|
+
backward_name = name + "_cuda_kernel_backward"
|
|
1738
|
+
backward_kernel = runtime.core.cuda_get_kernel(
|
|
1739
|
+
self.device.context, self.handle, backward_name.encode("utf-8")
|
|
1693
1740
|
)
|
|
1741
|
+
|
|
1742
|
+
# look up the required shared memory size for each kernel from module metadata
|
|
1743
|
+
forward_smem_bytes = self.meta[forward_name + "_smem_bytes"]
|
|
1744
|
+
backward_smem_bytes = self.meta[backward_name + "_smem_bytes"]
|
|
1745
|
+
|
|
1746
|
+
# configure kernels maximum shared memory size
|
|
1747
|
+
max_smem_bytes = runtime.core.cuda_get_max_shared_memory(self.device.context)
|
|
1748
|
+
|
|
1749
|
+
if not runtime.core.cuda_configure_kernel_shared_memory(forward_kernel, forward_smem_bytes):
|
|
1750
|
+
print(
|
|
1751
|
+
f"Warning: Failed to configure kernel dynamic shared memory for this device, tried to configure {forward_name} kernel for {forward_smem_bytes} bytes, but maximum available is {max_smem_bytes}"
|
|
1752
|
+
)
|
|
1753
|
+
|
|
1754
|
+
options = dict(kernel.module.options)
|
|
1755
|
+
options.update(kernel.options)
|
|
1756
|
+
|
|
1757
|
+
if options["enable_backward"] and not runtime.core.cuda_configure_kernel_shared_memory(
|
|
1758
|
+
backward_kernel, backward_smem_bytes
|
|
1759
|
+
):
|
|
1760
|
+
print(
|
|
1761
|
+
f"Warning: Failed to configure kernel dynamic shared memory for this device, tried to configure {backward_name} kernel for {backward_smem_bytes} bytes, but maximum available is {max_smem_bytes}"
|
|
1762
|
+
)
|
|
1763
|
+
|
|
1764
|
+
hooks = KernelHooks(forward_kernel, backward_kernel, forward_smem_bytes, backward_smem_bytes)
|
|
1765
|
+
|
|
1694
1766
|
else:
|
|
1695
1767
|
func = ctypes.CFUNCTYPE(None)
|
|
1696
1768
|
forward = (
|
|
@@ -1700,9 +1772,9 @@ class ModuleExec:
|
|
|
1700
1772
|
func(runtime.llvm.lookup(self.handle.encode("utf-8"), (name + "_cpu_backward").encode("utf-8"))) or None
|
|
1701
1773
|
)
|
|
1702
1774
|
|
|
1703
|
-
|
|
1704
|
-
self.kernel_hooks[kernel] = hooks
|
|
1775
|
+
hooks = KernelHooks(forward, backward)
|
|
1705
1776
|
|
|
1777
|
+
self.kernel_hooks[kernel.adj] = hooks
|
|
1706
1778
|
return hooks
|
|
1707
1779
|
|
|
1708
1780
|
|
|
@@ -1712,7 +1784,8 @@ class ModuleExec:
|
|
|
1712
1784
|
# build cache
|
|
1713
1785
|
class Module:
|
|
1714
1786
|
def __init__(self, name, loader):
|
|
1715
|
-
self.name = name
|
|
1787
|
+
self.name = name if name is not None else "None"
|
|
1788
|
+
|
|
1716
1789
|
self.loader = loader
|
|
1717
1790
|
|
|
1718
1791
|
# lookup the latest versions of kernels, functions, and structs by key
|
|
@@ -1720,12 +1793,14 @@ class Module:
|
|
|
1720
1793
|
self.functions = {} # (key: function)
|
|
1721
1794
|
self.structs = {} # (key: struct)
|
|
1722
1795
|
|
|
1723
|
-
# Set of all "live" kernels in this module.
|
|
1796
|
+
# Set of all "live" kernels in this module, i.e., kernels that still have references.
|
|
1797
|
+
# We keep a weak reference to every kernel ever created in this module and rely on Python GC
|
|
1798
|
+
# to release kernels that no longer have any references (in user code or internal bookkeeping).
|
|
1724
1799
|
# The difference between `live_kernels` and `kernels` is that `live_kernels` may contain
|
|
1725
1800
|
# multiple kernels with the same key (which is essential to support closures), while `kernels`
|
|
1726
1801
|
# only holds the latest kernel for each key. When the module is built, we compute the hash
|
|
1727
1802
|
# of each kernel in `live_kernels` and filter out duplicates for codegen.
|
|
1728
|
-
self.
|
|
1803
|
+
self._live_kernels = weakref.WeakSet()
|
|
1729
1804
|
|
|
1730
1805
|
# executable modules currently loaded
|
|
1731
1806
|
self.execs = {} # (device.context: ModuleExec)
|
|
@@ -1749,6 +1824,7 @@ class Module:
|
|
|
1749
1824
|
"fast_math": False,
|
|
1750
1825
|
"cuda_output": None, # supported values: "ptx", "cubin", or None (automatic)
|
|
1751
1826
|
"mode": warp.config.mode,
|
|
1827
|
+
"block_dim": 256,
|
|
1752
1828
|
}
|
|
1753
1829
|
|
|
1754
1830
|
# Module dependencies are determined by scanning each function
|
|
@@ -1773,7 +1849,7 @@ class Module:
|
|
|
1773
1849
|
self.kernels[kernel.key] = kernel
|
|
1774
1850
|
|
|
1775
1851
|
# track all kernel objects, even if they are duplicates
|
|
1776
|
-
self.
|
|
1852
|
+
self._live_kernels.add(kernel)
|
|
1777
1853
|
|
|
1778
1854
|
self.find_references(kernel.adj)
|
|
1779
1855
|
|
|
@@ -1839,6 +1915,19 @@ class Module:
|
|
|
1839
1915
|
# for a reload of module on next launch
|
|
1840
1916
|
self.mark_modified()
|
|
1841
1917
|
|
|
1918
|
+
@property
|
|
1919
|
+
def live_kernels(self):
|
|
1920
|
+
# Return a list of kernels that still have references.
|
|
1921
|
+
# We return a regular list instead of the WeakSet to avoid undesirable issues
|
|
1922
|
+
# if kernels are garbage collected before the caller is done using this list.
|
|
1923
|
+
# Note that we should avoid retaining strong references to kernels unnecessarily
|
|
1924
|
+
# so that Python GC can release kernels that no longer have user references.
|
|
1925
|
+
# It is tempting to call gc.collect() here to force garbage collection,
|
|
1926
|
+
# but this can have undesirable consequences (e.g., GC during graph capture),
|
|
1927
|
+
# so we should avoid it as a general rule. Instead, we rely on Python's
|
|
1928
|
+
# reference counting GC to collect kernels that have gone out of scope.
|
|
1929
|
+
return list(self._live_kernels)
|
|
1930
|
+
|
|
1842
1931
|
# find kernel corresponding to a Python function
|
|
1843
1932
|
def find_kernel(self, func):
|
|
1844
1933
|
qualname = warp.codegen.make_full_qualified_name(func)
|
|
@@ -1879,9 +1968,17 @@ class Module:
|
|
|
1879
1968
|
self.hasher = ModuleHasher(self)
|
|
1880
1969
|
return self.hasher.get_module_hash()
|
|
1881
1970
|
|
|
1882
|
-
def load(self, device) -> ModuleExec:
|
|
1971
|
+
def load(self, device, block_dim=None) -> ModuleExec:
|
|
1883
1972
|
device = runtime.get_device(device)
|
|
1884
1973
|
|
|
1974
|
+
# re-compile module if tile size (blockdim) changes
|
|
1975
|
+
# todo: it would be better to have a method such as `module.get_kernel(block_dim=N)`
|
|
1976
|
+
# that can return a single kernel instance with a given block size
|
|
1977
|
+
if block_dim is not None:
|
|
1978
|
+
if self.options["block_dim"] != block_dim:
|
|
1979
|
+
self.unload()
|
|
1980
|
+
self.options["block_dim"] = block_dim
|
|
1981
|
+
|
|
1885
1982
|
# compute the hash if needed
|
|
1886
1983
|
if self.hasher is None:
|
|
1887
1984
|
self.hasher = ModuleHasher(self)
|
|
@@ -1909,6 +2006,7 @@ class Module:
|
|
|
1909
2006
|
# determine output paths
|
|
1910
2007
|
if device.is_cpu:
|
|
1911
2008
|
output_name = "module_codegen.o"
|
|
2009
|
+
output_arch = None
|
|
1912
2010
|
|
|
1913
2011
|
elif device.is_cuda:
|
|
1914
2012
|
# determine whether to use PTX or CUBIN
|
|
@@ -1947,7 +2045,12 @@ class Module:
|
|
|
1947
2045
|
or not warp.config.cache_kernels
|
|
1948
2046
|
or warp.config.verify_autograd_array_access
|
|
1949
2047
|
):
|
|
1950
|
-
|
|
2048
|
+
builder_options = {
|
|
2049
|
+
**self.options,
|
|
2050
|
+
# Some of the Tile codegen, such as cuFFTDx and cuBLASDx, requires knowledge of the target arch
|
|
2051
|
+
"output_arch": output_arch,
|
|
2052
|
+
}
|
|
2053
|
+
builder = ModuleBuilder(self, builder_options, hasher=self.hasher)
|
|
1951
2054
|
|
|
1952
2055
|
# create a temporary (process unique) dir for build outputs before moving to the binary dir
|
|
1953
2056
|
build_dir = os.path.join(
|
|
@@ -2010,6 +2113,7 @@ class Module:
|
|
|
2010
2113
|
config=self.options["mode"],
|
|
2011
2114
|
fast_math=self.options["fast_math"],
|
|
2012
2115
|
verify_fp=warp.config.verify_fp,
|
|
2116
|
+
ltoirs=builder.ltoirs.values(),
|
|
2013
2117
|
)
|
|
2014
2118
|
|
|
2015
2119
|
except Exception as e:
|
|
@@ -2017,15 +2121,46 @@ class Module:
|
|
|
2017
2121
|
module_load_timer.extra_msg = " (error)"
|
|
2018
2122
|
raise (e)
|
|
2019
2123
|
|
|
2124
|
+
# ------------------------------------------------------------
|
|
2125
|
+
# build meta data
|
|
2126
|
+
|
|
2127
|
+
meta = builder.build_meta()
|
|
2128
|
+
meta_path = os.path.join(build_dir, "module_codegen.meta")
|
|
2129
|
+
|
|
2130
|
+
with open(meta_path, "w") as meta_file:
|
|
2131
|
+
json.dump(meta, meta_file)
|
|
2132
|
+
|
|
2020
2133
|
# -----------------------------------------------------------
|
|
2021
2134
|
# update cache
|
|
2022
2135
|
|
|
2023
|
-
|
|
2024
|
-
|
|
2025
|
-
|
|
2026
|
-
|
|
2027
|
-
|
|
2028
|
-
|
|
2136
|
+
def safe_rename(src, dst, attempts=5, delay=0.1):
|
|
2137
|
+
for i in range(attempts):
|
|
2138
|
+
try:
|
|
2139
|
+
os.rename(src, dst)
|
|
2140
|
+
return
|
|
2141
|
+
except FileExistsError:
|
|
2142
|
+
return
|
|
2143
|
+
except OSError as e:
|
|
2144
|
+
if e.errno == errno.ENOTEMPTY:
|
|
2145
|
+
# if directory exists we assume another process
|
|
2146
|
+
# got there first, in which case we will copy
|
|
2147
|
+
# our output to the directory manually in second step
|
|
2148
|
+
return
|
|
2149
|
+
else:
|
|
2150
|
+
# otherwise assume directory creation failed e.g.: access denied
|
|
2151
|
+
# on Windows we see occasional failures to rename directories due to
|
|
2152
|
+
# some process holding a lock on a file to be moved to workaround
|
|
2153
|
+
# this we make multiple attempts to rename with some delay
|
|
2154
|
+
if i < attempts - 1:
|
|
2155
|
+
time.sleep(delay)
|
|
2156
|
+
else:
|
|
2157
|
+
print(
|
|
2158
|
+
f"Could not update Warp cache with module binaries, trying to rename {build_dir} to {module_dir}, error {e}"
|
|
2159
|
+
)
|
|
2160
|
+
raise e
|
|
2161
|
+
|
|
2162
|
+
# try to move process outputs to cache
|
|
2163
|
+
safe_rename(build_dir, module_dir)
|
|
2029
2164
|
|
|
2030
2165
|
if os.path.exists(module_dir):
|
|
2031
2166
|
if not os.path.exists(binary_path):
|
|
@@ -2053,18 +2188,23 @@ class Module:
|
|
|
2053
2188
|
|
|
2054
2189
|
# -----------------------------------------------------------
|
|
2055
2190
|
# Load CPU or CUDA binary
|
|
2191
|
+
|
|
2192
|
+
meta_path = os.path.join(module_dir, "module_codegen.meta")
|
|
2193
|
+
with open(meta_path, "r") as meta_file:
|
|
2194
|
+
meta = json.load(meta_file)
|
|
2195
|
+
|
|
2056
2196
|
if device.is_cpu:
|
|
2057
2197
|
# LLVM modules are identified using strings, so we need to ensure uniqueness
|
|
2058
2198
|
module_handle = f"{module_name}_{self.cpu_exec_id}"
|
|
2059
2199
|
self.cpu_exec_id += 1
|
|
2060
2200
|
runtime.llvm.load_obj(binary_path.encode("utf-8"), module_handle.encode("utf-8"))
|
|
2061
|
-
module_exec = ModuleExec(module_handle, module_hash, device)
|
|
2201
|
+
module_exec = ModuleExec(module_handle, module_hash, device, meta)
|
|
2062
2202
|
self.execs[None] = module_exec
|
|
2063
2203
|
|
|
2064
2204
|
elif device.is_cuda:
|
|
2065
2205
|
cuda_module = warp.build.load_cuda(binary_path, device)
|
|
2066
2206
|
if cuda_module is not None:
|
|
2067
|
-
module_exec = ModuleExec(cuda_module, module_hash, device)
|
|
2207
|
+
module_exec = ModuleExec(cuda_module, module_hash, device, meta)
|
|
2068
2208
|
self.execs[device.context] = module_exec
|
|
2069
2209
|
else:
|
|
2070
2210
|
module_load_timer.extra_msg = " (error)"
|
|
@@ -2719,21 +2859,16 @@ class Graph:
|
|
|
2719
2859
|
|
|
2720
2860
|
class Runtime:
|
|
2721
2861
|
def __init__(self):
|
|
2722
|
-
if sys.version_info < (3,
|
|
2723
|
-
raise RuntimeError("Warp requires Python 3.
|
|
2862
|
+
if sys.version_info < (3, 8):
|
|
2863
|
+
raise RuntimeError("Warp requires Python 3.8 as a minimum")
|
|
2724
2864
|
if sys.version_info < (3, 9):
|
|
2725
2865
|
warp.utils.warn(f"Python 3.9 or newer is recommended for running Warp, detected {sys.version_info}")
|
|
2726
2866
|
|
|
2727
2867
|
bin_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "bin")
|
|
2728
2868
|
|
|
2729
2869
|
if os.name == "nt":
|
|
2730
|
-
|
|
2731
|
-
|
|
2732
|
-
os.add_dll_directory(bin_path)
|
|
2733
|
-
|
|
2734
|
-
else:
|
|
2735
|
-
# Python < 3.8 we add dll directory to path
|
|
2736
|
-
os.environ["PATH"] = bin_path + os.pathsep + os.environ["PATH"]
|
|
2870
|
+
# Python >= 3.8 this method to add dll search paths
|
|
2871
|
+
os.add_dll_directory(bin_path)
|
|
2737
2872
|
|
|
2738
2873
|
warp_lib = os.path.join(bin_path, "warp.dll")
|
|
2739
2874
|
llvm_lib = os.path.join(bin_path, "warp-clang.dll")
|
|
@@ -3205,6 +3340,8 @@ class Runtime:
|
|
|
3205
3340
|
self.core.is_cuda_compatibility_enabled.restype = ctypes.c_int
|
|
3206
3341
|
self.core.is_cutlass_enabled.argtypes = None
|
|
3207
3342
|
self.core.is_cutlass_enabled.restype = ctypes.c_int
|
|
3343
|
+
self.core.is_mathdx_enabled.argtypes = None
|
|
3344
|
+
self.core.is_mathdx_enabled.restype = ctypes.c_int
|
|
3208
3345
|
|
|
3209
3346
|
self.core.cuda_driver_version.argtypes = None
|
|
3210
3347
|
self.core.cuda_driver_version.restype = ctypes.c_int
|
|
@@ -3329,17 +3466,58 @@ class Runtime:
|
|
|
3329
3466
|
self.core.cuda_graph_destroy.restype = ctypes.c_bool
|
|
3330
3467
|
|
|
3331
3468
|
self.core.cuda_compile_program.argtypes = [
|
|
3332
|
-
ctypes.c_char_p,
|
|
3333
|
-
ctypes.c_int,
|
|
3334
|
-
ctypes.c_char_p,
|
|
3335
|
-
ctypes.
|
|
3336
|
-
ctypes.
|
|
3337
|
-
ctypes.c_bool,
|
|
3338
|
-
ctypes.c_bool,
|
|
3339
|
-
ctypes.
|
|
3469
|
+
ctypes.c_char_p, # cuda_src
|
|
3470
|
+
ctypes.c_int, # arch
|
|
3471
|
+
ctypes.c_char_p, # include_dir
|
|
3472
|
+
ctypes.c_int, # num_cuda_include_dirs
|
|
3473
|
+
ctypes.POINTER(ctypes.c_char_p), # cuda include dirs
|
|
3474
|
+
ctypes.c_bool, # debug
|
|
3475
|
+
ctypes.c_bool, # verbose
|
|
3476
|
+
ctypes.c_bool, # verify_fp
|
|
3477
|
+
ctypes.c_bool, # fast_math
|
|
3478
|
+
ctypes.c_char_p, # output_path
|
|
3479
|
+
ctypes.c_size_t, # num_ltoirs
|
|
3480
|
+
ctypes.POINTER(ctypes.c_char_p), # ltoirs
|
|
3481
|
+
ctypes.POINTER(ctypes.c_size_t), # ltoir_sizes
|
|
3340
3482
|
]
|
|
3341
3483
|
self.core.cuda_compile_program.restype = ctypes.c_size_t
|
|
3342
3484
|
|
|
3485
|
+
self.core.cuda_compile_fft.argtypes = [
|
|
3486
|
+
ctypes.c_char_p, # lto
|
|
3487
|
+
ctypes.c_char_p, # function name
|
|
3488
|
+
ctypes.c_int, # num include dirs
|
|
3489
|
+
ctypes.POINTER(ctypes.c_char_p), # include dirs
|
|
3490
|
+
ctypes.c_char_p, # mathdx include dir
|
|
3491
|
+
ctypes.c_int, # arch
|
|
3492
|
+
ctypes.c_int, # size
|
|
3493
|
+
ctypes.c_int, # ept
|
|
3494
|
+
ctypes.c_int, # direction
|
|
3495
|
+
ctypes.c_int, # precision
|
|
3496
|
+
ctypes.POINTER(ctypes.c_int), # smem (out)
|
|
3497
|
+
]
|
|
3498
|
+
self.core.cuda_compile_fft.restype = ctypes.c_bool
|
|
3499
|
+
|
|
3500
|
+
self.core.cuda_compile_dot.argtypes = [
|
|
3501
|
+
ctypes.c_char_p, # lto
|
|
3502
|
+
ctypes.c_char_p, # function name
|
|
3503
|
+
ctypes.c_int, # num include dirs
|
|
3504
|
+
ctypes.POINTER(ctypes.c_char_p), # include dirs
|
|
3505
|
+
ctypes.c_char_p, # mathdx include dir
|
|
3506
|
+
ctypes.c_int, # arch
|
|
3507
|
+
ctypes.c_int, # M
|
|
3508
|
+
ctypes.c_int, # N
|
|
3509
|
+
ctypes.c_int, # K
|
|
3510
|
+
ctypes.c_int, # a_precision
|
|
3511
|
+
ctypes.c_int, # b_precision
|
|
3512
|
+
ctypes.c_int, # c_precision
|
|
3513
|
+
ctypes.c_int, # type
|
|
3514
|
+
ctypes.c_int, # a_arrangement
|
|
3515
|
+
ctypes.c_int, # b_arrangement
|
|
3516
|
+
ctypes.c_int, # c_arrangement
|
|
3517
|
+
ctypes.c_int, # num threads
|
|
3518
|
+
]
|
|
3519
|
+
self.core.cuda_compile_dot.restype = ctypes.c_bool
|
|
3520
|
+
|
|
3343
3521
|
self.core.cuda_load_module.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
|
|
3344
3522
|
self.core.cuda_load_module.restype = ctypes.c_void_p
|
|
3345
3523
|
|
|
@@ -3349,11 +3527,19 @@ class Runtime:
|
|
|
3349
3527
|
self.core.cuda_get_kernel.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_char_p]
|
|
3350
3528
|
self.core.cuda_get_kernel.restype = ctypes.c_void_p
|
|
3351
3529
|
|
|
3530
|
+
self.core.cuda_get_max_shared_memory.argtypes = [ctypes.c_void_p]
|
|
3531
|
+
self.core.cuda_get_max_shared_memory.restype = ctypes.c_int
|
|
3532
|
+
|
|
3533
|
+
self.core.cuda_configure_kernel_shared_memory.argtypes = [ctypes.c_void_p, ctypes.c_int]
|
|
3534
|
+
self.core.cuda_configure_kernel_shared_memory.restype = ctypes.c_bool
|
|
3535
|
+
|
|
3352
3536
|
self.core.cuda_launch_kernel.argtypes = [
|
|
3353
3537
|
ctypes.c_void_p,
|
|
3354
3538
|
ctypes.c_void_p,
|
|
3355
3539
|
ctypes.c_size_t,
|
|
3356
3540
|
ctypes.c_int,
|
|
3541
|
+
ctypes.c_int,
|
|
3542
|
+
ctypes.c_int,
|
|
3357
3543
|
ctypes.POINTER(ctypes.c_void_p),
|
|
3358
3544
|
ctypes.c_void_p,
|
|
3359
3545
|
]
|
|
@@ -3382,6 +3568,23 @@ class Runtime:
|
|
|
3382
3568
|
self.core.cuda_timing_end.argtypes = []
|
|
3383
3569
|
self.core.cuda_timing_end.restype = None
|
|
3384
3570
|
|
|
3571
|
+
self.core.graph_coloring.argtypes = [
|
|
3572
|
+
ctypes.c_int,
|
|
3573
|
+
warp.types.array_t,
|
|
3574
|
+
ctypes.c_int,
|
|
3575
|
+
warp.types.array_t,
|
|
3576
|
+
]
|
|
3577
|
+
self.core.graph_coloring.restype = ctypes.c_int
|
|
3578
|
+
|
|
3579
|
+
self.core.balance_coloring.argtypes = [
|
|
3580
|
+
ctypes.c_int,
|
|
3581
|
+
warp.types.array_t,
|
|
3582
|
+
ctypes.c_int,
|
|
3583
|
+
ctypes.c_float,
|
|
3584
|
+
warp.types.array_t,
|
|
3585
|
+
]
|
|
3586
|
+
self.core.balance_coloring.restype = ctypes.c_float
|
|
3587
|
+
|
|
3385
3588
|
self.core.init.restype = ctypes.c_int
|
|
3386
3589
|
|
|
3387
3590
|
except AttributeError as e:
|
|
@@ -3607,10 +3810,7 @@ class Runtime:
|
|
|
3607
3810
|
|
|
3608
3811
|
def load_dll(self, dll_path):
|
|
3609
3812
|
try:
|
|
3610
|
-
|
|
3611
|
-
dll = ctypes.CDLL(dll_path, winmode=0)
|
|
3612
|
-
else:
|
|
3613
|
-
dll = ctypes.CDLL(dll_path)
|
|
3813
|
+
dll = ctypes.CDLL(dll_path, winmode=0)
|
|
3614
3814
|
except OSError as e:
|
|
3615
3815
|
if "GLIBCXX" in str(e):
|
|
3616
3816
|
raise RuntimeError(
|
|
@@ -3751,7 +3951,7 @@ def is_cuda_available() -> bool:
|
|
|
3751
3951
|
return get_cuda_device_count() > 0
|
|
3752
3952
|
|
|
3753
3953
|
|
|
3754
|
-
def is_device_available(device):
|
|
3954
|
+
def is_device_available(device: Device) -> bool:
|
|
3755
3955
|
return device in get_devices()
|
|
3756
3956
|
|
|
3757
3957
|
|
|
@@ -3811,7 +4011,7 @@ def get_cuda_devices() -> List[Device]:
|
|
|
3811
4011
|
|
|
3812
4012
|
|
|
3813
4013
|
def get_preferred_device() -> Device:
|
|
3814
|
-
"""Returns the preferred compute device,
|
|
4014
|
+
"""Returns the preferred compute device, ``cuda:0`` if available and ``cpu`` otherwise."""
|
|
3815
4015
|
|
|
3816
4016
|
init()
|
|
3817
4017
|
|
|
@@ -3896,7 +4096,7 @@ def set_mempool_enabled(device: Devicelike, enable: bool) -> None:
|
|
|
3896
4096
|
They should generally be enabled, but there is a rare caveat. Copying data between different GPUs
|
|
3897
4097
|
may fail during graph capture if the memory was allocated using pooled allocators and memory pool
|
|
3898
4098
|
access is not enabled between the two GPUs. This is an internal CUDA limitation that is not related
|
|
3899
|
-
to Warp. The preferred solution is to enable memory pool access using
|
|
4099
|
+
to Warp. The preferred solution is to enable memory pool access using :func:`set_mempool_access_enabled`.
|
|
3900
4100
|
If peer access is not supported, then the default CUDA allocators must be used to pre-allocate the memory
|
|
3901
4101
|
prior to graph capture.
|
|
3902
4102
|
"""
|
|
@@ -3951,7 +4151,7 @@ def set_mempool_release_threshold(device: Devicelike, threshold: Union[int, floa
|
|
|
3951
4151
|
|
|
3952
4152
|
|
|
3953
4153
|
def get_mempool_release_threshold(device: Devicelike) -> int:
|
|
3954
|
-
"""Get the CUDA memory pool release threshold on the device."""
|
|
4154
|
+
"""Get the CUDA memory pool release threshold on the device in bytes."""
|
|
3955
4155
|
|
|
3956
4156
|
init()
|
|
3957
4157
|
|
|
@@ -3970,7 +4170,7 @@ def is_peer_access_supported(target_device: Devicelike, peer_device: Devicelike)
|
|
|
3970
4170
|
"""Check if `peer_device` can directly access the memory of `target_device` on this system.
|
|
3971
4171
|
|
|
3972
4172
|
This applies to memory allocated using default CUDA allocators. For memory allocated using
|
|
3973
|
-
CUDA pooled allocators, use
|
|
4173
|
+
CUDA pooled allocators, use :func:`is_mempool_access_supported()`.
|
|
3974
4174
|
|
|
3975
4175
|
Returns:
|
|
3976
4176
|
A Boolean value indicating if this peer access is supported by the system.
|
|
@@ -3991,7 +4191,7 @@ def is_peer_access_enabled(target_device: Devicelike, peer_device: Devicelike) -
|
|
|
3991
4191
|
"""Check if `peer_device` can currently access the memory of `target_device`.
|
|
3992
4192
|
|
|
3993
4193
|
This applies to memory allocated using default CUDA allocators. For memory allocated using
|
|
3994
|
-
CUDA pooled allocators, use
|
|
4194
|
+
CUDA pooled allocators, use :func:`is_mempool_access_enabled()`.
|
|
3995
4195
|
|
|
3996
4196
|
Returns:
|
|
3997
4197
|
A Boolean value indicating if this peer access is currently enabled.
|
|
@@ -4015,7 +4215,7 @@ def set_peer_access_enabled(target_device: Devicelike, peer_device: Devicelike,
|
|
|
4015
4215
|
a negative impact on memory consumption and allocation performance.
|
|
4016
4216
|
|
|
4017
4217
|
This applies to memory allocated using default CUDA allocators. For memory allocated using
|
|
4018
|
-
CUDA pooled allocators, use
|
|
4218
|
+
CUDA pooled allocators, use :func:`set_mempool_access_enabled()`.
|
|
4019
4219
|
"""
|
|
4020
4220
|
|
|
4021
4221
|
init()
|
|
@@ -4043,7 +4243,8 @@ def set_peer_access_enabled(target_device: Devicelike, peer_device: Devicelike,
|
|
|
4043
4243
|
def is_mempool_access_supported(target_device: Devicelike, peer_device: Devicelike) -> bool:
|
|
4044
4244
|
"""Check if `peer_device` can directly access the memory pool of `target_device`.
|
|
4045
4245
|
|
|
4046
|
-
If mempool access is possible, it can be managed using
|
|
4246
|
+
If mempool access is possible, it can be managed using :func:`set_mempool_access_enabled()`
|
|
4247
|
+
and :func:`is_mempool_access_enabled()`.
|
|
4047
4248
|
|
|
4048
4249
|
Returns:
|
|
4049
4250
|
A Boolean value indicating if this memory pool access is supported by the system.
|
|
@@ -4061,7 +4262,7 @@ def is_mempool_access_enabled(target_device: Devicelike, peer_device: Devicelike
|
|
|
4061
4262
|
"""Check if `peer_device` can currently access the memory pool of `target_device`.
|
|
4062
4263
|
|
|
4063
4264
|
This applies to memory allocated using CUDA pooled allocators. For memory allocated using
|
|
4064
|
-
default CUDA allocators, use
|
|
4265
|
+
default CUDA allocators, use :func:`is_peer_access_enabled()`.
|
|
4065
4266
|
|
|
4066
4267
|
Returns:
|
|
4067
4268
|
A Boolean value indicating if this peer access is currently enabled.
|
|
@@ -4082,7 +4283,7 @@ def set_mempool_access_enabled(target_device: Devicelike, peer_device: Devicelik
|
|
|
4082
4283
|
"""Enable or disable access from `peer_device` to the memory pool of `target_device`.
|
|
4083
4284
|
|
|
4084
4285
|
This applies to memory allocated using CUDA pooled allocators. For memory allocated using
|
|
4085
|
-
default CUDA allocators, use
|
|
4286
|
+
default CUDA allocators, use :func:`set_peer_access_enabled()`.
|
|
4086
4287
|
"""
|
|
4087
4288
|
|
|
4088
4289
|
init()
|
|
@@ -4791,7 +4992,9 @@ def pack_arg(kernel, arg_type, arg_name, value, device, adjoint=False):
|
|
|
4791
4992
|
# represents all data required for a kernel launch
|
|
4792
4993
|
# so that launches can be replayed quickly, use `wp.launch(..., record_cmd=True)`
|
|
4793
4994
|
class Launch:
|
|
4794
|
-
def __init__(
|
|
4995
|
+
def __init__(
|
|
4996
|
+
self, kernel, device, hooks=None, params=None, params_addr=None, bounds=None, max_blocks=0, block_dim=256
|
|
4997
|
+
):
|
|
4795
4998
|
# retain the module executable so it doesn't get unloaded
|
|
4796
4999
|
self.module_exec = kernel.module.load(device)
|
|
4797
5000
|
if not self.module_exec:
|
|
@@ -4830,6 +5033,7 @@ class Launch:
|
|
|
4830
5033
|
self.device = device
|
|
4831
5034
|
self.bounds = bounds
|
|
4832
5035
|
self.max_blocks = max_blocks
|
|
5036
|
+
self.block_dim = block_dim
|
|
4833
5037
|
|
|
4834
5038
|
def set_dim(self, dim):
|
|
4835
5039
|
self.bounds = warp.types.launch_bounds_t(dim)
|
|
@@ -4911,6 +5115,8 @@ class Launch:
|
|
|
4911
5115
|
self.hooks.forward,
|
|
4912
5116
|
self.bounds.size,
|
|
4913
5117
|
self.max_blocks,
|
|
5118
|
+
self.block_dim,
|
|
5119
|
+
self.hooks.forward_smem_bytes,
|
|
4914
5120
|
self.params_addr,
|
|
4915
5121
|
stream.cuda_stream,
|
|
4916
5122
|
)
|
|
@@ -4929,6 +5135,7 @@ def launch(
|
|
|
4929
5135
|
record_tape=True,
|
|
4930
5136
|
record_cmd=False,
|
|
4931
5137
|
max_blocks=0,
|
|
5138
|
+
block_dim=256,
|
|
4932
5139
|
):
|
|
4933
5140
|
"""Launch a Warp kernel on the target device
|
|
4934
5141
|
|
|
@@ -4948,6 +5155,7 @@ def launch(
|
|
|
4948
5155
|
record_cmd: When True the launch will be returned as a ``Launch`` command object, the launch will not occur until the user calls ``cmd.launch()``
|
|
4949
5156
|
max_blocks: The maximum number of CUDA thread blocks to use. Only has an effect for CUDA kernel launches.
|
|
4950
5157
|
If negative or zero, the maximum hardware value will be used.
|
|
5158
|
+
block_dim: The number of threads per block.
|
|
4951
5159
|
"""
|
|
4952
5160
|
|
|
4953
5161
|
init()
|
|
@@ -5001,7 +5209,12 @@ def launch(
|
|
|
5001
5209
|
kernel = kernel.add_overload(fwd_types)
|
|
5002
5210
|
|
|
5003
5211
|
# delay load modules, including new overload if needed
|
|
5004
|
-
|
|
5212
|
+
try:
|
|
5213
|
+
module_exec = kernel.module.load(device, block_dim)
|
|
5214
|
+
except Exception:
|
|
5215
|
+
kernel.adj.skip_build = True
|
|
5216
|
+
raise
|
|
5217
|
+
|
|
5005
5218
|
if not module_exec:
|
|
5006
5219
|
return
|
|
5007
5220
|
|
|
@@ -5057,7 +5270,14 @@ def launch(
|
|
|
5057
5270
|
)
|
|
5058
5271
|
|
|
5059
5272
|
runtime.core.cuda_launch_kernel(
|
|
5060
|
-
device.context,
|
|
5273
|
+
device.context,
|
|
5274
|
+
hooks.backward,
|
|
5275
|
+
bounds.size,
|
|
5276
|
+
max_blocks,
|
|
5277
|
+
block_dim,
|
|
5278
|
+
hooks.backward_smem_bytes,
|
|
5279
|
+
kernel_params,
|
|
5280
|
+
stream.cuda_stream,
|
|
5061
5281
|
)
|
|
5062
5282
|
|
|
5063
5283
|
else:
|
|
@@ -5074,13 +5294,22 @@ def launch(
|
|
|
5074
5294
|
params_addr=kernel_params,
|
|
5075
5295
|
bounds=bounds,
|
|
5076
5296
|
device=device,
|
|
5297
|
+
max_blocks=max_blocks,
|
|
5298
|
+
block_dim=block_dim,
|
|
5077
5299
|
)
|
|
5078
5300
|
return launch
|
|
5079
5301
|
|
|
5080
5302
|
else:
|
|
5081
5303
|
# launch
|
|
5082
5304
|
runtime.core.cuda_launch_kernel(
|
|
5083
|
-
device.context,
|
|
5305
|
+
device.context,
|
|
5306
|
+
hooks.forward,
|
|
5307
|
+
bounds.size,
|
|
5308
|
+
max_blocks,
|
|
5309
|
+
block_dim,
|
|
5310
|
+
hooks.forward_smem_bytes,
|
|
5311
|
+
kernel_params,
|
|
5312
|
+
stream.cuda_stream,
|
|
5084
5313
|
)
|
|
5085
5314
|
|
|
5086
5315
|
try:
|
|
@@ -5094,13 +5323,65 @@ def launch(
|
|
|
5094
5323
|
# record file, lineno, func as metadata
|
|
5095
5324
|
frame = inspect.currentframe().f_back
|
|
5096
5325
|
caller = {"file": frame.f_code.co_filename, "lineno": frame.f_lineno, "func": frame.f_code.co_name}
|
|
5097
|
-
runtime.tape.record_launch(
|
|
5326
|
+
runtime.tape.record_launch(
|
|
5327
|
+
kernel, dim, max_blocks, inputs, outputs, device, block_dim, metadata={"caller": caller}
|
|
5328
|
+
)
|
|
5098
5329
|
|
|
5099
5330
|
# detect illegal inter-kernel read/write access patterns if verification flag is set
|
|
5100
5331
|
if warp.config.verify_autograd_array_access:
|
|
5101
5332
|
runtime.tape._check_kernel_array_access(kernel, fwd_args)
|
|
5102
5333
|
|
|
5103
5334
|
|
|
5335
|
+
def launch_tiled(*args, **kwargs):
|
|
5336
|
+
"""A helper method for launching a grid with an extra trailing dimension equal to the block size.
|
|
5337
|
+
|
|
5338
|
+
For example, to launch a 2D grid, where each element has 64 threads assigned you would use the following:
|
|
5339
|
+
|
|
5340
|
+
.. code-block:: python
|
|
5341
|
+
|
|
5342
|
+
wp.launch_tiled(kernel, [M, N], inputs=[...], block_dim=64)
|
|
5343
|
+
|
|
5344
|
+
Which is equivalent to the following:
|
|
5345
|
+
|
|
5346
|
+
.. code-block:: python
|
|
5347
|
+
|
|
5348
|
+
wp.launch(kernel, [M, N, 64], inputs=[...], block_dim=64)
|
|
5349
|
+
|
|
5350
|
+
Inside your kernel code you can retrieve the first two indices of the thread as usual, ignoring the implicit third dimension if desired:
|
|
5351
|
+
|
|
5352
|
+
.. code-block:: python
|
|
5353
|
+
|
|
5354
|
+
@wp.kernel
|
|
5355
|
+
def compute()
|
|
5356
|
+
|
|
5357
|
+
i, j = wp.tid()
|
|
5358
|
+
|
|
5359
|
+
...
|
|
5360
|
+
"""
|
|
5361
|
+
|
|
5362
|
+
# promote dim to a list in case it was passed as a scalar or tuple
|
|
5363
|
+
if "dim" not in kwargs:
|
|
5364
|
+
raise RuntimeError("Launch dimensions 'dim' argument should be passed via. keyword args for wp.launch_tiled()")
|
|
5365
|
+
|
|
5366
|
+
if "block_dim" not in kwargs:
|
|
5367
|
+
raise RuntimeError(
|
|
5368
|
+
"Launch block dimension 'block_dim' argument should be passed via. keyword args for wp.launch_tiled()"
|
|
5369
|
+
)
|
|
5370
|
+
|
|
5371
|
+
dim = kwargs["dim"]
|
|
5372
|
+
if not isinstance(dim, list):
|
|
5373
|
+
dim = list(dim) if isinstance(dim, tuple) else [dim]
|
|
5374
|
+
|
|
5375
|
+
if len(dim) > 3:
|
|
5376
|
+
raise RuntimeError("wp.launch_tiled() requires a grid with fewer than 4 dimensions")
|
|
5377
|
+
|
|
5378
|
+
# add trailing dimension
|
|
5379
|
+
kwargs["dim"] = dim + [kwargs["block_dim"]]
|
|
5380
|
+
|
|
5381
|
+
# forward to original launch method
|
|
5382
|
+
return launch(*args, **kwargs)
|
|
5383
|
+
|
|
5384
|
+
|
|
5104
5385
|
def synchronize():
|
|
5105
5386
|
"""Manually synchronize the calling CPU thread with any outstanding CUDA work on all devices
|
|
5106
5387
|
|
|
@@ -5619,16 +5900,6 @@ def type_str(t):
|
|
|
5619
5900
|
return "Any"
|
|
5620
5901
|
elif t == Callable:
|
|
5621
5902
|
return "Callable"
|
|
5622
|
-
elif t == Tuple[int]:
|
|
5623
|
-
return "Tuple[int]"
|
|
5624
|
-
elif t == Tuple[int, int]:
|
|
5625
|
-
return "Tuple[int, int]"
|
|
5626
|
-
elif t == Tuple[int, int, int]:
|
|
5627
|
-
return "Tuple[int, int, int]"
|
|
5628
|
-
elif t == Tuple[int, int, int, int]:
|
|
5629
|
-
return "Tuple[int, int, int, int]"
|
|
5630
|
-
elif t == Tuple[int, ...]:
|
|
5631
|
-
return "Tuple[int, ...]"
|
|
5632
5903
|
elif isinstance(t, int):
|
|
5633
5904
|
return str(t)
|
|
5634
5905
|
elif isinstance(t, List):
|
|
@@ -5663,9 +5934,13 @@ def type_str(t):
|
|
|
5663
5934
|
return f"Transformation[{type_str(t._wp_scalar_type_)}]"
|
|
5664
5935
|
|
|
5665
5936
|
raise TypeError("Invalid vector or matrix dimensions")
|
|
5666
|
-
elif
|
|
5667
|
-
args_repr = ", ".join(type_str(x) for x in
|
|
5668
|
-
return f"{t.
|
|
5937
|
+
elif warp.codegen.get_type_origin(t) in (list, tuple):
|
|
5938
|
+
args_repr = ", ".join(type_str(x) for x in warp.codegen.get_type_args(t))
|
|
5939
|
+
return f"{t._name}[{args_repr}]"
|
|
5940
|
+
elif t is Ellipsis:
|
|
5941
|
+
return "..."
|
|
5942
|
+
elif warp.types.is_tile(t):
|
|
5943
|
+
return "Tile"
|
|
5669
5944
|
|
|
5670
5945
|
return t.__name__
|
|
5671
5946
|
|
|
@@ -5826,9 +6101,6 @@ def export_stubs(file): # pragma: no cover
|
|
|
5826
6101
|
print('Cols = TypeVar("Cols", bound=int)', file=file)
|
|
5827
6102
|
print('DType = TypeVar("DType")', file=file)
|
|
5828
6103
|
|
|
5829
|
-
print('Int = TypeVar("Int")', file=file)
|
|
5830
|
-
print('Float = TypeVar("Float")', file=file)
|
|
5831
|
-
print('Scalar = TypeVar("Scalar")', file=file)
|
|
5832
6104
|
print("Vector = Generic[Length, Scalar]", file=file)
|
|
5833
6105
|
print("Matrix = Generic[Rows, Cols, Scalar]", file=file)
|
|
5834
6106
|
print("Quaternion = Generic[Float]", file=file)
|