warp-lang 1.9.0__py3-none-win_amd64.whl → 1.10.0rc2__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (350) hide show
  1. warp/__init__.py +301 -287
  2. warp/__init__.pyi +2220 -313
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1075 -0
  5. warp/_src/build.py +618 -0
  6. warp/_src/build_dll.py +640 -0
  7. warp/{builtins.py → _src/builtins.py} +1497 -226
  8. warp/_src/codegen.py +4359 -0
  9. warp/{config.py → _src/config.py} +178 -169
  10. warp/_src/constants.py +57 -0
  11. warp/_src/context.py +8294 -0
  12. warp/_src/dlpack.py +462 -0
  13. warp/_src/fabric.py +355 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +508 -0
  16. warp/_src/fem/cache.py +687 -0
  17. warp/_src/fem/dirichlet.py +188 -0
  18. warp/{fem → _src/fem}/domain.py +40 -30
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +701 -0
  21. warp/{fem → _src/fem}/field/nodal_field.py +30 -15
  22. warp/{fem → _src/fem}/field/restriction.py +1 -1
  23. warp/{fem → _src/fem}/field/virtual.py +53 -27
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +77 -163
  26. warp/_src/fem/geometry/closest_point.py +97 -0
  27. warp/{fem → _src/fem}/geometry/deformed_geometry.py +14 -22
  28. warp/{fem → _src/fem}/geometry/element.py +32 -10
  29. warp/{fem → _src/fem}/geometry/geometry.py +48 -20
  30. warp/{fem → _src/fem}/geometry/grid_2d.py +12 -23
  31. warp/{fem → _src/fem}/geometry/grid_3d.py +12 -23
  32. warp/{fem → _src/fem}/geometry/hexmesh.py +40 -63
  33. warp/{fem → _src/fem}/geometry/nanogrid.py +255 -248
  34. warp/{fem → _src/fem}/geometry/partition.py +121 -63
  35. warp/{fem → _src/fem}/geometry/quadmesh.py +26 -45
  36. warp/{fem → _src/fem}/geometry/tetmesh.py +40 -63
  37. warp/{fem → _src/fem}/geometry/trimesh.py +26 -45
  38. warp/{fem → _src/fem}/integrate.py +164 -158
  39. warp/_src/fem/linalg.py +383 -0
  40. warp/_src/fem/operator.py +396 -0
  41. warp/_src/fem/polynomial.py +229 -0
  42. warp/{fem → _src/fem}/quadrature/pic_quadrature.py +15 -20
  43. warp/{fem → _src/fem}/quadrature/quadrature.py +95 -47
  44. warp/_src/fem/space/__init__.py +248 -0
  45. warp/{fem → _src/fem}/space/basis_function_space.py +20 -11
  46. warp/_src/fem/space/basis_space.py +679 -0
  47. warp/{fem → _src/fem}/space/dof_mapper.py +3 -3
  48. warp/{fem → _src/fem}/space/function_space.py +14 -13
  49. warp/{fem → _src/fem}/space/grid_2d_function_space.py +4 -7
  50. warp/{fem → _src/fem}/space/grid_3d_function_space.py +4 -4
  51. warp/{fem → _src/fem}/space/hexmesh_function_space.py +4 -10
  52. warp/{fem → _src/fem}/space/nanogrid_function_space.py +3 -9
  53. warp/{fem → _src/fem}/space/partition.py +117 -60
  54. warp/{fem → _src/fem}/space/quadmesh_function_space.py +4 -10
  55. warp/{fem → _src/fem}/space/restriction.py +66 -33
  56. warp/_src/fem/space/shape/__init__.py +152 -0
  57. warp/{fem → _src/fem}/space/shape/cube_shape_function.py +9 -9
  58. warp/{fem → _src/fem}/space/shape/shape_function.py +8 -9
  59. warp/{fem → _src/fem}/space/shape/square_shape_function.py +6 -6
  60. warp/{fem → _src/fem}/space/shape/tet_shape_function.py +3 -3
  61. warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +3 -3
  62. warp/{fem → _src/fem}/space/tetmesh_function_space.py +3 -9
  63. warp/_src/fem/space/topology.py +459 -0
  64. warp/{fem → _src/fem}/space/trimesh_function_space.py +3 -9
  65. warp/_src/fem/types.py +112 -0
  66. warp/_src/fem/utils.py +486 -0
  67. warp/_src/jax.py +186 -0
  68. warp/_src/jax_experimental/__init__.py +14 -0
  69. warp/_src/jax_experimental/custom_call.py +387 -0
  70. warp/_src/jax_experimental/ffi.py +1284 -0
  71. warp/_src/jax_experimental/xla_ffi.py +656 -0
  72. warp/_src/marching_cubes.py +708 -0
  73. warp/_src/math.py +414 -0
  74. warp/_src/optim/__init__.py +14 -0
  75. warp/_src/optim/adam.py +163 -0
  76. warp/_src/optim/linear.py +1606 -0
  77. warp/_src/optim/sgd.py +112 -0
  78. warp/_src/paddle.py +406 -0
  79. warp/_src/render/__init__.py +14 -0
  80. warp/_src/render/imgui_manager.py +289 -0
  81. warp/_src/render/render_opengl.py +3636 -0
  82. warp/_src/render/render_usd.py +937 -0
  83. warp/_src/render/utils.py +160 -0
  84. warp/_src/sparse.py +2716 -0
  85. warp/_src/tape.py +1206 -0
  86. warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
  87. warp/_src/torch.py +391 -0
  88. warp/_src/types.py +5870 -0
  89. warp/_src/utils.py +1693 -0
  90. warp/autograd.py +12 -1054
  91. warp/bin/warp-clang.dll +0 -0
  92. warp/bin/warp.dll +0 -0
  93. warp/build.py +8 -588
  94. warp/build_dll.py +6 -471
  95. warp/codegen.py +6 -4246
  96. warp/constants.py +6 -39
  97. warp/context.py +12 -7851
  98. warp/dlpack.py +6 -444
  99. warp/examples/distributed/example_jacobi_mpi.py +4 -5
  100. warp/examples/fem/example_adaptive_grid.py +1 -1
  101. warp/examples/fem/example_apic_fluid.py +1 -1
  102. warp/examples/fem/example_burgers.py +8 -8
  103. warp/examples/fem/example_diffusion.py +1 -1
  104. warp/examples/fem/example_distortion_energy.py +1 -1
  105. warp/examples/fem/example_mixed_elasticity.py +2 -2
  106. warp/examples/fem/example_navier_stokes.py +1 -1
  107. warp/examples/fem/example_nonconforming_contact.py +7 -7
  108. warp/examples/fem/example_stokes.py +1 -1
  109. warp/examples/fem/example_stokes_transfer.py +1 -1
  110. warp/examples/fem/utils.py +2 -2
  111. warp/examples/interop/example_jax_callable.py +1 -1
  112. warp/examples/interop/example_jax_ffi_callback.py +1 -1
  113. warp/examples/interop/example_jax_kernel.py +3 -2
  114. warp/examples/tile/example_tile_mcgp.py +191 -0
  115. warp/fabric.py +6 -337
  116. warp/fem/__init__.py +159 -97
  117. warp/fem/adaptivity.py +7 -489
  118. warp/fem/cache.py +9 -648
  119. warp/fem/dirichlet.py +6 -184
  120. warp/fem/field/__init__.py +8 -109
  121. warp/fem/field/field.py +7 -652
  122. warp/fem/geometry/__init__.py +7 -18
  123. warp/fem/geometry/closest_point.py +11 -77
  124. warp/fem/linalg.py +18 -366
  125. warp/fem/operator.py +11 -369
  126. warp/fem/polynomial.py +9 -209
  127. warp/fem/space/__init__.py +5 -211
  128. warp/fem/space/basis_space.py +6 -662
  129. warp/fem/space/shape/__init__.py +41 -118
  130. warp/fem/space/topology.py +6 -437
  131. warp/fem/types.py +6 -81
  132. warp/fem/utils.py +11 -444
  133. warp/jax.py +8 -165
  134. warp/jax_experimental/__init__.py +14 -1
  135. warp/jax_experimental/custom_call.py +8 -342
  136. warp/jax_experimental/ffi.py +17 -853
  137. warp/jax_experimental/xla_ffi.py +5 -596
  138. warp/marching_cubes.py +5 -689
  139. warp/math.py +16 -393
  140. warp/native/array.h +385 -37
  141. warp/native/builtin.h +316 -39
  142. warp/native/bvh.cpp +43 -9
  143. warp/native/bvh.cu +62 -27
  144. warp/native/bvh.h +310 -309
  145. warp/native/clang/clang.cpp +102 -97
  146. warp/native/coloring.cpp +0 -1
  147. warp/native/crt.h +208 -0
  148. warp/native/exports.h +156 -0
  149. warp/native/hashgrid.cu +2 -0
  150. warp/native/intersect.h +24 -1
  151. warp/native/intersect_tri.h +44 -35
  152. warp/native/mat.h +1456 -276
  153. warp/native/mesh.cpp +4 -4
  154. warp/native/mesh.cu +4 -2
  155. warp/native/mesh.h +176 -61
  156. warp/native/quat.h +0 -52
  157. warp/native/scan.cu +2 -0
  158. warp/native/sort.cu +22 -13
  159. warp/native/sort.h +2 -0
  160. warp/native/sparse.cu +7 -3
  161. warp/native/spatial.h +12 -0
  162. warp/native/tile.h +837 -70
  163. warp/native/tile_radix_sort.h +1 -1
  164. warp/native/tile_reduce.h +394 -46
  165. warp/native/tile_scan.h +4 -4
  166. warp/native/vec.h +469 -53
  167. warp/native/version.h +23 -0
  168. warp/native/volume.cpp +1 -1
  169. warp/native/volume.cu +1 -0
  170. warp/native/volume.h +1 -1
  171. warp/native/volume_builder.cu +2 -0
  172. warp/native/warp.cpp +60 -32
  173. warp/native/warp.cu +313 -201
  174. warp/native/warp.h +14 -11
  175. warp/optim/__init__.py +6 -3
  176. warp/optim/adam.py +6 -145
  177. warp/optim/linear.py +14 -1585
  178. warp/optim/sgd.py +6 -94
  179. warp/paddle.py +6 -388
  180. warp/render/__init__.py +8 -4
  181. warp/render/imgui_manager.py +7 -267
  182. warp/render/render_opengl.py +6 -3616
  183. warp/render/render_usd.py +6 -918
  184. warp/render/utils.py +6 -142
  185. warp/sparse.py +37 -2563
  186. warp/tape.py +6 -1188
  187. warp/tests/__main__.py +1 -1
  188. warp/tests/cuda/test_async.py +4 -4
  189. warp/tests/cuda/test_conditional_captures.py +1 -1
  190. warp/tests/cuda/test_multigpu.py +1 -1
  191. warp/tests/cuda/test_streams.py +58 -1
  192. warp/tests/geometry/test_bvh.py +157 -22
  193. warp/tests/geometry/test_hash_grid.py +38 -0
  194. warp/tests/geometry/test_marching_cubes.py +0 -1
  195. warp/tests/geometry/test_mesh.py +5 -3
  196. warp/tests/geometry/test_mesh_query_aabb.py +5 -12
  197. warp/tests/geometry/test_mesh_query_point.py +5 -2
  198. warp/tests/geometry/test_mesh_query_ray.py +15 -3
  199. warp/tests/geometry/test_volume_write.py +5 -5
  200. warp/tests/interop/test_dlpack.py +14 -14
  201. warp/tests/interop/test_jax.py +1382 -79
  202. warp/tests/interop/test_paddle.py +1 -1
  203. warp/tests/test_adam.py +0 -1
  204. warp/tests/test_arithmetic.py +9 -9
  205. warp/tests/test_array.py +529 -100
  206. warp/tests/test_array_reduce.py +3 -3
  207. warp/tests/test_atomic.py +12 -8
  208. warp/tests/test_atomic_bitwise.py +209 -0
  209. warp/tests/test_atomic_cas.py +4 -4
  210. warp/tests/test_bool.py +2 -2
  211. warp/tests/test_builtins_resolution.py +5 -571
  212. warp/tests/test_codegen.py +34 -15
  213. warp/tests/test_conditional.py +1 -1
  214. warp/tests/test_context.py +6 -6
  215. warp/tests/test_copy.py +242 -161
  216. warp/tests/test_ctypes.py +3 -3
  217. warp/tests/test_devices.py +24 -2
  218. warp/tests/test_examples.py +16 -84
  219. warp/tests/test_fabricarray.py +35 -35
  220. warp/tests/test_fast_math.py +0 -2
  221. warp/tests/test_fem.py +60 -14
  222. warp/tests/test_fixedarray.py +3 -3
  223. warp/tests/test_func.py +8 -5
  224. warp/tests/test_generics.py +1 -1
  225. warp/tests/test_indexedarray.py +24 -24
  226. warp/tests/test_intersect.py +39 -9
  227. warp/tests/test_large.py +1 -1
  228. warp/tests/test_lerp.py +3 -1
  229. warp/tests/test_linear_solvers.py +1 -1
  230. warp/tests/test_map.py +49 -4
  231. warp/tests/test_mat.py +52 -62
  232. warp/tests/test_mat_constructors.py +4 -5
  233. warp/tests/test_mat_lite.py +1 -1
  234. warp/tests/test_mat_scalar_ops.py +121 -121
  235. warp/tests/test_math.py +34 -0
  236. warp/tests/test_module_aot.py +4 -4
  237. warp/tests/test_modules_lite.py +28 -2
  238. warp/tests/test_print.py +11 -11
  239. warp/tests/test_quat.py +93 -58
  240. warp/tests/test_runlength_encode.py +1 -1
  241. warp/tests/test_scalar_ops.py +38 -10
  242. warp/tests/test_smoothstep.py +1 -1
  243. warp/tests/test_sparse.py +126 -15
  244. warp/tests/test_spatial.py +105 -87
  245. warp/tests/test_special_values.py +6 -6
  246. warp/tests/test_static.py +7 -7
  247. warp/tests/test_struct.py +13 -2
  248. warp/tests/test_triangle_closest_point.py +48 -1
  249. warp/tests/test_tuple.py +96 -0
  250. warp/tests/test_types.py +82 -9
  251. warp/tests/test_utils.py +52 -52
  252. warp/tests/test_vec.py +29 -29
  253. warp/tests/test_vec_constructors.py +5 -5
  254. warp/tests/test_vec_scalar_ops.py +97 -97
  255. warp/tests/test_version.py +75 -0
  256. warp/tests/tile/test_tile.py +239 -0
  257. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  258. warp/tests/tile/test_tile_cholesky.py +7 -4
  259. warp/tests/tile/test_tile_load.py +26 -2
  260. warp/tests/tile/test_tile_mathdx.py +3 -3
  261. warp/tests/tile/test_tile_matmul.py +1 -1
  262. warp/tests/tile/test_tile_mlp.py +2 -4
  263. warp/tests/tile/test_tile_reduce.py +214 -13
  264. warp/tests/unittest_suites.py +6 -14
  265. warp/tests/unittest_utils.py +10 -9
  266. warp/tests/walkthrough_debug.py +3 -1
  267. warp/torch.py +6 -373
  268. warp/types.py +29 -5750
  269. warp/utils.py +10 -1659
  270. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/METADATA +47 -103
  271. warp_lang-1.10.0rc2.dist-info/RECORD +468 -0
  272. warp_lang-1.10.0rc2.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  273. warp_lang-1.10.0rc2.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  274. warp_lang-1.10.0rc2.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  275. warp_lang-1.10.0rc2.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  276. warp_lang-1.10.0rc2.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  277. warp_lang-1.10.0rc2.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  278. warp_lang-1.10.0rc2.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  279. warp_lang-1.10.0rc2.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  280. warp_lang-1.10.0rc2.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  281. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  282. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  283. warp_lang-1.10.0rc2.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  284. warp_lang-1.10.0rc2.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  285. warp_lang-1.10.0rc2.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  286. warp_lang-1.10.0rc2.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  287. warp/examples/assets/cartpole.urdf +0 -110
  288. warp/examples/assets/crazyflie.usd +0 -0
  289. warp/examples/assets/nv_ant.xml +0 -92
  290. warp/examples/assets/nv_humanoid.xml +0 -183
  291. warp/examples/assets/quadruped.urdf +0 -268
  292. warp/examples/optim/example_bounce.py +0 -266
  293. warp/examples/optim/example_cloth_throw.py +0 -228
  294. warp/examples/optim/example_drone.py +0 -870
  295. warp/examples/optim/example_inverse_kinematics.py +0 -182
  296. warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
  297. warp/examples/optim/example_softbody_properties.py +0 -400
  298. warp/examples/optim/example_spring_cage.py +0 -245
  299. warp/examples/optim/example_trajectory.py +0 -227
  300. warp/examples/sim/example_cartpole.py +0 -143
  301. warp/examples/sim/example_cloth.py +0 -225
  302. warp/examples/sim/example_cloth_self_contact.py +0 -316
  303. warp/examples/sim/example_granular.py +0 -130
  304. warp/examples/sim/example_granular_collision_sdf.py +0 -202
  305. warp/examples/sim/example_jacobian_ik.py +0 -244
  306. warp/examples/sim/example_particle_chain.py +0 -124
  307. warp/examples/sim/example_quadruped.py +0 -203
  308. warp/examples/sim/example_rigid_chain.py +0 -203
  309. warp/examples/sim/example_rigid_contact.py +0 -195
  310. warp/examples/sim/example_rigid_force.py +0 -133
  311. warp/examples/sim/example_rigid_gyroscopic.py +0 -115
  312. warp/examples/sim/example_rigid_soft_contact.py +0 -140
  313. warp/examples/sim/example_soft_body.py +0 -196
  314. warp/examples/tile/example_tile_walker.py +0 -327
  315. warp/sim/__init__.py +0 -74
  316. warp/sim/articulation.py +0 -793
  317. warp/sim/collide.py +0 -2570
  318. warp/sim/graph_coloring.py +0 -307
  319. warp/sim/import_mjcf.py +0 -791
  320. warp/sim/import_snu.py +0 -227
  321. warp/sim/import_urdf.py +0 -579
  322. warp/sim/import_usd.py +0 -898
  323. warp/sim/inertia.py +0 -357
  324. warp/sim/integrator.py +0 -245
  325. warp/sim/integrator_euler.py +0 -2000
  326. warp/sim/integrator_featherstone.py +0 -2101
  327. warp/sim/integrator_vbd.py +0 -2487
  328. warp/sim/integrator_xpbd.py +0 -3295
  329. warp/sim/model.py +0 -4821
  330. warp/sim/particles.py +0 -121
  331. warp/sim/render.py +0 -431
  332. warp/sim/utils.py +0 -431
  333. warp/tests/sim/disabled_kinematics.py +0 -244
  334. warp/tests/sim/test_cloth.py +0 -863
  335. warp/tests/sim/test_collision.py +0 -743
  336. warp/tests/sim/test_coloring.py +0 -347
  337. warp/tests/sim/test_inertia.py +0 -161
  338. warp/tests/sim/test_model.py +0 -226
  339. warp/tests/sim/test_sim_grad.py +0 -287
  340. warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
  341. warp/tests/sim/test_sim_kinematics.py +0 -98
  342. warp/thirdparty/__init__.py +0 -0
  343. warp_lang-1.9.0.dist-info/RECORD +0 -456
  344. /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
  345. /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
  346. /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
  347. /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
  348. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/WHEEL +0 -0
  349. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/licenses/LICENSE.md +0 -0
  350. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/top_level.txt +0 -0
