warp-lang 1.5.1__py3-none-manylinux2014_x86_64.whl → 1.6.0__py3-none-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (123) hide show
  1. warp/__init__.py +5 -0
  2. warp/autograd.py +414 -191
  3. warp/bin/warp-clang.so +0 -0
  4. warp/bin/warp.so +0 -0
  5. warp/build.py +40 -12
  6. warp/build_dll.py +13 -6
  7. warp/builtins.py +1076 -480
  8. warp/codegen.py +240 -119
  9. warp/config.py +1 -1
  10. warp/context.py +298 -84
  11. warp/examples/assets/square_cloth.usd +0 -0
  12. warp/examples/benchmarks/benchmark_gemm.py +27 -18
  13. warp/examples/benchmarks/benchmark_interop_paddle.py +3 -3
  14. warp/examples/benchmarks/benchmark_interop_torch.py +3 -3
  15. warp/examples/core/example_torch.py +18 -34
  16. warp/examples/fem/example_apic_fluid.py +1 -0
  17. warp/examples/fem/example_mixed_elasticity.py +1 -1
  18. warp/examples/optim/example_bounce.py +1 -1
  19. warp/examples/optim/example_cloth_throw.py +1 -1
  20. warp/examples/optim/example_diffray.py +4 -15
  21. warp/examples/optim/example_drone.py +1 -1
  22. warp/examples/optim/example_softbody_properties.py +392 -0
  23. warp/examples/optim/example_trajectory.py +1 -3
  24. warp/examples/optim/example_walker.py +5 -0
  25. warp/examples/sim/example_cartpole.py +0 -2
  26. warp/examples/sim/example_cloth_self_contact.py +260 -0
  27. warp/examples/sim/example_granular_collision_sdf.py +4 -5
  28. warp/examples/sim/example_jacobian_ik.py +0 -2
  29. warp/examples/sim/example_quadruped.py +5 -2
  30. warp/examples/tile/example_tile_cholesky.py +79 -0
  31. warp/examples/tile/example_tile_convolution.py +2 -2
  32. warp/examples/tile/example_tile_fft.py +2 -2
  33. warp/examples/tile/example_tile_filtering.py +3 -3
  34. warp/examples/tile/example_tile_matmul.py +4 -4
  35. warp/examples/tile/example_tile_mlp.py +12 -12
  36. warp/examples/tile/example_tile_nbody.py +180 -0
  37. warp/examples/tile/example_tile_walker.py +319 -0
  38. warp/math.py +147 -0
  39. warp/native/array.h +12 -0
  40. warp/native/builtin.h +0 -1
  41. warp/native/bvh.cpp +149 -70
  42. warp/native/bvh.cu +287 -68
  43. warp/native/bvh.h +195 -85
  44. warp/native/clang/clang.cpp +5 -1
  45. warp/native/cuda_util.cpp +35 -0
  46. warp/native/cuda_util.h +5 -0
  47. warp/native/exports.h +40 -40
  48. warp/native/intersect.h +17 -0
  49. warp/native/mat.h +41 -0
  50. warp/native/mathdx.cpp +19 -0
  51. warp/native/mesh.cpp +25 -8
  52. warp/native/mesh.cu +153 -101
  53. warp/native/mesh.h +482 -403
  54. warp/native/quat.h +40 -0
  55. warp/native/solid_angle.h +7 -0
  56. warp/native/sort.cpp +85 -0
  57. warp/native/sort.cu +34 -0
  58. warp/native/sort.h +3 -1
  59. warp/native/spatial.h +11 -0
  60. warp/native/tile.h +1185 -664
  61. warp/native/tile_reduce.h +8 -6
  62. warp/native/vec.h +41 -0
  63. warp/native/warp.cpp +8 -1
  64. warp/native/warp.cu +263 -40
  65. warp/native/warp.h +19 -5
  66. warp/optim/linear.py +22 -4
  67. warp/render/render_opengl.py +124 -59
  68. warp/sim/__init__.py +6 -1
  69. warp/sim/collide.py +270 -26
  70. warp/sim/integrator_euler.py +25 -7
  71. warp/sim/integrator_featherstone.py +154 -35
  72. warp/sim/integrator_vbd.py +842 -40
  73. warp/sim/model.py +111 -53
  74. warp/stubs.py +248 -115
  75. warp/tape.py +28 -30
  76. warp/tests/aux_test_module_unload.py +15 -0
  77. warp/tests/{test_sim_grad.py → flaky_test_sim_grad.py} +104 -63
  78. warp/tests/test_array.py +74 -0
  79. warp/tests/test_assert.py +242 -0
  80. warp/tests/test_codegen.py +14 -61
  81. warp/tests/test_collision.py +2 -2
  82. warp/tests/test_examples.py +9 -0
  83. warp/tests/test_grad_debug.py +87 -2
  84. warp/tests/test_hash_grid.py +1 -1
  85. warp/tests/test_ipc.py +116 -0
  86. warp/tests/test_mat.py +138 -167
  87. warp/tests/test_math.py +47 -1
  88. warp/tests/test_matmul.py +11 -7
  89. warp/tests/test_matmul_lite.py +4 -4
  90. warp/tests/test_mesh.py +84 -60
  91. warp/tests/test_mesh_query_aabb.py +165 -0
  92. warp/tests/test_mesh_query_point.py +328 -286
  93. warp/tests/test_mesh_query_ray.py +134 -121
  94. warp/tests/test_mlp.py +2 -2
  95. warp/tests/test_operators.py +43 -0
  96. warp/tests/test_overwrite.py +2 -2
  97. warp/tests/test_quat.py +77 -0
  98. warp/tests/test_reload.py +29 -0
  99. warp/tests/test_sim_grad_bounce_linear.py +204 -0
  100. warp/tests/test_static.py +16 -0
  101. warp/tests/test_tape.py +25 -0
  102. warp/tests/test_tile.py +134 -191
  103. warp/tests/test_tile_load.py +356 -0
  104. warp/tests/test_tile_mathdx.py +61 -8
  105. warp/tests/test_tile_mlp.py +17 -17
  106. warp/tests/test_tile_reduce.py +24 -18
  107. warp/tests/test_tile_shared_memory.py +66 -17
  108. warp/tests/test_tile_view.py +165 -0
  109. warp/tests/test_torch.py +35 -0
  110. warp/tests/test_utils.py +36 -24
  111. warp/tests/test_vec.py +110 -0
  112. warp/tests/unittest_suites.py +29 -4
  113. warp/tests/unittest_utils.py +30 -11
  114. warp/thirdparty/unittest_parallel.py +2 -2
  115. warp/types.py +409 -99
  116. warp/utils.py +9 -5
  117. {warp_lang-1.5.1.dist-info → warp_lang-1.6.0.dist-info}/METADATA +68 -44
  118. {warp_lang-1.5.1.dist-info → warp_lang-1.6.0.dist-info}/RECORD +121 -110
  119. {warp_lang-1.5.1.dist-info → warp_lang-1.6.0.dist-info}/WHEEL +1 -1
  120. warp/examples/benchmarks/benchmark_tile.py +0 -179
  121. warp/native/tile_gemm.h +0 -341
  122. {warp_lang-1.5.1.dist-info → warp_lang-1.6.0.dist-info}/LICENSE.md +0 -0
  123. {warp_lang-1.5.1.dist-info → warp_lang-1.6.0.dist-info}/top_level.txt +0 -0
