warp-lang 1.7.0__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (429) hide show
  1. warp/__init__.py +139 -0
  2. warp/__init__.pyi +1 -0
  3. warp/autograd.py +1142 -0
  4. warp/bin/warp-clang.so +0 -0
  5. warp/bin/warp.so +0 -0
  6. warp/build.py +557 -0
  7. warp/build_dll.py +405 -0
  8. warp/builtins.py +6855 -0
  9. warp/codegen.py +3969 -0
  10. warp/config.py +158 -0
  11. warp/constants.py +57 -0
  12. warp/context.py +6812 -0
  13. warp/dlpack.py +462 -0
  14. warp/examples/__init__.py +24 -0
  15. warp/examples/assets/bear.usd +0 -0
  16. warp/examples/assets/bunny.usd +0 -0
  17. warp/examples/assets/cartpole.urdf +110 -0
  18. warp/examples/assets/crazyflie.usd +0 -0
  19. warp/examples/assets/cube.usd +0 -0
  20. warp/examples/assets/nonuniform.usd +0 -0
  21. warp/examples/assets/nv_ant.xml +92 -0
  22. warp/examples/assets/nv_humanoid.xml +183 -0
  23. warp/examples/assets/nvidia_logo.png +0 -0
  24. warp/examples/assets/pixel.jpg +0 -0
  25. warp/examples/assets/quadruped.urdf +268 -0
  26. warp/examples/assets/rocks.nvdb +0 -0
  27. warp/examples/assets/rocks.usd +0 -0
  28. warp/examples/assets/sphere.usd +0 -0
  29. warp/examples/assets/square_cloth.usd +0 -0
  30. warp/examples/benchmarks/benchmark_api.py +389 -0
  31. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  32. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  33. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  34. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  35. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  36. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  37. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  38. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  39. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  40. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  41. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  42. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  43. warp/examples/benchmarks/benchmark_launches.py +301 -0
  44. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  45. warp/examples/browse.py +37 -0
  46. warp/examples/core/example_cupy.py +86 -0
  47. warp/examples/core/example_dem.py +241 -0
  48. warp/examples/core/example_fluid.py +299 -0
  49. warp/examples/core/example_graph_capture.py +150 -0
  50. warp/examples/core/example_marching_cubes.py +194 -0
  51. warp/examples/core/example_mesh.py +180 -0
  52. warp/examples/core/example_mesh_intersect.py +211 -0
  53. warp/examples/core/example_nvdb.py +182 -0
  54. warp/examples/core/example_raycast.py +111 -0
  55. warp/examples/core/example_raymarch.py +205 -0
  56. warp/examples/core/example_render_opengl.py +193 -0
  57. warp/examples/core/example_sample_mesh.py +300 -0
  58. warp/examples/core/example_sph.py +411 -0
  59. warp/examples/core/example_torch.py +211 -0
  60. warp/examples/core/example_wave.py +269 -0
  61. warp/examples/fem/example_adaptive_grid.py +286 -0
  62. warp/examples/fem/example_apic_fluid.py +423 -0
  63. warp/examples/fem/example_burgers.py +261 -0
  64. warp/examples/fem/example_convection_diffusion.py +178 -0
  65. warp/examples/fem/example_convection_diffusion_dg.py +204 -0
  66. warp/examples/fem/example_deformed_geometry.py +172 -0
  67. warp/examples/fem/example_diffusion.py +196 -0
  68. warp/examples/fem/example_diffusion_3d.py +225 -0
  69. warp/examples/fem/example_diffusion_mgpu.py +220 -0
  70. warp/examples/fem/example_distortion_energy.py +228 -0
  71. warp/examples/fem/example_magnetostatics.py +240 -0
  72. warp/examples/fem/example_mixed_elasticity.py +291 -0
  73. warp/examples/fem/example_navier_stokes.py +261 -0
  74. warp/examples/fem/example_nonconforming_contact.py +298 -0
  75. warp/examples/fem/example_stokes.py +213 -0
  76. warp/examples/fem/example_stokes_transfer.py +262 -0
  77. warp/examples/fem/example_streamlines.py +352 -0
  78. warp/examples/fem/utils.py +1000 -0
  79. warp/examples/interop/example_jax_callable.py +116 -0
  80. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  81. warp/examples/interop/example_jax_kernel.py +205 -0
  82. warp/examples/optim/example_bounce.py +266 -0
  83. warp/examples/optim/example_cloth_throw.py +228 -0
  84. warp/examples/optim/example_diffray.py +561 -0
  85. warp/examples/optim/example_drone.py +870 -0
  86. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  87. warp/examples/optim/example_inverse_kinematics.py +182 -0
  88. warp/examples/optim/example_inverse_kinematics_torch.py +191 -0
  89. warp/examples/optim/example_softbody_properties.py +400 -0
  90. warp/examples/optim/example_spring_cage.py +245 -0
  91. warp/examples/optim/example_trajectory.py +227 -0
  92. warp/examples/sim/example_cartpole.py +143 -0
  93. warp/examples/sim/example_cloth.py +225 -0
  94. warp/examples/sim/example_cloth_self_contact.py +322 -0
  95. warp/examples/sim/example_granular.py +130 -0
  96. warp/examples/sim/example_granular_collision_sdf.py +202 -0
  97. warp/examples/sim/example_jacobian_ik.py +244 -0
  98. warp/examples/sim/example_particle_chain.py +124 -0
  99. warp/examples/sim/example_quadruped.py +203 -0
  100. warp/examples/sim/example_rigid_chain.py +203 -0
  101. warp/examples/sim/example_rigid_contact.py +195 -0
  102. warp/examples/sim/example_rigid_force.py +133 -0
  103. warp/examples/sim/example_rigid_gyroscopic.py +115 -0
  104. warp/examples/sim/example_rigid_soft_contact.py +140 -0
  105. warp/examples/sim/example_soft_body.py +196 -0
  106. warp/examples/tile/example_tile_cholesky.py +87 -0
  107. warp/examples/tile/example_tile_convolution.py +66 -0
  108. warp/examples/tile/example_tile_fft.py +55 -0
  109. warp/examples/tile/example_tile_filtering.py +113 -0
  110. warp/examples/tile/example_tile_matmul.py +85 -0
  111. warp/examples/tile/example_tile_mlp.py +383 -0
  112. warp/examples/tile/example_tile_nbody.py +199 -0
  113. warp/examples/tile/example_tile_walker.py +327 -0
  114. warp/fabric.py +355 -0
  115. warp/fem/__init__.py +106 -0
  116. warp/fem/adaptivity.py +508 -0
  117. warp/fem/cache.py +572 -0
  118. warp/fem/dirichlet.py +202 -0
  119. warp/fem/domain.py +411 -0
  120. warp/fem/field/__init__.py +125 -0
  121. warp/fem/field/field.py +619 -0
  122. warp/fem/field/nodal_field.py +326 -0
  123. warp/fem/field/restriction.py +37 -0
  124. warp/fem/field/virtual.py +848 -0
  125. warp/fem/geometry/__init__.py +32 -0
  126. warp/fem/geometry/adaptive_nanogrid.py +857 -0
  127. warp/fem/geometry/closest_point.py +84 -0
  128. warp/fem/geometry/deformed_geometry.py +221 -0
  129. warp/fem/geometry/element.py +776 -0
  130. warp/fem/geometry/geometry.py +362 -0
  131. warp/fem/geometry/grid_2d.py +392 -0
  132. warp/fem/geometry/grid_3d.py +452 -0
  133. warp/fem/geometry/hexmesh.py +911 -0
  134. warp/fem/geometry/nanogrid.py +571 -0
  135. warp/fem/geometry/partition.py +389 -0
  136. warp/fem/geometry/quadmesh.py +663 -0
  137. warp/fem/geometry/tetmesh.py +855 -0
  138. warp/fem/geometry/trimesh.py +806 -0
  139. warp/fem/integrate.py +2335 -0
  140. warp/fem/linalg.py +419 -0
  141. warp/fem/operator.py +293 -0
  142. warp/fem/polynomial.py +229 -0
  143. warp/fem/quadrature/__init__.py +17 -0
  144. warp/fem/quadrature/pic_quadrature.py +299 -0
  145. warp/fem/quadrature/quadrature.py +591 -0
  146. warp/fem/space/__init__.py +228 -0
  147. warp/fem/space/basis_function_space.py +468 -0
  148. warp/fem/space/basis_space.py +667 -0
  149. warp/fem/space/dof_mapper.py +251 -0
  150. warp/fem/space/function_space.py +309 -0
  151. warp/fem/space/grid_2d_function_space.py +177 -0
  152. warp/fem/space/grid_3d_function_space.py +227 -0
  153. warp/fem/space/hexmesh_function_space.py +257 -0
  154. warp/fem/space/nanogrid_function_space.py +201 -0
  155. warp/fem/space/partition.py +367 -0
  156. warp/fem/space/quadmesh_function_space.py +223 -0
  157. warp/fem/space/restriction.py +179 -0
  158. warp/fem/space/shape/__init__.py +143 -0
  159. warp/fem/space/shape/cube_shape_function.py +1105 -0
  160. warp/fem/space/shape/shape_function.py +133 -0
  161. warp/fem/space/shape/square_shape_function.py +926 -0
  162. warp/fem/space/shape/tet_shape_function.py +834 -0
  163. warp/fem/space/shape/triangle_shape_function.py +672 -0
  164. warp/fem/space/tetmesh_function_space.py +271 -0
  165. warp/fem/space/topology.py +424 -0
  166. warp/fem/space/trimesh_function_space.py +194 -0
  167. warp/fem/types.py +99 -0
  168. warp/fem/utils.py +420 -0
  169. warp/jax.py +187 -0
  170. warp/jax_experimental/__init__.py +16 -0
  171. warp/jax_experimental/custom_call.py +351 -0
  172. warp/jax_experimental/ffi.py +698 -0
  173. warp/jax_experimental/xla_ffi.py +602 -0
  174. warp/math.py +244 -0
  175. warp/native/array.h +1145 -0
  176. warp/native/builtin.h +1800 -0
  177. warp/native/bvh.cpp +492 -0
  178. warp/native/bvh.cu +791 -0
  179. warp/native/bvh.h +554 -0
  180. warp/native/clang/clang.cpp +536 -0
  181. warp/native/coloring.cpp +613 -0
  182. warp/native/crt.cpp +51 -0
  183. warp/native/crt.h +362 -0
  184. warp/native/cuda_crt.h +1058 -0
  185. warp/native/cuda_util.cpp +646 -0
  186. warp/native/cuda_util.h +307 -0
  187. warp/native/error.cpp +77 -0
  188. warp/native/error.h +36 -0
  189. warp/native/exports.h +1878 -0
  190. warp/native/fabric.h +245 -0
  191. warp/native/hashgrid.cpp +311 -0
  192. warp/native/hashgrid.cu +87 -0
  193. warp/native/hashgrid.h +240 -0
  194. warp/native/initializer_array.h +41 -0
  195. warp/native/intersect.h +1230 -0
  196. warp/native/intersect_adj.h +375 -0
  197. warp/native/intersect_tri.h +339 -0
  198. warp/native/marching.cpp +19 -0
  199. warp/native/marching.cu +514 -0
  200. warp/native/marching.h +19 -0
  201. warp/native/mat.h +2220 -0
  202. warp/native/mathdx.cpp +87 -0
  203. warp/native/matnn.h +343 -0
  204. warp/native/mesh.cpp +266 -0
  205. warp/native/mesh.cu +404 -0
  206. warp/native/mesh.h +1980 -0
  207. warp/native/nanovdb/GridHandle.h +366 -0
  208. warp/native/nanovdb/HostBuffer.h +590 -0
  209. warp/native/nanovdb/NanoVDB.h +6624 -0
  210. warp/native/nanovdb/PNanoVDB.h +3390 -0
  211. warp/native/noise.h +859 -0
  212. warp/native/quat.h +1371 -0
  213. warp/native/rand.h +342 -0
  214. warp/native/range.h +139 -0
  215. warp/native/reduce.cpp +174 -0
  216. warp/native/reduce.cu +364 -0
  217. warp/native/runlength_encode.cpp +79 -0
  218. warp/native/runlength_encode.cu +61 -0
  219. warp/native/scan.cpp +47 -0
  220. warp/native/scan.cu +53 -0
  221. warp/native/scan.h +23 -0
  222. warp/native/solid_angle.h +466 -0
  223. warp/native/sort.cpp +251 -0
  224. warp/native/sort.cu +277 -0
  225. warp/native/sort.h +33 -0
  226. warp/native/sparse.cpp +378 -0
  227. warp/native/sparse.cu +524 -0
  228. warp/native/spatial.h +657 -0
  229. warp/native/svd.h +702 -0
  230. warp/native/temp_buffer.h +46 -0
  231. warp/native/tile.h +2584 -0
  232. warp/native/tile_reduce.h +264 -0
  233. warp/native/vec.h +1426 -0
  234. warp/native/volume.cpp +501 -0
  235. warp/native/volume.cu +67 -0
  236. warp/native/volume.h +969 -0
  237. warp/native/volume_builder.cu +477 -0
  238. warp/native/volume_builder.h +52 -0
  239. warp/native/volume_impl.h +70 -0
  240. warp/native/warp.cpp +1082 -0
  241. warp/native/warp.cu +3636 -0
  242. warp/native/warp.h +381 -0
  243. warp/optim/__init__.py +17 -0
  244. warp/optim/adam.py +163 -0
  245. warp/optim/linear.py +1137 -0
  246. warp/optim/sgd.py +112 -0
  247. warp/paddle.py +407 -0
  248. warp/render/__init__.py +18 -0
  249. warp/render/render_opengl.py +3518 -0
  250. warp/render/render_usd.py +784 -0
  251. warp/render/utils.py +160 -0
  252. warp/sim/__init__.py +65 -0
  253. warp/sim/articulation.py +793 -0
  254. warp/sim/collide.py +2395 -0
  255. warp/sim/graph_coloring.py +300 -0
  256. warp/sim/import_mjcf.py +790 -0
  257. warp/sim/import_snu.py +227 -0
  258. warp/sim/import_urdf.py +579 -0
  259. warp/sim/import_usd.py +894 -0
  260. warp/sim/inertia.py +324 -0
  261. warp/sim/integrator.py +242 -0
  262. warp/sim/integrator_euler.py +1997 -0
  263. warp/sim/integrator_featherstone.py +2101 -0
  264. warp/sim/integrator_vbd.py +2048 -0
  265. warp/sim/integrator_xpbd.py +3292 -0
  266. warp/sim/model.py +4791 -0
  267. warp/sim/particles.py +121 -0
  268. warp/sim/render.py +427 -0
  269. warp/sim/utils.py +428 -0
  270. warp/sparse.py +2057 -0
  271. warp/stubs.py +3333 -0
  272. warp/tape.py +1203 -0
  273. warp/tests/__init__.py +1 -0
  274. warp/tests/__main__.py +4 -0
  275. warp/tests/assets/curlnoise_golden.npy +0 -0
  276. warp/tests/assets/mlp_golden.npy +0 -0
  277. warp/tests/assets/pixel.npy +0 -0
  278. warp/tests/assets/pnoise_golden.npy +0 -0
  279. warp/tests/assets/spiky.usd +0 -0
  280. warp/tests/assets/test_grid.nvdb +0 -0
  281. warp/tests/assets/test_index_grid.nvdb +0 -0
  282. warp/tests/assets/test_int32_grid.nvdb +0 -0
  283. warp/tests/assets/test_vec_grid.nvdb +0 -0
  284. warp/tests/assets/torus.nvdb +0 -0
  285. warp/tests/assets/torus.usda +105 -0
  286. warp/tests/aux_test_class_kernel.py +34 -0
  287. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  288. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  289. warp/tests/aux_test_dependent.py +29 -0
  290. warp/tests/aux_test_grad_customs.py +29 -0
  291. warp/tests/aux_test_instancing_gc.py +26 -0
  292. warp/tests/aux_test_module_unload.py +23 -0
  293. warp/tests/aux_test_name_clash1.py +40 -0
  294. warp/tests/aux_test_name_clash2.py +40 -0
  295. warp/tests/aux_test_reference.py +9 -0
  296. warp/tests/aux_test_reference_reference.py +8 -0
  297. warp/tests/aux_test_square.py +16 -0
  298. warp/tests/aux_test_unresolved_func.py +22 -0
  299. warp/tests/aux_test_unresolved_symbol.py +22 -0
  300. warp/tests/cuda/__init__.py +0 -0
  301. warp/tests/cuda/test_async.py +676 -0
  302. warp/tests/cuda/test_ipc.py +124 -0
  303. warp/tests/cuda/test_mempool.py +233 -0
  304. warp/tests/cuda/test_multigpu.py +169 -0
  305. warp/tests/cuda/test_peer.py +139 -0
  306. warp/tests/cuda/test_pinned.py +84 -0
  307. warp/tests/cuda/test_streams.py +634 -0
  308. warp/tests/geometry/__init__.py +0 -0
  309. warp/tests/geometry/test_bvh.py +200 -0
  310. warp/tests/geometry/test_hash_grid.py +221 -0
  311. warp/tests/geometry/test_marching_cubes.py +74 -0
  312. warp/tests/geometry/test_mesh.py +316 -0
  313. warp/tests/geometry/test_mesh_query_aabb.py +399 -0
  314. warp/tests/geometry/test_mesh_query_point.py +932 -0
  315. warp/tests/geometry/test_mesh_query_ray.py +311 -0
  316. warp/tests/geometry/test_volume.py +1103 -0
  317. warp/tests/geometry/test_volume_write.py +346 -0
  318. warp/tests/interop/__init__.py +0 -0
  319. warp/tests/interop/test_dlpack.py +729 -0
  320. warp/tests/interop/test_jax.py +371 -0
  321. warp/tests/interop/test_paddle.py +800 -0
  322. warp/tests/interop/test_torch.py +1001 -0
  323. warp/tests/run_coverage_serial.py +39 -0
  324. warp/tests/sim/__init__.py +0 -0
  325. warp/tests/sim/disabled_kinematics.py +244 -0
  326. warp/tests/sim/flaky_test_sim_grad.py +290 -0
  327. warp/tests/sim/test_collision.py +604 -0
  328. warp/tests/sim/test_coloring.py +258 -0
  329. warp/tests/sim/test_model.py +224 -0
  330. warp/tests/sim/test_sim_grad_bounce_linear.py +212 -0
  331. warp/tests/sim/test_sim_kinematics.py +98 -0
  332. warp/tests/sim/test_vbd.py +597 -0
  333. warp/tests/test_adam.py +163 -0
  334. warp/tests/test_arithmetic.py +1096 -0
  335. warp/tests/test_array.py +2972 -0
  336. warp/tests/test_array_reduce.py +156 -0
  337. warp/tests/test_assert.py +250 -0
  338. warp/tests/test_atomic.py +153 -0
  339. warp/tests/test_bool.py +220 -0
  340. warp/tests/test_builtins_resolution.py +1298 -0
  341. warp/tests/test_closest_point_edge_edge.py +327 -0
  342. warp/tests/test_codegen.py +810 -0
  343. warp/tests/test_codegen_instancing.py +1495 -0
  344. warp/tests/test_compile_consts.py +215 -0
  345. warp/tests/test_conditional.py +252 -0
  346. warp/tests/test_context.py +42 -0
  347. warp/tests/test_copy.py +238 -0
  348. warp/tests/test_ctypes.py +638 -0
  349. warp/tests/test_dense.py +73 -0
  350. warp/tests/test_devices.py +97 -0
  351. warp/tests/test_examples.py +482 -0
  352. warp/tests/test_fabricarray.py +996 -0
  353. warp/tests/test_fast_math.py +74 -0
  354. warp/tests/test_fem.py +2003 -0
  355. warp/tests/test_fp16.py +136 -0
  356. warp/tests/test_func.py +454 -0
  357. warp/tests/test_future_annotations.py +98 -0
  358. warp/tests/test_generics.py +656 -0
  359. warp/tests/test_grad.py +893 -0
  360. warp/tests/test_grad_customs.py +339 -0
  361. warp/tests/test_grad_debug.py +341 -0
  362. warp/tests/test_implicit_init.py +411 -0
  363. warp/tests/test_import.py +45 -0
  364. warp/tests/test_indexedarray.py +1140 -0
  365. warp/tests/test_intersect.py +73 -0
  366. warp/tests/test_iter.py +76 -0
  367. warp/tests/test_large.py +177 -0
  368. warp/tests/test_launch.py +411 -0
  369. warp/tests/test_lerp.py +151 -0
  370. warp/tests/test_linear_solvers.py +193 -0
  371. warp/tests/test_lvalue.py +427 -0
  372. warp/tests/test_mat.py +2089 -0
  373. warp/tests/test_mat_lite.py +122 -0
  374. warp/tests/test_mat_scalar_ops.py +2913 -0
  375. warp/tests/test_math.py +178 -0
  376. warp/tests/test_mlp.py +282 -0
  377. warp/tests/test_module_hashing.py +258 -0
  378. warp/tests/test_modules_lite.py +44 -0
  379. warp/tests/test_noise.py +252 -0
  380. warp/tests/test_operators.py +299 -0
  381. warp/tests/test_options.py +129 -0
  382. warp/tests/test_overwrite.py +551 -0
  383. warp/tests/test_print.py +339 -0
  384. warp/tests/test_quat.py +2315 -0
  385. warp/tests/test_rand.py +339 -0
  386. warp/tests/test_reload.py +302 -0
  387. warp/tests/test_rounding.py +185 -0
  388. warp/tests/test_runlength_encode.py +196 -0
  389. warp/tests/test_scalar_ops.py +105 -0
  390. warp/tests/test_smoothstep.py +108 -0
  391. warp/tests/test_snippet.py +318 -0
  392. warp/tests/test_sparse.py +582 -0
  393. warp/tests/test_spatial.py +2229 -0
  394. warp/tests/test_special_values.py +361 -0
  395. warp/tests/test_static.py +592 -0
  396. warp/tests/test_struct.py +734 -0
  397. warp/tests/test_tape.py +204 -0
  398. warp/tests/test_transient_module.py +93 -0
  399. warp/tests/test_triangle_closest_point.py +145 -0
  400. warp/tests/test_types.py +562 -0
  401. warp/tests/test_utils.py +588 -0
  402. warp/tests/test_vec.py +1487 -0
  403. warp/tests/test_vec_lite.py +80 -0
  404. warp/tests/test_vec_scalar_ops.py +2327 -0
  405. warp/tests/test_verify_fp.py +100 -0
  406. warp/tests/tile/__init__.py +0 -0
  407. warp/tests/tile/test_tile.py +780 -0
  408. warp/tests/tile/test_tile_load.py +407 -0
  409. warp/tests/tile/test_tile_mathdx.py +208 -0
  410. warp/tests/tile/test_tile_mlp.py +402 -0
  411. warp/tests/tile/test_tile_reduce.py +447 -0
  412. warp/tests/tile/test_tile_shared_memory.py +247 -0
  413. warp/tests/tile/test_tile_view.py +173 -0
  414. warp/tests/unittest_serial.py +47 -0
  415. warp/tests/unittest_suites.py +427 -0
  416. warp/tests/unittest_utils.py +468 -0
  417. warp/tests/walkthrough_debug.py +93 -0
  418. warp/thirdparty/__init__.py +0 -0
  419. warp/thirdparty/appdirs.py +598 -0
  420. warp/thirdparty/dlpack.py +145 -0
  421. warp/thirdparty/unittest_parallel.py +570 -0
  422. warp/torch.py +391 -0
  423. warp/types.py +5230 -0
  424. warp/utils.py +1137 -0
  425. warp_lang-1.7.0.dist-info/METADATA +516 -0
  426. warp_lang-1.7.0.dist-info/RECORD +429 -0
  427. warp_lang-1.7.0.dist-info/WHEEL +5 -0
  428. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  429. warp_lang-1.7.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,497 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ ###########################################################################
