warp-lang 1.7.2rc1__py3-none-manylinux_2_34_aarch64.whl → 1.8.1__py3-none-manylinux_2_34_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +3 -1
- warp/__init__.pyi +3489 -1
- warp/autograd.py +45 -122
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/build.py +241 -252
- warp/build_dll.py +130 -26
- warp/builtins.py +1907 -384
- warp/codegen.py +272 -104
- warp/config.py +12 -1
- warp/constants.py +1 -1
- warp/context.py +770 -238
- warp/dlpack.py +1 -1
- warp/examples/benchmarks/benchmark_cloth.py +2 -2
- warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
- warp/examples/core/example_sample_mesh.py +1 -1
- warp/examples/core/example_spin_lock.py +93 -0
- warp/examples/core/example_work_queue.py +118 -0
- warp/examples/fem/example_adaptive_grid.py +5 -5
- warp/examples/fem/example_apic_fluid.py +1 -1
- warp/examples/fem/example_burgers.py +1 -1
- warp/examples/fem/example_convection_diffusion.py +9 -6
- warp/examples/fem/example_darcy_ls_optimization.py +489 -0
- warp/examples/fem/example_deformed_geometry.py +1 -1
- warp/examples/fem/example_diffusion.py +2 -2
- warp/examples/fem/example_diffusion_3d.py +1 -1
- warp/examples/fem/example_distortion_energy.py +1 -1
- warp/examples/fem/example_elastic_shape_optimization.py +387 -0
- warp/examples/fem/example_magnetostatics.py +5 -3
- warp/examples/fem/example_mixed_elasticity.py +5 -3
- warp/examples/fem/example_navier_stokes.py +11 -9
- warp/examples/fem/example_nonconforming_contact.py +5 -3
- warp/examples/fem/example_streamlines.py +8 -3
- warp/examples/fem/utils.py +9 -8
- warp/examples/interop/example_jax_callable.py +34 -4
- warp/examples/interop/example_jax_ffi_callback.py +2 -2
- warp/examples/interop/example_jax_kernel.py +27 -1
- warp/examples/optim/example_drone.py +1 -1
- warp/examples/sim/example_cloth.py +1 -1
- warp/examples/sim/example_cloth_self_contact.py +48 -54
- warp/examples/tile/example_tile_block_cholesky.py +502 -0
- warp/examples/tile/example_tile_cholesky.py +2 -1
- warp/examples/tile/example_tile_convolution.py +1 -1
- warp/examples/tile/example_tile_filtering.py +1 -1
- warp/examples/tile/example_tile_matmul.py +1 -1
- warp/examples/tile/example_tile_mlp.py +2 -0
- warp/fabric.py +7 -7
- warp/fem/__init__.py +5 -0
- warp/fem/adaptivity.py +1 -1
- warp/fem/cache.py +152 -63
- warp/fem/dirichlet.py +2 -2
- warp/fem/domain.py +136 -6
- warp/fem/field/field.py +141 -99
- warp/fem/field/nodal_field.py +85 -39
- warp/fem/field/virtual.py +99 -52
- warp/fem/geometry/adaptive_nanogrid.py +91 -86
- warp/fem/geometry/closest_point.py +13 -0
- warp/fem/geometry/deformed_geometry.py +102 -40
- warp/fem/geometry/element.py +56 -2
- warp/fem/geometry/geometry.py +323 -22
- warp/fem/geometry/grid_2d.py +157 -62
- warp/fem/geometry/grid_3d.py +116 -20
- warp/fem/geometry/hexmesh.py +86 -20
- warp/fem/geometry/nanogrid.py +166 -86
- warp/fem/geometry/partition.py +59 -25
- warp/fem/geometry/quadmesh.py +86 -135
- warp/fem/geometry/tetmesh.py +47 -119
- warp/fem/geometry/trimesh.py +77 -270
- warp/fem/integrate.py +181 -95
- warp/fem/linalg.py +25 -58
- warp/fem/operator.py +124 -27
- warp/fem/quadrature/pic_quadrature.py +36 -14
- warp/fem/quadrature/quadrature.py +40 -16
- warp/fem/space/__init__.py +1 -1
- warp/fem/space/basis_function_space.py +66 -46
- warp/fem/space/basis_space.py +17 -4
- warp/fem/space/dof_mapper.py +1 -1
- warp/fem/space/function_space.py +2 -2
- warp/fem/space/grid_2d_function_space.py +4 -1
- warp/fem/space/hexmesh_function_space.py +4 -2
- warp/fem/space/nanogrid_function_space.py +3 -1
- warp/fem/space/partition.py +11 -2
- warp/fem/space/quadmesh_function_space.py +4 -1
- warp/fem/space/restriction.py +5 -2
- warp/fem/space/shape/__init__.py +10 -8
- warp/fem/space/tetmesh_function_space.py +4 -1
- warp/fem/space/topology.py +52 -21
- warp/fem/space/trimesh_function_space.py +4 -1
- warp/fem/utils.py +53 -8
- warp/jax.py +1 -2
- warp/jax_experimental/ffi.py +210 -67
- warp/jax_experimental/xla_ffi.py +37 -24
- warp/math.py +171 -1
- warp/native/array.h +103 -4
- warp/native/builtin.h +182 -35
- warp/native/coloring.cpp +6 -2
- warp/native/cuda_util.cpp +1 -1
- warp/native/exports.h +118 -63
- warp/native/intersect.h +5 -5
- warp/native/mat.h +8 -13
- warp/native/mathdx.cpp +11 -5
- warp/native/matnn.h +1 -123
- warp/native/mesh.h +1 -1
- warp/native/quat.h +34 -6
- warp/native/rand.h +7 -7
- warp/native/sparse.cpp +121 -258
- warp/native/sparse.cu +181 -274
- warp/native/spatial.h +305 -17
- warp/native/svd.h +23 -8
- warp/native/tile.h +603 -73
- warp/native/tile_radix_sort.h +1112 -0
- warp/native/tile_reduce.h +239 -13
- warp/native/tile_scan.h +240 -0
- warp/native/tuple.h +189 -0
- warp/native/vec.h +10 -20
- warp/native/warp.cpp +36 -4
- warp/native/warp.cu +588 -52
- warp/native/warp.h +47 -74
- warp/optim/linear.py +5 -1
- warp/paddle.py +7 -8
- warp/py.typed +0 -0
- warp/render/render_opengl.py +110 -80
- warp/render/render_usd.py +124 -62
- warp/sim/__init__.py +9 -0
- warp/sim/collide.py +253 -80
- warp/sim/graph_coloring.py +8 -1
- warp/sim/import_mjcf.py +4 -3
- warp/sim/import_usd.py +11 -7
- warp/sim/integrator.py +5 -2
- warp/sim/integrator_euler.py +1 -1
- warp/sim/integrator_featherstone.py +1 -1
- warp/sim/integrator_vbd.py +761 -322
- warp/sim/integrator_xpbd.py +1 -1
- warp/sim/model.py +265 -260
- warp/sim/utils.py +10 -7
- warp/sparse.py +303 -166
- warp/tape.py +54 -51
- warp/tests/cuda/test_conditional_captures.py +1046 -0
- warp/tests/cuda/test_streams.py +1 -1
- warp/tests/geometry/test_volume.py +2 -2
- warp/tests/interop/test_dlpack.py +9 -9
- warp/tests/interop/test_jax.py +0 -1
- warp/tests/run_coverage_serial.py +1 -1
- warp/tests/sim/disabled_kinematics.py +2 -2
- warp/tests/sim/{test_vbd.py → test_cloth.py} +378 -112
- warp/tests/sim/test_collision.py +159 -51
- warp/tests/sim/test_coloring.py +91 -2
- warp/tests/test_array.py +254 -2
- warp/tests/test_array_reduce.py +2 -2
- warp/tests/test_assert.py +53 -0
- warp/tests/test_atomic_cas.py +312 -0
- warp/tests/test_codegen.py +142 -19
- warp/tests/test_conditional.py +47 -1
- warp/tests/test_ctypes.py +0 -20
- warp/tests/test_devices.py +8 -0
- warp/tests/test_fabricarray.py +4 -2
- warp/tests/test_fem.py +58 -25
- warp/tests/test_func.py +42 -1
- warp/tests/test_grad.py +1 -1
- warp/tests/test_lerp.py +1 -3
- warp/tests/test_map.py +481 -0
- warp/tests/test_mat.py +23 -24
- warp/tests/test_quat.py +28 -15
- warp/tests/test_rounding.py +10 -38
- warp/tests/test_runlength_encode.py +7 -7
- warp/tests/test_smoothstep.py +1 -1
- warp/tests/test_sparse.py +83 -2
- warp/tests/test_spatial.py +507 -1
- warp/tests/test_static.py +48 -0
- warp/tests/test_struct.py +2 -2
- warp/tests/test_tape.py +38 -0
- warp/tests/test_tuple.py +265 -0
- warp/tests/test_types.py +2 -2
- warp/tests/test_utils.py +24 -18
- warp/tests/test_vec.py +38 -408
- warp/tests/test_vec_constructors.py +325 -0
- warp/tests/tile/test_tile.py +438 -131
- warp/tests/tile/test_tile_mathdx.py +518 -14
- warp/tests/tile/test_tile_matmul.py +179 -0
- warp/tests/tile/test_tile_reduce.py +307 -5
- warp/tests/tile/test_tile_shared_memory.py +136 -7
- warp/tests/tile/test_tile_sort.py +121 -0
- warp/tests/unittest_suites.py +14 -6
- warp/types.py +462 -308
- warp/utils.py +647 -86
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/METADATA +20 -6
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/RECORD +190 -176
- warp/stubs.py +0 -3381
- warp/tests/sim/test_xpbd.py +0 -399
- warp/tests/test_mlp.py +0 -282
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/WHEEL +0 -0
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/licenses/LICENSE.md +0 -0
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/top_level.txt +0 -0
warp/context.py
CHANGED
|
@@ -32,22 +32,7 @@ import typing
|
|
|
32
32
|
import weakref
|
|
33
33
|
from copy import copy as shallowcopy
|
|
34
34
|
from pathlib import Path
|
|
35
|
-
from typing import
|
|
36
|
-
Any,
|
|
37
|
-
Callable,
|
|
38
|
-
Dict,
|
|
39
|
-
List,
|
|
40
|
-
Literal,
|
|
41
|
-
Mapping,
|
|
42
|
-
Optional,
|
|
43
|
-
Sequence,
|
|
44
|
-
Set,
|
|
45
|
-
Tuple,
|
|
46
|
-
TypeVar,
|
|
47
|
-
Union,
|
|
48
|
-
get_args,
|
|
49
|
-
get_origin,
|
|
50
|
-
)
|
|
35
|
+
from typing import Any, Callable, Dict, List, Literal, Mapping, Sequence, Tuple, TypeVar, Union, get_args, get_origin
|
|
51
36
|
|
|
52
37
|
import numpy as np
|
|
53
38
|
|
|
@@ -84,7 +69,7 @@ def get_function_args(func):
|
|
|
84
69
|
complex_type_hints = (Any, Callable, Tuple)
|
|
85
70
|
sequence_types = (list, tuple)
|
|
86
71
|
|
|
87
|
-
function_key_counts:
|
|
72
|
+
function_key_counts: dict[str, int] = {}
|
|
88
73
|
|
|
89
74
|
|
|
90
75
|
def generate_unique_function_identifier(key: str) -> str:
|
|
@@ -120,40 +105,41 @@ def generate_unique_function_identifier(key: str) -> str:
|
|
|
120
105
|
class Function:
|
|
121
106
|
def __init__(
|
|
122
107
|
self,
|
|
123
|
-
func:
|
|
108
|
+
func: Callable | None,
|
|
124
109
|
key: str,
|
|
125
110
|
namespace: str,
|
|
126
|
-
input_types:
|
|
127
|
-
value_type:
|
|
128
|
-
value_func:
|
|
129
|
-
export_func:
|
|
130
|
-
dispatch_func:
|
|
131
|
-
lto_dispatch_func:
|
|
132
|
-
module:
|
|
111
|
+
input_types: dict[str, type | TypeVar] | None = None,
|
|
112
|
+
value_type: type | None = None,
|
|
113
|
+
value_func: Callable[[Mapping[str, type], Mapping[str, Any]], type] | None = None,
|
|
114
|
+
export_func: Callable[[dict[str, type]], dict[str, type]] | None = None,
|
|
115
|
+
dispatch_func: Callable | None = None,
|
|
116
|
+
lto_dispatch_func: Callable | None = None,
|
|
117
|
+
module: Module | None = None,
|
|
133
118
|
variadic: bool = False,
|
|
134
|
-
initializer_list_func:
|
|
119
|
+
initializer_list_func: Callable[[dict[str, Any], type], bool] | None = None,
|
|
135
120
|
export: bool = False,
|
|
121
|
+
source: str | None = None,
|
|
136
122
|
doc: str = "",
|
|
137
123
|
group: str = "",
|
|
138
124
|
hidden: bool = False,
|
|
139
125
|
skip_replay: bool = False,
|
|
140
126
|
missing_grad: bool = False,
|
|
141
127
|
generic: bool = False,
|
|
142
|
-
native_func:
|
|
143
|
-
defaults:
|
|
144
|
-
custom_replay_func:
|
|
145
|
-
native_snippet:
|
|
146
|
-
adj_native_snippet:
|
|
147
|
-
replay_snippet:
|
|
128
|
+
native_func: str | None = None,
|
|
129
|
+
defaults: dict[str, Any] | None = None,
|
|
130
|
+
custom_replay_func: Function | None = None,
|
|
131
|
+
native_snippet: str | None = None,
|
|
132
|
+
adj_native_snippet: str | None = None,
|
|
133
|
+
replay_snippet: str | None = None,
|
|
148
134
|
skip_forward_codegen: bool = False,
|
|
149
135
|
skip_reverse_codegen: bool = False,
|
|
150
136
|
custom_reverse_num_input_args: int = -1,
|
|
151
137
|
custom_reverse_mode: bool = False,
|
|
152
|
-
overloaded_annotations:
|
|
153
|
-
code_transformers:
|
|
138
|
+
overloaded_annotations: dict[str, type] | None = None,
|
|
139
|
+
code_transformers: list[ast.NodeTransformer] | None = None,
|
|
154
140
|
skip_adding_overload: bool = False,
|
|
155
141
|
require_original_output_arg: bool = False,
|
|
156
|
-
scope_locals:
|
|
142
|
+
scope_locals: dict[str, Any] | None = None,
|
|
157
143
|
):
|
|
158
144
|
if code_transformers is None:
|
|
159
145
|
code_transformers = []
|
|
@@ -178,7 +164,7 @@ class Function:
|
|
|
178
164
|
self.native_snippet = native_snippet
|
|
179
165
|
self.adj_native_snippet = adj_native_snippet
|
|
180
166
|
self.replay_snippet = replay_snippet
|
|
181
|
-
self.custom_grad_func:
|
|
167
|
+
self.custom_grad_func: Function | None = None
|
|
182
168
|
self.require_original_output_arg = require_original_output_arg
|
|
183
169
|
self.generic_parent = None # generic function that was used to instantiate this overload
|
|
184
170
|
|
|
@@ -194,7 +180,7 @@ class Function:
|
|
|
194
180
|
)
|
|
195
181
|
self.missing_grad = missing_grad # whether builtin is missing a corresponding adjoint
|
|
196
182
|
self.generic = generic
|
|
197
|
-
self.mangled_name:
|
|
183
|
+
self.mangled_name: str | None = None
|
|
198
184
|
|
|
199
185
|
# allow registering functions with a different name in Python and native code
|
|
200
186
|
if native_func is None:
|
|
@@ -211,12 +197,13 @@ class Function:
|
|
|
211
197
|
# user-defined function
|
|
212
198
|
|
|
213
199
|
# generic and concrete overload lookups by type signature
|
|
214
|
-
self.user_templates:
|
|
215
|
-
self.user_overloads:
|
|
200
|
+
self.user_templates: dict[str, Function] = {}
|
|
201
|
+
self.user_overloads: dict[str, Function] = {}
|
|
216
202
|
|
|
217
203
|
# user defined (Python) function
|
|
218
204
|
self.adj = warp.codegen.Adjoint(
|
|
219
205
|
func,
|
|
206
|
+
source=source,
|
|
220
207
|
is_user_function=True,
|
|
221
208
|
skip_forward_codegen=skip_forward_codegen,
|
|
222
209
|
skip_reverse_codegen=skip_reverse_codegen,
|
|
@@ -244,7 +231,7 @@ class Function:
|
|
|
244
231
|
|
|
245
232
|
# embedded linked list of all overloads
|
|
246
233
|
# the builtin_functions dictionary holds the list head for a given key (func name)
|
|
247
|
-
self.overloads:
|
|
234
|
+
self.overloads: list[Function] = []
|
|
248
235
|
|
|
249
236
|
# builtin (native) function, canonicalize argument types
|
|
250
237
|
if input_types is not None:
|
|
@@ -293,10 +280,11 @@ class Function:
|
|
|
293
280
|
module.register_function(self, scope_locals, skip_adding_overload)
|
|
294
281
|
|
|
295
282
|
def __call__(self, *args, **kwargs):
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
283
|
+
"""Call this function from the CPython interpreter.
|
|
284
|
+
|
|
285
|
+
This is used to call built-in or user functions from the CPython
|
|
286
|
+
interpreter, rather than from within a kernel.
|
|
287
|
+
"""
|
|
300
288
|
|
|
301
289
|
if self.is_builtin() and self.mangled_name:
|
|
302
290
|
# For each of this function's existing overloads, we attempt to pack
|
|
@@ -306,7 +294,23 @@ class Function:
|
|
|
306
294
|
if overload.generic:
|
|
307
295
|
continue
|
|
308
296
|
|
|
309
|
-
|
|
297
|
+
try:
|
|
298
|
+
# Try to bind the given arguments to the function's signature.
|
|
299
|
+
# This is not checking whether the argument types are matching,
|
|
300
|
+
# rather it's just assigning each argument to the corresponding
|
|
301
|
+
# function parameter.
|
|
302
|
+
bound_args = self.signature.bind(*args, **kwargs)
|
|
303
|
+
except TypeError:
|
|
304
|
+
continue
|
|
305
|
+
|
|
306
|
+
if self.defaults:
|
|
307
|
+
# Populate the bound arguments with any default values.
|
|
308
|
+
default_args = {k: v for k, v in self.defaults.items() if k not in bound_args.arguments}
|
|
309
|
+
warp.codegen.apply_defaults(bound_args, default_args)
|
|
310
|
+
|
|
311
|
+
bound_args = tuple(bound_args.arguments.values())
|
|
312
|
+
|
|
313
|
+
success, return_value = call_builtin(overload, bound_args)
|
|
310
314
|
if success:
|
|
311
315
|
return return_value
|
|
312
316
|
|
|
@@ -324,6 +328,9 @@ class Function:
|
|
|
324
328
|
|
|
325
329
|
arguments = tuple(bound_args.arguments.values())
|
|
326
330
|
|
|
331
|
+
# Store the last runtime error we encountered from a function execution
|
|
332
|
+
last_execution_error = None
|
|
333
|
+
|
|
327
334
|
# try and find a matching overload
|
|
328
335
|
for overload in self.user_overloads.values():
|
|
329
336
|
if len(overload.input_types) != len(arguments):
|
|
@@ -334,10 +341,25 @@ class Function:
|
|
|
334
341
|
# attempt to unify argument types with function template types
|
|
335
342
|
warp.types.infer_argument_types(arguments, template_types, arg_names)
|
|
336
343
|
return overload.func(*arguments)
|
|
337
|
-
except Exception:
|
|
344
|
+
except Exception as e:
|
|
345
|
+
# The function was callable but threw an error during its execution.
|
|
346
|
+
# This might be the intended overload, but it failed, or it might be the wrong overload.
|
|
347
|
+
# We save this specific error and continue, just in case another overload later in the
|
|
348
|
+
# list is a better match and doesn't fail.
|
|
349
|
+
last_execution_error = e
|
|
338
350
|
continue
|
|
339
351
|
|
|
340
|
-
|
|
352
|
+
if last_execution_error:
|
|
353
|
+
# Raise a new, more contextual RuntimeError, but link it to the
|
|
354
|
+
# original error that was caught. This preserves the original
|
|
355
|
+
# traceback and error type for easier debugging.
|
|
356
|
+
raise RuntimeError(
|
|
357
|
+
f"Error calling function '{self.key}'. No version succeeded. "
|
|
358
|
+
f"See above for the error from the last version that was tried."
|
|
359
|
+
) from last_execution_error
|
|
360
|
+
else:
|
|
361
|
+
# We got here without ever calling an overload.func
|
|
362
|
+
raise RuntimeError(f"Error calling function '{self.key}', no overload found for arguments {args}")
|
|
341
363
|
|
|
342
364
|
# user-defined function with no overloads
|
|
343
365
|
if self.func is None:
|
|
@@ -358,9 +380,6 @@ class Function:
|
|
|
358
380
|
if warp.types.is_array(v) or v in complex_type_hints:
|
|
359
381
|
return False
|
|
360
382
|
|
|
361
|
-
if type(self.value_type) in sequence_types:
|
|
362
|
-
return False
|
|
363
|
-
|
|
364
383
|
return True
|
|
365
384
|
|
|
366
385
|
def mangle(self) -> str:
|
|
@@ -404,8 +423,12 @@ class Function:
|
|
|
404
423
|
else:
|
|
405
424
|
self.user_overloads[sig] = f
|
|
406
425
|
|
|
407
|
-
def get_overload(self, arg_types:
|
|
408
|
-
|
|
426
|
+
def get_overload(self, arg_types: list[type], kwarg_types: Mapping[str, type]) -> Function | None:
|
|
427
|
+
if self.is_builtin():
|
|
428
|
+
for f in self.overloads:
|
|
429
|
+
if warp.codegen.func_match_args(f, arg_types, kwarg_types):
|
|
430
|
+
return f
|
|
431
|
+
return None
|
|
409
432
|
|
|
410
433
|
for f in self.user_overloads.values():
|
|
411
434
|
if warp.codegen.func_match_args(f, arg_types, kwarg_types):
|
|
@@ -439,7 +462,7 @@ class Function:
|
|
|
439
462
|
overload_annotations[k] = warp.codegen.strip_reference(warp.codegen.get_arg_type(d))
|
|
440
463
|
|
|
441
464
|
ovl = shallowcopy(f)
|
|
442
|
-
ovl.adj = warp.codegen.Adjoint(f.func, overload_annotations)
|
|
465
|
+
ovl.adj = warp.codegen.Adjoint(f.func, overload_annotations, source=f.adj.source)
|
|
443
466
|
ovl.input_types = overload_annotations
|
|
444
467
|
ovl.value_func = None
|
|
445
468
|
ovl.generic_parent = f
|
|
@@ -475,11 +498,25 @@ def get_builtin_type(return_type: type) -> type:
|
|
|
475
498
|
return return_type
|
|
476
499
|
|
|
477
500
|
|
|
478
|
-
def
|
|
501
|
+
def extract_return_value(value_type: type, value_ctype: type, ret: Any) -> Any:
|
|
502
|
+
if issubclass(value_ctype, ctypes.Array) or issubclass(value_ctype, ctypes.Structure):
|
|
503
|
+
# return vector types as ctypes
|
|
504
|
+
return ret
|
|
505
|
+
|
|
506
|
+
if value_type is warp.types.float16:
|
|
507
|
+
return warp.types.half_bits_to_float(ret.value)
|
|
508
|
+
|
|
509
|
+
return ret.value
|
|
510
|
+
|
|
511
|
+
|
|
512
|
+
def call_builtin(func: Function, params: tuple) -> tuple[bool, Any]:
|
|
479
513
|
uses_non_warp_array_type = False
|
|
480
514
|
|
|
481
515
|
init()
|
|
482
516
|
|
|
517
|
+
if func.mangled_name is None:
|
|
518
|
+
return (False, None)
|
|
519
|
+
|
|
483
520
|
# Retrieve the built-in function from Warp's dll.
|
|
484
521
|
c_func = getattr(warp.context.runtime.core, func.mangled_name)
|
|
485
522
|
|
|
@@ -489,6 +526,8 @@ def call_builtin(func: Function, *params: Any) -> Tuple[bool, Any]:
|
|
|
489
526
|
else:
|
|
490
527
|
func_args = func.input_types
|
|
491
528
|
|
|
529
|
+
value_type = func.value_func(None, None)
|
|
530
|
+
|
|
492
531
|
# Try gathering the parameters that the function expects and pack them
|
|
493
532
|
# into their corresponding C types.
|
|
494
533
|
c_params = []
|
|
@@ -604,9 +643,9 @@ def call_builtin(func: Function, *params: Any) -> Tuple[bool, Any]:
|
|
|
604
643
|
|
|
605
644
|
if not (
|
|
606
645
|
isinstance(param, arg_type)
|
|
607
|
-
or (type(param) is float and arg_type is warp.types.float32)
|
|
608
|
-
or (type(param) is int and arg_type is warp.types.int32)
|
|
609
|
-
or (type(param) is bool and arg_type is warp.types.bool)
|
|
646
|
+
or (type(param) is float and arg_type is warp.types.float32)
|
|
647
|
+
or (type(param) is int and arg_type is warp.types.int32)
|
|
648
|
+
or (type(param) is bool and arg_type is warp.types.bool)
|
|
610
649
|
or warp.types.np_dtype_to_warp_type.get(getattr(param, "dtype", None)) is arg_type
|
|
611
650
|
):
|
|
612
651
|
return (False, None)
|
|
@@ -620,25 +659,18 @@ def call_builtin(func: Function, *params: Any) -> Tuple[bool, Any]:
|
|
|
620
659
|
else:
|
|
621
660
|
c_params.append(arg_type._type_(param))
|
|
622
661
|
|
|
623
|
-
#
|
|
624
|
-
value_type = func.value_func(
|
|
662
|
+
# Retrieve the return type.
|
|
663
|
+
value_type = func.value_func(func_args, None)
|
|
625
664
|
|
|
626
|
-
if value_type
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
value_ctype = value_type
|
|
634
|
-
else:
|
|
635
|
-
# scalar type
|
|
636
|
-
value_ctype = value_type._type_
|
|
665
|
+
if value_type is not None:
|
|
666
|
+
if not isinstance(value_type, Sequence):
|
|
667
|
+
value_type = (value_type,)
|
|
668
|
+
|
|
669
|
+
value_ctype = tuple(warp.types.type_ctype(x) for x in value_type)
|
|
670
|
+
ret = tuple(x() for x in value_ctype)
|
|
671
|
+
ret_addr = tuple(ctypes.c_void_p(ctypes.addressof(x)) for x in ret)
|
|
637
672
|
|
|
638
|
-
|
|
639
|
-
ret = value_ctype()
|
|
640
|
-
ret_addr = ctypes.c_void_p(ctypes.addressof(ret))
|
|
641
|
-
c_params.append(ret_addr)
|
|
673
|
+
c_params.extend(ret_addr)
|
|
642
674
|
|
|
643
675
|
# Call the built-in function from Warp's dll.
|
|
644
676
|
c_func(*c_params)
|
|
@@ -653,17 +685,14 @@ def call_builtin(func: Function, *params: Any) -> Tuple[bool, Any]:
|
|
|
653
685
|
stacklevel=3,
|
|
654
686
|
)
|
|
655
687
|
|
|
656
|
-
if
|
|
657
|
-
|
|
658
|
-
return (True, ret)
|
|
688
|
+
if value_type is None:
|
|
689
|
+
return (True, None)
|
|
659
690
|
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
value = ret.value
|
|
691
|
+
return_value = tuple(extract_return_value(x, y, z) for x, y, z in zip(value_type, value_ctype, ret))
|
|
692
|
+
if len(return_value) == 1:
|
|
693
|
+
return_value = return_value[0]
|
|
664
694
|
|
|
665
|
-
|
|
666
|
-
return (True, value)
|
|
695
|
+
return (True, return_value)
|
|
667
696
|
|
|
668
697
|
|
|
669
698
|
class KernelHooks:
|
|
@@ -677,7 +706,7 @@ class KernelHooks:
|
|
|
677
706
|
|
|
678
707
|
# caches source and compiled entry points for a kernel (will be populated after module loads)
|
|
679
708
|
class Kernel:
|
|
680
|
-
def __init__(self, func, key=None, module=None, options=None, code_transformers=None):
|
|
709
|
+
def __init__(self, func, key=None, module=None, options=None, code_transformers=None, source=None):
|
|
681
710
|
self.func = func
|
|
682
711
|
|
|
683
712
|
if module is None:
|
|
@@ -695,7 +724,7 @@ class Kernel:
|
|
|
695
724
|
if code_transformers is None:
|
|
696
725
|
code_transformers = []
|
|
697
726
|
|
|
698
|
-
self.adj = warp.codegen.Adjoint(func, transformers=code_transformers)
|
|
727
|
+
self.adj = warp.codegen.Adjoint(func, transformers=code_transformers, source=source)
|
|
699
728
|
|
|
700
729
|
# check if generic
|
|
701
730
|
self.is_generic = False
|
|
@@ -762,7 +791,7 @@ class Kernel:
|
|
|
762
791
|
|
|
763
792
|
# instantiate this kernel with the given argument types
|
|
764
793
|
ovl = shallowcopy(self)
|
|
765
|
-
ovl.adj = warp.codegen.Adjoint(self.func, overload_annotations)
|
|
794
|
+
ovl.adj = warp.codegen.Adjoint(self.func, overload_annotations, source=self.adj.source)
|
|
766
795
|
ovl.is_generic = False
|
|
767
796
|
ovl.overloads = {}
|
|
768
797
|
ovl.sig = sig
|
|
@@ -798,7 +827,7 @@ class Kernel:
|
|
|
798
827
|
|
|
799
828
|
|
|
800
829
|
# decorator to register function, @func
|
|
801
|
-
def func(f:
|
|
830
|
+
def func(f: Callable | None = None, *, name: str | None = None):
|
|
802
831
|
def wrapper(f, *args, **kwargs):
|
|
803
832
|
if name is None:
|
|
804
833
|
key = warp.codegen.make_full_qualified_name(f)
|
|
@@ -831,7 +860,7 @@ def func(f: Optional[Callable] = None, *, name: Optional[str] = None):
|
|
|
831
860
|
return wrapper(f)
|
|
832
861
|
|
|
833
862
|
|
|
834
|
-
def func_native(snippet: str, adj_snippet:
|
|
863
|
+
def func_native(snippet: str, adj_snippet: str | None = None, replay_snippet: str | None = None):
|
|
835
864
|
"""
|
|
836
865
|
Decorator to register native code snippet, @func_native
|
|
837
866
|
"""
|
|
@@ -1015,10 +1044,10 @@ def func_replay(forward_fn):
|
|
|
1015
1044
|
|
|
1016
1045
|
|
|
1017
1046
|
def kernel(
|
|
1018
|
-
f:
|
|
1047
|
+
f: Callable | None = None,
|
|
1019
1048
|
*,
|
|
1020
|
-
enable_backward:
|
|
1021
|
-
module:
|
|
1049
|
+
enable_backward: bool | None = None,
|
|
1050
|
+
module: Module | Literal["unique"] | None = None,
|
|
1022
1051
|
):
|
|
1023
1052
|
"""
|
|
1024
1053
|
Decorator to register a Warp kernel from a Python function.
|
|
@@ -1181,7 +1210,7 @@ def overload(kernel, arg_types=Union[None, Dict[str, Any], List[Any]]):
|
|
|
1181
1210
|
|
|
1182
1211
|
|
|
1183
1212
|
# native functions that are part of the Warp API
|
|
1184
|
-
builtin_functions:
|
|
1213
|
+
builtin_functions: dict[str, Function] = {}
|
|
1185
1214
|
|
|
1186
1215
|
|
|
1187
1216
|
def get_generic_vtypes():
|
|
@@ -1204,13 +1233,13 @@ scalar_types.update({x: x._wp_scalar_type_ for x in warp.types.vector_types})
|
|
|
1204
1233
|
|
|
1205
1234
|
def add_builtin(
|
|
1206
1235
|
key: str,
|
|
1207
|
-
input_types:
|
|
1208
|
-
constraint:
|
|
1209
|
-
value_type:
|
|
1210
|
-
value_func:
|
|
1211
|
-
export_func:
|
|
1212
|
-
dispatch_func:
|
|
1213
|
-
lto_dispatch_func:
|
|
1236
|
+
input_types: dict[str, type | TypeVar] | None = None,
|
|
1237
|
+
constraint: Callable[[Mapping[str, type]], bool] | None = None,
|
|
1238
|
+
value_type: type | None = None,
|
|
1239
|
+
value_func: Callable | None = None,
|
|
1240
|
+
export_func: Callable | None = None,
|
|
1241
|
+
dispatch_func: Callable | None = None,
|
|
1242
|
+
lto_dispatch_func: Callable | None = None,
|
|
1214
1243
|
doc: str = "",
|
|
1215
1244
|
namespace: str = "wp::",
|
|
1216
1245
|
variadic: bool = False,
|
|
@@ -1220,8 +1249,8 @@ def add_builtin(
|
|
|
1220
1249
|
hidden: bool = False,
|
|
1221
1250
|
skip_replay: bool = False,
|
|
1222
1251
|
missing_grad: bool = False,
|
|
1223
|
-
native_func:
|
|
1224
|
-
defaults:
|
|
1252
|
+
native_func: str | None = None,
|
|
1253
|
+
defaults: dict[str, Any] | None = None,
|
|
1225
1254
|
require_original_output_arg: bool = False,
|
|
1226
1255
|
):
|
|
1227
1256
|
"""Main entry point to register a new built-in function.
|
|
@@ -1371,18 +1400,13 @@ def add_builtin(
|
|
|
1371
1400
|
|
|
1372
1401
|
return_type = value_func(concrete_arg_types, None)
|
|
1373
1402
|
|
|
1374
|
-
|
|
1375
|
-
|
|
1376
|
-
|
|
1377
|
-
|
|
1378
|
-
|
|
1379
|
-
|
|
1380
|
-
|
|
1381
|
-
and x._wp_type_params_ == return_type._wp_type_params_
|
|
1382
|
-
)
|
|
1383
|
-
if not return_type_match:
|
|
1384
|
-
continue
|
|
1385
|
-
return_type = return_type_match[0]
|
|
1403
|
+
try:
|
|
1404
|
+
if isinstance(return_type, Sequence):
|
|
1405
|
+
return_type = tuple(get_builtin_type(x) for x in return_type)
|
|
1406
|
+
else:
|
|
1407
|
+
return_type = get_builtin_type(return_type)
|
|
1408
|
+
except RuntimeError:
|
|
1409
|
+
continue
|
|
1386
1410
|
|
|
1387
1411
|
# finally we can generate a function call for these concrete types:
|
|
1388
1412
|
add_builtin(
|
|
@@ -1485,7 +1509,7 @@ def register_api_function(
|
|
|
1485
1509
|
|
|
1486
1510
|
|
|
1487
1511
|
# global dictionary of modules
|
|
1488
|
-
user_modules:
|
|
1512
|
+
user_modules: dict[str, Module] = {}
|
|
1489
1513
|
|
|
1490
1514
|
|
|
1491
1515
|
def get_module(name: str) -> Module:
|
|
@@ -1608,7 +1632,7 @@ class ModuleHasher:
|
|
|
1608
1632
|
ch.update(bytes(func.key, "utf-8"))
|
|
1609
1633
|
|
|
1610
1634
|
# include all concrete and generic overloads
|
|
1611
|
-
overloads:
|
|
1635
|
+
overloads: dict[str, Function] = {**func.user_overloads, **func.user_templates}
|
|
1612
1636
|
for sig in sorted(overloads.keys()):
|
|
1613
1637
|
ovl = overloads[sig]
|
|
1614
1638
|
|
|
@@ -1668,7 +1692,7 @@ class ModuleHasher:
|
|
|
1668
1692
|
ch.update(bytes(name, "utf-8"))
|
|
1669
1693
|
ch.update(self.get_constant_bytes(value))
|
|
1670
1694
|
|
|
1671
|
-
# hash wp.static() expressions
|
|
1695
|
+
# hash wp.static() expressions
|
|
1672
1696
|
for k, v in adj.static_expressions.items():
|
|
1673
1697
|
ch.update(bytes(k, "utf-8"))
|
|
1674
1698
|
if isinstance(v, Function):
|
|
@@ -1857,7 +1881,7 @@ class ModuleBuilder:
|
|
|
1857
1881
|
# the original Modules get reloaded.
|
|
1858
1882
|
class ModuleExec:
|
|
1859
1883
|
def __new__(cls, *args, **kwargs):
|
|
1860
|
-
instance = super(
|
|
1884
|
+
instance = super().__new__(cls)
|
|
1861
1885
|
instance.handle = None
|
|
1862
1886
|
return instance
|
|
1863
1887
|
|
|
@@ -1952,7 +1976,7 @@ class ModuleExec:
|
|
|
1952
1976
|
# creates a hash of the function to use for checking
|
|
1953
1977
|
# build cache
|
|
1954
1978
|
class Module:
|
|
1955
|
-
def __init__(self, name:
|
|
1979
|
+
def __init__(self, name: str | None, loader=None):
|
|
1956
1980
|
self.name = name if name is not None else "None"
|
|
1957
1981
|
|
|
1958
1982
|
self.loader = loader
|
|
@@ -1987,6 +2011,9 @@ class Module:
|
|
|
1987
2011
|
# is retained and later reloaded with the same hash.
|
|
1988
2012
|
self.cpu_exec_id = 0
|
|
1989
2013
|
|
|
2014
|
+
# Indicates whether the module has functions or kernels with unresolved static expressions.
|
|
2015
|
+
self.has_unresolved_static_expressions = False
|
|
2016
|
+
|
|
1990
2017
|
self.options = {
|
|
1991
2018
|
"max_unroll": warp.config.max_unroll,
|
|
1992
2019
|
"enable_backward": warp.config.enable_backward,
|
|
@@ -1994,8 +2021,9 @@ class Module:
|
|
|
1994
2021
|
"fuse_fp": True,
|
|
1995
2022
|
"lineinfo": warp.config.lineinfo,
|
|
1996
2023
|
"cuda_output": None, # supported values: "ptx", "cubin", or None (automatic)
|
|
1997
|
-
"mode":
|
|
2024
|
+
"mode": None,
|
|
1998
2025
|
"block_dim": 256,
|
|
2026
|
+
"compile_time_trace": warp.config.compile_time_trace,
|
|
1999
2027
|
}
|
|
2000
2028
|
|
|
2001
2029
|
# Module dependencies are determined by scanning each function
|
|
@@ -2022,6 +2050,10 @@ class Module:
|
|
|
2022
2050
|
# track all kernel objects, even if they are duplicates
|
|
2023
2051
|
self._live_kernels.add(kernel)
|
|
2024
2052
|
|
|
2053
|
+
# Check for unresolved static expressions in the kernel.
|
|
2054
|
+
if kernel.adj.has_unresolved_static_expressions:
|
|
2055
|
+
self.has_unresolved_static_expressions = True
|
|
2056
|
+
|
|
2025
2057
|
self.find_references(kernel.adj)
|
|
2026
2058
|
|
|
2027
2059
|
# for a reload of module on next launch
|
|
@@ -2081,6 +2113,10 @@ class Module:
|
|
|
2081
2113
|
del func_existing.user_overloads[k]
|
|
2082
2114
|
func_existing.add_overload(func)
|
|
2083
2115
|
|
|
2116
|
+
# Check for unresolved static expressions in the function.
|
|
2117
|
+
if func.adj.has_unresolved_static_expressions:
|
|
2118
|
+
self.has_unresolved_static_expressions = True
|
|
2119
|
+
|
|
2084
2120
|
self.find_references(func.adj)
|
|
2085
2121
|
|
|
2086
2122
|
# for a reload of module on next launch
|
|
@@ -2140,7 +2176,7 @@ class Module:
|
|
|
2140
2176
|
self.hashers[block_dim] = ModuleHasher(self)
|
|
2141
2177
|
return self.hashers[block_dim].get_module_hash()
|
|
2142
2178
|
|
|
2143
|
-
def load(self, device, block_dim=None) -> ModuleExec:
|
|
2179
|
+
def load(self, device, block_dim=None) -> ModuleExec | None:
|
|
2144
2180
|
device = runtime.get_device(device)
|
|
2145
2181
|
|
|
2146
2182
|
# update module options if launching with a new block dim
|
|
@@ -2149,6 +2185,20 @@ class Module:
|
|
|
2149
2185
|
|
|
2150
2186
|
active_block_dim = self.options["block_dim"]
|
|
2151
2187
|
|
|
2188
|
+
if self.has_unresolved_static_expressions:
|
|
2189
|
+
# The module hash currently does not account for unresolved static expressions
|
|
2190
|
+
# (only static expressions evaluated at declaration time so far).
|
|
2191
|
+
# We need to generate the code for the functions and kernels that have
|
|
2192
|
+
# unresolved static expressions and then compute the module hash again.
|
|
2193
|
+
builder_options = {
|
|
2194
|
+
**self.options,
|
|
2195
|
+
"output_arch": None,
|
|
2196
|
+
}
|
|
2197
|
+
# build functions, kernels to resolve static expressions
|
|
2198
|
+
_ = ModuleBuilder(self, builder_options)
|
|
2199
|
+
|
|
2200
|
+
self.has_unresolved_static_expressions = False
|
|
2201
|
+
|
|
2152
2202
|
# compute the hash if needed
|
|
2153
2203
|
if active_block_dim not in self.hashers:
|
|
2154
2204
|
self.hashers[active_block_dim] = ModuleHasher(self)
|
|
@@ -2222,7 +2272,7 @@ class Module:
|
|
|
2222
2272
|
):
|
|
2223
2273
|
builder_options = {
|
|
2224
2274
|
**self.options,
|
|
2225
|
-
# Some of the
|
|
2275
|
+
# Some of the tile codegen, such as cuFFTDx and cuBLASDx, requires knowledge of the target arch
|
|
2226
2276
|
"output_arch": output_arch,
|
|
2227
2277
|
}
|
|
2228
2278
|
builder = ModuleBuilder(self, builder_options, hasher=self.hashers[active_block_dim])
|
|
@@ -2237,6 +2287,8 @@ class Module:
|
|
|
2237
2287
|
|
|
2238
2288
|
module_load_timer.extra_msg = " (compiled)" # For wp.ScopedTimer informational purposes
|
|
2239
2289
|
|
|
2290
|
+
mode = self.options["mode"] if self.options["mode"] is not None else warp.config.mode
|
|
2291
|
+
|
|
2240
2292
|
# build CPU
|
|
2241
2293
|
if device.is_cpu:
|
|
2242
2294
|
# build
|
|
@@ -2256,7 +2308,7 @@ class Module:
|
|
|
2256
2308
|
warp.build.build_cpu(
|
|
2257
2309
|
output_path,
|
|
2258
2310
|
source_code_path,
|
|
2259
|
-
mode=
|
|
2311
|
+
mode=mode,
|
|
2260
2312
|
fast_math=self.options["fast_math"],
|
|
2261
2313
|
verify_fp=warp.config.verify_fp,
|
|
2262
2314
|
fuse_fp=self.options["fuse_fp"],
|
|
@@ -2286,11 +2338,12 @@ class Module:
|
|
|
2286
2338
|
source_code_path,
|
|
2287
2339
|
output_arch,
|
|
2288
2340
|
output_path,
|
|
2289
|
-
config=
|
|
2341
|
+
config=mode,
|
|
2290
2342
|
verify_fp=warp.config.verify_fp,
|
|
2291
2343
|
fast_math=self.options["fast_math"],
|
|
2292
2344
|
fuse_fp=self.options["fuse_fp"],
|
|
2293
2345
|
lineinfo=self.options["lineinfo"],
|
|
2346
|
+
compile_time_trace=self.options["compile_time_trace"],
|
|
2294
2347
|
ltoirs=builder.ltoirs.values(),
|
|
2295
2348
|
fatbins=builder.fatbins.values(),
|
|
2296
2349
|
)
|
|
@@ -2343,7 +2396,7 @@ class Module:
|
|
|
2343
2396
|
# Load CPU or CUDA binary
|
|
2344
2397
|
|
|
2345
2398
|
meta_path = os.path.join(module_dir, f"{module_name_short}.meta")
|
|
2346
|
-
with open(meta_path
|
|
2399
|
+
with open(meta_path) as meta_file:
|
|
2347
2400
|
meta = json.load(meta_file)
|
|
2348
2401
|
|
|
2349
2402
|
if device.is_cpu:
|
|
@@ -2406,7 +2459,7 @@ class CpuDefaultAllocator:
|
|
|
2406
2459
|
def alloc(self, size_in_bytes):
|
|
2407
2460
|
ptr = runtime.core.alloc_host(size_in_bytes)
|
|
2408
2461
|
if not ptr:
|
|
2409
|
-
raise RuntimeError(f"Failed to allocate {size_in_bytes} bytes on device '
|
|
2462
|
+
raise RuntimeError(f"Failed to allocate {size_in_bytes} bytes on device 'cpu'")
|
|
2410
2463
|
return ptr
|
|
2411
2464
|
|
|
2412
2465
|
def free(self, ptr, size_in_bytes):
|
|
@@ -2510,12 +2563,12 @@ class Event:
|
|
|
2510
2563
|
|
|
2511
2564
|
def __new__(cls, *args, **kwargs):
|
|
2512
2565
|
"""Creates a new event instance."""
|
|
2513
|
-
instance = super(
|
|
2566
|
+
instance = super().__new__(cls)
|
|
2514
2567
|
instance.owner = False
|
|
2515
2568
|
return instance
|
|
2516
2569
|
|
|
2517
2570
|
def __init__(
|
|
2518
|
-
self, device:
|
|
2571
|
+
self, device: Devicelike = None, cuda_event=None, enable_timing: bool = False, interprocess: bool = False
|
|
2519
2572
|
):
|
|
2520
2573
|
"""Initializes the event on a CUDA device.
|
|
2521
2574
|
|
|
@@ -2611,12 +2664,12 @@ class Event:
|
|
|
2611
2664
|
|
|
2612
2665
|
class Stream:
|
|
2613
2666
|
def __new__(cls, *args, **kwargs):
|
|
2614
|
-
instance = super(
|
|
2667
|
+
instance = super().__new__(cls)
|
|
2615
2668
|
instance.cuda_stream = None
|
|
2616
2669
|
instance.owner = False
|
|
2617
2670
|
return instance
|
|
2618
2671
|
|
|
2619
|
-
def __init__(self, device:
|
|
2672
|
+
def __init__(self, device: Device | str | None = None, priority: int = 0, **kwargs):
|
|
2620
2673
|
"""Initialize the stream on a device with an optional specified priority.
|
|
2621
2674
|
|
|
2622
2675
|
Args:
|
|
@@ -2682,7 +2735,7 @@ class Stream:
|
|
|
2682
2735
|
self._cached_event = Event(self.device)
|
|
2683
2736
|
return self._cached_event
|
|
2684
2737
|
|
|
2685
|
-
def record_event(self, event:
|
|
2738
|
+
def record_event(self, event: Event | None = None) -> Event:
|
|
2686
2739
|
"""Record an event onto the stream.
|
|
2687
2740
|
|
|
2688
2741
|
Args:
|
|
@@ -2711,7 +2764,7 @@ class Stream:
|
|
|
2711
2764
|
"""
|
|
2712
2765
|
runtime.core.cuda_stream_wait_event(self.cuda_stream, event.cuda_event)
|
|
2713
2766
|
|
|
2714
|
-
def wait_stream(self, other_stream:
|
|
2767
|
+
def wait_stream(self, other_stream: Stream, event: Event | None = None):
|
|
2715
2768
|
"""Records an event on `other_stream` and makes this stream wait on it.
|
|
2716
2769
|
|
|
2717
2770
|
All work added to this stream after this function has been called will
|
|
@@ -2765,6 +2818,8 @@ class Device:
|
|
|
2765
2818
|
or ``"CPU"`` if the processor name cannot be determined.
|
|
2766
2819
|
arch (int): The compute capability version number calculated as ``10 * major + minor``.
|
|
2767
2820
|
``0`` for CPU devices.
|
|
2821
|
+
sm_count (int): The number of streaming multiprocessors on the CUDA device.
|
|
2822
|
+
``0`` for CPU devices.
|
|
2768
2823
|
is_uva (bool): Indicates whether the device supports unified addressing.
|
|
2769
2824
|
``False`` for CPU devices.
|
|
2770
2825
|
is_cubin_supported (bool): Indicates whether Warp's version of NVRTC can directly
|
|
@@ -2810,6 +2865,7 @@ class Device:
|
|
|
2810
2865
|
# CPU device
|
|
2811
2866
|
self.name = platform.processor() or "CPU"
|
|
2812
2867
|
self.arch = 0
|
|
2868
|
+
self.sm_count = 0
|
|
2813
2869
|
self.is_uva = False
|
|
2814
2870
|
self.is_mempool_supported = False
|
|
2815
2871
|
self.is_mempool_enabled = False
|
|
@@ -2829,6 +2885,7 @@ class Device:
|
|
|
2829
2885
|
# CUDA device
|
|
2830
2886
|
self.name = runtime.core.cuda_device_get_name(ordinal).decode()
|
|
2831
2887
|
self.arch = runtime.core.cuda_device_get_arch(ordinal)
|
|
2888
|
+
self.sm_count = runtime.core.cuda_device_get_sm_count(ordinal)
|
|
2832
2889
|
self.is_uva = runtime.core.cuda_device_is_uva(ordinal) > 0
|
|
2833
2890
|
self.is_mempool_supported = runtime.core.cuda_device_is_mempool_supported(ordinal) > 0
|
|
2834
2891
|
if platform.system() == "Linux":
|
|
@@ -3070,16 +3127,23 @@ class Graph:
|
|
|
3070
3127
|
def __init__(self, device: Device, capture_id: int):
|
|
3071
3128
|
self.device = device
|
|
3072
3129
|
self.capture_id = capture_id
|
|
3073
|
-
self.module_execs:
|
|
3074
|
-
self.graph_exec:
|
|
3130
|
+
self.module_execs: set[ModuleExec] = set()
|
|
3131
|
+
self.graph_exec: ctypes.c_void_p | None = None
|
|
3132
|
+
|
|
3133
|
+
self.graph: ctypes.c_void_p | None = None
|
|
3134
|
+
self.has_conditional = (
|
|
3135
|
+
False # Track if there are conditional nodes in the graph since they are not allowed in child graphs
|
|
3136
|
+
)
|
|
3075
3137
|
|
|
3076
3138
|
def __del__(self):
|
|
3077
|
-
if not hasattr(self, "
|
|
3139
|
+
if not hasattr(self, "graph") or not hasattr(self, "device") or not self.graph:
|
|
3078
3140
|
return
|
|
3079
3141
|
|
|
3080
3142
|
# use CUDA context guard to avoid side effects during garbage collection
|
|
3081
3143
|
with self.device.context_guard:
|
|
3082
|
-
runtime.core.cuda_graph_destroy(self.device.context, self.
|
|
3144
|
+
runtime.core.cuda_graph_destroy(self.device.context, self.graph)
|
|
3145
|
+
if hasattr(self, "graph_exec") and self.graph_exec is not None:
|
|
3146
|
+
runtime.core.cuda_graph_exec_destroy(self.device.context, self.graph_exec)
|
|
3083
3147
|
|
|
3084
3148
|
# retain executable CUDA modules used by this graph, which prevents them from being unloaded
|
|
3085
3149
|
def retain_module_exec(self, module_exec: ModuleExec):
|
|
@@ -3088,8 +3152,6 @@ class Graph:
|
|
|
3088
3152
|
|
|
3089
3153
|
class Runtime:
|
|
3090
3154
|
def __init__(self):
|
|
3091
|
-
if sys.version_info < (3, 8):
|
|
3092
|
-
raise RuntimeError("Warp requires Python 3.8 as a minimum")
|
|
3093
3155
|
if sys.version_info < (3, 9):
|
|
3094
3156
|
warp.utils.warn(f"Python 3.9 or newer is recommended for running Warp, detected {sys.version_info}")
|
|
3095
3157
|
|
|
@@ -3535,44 +3597,40 @@ class Runtime:
|
|
|
3535
3597
|
self.core.volume_get_blind_data_info.restype = ctypes.c_char_p
|
|
3536
3598
|
|
|
3537
3599
|
bsr_matrix_from_triplets_argtypes = [
|
|
3538
|
-
ctypes.c_int, #
|
|
3539
|
-
ctypes.c_int, #
|
|
3600
|
+
ctypes.c_int, # block_size
|
|
3601
|
+
ctypes.c_int, # scalar size in bytes
|
|
3540
3602
|
ctypes.c_int, # row_count
|
|
3541
|
-
ctypes.c_int, #
|
|
3603
|
+
ctypes.c_int, # col_count
|
|
3604
|
+
ctypes.c_int, # nnz_upper_bound
|
|
3605
|
+
ctypes.POINTER(ctypes.c_int), # tpl_nnz
|
|
3542
3606
|
ctypes.POINTER(ctypes.c_int), # tpl_rows
|
|
3543
3607
|
ctypes.POINTER(ctypes.c_int), # tpl_cols
|
|
3544
3608
|
ctypes.c_void_p, # tpl_values
|
|
3545
|
-
ctypes.
|
|
3609
|
+
ctypes.c_uint64, # zero_value_mask
|
|
3546
3610
|
ctypes.c_bool, # masked
|
|
3547
3611
|
ctypes.POINTER(ctypes.c_int), # bsr_offsets
|
|
3548
3612
|
ctypes.POINTER(ctypes.c_int), # bsr_columns
|
|
3549
|
-
ctypes.
|
|
3613
|
+
ctypes.POINTER(ctypes.c_int), # prefix sum of block count to sum for each bsr block
|
|
3614
|
+
ctypes.POINTER(ctypes.c_int), # indices to ptriplet blocks to sum for each bsr block
|
|
3550
3615
|
ctypes.POINTER(ctypes.c_int), # bsr_nnz
|
|
3551
3616
|
ctypes.c_void_p, # bsr_nnz_event
|
|
3552
3617
|
]
|
|
3553
3618
|
|
|
3554
|
-
self.core.
|
|
3555
|
-
self.core.
|
|
3556
|
-
self.core.bsr_matrix_from_triplets_float_device.argtypes = bsr_matrix_from_triplets_argtypes
|
|
3557
|
-
self.core.bsr_matrix_from_triplets_double_device.argtypes = bsr_matrix_from_triplets_argtypes
|
|
3619
|
+
self.core.bsr_matrix_from_triplets_host.argtypes = bsr_matrix_from_triplets_argtypes
|
|
3620
|
+
self.core.bsr_matrix_from_triplets_device.argtypes = bsr_matrix_from_triplets_argtypes
|
|
3558
3621
|
|
|
3559
3622
|
bsr_transpose_argtypes = [
|
|
3560
|
-
ctypes.c_int, # rows_per_bock
|
|
3561
|
-
ctypes.c_int, # cols_per_blocks
|
|
3562
3623
|
ctypes.c_int, # row_count
|
|
3563
3624
|
ctypes.c_int, # col count
|
|
3564
3625
|
ctypes.c_int, # nnz
|
|
3565
3626
|
ctypes.POINTER(ctypes.c_int), # transposed_bsr_offsets
|
|
3566
3627
|
ctypes.POINTER(ctypes.c_int), # transposed_bsr_columns
|
|
3567
|
-
ctypes.c_void_p, # bsr_values
|
|
3568
3628
|
ctypes.POINTER(ctypes.c_int), # transposed_bsr_offsets
|
|
3569
3629
|
ctypes.POINTER(ctypes.c_int), # transposed_bsr_columns
|
|
3570
|
-
ctypes.
|
|
3630
|
+
ctypes.POINTER(ctypes.c_int), # src to dest block map
|
|
3571
3631
|
]
|
|
3572
|
-
self.core.
|
|
3573
|
-
self.core.
|
|
3574
|
-
self.core.bsr_transpose_float_device.argtypes = bsr_transpose_argtypes
|
|
3575
|
-
self.core.bsr_transpose_double_device.argtypes = bsr_transpose_argtypes
|
|
3632
|
+
self.core.bsr_transpose_host.argtypes = bsr_transpose_argtypes
|
|
3633
|
+
self.core.bsr_transpose_device.argtypes = bsr_transpose_argtypes
|
|
3576
3634
|
|
|
3577
3635
|
self.core.is_cuda_enabled.argtypes = None
|
|
3578
3636
|
self.core.is_cuda_enabled.restype = ctypes.c_int
|
|
@@ -3601,6 +3659,8 @@ class Runtime:
|
|
|
3601
3659
|
self.core.cuda_device_get_name.restype = ctypes.c_char_p
|
|
3602
3660
|
self.core.cuda_device_get_arch.argtypes = [ctypes.c_int]
|
|
3603
3661
|
self.core.cuda_device_get_arch.restype = ctypes.c_int
|
|
3662
|
+
self.core.cuda_device_get_sm_count.argtypes = [ctypes.c_int]
|
|
3663
|
+
self.core.cuda_device_get_sm_count.restype = ctypes.c_int
|
|
3604
3664
|
self.core.cuda_device_is_uva.argtypes = [ctypes.c_int]
|
|
3605
3665
|
self.core.cuda_device_is_uva.restype = ctypes.c_int
|
|
3606
3666
|
self.core.cuda_device_is_mempool_supported.argtypes = [ctypes.c_int]
|
|
@@ -3724,11 +3784,73 @@ class Runtime:
|
|
|
3724
3784
|
ctypes.POINTER(ctypes.c_void_p),
|
|
3725
3785
|
]
|
|
3726
3786
|
self.core.cuda_graph_end_capture.restype = ctypes.c_bool
|
|
3787
|
+
|
|
3788
|
+
self.core.cuda_graph_create_exec.argtypes = [
|
|
3789
|
+
ctypes.c_void_p,
|
|
3790
|
+
ctypes.c_void_p,
|
|
3791
|
+
ctypes.c_void_p,
|
|
3792
|
+
ctypes.POINTER(ctypes.c_void_p),
|
|
3793
|
+
]
|
|
3794
|
+
self.core.cuda_graph_create_exec.restype = ctypes.c_bool
|
|
3795
|
+
|
|
3796
|
+
self.core.capture_debug_dot_print.argtypes = [ctypes.c_void_p, ctypes.c_char_p, ctypes.c_uint32]
|
|
3797
|
+
self.core.capture_debug_dot_print.restype = ctypes.c_bool
|
|
3798
|
+
|
|
3727
3799
|
self.core.cuda_graph_launch.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
|
|
3728
3800
|
self.core.cuda_graph_launch.restype = ctypes.c_bool
|
|
3801
|
+
self.core.cuda_graph_exec_destroy.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
|
|
3802
|
+
self.core.cuda_graph_exec_destroy.restype = ctypes.c_bool
|
|
3803
|
+
|
|
3729
3804
|
self.core.cuda_graph_destroy.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
|
|
3730
3805
|
self.core.cuda_graph_destroy.restype = ctypes.c_bool
|
|
3731
3806
|
|
|
3807
|
+
self.core.cuda_graph_insert_if_else.argtypes = [
|
|
3808
|
+
ctypes.c_void_p,
|
|
3809
|
+
ctypes.c_void_p,
|
|
3810
|
+
ctypes.POINTER(ctypes.c_int),
|
|
3811
|
+
ctypes.POINTER(ctypes.c_void_p),
|
|
3812
|
+
ctypes.POINTER(ctypes.c_void_p),
|
|
3813
|
+
]
|
|
3814
|
+
self.core.cuda_graph_insert_if_else.restype = ctypes.c_bool
|
|
3815
|
+
|
|
3816
|
+
self.core.cuda_graph_insert_while.argtypes = [
|
|
3817
|
+
ctypes.c_void_p,
|
|
3818
|
+
ctypes.c_void_p,
|
|
3819
|
+
ctypes.POINTER(ctypes.c_int),
|
|
3820
|
+
ctypes.POINTER(ctypes.c_void_p),
|
|
3821
|
+
ctypes.POINTER(ctypes.c_uint64),
|
|
3822
|
+
]
|
|
3823
|
+
self.core.cuda_graph_insert_while.restype = ctypes.c_bool
|
|
3824
|
+
|
|
3825
|
+
self.core.cuda_graph_set_condition.argtypes = [
|
|
3826
|
+
ctypes.c_void_p,
|
|
3827
|
+
ctypes.c_void_p,
|
|
3828
|
+
ctypes.POINTER(ctypes.c_int),
|
|
3829
|
+
ctypes.c_uint64,
|
|
3830
|
+
]
|
|
3831
|
+
self.core.cuda_graph_set_condition.restype = ctypes.c_bool
|
|
3832
|
+
|
|
3833
|
+
self.core.cuda_graph_pause_capture.argtypes = [
|
|
3834
|
+
ctypes.c_void_p,
|
|
3835
|
+
ctypes.c_void_p,
|
|
3836
|
+
ctypes.POINTER(ctypes.c_void_p),
|
|
3837
|
+
]
|
|
3838
|
+
self.core.cuda_graph_pause_capture.restype = ctypes.c_bool
|
|
3839
|
+
|
|
3840
|
+
self.core.cuda_graph_resume_capture.argtypes = [
|
|
3841
|
+
ctypes.c_void_p,
|
|
3842
|
+
ctypes.c_void_p,
|
|
3843
|
+
ctypes.c_void_p,
|
|
3844
|
+
]
|
|
3845
|
+
self.core.cuda_graph_resume_capture.restype = ctypes.c_bool
|
|
3846
|
+
|
|
3847
|
+
self.core.cuda_graph_insert_child_graph.argtypes = [
|
|
3848
|
+
ctypes.c_void_p,
|
|
3849
|
+
ctypes.c_void_p,
|
|
3850
|
+
ctypes.c_void_p,
|
|
3851
|
+
]
|
|
3852
|
+
self.core.cuda_graph_insert_child_graph.restype = ctypes.c_bool
|
|
3853
|
+
|
|
3732
3854
|
self.core.cuda_compile_program.argtypes = [
|
|
3733
3855
|
ctypes.c_char_p, # cuda_src
|
|
3734
3856
|
ctypes.c_char_p, # program name
|
|
@@ -3742,6 +3864,7 @@ class Runtime:
|
|
|
3742
3864
|
ctypes.c_bool, # fast_math
|
|
3743
3865
|
ctypes.c_bool, # fuse_fp
|
|
3744
3866
|
ctypes.c_bool, # lineinfo
|
|
3867
|
+
ctypes.c_bool, # compile_time_trace
|
|
3745
3868
|
ctypes.c_char_p, # output_path
|
|
3746
3869
|
ctypes.c_size_t, # num_ltoirs
|
|
3747
3870
|
ctypes.POINTER(ctypes.c_char_p), # ltoirs
|
|
@@ -3796,11 +3919,17 @@ class Runtime:
|
|
|
3796
3919
|
ctypes.c_int, # arch
|
|
3797
3920
|
ctypes.c_int, # M
|
|
3798
3921
|
ctypes.c_int, # N
|
|
3922
|
+
ctypes.c_int, # NRHS
|
|
3923
|
+
ctypes.c_int, # function
|
|
3924
|
+
ctypes.c_int, # side
|
|
3925
|
+
ctypes.c_int, # diag
|
|
3799
3926
|
ctypes.c_int, # precision
|
|
3927
|
+
ctypes.c_int, # a_arrangement
|
|
3928
|
+
ctypes.c_int, # b_arrangement
|
|
3800
3929
|
ctypes.c_int, # fill_mode
|
|
3801
3930
|
ctypes.c_int, # num threads
|
|
3802
3931
|
]
|
|
3803
|
-
self.core.
|
|
3932
|
+
self.core.cuda_compile_solver.restype = ctypes.c_bool
|
|
3804
3933
|
|
|
3805
3934
|
self.core.cuda_load_module.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
|
|
3806
3935
|
self.core.cuda_load_module.restype = ctypes.c_void_p
|
|
@@ -3965,9 +4094,14 @@ class Runtime:
|
|
|
3965
4094
|
# Update the default PTX architecture based on devices present in the system.
|
|
3966
4095
|
# Use the lowest architecture among devices that meet the minimum architecture requirement.
|
|
3967
4096
|
# Devices below the required minimum will use the highest architecture they support.
|
|
3968
|
-
|
|
3969
|
-
|
|
3970
|
-
|
|
4097
|
+
try:
|
|
4098
|
+
self.default_ptx_arch = min(
|
|
4099
|
+
d.arch
|
|
4100
|
+
for d in self.cuda_devices
|
|
4101
|
+
if d.arch >= self.default_ptx_arch and d.arch in self.nvrtc_supported_archs
|
|
4102
|
+
)
|
|
4103
|
+
except ValueError:
|
|
4104
|
+
pass # no eligible NVRTC-supported arch ≥ default, retain existing
|
|
3971
4105
|
else:
|
|
3972
4106
|
# CUDA not available
|
|
3973
4107
|
self.set_default_device("cpu")
|
|
@@ -4270,7 +4404,7 @@ def is_cuda_driver_initialized() -> bool:
|
|
|
4270
4404
|
return runtime.core.cuda_driver_is_initialized()
|
|
4271
4405
|
|
|
4272
4406
|
|
|
4273
|
-
def get_devices() ->
|
|
4407
|
+
def get_devices() -> list[Device]:
|
|
4274
4408
|
"""Returns a list of devices supported in this environment."""
|
|
4275
4409
|
|
|
4276
4410
|
init()
|
|
@@ -4291,7 +4425,7 @@ def get_cuda_device_count() -> int:
|
|
|
4291
4425
|
return len(runtime.cuda_devices)
|
|
4292
4426
|
|
|
4293
4427
|
|
|
4294
|
-
def get_cuda_device(ordinal:
|
|
4428
|
+
def get_cuda_device(ordinal: int | None = None) -> Device:
|
|
4295
4429
|
"""Returns the CUDA device with the given ordinal or the current CUDA device if ordinal is None."""
|
|
4296
4430
|
|
|
4297
4431
|
init()
|
|
@@ -4302,7 +4436,7 @@ def get_cuda_device(ordinal: Union[int, None] = None) -> Device:
|
|
|
4302
4436
|
return runtime.cuda_devices[ordinal]
|
|
4303
4437
|
|
|
4304
4438
|
|
|
4305
|
-
def get_cuda_devices() ->
|
|
4439
|
+
def get_cuda_devices() -> list[Device]:
|
|
4306
4440
|
"""Returns a list of CUDA devices supported in this environment."""
|
|
4307
4441
|
|
|
4308
4442
|
init()
|
|
@@ -4341,7 +4475,7 @@ def set_device(ident: Devicelike) -> None:
|
|
|
4341
4475
|
device.make_current()
|
|
4342
4476
|
|
|
4343
4477
|
|
|
4344
|
-
def map_cuda_device(alias: str, context:
|
|
4478
|
+
def map_cuda_device(alias: str, context: ctypes.c_void_p | None = None) -> Device:
|
|
4345
4479
|
"""Assign a device alias to a CUDA context.
|
|
4346
4480
|
|
|
4347
4481
|
This function can be used to create a wp.Device for an external CUDA context.
|
|
@@ -4436,7 +4570,7 @@ def set_mempool_enabled(device: Devicelike, enable: bool) -> None:
|
|
|
4436
4570
|
raise ValueError("Memory pools are only supported on CUDA devices")
|
|
4437
4571
|
|
|
4438
4572
|
|
|
4439
|
-
def set_mempool_release_threshold(device: Devicelike, threshold:
|
|
4573
|
+
def set_mempool_release_threshold(device: Devicelike, threshold: int | float) -> None:
|
|
4440
4574
|
"""Set the CUDA memory pool release threshold on the device.
|
|
4441
4575
|
|
|
4442
4576
|
This is the amount of reserved memory to hold onto before trying to release memory back to the OS.
|
|
@@ -4744,7 +4878,7 @@ def set_stream(stream: Stream, device: Devicelike = None, sync: bool = False) ->
|
|
|
4744
4878
|
get_device(device).set_stream(stream, sync=sync)
|
|
4745
4879
|
|
|
4746
4880
|
|
|
4747
|
-
def record_event(event:
|
|
4881
|
+
def record_event(event: Event | None = None):
|
|
4748
4882
|
"""Convenience function for calling :meth:`Stream.record_event` on the current stream.
|
|
4749
4883
|
|
|
4750
4884
|
Args:
|
|
@@ -4793,7 +4927,7 @@ def get_event_elapsed_time(start_event: Event, end_event: Event, synchronize: bo
|
|
|
4793
4927
|
return runtime.core.cuda_event_elapsed_time(start_event.cuda_event, end_event.cuda_event)
|
|
4794
4928
|
|
|
4795
4929
|
|
|
4796
|
-
def wait_stream(other_stream: Stream, event:
|
|
4930
|
+
def wait_stream(other_stream: Stream, event: Event | None = None):
|
|
4797
4931
|
"""Convenience function for calling :meth:`Stream.wait_stream` on the current stream.
|
|
4798
4932
|
|
|
4799
4933
|
Args:
|
|
@@ -4863,7 +4997,7 @@ class RegisteredGLBuffer:
|
|
|
4863
4997
|
__fallback_warning_shown = False
|
|
4864
4998
|
|
|
4865
4999
|
def __new__(cls, *args, **kwargs):
|
|
4866
|
-
instance = super(
|
|
5000
|
+
instance = super().__new__(cls)
|
|
4867
5001
|
instance.resource = None
|
|
4868
5002
|
return instance
|
|
4869
5003
|
|
|
@@ -4960,8 +5094,8 @@ class RegisteredGLBuffer:
|
|
|
4960
5094
|
|
|
4961
5095
|
|
|
4962
5096
|
def zeros(
|
|
4963
|
-
shape:
|
|
4964
|
-
dtype=float,
|
|
5097
|
+
shape: int | tuple[int, ...] | list[int] | None = None,
|
|
5098
|
+
dtype: type = float,
|
|
4965
5099
|
device: Devicelike = None,
|
|
4966
5100
|
requires_grad: bool = False,
|
|
4967
5101
|
pinned: bool = False,
|
|
@@ -4988,7 +5122,7 @@ def zeros(
|
|
|
4988
5122
|
|
|
4989
5123
|
|
|
4990
5124
|
def zeros_like(
|
|
4991
|
-
src: Array, device: Devicelike = None, requires_grad:
|
|
5125
|
+
src: Array, device: Devicelike = None, requires_grad: bool | None = None, pinned: bool | None = None
|
|
4992
5126
|
) -> warp.array:
|
|
4993
5127
|
"""Return a zero-initialized array with the same type and dimension of another array
|
|
4994
5128
|
|
|
@@ -5010,8 +5144,8 @@ def zeros_like(
|
|
|
5010
5144
|
|
|
5011
5145
|
|
|
5012
5146
|
def ones(
|
|
5013
|
-
shape:
|
|
5014
|
-
dtype=float,
|
|
5147
|
+
shape: int | tuple[int, ...] | list[int] | None = None,
|
|
5148
|
+
dtype: type = float,
|
|
5015
5149
|
device: Devicelike = None,
|
|
5016
5150
|
requires_grad: bool = False,
|
|
5017
5151
|
pinned: bool = False,
|
|
@@ -5034,7 +5168,7 @@ def ones(
|
|
|
5034
5168
|
|
|
5035
5169
|
|
|
5036
5170
|
def ones_like(
|
|
5037
|
-
src: Array, device: Devicelike = None, requires_grad:
|
|
5171
|
+
src: Array, device: Devicelike = None, requires_grad: bool | None = None, pinned: bool | None = None
|
|
5038
5172
|
) -> warp.array:
|
|
5039
5173
|
"""Return a one-initialized array with the same type and dimension of another array
|
|
5040
5174
|
|
|
@@ -5052,7 +5186,7 @@ def ones_like(
|
|
|
5052
5186
|
|
|
5053
5187
|
|
|
5054
5188
|
def full(
|
|
5055
|
-
shape:
|
|
5189
|
+
shape: int | tuple[int, ...] | list[int] | None = None,
|
|
5056
5190
|
value=0,
|
|
5057
5191
|
dtype=Any,
|
|
5058
5192
|
device: Devicelike = None,
|
|
@@ -5121,8 +5255,8 @@ def full_like(
|
|
|
5121
5255
|
src: Array,
|
|
5122
5256
|
value: Any,
|
|
5123
5257
|
device: Devicelike = None,
|
|
5124
|
-
requires_grad:
|
|
5125
|
-
pinned:
|
|
5258
|
+
requires_grad: bool | None = None,
|
|
5259
|
+
pinned: bool | None = None,
|
|
5126
5260
|
) -> warp.array:
|
|
5127
5261
|
"""Return an array with all elements initialized to the given value with the same type and dimension of another array
|
|
5128
5262
|
|
|
@@ -5145,7 +5279,7 @@ def full_like(
|
|
|
5145
5279
|
|
|
5146
5280
|
|
|
5147
5281
|
def clone(
|
|
5148
|
-
src: warp.array, device: Devicelike = None, requires_grad:
|
|
5282
|
+
src: warp.array, device: Devicelike = None, requires_grad: bool | None = None, pinned: bool | None = None
|
|
5149
5283
|
) -> warp.array:
|
|
5150
5284
|
"""Clone an existing array, allocates a copy of the src memory
|
|
5151
5285
|
|
|
@@ -5167,7 +5301,7 @@ def clone(
|
|
|
5167
5301
|
|
|
5168
5302
|
|
|
5169
5303
|
def empty(
|
|
5170
|
-
shape:
|
|
5304
|
+
shape: int | tuple[int, ...] | list[int] | None = None,
|
|
5171
5305
|
dtype=float,
|
|
5172
5306
|
device: Devicelike = None,
|
|
5173
5307
|
requires_grad: bool = False,
|
|
@@ -5200,7 +5334,7 @@ def empty(
|
|
|
5200
5334
|
|
|
5201
5335
|
|
|
5202
5336
|
def empty_like(
|
|
5203
|
-
src: Array, device: Devicelike = None, requires_grad:
|
|
5337
|
+
src: Array, device: Devicelike = None, requires_grad: bool | None = None, pinned: bool | None = None
|
|
5204
5338
|
) -> warp.array:
|
|
5205
5339
|
"""Return an uninitialized array with the same type and dimension of another array
|
|
5206
5340
|
|
|
@@ -5235,9 +5369,9 @@ def empty_like(
|
|
|
5235
5369
|
|
|
5236
5370
|
def from_numpy(
|
|
5237
5371
|
arr: np.ndarray,
|
|
5238
|
-
dtype:
|
|
5239
|
-
shape:
|
|
5240
|
-
device:
|
|
5372
|
+
dtype: type | None = None,
|
|
5373
|
+
shape: Sequence[int] | None = None,
|
|
5374
|
+
device: Devicelike | None = None,
|
|
5241
5375
|
requires_grad: bool = False,
|
|
5242
5376
|
) -> warp.array:
|
|
5243
5377
|
"""Returns a Warp array created from a NumPy array.
|
|
@@ -5255,7 +5389,7 @@ def from_numpy(
|
|
|
5255
5389
|
if dtype is None:
|
|
5256
5390
|
base_type = warp.types.np_dtype_to_warp_type.get(arr.dtype)
|
|
5257
5391
|
if base_type is None:
|
|
5258
|
-
raise RuntimeError("Unsupported NumPy data type '{}'."
|
|
5392
|
+
raise RuntimeError(f"Unsupported NumPy data type '{arr.dtype}'.")
|
|
5259
5393
|
|
|
5260
5394
|
dim_count = len(arr.shape)
|
|
5261
5395
|
if dim_count == 2:
|
|
@@ -5274,7 +5408,7 @@ def from_numpy(
|
|
|
5274
5408
|
)
|
|
5275
5409
|
|
|
5276
5410
|
|
|
5277
|
-
def event_from_ipc_handle(handle, device:
|
|
5411
|
+
def event_from_ipc_handle(handle, device: Devicelike = None) -> Event:
|
|
5278
5412
|
"""Create an event from an IPC handle.
|
|
5279
5413
|
|
|
5280
5414
|
Args:
|
|
@@ -5443,10 +5577,10 @@ class Launch:
|
|
|
5443
5577
|
self,
|
|
5444
5578
|
kernel,
|
|
5445
5579
|
device: Device,
|
|
5446
|
-
hooks:
|
|
5447
|
-
params:
|
|
5448
|
-
params_addr:
|
|
5449
|
-
bounds:
|
|
5580
|
+
hooks: KernelHooks | None = None,
|
|
5581
|
+
params: Sequence[Any] | None = None,
|
|
5582
|
+
params_addr: Sequence[ctypes.c_void_p] | None = None,
|
|
5583
|
+
bounds: launch_bounds_t | None = None,
|
|
5450
5584
|
max_blocks: int = 0,
|
|
5451
5585
|
block_dim: int = 256,
|
|
5452
5586
|
adjoint: bool = False,
|
|
@@ -5516,7 +5650,7 @@ class Launch:
|
|
|
5516
5650
|
self.adjoint: bool = adjoint
|
|
5517
5651
|
"""Whether to run the adjoint kernel instead of the forward kernel."""
|
|
5518
5652
|
|
|
5519
|
-
def set_dim(self, dim:
|
|
5653
|
+
def set_dim(self, dim: int | list[int] | tuple[int, ...]):
|
|
5520
5654
|
"""Set the launch dimensions.
|
|
5521
5655
|
|
|
5522
5656
|
Args:
|
|
@@ -5554,7 +5688,7 @@ class Launch:
|
|
|
5554
5688
|
if self.params_addr:
|
|
5555
5689
|
self.params_addr[params_index] = ctypes.c_void_p(ctypes.addressof(carg))
|
|
5556
5690
|
|
|
5557
|
-
def set_param_at_index_from_ctype(self, index: int, value:
|
|
5691
|
+
def set_param_at_index_from_ctype(self, index: int, value: ctypes.Structure | int | float):
|
|
5558
5692
|
"""Set a kernel parameter at an index without any type conversion.
|
|
5559
5693
|
|
|
5560
5694
|
Args:
|
|
@@ -5617,7 +5751,7 @@ class Launch:
|
|
|
5617
5751
|
for i, v in enumerate(values):
|
|
5618
5752
|
self.set_param_at_index_from_ctype(i, v)
|
|
5619
5753
|
|
|
5620
|
-
def launch(self, stream:
|
|
5754
|
+
def launch(self, stream: Stream | None = None) -> None:
|
|
5621
5755
|
"""Launch the kernel.
|
|
5622
5756
|
|
|
5623
5757
|
Args:
|
|
@@ -5634,7 +5768,7 @@ class Launch:
|
|
|
5634
5768
|
|
|
5635
5769
|
# If the stream is capturing, we retain the CUDA module so that it doesn't get unloaded
|
|
5636
5770
|
# before the captured graph is released.
|
|
5637
|
-
if runtime.core.cuda_stream_is_capturing(stream.cuda_stream):
|
|
5771
|
+
if len(runtime.captures) > 0 and runtime.core.cuda_stream_is_capturing(stream.cuda_stream):
|
|
5638
5772
|
capture_id = runtime.core.cuda_stream_get_capture_id(stream.cuda_stream)
|
|
5639
5773
|
graph = runtime.captures.get(capture_id)
|
|
5640
5774
|
if graph is not None:
|
|
@@ -5666,13 +5800,13 @@ class Launch:
|
|
|
5666
5800
|
|
|
5667
5801
|
def launch(
|
|
5668
5802
|
kernel,
|
|
5669
|
-
dim:
|
|
5803
|
+
dim: int | Sequence[int],
|
|
5670
5804
|
inputs: Sequence = [],
|
|
5671
5805
|
outputs: Sequence = [],
|
|
5672
5806
|
adj_inputs: Sequence = [],
|
|
5673
5807
|
adj_outputs: Sequence = [],
|
|
5674
5808
|
device: Devicelike = None,
|
|
5675
|
-
stream:
|
|
5809
|
+
stream: Stream | None = None,
|
|
5676
5810
|
adjoint: bool = False,
|
|
5677
5811
|
record_tape: bool = True,
|
|
5678
5812
|
record_cmd: bool = False,
|
|
@@ -5824,7 +5958,7 @@ def launch(
|
|
|
5824
5958
|
|
|
5825
5959
|
# If the stream is capturing, we retain the CUDA module so that it doesn't get unloaded
|
|
5826
5960
|
# before the captured graph is released.
|
|
5827
|
-
if runtime.core.cuda_stream_is_capturing(stream.cuda_stream):
|
|
5961
|
+
if len(runtime.captures) > 0 and runtime.core.cuda_stream_is_capturing(stream.cuda_stream):
|
|
5828
5962
|
capture_id = runtime.core.cuda_stream_get_capture_id(stream.cuda_stream)
|
|
5829
5963
|
graph = runtime.captures.get(capture_id)
|
|
5830
5964
|
if graph is not None:
|
|
@@ -5968,7 +6102,7 @@ def launch_tiled(*args, **kwargs):
|
|
|
5968
6102
|
raise RuntimeError("wp.launch_tiled() requires a grid with fewer than 4 dimensions")
|
|
5969
6103
|
|
|
5970
6104
|
# add trailing dimension
|
|
5971
|
-
kwargs["dim"] = dim
|
|
6105
|
+
kwargs["dim"] = [*dim, kwargs["block_dim"]]
|
|
5972
6106
|
|
|
5973
6107
|
# forward to original launch method
|
|
5974
6108
|
return launch(*args, **kwargs)
|
|
@@ -6016,7 +6150,7 @@ def synchronize_device(device: Devicelike = None):
|
|
|
6016
6150
|
runtime.core.cuda_context_synchronize(device.context)
|
|
6017
6151
|
|
|
6018
6152
|
|
|
6019
|
-
def synchronize_stream(stream_or_device:
|
|
6153
|
+
def synchronize_stream(stream_or_device: Stream | Devicelike | None = None):
|
|
6020
6154
|
"""Synchronize the calling CPU thread with any outstanding CUDA work on the specified stream.
|
|
6021
6155
|
|
|
6022
6156
|
This function allows the host application code to ensure that all kernel launches
|
|
@@ -6046,7 +6180,7 @@ def synchronize_event(event: Event):
|
|
|
6046
6180
|
runtime.core.cuda_event_synchronize(event.cuda_event)
|
|
6047
6181
|
|
|
6048
6182
|
|
|
6049
|
-
def force_load(device:
|
|
6183
|
+
def force_load(device: Device | str | list[Device] | list[str] | None = None, modules: list[Module] | None = None):
|
|
6050
6184
|
"""Force user-defined kernels to be compiled and loaded
|
|
6051
6185
|
|
|
6052
6186
|
Args:
|
|
@@ -6078,7 +6212,7 @@ def force_load(device: Union[Device, str, List[Device], List[str]] = None, modul
|
|
|
6078
6212
|
|
|
6079
6213
|
|
|
6080
6214
|
def load_module(
|
|
6081
|
-
module:
|
|
6215
|
+
module: Module | types.ModuleType | str | None = None, device: Device | str | None = None, recursive: bool = False
|
|
6082
6216
|
):
|
|
6083
6217
|
"""Force user-defined module to be compiled and loaded
|
|
6084
6218
|
|
|
@@ -6120,7 +6254,7 @@ def load_module(
|
|
|
6120
6254
|
force_load(device=device, modules=modules)
|
|
6121
6255
|
|
|
6122
6256
|
|
|
6123
|
-
def set_module_options(options:
|
|
6257
|
+
def set_module_options(options: dict[str, Any], module: Any = None):
|
|
6124
6258
|
"""Set options for the current module.
|
|
6125
6259
|
|
|
6126
6260
|
Options can be used to control runtime compilation and code-generation
|
|
@@ -6144,7 +6278,7 @@ def set_module_options(options: Dict[str, Any], module: Optional[Any] = None):
|
|
|
6144
6278
|
get_module(m.__name__).mark_modified()
|
|
6145
6279
|
|
|
6146
6280
|
|
|
6147
|
-
def get_module_options(module:
|
|
6281
|
+
def get_module_options(module: Any = None) -> dict[str, Any]:
|
|
6148
6282
|
"""Returns a list of options for the current module."""
|
|
6149
6283
|
if module is None:
|
|
6150
6284
|
m = inspect.getmodule(inspect.stack()[1][0])
|
|
@@ -6154,10 +6288,44 @@ def get_module_options(module: Optional[Any] = None) -> Dict[str, Any]:
|
|
|
6154
6288
|
return get_module(m.__name__).options
|
|
6155
6289
|
|
|
6156
6290
|
|
|
6291
|
+
def _unregister_capture(device: Device, stream: Stream, graph: Graph):
|
|
6292
|
+
"""Unregister a graph capture from the device and runtime.
|
|
6293
|
+
|
|
6294
|
+
This should be called when a graph capture is no longer active, either because it completed or was paused.
|
|
6295
|
+
The graph should only be registered while it is actively capturing.
|
|
6296
|
+
|
|
6297
|
+
Args:
|
|
6298
|
+
device: The CUDA device the graph was being captured on
|
|
6299
|
+
stream: The CUDA stream the graph was being captured on
|
|
6300
|
+
graph: The Graph object that was being captured
|
|
6301
|
+
"""
|
|
6302
|
+
del device.captures[stream]
|
|
6303
|
+
del runtime.captures[graph.capture_id]
|
|
6304
|
+
|
|
6305
|
+
|
|
6306
|
+
def _register_capture(device: Device, stream: Stream, graph: Graph, capture_id: int):
|
|
6307
|
+
"""Register a graph capture with the device and runtime.
|
|
6308
|
+
|
|
6309
|
+
Makes the graph discoverable through its capture_id so that retain_module_exec() can be called
|
|
6310
|
+
when launching kernels during graph capture. This ensures modules are retained until graph execution completes.
|
|
6311
|
+
|
|
6312
|
+
Args:
|
|
6313
|
+
device: The CUDA device the graph is being captured on
|
|
6314
|
+
stream: The CUDA stream the graph is being captured on
|
|
6315
|
+
graph: The Graph object being captured
|
|
6316
|
+
capture_id: Unique identifier for this graph capture
|
|
6317
|
+
"""
|
|
6318
|
+
# add to ongoing captures on the device
|
|
6319
|
+
device.captures[stream] = graph
|
|
6320
|
+
|
|
6321
|
+
# add to lookup table by globally unique capture id
|
|
6322
|
+
runtime.captures[capture_id] = graph
|
|
6323
|
+
|
|
6324
|
+
|
|
6157
6325
|
def capture_begin(
|
|
6158
6326
|
device: Devicelike = None,
|
|
6159
|
-
stream:
|
|
6160
|
-
force_module_load:
|
|
6327
|
+
stream: Stream | None = None,
|
|
6328
|
+
force_module_load: bool | None = None,
|
|
6161
6329
|
external: bool = False,
|
|
6162
6330
|
):
|
|
6163
6331
|
"""Begin capture of a CUDA graph
|
|
@@ -6219,14 +6387,10 @@ def capture_begin(
|
|
|
6219
6387
|
capture_id = runtime.core.cuda_stream_get_capture_id(stream.cuda_stream)
|
|
6220
6388
|
graph = Graph(device, capture_id)
|
|
6221
6389
|
|
|
6222
|
-
|
|
6223
|
-
device.captures[stream] = graph
|
|
6224
|
-
|
|
6225
|
-
# add to lookup table by globally unique capture id
|
|
6226
|
-
runtime.captures[capture_id] = graph
|
|
6390
|
+
_register_capture(device, stream, graph, capture_id)
|
|
6227
6391
|
|
|
6228
6392
|
|
|
6229
|
-
def capture_end(device: Devicelike = None, stream:
|
|
6393
|
+
def capture_end(device: Devicelike = None, stream: Stream | None = None) -> Graph:
|
|
6230
6394
|
"""End the capture of a CUDA graph.
|
|
6231
6395
|
|
|
6232
6396
|
Args:
|
|
@@ -6251,24 +6415,361 @@ def capture_end(device: Devicelike = None, stream: Optional[Stream] = None) -> G
|
|
|
6251
6415
|
if graph is None:
|
|
6252
6416
|
raise RuntimeError("Graph capture is not active on this stream")
|
|
6253
6417
|
|
|
6254
|
-
|
|
6255
|
-
del runtime.captures[graph.capture_id]
|
|
6418
|
+
_unregister_capture(device, stream, graph)
|
|
6256
6419
|
|
|
6257
6420
|
# get the graph executable
|
|
6258
|
-
|
|
6259
|
-
result = runtime.core.cuda_graph_end_capture(device.context, stream.cuda_stream, ctypes.byref(
|
|
6421
|
+
g = ctypes.c_void_p()
|
|
6422
|
+
result = runtime.core.cuda_graph_end_capture(device.context, stream.cuda_stream, ctypes.byref(g))
|
|
6260
6423
|
|
|
6261
6424
|
if not result:
|
|
6262
6425
|
# A concrete error should've already been reported, so we don't need to go into details here
|
|
6263
6426
|
raise RuntimeError(f"CUDA graph capture failed. {runtime.get_error_string()}")
|
|
6264
6427
|
|
|
6265
6428
|
# set the graph executable
|
|
6266
|
-
graph.
|
|
6429
|
+
graph.graph = g
|
|
6430
|
+
graph.graph_exec = None # Lazy initialization
|
|
6431
|
+
|
|
6432
|
+
return graph
|
|
6433
|
+
|
|
6434
|
+
|
|
6435
|
+
def capture_debug_dot_print(graph: Graph, path: str, verbose: bool = False):
|
|
6436
|
+
"""Export a CUDA graph to a DOT file for visualization
|
|
6437
|
+
|
|
6438
|
+
Args:
|
|
6439
|
+
graph: A :class:`Graph` as returned by :func:`~warp.capture_end()`
|
|
6440
|
+
path: Path to save the DOT file
|
|
6441
|
+
verbose: Whether to include additional debug information in the output
|
|
6442
|
+
"""
|
|
6443
|
+
if not runtime.core.capture_debug_dot_print(graph.graph, path.encode(), 0 if verbose else 1):
|
|
6444
|
+
raise RuntimeError(f"Graph debug dot print error: {runtime.get_error_string()}")
|
|
6445
|
+
|
|
6446
|
+
|
|
6447
|
+
def assert_conditional_graph_support():
|
|
6448
|
+
if runtime is None:
|
|
6449
|
+
init()
|
|
6450
|
+
|
|
6451
|
+
if runtime.toolkit_version < (12, 4):
|
|
6452
|
+
raise RuntimeError("Warp must be built with CUDA Toolkit 12.4+ to enable conditional graph nodes")
|
|
6453
|
+
|
|
6454
|
+
if runtime.driver_version < (12, 4):
|
|
6455
|
+
raise RuntimeError("Conditional graph nodes require CUDA driver 12.4+")
|
|
6456
|
+
|
|
6457
|
+
|
|
6458
|
+
def capture_pause(device: Devicelike = None, stream: Stream | None = None) -> Graph:
|
|
6459
|
+
if stream is not None:
|
|
6460
|
+
device = stream.device
|
|
6461
|
+
else:
|
|
6462
|
+
device = runtime.get_device(device)
|
|
6463
|
+
if not device.is_cuda:
|
|
6464
|
+
raise RuntimeError("Must be a CUDA device")
|
|
6465
|
+
stream = device.stream
|
|
6466
|
+
|
|
6467
|
+
# get the graph being captured
|
|
6468
|
+
graph = device.captures.get(stream)
|
|
6469
|
+
|
|
6470
|
+
if graph is None:
|
|
6471
|
+
raise RuntimeError("Graph capture is not active on this stream")
|
|
6472
|
+
|
|
6473
|
+
_unregister_capture(device, stream, graph)
|
|
6474
|
+
|
|
6475
|
+
g = ctypes.c_void_p()
|
|
6476
|
+
if not runtime.core.cuda_graph_pause_capture(device.context, stream.cuda_stream, ctypes.byref(g)):
|
|
6477
|
+
raise RuntimeError(runtime.get_error_string())
|
|
6478
|
+
|
|
6479
|
+
graph.graph = g
|
|
6267
6480
|
|
|
6268
6481
|
return graph
|
|
6269
6482
|
|
|
6270
6483
|
|
|
6271
|
-
def
|
|
6484
|
+
def capture_resume(graph: Graph, device: Devicelike = None, stream: Stream | None = None):
|
|
6485
|
+
if stream is not None:
|
|
6486
|
+
device = stream.device
|
|
6487
|
+
else:
|
|
6488
|
+
device = runtime.get_device(device)
|
|
6489
|
+
if not device.is_cuda:
|
|
6490
|
+
raise RuntimeError("Must be a CUDA device")
|
|
6491
|
+
stream = device.stream
|
|
6492
|
+
|
|
6493
|
+
if not runtime.core.cuda_graph_resume_capture(device.context, stream.cuda_stream, graph.graph):
|
|
6494
|
+
raise RuntimeError(runtime.get_error_string())
|
|
6495
|
+
|
|
6496
|
+
capture_id = runtime.core.cuda_stream_get_capture_id(stream.cuda_stream)
|
|
6497
|
+
graph.capture_id = capture_id
|
|
6498
|
+
|
|
6499
|
+
_register_capture(device, stream, graph, capture_id)
|
|
6500
|
+
|
|
6501
|
+
|
|
6502
|
+
# reusable pinned readback buffer for conditions
|
|
6503
|
+
condition_host = None
|
|
6504
|
+
|
|
6505
|
+
|
|
6506
|
+
def capture_if(
|
|
6507
|
+
condition: warp.array(dtype=int),
|
|
6508
|
+
on_true: Callable | Graph | None = None,
|
|
6509
|
+
on_false: Callable | Graph | None = None,
|
|
6510
|
+
stream: Stream = None,
|
|
6511
|
+
**kwargs,
|
|
6512
|
+
):
|
|
6513
|
+
"""Create a dynamic branch based on a condition.
|
|
6514
|
+
|
|
6515
|
+
The condition value is retrieved from the first element of the ``condition`` array.
|
|
6516
|
+
|
|
6517
|
+
This function is particularly useful with CUDA graphs, but can be used without graph capture as well.
|
|
6518
|
+
CUDA 12.4+ is required to take advantage of conditional graph nodes for dynamic control flow.
|
|
6519
|
+
|
|
6520
|
+
Args:
|
|
6521
|
+
condition: Warp array holding the condition value.
|
|
6522
|
+
on_true: A callback function or :class:`Graph` to execute if the condition is True.
|
|
6523
|
+
on_false: A callback function or :class:`Graph` to execute if the condition is False.
|
|
6524
|
+
stream: The CUDA stream where the condition was written. If None, use the current stream on the device where ``condition`` resides.
|
|
6525
|
+
|
|
6526
|
+
Any additional keyword arguments are forwarded to the callback functions.
|
|
6527
|
+
"""
|
|
6528
|
+
|
|
6529
|
+
# if neither the IF branch nor the ELSE branch is specified, it's a no-op
|
|
6530
|
+
if on_true is None and on_false is None:
|
|
6531
|
+
return
|
|
6532
|
+
|
|
6533
|
+
# check condition data type
|
|
6534
|
+
if not isinstance(condition, warp.array) or condition.dtype is not warp.int32:
|
|
6535
|
+
raise TypeError("Condition must be a Warp array of int32 with a single element")
|
|
6536
|
+
|
|
6537
|
+
device = condition.device
|
|
6538
|
+
|
|
6539
|
+
# determine the stream and whether a graph capture is active
|
|
6540
|
+
if device.is_cuda:
|
|
6541
|
+
if stream is None:
|
|
6542
|
+
stream = device.stream
|
|
6543
|
+
graph = device.captures.get(stream)
|
|
6544
|
+
else:
|
|
6545
|
+
graph = None
|
|
6546
|
+
|
|
6547
|
+
if graph is None:
|
|
6548
|
+
# if no graph is active, just execute the correct branch directly
|
|
6549
|
+
if device.is_cuda:
|
|
6550
|
+
# use a pinned buffer for condition readback to host
|
|
6551
|
+
global condition_host
|
|
6552
|
+
if condition_host is None:
|
|
6553
|
+
condition_host = warp.empty(1, dtype=int, device="cpu", pinned=True)
|
|
6554
|
+
warp.copy(condition_host, condition, stream=stream)
|
|
6555
|
+
warp.synchronize_stream(stream)
|
|
6556
|
+
condition_value = bool(ctypes.cast(condition_host.ptr, ctypes.POINTER(ctypes.c_int32)).contents)
|
|
6557
|
+
else:
|
|
6558
|
+
condition_value = bool(ctypes.cast(condition.ptr, ctypes.POINTER(ctypes.c_int32)).contents)
|
|
6559
|
+
|
|
6560
|
+
if condition_value:
|
|
6561
|
+
if on_true is not None:
|
|
6562
|
+
if isinstance(on_true, Callable):
|
|
6563
|
+
on_true(**kwargs)
|
|
6564
|
+
elif isinstance(on_true, Graph):
|
|
6565
|
+
capture_launch(on_true, stream=stream)
|
|
6566
|
+
else:
|
|
6567
|
+
raise TypeError("on_true must be a Callable or a Graph")
|
|
6568
|
+
else:
|
|
6569
|
+
if on_false is not None:
|
|
6570
|
+
if isinstance(on_false, Callable):
|
|
6571
|
+
on_false(**kwargs)
|
|
6572
|
+
elif isinstance(on_false, Graph):
|
|
6573
|
+
capture_launch(on_false, stream=stream)
|
|
6574
|
+
else:
|
|
6575
|
+
raise TypeError("on_false must be a Callable or a Graph")
|
|
6576
|
+
|
|
6577
|
+
return
|
|
6578
|
+
|
|
6579
|
+
graph.has_conditional = True
|
|
6580
|
+
|
|
6581
|
+
# ensure conditional graph nodes are supported
|
|
6582
|
+
assert_conditional_graph_support()
|
|
6583
|
+
|
|
6584
|
+
# insert conditional node
|
|
6585
|
+
graph_on_true = ctypes.c_void_p()
|
|
6586
|
+
graph_on_false = ctypes.c_void_p()
|
|
6587
|
+
if not runtime.core.cuda_graph_insert_if_else(
|
|
6588
|
+
device.context,
|
|
6589
|
+
stream.cuda_stream,
|
|
6590
|
+
ctypes.cast(condition.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
6591
|
+
None if on_true is None else ctypes.byref(graph_on_true),
|
|
6592
|
+
None if on_false is None else ctypes.byref(graph_on_false),
|
|
6593
|
+
):
|
|
6594
|
+
raise RuntimeError(runtime.get_error_string())
|
|
6595
|
+
|
|
6596
|
+
# pause capturing parent graph
|
|
6597
|
+
main_graph = capture_pause(stream=stream)
|
|
6598
|
+
# store the pointer to the cuda graph to restore it later
|
|
6599
|
+
main_graph_ptr = main_graph.graph
|
|
6600
|
+
|
|
6601
|
+
# capture if-graph
|
|
6602
|
+
if on_true is not None:
|
|
6603
|
+
# temporarily repurpose the main_graph python object such that all dependencies
|
|
6604
|
+
# added through retain_module_exec() end up in the correct python graph object
|
|
6605
|
+
main_graph.graph = graph_on_true
|
|
6606
|
+
capture_resume(main_graph, stream=stream)
|
|
6607
|
+
if isinstance(on_true, Callable):
|
|
6608
|
+
on_true(**kwargs)
|
|
6609
|
+
elif isinstance(on_true, Graph):
|
|
6610
|
+
if on_true.has_conditional:
|
|
6611
|
+
raise RuntimeError(
|
|
6612
|
+
"The on_true graph contains conditional nodes, which are not allowed in child graphs"
|
|
6613
|
+
)
|
|
6614
|
+
if not runtime.core.cuda_graph_insert_child_graph(
|
|
6615
|
+
device.context,
|
|
6616
|
+
stream.cuda_stream,
|
|
6617
|
+
on_true.graph,
|
|
6618
|
+
):
|
|
6619
|
+
raise RuntimeError(runtime.get_error_string())
|
|
6620
|
+
else:
|
|
6621
|
+
raise TypeError("on_true must be a Callable or a Graph")
|
|
6622
|
+
capture_pause(stream=stream)
|
|
6623
|
+
|
|
6624
|
+
# capture else-graph
|
|
6625
|
+
if on_false is not None:
|
|
6626
|
+
# temporarily repurpose the main_graph python object such that all dependencies
|
|
6627
|
+
# added through retain_module_exec() end up in the correct python graph object
|
|
6628
|
+
main_graph.graph = graph_on_false
|
|
6629
|
+
capture_resume(main_graph, stream=stream)
|
|
6630
|
+
if isinstance(on_false, Callable):
|
|
6631
|
+
on_false(**kwargs)
|
|
6632
|
+
elif isinstance(on_false, Graph):
|
|
6633
|
+
if on_false.has_conditional:
|
|
6634
|
+
raise RuntimeError(
|
|
6635
|
+
"The on_false graph contains conditional nodes, which are not allowed in child graphs"
|
|
6636
|
+
)
|
|
6637
|
+
if not runtime.core.cuda_graph_insert_child_graph(
|
|
6638
|
+
device.context,
|
|
6639
|
+
stream.cuda_stream,
|
|
6640
|
+
on_false.graph,
|
|
6641
|
+
):
|
|
6642
|
+
raise RuntimeError(runtime.get_error_string())
|
|
6643
|
+
else:
|
|
6644
|
+
raise TypeError("on_false must be a Callable or a Graph")
|
|
6645
|
+
capture_pause(stream=stream)
|
|
6646
|
+
|
|
6647
|
+
# restore the main graph to its original state
|
|
6648
|
+
main_graph.graph = main_graph_ptr
|
|
6649
|
+
|
|
6650
|
+
# resume capturing parent graph
|
|
6651
|
+
capture_resume(main_graph, stream=stream)
|
|
6652
|
+
|
|
6653
|
+
|
|
6654
|
+
def capture_while(condition: warp.array(dtype=int), while_body: Callable | Graph, stream: Stream = None, **kwargs):
|
|
6655
|
+
"""Create a dynamic loop based on a condition.
|
|
6656
|
+
|
|
6657
|
+
The condition value is retrieved from the first element of the ``condition`` array.
|
|
6658
|
+
|
|
6659
|
+
The ``while_body`` callback is responsible for updating the condition value so the loop can terminate.
|
|
6660
|
+
|
|
6661
|
+
This function is particularly useful with CUDA graphs, but can be used without graph capture as well.
|
|
6662
|
+
CUDA 12.4+ is required to take advantage of conditional graph nodes for dynamic control flow.
|
|
6663
|
+
|
|
6664
|
+
Args:
|
|
6665
|
+
condition: Warp array holding the condition value.
|
|
6666
|
+
while_body: A callback function or :class:`Graph` to execute while the loop condition is True.
|
|
6667
|
+
stream: The CUDA stream where the condition was written. If None, use the current stream on the device where ``condition`` resides.
|
|
6668
|
+
|
|
6669
|
+
Any additional keyword arguments are forwarded to the callback function.
|
|
6670
|
+
"""
|
|
6671
|
+
|
|
6672
|
+
# check condition data type
|
|
6673
|
+
if not isinstance(condition, warp.array) or condition.dtype is not warp.int32:
|
|
6674
|
+
raise TypeError("Condition must be a Warp array of int32 with a single element")
|
|
6675
|
+
|
|
6676
|
+
device = condition.device
|
|
6677
|
+
|
|
6678
|
+
# determine the stream and whether a graph capture is active
|
|
6679
|
+
if device.is_cuda:
|
|
6680
|
+
if stream is None:
|
|
6681
|
+
stream = device.stream
|
|
6682
|
+
graph = device.captures.get(stream)
|
|
6683
|
+
else:
|
|
6684
|
+
graph = None
|
|
6685
|
+
|
|
6686
|
+
if graph is None:
|
|
6687
|
+
# since no graph is active, just execute the kernels directly
|
|
6688
|
+
while True:
|
|
6689
|
+
if device.is_cuda:
|
|
6690
|
+
# use a pinned buffer for condition readback to host
|
|
6691
|
+
global condition_host
|
|
6692
|
+
if condition_host is None:
|
|
6693
|
+
condition_host = warp.empty(1, dtype=int, device="cpu", pinned=True)
|
|
6694
|
+
warp.copy(condition_host, condition, stream=stream)
|
|
6695
|
+
warp.synchronize_stream(stream)
|
|
6696
|
+
condition_value = bool(ctypes.cast(condition_host.ptr, ctypes.POINTER(ctypes.c_int32)).contents)
|
|
6697
|
+
else:
|
|
6698
|
+
condition_value = bool(ctypes.cast(condition.ptr, ctypes.POINTER(ctypes.c_int32)).contents)
|
|
6699
|
+
|
|
6700
|
+
if condition_value:
|
|
6701
|
+
if isinstance(while_body, Callable):
|
|
6702
|
+
while_body(**kwargs)
|
|
6703
|
+
elif isinstance(while_body, Graph):
|
|
6704
|
+
capture_launch(while_body, stream=stream)
|
|
6705
|
+
else:
|
|
6706
|
+
raise TypeError("while_body must be a callable or a graph")
|
|
6707
|
+
|
|
6708
|
+
else:
|
|
6709
|
+
break
|
|
6710
|
+
|
|
6711
|
+
return
|
|
6712
|
+
|
|
6713
|
+
graph.has_conditional = True
|
|
6714
|
+
|
|
6715
|
+
# ensure conditional graph nodes are supported
|
|
6716
|
+
assert_conditional_graph_support()
|
|
6717
|
+
|
|
6718
|
+
# insert conditional while-node
|
|
6719
|
+
body_graph = ctypes.c_void_p()
|
|
6720
|
+
cond_handle = ctypes.c_uint64()
|
|
6721
|
+
if not runtime.core.cuda_graph_insert_while(
|
|
6722
|
+
device.context,
|
|
6723
|
+
stream.cuda_stream,
|
|
6724
|
+
ctypes.cast(condition.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
6725
|
+
ctypes.byref(body_graph),
|
|
6726
|
+
ctypes.byref(cond_handle),
|
|
6727
|
+
):
|
|
6728
|
+
raise RuntimeError(runtime.get_error_string())
|
|
6729
|
+
|
|
6730
|
+
# pause capturing parent graph and start capturing child graph
|
|
6731
|
+
main_graph = capture_pause(stream=stream)
|
|
6732
|
+
# store the pointer to the cuda graph to restore it later
|
|
6733
|
+
main_graph_ptr = main_graph.graph
|
|
6734
|
+
|
|
6735
|
+
# temporarily repurpose the main_graph python object such that all dependencies
|
|
6736
|
+
# added through retain_module_exec() end up in the correct python graph object
|
|
6737
|
+
main_graph.graph = body_graph
|
|
6738
|
+
capture_resume(main_graph, stream=stream)
|
|
6739
|
+
|
|
6740
|
+
# capture while-body
|
|
6741
|
+
if isinstance(while_body, Callable):
|
|
6742
|
+
while_body(**kwargs)
|
|
6743
|
+
elif isinstance(while_body, Graph):
|
|
6744
|
+
if while_body.has_conditional:
|
|
6745
|
+
raise RuntimeError("The body graph contains conditional nodes, which are not allowed in child graphs")
|
|
6746
|
+
|
|
6747
|
+
if not runtime.core.cuda_graph_insert_child_graph(
|
|
6748
|
+
device.context,
|
|
6749
|
+
stream.cuda_stream,
|
|
6750
|
+
while_body.graph,
|
|
6751
|
+
):
|
|
6752
|
+
raise RuntimeError(runtime.get_error_string())
|
|
6753
|
+
else:
|
|
6754
|
+
raise RuntimeError(runtime.get_error_string())
|
|
6755
|
+
|
|
6756
|
+
# update condition
|
|
6757
|
+
if not runtime.core.cuda_graph_set_condition(
|
|
6758
|
+
device.context,
|
|
6759
|
+
stream.cuda_stream,
|
|
6760
|
+
ctypes.cast(condition.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
6761
|
+
cond_handle,
|
|
6762
|
+
):
|
|
6763
|
+
raise RuntimeError(runtime.get_error_string())
|
|
6764
|
+
|
|
6765
|
+
# stop capturing child graph and resume capturing parent graph
|
|
6766
|
+
capture_pause(stream=stream)
|
|
6767
|
+
# restore the main graph to its original state
|
|
6768
|
+
main_graph.graph = main_graph_ptr
|
|
6769
|
+
capture_resume(main_graph, stream=stream)
|
|
6770
|
+
|
|
6771
|
+
|
|
6772
|
+
def capture_launch(graph: Graph, stream: Stream | None = None):
|
|
6272
6773
|
"""Launch a previously captured CUDA graph
|
|
6273
6774
|
|
|
6274
6775
|
Args:
|
|
@@ -6284,6 +6785,15 @@ def capture_launch(graph: Graph, stream: Optional[Stream] = None):
|
|
|
6284
6785
|
device = graph.device
|
|
6285
6786
|
stream = device.stream
|
|
6286
6787
|
|
|
6788
|
+
if graph.graph_exec is None:
|
|
6789
|
+
g = ctypes.c_void_p()
|
|
6790
|
+
result = runtime.core.cuda_graph_create_exec(
|
|
6791
|
+
graph.device.context, stream.cuda_stream, graph.graph, ctypes.byref(g)
|
|
6792
|
+
)
|
|
6793
|
+
if not result:
|
|
6794
|
+
raise RuntimeError(f"Graph creation error: {runtime.get_error_string()}")
|
|
6795
|
+
graph.graph_exec = g
|
|
6796
|
+
|
|
6287
6797
|
if not runtime.core.cuda_graph_launch(graph.graph_exec, stream.cuda_stream):
|
|
6288
6798
|
raise RuntimeError(f"Graph launch error: {runtime.get_error_string()}")
|
|
6289
6799
|
|
|
@@ -6294,7 +6804,7 @@ def copy(
|
|
|
6294
6804
|
dest_offset: int = 0,
|
|
6295
6805
|
src_offset: int = 0,
|
|
6296
6806
|
count: int = 0,
|
|
6297
|
-
stream:
|
|
6807
|
+
stream: Stream | None = None,
|
|
6298
6808
|
):
|
|
6299
6809
|
"""Copy array contents from `src` to `dest`.
|
|
6300
6810
|
|
|
@@ -6431,11 +6941,8 @@ def copy(
|
|
|
6431
6941
|
|
|
6432
6942
|
# can't copy to/from fabric arrays of arrays, because they are jagged arrays of arbitrary lengths
|
|
6433
6943
|
# TODO?
|
|
6434
|
-
if (
|
|
6435
|
-
isinstance(
|
|
6436
|
-
and src.ndim > 1
|
|
6437
|
-
or isinstance(dest, (warp.fabricarray, warp.indexedfabricarray))
|
|
6438
|
-
and dest.ndim > 1
|
|
6944
|
+
if (isinstance(src, (warp.fabricarray, warp.indexedfabricarray)) and src.ndim > 1) or (
|
|
6945
|
+
isinstance(dest, (warp.fabricarray, warp.indexedfabricarray)) and dest.ndim > 1
|
|
6439
6946
|
):
|
|
6440
6947
|
raise RuntimeError("Copying to/from Fabric arrays of arrays is not supported")
|
|
6441
6948
|
|
|
@@ -6503,7 +7010,7 @@ def type_str(t):
|
|
|
6503
7010
|
return "Callable"
|
|
6504
7011
|
elif isinstance(t, int):
|
|
6505
7012
|
return str(t)
|
|
6506
|
-
elif isinstance(t, List):
|
|
7013
|
+
elif isinstance(t, (List, tuple)):
|
|
6507
7014
|
return "Tuple[" + ", ".join(map(type_str, t)) + "]"
|
|
6508
7015
|
elif isinstance(t, warp.array):
|
|
6509
7016
|
return f"Array[{type_str(t.dtype)}]"
|
|
@@ -6536,12 +7043,16 @@ def type_str(t):
|
|
|
6536
7043
|
|
|
6537
7044
|
raise TypeError("Invalid vector or matrix dimensions")
|
|
6538
7045
|
elif get_origin(t) in (list, tuple):
|
|
6539
|
-
|
|
6540
|
-
|
|
7046
|
+
args = get_args(t)
|
|
7047
|
+
if args:
|
|
7048
|
+
args_repr = ", ".join(type_str(x) for x in get_args(t))
|
|
7049
|
+
return f"{t._name}[{args_repr}]"
|
|
7050
|
+
else:
|
|
7051
|
+
return f"{t._name}"
|
|
6541
7052
|
elif t is Ellipsis:
|
|
6542
7053
|
return "..."
|
|
6543
7054
|
elif warp.types.is_tile(t):
|
|
6544
|
-
return "Tile"
|
|
7055
|
+
return f"Tile[{type_str(t.dtype)},{type_str(t.shape)}]"
|
|
6545
7056
|
|
|
6546
7057
|
return t.__name__
|
|
6547
7058
|
|
|
@@ -6568,14 +7079,14 @@ def resolve_exported_function_sig(f):
|
|
|
6568
7079
|
# so we can generate the return type for overloaded functions
|
|
6569
7080
|
return_type = f.value_func(func_args, None)
|
|
6570
7081
|
|
|
7082
|
+
if return_type is None or (isinstance(return_type, tuple) and len(return_type) > 1):
|
|
7083
|
+
return (func_args, return_type)
|
|
7084
|
+
|
|
6571
7085
|
try:
|
|
6572
|
-
|
|
7086
|
+
ctype_ret_str(return_type)
|
|
6573
7087
|
except Exception:
|
|
6574
7088
|
return None
|
|
6575
7089
|
|
|
6576
|
-
if return_type_str.startswith("Tuple"):
|
|
6577
|
-
return None
|
|
6578
|
-
|
|
6579
7090
|
return (func_args, return_type)
|
|
6580
7091
|
|
|
6581
7092
|
|
|
@@ -6716,13 +7227,18 @@ def export_functions_rst(file): # pragma: no cover
|
|
|
6716
7227
|
print("---------------", file=file)
|
|
6717
7228
|
|
|
6718
7229
|
for f, is_exported in g:
|
|
7230
|
+
if not isinstance(f, Function) and callable(f):
|
|
7231
|
+
# f is a plain Python function
|
|
7232
|
+
print(f".. autofunction:: {f.__module__}.{f.__name__}", file=file)
|
|
7233
|
+
continue
|
|
6719
7234
|
if f.func:
|
|
6720
7235
|
# f is a Warp function written in Python, we can use autofunction
|
|
6721
7236
|
print(f".. autofunction:: {f.func.__module__}.{f.key}", file=file)
|
|
6722
7237
|
continue
|
|
6723
7238
|
for f_prefix, query_type in query_types:
|
|
6724
7239
|
if f.key.startswith(f_prefix) and query_type not in written_query_types:
|
|
6725
|
-
print(f".. autoclass:: {query_type}", file=file)
|
|
7240
|
+
print(f".. autoclass:: warp.{query_type}", file=file)
|
|
7241
|
+
print(" :exclude-members: Var, vars", file=file)
|
|
6726
7242
|
written_query_types.add(query_type)
|
|
6727
7243
|
break
|
|
6728
7244
|
|
|
@@ -6775,6 +7291,7 @@ def export_stubs(file): # pragma: no cover
|
|
|
6775
7291
|
print('Rows = TypeVar("Rows", bound=int)', file=file)
|
|
6776
7292
|
print('Cols = TypeVar("Cols", bound=int)', file=file)
|
|
6777
7293
|
print('DType = TypeVar("DType")', file=file)
|
|
7294
|
+
print('Shape = TypeVar("Shape")', file=file)
|
|
6778
7295
|
|
|
6779
7296
|
print("Vector = Generic[Length, Scalar]", file=file)
|
|
6780
7297
|
print("Matrix = Generic[Rows, Cols, Scalar]", file=file)
|
|
@@ -6783,6 +7300,7 @@ def export_stubs(file): # pragma: no cover
|
|
|
6783
7300
|
print("Array = Generic[DType]", file=file)
|
|
6784
7301
|
print("FabricArray = Generic[DType]", file=file)
|
|
6785
7302
|
print("IndexedFabricArray = Generic[DType]", file=file)
|
|
7303
|
+
print("Tile = Generic[DType, Shape]", file=file)
|
|
6786
7304
|
|
|
6787
7305
|
# prepend __init__.py
|
|
6788
7306
|
with open(os.path.join(os.path.dirname(file.name), "__init__.py")) as header_file:
|
|
@@ -6817,7 +7335,7 @@ def export_stubs(file): # pragma: no cover
|
|
|
6817
7335
|
if hasattr(g, "overloads"):
|
|
6818
7336
|
for f in g.overloads:
|
|
6819
7337
|
add_stub(f)
|
|
6820
|
-
|
|
7338
|
+
elif isinstance(g, Function):
|
|
6821
7339
|
add_stub(g)
|
|
6822
7340
|
|
|
6823
7341
|
|
|
@@ -6848,16 +7366,30 @@ def export_builtins(file: io.TextIOBase): # pragma: no cover
|
|
|
6848
7366
|
args = ", ".join(f"{ctype_arg_str(v)} {k}" for k, v in func_args.items())
|
|
6849
7367
|
params = ", ".join(func_args.keys())
|
|
6850
7368
|
|
|
6851
|
-
|
|
6852
|
-
|
|
6853
|
-
if args == "":
|
|
6854
|
-
file.write(f"WP_API void {f.mangled_name}({return_str}* ret) {{ *ret = wp::{f.key}({params}); }}\n")
|
|
6855
|
-
elif return_type is None:
|
|
7369
|
+
if return_type is None:
|
|
7370
|
+
# void function
|
|
6856
7371
|
file.write(f"WP_API void {f.mangled_name}({args}) {{ wp::{f.key}({params}); }}\n")
|
|
7372
|
+
elif isinstance(return_type, tuple) and len(return_type) > 1:
|
|
7373
|
+
# multiple return value function using output parameters
|
|
7374
|
+
outputs = tuple(f"{ctype_ret_str(x)}& ret_{i}" for i, x in enumerate(return_type))
|
|
7375
|
+
output_params = ", ".join(f"ret_{i}" for i in range(len(outputs)))
|
|
7376
|
+
if args:
|
|
7377
|
+
file.write(
|
|
7378
|
+
f"WP_API void {f.mangled_name}({args}, {', '.join(outputs)}) {{ wp::{f.key}({params}, {output_params}); }}\n"
|
|
7379
|
+
)
|
|
7380
|
+
else:
|
|
7381
|
+
file.write(
|
|
7382
|
+
f"WP_API void {f.mangled_name}({', '.join(outputs)}) {{ wp::{f.key}({params}, {output_params}); }}\n"
|
|
7383
|
+
)
|
|
6857
7384
|
else:
|
|
6858
|
-
|
|
6859
|
-
|
|
6860
|
-
|
|
7385
|
+
# single return value function
|
|
7386
|
+
return_str = ctype_ret_str(return_type)
|
|
7387
|
+
if args:
|
|
7388
|
+
file.write(
|
|
7389
|
+
f"WP_API void {f.mangled_name}({args}, {return_str}* ret) {{ *ret = wp::{f.key}({params}); }}\n"
|
|
7390
|
+
)
|
|
7391
|
+
else:
|
|
7392
|
+
file.write(f"WP_API void {f.mangled_name}({return_str}* ret) {{ *ret = wp::{f.key}({params}); }}\n")
|
|
6861
7393
|
|
|
6862
7394
|
file.write('\n} // extern "C"\n\n')
|
|
6863
7395
|
file.write("} // namespace wp\n")
|