warp-lang 1.5.1__py3-none-manylinux2014_aarch64.whl → 1.6.0__py3-none-manylinux2014_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +5 -0
- warp/autograd.py +414 -191
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/build.py +40 -12
- warp/build_dll.py +13 -6
- warp/builtins.py +1076 -480
- warp/codegen.py +240 -119
- warp/config.py +1 -1
- warp/context.py +298 -84
- warp/examples/assets/square_cloth.usd +0 -0
- warp/examples/benchmarks/benchmark_gemm.py +27 -18
- warp/examples/benchmarks/benchmark_interop_paddle.py +3 -3
- warp/examples/benchmarks/benchmark_interop_torch.py +3 -3
- warp/examples/core/example_torch.py +18 -34
- warp/examples/fem/example_apic_fluid.py +1 -0
- warp/examples/fem/example_mixed_elasticity.py +1 -1
- warp/examples/optim/example_bounce.py +1 -1
- warp/examples/optim/example_cloth_throw.py +1 -1
- warp/examples/optim/example_diffray.py +4 -15
- warp/examples/optim/example_drone.py +1 -1
- warp/examples/optim/example_softbody_properties.py +392 -0
- warp/examples/optim/example_trajectory.py +1 -3
- warp/examples/optim/example_walker.py +5 -0
- warp/examples/sim/example_cartpole.py +0 -2
- warp/examples/sim/example_cloth_self_contact.py +260 -0
- warp/examples/sim/example_granular_collision_sdf.py +4 -5
- warp/examples/sim/example_jacobian_ik.py +0 -2
- warp/examples/sim/example_quadruped.py +5 -2
- warp/examples/tile/example_tile_cholesky.py +79 -0
- warp/examples/tile/example_tile_convolution.py +2 -2
- warp/examples/tile/example_tile_fft.py +2 -2
- warp/examples/tile/example_tile_filtering.py +3 -3
- warp/examples/tile/example_tile_matmul.py +4 -4
- warp/examples/tile/example_tile_mlp.py +12 -12
- warp/examples/tile/example_tile_nbody.py +180 -0
- warp/examples/tile/example_tile_walker.py +319 -0
- warp/math.py +147 -0
- warp/native/array.h +12 -0
- warp/native/builtin.h +0 -1
- warp/native/bvh.cpp +149 -70
- warp/native/bvh.cu +287 -68
- warp/native/bvh.h +195 -85
- warp/native/clang/clang.cpp +5 -1
- warp/native/cuda_util.cpp +35 -0
- warp/native/cuda_util.h +5 -0
- warp/native/exports.h +40 -40
- warp/native/intersect.h +17 -0
- warp/native/mat.h +41 -0
- warp/native/mathdx.cpp +19 -0
- warp/native/mesh.cpp +25 -8
- warp/native/mesh.cu +153 -101
- warp/native/mesh.h +482 -403
- warp/native/quat.h +40 -0
- warp/native/solid_angle.h +7 -0
- warp/native/sort.cpp +85 -0
- warp/native/sort.cu +34 -0
- warp/native/sort.h +3 -1
- warp/native/spatial.h +11 -0
- warp/native/tile.h +1185 -664
- warp/native/tile_reduce.h +8 -6
- warp/native/vec.h +41 -0
- warp/native/warp.cpp +8 -1
- warp/native/warp.cu +263 -40
- warp/native/warp.h +19 -5
- warp/optim/linear.py +22 -4
- warp/render/render_opengl.py +124 -59
- warp/sim/__init__.py +6 -1
- warp/sim/collide.py +270 -26
- warp/sim/integrator_euler.py +25 -7
- warp/sim/integrator_featherstone.py +154 -35
- warp/sim/integrator_vbd.py +842 -40
- warp/sim/model.py +111 -53
- warp/stubs.py +248 -115
- warp/tape.py +28 -30
- warp/tests/aux_test_module_unload.py +15 -0
- warp/tests/{test_sim_grad.py → flaky_test_sim_grad.py} +104 -63
- warp/tests/test_array.py +74 -0
- warp/tests/test_assert.py +242 -0
- warp/tests/test_codegen.py +14 -61
- warp/tests/test_collision.py +2 -2
- warp/tests/test_examples.py +9 -0
- warp/tests/test_grad_debug.py +87 -2
- warp/tests/test_hash_grid.py +1 -1
- warp/tests/test_ipc.py +116 -0
- warp/tests/test_mat.py +138 -167
- warp/tests/test_math.py +47 -1
- warp/tests/test_matmul.py +11 -7
- warp/tests/test_matmul_lite.py +4 -4
- warp/tests/test_mesh.py +84 -60
- warp/tests/test_mesh_query_aabb.py +165 -0
- warp/tests/test_mesh_query_point.py +328 -286
- warp/tests/test_mesh_query_ray.py +134 -121
- warp/tests/test_mlp.py +2 -2
- warp/tests/test_operators.py +43 -0
- warp/tests/test_overwrite.py +2 -2
- warp/tests/test_quat.py +77 -0
- warp/tests/test_reload.py +29 -0
- warp/tests/test_sim_grad_bounce_linear.py +204 -0
- warp/tests/test_static.py +16 -0
- warp/tests/test_tape.py +25 -0
- warp/tests/test_tile.py +134 -191
- warp/tests/test_tile_load.py +356 -0
- warp/tests/test_tile_mathdx.py +61 -8
- warp/tests/test_tile_mlp.py +17 -17
- warp/tests/test_tile_reduce.py +24 -18
- warp/tests/test_tile_shared_memory.py +66 -17
- warp/tests/test_tile_view.py +165 -0
- warp/tests/test_torch.py +35 -0
- warp/tests/test_utils.py +36 -24
- warp/tests/test_vec.py +110 -0
- warp/tests/unittest_suites.py +29 -4
- warp/tests/unittest_utils.py +30 -11
- warp/thirdparty/unittest_parallel.py +2 -2
- warp/types.py +409 -99
- warp/utils.py +9 -5
- {warp_lang-1.5.1.dist-info → warp_lang-1.6.0.dist-info}/METADATA +68 -44
- {warp_lang-1.5.1.dist-info → warp_lang-1.6.0.dist-info}/RECORD +121 -110
- {warp_lang-1.5.1.dist-info → warp_lang-1.6.0.dist-info}/WHEEL +1 -1
- warp/examples/benchmarks/benchmark_tile.py +0 -179
- warp/native/tile_gemm.h +0 -341
- {warp_lang-1.5.1.dist-info → warp_lang-1.6.0.dist-info}/LICENSE.md +0 -0
- {warp_lang-1.5.1.dist-info → warp_lang-1.6.0.dist-info}/top_level.txt +0 -0
warp/autograd.py
CHANGED
|
@@ -5,8 +5,9 @@
|
|
|
5
5
|
# distribution of this software and related documentation without an express
|
|
6
6
|
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
|
7
7
|
|
|
8
|
+
import inspect
|
|
8
9
|
import itertools
|
|
9
|
-
from typing import Any, Dict, List, Sequence, Tuple, Union
|
|
10
|
+
from typing import Any, Callable, Dict, List, Sequence, Tuple, Union
|
|
10
11
|
|
|
11
12
|
import numpy as np
|
|
12
13
|
|
|
@@ -22,23 +23,23 @@ __all__ = [
|
|
|
22
23
|
|
|
23
24
|
|
|
24
25
|
def gradcheck(
|
|
25
|
-
function: wp.Kernel,
|
|
26
|
-
dim: Tuple[int],
|
|
27
|
-
inputs: Sequence,
|
|
28
|
-
outputs: Sequence,
|
|
26
|
+
function: Union[wp.Kernel, Callable],
|
|
27
|
+
dim: Tuple[int] = None,
|
|
28
|
+
inputs: Sequence = None,
|
|
29
|
+
outputs: Sequence = None,
|
|
29
30
|
*,
|
|
30
|
-
eps=1e-4,
|
|
31
|
-
atol=1e-3,
|
|
32
|
-
rtol=1e-2,
|
|
33
|
-
raise_exception=True,
|
|
31
|
+
eps: float = 1e-4,
|
|
32
|
+
atol: float = 1e-3,
|
|
33
|
+
rtol: float = 1e-2,
|
|
34
|
+
raise_exception: bool = True,
|
|
34
35
|
input_output_mask: List[Tuple[Union[str, int], Union[str, int]]] = None,
|
|
35
36
|
device: wp.context.Devicelike = None,
|
|
36
|
-
max_blocks=0,
|
|
37
|
-
block_dim=256,
|
|
38
|
-
max_inputs_per_var
|
|
39
|
-
max_outputs_per_var
|
|
40
|
-
plot_relative_error=False,
|
|
41
|
-
plot_absolute_error=False,
|
|
37
|
+
max_blocks: int = 0,
|
|
38
|
+
block_dim: int = 256,
|
|
39
|
+
max_inputs_per_var: int = -1,
|
|
40
|
+
max_outputs_per_var: int = -1,
|
|
41
|
+
plot_relative_error: bool = False,
|
|
42
|
+
plot_absolute_error: bool = False,
|
|
42
43
|
show_summary: bool = True,
|
|
43
44
|
) -> bool:
|
|
44
45
|
"""
|
|
@@ -56,10 +57,10 @@ def gradcheck(
|
|
|
56
57
|
Structs arguments are not yet supported by this function to compute Jacobians.
|
|
57
58
|
|
|
58
59
|
Args:
|
|
59
|
-
function: The Warp kernel function, decorated with the ``@wp.kernel`` decorator.
|
|
60
|
-
dim: The number of threads to launch the kernel, can be an integer, or a Tuple of ints.
|
|
60
|
+
function: The Warp kernel function, decorated with the ``@wp.kernel`` decorator, or any function that involves Warp kernel launches.
|
|
61
|
+
dim: The number of threads to launch the kernel, can be an integer, or a Tuple of ints. Only required if the function is a Warp kernel.
|
|
61
62
|
inputs: List of input variables.
|
|
62
|
-
outputs: List of output variables.
|
|
63
|
+
outputs: List of output variables. Only required if the function is a Warp kernel.
|
|
63
64
|
eps: The finite-difference step size.
|
|
64
65
|
atol: The absolute tolerance for the gradient check.
|
|
65
66
|
rtol: The relative tolerance for the gradient check.
|
|
@@ -78,9 +79,12 @@ def gradcheck(
|
|
|
78
79
|
True if the gradient check passes, False otherwise.
|
|
79
80
|
"""
|
|
80
81
|
|
|
81
|
-
|
|
82
|
+
if inputs is None:
|
|
83
|
+
raise ValueError("The inputs argument must be provided")
|
|
82
84
|
|
|
83
|
-
|
|
85
|
+
metadata = FunctionMetadata()
|
|
86
|
+
|
|
87
|
+
jacs_ad = jacobian(
|
|
84
88
|
function,
|
|
85
89
|
dim=dim,
|
|
86
90
|
inputs=inputs,
|
|
@@ -89,12 +93,11 @@ def gradcheck(
|
|
|
89
93
|
device=device,
|
|
90
94
|
max_blocks=max_blocks,
|
|
91
95
|
block_dim=block_dim,
|
|
92
|
-
|
|
93
|
-
eps=eps,
|
|
96
|
+
max_outputs_per_var=max_outputs_per_var,
|
|
94
97
|
plot_jacobians=False,
|
|
98
|
+
metadata=metadata,
|
|
95
99
|
)
|
|
96
|
-
|
|
97
|
-
jacs_ad = jacobian(
|
|
100
|
+
jacs_fd = jacobian_fd(
|
|
98
101
|
function,
|
|
99
102
|
dim=dim,
|
|
100
103
|
inputs=inputs,
|
|
@@ -103,8 +106,10 @@ def gradcheck(
|
|
|
103
106
|
device=device,
|
|
104
107
|
max_blocks=max_blocks,
|
|
105
108
|
block_dim=block_dim,
|
|
106
|
-
|
|
109
|
+
max_inputs_per_var=max_inputs_per_var,
|
|
110
|
+
eps=eps,
|
|
107
111
|
plot_jacobians=False,
|
|
112
|
+
metadata=metadata,
|
|
108
113
|
)
|
|
109
114
|
|
|
110
115
|
relative_error_jacs = {}
|
|
@@ -112,7 +117,7 @@ def gradcheck(
|
|
|
112
117
|
|
|
113
118
|
if show_summary:
|
|
114
119
|
summary = []
|
|
115
|
-
summary_header = ["Input", "Output", "Max Abs Error", "Max Rel Error", "Pass"]
|
|
120
|
+
summary_header = ["Input", "Output", "Max Abs Error", "AD at MAE", "FD at MAE", "Max Rel Error", "Pass"]
|
|
116
121
|
|
|
117
122
|
class FontColors:
|
|
118
123
|
OKGREEN = "\033[92m"
|
|
@@ -121,6 +126,8 @@ def gradcheck(
|
|
|
121
126
|
ENDC = "\033[0m"
|
|
122
127
|
|
|
123
128
|
success = True
|
|
129
|
+
any_grad_mismatch = False
|
|
130
|
+
any_grad_nan = False
|
|
124
131
|
for (input_i, output_i), jac_fd in jacs_fd.items():
|
|
125
132
|
jac_ad = jacs_ad[input_i, output_i]
|
|
126
133
|
if plot_relative_error or plot_absolute_error:
|
|
@@ -147,28 +154,15 @@ def gradcheck(
|
|
|
147
154
|
cut_jac_fd = cut_jac_fd[:, :max_inputs_per_var]
|
|
148
155
|
cut_jac_ad = cut_jac_ad[:, :max_inputs_per_var]
|
|
149
156
|
grad_matches = np.allclose(cut_jac_ad, cut_jac_fd, atol=atol, rtol=rtol)
|
|
157
|
+
any_grad_mismatch = any_grad_mismatch or not grad_matches
|
|
150
158
|
success = success and grad_matches
|
|
151
|
-
if not grad_matches:
|
|
152
|
-
if raise_exception:
|
|
153
|
-
raise ValueError(
|
|
154
|
-
f"Gradient check failed for kernel {function.key}, input {input_i}, output {output_i}: "
|
|
155
|
-
f"finite difference and autodiff gradients do not match"
|
|
156
|
-
)
|
|
157
|
-
elif not show_summary:
|
|
158
|
-
return False
|
|
159
159
|
isnan = np.any(np.isnan(cut_jac_ad))
|
|
160
|
+
any_grad_nan = any_grad_nan or isnan
|
|
160
161
|
success = success and not isnan
|
|
161
|
-
if isnan:
|
|
162
|
-
if raise_exception:
|
|
163
|
-
raise ValueError(
|
|
164
|
-
f"Gradient check failed for kernel {function.key}, input {input_i}, output {output_i}: "
|
|
165
|
-
f"gradient contains NaN values"
|
|
166
|
-
)
|
|
167
|
-
elif not show_summary:
|
|
168
|
-
return False
|
|
169
162
|
|
|
170
163
|
if show_summary:
|
|
171
164
|
max_abs_error = np.abs(cut_jac_ad - cut_jac_fd).max()
|
|
165
|
+
arg_max_abs_error = np.unravel_index(np.argmax(np.abs(cut_jac_ad - cut_jac_fd)), cut_jac_ad.shape)
|
|
172
166
|
max_rel_error = np.abs((cut_jac_ad - cut_jac_fd) / (cut_jac_fd + 1e-8)).max()
|
|
173
167
|
if isnan:
|
|
174
168
|
pass_str = FontColors.FAIL + "NaN" + FontColors.ENDC
|
|
@@ -176,33 +170,55 @@ def gradcheck(
|
|
|
176
170
|
pass_str = FontColors.OKGREEN + "PASS" + FontColors.ENDC
|
|
177
171
|
else:
|
|
178
172
|
pass_str = FontColors.FAIL + "FAIL" + FontColors.ENDC
|
|
179
|
-
input_name =
|
|
180
|
-
output_name =
|
|
181
|
-
summary.append(
|
|
173
|
+
input_name = metadata.input_labels[input_i]
|
|
174
|
+
output_name = metadata.output_labels[output_i]
|
|
175
|
+
summary.append(
|
|
176
|
+
[
|
|
177
|
+
input_name,
|
|
178
|
+
output_name,
|
|
179
|
+
f"{max_abs_error:.3e} at {tuple(int(i) for i in arg_max_abs_error)}",
|
|
180
|
+
f"{cut_jac_ad[arg_max_abs_error]:.3e}",
|
|
181
|
+
f"{cut_jac_fd[arg_max_abs_error]:.3e}",
|
|
182
|
+
f"{max_rel_error:.3e}",
|
|
183
|
+
pass_str,
|
|
184
|
+
]
|
|
185
|
+
)
|
|
182
186
|
|
|
183
187
|
if show_summary:
|
|
184
188
|
print_table(summary_header, summary)
|
|
185
189
|
if not success:
|
|
186
|
-
print(FontColors.FAIL + f"Gradient check for kernel {
|
|
190
|
+
print(FontColors.FAIL + f"Gradient check for kernel {metadata.key} failed" + FontColors.ENDC)
|
|
187
191
|
else:
|
|
188
|
-
print(FontColors.OKGREEN + f"Gradient check for kernel {
|
|
192
|
+
print(FontColors.OKGREEN + f"Gradient check for kernel {metadata.key} passed" + FontColors.ENDC)
|
|
189
193
|
if plot_relative_error:
|
|
190
194
|
jacobian_plot(
|
|
191
195
|
relative_error_jacs,
|
|
192
|
-
|
|
196
|
+
metadata,
|
|
193
197
|
inputs,
|
|
194
198
|
outputs,
|
|
195
|
-
title=f"{
|
|
199
|
+
title=f"{metadata.key} kernel Jacobian relative error",
|
|
196
200
|
)
|
|
197
201
|
if plot_absolute_error:
|
|
198
202
|
jacobian_plot(
|
|
199
203
|
absolute_error_jacs,
|
|
200
|
-
|
|
204
|
+
metadata,
|
|
201
205
|
inputs,
|
|
202
206
|
outputs,
|
|
203
|
-
title=f"{
|
|
207
|
+
title=f"{metadata.key} kernel Jacobian absolute error",
|
|
204
208
|
)
|
|
205
209
|
|
|
210
|
+
if raise_exception:
|
|
211
|
+
if any_grad_mismatch:
|
|
212
|
+
raise ValueError(
|
|
213
|
+
f"Gradient check failed for kernel {metadata.key}, input {input_i}, output {output_i}: "
|
|
214
|
+
f"finite difference and autodiff gradients do not match"
|
|
215
|
+
)
|
|
216
|
+
if any_grad_nan:
|
|
217
|
+
raise ValueError(
|
|
218
|
+
f"Gradient check failed for kernel {metadata.key}, input {input_i}, output {output_i}: "
|
|
219
|
+
f"gradient contains NaN values"
|
|
220
|
+
)
|
|
221
|
+
|
|
206
222
|
return success
|
|
207
223
|
|
|
208
224
|
|
|
@@ -221,6 +237,8 @@ def gradcheck_tape(
|
|
|
221
237
|
plot_relative_error=False,
|
|
222
238
|
plot_absolute_error=False,
|
|
223
239
|
show_summary: bool = True,
|
|
240
|
+
reverse_launches: bool = False,
|
|
241
|
+
skip_to_launch_index: int = 0,
|
|
224
242
|
) -> bool:
|
|
225
243
|
"""
|
|
226
244
|
Checks whether the autodiff gradients for kernels recorded on the Warp tape match finite differences.
|
|
@@ -247,6 +265,7 @@ def gradcheck_tape(
|
|
|
247
265
|
plot_relative_error: If True, visualizes the relative error of the Jacobians in a plot (requires ``matplotlib``).
|
|
248
266
|
plot_absolute_error: If True, visualizes the absolute error of the Jacobians in a plot (requires ``matplotlib``).
|
|
249
267
|
show_summary: If True, prints a summary table of the gradient check results.
|
|
268
|
+
reverse_launches: If True, reverses the order of the kernel launches on the tape to check.
|
|
250
269
|
|
|
251
270
|
Returns:
|
|
252
271
|
True if the gradient check passes for all kernels on the tape, False otherwise.
|
|
@@ -263,7 +282,12 @@ def gradcheck_tape(
|
|
|
263
282
|
whitelist_kernels = set(whitelist_kernels)
|
|
264
283
|
|
|
265
284
|
overall_success = True
|
|
266
|
-
|
|
285
|
+
launches = reversed(tape.launches) if reverse_launches else tape.launches
|
|
286
|
+
for i, launch in enumerate(launches):
|
|
287
|
+
if i < skip_to_launch_index:
|
|
288
|
+
continue
|
|
289
|
+
if not isinstance(launch, tuple) and not isinstance(launch, list):
|
|
290
|
+
continue
|
|
267
291
|
if not isinstance(launch[0], wp.Kernel):
|
|
268
292
|
continue
|
|
269
293
|
kernel, dim, max_blocks, inputs, outputs, device, block_dim = launch[:7]
|
|
@@ -271,6 +295,9 @@ def gradcheck_tape(
|
|
|
271
295
|
continue
|
|
272
296
|
if kernel.key in blacklist_kernels:
|
|
273
297
|
continue
|
|
298
|
+
if not kernel.options.get("enable_backward", True):
|
|
299
|
+
continue
|
|
300
|
+
|
|
274
301
|
input_output_mask = input_output_masks.get(kernel.key)
|
|
275
302
|
success = gradcheck(
|
|
276
303
|
kernel,
|
|
@@ -312,11 +339,95 @@ def infer_device(xs: list):
|
|
|
312
339
|
return wp.get_preferred_device()
|
|
313
340
|
|
|
314
341
|
|
|
342
|
+
class FunctionMetadata:
|
|
343
|
+
"""
|
|
344
|
+
Metadata holder for kernel functions or functions with Warp arrays as inputs/outputs.
|
|
345
|
+
"""
|
|
346
|
+
|
|
347
|
+
def __init__(
|
|
348
|
+
self,
|
|
349
|
+
key: str = None,
|
|
350
|
+
input_labels: List[str] = None,
|
|
351
|
+
output_labels: List[str] = None,
|
|
352
|
+
input_strides: List[tuple] = None,
|
|
353
|
+
output_strides: List[tuple] = None,
|
|
354
|
+
input_dtypes: list = None,
|
|
355
|
+
output_dtypes: list = None,
|
|
356
|
+
):
|
|
357
|
+
self.key = key
|
|
358
|
+
self.input_labels = input_labels
|
|
359
|
+
self.output_labels = output_labels
|
|
360
|
+
self.input_strides = input_strides
|
|
361
|
+
self.output_strides = output_strides
|
|
362
|
+
self.input_dtypes = input_dtypes
|
|
363
|
+
self.output_dtypes = output_dtypes
|
|
364
|
+
|
|
365
|
+
@property
|
|
366
|
+
def is_empty(self):
|
|
367
|
+
return self.key is None
|
|
368
|
+
|
|
369
|
+
def input_is_array(self, i: int):
|
|
370
|
+
return self.input_strides[i] is not None
|
|
371
|
+
|
|
372
|
+
def output_is_array(self, i: int):
|
|
373
|
+
return self.output_strides[i] is not None
|
|
374
|
+
|
|
375
|
+
def update_from_kernel(self, kernel: wp.Kernel, inputs: Sequence):
|
|
376
|
+
self.key = kernel.key
|
|
377
|
+
self.input_labels = [arg.label for arg in kernel.adj.args[: len(inputs)]]
|
|
378
|
+
self.output_labels = [arg.label for arg in kernel.adj.args[len(inputs) :]]
|
|
379
|
+
self.input_strides = []
|
|
380
|
+
self.output_strides = []
|
|
381
|
+
self.input_dtypes = []
|
|
382
|
+
self.output_dtypes = []
|
|
383
|
+
for arg in kernel.adj.args[: len(inputs)]:
|
|
384
|
+
if arg.type is wp.array:
|
|
385
|
+
self.input_strides.append(arg.type.strides)
|
|
386
|
+
self.input_dtypes.append(arg.type.dtype)
|
|
387
|
+
else:
|
|
388
|
+
self.input_strides.append(None)
|
|
389
|
+
self.input_dtypes.append(None)
|
|
390
|
+
for arg in kernel.adj.args[len(inputs) :]:
|
|
391
|
+
if arg.type is wp.array:
|
|
392
|
+
self.output_strides.append(arg.type.strides)
|
|
393
|
+
self.output_dtypes.append(arg.type.dtype)
|
|
394
|
+
else:
|
|
395
|
+
self.output_strides.append(None)
|
|
396
|
+
self.output_dtypes.append(None)
|
|
397
|
+
|
|
398
|
+
def update_from_function(self, function: Callable, inputs: Sequence, outputs: Sequence = None):
|
|
399
|
+
self.key = function.__name__
|
|
400
|
+
self.input_labels = list(inspect.signature(function).parameters.keys())
|
|
401
|
+
if outputs is None:
|
|
402
|
+
outputs = function(*inputs)
|
|
403
|
+
if isinstance(outputs, wp.array):
|
|
404
|
+
outputs = [outputs]
|
|
405
|
+
self.output_labels = [f"output_{i}" for i in range(len(outputs))]
|
|
406
|
+
self.input_strides = []
|
|
407
|
+
self.output_strides = []
|
|
408
|
+
self.input_dtypes = []
|
|
409
|
+
self.output_dtypes = []
|
|
410
|
+
for input in inputs:
|
|
411
|
+
if isinstance(input, wp.array):
|
|
412
|
+
self.input_strides.append(input.strides)
|
|
413
|
+
self.input_dtypes.append(input.dtype)
|
|
414
|
+
else:
|
|
415
|
+
self.input_strides.append(None)
|
|
416
|
+
self.input_dtypes.append(None)
|
|
417
|
+
for output in outputs:
|
|
418
|
+
if isinstance(output, wp.array):
|
|
419
|
+
self.output_strides.append(output.strides)
|
|
420
|
+
self.output_dtypes.append(output.dtype)
|
|
421
|
+
else:
|
|
422
|
+
self.output_strides.append(None)
|
|
423
|
+
self.output_dtypes.append(None)
|
|
424
|
+
|
|
425
|
+
|
|
315
426
|
def jacobian_plot(
|
|
316
427
|
jacobians: Dict[Tuple[int, int], wp.array],
|
|
317
|
-
kernel: wp.Kernel,
|
|
318
|
-
inputs: Sequence,
|
|
319
|
-
outputs: Sequence,
|
|
428
|
+
kernel: Union[FunctionMetadata, wp.Kernel],
|
|
429
|
+
inputs: Sequence = None,
|
|
430
|
+
outputs: Sequence = None,
|
|
320
431
|
show_plot=True,
|
|
321
432
|
show_colorbar=True,
|
|
322
433
|
scale_colors_per_submatrix=False,
|
|
@@ -330,9 +441,9 @@ def jacobian_plot(
|
|
|
330
441
|
|
|
331
442
|
Args:
|
|
332
443
|
jacobians: A dictionary of Jacobians, where the keys are tuples of input and output indices, and the values are the Jacobian matrices.
|
|
333
|
-
kernel: The Warp kernel function, decorated with the ``@wp.kernel`` decorator.
|
|
444
|
+
kernel: The Warp kernel function, decorated with the ``@wp.kernel`` decorator, or a :class:`FunctionMetadata` instance with the kernel/function attributes.
|
|
334
445
|
inputs: List of input variables.
|
|
335
|
-
outputs: List of output variables.
|
|
446
|
+
outputs: List of output variables. Deprecated and will be removed in a future Warp version.
|
|
336
447
|
show_plot: If True, displays the plot via ``plt.show()``.
|
|
337
448
|
show_colorbar: If True, displays a colorbar next to the plot (or a colorbar next to every submatrix if ).
|
|
338
449
|
scale_colors_per_submatrix: If True, considers the minimum and maximum of each Jacobian submatrix separately for color scaling. Otherwise, uses the global minimum and maximum of all Jacobians.
|
|
@@ -343,19 +454,39 @@ def jacobian_plot(
|
|
|
343
454
|
Returns:
|
|
344
455
|
The created Matplotlib figure.
|
|
345
456
|
"""
|
|
457
|
+
|
|
346
458
|
import matplotlib.pyplot as plt
|
|
347
|
-
from matplotlib.ticker import
|
|
459
|
+
from matplotlib.ticker import MaxNLocator
|
|
460
|
+
|
|
461
|
+
if isinstance(kernel, wp.Kernel):
|
|
462
|
+
assert inputs is not None
|
|
463
|
+
metadata = FunctionMetadata()
|
|
464
|
+
metadata.update_from_kernel(kernel, inputs)
|
|
465
|
+
elif isinstance(kernel, FunctionMetadata):
|
|
466
|
+
metadata = kernel
|
|
467
|
+
else:
|
|
468
|
+
raise ValueError("Invalid kernel argument: must be a Warp kernel or a FunctionMetadata object")
|
|
469
|
+
if outputs is not None:
|
|
470
|
+
wp.utils.warn(
|
|
471
|
+
"The `outputs` argument to `jacobian_plot` is no longer needed and will be removed in a future Warp version.",
|
|
472
|
+
DeprecationWarning,
|
|
473
|
+
stacklevel=3,
|
|
474
|
+
)
|
|
348
475
|
|
|
349
476
|
jacobians = sorted(jacobians.items(), key=lambda x: (x[0][1], x[0][0]))
|
|
350
477
|
jacobians = dict(jacobians)
|
|
351
478
|
|
|
352
479
|
input_to_ax = {}
|
|
353
480
|
output_to_ax = {}
|
|
481
|
+
ax_to_input = {}
|
|
482
|
+
ax_to_output = {}
|
|
354
483
|
for i, j in jacobians.keys():
|
|
355
484
|
if i not in input_to_ax:
|
|
356
485
|
input_to_ax[i] = len(input_to_ax)
|
|
486
|
+
ax_to_input[input_to_ax[i]] = i
|
|
357
487
|
if j not in output_to_ax:
|
|
358
488
|
output_to_ax[j] = len(output_to_ax)
|
|
489
|
+
ax_to_output[output_to_ax[j]] = j
|
|
359
490
|
|
|
360
491
|
num_rows = len(output_to_ax)
|
|
361
492
|
num_cols = len(input_to_ax)
|
|
@@ -366,19 +497,19 @@ def jacobian_plot(
|
|
|
366
497
|
# dimensions of the Jacobians
|
|
367
498
|
width_ratios = []
|
|
368
499
|
height_ratios = []
|
|
369
|
-
for i
|
|
370
|
-
if not
|
|
500
|
+
for i in range(len(metadata.input_labels)):
|
|
501
|
+
if not metadata.input_is_array(i):
|
|
371
502
|
continue
|
|
372
|
-
input_stride =
|
|
373
|
-
for j in range(len(
|
|
503
|
+
input_stride = metadata.input_strides[i][0]
|
|
504
|
+
for j in range(len(metadata.output_labels)):
|
|
374
505
|
if (i, j) not in jacobians:
|
|
375
506
|
continue
|
|
376
507
|
jac_wp = jacobians[(i, j)]
|
|
377
508
|
width_ratios.append(jac_wp.shape[1] * input_stride)
|
|
378
509
|
break
|
|
379
510
|
|
|
380
|
-
for i
|
|
381
|
-
if not
|
|
511
|
+
for i in range(len(metadata.output_labels)):
|
|
512
|
+
if not metadata.output_is_array(i):
|
|
382
513
|
continue
|
|
383
514
|
for j in range(len(inputs)):
|
|
384
515
|
if (j, i) not in jacobians:
|
|
@@ -403,7 +534,8 @@ def jacobian_plot(
|
|
|
403
534
|
squeeze=False,
|
|
404
535
|
)
|
|
405
536
|
if title is None:
|
|
406
|
-
|
|
537
|
+
key = kernel.key if isinstance(kernel, wp.Kernel) else kernel.get("key", "unknown")
|
|
538
|
+
title = f"{key} kernel Jacobian"
|
|
407
539
|
fig.suptitle(title)
|
|
408
540
|
fig.canvas.manager.set_window_title(title)
|
|
409
541
|
|
|
@@ -421,66 +553,31 @@ def jacobian_plot(
|
|
|
421
553
|
has_plot = np.ones((num_rows, num_cols), dtype=bool)
|
|
422
554
|
for i in range(num_rows):
|
|
423
555
|
for j in range(num_cols):
|
|
424
|
-
if (j, i) not in jacobians:
|
|
556
|
+
if (ax_to_input[j], ax_to_output[i]) not in jacobians:
|
|
425
557
|
ax = axs[i, j]
|
|
426
558
|
ax.axis("off")
|
|
427
559
|
has_plot[i, j] = False
|
|
428
560
|
|
|
429
561
|
jac_i = 0
|
|
430
562
|
for (input_i, output_i), jac_wp in jacobians.items():
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
if not isinstance(input, wp.array) or not input.requires_grad:
|
|
434
|
-
continue
|
|
435
|
-
if not isinstance(output, wp.array) or not output.requires_grad:
|
|
436
|
-
continue
|
|
437
|
-
|
|
438
|
-
input_name = kernel.adj.args[input_i].label
|
|
439
|
-
output_name = kernel.adj.args[len(inputs) + output_i].label
|
|
563
|
+
input_name = metadata.input_labels[input_i]
|
|
564
|
+
output_name = metadata.output_labels[output_i]
|
|
440
565
|
|
|
441
566
|
ax_i, ax_j = output_to_ax[output_i], input_to_ax[input_i]
|
|
442
567
|
ax = axs[ax_i, ax_j]
|
|
443
568
|
ax.tick_params(which="major", width=1, length=7)
|
|
444
569
|
ax.tick_params(which="minor", width=1, length=4, color="gray")
|
|
445
|
-
# ax.yaxis.set_minor_formatter('{x:.0f}')
|
|
446
570
|
|
|
447
|
-
input_stride =
|
|
448
|
-
output_stride =
|
|
571
|
+
input_stride = metadata.input_dtypes[input_i]._length_
|
|
572
|
+
# output_stride = metadata.output_dtypes[output_i]._length_
|
|
449
573
|
|
|
450
574
|
jac = jac_wp.numpy()
|
|
451
575
|
# Jacobian matrix has output stride already multiplied to first dimension
|
|
452
576
|
jac = jac.reshape(jac_wp.shape[0], jac_wp.shape[1] * input_stride)
|
|
453
|
-
|
|
454
|
-
ax.
|
|
455
|
-
ax.
|
|
456
|
-
|
|
457
|
-
# ax.set_xticks(np.arange(jac.shape[0]))
|
|
458
|
-
# stride = jac.shape[1] // jacobians[jac_i].shape[1]
|
|
459
|
-
# ax.xaxis.set_major_locator(MultipleLocator(input_stride))
|
|
460
|
-
if input_stride > 1:
|
|
461
|
-
ax.xaxis.set_major_locator(MaxNLocator(integer=True, nbins=1, steps=[input_stride]))
|
|
462
|
-
ticks = FuncFormatter(lambda x, pos, input_stride=input_stride: "{0:g}".format(x // input_stride))
|
|
463
|
-
ax.xaxis.set_major_formatter(ticks)
|
|
464
|
-
# ax.xaxis.set_major_locator(FixedLocator(np.arange(0, jac.shape[1] + 1, input_stride)))
|
|
465
|
-
# ax.xaxis.set_major_formatter('{x:.0f}')
|
|
466
|
-
# ticks = np.arange(jac_wp.shape[1] + 1)
|
|
467
|
-
# ax.set_xticklabels(ticks)
|
|
468
|
-
|
|
469
|
-
# ax.yaxis.set_major_locator(FixedLocator(np.arange(0, jac.shape[0] + 1, output_stride)))
|
|
470
|
-
# ax.yaxis.set_major_formatter('{x:.0f}')
|
|
471
|
-
# ax.yaxis.set_major_locator(MultipleLocator(output_stride))
|
|
472
|
-
|
|
473
|
-
if output_stride > 1:
|
|
474
|
-
ax.yaxis.set_major_locator(MaxNLocator(integer=True, nbins=1, steps=[output_stride]))
|
|
475
|
-
max_y = jac_wp.shape[0]
|
|
476
|
-
ticks = FuncFormatter(
|
|
477
|
-
lambda y, pos, max_y=max_y, output_stride=output_stride: "{0:g}".format((max_y - y) // output_stride)
|
|
478
|
-
)
|
|
479
|
-
ax.yaxis.set_major_formatter(ticks)
|
|
480
|
-
# divide by output stride to get the correct number of rows
|
|
481
|
-
ticks = np.arange(jac_wp.shape[0] // output_stride + 1)
|
|
482
|
-
# flip y labels to match the order of matrix rows starting from the top
|
|
483
|
-
# ax.set_yticklabels(ticks[::-1])
|
|
577
|
+
|
|
578
|
+
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
|
|
579
|
+
ax.yaxis.set_major_locator(MaxNLocator(integer=True))
|
|
580
|
+
|
|
484
581
|
if scale_colors_per_submatrix:
|
|
485
582
|
safe_jac = jac[~np.isnan(jac)]
|
|
486
583
|
vmin = safe_jac.min()
|
|
@@ -494,7 +591,7 @@ def jacobian_plot(
|
|
|
494
591
|
vmin=vmin,
|
|
495
592
|
vmax=vmax,
|
|
496
593
|
)
|
|
497
|
-
if ax_i ==
|
|
594
|
+
if ax_i == num_rows - 1 or not has_plot[ax_i + 1 :, ax_j].any():
|
|
498
595
|
# last plot of this column
|
|
499
596
|
ax.set_xlabel(input_name)
|
|
500
597
|
if ax_j == 0 or not has_plot[ax_i, :ax_j].any():
|
|
@@ -609,9 +706,9 @@ def scalarize_array_2d(arr):
|
|
|
609
706
|
|
|
610
707
|
|
|
611
708
|
def jacobian(
|
|
612
|
-
|
|
613
|
-
dim: Tuple[int],
|
|
614
|
-
inputs: Sequence,
|
|
709
|
+
function: Union[wp.Kernel, Callable],
|
|
710
|
+
dim: Tuple[int] = None,
|
|
711
|
+
inputs: Sequence = None,
|
|
615
712
|
outputs: Sequence = None,
|
|
616
713
|
input_output_mask: List[Tuple[Union[str, int], Union[str, int]]] = None,
|
|
617
714
|
device: wp.context.Devicelike = None,
|
|
@@ -619,40 +716,84 @@ def jacobian(
|
|
|
619
716
|
block_dim=256,
|
|
620
717
|
max_outputs_per_var=-1,
|
|
621
718
|
plot_jacobians=False,
|
|
719
|
+
metadata: FunctionMetadata = None,
|
|
720
|
+
kernel: wp.Kernel = None,
|
|
622
721
|
) -> Dict[Tuple[int, int], wp.array]:
|
|
623
722
|
"""
|
|
624
|
-
Computes the Jacobians of a Warp kernel
|
|
723
|
+
Computes the Jacobians of a function or Warp kernel for the provided selection of differentiable inputs to differentiable outputs.
|
|
724
|
+
|
|
725
|
+
The input function can be either a Warp kernel (e.g. a function decorated by ``@wp.kernel``) or a regular Python function that accepts arguments (of which some must be Warp arrays) and returns a Warp array or a list of Warp arrays.
|
|
625
726
|
|
|
626
|
-
|
|
727
|
+
In case ``function`` is a Warp kernel, its adjoint kernel is launched with the given inputs and outputs, as well as the provided ``dim``,
|
|
627
728
|
``max_blocks``, and ``block_dim`` arguments (see :func:`warp.launch` for more details).
|
|
628
729
|
|
|
629
730
|
Note:
|
|
630
|
-
|
|
731
|
+
If ``function`` is a Warp kernel, the input arguments must precede the output arguments in the kernel code definition.
|
|
631
732
|
|
|
632
733
|
Only Warp arrays with ``requires_grad=True`` are considered for the Jacobian computation.
|
|
633
734
|
|
|
634
|
-
|
|
735
|
+
Function arguments of type :ref:`Struct <structs>` are not yet supported.
|
|
635
736
|
|
|
636
737
|
Args:
|
|
637
|
-
|
|
638
|
-
dim: The number of threads to launch the kernel, can be an integer, or a Tuple of ints
|
|
639
|
-
inputs: List of input variables.
|
|
640
|
-
outputs: List of output variables.
|
|
738
|
+
function: The Warp kernel function, or a regular Python function that returns a Warp array or a list of Warp arrays.
|
|
739
|
+
dim: The number of threads to launch the kernel, can be an integer, or a Tuple of ints. Only required if ``function`` is a Warp kernel.
|
|
740
|
+
inputs: List of input variables. At least one of the arguments must be a Warp array with ``requires_grad=True``.
|
|
741
|
+
outputs: List of output variables. Optional if the function is a regular Python function that returns a Warp array or a list of Warp arrays. Only required if ``function`` is a Warp kernel.
|
|
641
742
|
input_output_mask: List of tuples specifying the input-output pairs to compute the Jacobian for. Inputs and outputs can be identified either by their integer indices of where they appear in the kernel input/output arguments, or by the respective argument names as strings. If None, computes the Jacobian for all input-output pairs.
|
|
642
|
-
device: The device to launch on (optional)
|
|
643
|
-
max_blocks: The maximum number of CUDA thread blocks to use.
|
|
644
|
-
block_dim: The number of threads per block.
|
|
743
|
+
device: The device to launch on (optional). Only used if ``function`` is a Warp kernel.
|
|
744
|
+
max_blocks: The maximum number of CUDA thread blocks to use. Only used if ``function`` is a Warp kernel.
|
|
745
|
+
block_dim: The number of threads per block. Only used if ``function`` is a Warp kernel.
|
|
645
746
|
max_outputs_per_var: Maximum number of output dimensions over which to evaluate the Jacobians for the input-output pairs. Evaluates all output dimensions if value <= 0.
|
|
646
747
|
plot_jacobians: If True, visualizes the computed Jacobians in a plot (requires ``matplotlib``).
|
|
748
|
+
metadata: The metadata of the kernel function, containing the input and output labels, strides, and dtypes. If None or empty, the metadata is inferred from the kernel or function.
|
|
749
|
+
kernel: Deprecated argument. Use the ``function`` argument instead.
|
|
647
750
|
|
|
648
751
|
Returns:
|
|
649
752
|
A dictionary of Jacobians, where the keys are tuples of input and output indices, and the values are the Jacobian matrices.
|
|
650
753
|
"""
|
|
651
|
-
if outputs is None:
|
|
652
|
-
outputs = []
|
|
653
754
|
if input_output_mask is None:
|
|
654
755
|
input_output_mask = []
|
|
655
|
-
|
|
756
|
+
if kernel is not None:
|
|
757
|
+
wp.utils.warn(
|
|
758
|
+
"The argument `kernel` to the function `wp.autograd.jacobian` is deprecated in favor of the `function` argument and will be removed in a future Warp version.",
|
|
759
|
+
DeprecationWarning,
|
|
760
|
+
stacklevel=3,
|
|
761
|
+
)
|
|
762
|
+
function = kernel
|
|
763
|
+
|
|
764
|
+
if metadata is None:
|
|
765
|
+
metadata = FunctionMetadata()
|
|
766
|
+
|
|
767
|
+
if isinstance(function, wp.Kernel):
|
|
768
|
+
if not function.options.get("enable_backward", True):
|
|
769
|
+
raise ValueError("Kernel must have backward pass enabled to compute Jacobians")
|
|
770
|
+
if outputs is None or len(outputs) == 0:
|
|
771
|
+
raise ValueError("A list of output arguments must be provided to compute kernel Jacobians")
|
|
772
|
+
if device is None:
|
|
773
|
+
device = infer_device(inputs + outputs)
|
|
774
|
+
if metadata.is_empty:
|
|
775
|
+
metadata.update_from_kernel(function, inputs)
|
|
776
|
+
|
|
777
|
+
tape = wp.Tape()
|
|
778
|
+
tape.record_launch(
|
|
779
|
+
kernel=function,
|
|
780
|
+
dim=dim,
|
|
781
|
+
inputs=inputs,
|
|
782
|
+
outputs=outputs,
|
|
783
|
+
device=device,
|
|
784
|
+
max_blocks=max_blocks,
|
|
785
|
+
block_dim=block_dim,
|
|
786
|
+
)
|
|
787
|
+
else:
|
|
788
|
+
tape = wp.Tape()
|
|
789
|
+
with tape:
|
|
790
|
+
outputs = function(*inputs)
|
|
791
|
+
if isinstance(outputs, wp.array):
|
|
792
|
+
outputs = [outputs]
|
|
793
|
+
if metadata.is_empty:
|
|
794
|
+
metadata.update_from_function(function, inputs, outputs)
|
|
795
|
+
|
|
796
|
+
arg_names = metadata.input_labels + metadata.output_labels
|
|
656
797
|
|
|
657
798
|
def resolve_arg(name, offset: int = 0):
|
|
658
799
|
if isinstance(name, int):
|
|
@@ -665,19 +806,8 @@ def jacobian(
|
|
|
665
806
|
]
|
|
666
807
|
input_output_mask = set(input_output_mask)
|
|
667
808
|
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
|
|
671
|
-
tape = wp.Tape()
|
|
672
|
-
tape.record_launch(
|
|
673
|
-
kernel=kernel,
|
|
674
|
-
dim=dim,
|
|
675
|
-
inputs=inputs,
|
|
676
|
-
outputs=outputs,
|
|
677
|
-
device=device,
|
|
678
|
-
max_blocks=max_blocks,
|
|
679
|
-
block_dim=block_dim,
|
|
680
|
-
)
|
|
809
|
+
zero_grads(inputs)
|
|
810
|
+
zero_grads(outputs)
|
|
681
811
|
|
|
682
812
|
jacobians = {}
|
|
683
813
|
|
|
@@ -697,19 +827,21 @@ def jacobian(
|
|
|
697
827
|
if max_outputs_per_var > 0:
|
|
698
828
|
output_num = min(output_num, max_outputs_per_var)
|
|
699
829
|
for i in range(output_num):
|
|
700
|
-
|
|
830
|
+
output.grad.zero_()
|
|
701
831
|
if i > 0:
|
|
702
832
|
set_element(out_grad, i - 1, 0.0)
|
|
703
833
|
set_element(out_grad, i, 1.0)
|
|
704
834
|
tape.backward()
|
|
705
835
|
jacobian[i].assign(input.grad)
|
|
706
|
-
|
|
836
|
+
|
|
837
|
+
zero_grads(inputs)
|
|
838
|
+
zero_grads(outputs)
|
|
707
839
|
jacobians[input_i, output_i] = jacobian
|
|
708
840
|
|
|
709
841
|
if plot_jacobians:
|
|
710
842
|
jacobian_plot(
|
|
711
843
|
jacobians,
|
|
712
|
-
|
|
844
|
+
metadata,
|
|
713
845
|
inputs,
|
|
714
846
|
outputs,
|
|
715
847
|
)
|
|
@@ -718,53 +850,97 @@ def jacobian(
|
|
|
718
850
|
|
|
719
851
|
|
|
720
852
|
def jacobian_fd(
|
|
721
|
-
|
|
722
|
-
dim: Tuple[int],
|
|
723
|
-
inputs: Sequence,
|
|
853
|
+
function: Union[wp.Kernel, Callable],
|
|
854
|
+
dim: Tuple[int] = None,
|
|
855
|
+
inputs: Sequence = None,
|
|
724
856
|
outputs: Sequence = None,
|
|
725
857
|
input_output_mask: List[Tuple[Union[str, int], Union[str, int]]] = None,
|
|
726
858
|
device: wp.context.Devicelike = None,
|
|
727
859
|
max_blocks=0,
|
|
728
860
|
block_dim=256,
|
|
729
861
|
max_inputs_per_var=-1,
|
|
730
|
-
eps=1e-4,
|
|
862
|
+
eps: float = 1e-4,
|
|
731
863
|
plot_jacobians=False,
|
|
864
|
+
metadata: FunctionMetadata = None,
|
|
865
|
+
kernel: wp.Kernel = None,
|
|
732
866
|
) -> Dict[Tuple[int, int], wp.array]:
|
|
733
867
|
"""
|
|
734
|
-
Computes the finite-difference Jacobian of a Warp kernel
|
|
868
|
+
Computes the finite-difference Jacobian of a function or Warp kernel for the provided selection of differentiable inputs to differentiable outputs.
|
|
735
869
|
The method uses a central difference scheme to approximate the Jacobian.
|
|
736
870
|
|
|
737
|
-
The kernel
|
|
738
|
-
|
|
871
|
+
The input function can be either a Warp kernel (e.g. a function decorated by ``@wp.kernel``) or a regular Python function that accepts arguments (of which some must be Warp arrays) and returns a Warp array or a list of Warp arrays.
|
|
872
|
+
|
|
873
|
+
The function is launched multiple times in forward-only mode with the given inputs. If ``function`` is a Warp kernel, the provided inputs and outputs,
|
|
874
|
+
as well as the other parameters ``dim``, ``max_blocks``, and ``block_dim`` are provided to the kernel launch (see :func:`warp.launch`).
|
|
739
875
|
|
|
740
876
|
Note:
|
|
741
|
-
|
|
877
|
+
If ``function`` is a Warp kernel, the input arguments must precede the output arguments in the kernel code definition.
|
|
742
878
|
|
|
743
879
|
Only Warp arrays with ``requires_grad=True`` are considered for the Jacobian computation.
|
|
744
880
|
|
|
745
|
-
|
|
881
|
+
Function arguments of type :ref:`Struct <structs>` are not yet supported.
|
|
746
882
|
|
|
747
883
|
Args:
|
|
748
|
-
|
|
749
|
-
dim: The number of threads to launch the kernel, can be an integer, or a Tuple of ints
|
|
750
|
-
inputs: List of input variables.
|
|
751
|
-
outputs: List of output variables.
|
|
884
|
+
function: The Warp kernel function, or a regular Python function that returns a Warp array or a list of Warp arrays.
|
|
885
|
+
dim: The number of threads to launch the kernel, can be an integer, or a Tuple of ints. Only required if ``function`` is a Warp kernel.
|
|
886
|
+
inputs: List of input variables. At least one of the arguments must be a Warp array with ``requires_grad=True``.
|
|
887
|
+
outputs: List of output variables. Optional if the function is a regular Python function that returns a Warp array or a list of Warp arrays. Only required if ``function`` is a Warp kernel.
|
|
752
888
|
input_output_mask: List of tuples specifying the input-output pairs to compute the Jacobian for. Inputs and outputs can be identified either by their integer indices of where they appear in the kernel input/output arguments, or by the respective argument names as strings. If None, computes the Jacobian for all input-output pairs.
|
|
753
|
-
device: The device to launch on (optional)
|
|
754
|
-
max_blocks: The maximum number of CUDA thread blocks to use.
|
|
755
|
-
block_dim: The number of threads per block.
|
|
889
|
+
device: The device to launch on (optional). Only used if ``function`` is a Warp kernel.
|
|
890
|
+
max_blocks: The maximum number of CUDA thread blocks to use. Only used if ``function`` is a Warp kernel.
|
|
891
|
+
block_dim: The number of threads per block. Only used if ``function`` is a Warp kernel.
|
|
756
892
|
max_inputs_per_var: Maximum number of input dimensions over which to evaluate the Jacobians for the input-output pairs. Evaluates all input dimensions if value <= 0.
|
|
757
893
|
eps: The finite-difference step size.
|
|
758
894
|
plot_jacobians: If True, visualizes the computed Jacobians in a plot (requires ``matplotlib``).
|
|
895
|
+
metadata: The metadata of the kernel function, containing the input and output labels, strides, and dtypes. If None or empty, the metadata is inferred from the kernel or function.
|
|
896
|
+
kernel: Deprecated argument. Use the ``function`` argument instead.
|
|
759
897
|
|
|
760
898
|
Returns:
|
|
761
899
|
A dictionary of Jacobians, where the keys are tuples of input and output indices, and the values are the Jacobian matrices.
|
|
762
900
|
"""
|
|
763
|
-
if outputs is None:
|
|
764
|
-
outputs = []
|
|
765
901
|
if input_output_mask is None:
|
|
766
902
|
input_output_mask = []
|
|
767
|
-
|
|
903
|
+
if kernel is not None:
|
|
904
|
+
wp.utils.warn(
|
|
905
|
+
"The argument `kernel` to the function `wp.autograd.jacobian` is deprecated in favor of the `function` argument and will be removed in a future Warp version.",
|
|
906
|
+
DeprecationWarning,
|
|
907
|
+
stacklevel=3,
|
|
908
|
+
)
|
|
909
|
+
function = kernel
|
|
910
|
+
|
|
911
|
+
if metadata is None:
|
|
912
|
+
metadata = FunctionMetadata()
|
|
913
|
+
|
|
914
|
+
if isinstance(function, wp.Kernel):
|
|
915
|
+
if not function.options.get("enable_backward", True):
|
|
916
|
+
raise ValueError("Kernel must have backward pass enabled to compute Jacobians")
|
|
917
|
+
if outputs is None or len(outputs) == 0:
|
|
918
|
+
raise ValueError("A list of output arguments must be provided to compute kernel Jacobians")
|
|
919
|
+
if device is None:
|
|
920
|
+
device = infer_device(inputs + outputs)
|
|
921
|
+
if metadata.is_empty:
|
|
922
|
+
metadata.update_from_kernel(function, inputs)
|
|
923
|
+
|
|
924
|
+
tape = wp.Tape()
|
|
925
|
+
tape.record_launch(
|
|
926
|
+
kernel=function,
|
|
927
|
+
dim=dim,
|
|
928
|
+
inputs=inputs,
|
|
929
|
+
outputs=outputs,
|
|
930
|
+
device=device,
|
|
931
|
+
max_blocks=max_blocks,
|
|
932
|
+
block_dim=block_dim,
|
|
933
|
+
)
|
|
934
|
+
else:
|
|
935
|
+
tape = wp.Tape()
|
|
936
|
+
with tape:
|
|
937
|
+
outputs = function(*inputs)
|
|
938
|
+
if isinstance(outputs, wp.array):
|
|
939
|
+
outputs = [outputs]
|
|
940
|
+
if metadata.is_empty:
|
|
941
|
+
metadata.update_from_function(function, inputs, outputs)
|
|
942
|
+
|
|
943
|
+
arg_names = metadata.input_labels + metadata.output_labels
|
|
768
944
|
|
|
769
945
|
def resolve_arg(name, offset: int = 0):
|
|
770
946
|
if isinstance(name, int):
|
|
@@ -777,11 +953,15 @@ def jacobian_fd(
|
|
|
777
953
|
]
|
|
778
954
|
input_output_mask = set(input_output_mask)
|
|
779
955
|
|
|
780
|
-
if device is None:
|
|
781
|
-
device = infer_device(inputs + outputs)
|
|
782
|
-
|
|
783
956
|
jacobians = {}
|
|
784
957
|
|
|
958
|
+
def conditional_clone(obj):
|
|
959
|
+
if isinstance(obj, wp.array):
|
|
960
|
+
return wp.clone(obj)
|
|
961
|
+
return obj
|
|
962
|
+
|
|
963
|
+
outputs_copy = [conditional_clone(output) for output in outputs]
|
|
964
|
+
|
|
785
965
|
for input_i, output_i in itertools.product(range(len(inputs)), range(len(outputs))):
|
|
786
966
|
if len(input_output_mask) > 0 and (input_i, output_i) not in input_output_mask:
|
|
787
967
|
continue
|
|
@@ -796,13 +976,20 @@ def jacobian_fd(
|
|
|
796
976
|
|
|
797
977
|
left = wp.clone(output)
|
|
798
978
|
right = wp.clone(output)
|
|
979
|
+
left_copy = wp.clone(output)
|
|
980
|
+
right_copy = wp.clone(output)
|
|
799
981
|
flat_left = scalarize_array_1d(left)
|
|
800
982
|
flat_right = scalarize_array_1d(right)
|
|
801
983
|
|
|
802
|
-
|
|
803
|
-
|
|
984
|
+
outputs_until_left = [conditional_clone(output) for output in outputs_copy[:output_i]]
|
|
985
|
+
outputs_until_right = [conditional_clone(output) for output in outputs_copy[:output_i]]
|
|
986
|
+
outputs_after_left = [conditional_clone(output) for output in outputs_copy[output_i + 1 :]]
|
|
987
|
+
outputs_after_right = [conditional_clone(output) for output in outputs_copy[output_i + 1 :]]
|
|
988
|
+
left_outputs = outputs_until_left + [left] + outputs_after_left
|
|
989
|
+
right_outputs = outputs_until_right + [right] + outputs_after_right
|
|
804
990
|
|
|
805
991
|
input_num = flat_input.shape[0]
|
|
992
|
+
flat_input_copy = wp.clone(flat_input)
|
|
806
993
|
jacobian = wp.empty((flat_left.size, input.size), dtype=input.dtype, device=input.device)
|
|
807
994
|
jacobian.fill_(wp.nan)
|
|
808
995
|
|
|
@@ -812,38 +999,62 @@ def jacobian_fd(
|
|
|
812
999
|
input_num = min(input_num, max_inputs_per_var)
|
|
813
1000
|
for i in range(input_num):
|
|
814
1001
|
set_element(flat_input, i, -eps, relative=True)
|
|
815
|
-
wp.
|
|
816
|
-
|
|
817
|
-
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
|
|
822
|
-
|
|
823
|
-
|
|
1002
|
+
if isinstance(function, wp.Kernel):
|
|
1003
|
+
wp.launch(
|
|
1004
|
+
function,
|
|
1005
|
+
dim=dim,
|
|
1006
|
+
max_blocks=max_blocks,
|
|
1007
|
+
block_dim=block_dim,
|
|
1008
|
+
inputs=inputs,
|
|
1009
|
+
outputs=left_outputs,
|
|
1010
|
+
device=device,
|
|
1011
|
+
)
|
|
1012
|
+
else:
|
|
1013
|
+
outputs = function(*inputs)
|
|
1014
|
+
if isinstance(outputs, wp.array):
|
|
1015
|
+
outputs = [outputs]
|
|
1016
|
+
left.assign(outputs[output_i])
|
|
824
1017
|
|
|
825
1018
|
set_element(flat_input, i, 2 * eps, relative=True)
|
|
826
|
-
wp.
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
|
|
1019
|
+
if isinstance(function, wp.Kernel):
|
|
1020
|
+
wp.launch(
|
|
1021
|
+
function,
|
|
1022
|
+
dim=dim,
|
|
1023
|
+
max_blocks=max_blocks,
|
|
1024
|
+
block_dim=block_dim,
|
|
1025
|
+
inputs=inputs,
|
|
1026
|
+
outputs=right_outputs,
|
|
1027
|
+
device=device,
|
|
1028
|
+
)
|
|
1029
|
+
else:
|
|
1030
|
+
outputs = function(*inputs)
|
|
1031
|
+
if isinstance(outputs, wp.array):
|
|
1032
|
+
outputs = [outputs]
|
|
1033
|
+
right.assign(outputs[output_i])
|
|
1034
|
+
|
|
1035
|
+
# restore input
|
|
1036
|
+
flat_input.assign(flat_input_copy)
|
|
1037
|
+
|
|
1038
|
+
compute_fd(
|
|
1039
|
+
flat_left,
|
|
1040
|
+
flat_right,
|
|
1041
|
+
eps,
|
|
1042
|
+
jacobian_t[i],
|
|
834
1043
|
)
|
|
835
1044
|
|
|
836
|
-
|
|
837
|
-
|
|
838
|
-
|
|
1045
|
+
if i < input_num - 1:
|
|
1046
|
+
# reset output buffers
|
|
1047
|
+
left.assign(left_copy)
|
|
1048
|
+
right.assign(right_copy)
|
|
1049
|
+
flat_left = scalarize_array_1d(left)
|
|
1050
|
+
flat_right = scalarize_array_1d(right)
|
|
839
1051
|
|
|
840
|
-
output.grad.zero_()
|
|
841
1052
|
jacobians[input_i, output_i] = jacobian
|
|
842
1053
|
|
|
843
1054
|
if plot_jacobians:
|
|
844
1055
|
jacobian_plot(
|
|
845
1056
|
jacobians,
|
|
846
|
-
|
|
1057
|
+
metadata,
|
|
847
1058
|
inputs,
|
|
848
1059
|
outputs,
|
|
849
1060
|
)
|
|
@@ -864,7 +1075,7 @@ def set_element(a: wp.array(dtype=Any), i: int, val: Any, relative: bool = False
|
|
|
864
1075
|
|
|
865
1076
|
|
|
866
1077
|
@wp.kernel(enable_backward=False)
|
|
867
|
-
def compute_fd_kernel(left: wp.array(dtype=
|
|
1078
|
+
def compute_fd_kernel(left: wp.array(dtype=float), right: wp.array(dtype=float), eps: float, fd: wp.array(dtype=float)):
|
|
868
1079
|
tid = wp.tid()
|
|
869
1080
|
fd[tid] = (right[tid] - left[tid]) / (2.0 * eps)
|
|
870
1081
|
|
|
@@ -883,7 +1094,10 @@ def compute_error_kernel(
|
|
|
883
1094
|
tid = wp.tid()
|
|
884
1095
|
ad = jacobian_ad[tid]
|
|
885
1096
|
fd = jacobian_fd[tid]
|
|
886
|
-
|
|
1097
|
+
denom = ad
|
|
1098
|
+
if abs(ad) < 1e-8:
|
|
1099
|
+
denom = (type(ad))(1e-8)
|
|
1100
|
+
relative_error[tid] = (ad - fd) / denom
|
|
887
1101
|
absolute_error[tid] = wp.abs(ad - fd)
|
|
888
1102
|
|
|
889
1103
|
|
|
@@ -909,3 +1123,12 @@ def print_table(headers, cells):
|
|
|
909
1123
|
for cell, col_width in zip(cell_row, col_widths):
|
|
910
1124
|
print(f"{cell:{col_width}}", end=" | ")
|
|
911
1125
|
print()
|
|
1126
|
+
|
|
1127
|
+
|
|
1128
|
+
def zero_grads(arrays: list):
|
|
1129
|
+
"""
|
|
1130
|
+
Zeros the gradients of all Warp arrays in the given list.
|
|
1131
|
+
"""
|
|
1132
|
+
for array in arrays:
|
|
1133
|
+
if isinstance(array, wp.array) and array.requires_grad:
|
|
1134
|
+
array.grad.zero_()
|