warp-lang 1.2.2__py3-none-win_amd64.whl → 1.3.0__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +8 -6
- warp/autograd.py +823 -0
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +6 -2
- warp/builtins.py +1410 -886
- warp/codegen.py +503 -166
- warp/config.py +48 -18
- warp/context.py +400 -198
- warp/dlpack.py +8 -0
- warp/examples/assets/bunny.usd +0 -0
- warp/examples/benchmarks/benchmark_cloth_warp.py +1 -1
- warp/examples/benchmarks/benchmark_interop_torch.py +158 -0
- warp/examples/benchmarks/benchmark_launches.py +1 -1
- warp/examples/core/example_cupy.py +78 -0
- warp/examples/fem/example_apic_fluid.py +17 -36
- warp/examples/fem/example_burgers.py +9 -18
- warp/examples/fem/example_convection_diffusion.py +7 -17
- warp/examples/fem/example_convection_diffusion_dg.py +27 -47
- warp/examples/fem/example_deformed_geometry.py +11 -22
- warp/examples/fem/example_diffusion.py +7 -18
- warp/examples/fem/example_diffusion_3d.py +24 -28
- warp/examples/fem/example_diffusion_mgpu.py +7 -14
- warp/examples/fem/example_magnetostatics.py +190 -0
- warp/examples/fem/example_mixed_elasticity.py +111 -80
- warp/examples/fem/example_navier_stokes.py +30 -34
- warp/examples/fem/example_nonconforming_contact.py +290 -0
- warp/examples/fem/example_stokes.py +17 -32
- warp/examples/fem/example_stokes_transfer.py +12 -21
- warp/examples/fem/example_streamlines.py +350 -0
- warp/examples/fem/utils.py +936 -0
- warp/fabric.py +5 -2
- warp/fem/__init__.py +13 -3
- warp/fem/cache.py +161 -11
- warp/fem/dirichlet.py +37 -28
- warp/fem/domain.py +105 -14
- warp/fem/field/__init__.py +14 -3
- warp/fem/field/field.py +454 -11
- warp/fem/field/nodal_field.py +33 -18
- warp/fem/geometry/deformed_geometry.py +50 -15
- warp/fem/geometry/hexmesh.py +12 -24
- warp/fem/geometry/nanogrid.py +106 -31
- warp/fem/geometry/quadmesh_2d.py +6 -11
- warp/fem/geometry/tetmesh.py +103 -61
- warp/fem/geometry/trimesh_2d.py +98 -47
- warp/fem/integrate.py +231 -186
- warp/fem/operator.py +14 -9
- warp/fem/quadrature/pic_quadrature.py +35 -9
- warp/fem/quadrature/quadrature.py +119 -32
- warp/fem/space/basis_space.py +98 -22
- warp/fem/space/collocated_function_space.py +3 -1
- warp/fem/space/function_space.py +7 -2
- warp/fem/space/grid_2d_function_space.py +3 -3
- warp/fem/space/grid_3d_function_space.py +4 -4
- warp/fem/space/hexmesh_function_space.py +3 -2
- warp/fem/space/nanogrid_function_space.py +12 -14
- warp/fem/space/partition.py +45 -47
- warp/fem/space/restriction.py +19 -16
- warp/fem/space/shape/cube_shape_function.py +91 -3
- warp/fem/space/shape/shape_function.py +7 -0
- warp/fem/space/shape/square_shape_function.py +32 -0
- warp/fem/space/shape/tet_shape_function.py +11 -7
- warp/fem/space/shape/triangle_shape_function.py +10 -1
- warp/fem/space/topology.py +116 -42
- warp/fem/types.py +8 -1
- warp/fem/utils.py +301 -83
- warp/native/array.h +16 -0
- warp/native/builtin.h +0 -15
- warp/native/cuda_util.cpp +14 -6
- warp/native/exports.h +1348 -1308
- warp/native/quat.h +79 -0
- warp/native/rand.h +27 -4
- warp/native/sparse.cpp +83 -81
- warp/native/sparse.cu +381 -453
- warp/native/vec.h +64 -0
- warp/native/volume.cpp +40 -49
- warp/native/volume_builder.cu +2 -3
- warp/native/volume_builder.h +12 -17
- warp/native/warp.cu +3 -3
- warp/native/warp.h +69 -59
- warp/render/render_opengl.py +17 -9
- warp/sim/articulation.py +117 -17
- warp/sim/collide.py +35 -29
- warp/sim/model.py +123 -18
- warp/sim/render.py +3 -1
- warp/sparse.py +867 -203
- warp/stubs.py +312 -541
- warp/tape.py +29 -1
- warp/tests/disabled_kinematics.py +1 -1
- warp/tests/test_adam.py +1 -1
- warp/tests/test_arithmetic.py +1 -1
- warp/tests/test_array.py +58 -1
- warp/tests/test_array_reduce.py +1 -1
- warp/tests/test_async.py +1 -1
- warp/tests/test_atomic.py +1 -1
- warp/tests/test_bool.py +1 -1
- warp/tests/test_builtins_resolution.py +1 -1
- warp/tests/test_bvh.py +6 -1
- warp/tests/test_closest_point_edge_edge.py +1 -1
- warp/tests/test_codegen.py +66 -1
- warp/tests/test_compile_consts.py +1 -1
- warp/tests/test_conditional.py +1 -1
- warp/tests/test_copy.py +1 -1
- warp/tests/test_ctypes.py +1 -1
- warp/tests/test_dense.py +1 -1
- warp/tests/test_devices.py +1 -1
- warp/tests/test_dlpack.py +1 -1
- warp/tests/test_examples.py +33 -4
- warp/tests/test_fabricarray.py +5 -2
- warp/tests/test_fast_math.py +1 -1
- warp/tests/test_fem.py +213 -6
- warp/tests/test_fp16.py +1 -1
- warp/tests/test_func.py +1 -1
- warp/tests/test_future_annotations.py +90 -0
- warp/tests/test_generics.py +1 -1
- warp/tests/test_grad.py +1 -1
- warp/tests/test_grad_customs.py +1 -1
- warp/tests/test_grad_debug.py +247 -0
- warp/tests/test_hash_grid.py +6 -1
- warp/tests/test_implicit_init.py +354 -0
- warp/tests/test_import.py +1 -1
- warp/tests/test_indexedarray.py +1 -1
- warp/tests/test_intersect.py +1 -1
- warp/tests/test_jax.py +1 -1
- warp/tests/test_large.py +1 -1
- warp/tests/test_launch.py +1 -1
- warp/tests/test_lerp.py +1 -1
- warp/tests/test_linear_solvers.py +1 -1
- warp/tests/test_lvalue.py +1 -1
- warp/tests/test_marching_cubes.py +5 -2
- warp/tests/test_mat.py +34 -35
- warp/tests/test_mat_lite.py +2 -1
- warp/tests/test_mat_scalar_ops.py +1 -1
- warp/tests/test_math.py +1 -1
- warp/tests/test_matmul.py +20 -16
- warp/tests/test_matmul_lite.py +1 -1
- warp/tests/test_mempool.py +1 -1
- warp/tests/test_mesh.py +5 -2
- warp/tests/test_mesh_query_aabb.py +1 -1
- warp/tests/test_mesh_query_point.py +1 -1
- warp/tests/test_mesh_query_ray.py +1 -1
- warp/tests/test_mlp.py +1 -1
- warp/tests/test_model.py +1 -1
- warp/tests/test_module_hashing.py +77 -1
- warp/tests/test_modules_lite.py +1 -1
- warp/tests/test_multigpu.py +1 -1
- warp/tests/test_noise.py +1 -1
- warp/tests/test_operators.py +1 -1
- warp/tests/test_options.py +1 -1
- warp/tests/test_overwrite.py +542 -0
- warp/tests/test_peer.py +1 -1
- warp/tests/test_pinned.py +1 -1
- warp/tests/test_print.py +1 -1
- warp/tests/test_quat.py +15 -1
- warp/tests/test_rand.py +1 -1
- warp/tests/test_reload.py +1 -1
- warp/tests/test_rounding.py +1 -1
- warp/tests/test_runlength_encode.py +1 -1
- warp/tests/test_scalar_ops.py +95 -0
- warp/tests/test_sim_grad.py +1 -1
- warp/tests/test_sim_kinematics.py +1 -1
- warp/tests/test_smoothstep.py +1 -1
- warp/tests/test_sparse.py +82 -15
- warp/tests/test_spatial.py +1 -1
- warp/tests/test_special_values.py +2 -11
- warp/tests/test_streams.py +11 -1
- warp/tests/test_struct.py +1 -1
- warp/tests/test_tape.py +1 -1
- warp/tests/test_torch.py +194 -1
- warp/tests/test_transient_module.py +1 -1
- warp/tests/test_types.py +1 -1
- warp/tests/test_utils.py +1 -1
- warp/tests/test_vec.py +15 -63
- warp/tests/test_vec_lite.py +2 -1
- warp/tests/test_vec_scalar_ops.py +65 -1
- warp/tests/test_verify_fp.py +1 -1
- warp/tests/test_volume.py +28 -2
- warp/tests/test_volume_write.py +1 -1
- warp/tests/unittest_serial.py +1 -1
- warp/tests/unittest_suites.py +9 -1
- warp/tests/walkthrough_debug.py +1 -1
- warp/thirdparty/unittest_parallel.py +2 -5
- warp/torch.py +103 -41
- warp/types.py +341 -224
- warp/utils.py +11 -2
- {warp_lang-1.2.2.dist-info → warp_lang-1.3.0.dist-info}/METADATA +99 -46
- warp_lang-1.3.0.dist-info/RECORD +368 -0
- warp/examples/fem/bsr_utils.py +0 -378
- warp/examples/fem/mesh_utils.py +0 -133
- warp/examples/fem/plot_utils.py +0 -292
- warp_lang-1.2.2.dist-info/RECORD +0 -359
- {warp_lang-1.2.2.dist-info → warp_lang-1.3.0.dist-info}/LICENSE.md +0 -0
- {warp_lang-1.2.2.dist-info → warp_lang-1.3.0.dist-info}/WHEEL +0 -0
- {warp_lang-1.2.2.dist-info → warp_lang-1.3.0.dist-info}/top_level.txt +0 -0
warp/autograd.py
ADDED
|
@@ -0,0 +1,823 @@
|
|
|
1
|
+
# Copyright (c) 2024 NVIDIA CORPORATION. All rights reserved.
|
|
2
|
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
|
3
|
+
# and proprietary rights in and to this software, related documentation
|
|
4
|
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
|
5
|
+
# distribution of this software and related documentation without an express
|
|
6
|
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
|
7
|
+
|
|
8
|
+
import itertools
|
|
9
|
+
from typing import Any, Dict, List, Sequence, Tuple, Union
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
|
|
13
|
+
import warp as wp
|
|
14
|
+
|
|
15
|
+
__all__ = [
|
|
16
|
+
"jacobian",
|
|
17
|
+
"jacobian_fd",
|
|
18
|
+
"gradcheck",
|
|
19
|
+
"gradcheck_tape",
|
|
20
|
+
"plot_kernel_jacobians",
|
|
21
|
+
]
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def gradcheck(
|
|
25
|
+
function: wp.Kernel,
|
|
26
|
+
dim: Tuple[int],
|
|
27
|
+
inputs: Sequence,
|
|
28
|
+
outputs: Sequence,
|
|
29
|
+
*,
|
|
30
|
+
eps=1e-4,
|
|
31
|
+
atol=1e-3,
|
|
32
|
+
rtol=1e-2,
|
|
33
|
+
raise_exception=True,
|
|
34
|
+
input_output_mask: List[Tuple[Union[str, int], Union[str, int]]] = None,
|
|
35
|
+
device: wp.context.Devicelike = None,
|
|
36
|
+
max_blocks=0,
|
|
37
|
+
max_inputs_per_var=-1,
|
|
38
|
+
max_outputs_per_var=-1,
|
|
39
|
+
plot_relative_error=False,
|
|
40
|
+
plot_absolute_error=False,
|
|
41
|
+
show_summary: bool = True,
|
|
42
|
+
) -> bool:
|
|
43
|
+
"""
|
|
44
|
+
Checks whether the autodiff gradient of a Warp kernel matches finite differences.
|
|
45
|
+
Fails if the relative or absolute errors between the autodiff and finite difference gradients exceed the specified tolerance, or if the autodiff gradients contain NaN values.
|
|
46
|
+
|
|
47
|
+
The kernel function and its adjoint version are launched with the given inputs and outputs, as well as the provided ``dim`` and ``max_blocks`` arguments (see :func:`warp.launch` for more details).
|
|
48
|
+
|
|
49
|
+
Note:
|
|
50
|
+
This function only supports Warp kernels whose input arguments precede the output arguments.
|
|
51
|
+
|
|
52
|
+
Only Warp arrays with ``requires_grad=True`` are considered for the Jacobian computation.
|
|
53
|
+
|
|
54
|
+
Structs arguments are not yet supported by this function to compute Jacobians.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
function: The Warp kernel function, decorated with the ``@wp.kernel`` decorator.
|
|
58
|
+
dim: The number of threads to launch the kernel, can be an integer, or a Tuple of ints.
|
|
59
|
+
inputs: List of input variables.
|
|
60
|
+
outputs: List of output variables.
|
|
61
|
+
eps: The finite-difference step size.
|
|
62
|
+
atol: The absolute tolerance for the gradient check.
|
|
63
|
+
rtol: The relative tolerance for the gradient check.
|
|
64
|
+
raise_exception: If True, raises a `ValueError` if the gradient check fails.
|
|
65
|
+
input_output_mask: List of tuples specifying the input-output pairs to compute the Jacobian for. Inputs and outputs can be identified either by their integer indices of where they appear in the kernel input/output arguments, or by the respective argument names as strings. If None, computes the Jacobian for all input-output pairs.
|
|
66
|
+
device: The device to launch on (optional)
|
|
67
|
+
max_blocks: The maximum number of CUDA thread blocks to use.
|
|
68
|
+
max_inputs_per_var: Maximum number of input dimensions over which to evaluate the Jacobians for the input-output pairs. Evaluates all input dimensions if value <= 0.
|
|
69
|
+
max_outputs_per_var: Maximum number of output dimensions over which to evaluate the Jacobians for the input-output pairs. Evaluates all output dimensions if value <= 0.
|
|
70
|
+
plot_relative_error: If True, visualizes the relative error of the Jacobians in a plot (requires ``matplotlib``).
|
|
71
|
+
plot_absolute_error: If True, visualizes the absolute error of the Jacobians in a plot (requires ``matplotlib``).
|
|
72
|
+
show_summary: If True, prints a summary table of the gradient check results.
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
True if the gradient check passes, False otherwise.
|
|
76
|
+
"""
|
|
77
|
+
|
|
78
|
+
assert isinstance(function, wp.Kernel), "The function argument must be a Warp kernel"
|
|
79
|
+
|
|
80
|
+
jacs_fd = jacobian_fd(
|
|
81
|
+
function,
|
|
82
|
+
dim=dim,
|
|
83
|
+
inputs=inputs,
|
|
84
|
+
outputs=outputs,
|
|
85
|
+
input_output_mask=input_output_mask,
|
|
86
|
+
device=device,
|
|
87
|
+
max_blocks=max_blocks,
|
|
88
|
+
max_inputs_per_var=max_inputs_per_var,
|
|
89
|
+
eps=eps,
|
|
90
|
+
plot_jacobians=False,
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
jacs_ad = jacobian(
|
|
94
|
+
function,
|
|
95
|
+
dim=dim,
|
|
96
|
+
inputs=inputs,
|
|
97
|
+
outputs=outputs,
|
|
98
|
+
input_output_mask=input_output_mask,
|
|
99
|
+
device=device,
|
|
100
|
+
max_blocks=max_blocks,
|
|
101
|
+
max_outputs_per_var=max_outputs_per_var,
|
|
102
|
+
plot_jacobians=False,
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
relative_error_jacs = {}
|
|
106
|
+
absolute_error_jacs = {}
|
|
107
|
+
|
|
108
|
+
if show_summary:
|
|
109
|
+
summary = []
|
|
110
|
+
summary_header = ["Input", "Output", "Max Abs Error", "Max Rel Error", "Pass"]
|
|
111
|
+
|
|
112
|
+
class FontColors:
|
|
113
|
+
OKGREEN = "\033[92m"
|
|
114
|
+
WARNING = "\033[93m"
|
|
115
|
+
FAIL = "\033[91m"
|
|
116
|
+
ENDC = "\033[0m"
|
|
117
|
+
|
|
118
|
+
success = True
|
|
119
|
+
for (input_i, output_i), jac_fd in jacs_fd.items():
|
|
120
|
+
jac_ad = jacs_ad[input_i, output_i]
|
|
121
|
+
if plot_relative_error or plot_absolute_error:
|
|
122
|
+
jac_rel_error = wp.empty_like(jac_fd)
|
|
123
|
+
jac_abs_error = wp.empty_like(jac_fd)
|
|
124
|
+
flat_jac_fd = scalarize_array_1d(jac_fd)
|
|
125
|
+
flat_jac_ad = scalarize_array_1d(jac_ad)
|
|
126
|
+
flat_jac_rel_error = scalarize_array_1d(jac_rel_error)
|
|
127
|
+
flat_jac_abs_error = scalarize_array_1d(jac_abs_error)
|
|
128
|
+
wp.launch(
|
|
129
|
+
compute_error_kernel,
|
|
130
|
+
dim=len(flat_jac_fd),
|
|
131
|
+
inputs=[flat_jac_ad, flat_jac_fd, flat_jac_rel_error, flat_jac_abs_error],
|
|
132
|
+
device=jac_fd.device,
|
|
133
|
+
)
|
|
134
|
+
relative_error_jacs[(input_i, output_i)] = jac_rel_error
|
|
135
|
+
absolute_error_jacs[(input_i, output_i)] = jac_abs_error
|
|
136
|
+
cut_jac_fd = jac_fd.numpy()
|
|
137
|
+
cut_jac_ad = jac_ad.numpy()
|
|
138
|
+
if max_outputs_per_var > 0:
|
|
139
|
+
cut_jac_fd = cut_jac_fd[:max_outputs_per_var]
|
|
140
|
+
cut_jac_ad = cut_jac_ad[:max_outputs_per_var]
|
|
141
|
+
if max_inputs_per_var > 0:
|
|
142
|
+
cut_jac_fd = cut_jac_fd[:, :max_inputs_per_var]
|
|
143
|
+
cut_jac_ad = cut_jac_ad[:, :max_inputs_per_var]
|
|
144
|
+
grad_matches = np.allclose(cut_jac_ad, cut_jac_fd, atol=atol, rtol=rtol)
|
|
145
|
+
success = success and grad_matches
|
|
146
|
+
if not grad_matches:
|
|
147
|
+
if raise_exception:
|
|
148
|
+
raise ValueError(
|
|
149
|
+
f"Gradient check failed for kernel {function.key}, input {input_i}, output {output_i}: "
|
|
150
|
+
f"finite difference and autodiff gradients do not match"
|
|
151
|
+
)
|
|
152
|
+
elif not show_summary:
|
|
153
|
+
return False
|
|
154
|
+
isnan = np.any(np.isnan(cut_jac_ad))
|
|
155
|
+
success = success and not isnan
|
|
156
|
+
if isnan:
|
|
157
|
+
if raise_exception:
|
|
158
|
+
raise ValueError(
|
|
159
|
+
f"Gradient check failed for kernel {function.key}, input {input_i}, output {output_i}: "
|
|
160
|
+
f"gradient contains NaN values"
|
|
161
|
+
)
|
|
162
|
+
elif not show_summary:
|
|
163
|
+
return False
|
|
164
|
+
|
|
165
|
+
if show_summary:
|
|
166
|
+
max_abs_error = np.abs(cut_jac_ad - cut_jac_fd).max()
|
|
167
|
+
max_rel_error = np.abs((cut_jac_ad - cut_jac_fd) / (cut_jac_fd + 1e-8)).max()
|
|
168
|
+
if isnan:
|
|
169
|
+
pass_str = FontColors.FAIL + "NaN" + FontColors.ENDC
|
|
170
|
+
elif grad_matches:
|
|
171
|
+
pass_str = FontColors.OKGREEN + "PASS" + FontColors.ENDC
|
|
172
|
+
else:
|
|
173
|
+
pass_str = FontColors.FAIL + "FAIL" + FontColors.ENDC
|
|
174
|
+
input_name = function.adj.args[input_i].label
|
|
175
|
+
output_name = function.adj.args[len(inputs) + output_i].label
|
|
176
|
+
summary.append([input_name, output_name, f"{max_abs_error:.7e}", f"{max_rel_error:.7e}", pass_str])
|
|
177
|
+
|
|
178
|
+
if show_summary:
|
|
179
|
+
print_table(summary_header, summary)
|
|
180
|
+
if not success:
|
|
181
|
+
print(FontColors.FAIL + f"Gradient check for kernel {function.key} failed" + FontColors.ENDC)
|
|
182
|
+
else:
|
|
183
|
+
print(FontColors.OKGREEN + f"Gradient check for kernel {function.key} passed" + FontColors.ENDC)
|
|
184
|
+
if plot_relative_error:
|
|
185
|
+
plot_kernel_jacobians(
|
|
186
|
+
relative_error_jacs,
|
|
187
|
+
function,
|
|
188
|
+
inputs,
|
|
189
|
+
outputs,
|
|
190
|
+
title=f"{function.key} kernel Jacobian relative error",
|
|
191
|
+
)
|
|
192
|
+
if plot_absolute_error:
|
|
193
|
+
plot_kernel_jacobians(
|
|
194
|
+
absolute_error_jacs,
|
|
195
|
+
function,
|
|
196
|
+
inputs,
|
|
197
|
+
outputs,
|
|
198
|
+
title=f"{function.key} kernel Jacobian absolute error",
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
return success
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def gradcheck_tape(
|
|
205
|
+
tape: wp.Tape,
|
|
206
|
+
*,
|
|
207
|
+
eps=1e-4,
|
|
208
|
+
atol=1e-3,
|
|
209
|
+
rtol=1e-2,
|
|
210
|
+
raise_exception=True,
|
|
211
|
+
input_output_masks: Dict[str, List[Tuple[Union[str, int], Union[str, int]]]] = None,
|
|
212
|
+
blacklist_kernels: List[str] = None,
|
|
213
|
+
whitelist_kernels: List[str] = None,
|
|
214
|
+
max_inputs_per_var=-1,
|
|
215
|
+
max_outputs_per_var=-1,
|
|
216
|
+
plot_relative_error=False,
|
|
217
|
+
plot_absolute_error=False,
|
|
218
|
+
show_summary: bool = True,
|
|
219
|
+
) -> bool:
|
|
220
|
+
"""
|
|
221
|
+
Checks whether the autodiff gradients for kernels recorded on the Warp tape match finite differences.
|
|
222
|
+
Fails if the relative or absolute errors between the autodiff and finite difference gradients exceed the specified tolerance, or if the autodiff gradients contain NaN values.
|
|
223
|
+
|
|
224
|
+
Note:
|
|
225
|
+
Only Warp kernels recorded on the tape are checked but not arbitrary functions that have been recorded, e.g. via :meth:`Tape.record_func`.
|
|
226
|
+
|
|
227
|
+
Only Warp arrays with ``requires_grad=True`` are considered for the Jacobian computation.
|
|
228
|
+
|
|
229
|
+
Structs arguments are not yet supported by this function to compute Jacobians.
|
|
230
|
+
|
|
231
|
+
Args:
|
|
232
|
+
tape: The Warp tape to perform the gradient check on.
|
|
233
|
+
eps: The finite-difference step size.
|
|
234
|
+
atol: The absolute tolerance for the gradient check.
|
|
235
|
+
rtol: The relative tolerance for the gradient check.
|
|
236
|
+
raise_exception: If True, raises a `ValueError` if the gradient check fails.
|
|
237
|
+
input_output_masks: Dictionary of input-output masks for each kernel in the tape, mapping from kernel keys to input-output masks. Inputs and outputs can be identified either by their integer indices of where they appear in the kernel input/output arguments, or by the respective argument names as strings. If None, computes the Jacobian for all input-output pairs.
|
|
238
|
+
blacklist_kernels: List of kernel keys to exclude from the gradient check.
|
|
239
|
+
whitelist_kernels: List of kernel keys to include in the gradient check. If not empty or None, only kernels in this list are checked.
|
|
240
|
+
max_blocks: The maximum number of CUDA thread blocks to use.
|
|
241
|
+
max_inputs_per_var: Maximum number of input dimensions over which to evaluate the Jacobians for the input-output pairs. Evaluates all input dimensions if value <= 0.
|
|
242
|
+
max_outputs_per_var: Maximum number of output dimensions over which to evaluate the Jacobians for the input-output pairs. Evaluates all output dimensions if value <= 0.
|
|
243
|
+
plot_relative_error: If True, visualizes the relative error of the Jacobians in a plot (requires ``matplotlib``).
|
|
244
|
+
plot_absolute_error: If True, visualizes the absolute error of the Jacobians in a plot (requires ``matplotlib``).
|
|
245
|
+
show_summary: If True, prints a summary table of the gradient check results.
|
|
246
|
+
|
|
247
|
+
Returns:
|
|
248
|
+
True if the gradient check passes for all kernels on the tape, False otherwise.
|
|
249
|
+
"""
|
|
250
|
+
if input_output_masks is None:
|
|
251
|
+
input_output_masks = {}
|
|
252
|
+
if blacklist_kernels is None:
|
|
253
|
+
blacklist_kernels = []
|
|
254
|
+
else:
|
|
255
|
+
blacklist_kernels = set(blacklist_kernels)
|
|
256
|
+
if whitelist_kernels is None:
|
|
257
|
+
whitelist_kernels = []
|
|
258
|
+
else:
|
|
259
|
+
whitelist_kernels = set(whitelist_kernels)
|
|
260
|
+
|
|
261
|
+
overall_success = True
|
|
262
|
+
for launch in tape.launches:
|
|
263
|
+
if not isinstance(launch[0], wp.Kernel):
|
|
264
|
+
continue
|
|
265
|
+
kernel, dim, max_blocks, inputs, outputs, device = launch[:6]
|
|
266
|
+
if len(whitelist_kernels) > 0 and kernel.key not in whitelist_kernels:
|
|
267
|
+
continue
|
|
268
|
+
if kernel.key in blacklist_kernels:
|
|
269
|
+
continue
|
|
270
|
+
input_output_mask = input_output_masks.get(kernel.key)
|
|
271
|
+
success = gradcheck(
|
|
272
|
+
kernel,
|
|
273
|
+
dim,
|
|
274
|
+
inputs,
|
|
275
|
+
outputs,
|
|
276
|
+
eps=eps,
|
|
277
|
+
atol=atol,
|
|
278
|
+
rtol=rtol,
|
|
279
|
+
raise_exception=raise_exception,
|
|
280
|
+
input_output_mask=input_output_mask,
|
|
281
|
+
device=device,
|
|
282
|
+
max_blocks=max_blocks,
|
|
283
|
+
max_inputs_per_var=max_inputs_per_var,
|
|
284
|
+
max_outputs_per_var=max_outputs_per_var,
|
|
285
|
+
plot_relative_error=plot_relative_error,
|
|
286
|
+
plot_absolute_error=plot_absolute_error,
|
|
287
|
+
show_summary=show_summary,
|
|
288
|
+
)
|
|
289
|
+
overall_success = overall_success and success
|
|
290
|
+
|
|
291
|
+
return overall_success
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
def get_struct_vars(x: wp.codegen.StructInstance):
|
|
295
|
+
return {varname: getattr(x, varname) for varname, _ in x._cls.ctype._fields_}
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
def infer_device(xs: list):
|
|
299
|
+
# retrieve best matching Warp device for a list of variables
|
|
300
|
+
for x in xs:
|
|
301
|
+
if isinstance(x, wp.array):
|
|
302
|
+
return x.device
|
|
303
|
+
elif isinstance(x, wp.codegen.StructInstance):
|
|
304
|
+
for var in get_struct_vars(x).values():
|
|
305
|
+
if isinstance(var, wp.array):
|
|
306
|
+
return var.device
|
|
307
|
+
return wp.get_preferred_device()
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
def plot_kernel_jacobians(
|
|
311
|
+
jacobians: Dict[Tuple[int, int], wp.array],
|
|
312
|
+
kernel: wp.Kernel,
|
|
313
|
+
inputs: Sequence,
|
|
314
|
+
outputs: Sequence,
|
|
315
|
+
show_plot=True,
|
|
316
|
+
show_colorbar=True,
|
|
317
|
+
scale_colors_per_submatrix=False,
|
|
318
|
+
title: str = None,
|
|
319
|
+
colormap: str = "coolwarm",
|
|
320
|
+
log_scale=False,
|
|
321
|
+
):
|
|
322
|
+
"""
|
|
323
|
+
Visualizes the Jacobians computed by :func:`jacobian` or :func:`jacobian_fd` in a combined image plot.
|
|
324
|
+
Requires the ``matplotlib`` package to be installed.
|
|
325
|
+
|
|
326
|
+
Args:
|
|
327
|
+
jacobians: A dictionary of Jacobians, where the keys are tuples of input and output indices, and the values are the Jacobian matrices.
|
|
328
|
+
kernel: The Warp kernel function, decorated with the ``@wp.kernel`` decorator.
|
|
329
|
+
inputs: List of input variables.
|
|
330
|
+
outputs: List of output variables.
|
|
331
|
+
show_plot: If True, displays the plot via ``plt.show()``.
|
|
332
|
+
show_colorbar: If True, displays a colorbar next to the plot (or a colorbar next to every submatrix if ).
|
|
333
|
+
scale_colors_per_submatrix: If True, considers the minimum and maximum of each Jacobian submatrix separately for color scaling. Otherwise, uses the global minimum and maximum of all Jacobians.
|
|
334
|
+
title: The title of the plot (optional).
|
|
335
|
+
colormap: The colormap to use for the plot.
|
|
336
|
+
log_scale: If True, uses a logarithmic scale for the matrix values shown in the image plot.
|
|
337
|
+
|
|
338
|
+
Returns:
|
|
339
|
+
The created Matplotlib figure.
|
|
340
|
+
"""
|
|
341
|
+
import matplotlib.pyplot as plt
|
|
342
|
+
from matplotlib.ticker import FuncFormatter, MaxNLocator, MultipleLocator
|
|
343
|
+
|
|
344
|
+
jacobians = sorted(jacobians.items(), key=lambda x: (x[0][1], x[0][0]))
|
|
345
|
+
jacobians = dict(jacobians)
|
|
346
|
+
|
|
347
|
+
input_to_ax = {}
|
|
348
|
+
output_to_ax = {}
|
|
349
|
+
for i, j in jacobians.keys():
|
|
350
|
+
if i not in input_to_ax:
|
|
351
|
+
input_to_ax[i] = len(input_to_ax)
|
|
352
|
+
if j not in output_to_ax:
|
|
353
|
+
output_to_ax[j] = len(output_to_ax)
|
|
354
|
+
|
|
355
|
+
num_rows = len(output_to_ax)
|
|
356
|
+
num_cols = len(input_to_ax)
|
|
357
|
+
if num_rows == 0 or num_cols == 0:
|
|
358
|
+
return
|
|
359
|
+
|
|
360
|
+
# determine the width and height ratios for the subplots based on the
|
|
361
|
+
# dimensions of the Jacobians
|
|
362
|
+
width_ratios = []
|
|
363
|
+
height_ratios = []
|
|
364
|
+
for i, input in enumerate(inputs):
|
|
365
|
+
if not isinstance(input, wp.array) or not input.requires_grad:
|
|
366
|
+
continue
|
|
367
|
+
input_stride = input.dtype._length_
|
|
368
|
+
for j in range(len(outputs)):
|
|
369
|
+
if (i, j) not in jacobians:
|
|
370
|
+
continue
|
|
371
|
+
jac_wp = jacobians[(i, j)]
|
|
372
|
+
width_ratios.append(jac_wp.shape[1] * input_stride)
|
|
373
|
+
break
|
|
374
|
+
|
|
375
|
+
for i, output in enumerate(outputs):
|
|
376
|
+
if not isinstance(output, wp.array) or not output.requires_grad:
|
|
377
|
+
continue
|
|
378
|
+
for j in range(len(inputs)):
|
|
379
|
+
if (j, i) not in jacobians:
|
|
380
|
+
continue
|
|
381
|
+
jac_wp = jacobians[(j, i)]
|
|
382
|
+
height_ratios.append(jac_wp.shape[0])
|
|
383
|
+
break
|
|
384
|
+
|
|
385
|
+
fig, axs = plt.subplots(
|
|
386
|
+
ncols=num_cols,
|
|
387
|
+
nrows=num_rows,
|
|
388
|
+
figsize=(7, 7),
|
|
389
|
+
sharex="col",
|
|
390
|
+
sharey="row",
|
|
391
|
+
gridspec_kw={
|
|
392
|
+
"wspace": 0.1,
|
|
393
|
+
"hspace": 0.1,
|
|
394
|
+
"width_ratios": width_ratios,
|
|
395
|
+
"height_ratios": height_ratios,
|
|
396
|
+
},
|
|
397
|
+
subplot_kw={"aspect": 1},
|
|
398
|
+
squeeze=False,
|
|
399
|
+
)
|
|
400
|
+
if title is None:
|
|
401
|
+
title = f"{kernel.key} kernel Jacobian"
|
|
402
|
+
fig.suptitle(title)
|
|
403
|
+
fig.canvas.manager.set_window_title(title)
|
|
404
|
+
|
|
405
|
+
if not scale_colors_per_submatrix:
|
|
406
|
+
safe_jacobians = [jac.numpy().flatten() for jac in jacobians.values()]
|
|
407
|
+
safe_jacobians = [jac[~np.isnan(jac)] for jac in safe_jacobians]
|
|
408
|
+
safe_jacobians = [jac for jac in safe_jacobians if len(jac) > 0]
|
|
409
|
+
if len(safe_jacobians) == 0:
|
|
410
|
+
vmin = 0
|
|
411
|
+
vmax = 0
|
|
412
|
+
else:
|
|
413
|
+
vmin = min([jac.min() for jac in safe_jacobians])
|
|
414
|
+
vmax = max([jac.max() for jac in safe_jacobians])
|
|
415
|
+
|
|
416
|
+
has_plot = np.ones((num_rows, num_cols), dtype=bool)
|
|
417
|
+
for i in range(num_rows):
|
|
418
|
+
for j in range(num_cols):
|
|
419
|
+
if (j, i) not in jacobians:
|
|
420
|
+
ax = axs[i, j]
|
|
421
|
+
ax.axis("off")
|
|
422
|
+
has_plot[i, j] = False
|
|
423
|
+
|
|
424
|
+
jac_i = 0
|
|
425
|
+
for (input_i, output_i), jac_wp in jacobians.items():
|
|
426
|
+
input = inputs[input_i]
|
|
427
|
+
output = outputs[output_i]
|
|
428
|
+
if not isinstance(input, wp.array) or not input.requires_grad:
|
|
429
|
+
continue
|
|
430
|
+
if not isinstance(output, wp.array) or not output.requires_grad:
|
|
431
|
+
continue
|
|
432
|
+
|
|
433
|
+
input_name = kernel.adj.args[input_i].label
|
|
434
|
+
output_name = kernel.adj.args[len(inputs) + output_i].label
|
|
435
|
+
|
|
436
|
+
ax_i, ax_j = output_to_ax[output_i], input_to_ax[input_i]
|
|
437
|
+
ax = axs[ax_i, ax_j]
|
|
438
|
+
ax.tick_params(which="major", width=1, length=7)
|
|
439
|
+
ax.tick_params(which="minor", width=1, length=4, color="gray")
|
|
440
|
+
# ax.yaxis.set_minor_formatter('{x:.0f}')
|
|
441
|
+
|
|
442
|
+
input_stride = input.dtype._length_
|
|
443
|
+
output_stride = output.dtype._length_
|
|
444
|
+
|
|
445
|
+
jac = jac_wp.numpy()
|
|
446
|
+
# Jacobian matrix has output stride already multiplied to first dimension
|
|
447
|
+
jac = jac.reshape(jac_wp.shape[0], jac_wp.shape[1] * input_stride)
|
|
448
|
+
ax.xaxis.set_minor_formatter("")
|
|
449
|
+
ax.yaxis.set_minor_formatter("")
|
|
450
|
+
ax.xaxis.set_minor_locator(MultipleLocator(1))
|
|
451
|
+
ax.yaxis.set_minor_locator(MultipleLocator(1))
|
|
452
|
+
# ax.set_xticks(np.arange(jac.shape[0]))
|
|
453
|
+
# stride = jac.shape[1] // jacobians[jac_i].shape[1]
|
|
454
|
+
# ax.xaxis.set_major_locator(MultipleLocator(input_stride))
|
|
455
|
+
if input_stride > 1:
|
|
456
|
+
ax.xaxis.set_major_locator(MaxNLocator(integer=True, nbins=1, steps=[input_stride]))
|
|
457
|
+
ticks = FuncFormatter(lambda x, pos, input_stride=input_stride: "{0:g}".format(x // input_stride))
|
|
458
|
+
ax.xaxis.set_major_formatter(ticks)
|
|
459
|
+
# ax.xaxis.set_major_locator(FixedLocator(np.arange(0, jac.shape[1] + 1, input_stride)))
|
|
460
|
+
# ax.xaxis.set_major_formatter('{x:.0f}')
|
|
461
|
+
# ticks = np.arange(jac_wp.shape[1] + 1)
|
|
462
|
+
# ax.set_xticklabels(ticks)
|
|
463
|
+
|
|
464
|
+
# ax.yaxis.set_major_locator(FixedLocator(np.arange(0, jac.shape[0] + 1, output_stride)))
|
|
465
|
+
# ax.yaxis.set_major_formatter('{x:.0f}')
|
|
466
|
+
# ax.yaxis.set_major_locator(MultipleLocator(output_stride))
|
|
467
|
+
|
|
468
|
+
if output_stride > 1:
|
|
469
|
+
ax.yaxis.set_major_locator(MaxNLocator(integer=True, nbins=1, steps=[output_stride]))
|
|
470
|
+
max_y = jac_wp.shape[0]
|
|
471
|
+
ticks = FuncFormatter(
|
|
472
|
+
lambda y, pos, max_y=max_y, output_stride=output_stride: "{0:g}".format((max_y - y) // output_stride)
|
|
473
|
+
)
|
|
474
|
+
ax.yaxis.set_major_formatter(ticks)
|
|
475
|
+
# divide by output stride to get the correct number of rows
|
|
476
|
+
ticks = np.arange(jac_wp.shape[0] // output_stride + 1)
|
|
477
|
+
# flip y labels to match the order of matrix rows starting from the top
|
|
478
|
+
# ax.set_yticklabels(ticks[::-1])
|
|
479
|
+
if scale_colors_per_submatrix:
|
|
480
|
+
safe_jac = jac[~np.isnan(jac)]
|
|
481
|
+
vmin = safe_jac.min()
|
|
482
|
+
vmax = safe_jac.max()
|
|
483
|
+
img = ax.imshow(
|
|
484
|
+
np.log10(np.abs(jac) + 1e-8) if log_scale else jac,
|
|
485
|
+
cmap=colormap,
|
|
486
|
+
aspect="auto",
|
|
487
|
+
interpolation="nearest",
|
|
488
|
+
extent=[0, jac.shape[1], 0, jac.shape[0]],
|
|
489
|
+
vmin=vmin,
|
|
490
|
+
vmax=vmax,
|
|
491
|
+
)
|
|
492
|
+
if ax_i == len(outputs) - 1 or not has_plot[ax_i + 1 :, ax_j].any():
|
|
493
|
+
# last plot of this column
|
|
494
|
+
ax.set_xlabel(input_name)
|
|
495
|
+
if ax_j == 0 or not has_plot[ax_i, :ax_j].any():
|
|
496
|
+
# first plot of this row
|
|
497
|
+
ax.set_ylabel(output_name)
|
|
498
|
+
ax.grid(color="gray", which="minor", linestyle="--", linewidth=0.5)
|
|
499
|
+
ax.grid(color="black", which="major", linewidth=1.0)
|
|
500
|
+
|
|
501
|
+
if show_colorbar and scale_colors_per_submatrix:
|
|
502
|
+
plt.colorbar(img, ax=ax, orientation="vertical", pad=0.02)
|
|
503
|
+
|
|
504
|
+
jac_i += 1
|
|
505
|
+
|
|
506
|
+
if show_colorbar and not scale_colors_per_submatrix:
|
|
507
|
+
m = plt.cm.ScalarMappable(cmap=colormap)
|
|
508
|
+
m.set_array([vmin, vmax])
|
|
509
|
+
m.set_clim(vmin, vmax)
|
|
510
|
+
plt.colorbar(m, ax=axs, orientation="vertical", pad=0.02)
|
|
511
|
+
|
|
512
|
+
plt.tight_layout()
|
|
513
|
+
if show_plot:
|
|
514
|
+
plt.show()
|
|
515
|
+
return fig
|
|
516
|
+
|
|
517
|
+
|
|
518
|
+
def scalarize_array_1d(arr):
|
|
519
|
+
# convert array to 1D array with scalar dtype
|
|
520
|
+
if arr.dtype in wp.types.scalar_types:
|
|
521
|
+
return arr.flatten()
|
|
522
|
+
elif arr.dtype in wp.types.vector_types:
|
|
523
|
+
return wp.array(
|
|
524
|
+
ptr=arr.ptr,
|
|
525
|
+
shape=(arr.size * arr.dtype._length_,),
|
|
526
|
+
dtype=arr.dtype._wp_scalar_type_,
|
|
527
|
+
device=arr.device,
|
|
528
|
+
)
|
|
529
|
+
else:
|
|
530
|
+
raise ValueError(
|
|
531
|
+
f"Unsupported array dtype {arr.dtype}: array to be flattened must be a scalar/vector/matrix array"
|
|
532
|
+
)
|
|
533
|
+
|
|
534
|
+
|
|
535
|
+
def scalarize_array_2d(arr):
|
|
536
|
+
assert arr.ndim == 2
|
|
537
|
+
# convert array to 2D array with scalar dtype
|
|
538
|
+
if arr.dtype in wp.types.scalar_types:
|
|
539
|
+
return arr
|
|
540
|
+
elif arr.dtype in wp.types.vector_types:
|
|
541
|
+
return wp.array(
|
|
542
|
+
ptr=arr.ptr,
|
|
543
|
+
shape=(arr.shape[0], arr.shape[1] * arr.dtype._length_),
|
|
544
|
+
dtype=arr.dtype._wp_scalar_type_,
|
|
545
|
+
device=arr.device,
|
|
546
|
+
)
|
|
547
|
+
else:
|
|
548
|
+
raise ValueError(
|
|
549
|
+
f"Unsupported array dtype {arr.dtype}: array to be flattened must be a scalar/vector/matrix array"
|
|
550
|
+
)
|
|
551
|
+
|
|
552
|
+
|
|
553
|
+
def jacobian(
|
|
554
|
+
kernel: wp.Kernel,
|
|
555
|
+
dim: Tuple[int],
|
|
556
|
+
inputs: Sequence,
|
|
557
|
+
outputs: Sequence = None,
|
|
558
|
+
input_output_mask: List[Tuple[Union[str, int], Union[str, int]]] = None,
|
|
559
|
+
device: wp.context.Devicelike = None,
|
|
560
|
+
max_blocks=0,
|
|
561
|
+
max_outputs_per_var=-1,
|
|
562
|
+
plot_jacobians=False,
|
|
563
|
+
) -> Dict[Tuple[int, int], wp.array]:
|
|
564
|
+
"""
|
|
565
|
+
Computes the Jacobians of a Warp kernel launch for the provided selection of differentiable inputs to differentiable outputs.
|
|
566
|
+
|
|
567
|
+
The kernel adjoint function is launched with the given inputs and outputs, as well as the provided ``dim`` and ``max_blocks`` arguments (see :func:`warp.launch` for more details).
|
|
568
|
+
|
|
569
|
+
Note:
|
|
570
|
+
This function only supports Warp kernels whose input arguments precede the output arguments.
|
|
571
|
+
|
|
572
|
+
Only Warp arrays with ``requires_grad=True`` are considered for the Jacobian computation.
|
|
573
|
+
|
|
574
|
+
Structs arguments are not yet supported by this function to compute Jacobians.
|
|
575
|
+
|
|
576
|
+
Args:
|
|
577
|
+
kernel: The Warp kernel function, decorated with the ``@wp.kernel`` decorator
|
|
578
|
+
dim: The number of threads to launch the kernel, can be an integer, or a Tuple of ints
|
|
579
|
+
inputs: List of input variables.
|
|
580
|
+
outputs: List of output variables. If None, the outputs are inferred from the kernel argument flags.
|
|
581
|
+
input_output_mask: List of tuples specifying the input-output pairs to compute the Jacobian for. Inputs and outputs can be identified either by their integer indices of where they appear in the kernel input/output arguments, or by the respective argument names as strings. If None, computes the Jacobian for all input-output pairs.
|
|
582
|
+
device: The device to launch on (optional)
|
|
583
|
+
max_blocks: The maximum number of CUDA thread blocks to use.
|
|
584
|
+
max_outputs_per_var: Maximum number of output dimensions over which to evaluate the Jacobians for the input-output pairs. Evaluates all output dimensions if value <= 0.
|
|
585
|
+
plot_jacobians: If True, visualizes the computed Jacobians in a plot (requires ``matplotlib``).
|
|
586
|
+
|
|
587
|
+
Returns:
|
|
588
|
+
A dictionary of Jacobians, where the keys are tuples of input and output indices, and the values are the Jacobian matrices.
|
|
589
|
+
"""
|
|
590
|
+
if outputs is None:
|
|
591
|
+
outputs = []
|
|
592
|
+
if input_output_mask is None:
|
|
593
|
+
input_output_mask = []
|
|
594
|
+
arg_names = [arg.label for arg in kernel.adj.args]
|
|
595
|
+
|
|
596
|
+
def resolve_arg(name):
|
|
597
|
+
if isinstance(name, int):
|
|
598
|
+
return name
|
|
599
|
+
return arg_names.index(name)
|
|
600
|
+
|
|
601
|
+
input_output_mask = [
|
|
602
|
+
(resolve_arg(input_name), resolve_arg(output_name) - len(inputs))
|
|
603
|
+
for input_name, output_name in input_output_mask
|
|
604
|
+
]
|
|
605
|
+
input_output_mask = set(input_output_mask)
|
|
606
|
+
|
|
607
|
+
if device is None:
|
|
608
|
+
device = infer_device(inputs + outputs)
|
|
609
|
+
|
|
610
|
+
tape = wp.Tape()
|
|
611
|
+
tape.record_launch(kernel=kernel, dim=dim, max_blocks=max_blocks, inputs=inputs, outputs=outputs, device=device)
|
|
612
|
+
|
|
613
|
+
jacobians = {}
|
|
614
|
+
|
|
615
|
+
for input_i, output_i in itertools.product(range(len(inputs)), range(len(outputs))):
|
|
616
|
+
if len(input_output_mask) > 0 and (input_i, output_i) not in input_output_mask:
|
|
617
|
+
continue
|
|
618
|
+
input = inputs[input_i]
|
|
619
|
+
output = outputs[output_i]
|
|
620
|
+
if not isinstance(input, wp.array) or not input.requires_grad:
|
|
621
|
+
continue
|
|
622
|
+
if not isinstance(output, wp.array) or not output.requires_grad:
|
|
623
|
+
continue
|
|
624
|
+
out_grad = scalarize_array_1d(output.grad)
|
|
625
|
+
output_num = out_grad.shape[0]
|
|
626
|
+
jacobian = wp.empty((output_num, input.size), dtype=input.dtype, device=input.device)
|
|
627
|
+
jacobian.fill_(wp.nan)
|
|
628
|
+
if max_outputs_per_var > 0:
|
|
629
|
+
output_num = min(output_num, max_outputs_per_var)
|
|
630
|
+
for i in range(output_num):
|
|
631
|
+
tape.zero()
|
|
632
|
+
if i > 0:
|
|
633
|
+
set_element(out_grad, i - 1, 0.0)
|
|
634
|
+
set_element(out_grad, i, 1.0)
|
|
635
|
+
tape.backward()
|
|
636
|
+
jacobian[i].assign(input.grad)
|
|
637
|
+
output.grad.zero_()
|
|
638
|
+
jacobians[input_i, output_i] = jacobian
|
|
639
|
+
|
|
640
|
+
if plot_jacobians:
|
|
641
|
+
plot_kernel_jacobians(
|
|
642
|
+
jacobians,
|
|
643
|
+
kernel,
|
|
644
|
+
inputs,
|
|
645
|
+
outputs,
|
|
646
|
+
)
|
|
647
|
+
|
|
648
|
+
return jacobians
|
|
649
|
+
|
|
650
|
+
|
|
651
|
+
def jacobian_fd(
|
|
652
|
+
kernel: wp.Kernel,
|
|
653
|
+
dim: Tuple[int],
|
|
654
|
+
inputs: Sequence,
|
|
655
|
+
outputs: Sequence = None,
|
|
656
|
+
input_output_mask: List[Tuple[Union[str, int], Union[str, int]]] = None,
|
|
657
|
+
device: wp.context.Devicelike = None,
|
|
658
|
+
max_blocks=0,
|
|
659
|
+
max_inputs_per_var=-1,
|
|
660
|
+
eps=1e-4,
|
|
661
|
+
plot_jacobians=False,
|
|
662
|
+
) -> Dict[Tuple[int, int], wp.array]:
|
|
663
|
+
"""
|
|
664
|
+
Computes the finite-difference Jacobian of a Warp kernel launch for the provided selection of differentiable inputs to differentiable outputs.
|
|
665
|
+
The method uses a central difference scheme to approximate the Jacobian.
|
|
666
|
+
|
|
667
|
+
The kernel is launched multiple times in forward-only mode with the given inputs and outputs, as well as the provided ``dim`` and ``max_blocks`` arguments (see :func:`warp.launch` for more details).
|
|
668
|
+
|
|
669
|
+
Note:
|
|
670
|
+
This function only supports Warp kernels whose input arguments precede the output arguments.
|
|
671
|
+
|
|
672
|
+
Only Warp arrays with ``requires_grad=True`` are considered for the Jacobian computation.
|
|
673
|
+
|
|
674
|
+
Structs arguments are not yet supported by this function to compute Jacobians.
|
|
675
|
+
|
|
676
|
+
Args:
|
|
677
|
+
kernel: The Warp kernel function, decorated with the ``@wp.kernel`` decorator
|
|
678
|
+
dim: The number of threads to launch the kernel, can be an integer, or a Tuple of ints
|
|
679
|
+
inputs: List of input variables.
|
|
680
|
+
outputs: List of output variables. If None, the outputs are inferred from the kernel argument flags.
|
|
681
|
+
input_output_mask: List of tuples specifying the input-output pairs to compute the Jacobian for. Inputs and outputs can be identified either by their integer indices of where they appear in the kernel input/output arguments, or by the respective argument names as strings. If None, computes the Jacobian for all input-output pairs.
|
|
682
|
+
device: The device to launch on (optional)
|
|
683
|
+
max_blocks: The maximum number of CUDA thread blocks to use.
|
|
684
|
+
max_inputs_per_var: Maximum number of input dimensions over which to evaluate the Jacobians for the input-output pairs. Evaluates all input dimensions if value <= 0.
|
|
685
|
+
eps: The finite-difference step size.
|
|
686
|
+
plot_jacobians: If True, visualizes the computed Jacobians in a plot (requires ``matplotlib``).
|
|
687
|
+
|
|
688
|
+
Returns:
|
|
689
|
+
A dictionary of Jacobians, where the keys are tuples of input and output indices, and the values are the Jacobian matrices.
|
|
690
|
+
"""
|
|
691
|
+
if outputs is None:
|
|
692
|
+
outputs = []
|
|
693
|
+
if input_output_mask is None:
|
|
694
|
+
input_output_mask = []
|
|
695
|
+
arg_names = [arg.label for arg in kernel.adj.args]
|
|
696
|
+
|
|
697
|
+
def resolve_arg(name):
|
|
698
|
+
if isinstance(name, int):
|
|
699
|
+
return name
|
|
700
|
+
return arg_names.index(name)
|
|
701
|
+
|
|
702
|
+
input_output_mask = [
|
|
703
|
+
(resolve_arg(input_name), resolve_arg(output_name) - len(inputs))
|
|
704
|
+
for input_name, output_name in input_output_mask
|
|
705
|
+
]
|
|
706
|
+
input_output_mask = set(input_output_mask)
|
|
707
|
+
|
|
708
|
+
if device is None:
|
|
709
|
+
device = infer_device(inputs + outputs)
|
|
710
|
+
|
|
711
|
+
jacobians = {}
|
|
712
|
+
|
|
713
|
+
for input_i, output_i in itertools.product(range(len(inputs)), range(len(outputs))):
|
|
714
|
+
if len(input_output_mask) > 0 and (input_i, output_i) not in input_output_mask:
|
|
715
|
+
continue
|
|
716
|
+
input = inputs[input_i]
|
|
717
|
+
output = outputs[output_i]
|
|
718
|
+
if not isinstance(input, wp.array) or not input.requires_grad:
|
|
719
|
+
continue
|
|
720
|
+
if not isinstance(output, wp.array) or not output.requires_grad:
|
|
721
|
+
continue
|
|
722
|
+
|
|
723
|
+
flat_input = scalarize_array_1d(input)
|
|
724
|
+
|
|
725
|
+
left = wp.clone(output)
|
|
726
|
+
right = wp.clone(output)
|
|
727
|
+
flat_left = scalarize_array_1d(left)
|
|
728
|
+
flat_right = scalarize_array_1d(right)
|
|
729
|
+
|
|
730
|
+
left_outputs = outputs[:output_i] + [left] + outputs[output_i + 1 :]
|
|
731
|
+
right_outputs = outputs[:output_i] + [right] + outputs[output_i + 1 :]
|
|
732
|
+
|
|
733
|
+
input_num = flat_input.shape[0]
|
|
734
|
+
jacobian = wp.empty((flat_left.size, input.size), dtype=input.dtype, device=input.device)
|
|
735
|
+
jacobian.fill_(wp.nan)
|
|
736
|
+
|
|
737
|
+
jacobian_scalar = scalarize_array_2d(jacobian)
|
|
738
|
+
jacobian_t = jacobian_scalar.transpose()
|
|
739
|
+
if max_inputs_per_var > 0:
|
|
740
|
+
input_num = min(input_num, max_inputs_per_var)
|
|
741
|
+
for i in range(input_num):
|
|
742
|
+
set_element(flat_input, i, -eps, relative=True)
|
|
743
|
+
wp.launch(kernel, dim=dim, max_blocks=max_blocks, inputs=inputs, outputs=left_outputs, device=device)
|
|
744
|
+
|
|
745
|
+
set_element(flat_input, i, 2 * eps, relative=True)
|
|
746
|
+
wp.launch(kernel, dim=dim, max_blocks=max_blocks, inputs=inputs, outputs=right_outputs, device=device)
|
|
747
|
+
|
|
748
|
+
set_element(flat_input, i, -eps, relative=True)
|
|
749
|
+
|
|
750
|
+
compute_fd(flat_left, flat_right, eps, jacobian_t[i])
|
|
751
|
+
|
|
752
|
+
output.grad.zero_()
|
|
753
|
+
jacobians[input_i, output_i] = jacobian
|
|
754
|
+
|
|
755
|
+
if plot_jacobians:
|
|
756
|
+
plot_kernel_jacobians(
|
|
757
|
+
jacobians,
|
|
758
|
+
kernel,
|
|
759
|
+
inputs,
|
|
760
|
+
outputs,
|
|
761
|
+
)
|
|
762
|
+
|
|
763
|
+
return jacobians
|
|
764
|
+
|
|
765
|
+
|
|
766
|
+
@wp.kernel(enable_backward=False)
|
|
767
|
+
def set_element_kernel(a: wp.array(dtype=Any), i: int, val: Any, relative: bool):
|
|
768
|
+
if relative:
|
|
769
|
+
a[i] += val
|
|
770
|
+
else:
|
|
771
|
+
a[i] = val
|
|
772
|
+
|
|
773
|
+
|
|
774
|
+
def set_element(a: wp.array(dtype=Any), i: int, val: Any, relative: bool = False):
|
|
775
|
+
wp.launch(set_element_kernel, dim=1, inputs=[a, i, a.dtype(val), relative], device=a.device)
|
|
776
|
+
|
|
777
|
+
|
|
778
|
+
@wp.kernel(enable_backward=False)
|
|
779
|
+
def compute_fd_kernel(left: wp.array(dtype=Any), right: wp.array(dtype=Any), eps: Any, fd: wp.array(dtype=Any)):
|
|
780
|
+
tid = wp.tid()
|
|
781
|
+
fd[tid] = (right[tid] - left[tid]) / (2.0 * eps)
|
|
782
|
+
|
|
783
|
+
|
|
784
|
+
def compute_fd(left: wp.array(dtype=Any), right: wp.array(dtype=Any), eps: float, fd: wp.array(dtype=Any)):
|
|
785
|
+
wp.launch(compute_fd_kernel, dim=len(left), inputs=[left, right, eps], outputs=[fd], device=left.device)
|
|
786
|
+
|
|
787
|
+
|
|
788
|
+
@wp.kernel(enable_backward=False)
|
|
789
|
+
def compute_error_kernel(
|
|
790
|
+
jacobian_ad: wp.array(dtype=Any),
|
|
791
|
+
jacobian_fd: wp.array(dtype=Any),
|
|
792
|
+
relative_error: wp.array(dtype=Any),
|
|
793
|
+
absolute_error: wp.array(dtype=Any),
|
|
794
|
+
):
|
|
795
|
+
tid = wp.tid()
|
|
796
|
+
ad = jacobian_ad[tid]
|
|
797
|
+
fd = jacobian_fd[tid]
|
|
798
|
+
relative_error[tid] = (ad - fd) / (ad + 1e-8)
|
|
799
|
+
absolute_error[tid] = wp.abs(ad - fd)
|
|
800
|
+
|
|
801
|
+
|
|
802
|
+
def print_table(headers, cells):
|
|
803
|
+
"""
|
|
804
|
+
Prints a table with the given headers and cells.
|
|
805
|
+
|
|
806
|
+
Args:
|
|
807
|
+
headers: List of header strings.
|
|
808
|
+
cells: List of lists of cell strings.
|
|
809
|
+
"""
|
|
810
|
+
import re
|
|
811
|
+
|
|
812
|
+
def sanitized_len(s):
|
|
813
|
+
return len(re.sub(r"\033\[\d+m", "", str(s)))
|
|
814
|
+
|
|
815
|
+
col_widths = [max(sanitized_len(cell) for cell in col) for col in zip(headers, *cells)]
|
|
816
|
+
for header, col_width in zip(headers, col_widths):
|
|
817
|
+
print(f"{header:{col_width}}", end=" | ")
|
|
818
|
+
print()
|
|
819
|
+
print("-" * (sum(col_widths) + 3 * len(col_widths) - 1))
|
|
820
|
+
for cell_row in cells:
|
|
821
|
+
for cell, col_width in zip(cell_row, col_widths):
|
|
822
|
+
print(f"{cell:{col_width}}", end=" | ")
|
|
823
|
+
print()
|