warp-lang 1.7.2rc1__py3-none-win_amd64.whl → 1.8.0__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +3 -1
- warp/__init__.pyi +3489 -1
- warp/autograd.py +45 -122
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +241 -252
- warp/build_dll.py +125 -26
- warp/builtins.py +1907 -384
- warp/codegen.py +257 -101
- warp/config.py +12 -1
- warp/constants.py +1 -1
- warp/context.py +657 -223
- warp/dlpack.py +1 -1
- warp/examples/benchmarks/benchmark_cloth.py +2 -2
- warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
- warp/examples/core/example_sample_mesh.py +1 -1
- warp/examples/core/example_spin_lock.py +93 -0
- warp/examples/core/example_work_queue.py +118 -0
- warp/examples/fem/example_adaptive_grid.py +5 -5
- warp/examples/fem/example_apic_fluid.py +1 -1
- warp/examples/fem/example_burgers.py +1 -1
- warp/examples/fem/example_convection_diffusion.py +9 -6
- warp/examples/fem/example_darcy_ls_optimization.py +489 -0
- warp/examples/fem/example_deformed_geometry.py +1 -1
- warp/examples/fem/example_diffusion.py +2 -2
- warp/examples/fem/example_diffusion_3d.py +1 -1
- warp/examples/fem/example_distortion_energy.py +1 -1
- warp/examples/fem/example_elastic_shape_optimization.py +387 -0
- warp/examples/fem/example_magnetostatics.py +5 -3
- warp/examples/fem/example_mixed_elasticity.py +5 -3
- warp/examples/fem/example_navier_stokes.py +11 -9
- warp/examples/fem/example_nonconforming_contact.py +5 -3
- warp/examples/fem/example_streamlines.py +8 -3
- warp/examples/fem/utils.py +9 -8
- warp/examples/interop/example_jax_ffi_callback.py +2 -2
- warp/examples/optim/example_drone.py +1 -1
- warp/examples/sim/example_cloth.py +1 -1
- warp/examples/sim/example_cloth_self_contact.py +48 -54
- warp/examples/tile/example_tile_block_cholesky.py +502 -0
- warp/examples/tile/example_tile_cholesky.py +2 -1
- warp/examples/tile/example_tile_convolution.py +1 -1
- warp/examples/tile/example_tile_filtering.py +1 -1
- warp/examples/tile/example_tile_matmul.py +1 -1
- warp/examples/tile/example_tile_mlp.py +2 -0
- warp/fabric.py +7 -7
- warp/fem/__init__.py +5 -0
- warp/fem/adaptivity.py +1 -1
- warp/fem/cache.py +152 -63
- warp/fem/dirichlet.py +2 -2
- warp/fem/domain.py +136 -6
- warp/fem/field/field.py +141 -99
- warp/fem/field/nodal_field.py +85 -39
- warp/fem/field/virtual.py +97 -52
- warp/fem/geometry/adaptive_nanogrid.py +91 -86
- warp/fem/geometry/closest_point.py +13 -0
- warp/fem/geometry/deformed_geometry.py +102 -40
- warp/fem/geometry/element.py +56 -2
- warp/fem/geometry/geometry.py +323 -22
- warp/fem/geometry/grid_2d.py +157 -62
- warp/fem/geometry/grid_3d.py +116 -20
- warp/fem/geometry/hexmesh.py +86 -20
- warp/fem/geometry/nanogrid.py +166 -86
- warp/fem/geometry/partition.py +59 -25
- warp/fem/geometry/quadmesh.py +86 -135
- warp/fem/geometry/tetmesh.py +47 -119
- warp/fem/geometry/trimesh.py +77 -270
- warp/fem/integrate.py +107 -52
- warp/fem/linalg.py +25 -58
- warp/fem/operator.py +124 -27
- warp/fem/quadrature/pic_quadrature.py +36 -14
- warp/fem/quadrature/quadrature.py +40 -16
- warp/fem/space/__init__.py +1 -1
- warp/fem/space/basis_function_space.py +66 -46
- warp/fem/space/basis_space.py +17 -4
- warp/fem/space/dof_mapper.py +1 -1
- warp/fem/space/function_space.py +2 -2
- warp/fem/space/grid_2d_function_space.py +4 -1
- warp/fem/space/hexmesh_function_space.py +4 -2
- warp/fem/space/nanogrid_function_space.py +3 -1
- warp/fem/space/partition.py +11 -2
- warp/fem/space/quadmesh_function_space.py +4 -1
- warp/fem/space/restriction.py +5 -2
- warp/fem/space/shape/__init__.py +10 -8
- warp/fem/space/tetmesh_function_space.py +4 -1
- warp/fem/space/topology.py +52 -21
- warp/fem/space/trimesh_function_space.py +4 -1
- warp/fem/utils.py +53 -8
- warp/jax.py +1 -2
- warp/jax_experimental/ffi.py +12 -17
- warp/jax_experimental/xla_ffi.py +37 -24
- warp/math.py +171 -1
- warp/native/array.h +99 -0
- warp/native/builtin.h +174 -31
- warp/native/coloring.cpp +1 -1
- warp/native/exports.h +118 -63
- warp/native/intersect.h +3 -3
- warp/native/mat.h +5 -10
- warp/native/mathdx.cpp +11 -5
- warp/native/matnn.h +1 -123
- warp/native/quat.h +28 -4
- warp/native/sparse.cpp +121 -258
- warp/native/sparse.cu +181 -274
- warp/native/spatial.h +305 -17
- warp/native/tile.h +583 -72
- warp/native/tile_radix_sort.h +1108 -0
- warp/native/tile_reduce.h +237 -2
- warp/native/tile_scan.h +240 -0
- warp/native/tuple.h +189 -0
- warp/native/vec.h +6 -16
- warp/native/warp.cpp +36 -4
- warp/native/warp.cu +574 -51
- warp/native/warp.h +47 -74
- warp/optim/linear.py +5 -1
- warp/paddle.py +7 -8
- warp/py.typed +0 -0
- warp/render/render_opengl.py +58 -29
- warp/render/render_usd.py +124 -61
- warp/sim/__init__.py +9 -0
- warp/sim/collide.py +252 -78
- warp/sim/graph_coloring.py +8 -1
- warp/sim/import_mjcf.py +4 -3
- warp/sim/import_usd.py +11 -7
- warp/sim/integrator.py +5 -2
- warp/sim/integrator_euler.py +1 -1
- warp/sim/integrator_featherstone.py +1 -1
- warp/sim/integrator_vbd.py +751 -320
- warp/sim/integrator_xpbd.py +1 -1
- warp/sim/model.py +265 -260
- warp/sim/utils.py +10 -7
- warp/sparse.py +303 -166
- warp/tape.py +52 -51
- warp/tests/cuda/test_conditional_captures.py +1046 -0
- warp/tests/cuda/test_streams.py +1 -1
- warp/tests/geometry/test_volume.py +2 -2
- warp/tests/interop/test_dlpack.py +9 -9
- warp/tests/interop/test_jax.py +0 -1
- warp/tests/run_coverage_serial.py +1 -1
- warp/tests/sim/disabled_kinematics.py +2 -2
- warp/tests/sim/{test_vbd.py → test_cloth.py} +296 -113
- warp/tests/sim/test_collision.py +159 -51
- warp/tests/sim/test_coloring.py +15 -1
- warp/tests/test_array.py +254 -2
- warp/tests/test_array_reduce.py +2 -2
- warp/tests/test_atomic_cas.py +299 -0
- warp/tests/test_codegen.py +142 -19
- warp/tests/test_conditional.py +47 -1
- warp/tests/test_ctypes.py +0 -20
- warp/tests/test_devices.py +8 -0
- warp/tests/test_fabricarray.py +4 -2
- warp/tests/test_fem.py +58 -25
- warp/tests/test_func.py +42 -1
- warp/tests/test_grad.py +1 -1
- warp/tests/test_lerp.py +1 -3
- warp/tests/test_map.py +481 -0
- warp/tests/test_mat.py +1 -24
- warp/tests/test_quat.py +6 -15
- warp/tests/test_rounding.py +10 -38
- warp/tests/test_runlength_encode.py +7 -7
- warp/tests/test_smoothstep.py +1 -1
- warp/tests/test_sparse.py +51 -2
- warp/tests/test_spatial.py +507 -1
- warp/tests/test_struct.py +2 -2
- warp/tests/test_tuple.py +265 -0
- warp/tests/test_types.py +2 -2
- warp/tests/test_utils.py +24 -18
- warp/tests/tile/test_tile.py +420 -1
- warp/tests/tile/test_tile_mathdx.py +518 -14
- warp/tests/tile/test_tile_reduce.py +213 -0
- warp/tests/tile/test_tile_shared_memory.py +130 -1
- warp/tests/tile/test_tile_sort.py +117 -0
- warp/tests/unittest_suites.py +4 -6
- warp/types.py +462 -308
- warp/utils.py +647 -86
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.0.dist-info}/METADATA +20 -6
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.0.dist-info}/RECORD +178 -166
- warp/stubs.py +0 -3381
- warp/tests/sim/test_xpbd.py +0 -399
- warp/tests/test_mlp.py +0 -282
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.0.dist-info}/WHEEL +0 -0
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.0.dist-info}/licenses/LICENSE.md +0 -0
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.0.dist-info}/top_level.txt +0 -0
warp/context.py
CHANGED
|
@@ -32,22 +32,7 @@ import typing
|
|
|
32
32
|
import weakref
|
|
33
33
|
from copy import copy as shallowcopy
|
|
34
34
|
from pathlib import Path
|
|
35
|
-
from typing import
|
|
36
|
-
Any,
|
|
37
|
-
Callable,
|
|
38
|
-
Dict,
|
|
39
|
-
List,
|
|
40
|
-
Literal,
|
|
41
|
-
Mapping,
|
|
42
|
-
Optional,
|
|
43
|
-
Sequence,
|
|
44
|
-
Set,
|
|
45
|
-
Tuple,
|
|
46
|
-
TypeVar,
|
|
47
|
-
Union,
|
|
48
|
-
get_args,
|
|
49
|
-
get_origin,
|
|
50
|
-
)
|
|
35
|
+
from typing import Any, Callable, Dict, List, Literal, Mapping, Sequence, Tuple, TypeVar, Union, get_args, get_origin
|
|
51
36
|
|
|
52
37
|
import numpy as np
|
|
53
38
|
|
|
@@ -84,7 +69,7 @@ def get_function_args(func):
|
|
|
84
69
|
complex_type_hints = (Any, Callable, Tuple)
|
|
85
70
|
sequence_types = (list, tuple)
|
|
86
71
|
|
|
87
|
-
function_key_counts:
|
|
72
|
+
function_key_counts: dict[str, int] = {}
|
|
88
73
|
|
|
89
74
|
|
|
90
75
|
def generate_unique_function_identifier(key: str) -> str:
|
|
@@ -120,40 +105,41 @@ def generate_unique_function_identifier(key: str) -> str:
|
|
|
120
105
|
class Function:
|
|
121
106
|
def __init__(
|
|
122
107
|
self,
|
|
123
|
-
func:
|
|
108
|
+
func: Callable | None,
|
|
124
109
|
key: str,
|
|
125
110
|
namespace: str,
|
|
126
|
-
input_types:
|
|
127
|
-
value_type:
|
|
128
|
-
value_func:
|
|
129
|
-
export_func:
|
|
130
|
-
dispatch_func:
|
|
131
|
-
lto_dispatch_func:
|
|
132
|
-
module:
|
|
111
|
+
input_types: dict[str, type | TypeVar] | None = None,
|
|
112
|
+
value_type: type | None = None,
|
|
113
|
+
value_func: Callable[[Mapping[str, type], Mapping[str, Any]], type] | None = None,
|
|
114
|
+
export_func: Callable[[dict[str, type]], dict[str, type]] | None = None,
|
|
115
|
+
dispatch_func: Callable | None = None,
|
|
116
|
+
lto_dispatch_func: Callable | None = None,
|
|
117
|
+
module: Module | None = None,
|
|
133
118
|
variadic: bool = False,
|
|
134
|
-
initializer_list_func:
|
|
119
|
+
initializer_list_func: Callable[[dict[str, Any], type], bool] | None = None,
|
|
135
120
|
export: bool = False,
|
|
121
|
+
source: str | None = None,
|
|
136
122
|
doc: str = "",
|
|
137
123
|
group: str = "",
|
|
138
124
|
hidden: bool = False,
|
|
139
125
|
skip_replay: bool = False,
|
|
140
126
|
missing_grad: bool = False,
|
|
141
127
|
generic: bool = False,
|
|
142
|
-
native_func:
|
|
143
|
-
defaults:
|
|
144
|
-
custom_replay_func:
|
|
145
|
-
native_snippet:
|
|
146
|
-
adj_native_snippet:
|
|
147
|
-
replay_snippet:
|
|
128
|
+
native_func: str | None = None,
|
|
129
|
+
defaults: dict[str, Any] | None = None,
|
|
130
|
+
custom_replay_func: Function | None = None,
|
|
131
|
+
native_snippet: str | None = None,
|
|
132
|
+
adj_native_snippet: str | None = None,
|
|
133
|
+
replay_snippet: str | None = None,
|
|
148
134
|
skip_forward_codegen: bool = False,
|
|
149
135
|
skip_reverse_codegen: bool = False,
|
|
150
136
|
custom_reverse_num_input_args: int = -1,
|
|
151
137
|
custom_reverse_mode: bool = False,
|
|
152
|
-
overloaded_annotations:
|
|
153
|
-
code_transformers:
|
|
138
|
+
overloaded_annotations: dict[str, type] | None = None,
|
|
139
|
+
code_transformers: list[ast.NodeTransformer] | None = None,
|
|
154
140
|
skip_adding_overload: bool = False,
|
|
155
141
|
require_original_output_arg: bool = False,
|
|
156
|
-
scope_locals:
|
|
142
|
+
scope_locals: dict[str, Any] | None = None,
|
|
157
143
|
):
|
|
158
144
|
if code_transformers is None:
|
|
159
145
|
code_transformers = []
|
|
@@ -178,7 +164,7 @@ class Function:
|
|
|
178
164
|
self.native_snippet = native_snippet
|
|
179
165
|
self.adj_native_snippet = adj_native_snippet
|
|
180
166
|
self.replay_snippet = replay_snippet
|
|
181
|
-
self.custom_grad_func:
|
|
167
|
+
self.custom_grad_func: Function | None = None
|
|
182
168
|
self.require_original_output_arg = require_original_output_arg
|
|
183
169
|
self.generic_parent = None # generic function that was used to instantiate this overload
|
|
184
170
|
|
|
@@ -194,7 +180,7 @@ class Function:
|
|
|
194
180
|
)
|
|
195
181
|
self.missing_grad = missing_grad # whether builtin is missing a corresponding adjoint
|
|
196
182
|
self.generic = generic
|
|
197
|
-
self.mangled_name:
|
|
183
|
+
self.mangled_name: str | None = None
|
|
198
184
|
|
|
199
185
|
# allow registering functions with a different name in Python and native code
|
|
200
186
|
if native_func is None:
|
|
@@ -211,12 +197,13 @@ class Function:
|
|
|
211
197
|
# user-defined function
|
|
212
198
|
|
|
213
199
|
# generic and concrete overload lookups by type signature
|
|
214
|
-
self.user_templates:
|
|
215
|
-
self.user_overloads:
|
|
200
|
+
self.user_templates: dict[str, Function] = {}
|
|
201
|
+
self.user_overloads: dict[str, Function] = {}
|
|
216
202
|
|
|
217
203
|
# user defined (Python) function
|
|
218
204
|
self.adj = warp.codegen.Adjoint(
|
|
219
205
|
func,
|
|
206
|
+
source=source,
|
|
220
207
|
is_user_function=True,
|
|
221
208
|
skip_forward_codegen=skip_forward_codegen,
|
|
222
209
|
skip_reverse_codegen=skip_reverse_codegen,
|
|
@@ -244,7 +231,7 @@ class Function:
|
|
|
244
231
|
|
|
245
232
|
# embedded linked list of all overloads
|
|
246
233
|
# the builtin_functions dictionary holds the list head for a given key (func name)
|
|
247
|
-
self.overloads:
|
|
234
|
+
self.overloads: list[Function] = []
|
|
248
235
|
|
|
249
236
|
# builtin (native) function, canonicalize argument types
|
|
250
237
|
if input_types is not None:
|
|
@@ -293,10 +280,11 @@ class Function:
|
|
|
293
280
|
module.register_function(self, scope_locals, skip_adding_overload)
|
|
294
281
|
|
|
295
282
|
def __call__(self, *args, **kwargs):
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
283
|
+
"""Call this function from the CPython interpreter.
|
|
284
|
+
|
|
285
|
+
This is used to call built-in or user functions from the CPython
|
|
286
|
+
interpreter, rather than from within a kernel.
|
|
287
|
+
"""
|
|
300
288
|
|
|
301
289
|
if self.is_builtin() and self.mangled_name:
|
|
302
290
|
# For each of this function's existing overloads, we attempt to pack
|
|
@@ -306,7 +294,23 @@ class Function:
|
|
|
306
294
|
if overload.generic:
|
|
307
295
|
continue
|
|
308
296
|
|
|
309
|
-
|
|
297
|
+
try:
|
|
298
|
+
# Try to bind the given arguments to the function's signature.
|
|
299
|
+
# This is not checking whether the argument types are matching,
|
|
300
|
+
# rather it's just assigning each argument to the corresponding
|
|
301
|
+
# function parameter.
|
|
302
|
+
bound_args = self.signature.bind(*args, **kwargs)
|
|
303
|
+
except TypeError:
|
|
304
|
+
continue
|
|
305
|
+
|
|
306
|
+
if self.defaults:
|
|
307
|
+
# Populate the bound arguments with any default values.
|
|
308
|
+
default_args = {k: v for k, v in self.defaults.items() if k not in bound_args.arguments}
|
|
309
|
+
warp.codegen.apply_defaults(bound_args, default_args)
|
|
310
|
+
|
|
311
|
+
bound_args = tuple(bound_args.arguments.values())
|
|
312
|
+
|
|
313
|
+
success, return_value = call_builtin(overload, bound_args)
|
|
310
314
|
if success:
|
|
311
315
|
return return_value
|
|
312
316
|
|
|
@@ -324,6 +328,9 @@ class Function:
|
|
|
324
328
|
|
|
325
329
|
arguments = tuple(bound_args.arguments.values())
|
|
326
330
|
|
|
331
|
+
# Store the last runtime error we encountered from a function execution
|
|
332
|
+
last_execution_error = None
|
|
333
|
+
|
|
327
334
|
# try and find a matching overload
|
|
328
335
|
for overload in self.user_overloads.values():
|
|
329
336
|
if len(overload.input_types) != len(arguments):
|
|
@@ -334,10 +341,25 @@ class Function:
|
|
|
334
341
|
# attempt to unify argument types with function template types
|
|
335
342
|
warp.types.infer_argument_types(arguments, template_types, arg_names)
|
|
336
343
|
return overload.func(*arguments)
|
|
337
|
-
except Exception:
|
|
344
|
+
except Exception as e:
|
|
345
|
+
# The function was callable but threw an error during its execution.
|
|
346
|
+
# This might be the intended overload, but it failed, or it might be the wrong overload.
|
|
347
|
+
# We save this specific error and continue, just in case another overload later in the
|
|
348
|
+
# list is a better match and doesn't fail.
|
|
349
|
+
last_execution_error = e
|
|
338
350
|
continue
|
|
339
351
|
|
|
340
|
-
|
|
352
|
+
if last_execution_error:
|
|
353
|
+
# Raise a new, more contextual RuntimeError, but link it to the
|
|
354
|
+
# original error that was caught. This preserves the original
|
|
355
|
+
# traceback and error type for easier debugging.
|
|
356
|
+
raise RuntimeError(
|
|
357
|
+
f"Error calling function '{self.key}'. No version succeeded. "
|
|
358
|
+
f"See above for the error from the last version that was tried."
|
|
359
|
+
) from last_execution_error
|
|
360
|
+
else:
|
|
361
|
+
# We got here without ever calling an overload.func
|
|
362
|
+
raise RuntimeError(f"Error calling function '{self.key}', no overload found for arguments {args}")
|
|
341
363
|
|
|
342
364
|
# user-defined function with no overloads
|
|
343
365
|
if self.func is None:
|
|
@@ -358,9 +380,6 @@ class Function:
|
|
|
358
380
|
if warp.types.is_array(v) or v in complex_type_hints:
|
|
359
381
|
return False
|
|
360
382
|
|
|
361
|
-
if type(self.value_type) in sequence_types:
|
|
362
|
-
return False
|
|
363
|
-
|
|
364
383
|
return True
|
|
365
384
|
|
|
366
385
|
def mangle(self) -> str:
|
|
@@ -404,8 +423,12 @@ class Function:
|
|
|
404
423
|
else:
|
|
405
424
|
self.user_overloads[sig] = f
|
|
406
425
|
|
|
407
|
-
def get_overload(self, arg_types:
|
|
408
|
-
|
|
426
|
+
def get_overload(self, arg_types: list[type], kwarg_types: Mapping[str, type]) -> Function | None:
|
|
427
|
+
if self.is_builtin():
|
|
428
|
+
for f in self.overloads:
|
|
429
|
+
if warp.codegen.func_match_args(f, arg_types, kwarg_types):
|
|
430
|
+
return f
|
|
431
|
+
return None
|
|
409
432
|
|
|
410
433
|
for f in self.user_overloads.values():
|
|
411
434
|
if warp.codegen.func_match_args(f, arg_types, kwarg_types):
|
|
@@ -439,7 +462,7 @@ class Function:
|
|
|
439
462
|
overload_annotations[k] = warp.codegen.strip_reference(warp.codegen.get_arg_type(d))
|
|
440
463
|
|
|
441
464
|
ovl = shallowcopy(f)
|
|
442
|
-
ovl.adj = warp.codegen.Adjoint(f.func, overload_annotations)
|
|
465
|
+
ovl.adj = warp.codegen.Adjoint(f.func, overload_annotations, source=f.adj.source)
|
|
443
466
|
ovl.input_types = overload_annotations
|
|
444
467
|
ovl.value_func = None
|
|
445
468
|
ovl.generic_parent = f
|
|
@@ -475,11 +498,25 @@ def get_builtin_type(return_type: type) -> type:
|
|
|
475
498
|
return return_type
|
|
476
499
|
|
|
477
500
|
|
|
478
|
-
def
|
|
501
|
+
def extract_return_value(value_type: type, value_ctype: type, ret: Any) -> Any:
|
|
502
|
+
if issubclass(value_ctype, ctypes.Array) or issubclass(value_ctype, ctypes.Structure):
|
|
503
|
+
# return vector types as ctypes
|
|
504
|
+
return ret
|
|
505
|
+
|
|
506
|
+
if value_type is warp.types.float16:
|
|
507
|
+
return warp.types.half_bits_to_float(ret.value)
|
|
508
|
+
|
|
509
|
+
return ret.value
|
|
510
|
+
|
|
511
|
+
|
|
512
|
+
def call_builtin(func: Function, params: tuple) -> tuple[bool, Any]:
|
|
479
513
|
uses_non_warp_array_type = False
|
|
480
514
|
|
|
481
515
|
init()
|
|
482
516
|
|
|
517
|
+
if func.mangled_name is None:
|
|
518
|
+
return (False, None)
|
|
519
|
+
|
|
483
520
|
# Retrieve the built-in function from Warp's dll.
|
|
484
521
|
c_func = getattr(warp.context.runtime.core, func.mangled_name)
|
|
485
522
|
|
|
@@ -489,6 +526,8 @@ def call_builtin(func: Function, *params: Any) -> Tuple[bool, Any]:
|
|
|
489
526
|
else:
|
|
490
527
|
func_args = func.input_types
|
|
491
528
|
|
|
529
|
+
value_type = func.value_func(None, None)
|
|
530
|
+
|
|
492
531
|
# Try gathering the parameters that the function expects and pack them
|
|
493
532
|
# into their corresponding C types.
|
|
494
533
|
c_params = []
|
|
@@ -604,9 +643,9 @@ def call_builtin(func: Function, *params: Any) -> Tuple[bool, Any]:
|
|
|
604
643
|
|
|
605
644
|
if not (
|
|
606
645
|
isinstance(param, arg_type)
|
|
607
|
-
or (type(param) is float and arg_type is warp.types.float32)
|
|
608
|
-
or (type(param) is int and arg_type is warp.types.int32)
|
|
609
|
-
or (type(param) is bool and arg_type is warp.types.bool)
|
|
646
|
+
or (type(param) is float and arg_type is warp.types.float32)
|
|
647
|
+
or (type(param) is int and arg_type is warp.types.int32)
|
|
648
|
+
or (type(param) is bool and arg_type is warp.types.bool)
|
|
610
649
|
or warp.types.np_dtype_to_warp_type.get(getattr(param, "dtype", None)) is arg_type
|
|
611
650
|
):
|
|
612
651
|
return (False, None)
|
|
@@ -620,25 +659,18 @@ def call_builtin(func: Function, *params: Any) -> Tuple[bool, Any]:
|
|
|
620
659
|
else:
|
|
621
660
|
c_params.append(arg_type._type_(param))
|
|
622
661
|
|
|
623
|
-
#
|
|
624
|
-
value_type = func.value_func(
|
|
662
|
+
# Retrieve the return type.
|
|
663
|
+
value_type = func.value_func(func_args, None)
|
|
625
664
|
|
|
626
|
-
if value_type
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
value_ctype = value_type
|
|
634
|
-
else:
|
|
635
|
-
# scalar type
|
|
636
|
-
value_ctype = value_type._type_
|
|
665
|
+
if value_type is not None:
|
|
666
|
+
if not isinstance(value_type, Sequence):
|
|
667
|
+
value_type = (value_type,)
|
|
668
|
+
|
|
669
|
+
value_ctype = tuple(warp.types.type_ctype(x) for x in value_type)
|
|
670
|
+
ret = tuple(x() for x in value_ctype)
|
|
671
|
+
ret_addr = tuple(ctypes.c_void_p(ctypes.addressof(x)) for x in ret)
|
|
637
672
|
|
|
638
|
-
|
|
639
|
-
ret = value_ctype()
|
|
640
|
-
ret_addr = ctypes.c_void_p(ctypes.addressof(ret))
|
|
641
|
-
c_params.append(ret_addr)
|
|
673
|
+
c_params.extend(ret_addr)
|
|
642
674
|
|
|
643
675
|
# Call the built-in function from Warp's dll.
|
|
644
676
|
c_func(*c_params)
|
|
@@ -653,17 +685,14 @@ def call_builtin(func: Function, *params: Any) -> Tuple[bool, Any]:
|
|
|
653
685
|
stacklevel=3,
|
|
654
686
|
)
|
|
655
687
|
|
|
656
|
-
if
|
|
657
|
-
|
|
658
|
-
return (True, ret)
|
|
688
|
+
if value_type is None:
|
|
689
|
+
return (True, None)
|
|
659
690
|
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
value = ret.value
|
|
691
|
+
return_value = tuple(extract_return_value(x, y, z) for x, y, z in zip(value_type, value_ctype, ret))
|
|
692
|
+
if len(return_value) == 1:
|
|
693
|
+
return_value = return_value[0]
|
|
664
694
|
|
|
665
|
-
|
|
666
|
-
return (True, value)
|
|
695
|
+
return (True, return_value)
|
|
667
696
|
|
|
668
697
|
|
|
669
698
|
class KernelHooks:
|
|
@@ -677,7 +706,7 @@ class KernelHooks:
|
|
|
677
706
|
|
|
678
707
|
# caches source and compiled entry points for a kernel (will be populated after module loads)
|
|
679
708
|
class Kernel:
|
|
680
|
-
def __init__(self, func, key=None, module=None, options=None, code_transformers=None):
|
|
709
|
+
def __init__(self, func, key=None, module=None, options=None, code_transformers=None, source=None):
|
|
681
710
|
self.func = func
|
|
682
711
|
|
|
683
712
|
if module is None:
|
|
@@ -695,7 +724,7 @@ class Kernel:
|
|
|
695
724
|
if code_transformers is None:
|
|
696
725
|
code_transformers = []
|
|
697
726
|
|
|
698
|
-
self.adj = warp.codegen.Adjoint(func, transformers=code_transformers)
|
|
727
|
+
self.adj = warp.codegen.Adjoint(func, transformers=code_transformers, source=source)
|
|
699
728
|
|
|
700
729
|
# check if generic
|
|
701
730
|
self.is_generic = False
|
|
@@ -762,7 +791,7 @@ class Kernel:
|
|
|
762
791
|
|
|
763
792
|
# instantiate this kernel with the given argument types
|
|
764
793
|
ovl = shallowcopy(self)
|
|
765
|
-
ovl.adj = warp.codegen.Adjoint(self.func, overload_annotations)
|
|
794
|
+
ovl.adj = warp.codegen.Adjoint(self.func, overload_annotations, source=self.adj.source)
|
|
766
795
|
ovl.is_generic = False
|
|
767
796
|
ovl.overloads = {}
|
|
768
797
|
ovl.sig = sig
|
|
@@ -798,7 +827,7 @@ class Kernel:
|
|
|
798
827
|
|
|
799
828
|
|
|
800
829
|
# decorator to register function, @func
|
|
801
|
-
def func(f:
|
|
830
|
+
def func(f: Callable | None = None, *, name: str | None = None):
|
|
802
831
|
def wrapper(f, *args, **kwargs):
|
|
803
832
|
if name is None:
|
|
804
833
|
key = warp.codegen.make_full_qualified_name(f)
|
|
@@ -831,7 +860,7 @@ def func(f: Optional[Callable] = None, *, name: Optional[str] = None):
|
|
|
831
860
|
return wrapper(f)
|
|
832
861
|
|
|
833
862
|
|
|
834
|
-
def func_native(snippet: str, adj_snippet:
|
|
863
|
+
def func_native(snippet: str, adj_snippet: str | None = None, replay_snippet: str | None = None):
|
|
835
864
|
"""
|
|
836
865
|
Decorator to register native code snippet, @func_native
|
|
837
866
|
"""
|
|
@@ -1015,10 +1044,10 @@ def func_replay(forward_fn):
|
|
|
1015
1044
|
|
|
1016
1045
|
|
|
1017
1046
|
def kernel(
|
|
1018
|
-
f:
|
|
1047
|
+
f: Callable | None = None,
|
|
1019
1048
|
*,
|
|
1020
|
-
enable_backward:
|
|
1021
|
-
module:
|
|
1049
|
+
enable_backward: bool | None = None,
|
|
1050
|
+
module: Module | Literal["unique"] | None = None,
|
|
1022
1051
|
):
|
|
1023
1052
|
"""
|
|
1024
1053
|
Decorator to register a Warp kernel from a Python function.
|
|
@@ -1181,7 +1210,7 @@ def overload(kernel, arg_types=Union[None, Dict[str, Any], List[Any]]):
|
|
|
1181
1210
|
|
|
1182
1211
|
|
|
1183
1212
|
# native functions that are part of the Warp API
|
|
1184
|
-
builtin_functions:
|
|
1213
|
+
builtin_functions: dict[str, Function] = {}
|
|
1185
1214
|
|
|
1186
1215
|
|
|
1187
1216
|
def get_generic_vtypes():
|
|
@@ -1204,13 +1233,13 @@ scalar_types.update({x: x._wp_scalar_type_ for x in warp.types.vector_types})
|
|
|
1204
1233
|
|
|
1205
1234
|
def add_builtin(
|
|
1206
1235
|
key: str,
|
|
1207
|
-
input_types:
|
|
1208
|
-
constraint:
|
|
1209
|
-
value_type:
|
|
1210
|
-
value_func:
|
|
1211
|
-
export_func:
|
|
1212
|
-
dispatch_func:
|
|
1213
|
-
lto_dispatch_func:
|
|
1236
|
+
input_types: dict[str, type | TypeVar] | None = None,
|
|
1237
|
+
constraint: Callable[[Mapping[str, type]], bool] | None = None,
|
|
1238
|
+
value_type: type | None = None,
|
|
1239
|
+
value_func: Callable | None = None,
|
|
1240
|
+
export_func: Callable | None = None,
|
|
1241
|
+
dispatch_func: Callable | None = None,
|
|
1242
|
+
lto_dispatch_func: Callable | None = None,
|
|
1214
1243
|
doc: str = "",
|
|
1215
1244
|
namespace: str = "wp::",
|
|
1216
1245
|
variadic: bool = False,
|
|
@@ -1220,8 +1249,8 @@ def add_builtin(
|
|
|
1220
1249
|
hidden: bool = False,
|
|
1221
1250
|
skip_replay: bool = False,
|
|
1222
1251
|
missing_grad: bool = False,
|
|
1223
|
-
native_func:
|
|
1224
|
-
defaults:
|
|
1252
|
+
native_func: str | None = None,
|
|
1253
|
+
defaults: dict[str, Any] | None = None,
|
|
1225
1254
|
require_original_output_arg: bool = False,
|
|
1226
1255
|
):
|
|
1227
1256
|
"""Main entry point to register a new built-in function.
|
|
@@ -1371,18 +1400,13 @@ def add_builtin(
|
|
|
1371
1400
|
|
|
1372
1401
|
return_type = value_func(concrete_arg_types, None)
|
|
1373
1402
|
|
|
1374
|
-
|
|
1375
|
-
|
|
1376
|
-
|
|
1377
|
-
|
|
1378
|
-
|
|
1379
|
-
|
|
1380
|
-
|
|
1381
|
-
and x._wp_type_params_ == return_type._wp_type_params_
|
|
1382
|
-
)
|
|
1383
|
-
if not return_type_match:
|
|
1384
|
-
continue
|
|
1385
|
-
return_type = return_type_match[0]
|
|
1403
|
+
try:
|
|
1404
|
+
if isinstance(return_type, Sequence):
|
|
1405
|
+
return_type = tuple(get_builtin_type(x) for x in return_type)
|
|
1406
|
+
else:
|
|
1407
|
+
return_type = get_builtin_type(return_type)
|
|
1408
|
+
except RuntimeError:
|
|
1409
|
+
continue
|
|
1386
1410
|
|
|
1387
1411
|
# finally we can generate a function call for these concrete types:
|
|
1388
1412
|
add_builtin(
|
|
@@ -1485,7 +1509,7 @@ def register_api_function(
|
|
|
1485
1509
|
|
|
1486
1510
|
|
|
1487
1511
|
# global dictionary of modules
|
|
1488
|
-
user_modules:
|
|
1512
|
+
user_modules: dict[str, Module] = {}
|
|
1489
1513
|
|
|
1490
1514
|
|
|
1491
1515
|
def get_module(name: str) -> Module:
|
|
@@ -1608,7 +1632,7 @@ class ModuleHasher:
|
|
|
1608
1632
|
ch.update(bytes(func.key, "utf-8"))
|
|
1609
1633
|
|
|
1610
1634
|
# include all concrete and generic overloads
|
|
1611
|
-
overloads:
|
|
1635
|
+
overloads: dict[str, Function] = {**func.user_overloads, **func.user_templates}
|
|
1612
1636
|
for sig in sorted(overloads.keys()):
|
|
1613
1637
|
ovl = overloads[sig]
|
|
1614
1638
|
|
|
@@ -1857,7 +1881,7 @@ class ModuleBuilder:
|
|
|
1857
1881
|
# the original Modules get reloaded.
|
|
1858
1882
|
class ModuleExec:
|
|
1859
1883
|
def __new__(cls, *args, **kwargs):
|
|
1860
|
-
instance = super(
|
|
1884
|
+
instance = super().__new__(cls)
|
|
1861
1885
|
instance.handle = None
|
|
1862
1886
|
return instance
|
|
1863
1887
|
|
|
@@ -1952,7 +1976,7 @@ class ModuleExec:
|
|
|
1952
1976
|
# creates a hash of the function to use for checking
|
|
1953
1977
|
# build cache
|
|
1954
1978
|
class Module:
|
|
1955
|
-
def __init__(self, name:
|
|
1979
|
+
def __init__(self, name: str | None, loader=None):
|
|
1956
1980
|
self.name = name if name is not None else "None"
|
|
1957
1981
|
|
|
1958
1982
|
self.loader = loader
|
|
@@ -1996,6 +2020,7 @@ class Module:
|
|
|
1996
2020
|
"cuda_output": None, # supported values: "ptx", "cubin", or None (automatic)
|
|
1997
2021
|
"mode": warp.config.mode,
|
|
1998
2022
|
"block_dim": 256,
|
|
2023
|
+
"compile_time_trace": warp.config.compile_time_trace,
|
|
1999
2024
|
}
|
|
2000
2025
|
|
|
2001
2026
|
# Module dependencies are determined by scanning each function
|
|
@@ -2222,7 +2247,7 @@ class Module:
|
|
|
2222
2247
|
):
|
|
2223
2248
|
builder_options = {
|
|
2224
2249
|
**self.options,
|
|
2225
|
-
# Some of the
|
|
2250
|
+
# Some of the tile codegen, such as cuFFTDx and cuBLASDx, requires knowledge of the target arch
|
|
2226
2251
|
"output_arch": output_arch,
|
|
2227
2252
|
}
|
|
2228
2253
|
builder = ModuleBuilder(self, builder_options, hasher=self.hashers[active_block_dim])
|
|
@@ -2291,6 +2316,7 @@ class Module:
|
|
|
2291
2316
|
fast_math=self.options["fast_math"],
|
|
2292
2317
|
fuse_fp=self.options["fuse_fp"],
|
|
2293
2318
|
lineinfo=self.options["lineinfo"],
|
|
2319
|
+
compile_time_trace=self.options["compile_time_trace"],
|
|
2294
2320
|
ltoirs=builder.ltoirs.values(),
|
|
2295
2321
|
fatbins=builder.fatbins.values(),
|
|
2296
2322
|
)
|
|
@@ -2343,7 +2369,7 @@ class Module:
|
|
|
2343
2369
|
# Load CPU or CUDA binary
|
|
2344
2370
|
|
|
2345
2371
|
meta_path = os.path.join(module_dir, f"{module_name_short}.meta")
|
|
2346
|
-
with open(meta_path
|
|
2372
|
+
with open(meta_path) as meta_file:
|
|
2347
2373
|
meta = json.load(meta_file)
|
|
2348
2374
|
|
|
2349
2375
|
if device.is_cpu:
|
|
@@ -2406,7 +2432,7 @@ class CpuDefaultAllocator:
|
|
|
2406
2432
|
def alloc(self, size_in_bytes):
|
|
2407
2433
|
ptr = runtime.core.alloc_host(size_in_bytes)
|
|
2408
2434
|
if not ptr:
|
|
2409
|
-
raise RuntimeError(f"Failed to allocate {size_in_bytes} bytes on device '
|
|
2435
|
+
raise RuntimeError(f"Failed to allocate {size_in_bytes} bytes on device 'cpu'")
|
|
2410
2436
|
return ptr
|
|
2411
2437
|
|
|
2412
2438
|
def free(self, ptr, size_in_bytes):
|
|
@@ -2510,12 +2536,12 @@ class Event:
|
|
|
2510
2536
|
|
|
2511
2537
|
def __new__(cls, *args, **kwargs):
|
|
2512
2538
|
"""Creates a new event instance."""
|
|
2513
|
-
instance = super(
|
|
2539
|
+
instance = super().__new__(cls)
|
|
2514
2540
|
instance.owner = False
|
|
2515
2541
|
return instance
|
|
2516
2542
|
|
|
2517
2543
|
def __init__(
|
|
2518
|
-
self, device:
|
|
2544
|
+
self, device: Devicelike = None, cuda_event=None, enable_timing: bool = False, interprocess: bool = False
|
|
2519
2545
|
):
|
|
2520
2546
|
"""Initializes the event on a CUDA device.
|
|
2521
2547
|
|
|
@@ -2611,12 +2637,12 @@ class Event:
|
|
|
2611
2637
|
|
|
2612
2638
|
class Stream:
|
|
2613
2639
|
def __new__(cls, *args, **kwargs):
|
|
2614
|
-
instance = super(
|
|
2640
|
+
instance = super().__new__(cls)
|
|
2615
2641
|
instance.cuda_stream = None
|
|
2616
2642
|
instance.owner = False
|
|
2617
2643
|
return instance
|
|
2618
2644
|
|
|
2619
|
-
def __init__(self, device:
|
|
2645
|
+
def __init__(self, device: Device | str | None = None, priority: int = 0, **kwargs):
|
|
2620
2646
|
"""Initialize the stream on a device with an optional specified priority.
|
|
2621
2647
|
|
|
2622
2648
|
Args:
|
|
@@ -2682,7 +2708,7 @@ class Stream:
|
|
|
2682
2708
|
self._cached_event = Event(self.device)
|
|
2683
2709
|
return self._cached_event
|
|
2684
2710
|
|
|
2685
|
-
def record_event(self, event:
|
|
2711
|
+
def record_event(self, event: Event | None = None) -> Event:
|
|
2686
2712
|
"""Record an event onto the stream.
|
|
2687
2713
|
|
|
2688
2714
|
Args:
|
|
@@ -2711,7 +2737,7 @@ class Stream:
|
|
|
2711
2737
|
"""
|
|
2712
2738
|
runtime.core.cuda_stream_wait_event(self.cuda_stream, event.cuda_event)
|
|
2713
2739
|
|
|
2714
|
-
def wait_stream(self, other_stream:
|
|
2740
|
+
def wait_stream(self, other_stream: Stream, event: Event | None = None):
|
|
2715
2741
|
"""Records an event on `other_stream` and makes this stream wait on it.
|
|
2716
2742
|
|
|
2717
2743
|
All work added to this stream after this function has been called will
|
|
@@ -2765,6 +2791,8 @@ class Device:
|
|
|
2765
2791
|
or ``"CPU"`` if the processor name cannot be determined.
|
|
2766
2792
|
arch (int): The compute capability version number calculated as ``10 * major + minor``.
|
|
2767
2793
|
``0`` for CPU devices.
|
|
2794
|
+
sm_count (int): The number of streaming multiprocessors on the CUDA device.
|
|
2795
|
+
``0`` for CPU devices.
|
|
2768
2796
|
is_uva (bool): Indicates whether the device supports unified addressing.
|
|
2769
2797
|
``False`` for CPU devices.
|
|
2770
2798
|
is_cubin_supported (bool): Indicates whether Warp's version of NVRTC can directly
|
|
@@ -2810,6 +2838,7 @@ class Device:
|
|
|
2810
2838
|
# CPU device
|
|
2811
2839
|
self.name = platform.processor() or "CPU"
|
|
2812
2840
|
self.arch = 0
|
|
2841
|
+
self.sm_count = 0
|
|
2813
2842
|
self.is_uva = False
|
|
2814
2843
|
self.is_mempool_supported = False
|
|
2815
2844
|
self.is_mempool_enabled = False
|
|
@@ -2829,6 +2858,7 @@ class Device:
|
|
|
2829
2858
|
# CUDA device
|
|
2830
2859
|
self.name = runtime.core.cuda_device_get_name(ordinal).decode()
|
|
2831
2860
|
self.arch = runtime.core.cuda_device_get_arch(ordinal)
|
|
2861
|
+
self.sm_count = runtime.core.cuda_device_get_sm_count(ordinal)
|
|
2832
2862
|
self.is_uva = runtime.core.cuda_device_is_uva(ordinal) > 0
|
|
2833
2863
|
self.is_mempool_supported = runtime.core.cuda_device_is_mempool_supported(ordinal) > 0
|
|
2834
2864
|
if platform.system() == "Linux":
|
|
@@ -3070,16 +3100,23 @@ class Graph:
|
|
|
3070
3100
|
def __init__(self, device: Device, capture_id: int):
|
|
3071
3101
|
self.device = device
|
|
3072
3102
|
self.capture_id = capture_id
|
|
3073
|
-
self.module_execs:
|
|
3074
|
-
self.graph_exec:
|
|
3103
|
+
self.module_execs: set[ModuleExec] = set()
|
|
3104
|
+
self.graph_exec: ctypes.c_void_p | None = None
|
|
3105
|
+
|
|
3106
|
+
self.graph: ctypes.c_void_p | None = None
|
|
3107
|
+
self.has_conditional = (
|
|
3108
|
+
False # Track if there are conditional nodes in the graph since they are not allowed in child graphs
|
|
3109
|
+
)
|
|
3075
3110
|
|
|
3076
3111
|
def __del__(self):
|
|
3077
|
-
if not hasattr(self, "
|
|
3112
|
+
if not hasattr(self, "graph") or not hasattr(self, "device") or not self.graph:
|
|
3078
3113
|
return
|
|
3079
3114
|
|
|
3080
3115
|
# use CUDA context guard to avoid side effects during garbage collection
|
|
3081
3116
|
with self.device.context_guard:
|
|
3082
|
-
runtime.core.cuda_graph_destroy(self.device.context, self.
|
|
3117
|
+
runtime.core.cuda_graph_destroy(self.device.context, self.graph)
|
|
3118
|
+
if hasattr(self, "graph_exec") and self.graph_exec is not None:
|
|
3119
|
+
runtime.core.cuda_graph_exec_destroy(self.device.context, self.graph_exec)
|
|
3083
3120
|
|
|
3084
3121
|
# retain executable CUDA modules used by this graph, which prevents them from being unloaded
|
|
3085
3122
|
def retain_module_exec(self, module_exec: ModuleExec):
|
|
@@ -3088,8 +3125,6 @@ class Graph:
|
|
|
3088
3125
|
|
|
3089
3126
|
class Runtime:
|
|
3090
3127
|
def __init__(self):
|
|
3091
|
-
if sys.version_info < (3, 8):
|
|
3092
|
-
raise RuntimeError("Warp requires Python 3.8 as a minimum")
|
|
3093
3128
|
if sys.version_info < (3, 9):
|
|
3094
3129
|
warp.utils.warn(f"Python 3.9 or newer is recommended for running Warp, detected {sys.version_info}")
|
|
3095
3130
|
|
|
@@ -3535,44 +3570,40 @@ class Runtime:
|
|
|
3535
3570
|
self.core.volume_get_blind_data_info.restype = ctypes.c_char_p
|
|
3536
3571
|
|
|
3537
3572
|
bsr_matrix_from_triplets_argtypes = [
|
|
3538
|
-
ctypes.c_int, #
|
|
3539
|
-
ctypes.c_int, #
|
|
3573
|
+
ctypes.c_int, # block_size
|
|
3574
|
+
ctypes.c_int, # scalar size in bytes
|
|
3540
3575
|
ctypes.c_int, # row_count
|
|
3541
|
-
ctypes.c_int, #
|
|
3576
|
+
ctypes.c_int, # col_count
|
|
3577
|
+
ctypes.c_int, # nnz_upper_bound
|
|
3578
|
+
ctypes.POINTER(ctypes.c_int), # tpl_nnz
|
|
3542
3579
|
ctypes.POINTER(ctypes.c_int), # tpl_rows
|
|
3543
3580
|
ctypes.POINTER(ctypes.c_int), # tpl_cols
|
|
3544
3581
|
ctypes.c_void_p, # tpl_values
|
|
3545
|
-
ctypes.
|
|
3582
|
+
ctypes.c_uint64, # zero_value_mask
|
|
3546
3583
|
ctypes.c_bool, # masked
|
|
3547
3584
|
ctypes.POINTER(ctypes.c_int), # bsr_offsets
|
|
3548
3585
|
ctypes.POINTER(ctypes.c_int), # bsr_columns
|
|
3549
|
-
ctypes.
|
|
3586
|
+
ctypes.POINTER(ctypes.c_int), # prefix sum of block count to sum for each bsr block
|
|
3587
|
+
ctypes.POINTER(ctypes.c_int), # indices to ptriplet blocks to sum for each bsr block
|
|
3550
3588
|
ctypes.POINTER(ctypes.c_int), # bsr_nnz
|
|
3551
3589
|
ctypes.c_void_p, # bsr_nnz_event
|
|
3552
3590
|
]
|
|
3553
3591
|
|
|
3554
|
-
self.core.
|
|
3555
|
-
self.core.
|
|
3556
|
-
self.core.bsr_matrix_from_triplets_float_device.argtypes = bsr_matrix_from_triplets_argtypes
|
|
3557
|
-
self.core.bsr_matrix_from_triplets_double_device.argtypes = bsr_matrix_from_triplets_argtypes
|
|
3592
|
+
self.core.bsr_matrix_from_triplets_host.argtypes = bsr_matrix_from_triplets_argtypes
|
|
3593
|
+
self.core.bsr_matrix_from_triplets_device.argtypes = bsr_matrix_from_triplets_argtypes
|
|
3558
3594
|
|
|
3559
3595
|
bsr_transpose_argtypes = [
|
|
3560
|
-
ctypes.c_int, # rows_per_bock
|
|
3561
|
-
ctypes.c_int, # cols_per_blocks
|
|
3562
3596
|
ctypes.c_int, # row_count
|
|
3563
3597
|
ctypes.c_int, # col count
|
|
3564
3598
|
ctypes.c_int, # nnz
|
|
3565
3599
|
ctypes.POINTER(ctypes.c_int), # transposed_bsr_offsets
|
|
3566
3600
|
ctypes.POINTER(ctypes.c_int), # transposed_bsr_columns
|
|
3567
|
-
ctypes.c_void_p, # bsr_values
|
|
3568
3601
|
ctypes.POINTER(ctypes.c_int), # transposed_bsr_offsets
|
|
3569
3602
|
ctypes.POINTER(ctypes.c_int), # transposed_bsr_columns
|
|
3570
|
-
ctypes.
|
|
3603
|
+
ctypes.POINTER(ctypes.c_int), # src to dest block map
|
|
3571
3604
|
]
|
|
3572
|
-
self.core.
|
|
3573
|
-
self.core.
|
|
3574
|
-
self.core.bsr_transpose_float_device.argtypes = bsr_transpose_argtypes
|
|
3575
|
-
self.core.bsr_transpose_double_device.argtypes = bsr_transpose_argtypes
|
|
3605
|
+
self.core.bsr_transpose_host.argtypes = bsr_transpose_argtypes
|
|
3606
|
+
self.core.bsr_transpose_device.argtypes = bsr_transpose_argtypes
|
|
3576
3607
|
|
|
3577
3608
|
self.core.is_cuda_enabled.argtypes = None
|
|
3578
3609
|
self.core.is_cuda_enabled.restype = ctypes.c_int
|
|
@@ -3601,6 +3632,8 @@ class Runtime:
|
|
|
3601
3632
|
self.core.cuda_device_get_name.restype = ctypes.c_char_p
|
|
3602
3633
|
self.core.cuda_device_get_arch.argtypes = [ctypes.c_int]
|
|
3603
3634
|
self.core.cuda_device_get_arch.restype = ctypes.c_int
|
|
3635
|
+
self.core.cuda_device_get_sm_count.argtypes = [ctypes.c_int]
|
|
3636
|
+
self.core.cuda_device_get_sm_count.restype = ctypes.c_int
|
|
3604
3637
|
self.core.cuda_device_is_uva.argtypes = [ctypes.c_int]
|
|
3605
3638
|
self.core.cuda_device_is_uva.restype = ctypes.c_int
|
|
3606
3639
|
self.core.cuda_device_is_mempool_supported.argtypes = [ctypes.c_int]
|
|
@@ -3724,11 +3757,72 @@ class Runtime:
|
|
|
3724
3757
|
ctypes.POINTER(ctypes.c_void_p),
|
|
3725
3758
|
]
|
|
3726
3759
|
self.core.cuda_graph_end_capture.restype = ctypes.c_bool
|
|
3760
|
+
|
|
3761
|
+
self.core.cuda_graph_create_exec.argtypes = [
|
|
3762
|
+
ctypes.c_void_p,
|
|
3763
|
+
ctypes.c_void_p,
|
|
3764
|
+
ctypes.POINTER(ctypes.c_void_p),
|
|
3765
|
+
]
|
|
3766
|
+
self.core.cuda_graph_create_exec.restype = ctypes.c_bool
|
|
3767
|
+
|
|
3768
|
+
self.core.capture_debug_dot_print.argtypes = [ctypes.c_void_p, ctypes.c_char_p, ctypes.c_uint32]
|
|
3769
|
+
self.core.capture_debug_dot_print.restype = ctypes.c_bool
|
|
3770
|
+
|
|
3727
3771
|
self.core.cuda_graph_launch.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
|
|
3728
3772
|
self.core.cuda_graph_launch.restype = ctypes.c_bool
|
|
3773
|
+
self.core.cuda_graph_exec_destroy.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
|
|
3774
|
+
self.core.cuda_graph_exec_destroy.restype = ctypes.c_bool
|
|
3775
|
+
|
|
3729
3776
|
self.core.cuda_graph_destroy.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
|
|
3730
3777
|
self.core.cuda_graph_destroy.restype = ctypes.c_bool
|
|
3731
3778
|
|
|
3779
|
+
self.core.cuda_graph_insert_if_else.argtypes = [
|
|
3780
|
+
ctypes.c_void_p,
|
|
3781
|
+
ctypes.c_void_p,
|
|
3782
|
+
ctypes.POINTER(ctypes.c_int),
|
|
3783
|
+
ctypes.POINTER(ctypes.c_void_p),
|
|
3784
|
+
ctypes.POINTER(ctypes.c_void_p),
|
|
3785
|
+
]
|
|
3786
|
+
self.core.cuda_graph_insert_if_else.restype = ctypes.c_bool
|
|
3787
|
+
|
|
3788
|
+
self.core.cuda_graph_insert_while.argtypes = [
|
|
3789
|
+
ctypes.c_void_p,
|
|
3790
|
+
ctypes.c_void_p,
|
|
3791
|
+
ctypes.POINTER(ctypes.c_int),
|
|
3792
|
+
ctypes.POINTER(ctypes.c_void_p),
|
|
3793
|
+
ctypes.POINTER(ctypes.c_uint64),
|
|
3794
|
+
]
|
|
3795
|
+
self.core.cuda_graph_insert_while.restype = ctypes.c_bool
|
|
3796
|
+
|
|
3797
|
+
self.core.cuda_graph_set_condition.argtypes = [
|
|
3798
|
+
ctypes.c_void_p,
|
|
3799
|
+
ctypes.c_void_p,
|
|
3800
|
+
ctypes.POINTER(ctypes.c_int),
|
|
3801
|
+
ctypes.c_uint64,
|
|
3802
|
+
]
|
|
3803
|
+
self.core.cuda_graph_set_condition.restype = ctypes.c_bool
|
|
3804
|
+
|
|
3805
|
+
self.core.cuda_graph_pause_capture.argtypes = [
|
|
3806
|
+
ctypes.c_void_p,
|
|
3807
|
+
ctypes.c_void_p,
|
|
3808
|
+
ctypes.POINTER(ctypes.c_void_p),
|
|
3809
|
+
]
|
|
3810
|
+
self.core.cuda_graph_pause_capture.restype = ctypes.c_bool
|
|
3811
|
+
|
|
3812
|
+
self.core.cuda_graph_resume_capture.argtypes = [
|
|
3813
|
+
ctypes.c_void_p,
|
|
3814
|
+
ctypes.c_void_p,
|
|
3815
|
+
ctypes.c_void_p,
|
|
3816
|
+
]
|
|
3817
|
+
self.core.cuda_graph_resume_capture.restype = ctypes.c_bool
|
|
3818
|
+
|
|
3819
|
+
self.core.cuda_graph_insert_child_graph.argtypes = [
|
|
3820
|
+
ctypes.c_void_p,
|
|
3821
|
+
ctypes.c_void_p,
|
|
3822
|
+
ctypes.c_void_p,
|
|
3823
|
+
]
|
|
3824
|
+
self.core.cuda_graph_insert_child_graph.restype = ctypes.c_bool
|
|
3825
|
+
|
|
3732
3826
|
self.core.cuda_compile_program.argtypes = [
|
|
3733
3827
|
ctypes.c_char_p, # cuda_src
|
|
3734
3828
|
ctypes.c_char_p, # program name
|
|
@@ -3742,6 +3836,7 @@ class Runtime:
|
|
|
3742
3836
|
ctypes.c_bool, # fast_math
|
|
3743
3837
|
ctypes.c_bool, # fuse_fp
|
|
3744
3838
|
ctypes.c_bool, # lineinfo
|
|
3839
|
+
ctypes.c_bool, # compile_time_trace
|
|
3745
3840
|
ctypes.c_char_p, # output_path
|
|
3746
3841
|
ctypes.c_size_t, # num_ltoirs
|
|
3747
3842
|
ctypes.POINTER(ctypes.c_char_p), # ltoirs
|
|
@@ -3796,11 +3891,17 @@ class Runtime:
|
|
|
3796
3891
|
ctypes.c_int, # arch
|
|
3797
3892
|
ctypes.c_int, # M
|
|
3798
3893
|
ctypes.c_int, # N
|
|
3894
|
+
ctypes.c_int, # NRHS
|
|
3895
|
+
ctypes.c_int, # function
|
|
3896
|
+
ctypes.c_int, # side
|
|
3897
|
+
ctypes.c_int, # diag
|
|
3799
3898
|
ctypes.c_int, # precision
|
|
3899
|
+
ctypes.c_int, # a_arrangement
|
|
3900
|
+
ctypes.c_int, # b_arrangement
|
|
3800
3901
|
ctypes.c_int, # fill_mode
|
|
3801
3902
|
ctypes.c_int, # num threads
|
|
3802
3903
|
]
|
|
3803
|
-
self.core.
|
|
3904
|
+
self.core.cuda_compile_solver.restype = ctypes.c_bool
|
|
3804
3905
|
|
|
3805
3906
|
self.core.cuda_load_module.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
|
|
3806
3907
|
self.core.cuda_load_module.restype = ctypes.c_void_p
|
|
@@ -4270,7 +4371,7 @@ def is_cuda_driver_initialized() -> bool:
|
|
|
4270
4371
|
return runtime.core.cuda_driver_is_initialized()
|
|
4271
4372
|
|
|
4272
4373
|
|
|
4273
|
-
def get_devices() ->
|
|
4374
|
+
def get_devices() -> list[Device]:
|
|
4274
4375
|
"""Returns a list of devices supported in this environment."""
|
|
4275
4376
|
|
|
4276
4377
|
init()
|
|
@@ -4291,7 +4392,7 @@ def get_cuda_device_count() -> int:
|
|
|
4291
4392
|
return len(runtime.cuda_devices)
|
|
4292
4393
|
|
|
4293
4394
|
|
|
4294
|
-
def get_cuda_device(ordinal:
|
|
4395
|
+
def get_cuda_device(ordinal: int | None = None) -> Device:
|
|
4295
4396
|
"""Returns the CUDA device with the given ordinal or the current CUDA device if ordinal is None."""
|
|
4296
4397
|
|
|
4297
4398
|
init()
|
|
@@ -4302,7 +4403,7 @@ def get_cuda_device(ordinal: Union[int, None] = None) -> Device:
|
|
|
4302
4403
|
return runtime.cuda_devices[ordinal]
|
|
4303
4404
|
|
|
4304
4405
|
|
|
4305
|
-
def get_cuda_devices() ->
|
|
4406
|
+
def get_cuda_devices() -> list[Device]:
|
|
4306
4407
|
"""Returns a list of CUDA devices supported in this environment."""
|
|
4307
4408
|
|
|
4308
4409
|
init()
|
|
@@ -4341,7 +4442,7 @@ def set_device(ident: Devicelike) -> None:
|
|
|
4341
4442
|
device.make_current()
|
|
4342
4443
|
|
|
4343
4444
|
|
|
4344
|
-
def map_cuda_device(alias: str, context:
|
|
4445
|
+
def map_cuda_device(alias: str, context: ctypes.c_void_p | None = None) -> Device:
|
|
4345
4446
|
"""Assign a device alias to a CUDA context.
|
|
4346
4447
|
|
|
4347
4448
|
This function can be used to create a wp.Device for an external CUDA context.
|
|
@@ -4436,7 +4537,7 @@ def set_mempool_enabled(device: Devicelike, enable: bool) -> None:
|
|
|
4436
4537
|
raise ValueError("Memory pools are only supported on CUDA devices")
|
|
4437
4538
|
|
|
4438
4539
|
|
|
4439
|
-
def set_mempool_release_threshold(device: Devicelike, threshold:
|
|
4540
|
+
def set_mempool_release_threshold(device: Devicelike, threshold: int | float) -> None:
|
|
4440
4541
|
"""Set the CUDA memory pool release threshold on the device.
|
|
4441
4542
|
|
|
4442
4543
|
This is the amount of reserved memory to hold onto before trying to release memory back to the OS.
|
|
@@ -4744,7 +4845,7 @@ def set_stream(stream: Stream, device: Devicelike = None, sync: bool = False) ->
|
|
|
4744
4845
|
get_device(device).set_stream(stream, sync=sync)
|
|
4745
4846
|
|
|
4746
4847
|
|
|
4747
|
-
def record_event(event:
|
|
4848
|
+
def record_event(event: Event | None = None):
|
|
4748
4849
|
"""Convenience function for calling :meth:`Stream.record_event` on the current stream.
|
|
4749
4850
|
|
|
4750
4851
|
Args:
|
|
@@ -4793,7 +4894,7 @@ def get_event_elapsed_time(start_event: Event, end_event: Event, synchronize: bo
|
|
|
4793
4894
|
return runtime.core.cuda_event_elapsed_time(start_event.cuda_event, end_event.cuda_event)
|
|
4794
4895
|
|
|
4795
4896
|
|
|
4796
|
-
def wait_stream(other_stream: Stream, event:
|
|
4897
|
+
def wait_stream(other_stream: Stream, event: Event | None = None):
|
|
4797
4898
|
"""Convenience function for calling :meth:`Stream.wait_stream` on the current stream.
|
|
4798
4899
|
|
|
4799
4900
|
Args:
|
|
@@ -4863,7 +4964,7 @@ class RegisteredGLBuffer:
|
|
|
4863
4964
|
__fallback_warning_shown = False
|
|
4864
4965
|
|
|
4865
4966
|
def __new__(cls, *args, **kwargs):
|
|
4866
|
-
instance = super(
|
|
4967
|
+
instance = super().__new__(cls)
|
|
4867
4968
|
instance.resource = None
|
|
4868
4969
|
return instance
|
|
4869
4970
|
|
|
@@ -4960,8 +5061,8 @@ class RegisteredGLBuffer:
|
|
|
4960
5061
|
|
|
4961
5062
|
|
|
4962
5063
|
def zeros(
|
|
4963
|
-
shape:
|
|
4964
|
-
dtype=float,
|
|
5064
|
+
shape: int | tuple[int, ...] | list[int] | None = None,
|
|
5065
|
+
dtype: type = float,
|
|
4965
5066
|
device: Devicelike = None,
|
|
4966
5067
|
requires_grad: bool = False,
|
|
4967
5068
|
pinned: bool = False,
|
|
@@ -4988,7 +5089,7 @@ def zeros(
|
|
|
4988
5089
|
|
|
4989
5090
|
|
|
4990
5091
|
def zeros_like(
|
|
4991
|
-
src: Array, device: Devicelike = None, requires_grad:
|
|
5092
|
+
src: Array, device: Devicelike = None, requires_grad: bool | None = None, pinned: bool | None = None
|
|
4992
5093
|
) -> warp.array:
|
|
4993
5094
|
"""Return a zero-initialized array with the same type and dimension of another array
|
|
4994
5095
|
|
|
@@ -5010,8 +5111,8 @@ def zeros_like(
|
|
|
5010
5111
|
|
|
5011
5112
|
|
|
5012
5113
|
def ones(
|
|
5013
|
-
shape:
|
|
5014
|
-
dtype=float,
|
|
5114
|
+
shape: int | tuple[int, ...] | list[int] | None = None,
|
|
5115
|
+
dtype: type = float,
|
|
5015
5116
|
device: Devicelike = None,
|
|
5016
5117
|
requires_grad: bool = False,
|
|
5017
5118
|
pinned: bool = False,
|
|
@@ -5034,7 +5135,7 @@ def ones(
|
|
|
5034
5135
|
|
|
5035
5136
|
|
|
5036
5137
|
def ones_like(
|
|
5037
|
-
src: Array, device: Devicelike = None, requires_grad:
|
|
5138
|
+
src: Array, device: Devicelike = None, requires_grad: bool | None = None, pinned: bool | None = None
|
|
5038
5139
|
) -> warp.array:
|
|
5039
5140
|
"""Return a one-initialized array with the same type and dimension of another array
|
|
5040
5141
|
|
|
@@ -5052,7 +5153,7 @@ def ones_like(
|
|
|
5052
5153
|
|
|
5053
5154
|
|
|
5054
5155
|
def full(
|
|
5055
|
-
shape:
|
|
5156
|
+
shape: int | tuple[int, ...] | list[int] | None = None,
|
|
5056
5157
|
value=0,
|
|
5057
5158
|
dtype=Any,
|
|
5058
5159
|
device: Devicelike = None,
|
|
@@ -5121,8 +5222,8 @@ def full_like(
|
|
|
5121
5222
|
src: Array,
|
|
5122
5223
|
value: Any,
|
|
5123
5224
|
device: Devicelike = None,
|
|
5124
|
-
requires_grad:
|
|
5125
|
-
pinned:
|
|
5225
|
+
requires_grad: bool | None = None,
|
|
5226
|
+
pinned: bool | None = None,
|
|
5126
5227
|
) -> warp.array:
|
|
5127
5228
|
"""Return an array with all elements initialized to the given value with the same type and dimension of another array
|
|
5128
5229
|
|
|
@@ -5145,7 +5246,7 @@ def full_like(
|
|
|
5145
5246
|
|
|
5146
5247
|
|
|
5147
5248
|
def clone(
|
|
5148
|
-
src: warp.array, device: Devicelike = None, requires_grad:
|
|
5249
|
+
src: warp.array, device: Devicelike = None, requires_grad: bool | None = None, pinned: bool | None = None
|
|
5149
5250
|
) -> warp.array:
|
|
5150
5251
|
"""Clone an existing array, allocates a copy of the src memory
|
|
5151
5252
|
|
|
@@ -5167,7 +5268,7 @@ def clone(
|
|
|
5167
5268
|
|
|
5168
5269
|
|
|
5169
5270
|
def empty(
|
|
5170
|
-
shape:
|
|
5271
|
+
shape: int | tuple[int, ...] | list[int] | None = None,
|
|
5171
5272
|
dtype=float,
|
|
5172
5273
|
device: Devicelike = None,
|
|
5173
5274
|
requires_grad: bool = False,
|
|
@@ -5200,7 +5301,7 @@ def empty(
|
|
|
5200
5301
|
|
|
5201
5302
|
|
|
5202
5303
|
def empty_like(
|
|
5203
|
-
src: Array, device: Devicelike = None, requires_grad:
|
|
5304
|
+
src: Array, device: Devicelike = None, requires_grad: bool | None = None, pinned: bool | None = None
|
|
5204
5305
|
) -> warp.array:
|
|
5205
5306
|
"""Return an uninitialized array with the same type and dimension of another array
|
|
5206
5307
|
|
|
@@ -5235,9 +5336,9 @@ def empty_like(
|
|
|
5235
5336
|
|
|
5236
5337
|
def from_numpy(
|
|
5237
5338
|
arr: np.ndarray,
|
|
5238
|
-
dtype:
|
|
5239
|
-
shape:
|
|
5240
|
-
device:
|
|
5339
|
+
dtype: type | None = None,
|
|
5340
|
+
shape: Sequence[int] | None = None,
|
|
5341
|
+
device: Devicelike | None = None,
|
|
5241
5342
|
requires_grad: bool = False,
|
|
5242
5343
|
) -> warp.array:
|
|
5243
5344
|
"""Returns a Warp array created from a NumPy array.
|
|
@@ -5255,7 +5356,7 @@ def from_numpy(
|
|
|
5255
5356
|
if dtype is None:
|
|
5256
5357
|
base_type = warp.types.np_dtype_to_warp_type.get(arr.dtype)
|
|
5257
5358
|
if base_type is None:
|
|
5258
|
-
raise RuntimeError("Unsupported NumPy data type '{}'."
|
|
5359
|
+
raise RuntimeError(f"Unsupported NumPy data type '{arr.dtype}'.")
|
|
5259
5360
|
|
|
5260
5361
|
dim_count = len(arr.shape)
|
|
5261
5362
|
if dim_count == 2:
|
|
@@ -5274,7 +5375,7 @@ def from_numpy(
|
|
|
5274
5375
|
)
|
|
5275
5376
|
|
|
5276
5377
|
|
|
5277
|
-
def event_from_ipc_handle(handle, device:
|
|
5378
|
+
def event_from_ipc_handle(handle, device: Devicelike = None) -> Event:
|
|
5278
5379
|
"""Create an event from an IPC handle.
|
|
5279
5380
|
|
|
5280
5381
|
Args:
|
|
@@ -5443,10 +5544,10 @@ class Launch:
|
|
|
5443
5544
|
self,
|
|
5444
5545
|
kernel,
|
|
5445
5546
|
device: Device,
|
|
5446
|
-
hooks:
|
|
5447
|
-
params:
|
|
5448
|
-
params_addr:
|
|
5449
|
-
bounds:
|
|
5547
|
+
hooks: KernelHooks | None = None,
|
|
5548
|
+
params: Sequence[Any] | None = None,
|
|
5549
|
+
params_addr: Sequence[ctypes.c_void_p] | None = None,
|
|
5550
|
+
bounds: launch_bounds_t | None = None,
|
|
5450
5551
|
max_blocks: int = 0,
|
|
5451
5552
|
block_dim: int = 256,
|
|
5452
5553
|
adjoint: bool = False,
|
|
@@ -5516,7 +5617,7 @@ class Launch:
|
|
|
5516
5617
|
self.adjoint: bool = adjoint
|
|
5517
5618
|
"""Whether to run the adjoint kernel instead of the forward kernel."""
|
|
5518
5619
|
|
|
5519
|
-
def set_dim(self, dim:
|
|
5620
|
+
def set_dim(self, dim: int | list[int] | tuple[int, ...]):
|
|
5520
5621
|
"""Set the launch dimensions.
|
|
5521
5622
|
|
|
5522
5623
|
Args:
|
|
@@ -5554,7 +5655,7 @@ class Launch:
|
|
|
5554
5655
|
if self.params_addr:
|
|
5555
5656
|
self.params_addr[params_index] = ctypes.c_void_p(ctypes.addressof(carg))
|
|
5556
5657
|
|
|
5557
|
-
def set_param_at_index_from_ctype(self, index: int, value:
|
|
5658
|
+
def set_param_at_index_from_ctype(self, index: int, value: ctypes.Structure | int | float):
|
|
5558
5659
|
"""Set a kernel parameter at an index without any type conversion.
|
|
5559
5660
|
|
|
5560
5661
|
Args:
|
|
@@ -5617,7 +5718,7 @@ class Launch:
|
|
|
5617
5718
|
for i, v in enumerate(values):
|
|
5618
5719
|
self.set_param_at_index_from_ctype(i, v)
|
|
5619
5720
|
|
|
5620
|
-
def launch(self, stream:
|
|
5721
|
+
def launch(self, stream: Stream | None = None) -> None:
|
|
5621
5722
|
"""Launch the kernel.
|
|
5622
5723
|
|
|
5623
5724
|
Args:
|
|
@@ -5634,7 +5735,7 @@ class Launch:
|
|
|
5634
5735
|
|
|
5635
5736
|
# If the stream is capturing, we retain the CUDA module so that it doesn't get unloaded
|
|
5636
5737
|
# before the captured graph is released.
|
|
5637
|
-
if runtime.core.cuda_stream_is_capturing(stream.cuda_stream):
|
|
5738
|
+
if len(runtime.captures) > 0 and runtime.core.cuda_stream_is_capturing(stream.cuda_stream):
|
|
5638
5739
|
capture_id = runtime.core.cuda_stream_get_capture_id(stream.cuda_stream)
|
|
5639
5740
|
graph = runtime.captures.get(capture_id)
|
|
5640
5741
|
if graph is not None:
|
|
@@ -5666,13 +5767,13 @@ class Launch:
|
|
|
5666
5767
|
|
|
5667
5768
|
def launch(
|
|
5668
5769
|
kernel,
|
|
5669
|
-
dim:
|
|
5770
|
+
dim: int | Sequence[int],
|
|
5670
5771
|
inputs: Sequence = [],
|
|
5671
5772
|
outputs: Sequence = [],
|
|
5672
5773
|
adj_inputs: Sequence = [],
|
|
5673
5774
|
adj_outputs: Sequence = [],
|
|
5674
5775
|
device: Devicelike = None,
|
|
5675
|
-
stream:
|
|
5776
|
+
stream: Stream | None = None,
|
|
5676
5777
|
adjoint: bool = False,
|
|
5677
5778
|
record_tape: bool = True,
|
|
5678
5779
|
record_cmd: bool = False,
|
|
@@ -5824,7 +5925,7 @@ def launch(
|
|
|
5824
5925
|
|
|
5825
5926
|
# If the stream is capturing, we retain the CUDA module so that it doesn't get unloaded
|
|
5826
5927
|
# before the captured graph is released.
|
|
5827
|
-
if runtime.core.cuda_stream_is_capturing(stream.cuda_stream):
|
|
5928
|
+
if len(runtime.captures) > 0 and runtime.core.cuda_stream_is_capturing(stream.cuda_stream):
|
|
5828
5929
|
capture_id = runtime.core.cuda_stream_get_capture_id(stream.cuda_stream)
|
|
5829
5930
|
graph = runtime.captures.get(capture_id)
|
|
5830
5931
|
if graph is not None:
|
|
@@ -5968,7 +6069,7 @@ def launch_tiled(*args, **kwargs):
|
|
|
5968
6069
|
raise RuntimeError("wp.launch_tiled() requires a grid with fewer than 4 dimensions")
|
|
5969
6070
|
|
|
5970
6071
|
# add trailing dimension
|
|
5971
|
-
kwargs["dim"] = dim
|
|
6072
|
+
kwargs["dim"] = [*dim, kwargs["block_dim"]]
|
|
5972
6073
|
|
|
5973
6074
|
# forward to original launch method
|
|
5974
6075
|
return launch(*args, **kwargs)
|
|
@@ -6016,7 +6117,7 @@ def synchronize_device(device: Devicelike = None):
|
|
|
6016
6117
|
runtime.core.cuda_context_synchronize(device.context)
|
|
6017
6118
|
|
|
6018
6119
|
|
|
6019
|
-
def synchronize_stream(stream_or_device:
|
|
6120
|
+
def synchronize_stream(stream_or_device: Stream | Devicelike | None = None):
|
|
6020
6121
|
"""Synchronize the calling CPU thread with any outstanding CUDA work on the specified stream.
|
|
6021
6122
|
|
|
6022
6123
|
This function allows the host application code to ensure that all kernel launches
|
|
@@ -6046,7 +6147,7 @@ def synchronize_event(event: Event):
|
|
|
6046
6147
|
runtime.core.cuda_event_synchronize(event.cuda_event)
|
|
6047
6148
|
|
|
6048
6149
|
|
|
6049
|
-
def force_load(device:
|
|
6150
|
+
def force_load(device: Device | str | list[Device] | list[str] | None = None, modules: list[Module] | None = None):
|
|
6050
6151
|
"""Force user-defined kernels to be compiled and loaded
|
|
6051
6152
|
|
|
6052
6153
|
Args:
|
|
@@ -6078,7 +6179,7 @@ def force_load(device: Union[Device, str, List[Device], List[str]] = None, modul
|
|
|
6078
6179
|
|
|
6079
6180
|
|
|
6080
6181
|
def load_module(
|
|
6081
|
-
module:
|
|
6182
|
+
module: Module | types.ModuleType | str | None = None, device: Device | str | None = None, recursive: bool = False
|
|
6082
6183
|
):
|
|
6083
6184
|
"""Force user-defined module to be compiled and loaded
|
|
6084
6185
|
|
|
@@ -6120,7 +6221,7 @@ def load_module(
|
|
|
6120
6221
|
force_load(device=device, modules=modules)
|
|
6121
6222
|
|
|
6122
6223
|
|
|
6123
|
-
def set_module_options(options:
|
|
6224
|
+
def set_module_options(options: dict[str, Any], module: Any = None):
|
|
6124
6225
|
"""Set options for the current module.
|
|
6125
6226
|
|
|
6126
6227
|
Options can be used to control runtime compilation and code-generation
|
|
@@ -6144,7 +6245,7 @@ def set_module_options(options: Dict[str, Any], module: Optional[Any] = None):
|
|
|
6144
6245
|
get_module(m.__name__).mark_modified()
|
|
6145
6246
|
|
|
6146
6247
|
|
|
6147
|
-
def get_module_options(module:
|
|
6248
|
+
def get_module_options(module: Any = None) -> dict[str, Any]:
|
|
6148
6249
|
"""Returns a list of options for the current module."""
|
|
6149
6250
|
if module is None:
|
|
6150
6251
|
m = inspect.getmodule(inspect.stack()[1][0])
|
|
@@ -6156,8 +6257,8 @@ def get_module_options(module: Optional[Any] = None) -> Dict[str, Any]:
|
|
|
6156
6257
|
|
|
6157
6258
|
def capture_begin(
|
|
6158
6259
|
device: Devicelike = None,
|
|
6159
|
-
stream:
|
|
6160
|
-
force_module_load:
|
|
6260
|
+
stream: Stream | None = None,
|
|
6261
|
+
force_module_load: bool | None = None,
|
|
6161
6262
|
external: bool = False,
|
|
6162
6263
|
):
|
|
6163
6264
|
"""Begin capture of a CUDA graph
|
|
@@ -6226,7 +6327,7 @@ def capture_begin(
|
|
|
6226
6327
|
runtime.captures[capture_id] = graph
|
|
6227
6328
|
|
|
6228
6329
|
|
|
6229
|
-
def capture_end(device: Devicelike = None, stream:
|
|
6330
|
+
def capture_end(device: Devicelike = None, stream: Stream | None = None) -> Graph:
|
|
6230
6331
|
"""End the capture of a CUDA graph.
|
|
6231
6332
|
|
|
6232
6333
|
Args:
|
|
@@ -6255,20 +6356,324 @@ def capture_end(device: Devicelike = None, stream: Optional[Stream] = None) -> G
|
|
|
6255
6356
|
del runtime.captures[graph.capture_id]
|
|
6256
6357
|
|
|
6257
6358
|
# get the graph executable
|
|
6258
|
-
|
|
6259
|
-
result = runtime.core.cuda_graph_end_capture(device.context, stream.cuda_stream, ctypes.byref(
|
|
6359
|
+
g = ctypes.c_void_p()
|
|
6360
|
+
result = runtime.core.cuda_graph_end_capture(device.context, stream.cuda_stream, ctypes.byref(g))
|
|
6260
6361
|
|
|
6261
6362
|
if not result:
|
|
6262
6363
|
# A concrete error should've already been reported, so we don't need to go into details here
|
|
6263
6364
|
raise RuntimeError(f"CUDA graph capture failed. {runtime.get_error_string()}")
|
|
6264
6365
|
|
|
6265
6366
|
# set the graph executable
|
|
6266
|
-
graph.
|
|
6367
|
+
graph.graph = g
|
|
6368
|
+
graph.graph_exec = None # Lazy initialization
|
|
6267
6369
|
|
|
6268
6370
|
return graph
|
|
6269
6371
|
|
|
6270
6372
|
|
|
6271
|
-
def
|
|
6373
|
+
def capture_debug_dot_print(graph: Graph, path: str, verbose: bool = False):
|
|
6374
|
+
"""Export a CUDA graph to a DOT file for visualization
|
|
6375
|
+
|
|
6376
|
+
Args:
|
|
6377
|
+
graph: A :class:`Graph` as returned by :func:`~warp.capture_end()`
|
|
6378
|
+
path: Path to save the DOT file
|
|
6379
|
+
verbose: Whether to include additional debug information in the output
|
|
6380
|
+
"""
|
|
6381
|
+
if not runtime.core.capture_debug_dot_print(graph.graph, path.encode(), 0 if verbose else 1):
|
|
6382
|
+
raise RuntimeError(f"Graph debug dot print error: {runtime.get_error_string()}")
|
|
6383
|
+
|
|
6384
|
+
|
|
6385
|
+
def assert_conditional_graph_support():
|
|
6386
|
+
if runtime is None:
|
|
6387
|
+
init()
|
|
6388
|
+
|
|
6389
|
+
if runtime.toolkit_version < (12, 4):
|
|
6390
|
+
raise RuntimeError("Warp must be built with CUDA Toolkit 12.4+ to enable conditional graph nodes")
|
|
6391
|
+
|
|
6392
|
+
if runtime.driver_version < (12, 4):
|
|
6393
|
+
raise RuntimeError("Conditional graph nodes require CUDA driver 12.4+")
|
|
6394
|
+
|
|
6395
|
+
|
|
6396
|
+
def capture_pause(device: Devicelike = None, stream: Stream | None = None) -> ctypes.c_void_p:
|
|
6397
|
+
if stream is not None:
|
|
6398
|
+
device = stream.device
|
|
6399
|
+
else:
|
|
6400
|
+
device = runtime.get_device(device)
|
|
6401
|
+
if not device.is_cuda:
|
|
6402
|
+
raise RuntimeError("Must be a CUDA device")
|
|
6403
|
+
stream = device.stream
|
|
6404
|
+
|
|
6405
|
+
graph = ctypes.c_void_p()
|
|
6406
|
+
if not runtime.core.cuda_graph_pause_capture(device.context, stream.cuda_stream, ctypes.byref(graph)):
|
|
6407
|
+
raise RuntimeError(runtime.get_error_string())
|
|
6408
|
+
|
|
6409
|
+
return graph
|
|
6410
|
+
|
|
6411
|
+
|
|
6412
|
+
def capture_resume(graph: ctypes.c_void_p, device: Devicelike = None, stream: Stream | None = None):
|
|
6413
|
+
if stream is not None:
|
|
6414
|
+
device = stream.device
|
|
6415
|
+
else:
|
|
6416
|
+
device = runtime.get_device(device)
|
|
6417
|
+
if not device.is_cuda:
|
|
6418
|
+
raise RuntimeError("Must be a CUDA device")
|
|
6419
|
+
stream = device.stream
|
|
6420
|
+
|
|
6421
|
+
if not runtime.core.cuda_graph_resume_capture(device.context, stream.cuda_stream, graph):
|
|
6422
|
+
raise RuntimeError(runtime.get_error_string())
|
|
6423
|
+
|
|
6424
|
+
|
|
6425
|
+
# reusable pinned readback buffer for conditions
|
|
6426
|
+
condition_host = None
|
|
6427
|
+
|
|
6428
|
+
|
|
6429
|
+
def capture_if(
|
|
6430
|
+
condition: warp.array(dtype=int),
|
|
6431
|
+
on_true: Callable | Graph | None = None,
|
|
6432
|
+
on_false: Callable | Graph | None = None,
|
|
6433
|
+
stream: Stream = None,
|
|
6434
|
+
**kwargs,
|
|
6435
|
+
):
|
|
6436
|
+
"""Create a dynamic branch based on a condition.
|
|
6437
|
+
|
|
6438
|
+
The condition value is retrieved from the first element of the ``condition`` array.
|
|
6439
|
+
|
|
6440
|
+
This function is particularly useful with CUDA graphs, but can be used without graph capture as well.
|
|
6441
|
+
CUDA 12.4+ is required to take advantage of conditional graph nodes for dynamic control flow.
|
|
6442
|
+
|
|
6443
|
+
Args:
|
|
6444
|
+
condition: Warp array holding the condition value.
|
|
6445
|
+
on_true: A callback function or :class:`Graph` to execute if the condition is True.
|
|
6446
|
+
on_false: A callback function or :class:`Graph` to execute if the condition is False.
|
|
6447
|
+
stream: The CUDA stream where the condition was written. If None, use the current stream on the device where ``condition`` resides.
|
|
6448
|
+
|
|
6449
|
+
Any additional keyword arguments are forwarded to the callback functions.
|
|
6450
|
+
"""
|
|
6451
|
+
|
|
6452
|
+
# if neither the IF branch nor the ELSE branch is specified, it's a no-op
|
|
6453
|
+
if on_true is None and on_false is None:
|
|
6454
|
+
return
|
|
6455
|
+
|
|
6456
|
+
# check condition data type
|
|
6457
|
+
if not isinstance(condition, warp.array) or condition.dtype is not warp.int32:
|
|
6458
|
+
raise TypeError("Condition must be a Warp array of int32 with a single element")
|
|
6459
|
+
|
|
6460
|
+
device = condition.device
|
|
6461
|
+
|
|
6462
|
+
# determine the stream and whether a graph capture is active
|
|
6463
|
+
if device.is_cuda:
|
|
6464
|
+
if stream is None:
|
|
6465
|
+
stream = device.stream
|
|
6466
|
+
graph = device.captures.get(stream)
|
|
6467
|
+
else:
|
|
6468
|
+
graph = None
|
|
6469
|
+
|
|
6470
|
+
if graph is None:
|
|
6471
|
+
# if no graph is active, just execute the correct branch directly
|
|
6472
|
+
if device.is_cuda:
|
|
6473
|
+
# use a pinned buffer for condition readback to host
|
|
6474
|
+
global condition_host
|
|
6475
|
+
if condition_host is None:
|
|
6476
|
+
condition_host = warp.empty(1, dtype=int, device="cpu", pinned=True)
|
|
6477
|
+
warp.copy(condition_host, condition, stream=stream)
|
|
6478
|
+
warp.synchronize_stream(stream)
|
|
6479
|
+
condition_value = bool(ctypes.cast(condition_host.ptr, ctypes.POINTER(ctypes.c_int32)).contents)
|
|
6480
|
+
else:
|
|
6481
|
+
condition_value = bool(ctypes.cast(condition.ptr, ctypes.POINTER(ctypes.c_int32)).contents)
|
|
6482
|
+
|
|
6483
|
+
if condition_value:
|
|
6484
|
+
if on_true is not None:
|
|
6485
|
+
if isinstance(on_true, Callable):
|
|
6486
|
+
on_true(**kwargs)
|
|
6487
|
+
elif isinstance(on_true, Graph):
|
|
6488
|
+
capture_launch(on_true, stream=stream)
|
|
6489
|
+
else:
|
|
6490
|
+
raise TypeError("on_true must be a Callable or a Graph")
|
|
6491
|
+
else:
|
|
6492
|
+
if on_false is not None:
|
|
6493
|
+
if isinstance(on_false, Callable):
|
|
6494
|
+
on_false(**kwargs)
|
|
6495
|
+
elif isinstance(on_false, Graph):
|
|
6496
|
+
capture_launch(on_false, stream=stream)
|
|
6497
|
+
else:
|
|
6498
|
+
raise TypeError("on_false must be a Callable or a Graph")
|
|
6499
|
+
|
|
6500
|
+
return
|
|
6501
|
+
|
|
6502
|
+
graph.has_conditional = True
|
|
6503
|
+
|
|
6504
|
+
# ensure conditional graph nodes are supported
|
|
6505
|
+
assert_conditional_graph_support()
|
|
6506
|
+
|
|
6507
|
+
# insert conditional node
|
|
6508
|
+
graph_on_true = ctypes.c_void_p()
|
|
6509
|
+
graph_on_false = ctypes.c_void_p()
|
|
6510
|
+
if not runtime.core.cuda_graph_insert_if_else(
|
|
6511
|
+
device.context,
|
|
6512
|
+
stream.cuda_stream,
|
|
6513
|
+
ctypes.cast(condition.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
6514
|
+
None if on_true is None else ctypes.byref(graph_on_true),
|
|
6515
|
+
None if on_false is None else ctypes.byref(graph_on_false),
|
|
6516
|
+
):
|
|
6517
|
+
raise RuntimeError(runtime.get_error_string())
|
|
6518
|
+
|
|
6519
|
+
# pause capturing parent graph
|
|
6520
|
+
main_graph = capture_pause(stream=stream)
|
|
6521
|
+
|
|
6522
|
+
# capture if-graph
|
|
6523
|
+
if on_true is not None:
|
|
6524
|
+
capture_resume(graph_on_true, stream=stream)
|
|
6525
|
+
if isinstance(on_true, Callable):
|
|
6526
|
+
on_true(**kwargs)
|
|
6527
|
+
elif isinstance(on_true, Graph):
|
|
6528
|
+
if on_true.has_conditional:
|
|
6529
|
+
raise RuntimeError(
|
|
6530
|
+
"The on_true graph contains conditional nodes, which are not allowed in child graphs"
|
|
6531
|
+
)
|
|
6532
|
+
if not runtime.core.cuda_graph_insert_child_graph(
|
|
6533
|
+
device.context,
|
|
6534
|
+
stream.cuda_stream,
|
|
6535
|
+
on_true.graph,
|
|
6536
|
+
):
|
|
6537
|
+
raise RuntimeError(runtime.get_error_string())
|
|
6538
|
+
else:
|
|
6539
|
+
raise TypeError("on_true must be a Callable or a Graph")
|
|
6540
|
+
capture_pause(stream=stream)
|
|
6541
|
+
|
|
6542
|
+
# capture else-graph
|
|
6543
|
+
if on_false is not None:
|
|
6544
|
+
capture_resume(graph_on_false, stream=stream)
|
|
6545
|
+
if isinstance(on_false, Callable):
|
|
6546
|
+
on_false(**kwargs)
|
|
6547
|
+
elif isinstance(on_false, Graph):
|
|
6548
|
+
if on_false.has_conditional:
|
|
6549
|
+
raise RuntimeError(
|
|
6550
|
+
"The on_false graph contains conditional nodes, which are not allowed in child graphs"
|
|
6551
|
+
)
|
|
6552
|
+
if not runtime.core.cuda_graph_insert_child_graph(
|
|
6553
|
+
device.context,
|
|
6554
|
+
stream.cuda_stream,
|
|
6555
|
+
on_false.graph,
|
|
6556
|
+
):
|
|
6557
|
+
raise RuntimeError(runtime.get_error_string())
|
|
6558
|
+
else:
|
|
6559
|
+
raise TypeError("on_false must be a Callable or a Graph")
|
|
6560
|
+
capture_pause(stream=stream)
|
|
6561
|
+
|
|
6562
|
+
# resume capturing parent graph
|
|
6563
|
+
capture_resume(main_graph, stream=stream)
|
|
6564
|
+
|
|
6565
|
+
|
|
6566
|
+
def capture_while(condition: warp.array(dtype=int), while_body: Callable | Graph, stream: Stream = None, **kwargs):
|
|
6567
|
+
"""Create a dynamic loop based on a condition.
|
|
6568
|
+
|
|
6569
|
+
The condition value is retrieved from the first element of the ``condition`` array.
|
|
6570
|
+
|
|
6571
|
+
The ``while_body`` callback is responsible for updating the condition value so the loop can terminate.
|
|
6572
|
+
|
|
6573
|
+
This function is particularly useful with CUDA graphs, but can be used without graph capture as well.
|
|
6574
|
+
CUDA 12.4+ is required to take advantage of conditional graph nodes for dynamic control flow.
|
|
6575
|
+
|
|
6576
|
+
Args:
|
|
6577
|
+
condition: Warp array holding the condition value.
|
|
6578
|
+
while_body: A callback function or :class:`Graph` to execute while the loop condition is True.
|
|
6579
|
+
stream: The CUDA stream where the condition was written. If None, use the current stream on the device where ``condition`` resides.
|
|
6580
|
+
|
|
6581
|
+
Any additional keyword arguments are forwarded to the callback function.
|
|
6582
|
+
"""
|
|
6583
|
+
|
|
6584
|
+
# check condition data type
|
|
6585
|
+
if not isinstance(condition, warp.array) or condition.dtype is not warp.int32:
|
|
6586
|
+
raise TypeError("Condition must be a Warp array of int32 with a single element")
|
|
6587
|
+
|
|
6588
|
+
device = condition.device
|
|
6589
|
+
|
|
6590
|
+
# determine the stream and whether a graph capture is active
|
|
6591
|
+
if device.is_cuda:
|
|
6592
|
+
if stream is None:
|
|
6593
|
+
stream = device.stream
|
|
6594
|
+
graph = device.captures.get(stream)
|
|
6595
|
+
else:
|
|
6596
|
+
graph = None
|
|
6597
|
+
|
|
6598
|
+
if graph is None:
|
|
6599
|
+
# since no graph is active, just execute the kernels directly
|
|
6600
|
+
while True:
|
|
6601
|
+
if device.is_cuda:
|
|
6602
|
+
# use a pinned buffer for condition readback to host
|
|
6603
|
+
global condition_host
|
|
6604
|
+
if condition_host is None:
|
|
6605
|
+
condition_host = warp.empty(1, dtype=int, device="cpu", pinned=True)
|
|
6606
|
+
warp.copy(condition_host, condition, stream=stream)
|
|
6607
|
+
warp.synchronize_stream(stream)
|
|
6608
|
+
condition_value = bool(ctypes.cast(condition_host.ptr, ctypes.POINTER(ctypes.c_int32)).contents)
|
|
6609
|
+
else:
|
|
6610
|
+
condition_value = bool(ctypes.cast(condition.ptr, ctypes.POINTER(ctypes.c_int32)).contents)
|
|
6611
|
+
|
|
6612
|
+
if condition_value:
|
|
6613
|
+
if isinstance(while_body, Callable):
|
|
6614
|
+
while_body(**kwargs)
|
|
6615
|
+
elif isinstance(while_body, Graph):
|
|
6616
|
+
capture_launch(while_body, stream=stream)
|
|
6617
|
+
else:
|
|
6618
|
+
raise TypeError("while_body must be a callable or a graph")
|
|
6619
|
+
|
|
6620
|
+
else:
|
|
6621
|
+
break
|
|
6622
|
+
|
|
6623
|
+
return
|
|
6624
|
+
|
|
6625
|
+
graph.has_conditional = True
|
|
6626
|
+
|
|
6627
|
+
# ensure conditional graph nodes are supported
|
|
6628
|
+
assert_conditional_graph_support()
|
|
6629
|
+
|
|
6630
|
+
# insert conditional while-node
|
|
6631
|
+
body_graph = ctypes.c_void_p()
|
|
6632
|
+
cond_handle = ctypes.c_uint64()
|
|
6633
|
+
if not runtime.core.cuda_graph_insert_while(
|
|
6634
|
+
device.context,
|
|
6635
|
+
stream.cuda_stream,
|
|
6636
|
+
ctypes.cast(condition.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
6637
|
+
ctypes.byref(body_graph),
|
|
6638
|
+
ctypes.byref(cond_handle),
|
|
6639
|
+
):
|
|
6640
|
+
raise RuntimeError(runtime.get_error_string())
|
|
6641
|
+
|
|
6642
|
+
# pause capturing parent graph and start capturing child graph
|
|
6643
|
+
main_graph = capture_pause(stream=stream)
|
|
6644
|
+
capture_resume(body_graph, stream=stream)
|
|
6645
|
+
|
|
6646
|
+
# capture while-body
|
|
6647
|
+
if isinstance(while_body, Callable):
|
|
6648
|
+
while_body(**kwargs)
|
|
6649
|
+
elif isinstance(while_body, Graph):
|
|
6650
|
+
if while_body.has_conditional:
|
|
6651
|
+
raise RuntimeError("The body graph contains conditional nodes, which are not allowed in child graphs")
|
|
6652
|
+
|
|
6653
|
+
if not runtime.core.cuda_graph_insert_child_graph(
|
|
6654
|
+
device.context,
|
|
6655
|
+
stream.cuda_stream,
|
|
6656
|
+
while_body.graph,
|
|
6657
|
+
):
|
|
6658
|
+
raise RuntimeError(runtime.get_error_string())
|
|
6659
|
+
else:
|
|
6660
|
+
raise RuntimeError(runtime.get_error_string())
|
|
6661
|
+
|
|
6662
|
+
# update condition
|
|
6663
|
+
if not runtime.core.cuda_graph_set_condition(
|
|
6664
|
+
device.context,
|
|
6665
|
+
stream.cuda_stream,
|
|
6666
|
+
ctypes.cast(condition.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
6667
|
+
cond_handle,
|
|
6668
|
+
):
|
|
6669
|
+
raise RuntimeError(runtime.get_error_string())
|
|
6670
|
+
|
|
6671
|
+
# stop capturing child graph and resume capturing parent graph
|
|
6672
|
+
capture_pause(stream=stream)
|
|
6673
|
+
capture_resume(main_graph, stream=stream)
|
|
6674
|
+
|
|
6675
|
+
|
|
6676
|
+
def capture_launch(graph: Graph, stream: Stream | None = None):
|
|
6272
6677
|
"""Launch a previously captured CUDA graph
|
|
6273
6678
|
|
|
6274
6679
|
Args:
|
|
@@ -6284,6 +6689,13 @@ def capture_launch(graph: Graph, stream: Optional[Stream] = None):
|
|
|
6284
6689
|
device = graph.device
|
|
6285
6690
|
stream = device.stream
|
|
6286
6691
|
|
|
6692
|
+
if graph.graph_exec is None:
|
|
6693
|
+
g = ctypes.c_void_p()
|
|
6694
|
+
result = runtime.core.cuda_graph_create_exec(graph.device.context, graph.graph, ctypes.byref(g))
|
|
6695
|
+
if not result:
|
|
6696
|
+
raise RuntimeError(f"Graph creation error: {runtime.get_error_string()}")
|
|
6697
|
+
graph.graph_exec = g
|
|
6698
|
+
|
|
6287
6699
|
if not runtime.core.cuda_graph_launch(graph.graph_exec, stream.cuda_stream):
|
|
6288
6700
|
raise RuntimeError(f"Graph launch error: {runtime.get_error_string()}")
|
|
6289
6701
|
|
|
@@ -6294,7 +6706,7 @@ def copy(
|
|
|
6294
6706
|
dest_offset: int = 0,
|
|
6295
6707
|
src_offset: int = 0,
|
|
6296
6708
|
count: int = 0,
|
|
6297
|
-
stream:
|
|
6709
|
+
stream: Stream | None = None,
|
|
6298
6710
|
):
|
|
6299
6711
|
"""Copy array contents from `src` to `dest`.
|
|
6300
6712
|
|
|
@@ -6431,11 +6843,8 @@ def copy(
|
|
|
6431
6843
|
|
|
6432
6844
|
# can't copy to/from fabric arrays of arrays, because they are jagged arrays of arbitrary lengths
|
|
6433
6845
|
# TODO?
|
|
6434
|
-
if (
|
|
6435
|
-
isinstance(
|
|
6436
|
-
and src.ndim > 1
|
|
6437
|
-
or isinstance(dest, (warp.fabricarray, warp.indexedfabricarray))
|
|
6438
|
-
and dest.ndim > 1
|
|
6846
|
+
if (isinstance(src, (warp.fabricarray, warp.indexedfabricarray)) and src.ndim > 1) or (
|
|
6847
|
+
isinstance(dest, (warp.fabricarray, warp.indexedfabricarray)) and dest.ndim > 1
|
|
6439
6848
|
):
|
|
6440
6849
|
raise RuntimeError("Copying to/from Fabric arrays of arrays is not supported")
|
|
6441
6850
|
|
|
@@ -6503,7 +6912,7 @@ def type_str(t):
|
|
|
6503
6912
|
return "Callable"
|
|
6504
6913
|
elif isinstance(t, int):
|
|
6505
6914
|
return str(t)
|
|
6506
|
-
elif isinstance(t, List):
|
|
6915
|
+
elif isinstance(t, (List, tuple)):
|
|
6507
6916
|
return "Tuple[" + ", ".join(map(type_str, t)) + "]"
|
|
6508
6917
|
elif isinstance(t, warp.array):
|
|
6509
6918
|
return f"Array[{type_str(t.dtype)}]"
|
|
@@ -6536,12 +6945,16 @@ def type_str(t):
|
|
|
6536
6945
|
|
|
6537
6946
|
raise TypeError("Invalid vector or matrix dimensions")
|
|
6538
6947
|
elif get_origin(t) in (list, tuple):
|
|
6539
|
-
|
|
6540
|
-
|
|
6948
|
+
args = get_args(t)
|
|
6949
|
+
if args:
|
|
6950
|
+
args_repr = ", ".join(type_str(x) for x in get_args(t))
|
|
6951
|
+
return f"{t._name}[{args_repr}]"
|
|
6952
|
+
else:
|
|
6953
|
+
return f"{t._name}"
|
|
6541
6954
|
elif t is Ellipsis:
|
|
6542
6955
|
return "..."
|
|
6543
6956
|
elif warp.types.is_tile(t):
|
|
6544
|
-
return "Tile"
|
|
6957
|
+
return f"Tile[{type_str(t.dtype)},{type_str(t.shape)}]"
|
|
6545
6958
|
|
|
6546
6959
|
return t.__name__
|
|
6547
6960
|
|
|
@@ -6568,14 +6981,14 @@ def resolve_exported_function_sig(f):
|
|
|
6568
6981
|
# so we can generate the return type for overloaded functions
|
|
6569
6982
|
return_type = f.value_func(func_args, None)
|
|
6570
6983
|
|
|
6984
|
+
if return_type is None or (isinstance(return_type, tuple) and len(return_type) > 1):
|
|
6985
|
+
return (func_args, return_type)
|
|
6986
|
+
|
|
6571
6987
|
try:
|
|
6572
|
-
|
|
6988
|
+
ctype_ret_str(return_type)
|
|
6573
6989
|
except Exception:
|
|
6574
6990
|
return None
|
|
6575
6991
|
|
|
6576
|
-
if return_type_str.startswith("Tuple"):
|
|
6577
|
-
return None
|
|
6578
|
-
|
|
6579
6992
|
return (func_args, return_type)
|
|
6580
6993
|
|
|
6581
6994
|
|
|
@@ -6716,13 +7129,18 @@ def export_functions_rst(file): # pragma: no cover
|
|
|
6716
7129
|
print("---------------", file=file)
|
|
6717
7130
|
|
|
6718
7131
|
for f, is_exported in g:
|
|
7132
|
+
if not isinstance(f, Function) and callable(f):
|
|
7133
|
+
# f is a plain Python function
|
|
7134
|
+
print(f".. autofunction:: {f.__module__}.{f.__name__}", file=file)
|
|
7135
|
+
continue
|
|
6719
7136
|
if f.func:
|
|
6720
7137
|
# f is a Warp function written in Python, we can use autofunction
|
|
6721
7138
|
print(f".. autofunction:: {f.func.__module__}.{f.key}", file=file)
|
|
6722
7139
|
continue
|
|
6723
7140
|
for f_prefix, query_type in query_types:
|
|
6724
7141
|
if f.key.startswith(f_prefix) and query_type not in written_query_types:
|
|
6725
|
-
print(f".. autoclass:: {query_type}", file=file)
|
|
7142
|
+
print(f".. autoclass:: warp.{query_type}", file=file)
|
|
7143
|
+
print(" :exclude-members: Var, vars", file=file)
|
|
6726
7144
|
written_query_types.add(query_type)
|
|
6727
7145
|
break
|
|
6728
7146
|
|
|
@@ -6775,6 +7193,7 @@ def export_stubs(file): # pragma: no cover
|
|
|
6775
7193
|
print('Rows = TypeVar("Rows", bound=int)', file=file)
|
|
6776
7194
|
print('Cols = TypeVar("Cols", bound=int)', file=file)
|
|
6777
7195
|
print('DType = TypeVar("DType")', file=file)
|
|
7196
|
+
print('Shape = TypeVar("Shape")', file=file)
|
|
6778
7197
|
|
|
6779
7198
|
print("Vector = Generic[Length, Scalar]", file=file)
|
|
6780
7199
|
print("Matrix = Generic[Rows, Cols, Scalar]", file=file)
|
|
@@ -6783,6 +7202,7 @@ def export_stubs(file): # pragma: no cover
|
|
|
6783
7202
|
print("Array = Generic[DType]", file=file)
|
|
6784
7203
|
print("FabricArray = Generic[DType]", file=file)
|
|
6785
7204
|
print("IndexedFabricArray = Generic[DType]", file=file)
|
|
7205
|
+
print("Tile = Generic[DType, Shape]", file=file)
|
|
6786
7206
|
|
|
6787
7207
|
# prepend __init__.py
|
|
6788
7208
|
with open(os.path.join(os.path.dirname(file.name), "__init__.py")) as header_file:
|
|
@@ -6817,7 +7237,7 @@ def export_stubs(file): # pragma: no cover
|
|
|
6817
7237
|
if hasattr(g, "overloads"):
|
|
6818
7238
|
for f in g.overloads:
|
|
6819
7239
|
add_stub(f)
|
|
6820
|
-
|
|
7240
|
+
elif isinstance(g, Function):
|
|
6821
7241
|
add_stub(g)
|
|
6822
7242
|
|
|
6823
7243
|
|
|
@@ -6848,16 +7268,30 @@ def export_builtins(file: io.TextIOBase): # pragma: no cover
|
|
|
6848
7268
|
args = ", ".join(f"{ctype_arg_str(v)} {k}" for k, v in func_args.items())
|
|
6849
7269
|
params = ", ".join(func_args.keys())
|
|
6850
7270
|
|
|
6851
|
-
|
|
6852
|
-
|
|
6853
|
-
if args == "":
|
|
6854
|
-
file.write(f"WP_API void {f.mangled_name}({return_str}* ret) {{ *ret = wp::{f.key}({params}); }}\n")
|
|
6855
|
-
elif return_type is None:
|
|
7271
|
+
if return_type is None:
|
|
7272
|
+
# void function
|
|
6856
7273
|
file.write(f"WP_API void {f.mangled_name}({args}) {{ wp::{f.key}({params}); }}\n")
|
|
7274
|
+
elif isinstance(return_type, tuple) and len(return_type) > 1:
|
|
7275
|
+
# multiple return value function using output parameters
|
|
7276
|
+
outputs = tuple(f"{ctype_ret_str(x)}& ret_{i}" for i, x in enumerate(return_type))
|
|
7277
|
+
output_params = ", ".join(f"ret_{i}" for i in range(len(outputs)))
|
|
7278
|
+
if args:
|
|
7279
|
+
file.write(
|
|
7280
|
+
f"WP_API void {f.mangled_name}({args}, {', '.join(outputs)}) {{ wp::{f.key}({params}, {output_params}); }}\n"
|
|
7281
|
+
)
|
|
7282
|
+
else:
|
|
7283
|
+
file.write(
|
|
7284
|
+
f"WP_API void {f.mangled_name}({', '.join(outputs)}) {{ wp::{f.key}({params}, {output_params}); }}\n"
|
|
7285
|
+
)
|
|
6857
7286
|
else:
|
|
6858
|
-
|
|
6859
|
-
|
|
6860
|
-
|
|
7287
|
+
# single return value function
|
|
7288
|
+
return_str = ctype_ret_str(return_type)
|
|
7289
|
+
if args:
|
|
7290
|
+
file.write(
|
|
7291
|
+
f"WP_API void {f.mangled_name}({args}, {return_str}* ret) {{ *ret = wp::{f.key}({params}); }}\n"
|
|
7292
|
+
)
|
|
7293
|
+
else:
|
|
7294
|
+
file.write(f"WP_API void {f.mangled_name}({return_str}* ret) {{ *ret = wp::{f.key}({params}); }}\n")
|
|
6861
7295
|
|
|
6862
7296
|
file.write('\n} // extern "C"\n\n')
|
|
6863
7297
|
file.write("} // namespace wp\n")
|