warp-lang 1.8.0__py3-none-manylinux_2_34_aarch64.whl → 1.8.1__py3-none-manylinux_2_34_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/build_dll.py +5 -0
- warp/codegen.py +15 -3
- warp/config.py +1 -1
- warp/context.py +122 -24
- warp/examples/interop/example_jax_callable.py +34 -4
- warp/examples/interop/example_jax_kernel.py +27 -1
- warp/fem/field/virtual.py +2 -0
- warp/fem/integrate.py +78 -47
- warp/jax_experimental/ffi.py +201 -53
- warp/native/array.h +4 -4
- warp/native/builtin.h +8 -4
- warp/native/coloring.cpp +5 -1
- warp/native/cuda_util.cpp +1 -1
- warp/native/intersect.h +2 -2
- warp/native/mat.h +3 -3
- warp/native/mesh.h +1 -1
- warp/native/quat.h +6 -2
- warp/native/rand.h +7 -7
- warp/native/sparse.cu +1 -1
- warp/native/svd.h +23 -8
- warp/native/tile.h +20 -1
- warp/native/tile_radix_sort.h +5 -1
- warp/native/tile_reduce.h +16 -25
- warp/native/tuple.h +2 -2
- warp/native/vec.h +4 -4
- warp/native/warp.cpp +1 -1
- warp/native/warp.cu +15 -2
- warp/native/warp.h +1 -1
- warp/render/render_opengl.py +52 -51
- warp/render/render_usd.py +0 -1
- warp/sim/collide.py +1 -2
- warp/sim/integrator_vbd.py +10 -2
- warp/sparse.py +1 -1
- warp/tape.py +2 -0
- warp/tests/sim/test_cloth.py +89 -6
- warp/tests/sim/test_coloring.py +76 -1
- warp/tests/test_assert.py +53 -0
- warp/tests/test_atomic_cas.py +127 -114
- warp/tests/test_mat.py +22 -0
- warp/tests/test_quat.py +22 -0
- warp/tests/test_sparse.py +32 -0
- warp/tests/test_static.py +48 -0
- warp/tests/test_tape.py +38 -0
- warp/tests/test_vec.py +38 -408
- warp/tests/test_vec_constructors.py +325 -0
- warp/tests/tile/test_tile.py +31 -143
- warp/tests/tile/test_tile_mathdx.py +2 -2
- warp/tests/tile/test_tile_matmul.py +179 -0
- warp/tests/tile/test_tile_reduce.py +100 -11
- warp/tests/tile/test_tile_shared_memory.py +12 -12
- warp/tests/tile/test_tile_sort.py +59 -55
- warp/tests/unittest_suites.py +10 -0
- {warp_lang-1.8.0.dist-info → warp_lang-1.8.1.dist-info}/METADATA +4 -4
- {warp_lang-1.8.0.dist-info → warp_lang-1.8.1.dist-info}/RECORD +59 -57
- {warp_lang-1.8.0.dist-info → warp_lang-1.8.1.dist-info}/WHEEL +0 -0
- {warp_lang-1.8.0.dist-info → warp_lang-1.8.1.dist-info}/licenses/LICENSE.md +0 -0
- {warp_lang-1.8.0.dist-info → warp_lang-1.8.1.dist-info}/top_level.txt +0 -0
warp/bin/warp-clang.so
CHANGED
|
Binary file
|
warp/bin/warp.so
CHANGED
|
Binary file
|
warp/build_dll.py
CHANGED
|
@@ -227,6 +227,7 @@ def build_dll_for_arch(args, dll_path, cpp_paths, cu_path, arch, libs: Optional[
|
|
|
227
227
|
"-gencode=arch=compute_61,code=sm_61",
|
|
228
228
|
"-gencode=arch=compute_70,code=sm_70", # Volta
|
|
229
229
|
"-gencode=arch=compute_75,code=sm_75", # Turing
|
|
230
|
+
"-gencode=arch=compute_75,code=compute_75", # Turing (PTX)
|
|
230
231
|
"-gencode=arch=compute_80,code=sm_80", # Ampere
|
|
231
232
|
"-gencode=arch=compute_86,code=sm_86",
|
|
232
233
|
]
|
|
@@ -260,6 +261,10 @@ def build_dll_for_arch(args, dll_path, cpp_paths, cu_path, arch, libs: Optional[
|
|
|
260
261
|
"--cuda-gpu-arch=sm_87", # Orin
|
|
261
262
|
]
|
|
262
263
|
|
|
264
|
+
if ctk_version >= (12, 8):
|
|
265
|
+
gencode_opts += ["-gencode=arch=compute_101,code=sm_101"] # Thor (CUDA 12 numbering)
|
|
266
|
+
clang_arch_flags += ["--cuda-gpu-arch=sm_101"]
|
|
267
|
+
|
|
263
268
|
if ctk_version >= (12, 8):
|
|
264
269
|
# Support for Blackwell is available with CUDA Toolkit 12.8+
|
|
265
270
|
gencode_opts += [
|
warp/codegen.py
CHANGED
|
@@ -616,6 +616,8 @@ def compute_type_str(base_name, template_params):
|
|
|
616
616
|
def param2str(p):
|
|
617
617
|
if isinstance(p, int):
|
|
618
618
|
return str(p)
|
|
619
|
+
elif hasattr(p, "_wp_generic_type_str_"):
|
|
620
|
+
return compute_type_str(f"wp::{p._wp_generic_type_str_}", p._wp_type_params_)
|
|
619
621
|
elif hasattr(p, "_type_"):
|
|
620
622
|
if p.__name__ == "bool":
|
|
621
623
|
return "bool"
|
|
@@ -967,6 +969,11 @@ class Adjoint:
|
|
|
967
969
|
# this is to avoid registering false references to overshadowed modules
|
|
968
970
|
adj.symbols[name] = arg
|
|
969
971
|
|
|
972
|
+
# Indicates whether there are unresolved static expressions in the function.
|
|
973
|
+
# These stem from wp.static() expressions that could not be evaluated at declaration time.
|
|
974
|
+
# This will signal to the module builder that this module needs to be rebuilt even if the module hash is unchanged.
|
|
975
|
+
adj.has_unresolved_static_expressions = False
|
|
976
|
+
|
|
970
977
|
# try to replace static expressions by their constant result if the
|
|
971
978
|
# expression can be evaluated at declaration time
|
|
972
979
|
adj.static_expressions: dict[str, Any] = {}
|
|
@@ -2322,8 +2329,9 @@ class Adjoint:
|
|
|
2322
2329
|
|
|
2323
2330
|
if adj.is_static_expression(func):
|
|
2324
2331
|
# try to evaluate wp.static() expressions
|
|
2325
|
-
obj,
|
|
2332
|
+
obj, code = adj.evaluate_static_expression(node)
|
|
2326
2333
|
if obj is not None:
|
|
2334
|
+
adj.static_expressions[code] = obj
|
|
2327
2335
|
if isinstance(obj, warp.context.Function):
|
|
2328
2336
|
# special handling for wp.static() evaluating to a function
|
|
2329
2337
|
return obj
|
|
@@ -3109,6 +3117,7 @@ class Adjoint:
|
|
|
3109
3117
|
|
|
3110
3118
|
# Since this is an expression, we can enforce it to be defined on a single line.
|
|
3111
3119
|
static_code = static_code.replace("\n", "")
|
|
3120
|
+
code_to_eval = static_code # code to be evaluated
|
|
3112
3121
|
|
|
3113
3122
|
vars_dict = adj.get_static_evaluation_context()
|
|
3114
3123
|
# add constant variables to the static call context
|
|
@@ -3150,10 +3159,10 @@ class Adjoint:
|
|
|
3150
3159
|
loc = end
|
|
3151
3160
|
|
|
3152
3161
|
new_static_code += static_code[len_value_locs[-1][2] :]
|
|
3153
|
-
|
|
3162
|
+
code_to_eval = new_static_code
|
|
3154
3163
|
|
|
3155
3164
|
try:
|
|
3156
|
-
value = eval(
|
|
3165
|
+
value = eval(code_to_eval, vars_dict)
|
|
3157
3166
|
if warp.config.verbose:
|
|
3158
3167
|
print(f"Evaluated static command: {static_code} = {value}")
|
|
3159
3168
|
except NameError as e:
|
|
@@ -3206,6 +3215,9 @@ class Adjoint:
|
|
|
3206
3215
|
# (and is therefore not executable and raises this exception), in which
|
|
3207
3216
|
# case changing the constant, or the code affecting this constant, would lead to
|
|
3208
3217
|
# a different module hash anyway.
|
|
3218
|
+
# In any case, we mark this Adjoint to have unresolvable static expressions.
|
|
3219
|
+
# This will trigger a code generation step even if the module hash is unchanged.
|
|
3220
|
+
adj.has_unresolved_static_expressions = True
|
|
3209
3221
|
pass
|
|
3210
3222
|
|
|
3211
3223
|
return self.generic_visit(node)
|
warp/config.py
CHANGED
warp/context.py
CHANGED
|
@@ -1692,7 +1692,7 @@ class ModuleHasher:
|
|
|
1692
1692
|
ch.update(bytes(name, "utf-8"))
|
|
1693
1693
|
ch.update(self.get_constant_bytes(value))
|
|
1694
1694
|
|
|
1695
|
-
# hash wp.static() expressions
|
|
1695
|
+
# hash wp.static() expressions
|
|
1696
1696
|
for k, v in adj.static_expressions.items():
|
|
1697
1697
|
ch.update(bytes(k, "utf-8"))
|
|
1698
1698
|
if isinstance(v, Function):
|
|
@@ -2011,6 +2011,9 @@ class Module:
|
|
|
2011
2011
|
# is retained and later reloaded with the same hash.
|
|
2012
2012
|
self.cpu_exec_id = 0
|
|
2013
2013
|
|
|
2014
|
+
# Indicates whether the module has functions or kernels with unresolved static expressions.
|
|
2015
|
+
self.has_unresolved_static_expressions = False
|
|
2016
|
+
|
|
2014
2017
|
self.options = {
|
|
2015
2018
|
"max_unroll": warp.config.max_unroll,
|
|
2016
2019
|
"enable_backward": warp.config.enable_backward,
|
|
@@ -2018,7 +2021,7 @@ class Module:
|
|
|
2018
2021
|
"fuse_fp": True,
|
|
2019
2022
|
"lineinfo": warp.config.lineinfo,
|
|
2020
2023
|
"cuda_output": None, # supported values: "ptx", "cubin", or None (automatic)
|
|
2021
|
-
"mode":
|
|
2024
|
+
"mode": None,
|
|
2022
2025
|
"block_dim": 256,
|
|
2023
2026
|
"compile_time_trace": warp.config.compile_time_trace,
|
|
2024
2027
|
}
|
|
@@ -2047,6 +2050,10 @@ class Module:
|
|
|
2047
2050
|
# track all kernel objects, even if they are duplicates
|
|
2048
2051
|
self._live_kernels.add(kernel)
|
|
2049
2052
|
|
|
2053
|
+
# Check for unresolved static expressions in the kernel.
|
|
2054
|
+
if kernel.adj.has_unresolved_static_expressions:
|
|
2055
|
+
self.has_unresolved_static_expressions = True
|
|
2056
|
+
|
|
2050
2057
|
self.find_references(kernel.adj)
|
|
2051
2058
|
|
|
2052
2059
|
# for a reload of module on next launch
|
|
@@ -2106,6 +2113,10 @@ class Module:
|
|
|
2106
2113
|
del func_existing.user_overloads[k]
|
|
2107
2114
|
func_existing.add_overload(func)
|
|
2108
2115
|
|
|
2116
|
+
# Check for unresolved static expressions in the function.
|
|
2117
|
+
if func.adj.has_unresolved_static_expressions:
|
|
2118
|
+
self.has_unresolved_static_expressions = True
|
|
2119
|
+
|
|
2109
2120
|
self.find_references(func.adj)
|
|
2110
2121
|
|
|
2111
2122
|
# for a reload of module on next launch
|
|
@@ -2165,7 +2176,7 @@ class Module:
|
|
|
2165
2176
|
self.hashers[block_dim] = ModuleHasher(self)
|
|
2166
2177
|
return self.hashers[block_dim].get_module_hash()
|
|
2167
2178
|
|
|
2168
|
-
def load(self, device, block_dim=None) -> ModuleExec:
|
|
2179
|
+
def load(self, device, block_dim=None) -> ModuleExec | None:
|
|
2169
2180
|
device = runtime.get_device(device)
|
|
2170
2181
|
|
|
2171
2182
|
# update module options if launching with a new block dim
|
|
@@ -2174,6 +2185,20 @@ class Module:
|
|
|
2174
2185
|
|
|
2175
2186
|
active_block_dim = self.options["block_dim"]
|
|
2176
2187
|
|
|
2188
|
+
if self.has_unresolved_static_expressions:
|
|
2189
|
+
# The module hash currently does not account for unresolved static expressions
|
|
2190
|
+
# (only static expressions evaluated at declaration time so far).
|
|
2191
|
+
# We need to generate the code for the functions and kernels that have
|
|
2192
|
+
# unresolved static expressions and then compute the module hash again.
|
|
2193
|
+
builder_options = {
|
|
2194
|
+
**self.options,
|
|
2195
|
+
"output_arch": None,
|
|
2196
|
+
}
|
|
2197
|
+
# build functions, kernels to resolve static expressions
|
|
2198
|
+
_ = ModuleBuilder(self, builder_options)
|
|
2199
|
+
|
|
2200
|
+
self.has_unresolved_static_expressions = False
|
|
2201
|
+
|
|
2177
2202
|
# compute the hash if needed
|
|
2178
2203
|
if active_block_dim not in self.hashers:
|
|
2179
2204
|
self.hashers[active_block_dim] = ModuleHasher(self)
|
|
@@ -2262,6 +2287,8 @@ class Module:
|
|
|
2262
2287
|
|
|
2263
2288
|
module_load_timer.extra_msg = " (compiled)" # For wp.ScopedTimer informational purposes
|
|
2264
2289
|
|
|
2290
|
+
mode = self.options["mode"] if self.options["mode"] is not None else warp.config.mode
|
|
2291
|
+
|
|
2265
2292
|
# build CPU
|
|
2266
2293
|
if device.is_cpu:
|
|
2267
2294
|
# build
|
|
@@ -2281,7 +2308,7 @@ class Module:
|
|
|
2281
2308
|
warp.build.build_cpu(
|
|
2282
2309
|
output_path,
|
|
2283
2310
|
source_code_path,
|
|
2284
|
-
mode=
|
|
2311
|
+
mode=mode,
|
|
2285
2312
|
fast_math=self.options["fast_math"],
|
|
2286
2313
|
verify_fp=warp.config.verify_fp,
|
|
2287
2314
|
fuse_fp=self.options["fuse_fp"],
|
|
@@ -2311,7 +2338,7 @@ class Module:
|
|
|
2311
2338
|
source_code_path,
|
|
2312
2339
|
output_arch,
|
|
2313
2340
|
output_path,
|
|
2314
|
-
config=
|
|
2341
|
+
config=mode,
|
|
2315
2342
|
verify_fp=warp.config.verify_fp,
|
|
2316
2343
|
fast_math=self.options["fast_math"],
|
|
2317
2344
|
fuse_fp=self.options["fuse_fp"],
|
|
@@ -3759,6 +3786,7 @@ class Runtime:
|
|
|
3759
3786
|
self.core.cuda_graph_end_capture.restype = ctypes.c_bool
|
|
3760
3787
|
|
|
3761
3788
|
self.core.cuda_graph_create_exec.argtypes = [
|
|
3789
|
+
ctypes.c_void_p,
|
|
3762
3790
|
ctypes.c_void_p,
|
|
3763
3791
|
ctypes.c_void_p,
|
|
3764
3792
|
ctypes.POINTER(ctypes.c_void_p),
|
|
@@ -4066,9 +4094,14 @@ class Runtime:
|
|
|
4066
4094
|
# Update the default PTX architecture based on devices present in the system.
|
|
4067
4095
|
# Use the lowest architecture among devices that meet the minimum architecture requirement.
|
|
4068
4096
|
# Devices below the required minimum will use the highest architecture they support.
|
|
4069
|
-
|
|
4070
|
-
|
|
4071
|
-
|
|
4097
|
+
try:
|
|
4098
|
+
self.default_ptx_arch = min(
|
|
4099
|
+
d.arch
|
|
4100
|
+
for d in self.cuda_devices
|
|
4101
|
+
if d.arch >= self.default_ptx_arch and d.arch in self.nvrtc_supported_archs
|
|
4102
|
+
)
|
|
4103
|
+
except ValueError:
|
|
4104
|
+
pass # no eligible NVRTC-supported arch ≥ default, retain existing
|
|
4072
4105
|
else:
|
|
4073
4106
|
# CUDA not available
|
|
4074
4107
|
self.set_default_device("cpu")
|
|
@@ -6255,6 +6288,40 @@ def get_module_options(module: Any = None) -> dict[str, Any]:
|
|
|
6255
6288
|
return get_module(m.__name__).options
|
|
6256
6289
|
|
|
6257
6290
|
|
|
6291
|
+
def _unregister_capture(device: Device, stream: Stream, graph: Graph):
|
|
6292
|
+
"""Unregister a graph capture from the device and runtime.
|
|
6293
|
+
|
|
6294
|
+
This should be called when a graph capture is no longer active, either because it completed or was paused.
|
|
6295
|
+
The graph should only be registered while it is actively capturing.
|
|
6296
|
+
|
|
6297
|
+
Args:
|
|
6298
|
+
device: The CUDA device the graph was being captured on
|
|
6299
|
+
stream: The CUDA stream the graph was being captured on
|
|
6300
|
+
graph: The Graph object that was being captured
|
|
6301
|
+
"""
|
|
6302
|
+
del device.captures[stream]
|
|
6303
|
+
del runtime.captures[graph.capture_id]
|
|
6304
|
+
|
|
6305
|
+
|
|
6306
|
+
def _register_capture(device: Device, stream: Stream, graph: Graph, capture_id: int):
|
|
6307
|
+
"""Register a graph capture with the device and runtime.
|
|
6308
|
+
|
|
6309
|
+
Makes the graph discoverable through its capture_id so that retain_module_exec() can be called
|
|
6310
|
+
when launching kernels during graph capture. This ensures modules are retained until graph execution completes.
|
|
6311
|
+
|
|
6312
|
+
Args:
|
|
6313
|
+
device: The CUDA device the graph is being captured on
|
|
6314
|
+
stream: The CUDA stream the graph is being captured on
|
|
6315
|
+
graph: The Graph object being captured
|
|
6316
|
+
capture_id: Unique identifier for this graph capture
|
|
6317
|
+
"""
|
|
6318
|
+
# add to ongoing captures on the device
|
|
6319
|
+
device.captures[stream] = graph
|
|
6320
|
+
|
|
6321
|
+
# add to lookup table by globally unique capture id
|
|
6322
|
+
runtime.captures[capture_id] = graph
|
|
6323
|
+
|
|
6324
|
+
|
|
6258
6325
|
def capture_begin(
|
|
6259
6326
|
device: Devicelike = None,
|
|
6260
6327
|
stream: Stream | None = None,
|
|
@@ -6320,11 +6387,7 @@ def capture_begin(
|
|
|
6320
6387
|
capture_id = runtime.core.cuda_stream_get_capture_id(stream.cuda_stream)
|
|
6321
6388
|
graph = Graph(device, capture_id)
|
|
6322
6389
|
|
|
6323
|
-
|
|
6324
|
-
device.captures[stream] = graph
|
|
6325
|
-
|
|
6326
|
-
# add to lookup table by globally unique capture id
|
|
6327
|
-
runtime.captures[capture_id] = graph
|
|
6390
|
+
_register_capture(device, stream, graph, capture_id)
|
|
6328
6391
|
|
|
6329
6392
|
|
|
6330
6393
|
def capture_end(device: Devicelike = None, stream: Stream | None = None) -> Graph:
|
|
@@ -6352,8 +6415,7 @@ def capture_end(device: Devicelike = None, stream: Stream | None = None) -> Grap
|
|
|
6352
6415
|
if graph is None:
|
|
6353
6416
|
raise RuntimeError("Graph capture is not active on this stream")
|
|
6354
6417
|
|
|
6355
|
-
|
|
6356
|
-
del runtime.captures[graph.capture_id]
|
|
6418
|
+
_unregister_capture(device, stream, graph)
|
|
6357
6419
|
|
|
6358
6420
|
# get the graph executable
|
|
6359
6421
|
g = ctypes.c_void_p()
|
|
@@ -6393,7 +6455,7 @@ def assert_conditional_graph_support():
|
|
|
6393
6455
|
raise RuntimeError("Conditional graph nodes require CUDA driver 12.4+")
|
|
6394
6456
|
|
|
6395
6457
|
|
|
6396
|
-
def capture_pause(device: Devicelike = None, stream: Stream | None = None) ->
|
|
6458
|
+
def capture_pause(device: Devicelike = None, stream: Stream | None = None) -> Graph:
|
|
6397
6459
|
if stream is not None:
|
|
6398
6460
|
device = stream.device
|
|
6399
6461
|
else:
|
|
@@ -6402,14 +6464,24 @@ def capture_pause(device: Devicelike = None, stream: Stream | None = None) -> ct
|
|
|
6402
6464
|
raise RuntimeError("Must be a CUDA device")
|
|
6403
6465
|
stream = device.stream
|
|
6404
6466
|
|
|
6405
|
-
graph
|
|
6406
|
-
|
|
6467
|
+
# get the graph being captured
|
|
6468
|
+
graph = device.captures.get(stream)
|
|
6469
|
+
|
|
6470
|
+
if graph is None:
|
|
6471
|
+
raise RuntimeError("Graph capture is not active on this stream")
|
|
6472
|
+
|
|
6473
|
+
_unregister_capture(device, stream, graph)
|
|
6474
|
+
|
|
6475
|
+
g = ctypes.c_void_p()
|
|
6476
|
+
if not runtime.core.cuda_graph_pause_capture(device.context, stream.cuda_stream, ctypes.byref(g)):
|
|
6407
6477
|
raise RuntimeError(runtime.get_error_string())
|
|
6408
6478
|
|
|
6479
|
+
graph.graph = g
|
|
6480
|
+
|
|
6409
6481
|
return graph
|
|
6410
6482
|
|
|
6411
6483
|
|
|
6412
|
-
def capture_resume(graph:
|
|
6484
|
+
def capture_resume(graph: Graph, device: Devicelike = None, stream: Stream | None = None):
|
|
6413
6485
|
if stream is not None:
|
|
6414
6486
|
device = stream.device
|
|
6415
6487
|
else:
|
|
@@ -6418,9 +6490,14 @@ def capture_resume(graph: ctypes.c_void_p, device: Devicelike = None, stream: St
|
|
|
6418
6490
|
raise RuntimeError("Must be a CUDA device")
|
|
6419
6491
|
stream = device.stream
|
|
6420
6492
|
|
|
6421
|
-
if not runtime.core.cuda_graph_resume_capture(device.context, stream.cuda_stream, graph):
|
|
6493
|
+
if not runtime.core.cuda_graph_resume_capture(device.context, stream.cuda_stream, graph.graph):
|
|
6422
6494
|
raise RuntimeError(runtime.get_error_string())
|
|
6423
6495
|
|
|
6496
|
+
capture_id = runtime.core.cuda_stream_get_capture_id(stream.cuda_stream)
|
|
6497
|
+
graph.capture_id = capture_id
|
|
6498
|
+
|
|
6499
|
+
_register_capture(device, stream, graph, capture_id)
|
|
6500
|
+
|
|
6424
6501
|
|
|
6425
6502
|
# reusable pinned readback buffer for conditions
|
|
6426
6503
|
condition_host = None
|
|
@@ -6518,10 +6595,15 @@ def capture_if(
|
|
|
6518
6595
|
|
|
6519
6596
|
# pause capturing parent graph
|
|
6520
6597
|
main_graph = capture_pause(stream=stream)
|
|
6598
|
+
# store the pointer to the cuda graph to restore it later
|
|
6599
|
+
main_graph_ptr = main_graph.graph
|
|
6521
6600
|
|
|
6522
6601
|
# capture if-graph
|
|
6523
6602
|
if on_true is not None:
|
|
6524
|
-
|
|
6603
|
+
# temporarily repurpose the main_graph python object such that all dependencies
|
|
6604
|
+
# added through retain_module_exec() end up in the correct python graph object
|
|
6605
|
+
main_graph.graph = graph_on_true
|
|
6606
|
+
capture_resume(main_graph, stream=stream)
|
|
6525
6607
|
if isinstance(on_true, Callable):
|
|
6526
6608
|
on_true(**kwargs)
|
|
6527
6609
|
elif isinstance(on_true, Graph):
|
|
@@ -6541,7 +6623,10 @@ def capture_if(
|
|
|
6541
6623
|
|
|
6542
6624
|
# capture else-graph
|
|
6543
6625
|
if on_false is not None:
|
|
6544
|
-
|
|
6626
|
+
# temporarily repurpose the main_graph python object such that all dependencies
|
|
6627
|
+
# added through retain_module_exec() end up in the correct python graph object
|
|
6628
|
+
main_graph.graph = graph_on_false
|
|
6629
|
+
capture_resume(main_graph, stream=stream)
|
|
6545
6630
|
if isinstance(on_false, Callable):
|
|
6546
6631
|
on_false(**kwargs)
|
|
6547
6632
|
elif isinstance(on_false, Graph):
|
|
@@ -6559,6 +6644,9 @@ def capture_if(
|
|
|
6559
6644
|
raise TypeError("on_false must be a Callable or a Graph")
|
|
6560
6645
|
capture_pause(stream=stream)
|
|
6561
6646
|
|
|
6647
|
+
# restore the main graph to its original state
|
|
6648
|
+
main_graph.graph = main_graph_ptr
|
|
6649
|
+
|
|
6562
6650
|
# resume capturing parent graph
|
|
6563
6651
|
capture_resume(main_graph, stream=stream)
|
|
6564
6652
|
|
|
@@ -6641,7 +6729,13 @@ def capture_while(condition: warp.array(dtype=int), while_body: Callable | Graph
|
|
|
6641
6729
|
|
|
6642
6730
|
# pause capturing parent graph and start capturing child graph
|
|
6643
6731
|
main_graph = capture_pause(stream=stream)
|
|
6644
|
-
|
|
6732
|
+
# store the pointer to the cuda graph to restore it later
|
|
6733
|
+
main_graph_ptr = main_graph.graph
|
|
6734
|
+
|
|
6735
|
+
# temporarily repurpose the main_graph python object such that all dependencies
|
|
6736
|
+
# added through retain_module_exec() end up in the correct python graph object
|
|
6737
|
+
main_graph.graph = body_graph
|
|
6738
|
+
capture_resume(main_graph, stream=stream)
|
|
6645
6739
|
|
|
6646
6740
|
# capture while-body
|
|
6647
6741
|
if isinstance(while_body, Callable):
|
|
@@ -6670,6 +6764,8 @@ def capture_while(condition: warp.array(dtype=int), while_body: Callable | Graph
|
|
|
6670
6764
|
|
|
6671
6765
|
# stop capturing child graph and resume capturing parent graph
|
|
6672
6766
|
capture_pause(stream=stream)
|
|
6767
|
+
# restore the main graph to its original state
|
|
6768
|
+
main_graph.graph = main_graph_ptr
|
|
6673
6769
|
capture_resume(main_graph, stream=stream)
|
|
6674
6770
|
|
|
6675
6771
|
|
|
@@ -6691,7 +6787,9 @@ def capture_launch(graph: Graph, stream: Stream | None = None):
|
|
|
6691
6787
|
|
|
6692
6788
|
if graph.graph_exec is None:
|
|
6693
6789
|
g = ctypes.c_void_p()
|
|
6694
|
-
result = runtime.core.cuda_graph_create_exec(
|
|
6790
|
+
result = runtime.core.cuda_graph_create_exec(
|
|
6791
|
+
graph.device.context, stream.cuda_stream, graph.graph, ctypes.byref(g)
|
|
6792
|
+
)
|
|
6695
6793
|
if not result:
|
|
6696
6794
|
raise RuntimeError(f"Graph creation error: {runtime.get_error_string()}")
|
|
6697
6795
|
graph.graph_exec = g
|
|
@@ -42,7 +42,7 @@ def scale_vec_kernel(a: wp.array(dtype=wp.vec2), s: float, output: wp.array(dtyp
|
|
|
42
42
|
|
|
43
43
|
# The Python function to call.
|
|
44
44
|
# Note the argument annotations, just like Warp kernels.
|
|
45
|
-
def
|
|
45
|
+
def scale_func(
|
|
46
46
|
# inputs
|
|
47
47
|
a: wp.array(dtype=float),
|
|
48
48
|
b: wp.array(dtype=wp.vec2),
|
|
@@ -55,8 +55,23 @@ def example_func(
|
|
|
55
55
|
wp.launch(scale_vec_kernel, dim=b.shape, inputs=[b, s], outputs=[d])
|
|
56
56
|
|
|
57
57
|
|
|
58
|
+
@wp.kernel
|
|
59
|
+
def accum_kernel(a: wp.array(dtype=float), b: wp.array(dtype=float)):
|
|
60
|
+
tid = wp.tid()
|
|
61
|
+
b[tid] += a[tid]
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def in_out_func(
|
|
65
|
+
a: wp.array(dtype=float), # input only
|
|
66
|
+
b: wp.array(dtype=float), # input and output
|
|
67
|
+
c: wp.array(dtype=float), # output only
|
|
68
|
+
):
|
|
69
|
+
wp.launch(scale_kernel, dim=a.size, inputs=[a, 2.0], outputs=[c])
|
|
70
|
+
wp.launch(accum_kernel, dim=a.size, inputs=[a, b]) # modifies `b`
|
|
71
|
+
|
|
72
|
+
|
|
58
73
|
def example1():
|
|
59
|
-
jax_func = jax_callable(
|
|
74
|
+
jax_func = jax_callable(scale_func, num_outputs=2)
|
|
60
75
|
|
|
61
76
|
@jax.jit
|
|
62
77
|
def f():
|
|
@@ -78,7 +93,7 @@ def example1():
|
|
|
78
93
|
|
|
79
94
|
|
|
80
95
|
def example2():
|
|
81
|
-
jax_func = jax_callable(
|
|
96
|
+
jax_func = jax_callable(scale_func, num_outputs=2)
|
|
82
97
|
|
|
83
98
|
# NOTE: scalar arguments must be static compile-time constants
|
|
84
99
|
@partial(jax.jit, static_argnames=["s"])
|
|
@@ -100,11 +115,26 @@ def example2():
|
|
|
100
115
|
print(r2)
|
|
101
116
|
|
|
102
117
|
|
|
118
|
+
def example3():
|
|
119
|
+
# Using input-output arguments
|
|
120
|
+
|
|
121
|
+
jax_func = jax_callable(in_out_func, num_outputs=2, in_out_argnames=["b"])
|
|
122
|
+
|
|
123
|
+
f = jax.jit(jax_func)
|
|
124
|
+
|
|
125
|
+
a = jnp.ones(10, dtype=jnp.float32)
|
|
126
|
+
b = jnp.arange(10, dtype=jnp.float32)
|
|
127
|
+
|
|
128
|
+
b, c = f(a, b)
|
|
129
|
+
print(b)
|
|
130
|
+
print(c)
|
|
131
|
+
|
|
132
|
+
|
|
103
133
|
def main():
|
|
104
134
|
wp.init()
|
|
105
135
|
wp.load_module(device=wp.get_device())
|
|
106
136
|
|
|
107
|
-
examples = [example1, example2]
|
|
137
|
+
examples = [example1, example2, example3]
|
|
108
138
|
|
|
109
139
|
for example in examples:
|
|
110
140
|
print("\n===========================================================================")
|
|
@@ -72,6 +72,17 @@ def scale_vec_kernel(a: wp.array(dtype=wp.vec2), s: float, output: wp.array(dtyp
|
|
|
72
72
|
output[tid] = a[tid] * s
|
|
73
73
|
|
|
74
74
|
|
|
75
|
+
@wp.kernel
|
|
76
|
+
def in_out_kernel(
|
|
77
|
+
a: wp.array(dtype=float), # input only
|
|
78
|
+
b: wp.array(dtype=float), # input and output
|
|
79
|
+
c: wp.array(dtype=float), # output only
|
|
80
|
+
):
|
|
81
|
+
tid = wp.tid()
|
|
82
|
+
b[tid] += a[tid]
|
|
83
|
+
c[tid] = 2.0 * a[tid]
|
|
84
|
+
|
|
85
|
+
|
|
75
86
|
def example1():
|
|
76
87
|
# two inputs and one output
|
|
77
88
|
jax_add = jax_kernel(add_kernel)
|
|
@@ -189,11 +200,26 @@ def example7():
|
|
|
189
200
|
print(f())
|
|
190
201
|
|
|
191
202
|
|
|
203
|
+
def example8():
|
|
204
|
+
# Using input-output arguments
|
|
205
|
+
|
|
206
|
+
jax_func = jax_kernel(in_out_kernel, num_outputs=2, in_out_argnames=["b"])
|
|
207
|
+
|
|
208
|
+
f = jax.jit(jax_func)
|
|
209
|
+
|
|
210
|
+
a = jnp.ones(10, dtype=jnp.float32)
|
|
211
|
+
b = jnp.arange(10, dtype=jnp.float32)
|
|
212
|
+
|
|
213
|
+
b, c = f(a, b)
|
|
214
|
+
print(b)
|
|
215
|
+
print(c)
|
|
216
|
+
|
|
217
|
+
|
|
192
218
|
def main():
|
|
193
219
|
wp.init()
|
|
194
220
|
wp.load_module(device=wp.get_device())
|
|
195
221
|
|
|
196
|
-
examples = [example1, example2, example3, example4, example5, example6, example7]
|
|
222
|
+
examples = [example1, example2, example3, example4, example5, example6, example7, example8]
|
|
197
223
|
|
|
198
224
|
for example in examples:
|
|
199
225
|
print("\n===========================================================================")
|
warp/fem/field/virtual.py
CHANGED
|
@@ -365,6 +365,8 @@ class LocalAdjointField(SpaceField):
|
|
|
365
365
|
self._TAYLOR_DOF_COUNTS = LocalAdjointField.DofOffsets(0)
|
|
366
366
|
self.TAYLOR_DOF_COUNT = 0
|
|
367
367
|
|
|
368
|
+
cache.setup_dynamic_attributes(self)
|
|
369
|
+
|
|
368
370
|
def notify_operator_usage(self, ops: Set[operator.Operator]):
|
|
369
371
|
# Rebuild degrees-of-freedom offsets based on used operators
|
|
370
372
|
|