warp-lang 0.11.0__py3-none-manylinux2014_x86_64.whl → 1.0.0__py3-none-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +8 -0
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/build.py +7 -6
- warp/build_dll.py +70 -79
- warp/builtins.py +10 -6
- warp/codegen.py +51 -19
- warp/config.py +7 -8
- warp/constants.py +3 -0
- warp/context.py +948 -245
- warp/dlpack.py +198 -113
- warp/examples/assets/bunny.usd +0 -0
- warp/examples/assets/cartpole.urdf +110 -0
- warp/examples/assets/crazyflie.usd +0 -0
- warp/examples/assets/cube.usda +42 -0
- warp/examples/assets/nv_ant.xml +92 -0
- warp/examples/assets/nv_humanoid.xml +183 -0
- warp/examples/assets/quadruped.urdf +268 -0
- warp/examples/assets/rocks.nvdb +0 -0
- warp/examples/assets/rocks.usd +0 -0
- warp/examples/assets/sphere.usda +56 -0
- warp/examples/assets/torus.usda +105 -0
- warp/examples/benchmarks/benchmark_api.py +383 -0
- warp/examples/benchmarks/benchmark_cloth.py +279 -0
- warp/examples/benchmarks/benchmark_cloth_cupy.py +88 -0
- warp/examples/benchmarks/benchmark_cloth_jax.py +100 -0
- warp/examples/benchmarks/benchmark_cloth_numba.py +142 -0
- warp/examples/benchmarks/benchmark_cloth_numpy.py +77 -0
- warp/examples/benchmarks/benchmark_cloth_pytorch.py +86 -0
- warp/examples/benchmarks/benchmark_cloth_taichi.py +112 -0
- warp/examples/benchmarks/benchmark_cloth_warp.py +146 -0
- warp/examples/benchmarks/benchmark_launches.py +295 -0
- warp/examples/core/example_dem.py +221 -0
- warp/examples/core/example_fluid.py +267 -0
- warp/examples/core/example_graph_capture.py +129 -0
- warp/examples/core/example_marching_cubes.py +177 -0
- warp/examples/core/example_mesh.py +154 -0
- warp/examples/core/example_mesh_intersect.py +193 -0
- warp/examples/core/example_nvdb.py +169 -0
- warp/examples/core/example_raycast.py +89 -0
- warp/examples/core/example_raymarch.py +178 -0
- warp/examples/core/example_render_opengl.py +141 -0
- warp/examples/core/example_sph.py +389 -0
- warp/examples/core/example_torch.py +181 -0
- warp/examples/core/example_wave.py +249 -0
- warp/examples/fem/bsr_utils.py +380 -0
- warp/examples/fem/example_apic_fluid.py +391 -0
- warp/examples/fem/example_convection_diffusion.py +168 -0
- warp/examples/fem/example_convection_diffusion_dg.py +209 -0
- warp/examples/fem/example_convection_diffusion_dg0.py +194 -0
- warp/examples/fem/example_deformed_geometry.py +159 -0
- warp/examples/fem/example_diffusion.py +173 -0
- warp/examples/fem/example_diffusion_3d.py +152 -0
- warp/examples/fem/example_diffusion_mgpu.py +214 -0
- warp/examples/fem/example_mixed_elasticity.py +222 -0
- warp/examples/fem/example_navier_stokes.py +243 -0
- warp/examples/fem/example_stokes.py +192 -0
- warp/examples/fem/example_stokes_transfer.py +249 -0
- warp/examples/fem/mesh_utils.py +109 -0
- warp/examples/fem/plot_utils.py +287 -0
- warp/examples/optim/example_bounce.py +248 -0
- warp/examples/optim/example_cloth_throw.py +210 -0
- warp/examples/optim/example_diffray.py +535 -0
- warp/examples/optim/example_drone.py +850 -0
- warp/examples/optim/example_inverse_kinematics.py +169 -0
- warp/examples/optim/example_inverse_kinematics_torch.py +170 -0
- warp/examples/optim/example_spring_cage.py +234 -0
- warp/examples/optim/example_trajectory.py +201 -0
- warp/examples/sim/example_cartpole.py +128 -0
- warp/examples/sim/example_cloth.py +184 -0
- warp/examples/sim/example_granular.py +113 -0
- warp/examples/sim/example_granular_collision_sdf.py +185 -0
- warp/examples/sim/example_jacobian_ik.py +213 -0
- warp/examples/sim/example_particle_chain.py +106 -0
- warp/examples/sim/example_quadruped.py +179 -0
- warp/examples/sim/example_rigid_chain.py +191 -0
- warp/examples/sim/example_rigid_contact.py +176 -0
- warp/examples/sim/example_rigid_force.py +126 -0
- warp/examples/sim/example_rigid_gyroscopic.py +97 -0
- warp/examples/sim/example_rigid_soft_contact.py +124 -0
- warp/examples/sim/example_soft_body.py +178 -0
- warp/fabric.py +29 -20
- warp/fem/cache.py +0 -1
- warp/fem/dirichlet.py +0 -2
- warp/fem/integrate.py +0 -1
- warp/jax.py +45 -0
- warp/jax_experimental.py +339 -0
- warp/native/builtin.h +12 -0
- warp/native/bvh.cu +18 -18
- warp/native/clang/clang.cpp +8 -3
- warp/native/cuda_util.cpp +94 -5
- warp/native/cuda_util.h +35 -6
- warp/native/cutlass_gemm.cpp +1 -1
- warp/native/cutlass_gemm.cu +4 -1
- warp/native/error.cpp +66 -0
- warp/native/error.h +27 -0
- warp/native/mesh.cu +2 -2
- warp/native/reduce.cu +4 -4
- warp/native/runlength_encode.cu +2 -2
- warp/native/scan.cu +2 -2
- warp/native/sparse.cu +0 -1
- warp/native/temp_buffer.h +2 -2
- warp/native/warp.cpp +95 -60
- warp/native/warp.cu +1053 -218
- warp/native/warp.h +49 -32
- warp/optim/linear.py +33 -16
- warp/render/render_opengl.py +202 -101
- warp/render/render_usd.py +82 -40
- warp/sim/__init__.py +13 -4
- warp/sim/articulation.py +4 -5
- warp/sim/collide.py +320 -175
- warp/sim/import_mjcf.py +25 -30
- warp/sim/import_urdf.py +94 -63
- warp/sim/import_usd.py +51 -36
- warp/sim/inertia.py +3 -2
- warp/sim/integrator.py +233 -0
- warp/sim/integrator_euler.py +447 -469
- warp/sim/integrator_featherstone.py +1991 -0
- warp/sim/integrator_xpbd.py +1420 -640
- warp/sim/model.py +765 -487
- warp/sim/particles.py +2 -1
- warp/sim/render.py +35 -13
- warp/sim/utils.py +222 -11
- warp/stubs.py +8 -0
- warp/tape.py +16 -1
- warp/tests/aux_test_grad_customs.py +23 -0
- warp/tests/test_array.py +190 -1
- warp/tests/test_async.py +656 -0
- warp/tests/test_bool.py +50 -0
- warp/tests/test_dlpack.py +164 -11
- warp/tests/test_examples.py +166 -74
- warp/tests/test_fem.py +8 -1
- warp/tests/test_generics.py +15 -5
- warp/tests/test_grad.py +1 -1
- warp/tests/test_grad_customs.py +172 -12
- warp/tests/test_jax.py +254 -0
- warp/tests/test_large.py +29 -6
- warp/tests/test_launch.py +25 -0
- warp/tests/test_linear_solvers.py +20 -3
- warp/tests/test_matmul.py +61 -16
- warp/tests/test_matmul_lite.py +13 -13
- warp/tests/test_mempool.py +186 -0
- warp/tests/test_multigpu.py +3 -0
- warp/tests/test_options.py +16 -2
- warp/tests/test_peer.py +137 -0
- warp/tests/test_print.py +3 -1
- warp/tests/test_quat.py +23 -0
- warp/tests/test_sim_kinematics.py +97 -0
- warp/tests/test_snippet.py +126 -3
- warp/tests/test_streams.py +108 -79
- warp/tests/test_torch.py +16 -8
- warp/tests/test_utils.py +32 -27
- warp/tests/test_verify_fp.py +65 -0
- warp/tests/test_volume.py +1 -1
- warp/tests/unittest_serial.py +2 -0
- warp/tests/unittest_suites.py +12 -0
- warp/tests/unittest_utils.py +14 -7
- warp/thirdparty/unittest_parallel.py +15 -3
- warp/torch.py +10 -8
- warp/types.py +363 -246
- warp/utils.py +143 -19
- warp_lang-1.0.0.dist-info/LICENSE.md +126 -0
- warp_lang-1.0.0.dist-info/METADATA +394 -0
- {warp_lang-0.11.0.dist-info → warp_lang-1.0.0.dist-info}/RECORD +167 -86
- warp/sim/optimizer.py +0 -138
- warp_lang-0.11.0.dist-info/LICENSE.md +0 -36
- warp_lang-0.11.0.dist-info/METADATA +0 -238
- /warp/tests/{walkthough_debug.py → walkthrough_debug.py} +0 -0
- {warp_lang-0.11.0.dist-info → warp_lang-1.0.0.dist-info}/WHEEL +0 -0
- {warp_lang-0.11.0.dist-info → warp_lang-1.0.0.dist-info}/top_level.txt +0 -0
warp/jax_experimental.py
ADDED
|
@@ -0,0 +1,339 @@
|
|
|
1
|
+
# Copyright (c) 2024 NVIDIA CORPORATION. All rights reserved.
|
|
2
|
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
|
3
|
+
# and proprietary rights in and to this software, related documentation
|
|
4
|
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
|
5
|
+
# distribution of this software and related documentation without an express
|
|
6
|
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
|
7
|
+
|
|
8
|
+
import ctypes
|
|
9
|
+
import warp as wp
|
|
10
|
+
from warp.types import array_t, launch_bounds_t, strides_from_shape
|
|
11
|
+
from warp.context import type_str
|
|
12
|
+
import jax
|
|
13
|
+
import jax.numpy as jp
|
|
14
|
+
|
|
15
|
+
_jax_warp_p = None
|
|
16
|
+
|
|
17
|
+
# Holder for the custom callback to keep it alive.
|
|
18
|
+
_cc_callback = None
|
|
19
|
+
_registered_kernels = [None]
|
|
20
|
+
_registered_kernel_to_id = {}
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def jax_kernel(wp_kernel):
|
|
24
|
+
"""Create a Jax primitive from a Warp kernel.
|
|
25
|
+
|
|
26
|
+
NOTE: This is an experimental feature under development.
|
|
27
|
+
|
|
28
|
+
Current limitations:
|
|
29
|
+
- All kernel arguments must be arrays.
|
|
30
|
+
- Kernel launch dimensions are inferred from the shape of the first argument.
|
|
31
|
+
- Input arguments are followed by output arguments in the Warp kernel definition.
|
|
32
|
+
- There must be at least one input argument and at least one output argument.
|
|
33
|
+
- Output shapes must match the launch dimensions (i.e., output shapes must match the shape of the first argument).
|
|
34
|
+
- All arrays must be contiguous.
|
|
35
|
+
- Only the CUDA backend is supported.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
if _jax_warp_p == None:
|
|
39
|
+
# Create and register the primitive
|
|
40
|
+
_create_jax_warp_primitive()
|
|
41
|
+
if not wp_kernel in _registered_kernel_to_id:
|
|
42
|
+
id = len(_registered_kernels)
|
|
43
|
+
_registered_kernels.append(wp_kernel)
|
|
44
|
+
_registered_kernel_to_id[wp_kernel] = id
|
|
45
|
+
else:
|
|
46
|
+
id = _registered_kernel_to_id[wp_kernel]
|
|
47
|
+
|
|
48
|
+
def bind(*args):
|
|
49
|
+
return _jax_warp_p.bind(*args, kernel=id)
|
|
50
|
+
|
|
51
|
+
return bind
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def _warp_custom_callback(stream, buffers, opaque, opaque_len):
|
|
55
|
+
# The descriptor is the form
|
|
56
|
+
# <kernel-id>|<launch-dims>|<arg-dims-list>
|
|
57
|
+
# Example: 42|16,32|16,32;100;16,32
|
|
58
|
+
kernel_id_str, dim_str, args_str = opaque.decode().split("|")
|
|
59
|
+
|
|
60
|
+
# Get the kernel from the registry.
|
|
61
|
+
kernel_id = int(kernel_id_str)
|
|
62
|
+
kernel = _registered_kernels[kernel_id]
|
|
63
|
+
|
|
64
|
+
# Parse launch dimensions.
|
|
65
|
+
dims = [int(d) for d in dim_str.split(",")]
|
|
66
|
+
bounds = launch_bounds_t(dims)
|
|
67
|
+
|
|
68
|
+
# Parse arguments.
|
|
69
|
+
arg_strings = args_str.split(";")
|
|
70
|
+
num_args = len(arg_strings)
|
|
71
|
+
assert num_args == len(kernel.adj.args), "Incorrect number of arguments"
|
|
72
|
+
|
|
73
|
+
# First param is the launch bounds.
|
|
74
|
+
kernel_params = (ctypes.c_void_p * (1 + num_args))()
|
|
75
|
+
kernel_params[0] = ctypes.addressof(bounds)
|
|
76
|
+
|
|
77
|
+
# Parse array descriptors.
|
|
78
|
+
args = []
|
|
79
|
+
for i in range(num_args):
|
|
80
|
+
dtype = kernel.adj.args[i].type.dtype
|
|
81
|
+
shape = [int(d) for d in arg_strings[i].split(",")]
|
|
82
|
+
strides = strides_from_shape(shape, dtype)
|
|
83
|
+
|
|
84
|
+
arr = array_t(buffers[i], 0, len(shape), shape, strides)
|
|
85
|
+
args.append(arr) # keep a reference
|
|
86
|
+
arg_ptr = ctypes.addressof(arr)
|
|
87
|
+
|
|
88
|
+
kernel_params[i + 1] = arg_ptr
|
|
89
|
+
|
|
90
|
+
# Get current device.
|
|
91
|
+
device = wp.device_from_jax(_get_jax_device())
|
|
92
|
+
|
|
93
|
+
# Get kernel hooks.
|
|
94
|
+
# Note: module was loaded during jit lowering.
|
|
95
|
+
hooks = kernel.module.get_kernel_hooks(kernel, device)
|
|
96
|
+
assert hooks.forward, "Failed to find kernel entry point"
|
|
97
|
+
|
|
98
|
+
# Launch the kernel.
|
|
99
|
+
wp.context.runtime.core.cuda_launch_kernel(
|
|
100
|
+
device.context, hooks.forward, bounds.size, 0, kernel_params, stream
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
# TODO: is there a simpler way of getting the Jax "current" device?
|
|
105
|
+
def _get_jax_device():
|
|
106
|
+
# check if jax.default_device() context manager is active
|
|
107
|
+
device = jax.config.jax_default_device
|
|
108
|
+
# if default device is not set, use first device
|
|
109
|
+
if device is None:
|
|
110
|
+
device = jax.devices()[0]
|
|
111
|
+
return device
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def _create_jax_warp_primitive():
|
|
115
|
+
from functools import reduce
|
|
116
|
+
import jax
|
|
117
|
+
from jax._src.interpreters import batching
|
|
118
|
+
from jax.interpreters import mlir
|
|
119
|
+
from jax.interpreters.mlir import ir
|
|
120
|
+
from jaxlib.hlo_helpers import custom_call
|
|
121
|
+
|
|
122
|
+
global _jax_warp_p
|
|
123
|
+
global _cc_callback
|
|
124
|
+
|
|
125
|
+
# Create and register the primitive.
|
|
126
|
+
# TODO add default implementation that calls the kernel via warp.
|
|
127
|
+
_jax_warp_p = jax.core.Primitive("jax_warp")
|
|
128
|
+
_jax_warp_p.multiple_results = True
|
|
129
|
+
|
|
130
|
+
# TODO Just launch the kernel directly, but make sure the argument
|
|
131
|
+
# shapes are massaged the same way as below so that vmap works.
|
|
132
|
+
def impl(*args):
|
|
133
|
+
raise Exception("Not implemented")
|
|
134
|
+
|
|
135
|
+
_jax_warp_p.def_impl(impl)
|
|
136
|
+
|
|
137
|
+
# Auto-batching. Make sure all the arguments are fully broadcasted
|
|
138
|
+
# so that Warp is not confused about dimensions.
|
|
139
|
+
def vectorized_multi_batcher(args, dims, **params):
|
|
140
|
+
# Figure out the number of outputs.
|
|
141
|
+
wp_kernel = _registered_kernels[params["kernel"]]
|
|
142
|
+
output_count = len(wp_kernel.adj.args) - len(args)
|
|
143
|
+
shape, dim = next((a.shape, d) for a, d in zip(args, dims) if d is not None)
|
|
144
|
+
size = shape[dim]
|
|
145
|
+
args = [batching.bdim_at_front(a, d, size) if len(a.shape) else a for a, d in zip(args, dims)]
|
|
146
|
+
# Create the batched primitive.
|
|
147
|
+
return _jax_warp_p.bind(*args, **params), [dims[0]] * output_count
|
|
148
|
+
|
|
149
|
+
batching.primitive_batchers[_jax_warp_p] = vectorized_multi_batcher
|
|
150
|
+
|
|
151
|
+
def get_vecmat_shape(warp_type):
|
|
152
|
+
if hasattr(warp_type.dtype, "_shape_"):
|
|
153
|
+
return warp_type.dtype._shape_
|
|
154
|
+
return []
|
|
155
|
+
|
|
156
|
+
def strip_vecmat_dimensions(warp_arg, actual_shape):
|
|
157
|
+
shape = get_vecmat_shape(warp_arg.type)
|
|
158
|
+
for i, s in enumerate(reversed(shape)):
|
|
159
|
+
item = actual_shape[-i - 1]
|
|
160
|
+
if s != item:
|
|
161
|
+
raise Exception(f"The vector/matrix shape for argument {warp_arg.label} does not match")
|
|
162
|
+
return actual_shape[: len(actual_shape) - len(shape)]
|
|
163
|
+
|
|
164
|
+
def collapse_into_leading_dimension(warp_arg, actual_shape):
|
|
165
|
+
if len(actual_shape) < warp_arg.type.ndim:
|
|
166
|
+
raise Exception(f"Argument {warp_arg.label} has too few non-matrix/vector dimensions")
|
|
167
|
+
index_rest = len(actual_shape) - warp_arg.type.ndim + 1
|
|
168
|
+
leading_size = reduce(lambda x, y: x * y, actual_shape[:index_rest])
|
|
169
|
+
return [leading_size] + actual_shape[index_rest:]
|
|
170
|
+
|
|
171
|
+
# Infer array dimensions from input type.
|
|
172
|
+
def infer_dimensions(warp_arg, actual_shape):
|
|
173
|
+
actual_shape = strip_vecmat_dimensions(warp_arg, actual_shape)
|
|
174
|
+
return collapse_into_leading_dimension(warp_arg, actual_shape)
|
|
175
|
+
|
|
176
|
+
def base_type_to_jax(warp_dtype):
|
|
177
|
+
if hasattr(warp_dtype, "_wp_scalar_type_"):
|
|
178
|
+
return wp.jax.dtype_to_jax(warp_dtype._wp_scalar_type_)
|
|
179
|
+
return wp.jax.dtype_to_jax(warp_dtype)
|
|
180
|
+
|
|
181
|
+
def base_type_to_jax_ir(warp_dtype):
|
|
182
|
+
warp_to_jax_dict = {
|
|
183
|
+
wp.float16: ir.F16Type.get(),
|
|
184
|
+
wp.float32: ir.F32Type.get(),
|
|
185
|
+
wp.float64: ir.F64Type.get(),
|
|
186
|
+
wp.int8: ir.IntegerType.get_signless(8),
|
|
187
|
+
wp.int16: ir.IntegerType.get_signless(16),
|
|
188
|
+
wp.int32: ir.IntegerType.get_signless(32),
|
|
189
|
+
wp.int64: ir.IntegerType.get_signless(64),
|
|
190
|
+
wp.uint8: ir.IntegerType.get_unsigned(8),
|
|
191
|
+
wp.uint16: ir.IntegerType.get_unsigned(16),
|
|
192
|
+
wp.uint32: ir.IntegerType.get_unsigned(32),
|
|
193
|
+
wp.uint64: ir.IntegerType.get_unsigned(64),
|
|
194
|
+
}
|
|
195
|
+
if hasattr(warp_dtype, "_wp_scalar_type_"):
|
|
196
|
+
warp_dtype = warp_dtype._wp_scalar_type_
|
|
197
|
+
jax_dtype = warp_to_jax_dict.get(warp_dtype)
|
|
198
|
+
if jax_dtype is None:
|
|
199
|
+
raise TypeError(f"Invalid or unsupported data type: {warp_dtype}")
|
|
200
|
+
return jax_dtype
|
|
201
|
+
|
|
202
|
+
def base_type_is_compatible(warp_type, jax_ir_type):
|
|
203
|
+
jax_ir_to_warp = {
|
|
204
|
+
"f16": wp.float16,
|
|
205
|
+
"f32": wp.float32,
|
|
206
|
+
"f64": wp.float64,
|
|
207
|
+
"i8": wp.int8,
|
|
208
|
+
"i16": wp.int16,
|
|
209
|
+
"i32": wp.int32,
|
|
210
|
+
"i64": wp.int64,
|
|
211
|
+
"ui8": wp.uint8,
|
|
212
|
+
"ui16": wp.uint16,
|
|
213
|
+
"ui32": wp.uint32,
|
|
214
|
+
"ui64": wp.uint64,
|
|
215
|
+
}
|
|
216
|
+
expected_warp_type = jax_ir_to_warp.get(str(jax_ir_type))
|
|
217
|
+
if expected_warp_type is not None:
|
|
218
|
+
if hasattr(warp_type, "_wp_scalar_type_"):
|
|
219
|
+
return warp_type._wp_scalar_type_ == expected_warp_type
|
|
220
|
+
else:
|
|
221
|
+
return warp_type == expected_warp_type
|
|
222
|
+
else:
|
|
223
|
+
raise TypeError(f"Invalid or unsupported data type: {jax_ir_type}")
|
|
224
|
+
|
|
225
|
+
# Abstract evaluation.
|
|
226
|
+
def jax_warp_abstract(*args, kernel=None):
|
|
227
|
+
wp_kernel = _registered_kernels[kernel]
|
|
228
|
+
# All the extra arguments to the warp kernel are outputs.
|
|
229
|
+
warp_outputs = [o.type for o in wp_kernel.adj.args[len(args) :]]
|
|
230
|
+
# TODO. Let's just use the first input dimension to infer the output's dimensions.
|
|
231
|
+
dims = strip_vecmat_dimensions(wp_kernel.adj.args[0], list(args[0].shape))
|
|
232
|
+
jax_outputs = []
|
|
233
|
+
for o in warp_outputs:
|
|
234
|
+
shape = list(dims) + list(get_vecmat_shape(o))
|
|
235
|
+
dtype = base_type_to_jax(o.dtype)
|
|
236
|
+
jax_outputs.append(jax.core.ShapedArray(shape, dtype))
|
|
237
|
+
return jax_outputs
|
|
238
|
+
|
|
239
|
+
_jax_warp_p.def_abstract_eval(jax_warp_abstract)
|
|
240
|
+
|
|
241
|
+
# Lowering to MLIR.
|
|
242
|
+
|
|
243
|
+
# Create python-land custom call target.
|
|
244
|
+
CCALLFUNC = ctypes.CFUNCTYPE(
|
|
245
|
+
ctypes.c_voidp, ctypes.c_void_p, ctypes.POINTER(ctypes.c_void_p), ctypes.c_char_p, ctypes.c_size_t
|
|
246
|
+
)
|
|
247
|
+
_cc_callback = CCALLFUNC(_warp_custom_callback)
|
|
248
|
+
ccall_address = ctypes.cast(_cc_callback, ctypes.c_void_p)
|
|
249
|
+
|
|
250
|
+
# Put the custom call into a capsule, as required by XLA.
|
|
251
|
+
PyCapsule_Destructor = ctypes.CFUNCTYPE(None, ctypes.py_object)
|
|
252
|
+
PyCapsule_New = ctypes.pythonapi.PyCapsule_New
|
|
253
|
+
PyCapsule_New.restype = ctypes.py_object
|
|
254
|
+
PyCapsule_New.argtypes = (ctypes.c_void_p, ctypes.c_char_p, PyCapsule_Destructor)
|
|
255
|
+
capsule = PyCapsule_New(ccall_address.value, b"xla._CUSTOM_CALL_TARGET", PyCapsule_Destructor(0))
|
|
256
|
+
|
|
257
|
+
# Register the callback in XLA.
|
|
258
|
+
jax.lib.xla_client.register_custom_call_target("warp_call", capsule, platform="gpu")
|
|
259
|
+
|
|
260
|
+
def default_layout(shape):
|
|
261
|
+
return range(len(shape) - 1, -1, -1)
|
|
262
|
+
|
|
263
|
+
def warp_call_lowering(ctx, *args, kernel=None):
|
|
264
|
+
if not kernel:
|
|
265
|
+
raise Exception("Unknown kernel id " + str(kernel))
|
|
266
|
+
wp_kernel = _registered_kernels[kernel]
|
|
267
|
+
|
|
268
|
+
# TODO This may not be necessary, but it is perhaps better not to be
|
|
269
|
+
# mucking with kernel loading while already running the workload.
|
|
270
|
+
module = wp_kernel.module
|
|
271
|
+
device = wp.device_from_jax(_get_jax_device())
|
|
272
|
+
if not module.load(device):
|
|
273
|
+
raise Exception("Could not load kernel on device")
|
|
274
|
+
|
|
275
|
+
# Infer dimensions from the first input.
|
|
276
|
+
warp_arg0 = wp_kernel.adj.args[0]
|
|
277
|
+
actual_shape0 = ir.RankedTensorType(args[0].type).shape
|
|
278
|
+
dims = strip_vecmat_dimensions(warp_arg0, actual_shape0)
|
|
279
|
+
warp_dims = collapse_into_leading_dimension(warp_arg0, dims)
|
|
280
|
+
|
|
281
|
+
# Figure out the types and shapes of the input arrays.
|
|
282
|
+
arg_strings = []
|
|
283
|
+
operand_layouts = []
|
|
284
|
+
for actual, warg in zip(args, wp_kernel.adj.args):
|
|
285
|
+
wtype = warg.type
|
|
286
|
+
rtt = ir.RankedTensorType(actual.type)
|
|
287
|
+
|
|
288
|
+
if not isinstance(wtype, wp.array):
|
|
289
|
+
raise Exception("Only contiguous arrays are supported for Jax kernel arguments")
|
|
290
|
+
|
|
291
|
+
if not base_type_is_compatible(wtype.dtype, rtt.element_type):
|
|
292
|
+
raise TypeError(f"Incompatible data type for argument '{warg.label}', expected {type_str(wtype.dtype)}, got {rtt.element_type}")
|
|
293
|
+
|
|
294
|
+
# Infer array dimension (by removing the vector/matrix dimensions and
|
|
295
|
+
# collapsing the initial dimensions).
|
|
296
|
+
shape = infer_dimensions(warg, rtt.shape)
|
|
297
|
+
|
|
298
|
+
if len(shape) != wtype.ndim:
|
|
299
|
+
raise TypeError(f"Incompatible array dimensionality for argument '{warg.label}'")
|
|
300
|
+
|
|
301
|
+
arg_strings.append(",".join([str(d) for d in shape]))
|
|
302
|
+
operand_layouts.append(default_layout(rtt.shape))
|
|
303
|
+
|
|
304
|
+
# Figure out the types and shapes of the output arrays.
|
|
305
|
+
result_types = []
|
|
306
|
+
result_layouts = []
|
|
307
|
+
for warg in wp_kernel.adj.args[len(args) :]:
|
|
308
|
+
wtype = warg.type
|
|
309
|
+
|
|
310
|
+
if not isinstance(wtype, wp.array):
|
|
311
|
+
raise Exception("Only contiguous arrays are supported for Jax kernel arguments")
|
|
312
|
+
|
|
313
|
+
# Infer dimensions from the first input.
|
|
314
|
+
arg_strings.append(",".join([str(d) for d in warp_dims]))
|
|
315
|
+
|
|
316
|
+
result_shape = list(dims) + list(get_vecmat_shape(wtype))
|
|
317
|
+
result_types.append(ir.RankedTensorType.get(result_shape, base_type_to_jax_ir(wtype.dtype)))
|
|
318
|
+
result_layouts.append(default_layout(result_shape))
|
|
319
|
+
|
|
320
|
+
# Build opaque descriptor for callback.
|
|
321
|
+
shape_str = ",".join([str(d) for d in warp_dims])
|
|
322
|
+
args_str = ";".join(arg_strings)
|
|
323
|
+
descriptor = f"{kernel}|{shape_str}|{args_str}"
|
|
324
|
+
|
|
325
|
+
out = custom_call(
|
|
326
|
+
b"warp_call",
|
|
327
|
+
result_types=result_types,
|
|
328
|
+
operands=args,
|
|
329
|
+
backend_config=descriptor.encode("utf-8"),
|
|
330
|
+
operand_layouts=operand_layouts,
|
|
331
|
+
result_layouts=result_layouts,
|
|
332
|
+
).results
|
|
333
|
+
return out
|
|
334
|
+
|
|
335
|
+
mlir.register_lowering(
|
|
336
|
+
_jax_warp_p,
|
|
337
|
+
warp_call_lowering,
|
|
338
|
+
platform="gpu",
|
|
339
|
+
)
|
warp/native/builtin.h
CHANGED
|
@@ -354,6 +354,12 @@ inline CUDA_CALLABLE uint32 sign(uint32 x) { return 1; }
|
|
|
354
354
|
inline CUDA_CALLABLE uint64 sign(uint64 x) { return 1; }
|
|
355
355
|
|
|
356
356
|
|
|
357
|
+
// Catch-all for non-float types
|
|
358
|
+
template<typename T>
|
|
359
|
+
inline bool CUDA_CALLABLE isfinite(const T&)
|
|
360
|
+
{
|
|
361
|
+
return true;
|
|
362
|
+
}
|
|
357
363
|
|
|
358
364
|
inline bool CUDA_CALLABLE isfinite(half x)
|
|
359
365
|
{
|
|
@@ -368,6 +374,12 @@ inline bool CUDA_CALLABLE isfinite(double x)
|
|
|
368
374
|
return ::isfinite(x);
|
|
369
375
|
}
|
|
370
376
|
|
|
377
|
+
template<typename T>
|
|
378
|
+
inline CUDA_CALLABLE void print(const T&)
|
|
379
|
+
{
|
|
380
|
+
printf("<type without print implementation>\n");
|
|
381
|
+
}
|
|
382
|
+
|
|
371
383
|
inline CUDA_CALLABLE void print(float16 f)
|
|
372
384
|
{
|
|
373
385
|
printf("%g\n", half_to_float(f));
|
warp/native/bvh.cu
CHANGED
|
@@ -373,16 +373,16 @@ LinearBVHBuilderGPU::LinearBVHBuilderGPU()
|
|
|
373
373
|
, total_upper(NULL)
|
|
374
374
|
, total_inv_edges(NULL)
|
|
375
375
|
{
|
|
376
|
-
total_lower = (vec3*)
|
|
377
|
-
total_upper = (vec3*)
|
|
378
|
-
total_inv_edges = (vec3*)
|
|
376
|
+
total_lower = (vec3*)alloc_device(WP_CURRENT_CONTEXT, sizeof(vec3));
|
|
377
|
+
total_upper = (vec3*)alloc_device(WP_CURRENT_CONTEXT, sizeof(vec3));
|
|
378
|
+
total_inv_edges = (vec3*)alloc_device(WP_CURRENT_CONTEXT, sizeof(vec3));
|
|
379
379
|
}
|
|
380
380
|
|
|
381
381
|
LinearBVHBuilderGPU::~LinearBVHBuilderGPU()
|
|
382
382
|
{
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
383
|
+
free_device(WP_CURRENT_CONTEXT, total_lower);
|
|
384
|
+
free_device(WP_CURRENT_CONTEXT, total_upper);
|
|
385
|
+
free_device(WP_CURRENT_CONTEXT, total_inv_edges);
|
|
386
386
|
}
|
|
387
387
|
|
|
388
388
|
|
|
@@ -390,12 +390,12 @@ LinearBVHBuilderGPU::~LinearBVHBuilderGPU()
|
|
|
390
390
|
void LinearBVHBuilderGPU::build(BVH& bvh, const vec3* item_lowers, const vec3* item_uppers, int num_items, bounds3* total_bounds)
|
|
391
391
|
{
|
|
392
392
|
// allocate temporary memory used during building
|
|
393
|
-
indices = (int*)
|
|
394
|
-
keys = (int*)
|
|
395
|
-
deltas = (int*)
|
|
396
|
-
range_lefts = (int*)
|
|
397
|
-
range_rights = (int*)
|
|
398
|
-
num_children = (int*)
|
|
393
|
+
indices = (int*)alloc_device(WP_CURRENT_CONTEXT, sizeof(int)*num_items*2); // *2 for radix sort
|
|
394
|
+
keys = (int*)alloc_device(WP_CURRENT_CONTEXT, sizeof(int)*num_items*2); // *2 for radix sort
|
|
395
|
+
deltas = (int*)alloc_device(WP_CURRENT_CONTEXT, sizeof(int)*num_items); // highest differenting bit between keys for item i and i+1
|
|
396
|
+
range_lefts = (int*)alloc_device(WP_CURRENT_CONTEXT, sizeof(int)*bvh.max_nodes);
|
|
397
|
+
range_rights = (int*)alloc_device(WP_CURRENT_CONTEXT, sizeof(int)*bvh.max_nodes);
|
|
398
|
+
num_children = (int*)alloc_device(WP_CURRENT_CONTEXT, sizeof(int)*bvh.max_nodes);
|
|
399
399
|
|
|
400
400
|
// if total bounds supplied by the host then we just
|
|
401
401
|
// compute our edge length and upload it to the GPU directly
|
|
@@ -445,13 +445,13 @@ void LinearBVHBuilderGPU::build(BVH& bvh, const vec3* item_lowers, const vec3* i
|
|
|
445
445
|
wp_launch_device(WP_CURRENT_CONTEXT, build_hierarchy, num_items, (num_items, bvh.root, deltas, num_children, range_lefts, range_rights, bvh.node_parents, bvh.node_lowers, bvh.node_uppers));
|
|
446
446
|
|
|
447
447
|
// free temporary memory
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
448
|
+
free_device(WP_CURRENT_CONTEXT, indices);
|
|
449
|
+
free_device(WP_CURRENT_CONTEXT, keys);
|
|
450
|
+
free_device(WP_CURRENT_CONTEXT, deltas);
|
|
451
451
|
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
452
|
+
free_device(WP_CURRENT_CONTEXT, range_lefts);
|
|
453
|
+
free_device(WP_CURRENT_CONTEXT, range_rights);
|
|
454
|
+
free_device(WP_CURRENT_CONTEXT, num_children);
|
|
455
455
|
|
|
456
456
|
}
|
|
457
457
|
|
warp/native/clang/clang.cpp
CHANGED
|
@@ -81,7 +81,7 @@ static void initialize_llvm()
|
|
|
81
81
|
llvm::InitializeAllAsmPrinters();
|
|
82
82
|
}
|
|
83
83
|
|
|
84
|
-
static std::unique_ptr<llvm::Module> cpp_to_llvm(const std::string& input_file, const char* cpp_src, const char* include_dir, bool debug, llvm::LLVMContext& context)
|
|
84
|
+
static std::unique_ptr<llvm::Module> cpp_to_llvm(const std::string& input_file, const char* cpp_src, const char* include_dir, bool debug, bool verify_fp, llvm::LLVMContext& context)
|
|
85
85
|
{
|
|
86
86
|
// Compilation arguments
|
|
87
87
|
std::vector<const char*> args;
|
|
@@ -126,6 +126,11 @@ static std::unique_ptr<llvm::Module> cpp_to_llvm(const std::string& input_file,
|
|
|
126
126
|
compiler_instance.getPreprocessorOpts().addMacroDef("NDEBUG");
|
|
127
127
|
}
|
|
128
128
|
|
|
129
|
+
if(verify_fp)
|
|
130
|
+
{
|
|
131
|
+
compiler_instance.getPreprocessorOpts().addMacroDef("WP_VERIFY_FP");
|
|
132
|
+
}
|
|
133
|
+
|
|
129
134
|
compiler_instance.getLangOpts().MicrosoftExt = 1; // __forceinline / __int64
|
|
130
135
|
compiler_instance.getLangOpts().DeclSpecKeyword = 1; // __declspec
|
|
131
136
|
|
|
@@ -201,12 +206,12 @@ static std::unique_ptr<llvm::Module> cuda_to_llvm(const std::string& input_file,
|
|
|
201
206
|
|
|
202
207
|
extern "C" {
|
|
203
208
|
|
|
204
|
-
WP_API int compile_cpp(const char* cpp_src, const char *input_file, const char* include_dir, const char* output_file, bool debug)
|
|
209
|
+
WP_API int compile_cpp(const char* cpp_src, const char *input_file, const char* include_dir, const char* output_file, bool debug, bool verify_fp)
|
|
205
210
|
{
|
|
206
211
|
initialize_llvm();
|
|
207
212
|
|
|
208
213
|
llvm::LLVMContext context;
|
|
209
|
-
std::unique_ptr<llvm::Module> module = cpp_to_llvm(input_file, cpp_src, include_dir, debug, context);
|
|
214
|
+
std::unique_ptr<llvm::Module> module = cpp_to_llvm(input_file, cpp_src, include_dir, debug, verify_fp, context);
|
|
210
215
|
|
|
211
216
|
if(!module)
|
|
212
217
|
{
|
warp/native/cuda_util.cpp
CHANGED
|
@@ -9,6 +9,7 @@
|
|
|
9
9
|
#if WP_ENABLE_CUDA
|
|
10
10
|
|
|
11
11
|
#include "cuda_util.h"
|
|
12
|
+
#include "error.h"
|
|
12
13
|
|
|
13
14
|
#if defined(_WIN32)
|
|
14
15
|
#define WIN32_LEAN_AND_MEAN
|
|
@@ -19,6 +20,9 @@
|
|
|
19
20
|
#include <dlfcn.h>
|
|
20
21
|
#endif
|
|
21
22
|
|
|
23
|
+
#include <set>
|
|
24
|
+
#include <stack>
|
|
25
|
+
|
|
22
26
|
// the minimum CUDA version required from the driver
|
|
23
27
|
#define WP_CUDA_DRIVER_VERSION 11030
|
|
24
28
|
|
|
@@ -63,6 +67,7 @@ static PFN_cuDeviceGetUuid_v11040 pfn_cuDeviceGetUuid;
|
|
|
63
67
|
static PFN_cuDevicePrimaryCtxRetain_v7000 pfn_cuDevicePrimaryCtxRetain;
|
|
64
68
|
static PFN_cuDevicePrimaryCtxRelease_v11000 pfn_cuDevicePrimaryCtxRelease;
|
|
65
69
|
static PFN_cuDeviceCanAccessPeer_v4000 pfn_cuDeviceCanAccessPeer;
|
|
70
|
+
static PFN_cuMemGetInfo_v3020 pfn_cuMemGetInfo;
|
|
66
71
|
static PFN_cuCtxGetCurrent_v4000 pfn_cuCtxGetCurrent;
|
|
67
72
|
static PFN_cuCtxSetCurrent_v4000 pfn_cuCtxSetCurrent;
|
|
68
73
|
static PFN_cuCtxPushCurrent_v4000 pfn_cuCtxPushCurrent;
|
|
@@ -72,18 +77,23 @@ static PFN_cuCtxGetDevice_v2000 pfn_cuCtxGetDevice;
|
|
|
72
77
|
static PFN_cuCtxCreate_v3020 pfn_cuCtxCreate;
|
|
73
78
|
static PFN_cuCtxDestroy_v4000 pfn_cuCtxDestroy;
|
|
74
79
|
static PFN_cuCtxEnablePeerAccess_v4000 pfn_cuCtxEnablePeerAccess;
|
|
80
|
+
static PFN_cuCtxDisablePeerAccess_v4000 pfn_cuCtxDisablePeerAccess;
|
|
75
81
|
static PFN_cuStreamCreate_v2000 pfn_cuStreamCreate;
|
|
76
82
|
static PFN_cuStreamDestroy_v4000 pfn_cuStreamDestroy;
|
|
77
83
|
static PFN_cuStreamSynchronize_v2000 pfn_cuStreamSynchronize;
|
|
78
84
|
static PFN_cuStreamWaitEvent_v3020 pfn_cuStreamWaitEvent;
|
|
85
|
+
static PFN_cuStreamGetCaptureInfo_v11030 pfn_cuStreamGetCaptureInfo;
|
|
86
|
+
static PFN_cuStreamUpdateCaptureDependencies_v11030 pfn_cuStreamUpdateCaptureDependencies;
|
|
79
87
|
static PFN_cuEventCreate_v2000 pfn_cuEventCreate;
|
|
80
88
|
static PFN_cuEventDestroy_v4000 pfn_cuEventDestroy;
|
|
81
89
|
static PFN_cuEventRecord_v2000 pfn_cuEventRecord;
|
|
90
|
+
static PFN_cuEventRecordWithFlags_v11010 pfn_cuEventRecordWithFlags;
|
|
82
91
|
static PFN_cuModuleLoadDataEx_v2010 pfn_cuModuleLoadDataEx;
|
|
83
92
|
static PFN_cuModuleUnload_v2000 pfn_cuModuleUnload;
|
|
84
93
|
static PFN_cuModuleGetFunction_v2000 pfn_cuModuleGetFunction;
|
|
85
94
|
static PFN_cuLaunchKernel_v4000 pfn_cuLaunchKernel;
|
|
86
95
|
static PFN_cuMemcpyPeerAsync_v4000 pfn_cuMemcpyPeerAsync;
|
|
96
|
+
static PFN_cuPointerGetAttribute_v4000 pfn_cuPointerGetAttribute;
|
|
87
97
|
static PFN_cuGraphicsMapResources_v3000 pfn_cuGraphicsMapResources;
|
|
88
98
|
static PFN_cuGraphicsUnmapResources_v3000 pfn_cuGraphicsUnmapResources;
|
|
89
99
|
static PFN_cuGraphicsResourceGetMappedPointer_v3020 pfn_cuGraphicsResourceGetMappedPointer;
|
|
@@ -171,6 +181,7 @@ bool init_cuda_driver()
|
|
|
171
181
|
get_driver_entry_point("cuDevicePrimaryCtxRetain", &(void*&)pfn_cuDevicePrimaryCtxRetain);
|
|
172
182
|
get_driver_entry_point("cuDevicePrimaryCtxRelease", &(void*&)pfn_cuDevicePrimaryCtxRelease);
|
|
173
183
|
get_driver_entry_point("cuDeviceCanAccessPeer", &(void*&)pfn_cuDeviceCanAccessPeer);
|
|
184
|
+
get_driver_entry_point("cuMemGetInfo", &(void*&)pfn_cuMemGetInfo);
|
|
174
185
|
get_driver_entry_point("cuCtxSetCurrent", &(void*&)pfn_cuCtxSetCurrent);
|
|
175
186
|
get_driver_entry_point("cuCtxGetCurrent", &(void*&)pfn_cuCtxGetCurrent);
|
|
176
187
|
get_driver_entry_point("cuCtxPushCurrent", &(void*&)pfn_cuCtxPushCurrent);
|
|
@@ -180,18 +191,23 @@ bool init_cuda_driver()
|
|
|
180
191
|
get_driver_entry_point("cuCtxCreate", &(void*&)pfn_cuCtxCreate);
|
|
181
192
|
get_driver_entry_point("cuCtxDestroy", &(void*&)pfn_cuCtxDestroy);
|
|
182
193
|
get_driver_entry_point("cuCtxEnablePeerAccess", &(void*&)pfn_cuCtxEnablePeerAccess);
|
|
194
|
+
get_driver_entry_point("cuCtxDisablePeerAccess", &(void*&)pfn_cuCtxDisablePeerAccess);
|
|
183
195
|
get_driver_entry_point("cuStreamCreate", &(void*&)pfn_cuStreamCreate);
|
|
184
196
|
get_driver_entry_point("cuStreamDestroy", &(void*&)pfn_cuStreamDestroy);
|
|
185
197
|
get_driver_entry_point("cuStreamSynchronize", &(void*&)pfn_cuStreamSynchronize);
|
|
186
198
|
get_driver_entry_point("cuStreamWaitEvent", &(void*&)pfn_cuStreamWaitEvent);
|
|
199
|
+
get_driver_entry_point("cuStreamGetCaptureInfo", &(void*&)pfn_cuStreamGetCaptureInfo);
|
|
200
|
+
get_driver_entry_point("cuStreamUpdateCaptureDependencies", &(void*&)pfn_cuStreamUpdateCaptureDependencies);
|
|
187
201
|
get_driver_entry_point("cuEventCreate", &(void*&)pfn_cuEventCreate);
|
|
188
202
|
get_driver_entry_point("cuEventDestroy", &(void*&)pfn_cuEventDestroy);
|
|
189
203
|
get_driver_entry_point("cuEventRecord", &(void*&)pfn_cuEventRecord);
|
|
204
|
+
get_driver_entry_point("cuEventRecordWithFlags", &(void*&)pfn_cuEventRecordWithFlags);
|
|
190
205
|
get_driver_entry_point("cuModuleLoadDataEx", &(void*&)pfn_cuModuleLoadDataEx);
|
|
191
206
|
get_driver_entry_point("cuModuleUnload", &(void*&)pfn_cuModuleUnload);
|
|
192
207
|
get_driver_entry_point("cuModuleGetFunction", &(void*&)pfn_cuModuleGetFunction);
|
|
193
208
|
get_driver_entry_point("cuLaunchKernel", &(void*&)pfn_cuLaunchKernel);
|
|
194
209
|
get_driver_entry_point("cuMemcpyPeerAsync", &(void*&)pfn_cuMemcpyPeerAsync);
|
|
210
|
+
get_driver_entry_point("cuPointerGetAttribute", &(void*&)pfn_cuPointerGetAttribute);
|
|
195
211
|
get_driver_entry_point("cuGraphicsMapResources", &(void*&)pfn_cuGraphicsMapResources);
|
|
196
212
|
get_driver_entry_point("cuGraphicsUnmapResources", &(void*&)pfn_cuGraphicsUnmapResources);
|
|
197
213
|
get_driver_entry_point("cuGraphicsResourceGetMappedPointer", &(void*&)pfn_cuGraphicsResourceGetMappedPointer);
|
|
@@ -209,16 +225,16 @@ bool is_cuda_driver_initialized()
|
|
|
209
225
|
return cuda_driver_initialized;
|
|
210
226
|
}
|
|
211
227
|
|
|
212
|
-
bool check_cuda_result(cudaError_t code, const char* file, int line)
|
|
228
|
+
bool check_cuda_result(cudaError_t code, const char* func, const char* file, int line)
|
|
213
229
|
{
|
|
214
230
|
if (code == cudaSuccess)
|
|
215
231
|
return true;
|
|
216
232
|
|
|
217
|
-
|
|
233
|
+
wp::set_error_string("Warp CUDA error %u: %s (in function %s, %s:%d)", unsigned(code), cudaGetErrorString(code), func, file, line);
|
|
218
234
|
return false;
|
|
219
235
|
}
|
|
220
236
|
|
|
221
|
-
bool check_cu_result(CUresult result, const char* file, int line)
|
|
237
|
+
bool check_cu_result(CUresult result, const char* func, const char* file, int line)
|
|
222
238
|
{
|
|
223
239
|
if (result == CUDA_SUCCESS)
|
|
224
240
|
return true;
|
|
@@ -228,13 +244,56 @@ bool check_cu_result(CUresult result, const char* file, int line)
|
|
|
228
244
|
pfn_cuGetErrorString(result, &errString);
|
|
229
245
|
|
|
230
246
|
if (errString)
|
|
231
|
-
|
|
247
|
+
wp::set_error_string("Warp CUDA error %u: %s (in function %s, %s:%d)", unsigned(result), errString, func, file, line);
|
|
232
248
|
else
|
|
233
|
-
|
|
249
|
+
wp::set_error_string("Warp CUDA error %u (in function %s, %s:%d)", unsigned(result), func, file, line);
|
|
234
250
|
|
|
235
251
|
return false;
|
|
236
252
|
}
|
|
237
253
|
|
|
254
|
+
bool get_capture_dependencies(CUstream stream, std::vector<CUgraphNode>& dependencies_ret)
|
|
255
|
+
{
|
|
256
|
+
CUstreamCaptureStatus status;
|
|
257
|
+
size_t num_dependencies = 0;
|
|
258
|
+
const CUgraphNode* dependencies = NULL;
|
|
259
|
+
dependencies_ret.clear();
|
|
260
|
+
if (check_cu(cuStreamGetCaptureInfo_f(stream, &status, NULL, NULL, &dependencies, &num_dependencies)))
|
|
261
|
+
{
|
|
262
|
+
if (dependencies && num_dependencies > 0)
|
|
263
|
+
dependencies_ret.insert(dependencies_ret.begin(), dependencies, dependencies + num_dependencies);
|
|
264
|
+
return true;
|
|
265
|
+
}
|
|
266
|
+
return false;
|
|
267
|
+
}
|
|
268
|
+
|
|
269
|
+
bool get_graph_leaf_nodes(cudaGraph_t graph, std::vector<cudaGraphNode_t>& leaf_nodes_ret)
|
|
270
|
+
{
|
|
271
|
+
if (!graph)
|
|
272
|
+
return false;
|
|
273
|
+
|
|
274
|
+
size_t node_count = 0;
|
|
275
|
+
if (!check_cuda(cudaGraphGetNodes(graph, NULL, &node_count)))
|
|
276
|
+
return false;
|
|
277
|
+
|
|
278
|
+
std::vector<cudaGraphNode_t> nodes(node_count);
|
|
279
|
+
if (!check_cuda(cudaGraphGetNodes(graph, nodes.data(), &node_count)))
|
|
280
|
+
return false;
|
|
281
|
+
|
|
282
|
+
leaf_nodes_ret.clear();
|
|
283
|
+
|
|
284
|
+
for (cudaGraphNode_t node : nodes)
|
|
285
|
+
{
|
|
286
|
+
size_t dependent_count;
|
|
287
|
+
if (!check_cuda(cudaGraphNodeGetDependentNodes(node, NULL, &dependent_count)))
|
|
288
|
+
return false;
|
|
289
|
+
|
|
290
|
+
if (dependent_count == 0)
|
|
291
|
+
leaf_nodes_ret.push_back(node);
|
|
292
|
+
}
|
|
293
|
+
|
|
294
|
+
return true;
|
|
295
|
+
}
|
|
296
|
+
|
|
238
297
|
|
|
239
298
|
#define DRIVER_ENTRY_POINT_ERROR driver_entry_point_error(__FUNCTION__)
|
|
240
299
|
|
|
@@ -311,6 +370,11 @@ CUresult cuDeviceCanAccessPeer_f(int* can_access, CUdevice dev, CUdevice peer_de
|
|
|
311
370
|
return pfn_cuDeviceCanAccessPeer ? pfn_cuDeviceCanAccessPeer(can_access, dev, peer_dev) : DRIVER_ENTRY_POINT_ERROR;
|
|
312
371
|
}
|
|
313
372
|
|
|
373
|
+
CUresult cuMemGetInfo_f(size_t* free, size_t* total)
|
|
374
|
+
{
|
|
375
|
+
return pfn_cuMemGetInfo ? pfn_cuMemGetInfo(free, total) : DRIVER_ENTRY_POINT_ERROR;
|
|
376
|
+
}
|
|
377
|
+
|
|
314
378
|
CUresult cuCtxGetCurrent_f(CUcontext* ctx)
|
|
315
379
|
{
|
|
316
380
|
return pfn_cuCtxGetCurrent ? pfn_cuCtxGetCurrent(ctx) : DRIVER_ENTRY_POINT_ERROR;
|
|
@@ -356,6 +420,11 @@ CUresult cuCtxEnablePeerAccess_f(CUcontext peer_ctx, unsigned int flags)
|
|
|
356
420
|
return pfn_cuCtxEnablePeerAccess ? pfn_cuCtxEnablePeerAccess(peer_ctx, flags) : DRIVER_ENTRY_POINT_ERROR;
|
|
357
421
|
}
|
|
358
422
|
|
|
423
|
+
CUresult cuCtxDisablePeerAccess_f(CUcontext peer_ctx)
|
|
424
|
+
{
|
|
425
|
+
return pfn_cuCtxDisablePeerAccess ? pfn_cuCtxDisablePeerAccess(peer_ctx) : DRIVER_ENTRY_POINT_ERROR;
|
|
426
|
+
}
|
|
427
|
+
|
|
359
428
|
CUresult cuStreamCreate_f(CUstream* stream, unsigned int flags)
|
|
360
429
|
{
|
|
361
430
|
return pfn_cuStreamCreate ? pfn_cuStreamCreate(stream, flags) : DRIVER_ENTRY_POINT_ERROR;
|
|
@@ -376,6 +445,16 @@ CUresult cuStreamWaitEvent_f(CUstream stream, CUevent event, unsigned int flags)
|
|
|
376
445
|
return pfn_cuStreamWaitEvent ? pfn_cuStreamWaitEvent(stream, event, flags) : DRIVER_ENTRY_POINT_ERROR;
|
|
377
446
|
}
|
|
378
447
|
|
|
448
|
+
CUresult cuStreamGetCaptureInfo_f(CUstream stream, CUstreamCaptureStatus *captureStatus_out, cuuint64_t *id_out, CUgraph *graph_out, const CUgraphNode **dependencies_out, size_t *numDependencies_out)
|
|
449
|
+
{
|
|
450
|
+
return pfn_cuStreamGetCaptureInfo ? pfn_cuStreamGetCaptureInfo(stream, captureStatus_out, id_out, graph_out, dependencies_out, numDependencies_out) : DRIVER_ENTRY_POINT_ERROR;
|
|
451
|
+
}
|
|
452
|
+
|
|
453
|
+
CUresult cuStreamUpdateCaptureDependencies_f(CUstream stream, CUgraphNode *dependencies, size_t numDependencies, unsigned int flags)
|
|
454
|
+
{
|
|
455
|
+
return pfn_cuStreamUpdateCaptureDependencies ? pfn_cuStreamUpdateCaptureDependencies(stream, dependencies, numDependencies, flags) : DRIVER_ENTRY_POINT_ERROR;
|
|
456
|
+
}
|
|
457
|
+
|
|
379
458
|
CUresult cuEventCreate_f(CUevent* event, unsigned int flags)
|
|
380
459
|
{
|
|
381
460
|
return pfn_cuEventCreate ? pfn_cuEventCreate(event, flags) : DRIVER_ENTRY_POINT_ERROR;
|
|
@@ -391,6 +470,11 @@ CUresult cuEventRecord_f(CUevent event, CUstream stream)
|
|
|
391
470
|
return pfn_cuEventRecord ? pfn_cuEventRecord(event, stream) : DRIVER_ENTRY_POINT_ERROR;
|
|
392
471
|
}
|
|
393
472
|
|
|
473
|
+
CUresult cuEventRecordWithFlags_f(CUevent event, CUstream stream, unsigned int flags)
|
|
474
|
+
{
|
|
475
|
+
return pfn_cuEventRecordWithFlags ? pfn_cuEventRecordWithFlags(event, stream, flags) : DRIVER_ENTRY_POINT_ERROR;
|
|
476
|
+
}
|
|
477
|
+
|
|
394
478
|
CUresult cuModuleLoadDataEx_f(CUmodule *module, const void *image, unsigned int numOptions, CUjit_option *options, void **optionValues)
|
|
395
479
|
{
|
|
396
480
|
return pfn_cuModuleLoadDataEx ? pfn_cuModuleLoadDataEx(module, image, numOptions, options, optionValues) : DRIVER_ENTRY_POINT_ERROR;
|
|
@@ -416,6 +500,11 @@ CUresult cuMemcpyPeerAsync_f(CUdeviceptr dst_ptr, CUcontext dst_ctx, CUdeviceptr
|
|
|
416
500
|
return pfn_cuMemcpyPeerAsync ? pfn_cuMemcpyPeerAsync(dst_ptr, dst_ctx, src_ptr, src_ctx, n, stream) : DRIVER_ENTRY_POINT_ERROR;
|
|
417
501
|
}
|
|
418
502
|
|
|
503
|
+
CUresult cuPointerGetAttribute_f(void* data, CUpointer_attribute attribute, CUdeviceptr ptr)
|
|
504
|
+
{
|
|
505
|
+
return pfn_cuPointerGetAttribute ? pfn_cuPointerGetAttribute(data, attribute, ptr) : DRIVER_ENTRY_POINT_ERROR;
|
|
506
|
+
}
|
|
507
|
+
|
|
419
508
|
CUresult cuGraphicsMapResources_f(unsigned int count, CUgraphicsResource* resources, CUstream stream)
|
|
420
509
|
{
|
|
421
510
|
return pfn_cuGraphicsMapResources ? pfn_cuGraphicsMapResources(count, resources, stream) : DRIVER_ENTRY_POINT_ERROR;
|