warp/autograd.py CHANGED
@@ -5,8 +5,9 @@
5
5
  # distribution of this software and related documentation without an express
6
6
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
7
 
8
+ import inspect
8
9
  import itertools
9
- from typing import Any, Dict, List, Sequence, Tuple, Union
10
+ from typing import Any, Callable, Dict, List, Sequence, Tuple, Union
10
11
 
11
12
  import numpy as np
12
13
 
@@ -22,23 +23,23 @@ __all__ = [
22
23
 
23
24
 
24
25
  def gradcheck(
25
- function: wp.Kernel,
26
- dim: Tuple[int],
27
- inputs: Sequence,
28
- outputs: Sequence,
26
+ function: Union[wp.Kernel, Callable],
27
+ dim: Tuple[int] = None,
28
+ inputs: Sequence = None,
29
+ outputs: Sequence = None,
29
30
  *,
30
- eps=1e-4,
31
- atol=1e-3,
32
- rtol=1e-2,
33
- raise_exception=True,
31
+ eps: float = 1e-4,
32
+ atol: float = 1e-3,
33
+ rtol: float = 1e-2,
34
+ raise_exception: bool = True,
34
35
  input_output_mask: List[Tuple[Union[str, int], Union[str, int]]] = None,
35
36
  device: wp.context.Devicelike = None,
36
- max_blocks=0,
37
- block_dim=256,
38
- max_inputs_per_var=-1,
39
- max_outputs_per_var=-1,
40
- plot_relative_error=False,
41
- plot_absolute_error=False,
37
+ max_blocks: int = 0,
38
+ block_dim: int = 256,
39
+ max_inputs_per_var: int = -1,
40
+ max_outputs_per_var: int = -1,
41
+ plot_relative_error: bool = False,
42
+ plot_absolute_error: bool = False,
42
43
  show_summary: bool = True,
43
44
  ) -> bool:
44
45
  """
@@ -56,10 +57,10 @@ def gradcheck(
56
57
  Structs arguments are not yet supported by this function to compute Jacobians.
57
58
 
58
59
  Args:
59
- function: The Warp kernel function, decorated with the ``@wp.kernel`` decorator.
60
- dim: The number of threads to launch the kernel, can be an integer, or a Tuple of ints.
60
+ function: The Warp kernel function, decorated with the ``@wp.kernel`` decorator, or any function that involves Warp kernel launches.
61
+ dim: The number of threads to launch the kernel, can be an integer, or a Tuple of ints. Only required if the function is a Warp kernel.
61
62
  inputs: List of input variables.
62
- outputs: List of output variables.
63
+ outputs: List of output variables. Only required if the function is a Warp kernel.
63
64
  eps: The finite-difference step size.
64
65
  atol: The absolute tolerance for the gradient check.
65
66
  rtol: The relative tolerance for the gradient check.
@@ -78,9 +79,12 @@ def gradcheck(
78
79
  True if the gradient check passes, False otherwise.
79
80
  """
80
81
 
81
- assert isinstance(function, wp.Kernel), "The function argument must be a Warp kernel"
82
+ if inputs is None:
83
+ raise ValueError("The inputs argument must be provided")
82
84
 
83
- jacs_fd = jacobian_fd(
85
+ metadata = FunctionMetadata()
86
+
87
+ jacs_ad = jacobian(
84
88
  function,
85
89
  dim=dim,
86
90
  inputs=inputs,
@@ -89,12 +93,11 @@ def gradcheck(
89
93
  device=device,
90
94
  max_blocks=max_blocks,
91
95
  block_dim=block_dim,
92
- max_inputs_per_var=max_inputs_per_var,
93
- eps=eps,
96
+ max_outputs_per_var=max_outputs_per_var,
94
97
  plot_jacobians=False,
98
+ metadata=metadata,
95
99
  )
96
-
97
- jacs_ad = jacobian(
100
+ jacs_fd = jacobian_fd(
98
101
  function,
99
102
  dim=dim,
100
103
  inputs=inputs,
@@ -103,8 +106,10 @@ def gradcheck(
103
106
  device=device,
104
107
  max_blocks=max_blocks,
105
108
  block_dim=block_dim,
106
- max_outputs_per_var=max_outputs_per_var,
109
+ max_inputs_per_var=max_inputs_per_var,
110
+ eps=eps,
107
111
  plot_jacobians=False,
112
+ metadata=metadata,
108
113
  )
109
114
 
110
115
  relative_error_jacs = {}
@@ -112,7 +117,7 @@ def gradcheck(
112
117
 
113
118
  if show_summary:
114
119
  summary = []
115
- summary_header = ["Input", "Output", "Max Abs Error", "Max Rel Error", "Pass"]
120
+ summary_header = ["Input", "Output", "Max Abs Error", "AD at MAE", "FD at MAE", "Max Rel Error", "Pass"]
116
121
 
117
122
  class FontColors:
118
123
  OKGREEN = "\033[92m"
@@ -121,6 +126,8 @@ def gradcheck(
121
126
  ENDC = "\033[0m"
122
127
 
123
128
  success = True
129
+ any_grad_mismatch = False
130
+ any_grad_nan = False
124
131
  for (input_i, output_i), jac_fd in jacs_fd.items():
125
132
  jac_ad = jacs_ad[input_i, output_i]
126
133
  if plot_relative_error or plot_absolute_error:
@@ -147,28 +154,15 @@ def gradcheck(
147
154
  cut_jac_fd = cut_jac_fd[:, :max_inputs_per_var]
148
155
  cut_jac_ad = cut_jac_ad[:, :max_inputs_per_var]
149
156
  grad_matches = np.allclose(cut_jac_ad, cut_jac_fd, atol=atol, rtol=rtol)
157
+ any_grad_mismatch = any_grad_mismatch or not grad_matches
150
158
  success = success and grad_matches
151
- if not grad_matches:
152
- if raise_exception:
153
- raise ValueError(
154
- f"Gradient check failed for kernel {function.key}, input {input_i}, output {output_i}: "
155
- f"finite difference and autodiff gradients do not match"
156
- )
157
- elif not show_summary:
158
- return False
159
159
  isnan = np.any(np.isnan(cut_jac_ad))
160
+ any_grad_nan = any_grad_nan or isnan
160
161
  success = success and not isnan
161
- if isnan:
162
- if raise_exception:
163
- raise ValueError(
164
- f"Gradient check failed for kernel {function.key}, input {input_i}, output {output_i}: "
165
- f"gradient contains NaN values"
166
- )
167
- elif not show_summary:
168
- return False
169
162
 
170
163
  if show_summary:
171
164
  max_abs_error = np.abs(cut_jac_ad - cut_jac_fd).max()
165
+ arg_max_abs_error = np.unravel_index(np.argmax(np.abs(cut_jac_ad - cut_jac_fd)), cut_jac_ad.shape)
172
166
  max_rel_error = np.abs((cut_jac_ad - cut_jac_fd) / (cut_jac_fd + 1e-8)).max()
173
167
  if isnan:
174
168
  pass_str = FontColors.FAIL + "NaN" + FontColors.ENDC
@@ -176,33 +170,55 @@ def gradcheck(
176
170
  pass_str = FontColors.OKGREEN + "PASS" + FontColors.ENDC
177
171
  else:
178
172
  pass_str = FontColors.FAIL + "FAIL" + FontColors.ENDC
179
- input_name = function.adj.args[input_i].label
180
- output_name = function.adj.args[len(inputs) + output_i].label
181
- summary.append([input_name, output_name, f"{max_abs_error:.7e}", f"{max_rel_error:.7e}", pass_str])
173
+ input_name = metadata.input_labels[input_i]
174
+ output_name = metadata.output_labels[output_i]
175
+ summary.append(
176
+ [
177
+ input_name,
178
+ output_name,
179
+ f"{max_abs_error:.3e} at {tuple(int(i) for i in arg_max_abs_error)}",
180
+ f"{cut_jac_ad[arg_max_abs_error]:.3e}",
181
+ f"{cut_jac_fd[arg_max_abs_error]:.3e}",
182
+ f"{max_rel_error:.3e}",
183
+ pass_str,
184
+ ]
185
+ )
182
186
 
183
187
  if show_summary:
184
188
  print_table(summary_header, summary)
185
189
  if not success:
186
- print(FontColors.FAIL + f"Gradient check for kernel {function.key} failed" + FontColors.ENDC)
190
+ print(FontColors.FAIL + f"Gradient check for kernel {metadata.key} failed" + FontColors.ENDC)
187
191
  else:
188
- print(FontColors.OKGREEN + f"Gradient check for kernel {function.key} passed" + FontColors.ENDC)
192
+ print(FontColors.OKGREEN + f"Gradient check for kernel {metadata.key} passed" + FontColors.ENDC)
189
193
  if plot_relative_error:
190
194
  jacobian_plot(
191
195
  relative_error_jacs,
192
- function,
196
+ metadata,
193
197
  inputs,
194
198
  outputs,
195
- title=f"{function.key} kernel Jacobian relative error",
199
+ title=f"{metadata.key} kernel Jacobian relative error",
196
200
  )
197
201
  if plot_absolute_error:
198
202
  jacobian_plot(
199
203
  absolute_error_jacs,
200
- function,
204
+ metadata,
201
205
  inputs,
202
206
  outputs,
203
- title=f"{function.key} kernel Jacobian absolute error",
207
+ title=f"{metadata.key} kernel Jacobian absolute error",
204
208
  )
205
209
 
210
+ if raise_exception:
211
+ if any_grad_mismatch:
212
+ raise ValueError(
213
+ f"Gradient check failed for kernel {metadata.key}, input {input_i}, output {output_i}: "
214
+ f"finite difference and autodiff gradients do not match"
215
+ )
216
+ if any_grad_nan:
217
+ raise ValueError(
218
+ f"Gradient check failed for kernel {metadata.key}, input {input_i}, output {output_i}: "
219
+ f"gradient contains NaN values"
220
+ )
221
+
206
222
  return success
207
223
 
208
224
 
@@ -221,6 +237,8 @@ def gradcheck_tape(
221
237
  plot_relative_error=False,
222
238
  plot_absolute_error=False,
223
239
  show_summary: bool = True,
240
+ reverse_launches: bool = False,
241
+ skip_to_launch_index: int = 0,
224
242
  ) -> bool:
225
243
  """
226
244
  Checks whether the autodiff gradients for kernels recorded on the Warp tape match finite differences.
@@ -247,6 +265,7 @@ def gradcheck_tape(
247
265
  plot_relative_error: If True, visualizes the relative error of the Jacobians in a plot (requires ``matplotlib``).
248
266
  plot_absolute_error: If True, visualizes the absolute error of the Jacobians in a plot (requires ``matplotlib``).
249
267
  show_summary: If True, prints a summary table of the gradient check results.
268
+ reverse_launches: If True, reverses the order of the kernel launches on the tape to check.
250
269
 
251
270
  Returns:
252
271
  True if the gradient check passes for all kernels on the tape, False otherwise.
@@ -263,7 +282,12 @@ def gradcheck_tape(
263
282
  whitelist_kernels = set(whitelist_kernels)
264
283
 
265
284
  overall_success = True
266
- for launch in tape.launches:
285
+ launches = reversed(tape.launches) if reverse_launches else tape.launches
286
+ for i, launch in enumerate(launches):
287
+ if i < skip_to_launch_index:
288
+ continue
289
+ if not isinstance(launch, tuple) and not isinstance(launch, list):
290
+ continue
267
291
  if not isinstance(launch[0], wp.Kernel):
268
292
  continue
269
293
  kernel, dim, max_blocks, inputs, outputs, device, block_dim = launch[:7]
@@ -271,6 +295,9 @@ def gradcheck_tape(
271
295
  continue
272
296
  if kernel.key in blacklist_kernels:
273
297
  continue
298
+ if not kernel.options.get("enable_backward", True):
299
+ continue
300
+
274
301
  input_output_mask = input_output_masks.get(kernel.key)
275
302
  success = gradcheck(
276
303
  kernel,
@@ -312,11 +339,95 @@ def infer_device(xs: list):
312
339
  return wp.get_preferred_device()
313
340
 
314
341
 
342
+ class FunctionMetadata:
343
+ """
344
+ Metadata holder for kernel functions or functions with Warp arrays as inputs/outputs.
345
+ """
346
+
347
+ def __init__(
348
+ self,
349
+ key: str = None,
350
+ input_labels: List[str] = None,
351
+ output_labels: List[str] = None,
352
+ input_strides: List[tuple] = None,
353
+ output_strides: List[tuple] = None,
354
+ input_dtypes: list = None,
355
+ output_dtypes: list = None,
356
+ ):
357
+ self.key = key
358
+ self.input_labels = input_labels
359
+ self.output_labels = output_labels
360
+ self.input_strides = input_strides
361
+ self.output_strides = output_strides
362
+ self.input_dtypes = input_dtypes
363
+ self.output_dtypes = output_dtypes
364
+
365
+ @property
366
+ def is_empty(self):
367
+ return self.key is None
368
+
369
+ def input_is_array(self, i: int):
370
+ return self.input_strides[i] is not None
371
+
372
+ def output_is_array(self, i: int):
373
+ return self.output_strides[i] is not None
374
+
375
+ def update_from_kernel(self, kernel: wp.Kernel, inputs: Sequence):
376
+ self.key = kernel.key
377
+ self.input_labels = [arg.label for arg in kernel.adj.args[: len(inputs)]]
378
+ self.output_labels = [arg.label for arg in kernel.adj.args[len(inputs) :]]
379
+ self.input_strides = []
380
+ self.output_strides = []
381
+ self.input_dtypes = []
382
+ self.output_dtypes = []
383
+ for arg in kernel.adj.args[: len(inputs)]:
384
+ if arg.type is wp.array:
385
+ self.input_strides.append(arg.type.strides)
386
+ self.input_dtypes.append(arg.type.dtype)
387
+ else:
388
+ self.input_strides.append(None)
389
+ self.input_dtypes.append(None)
390
+ for arg in kernel.adj.args[len(inputs) :]:
391
+ if arg.type is wp.array:
392
+ self.output_strides.append(arg.type.strides)
393
+ self.output_dtypes.append(arg.type.dtype)
394
+ else:
395
+ self.output_strides.append(None)
396
+ self.output_dtypes.append(None)
397
+
398
+ def update_from_function(self, function: Callable, inputs: Sequence, outputs: Sequence = None):
399
+ self.key = function.__name__
400
+ self.input_labels = list(inspect.signature(function).parameters.keys())
401
+ if outputs is None:
402
+ outputs = function(*inputs)
403
+ if isinstance(outputs, wp.array):
404
+ outputs = [outputs]
405
+ self.output_labels = [f"output_{i}" for i in range(len(outputs))]
406
+ self.input_strides = []
407
+ self.output_strides = []
408
+ self.input_dtypes = []
409
+ self.output_dtypes = []
410
+ for input in inputs:
411
+ if isinstance(input, wp.array):
412
+ self.input_strides.append(input.strides)
413
+ self.input_dtypes.append(input.dtype)
414
+ else:
415
+ self.input_strides.append(None)
416
+ self.input_dtypes.append(None)
417
+ for output in outputs:
418
+ if isinstance(output, wp.array):
419
+ self.output_strides.append(output.strides)
420
+ self.output_dtypes.append(output.dtype)
421
+ else:
422
+ self.output_strides.append(None)
423
+ self.output_dtypes.append(None)
424
+
425
+
315
426
  def jacobian_plot(
316
427
  jacobians: Dict[Tuple[int, int], wp.array],
317
- kernel: wp.Kernel,
318
- inputs: Sequence,
319
- outputs: Sequence,
428
+ kernel: Union[FunctionMetadata, wp.Kernel],
429
+ inputs: Sequence = None,
430
+ outputs: Sequence = None,
320
431
  show_plot=True,
321
432
  show_colorbar=True,
322
433
  scale_colors_per_submatrix=False,
@@ -330,9 +441,9 @@ def jacobian_plot(
330
441
 
331
442
  Args:
332
443
  jacobians: A dictionary of Jacobians, where the keys are tuples of input and output indices, and the values are the Jacobian matrices.
333
- kernel: The Warp kernel function, decorated with the ``@wp.kernel`` decorator.
444
+ kernel: The Warp kernel function, decorated with the ``@wp.kernel`` decorator, or a :class:`FunctionMetadata` instance with the kernel/function attributes.
334
445
  inputs: List of input variables.
335
- outputs: List of output variables.
446
+ outputs: List of output variables. Deprecated and will be removed in a future Warp version.
336
447
  show_plot: If True, displays the plot via ``plt.show()``.
337
448
  show_colorbar: If True, displays a colorbar next to the plot (or a colorbar next to every submatrix if ).
338
449
  scale_colors_per_submatrix: If True, considers the minimum and maximum of each Jacobian submatrix separately for color scaling. Otherwise, uses the global minimum and maximum of all Jacobians.
@@ -343,19 +454,39 @@ def jacobian_plot(
343
454
  Returns:
344
455
  The created Matplotlib figure.
345
456
  """
457
+
346
458
  import matplotlib.pyplot as plt
347
- from matplotlib.ticker import FuncFormatter, MaxNLocator, MultipleLocator
459
+ from matplotlib.ticker import MaxNLocator
460
+
461
+ if isinstance(kernel, wp.Kernel):
462
+ assert inputs is not None
463
+ metadata = FunctionMetadata()
464
+ metadata.update_from_kernel(kernel, inputs)
465
+ elif isinstance(kernel, FunctionMetadata):
466
+ metadata = kernel
467
+ else:
468
+ raise ValueError("Invalid kernel argument: must be a Warp kernel or a FunctionMetadata object")
469
+ if outputs is not None:
470
+ wp.utils.warn(
471
+ "The `outputs` argument to `jacobian_plot` is no longer needed and will be removed in a future Warp version.",
472
+ DeprecationWarning,
473
+ stacklevel=3,
474
+ )
348
475
 
349
476
  jacobians = sorted(jacobians.items(), key=lambda x: (x[0][1], x[0][0]))
350
477
  jacobians = dict(jacobians)
351
478
 
352
479
  input_to_ax = {}
353
480
  output_to_ax = {}
481
+ ax_to_input = {}
482
+ ax_to_output = {}
354
483
  for i, j in jacobians.keys():
355
484
  if i not in input_to_ax:
356
485
  input_to_ax[i] = len(input_to_ax)
486
+ ax_to_input[input_to_ax[i]] = i
357
487
  if j not in output_to_ax:
358
488
  output_to_ax[j] = len(output_to_ax)
489
+ ax_to_output[output_to_ax[j]] = j
359
490
 
360
491
  num_rows = len(output_to_ax)
361
492
  num_cols = len(input_to_ax)
@@ -366,19 +497,19 @@ def jacobian_plot(
366
497
  # dimensions of the Jacobians
367
498
  width_ratios = []
368
499
  height_ratios = []
369
- for i, input in enumerate(inputs):
370
- if not isinstance(input, wp.array) or not input.requires_grad:
500
+ for i in range(len(metadata.input_labels)):
501
+ if not metadata.input_is_array(i):
371
502
  continue
372
- input_stride = input.dtype._length_
373
- for j in range(len(outputs)):
503
+ input_stride = metadata.input_strides[i][0]
504
+ for j in range(len(metadata.output_labels)):
374
505
  if (i, j) not in jacobians:
375
506
  continue
376
507
  jac_wp = jacobians[(i, j)]
377
508
  width_ratios.append(jac_wp.shape[1] * input_stride)
378
509
  break
379
510
 
380
- for i, output in enumerate(outputs):
381
- if not isinstance(output, wp.array) or not output.requires_grad:
511
+ for i in range(len(metadata.output_labels)):
512
+ if not metadata.output_is_array(i):
382
513
  continue
383
514
  for j in range(len(inputs)):
384
515
  if (j, i) not in jacobians:
@@ -403,7 +534,8 @@ def jacobian_plot(
403
534
  squeeze=False,
404
535
  )
405
536
  if title is None:
406
- title = f"{kernel.key} kernel Jacobian"
537
+ key = kernel.key if isinstance(kernel, wp.Kernel) else kernel.get("key", "unknown")
538
+ title = f"{key} kernel Jacobian"
407
539
  fig.suptitle(title)
408
540
  fig.canvas.manager.set_window_title(title)
409
541
 
@@ -421,66 +553,31 @@ def jacobian_plot(
421
553
  has_plot = np.ones((num_rows, num_cols), dtype=bool)
422
554
  for i in range(num_rows):
423
555
  for j in range(num_cols):
424
- if (j, i) not in jacobians:
556
+ if (ax_to_input[j], ax_to_output[i]) not in jacobians:
425
557
  ax = axs[i, j]
426
558
  ax.axis("off")
427
559
  has_plot[i, j] = False
428
560
 
429
561
  jac_i = 0
430
562
  for (input_i, output_i), jac_wp in jacobians.items():
431
- input = inputs[input_i]
432
- output = outputs[output_i]
433
- if not isinstance(input, wp.array) or not input.requires_grad:
434
- continue
435
- if not isinstance(output, wp.array) or not output.requires_grad:
436
- continue
437
-
438
- input_name = kernel.adj.args[input_i].label
439
- output_name = kernel.adj.args[len(inputs) + output_i].label
563
+ input_name = metadata.input_labels[input_i]
564
+ output_name = metadata.output_labels[output_i]
440
565
 
441
566
  ax_i, ax_j = output_to_ax[output_i], input_to_ax[input_i]
442
567
  ax = axs[ax_i, ax_j]
443
568
  ax.tick_params(which="major", width=1, length=7)
444
569
  ax.tick_params(which="minor", width=1, length=4, color="gray")
445
- # ax.yaxis.set_minor_formatter('{x:.0f}')
446
570
 
447
- input_stride = input.dtype._length_
448
- output_stride = output.dtype._length_
571
+ input_stride = metadata.input_dtypes[input_i]._length_
572
+ # output_stride = metadata.output_dtypes[output_i]._length_
449
573
 
450
574
  jac = jac_wp.numpy()
451
575
  # Jacobian matrix has output stride already multiplied to first dimension
452
576
  jac = jac.reshape(jac_wp.shape[0], jac_wp.shape[1] * input_stride)
453
- ax.xaxis.set_minor_formatter("")
454
- ax.yaxis.set_minor_formatter("")
455
- ax.xaxis.set_minor_locator(MultipleLocator(1))
456
- ax.yaxis.set_minor_locator(MultipleLocator(1))
457
- # ax.set_xticks(np.arange(jac.shape[0]))
458
- # stride = jac.shape[1] // jacobians[jac_i].shape[1]
459
- # ax.xaxis.set_major_locator(MultipleLocator(input_stride))
460
- if input_stride > 1:
461
- ax.xaxis.set_major_locator(MaxNLocator(integer=True, nbins=1, steps=[input_stride]))
462
- ticks = FuncFormatter(lambda x, pos, input_stride=input_stride: "{0:g}".format(x // input_stride))
463
- ax.xaxis.set_major_formatter(ticks)
464
- # ax.xaxis.set_major_locator(FixedLocator(np.arange(0, jac.shape[1] + 1, input_stride)))
465
- # ax.xaxis.set_major_formatter('{x:.0f}')
466
- # ticks = np.arange(jac_wp.shape[1] + 1)
467
- # ax.set_xticklabels(ticks)
468
-
469
- # ax.yaxis.set_major_locator(FixedLocator(np.arange(0, jac.shape[0] + 1, output_stride)))
470
- # ax.yaxis.set_major_formatter('{x:.0f}')
471
- # ax.yaxis.set_major_locator(MultipleLocator(output_stride))
472
-
473
- if output_stride > 1:
474
- ax.yaxis.set_major_locator(MaxNLocator(integer=True, nbins=1, steps=[output_stride]))
475
- max_y = jac_wp.shape[0]
476
- ticks = FuncFormatter(
477
- lambda y, pos, max_y=max_y, output_stride=output_stride: "{0:g}".format((max_y - y) // output_stride)
478
- )
479
- ax.yaxis.set_major_formatter(ticks)
480
- # divide by output stride to get the correct number of rows
481
- ticks = np.arange(jac_wp.shape[0] // output_stride + 1)
482
- # flip y labels to match the order of matrix rows starting from the top
483
- # ax.set_yticklabels(ticks[::-1])
577
+
578
+ ax.xaxis.set_major_locator(MaxNLocator(integer=True))
579
+ ax.yaxis.set_major_locator(MaxNLocator(integer=True))
580
+
484
581
  if scale_colors_per_submatrix:
485
582
  safe_jac = jac[~np.isnan(jac)]
486
583
  vmin = safe_jac.min()
@@ -494,7 +591,7 @@ def jacobian_plot(
494
591
  vmin=vmin,
495
592
  vmax=vmax,
496
593
  )
497
- if ax_i == len(outputs) - 1 or not has_plot[ax_i + 1 :, ax_j].any():
594
+ if ax_i == num_rows - 1 or not has_plot[ax_i + 1 :, ax_j].any():
498
595
  # last plot of this column
499
596
  ax.set_xlabel(input_name)
500
597
  if ax_j == 0 or not has_plot[ax_i, :ax_j].any():
@@ -609,9 +706,9 @@ def scalarize_array_2d(arr):
609
706
 
610
707
 
611
708
  def jacobian(
612
- kernel: wp.Kernel,
613
- dim: Tuple[int],
614
- inputs: Sequence,
709
+ function: Union[wp.Kernel, Callable],
710
+ dim: Tuple[int] = None,
711
+ inputs: Sequence = None,
615
712
  outputs: Sequence = None,
616
713
  input_output_mask: List[Tuple[Union[str, int], Union[str, int]]] = None,
617
714
  device: wp.context.Devicelike = None,
@@ -619,40 +716,84 @@ def jacobian(
619
716
  block_dim=256,
620
717
  max_outputs_per_var=-1,
621
718
  plot_jacobians=False,
719
+ metadata: FunctionMetadata = None,
720
+ kernel: wp.Kernel = None,
622
721
  ) -> Dict[Tuple[int, int], wp.array]:
623
722
  """
624
- Computes the Jacobians of a Warp kernel launch for the provided selection of differentiable inputs to differentiable outputs.
723
+ Computes the Jacobians of a function or Warp kernel for the provided selection of differentiable inputs to differentiable outputs.
724
+
725
+ The input function can be either a Warp kernel (e.g. a function decorated by ``@wp.kernel``) or a regular Python function that accepts arguments (of which some must be Warp arrays) and returns a Warp array or a list of Warp arrays.
625
726
 
626
- The kernel adjoint function is launched with the given inputs and outputs, as well as the provided ``dim``,
727
+ In case ``function`` is a Warp kernel, its adjoint kernel is launched with the given inputs and outputs, as well as the provided ``dim``,
627
728
  ``max_blocks``, and ``block_dim`` arguments (see :func:`warp.launch` for more details).
628
729
 
629
730
  Note:
630
- This function only supports Warp kernels whose input arguments precede the output arguments.
731
+ If ``function`` is a Warp kernel, the input arguments must precede the output arguments in the kernel code definition.
631
732
 
632
733
  Only Warp arrays with ``requires_grad=True`` are considered for the Jacobian computation.
633
734
 
634
- Structs arguments are not yet supported by this function to compute Jacobians.
735
+ Function arguments of type :ref:`Struct <structs>` are not yet supported.
635
736
 
636
737
  Args:
637
- kernel: The Warp kernel function, decorated with the ``@wp.kernel`` decorator
638
- dim: The number of threads to launch the kernel, can be an integer, or a Tuple of ints
639
- inputs: List of input variables.
640
- outputs: List of output variables. If None, the outputs are inferred from the kernel argument flags.
738
+ function: The Warp kernel function, or a regular Python function that returns a Warp array or a list of Warp arrays.
739
+ dim: The number of threads to launch the kernel, can be an integer, or a Tuple of ints. Only required if ``function`` is a Warp kernel.
740
+ inputs: List of input variables. At least one of the arguments must be a Warp array with ``requires_grad=True``.
741
+ outputs: List of output variables. Optional if the function is a regular Python function that returns a Warp array or a list of Warp arrays. Only required if ``function`` is a Warp kernel.
641
742
  input_output_mask: List of tuples specifying the input-output pairs to compute the Jacobian for. Inputs and outputs can be identified either by their integer indices of where they appear in the kernel input/output arguments, or by the respective argument names as strings. If None, computes the Jacobian for all input-output pairs.
642
- device: The device to launch on (optional)
643
- max_blocks: The maximum number of CUDA thread blocks to use.
644
- block_dim: The number of threads per block.
743
+ device: The device to launch on (optional). Only used if ``function`` is a Warp kernel.
744
+ max_blocks: The maximum number of CUDA thread blocks to use. Only used if ``function`` is a Warp kernel.
745
+ block_dim: The number of threads per block. Only used if ``function`` is a Warp kernel.
645
746
  max_outputs_per_var: Maximum number of output dimensions over which to evaluate the Jacobians for the input-output pairs. Evaluates all output dimensions if value <= 0.
646
747
  plot_jacobians: If True, visualizes the computed Jacobians in a plot (requires ``matplotlib``).
748
+ metadata: The metadata of the kernel function, containing the input and output labels, strides, and dtypes. If None or empty, the metadata is inferred from the kernel or function.
749
+ kernel: Deprecated argument. Use the ``function`` argument instead.
647
750
 
648
751
  Returns:
649
752
  A dictionary of Jacobians, where the keys are tuples of input and output indices, and the values are the Jacobian matrices.
650
753
  """
651
- if outputs is None:
652
- outputs = []
653
754
  if input_output_mask is None:
654
755
  input_output_mask = []
655
- arg_names = [arg.label for arg in kernel.adj.args]
756
+ if kernel is not None:
757
+ wp.utils.warn(
758
+ "The argument `kernel` to the function `wp.autograd.jacobian` is deprecated in favor of the `function` argument and will be removed in a future Warp version.",
759
+ DeprecationWarning,
760
+ stacklevel=3,
761
+ )
762
+ function = kernel
763
+
764
+ if metadata is None:
765
+ metadata = FunctionMetadata()
766
+
767
+ if isinstance(function, wp.Kernel):
768
+ if not function.options.get("enable_backward", True):
769
+ raise ValueError("Kernel must have backward pass enabled to compute Jacobians")
770
+ if outputs is None or len(outputs) == 0:
771
+ raise ValueError("A list of output arguments must be provided to compute kernel Jacobians")
772
+ if device is None:
773
+ device = infer_device(inputs + outputs)
774
+ if metadata.is_empty:
775
+ metadata.update_from_kernel(function, inputs)
776
+
777
+ tape = wp.Tape()
778
+ tape.record_launch(
779
+ kernel=function,
780
+ dim=dim,
781
+ inputs=inputs,
782
+ outputs=outputs,
783
+ device=device,
784
+ max_blocks=max_blocks,
785
+ block_dim=block_dim,
786
+ )
787
+ else:
788
+ tape = wp.Tape()
789
+ with tape:
790
+ outputs = function(*inputs)
791
+ if isinstance(outputs, wp.array):
792
+ outputs = [outputs]
793
+ if metadata.is_empty:
794
+ metadata.update_from_function(function, inputs, outputs)
795
+
796
+ arg_names = metadata.input_labels + metadata.output_labels
656
797
 
657
798
  def resolve_arg(name, offset: int = 0):
658
799
  if isinstance(name, int):
@@ -665,19 +806,8 @@ def jacobian(
665
806
  ]
666
807
  input_output_mask = set(input_output_mask)
667
808
 
668
- if device is None:
669
- device = infer_device(inputs + outputs)
670
-
671
- tape = wp.Tape()
672
- tape.record_launch(
673
- kernel=kernel,
674
- dim=dim,
675
- inputs=inputs,
676
- outputs=outputs,
677
- device=device,
678
- max_blocks=max_blocks,
679
- block_dim=block_dim,
680
- )
809
+ zero_grads(inputs)
810
+ zero_grads(outputs)
681
811
 
682
812
  jacobians = {}
683
813
 
@@ -697,19 +827,21 @@ def jacobian(
697
827
  if max_outputs_per_var > 0:
698
828
  output_num = min(output_num, max_outputs_per_var)
699
829
  for i in range(output_num):
700
- tape.zero()
830
+ output.grad.zero_()
701
831
  if i > 0:
702
832
  set_element(out_grad, i - 1, 0.0)
703
833
  set_element(out_grad, i, 1.0)
704
834
  tape.backward()
705
835
  jacobian[i].assign(input.grad)
706
- output.grad.zero_()
836
+
837
+ zero_grads(inputs)
838
+ zero_grads(outputs)
707
839
  jacobians[input_i, output_i] = jacobian
708
840
 
709
841
  if plot_jacobians:
710
842
  jacobian_plot(
711
843
  jacobians,
712
- kernel,
844
+ metadata,
713
845
  inputs,
714
846
  outputs,
715
847
  )
@@ -718,53 +850,97 @@ def jacobian(
718
850
 
719
851
 
720
852
  def jacobian_fd(
721
- kernel: wp.Kernel,
722
- dim: Tuple[int],
723
- inputs: Sequence,
853
+ function: Union[wp.Kernel, Callable],
854
+ dim: Tuple[int] = None,
855
+ inputs: Sequence = None,
724
856
  outputs: Sequence = None,
725
857
  input_output_mask: List[Tuple[Union[str, int], Union[str, int]]] = None,
726
858
  device: wp.context.Devicelike = None,
727
859
  max_blocks=0,
728
860
  block_dim=256,
729
861
  max_inputs_per_var=-1,
730
- eps=1e-4,
862
+ eps: float = 1e-4,
731
863
  plot_jacobians=False,
864
+ metadata: FunctionMetadata = None,
865
+ kernel: wp.Kernel = None,
732
866
  ) -> Dict[Tuple[int, int], wp.array]:
733
867
  """
734
- Computes the finite-difference Jacobian of a Warp kernel launch for the provided selection of differentiable inputs to differentiable outputs.
868
+ Computes the finite-difference Jacobian of a function or Warp kernel for the provided selection of differentiable inputs to differentiable outputs.
735
869
  The method uses a central difference scheme to approximate the Jacobian.
736
870
 
737
- The kernel is launched multiple times in forward-only mode with the given inputs and outputs, as well as the
738
- provided ``dim``, ``max_blocks``, and ``block_dim`` arguments (see :func:`warp.launch` for more details).
871
+ The input function can be either a Warp kernel (e.g. a function decorated by ``@wp.kernel``) or a regular Python function that accepts arguments (of which some must be Warp arrays) and returns a Warp array or a list of Warp arrays.
872
+
873
+ The function is launched multiple times in forward-only mode with the given inputs. If ``function`` is a Warp kernel, the provided inputs and outputs,
874
+ as well as the other parameters ``dim``, ``max_blocks``, and ``block_dim`` are provided to the kernel launch (see :func:`warp.launch`).
739
875
 
740
876
  Note:
741
- This function only supports Warp kernels whose input arguments precede the output arguments.
877
+ If ``function`` is a Warp kernel, the input arguments must precede the output arguments in the kernel code definition.
742
878
 
743
879
  Only Warp arrays with ``requires_grad=True`` are considered for the Jacobian computation.
744
880
 
745
- Structs arguments are not yet supported by this function to compute Jacobians.
881
+ Function arguments of type :ref:`Struct <structs>` are not yet supported.
746
882
 
747
883
  Args:
748
- kernel: The Warp kernel function, decorated with the ``@wp.kernel`` decorator
749
- dim: The number of threads to launch the kernel, can be an integer, or a Tuple of ints
750
- inputs: List of input variables.
751
- outputs: List of output variables. If None, the outputs are inferred from the kernel argument flags.
884
+ function: The Warp kernel function, or a regular Python function that returns a Warp array or a list of Warp arrays.
885
+ dim: The number of threads to launch the kernel, can be an integer, or a Tuple of ints. Only required if ``function`` is a Warp kernel.
886
+ inputs: List of input variables. At least one of the arguments must be a Warp array with ``requires_grad=True``.
887
+ outputs: List of output variables. Optional if the function is a regular Python function that returns a Warp array or a list of Warp arrays. Only required if ``function`` is a Warp kernel.
752
888
  input_output_mask: List of tuples specifying the input-output pairs to compute the Jacobian for. Inputs and outputs can be identified either by their integer indices of where they appear in the kernel input/output arguments, or by the respective argument names as strings. If None, computes the Jacobian for all input-output pairs.
753
- device: The device to launch on (optional)
754
- max_blocks: The maximum number of CUDA thread blocks to use.
755
- block_dim: The number of threads per block.
889
+ device: The device to launch on (optional). Only used if ``function`` is a Warp kernel.
890
+ max_blocks: The maximum number of CUDA thread blocks to use. Only used if ``function`` is a Warp kernel.
891
+ block_dim: The number of threads per block. Only used if ``function`` is a Warp kernel.
756
892
  max_inputs_per_var: Maximum number of input dimensions over which to evaluate the Jacobians for the input-output pairs. Evaluates all input dimensions if value <= 0.
757
893
  eps: The finite-difference step size.
758
894
  plot_jacobians: If True, visualizes the computed Jacobians in a plot (requires ``matplotlib``).
895
+ metadata: The metadata of the kernel function, containing the input and output labels, strides, and dtypes. If None or empty, the metadata is inferred from the kernel or function.
896
+ kernel: Deprecated argument. Use the ``function`` argument instead.
759
897
 
760
898
  Returns:
761
899
  A dictionary of Jacobians, where the keys are tuples of input and output indices, and the values are the Jacobian matrices.
762
900
  """
763
- if outputs is None:
764
- outputs = []
765
901
  if input_output_mask is None:
766
902
  input_output_mask = []
767
- arg_names = [arg.label for arg in kernel.adj.args]
903
+ if kernel is not None:
904
+ wp.utils.warn(
905
+ "The argument `kernel` to the function `wp.autograd.jacobian` is deprecated in favor of the `function` argument and will be removed in a future Warp version.",
906
+ DeprecationWarning,
907
+ stacklevel=3,
908
+ )
909
+ function = kernel
910
+
911
+ if metadata is None:
912
+ metadata = FunctionMetadata()
913
+
914
+ if isinstance(function, wp.Kernel):
915
+ if not function.options.get("enable_backward", True):
916
+ raise ValueError("Kernel must have backward pass enabled to compute Jacobians")
917
+ if outputs is None or len(outputs) == 0:
918
+ raise ValueError("A list of output arguments must be provided to compute kernel Jacobians")
919
+ if device is None:
920
+ device = infer_device(inputs + outputs)
921
+ if metadata.is_empty:
922
+ metadata.update_from_kernel(function, inputs)
923
+
924
+ tape = wp.Tape()
925
+ tape.record_launch(
926
+ kernel=function,
927
+ dim=dim,
928
+ inputs=inputs,
929
+ outputs=outputs,
930
+ device=device,
931
+ max_blocks=max_blocks,
932
+ block_dim=block_dim,
933
+ )
934
+ else:
935
+ tape = wp.Tape()
936
+ with tape:
937
+ outputs = function(*inputs)
938
+ if isinstance(outputs, wp.array):
939
+ outputs = [outputs]
940
+ if metadata.is_empty:
941
+ metadata.update_from_function(function, inputs, outputs)
942
+
943
+ arg_names = metadata.input_labels + metadata.output_labels
768
944
 
769
945
  def resolve_arg(name, offset: int = 0):
770
946
  if isinstance(name, int):
@@ -777,11 +953,15 @@ def jacobian_fd(
777
953
  ]
778
954
  input_output_mask = set(input_output_mask)
779
955
 
780
- if device is None:
781
- device = infer_device(inputs + outputs)
782
-
783
956
  jacobians = {}
784
957
 
958
+ def conditional_clone(obj):
959
+ if isinstance(obj, wp.array):
960
+ return wp.clone(obj)
961
+ return obj
962
+
963
+ outputs_copy = [conditional_clone(output) for output in outputs]
964
+
785
965
  for input_i, output_i in itertools.product(range(len(inputs)), range(len(outputs))):
786
966
  if len(input_output_mask) > 0 and (input_i, output_i) not in input_output_mask:
787
967
  continue
@@ -796,13 +976,20 @@ def jacobian_fd(
796
976
 
797
977
  left = wp.clone(output)
798
978
  right = wp.clone(output)
979
+ left_copy = wp.clone(output)
980
+ right_copy = wp.clone(output)
799
981
  flat_left = scalarize_array_1d(left)
800
982
  flat_right = scalarize_array_1d(right)
801
983
 
802
- left_outputs = outputs[:output_i] + [left] + outputs[output_i + 1 :]
803
- right_outputs = outputs[:output_i] + [right] + outputs[output_i + 1 :]
984
+ outputs_until_left = [conditional_clone(output) for output in outputs_copy[:output_i]]
985
+ outputs_until_right = [conditional_clone(output) for output in outputs_copy[:output_i]]
986
+ outputs_after_left = [conditional_clone(output) for output in outputs_copy[output_i + 1 :]]
987
+ outputs_after_right = [conditional_clone(output) for output in outputs_copy[output_i + 1 :]]
988
+ left_outputs = outputs_until_left + [left] + outputs_after_left
989
+ right_outputs = outputs_until_right + [right] + outputs_after_right
804
990
 
805
991
  input_num = flat_input.shape[0]
992
+ flat_input_copy = wp.clone(flat_input)
806
993
  jacobian = wp.empty((flat_left.size, input.size), dtype=input.dtype, device=input.device)
807
994
  jacobian.fill_(wp.nan)
808
995
 
@@ -812,38 +999,62 @@ def jacobian_fd(
812
999
  input_num = min(input_num, max_inputs_per_var)
813
1000
  for i in range(input_num):
814
1001
  set_element(flat_input, i, -eps, relative=True)
815
- wp.launch(
816
- kernel,
817
- dim=dim,
818
- inputs=inputs,
819
- outputs=left_outputs,
820
- device=device,
821
- max_blocks=max_blocks,
822
- block_dim=block_dim,
823
- )
1002
+ if isinstance(function, wp.Kernel):
1003
+ wp.launch(
1004
+ function,
1005
+ dim=dim,
1006
+ max_blocks=max_blocks,
1007
+ block_dim=block_dim,
1008
+ inputs=inputs,
1009
+ outputs=left_outputs,
1010
+ device=device,
1011
+ )
1012
+ else:
1013
+ outputs = function(*inputs)
1014
+ if isinstance(outputs, wp.array):
1015
+ outputs = [outputs]
1016
+ left.assign(outputs[output_i])
824
1017
 
825
1018
  set_element(flat_input, i, 2 * eps, relative=True)
826
- wp.launch(
827
- kernel,
828
- dim=dim,
829
- inputs=inputs,
830
- outputs=right_outputs,
831
- device=device,
832
- max_blocks=max_blocks,
833
- block_dim=block_dim,
1019
+ if isinstance(function, wp.Kernel):
1020
+ wp.launch(
1021
+ function,
1022
+ dim=dim,
1023
+ max_blocks=max_blocks,
1024
+ block_dim=block_dim,
1025
+ inputs=inputs,
1026
+ outputs=right_outputs,
1027
+ device=device,
1028
+ )
1029
+ else:
1030
+ outputs = function(*inputs)
1031
+ if isinstance(outputs, wp.array):
1032
+ outputs = [outputs]
1033
+ right.assign(outputs[output_i])
1034
+
1035
+ # restore input
1036
+ flat_input.assign(flat_input_copy)
1037
+
1038
+ compute_fd(
1039
+ flat_left,
1040
+ flat_right,
1041
+ eps,
1042
+ jacobian_t[i],
834
1043
  )
835
1044
 
836
- set_element(flat_input, i, -eps, relative=True)
837
-
838
- compute_fd(flat_left, flat_right, eps, jacobian_t[i])
1045
+ if i < input_num - 1:
1046
+ # reset output buffers
1047
+ left.assign(left_copy)
1048
+ right.assign(right_copy)
1049
+ flat_left = scalarize_array_1d(left)
1050
+ flat_right = scalarize_array_1d(right)
839
1051
 
840
- output.grad.zero_()
841
1052
  jacobians[input_i, output_i] = jacobian
842
1053
 
843
1054
  if plot_jacobians:
844
1055
  jacobian_plot(
845
1056
  jacobians,
846
- kernel,
1057
+ metadata,
847
1058
  inputs,
848
1059
  outputs,
849
1060
  )
@@ -864,7 +1075,7 @@ def set_element(a: wp.array(dtype=Any), i: int, val: Any, relative: bool = False
864
1075
 
865
1076
 
866
1077
  @wp.kernel(enable_backward=False)
867
- def compute_fd_kernel(left: wp.array(dtype=Any), right: wp.array(dtype=Any), eps: Any, fd: wp.array(dtype=Any)):
1078
+ def compute_fd_kernel(left: wp.array(dtype=float), right: wp.array(dtype=float), eps: float, fd: wp.array(dtype=float)):
868
1079
  tid = wp.tid()
869
1080
  fd[tid] = (right[tid] - left[tid]) / (2.0 * eps)
870
1081
 
@@ -883,7 +1094,10 @@ def compute_error_kernel(
883
1094
  tid = wp.tid()
884
1095
  ad = jacobian_ad[tid]
885
1096
  fd = jacobian_fd[tid]
886
- relative_error[tid] = (ad - fd) / (ad + 1e-8)
1097
+ denom = ad
1098
+ if abs(ad) < 1e-8:
1099
+ denom = (type(ad))(1e-8)
1100
+ relative_error[tid] = (ad - fd) / denom
887
1101
  absolute_error[tid] = wp.abs(ad - fd)
888
1102
 
889
1103
 
@@ -909,3 +1123,12 @@ def print_table(headers, cells):
909
1123
  for cell, col_width in zip(cell_row, col_widths):
910
1124
  print(f"{cell:{col_width}}", end=" | ")
911
1125
  print()
1126
+
1127
+
1128
+ def zero_grads(arrays: list):
1129
+ """
1130
+ Zeros the gradients of all Warp arrays in the given list.
1131
+ """
1132
+ for array in arrays:
1133
+ if isinstance(array, wp.array) and array.requires_grad:
1134
+ array.grad.zero_()