warp-lang 1.9.1__py3-none-manylinux_2_34_aarch64.whl → 1.10.0rc2__py3-none-manylinux_2_34_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +301 -287
- warp/__init__.pyi +794 -305
- warp/_src/__init__.py +14 -0
- warp/_src/autograd.py +1075 -0
- warp/_src/build.py +618 -0
- warp/_src/build_dll.py +640 -0
- warp/{builtins.py → _src/builtins.py} +1382 -377
- warp/_src/codegen.py +4359 -0
- warp/{config.py → _src/config.py} +178 -169
- warp/_src/constants.py +57 -0
- warp/_src/context.py +8294 -0
- warp/_src/dlpack.py +462 -0
- warp/_src/fabric.py +355 -0
- warp/_src/fem/__init__.py +14 -0
- warp/_src/fem/adaptivity.py +508 -0
- warp/_src/fem/cache.py +687 -0
- warp/_src/fem/dirichlet.py +188 -0
- warp/{fem → _src/fem}/domain.py +40 -30
- warp/_src/fem/field/__init__.py +131 -0
- warp/_src/fem/field/field.py +701 -0
- warp/{fem → _src/fem}/field/nodal_field.py +30 -15
- warp/{fem → _src/fem}/field/restriction.py +1 -1
- warp/{fem → _src/fem}/field/virtual.py +53 -27
- warp/_src/fem/geometry/__init__.py +32 -0
- warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +77 -163
- warp/_src/fem/geometry/closest_point.py +97 -0
- warp/{fem → _src/fem}/geometry/deformed_geometry.py +14 -22
- warp/{fem → _src/fem}/geometry/element.py +32 -10
- warp/{fem → _src/fem}/geometry/geometry.py +48 -20
- warp/{fem → _src/fem}/geometry/grid_2d.py +12 -23
- warp/{fem → _src/fem}/geometry/grid_3d.py +12 -23
- warp/{fem → _src/fem}/geometry/hexmesh.py +40 -63
- warp/{fem → _src/fem}/geometry/nanogrid.py +255 -248
- warp/{fem → _src/fem}/geometry/partition.py +121 -63
- warp/{fem → _src/fem}/geometry/quadmesh.py +26 -45
- warp/{fem → _src/fem}/geometry/tetmesh.py +40 -63
- warp/{fem → _src/fem}/geometry/trimesh.py +26 -45
- warp/{fem → _src/fem}/integrate.py +164 -158
- warp/_src/fem/linalg.py +383 -0
- warp/_src/fem/operator.py +396 -0
- warp/_src/fem/polynomial.py +229 -0
- warp/{fem → _src/fem}/quadrature/pic_quadrature.py +15 -20
- warp/{fem → _src/fem}/quadrature/quadrature.py +95 -47
- warp/_src/fem/space/__init__.py +248 -0
- warp/{fem → _src/fem}/space/basis_function_space.py +20 -11
- warp/_src/fem/space/basis_space.py +679 -0
- warp/{fem → _src/fem}/space/dof_mapper.py +3 -3
- warp/{fem → _src/fem}/space/function_space.py +14 -13
- warp/{fem → _src/fem}/space/grid_2d_function_space.py +4 -7
- warp/{fem → _src/fem}/space/grid_3d_function_space.py +4 -4
- warp/{fem → _src/fem}/space/hexmesh_function_space.py +4 -10
- warp/{fem → _src/fem}/space/nanogrid_function_space.py +3 -9
- warp/{fem → _src/fem}/space/partition.py +117 -60
- warp/{fem → _src/fem}/space/quadmesh_function_space.py +4 -10
- warp/{fem → _src/fem}/space/restriction.py +66 -33
- warp/_src/fem/space/shape/__init__.py +152 -0
- warp/{fem → _src/fem}/space/shape/cube_shape_function.py +9 -9
- warp/{fem → _src/fem}/space/shape/shape_function.py +8 -9
- warp/{fem → _src/fem}/space/shape/square_shape_function.py +6 -6
- warp/{fem → _src/fem}/space/shape/tet_shape_function.py +3 -3
- warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +3 -3
- warp/{fem → _src/fem}/space/tetmesh_function_space.py +3 -9
- warp/_src/fem/space/topology.py +459 -0
- warp/{fem → _src/fem}/space/trimesh_function_space.py +3 -9
- warp/_src/fem/types.py +112 -0
- warp/_src/fem/utils.py +486 -0
- warp/_src/jax.py +186 -0
- warp/_src/jax_experimental/__init__.py +14 -0
- warp/_src/jax_experimental/custom_call.py +387 -0
- warp/_src/jax_experimental/ffi.py +1284 -0
- warp/_src/jax_experimental/xla_ffi.py +656 -0
- warp/_src/marching_cubes.py +708 -0
- warp/_src/math.py +414 -0
- warp/_src/optim/__init__.py +14 -0
- warp/_src/optim/adam.py +163 -0
- warp/_src/optim/linear.py +1606 -0
- warp/_src/optim/sgd.py +112 -0
- warp/_src/paddle.py +406 -0
- warp/_src/render/__init__.py +14 -0
- warp/_src/render/imgui_manager.py +289 -0
- warp/_src/render/render_opengl.py +3636 -0
- warp/_src/render/render_usd.py +937 -0
- warp/_src/render/utils.py +160 -0
- warp/_src/sparse.py +2716 -0
- warp/_src/tape.py +1206 -0
- warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
- warp/_src/torch.py +391 -0
- warp/_src/types.py +5870 -0
- warp/_src/utils.py +1693 -0
- warp/autograd.py +12 -1054
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/build.py +8 -588
- warp/build_dll.py +6 -721
- warp/codegen.py +6 -4251
- warp/constants.py +6 -39
- warp/context.py +12 -8062
- warp/dlpack.py +6 -444
- warp/examples/distributed/example_jacobi_mpi.py +4 -5
- warp/examples/fem/example_adaptive_grid.py +1 -1
- warp/examples/fem/example_apic_fluid.py +1 -1
- warp/examples/fem/example_burgers.py +8 -8
- warp/examples/fem/example_diffusion.py +1 -1
- warp/examples/fem/example_distortion_energy.py +1 -1
- warp/examples/fem/example_mixed_elasticity.py +2 -2
- warp/examples/fem/example_navier_stokes.py +1 -1
- warp/examples/fem/example_nonconforming_contact.py +7 -7
- warp/examples/fem/example_stokes.py +1 -1
- warp/examples/fem/example_stokes_transfer.py +1 -1
- warp/examples/fem/utils.py +2 -2
- warp/examples/interop/example_jax_callable.py +1 -1
- warp/examples/interop/example_jax_ffi_callback.py +1 -1
- warp/examples/interop/example_jax_kernel.py +1 -1
- warp/examples/tile/example_tile_mcgp.py +191 -0
- warp/fabric.py +6 -337
- warp/fem/__init__.py +159 -97
- warp/fem/adaptivity.py +7 -489
- warp/fem/cache.py +9 -648
- warp/fem/dirichlet.py +6 -184
- warp/fem/field/__init__.py +8 -109
- warp/fem/field/field.py +7 -652
- warp/fem/geometry/__init__.py +7 -18
- warp/fem/geometry/closest_point.py +11 -77
- warp/fem/linalg.py +18 -366
- warp/fem/operator.py +11 -369
- warp/fem/polynomial.py +9 -209
- warp/fem/space/__init__.py +5 -211
- warp/fem/space/basis_space.py +6 -662
- warp/fem/space/shape/__init__.py +41 -118
- warp/fem/space/topology.py +6 -437
- warp/fem/types.py +6 -81
- warp/fem/utils.py +11 -444
- warp/jax.py +8 -165
- warp/jax_experimental/__init__.py +14 -1
- warp/jax_experimental/custom_call.py +8 -365
- warp/jax_experimental/ffi.py +17 -873
- warp/jax_experimental/xla_ffi.py +5 -605
- warp/marching_cubes.py +5 -689
- warp/math.py +16 -393
- warp/native/array.h +385 -37
- warp/native/builtin.h +314 -37
- warp/native/bvh.cpp +43 -9
- warp/native/bvh.cu +62 -27
- warp/native/bvh.h +310 -309
- warp/native/clang/clang.cpp +102 -97
- warp/native/coloring.cpp +0 -1
- warp/native/crt.h +208 -0
- warp/native/exports.h +156 -0
- warp/native/hashgrid.cu +2 -0
- warp/native/intersect.h +24 -1
- warp/native/intersect_tri.h +44 -35
- warp/native/mat.h +1456 -276
- warp/native/mesh.cpp +4 -4
- warp/native/mesh.cu +4 -2
- warp/native/mesh.h +176 -61
- warp/native/quat.h +0 -52
- warp/native/scan.cu +2 -0
- warp/native/sparse.cu +7 -3
- warp/native/spatial.h +12 -0
- warp/native/tile.h +681 -89
- warp/native/tile_radix_sort.h +1 -1
- warp/native/tile_reduce.h +394 -46
- warp/native/tile_scan.h +4 -4
- warp/native/vec.h +469 -0
- warp/native/version.h +23 -0
- warp/native/volume.cpp +1 -1
- warp/native/volume.cu +1 -0
- warp/native/volume.h +1 -1
- warp/native/volume_builder.cu +2 -0
- warp/native/warp.cpp +57 -29
- warp/native/warp.cu +253 -171
- warp/native/warp.h +11 -8
- warp/optim/__init__.py +6 -3
- warp/optim/adam.py +6 -145
- warp/optim/linear.py +14 -1585
- warp/optim/sgd.py +6 -94
- warp/paddle.py +6 -388
- warp/render/__init__.py +8 -4
- warp/render/imgui_manager.py +7 -267
- warp/render/render_opengl.py +6 -3618
- warp/render/render_usd.py +6 -919
- warp/render/utils.py +6 -142
- warp/sparse.py +37 -2563
- warp/tape.py +6 -1188
- warp/tests/__main__.py +1 -1
- warp/tests/cuda/test_async.py +4 -4
- warp/tests/cuda/test_conditional_captures.py +1 -1
- warp/tests/cuda/test_multigpu.py +1 -1
- warp/tests/cuda/test_streams.py +58 -1
- warp/tests/geometry/test_bvh.py +157 -22
- warp/tests/geometry/test_marching_cubes.py +0 -1
- warp/tests/geometry/test_mesh.py +5 -3
- warp/tests/geometry/test_mesh_query_aabb.py +5 -12
- warp/tests/geometry/test_mesh_query_point.py +5 -2
- warp/tests/geometry/test_mesh_query_ray.py +15 -3
- warp/tests/geometry/test_volume_write.py +5 -5
- warp/tests/interop/test_dlpack.py +14 -14
- warp/tests/interop/test_jax.py +772 -49
- warp/tests/interop/test_paddle.py +1 -1
- warp/tests/test_adam.py +0 -1
- warp/tests/test_arithmetic.py +9 -9
- warp/tests/test_array.py +527 -100
- warp/tests/test_array_reduce.py +3 -3
- warp/tests/test_atomic.py +12 -8
- warp/tests/test_atomic_bitwise.py +209 -0
- warp/tests/test_atomic_cas.py +4 -4
- warp/tests/test_bool.py +2 -2
- warp/tests/test_builtins_resolution.py +5 -571
- warp/tests/test_codegen.py +33 -14
- warp/tests/test_conditional.py +1 -1
- warp/tests/test_context.py +6 -6
- warp/tests/test_copy.py +242 -161
- warp/tests/test_ctypes.py +3 -3
- warp/tests/test_devices.py +24 -2
- warp/tests/test_examples.py +16 -84
- warp/tests/test_fabricarray.py +35 -35
- warp/tests/test_fast_math.py +0 -2
- warp/tests/test_fem.py +56 -10
- warp/tests/test_fixedarray.py +3 -3
- warp/tests/test_func.py +8 -5
- warp/tests/test_generics.py +1 -1
- warp/tests/test_indexedarray.py +24 -24
- warp/tests/test_intersect.py +39 -9
- warp/tests/test_large.py +1 -1
- warp/tests/test_lerp.py +3 -1
- warp/tests/test_linear_solvers.py +1 -1
- warp/tests/test_map.py +35 -4
- warp/tests/test_mat.py +52 -62
- warp/tests/test_mat_constructors.py +4 -5
- warp/tests/test_mat_lite.py +1 -1
- warp/tests/test_mat_scalar_ops.py +121 -121
- warp/tests/test_math.py +34 -0
- warp/tests/test_module_aot.py +4 -4
- warp/tests/test_modules_lite.py +28 -2
- warp/tests/test_print.py +11 -11
- warp/tests/test_quat.py +93 -58
- warp/tests/test_runlength_encode.py +1 -1
- warp/tests/test_scalar_ops.py +38 -10
- warp/tests/test_smoothstep.py +1 -1
- warp/tests/test_sparse.py +126 -15
- warp/tests/test_spatial.py +105 -87
- warp/tests/test_special_values.py +6 -6
- warp/tests/test_static.py +7 -7
- warp/tests/test_struct.py +13 -2
- warp/tests/test_triangle_closest_point.py +48 -1
- warp/tests/test_types.py +27 -15
- warp/tests/test_utils.py +52 -52
- warp/tests/test_vec.py +29 -29
- warp/tests/test_vec_constructors.py +5 -5
- warp/tests/test_vec_scalar_ops.py +97 -97
- warp/tests/test_version.py +75 -0
- warp/tests/tile/test_tile.py +178 -0
- warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
- warp/tests/tile/test_tile_cholesky.py +7 -4
- warp/tests/tile/test_tile_load.py +26 -2
- warp/tests/tile/test_tile_mathdx.py +3 -3
- warp/tests/tile/test_tile_matmul.py +1 -1
- warp/tests/tile/test_tile_mlp.py +2 -4
- warp/tests/tile/test_tile_reduce.py +214 -13
- warp/tests/unittest_suites.py +6 -14
- warp/tests/unittest_utils.py +10 -9
- warp/tests/walkthrough_debug.py +3 -1
- warp/torch.py +6 -373
- warp/types.py +29 -5764
- warp/utils.py +10 -1659
- {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/METADATA +46 -99
- warp_lang-1.10.0rc2.dist-info/RECORD +468 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
- warp/examples/assets/cartpole.urdf +0 -110
- warp/examples/assets/crazyflie.usd +0 -0
- warp/examples/assets/nv_ant.xml +0 -92
- warp/examples/assets/nv_humanoid.xml +0 -183
- warp/examples/assets/quadruped.urdf +0 -268
- warp/examples/optim/example_bounce.py +0 -266
- warp/examples/optim/example_cloth_throw.py +0 -228
- warp/examples/optim/example_drone.py +0 -870
- warp/examples/optim/example_inverse_kinematics.py +0 -182
- warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
- warp/examples/optim/example_softbody_properties.py +0 -400
- warp/examples/optim/example_spring_cage.py +0 -245
- warp/examples/optim/example_trajectory.py +0 -227
- warp/examples/sim/example_cartpole.py +0 -143
- warp/examples/sim/example_cloth.py +0 -225
- warp/examples/sim/example_cloth_self_contact.py +0 -316
- warp/examples/sim/example_granular.py +0 -130
- warp/examples/sim/example_granular_collision_sdf.py +0 -202
- warp/examples/sim/example_jacobian_ik.py +0 -244
- warp/examples/sim/example_particle_chain.py +0 -124
- warp/examples/sim/example_quadruped.py +0 -203
- warp/examples/sim/example_rigid_chain.py +0 -203
- warp/examples/sim/example_rigid_contact.py +0 -195
- warp/examples/sim/example_rigid_force.py +0 -133
- warp/examples/sim/example_rigid_gyroscopic.py +0 -115
- warp/examples/sim/example_rigid_soft_contact.py +0 -140
- warp/examples/sim/example_soft_body.py +0 -196
- warp/examples/tile/example_tile_walker.py +0 -327
- warp/sim/__init__.py +0 -74
- warp/sim/articulation.py +0 -793
- warp/sim/collide.py +0 -2570
- warp/sim/graph_coloring.py +0 -307
- warp/sim/import_mjcf.py +0 -791
- warp/sim/import_snu.py +0 -227
- warp/sim/import_urdf.py +0 -579
- warp/sim/import_usd.py +0 -898
- warp/sim/inertia.py +0 -357
- warp/sim/integrator.py +0 -245
- warp/sim/integrator_euler.py +0 -2000
- warp/sim/integrator_featherstone.py +0 -2101
- warp/sim/integrator_vbd.py +0 -2487
- warp/sim/integrator_xpbd.py +0 -3295
- warp/sim/model.py +0 -4821
- warp/sim/particles.py +0 -121
- warp/sim/render.py +0 -431
- warp/sim/utils.py +0 -431
- warp/tests/sim/disabled_kinematics.py +0 -244
- warp/tests/sim/test_cloth.py +0 -863
- warp/tests/sim/test_collision.py +0 -743
- warp/tests/sim/test_coloring.py +0 -347
- warp/tests/sim/test_inertia.py +0 -161
- warp/tests/sim/test_model.py +0 -226
- warp/tests/sim/test_sim_grad.py +0 -287
- warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
- warp/tests/sim/test_sim_kinematics.py +0 -98
- warp/thirdparty/__init__.py +0 -0
- warp_lang-1.9.1.dist-info/RECORD +0 -456
- /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
- /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
- /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
- /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
- {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/WHEEL +0 -0
- {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/licenses/LICENSE.md +0 -0
- {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/top_level.txt +0 -0
warp/autograd.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# SPDX-FileCopyrightText: Copyright (c)
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -13,1063 +13,21 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
|
|
16
|
+
# isort: skip_file
|
|
17
17
|
|
|
18
|
-
import
|
|
19
|
-
import
|
|
20
|
-
from
|
|
18
|
+
from warp._src.autograd import gradcheck as gradcheck
|
|
19
|
+
from warp._src.autograd import gradcheck_tape as gradcheck_tape
|
|
20
|
+
from warp._src.autograd import jacobian as jacobian
|
|
21
|
+
from warp._src.autograd import jacobian_fd as jacobian_fd
|
|
22
|
+
from warp._src.autograd import jacobian_plot as jacobian_plot
|
|
21
23
|
|
|
22
|
-
import numpy as np
|
|
23
24
|
|
|
24
|
-
|
|
25
|
+
# TODO: Remove after cleaning up the public API.
|
|
25
26
|
|
|
26
|
-
|
|
27
|
-
"gradcheck",
|
|
28
|
-
"gradcheck_tape",
|
|
29
|
-
"jacobian",
|
|
30
|
-
"jacobian_fd",
|
|
31
|
-
"jacobian_plot",
|
|
32
|
-
]
|
|
27
|
+
from warp._src import autograd as _autograd
|
|
33
28
|
|
|
34
29
|
|
|
35
|
-
def
|
|
36
|
-
|
|
37
|
-
dim: tuple[int] | None = None,
|
|
38
|
-
inputs: Sequence | None = None,
|
|
39
|
-
outputs: Sequence | None = None,
|
|
40
|
-
*,
|
|
41
|
-
eps: float = 1e-4,
|
|
42
|
-
atol: float = 1e-3,
|
|
43
|
-
rtol: float = 1e-2,
|
|
44
|
-
raise_exception: bool = True,
|
|
45
|
-
input_output_mask: list[tuple[str | int, str | int]] | None = None,
|
|
46
|
-
device: wp.context.Devicelike = None,
|
|
47
|
-
max_blocks: int = 0,
|
|
48
|
-
block_dim: int = 256,
|
|
49
|
-
max_inputs_per_var: int = -1,
|
|
50
|
-
max_outputs_per_var: int = -1,
|
|
51
|
-
plot_relative_error: bool = False,
|
|
52
|
-
plot_absolute_error: bool = False,
|
|
53
|
-
show_summary: bool = True,
|
|
54
|
-
) -> bool:
|
|
55
|
-
"""
|
|
56
|
-
Checks whether the autodiff gradient of a Warp kernel matches finite differences.
|
|
30
|
+
def __getattr__(name):
|
|
31
|
+
from warp._src.utils import get_deprecated_api
|
|
57
32
|
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
.. math::
|
|
61
|
-
|
|
62
|
-
|\\nabla_\\text{AD} - \\nabla_\\text{FD}| \\leq atol + rtol \\cdot |\\nabla_\\text{FD}|.
|
|
63
|
-
|
|
64
|
-
The kernel function and its adjoint version are launched with the given inputs and outputs, as well as the provided
|
|
65
|
-
``dim``, ``max_blocks``, and ``block_dim`` arguments (see :func:`warp.launch` for more details).
|
|
66
|
-
|
|
67
|
-
Note:
|
|
68
|
-
This function only supports Warp kernels whose input arguments precede the output arguments.
|
|
69
|
-
|
|
70
|
-
Only Warp arrays with ``requires_grad=True`` are considered for the Jacobian computation.
|
|
71
|
-
|
|
72
|
-
Structs arguments are not yet supported by this function to compute Jacobians.
|
|
73
|
-
|
|
74
|
-
Args:
|
|
75
|
-
function: The Warp kernel function, decorated with the ``@wp.kernel`` decorator, or any function that involves Warp kernel launches.
|
|
76
|
-
dim: The number of threads to launch the kernel, can be an integer, or a Tuple of ints. Only required if the function is a Warp kernel.
|
|
77
|
-
inputs: List of input variables.
|
|
78
|
-
outputs: List of output variables. Only required if the function is a Warp kernel.
|
|
79
|
-
eps: The finite-difference step size.
|
|
80
|
-
atol: The absolute tolerance for the gradient check.
|
|
81
|
-
rtol: The relative tolerance for the gradient check.
|
|
82
|
-
raise_exception: If True, raises a `ValueError` if the gradient check fails.
|
|
83
|
-
input_output_mask: List of tuples specifying the input-output pairs to compute the Jacobian for. Inputs and outputs can be identified either by their integer indices of where they appear in the kernel input/output arguments, or by the respective argument names as strings. If None, computes the Jacobian for all input-output pairs.
|
|
84
|
-
device: The device to launch on (optional)
|
|
85
|
-
max_blocks: The maximum number of CUDA thread blocks to use.
|
|
86
|
-
block_dim: The number of threads per block.
|
|
87
|
-
max_inputs_per_var: Maximum number of input dimensions over which to evaluate the Jacobians for the input-output pairs. Evaluates all input dimensions if value <= 0.
|
|
88
|
-
max_outputs_per_var: Maximum number of output dimensions over which to evaluate the Jacobians for the input-output pairs. Evaluates all output dimensions if value <= 0.
|
|
89
|
-
plot_relative_error: If True, visualizes the relative error of the Jacobians in a plot (requires ``matplotlib``).
|
|
90
|
-
plot_absolute_error: If True, visualizes the absolute error of the Jacobians in a plot (requires ``matplotlib``).
|
|
91
|
-
show_summary: If True, prints a summary table of the gradient check results.
|
|
92
|
-
|
|
93
|
-
Returns:
|
|
94
|
-
True if the gradient check passes, False otherwise.
|
|
95
|
-
"""
|
|
96
|
-
|
|
97
|
-
if inputs is None:
|
|
98
|
-
raise ValueError("The inputs argument must be provided")
|
|
99
|
-
|
|
100
|
-
metadata = FunctionMetadata()
|
|
101
|
-
|
|
102
|
-
jacs_ad = jacobian(
|
|
103
|
-
function,
|
|
104
|
-
dim=dim,
|
|
105
|
-
inputs=inputs,
|
|
106
|
-
outputs=outputs,
|
|
107
|
-
input_output_mask=input_output_mask,
|
|
108
|
-
device=device,
|
|
109
|
-
max_blocks=max_blocks,
|
|
110
|
-
block_dim=block_dim,
|
|
111
|
-
max_outputs_per_var=max_outputs_per_var,
|
|
112
|
-
plot_jacobians=False,
|
|
113
|
-
metadata=metadata,
|
|
114
|
-
)
|
|
115
|
-
jacs_fd = jacobian_fd(
|
|
116
|
-
function,
|
|
117
|
-
dim=dim,
|
|
118
|
-
inputs=inputs,
|
|
119
|
-
outputs=outputs,
|
|
120
|
-
input_output_mask=input_output_mask,
|
|
121
|
-
device=device,
|
|
122
|
-
max_blocks=max_blocks,
|
|
123
|
-
block_dim=block_dim,
|
|
124
|
-
max_inputs_per_var=max_inputs_per_var,
|
|
125
|
-
eps=eps,
|
|
126
|
-
plot_jacobians=False,
|
|
127
|
-
metadata=metadata,
|
|
128
|
-
)
|
|
129
|
-
|
|
130
|
-
relative_error_jacs = {}
|
|
131
|
-
absolute_error_jacs = {}
|
|
132
|
-
|
|
133
|
-
if show_summary:
|
|
134
|
-
summary = []
|
|
135
|
-
summary_header = ["Input", "Output", "Max Abs Error", "AD at MAE", "FD at MAE", "Max Rel Error", "Pass"]
|
|
136
|
-
|
|
137
|
-
class FontColors:
|
|
138
|
-
OKGREEN = "\033[92m"
|
|
139
|
-
WARNING = "\033[93m"
|
|
140
|
-
FAIL = "\033[91m"
|
|
141
|
-
ENDC = "\033[0m"
|
|
142
|
-
|
|
143
|
-
success = True
|
|
144
|
-
any_grad_mismatch = False
|
|
145
|
-
any_grad_nan = False
|
|
146
|
-
for (input_i, output_i), jac_fd in jacs_fd.items():
|
|
147
|
-
jac_ad = jacs_ad[input_i, output_i]
|
|
148
|
-
if plot_relative_error or plot_absolute_error:
|
|
149
|
-
jac_rel_error = wp.empty_like(jac_fd)
|
|
150
|
-
jac_abs_error = wp.empty_like(jac_fd)
|
|
151
|
-
flat_jac_fd = scalarize_array_1d(jac_fd)
|
|
152
|
-
flat_jac_ad = scalarize_array_1d(jac_ad)
|
|
153
|
-
flat_jac_rel_error = scalarize_array_1d(jac_rel_error)
|
|
154
|
-
flat_jac_abs_error = scalarize_array_1d(jac_abs_error)
|
|
155
|
-
wp.launch(
|
|
156
|
-
compute_error_kernel,
|
|
157
|
-
dim=len(flat_jac_fd),
|
|
158
|
-
inputs=[flat_jac_ad, flat_jac_fd, flat_jac_rel_error, flat_jac_abs_error],
|
|
159
|
-
device=jac_fd.device,
|
|
160
|
-
)
|
|
161
|
-
relative_error_jacs[(input_i, output_i)] = jac_rel_error
|
|
162
|
-
absolute_error_jacs[(input_i, output_i)] = jac_abs_error
|
|
163
|
-
cut_jac_fd = jac_fd.numpy()
|
|
164
|
-
cut_jac_ad = jac_ad.numpy()
|
|
165
|
-
if max_outputs_per_var > 0:
|
|
166
|
-
cut_jac_fd = cut_jac_fd[:max_outputs_per_var]
|
|
167
|
-
cut_jac_ad = cut_jac_ad[:max_outputs_per_var]
|
|
168
|
-
if max_inputs_per_var > 0:
|
|
169
|
-
cut_jac_fd = cut_jac_fd[:, :max_inputs_per_var]
|
|
170
|
-
cut_jac_ad = cut_jac_ad[:, :max_inputs_per_var]
|
|
171
|
-
grad_matches = np.allclose(cut_jac_ad, cut_jac_fd, atol=atol, rtol=rtol)
|
|
172
|
-
any_grad_mismatch = any_grad_mismatch or not grad_matches
|
|
173
|
-
success = success and grad_matches
|
|
174
|
-
isnan = np.any(np.isnan(cut_jac_ad))
|
|
175
|
-
any_grad_nan = any_grad_nan or isnan
|
|
176
|
-
success = success and not isnan
|
|
177
|
-
|
|
178
|
-
if show_summary:
|
|
179
|
-
max_abs_error = np.abs(cut_jac_ad - cut_jac_fd).max()
|
|
180
|
-
arg_max_abs_error = np.unravel_index(np.argmax(np.abs(cut_jac_ad - cut_jac_fd)), cut_jac_ad.shape)
|
|
181
|
-
max_rel_error = np.abs((cut_jac_ad - cut_jac_fd) / (cut_jac_fd + 1e-8)).max()
|
|
182
|
-
if isnan:
|
|
183
|
-
pass_str = FontColors.FAIL + "NaN" + FontColors.ENDC
|
|
184
|
-
elif grad_matches:
|
|
185
|
-
pass_str = FontColors.OKGREEN + "PASS" + FontColors.ENDC
|
|
186
|
-
else:
|
|
187
|
-
pass_str = FontColors.FAIL + "FAIL" + FontColors.ENDC
|
|
188
|
-
input_name = metadata.input_labels[input_i]
|
|
189
|
-
output_name = metadata.output_labels[output_i]
|
|
190
|
-
summary.append(
|
|
191
|
-
[
|
|
192
|
-
input_name,
|
|
193
|
-
output_name,
|
|
194
|
-
f"{max_abs_error:.3e} at {tuple(int(i) for i in arg_max_abs_error)}",
|
|
195
|
-
f"{cut_jac_ad[arg_max_abs_error]:.3e}",
|
|
196
|
-
f"{cut_jac_fd[arg_max_abs_error]:.3e}",
|
|
197
|
-
f"{max_rel_error:.3e}",
|
|
198
|
-
pass_str,
|
|
199
|
-
]
|
|
200
|
-
)
|
|
201
|
-
|
|
202
|
-
if show_summary:
|
|
203
|
-
print_table(summary_header, summary)
|
|
204
|
-
if not success:
|
|
205
|
-
print(FontColors.FAIL + f"Gradient check for kernel {metadata.key} failed" + FontColors.ENDC)
|
|
206
|
-
else:
|
|
207
|
-
print(FontColors.OKGREEN + f"Gradient check for kernel {metadata.key} passed" + FontColors.ENDC)
|
|
208
|
-
if plot_relative_error:
|
|
209
|
-
jacobian_plot(
|
|
210
|
-
relative_error_jacs,
|
|
211
|
-
metadata,
|
|
212
|
-
inputs,
|
|
213
|
-
outputs,
|
|
214
|
-
title=f"{metadata.key} kernel Jacobian relative error",
|
|
215
|
-
)
|
|
216
|
-
if plot_absolute_error:
|
|
217
|
-
jacobian_plot(
|
|
218
|
-
absolute_error_jacs,
|
|
219
|
-
metadata,
|
|
220
|
-
inputs,
|
|
221
|
-
outputs,
|
|
222
|
-
title=f"{metadata.key} kernel Jacobian absolute error",
|
|
223
|
-
)
|
|
224
|
-
|
|
225
|
-
if raise_exception:
|
|
226
|
-
if any_grad_mismatch:
|
|
227
|
-
raise ValueError(
|
|
228
|
-
f"Gradient check failed for kernel {metadata.key}, input {input_i}, output {output_i}: "
|
|
229
|
-
f"finite difference and autodiff gradients do not match"
|
|
230
|
-
)
|
|
231
|
-
if any_grad_nan:
|
|
232
|
-
raise ValueError(
|
|
233
|
-
f"Gradient check failed for kernel {metadata.key}, input {input_i}, output {output_i}: "
|
|
234
|
-
f"gradient contains NaN values"
|
|
235
|
-
)
|
|
236
|
-
|
|
237
|
-
return success
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
def gradcheck_tape(
|
|
241
|
-
tape: wp.Tape,
|
|
242
|
-
*,
|
|
243
|
-
eps=1e-4,
|
|
244
|
-
atol=1e-3,
|
|
245
|
-
rtol=1e-2,
|
|
246
|
-
raise_exception=True,
|
|
247
|
-
input_output_masks: dict[str, list[tuple[str | int, str | int]]] | None = None,
|
|
248
|
-
blacklist_kernels: list[str] | None = None,
|
|
249
|
-
whitelist_kernels: list[str] | None = None,
|
|
250
|
-
max_inputs_per_var=-1,
|
|
251
|
-
max_outputs_per_var=-1,
|
|
252
|
-
plot_relative_error=False,
|
|
253
|
-
plot_absolute_error=False,
|
|
254
|
-
show_summary: bool = True,
|
|
255
|
-
reverse_launches: bool = False,
|
|
256
|
-
skip_to_launch_index: int = 0,
|
|
257
|
-
) -> bool:
|
|
258
|
-
"""
|
|
259
|
-
Checks whether the autodiff gradients for kernels recorded on the Warp tape match finite differences.
|
|
260
|
-
|
|
261
|
-
Given the autodiff (:math:`\\nabla_\\text{AD}`) and finite difference gradients (:math:`\\nabla_\\text{FD}`), the check succeeds if the autodiff gradients contain no NaN values and the following condition holds:
|
|
262
|
-
|
|
263
|
-
.. math::
|
|
264
|
-
|
|
265
|
-
|\\nabla_\\text{AD} - \\nabla_\\text{FD}| \\leq atol + rtol \\cdot |\\nabla_\\text{FD}|.
|
|
266
|
-
|
|
267
|
-
Note:
|
|
268
|
-
Only Warp kernels recorded on the tape are checked but not arbitrary functions that have been recorded, e.g. via :meth:`Tape.record_func`.
|
|
269
|
-
|
|
270
|
-
Only Warp arrays with ``requires_grad=True`` are considered for the Jacobian computation.
|
|
271
|
-
|
|
272
|
-
Structs arguments are not yet supported by this function to compute Jacobians.
|
|
273
|
-
|
|
274
|
-
Args:
|
|
275
|
-
tape: The Warp tape to perform the gradient check on.
|
|
276
|
-
eps: The finite-difference step size.
|
|
277
|
-
atol: The absolute tolerance for the gradient check.
|
|
278
|
-
rtol: The relative tolerance for the gradient check.
|
|
279
|
-
raise_exception: If True, raises a `ValueError` if the gradient check fails.
|
|
280
|
-
input_output_masks: Dictionary of input-output masks for each kernel in the tape, mapping from kernel keys to input-output masks. Inputs and outputs can be identified either by their integer indices of where they appear in the kernel input/output arguments, or by the respective argument names as strings. If None, computes the Jacobian for all input-output pairs.
|
|
281
|
-
blacklist_kernels: List of kernel keys to exclude from the gradient check.
|
|
282
|
-
whitelist_kernels: List of kernel keys to include in the gradient check. If not empty or None, only kernels in this list are checked.
|
|
283
|
-
max_inputs_per_var: Maximum number of input dimensions over which to evaluate the Jacobians for the input-output pairs. Evaluates all input dimensions if value <= 0.
|
|
284
|
-
max_outputs_per_var: Maximum number of output dimensions over which to evaluate the Jacobians for the input-output pairs. Evaluates all output dimensions if value <= 0.
|
|
285
|
-
plot_relative_error: If True, visualizes the relative error of the Jacobians in a plot (requires ``matplotlib``).
|
|
286
|
-
plot_absolute_error: If True, visualizes the absolute error of the Jacobians in a plot (requires ``matplotlib``).
|
|
287
|
-
show_summary: If True, prints a summary table of the gradient check results.
|
|
288
|
-
reverse_launches: If True, reverses the order of the kernel launches on the tape to check.
|
|
289
|
-
|
|
290
|
-
Returns:
|
|
291
|
-
True if the gradient check passes for all kernels on the tape, False otherwise.
|
|
292
|
-
"""
|
|
293
|
-
if input_output_masks is None:
|
|
294
|
-
input_output_masks = {}
|
|
295
|
-
if blacklist_kernels is None:
|
|
296
|
-
blacklist_kernels = []
|
|
297
|
-
else:
|
|
298
|
-
blacklist_kernels = set(blacklist_kernels)
|
|
299
|
-
if whitelist_kernels is None:
|
|
300
|
-
whitelist_kernels = []
|
|
301
|
-
else:
|
|
302
|
-
whitelist_kernels = set(whitelist_kernels)
|
|
303
|
-
|
|
304
|
-
overall_success = True
|
|
305
|
-
launches = reversed(tape.launches) if reverse_launches else tape.launches
|
|
306
|
-
for i, launch in enumerate(launches):
|
|
307
|
-
if i < skip_to_launch_index:
|
|
308
|
-
continue
|
|
309
|
-
if not isinstance(launch, tuple) and not isinstance(launch, list):
|
|
310
|
-
continue
|
|
311
|
-
if not isinstance(launch[0], wp.Kernel):
|
|
312
|
-
continue
|
|
313
|
-
kernel, dim, max_blocks, inputs, outputs, device, block_dim = launch[:7]
|
|
314
|
-
if len(whitelist_kernels) > 0 and kernel.key not in whitelist_kernels:
|
|
315
|
-
continue
|
|
316
|
-
if kernel.key in blacklist_kernels:
|
|
317
|
-
continue
|
|
318
|
-
if not kernel.options.get("enable_backward", True):
|
|
319
|
-
continue
|
|
320
|
-
|
|
321
|
-
input_output_mask = input_output_masks.get(kernel.key)
|
|
322
|
-
success = gradcheck(
|
|
323
|
-
kernel,
|
|
324
|
-
dim,
|
|
325
|
-
inputs,
|
|
326
|
-
outputs,
|
|
327
|
-
eps=eps,
|
|
328
|
-
atol=atol,
|
|
329
|
-
rtol=rtol,
|
|
330
|
-
raise_exception=raise_exception,
|
|
331
|
-
input_output_mask=input_output_mask,
|
|
332
|
-
device=device,
|
|
333
|
-
max_blocks=max_blocks,
|
|
334
|
-
block_dim=block_dim,
|
|
335
|
-
max_inputs_per_var=max_inputs_per_var,
|
|
336
|
-
max_outputs_per_var=max_outputs_per_var,
|
|
337
|
-
plot_relative_error=plot_relative_error,
|
|
338
|
-
plot_absolute_error=plot_absolute_error,
|
|
339
|
-
show_summary=show_summary,
|
|
340
|
-
)
|
|
341
|
-
overall_success = overall_success and success
|
|
342
|
-
|
|
343
|
-
return overall_success
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
def get_struct_vars(x: wp.codegen.StructInstance):
|
|
347
|
-
return {varname: getattr(x, varname) for varname, _ in x._cls.ctype._fields_}
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
def infer_device(xs: list):
|
|
351
|
-
# retrieve best matching Warp device for a list of variables
|
|
352
|
-
for x in xs:
|
|
353
|
-
if isinstance(x, wp.array):
|
|
354
|
-
return x.device
|
|
355
|
-
elif isinstance(x, wp.codegen.StructInstance):
|
|
356
|
-
for var in get_struct_vars(x).values():
|
|
357
|
-
if isinstance(var, wp.array):
|
|
358
|
-
return var.device
|
|
359
|
-
return wp.get_preferred_device()
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
class FunctionMetadata:
|
|
363
|
-
"""
|
|
364
|
-
Metadata holder for kernel functions or functions with Warp arrays as inputs/outputs.
|
|
365
|
-
"""
|
|
366
|
-
|
|
367
|
-
def __init__(
|
|
368
|
-
self,
|
|
369
|
-
key: str | None = None,
|
|
370
|
-
input_labels: list[str] | None = None,
|
|
371
|
-
output_labels: list[str] | None = None,
|
|
372
|
-
input_strides: list[tuple] | None = None,
|
|
373
|
-
output_strides: list[tuple] | None = None,
|
|
374
|
-
input_dtypes: list | None = None,
|
|
375
|
-
output_dtypes: list | None = None,
|
|
376
|
-
):
|
|
377
|
-
self.key = key
|
|
378
|
-
self.input_labels = input_labels
|
|
379
|
-
self.output_labels = output_labels
|
|
380
|
-
self.input_strides = input_strides
|
|
381
|
-
self.output_strides = output_strides
|
|
382
|
-
self.input_dtypes = input_dtypes
|
|
383
|
-
self.output_dtypes = output_dtypes
|
|
384
|
-
|
|
385
|
-
@property
|
|
386
|
-
def is_empty(self):
|
|
387
|
-
return self.key is None
|
|
388
|
-
|
|
389
|
-
def input_is_array(self, i: int):
|
|
390
|
-
return self.input_strides[i] is not None
|
|
391
|
-
|
|
392
|
-
def output_is_array(self, i: int):
|
|
393
|
-
return self.output_strides[i] is not None
|
|
394
|
-
|
|
395
|
-
def update_from_kernel(self, kernel: wp.Kernel, inputs: Sequence):
|
|
396
|
-
self.key = kernel.key
|
|
397
|
-
self.input_labels = [arg.label for arg in kernel.adj.args[: len(inputs)]]
|
|
398
|
-
self.output_labels = [arg.label for arg in kernel.adj.args[len(inputs) :]]
|
|
399
|
-
self.input_strides = []
|
|
400
|
-
self.output_strides = []
|
|
401
|
-
self.input_dtypes = []
|
|
402
|
-
self.output_dtypes = []
|
|
403
|
-
for arg in kernel.adj.args[: len(inputs)]:
|
|
404
|
-
if arg.type is wp.array:
|
|
405
|
-
self.input_strides.append(arg.type.strides)
|
|
406
|
-
self.input_dtypes.append(arg.type.dtype)
|
|
407
|
-
else:
|
|
408
|
-
self.input_strides.append(None)
|
|
409
|
-
self.input_dtypes.append(None)
|
|
410
|
-
for arg in kernel.adj.args[len(inputs) :]:
|
|
411
|
-
if arg.type is wp.array:
|
|
412
|
-
self.output_strides.append(arg.type.strides)
|
|
413
|
-
self.output_dtypes.append(arg.type.dtype)
|
|
414
|
-
else:
|
|
415
|
-
self.output_strides.append(None)
|
|
416
|
-
self.output_dtypes.append(None)
|
|
417
|
-
|
|
418
|
-
def update_from_function(self, function: Callable, inputs: Sequence, outputs: Sequence | None = None):
|
|
419
|
-
self.key = function.__name__
|
|
420
|
-
self.input_labels = list(inspect.signature(function).parameters.keys())
|
|
421
|
-
if outputs is None:
|
|
422
|
-
outputs = function(*inputs)
|
|
423
|
-
if isinstance(outputs, wp.array):
|
|
424
|
-
outputs = [outputs]
|
|
425
|
-
self.output_labels = [f"output_{i}" for i in range(len(outputs))]
|
|
426
|
-
self.input_strides = []
|
|
427
|
-
self.output_strides = []
|
|
428
|
-
self.input_dtypes = []
|
|
429
|
-
self.output_dtypes = []
|
|
430
|
-
for input in inputs:
|
|
431
|
-
if isinstance(input, wp.array):
|
|
432
|
-
self.input_strides.append(input.strides)
|
|
433
|
-
self.input_dtypes.append(input.dtype)
|
|
434
|
-
else:
|
|
435
|
-
self.input_strides.append(None)
|
|
436
|
-
self.input_dtypes.append(None)
|
|
437
|
-
for output in outputs:
|
|
438
|
-
if isinstance(output, wp.array):
|
|
439
|
-
self.output_strides.append(output.strides)
|
|
440
|
-
self.output_dtypes.append(output.dtype)
|
|
441
|
-
else:
|
|
442
|
-
self.output_strides.append(None)
|
|
443
|
-
self.output_dtypes.append(None)
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
def jacobian_plot(
|
|
447
|
-
jacobians: dict[tuple[int, int], wp.array],
|
|
448
|
-
kernel: FunctionMetadata | wp.Kernel,
|
|
449
|
-
inputs: Sequence | None = None,
|
|
450
|
-
show_plot: bool = True,
|
|
451
|
-
show_colorbar: bool = True,
|
|
452
|
-
scale_colors_per_submatrix: bool = False,
|
|
453
|
-
title: str | None = None,
|
|
454
|
-
colormap: str = "coolwarm",
|
|
455
|
-
log_scale: bool = False,
|
|
456
|
-
):
|
|
457
|
-
"""
|
|
458
|
-
Visualizes the Jacobians computed by :func:`jacobian` or :func:`jacobian_fd` in a combined image plot.
|
|
459
|
-
Requires the ``matplotlib`` package to be installed.
|
|
460
|
-
|
|
461
|
-
Args:
|
|
462
|
-
jacobians: A dictionary of Jacobians, where the keys are tuples of input and output indices, and the values are the Jacobian matrices.
|
|
463
|
-
kernel: The Warp kernel function, decorated with the ``@wp.kernel`` decorator, or a :class:`FunctionMetadata` instance with the kernel/function attributes.
|
|
464
|
-
inputs: List of input variables.
|
|
465
|
-
show_plot: If True, displays the plot via ``plt.show()``.
|
|
466
|
-
show_colorbar: If True, displays a colorbar next to the plot (or a colorbar next to every submatrix if ).
|
|
467
|
-
scale_colors_per_submatrix: If True, considers the minimum and maximum of each Jacobian submatrix separately for color scaling. Otherwise, uses the global minimum and maximum of all Jacobians.
|
|
468
|
-
title: The title of the plot (optional).
|
|
469
|
-
colormap: The colormap to use for the plot.
|
|
470
|
-
log_scale: If True, uses a logarithmic scale for the matrix values shown in the image plot.
|
|
471
|
-
|
|
472
|
-
Returns:
|
|
473
|
-
The created Matplotlib figure.
|
|
474
|
-
"""
|
|
475
|
-
|
|
476
|
-
import matplotlib.pyplot as plt
|
|
477
|
-
from matplotlib.ticker import MaxNLocator
|
|
478
|
-
|
|
479
|
-
if isinstance(kernel, wp.Kernel):
|
|
480
|
-
assert inputs is not None
|
|
481
|
-
metadata = FunctionMetadata()
|
|
482
|
-
metadata.update_from_kernel(kernel, inputs)
|
|
483
|
-
elif isinstance(kernel, FunctionMetadata):
|
|
484
|
-
metadata = kernel
|
|
485
|
-
else:
|
|
486
|
-
raise ValueError("Invalid kernel argument: must be a Warp kernel or a FunctionMetadata object")
|
|
487
|
-
|
|
488
|
-
jacobians = sorted(jacobians.items(), key=lambda x: (x[0][1], x[0][0]))
|
|
489
|
-
jacobians = dict(jacobians)
|
|
490
|
-
|
|
491
|
-
input_to_ax = {}
|
|
492
|
-
output_to_ax = {}
|
|
493
|
-
ax_to_input = {}
|
|
494
|
-
ax_to_output = {}
|
|
495
|
-
for i, j in jacobians.keys():
|
|
496
|
-
if i not in input_to_ax:
|
|
497
|
-
input_to_ax[i] = len(input_to_ax)
|
|
498
|
-
ax_to_input[input_to_ax[i]] = i
|
|
499
|
-
if j not in output_to_ax:
|
|
500
|
-
output_to_ax[j] = len(output_to_ax)
|
|
501
|
-
ax_to_output[output_to_ax[j]] = j
|
|
502
|
-
|
|
503
|
-
num_rows = len(output_to_ax)
|
|
504
|
-
num_cols = len(input_to_ax)
|
|
505
|
-
if num_rows == 0 or num_cols == 0:
|
|
506
|
-
return
|
|
507
|
-
|
|
508
|
-
# determine the width and height ratios for the subplots based on the
|
|
509
|
-
# dimensions of the Jacobians
|
|
510
|
-
width_ratios = []
|
|
511
|
-
height_ratios = []
|
|
512
|
-
for i in range(len(metadata.input_labels)):
|
|
513
|
-
if not metadata.input_is_array(i):
|
|
514
|
-
continue
|
|
515
|
-
input_stride = metadata.input_strides[i][0]
|
|
516
|
-
for j in range(len(metadata.output_labels)):
|
|
517
|
-
if (i, j) not in jacobians:
|
|
518
|
-
continue
|
|
519
|
-
jac_wp = jacobians[(i, j)]
|
|
520
|
-
width_ratios.append(jac_wp.shape[1] * input_stride)
|
|
521
|
-
break
|
|
522
|
-
|
|
523
|
-
for i in range(len(metadata.output_labels)):
|
|
524
|
-
if not metadata.output_is_array(i):
|
|
525
|
-
continue
|
|
526
|
-
for j in range(len(inputs)):
|
|
527
|
-
if (j, i) not in jacobians:
|
|
528
|
-
continue
|
|
529
|
-
jac_wp = jacobians[(j, i)]
|
|
530
|
-
height_ratios.append(jac_wp.shape[0])
|
|
531
|
-
break
|
|
532
|
-
|
|
533
|
-
fig, axs = plt.subplots(
|
|
534
|
-
ncols=num_cols,
|
|
535
|
-
nrows=num_rows,
|
|
536
|
-
figsize=(7, 7),
|
|
537
|
-
sharex="col",
|
|
538
|
-
sharey="row",
|
|
539
|
-
gridspec_kw={
|
|
540
|
-
"wspace": 0.1,
|
|
541
|
-
"hspace": 0.1,
|
|
542
|
-
"width_ratios": width_ratios,
|
|
543
|
-
"height_ratios": height_ratios,
|
|
544
|
-
},
|
|
545
|
-
subplot_kw={"aspect": 1},
|
|
546
|
-
squeeze=False,
|
|
547
|
-
)
|
|
548
|
-
if title is None:
|
|
549
|
-
key = kernel.key if isinstance(kernel, wp.Kernel) else kernel.get("key", "unknown")
|
|
550
|
-
title = f"{key} kernel Jacobian"
|
|
551
|
-
fig.suptitle(title)
|
|
552
|
-
fig.canvas.manager.set_window_title(title)
|
|
553
|
-
|
|
554
|
-
if not scale_colors_per_submatrix:
|
|
555
|
-
safe_jacobians = [jac.numpy().flatten() for jac in jacobians.values()]
|
|
556
|
-
safe_jacobians = [jac[~np.isnan(jac)] for jac in safe_jacobians]
|
|
557
|
-
safe_jacobians = [jac for jac in safe_jacobians if len(jac) > 0]
|
|
558
|
-
if len(safe_jacobians) == 0:
|
|
559
|
-
vmin = 0
|
|
560
|
-
vmax = 0
|
|
561
|
-
else:
|
|
562
|
-
vmin = min([jac.min() for jac in safe_jacobians])
|
|
563
|
-
vmax = max([jac.max() for jac in safe_jacobians])
|
|
564
|
-
|
|
565
|
-
has_plot = np.ones((num_rows, num_cols), dtype=bool)
|
|
566
|
-
for i in range(num_rows):
|
|
567
|
-
for j in range(num_cols):
|
|
568
|
-
if (ax_to_input[j], ax_to_output[i]) not in jacobians:
|
|
569
|
-
ax = axs[i, j]
|
|
570
|
-
ax.axis("off")
|
|
571
|
-
has_plot[i, j] = False
|
|
572
|
-
|
|
573
|
-
jac_i = 0
|
|
574
|
-
for (input_i, output_i), jac_wp in jacobians.items():
|
|
575
|
-
input_name = metadata.input_labels[input_i]
|
|
576
|
-
output_name = metadata.output_labels[output_i]
|
|
577
|
-
|
|
578
|
-
ax_i, ax_j = output_to_ax[output_i], input_to_ax[input_i]
|
|
579
|
-
ax = axs[ax_i, ax_j]
|
|
580
|
-
ax.tick_params(which="major", width=1, length=7)
|
|
581
|
-
ax.tick_params(which="minor", width=1, length=4, color="gray")
|
|
582
|
-
|
|
583
|
-
input_stride = metadata.input_dtypes[input_i]._length_
|
|
584
|
-
# output_stride = metadata.output_dtypes[output_i]._length_
|
|
585
|
-
|
|
586
|
-
jac = jac_wp.numpy()
|
|
587
|
-
# Jacobian matrix has output stride already multiplied to first dimension
|
|
588
|
-
jac = jac.reshape(jac_wp.shape[0], jac_wp.shape[1] * input_stride)
|
|
589
|
-
|
|
590
|
-
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
|
|
591
|
-
ax.yaxis.set_major_locator(MaxNLocator(integer=True))
|
|
592
|
-
|
|
593
|
-
if scale_colors_per_submatrix:
|
|
594
|
-
safe_jac = jac[~np.isnan(jac)]
|
|
595
|
-
vmin = safe_jac.min()
|
|
596
|
-
vmax = safe_jac.max()
|
|
597
|
-
img = ax.imshow(
|
|
598
|
-
np.log10(np.abs(jac) + 1e-8) if log_scale else jac,
|
|
599
|
-
cmap=colormap,
|
|
600
|
-
aspect="auto",
|
|
601
|
-
interpolation="nearest",
|
|
602
|
-
extent=[0, jac.shape[1], 0, jac.shape[0]],
|
|
603
|
-
vmin=vmin,
|
|
604
|
-
vmax=vmax,
|
|
605
|
-
)
|
|
606
|
-
if ax_i == num_rows - 1 or not has_plot[ax_i + 1 :, ax_j].any():
|
|
607
|
-
# last plot of this column
|
|
608
|
-
ax.set_xlabel(input_name)
|
|
609
|
-
if ax_j == 0 or not has_plot[ax_i, :ax_j].any():
|
|
610
|
-
# first plot of this row
|
|
611
|
-
ax.set_ylabel(output_name)
|
|
612
|
-
ax.grid(color="gray", which="minor", linestyle="--", linewidth=0.5)
|
|
613
|
-
ax.grid(color="black", which="major", linewidth=1.0)
|
|
614
|
-
|
|
615
|
-
if show_colorbar and scale_colors_per_submatrix:
|
|
616
|
-
plt.colorbar(img, ax=ax, orientation="vertical", pad=0.02)
|
|
617
|
-
|
|
618
|
-
jac_i += 1
|
|
619
|
-
|
|
620
|
-
if show_colorbar and not scale_colors_per_submatrix:
|
|
621
|
-
m = plt.cm.ScalarMappable(cmap=colormap)
|
|
622
|
-
m.set_array([vmin, vmax])
|
|
623
|
-
m.set_clim(vmin, vmax)
|
|
624
|
-
plt.colorbar(m, ax=axs, orientation="vertical", pad=0.02)
|
|
625
|
-
|
|
626
|
-
plt.tight_layout()
|
|
627
|
-
if show_plot:
|
|
628
|
-
plt.show()
|
|
629
|
-
return fig
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
def scalarize_array_1d(arr):
|
|
633
|
-
# convert array to 1D array with scalar dtype
|
|
634
|
-
if arr.dtype in wp.types.scalar_types:
|
|
635
|
-
return arr.flatten()
|
|
636
|
-
elif arr.dtype in wp.types.vector_types:
|
|
637
|
-
return wp.array(
|
|
638
|
-
ptr=arr.ptr,
|
|
639
|
-
shape=(arr.size * arr.dtype._length_,),
|
|
640
|
-
dtype=arr.dtype._wp_scalar_type_,
|
|
641
|
-
device=arr.device,
|
|
642
|
-
)
|
|
643
|
-
else:
|
|
644
|
-
raise ValueError(
|
|
645
|
-
f"Unsupported array dtype {arr.dtype}: array to be flattened must be a scalar/vector/matrix array"
|
|
646
|
-
)
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
def scalarize_array_2d(arr):
|
|
650
|
-
assert arr.ndim == 2
|
|
651
|
-
# convert array to 2D array with scalar dtype
|
|
652
|
-
if arr.dtype in wp.types.scalar_types:
|
|
653
|
-
return arr
|
|
654
|
-
elif arr.dtype in wp.types.vector_types:
|
|
655
|
-
return wp.array(
|
|
656
|
-
ptr=arr.ptr,
|
|
657
|
-
shape=(arr.shape[0], arr.shape[1] * arr.dtype._length_),
|
|
658
|
-
dtype=arr.dtype._wp_scalar_type_,
|
|
659
|
-
device=arr.device,
|
|
660
|
-
)
|
|
661
|
-
else:
|
|
662
|
-
raise ValueError(
|
|
663
|
-
f"Unsupported array dtype {arr.dtype}: array to be flattened must be a scalar/vector/matrix array"
|
|
664
|
-
)
|
|
665
|
-
|
|
666
|
-
|
|
667
|
-
def jacobian(
|
|
668
|
-
function: wp.Kernel | Callable,
|
|
669
|
-
dim: tuple[int] | None = None,
|
|
670
|
-
inputs: Sequence | None = None,
|
|
671
|
-
outputs: Sequence | None = None,
|
|
672
|
-
input_output_mask: list[tuple[str | int, str | int]] | None = None,
|
|
673
|
-
device: wp.context.Devicelike = None,
|
|
674
|
-
max_blocks=0,
|
|
675
|
-
block_dim=256,
|
|
676
|
-
max_outputs_per_var=-1,
|
|
677
|
-
plot_jacobians=False,
|
|
678
|
-
metadata: FunctionMetadata | None = None,
|
|
679
|
-
) -> dict[tuple[int, int], wp.array]:
|
|
680
|
-
"""
|
|
681
|
-
Computes the Jacobians of a function or Warp kernel for the provided selection of differentiable inputs to differentiable outputs.
|
|
682
|
-
|
|
683
|
-
The input function can be either a Warp kernel (e.g. a function decorated by ``@wp.kernel``) or a regular Python function that accepts arguments (of which some must be Warp arrays) and returns a Warp array or a list of Warp arrays.
|
|
684
|
-
|
|
685
|
-
In case ``function`` is a Warp kernel, its adjoint kernel is launched with the given inputs and outputs, as well as the provided ``dim``,
|
|
686
|
-
``max_blocks``, and ``block_dim`` arguments (see :func:`warp.launch` for more details).
|
|
687
|
-
|
|
688
|
-
Note:
|
|
689
|
-
If ``function`` is a Warp kernel, the input arguments must precede the output arguments in the kernel code definition.
|
|
690
|
-
|
|
691
|
-
Only Warp arrays with ``requires_grad=True`` are considered for the Jacobian computation.
|
|
692
|
-
|
|
693
|
-
Function arguments of type :ref:`Struct <structs>` are not yet supported.
|
|
694
|
-
|
|
695
|
-
Args:
|
|
696
|
-
function: The Warp kernel function, or a regular Python function that returns a Warp array or a list of Warp arrays.
|
|
697
|
-
dim: The number of threads to launch the kernel, can be an integer, or a Tuple of ints. Only required if ``function`` is a Warp kernel.
|
|
698
|
-
inputs: List of input variables. At least one of the arguments must be a Warp array with ``requires_grad=True``.
|
|
699
|
-
outputs: List of output variables. Optional if the function is a regular Python function that returns a Warp array or a list of Warp arrays. Only required if ``function`` is a Warp kernel.
|
|
700
|
-
input_output_mask: List of tuples specifying the input-output pairs to compute the Jacobian for. Inputs and outputs can be identified either by their integer indices of where they appear in the kernel input/output arguments, or by the respective argument names as strings. If None, computes the Jacobian for all input-output pairs.
|
|
701
|
-
device: The device to launch on (optional). Only used if ``function`` is a Warp kernel.
|
|
702
|
-
max_blocks: The maximum number of CUDA thread blocks to use. Only used if ``function`` is a Warp kernel.
|
|
703
|
-
block_dim: The number of threads per block. Only used if ``function`` is a Warp kernel.
|
|
704
|
-
max_outputs_per_var: Maximum number of output dimensions over which to evaluate the Jacobians for the input-output pairs. Evaluates all output dimensions if value <= 0.
|
|
705
|
-
plot_jacobians: If True, visualizes the computed Jacobians in a plot (requires ``matplotlib``).
|
|
706
|
-
metadata: The metadata of the kernel function, containing the input and output labels, strides, and dtypes. If None or empty, the metadata is inferred from the kernel or function.
|
|
707
|
-
|
|
708
|
-
Returns:
|
|
709
|
-
A dictionary of Jacobians, where the keys are tuples of input and output indices, and the values are the Jacobian matrices.
|
|
710
|
-
"""
|
|
711
|
-
if input_output_mask is None:
|
|
712
|
-
input_output_mask = []
|
|
713
|
-
|
|
714
|
-
if metadata is None:
|
|
715
|
-
metadata = FunctionMetadata()
|
|
716
|
-
|
|
717
|
-
if isinstance(function, wp.Kernel):
|
|
718
|
-
if not function.options.get("enable_backward", True):
|
|
719
|
-
raise ValueError("Kernel must have backward pass enabled to compute Jacobians")
|
|
720
|
-
if outputs is None or len(outputs) == 0:
|
|
721
|
-
raise ValueError("A list of output arguments must be provided to compute kernel Jacobians")
|
|
722
|
-
if device is None:
|
|
723
|
-
device = infer_device(inputs + outputs)
|
|
724
|
-
if metadata.is_empty:
|
|
725
|
-
metadata.update_from_kernel(function, inputs)
|
|
726
|
-
|
|
727
|
-
tape = wp.Tape()
|
|
728
|
-
tape.record_launch(
|
|
729
|
-
kernel=function,
|
|
730
|
-
dim=dim,
|
|
731
|
-
inputs=inputs,
|
|
732
|
-
outputs=outputs,
|
|
733
|
-
device=device,
|
|
734
|
-
max_blocks=max_blocks,
|
|
735
|
-
block_dim=block_dim,
|
|
736
|
-
)
|
|
737
|
-
else:
|
|
738
|
-
tape = wp.Tape()
|
|
739
|
-
with tape:
|
|
740
|
-
outputs = function(*inputs)
|
|
741
|
-
if isinstance(outputs, wp.array):
|
|
742
|
-
outputs = [outputs]
|
|
743
|
-
if metadata.is_empty:
|
|
744
|
-
metadata.update_from_function(function, inputs, outputs)
|
|
745
|
-
|
|
746
|
-
arg_names = metadata.input_labels + metadata.output_labels
|
|
747
|
-
|
|
748
|
-
def resolve_arg(name, offset: int = 0):
|
|
749
|
-
if isinstance(name, int):
|
|
750
|
-
return name
|
|
751
|
-
return arg_names.index(name) + offset
|
|
752
|
-
|
|
753
|
-
input_output_mask = [
|
|
754
|
-
(resolve_arg(input_name), resolve_arg(output_name, -len(inputs)))
|
|
755
|
-
for input_name, output_name in input_output_mask
|
|
756
|
-
]
|
|
757
|
-
input_output_mask = set(input_output_mask)
|
|
758
|
-
|
|
759
|
-
zero_grads(inputs)
|
|
760
|
-
zero_grads(outputs)
|
|
761
|
-
|
|
762
|
-
jacobians = {}
|
|
763
|
-
|
|
764
|
-
for input_i, output_i in itertools.product(range(len(inputs)), range(len(outputs))):
|
|
765
|
-
if len(input_output_mask) > 0 and (input_i, output_i) not in input_output_mask:
|
|
766
|
-
continue
|
|
767
|
-
input = inputs[input_i]
|
|
768
|
-
output = outputs[output_i]
|
|
769
|
-
if not isinstance(input, wp.array) or not input.requires_grad:
|
|
770
|
-
continue
|
|
771
|
-
if not isinstance(output, wp.array) or not output.requires_grad:
|
|
772
|
-
continue
|
|
773
|
-
out_grad = scalarize_array_1d(output.grad)
|
|
774
|
-
output_num = out_grad.shape[0]
|
|
775
|
-
jacobian = wp.empty((output_num, input.size), dtype=input.dtype, device=input.device)
|
|
776
|
-
jacobian.fill_(wp.nan)
|
|
777
|
-
if max_outputs_per_var > 0:
|
|
778
|
-
output_num = min(output_num, max_outputs_per_var)
|
|
779
|
-
for i in range(output_num):
|
|
780
|
-
output.grad.zero_()
|
|
781
|
-
if i > 0:
|
|
782
|
-
set_element(out_grad, i - 1, 0.0)
|
|
783
|
-
set_element(out_grad, i, 1.0)
|
|
784
|
-
tape.backward()
|
|
785
|
-
jacobian[i].assign(input.grad)
|
|
786
|
-
|
|
787
|
-
zero_grads(inputs)
|
|
788
|
-
zero_grads(outputs)
|
|
789
|
-
jacobians[input_i, output_i] = jacobian
|
|
790
|
-
|
|
791
|
-
if plot_jacobians:
|
|
792
|
-
jacobian_plot(
|
|
793
|
-
jacobians,
|
|
794
|
-
metadata,
|
|
795
|
-
inputs,
|
|
796
|
-
outputs,
|
|
797
|
-
)
|
|
798
|
-
|
|
799
|
-
return jacobians
|
|
800
|
-
|
|
801
|
-
|
|
802
|
-
def jacobian_fd(
|
|
803
|
-
function: wp.Kernel | Callable,
|
|
804
|
-
dim: tuple[int] | None | None = None,
|
|
805
|
-
inputs: Sequence | None = None,
|
|
806
|
-
outputs: Sequence | None = None,
|
|
807
|
-
input_output_mask: list[tuple[str | int, str | int]] | None = None,
|
|
808
|
-
device: wp.context.Devicelike = None,
|
|
809
|
-
max_blocks=0,
|
|
810
|
-
block_dim=256,
|
|
811
|
-
max_inputs_per_var=-1,
|
|
812
|
-
eps: float = 1e-4,
|
|
813
|
-
plot_jacobians=False,
|
|
814
|
-
metadata: FunctionMetadata | None = None,
|
|
815
|
-
) -> dict[tuple[int, int], wp.array]:
|
|
816
|
-
"""
|
|
817
|
-
Computes the finite-difference Jacobian of a function or Warp kernel for the provided selection of differentiable inputs to differentiable outputs.
|
|
818
|
-
The method uses a central difference scheme to approximate the Jacobian.
|
|
819
|
-
|
|
820
|
-
The input function can be either a Warp kernel (e.g. a function decorated by ``@wp.kernel``) or a regular Python function that accepts arguments (of which some must be Warp arrays) and returns a Warp array or a list of Warp arrays.
|
|
821
|
-
|
|
822
|
-
The function is launched multiple times in forward-only mode with the given inputs. If ``function`` is a Warp kernel, the provided inputs and outputs,
|
|
823
|
-
as well as the other parameters ``dim``, ``max_blocks``, and ``block_dim`` are provided to the kernel launch (see :func:`warp.launch`).
|
|
824
|
-
|
|
825
|
-
Note:
|
|
826
|
-
If ``function`` is a Warp kernel, the input arguments must precede the output arguments in the kernel code definition.
|
|
827
|
-
|
|
828
|
-
Only Warp arrays with ``requires_grad=True`` are considered for the Jacobian computation.
|
|
829
|
-
|
|
830
|
-
Function arguments of type :ref:`Struct <structs>` are not yet supported.
|
|
831
|
-
|
|
832
|
-
Args:
|
|
833
|
-
function: The Warp kernel function, or a regular Python function that returns a Warp array or a list of Warp arrays.
|
|
834
|
-
dim: The number of threads to launch the kernel, can be an integer, or a Tuple of ints. Only required if ``function`` is a Warp kernel.
|
|
835
|
-
inputs: List of input variables. At least one of the arguments must be a Warp array with ``requires_grad=True``.
|
|
836
|
-
outputs: List of output variables. Optional if the function is a regular Python function that returns a Warp array or a list of Warp arrays. Only required if ``function`` is a Warp kernel.
|
|
837
|
-
input_output_mask: List of tuples specifying the input-output pairs to compute the Jacobian for. Inputs and outputs can be identified either by their integer indices of where they appear in the kernel input/output arguments, or by the respective argument names as strings. If None, computes the Jacobian for all input-output pairs.
|
|
838
|
-
device: The device to launch on (optional). Only used if ``function`` is a Warp kernel.
|
|
839
|
-
max_blocks: The maximum number of CUDA thread blocks to use. Only used if ``function`` is a Warp kernel.
|
|
840
|
-
block_dim: The number of threads per block. Only used if ``function`` is a Warp kernel.
|
|
841
|
-
max_inputs_per_var: Maximum number of input dimensions over which to evaluate the Jacobians for the input-output pairs. Evaluates all input dimensions if value <= 0.
|
|
842
|
-
eps: The finite-difference step size.
|
|
843
|
-
plot_jacobians: If True, visualizes the computed Jacobians in a plot (requires ``matplotlib``).
|
|
844
|
-
metadata: The metadata of the kernel function, containing the input and output labels, strides, and dtypes. If None or empty, the metadata is inferred from the kernel or function.
|
|
845
|
-
|
|
846
|
-
Returns:
|
|
847
|
-
A dictionary of Jacobians, where the keys are tuples of input and output indices, and the values are the Jacobian matrices.
|
|
848
|
-
"""
|
|
849
|
-
if input_output_mask is None:
|
|
850
|
-
input_output_mask = []
|
|
851
|
-
|
|
852
|
-
if metadata is None:
|
|
853
|
-
metadata = FunctionMetadata()
|
|
854
|
-
|
|
855
|
-
if isinstance(function, wp.Kernel):
|
|
856
|
-
if not function.options.get("enable_backward", True):
|
|
857
|
-
raise ValueError("Kernel must have backward pass enabled to compute Jacobians")
|
|
858
|
-
if outputs is None or len(outputs) == 0:
|
|
859
|
-
raise ValueError("A list of output arguments must be provided to compute kernel Jacobians")
|
|
860
|
-
if device is None:
|
|
861
|
-
device = infer_device(inputs + outputs)
|
|
862
|
-
if metadata.is_empty:
|
|
863
|
-
metadata.update_from_kernel(function, inputs)
|
|
864
|
-
|
|
865
|
-
tape = wp.Tape()
|
|
866
|
-
tape.record_launch(
|
|
867
|
-
kernel=function,
|
|
868
|
-
dim=dim,
|
|
869
|
-
inputs=inputs,
|
|
870
|
-
outputs=outputs,
|
|
871
|
-
device=device,
|
|
872
|
-
max_blocks=max_blocks,
|
|
873
|
-
block_dim=block_dim,
|
|
874
|
-
)
|
|
875
|
-
else:
|
|
876
|
-
tape = wp.Tape()
|
|
877
|
-
with tape:
|
|
878
|
-
outputs = function(*inputs)
|
|
879
|
-
if isinstance(outputs, wp.array):
|
|
880
|
-
outputs = [outputs]
|
|
881
|
-
if metadata.is_empty:
|
|
882
|
-
metadata.update_from_function(function, inputs, outputs)
|
|
883
|
-
|
|
884
|
-
arg_names = metadata.input_labels + metadata.output_labels
|
|
885
|
-
|
|
886
|
-
def resolve_arg(name, offset: int = 0):
|
|
887
|
-
if isinstance(name, int):
|
|
888
|
-
return name
|
|
889
|
-
return arg_names.index(name) + offset
|
|
890
|
-
|
|
891
|
-
input_output_mask = [
|
|
892
|
-
(resolve_arg(input_name), resolve_arg(output_name, -len(inputs)))
|
|
893
|
-
for input_name, output_name in input_output_mask
|
|
894
|
-
]
|
|
895
|
-
input_output_mask = set(input_output_mask)
|
|
896
|
-
|
|
897
|
-
jacobians = {}
|
|
898
|
-
|
|
899
|
-
def conditional_clone(obj):
|
|
900
|
-
if isinstance(obj, wp.array):
|
|
901
|
-
return wp.clone(obj)
|
|
902
|
-
return obj
|
|
903
|
-
|
|
904
|
-
outputs_copy = [conditional_clone(output) for output in outputs]
|
|
905
|
-
|
|
906
|
-
for input_i, output_i in itertools.product(range(len(inputs)), range(len(outputs))):
|
|
907
|
-
if len(input_output_mask) > 0 and (input_i, output_i) not in input_output_mask:
|
|
908
|
-
continue
|
|
909
|
-
input = inputs[input_i]
|
|
910
|
-
output = outputs[output_i]
|
|
911
|
-
if not isinstance(input, wp.array) or not input.requires_grad:
|
|
912
|
-
continue
|
|
913
|
-
if not isinstance(output, wp.array) or not output.requires_grad:
|
|
914
|
-
continue
|
|
915
|
-
|
|
916
|
-
flat_input = scalarize_array_1d(input)
|
|
917
|
-
|
|
918
|
-
left = wp.clone(output)
|
|
919
|
-
right = wp.clone(output)
|
|
920
|
-
left_copy = wp.clone(output)
|
|
921
|
-
right_copy = wp.clone(output)
|
|
922
|
-
flat_left = scalarize_array_1d(left)
|
|
923
|
-
flat_right = scalarize_array_1d(right)
|
|
924
|
-
|
|
925
|
-
outputs_until_left = [conditional_clone(output) for output in outputs_copy[:output_i]]
|
|
926
|
-
outputs_until_right = [conditional_clone(output) for output in outputs_copy[:output_i]]
|
|
927
|
-
outputs_after_left = [conditional_clone(output) for output in outputs_copy[output_i + 1 :]]
|
|
928
|
-
outputs_after_right = [conditional_clone(output) for output in outputs_copy[output_i + 1 :]]
|
|
929
|
-
left_outputs = [*outputs_until_left, left, *outputs_after_left]
|
|
930
|
-
right_outputs = [*outputs_until_right, right, *outputs_after_right]
|
|
931
|
-
|
|
932
|
-
input_num = flat_input.shape[0]
|
|
933
|
-
flat_input_copy = wp.clone(flat_input)
|
|
934
|
-
jacobian = wp.empty((flat_left.size, input.size), dtype=input.dtype, device=input.device)
|
|
935
|
-
jacobian.fill_(wp.nan)
|
|
936
|
-
|
|
937
|
-
jacobian_scalar = scalarize_array_2d(jacobian)
|
|
938
|
-
jacobian_t = jacobian_scalar.transpose()
|
|
939
|
-
if max_inputs_per_var > 0:
|
|
940
|
-
input_num = min(input_num, max_inputs_per_var)
|
|
941
|
-
for i in range(input_num):
|
|
942
|
-
set_element(flat_input, i, -eps, relative=True)
|
|
943
|
-
if isinstance(function, wp.Kernel):
|
|
944
|
-
wp.launch(
|
|
945
|
-
function,
|
|
946
|
-
dim=dim,
|
|
947
|
-
max_blocks=max_blocks,
|
|
948
|
-
block_dim=block_dim,
|
|
949
|
-
inputs=inputs,
|
|
950
|
-
outputs=left_outputs,
|
|
951
|
-
device=device,
|
|
952
|
-
)
|
|
953
|
-
else:
|
|
954
|
-
outputs = function(*inputs)
|
|
955
|
-
if isinstance(outputs, wp.array):
|
|
956
|
-
outputs = [outputs]
|
|
957
|
-
left.assign(outputs[output_i])
|
|
958
|
-
|
|
959
|
-
set_element(flat_input, i, 2 * eps, relative=True)
|
|
960
|
-
if isinstance(function, wp.Kernel):
|
|
961
|
-
wp.launch(
|
|
962
|
-
function,
|
|
963
|
-
dim=dim,
|
|
964
|
-
max_blocks=max_blocks,
|
|
965
|
-
block_dim=block_dim,
|
|
966
|
-
inputs=inputs,
|
|
967
|
-
outputs=right_outputs,
|
|
968
|
-
device=device,
|
|
969
|
-
)
|
|
970
|
-
else:
|
|
971
|
-
outputs = function(*inputs)
|
|
972
|
-
if isinstance(outputs, wp.array):
|
|
973
|
-
outputs = [outputs]
|
|
974
|
-
right.assign(outputs[output_i])
|
|
975
|
-
|
|
976
|
-
# restore input
|
|
977
|
-
flat_input.assign(flat_input_copy)
|
|
978
|
-
|
|
979
|
-
compute_fd(
|
|
980
|
-
flat_left,
|
|
981
|
-
flat_right,
|
|
982
|
-
eps,
|
|
983
|
-
jacobian_t[i],
|
|
984
|
-
)
|
|
985
|
-
|
|
986
|
-
if i < input_num - 1:
|
|
987
|
-
# reset output buffers
|
|
988
|
-
left.assign(left_copy)
|
|
989
|
-
right.assign(right_copy)
|
|
990
|
-
flat_left = scalarize_array_1d(left)
|
|
991
|
-
flat_right = scalarize_array_1d(right)
|
|
992
|
-
|
|
993
|
-
jacobians[input_i, output_i] = jacobian
|
|
994
|
-
|
|
995
|
-
if plot_jacobians:
|
|
996
|
-
jacobian_plot(
|
|
997
|
-
jacobians,
|
|
998
|
-
metadata,
|
|
999
|
-
inputs,
|
|
1000
|
-
outputs,
|
|
1001
|
-
)
|
|
1002
|
-
|
|
1003
|
-
return jacobians
|
|
1004
|
-
|
|
1005
|
-
|
|
1006
|
-
@wp.kernel(enable_backward=False)
|
|
1007
|
-
def set_element_kernel(a: wp.array(dtype=Any), i: int, val: Any, relative: bool):
|
|
1008
|
-
if relative:
|
|
1009
|
-
a[i] += val
|
|
1010
|
-
else:
|
|
1011
|
-
a[i] = val
|
|
1012
|
-
|
|
1013
|
-
|
|
1014
|
-
def set_element(a: wp.array(dtype=Any), i: int, val: Any, relative: bool = False):
|
|
1015
|
-
wp.launch(set_element_kernel, dim=1, inputs=[a, i, a.dtype(val), relative], device=a.device)
|
|
1016
|
-
|
|
1017
|
-
|
|
1018
|
-
@wp.kernel(enable_backward=False)
|
|
1019
|
-
def compute_fd_kernel(left: wp.array(dtype=float), right: wp.array(dtype=float), eps: float, fd: wp.array(dtype=float)):
|
|
1020
|
-
tid = wp.tid()
|
|
1021
|
-
fd[tid] = (right[tid] - left[tid]) / (2.0 * eps)
|
|
1022
|
-
|
|
1023
|
-
|
|
1024
|
-
def compute_fd(left: wp.array(dtype=Any), right: wp.array(dtype=Any), eps: float, fd: wp.array(dtype=Any)):
|
|
1025
|
-
wp.launch(compute_fd_kernel, dim=len(left), inputs=[left, right, eps], outputs=[fd], device=left.device)
|
|
1026
|
-
|
|
1027
|
-
|
|
1028
|
-
@wp.kernel(enable_backward=False)
|
|
1029
|
-
def compute_error_kernel(
|
|
1030
|
-
jacobian_ad: wp.array(dtype=Any),
|
|
1031
|
-
jacobian_fd: wp.array(dtype=Any),
|
|
1032
|
-
relative_error: wp.array(dtype=Any),
|
|
1033
|
-
absolute_error: wp.array(dtype=Any),
|
|
1034
|
-
):
|
|
1035
|
-
tid = wp.tid()
|
|
1036
|
-
ad = jacobian_ad[tid]
|
|
1037
|
-
fd = jacobian_fd[tid]
|
|
1038
|
-
denom = ad
|
|
1039
|
-
if abs(ad) < 1e-8:
|
|
1040
|
-
denom = (type(ad))(1e-8)
|
|
1041
|
-
relative_error[tid] = (ad - fd) / denom
|
|
1042
|
-
absolute_error[tid] = wp.abs(ad - fd)
|
|
1043
|
-
|
|
1044
|
-
|
|
1045
|
-
def print_table(headers, cells):
|
|
1046
|
-
"""
|
|
1047
|
-
Prints a table with the given headers and cells.
|
|
1048
|
-
|
|
1049
|
-
Args:
|
|
1050
|
-
headers: List of header strings.
|
|
1051
|
-
cells: List of lists of cell strings.
|
|
1052
|
-
"""
|
|
1053
|
-
import re
|
|
1054
|
-
|
|
1055
|
-
def sanitized_len(s):
|
|
1056
|
-
return len(re.sub(r"\033\[\d+m", "", str(s)))
|
|
1057
|
-
|
|
1058
|
-
col_widths = [max(sanitized_len(cell) for cell in col) for col in zip(headers, *cells)]
|
|
1059
|
-
for header, col_width in zip(headers, col_widths):
|
|
1060
|
-
print(f"{header:{col_width}}", end=" | ")
|
|
1061
|
-
print()
|
|
1062
|
-
print("-" * (sum(col_widths) + 3 * len(col_widths) - 1))
|
|
1063
|
-
for cell_row in cells:
|
|
1064
|
-
for cell, col_width in zip(cell_row, col_widths):
|
|
1065
|
-
print(f"{cell:{col_width}}", end=" | ")
|
|
1066
|
-
print()
|
|
1067
|
-
|
|
1068
|
-
|
|
1069
|
-
def zero_grads(arrays: list):
|
|
1070
|
-
"""
|
|
1071
|
-
Zeros the gradients of all Warp arrays in the given list.
|
|
1072
|
-
"""
|
|
1073
|
-
for array in arrays:
|
|
1074
|
-
if isinstance(array, wp.array) and array.requires_grad:
|
|
1075
|
-
array.grad.zero_()
|
|
33
|
+
return get_deprecated_api(_autograd, "wp", name)
|