warp-lang 1.7.2rc1__py3-none-manylinux_2_34_aarch64.whl → 1.8.1__py3-none-manylinux_2_34_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +3 -1
- warp/__init__.pyi +3489 -1
- warp/autograd.py +45 -122
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/build.py +241 -252
- warp/build_dll.py +130 -26
- warp/builtins.py +1907 -384
- warp/codegen.py +272 -104
- warp/config.py +12 -1
- warp/constants.py +1 -1
- warp/context.py +770 -238
- warp/dlpack.py +1 -1
- warp/examples/benchmarks/benchmark_cloth.py +2 -2
- warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
- warp/examples/core/example_sample_mesh.py +1 -1
- warp/examples/core/example_spin_lock.py +93 -0
- warp/examples/core/example_work_queue.py +118 -0
- warp/examples/fem/example_adaptive_grid.py +5 -5
- warp/examples/fem/example_apic_fluid.py +1 -1
- warp/examples/fem/example_burgers.py +1 -1
- warp/examples/fem/example_convection_diffusion.py +9 -6
- warp/examples/fem/example_darcy_ls_optimization.py +489 -0
- warp/examples/fem/example_deformed_geometry.py +1 -1
- warp/examples/fem/example_diffusion.py +2 -2
- warp/examples/fem/example_diffusion_3d.py +1 -1
- warp/examples/fem/example_distortion_energy.py +1 -1
- warp/examples/fem/example_elastic_shape_optimization.py +387 -0
- warp/examples/fem/example_magnetostatics.py +5 -3
- warp/examples/fem/example_mixed_elasticity.py +5 -3
- warp/examples/fem/example_navier_stokes.py +11 -9
- warp/examples/fem/example_nonconforming_contact.py +5 -3
- warp/examples/fem/example_streamlines.py +8 -3
- warp/examples/fem/utils.py +9 -8
- warp/examples/interop/example_jax_callable.py +34 -4
- warp/examples/interop/example_jax_ffi_callback.py +2 -2
- warp/examples/interop/example_jax_kernel.py +27 -1
- warp/examples/optim/example_drone.py +1 -1
- warp/examples/sim/example_cloth.py +1 -1
- warp/examples/sim/example_cloth_self_contact.py +48 -54
- warp/examples/tile/example_tile_block_cholesky.py +502 -0
- warp/examples/tile/example_tile_cholesky.py +2 -1
- warp/examples/tile/example_tile_convolution.py +1 -1
- warp/examples/tile/example_tile_filtering.py +1 -1
- warp/examples/tile/example_tile_matmul.py +1 -1
- warp/examples/tile/example_tile_mlp.py +2 -0
- warp/fabric.py +7 -7
- warp/fem/__init__.py +5 -0
- warp/fem/adaptivity.py +1 -1
- warp/fem/cache.py +152 -63
- warp/fem/dirichlet.py +2 -2
- warp/fem/domain.py +136 -6
- warp/fem/field/field.py +141 -99
- warp/fem/field/nodal_field.py +85 -39
- warp/fem/field/virtual.py +99 -52
- warp/fem/geometry/adaptive_nanogrid.py +91 -86
- warp/fem/geometry/closest_point.py +13 -0
- warp/fem/geometry/deformed_geometry.py +102 -40
- warp/fem/geometry/element.py +56 -2
- warp/fem/geometry/geometry.py +323 -22
- warp/fem/geometry/grid_2d.py +157 -62
- warp/fem/geometry/grid_3d.py +116 -20
- warp/fem/geometry/hexmesh.py +86 -20
- warp/fem/geometry/nanogrid.py +166 -86
- warp/fem/geometry/partition.py +59 -25
- warp/fem/geometry/quadmesh.py +86 -135
- warp/fem/geometry/tetmesh.py +47 -119
- warp/fem/geometry/trimesh.py +77 -270
- warp/fem/integrate.py +181 -95
- warp/fem/linalg.py +25 -58
- warp/fem/operator.py +124 -27
- warp/fem/quadrature/pic_quadrature.py +36 -14
- warp/fem/quadrature/quadrature.py +40 -16
- warp/fem/space/__init__.py +1 -1
- warp/fem/space/basis_function_space.py +66 -46
- warp/fem/space/basis_space.py +17 -4
- warp/fem/space/dof_mapper.py +1 -1
- warp/fem/space/function_space.py +2 -2
- warp/fem/space/grid_2d_function_space.py +4 -1
- warp/fem/space/hexmesh_function_space.py +4 -2
- warp/fem/space/nanogrid_function_space.py +3 -1
- warp/fem/space/partition.py +11 -2
- warp/fem/space/quadmesh_function_space.py +4 -1
- warp/fem/space/restriction.py +5 -2
- warp/fem/space/shape/__init__.py +10 -8
- warp/fem/space/tetmesh_function_space.py +4 -1
- warp/fem/space/topology.py +52 -21
- warp/fem/space/trimesh_function_space.py +4 -1
- warp/fem/utils.py +53 -8
- warp/jax.py +1 -2
- warp/jax_experimental/ffi.py +210 -67
- warp/jax_experimental/xla_ffi.py +37 -24
- warp/math.py +171 -1
- warp/native/array.h +103 -4
- warp/native/builtin.h +182 -35
- warp/native/coloring.cpp +6 -2
- warp/native/cuda_util.cpp +1 -1
- warp/native/exports.h +118 -63
- warp/native/intersect.h +5 -5
- warp/native/mat.h +8 -13
- warp/native/mathdx.cpp +11 -5
- warp/native/matnn.h +1 -123
- warp/native/mesh.h +1 -1
- warp/native/quat.h +34 -6
- warp/native/rand.h +7 -7
- warp/native/sparse.cpp +121 -258
- warp/native/sparse.cu +181 -274
- warp/native/spatial.h +305 -17
- warp/native/svd.h +23 -8
- warp/native/tile.h +603 -73
- warp/native/tile_radix_sort.h +1112 -0
- warp/native/tile_reduce.h +239 -13
- warp/native/tile_scan.h +240 -0
- warp/native/tuple.h +189 -0
- warp/native/vec.h +10 -20
- warp/native/warp.cpp +36 -4
- warp/native/warp.cu +588 -52
- warp/native/warp.h +47 -74
- warp/optim/linear.py +5 -1
- warp/paddle.py +7 -8
- warp/py.typed +0 -0
- warp/render/render_opengl.py +110 -80
- warp/render/render_usd.py +124 -62
- warp/sim/__init__.py +9 -0
- warp/sim/collide.py +253 -80
- warp/sim/graph_coloring.py +8 -1
- warp/sim/import_mjcf.py +4 -3
- warp/sim/import_usd.py +11 -7
- warp/sim/integrator.py +5 -2
- warp/sim/integrator_euler.py +1 -1
- warp/sim/integrator_featherstone.py +1 -1
- warp/sim/integrator_vbd.py +761 -322
- warp/sim/integrator_xpbd.py +1 -1
- warp/sim/model.py +265 -260
- warp/sim/utils.py +10 -7
- warp/sparse.py +303 -166
- warp/tape.py +54 -51
- warp/tests/cuda/test_conditional_captures.py +1046 -0
- warp/tests/cuda/test_streams.py +1 -1
- warp/tests/geometry/test_volume.py +2 -2
- warp/tests/interop/test_dlpack.py +9 -9
- warp/tests/interop/test_jax.py +0 -1
- warp/tests/run_coverage_serial.py +1 -1
- warp/tests/sim/disabled_kinematics.py +2 -2
- warp/tests/sim/{test_vbd.py → test_cloth.py} +378 -112
- warp/tests/sim/test_collision.py +159 -51
- warp/tests/sim/test_coloring.py +91 -2
- warp/tests/test_array.py +254 -2
- warp/tests/test_array_reduce.py +2 -2
- warp/tests/test_assert.py +53 -0
- warp/tests/test_atomic_cas.py +312 -0
- warp/tests/test_codegen.py +142 -19
- warp/tests/test_conditional.py +47 -1
- warp/tests/test_ctypes.py +0 -20
- warp/tests/test_devices.py +8 -0
- warp/tests/test_fabricarray.py +4 -2
- warp/tests/test_fem.py +58 -25
- warp/tests/test_func.py +42 -1
- warp/tests/test_grad.py +1 -1
- warp/tests/test_lerp.py +1 -3
- warp/tests/test_map.py +481 -0
- warp/tests/test_mat.py +23 -24
- warp/tests/test_quat.py +28 -15
- warp/tests/test_rounding.py +10 -38
- warp/tests/test_runlength_encode.py +7 -7
- warp/tests/test_smoothstep.py +1 -1
- warp/tests/test_sparse.py +83 -2
- warp/tests/test_spatial.py +507 -1
- warp/tests/test_static.py +48 -0
- warp/tests/test_struct.py +2 -2
- warp/tests/test_tape.py +38 -0
- warp/tests/test_tuple.py +265 -0
- warp/tests/test_types.py +2 -2
- warp/tests/test_utils.py +24 -18
- warp/tests/test_vec.py +38 -408
- warp/tests/test_vec_constructors.py +325 -0
- warp/tests/tile/test_tile.py +438 -131
- warp/tests/tile/test_tile_mathdx.py +518 -14
- warp/tests/tile/test_tile_matmul.py +179 -0
- warp/tests/tile/test_tile_reduce.py +307 -5
- warp/tests/tile/test_tile_shared_memory.py +136 -7
- warp/tests/tile/test_tile_sort.py +121 -0
- warp/tests/unittest_suites.py +14 -6
- warp/types.py +462 -308
- warp/utils.py +647 -86
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/METADATA +20 -6
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/RECORD +190 -176
- warp/stubs.py +0 -3381
- warp/tests/sim/test_xpbd.py +0 -399
- warp/tests/test_mlp.py +0 -282
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/WHEEL +0 -0
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/licenses/LICENSE.md +0 -0
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,1046 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
import tempfile
|
|
18
|
+
import unittest
|
|
19
|
+
|
|
20
|
+
import numpy as np
|
|
21
|
+
|
|
22
|
+
import warp as wp
|
|
23
|
+
from warp.context import assert_conditional_graph_support
|
|
24
|
+
from warp.tests.unittest_utils import *
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def check_conditional_graph_support():
|
|
28
|
+
try:
|
|
29
|
+
assert_conditional_graph_support()
|
|
30
|
+
except Exception:
|
|
31
|
+
return False
|
|
32
|
+
return True
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@wp.kernel
|
|
36
|
+
def multiply_by_one_kernel(array: wp.array(dtype=wp.float32)):
|
|
37
|
+
tid = wp.tid()
|
|
38
|
+
array[tid] = array[tid] * 1.0
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def launch_multiply_by_one(array: wp.array(dtype=wp.float32)):
|
|
42
|
+
wp.launch(multiply_by_one_kernel, dim=array.size, inputs=[array])
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@wp.kernel
|
|
46
|
+
def multiply_by_two_kernel(array: wp.array(dtype=wp.float32)):
|
|
47
|
+
tid = wp.tid()
|
|
48
|
+
array[tid] = array[tid] * 2.0
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def launch_multiply_by_two(array: wp.array(dtype=wp.float32)):
|
|
52
|
+
wp.launch(multiply_by_two_kernel, dim=array.size, inputs=[array])
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
@wp.kernel
|
|
56
|
+
def multiply_by_two_kernel_limited(
|
|
57
|
+
array: wp.array(dtype=wp.float32), condition: wp.array(dtype=wp.int32), limit: float
|
|
58
|
+
):
|
|
59
|
+
tid = wp.tid()
|
|
60
|
+
array[tid] = array[tid] * 2.0
|
|
61
|
+
|
|
62
|
+
# set termination condition if limit exceeded
|
|
63
|
+
if array[tid] > limit:
|
|
64
|
+
condition[0] = 0
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def launch_multiply_by_two_until_limit(array: wp.array(dtype=wp.float32), cond: wp.array(dtype=wp.int32), limit: float):
|
|
68
|
+
wp.launch(multiply_by_two_kernel_limited, dim=array.size, inputs=[array, cond, limit])
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
@wp.kernel
|
|
72
|
+
def multiply_by_three_kernel(array: wp.array(dtype=wp.float32)):
|
|
73
|
+
tid = wp.tid()
|
|
74
|
+
array[tid] = array[tid] * 3.0
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def launch_multiply_by_three(array: wp.array(dtype=wp.float32)):
|
|
78
|
+
wp.launch(multiply_by_three_kernel, dim=array.size, inputs=[array])
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
@wp.kernel
|
|
82
|
+
def multiply_by_five_kernel(array: wp.array(dtype=wp.float32)):
|
|
83
|
+
tid = wp.tid()
|
|
84
|
+
array[tid] = array[tid] * 5.0
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def launch_multiply_by_five(array: wp.array(dtype=wp.float32)):
|
|
88
|
+
wp.launch(multiply_by_five_kernel, dim=array.size, inputs=[array])
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
@wp.kernel
|
|
92
|
+
def multiply_by_seven_kernel(array: wp.array(dtype=wp.float32)):
|
|
93
|
+
tid = wp.tid()
|
|
94
|
+
array[tid] = array[tid] * 7.0
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def launch_multiply_by_seven(array: wp.array(dtype=wp.float32)):
|
|
98
|
+
wp.launch(multiply_by_seven_kernel, dim=array.size, inputs=[array])
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
@wp.kernel
|
|
102
|
+
def multiply_by_eleven_kernel(array: wp.array(dtype=wp.float32)):
|
|
103
|
+
tid = wp.tid()
|
|
104
|
+
array[tid] = array[tid] * 11.0
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def launch_multiply_by_eleven(array: wp.array(dtype=wp.float32)):
|
|
108
|
+
wp.launch(multiply_by_eleven_kernel, dim=array.size, inputs=[array])
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
@wp.kernel
|
|
112
|
+
def multiply_by_thirteen_kernel(array: wp.array(dtype=wp.float32)):
|
|
113
|
+
tid = wp.tid()
|
|
114
|
+
array[tid] = array[tid] * 13.0
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def launch_multiply_by_thirteen(array: wp.array(dtype=wp.float32)):
|
|
118
|
+
wp.launch(multiply_by_thirteen_kernel, dim=array.size, inputs=[array])
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def launch_multiply_by_two_or_thirteen(array: wp.array(dtype=wp.float32), cond: wp.array(dtype=wp.int32)):
|
|
122
|
+
wp.capture_if(
|
|
123
|
+
cond,
|
|
124
|
+
lambda: launch_multiply_by_two(array),
|
|
125
|
+
lambda: launch_multiply_by_thirteen(array),
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def launch_multiply_by_three_or_eleven(array: wp.array(dtype=wp.float32), cond: wp.array(dtype=wp.int32)):
|
|
130
|
+
wp.capture_if(
|
|
131
|
+
cond,
|
|
132
|
+
lambda: launch_multiply_by_three(array),
|
|
133
|
+
lambda: launch_multiply_by_eleven(array),
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
@unittest.skipUnless(check_conditional_graph_support(), "Conditional graph nodes not supported")
|
|
138
|
+
def test_if_capture(test, device):
|
|
139
|
+
assert device.is_cuda
|
|
140
|
+
|
|
141
|
+
with wp.ScopedDevice(device):
|
|
142
|
+
array = wp.zeros(4, dtype=wp.float32)
|
|
143
|
+
condition = wp.zeros(1, dtype=wp.int32)
|
|
144
|
+
|
|
145
|
+
# preload module before graph capture
|
|
146
|
+
wp.load_module(device=device)
|
|
147
|
+
|
|
148
|
+
# capture graph
|
|
149
|
+
with wp.ScopedCapture(force_module_load=False) as capture:
|
|
150
|
+
wp.capture_if(
|
|
151
|
+
condition,
|
|
152
|
+
launch_multiply_by_two,
|
|
153
|
+
array=array,
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
# test different conditions
|
|
157
|
+
for cond in [0, 1]:
|
|
158
|
+
array.assign([1.0, 2.0, 3.0, 4.0])
|
|
159
|
+
condition.assign([cond])
|
|
160
|
+
|
|
161
|
+
wp.capture_launch(capture.graph)
|
|
162
|
+
|
|
163
|
+
if cond == 0:
|
|
164
|
+
expected = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)
|
|
165
|
+
else:
|
|
166
|
+
expected = np.array([2.0, 4.0, 6.0, 8.0], dtype=np.float32)
|
|
167
|
+
|
|
168
|
+
np.testing.assert_array_equal(array.numpy(), expected)
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
@unittest.skipUnless(check_conditional_graph_support(), "Conditional graph nodes not supported")
|
|
172
|
+
def test_if_capture_with_subgraph(test, device):
|
|
173
|
+
assert device.is_cuda
|
|
174
|
+
|
|
175
|
+
with wp.ScopedDevice(device):
|
|
176
|
+
array = wp.zeros(4, dtype=wp.float32)
|
|
177
|
+
condition = wp.zeros(1, dtype=wp.int32)
|
|
178
|
+
|
|
179
|
+
# preload module before graph capture
|
|
180
|
+
wp.load_module(device=device)
|
|
181
|
+
|
|
182
|
+
# capture if branch graph
|
|
183
|
+
with wp.ScopedCapture(force_module_load=False) as if_capture:
|
|
184
|
+
launch_multiply_by_two(array)
|
|
185
|
+
|
|
186
|
+
# capture main graph
|
|
187
|
+
with wp.ScopedCapture(force_module_load=False) as capture:
|
|
188
|
+
wp.capture_if(
|
|
189
|
+
condition,
|
|
190
|
+
if_capture.graph,
|
|
191
|
+
array=array,
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
# test different conditions
|
|
195
|
+
for cond in [0, 1]:
|
|
196
|
+
array.assign([1.0, 2.0, 3.0, 4.0])
|
|
197
|
+
condition.assign([cond])
|
|
198
|
+
|
|
199
|
+
wp.capture_launch(capture.graph)
|
|
200
|
+
|
|
201
|
+
if cond == 0:
|
|
202
|
+
expected = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)
|
|
203
|
+
else:
|
|
204
|
+
expected = np.array([2.0, 4.0, 6.0, 8.0], dtype=np.float32)
|
|
205
|
+
|
|
206
|
+
np.testing.assert_array_equal(array.numpy(), expected)
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def test_if_nocapture(test, device):
|
|
210
|
+
with wp.ScopedDevice(device):
|
|
211
|
+
# test different conditions
|
|
212
|
+
for cond in [0, 1]:
|
|
213
|
+
array = wp.array([1.0, 2.0, 3.0, 4.0], dtype=wp.float32)
|
|
214
|
+
condition = wp.array([cond], dtype=wp.int32)
|
|
215
|
+
|
|
216
|
+
wp.capture_if(
|
|
217
|
+
condition,
|
|
218
|
+
launch_multiply_by_two,
|
|
219
|
+
array=array,
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
if cond == 0:
|
|
223
|
+
expected = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)
|
|
224
|
+
else:
|
|
225
|
+
expected = np.array([2.0, 4.0, 6.0, 8.0], dtype=np.float32)
|
|
226
|
+
|
|
227
|
+
np.testing.assert_array_equal(array.numpy(), expected)
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
def test_if_with_subgraph(test, device):
|
|
231
|
+
assert device.is_cuda
|
|
232
|
+
|
|
233
|
+
with wp.ScopedDevice(device):
|
|
234
|
+
# test different conditions
|
|
235
|
+
for cond in [0, 1]:
|
|
236
|
+
array = wp.array([1.0, 2.0, 3.0, 4.0], dtype=wp.float32)
|
|
237
|
+
condition = wp.array([cond], dtype=wp.int32)
|
|
238
|
+
|
|
239
|
+
# capture if branch graph
|
|
240
|
+
with wp.ScopedCapture(force_module_load=False) as if_capture:
|
|
241
|
+
launch_multiply_by_two(array)
|
|
242
|
+
|
|
243
|
+
wp.capture_if(
|
|
244
|
+
condition,
|
|
245
|
+
if_capture.graph,
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
if cond == 0:
|
|
249
|
+
expected = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)
|
|
250
|
+
else:
|
|
251
|
+
expected = np.array([2.0, 4.0, 6.0, 8.0], dtype=np.float32)
|
|
252
|
+
|
|
253
|
+
np.testing.assert_array_equal(array.numpy(), expected)
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
@unittest.skipUnless(check_conditional_graph_support(), "Conditional graph nodes not supported")
|
|
257
|
+
def test_if_else_capture(test, device):
|
|
258
|
+
assert device.is_cuda
|
|
259
|
+
|
|
260
|
+
with wp.ScopedDevice(device):
|
|
261
|
+
array = wp.zeros(4, dtype=wp.float32)
|
|
262
|
+
condition = wp.zeros(1, dtype=wp.int32)
|
|
263
|
+
|
|
264
|
+
# preload module before graph capture
|
|
265
|
+
wp.load_module(device=device)
|
|
266
|
+
|
|
267
|
+
# capture graph
|
|
268
|
+
with wp.ScopedCapture(force_module_load=False) as capture:
|
|
269
|
+
wp.capture_if(
|
|
270
|
+
condition,
|
|
271
|
+
launch_multiply_by_two,
|
|
272
|
+
launch_multiply_by_three,
|
|
273
|
+
array=array,
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
# test different conditions
|
|
277
|
+
for cond in [0, 1]:
|
|
278
|
+
array.assign([1.0, 2.0, 3.0, 4.0])
|
|
279
|
+
condition.assign([cond])
|
|
280
|
+
|
|
281
|
+
wp.capture_launch(capture.graph)
|
|
282
|
+
|
|
283
|
+
if cond == 0:
|
|
284
|
+
expected = np.array([3.0, 6.0, 9.0, 12.0], dtype=np.float32)
|
|
285
|
+
else:
|
|
286
|
+
expected = np.array([2.0, 4.0, 6.0, 8.0], dtype=np.float32)
|
|
287
|
+
|
|
288
|
+
np.testing.assert_array_equal(array.numpy(), expected)
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
@unittest.skipUnless(check_conditional_graph_support(), "Conditional graph nodes not supported")
|
|
292
|
+
def test_if_else_capture_with_subgraph(test, device):
|
|
293
|
+
assert device.is_cuda
|
|
294
|
+
|
|
295
|
+
with wp.ScopedDevice(device):
|
|
296
|
+
array = wp.zeros(4, dtype=wp.float32)
|
|
297
|
+
condition = wp.zeros(1, dtype=wp.int32)
|
|
298
|
+
|
|
299
|
+
# preload module before graph capture
|
|
300
|
+
wp.load_module(device=device)
|
|
301
|
+
|
|
302
|
+
with wp.ScopedCapture(force_module_load=False) as capture_true:
|
|
303
|
+
launch_multiply_by_two(array)
|
|
304
|
+
|
|
305
|
+
with wp.ScopedCapture(force_module_load=False) as capture_false:
|
|
306
|
+
launch_multiply_by_three(array)
|
|
307
|
+
|
|
308
|
+
# capture graph
|
|
309
|
+
with wp.ScopedCapture(force_module_load=False) as capture:
|
|
310
|
+
wp.capture_if(
|
|
311
|
+
condition,
|
|
312
|
+
capture_true.graph,
|
|
313
|
+
capture_false.graph,
|
|
314
|
+
array=array,
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
launch_multiply_by_one(array)
|
|
318
|
+
|
|
319
|
+
# test different conditions
|
|
320
|
+
for cond in [0, 1]:
|
|
321
|
+
array.assign([1.0, 2.0, 3.0, 4.0])
|
|
322
|
+
condition.assign([cond])
|
|
323
|
+
|
|
324
|
+
wp.capture_launch(capture.graph)
|
|
325
|
+
|
|
326
|
+
if cond == 0:
|
|
327
|
+
expected = np.array([3.0, 6.0, 9.0, 12.0], dtype=np.float32)
|
|
328
|
+
else:
|
|
329
|
+
expected = np.array([2.0, 4.0, 6.0, 8.0], dtype=np.float32)
|
|
330
|
+
|
|
331
|
+
np.testing.assert_array_equal(array.numpy(), expected)
|
|
332
|
+
|
|
333
|
+
|
|
334
|
+
def test_if_else_nocapture(test, device):
|
|
335
|
+
with wp.ScopedDevice(device):
|
|
336
|
+
# test different conditions
|
|
337
|
+
for cond in [0, 1]:
|
|
338
|
+
array = wp.array([1.0, 2.0, 3.0, 4.0], dtype=wp.float32)
|
|
339
|
+
condition = wp.array([cond], dtype=wp.int32)
|
|
340
|
+
|
|
341
|
+
wp.capture_if(
|
|
342
|
+
condition,
|
|
343
|
+
launch_multiply_by_two,
|
|
344
|
+
launch_multiply_by_three,
|
|
345
|
+
array=array,
|
|
346
|
+
)
|
|
347
|
+
|
|
348
|
+
if cond == 0:
|
|
349
|
+
expected = np.array([3.0, 6.0, 9.0, 12.0], dtype=np.float32)
|
|
350
|
+
else:
|
|
351
|
+
expected = np.array([2.0, 4.0, 6.0, 8.0], dtype=np.float32)
|
|
352
|
+
|
|
353
|
+
np.testing.assert_array_equal(array.numpy(), expected)
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
def test_if_else_with_subgraph(test, device):
|
|
357
|
+
with wp.ScopedDevice(device):
|
|
358
|
+
# test different conditions
|
|
359
|
+
for cond in [0, 1]:
|
|
360
|
+
array = wp.array([1.0, 2.0, 3.0, 4.0], dtype=wp.float32)
|
|
361
|
+
condition = wp.array([cond], dtype=wp.int32)
|
|
362
|
+
|
|
363
|
+
# capture if-true branch graph
|
|
364
|
+
with wp.ScopedCapture(force_module_load=False) as if_true_capture:
|
|
365
|
+
launch_multiply_by_two(array)
|
|
366
|
+
if_true_graph = if_true_capture.graph
|
|
367
|
+
|
|
368
|
+
# capture if-false branch graph
|
|
369
|
+
with wp.ScopedCapture(force_module_load=False) as if_false_capture:
|
|
370
|
+
launch_multiply_by_three(array)
|
|
371
|
+
if_false_graph = if_false_capture.graph
|
|
372
|
+
|
|
373
|
+
wp.capture_if(
|
|
374
|
+
condition,
|
|
375
|
+
if_true_graph,
|
|
376
|
+
if_false_graph,
|
|
377
|
+
)
|
|
378
|
+
|
|
379
|
+
if cond == 0:
|
|
380
|
+
expected = np.array([3.0, 6.0, 9.0, 12.0], dtype=np.float32)
|
|
381
|
+
else:
|
|
382
|
+
expected = np.array([2.0, 4.0, 6.0, 8.0], dtype=np.float32)
|
|
383
|
+
|
|
384
|
+
np.testing.assert_array_equal(array.numpy(), expected)
|
|
385
|
+
|
|
386
|
+
|
|
387
|
+
@unittest.skipUnless(check_conditional_graph_support(), "Conditional graph nodes not supported")
|
|
388
|
+
def test_else_capture(test, device):
|
|
389
|
+
assert device.is_cuda
|
|
390
|
+
|
|
391
|
+
with wp.ScopedDevice(device):
|
|
392
|
+
array = wp.zeros(4, dtype=wp.float32)
|
|
393
|
+
condition = wp.zeros(1, dtype=wp.int32)
|
|
394
|
+
|
|
395
|
+
# preload module before graph capture
|
|
396
|
+
wp.load_module(device=device)
|
|
397
|
+
|
|
398
|
+
# capture graph
|
|
399
|
+
with wp.ScopedCapture(force_module_load=False) as capture:
|
|
400
|
+
wp.capture_if(
|
|
401
|
+
condition,
|
|
402
|
+
on_false=launch_multiply_by_two,
|
|
403
|
+
array=array,
|
|
404
|
+
)
|
|
405
|
+
|
|
406
|
+
# test different conditions
|
|
407
|
+
for cond in [0, 1]:
|
|
408
|
+
array.assign([1.0, 2.0, 3.0, 4.0])
|
|
409
|
+
condition.assign([cond])
|
|
410
|
+
|
|
411
|
+
wp.capture_launch(capture.graph)
|
|
412
|
+
|
|
413
|
+
if cond == 0:
|
|
414
|
+
expected = np.array([2.0, 4.0, 6.0, 8.0], dtype=np.float32)
|
|
415
|
+
else:
|
|
416
|
+
expected = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)
|
|
417
|
+
|
|
418
|
+
np.testing.assert_array_equal(array.numpy(), expected)
|
|
419
|
+
|
|
420
|
+
|
|
421
|
+
@unittest.skipUnless(check_conditional_graph_support(), "Conditional graph nodes not supported")
|
|
422
|
+
def test_else_capture_with_subgraph(test, device):
|
|
423
|
+
assert device.is_cuda
|
|
424
|
+
|
|
425
|
+
with wp.ScopedDevice(device):
|
|
426
|
+
array = wp.zeros(4, dtype=wp.float32)
|
|
427
|
+
condition = wp.zeros(1, dtype=wp.int32)
|
|
428
|
+
|
|
429
|
+
# preload module before graph capture
|
|
430
|
+
wp.load_module(device=device)
|
|
431
|
+
|
|
432
|
+
# capture subgraph for multiply by two
|
|
433
|
+
with wp.ScopedCapture(force_module_load=False) as multiply_capture:
|
|
434
|
+
launch_multiply_by_two(array=array)
|
|
435
|
+
multiply_graph = multiply_capture.graph
|
|
436
|
+
|
|
437
|
+
# capture main graph
|
|
438
|
+
with wp.ScopedCapture(force_module_load=False) as capture:
|
|
439
|
+
wp.capture_if(
|
|
440
|
+
condition,
|
|
441
|
+
on_false=multiply_graph,
|
|
442
|
+
array=array,
|
|
443
|
+
)
|
|
444
|
+
|
|
445
|
+
# test different conditions
|
|
446
|
+
for cond in [0, 1]:
|
|
447
|
+
array.assign([1.0, 2.0, 3.0, 4.0])
|
|
448
|
+
condition.assign([cond])
|
|
449
|
+
|
|
450
|
+
wp.capture_launch(capture.graph)
|
|
451
|
+
|
|
452
|
+
if cond == 0:
|
|
453
|
+
expected = np.array([2.0, 4.0, 6.0, 8.0], dtype=np.float32)
|
|
454
|
+
else:
|
|
455
|
+
expected = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)
|
|
456
|
+
|
|
457
|
+
np.testing.assert_array_equal(array.numpy(), expected)
|
|
458
|
+
|
|
459
|
+
|
|
460
|
+
def test_else_nocapture(test, device):
|
|
461
|
+
with wp.ScopedDevice(device):
|
|
462
|
+
# test different conditions
|
|
463
|
+
for cond in [0, 1]:
|
|
464
|
+
array = wp.array([1.0, 2.0, 3.0, 4.0], dtype=wp.float32)
|
|
465
|
+
condition = wp.array([cond], dtype=wp.int32)
|
|
466
|
+
|
|
467
|
+
wp.capture_if(
|
|
468
|
+
condition,
|
|
469
|
+
on_false=launch_multiply_by_two,
|
|
470
|
+
array=array,
|
|
471
|
+
)
|
|
472
|
+
|
|
473
|
+
if cond == 0:
|
|
474
|
+
expected = np.array([2.0, 4.0, 6.0, 8.0], dtype=np.float32)
|
|
475
|
+
else:
|
|
476
|
+
expected = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)
|
|
477
|
+
|
|
478
|
+
np.testing.assert_array_equal(array.numpy(), expected)
|
|
479
|
+
|
|
480
|
+
|
|
481
|
+
def test_else_with_subgraph(test, device):
|
|
482
|
+
assert device.is_cuda
|
|
483
|
+
|
|
484
|
+
with wp.ScopedDevice(device):
|
|
485
|
+
# test different conditions
|
|
486
|
+
for cond in [0, 1]:
|
|
487
|
+
array = wp.array([1.0, 2.0, 3.0, 4.0], dtype=wp.float32)
|
|
488
|
+
condition = wp.array([cond], dtype=wp.int32)
|
|
489
|
+
|
|
490
|
+
# capture else branch graph
|
|
491
|
+
with wp.ScopedCapture(force_module_load=False) as else_capture:
|
|
492
|
+
launch_multiply_by_two(array)
|
|
493
|
+
else_graph = else_capture.graph
|
|
494
|
+
|
|
495
|
+
wp.capture_if(
|
|
496
|
+
condition,
|
|
497
|
+
on_false=else_graph,
|
|
498
|
+
)
|
|
499
|
+
|
|
500
|
+
if cond == 0:
|
|
501
|
+
expected = np.array([2.0, 4.0, 6.0, 8.0], dtype=np.float32)
|
|
502
|
+
else:
|
|
503
|
+
expected = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)
|
|
504
|
+
|
|
505
|
+
np.testing.assert_array_equal(array.numpy(), expected)
|
|
506
|
+
|
|
507
|
+
|
|
508
|
+
@unittest.skipUnless(check_conditional_graph_support(), "Conditional graph nodes not supported")
|
|
509
|
+
def test_while_capture(test, device):
|
|
510
|
+
assert device.is_cuda
|
|
511
|
+
|
|
512
|
+
with wp.ScopedDevice(device):
|
|
513
|
+
array = wp.zeros(4, dtype=wp.float32)
|
|
514
|
+
condition = wp.zeros(1, dtype=wp.int32)
|
|
515
|
+
|
|
516
|
+
# preload module before graph capture
|
|
517
|
+
wp.load_module(device=device)
|
|
518
|
+
|
|
519
|
+
# capture graph
|
|
520
|
+
with wp.ScopedCapture(force_module_load=False) as capture:
|
|
521
|
+
wp.capture_while(
|
|
522
|
+
condition,
|
|
523
|
+
launch_multiply_by_two_until_limit,
|
|
524
|
+
array=array,
|
|
525
|
+
cond=condition,
|
|
526
|
+
limit=1000,
|
|
527
|
+
)
|
|
528
|
+
|
|
529
|
+
# test different conditions
|
|
530
|
+
for cond in [0, 1]:
|
|
531
|
+
array.assign([1.0, 2.0, 3.0, 4.0])
|
|
532
|
+
condition.assign([cond])
|
|
533
|
+
|
|
534
|
+
wp.capture_launch(capture.graph)
|
|
535
|
+
|
|
536
|
+
# Check the output matches expected values
|
|
537
|
+
if cond == 0:
|
|
538
|
+
# No iterations executed since condition was false
|
|
539
|
+
expected = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)
|
|
540
|
+
else:
|
|
541
|
+
# Multiple iterations until limit reached
|
|
542
|
+
expected = np.array([256.0, 512.0, 768.0, 1024.0], dtype=np.float32)
|
|
543
|
+
|
|
544
|
+
np.testing.assert_array_equal(array.numpy(), expected)
|
|
545
|
+
|
|
546
|
+
|
|
547
|
+
@unittest.skipUnless(check_conditional_graph_support(), "Conditional graph nodes not supported")
|
|
548
|
+
def test_while_capture_with_subgraph(test, device):
|
|
549
|
+
assert device.is_cuda
|
|
550
|
+
|
|
551
|
+
with wp.ScopedDevice(device):
|
|
552
|
+
array = wp.zeros(4, dtype=wp.float32)
|
|
553
|
+
condition = wp.zeros(1, dtype=wp.int32)
|
|
554
|
+
|
|
555
|
+
# preload module before graph capture
|
|
556
|
+
wp.load_module(device=device)
|
|
557
|
+
|
|
558
|
+
# capture subgraph for body of while loop
|
|
559
|
+
with wp.ScopedCapture(force_module_load=False) as body_capture:
|
|
560
|
+
launch_multiply_by_two_until_limit(array=array, cond=condition, limit=1000)
|
|
561
|
+
|
|
562
|
+
# capture main graph with while node
|
|
563
|
+
with wp.ScopedCapture(force_module_load=False) as capture:
|
|
564
|
+
wp.capture_while(
|
|
565
|
+
condition,
|
|
566
|
+
body_capture.graph,
|
|
567
|
+
array=array,
|
|
568
|
+
cond=condition,
|
|
569
|
+
limit=1000,
|
|
570
|
+
)
|
|
571
|
+
|
|
572
|
+
# test different conditions
|
|
573
|
+
for cond in [0, 1]:
|
|
574
|
+
array.assign([1.0, 2.0, 3.0, 4.0])
|
|
575
|
+
condition.assign([cond])
|
|
576
|
+
|
|
577
|
+
wp.capture_launch(capture.graph)
|
|
578
|
+
|
|
579
|
+
# Check the output matches expected values
|
|
580
|
+
if cond == 0:
|
|
581
|
+
# No iterations executed since condition was false
|
|
582
|
+
expected = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)
|
|
583
|
+
else:
|
|
584
|
+
# Multiple iterations until limit reached
|
|
585
|
+
expected = np.array([256.0, 512.0, 768.0, 1024.0], dtype=np.float32)
|
|
586
|
+
|
|
587
|
+
np.testing.assert_array_equal(array.numpy(), expected)
|
|
588
|
+
|
|
589
|
+
|
|
590
|
+
def test_while_nocapture(test, device):
|
|
591
|
+
with wp.ScopedDevice(device):
|
|
592
|
+
# test different conditions
|
|
593
|
+
for cond in [0, 1]:
|
|
594
|
+
array = wp.array([1.0, 2.0, 3.0, 4.0], dtype=wp.float32)
|
|
595
|
+
condition = wp.array([cond], dtype=wp.int32)
|
|
596
|
+
|
|
597
|
+
wp.capture_while(
|
|
598
|
+
condition,
|
|
599
|
+
launch_multiply_by_two_until_limit,
|
|
600
|
+
array=array,
|
|
601
|
+
cond=condition,
|
|
602
|
+
limit=1000,
|
|
603
|
+
)
|
|
604
|
+
|
|
605
|
+
# Check the output matches expected values
|
|
606
|
+
if cond == 0:
|
|
607
|
+
# No iterations executed since condition was false
|
|
608
|
+
expected = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)
|
|
609
|
+
else:
|
|
610
|
+
# Multiple iterations until limit reached
|
|
611
|
+
expected = np.array([256.0, 512.0, 768.0, 1024.0], dtype=np.float32)
|
|
612
|
+
|
|
613
|
+
np.testing.assert_array_equal(array.numpy(), expected)
|
|
614
|
+
|
|
615
|
+
|
|
616
|
+
def test_while_with_subgraph(test, device):
|
|
617
|
+
with wp.ScopedDevice(device):
|
|
618
|
+
# test different conditions
|
|
619
|
+
for cond in [0, 1]:
|
|
620
|
+
array = wp.array([1.0, 2.0, 3.0, 4.0], dtype=wp.float32)
|
|
621
|
+
condition = wp.array([cond], dtype=wp.int32)
|
|
622
|
+
|
|
623
|
+
# capture body graph
|
|
624
|
+
with wp.ScopedCapture(force_module_load=False) as body_capture:
|
|
625
|
+
launch_multiply_by_two_until_limit(array=array, cond=condition, limit=1000)
|
|
626
|
+
body_graph = body_capture.graph
|
|
627
|
+
|
|
628
|
+
wp.capture_while(
|
|
629
|
+
condition,
|
|
630
|
+
body_graph,
|
|
631
|
+
)
|
|
632
|
+
|
|
633
|
+
# Check the output matches expected values
|
|
634
|
+
if cond == 0:
|
|
635
|
+
# No iterations executed since condition was false
|
|
636
|
+
expected = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)
|
|
637
|
+
else:
|
|
638
|
+
# Multiple iterations until limit reached
|
|
639
|
+
expected = np.array([256.0, 512.0, 768.0, 1024.0], dtype=np.float32)
|
|
640
|
+
|
|
641
|
+
np.testing.assert_array_equal(array.numpy(), expected)
|
|
642
|
+
|
|
643
|
+
|
|
644
|
+
@unittest.skipUnless(check_conditional_graph_support(), "Conditional graph nodes not supported")
|
|
645
|
+
def test_complex_capture(test, device):
|
|
646
|
+
assert device.is_cuda
|
|
647
|
+
|
|
648
|
+
with wp.ScopedDevice(device):
|
|
649
|
+
# data array
|
|
650
|
+
array = wp.zeros(4, dtype=wp.float32)
|
|
651
|
+
|
|
652
|
+
# condition arrays
|
|
653
|
+
condition1 = wp.zeros(1, dtype=wp.int32)
|
|
654
|
+
condition2 = wp.zeros(1, dtype=wp.int32)
|
|
655
|
+
while_condition = wp.zeros(1, dtype=wp.int32)
|
|
656
|
+
|
|
657
|
+
limit = 1000
|
|
658
|
+
|
|
659
|
+
# preload module before graph capture
|
|
660
|
+
wp.load_module(device=device)
|
|
661
|
+
|
|
662
|
+
# capture graph
|
|
663
|
+
with wp.ScopedCapture(force_module_load=False) as capture:
|
|
664
|
+
wp.capture_while(
|
|
665
|
+
while_condition,
|
|
666
|
+
launch_multiply_by_two_until_limit,
|
|
667
|
+
array=array,
|
|
668
|
+
cond=while_condition,
|
|
669
|
+
limit=limit,
|
|
670
|
+
)
|
|
671
|
+
|
|
672
|
+
launch_multiply_by_seven(array)
|
|
673
|
+
|
|
674
|
+
wp.capture_if(
|
|
675
|
+
condition1,
|
|
676
|
+
launch_multiply_by_two_or_thirteen, # nested if-else
|
|
677
|
+
launch_multiply_by_three_or_eleven, # nested if-else
|
|
678
|
+
array=array,
|
|
679
|
+
cond=condition2,
|
|
680
|
+
)
|
|
681
|
+
|
|
682
|
+
launch_multiply_by_five(array)
|
|
683
|
+
|
|
684
|
+
# test different conditions
|
|
685
|
+
for cond1 in [0, 1]:
|
|
686
|
+
for cond2 in [0, 1]:
|
|
687
|
+
for while_cond in [0, 1]:
|
|
688
|
+
# reset data
|
|
689
|
+
array.assign([1.0, 2.0, 3.0, 4.0])
|
|
690
|
+
|
|
691
|
+
# set conditions
|
|
692
|
+
condition1.assign([cond1])
|
|
693
|
+
condition2.assign([cond2])
|
|
694
|
+
while_condition.assign([while_cond])
|
|
695
|
+
|
|
696
|
+
# launch the graph
|
|
697
|
+
wp.capture_launch(capture.graph)
|
|
698
|
+
|
|
699
|
+
# calculate expected values based on conditions
|
|
700
|
+
base = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)
|
|
701
|
+
cond = while_cond
|
|
702
|
+
while cond != 0:
|
|
703
|
+
base = 2 * base
|
|
704
|
+
# set cond to zero if any value exceeds limit
|
|
705
|
+
if np.any(base > limit):
|
|
706
|
+
cond = 0
|
|
707
|
+
|
|
708
|
+
# multiply by 7
|
|
709
|
+
base *= 7.0
|
|
710
|
+
|
|
711
|
+
# apply nested conditions
|
|
712
|
+
if cond1:
|
|
713
|
+
if cond2:
|
|
714
|
+
base *= 2.0 # multiply by 2
|
|
715
|
+
else:
|
|
716
|
+
base *= 13.0 # multiply by 13
|
|
717
|
+
else:
|
|
718
|
+
if cond2:
|
|
719
|
+
base *= 3.0 # multiply by 3
|
|
720
|
+
else:
|
|
721
|
+
base *= 11.0 # multiply by 11
|
|
722
|
+
|
|
723
|
+
# multiply by 5
|
|
724
|
+
base *= 5.0
|
|
725
|
+
|
|
726
|
+
if not np.array_equal(array.numpy(), base):
|
|
727
|
+
# print(f"Conditions: while_cond={while_cond}, cond1={cond1}, cond2={cond2}, limit={limit}")
|
|
728
|
+
np.testing.assert_array_equal(array.numpy(), base)
|
|
729
|
+
|
|
730
|
+
|
|
731
|
+
@unittest.skipUnless(check_conditional_graph_support(), "Conditional graph nodes not supported")
|
|
732
|
+
def test_complex_capture_with_subgraphs(test, device):
|
|
733
|
+
assert device.is_cuda
|
|
734
|
+
|
|
735
|
+
with wp.ScopedDevice(device):
|
|
736
|
+
# data array
|
|
737
|
+
array = wp.zeros(4, dtype=wp.float32)
|
|
738
|
+
|
|
739
|
+
# condition arrays
|
|
740
|
+
condition1 = wp.zeros(1, dtype=wp.int32)
|
|
741
|
+
while_condition = wp.zeros(1, dtype=wp.int32)
|
|
742
|
+
|
|
743
|
+
limit = 1000
|
|
744
|
+
|
|
745
|
+
# preload module before graph capture
|
|
746
|
+
wp.load_module(device=device)
|
|
747
|
+
|
|
748
|
+
# capture subgraphs
|
|
749
|
+
with wp.ScopedCapture(force_module_load=False) as while_capture:
|
|
750
|
+
launch_multiply_by_two_until_limit(array, while_condition, limit)
|
|
751
|
+
while_graph = while_capture.graph
|
|
752
|
+
|
|
753
|
+
with wp.ScopedCapture(force_module_load=False) as if_true_capture:
|
|
754
|
+
launch_multiply_by_two(array)
|
|
755
|
+
launch_multiply_by_thirteen(array)
|
|
756
|
+
if_true_graph = if_true_capture.graph
|
|
757
|
+
|
|
758
|
+
with wp.ScopedCapture(force_module_load=False) as if_false_capture:
|
|
759
|
+
launch_multiply_by_three(array)
|
|
760
|
+
launch_multiply_by_eleven(array)
|
|
761
|
+
if_false_graph = if_false_capture.graph
|
|
762
|
+
|
|
763
|
+
# capture main graph
|
|
764
|
+
with wp.ScopedCapture(force_module_load=False) as capture:
|
|
765
|
+
wp.capture_while(while_condition, while_graph)
|
|
766
|
+
|
|
767
|
+
launch_multiply_by_seven(array)
|
|
768
|
+
|
|
769
|
+
wp.capture_if(condition1, if_true_graph, if_false_graph)
|
|
770
|
+
|
|
771
|
+
launch_multiply_by_five(array)
|
|
772
|
+
|
|
773
|
+
# test different conditions
|
|
774
|
+
for cond1 in [0, 1]:
|
|
775
|
+
for while_cond in [0, 1]:
|
|
776
|
+
# reset data
|
|
777
|
+
array.assign([1.0, 2.0, 3.0, 4.0])
|
|
778
|
+
|
|
779
|
+
# set conditions
|
|
780
|
+
condition1.assign([cond1])
|
|
781
|
+
while_condition.assign([while_cond])
|
|
782
|
+
|
|
783
|
+
# launch the graph
|
|
784
|
+
wp.capture_launch(capture.graph)
|
|
785
|
+
|
|
786
|
+
# calculate expected values based on conditions
|
|
787
|
+
base = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)
|
|
788
|
+
cond = while_cond
|
|
789
|
+
while cond != 0:
|
|
790
|
+
base = 2 * base
|
|
791
|
+
# set cond to zero if any value exceeds limit
|
|
792
|
+
if np.any(base > limit):
|
|
793
|
+
cond = 0
|
|
794
|
+
|
|
795
|
+
# multiply by 7
|
|
796
|
+
base *= 7.0
|
|
797
|
+
|
|
798
|
+
# apply nested conditions
|
|
799
|
+
if cond1:
|
|
800
|
+
base *= 2.0 # multiply by 2
|
|
801
|
+
base *= 13.0 # multiply by 13
|
|
802
|
+
else:
|
|
803
|
+
base *= 3.0 # multiply by 3
|
|
804
|
+
base *= 11.0 # multiply by 11
|
|
805
|
+
|
|
806
|
+
# multiply by 5
|
|
807
|
+
base *= 5.0
|
|
808
|
+
|
|
809
|
+
if not np.array_equal(array.numpy(), base):
|
|
810
|
+
# print(f"Conditions: while_cond={while_cond}, cond1={cond1}, cond2={cond2}, limit={limit}")
|
|
811
|
+
np.testing.assert_array_equal(array.numpy(), base)
|
|
812
|
+
|
|
813
|
+
|
|
814
|
+
def test_complex_nocapture(test, device):
|
|
815
|
+
with wp.ScopedDevice(device):
|
|
816
|
+
limit = 1000
|
|
817
|
+
|
|
818
|
+
# test different conditions
|
|
819
|
+
for cond1 in [0, 1]:
|
|
820
|
+
for cond2 in [0, 1]:
|
|
821
|
+
for while_cond in [0, 1]:
|
|
822
|
+
# set data
|
|
823
|
+
array = wp.array([1.0, 2.0, 3.0, 4.0], dtype=wp.float32)
|
|
824
|
+
|
|
825
|
+
# set conditions
|
|
826
|
+
condition1 = wp.array([cond1], dtype=wp.int32)
|
|
827
|
+
condition2 = wp.array([cond2], dtype=wp.int32)
|
|
828
|
+
while_condition = wp.array([while_cond], dtype=wp.int32)
|
|
829
|
+
|
|
830
|
+
wp.capture_while(
|
|
831
|
+
while_condition,
|
|
832
|
+
launch_multiply_by_two_until_limit,
|
|
833
|
+
array=array,
|
|
834
|
+
cond=while_condition,
|
|
835
|
+
limit=limit,
|
|
836
|
+
)
|
|
837
|
+
|
|
838
|
+
launch_multiply_by_seven(array)
|
|
839
|
+
|
|
840
|
+
wp.capture_if(
|
|
841
|
+
condition1,
|
|
842
|
+
launch_multiply_by_two_or_thirteen, # nested if-else
|
|
843
|
+
launch_multiply_by_three_or_eleven, # nested if-else
|
|
844
|
+
array=array,
|
|
845
|
+
cond=condition2,
|
|
846
|
+
)
|
|
847
|
+
|
|
848
|
+
launch_multiply_by_five(array)
|
|
849
|
+
|
|
850
|
+
# calculate expected values based on conditions
|
|
851
|
+
base = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)
|
|
852
|
+
cond = while_cond
|
|
853
|
+
while cond != 0:
|
|
854
|
+
base = 2 * base
|
|
855
|
+
# set cond to zero if any value exceeds limit
|
|
856
|
+
if np.any(base > limit):
|
|
857
|
+
cond = 0
|
|
858
|
+
|
|
859
|
+
# multiply by 7
|
|
860
|
+
base *= 7.0
|
|
861
|
+
|
|
862
|
+
# apply nested conditions
|
|
863
|
+
if cond1:
|
|
864
|
+
if cond2:
|
|
865
|
+
base *= 2.0 # multiply by 2
|
|
866
|
+
else:
|
|
867
|
+
base *= 13.0 # multiply by 13
|
|
868
|
+
else:
|
|
869
|
+
if cond2:
|
|
870
|
+
base *= 3.0 # multiply by 3
|
|
871
|
+
else:
|
|
872
|
+
base *= 11.0 # multiply by 11
|
|
873
|
+
|
|
874
|
+
# multiply by 5
|
|
875
|
+
base *= 5.0
|
|
876
|
+
|
|
877
|
+
if not np.array_equal(array.numpy(), base):
|
|
878
|
+
# print(f"Conditions: while_cond={while_cond}, cond1={cond1}, cond2={cond2}, limit={limit}")
|
|
879
|
+
np.testing.assert_array_equal(array.numpy(), base)
|
|
880
|
+
|
|
881
|
+
|
|
882
|
+
def test_complex_with_subgraphs(test, device):
|
|
883
|
+
with wp.ScopedDevice(device):
|
|
884
|
+
limit = 1000
|
|
885
|
+
|
|
886
|
+
# test different conditions
|
|
887
|
+
for cond1 in [0, 1]:
|
|
888
|
+
for while_cond in [0, 1]:
|
|
889
|
+
# set data
|
|
890
|
+
array = wp.array([1.0, 2.0, 3.0, 4.0], dtype=wp.float32)
|
|
891
|
+
|
|
892
|
+
# set conditions
|
|
893
|
+
condition1 = wp.array([cond1], dtype=wp.int32)
|
|
894
|
+
while_condition = wp.array([while_cond], dtype=wp.int32)
|
|
895
|
+
|
|
896
|
+
# capture while loop body graph
|
|
897
|
+
with wp.ScopedCapture(force_module_load=False) as while_body_capture:
|
|
898
|
+
launch_multiply_by_two_until_limit(array=array, cond=while_condition, limit=limit)
|
|
899
|
+
while_body_graph = while_body_capture.graph
|
|
900
|
+
|
|
901
|
+
# capture nested if-else true branch
|
|
902
|
+
with wp.ScopedCapture(force_module_load=False) as if_true_capture:
|
|
903
|
+
launch_multiply_by_two(array=array)
|
|
904
|
+
launch_multiply_by_thirteen(array=array)
|
|
905
|
+
if_true_graph = if_true_capture.graph
|
|
906
|
+
|
|
907
|
+
# capture nested if-else false branch
|
|
908
|
+
with wp.ScopedCapture(force_module_load=False) as if_false_capture:
|
|
909
|
+
launch_multiply_by_three(array=array)
|
|
910
|
+
launch_multiply_by_eleven(array=array)
|
|
911
|
+
if_false_graph = if_false_capture.graph
|
|
912
|
+
|
|
913
|
+
wp.capture_while(
|
|
914
|
+
while_condition,
|
|
915
|
+
while_body_graph,
|
|
916
|
+
)
|
|
917
|
+
|
|
918
|
+
launch_multiply_by_seven(array)
|
|
919
|
+
|
|
920
|
+
wp.capture_if(
|
|
921
|
+
condition1,
|
|
922
|
+
if_true_graph,
|
|
923
|
+
if_false_graph,
|
|
924
|
+
)
|
|
925
|
+
|
|
926
|
+
launch_multiply_by_five(array)
|
|
927
|
+
|
|
928
|
+
# calculate expected values based on conditions
|
|
929
|
+
base = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)
|
|
930
|
+
cond = while_cond
|
|
931
|
+
while cond != 0:
|
|
932
|
+
base = 2 * base
|
|
933
|
+
# set cond to zero if any value exceeds limit
|
|
934
|
+
if np.any(base > limit):
|
|
935
|
+
cond = 0
|
|
936
|
+
|
|
937
|
+
# multiply by 7
|
|
938
|
+
base *= 7.0
|
|
939
|
+
|
|
940
|
+
# apply nested conditions
|
|
941
|
+
if cond1:
|
|
942
|
+
base *= 2.0 # multiply by 2
|
|
943
|
+
base *= 13.0 # multiply by 13
|
|
944
|
+
else:
|
|
945
|
+
base *= 3.0 # multiply by 3
|
|
946
|
+
base *= 11.0 # multiply by 11
|
|
947
|
+
|
|
948
|
+
# multiply by 5
|
|
949
|
+
base *= 5.0
|
|
950
|
+
|
|
951
|
+
if not np.array_equal(array.numpy(), base):
|
|
952
|
+
# print(f"Conditions: while_cond={while_cond}, cond1={cond1}, cond2={cond2}, limit={limit}")
|
|
953
|
+
np.testing.assert_array_equal(array.numpy(), base)
|
|
954
|
+
|
|
955
|
+
|
|
956
|
+
def test_graph_debug_dot_print(test, device):
|
|
957
|
+
# create a simple graph to test dot file output
|
|
958
|
+
array = wp.array([1.0, 2.0, 3.0, 4.0], dtype=wp.float32, device=device)
|
|
959
|
+
|
|
960
|
+
with wp.ScopedCapture() as capture:
|
|
961
|
+
wp.launch(multiply_by_two_kernel, dim=array.size, inputs=[array], device=device)
|
|
962
|
+
wp.launch(multiply_by_three_kernel, dim=array.size, inputs=[array], device=device)
|
|
963
|
+
wp.launch(multiply_by_five_kernel, dim=array.size, inputs=[array], device=device)
|
|
964
|
+
wp.launch(multiply_by_seven_kernel, dim=array.size, inputs=[array], device=device)
|
|
965
|
+
|
|
966
|
+
# create temporary file path
|
|
967
|
+
|
|
968
|
+
temp_dir = tempfile.gettempdir()
|
|
969
|
+
dot_file = os.path.join(temp_dir, "test_graph.dot")
|
|
970
|
+
|
|
971
|
+
# generate dot file
|
|
972
|
+
wp.capture_debug_dot_print(capture.graph, dot_file, verbose=True)
|
|
973
|
+
|
|
974
|
+
# verify file was created and has content
|
|
975
|
+
assert os.path.exists(dot_file)
|
|
976
|
+
assert os.path.getsize(dot_file) > 0
|
|
977
|
+
|
|
978
|
+
# cleanup
|
|
979
|
+
os.remove(dot_file)
|
|
980
|
+
|
|
981
|
+
|
|
982
|
+
devices = get_test_devices()
|
|
983
|
+
cuda_devices = get_cuda_test_devices()
|
|
984
|
+
|
|
985
|
+
|
|
986
|
+
class TestConditionalCaptures(unittest.TestCase):
|
|
987
|
+
pass
|
|
988
|
+
|
|
989
|
+
|
|
990
|
+
# tests with graph capture
|
|
991
|
+
add_function_test(TestConditionalCaptures, "test_if_capture", test_if_capture, devices=cuda_devices)
|
|
992
|
+
add_function_test(
|
|
993
|
+
TestConditionalCaptures, "test_if_capture_with_subgraph", test_if_capture_with_subgraph, devices=cuda_devices
|
|
994
|
+
)
|
|
995
|
+
add_function_test(TestConditionalCaptures, "test_if_else_capture", test_if_else_capture, devices=cuda_devices)
|
|
996
|
+
add_function_test(
|
|
997
|
+
TestConditionalCaptures,
|
|
998
|
+
"test_if_else_capture_with_subgraph",
|
|
999
|
+
test_if_else_capture_with_subgraph,
|
|
1000
|
+
devices=cuda_devices,
|
|
1001
|
+
)
|
|
1002
|
+
add_function_test(TestConditionalCaptures, "test_else_capture", test_else_capture, devices=cuda_devices)
|
|
1003
|
+
add_function_test(
|
|
1004
|
+
TestConditionalCaptures, "test_else_capture_with_subgraph", test_else_capture_with_subgraph, devices=cuda_devices
|
|
1005
|
+
)
|
|
1006
|
+
add_function_test(TestConditionalCaptures, "test_while_capture", test_while_capture, devices=cuda_devices)
|
|
1007
|
+
add_function_test(
|
|
1008
|
+
TestConditionalCaptures, "test_while_capture_with_subgraph", test_while_capture_with_subgraph, devices=cuda_devices
|
|
1009
|
+
)
|
|
1010
|
+
add_function_test(TestConditionalCaptures, "test_complex_capture", test_complex_capture, devices=cuda_devices)
|
|
1011
|
+
add_function_test(
|
|
1012
|
+
TestConditionalCaptures,
|
|
1013
|
+
"test_complex_capture_with_subgraphs",
|
|
1014
|
+
test_complex_capture_with_subgraphs,
|
|
1015
|
+
devices=cuda_devices,
|
|
1016
|
+
)
|
|
1017
|
+
|
|
1018
|
+
|
|
1019
|
+
# tests without graph capture
|
|
1020
|
+
add_function_test(TestConditionalCaptures, "test_if_nocapture", test_if_nocapture, devices=devices)
|
|
1021
|
+
add_function_test(TestConditionalCaptures, "test_if_with_subgraph", test_if_with_subgraph, devices=cuda_devices)
|
|
1022
|
+
add_function_test(TestConditionalCaptures, "test_if_else_nocapture", test_if_else_nocapture, devices=devices)
|
|
1023
|
+
add_function_test(
|
|
1024
|
+
TestConditionalCaptures, "test_if_else_with_subgraph", test_if_else_with_subgraph, devices=cuda_devices
|
|
1025
|
+
)
|
|
1026
|
+
add_function_test(TestConditionalCaptures, "test_else_nocapture", test_else_nocapture, devices=devices)
|
|
1027
|
+
add_function_test(TestConditionalCaptures, "test_else_with_subgraph", test_else_with_subgraph, devices=cuda_devices)
|
|
1028
|
+
add_function_test(TestConditionalCaptures, "test_while_nocapture", test_while_nocapture, devices=devices)
|
|
1029
|
+
add_function_test(TestConditionalCaptures, "test_while_with_subgraph", test_while_with_subgraph, devices=cuda_devices)
|
|
1030
|
+
add_function_test(TestConditionalCaptures, "test_complex_nocapture", test_complex_nocapture, devices=devices)
|
|
1031
|
+
add_function_test(
|
|
1032
|
+
TestConditionalCaptures,
|
|
1033
|
+
"test_complex_with_subgraphs",
|
|
1034
|
+
test_complex_with_subgraphs,
|
|
1035
|
+
devices=cuda_devices,
|
|
1036
|
+
)
|
|
1037
|
+
|
|
1038
|
+
|
|
1039
|
+
add_function_test(
|
|
1040
|
+
TestConditionalCaptures, "test_graph_debug_dot_print", test_graph_debug_dot_print, devices=cuda_devices
|
|
1041
|
+
)
|
|
1042
|
+
|
|
1043
|
+
|
|
1044
|
+
if __name__ == "__main__":
|
|
1045
|
+
wp.clear_kernel_cache()
|
|
1046
|
+
unittest.main(verbosity=2, failfast=True)
|