warp-lang 1.7.0__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (429) hide show
  1. warp/__init__.py +139 -0
  2. warp/__init__.pyi +1 -0
  3. warp/autograd.py +1142 -0
  4. warp/bin/warp-clang.so +0 -0
  5. warp/bin/warp.so +0 -0
  6. warp/build.py +557 -0
  7. warp/build_dll.py +405 -0
  8. warp/builtins.py +6855 -0
  9. warp/codegen.py +3969 -0
  10. warp/config.py +158 -0
  11. warp/constants.py +57 -0
  12. warp/context.py +6812 -0
  13. warp/dlpack.py +462 -0
  14. warp/examples/__init__.py +24 -0
  15. warp/examples/assets/bear.usd +0 -0
  16. warp/examples/assets/bunny.usd +0 -0
  17. warp/examples/assets/cartpole.urdf +110 -0
  18. warp/examples/assets/crazyflie.usd +0 -0
  19. warp/examples/assets/cube.usd +0 -0
  20. warp/examples/assets/nonuniform.usd +0 -0
  21. warp/examples/assets/nv_ant.xml +92 -0
  22. warp/examples/assets/nv_humanoid.xml +183 -0
  23. warp/examples/assets/nvidia_logo.png +0 -0
  24. warp/examples/assets/pixel.jpg +0 -0
  25. warp/examples/assets/quadruped.urdf +268 -0
  26. warp/examples/assets/rocks.nvdb +0 -0
  27. warp/examples/assets/rocks.usd +0 -0
  28. warp/examples/assets/sphere.usd +0 -0
  29. warp/examples/assets/square_cloth.usd +0 -0
  30. warp/examples/benchmarks/benchmark_api.py +389 -0
  31. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  32. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  33. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  34. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  35. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  36. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  37. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  38. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  39. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  40. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  41. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  42. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  43. warp/examples/benchmarks/benchmark_launches.py +301 -0
  44. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  45. warp/examples/browse.py +37 -0
  46. warp/examples/core/example_cupy.py +86 -0
  47. warp/examples/core/example_dem.py +241 -0
  48. warp/examples/core/example_fluid.py +299 -0
  49. warp/examples/core/example_graph_capture.py +150 -0
  50. warp/examples/core/example_marching_cubes.py +194 -0
  51. warp/examples/core/example_mesh.py +180 -0
  52. warp/examples/core/example_mesh_intersect.py +211 -0
  53. warp/examples/core/example_nvdb.py +182 -0
  54. warp/examples/core/example_raycast.py +111 -0
  55. warp/examples/core/example_raymarch.py +205 -0
  56. warp/examples/core/example_render_opengl.py +193 -0
  57. warp/examples/core/example_sample_mesh.py +300 -0
  58. warp/examples/core/example_sph.py +411 -0
  59. warp/examples/core/example_torch.py +211 -0
  60. warp/examples/core/example_wave.py +269 -0
  61. warp/examples/fem/example_adaptive_grid.py +286 -0
  62. warp/examples/fem/example_apic_fluid.py +423 -0
  63. warp/examples/fem/example_burgers.py +261 -0
  64. warp/examples/fem/example_convection_diffusion.py +178 -0
  65. warp/examples/fem/example_convection_diffusion_dg.py +204 -0
  66. warp/examples/fem/example_deformed_geometry.py +172 -0
  67. warp/examples/fem/example_diffusion.py +196 -0
  68. warp/examples/fem/example_diffusion_3d.py +225 -0
  69. warp/examples/fem/example_diffusion_mgpu.py +220 -0
  70. warp/examples/fem/example_distortion_energy.py +228 -0
  71. warp/examples/fem/example_magnetostatics.py +240 -0
  72. warp/examples/fem/example_mixed_elasticity.py +291 -0
  73. warp/examples/fem/example_navier_stokes.py +261 -0
  74. warp/examples/fem/example_nonconforming_contact.py +298 -0
  75. warp/examples/fem/example_stokes.py +213 -0
  76. warp/examples/fem/example_stokes_transfer.py +262 -0
  77. warp/examples/fem/example_streamlines.py +352 -0
  78. warp/examples/fem/utils.py +1000 -0
  79. warp/examples/interop/example_jax_callable.py +116 -0
  80. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  81. warp/examples/interop/example_jax_kernel.py +205 -0
  82. warp/examples/optim/example_bounce.py +266 -0
  83. warp/examples/optim/example_cloth_throw.py +228 -0
  84. warp/examples/optim/example_diffray.py +561 -0
  85. warp/examples/optim/example_drone.py +870 -0
  86. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  87. warp/examples/optim/example_inverse_kinematics.py +182 -0
  88. warp/examples/optim/example_inverse_kinematics_torch.py +191 -0
  89. warp/examples/optim/example_softbody_properties.py +400 -0
  90. warp/examples/optim/example_spring_cage.py +245 -0
  91. warp/examples/optim/example_trajectory.py +227 -0
  92. warp/examples/sim/example_cartpole.py +143 -0
  93. warp/examples/sim/example_cloth.py +225 -0
  94. warp/examples/sim/example_cloth_self_contact.py +322 -0
  95. warp/examples/sim/example_granular.py +130 -0
  96. warp/examples/sim/example_granular_collision_sdf.py +202 -0
  97. warp/examples/sim/example_jacobian_ik.py +244 -0
  98. warp/examples/sim/example_particle_chain.py +124 -0
  99. warp/examples/sim/example_quadruped.py +203 -0
  100. warp/examples/sim/example_rigid_chain.py +203 -0
  101. warp/examples/sim/example_rigid_contact.py +195 -0
  102. warp/examples/sim/example_rigid_force.py +133 -0
  103. warp/examples/sim/example_rigid_gyroscopic.py +115 -0
  104. warp/examples/sim/example_rigid_soft_contact.py +140 -0
  105. warp/examples/sim/example_soft_body.py +196 -0
  106. warp/examples/tile/example_tile_cholesky.py +87 -0
  107. warp/examples/tile/example_tile_convolution.py +66 -0
  108. warp/examples/tile/example_tile_fft.py +55 -0
  109. warp/examples/tile/example_tile_filtering.py +113 -0
  110. warp/examples/tile/example_tile_matmul.py +85 -0
  111. warp/examples/tile/example_tile_mlp.py +383 -0
  112. warp/examples/tile/example_tile_nbody.py +199 -0
  113. warp/examples/tile/example_tile_walker.py +327 -0
  114. warp/fabric.py +355 -0
  115. warp/fem/__init__.py +106 -0
  116. warp/fem/adaptivity.py +508 -0
  117. warp/fem/cache.py +572 -0
  118. warp/fem/dirichlet.py +202 -0
  119. warp/fem/domain.py +411 -0
  120. warp/fem/field/__init__.py +125 -0
  121. warp/fem/field/field.py +619 -0
  122. warp/fem/field/nodal_field.py +326 -0
  123. warp/fem/field/restriction.py +37 -0
  124. warp/fem/field/virtual.py +848 -0
  125. warp/fem/geometry/__init__.py +32 -0
  126. warp/fem/geometry/adaptive_nanogrid.py +857 -0
  127. warp/fem/geometry/closest_point.py +84 -0
  128. warp/fem/geometry/deformed_geometry.py +221 -0
  129. warp/fem/geometry/element.py +776 -0
  130. warp/fem/geometry/geometry.py +362 -0
  131. warp/fem/geometry/grid_2d.py +392 -0
  132. warp/fem/geometry/grid_3d.py +452 -0
  133. warp/fem/geometry/hexmesh.py +911 -0
  134. warp/fem/geometry/nanogrid.py +571 -0
  135. warp/fem/geometry/partition.py +389 -0
  136. warp/fem/geometry/quadmesh.py +663 -0
  137. warp/fem/geometry/tetmesh.py +855 -0
  138. warp/fem/geometry/trimesh.py +806 -0
  139. warp/fem/integrate.py +2335 -0
  140. warp/fem/linalg.py +419 -0
  141. warp/fem/operator.py +293 -0
  142. warp/fem/polynomial.py +229 -0
  143. warp/fem/quadrature/__init__.py +17 -0
  144. warp/fem/quadrature/pic_quadrature.py +299 -0
  145. warp/fem/quadrature/quadrature.py +591 -0
  146. warp/fem/space/__init__.py +228 -0
  147. warp/fem/space/basis_function_space.py +468 -0
  148. warp/fem/space/basis_space.py +667 -0
  149. warp/fem/space/dof_mapper.py +251 -0
  150. warp/fem/space/function_space.py +309 -0
  151. warp/fem/space/grid_2d_function_space.py +177 -0
  152. warp/fem/space/grid_3d_function_space.py +227 -0
  153. warp/fem/space/hexmesh_function_space.py +257 -0
  154. warp/fem/space/nanogrid_function_space.py +201 -0
  155. warp/fem/space/partition.py +367 -0
  156. warp/fem/space/quadmesh_function_space.py +223 -0
  157. warp/fem/space/restriction.py +179 -0
  158. warp/fem/space/shape/__init__.py +143 -0
  159. warp/fem/space/shape/cube_shape_function.py +1105 -0
  160. warp/fem/space/shape/shape_function.py +133 -0
  161. warp/fem/space/shape/square_shape_function.py +926 -0
  162. warp/fem/space/shape/tet_shape_function.py +834 -0
  163. warp/fem/space/shape/triangle_shape_function.py +672 -0
  164. warp/fem/space/tetmesh_function_space.py +271 -0
  165. warp/fem/space/topology.py +424 -0
  166. warp/fem/space/trimesh_function_space.py +194 -0
  167. warp/fem/types.py +99 -0
  168. warp/fem/utils.py +420 -0
  169. warp/jax.py +187 -0
  170. warp/jax_experimental/__init__.py +16 -0
  171. warp/jax_experimental/custom_call.py +351 -0
  172. warp/jax_experimental/ffi.py +698 -0
  173. warp/jax_experimental/xla_ffi.py +602 -0
  174. warp/math.py +244 -0
  175. warp/native/array.h +1145 -0
  176. warp/native/builtin.h +1800 -0
  177. warp/native/bvh.cpp +492 -0
  178. warp/native/bvh.cu +791 -0
  179. warp/native/bvh.h +554 -0
  180. warp/native/clang/clang.cpp +536 -0
  181. warp/native/coloring.cpp +613 -0
  182. warp/native/crt.cpp +51 -0
  183. warp/native/crt.h +362 -0
  184. warp/native/cuda_crt.h +1058 -0
  185. warp/native/cuda_util.cpp +646 -0
  186. warp/native/cuda_util.h +307 -0
  187. warp/native/error.cpp +77 -0
  188. warp/native/error.h +36 -0
  189. warp/native/exports.h +1878 -0
  190. warp/native/fabric.h +245 -0
  191. warp/native/hashgrid.cpp +311 -0
  192. warp/native/hashgrid.cu +87 -0
  193. warp/native/hashgrid.h +240 -0
  194. warp/native/initializer_array.h +41 -0
  195. warp/native/intersect.h +1230 -0
  196. warp/native/intersect_adj.h +375 -0
  197. warp/native/intersect_tri.h +339 -0
  198. warp/native/marching.cpp +19 -0
  199. warp/native/marching.cu +514 -0
  200. warp/native/marching.h +19 -0
  201. warp/native/mat.h +2220 -0
  202. warp/native/mathdx.cpp +87 -0
  203. warp/native/matnn.h +343 -0
  204. warp/native/mesh.cpp +266 -0
  205. warp/native/mesh.cu +404 -0
  206. warp/native/mesh.h +1980 -0
  207. warp/native/nanovdb/GridHandle.h +366 -0
  208. warp/native/nanovdb/HostBuffer.h +590 -0
  209. warp/native/nanovdb/NanoVDB.h +6624 -0
  210. warp/native/nanovdb/PNanoVDB.h +3390 -0
  211. warp/native/noise.h +859 -0
  212. warp/native/quat.h +1371 -0
  213. warp/native/rand.h +342 -0
  214. warp/native/range.h +139 -0
  215. warp/native/reduce.cpp +174 -0
  216. warp/native/reduce.cu +364 -0
  217. warp/native/runlength_encode.cpp +79 -0
  218. warp/native/runlength_encode.cu +61 -0
  219. warp/native/scan.cpp +47 -0
  220. warp/native/scan.cu +53 -0
  221. warp/native/scan.h +23 -0
  222. warp/native/solid_angle.h +466 -0
  223. warp/native/sort.cpp +251 -0
  224. warp/native/sort.cu +277 -0
  225. warp/native/sort.h +33 -0
  226. warp/native/sparse.cpp +378 -0
  227. warp/native/sparse.cu +524 -0
  228. warp/native/spatial.h +657 -0
  229. warp/native/svd.h +702 -0
  230. warp/native/temp_buffer.h +46 -0
  231. warp/native/tile.h +2584 -0
  232. warp/native/tile_reduce.h +264 -0
  233. warp/native/vec.h +1426 -0
  234. warp/native/volume.cpp +501 -0
  235. warp/native/volume.cu +67 -0
  236. warp/native/volume.h +969 -0
  237. warp/native/volume_builder.cu +477 -0
  238. warp/native/volume_builder.h +52 -0
  239. warp/native/volume_impl.h +70 -0
  240. warp/native/warp.cpp +1082 -0
  241. warp/native/warp.cu +3636 -0
  242. warp/native/warp.h +381 -0
  243. warp/optim/__init__.py +17 -0
  244. warp/optim/adam.py +163 -0
  245. warp/optim/linear.py +1137 -0
  246. warp/optim/sgd.py +112 -0
  247. warp/paddle.py +407 -0
  248. warp/render/__init__.py +18 -0
  249. warp/render/render_opengl.py +3518 -0
  250. warp/render/render_usd.py +784 -0
  251. warp/render/utils.py +160 -0
  252. warp/sim/__init__.py +65 -0
  253. warp/sim/articulation.py +793 -0
  254. warp/sim/collide.py +2395 -0
  255. warp/sim/graph_coloring.py +300 -0
  256. warp/sim/import_mjcf.py +790 -0
  257. warp/sim/import_snu.py +227 -0
  258. warp/sim/import_urdf.py +579 -0
  259. warp/sim/import_usd.py +894 -0
  260. warp/sim/inertia.py +324 -0
  261. warp/sim/integrator.py +242 -0
  262. warp/sim/integrator_euler.py +1997 -0
  263. warp/sim/integrator_featherstone.py +2101 -0
  264. warp/sim/integrator_vbd.py +2048 -0
  265. warp/sim/integrator_xpbd.py +3292 -0
  266. warp/sim/model.py +4791 -0
  267. warp/sim/particles.py +121 -0
  268. warp/sim/render.py +427 -0
  269. warp/sim/utils.py +428 -0
  270. warp/sparse.py +2057 -0
  271. warp/stubs.py +3333 -0
  272. warp/tape.py +1203 -0
  273. warp/tests/__init__.py +1 -0
  274. warp/tests/__main__.py +4 -0
  275. warp/tests/assets/curlnoise_golden.npy +0 -0
  276. warp/tests/assets/mlp_golden.npy +0 -0
  277. warp/tests/assets/pixel.npy +0 -0
  278. warp/tests/assets/pnoise_golden.npy +0 -0
  279. warp/tests/assets/spiky.usd +0 -0
  280. warp/tests/assets/test_grid.nvdb +0 -0
  281. warp/tests/assets/test_index_grid.nvdb +0 -0
  282. warp/tests/assets/test_int32_grid.nvdb +0 -0
  283. warp/tests/assets/test_vec_grid.nvdb +0 -0
  284. warp/tests/assets/torus.nvdb +0 -0
  285. warp/tests/assets/torus.usda +105 -0
  286. warp/tests/aux_test_class_kernel.py +34 -0
  287. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  288. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  289. warp/tests/aux_test_dependent.py +29 -0
  290. warp/tests/aux_test_grad_customs.py +29 -0
  291. warp/tests/aux_test_instancing_gc.py +26 -0
  292. warp/tests/aux_test_module_unload.py +23 -0
  293. warp/tests/aux_test_name_clash1.py +40 -0
  294. warp/tests/aux_test_name_clash2.py +40 -0
  295. warp/tests/aux_test_reference.py +9 -0
  296. warp/tests/aux_test_reference_reference.py +8 -0
  297. warp/tests/aux_test_square.py +16 -0
  298. warp/tests/aux_test_unresolved_func.py +22 -0
  299. warp/tests/aux_test_unresolved_symbol.py +22 -0
  300. warp/tests/cuda/__init__.py +0 -0
  301. warp/tests/cuda/test_async.py +676 -0
  302. warp/tests/cuda/test_ipc.py +124 -0
  303. warp/tests/cuda/test_mempool.py +233 -0
  304. warp/tests/cuda/test_multigpu.py +169 -0
  305. warp/tests/cuda/test_peer.py +139 -0
  306. warp/tests/cuda/test_pinned.py +84 -0
  307. warp/tests/cuda/test_streams.py +634 -0
  308. warp/tests/geometry/__init__.py +0 -0
  309. warp/tests/geometry/test_bvh.py +200 -0
  310. warp/tests/geometry/test_hash_grid.py +221 -0
  311. warp/tests/geometry/test_marching_cubes.py +74 -0
  312. warp/tests/geometry/test_mesh.py +316 -0
  313. warp/tests/geometry/test_mesh_query_aabb.py +399 -0
  314. warp/tests/geometry/test_mesh_query_point.py +932 -0
  315. warp/tests/geometry/test_mesh_query_ray.py +311 -0
  316. warp/tests/geometry/test_volume.py +1103 -0
  317. warp/tests/geometry/test_volume_write.py +346 -0
  318. warp/tests/interop/__init__.py +0 -0
  319. warp/tests/interop/test_dlpack.py +729 -0
  320. warp/tests/interop/test_jax.py +371 -0
  321. warp/tests/interop/test_paddle.py +800 -0
  322. warp/tests/interop/test_torch.py +1001 -0
  323. warp/tests/run_coverage_serial.py +39 -0
  324. warp/tests/sim/__init__.py +0 -0
  325. warp/tests/sim/disabled_kinematics.py +244 -0
  326. warp/tests/sim/flaky_test_sim_grad.py +290 -0
  327. warp/tests/sim/test_collision.py +604 -0
  328. warp/tests/sim/test_coloring.py +258 -0
  329. warp/tests/sim/test_model.py +224 -0
  330. warp/tests/sim/test_sim_grad_bounce_linear.py +212 -0
  331. warp/tests/sim/test_sim_kinematics.py +98 -0
  332. warp/tests/sim/test_vbd.py +597 -0
  333. warp/tests/test_adam.py +163 -0
  334. warp/tests/test_arithmetic.py +1096 -0
  335. warp/tests/test_array.py +2972 -0
  336. warp/tests/test_array_reduce.py +156 -0
  337. warp/tests/test_assert.py +250 -0
  338. warp/tests/test_atomic.py +153 -0
  339. warp/tests/test_bool.py +220 -0
  340. warp/tests/test_builtins_resolution.py +1298 -0
  341. warp/tests/test_closest_point_edge_edge.py +327 -0
  342. warp/tests/test_codegen.py +810 -0
  343. warp/tests/test_codegen_instancing.py +1495 -0
  344. warp/tests/test_compile_consts.py +215 -0
  345. warp/tests/test_conditional.py +252 -0
  346. warp/tests/test_context.py +42 -0
  347. warp/tests/test_copy.py +238 -0
  348. warp/tests/test_ctypes.py +638 -0
  349. warp/tests/test_dense.py +73 -0
  350. warp/tests/test_devices.py +97 -0
  351. warp/tests/test_examples.py +482 -0
  352. warp/tests/test_fabricarray.py +996 -0
  353. warp/tests/test_fast_math.py +74 -0
  354. warp/tests/test_fem.py +2003 -0
  355. warp/tests/test_fp16.py +136 -0
  356. warp/tests/test_func.py +454 -0
  357. warp/tests/test_future_annotations.py +98 -0
  358. warp/tests/test_generics.py +656 -0
  359. warp/tests/test_grad.py +893 -0
  360. warp/tests/test_grad_customs.py +339 -0
  361. warp/tests/test_grad_debug.py +341 -0
  362. warp/tests/test_implicit_init.py +411 -0
  363. warp/tests/test_import.py +45 -0
  364. warp/tests/test_indexedarray.py +1140 -0
  365. warp/tests/test_intersect.py +73 -0
  366. warp/tests/test_iter.py +76 -0
  367. warp/tests/test_large.py +177 -0
  368. warp/tests/test_launch.py +411 -0
  369. warp/tests/test_lerp.py +151 -0
  370. warp/tests/test_linear_solvers.py +193 -0
  371. warp/tests/test_lvalue.py +427 -0
  372. warp/tests/test_mat.py +2089 -0
  373. warp/tests/test_mat_lite.py +122 -0
  374. warp/tests/test_mat_scalar_ops.py +2913 -0
  375. warp/tests/test_math.py +178 -0
  376. warp/tests/test_mlp.py +282 -0
  377. warp/tests/test_module_hashing.py +258 -0
  378. warp/tests/test_modules_lite.py +44 -0
  379. warp/tests/test_noise.py +252 -0
  380. warp/tests/test_operators.py +299 -0
  381. warp/tests/test_options.py +129 -0
  382. warp/tests/test_overwrite.py +551 -0
  383. warp/tests/test_print.py +339 -0
  384. warp/tests/test_quat.py +2315 -0
  385. warp/tests/test_rand.py +339 -0
  386. warp/tests/test_reload.py +302 -0
  387. warp/tests/test_rounding.py +185 -0
  388. warp/tests/test_runlength_encode.py +196 -0
  389. warp/tests/test_scalar_ops.py +105 -0
  390. warp/tests/test_smoothstep.py +108 -0
  391. warp/tests/test_snippet.py +318 -0
  392. warp/tests/test_sparse.py +582 -0
  393. warp/tests/test_spatial.py +2229 -0
  394. warp/tests/test_special_values.py +361 -0
  395. warp/tests/test_static.py +592 -0
  396. warp/tests/test_struct.py +734 -0
  397. warp/tests/test_tape.py +204 -0
  398. warp/tests/test_transient_module.py +93 -0
  399. warp/tests/test_triangle_closest_point.py +145 -0
  400. warp/tests/test_types.py +562 -0
  401. warp/tests/test_utils.py +588 -0
  402. warp/tests/test_vec.py +1487 -0
  403. warp/tests/test_vec_lite.py +80 -0
  404. warp/tests/test_vec_scalar_ops.py +2327 -0
  405. warp/tests/test_verify_fp.py +100 -0
  406. warp/tests/tile/__init__.py +0 -0
  407. warp/tests/tile/test_tile.py +780 -0
  408. warp/tests/tile/test_tile_load.py +407 -0
  409. warp/tests/tile/test_tile_mathdx.py +208 -0
  410. warp/tests/tile/test_tile_mlp.py +402 -0
  411. warp/tests/tile/test_tile_reduce.py +447 -0
  412. warp/tests/tile/test_tile_shared_memory.py +247 -0
  413. warp/tests/tile/test_tile_view.py +173 -0
  414. warp/tests/unittest_serial.py +47 -0
  415. warp/tests/unittest_suites.py +427 -0
  416. warp/tests/unittest_utils.py +468 -0
  417. warp/tests/walkthrough_debug.py +93 -0
  418. warp/thirdparty/__init__.py +0 -0
  419. warp/thirdparty/appdirs.py +598 -0
  420. warp/thirdparty/dlpack.py +145 -0
  421. warp/thirdparty/unittest_parallel.py +570 -0
  422. warp/torch.py +391 -0
  423. warp/types.py +5230 -0
  424. warp/utils.py +1137 -0
  425. warp_lang-1.7.0.dist-info/METADATA +516 -0
  426. warp_lang-1.7.0.dist-info/RECORD +429 -0
  427. warp_lang-1.7.0.dist-info/WHEEL +5 -0
  428. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  429. warp_lang-1.7.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,793 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import warp as wp
