warp-lang 1.9.0__py3-none-manylinux_2_34_aarch64.whl → 1.10.0__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (350) hide show
  1. warp/__init__.py +301 -287
  2. warp/__init__.pyi +2302 -307
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1077 -0
  5. warp/_src/build.py +620 -0
  6. warp/_src/build_dll.py +642 -0
  7. warp/{builtins.py → _src/builtins.py} +1546 -224
  8. warp/_src/codegen.py +4361 -0
  9. warp/{config.py → _src/config.py} +178 -169
  10. warp/_src/constants.py +59 -0
  11. warp/_src/context.py +8352 -0
  12. warp/_src/dlpack.py +464 -0
  13. warp/_src/fabric.py +362 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +510 -0
  16. warp/_src/fem/cache.py +689 -0
  17. warp/_src/fem/dirichlet.py +190 -0
  18. warp/{fem → _src/fem}/domain.py +42 -30
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +703 -0
  21. warp/{fem → _src/fem}/field/nodal_field.py +32 -15
  22. warp/{fem → _src/fem}/field/restriction.py +3 -1
  23. warp/{fem → _src/fem}/field/virtual.py +55 -27
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +79 -163
  26. warp/_src/fem/geometry/closest_point.py +99 -0
  27. warp/{fem → _src/fem}/geometry/deformed_geometry.py +16 -22
  28. warp/{fem → _src/fem}/geometry/element.py +34 -10
  29. warp/{fem → _src/fem}/geometry/geometry.py +50 -20
  30. warp/{fem → _src/fem}/geometry/grid_2d.py +14 -23
  31. warp/{fem → _src/fem}/geometry/grid_3d.py +14 -23
  32. warp/{fem → _src/fem}/geometry/hexmesh.py +42 -63
  33. warp/{fem → _src/fem}/geometry/nanogrid.py +256 -247
  34. warp/{fem → _src/fem}/geometry/partition.py +123 -63
  35. warp/{fem → _src/fem}/geometry/quadmesh.py +28 -45
  36. warp/{fem → _src/fem}/geometry/tetmesh.py +42 -63
  37. warp/{fem → _src/fem}/geometry/trimesh.py +28 -45
  38. warp/{fem → _src/fem}/integrate.py +166 -158
  39. warp/_src/fem/linalg.py +385 -0
  40. warp/_src/fem/operator.py +398 -0
  41. warp/_src/fem/polynomial.py +231 -0
  42. warp/{fem → _src/fem}/quadrature/pic_quadrature.py +17 -20
  43. warp/{fem → _src/fem}/quadrature/quadrature.py +97 -47
  44. warp/_src/fem/space/__init__.py +248 -0
  45. warp/{fem → _src/fem}/space/basis_function_space.py +22 -11
  46. warp/_src/fem/space/basis_space.py +681 -0
  47. warp/{fem → _src/fem}/space/dof_mapper.py +5 -3
  48. warp/{fem → _src/fem}/space/function_space.py +16 -13
  49. warp/{fem → _src/fem}/space/grid_2d_function_space.py +6 -7
  50. warp/{fem → _src/fem}/space/grid_3d_function_space.py +6 -4
  51. warp/{fem → _src/fem}/space/hexmesh_function_space.py +6 -10
  52. warp/{fem → _src/fem}/space/nanogrid_function_space.py +5 -9
  53. warp/{fem → _src/fem}/space/partition.py +119 -60
  54. warp/{fem → _src/fem}/space/quadmesh_function_space.py +6 -10
  55. warp/{fem → _src/fem}/space/restriction.py +68 -33
  56. warp/_src/fem/space/shape/__init__.py +152 -0
  57. warp/{fem → _src/fem}/space/shape/cube_shape_function.py +11 -9
  58. warp/{fem → _src/fem}/space/shape/shape_function.py +10 -9
  59. warp/{fem → _src/fem}/space/shape/square_shape_function.py +8 -6
  60. warp/{fem → _src/fem}/space/shape/tet_shape_function.py +5 -3
  61. warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +5 -3
  62. warp/{fem → _src/fem}/space/tetmesh_function_space.py +5 -9
  63. warp/_src/fem/space/topology.py +461 -0
  64. warp/{fem → _src/fem}/space/trimesh_function_space.py +5 -9
  65. warp/_src/fem/types.py +114 -0
  66. warp/_src/fem/utils.py +488 -0
  67. warp/_src/jax.py +188 -0
  68. warp/_src/jax_experimental/__init__.py +14 -0
  69. warp/_src/jax_experimental/custom_call.py +389 -0
  70. warp/_src/jax_experimental/ffi.py +1286 -0
  71. warp/_src/jax_experimental/xla_ffi.py +658 -0
  72. warp/_src/marching_cubes.py +710 -0
  73. warp/_src/math.py +416 -0
  74. warp/_src/optim/__init__.py +14 -0
  75. warp/_src/optim/adam.py +165 -0
  76. warp/_src/optim/linear.py +1608 -0
  77. warp/_src/optim/sgd.py +114 -0
  78. warp/_src/paddle.py +408 -0
  79. warp/_src/render/__init__.py +14 -0
  80. warp/_src/render/imgui_manager.py +291 -0
  81. warp/_src/render/render_opengl.py +3638 -0
  82. warp/_src/render/render_usd.py +939 -0
  83. warp/_src/render/utils.py +162 -0
  84. warp/_src/sparse.py +2718 -0
  85. warp/_src/tape.py +1208 -0
  86. warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
  87. warp/_src/torch.py +393 -0
  88. warp/_src/types.py +5888 -0
  89. warp/_src/utils.py +1695 -0
  90. warp/autograd.py +12 -1054
  91. warp/bin/warp-clang.so +0 -0
  92. warp/bin/warp.so +0 -0
  93. warp/build.py +8 -588
  94. warp/build_dll.py +6 -471
  95. warp/codegen.py +6 -4246
  96. warp/constants.py +6 -39
  97. warp/context.py +12 -7851
  98. warp/dlpack.py +6 -444
  99. warp/examples/distributed/example_jacobi_mpi.py +4 -5
  100. warp/examples/fem/example_adaptive_grid.py +1 -1
  101. warp/examples/fem/example_apic_fluid.py +1 -1
  102. warp/examples/fem/example_burgers.py +8 -8
  103. warp/examples/fem/example_diffusion.py +1 -1
  104. warp/examples/fem/example_distortion_energy.py +1 -1
  105. warp/examples/fem/example_mixed_elasticity.py +2 -2
  106. warp/examples/fem/example_navier_stokes.py +1 -1
  107. warp/examples/fem/example_nonconforming_contact.py +7 -7
  108. warp/examples/fem/example_stokes.py +1 -1
  109. warp/examples/fem/example_stokes_transfer.py +1 -1
  110. warp/examples/fem/utils.py +2 -2
  111. warp/examples/interop/example_jax_callable.py +1 -1
  112. warp/examples/interop/example_jax_ffi_callback.py +1 -1
  113. warp/examples/interop/example_jax_kernel.py +3 -2
  114. warp/examples/tile/example_tile_mcgp.py +191 -0
  115. warp/fabric.py +6 -337
  116. warp/fem/__init__.py +159 -97
  117. warp/fem/adaptivity.py +7 -489
  118. warp/fem/cache.py +9 -648
  119. warp/fem/dirichlet.py +6 -184
  120. warp/fem/field/__init__.py +8 -109
  121. warp/fem/field/field.py +7 -652
  122. warp/fem/geometry/__init__.py +7 -18
  123. warp/fem/geometry/closest_point.py +11 -77
  124. warp/fem/linalg.py +18 -366
  125. warp/fem/operator.py +11 -369
  126. warp/fem/polynomial.py +9 -209
  127. warp/fem/space/__init__.py +5 -211
  128. warp/fem/space/basis_space.py +6 -662
  129. warp/fem/space/shape/__init__.py +41 -118
  130. warp/fem/space/topology.py +6 -437
  131. warp/fem/types.py +6 -81
  132. warp/fem/utils.py +11 -444
  133. warp/jax.py +8 -165
  134. warp/jax_experimental/__init__.py +14 -1
  135. warp/jax_experimental/custom_call.py +8 -342
  136. warp/jax_experimental/ffi.py +17 -853
  137. warp/jax_experimental/xla_ffi.py +5 -596
  138. warp/marching_cubes.py +5 -689
  139. warp/math.py +16 -393
  140. warp/native/array.h +385 -37
  141. warp/native/builtin.h +316 -39
  142. warp/native/bvh.cpp +43 -9
  143. warp/native/bvh.cu +62 -27
  144. warp/native/bvh.h +310 -309
  145. warp/native/clang/clang.cpp +102 -97
  146. warp/native/coloring.cpp +0 -1
  147. warp/native/crt.h +208 -0
  148. warp/native/exports.h +156 -0
  149. warp/native/hashgrid.cu +2 -0
  150. warp/native/intersect.h +24 -1
  151. warp/native/intersect_tri.h +44 -35
  152. warp/native/mat.h +1456 -276
  153. warp/native/mesh.cpp +4 -4
  154. warp/native/mesh.cu +4 -2
  155. warp/native/mesh.h +176 -61
  156. warp/native/quat.h +0 -52
  157. warp/native/scan.cu +2 -0
  158. warp/native/sort.cu +22 -13
  159. warp/native/sort.h +2 -0
  160. warp/native/sparse.cu +7 -3
  161. warp/native/spatial.h +12 -0
  162. warp/native/tile.h +837 -70
  163. warp/native/tile_radix_sort.h +3 -3
  164. warp/native/tile_reduce.h +394 -46
  165. warp/native/tile_scan.h +4 -4
  166. warp/native/vec.h +469 -53
  167. warp/native/version.h +23 -0
  168. warp/native/volume.cpp +1 -1
  169. warp/native/volume.cu +1 -0
  170. warp/native/volume.h +1 -1
  171. warp/native/volume_builder.cu +2 -0
  172. warp/native/warp.cpp +60 -32
  173. warp/native/warp.cu +581 -280
  174. warp/native/warp.h +14 -11
  175. warp/optim/__init__.py +6 -3
  176. warp/optim/adam.py +6 -145
  177. warp/optim/linear.py +14 -1585
  178. warp/optim/sgd.py +6 -94
  179. warp/paddle.py +6 -388
  180. warp/render/__init__.py +8 -4
  181. warp/render/imgui_manager.py +7 -267
  182. warp/render/render_opengl.py +6 -3616
  183. warp/render/render_usd.py +6 -918
  184. warp/render/utils.py +6 -142
  185. warp/sparse.py +37 -2563
  186. warp/tape.py +6 -1188
  187. warp/tests/__main__.py +1 -1
  188. warp/tests/cuda/test_async.py +4 -4
  189. warp/tests/cuda/test_conditional_captures.py +1 -1
  190. warp/tests/cuda/test_multigpu.py +1 -1
  191. warp/tests/cuda/test_streams.py +58 -1
  192. warp/tests/geometry/test_bvh.py +157 -22
  193. warp/tests/geometry/test_hash_grid.py +38 -0
  194. warp/tests/geometry/test_marching_cubes.py +0 -1
  195. warp/tests/geometry/test_mesh.py +5 -3
  196. warp/tests/geometry/test_mesh_query_aabb.py +5 -12
  197. warp/tests/geometry/test_mesh_query_point.py +5 -2
  198. warp/tests/geometry/test_mesh_query_ray.py +15 -3
  199. warp/tests/geometry/test_volume_write.py +5 -5
  200. warp/tests/interop/test_dlpack.py +18 -17
  201. warp/tests/interop/test_jax.py +1382 -79
  202. warp/tests/interop/test_paddle.py +1 -1
  203. warp/tests/test_adam.py +0 -1
  204. warp/tests/test_arithmetic.py +9 -9
  205. warp/tests/test_array.py +580 -100
  206. warp/tests/test_array_reduce.py +3 -3
  207. warp/tests/test_atomic.py +12 -8
  208. warp/tests/test_atomic_bitwise.py +209 -0
  209. warp/tests/test_atomic_cas.py +4 -4
  210. warp/tests/test_bool.py +2 -2
  211. warp/tests/test_builtins_resolution.py +5 -571
  212. warp/tests/test_codegen.py +34 -15
  213. warp/tests/test_conditional.py +1 -1
  214. warp/tests/test_context.py +6 -6
  215. warp/tests/test_copy.py +242 -161
  216. warp/tests/test_ctypes.py +3 -3
  217. warp/tests/test_devices.py +24 -2
  218. warp/tests/test_examples.py +16 -84
  219. warp/tests/test_fabricarray.py +35 -35
  220. warp/tests/test_fast_math.py +0 -2
  221. warp/tests/test_fem.py +60 -14
  222. warp/tests/test_fixedarray.py +3 -3
  223. warp/tests/test_func.py +8 -5
  224. warp/tests/test_generics.py +1 -1
  225. warp/tests/test_indexedarray.py +24 -24
  226. warp/tests/test_intersect.py +39 -9
  227. warp/tests/test_large.py +1 -1
  228. warp/tests/test_lerp.py +3 -1
  229. warp/tests/test_linear_solvers.py +1 -1
  230. warp/tests/test_map.py +49 -4
  231. warp/tests/test_mat.py +52 -62
  232. warp/tests/test_mat_constructors.py +4 -5
  233. warp/tests/test_mat_lite.py +1 -1
  234. warp/tests/test_mat_scalar_ops.py +121 -121
  235. warp/tests/test_math.py +34 -0
  236. warp/tests/test_module_aot.py +4 -4
  237. warp/tests/test_modules_lite.py +28 -2
  238. warp/tests/test_print.py +11 -11
  239. warp/tests/test_quat.py +93 -58
  240. warp/tests/test_runlength_encode.py +1 -1
  241. warp/tests/test_scalar_ops.py +38 -10
  242. warp/tests/test_smoothstep.py +1 -1
  243. warp/tests/test_sparse.py +126 -15
  244. warp/tests/test_spatial.py +105 -87
  245. warp/tests/test_special_values.py +6 -6
  246. warp/tests/test_static.py +7 -7
  247. warp/tests/test_struct.py +13 -2
  248. warp/tests/test_triangle_closest_point.py +48 -1
  249. warp/tests/test_tuple.py +96 -0
  250. warp/tests/test_types.py +82 -9
  251. warp/tests/test_utils.py +52 -52
  252. warp/tests/test_vec.py +29 -29
  253. warp/tests/test_vec_constructors.py +5 -5
  254. warp/tests/test_vec_scalar_ops.py +97 -97
  255. warp/tests/test_version.py +75 -0
  256. warp/tests/tile/test_tile.py +239 -0
  257. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  258. warp/tests/tile/test_tile_cholesky.py +7 -4
  259. warp/tests/tile/test_tile_load.py +26 -2
  260. warp/tests/tile/test_tile_mathdx.py +3 -3
  261. warp/tests/tile/test_tile_matmul.py +1 -1
  262. warp/tests/tile/test_tile_mlp.py +2 -4
  263. warp/tests/tile/test_tile_reduce.py +214 -13
  264. warp/tests/unittest_suites.py +6 -14
  265. warp/tests/unittest_utils.py +10 -9
  266. warp/tests/walkthrough_debug.py +3 -1
  267. warp/torch.py +6 -373
  268. warp/types.py +29 -5750
  269. warp/utils.py +10 -1659
  270. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0.dist-info}/METADATA +47 -103
  271. warp_lang-1.10.0.dist-info/RECORD +468 -0
  272. warp_lang-1.10.0.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  273. warp_lang-1.10.0.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  274. warp_lang-1.10.0.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  275. warp_lang-1.10.0.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  276. warp_lang-1.10.0.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  277. warp_lang-1.10.0.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  278. warp_lang-1.10.0.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  279. warp_lang-1.10.0.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  280. warp_lang-1.10.0.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  281. warp_lang-1.10.0.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  282. warp_lang-1.10.0.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  283. warp_lang-1.10.0.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  284. warp_lang-1.10.0.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  285. warp_lang-1.10.0.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  286. warp_lang-1.10.0.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  287. warp/examples/assets/cartpole.urdf +0 -110
  288. warp/examples/assets/crazyflie.usd +0 -0
  289. warp/examples/assets/nv_ant.xml +0 -92
  290. warp/examples/assets/nv_humanoid.xml +0 -183
  291. warp/examples/assets/quadruped.urdf +0 -268
  292. warp/examples/optim/example_bounce.py +0 -266
  293. warp/examples/optim/example_cloth_throw.py +0 -228
  294. warp/examples/optim/example_drone.py +0 -870
  295. warp/examples/optim/example_inverse_kinematics.py +0 -182
  296. warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
  297. warp/examples/optim/example_softbody_properties.py +0 -400
  298. warp/examples/optim/example_spring_cage.py +0 -245
  299. warp/examples/optim/example_trajectory.py +0 -227
  300. warp/examples/sim/example_cartpole.py +0 -143
  301. warp/examples/sim/example_cloth.py +0 -225
  302. warp/examples/sim/example_cloth_self_contact.py +0 -316
  303. warp/examples/sim/example_granular.py +0 -130
  304. warp/examples/sim/example_granular_collision_sdf.py +0 -202
  305. warp/examples/sim/example_jacobian_ik.py +0 -244
  306. warp/examples/sim/example_particle_chain.py +0 -124
  307. warp/examples/sim/example_quadruped.py +0 -203
  308. warp/examples/sim/example_rigid_chain.py +0 -203
  309. warp/examples/sim/example_rigid_contact.py +0 -195
  310. warp/examples/sim/example_rigid_force.py +0 -133
  311. warp/examples/sim/example_rigid_gyroscopic.py +0 -115
  312. warp/examples/sim/example_rigid_soft_contact.py +0 -140
  313. warp/examples/sim/example_soft_body.py +0 -196
  314. warp/examples/tile/example_tile_walker.py +0 -327
  315. warp/sim/__init__.py +0 -74
  316. warp/sim/articulation.py +0 -793
  317. warp/sim/collide.py +0 -2570
  318. warp/sim/graph_coloring.py +0 -307
  319. warp/sim/import_mjcf.py +0 -791
  320. warp/sim/import_snu.py +0 -227
  321. warp/sim/import_urdf.py +0 -579
  322. warp/sim/import_usd.py +0 -898
  323. warp/sim/inertia.py +0 -357
  324. warp/sim/integrator.py +0 -245
  325. warp/sim/integrator_euler.py +0 -2000
  326. warp/sim/integrator_featherstone.py +0 -2101
  327. warp/sim/integrator_vbd.py +0 -2487
  328. warp/sim/integrator_xpbd.py +0 -3295
  329. warp/sim/model.py +0 -4821
  330. warp/sim/particles.py +0 -121
  331. warp/sim/render.py +0 -431
  332. warp/sim/utils.py +0 -431
  333. warp/tests/sim/disabled_kinematics.py +0 -244
  334. warp/tests/sim/test_cloth.py +0 -863
  335. warp/tests/sim/test_collision.py +0 -743
  336. warp/tests/sim/test_coloring.py +0 -347
  337. warp/tests/sim/test_inertia.py +0 -161
  338. warp/tests/sim/test_model.py +0 -226
  339. warp/tests/sim/test_sim_grad.py +0 -287
  340. warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
  341. warp/tests/sim/test_sim_kinematics.py +0 -98
  342. warp/thirdparty/__init__.py +0 -0
  343. warp_lang-1.9.0.dist-info/RECORD +0 -456
  344. /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
  345. /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
  346. /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
  347. /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
  348. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0.dist-info}/WHEEL +0 -0
  349. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0.dist-info}/licenses/LICENSE.md +0 -0
  350. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0.dist-info}/top_level.txt +0 -0