17
+ # Example Fluid Checkpoint
18
+ #
19
+ # Shows how to implement a differentiable 2D stable-fluids solver and
20
+ # optimize the initial velocity field to form the NVIDIA logo at the end
21
+ # of the simulation. Gradient checkpointing to reduce memory usage
22
+ # is manually implemented.
23
+ #
24
+ # References:
25
+ # https://github.com/HIPS/autograd/blob/master/examples/fluidsim/fluidsim.py
26
+ #
27
+ ###########################################################################
28
+
29
+ import math
30
+ import os
31
+
32
+ import numpy as np
33
+
34
+ import warp as wp
35
+ import warp.examples
36
+ import warp.optim
37
+
38
+ try:
39
+ from PIL import Image
40
+ except ImportError as err:
41
+ raise ImportError("This example requires the Pillow package. Please install it with 'pip install Pillow'.") from err
42
+
43
+
44
+ N_GRID = wp.constant(512)
45
+ DH = 1.0 / N_GRID # Grid spacing
46
+ FLUID_COLUMN_WIDTH = N_GRID / 10.0
47
+
48
+
49
+ @wp.func
50
+ def cyclic_index(idx: wp.int32):
51
+ """Helper function to index with periodic boundary conditions."""
52
+ ret_idx = idx % N_GRID
53
+ if ret_idx < 0:
54
+ ret_idx += N_GRID
55
+ return ret_idx
56
+
57
+
58
+ @wp.kernel
59
+ def fill_initial_density(density: wp.array2d(dtype=wp.float32)):
60
+ """Initialize the density array with three bands of fluid."""
61
+ i, j = wp.tid()
62
+
63
+ y_pos = wp.float32(i)
64
+
65
+ if FLUID_COLUMN_WIDTH <= y_pos < 2.0 * FLUID_COLUMN_WIDTH:
66
+ density[i, j] = 1.0
67
+ elif 4.5 * FLUID_COLUMN_WIDTH <= y_pos < 5.5 * FLUID_COLUMN_WIDTH:
68
+ density[i, j] = 1.0
69
+ elif 8.0 * FLUID_COLUMN_WIDTH <= y_pos < 9.0 * FLUID_COLUMN_WIDTH:
70
+ density[i, j] = 1.0
71
+ else:
72
+ density[i, j] = 0.0
73
+
74
+
75
+ @wp.kernel
76
+ def advect(
77
+ dt: float,
78
+ vx: wp.array2d(dtype=float),
79
+ vy: wp.array2d(dtype=float),
80
+ f0: wp.array2d(dtype=float),
81
+ f1: wp.array2d(dtype=float),
82
+ ):
83
+ """Move field f0 according to vx and vy velocities using an implicit Euler integrator."""
84
+
85
+ i, j = wp.tid()
86
+
87
+ center_xs = wp.float32(i) - vx[i, j] * dt
88
+ center_ys = wp.float32(j) - vy[i, j] * dt
89
+
90
+ # Compute indices of source cells.
91
+ left_idx = wp.int32(wp.floor(center_xs))
92
+ bot_idx = wp.int32(wp.floor(center_ys))
93
+
94
+ s1 = center_xs - wp.float32(left_idx) # Relative weight of right cell
95
+ s0 = 1.0 - s1
96
+ t1 = center_ys - wp.float32(bot_idx) # Relative weight of top cell
97
+ t0 = 1.0 - t1
98
+
99
+ i0 = cyclic_index(left_idx)
100
+ i1 = cyclic_index(left_idx + 1)
101
+ j0 = cyclic_index(bot_idx)
102
+ j1 = cyclic_index(bot_idx + 1)
103
+
104
+ # Perform bilinear interpolation of the four cells bounding the back-in-time position
105
+ f1[i, j] = s0 * (t0 * f0[i0, j0] + t1 * f0[i0, j1]) + s1 * (t0 * f0[i1, j0] + t1 * f0[i1, j1])
106
+
107
+
108
+ @wp.kernel
109
+ def divergence(wx: wp.array2d(dtype=float), wy: wp.array2d(dtype=float), div: wp.array2d(dtype=float)):
110
+ """Compute div(w)."""
111
+
112
+ i, j = wp.tid()
113
+
114
+ div[i, j] = (
115
+ 0.5
116
+ * (
117
+ wx[cyclic_index(i + 1), j]
118
+ - wx[cyclic_index(i - 1), j]
119
+ + wy[i, cyclic_index(j + 1)]
120
+ - wy[i, cyclic_index(j - 1)]
121
+ )
122
+ / DH
123
+ )
124
+
125
+
126
+ @wp.kernel
127
+ def jacobi_iter(div: wp.array2d(dtype=float), p0: wp.array2d(dtype=float), p1: wp.array2d(dtype=float)):
128
+ """Calculate a single Jacobi iteration for solving the pressure Poisson equation."""
129
+
130
+ i, j = wp.tid()
131
+
132
+ p1[i, j] = 0.25 * (
133
+ -DH * DH * div[i, j]
134
+ + p0[cyclic_index(i - 1), j]
135
+ + p0[cyclic_index(i + 1), j]
136
+ + p0[i, cyclic_index(j - 1)]
137
+ + p0[i, cyclic_index(j + 1)]
138
+ )
139
+
140
+
141
+ @wp.kernel
142
+ def update_velocities(
143
+ p: wp.array2d(dtype=float),
144
+ wx: wp.array2d(dtype=float),
145
+ wy: wp.array2d(dtype=float),
146
+ vx: wp.array2d(dtype=float),
147
+ vy: wp.array2d(dtype=float),
148
+ ):
149
+ """Given p and (wx, wy), compute an 'incompressible' velocity field (vx, vy)."""
150
+
151
+ i, j = wp.tid()
152
+
153
+ vx[i, j] = wx[i, j] - 0.5 * (p[cyclic_index(i + 1), j] - p[cyclic_index(i - 1), j]) / DH
154
+ vy[i, j] = wy[i, j] - 0.5 * (p[i, cyclic_index(j + 1)] - p[i, cyclic_index(j - 1)]) / DH
155
+
156
+
157
+ @wp.kernel
158
+ def compute_loss(
159
+ actual_state: wp.array2d(dtype=float), target_state: wp.array2d(dtype=float), loss: wp.array(dtype=float)
160
+ ):
161
+ i, j = wp.tid()
162
+
163
+ loss_value = (
164
+ (actual_state[i, j] - target_state[i, j])
165
+ * (actual_state[i, j] - target_state[i, j])
166
+ / wp.float32(N_GRID * N_GRID)
167
+ )
168
+
169
+ wp.atomic_add(loss, 0, loss_value)
170
+
171
+
172
+ class Example:
173
+ def __init__(self, sim_steps=1000):
174
+ self.pressure_arrays = []
175
+ self.wx_arrays = []
176
+ self.wy_arrays = []
177
+ self.vx_arrays = []
178
+ self.vy_arrays = []
179
+ self.density_arrays = []
180
+ self.div_arrays = []
181
+
182
+ # Memory usage is minimized when the segment size is approx. sqrt(sim_steps)
183
+ self.segment_size = math.ceil(math.sqrt(sim_steps))
184
+
185
+ # TODO: For now, let's just round up sim_steps so each segment is the same size
186
+ self.num_segments = math.ceil(sim_steps / self.segment_size)
187
+ self.sim_steps = self.segment_size * self.num_segments
188
+
189
+ self.pressure_iterations = 50
190
+ self.dt = 1.0
191
+
192
+ # Store enough arrays to step through a segment without overwriting arrays
193
+ # NOTE: Need an extra array to store the final time-advanced velocities and densities
194
+ for _step in range(self.segment_size + 1):
195
+ self.vx_arrays.append(wp.zeros((N_GRID, N_GRID), dtype=float, requires_grad=True))
196
+ self.vy_arrays.append(wp.zeros((N_GRID, N_GRID), dtype=float, requires_grad=True))
197
+ self.density_arrays.append(wp.zeros((N_GRID, N_GRID), dtype=float, requires_grad=True))
198
+
199
+ for _step in range(self.segment_size):
200
+ self.wx_arrays.append(wp.zeros((N_GRID, N_GRID), dtype=float, requires_grad=True))
201
+ self.wy_arrays.append(wp.zeros((N_GRID, N_GRID), dtype=float, requires_grad=True))
202
+ self.div_arrays.append(wp.zeros((N_GRID, N_GRID), dtype=float, requires_grad=True))
203
+
204
+ for _iter in range(self.pressure_iterations):
205
+ self.pressure_arrays.append(wp.zeros((N_GRID, N_GRID), dtype=float, requires_grad=True))
206
+
207
+ # Allocate one more pressure array for the final time step
208
+ self.pressure_arrays.append(wp.zeros((N_GRID, N_GRID), dtype=float, requires_grad=True))
209
+
210
+ # Allocate memory to save the fluid state at the start of each segment
211
+ self.segment_start_vx_arrays = []
212
+ self.segment_start_vy_arrays = []
213
+ self.segment_start_density_arrays = []
214
+ self.segment_start_pressure_arrays = []
215
+
216
+ for _segment_index in range(self.num_segments):
217
+ self.segment_start_vx_arrays.append(wp.zeros((N_GRID, N_GRID), dtype=float))
218
+ self.segment_start_vy_arrays.append(wp.zeros((N_GRID, N_GRID), dtype=float))
219
+ self.segment_start_density_arrays.append(wp.zeros((N_GRID, N_GRID), dtype=float))
220
+ self.segment_start_pressure_arrays.append(wp.zeros((N_GRID, N_GRID), dtype=float))
221
+
222
+ # To restore previously computed gradients before calling tape.backward()
223
+ self.vx_array_grad_saved = wp.zeros((N_GRID, N_GRID), dtype=float)
224
+ self.vy_array_grad_saved = wp.zeros((N_GRID, N_GRID), dtype=float)
225
+ self.density_array_grad_saved = wp.zeros((N_GRID, N_GRID), dtype=float)
226
+ self.pressure_array_grad_saved = wp.zeros((N_GRID, N_GRID), dtype=float)
227
+
228
+ wp.launch(fill_initial_density, (N_GRID, N_GRID), inputs=[self.density_arrays[0]])
229
+
230
+ target_base = Image.open(os.path.join(warp.examples.get_asset_directory(), "nvidia_logo.png"))
231
+ target_resized = target_base.resize((N_GRID, N_GRID))
232
+
233
+ target_np = np.array(target_resized)[:, :, 0] / 255.0
234
+ self.target_wp = wp.array(target_np, dtype=float)
235
+
236
+ self.loss = wp.zeros((1,), dtype=float, requires_grad=True)
237
+
238
+ self.train_rate = 0.01
239
+ self.optimizer = warp.optim.Adam([self.vx_arrays[0].flatten(), self.vy_arrays[0].flatten()], lr=self.train_rate)
240
+
241
+ # Capture forward/backward passes and tape.zero()
242
+ self.use_cuda_graph = wp.get_device().is_cuda
243
+ self.forward_graph = None
244
+ self.backward_graph = None
245
+ self.zero_tape_graph = None
246
+
247
+ if self.use_cuda_graph:
248
+ with wp.ScopedCapture() as capture:
249
+ self.forward()
250
+ self.forward_graph = capture.graph
251
+
252
+ with wp.ScopedCapture() as capture:
253
+ self.backward()
254
+ self.backward_graph = capture.graph
255
+
256
+ # tape.zero() launches many memsets, which can be a significant overhead for smaller problems
257
+ with wp.ScopedCapture() as capture:
258
+ self.tape.zero()
259
+ self.zero_tape_graph = capture.graph
260
+
261
+ def step(self, step_index) -> None:
262
+ """Perform a single time step from t=step_index-1 to t=step_index.
263
+
264
+ 1. Self-advection of velocity components (store output in wx_arrays and wy_arrays)
265
+ 2. Incompressibility constraint (store output in vx_arrays and vy_arrays)
266
+ 3. Advection of density using velocities (vx_arrays, vy_arrays)
267
+ """
268
+
269
+ wp.launch(
270
+ advect,
271
+ (N_GRID, N_GRID),
272
+ inputs=[
273
+ self.dt,
274
+ self.vx_arrays[step_index - 1],
275
+ self.vy_arrays[step_index - 1],
276
+ self.vx_arrays[step_index - 1],
277
+ ],
278
+ outputs=[self.wx_arrays[step_index - 1]],
279
+ )
280
+ wp.launch(
281
+ advect,
282
+ (N_GRID, N_GRID),
283
+ inputs=[
284
+ self.dt,
285
+ self.vx_arrays[step_index - 1],
286
+ self.vy_arrays[step_index - 1],
287
+ self.vy_arrays[step_index - 1],
288
+ ],
289
+ outputs=[self.wy_arrays[step_index - 1]],
290
+ )
291
+
292
+ # Pressure projection using a few Jacobi iterations
293
+ wp.launch(
294
+ divergence,
295
+ (N_GRID, N_GRID),
296
+ inputs=[self.wx_arrays[step_index - 1], self.wy_arrays[step_index - 1]],
297
+ outputs=[self.div_arrays[step_index - 1]],
298
+ )
299
+
300
+ # NOTE: Uses previous step's final pressure as the initial guess
301
+ for k in range(self.pressure_iterations):
302
+ input_index = self.pressure_iterations * (step_index - 1) + k
303
+ output_index = input_index + 1
304
+
305
+ wp.launch(
306
+ jacobi_iter,
307
+ (N_GRID, N_GRID),
308
+ inputs=[self.div_arrays[step_index - 1], self.pressure_arrays[input_index]],
309
+ outputs=[self.pressure_arrays[output_index]],
310
+ )
311
+
312
+ # NOTE: output_index should be self.pressure_iterations*step_index at this point
313
+ wp.launch(
314
+ update_velocities,
315
+ (N_GRID, N_GRID),
316
+ inputs=[self.pressure_arrays[output_index], self.wx_arrays[step_index - 1], self.wy_arrays[step_index - 1]],
317
+ outputs=[self.vx_arrays[step_index], self.vy_arrays[step_index]],
318
+ )
319
+
320
+ wp.launch(
321
+ advect,
322
+ (N_GRID, N_GRID),
323
+ inputs=[
324
+ self.dt,
325
+ self.vx_arrays[step_index],
326
+ self.vy_arrays[step_index],
327
+ self.density_arrays[step_index - 1],
328
+ ],
329
+ outputs=[self.density_arrays[step_index]],
330
+ )
331
+
332
+ def forward(self) -> None:
333
+ """Advance the simulation forward in segments, storing the fluid state at the start of each segment.
334
+
335
+ The loss function is also evaluated at the end of the function.
336
+ """
337
+ self.loss.zero_()
338
+
339
+ for segment_index in range(self.num_segments):
340
+ # Save start-of-segment values
341
+ wp.copy(self.segment_start_vx_arrays[segment_index], self.vx_arrays[0])
342
+ wp.copy(self.segment_start_vy_arrays[segment_index], self.vy_arrays[0])
343
+ wp.copy(self.segment_start_density_arrays[segment_index], self.density_arrays[0])
344
+ wp.copy(self.segment_start_pressure_arrays[segment_index], self.pressure_arrays[0])
345
+
346
+ for t in range(1, self.segment_size + 1):
347
+ # sim_t = (segment_index - 1) * self.segment_size + t
348
+ self.step(t)
349
+
350
+ # Set the initial conditions for the next segment
351
+ if segment_index < self.num_segments - 1:
352
+ wp.copy(self.vx_arrays[0], self.vx_arrays[-1])
353
+ wp.copy(self.vy_arrays[0], self.vy_arrays[-1])
354
+ wp.copy(self.density_arrays[0], self.density_arrays[-1])
355
+ wp.copy(self.pressure_arrays[0], self.pressure_arrays[-1])
356
+
357
+ wp.launch(
358
+ compute_loss,
359
+ (N_GRID, N_GRID),
360
+ inputs=[self.density_arrays[self.segment_size], self.target_wp],
361
+ outputs=[self.loss],
362
+ )
363
+
364
+ def backward(self) -> None:
365
+ """Compute the adjoints using a checkpointing approach.
366
+
367
+ Starting from the final segment, the forward pass for the segment is
368
+ repeated, this time recording the kernel launches onto a tape. Any
369
+ previously computed adjoints are restored prior to evaluating the
370
+ backward pass for the segment. This process is repeated until the
371
+ adjoints of the initial state have been calculated.
372
+ """
373
+
374
+ for segment_index in range(self.num_segments - 1, -1, -1):
375
+ # Restore state at the start of the segment
376
+ wp.copy(self.vx_arrays[0], self.segment_start_vx_arrays[segment_index])
377
+ wp.copy(self.vy_arrays[0], self.segment_start_vy_arrays[segment_index])
378
+ wp.copy(self.density_arrays[0], self.segment_start_density_arrays[segment_index])
379
+ wp.copy(self.pressure_arrays[0], self.segment_start_pressure_arrays[segment_index])
380
+
381
+ # Record operations on tape
382
+ with wp.Tape() as self.tape:
383
+ for t in range(1, self.segment_size + 1):
384
+ self.step(t)
385
+
386
+ if segment_index == self.num_segments - 1:
387
+ self.loss.grad.fill_(1.0)
388
+
389
+ wp.launch(
390
+ compute_loss,
391
+ (N_GRID, N_GRID),
392
+ inputs=[self.density_arrays[self.segment_size], self.target_wp],
393
+ outputs=[self.loss],
394
+ adj_inputs=[self.density_arrays[self.segment_size].grad, None],
395
+ adj_outputs=[self.loss.grad],
396
+ adjoint=True,
397
+ )
398
+ else:
399
+ # Fill in previously computed gradients from the last segment
400
+ wp.copy(self.vx_arrays[-1].grad, self.vx_array_grad_saved)
401
+ wp.copy(self.vy_arrays[-1].grad, self.vy_array_grad_saved)
402
+ wp.copy(self.density_arrays[-1].grad, self.density_array_grad_saved)
403
+ wp.copy(self.pressure_arrays[-1].grad, self.pressure_array_grad_saved)
404
+
405
+ self.tape.backward()
406
+
407
+ if segment_index > 0:
408
+ # Save the gradients to variables and zero-out the gradients for the next segment
409
+ wp.copy(self.vx_array_grad_saved, self.vx_arrays[0].grad)
410
+ wp.copy(self.vy_array_grad_saved, self.vy_arrays[0].grad)
411
+ wp.copy(self.density_array_grad_saved, self.density_arrays[0].grad)
412
+ wp.copy(self.pressure_array_grad_saved, self.pressure_arrays[0].grad)
413
+
414
+ self.tape.zero()
415
+
416
+ # Done with backward pass, we're interested in self.vx_arrays[0].grad and self.vy_arrays[0].grad
417
+
418
+
419
+ if __name__ == "__main__":
420
+ import argparse
421
+
422
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
423
+ parser.add_argument("--device", type=str, default=None, help="Override the default Warp device.")
424
+ parser.add_argument(
425
+ "--num_frames", type=int, default=1000, help="Number of frames to simulate before computing loss."
426
+ )
427
+ parser.add_argument("--train_iters", type=int, default=50, help="Total number of training iterations.")
428
+ parser.add_argument(
429
+ "--headless",
430
+ action="store_true",
431
+ help="Run in headless mode, suppressing the opening of any graphical windows.",
432
+ )
433
+
434
+ args = parser.parse_known_args()[0]
435
+
436
+ with wp.ScopedDevice(args.device):
437
+ example = Example(sim_steps=args.num_frames)
438
+
439
+ wp.synchronize_device()
440
+
441
+ if (device := wp.get_device()).is_cuda:
442
+ print(f"Current memory usage: {wp.get_mempool_used_mem_current(device) / (1024 * 1024 * 1024):.4f} GiB")
443
+
444
+ # Main training loop
445
+ for train_iter in range(args.train_iters):
446
+ if example.forward_graph:
447
+ wp.capture_launch(example.forward_graph)
448
+ else:
449
+ example.forward()
450
+
451
+ if example.backward_graph:
452
+ wp.capture_launch(example.backward_graph)
453
+ else:
454
+ example.backward()
455
+
456
+ example.optimizer.step([example.vx_arrays[0].grad.flatten(), example.vy_arrays[0].grad.flatten()])
457
+
458
+ # Clear grad arrays for next iteration
459
+ if example.zero_tape_graph:
460
+ wp.capture_launch(example.zero_tape_graph)
461
+ else:
462
+ example.tape.zero()
463
+
464
+ print(f"Iteration {train_iter:05d} loss: {example.loss.numpy()[0]:.6f}")
465
+
466
+ if not args.headless:
467
+ import matplotlib
468
+ import matplotlib.pyplot as plt
469
+
470
+ if matplotlib.rcParams["figure.raise_window"]:
471
+ matplotlib.rcParams["figure.raise_window"] = False
472
+
473
+ fig, ax = plt.subplots()
474
+ image = ax.imshow(example.density_arrays[-1].numpy(), cmap="viridis", origin="lower", vmin=0, vmax=1)
475
+ ax.set_xticks([])
476
+ ax.set_yticks([])
477
+ ax.set_title("Fluid Density")
478
+
479
+ # Run the final simulation to the stop time
480
+ for _ in range(args.num_frames):
481
+ example.step(1)
482
+ # Swap pointers
483
+ (example.vx_arrays[0], example.vx_arrays[1]) = (example.vx_arrays[1], example.vx_arrays[0])
484
+ (example.vy_arrays[0], example.vy_arrays[1]) = (example.vy_arrays[1], example.vy_arrays[0])
485
+ (example.density_arrays[0], example.density_arrays[1]) = (
486
+ example.density_arrays[1],
487
+ example.density_arrays[0],
488
+ )
489
+ (example.pressure_arrays[0], example.pressure_arrays[example.pressure_iterations]) = (
490
+ example.pressure_arrays[example.pressure_iterations],
491
+ example.pressure_arrays[0],
492
+ )
493
+
494
+ image.set_data(example.density_arrays[0].numpy())
495
+ plt.pause(0.001)
496
+
497
+ plt.show()
@@ -0,0 +1,182 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ ###########################################################################
17
+ # Example Sim Rigid Kinematics
18
+ #
19
+ # Tests rigid body forward and backwards kinematics through the
20
+ # wp.sim.eval_ik() and wp.sim.eval_fk() methods.
21
+ #
22
+ ###########################################################################
23
+
24
+ import numpy as np
25
+
26
+ import warp as wp
27
+ import warp.sim
28
+ import warp.sim.render
29
+
30
+ TARGET = wp.constant(wp.vec3(2.0, 1.0, 0.0))
31
+
32
+
33
+ @wp.kernel
34
+ def compute_loss(body_q: wp.array(dtype=wp.transform), body_index: int, loss: wp.array(dtype=float)):
35
+ x = wp.transform_get_translation(body_q[body_index])
36
+
37
+ delta = x - TARGET
38
+ loss[0] = wp.dot(delta, delta)
39
+
40
+
41
+ @wp.kernel
42
+ def step_kernel(x: wp.array(dtype=float), grad: wp.array(dtype=float), alpha: float):
43
+ tid = wp.tid()
44
+
45
+ # gradient descent step
46
+ x[tid] = x[tid] - grad[tid] * alpha
47
+
48
+
49
+ class Example:
50
+ def __init__(self, stage_path="example_inverse_kinematics.usd", verbose=False):
51
+ self.verbose = verbose
52
+
53
+ fps = 60
54
+ self.frame_dt = 1.0 / fps
55
+ self.render_time = 0.0
56
+
57
+ builder = wp.sim.ModelBuilder()
58
+ builder.add_articulation()
59
+
60
+ chain_length = 4
61
+ chain_width = 1.0
62
+
63
+ for i in range(chain_length):
64
+ if i == 0:
65
+ parent = -1
66
+ parent_joint_xform = wp.transform([0.0, 0.0, 0.0], wp.quat_identity())
67
+ else:
68
+ parent = builder.joint_count - 1
69
+ parent_joint_xform = wp.transform([chain_width, 0.0, 0.0], wp.quat_identity())
70
+
71
+ # create body
72
+ b = builder.add_body(origin=wp.transform([i, 0.0, 0.0], wp.quat_identity()), armature=0.1)
73
+
74
+ builder.add_joint_revolute(
75
+ parent=parent,
76
+ child=b,
77
+ axis=(0.0, 0.0, 1.0),
78
+ parent_xform=parent_joint_xform,
79
+ child_xform=wp.transform_identity(),
80
+ limit_lower=-np.deg2rad(60.0),
81
+ limit_upper=np.deg2rad(60.0),
82
+ target_ke=0.0,
83
+ target_kd=0.0,
84
+ limit_ke=30.0,
85
+ limit_kd=30.0,
86
+ )
87
+
88
+ if i == chain_length - 1:
89
+ # create end effector
90
+ builder.add_shape_sphere(pos=wp.vec3(0.0, 0.0, 0.0), radius=0.1, density=10.0, body=b)
91
+
92
+ else:
93
+ # create shape
94
+ builder.add_shape_box(
95
+ pos=wp.vec3(chain_width * 0.5, 0.0, 0.0), hx=chain_width * 0.5, hy=0.1, hz=0.1, density=10.0, body=b
96
+ )
97
+
98
+ # finalize model
99
+ self.model = builder.finalize()
100
+ self.model.ground = False
101
+
102
+ self.state = self.model.state()
103
+
104
+ if stage_path:
105
+ self.renderer = wp.sim.render.SimRenderer(self.model, stage_path, scaling=50.0)
106
+ else:
107
+ self.renderer = None
108
+
109
+ # optimization variables
110
+ self.loss = wp.zeros(1, dtype=float)
111
+
112
+ self.model.joint_q.requires_grad = True
113
+ self.state.body_q.requires_grad = True
114
+ self.loss.requires_grad = True
115
+
116
+ self.train_rate = 0.01
117
+
118
+ def forward(self):
119
+ wp.sim.eval_fk(self.model, self.model.joint_q, self.model.joint_qd, None, self.state)
120
+
121
+ wp.launch(compute_loss, dim=1, inputs=[self.state.body_q, len(self.state.body_q) - 1, self.loss])
122
+
123
+ def step(self):
124
+ with wp.ScopedTimer("step"):
125
+ tape = wp.Tape()
126
+ with tape:
127
+ self.forward()
128
+ tape.backward(loss=self.loss)
129
+
130
+ if self.verbose:
131
+ print(f"loss: {self.loss}")
132
+ print(f"joint_grad: {tape.gradients[self.model.joint_q]}")
133
+
134
+ # gradient descent
135
+ wp.launch(
136
+ step_kernel,
137
+ dim=len(self.model.joint_q),
138
+ inputs=[self.model.joint_q, tape.gradients[self.model.joint_q], self.train_rate],
139
+ )
140
+
141
+ # zero gradients
142
+ tape.zero()
143
+
144
+ def render(self):
145
+ if self.renderer is None:
146
+ return
147
+
148
+ with wp.ScopedTimer("render"):
149
+ self.renderer.begin_frame(self.render_time)
150
+ self.renderer.render(self.state)
151
+ self.renderer.render_sphere(
152
+ name="target", pos=TARGET, rot=wp.quat_identity(), radius=0.1, color=(1.0, 0.0, 0.0)
153
+ )
154
+ self.renderer.end_frame()
155
+ self.render_time += self.frame_dt
156
+
157
+
158
+ if __name__ == "__main__":
159
+ import argparse
160
+
161
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
162
+ parser.add_argument("--device", type=str, default=None, help="Override the default Warp device.")
163
+ parser.add_argument(
164
+ "--stage_path",
165
+ type=lambda x: None if x == "None" else str(x),
166
+ default="example_inverse_kinematics.usd",
167
+ help="Path to the output USD file.",
168
+ )
169
+ parser.add_argument("--train_iters", type=int, default=512, help="Total number of training iterations.")
170
+ parser.add_argument("--verbose", action="store_true", help="Print out additional status messages during execution.")
171
+
172
+ args = parser.parse_known_args()[0]
173
+
174
+ with wp.ScopedDevice(args.device):
175
+ example = Example(stage_path=args.stage_path, verbose=args.verbose)
176
+
177
+ for _ in range(args.train_iters):
178
+ example.step()
179
+ example.render()
180
+
181
+ if example.renderer:
182
+ example.renderer.save()