warp-lang 1.0.1__py3-none-manylinux2014_aarch64.whl → 1.1.0__py3-none-manylinux2014_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +108 -97
- warp/__init__.pyi +1 -1
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/build.py +115 -113
- warp/build_dll.py +383 -375
- warp/builtins.py +3425 -3354
- warp/codegen.py +2878 -2792
- warp/config.py +40 -36
- warp/constants.py +45 -45
- warp/context.py +5194 -5102
- warp/dlpack.py +442 -442
- warp/examples/__init__.py +16 -16
- warp/examples/assets/bear.usd +0 -0
- warp/examples/assets/bunny.usd +0 -0
- warp/examples/assets/cartpole.urdf +110 -110
- warp/examples/assets/crazyflie.usd +0 -0
- warp/examples/assets/cube.usd +0 -0
- warp/examples/assets/nv_ant.xml +92 -92
- warp/examples/assets/nv_humanoid.xml +183 -183
- warp/examples/assets/quadruped.urdf +267 -267
- warp/examples/assets/rocks.nvdb +0 -0
- warp/examples/assets/rocks.usd +0 -0
- warp/examples/assets/sphere.usd +0 -0
- warp/examples/benchmarks/benchmark_api.py +383 -383
- warp/examples/benchmarks/benchmark_cloth.py +278 -279
- warp/examples/benchmarks/benchmark_cloth_cupy.py +88 -88
- warp/examples/benchmarks/benchmark_cloth_jax.py +97 -100
- warp/examples/benchmarks/benchmark_cloth_numba.py +146 -142
- warp/examples/benchmarks/benchmark_cloth_numpy.py +77 -77
- warp/examples/benchmarks/benchmark_cloth_pytorch.py +86 -86
- warp/examples/benchmarks/benchmark_cloth_taichi.py +112 -112
- warp/examples/benchmarks/benchmark_cloth_warp.py +146 -146
- warp/examples/benchmarks/benchmark_launches.py +295 -295
- warp/examples/browse.py +29 -28
- warp/examples/core/example_dem.py +234 -221
- warp/examples/core/example_fluid.py +293 -267
- warp/examples/core/example_graph_capture.py +144 -129
- warp/examples/core/example_marching_cubes.py +188 -176
- warp/examples/core/example_mesh.py +174 -154
- warp/examples/core/example_mesh_intersect.py +205 -193
- warp/examples/core/example_nvdb.py +176 -169
- warp/examples/core/example_raycast.py +105 -89
- warp/examples/core/example_raymarch.py +199 -178
- warp/examples/core/example_render_opengl.py +185 -141
- warp/examples/core/example_sph.py +405 -389
- warp/examples/core/example_torch.py +222 -181
- warp/examples/core/example_wave.py +263 -249
- warp/examples/fem/bsr_utils.py +378 -380
- warp/examples/fem/example_apic_fluid.py +407 -391
- warp/examples/fem/example_convection_diffusion.py +182 -168
- warp/examples/fem/example_convection_diffusion_dg.py +219 -209
- warp/examples/fem/example_convection_diffusion_dg0.py +204 -194
- warp/examples/fem/example_deformed_geometry.py +177 -159
- warp/examples/fem/example_diffusion.py +201 -173
- warp/examples/fem/example_diffusion_3d.py +177 -152
- warp/examples/fem/example_diffusion_mgpu.py +221 -214
- warp/examples/fem/example_mixed_elasticity.py +244 -222
- warp/examples/fem/example_navier_stokes.py +259 -243
- warp/examples/fem/example_stokes.py +220 -192
- warp/examples/fem/example_stokes_transfer.py +265 -249
- warp/examples/fem/mesh_utils.py +133 -109
- warp/examples/fem/plot_utils.py +292 -287
- warp/examples/optim/example_bounce.py +260 -248
- warp/examples/optim/example_cloth_throw.py +222 -210
- warp/examples/optim/example_diffray.py +566 -535
- warp/examples/optim/example_drone.py +864 -835
- warp/examples/optim/example_inverse_kinematics.py +176 -169
- warp/examples/optim/example_inverse_kinematics_torch.py +185 -170
- warp/examples/optim/example_spring_cage.py +239 -234
- warp/examples/optim/example_trajectory.py +223 -201
- warp/examples/optim/example_walker.py +306 -292
- warp/examples/sim/example_cartpole.py +139 -128
- warp/examples/sim/example_cloth.py +196 -184
- warp/examples/sim/example_granular.py +124 -113
- warp/examples/sim/example_granular_collision_sdf.py +197 -185
- warp/examples/sim/example_jacobian_ik.py +236 -213
- warp/examples/sim/example_particle_chain.py +118 -106
- warp/examples/sim/example_quadruped.py +193 -179
- warp/examples/sim/example_rigid_chain.py +197 -189
- warp/examples/sim/example_rigid_contact.py +189 -176
- warp/examples/sim/example_rigid_force.py +127 -126
- warp/examples/sim/example_rigid_gyroscopic.py +109 -97
- warp/examples/sim/example_rigid_soft_contact.py +134 -124
- warp/examples/sim/example_soft_body.py +190 -178
- warp/fabric.py +337 -335
- warp/fem/__init__.py +60 -27
- warp/fem/cache.py +401 -388
- warp/fem/dirichlet.py +178 -179
- warp/fem/domain.py +262 -263
- warp/fem/field/__init__.py +100 -101
- warp/fem/field/field.py +148 -149
- warp/fem/field/nodal_field.py +298 -299
- warp/fem/field/restriction.py +22 -21
- warp/fem/field/test.py +180 -181
- warp/fem/field/trial.py +183 -183
- warp/fem/geometry/__init__.py +15 -19
- warp/fem/geometry/closest_point.py +69 -70
- warp/fem/geometry/deformed_geometry.py +270 -271
- warp/fem/geometry/element.py +744 -744
- warp/fem/geometry/geometry.py +184 -186
- warp/fem/geometry/grid_2d.py +380 -373
- warp/fem/geometry/grid_3d.py +441 -435
- warp/fem/geometry/hexmesh.py +953 -953
- warp/fem/geometry/partition.py +374 -376
- warp/fem/geometry/quadmesh_2d.py +532 -532
- warp/fem/geometry/tetmesh.py +840 -840
- warp/fem/geometry/trimesh_2d.py +577 -577
- warp/fem/integrate.py +1630 -1615
- warp/fem/operator.py +190 -191
- warp/fem/polynomial.py +214 -213
- warp/fem/quadrature/__init__.py +2 -2
- warp/fem/quadrature/pic_quadrature.py +243 -245
- warp/fem/quadrature/quadrature.py +295 -294
- warp/fem/space/__init__.py +294 -292
- warp/fem/space/basis_space.py +488 -489
- warp/fem/space/collocated_function_space.py +100 -105
- warp/fem/space/dof_mapper.py +236 -236
- warp/fem/space/function_space.py +148 -145
- warp/fem/space/grid_2d_function_space.py +267 -267
- warp/fem/space/grid_3d_function_space.py +305 -306
- warp/fem/space/hexmesh_function_space.py +350 -352
- warp/fem/space/partition.py +350 -350
- warp/fem/space/quadmesh_2d_function_space.py +368 -369
- warp/fem/space/restriction.py +158 -160
- warp/fem/space/shape/__init__.py +13 -15
- warp/fem/space/shape/cube_shape_function.py +738 -738
- warp/fem/space/shape/shape_function.py +102 -103
- warp/fem/space/shape/square_shape_function.py +611 -611
- warp/fem/space/shape/tet_shape_function.py +565 -567
- warp/fem/space/shape/triangle_shape_function.py +429 -429
- warp/fem/space/tetmesh_function_space.py +294 -292
- warp/fem/space/topology.py +297 -295
- warp/fem/space/trimesh_2d_function_space.py +223 -221
- warp/fem/types.py +77 -77
- warp/fem/utils.py +495 -495
- warp/jax.py +166 -141
- warp/jax_experimental.py +341 -339
- warp/native/array.h +1072 -1025
- warp/native/builtin.h +1560 -1560
- warp/native/bvh.cpp +398 -398
- warp/native/bvh.cu +525 -525
- warp/native/bvh.h +429 -429
- warp/native/clang/clang.cpp +495 -464
- warp/native/crt.cpp +31 -31
- warp/native/crt.h +334 -334
- warp/native/cuda_crt.h +1049 -1049
- warp/native/cuda_util.cpp +549 -540
- warp/native/cuda_util.h +288 -203
- warp/native/cutlass_gemm.cpp +34 -34
- warp/native/cutlass_gemm.cu +372 -372
- warp/native/error.cpp +66 -66
- warp/native/error.h +27 -27
- warp/native/fabric.h +228 -228
- warp/native/hashgrid.cpp +301 -278
- warp/native/hashgrid.cu +78 -77
- warp/native/hashgrid.h +227 -227
- warp/native/initializer_array.h +32 -32
- warp/native/intersect.h +1204 -1204
- warp/native/intersect_adj.h +365 -365
- warp/native/intersect_tri.h +322 -322
- warp/native/marching.cpp +2 -2
- warp/native/marching.cu +497 -497
- warp/native/marching.h +2 -2
- warp/native/mat.h +1498 -1498
- warp/native/matnn.h +333 -333
- warp/native/mesh.cpp +203 -203
- warp/native/mesh.cu +293 -293
- warp/native/mesh.h +1887 -1887
- warp/native/nanovdb/NanoVDB.h +4782 -4782
- warp/native/nanovdb/PNanoVDB.h +2553 -2553
- warp/native/nanovdb/PNanoVDBWrite.h +294 -294
- warp/native/noise.h +850 -850
- warp/native/quat.h +1084 -1084
- warp/native/rand.h +299 -299
- warp/native/range.h +108 -108
- warp/native/reduce.cpp +156 -156
- warp/native/reduce.cu +348 -348
- warp/native/runlength_encode.cpp +61 -61
- warp/native/runlength_encode.cu +46 -46
- warp/native/scan.cpp +30 -30
- warp/native/scan.cu +36 -36
- warp/native/scan.h +7 -7
- warp/native/solid_angle.h +442 -442
- warp/native/sort.cpp +94 -94
- warp/native/sort.cu +97 -97
- warp/native/sort.h +14 -14
- warp/native/sparse.cpp +337 -337
- warp/native/sparse.cu +544 -544
- warp/native/spatial.h +630 -630
- warp/native/svd.h +562 -562
- warp/native/temp_buffer.h +30 -30
- warp/native/vec.h +1132 -1132
- warp/native/volume.cpp +297 -297
- warp/native/volume.cu +32 -32
- warp/native/volume.h +538 -538
- warp/native/volume_builder.cu +425 -425
- warp/native/volume_builder.h +19 -19
- warp/native/warp.cpp +1057 -1052
- warp/native/warp.cu +2943 -2828
- warp/native/warp.h +313 -305
- warp/optim/__init__.py +9 -9
- warp/optim/adam.py +120 -120
- warp/optim/linear.py +1104 -939
- warp/optim/sgd.py +104 -92
- warp/render/__init__.py +10 -10
- warp/render/render_opengl.py +3217 -3204
- warp/render/render_usd.py +768 -749
- warp/render/utils.py +152 -150
- warp/sim/__init__.py +52 -59
- warp/sim/articulation.py +685 -685
- warp/sim/collide.py +1594 -1590
- warp/sim/import_mjcf.py +489 -481
- warp/sim/import_snu.py +220 -221
- warp/sim/import_urdf.py +536 -516
- warp/sim/import_usd.py +887 -881
- warp/sim/inertia.py +316 -317
- warp/sim/integrator.py +234 -233
- warp/sim/integrator_euler.py +1956 -1956
- warp/sim/integrator_featherstone.py +1910 -1991
- warp/sim/integrator_xpbd.py +3294 -3312
- warp/sim/model.py +4473 -4314
- warp/sim/particles.py +113 -112
- warp/sim/render.py +417 -403
- warp/sim/utils.py +413 -410
- warp/sparse.py +1227 -1227
- warp/stubs.py +2109 -2469
- warp/tape.py +1162 -225
- warp/tests/__init__.py +1 -1
- warp/tests/__main__.py +4 -4
- warp/tests/assets/torus.usda +105 -105
- warp/tests/aux_test_class_kernel.py +26 -26
- warp/tests/aux_test_compile_consts_dummy.py +10 -10
- warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -21
- warp/tests/aux_test_dependent.py +22 -22
- warp/tests/aux_test_grad_customs.py +23 -23
- warp/tests/aux_test_reference.py +11 -11
- warp/tests/aux_test_reference_reference.py +10 -10
- warp/tests/aux_test_square.py +17 -17
- warp/tests/aux_test_unresolved_func.py +14 -14
- warp/tests/aux_test_unresolved_symbol.py +14 -14
- warp/tests/disabled_kinematics.py +239 -239
- warp/tests/run_coverage_serial.py +31 -31
- warp/tests/test_adam.py +157 -157
- warp/tests/test_arithmetic.py +1124 -1124
- warp/tests/test_array.py +2417 -2326
- warp/tests/test_array_reduce.py +150 -150
- warp/tests/test_async.py +668 -656
- warp/tests/test_atomic.py +141 -141
- warp/tests/test_bool.py +204 -149
- warp/tests/test_builtins_resolution.py +1292 -1292
- warp/tests/test_bvh.py +164 -171
- warp/tests/test_closest_point_edge_edge.py +228 -228
- warp/tests/test_codegen.py +566 -553
- warp/tests/test_compile_consts.py +97 -101
- warp/tests/test_conditional.py +246 -246
- warp/tests/test_copy.py +232 -215
- warp/tests/test_ctypes.py +632 -632
- warp/tests/test_dense.py +67 -67
- warp/tests/test_devices.py +91 -98
- warp/tests/test_dlpack.py +530 -529
- warp/tests/test_examples.py +400 -378
- warp/tests/test_fabricarray.py +955 -955
- warp/tests/test_fast_math.py +62 -54
- warp/tests/test_fem.py +1277 -1278
- warp/tests/test_fp16.py +130 -130
- warp/tests/test_func.py +338 -337
- warp/tests/test_generics.py +571 -571
- warp/tests/test_grad.py +746 -640
- warp/tests/test_grad_customs.py +333 -336
- warp/tests/test_hash_grid.py +210 -164
- warp/tests/test_import.py +39 -39
- warp/tests/test_indexedarray.py +1134 -1134
- warp/tests/test_intersect.py +67 -67
- warp/tests/test_jax.py +307 -307
- warp/tests/test_large.py +167 -164
- warp/tests/test_launch.py +354 -354
- warp/tests/test_lerp.py +261 -261
- warp/tests/test_linear_solvers.py +191 -171
- warp/tests/test_lvalue.py +421 -493
- warp/tests/test_marching_cubes.py +65 -65
- warp/tests/test_mat.py +1801 -1827
- warp/tests/test_mat_lite.py +115 -115
- warp/tests/test_mat_scalar_ops.py +2907 -2889
- warp/tests/test_math.py +126 -193
- warp/tests/test_matmul.py +500 -499
- warp/tests/test_matmul_lite.py +410 -410
- warp/tests/test_mempool.py +188 -190
- warp/tests/test_mesh.py +284 -324
- warp/tests/test_mesh_query_aabb.py +228 -241
- warp/tests/test_mesh_query_point.py +692 -702
- warp/tests/test_mesh_query_ray.py +292 -303
- warp/tests/test_mlp.py +276 -276
- warp/tests/test_model.py +110 -110
- warp/tests/test_modules_lite.py +39 -39
- warp/tests/test_multigpu.py +163 -163
- warp/tests/test_noise.py +248 -248
- warp/tests/test_operators.py +250 -250
- warp/tests/test_options.py +123 -125
- warp/tests/test_peer.py +133 -137
- warp/tests/test_pinned.py +78 -78
- warp/tests/test_print.py +54 -54
- warp/tests/test_quat.py +2086 -2086
- warp/tests/test_rand.py +288 -288
- warp/tests/test_reload.py +217 -217
- warp/tests/test_rounding.py +179 -179
- warp/tests/test_runlength_encode.py +190 -190
- warp/tests/test_sim_grad.py +243 -0
- warp/tests/test_sim_kinematics.py +91 -97
- warp/tests/test_smoothstep.py +168 -168
- warp/tests/test_snippet.py +305 -266
- warp/tests/test_sparse.py +468 -460
- warp/tests/test_spatial.py +2148 -2148
- warp/tests/test_streams.py +486 -473
- warp/tests/test_struct.py +710 -675
- warp/tests/test_tape.py +173 -148
- warp/tests/test_torch.py +743 -743
- warp/tests/test_transient_module.py +87 -87
- warp/tests/test_types.py +556 -659
- warp/tests/test_utils.py +490 -499
- warp/tests/test_vec.py +1264 -1268
- warp/tests/test_vec_lite.py +73 -73
- warp/tests/test_vec_scalar_ops.py +2099 -2099
- warp/tests/test_verify_fp.py +94 -94
- warp/tests/test_volume.py +737 -736
- warp/tests/test_volume_write.py +255 -265
- warp/tests/unittest_serial.py +37 -37
- warp/tests/unittest_suites.py +363 -359
- warp/tests/unittest_utils.py +603 -578
- warp/tests/unused_test_misc.py +71 -71
- warp/tests/walkthrough_debug.py +85 -85
- warp/thirdparty/appdirs.py +598 -598
- warp/thirdparty/dlpack.py +143 -143
- warp/thirdparty/unittest_parallel.py +566 -561
- warp/torch.py +321 -295
- warp/types.py +4504 -4450
- warp/utils.py +1008 -821
- {warp_lang-1.0.1.dist-info → warp_lang-1.1.0.dist-info}/LICENSE.md +126 -126
- {warp_lang-1.0.1.dist-info → warp_lang-1.1.0.dist-info}/METADATA +338 -400
- warp_lang-1.1.0.dist-info/RECORD +352 -0
- warp/examples/assets/cube.usda +0 -42
- warp/examples/assets/sphere.usda +0 -56
- warp/examples/assets/torus.usda +0 -105
- warp_lang-1.0.1.dist-info/RECORD +0 -352
- {warp_lang-1.0.1.dist-info → warp_lang-1.1.0.dist-info}/WHEEL +0 -0
- {warp_lang-1.0.1.dist-info → warp_lang-1.1.0.dist-info}/top_level.txt +0 -0
warp/tests/test_grad.py
CHANGED
|
@@ -1,640 +1,746 @@
|
|
|
1
|
-
# Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
|
|
2
|
-
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
|
3
|
-
# and proprietary rights in and to this software, related documentation
|
|
4
|
-
# and any modifications thereto. Any use, reproduction, disclosure or
|
|
5
|
-
# distribution of this software and related documentation without an express
|
|
6
|
-
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
|
7
|
-
|
|
8
|
-
import unittest
|
|
9
|
-
from typing import Any
|
|
10
|
-
|
|
11
|
-
import numpy as np
|
|
12
|
-
|
|
13
|
-
import warp as wp
|
|
14
|
-
from warp.tests.unittest_utils import *
|
|
15
|
-
|
|
16
|
-
wp.init()
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
@wp.kernel
|
|
20
|
-
def scalar_grad(x: wp.array(dtype=float), y: wp.array(dtype=float)):
|
|
21
|
-
y[0] = x[0] ** 2.0
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
def test_scalar_grad(test, device):
|
|
25
|
-
x = wp.array([3.0], dtype=float, device=device, requires_grad=True)
|
|
26
|
-
y = wp.zeros_like(x)
|
|
27
|
-
|
|
28
|
-
tape = wp.Tape()
|
|
29
|
-
with tape:
|
|
30
|
-
wp.launch(scalar_grad, dim=1, inputs=[x, y], device=device)
|
|
31
|
-
|
|
32
|
-
tape.backward(y)
|
|
33
|
-
|
|
34
|
-
assert_np_equal(tape.gradients[x].numpy(), np.array(6.0))
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
@wp.kernel
|
|
38
|
-
def for_loop_grad(n: int, x: wp.array(dtype=float), s: wp.array(dtype=float)):
|
|
39
|
-
sum = float(0.0)
|
|
40
|
-
|
|
41
|
-
for i in range(n):
|
|
42
|
-
sum = sum + x[i] * 2.0
|
|
43
|
-
|
|
44
|
-
s[0] = sum
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
def test_for_loop_grad(test, device):
|
|
48
|
-
n = 32
|
|
49
|
-
val = np.ones(n, dtype=np.float32)
|
|
50
|
-
|
|
51
|
-
x = wp.array(val, device=device, requires_grad=True)
|
|
52
|
-
sum = wp.zeros(1, dtype=wp.float32, device=device, requires_grad=True)
|
|
53
|
-
|
|
54
|
-
tape = wp.Tape()
|
|
55
|
-
with tape:
|
|
56
|
-
wp.launch(for_loop_grad, dim=1, inputs=[n, x, sum], device=device)
|
|
57
|
-
|
|
58
|
-
# ensure forward pass outputs correct
|
|
59
|
-
assert_np_equal(sum.numpy(), 2.0 * np.sum(x.numpy()))
|
|
60
|
-
|
|
61
|
-
tape.backward(loss=sum)
|
|
62
|
-
|
|
63
|
-
# ensure forward pass outputs persist
|
|
64
|
-
assert_np_equal(sum.numpy(), 2.0 * np.sum(x.numpy()))
|
|
65
|
-
# ensure gradients correct
|
|
66
|
-
assert_np_equal(tape.gradients[x].numpy(), 2.0 * val)
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
def test_for_loop_graph_grad(test, device):
|
|
70
|
-
wp.load_module(device=device)
|
|
71
|
-
|
|
72
|
-
n = 32
|
|
73
|
-
val = np.ones(n, dtype=np.float32)
|
|
74
|
-
|
|
75
|
-
x = wp.array(val, device=device, requires_grad=True)
|
|
76
|
-
sum = wp.zeros(1, dtype=wp.float32, device=device, requires_grad=True)
|
|
77
|
-
|
|
78
|
-
wp.capture_begin(device, force_module_load=False)
|
|
79
|
-
try:
|
|
80
|
-
tape = wp.Tape()
|
|
81
|
-
with tape:
|
|
82
|
-
wp.launch(for_loop_grad, dim=1, inputs=[n, x, sum], device=device)
|
|
83
|
-
|
|
84
|
-
tape.backward(loss=sum)
|
|
85
|
-
finally:
|
|
86
|
-
graph = wp.capture_end(device)
|
|
87
|
-
|
|
88
|
-
wp.capture_launch(graph)
|
|
89
|
-
wp.synchronize_device(device)
|
|
90
|
-
|
|
91
|
-
# ensure forward pass outputs persist
|
|
92
|
-
assert_np_equal(sum.numpy(), 2.0 * np.sum(x.numpy()))
|
|
93
|
-
# ensure gradients correct
|
|
94
|
-
assert_np_equal(x.grad.numpy(), 2.0 * val)
|
|
95
|
-
|
|
96
|
-
wp.capture_launch(graph)
|
|
97
|
-
wp.synchronize_device(device)
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
@wp.kernel
|
|
101
|
-
def for_loop_nested_if_grad(n: int, x: wp.array(dtype=float), s: wp.array(dtype=float)):
|
|
102
|
-
sum = float(0.0)
|
|
103
|
-
|
|
104
|
-
for i in range(n):
|
|
105
|
-
if i < 16:
|
|
106
|
-
if i < 8:
|
|
107
|
-
sum = sum + x[i] * 2.0
|
|
108
|
-
else:
|
|
109
|
-
sum = sum + x[i] * 4.0
|
|
110
|
-
else:
|
|
111
|
-
if i < 24:
|
|
112
|
-
sum = sum + x[i] * 6.0
|
|
113
|
-
else:
|
|
114
|
-
sum = sum + x[i] * 8.0
|
|
115
|
-
|
|
116
|
-
s[0] = sum
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
def test_for_loop_nested_if_grad(test, device):
|
|
120
|
-
n = 32
|
|
121
|
-
val = np.ones(n, dtype=np.float32)
|
|
122
|
-
# fmt: off
|
|
123
|
-
expected_val = [
|
|
124
|
-
2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0,
|
|
125
|
-
4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0,
|
|
126
|
-
6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0,
|
|
127
|
-
8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0,
|
|
128
|
-
]
|
|
129
|
-
expected_grad = [
|
|
130
|
-
2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0,
|
|
131
|
-
4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0,
|
|
132
|
-
6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0,
|
|
133
|
-
8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0,
|
|
134
|
-
]
|
|
135
|
-
# fmt: on
|
|
136
|
-
|
|
137
|
-
x = wp.array(val, device=device, requires_grad=True)
|
|
138
|
-
sum = wp.zeros(1, dtype=wp.float32, device=device, requires_grad=True)
|
|
139
|
-
|
|
140
|
-
tape = wp.Tape()
|
|
141
|
-
with tape:
|
|
142
|
-
wp.launch(for_loop_nested_if_grad, dim=1, inputs=[n, x, sum], device=device)
|
|
143
|
-
|
|
144
|
-
assert_np_equal(sum.numpy(), np.sum(expected_val))
|
|
145
|
-
|
|
146
|
-
tape.backward(loss=sum)
|
|
147
|
-
|
|
148
|
-
assert_np_equal(sum.numpy(), np.sum(expected_val))
|
|
149
|
-
assert_np_equal(tape.gradients[x].numpy(), np.array(expected_grad))
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
@wp.kernel
|
|
153
|
-
def for_loop_grad_nested(n: int, x: wp.array(dtype=float), s: wp.array(dtype=float)):
|
|
154
|
-
sum = float(0.0)
|
|
155
|
-
|
|
156
|
-
for i in range(n):
|
|
157
|
-
for j in range(n):
|
|
158
|
-
sum = sum + x[i * n + j] * float(i * n + j) + 1.0
|
|
159
|
-
|
|
160
|
-
s[0] = sum
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
def test_for_loop_nested_for_grad(test, device):
|
|
164
|
-
x = wp.zeros(9, dtype=float, device=device, requires_grad=True)
|
|
165
|
-
s = wp.zeros(1, dtype=float, device=device, requires_grad=True)
|
|
166
|
-
|
|
167
|
-
tape = wp.Tape()
|
|
168
|
-
with tape:
|
|
169
|
-
wp.launch(for_loop_grad_nested, dim=1, inputs=[3, x, s], device=device)
|
|
170
|
-
|
|
171
|
-
tape.backward(s)
|
|
172
|
-
|
|
173
|
-
assert_np_equal(s.numpy(), np.array([9.0]))
|
|
174
|
-
assert_np_equal(tape.gradients[x].numpy(), np.arange(0.0, 9.0, 1.0))
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
# differentiating thought most while loops is not supported
|
|
178
|
-
# since doing things like i = i + 1 breaks adjointing
|
|
179
|
-
|
|
180
|
-
# @wp.kernel
|
|
181
|
-
# def while_loop_grad(n: int,
|
|
182
|
-
# x: wp.array(dtype=float),
|
|
183
|
-
# c: wp.array(dtype=int),
|
|
184
|
-
# s: wp.array(dtype=float)):
|
|
185
|
-
|
|
186
|
-
# tid = wp.tid()
|
|
187
|
-
|
|
188
|
-
# while i < n:
|
|
189
|
-
# s[0] = s[0] + x[i]*2.0
|
|
190
|
-
# i = i + 1
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
# def test_while_loop_grad(test, device):
|
|
194
|
-
|
|
195
|
-
# n = 32
|
|
196
|
-
# x = wp.array(np.ones(n, dtype=np.float32), device=device, requires_grad=True)
|
|
197
|
-
# c = wp.zeros(1, dtype=int, device=device)
|
|
198
|
-
# sum = wp.zeros(1, dtype=wp.float32, device=device)
|
|
199
|
-
|
|
200
|
-
# tape = wp.Tape()
|
|
201
|
-
# with tape:
|
|
202
|
-
# wp.launch(while_loop_grad, dim=1, inputs=[n, x, c, sum], device=device)
|
|
203
|
-
|
|
204
|
-
# tape.backward(loss=sum)
|
|
205
|
-
|
|
206
|
-
# assert_np_equal(sum.numpy(), 2.0*np.sum(x.numpy()))
|
|
207
|
-
# assert_np_equal(tape.gradients[x].numpy(), 2.0*np.ones_like(x.numpy()))
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
@wp.kernel
|
|
211
|
-
def preserve_outputs(
|
|
212
|
-
n: int, x: wp.array(dtype=float), c: wp.array(dtype=float), s1: wp.array(dtype=float), s2: wp.array(dtype=float)
|
|
213
|
-
):
|
|
214
|
-
tid = wp.tid()
|
|
215
|
-
|
|
216
|
-
# plain store
|
|
217
|
-
c[tid] = x[tid] * 2.0
|
|
218
|
-
|
|
219
|
-
# atomic stores
|
|
220
|
-
wp.atomic_add(s1, 0, x[tid] * 3.0)
|
|
221
|
-
wp.atomic_sub(s2, 0, x[tid] * 2.0)
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
# tests that outputs from the forward pass are
|
|
225
|
-
# preserved by the backward pass, i.e.: stores
|
|
226
|
-
# are omitted during the forward reply
|
|
227
|
-
def test_preserve_outputs_grad(test, device):
|
|
228
|
-
n = 32
|
|
229
|
-
|
|
230
|
-
val = np.ones(n, dtype=np.float32)
|
|
231
|
-
|
|
232
|
-
x = wp.array(val, device=device, requires_grad=True)
|
|
233
|
-
c = wp.zeros_like(x)
|
|
234
|
-
|
|
235
|
-
s1 = wp.zeros(1, dtype=wp.float32, device=device, requires_grad=True)
|
|
236
|
-
s2 = wp.zeros(1, dtype=wp.float32, device=device, requires_grad=True)
|
|
237
|
-
|
|
238
|
-
tape = wp.Tape()
|
|
239
|
-
with tape:
|
|
240
|
-
wp.launch(preserve_outputs, dim=n, inputs=[n, x, c, s1, s2], device=device)
|
|
241
|
-
|
|
242
|
-
# ensure forward pass results are correct
|
|
243
|
-
assert_np_equal(x.numpy(), val)
|
|
244
|
-
assert_np_equal(c.numpy(), val * 2.0)
|
|
245
|
-
assert_np_equal(s1.numpy(), np.array(3.0 * n))
|
|
246
|
-
assert_np_equal(s2.numpy(), np.array(-2.0 * n))
|
|
247
|
-
|
|
248
|
-
# run backward on first loss
|
|
249
|
-
tape.backward(loss=s1)
|
|
250
|
-
|
|
251
|
-
# ensure inputs, copy and sum are unchanged by backwards pass
|
|
252
|
-
assert_np_equal(x.numpy(), val)
|
|
253
|
-
assert_np_equal(c.numpy(), val * 2.0)
|
|
254
|
-
assert_np_equal(s1.numpy(), np.array(3.0 * n))
|
|
255
|
-
assert_np_equal(s2.numpy(), np.array(-2.0 * n))
|
|
256
|
-
|
|
257
|
-
# ensure gradients are correct
|
|
258
|
-
assert_np_equal(tape.gradients[x].numpy(), 3.0 * val)
|
|
259
|
-
|
|
260
|
-
# run backward on second loss
|
|
261
|
-
tape.zero()
|
|
262
|
-
tape.backward(loss=s2)
|
|
263
|
-
|
|
264
|
-
assert_np_equal(x.numpy(), val)
|
|
265
|
-
assert_np_equal(c.numpy(), val * 2.0)
|
|
266
|
-
assert_np_equal(s1.numpy(), np.array(3.0 * n))
|
|
267
|
-
assert_np_equal(s2.numpy(), np.array(-2.0 * n))
|
|
268
|
-
|
|
269
|
-
# ensure gradients are correct
|
|
270
|
-
assert_np_equal(tape.gradients[x].numpy(), -2.0 * val)
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
def gradcheck(func, func_name, inputs, device, eps=1e-4, tol=1e-2):
|
|
274
|
-
"""
|
|
275
|
-
Checks that the gradient of the Warp kernel is correct by comparing it to the
|
|
276
|
-
numerical gradient computed using finite differences.
|
|
277
|
-
"""
|
|
278
|
-
|
|
279
|
-
kernel = wp.Kernel(func=func, key=func_name)
|
|
280
|
-
|
|
281
|
-
def f(xs):
|
|
282
|
-
# call the kernel without taping for finite differences
|
|
283
|
-
wp_xs = [wp.array(xs[i], ndim=1, dtype=inputs[i].dtype, device=device) for i in range(len(inputs))]
|
|
284
|
-
output = wp.zeros(1, dtype=wp.float32, device=device)
|
|
285
|
-
wp.launch(kernel, dim=1, inputs=wp_xs, outputs=[output], device=device)
|
|
286
|
-
return output.numpy()[0]
|
|
287
|
-
|
|
288
|
-
# compute numerical gradient
|
|
289
|
-
numerical_grad = []
|
|
290
|
-
np_xs = []
|
|
291
|
-
for i in range(len(inputs)):
|
|
292
|
-
np_xs.append(inputs[i].numpy().flatten().copy())
|
|
293
|
-
numerical_grad.append(np.zeros_like(np_xs[-1]))
|
|
294
|
-
inputs[i].requires_grad = True
|
|
295
|
-
|
|
296
|
-
for i in range(len(np_xs)):
|
|
297
|
-
for j in range(len(np_xs[i])):
|
|
298
|
-
np_xs[i][j] += eps
|
|
299
|
-
y1 = f(np_xs)
|
|
300
|
-
np_xs[i][j] -= 2 * eps
|
|
301
|
-
y2 = f(np_xs)
|
|
302
|
-
np_xs[i][j] += eps
|
|
303
|
-
numerical_grad[i][j] = (y1 - y2) / (2 * eps)
|
|
304
|
-
|
|
305
|
-
# compute analytical gradient
|
|
306
|
-
tape = wp.Tape()
|
|
307
|
-
output = wp.zeros(1, dtype=wp.float32, device=device, requires_grad=True)
|
|
308
|
-
with tape:
|
|
309
|
-
wp.launch(kernel, dim=1, inputs=inputs, outputs=[output], device=device)
|
|
310
|
-
|
|
311
|
-
tape.backward(loss=output)
|
|
312
|
-
|
|
313
|
-
# compare gradients
|
|
314
|
-
for i in range(len(inputs)):
|
|
315
|
-
grad = tape.gradients[inputs[i]]
|
|
316
|
-
assert_np_equal(grad.numpy(), numerical_grad[i], tol=tol)
|
|
317
|
-
|
|
318
|
-
tape.zero()
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
def test_vector_math_grad(test, device):
|
|
322
|
-
rng = np.random.default_rng(123)
|
|
323
|
-
|
|
324
|
-
# test unary operations
|
|
325
|
-
for dim, vec_type in [(2, wp.vec2), (3, wp.vec3), (4, wp.vec4), (4, wp.quat)]:
|
|
326
|
-
|
|
327
|
-
def check_length(vs: wp.array(dtype=vec_type), out: wp.array(dtype=float)):
|
|
328
|
-
out[0] = wp.length(vs[0])
|
|
329
|
-
|
|
330
|
-
def check_length_sq(vs: wp.array(dtype=vec_type), out: wp.array(dtype=float)):
|
|
331
|
-
out[0] = wp.length_sq(vs[0])
|
|
332
|
-
|
|
333
|
-
def check_normalize(vs: wp.array(dtype=vec_type), out: wp.array(dtype=float)):
|
|
334
|
-
out[0] = wp.length_sq(wp.normalize(vs[0])) # compress to scalar output
|
|
335
|
-
|
|
336
|
-
# run the tests with 5 different random inputs
|
|
337
|
-
for _ in range(5):
|
|
338
|
-
x = wp.array(rng.random(size=(1, dim), dtype=np.float32), dtype=vec_type, device=device)
|
|
339
|
-
gradcheck(check_length, f"check_length_{vec_type.__name__}", [x], device)
|
|
340
|
-
gradcheck(check_length_sq, f"check_length_sq_{vec_type.__name__}", [x], device)
|
|
341
|
-
gradcheck(check_normalize, f"check_normalize_{vec_type.__name__}", [x], device)
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
def test_matrix_math_grad(test, device):
|
|
345
|
-
rng = np.random.default_rng(123)
|
|
346
|
-
|
|
347
|
-
# test unary operations
|
|
348
|
-
for dim, mat_type in [(2, wp.mat22), (3, wp.mat33), (4, wp.mat44)]:
|
|
349
|
-
|
|
350
|
-
def check_determinant(vs: wp.array(dtype=mat_type), out: wp.array(dtype=float)):
|
|
351
|
-
out[0] = wp.determinant(vs[0])
|
|
352
|
-
|
|
353
|
-
def check_trace(vs: wp.array(dtype=mat_type), out: wp.array(dtype=float)):
|
|
354
|
-
out[0] = wp.trace(vs[0])
|
|
355
|
-
|
|
356
|
-
# run the tests with 5 different random inputs
|
|
357
|
-
for _ in range(5):
|
|
358
|
-
x = wp.array(rng.random(size=(1, dim, dim), dtype=np.float32), ndim=1, dtype=mat_type, device=device)
|
|
359
|
-
gradcheck(check_determinant, f"check_length_{mat_type.__name__}", [x], device)
|
|
360
|
-
gradcheck(check_trace, f"check_length_sq_{mat_type.__name__}", [x], device)
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
def test_3d_math_grad(test, device):
|
|
364
|
-
rng = np.random.default_rng(123)
|
|
365
|
-
|
|
366
|
-
# test binary operations
|
|
367
|
-
def check_cross(vs: wp.array(dtype=wp.vec3), out: wp.array(dtype=float)):
|
|
368
|
-
out[0] = wp.length(wp.cross(vs[0], vs[1]))
|
|
369
|
-
|
|
370
|
-
def check_dot(vs: wp.array(dtype=wp.vec3), out: wp.array(dtype=float)):
|
|
371
|
-
out[0] = wp.dot(vs[0], vs[1])
|
|
372
|
-
|
|
373
|
-
def check_mat33(vs: wp.array(dtype=wp.vec3), out: wp.array(dtype=float)):
|
|
374
|
-
a = vs[0]
|
|
375
|
-
b = vs[1]
|
|
376
|
-
c = wp.cross(a, b)
|
|
377
|
-
m = wp.mat33(a[0], b[0], c[0], a[1], b[1], c[1], a[2], b[2], c[2])
|
|
378
|
-
out[0] = wp.determinant(m)
|
|
379
|
-
|
|
380
|
-
def check_trace_diagonal(vs: wp.array(dtype=wp.vec3), out: wp.array(dtype=float)):
|
|
381
|
-
a = vs[0]
|
|
382
|
-
b = vs[1]
|
|
383
|
-
c = wp.cross(a, b)
|
|
384
|
-
m = wp.mat33(
|
|
385
|
-
1.0 / (a[0] + 10.0),
|
|
386
|
-
0.0,
|
|
387
|
-
0.0,
|
|
388
|
-
0.0,
|
|
389
|
-
1.0 / (b[1] + 10.0),
|
|
390
|
-
0.0,
|
|
391
|
-
0.0,
|
|
392
|
-
0.0,
|
|
393
|
-
1.0 / (c[2] + 10.0),
|
|
394
|
-
)
|
|
395
|
-
out[0] = wp.trace(m)
|
|
396
|
-
|
|
397
|
-
def check_rot_rpy(vs: wp.array(dtype=wp.vec3), out: wp.array(dtype=float)):
|
|
398
|
-
v = vs[0]
|
|
399
|
-
q = wp.quat_rpy(v[0], v[1], v[2])
|
|
400
|
-
out[0] = wp.length(wp.quat_rotate(q, vs[1]))
|
|
401
|
-
|
|
402
|
-
def check_rot_axis_angle(vs: wp.array(dtype=wp.vec3), out: wp.array(dtype=float)):
|
|
403
|
-
v = wp.normalize(vs[0])
|
|
404
|
-
q = wp.quat_from_axis_angle(v, 0.5)
|
|
405
|
-
out[0] = wp.length(wp.quat_rotate(q, vs[1]))
|
|
406
|
-
|
|
407
|
-
def check_rot_quat_inv(vs: wp.array(dtype=wp.vec3), out: wp.array(dtype=float)):
|
|
408
|
-
v = vs[0]
|
|
409
|
-
q = wp.normalize(wp.quat(v[0], v[1], v[2], 1.0))
|
|
410
|
-
out[0] = wp.length(wp.quat_rotate_inv(q, vs[1]))
|
|
411
|
-
|
|
412
|
-
# run the tests with 5 different random inputs
|
|
413
|
-
for _ in range(5):
|
|
414
|
-
x = wp.array(
|
|
415
|
-
rng.standard_normal(size=(2, 3), dtype=np.float32), dtype=wp.vec3, device=device, requires_grad=True
|
|
416
|
-
)
|
|
417
|
-
gradcheck(check_cross, "check_cross_3d", [x], device)
|
|
418
|
-
gradcheck(check_dot, "check_dot_3d", [x], device)
|
|
419
|
-
gradcheck(check_mat33, "check_mat33_3d", [x], device, eps=2e-2)
|
|
420
|
-
gradcheck(check_trace_diagonal, "check_trace_diagonal_3d", [x], device)
|
|
421
|
-
gradcheck(check_rot_rpy, "check_rot_rpy_3d", [x], device)
|
|
422
|
-
gradcheck(check_rot_axis_angle, "check_rot_axis_angle_3d", [x], device)
|
|
423
|
-
gradcheck(check_rot_quat_inv, "check_rot_quat_inv_3d", [x], device)
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
def test_multi_valued_function_grad(test, device):
|
|
427
|
-
rng = np.random.default_rng(123)
|
|
428
|
-
|
|
429
|
-
@wp.func
|
|
430
|
-
def multi_valued(x: float, y: float, z: float):
|
|
431
|
-
return wp.sin(x), wp.cos(y) * z, wp.sqrt(wp.abs(z)) / wp.abs(x)
|
|
432
|
-
|
|
433
|
-
# test multi-valued functions
|
|
434
|
-
def check_multi_valued(vs: wp.array(dtype=wp.vec3), out: wp.array(dtype=float)):
|
|
435
|
-
tid = wp.tid()
|
|
436
|
-
v = vs[tid]
|
|
437
|
-
a, b, c = multi_valued(v[0], v[1], v[2])
|
|
438
|
-
out[tid] = a + b + c
|
|
439
|
-
|
|
440
|
-
# run the tests with 5 different random inputs
|
|
441
|
-
for _ in range(5):
|
|
442
|
-
x = wp.array(
|
|
443
|
-
rng.standard_normal(size=(2, 3), dtype=np.float32), dtype=wp.vec3, device=device, requires_grad=True
|
|
444
|
-
)
|
|
445
|
-
gradcheck(check_multi_valued, "check_multi_valued_3d", [x], device)
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
def test_mesh_grad(test, device):
|
|
449
|
-
pos = wp.array(
|
|
450
|
-
[
|
|
451
|
-
[0.0, 0.0, 0.0],
|
|
452
|
-
[1.0, 0.0, 0.0],
|
|
453
|
-
[0.0, 1.0, 0.0],
|
|
454
|
-
[0.0, 0.0, 1.0],
|
|
455
|
-
],
|
|
456
|
-
dtype=wp.vec3,
|
|
457
|
-
device=device,
|
|
458
|
-
requires_grad=True,
|
|
459
|
-
)
|
|
460
|
-
indices = wp.array(
|
|
461
|
-
[0, 1, 2, 0, 2, 3, 0, 3, 1, 1, 3, 2],
|
|
462
|
-
dtype=wp.int32,
|
|
463
|
-
device=device,
|
|
464
|
-
)
|
|
465
|
-
|
|
466
|
-
mesh = wp.Mesh(points=pos, indices=indices)
|
|
467
|
-
|
|
468
|
-
@wp.func
|
|
469
|
-
def compute_triangle_area(mesh_id: wp.uint64, tri_id: int):
|
|
470
|
-
mesh = wp.mesh_get(mesh_id)
|
|
471
|
-
i, j, k = mesh.indices[tri_id * 3 + 0], mesh.indices[tri_id * 3 + 1], mesh.indices[tri_id * 3 + 2]
|
|
472
|
-
a = mesh.points[i]
|
|
473
|
-
b = mesh.points[j]
|
|
474
|
-
c = mesh.points[k]
|
|
475
|
-
return wp.length(wp.cross(b - a, c - a)) * 0.5
|
|
476
|
-
|
|
477
|
-
@wp.kernel
|
|
478
|
-
def compute_area(mesh_id: wp.uint64, out: wp.array(dtype=wp.float32)):
|
|
479
|
-
wp.atomic_add(out, 0, compute_triangle_area(mesh_id, wp.tid()))
|
|
480
|
-
|
|
481
|
-
num_tris = int(len(indices) / 3)
|
|
482
|
-
|
|
483
|
-
# compute analytical gradient
|
|
484
|
-
tape = wp.Tape()
|
|
485
|
-
output = wp.zeros(1, dtype=wp.float32, device=device, requires_grad=True)
|
|
486
|
-
with tape:
|
|
487
|
-
wp.launch(compute_area, dim=num_tris, inputs=[mesh.id], outputs=[output], device=device)
|
|
488
|
-
|
|
489
|
-
tape.backward(loss=output)
|
|
490
|
-
|
|
491
|
-
ad_grad = mesh.points.grad.numpy()
|
|
492
|
-
|
|
493
|
-
# compute finite differences
|
|
494
|
-
eps = 1e-3
|
|
495
|
-
pos_np = pos.numpy()
|
|
496
|
-
fd_grad = np.zeros_like(ad_grad)
|
|
497
|
-
|
|
498
|
-
for i in range(len(pos)):
|
|
499
|
-
for j in range(3):
|
|
500
|
-
pos_np[i, j] += eps
|
|
501
|
-
pos = wp.array(pos_np, dtype=wp.vec3, device=device)
|
|
502
|
-
mesh = wp.Mesh(points=pos, indices=indices)
|
|
503
|
-
output.zero_()
|
|
504
|
-
wp.launch(compute_area, dim=num_tris, inputs=[mesh.id], outputs=[output], device=device)
|
|
505
|
-
f1 = output.numpy()[0]
|
|
506
|
-
pos_np[i, j] -= 2 * eps
|
|
507
|
-
pos = wp.array(pos_np, dtype=wp.vec3, device=device)
|
|
508
|
-
mesh = wp.Mesh(points=pos, indices=indices)
|
|
509
|
-
output.zero_()
|
|
510
|
-
wp.launch(compute_area, dim=num_tris, inputs=[mesh.id], outputs=[output], device=device)
|
|
511
|
-
f2 = output.numpy()[0]
|
|
512
|
-
pos_np[i, j] += eps
|
|
513
|
-
fd_grad[i, j] = (f1 - f2) / (2 * eps)
|
|
514
|
-
|
|
515
|
-
assert np.allclose(ad_grad, fd_grad, atol=1e-3)
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
@wp.func
|
|
519
|
-
def name_clash(a: float, b: float) -> float:
|
|
520
|
-
return a + b
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
@wp.func_grad(name_clash)
|
|
524
|
-
def adj_name_clash(a: float, b: float, adj_ret: float):
|
|
525
|
-
# names `adj_a` and `adj_b` must not clash with function args of generated function
|
|
526
|
-
adj_a = 0.0
|
|
527
|
-
adj_b = 0.0
|
|
528
|
-
if a < 0.0:
|
|
529
|
-
adj_a = adj_ret
|
|
530
|
-
if b > 0.0:
|
|
531
|
-
adj_b = adj_ret
|
|
532
|
-
|
|
533
|
-
wp.adjoint[a] += adj_a
|
|
534
|
-
wp.adjoint[b] += adj_b
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
@wp.kernel
|
|
538
|
-
def name_clash_kernel(
|
|
539
|
-
input_a: wp.array(dtype=float),
|
|
540
|
-
input_b: wp.array(dtype=float),
|
|
541
|
-
output: wp.array(dtype=float),
|
|
542
|
-
):
|
|
543
|
-
tid = wp.tid()
|
|
544
|
-
output[tid] = name_clash(input_a[tid], input_b[tid])
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
def test_name_clash(test, device):
|
|
548
|
-
# tests that no name clashes occur when variable names such as `adj_a` are used in custom gradient code
|
|
549
|
-
with wp.ScopedDevice(device):
|
|
550
|
-
input_a = wp.array([1.0, -2.0, 3.0], dtype=wp.float32, requires_grad=True)
|
|
551
|
-
input_b = wp.array([4.0, 5.0, -6.0], dtype=wp.float32, requires_grad=True)
|
|
552
|
-
output = wp.zeros(3, dtype=wp.float32, requires_grad=True)
|
|
553
|
-
|
|
554
|
-
tape = wp.Tape()
|
|
555
|
-
with tape:
|
|
556
|
-
wp.launch(name_clash_kernel, dim=len(input_a), inputs=[input_a, input_b], outputs=[output])
|
|
557
|
-
|
|
558
|
-
tape.backward(grads={output: wp.array(np.ones(len(input_a), dtype=np.float32))})
|
|
559
|
-
|
|
560
|
-
assert_np_equal(input_a.grad.numpy(), np.array([0.0, 1.0, 0.0]))
|
|
561
|
-
assert_np_equal(input_b.grad.numpy(), np.array([1.0, 1.0, 0.0]))
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
@wp.struct
|
|
565
|
-
class NestedStruct:
|
|
566
|
-
v: wp.vec2
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
@wp.struct
|
|
570
|
-
class ParentStruct:
|
|
571
|
-
a: float
|
|
572
|
-
n: NestedStruct
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
@wp.func
|
|
576
|
-
def noop(a: Any):
|
|
577
|
-
pass
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
@wp.func
|
|
581
|
-
def sum2(v: wp.vec2):
|
|
582
|
-
return v[0] + v[1]
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
@wp.kernel
|
|
586
|
-
def test_struct_attribute_gradient_kernel(src: wp.array(dtype=float), res: wp.array(dtype=float)):
|
|
587
|
-
tid = wp.tid()
|
|
588
|
-
|
|
589
|
-
p = ParentStruct(src[tid], NestedStruct(wp.vec2(2.0 * src[tid])))
|
|
590
|
-
|
|
591
|
-
# test that we are not losing gradients when accessing attributes
|
|
592
|
-
noop(p.a)
|
|
593
|
-
noop(p.n)
|
|
594
|
-
noop(p.n.v)
|
|
595
|
-
|
|
596
|
-
res[tid] = p.a + sum2(p.n.v)
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
def test_struct_attribute_gradient(
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
1
|
+
# Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
|
|
2
|
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
|
3
|
+
# and proprietary rights in and to this software, related documentation
|
|
4
|
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
|
5
|
+
# distribution of this software and related documentation without an express
|
|
6
|
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
|
7
|
+
|
|
8
|
+
import unittest
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
|
|
13
|
+
import warp as wp
|
|
14
|
+
from warp.tests.unittest_utils import *
|
|
15
|
+
|
|
16
|
+
wp.init()
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@wp.kernel
|
|
20
|
+
def scalar_grad(x: wp.array(dtype=float), y: wp.array(dtype=float)):
|
|
21
|
+
y[0] = x[0] ** 2.0
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def test_scalar_grad(test, device):
|
|
25
|
+
x = wp.array([3.0], dtype=float, device=device, requires_grad=True)
|
|
26
|
+
y = wp.zeros_like(x)
|
|
27
|
+
|
|
28
|
+
tape = wp.Tape()
|
|
29
|
+
with tape:
|
|
30
|
+
wp.launch(scalar_grad, dim=1, inputs=[x, y], device=device)
|
|
31
|
+
|
|
32
|
+
tape.backward(y)
|
|
33
|
+
|
|
34
|
+
assert_np_equal(tape.gradients[x].numpy(), np.array(6.0))
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@wp.kernel
|
|
38
|
+
def for_loop_grad(n: int, x: wp.array(dtype=float), s: wp.array(dtype=float)):
|
|
39
|
+
sum = float(0.0)
|
|
40
|
+
|
|
41
|
+
for i in range(n):
|
|
42
|
+
sum = sum + x[i] * 2.0
|
|
43
|
+
|
|
44
|
+
s[0] = sum
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def test_for_loop_grad(test, device):
|
|
48
|
+
n = 32
|
|
49
|
+
val = np.ones(n, dtype=np.float32)
|
|
50
|
+
|
|
51
|
+
x = wp.array(val, device=device, requires_grad=True)
|
|
52
|
+
sum = wp.zeros(1, dtype=wp.float32, device=device, requires_grad=True)
|
|
53
|
+
|
|
54
|
+
tape = wp.Tape()
|
|
55
|
+
with tape:
|
|
56
|
+
wp.launch(for_loop_grad, dim=1, inputs=[n, x, sum], device=device)
|
|
57
|
+
|
|
58
|
+
# ensure forward pass outputs correct
|
|
59
|
+
assert_np_equal(sum.numpy(), 2.0 * np.sum(x.numpy()))
|
|
60
|
+
|
|
61
|
+
tape.backward(loss=sum)
|
|
62
|
+
|
|
63
|
+
# ensure forward pass outputs persist
|
|
64
|
+
assert_np_equal(sum.numpy(), 2.0 * np.sum(x.numpy()))
|
|
65
|
+
# ensure gradients correct
|
|
66
|
+
assert_np_equal(tape.gradients[x].numpy(), 2.0 * val)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def test_for_loop_graph_grad(test, device):
|
|
70
|
+
wp.load_module(device=device)
|
|
71
|
+
|
|
72
|
+
n = 32
|
|
73
|
+
val = np.ones(n, dtype=np.float32)
|
|
74
|
+
|
|
75
|
+
x = wp.array(val, device=device, requires_grad=True)
|
|
76
|
+
sum = wp.zeros(1, dtype=wp.float32, device=device, requires_grad=True)
|
|
77
|
+
|
|
78
|
+
wp.capture_begin(device, force_module_load=False)
|
|
79
|
+
try:
|
|
80
|
+
tape = wp.Tape()
|
|
81
|
+
with tape:
|
|
82
|
+
wp.launch(for_loop_grad, dim=1, inputs=[n, x, sum], device=device)
|
|
83
|
+
|
|
84
|
+
tape.backward(loss=sum)
|
|
85
|
+
finally:
|
|
86
|
+
graph = wp.capture_end(device)
|
|
87
|
+
|
|
88
|
+
wp.capture_launch(graph)
|
|
89
|
+
wp.synchronize_device(device)
|
|
90
|
+
|
|
91
|
+
# ensure forward pass outputs persist
|
|
92
|
+
assert_np_equal(sum.numpy(), 2.0 * np.sum(x.numpy()))
|
|
93
|
+
# ensure gradients correct
|
|
94
|
+
assert_np_equal(x.grad.numpy(), 2.0 * val)
|
|
95
|
+
|
|
96
|
+
wp.capture_launch(graph)
|
|
97
|
+
wp.synchronize_device(device)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
@wp.kernel
|
|
101
|
+
def for_loop_nested_if_grad(n: int, x: wp.array(dtype=float), s: wp.array(dtype=float)):
|
|
102
|
+
sum = float(0.0)
|
|
103
|
+
|
|
104
|
+
for i in range(n):
|
|
105
|
+
if i < 16:
|
|
106
|
+
if i < 8:
|
|
107
|
+
sum = sum + x[i] * 2.0
|
|
108
|
+
else:
|
|
109
|
+
sum = sum + x[i] * 4.0
|
|
110
|
+
else:
|
|
111
|
+
if i < 24:
|
|
112
|
+
sum = sum + x[i] * 6.0
|
|
113
|
+
else:
|
|
114
|
+
sum = sum + x[i] * 8.0
|
|
115
|
+
|
|
116
|
+
s[0] = sum
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def test_for_loop_nested_if_grad(test, device):
|
|
120
|
+
n = 32
|
|
121
|
+
val = np.ones(n, dtype=np.float32)
|
|
122
|
+
# fmt: off
|
|
123
|
+
expected_val = [
|
|
124
|
+
2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0,
|
|
125
|
+
4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0,
|
|
126
|
+
6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0,
|
|
127
|
+
8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0,
|
|
128
|
+
]
|
|
129
|
+
expected_grad = [
|
|
130
|
+
2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0,
|
|
131
|
+
4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0,
|
|
132
|
+
6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0,
|
|
133
|
+
8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0,
|
|
134
|
+
]
|
|
135
|
+
# fmt: on
|
|
136
|
+
|
|
137
|
+
x = wp.array(val, device=device, requires_grad=True)
|
|
138
|
+
sum = wp.zeros(1, dtype=wp.float32, device=device, requires_grad=True)
|
|
139
|
+
|
|
140
|
+
tape = wp.Tape()
|
|
141
|
+
with tape:
|
|
142
|
+
wp.launch(for_loop_nested_if_grad, dim=1, inputs=[n, x, sum], device=device)
|
|
143
|
+
|
|
144
|
+
assert_np_equal(sum.numpy(), np.sum(expected_val))
|
|
145
|
+
|
|
146
|
+
tape.backward(loss=sum)
|
|
147
|
+
|
|
148
|
+
assert_np_equal(sum.numpy(), np.sum(expected_val))
|
|
149
|
+
assert_np_equal(tape.gradients[x].numpy(), np.array(expected_grad))
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
@wp.kernel
|
|
153
|
+
def for_loop_grad_nested(n: int, x: wp.array(dtype=float), s: wp.array(dtype=float)):
|
|
154
|
+
sum = float(0.0)
|
|
155
|
+
|
|
156
|
+
for i in range(n):
|
|
157
|
+
for j in range(n):
|
|
158
|
+
sum = sum + x[i * n + j] * float(i * n + j) + 1.0
|
|
159
|
+
|
|
160
|
+
s[0] = sum
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def test_for_loop_nested_for_grad(test, device):
|
|
164
|
+
x = wp.zeros(9, dtype=float, device=device, requires_grad=True)
|
|
165
|
+
s = wp.zeros(1, dtype=float, device=device, requires_grad=True)
|
|
166
|
+
|
|
167
|
+
tape = wp.Tape()
|
|
168
|
+
with tape:
|
|
169
|
+
wp.launch(for_loop_grad_nested, dim=1, inputs=[3, x, s], device=device)
|
|
170
|
+
|
|
171
|
+
tape.backward(s)
|
|
172
|
+
|
|
173
|
+
assert_np_equal(s.numpy(), np.array([9.0]))
|
|
174
|
+
assert_np_equal(tape.gradients[x].numpy(), np.arange(0.0, 9.0, 1.0))
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
# differentiating thought most while loops is not supported
|
|
178
|
+
# since doing things like i = i + 1 breaks adjointing
|
|
179
|
+
|
|
180
|
+
# @wp.kernel
|
|
181
|
+
# def while_loop_grad(n: int,
|
|
182
|
+
# x: wp.array(dtype=float),
|
|
183
|
+
# c: wp.array(dtype=int),
|
|
184
|
+
# s: wp.array(dtype=float)):
|
|
185
|
+
|
|
186
|
+
# tid = wp.tid()
|
|
187
|
+
|
|
188
|
+
# while i < n:
|
|
189
|
+
# s[0] = s[0] + x[i]*2.0
|
|
190
|
+
# i = i + 1
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
# def test_while_loop_grad(test, device):
|
|
194
|
+
|
|
195
|
+
# n = 32
|
|
196
|
+
# x = wp.array(np.ones(n, dtype=np.float32), device=device, requires_grad=True)
|
|
197
|
+
# c = wp.zeros(1, dtype=int, device=device)
|
|
198
|
+
# sum = wp.zeros(1, dtype=wp.float32, device=device)
|
|
199
|
+
|
|
200
|
+
# tape = wp.Tape()
|
|
201
|
+
# with tape:
|
|
202
|
+
# wp.launch(while_loop_grad, dim=1, inputs=[n, x, c, sum], device=device)
|
|
203
|
+
|
|
204
|
+
# tape.backward(loss=sum)
|
|
205
|
+
|
|
206
|
+
# assert_np_equal(sum.numpy(), 2.0*np.sum(x.numpy()))
|
|
207
|
+
# assert_np_equal(tape.gradients[x].numpy(), 2.0*np.ones_like(x.numpy()))
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
@wp.kernel
|
|
211
|
+
def preserve_outputs(
|
|
212
|
+
n: int, x: wp.array(dtype=float), c: wp.array(dtype=float), s1: wp.array(dtype=float), s2: wp.array(dtype=float)
|
|
213
|
+
):
|
|
214
|
+
tid = wp.tid()
|
|
215
|
+
|
|
216
|
+
# plain store
|
|
217
|
+
c[tid] = x[tid] * 2.0
|
|
218
|
+
|
|
219
|
+
# atomic stores
|
|
220
|
+
wp.atomic_add(s1, 0, x[tid] * 3.0)
|
|
221
|
+
wp.atomic_sub(s2, 0, x[tid] * 2.0)
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
# tests that outputs from the forward pass are
|
|
225
|
+
# preserved by the backward pass, i.e.: stores
|
|
226
|
+
# are omitted during the forward reply
|
|
227
|
+
def test_preserve_outputs_grad(test, device):
|
|
228
|
+
n = 32
|
|
229
|
+
|
|
230
|
+
val = np.ones(n, dtype=np.float32)
|
|
231
|
+
|
|
232
|
+
x = wp.array(val, device=device, requires_grad=True)
|
|
233
|
+
c = wp.zeros_like(x)
|
|
234
|
+
|
|
235
|
+
s1 = wp.zeros(1, dtype=wp.float32, device=device, requires_grad=True)
|
|
236
|
+
s2 = wp.zeros(1, dtype=wp.float32, device=device, requires_grad=True)
|
|
237
|
+
|
|
238
|
+
tape = wp.Tape()
|
|
239
|
+
with tape:
|
|
240
|
+
wp.launch(preserve_outputs, dim=n, inputs=[n, x, c, s1, s2], device=device)
|
|
241
|
+
|
|
242
|
+
# ensure forward pass results are correct
|
|
243
|
+
assert_np_equal(x.numpy(), val)
|
|
244
|
+
assert_np_equal(c.numpy(), val * 2.0)
|
|
245
|
+
assert_np_equal(s1.numpy(), np.array(3.0 * n))
|
|
246
|
+
assert_np_equal(s2.numpy(), np.array(-2.0 * n))
|
|
247
|
+
|
|
248
|
+
# run backward on first loss
|
|
249
|
+
tape.backward(loss=s1)
|
|
250
|
+
|
|
251
|
+
# ensure inputs, copy and sum are unchanged by backwards pass
|
|
252
|
+
assert_np_equal(x.numpy(), val)
|
|
253
|
+
assert_np_equal(c.numpy(), val * 2.0)
|
|
254
|
+
assert_np_equal(s1.numpy(), np.array(3.0 * n))
|
|
255
|
+
assert_np_equal(s2.numpy(), np.array(-2.0 * n))
|
|
256
|
+
|
|
257
|
+
# ensure gradients are correct
|
|
258
|
+
assert_np_equal(tape.gradients[x].numpy(), 3.0 * val)
|
|
259
|
+
|
|
260
|
+
# run backward on second loss
|
|
261
|
+
tape.zero()
|
|
262
|
+
tape.backward(loss=s2)
|
|
263
|
+
|
|
264
|
+
assert_np_equal(x.numpy(), val)
|
|
265
|
+
assert_np_equal(c.numpy(), val * 2.0)
|
|
266
|
+
assert_np_equal(s1.numpy(), np.array(3.0 * n))
|
|
267
|
+
assert_np_equal(s2.numpy(), np.array(-2.0 * n))
|
|
268
|
+
|
|
269
|
+
# ensure gradients are correct
|
|
270
|
+
assert_np_equal(tape.gradients[x].numpy(), -2.0 * val)
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
def gradcheck(func, func_name, inputs, device, eps=1e-4, tol=1e-2):
|
|
274
|
+
"""
|
|
275
|
+
Checks that the gradient of the Warp kernel is correct by comparing it to the
|
|
276
|
+
numerical gradient computed using finite differences.
|
|
277
|
+
"""
|
|
278
|
+
|
|
279
|
+
kernel = wp.Kernel(func=func, key=func_name)
|
|
280
|
+
|
|
281
|
+
def f(xs):
|
|
282
|
+
# call the kernel without taping for finite differences
|
|
283
|
+
wp_xs = [wp.array(xs[i], ndim=1, dtype=inputs[i].dtype, device=device) for i in range(len(inputs))]
|
|
284
|
+
output = wp.zeros(1, dtype=wp.float32, device=device)
|
|
285
|
+
wp.launch(kernel, dim=1, inputs=wp_xs, outputs=[output], device=device)
|
|
286
|
+
return output.numpy()[0]
|
|
287
|
+
|
|
288
|
+
# compute numerical gradient
|
|
289
|
+
numerical_grad = []
|
|
290
|
+
np_xs = []
|
|
291
|
+
for i in range(len(inputs)):
|
|
292
|
+
np_xs.append(inputs[i].numpy().flatten().copy())
|
|
293
|
+
numerical_grad.append(np.zeros_like(np_xs[-1]))
|
|
294
|
+
inputs[i].requires_grad = True
|
|
295
|
+
|
|
296
|
+
for i in range(len(np_xs)):
|
|
297
|
+
for j in range(len(np_xs[i])):
|
|
298
|
+
np_xs[i][j] += eps
|
|
299
|
+
y1 = f(np_xs)
|
|
300
|
+
np_xs[i][j] -= 2 * eps
|
|
301
|
+
y2 = f(np_xs)
|
|
302
|
+
np_xs[i][j] += eps
|
|
303
|
+
numerical_grad[i][j] = (y1 - y2) / (2 * eps)
|
|
304
|
+
|
|
305
|
+
# compute analytical gradient
|
|
306
|
+
tape = wp.Tape()
|
|
307
|
+
output = wp.zeros(1, dtype=wp.float32, device=device, requires_grad=True)
|
|
308
|
+
with tape:
|
|
309
|
+
wp.launch(kernel, dim=1, inputs=inputs, outputs=[output], device=device)
|
|
310
|
+
|
|
311
|
+
tape.backward(loss=output)
|
|
312
|
+
|
|
313
|
+
# compare gradients
|
|
314
|
+
for i in range(len(inputs)):
|
|
315
|
+
grad = tape.gradients[inputs[i]]
|
|
316
|
+
assert_np_equal(grad.numpy(), numerical_grad[i], tol=tol)
|
|
317
|
+
|
|
318
|
+
tape.zero()
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
def test_vector_math_grad(test, device):
|
|
322
|
+
rng = np.random.default_rng(123)
|
|
323
|
+
|
|
324
|
+
# test unary operations
|
|
325
|
+
for dim, vec_type in [(2, wp.vec2), (3, wp.vec3), (4, wp.vec4), (4, wp.quat)]:
|
|
326
|
+
|
|
327
|
+
def check_length(vs: wp.array(dtype=vec_type), out: wp.array(dtype=float)):
|
|
328
|
+
out[0] = wp.length(vs[0])
|
|
329
|
+
|
|
330
|
+
def check_length_sq(vs: wp.array(dtype=vec_type), out: wp.array(dtype=float)):
|
|
331
|
+
out[0] = wp.length_sq(vs[0])
|
|
332
|
+
|
|
333
|
+
def check_normalize(vs: wp.array(dtype=vec_type), out: wp.array(dtype=float)):
|
|
334
|
+
out[0] = wp.length_sq(wp.normalize(vs[0])) # compress to scalar output
|
|
335
|
+
|
|
336
|
+
# run the tests with 5 different random inputs
|
|
337
|
+
for _ in range(5):
|
|
338
|
+
x = wp.array(rng.random(size=(1, dim), dtype=np.float32), dtype=vec_type, device=device)
|
|
339
|
+
gradcheck(check_length, f"check_length_{vec_type.__name__}", [x], device)
|
|
340
|
+
gradcheck(check_length_sq, f"check_length_sq_{vec_type.__name__}", [x], device)
|
|
341
|
+
gradcheck(check_normalize, f"check_normalize_{vec_type.__name__}", [x], device)
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
def test_matrix_math_grad(test, device):
|
|
345
|
+
rng = np.random.default_rng(123)
|
|
346
|
+
|
|
347
|
+
# test unary operations
|
|
348
|
+
for dim, mat_type in [(2, wp.mat22), (3, wp.mat33), (4, wp.mat44)]:
|
|
349
|
+
|
|
350
|
+
def check_determinant(vs: wp.array(dtype=mat_type), out: wp.array(dtype=float)):
|
|
351
|
+
out[0] = wp.determinant(vs[0])
|
|
352
|
+
|
|
353
|
+
def check_trace(vs: wp.array(dtype=mat_type), out: wp.array(dtype=float)):
|
|
354
|
+
out[0] = wp.trace(vs[0])
|
|
355
|
+
|
|
356
|
+
# run the tests with 5 different random inputs
|
|
357
|
+
for _ in range(5):
|
|
358
|
+
x = wp.array(rng.random(size=(1, dim, dim), dtype=np.float32), ndim=1, dtype=mat_type, device=device)
|
|
359
|
+
gradcheck(check_determinant, f"check_length_{mat_type.__name__}", [x], device)
|
|
360
|
+
gradcheck(check_trace, f"check_length_sq_{mat_type.__name__}", [x], device)
|
|
361
|
+
|
|
362
|
+
|
|
363
|
+
def test_3d_math_grad(test, device):
|
|
364
|
+
rng = np.random.default_rng(123)
|
|
365
|
+
|
|
366
|
+
# test binary operations
|
|
367
|
+
def check_cross(vs: wp.array(dtype=wp.vec3), out: wp.array(dtype=float)):
|
|
368
|
+
out[0] = wp.length(wp.cross(vs[0], vs[1]))
|
|
369
|
+
|
|
370
|
+
def check_dot(vs: wp.array(dtype=wp.vec3), out: wp.array(dtype=float)):
|
|
371
|
+
out[0] = wp.dot(vs[0], vs[1])
|
|
372
|
+
|
|
373
|
+
def check_mat33(vs: wp.array(dtype=wp.vec3), out: wp.array(dtype=float)):
|
|
374
|
+
a = vs[0]
|
|
375
|
+
b = vs[1]
|
|
376
|
+
c = wp.cross(a, b)
|
|
377
|
+
m = wp.mat33(a[0], b[0], c[0], a[1], b[1], c[1], a[2], b[2], c[2])
|
|
378
|
+
out[0] = wp.determinant(m)
|
|
379
|
+
|
|
380
|
+
def check_trace_diagonal(vs: wp.array(dtype=wp.vec3), out: wp.array(dtype=float)):
|
|
381
|
+
a = vs[0]
|
|
382
|
+
b = vs[1]
|
|
383
|
+
c = wp.cross(a, b)
|
|
384
|
+
m = wp.mat33(
|
|
385
|
+
1.0 / (a[0] + 10.0),
|
|
386
|
+
0.0,
|
|
387
|
+
0.0,
|
|
388
|
+
0.0,
|
|
389
|
+
1.0 / (b[1] + 10.0),
|
|
390
|
+
0.0,
|
|
391
|
+
0.0,
|
|
392
|
+
0.0,
|
|
393
|
+
1.0 / (c[2] + 10.0),
|
|
394
|
+
)
|
|
395
|
+
out[0] = wp.trace(m)
|
|
396
|
+
|
|
397
|
+
def check_rot_rpy(vs: wp.array(dtype=wp.vec3), out: wp.array(dtype=float)):
|
|
398
|
+
v = vs[0]
|
|
399
|
+
q = wp.quat_rpy(v[0], v[1], v[2])
|
|
400
|
+
out[0] = wp.length(wp.quat_rotate(q, vs[1]))
|
|
401
|
+
|
|
402
|
+
def check_rot_axis_angle(vs: wp.array(dtype=wp.vec3), out: wp.array(dtype=float)):
|
|
403
|
+
v = wp.normalize(vs[0])
|
|
404
|
+
q = wp.quat_from_axis_angle(v, 0.5)
|
|
405
|
+
out[0] = wp.length(wp.quat_rotate(q, vs[1]))
|
|
406
|
+
|
|
407
|
+
def check_rot_quat_inv(vs: wp.array(dtype=wp.vec3), out: wp.array(dtype=float)):
|
|
408
|
+
v = vs[0]
|
|
409
|
+
q = wp.normalize(wp.quat(v[0], v[1], v[2], 1.0))
|
|
410
|
+
out[0] = wp.length(wp.quat_rotate_inv(q, vs[1]))
|
|
411
|
+
|
|
412
|
+
# run the tests with 5 different random inputs
|
|
413
|
+
for _ in range(5):
|
|
414
|
+
x = wp.array(
|
|
415
|
+
rng.standard_normal(size=(2, 3), dtype=np.float32), dtype=wp.vec3, device=device, requires_grad=True
|
|
416
|
+
)
|
|
417
|
+
gradcheck(check_cross, "check_cross_3d", [x], device)
|
|
418
|
+
gradcheck(check_dot, "check_dot_3d", [x], device)
|
|
419
|
+
gradcheck(check_mat33, "check_mat33_3d", [x], device, eps=2e-2)
|
|
420
|
+
gradcheck(check_trace_diagonal, "check_trace_diagonal_3d", [x], device)
|
|
421
|
+
gradcheck(check_rot_rpy, "check_rot_rpy_3d", [x], device)
|
|
422
|
+
gradcheck(check_rot_axis_angle, "check_rot_axis_angle_3d", [x], device)
|
|
423
|
+
gradcheck(check_rot_quat_inv, "check_rot_quat_inv_3d", [x], device)
|
|
424
|
+
|
|
425
|
+
|
|
426
|
+
def test_multi_valued_function_grad(test, device):
|
|
427
|
+
rng = np.random.default_rng(123)
|
|
428
|
+
|
|
429
|
+
@wp.func
|
|
430
|
+
def multi_valued(x: float, y: float, z: float):
|
|
431
|
+
return wp.sin(x), wp.cos(y) * z, wp.sqrt(wp.abs(z)) / wp.abs(x)
|
|
432
|
+
|
|
433
|
+
# test multi-valued functions
|
|
434
|
+
def check_multi_valued(vs: wp.array(dtype=wp.vec3), out: wp.array(dtype=float)):
|
|
435
|
+
tid = wp.tid()
|
|
436
|
+
v = vs[tid]
|
|
437
|
+
a, b, c = multi_valued(v[0], v[1], v[2])
|
|
438
|
+
out[tid] = a + b + c
|
|
439
|
+
|
|
440
|
+
# run the tests with 5 different random inputs
|
|
441
|
+
for _ in range(5):
|
|
442
|
+
x = wp.array(
|
|
443
|
+
rng.standard_normal(size=(2, 3), dtype=np.float32), dtype=wp.vec3, device=device, requires_grad=True
|
|
444
|
+
)
|
|
445
|
+
gradcheck(check_multi_valued, "check_multi_valued_3d", [x], device)
|
|
446
|
+
|
|
447
|
+
|
|
448
|
+
def test_mesh_grad(test, device):
|
|
449
|
+
pos = wp.array(
|
|
450
|
+
[
|
|
451
|
+
[0.0, 0.0, 0.0],
|
|
452
|
+
[1.0, 0.0, 0.0],
|
|
453
|
+
[0.0, 1.0, 0.0],
|
|
454
|
+
[0.0, 0.0, 1.0],
|
|
455
|
+
],
|
|
456
|
+
dtype=wp.vec3,
|
|
457
|
+
device=device,
|
|
458
|
+
requires_grad=True,
|
|
459
|
+
)
|
|
460
|
+
indices = wp.array(
|
|
461
|
+
[0, 1, 2, 0, 2, 3, 0, 3, 1, 1, 3, 2],
|
|
462
|
+
dtype=wp.int32,
|
|
463
|
+
device=device,
|
|
464
|
+
)
|
|
465
|
+
|
|
466
|
+
mesh = wp.Mesh(points=pos, indices=indices)
|
|
467
|
+
|
|
468
|
+
@wp.func
|
|
469
|
+
def compute_triangle_area(mesh_id: wp.uint64, tri_id: int):
|
|
470
|
+
mesh = wp.mesh_get(mesh_id)
|
|
471
|
+
i, j, k = mesh.indices[tri_id * 3 + 0], mesh.indices[tri_id * 3 + 1], mesh.indices[tri_id * 3 + 2]
|
|
472
|
+
a = mesh.points[i]
|
|
473
|
+
b = mesh.points[j]
|
|
474
|
+
c = mesh.points[k]
|
|
475
|
+
return wp.length(wp.cross(b - a, c - a)) * 0.5
|
|
476
|
+
|
|
477
|
+
@wp.kernel
|
|
478
|
+
def compute_area(mesh_id: wp.uint64, out: wp.array(dtype=wp.float32)):
|
|
479
|
+
wp.atomic_add(out, 0, compute_triangle_area(mesh_id, wp.tid()))
|
|
480
|
+
|
|
481
|
+
num_tris = int(len(indices) / 3)
|
|
482
|
+
|
|
483
|
+
# compute analytical gradient
|
|
484
|
+
tape = wp.Tape()
|
|
485
|
+
output = wp.zeros(1, dtype=wp.float32, device=device, requires_grad=True)
|
|
486
|
+
with tape:
|
|
487
|
+
wp.launch(compute_area, dim=num_tris, inputs=[mesh.id], outputs=[output], device=device)
|
|
488
|
+
|
|
489
|
+
tape.backward(loss=output)
|
|
490
|
+
|
|
491
|
+
ad_grad = mesh.points.grad.numpy()
|
|
492
|
+
|
|
493
|
+
# compute finite differences
|
|
494
|
+
eps = 1e-3
|
|
495
|
+
pos_np = pos.numpy()
|
|
496
|
+
fd_grad = np.zeros_like(ad_grad)
|
|
497
|
+
|
|
498
|
+
for i in range(len(pos)):
|
|
499
|
+
for j in range(3):
|
|
500
|
+
pos_np[i, j] += eps
|
|
501
|
+
pos = wp.array(pos_np, dtype=wp.vec3, device=device)
|
|
502
|
+
mesh = wp.Mesh(points=pos, indices=indices)
|
|
503
|
+
output.zero_()
|
|
504
|
+
wp.launch(compute_area, dim=num_tris, inputs=[mesh.id], outputs=[output], device=device)
|
|
505
|
+
f1 = output.numpy()[0]
|
|
506
|
+
pos_np[i, j] -= 2 * eps
|
|
507
|
+
pos = wp.array(pos_np, dtype=wp.vec3, device=device)
|
|
508
|
+
mesh = wp.Mesh(points=pos, indices=indices)
|
|
509
|
+
output.zero_()
|
|
510
|
+
wp.launch(compute_area, dim=num_tris, inputs=[mesh.id], outputs=[output], device=device)
|
|
511
|
+
f2 = output.numpy()[0]
|
|
512
|
+
pos_np[i, j] += eps
|
|
513
|
+
fd_grad[i, j] = (f1 - f2) / (2 * eps)
|
|
514
|
+
|
|
515
|
+
assert np.allclose(ad_grad, fd_grad, atol=1e-3)
|
|
516
|
+
|
|
517
|
+
|
|
518
|
+
@wp.func
|
|
519
|
+
def name_clash(a: float, b: float) -> float:
|
|
520
|
+
return a + b
|
|
521
|
+
|
|
522
|
+
|
|
523
|
+
@wp.func_grad(name_clash)
|
|
524
|
+
def adj_name_clash(a: float, b: float, adj_ret: float):
|
|
525
|
+
# names `adj_a` and `adj_b` must not clash with function args of generated function
|
|
526
|
+
adj_a = 0.0
|
|
527
|
+
adj_b = 0.0
|
|
528
|
+
if a < 0.0:
|
|
529
|
+
adj_a = adj_ret
|
|
530
|
+
if b > 0.0:
|
|
531
|
+
adj_b = adj_ret
|
|
532
|
+
|
|
533
|
+
wp.adjoint[a] += adj_a
|
|
534
|
+
wp.adjoint[b] += adj_b
|
|
535
|
+
|
|
536
|
+
|
|
537
|
+
@wp.kernel
|
|
538
|
+
def name_clash_kernel(
|
|
539
|
+
input_a: wp.array(dtype=float),
|
|
540
|
+
input_b: wp.array(dtype=float),
|
|
541
|
+
output: wp.array(dtype=float),
|
|
542
|
+
):
|
|
543
|
+
tid = wp.tid()
|
|
544
|
+
output[tid] = name_clash(input_a[tid], input_b[tid])
|
|
545
|
+
|
|
546
|
+
|
|
547
|
+
def test_name_clash(test, device):
|
|
548
|
+
# tests that no name clashes occur when variable names such as `adj_a` are used in custom gradient code
|
|
549
|
+
with wp.ScopedDevice(device):
|
|
550
|
+
input_a = wp.array([1.0, -2.0, 3.0], dtype=wp.float32, requires_grad=True)
|
|
551
|
+
input_b = wp.array([4.0, 5.0, -6.0], dtype=wp.float32, requires_grad=True)
|
|
552
|
+
output = wp.zeros(3, dtype=wp.float32, requires_grad=True)
|
|
553
|
+
|
|
554
|
+
tape = wp.Tape()
|
|
555
|
+
with tape:
|
|
556
|
+
wp.launch(name_clash_kernel, dim=len(input_a), inputs=[input_a, input_b], outputs=[output])
|
|
557
|
+
|
|
558
|
+
tape.backward(grads={output: wp.array(np.ones(len(input_a), dtype=np.float32))})
|
|
559
|
+
|
|
560
|
+
assert_np_equal(input_a.grad.numpy(), np.array([0.0, 1.0, 0.0]))
|
|
561
|
+
assert_np_equal(input_b.grad.numpy(), np.array([1.0, 1.0, 0.0]))
|
|
562
|
+
|
|
563
|
+
|
|
564
|
+
@wp.struct
|
|
565
|
+
class NestedStruct:
|
|
566
|
+
v: wp.vec2
|
|
567
|
+
|
|
568
|
+
|
|
569
|
+
@wp.struct
|
|
570
|
+
class ParentStruct:
|
|
571
|
+
a: float
|
|
572
|
+
n: NestedStruct
|
|
573
|
+
|
|
574
|
+
|
|
575
|
+
@wp.func
|
|
576
|
+
def noop(a: Any):
|
|
577
|
+
pass
|
|
578
|
+
|
|
579
|
+
|
|
580
|
+
@wp.func
|
|
581
|
+
def sum2(v: wp.vec2):
|
|
582
|
+
return v[0] + v[1]
|
|
583
|
+
|
|
584
|
+
|
|
585
|
+
@wp.kernel
|
|
586
|
+
def test_struct_attribute_gradient_kernel(src: wp.array(dtype=float), res: wp.array(dtype=float)):
|
|
587
|
+
tid = wp.tid()
|
|
588
|
+
|
|
589
|
+
p = ParentStruct(src[tid], NestedStruct(wp.vec2(2.0 * src[tid])))
|
|
590
|
+
|
|
591
|
+
# test that we are not losing gradients when accessing attributes
|
|
592
|
+
noop(p.a)
|
|
593
|
+
noop(p.n)
|
|
594
|
+
noop(p.n.v)
|
|
595
|
+
|
|
596
|
+
res[tid] = p.a + sum2(p.n.v)
|
|
597
|
+
|
|
598
|
+
|
|
599
|
+
def test_struct_attribute_gradient(test, device):
|
|
600
|
+
with wp.ScopedDevice(device):
|
|
601
|
+
src = wp.array([1], dtype=float, requires_grad=True)
|
|
602
|
+
res = wp.empty_like(src)
|
|
603
|
+
|
|
604
|
+
tape = wp.Tape()
|
|
605
|
+
with tape:
|
|
606
|
+
wp.launch(test_struct_attribute_gradient_kernel, dim=1, inputs=[src, res])
|
|
607
|
+
|
|
608
|
+
res.grad.fill_(1.0)
|
|
609
|
+
tape.backward()
|
|
610
|
+
|
|
611
|
+
test.assertEqual(src.grad.numpy()[0], 5.0)
|
|
612
|
+
|
|
613
|
+
|
|
614
|
+
@wp.kernel
|
|
615
|
+
def copy_kernel(a: wp.array(dtype=wp.float32), b: wp.array(dtype=wp.float32)):
|
|
616
|
+
tid = wp.tid()
|
|
617
|
+
ai = a[tid]
|
|
618
|
+
bi = ai
|
|
619
|
+
b[tid] = bi
|
|
620
|
+
|
|
621
|
+
|
|
622
|
+
def test_copy(test, device):
|
|
623
|
+
with wp.ScopedDevice(device):
|
|
624
|
+
a = wp.array([-1.0, 2.0, 3.0], dtype=wp.float32, requires_grad=True)
|
|
625
|
+
b = wp.array([0.0, 0.0, 0.0], dtype=wp.float32, requires_grad=True)
|
|
626
|
+
|
|
627
|
+
wp.launch(copy_kernel, 1, inputs=[a, b])
|
|
628
|
+
|
|
629
|
+
b.grad = wp.array([1.0, 1.0, 1.0], dtype=wp.float32)
|
|
630
|
+
wp.launch(copy_kernel, a.shape[0], inputs=[a, b], adjoint=True, adj_inputs=[None, None])
|
|
631
|
+
|
|
632
|
+
assert_np_equal(a.grad.numpy(), np.array([1.0, 1.0, 1.0]))
|
|
633
|
+
|
|
634
|
+
|
|
635
|
+
@wp.kernel
|
|
636
|
+
def aliasing_kernel(a: wp.array(dtype=wp.float32), b: wp.array(dtype=wp.float32)):
|
|
637
|
+
tid = wp.tid()
|
|
638
|
+
x = a[tid]
|
|
639
|
+
|
|
640
|
+
y = x
|
|
641
|
+
if y > 0.0:
|
|
642
|
+
y = x * x
|
|
643
|
+
else:
|
|
644
|
+
y = x * x * x
|
|
645
|
+
|
|
646
|
+
b[tid] = y
|
|
647
|
+
|
|
648
|
+
|
|
649
|
+
def test_aliasing(test, device):
|
|
650
|
+
with wp.ScopedDevice(device):
|
|
651
|
+
a = wp.array([-1.0, 2.0, 3.0], dtype=wp.float32, requires_grad=True)
|
|
652
|
+
b = wp.array([0.0, 0.0, 0.0], dtype=wp.float32, requires_grad=True)
|
|
653
|
+
|
|
654
|
+
wp.launch(aliasing_kernel, 1, inputs=[a, b])
|
|
655
|
+
|
|
656
|
+
b.grad = wp.array([1.0, 1.0, 1.0], dtype=wp.float32)
|
|
657
|
+
wp.launch(aliasing_kernel, a.shape[0], inputs=[a, b], adjoint=True, adj_inputs=[None, None])
|
|
658
|
+
|
|
659
|
+
assert_np_equal(a.grad.numpy(), np.array([3.0, 4.0, 6.0]))
|
|
660
|
+
|
|
661
|
+
|
|
662
|
+
@wp.kernel
|
|
663
|
+
def square_kernel(x: wp.array(dtype=float), y: wp.array(dtype=float)):
|
|
664
|
+
tid = wp.tid()
|
|
665
|
+
y[tid] = x[tid] ** 2.0
|
|
666
|
+
|
|
667
|
+
|
|
668
|
+
def test_gradient_internal(test, device):
|
|
669
|
+
with wp.ScopedDevice(device):
|
|
670
|
+
a = wp.array([1.0, 2.0, 3.0], dtype=float, requires_grad=True)
|
|
671
|
+
b = wp.array([0.0, 0.0, 0.0], dtype=float, requires_grad=True)
|
|
672
|
+
|
|
673
|
+
wp.launch(square_kernel, a.size, inputs=[a, b])
|
|
674
|
+
|
|
675
|
+
# use internal gradients (.grad), adj_inputs are None
|
|
676
|
+
b.grad = wp.array([1.0, 1.0, 1.0], dtype=float)
|
|
677
|
+
wp.launch(square_kernel, a.shape[0], inputs=[a, b], adjoint=True, adj_inputs=[None, None])
|
|
678
|
+
|
|
679
|
+
assert_np_equal(a.grad.numpy(), np.array([2.0, 4.0, 6.0]))
|
|
680
|
+
|
|
681
|
+
|
|
682
|
+
def test_gradient_external(test, device):
|
|
683
|
+
with wp.ScopedDevice(device):
|
|
684
|
+
a = wp.array([1.0, 2.0, 3.0], dtype=float, requires_grad=False)
|
|
685
|
+
b = wp.array([0.0, 0.0, 0.0], dtype=float, requires_grad=False)
|
|
686
|
+
|
|
687
|
+
wp.launch(square_kernel, a.size, inputs=[a, b])
|
|
688
|
+
|
|
689
|
+
# use external gradients passed in adj_inputs
|
|
690
|
+
a_grad = wp.array([0.0, 0.0, 0.0], dtype=float)
|
|
691
|
+
b_grad = wp.array([1.0, 1.0, 1.0], dtype=float)
|
|
692
|
+
wp.launch(square_kernel, a.shape[0], inputs=[a, b], adjoint=True, adj_inputs=[a_grad, b_grad])
|
|
693
|
+
|
|
694
|
+
assert_np_equal(a_grad.numpy(), np.array([2.0, 4.0, 6.0]))
|
|
695
|
+
|
|
696
|
+
|
|
697
|
+
def test_gradient_precedence(test, device):
|
|
698
|
+
with wp.ScopedDevice(device):
|
|
699
|
+
a = wp.array([1.0, 2.0, 3.0], dtype=float, requires_grad=True)
|
|
700
|
+
b = wp.array([0.0, 0.0, 0.0], dtype=float, requires_grad=True)
|
|
701
|
+
|
|
702
|
+
wp.launch(square_kernel, a.size, inputs=[a, b])
|
|
703
|
+
|
|
704
|
+
# if both internal and external gradients are present, the external one takes precedence,
|
|
705
|
+
# because it's explicitly passed by the user in adj_inputs
|
|
706
|
+
a_grad = wp.array([0.0, 0.0, 0.0], dtype=float)
|
|
707
|
+
b_grad = wp.array([1.0, 1.0, 1.0], dtype=float)
|
|
708
|
+
wp.launch(square_kernel, a.shape[0], inputs=[a, b], adjoint=True, adj_inputs=[a_grad, b_grad])
|
|
709
|
+
|
|
710
|
+
assert_np_equal(a_grad.numpy(), np.array([2.0, 4.0, 6.0])) # used
|
|
711
|
+
assert_np_equal(a.grad.numpy(), np.array([0.0, 0.0, 0.0])) # unused
|
|
712
|
+
|
|
713
|
+
|
|
714
|
+
devices = get_test_devices()
|
|
715
|
+
|
|
716
|
+
|
|
717
|
+
class TestGrad(unittest.TestCase):
|
|
718
|
+
pass
|
|
719
|
+
|
|
720
|
+
|
|
721
|
+
# add_function_test(TestGrad, "test_while_loop_grad", test_while_loop_grad, devices=devices)
|
|
722
|
+
add_function_test(TestGrad, "test_for_loop_nested_for_grad", test_for_loop_nested_for_grad, devices=devices)
|
|
723
|
+
add_function_test(TestGrad, "test_scalar_grad", test_scalar_grad, devices=devices)
|
|
724
|
+
add_function_test(TestGrad, "test_for_loop_grad", test_for_loop_grad, devices=devices)
|
|
725
|
+
add_function_test(
|
|
726
|
+
TestGrad, "test_for_loop_graph_grad", test_for_loop_graph_grad, devices=get_selected_cuda_test_devices()
|
|
727
|
+
)
|
|
728
|
+
add_function_test(TestGrad, "test_for_loop_nested_if_grad", test_for_loop_nested_if_grad, devices=devices)
|
|
729
|
+
add_function_test(TestGrad, "test_preserve_outputs_grad", test_preserve_outputs_grad, devices=devices)
|
|
730
|
+
add_function_test(TestGrad, "test_vector_math_grad", test_vector_math_grad, devices=devices)
|
|
731
|
+
add_function_test(TestGrad, "test_matrix_math_grad", test_matrix_math_grad, devices=devices)
|
|
732
|
+
add_function_test(TestGrad, "test_3d_math_grad", test_3d_math_grad, devices=devices)
|
|
733
|
+
add_function_test(TestGrad, "test_multi_valued_function_grad", test_multi_valued_function_grad, devices=devices)
|
|
734
|
+
add_function_test(TestGrad, "test_mesh_grad", test_mesh_grad, devices=devices)
|
|
735
|
+
add_function_test(TestGrad, "test_name_clash", test_name_clash, devices=devices)
|
|
736
|
+
add_function_test(TestGrad, "test_struct_attribute_gradient", test_struct_attribute_gradient, devices=devices)
|
|
737
|
+
add_function_test(TestGrad, "test_copy", test_copy, devices=devices)
|
|
738
|
+
add_function_test(TestGrad, "test_aliasing", test_aliasing, devices=devices)
|
|
739
|
+
add_function_test(TestGrad, "test_gradient_internal", test_gradient_internal, devices=devices)
|
|
740
|
+
add_function_test(TestGrad, "test_gradient_external", test_gradient_external, devices=devices)
|
|
741
|
+
add_function_test(TestGrad, "test_gradient_precedence", test_gradient_precedence, devices=devices)
|
|
742
|
+
|
|
743
|
+
|
|
744
|
+
if __name__ == "__main__":
|
|
745
|
+
wp.build.clear_kernel_cache()
|
|
746
|
+
unittest.main(verbosity=2, failfast=False)
|