warp/_src/autograd.py ADDED
@@ -0,0 +1,1077 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from __future__ import annotations
17
+
18
+ import inspect
19
+ import itertools
20
+ from typing import Any, Callable, Sequence
21
+
22
+ import numpy as np
23
+
24
+ import warp as wp
25
+
26
+ _wp_module_name_ = "warp.autograd"
27
+
28
+ __all__ = [
29
+ "gradcheck",
30
+ "gradcheck_tape",
31
+ "jacobian",
32
+ "jacobian_fd",
33
+ "jacobian_plot",
34
+ ]
35
+
36
+
37
+ def gradcheck(
38
+ function: wp.Kernel | Callable,
39
+ dim: tuple[int] | None = None,
40
+ inputs: Sequence | None = None,
41
+ outputs: Sequence | None = None,
42
+ *,
43
+ eps: float = 1e-4,
44
+ atol: float = 1e-3,
45
+ rtol: float = 1e-2,
46
+ raise_exception: bool = True,
47
+ input_output_mask: list[tuple[str | int, str | int]] | None = None,
48
+ device: wp.context.Devicelike = None,
49
+ max_blocks: int = 0,
50
+ block_dim: int = 256,
51
+ max_inputs_per_var: int = -1,
52
+ max_outputs_per_var: int = -1,
53
+ plot_relative_error: bool = False,
54
+ plot_absolute_error: bool = False,
55
+ show_summary: bool = True,
56
+ ) -> bool:
57
+ """
58
+ Checks whether the autodiff gradient of a Warp kernel matches finite differences.
59
+
60
+ Given the autodiff (:math:`\\nabla_\\text{AD}`) and finite difference gradients (:math:`\\nabla_\\text{FD}`), the check succeeds if the autodiff gradients contain no NaN values and the following condition holds:
61
+
62
+ .. math::
63
+
64
+ |\\nabla_\\text{AD} - \\nabla_\\text{FD}| \\leq atol + rtol \\cdot |\\nabla_\\text{FD}|.
65
+
66
+ The kernel function and its adjoint version are launched with the given inputs and outputs, as well as the provided
67
+ ``dim``, ``max_blocks``, and ``block_dim`` arguments (see :func:`warp.launch` for more details).
68
+
69
+ Note:
70
+ This function only supports Warp kernels whose input arguments precede the output arguments.
71
+
72
+ Only Warp arrays with ``requires_grad=True`` are considered for the Jacobian computation.
73
+
74
+ Structs arguments are not yet supported by this function to compute Jacobians.
75
+
76
+ Args:
77
+ function: The Warp kernel function, decorated with the ``@wp.kernel`` decorator, or any function that involves Warp kernel launches.
78
+ dim: The number of threads to launch the kernel, can be an integer, or a Tuple of ints. Only required if the function is a Warp kernel.
79
+ inputs: List of input variables.
80
+ outputs: List of output variables. Only required if the function is a Warp kernel.
81
+ eps: The finite-difference step size.
82
+ atol: The absolute tolerance for the gradient check.
83
+ rtol: The relative tolerance for the gradient check.
84
+ raise_exception: If True, raises a `ValueError` if the gradient check fails.
85
+ input_output_mask: List of tuples specifying the input-output pairs to compute the Jacobian for. Inputs and outputs can be identified either by their integer indices of where they appear in the kernel input/output arguments, or by the respective argument names as strings. If None, computes the Jacobian for all input-output pairs.
86
+ device: The device to launch on (optional)
87
+ max_blocks: The maximum number of CUDA thread blocks to use.
88
+ block_dim: The number of threads per block.
89
+ max_inputs_per_var: Maximum number of input dimensions over which to evaluate the Jacobians for the input-output pairs. Evaluates all input dimensions if value <= 0.
90
+ max_outputs_per_var: Maximum number of output dimensions over which to evaluate the Jacobians for the input-output pairs. Evaluates all output dimensions if value <= 0.
91
+ plot_relative_error: If True, visualizes the relative error of the Jacobians in a plot (requires ``matplotlib``).
92
+ plot_absolute_error: If True, visualizes the absolute error of the Jacobians in a plot (requires ``matplotlib``).
93
+ show_summary: If True, prints a summary table of the gradient check results.
94
+
95
+ Returns:
96
+ True if the gradient check passes, False otherwise.
97
+ """
98
+
99
+ if inputs is None:
100
+ raise ValueError("The inputs argument must be provided")
101
+
102
+ metadata = FunctionMetadata()
103
+
104
+ jacs_ad = jacobian(
105
+ function,
106
+ dim=dim,
107
+ inputs=inputs,
108
+ outputs=outputs,
109
+ input_output_mask=input_output_mask,
110
+ device=device,
111
+ max_blocks=max_blocks,
112
+ block_dim=block_dim,
113
+ max_outputs_per_var=max_outputs_per_var,
114
+ plot_jacobians=False,
115
+ metadata=metadata,
116
+ )
117
+ jacs_fd = jacobian_fd(
118
+ function,
119
+ dim=dim,
120
+ inputs=inputs,
121
+ outputs=outputs,
122
+ input_output_mask=input_output_mask,
123
+ device=device,
124
+ max_blocks=max_blocks,
125
+ block_dim=block_dim,
126
+ max_inputs_per_var=max_inputs_per_var,
127
+ eps=eps,
128
+ plot_jacobians=False,
129
+ metadata=metadata,
130
+ )
131
+
132
+ relative_error_jacs = {}
133
+ absolute_error_jacs = {}
134
+
135
+ if show_summary:
136
+ summary = []
137
+ summary_header = ["Input", "Output", "Max Abs Error", "AD at MAE", "FD at MAE", "Max Rel Error", "Pass"]
138
+
139
+ class FontColors:
140
+ OKGREEN = "\033[92m"
141
+ WARNING = "\033[93m"
142
+ FAIL = "\033[91m"
143
+ ENDC = "\033[0m"
144
+
145
+ success = True
146
+ any_grad_mismatch = False
147
+ any_grad_nan = False
148
+ for (input_i, output_i), jac_fd in jacs_fd.items():
149
+ jac_ad = jacs_ad[input_i, output_i]
150
+ if plot_relative_error or plot_absolute_error:
151
+ jac_rel_error = wp.empty_like(jac_fd)
152
+ jac_abs_error = wp.empty_like(jac_fd)
153
+ flat_jac_fd = scalarize_array_1d(jac_fd)
154
+ flat_jac_ad = scalarize_array_1d(jac_ad)
155
+ flat_jac_rel_error = scalarize_array_1d(jac_rel_error)
156
+ flat_jac_abs_error = scalarize_array_1d(jac_abs_error)
157
+ wp.launch(
158
+ compute_error_kernel,
159
+ dim=len(flat_jac_fd),
160
+ inputs=[flat_jac_ad, flat_jac_fd, flat_jac_rel_error, flat_jac_abs_error],
161
+ device=jac_fd.device,
162
+ )
163
+ relative_error_jacs[(input_i, output_i)] = jac_rel_error
164
+ absolute_error_jacs[(input_i, output_i)] = jac_abs_error
165
+ cut_jac_fd = jac_fd.numpy()
166
+ cut_jac_ad = jac_ad.numpy()
167
+ if max_outputs_per_var > 0:
168
+ cut_jac_fd = cut_jac_fd[:max_outputs_per_var]
169
+ cut_jac_ad = cut_jac_ad[:max_outputs_per_var]
170
+ if max_inputs_per_var > 0:
171
+ cut_jac_fd = cut_jac_fd[:, :max_inputs_per_var]
172
+ cut_jac_ad = cut_jac_ad[:, :max_inputs_per_var]
173
+ grad_matches = np.allclose(cut_jac_ad, cut_jac_fd, atol=atol, rtol=rtol)
174
+ any_grad_mismatch = any_grad_mismatch or not grad_matches
175
+ success = success and grad_matches
176
+ isnan = np.any(np.isnan(cut_jac_ad))
177
+ any_grad_nan = any_grad_nan or isnan
178
+ success = success and not isnan
179
+
180
+ if show_summary:
181
+ max_abs_error = np.abs(cut_jac_ad - cut_jac_fd).max()
182
+ arg_max_abs_error = np.unravel_index(np.argmax(np.abs(cut_jac_ad - cut_jac_fd)), cut_jac_ad.shape)
183
+ max_rel_error = np.abs((cut_jac_ad - cut_jac_fd) / (cut_jac_fd + 1e-8)).max()
184
+ if isnan:
185
+ pass_str = FontColors.FAIL + "NaN" + FontColors.ENDC
186
+ elif grad_matches:
187
+ pass_str = FontColors.OKGREEN + "PASS" + FontColors.ENDC
188
+ else:
189
+ pass_str = FontColors.FAIL + "FAIL" + FontColors.ENDC
190
+ input_name = metadata.input_labels[input_i]
191
+ output_name = metadata.output_labels[output_i]
192
+ summary.append(
193
+ [
194
+ input_name,
195
+ output_name,
196
+ f"{max_abs_error:.3e} at {tuple(int(i) for i in arg_max_abs_error)}",
197
+ f"{cut_jac_ad[arg_max_abs_error]:.3e}",
198
+ f"{cut_jac_fd[arg_max_abs_error]:.3e}",
199
+ f"{max_rel_error:.3e}",
200
+ pass_str,
201
+ ]
202
+ )
203
+
204
+ if show_summary:
205
+ print_table(summary_header, summary)
206
+ if not success:
207
+ print(FontColors.FAIL + f"Gradient check for kernel {metadata.key} failed" + FontColors.ENDC)
208
+ else:
209
+ print(FontColors.OKGREEN + f"Gradient check for kernel {metadata.key} passed" + FontColors.ENDC)
210
+ if plot_relative_error:
211
+ jacobian_plot(
212
+ relative_error_jacs,
213
+ metadata,
214
+ inputs,
215
+ outputs,
216
+ title=f"{metadata.key} kernel Jacobian relative error",
217
+ )
218
+ if plot_absolute_error:
219
+ jacobian_plot(
220
+ absolute_error_jacs,
221
+ metadata,
222
+ inputs,
223
+ outputs,
224
+ title=f"{metadata.key} kernel Jacobian absolute error",
225
+ )
226
+
227
+ if raise_exception:
228
+ if any_grad_mismatch:
229
+ raise ValueError(
230
+ f"Gradient check failed for kernel {metadata.key}, input {input_i}, output {output_i}: "
231
+ f"finite difference and autodiff gradients do not match"
232
+ )
233
+ if any_grad_nan:
234
+ raise ValueError(
235
+ f"Gradient check failed for kernel {metadata.key}, input {input_i}, output {output_i}: "
236
+ f"gradient contains NaN values"
237
+ )
238
+
239
+ return success
240
+
241
+
242
+ def gradcheck_tape(
243
+ tape: wp.Tape,
244
+ *,
245
+ eps=1e-4,
246
+ atol=1e-3,
247
+ rtol=1e-2,
248
+ raise_exception=True,
249
+ input_output_masks: dict[str, list[tuple[str | int, str | int]]] | None = None,
250
+ blacklist_kernels: list[str] | None = None,
251
+ whitelist_kernels: list[str] | None = None,
252
+ max_inputs_per_var=-1,
253
+ max_outputs_per_var=-1,
254
+ plot_relative_error=False,
255
+ plot_absolute_error=False,
256
+ show_summary: bool = True,
257
+ reverse_launches: bool = False,
258
+ skip_to_launch_index: int = 0,
259
+ ) -> bool:
260
+ """
261
+ Checks whether the autodiff gradients for kernels recorded on the Warp tape match finite differences.
262
+
263
+ Given the autodiff (:math:`\\nabla_\\text{AD}`) and finite difference gradients (:math:`\\nabla_\\text{FD}`), the check succeeds if the autodiff gradients contain no NaN values and the following condition holds:
264
+
265
+ .. math::
266
+
267
+ |\\nabla_\\text{AD} - \\nabla_\\text{FD}| \\leq atol + rtol \\cdot |\\nabla_\\text{FD}|.
268
+
269
+ Note:
270
+ Only Warp kernels recorded on the tape are checked but not arbitrary functions that have been recorded, e.g. via :meth:`Tape.record_func`.
271
+
272
+ Only Warp arrays with ``requires_grad=True`` are considered for the Jacobian computation.
273
+
274
+ Structs arguments are not yet supported by this function to compute Jacobians.
275
+
276
+ Args:
277
+ tape: The Warp tape to perform the gradient check on.
278
+ eps: The finite-difference step size.
279
+ atol: The absolute tolerance for the gradient check.
280
+ rtol: The relative tolerance for the gradient check.
281
+ raise_exception: If True, raises a `ValueError` if the gradient check fails.
282
+ input_output_masks: Dictionary of input-output masks for each kernel in the tape, mapping from kernel keys to input-output masks. Inputs and outputs can be identified either by their integer indices of where they appear in the kernel input/output arguments, or by the respective argument names as strings. If None, computes the Jacobian for all input-output pairs.
283
+ blacklist_kernels: List of kernel keys to exclude from the gradient check.
284
+ whitelist_kernels: List of kernel keys to include in the gradient check. If not empty or None, only kernels in this list are checked.
285
+ max_inputs_per_var: Maximum number of input dimensions over which to evaluate the Jacobians for the input-output pairs. Evaluates all input dimensions if value <= 0.
286
+ max_outputs_per_var: Maximum number of output dimensions over which to evaluate the Jacobians for the input-output pairs. Evaluates all output dimensions if value <= 0.
287
+ plot_relative_error: If True, visualizes the relative error of the Jacobians in a plot (requires ``matplotlib``).
288
+ plot_absolute_error: If True, visualizes the absolute error of the Jacobians in a plot (requires ``matplotlib``).
289
+ show_summary: If True, prints a summary table of the gradient check results.
290
+ reverse_launches: If True, reverses the order of the kernel launches on the tape to check.
291
+
292
+ Returns:
293
+ True if the gradient check passes for all kernels on the tape, False otherwise.
294
+ """
295
+ if input_output_masks is None:
296
+ input_output_masks = {}
297
+ if blacklist_kernels is None:
298
+ blacklist_kernels = []
299
+ else:
300
+ blacklist_kernels = set(blacklist_kernels)
301
+ if whitelist_kernels is None:
302
+ whitelist_kernels = []
303
+ else:
304
+ whitelist_kernels = set(whitelist_kernels)
305
+
306
+ overall_success = True
307
+ launches = reversed(tape.launches) if reverse_launches else tape.launches
308
+ for i, launch in enumerate(launches):
309
+ if i < skip_to_launch_index:
310
+ continue
311
+ if not isinstance(launch, tuple) and not isinstance(launch, list):
312
+ continue
313
+ if not isinstance(launch[0], wp.Kernel):
314
+ continue
315
+ kernel, dim, max_blocks, inputs, outputs, device, block_dim = launch[:7]
316
+ if len(whitelist_kernels) > 0 and kernel.key not in whitelist_kernels:
317
+ continue
318
+ if kernel.key in blacklist_kernels:
319
+ continue
320
+ if not kernel.options.get("enable_backward", True):
321
+ continue
322
+
323
+ input_output_mask = input_output_masks.get(kernel.key)
324
+ success = gradcheck(
325
+ kernel,
326
+ dim,
327
+ inputs,
328
+ outputs,
329
+ eps=eps,
330
+ atol=atol,
331
+ rtol=rtol,
332
+ raise_exception=raise_exception,
333
+ input_output_mask=input_output_mask,
334
+ device=device,
335
+ max_blocks=max_blocks,
336
+ block_dim=block_dim,
337
+ max_inputs_per_var=max_inputs_per_var,
338
+ max_outputs_per_var=max_outputs_per_var,
339
+ plot_relative_error=plot_relative_error,
340
+ plot_absolute_error=plot_absolute_error,
341
+ show_summary=show_summary,
342
+ )
343
+ overall_success = overall_success and success
344
+
345
+ return overall_success
346
+
347
+
348
+ def get_struct_vars(x: wp._src.codegen.StructInstance):
349
+ return {varname: getattr(x, varname) for varname, _ in x._cls.ctype._fields_}
350
+
351
+
352
+ def infer_device(xs: list):
353
+ # retrieve best matching Warp device for a list of variables
354
+ for x in xs:
355
+ if isinstance(x, wp.array):
356
+ return x.device
357
+ elif isinstance(x, wp._src.codegen.StructInstance):
358
+ for var in get_struct_vars(x).values():
359
+ if isinstance(var, wp.array):
360
+ return var.device
361
+ return wp.get_preferred_device()
362
+
363
+
364
+ class FunctionMetadata:
365
+ """
366
+ Metadata holder for kernel functions or functions with Warp arrays as inputs/outputs.
367
+ """
368
+
369
+ def __init__(
370
+ self,
371
+ key: str | None = None,
372
+ input_labels: list[str] | None = None,
373
+ output_labels: list[str] | None = None,
374
+ input_strides: list[tuple] | None = None,
375
+ output_strides: list[tuple] | None = None,
376
+ input_dtypes: list | None = None,
377
+ output_dtypes: list | None = None,
378
+ ):
379
+ self.key = key
380
+ self.input_labels = input_labels
381
+ self.output_labels = output_labels
382
+ self.input_strides = input_strides
383
+ self.output_strides = output_strides
384
+ self.input_dtypes = input_dtypes
385
+ self.output_dtypes = output_dtypes
386
+
387
+ @property
388
+ def is_empty(self):
389
+ return self.key is None
390
+
391
+ def input_is_array(self, i: int):
392
+ return self.input_strides[i] is not None
393
+
394
+ def output_is_array(self, i: int):
395
+ return self.output_strides[i] is not None
396
+
397
+ def update_from_kernel(self, kernel: wp.Kernel, inputs: Sequence):
398
+ self.key = kernel.key
399
+ self.input_labels = [arg.label for arg in kernel.adj.args[: len(inputs)]]
400
+ self.output_labels = [arg.label for arg in kernel.adj.args[len(inputs) :]]
401
+ self.input_strides = []
402
+ self.output_strides = []
403
+ self.input_dtypes = []
404
+ self.output_dtypes = []
405
+ for arg in kernel.adj.args[: len(inputs)]:
406
+ if arg.type is wp.array:
407
+ self.input_strides.append(arg.type.strides)
408
+ self.input_dtypes.append(arg.type.dtype)
409
+ else:
410
+ self.input_strides.append(None)
411
+ self.input_dtypes.append(None)
412
+ for arg in kernel.adj.args[len(inputs) :]:
413
+ if arg.type is wp.array:
414
+ self.output_strides.append(arg.type.strides)
415
+ self.output_dtypes.append(arg.type.dtype)
416
+ else:
417
+ self.output_strides.append(None)
418
+ self.output_dtypes.append(None)
419
+
420
+ def update_from_function(self, function: Callable, inputs: Sequence, outputs: Sequence | None = None):
421
+ self.key = function.__name__
422
+ self.input_labels = list(inspect.signature(function).parameters.keys())
423
+ if outputs is None:
424
+ outputs = function(*inputs)
425
+ if isinstance(outputs, wp.array):
426
+ outputs = [outputs]
427
+ self.output_labels = [f"output_{i}" for i in range(len(outputs))]
428
+ self.input_strides = []
429
+ self.output_strides = []
430
+ self.input_dtypes = []
431
+ self.output_dtypes = []
432
+ for input in inputs:
433
+ if isinstance(input, wp.array):
434
+ self.input_strides.append(input.strides)
435
+ self.input_dtypes.append(input.dtype)
436
+ else:
437
+ self.input_strides.append(None)
438
+ self.input_dtypes.append(None)
439
+ for output in outputs:
440
+ if isinstance(output, wp.array):
441
+ self.output_strides.append(output.strides)
442
+ self.output_dtypes.append(output.dtype)
443
+ else:
444
+ self.output_strides.append(None)
445
+ self.output_dtypes.append(None)
446
+
447
+
448
+ def jacobian_plot(
449
+ jacobians: dict[tuple[int, int], wp.array],
450
+ kernel: FunctionMetadata | wp.Kernel,
451
+ inputs: Sequence | None = None,
452
+ show_plot: bool = True,
453
+ show_colorbar: bool = True,
454
+ scale_colors_per_submatrix: bool = False,
455
+ title: str | None = None,
456
+ colormap: str = "coolwarm",
457
+ log_scale: bool = False,
458
+ ):
459
+ """
460
+ Visualizes the Jacobians computed by :func:`jacobian` or :func:`jacobian_fd` in a combined image plot.
461
+ Requires the ``matplotlib`` package to be installed.
462
+
463
+ Args:
464
+ jacobians: A dictionary of Jacobians, where the keys are tuples of input and output indices, and the values are the Jacobian matrices.
465
+ kernel: The Warp kernel function, decorated with the ``@wp.kernel`` decorator, or a :class:`FunctionMetadata` instance with the kernel/function attributes.
466
+ inputs: List of input variables.
467
+ show_plot: If True, displays the plot via ``plt.show()``.
468
+ show_colorbar: If True, displays a colorbar next to the plot (or a colorbar next to every submatrix if ).
469
+ scale_colors_per_submatrix: If True, considers the minimum and maximum of each Jacobian submatrix separately for color scaling. Otherwise, uses the global minimum and maximum of all Jacobians.
470
+ title: The title of the plot (optional).
471
+ colormap: The colormap to use for the plot.
472
+ log_scale: If True, uses a logarithmic scale for the matrix values shown in the image plot.
473
+
474
+ Returns:
475
+ The created Matplotlib figure.
476
+ """
477
+
478
+ import matplotlib.pyplot as plt
479
+ from matplotlib.ticker import MaxNLocator
480
+
481
+ if isinstance(kernel, wp.Kernel):
482
+ assert inputs is not None
483
+ metadata = FunctionMetadata()
484
+ metadata.update_from_kernel(kernel, inputs)
485
+ elif isinstance(kernel, FunctionMetadata):
486
+ metadata = kernel
487
+ else:
488
+ raise ValueError("Invalid kernel argument: must be a Warp kernel or a FunctionMetadata object")
489
+
490
+ jacobians = sorted(jacobians.items(), key=lambda x: (x[0][1], x[0][0]))
491
+ jacobians = dict(jacobians)
492
+
493
+ input_to_ax = {}
494
+ output_to_ax = {}
495
+ ax_to_input = {}
496
+ ax_to_output = {}
497
+ for i, j in jacobians.keys():
498
+ if i not in input_to_ax:
499
+ input_to_ax[i] = len(input_to_ax)
500
+ ax_to_input[input_to_ax[i]] = i
501
+ if j not in output_to_ax:
502
+ output_to_ax[j] = len(output_to_ax)
503
+ ax_to_output[output_to_ax[j]] = j
504
+
505
+ num_rows = len(output_to_ax)
506
+ num_cols = len(input_to_ax)
507
+ if num_rows == 0 or num_cols == 0:
508
+ return
509
+
510
+ # determine the width and height ratios for the subplots based on the
511
+ # dimensions of the Jacobians
512
+ width_ratios = []
513
+ height_ratios = []
514
+ for i in range(len(metadata.input_labels)):
515
+ if not metadata.input_is_array(i):
516
+ continue
517
+ input_stride = metadata.input_strides[i][0]
518
+ for j in range(len(metadata.output_labels)):
519
+ if (i, j) not in jacobians:
520
+ continue
521
+ jac_wp = jacobians[(i, j)]
522
+ width_ratios.append(jac_wp.shape[1] * input_stride)
523
+ break
524
+
525
+ for i in range(len(metadata.output_labels)):
526
+ if not metadata.output_is_array(i):
527
+ continue
528
+ for j in range(len(inputs)):
529
+ if (j, i) not in jacobians:
530
+ continue
531
+ jac_wp = jacobians[(j, i)]
532
+ height_ratios.append(jac_wp.shape[0])
533
+ break
534
+
535
+ fig, axs = plt.subplots(
536
+ ncols=num_cols,
537
+ nrows=num_rows,
538
+ figsize=(7, 7),
539
+ sharex="col",
540
+ sharey="row",
541
+ gridspec_kw={
542
+ "wspace": 0.1,
543
+ "hspace": 0.1,
544
+ "width_ratios": width_ratios,
545
+ "height_ratios": height_ratios,
546
+ },
547
+ subplot_kw={"aspect": 1},
548
+ squeeze=False,
549
+ )
550
+ if title is None:
551
+ key = kernel.key if isinstance(kernel, wp.Kernel) else kernel.get("key", "unknown")
552
+ title = f"{key} kernel Jacobian"
553
+ fig.suptitle(title)
554
+ fig.canvas.manager.set_window_title(title)
555
+
556
+ if not scale_colors_per_submatrix:
557
+ safe_jacobians = [jac.numpy().flatten() for jac in jacobians.values()]
558
+ safe_jacobians = [jac[~np.isnan(jac)] for jac in safe_jacobians]
559
+ safe_jacobians = [jac for jac in safe_jacobians if len(jac) > 0]
560
+ if len(safe_jacobians) == 0:
561
+ vmin = 0
562
+ vmax = 0
563
+ else:
564
+ vmin = min([jac.min() for jac in safe_jacobians])
565
+ vmax = max([jac.max() for jac in safe_jacobians])
566
+
567
+ has_plot = np.ones((num_rows, num_cols), dtype=bool)
568
+ for i in range(num_rows):
569
+ for j in range(num_cols):
570
+ if (ax_to_input[j], ax_to_output[i]) not in jacobians:
571
+ ax = axs[i, j]
572
+ ax.axis("off")
573
+ has_plot[i, j] = False
574
+
575
+ jac_i = 0
576
+ for (input_i, output_i), jac_wp in jacobians.items():
577
+ input_name = metadata.input_labels[input_i]
578
+ output_name = metadata.output_labels[output_i]
579
+
580
+ ax_i, ax_j = output_to_ax[output_i], input_to_ax[input_i]
581
+ ax = axs[ax_i, ax_j]
582
+ ax.tick_params(which="major", width=1, length=7)
583
+ ax.tick_params(which="minor", width=1, length=4, color="gray")
584
+
585
+ input_stride = metadata.input_dtypes[input_i]._length_
586
+ # output_stride = metadata.output_dtypes[output_i]._length_
587
+
588
+ jac = jac_wp.numpy()
589
+ # Jacobian matrix has output stride already multiplied to first dimension
590
+ jac = jac.reshape(jac_wp.shape[0], jac_wp.shape[1] * input_stride)
591
+
592
+ ax.xaxis.set_major_locator(MaxNLocator(integer=True))
593
+ ax.yaxis.set_major_locator(MaxNLocator(integer=True))
594
+
595
+ if scale_colors_per_submatrix:
596
+ safe_jac = jac[~np.isnan(jac)]
597
+ vmin = safe_jac.min()
598
+ vmax = safe_jac.max()
599
+ img = ax.imshow(
600
+ np.log10(np.abs(jac) + 1e-8) if log_scale else jac,
601
+ cmap=colormap,
602
+ aspect="auto",
603
+ interpolation="nearest",
604
+ extent=[0, jac.shape[1], 0, jac.shape[0]],
605
+ vmin=vmin,
606
+ vmax=vmax,
607
+ )
608
+ if ax_i == num_rows - 1 or not has_plot[ax_i + 1 :, ax_j].any():
609
+ # last plot of this column
610
+ ax.set_xlabel(input_name)
611
+ if ax_j == 0 or not has_plot[ax_i, :ax_j].any():
612
+ # first plot of this row
613
+ ax.set_ylabel(output_name)
614
+ ax.grid(color="gray", which="minor", linestyle="--", linewidth=0.5)
615
+ ax.grid(color="black", which="major", linewidth=1.0)
616
+
617
+ if show_colorbar and scale_colors_per_submatrix:
618
+ plt.colorbar(img, ax=ax, orientation="vertical", pad=0.02)
619
+
620
+ jac_i += 1
621
+
622
+ if show_colorbar and not scale_colors_per_submatrix:
623
+ m = plt.cm.ScalarMappable(cmap=colormap)
624
+ m.set_array([vmin, vmax])
625
+ m.set_clim(vmin, vmax)
626
+ plt.colorbar(m, ax=axs, orientation="vertical", pad=0.02)
627
+
628
+ plt.tight_layout()
629
+ if show_plot:
630
+ plt.show()
631
+ return fig
632
+
633
+
634
+ def scalarize_array_1d(arr):
635
+ # convert array to 1D array with scalar dtype
636
+ if arr.dtype in wp._src.types.scalar_types:
637
+ return arr.flatten()
638
+ elif arr.dtype in wp._src.types.vector_types:
639
+ return wp.array(
640
+ ptr=arr.ptr,
641
+ shape=(arr.size * arr.dtype._length_,),
642
+ dtype=arr.dtype._wp_scalar_type_,
643
+ device=arr.device,
644
+ )
645
+ else:
646
+ raise ValueError(
647
+ f"Unsupported array dtype {arr.dtype}: array to be flattened must be a scalar/vector/matrix array"
648
+ )
649
+
650
+
651
+ def scalarize_array_2d(arr):
652
+ assert arr.ndim == 2
653
+ # convert array to 2D array with scalar dtype
654
+ if arr.dtype in wp._src.types.scalar_types:
655
+ return arr
656
+ elif arr.dtype in wp._src.types.vector_types:
657
+ return wp.array(
658
+ ptr=arr.ptr,
659
+ shape=(arr.shape[0], arr.shape[1] * arr.dtype._length_),
660
+ dtype=arr.dtype._wp_scalar_type_,
661
+ device=arr.device,
662
+ )
663
+ else:
664
+ raise ValueError(
665
+ f"Unsupported array dtype {arr.dtype}: array to be flattened must be a scalar/vector/matrix array"
666
+ )
667
+
668
+
669
+ def jacobian(
670
+ function: wp.Kernel | Callable,
671
+ dim: tuple[int] | None = None,
672
+ inputs: Sequence | None = None,
673
+ outputs: Sequence | None = None,
674
+ input_output_mask: list[tuple[str | int, str | int]] | None = None,
675
+ device: wp.context.Devicelike = None,
676
+ max_blocks=0,
677
+ block_dim=256,
678
+ max_outputs_per_var=-1,
679
+ plot_jacobians=False,
680
+ metadata: FunctionMetadata | None = None,
681
+ ) -> dict[tuple[int, int], wp.array]:
682
+ """
683
+ Computes the Jacobians of a function or Warp kernel for the provided selection of differentiable inputs to differentiable outputs.
684
+
685
+ The input function can be either a Warp kernel (e.g. a function decorated by ``@wp.kernel``) or a regular Python function that accepts arguments (of which some must be Warp arrays) and returns a Warp array or a list of Warp arrays.
686
+
687
+ In case ``function`` is a Warp kernel, its adjoint kernel is launched with the given inputs and outputs, as well as the provided ``dim``,
688
+ ``max_blocks``, and ``block_dim`` arguments (see :func:`warp.launch` for more details).
689
+
690
+ Note:
691
+ If ``function`` is a Warp kernel, the input arguments must precede the output arguments in the kernel code definition.
692
+
693
+ Only Warp arrays with ``requires_grad=True`` are considered for the Jacobian computation.
694
+
695
+ Function arguments of type :ref:`Struct <structs>` are not yet supported.
696
+
697
+ Args:
698
+ function: The Warp kernel function, or a regular Python function that returns a Warp array or a list of Warp arrays.
699
+ dim: The number of threads to launch the kernel, can be an integer, or a Tuple of ints. Only required if ``function`` is a Warp kernel.
700
+ inputs: List of input variables. At least one of the arguments must be a Warp array with ``requires_grad=True``.
701
+ outputs: List of output variables. Optional if the function is a regular Python function that returns a Warp array or a list of Warp arrays. Only required if ``function`` is a Warp kernel.
702
+ input_output_mask: List of tuples specifying the input-output pairs to compute the Jacobian for. Inputs and outputs can be identified either by their integer indices of where they appear in the kernel input/output arguments, or by the respective argument names as strings. If None, computes the Jacobian for all input-output pairs.
703
+ device: The device to launch on (optional). Only used if ``function`` is a Warp kernel.
704
+ max_blocks: The maximum number of CUDA thread blocks to use. Only used if ``function`` is a Warp kernel.
705
+ block_dim: The number of threads per block. Only used if ``function`` is a Warp kernel.
706
+ max_outputs_per_var: Maximum number of output dimensions over which to evaluate the Jacobians for the input-output pairs. Evaluates all output dimensions if value <= 0.
707
+ plot_jacobians: If True, visualizes the computed Jacobians in a plot (requires ``matplotlib``).
708
+ metadata: The metadata of the kernel function, containing the input and output labels, strides, and dtypes. If None or empty, the metadata is inferred from the kernel or function.
709
+
710
+ Returns:
711
+ A dictionary of Jacobians, where the keys are tuples of input and output indices, and the values are the Jacobian matrices.
712
+ """
713
+ if input_output_mask is None:
714
+ input_output_mask = []
715
+
716
+ if metadata is None:
717
+ metadata = FunctionMetadata()
718
+
719
+ if isinstance(function, wp.Kernel):
720
+ if not function.options.get("enable_backward", True):
721
+ raise ValueError("Kernel must have backward pass enabled to compute Jacobians")
722
+ if outputs is None or len(outputs) == 0:
723
+ raise ValueError("A list of output arguments must be provided to compute kernel Jacobians")
724
+ if device is None:
725
+ device = infer_device(inputs + outputs)
726
+ if metadata.is_empty:
727
+ metadata.update_from_kernel(function, inputs)
728
+
729
+ tape = wp.Tape()
730
+ tape.record_launch(
731
+ kernel=function,
732
+ dim=dim,
733
+ inputs=inputs,
734
+ outputs=outputs,
735
+ device=device,
736
+ max_blocks=max_blocks,
737
+ block_dim=block_dim,
738
+ )
739
+ else:
740
+ tape = wp.Tape()
741
+ with tape:
742
+ outputs = function(*inputs)
743
+ if isinstance(outputs, wp.array):
744
+ outputs = [outputs]
745
+ if metadata.is_empty:
746
+ metadata.update_from_function(function, inputs, outputs)
747
+
748
+ arg_names = metadata.input_labels + metadata.output_labels
749
+
750
+ def resolve_arg(name, offset: int = 0):
751
+ if isinstance(name, int):
752
+ return name
753
+ return arg_names.index(name) + offset
754
+
755
+ input_output_mask = [
756
+ (resolve_arg(input_name), resolve_arg(output_name, -len(inputs)))
757
+ for input_name, output_name in input_output_mask
758
+ ]
759
+ input_output_mask = set(input_output_mask)
760
+
761
+ zero_grads(inputs)
762
+ zero_grads(outputs)
763
+
764
+ jacobians = {}
765
+
766
+ for input_i, output_i in itertools.product(range(len(inputs)), range(len(outputs))):
767
+ if len(input_output_mask) > 0 and (input_i, output_i) not in input_output_mask:
768
+ continue
769
+ input = inputs[input_i]
770
+ output = outputs[output_i]
771
+ if not isinstance(input, wp.array) or not input.requires_grad:
772
+ continue
773
+ if not isinstance(output, wp.array) or not output.requires_grad:
774
+ continue
775
+ out_grad = scalarize_array_1d(output.grad)
776
+ output_num = out_grad.shape[0]
777
+ jacobian = wp.empty((output_num, input.size), dtype=input.dtype, device=input.device)
778
+ jacobian.fill_(wp.nan)
779
+ if max_outputs_per_var > 0:
780
+ output_num = min(output_num, max_outputs_per_var)
781
+ for i in range(output_num):
782
+ output.grad.zero_()
783
+ if i > 0:
784
+ set_element(out_grad, i - 1, 0.0)
785
+ set_element(out_grad, i, 1.0)
786
+ tape.backward()
787
+ jacobian[i].assign(input.grad)
788
+
789
+ zero_grads(inputs)
790
+ zero_grads(outputs)
791
+ jacobians[input_i, output_i] = jacobian
792
+
793
+ if plot_jacobians:
794
+ jacobian_plot(
795
+ jacobians,
796
+ metadata,
797
+ inputs,
798
+ outputs,
799
+ )
800
+
801
+ return jacobians
802
+
803
+
804
+ def jacobian_fd(
805
+ function: wp.Kernel | Callable,
806
+ dim: tuple[int] | None | None = None,
807
+ inputs: Sequence | None = None,
808
+ outputs: Sequence | None = None,
809
+ input_output_mask: list[tuple[str | int, str | int]] | None = None,
810
+ device: wp.context.Devicelike = None,
811
+ max_blocks=0,
812
+ block_dim=256,
813
+ max_inputs_per_var=-1,
814
+ eps: float = 1e-4,
815
+ plot_jacobians=False,
816
+ metadata: FunctionMetadata | None = None,
817
+ ) -> dict[tuple[int, int], wp.array]:
818
+ """
819
+ Computes the finite-difference Jacobian of a function or Warp kernel for the provided selection of differentiable inputs to differentiable outputs.
820
+ The method uses a central difference scheme to approximate the Jacobian.
821
+
822
+ The input function can be either a Warp kernel (e.g. a function decorated by ``@wp.kernel``) or a regular Python function that accepts arguments (of which some must be Warp arrays) and returns a Warp array or a list of Warp arrays.
823
+
824
+ The function is launched multiple times in forward-only mode with the given inputs. If ``function`` is a Warp kernel, the provided inputs and outputs,
825
+ as well as the other parameters ``dim``, ``max_blocks``, and ``block_dim`` are provided to the kernel launch (see :func:`warp.launch`).
826
+
827
+ Note:
828
+ If ``function`` is a Warp kernel, the input arguments must precede the output arguments in the kernel code definition.
829
+
830
+ Only Warp arrays with ``requires_grad=True`` are considered for the Jacobian computation.
831
+
832
+ Function arguments of type :ref:`Struct <structs>` are not yet supported.
833
+
834
+ Args:
835
+ function: The Warp kernel function, or a regular Python function that returns a Warp array or a list of Warp arrays.
836
+ dim: The number of threads to launch the kernel, can be an integer, or a Tuple of ints. Only required if ``function`` is a Warp kernel.
837
+ inputs: List of input variables. At least one of the arguments must be a Warp array with ``requires_grad=True``.
838
+ outputs: List of output variables. Optional if the function is a regular Python function that returns a Warp array or a list of Warp arrays. Only required if ``function`` is a Warp kernel.
839
+ input_output_mask: List of tuples specifying the input-output pairs to compute the Jacobian for. Inputs and outputs can be identified either by their integer indices of where they appear in the kernel input/output arguments, or by the respective argument names as strings. If None, computes the Jacobian for all input-output pairs.
840
+ device: The device to launch on (optional). Only used if ``function`` is a Warp kernel.
841
+ max_blocks: The maximum number of CUDA thread blocks to use. Only used if ``function`` is a Warp kernel.
842
+ block_dim: The number of threads per block. Only used if ``function`` is a Warp kernel.
843
+ max_inputs_per_var: Maximum number of input dimensions over which to evaluate the Jacobians for the input-output pairs. Evaluates all input dimensions if value <= 0.
844
+ eps: The finite-difference step size.
845
+ plot_jacobians: If True, visualizes the computed Jacobians in a plot (requires ``matplotlib``).
846
+ metadata: The metadata of the kernel function, containing the input and output labels, strides, and dtypes. If None or empty, the metadata is inferred from the kernel or function.
847
+
848
+ Returns:
849
+ A dictionary of Jacobians, where the keys are tuples of input and output indices, and the values are the Jacobian matrices.
850
+ """
851
+ if input_output_mask is None:
852
+ input_output_mask = []
853
+
854
+ if metadata is None:
855
+ metadata = FunctionMetadata()
856
+
857
+ if isinstance(function, wp.Kernel):
858
+ if not function.options.get("enable_backward", True):
859
+ raise ValueError("Kernel must have backward pass enabled to compute Jacobians")
860
+ if outputs is None or len(outputs) == 0:
861
+ raise ValueError("A list of output arguments must be provided to compute kernel Jacobians")
862
+ if device is None:
863
+ device = infer_device(inputs + outputs)
864
+ if metadata.is_empty:
865
+ metadata.update_from_kernel(function, inputs)
866
+
867
+ tape = wp.Tape()
868
+ tape.record_launch(
869
+ kernel=function,
870
+ dim=dim,
871
+ inputs=inputs,
872
+ outputs=outputs,
873
+ device=device,
874
+ max_blocks=max_blocks,
875
+ block_dim=block_dim,
876
+ )
877
+ else:
878
+ tape = wp.Tape()
879
+ with tape:
880
+ outputs = function(*inputs)
881
+ if isinstance(outputs, wp.array):
882
+ outputs = [outputs]
883
+ if metadata.is_empty:
884
+ metadata.update_from_function(function, inputs, outputs)
885
+
886
+ arg_names = metadata.input_labels + metadata.output_labels
887
+
888
+ def resolve_arg(name, offset: int = 0):
889
+ if isinstance(name, int):
890
+ return name
891
+ return arg_names.index(name) + offset
892
+
893
+ input_output_mask = [
894
+ (resolve_arg(input_name), resolve_arg(output_name, -len(inputs)))
895
+ for input_name, output_name in input_output_mask
896
+ ]
897
+ input_output_mask = set(input_output_mask)
898
+
899
+ jacobians = {}
900
+
901
+ def conditional_clone(obj):
902
+ if isinstance(obj, wp.array):
903
+ return wp.clone(obj)
904
+ return obj
905
+
906
+ outputs_copy = [conditional_clone(output) for output in outputs]
907
+
908
+ for input_i, output_i in itertools.product(range(len(inputs)), range(len(outputs))):
909
+ if len(input_output_mask) > 0 and (input_i, output_i) not in input_output_mask:
910
+ continue
911
+ input = inputs[input_i]
912
+ output = outputs[output_i]
913
+ if not isinstance(input, wp.array) or not input.requires_grad:
914
+ continue
915
+ if not isinstance(output, wp.array) or not output.requires_grad:
916
+ continue
917
+
918
+ flat_input = scalarize_array_1d(input)
919
+
920
+ left = wp.clone(output)
921
+ right = wp.clone(output)
922
+ left_copy = wp.clone(output)
923
+ right_copy = wp.clone(output)
924
+ flat_left = scalarize_array_1d(left)
925
+ flat_right = scalarize_array_1d(right)
926
+
927
+ outputs_until_left = [conditional_clone(output) for output in outputs_copy[:output_i]]
928
+ outputs_until_right = [conditional_clone(output) for output in outputs_copy[:output_i]]
929
+ outputs_after_left = [conditional_clone(output) for output in outputs_copy[output_i + 1 :]]
930
+ outputs_after_right = [conditional_clone(output) for output in outputs_copy[output_i + 1 :]]
931
+ left_outputs = [*outputs_until_left, left, *outputs_after_left]
932
+ right_outputs = [*outputs_until_right, right, *outputs_after_right]
933
+
934
+ input_num = flat_input.shape[0]
935
+ flat_input_copy = wp.clone(flat_input)
936
+ jacobian = wp.empty((flat_left.size, input.size), dtype=input.dtype, device=input.device)
937
+ jacobian.fill_(wp.nan)
938
+
939
+ jacobian_scalar = scalarize_array_2d(jacobian)
940
+ jacobian_t = jacobian_scalar.transpose()
941
+ if max_inputs_per_var > 0:
942
+ input_num = min(input_num, max_inputs_per_var)
943
+ for i in range(input_num):
944
+ set_element(flat_input, i, -eps, relative=True)
945
+ if isinstance(function, wp.Kernel):
946
+ wp.launch(
947
+ function,
948
+ dim=dim,
949
+ max_blocks=max_blocks,
950
+ block_dim=block_dim,
951
+ inputs=inputs,
952
+ outputs=left_outputs,
953
+ device=device,
954
+ )
955
+ else:
956
+ outputs = function(*inputs)
957
+ if isinstance(outputs, wp.array):
958
+ outputs = [outputs]
959
+ left.assign(outputs[output_i])
960
+
961
+ set_element(flat_input, i, 2 * eps, relative=True)
962
+ if isinstance(function, wp.Kernel):
963
+ wp.launch(
964
+ function,
965
+ dim=dim,
966
+ max_blocks=max_blocks,
967
+ block_dim=block_dim,
968
+ inputs=inputs,
969
+ outputs=right_outputs,
970
+ device=device,
971
+ )
972
+ else:
973
+ outputs = function(*inputs)
974
+ if isinstance(outputs, wp.array):
975
+ outputs = [outputs]
976
+ right.assign(outputs[output_i])
977
+
978
+ # restore input
979
+ flat_input.assign(flat_input_copy)
980
+
981
+ compute_fd(
982
+ flat_left,
983
+ flat_right,
984
+ eps,
985
+ jacobian_t[i],
986
+ )
987
+
988
+ if i < input_num - 1:
989
+ # reset output buffers
990
+ left.assign(left_copy)
991
+ right.assign(right_copy)
992
+ flat_left = scalarize_array_1d(left)
993
+ flat_right = scalarize_array_1d(right)
994
+
995
+ jacobians[input_i, output_i] = jacobian
996
+
997
+ if plot_jacobians:
998
+ jacobian_plot(
999
+ jacobians,
1000
+ metadata,
1001
+ inputs,
1002
+ outputs,
1003
+ )
1004
+
1005
+ return jacobians
1006
+
1007
+
1008
+ @wp.kernel(enable_backward=False)
1009
+ def set_element_kernel(a: wp.array(dtype=Any), i: int, val: Any, relative: bool):
1010
+ if relative:
1011
+ a[i] += val
1012
+ else:
1013
+ a[i] = val
1014
+
1015
+
1016
+ def set_element(a: wp.array(dtype=Any), i: int, val: Any, relative: bool = False):
1017
+ wp.launch(set_element_kernel, dim=1, inputs=[a, i, a.dtype(val), relative], device=a.device)
1018
+
1019
+
1020
+ @wp.kernel(enable_backward=False)
1021
+ def compute_fd_kernel(left: wp.array(dtype=float), right: wp.array(dtype=float), eps: float, fd: wp.array(dtype=float)):
1022
+ tid = wp.tid()
1023
+ fd[tid] = (right[tid] - left[tid]) / (2.0 * eps)
1024
+
1025
+
1026
+ def compute_fd(left: wp.array(dtype=Any), right: wp.array(dtype=Any), eps: float, fd: wp.array(dtype=Any)):
1027
+ wp.launch(compute_fd_kernel, dim=len(left), inputs=[left, right, eps], outputs=[fd], device=left.device)
1028
+
1029
+
1030
+ @wp.kernel(enable_backward=False)
1031
+ def compute_error_kernel(
1032
+ jacobian_ad: wp.array(dtype=Any),
1033
+ jacobian_fd: wp.array(dtype=Any),
1034
+ relative_error: wp.array(dtype=Any),
1035
+ absolute_error: wp.array(dtype=Any),
1036
+ ):
1037
+ tid = wp.tid()
1038
+ ad = jacobian_ad[tid]
1039
+ fd = jacobian_fd[tid]
1040
+ denom = ad
1041
+ if abs(ad) < 1e-8:
1042
+ denom = (type(ad))(1e-8)
1043
+ relative_error[tid] = (ad - fd) / denom
1044
+ absolute_error[tid] = wp.abs(ad - fd)
1045
+
1046
+
1047
+ def print_table(headers, cells):
1048
+ """
1049
+ Prints a table with the given headers and cells.
1050
+
1051
+ Args:
1052
+ headers: List of header strings.
1053
+ cells: List of lists of cell strings.
1054
+ """
1055
+ import re
1056
+
1057
+ def sanitized_len(s):
1058
+ return len(re.sub(r"\033\[\d+m", "", str(s)))
1059
+
1060
+ col_widths = [max(sanitized_len(cell) for cell in col) for col in zip(headers, *cells)]
1061
+ for header, col_width in zip(headers, col_widths):
1062
+ print(f"{header:{col_width}}", end=" | ")
1063
+ print()
1064
+ print("-" * (sum(col_widths) + 3 * len(col_widths) - 1))
1065
+ for cell_row in cells:
1066
+ for cell, col_width in zip(cell_row, col_widths):
1067
+ print(f"{cell:{col_width}}", end=" | ")
1068
+ print()
1069
+
1070
+
1071
+ def zero_grads(arrays: list):
1072
+ """
1073
+ Zeros the gradients of all Warp arrays in the given list.
1074
+ """
1075
+ for array in arrays:
1076
+ if isinstance(array, wp.array) and array.requires_grad:
1077
+ array.grad.zero_()