warp-lang 1.9.0__py3-none-win_amd64.whl → 1.10.0rc2__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (350) hide show
  1. warp/__init__.py +301 -287
  2. warp/__init__.pyi +2220 -313
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1075 -0
  5. warp/_src/build.py +618 -0
  6. warp/_src/build_dll.py +640 -0
  7. warp/{builtins.py → _src/builtins.py} +1497 -226
  8. warp/_src/codegen.py +4359 -0
  9. warp/{config.py → _src/config.py} +178 -169
  10. warp/_src/constants.py +57 -0
  11. warp/_src/context.py +8294 -0
  12. warp/_src/dlpack.py +462 -0
  13. warp/_src/fabric.py +355 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +508 -0
  16. warp/_src/fem/cache.py +687 -0
  17. warp/_src/fem/dirichlet.py +188 -0
  18. warp/{fem → _src/fem}/domain.py +40 -30
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +701 -0
  21. warp/{fem → _src/fem}/field/nodal_field.py +30 -15
  22. warp/{fem → _src/fem}/field/restriction.py +1 -1
  23. warp/{fem → _src/fem}/field/virtual.py +53 -27
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +77 -163
  26. warp/_src/fem/geometry/closest_point.py +97 -0
  27. warp/{fem → _src/fem}/geometry/deformed_geometry.py +14 -22
  28. warp/{fem → _src/fem}/geometry/element.py +32 -10
  29. warp/{fem → _src/fem}/geometry/geometry.py +48 -20
  30. warp/{fem → _src/fem}/geometry/grid_2d.py +12 -23
  31. warp/{fem → _src/fem}/geometry/grid_3d.py +12 -23
  32. warp/{fem → _src/fem}/geometry/hexmesh.py +40 -63
  33. warp/{fem → _src/fem}/geometry/nanogrid.py +255 -248
  34. warp/{fem → _src/fem}/geometry/partition.py +121 -63
  35. warp/{fem → _src/fem}/geometry/quadmesh.py +26 -45
  36. warp/{fem → _src/fem}/geometry/tetmesh.py +40 -63
  37. warp/{fem → _src/fem}/geometry/trimesh.py +26 -45
  38. warp/{fem → _src/fem}/integrate.py +164 -158
  39. warp/_src/fem/linalg.py +383 -0
  40. warp/_src/fem/operator.py +396 -0
  41. warp/_src/fem/polynomial.py +229 -0
  42. warp/{fem → _src/fem}/quadrature/pic_quadrature.py +15 -20
  43. warp/{fem → _src/fem}/quadrature/quadrature.py +95 -47
  44. warp/_src/fem/space/__init__.py +248 -0
  45. warp/{fem → _src/fem}/space/basis_function_space.py +20 -11
  46. warp/_src/fem/space/basis_space.py +679 -0
  47. warp/{fem → _src/fem}/space/dof_mapper.py +3 -3
  48. warp/{fem → _src/fem}/space/function_space.py +14 -13
  49. warp/{fem → _src/fem}/space/grid_2d_function_space.py +4 -7
  50. warp/{fem → _src/fem}/space/grid_3d_function_space.py +4 -4
  51. warp/{fem → _src/fem}/space/hexmesh_function_space.py +4 -10
  52. warp/{fem → _src/fem}/space/nanogrid_function_space.py +3 -9
  53. warp/{fem → _src/fem}/space/partition.py +117 -60
  54. warp/{fem → _src/fem}/space/quadmesh_function_space.py +4 -10
  55. warp/{fem → _src/fem}/space/restriction.py +66 -33
  56. warp/_src/fem/space/shape/__init__.py +152 -0
  57. warp/{fem → _src/fem}/space/shape/cube_shape_function.py +9 -9
  58. warp/{fem → _src/fem}/space/shape/shape_function.py +8 -9
  59. warp/{fem → _src/fem}/space/shape/square_shape_function.py +6 -6
  60. warp/{fem → _src/fem}/space/shape/tet_shape_function.py +3 -3
  61. warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +3 -3
  62. warp/{fem → _src/fem}/space/tetmesh_function_space.py +3 -9
  63. warp/_src/fem/space/topology.py +459 -0
  64. warp/{fem → _src/fem}/space/trimesh_function_space.py +3 -9
  65. warp/_src/fem/types.py +112 -0
  66. warp/_src/fem/utils.py +486 -0
  67. warp/_src/jax.py +186 -0
  68. warp/_src/jax_experimental/__init__.py +14 -0
  69. warp/_src/jax_experimental/custom_call.py +387 -0
  70. warp/_src/jax_experimental/ffi.py +1284 -0
  71. warp/_src/jax_experimental/xla_ffi.py +656 -0
  72. warp/_src/marching_cubes.py +708 -0
  73. warp/_src/math.py +414 -0
  74. warp/_src/optim/__init__.py +14 -0
  75. warp/_src/optim/adam.py +163 -0
  76. warp/_src/optim/linear.py +1606 -0
  77. warp/_src/optim/sgd.py +112 -0
  78. warp/_src/paddle.py +406 -0
  79. warp/_src/render/__init__.py +14 -0
  80. warp/_src/render/imgui_manager.py +289 -0
  81. warp/_src/render/render_opengl.py +3636 -0
  82. warp/_src/render/render_usd.py +937 -0
  83. warp/_src/render/utils.py +160 -0
  84. warp/_src/sparse.py +2716 -0
  85. warp/_src/tape.py +1206 -0
  86. warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
  87. warp/_src/torch.py +391 -0
  88. warp/_src/types.py +5870 -0
  89. warp/_src/utils.py +1693 -0
  90. warp/autograd.py +12 -1054
  91. warp/bin/warp-clang.dll +0 -0
  92. warp/bin/warp.dll +0 -0
  93. warp/build.py +8 -588
  94. warp/build_dll.py +6 -471
  95. warp/codegen.py +6 -4246
  96. warp/constants.py +6 -39
  97. warp/context.py +12 -7851
  98. warp/dlpack.py +6 -444
  99. warp/examples/distributed/example_jacobi_mpi.py +4 -5
  100. warp/examples/fem/example_adaptive_grid.py +1 -1
  101. warp/examples/fem/example_apic_fluid.py +1 -1
  102. warp/examples/fem/example_burgers.py +8 -8
  103. warp/examples/fem/example_diffusion.py +1 -1
  104. warp/examples/fem/example_distortion_energy.py +1 -1
  105. warp/examples/fem/example_mixed_elasticity.py +2 -2
  106. warp/examples/fem/example_navier_stokes.py +1 -1
  107. warp/examples/fem/example_nonconforming_contact.py +7 -7
  108. warp/examples/fem/example_stokes.py +1 -1
  109. warp/examples/fem/example_stokes_transfer.py +1 -1
  110. warp/examples/fem/utils.py +2 -2
  111. warp/examples/interop/example_jax_callable.py +1 -1
  112. warp/examples/interop/example_jax_ffi_callback.py +1 -1
  113. warp/examples/interop/example_jax_kernel.py +3 -2
  114. warp/examples/tile/example_tile_mcgp.py +191 -0
  115. warp/fabric.py +6 -337
  116. warp/fem/__init__.py +159 -97
  117. warp/fem/adaptivity.py +7 -489
  118. warp/fem/cache.py +9 -648
  119. warp/fem/dirichlet.py +6 -184
  120. warp/fem/field/__init__.py +8 -109
  121. warp/fem/field/field.py +7 -652
  122. warp/fem/geometry/__init__.py +7 -18
  123. warp/fem/geometry/closest_point.py +11 -77
  124. warp/fem/linalg.py +18 -366
  125. warp/fem/operator.py +11 -369
  126. warp/fem/polynomial.py +9 -209
  127. warp/fem/space/__init__.py +5 -211
  128. warp/fem/space/basis_space.py +6 -662
  129. warp/fem/space/shape/__init__.py +41 -118
  130. warp/fem/space/topology.py +6 -437
  131. warp/fem/types.py +6 -81
  132. warp/fem/utils.py +11 -444
  133. warp/jax.py +8 -165
  134. warp/jax_experimental/__init__.py +14 -1
  135. warp/jax_experimental/custom_call.py +8 -342
  136. warp/jax_experimental/ffi.py +17 -853
  137. warp/jax_experimental/xla_ffi.py +5 -596
  138. warp/marching_cubes.py +5 -689
  139. warp/math.py +16 -393
  140. warp/native/array.h +385 -37
  141. warp/native/builtin.h +316 -39
  142. warp/native/bvh.cpp +43 -9
  143. warp/native/bvh.cu +62 -27
  144. warp/native/bvh.h +310 -309
  145. warp/native/clang/clang.cpp +102 -97
  146. warp/native/coloring.cpp +0 -1
  147. warp/native/crt.h +208 -0
  148. warp/native/exports.h +156 -0
  149. warp/native/hashgrid.cu +2 -0
  150. warp/native/intersect.h +24 -1
  151. warp/native/intersect_tri.h +44 -35
  152. warp/native/mat.h +1456 -276
  153. warp/native/mesh.cpp +4 -4
  154. warp/native/mesh.cu +4 -2
  155. warp/native/mesh.h +176 -61
  156. warp/native/quat.h +0 -52
  157. warp/native/scan.cu +2 -0
  158. warp/native/sort.cu +22 -13
  159. warp/native/sort.h +2 -0
  160. warp/native/sparse.cu +7 -3
  161. warp/native/spatial.h +12 -0
  162. warp/native/tile.h +837 -70
  163. warp/native/tile_radix_sort.h +1 -1
  164. warp/native/tile_reduce.h +394 -46
  165. warp/native/tile_scan.h +4 -4
  166. warp/native/vec.h +469 -53
  167. warp/native/version.h +23 -0
  168. warp/native/volume.cpp +1 -1
  169. warp/native/volume.cu +1 -0
  170. warp/native/volume.h +1 -1
  171. warp/native/volume_builder.cu +2 -0
  172. warp/native/warp.cpp +60 -32
  173. warp/native/warp.cu +313 -201
  174. warp/native/warp.h +14 -11
  175. warp/optim/__init__.py +6 -3
  176. warp/optim/adam.py +6 -145
  177. warp/optim/linear.py +14 -1585
  178. warp/optim/sgd.py +6 -94
  179. warp/paddle.py +6 -388
  180. warp/render/__init__.py +8 -4
  181. warp/render/imgui_manager.py +7 -267
  182. warp/render/render_opengl.py +6 -3616
  183. warp/render/render_usd.py +6 -918
  184. warp/render/utils.py +6 -142
  185. warp/sparse.py +37 -2563
  186. warp/tape.py +6 -1188
  187. warp/tests/__main__.py +1 -1
  188. warp/tests/cuda/test_async.py +4 -4
  189. warp/tests/cuda/test_conditional_captures.py +1 -1
  190. warp/tests/cuda/test_multigpu.py +1 -1
  191. warp/tests/cuda/test_streams.py +58 -1
  192. warp/tests/geometry/test_bvh.py +157 -22
  193. warp/tests/geometry/test_hash_grid.py +38 -0
  194. warp/tests/geometry/test_marching_cubes.py +0 -1
  195. warp/tests/geometry/test_mesh.py +5 -3
  196. warp/tests/geometry/test_mesh_query_aabb.py +5 -12
  197. warp/tests/geometry/test_mesh_query_point.py +5 -2
  198. warp/tests/geometry/test_mesh_query_ray.py +15 -3
  199. warp/tests/geometry/test_volume_write.py +5 -5
  200. warp/tests/interop/test_dlpack.py +14 -14
  201. warp/tests/interop/test_jax.py +1382 -79
  202. warp/tests/interop/test_paddle.py +1 -1
  203. warp/tests/test_adam.py +0 -1
  204. warp/tests/test_arithmetic.py +9 -9
  205. warp/tests/test_array.py +529 -100
  206. warp/tests/test_array_reduce.py +3 -3
  207. warp/tests/test_atomic.py +12 -8
  208. warp/tests/test_atomic_bitwise.py +209 -0
  209. warp/tests/test_atomic_cas.py +4 -4
  210. warp/tests/test_bool.py +2 -2
  211. warp/tests/test_builtins_resolution.py +5 -571
  212. warp/tests/test_codegen.py +34 -15
  213. warp/tests/test_conditional.py +1 -1
  214. warp/tests/test_context.py +6 -6
  215. warp/tests/test_copy.py +242 -161
  216. warp/tests/test_ctypes.py +3 -3
  217. warp/tests/test_devices.py +24 -2
  218. warp/tests/test_examples.py +16 -84
  219. warp/tests/test_fabricarray.py +35 -35
  220. warp/tests/test_fast_math.py +0 -2
  221. warp/tests/test_fem.py +60 -14
  222. warp/tests/test_fixedarray.py +3 -3
  223. warp/tests/test_func.py +8 -5
  224. warp/tests/test_generics.py +1 -1
  225. warp/tests/test_indexedarray.py +24 -24
  226. warp/tests/test_intersect.py +39 -9
  227. warp/tests/test_large.py +1 -1
  228. warp/tests/test_lerp.py +3 -1
  229. warp/tests/test_linear_solvers.py +1 -1
  230. warp/tests/test_map.py +49 -4
  231. warp/tests/test_mat.py +52 -62
  232. warp/tests/test_mat_constructors.py +4 -5
  233. warp/tests/test_mat_lite.py +1 -1
  234. warp/tests/test_mat_scalar_ops.py +121 -121
  235. warp/tests/test_math.py +34 -0
  236. warp/tests/test_module_aot.py +4 -4
  237. warp/tests/test_modules_lite.py +28 -2
  238. warp/tests/test_print.py +11 -11
  239. warp/tests/test_quat.py +93 -58
  240. warp/tests/test_runlength_encode.py +1 -1
  241. warp/tests/test_scalar_ops.py +38 -10
  242. warp/tests/test_smoothstep.py +1 -1
  243. warp/tests/test_sparse.py +126 -15
  244. warp/tests/test_spatial.py +105 -87
  245. warp/tests/test_special_values.py +6 -6
  246. warp/tests/test_static.py +7 -7
  247. warp/tests/test_struct.py +13 -2
  248. warp/tests/test_triangle_closest_point.py +48 -1
  249. warp/tests/test_tuple.py +96 -0
  250. warp/tests/test_types.py +82 -9
  251. warp/tests/test_utils.py +52 -52
  252. warp/tests/test_vec.py +29 -29
  253. warp/tests/test_vec_constructors.py +5 -5
  254. warp/tests/test_vec_scalar_ops.py +97 -97
  255. warp/tests/test_version.py +75 -0
  256. warp/tests/tile/test_tile.py +239 -0
  257. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  258. warp/tests/tile/test_tile_cholesky.py +7 -4
  259. warp/tests/tile/test_tile_load.py +26 -2
  260. warp/tests/tile/test_tile_mathdx.py +3 -3
  261. warp/tests/tile/test_tile_matmul.py +1 -1
  262. warp/tests/tile/test_tile_mlp.py +2 -4
  263. warp/tests/tile/test_tile_reduce.py +214 -13
  264. warp/tests/unittest_suites.py +6 -14
  265. warp/tests/unittest_utils.py +10 -9
  266. warp/tests/walkthrough_debug.py +3 -1
  267. warp/torch.py +6 -373
  268. warp/types.py +29 -5750
  269. warp/utils.py +10 -1659
  270. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/METADATA +47 -103
  271. warp_lang-1.10.0rc2.dist-info/RECORD +468 -0
  272. warp_lang-1.10.0rc2.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  273. warp_lang-1.10.0rc2.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  274. warp_lang-1.10.0rc2.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  275. warp_lang-1.10.0rc2.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  276. warp_lang-1.10.0rc2.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  277. warp_lang-1.10.0rc2.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  278. warp_lang-1.10.0rc2.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  279. warp_lang-1.10.0rc2.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  280. warp_lang-1.10.0rc2.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  281. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  282. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  283. warp_lang-1.10.0rc2.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  284. warp_lang-1.10.0rc2.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  285. warp_lang-1.10.0rc2.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  286. warp_lang-1.10.0rc2.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  287. warp/examples/assets/cartpole.urdf +0 -110
  288. warp/examples/assets/crazyflie.usd +0 -0
  289. warp/examples/assets/nv_ant.xml +0 -92
  290. warp/examples/assets/nv_humanoid.xml +0 -183
  291. warp/examples/assets/quadruped.urdf +0 -268
  292. warp/examples/optim/example_bounce.py +0 -266
  293. warp/examples/optim/example_cloth_throw.py +0 -228
  294. warp/examples/optim/example_drone.py +0 -870
  295. warp/examples/optim/example_inverse_kinematics.py +0 -182
  296. warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
  297. warp/examples/optim/example_softbody_properties.py +0 -400
  298. warp/examples/optim/example_spring_cage.py +0 -245
  299. warp/examples/optim/example_trajectory.py +0 -227
  300. warp/examples/sim/example_cartpole.py +0 -143
  301. warp/examples/sim/example_cloth.py +0 -225
  302. warp/examples/sim/example_cloth_self_contact.py +0 -316
  303. warp/examples/sim/example_granular.py +0 -130
  304. warp/examples/sim/example_granular_collision_sdf.py +0 -202
  305. warp/examples/sim/example_jacobian_ik.py +0 -244
  306. warp/examples/sim/example_particle_chain.py +0 -124
  307. warp/examples/sim/example_quadruped.py +0 -203
  308. warp/examples/sim/example_rigid_chain.py +0 -203
  309. warp/examples/sim/example_rigid_contact.py +0 -195
  310. warp/examples/sim/example_rigid_force.py +0 -133
  311. warp/examples/sim/example_rigid_gyroscopic.py +0 -115
  312. warp/examples/sim/example_rigid_soft_contact.py +0 -140
  313. warp/examples/sim/example_soft_body.py +0 -196
  314. warp/examples/tile/example_tile_walker.py +0 -327
  315. warp/sim/__init__.py +0 -74
  316. warp/sim/articulation.py +0 -793
  317. warp/sim/collide.py +0 -2570
  318. warp/sim/graph_coloring.py +0 -307
  319. warp/sim/import_mjcf.py +0 -791
  320. warp/sim/import_snu.py +0 -227
  321. warp/sim/import_urdf.py +0 -579
  322. warp/sim/import_usd.py +0 -898
  323. warp/sim/inertia.py +0 -357
  324. warp/sim/integrator.py +0 -245
  325. warp/sim/integrator_euler.py +0 -2000
  326. warp/sim/integrator_featherstone.py +0 -2101
  327. warp/sim/integrator_vbd.py +0 -2487
  328. warp/sim/integrator_xpbd.py +0 -3295
  329. warp/sim/model.py +0 -4821
  330. warp/sim/particles.py +0 -121
  331. warp/sim/render.py +0 -431
  332. warp/sim/utils.py +0 -431
  333. warp/tests/sim/disabled_kinematics.py +0 -244
  334. warp/tests/sim/test_cloth.py +0 -863
  335. warp/tests/sim/test_collision.py +0 -743
  336. warp/tests/sim/test_coloring.py +0 -347
  337. warp/tests/sim/test_inertia.py +0 -161
  338. warp/tests/sim/test_model.py +0 -226
  339. warp/tests/sim/test_sim_grad.py +0 -287
  340. warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
  341. warp/tests/sim/test_sim_kinematics.py +0 -98
  342. warp/thirdparty/__init__.py +0 -0
  343. warp_lang-1.9.0.dist-info/RECORD +0 -456
  344. /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
  345. /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
  346. /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
  347. /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
  348. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/WHEEL +0 -0
  349. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/licenses/LICENSE.md +0 -0
  350. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/top_level.txt +0 -0
