warp-lang 1.7.0__py3-none-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (429) hide show
  1. warp/__init__.py +139 -0
  2. warp/__init__.pyi +1 -0
  3. warp/autograd.py +1142 -0
  4. warp/bin/warp-clang.so +0 -0
  5. warp/bin/warp.so +0 -0
  6. warp/build.py +557 -0
  7. warp/build_dll.py +405 -0
  8. warp/builtins.py +6855 -0
  9. warp/codegen.py +3969 -0
  10. warp/config.py +158 -0
  11. warp/constants.py +57 -0
  12. warp/context.py +6812 -0
  13. warp/dlpack.py +462 -0
  14. warp/examples/__init__.py +24 -0
  15. warp/examples/assets/bear.usd +0 -0
  16. warp/examples/assets/bunny.usd +0 -0
  17. warp/examples/assets/cartpole.urdf +110 -0
  18. warp/examples/assets/crazyflie.usd +0 -0
  19. warp/examples/assets/cube.usd +0 -0
  20. warp/examples/assets/nonuniform.usd +0 -0
  21. warp/examples/assets/nv_ant.xml +92 -0
  22. warp/examples/assets/nv_humanoid.xml +183 -0
  23. warp/examples/assets/nvidia_logo.png +0 -0
  24. warp/examples/assets/pixel.jpg +0 -0
  25. warp/examples/assets/quadruped.urdf +268 -0
  26. warp/examples/assets/rocks.nvdb +0 -0
  27. warp/examples/assets/rocks.usd +0 -0
  28. warp/examples/assets/sphere.usd +0 -0
  29. warp/examples/assets/square_cloth.usd +0 -0
  30. warp/examples/benchmarks/benchmark_api.py +389 -0
  31. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  32. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  33. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  34. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  35. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  36. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  37. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  38. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  39. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  40. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  41. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  42. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  43. warp/examples/benchmarks/benchmark_launches.py +301 -0
  44. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  45. warp/examples/browse.py +37 -0
  46. warp/examples/core/example_cupy.py +86 -0
  47. warp/examples/core/example_dem.py +241 -0
  48. warp/examples/core/example_fluid.py +299 -0
  49. warp/examples/core/example_graph_capture.py +150 -0
  50. warp/examples/core/example_marching_cubes.py +194 -0
  51. warp/examples/core/example_mesh.py +180 -0
  52. warp/examples/core/example_mesh_intersect.py +211 -0
  53. warp/examples/core/example_nvdb.py +182 -0
  54. warp/examples/core/example_raycast.py +111 -0
  55. warp/examples/core/example_raymarch.py +205 -0
  56. warp/examples/core/example_render_opengl.py +193 -0
  57. warp/examples/core/example_sample_mesh.py +300 -0
  58. warp/examples/core/example_sph.py +411 -0
  59. warp/examples/core/example_torch.py +211 -0
  60. warp/examples/core/example_wave.py +269 -0
  61. warp/examples/fem/example_adaptive_grid.py +286 -0
  62. warp/examples/fem/example_apic_fluid.py +423 -0
  63. warp/examples/fem/example_burgers.py +261 -0
  64. warp/examples/fem/example_convection_diffusion.py +178 -0
  65. warp/examples/fem/example_convection_diffusion_dg.py +204 -0
  66. warp/examples/fem/example_deformed_geometry.py +172 -0
  67. warp/examples/fem/example_diffusion.py +196 -0
  68. warp/examples/fem/example_diffusion_3d.py +225 -0
  69. warp/examples/fem/example_diffusion_mgpu.py +220 -0
  70. warp/examples/fem/example_distortion_energy.py +228 -0
  71. warp/examples/fem/example_magnetostatics.py +240 -0
  72. warp/examples/fem/example_mixed_elasticity.py +291 -0
  73. warp/examples/fem/example_navier_stokes.py +261 -0
  74. warp/examples/fem/example_nonconforming_contact.py +298 -0
  75. warp/examples/fem/example_stokes.py +213 -0
  76. warp/examples/fem/example_stokes_transfer.py +262 -0
  77. warp/examples/fem/example_streamlines.py +352 -0
  78. warp/examples/fem/utils.py +1000 -0
  79. warp/examples/interop/example_jax_callable.py +116 -0
  80. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  81. warp/examples/interop/example_jax_kernel.py +205 -0
  82. warp/examples/optim/example_bounce.py +266 -0
  83. warp/examples/optim/example_cloth_throw.py +228 -0
  84. warp/examples/optim/example_diffray.py +561 -0
  85. warp/examples/optim/example_drone.py +870 -0
  86. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  87. warp/examples/optim/example_inverse_kinematics.py +182 -0
  88. warp/examples/optim/example_inverse_kinematics_torch.py +191 -0
  89. warp/examples/optim/example_softbody_properties.py +400 -0
  90. warp/examples/optim/example_spring_cage.py +245 -0
  91. warp/examples/optim/example_trajectory.py +227 -0
  92. warp/examples/sim/example_cartpole.py +143 -0
  93. warp/examples/sim/example_cloth.py +225 -0
  94. warp/examples/sim/example_cloth_self_contact.py +322 -0
  95. warp/examples/sim/example_granular.py +130 -0
  96. warp/examples/sim/example_granular_collision_sdf.py +202 -0
  97. warp/examples/sim/example_jacobian_ik.py +244 -0
  98. warp/examples/sim/example_particle_chain.py +124 -0
  99. warp/examples/sim/example_quadruped.py +203 -0
  100. warp/examples/sim/example_rigid_chain.py +203 -0
  101. warp/examples/sim/example_rigid_contact.py +195 -0
  102. warp/examples/sim/example_rigid_force.py +133 -0
  103. warp/examples/sim/example_rigid_gyroscopic.py +115 -0
  104. warp/examples/sim/example_rigid_soft_contact.py +140 -0
  105. warp/examples/sim/example_soft_body.py +196 -0
  106. warp/examples/tile/example_tile_cholesky.py +87 -0
  107. warp/examples/tile/example_tile_convolution.py +66 -0
  108. warp/examples/tile/example_tile_fft.py +55 -0
  109. warp/examples/tile/example_tile_filtering.py +113 -0
  110. warp/examples/tile/example_tile_matmul.py +85 -0
  111. warp/examples/tile/example_tile_mlp.py +383 -0
  112. warp/examples/tile/example_tile_nbody.py +199 -0
  113. warp/examples/tile/example_tile_walker.py +327 -0
  114. warp/fabric.py +355 -0
  115. warp/fem/__init__.py +106 -0
  116. warp/fem/adaptivity.py +508 -0
  117. warp/fem/cache.py +572 -0
  118. warp/fem/dirichlet.py +202 -0
  119. warp/fem/domain.py +411 -0
  120. warp/fem/field/__init__.py +125 -0
  121. warp/fem/field/field.py +619 -0
  122. warp/fem/field/nodal_field.py +326 -0
  123. warp/fem/field/restriction.py +37 -0
  124. warp/fem/field/virtual.py +848 -0
  125. warp/fem/geometry/__init__.py +32 -0
  126. warp/fem/geometry/adaptive_nanogrid.py +857 -0
  127. warp/fem/geometry/closest_point.py +84 -0
  128. warp/fem/geometry/deformed_geometry.py +221 -0
  129. warp/fem/geometry/element.py +776 -0
  130. warp/fem/geometry/geometry.py +362 -0
  131. warp/fem/geometry/grid_2d.py +392 -0
  132. warp/fem/geometry/grid_3d.py +452 -0
  133. warp/fem/geometry/hexmesh.py +911 -0
  134. warp/fem/geometry/nanogrid.py +571 -0
  135. warp/fem/geometry/partition.py +389 -0
  136. warp/fem/geometry/quadmesh.py +663 -0
  137. warp/fem/geometry/tetmesh.py +855 -0
  138. warp/fem/geometry/trimesh.py +806 -0
  139. warp/fem/integrate.py +2335 -0
  140. warp/fem/linalg.py +419 -0
  141. warp/fem/operator.py +293 -0
  142. warp/fem/polynomial.py +229 -0
  143. warp/fem/quadrature/__init__.py +17 -0
  144. warp/fem/quadrature/pic_quadrature.py +299 -0
  145. warp/fem/quadrature/quadrature.py +591 -0
  146. warp/fem/space/__init__.py +228 -0
  147. warp/fem/space/basis_function_space.py +468 -0
  148. warp/fem/space/basis_space.py +667 -0
  149. warp/fem/space/dof_mapper.py +251 -0
  150. warp/fem/space/function_space.py +309 -0
  151. warp/fem/space/grid_2d_function_space.py +177 -0
  152. warp/fem/space/grid_3d_function_space.py +227 -0
  153. warp/fem/space/hexmesh_function_space.py +257 -0
  154. warp/fem/space/nanogrid_function_space.py +201 -0
  155. warp/fem/space/partition.py +367 -0
  156. warp/fem/space/quadmesh_function_space.py +223 -0
  157. warp/fem/space/restriction.py +179 -0
  158. warp/fem/space/shape/__init__.py +143 -0
  159. warp/fem/space/shape/cube_shape_function.py +1105 -0
  160. warp/fem/space/shape/shape_function.py +133 -0
  161. warp/fem/space/shape/square_shape_function.py +926 -0
  162. warp/fem/space/shape/tet_shape_function.py +834 -0
  163. warp/fem/space/shape/triangle_shape_function.py +672 -0
  164. warp/fem/space/tetmesh_function_space.py +271 -0
  165. warp/fem/space/topology.py +424 -0
  166. warp/fem/space/trimesh_function_space.py +194 -0
  167. warp/fem/types.py +99 -0
  168. warp/fem/utils.py +420 -0
  169. warp/jax.py +187 -0
  170. warp/jax_experimental/__init__.py +16 -0
  171. warp/jax_experimental/custom_call.py +351 -0
  172. warp/jax_experimental/ffi.py +698 -0
  173. warp/jax_experimental/xla_ffi.py +602 -0
  174. warp/math.py +244 -0
  175. warp/native/array.h +1145 -0
  176. warp/native/builtin.h +1800 -0
  177. warp/native/bvh.cpp +492 -0
  178. warp/native/bvh.cu +791 -0
  179. warp/native/bvh.h +554 -0
  180. warp/native/clang/clang.cpp +536 -0
  181. warp/native/coloring.cpp +613 -0
  182. warp/native/crt.cpp +51 -0
  183. warp/native/crt.h +362 -0
  184. warp/native/cuda_crt.h +1058 -0
  185. warp/native/cuda_util.cpp +646 -0
  186. warp/native/cuda_util.h +307 -0
  187. warp/native/error.cpp +77 -0
  188. warp/native/error.h +36 -0
  189. warp/native/exports.h +1878 -0
  190. warp/native/fabric.h +245 -0
  191. warp/native/hashgrid.cpp +311 -0
  192. warp/native/hashgrid.cu +87 -0
  193. warp/native/hashgrid.h +240 -0
  194. warp/native/initializer_array.h +41 -0
  195. warp/native/intersect.h +1230 -0
  196. warp/native/intersect_adj.h +375 -0
  197. warp/native/intersect_tri.h +339 -0
  198. warp/native/marching.cpp +19 -0
  199. warp/native/marching.cu +514 -0
  200. warp/native/marching.h +19 -0
  201. warp/native/mat.h +2220 -0
  202. warp/native/mathdx.cpp +87 -0
  203. warp/native/matnn.h +343 -0
  204. warp/native/mesh.cpp +266 -0
  205. warp/native/mesh.cu +404 -0
  206. warp/native/mesh.h +1980 -0
  207. warp/native/nanovdb/GridHandle.h +366 -0
  208. warp/native/nanovdb/HostBuffer.h +590 -0
  209. warp/native/nanovdb/NanoVDB.h +6624 -0
  210. warp/native/nanovdb/PNanoVDB.h +3390 -0
  211. warp/native/noise.h +859 -0
  212. warp/native/quat.h +1371 -0
  213. warp/native/rand.h +342 -0
  214. warp/native/range.h +139 -0
  215. warp/native/reduce.cpp +174 -0
  216. warp/native/reduce.cu +364 -0
  217. warp/native/runlength_encode.cpp +79 -0
  218. warp/native/runlength_encode.cu +61 -0
  219. warp/native/scan.cpp +47 -0
  220. warp/native/scan.cu +53 -0
  221. warp/native/scan.h +23 -0
  222. warp/native/solid_angle.h +466 -0
  223. warp/native/sort.cpp +251 -0
  224. warp/native/sort.cu +277 -0
  225. warp/native/sort.h +33 -0
  226. warp/native/sparse.cpp +378 -0
  227. warp/native/sparse.cu +524 -0
  228. warp/native/spatial.h +657 -0
  229. warp/native/svd.h +702 -0
  230. warp/native/temp_buffer.h +46 -0
  231. warp/native/tile.h +2584 -0
  232. warp/native/tile_reduce.h +264 -0
  233. warp/native/vec.h +1426 -0
  234. warp/native/volume.cpp +501 -0
  235. warp/native/volume.cu +67 -0
  236. warp/native/volume.h +969 -0
  237. warp/native/volume_builder.cu +477 -0
  238. warp/native/volume_builder.h +52 -0
  239. warp/native/volume_impl.h +70 -0
  240. warp/native/warp.cpp +1082 -0
  241. warp/native/warp.cu +3636 -0
  242. warp/native/warp.h +381 -0
  243. warp/optim/__init__.py +17 -0
  244. warp/optim/adam.py +163 -0
  245. warp/optim/linear.py +1137 -0
  246. warp/optim/sgd.py +112 -0
  247. warp/paddle.py +407 -0
  248. warp/render/__init__.py +18 -0
  249. warp/render/render_opengl.py +3518 -0
  250. warp/render/render_usd.py +784 -0
  251. warp/render/utils.py +160 -0
  252. warp/sim/__init__.py +65 -0
  253. warp/sim/articulation.py +793 -0
  254. warp/sim/collide.py +2395 -0
  255. warp/sim/graph_coloring.py +300 -0
  256. warp/sim/import_mjcf.py +790 -0
  257. warp/sim/import_snu.py +227 -0
  258. warp/sim/import_urdf.py +579 -0
  259. warp/sim/import_usd.py +894 -0
  260. warp/sim/inertia.py +324 -0
  261. warp/sim/integrator.py +242 -0
  262. warp/sim/integrator_euler.py +1997 -0
  263. warp/sim/integrator_featherstone.py +2101 -0
  264. warp/sim/integrator_vbd.py +2048 -0
  265. warp/sim/integrator_xpbd.py +3292 -0
  266. warp/sim/model.py +4791 -0
  267. warp/sim/particles.py +121 -0
  268. warp/sim/render.py +427 -0
  269. warp/sim/utils.py +428 -0
  270. warp/sparse.py +2057 -0
  271. warp/stubs.py +3333 -0
  272. warp/tape.py +1203 -0
  273. warp/tests/__init__.py +1 -0
  274. warp/tests/__main__.py +4 -0
  275. warp/tests/assets/curlnoise_golden.npy +0 -0
  276. warp/tests/assets/mlp_golden.npy +0 -0
  277. warp/tests/assets/pixel.npy +0 -0
  278. warp/tests/assets/pnoise_golden.npy +0 -0
  279. warp/tests/assets/spiky.usd +0 -0
  280. warp/tests/assets/test_grid.nvdb +0 -0
  281. warp/tests/assets/test_index_grid.nvdb +0 -0
  282. warp/tests/assets/test_int32_grid.nvdb +0 -0
  283. warp/tests/assets/test_vec_grid.nvdb +0 -0
  284. warp/tests/assets/torus.nvdb +0 -0
  285. warp/tests/assets/torus.usda +105 -0
  286. warp/tests/aux_test_class_kernel.py +34 -0
  287. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  288. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  289. warp/tests/aux_test_dependent.py +29 -0
  290. warp/tests/aux_test_grad_customs.py +29 -0
  291. warp/tests/aux_test_instancing_gc.py +26 -0
  292. warp/tests/aux_test_module_unload.py +23 -0
  293. warp/tests/aux_test_name_clash1.py +40 -0
  294. warp/tests/aux_test_name_clash2.py +40 -0
  295. warp/tests/aux_test_reference.py +9 -0
  296. warp/tests/aux_test_reference_reference.py +8 -0
  297. warp/tests/aux_test_square.py +16 -0
  298. warp/tests/aux_test_unresolved_func.py +22 -0
  299. warp/tests/aux_test_unresolved_symbol.py +22 -0
  300. warp/tests/cuda/__init__.py +0 -0
  301. warp/tests/cuda/test_async.py +676 -0
  302. warp/tests/cuda/test_ipc.py +124 -0
  303. warp/tests/cuda/test_mempool.py +233 -0
  304. warp/tests/cuda/test_multigpu.py +169 -0
  305. warp/tests/cuda/test_peer.py +139 -0
  306. warp/tests/cuda/test_pinned.py +84 -0
  307. warp/tests/cuda/test_streams.py +634 -0
  308. warp/tests/geometry/__init__.py +0 -0
  309. warp/tests/geometry/test_bvh.py +200 -0
  310. warp/tests/geometry/test_hash_grid.py +221 -0
  311. warp/tests/geometry/test_marching_cubes.py +74 -0
  312. warp/tests/geometry/test_mesh.py +316 -0
  313. warp/tests/geometry/test_mesh_query_aabb.py +399 -0
  314. warp/tests/geometry/test_mesh_query_point.py +932 -0
  315. warp/tests/geometry/test_mesh_query_ray.py +311 -0
  316. warp/tests/geometry/test_volume.py +1103 -0
  317. warp/tests/geometry/test_volume_write.py +346 -0
  318. warp/tests/interop/__init__.py +0 -0
  319. warp/tests/interop/test_dlpack.py +729 -0
  320. warp/tests/interop/test_jax.py +371 -0
  321. warp/tests/interop/test_paddle.py +800 -0
  322. warp/tests/interop/test_torch.py +1001 -0
  323. warp/tests/run_coverage_serial.py +39 -0
  324. warp/tests/sim/__init__.py +0 -0
  325. warp/tests/sim/disabled_kinematics.py +244 -0
  326. warp/tests/sim/flaky_test_sim_grad.py +290 -0
  327. warp/tests/sim/test_collision.py +604 -0
  328. warp/tests/sim/test_coloring.py +258 -0
  329. warp/tests/sim/test_model.py +224 -0
  330. warp/tests/sim/test_sim_grad_bounce_linear.py +212 -0
  331. warp/tests/sim/test_sim_kinematics.py +98 -0
  332. warp/tests/sim/test_vbd.py +597 -0
  333. warp/tests/test_adam.py +163 -0
  334. warp/tests/test_arithmetic.py +1096 -0
  335. warp/tests/test_array.py +2972 -0
  336. warp/tests/test_array_reduce.py +156 -0
  337. warp/tests/test_assert.py +250 -0
  338. warp/tests/test_atomic.py +153 -0
  339. warp/tests/test_bool.py +220 -0
  340. warp/tests/test_builtins_resolution.py +1298 -0
  341. warp/tests/test_closest_point_edge_edge.py +327 -0
  342. warp/tests/test_codegen.py +810 -0
  343. warp/tests/test_codegen_instancing.py +1495 -0
  344. warp/tests/test_compile_consts.py +215 -0
  345. warp/tests/test_conditional.py +252 -0
  346. warp/tests/test_context.py +42 -0
  347. warp/tests/test_copy.py +238 -0
  348. warp/tests/test_ctypes.py +638 -0
  349. warp/tests/test_dense.py +73 -0
  350. warp/tests/test_devices.py +97 -0
  351. warp/tests/test_examples.py +482 -0
  352. warp/tests/test_fabricarray.py +996 -0
  353. warp/tests/test_fast_math.py +74 -0
  354. warp/tests/test_fem.py +2003 -0
  355. warp/tests/test_fp16.py +136 -0
  356. warp/tests/test_func.py +454 -0
  357. warp/tests/test_future_annotations.py +98 -0
  358. warp/tests/test_generics.py +656 -0
  359. warp/tests/test_grad.py +893 -0
  360. warp/tests/test_grad_customs.py +339 -0
  361. warp/tests/test_grad_debug.py +341 -0
  362. warp/tests/test_implicit_init.py +411 -0
  363. warp/tests/test_import.py +45 -0
  364. warp/tests/test_indexedarray.py +1140 -0
  365. warp/tests/test_intersect.py +73 -0
  366. warp/tests/test_iter.py +76 -0
  367. warp/tests/test_large.py +177 -0
  368. warp/tests/test_launch.py +411 -0
  369. warp/tests/test_lerp.py +151 -0
  370. warp/tests/test_linear_solvers.py +193 -0
  371. warp/tests/test_lvalue.py +427 -0
  372. warp/tests/test_mat.py +2089 -0
  373. warp/tests/test_mat_lite.py +122 -0
  374. warp/tests/test_mat_scalar_ops.py +2913 -0
  375. warp/tests/test_math.py +178 -0
  376. warp/tests/test_mlp.py +282 -0
  377. warp/tests/test_module_hashing.py +258 -0
  378. warp/tests/test_modules_lite.py +44 -0
  379. warp/tests/test_noise.py +252 -0
  380. warp/tests/test_operators.py +299 -0
  381. warp/tests/test_options.py +129 -0
  382. warp/tests/test_overwrite.py +551 -0
  383. warp/tests/test_print.py +339 -0
  384. warp/tests/test_quat.py +2315 -0
  385. warp/tests/test_rand.py +339 -0
  386. warp/tests/test_reload.py +302 -0
  387. warp/tests/test_rounding.py +185 -0
  388. warp/tests/test_runlength_encode.py +196 -0
  389. warp/tests/test_scalar_ops.py +105 -0
  390. warp/tests/test_smoothstep.py +108 -0
  391. warp/tests/test_snippet.py +318 -0
  392. warp/tests/test_sparse.py +582 -0
  393. warp/tests/test_spatial.py +2229 -0
  394. warp/tests/test_special_values.py +361 -0
  395. warp/tests/test_static.py +592 -0
  396. warp/tests/test_struct.py +734 -0
  397. warp/tests/test_tape.py +204 -0
  398. warp/tests/test_transient_module.py +93 -0
  399. warp/tests/test_triangle_closest_point.py +145 -0
  400. warp/tests/test_types.py +562 -0
  401. warp/tests/test_utils.py +588 -0
  402. warp/tests/test_vec.py +1487 -0
  403. warp/tests/test_vec_lite.py +80 -0
  404. warp/tests/test_vec_scalar_ops.py +2327 -0
  405. warp/tests/test_verify_fp.py +100 -0
  406. warp/tests/tile/__init__.py +0 -0
  407. warp/tests/tile/test_tile.py +780 -0
  408. warp/tests/tile/test_tile_load.py +407 -0
  409. warp/tests/tile/test_tile_mathdx.py +208 -0
  410. warp/tests/tile/test_tile_mlp.py +402 -0
  411. warp/tests/tile/test_tile_reduce.py +447 -0
  412. warp/tests/tile/test_tile_shared_memory.py +247 -0
  413. warp/tests/tile/test_tile_view.py +173 -0
  414. warp/tests/unittest_serial.py +47 -0
  415. warp/tests/unittest_suites.py +427 -0
  416. warp/tests/unittest_utils.py +468 -0
  417. warp/tests/walkthrough_debug.py +93 -0
  418. warp/thirdparty/__init__.py +0 -0
  419. warp/thirdparty/appdirs.py +598 -0
  420. warp/thirdparty/dlpack.py +145 -0
  421. warp/thirdparty/unittest_parallel.py +570 -0
  422. warp/torch.py +391 -0
  423. warp/types.py +5230 -0
  424. warp/utils.py +1137 -0
  425. warp_lang-1.7.0.dist-info/METADATA +516 -0
  426. warp_lang-1.7.0.dist-info/RECORD +429 -0
  427. warp_lang-1.7.0.dist-info/WHEEL +5 -0
  428. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  429. warp_lang-1.7.0.dist-info/top_level.txt +1 -0
