warp-lang 1.7.0__py3-none-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (429) hide show
  1. warp/__init__.py +139 -0
  2. warp/__init__.pyi +1 -0
  3. warp/autograd.py +1142 -0
  4. warp/bin/warp-clang.so +0 -0
  5. warp/bin/warp.so +0 -0
  6. warp/build.py +557 -0
  7. warp/build_dll.py +405 -0
  8. warp/builtins.py +6855 -0
  9. warp/codegen.py +3969 -0
  10. warp/config.py +158 -0
  11. warp/constants.py +57 -0
  12. warp/context.py +6812 -0
  13. warp/dlpack.py +462 -0
  14. warp/examples/__init__.py +24 -0
  15. warp/examples/assets/bear.usd +0 -0
  16. warp/examples/assets/bunny.usd +0 -0
  17. warp/examples/assets/cartpole.urdf +110 -0
  18. warp/examples/assets/crazyflie.usd +0 -0
  19. warp/examples/assets/cube.usd +0 -0
  20. warp/examples/assets/nonuniform.usd +0 -0
  21. warp/examples/assets/nv_ant.xml +92 -0
  22. warp/examples/assets/nv_humanoid.xml +183 -0
  23. warp/examples/assets/nvidia_logo.png +0 -0
  24. warp/examples/assets/pixel.jpg +0 -0
  25. warp/examples/assets/quadruped.urdf +268 -0
  26. warp/examples/assets/rocks.nvdb +0 -0
  27. warp/examples/assets/rocks.usd +0 -0
  28. warp/examples/assets/sphere.usd +0 -0
  29. warp/examples/assets/square_cloth.usd +0 -0
  30. warp/examples/benchmarks/benchmark_api.py +389 -0
  31. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  32. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  33. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  34. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  35. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  36. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  37. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  38. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  39. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  40. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  41. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  42. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  43. warp/examples/benchmarks/benchmark_launches.py +301 -0
  44. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  45. warp/examples/browse.py +37 -0
  46. warp/examples/core/example_cupy.py +86 -0
  47. warp/examples/core/example_dem.py +241 -0
  48. warp/examples/core/example_fluid.py +299 -0
  49. warp/examples/core/example_graph_capture.py +150 -0
  50. warp/examples/core/example_marching_cubes.py +194 -0
  51. warp/examples/core/example_mesh.py +180 -0
  52. warp/examples/core/example_mesh_intersect.py +211 -0
  53. warp/examples/core/example_nvdb.py +182 -0
  54. warp/examples/core/example_raycast.py +111 -0
  55. warp/examples/core/example_raymarch.py +205 -0
  56. warp/examples/core/example_render_opengl.py +193 -0
  57. warp/examples/core/example_sample_mesh.py +300 -0
  58. warp/examples/core/example_sph.py +411 -0
  59. warp/examples/core/example_torch.py +211 -0
  60. warp/examples/core/example_wave.py +269 -0
  61. warp/examples/fem/example_adaptive_grid.py +286 -0
  62. warp/examples/fem/example_apic_fluid.py +423 -0
  63. warp/examples/fem/example_burgers.py +261 -0
  64. warp/examples/fem/example_convection_diffusion.py +178 -0
  65. warp/examples/fem/example_convection_diffusion_dg.py +204 -0
  66. warp/examples/fem/example_deformed_geometry.py +172 -0
  67. warp/examples/fem/example_diffusion.py +196 -0
  68. warp/examples/fem/example_diffusion_3d.py +225 -0
  69. warp/examples/fem/example_diffusion_mgpu.py +220 -0
  70. warp/examples/fem/example_distortion_energy.py +228 -0
  71. warp/examples/fem/example_magnetostatics.py +240 -0
  72. warp/examples/fem/example_mixed_elasticity.py +291 -0
  73. warp/examples/fem/example_navier_stokes.py +261 -0
  74. warp/examples/fem/example_nonconforming_contact.py +298 -0
  75. warp/examples/fem/example_stokes.py +213 -0
  76. warp/examples/fem/example_stokes_transfer.py +262 -0
  77. warp/examples/fem/example_streamlines.py +352 -0
  78. warp/examples/fem/utils.py +1000 -0
  79. warp/examples/interop/example_jax_callable.py +116 -0
  80. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  81. warp/examples/interop/example_jax_kernel.py +205 -0
  82. warp/examples/optim/example_bounce.py +266 -0
  83. warp/examples/optim/example_cloth_throw.py +228 -0
  84. warp/examples/optim/example_diffray.py +561 -0
  85. warp/examples/optim/example_drone.py +870 -0
  86. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  87. warp/examples/optim/example_inverse_kinematics.py +182 -0
  88. warp/examples/optim/example_inverse_kinematics_torch.py +191 -0
  89. warp/examples/optim/example_softbody_properties.py +400 -0
  90. warp/examples/optim/example_spring_cage.py +245 -0
  91. warp/examples/optim/example_trajectory.py +227 -0
  92. warp/examples/sim/example_cartpole.py +143 -0
  93. warp/examples/sim/example_cloth.py +225 -0
  94. warp/examples/sim/example_cloth_self_contact.py +322 -0
  95. warp/examples/sim/example_granular.py +130 -0
  96. warp/examples/sim/example_granular_collision_sdf.py +202 -0
  97. warp/examples/sim/example_jacobian_ik.py +244 -0
  98. warp/examples/sim/example_particle_chain.py +124 -0
  99. warp/examples/sim/example_quadruped.py +203 -0
  100. warp/examples/sim/example_rigid_chain.py +203 -0
  101. warp/examples/sim/example_rigid_contact.py +195 -0
  102. warp/examples/sim/example_rigid_force.py +133 -0
  103. warp/examples/sim/example_rigid_gyroscopic.py +115 -0
  104. warp/examples/sim/example_rigid_soft_contact.py +140 -0
  105. warp/examples/sim/example_soft_body.py +196 -0
  106. warp/examples/tile/example_tile_cholesky.py +87 -0
  107. warp/examples/tile/example_tile_convolution.py +66 -0
  108. warp/examples/tile/example_tile_fft.py +55 -0
  109. warp/examples/tile/example_tile_filtering.py +113 -0
  110. warp/examples/tile/example_tile_matmul.py +85 -0
  111. warp/examples/tile/example_tile_mlp.py +383 -0
  112. warp/examples/tile/example_tile_nbody.py +199 -0
  113. warp/examples/tile/example_tile_walker.py +327 -0
  114. warp/fabric.py +355 -0
  115. warp/fem/__init__.py +106 -0
  116. warp/fem/adaptivity.py +508 -0
  117. warp/fem/cache.py +572 -0
  118. warp/fem/dirichlet.py +202 -0
  119. warp/fem/domain.py +411 -0
  120. warp/fem/field/__init__.py +125 -0
  121. warp/fem/field/field.py +619 -0
  122. warp/fem/field/nodal_field.py +326 -0
  123. warp/fem/field/restriction.py +37 -0
  124. warp/fem/field/virtual.py +848 -0
  125. warp/fem/geometry/__init__.py +32 -0
  126. warp/fem/geometry/adaptive_nanogrid.py +857 -0
  127. warp/fem/geometry/closest_point.py +84 -0
  128. warp/fem/geometry/deformed_geometry.py +221 -0
  129. warp/fem/geometry/element.py +776 -0
  130. warp/fem/geometry/geometry.py +362 -0
  131. warp/fem/geometry/grid_2d.py +392 -0
  132. warp/fem/geometry/grid_3d.py +452 -0
  133. warp/fem/geometry/hexmesh.py +911 -0
  134. warp/fem/geometry/nanogrid.py +571 -0
  135. warp/fem/geometry/partition.py +389 -0
  136. warp/fem/geometry/quadmesh.py +663 -0
  137. warp/fem/geometry/tetmesh.py +855 -0
  138. warp/fem/geometry/trimesh.py +806 -0
  139. warp/fem/integrate.py +2335 -0
  140. warp/fem/linalg.py +419 -0
  141. warp/fem/operator.py +293 -0
  142. warp/fem/polynomial.py +229 -0
  143. warp/fem/quadrature/__init__.py +17 -0
  144. warp/fem/quadrature/pic_quadrature.py +299 -0
  145. warp/fem/quadrature/quadrature.py +591 -0
  146. warp/fem/space/__init__.py +228 -0
  147. warp/fem/space/basis_function_space.py +468 -0
  148. warp/fem/space/basis_space.py +667 -0
  149. warp/fem/space/dof_mapper.py +251 -0
  150. warp/fem/space/function_space.py +309 -0
  151. warp/fem/space/grid_2d_function_space.py +177 -0
  152. warp/fem/space/grid_3d_function_space.py +227 -0
  153. warp/fem/space/hexmesh_function_space.py +257 -0
  154. warp/fem/space/nanogrid_function_space.py +201 -0
  155. warp/fem/space/partition.py +367 -0
  156. warp/fem/space/quadmesh_function_space.py +223 -0
  157. warp/fem/space/restriction.py +179 -0
  158. warp/fem/space/shape/__init__.py +143 -0
  159. warp/fem/space/shape/cube_shape_function.py +1105 -0
  160. warp/fem/space/shape/shape_function.py +133 -0
  161. warp/fem/space/shape/square_shape_function.py +926 -0
  162. warp/fem/space/shape/tet_shape_function.py +834 -0
  163. warp/fem/space/shape/triangle_shape_function.py +672 -0
  164. warp/fem/space/tetmesh_function_space.py +271 -0
  165. warp/fem/space/topology.py +424 -0
  166. warp/fem/space/trimesh_function_space.py +194 -0
  167. warp/fem/types.py +99 -0
  168. warp/fem/utils.py +420 -0
  169. warp/jax.py +187 -0
  170. warp/jax_experimental/__init__.py +16 -0
  171. warp/jax_experimental/custom_call.py +351 -0
  172. warp/jax_experimental/ffi.py +698 -0
  173. warp/jax_experimental/xla_ffi.py +602 -0
  174. warp/math.py +244 -0
  175. warp/native/array.h +1145 -0
  176. warp/native/builtin.h +1800 -0
  177. warp/native/bvh.cpp +492 -0
  178. warp/native/bvh.cu +791 -0
  179. warp/native/bvh.h +554 -0
  180. warp/native/clang/clang.cpp +536 -0
  181. warp/native/coloring.cpp +613 -0
  182. warp/native/crt.cpp +51 -0
  183. warp/native/crt.h +362 -0
  184. warp/native/cuda_crt.h +1058 -0
  185. warp/native/cuda_util.cpp +646 -0
  186. warp/native/cuda_util.h +307 -0
  187. warp/native/error.cpp +77 -0
  188. warp/native/error.h +36 -0
  189. warp/native/exports.h +1878 -0
  190. warp/native/fabric.h +245 -0
  191. warp/native/hashgrid.cpp +311 -0
  192. warp/native/hashgrid.cu +87 -0
  193. warp/native/hashgrid.h +240 -0
  194. warp/native/initializer_array.h +41 -0
  195. warp/native/intersect.h +1230 -0
  196. warp/native/intersect_adj.h +375 -0
  197. warp/native/intersect_tri.h +339 -0
  198. warp/native/marching.cpp +19 -0
  199. warp/native/marching.cu +514 -0
  200. warp/native/marching.h +19 -0
  201. warp/native/mat.h +2220 -0
  202. warp/native/mathdx.cpp +87 -0
  203. warp/native/matnn.h +343 -0
  204. warp/native/mesh.cpp +266 -0
  205. warp/native/mesh.cu +404 -0
  206. warp/native/mesh.h +1980 -0
  207. warp/native/nanovdb/GridHandle.h +366 -0
  208. warp/native/nanovdb/HostBuffer.h +590 -0
  209. warp/native/nanovdb/NanoVDB.h +6624 -0
  210. warp/native/nanovdb/PNanoVDB.h +3390 -0
  211. warp/native/noise.h +859 -0
  212. warp/native/quat.h +1371 -0
  213. warp/native/rand.h +342 -0
  214. warp/native/range.h +139 -0
  215. warp/native/reduce.cpp +174 -0
  216. warp/native/reduce.cu +364 -0
  217. warp/native/runlength_encode.cpp +79 -0
  218. warp/native/runlength_encode.cu +61 -0
  219. warp/native/scan.cpp +47 -0
  220. warp/native/scan.cu +53 -0
  221. warp/native/scan.h +23 -0
  222. warp/native/solid_angle.h +466 -0
  223. warp/native/sort.cpp +251 -0
  224. warp/native/sort.cu +277 -0
  225. warp/native/sort.h +33 -0
  226. warp/native/sparse.cpp +378 -0
  227. warp/native/sparse.cu +524 -0
  228. warp/native/spatial.h +657 -0
  229. warp/native/svd.h +702 -0
  230. warp/native/temp_buffer.h +46 -0
  231. warp/native/tile.h +2584 -0
  232. warp/native/tile_reduce.h +264 -0
  233. warp/native/vec.h +1426 -0
  234. warp/native/volume.cpp +501 -0
  235. warp/native/volume.cu +67 -0
  236. warp/native/volume.h +969 -0
  237. warp/native/volume_builder.cu +477 -0
  238. warp/native/volume_builder.h +52 -0
  239. warp/native/volume_impl.h +70 -0
  240. warp/native/warp.cpp +1082 -0
  241. warp/native/warp.cu +3636 -0
  242. warp/native/warp.h +381 -0
  243. warp/optim/__init__.py +17 -0
  244. warp/optim/adam.py +163 -0
  245. warp/optim/linear.py +1137 -0
  246. warp/optim/sgd.py +112 -0
  247. warp/paddle.py +407 -0
  248. warp/render/__init__.py +18 -0
  249. warp/render/render_opengl.py +3518 -0
  250. warp/render/render_usd.py +784 -0
  251. warp/render/utils.py +160 -0
  252. warp/sim/__init__.py +65 -0
  253. warp/sim/articulation.py +793 -0
  254. warp/sim/collide.py +2395 -0
  255. warp/sim/graph_coloring.py +300 -0
  256. warp/sim/import_mjcf.py +790 -0
  257. warp/sim/import_snu.py +227 -0
  258. warp/sim/import_urdf.py +579 -0
  259. warp/sim/import_usd.py +894 -0
  260. warp/sim/inertia.py +324 -0
  261. warp/sim/integrator.py +242 -0
  262. warp/sim/integrator_euler.py +1997 -0
  263. warp/sim/integrator_featherstone.py +2101 -0
  264. warp/sim/integrator_vbd.py +2048 -0
  265. warp/sim/integrator_xpbd.py +3292 -0
  266. warp/sim/model.py +4791 -0
  267. warp/sim/particles.py +121 -0
  268. warp/sim/render.py +427 -0
  269. warp/sim/utils.py +428 -0
  270. warp/sparse.py +2057 -0
  271. warp/stubs.py +3333 -0
  272. warp/tape.py +1203 -0
  273. warp/tests/__init__.py +1 -0
  274. warp/tests/__main__.py +4 -0
  275. warp/tests/assets/curlnoise_golden.npy +0 -0
  276. warp/tests/assets/mlp_golden.npy +0 -0
  277. warp/tests/assets/pixel.npy +0 -0
  278. warp/tests/assets/pnoise_golden.npy +0 -0
  279. warp/tests/assets/spiky.usd +0 -0
  280. warp/tests/assets/test_grid.nvdb +0 -0
  281. warp/tests/assets/test_index_grid.nvdb +0 -0
  282. warp/tests/assets/test_int32_grid.nvdb +0 -0
  283. warp/tests/assets/test_vec_grid.nvdb +0 -0
  284. warp/tests/assets/torus.nvdb +0 -0
  285. warp/tests/assets/torus.usda +105 -0
  286. warp/tests/aux_test_class_kernel.py +34 -0
  287. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  288. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  289. warp/tests/aux_test_dependent.py +29 -0
  290. warp/tests/aux_test_grad_customs.py +29 -0
  291. warp/tests/aux_test_instancing_gc.py +26 -0
  292. warp/tests/aux_test_module_unload.py +23 -0
  293. warp/tests/aux_test_name_clash1.py +40 -0
  294. warp/tests/aux_test_name_clash2.py +40 -0
  295. warp/tests/aux_test_reference.py +9 -0
  296. warp/tests/aux_test_reference_reference.py +8 -0
  297. warp/tests/aux_test_square.py +16 -0
  298. warp/tests/aux_test_unresolved_func.py +22 -0
  299. warp/tests/aux_test_unresolved_symbol.py +22 -0
  300. warp/tests/cuda/__init__.py +0 -0
  301. warp/tests/cuda/test_async.py +676 -0
  302. warp/tests/cuda/test_ipc.py +124 -0
  303. warp/tests/cuda/test_mempool.py +233 -0
  304. warp/tests/cuda/test_multigpu.py +169 -0
  305. warp/tests/cuda/test_peer.py +139 -0
  306. warp/tests/cuda/test_pinned.py +84 -0
  307. warp/tests/cuda/test_streams.py +634 -0
  308. warp/tests/geometry/__init__.py +0 -0
  309. warp/tests/geometry/test_bvh.py +200 -0
  310. warp/tests/geometry/test_hash_grid.py +221 -0
  311. warp/tests/geometry/test_marching_cubes.py +74 -0
  312. warp/tests/geometry/test_mesh.py +316 -0
  313. warp/tests/geometry/test_mesh_query_aabb.py +399 -0
  314. warp/tests/geometry/test_mesh_query_point.py +932 -0
  315. warp/tests/geometry/test_mesh_query_ray.py +311 -0
  316. warp/tests/geometry/test_volume.py +1103 -0
  317. warp/tests/geometry/test_volume_write.py +346 -0
  318. warp/tests/interop/__init__.py +0 -0
  319. warp/tests/interop/test_dlpack.py +729 -0
  320. warp/tests/interop/test_jax.py +371 -0
  321. warp/tests/interop/test_paddle.py +800 -0
  322. warp/tests/interop/test_torch.py +1001 -0
  323. warp/tests/run_coverage_serial.py +39 -0
  324. warp/tests/sim/__init__.py +0 -0
  325. warp/tests/sim/disabled_kinematics.py +244 -0
  326. warp/tests/sim/flaky_test_sim_grad.py +290 -0
  327. warp/tests/sim/test_collision.py +604 -0
  328. warp/tests/sim/test_coloring.py +258 -0
  329. warp/tests/sim/test_model.py +224 -0
  330. warp/tests/sim/test_sim_grad_bounce_linear.py +212 -0
  331. warp/tests/sim/test_sim_kinematics.py +98 -0
  332. warp/tests/sim/test_vbd.py +597 -0
  333. warp/tests/test_adam.py +163 -0
  334. warp/tests/test_arithmetic.py +1096 -0
  335. warp/tests/test_array.py +2972 -0
  336. warp/tests/test_array_reduce.py +156 -0
  337. warp/tests/test_assert.py +250 -0
  338. warp/tests/test_atomic.py +153 -0
  339. warp/tests/test_bool.py +220 -0
  340. warp/tests/test_builtins_resolution.py +1298 -0
  341. warp/tests/test_closest_point_edge_edge.py +327 -0
  342. warp/tests/test_codegen.py +810 -0
  343. warp/tests/test_codegen_instancing.py +1495 -0
  344. warp/tests/test_compile_consts.py +215 -0
  345. warp/tests/test_conditional.py +252 -0
  346. warp/tests/test_context.py +42 -0
  347. warp/tests/test_copy.py +238 -0
  348. warp/tests/test_ctypes.py +638 -0
  349. warp/tests/test_dense.py +73 -0
  350. warp/tests/test_devices.py +97 -0
  351. warp/tests/test_examples.py +482 -0
  352. warp/tests/test_fabricarray.py +996 -0
  353. warp/tests/test_fast_math.py +74 -0
  354. warp/tests/test_fem.py +2003 -0
  355. warp/tests/test_fp16.py +136 -0
  356. warp/tests/test_func.py +454 -0
  357. warp/tests/test_future_annotations.py +98 -0
  358. warp/tests/test_generics.py +656 -0
  359. warp/tests/test_grad.py +893 -0
  360. warp/tests/test_grad_customs.py +339 -0
  361. warp/tests/test_grad_debug.py +341 -0
  362. warp/tests/test_implicit_init.py +411 -0
  363. warp/tests/test_import.py +45 -0
  364. warp/tests/test_indexedarray.py +1140 -0
  365. warp/tests/test_intersect.py +73 -0
  366. warp/tests/test_iter.py +76 -0
  367. warp/tests/test_large.py +177 -0
  368. warp/tests/test_launch.py +411 -0
  369. warp/tests/test_lerp.py +151 -0
  370. warp/tests/test_linear_solvers.py +193 -0
  371. warp/tests/test_lvalue.py +427 -0
  372. warp/tests/test_mat.py +2089 -0
  373. warp/tests/test_mat_lite.py +122 -0
  374. warp/tests/test_mat_scalar_ops.py +2913 -0
  375. warp/tests/test_math.py +178 -0
  376. warp/tests/test_mlp.py +282 -0
  377. warp/tests/test_module_hashing.py +258 -0
  378. warp/tests/test_modules_lite.py +44 -0
  379. warp/tests/test_noise.py +252 -0
  380. warp/tests/test_operators.py +299 -0
  381. warp/tests/test_options.py +129 -0
  382. warp/tests/test_overwrite.py +551 -0
  383. warp/tests/test_print.py +339 -0
  384. warp/tests/test_quat.py +2315 -0
  385. warp/tests/test_rand.py +339 -0
  386. warp/tests/test_reload.py +302 -0
  387. warp/tests/test_rounding.py +185 -0
  388. warp/tests/test_runlength_encode.py +196 -0
  389. warp/tests/test_scalar_ops.py +105 -0
  390. warp/tests/test_smoothstep.py +108 -0
  391. warp/tests/test_snippet.py +318 -0
  392. warp/tests/test_sparse.py +582 -0
  393. warp/tests/test_spatial.py +2229 -0
  394. warp/tests/test_special_values.py +361 -0
  395. warp/tests/test_static.py +592 -0
  396. warp/tests/test_struct.py +734 -0
  397. warp/tests/test_tape.py +204 -0
  398. warp/tests/test_transient_module.py +93 -0
  399. warp/tests/test_triangle_closest_point.py +145 -0
  400. warp/tests/test_types.py +562 -0
  401. warp/tests/test_utils.py +588 -0
  402. warp/tests/test_vec.py +1487 -0
  403. warp/tests/test_vec_lite.py +80 -0
  404. warp/tests/test_vec_scalar_ops.py +2327 -0
  405. warp/tests/test_verify_fp.py +100 -0
  406. warp/tests/tile/__init__.py +0 -0
  407. warp/tests/tile/test_tile.py +780 -0
  408. warp/tests/tile/test_tile_load.py +407 -0
  409. warp/tests/tile/test_tile_mathdx.py +208 -0
  410. warp/tests/tile/test_tile_mlp.py +402 -0
  411. warp/tests/tile/test_tile_reduce.py +447 -0
  412. warp/tests/tile/test_tile_shared_memory.py +247 -0
  413. warp/tests/tile/test_tile_view.py +173 -0
  414. warp/tests/unittest_serial.py +47 -0
  415. warp/tests/unittest_suites.py +427 -0
  416. warp/tests/unittest_utils.py +468 -0
  417. warp/tests/walkthrough_debug.py +93 -0
  418. warp/thirdparty/__init__.py +0 -0
  419. warp/thirdparty/appdirs.py +598 -0
  420. warp/thirdparty/dlpack.py +145 -0
  421. warp/thirdparty/unittest_parallel.py +570 -0
  422. warp/torch.py +391 -0
  423. warp/types.py +5230 -0
  424. warp/utils.py +1137 -0
  425. warp_lang-1.7.0.dist-info/METADATA +516 -0
  426. warp_lang-1.7.0.dist-info/RECORD +429 -0
  427. warp_lang-1.7.0.dist-info/WHEEL +5 -0
  428. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  429. warp_lang-1.7.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,339 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import unittest