warp/_src/autograd.py ADDED
@@ -0,0 +1,1075 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from __future__ import annotations
17
+
18
+ import inspect
19
+ import itertools
20
+ from typing import Any, Callable, Sequence
21
+
22
+ import numpy as np
23
+
24
+ import warp as wp
25
+
26
+ __all__ = [
27
+ "gradcheck",
28
+ "gradcheck_tape",
29
+ "jacobian",
30
+ "jacobian_fd",
31
+ "jacobian_plot",
32
+ ]
33
+
34
+
35
+ def gradcheck(
36
+ function: wp.Kernel | Callable,
37
+ dim: tuple[int] | None = None,
38
+ inputs: Sequence | None = None,
39
+ outputs: Sequence | None = None,
40
+ *,
41
+ eps: float = 1e-4,
42
+ atol: float = 1e-3,
43
+ rtol: float = 1e-2,
44
+ raise_exception: bool = True,
45
+ input_output_mask: list[tuple[str | int, str | int]] | None = None,
46
+ device: wp.context.Devicelike = None,
47
+ max_blocks: int = 0,
48
+ block_dim: int = 256,
49
+ max_inputs_per_var: int = -1,
50
+ max_outputs_per_var: int = -1,
51
+ plot_relative_error: bool = False,
52
+ plot_absolute_error: bool = False,
53
+ show_summary: bool = True,
54
+ ) -> bool:
55
+ """
56
+ Checks whether the autodiff gradient of a Warp kernel matches finite differences.
57
+
58
+ Given the autodiff (:math:`\\nabla_\\text{AD}`) and finite difference gradients (:math:`\\nabla_\\text{FD}`), the check succeeds if the autodiff gradients contain no NaN values and the following condition holds:
59
+
60
+ .. math::
61
+
62
+ |\\nabla_\\text{AD} - \\nabla_\\text{FD}| \\leq atol + rtol \\cdot |\\nabla_\\text{FD}|.
63
+
64
+ The kernel function and its adjoint version are launched with the given inputs and outputs, as well as the provided
65
+ ``dim``, ``max_blocks``, and ``block_dim`` arguments (see :func:`warp.launch` for more details).
66
+
67
+ Note:
68
+ This function only supports Warp kernels whose input arguments precede the output arguments.
69
+
70
+ Only Warp arrays with ``requires_grad=True`` are considered for the Jacobian computation.
71
+
72
+ Structs arguments are not yet supported by this function to compute Jacobians.
73
+
74
+ Args:
75
+ function: The Warp kernel function, decorated with the ``@wp.kernel`` decorator, or any function that involves Warp kernel launches.
76
+ dim: The number of threads to launch the kernel, can be an integer, or a Tuple of ints. Only required if the function is a Warp kernel.
77
+ inputs: List of input variables.
78
+ outputs: List of output variables. Only required if the function is a Warp kernel.
79
+ eps: The finite-difference step size.
80
+ atol: The absolute tolerance for the gradient check.
81
+ rtol: The relative tolerance for the gradient check.
82
+ raise_exception: If True, raises a `ValueError` if the gradient check fails.
83
+ input_output_mask: List of tuples specifying the input-output pairs to compute the Jacobian for. Inputs and outputs can be identified either by their integer indices of where they appear in the kernel input/output arguments, or by the respective argument names as strings. If None, computes the Jacobian for all input-output pairs.
84
+ device: The device to launch on (optional)
85
+ max_blocks: The maximum number of CUDA thread blocks to use.
86
+ block_dim: The number of threads per block.
87
+ max_inputs_per_var: Maximum number of input dimensions over which to evaluate the Jacobians for the input-output pairs. Evaluates all input dimensions if value <= 0.
88
+ max_outputs_per_var: Maximum number of output dimensions over which to evaluate the Jacobians for the input-output pairs. Evaluates all output dimensions if value <= 0.
89
+ plot_relative_error: If True, visualizes the relative error of the Jacobians in a plot (requires ``matplotlib``).
90
+ plot_absolute_error: If True, visualizes the absolute error of the Jacobians in a plot (requires ``matplotlib``).
91
+ show_summary: If True, prints a summary table of the gradient check results.
92
+
93
+ Returns:
94
+ True if the gradient check passes, False otherwise.
95
+ """
96
+
97
+ if inputs is None:
98
+ raise ValueError("The inputs argument must be provided")
99
+
100
+ metadata = FunctionMetadata()
101
+
102
+ jacs_ad = jacobian(
103
+ function,
104
+ dim=dim,
105
+ inputs=inputs,
106
+ outputs=outputs,
107
+ input_output_mask=input_output_mask,
108
+ device=device,
109
+ max_blocks=max_blocks,
110
+ block_dim=block_dim,
111
+ max_outputs_per_var=max_outputs_per_var,
112
+ plot_jacobians=False,
113
+ metadata=metadata,
114
+ )
115
+ jacs_fd = jacobian_fd(
116
+ function,
117
+ dim=dim,
118
+ inputs=inputs,
119
+ outputs=outputs,
120
+ input_output_mask=input_output_mask,
121
+ device=device,
122
+ max_blocks=max_blocks,
123
+ block_dim=block_dim,
124
+ max_inputs_per_var=max_inputs_per_var,
125
+ eps=eps,
126
+ plot_jacobians=False,
127
+ metadata=metadata,
128
+ )
129
+
130
+ relative_error_jacs = {}
131
+ absolute_error_jacs = {}
132
+
133
+ if show_summary:
134
+ summary = []
135
+ summary_header = ["Input", "Output", "Max Abs Error", "AD at MAE", "FD at MAE", "Max Rel Error", "Pass"]
136
+
137
+ class FontColors:
138
+ OKGREEN = "\033[92m"
139
+ WARNING = "\033[93m"
140
+ FAIL = "\033[91m"
141
+ ENDC = "\033[0m"
142
+
143
+ success = True
144
+ any_grad_mismatch = False
145
+ any_grad_nan = False
146
+ for (input_i, output_i), jac_fd in jacs_fd.items():
147
+ jac_ad = jacs_ad[input_i, output_i]
148
+ if plot_relative_error or plot_absolute_error:
149
+ jac_rel_error = wp.empty_like(jac_fd)
150
+ jac_abs_error = wp.empty_like(jac_fd)
151
+ flat_jac_fd = scalarize_array_1d(jac_fd)
152
+ flat_jac_ad = scalarize_array_1d(jac_ad)
153
+ flat_jac_rel_error = scalarize_array_1d(jac_rel_error)
154
+ flat_jac_abs_error = scalarize_array_1d(jac_abs_error)
155
+ wp.launch(
156
+ compute_error_kernel,
157
+ dim=len(flat_jac_fd),
158
+ inputs=[flat_jac_ad, flat_jac_fd, flat_jac_rel_error, flat_jac_abs_error],
159
+ device=jac_fd.device,
160
+ )
161
+ relative_error_jacs[(input_i, output_i)] = jac_rel_error
162
+ absolute_error_jacs[(input_i, output_i)] = jac_abs_error
163
+ cut_jac_fd = jac_fd.numpy()
164
+ cut_jac_ad = jac_ad.numpy()
165
+ if max_outputs_per_var > 0:
166
+ cut_jac_fd = cut_jac_fd[:max_outputs_per_var]
167
+ cut_jac_ad = cut_jac_ad[:max_outputs_per_var]
168
+ if max_inputs_per_var > 0:
169
+ cut_jac_fd = cut_jac_fd[:, :max_inputs_per_var]
170
+ cut_jac_ad = cut_jac_ad[:, :max_inputs_per_var]
171
+ grad_matches = np.allclose(cut_jac_ad, cut_jac_fd, atol=atol, rtol=rtol)
172
+ any_grad_mismatch = any_grad_mismatch or not grad_matches
173
+ success = success and grad_matches
174
+ isnan = np.any(np.isnan(cut_jac_ad))
175
+ any_grad_nan = any_grad_nan or isnan
176
+ success = success and not isnan
177
+
178
+ if show_summary:
179
+ max_abs_error = np.abs(cut_jac_ad - cut_jac_fd).max()
180
+ arg_max_abs_error = np.unravel_index(np.argmax(np.abs(cut_jac_ad - cut_jac_fd)), cut_jac_ad.shape)
181
+ max_rel_error = np.abs((cut_jac_ad - cut_jac_fd) / (cut_jac_fd + 1e-8)).max()
182
+ if isnan:
183
+ pass_str = FontColors.FAIL + "NaN" + FontColors.ENDC
184
+ elif grad_matches:
185
+ pass_str = FontColors.OKGREEN + "PASS" + FontColors.ENDC
186
+ else:
187
+ pass_str = FontColors.FAIL + "FAIL" + FontColors.ENDC
188
+ input_name = metadata.input_labels[input_i]
189
+ output_name = metadata.output_labels[output_i]
190
+ summary.append(
191
+ [
192
+ input_name,
193
+ output_name,
194
+ f"{max_abs_error:.3e} at {tuple(int(i) for i in arg_max_abs_error)}",
195
+ f"{cut_jac_ad[arg_max_abs_error]:.3e}",
196
+ f"{cut_jac_fd[arg_max_abs_error]:.3e}",
197
+ f"{max_rel_error:.3e}",
198
+ pass_str,
199
+ ]
200
+ )
201
+
202
+ if show_summary:
203
+ print_table(summary_header, summary)
204
+ if not success:
205
+ print(FontColors.FAIL + f"Gradient check for kernel {metadata.key} failed" + FontColors.ENDC)
206
+ else:
207
+ print(FontColors.OKGREEN + f"Gradient check for kernel {metadata.key} passed" + FontColors.ENDC)
208
+ if plot_relative_error:
209
+ jacobian_plot(
210
+ relative_error_jacs,
211
+ metadata,
212
+ inputs,
213
+ outputs,
214
+ title=f"{metadata.key} kernel Jacobian relative error",
215
+ )
216
+ if plot_absolute_error:
217
+ jacobian_plot(
218
+ absolute_error_jacs,
219
+ metadata,
220
+ inputs,
221
+ outputs,
222
+ title=f"{metadata.key} kernel Jacobian absolute error",
223
+ )
224
+
225
+ if raise_exception:
226
+ if any_grad_mismatch:
227
+ raise ValueError(
228
+ f"Gradient check failed for kernel {metadata.key}, input {input_i}, output {output_i}: "
229
+ f"finite difference and autodiff gradients do not match"
230
+ )
231
+ if any_grad_nan:
232
+ raise ValueError(
233
+ f"Gradient check failed for kernel {metadata.key}, input {input_i}, output {output_i}: "
234
+ f"gradient contains NaN values"
235
+ )
236
+
237
+ return success
238
+
239
+
240
+ def gradcheck_tape(
241
+ tape: wp.Tape,
242
+ *,
243
+ eps=1e-4,
244
+ atol=1e-3,
245
+ rtol=1e-2,
246
+ raise_exception=True,
247
+ input_output_masks: dict[str, list[tuple[str | int, str | int]]] | None = None,
248
+ blacklist_kernels: list[str] | None = None,
249
+ whitelist_kernels: list[str] | None = None,
250
+ max_inputs_per_var=-1,
251
+ max_outputs_per_var=-1,
252
+ plot_relative_error=False,
253
+ plot_absolute_error=False,
254
+ show_summary: bool = True,
255
+ reverse_launches: bool = False,
256
+ skip_to_launch_index: int = 0,
257
+ ) -> bool:
258
+ """
259
+ Checks whether the autodiff gradients for kernels recorded on the Warp tape match finite differences.
260
+
261
+ Given the autodiff (:math:`\\nabla_\\text{AD}`) and finite difference gradients (:math:`\\nabla_\\text{FD}`), the check succeeds if the autodiff gradients contain no NaN values and the following condition holds:
262
+
263
+ .. math::
264
+
265
+ |\\nabla_\\text{AD} - \\nabla_\\text{FD}| \\leq atol + rtol \\cdot |\\nabla_\\text{FD}|.
266
+
267
+ Note:
268
+ Only Warp kernels recorded on the tape are checked but not arbitrary functions that have been recorded, e.g. via :meth:`Tape.record_func`.
269
+
270
+ Only Warp arrays with ``requires_grad=True`` are considered for the Jacobian computation.
271
+
272
+ Structs arguments are not yet supported by this function to compute Jacobians.
273
+
274
+ Args:
275
+ tape: The Warp tape to perform the gradient check on.
276
+ eps: The finite-difference step size.
277
+ atol: The absolute tolerance for the gradient check.
278
+ rtol: The relative tolerance for the gradient check.
279
+ raise_exception: If True, raises a `ValueError` if the gradient check fails.
280
+ input_output_masks: Dictionary of input-output masks for each kernel in the tape, mapping from kernel keys to input-output masks. Inputs and outputs can be identified either by their integer indices of where they appear in the kernel input/output arguments, or by the respective argument names as strings. If None, computes the Jacobian for all input-output pairs.
281
+ blacklist_kernels: List of kernel keys to exclude from the gradient check.
282
+ whitelist_kernels: List of kernel keys to include in the gradient check. If not empty or None, only kernels in this list are checked.
283
+ max_inputs_per_var: Maximum number of input dimensions over which to evaluate the Jacobians for the input-output pairs. Evaluates all input dimensions if value <= 0.
284
+ max_outputs_per_var: Maximum number of output dimensions over which to evaluate the Jacobians for the input-output pairs. Evaluates all output dimensions if value <= 0.
285
+ plot_relative_error: If True, visualizes the relative error of the Jacobians in a plot (requires ``matplotlib``).
286
+ plot_absolute_error: If True, visualizes the absolute error of the Jacobians in a plot (requires ``matplotlib``).
287
+ show_summary: If True, prints a summary table of the gradient check results.
288
+ reverse_launches: If True, reverses the order of the kernel launches on the tape to check.
289
+
290
+ Returns:
291
+ True if the gradient check passes for all kernels on the tape, False otherwise.
292
+ """
293
+ if input_output_masks is None:
294
+ input_output_masks = {}
295
+ if blacklist_kernels is None:
296
+ blacklist_kernels = []
297
+ else:
298
+ blacklist_kernels = set(blacklist_kernels)
299
+ if whitelist_kernels is None:
300
+ whitelist_kernels = []
301
+ else:
302
+ whitelist_kernels = set(whitelist_kernels)
303
+
304
+ overall_success = True
305
+ launches = reversed(tape.launches) if reverse_launches else tape.launches
306
+ for i, launch in enumerate(launches):
307
+ if i < skip_to_launch_index:
308
+ continue
309
+ if not isinstance(launch, tuple) and not isinstance(launch, list):
310
+ continue
311
+ if not isinstance(launch[0], wp.Kernel):
312
+ continue
313
+ kernel, dim, max_blocks, inputs, outputs, device, block_dim = launch[:7]
314
+ if len(whitelist_kernels) > 0 and kernel.key not in whitelist_kernels:
315
+ continue
316
+ if kernel.key in blacklist_kernels:
317
+ continue
318
+ if not kernel.options.get("enable_backward", True):
319
+ continue
320
+
321
+ input_output_mask = input_output_masks.get(kernel.key)
322
+ success = gradcheck(
323
+ kernel,
324
+ dim,
325
+ inputs,
326
+ outputs,
327
+ eps=eps,
328
+ atol=atol,
329
+ rtol=rtol,
330
+ raise_exception=raise_exception,
331
+ input_output_mask=input_output_mask,
332
+ device=device,
333
+ max_blocks=max_blocks,
334
+ block_dim=block_dim,
335
+ max_inputs_per_var=max_inputs_per_var,
336
+ max_outputs_per_var=max_outputs_per_var,
337
+ plot_relative_error=plot_relative_error,
338
+ plot_absolute_error=plot_absolute_error,
339
+ show_summary=show_summary,
340
+ )
341
+ overall_success = overall_success and success
342
+
343
+ return overall_success
344
+
345
+
346
+ def get_struct_vars(x: wp._src.codegen.StructInstance):
347
+ return {varname: getattr(x, varname) for varname, _ in x._cls.ctype._fields_}
348
+
349
+
350
+ def infer_device(xs: list):
351
+ # retrieve best matching Warp device for a list of variables
352
+ for x in xs:
353
+ if isinstance(x, wp.array):
354
+ return x.device
355
+ elif isinstance(x, wp._src.codegen.StructInstance):
356
+ for var in get_struct_vars(x).values():
357
+ if isinstance(var, wp.array):
358
+ return var.device
359
+ return wp.get_preferred_device()
360
+
361
+
362
+ class FunctionMetadata:
363
+ """
364
+ Metadata holder for kernel functions or functions with Warp arrays as inputs/outputs.
365
+ """
366
+
367
+ def __init__(
368
+ self,
369
+ key: str | None = None,
370
+ input_labels: list[str] | None = None,
371
+ output_labels: list[str] | None = None,
372
+ input_strides: list[tuple] | None = None,
373
+ output_strides: list[tuple] | None = None,
374
+ input_dtypes: list | None = None,
375
+ output_dtypes: list | None = None,
376
+ ):
377
+ self.key = key
378
+ self.input_labels = input_labels
379
+ self.output_labels = output_labels
380
+ self.input_strides = input_strides
381
+ self.output_strides = output_strides
382
+ self.input_dtypes = input_dtypes
383
+ self.output_dtypes = output_dtypes
384
+
385
+ @property
386
+ def is_empty(self):
387
+ return self.key is None
388
+
389
+ def input_is_array(self, i: int):
390
+ return self.input_strides[i] is not None
391
+
392
+ def output_is_array(self, i: int):
393
+ return self.output_strides[i] is not None
394
+
395
+ def update_from_kernel(self, kernel: wp.Kernel, inputs: Sequence):
396
+ self.key = kernel.key
397
+ self.input_labels = [arg.label for arg in kernel.adj.args[: len(inputs)]]
398
+ self.output_labels = [arg.label for arg in kernel.adj.args[len(inputs) :]]
399
+ self.input_strides = []
400
+ self.output_strides = []
401
+ self.input_dtypes = []
402
+ self.output_dtypes = []
403
+ for arg in kernel.adj.args[: len(inputs)]:
404
+ if arg.type is wp.array:
405
+ self.input_strides.append(arg.type.strides)
406
+ self.input_dtypes.append(arg.type.dtype)
407
+ else:
408
+ self.input_strides.append(None)
409
+ self.input_dtypes.append(None)
410
+ for arg in kernel.adj.args[len(inputs) :]:
411
+ if arg.type is wp.array:
412
+ self.output_strides.append(arg.type.strides)
413
+ self.output_dtypes.append(arg.type.dtype)
414
+ else:
415
+ self.output_strides.append(None)
416
+ self.output_dtypes.append(None)
417
+
418
+ def update_from_function(self, function: Callable, inputs: Sequence, outputs: Sequence | None = None):
419
+ self.key = function.__name__
420
+ self.input_labels = list(inspect.signature(function).parameters.keys())
421
+ if outputs is None:
422
+ outputs = function(*inputs)
423
+ if isinstance(outputs, wp.array):
424
+ outputs = [outputs]
425
+ self.output_labels = [f"output_{i}" for i in range(len(outputs))]
426
+ self.input_strides = []
427
+ self.output_strides = []
428
+ self.input_dtypes = []
429
+ self.output_dtypes = []
430
+ for input in inputs:
431
+ if isinstance(input, wp.array):
432
+ self.input_strides.append(input.strides)
433
+ self.input_dtypes.append(input.dtype)
434
+ else:
435
+ self.input_strides.append(None)
436
+ self.input_dtypes.append(None)
437
+ for output in outputs:
438
+ if isinstance(output, wp.array):
439
+ self.output_strides.append(output.strides)
440
+ self.output_dtypes.append(output.dtype)
441
+ else:
442
+ self.output_strides.append(None)
443
+ self.output_dtypes.append(None)
444
+
445
+
446
+ def jacobian_plot(
447
+ jacobians: dict[tuple[int, int], wp.array],
448
+ kernel: FunctionMetadata | wp.Kernel,
449
+ inputs: Sequence | None = None,
450
+ show_plot: bool = True,
451
+ show_colorbar: bool = True,
452
+ scale_colors_per_submatrix: bool = False,
453
+ title: str | None = None,
454
+ colormap: str = "coolwarm",
455
+ log_scale: bool = False,
456
+ ):
457
+ """
458
+ Visualizes the Jacobians computed by :func:`jacobian` or :func:`jacobian_fd` in a combined image plot.
459
+ Requires the ``matplotlib`` package to be installed.
460
+
461
+ Args:
462
+ jacobians: A dictionary of Jacobians, where the keys are tuples of input and output indices, and the values are the Jacobian matrices.
463
+ kernel: The Warp kernel function, decorated with the ``@wp.kernel`` decorator, or a :class:`FunctionMetadata` instance with the kernel/function attributes.
464
+ inputs: List of input variables.
465
+ show_plot: If True, displays the plot via ``plt.show()``.
466
+ show_colorbar: If True, displays a colorbar next to the plot (or a colorbar next to every submatrix if ).
467
+ scale_colors_per_submatrix: If True, considers the minimum and maximum of each Jacobian submatrix separately for color scaling. Otherwise, uses the global minimum and maximum of all Jacobians.
468
+ title: The title of the plot (optional).
469
+ colormap: The colormap to use for the plot.
470
+ log_scale: If True, uses a logarithmic scale for the matrix values shown in the image plot.
471
+
472
+ Returns:
473
+ The created Matplotlib figure.
474
+ """
475
+
476
+ import matplotlib.pyplot as plt
477
+ from matplotlib.ticker import MaxNLocator
478
+
479
+ if isinstance(kernel, wp.Kernel):
480
+ assert inputs is not None
481
+ metadata = FunctionMetadata()
482
+ metadata.update_from_kernel(kernel, inputs)
483
+ elif isinstance(kernel, FunctionMetadata):
484
+ metadata = kernel
485
+ else:
486
+ raise ValueError("Invalid kernel argument: must be a Warp kernel or a FunctionMetadata object")
487
+
488
+ jacobians = sorted(jacobians.items(), key=lambda x: (x[0][1], x[0][0]))
489
+ jacobians = dict(jacobians)
490
+
491
+ input_to_ax = {}
492
+ output_to_ax = {}
493
+ ax_to_input = {}
494
+ ax_to_output = {}
495
+ for i, j in jacobians.keys():
496
+ if i not in input_to_ax:
497
+ input_to_ax[i] = len(input_to_ax)
498
+ ax_to_input[input_to_ax[i]] = i
499
+ if j not in output_to_ax:
500
+ output_to_ax[j] = len(output_to_ax)
501
+ ax_to_output[output_to_ax[j]] = j
502
+
503
+ num_rows = len(output_to_ax)
504
+ num_cols = len(input_to_ax)
505
+ if num_rows == 0 or num_cols == 0:
506
+ return
507
+
508
+ # determine the width and height ratios for the subplots based on the
509
+ # dimensions of the Jacobians
510
+ width_ratios = []
511
+ height_ratios = []
512
+ for i in range(len(metadata.input_labels)):
513
+ if not metadata.input_is_array(i):
514
+ continue
515
+ input_stride = metadata.input_strides[i][0]
516
+ for j in range(len(metadata.output_labels)):
517
+ if (i, j) not in jacobians:
518
+ continue
519
+ jac_wp = jacobians[(i, j)]
520
+ width_ratios.append(jac_wp.shape[1] * input_stride)
521
+ break
522
+
523
+ for i in range(len(metadata.output_labels)):
524
+ if not metadata.output_is_array(i):
525
+ continue
526
+ for j in range(len(inputs)):
527
+ if (j, i) not in jacobians:
528
+ continue
529
+ jac_wp = jacobians[(j, i)]
530
+ height_ratios.append(jac_wp.shape[0])
531
+ break
532
+
533
+ fig, axs = plt.subplots(
534
+ ncols=num_cols,
535
+ nrows=num_rows,
536
+ figsize=(7, 7),
537
+ sharex="col",
538
+ sharey="row",
539
+ gridspec_kw={
540
+ "wspace": 0.1,
541
+ "hspace": 0.1,
542
+ "width_ratios": width_ratios,
543
+ "height_ratios": height_ratios,
544
+ },
545
+ subplot_kw={"aspect": 1},
546
+ squeeze=False,
547
+ )
548
+ if title is None:
549
+ key = kernel.key if isinstance(kernel, wp.Kernel) else kernel.get("key", "unknown")
550
+ title = f"{key} kernel Jacobian"
551
+ fig.suptitle(title)
552
+ fig.canvas.manager.set_window_title(title)
553
+
554
+ if not scale_colors_per_submatrix:
555
+ safe_jacobians = [jac.numpy().flatten() for jac in jacobians.values()]
556
+ safe_jacobians = [jac[~np.isnan(jac)] for jac in safe_jacobians]
557
+ safe_jacobians = [jac for jac in safe_jacobians if len(jac) > 0]
558
+ if len(safe_jacobians) == 0:
559
+ vmin = 0
560
+ vmax = 0
561
+ else:
562
+ vmin = min([jac.min() for jac in safe_jacobians])
563
+ vmax = max([jac.max() for jac in safe_jacobians])
564
+
565
+ has_plot = np.ones((num_rows, num_cols), dtype=bool)
566
+ for i in range(num_rows):
567
+ for j in range(num_cols):
568
+ if (ax_to_input[j], ax_to_output[i]) not in jacobians:
569
+ ax = axs[i, j]
570
+ ax.axis("off")
571
+ has_plot[i, j] = False
572
+
573
+ jac_i = 0
574
+ for (input_i, output_i), jac_wp in jacobians.items():
575
+ input_name = metadata.input_labels[input_i]
576
+ output_name = metadata.output_labels[output_i]
577
+
578
+ ax_i, ax_j = output_to_ax[output_i], input_to_ax[input_i]
579
+ ax = axs[ax_i, ax_j]
580
+ ax.tick_params(which="major", width=1, length=7)
581
+ ax.tick_params(which="minor", width=1, length=4, color="gray")
582
+
583
+ input_stride = metadata.input_dtypes[input_i]._length_
584
+ # output_stride = metadata.output_dtypes[output_i]._length_
585
+
586
+ jac = jac_wp.numpy()
587
+ # Jacobian matrix has output stride already multiplied to first dimension
588
+ jac = jac.reshape(jac_wp.shape[0], jac_wp.shape[1] * input_stride)
589
+
590
+ ax.xaxis.set_major_locator(MaxNLocator(integer=True))
591
+ ax.yaxis.set_major_locator(MaxNLocator(integer=True))
592
+
593
+ if scale_colors_per_submatrix:
594
+ safe_jac = jac[~np.isnan(jac)]
595
+ vmin = safe_jac.min()
596
+ vmax = safe_jac.max()
597
+ img = ax.imshow(
598
+ np.log10(np.abs(jac) + 1e-8) if log_scale else jac,
599
+ cmap=colormap,
600
+ aspect="auto",
601
+ interpolation="nearest",
602
+ extent=[0, jac.shape[1], 0, jac.shape[0]],
603
+ vmin=vmin,
604
+ vmax=vmax,
605
+ )
606
+ if ax_i == num_rows - 1 or not has_plot[ax_i + 1 :, ax_j].any():
607
+ # last plot of this column
608
+ ax.set_xlabel(input_name)
609
+ if ax_j == 0 or not has_plot[ax_i, :ax_j].any():
610
+ # first plot of this row
611
+ ax.set_ylabel(output_name)
612
+ ax.grid(color="gray", which="minor", linestyle="--", linewidth=0.5)
613
+ ax.grid(color="black", which="major", linewidth=1.0)
614
+
615
+ if show_colorbar and scale_colors_per_submatrix:
616
+ plt.colorbar(img, ax=ax, orientation="vertical", pad=0.02)
617
+
618
+ jac_i += 1
619
+
620
+ if show_colorbar and not scale_colors_per_submatrix:
621
+ m = plt.cm.ScalarMappable(cmap=colormap)
622
+ m.set_array([vmin, vmax])
623
+ m.set_clim(vmin, vmax)
624
+ plt.colorbar(m, ax=axs, orientation="vertical", pad=0.02)
625
+
626
+ plt.tight_layout()
627
+ if show_plot:
628
+ plt.show()
629
+ return fig
630
+
631
+
632
+ def scalarize_array_1d(arr):
633
+ # convert array to 1D array with scalar dtype
634
+ if arr.dtype in wp._src.types.scalar_types:
635
+ return arr.flatten()
636
+ elif arr.dtype in wp._src.types.vector_types:
637
+ return wp.array(
638
+ ptr=arr.ptr,
639
+ shape=(arr.size * arr.dtype._length_,),
640
+ dtype=arr.dtype._wp_scalar_type_,
641
+ device=arr.device,
642
+ )
643
+ else:
644
+ raise ValueError(
645
+ f"Unsupported array dtype {arr.dtype}: array to be flattened must be a scalar/vector/matrix array"
646
+ )
647
+
648
+
649
+ def scalarize_array_2d(arr):
650
+ assert arr.ndim == 2
651
+ # convert array to 2D array with scalar dtype
652
+ if arr.dtype in wp._src.types.scalar_types:
653
+ return arr
654
+ elif arr.dtype in wp._src.types.vector_types:
655
+ return wp.array(
656
+ ptr=arr.ptr,
657
+ shape=(arr.shape[0], arr.shape[1] * arr.dtype._length_),
658
+ dtype=arr.dtype._wp_scalar_type_,
659
+ device=arr.device,
660
+ )
661
+ else:
662
+ raise ValueError(
663
+ f"Unsupported array dtype {arr.dtype}: array to be flattened must be a scalar/vector/matrix array"
664
+ )
665
+
666
+
667
+ def jacobian(
668
+ function: wp.Kernel | Callable,
669
+ dim: tuple[int] | None = None,
670
+ inputs: Sequence | None = None,
671
+ outputs: Sequence | None = None,
672
+ input_output_mask: list[tuple[str | int, str | int]] | None = None,
673
+ device: wp.context.Devicelike = None,
674
+ max_blocks=0,
675
+ block_dim=256,
676
+ max_outputs_per_var=-1,
677
+ plot_jacobians=False,
678
+ metadata: FunctionMetadata | None = None,
679
+ ) -> dict[tuple[int, int], wp.array]:
680
+ """
681
+ Computes the Jacobians of a function or Warp kernel for the provided selection of differentiable inputs to differentiable outputs.
682
+
683
+ The input function can be either a Warp kernel (e.g. a function decorated by ``@wp.kernel``) or a regular Python function that accepts arguments (of which some must be Warp arrays) and returns a Warp array or a list of Warp arrays.
684
+
685
+ In case ``function`` is a Warp kernel, its adjoint kernel is launched with the given inputs and outputs, as well as the provided ``dim``,
686
+ ``max_blocks``, and ``block_dim`` arguments (see :func:`warp.launch` for more details).
687
+
688
+ Note:
689
+ If ``function`` is a Warp kernel, the input arguments must precede the output arguments in the kernel code definition.
690
+
691
+ Only Warp arrays with ``requires_grad=True`` are considered for the Jacobian computation.
692
+
693
+ Function arguments of type :ref:`Struct <structs>` are not yet supported.
694
+
695
+ Args:
696
+ function: The Warp kernel function, or a regular Python function that returns a Warp array or a list of Warp arrays.
697
+ dim: The number of threads to launch the kernel, can be an integer, or a Tuple of ints. Only required if ``function`` is a Warp kernel.
698
+ inputs: List of input variables. At least one of the arguments must be a Warp array with ``requires_grad=True``.
699
+ outputs: List of output variables. Optional if the function is a regular Python function that returns a Warp array or a list of Warp arrays. Only required if ``function`` is a Warp kernel.
700
+ input_output_mask: List of tuples specifying the input-output pairs to compute the Jacobian for. Inputs and outputs can be identified either by their integer indices of where they appear in the kernel input/output arguments, or by the respective argument names as strings. If None, computes the Jacobian for all input-output pairs.
701
+ device: The device to launch on (optional). Only used if ``function`` is a Warp kernel.
702
+ max_blocks: The maximum number of CUDA thread blocks to use. Only used if ``function`` is a Warp kernel.
703
+ block_dim: The number of threads per block. Only used if ``function`` is a Warp kernel.
704
+ max_outputs_per_var: Maximum number of output dimensions over which to evaluate the Jacobians for the input-output pairs. Evaluates all output dimensions if value <= 0.
705
+ plot_jacobians: If True, visualizes the computed Jacobians in a plot (requires ``matplotlib``).
706
+ metadata: The metadata of the kernel function, containing the input and output labels, strides, and dtypes. If None or empty, the metadata is inferred from the kernel or function.
707
+
708
+ Returns:
709
+ A dictionary of Jacobians, where the keys are tuples of input and output indices, and the values are the Jacobian matrices.
710
+ """
711
+ if input_output_mask is None:
712
+ input_output_mask = []
713
+
714
+ if metadata is None:
715
+ metadata = FunctionMetadata()
716
+
717
+ if isinstance(function, wp.Kernel):
718
+ if not function.options.get("enable_backward", True):
719
+ raise ValueError("Kernel must have backward pass enabled to compute Jacobians")
720
+ if outputs is None or len(outputs) == 0:
721
+ raise ValueError("A list of output arguments must be provided to compute kernel Jacobians")
722
+ if device is None:
723
+ device = infer_device(inputs + outputs)
724
+ if metadata.is_empty:
725
+ metadata.update_from_kernel(function, inputs)
726
+
727
+ tape = wp.Tape()
728
+ tape.record_launch(
729
+ kernel=function,
730
+ dim=dim,
731
+ inputs=inputs,
732
+ outputs=outputs,
733
+ device=device,
734
+ max_blocks=max_blocks,
735
+ block_dim=block_dim,
736
+ )
737
+ else:
738
+ tape = wp.Tape()
739
+ with tape:
740
+ outputs = function(*inputs)
741
+ if isinstance(outputs, wp.array):
742
+ outputs = [outputs]
743
+ if metadata.is_empty:
744
+ metadata.update_from_function(function, inputs, outputs)
745
+
746
+ arg_names = metadata.input_labels + metadata.output_labels
747
+
748
+ def resolve_arg(name, offset: int = 0):
749
+ if isinstance(name, int):
750
+ return name
751
+ return arg_names.index(name) + offset
752
+
753
+ input_output_mask = [
754
+ (resolve_arg(input_name), resolve_arg(output_name, -len(inputs)))
755
+ for input_name, output_name in input_output_mask
756
+ ]
757
+ input_output_mask = set(input_output_mask)
758
+
759
+ zero_grads(inputs)
760
+ zero_grads(outputs)
761
+
762
+ jacobians = {}
763
+
764
+ for input_i, output_i in itertools.product(range(len(inputs)), range(len(outputs))):
765
+ if len(input_output_mask) > 0 and (input_i, output_i) not in input_output_mask:
766
+ continue
767
+ input = inputs[input_i]
768
+ output = outputs[output_i]
769
+ if not isinstance(input, wp.array) or not input.requires_grad:
770
+ continue
771
+ if not isinstance(output, wp.array) or not output.requires_grad:
772
+ continue
773
+ out_grad = scalarize_array_1d(output.grad)
774
+ output_num = out_grad.shape[0]
775
+ jacobian = wp.empty((output_num, input.size), dtype=input.dtype, device=input.device)
776
+ jacobian.fill_(wp.nan)
777
+ if max_outputs_per_var > 0:
778
+ output_num = min(output_num, max_outputs_per_var)
779
+ for i in range(output_num):
780
+ output.grad.zero_()
781
+ if i > 0:
782
+ set_element(out_grad, i - 1, 0.0)
783
+ set_element(out_grad, i, 1.0)
784
+ tape.backward()
785
+ jacobian[i].assign(input.grad)
786
+
787
+ zero_grads(inputs)
788
+ zero_grads(outputs)
789
+ jacobians[input_i, output_i] = jacobian
790
+
791
+ if plot_jacobians:
792
+ jacobian_plot(
793
+ jacobians,
794
+ metadata,
795
+ inputs,
796
+ outputs,
797
+ )
798
+
799
+ return jacobians
800
+
801
+
802
+ def jacobian_fd(
803
+ function: wp.Kernel | Callable,
804
+ dim: tuple[int] | None | None = None,
805
+ inputs: Sequence | None = None,
806
+ outputs: Sequence | None = None,
807
+ input_output_mask: list[tuple[str | int, str | int]] | None = None,
808
+ device: wp.context.Devicelike = None,
809
+ max_blocks=0,
810
+ block_dim=256,
811
+ max_inputs_per_var=-1,
812
+ eps: float = 1e-4,
813
+ plot_jacobians=False,
814
+ metadata: FunctionMetadata | None = None,
815
+ ) -> dict[tuple[int, int], wp.array]:
816
+ """
817
+ Computes the finite-difference Jacobian of a function or Warp kernel for the provided selection of differentiable inputs to differentiable outputs.
818
+ The method uses a central difference scheme to approximate the Jacobian.
819
+
820
+ The input function can be either a Warp kernel (e.g. a function decorated by ``@wp.kernel``) or a regular Python function that accepts arguments (of which some must be Warp arrays) and returns a Warp array or a list of Warp arrays.
821
+
822
+ The function is launched multiple times in forward-only mode with the given inputs. If ``function`` is a Warp kernel, the provided inputs and outputs,
823
+ as well as the other parameters ``dim``, ``max_blocks``, and ``block_dim`` are provided to the kernel launch (see :func:`warp.launch`).
824
+
825
+ Note:
826
+ If ``function`` is a Warp kernel, the input arguments must precede the output arguments in the kernel code definition.
827
+
828
+ Only Warp arrays with ``requires_grad=True`` are considered for the Jacobian computation.
829
+
830
+ Function arguments of type :ref:`Struct <structs>` are not yet supported.
831
+
832
+ Args:
833
+ function: The Warp kernel function, or a regular Python function that returns a Warp array or a list of Warp arrays.
834
+ dim: The number of threads to launch the kernel, can be an integer, or a Tuple of ints. Only required if ``function`` is a Warp kernel.
835
+ inputs: List of input variables. At least one of the arguments must be a Warp array with ``requires_grad=True``.
836
+ outputs: List of output variables. Optional if the function is a regular Python function that returns a Warp array or a list of Warp arrays. Only required if ``function`` is a Warp kernel.
837
+ input_output_mask: List of tuples specifying the input-output pairs to compute the Jacobian for. Inputs and outputs can be identified either by their integer indices of where they appear in the kernel input/output arguments, or by the respective argument names as strings. If None, computes the Jacobian for all input-output pairs.
838
+ device: The device to launch on (optional). Only used if ``function`` is a Warp kernel.
839
+ max_blocks: The maximum number of CUDA thread blocks to use. Only used if ``function`` is a Warp kernel.
840
+ block_dim: The number of threads per block. Only used if ``function`` is a Warp kernel.
841
+ max_inputs_per_var: Maximum number of input dimensions over which to evaluate the Jacobians for the input-output pairs. Evaluates all input dimensions if value <= 0.
842
+ eps: The finite-difference step size.
843
+ plot_jacobians: If True, visualizes the computed Jacobians in a plot (requires ``matplotlib``).
844
+ metadata: The metadata of the kernel function, containing the input and output labels, strides, and dtypes. If None or empty, the metadata is inferred from the kernel or function.
845
+
846
+ Returns:
847
+ A dictionary of Jacobians, where the keys are tuples of input and output indices, and the values are the Jacobian matrices.
848
+ """
849
+ if input_output_mask is None:
850
+ input_output_mask = []
851
+
852
+ if metadata is None:
853
+ metadata = FunctionMetadata()
854
+
855
+ if isinstance(function, wp.Kernel):
856
+ if not function.options.get("enable_backward", True):
857
+ raise ValueError("Kernel must have backward pass enabled to compute Jacobians")
858
+ if outputs is None or len(outputs) == 0:
859
+ raise ValueError("A list of output arguments must be provided to compute kernel Jacobians")
860
+ if device is None:
861
+ device = infer_device(inputs + outputs)
862
+ if metadata.is_empty:
863
+ metadata.update_from_kernel(function, inputs)
864
+
865
+ tape = wp.Tape()
866
+ tape.record_launch(
867
+ kernel=function,
868
+ dim=dim,
869
+ inputs=inputs,
870
+ outputs=outputs,
871
+ device=device,
872
+ max_blocks=max_blocks,
873
+ block_dim=block_dim,
874
+ )
875
+ else:
876
+ tape = wp.Tape()
877
+ with tape:
878
+ outputs = function(*inputs)
879
+ if isinstance(outputs, wp.array):
880
+ outputs = [outputs]
881
+ if metadata.is_empty:
882
+ metadata.update_from_function(function, inputs, outputs)
883
+
884
+ arg_names = metadata.input_labels + metadata.output_labels
885
+
886
+ def resolve_arg(name, offset: int = 0):
887
+ if isinstance(name, int):
888
+ return name
889
+ return arg_names.index(name) + offset
890
+
891
+ input_output_mask = [
892
+ (resolve_arg(input_name), resolve_arg(output_name, -len(inputs)))
893
+ for input_name, output_name in input_output_mask
894
+ ]
895
+ input_output_mask = set(input_output_mask)
896
+
897
+ jacobians = {}
898
+
899
+ def conditional_clone(obj):
900
+ if isinstance(obj, wp.array):
901
+ return wp.clone(obj)
902
+ return obj
903
+
904
+ outputs_copy = [conditional_clone(output) for output in outputs]
905
+
906
+ for input_i, output_i in itertools.product(range(len(inputs)), range(len(outputs))):
907
+ if len(input_output_mask) > 0 and (input_i, output_i) not in input_output_mask:
908
+ continue
909
+ input = inputs[input_i]
910
+ output = outputs[output_i]
911
+ if not isinstance(input, wp.array) or not input.requires_grad:
912
+ continue
913
+ if not isinstance(output, wp.array) or not output.requires_grad:
914
+ continue
915
+
916
+ flat_input = scalarize_array_1d(input)
917
+
918
+ left = wp.clone(output)
919
+ right = wp.clone(output)
920
+ left_copy = wp.clone(output)
921
+ right_copy = wp.clone(output)
922
+ flat_left = scalarize_array_1d(left)
923
+ flat_right = scalarize_array_1d(right)
924
+
925
+ outputs_until_left = [conditional_clone(output) for output in outputs_copy[:output_i]]
926
+ outputs_until_right = [conditional_clone(output) for output in outputs_copy[:output_i]]
927
+ outputs_after_left = [conditional_clone(output) for output in outputs_copy[output_i + 1 :]]
928
+ outputs_after_right = [conditional_clone(output) for output in outputs_copy[output_i + 1 :]]
929
+ left_outputs = [*outputs_until_left, left, *outputs_after_left]
930
+ right_outputs = [*outputs_until_right, right, *outputs_after_right]
931
+
932
+ input_num = flat_input.shape[0]
933
+ flat_input_copy = wp.clone(flat_input)
934
+ jacobian = wp.empty((flat_left.size, input.size), dtype=input.dtype, device=input.device)
935
+ jacobian.fill_(wp.nan)
936
+
937
+ jacobian_scalar = scalarize_array_2d(jacobian)
938
+ jacobian_t = jacobian_scalar.transpose()
939
+ if max_inputs_per_var > 0:
940
+ input_num = min(input_num, max_inputs_per_var)
941
+ for i in range(input_num):
942
+ set_element(flat_input, i, -eps, relative=True)
943
+ if isinstance(function, wp.Kernel):
944
+ wp.launch(
945
+ function,
946
+ dim=dim,
947
+ max_blocks=max_blocks,
948
+ block_dim=block_dim,
949
+ inputs=inputs,
950
+ outputs=left_outputs,
951
+ device=device,
952
+ )
953
+ else:
954
+ outputs = function(*inputs)
955
+ if isinstance(outputs, wp.array):
956
+ outputs = [outputs]
957
+ left.assign(outputs[output_i])
958
+
959
+ set_element(flat_input, i, 2 * eps, relative=True)
960
+ if isinstance(function, wp.Kernel):
961
+ wp.launch(
962
+ function,
963
+ dim=dim,
964
+ max_blocks=max_blocks,
965
+ block_dim=block_dim,
966
+ inputs=inputs,
967
+ outputs=right_outputs,
968
+ device=device,
969
+ )
970
+ else:
971
+ outputs = function(*inputs)
972
+ if isinstance(outputs, wp.array):
973
+ outputs = [outputs]
974
+ right.assign(outputs[output_i])
975
+
976
+ # restore input
977
+ flat_input.assign(flat_input_copy)
978
+
979
+ compute_fd(
980
+ flat_left,
981
+ flat_right,
982
+ eps,
983
+ jacobian_t[i],
984
+ )
985
+
986
+ if i < input_num - 1:
987
+ # reset output buffers
988
+ left.assign(left_copy)
989
+ right.assign(right_copy)
990
+ flat_left = scalarize_array_1d(left)
991
+ flat_right = scalarize_array_1d(right)
992
+
993
+ jacobians[input_i, output_i] = jacobian
994
+
995
+ if plot_jacobians:
996
+ jacobian_plot(
997
+ jacobians,
998
+ metadata,
999
+ inputs,
1000
+ outputs,
1001
+ )
1002
+
1003
+ return jacobians
1004
+
1005
+
1006
+ @wp.kernel(enable_backward=False)
1007
+ def set_element_kernel(a: wp.array(dtype=Any), i: int, val: Any, relative: bool):
1008
+ if relative:
1009
+ a[i] += val
1010
+ else:
1011
+ a[i] = val
1012
+
1013
+
1014
+ def set_element(a: wp.array(dtype=Any), i: int, val: Any, relative: bool = False):
1015
+ wp.launch(set_element_kernel, dim=1, inputs=[a, i, a.dtype(val), relative], device=a.device)
1016
+
1017
+
1018
+ @wp.kernel(enable_backward=False)
1019
+ def compute_fd_kernel(left: wp.array(dtype=float), right: wp.array(dtype=float), eps: float, fd: wp.array(dtype=float)):
1020
+ tid = wp.tid()
1021
+ fd[tid] = (right[tid] - left[tid]) / (2.0 * eps)
1022
+
1023
+
1024
+ def compute_fd(left: wp.array(dtype=Any), right: wp.array(dtype=Any), eps: float, fd: wp.array(dtype=Any)):
1025
+ wp.launch(compute_fd_kernel, dim=len(left), inputs=[left, right, eps], outputs=[fd], device=left.device)
1026
+
1027
+
1028
+ @wp.kernel(enable_backward=False)
1029
+ def compute_error_kernel(
1030
+ jacobian_ad: wp.array(dtype=Any),
1031
+ jacobian_fd: wp.array(dtype=Any),
1032
+ relative_error: wp.array(dtype=Any),
1033
+ absolute_error: wp.array(dtype=Any),
1034
+ ):
1035
+ tid = wp.tid()
1036
+ ad = jacobian_ad[tid]
1037
+ fd = jacobian_fd[tid]
1038
+ denom = ad
1039
+ if abs(ad) < 1e-8:
1040
+ denom = (type(ad))(1e-8)
1041
+ relative_error[tid] = (ad - fd) / denom
1042
+ absolute_error[tid] = wp.abs(ad - fd)
1043
+
1044
+
1045
+ def print_table(headers, cells):
1046
+ """
1047
+ Prints a table with the given headers and cells.
1048
+
1049
+ Args:
1050
+ headers: List of header strings.
1051
+ cells: List of lists of cell strings.
1052
+ """
1053
+ import re
1054
+
1055
+ def sanitized_len(s):
1056
+ return len(re.sub(r"\033\[\d+m", "", str(s)))
1057
+
1058
+ col_widths = [max(sanitized_len(cell) for cell in col) for col in zip(headers, *cells)]
1059
+ for header, col_width in zip(headers, col_widths):
1060
+ print(f"{header:{col_width}}", end=" | ")
1061
+ print()
1062
+ print("-" * (sum(col_widths) + 3 * len(col_widths) - 1))
1063
+ for cell_row in cells:
1064
+ for cell, col_width in zip(cell_row, col_widths):
1065
+ print(f"{cell:{col_width}}", end=" | ")
1066
+ print()
1067
+
1068
+
1069
+ def zero_grads(arrays: list):
1070
+ """
1071
+ Zeros the gradients of all Warp arrays in the given list.
1072
+ """
1073
+ for array in arrays:
1074
+ if isinstance(array, wp.array) and array.requires_grad:
1075
+ array.grad.zero_()