warp-lang 1.0.2__py3-none-win_amd64.whl → 1.1.0__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +108 -97
- warp/__init__.pyi +1 -1
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +115 -113
- warp/build_dll.py +383 -375
- warp/builtins.py +3425 -3354
- warp/codegen.py +2878 -2792
- warp/config.py +40 -36
- warp/constants.py +45 -45
- warp/context.py +5194 -5102
- warp/dlpack.py +442 -442
- warp/examples/__init__.py +16 -16
- warp/examples/assets/bear.usd +0 -0
- warp/examples/assets/bunny.usd +0 -0
- warp/examples/assets/cartpole.urdf +110 -110
- warp/examples/assets/crazyflie.usd +0 -0
- warp/examples/assets/cube.usd +0 -0
- warp/examples/assets/nv_ant.xml +92 -92
- warp/examples/assets/nv_humanoid.xml +183 -183
- warp/examples/assets/quadruped.urdf +267 -267
- warp/examples/assets/rocks.nvdb +0 -0
- warp/examples/assets/rocks.usd +0 -0
- warp/examples/assets/sphere.usd +0 -0
- warp/examples/benchmarks/benchmark_api.py +383 -383
- warp/examples/benchmarks/benchmark_cloth.py +278 -277
- warp/examples/benchmarks/benchmark_cloth_cupy.py +88 -88
- warp/examples/benchmarks/benchmark_cloth_jax.py +97 -100
- warp/examples/benchmarks/benchmark_cloth_numba.py +146 -142
- warp/examples/benchmarks/benchmark_cloth_numpy.py +77 -77
- warp/examples/benchmarks/benchmark_cloth_pytorch.py +86 -86
- warp/examples/benchmarks/benchmark_cloth_taichi.py +112 -112
- warp/examples/benchmarks/benchmark_cloth_warp.py +146 -146
- warp/examples/benchmarks/benchmark_launches.py +295 -295
- warp/examples/browse.py +29 -29
- warp/examples/core/example_dem.py +234 -219
- warp/examples/core/example_fluid.py +293 -267
- warp/examples/core/example_graph_capture.py +144 -126
- warp/examples/core/example_marching_cubes.py +188 -174
- warp/examples/core/example_mesh.py +174 -155
- warp/examples/core/example_mesh_intersect.py +205 -193
- warp/examples/core/example_nvdb.py +176 -170
- warp/examples/core/example_raycast.py +105 -90
- warp/examples/core/example_raymarch.py +199 -178
- warp/examples/core/example_render_opengl.py +185 -141
- warp/examples/core/example_sph.py +405 -387
- warp/examples/core/example_torch.py +222 -181
- warp/examples/core/example_wave.py +263 -248
- warp/examples/fem/bsr_utils.py +378 -380
- warp/examples/fem/example_apic_fluid.py +407 -389
- warp/examples/fem/example_convection_diffusion.py +182 -168
- warp/examples/fem/example_convection_diffusion_dg.py +219 -209
- warp/examples/fem/example_convection_diffusion_dg0.py +204 -194
- warp/examples/fem/example_deformed_geometry.py +177 -159
- warp/examples/fem/example_diffusion.py +201 -173
- warp/examples/fem/example_diffusion_3d.py +177 -152
- warp/examples/fem/example_diffusion_mgpu.py +221 -214
- warp/examples/fem/example_mixed_elasticity.py +244 -222
- warp/examples/fem/example_navier_stokes.py +259 -243
- warp/examples/fem/example_stokes.py +220 -192
- warp/examples/fem/example_stokes_transfer.py +265 -249
- warp/examples/fem/mesh_utils.py +133 -109
- warp/examples/fem/plot_utils.py +292 -287
- warp/examples/optim/example_bounce.py +260 -246
- warp/examples/optim/example_cloth_throw.py +222 -209
- warp/examples/optim/example_diffray.py +566 -536
- warp/examples/optim/example_drone.py +864 -835
- warp/examples/optim/example_inverse_kinematics.py +176 -168
- warp/examples/optim/example_inverse_kinematics_torch.py +185 -169
- warp/examples/optim/example_spring_cage.py +239 -231
- warp/examples/optim/example_trajectory.py +223 -199
- warp/examples/optim/example_walker.py +306 -293
- warp/examples/sim/example_cartpole.py +139 -129
- warp/examples/sim/example_cloth.py +196 -186
- warp/examples/sim/example_granular.py +124 -111
- warp/examples/sim/example_granular_collision_sdf.py +197 -186
- warp/examples/sim/example_jacobian_ik.py +236 -214
- warp/examples/sim/example_particle_chain.py +118 -105
- warp/examples/sim/example_quadruped.py +193 -180
- warp/examples/sim/example_rigid_chain.py +197 -187
- warp/examples/sim/example_rigid_contact.py +189 -177
- warp/examples/sim/example_rigid_force.py +127 -125
- warp/examples/sim/example_rigid_gyroscopic.py +109 -95
- warp/examples/sim/example_rigid_soft_contact.py +134 -122
- warp/examples/sim/example_soft_body.py +190 -177
- warp/fabric.py +337 -335
- warp/fem/__init__.py +60 -27
- warp/fem/cache.py +401 -388
- warp/fem/dirichlet.py +178 -179
- warp/fem/domain.py +262 -263
- warp/fem/field/__init__.py +100 -101
- warp/fem/field/field.py +148 -149
- warp/fem/field/nodal_field.py +298 -299
- warp/fem/field/restriction.py +22 -21
- warp/fem/field/test.py +180 -181
- warp/fem/field/trial.py +183 -183
- warp/fem/geometry/__init__.py +15 -19
- warp/fem/geometry/closest_point.py +69 -70
- warp/fem/geometry/deformed_geometry.py +270 -271
- warp/fem/geometry/element.py +744 -744
- warp/fem/geometry/geometry.py +184 -186
- warp/fem/geometry/grid_2d.py +380 -373
- warp/fem/geometry/grid_3d.py +441 -435
- warp/fem/geometry/hexmesh.py +953 -953
- warp/fem/geometry/partition.py +374 -376
- warp/fem/geometry/quadmesh_2d.py +532 -532
- warp/fem/geometry/tetmesh.py +840 -840
- warp/fem/geometry/trimesh_2d.py +577 -577
- warp/fem/integrate.py +1630 -1615
- warp/fem/operator.py +190 -191
- warp/fem/polynomial.py +214 -213
- warp/fem/quadrature/__init__.py +2 -2
- warp/fem/quadrature/pic_quadrature.py +243 -245
- warp/fem/quadrature/quadrature.py +295 -294
- warp/fem/space/__init__.py +294 -292
- warp/fem/space/basis_space.py +488 -489
- warp/fem/space/collocated_function_space.py +100 -105
- warp/fem/space/dof_mapper.py +236 -236
- warp/fem/space/function_space.py +148 -145
- warp/fem/space/grid_2d_function_space.py +267 -267
- warp/fem/space/grid_3d_function_space.py +305 -306
- warp/fem/space/hexmesh_function_space.py +350 -352
- warp/fem/space/partition.py +350 -350
- warp/fem/space/quadmesh_2d_function_space.py +368 -369
- warp/fem/space/restriction.py +158 -160
- warp/fem/space/shape/__init__.py +13 -15
- warp/fem/space/shape/cube_shape_function.py +738 -738
- warp/fem/space/shape/shape_function.py +102 -103
- warp/fem/space/shape/square_shape_function.py +611 -611
- warp/fem/space/shape/tet_shape_function.py +565 -567
- warp/fem/space/shape/triangle_shape_function.py +429 -429
- warp/fem/space/tetmesh_function_space.py +294 -292
- warp/fem/space/topology.py +297 -295
- warp/fem/space/trimesh_2d_function_space.py +223 -221
- warp/fem/types.py +77 -77
- warp/fem/utils.py +495 -495
- warp/jax.py +166 -141
- warp/jax_experimental.py +341 -339
- warp/native/array.h +1072 -1025
- warp/native/builtin.h +1560 -1560
- warp/native/bvh.cpp +398 -398
- warp/native/bvh.cu +525 -525
- warp/native/bvh.h +429 -429
- warp/native/clang/clang.cpp +495 -464
- warp/native/crt.cpp +31 -31
- warp/native/crt.h +334 -334
- warp/native/cuda_crt.h +1049 -1049
- warp/native/cuda_util.cpp +549 -540
- warp/native/cuda_util.h +288 -203
- warp/native/cutlass_gemm.cpp +34 -34
- warp/native/cutlass_gemm.cu +372 -372
- warp/native/error.cpp +66 -66
- warp/native/error.h +27 -27
- warp/native/fabric.h +228 -228
- warp/native/hashgrid.cpp +301 -278
- warp/native/hashgrid.cu +78 -77
- warp/native/hashgrid.h +227 -227
- warp/native/initializer_array.h +32 -32
- warp/native/intersect.h +1204 -1204
- warp/native/intersect_adj.h +365 -365
- warp/native/intersect_tri.h +322 -322
- warp/native/marching.cpp +2 -2
- warp/native/marching.cu +497 -497
- warp/native/marching.h +2 -2
- warp/native/mat.h +1498 -1498
- warp/native/matnn.h +333 -333
- warp/native/mesh.cpp +203 -203
- warp/native/mesh.cu +293 -293
- warp/native/mesh.h +1887 -1887
- warp/native/nanovdb/NanoVDB.h +4782 -4782
- warp/native/nanovdb/PNanoVDB.h +2553 -2553
- warp/native/nanovdb/PNanoVDBWrite.h +294 -294
- warp/native/noise.h +850 -850
- warp/native/quat.h +1084 -1084
- warp/native/rand.h +299 -299
- warp/native/range.h +108 -108
- warp/native/reduce.cpp +156 -156
- warp/native/reduce.cu +348 -348
- warp/native/runlength_encode.cpp +61 -61
- warp/native/runlength_encode.cu +46 -46
- warp/native/scan.cpp +30 -30
- warp/native/scan.cu +36 -36
- warp/native/scan.h +7 -7
- warp/native/solid_angle.h +442 -442
- warp/native/sort.cpp +94 -94
- warp/native/sort.cu +97 -97
- warp/native/sort.h +14 -14
- warp/native/sparse.cpp +337 -337
- warp/native/sparse.cu +544 -544
- warp/native/spatial.h +630 -630
- warp/native/svd.h +562 -562
- warp/native/temp_buffer.h +30 -30
- warp/native/vec.h +1132 -1132
- warp/native/volume.cpp +297 -297
- warp/native/volume.cu +32 -32
- warp/native/volume.h +538 -538
- warp/native/volume_builder.cu +425 -425
- warp/native/volume_builder.h +19 -19
- warp/native/warp.cpp +1057 -1052
- warp/native/warp.cu +2943 -2828
- warp/native/warp.h +313 -305
- warp/optim/__init__.py +9 -9
- warp/optim/adam.py +120 -120
- warp/optim/linear.py +1104 -939
- warp/optim/sgd.py +104 -92
- warp/render/__init__.py +10 -10
- warp/render/render_opengl.py +3217 -3204
- warp/render/render_usd.py +768 -749
- warp/render/utils.py +152 -150
- warp/sim/__init__.py +52 -59
- warp/sim/articulation.py +685 -685
- warp/sim/collide.py +1594 -1590
- warp/sim/import_mjcf.py +489 -481
- warp/sim/import_snu.py +220 -221
- warp/sim/import_urdf.py +536 -516
- warp/sim/import_usd.py +887 -881
- warp/sim/inertia.py +316 -317
- warp/sim/integrator.py +234 -233
- warp/sim/integrator_euler.py +1956 -1956
- warp/sim/integrator_featherstone.py +1910 -1991
- warp/sim/integrator_xpbd.py +3294 -3312
- warp/sim/model.py +4473 -4314
- warp/sim/particles.py +113 -112
- warp/sim/render.py +417 -403
- warp/sim/utils.py +413 -410
- warp/sparse.py +1227 -1227
- warp/stubs.py +2109 -2469
- warp/tape.py +1162 -225
- warp/tests/__init__.py +1 -1
- warp/tests/__main__.py +4 -4
- warp/tests/assets/torus.usda +105 -105
- warp/tests/aux_test_class_kernel.py +26 -26
- warp/tests/aux_test_compile_consts_dummy.py +10 -10
- warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -21
- warp/tests/aux_test_dependent.py +22 -22
- warp/tests/aux_test_grad_customs.py +23 -23
- warp/tests/aux_test_reference.py +11 -11
- warp/tests/aux_test_reference_reference.py +10 -10
- warp/tests/aux_test_square.py +17 -17
- warp/tests/aux_test_unresolved_func.py +14 -14
- warp/tests/aux_test_unresolved_symbol.py +14 -14
- warp/tests/disabled_kinematics.py +239 -239
- warp/tests/run_coverage_serial.py +31 -31
- warp/tests/test_adam.py +157 -157
- warp/tests/test_arithmetic.py +1124 -1124
- warp/tests/test_array.py +2417 -2326
- warp/tests/test_array_reduce.py +150 -150
- warp/tests/test_async.py +668 -656
- warp/tests/test_atomic.py +141 -141
- warp/tests/test_bool.py +204 -149
- warp/tests/test_builtins_resolution.py +1292 -1292
- warp/tests/test_bvh.py +164 -171
- warp/tests/test_closest_point_edge_edge.py +228 -228
- warp/tests/test_codegen.py +566 -553
- warp/tests/test_compile_consts.py +97 -101
- warp/tests/test_conditional.py +246 -246
- warp/tests/test_copy.py +232 -215
- warp/tests/test_ctypes.py +632 -632
- warp/tests/test_dense.py +67 -67
- warp/tests/test_devices.py +91 -98
- warp/tests/test_dlpack.py +530 -529
- warp/tests/test_examples.py +400 -378
- warp/tests/test_fabricarray.py +955 -955
- warp/tests/test_fast_math.py +62 -54
- warp/tests/test_fem.py +1277 -1278
- warp/tests/test_fp16.py +130 -130
- warp/tests/test_func.py +338 -337
- warp/tests/test_generics.py +571 -571
- warp/tests/test_grad.py +746 -640
- warp/tests/test_grad_customs.py +333 -336
- warp/tests/test_hash_grid.py +210 -164
- warp/tests/test_import.py +39 -39
- warp/tests/test_indexedarray.py +1134 -1134
- warp/tests/test_intersect.py +67 -67
- warp/tests/test_jax.py +307 -307
- warp/tests/test_large.py +167 -164
- warp/tests/test_launch.py +354 -354
- warp/tests/test_lerp.py +261 -261
- warp/tests/test_linear_solvers.py +191 -171
- warp/tests/test_lvalue.py +421 -493
- warp/tests/test_marching_cubes.py +65 -65
- warp/tests/test_mat.py +1801 -1827
- warp/tests/test_mat_lite.py +115 -115
- warp/tests/test_mat_scalar_ops.py +2907 -2889
- warp/tests/test_math.py +126 -193
- warp/tests/test_matmul.py +500 -499
- warp/tests/test_matmul_lite.py +410 -410
- warp/tests/test_mempool.py +188 -190
- warp/tests/test_mesh.py +284 -324
- warp/tests/test_mesh_query_aabb.py +228 -241
- warp/tests/test_mesh_query_point.py +692 -702
- warp/tests/test_mesh_query_ray.py +292 -303
- warp/tests/test_mlp.py +276 -276
- warp/tests/test_model.py +110 -110
- warp/tests/test_modules_lite.py +39 -39
- warp/tests/test_multigpu.py +163 -163
- warp/tests/test_noise.py +248 -248
- warp/tests/test_operators.py +250 -250
- warp/tests/test_options.py +123 -125
- warp/tests/test_peer.py +133 -137
- warp/tests/test_pinned.py +78 -78
- warp/tests/test_print.py +54 -54
- warp/tests/test_quat.py +2086 -2086
- warp/tests/test_rand.py +288 -288
- warp/tests/test_reload.py +217 -217
- warp/tests/test_rounding.py +179 -179
- warp/tests/test_runlength_encode.py +190 -190
- warp/tests/test_sim_grad.py +243 -0
- warp/tests/test_sim_kinematics.py +91 -97
- warp/tests/test_smoothstep.py +168 -168
- warp/tests/test_snippet.py +305 -266
- warp/tests/test_sparse.py +468 -460
- warp/tests/test_spatial.py +2148 -2148
- warp/tests/test_streams.py +486 -473
- warp/tests/test_struct.py +710 -675
- warp/tests/test_tape.py +173 -148
- warp/tests/test_torch.py +743 -743
- warp/tests/test_transient_module.py +87 -87
- warp/tests/test_types.py +556 -659
- warp/tests/test_utils.py +490 -499
- warp/tests/test_vec.py +1264 -1268
- warp/tests/test_vec_lite.py +73 -73
- warp/tests/test_vec_scalar_ops.py +2099 -2099
- warp/tests/test_verify_fp.py +94 -94
- warp/tests/test_volume.py +737 -736
- warp/tests/test_volume_write.py +255 -265
- warp/tests/unittest_serial.py +37 -37
- warp/tests/unittest_suites.py +363 -359
- warp/tests/unittest_utils.py +603 -578
- warp/tests/unused_test_misc.py +71 -71
- warp/tests/walkthrough_debug.py +85 -85
- warp/thirdparty/appdirs.py +598 -598
- warp/thirdparty/dlpack.py +143 -143
- warp/thirdparty/unittest_parallel.py +566 -561
- warp/torch.py +321 -295
- warp/types.py +4504 -4450
- warp/utils.py +1008 -821
- {warp_lang-1.0.2.dist-info → warp_lang-1.1.0.dist-info}/LICENSE.md +126 -126
- {warp_lang-1.0.2.dist-info → warp_lang-1.1.0.dist-info}/METADATA +338 -400
- warp_lang-1.1.0.dist-info/RECORD +352 -0
- warp/examples/assets/cube.usda +0 -42
- warp/examples/assets/sphere.usda +0 -56
- warp/examples/assets/torus.usda +0 -105
- warp_lang-1.0.2.dist-info/RECORD +0 -352
- {warp_lang-1.0.2.dist-info → warp_lang-1.1.0.dist-info}/WHEEL +0 -0
- {warp_lang-1.0.2.dist-info → warp_lang-1.1.0.dist-info}/top_level.txt +0 -0
warp/tests/test_torch.py
CHANGED
|
@@ -1,743 +1,743 @@
|
|
|
1
|
-
# Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
|
|
2
|
-
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
|
3
|
-
# and proprietary rights in and to this software, related documentation
|
|
4
|
-
# and any modifications thereto. Any use, reproduction, disclosure or
|
|
5
|
-
# distribution of this software and related documentation without an express
|
|
6
|
-
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
|
7
|
-
|
|
8
|
-
import unittest
|
|
9
|
-
|
|
10
|
-
import numpy as np
|
|
11
|
-
|
|
12
|
-
import warp as wp
|
|
13
|
-
from warp.tests.unittest_utils import *
|
|
14
|
-
|
|
15
|
-
wp.init()
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
@wp.kernel
|
|
19
|
-
def op_kernel(x: wp.array(dtype=float), y: wp.array(dtype=float)):
|
|
20
|
-
tid = wp.tid()
|
|
21
|
-
y[tid] = 0.5 - x[tid] * 2.0
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
@wp.kernel
|
|
25
|
-
def inc(a: wp.array(dtype=float)):
|
|
26
|
-
tid = wp.tid()
|
|
27
|
-
a[tid] = a[tid] + 1.0
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
@wp.kernel
|
|
31
|
-
def arange(start: int, step: int, a: wp.array(dtype=int)):
|
|
32
|
-
tid = wp.tid()
|
|
33
|
-
a[tid] = start + step * tid
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
# copy elements between non-contiguous 1d arrays of float
|
|
37
|
-
@wp.kernel
|
|
38
|
-
def copy1d_float_kernel(dst: wp.array(dtype=float), src: wp.array(dtype=float)):
|
|
39
|
-
i = wp.tid()
|
|
40
|
-
dst[i] = src[i]
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
# copy elements between non-contiguous 2d arrays of float
|
|
44
|
-
@wp.kernel
|
|
45
|
-
def copy2d_float_kernel(dst: wp.array2d(dtype=float), src: wp.array2d(dtype=float)):
|
|
46
|
-
i, j = wp.tid()
|
|
47
|
-
dst[i, j] = src[i, j]
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
# copy elements between non-contiguous 3d arrays of float
|
|
51
|
-
@wp.kernel
|
|
52
|
-
def copy3d_float_kernel(dst: wp.array3d(dtype=float), src: wp.array3d(dtype=float)):
|
|
53
|
-
i, j, k = wp.tid()
|
|
54
|
-
dst[i, j, k] = src[i, j, k]
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
# copy elements between non-contiguous 2d arrays of vec3
|
|
58
|
-
@wp.kernel
|
|
59
|
-
def copy2d_vec3_kernel(dst: wp.array2d(dtype=wp.vec3), src: wp.array2d(dtype=wp.vec3)):
|
|
60
|
-
i, j = wp.tid()
|
|
61
|
-
dst[i, j] = src[i, j]
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
# copy elements between non-contiguous 2d arrays of mat22
|
|
65
|
-
@wp.kernel
|
|
66
|
-
def copy2d_mat22_kernel(dst: wp.array2d(dtype=wp.mat22), src: wp.array2d(dtype=wp.mat22)):
|
|
67
|
-
i, j = wp.tid()
|
|
68
|
-
dst[i, j] = src[i, j]
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
def test_dtype_from_torch(test, device):
|
|
72
|
-
import torch
|
|
73
|
-
|
|
74
|
-
def test_conversions(torch_type, warp_type):
|
|
75
|
-
test.assertEqual(wp.dtype_from_torch(torch_type), warp_type)
|
|
76
|
-
|
|
77
|
-
test_conversions(torch.float16, wp.float16)
|
|
78
|
-
test_conversions(torch.float32, wp.float32)
|
|
79
|
-
test_conversions(torch.float64, wp.float64)
|
|
80
|
-
test_conversions(torch.int8, wp.int8)
|
|
81
|
-
test_conversions(torch.int16, wp.int16)
|
|
82
|
-
test_conversions(torch.int32, wp.int32)
|
|
83
|
-
test_conversions(torch.int64, wp.int64)
|
|
84
|
-
test_conversions(torch.uint8, wp.uint8)
|
|
85
|
-
test_conversions(torch.bool, wp.bool)
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
def test_dtype_to_torch(test, device):
|
|
89
|
-
import torch
|
|
90
|
-
|
|
91
|
-
def test_conversions(warp_type, torch_type):
|
|
92
|
-
test.assertEqual(wp.dtype_to_torch(warp_type), torch_type)
|
|
93
|
-
|
|
94
|
-
test_conversions(wp.float16, torch.float16)
|
|
95
|
-
test_conversions(wp.float32, torch.float32)
|
|
96
|
-
test_conversions(wp.float64, torch.float64)
|
|
97
|
-
test_conversions(wp.int8, torch.int8)
|
|
98
|
-
test_conversions(wp.int16, torch.int16)
|
|
99
|
-
test_conversions(wp.int32, torch.int32)
|
|
100
|
-
test_conversions(wp.int64, torch.int64)
|
|
101
|
-
test_conversions(wp.uint8, torch.uint8)
|
|
102
|
-
test_conversions(wp.uint16, torch.int16)
|
|
103
|
-
test_conversions(wp.uint32, torch.int32)
|
|
104
|
-
test_conversions(wp.uint64, torch.int64)
|
|
105
|
-
test_conversions(wp.bool, torch.bool)
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
def test_device_conversion(test, device):
|
|
109
|
-
torch_device = wp.device_to_torch(device)
|
|
110
|
-
warp_device = wp.device_from_torch(torch_device)
|
|
111
|
-
test.assertEqual(warp_device, device)
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
def test_torch_zerocopy(test, device):
|
|
115
|
-
import torch
|
|
116
|
-
|
|
117
|
-
a = wp.zeros(10, dtype=wp.float32, device=device)
|
|
118
|
-
t = wp.to_torch(a)
|
|
119
|
-
assert a.ptr == t.data_ptr()
|
|
120
|
-
|
|
121
|
-
torch_device = wp.device_to_torch(device)
|
|
122
|
-
|
|
123
|
-
t = torch.zeros(10, dtype=torch.float32, device=torch_device)
|
|
124
|
-
a = wp.from_torch(t)
|
|
125
|
-
assert a.ptr == t.data_ptr()
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
def test_from_torch(test, device):
|
|
129
|
-
import torch
|
|
130
|
-
|
|
131
|
-
torch_device = wp.device_to_torch(device)
|
|
132
|
-
|
|
133
|
-
# automatically determine warp dtype
|
|
134
|
-
def wrap_scalar_tensor_implicit(torch_dtype, expected_warp_dtype):
|
|
135
|
-
t = torch.zeros(10, dtype=torch_dtype, device=torch_device)
|
|
136
|
-
a = wp.from_torch(t)
|
|
137
|
-
assert a.dtype == expected_warp_dtype
|
|
138
|
-
assert a.shape == tuple(t.shape)
|
|
139
|
-
|
|
140
|
-
wrap_scalar_tensor_implicit(torch.float64, wp.float64)
|
|
141
|
-
wrap_scalar_tensor_implicit(torch.float32, wp.float32)
|
|
142
|
-
wrap_scalar_tensor_implicit(torch.float16, wp.float16)
|
|
143
|
-
wrap_scalar_tensor_implicit(torch.int64, wp.int64)
|
|
144
|
-
wrap_scalar_tensor_implicit(torch.int32, wp.int32)
|
|
145
|
-
wrap_scalar_tensor_implicit(torch.int16, wp.int16)
|
|
146
|
-
wrap_scalar_tensor_implicit(torch.int8, wp.int8)
|
|
147
|
-
wrap_scalar_tensor_implicit(torch.uint8, wp.uint8)
|
|
148
|
-
wrap_scalar_tensor_implicit(torch.bool, wp.bool)
|
|
149
|
-
|
|
150
|
-
# explicitly specify warp dtype
|
|
151
|
-
def wrap_scalar_tensor_explicit(torch_dtype, expected_warp_dtype):
|
|
152
|
-
t = torch.zeros(10, dtype=torch_dtype, device=torch_device)
|
|
153
|
-
a = wp.from_torch(t, expected_warp_dtype)
|
|
154
|
-
assert a.dtype == expected_warp_dtype
|
|
155
|
-
assert a.shape == tuple(t.shape)
|
|
156
|
-
|
|
157
|
-
wrap_scalar_tensor_explicit(torch.float64, wp.float64)
|
|
158
|
-
wrap_scalar_tensor_explicit(torch.float32, wp.float32)
|
|
159
|
-
wrap_scalar_tensor_explicit(torch.float16, wp.float16)
|
|
160
|
-
wrap_scalar_tensor_explicit(torch.int64, wp.int64)
|
|
161
|
-
wrap_scalar_tensor_explicit(torch.int64, wp.uint64)
|
|
162
|
-
wrap_scalar_tensor_explicit(torch.int32, wp.int32)
|
|
163
|
-
wrap_scalar_tensor_explicit(torch.int32, wp.uint32)
|
|
164
|
-
wrap_scalar_tensor_explicit(torch.int16, wp.int16)
|
|
165
|
-
wrap_scalar_tensor_explicit(torch.int16, wp.uint16)
|
|
166
|
-
wrap_scalar_tensor_explicit(torch.int8, wp.int8)
|
|
167
|
-
wrap_scalar_tensor_explicit(torch.int8, wp.uint8)
|
|
168
|
-
wrap_scalar_tensor_explicit(torch.uint8, wp.uint8)
|
|
169
|
-
wrap_scalar_tensor_explicit(torch.uint8, wp.int8)
|
|
170
|
-
wrap_scalar_tensor_explicit(torch.bool, wp.uint8)
|
|
171
|
-
wrap_scalar_tensor_explicit(torch.bool, wp.int8)
|
|
172
|
-
wrap_scalar_tensor_explicit(torch.bool, wp.bool)
|
|
173
|
-
|
|
174
|
-
def wrap_vec_tensor(n, desired_warp_dtype):
|
|
175
|
-
t = torch.zeros((10, n), dtype=torch.float32, device=torch_device)
|
|
176
|
-
a = wp.from_torch(t, desired_warp_dtype)
|
|
177
|
-
assert a.dtype == desired_warp_dtype
|
|
178
|
-
assert a.shape == (10,)
|
|
179
|
-
|
|
180
|
-
wrap_vec_tensor(2, wp.vec2)
|
|
181
|
-
wrap_vec_tensor(3, wp.vec3)
|
|
182
|
-
wrap_vec_tensor(4, wp.vec4)
|
|
183
|
-
wrap_vec_tensor(6, wp.spatial_vector)
|
|
184
|
-
wrap_vec_tensor(7, wp.transform)
|
|
185
|
-
|
|
186
|
-
def wrap_mat_tensor(n, m, desired_warp_dtype):
|
|
187
|
-
t = torch.zeros((10, n, m), dtype=torch.float32, device=torch_device)
|
|
188
|
-
a = wp.from_torch(t, desired_warp_dtype)
|
|
189
|
-
assert a.dtype == desired_warp_dtype
|
|
190
|
-
assert a.shape == (10,)
|
|
191
|
-
|
|
192
|
-
wrap_mat_tensor(2, 2, wp.mat22)
|
|
193
|
-
wrap_mat_tensor(3, 3, wp.mat33)
|
|
194
|
-
wrap_mat_tensor(4, 4, wp.mat44)
|
|
195
|
-
wrap_mat_tensor(6, 6, wp.spatial_matrix)
|
|
196
|
-
|
|
197
|
-
def wrap_vec_tensor_with_grad(n, desired_warp_dtype):
|
|
198
|
-
t = torch.zeros((10, n), dtype=torch.float32, device=torch_device)
|
|
199
|
-
a = wp.from_torch(t, desired_warp_dtype, requires_grad=True)
|
|
200
|
-
assert a.dtype == desired_warp_dtype
|
|
201
|
-
assert a.shape == (10,)
|
|
202
|
-
|
|
203
|
-
wrap_vec_tensor_with_grad(2, wp.vec2)
|
|
204
|
-
wrap_vec_tensor_with_grad(3, wp.vec3)
|
|
205
|
-
wrap_vec_tensor_with_grad(4, wp.vec4)
|
|
206
|
-
wrap_vec_tensor_with_grad(6, wp.spatial_vector)
|
|
207
|
-
wrap_vec_tensor_with_grad(7, wp.transform)
|
|
208
|
-
|
|
209
|
-
def wrap_mat_tensor_with_grad(n, m, desired_warp_dtype):
|
|
210
|
-
t = torch.zeros((10, n, m), dtype=torch.float32, device=torch_device)
|
|
211
|
-
a = wp.from_torch(t, desired_warp_dtype, requires_grad=True)
|
|
212
|
-
assert a.dtype == desired_warp_dtype
|
|
213
|
-
assert a.shape == (10,)
|
|
214
|
-
|
|
215
|
-
wrap_mat_tensor_with_grad(2, 2, wp.mat22)
|
|
216
|
-
wrap_mat_tensor_with_grad(3, 3, wp.mat33)
|
|
217
|
-
wrap_mat_tensor_with_grad(4, 4, wp.mat44)
|
|
218
|
-
wrap_mat_tensor_with_grad(6, 6, wp.spatial_matrix)
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
def test_to_torch(test, device):
|
|
222
|
-
import torch
|
|
223
|
-
|
|
224
|
-
def wrap_scalar_array(warp_dtype, expected_torch_dtype):
|
|
225
|
-
a = wp.zeros(10, dtype=warp_dtype, device=device)
|
|
226
|
-
t = wp.to_torch(a)
|
|
227
|
-
assert t.dtype == expected_torch_dtype
|
|
228
|
-
assert tuple(t.shape) == a.shape
|
|
229
|
-
|
|
230
|
-
wrap_scalar_array(wp.float64, torch.float64)
|
|
231
|
-
wrap_scalar_array(wp.float32, torch.float32)
|
|
232
|
-
wrap_scalar_array(wp.float16, torch.float16)
|
|
233
|
-
wrap_scalar_array(wp.int64, torch.int64)
|
|
234
|
-
wrap_scalar_array(wp.int32, torch.int32)
|
|
235
|
-
wrap_scalar_array(wp.int16, torch.int16)
|
|
236
|
-
wrap_scalar_array(wp.int8, torch.int8)
|
|
237
|
-
wrap_scalar_array(wp.uint8, torch.uint8)
|
|
238
|
-
wrap_scalar_array(wp.bool, torch.bool)
|
|
239
|
-
|
|
240
|
-
# not supported by torch
|
|
241
|
-
# wrap_scalar_array(wp.uint64, torch.int64)
|
|
242
|
-
# wrap_scalar_array(wp.uint32, torch.int32)
|
|
243
|
-
# wrap_scalar_array(wp.uint16, torch.int16)
|
|
244
|
-
|
|
245
|
-
def wrap_vec_array(n, warp_dtype):
|
|
246
|
-
a = wp.zeros(10, dtype=warp_dtype, device=device)
|
|
247
|
-
t = wp.to_torch(a)
|
|
248
|
-
assert t.dtype == torch.float32
|
|
249
|
-
assert tuple(t.shape) == (10, n)
|
|
250
|
-
|
|
251
|
-
wrap_vec_array(2, wp.vec2)
|
|
252
|
-
wrap_vec_array(3, wp.vec3)
|
|
253
|
-
wrap_vec_array(4, wp.vec4)
|
|
254
|
-
wrap_vec_array(6, wp.spatial_vector)
|
|
255
|
-
wrap_vec_array(7, wp.transform)
|
|
256
|
-
|
|
257
|
-
def wrap_mat_array(n, m, warp_dtype):
|
|
258
|
-
a = wp.zeros(10, dtype=warp_dtype, device=device)
|
|
259
|
-
t = wp.to_torch(a)
|
|
260
|
-
assert t.dtype == torch.float32
|
|
261
|
-
assert tuple(t.shape) == (10, n, m)
|
|
262
|
-
|
|
263
|
-
wrap_mat_array(2, 2, wp.mat22)
|
|
264
|
-
wrap_mat_array(3, 3, wp.mat33)
|
|
265
|
-
wrap_mat_array(4, 4, wp.mat44)
|
|
266
|
-
wrap_mat_array(6, 6, wp.spatial_matrix)
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
def test_from_torch_slices(test, device):
|
|
270
|
-
import torch
|
|
271
|
-
|
|
272
|
-
torch_device = wp.device_to_torch(device)
|
|
273
|
-
|
|
274
|
-
# 1D slice, contiguous
|
|
275
|
-
t_base = torch.arange(10, dtype=torch.float32, device=torch_device)
|
|
276
|
-
t = t_base[2:9]
|
|
277
|
-
a = wp.from_torch(t)
|
|
278
|
-
assert a.ptr == t.data_ptr()
|
|
279
|
-
assert a.is_contiguous
|
|
280
|
-
assert a.shape == tuple(t.shape)
|
|
281
|
-
assert_np_equal(a.numpy(), t.cpu().numpy())
|
|
282
|
-
|
|
283
|
-
# 1D slice with non-contiguous stride
|
|
284
|
-
t_base = torch.arange(10, dtype=torch.float32, device=torch_device)
|
|
285
|
-
t = t_base[2:9:2]
|
|
286
|
-
a = wp.from_torch(t)
|
|
287
|
-
assert a.ptr == t.data_ptr()
|
|
288
|
-
assert not a.is_contiguous
|
|
289
|
-
assert a.shape == tuple(t.shape)
|
|
290
|
-
# copy contents to contiguous array
|
|
291
|
-
a_contiguous = wp.empty_like(a)
|
|
292
|
-
wp.launch(copy1d_float_kernel, dim=a.shape, inputs=[a_contiguous, a], device=device)
|
|
293
|
-
assert_np_equal(a_contiguous.numpy(), t.cpu().numpy())
|
|
294
|
-
|
|
295
|
-
# 2D slices (non-contiguous)
|
|
296
|
-
t_base = torch.arange(24, dtype=torch.float32, device=torch_device).reshape((4, 6))
|
|
297
|
-
t = t_base[1:3, 2:5]
|
|
298
|
-
a = wp.from_torch(t)
|
|
299
|
-
assert a.ptr == t.data_ptr()
|
|
300
|
-
assert not a.is_contiguous
|
|
301
|
-
assert a.shape == tuple(t.shape)
|
|
302
|
-
# copy contents to contiguous array
|
|
303
|
-
a_contiguous = wp.empty_like(a)
|
|
304
|
-
wp.launch(copy2d_float_kernel, dim=a.shape, inputs=[a_contiguous, a], device=device)
|
|
305
|
-
assert_np_equal(a_contiguous.numpy(), t.cpu().numpy())
|
|
306
|
-
|
|
307
|
-
# 3D slices (non-contiguous)
|
|
308
|
-
t_base = torch.arange(36, dtype=torch.float32, device=torch_device).reshape((4, 3, 3))
|
|
309
|
-
t = t_base[::2, 0:1, 1:2]
|
|
310
|
-
a = wp.from_torch(t)
|
|
311
|
-
assert a.ptr == t.data_ptr()
|
|
312
|
-
assert not a.is_contiguous
|
|
313
|
-
assert a.shape == tuple(t.shape)
|
|
314
|
-
# copy contents to contiguous array
|
|
315
|
-
a_contiguous = wp.empty_like(a)
|
|
316
|
-
wp.launch(copy3d_float_kernel, dim=a.shape, inputs=[a_contiguous, a], device=device)
|
|
317
|
-
assert_np_equal(a_contiguous.numpy(), t.cpu().numpy())
|
|
318
|
-
|
|
319
|
-
# 2D slices of vec3 (inner contiguous, outer non-contiguous)
|
|
320
|
-
t_base = torch.arange(150, dtype=torch.float32, device=torch_device).reshape((10, 5, 3))
|
|
321
|
-
t = t_base[1:7:2, 2:5]
|
|
322
|
-
a = wp.from_torch(t, dtype=wp.vec3)
|
|
323
|
-
assert a.ptr == t.data_ptr()
|
|
324
|
-
assert not a.is_contiguous
|
|
325
|
-
assert a.shape == tuple(t.shape[:-1])
|
|
326
|
-
# copy contents to contiguous array
|
|
327
|
-
a_contiguous = wp.empty_like(a)
|
|
328
|
-
wp.launch(copy2d_vec3_kernel, dim=a.shape, inputs=[a_contiguous, a], device=device)
|
|
329
|
-
assert_np_equal(a_contiguous.numpy(), t.cpu().numpy())
|
|
330
|
-
|
|
331
|
-
# 2D slices of mat22 (inner contiguous, outer non-contiguous)
|
|
332
|
-
t_base = torch.arange(200, dtype=torch.float32, device=torch_device).reshape((10, 5, 2, 2))
|
|
333
|
-
t = t_base[1:7:2, 2:5]
|
|
334
|
-
a = wp.from_torch(t, dtype=wp.mat22)
|
|
335
|
-
assert a.ptr == t.data_ptr()
|
|
336
|
-
assert not a.is_contiguous
|
|
337
|
-
assert a.shape == tuple(t.shape[:-2])
|
|
338
|
-
# copy contents to contiguous array
|
|
339
|
-
a_contiguous = wp.empty_like(a)
|
|
340
|
-
wp.launch(copy2d_mat22_kernel, dim=a.shape, inputs=[a_contiguous, a], device=device)
|
|
341
|
-
assert_np_equal(a_contiguous.numpy(), t.cpu().numpy())
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
def test_from_torch_zero_strides(test, device):
|
|
345
|
-
import torch
|
|
346
|
-
|
|
347
|
-
torch_device = wp.device_to_torch(device)
|
|
348
|
-
|
|
349
|
-
t_base = torch.arange(9, dtype=torch.float32, device=torch_device).reshape((3, 3))
|
|
350
|
-
|
|
351
|
-
# expand outermost dimension
|
|
352
|
-
t = t_base.unsqueeze(0).expand(3, -1, -1)
|
|
353
|
-
a = wp.from_torch(t)
|
|
354
|
-
assert a.ptr == t.data_ptr()
|
|
355
|
-
assert not a.is_contiguous
|
|
356
|
-
assert a.shape == tuple(t.shape)
|
|
357
|
-
a_contiguous = wp.empty_like(a)
|
|
358
|
-
wp.launch(copy3d_float_kernel, dim=a.shape, inputs=[a_contiguous, a], device=device)
|
|
359
|
-
assert_np_equal(a_contiguous.numpy(), t.cpu().numpy())
|
|
360
|
-
|
|
361
|
-
# expand middle dimension
|
|
362
|
-
t = t_base.unsqueeze(1).expand(-1, 3, -1)
|
|
363
|
-
a = wp.from_torch(t)
|
|
364
|
-
assert a.ptr == t.data_ptr()
|
|
365
|
-
assert not a.is_contiguous
|
|
366
|
-
assert a.shape == tuple(t.shape)
|
|
367
|
-
a_contiguous = wp.empty_like(a)
|
|
368
|
-
wp.launch(copy3d_float_kernel, dim=a.shape, inputs=[a_contiguous, a], device=device)
|
|
369
|
-
assert_np_equal(a_contiguous.numpy(), t.cpu().numpy())
|
|
370
|
-
|
|
371
|
-
# expand innermost dimension
|
|
372
|
-
t = t_base.unsqueeze(2).expand(-1, -1, 3)
|
|
373
|
-
a = wp.from_torch(t)
|
|
374
|
-
assert a.ptr == t.data_ptr()
|
|
375
|
-
assert not a.is_contiguous
|
|
376
|
-
assert a.shape == tuple(t.shape)
|
|
377
|
-
a_contiguous = wp.empty_like(a)
|
|
378
|
-
wp.launch(copy3d_float_kernel, dim=a.shape, inputs=[a_contiguous, a], device=device)
|
|
379
|
-
assert_np_equal(a_contiguous.numpy(), t.cpu().numpy())
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
def test_torch_mgpu_from_torch(test, device):
|
|
383
|
-
import torch
|
|
384
|
-
|
|
385
|
-
n = 32
|
|
386
|
-
|
|
387
|
-
t0 = torch.arange(0, n, 1, dtype=torch.int32, device="cuda:0")
|
|
388
|
-
t1 = torch.arange(0, n * 2, 2, dtype=torch.int32, device="cuda:1")
|
|
389
|
-
|
|
390
|
-
a0 = wp.from_torch(t0, dtype=wp.int32)
|
|
391
|
-
a1 = wp.from_torch(t1, dtype=wp.int32)
|
|
392
|
-
|
|
393
|
-
assert a0.device == "cuda:0"
|
|
394
|
-
assert a1.device == "cuda:1"
|
|
395
|
-
|
|
396
|
-
expected0 = np.arange(0, n, 1)
|
|
397
|
-
expected1 = np.arange(0, n * 2, 2)
|
|
398
|
-
|
|
399
|
-
assert_np_equal(a0.numpy(), expected0)
|
|
400
|
-
assert_np_equal(a1.numpy(), expected1)
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
def test_torch_mgpu_to_torch(test, device):
|
|
404
|
-
n = 32
|
|
405
|
-
|
|
406
|
-
with wp.ScopedDevice("cuda:0"):
|
|
407
|
-
a0 = wp.empty(n, dtype=wp.int32)
|
|
408
|
-
wp.launch(arange, dim=a0.size, inputs=[0, 1, a0])
|
|
409
|
-
|
|
410
|
-
with wp.ScopedDevice("cuda:1"):
|
|
411
|
-
a1 = wp.empty(n, dtype=wp.int32)
|
|
412
|
-
wp.launch(arange, dim=a1.size, inputs=[0, 2, a1])
|
|
413
|
-
|
|
414
|
-
t0 = wp.to_torch(a0)
|
|
415
|
-
t1 = wp.to_torch(a1)
|
|
416
|
-
|
|
417
|
-
assert str(t0.device) == "cuda:0"
|
|
418
|
-
assert str(t1.device) == "cuda:1"
|
|
419
|
-
|
|
420
|
-
expected0 = np.arange(0, n, 1, dtype=np.int32)
|
|
421
|
-
expected1 = np.arange(0, n * 2, 2, dtype=np.int32)
|
|
422
|
-
|
|
423
|
-
assert_np_equal(t0.cpu().numpy(), expected0)
|
|
424
|
-
assert_np_equal(t1.cpu().numpy(), expected1)
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
def test_torch_mgpu_interop(test, device):
|
|
428
|
-
import torch
|
|
429
|
-
|
|
430
|
-
n = 1024 * 1024
|
|
431
|
-
|
|
432
|
-
with torch.cuda.device(0):
|
|
433
|
-
t0 = torch.arange(n, dtype=torch.float32, device="cuda")
|
|
434
|
-
a0 = wp.from_torch(t0)
|
|
435
|
-
wp.launch(inc, dim=a0.size, inputs=[a0], stream=wp.stream_from_torch())
|
|
436
|
-
|
|
437
|
-
with torch.cuda.device(1):
|
|
438
|
-
t1 = torch.arange(n, dtype=torch.float32, device="cuda")
|
|
439
|
-
a1 = wp.from_torch(t1)
|
|
440
|
-
wp.launch(inc, dim=a1.size, inputs=[a1], stream=wp.stream_from_torch())
|
|
441
|
-
|
|
442
|
-
assert a0.device == "cuda:0"
|
|
443
|
-
assert a1.device == "cuda:1"
|
|
444
|
-
|
|
445
|
-
expected = np.arange(n, dtype=int) + 1
|
|
446
|
-
|
|
447
|
-
# ensure the torch tensors were modified by warp
|
|
448
|
-
assert_np_equal(t0.cpu().numpy(), expected)
|
|
449
|
-
assert_np_equal(t1.cpu().numpy(), expected)
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
def test_torch_autograd(test, device):
|
|
453
|
-
"""Test torch autograd with a custom Warp op"""
|
|
454
|
-
|
|
455
|
-
import torch
|
|
456
|
-
|
|
457
|
-
# custom autograd op
|
|
458
|
-
class TestFunc(torch.autograd.Function):
|
|
459
|
-
@staticmethod
|
|
460
|
-
def forward(ctx, x):
|
|
461
|
-
# allocate output array
|
|
462
|
-
y = torch.empty_like(x)
|
|
463
|
-
|
|
464
|
-
ctx.x = x
|
|
465
|
-
ctx.y = y
|
|
466
|
-
|
|
467
|
-
wp.launch(kernel=op_kernel, dim=len(x), inputs=[wp.from_torch(x)], outputs=[wp.from_torch(y)])
|
|
468
|
-
|
|
469
|
-
return y
|
|
470
|
-
|
|
471
|
-
@staticmethod
|
|
472
|
-
def backward(ctx, adj_y):
|
|
473
|
-
# adjoints should be allocated as zero initialized
|
|
474
|
-
adj_x = torch.zeros_like(ctx.x).contiguous()
|
|
475
|
-
adj_y = adj_y.contiguous()
|
|
476
|
-
|
|
477
|
-
wp_x = wp.from_torch(ctx.x, grad=adj_x)
|
|
478
|
-
wp_y = wp.from_torch(ctx.y, grad=adj_y)
|
|
479
|
-
|
|
480
|
-
wp.launch(
|
|
481
|
-
kernel=op_kernel,
|
|
482
|
-
dim=len(ctx.x),
|
|
483
|
-
# fwd inputs
|
|
484
|
-
inputs=[wp_x],
|
|
485
|
-
outputs=[wp_y],
|
|
486
|
-
# adj inputs (already stored in input/output arrays, passing null pointers)
|
|
487
|
-
adj_inputs=[None],
|
|
488
|
-
adj_outputs=[None],
|
|
489
|
-
adjoint=True,
|
|
490
|
-
)
|
|
491
|
-
|
|
492
|
-
return adj_x
|
|
493
|
-
|
|
494
|
-
# run autograd on given device
|
|
495
|
-
with wp.ScopedDevice(device):
|
|
496
|
-
torch_device = wp.device_to_torch(device)
|
|
497
|
-
|
|
498
|
-
# input data
|
|
499
|
-
x = torch.ones(16, dtype=torch.float32, device=torch_device, requires_grad=True)
|
|
500
|
-
|
|
501
|
-
# execute op
|
|
502
|
-
y = TestFunc.apply(x)
|
|
503
|
-
|
|
504
|
-
# compute grads
|
|
505
|
-
l = y.sum()
|
|
506
|
-
l.backward()
|
|
507
|
-
|
|
508
|
-
passed = (x.grad == -2.0).all()
|
|
509
|
-
assert passed.item()
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
def test_torch_graph_torch_stream(test, device):
|
|
513
|
-
"""Capture Torch graph on Torch stream"""
|
|
514
|
-
|
|
515
|
-
wp.load_module(device=device)
|
|
516
|
-
|
|
517
|
-
import torch
|
|
518
|
-
|
|
519
|
-
torch_device = wp.device_to_torch(device)
|
|
520
|
-
|
|
521
|
-
n = 1024 * 1024
|
|
522
|
-
t = torch.zeros(n, dtype=torch.float32, device=torch_device)
|
|
523
|
-
a = wp.from_torch(t)
|
|
524
|
-
|
|
525
|
-
g = torch.cuda.CUDAGraph()
|
|
526
|
-
|
|
527
|
-
# create a device-specific torch stream to use for capture
|
|
528
|
-
# (otherwise torch.cuda.graph reuses its capture stream, which can be problematic if it's from a different device)
|
|
529
|
-
torch_stream = torch.cuda.Stream(device=torch_device)
|
|
530
|
-
|
|
531
|
-
# make warp use the same stream
|
|
532
|
-
warp_stream = wp.stream_from_torch(torch_stream)
|
|
533
|
-
|
|
534
|
-
# capture graph
|
|
535
|
-
with wp.ScopedStream(warp_stream), torch.cuda.graph(g, stream=torch_stream):
|
|
536
|
-
wp.capture_begin(force_module_load=False, external=True)
|
|
537
|
-
try:
|
|
538
|
-
t += 1.0
|
|
539
|
-
wp.launch(inc, dim=n, inputs=[a])
|
|
540
|
-
t += 1.0
|
|
541
|
-
wp.launch(inc, dim=n, inputs=[a])
|
|
542
|
-
finally:
|
|
543
|
-
wp.capture_end()
|
|
544
|
-
|
|
545
|
-
# replay graph
|
|
546
|
-
num_iters = 10
|
|
547
|
-
for
|
|
548
|
-
g.replay()
|
|
549
|
-
|
|
550
|
-
passed = (t == num_iters * 4.0).all()
|
|
551
|
-
assert passed.item()
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
def test_torch_graph_warp_stream(test, device):
|
|
555
|
-
"""Capture Torch graph on Warp stream"""
|
|
556
|
-
|
|
557
|
-
import torch
|
|
558
|
-
|
|
559
|
-
torch_device = wp.device_to_torch(device)
|
|
560
|
-
|
|
561
|
-
n = 1024 * 1024
|
|
562
|
-
t = torch.zeros(n, dtype=torch.float32, device=torch_device)
|
|
563
|
-
a = wp.from_torch(t)
|
|
564
|
-
|
|
565
|
-
g = torch.cuda.CUDAGraph()
|
|
566
|
-
|
|
567
|
-
# make torch use the warp stream from the given device
|
|
568
|
-
torch_stream = wp.stream_to_torch(device)
|
|
569
|
-
|
|
570
|
-
# capture graph
|
|
571
|
-
with wp.ScopedDevice(device), torch.cuda.graph(g, stream=torch_stream):
|
|
572
|
-
wp.capture_begin(force_module_load=False, external=True)
|
|
573
|
-
try:
|
|
574
|
-
t += 1.0
|
|
575
|
-
wp.launch(inc, dim=n, inputs=[a])
|
|
576
|
-
t += 1.0
|
|
577
|
-
wp.launch(inc, dim=n, inputs=[a])
|
|
578
|
-
finally:
|
|
579
|
-
wp.capture_end()
|
|
580
|
-
|
|
581
|
-
# replay graph
|
|
582
|
-
num_iters = 10
|
|
583
|
-
for
|
|
584
|
-
g.replay()
|
|
585
|
-
|
|
586
|
-
passed = (t == num_iters * 4.0).all()
|
|
587
|
-
assert passed.item()
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
def test_warp_graph_warp_stream(test, device):
|
|
591
|
-
"""Capture Warp graph on Warp stream"""
|
|
592
|
-
|
|
593
|
-
import torch
|
|
594
|
-
|
|
595
|
-
torch_device = wp.device_to_torch(device)
|
|
596
|
-
|
|
597
|
-
n = 1024 * 1024
|
|
598
|
-
t = torch.zeros(n, dtype=torch.float32, device=torch_device)
|
|
599
|
-
a = wp.from_torch(t)
|
|
600
|
-
|
|
601
|
-
# make torch use the warp stream from the given device
|
|
602
|
-
torch_stream = wp.stream_to_torch(device)
|
|
603
|
-
|
|
604
|
-
# capture graph
|
|
605
|
-
with wp.ScopedDevice(device), torch.cuda.stream(torch_stream):
|
|
606
|
-
wp.capture_begin(force_module_load=False)
|
|
607
|
-
try:
|
|
608
|
-
t += 1.0
|
|
609
|
-
wp.launch(inc, dim=n, inputs=[a])
|
|
610
|
-
t += 1.0
|
|
611
|
-
wp.launch(inc, dim=n, inputs=[a])
|
|
612
|
-
finally:
|
|
613
|
-
g = wp.capture_end()
|
|
614
|
-
|
|
615
|
-
# replay graph
|
|
616
|
-
num_iters = 10
|
|
617
|
-
for
|
|
618
|
-
wp.capture_launch(g)
|
|
619
|
-
|
|
620
|
-
passed = (t == num_iters * 4.0).all()
|
|
621
|
-
assert passed.item()
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
def test_warp_graph_torch_stream(test, device):
|
|
625
|
-
"""Capture Warp graph on Torch stream"""
|
|
626
|
-
|
|
627
|
-
wp.load_module(device=device)
|
|
628
|
-
|
|
629
|
-
import torch
|
|
630
|
-
|
|
631
|
-
torch_device = wp.device_to_torch(device)
|
|
632
|
-
|
|
633
|
-
n = 1024 * 1024
|
|
634
|
-
t = torch.zeros(n, dtype=torch.float32, device=torch_device)
|
|
635
|
-
a = wp.from_torch(t)
|
|
636
|
-
|
|
637
|
-
# create a device-specific torch stream to use for capture
|
|
638
|
-
# (the default torch stream is not suitable for graph capture)
|
|
639
|
-
torch_stream = torch.cuda.Stream(device=torch_device)
|
|
640
|
-
|
|
641
|
-
# make warp use the same stream
|
|
642
|
-
warp_stream = wp.stream_from_torch(torch_stream)
|
|
643
|
-
|
|
644
|
-
# capture graph
|
|
645
|
-
with wp.ScopedStream(warp_stream), torch.cuda.stream(torch_stream):
|
|
646
|
-
wp.capture_begin(force_module_load=False)
|
|
647
|
-
try:
|
|
648
|
-
t += 1.0
|
|
649
|
-
wp.launch(inc, dim=n, inputs=[a])
|
|
650
|
-
t += 1.0
|
|
651
|
-
wp.launch(inc, dim=n, inputs=[a])
|
|
652
|
-
finally:
|
|
653
|
-
g = wp.capture_end()
|
|
654
|
-
|
|
655
|
-
# replay graph
|
|
656
|
-
num_iters = 10
|
|
657
|
-
for
|
|
658
|
-
wp.capture_launch(g)
|
|
659
|
-
|
|
660
|
-
passed = (t == num_iters * 4.0).all()
|
|
661
|
-
assert passed.item()
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
class TestTorch(unittest.TestCase):
|
|
665
|
-
pass
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
test_devices = get_test_devices()
|
|
669
|
-
|
|
670
|
-
try:
|
|
671
|
-
import torch
|
|
672
|
-
|
|
673
|
-
# check which Warp devices work with Torch
|
|
674
|
-
# CUDA devices may fail if Torch was not compiled with CUDA support
|
|
675
|
-
torch_compatible_devices = []
|
|
676
|
-
torch_compatible_cuda_devices = []
|
|
677
|
-
|
|
678
|
-
for d in test_devices:
|
|
679
|
-
try:
|
|
680
|
-
t = torch.arange(10, device=wp.device_to_torch(d))
|
|
681
|
-
t += 1
|
|
682
|
-
torch_compatible_devices.append(d)
|
|
683
|
-
if d.is_cuda:
|
|
684
|
-
torch_compatible_cuda_devices.append(d)
|
|
685
|
-
except Exception as e:
|
|
686
|
-
print(f"Skipping Torch tests on device '{d}' due to exception: {e}")
|
|
687
|
-
|
|
688
|
-
add_function_test(TestTorch, "test_dtype_from_torch", test_dtype_from_torch, devices=None)
|
|
689
|
-
add_function_test(TestTorch, "test_dtype_to_torch", test_dtype_to_torch, devices=None)
|
|
690
|
-
|
|
691
|
-
if torch_compatible_devices:
|
|
692
|
-
add_function_test(TestTorch, "test_device_conversion", test_device_conversion, devices=torch_compatible_devices)
|
|
693
|
-
add_function_test(TestTorch, "test_from_torch", test_from_torch, devices=torch_compatible_devices)
|
|
694
|
-
add_function_test(TestTorch, "test_from_torch_slices", test_from_torch_slices, devices=torch_compatible_devices)
|
|
695
|
-
add_function_test(
|
|
696
|
-
TestTorch,
|
|
697
|
-
"test_from_torch_zero_strides",
|
|
698
|
-
test_from_torch_zero_strides,
|
|
699
|
-
devices=torch_compatible_devices,
|
|
700
|
-
)
|
|
701
|
-
add_function_test(TestTorch, "test_to_torch", test_to_torch, devices=torch_compatible_devices)
|
|
702
|
-
add_function_test(TestTorch, "test_torch_zerocopy", test_torch_zerocopy, devices=torch_compatible_devices)
|
|
703
|
-
add_function_test(TestTorch, "test_torch_autograd", test_torch_autograd, devices=torch_compatible_devices)
|
|
704
|
-
|
|
705
|
-
if torch_compatible_cuda_devices:
|
|
706
|
-
add_function_test(
|
|
707
|
-
TestTorch,
|
|
708
|
-
"test_torch_graph_torch_stream",
|
|
709
|
-
test_torch_graph_torch_stream,
|
|
710
|
-
devices=torch_compatible_cuda_devices,
|
|
711
|
-
)
|
|
712
|
-
add_function_test(
|
|
713
|
-
TestTorch,
|
|
714
|
-
"test_torch_graph_warp_stream",
|
|
715
|
-
test_torch_graph_warp_stream,
|
|
716
|
-
devices=torch_compatible_cuda_devices,
|
|
717
|
-
)
|
|
718
|
-
add_function_test(
|
|
719
|
-
TestTorch,
|
|
720
|
-
"test_warp_graph_warp_stream",
|
|
721
|
-
test_warp_graph_warp_stream,
|
|
722
|
-
devices=torch_compatible_cuda_devices,
|
|
723
|
-
)
|
|
724
|
-
add_function_test(
|
|
725
|
-
TestTorch,
|
|
726
|
-
"test_warp_graph_torch_stream",
|
|
727
|
-
test_warp_graph_torch_stream,
|
|
728
|
-
devices=torch_compatible_cuda_devices,
|
|
729
|
-
)
|
|
730
|
-
|
|
731
|
-
# multi-GPU tests
|
|
732
|
-
if len(torch_compatible_cuda_devices) > 1:
|
|
733
|
-
add_function_test(TestTorch, "test_torch_mgpu_from_torch", test_torch_mgpu_from_torch)
|
|
734
|
-
add_function_test(TestTorch, "test_torch_mgpu_to_torch", test_torch_mgpu_to_torch)
|
|
735
|
-
add_function_test(TestTorch, "test_torch_mgpu_interop", test_torch_mgpu_interop)
|
|
736
|
-
|
|
737
|
-
except Exception as e:
|
|
738
|
-
print(f"Skipping Torch tests due to exception: {e}")
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
if __name__ == "__main__":
|
|
742
|
-
wp.build.clear_kernel_cache()
|
|
743
|
-
unittest.main(verbosity=2)
|
|
1
|
+
# Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
|
|
2
|
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
|
3
|
+
# and proprietary rights in and to this software, related documentation
|
|
4
|
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
|
5
|
+
# distribution of this software and related documentation without an express
|
|
6
|
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
|
7
|
+
|
|
8
|
+
import unittest
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
|
|
12
|
+
import warp as wp
|
|
13
|
+
from warp.tests.unittest_utils import *
|
|
14
|
+
|
|
15
|
+
wp.init()
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@wp.kernel
|
|
19
|
+
def op_kernel(x: wp.array(dtype=float), y: wp.array(dtype=float)):
|
|
20
|
+
tid = wp.tid()
|
|
21
|
+
y[tid] = 0.5 - x[tid] * 2.0
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@wp.kernel
|
|
25
|
+
def inc(a: wp.array(dtype=float)):
|
|
26
|
+
tid = wp.tid()
|
|
27
|
+
a[tid] = a[tid] + 1.0
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@wp.kernel
|
|
31
|
+
def arange(start: int, step: int, a: wp.array(dtype=int)):
|
|
32
|
+
tid = wp.tid()
|
|
33
|
+
a[tid] = start + step * tid
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
# copy elements between non-contiguous 1d arrays of float
|
|
37
|
+
@wp.kernel
|
|
38
|
+
def copy1d_float_kernel(dst: wp.array(dtype=float), src: wp.array(dtype=float)):
|
|
39
|
+
i = wp.tid()
|
|
40
|
+
dst[i] = src[i]
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
# copy elements between non-contiguous 2d arrays of float
|
|
44
|
+
@wp.kernel
|
|
45
|
+
def copy2d_float_kernel(dst: wp.array2d(dtype=float), src: wp.array2d(dtype=float)):
|
|
46
|
+
i, j = wp.tid()
|
|
47
|
+
dst[i, j] = src[i, j]
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
# copy elements between non-contiguous 3d arrays of float
|
|
51
|
+
@wp.kernel
|
|
52
|
+
def copy3d_float_kernel(dst: wp.array3d(dtype=float), src: wp.array3d(dtype=float)):
|
|
53
|
+
i, j, k = wp.tid()
|
|
54
|
+
dst[i, j, k] = src[i, j, k]
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
# copy elements between non-contiguous 2d arrays of vec3
|
|
58
|
+
@wp.kernel
|
|
59
|
+
def copy2d_vec3_kernel(dst: wp.array2d(dtype=wp.vec3), src: wp.array2d(dtype=wp.vec3)):
|
|
60
|
+
i, j = wp.tid()
|
|
61
|
+
dst[i, j] = src[i, j]
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
# copy elements between non-contiguous 2d arrays of mat22
|
|
65
|
+
@wp.kernel
|
|
66
|
+
def copy2d_mat22_kernel(dst: wp.array2d(dtype=wp.mat22), src: wp.array2d(dtype=wp.mat22)):
|
|
67
|
+
i, j = wp.tid()
|
|
68
|
+
dst[i, j] = src[i, j]
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def test_dtype_from_torch(test, device):
|
|
72
|
+
import torch
|
|
73
|
+
|
|
74
|
+
def test_conversions(torch_type, warp_type):
|
|
75
|
+
test.assertEqual(wp.dtype_from_torch(torch_type), warp_type)
|
|
76
|
+
|
|
77
|
+
test_conversions(torch.float16, wp.float16)
|
|
78
|
+
test_conversions(torch.float32, wp.float32)
|
|
79
|
+
test_conversions(torch.float64, wp.float64)
|
|
80
|
+
test_conversions(torch.int8, wp.int8)
|
|
81
|
+
test_conversions(torch.int16, wp.int16)
|
|
82
|
+
test_conversions(torch.int32, wp.int32)
|
|
83
|
+
test_conversions(torch.int64, wp.int64)
|
|
84
|
+
test_conversions(torch.uint8, wp.uint8)
|
|
85
|
+
test_conversions(torch.bool, wp.bool)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def test_dtype_to_torch(test, device):
|
|
89
|
+
import torch
|
|
90
|
+
|
|
91
|
+
def test_conversions(warp_type, torch_type):
|
|
92
|
+
test.assertEqual(wp.dtype_to_torch(warp_type), torch_type)
|
|
93
|
+
|
|
94
|
+
test_conversions(wp.float16, torch.float16)
|
|
95
|
+
test_conversions(wp.float32, torch.float32)
|
|
96
|
+
test_conversions(wp.float64, torch.float64)
|
|
97
|
+
test_conversions(wp.int8, torch.int8)
|
|
98
|
+
test_conversions(wp.int16, torch.int16)
|
|
99
|
+
test_conversions(wp.int32, torch.int32)
|
|
100
|
+
test_conversions(wp.int64, torch.int64)
|
|
101
|
+
test_conversions(wp.uint8, torch.uint8)
|
|
102
|
+
test_conversions(wp.uint16, torch.int16)
|
|
103
|
+
test_conversions(wp.uint32, torch.int32)
|
|
104
|
+
test_conversions(wp.uint64, torch.int64)
|
|
105
|
+
test_conversions(wp.bool, torch.bool)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def test_device_conversion(test, device):
|
|
109
|
+
torch_device = wp.device_to_torch(device)
|
|
110
|
+
warp_device = wp.device_from_torch(torch_device)
|
|
111
|
+
test.assertEqual(warp_device, device)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def test_torch_zerocopy(test, device):
|
|
115
|
+
import torch
|
|
116
|
+
|
|
117
|
+
a = wp.zeros(10, dtype=wp.float32, device=device)
|
|
118
|
+
t = wp.to_torch(a)
|
|
119
|
+
assert a.ptr == t.data_ptr()
|
|
120
|
+
|
|
121
|
+
torch_device = wp.device_to_torch(device)
|
|
122
|
+
|
|
123
|
+
t = torch.zeros(10, dtype=torch.float32, device=torch_device)
|
|
124
|
+
a = wp.from_torch(t)
|
|
125
|
+
assert a.ptr == t.data_ptr()
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def test_from_torch(test, device):
|
|
129
|
+
import torch
|
|
130
|
+
|
|
131
|
+
torch_device = wp.device_to_torch(device)
|
|
132
|
+
|
|
133
|
+
# automatically determine warp dtype
|
|
134
|
+
def wrap_scalar_tensor_implicit(torch_dtype, expected_warp_dtype):
|
|
135
|
+
t = torch.zeros(10, dtype=torch_dtype, device=torch_device)
|
|
136
|
+
a = wp.from_torch(t)
|
|
137
|
+
assert a.dtype == expected_warp_dtype
|
|
138
|
+
assert a.shape == tuple(t.shape)
|
|
139
|
+
|
|
140
|
+
wrap_scalar_tensor_implicit(torch.float64, wp.float64)
|
|
141
|
+
wrap_scalar_tensor_implicit(torch.float32, wp.float32)
|
|
142
|
+
wrap_scalar_tensor_implicit(torch.float16, wp.float16)
|
|
143
|
+
wrap_scalar_tensor_implicit(torch.int64, wp.int64)
|
|
144
|
+
wrap_scalar_tensor_implicit(torch.int32, wp.int32)
|
|
145
|
+
wrap_scalar_tensor_implicit(torch.int16, wp.int16)
|
|
146
|
+
wrap_scalar_tensor_implicit(torch.int8, wp.int8)
|
|
147
|
+
wrap_scalar_tensor_implicit(torch.uint8, wp.uint8)
|
|
148
|
+
wrap_scalar_tensor_implicit(torch.bool, wp.bool)
|
|
149
|
+
|
|
150
|
+
# explicitly specify warp dtype
|
|
151
|
+
def wrap_scalar_tensor_explicit(torch_dtype, expected_warp_dtype):
|
|
152
|
+
t = torch.zeros(10, dtype=torch_dtype, device=torch_device)
|
|
153
|
+
a = wp.from_torch(t, expected_warp_dtype)
|
|
154
|
+
assert a.dtype == expected_warp_dtype
|
|
155
|
+
assert a.shape == tuple(t.shape)
|
|
156
|
+
|
|
157
|
+
wrap_scalar_tensor_explicit(torch.float64, wp.float64)
|
|
158
|
+
wrap_scalar_tensor_explicit(torch.float32, wp.float32)
|
|
159
|
+
wrap_scalar_tensor_explicit(torch.float16, wp.float16)
|
|
160
|
+
wrap_scalar_tensor_explicit(torch.int64, wp.int64)
|
|
161
|
+
wrap_scalar_tensor_explicit(torch.int64, wp.uint64)
|
|
162
|
+
wrap_scalar_tensor_explicit(torch.int32, wp.int32)
|
|
163
|
+
wrap_scalar_tensor_explicit(torch.int32, wp.uint32)
|
|
164
|
+
wrap_scalar_tensor_explicit(torch.int16, wp.int16)
|
|
165
|
+
wrap_scalar_tensor_explicit(torch.int16, wp.uint16)
|
|
166
|
+
wrap_scalar_tensor_explicit(torch.int8, wp.int8)
|
|
167
|
+
wrap_scalar_tensor_explicit(torch.int8, wp.uint8)
|
|
168
|
+
wrap_scalar_tensor_explicit(torch.uint8, wp.uint8)
|
|
169
|
+
wrap_scalar_tensor_explicit(torch.uint8, wp.int8)
|
|
170
|
+
wrap_scalar_tensor_explicit(torch.bool, wp.uint8)
|
|
171
|
+
wrap_scalar_tensor_explicit(torch.bool, wp.int8)
|
|
172
|
+
wrap_scalar_tensor_explicit(torch.bool, wp.bool)
|
|
173
|
+
|
|
174
|
+
def wrap_vec_tensor(n, desired_warp_dtype):
|
|
175
|
+
t = torch.zeros((10, n), dtype=torch.float32, device=torch_device)
|
|
176
|
+
a = wp.from_torch(t, desired_warp_dtype)
|
|
177
|
+
assert a.dtype == desired_warp_dtype
|
|
178
|
+
assert a.shape == (10,)
|
|
179
|
+
|
|
180
|
+
wrap_vec_tensor(2, wp.vec2)
|
|
181
|
+
wrap_vec_tensor(3, wp.vec3)
|
|
182
|
+
wrap_vec_tensor(4, wp.vec4)
|
|
183
|
+
wrap_vec_tensor(6, wp.spatial_vector)
|
|
184
|
+
wrap_vec_tensor(7, wp.transform)
|
|
185
|
+
|
|
186
|
+
def wrap_mat_tensor(n, m, desired_warp_dtype):
|
|
187
|
+
t = torch.zeros((10, n, m), dtype=torch.float32, device=torch_device)
|
|
188
|
+
a = wp.from_torch(t, desired_warp_dtype)
|
|
189
|
+
assert a.dtype == desired_warp_dtype
|
|
190
|
+
assert a.shape == (10,)
|
|
191
|
+
|
|
192
|
+
wrap_mat_tensor(2, 2, wp.mat22)
|
|
193
|
+
wrap_mat_tensor(3, 3, wp.mat33)
|
|
194
|
+
wrap_mat_tensor(4, 4, wp.mat44)
|
|
195
|
+
wrap_mat_tensor(6, 6, wp.spatial_matrix)
|
|
196
|
+
|
|
197
|
+
def wrap_vec_tensor_with_grad(n, desired_warp_dtype):
|
|
198
|
+
t = torch.zeros((10, n), dtype=torch.float32, device=torch_device)
|
|
199
|
+
a = wp.from_torch(t, desired_warp_dtype, requires_grad=True)
|
|
200
|
+
assert a.dtype == desired_warp_dtype
|
|
201
|
+
assert a.shape == (10,)
|
|
202
|
+
|
|
203
|
+
wrap_vec_tensor_with_grad(2, wp.vec2)
|
|
204
|
+
wrap_vec_tensor_with_grad(3, wp.vec3)
|
|
205
|
+
wrap_vec_tensor_with_grad(4, wp.vec4)
|
|
206
|
+
wrap_vec_tensor_with_grad(6, wp.spatial_vector)
|
|
207
|
+
wrap_vec_tensor_with_grad(7, wp.transform)
|
|
208
|
+
|
|
209
|
+
def wrap_mat_tensor_with_grad(n, m, desired_warp_dtype):
|
|
210
|
+
t = torch.zeros((10, n, m), dtype=torch.float32, device=torch_device)
|
|
211
|
+
a = wp.from_torch(t, desired_warp_dtype, requires_grad=True)
|
|
212
|
+
assert a.dtype == desired_warp_dtype
|
|
213
|
+
assert a.shape == (10,)
|
|
214
|
+
|
|
215
|
+
wrap_mat_tensor_with_grad(2, 2, wp.mat22)
|
|
216
|
+
wrap_mat_tensor_with_grad(3, 3, wp.mat33)
|
|
217
|
+
wrap_mat_tensor_with_grad(4, 4, wp.mat44)
|
|
218
|
+
wrap_mat_tensor_with_grad(6, 6, wp.spatial_matrix)
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def test_to_torch(test, device):
|
|
222
|
+
import torch
|
|
223
|
+
|
|
224
|
+
def wrap_scalar_array(warp_dtype, expected_torch_dtype):
|
|
225
|
+
a = wp.zeros(10, dtype=warp_dtype, device=device)
|
|
226
|
+
t = wp.to_torch(a)
|
|
227
|
+
assert t.dtype == expected_torch_dtype
|
|
228
|
+
assert tuple(t.shape) == a.shape
|
|
229
|
+
|
|
230
|
+
wrap_scalar_array(wp.float64, torch.float64)
|
|
231
|
+
wrap_scalar_array(wp.float32, torch.float32)
|
|
232
|
+
wrap_scalar_array(wp.float16, torch.float16)
|
|
233
|
+
wrap_scalar_array(wp.int64, torch.int64)
|
|
234
|
+
wrap_scalar_array(wp.int32, torch.int32)
|
|
235
|
+
wrap_scalar_array(wp.int16, torch.int16)
|
|
236
|
+
wrap_scalar_array(wp.int8, torch.int8)
|
|
237
|
+
wrap_scalar_array(wp.uint8, torch.uint8)
|
|
238
|
+
wrap_scalar_array(wp.bool, torch.bool)
|
|
239
|
+
|
|
240
|
+
# not supported by torch
|
|
241
|
+
# wrap_scalar_array(wp.uint64, torch.int64)
|
|
242
|
+
# wrap_scalar_array(wp.uint32, torch.int32)
|
|
243
|
+
# wrap_scalar_array(wp.uint16, torch.int16)
|
|
244
|
+
|
|
245
|
+
def wrap_vec_array(n, warp_dtype):
|
|
246
|
+
a = wp.zeros(10, dtype=warp_dtype, device=device)
|
|
247
|
+
t = wp.to_torch(a)
|
|
248
|
+
assert t.dtype == torch.float32
|
|
249
|
+
assert tuple(t.shape) == (10, n)
|
|
250
|
+
|
|
251
|
+
wrap_vec_array(2, wp.vec2)
|
|
252
|
+
wrap_vec_array(3, wp.vec3)
|
|
253
|
+
wrap_vec_array(4, wp.vec4)
|
|
254
|
+
wrap_vec_array(6, wp.spatial_vector)
|
|
255
|
+
wrap_vec_array(7, wp.transform)
|
|
256
|
+
|
|
257
|
+
def wrap_mat_array(n, m, warp_dtype):
|
|
258
|
+
a = wp.zeros(10, dtype=warp_dtype, device=device)
|
|
259
|
+
t = wp.to_torch(a)
|
|
260
|
+
assert t.dtype == torch.float32
|
|
261
|
+
assert tuple(t.shape) == (10, n, m)
|
|
262
|
+
|
|
263
|
+
wrap_mat_array(2, 2, wp.mat22)
|
|
264
|
+
wrap_mat_array(3, 3, wp.mat33)
|
|
265
|
+
wrap_mat_array(4, 4, wp.mat44)
|
|
266
|
+
wrap_mat_array(6, 6, wp.spatial_matrix)
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
def test_from_torch_slices(test, device):
|
|
270
|
+
import torch
|
|
271
|
+
|
|
272
|
+
torch_device = wp.device_to_torch(device)
|
|
273
|
+
|
|
274
|
+
# 1D slice, contiguous
|
|
275
|
+
t_base = torch.arange(10, dtype=torch.float32, device=torch_device)
|
|
276
|
+
t = t_base[2:9]
|
|
277
|
+
a = wp.from_torch(t)
|
|
278
|
+
assert a.ptr == t.data_ptr()
|
|
279
|
+
assert a.is_contiguous
|
|
280
|
+
assert a.shape == tuple(t.shape)
|
|
281
|
+
assert_np_equal(a.numpy(), t.cpu().numpy())
|
|
282
|
+
|
|
283
|
+
# 1D slice with non-contiguous stride
|
|
284
|
+
t_base = torch.arange(10, dtype=torch.float32, device=torch_device)
|
|
285
|
+
t = t_base[2:9:2]
|
|
286
|
+
a = wp.from_torch(t)
|
|
287
|
+
assert a.ptr == t.data_ptr()
|
|
288
|
+
assert not a.is_contiguous
|
|
289
|
+
assert a.shape == tuple(t.shape)
|
|
290
|
+
# copy contents to contiguous array
|
|
291
|
+
a_contiguous = wp.empty_like(a)
|
|
292
|
+
wp.launch(copy1d_float_kernel, dim=a.shape, inputs=[a_contiguous, a], device=device)
|
|
293
|
+
assert_np_equal(a_contiguous.numpy(), t.cpu().numpy())
|
|
294
|
+
|
|
295
|
+
# 2D slices (non-contiguous)
|
|
296
|
+
t_base = torch.arange(24, dtype=torch.float32, device=torch_device).reshape((4, 6))
|
|
297
|
+
t = t_base[1:3, 2:5]
|
|
298
|
+
a = wp.from_torch(t)
|
|
299
|
+
assert a.ptr == t.data_ptr()
|
|
300
|
+
assert not a.is_contiguous
|
|
301
|
+
assert a.shape == tuple(t.shape)
|
|
302
|
+
# copy contents to contiguous array
|
|
303
|
+
a_contiguous = wp.empty_like(a)
|
|
304
|
+
wp.launch(copy2d_float_kernel, dim=a.shape, inputs=[a_contiguous, a], device=device)
|
|
305
|
+
assert_np_equal(a_contiguous.numpy(), t.cpu().numpy())
|
|
306
|
+
|
|
307
|
+
# 3D slices (non-contiguous)
|
|
308
|
+
t_base = torch.arange(36, dtype=torch.float32, device=torch_device).reshape((4, 3, 3))
|
|
309
|
+
t = t_base[::2, 0:1, 1:2]
|
|
310
|
+
a = wp.from_torch(t)
|
|
311
|
+
assert a.ptr == t.data_ptr()
|
|
312
|
+
assert not a.is_contiguous
|
|
313
|
+
assert a.shape == tuple(t.shape)
|
|
314
|
+
# copy contents to contiguous array
|
|
315
|
+
a_contiguous = wp.empty_like(a)
|
|
316
|
+
wp.launch(copy3d_float_kernel, dim=a.shape, inputs=[a_contiguous, a], device=device)
|
|
317
|
+
assert_np_equal(a_contiguous.numpy(), t.cpu().numpy())
|
|
318
|
+
|
|
319
|
+
# 2D slices of vec3 (inner contiguous, outer non-contiguous)
|
|
320
|
+
t_base = torch.arange(150, dtype=torch.float32, device=torch_device).reshape((10, 5, 3))
|
|
321
|
+
t = t_base[1:7:2, 2:5]
|
|
322
|
+
a = wp.from_torch(t, dtype=wp.vec3)
|
|
323
|
+
assert a.ptr == t.data_ptr()
|
|
324
|
+
assert not a.is_contiguous
|
|
325
|
+
assert a.shape == tuple(t.shape[:-1])
|
|
326
|
+
# copy contents to contiguous array
|
|
327
|
+
a_contiguous = wp.empty_like(a)
|
|
328
|
+
wp.launch(copy2d_vec3_kernel, dim=a.shape, inputs=[a_contiguous, a], device=device)
|
|
329
|
+
assert_np_equal(a_contiguous.numpy(), t.cpu().numpy())
|
|
330
|
+
|
|
331
|
+
# 2D slices of mat22 (inner contiguous, outer non-contiguous)
|
|
332
|
+
t_base = torch.arange(200, dtype=torch.float32, device=torch_device).reshape((10, 5, 2, 2))
|
|
333
|
+
t = t_base[1:7:2, 2:5]
|
|
334
|
+
a = wp.from_torch(t, dtype=wp.mat22)
|
|
335
|
+
assert a.ptr == t.data_ptr()
|
|
336
|
+
assert not a.is_contiguous
|
|
337
|
+
assert a.shape == tuple(t.shape[:-2])
|
|
338
|
+
# copy contents to contiguous array
|
|
339
|
+
a_contiguous = wp.empty_like(a)
|
|
340
|
+
wp.launch(copy2d_mat22_kernel, dim=a.shape, inputs=[a_contiguous, a], device=device)
|
|
341
|
+
assert_np_equal(a_contiguous.numpy(), t.cpu().numpy())
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
def test_from_torch_zero_strides(test, device):
|
|
345
|
+
import torch
|
|
346
|
+
|
|
347
|
+
torch_device = wp.device_to_torch(device)
|
|
348
|
+
|
|
349
|
+
t_base = torch.arange(9, dtype=torch.float32, device=torch_device).reshape((3, 3))
|
|
350
|
+
|
|
351
|
+
# expand outermost dimension
|
|
352
|
+
t = t_base.unsqueeze(0).expand(3, -1, -1)
|
|
353
|
+
a = wp.from_torch(t)
|
|
354
|
+
assert a.ptr == t.data_ptr()
|
|
355
|
+
assert not a.is_contiguous
|
|
356
|
+
assert a.shape == tuple(t.shape)
|
|
357
|
+
a_contiguous = wp.empty_like(a)
|
|
358
|
+
wp.launch(copy3d_float_kernel, dim=a.shape, inputs=[a_contiguous, a], device=device)
|
|
359
|
+
assert_np_equal(a_contiguous.numpy(), t.cpu().numpy())
|
|
360
|
+
|
|
361
|
+
# expand middle dimension
|
|
362
|
+
t = t_base.unsqueeze(1).expand(-1, 3, -1)
|
|
363
|
+
a = wp.from_torch(t)
|
|
364
|
+
assert a.ptr == t.data_ptr()
|
|
365
|
+
assert not a.is_contiguous
|
|
366
|
+
assert a.shape == tuple(t.shape)
|
|
367
|
+
a_contiguous = wp.empty_like(a)
|
|
368
|
+
wp.launch(copy3d_float_kernel, dim=a.shape, inputs=[a_contiguous, a], device=device)
|
|
369
|
+
assert_np_equal(a_contiguous.numpy(), t.cpu().numpy())
|
|
370
|
+
|
|
371
|
+
# expand innermost dimension
|
|
372
|
+
t = t_base.unsqueeze(2).expand(-1, -1, 3)
|
|
373
|
+
a = wp.from_torch(t)
|
|
374
|
+
assert a.ptr == t.data_ptr()
|
|
375
|
+
assert not a.is_contiguous
|
|
376
|
+
assert a.shape == tuple(t.shape)
|
|
377
|
+
a_contiguous = wp.empty_like(a)
|
|
378
|
+
wp.launch(copy3d_float_kernel, dim=a.shape, inputs=[a_contiguous, a], device=device)
|
|
379
|
+
assert_np_equal(a_contiguous.numpy(), t.cpu().numpy())
|
|
380
|
+
|
|
381
|
+
|
|
382
|
+
def test_torch_mgpu_from_torch(test, device):
|
|
383
|
+
import torch
|
|
384
|
+
|
|
385
|
+
n = 32
|
|
386
|
+
|
|
387
|
+
t0 = torch.arange(0, n, 1, dtype=torch.int32, device="cuda:0")
|
|
388
|
+
t1 = torch.arange(0, n * 2, 2, dtype=torch.int32, device="cuda:1")
|
|
389
|
+
|
|
390
|
+
a0 = wp.from_torch(t0, dtype=wp.int32)
|
|
391
|
+
a1 = wp.from_torch(t1, dtype=wp.int32)
|
|
392
|
+
|
|
393
|
+
assert a0.device == "cuda:0"
|
|
394
|
+
assert a1.device == "cuda:1"
|
|
395
|
+
|
|
396
|
+
expected0 = np.arange(0, n, 1)
|
|
397
|
+
expected1 = np.arange(0, n * 2, 2)
|
|
398
|
+
|
|
399
|
+
assert_np_equal(a0.numpy(), expected0)
|
|
400
|
+
assert_np_equal(a1.numpy(), expected1)
|
|
401
|
+
|
|
402
|
+
|
|
403
|
+
def test_torch_mgpu_to_torch(test, device):
|
|
404
|
+
n = 32
|
|
405
|
+
|
|
406
|
+
with wp.ScopedDevice("cuda:0"):
|
|
407
|
+
a0 = wp.empty(n, dtype=wp.int32)
|
|
408
|
+
wp.launch(arange, dim=a0.size, inputs=[0, 1, a0])
|
|
409
|
+
|
|
410
|
+
with wp.ScopedDevice("cuda:1"):
|
|
411
|
+
a1 = wp.empty(n, dtype=wp.int32)
|
|
412
|
+
wp.launch(arange, dim=a1.size, inputs=[0, 2, a1])
|
|
413
|
+
|
|
414
|
+
t0 = wp.to_torch(a0)
|
|
415
|
+
t1 = wp.to_torch(a1)
|
|
416
|
+
|
|
417
|
+
assert str(t0.device) == "cuda:0"
|
|
418
|
+
assert str(t1.device) == "cuda:1"
|
|
419
|
+
|
|
420
|
+
expected0 = np.arange(0, n, 1, dtype=np.int32)
|
|
421
|
+
expected1 = np.arange(0, n * 2, 2, dtype=np.int32)
|
|
422
|
+
|
|
423
|
+
assert_np_equal(t0.cpu().numpy(), expected0)
|
|
424
|
+
assert_np_equal(t1.cpu().numpy(), expected1)
|
|
425
|
+
|
|
426
|
+
|
|
427
|
+
def test_torch_mgpu_interop(test, device):
|
|
428
|
+
import torch
|
|
429
|
+
|
|
430
|
+
n = 1024 * 1024
|
|
431
|
+
|
|
432
|
+
with torch.cuda.device(0):
|
|
433
|
+
t0 = torch.arange(n, dtype=torch.float32, device="cuda")
|
|
434
|
+
a0 = wp.from_torch(t0)
|
|
435
|
+
wp.launch(inc, dim=a0.size, inputs=[a0], stream=wp.stream_from_torch())
|
|
436
|
+
|
|
437
|
+
with torch.cuda.device(1):
|
|
438
|
+
t1 = torch.arange(n, dtype=torch.float32, device="cuda")
|
|
439
|
+
a1 = wp.from_torch(t1)
|
|
440
|
+
wp.launch(inc, dim=a1.size, inputs=[a1], stream=wp.stream_from_torch())
|
|
441
|
+
|
|
442
|
+
assert a0.device == "cuda:0"
|
|
443
|
+
assert a1.device == "cuda:1"
|
|
444
|
+
|
|
445
|
+
expected = np.arange(n, dtype=int) + 1
|
|
446
|
+
|
|
447
|
+
# ensure the torch tensors were modified by warp
|
|
448
|
+
assert_np_equal(t0.cpu().numpy(), expected)
|
|
449
|
+
assert_np_equal(t1.cpu().numpy(), expected)
|
|
450
|
+
|
|
451
|
+
|
|
452
|
+
def test_torch_autograd(test, device):
|
|
453
|
+
"""Test torch autograd with a custom Warp op"""
|
|
454
|
+
|
|
455
|
+
import torch
|
|
456
|
+
|
|
457
|
+
# custom autograd op
|
|
458
|
+
class TestFunc(torch.autograd.Function):
|
|
459
|
+
@staticmethod
|
|
460
|
+
def forward(ctx, x):
|
|
461
|
+
# allocate output array
|
|
462
|
+
y = torch.empty_like(x)
|
|
463
|
+
|
|
464
|
+
ctx.x = x
|
|
465
|
+
ctx.y = y
|
|
466
|
+
|
|
467
|
+
wp.launch(kernel=op_kernel, dim=len(x), inputs=[wp.from_torch(x)], outputs=[wp.from_torch(y)])
|
|
468
|
+
|
|
469
|
+
return y
|
|
470
|
+
|
|
471
|
+
@staticmethod
|
|
472
|
+
def backward(ctx, adj_y):
|
|
473
|
+
# adjoints should be allocated as zero initialized
|
|
474
|
+
adj_x = torch.zeros_like(ctx.x).contiguous()
|
|
475
|
+
adj_y = adj_y.contiguous()
|
|
476
|
+
|
|
477
|
+
wp_x = wp.from_torch(ctx.x, grad=adj_x)
|
|
478
|
+
wp_y = wp.from_torch(ctx.y, grad=adj_y)
|
|
479
|
+
|
|
480
|
+
wp.launch(
|
|
481
|
+
kernel=op_kernel,
|
|
482
|
+
dim=len(ctx.x),
|
|
483
|
+
# fwd inputs
|
|
484
|
+
inputs=[wp_x],
|
|
485
|
+
outputs=[wp_y],
|
|
486
|
+
# adj inputs (already stored in input/output arrays, passing null pointers)
|
|
487
|
+
adj_inputs=[None],
|
|
488
|
+
adj_outputs=[None],
|
|
489
|
+
adjoint=True,
|
|
490
|
+
)
|
|
491
|
+
|
|
492
|
+
return adj_x
|
|
493
|
+
|
|
494
|
+
# run autograd on given device
|
|
495
|
+
with wp.ScopedDevice(device):
|
|
496
|
+
torch_device = wp.device_to_torch(device)
|
|
497
|
+
|
|
498
|
+
# input data
|
|
499
|
+
x = torch.ones(16, dtype=torch.float32, device=torch_device, requires_grad=True)
|
|
500
|
+
|
|
501
|
+
# execute op
|
|
502
|
+
y = TestFunc.apply(x)
|
|
503
|
+
|
|
504
|
+
# compute grads
|
|
505
|
+
l = y.sum()
|
|
506
|
+
l.backward()
|
|
507
|
+
|
|
508
|
+
passed = (x.grad == -2.0).all()
|
|
509
|
+
assert passed.item()
|
|
510
|
+
|
|
511
|
+
|
|
512
|
+
def test_torch_graph_torch_stream(test, device):
|
|
513
|
+
"""Capture Torch graph on Torch stream"""
|
|
514
|
+
|
|
515
|
+
wp.load_module(device=device)
|
|
516
|
+
|
|
517
|
+
import torch
|
|
518
|
+
|
|
519
|
+
torch_device = wp.device_to_torch(device)
|
|
520
|
+
|
|
521
|
+
n = 1024 * 1024
|
|
522
|
+
t = torch.zeros(n, dtype=torch.float32, device=torch_device)
|
|
523
|
+
a = wp.from_torch(t)
|
|
524
|
+
|
|
525
|
+
g = torch.cuda.CUDAGraph()
|
|
526
|
+
|
|
527
|
+
# create a device-specific torch stream to use for capture
|
|
528
|
+
# (otherwise torch.cuda.graph reuses its capture stream, which can be problematic if it's from a different device)
|
|
529
|
+
torch_stream = torch.cuda.Stream(device=torch_device)
|
|
530
|
+
|
|
531
|
+
# make warp use the same stream
|
|
532
|
+
warp_stream = wp.stream_from_torch(torch_stream)
|
|
533
|
+
|
|
534
|
+
# capture graph
|
|
535
|
+
with wp.ScopedStream(warp_stream), torch.cuda.graph(g, stream=torch_stream):
|
|
536
|
+
wp.capture_begin(force_module_load=False, external=True)
|
|
537
|
+
try:
|
|
538
|
+
t += 1.0
|
|
539
|
+
wp.launch(inc, dim=n, inputs=[a])
|
|
540
|
+
t += 1.0
|
|
541
|
+
wp.launch(inc, dim=n, inputs=[a])
|
|
542
|
+
finally:
|
|
543
|
+
wp.capture_end()
|
|
544
|
+
|
|
545
|
+
# replay graph
|
|
546
|
+
num_iters = 10
|
|
547
|
+
for _i in range(num_iters):
|
|
548
|
+
g.replay()
|
|
549
|
+
|
|
550
|
+
passed = (t == num_iters * 4.0).all()
|
|
551
|
+
assert passed.item()
|
|
552
|
+
|
|
553
|
+
|
|
554
|
+
def test_torch_graph_warp_stream(test, device):
|
|
555
|
+
"""Capture Torch graph on Warp stream"""
|
|
556
|
+
|
|
557
|
+
import torch
|
|
558
|
+
|
|
559
|
+
torch_device = wp.device_to_torch(device)
|
|
560
|
+
|
|
561
|
+
n = 1024 * 1024
|
|
562
|
+
t = torch.zeros(n, dtype=torch.float32, device=torch_device)
|
|
563
|
+
a = wp.from_torch(t)
|
|
564
|
+
|
|
565
|
+
g = torch.cuda.CUDAGraph()
|
|
566
|
+
|
|
567
|
+
# make torch use the warp stream from the given device
|
|
568
|
+
torch_stream = wp.stream_to_torch(device)
|
|
569
|
+
|
|
570
|
+
# capture graph
|
|
571
|
+
with wp.ScopedDevice(device), torch.cuda.graph(g, stream=torch_stream):
|
|
572
|
+
wp.capture_begin(force_module_load=False, external=True)
|
|
573
|
+
try:
|
|
574
|
+
t += 1.0
|
|
575
|
+
wp.launch(inc, dim=n, inputs=[a])
|
|
576
|
+
t += 1.0
|
|
577
|
+
wp.launch(inc, dim=n, inputs=[a])
|
|
578
|
+
finally:
|
|
579
|
+
wp.capture_end()
|
|
580
|
+
|
|
581
|
+
# replay graph
|
|
582
|
+
num_iters = 10
|
|
583
|
+
for _i in range(num_iters):
|
|
584
|
+
g.replay()
|
|
585
|
+
|
|
586
|
+
passed = (t == num_iters * 4.0).all()
|
|
587
|
+
assert passed.item()
|
|
588
|
+
|
|
589
|
+
|
|
590
|
+
def test_warp_graph_warp_stream(test, device):
|
|
591
|
+
"""Capture Warp graph on Warp stream"""
|
|
592
|
+
|
|
593
|
+
import torch
|
|
594
|
+
|
|
595
|
+
torch_device = wp.device_to_torch(device)
|
|
596
|
+
|
|
597
|
+
n = 1024 * 1024
|
|
598
|
+
t = torch.zeros(n, dtype=torch.float32, device=torch_device)
|
|
599
|
+
a = wp.from_torch(t)
|
|
600
|
+
|
|
601
|
+
# make torch use the warp stream from the given device
|
|
602
|
+
torch_stream = wp.stream_to_torch(device)
|
|
603
|
+
|
|
604
|
+
# capture graph
|
|
605
|
+
with wp.ScopedDevice(device), torch.cuda.stream(torch_stream):
|
|
606
|
+
wp.capture_begin(force_module_load=False)
|
|
607
|
+
try:
|
|
608
|
+
t += 1.0
|
|
609
|
+
wp.launch(inc, dim=n, inputs=[a])
|
|
610
|
+
t += 1.0
|
|
611
|
+
wp.launch(inc, dim=n, inputs=[a])
|
|
612
|
+
finally:
|
|
613
|
+
g = wp.capture_end()
|
|
614
|
+
|
|
615
|
+
# replay graph
|
|
616
|
+
num_iters = 10
|
|
617
|
+
for _i in range(num_iters):
|
|
618
|
+
wp.capture_launch(g)
|
|
619
|
+
|
|
620
|
+
passed = (t == num_iters * 4.0).all()
|
|
621
|
+
assert passed.item()
|
|
622
|
+
|
|
623
|
+
|
|
624
|
+
def test_warp_graph_torch_stream(test, device):
|
|
625
|
+
"""Capture Warp graph on Torch stream"""
|
|
626
|
+
|
|
627
|
+
wp.load_module(device=device)
|
|
628
|
+
|
|
629
|
+
import torch
|
|
630
|
+
|
|
631
|
+
torch_device = wp.device_to_torch(device)
|
|
632
|
+
|
|
633
|
+
n = 1024 * 1024
|
|
634
|
+
t = torch.zeros(n, dtype=torch.float32, device=torch_device)
|
|
635
|
+
a = wp.from_torch(t)
|
|
636
|
+
|
|
637
|
+
# create a device-specific torch stream to use for capture
|
|
638
|
+
# (the default torch stream is not suitable for graph capture)
|
|
639
|
+
torch_stream = torch.cuda.Stream(device=torch_device)
|
|
640
|
+
|
|
641
|
+
# make warp use the same stream
|
|
642
|
+
warp_stream = wp.stream_from_torch(torch_stream)
|
|
643
|
+
|
|
644
|
+
# capture graph
|
|
645
|
+
with wp.ScopedStream(warp_stream), torch.cuda.stream(torch_stream):
|
|
646
|
+
wp.capture_begin(force_module_load=False)
|
|
647
|
+
try:
|
|
648
|
+
t += 1.0
|
|
649
|
+
wp.launch(inc, dim=n, inputs=[a])
|
|
650
|
+
t += 1.0
|
|
651
|
+
wp.launch(inc, dim=n, inputs=[a])
|
|
652
|
+
finally:
|
|
653
|
+
g = wp.capture_end()
|
|
654
|
+
|
|
655
|
+
# replay graph
|
|
656
|
+
num_iters = 10
|
|
657
|
+
for _i in range(num_iters):
|
|
658
|
+
wp.capture_launch(g)
|
|
659
|
+
|
|
660
|
+
passed = (t == num_iters * 4.0).all()
|
|
661
|
+
assert passed.item()
|
|
662
|
+
|
|
663
|
+
|
|
664
|
+
class TestTorch(unittest.TestCase):
|
|
665
|
+
pass
|
|
666
|
+
|
|
667
|
+
|
|
668
|
+
test_devices = get_test_devices()
|
|
669
|
+
|
|
670
|
+
try:
|
|
671
|
+
import torch
|
|
672
|
+
|
|
673
|
+
# check which Warp devices work with Torch
|
|
674
|
+
# CUDA devices may fail if Torch was not compiled with CUDA support
|
|
675
|
+
torch_compatible_devices = []
|
|
676
|
+
torch_compatible_cuda_devices = []
|
|
677
|
+
|
|
678
|
+
for d in test_devices:
|
|
679
|
+
try:
|
|
680
|
+
t = torch.arange(10, device=wp.device_to_torch(d))
|
|
681
|
+
t += 1
|
|
682
|
+
torch_compatible_devices.append(d)
|
|
683
|
+
if d.is_cuda:
|
|
684
|
+
torch_compatible_cuda_devices.append(d)
|
|
685
|
+
except Exception as e:
|
|
686
|
+
print(f"Skipping Torch tests on device '{d}' due to exception: {e}")
|
|
687
|
+
|
|
688
|
+
add_function_test(TestTorch, "test_dtype_from_torch", test_dtype_from_torch, devices=None)
|
|
689
|
+
add_function_test(TestTorch, "test_dtype_to_torch", test_dtype_to_torch, devices=None)
|
|
690
|
+
|
|
691
|
+
if torch_compatible_devices:
|
|
692
|
+
add_function_test(TestTorch, "test_device_conversion", test_device_conversion, devices=torch_compatible_devices)
|
|
693
|
+
add_function_test(TestTorch, "test_from_torch", test_from_torch, devices=torch_compatible_devices)
|
|
694
|
+
add_function_test(TestTorch, "test_from_torch_slices", test_from_torch_slices, devices=torch_compatible_devices)
|
|
695
|
+
add_function_test(
|
|
696
|
+
TestTorch,
|
|
697
|
+
"test_from_torch_zero_strides",
|
|
698
|
+
test_from_torch_zero_strides,
|
|
699
|
+
devices=torch_compatible_devices,
|
|
700
|
+
)
|
|
701
|
+
add_function_test(TestTorch, "test_to_torch", test_to_torch, devices=torch_compatible_devices)
|
|
702
|
+
add_function_test(TestTorch, "test_torch_zerocopy", test_torch_zerocopy, devices=torch_compatible_devices)
|
|
703
|
+
add_function_test(TestTorch, "test_torch_autograd", test_torch_autograd, devices=torch_compatible_devices)
|
|
704
|
+
|
|
705
|
+
if torch_compatible_cuda_devices:
|
|
706
|
+
add_function_test(
|
|
707
|
+
TestTorch,
|
|
708
|
+
"test_torch_graph_torch_stream",
|
|
709
|
+
test_torch_graph_torch_stream,
|
|
710
|
+
devices=torch_compatible_cuda_devices,
|
|
711
|
+
)
|
|
712
|
+
add_function_test(
|
|
713
|
+
TestTorch,
|
|
714
|
+
"test_torch_graph_warp_stream",
|
|
715
|
+
test_torch_graph_warp_stream,
|
|
716
|
+
devices=torch_compatible_cuda_devices,
|
|
717
|
+
)
|
|
718
|
+
add_function_test(
|
|
719
|
+
TestTorch,
|
|
720
|
+
"test_warp_graph_warp_stream",
|
|
721
|
+
test_warp_graph_warp_stream,
|
|
722
|
+
devices=torch_compatible_cuda_devices,
|
|
723
|
+
)
|
|
724
|
+
add_function_test(
|
|
725
|
+
TestTorch,
|
|
726
|
+
"test_warp_graph_torch_stream",
|
|
727
|
+
test_warp_graph_torch_stream,
|
|
728
|
+
devices=torch_compatible_cuda_devices,
|
|
729
|
+
)
|
|
730
|
+
|
|
731
|
+
# multi-GPU tests
|
|
732
|
+
if len(torch_compatible_cuda_devices) > 1:
|
|
733
|
+
add_function_test(TestTorch, "test_torch_mgpu_from_torch", test_torch_mgpu_from_torch)
|
|
734
|
+
add_function_test(TestTorch, "test_torch_mgpu_to_torch", test_torch_mgpu_to_torch)
|
|
735
|
+
add_function_test(TestTorch, "test_torch_mgpu_interop", test_torch_mgpu_interop)
|
|
736
|
+
|
|
737
|
+
except Exception as e:
|
|
738
|
+
print(f"Skipping Torch tests due to exception: {e}")
|
|
739
|
+
|
|
740
|
+
|
|
741
|
+
if __name__ == "__main__":
|
|
742
|
+
wp.build.clear_kernel_cache()
|
|
743
|
+
unittest.main(verbosity=2)
|