17
+
18
+ import numpy as np
19
+
20
+ import warp as wp
21
+ from warp.tests.unittest_utils import *
22
+
23
+
24
+ # atomic add function that memorizes which thread incremented the counter
25
+ # so that the correct counter value per thread can be used in the replay
26
+ # phase of the backward pass
27
+ @wp.func
28
+ def reversible_increment(
29
+ counter: wp.array(dtype=int), counter_index: int, value: int, thread_values: wp.array(dtype=int), tid: int
30
+ ):
31
+ """This is a docstring"""
32
+ next_index = wp.atomic_add(counter, counter_index, value)
33
+ thread_values[tid] = next_index
34
+ return next_index
35
+
36
+
37
+ @wp.func_replay(reversible_increment)
38
+ def replay_reversible_increment(
39
+ counter: wp.array(dtype=int), counter_index: int, value: int, thread_values: wp.array(dtype=int), tid: int
40
+ ):
41
+ """This is a docstring"""
42
+ return thread_values[tid]
43
+
44
+
45
+ def test_custom_replay_grad(test, device):
46
+ num_threads = 128
47
+ counter = wp.zeros(1, dtype=wp.int32, device=device)
48
+ thread_ids = wp.zeros(num_threads, dtype=wp.int32, device=device)
49
+ inputs = wp.array(np.arange(num_threads, dtype=np.float32), device=device, requires_grad=True)
50
+ outputs = wp.zeros_like(inputs)
51
+
52
+ @wp.kernel
53
+ def run_atomic_add(
54
+ input: wp.array(dtype=float),
55
+ counter: wp.array(dtype=int),
56
+ thread_values: wp.array(dtype=int),
57
+ output: wp.array(dtype=float),
58
+ ):
59
+ tid = wp.tid()
60
+ idx = reversible_increment(counter, 0, 1, thread_values, tid)
61
+ output[idx] = input[idx] ** 2.0
62
+
63
+ tape = wp.Tape()
64
+ with tape:
65
+ wp.launch(
66
+ run_atomic_add, dim=num_threads, inputs=[inputs, counter, thread_ids], outputs=[outputs], device=device
67
+ )
68
+
69
+ tape.backward(grads={outputs: wp.ones(num_threads, dtype=wp.float32, device=device)})
70
+ assert_np_equal(inputs.grad.numpy(), 2.0 * inputs.numpy(), tol=1e-4)
71
+
72
+
73
+ @wp.func
74
+ def overload_fn(x: float, y: float):
75
+ """This is a docstring"""
76
+ return x * 3.0 + y / 3.0, y**2.5
77
+
78
+
79
+ @wp.func_grad(overload_fn)
80
+ def overload_fn_grad(x: float, y: float, adj_ret0: float, adj_ret1: float):
81
+ """This is a docstring"""
82
+ wp.adjoint[x] += x * adj_ret0 * 42.0 + y * adj_ret1 * 10.0
83
+ wp.adjoint[y] += y * adj_ret1 * 3.0
84
+
85
+
86
+ @wp.struct
87
+ class MyStruct:
88
+ """This is a docstring"""
89
+
90
+ scalar: float
91
+ vec: wp.vec3
92
+
93
+
94
+ @wp.func
95
+ def overload_fn(x: MyStruct):
96
+ """This is a docstring"""
97
+ return x.vec[0] * x.vec[1] * x.vec[2] * 4.0, wp.length(x.vec), x.scalar**0.5
98
+
99
+
100
+ @wp.func_grad(overload_fn)
101
+ def overload_fn_grad(x: MyStruct, adj_ret0: float, adj_ret1: float, adj_ret2: float):
102
+ """This is a docstring"""
103
+ wp.adjoint[x.scalar] += x.scalar * adj_ret0 * 10.0
104
+ wp.adjoint[x.vec][0] += adj_ret0 * x.vec[1] * x.vec[2] * 20.0
105
+ wp.adjoint[x.vec][1] += adj_ret1 * x.vec[0] * x.vec[2] * 30.0
106
+ wp.adjoint[x.vec][2] += adj_ret2 * x.vec[0] * x.vec[1] * 40.0
107
+
108
+
109
+ @wp.kernel
110
+ def run_overload_float_fn(
111
+ xs: wp.array(dtype=float), ys: wp.array(dtype=float), output0: wp.array(dtype=float), output1: wp.array(dtype=float)
112
+ ):
113
+ """This is a docstring"""
114
+ i = wp.tid()
115
+ out0, out1 = overload_fn(xs[i], ys[i])
116
+ output0[i] = out0
117
+ output1[i] = out1
118
+
119
+
120
+ @wp.kernel
121
+ def run_overload_struct_fn(xs: wp.array(dtype=MyStruct), output: wp.array(dtype=float)):
122
+ i = wp.tid()
123
+ out0, out1, out2 = overload_fn(xs[i])
124
+ output[i] = out0 + out1 + out2
125
+
126
+
127
+ def test_custom_overload_grad(test, device):
128
+ dim = 3
129
+ xs_float = wp.array(np.arange(1.0, dim + 1.0), dtype=wp.float32, requires_grad=True, device=device)
130
+ ys_float = wp.array(np.arange(10.0, dim + 10.0), dtype=wp.float32, requires_grad=True, device=device)
131
+ out0_float = wp.zeros(dim, device=device)
132
+ out1_float = wp.zeros(dim, device=device)
133
+ tape = wp.Tape()
134
+ with tape:
135
+ wp.launch(
136
+ run_overload_float_fn, dim=dim, inputs=[xs_float, ys_float], outputs=[out0_float, out1_float], device=device
137
+ )
138
+ tape.backward(
139
+ grads={
140
+ out0_float: wp.ones(dim, dtype=wp.float32, device=device),
141
+ out1_float: wp.ones(dim, dtype=wp.float32, device=device),
142
+ }
143
+ )
144
+ assert_np_equal(xs_float.grad.numpy(), xs_float.numpy() * 42.0 + ys_float.numpy() * 10.0)
145
+ assert_np_equal(ys_float.grad.numpy(), ys_float.numpy() * 3.0)
146
+
147
+ x0 = MyStruct()
148
+ x0.vec = wp.vec3(1.0, 2.0, 3.0)
149
+ x0.scalar = 4.0
150
+ x1 = MyStruct()
151
+ x1.vec = wp.vec3(5.0, 6.0, 7.0)
152
+ x1.scalar = -1.0
153
+ x2 = MyStruct()
154
+ x2.vec = wp.vec3(8.0, 9.0, 10.0)
155
+ x2.scalar = 19.0
156
+ xs_struct = wp.array([x0, x1, x2], dtype=MyStruct, requires_grad=True, device=device)
157
+ out_struct = wp.zeros(dim, device=device)
158
+ tape = wp.Tape()
159
+ with tape:
160
+ wp.launch(run_overload_struct_fn, dim=dim, inputs=[xs_struct], outputs=[out_struct], device=device)
161
+ tape.backward(grads={out_struct: wp.ones(dim, dtype=wp.float32, device=device)})
162
+ xs_struct_np = xs_struct.numpy()
163
+ struct_grads = xs_struct.grad.numpy()
164
+ # fmt: off
165
+ assert_np_equal(
166
+ np.array([g[0] for g in struct_grads]),
167
+ np.array([g[0] * 10.0 for g in xs_struct_np]))
168
+ assert_np_equal(
169
+ np.array([g[1][0] for g in struct_grads]),
170
+ np.array([g[1][1] * g[1][2] * 20.0 for g in xs_struct_np]))
171
+ assert_np_equal(
172
+ np.array([g[1][1] for g in struct_grads]),
173
+ np.array([g[1][0] * g[1][2] * 30.0 for g in xs_struct_np]))
174
+ assert_np_equal(
175
+ np.array([g[1][2] for g in struct_grads]),
176
+ np.array([g[1][0] * g[1][1] * 40.0 for g in xs_struct_np]))
177
+ # fmt: on
178
+
179
+
180
+ def test_custom_import_grad(test, device):
181
+ from warp.tests.aux_test_grad_customs import aux_custom_fn
182
+
183
+ @wp.kernel
184
+ def run_defined_float_fn(
185
+ xs: wp.array(dtype=float),
186
+ ys: wp.array(dtype=float),
187
+ output0: wp.array(dtype=float),
188
+ output1: wp.array(dtype=float),
189
+ ):
190
+ i = wp.tid()
191
+ out0, out1 = aux_custom_fn(xs[i], ys[i])
192
+ output0[i] = out0
193
+ output1[i] = out1
194
+
195
+ dim = 3
196
+ xs_float = wp.array(np.arange(1.0, dim + 1.0), dtype=wp.float32, requires_grad=True, device=device)
197
+ ys_float = wp.array(np.arange(10.0, dim + 10.0), dtype=wp.float32, requires_grad=True, device=device)
198
+ out0_float = wp.zeros(dim, device=device)
199
+ out1_float = wp.zeros(dim, device=device)
200
+ tape = wp.Tape()
201
+ with tape:
202
+ wp.launch(
203
+ run_defined_float_fn, dim=dim, inputs=[xs_float, ys_float], outputs=[out0_float, out1_float], device=device
204
+ )
205
+ tape.backward(
206
+ grads={
207
+ out0_float: wp.ones(dim, dtype=wp.float32, device=device),
208
+ out1_float: wp.ones(dim, dtype=wp.float32, device=device),
209
+ }
210
+ )
211
+ assert_np_equal(xs_float.grad.numpy(), xs_float.numpy() * 42.0 + ys_float.numpy() * 10.0)
212
+ assert_np_equal(ys_float.grad.numpy(), ys_float.numpy() * 3.0)
213
+
214
+
215
+ @wp.func
216
+ def sigmoid(x: float):
217
+ return 1.0 / (1.0 + wp.exp(-x))
218
+
219
+
220
+ @wp.func_grad(sigmoid)
221
+ def adj_sigmoid(x: float, adj: float):
222
+ # unused function to test that we don't run into infinite recursion when calling
223
+ # the forward function from within the gradient function
224
+ wp.adjoint[x] += adj * sigmoid(x) * (1.0 - sigmoid(x))
225
+
226
+
227
+ @wp.func
228
+ def sigmoid_no_return(i: int, xs: wp.array(dtype=float), ys: wp.array(dtype=float)):
229
+ # test function that does not return anything
230
+ ys[i] = sigmoid(xs[i])
231
+
232
+
233
+ @wp.func_grad(sigmoid_no_return)
234
+ def adj_sigmoid_no_return(i: int, xs: wp.array(dtype=float), ys: wp.array(dtype=float)):
235
+ wp.adjoint[xs][i] += ys[i] * (1.0 - ys[i])
236
+
237
+
238
+ @wp.kernel
239
+ def eval_sigmoid(xs: wp.array(dtype=float), ys: wp.array(dtype=float)):
240
+ i = wp.tid()
241
+ sigmoid_no_return(i, xs, ys)
242
+
243
+
244
+ def test_custom_grad_no_return(test, device):
245
+ xs = wp.array([1.0, 2.0, 3.0, 4.0], dtype=wp.float32, requires_grad=True, device=device)
246
+ ys = wp.zeros_like(xs, device=device)
247
+ ys.grad.fill_(1.0)
248
+
249
+ tape = wp.Tape()
250
+ with tape:
251
+ wp.launch(eval_sigmoid, dim=len(xs), inputs=[xs], outputs=[ys], device=device)
252
+ tape.backward()
253
+
254
+ sigmoids = ys.numpy()
255
+ grad = xs.grad.numpy()
256
+ assert_np_equal(grad, sigmoids * (1.0 - sigmoids))
257
+
258
+
259
+ @wp.func
260
+ def dense_gemm(
261
+ m: int,
262
+ n: int,
263
+ p: int,
264
+ transpose_A: bool,
265
+ transpose_B: bool,
266
+ add_to_C: bool,
267
+ A: wp.array(dtype=float),
268
+ B: wp.array(dtype=float),
269
+ # outputs
270
+ C: wp.array(dtype=float),
271
+ ):
272
+ # this function doesn't get called but it is an important test for code generation
273
+ # multiply a `m x p` matrix A by a `p x n` matrix B to produce a `m x n` matrix C
274
+ for i in range(m):
275
+ for j in range(n):
276
+ sum = float(0.0)
277
+ for k in range(p):
278
+ if transpose_A:
279
+ a_i = k * m + i
280
+ else:
281
+ a_i = i * p + k
282
+ if transpose_B:
283
+ b_j = j * p + k
284
+ else:
285
+ b_j = k * n + j
286
+ sum += A[a_i] * B[b_j]
287
+
288
+ if add_to_C:
289
+ C[i * n + j] += sum
290
+ else:
291
+ C[i * n + j] = sum
292
+
293
+
294
+ @wp.func_grad(dense_gemm)
295
+ def adj_dense_gemm(
296
+ m: int,
297
+ n: int,
298
+ p: int,
299
+ transpose_A: bool,
300
+ transpose_B: bool,
301
+ add_to_C: bool,
302
+ A: wp.array(dtype=float),
303
+ B: wp.array(dtype=float),
304
+ # outputs
305
+ C: wp.array(dtype=float),
306
+ ):
307
+ # code generation would break here if we didn't defer building the custom grad
308
+ # function until after the forward functions + kernels of the module have been built
309
+ add_to_C = True
310
+ if transpose_A:
311
+ dense_gemm(p, m, n, False, True, add_to_C, B, wp.adjoint[C], wp.adjoint[A])
312
+ dense_gemm(p, n, m, False, False, add_to_C, A, wp.adjoint[C], wp.adjoint[B])
313
+ else:
314
+ dense_gemm(m, p, n, False, not transpose_B, add_to_C, wp.adjoint[C], B, wp.adjoint[A])
315
+ dense_gemm(p, n, m, True, False, add_to_C, A, wp.adjoint[C], wp.adjoint[B])
316
+
317
+
318
+ devices = get_test_devices()
319
+
320
+
321
+ class TestGradCustoms(unittest.TestCase):
322
+ def test_wrapped_docstring(self):
323
+ self.assertTrue("This is a docstring" in reversible_increment.__doc__)
324
+ self.assertTrue("This is a docstring" in replay_reversible_increment.__doc__)
325
+ self.assertTrue("This is a docstring" in overload_fn.__doc__)
326
+ self.assertTrue("This is a docstring" in overload_fn_grad.__doc__)
327
+ self.assertTrue("This is a docstring" in run_overload_float_fn.__doc__)
328
+ self.assertTrue("This is a docstring" in MyStruct.__doc__)
329
+
330
+
331
+ add_function_test(TestGradCustoms, "test_custom_replay_grad", test_custom_replay_grad, devices=devices)
332
+ add_function_test(TestGradCustoms, "test_custom_overload_grad", test_custom_overload_grad, devices=devices)
333
+ add_function_test(TestGradCustoms, "test_custom_import_grad", test_custom_import_grad, devices=devices)
334
+ add_function_test(TestGradCustoms, "test_custom_grad_no_return", test_custom_grad_no_return, devices=devices)
335
+
336
+
337
+ if __name__ == "__main__":
338
+ wp.clear_kernel_cache()
339
+ unittest.main(verbosity=2, failfast=False)
@@ -0,0 +1,341 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import unittest
17
+
18
+ import warp as wp
19
+ from warp.autograd import (
20
+ gradcheck,
21
+ gradcheck_tape,
22
+ jacobian,
23
+ jacobian_fd,
24
+ )
25
+ from warp.tests.unittest_utils import *
26
+
27
+
28
+ @wp.kernel
29
+ def kernel_3d(
30
+ a: wp.array(dtype=float, ndim=3),
31
+ b: wp.array(dtype=float, ndim=3),
32
+ c: wp.array(dtype=float, ndim=3),
33
+ out1: wp.array(dtype=float, ndim=3),
34
+ out2: wp.array(dtype=float, ndim=3),
35
+ ):
36
+ i, j, k = wp.tid()
37
+ out1[i, j, k] = a[i, j, k] * b[i, j, k] + c[i, j, k]
38
+ out2[i, j, k] = -a[i, j, k] * b[i, j, k] - c[i, j, k]
39
+
40
+
41
+ @wp.kernel
42
+ def kernel_mixed(
43
+ a: wp.array(dtype=float),
44
+ b: wp.array(dtype=wp.vec3),
45
+ out1: wp.array(dtype=wp.vec2),
46
+ out2: wp.array(dtype=wp.quat),
47
+ ):
48
+ tid = wp.tid()
49
+ ai, bi = a[tid], b[tid]
50
+ out1[tid] = wp.vec2(ai * wp.length(bi), -ai * wp.dot(bi, wp.vec3(0.1, 1.0, -0.1)))
51
+ out2[tid] = wp.normalize(wp.quat(ai, bi[0], bi[1], bi[2]))
52
+
53
+
54
+ @wp.kernel
55
+ def vec_length_kernel(a: wp.array(dtype=wp.vec3), out: wp.array(dtype=float)):
56
+ tid = wp.tid()
57
+ v = a[tid]
58
+ # instead of wp.length(v), we use a trivial implementation that
59
+ # fails when a division by zero occurs in the backward pass of sqrt
60
+ out[tid] = wp.sqrt(v[0] ** 2.0 + v[1] ** 2.0 + v[2] ** 2.0)
61
+
62
+
63
+ @wp.func
64
+ def wrong_grad_func(x: float):
65
+ return x * x
66
+
67
+
68
+ @wp.func_grad(wrong_grad_func)
69
+ def adj_wrong_grad_func(x: float, adj: float):
70
+ wp.adjoint[x] -= 2.0 * x * adj
71
+
72
+
73
+ @wp.kernel
74
+ def wrong_grad_kernel(a: wp.array(dtype=float), out: wp.array(dtype=float)):
75
+ tid = wp.tid()
76
+ out[tid] = wrong_grad_func(a[tid])
77
+
78
+
79
+ @wp.kernel
80
+ def transform_point_kernel(
81
+ transforms: wp.array(dtype=wp.transform),
82
+ points: wp.array(dtype=wp.vec3),
83
+ out: wp.array(dtype=wp.vec3),
84
+ ):
85
+ tid = wp.tid()
86
+ out[tid] = wp.transform_point(transforms[tid], points[tid])
87
+
88
+
89
+ def test_gradcheck_3d(test, device):
90
+ a_3d = wp.array([((2.0, 0.0), (1.0, 0.0), (2.0, 0.0))], dtype=float, requires_grad=True, device=device)
91
+ b_3d = wp.array([((3.0, 0.0), (1.0, 0.0), (2.0, 0.0))], dtype=float, requires_grad=True, device=device)
92
+ c_3d = wp.array([((4.0, 0.0), (1.0, 0.0), (2.0, 0.0))], dtype=float, requires_grad=True, device=device)
93
+
94
+ out1_3d = wp.array([((3.0, 0.0), (1.0, 0.0), (2.0, 0.0))], dtype=float, requires_grad=True, device=device)
95
+ out2_3d = wp.array([((4.0, 0.0), (1.0, 0.0), (2.0, 0.0))], dtype=float, requires_grad=True, device=device)
96
+
97
+ jacs_ad = jacobian(
98
+ kernel_3d,
99
+ dim=a_3d.shape,
100
+ inputs=[a_3d, b_3d, c_3d],
101
+ outputs=[out1_3d, out2_3d],
102
+ max_outputs_per_var=4,
103
+ input_output_mask=[("a", "out1"), ("b", "out2")],
104
+ )
105
+
106
+ assert sorted(jacs_ad.keys()) == [(0, 0), (1, 1)]
107
+ assert jacs_ad[(0, 0)].shape == (6, 6)
108
+ assert jacs_ad[(1, 1)].shape == (6, 6)
109
+ # all entries beyond the max_outputs_per_var are NaN
110
+ assert np.all(np.isnan(jacs_ad[(0, 0)].numpy()[4:]))
111
+ assert np.all(np.isnan(jacs_ad[(1, 1)].numpy()[4:]))
112
+
113
+ jacs_fd = jacobian_fd(
114
+ kernel_3d,
115
+ dim=a_3d.shape,
116
+ inputs=[a_3d, b_3d, c_3d],
117
+ outputs=[out1_3d, out2_3d],
118
+ max_inputs_per_var=4,
119
+ # use integer indices instead of variable names
120
+ input_output_mask=[(0, 0), (1, 1)],
121
+ eps=1e-4,
122
+ )
123
+
124
+ assert sorted(jacs_fd.keys()) == [(0, 0), (1, 1)]
125
+ assert jacs_fd[(0, 0)].shape == (6, 6)
126
+ assert jacs_fd[(1, 1)].shape == (6, 6)
127
+ # all entries beyond the max_inputs_per_var are NaN
128
+ assert np.all(np.isnan(jacs_fd[(0, 0)].numpy()[:, 4:]))
129
+ assert np.all(np.isnan(jacs_fd[(1, 1)].numpy()[:, 4:]))
130
+
131
+ # manual gradcheck
132
+ assert np.allclose(jacs_ad[(0, 0)].numpy()[:4, :4], jacs_fd[(0, 0)].numpy()[:4, :4], atol=1e-2, rtol=1e-2)
133
+ assert np.allclose(jacs_ad[(1, 1)].numpy()[:4, :4], jacs_fd[(1, 1)].numpy()[:4, :4], atol=1e-2, rtol=1e-2)
134
+
135
+ passed = gradcheck(
136
+ kernel_3d,
137
+ dim=a_3d.shape,
138
+ inputs=[a_3d, b_3d, c_3d],
139
+ outputs=[out1_3d, out2_3d],
140
+ max_inputs_per_var=4,
141
+ max_outputs_per_var=4,
142
+ input_output_mask=[("a", "out1"), ("b", "out2")],
143
+ show_summary=False,
144
+ )
145
+ assert passed
146
+
147
+
148
+ def test_gradcheck_mixed(test, device):
149
+ a = wp.array([2.0, -1.0], dtype=wp.float32, requires_grad=True, device=device)
150
+ b = wp.array([wp.vec3(3.0, 1.0, 2.0), wp.vec3(-4.0, -1.0, 0.0)], dtype=wp.vec3, requires_grad=True, device=device)
151
+ out1 = wp.zeros(2, dtype=wp.vec2, requires_grad=True, device=device)
152
+ out2 = wp.zeros(2, dtype=wp.quat, requires_grad=True, device=device)
153
+
154
+ jacs_ad = jacobian(
155
+ kernel_mixed,
156
+ dim=len(a),
157
+ inputs=[a, b],
158
+ outputs=[out1, out2],
159
+ )
160
+ jacs_fd = jacobian_fd(
161
+ kernel_mixed,
162
+ dim=len(a),
163
+ inputs=[a, b],
164
+ outputs=[out1, out2],
165
+ eps=1e-4,
166
+ )
167
+
168
+ # manual gradcheck
169
+ for i in range(2):
170
+ for j in range(2):
171
+ assert np.allclose(jacs_ad[(i, j)].numpy(), jacs_fd[(i, j)].numpy(), atol=1e-2, rtol=1e-2)
172
+
173
+ passed = gradcheck(
174
+ kernel_mixed,
175
+ dim=len(a),
176
+ inputs=[a, b],
177
+ outputs=[out1, out2],
178
+ raise_exception=False,
179
+ show_summary=False,
180
+ )
181
+
182
+ assert passed
183
+
184
+
185
+ def test_gradcheck_nan(test, device):
186
+ a = wp.array([wp.vec3(1.0, 2.0, 3.0), wp.vec3(0.0, 0.0, 0.0)], dtype=wp.vec3, requires_grad=True, device=device)
187
+ out = wp.array([0.0, 0.0], dtype=float, requires_grad=True, device=device)
188
+
189
+ with test.assertRaises(ValueError):
190
+ gradcheck(
191
+ vec_length_kernel,
192
+ dim=a.shape,
193
+ inputs=[a],
194
+ outputs=[out],
195
+ raise_exception=True,
196
+ show_summary=False,
197
+ )
198
+
199
+
200
+ def test_gradcheck_incorrect(test, device):
201
+ a = wp.array([1.0, 2.0, 3.0], dtype=wp.float32, requires_grad=True, device=device)
202
+ out = wp.zeros_like(a)
203
+
204
+ with test.assertRaises(ValueError):
205
+ gradcheck(
206
+ wrong_grad_kernel,
207
+ dim=a.shape,
208
+ inputs=[a],
209
+ outputs=[out],
210
+ raise_exception=True,
211
+ show_summary=False,
212
+ )
213
+
214
+
215
+ def test_gradcheck_tape(test, device):
216
+ a = wp.array([2.0, -1.0], dtype=wp.float32, requires_grad=True, device=device)
217
+ b = wp.array([wp.vec3(3.0, 1.0, 2.0), wp.vec3(-4.0, -1.0, 0.0)], dtype=wp.vec3, requires_grad=True, device=device)
218
+ out1 = wp.zeros(2, dtype=wp.vec2, requires_grad=True, device=device)
219
+ out2 = wp.zeros(2, dtype=wp.quat, requires_grad=True, device=device)
220
+
221
+ a_3d = wp.array([((2.0, 0.0), (1.0, 0.0), (2.0, 0.0))], dtype=float, requires_grad=True, device=device)
222
+ b_3d = wp.array([((3.0, 0.0), (1.0, 0.0), (2.0, 0.0))], dtype=float, requires_grad=True, device=device)
223
+ c_3d = wp.array([((4.0, 0.0), (1.0, 0.0), (2.0, 0.0))], dtype=float, requires_grad=True, device=device)
224
+
225
+ out1_3d = wp.array([((3.0, 0.0), (1.0, 0.0), (2.0, 0.0))], dtype=float, requires_grad=True, device=device)
226
+ out2_3d = wp.array([((4.0, 0.0), (1.0, 0.0), (2.0, 0.0))], dtype=float, requires_grad=True, device=device)
227
+
228
+ tape = wp.Tape()
229
+ with tape:
230
+ wp.launch(
231
+ kernel_mixed,
232
+ dim=len(a),
233
+ inputs=[a, b],
234
+ outputs=[out1, out2],
235
+ device=device,
236
+ )
237
+
238
+ wp.launch(
239
+ kernel_3d,
240
+ dim=a_3d.shape,
241
+ inputs=[a_3d, b_3d, c_3d],
242
+ outputs=[out1_3d, out2_3d],
243
+ device=device,
244
+ )
245
+
246
+ passed = gradcheck_tape(
247
+ tape,
248
+ raise_exception=False,
249
+ show_summary=False,
250
+ )
251
+
252
+ assert passed
253
+
254
+
255
+ def test_gradcheck_function(test, device):
256
+ def compute_transformed_point_norms(transforms, points):
257
+ tf_points = wp.empty_like(points)
258
+ norms = wp.empty(len(points), dtype=float, requires_grad=points.requires_grad, device=points.device)
259
+
260
+ wp.launch(
261
+ transform_point_kernel,
262
+ dim=len(points),
263
+ inputs=[transforms, points],
264
+ outputs=[tf_points],
265
+ device=device,
266
+ )
267
+ wp.launch(
268
+ vec_length_kernel,
269
+ dim=len(points),
270
+ inputs=[tf_points],
271
+ outputs=[norms],
272
+ device=device,
273
+ )
274
+ return tf_points, norms
275
+
276
+ transforms = wp.array(
277
+ [
278
+ wp.transform(wp.vec3(1.0, 0.6, -2.0), wp.quat_rpy(-0.5, 0.1, 0.8)),
279
+ wp.transform(wp.vec3(0.2, 1.4, -0.4), wp.quat_rpy(0.5, 0.65, -0.3)),
280
+ wp.transform(wp.vec3(0.5, 0.2, 0.0), wp.quat_rpy(-0.5, -0.3, 0.4)),
281
+ ],
282
+ dtype=wp.transform,
283
+ requires_grad=True,
284
+ device=device,
285
+ )
286
+ points = wp.array(
287
+ [
288
+ (1.0, -0.5, 2.0),
289
+ (-0.95, -0.1, 0.0),
290
+ (9.1, 9.7, 3.8),
291
+ ],
292
+ dtype=wp.vec3,
293
+ requires_grad=True,
294
+ device=device,
295
+ )
296
+
297
+ jacs_ad = jacobian(
298
+ kernel_mixed,
299
+ dim=len(points),
300
+ inputs=[transforms, points],
301
+ )
302
+ jacs_fd = jacobian_fd(
303
+ kernel_mixed,
304
+ dim=len(points),
305
+ inputs=[transforms, points],
306
+ eps=1e-4,
307
+ )
308
+
309
+ # manual gradcheck
310
+ for i in range(2):
311
+ for j in range(2):
312
+ assert np.allclose(jacs_ad[(i, j)].numpy(), jacs_fd[(i, j)].numpy(), atol=1e-2, rtol=1e-2)
313
+
314
+ passed = gradcheck(
315
+ kernel_mixed,
316
+ dim=len(points),
317
+ inputs=[transforms, points],
318
+ raise_exception=False,
319
+ show_summary=False,
320
+ )
321
+
322
+ assert passed
323
+
324
+
325
+ devices = get_test_devices()
326
+
327
+
328
+ class TestGradDebug(unittest.TestCase):
329
+ pass
330
+
331
+
332
+ add_function_test(TestGradDebug, "test_gradcheck_3d", test_gradcheck_3d, devices=devices)
333
+ add_function_test(TestGradDebug, "test_gradcheck_mixed", test_gradcheck_mixed, devices=devices)
334
+ add_function_test(TestGradDebug, "test_gradcheck_nan", test_gradcheck_nan, devices=devices)
335
+ add_function_test(TestGradDebug, "test_gradcheck_incorrect", test_gradcheck_incorrect, devices=devices)
336
+ add_function_test(TestGradDebug, "test_gradcheck_tape", test_gradcheck_tape, devices=devices)
337
+
338
+
339
+ if __name__ == "__main__":
340
+ wp.build.clear_kernel_cache()
341
+ unittest.main(verbosity=2, failfast=False)