warp/autograd.py ADDED
@@ -0,0 +1,1142 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import inspect
17
+ import itertools
18
+ from typing import Any, Callable, Dict, List, Sequence, Tuple, Union
19
+
20
+ import numpy as np
21
+
22
+ import warp as wp
23
+
24
+ __all__ = [
25
+ "jacobian",
26
+ "jacobian_fd",
27
+ "gradcheck",
28
+ "gradcheck_tape",
29
+ "jacobian_plot",
30
+ ]
31
+
32
+
33
+ def gradcheck(
34
+ function: Union[wp.Kernel, Callable],
35
+ dim: Tuple[int] = None,
36
+ inputs: Sequence = None,
37
+ outputs: Sequence = None,
38
+ *,
39
+ eps: float = 1e-4,
40
+ atol: float = 1e-3,
41
+ rtol: float = 1e-2,
42
+ raise_exception: bool = True,
43
+ input_output_mask: List[Tuple[Union[str, int], Union[str, int]]] = None,
44
+ device: wp.context.Devicelike = None,
45
+ max_blocks: int = 0,
46
+ block_dim: int = 256,
47
+ max_inputs_per_var: int = -1,
48
+ max_outputs_per_var: int = -1,
49
+ plot_relative_error: bool = False,
50
+ plot_absolute_error: bool = False,
51
+ show_summary: bool = True,
52
+ ) -> bool:
53
+ """
54
+ Checks whether the autodiff gradient of a Warp kernel matches finite differences.
55
+ Fails if the relative or absolute errors between the autodiff and finite difference gradients exceed the specified tolerance, or if the autodiff gradients contain NaN values.
56
+
57
+ The kernel function and its adjoint version are launched with the given inputs and outputs, as well as the provided
58
+ ``dim``, ``max_blocks``, and ``block_dim`` arguments (see :func:`warp.launch` for more details).
59
+
60
+ Note:
61
+ This function only supports Warp kernels whose input arguments precede the output arguments.
62
+
63
+ Only Warp arrays with ``requires_grad=True`` are considered for the Jacobian computation.
64
+
65
+ Structs arguments are not yet supported by this function to compute Jacobians.
66
+
67
+ Args:
68
+ function: The Warp kernel function, decorated with the ``@wp.kernel`` decorator, or any function that involves Warp kernel launches.
69
+ dim: The number of threads to launch the kernel, can be an integer, or a Tuple of ints. Only required if the function is a Warp kernel.
70
+ inputs: List of input variables.
71
+ outputs: List of output variables. Only required if the function is a Warp kernel.
72
+ eps: The finite-difference step size.
73
+ atol: The absolute tolerance for the gradient check.
74
+ rtol: The relative tolerance for the gradient check.
75
+ raise_exception: If True, raises a `ValueError` if the gradient check fails.
76
+ input_output_mask: List of tuples specifying the input-output pairs to compute the Jacobian for. Inputs and outputs can be identified either by their integer indices of where they appear in the kernel input/output arguments, or by the respective argument names as strings. If None, computes the Jacobian for all input-output pairs.
77
+ device: The device to launch on (optional)
78
+ max_blocks: The maximum number of CUDA thread blocks to use.
79
+ block_dim: The number of threads per block.
80
+ max_inputs_per_var: Maximum number of input dimensions over which to evaluate the Jacobians for the input-output pairs. Evaluates all input dimensions if value <= 0.
81
+ max_outputs_per_var: Maximum number of output dimensions over which to evaluate the Jacobians for the input-output pairs. Evaluates all output dimensions if value <= 0.
82
+ plot_relative_error: If True, visualizes the relative error of the Jacobians in a plot (requires ``matplotlib``).
83
+ plot_absolute_error: If True, visualizes the absolute error of the Jacobians in a plot (requires ``matplotlib``).
84
+ show_summary: If True, prints a summary table of the gradient check results.
85
+
86
+ Returns:
87
+ True if the gradient check passes, False otherwise.
88
+ """
89
+
90
+ if inputs is None:
91
+ raise ValueError("The inputs argument must be provided")
92
+
93
+ metadata = FunctionMetadata()
94
+
95
+ jacs_ad = jacobian(
96
+ function,
97
+ dim=dim,
98
+ inputs=inputs,
99
+ outputs=outputs,
100
+ input_output_mask=input_output_mask,
101
+ device=device,
102
+ max_blocks=max_blocks,
103
+ block_dim=block_dim,
104
+ max_outputs_per_var=max_outputs_per_var,
105
+ plot_jacobians=False,
106
+ metadata=metadata,
107
+ )
108
+ jacs_fd = jacobian_fd(
109
+ function,
110
+ dim=dim,
111
+ inputs=inputs,
112
+ outputs=outputs,
113
+ input_output_mask=input_output_mask,
114
+ device=device,
115
+ max_blocks=max_blocks,
116
+ block_dim=block_dim,
117
+ max_inputs_per_var=max_inputs_per_var,
118
+ eps=eps,
119
+ plot_jacobians=False,
120
+ metadata=metadata,
121
+ )
122
+
123
+ relative_error_jacs = {}
124
+ absolute_error_jacs = {}
125
+
126
+ if show_summary:
127
+ summary = []
128
+ summary_header = ["Input", "Output", "Max Abs Error", "AD at MAE", "FD at MAE", "Max Rel Error", "Pass"]
129
+
130
+ class FontColors:
131
+ OKGREEN = "\033[92m"
132
+ WARNING = "\033[93m"
133
+ FAIL = "\033[91m"
134
+ ENDC = "\033[0m"
135
+
136
+ success = True
137
+ any_grad_mismatch = False
138
+ any_grad_nan = False
139
+ for (input_i, output_i), jac_fd in jacs_fd.items():
140
+ jac_ad = jacs_ad[input_i, output_i]
141
+ if plot_relative_error or plot_absolute_error:
142
+ jac_rel_error = wp.empty_like(jac_fd)
143
+ jac_abs_error = wp.empty_like(jac_fd)
144
+ flat_jac_fd = scalarize_array_1d(jac_fd)
145
+ flat_jac_ad = scalarize_array_1d(jac_ad)
146
+ flat_jac_rel_error = scalarize_array_1d(jac_rel_error)
147
+ flat_jac_abs_error = scalarize_array_1d(jac_abs_error)
148
+ wp.launch(
149
+ compute_error_kernel,
150
+ dim=len(flat_jac_fd),
151
+ inputs=[flat_jac_ad, flat_jac_fd, flat_jac_rel_error, flat_jac_abs_error],
152
+ device=jac_fd.device,
153
+ )
154
+ relative_error_jacs[(input_i, output_i)] = jac_rel_error
155
+ absolute_error_jacs[(input_i, output_i)] = jac_abs_error
156
+ cut_jac_fd = jac_fd.numpy()
157
+ cut_jac_ad = jac_ad.numpy()
158
+ if max_outputs_per_var > 0:
159
+ cut_jac_fd = cut_jac_fd[:max_outputs_per_var]
160
+ cut_jac_ad = cut_jac_ad[:max_outputs_per_var]
161
+ if max_inputs_per_var > 0:
162
+ cut_jac_fd = cut_jac_fd[:, :max_inputs_per_var]
163
+ cut_jac_ad = cut_jac_ad[:, :max_inputs_per_var]
164
+ grad_matches = np.allclose(cut_jac_ad, cut_jac_fd, atol=atol, rtol=rtol)
165
+ any_grad_mismatch = any_grad_mismatch or not grad_matches
166
+ success = success and grad_matches
167
+ isnan = np.any(np.isnan(cut_jac_ad))
168
+ any_grad_nan = any_grad_nan or isnan
169
+ success = success and not isnan
170
+
171
+ if show_summary:
172
+ max_abs_error = np.abs(cut_jac_ad - cut_jac_fd).max()
173
+ arg_max_abs_error = np.unravel_index(np.argmax(np.abs(cut_jac_ad - cut_jac_fd)), cut_jac_ad.shape)
174
+ max_rel_error = np.abs((cut_jac_ad - cut_jac_fd) / (cut_jac_fd + 1e-8)).max()
175
+ if isnan:
176
+ pass_str = FontColors.FAIL + "NaN" + FontColors.ENDC
177
+ elif grad_matches:
178
+ pass_str = FontColors.OKGREEN + "PASS" + FontColors.ENDC
179
+ else:
180
+ pass_str = FontColors.FAIL + "FAIL" + FontColors.ENDC
181
+ input_name = metadata.input_labels[input_i]
182
+ output_name = metadata.output_labels[output_i]
183
+ summary.append(
184
+ [
185
+ input_name,
186
+ output_name,
187
+ f"{max_abs_error:.3e} at {tuple(int(i) for i in arg_max_abs_error)}",
188
+ f"{cut_jac_ad[arg_max_abs_error]:.3e}",
189
+ f"{cut_jac_fd[arg_max_abs_error]:.3e}",
190
+ f"{max_rel_error:.3e}",
191
+ pass_str,
192
+ ]
193
+ )
194
+
195
+ if show_summary:
196
+ print_table(summary_header, summary)
197
+ if not success:
198
+ print(FontColors.FAIL + f"Gradient check for kernel {metadata.key} failed" + FontColors.ENDC)
199
+ else:
200
+ print(FontColors.OKGREEN + f"Gradient check for kernel {metadata.key} passed" + FontColors.ENDC)
201
+ if plot_relative_error:
202
+ jacobian_plot(
203
+ relative_error_jacs,
204
+ metadata,
205
+ inputs,
206
+ outputs,
207
+ title=f"{metadata.key} kernel Jacobian relative error",
208
+ )
209
+ if plot_absolute_error:
210
+ jacobian_plot(
211
+ absolute_error_jacs,
212
+ metadata,
213
+ inputs,
214
+ outputs,
215
+ title=f"{metadata.key} kernel Jacobian absolute error",
216
+ )
217
+
218
+ if raise_exception:
219
+ if any_grad_mismatch:
220
+ raise ValueError(
221
+ f"Gradient check failed for kernel {metadata.key}, input {input_i}, output {output_i}: "
222
+ f"finite difference and autodiff gradients do not match"
223
+ )
224
+ if any_grad_nan:
225
+ raise ValueError(
226
+ f"Gradient check failed for kernel {metadata.key}, input {input_i}, output {output_i}: "
227
+ f"gradient contains NaN values"
228
+ )
229
+
230
+ return success
231
+
232
+
233
+ def gradcheck_tape(
234
+ tape: wp.Tape,
235
+ *,
236
+ eps=1e-4,
237
+ atol=1e-3,
238
+ rtol=1e-2,
239
+ raise_exception=True,
240
+ input_output_masks: Dict[str, List[Tuple[Union[str, int], Union[str, int]]]] = None,
241
+ blacklist_kernels: List[str] = None,
242
+ whitelist_kernels: List[str] = None,
243
+ max_inputs_per_var=-1,
244
+ max_outputs_per_var=-1,
245
+ plot_relative_error=False,
246
+ plot_absolute_error=False,
247
+ show_summary: bool = True,
248
+ reverse_launches: bool = False,
249
+ skip_to_launch_index: int = 0,
250
+ ) -> bool:
251
+ """
252
+ Checks whether the autodiff gradients for kernels recorded on the Warp tape match finite differences.
253
+ Fails if the relative or absolute errors between the autodiff and finite difference gradients exceed the specified tolerance, or if the autodiff gradients contain NaN values.
254
+
255
+ Note:
256
+ Only Warp kernels recorded on the tape are checked but not arbitrary functions that have been recorded, e.g. via :meth:`Tape.record_func`.
257
+
258
+ Only Warp arrays with ``requires_grad=True`` are considered for the Jacobian computation.
259
+
260
+ Structs arguments are not yet supported by this function to compute Jacobians.
261
+
262
+ Args:
263
+ tape: The Warp tape to perform the gradient check on.
264
+ eps: The finite-difference step size.
265
+ atol: The absolute tolerance for the gradient check.
266
+ rtol: The relative tolerance for the gradient check.
267
+ raise_exception: If True, raises a `ValueError` if the gradient check fails.
268
+ input_output_masks: Dictionary of input-output masks for each kernel in the tape, mapping from kernel keys to input-output masks. Inputs and outputs can be identified either by their integer indices of where they appear in the kernel input/output arguments, or by the respective argument names as strings. If None, computes the Jacobian for all input-output pairs.
269
+ blacklist_kernels: List of kernel keys to exclude from the gradient check.
270
+ whitelist_kernels: List of kernel keys to include in the gradient check. If not empty or None, only kernels in this list are checked.
271
+ max_inputs_per_var: Maximum number of input dimensions over which to evaluate the Jacobians for the input-output pairs. Evaluates all input dimensions if value <= 0.
272
+ max_outputs_per_var: Maximum number of output dimensions over which to evaluate the Jacobians for the input-output pairs. Evaluates all output dimensions if value <= 0.
273
+ plot_relative_error: If True, visualizes the relative error of the Jacobians in a plot (requires ``matplotlib``).
274
+ plot_absolute_error: If True, visualizes the absolute error of the Jacobians in a plot (requires ``matplotlib``).
275
+ show_summary: If True, prints a summary table of the gradient check results.
276
+ reverse_launches: If True, reverses the order of the kernel launches on the tape to check.
277
+
278
+ Returns:
279
+ True if the gradient check passes for all kernels on the tape, False otherwise.
280
+ """
281
+ if input_output_masks is None:
282
+ input_output_masks = {}
283
+ if blacklist_kernels is None:
284
+ blacklist_kernels = []
285
+ else:
286
+ blacklist_kernels = set(blacklist_kernels)
287
+ if whitelist_kernels is None:
288
+ whitelist_kernels = []
289
+ else:
290
+ whitelist_kernels = set(whitelist_kernels)
291
+
292
+ overall_success = True
293
+ launches = reversed(tape.launches) if reverse_launches else tape.launches
294
+ for i, launch in enumerate(launches):
295
+ if i < skip_to_launch_index:
296
+ continue
297
+ if not isinstance(launch, tuple) and not isinstance(launch, list):
298
+ continue
299
+ if not isinstance(launch[0], wp.Kernel):
300
+ continue
301
+ kernel, dim, max_blocks, inputs, outputs, device, block_dim = launch[:7]
302
+ if len(whitelist_kernels) > 0 and kernel.key not in whitelist_kernels:
303
+ continue
304
+ if kernel.key in blacklist_kernels:
305
+ continue
306
+ if not kernel.options.get("enable_backward", True):
307
+ continue
308
+
309
+ input_output_mask = input_output_masks.get(kernel.key)
310
+ success = gradcheck(
311
+ kernel,
312
+ dim,
313
+ inputs,
314
+ outputs,
315
+ eps=eps,
316
+ atol=atol,
317
+ rtol=rtol,
318
+ raise_exception=raise_exception,
319
+ input_output_mask=input_output_mask,
320
+ device=device,
321
+ max_blocks=max_blocks,
322
+ block_dim=block_dim,
323
+ max_inputs_per_var=max_inputs_per_var,
324
+ max_outputs_per_var=max_outputs_per_var,
325
+ plot_relative_error=plot_relative_error,
326
+ plot_absolute_error=plot_absolute_error,
327
+ show_summary=show_summary,
328
+ )
329
+ overall_success = overall_success and success
330
+
331
+ return overall_success
332
+
333
+
334
+ def get_struct_vars(x: wp.codegen.StructInstance):
335
+ return {varname: getattr(x, varname) for varname, _ in x._cls.ctype._fields_}
336
+
337
+
338
+ def infer_device(xs: list):
339
+ # retrieve best matching Warp device for a list of variables
340
+ for x in xs:
341
+ if isinstance(x, wp.array):
342
+ return x.device
343
+ elif isinstance(x, wp.codegen.StructInstance):
344
+ for var in get_struct_vars(x).values():
345
+ if isinstance(var, wp.array):
346
+ return var.device
347
+ return wp.get_preferred_device()
348
+
349
+
350
+ class FunctionMetadata:
351
+ """
352
+ Metadata holder for kernel functions or functions with Warp arrays as inputs/outputs.
353
+ """
354
+
355
+ def __init__(
356
+ self,
357
+ key: str = None,
358
+ input_labels: List[str] = None,
359
+ output_labels: List[str] = None,
360
+ input_strides: List[tuple] = None,
361
+ output_strides: List[tuple] = None,
362
+ input_dtypes: list = None,
363
+ output_dtypes: list = None,
364
+ ):
365
+ self.key = key
366
+ self.input_labels = input_labels
367
+ self.output_labels = output_labels
368
+ self.input_strides = input_strides
369
+ self.output_strides = output_strides
370
+ self.input_dtypes = input_dtypes
371
+ self.output_dtypes = output_dtypes
372
+
373
+ @property
374
+ def is_empty(self):
375
+ return self.key is None
376
+
377
+ def input_is_array(self, i: int):
378
+ return self.input_strides[i] is not None
379
+
380
+ def output_is_array(self, i: int):
381
+ return self.output_strides[i] is not None
382
+
383
+ def update_from_kernel(self, kernel: wp.Kernel, inputs: Sequence):
384
+ self.key = kernel.key
385
+ self.input_labels = [arg.label for arg in kernel.adj.args[: len(inputs)]]
386
+ self.output_labels = [arg.label for arg in kernel.adj.args[len(inputs) :]]
387
+ self.input_strides = []
388
+ self.output_strides = []
389
+ self.input_dtypes = []
390
+ self.output_dtypes = []
391
+ for arg in kernel.adj.args[: len(inputs)]:
392
+ if arg.type is wp.array:
393
+ self.input_strides.append(arg.type.strides)
394
+ self.input_dtypes.append(arg.type.dtype)
395
+ else:
396
+ self.input_strides.append(None)
397
+ self.input_dtypes.append(None)
398
+ for arg in kernel.adj.args[len(inputs) :]:
399
+ if arg.type is wp.array:
400
+ self.output_strides.append(arg.type.strides)
401
+ self.output_dtypes.append(arg.type.dtype)
402
+ else:
403
+ self.output_strides.append(None)
404
+ self.output_dtypes.append(None)
405
+
406
+ def update_from_function(self, function: Callable, inputs: Sequence, outputs: Sequence = None):
407
+ self.key = function.__name__
408
+ self.input_labels = list(inspect.signature(function).parameters.keys())
409
+ if outputs is None:
410
+ outputs = function(*inputs)
411
+ if isinstance(outputs, wp.array):
412
+ outputs = [outputs]
413
+ self.output_labels = [f"output_{i}" for i in range(len(outputs))]
414
+ self.input_strides = []
415
+ self.output_strides = []
416
+ self.input_dtypes = []
417
+ self.output_dtypes = []
418
+ for input in inputs:
419
+ if isinstance(input, wp.array):
420
+ self.input_strides.append(input.strides)
421
+ self.input_dtypes.append(input.dtype)
422
+ else:
423
+ self.input_strides.append(None)
424
+ self.input_dtypes.append(None)
425
+ for output in outputs:
426
+ if isinstance(output, wp.array):
427
+ self.output_strides.append(output.strides)
428
+ self.output_dtypes.append(output.dtype)
429
+ else:
430
+ self.output_strides.append(None)
431
+ self.output_dtypes.append(None)
432
+
433
+
434
+ def jacobian_plot(
435
+ jacobians: Dict[Tuple[int, int], wp.array],
436
+ kernel: Union[FunctionMetadata, wp.Kernel],
437
+ inputs: Sequence = None,
438
+ outputs: Sequence = None,
439
+ show_plot=True,
440
+ show_colorbar=True,
441
+ scale_colors_per_submatrix=False,
442
+ title: str = None,
443
+ colormap: str = "coolwarm",
444
+ log_scale=False,
445
+ ):
446
+ """
447
+ Visualizes the Jacobians computed by :func:`jacobian` or :func:`jacobian_fd` in a combined image plot.
448
+ Requires the ``matplotlib`` package to be installed.
449
+
450
+ Args:
451
+ jacobians: A dictionary of Jacobians, where the keys are tuples of input and output indices, and the values are the Jacobian matrices.
452
+ kernel: The Warp kernel function, decorated with the ``@wp.kernel`` decorator, or a :class:`FunctionMetadata` instance with the kernel/function attributes.
453
+ inputs: List of input variables.
454
+ outputs: List of output variables. Deprecated and will be removed in a future Warp version.
455
+ show_plot: If True, displays the plot via ``plt.show()``.
456
+ show_colorbar: If True, displays a colorbar next to the plot (or a colorbar next to every submatrix if ).
457
+ scale_colors_per_submatrix: If True, considers the minimum and maximum of each Jacobian submatrix separately for color scaling. Otherwise, uses the global minimum and maximum of all Jacobians.
458
+ title: The title of the plot (optional).
459
+ colormap: The colormap to use for the plot.
460
+ log_scale: If True, uses a logarithmic scale for the matrix values shown in the image plot.
461
+
462
+ Returns:
463
+ The created Matplotlib figure.
464
+ """
465
+
466
+ import matplotlib.pyplot as plt
467
+ from matplotlib.ticker import MaxNLocator
468
+
469
+ if isinstance(kernel, wp.Kernel):
470
+ assert inputs is not None
471
+ metadata = FunctionMetadata()
472
+ metadata.update_from_kernel(kernel, inputs)
473
+ elif isinstance(kernel, FunctionMetadata):
474
+ metadata = kernel
475
+ else:
476
+ raise ValueError("Invalid kernel argument: must be a Warp kernel or a FunctionMetadata object")
477
+ if outputs is not None:
478
+ wp.utils.warn(
479
+ "The `outputs` argument to `jacobian_plot` is no longer needed and will be removed in a future Warp version.",
480
+ DeprecationWarning,
481
+ stacklevel=3,
482
+ )
483
+
484
+ jacobians = sorted(jacobians.items(), key=lambda x: (x[0][1], x[0][0]))
485
+ jacobians = dict(jacobians)
486
+
487
+ input_to_ax = {}
488
+ output_to_ax = {}
489
+ ax_to_input = {}
490
+ ax_to_output = {}
491
+ for i, j in jacobians.keys():
492
+ if i not in input_to_ax:
493
+ input_to_ax[i] = len(input_to_ax)
494
+ ax_to_input[input_to_ax[i]] = i
495
+ if j not in output_to_ax:
496
+ output_to_ax[j] = len(output_to_ax)
497
+ ax_to_output[output_to_ax[j]] = j
498
+
499
+ num_rows = len(output_to_ax)
500
+ num_cols = len(input_to_ax)
501
+ if num_rows == 0 or num_cols == 0:
502
+ return
503
+
504
+ # determine the width and height ratios for the subplots based on the
505
+ # dimensions of the Jacobians
506
+ width_ratios = []
507
+ height_ratios = []
508
+ for i in range(len(metadata.input_labels)):
509
+ if not metadata.input_is_array(i):
510
+ continue
511
+ input_stride = metadata.input_strides[i][0]
512
+ for j in range(len(metadata.output_labels)):
513
+ if (i, j) not in jacobians:
514
+ continue
515
+ jac_wp = jacobians[(i, j)]
516
+ width_ratios.append(jac_wp.shape[1] * input_stride)
517
+ break
518
+
519
+ for i in range(len(metadata.output_labels)):
520
+ if not metadata.output_is_array(i):
521
+ continue
522
+ for j in range(len(inputs)):
523
+ if (j, i) not in jacobians:
524
+ continue
525
+ jac_wp = jacobians[(j, i)]
526
+ height_ratios.append(jac_wp.shape[0])
527
+ break
528
+
529
+ fig, axs = plt.subplots(
530
+ ncols=num_cols,
531
+ nrows=num_rows,
532
+ figsize=(7, 7),
533
+ sharex="col",
534
+ sharey="row",
535
+ gridspec_kw={
536
+ "wspace": 0.1,
537
+ "hspace": 0.1,
538
+ "width_ratios": width_ratios,
539
+ "height_ratios": height_ratios,
540
+ },
541
+ subplot_kw={"aspect": 1},
542
+ squeeze=False,
543
+ )
544
+ if title is None:
545
+ key = kernel.key if isinstance(kernel, wp.Kernel) else kernel.get("key", "unknown")
546
+ title = f"{key} kernel Jacobian"
547
+ fig.suptitle(title)
548
+ fig.canvas.manager.set_window_title(title)
549
+
550
+ if not scale_colors_per_submatrix:
551
+ safe_jacobians = [jac.numpy().flatten() for jac in jacobians.values()]
552
+ safe_jacobians = [jac[~np.isnan(jac)] for jac in safe_jacobians]
553
+ safe_jacobians = [jac for jac in safe_jacobians if len(jac) > 0]
554
+ if len(safe_jacobians) == 0:
555
+ vmin = 0
556
+ vmax = 0
557
+ else:
558
+ vmin = min([jac.min() for jac in safe_jacobians])
559
+ vmax = max([jac.max() for jac in safe_jacobians])
560
+
561
+ has_plot = np.ones((num_rows, num_cols), dtype=bool)
562
+ for i in range(num_rows):
563
+ for j in range(num_cols):
564
+ if (ax_to_input[j], ax_to_output[i]) not in jacobians:
565
+ ax = axs[i, j]
566
+ ax.axis("off")
567
+ has_plot[i, j] = False
568
+
569
+ jac_i = 0
570
+ for (input_i, output_i), jac_wp in jacobians.items():
571
+ input_name = metadata.input_labels[input_i]
572
+ output_name = metadata.output_labels[output_i]
573
+
574
+ ax_i, ax_j = output_to_ax[output_i], input_to_ax[input_i]
575
+ ax = axs[ax_i, ax_j]
576
+ ax.tick_params(which="major", width=1, length=7)
577
+ ax.tick_params(which="minor", width=1, length=4, color="gray")
578
+
579
+ input_stride = metadata.input_dtypes[input_i]._length_
580
+ # output_stride = metadata.output_dtypes[output_i]._length_
581
+
582
+ jac = jac_wp.numpy()
583
+ # Jacobian matrix has output stride already multiplied to first dimension
584
+ jac = jac.reshape(jac_wp.shape[0], jac_wp.shape[1] * input_stride)
585
+
586
+ ax.xaxis.set_major_locator(MaxNLocator(integer=True))
587
+ ax.yaxis.set_major_locator(MaxNLocator(integer=True))
588
+
589
+ if scale_colors_per_submatrix:
590
+ safe_jac = jac[~np.isnan(jac)]
591
+ vmin = safe_jac.min()
592
+ vmax = safe_jac.max()
593
+ img = ax.imshow(
594
+ np.log10(np.abs(jac) + 1e-8) if log_scale else jac,
595
+ cmap=colormap,
596
+ aspect="auto",
597
+ interpolation="nearest",
598
+ extent=[0, jac.shape[1], 0, jac.shape[0]],
599
+ vmin=vmin,
600
+ vmax=vmax,
601
+ )
602
+ if ax_i == num_rows - 1 or not has_plot[ax_i + 1 :, ax_j].any():
603
+ # last plot of this column
604
+ ax.set_xlabel(input_name)
605
+ if ax_j == 0 or not has_plot[ax_i, :ax_j].any():
606
+ # first plot of this row
607
+ ax.set_ylabel(output_name)
608
+ ax.grid(color="gray", which="minor", linestyle="--", linewidth=0.5)
609
+ ax.grid(color="black", which="major", linewidth=1.0)
610
+
611
+ if show_colorbar and scale_colors_per_submatrix:
612
+ plt.colorbar(img, ax=ax, orientation="vertical", pad=0.02)
613
+
614
+ jac_i += 1
615
+
616
+ if show_colorbar and not scale_colors_per_submatrix:
617
+ m = plt.cm.ScalarMappable(cmap=colormap)
618
+ m.set_array([vmin, vmax])
619
+ m.set_clim(vmin, vmax)
620
+ plt.colorbar(m, ax=axs, orientation="vertical", pad=0.02)
621
+
622
+ plt.tight_layout()
623
+ if show_plot:
624
+ plt.show()
625
+ return fig
626
+
627
+
628
+ def plot_kernel_jacobians(
629
+ jacobians: Dict[Tuple[int, int], wp.array],
630
+ kernel: wp.Kernel,
631
+ inputs: Sequence,
632
+ outputs: Sequence,
633
+ show_plot=True,
634
+ show_colorbar=True,
635
+ scale_colors_per_submatrix=False,
636
+ title: str = None,
637
+ colormap: str = "coolwarm",
638
+ log_scale=False,
639
+ ):
640
+ """
641
+ Visualizes the Jacobians computed by :func:`jacobian` or :func:`jacobian_fd` in a combined image plot.
642
+ Requires the ``matplotlib`` package to be installed.
643
+
644
+ Note:
645
+ This function is deprecated and will be removed in a future Warp version. Please call :func:`jacobian_plot` instead.
646
+
647
+ Args:
648
+ jacobians: A dictionary of Jacobians, where the keys are tuples of input and output indices, and the values are the Jacobian matrices.
649
+ kernel: The Warp kernel function, decorated with the ``@wp.kernel`` decorator.
650
+ inputs: List of input variables.
651
+ outputs: List of output variables.
652
+ show_plot: If True, displays the plot via ``plt.show()``.
653
+ show_colorbar: If True, displays a colorbar next to the plot (or a colorbar next to every submatrix if ).
654
+ scale_colors_per_submatrix: If True, considers the minimum and maximum of each Jacobian submatrix separately for color scaling. Otherwise, uses the global minimum and maximum of all Jacobians.
655
+ title: The title of the plot (optional).
656
+ colormap: The colormap to use for the plot.
657
+ log_scale: If True, uses a logarithmic scale for the matrix values shown in the image plot.
658
+
659
+ Returns:
660
+ The created Matplotlib figure.
661
+ """
662
+ wp.utils.warn(
663
+ "The function `plot_kernel_jacobians` is deprecated and will be removed in a future Warp version. Please call `jacobian_plot` instead.",
664
+ DeprecationWarning,
665
+ stacklevel=3,
666
+ )
667
+ return jacobian_plot(
668
+ jacobians,
669
+ kernel,
670
+ inputs,
671
+ outputs,
672
+ show_plot=show_plot,
673
+ show_colorbar=show_colorbar,
674
+ scale_colors_per_submatrix=scale_colors_per_submatrix,
675
+ title=title,
676
+ colormap=colormap,
677
+ log_scale=log_scale,
678
+ )
679
+
680
+
681
+ def scalarize_array_1d(arr):
682
+ # convert array to 1D array with scalar dtype
683
+ if arr.dtype in wp.types.scalar_types:
684
+ return arr.flatten()
685
+ elif arr.dtype in wp.types.vector_types:
686
+ return wp.array(
687
+ ptr=arr.ptr,
688
+ shape=(arr.size * arr.dtype._length_,),
689
+ dtype=arr.dtype._wp_scalar_type_,
690
+ device=arr.device,
691
+ )
692
+ else:
693
+ raise ValueError(
694
+ f"Unsupported array dtype {arr.dtype}: array to be flattened must be a scalar/vector/matrix array"
695
+ )
696
+
697
+
698
+ def scalarize_array_2d(arr):
699
+ assert arr.ndim == 2
700
+ # convert array to 2D array with scalar dtype
701
+ if arr.dtype in wp.types.scalar_types:
702
+ return arr
703
+ elif arr.dtype in wp.types.vector_types:
704
+ return wp.array(
705
+ ptr=arr.ptr,
706
+ shape=(arr.shape[0], arr.shape[1] * arr.dtype._length_),
707
+ dtype=arr.dtype._wp_scalar_type_,
708
+ device=arr.device,
709
+ )
710
+ else:
711
+ raise ValueError(
712
+ f"Unsupported array dtype {arr.dtype}: array to be flattened must be a scalar/vector/matrix array"
713
+ )
714
+
715
+
716
+ def jacobian(
717
+ function: Union[wp.Kernel, Callable],
718
+ dim: Tuple[int] = None,
719
+ inputs: Sequence = None,
720
+ outputs: Sequence = None,
721
+ input_output_mask: List[Tuple[Union[str, int], Union[str, int]]] = None,
722
+ device: wp.context.Devicelike = None,
723
+ max_blocks=0,
724
+ block_dim=256,
725
+ max_outputs_per_var=-1,
726
+ plot_jacobians=False,
727
+ metadata: FunctionMetadata = None,
728
+ kernel: wp.Kernel = None,
729
+ ) -> Dict[Tuple[int, int], wp.array]:
730
+ """
731
+ Computes the Jacobians of a function or Warp kernel for the provided selection of differentiable inputs to differentiable outputs.
732
+
733
+ The input function can be either a Warp kernel (e.g. a function decorated by ``@wp.kernel``) or a regular Python function that accepts arguments (of which some must be Warp arrays) and returns a Warp array or a list of Warp arrays.
734
+
735
+ In case ``function`` is a Warp kernel, its adjoint kernel is launched with the given inputs and outputs, as well as the provided ``dim``,
736
+ ``max_blocks``, and ``block_dim`` arguments (see :func:`warp.launch` for more details).
737
+
738
+ Note:
739
+ If ``function`` is a Warp kernel, the input arguments must precede the output arguments in the kernel code definition.
740
+
741
+ Only Warp arrays with ``requires_grad=True`` are considered for the Jacobian computation.
742
+
743
+ Function arguments of type :ref:`Struct <structs>` are not yet supported.
744
+
745
+ Args:
746
+ function: The Warp kernel function, or a regular Python function that returns a Warp array or a list of Warp arrays.
747
+ dim: The number of threads to launch the kernel, can be an integer, or a Tuple of ints. Only required if ``function`` is a Warp kernel.
748
+ inputs: List of input variables. At least one of the arguments must be a Warp array with ``requires_grad=True``.
749
+ outputs: List of output variables. Optional if the function is a regular Python function that returns a Warp array or a list of Warp arrays. Only required if ``function`` is a Warp kernel.
750
+ input_output_mask: List of tuples specifying the input-output pairs to compute the Jacobian for. Inputs and outputs can be identified either by their integer indices of where they appear in the kernel input/output arguments, or by the respective argument names as strings. If None, computes the Jacobian for all input-output pairs.
751
+ device: The device to launch on (optional). Only used if ``function`` is a Warp kernel.
752
+ max_blocks: The maximum number of CUDA thread blocks to use. Only used if ``function`` is a Warp kernel.
753
+ block_dim: The number of threads per block. Only used if ``function`` is a Warp kernel.
754
+ max_outputs_per_var: Maximum number of output dimensions over which to evaluate the Jacobians for the input-output pairs. Evaluates all output dimensions if value <= 0.
755
+ plot_jacobians: If True, visualizes the computed Jacobians in a plot (requires ``matplotlib``).
756
+ metadata: The metadata of the kernel function, containing the input and output labels, strides, and dtypes. If None or empty, the metadata is inferred from the kernel or function.
757
+ kernel: Deprecated argument. Use the ``function`` argument instead.
758
+
759
+ Returns:
760
+ A dictionary of Jacobians, where the keys are tuples of input and output indices, and the values are the Jacobian matrices.
761
+ """
762
+ if input_output_mask is None:
763
+ input_output_mask = []
764
+ if kernel is not None:
765
+ wp.utils.warn(
766
+ "The argument `kernel` to the function `wp.autograd.jacobian` is deprecated in favor of the `function` argument and will be removed in a future Warp version.",
767
+ DeprecationWarning,
768
+ stacklevel=3,
769
+ )
770
+ function = kernel
771
+
772
+ if metadata is None:
773
+ metadata = FunctionMetadata()
774
+
775
+ if isinstance(function, wp.Kernel):
776
+ if not function.options.get("enable_backward", True):
777
+ raise ValueError("Kernel must have backward pass enabled to compute Jacobians")
778
+ if outputs is None or len(outputs) == 0:
779
+ raise ValueError("A list of output arguments must be provided to compute kernel Jacobians")
780
+ if device is None:
781
+ device = infer_device(inputs + outputs)
782
+ if metadata.is_empty:
783
+ metadata.update_from_kernel(function, inputs)
784
+
785
+ tape = wp.Tape()
786
+ tape.record_launch(
787
+ kernel=function,
788
+ dim=dim,
789
+ inputs=inputs,
790
+ outputs=outputs,
791
+ device=device,
792
+ max_blocks=max_blocks,
793
+ block_dim=block_dim,
794
+ )
795
+ else:
796
+ tape = wp.Tape()
797
+ with tape:
798
+ outputs = function(*inputs)
799
+ if isinstance(outputs, wp.array):
800
+ outputs = [outputs]
801
+ if metadata.is_empty:
802
+ metadata.update_from_function(function, inputs, outputs)
803
+
804
+ arg_names = metadata.input_labels + metadata.output_labels
805
+
806
+ def resolve_arg(name, offset: int = 0):
807
+ if isinstance(name, int):
808
+ return name
809
+ return arg_names.index(name) + offset
810
+
811
+ input_output_mask = [
812
+ (resolve_arg(input_name), resolve_arg(output_name, -len(inputs)))
813
+ for input_name, output_name in input_output_mask
814
+ ]
815
+ input_output_mask = set(input_output_mask)
816
+
817
+ zero_grads(inputs)
818
+ zero_grads(outputs)
819
+
820
+ jacobians = {}
821
+
822
+ for input_i, output_i in itertools.product(range(len(inputs)), range(len(outputs))):
823
+ if len(input_output_mask) > 0 and (input_i, output_i) not in input_output_mask:
824
+ continue
825
+ input = inputs[input_i]
826
+ output = outputs[output_i]
827
+ if not isinstance(input, wp.array) or not input.requires_grad:
828
+ continue
829
+ if not isinstance(output, wp.array) or not output.requires_grad:
830
+ continue
831
+ out_grad = scalarize_array_1d(output.grad)
832
+ output_num = out_grad.shape[0]
833
+ jacobian = wp.empty((output_num, input.size), dtype=input.dtype, device=input.device)
834
+ jacobian.fill_(wp.nan)
835
+ if max_outputs_per_var > 0:
836
+ output_num = min(output_num, max_outputs_per_var)
837
+ for i in range(output_num):
838
+ output.grad.zero_()
839
+ if i > 0:
840
+ set_element(out_grad, i - 1, 0.0)
841
+ set_element(out_grad, i, 1.0)
842
+ tape.backward()
843
+ jacobian[i].assign(input.grad)
844
+
845
+ zero_grads(inputs)
846
+ zero_grads(outputs)
847
+ jacobians[input_i, output_i] = jacobian
848
+
849
+ if plot_jacobians:
850
+ jacobian_plot(
851
+ jacobians,
852
+ metadata,
853
+ inputs,
854
+ outputs,
855
+ )
856
+
857
+ return jacobians
858
+
859
+
860
+ def jacobian_fd(
861
+ function: Union[wp.Kernel, Callable],
862
+ dim: Tuple[int] = None,
863
+ inputs: Sequence = None,
864
+ outputs: Sequence = None,
865
+ input_output_mask: List[Tuple[Union[str, int], Union[str, int]]] = None,
866
+ device: wp.context.Devicelike = None,
867
+ max_blocks=0,
868
+ block_dim=256,
869
+ max_inputs_per_var=-1,
870
+ eps: float = 1e-4,
871
+ plot_jacobians=False,
872
+ metadata: FunctionMetadata = None,
873
+ kernel: wp.Kernel = None,
874
+ ) -> Dict[Tuple[int, int], wp.array]:
875
+ """
876
+ Computes the finite-difference Jacobian of a function or Warp kernel for the provided selection of differentiable inputs to differentiable outputs.
877
+ The method uses a central difference scheme to approximate the Jacobian.
878
+
879
+ The input function can be either a Warp kernel (e.g. a function decorated by ``@wp.kernel``) or a regular Python function that accepts arguments (of which some must be Warp arrays) and returns a Warp array or a list of Warp arrays.
880
+
881
+ The function is launched multiple times in forward-only mode with the given inputs. If ``function`` is a Warp kernel, the provided inputs and outputs,
882
+ as well as the other parameters ``dim``, ``max_blocks``, and ``block_dim`` are provided to the kernel launch (see :func:`warp.launch`).
883
+
884
+ Note:
885
+ If ``function`` is a Warp kernel, the input arguments must precede the output arguments in the kernel code definition.
886
+
887
+ Only Warp arrays with ``requires_grad=True`` are considered for the Jacobian computation.
888
+
889
+ Function arguments of type :ref:`Struct <structs>` are not yet supported.
890
+
891
+ Args:
892
+ function: The Warp kernel function, or a regular Python function that returns a Warp array or a list of Warp arrays.
893
+ dim: The number of threads to launch the kernel, can be an integer, or a Tuple of ints. Only required if ``function`` is a Warp kernel.
894
+ inputs: List of input variables. At least one of the arguments must be a Warp array with ``requires_grad=True``.
895
+ outputs: List of output variables. Optional if the function is a regular Python function that returns a Warp array or a list of Warp arrays. Only required if ``function`` is a Warp kernel.
896
+ input_output_mask: List of tuples specifying the input-output pairs to compute the Jacobian for. Inputs and outputs can be identified either by their integer indices of where they appear in the kernel input/output arguments, or by the respective argument names as strings. If None, computes the Jacobian for all input-output pairs.
897
+ device: The device to launch on (optional). Only used if ``function`` is a Warp kernel.
898
+ max_blocks: The maximum number of CUDA thread blocks to use. Only used if ``function`` is a Warp kernel.
899
+ block_dim: The number of threads per block. Only used if ``function`` is a Warp kernel.
900
+ max_inputs_per_var: Maximum number of input dimensions over which to evaluate the Jacobians for the input-output pairs. Evaluates all input dimensions if value <= 0.
901
+ eps: The finite-difference step size.
902
+ plot_jacobians: If True, visualizes the computed Jacobians in a plot (requires ``matplotlib``).
903
+ metadata: The metadata of the kernel function, containing the input and output labels, strides, and dtypes. If None or empty, the metadata is inferred from the kernel or function.
904
+ kernel: Deprecated argument. Use the ``function`` argument instead.
905
+
906
+ Returns:
907
+ A dictionary of Jacobians, where the keys are tuples of input and output indices, and the values are the Jacobian matrices.
908
+ """
909
+ if input_output_mask is None:
910
+ input_output_mask = []
911
+ if kernel is not None:
912
+ wp.utils.warn(
913
+ "The argument `kernel` to the function `wp.autograd.jacobian` is deprecated in favor of the `function` argument and will be removed in a future Warp version.",
914
+ DeprecationWarning,
915
+ stacklevel=3,
916
+ )
917
+ function = kernel
918
+
919
+ if metadata is None:
920
+ metadata = FunctionMetadata()
921
+
922
+ if isinstance(function, wp.Kernel):
923
+ if not function.options.get("enable_backward", True):
924
+ raise ValueError("Kernel must have backward pass enabled to compute Jacobians")
925
+ if outputs is None or len(outputs) == 0:
926
+ raise ValueError("A list of output arguments must be provided to compute kernel Jacobians")
927
+ if device is None:
928
+ device = infer_device(inputs + outputs)
929
+ if metadata.is_empty:
930
+ metadata.update_from_kernel(function, inputs)
931
+
932
+ tape = wp.Tape()
933
+ tape.record_launch(
934
+ kernel=function,
935
+ dim=dim,
936
+ inputs=inputs,
937
+ outputs=outputs,
938
+ device=device,
939
+ max_blocks=max_blocks,
940
+ block_dim=block_dim,
941
+ )
942
+ else:
943
+ tape = wp.Tape()
944
+ with tape:
945
+ outputs = function(*inputs)
946
+ if isinstance(outputs, wp.array):
947
+ outputs = [outputs]
948
+ if metadata.is_empty:
949
+ metadata.update_from_function(function, inputs, outputs)
950
+
951
+ arg_names = metadata.input_labels + metadata.output_labels
952
+
953
+ def resolve_arg(name, offset: int = 0):
954
+ if isinstance(name, int):
955
+ return name
956
+ return arg_names.index(name) + offset
957
+
958
+ input_output_mask = [
959
+ (resolve_arg(input_name), resolve_arg(output_name, -len(inputs)))
960
+ for input_name, output_name in input_output_mask
961
+ ]
962
+ input_output_mask = set(input_output_mask)
963
+
964
+ jacobians = {}
965
+
966
+ def conditional_clone(obj):
967
+ if isinstance(obj, wp.array):
968
+ return wp.clone(obj)
969
+ return obj
970
+
971
+ outputs_copy = [conditional_clone(output) for output in outputs]
972
+
973
+ for input_i, output_i in itertools.product(range(len(inputs)), range(len(outputs))):
974
+ if len(input_output_mask) > 0 and (input_i, output_i) not in input_output_mask:
975
+ continue
976
+ input = inputs[input_i]
977
+ output = outputs[output_i]
978
+ if not isinstance(input, wp.array) or not input.requires_grad:
979
+ continue
980
+ if not isinstance(output, wp.array) or not output.requires_grad:
981
+ continue
982
+
983
+ flat_input = scalarize_array_1d(input)
984
+
985
+ left = wp.clone(output)
986
+ right = wp.clone(output)
987
+ left_copy = wp.clone(output)
988
+ right_copy = wp.clone(output)
989
+ flat_left = scalarize_array_1d(left)
990
+ flat_right = scalarize_array_1d(right)
991
+
992
+ outputs_until_left = [conditional_clone(output) for output in outputs_copy[:output_i]]
993
+ outputs_until_right = [conditional_clone(output) for output in outputs_copy[:output_i]]
994
+ outputs_after_left = [conditional_clone(output) for output in outputs_copy[output_i + 1 :]]
995
+ outputs_after_right = [conditional_clone(output) for output in outputs_copy[output_i + 1 :]]
996
+ left_outputs = outputs_until_left + [left] + outputs_after_left
997
+ right_outputs = outputs_until_right + [right] + outputs_after_right
998
+
999
+ input_num = flat_input.shape[0]
1000
+ flat_input_copy = wp.clone(flat_input)
1001
+ jacobian = wp.empty((flat_left.size, input.size), dtype=input.dtype, device=input.device)
1002
+ jacobian.fill_(wp.nan)
1003
+
1004
+ jacobian_scalar = scalarize_array_2d(jacobian)
1005
+ jacobian_t = jacobian_scalar.transpose()
1006
+ if max_inputs_per_var > 0:
1007
+ input_num = min(input_num, max_inputs_per_var)
1008
+ for i in range(input_num):
1009
+ set_element(flat_input, i, -eps, relative=True)
1010
+ if isinstance(function, wp.Kernel):
1011
+ wp.launch(
1012
+ function,
1013
+ dim=dim,
1014
+ max_blocks=max_blocks,
1015
+ block_dim=block_dim,
1016
+ inputs=inputs,
1017
+ outputs=left_outputs,
1018
+ device=device,
1019
+ )
1020
+ else:
1021
+ outputs = function(*inputs)
1022
+ if isinstance(outputs, wp.array):
1023
+ outputs = [outputs]
1024
+ left.assign(outputs[output_i])
1025
+
1026
+ set_element(flat_input, i, 2 * eps, relative=True)
1027
+ if isinstance(function, wp.Kernel):
1028
+ wp.launch(
1029
+ function,
1030
+ dim=dim,
1031
+ max_blocks=max_blocks,
1032
+ block_dim=block_dim,
1033
+ inputs=inputs,
1034
+ outputs=right_outputs,
1035
+ device=device,
1036
+ )
1037
+ else:
1038
+ outputs = function(*inputs)
1039
+ if isinstance(outputs, wp.array):
1040
+ outputs = [outputs]
1041
+ right.assign(outputs[output_i])
1042
+
1043
+ # restore input
1044
+ flat_input.assign(flat_input_copy)
1045
+
1046
+ compute_fd(
1047
+ flat_left,
1048
+ flat_right,
1049
+ eps,
1050
+ jacobian_t[i],
1051
+ )
1052
+
1053
+ if i < input_num - 1:
1054
+ # reset output buffers
1055
+ left.assign(left_copy)
1056
+ right.assign(right_copy)
1057
+ flat_left = scalarize_array_1d(left)
1058
+ flat_right = scalarize_array_1d(right)
1059
+
1060
+ jacobians[input_i, output_i] = jacobian
1061
+
1062
+ if plot_jacobians:
1063
+ jacobian_plot(
1064
+ jacobians,
1065
+ metadata,
1066
+ inputs,
1067
+ outputs,
1068
+ )
1069
+
1070
+ return jacobians
1071
+
1072
+
1073
+ @wp.kernel(enable_backward=False)
1074
+ def set_element_kernel(a: wp.array(dtype=Any), i: int, val: Any, relative: bool):
1075
+ if relative:
1076
+ a[i] += val
1077
+ else:
1078
+ a[i] = val
1079
+
1080
+
1081
+ def set_element(a: wp.array(dtype=Any), i: int, val: Any, relative: bool = False):
1082
+ wp.launch(set_element_kernel, dim=1, inputs=[a, i, a.dtype(val), relative], device=a.device)
1083
+
1084
+
1085
+ @wp.kernel(enable_backward=False)
1086
+ def compute_fd_kernel(left: wp.array(dtype=float), right: wp.array(dtype=float), eps: float, fd: wp.array(dtype=float)):
1087
+ tid = wp.tid()
1088
+ fd[tid] = (right[tid] - left[tid]) / (2.0 * eps)
1089
+
1090
+
1091
+ def compute_fd(left: wp.array(dtype=Any), right: wp.array(dtype=Any), eps: float, fd: wp.array(dtype=Any)):
1092
+ wp.launch(compute_fd_kernel, dim=len(left), inputs=[left, right, eps], outputs=[fd], device=left.device)
1093
+
1094
+
1095
+ @wp.kernel(enable_backward=False)
1096
+ def compute_error_kernel(
1097
+ jacobian_ad: wp.array(dtype=Any),
1098
+ jacobian_fd: wp.array(dtype=Any),
1099
+ relative_error: wp.array(dtype=Any),
1100
+ absolute_error: wp.array(dtype=Any),
1101
+ ):
1102
+ tid = wp.tid()
1103
+ ad = jacobian_ad[tid]
1104
+ fd = jacobian_fd[tid]
1105
+ denom = ad
1106
+ if abs(ad) < 1e-8:
1107
+ denom = (type(ad))(1e-8)
1108
+ relative_error[tid] = (ad - fd) / denom
1109
+ absolute_error[tid] = wp.abs(ad - fd)
1110
+
1111
+
1112
+ def print_table(headers, cells):
1113
+ """
1114
+ Prints a table with the given headers and cells.
1115
+
1116
+ Args:
1117
+ headers: List of header strings.
1118
+ cells: List of lists of cell strings.
1119
+ """
1120
+ import re
1121
+
1122
+ def sanitized_len(s):
1123
+ return len(re.sub(r"\033\[\d+m", "", str(s)))
1124
+
1125
+ col_widths = [max(sanitized_len(cell) for cell in col) for col in zip(headers, *cells)]
1126
+ for header, col_width in zip(headers, col_widths):
1127
+ print(f"{header:{col_width}}", end=" | ")
1128
+ print()
1129
+ print("-" * (sum(col_widths) + 3 * len(col_widths) - 1))
1130
+ for cell_row in cells:
1131
+ for cell, col_width in zip(cell_row, col_widths):
1132
+ print(f"{cell:{col_width}}", end=" | ")
1133
+ print()
1134
+
1135
+
1136
+ def zero_grads(arrays: list):
1137
+ """
1138
+ Zeros the gradients of all Warp arrays in the given list.
1139
+ """
1140
+ for array in arrays:
1141
+ if isinstance(array, wp.array) and array.requires_grad:
1142
+ array.grad.zero_()