@@ -1,2101 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import warp as wp
17
-
18
- from .articulation import (
19
- compute_2d_rotational_dofs,
20
- compute_3d_rotational_dofs,
21
- eval_fk,
22
- )
23
- from .integrator import Integrator
24
- from .integrator_euler import (
25
- eval_bending_forces,
26
- eval_joint_force,
27
- eval_muscle_forces,
28
- eval_particle_body_contact_forces,
29
- eval_particle_forces,
30
- eval_particle_ground_contact_forces,
31
- eval_rigid_contacts,
32
- eval_spring_forces,
33
- eval_tetrahedral_forces,
34
- eval_triangle_contact_forces,
35
- eval_triangle_forces,
36
- )
37
- from .model import Control, Model, State
38
-
39
-
40
- # Frank & Park definition 3.20, pg 100
41
- @wp.func
42
- def transform_twist(t: wp.transform, x: wp.spatial_vector):
43
- q = wp.transform_get_rotation(t)
44
- p = wp.transform_get_translation(t)
45
-
46
- w = wp.spatial_top(x)
47
- v = wp.spatial_bottom(x)
48
-
49
- w = wp.quat_rotate(q, w)
50
- v = wp.quat_rotate(q, v) + wp.cross(p, w)
51
-
52
- return wp.spatial_vector(w, v)
53
-
54
-
55
- @wp.func
56
- def transform_wrench(t: wp.transform, x: wp.spatial_vector):
57
- q = wp.transform_get_rotation(t)
58
- p = wp.transform_get_translation(t)
59
-
60
- w = wp.spatial_top(x)
61
- v = wp.spatial_bottom(x)
62
-
63
- v = wp.quat_rotate(q, v)
64
- w = wp.quat_rotate(q, w) + wp.cross(p, v)
65
-
66
- return wp.spatial_vector(w, v)
67
-
68
-
69
- @wp.func
70
- def spatial_adjoint(R: wp.mat33, S: wp.mat33):
71
- # T = [R 0]
72
- # [S R]
73
-
74
- # fmt: off
75
- return wp.spatial_matrix(
76
- R[0, 0], R[0, 1], R[0, 2], 0.0, 0.0, 0.0,
77
- R[1, 0], R[1, 1], R[1, 2], 0.0, 0.0, 0.0,
78
- R[2, 0], R[2, 1], R[2, 2], 0.0, 0.0, 0.0,
79
- S[0, 0], S[0, 1], S[0, 2], R[0, 0], R[0, 1], R[0, 2],
80
- S[1, 0], S[1, 1], S[1, 2], R[1, 0], R[1, 1], R[1, 2],
81
- S[2, 0], S[2, 1], S[2, 2], R[2, 0], R[2, 1], R[2, 2],
82
- )
83
- # fmt: on
84
-
85
-
86
- @wp.kernel
87
- def compute_spatial_inertia(
88
- body_inertia: wp.array(dtype=wp.mat33),
89
- body_mass: wp.array(dtype=float),
90
- # outputs
91
- body_I_m: wp.array(dtype=wp.spatial_matrix),
92
- ):
93
- tid = wp.tid()
94
- I = body_inertia[tid]
95
- m = body_mass[tid]
96
- # fmt: off
97
- body_I_m[tid] = wp.spatial_matrix(
98
- I[0, 0], I[0, 1], I[0, 2], 0.0, 0.0, 0.0,
99
- I[1, 0], I[1, 1], I[1, 2], 0.0, 0.0, 0.0,
100
- I[2, 0], I[2, 1], I[2, 2], 0.0, 0.0, 0.0,
101
- 0.0, 0.0, 0.0, m, 0.0, 0.0,
102
- 0.0, 0.0, 0.0, 0.0, m, 0.0,
103
- 0.0, 0.0, 0.0, 0.0, 0.0, m,
104
- )
105
- # fmt: on
106
-
107
-
108
- @wp.kernel
109
- def compute_com_transforms(
110
- body_com: wp.array(dtype=wp.vec3),
111
- # outputs
112
- body_X_com: wp.array(dtype=wp.transform),
113
- ):
114
- tid = wp.tid()
115
- com = body_com[tid]
116
- body_X_com[tid] = wp.transform(com, wp.quat_identity())
117
-
118
-
119
- # computes adj_t^-T*I*adj_t^-1 (tensor change of coordinates), Frank & Park, section 8.2.3, pg 290
120
- @wp.func
121
- def spatial_transform_inertia(t: wp.transform, I: wp.spatial_matrix):
122
- t_inv = wp.transform_inverse(t)
123
-
124
- q = wp.transform_get_rotation(t_inv)
125
- p = wp.transform_get_translation(t_inv)
126
-
127
- r1 = wp.quat_rotate(q, wp.vec3(1.0, 0.0, 0.0))
128
- r2 = wp.quat_rotate(q, wp.vec3(0.0, 1.0, 0.0))
129
- r3 = wp.quat_rotate(q, wp.vec3(0.0, 0.0, 1.0))
130
-
131
- R = wp.matrix_from_cols(r1, r2, r3)
132
- S = wp.skew(p) @ R
133
-
134
- T = spatial_adjoint(R, S)
135
-
136
- return wp.mul(wp.mul(wp.transpose(T), I), T)
137
-
138
-
139
- # compute transform across a joint
140
- @wp.func
141
- def jcalc_transform(
142
- type: int,
143
- joint_axis: wp.array(dtype=wp.vec3),
144
- axis_start: int,
145
- lin_axis_count: int,
146
- ang_axis_count: int,
147
- joint_q: wp.array(dtype=float),
148
- start: int,
149
- ):
150
- if type == wp.sim.JOINT_PRISMATIC:
151
- q = joint_q[start]
152
- axis = joint_axis[axis_start]
153
- X_jc = wp.transform(axis * q, wp.quat_identity())
154
- return X_jc
155
-
156
- if type == wp.sim.JOINT_REVOLUTE:
157
- q = joint_q[start]
158
- axis = joint_axis[axis_start]
159
- X_jc = wp.transform(wp.vec3(), wp.quat_from_axis_angle(axis, q))
160
- return X_jc
161
-
162
- if type == wp.sim.JOINT_BALL:
163
- qx = joint_q[start + 0]
164
- qy = joint_q[start + 1]
165
- qz = joint_q[start + 2]
166
- qw = joint_q[start + 3]
167
-
168
- X_jc = wp.transform(wp.vec3(), wp.quat(qx, qy, qz, qw))
169
- return X_jc
170
-
171
- if type == wp.sim.JOINT_FIXED:
172
- X_jc = wp.transform_identity()
173
- return X_jc
174
-
175
- if type == wp.sim.JOINT_FREE or type == wp.sim.JOINT_DISTANCE:
176
- px = joint_q[start + 0]
177
- py = joint_q[start + 1]
178
- pz = joint_q[start + 2]
179
-
180
- qx = joint_q[start + 3]
181
- qy = joint_q[start + 4]
182
- qz = joint_q[start + 5]
183
- qw = joint_q[start + 6]
184
-
185
- X_jc = wp.transform(wp.vec3(px, py, pz), wp.quat(qx, qy, qz, qw))
186
- return X_jc
187
-
188
- if type == wp.sim.JOINT_COMPOUND:
189
- rot, _ = compute_3d_rotational_dofs(
190
- joint_axis[axis_start],
191
- joint_axis[axis_start + 1],
192
- joint_axis[axis_start + 2],
193
- joint_q[start + 0],
194
- joint_q[start + 1],
195
- joint_q[start + 2],
196
- 0.0,
197
- 0.0,
198
- 0.0,
199
- )
200
-
201
- X_jc = wp.transform(wp.vec3(), rot)
202
- return X_jc
203
-
204
- if type == wp.sim.JOINT_UNIVERSAL:
205
- rot, _ = compute_2d_rotational_dofs(
206
- joint_axis[axis_start],
207
- joint_axis[axis_start + 1],
208
- joint_q[start + 0],
209
- joint_q[start + 1],
210
- 0.0,
211
- 0.0,
212
- )
213
-
214
- X_jc = wp.transform(wp.vec3(), rot)
215
- return X_jc
216
-
217
- if type == wp.sim.JOINT_D6:
218
- pos = wp.vec3(0.0)
219
- rot = wp.quat_identity()
220
-
221
- # unroll for loop to ensure joint actions remain differentiable
222
- # (since differentiating through a for loop that updates a local variable is not supported)
223
-
224
- if lin_axis_count > 0:
225
- axis = joint_axis[axis_start + 0]
226
- pos += axis * joint_q[start + 0]
227
- if lin_axis_count > 1:
228
- axis = joint_axis[axis_start + 1]
229
- pos += axis * joint_q[start + 1]
230
- if lin_axis_count > 2:
231
- axis = joint_axis[axis_start + 2]
232
- pos += axis * joint_q[start + 2]
233
-
234
- ia = axis_start + lin_axis_count
235
- iq = start + lin_axis_count
236
- if ang_axis_count == 1:
237
- axis = joint_axis[ia]
238
- rot = wp.quat_from_axis_angle(axis, joint_q[iq])
239
- if ang_axis_count == 2:
240
- rot, _ = compute_2d_rotational_dofs(
241
- joint_axis[ia + 0],
242
- joint_axis[ia + 1],
243
- joint_q[iq + 0],
244
- joint_q[iq + 1],
245
- 0.0,
246
- 0.0,
247
- )
248
- if ang_axis_count == 3:
249
- rot, _ = compute_3d_rotational_dofs(
250
- joint_axis[ia + 0],
251
- joint_axis[ia + 1],
252
- joint_axis[ia + 2],
253
- joint_q[iq + 0],
254
- joint_q[iq + 1],
255
- joint_q[iq + 2],
256
- 0.0,
257
- 0.0,
258
- 0.0,
259
- )
260
-
261
- X_jc = wp.transform(pos, rot)
262
- return X_jc
263
-
264
- # default case
265
- return wp.transform_identity()
266
-
267
-
268
- # compute motion subspace and velocity for a joint
269
- @wp.func
270
- def jcalc_motion(
271
- type: int,
272
- joint_axis: wp.array(dtype=wp.vec3),
273
- axis_start: int,
274
- lin_axis_count: int,
275
- ang_axis_count: int,
276
- X_sc: wp.transform,
277
- joint_q: wp.array(dtype=float),
278
- joint_qd: wp.array(dtype=float),
279
- q_start: int,
280
- qd_start: int,
281
- # outputs
282
- joint_S_s: wp.array(dtype=wp.spatial_vector),
283
- ):
284
- if type == wp.sim.JOINT_PRISMATIC:
285
- axis = joint_axis[axis_start]
286
- S_s = transform_twist(X_sc, wp.spatial_vector(wp.vec3(), axis))
287
- v_j_s = S_s * joint_qd[qd_start]
288
- joint_S_s[qd_start] = S_s
289
- return v_j_s
290
-
291
- if type == wp.sim.JOINT_REVOLUTE:
292
- axis = joint_axis[axis_start]
293
- S_s = transform_twist(X_sc, wp.spatial_vector(axis, wp.vec3()))
294
- v_j_s = S_s * joint_qd[qd_start]
295
- joint_S_s[qd_start] = S_s
296
- return v_j_s
297
-
298
- if type == wp.sim.JOINT_UNIVERSAL:
299
- axis_0 = joint_axis[axis_start + 0]
300
- axis_1 = joint_axis[axis_start + 1]
301
- q_off = wp.quat_from_matrix(wp.matrix_from_cols(axis_0, axis_1, wp.cross(axis_0, axis_1)))
302
- local_0 = wp.quat_rotate(q_off, wp.vec3(1.0, 0.0, 0.0))
303
- local_1 = wp.quat_rotate(q_off, wp.vec3(0.0, 1.0, 0.0))
304
-
305
- axis_0 = local_0
306
- q_0 = wp.quat_from_axis_angle(axis_0, joint_q[q_start + 0])
307
-
308
- axis_1 = wp.quat_rotate(q_0, local_1)
309
-
310
- S_0 = transform_twist(X_sc, wp.spatial_vector(axis_0, wp.vec3()))
311
- S_1 = transform_twist(X_sc, wp.spatial_vector(axis_1, wp.vec3()))
312
-
313
- joint_S_s[qd_start + 0] = S_0
314
- joint_S_s[qd_start + 1] = S_1
315
-
316
- return S_0 * joint_qd[qd_start + 0] + S_1 * joint_qd[qd_start + 1]
317
-
318
- if type == wp.sim.JOINT_COMPOUND:
319
- axis_0 = joint_axis[axis_start + 0]
320
- axis_1 = joint_axis[axis_start + 1]
321
- axis_2 = joint_axis[axis_start + 2]
322
- q_off = wp.quat_from_matrix(wp.matrix_from_cols(axis_0, axis_1, axis_2))
323
- local_0 = wp.quat_rotate(q_off, wp.vec3(1.0, 0.0, 0.0))
324
- local_1 = wp.quat_rotate(q_off, wp.vec3(0.0, 1.0, 0.0))
325
- local_2 = wp.quat_rotate(q_off, wp.vec3(0.0, 0.0, 1.0))
326
-
327
- axis_0 = local_0
328
- q_0 = wp.quat_from_axis_angle(axis_0, joint_q[q_start + 0])
329
-
330
- axis_1 = wp.quat_rotate(q_0, local_1)
331
- q_1 = wp.quat_from_axis_angle(axis_1, joint_q[q_start + 1])
332
-
333
- axis_2 = wp.quat_rotate(q_1 * q_0, local_2)
334
-
335
- S_0 = transform_twist(X_sc, wp.spatial_vector(axis_0, wp.vec3()))
336
- S_1 = transform_twist(X_sc, wp.spatial_vector(axis_1, wp.vec3()))
337
- S_2 = transform_twist(X_sc, wp.spatial_vector(axis_2, wp.vec3()))
338
-
339
- joint_S_s[qd_start + 0] = S_0
340
- joint_S_s[qd_start + 1] = S_1
341
- joint_S_s[qd_start + 2] = S_2
342
-
343
- return S_0 * joint_qd[qd_start + 0] + S_1 * joint_qd[qd_start + 1] + S_2 * joint_qd[qd_start + 2]
344
-
345
- if type == wp.sim.JOINT_D6:
346
- v_j_s = wp.spatial_vector()
347
- if lin_axis_count > 0:
348
- axis = joint_axis[axis_start + 0]
349
- S_s = transform_twist(X_sc, wp.spatial_vector(wp.vec3(), axis))
350
- v_j_s += S_s * joint_qd[qd_start + 0]
351
- joint_S_s[qd_start + 0] = S_s
352
- if lin_axis_count > 1:
353
- axis = joint_axis[axis_start + 1]
354
- S_s = transform_twist(X_sc, wp.spatial_vector(wp.vec3(), axis))
355
- v_j_s += S_s * joint_qd[qd_start + 1]
356
- joint_S_s[qd_start + 1] = S_s
357
- if lin_axis_count > 2:
358
- axis = joint_axis[axis_start + 2]
359
- S_s = transform_twist(X_sc, wp.spatial_vector(wp.vec3(), axis))
360
- v_j_s += S_s * joint_qd[qd_start + 2]
361
- joint_S_s[qd_start + 2] = S_s
362
- if ang_axis_count > 0:
363
- axis = joint_axis[axis_start + lin_axis_count + 0]
364
- S_s = transform_twist(X_sc, wp.spatial_vector(axis, wp.vec3()))
365
- v_j_s += S_s * joint_qd[qd_start + lin_axis_count + 0]
366
- joint_S_s[qd_start + lin_axis_count + 0] = S_s
367
- if ang_axis_count > 1:
368
- axis = joint_axis[axis_start + lin_axis_count + 1]
369
- S_s = transform_twist(X_sc, wp.spatial_vector(axis, wp.vec3()))
370
- v_j_s += S_s * joint_qd[qd_start + lin_axis_count + 1]
371
- joint_S_s[qd_start + lin_axis_count + 1] = S_s
372
- if ang_axis_count > 2:
373
- axis = joint_axis[axis_start + lin_axis_count + 2]
374
- S_s = transform_twist(X_sc, wp.spatial_vector(axis, wp.vec3()))
375
- v_j_s += S_s * joint_qd[qd_start + lin_axis_count + 2]
376
- joint_S_s[qd_start + lin_axis_count + 2] = S_s
377
-
378
- return v_j_s
379
-
380
- if type == wp.sim.JOINT_BALL:
381
- S_0 = transform_twist(X_sc, wp.spatial_vector(1.0, 0.0, 0.0, 0.0, 0.0, 0.0))
382
- S_1 = transform_twist(X_sc, wp.spatial_vector(0.0, 1.0, 0.0, 0.0, 0.0, 0.0))
383
- S_2 = transform_twist(X_sc, wp.spatial_vector(0.0, 0.0, 1.0, 0.0, 0.0, 0.0))
384
-
385
- joint_S_s[qd_start + 0] = S_0
386
- joint_S_s[qd_start + 1] = S_1
387
- joint_S_s[qd_start + 2] = S_2
388
-
389
- return S_0 * joint_qd[qd_start + 0] + S_1 * joint_qd[qd_start + 1] + S_2 * joint_qd[qd_start + 2]
390
-
391
- if type == wp.sim.JOINT_FIXED:
392
- return wp.spatial_vector()
393
-
394
- if type == wp.sim.JOINT_FREE or type == wp.sim.JOINT_DISTANCE:
395
- v_j_s = transform_twist(
396
- X_sc,
397
- wp.spatial_vector(
398
- joint_qd[qd_start + 0],
399
- joint_qd[qd_start + 1],
400
- joint_qd[qd_start + 2],
401
- joint_qd[qd_start + 3],
402
- joint_qd[qd_start + 4],
403
- joint_qd[qd_start + 5],
404
- ),
405
- )
406
-
407
- joint_S_s[qd_start + 0] = transform_twist(X_sc, wp.spatial_vector(1.0, 0.0, 0.0, 0.0, 0.0, 0.0))
408
- joint_S_s[qd_start + 1] = transform_twist(X_sc, wp.spatial_vector(0.0, 1.0, 0.0, 0.0, 0.0, 0.0))
409
- joint_S_s[qd_start + 2] = transform_twist(X_sc, wp.spatial_vector(0.0, 0.0, 1.0, 0.0, 0.0, 0.0))
410
- joint_S_s[qd_start + 3] = transform_twist(X_sc, wp.spatial_vector(0.0, 0.0, 0.0, 1.0, 0.0, 0.0))
411
- joint_S_s[qd_start + 4] = transform_twist(X_sc, wp.spatial_vector(0.0, 0.0, 0.0, 0.0, 1.0, 0.0))
412
- joint_S_s[qd_start + 5] = transform_twist(X_sc, wp.spatial_vector(0.0, 0.0, 0.0, 0.0, 0.0, 1.0))
413
-
414
- return v_j_s
415
-
416
- wp.printf("jcalc_motion not implemented for joint type %d\n", type)
417
-
418
- # default case
419
- return wp.spatial_vector()
420
-
421
-
422
- # computes joint space forces/torques in tau
423
- @wp.func
424
- def jcalc_tau(
425
- type: int,
426
- joint_target_ke: wp.array(dtype=float),
427
- joint_target_kd: wp.array(dtype=float),
428
- joint_limit_ke: wp.array(dtype=float),
429
- joint_limit_kd: wp.array(dtype=float),
430
- joint_S_s: wp.array(dtype=wp.spatial_vector),
431
- joint_q: wp.array(dtype=float),
432
- joint_qd: wp.array(dtype=float),
433
- joint_act: wp.array(dtype=float),
434
- joint_axis_mode: wp.array(dtype=int),
435
- joint_limit_lower: wp.array(dtype=float),
436
- joint_limit_upper: wp.array(dtype=float),
437
- coord_start: int,
438
- dof_start: int,
439
- axis_start: int,
440
- lin_axis_count: int,
441
- ang_axis_count: int,
442
- body_f_s: wp.spatial_vector,
443
- # outputs
444
- tau: wp.array(dtype=float),
445
- ):
446
- if type == wp.sim.JOINT_PRISMATIC or type == wp.sim.JOINT_REVOLUTE:
447
- S_s = joint_S_s[dof_start]
448
-
449
- q = joint_q[coord_start]
450
- qd = joint_qd[dof_start]
451
- act = joint_act[axis_start]
452
-
453
- lower = joint_limit_lower[axis_start]
454
- upper = joint_limit_upper[axis_start]
455
-
456
- limit_ke = joint_limit_ke[axis_start]
457
- limit_kd = joint_limit_kd[axis_start]
458
- target_ke = joint_target_ke[axis_start]
459
- target_kd = joint_target_kd[axis_start]
460
- mode = joint_axis_mode[axis_start]
461
-
462
- # total torque / force on the joint
463
- t = -wp.dot(S_s, body_f_s) + eval_joint_force(
464
- q, qd, act, target_ke, target_kd, lower, upper, limit_ke, limit_kd, mode
465
- )
466
-
467
- tau[dof_start] = t
468
-
469
- return
470
-
471
- if type == wp.sim.JOINT_BALL:
472
- # target_ke = joint_target_ke[axis_start]
473
- # target_kd = joint_target_kd[axis_start]
474
-
475
- for i in range(3):
476
- S_s = joint_S_s[dof_start + i]
477
-
478
- # w = joint_qd[dof_start + i]
479
- # r = joint_q[coord_start + i]
480
-
481
- tau[dof_start + i] = -wp.dot(S_s, body_f_s) # - w * target_kd - r * target_ke
482
-
483
- return
484
-
485
- if type == wp.sim.JOINT_FREE or type == wp.sim.JOINT_DISTANCE:
486
- for i in range(6):
487
- S_s = joint_S_s[dof_start + i]
488
- tau[dof_start + i] = -wp.dot(S_s, body_f_s)
489
-
490
- return
491
-
492
- if type == wp.sim.JOINT_COMPOUND or type == wp.sim.JOINT_UNIVERSAL or type == wp.sim.JOINT_D6:
493
- axis_count = lin_axis_count + ang_axis_count
494
-
495
- for i in range(axis_count):
496
- S_s = joint_S_s[dof_start + i]
497
-
498
- q = joint_q[coord_start + i]
499
- qd = joint_qd[dof_start + i]
500
- act = joint_act[axis_start + i]
501
-
502
- lower = joint_limit_lower[axis_start + i]
503
- upper = joint_limit_upper[axis_start + i]
504
- limit_ke = joint_limit_ke[axis_start + i]
505
- limit_kd = joint_limit_kd[axis_start + i]
506
- target_ke = joint_target_ke[axis_start + i]
507
- target_kd = joint_target_kd[axis_start + i]
508
- mode = joint_axis_mode[axis_start + i]
509
-
510
- f = eval_joint_force(q, qd, act, target_ke, target_kd, lower, upper, limit_ke, limit_kd, mode)
511
-
512
- # total torque / force on the joint
513
- t = -wp.dot(S_s, body_f_s) + f
514
-
515
- tau[dof_start + i] = t
516
-
517
- return
518
-
519
-
520
- @wp.func
521
- def jcalc_integrate(
522
- type: int,
523
- joint_q: wp.array(dtype=float),
524
- joint_qd: wp.array(dtype=float),
525
- joint_qdd: wp.array(dtype=float),
526
- coord_start: int,
527
- dof_start: int,
528
- lin_axis_count: int,
529
- ang_axis_count: int,
530
- dt: float,
531
- # outputs
532
- joint_q_new: wp.array(dtype=float),
533
- joint_qd_new: wp.array(dtype=float),
534
- ):
535
- if type == wp.sim.JOINT_FIXED:
536
- return
537
-
538
- # prismatic / revolute
539
- if type == wp.sim.JOINT_PRISMATIC or type == wp.sim.JOINT_REVOLUTE:
540
- qdd = joint_qdd[dof_start]
541
- qd = joint_qd[dof_start]
542
- q = joint_q[coord_start]
543
-
544
- qd_new = qd + qdd * dt
545
- q_new = q + qd_new * dt
546
-
547
- joint_qd_new[dof_start] = qd_new
548
- joint_q_new[coord_start] = q_new
549
-
550
- return
551
-
552
- # ball
553
- if type == wp.sim.JOINT_BALL:
554
- m_j = wp.vec3(joint_qdd[dof_start + 0], joint_qdd[dof_start + 1], joint_qdd[dof_start + 2])
555
- w_j = wp.vec3(joint_qd[dof_start + 0], joint_qd[dof_start + 1], joint_qd[dof_start + 2])
556
-
557
- r_j = wp.quat(
558
- joint_q[coord_start + 0], joint_q[coord_start + 1], joint_q[coord_start + 2], joint_q[coord_start + 3]
559
- )
560
-
561
- # symplectic Euler
562
- w_j_new = w_j + m_j * dt
563
-
564
- drdt_j = wp.quat(w_j_new, 0.0) * r_j * 0.5
565
-
566
- # new orientation (normalized)
567
- r_j_new = wp.normalize(r_j + drdt_j * dt)
568
-
569
- # update joint coords
570
- joint_q_new[coord_start + 0] = r_j_new[0]
571
- joint_q_new[coord_start + 1] = r_j_new[1]
572
- joint_q_new[coord_start + 2] = r_j_new[2]
573
- joint_q_new[coord_start + 3] = r_j_new[3]
574
-
575
- # update joint vel
576
- joint_qd_new[dof_start + 0] = w_j_new[0]
577
- joint_qd_new[dof_start + 1] = w_j_new[1]
578
- joint_qd_new[dof_start + 2] = w_j_new[2]
579
-
580
- return
581
-
582
- # free joint
583
- if type == wp.sim.JOINT_FREE or type == wp.sim.JOINT_DISTANCE:
584
- # dofs: qd = (omega_x, omega_y, omega_z, vel_x, vel_y, vel_z)
585
- # coords: q = (trans_x, trans_y, trans_z, quat_x, quat_y, quat_z, quat_w)
586
-
587
- # angular and linear acceleration
588
- m_s = wp.vec3(joint_qdd[dof_start + 0], joint_qdd[dof_start + 1], joint_qdd[dof_start + 2])
589
- a_s = wp.vec3(joint_qdd[dof_start + 3], joint_qdd[dof_start + 4], joint_qdd[dof_start + 5])
590
-
591
- # angular and linear velocity
592
- w_s = wp.vec3(joint_qd[dof_start + 0], joint_qd[dof_start + 1], joint_qd[dof_start + 2])
593
- v_s = wp.vec3(joint_qd[dof_start + 3], joint_qd[dof_start + 4], joint_qd[dof_start + 5])
594
-
595
- # symplectic Euler
596
- w_s = w_s + m_s * dt
597
- v_s = v_s + a_s * dt
598
-
599
- # translation of origin
600
- p_s = wp.vec3(joint_q[coord_start + 0], joint_q[coord_start + 1], joint_q[coord_start + 2])
601
-
602
- # linear vel of origin (note q/qd switch order of linear angular elements)
603
- # note we are converting the body twist in the space frame (w_s, v_s) to compute center of mass velocity
604
- dpdt_s = v_s + wp.cross(w_s, p_s)
605
-
606
- # quat and quat derivative
607
- r_s = wp.quat(
608
- joint_q[coord_start + 3], joint_q[coord_start + 4], joint_q[coord_start + 5], joint_q[coord_start + 6]
609
- )
610
-
611
- drdt_s = wp.quat(w_s, 0.0) * r_s * 0.5
612
-
613
- # new orientation (normalized)
614
- p_s_new = p_s + dpdt_s * dt
615
- r_s_new = wp.normalize(r_s + drdt_s * dt)
616
-
617
- # update transform
618
- joint_q_new[coord_start + 0] = p_s_new[0]
619
- joint_q_new[coord_start + 1] = p_s_new[1]
620
- joint_q_new[coord_start + 2] = p_s_new[2]
621
-
622
- joint_q_new[coord_start + 3] = r_s_new[0]
623
- joint_q_new[coord_start + 4] = r_s_new[1]
624
- joint_q_new[coord_start + 5] = r_s_new[2]
625
- joint_q_new[coord_start + 6] = r_s_new[3]
626
-
627
- # update joint_twist
628
- joint_qd_new[dof_start + 0] = w_s[0]
629
- joint_qd_new[dof_start + 1] = w_s[1]
630
- joint_qd_new[dof_start + 2] = w_s[2]
631
- joint_qd_new[dof_start + 3] = v_s[0]
632
- joint_qd_new[dof_start + 4] = v_s[1]
633
- joint_qd_new[dof_start + 5] = v_s[2]
634
-
635
- return
636
-
637
- # other joint types (compound, universal, D6)
638
- if type == wp.sim.JOINT_COMPOUND or type == wp.sim.JOINT_UNIVERSAL or type == wp.sim.JOINT_D6:
639
- axis_count = lin_axis_count + ang_axis_count
640
-
641
- for i in range(axis_count):
642
- qdd = joint_qdd[dof_start + i]
643
- qd = joint_qd[dof_start + i]
644
- q = joint_q[coord_start + i]
645
-
646
- qd_new = qd + qdd * dt
647
- q_new = q + qd_new * dt
648
-
649
- joint_qd_new[dof_start + i] = qd_new
650
- joint_q_new[coord_start + i] = q_new
651
-
652
- return
653
-
654
-
655
- @wp.func
656
- def compute_link_transform(
657
- i: int,
658
- joint_type: wp.array(dtype=int),
659
- joint_parent: wp.array(dtype=int),
660
- joint_child: wp.array(dtype=int),
661
- joint_q_start: wp.array(dtype=int),
662
- joint_q: wp.array(dtype=float),
663
- joint_X_p: wp.array(dtype=wp.transform),
664
- joint_X_c: wp.array(dtype=wp.transform),
665
- body_X_com: wp.array(dtype=wp.transform),
666
- joint_axis: wp.array(dtype=wp.vec3),
667
- joint_axis_start: wp.array(dtype=int),
668
- joint_axis_dim: wp.array(dtype=int, ndim=2),
669
- # outputs
670
- body_q: wp.array(dtype=wp.transform),
671
- body_q_com: wp.array(dtype=wp.transform),
672
- ):
673
- # parent transform
674
- parent = joint_parent[i]
675
- child = joint_child[i]
676
-
677
- # parent transform in spatial coordinates
678
- X_pj = joint_X_p[i]
679
- X_cj = joint_X_c[i]
680
- # parent anchor frame in world space
681
- X_wpj = X_pj
682
- if parent >= 0:
683
- X_wp = body_q[parent]
684
- X_wpj = X_wp * X_wpj
685
-
686
- type = joint_type[i]
687
- axis_start = joint_axis_start[i]
688
- lin_axis_count = joint_axis_dim[i, 0]
689
- ang_axis_count = joint_axis_dim[i, 1]
690
- coord_start = joint_q_start[i]
691
-
692
- # compute transform across joint
693
- X_j = jcalc_transform(type, joint_axis, axis_start, lin_axis_count, ang_axis_count, joint_q, coord_start)
694
-
695
- # transform from world to joint anchor frame at child body
696
- X_wcj = X_wpj * X_j
697
- # transform from world to child body frame
698
- X_wc = X_wcj * wp.transform_inverse(X_cj)
699
-
700
- # compute transform of center of mass
701
- X_cm = body_X_com[child]
702
- X_sm = X_wc * X_cm
703
-
704
- # store geometry transforms
705
- body_q[child] = X_wc
706
- body_q_com[child] = X_sm
707
-
708
-
709
- @wp.kernel
710
- def eval_rigid_fk(
711
- articulation_start: wp.array(dtype=int),
712
- joint_type: wp.array(dtype=int),
713
- joint_parent: wp.array(dtype=int),
714
- joint_child: wp.array(dtype=int),
715
- joint_q_start: wp.array(dtype=int),
716
- joint_q: wp.array(dtype=float),
717
- joint_X_p: wp.array(dtype=wp.transform),
718
- joint_X_c: wp.array(dtype=wp.transform),
719
- body_X_com: wp.array(dtype=wp.transform),
720
- joint_axis: wp.array(dtype=wp.vec3),
721
- joint_axis_start: wp.array(dtype=int),
722
- joint_axis_dim: wp.array(dtype=int, ndim=2),
723
- # outputs
724
- body_q: wp.array(dtype=wp.transform),
725
- body_q_com: wp.array(dtype=wp.transform),
726
- ):
727
- # one thread per-articulation
728
- index = wp.tid()
729
-
730
- start = articulation_start[index]
731
- end = articulation_start[index + 1]
732
-
733
- for i in range(start, end):
734
- compute_link_transform(
735
- i,
736
- joint_type,
737
- joint_parent,
738
- joint_child,
739
- joint_q_start,
740
- joint_q,
741
- joint_X_p,
742
- joint_X_c,
743
- body_X_com,
744
- joint_axis,
745
- joint_axis_start,
746
- joint_axis_dim,
747
- body_q,
748
- body_q_com,
749
- )
750
-
751
-
752
- @wp.func
753
- def spatial_cross(a: wp.spatial_vector, b: wp.spatial_vector):
754
- w_a = wp.spatial_top(a)
755
- v_a = wp.spatial_bottom(a)
756
-
757
- w_b = wp.spatial_top(b)
758
- v_b = wp.spatial_bottom(b)
759
-
760
- w = wp.cross(w_a, w_b)
761
- v = wp.cross(w_a, v_b) + wp.cross(v_a, w_b)
762
-
763
- return wp.spatial_vector(w, v)
764
-
765
-
766
- @wp.func
767
- def spatial_cross_dual(a: wp.spatial_vector, b: wp.spatial_vector):
768
- w_a = wp.spatial_top(a)
769
- v_a = wp.spatial_bottom(a)
770
-
771
- w_b = wp.spatial_top(b)
772
- v_b = wp.spatial_bottom(b)
773
-
774
- w = wp.cross(w_a, w_b) + wp.cross(v_a, v_b)
775
- v = wp.cross(w_a, v_b)
776
-
777
- return wp.spatial_vector(w, v)
778
-
779
-
780
- @wp.func
781
- def dense_index(stride: int, i: int, j: int):
782
- return i * stride + j
783
-
784
-
785
- @wp.func
786
- def compute_link_velocity(
787
- i: int,
788
- joint_type: wp.array(dtype=int),
789
- joint_parent: wp.array(dtype=int),
790
- joint_child: wp.array(dtype=int),
791
- joint_q_start: wp.array(dtype=int),
792
- joint_qd_start: wp.array(dtype=int),
793
- joint_q: wp.array(dtype=float),
794
- joint_qd: wp.array(dtype=float),
795
- joint_axis: wp.array(dtype=wp.vec3),
796
- joint_axis_start: wp.array(dtype=int),
797
- joint_axis_dim: wp.array(dtype=int, ndim=2),
798
- body_I_m: wp.array(dtype=wp.spatial_matrix),
799
- body_q: wp.array(dtype=wp.transform),
800
- body_q_com: wp.array(dtype=wp.transform),
801
- joint_X_p: wp.array(dtype=wp.transform),
802
- joint_X_c: wp.array(dtype=wp.transform),
803
- gravity: wp.vec3,
804
- # outputs
805
- joint_S_s: wp.array(dtype=wp.spatial_vector),
806
- body_I_s: wp.array(dtype=wp.spatial_matrix),
807
- body_v_s: wp.array(dtype=wp.spatial_vector),
808
- body_f_s: wp.array(dtype=wp.spatial_vector),
809
- body_a_s: wp.array(dtype=wp.spatial_vector),
810
- ):
811
- type = joint_type[i]
812
- child = joint_child[i]
813
- parent = joint_parent[i]
814
- q_start = joint_q_start[i]
815
- qd_start = joint_qd_start[i]
816
-
817
- X_pj = joint_X_p[i]
818
- # X_cj = joint_X_c[i]
819
-
820
- # parent anchor frame in world space
821
- X_wpj = X_pj
822
- if parent >= 0:
823
- X_wp = body_q[parent]
824
- X_wpj = X_wp * X_wpj
825
-
826
- # compute motion subspace and velocity across the joint (also stores S_s to global memory)
827
- axis_start = joint_axis_start[i]
828
- lin_axis_count = joint_axis_dim[i, 0]
829
- ang_axis_count = joint_axis_dim[i, 1]
830
- v_j_s = jcalc_motion(
831
- type,
832
- joint_axis,
833
- axis_start,
834
- lin_axis_count,
835
- ang_axis_count,
836
- X_wpj,
837
- joint_q,
838
- joint_qd,
839
- q_start,
840
- qd_start,
841
- joint_S_s,
842
- )
843
-
844
- # parent velocity
845
- v_parent_s = wp.spatial_vector()
846
- a_parent_s = wp.spatial_vector()
847
-
848
- if parent >= 0:
849
- v_parent_s = body_v_s[parent]
850
- a_parent_s = body_a_s[parent]
851
-
852
- # body velocity, acceleration
853
- v_s = v_parent_s + v_j_s
854
- a_s = a_parent_s + spatial_cross(v_s, v_j_s) # + joint_S_s[i]*self.joint_qdd[i]
855
-
856
- # compute body forces
857
- X_sm = body_q_com[child]
858
- I_m = body_I_m[child]
859
-
860
- # gravity and external forces (expressed in frame aligned with s but centered at body mass)
861
- m = I_m[3, 3]
862
-
863
- f_g = m * gravity
864
- r_com = wp.transform_get_translation(X_sm)
865
- f_g_s = wp.spatial_vector(wp.cross(r_com, f_g), f_g)
866
-
867
- # body forces
868
- I_s = spatial_transform_inertia(X_sm, I_m)
869
-
870
- f_b_s = I_s * a_s + spatial_cross_dual(v_s, I_s * v_s)
871
-
872
- body_v_s[child] = v_s
873
- body_a_s[child] = a_s
874
- body_f_s[child] = f_b_s - f_g_s
875
- body_I_s[child] = I_s
876
-
877
-
878
- # Inverse dynamics via Recursive Newton-Euler algorithm (Featherstone Table 5.1)
879
- @wp.kernel
880
- def eval_rigid_id(
881
- articulation_start: wp.array(dtype=int),
882
- joint_type: wp.array(dtype=int),
883
- joint_parent: wp.array(dtype=int),
884
- joint_child: wp.array(dtype=int),
885
- joint_q_start: wp.array(dtype=int),
886
- joint_qd_start: wp.array(dtype=int),
887
- joint_q: wp.array(dtype=float),
888
- joint_qd: wp.array(dtype=float),
889
- joint_axis: wp.array(dtype=wp.vec3),
890
- joint_axis_start: wp.array(dtype=int),
891
- joint_axis_dim: wp.array(dtype=int, ndim=2),
892
- body_I_m: wp.array(dtype=wp.spatial_matrix),
893
- body_q: wp.array(dtype=wp.transform),
894
- body_q_com: wp.array(dtype=wp.transform),
895
- joint_X_p: wp.array(dtype=wp.transform),
896
- joint_X_c: wp.array(dtype=wp.transform),
897
- gravity: wp.vec3,
898
- # outputs
899
- joint_S_s: wp.array(dtype=wp.spatial_vector),
900
- body_I_s: wp.array(dtype=wp.spatial_matrix),
901
- body_v_s: wp.array(dtype=wp.spatial_vector),
902
- body_f_s: wp.array(dtype=wp.spatial_vector),
903
- body_a_s: wp.array(dtype=wp.spatial_vector),
904
- ):
905
- # one thread per-articulation
906
- index = wp.tid()
907
-
908
- start = articulation_start[index]
909
- end = articulation_start[index + 1]
910
-
911
- # compute link velocities and coriolis forces
912
- for i in range(start, end):
913
- compute_link_velocity(
914
- i,
915
- joint_type,
916
- joint_parent,
917
- joint_child,
918
- joint_q_start,
919
- joint_qd_start,
920
- joint_q,
921
- joint_qd,
922
- joint_axis,
923
- joint_axis_start,
924
- joint_axis_dim,
925
- body_I_m,
926
- body_q,
927
- body_q_com,
928
- joint_X_p,
929
- joint_X_c,
930
- gravity,
931
- joint_S_s,
932
- body_I_s,
933
- body_v_s,
934
- body_f_s,
935
- body_a_s,
936
- )
937
-
938
-
939
- @wp.kernel
940
- def eval_rigid_tau(
941
- articulation_start: wp.array(dtype=int),
942
- joint_type: wp.array(dtype=int),
943
- joint_parent: wp.array(dtype=int),
944
- joint_child: wp.array(dtype=int),
945
- joint_q_start: wp.array(dtype=int),
946
- joint_qd_start: wp.array(dtype=int),
947
- joint_axis_start: wp.array(dtype=int),
948
- joint_axis_dim: wp.array(dtype=int, ndim=2),
949
- joint_axis_mode: wp.array(dtype=int),
950
- joint_q: wp.array(dtype=float),
951
- joint_qd: wp.array(dtype=float),
952
- joint_act: wp.array(dtype=float),
953
- joint_target_ke: wp.array(dtype=float),
954
- joint_target_kd: wp.array(dtype=float),
955
- joint_limit_lower: wp.array(dtype=float),
956
- joint_limit_upper: wp.array(dtype=float),
957
- joint_limit_ke: wp.array(dtype=float),
958
- joint_limit_kd: wp.array(dtype=float),
959
- joint_S_s: wp.array(dtype=wp.spatial_vector),
960
- body_fb_s: wp.array(dtype=wp.spatial_vector),
961
- body_f_ext: wp.array(dtype=wp.spatial_vector),
962
- # outputs
963
- body_ft_s: wp.array(dtype=wp.spatial_vector),
964
- tau: wp.array(dtype=float),
965
- ):
966
- # one thread per-articulation
967
- index = wp.tid()
968
-
969
- start = articulation_start[index]
970
- end = articulation_start[index + 1]
971
- count = end - start
972
-
973
- # compute joint forces
974
- for offset in range(count):
975
- # for backwards traversal
976
- i = end - offset - 1
977
-
978
- type = joint_type[i]
979
- parent = joint_parent[i]
980
- child = joint_child[i]
981
- dof_start = joint_qd_start[i]
982
- coord_start = joint_q_start[i]
983
- axis_start = joint_axis_start[i]
984
- lin_axis_count = joint_axis_dim[i, 0]
985
- ang_axis_count = joint_axis_dim[i, 1]
986
-
987
- # total forces on body
988
- f_b_s = body_fb_s[child]
989
- f_t_s = body_ft_s[child]
990
- f_ext = body_f_ext[child]
991
- f_s = f_b_s + f_t_s + f_ext
992
-
993
- # compute joint-space forces, writes out tau
994
- jcalc_tau(
995
- type,
996
- joint_target_ke,
997
- joint_target_kd,
998
- joint_limit_ke,
999
- joint_limit_kd,
1000
- joint_S_s,
1001
- joint_q,
1002
- joint_qd,
1003
- joint_act,
1004
- joint_axis_mode,
1005
- joint_limit_lower,
1006
- joint_limit_upper,
1007
- coord_start,
1008
- dof_start,
1009
- axis_start,
1010
- lin_axis_count,
1011
- ang_axis_count,
1012
- f_s,
1013
- tau,
1014
- )
1015
-
1016
- # update parent forces, todo: check that this is valid for the backwards pass
1017
- if parent >= 0:
1018
- wp.atomic_add(body_ft_s, parent, f_s)
1019
-
1020
-
1021
- # builds spatial Jacobian J which is an (joint_count*6)x(dof_count) matrix
1022
- @wp.kernel
1023
- def eval_rigid_jacobian(
1024
- articulation_start: wp.array(dtype=int),
1025
- articulation_J_start: wp.array(dtype=int),
1026
- joint_ancestor: wp.array(dtype=int),
1027
- joint_qd_start: wp.array(dtype=int),
1028
- joint_S_s: wp.array(dtype=wp.spatial_vector),
1029
- # outputs
1030
- J: wp.array(dtype=float),
1031
- ):
1032
- # one thread per-articulation
1033
- index = wp.tid()
1034
-
1035
- joint_start = articulation_start[index]
1036
- joint_end = articulation_start[index + 1]
1037
- joint_count = joint_end - joint_start
1038
-
1039
- J_offset = articulation_J_start[index]
1040
-
1041
- articulation_dof_start = joint_qd_start[joint_start]
1042
- articulation_dof_end = joint_qd_start[joint_end]
1043
- articulation_dof_count = articulation_dof_end - articulation_dof_start
1044
-
1045
- for i in range(joint_count):
1046
- row_start = i * 6
1047
-
1048
- j = joint_start + i
1049
- while j != -1:
1050
- joint_dof_start = joint_qd_start[j]
1051
- joint_dof_end = joint_qd_start[j + 1]
1052
- joint_dof_count = joint_dof_end - joint_dof_start
1053
-
1054
- # fill out each row of the Jacobian walking up the tree
1055
- for dof in range(joint_dof_count):
1056
- col = (joint_dof_start - articulation_dof_start) + dof
1057
- S = joint_S_s[joint_dof_start + dof]
1058
-
1059
- for k in range(6):
1060
- J[J_offset + dense_index(articulation_dof_count, row_start + k, col)] = S[k]
1061
-
1062
- j = joint_ancestor[j]
1063
-
1064
-
1065
- @wp.func
1066
- def spatial_mass(
1067
- body_I_s: wp.array(dtype=wp.spatial_matrix),
1068
- joint_start: int,
1069
- joint_count: int,
1070
- M_start: int,
1071
- # outputs
1072
- M: wp.array(dtype=float),
1073
- ):
1074
- stride = joint_count * 6
1075
- for l in range(joint_count):
1076
- I = body_I_s[joint_start + l]
1077
- for i in range(6):
1078
- for j in range(6):
1079
- M[M_start + dense_index(stride, l * 6 + i, l * 6 + j)] = I[i, j]
1080
-
1081
-
1082
- @wp.kernel
1083
- def eval_rigid_mass(
1084
- articulation_start: wp.array(dtype=int),
1085
- articulation_M_start: wp.array(dtype=int),
1086
- body_I_s: wp.array(dtype=wp.spatial_matrix),
1087
- # outputs
1088
- M: wp.array(dtype=float),
1089
- ):
1090
- # one thread per-articulation
1091
- index = wp.tid()
1092
-
1093
- joint_start = articulation_start[index]
1094
- joint_end = articulation_start[index + 1]
1095
- joint_count = joint_end - joint_start
1096
-
1097
- M_offset = articulation_M_start[index]
1098
-
1099
- spatial_mass(body_I_s, joint_start, joint_count, M_offset, M)
1100
-
1101
-
1102
- @wp.func
1103
- def dense_gemm(
1104
- m: int,
1105
- n: int,
1106
- p: int,
1107
- transpose_A: bool,
1108
- transpose_B: bool,
1109
- add_to_C: bool,
1110
- A_start: int,
1111
- B_start: int,
1112
- C_start: int,
1113
- A: wp.array(dtype=float),
1114
- B: wp.array(dtype=float),
1115
- # outputs
1116
- C: wp.array(dtype=float),
1117
- ):
1118
- # multiply a `m x p` matrix A by a `p x n` matrix B to produce a `m x n` matrix C
1119
- for i in range(m):
1120
- for j in range(n):
1121
- sum = float(0.0)
1122
- for k in range(p):
1123
- if transpose_A:
1124
- a_i = k * m + i
1125
- else:
1126
- a_i = i * p + k
1127
- if transpose_B:
1128
- b_j = j * p + k
1129
- else:
1130
- b_j = k * n + j
1131
- sum += A[A_start + a_i] * B[B_start + b_j]
1132
-
1133
- if add_to_C:
1134
- C[C_start + i * n + j] += sum
1135
- else:
1136
- C[C_start + i * n + j] = sum
1137
-
1138
-
1139
- # @wp.func_grad(dense_gemm)
1140
- # def adj_dense_gemm(
1141
- # m: int,
1142
- # n: int,
1143
- # p: int,
1144
- # transpose_A: bool,
1145
- # transpose_B: bool,
1146
- # add_to_C: bool,
1147
- # A_start: int,
1148
- # B_start: int,
1149
- # C_start: int,
1150
- # A: wp.array(dtype=float),
1151
- # B: wp.array(dtype=float),
1152
- # # outputs
1153
- # C: wp.array(dtype=float),
1154
- # ):
1155
- # add_to_C = True
1156
- # if transpose_A:
1157
- # dense_gemm(p, m, n, False, True, add_to_C, A_start, B_start, C_start, B, wp.adjoint[C], wp.adjoint[A])
1158
- # dense_gemm(p, n, m, False, False, add_to_C, A_start, B_start, C_start, A, wp.adjoint[C], wp.adjoint[B])
1159
- # else:
1160
- # dense_gemm(
1161
- # m, p, n, False, not transpose_B, add_to_C, A_start, B_start, C_start, wp.adjoint[C], B, wp.adjoint[A]
1162
- # )
1163
- # dense_gemm(p, n, m, True, False, add_to_C, A_start, B_start, C_start, A, wp.adjoint[C], wp.adjoint[B])
1164
-
1165
-
1166
- def create_inertia_matrix_kernel(num_joints, num_dofs):
1167
- @wp.kernel
1168
- def eval_dense_gemm_tile(
1169
- J_arr: wp.array3d(dtype=float), M_arr: wp.array3d(dtype=float), H_arr: wp.array3d(dtype=float)
1170
- ):
1171
- articulation = wp.tid()
1172
-
1173
- J = wp.tile_load(J_arr[articulation], shape=(wp.static(6 * num_joints), num_dofs))
1174
- P = wp.tile_zeros(shape=(wp.static(6 * num_joints), num_dofs), dtype=float)
1175
-
1176
- # compute P = M*J where M is a 6x6 block diagonal mass matrix
1177
- for i in range(int(num_joints)):
1178
- # 6x6 block matrices are on the diagonal
1179
- M_body = wp.tile_load(M_arr[articulation], shape=(6, 6), offset=(i * 6, i * 6))
1180
-
1181
- # load a 6xN row from the Jacobian
1182
- J_body = wp.tile_view(J, offset=(i * 6, 0), shape=(6, num_dofs))
1183
-
1184
- # compute weighted row
1185
- P_body = wp.tile_matmul(M_body, J_body)
1186
-
1187
- # assign to the P slice
1188
- wp.tile_assign(P, P_body, offset=(i * 6, 0))
1189
-
1190
- # compute H = J^T*P
1191
- H = wp.tile_matmul(wp.tile_transpose(J), P)
1192
-
1193
- wp.tile_store(H_arr[articulation], H)
1194
-
1195
- return eval_dense_gemm_tile
1196
-
1197
-
1198
- def create_batched_cholesky_kernel(num_dofs):
1199
- assert num_dofs == 18
1200
-
1201
- @wp.kernel
1202
- def eval_tiled_dense_cholesky_batched(
1203
- A: wp.array3d(dtype=float), R: wp.array2d(dtype=float), L: wp.array3d(dtype=float)
1204
- ):
1205
- articulation = wp.tid()
1206
-
1207
- a = wp.tile_load(A[articulation], shape=(num_dofs, num_dofs), storage="shared")
1208
- r = wp.tile_load(R[articulation], shape=num_dofs, storage="shared")
1209
- a_r = wp.tile_diag_add(a, r)
1210
- l = wp.tile_cholesky(a_r)
1211
- wp.tile_store(L[articulation], wp.tile_transpose(l))
1212
-
1213
- return eval_tiled_dense_cholesky_batched
1214
-
1215
-
1216
- def create_inertia_matrix_cholesky_kernel(num_joints, num_dofs):
1217
- @wp.kernel
1218
- def eval_dense_gemm_and_cholesky_tile(
1219
- J_arr: wp.array3d(dtype=float),
1220
- M_arr: wp.array3d(dtype=float),
1221
- R_arr: wp.array2d(dtype=float),
1222
- H_arr: wp.array3d(dtype=float),
1223
- L_arr: wp.array3d(dtype=float),
1224
- ):
1225
- articulation = wp.tid()
1226
-
1227
- J = wp.tile_load(J_arr[articulation], shape=(wp.static(6 * num_joints), num_dofs))
1228
- P = wp.tile_zeros(shape=(wp.static(6 * num_joints), num_dofs), dtype=float)
1229
-
1230
- # compute P = M*J where M is a 6x6 block diagonal mass matrix
1231
- for i in range(int(num_joints)):
1232
- # 6x6 block matrices are on the diagonal
1233
- M_body = wp.tile_load(M_arr[articulation], shape=(6, 6), offset=(i * 6, i * 6))
1234
-
1235
- # load a 6xN row from the Jacobian
1236
- J_body = wp.tile_view(J, offset=(i * 6, 0), shape=(6, num_dofs))
1237
-
1238
- # compute weighted row
1239
- P_body = wp.tile_matmul(M_body, J_body)
1240
-
1241
- # assign to the P slice
1242
- wp.tile_assign(P, P_body, offset=(i * 6, 0))
1243
-
1244
- # compute H = J^T*P
1245
- H = wp.tile_matmul(wp.tile_transpose(J), P)
1246
- wp.tile_store(H_arr[articulation], H)
1247
-
1248
- # cholesky L L^T = (H + diag(R))
1249
- R = wp.tile_load(R_arr[articulation], shape=num_dofs, storage="shared")
1250
- H_R = wp.tile_diag_add(H, R)
1251
- L = wp.tile_cholesky(H_R)
1252
- wp.tile_store(L_arr[articulation], L)
1253
-
1254
- return eval_dense_gemm_and_cholesky_tile
1255
-
1256
-
1257
- @wp.kernel
1258
- def eval_dense_gemm_batched(
1259
- m: wp.array(dtype=int),
1260
- n: wp.array(dtype=int),
1261
- p: wp.array(dtype=int),
1262
- transpose_A: bool,
1263
- transpose_B: bool,
1264
- A_start: wp.array(dtype=int),
1265
- B_start: wp.array(dtype=int),
1266
- C_start: wp.array(dtype=int),
1267
- A: wp.array(dtype=float),
1268
- B: wp.array(dtype=float),
1269
- C: wp.array(dtype=float),
1270
- ):
1271
- # on the CPU each thread computes the whole matrix multiply
1272
- # on the GPU each block computes the multiply with one output per-thread
1273
- batch = wp.tid() # /kNumThreadsPerBlock;
1274
- add_to_C = False
1275
-
1276
- dense_gemm(
1277
- m[batch],
1278
- n[batch],
1279
- p[batch],
1280
- transpose_A,
1281
- transpose_B,
1282
- add_to_C,
1283
- A_start[batch],
1284
- B_start[batch],
1285
- C_start[batch],
1286
- A,
1287
- B,
1288
- C,
1289
- )
1290
-
1291
-
1292
- @wp.func
1293
- def dense_cholesky(
1294
- n: int,
1295
- A: wp.array(dtype=float),
1296
- R: wp.array(dtype=float),
1297
- A_start: int,
1298
- R_start: int,
1299
- # outputs
1300
- L: wp.array(dtype=float),
1301
- ):
1302
- # compute the Cholesky factorization of A = L L^T with diagonal regularization R
1303
- for j in range(n):
1304
- s = A[A_start + dense_index(n, j, j)] + R[R_start + j]
1305
-
1306
- for k in range(j):
1307
- r = L[A_start + dense_index(n, j, k)]
1308
- s -= r * r
1309
-
1310
- s = wp.sqrt(s)
1311
- invS = 1.0 / s
1312
-
1313
- L[A_start + dense_index(n, j, j)] = s
1314
-
1315
- for i in range(j + 1, n):
1316
- s = A[A_start + dense_index(n, i, j)]
1317
-
1318
- for k in range(j):
1319
- s -= L[A_start + dense_index(n, i, k)] * L[A_start + dense_index(n, j, k)]
1320
-
1321
- L[A_start + dense_index(n, i, j)] = s * invS
1322
-
1323
-
1324
- @wp.func_grad(dense_cholesky)
1325
- def adj_dense_cholesky(
1326
- n: int,
1327
- A: wp.array(dtype=float),
1328
- R: wp.array(dtype=float),
1329
- A_start: int,
1330
- R_start: int,
1331
- # outputs
1332
- L: wp.array(dtype=float),
1333
- ):
1334
- # nop, use dense_solve to differentiate through (A^-1)b = x
1335
- pass
1336
-
1337
-
1338
- @wp.kernel
1339
- def eval_dense_cholesky_batched(
1340
- A_starts: wp.array(dtype=int),
1341
- A_dim: wp.array(dtype=int),
1342
- A: wp.array(dtype=float),
1343
- R: wp.array(dtype=float),
1344
- L: wp.array(dtype=float),
1345
- ):
1346
- batch = wp.tid()
1347
-
1348
- n = A_dim[batch]
1349
- A_start = A_starts[batch]
1350
- R_start = n * batch
1351
-
1352
- dense_cholesky(n, A, R, A_start, R_start, L)
1353
-
1354
-
1355
- @wp.func
1356
- def dense_subs(
1357
- n: int,
1358
- L_start: int,
1359
- b_start: int,
1360
- L: wp.array(dtype=float),
1361
- b: wp.array(dtype=float),
1362
- # outputs
1363
- x: wp.array(dtype=float),
1364
- ):
1365
- # Solves (L L^T) x = b for x given the Cholesky factor L
1366
- # forward substitution solves the lower triangular system L y = b for y
1367
- for i in range(n):
1368
- s = b[b_start + i]
1369
-
1370
- for j in range(i):
1371
- s -= L[L_start + dense_index(n, i, j)] * x[b_start + j]
1372
-
1373
- x[b_start + i] = s / L[L_start + dense_index(n, i, i)]
1374
-
1375
- # backward substitution solves the upper triangular system L^T x = y for x
1376
- for i in range(n - 1, -1, -1):
1377
- s = x[b_start + i]
1378
-
1379
- for j in range(i + 1, n):
1380
- s -= L[L_start + dense_index(n, j, i)] * x[b_start + j]
1381
-
1382
- x[b_start + i] = s / L[L_start + dense_index(n, i, i)]
1383
-
1384
-
1385
- @wp.func
1386
- def dense_solve(
1387
- n: int,
1388
- L_start: int,
1389
- b_start: int,
1390
- A: wp.array(dtype=float),
1391
- L: wp.array(dtype=float),
1392
- b: wp.array(dtype=float),
1393
- # outputs
1394
- x: wp.array(dtype=float),
1395
- tmp: wp.array(dtype=float),
1396
- ):
1397
- # helper function to include tmp argument for backward pass
1398
- dense_subs(n, L_start, b_start, L, b, x)
1399
-
1400
-
1401
- @wp.func_grad(dense_solve)
1402
- def adj_dense_solve(
1403
- n: int,
1404
- L_start: int,
1405
- b_start: int,
1406
- A: wp.array(dtype=float),
1407
- L: wp.array(dtype=float),
1408
- b: wp.array(dtype=float),
1409
- # outputs
1410
- x: wp.array(dtype=float),
1411
- tmp: wp.array(dtype=float),
1412
- ):
1413
- if not tmp or not wp.adjoint[x] or not wp.adjoint[A] or not wp.adjoint[L]:
1414
- return
1415
- for i in range(n):
1416
- tmp[b_start + i] = 0.0
1417
-
1418
- dense_subs(n, L_start, b_start, L, wp.adjoint[x], tmp)
1419
-
1420
- for i in range(n):
1421
- wp.adjoint[b][b_start + i] += tmp[b_start + i]
1422
-
1423
- # A* = -adj_b*x^T
1424
- for i in range(n):
1425
- for j in range(n):
1426
- wp.adjoint[L][L_start + dense_index(n, i, j)] += -tmp[b_start + i] * x[b_start + j]
1427
-
1428
- for i in range(n):
1429
- for j in range(n):
1430
- wp.adjoint[A][L_start + dense_index(n, i, j)] += -tmp[b_start + i] * x[b_start + j]
1431
-
1432
-
1433
- @wp.kernel
1434
- def eval_dense_solve_batched(
1435
- L_start: wp.array(dtype=int),
1436
- L_dim: wp.array(dtype=int),
1437
- b_start: wp.array(dtype=int),
1438
- A: wp.array(dtype=float),
1439
- L: wp.array(dtype=float),
1440
- b: wp.array(dtype=float),
1441
- # outputs
1442
- x: wp.array(dtype=float),
1443
- tmp: wp.array(dtype=float),
1444
- ):
1445
- batch = wp.tid()
1446
-
1447
- dense_solve(L_dim[batch], L_start[batch], b_start[batch], A, L, b, x, tmp)
1448
-
1449
-
1450
- @wp.kernel
1451
- def integrate_generalized_joints(
1452
- joint_type: wp.array(dtype=int),
1453
- joint_q_start: wp.array(dtype=int),
1454
- joint_qd_start: wp.array(dtype=int),
1455
- joint_axis_dim: wp.array(dtype=int, ndim=2),
1456
- joint_q: wp.array(dtype=float),
1457
- joint_qd: wp.array(dtype=float),
1458
- joint_qdd: wp.array(dtype=float),
1459
- dt: float,
1460
- # outputs
1461
- joint_q_new: wp.array(dtype=float),
1462
- joint_qd_new: wp.array(dtype=float),
1463
- ):
1464
- # one thread per-articulation
1465
- index = wp.tid()
1466
-
1467
- type = joint_type[index]
1468
- coord_start = joint_q_start[index]
1469
- dof_start = joint_qd_start[index]
1470
- lin_axis_count = joint_axis_dim[index, 0]
1471
- ang_axis_count = joint_axis_dim[index, 1]
1472
-
1473
- jcalc_integrate(
1474
- type,
1475
- joint_q,
1476
- joint_qd,
1477
- joint_qdd,
1478
- coord_start,
1479
- dof_start,
1480
- lin_axis_count,
1481
- ang_axis_count,
1482
- dt,
1483
- joint_q_new,
1484
- joint_qd_new,
1485
- )
1486
-
1487
-
1488
- class FeatherstoneIntegrator(Integrator):
1489
- """A semi-implicit integrator using symplectic Euler that operates
1490
- on reduced (also called generalized) coordinates to simulate articulated rigid body dynamics
1491
- based on Featherstone's composite rigid body algorithm (CRBA).
1492
-
1493
- See: Featherstone, Roy. Rigid Body Dynamics Algorithms. Springer US, 2014.
1494
-
1495
- Instead of maximal coordinates :attr:`State.body_q` (rigid body positions) and :attr:`State.body_qd`
1496
- (rigid body velocities) as is the case :class:`SemiImplicitIntegrator`, :class:`FeatherstoneIntegrator`
1497
- uses :attr:`State.joint_q` and :attr:`State.joint_qd` to represent the positions and velocities of
1498
- joints without allowing any redundant degrees of freedom.
1499
-
1500
- After constructing :class:`Model` and :class:`State` objects this time-integrator
1501
- may be used to advance the simulation state forward in time.
1502
-
1503
- Note:
1504
- Unlike :class:`SemiImplicitIntegrator` and :class:`XPBDIntegrator`, :class:`FeatherstoneIntegrator` does not simulate rigid bodies with nonzero mass as floating bodies if they are not connected through any joints. Floating-base systems require an explicit free joint with which the body is connected to the world, see :meth:`ModelBuilder.add_joint_free`.
1505
-
1506
- Semi-implicit time integration is a variational integrator that
1507
- preserves energy, however it not unconditionally stable, and requires a time-step
1508
- small enough to support the required stiffness and damping forces.
1509
-
1510
- See: https://en.wikipedia.org/wiki/Semi-implicit_Euler_method
1511
-
1512
- Example
1513
- -------
1514
-
1515
- .. code-block:: python
1516
-
1517
- integrator = wp.FeatherstoneIntegrator(model)
1518
-
1519
- # simulation loop
1520
- for i in range(100):
1521
- state = integrator.simulate(model, state_in, state_out, dt)
1522
-
1523
- Note:
1524
- The :class:`FeatherstoneIntegrator` requires the :class:`Model` to be passed in as a constructor argument.
1525
-
1526
- """
1527
-
1528
- def __init__(
1529
- self,
1530
- model,
1531
- angular_damping=0.05,
1532
- update_mass_matrix_every=1,
1533
- friction_smoothing=1.0,
1534
- use_tile_gemm=False,
1535
- fuse_cholesky=True,
1536
- ):
1537
- """
1538
- Args:
1539
- model (Model): the model to be simulated.
1540
- angular_damping (float, optional): Angular damping factor. Defaults to 0.05.
1541
- update_mass_matrix_every (int, optional): How often to update the mass matrix (every n-th time the :meth:`simulate` function gets called). Defaults to 1.
1542
- friction_smoothing (float, optional): The delta value for the Huber norm (see :func:`warp.math.norm_huber`) used for the friction velocity normalization. Defaults to 1.0.
1543
- """
1544
- self.angular_damping = angular_damping
1545
- self.update_mass_matrix_every = update_mass_matrix_every
1546
- self.friction_smoothing = friction_smoothing
1547
- self.use_tile_gemm = use_tile_gemm
1548
- self.fuse_cholesky = fuse_cholesky
1549
-
1550
- self._step = 0
1551
-
1552
- self.compute_articulation_indices(model)
1553
- self.allocate_model_aux_vars(model)
1554
-
1555
- if self.use_tile_gemm:
1556
- # create a custom kernel to evaluate the system matrix for this type
1557
- if self.fuse_cholesky:
1558
- self.eval_inertia_matrix_cholesky_kernel = create_inertia_matrix_cholesky_kernel(
1559
- int(self.joint_count), int(self.dof_count)
1560
- )
1561
- else:
1562
- self.eval_inertia_matrix_kernel = create_inertia_matrix_kernel(
1563
- int(self.joint_count), int(self.dof_count)
1564
- )
1565
-
1566
- # ensure matrix is reloaded since otherwise an unload can happen during graph capture
1567
- # todo: should not be necessary?
1568
- wp.load_module(device=wp.get_device())
1569
-
1570
- def compute_articulation_indices(self, model):
1571
- # calculate total size and offsets of Jacobian and mass matrices for entire system
1572
- if model.joint_count:
1573
- self.J_size = 0
1574
- self.M_size = 0
1575
- self.H_size = 0
1576
-
1577
- articulation_J_start = []
1578
- articulation_M_start = []
1579
- articulation_H_start = []
1580
-
1581
- articulation_M_rows = []
1582
- articulation_H_rows = []
1583
- articulation_J_rows = []
1584
- articulation_J_cols = []
1585
-
1586
- articulation_dof_start = []
1587
- articulation_coord_start = []
1588
-
1589
- articulation_start = model.articulation_start.numpy()
1590
- joint_q_start = model.joint_q_start.numpy()
1591
- joint_qd_start = model.joint_qd_start.numpy()
1592
-
1593
- for i in range(model.articulation_count):
1594
- first_joint = articulation_start[i]
1595
- last_joint = articulation_start[i + 1]
1596
-
1597
- first_coord = joint_q_start[first_joint]
1598
-
1599
- first_dof = joint_qd_start[first_joint]
1600
- last_dof = joint_qd_start[last_joint]
1601
-
1602
- joint_count = last_joint - first_joint
1603
- dof_count = last_dof - first_dof
1604
-
1605
- articulation_J_start.append(self.J_size)
1606
- articulation_M_start.append(self.M_size)
1607
- articulation_H_start.append(self.H_size)
1608
- articulation_dof_start.append(first_dof)
1609
- articulation_coord_start.append(first_coord)
1610
-
1611
- # bit of data duplication here, but will leave it as such for clarity
1612
- articulation_M_rows.append(joint_count * 6)
1613
- articulation_H_rows.append(dof_count)
1614
- articulation_J_rows.append(joint_count * 6)
1615
- articulation_J_cols.append(dof_count)
1616
-
1617
- if self.use_tile_gemm:
1618
- # store the joint and dof count assuming all
1619
- # articulations have the same structure
1620
- self.joint_count = joint_count
1621
- self.dof_count = dof_count
1622
-
1623
- self.J_size += 6 * joint_count * dof_count
1624
- self.M_size += 6 * joint_count * 6 * joint_count
1625
- self.H_size += dof_count * dof_count
1626
-
1627
- # matrix offsets for batched gemm
1628
- self.articulation_J_start = wp.array(articulation_J_start, dtype=wp.int32, device=model.device)
1629
- self.articulation_M_start = wp.array(articulation_M_start, dtype=wp.int32, device=model.device)
1630
- self.articulation_H_start = wp.array(articulation_H_start, dtype=wp.int32, device=model.device)
1631
-
1632
- self.articulation_M_rows = wp.array(articulation_M_rows, dtype=wp.int32, device=model.device)
1633
- self.articulation_H_rows = wp.array(articulation_H_rows, dtype=wp.int32, device=model.device)
1634
- self.articulation_J_rows = wp.array(articulation_J_rows, dtype=wp.int32, device=model.device)
1635
- self.articulation_J_cols = wp.array(articulation_J_cols, dtype=wp.int32, device=model.device)
1636
-
1637
- self.articulation_dof_start = wp.array(articulation_dof_start, dtype=wp.int32, device=model.device)
1638
- self.articulation_coord_start = wp.array(articulation_coord_start, dtype=wp.int32, device=model.device)
1639
-
1640
- def allocate_model_aux_vars(self, model):
1641
- # allocate mass, Jacobian matrices, and other auxiliary variables pertaining to the model
1642
- if model.joint_count:
1643
- # system matrices
1644
- self.M = wp.zeros((self.M_size,), dtype=wp.float32, device=model.device, requires_grad=model.requires_grad)
1645
- self.J = wp.zeros((self.J_size,), dtype=wp.float32, device=model.device, requires_grad=model.requires_grad)
1646
- self.P = wp.empty_like(self.J, requires_grad=model.requires_grad)
1647
- self.H = wp.empty((self.H_size,), dtype=wp.float32, device=model.device, requires_grad=model.requires_grad)
1648
-
1649
- # zero since only upper triangle is set which can trigger NaN detection
1650
- self.L = wp.zeros_like(self.H)
1651
-
1652
- if model.body_count:
1653
- self.body_I_m = wp.empty(
1654
- (model.body_count,), dtype=wp.spatial_matrix, device=model.device, requires_grad=model.requires_grad
1655
- )
1656
- wp.launch(
1657
- compute_spatial_inertia,
1658
- model.body_count,
1659
- inputs=[model.body_inertia, model.body_mass],
1660
- outputs=[self.body_I_m],
1661
- device=model.device,
1662
- )
1663
- self.body_X_com = wp.empty(
1664
- (model.body_count,), dtype=wp.transform, device=model.device, requires_grad=model.requires_grad
1665
- )
1666
- wp.launch(
1667
- compute_com_transforms,
1668
- model.body_count,
1669
- inputs=[model.body_com],
1670
- outputs=[self.body_X_com],
1671
- device=model.device,
1672
- )
1673
-
1674
- def allocate_state_aux_vars(self, model, target, requires_grad):
1675
- # allocate auxiliary variables that vary with state
1676
- if model.body_count:
1677
- # joints
1678
- target.joint_qdd = wp.zeros_like(model.joint_qd, requires_grad=requires_grad)
1679
- target.joint_tau = wp.empty_like(model.joint_qd, requires_grad=requires_grad)
1680
- if requires_grad:
1681
- # used in the custom grad implementation of eval_dense_solve_batched
1682
- target.joint_solve_tmp = wp.zeros_like(model.joint_qd, requires_grad=True)
1683
- else:
1684
- target.joint_solve_tmp = None
1685
- target.joint_S_s = wp.empty(
1686
- (model.joint_dof_count,),
1687
- dtype=wp.spatial_vector,
1688
- device=model.device,
1689
- requires_grad=requires_grad,
1690
- )
1691
-
1692
- # derived rigid body data (maximal coordinates)
1693
- target.body_q_com = wp.empty_like(model.body_q, requires_grad=requires_grad)
1694
- target.body_I_s = wp.empty(
1695
- (model.body_count,), dtype=wp.spatial_matrix, device=model.device, requires_grad=requires_grad
1696
- )
1697
- target.body_v_s = wp.empty(
1698
- (model.body_count,), dtype=wp.spatial_vector, device=model.device, requires_grad=requires_grad
1699
- )
1700
- target.body_a_s = wp.empty(
1701
- (model.body_count,), dtype=wp.spatial_vector, device=model.device, requires_grad=requires_grad
1702
- )
1703
- target.body_f_s = wp.zeros(
1704
- (model.body_count,), dtype=wp.spatial_vector, device=model.device, requires_grad=requires_grad
1705
- )
1706
- target.body_ft_s = wp.zeros(
1707
- (model.body_count,), dtype=wp.spatial_vector, device=model.device, requires_grad=requires_grad
1708
- )
1709
-
1710
- target._featherstone_augmented = True
1711
-
1712
- def simulate(self, model: Model, state_in: State, state_out: State, dt: float, control: Control = None):
1713
- requires_grad = state_in.requires_grad
1714
-
1715
- # optionally create dynamical auxiliary variables
1716
- if requires_grad:
1717
- state_aug = state_out
1718
- else:
1719
- state_aug = self
1720
-
1721
- if not getattr(state_aug, "_featherstone_augmented", False):
1722
- self.allocate_state_aux_vars(model, state_aug, requires_grad)
1723
- if control is None:
1724
- control = model.control(clone_variables=False)
1725
-
1726
- with wp.ScopedTimer("simulate", False):
1727
- particle_f = None
1728
- body_f = None
1729
-
1730
- if state_in.particle_count:
1731
- particle_f = state_in.particle_f
1732
-
1733
- if state_in.body_count:
1734
- body_f = state_in.body_f
1735
-
1736
- # damped springs
1737
- eval_spring_forces(model, state_in, particle_f)
1738
-
1739
- # triangle elastic and lift/drag forces
1740
- eval_triangle_forces(model, state_in, control, particle_f)
1741
-
1742
- # triangle/triangle contacts
1743
- eval_triangle_contact_forces(model, state_in, particle_f)
1744
-
1745
- # triangle bending
1746
- eval_bending_forces(model, state_in, particle_f)
1747
-
1748
- # tetrahedral FEM
1749
- eval_tetrahedral_forces(model, state_in, control, particle_f)
1750
-
1751
- # particle-particle interactions
1752
- eval_particle_forces(model, state_in, particle_f)
1753
-
1754
- # particle ground contacts
1755
- eval_particle_ground_contact_forces(model, state_in, particle_f)
1756
-
1757
- # particle shape contact
1758
- eval_particle_body_contact_forces(model, state_in, particle_f, body_f, body_f_in_world_frame=True)
1759
-
1760
- # muscles
1761
- if False:
1762
- eval_muscle_forces(model, state_in, control, body_f)
1763
-
1764
- # ----------------------------
1765
- # articulations
1766
-
1767
- if model.joint_count:
1768
- # evaluate body transforms
1769
- wp.launch(
1770
- eval_rigid_fk,
1771
- dim=model.articulation_count,
1772
- inputs=[
1773
- model.articulation_start,
1774
- model.joint_type,
1775
- model.joint_parent,
1776
- model.joint_child,
1777
- model.joint_q_start,
1778
- state_in.joint_q,
1779
- model.joint_X_p,
1780
- model.joint_X_c,
1781
- self.body_X_com,
1782
- model.joint_axis,
1783
- model.joint_axis_start,
1784
- model.joint_axis_dim,
1785
- ],
1786
- outputs=[state_in.body_q, state_aug.body_q_com],
1787
- device=model.device,
1788
- )
1789
-
1790
- # print("body_X_sc:")
1791
- # print(state_in.body_q.numpy())
1792
-
1793
- # evaluate joint inertias, motion vectors, and forces
1794
- state_aug.body_f_s.zero_()
1795
- wp.launch(
1796
- eval_rigid_id,
1797
- dim=model.articulation_count,
1798
- inputs=[
1799
- model.articulation_start,
1800
- model.joint_type,
1801
- model.joint_parent,
1802
- model.joint_child,
1803
- model.joint_q_start,
1804
- model.joint_qd_start,
1805
- state_in.joint_q,
1806
- state_in.joint_qd,
1807
- model.joint_axis,
1808
- model.joint_axis_start,
1809
- model.joint_axis_dim,
1810
- self.body_I_m,
1811
- state_in.body_q,
1812
- state_aug.body_q_com,
1813
- model.joint_X_p,
1814
- model.joint_X_c,
1815
- model.gravity,
1816
- ],
1817
- outputs=[
1818
- state_aug.joint_S_s,
1819
- state_aug.body_I_s,
1820
- state_aug.body_v_s,
1821
- state_aug.body_f_s,
1822
- state_aug.body_a_s,
1823
- ],
1824
- device=model.device,
1825
- )
1826
-
1827
- if model.rigid_contact_max and (
1828
- (model.ground and model.shape_ground_contact_pair_count) or model.shape_contact_pair_count
1829
- ):
1830
- wp.launch(
1831
- kernel=eval_rigid_contacts,
1832
- dim=model.rigid_contact_max,
1833
- inputs=[
1834
- state_in.body_q,
1835
- state_aug.body_v_s,
1836
- model.body_com,
1837
- model.shape_materials,
1838
- model.shape_geo,
1839
- model.shape_body,
1840
- model.rigid_contact_count,
1841
- model.rigid_contact_point0,
1842
- model.rigid_contact_point1,
1843
- model.rigid_contact_normal,
1844
- model.rigid_contact_shape0,
1845
- model.rigid_contact_shape1,
1846
- True,
1847
- self.friction_smoothing,
1848
- ],
1849
- outputs=[body_f],
1850
- device=model.device,
1851
- )
1852
-
1853
- # if model.rigid_contact_count.numpy()[0] > 0:
1854
- # print(body_f.numpy())
1855
-
1856
- if model.articulation_count:
1857
- # evaluate joint torques
1858
- state_aug.body_ft_s.zero_()
1859
- wp.launch(
1860
- eval_rigid_tau,
1861
- dim=model.articulation_count,
1862
- inputs=[
1863
- model.articulation_start,
1864
- model.joint_type,
1865
- model.joint_parent,
1866
- model.joint_child,
1867
- model.joint_q_start,
1868
- model.joint_qd_start,
1869
- model.joint_axis_start,
1870
- model.joint_axis_dim,
1871
- model.joint_axis_mode,
1872
- state_in.joint_q,
1873
- state_in.joint_qd,
1874
- control.joint_act,
1875
- model.joint_target_ke,
1876
- model.joint_target_kd,
1877
- model.joint_limit_lower,
1878
- model.joint_limit_upper,
1879
- model.joint_limit_ke,
1880
- model.joint_limit_kd,
1881
- state_aug.joint_S_s,
1882
- state_aug.body_f_s,
1883
- body_f,
1884
- ],
1885
- outputs=[
1886
- state_aug.body_ft_s,
1887
- state_aug.joint_tau,
1888
- ],
1889
- device=model.device,
1890
- )
1891
-
1892
- # print("joint_tau:")
1893
- # print(state_aug.joint_tau.numpy())
1894
- # print("body_q:")
1895
- # print(state_in.body_q.numpy())
1896
- # print("body_qd:")
1897
- # print(state_in.body_qd.numpy())
1898
-
1899
- if self._step % self.update_mass_matrix_every == 0:
1900
- # build J
1901
- wp.launch(
1902
- eval_rigid_jacobian,
1903
- dim=model.articulation_count,
1904
- inputs=[
1905
- model.articulation_start,
1906
- self.articulation_J_start,
1907
- model.joint_ancestor,
1908
- model.joint_qd_start,
1909
- state_aug.joint_S_s,
1910
- ],
1911
- outputs=[self.J],
1912
- device=model.device,
1913
- )
1914
-
1915
- # build M
1916
- wp.launch(
1917
- eval_rigid_mass,
1918
- dim=model.articulation_count,
1919
- inputs=[
1920
- model.articulation_start,
1921
- self.articulation_M_start,
1922
- state_aug.body_I_s,
1923
- ],
1924
- outputs=[self.M],
1925
- device=model.device,
1926
- )
1927
-
1928
- if self.use_tile_gemm:
1929
- # reshape arrays
1930
- M_tiled = self.M.reshape((-1, 6 * self.joint_count, 6 * self.joint_count))
1931
- J_tiled = self.J.reshape((-1, 6 * self.joint_count, self.dof_count))
1932
- R_tiled = model.joint_armature.reshape((-1, self.dof_count))
1933
- H_tiled = self.H.reshape((-1, self.dof_count, self.dof_count))
1934
- L_tiled = self.L.reshape((-1, self.dof_count, self.dof_count))
1935
- assert H_tiled.shape == (model.articulation_count, 18, 18)
1936
- assert L_tiled.shape == (model.articulation_count, 18, 18)
1937
- assert R_tiled.shape == (model.articulation_count, 18)
1938
-
1939
- if self.fuse_cholesky:
1940
- wp.launch_tiled(
1941
- self.eval_inertia_matrix_cholesky_kernel,
1942
- dim=model.articulation_count,
1943
- inputs=[J_tiled, M_tiled, R_tiled],
1944
- outputs=[H_tiled, L_tiled],
1945
- device=model.device,
1946
- block_dim=64,
1947
- )
1948
-
1949
- else:
1950
- wp.launch_tiled(
1951
- self.eval_inertia_matrix_kernel,
1952
- dim=model.articulation_count,
1953
- inputs=[J_tiled, M_tiled],
1954
- outputs=[H_tiled],
1955
- device=model.device,
1956
- block_dim=256,
1957
- )
1958
-
1959
- wp.launch(
1960
- eval_dense_cholesky_batched,
1961
- dim=model.articulation_count,
1962
- inputs=[
1963
- self.articulation_H_start,
1964
- self.articulation_H_rows,
1965
- self.H,
1966
- model.joint_armature,
1967
- ],
1968
- outputs=[self.L],
1969
- device=model.device,
1970
- )
1971
-
1972
- # import numpy as np
1973
- # J = J_tiled.numpy()
1974
- # M = M_tiled.numpy()
1975
- # R = R_tiled.numpy()
1976
- # for i in range(model.articulation_count):
1977
- # r = R[i,:,0]
1978
- # H = J[i].T @ M[i] @ J[i]
1979
- # L = np.linalg.cholesky(H + np.diag(r))
1980
- # np.testing.assert_allclose(H, H_tiled.numpy()[i], rtol=1e-2, atol=1e-2)
1981
- # np.testing.assert_allclose(L, L_tiled.numpy()[i], rtol=1e-1, atol=1e-1)
1982
-
1983
- else:
1984
- # form P = M*J
1985
- wp.launch(
1986
- eval_dense_gemm_batched,
1987
- dim=model.articulation_count,
1988
- inputs=[
1989
- self.articulation_M_rows,
1990
- self.articulation_J_cols,
1991
- self.articulation_J_rows,
1992
- False,
1993
- False,
1994
- self.articulation_M_start,
1995
- self.articulation_J_start,
1996
- # P start is the same as J start since it has the same dims as J
1997
- self.articulation_J_start,
1998
- self.M,
1999
- self.J,
2000
- ],
2001
- outputs=[self.P],
2002
- device=model.device,
2003
- )
2004
-
2005
- # form H = J^T*P
2006
- wp.launch(
2007
- eval_dense_gemm_batched,
2008
- dim=model.articulation_count,
2009
- inputs=[
2010
- self.articulation_J_cols,
2011
- self.articulation_J_cols,
2012
- # P rows is the same as J rows
2013
- self.articulation_J_rows,
2014
- True,
2015
- False,
2016
- self.articulation_J_start,
2017
- # P start is the same as J start since it has the same dims as J
2018
- self.articulation_J_start,
2019
- self.articulation_H_start,
2020
- self.J,
2021
- self.P,
2022
- ],
2023
- outputs=[self.H],
2024
- device=model.device,
2025
- )
2026
-
2027
- # compute decomposition
2028
- wp.launch(
2029
- eval_dense_cholesky_batched,
2030
- dim=model.articulation_count,
2031
- inputs=[
2032
- self.articulation_H_start,
2033
- self.articulation_H_rows,
2034
- self.H,
2035
- model.joint_armature,
2036
- ],
2037
- outputs=[self.L],
2038
- device=model.device,
2039
- )
2040
-
2041
- # print("joint_act:")
2042
- # print(control.joint_act.numpy())
2043
- # print("joint_tau:")
2044
- # print(state_aug.joint_tau.numpy())
2045
- # print("H:")
2046
- # print(self.H.numpy())
2047
- # print("L:")
2048
- # print(self.L.numpy())
2049
-
2050
- # solve for qdd
2051
- state_aug.joint_qdd.zero_()
2052
- wp.launch(
2053
- eval_dense_solve_batched,
2054
- dim=model.articulation_count,
2055
- inputs=[
2056
- self.articulation_H_start,
2057
- self.articulation_H_rows,
2058
- self.articulation_dof_start,
2059
- self.H,
2060
- self.L,
2061
- state_aug.joint_tau,
2062
- ],
2063
- outputs=[
2064
- state_aug.joint_qdd,
2065
- state_aug.joint_solve_tmp,
2066
- ],
2067
- device=model.device,
2068
- )
2069
- # print("joint_qdd:")
2070
- # print(state_aug.joint_qdd.numpy())
2071
- # print("\n\n")
2072
-
2073
- # -------------------------------------
2074
- # integrate bodies
2075
-
2076
- if model.joint_count:
2077
- wp.launch(
2078
- kernel=integrate_generalized_joints,
2079
- dim=model.joint_count,
2080
- inputs=[
2081
- model.joint_type,
2082
- model.joint_q_start,
2083
- model.joint_qd_start,
2084
- model.joint_axis_dim,
2085
- state_in.joint_q,
2086
- state_in.joint_qd,
2087
- state_aug.joint_qdd,
2088
- dt,
2089
- ],
2090
- outputs=[state_out.joint_q, state_out.joint_qd],
2091
- device=model.device,
2092
- )
2093
-
2094
- # update maximal coordinates
2095
- eval_fk(model, state_out.joint_q, state_out.joint_qd, None, state_out)
2096
-
2097
- self.integrate_particles(model, state_in, state_out, dt)
2098
-
2099
- self._step += 1
2100
-
2101
- return state_out