warp-lang 1.0.1__py3-none-macosx_10_13_universal2.whl → 1.1.0__py3-none-macosx_10_13_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +108 -97
- warp/__init__.pyi +1 -1
- warp/bin/libwarp-clang.dylib +0 -0
- warp/bin/libwarp.dylib +0 -0
- warp/build.py +115 -113
- warp/build_dll.py +383 -375
- warp/builtins.py +3425 -3354
- warp/codegen.py +2878 -2792
- warp/config.py +40 -36
- warp/constants.py +45 -45
- warp/context.py +5194 -5102
- warp/dlpack.py +442 -442
- warp/examples/__init__.py +16 -16
- warp/examples/assets/bear.usd +0 -0
- warp/examples/assets/bunny.usd +0 -0
- warp/examples/assets/cartpole.urdf +110 -110
- warp/examples/assets/crazyflie.usd +0 -0
- warp/examples/assets/cube.usd +0 -0
- warp/examples/assets/nv_ant.xml +92 -92
- warp/examples/assets/nv_humanoid.xml +183 -183
- warp/examples/assets/quadruped.urdf +267 -267
- warp/examples/assets/rocks.nvdb +0 -0
- warp/examples/assets/rocks.usd +0 -0
- warp/examples/assets/sphere.usd +0 -0
- warp/examples/benchmarks/benchmark_api.py +383 -383
- warp/examples/benchmarks/benchmark_cloth.py +278 -279
- warp/examples/benchmarks/benchmark_cloth_cupy.py +88 -88
- warp/examples/benchmarks/benchmark_cloth_jax.py +97 -100
- warp/examples/benchmarks/benchmark_cloth_numba.py +146 -142
- warp/examples/benchmarks/benchmark_cloth_numpy.py +77 -77
- warp/examples/benchmarks/benchmark_cloth_pytorch.py +86 -86
- warp/examples/benchmarks/benchmark_cloth_taichi.py +112 -112
- warp/examples/benchmarks/benchmark_cloth_warp.py +146 -146
- warp/examples/benchmarks/benchmark_launches.py +295 -295
- warp/examples/browse.py +29 -28
- warp/examples/core/example_dem.py +234 -221
- warp/examples/core/example_fluid.py +293 -267
- warp/examples/core/example_graph_capture.py +144 -129
- warp/examples/core/example_marching_cubes.py +188 -176
- warp/examples/core/example_mesh.py +174 -154
- warp/examples/core/example_mesh_intersect.py +205 -193
- warp/examples/core/example_nvdb.py +176 -169
- warp/examples/core/example_raycast.py +105 -89
- warp/examples/core/example_raymarch.py +199 -178
- warp/examples/core/example_render_opengl.py +185 -141
- warp/examples/core/example_sph.py +405 -389
- warp/examples/core/example_torch.py +222 -181
- warp/examples/core/example_wave.py +263 -249
- warp/examples/fem/bsr_utils.py +378 -380
- warp/examples/fem/example_apic_fluid.py +407 -391
- warp/examples/fem/example_convection_diffusion.py +182 -168
- warp/examples/fem/example_convection_diffusion_dg.py +219 -209
- warp/examples/fem/example_convection_diffusion_dg0.py +204 -194
- warp/examples/fem/example_deformed_geometry.py +177 -159
- warp/examples/fem/example_diffusion.py +201 -173
- warp/examples/fem/example_diffusion_3d.py +177 -152
- warp/examples/fem/example_diffusion_mgpu.py +221 -214
- warp/examples/fem/example_mixed_elasticity.py +244 -222
- warp/examples/fem/example_navier_stokes.py +259 -243
- warp/examples/fem/example_stokes.py +220 -192
- warp/examples/fem/example_stokes_transfer.py +265 -249
- warp/examples/fem/mesh_utils.py +133 -109
- warp/examples/fem/plot_utils.py +292 -287
- warp/examples/optim/example_bounce.py +260 -248
- warp/examples/optim/example_cloth_throw.py +222 -210
- warp/examples/optim/example_diffray.py +566 -535
- warp/examples/optim/example_drone.py +864 -835
- warp/examples/optim/example_inverse_kinematics.py +176 -169
- warp/examples/optim/example_inverse_kinematics_torch.py +185 -170
- warp/examples/optim/example_spring_cage.py +239 -234
- warp/examples/optim/example_trajectory.py +223 -201
- warp/examples/optim/example_walker.py +306 -292
- warp/examples/sim/example_cartpole.py +139 -128
- warp/examples/sim/example_cloth.py +196 -184
- warp/examples/sim/example_granular.py +124 -113
- warp/examples/sim/example_granular_collision_sdf.py +197 -185
- warp/examples/sim/example_jacobian_ik.py +236 -213
- warp/examples/sim/example_particle_chain.py +118 -106
- warp/examples/sim/example_quadruped.py +193 -179
- warp/examples/sim/example_rigid_chain.py +197 -189
- warp/examples/sim/example_rigid_contact.py +189 -176
- warp/examples/sim/example_rigid_force.py +127 -126
- warp/examples/sim/example_rigid_gyroscopic.py +109 -97
- warp/examples/sim/example_rigid_soft_contact.py +134 -124
- warp/examples/sim/example_soft_body.py +190 -178
- warp/fabric.py +337 -335
- warp/fem/__init__.py +60 -27
- warp/fem/cache.py +401 -388
- warp/fem/dirichlet.py +178 -179
- warp/fem/domain.py +262 -263
- warp/fem/field/__init__.py +100 -101
- warp/fem/field/field.py +148 -149
- warp/fem/field/nodal_field.py +298 -299
- warp/fem/field/restriction.py +22 -21
- warp/fem/field/test.py +180 -181
- warp/fem/field/trial.py +183 -183
- warp/fem/geometry/__init__.py +15 -19
- warp/fem/geometry/closest_point.py +69 -70
- warp/fem/geometry/deformed_geometry.py +270 -271
- warp/fem/geometry/element.py +744 -744
- warp/fem/geometry/geometry.py +184 -186
- warp/fem/geometry/grid_2d.py +380 -373
- warp/fem/geometry/grid_3d.py +441 -435
- warp/fem/geometry/hexmesh.py +953 -953
- warp/fem/geometry/partition.py +374 -376
- warp/fem/geometry/quadmesh_2d.py +532 -532
- warp/fem/geometry/tetmesh.py +840 -840
- warp/fem/geometry/trimesh_2d.py +577 -577
- warp/fem/integrate.py +1630 -1615
- warp/fem/operator.py +190 -191
- warp/fem/polynomial.py +214 -213
- warp/fem/quadrature/__init__.py +2 -2
- warp/fem/quadrature/pic_quadrature.py +243 -245
- warp/fem/quadrature/quadrature.py +295 -294
- warp/fem/space/__init__.py +294 -292
- warp/fem/space/basis_space.py +488 -489
- warp/fem/space/collocated_function_space.py +100 -105
- warp/fem/space/dof_mapper.py +236 -236
- warp/fem/space/function_space.py +148 -145
- warp/fem/space/grid_2d_function_space.py +267 -267
- warp/fem/space/grid_3d_function_space.py +305 -306
- warp/fem/space/hexmesh_function_space.py +350 -352
- warp/fem/space/partition.py +350 -350
- warp/fem/space/quadmesh_2d_function_space.py +368 -369
- warp/fem/space/restriction.py +158 -160
- warp/fem/space/shape/__init__.py +13 -15
- warp/fem/space/shape/cube_shape_function.py +738 -738
- warp/fem/space/shape/shape_function.py +102 -103
- warp/fem/space/shape/square_shape_function.py +611 -611
- warp/fem/space/shape/tet_shape_function.py +565 -567
- warp/fem/space/shape/triangle_shape_function.py +429 -429
- warp/fem/space/tetmesh_function_space.py +294 -292
- warp/fem/space/topology.py +297 -295
- warp/fem/space/trimesh_2d_function_space.py +223 -221
- warp/fem/types.py +77 -77
- warp/fem/utils.py +495 -495
- warp/jax.py +166 -141
- warp/jax_experimental.py +341 -339
- warp/native/array.h +1072 -1025
- warp/native/builtin.h +1560 -1560
- warp/native/bvh.cpp +398 -398
- warp/native/bvh.cu +525 -525
- warp/native/bvh.h +429 -429
- warp/native/clang/clang.cpp +495 -464
- warp/native/crt.cpp +31 -31
- warp/native/crt.h +334 -334
- warp/native/cuda_crt.h +1049 -1049
- warp/native/cuda_util.cpp +549 -540
- warp/native/cuda_util.h +288 -203
- warp/native/cutlass_gemm.cpp +34 -34
- warp/native/cutlass_gemm.cu +372 -372
- warp/native/error.cpp +66 -66
- warp/native/error.h +27 -27
- warp/native/fabric.h +228 -228
- warp/native/hashgrid.cpp +301 -278
- warp/native/hashgrid.cu +78 -77
- warp/native/hashgrid.h +227 -227
- warp/native/initializer_array.h +32 -32
- warp/native/intersect.h +1204 -1204
- warp/native/intersect_adj.h +365 -365
- warp/native/intersect_tri.h +322 -322
- warp/native/marching.cpp +2 -2
- warp/native/marching.cu +497 -497
- warp/native/marching.h +2 -2
- warp/native/mat.h +1498 -1498
- warp/native/matnn.h +333 -333
- warp/native/mesh.cpp +203 -203
- warp/native/mesh.cu +293 -293
- warp/native/mesh.h +1887 -1887
- warp/native/nanovdb/NanoVDB.h +4782 -4782
- warp/native/nanovdb/PNanoVDB.h +2553 -2553
- warp/native/nanovdb/PNanoVDBWrite.h +294 -294
- warp/native/noise.h +850 -850
- warp/native/quat.h +1084 -1084
- warp/native/rand.h +299 -299
- warp/native/range.h +108 -108
- warp/native/reduce.cpp +156 -156
- warp/native/reduce.cu +348 -348
- warp/native/runlength_encode.cpp +61 -61
- warp/native/runlength_encode.cu +46 -46
- warp/native/scan.cpp +30 -30
- warp/native/scan.cu +36 -36
- warp/native/scan.h +7 -7
- warp/native/solid_angle.h +442 -442
- warp/native/sort.cpp +94 -94
- warp/native/sort.cu +97 -97
- warp/native/sort.h +14 -14
- warp/native/sparse.cpp +337 -337
- warp/native/sparse.cu +544 -544
- warp/native/spatial.h +630 -630
- warp/native/svd.h +562 -562
- warp/native/temp_buffer.h +30 -30
- warp/native/vec.h +1132 -1132
- warp/native/volume.cpp +297 -297
- warp/native/volume.cu +32 -32
- warp/native/volume.h +538 -538
- warp/native/volume_builder.cu +425 -425
- warp/native/volume_builder.h +19 -19
- warp/native/warp.cpp +1057 -1052
- warp/native/warp.cu +2943 -2828
- warp/native/warp.h +313 -305
- warp/optim/__init__.py +9 -9
- warp/optim/adam.py +120 -120
- warp/optim/linear.py +1104 -939
- warp/optim/sgd.py +104 -92
- warp/render/__init__.py +10 -10
- warp/render/render_opengl.py +3217 -3204
- warp/render/render_usd.py +768 -749
- warp/render/utils.py +152 -150
- warp/sim/__init__.py +52 -59
- warp/sim/articulation.py +685 -685
- warp/sim/collide.py +1594 -1590
- warp/sim/import_mjcf.py +489 -481
- warp/sim/import_snu.py +220 -221
- warp/sim/import_urdf.py +536 -516
- warp/sim/import_usd.py +887 -881
- warp/sim/inertia.py +316 -317
- warp/sim/integrator.py +234 -233
- warp/sim/integrator_euler.py +1956 -1956
- warp/sim/integrator_featherstone.py +1910 -1991
- warp/sim/integrator_xpbd.py +3294 -3312
- warp/sim/model.py +4473 -4314
- warp/sim/particles.py +113 -112
- warp/sim/render.py +417 -403
- warp/sim/utils.py +413 -410
- warp/sparse.py +1227 -1227
- warp/stubs.py +2109 -2469
- warp/tape.py +1162 -225
- warp/tests/__init__.py +1 -1
- warp/tests/__main__.py +4 -4
- warp/tests/assets/torus.usda +105 -105
- warp/tests/aux_test_class_kernel.py +26 -26
- warp/tests/aux_test_compile_consts_dummy.py +10 -10
- warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -21
- warp/tests/aux_test_dependent.py +22 -22
- warp/tests/aux_test_grad_customs.py +23 -23
- warp/tests/aux_test_reference.py +11 -11
- warp/tests/aux_test_reference_reference.py +10 -10
- warp/tests/aux_test_square.py +17 -17
- warp/tests/aux_test_unresolved_func.py +14 -14
- warp/tests/aux_test_unresolved_symbol.py +14 -14
- warp/tests/disabled_kinematics.py +239 -239
- warp/tests/run_coverage_serial.py +31 -31
- warp/tests/test_adam.py +157 -157
- warp/tests/test_arithmetic.py +1124 -1124
- warp/tests/test_array.py +2417 -2326
- warp/tests/test_array_reduce.py +150 -150
- warp/tests/test_async.py +668 -656
- warp/tests/test_atomic.py +141 -141
- warp/tests/test_bool.py +204 -149
- warp/tests/test_builtins_resolution.py +1292 -1292
- warp/tests/test_bvh.py +164 -171
- warp/tests/test_closest_point_edge_edge.py +228 -228
- warp/tests/test_codegen.py +566 -553
- warp/tests/test_compile_consts.py +97 -101
- warp/tests/test_conditional.py +246 -246
- warp/tests/test_copy.py +232 -215
- warp/tests/test_ctypes.py +632 -632
- warp/tests/test_dense.py +67 -67
- warp/tests/test_devices.py +91 -98
- warp/tests/test_dlpack.py +530 -529
- warp/tests/test_examples.py +400 -378
- warp/tests/test_fabricarray.py +955 -955
- warp/tests/test_fast_math.py +62 -54
- warp/tests/test_fem.py +1277 -1278
- warp/tests/test_fp16.py +130 -130
- warp/tests/test_func.py +338 -337
- warp/tests/test_generics.py +571 -571
- warp/tests/test_grad.py +746 -640
- warp/tests/test_grad_customs.py +333 -336
- warp/tests/test_hash_grid.py +210 -164
- warp/tests/test_import.py +39 -39
- warp/tests/test_indexedarray.py +1134 -1134
- warp/tests/test_intersect.py +67 -67
- warp/tests/test_jax.py +307 -307
- warp/tests/test_large.py +167 -164
- warp/tests/test_launch.py +354 -354
- warp/tests/test_lerp.py +261 -261
- warp/tests/test_linear_solvers.py +191 -171
- warp/tests/test_lvalue.py +421 -493
- warp/tests/test_marching_cubes.py +65 -65
- warp/tests/test_mat.py +1801 -1827
- warp/tests/test_mat_lite.py +115 -115
- warp/tests/test_mat_scalar_ops.py +2907 -2889
- warp/tests/test_math.py +126 -193
- warp/tests/test_matmul.py +500 -499
- warp/tests/test_matmul_lite.py +410 -410
- warp/tests/test_mempool.py +188 -190
- warp/tests/test_mesh.py +284 -324
- warp/tests/test_mesh_query_aabb.py +228 -241
- warp/tests/test_mesh_query_point.py +692 -702
- warp/tests/test_mesh_query_ray.py +292 -303
- warp/tests/test_mlp.py +276 -276
- warp/tests/test_model.py +110 -110
- warp/tests/test_modules_lite.py +39 -39
- warp/tests/test_multigpu.py +163 -163
- warp/tests/test_noise.py +248 -248
- warp/tests/test_operators.py +250 -250
- warp/tests/test_options.py +123 -125
- warp/tests/test_peer.py +133 -137
- warp/tests/test_pinned.py +78 -78
- warp/tests/test_print.py +54 -54
- warp/tests/test_quat.py +2086 -2086
- warp/tests/test_rand.py +288 -288
- warp/tests/test_reload.py +217 -217
- warp/tests/test_rounding.py +179 -179
- warp/tests/test_runlength_encode.py +190 -190
- warp/tests/test_sim_grad.py +243 -0
- warp/tests/test_sim_kinematics.py +91 -97
- warp/tests/test_smoothstep.py +168 -168
- warp/tests/test_snippet.py +305 -266
- warp/tests/test_sparse.py +468 -460
- warp/tests/test_spatial.py +2148 -2148
- warp/tests/test_streams.py +486 -473
- warp/tests/test_struct.py +710 -675
- warp/tests/test_tape.py +173 -148
- warp/tests/test_torch.py +743 -743
- warp/tests/test_transient_module.py +87 -87
- warp/tests/test_types.py +556 -659
- warp/tests/test_utils.py +490 -499
- warp/tests/test_vec.py +1264 -1268
- warp/tests/test_vec_lite.py +73 -73
- warp/tests/test_vec_scalar_ops.py +2099 -2099
- warp/tests/test_verify_fp.py +94 -94
- warp/tests/test_volume.py +737 -736
- warp/tests/test_volume_write.py +255 -265
- warp/tests/unittest_serial.py +37 -37
- warp/tests/unittest_suites.py +363 -359
- warp/tests/unittest_utils.py +603 -578
- warp/tests/unused_test_misc.py +71 -71
- warp/tests/walkthrough_debug.py +85 -85
- warp/thirdparty/appdirs.py +598 -598
- warp/thirdparty/dlpack.py +143 -143
- warp/thirdparty/unittest_parallel.py +566 -561
- warp/torch.py +321 -295
- warp/types.py +4504 -4450
- warp/utils.py +1008 -821
- {warp_lang-1.0.1.dist-info → warp_lang-1.1.0.dist-info}/LICENSE.md +126 -126
- {warp_lang-1.0.1.dist-info → warp_lang-1.1.0.dist-info}/METADATA +338 -400
- warp_lang-1.1.0.dist-info/RECORD +352 -0
- warp/examples/assets/cube.usda +0 -42
- warp/examples/assets/sphere.usda +0 -56
- warp/examples/assets/torus.usda +0 -105
- warp_lang-1.0.1.dist-info/RECORD +0 -352
- {warp_lang-1.0.1.dist-info → warp_lang-1.1.0.dist-info}/WHEEL +0 -0
- {warp_lang-1.0.1.dist-info → warp_lang-1.1.0.dist-info}/top_level.txt +0 -0
warp/tests/test_async.py
CHANGED
|
@@ -1,656 +1,668 @@
|
|
|
1
|
-
# Copyright (c) 2023 NVIDIA CORPORATION. All rights reserved.
|
|
2
|
-
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
|
3
|
-
# and proprietary rights in and to this software, related documentation
|
|
4
|
-
# and any modifications thereto. Any use, reproduction, disclosure or
|
|
5
|
-
# distribution of this software and related documentation without an express
|
|
6
|
-
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
|
7
|
-
|
|
8
|
-
import unittest
|
|
9
|
-
|
|
10
|
-
import numpy as np
|
|
11
|
-
|
|
12
|
-
import warp as wp
|
|
13
|
-
from warp.
|
|
14
|
-
from warp.
|
|
15
|
-
|
|
16
|
-
wp.init()
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
class Capturable:
|
|
20
|
-
def __init__(self, use_graph=True, stream=None):
|
|
21
|
-
self.use_graph = use_graph
|
|
22
|
-
self.stream = stream
|
|
23
|
-
|
|
24
|
-
def __enter__(self):
|
|
25
|
-
if self.use_graph:
|
|
26
|
-
wp.capture_begin(stream=self.stream)
|
|
27
|
-
|
|
28
|
-
def __exit__(self, exc_type, exc_value, traceback):
|
|
29
|
-
if self.use_graph:
|
|
30
|
-
try:
|
|
31
|
-
# need to call capture_end() to terminate the CUDA stream capture
|
|
32
|
-
graph = wp.capture_end(stream=self.stream)
|
|
33
|
-
except:
|
|
34
|
-
# capture_end() will raise if there was an error during capture, but we squash it here
|
|
35
|
-
# if we already had an exception so that the original exception percolates to the caller
|
|
36
|
-
if exc_type is None:
|
|
37
|
-
raise
|
|
38
|
-
else:
|
|
39
|
-
# capture can succeed despite some errors during capture (e.g. cudaInvalidValue during copy)
|
|
40
|
-
# but if we had an exception during capture, don't launch the graph
|
|
41
|
-
if exc_type is None:
|
|
42
|
-
wp.capture_launch(graph, stream=self.stream)
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
@wp.kernel
|
|
46
|
-
def inc(a: wp.array(dtype=float)):
|
|
47
|
-
tid = wp.tid()
|
|
48
|
-
a[tid] = a[tid] + 1.0
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
def test_async_empty(test, device, use_mempools, use_graph):
|
|
52
|
-
with wp.ScopedDevice(device), wp.ScopedMempool(device, use_mempools):
|
|
53
|
-
n = 100
|
|
54
|
-
|
|
55
|
-
with Capturable(use_graph):
|
|
56
|
-
a = wp.empty(n, dtype=float)
|
|
57
|
-
|
|
58
|
-
test.assertIsInstance(a, wp.array)
|
|
59
|
-
test.assertIsNotNone(a.ptr)
|
|
60
|
-
test.assertEqual(a.size, n)
|
|
61
|
-
test.assertEqual(a.dtype, wp.float32)
|
|
62
|
-
test.assertEqual(a.device, device)
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
def test_async_zeros(test, device, use_mempools, use_graph):
|
|
66
|
-
with wp.ScopedDevice(device), wp.ScopedMempool(device, use_mempools):
|
|
67
|
-
n = 100
|
|
68
|
-
|
|
69
|
-
with Capturable(use_graph):
|
|
70
|
-
a = wp.zeros(n, dtype=float)
|
|
71
|
-
|
|
72
|
-
assert_np_equal(a.numpy(), np.zeros(n, dtype=np.float32))
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
def test_async_zero_v1(test, device, use_mempools, use_graph):
|
|
76
|
-
with wp.ScopedDevice(device), wp.ScopedMempool(device, use_mempools):
|
|
77
|
-
n = 100
|
|
78
|
-
|
|
79
|
-
with Capturable(use_graph):
|
|
80
|
-
a = wp.empty(n, dtype=float)
|
|
81
|
-
a.zero_()
|
|
82
|
-
|
|
83
|
-
assert_np_equal(a.numpy(), np.zeros(n, dtype=np.float32))
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
def test_async_zero_v2(test, device, use_mempools, use_graph):
|
|
87
|
-
with wp.ScopedDevice(device), wp.ScopedMempool(device, use_mempools):
|
|
88
|
-
n = 100
|
|
89
|
-
|
|
90
|
-
a = wp.empty(n, dtype=float)
|
|
91
|
-
|
|
92
|
-
with Capturable(use_graph):
|
|
93
|
-
a.zero_()
|
|
94
|
-
|
|
95
|
-
assert_np_equal(a.numpy(), np.zeros(n, dtype=np.float32))
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
def test_async_full(test, device, use_mempools, use_graph):
|
|
99
|
-
with wp.ScopedDevice(device), wp.ScopedMempool(device, use_mempools):
|
|
100
|
-
n = 100
|
|
101
|
-
value = 42
|
|
102
|
-
|
|
103
|
-
with Capturable(use_graph):
|
|
104
|
-
a = wp.full(n, value, dtype=float)
|
|
105
|
-
|
|
106
|
-
assert_np_equal(a.numpy(), np.full(n, value, dtype=np.float32))
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
def test_async_fill_v1(test, device, use_mempools, use_graph):
|
|
110
|
-
with wp.ScopedDevice(device), wp.ScopedMempool(device, use_mempools):
|
|
111
|
-
n = 100
|
|
112
|
-
value = 17
|
|
113
|
-
|
|
114
|
-
with Capturable(use_graph):
|
|
115
|
-
a = wp.empty(n, dtype=float)
|
|
116
|
-
a.fill_(value)
|
|
117
|
-
|
|
118
|
-
assert_np_equal(a.numpy(), np.full(n, value, dtype=np.float32))
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
def test_async_fill_v2(test, device, use_mempools, use_graph):
|
|
122
|
-
with wp.ScopedDevice(device), wp.ScopedMempool(device, use_mempools):
|
|
123
|
-
n = 100
|
|
124
|
-
value = 17
|
|
125
|
-
|
|
126
|
-
a = wp.empty(n, dtype=float)
|
|
127
|
-
|
|
128
|
-
with Capturable(use_graph):
|
|
129
|
-
a.fill_(value)
|
|
130
|
-
|
|
131
|
-
assert_np_equal(a.numpy(), np.full(n, value, dtype=np.float32))
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
def test_async_kernels_v1(test, device, use_mempools, use_graph):
|
|
135
|
-
with wp.ScopedDevice(device), wp.ScopedMempool(device, use_mempools):
|
|
136
|
-
n = 100
|
|
137
|
-
num_iters = 10
|
|
138
|
-
|
|
139
|
-
with Capturable(use_graph):
|
|
140
|
-
a = wp.zeros(n, dtype=float)
|
|
141
|
-
for
|
|
142
|
-
wp.launch(inc, dim=a.size, inputs=[a])
|
|
143
|
-
|
|
144
|
-
assert_np_equal(a.numpy(), np.full(n, num_iters, dtype=np.float32))
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
def test_async_kernels_v2(test, device, use_mempools, use_graph):
|
|
148
|
-
with wp.ScopedDevice(device), wp.ScopedMempool(device, use_mempools):
|
|
149
|
-
n = 100
|
|
150
|
-
num_iters = 10
|
|
151
|
-
|
|
152
|
-
a = wp.zeros(n, dtype=float)
|
|
153
|
-
|
|
154
|
-
with Capturable(use_graph):
|
|
155
|
-
for
|
|
156
|
-
wp.launch(inc, dim=a.size, inputs=[a])
|
|
157
|
-
|
|
158
|
-
assert_np_equal(a.numpy(), np.full(n, num_iters, dtype=np.float32))
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
class TestAsync(unittest.TestCase):
|
|
162
|
-
pass
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
# get all CUDA devices
|
|
166
|
-
cuda_devices = wp.get_cuda_devices()
|
|
167
|
-
|
|
168
|
-
# get CUDA devices that support mempools
|
|
169
|
-
cuda_devices_with_mempools = []
|
|
170
|
-
for d in cuda_devices:
|
|
171
|
-
if d.is_mempool_supported:
|
|
172
|
-
cuda_devices_with_mempools.append(d)
|
|
173
|
-
|
|
174
|
-
# get a pair of CUDA devices that support mempool access
|
|
175
|
-
cuda_devices_with_mempool_access = []
|
|
176
|
-
for target_device in cuda_devices_with_mempools:
|
|
177
|
-
for peer_device in cuda_devices_with_mempools:
|
|
178
|
-
if peer_device != target_device:
|
|
179
|
-
if wp.is_mempool_access_supported(target_device, peer_device):
|
|
180
|
-
cuda_devices_with_mempool_access = [target_device, peer_device]
|
|
181
|
-
break
|
|
182
|
-
if cuda_devices_with_mempool_access:
|
|
183
|
-
break
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
# test that works with default allocators
|
|
193
|
-
if not graph_allocs and device_count <= len(cuda_devices):
|
|
194
|
-
devices = cuda_devices[:device_count]
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
wp.copy(
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
fa
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
self
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
wp.context.runtime.core.
|
|
422
|
-
wp.
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
wp.
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
"
|
|
452
|
-
"
|
|
453
|
-
"
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
if wp.
|
|
473
|
-
|
|
474
|
-
device_pairs["
|
|
475
|
-
if wp.
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
#
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
test_name += "
|
|
601
|
-
|
|
602
|
-
test_name += "
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
test_name += "
|
|
615
|
-
|
|
616
|
-
test_name += "
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
test_name += "
|
|
622
|
-
|
|
623
|
-
test_name += "
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
1
|
+
# Copyright (c) 2023 NVIDIA CORPORATION. All rights reserved.
|
|
2
|
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
|
3
|
+
# and proprietary rights in and to this software, related documentation
|
|
4
|
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
|
5
|
+
# distribution of this software and related documentation without an express
|
|
6
|
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
|
7
|
+
|
|
8
|
+
import unittest
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
|
|
12
|
+
import warp as wp
|
|
13
|
+
from warp.tests.unittest_utils import *
|
|
14
|
+
from warp.utils import check_iommu
|
|
15
|
+
|
|
16
|
+
wp.init()
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class Capturable:
|
|
20
|
+
def __init__(self, use_graph=True, stream=None):
|
|
21
|
+
self.use_graph = use_graph
|
|
22
|
+
self.stream = stream
|
|
23
|
+
|
|
24
|
+
def __enter__(self):
|
|
25
|
+
if self.use_graph:
|
|
26
|
+
wp.capture_begin(stream=self.stream)
|
|
27
|
+
|
|
28
|
+
def __exit__(self, exc_type, exc_value, traceback):
|
|
29
|
+
if self.use_graph:
|
|
30
|
+
try:
|
|
31
|
+
# need to call capture_end() to terminate the CUDA stream capture
|
|
32
|
+
graph = wp.capture_end(stream=self.stream)
|
|
33
|
+
except Exception:
|
|
34
|
+
# capture_end() will raise if there was an error during capture, but we squash it here
|
|
35
|
+
# if we already had an exception so that the original exception percolates to the caller
|
|
36
|
+
if exc_type is None:
|
|
37
|
+
raise
|
|
38
|
+
else:
|
|
39
|
+
# capture can succeed despite some errors during capture (e.g. cudaInvalidValue during copy)
|
|
40
|
+
# but if we had an exception during capture, don't launch the graph
|
|
41
|
+
if exc_type is None:
|
|
42
|
+
wp.capture_launch(graph, stream=self.stream)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@wp.kernel
|
|
46
|
+
def inc(a: wp.array(dtype=float)):
|
|
47
|
+
tid = wp.tid()
|
|
48
|
+
a[tid] = a[tid] + 1.0
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def test_async_empty(test, device, use_mempools, use_graph):
|
|
52
|
+
with wp.ScopedDevice(device), wp.ScopedMempool(device, use_mempools):
|
|
53
|
+
n = 100
|
|
54
|
+
|
|
55
|
+
with Capturable(use_graph):
|
|
56
|
+
a = wp.empty(n, dtype=float)
|
|
57
|
+
|
|
58
|
+
test.assertIsInstance(a, wp.array)
|
|
59
|
+
test.assertIsNotNone(a.ptr)
|
|
60
|
+
test.assertEqual(a.size, n)
|
|
61
|
+
test.assertEqual(a.dtype, wp.float32)
|
|
62
|
+
test.assertEqual(a.device, device)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def test_async_zeros(test, device, use_mempools, use_graph):
|
|
66
|
+
with wp.ScopedDevice(device), wp.ScopedMempool(device, use_mempools):
|
|
67
|
+
n = 100
|
|
68
|
+
|
|
69
|
+
with Capturable(use_graph):
|
|
70
|
+
a = wp.zeros(n, dtype=float)
|
|
71
|
+
|
|
72
|
+
assert_np_equal(a.numpy(), np.zeros(n, dtype=np.float32))
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def test_async_zero_v1(test, device, use_mempools, use_graph):
|
|
76
|
+
with wp.ScopedDevice(device), wp.ScopedMempool(device, use_mempools):
|
|
77
|
+
n = 100
|
|
78
|
+
|
|
79
|
+
with Capturable(use_graph):
|
|
80
|
+
a = wp.empty(n, dtype=float)
|
|
81
|
+
a.zero_()
|
|
82
|
+
|
|
83
|
+
assert_np_equal(a.numpy(), np.zeros(n, dtype=np.float32))
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def test_async_zero_v2(test, device, use_mempools, use_graph):
|
|
87
|
+
with wp.ScopedDevice(device), wp.ScopedMempool(device, use_mempools):
|
|
88
|
+
n = 100
|
|
89
|
+
|
|
90
|
+
a = wp.empty(n, dtype=float)
|
|
91
|
+
|
|
92
|
+
with Capturable(use_graph):
|
|
93
|
+
a.zero_()
|
|
94
|
+
|
|
95
|
+
assert_np_equal(a.numpy(), np.zeros(n, dtype=np.float32))
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def test_async_full(test, device, use_mempools, use_graph):
|
|
99
|
+
with wp.ScopedDevice(device), wp.ScopedMempool(device, use_mempools):
|
|
100
|
+
n = 100
|
|
101
|
+
value = 42
|
|
102
|
+
|
|
103
|
+
with Capturable(use_graph):
|
|
104
|
+
a = wp.full(n, value, dtype=float)
|
|
105
|
+
|
|
106
|
+
assert_np_equal(a.numpy(), np.full(n, value, dtype=np.float32))
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def test_async_fill_v1(test, device, use_mempools, use_graph):
|
|
110
|
+
with wp.ScopedDevice(device), wp.ScopedMempool(device, use_mempools):
|
|
111
|
+
n = 100
|
|
112
|
+
value = 17
|
|
113
|
+
|
|
114
|
+
with Capturable(use_graph):
|
|
115
|
+
a = wp.empty(n, dtype=float)
|
|
116
|
+
a.fill_(value)
|
|
117
|
+
|
|
118
|
+
assert_np_equal(a.numpy(), np.full(n, value, dtype=np.float32))
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def test_async_fill_v2(test, device, use_mempools, use_graph):
|
|
122
|
+
with wp.ScopedDevice(device), wp.ScopedMempool(device, use_mempools):
|
|
123
|
+
n = 100
|
|
124
|
+
value = 17
|
|
125
|
+
|
|
126
|
+
a = wp.empty(n, dtype=float)
|
|
127
|
+
|
|
128
|
+
with Capturable(use_graph):
|
|
129
|
+
a.fill_(value)
|
|
130
|
+
|
|
131
|
+
assert_np_equal(a.numpy(), np.full(n, value, dtype=np.float32))
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def test_async_kernels_v1(test, device, use_mempools, use_graph):
|
|
135
|
+
with wp.ScopedDevice(device), wp.ScopedMempool(device, use_mempools):
|
|
136
|
+
n = 100
|
|
137
|
+
num_iters = 10
|
|
138
|
+
|
|
139
|
+
with Capturable(use_graph):
|
|
140
|
+
a = wp.zeros(n, dtype=float)
|
|
141
|
+
for _i in range(num_iters):
|
|
142
|
+
wp.launch(inc, dim=a.size, inputs=[a])
|
|
143
|
+
|
|
144
|
+
assert_np_equal(a.numpy(), np.full(n, num_iters, dtype=np.float32))
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def test_async_kernels_v2(test, device, use_mempools, use_graph):
|
|
148
|
+
with wp.ScopedDevice(device), wp.ScopedMempool(device, use_mempools):
|
|
149
|
+
n = 100
|
|
150
|
+
num_iters = 10
|
|
151
|
+
|
|
152
|
+
a = wp.zeros(n, dtype=float)
|
|
153
|
+
|
|
154
|
+
with Capturable(use_graph):
|
|
155
|
+
for _i in range(num_iters):
|
|
156
|
+
wp.launch(inc, dim=a.size, inputs=[a])
|
|
157
|
+
|
|
158
|
+
assert_np_equal(a.numpy(), np.full(n, num_iters, dtype=np.float32))
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
class TestAsync(unittest.TestCase):
|
|
162
|
+
pass
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
# get all CUDA devices
|
|
166
|
+
cuda_devices = wp.get_cuda_devices()
|
|
167
|
+
|
|
168
|
+
# get CUDA devices that support mempools
|
|
169
|
+
cuda_devices_with_mempools = []
|
|
170
|
+
for d in cuda_devices:
|
|
171
|
+
if d.is_mempool_supported:
|
|
172
|
+
cuda_devices_with_mempools.append(d)
|
|
173
|
+
|
|
174
|
+
# get a pair of CUDA devices that support mempool access
|
|
175
|
+
cuda_devices_with_mempool_access = []
|
|
176
|
+
for target_device in cuda_devices_with_mempools:
|
|
177
|
+
for peer_device in cuda_devices_with_mempools:
|
|
178
|
+
if peer_device != target_device:
|
|
179
|
+
if wp.is_mempool_access_supported(target_device, peer_device):
|
|
180
|
+
cuda_devices_with_mempool_access = [target_device, peer_device]
|
|
181
|
+
break
|
|
182
|
+
if cuda_devices_with_mempool_access:
|
|
183
|
+
break
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def add_test_variants(
|
|
187
|
+
func,
|
|
188
|
+
device_count=1,
|
|
189
|
+
graph_allocs=False,
|
|
190
|
+
requires_mempool_access_with_graph=False,
|
|
191
|
+
):
|
|
192
|
+
# test that works with default allocators
|
|
193
|
+
if not graph_allocs and device_count <= len(cuda_devices):
|
|
194
|
+
devices = cuda_devices[:device_count]
|
|
195
|
+
|
|
196
|
+
def func1(t, d):
|
|
197
|
+
return func(t, *devices, False, False)
|
|
198
|
+
|
|
199
|
+
def func2(t, d):
|
|
200
|
+
return func(t, *devices, False, True)
|
|
201
|
+
|
|
202
|
+
name1 = f"{func.__name__}_DefaultAlloc_NoGraph"
|
|
203
|
+
name2 = f"{func.__name__}_DefaultAlloc_WithGraph"
|
|
204
|
+
if device_count == 1:
|
|
205
|
+
add_function_test(TestAsync, name1, func1, devices=devices)
|
|
206
|
+
add_function_test(TestAsync, name2, func2, devices=devices)
|
|
207
|
+
else:
|
|
208
|
+
add_function_test(TestAsync, name1, func1)
|
|
209
|
+
add_function_test(TestAsync, name2, func2)
|
|
210
|
+
|
|
211
|
+
# test that works with mempool allocators
|
|
212
|
+
if device_count <= len(cuda_devices_with_mempools):
|
|
213
|
+
devices = cuda_devices_with_mempools[:device_count]
|
|
214
|
+
|
|
215
|
+
def func3(t, d):
|
|
216
|
+
return func(t, *devices, True, False)
|
|
217
|
+
|
|
218
|
+
name3 = f"{func.__name__}_MempoolAlloc_NoGraph"
|
|
219
|
+
if device_count == 1:
|
|
220
|
+
add_function_test(TestAsync, name3, func3, devices=devices)
|
|
221
|
+
else:
|
|
222
|
+
add_function_test(TestAsync, name3, func3)
|
|
223
|
+
|
|
224
|
+
# test that requires devices with mutual mempool access during graph capture (e.g., p2p memcpy limitation)
|
|
225
|
+
if requires_mempool_access_with_graph:
|
|
226
|
+
suitable_devices = cuda_devices_with_mempool_access
|
|
227
|
+
else:
|
|
228
|
+
suitable_devices = cuda_devices_with_mempools
|
|
229
|
+
|
|
230
|
+
if device_count <= len(suitable_devices):
|
|
231
|
+
devices = suitable_devices[:device_count]
|
|
232
|
+
|
|
233
|
+
def func4(t, d):
|
|
234
|
+
return func(t, *devices, True, True)
|
|
235
|
+
|
|
236
|
+
name4 = f"{func.__name__}_MempoolAlloc_WithGraph"
|
|
237
|
+
if device_count == 1:
|
|
238
|
+
add_function_test(TestAsync, name4, func4, devices=devices)
|
|
239
|
+
else:
|
|
240
|
+
add_function_test(TestAsync, name4, func4)
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
add_test_variants(test_async_empty, graph_allocs=True)
|
|
244
|
+
add_test_variants(test_async_zeros, graph_allocs=True)
|
|
245
|
+
add_test_variants(test_async_zero_v1, graph_allocs=True)
|
|
246
|
+
add_test_variants(test_async_zero_v2, graph_allocs=False)
|
|
247
|
+
add_test_variants(test_async_full, graph_allocs=True)
|
|
248
|
+
add_test_variants(test_async_fill_v1, graph_allocs=True)
|
|
249
|
+
add_test_variants(test_async_fill_v2, graph_allocs=False)
|
|
250
|
+
add_test_variants(test_async_kernels_v1, graph_allocs=True)
|
|
251
|
+
add_test_variants(test_async_kernels_v2, graph_allocs=False)
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
# =================================================================================
|
|
255
|
+
# wp.copy() tests
|
|
256
|
+
# =================================================================================
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
def as_contiguous_array(data, device=None, grad_data=None):
|
|
260
|
+
a = wp.array(data=data, device=device, copy=True)
|
|
261
|
+
if grad_data is not None:
|
|
262
|
+
a.grad = as_contiguous_array(grad_data, device=device)
|
|
263
|
+
return a
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
def as_strided_array(data, device=None, grad_data=None):
|
|
267
|
+
a = wp.array(data=data, device=device)
|
|
268
|
+
# make a copy with non-contiguous strides
|
|
269
|
+
strides = (*a.strides[:-1], 2 * a.strides[-1])
|
|
270
|
+
strided_a = wp.zeros(shape=a.shape, strides=strides, dtype=a.dtype, device=device)
|
|
271
|
+
wp.copy(strided_a, a)
|
|
272
|
+
if grad_data is not None:
|
|
273
|
+
strided_a.grad = as_strided_array(grad_data, device=device)
|
|
274
|
+
return strided_a
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
def as_indexed_array(data, device=None, **kwargs):
|
|
278
|
+
a = wp.array(data=data, device=device)
|
|
279
|
+
# allocate double the elements so we can index half of them
|
|
280
|
+
shape = (*a.shape[:-1], 2 * a.shape[-1])
|
|
281
|
+
big_a = wp.zeros(shape=shape, dtype=a.dtype, device=device)
|
|
282
|
+
indices = wp.array(data=np.arange(0, shape[-1], 2, dtype=np.int32), device=device)
|
|
283
|
+
indexed_a = big_a[indices]
|
|
284
|
+
wp.copy(indexed_a, a)
|
|
285
|
+
return indexed_a
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
def as_fabric_array(data, device=None, **kwargs):
|
|
289
|
+
from warp.tests.test_fabricarray import _create_fabric_array_interface
|
|
290
|
+
|
|
291
|
+
a = wp.array(data=data, device=device)
|
|
292
|
+
iface = _create_fabric_array_interface(a, "foo")
|
|
293
|
+
fa = wp.fabricarray(data=iface, attrib="foo")
|
|
294
|
+
fa._iface = iface # save data reference
|
|
295
|
+
return fa
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
def as_indexed_fabric_array(data, device=None, **kwargs):
|
|
299
|
+
from warp.tests.test_fabricarray import _create_fabric_array_interface
|
|
300
|
+
|
|
301
|
+
a = wp.array(data=data, device=device)
|
|
302
|
+
shape = (*a.shape[:-1], 2 * a.shape[-1])
|
|
303
|
+
# allocate double the elements so we can index half of them
|
|
304
|
+
big_a = wp.zeros(shape=shape, dtype=a.dtype, device=device)
|
|
305
|
+
indices = wp.array(data=np.arange(0, shape[-1], 2, dtype=np.int32), device=device)
|
|
306
|
+
iface = _create_fabric_array_interface(big_a, "foo", copy=True)
|
|
307
|
+
fa = wp.fabricarray(data=iface, attrib="foo")
|
|
308
|
+
fa._iface = iface # save data reference
|
|
309
|
+
indexed_fa = fa[indices]
|
|
310
|
+
wp.copy(indexed_fa, a)
|
|
311
|
+
return indexed_fa
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
class CopyParams:
|
|
315
|
+
def __init__(
|
|
316
|
+
self,
|
|
317
|
+
with_grad=False, # whether to use arrays with gradients (contiguous and strided only)
|
|
318
|
+
src_use_mempool=False, # whether to enable memory pool on source device
|
|
319
|
+
dst_use_mempool=False, # whether to enable memory pool on destination device
|
|
320
|
+
access_dst_src=False, # whether destination device has access to the source mempool
|
|
321
|
+
access_src_dst=False, # whether source device has access to the destination mempool
|
|
322
|
+
stream_device=None, # the device for the stream (None for default behaviour)
|
|
323
|
+
use_graph=False, # whether to use a graph
|
|
324
|
+
value_offset=0, # unique offset for generated data values per test
|
|
325
|
+
):
|
|
326
|
+
self.with_grad = with_grad
|
|
327
|
+
self.src_use_mempool = src_use_mempool
|
|
328
|
+
self.dst_use_mempool = dst_use_mempool
|
|
329
|
+
self.access_dst_src = access_dst_src
|
|
330
|
+
self.access_src_dst = access_src_dst
|
|
331
|
+
self.stream_device = stream_device
|
|
332
|
+
self.use_graph = use_graph
|
|
333
|
+
self.value_offset = value_offset
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
def copy_template(test, src_ctor, dst_ctor, src_device, dst_device, n, params: CopyParams):
|
|
337
|
+
# activate the given memory pool configuration
|
|
338
|
+
with wp.ScopedMempool(src_device, params.src_use_mempool), wp.ScopedMempool(
|
|
339
|
+
dst_device, params.dst_use_mempool
|
|
340
|
+
), wp.ScopedMempoolAccess(dst_device, src_device, params.access_dst_src), wp.ScopedMempoolAccess(
|
|
341
|
+
src_device, dst_device, params.access_src_dst
|
|
342
|
+
):
|
|
343
|
+
# make sure the data are different between tests by adding a unique offset
|
|
344
|
+
# this avoids aliasing issues with older memory
|
|
345
|
+
src_data = np.arange(params.value_offset, params.value_offset + n, dtype=np.float32)
|
|
346
|
+
dst_data = np.zeros(n, dtype=np.float32)
|
|
347
|
+
|
|
348
|
+
if params.with_grad:
|
|
349
|
+
src_grad_data = -np.arange(params.value_offset, params.value_offset + n, dtype=np.float32)
|
|
350
|
+
dst_grad_data = np.zeros(n, dtype=np.float32)
|
|
351
|
+
else:
|
|
352
|
+
src_grad_data = None
|
|
353
|
+
dst_grad_data = None
|
|
354
|
+
|
|
355
|
+
# create Warp arrays for the copy
|
|
356
|
+
src = src_ctor(src_data, device=src_device, grad_data=src_grad_data)
|
|
357
|
+
dst = dst_ctor(dst_data, device=dst_device, grad_data=dst_grad_data)
|
|
358
|
+
|
|
359
|
+
# determine the stream argument to pass to wp.copy()
|
|
360
|
+
if params.stream_device is not None:
|
|
361
|
+
stream_arg = wp.Stream(params.stream_device)
|
|
362
|
+
else:
|
|
363
|
+
stream_arg = None
|
|
364
|
+
|
|
365
|
+
# determine the actual stream used for the copy
|
|
366
|
+
if stream_arg is not None:
|
|
367
|
+
stream = stream_arg
|
|
368
|
+
else:
|
|
369
|
+
if dst_device.is_cuda:
|
|
370
|
+
stream = dst_device.stream
|
|
371
|
+
elif src_device.is_cuda:
|
|
372
|
+
stream = src_device.stream
|
|
373
|
+
else:
|
|
374
|
+
stream = None
|
|
375
|
+
|
|
376
|
+
# check if an exception is expected given the arguments and system configuration
|
|
377
|
+
expected_error_type = None
|
|
378
|
+
expected_error_regex = None
|
|
379
|
+
|
|
380
|
+
# restrictions on copying between different devices during graph capture
|
|
381
|
+
if params.use_graph and src_device != dst_device:
|
|
382
|
+
# errors with allocating staging buffer on source device
|
|
383
|
+
if not src.is_contiguous:
|
|
384
|
+
if src_device.is_cuda and not src_device.is_mempool_enabled:
|
|
385
|
+
# can't allocate staging buffer using default CUDA allocator during capture
|
|
386
|
+
expected_error_type, expected_error_regex = RuntimeError, r"^Failed to allocate"
|
|
387
|
+
elif src_device.is_cpu:
|
|
388
|
+
# can't allocate CPU staging buffer during capture
|
|
389
|
+
expected_error_type, expected_error_regex = RuntimeError, r"^Failed to allocate"
|
|
390
|
+
|
|
391
|
+
# errors with allocating staging buffer on destination device
|
|
392
|
+
if expected_error_type is None:
|
|
393
|
+
if not dst.is_contiguous:
|
|
394
|
+
if dst_device.is_cuda and not dst_device.is_mempool_enabled:
|
|
395
|
+
# can't allocate staging buffer using default CUDA allocator during capture
|
|
396
|
+
expected_error_type, expected_error_regex = RuntimeError, r"^Failed to allocate"
|
|
397
|
+
elif dst_device.is_cpu and src_device.is_cuda:
|
|
398
|
+
# can't allocate CPU staging buffer during capture
|
|
399
|
+
expected_error_type, expected_error_regex = RuntimeError, r"^Failed to allocate"
|
|
400
|
+
|
|
401
|
+
# p2p copies and mempool access
|
|
402
|
+
if expected_error_type is None and src_device.is_cuda and dst_device.is_cuda:
|
|
403
|
+
# If the source is a contiguous mempool allocation or a non-contiguous array
|
|
404
|
+
# AND the destination is a contiguous mempool allocation or a non-contiguous array,
|
|
405
|
+
# then memory pool access needs to be enabled EITHER from src_device to dst_device
|
|
406
|
+
# OR from dst_device to src_device.
|
|
407
|
+
if (
|
|
408
|
+
((src.is_contiguous and params.src_use_mempool) or not src.is_contiguous)
|
|
409
|
+
and ((dst.is_contiguous and params.dst_use_mempool) or not dst.is_contiguous)
|
|
410
|
+
and not wp.is_mempool_access_enabled(src_device, dst_device)
|
|
411
|
+
and not wp.is_mempool_access_enabled(dst_device, src_device)
|
|
412
|
+
):
|
|
413
|
+
expected_error_type, expected_error_regex = RuntimeError, r"^Warp copy error"
|
|
414
|
+
|
|
415
|
+
# synchronize before test
|
|
416
|
+
wp.synchronize()
|
|
417
|
+
|
|
418
|
+
if expected_error_type is not None:
|
|
419
|
+
# disable error output from Warp if we expect an exception
|
|
420
|
+
try:
|
|
421
|
+
saved_error_output_enabled = wp.context.runtime.core.is_error_output_enabled()
|
|
422
|
+
wp.context.runtime.core.set_error_output_enabled(False)
|
|
423
|
+
with test.assertRaisesRegex(expected_error_type, expected_error_regex):
|
|
424
|
+
with Capturable(use_graph=params.use_graph, stream=stream):
|
|
425
|
+
wp.copy(dst, src, stream=stream_arg)
|
|
426
|
+
finally:
|
|
427
|
+
wp.context.runtime.core.set_error_output_enabled(saved_error_output_enabled)
|
|
428
|
+
wp.synchronize()
|
|
429
|
+
|
|
430
|
+
# print(f"SUCCESSFUL ERROR PREDICTION: {expected_error_regex}")
|
|
431
|
+
|
|
432
|
+
else:
|
|
433
|
+
with Capturable(use_graph=params.use_graph, stream=stream):
|
|
434
|
+
wp.copy(dst, src, stream=stream_arg)
|
|
435
|
+
|
|
436
|
+
# synchronize the stream where the copy was running (None for h2h copies)
|
|
437
|
+
if stream is not None:
|
|
438
|
+
wp.synchronize_stream(stream)
|
|
439
|
+
|
|
440
|
+
assert_np_equal(dst.numpy(), src.numpy())
|
|
441
|
+
|
|
442
|
+
if params.with_grad:
|
|
443
|
+
assert_np_equal(dst.grad.numpy(), src.grad.numpy())
|
|
444
|
+
|
|
445
|
+
# print("SUCCESSFUL COPY")
|
|
446
|
+
|
|
447
|
+
|
|
448
|
+
array_constructors = {
|
|
449
|
+
"contiguous": as_contiguous_array,
|
|
450
|
+
"strided": as_strided_array,
|
|
451
|
+
"indexed": as_indexed_array,
|
|
452
|
+
"fabric": as_fabric_array,
|
|
453
|
+
"indexedfabric": as_indexed_fabric_array,
|
|
454
|
+
}
|
|
455
|
+
|
|
456
|
+
array_type_codes = {
|
|
457
|
+
"contiguous": "c",
|
|
458
|
+
"strided": "s",
|
|
459
|
+
"indexed": "i",
|
|
460
|
+
"fabric": "f",
|
|
461
|
+
"indexedfabric": "fi",
|
|
462
|
+
}
|
|
463
|
+
|
|
464
|
+
device_pairs = {}
|
|
465
|
+
cpu = None
|
|
466
|
+
cuda0 = None
|
|
467
|
+
cuda1 = None
|
|
468
|
+
cuda2 = None
|
|
469
|
+
if wp.is_cpu_available():
|
|
470
|
+
cpu = wp.get_device("cpu")
|
|
471
|
+
device_pairs["h2h"] = (cpu, cpu)
|
|
472
|
+
if wp.is_cuda_available():
|
|
473
|
+
cuda0 = wp.get_device("cuda:0")
|
|
474
|
+
device_pairs["d2d"] = (cuda0, cuda0)
|
|
475
|
+
if wp.is_cpu_available():
|
|
476
|
+
device_pairs["h2d"] = (cpu, cuda0)
|
|
477
|
+
device_pairs["d2h"] = (cuda0, cpu)
|
|
478
|
+
if wp.get_cuda_device_count() > 1:
|
|
479
|
+
cuda1 = wp.get_device("cuda:1")
|
|
480
|
+
device_pairs["p2p"] = (cuda0, cuda1)
|
|
481
|
+
if wp.get_cuda_device_count() > 2:
|
|
482
|
+
cuda2 = wp.get_device("cuda:2")
|
|
483
|
+
|
|
484
|
+
num_copy_elems = 1000000
|
|
485
|
+
num_copy_tests = 0
|
|
486
|
+
|
|
487
|
+
|
|
488
|
+
def add_copy_test(test_name, src_ctor, dst_ctor, src_device, dst_device, n, params):
|
|
489
|
+
def test_func(
|
|
490
|
+
test,
|
|
491
|
+
device,
|
|
492
|
+
src_ctor=src_ctor,
|
|
493
|
+
dst_ctor=dst_ctor,
|
|
494
|
+
src_device=src_device,
|
|
495
|
+
dst_device=dst_device,
|
|
496
|
+
n=n,
|
|
497
|
+
params=params,
|
|
498
|
+
):
|
|
499
|
+
return copy_template(test, src_ctor, dst_ctor, src_device, dst_device, n, params)
|
|
500
|
+
|
|
501
|
+
add_function_test(TestAsync, test_name, test_func, check_output=False)
|
|
502
|
+
|
|
503
|
+
|
|
504
|
+
# Procedurally add tests with argument combinations supported by the system.
|
|
505
|
+
for src_type, src_ctor in array_constructors.items():
|
|
506
|
+
for dst_type, dst_ctor in array_constructors.items():
|
|
507
|
+
copy_type = f"{array_type_codes[src_type]}2{array_type_codes[dst_type]}"
|
|
508
|
+
|
|
509
|
+
for transfer_type, device_pair in device_pairs.items():
|
|
510
|
+
# skip p2p tests if IOMMU is enabled on Linux
|
|
511
|
+
if transfer_type == "p2p" and not check_iommu():
|
|
512
|
+
continue
|
|
513
|
+
|
|
514
|
+
src_device = device_pair[0]
|
|
515
|
+
dst_device = device_pair[1]
|
|
516
|
+
|
|
517
|
+
# basic copy arguments
|
|
518
|
+
copy_args = (src_ctor, dst_ctor, src_device, dst_device, num_copy_elems)
|
|
519
|
+
|
|
520
|
+
if src_device.is_cuda and src_device.is_mempool_supported:
|
|
521
|
+
src_mempool_flags = [False, True]
|
|
522
|
+
else:
|
|
523
|
+
src_mempool_flags = [False]
|
|
524
|
+
|
|
525
|
+
if dst_device.is_cuda and dst_device.is_mempool_supported:
|
|
526
|
+
dst_mempool_flags = [False, True]
|
|
527
|
+
else:
|
|
528
|
+
dst_mempool_flags = [False]
|
|
529
|
+
|
|
530
|
+
# stream options
|
|
531
|
+
if src_device.is_cuda:
|
|
532
|
+
if dst_device.is_cuda:
|
|
533
|
+
if src_device == dst_device:
|
|
534
|
+
# d2d
|
|
535
|
+
assert src_device == cuda0 and dst_device == cuda0
|
|
536
|
+
if cuda1 is not None:
|
|
537
|
+
stream_devices = [None, cuda0, cuda1]
|
|
538
|
+
else:
|
|
539
|
+
stream_devices = [None, cuda0]
|
|
540
|
+
else:
|
|
541
|
+
# p2p
|
|
542
|
+
assert src_device == cuda0 and dst_device == cuda1
|
|
543
|
+
if cuda2 is not None:
|
|
544
|
+
stream_devices = [None, cuda0, cuda1, cuda2]
|
|
545
|
+
else:
|
|
546
|
+
stream_devices = [None, cuda0, cuda1]
|
|
547
|
+
else:
|
|
548
|
+
# d2h
|
|
549
|
+
assert src_device == cuda0
|
|
550
|
+
if cuda1 is not None:
|
|
551
|
+
stream_devices = [None, cuda0, cuda1]
|
|
552
|
+
else:
|
|
553
|
+
stream_devices = [None, cuda0]
|
|
554
|
+
else:
|
|
555
|
+
if dst_device.is_cuda:
|
|
556
|
+
# h2d
|
|
557
|
+
assert dst_device == cuda0
|
|
558
|
+
if cuda1 is not None:
|
|
559
|
+
stream_devices = [None, cuda0, cuda1]
|
|
560
|
+
else:
|
|
561
|
+
stream_devices = [None, cuda0]
|
|
562
|
+
else:
|
|
563
|
+
# h2h
|
|
564
|
+
stream_devices = [None]
|
|
565
|
+
|
|
566
|
+
# gradient options (only supported with contiguous and strided arrays)
|
|
567
|
+
if src_type in ("contiguous", "strided") and dst_type in ("contiguous", "strided"):
|
|
568
|
+
grad_flags = [False, True]
|
|
569
|
+
else:
|
|
570
|
+
grad_flags = [False]
|
|
571
|
+
|
|
572
|
+
# graph capture options (only supported with CUDA devices)
|
|
573
|
+
if src_device.is_cuda or dst_device.is_cuda:
|
|
574
|
+
graph_flags = [False, True]
|
|
575
|
+
else:
|
|
576
|
+
graph_flags = [False]
|
|
577
|
+
|
|
578
|
+
# access from destination device to source mempool
|
|
579
|
+
if wp.is_mempool_access_supported(dst_device, src_device):
|
|
580
|
+
access_dst_src_flags = [False, True]
|
|
581
|
+
else:
|
|
582
|
+
access_dst_src_flags = [False]
|
|
583
|
+
|
|
584
|
+
# access from source device to destination mempool
|
|
585
|
+
if wp.is_mempool_access_supported(src_device, dst_device):
|
|
586
|
+
access_src_dst_flags = [False, True]
|
|
587
|
+
else:
|
|
588
|
+
access_src_dst_flags = [False]
|
|
589
|
+
|
|
590
|
+
for src_use_mempool in src_mempool_flags:
|
|
591
|
+
for dst_use_mempool in dst_mempool_flags:
|
|
592
|
+
for stream_device in stream_devices:
|
|
593
|
+
for access_dst_src in access_dst_src_flags:
|
|
594
|
+
for access_src_dst in access_src_dst_flags:
|
|
595
|
+
for with_grad in grad_flags:
|
|
596
|
+
for use_graph in graph_flags:
|
|
597
|
+
test_name = f"test_copy_{copy_type}_{transfer_type}"
|
|
598
|
+
|
|
599
|
+
if src_use_mempool:
|
|
600
|
+
test_name += "_SrcPoolOn"
|
|
601
|
+
else:
|
|
602
|
+
test_name += "_SrcPoolOff"
|
|
603
|
+
|
|
604
|
+
if dst_use_mempool:
|
|
605
|
+
test_name += "_DstPoolOn"
|
|
606
|
+
else:
|
|
607
|
+
test_name += "_DstPoolOff"
|
|
608
|
+
|
|
609
|
+
if stream_device is None:
|
|
610
|
+
test_name += "_NoStream"
|
|
611
|
+
elif stream_device == cuda0:
|
|
612
|
+
test_name += "_Stream0"
|
|
613
|
+
elif stream_device == cuda1:
|
|
614
|
+
test_name += "_Stream1"
|
|
615
|
+
elif stream_device == cuda2:
|
|
616
|
+
test_name += "_Stream2"
|
|
617
|
+
else:
|
|
618
|
+
raise AssertionError
|
|
619
|
+
|
|
620
|
+
if with_grad:
|
|
621
|
+
test_name += "_Grad"
|
|
622
|
+
else:
|
|
623
|
+
test_name += "_NoGrad"
|
|
624
|
+
|
|
625
|
+
if use_graph:
|
|
626
|
+
test_name += "_Graph"
|
|
627
|
+
else:
|
|
628
|
+
test_name += "_NoGraph"
|
|
629
|
+
|
|
630
|
+
if access_dst_src and access_src_dst:
|
|
631
|
+
test_name += "_AccessBoth"
|
|
632
|
+
elif access_dst_src and not access_src_dst:
|
|
633
|
+
test_name += "_AccessDstSrc"
|
|
634
|
+
elif not access_dst_src and access_src_dst:
|
|
635
|
+
test_name += "_AccessSrcDst"
|
|
636
|
+
else:
|
|
637
|
+
test_name += "_AccessNone"
|
|
638
|
+
|
|
639
|
+
copy_params = CopyParams(
|
|
640
|
+
src_use_mempool=src_use_mempool,
|
|
641
|
+
dst_use_mempool=dst_use_mempool,
|
|
642
|
+
access_dst_src=access_dst_src,
|
|
643
|
+
access_src_dst=access_src_dst,
|
|
644
|
+
stream_device=stream_device,
|
|
645
|
+
with_grad=with_grad,
|
|
646
|
+
use_graph=use_graph,
|
|
647
|
+
value_offset=num_copy_tests,
|
|
648
|
+
)
|
|
649
|
+
|
|
650
|
+
add_copy_test(test_name, *copy_args, copy_params)
|
|
651
|
+
|
|
652
|
+
num_copy_tests += 1
|
|
653
|
+
|
|
654
|
+
# Specify individual test(s) for debugging purposes
|
|
655
|
+
# add_copy_test("test_a", as_contiguous_array, as_strided_array, cuda0, cuda1, num_copy_elems,
|
|
656
|
+
# CopyParams(
|
|
657
|
+
# src_use_mempool=True,
|
|
658
|
+
# dst_use_mempool=True,
|
|
659
|
+
# access_dst_src=False,
|
|
660
|
+
# access_src_dst=False,
|
|
661
|
+
# stream_device=cuda0,
|
|
662
|
+
# with_grad=False,
|
|
663
|
+
# use_graph=True,
|
|
664
|
+
# value_offset=0))
|
|
665
|
+
|
|
666
|
+
if __name__ == "__main__":
|
|
667
|
+
wp.build.clear_kernel_cache()
|
|
668
|
+
unittest.main(verbosity=2)
|