warp-lang 1.7.2__py3-none-win_amd64.whl → 1.8.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (181) hide show
  1. warp/__init__.py +3 -1
  2. warp/__init__.pyi +3489 -1
  3. warp/autograd.py +45 -122
  4. warp/bin/warp-clang.dll +0 -0
  5. warp/bin/warp.dll +0 -0
  6. warp/build.py +241 -252
  7. warp/build_dll.py +125 -26
  8. warp/builtins.py +1907 -384
  9. warp/codegen.py +257 -101
  10. warp/config.py +12 -1
  11. warp/constants.py +1 -1
  12. warp/context.py +657 -223
  13. warp/dlpack.py +1 -1
  14. warp/examples/benchmarks/benchmark_cloth.py +2 -2
  15. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  16. warp/examples/core/example_sample_mesh.py +1 -1
  17. warp/examples/core/example_spin_lock.py +93 -0
  18. warp/examples/core/example_work_queue.py +118 -0
  19. warp/examples/fem/example_adaptive_grid.py +5 -5
  20. warp/examples/fem/example_apic_fluid.py +1 -1
  21. warp/examples/fem/example_burgers.py +1 -1
  22. warp/examples/fem/example_convection_diffusion.py +9 -6
  23. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  24. warp/examples/fem/example_deformed_geometry.py +1 -1
  25. warp/examples/fem/example_diffusion.py +2 -2
  26. warp/examples/fem/example_diffusion_3d.py +1 -1
  27. warp/examples/fem/example_distortion_energy.py +1 -1
  28. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  29. warp/examples/fem/example_magnetostatics.py +5 -3
  30. warp/examples/fem/example_mixed_elasticity.py +5 -3
  31. warp/examples/fem/example_navier_stokes.py +11 -9
  32. warp/examples/fem/example_nonconforming_contact.py +5 -3
  33. warp/examples/fem/example_streamlines.py +8 -3
  34. warp/examples/fem/utils.py +9 -8
  35. warp/examples/interop/example_jax_ffi_callback.py +2 -2
  36. warp/examples/optim/example_drone.py +1 -1
  37. warp/examples/sim/example_cloth.py +1 -1
  38. warp/examples/sim/example_cloth_self_contact.py +48 -54
  39. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  40. warp/examples/tile/example_tile_cholesky.py +2 -1
  41. warp/examples/tile/example_tile_convolution.py +1 -1
  42. warp/examples/tile/example_tile_filtering.py +1 -1
  43. warp/examples/tile/example_tile_matmul.py +1 -1
  44. warp/examples/tile/example_tile_mlp.py +2 -0
  45. warp/fabric.py +7 -7
  46. warp/fem/__init__.py +5 -0
  47. warp/fem/adaptivity.py +1 -1
  48. warp/fem/cache.py +152 -63
  49. warp/fem/dirichlet.py +2 -2
  50. warp/fem/domain.py +136 -6
  51. warp/fem/field/field.py +141 -99
  52. warp/fem/field/nodal_field.py +85 -39
  53. warp/fem/field/virtual.py +97 -52
  54. warp/fem/geometry/adaptive_nanogrid.py +91 -86
  55. warp/fem/geometry/closest_point.py +13 -0
  56. warp/fem/geometry/deformed_geometry.py +102 -40
  57. warp/fem/geometry/element.py +56 -2
  58. warp/fem/geometry/geometry.py +323 -22
  59. warp/fem/geometry/grid_2d.py +157 -62
  60. warp/fem/geometry/grid_3d.py +116 -20
  61. warp/fem/geometry/hexmesh.py +86 -20
  62. warp/fem/geometry/nanogrid.py +166 -86
  63. warp/fem/geometry/partition.py +59 -25
  64. warp/fem/geometry/quadmesh.py +86 -135
  65. warp/fem/geometry/tetmesh.py +47 -119
  66. warp/fem/geometry/trimesh.py +77 -270
  67. warp/fem/integrate.py +107 -52
  68. warp/fem/linalg.py +25 -58
  69. warp/fem/operator.py +124 -27
  70. warp/fem/quadrature/pic_quadrature.py +36 -14
  71. warp/fem/quadrature/quadrature.py +40 -16
  72. warp/fem/space/__init__.py +1 -1
  73. warp/fem/space/basis_function_space.py +66 -46
  74. warp/fem/space/basis_space.py +17 -4
  75. warp/fem/space/dof_mapper.py +1 -1
  76. warp/fem/space/function_space.py +2 -2
  77. warp/fem/space/grid_2d_function_space.py +4 -1
  78. warp/fem/space/hexmesh_function_space.py +4 -2
  79. warp/fem/space/nanogrid_function_space.py +3 -1
  80. warp/fem/space/partition.py +11 -2
  81. warp/fem/space/quadmesh_function_space.py +4 -1
  82. warp/fem/space/restriction.py +5 -2
  83. warp/fem/space/shape/__init__.py +10 -8
  84. warp/fem/space/tetmesh_function_space.py +4 -1
  85. warp/fem/space/topology.py +52 -21
  86. warp/fem/space/trimesh_function_space.py +4 -1
  87. warp/fem/utils.py +53 -8
  88. warp/jax.py +1 -2
  89. warp/jax_experimental/ffi.py +12 -17
  90. warp/jax_experimental/xla_ffi.py +37 -24
  91. warp/math.py +171 -1
  92. warp/native/array.h +99 -0
  93. warp/native/builtin.h +174 -31
  94. warp/native/coloring.cpp +1 -1
  95. warp/native/exports.h +118 -63
  96. warp/native/intersect.h +3 -3
  97. warp/native/mat.h +5 -10
  98. warp/native/mathdx.cpp +11 -5
  99. warp/native/matnn.h +1 -123
  100. warp/native/quat.h +28 -4
  101. warp/native/sparse.cpp +121 -258
  102. warp/native/sparse.cu +181 -274
  103. warp/native/spatial.h +305 -17
  104. warp/native/tile.h +583 -72
  105. warp/native/tile_radix_sort.h +1108 -0
  106. warp/native/tile_reduce.h +237 -2
  107. warp/native/tile_scan.h +240 -0
  108. warp/native/tuple.h +189 -0
  109. warp/native/vec.h +6 -16
  110. warp/native/warp.cpp +36 -4
  111. warp/native/warp.cu +574 -51
  112. warp/native/warp.h +47 -74
  113. warp/optim/linear.py +5 -1
  114. warp/paddle.py +7 -8
  115. warp/py.typed +0 -0
  116. warp/render/render_opengl.py +58 -29
  117. warp/render/render_usd.py +124 -61
  118. warp/sim/__init__.py +9 -0
  119. warp/sim/collide.py +252 -78
  120. warp/sim/graph_coloring.py +8 -1
  121. warp/sim/import_mjcf.py +4 -3
  122. warp/sim/import_usd.py +11 -7
  123. warp/sim/integrator.py +5 -2
  124. warp/sim/integrator_euler.py +1 -1
  125. warp/sim/integrator_featherstone.py +1 -1
  126. warp/sim/integrator_vbd.py +751 -320
  127. warp/sim/integrator_xpbd.py +1 -1
  128. warp/sim/model.py +265 -260
  129. warp/sim/utils.py +10 -7
  130. warp/sparse.py +303 -166
  131. warp/tape.py +52 -51
  132. warp/tests/cuda/test_conditional_captures.py +1046 -0
  133. warp/tests/cuda/test_streams.py +1 -1
  134. warp/tests/geometry/test_volume.py +2 -2
  135. warp/tests/interop/test_dlpack.py +9 -9
  136. warp/tests/interop/test_jax.py +0 -1
  137. warp/tests/run_coverage_serial.py +1 -1
  138. warp/tests/sim/disabled_kinematics.py +2 -2
  139. warp/tests/sim/{test_vbd.py → test_cloth.py} +296 -113
  140. warp/tests/sim/test_collision.py +159 -51
  141. warp/tests/sim/test_coloring.py +15 -1
  142. warp/tests/test_array.py +254 -2
  143. warp/tests/test_array_reduce.py +2 -2
  144. warp/tests/test_atomic_cas.py +299 -0
  145. warp/tests/test_codegen.py +142 -19
  146. warp/tests/test_conditional.py +47 -1
  147. warp/tests/test_ctypes.py +0 -20
  148. warp/tests/test_devices.py +8 -0
  149. warp/tests/test_fabricarray.py +4 -2
  150. warp/tests/test_fem.py +58 -25
  151. warp/tests/test_func.py +42 -1
  152. warp/tests/test_grad.py +1 -1
  153. warp/tests/test_lerp.py +1 -3
  154. warp/tests/test_map.py +481 -0
  155. warp/tests/test_mat.py +1 -24
  156. warp/tests/test_quat.py +6 -15
  157. warp/tests/test_rounding.py +10 -38
  158. warp/tests/test_runlength_encode.py +7 -7
  159. warp/tests/test_smoothstep.py +1 -1
  160. warp/tests/test_sparse.py +51 -2
  161. warp/tests/test_spatial.py +507 -1
  162. warp/tests/test_struct.py +2 -2
  163. warp/tests/test_tuple.py +265 -0
  164. warp/tests/test_types.py +2 -2
  165. warp/tests/test_utils.py +24 -18
  166. warp/tests/tile/test_tile.py +420 -1
  167. warp/tests/tile/test_tile_mathdx.py +518 -14
  168. warp/tests/tile/test_tile_reduce.py +213 -0
  169. warp/tests/tile/test_tile_shared_memory.py +130 -1
  170. warp/tests/tile/test_tile_sort.py +117 -0
  171. warp/tests/unittest_suites.py +4 -6
  172. warp/types.py +462 -308
  173. warp/utils.py +647 -86
  174. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/METADATA +20 -6
  175. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/RECORD +178 -166
  176. warp/stubs.py +0 -3381
  177. warp/tests/sim/test_xpbd.py +0 -399
  178. warp/tests/test_mlp.py +0 -282
  179. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/WHEEL +0 -0
  180. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/licenses/LICENSE.md +0 -0
  181. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1046 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ import tempfile