17
+
18
+ from .utils import quat_decompose, quat_twist
19
+
20
+
21
+ @wp.func
22
+ def compute_2d_rotational_dofs(
23
+ axis_0: wp.vec3,
24
+ axis_1: wp.vec3,
25
+ q0: float,
26
+ q1: float,
27
+ qd0: float,
28
+ qd1: float,
29
+ ):
30
+ """
31
+ Computes the rotation quaternion and 3D angular velocity given the joint axes, coordinates and velocities.
32
+ """
33
+ q_off = wp.quat_from_matrix(wp.matrix_from_cols(axis_0, axis_1, wp.cross(axis_0, axis_1)))
34
+
35
+ # body local axes
36
+ local_0 = wp.quat_rotate(q_off, wp.vec3(1.0, 0.0, 0.0))
37
+ local_1 = wp.quat_rotate(q_off, wp.vec3(0.0, 1.0, 0.0))
38
+
39
+ axis_0 = local_0
40
+ q_0 = wp.quat_from_axis_angle(axis_0, q0)
41
+
42
+ axis_1 = wp.quat_rotate(q_0, local_1)
43
+ q_1 = wp.quat_from_axis_angle(axis_1, q1)
44
+
45
+ rot = q_1 * q_0
46
+
47
+ vel = axis_0 * qd0 + axis_1 * qd1
48
+
49
+ return rot, vel
50
+
51
+
52
+ @wp.func
53
+ def invert_2d_rotational_dofs(
54
+ axis_0: wp.vec3,
55
+ axis_1: wp.vec3,
56
+ q_p: wp.quat,
57
+ q_c: wp.quat,
58
+ w_err: wp.vec3,
59
+ ):
60
+ """
61
+ Computes generalized joint position and velocity coordinates for a 2D rotational joint given the joint axes, relative orientations and angular velocity differences between the two bodies the joint connects.
62
+ """
63
+ q_off = wp.quat_from_matrix(wp.matrix_from_cols(axis_0, axis_1, wp.cross(axis_0, axis_1)))
64
+ q_pc = wp.quat_inverse(q_off) * wp.quat_inverse(q_p) * q_c * q_off
65
+
66
+ # decompose to a compound rotation each axis
67
+ angles = quat_decompose(q_pc)
68
+
69
+ # find rotation axes
70
+ local_0 = wp.quat_rotate(q_off, wp.vec3(1.0, 0.0, 0.0))
71
+ local_1 = wp.quat_rotate(q_off, wp.vec3(0.0, 1.0, 0.0))
72
+ local_2 = wp.quat_rotate(q_off, wp.vec3(0.0, 0.0, 1.0))
73
+
74
+ axis_0 = local_0
75
+ q_0 = wp.quat_from_axis_angle(axis_0, angles[0])
76
+
77
+ axis_1 = wp.quat_rotate(q_0, local_1)
78
+ q_1 = wp.quat_from_axis_angle(axis_1, angles[1])
79
+
80
+ axis_2 = wp.quat_rotate(q_1 * q_0, local_2)
81
+
82
+ # convert angular velocity to local space
83
+ w_err_p = wp.quat_rotate_inv(q_p, w_err)
84
+
85
+ # given joint axes and angular velocity error, solve for joint velocities
86
+ c12 = wp.cross(axis_1, axis_2)
87
+ c02 = wp.cross(axis_0, axis_2)
88
+
89
+ vel = wp.vec2(wp.dot(w_err_p, c12) / wp.dot(axis_0, c12), wp.dot(w_err_p, c02) / wp.dot(axis_1, c02))
90
+
91
+ return wp.vec2(angles[0], angles[1]), vel
92
+
93
+
94
+ @wp.func
95
+ def compute_3d_rotational_dofs(
96
+ axis_0: wp.vec3,
97
+ axis_1: wp.vec3,
98
+ axis_2: wp.vec3,
99
+ q0: float,
100
+ q1: float,
101
+ q2: float,
102
+ qd0: float,
103
+ qd1: float,
104
+ qd2: float,
105
+ ):
106
+ """
107
+ Computes the rotation quaternion and 3D angular velocity given the joint axes, coordinates and velocities.
108
+ """
109
+ q_off = wp.quat_from_matrix(wp.matrix_from_cols(axis_0, axis_1, axis_2))
110
+
111
+ # body local axes
112
+ local_0 = wp.quat_rotate(q_off, wp.vec3(1.0, 0.0, 0.0))
113
+ local_1 = wp.quat_rotate(q_off, wp.vec3(0.0, 1.0, 0.0))
114
+ local_2 = wp.quat_rotate(q_off, wp.vec3(0.0, 0.0, 1.0))
115
+
116
+ # reconstruct rotation axes
117
+ axis_0 = local_0
118
+ q_0 = wp.quat_from_axis_angle(axis_0, q0)
119
+
120
+ axis_1 = wp.quat_rotate(q_0, local_1)
121
+ q_1 = wp.quat_from_axis_angle(axis_1, q1)
122
+
123
+ axis_2 = wp.quat_rotate(q_1 * q_0, local_2)
124
+ q_2 = wp.quat_from_axis_angle(axis_2, q2)
125
+
126
+ rot = q_2 * q_1 * q_0
127
+ vel = axis_0 * qd0 + axis_1 * qd1 + axis_2 * qd2
128
+
129
+ return rot, vel
130
+
131
+
132
+ @wp.func
133
+ def invert_3d_rotational_dofs(
134
+ axis_0: wp.vec3, axis_1: wp.vec3, axis_2: wp.vec3, q_p: wp.quat, q_c: wp.quat, w_err: wp.vec3
135
+ ):
136
+ """
137
+ Computes generalized joint position and velocity coordinates for a 3D rotational joint given the joint axes, relative orientations and angular velocity differences between the two bodies the joint connects.
138
+ """
139
+ q_off = wp.quat_from_matrix(wp.matrix_from_cols(axis_0, axis_1, axis_2))
140
+ q_pc = wp.quat_inverse(q_off) * wp.quat_inverse(q_p) * q_c * q_off
141
+
142
+ # decompose to a compound rotation each axis
143
+ angles = quat_decompose(q_pc)
144
+
145
+ # find rotation axes
146
+ local_0 = wp.quat_rotate(q_off, wp.vec3(1.0, 0.0, 0.0))
147
+ local_1 = wp.quat_rotate(q_off, wp.vec3(0.0, 1.0, 0.0))
148
+ local_2 = wp.quat_rotate(q_off, wp.vec3(0.0, 0.0, 1.0))
149
+
150
+ axis_0 = local_0
151
+ q_0 = wp.quat_from_axis_angle(axis_0, angles[0])
152
+
153
+ axis_1 = wp.quat_rotate(q_0, local_1)
154
+ q_1 = wp.quat_from_axis_angle(axis_1, angles[1])
155
+
156
+ axis_2 = wp.quat_rotate(q_1 * q_0, local_2)
157
+
158
+ # convert angular velocity to local space
159
+ w_err_p = wp.quat_rotate_inv(q_p, w_err)
160
+
161
+ # given joint axes and angular velocity error, solve for joint velocities
162
+ c12 = wp.cross(axis_1, axis_2)
163
+ c02 = wp.cross(axis_0, axis_2)
164
+ c01 = wp.cross(axis_0, axis_1)
165
+
166
+ velocities = wp.vec3(
167
+ wp.dot(w_err_p, c12) / wp.dot(axis_0, c12),
168
+ wp.dot(w_err_p, c02) / wp.dot(axis_1, c02),
169
+ wp.dot(w_err_p, c01) / wp.dot(axis_2, c01),
170
+ )
171
+
172
+ return angles, velocities
173
+
174
+
175
+ @wp.func
176
+ def eval_single_articulation_fk(
177
+ joint_start: int,
178
+ joint_end: int,
179
+ joint_q: wp.array(dtype=float),
180
+ joint_qd: wp.array(dtype=float),
181
+ joint_q_start: wp.array(dtype=int),
182
+ joint_qd_start: wp.array(dtype=int),
183
+ joint_type: wp.array(dtype=int),
184
+ joint_parent: wp.array(dtype=int),
185
+ joint_child: wp.array(dtype=int),
186
+ joint_X_p: wp.array(dtype=wp.transform),
187
+ joint_X_c: wp.array(dtype=wp.transform),
188
+ joint_axis: wp.array(dtype=wp.vec3),
189
+ joint_axis_start: wp.array(dtype=int),
190
+ joint_axis_dim: wp.array(dtype=int, ndim=2),
191
+ body_com: wp.array(dtype=wp.vec3),
192
+ # outputs
193
+ body_q: wp.array(dtype=wp.transform),
194
+ body_qd: wp.array(dtype=wp.spatial_vector),
195
+ ):
196
+ for i in range(joint_start, joint_end):
197
+ parent = joint_parent[i]
198
+ child = joint_child[i]
199
+
200
+ # compute transform across the joint
201
+ type = joint_type[i]
202
+
203
+ X_pj = joint_X_p[i]
204
+ X_cj = joint_X_c[i]
205
+
206
+ # parent anchor frame in world space
207
+ X_wpj = X_pj
208
+ # velocity of parent anchor point in world space
209
+ v_wpj = wp.spatial_vector()
210
+ if parent >= 0:
211
+ X_wp = body_q[parent]
212
+ X_wpj = X_wp * X_wpj
213
+ r_p = wp.transform_get_translation(X_wpj) - wp.transform_point(X_wp, body_com[parent])
214
+
215
+ v_wp = body_qd[parent]
216
+ w_p = wp.spatial_top(v_wp)
217
+ v_p = wp.spatial_bottom(v_wp) + wp.cross(w_p, r_p)
218
+ v_wpj = wp.spatial_vector(w_p, v_p)
219
+
220
+ q_start = joint_q_start[i]
221
+ qd_start = joint_qd_start[i]
222
+ axis_start = joint_axis_start[i]
223
+ lin_axis_count = joint_axis_dim[i, 0]
224
+ ang_axis_count = joint_axis_dim[i, 1]
225
+
226
+ X_j = wp.transform_identity()
227
+ v_j = wp.spatial_vector(wp.vec3(), wp.vec3())
228
+
229
+ if type == wp.sim.JOINT_PRISMATIC:
230
+ axis = joint_axis[axis_start]
231
+
232
+ q = joint_q[q_start]
233
+ qd = joint_qd[qd_start]
234
+
235
+ X_j = wp.transform(axis * q, wp.quat_identity())
236
+ v_j = wp.spatial_vector(wp.vec3(), axis * qd)
237
+
238
+ if type == wp.sim.JOINT_REVOLUTE:
239
+ axis = joint_axis[axis_start]
240
+
241
+ q = joint_q[q_start]
242
+ qd = joint_qd[qd_start]
243
+
244
+ X_j = wp.transform(wp.vec3(), wp.quat_from_axis_angle(axis, q))
245
+ v_j = wp.spatial_vector(axis * qd, wp.vec3())
246
+
247
+ if type == wp.sim.JOINT_BALL:
248
+ r = wp.quat(joint_q[q_start + 0], joint_q[q_start + 1], joint_q[q_start + 2], joint_q[q_start + 3])
249
+
250
+ w = wp.vec3(joint_qd[qd_start + 0], joint_qd[qd_start + 1], joint_qd[qd_start + 2])
251
+
252
+ X_j = wp.transform(wp.vec3(), r)
253
+ v_j = wp.spatial_vector(w, wp.vec3())
254
+
255
+ if type == wp.sim.JOINT_FREE or type == wp.sim.JOINT_DISTANCE:
256
+ t = wp.transform(
257
+ wp.vec3(joint_q[q_start + 0], joint_q[q_start + 1], joint_q[q_start + 2]),
258
+ wp.quat(joint_q[q_start + 3], joint_q[q_start + 4], joint_q[q_start + 5], joint_q[q_start + 6]),
259
+ )
260
+
261
+ v = wp.spatial_vector(
262
+ wp.vec3(joint_qd[qd_start + 0], joint_qd[qd_start + 1], joint_qd[qd_start + 2]),
263
+ wp.vec3(joint_qd[qd_start + 3], joint_qd[qd_start + 4], joint_qd[qd_start + 5]),
264
+ )
265
+
266
+ X_j = t
267
+ v_j = v
268
+
269
+ if type == wp.sim.JOINT_COMPOUND:
270
+ rot, vel_w = compute_3d_rotational_dofs(
271
+ joint_axis[axis_start],
272
+ joint_axis[axis_start + 1],
273
+ joint_axis[axis_start + 2],
274
+ joint_q[q_start + 0],
275
+ joint_q[q_start + 1],
276
+ joint_q[q_start + 2],
277
+ joint_qd[qd_start + 0],
278
+ joint_qd[qd_start + 1],
279
+ joint_qd[qd_start + 2],
280
+ )
281
+
282
+ t = wp.transform(wp.vec3(0.0, 0.0, 0.0), rot)
283
+ v = wp.spatial_vector(vel_w, wp.vec3(0.0, 0.0, 0.0))
284
+
285
+ X_j = t
286
+ v_j = v
287
+
288
+ if type == wp.sim.JOINT_UNIVERSAL:
289
+ rot, vel_w = compute_2d_rotational_dofs(
290
+ joint_axis[axis_start],
291
+ joint_axis[axis_start + 1],
292
+ joint_q[q_start + 0],
293
+ joint_q[q_start + 1],
294
+ joint_qd[qd_start + 0],
295
+ joint_qd[qd_start + 1],
296
+ )
297
+
298
+ t = wp.transform(wp.vec3(0.0, 0.0, 0.0), rot)
299
+ v = wp.spatial_vector(vel_w, wp.vec3(0.0, 0.0, 0.0))
300
+
301
+ X_j = t
302
+ v_j = v
303
+
304
+ if type == wp.sim.JOINT_D6:
305
+ pos = wp.vec3(0.0)
306
+ rot = wp.quat_identity()
307
+ vel_v = wp.vec3(0.0)
308
+ vel_w = wp.vec3(0.0)
309
+
310
+ # unroll for loop to ensure joint actions remain differentiable
311
+ # (since differentiating through a for loop that updates a local variable is not supported)
312
+
313
+ if lin_axis_count > 0:
314
+ axis = joint_axis[axis_start + 0]
315
+ pos += axis * joint_q[q_start + 0]
316
+ vel_v += axis * joint_qd[qd_start + 0]
317
+ if lin_axis_count > 1:
318
+ axis = joint_axis[axis_start + 1]
319
+ pos += axis * joint_q[q_start + 1]
320
+ vel_v += axis * joint_qd[qd_start + 1]
321
+ if lin_axis_count > 2:
322
+ axis = joint_axis[axis_start + 2]
323
+ pos += axis * joint_q[q_start + 2]
324
+ vel_v += axis * joint_qd[qd_start + 2]
325
+
326
+ ia = axis_start + lin_axis_count
327
+ iq = q_start + lin_axis_count
328
+ iqd = qd_start + lin_axis_count
329
+ if ang_axis_count == 1:
330
+ axis = joint_axis[ia]
331
+ rot = wp.quat_from_axis_angle(axis, joint_q[iq])
332
+ vel_w = joint_qd[iqd] * axis
333
+ if ang_axis_count == 2:
334
+ rot, vel_w = compute_2d_rotational_dofs(
335
+ joint_axis[ia + 0],
336
+ joint_axis[ia + 1],
337
+ joint_q[iq + 0],
338
+ joint_q[iq + 1],
339
+ joint_qd[iqd + 0],
340
+ joint_qd[iqd + 1],
341
+ )
342
+ if ang_axis_count == 3:
343
+ rot, vel_w = compute_3d_rotational_dofs(
344
+ joint_axis[ia + 0],
345
+ joint_axis[ia + 1],
346
+ joint_axis[ia + 2],
347
+ joint_q[iq + 0],
348
+ joint_q[iq + 1],
349
+ joint_q[iq + 2],
350
+ joint_qd[iqd + 0],
351
+ joint_qd[iqd + 1],
352
+ joint_qd[iqd + 2],
353
+ )
354
+
355
+ X_j = wp.transform(pos, rot)
356
+ v_j = wp.spatial_vector(vel_w, vel_v)
357
+
358
+ # transform from world to joint anchor frame at child body
359
+ X_wcj = X_wpj * X_j
360
+ # transform from world to child body frame
361
+ X_wc = X_wcj * wp.transform_inverse(X_cj)
362
+
363
+ # transform velocity across the joint to world space
364
+ angular_vel = wp.transform_vector(X_wpj, wp.spatial_top(v_j))
365
+ linear_vel = wp.transform_vector(X_wpj, wp.spatial_bottom(v_j))
366
+
367
+ v_wc = v_wpj + wp.spatial_vector(angular_vel, linear_vel)
368
+
369
+ body_q[child] = X_wc
370
+ body_qd[child] = v_wc
371
+
372
+
373
+ # implementation where mask is an integer array
374
+ @wp.kernel
375
+ def eval_articulation_fk(
376
+ articulation_start: wp.array(dtype=int),
377
+ articulation_mask: wp.array(
378
+ dtype=int
379
+ ), # used to enable / disable FK for an articulation, if None then treat all as enabled
380
+ joint_q: wp.array(dtype=float),
381
+ joint_qd: wp.array(dtype=float),
382
+ joint_q_start: wp.array(dtype=int),
383
+ joint_qd_start: wp.array(dtype=int),
384
+ joint_type: wp.array(dtype=int),
385
+ joint_parent: wp.array(dtype=int),
386
+ joint_child: wp.array(dtype=int),
387
+ joint_X_p: wp.array(dtype=wp.transform),
388
+ joint_X_c: wp.array(dtype=wp.transform),
389
+ joint_axis: wp.array(dtype=wp.vec3),
390
+ joint_axis_start: wp.array(dtype=int),
391
+ joint_axis_dim: wp.array(dtype=int, ndim=2),
392
+ body_com: wp.array(dtype=wp.vec3),
393
+ # outputs
394
+ body_q: wp.array(dtype=wp.transform),
395
+ body_qd: wp.array(dtype=wp.spatial_vector),
396
+ ):
397
+ tid = wp.tid()
398
+
399
+ # early out if disabling FK for this articulation
400
+ if articulation_mask:
401
+ if articulation_mask[tid] == 0:
402
+ return
403
+
404
+ joint_start = articulation_start[tid]
405
+ joint_end = articulation_start[tid + 1]
406
+
407
+ eval_single_articulation_fk(
408
+ joint_start,
409
+ joint_end,
410
+ joint_q,
411
+ joint_qd,
412
+ joint_q_start,
413
+ joint_qd_start,
414
+ joint_type,
415
+ joint_parent,
416
+ joint_child,
417
+ joint_X_p,
418
+ joint_X_c,
419
+ joint_axis,
420
+ joint_axis_start,
421
+ joint_axis_dim,
422
+ body_com,
423
+ # outputs
424
+ body_q,
425
+ body_qd,
426
+ )
427
+
428
+
429
+ # overload where mask is a bool array
430
+ @wp.kernel
431
+ def eval_articulation_fk(
432
+ articulation_start: wp.array(dtype=int),
433
+ articulation_mask: wp.array(
434
+ dtype=bool
435
+ ), # used to enable / disable FK for an articulation, if None then treat all as enabled
436
+ joint_q: wp.array(dtype=float),
437
+ joint_qd: wp.array(dtype=float),
438
+ joint_q_start: wp.array(dtype=int),
439
+ joint_qd_start: wp.array(dtype=int),
440
+ joint_type: wp.array(dtype=int),
441
+ joint_parent: wp.array(dtype=int),
442
+ joint_child: wp.array(dtype=int),
443
+ joint_X_p: wp.array(dtype=wp.transform),
444
+ joint_X_c: wp.array(dtype=wp.transform),
445
+ joint_axis: wp.array(dtype=wp.vec3),
446
+ joint_axis_start: wp.array(dtype=int),
447
+ joint_axis_dim: wp.array(dtype=int, ndim=2),
448
+ body_com: wp.array(dtype=wp.vec3),
449
+ # outputs
450
+ body_q: wp.array(dtype=wp.transform),
451
+ body_qd: wp.array(dtype=wp.spatial_vector),
452
+ ):
453
+ tid = wp.tid()
454
+
455
+ # early out if disabling FK for this articulation
456
+ if articulation_mask:
457
+ if not articulation_mask[tid]:
458
+ return
459
+
460
+ joint_start = articulation_start[tid]
461
+ joint_end = articulation_start[tid + 1]
462
+
463
+ eval_single_articulation_fk(
464
+ joint_start,
465
+ joint_end,
466
+ joint_q,
467
+ joint_qd,
468
+ joint_q_start,
469
+ joint_qd_start,
470
+ joint_type,
471
+ joint_parent,
472
+ joint_child,
473
+ joint_X_p,
474
+ joint_X_c,
475
+ joint_axis,
476
+ joint_axis_start,
477
+ joint_axis_dim,
478
+ body_com,
479
+ # outputs
480
+ body_q,
481
+ body_qd,
482
+ )
483
+
484
+
485
+ # updates state body information based on joint coordinates
486
+ def eval_fk(model, joint_q, joint_qd, mask, state):
487
+ """
488
+ Evaluates the model's forward kinematics given the joint coordinates and updates the state's body information (:attr:`State.body_q` and :attr:`State.body_qd`).
489
+
490
+ Args:
491
+ model (Model): The model to evaluate.
492
+ joint_q (array): Generalized joint position coordinates, shape [joint_coord_count], float
493
+ joint_qd (array): Generalized joint velocity coordinates, shape [joint_dof_count], float
494
+ mask (array): The mask to use to enable / disable FK for an articulation. If None then treat all as enabled, shape [articulation_count], int/bool
495
+ state (State): The state to update.
496
+ """
497
+ wp.launch(
498
+ kernel=eval_articulation_fk,
499
+ dim=model.articulation_count,
500
+ inputs=[
501
+ model.articulation_start,
502
+ mask,
503
+ joint_q,
504
+ joint_qd,
505
+ model.joint_q_start,
506
+ model.joint_qd_start,
507
+ model.joint_type,
508
+ model.joint_parent,
509
+ model.joint_child,
510
+ model.joint_X_p,
511
+ model.joint_X_c,
512
+ model.joint_axis,
513
+ model.joint_axis_start,
514
+ model.joint_axis_dim,
515
+ model.body_com,
516
+ ],
517
+ outputs=[
518
+ state.body_q,
519
+ state.body_qd,
520
+ ],
521
+ device=model.device,
522
+ )
523
+
524
+
525
+ @wp.func
526
+ def reconstruct_angular_q_qd(q_pc: wp.quat, w_err: wp.vec3, X_wp: wp.transform, axis: wp.vec3):
527
+ """
528
+ Reconstructs the angular joint coordinates and velocities given the relative rotation and angular velocity
529
+ between a parent and child body.
530
+
531
+ Args:
532
+ q_pc (quat): The relative rotation between the parent and child body.
533
+ w_err (vec3): The angular velocity between the parent and child body.
534
+ X_wp (transform): The parent body's transform in world space.
535
+ axis (vec3): The joint axis in the frame of the parent body.
536
+
537
+ Returns:
538
+ q (float): The joint position coordinate.
539
+ qd (float): The joint velocity coordinate.
540
+ """
541
+ axis_p = wp.transform_vector(X_wp, axis)
542
+ twist = quat_twist(axis, q_pc)
543
+ q = wp.acos(twist[3]) * 2.0 * wp.sign(wp.dot(axis, wp.vec3(twist[0], twist[1], twist[2])))
544
+ qd = wp.dot(w_err, axis_p)
545
+ return q, qd
546
+
547
+
548
+ @wp.kernel
549
+ def eval_articulation_ik(
550
+ body_q: wp.array(dtype=wp.transform),
551
+ body_qd: wp.array(dtype=wp.spatial_vector),
552
+ body_com: wp.array(dtype=wp.vec3),
553
+ joint_type: wp.array(dtype=int),
554
+ joint_parent: wp.array(dtype=int),
555
+ joint_child: wp.array(dtype=int),
556
+ joint_X_p: wp.array(dtype=wp.transform),
557
+ joint_X_c: wp.array(dtype=wp.transform),
558
+ joint_axis: wp.array(dtype=wp.vec3),
559
+ joint_axis_start: wp.array(dtype=int),
560
+ joint_axis_dim: wp.array(dtype=int, ndim=2),
561
+ joint_q_start: wp.array(dtype=int),
562
+ joint_qd_start: wp.array(dtype=int),
563
+ joint_q: wp.array(dtype=float),
564
+ joint_qd: wp.array(dtype=float),
565
+ ):
566
+ tid = wp.tid()
567
+
568
+ parent = joint_parent[tid]
569
+ child = joint_child[tid]
570
+
571
+ X_pj = joint_X_p[tid]
572
+ X_cj = joint_X_c[tid]
573
+
574
+ w_p = wp.vec3()
575
+ v_p = wp.vec3()
576
+ v_wp = wp.spatial_vector()
577
+
578
+ # parent anchor frame in world space
579
+ X_wpj = X_pj
580
+ if parent >= 0:
581
+ X_wp = body_q[parent]
582
+ X_wpj = X_wp * X_pj
583
+ r_p = wp.transform_get_translation(X_wpj) - wp.transform_point(X_wp, body_com[parent])
584
+
585
+ v_wp = body_qd[parent]
586
+ w_p = wp.spatial_top(v_wp)
587
+ v_p = wp.spatial_bottom(v_wp) + wp.cross(w_p, r_p)
588
+
589
+ # child transform and moment arm
590
+ X_wc = body_q[child]
591
+ X_wcj = X_wc * X_cj
592
+
593
+ v_wc = body_qd[child]
594
+
595
+ w_c = wp.spatial_top(v_wc)
596
+ v_c = wp.spatial_bottom(v_wc)
597
+
598
+ # joint properties
599
+ type = joint_type[tid]
600
+
601
+ # compute position and orientation differences between anchor frames
602
+ x_p = wp.transform_get_translation(X_wpj)
603
+ x_c = wp.transform_get_translation(X_wcj)
604
+
605
+ q_p = wp.transform_get_rotation(X_wpj)
606
+ q_c = wp.transform_get_rotation(X_wcj)
607
+
608
+ x_err = x_c - x_p
609
+ v_err = v_c - v_p
610
+ w_err = w_c - w_p
611
+
612
+ q_start = joint_q_start[tid]
613
+ qd_start = joint_qd_start[tid]
614
+ axis_start = joint_axis_start[tid]
615
+ lin_axis_count = joint_axis_dim[tid, 0]
616
+ ang_axis_count = joint_axis_dim[tid, 1]
617
+
618
+ if type == wp.sim.JOINT_PRISMATIC:
619
+ axis = joint_axis[axis_start]
620
+
621
+ # world space joint axis
622
+ axis_p = wp.quat_rotate(q_p, axis)
623
+
624
+ # evaluate joint coordinates
625
+ q = wp.dot(x_err, axis_p)
626
+ qd = wp.dot(v_err, axis_p)
627
+
628
+ joint_q[q_start] = q
629
+ joint_qd[qd_start] = qd
630
+
631
+ return
632
+
633
+ if type == wp.sim.JOINT_REVOLUTE:
634
+ axis = joint_axis[axis_start]
635
+ q_pc = wp.quat_inverse(q_p) * q_c
636
+
637
+ q, qd = reconstruct_angular_q_qd(q_pc, w_err, X_wpj, axis)
638
+
639
+ joint_q[q_start] = q
640
+ joint_qd[qd_start] = qd
641
+
642
+ return
643
+
644
+ if type == wp.sim.JOINT_BALL:
645
+ q_pc = wp.quat_inverse(q_p) * q_c
646
+
647
+ joint_q[q_start + 0] = q_pc[0]
648
+ joint_q[q_start + 1] = q_pc[1]
649
+ joint_q[q_start + 2] = q_pc[2]
650
+ joint_q[q_start + 3] = q_pc[3]
651
+
652
+ ang_vel = wp.transform_vector(wp.transform_inverse(X_wpj), w_err)
653
+ joint_qd[qd_start + 0] = ang_vel[0]
654
+ joint_qd[qd_start + 1] = ang_vel[1]
655
+ joint_qd[qd_start + 2] = ang_vel[2]
656
+
657
+ return
658
+
659
+ if type == wp.sim.JOINT_FIXED:
660
+ return
661
+
662
+ if type == wp.sim.JOINT_FREE or type == wp.sim.JOINT_DISTANCE:
663
+ q_pc = wp.quat_inverse(q_p) * q_c
664
+
665
+ x_err_c = wp.quat_rotate_inv(q_p, x_err)
666
+ v_err_c = wp.quat_rotate_inv(q_p, v_err)
667
+ w_err_c = wp.quat_rotate_inv(q_p, w_err)
668
+
669
+ joint_q[q_start + 0] = x_err_c[0]
670
+ joint_q[q_start + 1] = x_err_c[1]
671
+ joint_q[q_start + 2] = x_err_c[2]
672
+
673
+ joint_q[q_start + 3] = q_pc[0]
674
+ joint_q[q_start + 4] = q_pc[1]
675
+ joint_q[q_start + 5] = q_pc[2]
676
+ joint_q[q_start + 6] = q_pc[3]
677
+
678
+ joint_qd[qd_start + 0] = w_err_c[0]
679
+ joint_qd[qd_start + 1] = w_err_c[1]
680
+ joint_qd[qd_start + 2] = w_err_c[2]
681
+
682
+ joint_qd[qd_start + 3] = v_err_c[0]
683
+ joint_qd[qd_start + 4] = v_err_c[1]
684
+ joint_qd[qd_start + 5] = v_err_c[2]
685
+
686
+ return
687
+
688
+ if type == wp.sim.JOINT_COMPOUND:
689
+ axis_0 = joint_axis[axis_start + 0]
690
+ axis_1 = joint_axis[axis_start + 1]
691
+ axis_2 = joint_axis[axis_start + 2]
692
+ qs, qds = invert_3d_rotational_dofs(axis_0, axis_1, axis_2, q_p, q_c, w_err)
693
+ joint_q[q_start + 0] = qs[0]
694
+ joint_q[q_start + 1] = qs[1]
695
+ joint_q[q_start + 2] = qs[2]
696
+ joint_qd[qd_start + 0] = qds[0]
697
+ joint_qd[qd_start + 1] = qds[1]
698
+ joint_qd[qd_start + 2] = qds[2]
699
+
700
+ return
701
+
702
+ if type == wp.sim.JOINT_UNIVERSAL:
703
+ axis_0 = joint_axis[axis_start + 0]
704
+ axis_1 = joint_axis[axis_start + 1]
705
+ qs2, qds2 = invert_2d_rotational_dofs(axis_0, axis_1, q_p, q_c, w_err)
706
+ joint_q[q_start + 0] = qs2[0]
707
+ joint_q[q_start + 1] = qs2[1]
708
+ joint_qd[qd_start + 0] = qds2[0]
709
+ joint_qd[qd_start + 1] = qds2[1]
710
+
711
+ return
712
+
713
+ if type == wp.sim.JOINT_D6:
714
+ x_err_c = wp.quat_rotate_inv(q_p, x_err)
715
+ v_err_c = wp.quat_rotate_inv(q_p, v_err)
716
+ if lin_axis_count > 0:
717
+ axis = joint_axis[axis_start + 0]
718
+ joint_q[q_start + 0] = wp.dot(x_err_c, axis)
719
+ joint_qd[qd_start + 0] = wp.dot(v_err_c, axis)
720
+
721
+ if lin_axis_count > 1:
722
+ axis = joint_axis[axis_start + 1]
723
+ joint_q[q_start + 1] = wp.dot(x_err_c, axis)
724
+ joint_qd[qd_start + 1] = wp.dot(v_err_c, axis)
725
+
726
+ if lin_axis_count > 2:
727
+ axis = joint_axis[axis_start + 2]
728
+ joint_q[q_start + 2] = wp.dot(x_err_c, axis)
729
+ joint_qd[qd_start + 2] = wp.dot(v_err_c, axis)
730
+
731
+ if ang_axis_count == 1:
732
+ axis = joint_axis[axis_start]
733
+ q_pc = wp.quat_inverse(q_p) * q_c
734
+ q, qd = reconstruct_angular_q_qd(q_pc, w_err, X_wpj, joint_axis[axis_start + lin_axis_count])
735
+ joint_q[q_start + lin_axis_count] = q
736
+ joint_qd[qd_start + lin_axis_count] = qd
737
+
738
+ if ang_axis_count == 2:
739
+ axis_0 = joint_axis[axis_start + lin_axis_count + 0]
740
+ axis_1 = joint_axis[axis_start + lin_axis_count + 1]
741
+ qs2, qds2 = invert_2d_rotational_dofs(axis_0, axis_1, q_p, q_c, w_err)
742
+ joint_q[q_start + lin_axis_count + 0] = qs2[0]
743
+ joint_q[q_start + lin_axis_count + 1] = qs2[1]
744
+ joint_qd[qd_start + lin_axis_count + 0] = qds2[0]
745
+ joint_qd[qd_start + lin_axis_count + 1] = qds2[1]
746
+
747
+ if ang_axis_count == 3:
748
+ axis_0 = joint_axis[axis_start + lin_axis_count + 0]
749
+ axis_1 = joint_axis[axis_start + lin_axis_count + 1]
750
+ axis_2 = joint_axis[axis_start + lin_axis_count + 2]
751
+ qs3, qds3 = invert_3d_rotational_dofs(axis_0, axis_1, axis_2, q_p, q_c, w_err)
752
+ joint_q[q_start + lin_axis_count + 0] = qs3[0]
753
+ joint_q[q_start + lin_axis_count + 1] = qs3[1]
754
+ joint_q[q_start + lin_axis_count + 2] = qs3[2]
755
+ joint_qd[qd_start + lin_axis_count + 0] = qds3[0]
756
+ joint_qd[qd_start + lin_axis_count + 1] = qds3[1]
757
+ joint_qd[qd_start + lin_axis_count + 2] = qds3[2]
758
+
759
+ return
760
+
761
+
762
+ # given maximal coordinate model computes ik (closest point projection)
763
+ def eval_ik(model, state, joint_q, joint_qd):
764
+ """
765
+ Evaluates the model's inverse kinematics given the state's body information (:attr:`State.body_q` and :attr:`State.body_qd`) and updates the generalized joint coordinates `joint_q` and `joint_qd`.
766
+
767
+ Args:
768
+ model (Model): The model to evaluate.
769
+ state (State): The state with the body's maximal coordinates (positions :attr:`State.body_q` and velocities :attr:`State.body_qd`) to use.
770
+ joint_q (array): Generalized joint position coordinates, shape [joint_coord_count], float
771
+ joint_qd (array): Generalized joint velocity coordinates, shape [joint_dof_count], float
772
+ """
773
+ wp.launch(
774
+ kernel=eval_articulation_ik,
775
+ dim=model.joint_count,
776
+ inputs=[
777
+ state.body_q,
778
+ state.body_qd,
779
+ model.body_com,
780
+ model.joint_type,
781
+ model.joint_parent,
782
+ model.joint_child,
783
+ model.joint_X_p,
784
+ model.joint_X_c,
785
+ model.joint_axis,
786
+ model.joint_axis_start,
787
+ model.joint_axis_dim,
788
+ model.joint_q_start,
789
+ model.joint_qd_start,
790
+ ],
791
+ outputs=[joint_q, joint_qd],
792
+ device=model.device,
793
+ )