warp-lang 1.0.0b2__py3-none-win_amd64.whl → 1.0.0b6__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- docs/conf.py +17 -5
- examples/env/env_ant.py +1 -1
- examples/env/env_cartpole.py +1 -1
- examples/env/env_humanoid.py +1 -1
- examples/env/env_usd.py +4 -1
- examples/env/environment.py +8 -9
- examples/example_dem.py +34 -33
- examples/example_diffray.py +364 -337
- examples/example_fluid.py +32 -23
- examples/example_jacobian_ik.py +97 -93
- examples/example_marching_cubes.py +6 -16
- examples/example_mesh.py +6 -16
- examples/example_mesh_intersect.py +16 -14
- examples/example_nvdb.py +14 -16
- examples/example_raycast.py +14 -13
- examples/example_raymarch.py +16 -23
- examples/example_render_opengl.py +19 -10
- examples/example_sim_cartpole.py +82 -78
- examples/example_sim_cloth.py +45 -48
- examples/example_sim_fk_grad.py +51 -44
- examples/example_sim_fk_grad_torch.py +47 -40
- examples/example_sim_grad_bounce.py +108 -133
- examples/example_sim_grad_cloth.py +99 -113
- examples/example_sim_granular.py +5 -6
- examples/{example_sim_sdf_shape.py → example_sim_granular_collision_sdf.py} +37 -26
- examples/example_sim_neo_hookean.py +51 -55
- examples/example_sim_particle_chain.py +4 -4
- examples/example_sim_quadruped.py +126 -81
- examples/example_sim_rigid_chain.py +54 -61
- examples/example_sim_rigid_contact.py +66 -70
- examples/example_sim_rigid_fem.py +3 -3
- examples/example_sim_rigid_force.py +1 -1
- examples/example_sim_rigid_gyroscopic.py +3 -4
- examples/example_sim_rigid_kinematics.py +28 -39
- examples/example_sim_trajopt.py +112 -110
- examples/example_sph.py +9 -8
- examples/example_wave.py +7 -7
- examples/fem/bsr_utils.py +30 -17
- examples/fem/example_apic_fluid.py +85 -69
- examples/fem/example_convection_diffusion.py +97 -93
- examples/fem/example_convection_diffusion_dg.py +142 -149
- examples/fem/example_convection_diffusion_dg0.py +141 -136
- examples/fem/example_deformed_geometry.py +146 -0
- examples/fem/example_diffusion.py +115 -84
- examples/fem/example_diffusion_3d.py +116 -86
- examples/fem/example_diffusion_mgpu.py +102 -79
- examples/fem/example_mixed_elasticity.py +139 -100
- examples/fem/example_navier_stokes.py +175 -162
- examples/fem/example_stokes.py +143 -111
- examples/fem/example_stokes_transfer.py +186 -157
- examples/fem/mesh_utils.py +59 -97
- examples/fem/plot_utils.py +138 -17
- tools/ci/publishing/build_nodes_info.py +54 -0
- warp/__init__.py +4 -3
- warp/__init__.pyi +1 -0
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +5 -3
- warp/build_dll.py +29 -9
- warp/builtins.py +836 -492
- warp/codegen.py +864 -553
- warp/config.py +3 -1
- warp/context.py +389 -172
- warp/fem/__init__.py +24 -6
- warp/fem/cache.py +318 -25
- warp/fem/dirichlet.py +7 -3
- warp/fem/domain.py +14 -0
- warp/fem/field/__init__.py +30 -38
- warp/fem/field/field.py +149 -0
- warp/fem/field/nodal_field.py +244 -138
- warp/fem/field/restriction.py +8 -6
- warp/fem/field/test.py +127 -59
- warp/fem/field/trial.py +117 -60
- warp/fem/geometry/__init__.py +5 -1
- warp/fem/geometry/deformed_geometry.py +271 -0
- warp/fem/geometry/element.py +24 -1
- warp/fem/geometry/geometry.py +86 -14
- warp/fem/geometry/grid_2d.py +112 -54
- warp/fem/geometry/grid_3d.py +134 -65
- warp/fem/geometry/hexmesh.py +953 -0
- warp/fem/geometry/partition.py +85 -33
- warp/fem/geometry/quadmesh_2d.py +532 -0
- warp/fem/geometry/tetmesh.py +451 -115
- warp/fem/geometry/trimesh_2d.py +197 -92
- warp/fem/integrate.py +534 -268
- warp/fem/operator.py +58 -31
- warp/fem/polynomial.py +11 -0
- warp/fem/quadrature/__init__.py +1 -1
- warp/fem/quadrature/pic_quadrature.py +150 -58
- warp/fem/quadrature/quadrature.py +209 -57
- warp/fem/space/__init__.py +230 -53
- warp/fem/space/basis_space.py +489 -0
- warp/fem/space/collocated_function_space.py +105 -0
- warp/fem/space/dof_mapper.py +49 -2
- warp/fem/space/function_space.py +90 -39
- warp/fem/space/grid_2d_function_space.py +149 -496
- warp/fem/space/grid_3d_function_space.py +173 -538
- warp/fem/space/hexmesh_function_space.py +352 -0
- warp/fem/space/partition.py +129 -76
- warp/fem/space/quadmesh_2d_function_space.py +369 -0
- warp/fem/space/restriction.py +46 -34
- warp/fem/space/shape/__init__.py +15 -0
- warp/fem/space/shape/cube_shape_function.py +738 -0
- warp/fem/space/shape/shape_function.py +103 -0
- warp/fem/space/shape/square_shape_function.py +611 -0
- warp/fem/space/shape/tet_shape_function.py +567 -0
- warp/fem/space/shape/triangle_shape_function.py +429 -0
- warp/fem/space/tetmesh_function_space.py +132 -1039
- warp/fem/space/topology.py +295 -0
- warp/fem/space/trimesh_2d_function_space.py +104 -742
- warp/fem/types.py +13 -11
- warp/fem/utils.py +335 -60
- warp/native/array.h +120 -34
- warp/native/builtin.h +101 -72
- warp/native/bvh.cpp +73 -325
- warp/native/bvh.cu +406 -23
- warp/native/bvh.h +22 -40
- warp/native/clang/clang.cpp +1 -0
- warp/native/crt.h +2 -0
- warp/native/cuda_util.cpp +8 -3
- warp/native/cuda_util.h +1 -0
- warp/native/exports.h +1522 -1243
- warp/native/intersect.h +19 -4
- warp/native/intersect_adj.h +8 -8
- warp/native/mat.h +76 -17
- warp/native/mesh.cpp +33 -108
- warp/native/mesh.cu +114 -18
- warp/native/mesh.h +395 -40
- warp/native/noise.h +272 -329
- warp/native/quat.h +51 -8
- warp/native/rand.h +44 -34
- warp/native/reduce.cpp +1 -1
- warp/native/sparse.cpp +4 -4
- warp/native/sparse.cu +163 -155
- warp/native/spatial.h +2 -2
- warp/native/temp_buffer.h +18 -14
- warp/native/vec.h +103 -21
- warp/native/warp.cpp +2 -1
- warp/native/warp.cu +28 -3
- warp/native/warp.h +4 -3
- warp/render/render_opengl.py +261 -109
- warp/sim/__init__.py +1 -2
- warp/sim/articulation.py +385 -185
- warp/sim/import_mjcf.py +59 -48
- warp/sim/import_urdf.py +15 -15
- warp/sim/import_usd.py +174 -102
- warp/sim/inertia.py +17 -18
- warp/sim/integrator_xpbd.py +4 -3
- warp/sim/model.py +330 -250
- warp/sim/render.py +1 -1
- warp/sparse.py +625 -152
- warp/stubs.py +341 -309
- warp/tape.py +9 -6
- warp/tests/__main__.py +3 -6
- warp/tests/assets/curlnoise_golden.npy +0 -0
- warp/tests/assets/pnoise_golden.npy +0 -0
- warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
- warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
- warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
- warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
- warp/tests/aux_test_unresolved_func.py +14 -0
- warp/tests/aux_test_unresolved_symbol.py +14 -0
- warp/tests/disabled_kinematics.py +239 -0
- warp/tests/run_coverage_serial.py +31 -0
- warp/tests/test_adam.py +103 -106
- warp/tests/test_arithmetic.py +94 -74
- warp/tests/test_array.py +82 -101
- warp/tests/test_array_reduce.py +57 -23
- warp/tests/test_atomic.py +64 -28
- warp/tests/test_bool.py +22 -12
- warp/tests/test_builtins_resolution.py +1292 -0
- warp/tests/test_bvh.py +18 -18
- warp/tests/test_closest_point_edge_edge.py +54 -57
- warp/tests/test_codegen.py +165 -134
- warp/tests/test_compile_consts.py +28 -20
- warp/tests/test_conditional.py +108 -24
- warp/tests/test_copy.py +10 -12
- warp/tests/test_ctypes.py +112 -88
- warp/tests/test_dense.py +21 -14
- warp/tests/test_devices.py +98 -0
- warp/tests/test_dlpack.py +75 -75
- warp/tests/test_examples.py +237 -0
- warp/tests/test_fabricarray.py +22 -24
- warp/tests/test_fast_math.py +15 -11
- warp/tests/test_fem.py +1034 -124
- warp/tests/test_fp16.py +23 -16
- warp/tests/test_func.py +187 -86
- warp/tests/test_generics.py +194 -49
- warp/tests/test_grad.py +123 -181
- warp/tests/test_grad_customs.py +176 -0
- warp/tests/test_hash_grid.py +35 -34
- warp/tests/test_import.py +10 -23
- warp/tests/test_indexedarray.py +24 -25
- warp/tests/test_intersect.py +18 -9
- warp/tests/test_large.py +141 -0
- warp/tests/test_launch.py +14 -41
- warp/tests/test_lerp.py +64 -65
- warp/tests/test_lvalue.py +493 -0
- warp/tests/test_marching_cubes.py +12 -13
- warp/tests/test_mat.py +517 -2898
- warp/tests/test_mat_lite.py +115 -0
- warp/tests/test_mat_scalar_ops.py +2889 -0
- warp/tests/test_math.py +103 -9
- warp/tests/test_matmul.py +304 -69
- warp/tests/test_matmul_lite.py +410 -0
- warp/tests/test_mesh.py +60 -22
- warp/tests/test_mesh_query_aabb.py +21 -25
- warp/tests/test_mesh_query_point.py +111 -22
- warp/tests/test_mesh_query_ray.py +12 -24
- warp/tests/test_mlp.py +30 -22
- warp/tests/test_model.py +92 -89
- warp/tests/test_modules_lite.py +39 -0
- warp/tests/test_multigpu.py +88 -114
- warp/tests/test_noise.py +12 -11
- warp/tests/test_operators.py +16 -20
- warp/tests/test_options.py +11 -11
- warp/tests/test_pinned.py +17 -18
- warp/tests/test_print.py +32 -11
- warp/tests/test_quat.py +275 -129
- warp/tests/test_rand.py +18 -16
- warp/tests/test_reload.py +38 -34
- warp/tests/test_rounding.py +50 -43
- warp/tests/test_runlength_encode.py +168 -20
- warp/tests/test_smoothstep.py +9 -11
- warp/tests/test_snippet.py +143 -0
- warp/tests/test_sparse.py +261 -63
- warp/tests/test_spatial.py +276 -243
- warp/tests/test_streams.py +110 -85
- warp/tests/test_struct.py +268 -63
- warp/tests/test_tape.py +39 -21
- warp/tests/test_torch.py +90 -86
- warp/tests/test_transient_module.py +10 -12
- warp/tests/test_types.py +363 -0
- warp/tests/test_utils.py +451 -0
- warp/tests/test_vec.py +354 -2050
- warp/tests/test_vec_lite.py +73 -0
- warp/tests/test_vec_scalar_ops.py +2099 -0
- warp/tests/test_volume.py +418 -376
- warp/tests/test_volume_write.py +124 -134
- warp/tests/unittest_serial.py +35 -0
- warp/tests/unittest_suites.py +291 -0
- warp/tests/unittest_utils.py +342 -0
- warp/tests/{test_misc.py → unused_test_misc.py} +13 -5
- warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
- warp/thirdparty/appdirs.py +36 -45
- warp/thirdparty/unittest_parallel.py +589 -0
- warp/types.py +622 -211
- warp/utils.py +54 -393
- warp_lang-1.0.0b6.dist-info/METADATA +238 -0
- warp_lang-1.0.0b6.dist-info/RECORD +409 -0
- {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/WHEEL +1 -1
- examples/example_cache_management.py +0 -40
- examples/example_multigpu.py +0 -54
- examples/example_struct.py +0 -65
- examples/fem/example_stokes_transfer_3d.py +0 -210
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/fem/field/discrete_field.py +0 -80
- warp/fem/space/nodal_function_space.py +0 -233
- warp/tests/test_all.py +0 -223
- warp/tests/test_array_scan.py +0 -60
- warp/tests/test_base.py +0 -208
- warp/tests/test_unresolved_func.py +0 -7
- warp/tests/test_unresolved_symbol.py +0 -7
- warp_lang-1.0.0b2.dist-info/METADATA +0 -26
- warp_lang-1.0.0b2.dist-info/RECORD +0 -380
- /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
- /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
- /warp/tests/{test_square.py → aux_test_square.py} +0 -0
- {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/LICENSE.md +0 -0
- {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/top_level.txt +0 -0
warp/tests/test_grad.py
CHANGED
|
@@ -5,9 +5,13 @@
|
|
|
5
5
|
# distribution of this software and related documentation without an express
|
|
6
6
|
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
|
7
7
|
|
|
8
|
+
import unittest
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
8
11
|
import numpy as np
|
|
12
|
+
|
|
9
13
|
import warp as wp
|
|
10
|
-
from warp.tests.
|
|
14
|
+
from warp.tests.unittest_utils import *
|
|
11
15
|
|
|
12
16
|
wp.init()
|
|
13
17
|
|
|
@@ -63,26 +67,26 @@ def test_for_loop_grad(test, device):
|
|
|
63
67
|
|
|
64
68
|
|
|
65
69
|
def test_for_loop_graph_grad(test, device):
|
|
70
|
+
wp.load_module(device=device)
|
|
71
|
+
|
|
66
72
|
n = 32
|
|
67
73
|
val = np.ones(n, dtype=np.float32)
|
|
68
74
|
|
|
69
75
|
x = wp.array(val, device=device, requires_grad=True)
|
|
70
76
|
sum = wp.zeros(1, dtype=wp.float32, device=device, requires_grad=True)
|
|
71
77
|
|
|
72
|
-
wp.
|
|
78
|
+
wp.capture_begin(device, force_module_load=False)
|
|
79
|
+
try:
|
|
80
|
+
tape = wp.Tape()
|
|
81
|
+
with tape:
|
|
82
|
+
wp.launch(for_loop_grad, dim=1, inputs=[n, x, sum], device=device)
|
|
73
83
|
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
with tape:
|
|
78
|
-
wp.launch(for_loop_grad, dim=1, inputs=[n, x, sum], device=device)
|
|
79
|
-
|
|
80
|
-
tape.backward(loss=sum)
|
|
81
|
-
|
|
82
|
-
graph = wp.capture_end()
|
|
84
|
+
tape.backward(loss=sum)
|
|
85
|
+
finally:
|
|
86
|
+
graph = wp.capture_end(device)
|
|
83
87
|
|
|
84
88
|
wp.capture_launch(graph)
|
|
85
|
-
wp.
|
|
89
|
+
wp.synchronize_device(device)
|
|
86
90
|
|
|
87
91
|
# ensure forward pass outputs persist
|
|
88
92
|
assert_np_equal(sum.numpy(), 2.0 * np.sum(x.numpy()))
|
|
@@ -90,7 +94,7 @@ def test_for_loop_graph_grad(test, device):
|
|
|
90
94
|
assert_np_equal(x.grad.numpy(), 2.0 * val)
|
|
91
95
|
|
|
92
96
|
wp.capture_launch(graph)
|
|
93
|
-
wp.
|
|
97
|
+
wp.synchronize_device(device)
|
|
94
98
|
|
|
95
99
|
|
|
96
100
|
@wp.kernel
|
|
@@ -272,8 +276,7 @@ def gradcheck(func, func_name, inputs, device, eps=1e-4, tol=1e-2):
|
|
|
272
276
|
numerical gradient computed using finite differences.
|
|
273
277
|
"""
|
|
274
278
|
|
|
275
|
-
|
|
276
|
-
kernel = wp.Kernel(func=func, key=func_name, module=module)
|
|
279
|
+
kernel = wp.Kernel(func=func, key=func_name)
|
|
277
280
|
|
|
278
281
|
def f(xs):
|
|
279
282
|
# call the kernel without taping for finite differences
|
|
@@ -316,7 +319,7 @@ def gradcheck(func, func_name, inputs, device, eps=1e-4, tol=1e-2):
|
|
|
316
319
|
|
|
317
320
|
|
|
318
321
|
def test_vector_math_grad(test, device):
|
|
319
|
-
np.random.
|
|
322
|
+
rng = np.random.default_rng(123)
|
|
320
323
|
|
|
321
324
|
# test unary operations
|
|
322
325
|
for dim, vec_type in [(2, wp.vec2), (3, wp.vec3), (4, wp.vec4), (4, wp.quat)]:
|
|
@@ -332,14 +335,14 @@ def test_vector_math_grad(test, device):
|
|
|
332
335
|
|
|
333
336
|
# run the tests with 5 different random inputs
|
|
334
337
|
for _ in range(5):
|
|
335
|
-
x = wp.array(
|
|
338
|
+
x = wp.array(rng.random(size=(1, dim), dtype=np.float32), dtype=vec_type, device=device)
|
|
336
339
|
gradcheck(check_length, f"check_length_{vec_type.__name__}", [x], device)
|
|
337
340
|
gradcheck(check_length_sq, f"check_length_sq_{vec_type.__name__}", [x], device)
|
|
338
341
|
gradcheck(check_normalize, f"check_normalize_{vec_type.__name__}", [x], device)
|
|
339
342
|
|
|
340
343
|
|
|
341
344
|
def test_matrix_math_grad(test, device):
|
|
342
|
-
np.random.
|
|
345
|
+
rng = np.random.default_rng(123)
|
|
343
346
|
|
|
344
347
|
# test unary operations
|
|
345
348
|
for dim, mat_type in [(2, wp.mat22), (3, wp.mat33), (4, wp.mat44)]:
|
|
@@ -352,13 +355,13 @@ def test_matrix_math_grad(test, device):
|
|
|
352
355
|
|
|
353
356
|
# run the tests with 5 different random inputs
|
|
354
357
|
for _ in range(5):
|
|
355
|
-
x = wp.array(
|
|
358
|
+
x = wp.array(rng.random(size=(1, dim, dim), dtype=np.float32), ndim=1, dtype=mat_type, device=device)
|
|
356
359
|
gradcheck(check_determinant, f"check_length_{mat_type.__name__}", [x], device)
|
|
357
360
|
gradcheck(check_trace, f"check_length_sq_{mat_type.__name__}", [x], device)
|
|
358
361
|
|
|
359
362
|
|
|
360
363
|
def test_3d_math_grad(test, device):
|
|
361
|
-
np.random.
|
|
364
|
+
rng = np.random.default_rng(123)
|
|
362
365
|
|
|
363
366
|
# test binary operations
|
|
364
367
|
def check_cross(vs: wp.array(dtype=wp.vec3), out: wp.array(dtype=float)):
|
|
@@ -408,7 +411,9 @@ def test_3d_math_grad(test, device):
|
|
|
408
411
|
|
|
409
412
|
# run the tests with 5 different random inputs
|
|
410
413
|
for _ in range(5):
|
|
411
|
-
x = wp.array(
|
|
414
|
+
x = wp.array(
|
|
415
|
+
rng.standard_normal(size=(2, 3), dtype=np.float32), dtype=wp.vec3, device=device, requires_grad=True
|
|
416
|
+
)
|
|
412
417
|
gradcheck(check_cross, "check_cross_3d", [x], device)
|
|
413
418
|
gradcheck(check_dot, "check_dot_3d", [x], device)
|
|
414
419
|
gradcheck(check_mat33, "check_mat33_3d", [x], device, eps=2e-2)
|
|
@@ -419,7 +424,7 @@ def test_3d_math_grad(test, device):
|
|
|
419
424
|
|
|
420
425
|
|
|
421
426
|
def test_multi_valued_function_grad(test, device):
|
|
422
|
-
np.random.
|
|
427
|
+
rng = np.random.default_rng(123)
|
|
423
428
|
|
|
424
429
|
@wp.func
|
|
425
430
|
def multi_valued(x: float, y: float, z: float):
|
|
@@ -434,7 +439,9 @@ def test_multi_valued_function_grad(test, device):
|
|
|
434
439
|
|
|
435
440
|
# run the tests with 5 different random inputs
|
|
436
441
|
for _ in range(5):
|
|
437
|
-
x = wp.array(
|
|
442
|
+
x = wp.array(
|
|
443
|
+
rng.standard_normal(size=(2, 3), dtype=np.float32), dtype=wp.vec3, device=device, requires_grad=True
|
|
444
|
+
)
|
|
438
445
|
gradcheck(check_multi_valued, "check_multi_valued_3d", [x], device)
|
|
439
446
|
|
|
440
447
|
|
|
@@ -467,19 +474,17 @@ def test_mesh_grad(test, device):
|
|
|
467
474
|
c = mesh.points[k]
|
|
468
475
|
return wp.length(wp.cross(b - a, c - a)) * 0.5
|
|
469
476
|
|
|
477
|
+
@wp.kernel
|
|
470
478
|
def compute_area(mesh_id: wp.uint64, out: wp.array(dtype=wp.float32)):
|
|
471
479
|
wp.atomic_add(out, 0, compute_triangle_area(mesh_id, wp.tid()))
|
|
472
480
|
|
|
473
|
-
module = wp.get_module(compute_area.__module__)
|
|
474
|
-
kernel = wp.Kernel(func=compute_area, key="compute_area", module=module)
|
|
475
|
-
|
|
476
481
|
num_tris = int(len(indices) / 3)
|
|
477
482
|
|
|
478
483
|
# compute analytical gradient
|
|
479
484
|
tape = wp.Tape()
|
|
480
485
|
output = wp.zeros(1, dtype=wp.float32, device=device, requires_grad=True)
|
|
481
486
|
with tape:
|
|
482
|
-
wp.launch(
|
|
487
|
+
wp.launch(compute_area, dim=num_tris, inputs=[mesh.id], outputs=[output], device=device)
|
|
483
488
|
|
|
484
489
|
tape.backward(loss=output)
|
|
485
490
|
|
|
@@ -496,13 +501,13 @@ def test_mesh_grad(test, device):
|
|
|
496
501
|
pos = wp.array(pos_np, dtype=wp.vec3, device=device)
|
|
497
502
|
mesh = wp.Mesh(points=pos, indices=indices)
|
|
498
503
|
output.zero_()
|
|
499
|
-
wp.launch(
|
|
504
|
+
wp.launch(compute_area, dim=num_tris, inputs=[mesh.id], outputs=[output], device=device)
|
|
500
505
|
f1 = output.numpy()[0]
|
|
501
506
|
pos_np[i, j] -= 2 * eps
|
|
502
507
|
pos = wp.array(pos_np, dtype=wp.vec3, device=device)
|
|
503
508
|
mesh = wp.Mesh(points=pos, indices=indices)
|
|
504
509
|
output.zero_()
|
|
505
|
-
wp.launch(
|
|
510
|
+
wp.launch(compute_area, dim=num_tris, inputs=[mesh.id], outputs=[output], device=device)
|
|
506
511
|
f2 = output.numpy()[0]
|
|
507
512
|
pos_np[i, j] += eps
|
|
508
513
|
fd_grad[i, j] = (f1 - f2) / (2 * eps)
|
|
@@ -510,189 +515,126 @@ def test_mesh_grad(test, device):
|
|
|
510
515
|
assert np.allclose(ad_grad, fd_grad, atol=1e-3)
|
|
511
516
|
|
|
512
517
|
|
|
513
|
-
# atomic add function that memorizes which thread incremented the counter
|
|
514
|
-
# so that the correct counter value per thread can be used in the replay
|
|
515
|
-
# phase of the backward pass
|
|
516
518
|
@wp.func
|
|
517
|
-
def
|
|
518
|
-
|
|
519
|
-
counter_index: int,
|
|
520
|
-
value: int,
|
|
521
|
-
thread_values: wp.array(dtype=int),
|
|
522
|
-
tid: int
|
|
523
|
-
):
|
|
524
|
-
next_index = wp.atomic_add(counter, counter_index, value)
|
|
525
|
-
thread_values[tid] = next_index
|
|
526
|
-
return next_index
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
@wp.func_replay(reversible_increment)
|
|
530
|
-
def replay_reversible_increment(
|
|
531
|
-
counter: wp.array(dtype=int),
|
|
532
|
-
counter_index: int,
|
|
533
|
-
value: int,
|
|
534
|
-
thread_values: wp.array(dtype=int),
|
|
535
|
-
tid: int
|
|
536
|
-
):
|
|
537
|
-
return thread_values[tid]
|
|
519
|
+
def name_clash(a: float, b: float) -> float:
|
|
520
|
+
return a + b
|
|
538
521
|
|
|
539
522
|
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
523
|
+
@wp.func_grad(name_clash)
|
|
524
|
+
def adj_name_clash(a: float, b: float, adj_ret: float):
|
|
525
|
+
# names `adj_a` and `adj_b` must not clash with function args of generated function
|
|
526
|
+
adj_a = 0.0
|
|
527
|
+
adj_b = 0.0
|
|
528
|
+
if a < 0.0:
|
|
529
|
+
adj_a = adj_ret
|
|
530
|
+
if b > 0.0:
|
|
531
|
+
adj_b = adj_ret
|
|
546
532
|
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
input: wp.array(dtype=float),
|
|
550
|
-
counter: wp.array(dtype=int),
|
|
551
|
-
thread_values: wp.array(dtype=int),
|
|
552
|
-
output: wp.array(dtype=float)
|
|
553
|
-
):
|
|
554
|
-
tid = wp.tid()
|
|
555
|
-
idx = reversible_increment(counter, 0, 1, thread_values, tid)
|
|
556
|
-
output[idx] = input[idx] ** 2.0
|
|
533
|
+
wp.adjoint[a] += adj_a
|
|
534
|
+
wp.adjoint[b] += adj_b
|
|
557
535
|
|
|
558
|
-
tape = wp.Tape()
|
|
559
|
-
with tape:
|
|
560
|
-
wp.launch(run_atomic_add, dim=num_threads, inputs=[inputs, counter, thread_ids], outputs=[outputs], device=device)
|
|
561
536
|
|
|
562
|
-
|
|
563
|
-
|
|
537
|
+
@wp.kernel
|
|
538
|
+
def name_clash_kernel(
|
|
539
|
+
input_a: wp.array(dtype=float),
|
|
540
|
+
input_b: wp.array(dtype=float),
|
|
541
|
+
output: wp.array(dtype=float),
|
|
542
|
+
):
|
|
543
|
+
tid = wp.tid()
|
|
544
|
+
output[tid] = name_clash(input_a[tid], input_b[tid])
|
|
564
545
|
|
|
565
546
|
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
|
|
547
|
+
def test_name_clash(test, device):
|
|
548
|
+
# tests that no name clashes occur when variable names such as `adj_a` are used in custom gradient code
|
|
549
|
+
with wp.ScopedDevice(device):
|
|
550
|
+
input_a = wp.array([1.0, -2.0, 3.0], dtype=wp.float32, requires_grad=True)
|
|
551
|
+
input_b = wp.array([4.0, 5.0, -6.0], dtype=wp.float32, requires_grad=True)
|
|
552
|
+
output = wp.zeros(3, dtype=wp.float32, requires_grad=True)
|
|
553
|
+
|
|
554
|
+
tape = wp.Tape()
|
|
555
|
+
with tape:
|
|
556
|
+
wp.launch(name_clash_kernel, dim=len(input_a), inputs=[input_a, input_b], outputs=[output])
|
|
569
557
|
|
|
558
|
+
tape.backward(grads={output: wp.array(np.ones(len(input_a), dtype=np.float32))})
|
|
570
559
|
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
|
|
560
|
+
assert_np_equal(input_a.grad.numpy(), np.array([0.0, 1.0, 0.0]))
|
|
561
|
+
assert_np_equal(input_b.grad.numpy(), np.array([1.0, 1.0, 0.0]))
|
|
562
|
+
|
|
563
|
+
|
|
564
|
+
@wp.struct
|
|
565
|
+
class NestedStruct:
|
|
566
|
+
v: wp.vec2
|
|
575
567
|
|
|
576
568
|
|
|
577
569
|
@wp.struct
|
|
578
|
-
class
|
|
579
|
-
|
|
580
|
-
|
|
570
|
+
class ParentStruct:
|
|
571
|
+
a: float
|
|
572
|
+
n: NestedStruct
|
|
581
573
|
|
|
582
574
|
|
|
583
575
|
@wp.func
|
|
584
|
-
def
|
|
585
|
-
|
|
576
|
+
def noop(a: Any):
|
|
577
|
+
pass
|
|
586
578
|
|
|
587
579
|
|
|
588
|
-
@wp.
|
|
589
|
-
def
|
|
590
|
-
|
|
591
|
-
wp.adjoint[x.vec][0] += adj_ret0 * x.vec[1] * x.vec[2] * 20.0
|
|
592
|
-
wp.adjoint[x.vec][1] += adj_ret1 * x.vec[0] * x.vec[2] * 30.0
|
|
593
|
-
wp.adjoint[x.vec][2] += adj_ret2 * x.vec[0] * x.vec[1] * 40.0
|
|
580
|
+
@wp.func
|
|
581
|
+
def sum2(v: wp.vec2):
|
|
582
|
+
return v[0] + v[1]
|
|
594
583
|
|
|
595
584
|
|
|
596
585
|
@wp.kernel
|
|
597
|
-
def
|
|
598
|
-
|
|
599
|
-
ys: wp.array(dtype=float),
|
|
600
|
-
output0: wp.array(dtype=float),
|
|
601
|
-
output1: wp.array(dtype=float)
|
|
602
|
-
):
|
|
603
|
-
i = wp.tid()
|
|
604
|
-
out0, out1 = overload_fn(xs[i], ys[i])
|
|
605
|
-
output0[i] = out0
|
|
606
|
-
output1[i] = out1
|
|
586
|
+
def test_struct_attribute_gradient_kernel(src: wp.array(dtype=float), res: wp.array(dtype=float)):
|
|
587
|
+
tid = wp.tid()
|
|
607
588
|
|
|
589
|
+
p = ParentStruct(src[tid], NestedStruct(wp.vec2(2.0 * src[tid])))
|
|
590
|
+
|
|
591
|
+
# test that we are not losing gradients when accessing attributes
|
|
592
|
+
noop(p.a)
|
|
593
|
+
noop(p.n)
|
|
594
|
+
noop(p.n.v)
|
|
595
|
+
|
|
596
|
+
res[tid] = p.a + sum2(p.n.v)
|
|
597
|
+
|
|
598
|
+
|
|
599
|
+
def test_struct_attribute_gradient(test_case, device):
|
|
600
|
+
src = wp.array([1], dtype=float, requires_grad=True)
|
|
601
|
+
res = wp.empty_like(src)
|
|
608
602
|
|
|
609
|
-
@wp.kernel
|
|
610
|
-
def run_overload_struct_fn(xs: wp.array(dtype=MyStruct), output: wp.array(dtype=float)):
|
|
611
|
-
i = wp.tid()
|
|
612
|
-
out0, out1, out2 = overload_fn(xs[i])
|
|
613
|
-
output[i] = out0 + out1 + out2
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
def test_custom_overload_grad(test, device):
|
|
617
|
-
dim = 3
|
|
618
|
-
xs_float = wp.array(np.arange(1.0, dim + 1.0), dtype=wp.float32, requires_grad=True)
|
|
619
|
-
ys_float = wp.array(np.arange(10.0, dim + 10.0), dtype=wp.float32, requires_grad=True)
|
|
620
|
-
out0_float = wp.zeros(dim)
|
|
621
|
-
out1_float = wp.zeros(dim)
|
|
622
|
-
tape = wp.Tape()
|
|
623
|
-
with tape:
|
|
624
|
-
wp.launch(
|
|
625
|
-
run_overload_float_fn,
|
|
626
|
-
dim=dim,
|
|
627
|
-
inputs=[xs_float, ys_float],
|
|
628
|
-
outputs=[out0_float, out1_float])
|
|
629
|
-
tape.backward(grads={
|
|
630
|
-
out0_float: wp.array(np.ones(dim), dtype=wp.float32),
|
|
631
|
-
out1_float: wp.array(np.ones(dim), dtype=wp.float32)})
|
|
632
|
-
assert_np_equal(xs_float.grad.numpy(), xs_float.numpy() * 42.0 + ys_float.numpy() * 10.0)
|
|
633
|
-
assert_np_equal(ys_float.grad.numpy(), ys_float.numpy() * 3.0)
|
|
634
|
-
|
|
635
|
-
x0 = MyStruct()
|
|
636
|
-
x0.vec = wp.vec3(1., 2., 3.)
|
|
637
|
-
x0.scalar = 4.
|
|
638
|
-
x1 = MyStruct()
|
|
639
|
-
x1.vec = wp.vec3(5., 6., 7.)
|
|
640
|
-
x1.scalar = -1.0
|
|
641
|
-
x2 = MyStruct()
|
|
642
|
-
x2.vec = wp.vec3(8., 9., 10.)
|
|
643
|
-
x2.scalar = 19.0
|
|
644
|
-
xs_struct = wp.array([x0, x1, x2], dtype=MyStruct, requires_grad=True)
|
|
645
|
-
out_struct = wp.zeros(dim)
|
|
646
603
|
tape = wp.Tape()
|
|
647
604
|
with tape:
|
|
648
|
-
wp.launch(
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
struct_grads = xs_struct.grad.numpy()
|
|
656
|
-
# fmt: off
|
|
657
|
-
assert_np_equal(
|
|
658
|
-
np.array([g[0] for g in struct_grads]),
|
|
659
|
-
np.array([g[0] * 10.0 for g in xs_struct_np]))
|
|
660
|
-
assert_np_equal(
|
|
661
|
-
np.array([g[1][0] for g in struct_grads]),
|
|
662
|
-
np.array([g[1][1] * g[1][2] * 20.0 for g in xs_struct_np]))
|
|
663
|
-
assert_np_equal(
|
|
664
|
-
np.array([g[1][1] for g in struct_grads]),
|
|
665
|
-
np.array([g[1][0] * g[1][2] * 30.0 for g in xs_struct_np]))
|
|
666
|
-
assert_np_equal(
|
|
667
|
-
np.array([g[1][2] for g in struct_grads]),
|
|
668
|
-
np.array([g[1][0] * g[1][1] * 40.0 for g in xs_struct_np]))
|
|
669
|
-
# fmt: on
|
|
605
|
+
wp.launch(test_struct_attribute_gradient_kernel, dim=1, inputs=[src, res])
|
|
606
|
+
|
|
607
|
+
res.grad.fill_(1.0)
|
|
608
|
+
tape.backward()
|
|
609
|
+
|
|
610
|
+
test_case.assertEqual(src.grad.numpy()[0], 5.0)
|
|
611
|
+
|
|
670
612
|
|
|
613
|
+
devices = get_test_devices()
|
|
671
614
|
|
|
672
|
-
def register(parent):
|
|
673
|
-
devices = get_test_devices()
|
|
674
615
|
|
|
675
|
-
|
|
676
|
-
|
|
616
|
+
class TestGrad(unittest.TestCase):
|
|
617
|
+
pass
|
|
677
618
|
|
|
678
|
-
# add_function_test(TestGrad, "test_while_loop_grad", test_while_loop_grad, devices=devices)
|
|
679
|
-
add_function_test(TestGrad, "test_for_loop_nested_for_grad", test_for_loop_nested_for_grad, devices=devices)
|
|
680
|
-
add_function_test(TestGrad, "test_scalar_grad", test_scalar_grad, devices=devices)
|
|
681
|
-
add_function_test(TestGrad, "test_for_loop_grad", test_for_loop_grad, devices=devices)
|
|
682
|
-
add_function_test(TestGrad, "test_for_loop_graph_grad", test_for_loop_graph_grad, devices=wp.get_cuda_devices())
|
|
683
|
-
add_function_test(TestGrad, "test_for_loop_nested_if_grad", test_for_loop_nested_if_grad, devices=devices)
|
|
684
|
-
add_function_test(TestGrad, "test_preserve_outputs_grad", test_preserve_outputs_grad, devices=devices)
|
|
685
|
-
add_function_test(TestGrad, "test_vector_math_grad", test_vector_math_grad, devices=devices)
|
|
686
|
-
add_function_test(TestGrad, "test_matrix_math_grad", test_matrix_math_grad, devices=devices)
|
|
687
|
-
add_function_test(TestGrad, "test_3d_math_grad", test_3d_math_grad, devices=devices)
|
|
688
|
-
add_function_test(TestGrad, "test_multi_valued_function_grad", test_multi_valued_function_grad, devices=devices)
|
|
689
|
-
add_function_test(TestGrad, "test_mesh_grad", test_mesh_grad, devices=devices)
|
|
690
|
-
add_function_test(TestGrad, "test_custom_replay_grad", test_custom_replay_grad, devices=devices)
|
|
691
|
-
add_function_test(TestGrad, "test_custom_overload_grad", test_custom_overload_grad, devices=devices)
|
|
692
619
|
|
|
693
|
-
|
|
620
|
+
# add_function_test(TestGrad, "test_while_loop_grad", test_while_loop_grad, devices=devices)
|
|
621
|
+
add_function_test(TestGrad, "test_for_loop_nested_for_grad", test_for_loop_nested_for_grad, devices=devices)
|
|
622
|
+
add_function_test(TestGrad, "test_scalar_grad", test_scalar_grad, devices=devices)
|
|
623
|
+
add_function_test(TestGrad, "test_for_loop_grad", test_for_loop_grad, devices=devices)
|
|
624
|
+
add_function_test(
|
|
625
|
+
TestGrad, "test_for_loop_graph_grad", test_for_loop_graph_grad, devices=get_unique_cuda_test_devices()
|
|
626
|
+
)
|
|
627
|
+
add_function_test(TestGrad, "test_for_loop_nested_if_grad", test_for_loop_nested_if_grad, devices=devices)
|
|
628
|
+
add_function_test(TestGrad, "test_preserve_outputs_grad", test_preserve_outputs_grad, devices=devices)
|
|
629
|
+
add_function_test(TestGrad, "test_vector_math_grad", test_vector_math_grad, devices=devices)
|
|
630
|
+
add_function_test(TestGrad, "test_matrix_math_grad", test_matrix_math_grad, devices=devices)
|
|
631
|
+
add_function_test(TestGrad, "test_3d_math_grad", test_3d_math_grad, devices=devices)
|
|
632
|
+
add_function_test(TestGrad, "test_multi_valued_function_grad", test_multi_valued_function_grad, devices=devices)
|
|
633
|
+
add_function_test(TestGrad, "test_mesh_grad", test_mesh_grad, devices=devices)
|
|
634
|
+
add_function_test(TestGrad, "test_name_clash", test_name_clash, devices=devices)
|
|
635
|
+
add_function_test(TestGrad, "test_struct_attribute_gradient", test_struct_attribute_gradient, devices=devices)
|
|
694
636
|
|
|
695
637
|
|
|
696
638
|
if __name__ == "__main__":
|
|
697
|
-
|
|
639
|
+
wp.build.clear_kernel_cache()
|
|
698
640
|
unittest.main(verbosity=2, failfast=False)
|
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
# Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
|
|
2
|
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
|
3
|
+
# and proprietary rights in and to this software, related documentation
|
|
4
|
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
|
5
|
+
# distribution of this software and related documentation without an express
|
|
6
|
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
|
7
|
+
|
|
8
|
+
import unittest
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
|
|
12
|
+
import warp as wp
|
|
13
|
+
from warp.tests.unittest_utils import *
|
|
14
|
+
|
|
15
|
+
wp.init()
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
# atomic add function that memorizes which thread incremented the counter
|
|
19
|
+
# so that the correct counter value per thread can be used in the replay
|
|
20
|
+
# phase of the backward pass
|
|
21
|
+
@wp.func
|
|
22
|
+
def reversible_increment(
|
|
23
|
+
counter: wp.array(dtype=int), counter_index: int, value: int, thread_values: wp.array(dtype=int), tid: int
|
|
24
|
+
):
|
|
25
|
+
next_index = wp.atomic_add(counter, counter_index, value)
|
|
26
|
+
thread_values[tid] = next_index
|
|
27
|
+
return next_index
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@wp.func_replay(reversible_increment)
|
|
31
|
+
def replay_reversible_increment(
|
|
32
|
+
counter: wp.array(dtype=int), counter_index: int, value: int, thread_values: wp.array(dtype=int), tid: int
|
|
33
|
+
):
|
|
34
|
+
return thread_values[tid]
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def test_custom_replay_grad(test, device):
|
|
38
|
+
num_threads = 128
|
|
39
|
+
counter = wp.zeros(1, dtype=wp.int32, device=device)
|
|
40
|
+
thread_ids = wp.zeros(num_threads, dtype=wp.int32, device=device)
|
|
41
|
+
inputs = wp.array(np.arange(num_threads, dtype=np.float32), device=device, requires_grad=True)
|
|
42
|
+
outputs = wp.zeros_like(inputs)
|
|
43
|
+
|
|
44
|
+
@wp.kernel
|
|
45
|
+
def run_atomic_add(
|
|
46
|
+
input: wp.array(dtype=float),
|
|
47
|
+
counter: wp.array(dtype=int),
|
|
48
|
+
thread_values: wp.array(dtype=int),
|
|
49
|
+
output: wp.array(dtype=float),
|
|
50
|
+
):
|
|
51
|
+
tid = wp.tid()
|
|
52
|
+
idx = reversible_increment(counter, 0, 1, thread_values, tid)
|
|
53
|
+
output[idx] = input[idx] ** 2.0
|
|
54
|
+
|
|
55
|
+
tape = wp.Tape()
|
|
56
|
+
with tape:
|
|
57
|
+
wp.launch(
|
|
58
|
+
run_atomic_add, dim=num_threads, inputs=[inputs, counter, thread_ids], outputs=[outputs], device=device
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
tape.backward(grads={outputs: wp.array(np.ones(num_threads, dtype=np.float32), device=device)})
|
|
62
|
+
assert_np_equal(inputs.grad.numpy(), 2.0 * inputs.numpy(), tol=1e-4)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
@wp.func
|
|
66
|
+
def overload_fn(x: float, y: float):
|
|
67
|
+
return x * 3.0 + y / 3.0, y**2.5
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
@wp.func_grad(overload_fn)
|
|
71
|
+
def overload_fn_grad(x: float, y: float, adj_ret0: float, adj_ret1: float):
|
|
72
|
+
wp.adjoint[x] += x * adj_ret0 * 42.0 + y * adj_ret1 * 10.0
|
|
73
|
+
wp.adjoint[y] += y * adj_ret1 * 3.0
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
@wp.struct
|
|
77
|
+
class MyStruct:
|
|
78
|
+
scalar: float
|
|
79
|
+
vec: wp.vec3
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
@wp.func
|
|
83
|
+
def overload_fn(x: MyStruct):
|
|
84
|
+
return x.vec[0] * x.vec[1] * x.vec[2] * 4.0, wp.length(x.vec), x.scalar**0.5
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
@wp.func_grad(overload_fn)
|
|
88
|
+
def overload_fn_grad(x: MyStruct, adj_ret0: float, adj_ret1: float, adj_ret2: float):
|
|
89
|
+
wp.adjoint[x.scalar] += x.scalar * adj_ret0 * 10.0
|
|
90
|
+
wp.adjoint[x.vec][0] += adj_ret0 * x.vec[1] * x.vec[2] * 20.0
|
|
91
|
+
wp.adjoint[x.vec][1] += adj_ret1 * x.vec[0] * x.vec[2] * 30.0
|
|
92
|
+
wp.adjoint[x.vec][2] += adj_ret2 * x.vec[0] * x.vec[1] * 40.0
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
@wp.kernel
|
|
96
|
+
def run_overload_float_fn(
|
|
97
|
+
xs: wp.array(dtype=float), ys: wp.array(dtype=float), output0: wp.array(dtype=float), output1: wp.array(dtype=float)
|
|
98
|
+
):
|
|
99
|
+
i = wp.tid()
|
|
100
|
+
out0, out1 = overload_fn(xs[i], ys[i])
|
|
101
|
+
output0[i] = out0
|
|
102
|
+
output1[i] = out1
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
@wp.kernel
|
|
106
|
+
def run_overload_struct_fn(xs: wp.array(dtype=MyStruct), output: wp.array(dtype=float)):
|
|
107
|
+
i = wp.tid()
|
|
108
|
+
out0, out1, out2 = overload_fn(xs[i])
|
|
109
|
+
output[i] = out0 + out1 + out2
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def test_custom_overload_grad(test, device):
|
|
113
|
+
dim = 3
|
|
114
|
+
xs_float = wp.array(np.arange(1.0, dim + 1.0), dtype=wp.float32, requires_grad=True)
|
|
115
|
+
ys_float = wp.array(np.arange(10.0, dim + 10.0), dtype=wp.float32, requires_grad=True)
|
|
116
|
+
out0_float = wp.zeros(dim)
|
|
117
|
+
out1_float = wp.zeros(dim)
|
|
118
|
+
tape = wp.Tape()
|
|
119
|
+
with tape:
|
|
120
|
+
wp.launch(run_overload_float_fn, dim=dim, inputs=[xs_float, ys_float], outputs=[out0_float, out1_float])
|
|
121
|
+
tape.backward(
|
|
122
|
+
grads={
|
|
123
|
+
out0_float: wp.array(np.ones(dim), dtype=wp.float32),
|
|
124
|
+
out1_float: wp.array(np.ones(dim), dtype=wp.float32),
|
|
125
|
+
}
|
|
126
|
+
)
|
|
127
|
+
assert_np_equal(xs_float.grad.numpy(), xs_float.numpy() * 42.0 + ys_float.numpy() * 10.0)
|
|
128
|
+
assert_np_equal(ys_float.grad.numpy(), ys_float.numpy() * 3.0)
|
|
129
|
+
|
|
130
|
+
x0 = MyStruct()
|
|
131
|
+
x0.vec = wp.vec3(1.0, 2.0, 3.0)
|
|
132
|
+
x0.scalar = 4.0
|
|
133
|
+
x1 = MyStruct()
|
|
134
|
+
x1.vec = wp.vec3(5.0, 6.0, 7.0)
|
|
135
|
+
x1.scalar = -1.0
|
|
136
|
+
x2 = MyStruct()
|
|
137
|
+
x2.vec = wp.vec3(8.0, 9.0, 10.0)
|
|
138
|
+
x2.scalar = 19.0
|
|
139
|
+
xs_struct = wp.array([x0, x1, x2], dtype=MyStruct, requires_grad=True)
|
|
140
|
+
out_struct = wp.zeros(dim)
|
|
141
|
+
tape = wp.Tape()
|
|
142
|
+
with tape:
|
|
143
|
+
wp.launch(run_overload_struct_fn, dim=dim, inputs=[xs_struct], outputs=[out_struct])
|
|
144
|
+
tape.backward(grads={out_struct: wp.array(np.ones(dim), dtype=wp.float32)})
|
|
145
|
+
xs_struct_np = xs_struct.numpy()
|
|
146
|
+
struct_grads = xs_struct.grad.numpy()
|
|
147
|
+
# fmt: off
|
|
148
|
+
assert_np_equal(
|
|
149
|
+
np.array([g[0] for g in struct_grads]),
|
|
150
|
+
np.array([g[0] * 10.0 for g in xs_struct_np]))
|
|
151
|
+
assert_np_equal(
|
|
152
|
+
np.array([g[1][0] for g in struct_grads]),
|
|
153
|
+
np.array([g[1][1] * g[1][2] * 20.0 for g in xs_struct_np]))
|
|
154
|
+
assert_np_equal(
|
|
155
|
+
np.array([g[1][1] for g in struct_grads]),
|
|
156
|
+
np.array([g[1][0] * g[1][2] * 30.0 for g in xs_struct_np]))
|
|
157
|
+
assert_np_equal(
|
|
158
|
+
np.array([g[1][2] for g in struct_grads]),
|
|
159
|
+
np.array([g[1][0] * g[1][1] * 40.0 for g in xs_struct_np]))
|
|
160
|
+
# fmt: on
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
devices = get_test_devices()
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
class TestGradCustoms(unittest.TestCase):
|
|
167
|
+
pass
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
add_function_test(TestGradCustoms, "test_custom_replay_grad", test_custom_replay_grad, devices=devices)
|
|
171
|
+
add_function_test(TestGradCustoms, "test_custom_overload_grad", test_custom_overload_grad, devices=devices)
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
if __name__ == "__main__":
|
|
175
|
+
wp.build.clear_kernel_cache()
|
|
176
|
+
unittest.main(verbosity=2, failfast=False)
|