18
+ import unittest
19
+
20
+ import numpy as np
21
+
22
+ import warp as wp
23
+ from warp.context import assert_conditional_graph_support
24
+ from warp.tests.unittest_utils import *
25
+
26
+
27
+ def check_conditional_graph_support():
28
+ try:
29
+ assert_conditional_graph_support()
30
+ except Exception:
31
+ return False
32
+ return True
33
+
34
+
35
+ @wp.kernel
36
+ def multiply_by_one_kernel(array: wp.array(dtype=wp.float32)):
37
+ tid = wp.tid()
38
+ array[tid] = array[tid] * 1.0
39
+
40
+
41
+ def launch_multiply_by_one(array: wp.array(dtype=wp.float32)):
42
+ wp.launch(multiply_by_one_kernel, dim=array.size, inputs=[array])
43
+
44
+
45
+ @wp.kernel
46
+ def multiply_by_two_kernel(array: wp.array(dtype=wp.float32)):
47
+ tid = wp.tid()
48
+ array[tid] = array[tid] * 2.0
49
+
50
+
51
+ def launch_multiply_by_two(array: wp.array(dtype=wp.float32)):
52
+ wp.launch(multiply_by_two_kernel, dim=array.size, inputs=[array])
53
+
54
+
55
+ @wp.kernel
56
+ def multiply_by_two_kernel_limited(
57
+ array: wp.array(dtype=wp.float32), condition: wp.array(dtype=wp.int32), limit: float
58
+ ):
59
+ tid = wp.tid()
60
+ array[tid] = array[tid] * 2.0
61
+
62
+ # set termination condition if limit exceeded
63
+ if array[tid] > limit:
64
+ condition[0] = 0
65
+
66
+
67
+ def launch_multiply_by_two_until_limit(array: wp.array(dtype=wp.float32), cond: wp.array(dtype=wp.int32), limit: float):
68
+ wp.launch(multiply_by_two_kernel_limited, dim=array.size, inputs=[array, cond, limit])
69
+
70
+
71
+ @wp.kernel
72
+ def multiply_by_three_kernel(array: wp.array(dtype=wp.float32)):
73
+ tid = wp.tid()
74
+ array[tid] = array[tid] * 3.0
75
+
76
+
77
+ def launch_multiply_by_three(array: wp.array(dtype=wp.float32)):
78
+ wp.launch(multiply_by_three_kernel, dim=array.size, inputs=[array])
79
+
80
+
81
+ @wp.kernel
82
+ def multiply_by_five_kernel(array: wp.array(dtype=wp.float32)):
83
+ tid = wp.tid()
84
+ array[tid] = array[tid] * 5.0
85
+
86
+
87
+ def launch_multiply_by_five(array: wp.array(dtype=wp.float32)):
88
+ wp.launch(multiply_by_five_kernel, dim=array.size, inputs=[array])
89
+
90
+
91
+ @wp.kernel
92
+ def multiply_by_seven_kernel(array: wp.array(dtype=wp.float32)):
93
+ tid = wp.tid()
94
+ array[tid] = array[tid] * 7.0
95
+
96
+
97
+ def launch_multiply_by_seven(array: wp.array(dtype=wp.float32)):
98
+ wp.launch(multiply_by_seven_kernel, dim=array.size, inputs=[array])
99
+
100
+
101
+ @wp.kernel
102
+ def multiply_by_eleven_kernel(array: wp.array(dtype=wp.float32)):
103
+ tid = wp.tid()
104
+ array[tid] = array[tid] * 11.0
105
+
106
+
107
+ def launch_multiply_by_eleven(array: wp.array(dtype=wp.float32)):
108
+ wp.launch(multiply_by_eleven_kernel, dim=array.size, inputs=[array])
109
+
110
+
111
+ @wp.kernel
112
+ def multiply_by_thirteen_kernel(array: wp.array(dtype=wp.float32)):
113
+ tid = wp.tid()
114
+ array[tid] = array[tid] * 13.0
115
+
116
+
117
+ def launch_multiply_by_thirteen(array: wp.array(dtype=wp.float32)):
118
+ wp.launch(multiply_by_thirteen_kernel, dim=array.size, inputs=[array])
119
+
120
+
121
+ def launch_multiply_by_two_or_thirteen(array: wp.array(dtype=wp.float32), cond: wp.array(dtype=wp.int32)):
122
+ wp.capture_if(
123
+ cond,
124
+ lambda: launch_multiply_by_two(array),
125
+ lambda: launch_multiply_by_thirteen(array),
126
+ )
127
+
128
+
129
+ def launch_multiply_by_three_or_eleven(array: wp.array(dtype=wp.float32), cond: wp.array(dtype=wp.int32)):
130
+ wp.capture_if(
131
+ cond,
132
+ lambda: launch_multiply_by_three(array),
133
+ lambda: launch_multiply_by_eleven(array),
134
+ )
135
+
136
+
137
+ @unittest.skipUnless(check_conditional_graph_support(), "Conditional graph nodes not supported")
138
+ def test_if_capture(test, device):
139
+ assert device.is_cuda
140
+
141
+ with wp.ScopedDevice(device):
142
+ array = wp.zeros(4, dtype=wp.float32)
143
+ condition = wp.zeros(1, dtype=wp.int32)
144
+
145
+ # preload module before graph capture
146
+ wp.load_module(device=device)
147
+
148
+ # capture graph
149
+ with wp.ScopedCapture(force_module_load=False) as capture:
150
+ wp.capture_if(
151
+ condition,
152
+ launch_multiply_by_two,
153
+ array=array,
154
+ )
155
+
156
+ # test different conditions
157
+ for cond in [0, 1]:
158
+ array.assign([1.0, 2.0, 3.0, 4.0])
159
+ condition.assign([cond])
160
+
161
+ wp.capture_launch(capture.graph)
162
+
163
+ if cond == 0:
164
+ expected = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)
165
+ else:
166
+ expected = np.array([2.0, 4.0, 6.0, 8.0], dtype=np.float32)
167
+
168
+ np.testing.assert_array_equal(array.numpy(), expected)
169
+
170
+
171
+ @unittest.skipUnless(check_conditional_graph_support(), "Conditional graph nodes not supported")
172
+ def test_if_capture_with_subgraph(test, device):
173
+ assert device.is_cuda
174
+
175
+ with wp.ScopedDevice(device):
176
+ array = wp.zeros(4, dtype=wp.float32)
177
+ condition = wp.zeros(1, dtype=wp.int32)
178
+
179
+ # preload module before graph capture
180
+ wp.load_module(device=device)
181
+
182
+ # capture if branch graph
183
+ with wp.ScopedCapture(force_module_load=False) as if_capture:
184
+ launch_multiply_by_two(array)
185
+
186
+ # capture main graph
187
+ with wp.ScopedCapture(force_module_load=False) as capture:
188
+ wp.capture_if(
189
+ condition,
190
+ if_capture.graph,
191
+ array=array,
192
+ )
193
+
194
+ # test different conditions
195
+ for cond in [0, 1]:
196
+ array.assign([1.0, 2.0, 3.0, 4.0])
197
+ condition.assign([cond])
198
+
199
+ wp.capture_launch(capture.graph)
200
+
201
+ if cond == 0:
202
+ expected = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)
203
+ else:
204
+ expected = np.array([2.0, 4.0, 6.0, 8.0], dtype=np.float32)
205
+
206
+ np.testing.assert_array_equal(array.numpy(), expected)
207
+
208
+
209
+ def test_if_nocapture(test, device):
210
+ with wp.ScopedDevice(device):
211
+ # test different conditions
212
+ for cond in [0, 1]:
213
+ array = wp.array([1.0, 2.0, 3.0, 4.0], dtype=wp.float32)
214
+ condition = wp.array([cond], dtype=wp.int32)
215
+
216
+ wp.capture_if(
217
+ condition,
218
+ launch_multiply_by_two,
219
+ array=array,
220
+ )
221
+
222
+ if cond == 0:
223
+ expected = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)
224
+ else:
225
+ expected = np.array([2.0, 4.0, 6.0, 8.0], dtype=np.float32)
226
+
227
+ np.testing.assert_array_equal(array.numpy(), expected)
228
+
229
+
230
+ def test_if_with_subgraph(test, device):
231
+ assert device.is_cuda
232
+
233
+ with wp.ScopedDevice(device):
234
+ # test different conditions
235
+ for cond in [0, 1]:
236
+ array = wp.array([1.0, 2.0, 3.0, 4.0], dtype=wp.float32)
237
+ condition = wp.array([cond], dtype=wp.int32)
238
+
239
+ # capture if branch graph
240
+ with wp.ScopedCapture(force_module_load=False) as if_capture:
241
+ launch_multiply_by_two(array)
242
+
243
+ wp.capture_if(
244
+ condition,
245
+ if_capture.graph,
246
+ )
247
+
248
+ if cond == 0:
249
+ expected = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)
250
+ else:
251
+ expected = np.array([2.0, 4.0, 6.0, 8.0], dtype=np.float32)
252
+
253
+ np.testing.assert_array_equal(array.numpy(), expected)
254
+
255
+
256
+ @unittest.skipUnless(check_conditional_graph_support(), "Conditional graph nodes not supported")
257
+ def test_if_else_capture(test, device):
258
+ assert device.is_cuda
259
+
260
+ with wp.ScopedDevice(device):
261
+ array = wp.zeros(4, dtype=wp.float32)
262
+ condition = wp.zeros(1, dtype=wp.int32)
263
+
264
+ # preload module before graph capture
265
+ wp.load_module(device=device)
266
+
267
+ # capture graph
268
+ with wp.ScopedCapture(force_module_load=False) as capture:
269
+ wp.capture_if(
270
+ condition,
271
+ launch_multiply_by_two,
272
+ launch_multiply_by_three,
273
+ array=array,
274
+ )
275
+
276
+ # test different conditions
277
+ for cond in [0, 1]:
278
+ array.assign([1.0, 2.0, 3.0, 4.0])
279
+ condition.assign([cond])
280
+
281
+ wp.capture_launch(capture.graph)
282
+
283
+ if cond == 0:
284
+ expected = np.array([3.0, 6.0, 9.0, 12.0], dtype=np.float32)
285
+ else:
286
+ expected = np.array([2.0, 4.0, 6.0, 8.0], dtype=np.float32)
287
+
288
+ np.testing.assert_array_equal(array.numpy(), expected)
289
+
290
+
291
+ @unittest.skipUnless(check_conditional_graph_support(), "Conditional graph nodes not supported")
292
+ def test_if_else_capture_with_subgraph(test, device):
293
+ assert device.is_cuda
294
+
295
+ with wp.ScopedDevice(device):
296
+ array = wp.zeros(4, dtype=wp.float32)
297
+ condition = wp.zeros(1, dtype=wp.int32)
298
+
299
+ # preload module before graph capture
300
+ wp.load_module(device=device)
301
+
302
+ with wp.ScopedCapture(force_module_load=False) as capture_true:
303
+ launch_multiply_by_two(array)
304
+
305
+ with wp.ScopedCapture(force_module_load=False) as capture_false:
306
+ launch_multiply_by_three(array)
307
+
308
+ # capture graph
309
+ with wp.ScopedCapture(force_module_load=False) as capture:
310
+ wp.capture_if(
311
+ condition,
312
+ capture_true.graph,
313
+ capture_false.graph,
314
+ array=array,
315
+ )
316
+
317
+ launch_multiply_by_one(array)
318
+
319
+ # test different conditions
320
+ for cond in [0, 1]:
321
+ array.assign([1.0, 2.0, 3.0, 4.0])
322
+ condition.assign([cond])
323
+
324
+ wp.capture_launch(capture.graph)
325
+
326
+ if cond == 0:
327
+ expected = np.array([3.0, 6.0, 9.0, 12.0], dtype=np.float32)
328
+ else:
329
+ expected = np.array([2.0, 4.0, 6.0, 8.0], dtype=np.float32)
330
+
331
+ np.testing.assert_array_equal(array.numpy(), expected)
332
+
333
+
334
+ def test_if_else_nocapture(test, device):
335
+ with wp.ScopedDevice(device):
336
+ # test different conditions
337
+ for cond in [0, 1]:
338
+ array = wp.array([1.0, 2.0, 3.0, 4.0], dtype=wp.float32)
339
+ condition = wp.array([cond], dtype=wp.int32)
340
+
341
+ wp.capture_if(
342
+ condition,
343
+ launch_multiply_by_two,
344
+ launch_multiply_by_three,
345
+ array=array,
346
+ )
347
+
348
+ if cond == 0:
349
+ expected = np.array([3.0, 6.0, 9.0, 12.0], dtype=np.float32)
350
+ else:
351
+ expected = np.array([2.0, 4.0, 6.0, 8.0], dtype=np.float32)
352
+
353
+ np.testing.assert_array_equal(array.numpy(), expected)
354
+
355
+
356
+ def test_if_else_with_subgraph(test, device):
357
+ with wp.ScopedDevice(device):
358
+ # test different conditions
359
+ for cond in [0, 1]:
360
+ array = wp.array([1.0, 2.0, 3.0, 4.0], dtype=wp.float32)
361
+ condition = wp.array([cond], dtype=wp.int32)
362
+
363
+ # capture if-true branch graph
364
+ with wp.ScopedCapture(force_module_load=False) as if_true_capture:
365
+ launch_multiply_by_two(array)
366
+ if_true_graph = if_true_capture.graph
367
+
368
+ # capture if-false branch graph
369
+ with wp.ScopedCapture(force_module_load=False) as if_false_capture:
370
+ launch_multiply_by_three(array)
371
+ if_false_graph = if_false_capture.graph
372
+
373
+ wp.capture_if(
374
+ condition,
375
+ if_true_graph,
376
+ if_false_graph,
377
+ )
378
+
379
+ if cond == 0:
380
+ expected = np.array([3.0, 6.0, 9.0, 12.0], dtype=np.float32)
381
+ else:
382
+ expected = np.array([2.0, 4.0, 6.0, 8.0], dtype=np.float32)
383
+
384
+ np.testing.assert_array_equal(array.numpy(), expected)
385
+
386
+
387
+ @unittest.skipUnless(check_conditional_graph_support(), "Conditional graph nodes not supported")
388
+ def test_else_capture(test, device):
389
+ assert device.is_cuda
390
+
391
+ with wp.ScopedDevice(device):
392
+ array = wp.zeros(4, dtype=wp.float32)
393
+ condition = wp.zeros(1, dtype=wp.int32)
394
+
395
+ # preload module before graph capture
396
+ wp.load_module(device=device)
397
+
398
+ # capture graph
399
+ with wp.ScopedCapture(force_module_load=False) as capture:
400
+ wp.capture_if(
401
+ condition,
402
+ on_false=launch_multiply_by_two,
403
+ array=array,
404
+ )
405
+
406
+ # test different conditions
407
+ for cond in [0, 1]:
408
+ array.assign([1.0, 2.0, 3.0, 4.0])
409
+ condition.assign([cond])
410
+
411
+ wp.capture_launch(capture.graph)
412
+
413
+ if cond == 0:
414
+ expected = np.array([2.0, 4.0, 6.0, 8.0], dtype=np.float32)
415
+ else:
416
+ expected = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)
417
+
418
+ np.testing.assert_array_equal(array.numpy(), expected)
419
+
420
+
421
+ @unittest.skipUnless(check_conditional_graph_support(), "Conditional graph nodes not supported")
422
+ def test_else_capture_with_subgraph(test, device):
423
+ assert device.is_cuda
424
+
425
+ with wp.ScopedDevice(device):
426
+ array = wp.zeros(4, dtype=wp.float32)
427
+ condition = wp.zeros(1, dtype=wp.int32)
428
+
429
+ # preload module before graph capture
430
+ wp.load_module(device=device)
431
+
432
+ # capture subgraph for multiply by two
433
+ with wp.ScopedCapture(force_module_load=False) as multiply_capture:
434
+ launch_multiply_by_two(array=array)
435
+ multiply_graph = multiply_capture.graph
436
+
437
+ # capture main graph
438
+ with wp.ScopedCapture(force_module_load=False) as capture:
439
+ wp.capture_if(
440
+ condition,
441
+ on_false=multiply_graph,
442
+ array=array,
443
+ )
444
+
445
+ # test different conditions
446
+ for cond in [0, 1]:
447
+ array.assign([1.0, 2.0, 3.0, 4.0])
448
+ condition.assign([cond])
449
+
450
+ wp.capture_launch(capture.graph)
451
+
452
+ if cond == 0:
453
+ expected = np.array([2.0, 4.0, 6.0, 8.0], dtype=np.float32)
454
+ else:
455
+ expected = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)
456
+
457
+ np.testing.assert_array_equal(array.numpy(), expected)
458
+
459
+
460
+ def test_else_nocapture(test, device):
461
+ with wp.ScopedDevice(device):
462
+ # test different conditions
463
+ for cond in [0, 1]:
464
+ array = wp.array([1.0, 2.0, 3.0, 4.0], dtype=wp.float32)
465
+ condition = wp.array([cond], dtype=wp.int32)
466
+
467
+ wp.capture_if(
468
+ condition,
469
+ on_false=launch_multiply_by_two,
470
+ array=array,
471
+ )
472
+
473
+ if cond == 0:
474
+ expected = np.array([2.0, 4.0, 6.0, 8.0], dtype=np.float32)
475
+ else:
476
+ expected = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)
477
+
478
+ np.testing.assert_array_equal(array.numpy(), expected)
479
+
480
+
481
+ def test_else_with_subgraph(test, device):
482
+ assert device.is_cuda
483
+
484
+ with wp.ScopedDevice(device):
485
+ # test different conditions
486
+ for cond in [0, 1]:
487
+ array = wp.array([1.0, 2.0, 3.0, 4.0], dtype=wp.float32)
488
+ condition = wp.array([cond], dtype=wp.int32)
489
+
490
+ # capture else branch graph
491
+ with wp.ScopedCapture(force_module_load=False) as else_capture:
492
+ launch_multiply_by_two(array)
493
+ else_graph = else_capture.graph
494
+
495
+ wp.capture_if(
496
+ condition,
497
+ on_false=else_graph,
498
+ )
499
+
500
+ if cond == 0:
501
+ expected = np.array([2.0, 4.0, 6.0, 8.0], dtype=np.float32)
502
+ else:
503
+ expected = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)
504
+
505
+ np.testing.assert_array_equal(array.numpy(), expected)
506
+
507
+
508
+ @unittest.skipUnless(check_conditional_graph_support(), "Conditional graph nodes not supported")
509
+ def test_while_capture(test, device):
510
+ assert device.is_cuda
511
+
512
+ with wp.ScopedDevice(device):
513
+ array = wp.zeros(4, dtype=wp.float32)
514
+ condition = wp.zeros(1, dtype=wp.int32)
515
+
516
+ # preload module before graph capture
517
+ wp.load_module(device=device)
518
+
519
+ # capture graph
520
+ with wp.ScopedCapture(force_module_load=False) as capture:
521
+ wp.capture_while(
522
+ condition,
523
+ launch_multiply_by_two_until_limit,
524
+ array=array,
525
+ cond=condition,
526
+ limit=1000,
527
+ )
528
+
529
+ # test different conditions
530
+ for cond in [0, 1]:
531
+ array.assign([1.0, 2.0, 3.0, 4.0])
532
+ condition.assign([cond])
533
+
534
+ wp.capture_launch(capture.graph)
535
+
536
+ # Check the output matches expected values
537
+ if cond == 0:
538
+ # No iterations executed since condition was false
539
+ expected = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)
540
+ else:
541
+ # Multiple iterations until limit reached
542
+ expected = np.array([256.0, 512.0, 768.0, 1024.0], dtype=np.float32)
543
+
544
+ np.testing.assert_array_equal(array.numpy(), expected)
545
+
546
+
547
+ @unittest.skipUnless(check_conditional_graph_support(), "Conditional graph nodes not supported")
548
+ def test_while_capture_with_subgraph(test, device):
549
+ assert device.is_cuda
550
+
551
+ with wp.ScopedDevice(device):
552
+ array = wp.zeros(4, dtype=wp.float32)
553
+ condition = wp.zeros(1, dtype=wp.int32)
554
+
555
+ # preload module before graph capture
556
+ wp.load_module(device=device)
557
+
558
+ # capture subgraph for body of while loop
559
+ with wp.ScopedCapture(force_module_load=False) as body_capture:
560
+ launch_multiply_by_two_until_limit(array=array, cond=condition, limit=1000)
561
+
562
+ # capture main graph with while node
563
+ with wp.ScopedCapture(force_module_load=False) as capture:
564
+ wp.capture_while(
565
+ condition,
566
+ body_capture.graph,
567
+ array=array,
568
+ cond=condition,
569
+ limit=1000,
570
+ )
571
+
572
+ # test different conditions
573
+ for cond in [0, 1]:
574
+ array.assign([1.0, 2.0, 3.0, 4.0])
575
+ condition.assign([cond])
576
+
577
+ wp.capture_launch(capture.graph)
578
+
579
+ # Check the output matches expected values
580
+ if cond == 0:
581
+ # No iterations executed since condition was false
582
+ expected = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)
583
+ else:
584
+ # Multiple iterations until limit reached
585
+ expected = np.array([256.0, 512.0, 768.0, 1024.0], dtype=np.float32)
586
+
587
+ np.testing.assert_array_equal(array.numpy(), expected)
588
+
589
+
590
+ def test_while_nocapture(test, device):
591
+ with wp.ScopedDevice(device):
592
+ # test different conditions
593
+ for cond in [0, 1]:
594
+ array = wp.array([1.0, 2.0, 3.0, 4.0], dtype=wp.float32)
595
+ condition = wp.array([cond], dtype=wp.int32)
596
+
597
+ wp.capture_while(
598
+ condition,
599
+ launch_multiply_by_two_until_limit,
600
+ array=array,
601
+ cond=condition,
602
+ limit=1000,
603
+ )
604
+
605
+ # Check the output matches expected values
606
+ if cond == 0:
607
+ # No iterations executed since condition was false
608
+ expected = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)
609
+ else:
610
+ # Multiple iterations until limit reached
611
+ expected = np.array([256.0, 512.0, 768.0, 1024.0], dtype=np.float32)
612
+
613
+ np.testing.assert_array_equal(array.numpy(), expected)
614
+
615
+
616
+ def test_while_with_subgraph(test, device):
617
+ with wp.ScopedDevice(device):
618
+ # test different conditions
619
+ for cond in [0, 1]:
620
+ array = wp.array([1.0, 2.0, 3.0, 4.0], dtype=wp.float32)
621
+ condition = wp.array([cond], dtype=wp.int32)
622
+
623
+ # capture body graph
624
+ with wp.ScopedCapture(force_module_load=False) as body_capture:
625
+ launch_multiply_by_two_until_limit(array=array, cond=condition, limit=1000)
626
+ body_graph = body_capture.graph
627
+
628
+ wp.capture_while(
629
+ condition,
630
+ body_graph,
631
+ )
632
+
633
+ # Check the output matches expected values
634
+ if cond == 0:
635
+ # No iterations executed since condition was false
636
+ expected = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)
637
+ else:
638
+ # Multiple iterations until limit reached
639
+ expected = np.array([256.0, 512.0, 768.0, 1024.0], dtype=np.float32)
640
+
641
+ np.testing.assert_array_equal(array.numpy(), expected)
642
+
643
+
644
+ @unittest.skipUnless(check_conditional_graph_support(), "Conditional graph nodes not supported")
645
+ def test_complex_capture(test, device):
646
+ assert device.is_cuda
647
+
648
+ with wp.ScopedDevice(device):
649
+ # data array
650
+ array = wp.zeros(4, dtype=wp.float32)
651
+
652
+ # condition arrays
653
+ condition1 = wp.zeros(1, dtype=wp.int32)
654
+ condition2 = wp.zeros(1, dtype=wp.int32)
655
+ while_condition = wp.zeros(1, dtype=wp.int32)
656
+
657
+ limit = 1000
658
+
659
+ # preload module before graph capture
660
+ wp.load_module(device=device)
661
+
662
+ # capture graph
663
+ with wp.ScopedCapture(force_module_load=False) as capture:
664
+ wp.capture_while(
665
+ while_condition,
666
+ launch_multiply_by_two_until_limit,
667
+ array=array,
668
+ cond=while_condition,
669
+ limit=limit,
670
+ )
671
+
672
+ launch_multiply_by_seven(array)
673
+
674
+ wp.capture_if(
675
+ condition1,
676
+ launch_multiply_by_two_or_thirteen, # nested if-else
677
+ launch_multiply_by_three_or_eleven, # nested if-else
678
+ array=array,
679
+ cond=condition2,
680
+ )
681
+
682
+ launch_multiply_by_five(array)
683
+
684
+ # test different conditions
685
+ for cond1 in [0, 1]:
686
+ for cond2 in [0, 1]:
687
+ for while_cond in [0, 1]:
688
+ # reset data
689
+ array.assign([1.0, 2.0, 3.0, 4.0])
690
+
691
+ # set conditions
692
+ condition1.assign([cond1])
693
+ condition2.assign([cond2])
694
+ while_condition.assign([while_cond])
695
+
696
+ # launch the graph
697
+ wp.capture_launch(capture.graph)
698
+
699
+ # calculate expected values based on conditions
700
+ base = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)
701
+ cond = while_cond
702
+ while cond != 0:
703
+ base = 2 * base
704
+ # set cond to zero if any value exceeds limit
705
+ if np.any(base > limit):
706
+ cond = 0
707
+
708
+ # multiply by 7
709
+ base *= 7.0
710
+
711
+ # apply nested conditions
712
+ if cond1:
713
+ if cond2:
714
+ base *= 2.0 # multiply by 2
715
+ else:
716
+ base *= 13.0 # multiply by 13
717
+ else:
718
+ if cond2:
719
+ base *= 3.0 # multiply by 3
720
+ else:
721
+ base *= 11.0 # multiply by 11
722
+
723
+ # multiply by 5
724
+ base *= 5.0
725
+
726
+ if not np.array_equal(array.numpy(), base):
727
+ # print(f"Conditions: while_cond={while_cond}, cond1={cond1}, cond2={cond2}, limit={limit}")
728
+ np.testing.assert_array_equal(array.numpy(), base)
729
+
730
+
731
+ @unittest.skipUnless(check_conditional_graph_support(), "Conditional graph nodes not supported")
732
+ def test_complex_capture_with_subgraphs(test, device):
733
+ assert device.is_cuda
734
+
735
+ with wp.ScopedDevice(device):
736
+ # data array
737
+ array = wp.zeros(4, dtype=wp.float32)
738
+
739
+ # condition arrays
740
+ condition1 = wp.zeros(1, dtype=wp.int32)
741
+ while_condition = wp.zeros(1, dtype=wp.int32)
742
+
743
+ limit = 1000
744
+
745
+ # preload module before graph capture
746
+ wp.load_module(device=device)
747
+
748
+ # capture subgraphs
749
+ with wp.ScopedCapture(force_module_load=False) as while_capture:
750
+ launch_multiply_by_two_until_limit(array, while_condition, limit)
751
+ while_graph = while_capture.graph
752
+
753
+ with wp.ScopedCapture(force_module_load=False) as if_true_capture:
754
+ launch_multiply_by_two(array)
755
+ launch_multiply_by_thirteen(array)
756
+ if_true_graph = if_true_capture.graph
757
+
758
+ with wp.ScopedCapture(force_module_load=False) as if_false_capture:
759
+ launch_multiply_by_three(array)
760
+ launch_multiply_by_eleven(array)
761
+ if_false_graph = if_false_capture.graph
762
+
763
+ # capture main graph
764
+ with wp.ScopedCapture(force_module_load=False) as capture:
765
+ wp.capture_while(while_condition, while_graph)
766
+
767
+ launch_multiply_by_seven(array)
768
+
769
+ wp.capture_if(condition1, if_true_graph, if_false_graph)
770
+
771
+ launch_multiply_by_five(array)
772
+
773
+ # test different conditions
774
+ for cond1 in [0, 1]:
775
+ for while_cond in [0, 1]:
776
+ # reset data
777
+ array.assign([1.0, 2.0, 3.0, 4.0])
778
+
779
+ # set conditions
780
+ condition1.assign([cond1])
781
+ while_condition.assign([while_cond])
782
+
783
+ # launch the graph
784
+ wp.capture_launch(capture.graph)
785
+
786
+ # calculate expected values based on conditions
787
+ base = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)
788
+ cond = while_cond
789
+ while cond != 0:
790
+ base = 2 * base
791
+ # set cond to zero if any value exceeds limit
792
+ if np.any(base > limit):
793
+ cond = 0
794
+
795
+ # multiply by 7
796
+ base *= 7.0
797
+
798
+ # apply nested conditions
799
+ if cond1:
800
+ base *= 2.0 # multiply by 2
801
+ base *= 13.0 # multiply by 13
802
+ else:
803
+ base *= 3.0 # multiply by 3
804
+ base *= 11.0 # multiply by 11
805
+
806
+ # multiply by 5
807
+ base *= 5.0
808
+
809
+ if not np.array_equal(array.numpy(), base):
810
+ # print(f"Conditions: while_cond={while_cond}, cond1={cond1}, cond2={cond2}, limit={limit}")
811
+ np.testing.assert_array_equal(array.numpy(), base)
812
+
813
+
814
+ def test_complex_nocapture(test, device):
815
+ with wp.ScopedDevice(device):
816
+ limit = 1000
817
+
818
+ # test different conditions
819
+ for cond1 in [0, 1]:
820
+ for cond2 in [0, 1]:
821
+ for while_cond in [0, 1]:
822
+ # set data
823
+ array = wp.array([1.0, 2.0, 3.0, 4.0], dtype=wp.float32)
824
+
825
+ # set conditions
826
+ condition1 = wp.array([cond1], dtype=wp.int32)
827
+ condition2 = wp.array([cond2], dtype=wp.int32)
828
+ while_condition = wp.array([while_cond], dtype=wp.int32)
829
+
830
+ wp.capture_while(
831
+ while_condition,
832
+ launch_multiply_by_two_until_limit,
833
+ array=array,
834
+ cond=while_condition,
835
+ limit=limit,
836
+ )
837
+
838
+ launch_multiply_by_seven(array)
839
+
840
+ wp.capture_if(
841
+ condition1,
842
+ launch_multiply_by_two_or_thirteen, # nested if-else
843
+ launch_multiply_by_three_or_eleven, # nested if-else
844
+ array=array,
845
+ cond=condition2,
846
+ )
847
+
848
+ launch_multiply_by_five(array)
849
+
850
+ # calculate expected values based on conditions
851
+ base = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)
852
+ cond = while_cond
853
+ while cond != 0:
854
+ base = 2 * base
855
+ # set cond to zero if any value exceeds limit
856
+ if np.any(base > limit):
857
+ cond = 0
858
+
859
+ # multiply by 7
860
+ base *= 7.0
861
+
862
+ # apply nested conditions
863
+ if cond1:
864
+ if cond2:
865
+ base *= 2.0 # multiply by 2
866
+ else:
867
+ base *= 13.0 # multiply by 13
868
+ else:
869
+ if cond2:
870
+ base *= 3.0 # multiply by 3
871
+ else:
872
+ base *= 11.0 # multiply by 11
873
+
874
+ # multiply by 5
875
+ base *= 5.0
876
+
877
+ if not np.array_equal(array.numpy(), base):
878
+ # print(f"Conditions: while_cond={while_cond}, cond1={cond1}, cond2={cond2}, limit={limit}")
879
+ np.testing.assert_array_equal(array.numpy(), base)
880
+
881
+
882
+ def test_complex_with_subgraphs(test, device):
883
+ with wp.ScopedDevice(device):
884
+ limit = 1000
885
+
886
+ # test different conditions
887
+ for cond1 in [0, 1]:
888
+ for while_cond in [0, 1]:
889
+ # set data
890
+ array = wp.array([1.0, 2.0, 3.0, 4.0], dtype=wp.float32)
891
+
892
+ # set conditions
893
+ condition1 = wp.array([cond1], dtype=wp.int32)
894
+ while_condition = wp.array([while_cond], dtype=wp.int32)
895
+
896
+ # capture while loop body graph
897
+ with wp.ScopedCapture(force_module_load=False) as while_body_capture:
898
+ launch_multiply_by_two_until_limit(array=array, cond=while_condition, limit=limit)
899
+ while_body_graph = while_body_capture.graph
900
+
901
+ # capture nested if-else true branch
902
+ with wp.ScopedCapture(force_module_load=False) as if_true_capture:
903
+ launch_multiply_by_two(array=array)
904
+ launch_multiply_by_thirteen(array=array)
905
+ if_true_graph = if_true_capture.graph
906
+
907
+ # capture nested if-else false branch
908
+ with wp.ScopedCapture(force_module_load=False) as if_false_capture:
909
+ launch_multiply_by_three(array=array)
910
+ launch_multiply_by_eleven(array=array)
911
+ if_false_graph = if_false_capture.graph
912
+
913
+ wp.capture_while(
914
+ while_condition,
915
+ while_body_graph,
916
+ )
917
+
918
+ launch_multiply_by_seven(array)
919
+
920
+ wp.capture_if(
921
+ condition1,
922
+ if_true_graph,
923
+ if_false_graph,
924
+ )
925
+
926
+ launch_multiply_by_five(array)
927
+
928
+ # calculate expected values based on conditions
929
+ base = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)
930
+ cond = while_cond
931
+ while cond != 0:
932
+ base = 2 * base
933
+ # set cond to zero if any value exceeds limit
934
+ if np.any(base > limit):
935
+ cond = 0
936
+
937
+ # multiply by 7
938
+ base *= 7.0
939
+
940
+ # apply nested conditions
941
+ if cond1:
942
+ base *= 2.0 # multiply by 2
943
+ base *= 13.0 # multiply by 13
944
+ else:
945
+ base *= 3.0 # multiply by 3
946
+ base *= 11.0 # multiply by 11
947
+
948
+ # multiply by 5
949
+ base *= 5.0
950
+
951
+ if not np.array_equal(array.numpy(), base):
952
+ # print(f"Conditions: while_cond={while_cond}, cond1={cond1}, cond2={cond2}, limit={limit}")
953
+ np.testing.assert_array_equal(array.numpy(), base)
954
+
955
+
956
+ def test_graph_debug_dot_print(test, device):
957
+ # create a simple graph to test dot file output
958
+ array = wp.array([1.0, 2.0, 3.0, 4.0], dtype=wp.float32, device=device)
959
+
960
+ with wp.ScopedCapture() as capture:
961
+ wp.launch(multiply_by_two_kernel, dim=array.size, inputs=[array], device=device)
962
+ wp.launch(multiply_by_three_kernel, dim=array.size, inputs=[array], device=device)
963
+ wp.launch(multiply_by_five_kernel, dim=array.size, inputs=[array], device=device)
964
+ wp.launch(multiply_by_seven_kernel, dim=array.size, inputs=[array], device=device)
965
+
966
+ # create temporary file path
967
+
968
+ temp_dir = tempfile.gettempdir()
969
+ dot_file = os.path.join(temp_dir, "test_graph.dot")
970
+
971
+ # generate dot file
972
+ wp.capture_debug_dot_print(capture.graph, dot_file, verbose=True)
973
+
974
+ # verify file was created and has content
975
+ assert os.path.exists(dot_file)
976
+ assert os.path.getsize(dot_file) > 0
977
+
978
+ # cleanup
979
+ os.remove(dot_file)
980
+
981
+
982
+ devices = get_test_devices()
983
+ cuda_devices = get_cuda_test_devices()
984
+
985
+
986
+ class TestConditionalCaptures(unittest.TestCase):
987
+ pass
988
+
989
+
990
+ # tests with graph capture
991
+ add_function_test(TestConditionalCaptures, "test_if_capture", test_if_capture, devices=cuda_devices)
992
+ add_function_test(
993
+ TestConditionalCaptures, "test_if_capture_with_subgraph", test_if_capture_with_subgraph, devices=cuda_devices
994
+ )
995
+ add_function_test(TestConditionalCaptures, "test_if_else_capture", test_if_else_capture, devices=cuda_devices)
996
+ add_function_test(
997
+ TestConditionalCaptures,
998
+ "test_if_else_capture_with_subgraph",
999
+ test_if_else_capture_with_subgraph,
1000
+ devices=cuda_devices,
1001
+ )
1002
+ add_function_test(TestConditionalCaptures, "test_else_capture", test_else_capture, devices=cuda_devices)
1003
+ add_function_test(
1004
+ TestConditionalCaptures, "test_else_capture_with_subgraph", test_else_capture_with_subgraph, devices=cuda_devices
1005
+ )
1006
+ add_function_test(TestConditionalCaptures, "test_while_capture", test_while_capture, devices=cuda_devices)
1007
+ add_function_test(
1008
+ TestConditionalCaptures, "test_while_capture_with_subgraph", test_while_capture_with_subgraph, devices=cuda_devices
1009
+ )
1010
+ add_function_test(TestConditionalCaptures, "test_complex_capture", test_complex_capture, devices=cuda_devices)
1011
+ add_function_test(
1012
+ TestConditionalCaptures,
1013
+ "test_complex_capture_with_subgraphs",
1014
+ test_complex_capture_with_subgraphs,
1015
+ devices=cuda_devices,
1016
+ )
1017
+
1018
+
1019
+ # tests without graph capture
1020
+ add_function_test(TestConditionalCaptures, "test_if_nocapture", test_if_nocapture, devices=devices)
1021
+ add_function_test(TestConditionalCaptures, "test_if_with_subgraph", test_if_with_subgraph, devices=cuda_devices)
1022
+ add_function_test(TestConditionalCaptures, "test_if_else_nocapture", test_if_else_nocapture, devices=devices)
1023
+ add_function_test(
1024
+ TestConditionalCaptures, "test_if_else_with_subgraph", test_if_else_with_subgraph, devices=cuda_devices
1025
+ )
1026
+ add_function_test(TestConditionalCaptures, "test_else_nocapture", test_else_nocapture, devices=devices)
1027
+ add_function_test(TestConditionalCaptures, "test_else_with_subgraph", test_else_with_subgraph, devices=cuda_devices)
1028
+ add_function_test(TestConditionalCaptures, "test_while_nocapture", test_while_nocapture, devices=devices)
1029
+ add_function_test(TestConditionalCaptures, "test_while_with_subgraph", test_while_with_subgraph, devices=cuda_devices)
1030
+ add_function_test(TestConditionalCaptures, "test_complex_nocapture", test_complex_nocapture, devices=devices)
1031
+ add_function_test(
1032
+ TestConditionalCaptures,
1033
+ "test_complex_with_subgraphs",
1034
+ test_complex_with_subgraphs,
1035
+ devices=cuda_devices,
1036
+ )
1037
+
1038
+
1039
+ add_function_test(
1040
+ TestConditionalCaptures, "test_graph_debug_dot_print", test_graph_debug_dot_print, devices=cuda_devices
1041
+ )
1042
+
1043
+
1044
+ if __name__ == "__main__":
1045
+ wp.clear_kernel_cache()
1046
+ unittest.main(verbosity=2, failfast=True)