warp-lang 1.0.2__py3-none-macosx_10_13_universal2.whl → 1.1.0__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (346) hide show
  1. warp/__init__.py +108 -97
  2. warp/__init__.pyi +1 -1
  3. warp/bin/libwarp-clang.dylib +0 -0
  4. warp/bin/libwarp.dylib +0 -0
  5. warp/build.py +115 -113
  6. warp/build_dll.py +383 -375
  7. warp/builtins.py +3425 -3354
  8. warp/codegen.py +2878 -2792
  9. warp/config.py +40 -36
  10. warp/constants.py +45 -45
  11. warp/context.py +5194 -5102
  12. warp/dlpack.py +442 -442
  13. warp/examples/__init__.py +16 -16
  14. warp/examples/assets/bear.usd +0 -0
  15. warp/examples/assets/bunny.usd +0 -0
  16. warp/examples/assets/cartpole.urdf +110 -110
  17. warp/examples/assets/crazyflie.usd +0 -0
  18. warp/examples/assets/cube.usd +0 -0
  19. warp/examples/assets/nv_ant.xml +92 -92
  20. warp/examples/assets/nv_humanoid.xml +183 -183
  21. warp/examples/assets/quadruped.urdf +267 -267
  22. warp/examples/assets/rocks.nvdb +0 -0
  23. warp/examples/assets/rocks.usd +0 -0
  24. warp/examples/assets/sphere.usd +0 -0
  25. warp/examples/benchmarks/benchmark_api.py +383 -383
  26. warp/examples/benchmarks/benchmark_cloth.py +278 -277
  27. warp/examples/benchmarks/benchmark_cloth_cupy.py +88 -88
  28. warp/examples/benchmarks/benchmark_cloth_jax.py +97 -100
  29. warp/examples/benchmarks/benchmark_cloth_numba.py +146 -142
  30. warp/examples/benchmarks/benchmark_cloth_numpy.py +77 -77
  31. warp/examples/benchmarks/benchmark_cloth_pytorch.py +86 -86
  32. warp/examples/benchmarks/benchmark_cloth_taichi.py +112 -112
  33. warp/examples/benchmarks/benchmark_cloth_warp.py +146 -146
  34. warp/examples/benchmarks/benchmark_launches.py +295 -295
  35. warp/examples/browse.py +29 -29
  36. warp/examples/core/example_dem.py +234 -219
  37. warp/examples/core/example_fluid.py +293 -267
  38. warp/examples/core/example_graph_capture.py +144 -126
  39. warp/examples/core/example_marching_cubes.py +188 -174
  40. warp/examples/core/example_mesh.py +174 -155
  41. warp/examples/core/example_mesh_intersect.py +205 -193
  42. warp/examples/core/example_nvdb.py +176 -170
  43. warp/examples/core/example_raycast.py +105 -90
  44. warp/examples/core/example_raymarch.py +199 -178
  45. warp/examples/core/example_render_opengl.py +185 -141
  46. warp/examples/core/example_sph.py +405 -387
  47. warp/examples/core/example_torch.py +222 -181
  48. warp/examples/core/example_wave.py +263 -248
  49. warp/examples/fem/bsr_utils.py +378 -380
  50. warp/examples/fem/example_apic_fluid.py +407 -389
  51. warp/examples/fem/example_convection_diffusion.py +182 -168
  52. warp/examples/fem/example_convection_diffusion_dg.py +219 -209
  53. warp/examples/fem/example_convection_diffusion_dg0.py +204 -194
  54. warp/examples/fem/example_deformed_geometry.py +177 -159
  55. warp/examples/fem/example_diffusion.py +201 -173
  56. warp/examples/fem/example_diffusion_3d.py +177 -152
  57. warp/examples/fem/example_diffusion_mgpu.py +221 -214
  58. warp/examples/fem/example_mixed_elasticity.py +244 -222
  59. warp/examples/fem/example_navier_stokes.py +259 -243
  60. warp/examples/fem/example_stokes.py +220 -192
  61. warp/examples/fem/example_stokes_transfer.py +265 -249
  62. warp/examples/fem/mesh_utils.py +133 -109
  63. warp/examples/fem/plot_utils.py +292 -287
  64. warp/examples/optim/example_bounce.py +260 -246
  65. warp/examples/optim/example_cloth_throw.py +222 -209
  66. warp/examples/optim/example_diffray.py +566 -536
  67. warp/examples/optim/example_drone.py +864 -835
  68. warp/examples/optim/example_inverse_kinematics.py +176 -168
  69. warp/examples/optim/example_inverse_kinematics_torch.py +185 -169
  70. warp/examples/optim/example_spring_cage.py +239 -231
  71. warp/examples/optim/example_trajectory.py +223 -199
  72. warp/examples/optim/example_walker.py +306 -293
  73. warp/examples/sim/example_cartpole.py +139 -129
  74. warp/examples/sim/example_cloth.py +196 -186
  75. warp/examples/sim/example_granular.py +124 -111
  76. warp/examples/sim/example_granular_collision_sdf.py +197 -186
  77. warp/examples/sim/example_jacobian_ik.py +236 -214
  78. warp/examples/sim/example_particle_chain.py +118 -105
  79. warp/examples/sim/example_quadruped.py +193 -180
  80. warp/examples/sim/example_rigid_chain.py +197 -187
  81. warp/examples/sim/example_rigid_contact.py +189 -177
  82. warp/examples/sim/example_rigid_force.py +127 -125
  83. warp/examples/sim/example_rigid_gyroscopic.py +109 -95
  84. warp/examples/sim/example_rigid_soft_contact.py +134 -122
  85. warp/examples/sim/example_soft_body.py +190 -177
  86. warp/fabric.py +337 -335
  87. warp/fem/__init__.py +60 -27
  88. warp/fem/cache.py +401 -388
  89. warp/fem/dirichlet.py +178 -179
  90. warp/fem/domain.py +262 -263
  91. warp/fem/field/__init__.py +100 -101
  92. warp/fem/field/field.py +148 -149
  93. warp/fem/field/nodal_field.py +298 -299
  94. warp/fem/field/restriction.py +22 -21
  95. warp/fem/field/test.py +180 -181
  96. warp/fem/field/trial.py +183 -183
  97. warp/fem/geometry/__init__.py +15 -19
  98. warp/fem/geometry/closest_point.py +69 -70
  99. warp/fem/geometry/deformed_geometry.py +270 -271
  100. warp/fem/geometry/element.py +744 -744
  101. warp/fem/geometry/geometry.py +184 -186
  102. warp/fem/geometry/grid_2d.py +380 -373
  103. warp/fem/geometry/grid_3d.py +441 -435
  104. warp/fem/geometry/hexmesh.py +953 -953
  105. warp/fem/geometry/partition.py +374 -376
  106. warp/fem/geometry/quadmesh_2d.py +532 -532
  107. warp/fem/geometry/tetmesh.py +840 -840
  108. warp/fem/geometry/trimesh_2d.py +577 -577
  109. warp/fem/integrate.py +1630 -1615
  110. warp/fem/operator.py +190 -191
  111. warp/fem/polynomial.py +214 -213
  112. warp/fem/quadrature/__init__.py +2 -2
  113. warp/fem/quadrature/pic_quadrature.py +243 -245
  114. warp/fem/quadrature/quadrature.py +295 -294
  115. warp/fem/space/__init__.py +294 -292
  116. warp/fem/space/basis_space.py +488 -489
  117. warp/fem/space/collocated_function_space.py +100 -105
  118. warp/fem/space/dof_mapper.py +236 -236
  119. warp/fem/space/function_space.py +148 -145
  120. warp/fem/space/grid_2d_function_space.py +267 -267
  121. warp/fem/space/grid_3d_function_space.py +305 -306
  122. warp/fem/space/hexmesh_function_space.py +350 -352
  123. warp/fem/space/partition.py +350 -350
  124. warp/fem/space/quadmesh_2d_function_space.py +368 -369
  125. warp/fem/space/restriction.py +158 -160
  126. warp/fem/space/shape/__init__.py +13 -15
  127. warp/fem/space/shape/cube_shape_function.py +738 -738
  128. warp/fem/space/shape/shape_function.py +102 -103
  129. warp/fem/space/shape/square_shape_function.py +611 -611
  130. warp/fem/space/shape/tet_shape_function.py +565 -567
  131. warp/fem/space/shape/triangle_shape_function.py +429 -429
  132. warp/fem/space/tetmesh_function_space.py +294 -292
  133. warp/fem/space/topology.py +297 -295
  134. warp/fem/space/trimesh_2d_function_space.py +223 -221
  135. warp/fem/types.py +77 -77
  136. warp/fem/utils.py +495 -495
  137. warp/jax.py +166 -141
  138. warp/jax_experimental.py +341 -339
  139. warp/native/array.h +1072 -1025
  140. warp/native/builtin.h +1560 -1560
  141. warp/native/bvh.cpp +398 -398
  142. warp/native/bvh.cu +525 -525
  143. warp/native/bvh.h +429 -429
  144. warp/native/clang/clang.cpp +495 -464
  145. warp/native/crt.cpp +31 -31
  146. warp/native/crt.h +334 -334
  147. warp/native/cuda_crt.h +1049 -1049
  148. warp/native/cuda_util.cpp +549 -540
  149. warp/native/cuda_util.h +288 -203
  150. warp/native/cutlass_gemm.cpp +34 -34
  151. warp/native/cutlass_gemm.cu +372 -372
  152. warp/native/error.cpp +66 -66
  153. warp/native/error.h +27 -27
  154. warp/native/fabric.h +228 -228
  155. warp/native/hashgrid.cpp +301 -278
  156. warp/native/hashgrid.cu +78 -77
  157. warp/native/hashgrid.h +227 -227
  158. warp/native/initializer_array.h +32 -32
  159. warp/native/intersect.h +1204 -1204
  160. warp/native/intersect_adj.h +365 -365
  161. warp/native/intersect_tri.h +322 -322
  162. warp/native/marching.cpp +2 -2
  163. warp/native/marching.cu +497 -497
  164. warp/native/marching.h +2 -2
  165. warp/native/mat.h +1498 -1498
  166. warp/native/matnn.h +333 -333
  167. warp/native/mesh.cpp +203 -203
  168. warp/native/mesh.cu +293 -293
  169. warp/native/mesh.h +1887 -1887
  170. warp/native/nanovdb/NanoVDB.h +4782 -4782
  171. warp/native/nanovdb/PNanoVDB.h +2553 -2553
  172. warp/native/nanovdb/PNanoVDBWrite.h +294 -294
  173. warp/native/noise.h +850 -850
  174. warp/native/quat.h +1084 -1084
  175. warp/native/rand.h +299 -299
  176. warp/native/range.h +108 -108
  177. warp/native/reduce.cpp +156 -156
  178. warp/native/reduce.cu +348 -348
  179. warp/native/runlength_encode.cpp +61 -61
  180. warp/native/runlength_encode.cu +46 -46
  181. warp/native/scan.cpp +30 -30
  182. warp/native/scan.cu +36 -36
  183. warp/native/scan.h +7 -7
  184. warp/native/solid_angle.h +442 -442
  185. warp/native/sort.cpp +94 -94
  186. warp/native/sort.cu +97 -97
  187. warp/native/sort.h +14 -14
  188. warp/native/sparse.cpp +337 -337
  189. warp/native/sparse.cu +544 -544
  190. warp/native/spatial.h +630 -630
  191. warp/native/svd.h +562 -562
  192. warp/native/temp_buffer.h +30 -30
  193. warp/native/vec.h +1132 -1132
  194. warp/native/volume.cpp +297 -297
  195. warp/native/volume.cu +32 -32
  196. warp/native/volume.h +538 -538
  197. warp/native/volume_builder.cu +425 -425
  198. warp/native/volume_builder.h +19 -19
  199. warp/native/warp.cpp +1057 -1052
  200. warp/native/warp.cu +2943 -2828
  201. warp/native/warp.h +313 -305
  202. warp/optim/__init__.py +9 -9
  203. warp/optim/adam.py +120 -120
  204. warp/optim/linear.py +1104 -939
  205. warp/optim/sgd.py +104 -92
  206. warp/render/__init__.py +10 -10
  207. warp/render/render_opengl.py +3217 -3204
  208. warp/render/render_usd.py +768 -749
  209. warp/render/utils.py +152 -150
  210. warp/sim/__init__.py +52 -59
  211. warp/sim/articulation.py +685 -685
  212. warp/sim/collide.py +1594 -1590
  213. warp/sim/import_mjcf.py +489 -481
  214. warp/sim/import_snu.py +220 -221
  215. warp/sim/import_urdf.py +536 -516
  216. warp/sim/import_usd.py +887 -881
  217. warp/sim/inertia.py +316 -317
  218. warp/sim/integrator.py +234 -233
  219. warp/sim/integrator_euler.py +1956 -1956
  220. warp/sim/integrator_featherstone.py +1910 -1991
  221. warp/sim/integrator_xpbd.py +3294 -3312
  222. warp/sim/model.py +4473 -4314
  223. warp/sim/particles.py +113 -112
  224. warp/sim/render.py +417 -403
  225. warp/sim/utils.py +413 -410
  226. warp/sparse.py +1227 -1227
  227. warp/stubs.py +2109 -2469
  228. warp/tape.py +1162 -225
  229. warp/tests/__init__.py +1 -1
  230. warp/tests/__main__.py +4 -4
  231. warp/tests/assets/torus.usda +105 -105
  232. warp/tests/aux_test_class_kernel.py +26 -26
  233. warp/tests/aux_test_compile_consts_dummy.py +10 -10
  234. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -21
  235. warp/tests/aux_test_dependent.py +22 -22
  236. warp/tests/aux_test_grad_customs.py +23 -23
  237. warp/tests/aux_test_reference.py +11 -11
  238. warp/tests/aux_test_reference_reference.py +10 -10
  239. warp/tests/aux_test_square.py +17 -17
  240. warp/tests/aux_test_unresolved_func.py +14 -14
  241. warp/tests/aux_test_unresolved_symbol.py +14 -14
  242. warp/tests/disabled_kinematics.py +239 -239
  243. warp/tests/run_coverage_serial.py +31 -31
  244. warp/tests/test_adam.py +157 -157
  245. warp/tests/test_arithmetic.py +1124 -1124
  246. warp/tests/test_array.py +2417 -2326
  247. warp/tests/test_array_reduce.py +150 -150
  248. warp/tests/test_async.py +668 -656
  249. warp/tests/test_atomic.py +141 -141
  250. warp/tests/test_bool.py +204 -149
  251. warp/tests/test_builtins_resolution.py +1292 -1292
  252. warp/tests/test_bvh.py +164 -171
  253. warp/tests/test_closest_point_edge_edge.py +228 -228
  254. warp/tests/test_codegen.py +566 -553
  255. warp/tests/test_compile_consts.py +97 -101
  256. warp/tests/test_conditional.py +246 -246
  257. warp/tests/test_copy.py +232 -215
  258. warp/tests/test_ctypes.py +632 -632
  259. warp/tests/test_dense.py +67 -67
  260. warp/tests/test_devices.py +91 -98
  261. warp/tests/test_dlpack.py +530 -529
  262. warp/tests/test_examples.py +400 -378
  263. warp/tests/test_fabricarray.py +955 -955
  264. warp/tests/test_fast_math.py +62 -54
  265. warp/tests/test_fem.py +1277 -1278
  266. warp/tests/test_fp16.py +130 -130
  267. warp/tests/test_func.py +338 -337
  268. warp/tests/test_generics.py +571 -571
  269. warp/tests/test_grad.py +746 -640
  270. warp/tests/test_grad_customs.py +333 -336
  271. warp/tests/test_hash_grid.py +210 -164
  272. warp/tests/test_import.py +39 -39
  273. warp/tests/test_indexedarray.py +1134 -1134
  274. warp/tests/test_intersect.py +67 -67
  275. warp/tests/test_jax.py +307 -307
  276. warp/tests/test_large.py +167 -164
  277. warp/tests/test_launch.py +354 -354
  278. warp/tests/test_lerp.py +261 -261
  279. warp/tests/test_linear_solvers.py +191 -171
  280. warp/tests/test_lvalue.py +421 -493
  281. warp/tests/test_marching_cubes.py +65 -65
  282. warp/tests/test_mat.py +1801 -1827
  283. warp/tests/test_mat_lite.py +115 -115
  284. warp/tests/test_mat_scalar_ops.py +2907 -2889
  285. warp/tests/test_math.py +126 -193
  286. warp/tests/test_matmul.py +500 -499
  287. warp/tests/test_matmul_lite.py +410 -410
  288. warp/tests/test_mempool.py +188 -190
  289. warp/tests/test_mesh.py +284 -324
  290. warp/tests/test_mesh_query_aabb.py +228 -241
  291. warp/tests/test_mesh_query_point.py +692 -702
  292. warp/tests/test_mesh_query_ray.py +292 -303
  293. warp/tests/test_mlp.py +276 -276
  294. warp/tests/test_model.py +110 -110
  295. warp/tests/test_modules_lite.py +39 -39
  296. warp/tests/test_multigpu.py +163 -163
  297. warp/tests/test_noise.py +248 -248
  298. warp/tests/test_operators.py +250 -250
  299. warp/tests/test_options.py +123 -125
  300. warp/tests/test_peer.py +133 -137
  301. warp/tests/test_pinned.py +78 -78
  302. warp/tests/test_print.py +54 -54
  303. warp/tests/test_quat.py +2086 -2086
  304. warp/tests/test_rand.py +288 -288
  305. warp/tests/test_reload.py +217 -217
  306. warp/tests/test_rounding.py +179 -179
  307. warp/tests/test_runlength_encode.py +190 -190
  308. warp/tests/test_sim_grad.py +243 -0
  309. warp/tests/test_sim_kinematics.py +91 -97
  310. warp/tests/test_smoothstep.py +168 -168
  311. warp/tests/test_snippet.py +305 -266
  312. warp/tests/test_sparse.py +468 -460
  313. warp/tests/test_spatial.py +2148 -2148
  314. warp/tests/test_streams.py +486 -473
  315. warp/tests/test_struct.py +710 -675
  316. warp/tests/test_tape.py +173 -148
  317. warp/tests/test_torch.py +743 -743
  318. warp/tests/test_transient_module.py +87 -87
  319. warp/tests/test_types.py +556 -659
  320. warp/tests/test_utils.py +490 -499
  321. warp/tests/test_vec.py +1264 -1268
  322. warp/tests/test_vec_lite.py +73 -73
  323. warp/tests/test_vec_scalar_ops.py +2099 -2099
  324. warp/tests/test_verify_fp.py +94 -94
  325. warp/tests/test_volume.py +737 -736
  326. warp/tests/test_volume_write.py +255 -265
  327. warp/tests/unittest_serial.py +37 -37
  328. warp/tests/unittest_suites.py +363 -359
  329. warp/tests/unittest_utils.py +603 -578
  330. warp/tests/unused_test_misc.py +71 -71
  331. warp/tests/walkthrough_debug.py +85 -85
  332. warp/thirdparty/appdirs.py +598 -598
  333. warp/thirdparty/dlpack.py +143 -143
  334. warp/thirdparty/unittest_parallel.py +566 -561
  335. warp/torch.py +321 -295
  336. warp/types.py +4504 -4450
  337. warp/utils.py +1008 -821
  338. {warp_lang-1.0.2.dist-info → warp_lang-1.1.0.dist-info}/LICENSE.md +126 -126
  339. {warp_lang-1.0.2.dist-info → warp_lang-1.1.0.dist-info}/METADATA +338 -400
  340. warp_lang-1.1.0.dist-info/RECORD +352 -0
  341. warp/examples/assets/cube.usda +0 -42
  342. warp/examples/assets/sphere.usda +0 -56
  343. warp/examples/assets/torus.usda +0 -105
  344. warp_lang-1.0.2.dist-info/RECORD +0 -352
  345. {warp_lang-1.0.2.dist-info → warp_lang-1.1.0.dist-info}/WHEEL +0 -0
  346. {warp_lang-1.0.2.dist-info → warp_lang-1.1.0.dist-info}/top_level.txt +0 -0
@@ -1,1991 +1,1910 @@
1
- # Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
2
- # NVIDIA CORPORATION and its licensors retain all intellectual property
3
- # and proprietary rights in and to this software, related documentation
4
- # and any modifications thereto. Any use, reproduction, disclosure or
5
- # distribution of this software and related documentation without an express
6
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
-
8
- import warp as wp
9
-
10
- from .model import Model, State, Control
11
-
12
- from .integrator import Integrator
13
-
14
- from .integrator_euler import (
15
- eval_spring_forces,
16
- eval_triangle_forces,
17
- eval_triangle_contact_forces,
18
- eval_bending_forces,
19
- eval_tetrahedral_forces,
20
- eval_particle_forces,
21
- eval_particle_ground_contact_forces,
22
- eval_particle_body_contact_forces,
23
- eval_muscle_forces,
24
- eval_rigid_contacts,
25
- eval_joint_force,
26
- )
27
-
28
- from .articulation import (
29
- compute_2d_rotational_dofs,
30
- compute_3d_rotational_dofs,
31
- )
32
-
33
-
34
- # Frank & Park definition 3.20, pg 100
35
- @wp.func
36
- def transform_twist(t: wp.transform, x: wp.spatial_vector):
37
- q = wp.transform_get_rotation(t)
38
- p = wp.transform_get_translation(t)
39
-
40
- w = wp.spatial_top(x)
41
- v = wp.spatial_bottom(x)
42
-
43
- w = wp.quat_rotate(q, w)
44
- v = wp.quat_rotate(q, v) + wp.cross(p, w)
45
-
46
- return wp.spatial_vector(w, v)
47
-
48
-
49
- @wp.func
50
- def transform_wrench(t: wp.transform, x: wp.spatial_vector):
51
- q = wp.transform_get_rotation(t)
52
- p = wp.transform_get_translation(t)
53
-
54
- w = wp.spatial_top(x)
55
- v = wp.spatial_bottom(x)
56
-
57
- v = wp.quat_rotate(q, v)
58
- w = wp.quat_rotate(q, w) + wp.cross(p, v)
59
-
60
- return wp.spatial_vector(w, v)
61
-
62
-
63
- @wp.func
64
- def spatial_adjoint(R: wp.mat33, S: wp.mat33):
65
- # T = [R 0]
66
- # [S R]
67
-
68
- # fmt: off
69
- return wp.spatial_matrix(
70
- R[0, 0], R[0, 1], R[0, 2], 0.0, 0.0, 0.0,
71
- R[1, 0], R[1, 1], R[1, 2], 0.0, 0.0, 0.0,
72
- R[2, 0], R[2, 1], R[2, 2], 0.0, 0.0, 0.0,
73
- S[0, 0], S[0, 1], S[0, 2], R[0, 0], R[0, 1], R[0, 2],
74
- S[1, 0], S[1, 1], S[1, 2], R[1, 0], R[1, 1], R[1, 2],
75
- S[2, 0], S[2, 1], S[2, 2], R[2, 0], R[2, 1], R[2, 2],
76
- )
77
- # fmt: on
78
-
79
-
80
- @wp.kernel
81
- def compute_spatial_inertia(
82
- body_inertia: wp.array(dtype=wp.mat33),
83
- body_mass: wp.array(dtype=float),
84
- # outputs
85
- body_I_m: wp.array(dtype=wp.spatial_matrix),
86
- ):
87
- tid = wp.tid()
88
- I = body_inertia[tid]
89
- m = body_mass[tid]
90
- # fmt: off
91
- body_I_m[tid] = wp.spatial_matrix(
92
- I[0, 0], I[0, 1], I[0, 2], 0.0, 0.0, 0.0,
93
- I[1, 0], I[1, 1], I[1, 2], 0.0, 0.0, 0.0,
94
- I[2, 0], I[2, 1], I[2, 2], 0.0, 0.0, 0.0,
95
- 0.0, 0.0, 0.0, m, 0.0, 0.0,
96
- 0.0, 0.0, 0.0, 0.0, m, 0.0,
97
- 0.0, 0.0, 0.0, 0.0, 0.0, m,
98
- )
99
- # fmt: on
100
-
101
-
102
- @wp.kernel
103
- def compute_com_transforms(
104
- body_com: wp.array(dtype=wp.vec3),
105
- # outputs
106
- body_X_com: wp.array(dtype=wp.transform),
107
- ):
108
- tid = wp.tid()
109
- com = body_com[tid]
110
- body_X_com[tid] = wp.transform(com, wp.quat_identity())
111
-
112
-
113
- # computes adj_t^-T*I*adj_t^-1 (tensor change of coordinates), Frank & Park, section 8.2.3, pg 290
114
- @wp.func
115
- def spatial_transform_inertia(t: wp.transform, I: wp.spatial_matrix):
116
- t_inv = wp.transform_inverse(t)
117
-
118
- q = wp.transform_get_rotation(t_inv)
119
- p = wp.transform_get_translation(t_inv)
120
-
121
- r1 = wp.quat_rotate(q, wp.vec3(1.0, 0.0, 0.0))
122
- r2 = wp.quat_rotate(q, wp.vec3(0.0, 1.0, 0.0))
123
- r3 = wp.quat_rotate(q, wp.vec3(0.0, 0.0, 1.0))
124
-
125
- R = wp.mat33(r1, r2, r3)
126
- S = wp.skew(p) @ R
127
-
128
- T = spatial_adjoint(R, S)
129
-
130
- return wp.mul(wp.mul(wp.transpose(T), I), T)
131
-
132
-
133
- # compute transform across a joint
134
- @wp.func
135
- def jcalc_transform(
136
- type: int,
137
- joint_axis: wp.array(dtype=wp.vec3),
138
- axis_start: int,
139
- lin_axis_count: int,
140
- ang_axis_count: int,
141
- joint_q: wp.array(dtype=float),
142
- start: int,
143
- ):
144
- if type == wp.sim.JOINT_PRISMATIC:
145
- q = joint_q[start]
146
- axis = joint_axis[axis_start]
147
- X_jc = wp.transform(axis * q, wp.quat_identity())
148
- return X_jc
149
-
150
- if type == wp.sim.JOINT_REVOLUTE:
151
- q = joint_q[start]
152
- axis = joint_axis[axis_start]
153
- X_jc = wp.transform(wp.vec3(), wp.quat_from_axis_angle(axis, q))
154
- return X_jc
155
-
156
- if type == wp.sim.JOINT_BALL:
157
- qx = joint_q[start + 0]
158
- qy = joint_q[start + 1]
159
- qz = joint_q[start + 2]
160
- qw = joint_q[start + 3]
161
-
162
- X_jc = wp.transform(wp.vec3(), wp.quat(qx, qy, qz, qw))
163
- return X_jc
164
-
165
- if type == wp.sim.JOINT_FIXED:
166
- X_jc = wp.transform_identity()
167
- return X_jc
168
-
169
- if type == wp.sim.JOINT_FREE or type == wp.sim.JOINT_DISTANCE:
170
- px = joint_q[start + 0]
171
- py = joint_q[start + 1]
172
- pz = joint_q[start + 2]
173
-
174
- qx = joint_q[start + 3]
175
- qy = joint_q[start + 4]
176
- qz = joint_q[start + 5]
177
- qw = joint_q[start + 6]
178
-
179
- X_jc = wp.transform(wp.vec3(px, py, pz), wp.quat(qx, qy, qz, qw))
180
- return X_jc
181
-
182
- if type == wp.sim.JOINT_COMPOUND:
183
- rot, _ = compute_3d_rotational_dofs(
184
- joint_axis[axis_start],
185
- joint_axis[axis_start + 1],
186
- joint_axis[axis_start + 2],
187
- joint_q[start + 0],
188
- joint_q[start + 1],
189
- joint_q[start + 2],
190
- 0.0,
191
- 0.0,
192
- 0.0,
193
- )
194
-
195
- X_jc = wp.transform(wp.vec3(), rot)
196
- return X_jc
197
-
198
- if type == wp.sim.JOINT_UNIVERSAL:
199
- rot, _ = compute_2d_rotational_dofs(
200
- joint_axis[axis_start],
201
- joint_axis[axis_start + 1],
202
- joint_q[start + 0],
203
- joint_q[start + 1],
204
- 0.0,
205
- 0.0,
206
- )
207
-
208
- X_jc = wp.transform(wp.vec3(), rot)
209
- return X_jc
210
-
211
- if type == wp.sim.JOINT_D6:
212
- pos = wp.vec3(0.0)
213
- rot = wp.quat_identity()
214
-
215
- # unroll for loop to ensure joint actions remain differentiable
216
- # (since differentiating through a for loop that updates a local variable is not supported)
217
-
218
- if lin_axis_count > 0:
219
- axis = joint_axis[axis_start + 0]
220
- pos += axis * joint_q[start + 0]
221
- if lin_axis_count > 1:
222
- axis = joint_axis[axis_start + 1]
223
- pos += axis * joint_q[start + 1]
224
- if lin_axis_count > 2:
225
- axis = joint_axis[axis_start + 2]
226
- pos += axis * joint_q[start + 2]
227
-
228
- ia = axis_start + lin_axis_count
229
- iq = start + lin_axis_count
230
- if ang_axis_count == 1:
231
- axis = joint_axis[ia]
232
- rot = wp.quat_from_axis_angle(axis, joint_q[iq])
233
- if ang_axis_count == 2:
234
- rot, _ = compute_2d_rotational_dofs(
235
- joint_axis[ia + 0],
236
- joint_axis[ia + 1],
237
- joint_q[iq + 0],
238
- joint_q[iq + 1],
239
- 0.0,
240
- 0.0,
241
- )
242
- if ang_axis_count == 3:
243
- rot, _ = compute_3d_rotational_dofs(
244
- joint_axis[ia + 0],
245
- joint_axis[ia + 1],
246
- joint_axis[ia + 2],
247
- joint_q[iq + 0],
248
- joint_q[iq + 1],
249
- joint_q[iq + 2],
250
- 0.0,
251
- 0.0,
252
- 0.0,
253
- )
254
-
255
- X_jc = wp.transform(pos, rot)
256
- return X_jc
257
-
258
- # default case
259
- return wp.transform_identity()
260
-
261
-
262
- # compute motion subspace and velocity for a joint
263
- @wp.func
264
- def jcalc_motion(
265
- type: int,
266
- joint_axis: wp.array(dtype=wp.vec3),
267
- axis_start: int,
268
- lin_axis_count: int,
269
- ang_axis_count: int,
270
- X_sc: wp.transform,
271
- joint_q: wp.array(dtype=float),
272
- joint_qd: wp.array(dtype=float),
273
- q_start: int,
274
- qd_start: int,
275
- # outputs
276
- joint_S_s: wp.array(dtype=wp.spatial_vector),
277
- ):
278
- if type == wp.sim.JOINT_PRISMATIC:
279
- axis = joint_axis[axis_start]
280
- S_s = transform_twist(X_sc, wp.spatial_vector(wp.vec3(), axis))
281
- v_j_s = S_s * joint_qd[qd_start]
282
- joint_S_s[qd_start] = S_s
283
- return v_j_s
284
-
285
- if type == wp.sim.JOINT_REVOLUTE:
286
- axis = joint_axis[axis_start]
287
- S_s = transform_twist(X_sc, wp.spatial_vector(axis, wp.vec3()))
288
- v_j_s = S_s * joint_qd[qd_start]
289
- joint_S_s[qd_start] = S_s
290
- return v_j_s
291
-
292
- if type == wp.sim.JOINT_UNIVERSAL:
293
- axis_0 = joint_axis[axis_start + 0]
294
- axis_1 = joint_axis[axis_start + 1]
295
- q_off = wp.quat_from_matrix(wp.mat33(axis_0, axis_1, wp.cross(axis_0, axis_1)))
296
- local_0 = wp.quat_rotate(q_off, wp.vec3(1.0, 0.0, 0.0))
297
- local_1 = wp.quat_rotate(q_off, wp.vec3(0.0, 1.0, 0.0))
298
-
299
- axis_0 = local_0
300
- q_0 = wp.quat_from_axis_angle(axis_0, joint_q[q_start + 0])
301
-
302
- axis_1 = wp.quat_rotate(q_0, local_1)
303
-
304
- S_0 = transform_twist(X_sc, wp.spatial_vector(axis_0, wp.vec3()))
305
- S_1 = transform_twist(X_sc, wp.spatial_vector(axis_1, wp.vec3()))
306
-
307
- joint_S_s[qd_start + 0] = S_0
308
- joint_S_s[qd_start + 1] = S_1
309
-
310
- return S_0 * joint_qd[qd_start + 0] + S_1 * joint_qd[qd_start + 1]
311
-
312
- if type == wp.sim.JOINT_COMPOUND:
313
- axis_0 = joint_axis[axis_start + 0]
314
- axis_1 = joint_axis[axis_start + 1]
315
- axis_2 = joint_axis[axis_start + 2]
316
- q_off = wp.quat_from_matrix(wp.mat33(axis_0, axis_1, axis_2))
317
- local_0 = wp.quat_rotate(q_off, wp.vec3(1.0, 0.0, 0.0))
318
- local_1 = wp.quat_rotate(q_off, wp.vec3(0.0, 1.0, 0.0))
319
- local_2 = wp.quat_rotate(q_off, wp.vec3(0.0, 0.0, 1.0))
320
-
321
- axis_0 = local_0
322
- q_0 = wp.quat_from_axis_angle(axis_0, joint_q[q_start + 0])
323
-
324
- axis_1 = wp.quat_rotate(q_0, local_1)
325
- q_1 = wp.quat_from_axis_angle(axis_1, joint_q[q_start + 1])
326
-
327
- axis_2 = wp.quat_rotate(q_1 * q_0, local_2)
328
-
329
- S_0 = transform_twist(X_sc, wp.spatial_vector(axis_0, wp.vec3()))
330
- S_1 = transform_twist(X_sc, wp.spatial_vector(axis_1, wp.vec3()))
331
- S_2 = transform_twist(X_sc, wp.spatial_vector(axis_2, wp.vec3()))
332
-
333
- joint_S_s[qd_start + 0] = S_0
334
- joint_S_s[qd_start + 1] = S_1
335
- joint_S_s[qd_start + 2] = S_2
336
-
337
- return S_0 * joint_qd[qd_start + 0] + S_1 * joint_qd[qd_start + 1] + S_2 * joint_qd[qd_start + 2]
338
-
339
- if type == wp.sim.JOINT_D6:
340
- v_j_s = wp.spatial_vector()
341
- if lin_axis_count > 0:
342
- axis = joint_axis[axis_start + 0]
343
- S_s = transform_twist(X_sc, wp.spatial_vector(wp.vec3(), axis))
344
- v_j_s += S_s * joint_qd[qd_start + 0]
345
- joint_S_s[qd_start + 0] = S_s
346
- if lin_axis_count > 1:
347
- axis = joint_axis[axis_start + 1]
348
- S_s = transform_twist(X_sc, wp.spatial_vector(wp.vec3(), axis))
349
- v_j_s += S_s * joint_qd[qd_start + 1]
350
- joint_S_s[qd_start + 1] = S_s
351
- if lin_axis_count > 2:
352
- axis = joint_axis[axis_start + 2]
353
- S_s = transform_twist(X_sc, wp.spatial_vector(wp.vec3(), axis))
354
- v_j_s += S_s * joint_qd[qd_start + 2]
355
- joint_S_s[qd_start + 2] = S_s
356
- if ang_axis_count > 0:
357
- axis = joint_axis[axis_start + lin_axis_count + 0]
358
- S_s = transform_twist(X_sc, wp.spatial_vector(axis, wp.vec3()))
359
- v_j_s += S_s * joint_qd[qd_start + lin_axis_count + 0]
360
- joint_S_s[qd_start + lin_axis_count + 0] = S_s
361
- if ang_axis_count > 1:
362
- axis = joint_axis[axis_start + lin_axis_count + 1]
363
- S_s = transform_twist(X_sc, wp.spatial_vector(axis, wp.vec3()))
364
- v_j_s += S_s * joint_qd[qd_start + lin_axis_count + 1]
365
- joint_S_s[qd_start + lin_axis_count + 1] = S_s
366
- if ang_axis_count > 2:
367
- axis = joint_axis[axis_start + lin_axis_count + 2]
368
- S_s = transform_twist(X_sc, wp.spatial_vector(axis, wp.vec3()))
369
- v_j_s += S_s * joint_qd[qd_start + lin_axis_count + 2]
370
- joint_S_s[qd_start + lin_axis_count + 2] = S_s
371
-
372
- return v_j_s
373
-
374
- if type == wp.sim.JOINT_BALL:
375
- S_0 = transform_twist(X_sc, wp.spatial_vector(1.0, 0.0, 0.0, 0.0, 0.0, 0.0))
376
- S_1 = transform_twist(X_sc, wp.spatial_vector(0.0, 1.0, 0.0, 0.0, 0.0, 0.0))
377
- S_2 = transform_twist(X_sc, wp.spatial_vector(0.0, 0.0, 1.0, 0.0, 0.0, 0.0))
378
-
379
- joint_S_s[qd_start + 0] = S_0
380
- joint_S_s[qd_start + 1] = S_1
381
- joint_S_s[qd_start + 2] = S_2
382
-
383
- return S_0 * joint_qd[qd_start + 0] + S_1 * joint_qd[qd_start + 1] + S_2 * joint_qd[qd_start + 2]
384
-
385
- if type == wp.sim.JOINT_FIXED:
386
- return wp.spatial_vector()
387
-
388
- if type == wp.sim.JOINT_FREE or type == wp.sim.JOINT_DISTANCE:
389
- v_j_s = transform_twist(
390
- X_sc,
391
- wp.spatial_vector(
392
- joint_qd[qd_start + 0],
393
- joint_qd[qd_start + 1],
394
- joint_qd[qd_start + 2],
395
- joint_qd[qd_start + 3],
396
- joint_qd[qd_start + 4],
397
- joint_qd[qd_start + 5],
398
- ),
399
- )
400
-
401
- joint_S_s[qd_start + 0] = transform_twist(X_sc, wp.spatial_vector(1.0, 0.0, 0.0, 0.0, 0.0, 0.0))
402
- joint_S_s[qd_start + 1] = transform_twist(X_sc, wp.spatial_vector(0.0, 1.0, 0.0, 0.0, 0.0, 0.0))
403
- joint_S_s[qd_start + 2] = transform_twist(X_sc, wp.spatial_vector(0.0, 0.0, 1.0, 0.0, 0.0, 0.0))
404
- joint_S_s[qd_start + 3] = transform_twist(X_sc, wp.spatial_vector(0.0, 0.0, 0.0, 1.0, 0.0, 0.0))
405
- joint_S_s[qd_start + 4] = transform_twist(X_sc, wp.spatial_vector(0.0, 0.0, 0.0, 0.0, 1.0, 0.0))
406
- joint_S_s[qd_start + 5] = transform_twist(X_sc, wp.spatial_vector(0.0, 0.0, 0.0, 0.0, 0.0, 1.0))
407
-
408
- return v_j_s
409
-
410
- wp.printf("jcalc_motion not implemented for joint type %d\n", type)
411
-
412
- # default case
413
- return wp.spatial_vector()
414
-
415
-
416
- # computes joint space forces/torques in tau
417
- @wp.func
418
- def jcalc_tau(
419
- type: int,
420
- joint_target_ke: wp.array(dtype=float),
421
- joint_target_kd: wp.array(dtype=float),
422
- joint_limit_ke: wp.array(dtype=float),
423
- joint_limit_kd: wp.array(dtype=float),
424
- joint_S_s: wp.array(dtype=wp.spatial_vector),
425
- joint_q: wp.array(dtype=float),
426
- joint_qd: wp.array(dtype=float),
427
- joint_act: wp.array(dtype=float),
428
- joint_axis_mode: wp.array(dtype=int),
429
- joint_limit_lower: wp.array(dtype=float),
430
- joint_limit_upper: wp.array(dtype=float),
431
- coord_start: int,
432
- dof_start: int,
433
- axis_start: int,
434
- lin_axis_count: int,
435
- ang_axis_count: int,
436
- body_f_s: wp.spatial_vector,
437
- # outputs
438
- tau: wp.array(dtype=float),
439
- ):
440
- if type == wp.sim.JOINT_PRISMATIC or type == wp.sim.JOINT_REVOLUTE:
441
- S_s = joint_S_s[dof_start]
442
-
443
- q = joint_q[coord_start]
444
- qd = joint_qd[dof_start]
445
- act = joint_act[axis_start]
446
-
447
- lower = joint_limit_lower[axis_start]
448
- upper = joint_limit_upper[axis_start]
449
-
450
- limit_ke = joint_limit_ke[axis_start]
451
- limit_kd = joint_limit_kd[axis_start]
452
- target_ke = joint_target_ke[axis_start]
453
- target_kd = joint_target_kd[axis_start]
454
- mode = joint_axis_mode[axis_start]
455
-
456
- # total torque / force on the joint
457
- t = -wp.dot(S_s, body_f_s) + eval_joint_force(
458
- q, qd, act, target_ke, target_kd, lower, upper, limit_ke, limit_kd, mode
459
- )
460
-
461
- tau[dof_start] = t
462
-
463
- return
464
-
465
- if type == wp.sim.JOINT_BALL:
466
- # target_ke = joint_target_ke[axis_start]
467
- # target_kd = joint_target_kd[axis_start]
468
-
469
- for i in range(3):
470
- S_s = joint_S_s[dof_start + i]
471
-
472
- # w = joint_qd[dof_start + i]
473
- # r = joint_q[coord_start + i]
474
-
475
- tau[dof_start + i] = -wp.dot(S_s, body_f_s) # - w * target_kd - r * target_ke
476
-
477
- return
478
-
479
- if type == wp.sim.JOINT_FREE or type == wp.sim.JOINT_DISTANCE:
480
- for i in range(6):
481
- S_s = joint_S_s[dof_start + i]
482
- tau[dof_start + i] = -wp.dot(S_s, body_f_s)
483
-
484
- return
485
-
486
- if type == wp.sim.JOINT_COMPOUND or type == wp.sim.JOINT_UNIVERSAL or type == wp.sim.JOINT_D6:
487
- axis_count = lin_axis_count + ang_axis_count
488
-
489
- for i in range(axis_count):
490
- S_s = joint_S_s[dof_start + i]
491
-
492
- q = joint_q[coord_start + i]
493
- qd = joint_qd[dof_start + i]
494
- act = joint_act[axis_start + i]
495
-
496
- lower = joint_limit_lower[axis_start + i]
497
- upper = joint_limit_upper[axis_start + i]
498
- limit_ke = joint_limit_ke[axis_start + i]
499
- limit_kd = joint_limit_kd[axis_start + i]
500
- target_ke = joint_target_ke[axis_start + i]
501
- target_kd = joint_target_kd[axis_start + i]
502
- mode = joint_axis_mode[axis_start + i]
503
-
504
- f = eval_joint_force(q, qd, act, target_ke, target_kd, lower, upper, limit_ke, limit_kd, mode)
505
-
506
- # total torque / force on the joint
507
- t = -wp.dot(S_s, body_f_s) + f
508
-
509
- tau[dof_start + i] = t
510
-
511
- return
512
-
513
-
514
- @wp.func
515
- def jcalc_integrate(
516
- type: int,
517
- joint_q: wp.array(dtype=float),
518
- joint_qd: wp.array(dtype=float),
519
- joint_qdd: wp.array(dtype=float),
520
- coord_start: int,
521
- dof_start: int,
522
- lin_axis_count: int,
523
- ang_axis_count: int,
524
- dt: float,
525
- # outputs
526
- joint_q_new: wp.array(dtype=float),
527
- joint_qd_new: wp.array(dtype=float),
528
- ):
529
- if type == wp.sim.JOINT_FIXED:
530
- return
531
-
532
- # prismatic / revolute
533
- if type == wp.sim.JOINT_PRISMATIC or type == wp.sim.JOINT_REVOLUTE:
534
- qdd = joint_qdd[dof_start]
535
- qd = joint_qd[dof_start]
536
- q = joint_q[coord_start]
537
-
538
- qd_new = qd + qdd * dt
539
- q_new = q + qd_new * dt
540
-
541
- joint_qd_new[dof_start] = qd_new
542
- joint_q_new[coord_start] = q_new
543
-
544
- return
545
-
546
- # ball
547
- if type == wp.sim.JOINT_BALL:
548
- m_j = wp.vec3(joint_qdd[dof_start + 0], joint_qdd[dof_start + 1], joint_qdd[dof_start + 2])
549
- w_j = wp.vec3(joint_qd[dof_start + 0], joint_qd[dof_start + 1], joint_qd[dof_start + 2])
550
-
551
- r_j = wp.quat(
552
- joint_q[coord_start + 0], joint_q[coord_start + 1], joint_q[coord_start + 2], joint_q[coord_start + 3]
553
- )
554
-
555
- # symplectic Euler
556
- w_j_new = w_j + m_j * dt
557
-
558
- drdt_j = wp.quat(w_j_new, 0.0) * r_j * 0.5
559
-
560
- # new orientation (normalized)
561
- r_j_new = wp.normalize(r_j + drdt_j * dt)
562
-
563
- # update joint coords
564
- joint_q_new[coord_start + 0] = r_j_new[0]
565
- joint_q_new[coord_start + 1] = r_j_new[1]
566
- joint_q_new[coord_start + 2] = r_j_new[2]
567
- joint_q_new[coord_start + 3] = r_j_new[3]
568
-
569
- # update joint vel
570
- joint_qd_new[dof_start + 0] = w_j_new[0]
571
- joint_qd_new[dof_start + 1] = w_j_new[1]
572
- joint_qd_new[dof_start + 2] = w_j_new[2]
573
-
574
- return
575
-
576
- # free joint
577
- if type == wp.sim.JOINT_FREE or type == wp.sim.JOINT_DISTANCE:
578
- # dofs: qd = (omega_x, omega_y, omega_z, vel_x, vel_y, vel_z)
579
- # coords: q = (trans_x, trans_y, trans_z, quat_x, quat_y, quat_z, quat_w)
580
-
581
- # angular and linear acceleration
582
- m_s = wp.vec3(joint_qdd[dof_start + 0], joint_qdd[dof_start + 1], joint_qdd[dof_start + 2])
583
- a_s = wp.vec3(joint_qdd[dof_start + 3], joint_qdd[dof_start + 4], joint_qdd[dof_start + 5])
584
-
585
- # angular and linear velocity
586
- w_s = wp.vec3(joint_qd[dof_start + 0], joint_qd[dof_start + 1], joint_qd[dof_start + 2])
587
- v_s = wp.vec3(joint_qd[dof_start + 3], joint_qd[dof_start + 4], joint_qd[dof_start + 5])
588
-
589
- # symplectic Euler
590
- w_s = w_s + m_s * dt
591
- v_s = v_s + a_s * dt
592
-
593
- # translation of origin
594
- p_s = wp.vec3(joint_q[coord_start + 0], joint_q[coord_start + 1], joint_q[coord_start + 2])
595
-
596
- # linear vel of origin (note q/qd switch order of linear angular elements)
597
- # note we are converting the body twist in the space frame (w_s, v_s) to compute center of mass velcity
598
- dpdt_s = v_s + wp.cross(w_s, p_s)
599
-
600
- # quat and quat derivative
601
- r_s = wp.quat(
602
- joint_q[coord_start + 3], joint_q[coord_start + 4], joint_q[coord_start + 5], joint_q[coord_start + 6]
603
- )
604
-
605
- drdt_s = wp.quat(w_s, 0.0) * r_s * 0.5
606
-
607
- # new orientation (normalized)
608
- p_s_new = p_s + dpdt_s * dt
609
- r_s_new = wp.normalize(r_s + drdt_s * dt)
610
-
611
- # update transform
612
- joint_q_new[coord_start + 0] = p_s_new[0]
613
- joint_q_new[coord_start + 1] = p_s_new[1]
614
- joint_q_new[coord_start + 2] = p_s_new[2]
615
-
616
- joint_q_new[coord_start + 3] = r_s_new[0]
617
- joint_q_new[coord_start + 4] = r_s_new[1]
618
- joint_q_new[coord_start + 5] = r_s_new[2]
619
- joint_q_new[coord_start + 6] = r_s_new[3]
620
-
621
- # update joint_twist
622
- joint_qd_new[dof_start + 0] = w_s[0]
623
- joint_qd_new[dof_start + 1] = w_s[1]
624
- joint_qd_new[dof_start + 2] = w_s[2]
625
- joint_qd_new[dof_start + 3] = v_s[0]
626
- joint_qd_new[dof_start + 4] = v_s[1]
627
- joint_qd_new[dof_start + 5] = v_s[2]
628
-
629
- return
630
-
631
- # other joint types (compound, universal, D6)
632
- if type == wp.sim.JOINT_COMPOUND or type == wp.sim.JOINT_UNIVERSAL or type == wp.sim.JOINT_D6:
633
- axis_count = lin_axis_count + ang_axis_count
634
-
635
- for i in range(axis_count):
636
- qdd = joint_qdd[dof_start + i]
637
- qd = joint_qd[dof_start + i]
638
- q = joint_q[coord_start + i]
639
-
640
- qd_new = qd + qdd * dt
641
- q_new = q + qd_new * dt
642
-
643
- joint_qd_new[dof_start + i] = qd_new
644
- joint_q_new[coord_start + i] = q_new
645
-
646
- return
647
-
648
-
649
- @wp.func
650
- def compute_link_transform(
651
- i: int,
652
- joint_type: wp.array(dtype=int),
653
- joint_parent: wp.array(dtype=int),
654
- joint_child: wp.array(dtype=int),
655
- joint_q_start: wp.array(dtype=int),
656
- joint_q: wp.array(dtype=float),
657
- joint_X_p: wp.array(dtype=wp.transform),
658
- joint_X_c: wp.array(dtype=wp.transform),
659
- body_X_com: wp.array(dtype=wp.transform),
660
- joint_axis: wp.array(dtype=wp.vec3),
661
- joint_axis_start: wp.array(dtype=int),
662
- joint_axis_dim: wp.array(dtype=int, ndim=2),
663
- # outputs
664
- body_q: wp.array(dtype=wp.transform),
665
- body_q_com: wp.array(dtype=wp.transform),
666
- ):
667
- # parent transform
668
- parent = joint_parent[i]
669
- child = joint_child[i]
670
-
671
- # parent transform in spatial coordinates
672
- X_pj = joint_X_p[i]
673
- X_cj = joint_X_c[i]
674
- # parent anchor frame in world space
675
- X_wpj = X_pj
676
- if parent >= 0:
677
- X_wp = body_q[parent]
678
- X_wpj = X_wp * X_wpj
679
-
680
- type = joint_type[i]
681
- axis_start = joint_axis_start[i]
682
- lin_axis_count = joint_axis_dim[i, 0]
683
- ang_axis_count = joint_axis_dim[i, 1]
684
- coord_start = joint_q_start[i]
685
-
686
- # compute transform across joint
687
- X_j = jcalc_transform(type, joint_axis, axis_start, lin_axis_count, ang_axis_count, joint_q, coord_start)
688
-
689
- # transform from world to joint anchor frame at child body
690
- X_wcj = X_wpj * X_j
691
- # transform from world to child body frame
692
- X_wc = X_wcj * wp.transform_inverse(X_cj)
693
-
694
- # compute transform of center of mass
695
- X_cm = body_X_com[child]
696
- X_sm = X_wc * X_cm
697
-
698
- # store geometry transforms
699
- body_q[child] = X_wc
700
- body_q_com[child] = X_sm
701
-
702
-
703
- @wp.kernel
704
- def eval_rigid_fk(
705
- articulation_start: wp.array(dtype=int),
706
- joint_type: wp.array(dtype=int),
707
- joint_parent: wp.array(dtype=int),
708
- joint_child: wp.array(dtype=int),
709
- joint_q_start: wp.array(dtype=int),
710
- joint_q: wp.array(dtype=float),
711
- joint_X_p: wp.array(dtype=wp.transform),
712
- joint_X_c: wp.array(dtype=wp.transform),
713
- body_X_com: wp.array(dtype=wp.transform),
714
- joint_axis: wp.array(dtype=wp.vec3),
715
- joint_axis_start: wp.array(dtype=int),
716
- joint_axis_dim: wp.array(dtype=int, ndim=2),
717
- # outputs
718
- body_q: wp.array(dtype=wp.transform),
719
- body_q_com: wp.array(dtype=wp.transform),
720
- ):
721
- # one thread per-articulation
722
- index = wp.tid()
723
-
724
- start = articulation_start[index]
725
- end = articulation_start[index + 1]
726
-
727
- for i in range(start, end):
728
- compute_link_transform(
729
- i,
730
- joint_type,
731
- joint_parent,
732
- joint_child,
733
- joint_q_start,
734
- joint_q,
735
- joint_X_p,
736
- joint_X_c,
737
- body_X_com,
738
- joint_axis,
739
- joint_axis_start,
740
- joint_axis_dim,
741
- body_q,
742
- body_q_com,
743
- )
744
-
745
-
746
- @wp.func
747
- def spatial_cross(a: wp.spatial_vector, b: wp.spatial_vector):
748
- w_a = wp.spatial_top(a)
749
- v_a = wp.spatial_bottom(a)
750
-
751
- w_b = wp.spatial_top(b)
752
- v_b = wp.spatial_bottom(b)
753
-
754
- w = wp.cross(w_a, w_b)
755
- v = wp.cross(w_a, v_b) + wp.cross(v_a, w_b)
756
-
757
- return wp.spatial_vector(w, v)
758
-
759
-
760
- @wp.func
761
- def spatial_cross_dual(a: wp.spatial_vector, b: wp.spatial_vector):
762
- w_a = wp.spatial_top(a)
763
- v_a = wp.spatial_bottom(a)
764
-
765
- w_b = wp.spatial_top(b)
766
- v_b = wp.spatial_bottom(b)
767
-
768
- w = wp.cross(w_a, w_b) + wp.cross(v_a, v_b)
769
- v = wp.cross(w_a, v_b)
770
-
771
- return wp.spatial_vector(w, v)
772
-
773
-
774
- @wp.func
775
- def dense_index(stride: int, i: int, j: int):
776
- return i * stride + j
777
-
778
-
779
- @wp.func
780
- def compute_link_velocity(
781
- i: int,
782
- joint_type: wp.array(dtype=int),
783
- joint_parent: wp.array(dtype=int),
784
- joint_child: wp.array(dtype=int),
785
- joint_q_start: wp.array(dtype=int),
786
- joint_qd_start: wp.array(dtype=int),
787
- joint_q: wp.array(dtype=float),
788
- joint_qd: wp.array(dtype=float),
789
- joint_axis: wp.array(dtype=wp.vec3),
790
- joint_axis_start: wp.array(dtype=int),
791
- joint_axis_dim: wp.array(dtype=int, ndim=2),
792
- body_I_m: wp.array(dtype=wp.spatial_matrix),
793
- body_q: wp.array(dtype=wp.transform),
794
- body_q_com: wp.array(dtype=wp.transform),
795
- joint_X_p: wp.array(dtype=wp.transform),
796
- joint_X_c: wp.array(dtype=wp.transform),
797
- gravity: wp.vec3,
798
- # outputs
799
- joint_S_s: wp.array(dtype=wp.spatial_vector),
800
- body_I_s: wp.array(dtype=wp.spatial_matrix),
801
- body_v_s: wp.array(dtype=wp.spatial_vector),
802
- body_f_s: wp.array(dtype=wp.spatial_vector),
803
- body_a_s: wp.array(dtype=wp.spatial_vector),
804
- ):
805
- type = joint_type[i]
806
- child = joint_child[i]
807
- parent = joint_parent[i]
808
- q_start = joint_q_start[i]
809
- qd_start = joint_qd_start[i]
810
-
811
- X_pj = joint_X_p[i]
812
- X_cj = joint_X_c[i]
813
-
814
- # parent anchor frame in world space
815
- X_wpj = X_pj
816
- if parent >= 0:
817
- X_wp = body_q[parent]
818
- X_wpj = X_wp * X_wpj
819
-
820
- # compute motion subspace and velocity across the joint (also stores S_s to global memory)
821
- axis_start = joint_axis_start[i]
822
- lin_axis_count = joint_axis_dim[i, 0]
823
- ang_axis_count = joint_axis_dim[i, 1]
824
- v_j_s = jcalc_motion(
825
- type,
826
- joint_axis,
827
- axis_start,
828
- lin_axis_count,
829
- ang_axis_count,
830
- X_wpj,
831
- joint_q,
832
- joint_qd,
833
- q_start,
834
- qd_start,
835
- joint_S_s,
836
- )
837
-
838
- # parent velocity
839
- v_parent_s = wp.spatial_vector()
840
- a_parent_s = wp.spatial_vector()
841
-
842
- if parent >= 0:
843
- v_parent_s = body_v_s[parent]
844
- a_parent_s = body_a_s[parent]
845
-
846
- # body velocity, acceleration
847
- v_s = v_parent_s + v_j_s
848
- a_s = a_parent_s + spatial_cross(v_s, v_j_s) # + joint_S_s[i]*self.joint_qdd[i]
849
-
850
- # compute body forces
851
- X_sm = body_q_com[child]
852
- I_m = body_I_m[child]
853
-
854
- # gravity and external forces (expressed in frame aligned with s but centered at body mass)
855
- m = I_m[3, 3]
856
-
857
- f_g = m * gravity
858
- r_com = wp.transform_get_translation(X_sm)
859
- f_g_s = wp.spatial_vector(wp.cross(r_com, f_g), f_g)
860
-
861
- # body forces
862
- I_s = spatial_transform_inertia(X_sm, I_m)
863
-
864
- f_b_s = I_s * a_s + spatial_cross_dual(v_s, I_s * v_s)
865
-
866
- body_v_s[child] = v_s
867
- body_a_s[child] = a_s
868
- body_f_s[child] = f_b_s - f_g_s
869
- body_I_s[child] = I_s
870
-
871
-
872
- # Inverse dynamics via Recursive Newton-Euler algorithm (Featherstone Table 5.1)
873
- @wp.kernel
874
- def eval_rigid_id(
875
- articulation_start: wp.array(dtype=int),
876
- joint_type: wp.array(dtype=int),
877
- joint_parent: wp.array(dtype=int),
878
- joint_child: wp.array(dtype=int),
879
- joint_q_start: wp.array(dtype=int),
880
- joint_qd_start: wp.array(dtype=int),
881
- joint_q: wp.array(dtype=float),
882
- joint_qd: wp.array(dtype=float),
883
- joint_axis: wp.array(dtype=wp.vec3),
884
- joint_axis_start: wp.array(dtype=int),
885
- joint_axis_dim: wp.array(dtype=int, ndim=2),
886
- body_I_m: wp.array(dtype=wp.spatial_matrix),
887
- body_q: wp.array(dtype=wp.transform),
888
- body_q_com: wp.array(dtype=wp.transform),
889
- joint_X_p: wp.array(dtype=wp.transform),
890
- joint_X_c: wp.array(dtype=wp.transform),
891
- gravity: wp.vec3,
892
- # outputs
893
- joint_S_s: wp.array(dtype=wp.spatial_vector),
894
- body_I_s: wp.array(dtype=wp.spatial_matrix),
895
- body_v_s: wp.array(dtype=wp.spatial_vector),
896
- body_f_s: wp.array(dtype=wp.spatial_vector),
897
- body_a_s: wp.array(dtype=wp.spatial_vector),
898
- ):
899
- # one thread per-articulation
900
- index = wp.tid()
901
-
902
- start = articulation_start[index]
903
- end = articulation_start[index + 1]
904
-
905
- # compute link velocities and coriolis forces
906
- for i in range(start, end):
907
- compute_link_velocity(
908
- i,
909
- joint_type,
910
- joint_parent,
911
- joint_child,
912
- joint_q_start,
913
- joint_qd_start,
914
- joint_q,
915
- joint_qd,
916
- joint_axis,
917
- joint_axis_start,
918
- joint_axis_dim,
919
- body_I_m,
920
- body_q,
921
- body_q_com,
922
- joint_X_p,
923
- joint_X_c,
924
- gravity,
925
- joint_S_s,
926
- body_I_s,
927
- body_v_s,
928
- body_f_s,
929
- body_a_s,
930
- )
931
-
932
-
933
- @wp.kernel
934
- def eval_rigid_tau(
935
- articulation_start: wp.array(dtype=int),
936
- joint_type: wp.array(dtype=int),
937
- joint_parent: wp.array(dtype=int),
938
- joint_child: wp.array(dtype=int),
939
- joint_q_start: wp.array(dtype=int),
940
- joint_qd_start: wp.array(dtype=int),
941
- joint_axis_start: wp.array(dtype=int),
942
- joint_axis_dim: wp.array(dtype=int, ndim=2),
943
- joint_axis_mode: wp.array(dtype=int),
944
- joint_q: wp.array(dtype=float),
945
- joint_qd: wp.array(dtype=float),
946
- joint_act: wp.array(dtype=float),
947
- joint_target_ke: wp.array(dtype=float),
948
- joint_target_kd: wp.array(dtype=float),
949
- joint_limit_lower: wp.array(dtype=float),
950
- joint_limit_upper: wp.array(dtype=float),
951
- joint_limit_ke: wp.array(dtype=float),
952
- joint_limit_kd: wp.array(dtype=float),
953
- joint_S_s: wp.array(dtype=wp.spatial_vector),
954
- body_fb_s: wp.array(dtype=wp.spatial_vector),
955
- body_f_ext: wp.array(dtype=wp.spatial_vector),
956
- # outputs
957
- body_ft_s: wp.array(dtype=wp.spatial_vector),
958
- tau: wp.array(dtype=float),
959
- ):
960
- # one thread per-articulation
961
- index = wp.tid()
962
-
963
- start = articulation_start[index]
964
- end = articulation_start[index + 1]
965
- count = end - start
966
-
967
- # compute joint forces
968
- for offset in range(count):
969
- # for backwards traversal
970
- i = end - offset - 1
971
-
972
- type = joint_type[i]
973
- parent = joint_parent[i]
974
- child = joint_child[i]
975
- dof_start = joint_qd_start[i]
976
- coord_start = joint_q_start[i]
977
- axis_start = joint_axis_start[i]
978
- lin_axis_count = joint_axis_dim[i, 0]
979
- ang_axis_count = joint_axis_dim[i, 1]
980
-
981
- # total forces on body
982
- f_b_s = body_fb_s[child]
983
- f_t_s = body_ft_s[child]
984
- f_ext = body_f_ext[child]
985
- f_s = f_b_s + f_t_s + f_ext
986
-
987
- # compute joint-space forces, writes out tau
988
- jcalc_tau(
989
- type,
990
- joint_target_ke,
991
- joint_target_kd,
992
- joint_limit_ke,
993
- joint_limit_kd,
994
- joint_S_s,
995
- joint_q,
996
- joint_qd,
997
- joint_act,
998
- joint_axis_mode,
999
- joint_limit_lower,
1000
- joint_limit_upper,
1001
- coord_start,
1002
- dof_start,
1003
- axis_start,
1004
- lin_axis_count,
1005
- ang_axis_count,
1006
- f_s,
1007
- tau,
1008
- )
1009
-
1010
- # update parent forces, todo: check that this is valid for the backwards pass
1011
- if parent >= 0:
1012
- wp.atomic_add(body_ft_s, parent, f_s)
1013
-
1014
-
1015
- # builds spatial Jacobian J which is an (joint_count*6)x(dof_count) matrix
1016
- @wp.kernel
1017
- def eval_rigid_jacobian(
1018
- articulation_start: wp.array(dtype=int),
1019
- articulation_J_start: wp.array(dtype=int),
1020
- joint_parent: wp.array(dtype=int),
1021
- joint_qd_start: wp.array(dtype=int),
1022
- joint_S_s: wp.array(dtype=wp.spatial_vector),
1023
- # outputs
1024
- J: wp.array(dtype=float),
1025
- ):
1026
- # one thread per-articulation
1027
- index = wp.tid()
1028
-
1029
- joint_start = articulation_start[index]
1030
- joint_end = articulation_start[index + 1]
1031
- joint_count = joint_end - joint_start
1032
-
1033
- J_offset = articulation_J_start[index]
1034
-
1035
- articulation_dof_start = joint_qd_start[joint_start]
1036
- articulation_dof_end = joint_qd_start[joint_end]
1037
- articulation_dof_count = articulation_dof_end - articulation_dof_start
1038
-
1039
- for i in range(joint_count):
1040
- row_start = i * 6
1041
-
1042
- j = joint_start + i
1043
- while j != -1:
1044
- joint_dof_start = joint_qd_start[j]
1045
- joint_dof_end = joint_qd_start[j + 1]
1046
- joint_dof_count = joint_dof_end - joint_dof_start
1047
-
1048
- # fill out each row of the Jacobian walking up the tree
1049
- for dof in range(joint_dof_count):
1050
- col = (joint_dof_start - articulation_dof_start) + dof
1051
- S = joint_S_s[joint_dof_start + dof]
1052
-
1053
- for k in range(6):
1054
- J[J_offset + dense_index(articulation_dof_count, row_start + k, col)] = S[k]
1055
-
1056
- j = joint_parent[j]
1057
-
1058
-
1059
- @wp.func
1060
- def spatial_mass(
1061
- body_I_s: wp.array(dtype=wp.spatial_matrix),
1062
- joint_start: int,
1063
- joint_count: int,
1064
- M_start: int,
1065
- # outputs
1066
- M: wp.array(dtype=float),
1067
- ):
1068
- stride = joint_count * 6
1069
- for l in range(joint_count):
1070
- I = body_I_s[joint_start + l]
1071
- for i in range(6):
1072
- for j in range(6):
1073
- M[M_start + dense_index(stride, l * 6 + i, l * 6 + j)] = I[i, j]
1074
-
1075
-
1076
- @wp.kernel
1077
- def eval_rigid_mass(
1078
- articulation_start: wp.array(dtype=int),
1079
- articulation_M_start: wp.array(dtype=int),
1080
- body_I_s: wp.array(dtype=wp.spatial_matrix),
1081
- # outputs
1082
- M: wp.array(dtype=float),
1083
- ):
1084
- # one thread per-articulation
1085
- index = wp.tid()
1086
-
1087
- joint_start = articulation_start[index]
1088
- joint_end = articulation_start[index + 1]
1089
- joint_count = joint_end - joint_start
1090
-
1091
- M_offset = articulation_M_start[index]
1092
-
1093
- spatial_mass(body_I_s, joint_start, joint_count, M_offset, M)
1094
-
1095
-
1096
- @wp.func
1097
- def dense_gemm(
1098
- m: int,
1099
- n: int,
1100
- p: int,
1101
- transpose_A: bool,
1102
- transpose_B: bool,
1103
- add_to_C: bool,
1104
- A_start: int,
1105
- B_start: int,
1106
- C_start: int,
1107
- A: wp.array(dtype=float),
1108
- B: wp.array(dtype=float),
1109
- # outputs
1110
- C: wp.array(dtype=float),
1111
- ):
1112
- # multiply a `m x p` matrix A by a `p x n` matrix B to produce a `m x n` matrix C
1113
- for i in range(m):
1114
- for j in range(n):
1115
- sum = float(0.0)
1116
- for k in range(p):
1117
- if transpose_A:
1118
- a_i = k * m + i
1119
- else:
1120
- a_i = i * p + k
1121
- if transpose_B:
1122
- b_j = j * p + k
1123
- else:
1124
- b_j = k * n + j
1125
- sum += A[A_start + a_i] * B[B_start + b_j]
1126
-
1127
- if add_to_C:
1128
- C[C_start + i * n + j] += sum
1129
- else:
1130
- C[C_start + i * n + j] = sum
1131
-
1132
-
1133
- @wp.func_grad(dense_gemm)
1134
- def adj_dense_gemm(
1135
- m: int,
1136
- n: int,
1137
- p: int,
1138
- transpose_A: bool,
1139
- transpose_B: bool,
1140
- add_to_C: bool,
1141
- A_start: int,
1142
- B_start: int,
1143
- C_start: int,
1144
- A: wp.array(dtype=float),
1145
- B: wp.array(dtype=float),
1146
- # outputs
1147
- C: wp.array(dtype=float),
1148
- ):
1149
- add_to_C = True
1150
- if transpose_A:
1151
- dense_gemm(p, m, n, False, True, add_to_C, A_start, B_start, C_start, B, wp.adjoint[C], wp.adjoint[A])
1152
- dense_gemm(p, n, m, False, False, add_to_C, A_start, B_start, C_start, A, wp.adjoint[C], wp.adjoint[B])
1153
- else:
1154
- dense_gemm(
1155
- m, p, n, False, not transpose_B, add_to_C, A_start, B_start, C_start, wp.adjoint[C], B, wp.adjoint[A]
1156
- )
1157
- dense_gemm(p, n, m, True, False, add_to_C, A_start, B_start, C_start, A, wp.adjoint[C], wp.adjoint[B])
1158
-
1159
-
1160
- @wp.kernel
1161
- def eval_dense_gemm_batched(
1162
- m: wp.array(dtype=int),
1163
- n: wp.array(dtype=int),
1164
- p: wp.array(dtype=int),
1165
- transpose_A: bool,
1166
- transpose_B: bool,
1167
- A_start: wp.array(dtype=int),
1168
- B_start: wp.array(dtype=int),
1169
- C_start: wp.array(dtype=int),
1170
- A: wp.array(dtype=float),
1171
- B: wp.array(dtype=float),
1172
- C: wp.array(dtype=float),
1173
- ):
1174
- # on the CPU each thread computes the whole matrix multiply
1175
- # on the GPU each block computes the multiply with one output per-thread
1176
- batch = wp.tid() # /kNumThreadsPerBlock;
1177
- add_to_C = False
1178
-
1179
- dense_gemm(
1180
- m[batch],
1181
- n[batch],
1182
- p[batch],
1183
- transpose_A,
1184
- transpose_B,
1185
- add_to_C,
1186
- A_start[batch],
1187
- B_start[batch],
1188
- C_start[batch],
1189
- A,
1190
- B,
1191
- C,
1192
- )
1193
-
1194
-
1195
- @wp.func
1196
- def dense_cholesky(
1197
- n: int,
1198
- A: wp.array(dtype=float),
1199
- R: wp.array(dtype=float),
1200
- A_start: int,
1201
- R_start: int,
1202
- # outputs
1203
- L: wp.array(dtype=float),
1204
- ):
1205
- # compute the Cholesky factorization of A = L L^T with diagonal regularization R
1206
- for j in range(n):
1207
- s = A[A_start + dense_index(n, j, j)] + R[R_start + j]
1208
-
1209
- for k in range(j):
1210
- r = L[A_start + dense_index(n, j, k)]
1211
- s -= r * r
1212
-
1213
- s = wp.sqrt(s)
1214
- invS = 1.0 / s
1215
-
1216
- L[A_start + dense_index(n, j, j)] = s
1217
-
1218
- for i in range(j + 1, n):
1219
- s = A[A_start + dense_index(n, i, j)]
1220
-
1221
- for k in range(j):
1222
- s -= L[A_start + dense_index(n, i, k)] * L[A_start + dense_index(n, j, k)]
1223
-
1224
- L[A_start + dense_index(n, i, j)] = s * invS
1225
-
1226
-
1227
- @wp.func_grad(dense_cholesky)
1228
- def adj_dense_cholesky(
1229
- n: int,
1230
- A: wp.array(dtype=float),
1231
- R: wp.array(dtype=float),
1232
- A_start: int,
1233
- R_start: int,
1234
- # outputs
1235
- L: wp.array(dtype=float),
1236
- ):
1237
- # nop, use dense_solve to differentiate through (A^-1)b = x
1238
- pass
1239
-
1240
-
1241
- @wp.kernel
1242
- def eval_dense_cholesky_batched(
1243
- A_starts: wp.array(dtype=int),
1244
- A_dim: wp.array(dtype=int),
1245
- A: wp.array(dtype=float),
1246
- R: wp.array(dtype=float),
1247
- L: wp.array(dtype=float),
1248
- ):
1249
- batch = wp.tid()
1250
-
1251
- n = A_dim[batch]
1252
- A_start = A_starts[batch]
1253
- R_start = n * batch
1254
-
1255
- dense_cholesky(n, A, R, A_start, R_start, L)
1256
-
1257
-
1258
- @wp.func
1259
- def dense_subs(
1260
- n: int,
1261
- L_start: int,
1262
- b_start: int,
1263
- L: wp.array(dtype=float),
1264
- b: wp.array(dtype=float),
1265
- # outputs
1266
- x: wp.array(dtype=float),
1267
- ):
1268
- # Solves (L L^T) x = b for x given the Cholesky factor L
1269
- # forward substitution solves the lower triangular system L y = b for y
1270
- for i in range(n):
1271
- s = b[b_start + i]
1272
-
1273
- for j in range(i):
1274
- s -= L[L_start + dense_index(n, i, j)] * x[b_start + j]
1275
-
1276
- x[b_start + i] = s / L[L_start + dense_index(n, i, i)]
1277
-
1278
- # backward substitution solves the upper triangular system L^T x = y for x
1279
- for i in range(n - 1, -1, -1):
1280
- s = x[b_start + i]
1281
-
1282
- for j in range(i + 1, n):
1283
- s -= L[L_start + dense_index(n, j, i)] * x[b_start + j]
1284
-
1285
- x[b_start + i] = s / L[L_start + dense_index(n, i, i)]
1286
-
1287
-
1288
- @wp.func
1289
- def dense_solve(
1290
- n: int,
1291
- L_start: int,
1292
- b_start: int,
1293
- L: wp.array(dtype=float),
1294
- b: wp.array(dtype=float),
1295
- # outputs
1296
- x: wp.array(dtype=float),
1297
- tmp: wp.array(dtype=float),
1298
- ):
1299
- # helper function to include tmp argument for backward pass
1300
- dense_subs(n, L_start, b_start, L, b, x)
1301
-
1302
-
1303
- @wp.func_grad(dense_solve)
1304
- def adj_dense_solve(
1305
- n: int,
1306
- L_start: int,
1307
- b_start: int,
1308
- L: wp.array(dtype=float),
1309
- b: wp.array(dtype=float),
1310
- # outputs
1311
- x: wp.array(dtype=float),
1312
- tmp: wp.array(dtype=float),
1313
- ):
1314
- for i in range(n):
1315
- tmp[b_start + i] = 0.0
1316
-
1317
- dense_subs(n, L_start, b_start, L, wp.adjoint[x], tmp)
1318
-
1319
- for i in range(n):
1320
- wp.adjoint[b][b_start + i] += tmp[b_start + i]
1321
-
1322
- # A* = -adj_b*x^T
1323
- for i in range(n):
1324
- for j in range(n):
1325
- wp.adjoint[L][L_start + dense_index(n, i, j)] += -tmp[b_start + i] * x[b_start + j]
1326
-
1327
-
1328
- @wp.kernel
1329
- def eval_dense_solve_batched(
1330
- L_start: wp.array(dtype=int),
1331
- L_dim: wp.array(dtype=int),
1332
- b_start: wp.array(dtype=int),
1333
- L: wp.array(dtype=float),
1334
- b: wp.array(dtype=float),
1335
- x: wp.array(dtype=float),
1336
- tmp: wp.array(dtype=float),
1337
- ):
1338
- batch = wp.tid()
1339
-
1340
- dense_solve(L_dim[batch], L_start[batch], b_start[batch], L, b, x, tmp)
1341
-
1342
-
1343
- @wp.kernel
1344
- def integrate_generalized_joints(
1345
- joint_type: wp.array(dtype=int),
1346
- joint_q_start: wp.array(dtype=int),
1347
- joint_qd_start: wp.array(dtype=int),
1348
- joint_axis_dim: wp.array(dtype=int, ndim=2),
1349
- joint_q: wp.array(dtype=float),
1350
- joint_qd: wp.array(dtype=float),
1351
- joint_qdd: wp.array(dtype=float),
1352
- dt: float,
1353
- # outputs
1354
- joint_q_new: wp.array(dtype=float),
1355
- joint_qd_new: wp.array(dtype=float),
1356
- ):
1357
- # one thread per-articulation
1358
- index = wp.tid()
1359
-
1360
- type = joint_type[index]
1361
- coord_start = joint_q_start[index]
1362
- dof_start = joint_qd_start[index]
1363
- lin_axis_count = joint_axis_dim[index, 0]
1364
- ang_axis_count = joint_axis_dim[index, 1]
1365
-
1366
- jcalc_integrate(
1367
- type,
1368
- joint_q,
1369
- joint_qd,
1370
- joint_qdd,
1371
- coord_start,
1372
- dof_start,
1373
- lin_axis_count,
1374
- ang_axis_count,
1375
- dt,
1376
- joint_q_new,
1377
- joint_qd_new,
1378
- )
1379
-
1380
-
1381
- @wp.kernel
1382
- def eval_body_inertial_velocities(
1383
- body_q: wp.array(dtype=wp.transform),
1384
- body_v_s: wp.array(dtype=wp.spatial_vector),
1385
- # outputs
1386
- body_qd: wp.array(dtype=wp.spatial_vector),
1387
- ):
1388
- tid = wp.tid()
1389
-
1390
- X_sc = body_q[tid]
1391
- v_s = body_v_s[tid]
1392
- w = wp.spatial_top(v_s)
1393
- v = wp.spatial_bottom(v_s)
1394
-
1395
- v_inertial = v + wp.cross(w, wp.transform_get_translation(X_sc))
1396
-
1397
- body_qd[tid] = wp.spatial_vector(w, v_inertial)
1398
-
1399
-
1400
- class FeatherstoneIntegrator(Integrator):
1401
- """A semi-implicit integrator using symplectic Euler that operates
1402
- on reduced (also called generalized) coordinates to simulate articulated rigid body dynamics
1403
- based on Featherstone's composite rigid body algorithm (CRBA).
1404
-
1405
- See: Featherstone, Roy. Rigid Body Dynamics Algorithms. Springer US, 2014.
1406
-
1407
- Instead of maximal coordinates :attr:`State.body_q` (rigid body positions) and :attr:`State.body_qd`
1408
- (rigid body velocities) as is the case :class:`SemiImplicitIntegrator`, :class:`FeatherstoneIntegrator`
1409
- uses :attr:`State.joint_q` and :attr:`State.joint_qd` to represent the positions and velocities of
1410
- joints without allowing any redundant degrees of freedom.
1411
-
1412
- After constructing :class:`Model` and :class:`State` objects this time-integrator
1413
- may be used to advance the simulation state forward in time.
1414
-
1415
- Note:
1416
- Unlike :class:`SemiImplicitIntegrator` and :class:`XPBDIntegrator`, :class:`FeatherstoneIntegrator` does not simulate rigid bodies with nonzero mass as floating bodies if they are not connected through any joints. Floating-base systems require an explicit free joint with which the body is connected to the world, see :meth:`ModelBuilder.add_joint_free`.
1417
-
1418
- Semi-implicit time integration is a variational integrator that
1419
- preserves energy, however it not unconditionally stable, and requires a time-step
1420
- small enough to support the required stiffness and damping forces.
1421
-
1422
- See: https://en.wikipedia.org/wiki/Semi-implicit_Euler_method
1423
-
1424
- Example
1425
- -------
1426
-
1427
- .. code-block:: python
1428
-
1429
- integrator = wp.FeatherstoneIntegrator(model)
1430
-
1431
- # simulation loop
1432
- for i in range(100):
1433
- state = integrator.simulate(model, state_in, state_out, dt)
1434
-
1435
- Note:
1436
- The :class:`FeatherstoneIntegrator` requires the :class:`Model` to be passed in as a constructor argument.
1437
-
1438
- """
1439
-
1440
- def __init__(self, model, angular_damping=0.05, update_mass_matrix_every=1):
1441
- """
1442
- Args:
1443
- model (Model): the model to be simulated.
1444
- angular_damping (float, optional): Angular damping factor. Defaults to 0.05.
1445
- update_mass_matrix_every (int, optional): How often to update the mass matrix (every n-th time the :meth:`simulate` function gets called). Defaults to 1.
1446
- """
1447
- self.angular_damping = angular_damping
1448
- self.update_mass_matrix_every = update_mass_matrix_every
1449
- self.compute_articulation_indices(model)
1450
- self.allocate_model_aux_vars(model)
1451
- self._step = 0
1452
-
1453
- def compute_articulation_indices(self, model):
1454
- # calculate total size and offsets of Jacobian and mass matrices for entire system
1455
- if model.joint_count:
1456
- self.J_size = 0
1457
- self.M_size = 0
1458
- self.H_size = 0
1459
-
1460
- articulation_J_start = []
1461
- articulation_M_start = []
1462
- articulation_H_start = []
1463
-
1464
- articulation_M_rows = []
1465
- articulation_H_rows = []
1466
- articulation_J_rows = []
1467
- articulation_J_cols = []
1468
-
1469
- articulation_dof_start = []
1470
- articulation_coord_start = []
1471
-
1472
- articulation_start = model.articulation_start.numpy()
1473
- joint_q_start = model.joint_q_start.numpy()
1474
- joint_qd_start = model.joint_qd_start.numpy()
1475
-
1476
- for i in range(model.articulation_count):
1477
- first_joint = articulation_start[i]
1478
- last_joint = articulation_start[i + 1]
1479
-
1480
- first_coord = joint_q_start[first_joint]
1481
-
1482
- first_dof = joint_qd_start[first_joint]
1483
- last_dof = joint_qd_start[last_joint]
1484
-
1485
- joint_count = last_joint - first_joint
1486
- dof_count = last_dof - first_dof
1487
-
1488
- articulation_J_start.append(self.J_size)
1489
- articulation_M_start.append(self.M_size)
1490
- articulation_H_start.append(self.H_size)
1491
- articulation_dof_start.append(first_dof)
1492
- articulation_coord_start.append(first_coord)
1493
-
1494
- # bit of data duplication here, but will leave it as such for clarity
1495
- articulation_M_rows.append(joint_count * 6)
1496
- articulation_H_rows.append(dof_count)
1497
- articulation_J_rows.append(joint_count * 6)
1498
- articulation_J_cols.append(dof_count)
1499
-
1500
- self.J_size += 6 * joint_count * dof_count
1501
- self.M_size += 6 * joint_count * 6 * joint_count
1502
- self.H_size += dof_count * dof_count
1503
-
1504
- # matrix offsets for batched gemm
1505
- self.articulation_J_start = wp.array(articulation_J_start, dtype=wp.int32, device=model.device)
1506
- self.articulation_M_start = wp.array(articulation_M_start, dtype=wp.int32, device=model.device)
1507
- self.articulation_H_start = wp.array(articulation_H_start, dtype=wp.int32, device=model.device)
1508
-
1509
- self.articulation_M_rows = wp.array(articulation_M_rows, dtype=wp.int32, device=model.device)
1510
- self.articulation_H_rows = wp.array(articulation_H_rows, dtype=wp.int32, device=model.device)
1511
- self.articulation_J_rows = wp.array(articulation_J_rows, dtype=wp.int32, device=model.device)
1512
- self.articulation_J_cols = wp.array(articulation_J_cols, dtype=wp.int32, device=model.device)
1513
-
1514
- self.articulation_dof_start = wp.array(articulation_dof_start, dtype=wp.int32, device=model.device)
1515
- self.articulation_coord_start = wp.array(articulation_coord_start, dtype=wp.int32, device=model.device)
1516
-
1517
- def allocate_model_aux_vars(self, model):
1518
- # allocate mass, Jacobian matrices, and other auxiliary variables pertaining to the model
1519
- if model.joint_count:
1520
- # system matrices
1521
- self.M = wp.zeros((self.M_size,), dtype=wp.float32, device=model.device, requires_grad=model.requires_grad)
1522
- self.J = wp.zeros((self.J_size,), dtype=wp.float32, device=model.device, requires_grad=model.requires_grad)
1523
- self.P = wp.empty_like(self.J, requires_grad=model.requires_grad)
1524
- self.H = wp.empty((self.H_size,), dtype=wp.float32, device=model.device, requires_grad=model.requires_grad)
1525
-
1526
- # zero since only upper triangle is set which can trigger NaN detection
1527
- self.L = wp.zeros_like(self.H)
1528
-
1529
- if model.body_count:
1530
- # TODO use requires_grad here?
1531
- self.body_I_m = wp.empty((model.body_count,), dtype=wp.spatial_matrix, device=model.device)
1532
- wp.launch(
1533
- compute_spatial_inertia,
1534
- model.body_count,
1535
- inputs=[model.body_inertia, model.body_mass],
1536
- outputs=[self.body_I_m],
1537
- device=model.device,
1538
- )
1539
- self.body_X_com = wp.empty((model.body_count,), dtype=wp.transform, device=model.device)
1540
- wp.launch(
1541
- compute_com_transforms,
1542
- model.body_count,
1543
- inputs=[model.body_com],
1544
- outputs=[self.body_X_com],
1545
- device=model.device,
1546
- )
1547
-
1548
- def allocate_state_aux_vars(self, model, target, requires_grad):
1549
- # allocate auxiliary variables that vary with state
1550
- if model.body_count:
1551
- # joints
1552
- target.joint_qdd = wp.zeros_like(model.joint_qd, requires_grad=requires_grad)
1553
- target.joint_tau = wp.empty_like(model.joint_qd, requires_grad=requires_grad)
1554
- if requires_grad:
1555
- # used in the custom grad implementation of eval_dense_solve_batched
1556
- target.joint_solve_tmp = wp.zeros_like(model.joint_qd, requires_grad=True)
1557
- else:
1558
- target.joint_solve_tmp = None
1559
- target.joint_S_s = wp.empty(
1560
- (model.joint_dof_count,),
1561
- dtype=wp.spatial_vector,
1562
- device=model.device,
1563
- requires_grad=requires_grad,
1564
- )
1565
-
1566
- # derived rigid body data (maximal coordinates)
1567
- target.body_q_com = wp.empty_like(model.body_q, requires_grad=requires_grad)
1568
- target.body_I_s = wp.empty(
1569
- (model.body_count,), dtype=wp.spatial_matrix, device=model.device, requires_grad=requires_grad
1570
- )
1571
- target.body_v_s = wp.empty(
1572
- (model.body_count,), dtype=wp.spatial_vector, device=model.device, requires_grad=requires_grad
1573
- )
1574
- target.body_a_s = wp.empty(
1575
- (model.body_count,), dtype=wp.spatial_vector, device=model.device, requires_grad=requires_grad
1576
- )
1577
- target.body_f_s = wp.zeros(
1578
- (model.body_count,), dtype=wp.spatial_vector, device=model.device, requires_grad=requires_grad
1579
- )
1580
- target.body_ft_s = wp.zeros(
1581
- (model.body_count,), dtype=wp.spatial_vector, device=model.device, requires_grad=requires_grad
1582
- )
1583
-
1584
- target._featherstone_augmented = True
1585
-
1586
- def simulate(self, model: Model, state_in: State, state_out: State, dt: float, control: Control = None):
1587
- requires_grad = state_in.requires_grad
1588
-
1589
- # optionally create dynamical auxiliary variables
1590
- if requires_grad:
1591
- state_aug = state_out
1592
- else:
1593
- state_aug = self
1594
-
1595
- if not getattr(state_aug, "_featherstone_augmented", False):
1596
- self.allocate_state_aux_vars(model, state_aug, requires_grad)
1597
- if control is None:
1598
- control = model.control(clone_variables=False)
1599
-
1600
- with wp.ScopedTimer("simulate", False):
1601
- particle_f = None
1602
- body_f = None
1603
-
1604
- if state_in.particle_count:
1605
- particle_f = state_in.particle_f
1606
-
1607
- if state_in.body_count:
1608
- body_f = state_in.body_f
1609
-
1610
- # damped springs
1611
- eval_spring_forces(model, state_in, particle_f)
1612
-
1613
- # triangle elastic and lift/drag forces
1614
- eval_triangle_forces(model, state_in, control, particle_f)
1615
-
1616
- # triangle/triangle contacts
1617
- eval_triangle_contact_forces(model, state_in, particle_f)
1618
-
1619
- # triangle bending
1620
- eval_bending_forces(model, state_in, particle_f)
1621
-
1622
- # tetrahedral FEM
1623
- eval_tetrahedral_forces(model, state_in, control, particle_f)
1624
-
1625
- # particle-particle interactions
1626
- eval_particle_forces(model, state_in, particle_f)
1627
-
1628
- # particle ground contacts
1629
- eval_particle_ground_contact_forces(model, state_in, particle_f)
1630
-
1631
- # particle shape contact
1632
- eval_particle_body_contact_forces(model, state_in, particle_f, body_f)
1633
-
1634
- # muscles
1635
- if False:
1636
- eval_muscle_forces(model, state_in, control, body_f)
1637
-
1638
- # ----------------------------
1639
- # articulations
1640
-
1641
- if model.joint_count:
1642
-
1643
- # evaluate body transforms
1644
- wp.launch(
1645
- eval_rigid_fk,
1646
- dim=model.articulation_count,
1647
- inputs=[
1648
- model.articulation_start,
1649
- model.joint_type,
1650
- model.joint_parent,
1651
- model.joint_child,
1652
- model.joint_q_start,
1653
- state_in.joint_q,
1654
- model.joint_X_p,
1655
- model.joint_X_c,
1656
- self.body_X_com,
1657
- model.joint_axis,
1658
- model.joint_axis_start,
1659
- model.joint_axis_dim,
1660
- ],
1661
- outputs=[state_in.body_q, state_aug.body_q_com],
1662
- device=model.device,
1663
- )
1664
-
1665
- # print("body_X_sc:")
1666
- # print(state_in.body_q.numpy())
1667
-
1668
- # evaluate joint inertias, motion vectors, and forces
1669
- state_aug.body_f_s.zero_()
1670
- wp.launch(
1671
- eval_rigid_id,
1672
- dim=model.articulation_count,
1673
- inputs=[
1674
- model.articulation_start,
1675
- model.joint_type,
1676
- model.joint_parent,
1677
- model.joint_child,
1678
- model.joint_q_start,
1679
- model.joint_qd_start,
1680
- state_in.joint_q,
1681
- state_in.joint_qd,
1682
- model.joint_axis,
1683
- model.joint_axis_start,
1684
- model.joint_axis_dim,
1685
- self.body_I_m,
1686
- state_in.body_q,
1687
- state_aug.body_q_com,
1688
- model.joint_X_p,
1689
- model.joint_X_c,
1690
- model.gravity,
1691
- ],
1692
- outputs=[
1693
- state_aug.joint_S_s,
1694
- state_aug.body_I_s,
1695
- state_aug.body_v_s,
1696
- state_aug.body_f_s,
1697
- state_aug.body_a_s,
1698
- ],
1699
- device=model.device,
1700
- )
1701
-
1702
- if model.rigid_contact_max and (
1703
- model.ground and model.shape_ground_contact_pair_count or model.shape_contact_pair_count
1704
- ):
1705
- wp.launch(
1706
- kernel=eval_rigid_contacts,
1707
- dim=model.rigid_contact_max,
1708
- inputs=[
1709
- state_in.body_q,
1710
- state_aug.body_v_s,
1711
- model.body_com,
1712
- model.shape_materials,
1713
- model.shape_geo,
1714
- model.shape_body,
1715
- model.rigid_contact_count,
1716
- model.rigid_contact_point0,
1717
- model.rigid_contact_point1,
1718
- model.rigid_contact_normal,
1719
- model.rigid_contact_shape0,
1720
- model.rigid_contact_shape1,
1721
- True,
1722
- ],
1723
- outputs=[body_f],
1724
- device=model.device,
1725
- )
1726
-
1727
- # if model.rigid_contact_count.numpy()[0] > 0:
1728
- # print(body_f.numpy())
1729
-
1730
- if model.articulation_count:
1731
- # evaluate joint torques
1732
- state_aug.body_ft_s.zero_()
1733
- wp.launch(
1734
- eval_rigid_tau,
1735
- dim=model.articulation_count,
1736
- inputs=[
1737
- model.articulation_start,
1738
- model.joint_type,
1739
- model.joint_parent,
1740
- model.joint_child,
1741
- model.joint_q_start,
1742
- model.joint_qd_start,
1743
- model.joint_axis_start,
1744
- model.joint_axis_dim,
1745
- model.joint_axis_mode,
1746
- state_in.joint_q,
1747
- state_in.joint_qd,
1748
- control.joint_act,
1749
- model.joint_target_ke,
1750
- model.joint_target_kd,
1751
- model.joint_limit_lower,
1752
- model.joint_limit_upper,
1753
- model.joint_limit_ke,
1754
- model.joint_limit_kd,
1755
- state_aug.joint_S_s,
1756
- state_aug.body_f_s,
1757
- body_f,
1758
- ],
1759
- outputs=[
1760
- state_aug.body_ft_s,
1761
- state_aug.joint_tau,
1762
- ],
1763
- device=model.device,
1764
- )
1765
-
1766
- # print("joint_tau:")
1767
- # print(state_aug.joint_tau.numpy())
1768
- # print("body_q:")
1769
- # print(state_in.body_q.numpy())
1770
- # print("body_qd:")
1771
- # print(state_in.body_qd.numpy())
1772
-
1773
- if self._step % self.update_mass_matrix_every == 0:
1774
- # build J
1775
- wp.launch(
1776
- eval_rigid_jacobian,
1777
- dim=model.articulation_count,
1778
- inputs=[
1779
- model.articulation_start,
1780
- self.articulation_J_start,
1781
- model.joint_parent,
1782
- model.joint_qd_start,
1783
- state_aug.joint_S_s,
1784
- ],
1785
- outputs=[self.J],
1786
- device=model.device,
1787
- )
1788
-
1789
- # build M
1790
- wp.launch(
1791
- eval_rigid_mass,
1792
- dim=model.articulation_count,
1793
- inputs=[
1794
- model.articulation_start,
1795
- self.articulation_M_start,
1796
- state_aug.body_I_s,
1797
- ],
1798
- outputs=[self.M],
1799
- device=model.device,
1800
- )
1801
-
1802
- # form P = M*J
1803
- wp.launch(
1804
- eval_dense_gemm_batched,
1805
- dim=model.articulation_count,
1806
- inputs=[
1807
- self.articulation_M_rows,
1808
- self.articulation_J_cols,
1809
- self.articulation_J_rows,
1810
- False,
1811
- False,
1812
- self.articulation_M_start,
1813
- self.articulation_J_start,
1814
- # P start is the same as J start since it has the same dims as J
1815
- self.articulation_J_start,
1816
- self.M,
1817
- self.J,
1818
- ],
1819
- outputs=[self.P],
1820
- device=model.device,
1821
- )
1822
-
1823
- # form H = J^T*P
1824
- wp.launch(
1825
- eval_dense_gemm_batched,
1826
- dim=model.articulation_count,
1827
- inputs=[
1828
- self.articulation_J_cols,
1829
- self.articulation_J_cols,
1830
- # P rows is the same as J rows
1831
- self.articulation_J_rows,
1832
- True,
1833
- False,
1834
- self.articulation_J_start,
1835
- # P start is the same as J start since it has the same dims as J
1836
- self.articulation_J_start,
1837
- self.articulation_H_start,
1838
- self.J,
1839
- self.P,
1840
- ],
1841
- outputs=[self.H],
1842
- device=model.device,
1843
- )
1844
-
1845
- # compute decomposition
1846
- wp.launch(
1847
- eval_dense_cholesky_batched,
1848
- dim=model.articulation_count,
1849
- inputs=[
1850
- self.articulation_H_start,
1851
- self.articulation_H_rows,
1852
- self.H,
1853
- model.joint_armature,
1854
- ],
1855
- outputs=[self.L],
1856
- device=model.device,
1857
- )
1858
-
1859
- # print("joint_act:")
1860
- # print(control.joint_act.numpy())
1861
- # print("joint_tau:")
1862
- # print(state_aug.joint_tau.numpy())
1863
- # print("H:")
1864
- # print(self.H.numpy())
1865
- # print("L:")
1866
- # print(self.L.numpy())
1867
-
1868
- # solve for qdd
1869
- state_aug.joint_qdd.zero_()
1870
- wp.launch(
1871
- eval_dense_solve_batched,
1872
- dim=model.articulation_count,
1873
- inputs=[
1874
- self.articulation_H_start,
1875
- self.articulation_H_rows,
1876
- self.articulation_dof_start,
1877
- self.L,
1878
- state_aug.joint_tau,
1879
- ],
1880
- outputs=[
1881
- state_aug.joint_qdd,
1882
- state_aug.joint_solve_tmp,
1883
- ],
1884
- device=model.device,
1885
- )
1886
- # if wp.context.runtime.tape:
1887
- # wp.context.runtime.tape.record_func(
1888
- # backward=lambda: adj_matmul(
1889
- # a, b, c, a.grad, b.grad, c.grad, d.grad, alpha, beta, allow_tf32x3_arith, device
1890
- # ),
1891
- # arrays=[a, b, c, d],
1892
- # )
1893
- # print("joint_qdd:")
1894
- # print(state_aug.joint_qdd.numpy())
1895
- # print("\n\n")
1896
-
1897
- # -------------------------------------
1898
- # integrate bodies
1899
-
1900
- if model.joint_count:
1901
- wp.launch(
1902
- kernel=integrate_generalized_joints,
1903
- dim=model.joint_count,
1904
- inputs=[
1905
- model.joint_type,
1906
- model.joint_q_start,
1907
- model.joint_qd_start,
1908
- model.joint_axis_dim,
1909
- state_in.joint_q,
1910
- state_in.joint_qd,
1911
- state_aug.joint_qdd,
1912
- dt,
1913
- ],
1914
- outputs=[state_out.joint_q, state_out.joint_qd],
1915
- device=model.device,
1916
- )
1917
-
1918
- wp.launch(
1919
- eval_rigid_fk,
1920
- dim=model.articulation_count,
1921
- inputs=[
1922
- model.articulation_start,
1923
- model.joint_type,
1924
- model.joint_parent,
1925
- model.joint_child,
1926
- model.joint_q_start,
1927
- state_out.joint_q,
1928
- model.joint_X_p,
1929
- model.joint_X_c,
1930
- self.body_X_com,
1931
- model.joint_axis,
1932
- model.joint_axis_start,
1933
- model.joint_axis_dim,
1934
- ],
1935
- outputs=[state_out.body_q, state_aug.body_q_com],
1936
- device=model.device,
1937
- )
1938
-
1939
- # compute body_qd
1940
- state_aug.body_f_s.zero_()
1941
- wp.launch(
1942
- eval_rigid_id,
1943
- dim=model.articulation_count,
1944
- inputs=[
1945
- model.articulation_start,
1946
- model.joint_type,
1947
- model.joint_parent,
1948
- model.joint_child,
1949
- model.joint_q_start,
1950
- model.joint_qd_start,
1951
- state_out.joint_q,
1952
- state_out.joint_qd,
1953
- model.joint_axis,
1954
- model.joint_axis_start,
1955
- model.joint_axis_dim,
1956
- self.body_I_m,
1957
- state_out.body_q,
1958
- state_aug.body_q_com,
1959
- model.joint_X_p,
1960
- model.joint_X_c,
1961
- model.gravity,
1962
- ],
1963
- outputs=[
1964
- state_aug.joint_S_s,
1965
- state_aug.body_I_s,
1966
- state_aug.body_v_s,
1967
- state_aug.body_f_s,
1968
- state_aug.body_a_s,
1969
- ],
1970
- device=model.device,
1971
- )
1972
-
1973
- # body velocity in inertial frame
1974
- wp.launch(
1975
- kernel=eval_body_inertial_velocities,
1976
- dim=model.body_count,
1977
- inputs=[
1978
- state_out.body_q,
1979
- state_aug.body_v_s,
1980
- ],
1981
- outputs=[
1982
- state_out.body_qd,
1983
- ],
1984
- device=model.device,
1985
- )
1986
-
1987
- self.integrate_particles(model, state_in, state_out, dt)
1988
-
1989
- self._step += 1
1990
-
1991
- return state_out
1
+ # Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
2
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
3
+ # and proprietary rights in and to this software, related documentation
4
+ # and any modifications thereto. Any use, reproduction, disclosure or
5
+ # distribution of this software and related documentation without an express
6
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
+
8
+ import warp as wp
9
+
10
+ from .articulation import (
11
+ compute_2d_rotational_dofs,
12
+ compute_3d_rotational_dofs,
13
+ eval_fk,
14
+ )
15
+ from .integrator import Integrator
16
+ from .integrator_euler import (
17
+ eval_bending_forces,
18
+ eval_joint_force,
19
+ eval_muscle_forces,
20
+ eval_particle_body_contact_forces,
21
+ eval_particle_forces,
22
+ eval_particle_ground_contact_forces,
23
+ eval_rigid_contacts,
24
+ eval_spring_forces,
25
+ eval_tetrahedral_forces,
26
+ eval_triangle_contact_forces,
27
+ eval_triangle_forces,
28
+ )
29
+ from .model import Control, Model, State
30
+
31
+
32
+ # Frank & Park definition 3.20, pg 100
33
+ @wp.func
34
+ def transform_twist(t: wp.transform, x: wp.spatial_vector):
35
+ q = wp.transform_get_rotation(t)
36
+ p = wp.transform_get_translation(t)
37
+
38
+ w = wp.spatial_top(x)
39
+ v = wp.spatial_bottom(x)
40
+
41
+ w = wp.quat_rotate(q, w)
42
+ v = wp.quat_rotate(q, v) + wp.cross(p, w)
43
+
44
+ return wp.spatial_vector(w, v)
45
+
46
+
47
+ @wp.func
48
+ def transform_wrench(t: wp.transform, x: wp.spatial_vector):
49
+ q = wp.transform_get_rotation(t)
50
+ p = wp.transform_get_translation(t)
51
+
52
+ w = wp.spatial_top(x)
53
+ v = wp.spatial_bottom(x)
54
+
55
+ v = wp.quat_rotate(q, v)
56
+ w = wp.quat_rotate(q, w) + wp.cross(p, v)
57
+
58
+ return wp.spatial_vector(w, v)
59
+
60
+
61
+ @wp.func
62
+ def spatial_adjoint(R: wp.mat33, S: wp.mat33):
63
+ # T = [R 0]
64
+ # [S R]
65
+
66
+ # fmt: off
67
+ return wp.spatial_matrix(
68
+ R[0, 0], R[0, 1], R[0, 2], 0.0, 0.0, 0.0,
69
+ R[1, 0], R[1, 1], R[1, 2], 0.0, 0.0, 0.0,
70
+ R[2, 0], R[2, 1], R[2, 2], 0.0, 0.0, 0.0,
71
+ S[0, 0], S[0, 1], S[0, 2], R[0, 0], R[0, 1], R[0, 2],
72
+ S[1, 0], S[1, 1], S[1, 2], R[1, 0], R[1, 1], R[1, 2],
73
+ S[2, 0], S[2, 1], S[2, 2], R[2, 0], R[2, 1], R[2, 2],
74
+ )
75
+ # fmt: on
76
+
77
+
78
+ @wp.kernel
79
+ def compute_spatial_inertia(
80
+ body_inertia: wp.array(dtype=wp.mat33),
81
+ body_mass: wp.array(dtype=float),
82
+ # outputs
83
+ body_I_m: wp.array(dtype=wp.spatial_matrix),
84
+ ):
85
+ tid = wp.tid()
86
+ I = body_inertia[tid]
87
+ m = body_mass[tid]
88
+ # fmt: off
89
+ body_I_m[tid] = wp.spatial_matrix(
90
+ I[0, 0], I[0, 1], I[0, 2], 0.0, 0.0, 0.0,
91
+ I[1, 0], I[1, 1], I[1, 2], 0.0, 0.0, 0.0,
92
+ I[2, 0], I[2, 1], I[2, 2], 0.0, 0.0, 0.0,
93
+ 0.0, 0.0, 0.0, m, 0.0, 0.0,
94
+ 0.0, 0.0, 0.0, 0.0, m, 0.0,
95
+ 0.0, 0.0, 0.0, 0.0, 0.0, m,
96
+ )
97
+ # fmt: on
98
+
99
+
100
+ @wp.kernel
101
+ def compute_com_transforms(
102
+ body_com: wp.array(dtype=wp.vec3),
103
+ # outputs
104
+ body_X_com: wp.array(dtype=wp.transform),
105
+ ):
106
+ tid = wp.tid()
107
+ com = body_com[tid]
108
+ body_X_com[tid] = wp.transform(com, wp.quat_identity())
109
+
110
+
111
+ # computes adj_t^-T*I*adj_t^-1 (tensor change of coordinates), Frank & Park, section 8.2.3, pg 290
112
+ @wp.func
113
+ def spatial_transform_inertia(t: wp.transform, I: wp.spatial_matrix):
114
+ t_inv = wp.transform_inverse(t)
115
+
116
+ q = wp.transform_get_rotation(t_inv)
117
+ p = wp.transform_get_translation(t_inv)
118
+
119
+ r1 = wp.quat_rotate(q, wp.vec3(1.0, 0.0, 0.0))
120
+ r2 = wp.quat_rotate(q, wp.vec3(0.0, 1.0, 0.0))
121
+ r3 = wp.quat_rotate(q, wp.vec3(0.0, 0.0, 1.0))
122
+
123
+ R = wp.mat33(r1, r2, r3)
124
+ S = wp.skew(p) @ R
125
+
126
+ T = spatial_adjoint(R, S)
127
+
128
+ return wp.mul(wp.mul(wp.transpose(T), I), T)
129
+
130
+
131
+ # compute transform across a joint
132
+ @wp.func
133
+ def jcalc_transform(
134
+ type: int,
135
+ joint_axis: wp.array(dtype=wp.vec3),
136
+ axis_start: int,
137
+ lin_axis_count: int,
138
+ ang_axis_count: int,
139
+ joint_q: wp.array(dtype=float),
140
+ start: int,
141
+ ):
142
+ if type == wp.sim.JOINT_PRISMATIC:
143
+ q = joint_q[start]
144
+ axis = joint_axis[axis_start]
145
+ X_jc = wp.transform(axis * q, wp.quat_identity())
146
+ return X_jc
147
+
148
+ if type == wp.sim.JOINT_REVOLUTE:
149
+ q = joint_q[start]
150
+ axis = joint_axis[axis_start]
151
+ X_jc = wp.transform(wp.vec3(), wp.quat_from_axis_angle(axis, q))
152
+ return X_jc
153
+
154
+ if type == wp.sim.JOINT_BALL:
155
+ qx = joint_q[start + 0]
156
+ qy = joint_q[start + 1]
157
+ qz = joint_q[start + 2]
158
+ qw = joint_q[start + 3]
159
+
160
+ X_jc = wp.transform(wp.vec3(), wp.quat(qx, qy, qz, qw))
161
+ return X_jc
162
+
163
+ if type == wp.sim.JOINT_FIXED:
164
+ X_jc = wp.transform_identity()
165
+ return X_jc
166
+
167
+ if type == wp.sim.JOINT_FREE or type == wp.sim.JOINT_DISTANCE:
168
+ px = joint_q[start + 0]
169
+ py = joint_q[start + 1]
170
+ pz = joint_q[start + 2]
171
+
172
+ qx = joint_q[start + 3]
173
+ qy = joint_q[start + 4]
174
+ qz = joint_q[start + 5]
175
+ qw = joint_q[start + 6]
176
+
177
+ X_jc = wp.transform(wp.vec3(px, py, pz), wp.quat(qx, qy, qz, qw))
178
+ return X_jc
179
+
180
+ if type == wp.sim.JOINT_COMPOUND:
181
+ rot, _ = compute_3d_rotational_dofs(
182
+ joint_axis[axis_start],
183
+ joint_axis[axis_start + 1],
184
+ joint_axis[axis_start + 2],
185
+ joint_q[start + 0],
186
+ joint_q[start + 1],
187
+ joint_q[start + 2],
188
+ 0.0,
189
+ 0.0,
190
+ 0.0,
191
+ )
192
+
193
+ X_jc = wp.transform(wp.vec3(), rot)
194
+ return X_jc
195
+
196
+ if type == wp.sim.JOINT_UNIVERSAL:
197
+ rot, _ = compute_2d_rotational_dofs(
198
+ joint_axis[axis_start],
199
+ joint_axis[axis_start + 1],
200
+ joint_q[start + 0],
201
+ joint_q[start + 1],
202
+ 0.0,
203
+ 0.0,
204
+ )
205
+
206
+ X_jc = wp.transform(wp.vec3(), rot)
207
+ return X_jc
208
+
209
+ if type == wp.sim.JOINT_D6:
210
+ pos = wp.vec3(0.0)
211
+ rot = wp.quat_identity()
212
+
213
+ # unroll for loop to ensure joint actions remain differentiable
214
+ # (since differentiating through a for loop that updates a local variable is not supported)
215
+
216
+ if lin_axis_count > 0:
217
+ axis = joint_axis[axis_start + 0]
218
+ pos += axis * joint_q[start + 0]
219
+ if lin_axis_count > 1:
220
+ axis = joint_axis[axis_start + 1]
221
+ pos += axis * joint_q[start + 1]
222
+ if lin_axis_count > 2:
223
+ axis = joint_axis[axis_start + 2]
224
+ pos += axis * joint_q[start + 2]
225
+
226
+ ia = axis_start + lin_axis_count
227
+ iq = start + lin_axis_count
228
+ if ang_axis_count == 1:
229
+ axis = joint_axis[ia]
230
+ rot = wp.quat_from_axis_angle(axis, joint_q[iq])
231
+ if ang_axis_count == 2:
232
+ rot, _ = compute_2d_rotational_dofs(
233
+ joint_axis[ia + 0],
234
+ joint_axis[ia + 1],
235
+ joint_q[iq + 0],
236
+ joint_q[iq + 1],
237
+ 0.0,
238
+ 0.0,
239
+ )
240
+ if ang_axis_count == 3:
241
+ rot, _ = compute_3d_rotational_dofs(
242
+ joint_axis[ia + 0],
243
+ joint_axis[ia + 1],
244
+ joint_axis[ia + 2],
245
+ joint_q[iq + 0],
246
+ joint_q[iq + 1],
247
+ joint_q[iq + 2],
248
+ 0.0,
249
+ 0.0,
250
+ 0.0,
251
+ )
252
+
253
+ X_jc = wp.transform(pos, rot)
254
+ return X_jc
255
+
256
+ # default case
257
+ return wp.transform_identity()
258
+
259
+
260
+ # compute motion subspace and velocity for a joint
261
+ @wp.func
262
+ def jcalc_motion(
263
+ type: int,
264
+ joint_axis: wp.array(dtype=wp.vec3),
265
+ axis_start: int,
266
+ lin_axis_count: int,
267
+ ang_axis_count: int,
268
+ X_sc: wp.transform,
269
+ joint_q: wp.array(dtype=float),
270
+ joint_qd: wp.array(dtype=float),
271
+ q_start: int,
272
+ qd_start: int,
273
+ # outputs
274
+ joint_S_s: wp.array(dtype=wp.spatial_vector),
275
+ ):
276
+ if type == wp.sim.JOINT_PRISMATIC:
277
+ axis = joint_axis[axis_start]
278
+ S_s = transform_twist(X_sc, wp.spatial_vector(wp.vec3(), axis))
279
+ v_j_s = S_s * joint_qd[qd_start]
280
+ joint_S_s[qd_start] = S_s
281
+ return v_j_s
282
+
283
+ if type == wp.sim.JOINT_REVOLUTE:
284
+ axis = joint_axis[axis_start]
285
+ S_s = transform_twist(X_sc, wp.spatial_vector(axis, wp.vec3()))
286
+ v_j_s = S_s * joint_qd[qd_start]
287
+ joint_S_s[qd_start] = S_s
288
+ return v_j_s
289
+
290
+ if type == wp.sim.JOINT_UNIVERSAL:
291
+ axis_0 = joint_axis[axis_start + 0]
292
+ axis_1 = joint_axis[axis_start + 1]
293
+ q_off = wp.quat_from_matrix(wp.mat33(axis_0, axis_1, wp.cross(axis_0, axis_1)))
294
+ local_0 = wp.quat_rotate(q_off, wp.vec3(1.0, 0.0, 0.0))
295
+ local_1 = wp.quat_rotate(q_off, wp.vec3(0.0, 1.0, 0.0))
296
+
297
+ axis_0 = local_0
298
+ q_0 = wp.quat_from_axis_angle(axis_0, joint_q[q_start + 0])
299
+
300
+ axis_1 = wp.quat_rotate(q_0, local_1)
301
+
302
+ S_0 = transform_twist(X_sc, wp.spatial_vector(axis_0, wp.vec3()))
303
+ S_1 = transform_twist(X_sc, wp.spatial_vector(axis_1, wp.vec3()))
304
+
305
+ joint_S_s[qd_start + 0] = S_0
306
+ joint_S_s[qd_start + 1] = S_1
307
+
308
+ return S_0 * joint_qd[qd_start + 0] + S_1 * joint_qd[qd_start + 1]
309
+
310
+ if type == wp.sim.JOINT_COMPOUND:
311
+ axis_0 = joint_axis[axis_start + 0]
312
+ axis_1 = joint_axis[axis_start + 1]
313
+ axis_2 = joint_axis[axis_start + 2]
314
+ q_off = wp.quat_from_matrix(wp.mat33(axis_0, axis_1, axis_2))
315
+ local_0 = wp.quat_rotate(q_off, wp.vec3(1.0, 0.0, 0.0))
316
+ local_1 = wp.quat_rotate(q_off, wp.vec3(0.0, 1.0, 0.0))
317
+ local_2 = wp.quat_rotate(q_off, wp.vec3(0.0, 0.0, 1.0))
318
+
319
+ axis_0 = local_0
320
+ q_0 = wp.quat_from_axis_angle(axis_0, joint_q[q_start + 0])
321
+
322
+ axis_1 = wp.quat_rotate(q_0, local_1)
323
+ q_1 = wp.quat_from_axis_angle(axis_1, joint_q[q_start + 1])
324
+
325
+ axis_2 = wp.quat_rotate(q_1 * q_0, local_2)
326
+
327
+ S_0 = transform_twist(X_sc, wp.spatial_vector(axis_0, wp.vec3()))
328
+ S_1 = transform_twist(X_sc, wp.spatial_vector(axis_1, wp.vec3()))
329
+ S_2 = transform_twist(X_sc, wp.spatial_vector(axis_2, wp.vec3()))
330
+
331
+ joint_S_s[qd_start + 0] = S_0
332
+ joint_S_s[qd_start + 1] = S_1
333
+ joint_S_s[qd_start + 2] = S_2
334
+
335
+ return S_0 * joint_qd[qd_start + 0] + S_1 * joint_qd[qd_start + 1] + S_2 * joint_qd[qd_start + 2]
336
+
337
+ if type == wp.sim.JOINT_D6:
338
+ v_j_s = wp.spatial_vector()
339
+ if lin_axis_count > 0:
340
+ axis = joint_axis[axis_start + 0]
341
+ S_s = transform_twist(X_sc, wp.spatial_vector(wp.vec3(), axis))
342
+ v_j_s += S_s * joint_qd[qd_start + 0]
343
+ joint_S_s[qd_start + 0] = S_s
344
+ if lin_axis_count > 1:
345
+ axis = joint_axis[axis_start + 1]
346
+ S_s = transform_twist(X_sc, wp.spatial_vector(wp.vec3(), axis))
347
+ v_j_s += S_s * joint_qd[qd_start + 1]
348
+ joint_S_s[qd_start + 1] = S_s
349
+ if lin_axis_count > 2:
350
+ axis = joint_axis[axis_start + 2]
351
+ S_s = transform_twist(X_sc, wp.spatial_vector(wp.vec3(), axis))
352
+ v_j_s += S_s * joint_qd[qd_start + 2]
353
+ joint_S_s[qd_start + 2] = S_s
354
+ if ang_axis_count > 0:
355
+ axis = joint_axis[axis_start + lin_axis_count + 0]
356
+ S_s = transform_twist(X_sc, wp.spatial_vector(axis, wp.vec3()))
357
+ v_j_s += S_s * joint_qd[qd_start + lin_axis_count + 0]
358
+ joint_S_s[qd_start + lin_axis_count + 0] = S_s
359
+ if ang_axis_count > 1:
360
+ axis = joint_axis[axis_start + lin_axis_count + 1]
361
+ S_s = transform_twist(X_sc, wp.spatial_vector(axis, wp.vec3()))
362
+ v_j_s += S_s * joint_qd[qd_start + lin_axis_count + 1]
363
+ joint_S_s[qd_start + lin_axis_count + 1] = S_s
364
+ if ang_axis_count > 2:
365
+ axis = joint_axis[axis_start + lin_axis_count + 2]
366
+ S_s = transform_twist(X_sc, wp.spatial_vector(axis, wp.vec3()))
367
+ v_j_s += S_s * joint_qd[qd_start + lin_axis_count + 2]
368
+ joint_S_s[qd_start + lin_axis_count + 2] = S_s
369
+
370
+ return v_j_s
371
+
372
+ if type == wp.sim.JOINT_BALL:
373
+ S_0 = transform_twist(X_sc, wp.spatial_vector(1.0, 0.0, 0.0, 0.0, 0.0, 0.0))
374
+ S_1 = transform_twist(X_sc, wp.spatial_vector(0.0, 1.0, 0.0, 0.0, 0.0, 0.0))
375
+ S_2 = transform_twist(X_sc, wp.spatial_vector(0.0, 0.0, 1.0, 0.0, 0.0, 0.0))
376
+
377
+ joint_S_s[qd_start + 0] = S_0
378
+ joint_S_s[qd_start + 1] = S_1
379
+ joint_S_s[qd_start + 2] = S_2
380
+
381
+ return S_0 * joint_qd[qd_start + 0] + S_1 * joint_qd[qd_start + 1] + S_2 * joint_qd[qd_start + 2]
382
+
383
+ if type == wp.sim.JOINT_FIXED:
384
+ return wp.spatial_vector()
385
+
386
+ if type == wp.sim.JOINT_FREE or type == wp.sim.JOINT_DISTANCE:
387
+ v_j_s = transform_twist(
388
+ X_sc,
389
+ wp.spatial_vector(
390
+ joint_qd[qd_start + 0],
391
+ joint_qd[qd_start + 1],
392
+ joint_qd[qd_start + 2],
393
+ joint_qd[qd_start + 3],
394
+ joint_qd[qd_start + 4],
395
+ joint_qd[qd_start + 5],
396
+ ),
397
+ )
398
+
399
+ joint_S_s[qd_start + 0] = transform_twist(X_sc, wp.spatial_vector(1.0, 0.0, 0.0, 0.0, 0.0, 0.0))
400
+ joint_S_s[qd_start + 1] = transform_twist(X_sc, wp.spatial_vector(0.0, 1.0, 0.0, 0.0, 0.0, 0.0))
401
+ joint_S_s[qd_start + 2] = transform_twist(X_sc, wp.spatial_vector(0.0, 0.0, 1.0, 0.0, 0.0, 0.0))
402
+ joint_S_s[qd_start + 3] = transform_twist(X_sc, wp.spatial_vector(0.0, 0.0, 0.0, 1.0, 0.0, 0.0))
403
+ joint_S_s[qd_start + 4] = transform_twist(X_sc, wp.spatial_vector(0.0, 0.0, 0.0, 0.0, 1.0, 0.0))
404
+ joint_S_s[qd_start + 5] = transform_twist(X_sc, wp.spatial_vector(0.0, 0.0, 0.0, 0.0, 0.0, 1.0))
405
+
406
+ return v_j_s
407
+
408
+ wp.printf("jcalc_motion not implemented for joint type %d\n", type)
409
+
410
+ # default case
411
+ return wp.spatial_vector()
412
+
413
+
414
+ # computes joint space forces/torques in tau
415
+ @wp.func
416
+ def jcalc_tau(
417
+ type: int,
418
+ joint_target_ke: wp.array(dtype=float),
419
+ joint_target_kd: wp.array(dtype=float),
420
+ joint_limit_ke: wp.array(dtype=float),
421
+ joint_limit_kd: wp.array(dtype=float),
422
+ joint_S_s: wp.array(dtype=wp.spatial_vector),
423
+ joint_q: wp.array(dtype=float),
424
+ joint_qd: wp.array(dtype=float),
425
+ joint_act: wp.array(dtype=float),
426
+ joint_axis_mode: wp.array(dtype=int),
427
+ joint_limit_lower: wp.array(dtype=float),
428
+ joint_limit_upper: wp.array(dtype=float),
429
+ coord_start: int,
430
+ dof_start: int,
431
+ axis_start: int,
432
+ lin_axis_count: int,
433
+ ang_axis_count: int,
434
+ body_f_s: wp.spatial_vector,
435
+ # outputs
436
+ tau: wp.array(dtype=float),
437
+ ):
438
+ if type == wp.sim.JOINT_PRISMATIC or type == wp.sim.JOINT_REVOLUTE:
439
+ S_s = joint_S_s[dof_start]
440
+
441
+ q = joint_q[coord_start]
442
+ qd = joint_qd[dof_start]
443
+ act = joint_act[axis_start]
444
+
445
+ lower = joint_limit_lower[axis_start]
446
+ upper = joint_limit_upper[axis_start]
447
+
448
+ limit_ke = joint_limit_ke[axis_start]
449
+ limit_kd = joint_limit_kd[axis_start]
450
+ target_ke = joint_target_ke[axis_start]
451
+ target_kd = joint_target_kd[axis_start]
452
+ mode = joint_axis_mode[axis_start]
453
+
454
+ # total torque / force on the joint
455
+ t = -wp.dot(S_s, body_f_s) + eval_joint_force(
456
+ q, qd, act, target_ke, target_kd, lower, upper, limit_ke, limit_kd, mode
457
+ )
458
+
459
+ tau[dof_start] = t
460
+
461
+ return
462
+
463
+ if type == wp.sim.JOINT_BALL:
464
+ # target_ke = joint_target_ke[axis_start]
465
+ # target_kd = joint_target_kd[axis_start]
466
+
467
+ for i in range(3):
468
+ S_s = joint_S_s[dof_start + i]
469
+
470
+ # w = joint_qd[dof_start + i]
471
+ # r = joint_q[coord_start + i]
472
+
473
+ tau[dof_start + i] = -wp.dot(S_s, body_f_s) # - w * target_kd - r * target_ke
474
+
475
+ return
476
+
477
+ if type == wp.sim.JOINT_FREE or type == wp.sim.JOINT_DISTANCE:
478
+ for i in range(6):
479
+ S_s = joint_S_s[dof_start + i]
480
+ tau[dof_start + i] = -wp.dot(S_s, body_f_s)
481
+
482
+ return
483
+
484
+ if type == wp.sim.JOINT_COMPOUND or type == wp.sim.JOINT_UNIVERSAL or type == wp.sim.JOINT_D6:
485
+ axis_count = lin_axis_count + ang_axis_count
486
+
487
+ for i in range(axis_count):
488
+ S_s = joint_S_s[dof_start + i]
489
+
490
+ q = joint_q[coord_start + i]
491
+ qd = joint_qd[dof_start + i]
492
+ act = joint_act[axis_start + i]
493
+
494
+ lower = joint_limit_lower[axis_start + i]
495
+ upper = joint_limit_upper[axis_start + i]
496
+ limit_ke = joint_limit_ke[axis_start + i]
497
+ limit_kd = joint_limit_kd[axis_start + i]
498
+ target_ke = joint_target_ke[axis_start + i]
499
+ target_kd = joint_target_kd[axis_start + i]
500
+ mode = joint_axis_mode[axis_start + i]
501
+
502
+ f = eval_joint_force(q, qd, act, target_ke, target_kd, lower, upper, limit_ke, limit_kd, mode)
503
+
504
+ # total torque / force on the joint
505
+ t = -wp.dot(S_s, body_f_s) + f
506
+
507
+ tau[dof_start + i] = t
508
+
509
+ return
510
+
511
+
512
+ @wp.func
513
+ def jcalc_integrate(
514
+ type: int,
515
+ joint_q: wp.array(dtype=float),
516
+ joint_qd: wp.array(dtype=float),
517
+ joint_qdd: wp.array(dtype=float),
518
+ coord_start: int,
519
+ dof_start: int,
520
+ lin_axis_count: int,
521
+ ang_axis_count: int,
522
+ dt: float,
523
+ # outputs
524
+ joint_q_new: wp.array(dtype=float),
525
+ joint_qd_new: wp.array(dtype=float),
526
+ ):
527
+ if type == wp.sim.JOINT_FIXED:
528
+ return
529
+
530
+ # prismatic / revolute
531
+ if type == wp.sim.JOINT_PRISMATIC or type == wp.sim.JOINT_REVOLUTE:
532
+ qdd = joint_qdd[dof_start]
533
+ qd = joint_qd[dof_start]
534
+ q = joint_q[coord_start]
535
+
536
+ qd_new = qd + qdd * dt
537
+ q_new = q + qd_new * dt
538
+
539
+ joint_qd_new[dof_start] = qd_new
540
+ joint_q_new[coord_start] = q_new
541
+
542
+ return
543
+
544
+ # ball
545
+ if type == wp.sim.JOINT_BALL:
546
+ m_j = wp.vec3(joint_qdd[dof_start + 0], joint_qdd[dof_start + 1], joint_qdd[dof_start + 2])
547
+ w_j = wp.vec3(joint_qd[dof_start + 0], joint_qd[dof_start + 1], joint_qd[dof_start + 2])
548
+
549
+ r_j = wp.quat(
550
+ joint_q[coord_start + 0], joint_q[coord_start + 1], joint_q[coord_start + 2], joint_q[coord_start + 3]
551
+ )
552
+
553
+ # symplectic Euler
554
+ w_j_new = w_j + m_j * dt
555
+
556
+ drdt_j = wp.quat(w_j_new, 0.0) * r_j * 0.5
557
+
558
+ # new orientation (normalized)
559
+ r_j_new = wp.normalize(r_j + drdt_j * dt)
560
+
561
+ # update joint coords
562
+ joint_q_new[coord_start + 0] = r_j_new[0]
563
+ joint_q_new[coord_start + 1] = r_j_new[1]
564
+ joint_q_new[coord_start + 2] = r_j_new[2]
565
+ joint_q_new[coord_start + 3] = r_j_new[3]
566
+
567
+ # update joint vel
568
+ joint_qd_new[dof_start + 0] = w_j_new[0]
569
+ joint_qd_new[dof_start + 1] = w_j_new[1]
570
+ joint_qd_new[dof_start + 2] = w_j_new[2]
571
+
572
+ return
573
+
574
+ # free joint
575
+ if type == wp.sim.JOINT_FREE or type == wp.sim.JOINT_DISTANCE:
576
+ # dofs: qd = (omega_x, omega_y, omega_z, vel_x, vel_y, vel_z)
577
+ # coords: q = (trans_x, trans_y, trans_z, quat_x, quat_y, quat_z, quat_w)
578
+
579
+ # angular and linear acceleration
580
+ m_s = wp.vec3(joint_qdd[dof_start + 0], joint_qdd[dof_start + 1], joint_qdd[dof_start + 2])
581
+ a_s = wp.vec3(joint_qdd[dof_start + 3], joint_qdd[dof_start + 4], joint_qdd[dof_start + 5])
582
+
583
+ # angular and linear velocity
584
+ w_s = wp.vec3(joint_qd[dof_start + 0], joint_qd[dof_start + 1], joint_qd[dof_start + 2])
585
+ v_s = wp.vec3(joint_qd[dof_start + 3], joint_qd[dof_start + 4], joint_qd[dof_start + 5])
586
+
587
+ # symplectic Euler
588
+ w_s = w_s + m_s * dt
589
+ v_s = v_s + a_s * dt
590
+
591
+ # translation of origin
592
+ p_s = wp.vec3(joint_q[coord_start + 0], joint_q[coord_start + 1], joint_q[coord_start + 2])
593
+
594
+ # linear vel of origin (note q/qd switch order of linear angular elements)
595
+ # note we are converting the body twist in the space frame (w_s, v_s) to compute center of mass velcity
596
+ dpdt_s = v_s + wp.cross(w_s, p_s)
597
+
598
+ # quat and quat derivative
599
+ r_s = wp.quat(
600
+ joint_q[coord_start + 3], joint_q[coord_start + 4], joint_q[coord_start + 5], joint_q[coord_start + 6]
601
+ )
602
+
603
+ drdt_s = wp.quat(w_s, 0.0) * r_s * 0.5
604
+
605
+ # new orientation (normalized)
606
+ p_s_new = p_s + dpdt_s * dt
607
+ r_s_new = wp.normalize(r_s + drdt_s * dt)
608
+
609
+ # update transform
610
+ joint_q_new[coord_start + 0] = p_s_new[0]
611
+ joint_q_new[coord_start + 1] = p_s_new[1]
612
+ joint_q_new[coord_start + 2] = p_s_new[2]
613
+
614
+ joint_q_new[coord_start + 3] = r_s_new[0]
615
+ joint_q_new[coord_start + 4] = r_s_new[1]
616
+ joint_q_new[coord_start + 5] = r_s_new[2]
617
+ joint_q_new[coord_start + 6] = r_s_new[3]
618
+
619
+ # update joint_twist
620
+ joint_qd_new[dof_start + 0] = w_s[0]
621
+ joint_qd_new[dof_start + 1] = w_s[1]
622
+ joint_qd_new[dof_start + 2] = w_s[2]
623
+ joint_qd_new[dof_start + 3] = v_s[0]
624
+ joint_qd_new[dof_start + 4] = v_s[1]
625
+ joint_qd_new[dof_start + 5] = v_s[2]
626
+
627
+ return
628
+
629
+ # other joint types (compound, universal, D6)
630
+ if type == wp.sim.JOINT_COMPOUND or type == wp.sim.JOINT_UNIVERSAL or type == wp.sim.JOINT_D6:
631
+ axis_count = lin_axis_count + ang_axis_count
632
+
633
+ for i in range(axis_count):
634
+ qdd = joint_qdd[dof_start + i]
635
+ qd = joint_qd[dof_start + i]
636
+ q = joint_q[coord_start + i]
637
+
638
+ qd_new = qd + qdd * dt
639
+ q_new = q + qd_new * dt
640
+
641
+ joint_qd_new[dof_start + i] = qd_new
642
+ joint_q_new[coord_start + i] = q_new
643
+
644
+ return
645
+
646
+
647
+ @wp.func
648
+ def compute_link_transform(
649
+ i: int,
650
+ joint_type: wp.array(dtype=int),
651
+ joint_parent: wp.array(dtype=int),
652
+ joint_child: wp.array(dtype=int),
653
+ joint_q_start: wp.array(dtype=int),
654
+ joint_q: wp.array(dtype=float),
655
+ joint_X_p: wp.array(dtype=wp.transform),
656
+ joint_X_c: wp.array(dtype=wp.transform),
657
+ body_X_com: wp.array(dtype=wp.transform),
658
+ joint_axis: wp.array(dtype=wp.vec3),
659
+ joint_axis_start: wp.array(dtype=int),
660
+ joint_axis_dim: wp.array(dtype=int, ndim=2),
661
+ # outputs
662
+ body_q: wp.array(dtype=wp.transform),
663
+ body_q_com: wp.array(dtype=wp.transform),
664
+ ):
665
+ # parent transform
666
+ parent = joint_parent[i]
667
+ child = joint_child[i]
668
+
669
+ # parent transform in spatial coordinates
670
+ X_pj = joint_X_p[i]
671
+ X_cj = joint_X_c[i]
672
+ # parent anchor frame in world space
673
+ X_wpj = X_pj
674
+ if parent >= 0:
675
+ X_wp = body_q[parent]
676
+ X_wpj = X_wp * X_wpj
677
+
678
+ type = joint_type[i]
679
+ axis_start = joint_axis_start[i]
680
+ lin_axis_count = joint_axis_dim[i, 0]
681
+ ang_axis_count = joint_axis_dim[i, 1]
682
+ coord_start = joint_q_start[i]
683
+
684
+ # compute transform across joint
685
+ X_j = jcalc_transform(type, joint_axis, axis_start, lin_axis_count, ang_axis_count, joint_q, coord_start)
686
+
687
+ # transform from world to joint anchor frame at child body
688
+ X_wcj = X_wpj * X_j
689
+ # transform from world to child body frame
690
+ X_wc = X_wcj * wp.transform_inverse(X_cj)
691
+
692
+ # compute transform of center of mass
693
+ X_cm = body_X_com[child]
694
+ X_sm = X_wc * X_cm
695
+
696
+ # store geometry transforms
697
+ body_q[child] = X_wc
698
+ body_q_com[child] = X_sm
699
+
700
+
701
+ @wp.kernel
702
+ def eval_rigid_fk(
703
+ articulation_start: wp.array(dtype=int),
704
+ joint_type: wp.array(dtype=int),
705
+ joint_parent: wp.array(dtype=int),
706
+ joint_child: wp.array(dtype=int),
707
+ joint_q_start: wp.array(dtype=int),
708
+ joint_q: wp.array(dtype=float),
709
+ joint_X_p: wp.array(dtype=wp.transform),
710
+ joint_X_c: wp.array(dtype=wp.transform),
711
+ body_X_com: wp.array(dtype=wp.transform),
712
+ joint_axis: wp.array(dtype=wp.vec3),
713
+ joint_axis_start: wp.array(dtype=int),
714
+ joint_axis_dim: wp.array(dtype=int, ndim=2),
715
+ # outputs
716
+ body_q: wp.array(dtype=wp.transform),
717
+ body_q_com: wp.array(dtype=wp.transform),
718
+ ):
719
+ # one thread per-articulation
720
+ index = wp.tid()
721
+
722
+ start = articulation_start[index]
723
+ end = articulation_start[index + 1]
724
+
725
+ for i in range(start, end):
726
+ compute_link_transform(
727
+ i,
728
+ joint_type,
729
+ joint_parent,
730
+ joint_child,
731
+ joint_q_start,
732
+ joint_q,
733
+ joint_X_p,
734
+ joint_X_c,
735
+ body_X_com,
736
+ joint_axis,
737
+ joint_axis_start,
738
+ joint_axis_dim,
739
+ body_q,
740
+ body_q_com,
741
+ )
742
+
743
+
744
+ @wp.func
745
+ def spatial_cross(a: wp.spatial_vector, b: wp.spatial_vector):
746
+ w_a = wp.spatial_top(a)
747
+ v_a = wp.spatial_bottom(a)
748
+
749
+ w_b = wp.spatial_top(b)
750
+ v_b = wp.spatial_bottom(b)
751
+
752
+ w = wp.cross(w_a, w_b)
753
+ v = wp.cross(w_a, v_b) + wp.cross(v_a, w_b)
754
+
755
+ return wp.spatial_vector(w, v)
756
+
757
+
758
+ @wp.func
759
+ def spatial_cross_dual(a: wp.spatial_vector, b: wp.spatial_vector):
760
+ w_a = wp.spatial_top(a)
761
+ v_a = wp.spatial_bottom(a)
762
+
763
+ w_b = wp.spatial_top(b)
764
+ v_b = wp.spatial_bottom(b)
765
+
766
+ w = wp.cross(w_a, w_b) + wp.cross(v_a, v_b)
767
+ v = wp.cross(w_a, v_b)
768
+
769
+ return wp.spatial_vector(w, v)
770
+
771
+
772
+ @wp.func
773
+ def dense_index(stride: int, i: int, j: int):
774
+ return i * stride + j
775
+
776
+
777
+ @wp.func
778
+ def compute_link_velocity(
779
+ i: int,
780
+ joint_type: wp.array(dtype=int),
781
+ joint_parent: wp.array(dtype=int),
782
+ joint_child: wp.array(dtype=int),
783
+ joint_q_start: wp.array(dtype=int),
784
+ joint_qd_start: wp.array(dtype=int),
785
+ joint_q: wp.array(dtype=float),
786
+ joint_qd: wp.array(dtype=float),
787
+ joint_axis: wp.array(dtype=wp.vec3),
788
+ joint_axis_start: wp.array(dtype=int),
789
+ joint_axis_dim: wp.array(dtype=int, ndim=2),
790
+ body_I_m: wp.array(dtype=wp.spatial_matrix),
791
+ body_q: wp.array(dtype=wp.transform),
792
+ body_q_com: wp.array(dtype=wp.transform),
793
+ joint_X_p: wp.array(dtype=wp.transform),
794
+ joint_X_c: wp.array(dtype=wp.transform),
795
+ gravity: wp.vec3,
796
+ # outputs
797
+ joint_S_s: wp.array(dtype=wp.spatial_vector),
798
+ body_I_s: wp.array(dtype=wp.spatial_matrix),
799
+ body_v_s: wp.array(dtype=wp.spatial_vector),
800
+ body_f_s: wp.array(dtype=wp.spatial_vector),
801
+ body_a_s: wp.array(dtype=wp.spatial_vector),
802
+ ):
803
+ type = joint_type[i]
804
+ child = joint_child[i]
805
+ parent = joint_parent[i]
806
+ q_start = joint_q_start[i]
807
+ qd_start = joint_qd_start[i]
808
+
809
+ X_pj = joint_X_p[i]
810
+ # X_cj = joint_X_c[i]
811
+
812
+ # parent anchor frame in world space
813
+ X_wpj = X_pj
814
+ if parent >= 0:
815
+ X_wp = body_q[parent]
816
+ X_wpj = X_wp * X_wpj
817
+
818
+ # compute motion subspace and velocity across the joint (also stores S_s to global memory)
819
+ axis_start = joint_axis_start[i]
820
+ lin_axis_count = joint_axis_dim[i, 0]
821
+ ang_axis_count = joint_axis_dim[i, 1]
822
+ v_j_s = jcalc_motion(
823
+ type,
824
+ joint_axis,
825
+ axis_start,
826
+ lin_axis_count,
827
+ ang_axis_count,
828
+ X_wpj,
829
+ joint_q,
830
+ joint_qd,
831
+ q_start,
832
+ qd_start,
833
+ joint_S_s,
834
+ )
835
+
836
+ # parent velocity
837
+ v_parent_s = wp.spatial_vector()
838
+ a_parent_s = wp.spatial_vector()
839
+
840
+ if parent >= 0:
841
+ v_parent_s = body_v_s[parent]
842
+ a_parent_s = body_a_s[parent]
843
+
844
+ # body velocity, acceleration
845
+ v_s = v_parent_s + v_j_s
846
+ a_s = a_parent_s + spatial_cross(v_s, v_j_s) # + joint_S_s[i]*self.joint_qdd[i]
847
+
848
+ # compute body forces
849
+ X_sm = body_q_com[child]
850
+ I_m = body_I_m[child]
851
+
852
+ # gravity and external forces (expressed in frame aligned with s but centered at body mass)
853
+ m = I_m[3, 3]
854
+
855
+ f_g = m * gravity
856
+ r_com = wp.transform_get_translation(X_sm)
857
+ f_g_s = wp.spatial_vector(wp.cross(r_com, f_g), f_g)
858
+
859
+ # body forces
860
+ I_s = spatial_transform_inertia(X_sm, I_m)
861
+
862
+ f_b_s = I_s * a_s + spatial_cross_dual(v_s, I_s * v_s)
863
+
864
+ body_v_s[child] = v_s
865
+ body_a_s[child] = a_s
866
+ body_f_s[child] = f_b_s - f_g_s
867
+ body_I_s[child] = I_s
868
+
869
+
870
+ # Inverse dynamics via Recursive Newton-Euler algorithm (Featherstone Table 5.1)
871
+ @wp.kernel
872
+ def eval_rigid_id(
873
+ articulation_start: wp.array(dtype=int),
874
+ joint_type: wp.array(dtype=int),
875
+ joint_parent: wp.array(dtype=int),
876
+ joint_child: wp.array(dtype=int),
877
+ joint_q_start: wp.array(dtype=int),
878
+ joint_qd_start: wp.array(dtype=int),
879
+ joint_q: wp.array(dtype=float),
880
+ joint_qd: wp.array(dtype=float),
881
+ joint_axis: wp.array(dtype=wp.vec3),
882
+ joint_axis_start: wp.array(dtype=int),
883
+ joint_axis_dim: wp.array(dtype=int, ndim=2),
884
+ body_I_m: wp.array(dtype=wp.spatial_matrix),
885
+ body_q: wp.array(dtype=wp.transform),
886
+ body_q_com: wp.array(dtype=wp.transform),
887
+ joint_X_p: wp.array(dtype=wp.transform),
888
+ joint_X_c: wp.array(dtype=wp.transform),
889
+ gravity: wp.vec3,
890
+ # outputs
891
+ joint_S_s: wp.array(dtype=wp.spatial_vector),
892
+ body_I_s: wp.array(dtype=wp.spatial_matrix),
893
+ body_v_s: wp.array(dtype=wp.spatial_vector),
894
+ body_f_s: wp.array(dtype=wp.spatial_vector),
895
+ body_a_s: wp.array(dtype=wp.spatial_vector),
896
+ ):
897
+ # one thread per-articulation
898
+ index = wp.tid()
899
+
900
+ start = articulation_start[index]
901
+ end = articulation_start[index + 1]
902
+
903
+ # compute link velocities and coriolis forces
904
+ for i in range(start, end):
905
+ compute_link_velocity(
906
+ i,
907
+ joint_type,
908
+ joint_parent,
909
+ joint_child,
910
+ joint_q_start,
911
+ joint_qd_start,
912
+ joint_q,
913
+ joint_qd,
914
+ joint_axis,
915
+ joint_axis_start,
916
+ joint_axis_dim,
917
+ body_I_m,
918
+ body_q,
919
+ body_q_com,
920
+ joint_X_p,
921
+ joint_X_c,
922
+ gravity,
923
+ joint_S_s,
924
+ body_I_s,
925
+ body_v_s,
926
+ body_f_s,
927
+ body_a_s,
928
+ )
929
+
930
+
931
+ @wp.kernel
932
+ def eval_rigid_tau(
933
+ articulation_start: wp.array(dtype=int),
934
+ joint_type: wp.array(dtype=int),
935
+ joint_parent: wp.array(dtype=int),
936
+ joint_child: wp.array(dtype=int),
937
+ joint_q_start: wp.array(dtype=int),
938
+ joint_qd_start: wp.array(dtype=int),
939
+ joint_axis_start: wp.array(dtype=int),
940
+ joint_axis_dim: wp.array(dtype=int, ndim=2),
941
+ joint_axis_mode: wp.array(dtype=int),
942
+ joint_q: wp.array(dtype=float),
943
+ joint_qd: wp.array(dtype=float),
944
+ joint_act: wp.array(dtype=float),
945
+ joint_target_ke: wp.array(dtype=float),
946
+ joint_target_kd: wp.array(dtype=float),
947
+ joint_limit_lower: wp.array(dtype=float),
948
+ joint_limit_upper: wp.array(dtype=float),
949
+ joint_limit_ke: wp.array(dtype=float),
950
+ joint_limit_kd: wp.array(dtype=float),
951
+ joint_S_s: wp.array(dtype=wp.spatial_vector),
952
+ body_fb_s: wp.array(dtype=wp.spatial_vector),
953
+ body_f_ext: wp.array(dtype=wp.spatial_vector),
954
+ # outputs
955
+ body_ft_s: wp.array(dtype=wp.spatial_vector),
956
+ tau: wp.array(dtype=float),
957
+ ):
958
+ # one thread per-articulation
959
+ index = wp.tid()
960
+
961
+ start = articulation_start[index]
962
+ end = articulation_start[index + 1]
963
+ count = end - start
964
+
965
+ # compute joint forces
966
+ for offset in range(count):
967
+ # for backwards traversal
968
+ i = end - offset - 1
969
+
970
+ type = joint_type[i]
971
+ parent = joint_parent[i]
972
+ child = joint_child[i]
973
+ dof_start = joint_qd_start[i]
974
+ coord_start = joint_q_start[i]
975
+ axis_start = joint_axis_start[i]
976
+ lin_axis_count = joint_axis_dim[i, 0]
977
+ ang_axis_count = joint_axis_dim[i, 1]
978
+
979
+ # total forces on body
980
+ f_b_s = body_fb_s[child]
981
+ f_t_s = body_ft_s[child]
982
+ f_ext = body_f_ext[child]
983
+ f_s = f_b_s + f_t_s + f_ext
984
+
985
+ # compute joint-space forces, writes out tau
986
+ jcalc_tau(
987
+ type,
988
+ joint_target_ke,
989
+ joint_target_kd,
990
+ joint_limit_ke,
991
+ joint_limit_kd,
992
+ joint_S_s,
993
+ joint_q,
994
+ joint_qd,
995
+ joint_act,
996
+ joint_axis_mode,
997
+ joint_limit_lower,
998
+ joint_limit_upper,
999
+ coord_start,
1000
+ dof_start,
1001
+ axis_start,
1002
+ lin_axis_count,
1003
+ ang_axis_count,
1004
+ f_s,
1005
+ tau,
1006
+ )
1007
+
1008
+ # update parent forces, todo: check that this is valid for the backwards pass
1009
+ if parent >= 0:
1010
+ wp.atomic_add(body_ft_s, parent, f_s)
1011
+
1012
+
1013
+ # builds spatial Jacobian J which is an (joint_count*6)x(dof_count) matrix
1014
+ @wp.kernel
1015
+ def eval_rigid_jacobian(
1016
+ articulation_start: wp.array(dtype=int),
1017
+ articulation_J_start: wp.array(dtype=int),
1018
+ joint_parent: wp.array(dtype=int),
1019
+ joint_qd_start: wp.array(dtype=int),
1020
+ joint_S_s: wp.array(dtype=wp.spatial_vector),
1021
+ # outputs
1022
+ J: wp.array(dtype=float),
1023
+ ):
1024
+ # one thread per-articulation
1025
+ index = wp.tid()
1026
+
1027
+ joint_start = articulation_start[index]
1028
+ joint_end = articulation_start[index + 1]
1029
+ joint_count = joint_end - joint_start
1030
+
1031
+ J_offset = articulation_J_start[index]
1032
+
1033
+ articulation_dof_start = joint_qd_start[joint_start]
1034
+ articulation_dof_end = joint_qd_start[joint_end]
1035
+ articulation_dof_count = articulation_dof_end - articulation_dof_start
1036
+
1037
+ for i in range(joint_count):
1038
+ row_start = i * 6
1039
+
1040
+ j = joint_start + i
1041
+ while j != -1:
1042
+ joint_dof_start = joint_qd_start[j]
1043
+ joint_dof_end = joint_qd_start[j + 1]
1044
+ joint_dof_count = joint_dof_end - joint_dof_start
1045
+
1046
+ # fill out each row of the Jacobian walking up the tree
1047
+ for dof in range(joint_dof_count):
1048
+ col = (joint_dof_start - articulation_dof_start) + dof
1049
+ S = joint_S_s[joint_dof_start + dof]
1050
+
1051
+ for k in range(6):
1052
+ J[J_offset + dense_index(articulation_dof_count, row_start + k, col)] = S[k]
1053
+
1054
+ j = joint_parent[j]
1055
+
1056
+
1057
+ @wp.func
1058
+ def spatial_mass(
1059
+ body_I_s: wp.array(dtype=wp.spatial_matrix),
1060
+ joint_start: int,
1061
+ joint_count: int,
1062
+ M_start: int,
1063
+ # outputs
1064
+ M: wp.array(dtype=float),
1065
+ ):
1066
+ stride = joint_count * 6
1067
+ for l in range(joint_count):
1068
+ I = body_I_s[joint_start + l]
1069
+ for i in range(6):
1070
+ for j in range(6):
1071
+ M[M_start + dense_index(stride, l * 6 + i, l * 6 + j)] = I[i, j]
1072
+
1073
+
1074
+ @wp.kernel
1075
+ def eval_rigid_mass(
1076
+ articulation_start: wp.array(dtype=int),
1077
+ articulation_M_start: wp.array(dtype=int),
1078
+ body_I_s: wp.array(dtype=wp.spatial_matrix),
1079
+ # outputs
1080
+ M: wp.array(dtype=float),
1081
+ ):
1082
+ # one thread per-articulation
1083
+ index = wp.tid()
1084
+
1085
+ joint_start = articulation_start[index]
1086
+ joint_end = articulation_start[index + 1]
1087
+ joint_count = joint_end - joint_start
1088
+
1089
+ M_offset = articulation_M_start[index]
1090
+
1091
+ spatial_mass(body_I_s, joint_start, joint_count, M_offset, M)
1092
+
1093
+
1094
+ @wp.func
1095
+ def dense_gemm(
1096
+ m: int,
1097
+ n: int,
1098
+ p: int,
1099
+ transpose_A: bool,
1100
+ transpose_B: bool,
1101
+ add_to_C: bool,
1102
+ A_start: int,
1103
+ B_start: int,
1104
+ C_start: int,
1105
+ A: wp.array(dtype=float),
1106
+ B: wp.array(dtype=float),
1107
+ # outputs
1108
+ C: wp.array(dtype=float),
1109
+ ):
1110
+ # multiply a `m x p` matrix A by a `p x n` matrix B to produce a `m x n` matrix C
1111
+ for i in range(m):
1112
+ for j in range(n):
1113
+ sum = float(0.0)
1114
+ for k in range(p):
1115
+ if transpose_A:
1116
+ a_i = k * m + i
1117
+ else:
1118
+ a_i = i * p + k
1119
+ if transpose_B:
1120
+ b_j = j * p + k
1121
+ else:
1122
+ b_j = k * n + j
1123
+ sum += A[A_start + a_i] * B[B_start + b_j]
1124
+
1125
+ if add_to_C:
1126
+ C[C_start + i * n + j] += sum
1127
+ else:
1128
+ C[C_start + i * n + j] = sum
1129
+
1130
+
1131
+ # @wp.func_grad(dense_gemm)
1132
+ # def adj_dense_gemm(
1133
+ # m: int,
1134
+ # n: int,
1135
+ # p: int,
1136
+ # transpose_A: bool,
1137
+ # transpose_B: bool,
1138
+ # add_to_C: bool,
1139
+ # A_start: int,
1140
+ # B_start: int,
1141
+ # C_start: int,
1142
+ # A: wp.array(dtype=float),
1143
+ # B: wp.array(dtype=float),
1144
+ # # outputs
1145
+ # C: wp.array(dtype=float),
1146
+ # ):
1147
+ # add_to_C = True
1148
+ # if transpose_A:
1149
+ # dense_gemm(p, m, n, False, True, add_to_C, A_start, B_start, C_start, B, wp.adjoint[C], wp.adjoint[A])
1150
+ # dense_gemm(p, n, m, False, False, add_to_C, A_start, B_start, C_start, A, wp.adjoint[C], wp.adjoint[B])
1151
+ # else:
1152
+ # dense_gemm(
1153
+ # m, p, n, False, not transpose_B, add_to_C, A_start, B_start, C_start, wp.adjoint[C], B, wp.adjoint[A]
1154
+ # )
1155
+ # dense_gemm(p, n, m, True, False, add_to_C, A_start, B_start, C_start, A, wp.adjoint[C], wp.adjoint[B])
1156
+
1157
+
1158
+ @wp.kernel
1159
+ def eval_dense_gemm_batched(
1160
+ m: wp.array(dtype=int),
1161
+ n: wp.array(dtype=int),
1162
+ p: wp.array(dtype=int),
1163
+ transpose_A: bool,
1164
+ transpose_B: bool,
1165
+ A_start: wp.array(dtype=int),
1166
+ B_start: wp.array(dtype=int),
1167
+ C_start: wp.array(dtype=int),
1168
+ A: wp.array(dtype=float),
1169
+ B: wp.array(dtype=float),
1170
+ C: wp.array(dtype=float),
1171
+ ):
1172
+ # on the CPU each thread computes the whole matrix multiply
1173
+ # on the GPU each block computes the multiply with one output per-thread
1174
+ batch = wp.tid() # /kNumThreadsPerBlock;
1175
+ add_to_C = False
1176
+
1177
+ dense_gemm(
1178
+ m[batch],
1179
+ n[batch],
1180
+ p[batch],
1181
+ transpose_A,
1182
+ transpose_B,
1183
+ add_to_C,
1184
+ A_start[batch],
1185
+ B_start[batch],
1186
+ C_start[batch],
1187
+ A,
1188
+ B,
1189
+ C,
1190
+ )
1191
+
1192
+
1193
+ @wp.func
1194
+ def dense_cholesky(
1195
+ n: int,
1196
+ A: wp.array(dtype=float),
1197
+ R: wp.array(dtype=float),
1198
+ A_start: int,
1199
+ R_start: int,
1200
+ # outputs
1201
+ L: wp.array(dtype=float),
1202
+ ):
1203
+ # compute the Cholesky factorization of A = L L^T with diagonal regularization R
1204
+ for j in range(n):
1205
+ s = A[A_start + dense_index(n, j, j)] + R[R_start + j]
1206
+
1207
+ for k in range(j):
1208
+ r = L[A_start + dense_index(n, j, k)]
1209
+ s -= r * r
1210
+
1211
+ s = wp.sqrt(s)
1212
+ invS = 1.0 / s
1213
+
1214
+ L[A_start + dense_index(n, j, j)] = s
1215
+
1216
+ for i in range(j + 1, n):
1217
+ s = A[A_start + dense_index(n, i, j)]
1218
+
1219
+ for k in range(j):
1220
+ s -= L[A_start + dense_index(n, i, k)] * L[A_start + dense_index(n, j, k)]
1221
+
1222
+ L[A_start + dense_index(n, i, j)] = s * invS
1223
+
1224
+
1225
+ @wp.func_grad(dense_cholesky)
1226
+ def adj_dense_cholesky(
1227
+ n: int,
1228
+ A: wp.array(dtype=float),
1229
+ R: wp.array(dtype=float),
1230
+ A_start: int,
1231
+ R_start: int,
1232
+ # outputs
1233
+ L: wp.array(dtype=float),
1234
+ ):
1235
+ # nop, use dense_solve to differentiate through (A^-1)b = x
1236
+ pass
1237
+
1238
+
1239
+ @wp.kernel
1240
+ def eval_dense_cholesky_batched(
1241
+ A_starts: wp.array(dtype=int),
1242
+ A_dim: wp.array(dtype=int),
1243
+ A: wp.array(dtype=float),
1244
+ R: wp.array(dtype=float),
1245
+ L: wp.array(dtype=float),
1246
+ ):
1247
+ batch = wp.tid()
1248
+
1249
+ n = A_dim[batch]
1250
+ A_start = A_starts[batch]
1251
+ R_start = n * batch
1252
+
1253
+ dense_cholesky(n, A, R, A_start, R_start, L)
1254
+
1255
+
1256
+ @wp.func
1257
+ def dense_subs(
1258
+ n: int,
1259
+ L_start: int,
1260
+ b_start: int,
1261
+ L: wp.array(dtype=float),
1262
+ b: wp.array(dtype=float),
1263
+ # outputs
1264
+ x: wp.array(dtype=float),
1265
+ ):
1266
+ # Solves (L L^T) x = b for x given the Cholesky factor L
1267
+ # forward substitution solves the lower triangular system L y = b for y
1268
+ for i in range(n):
1269
+ s = b[b_start + i]
1270
+
1271
+ for j in range(i):
1272
+ s -= L[L_start + dense_index(n, i, j)] * x[b_start + j]
1273
+
1274
+ x[b_start + i] = s / L[L_start + dense_index(n, i, i)]
1275
+
1276
+ # backward substitution solves the upper triangular system L^T x = y for x
1277
+ for i in range(n - 1, -1, -1):
1278
+ s = x[b_start + i]
1279
+
1280
+ for j in range(i + 1, n):
1281
+ s -= L[L_start + dense_index(n, j, i)] * x[b_start + j]
1282
+
1283
+ x[b_start + i] = s / L[L_start + dense_index(n, i, i)]
1284
+
1285
+
1286
+ @wp.func
1287
+ def dense_solve(
1288
+ n: int,
1289
+ L_start: int,
1290
+ b_start: int,
1291
+ L: wp.array(dtype=float),
1292
+ b: wp.array(dtype=float),
1293
+ # outputs
1294
+ x: wp.array(dtype=float),
1295
+ tmp: wp.array(dtype=float),
1296
+ ):
1297
+ # helper function to include tmp argument for backward pass
1298
+ dense_subs(n, L_start, b_start, L, b, x)
1299
+
1300
+
1301
+ @wp.func_grad(dense_solve)
1302
+ def adj_dense_solve(
1303
+ n: int,
1304
+ L_start: int,
1305
+ b_start: int,
1306
+ L: wp.array(dtype=float),
1307
+ b: wp.array(dtype=float),
1308
+ # outputs
1309
+ x: wp.array(dtype=float),
1310
+ tmp: wp.array(dtype=float),
1311
+ ):
1312
+ if not tmp or not wp.adjoint[x] or not wp.adjoint[L]:
1313
+ return
1314
+ for i in range(n):
1315
+ tmp[b_start + i] = 0.0
1316
+
1317
+ dense_subs(n, L_start, b_start, L, wp.adjoint[x], tmp)
1318
+
1319
+ for i in range(n):
1320
+ wp.adjoint[b][b_start + i] += tmp[b_start + i]
1321
+
1322
+ # A* = -adj_b*x^T
1323
+ for i in range(n):
1324
+ for j in range(n):
1325
+ wp.adjoint[L][L_start + dense_index(n, i, j)] += -tmp[b_start + i] * x[b_start + j]
1326
+
1327
+
1328
+ @wp.kernel
1329
+ def eval_dense_solve_batched(
1330
+ L_start: wp.array(dtype=int),
1331
+ L_dim: wp.array(dtype=int),
1332
+ b_start: wp.array(dtype=int),
1333
+ L: wp.array(dtype=float),
1334
+ b: wp.array(dtype=float),
1335
+ # outputs
1336
+ x: wp.array(dtype=float),
1337
+ tmp: wp.array(dtype=float),
1338
+ ):
1339
+ batch = wp.tid()
1340
+
1341
+ dense_solve(L_dim[batch], L_start[batch], b_start[batch], L, b, x, tmp)
1342
+
1343
+
1344
+ @wp.kernel
1345
+ def integrate_generalized_joints(
1346
+ joint_type: wp.array(dtype=int),
1347
+ joint_q_start: wp.array(dtype=int),
1348
+ joint_qd_start: wp.array(dtype=int),
1349
+ joint_axis_dim: wp.array(dtype=int, ndim=2),
1350
+ joint_q: wp.array(dtype=float),
1351
+ joint_qd: wp.array(dtype=float),
1352
+ joint_qdd: wp.array(dtype=float),
1353
+ dt: float,
1354
+ # outputs
1355
+ joint_q_new: wp.array(dtype=float),
1356
+ joint_qd_new: wp.array(dtype=float),
1357
+ ):
1358
+ # one thread per-articulation
1359
+ index = wp.tid()
1360
+
1361
+ type = joint_type[index]
1362
+ coord_start = joint_q_start[index]
1363
+ dof_start = joint_qd_start[index]
1364
+ lin_axis_count = joint_axis_dim[index, 0]
1365
+ ang_axis_count = joint_axis_dim[index, 1]
1366
+
1367
+ jcalc_integrate(
1368
+ type,
1369
+ joint_q,
1370
+ joint_qd,
1371
+ joint_qdd,
1372
+ coord_start,
1373
+ dof_start,
1374
+ lin_axis_count,
1375
+ ang_axis_count,
1376
+ dt,
1377
+ joint_q_new,
1378
+ joint_qd_new,
1379
+ )
1380
+
1381
+
1382
+ class FeatherstoneIntegrator(Integrator):
1383
+ """A semi-implicit integrator using symplectic Euler that operates
1384
+ on reduced (also called generalized) coordinates to simulate articulated rigid body dynamics
1385
+ based on Featherstone's composite rigid body algorithm (CRBA).
1386
+
1387
+ See: Featherstone, Roy. Rigid Body Dynamics Algorithms. Springer US, 2014.
1388
+
1389
+ Instead of maximal coordinates :attr:`State.body_q` (rigid body positions) and :attr:`State.body_qd`
1390
+ (rigid body velocities) as is the case :class:`SemiImplicitIntegrator`, :class:`FeatherstoneIntegrator`
1391
+ uses :attr:`State.joint_q` and :attr:`State.joint_qd` to represent the positions and velocities of
1392
+ joints without allowing any redundant degrees of freedom.
1393
+
1394
+ After constructing :class:`Model` and :class:`State` objects this time-integrator
1395
+ may be used to advance the simulation state forward in time.
1396
+
1397
+ Note:
1398
+ Unlike :class:`SemiImplicitIntegrator` and :class:`XPBDIntegrator`, :class:`FeatherstoneIntegrator` does not simulate rigid bodies with nonzero mass as floating bodies if they are not connected through any joints. Floating-base systems require an explicit free joint with which the body is connected to the world, see :meth:`ModelBuilder.add_joint_free`.
1399
+
1400
+ Semi-implicit time integration is a variational integrator that
1401
+ preserves energy, however it not unconditionally stable, and requires a time-step
1402
+ small enough to support the required stiffness and damping forces.
1403
+
1404
+ See: https://en.wikipedia.org/wiki/Semi-implicit_Euler_method
1405
+
1406
+ Example
1407
+ -------
1408
+
1409
+ .. code-block:: python
1410
+
1411
+ integrator = wp.FeatherstoneIntegrator(model)
1412
+
1413
+ # simulation loop
1414
+ for i in range(100):
1415
+ state = integrator.simulate(model, state_in, state_out, dt)
1416
+
1417
+ Note:
1418
+ The :class:`FeatherstoneIntegrator` requires the :class:`Model` to be passed in as a constructor argument.
1419
+
1420
+ """
1421
+
1422
+ def __init__(self, model, angular_damping=0.05, update_mass_matrix_every=1):
1423
+ """
1424
+ Args:
1425
+ model (Model): the model to be simulated.
1426
+ angular_damping (float, optional): Angular damping factor. Defaults to 0.05.
1427
+ update_mass_matrix_every (int, optional): How often to update the mass matrix (every n-th time the :meth:`simulate` function gets called). Defaults to 1.
1428
+ """
1429
+ self.angular_damping = angular_damping
1430
+ self.update_mass_matrix_every = update_mass_matrix_every
1431
+ self.compute_articulation_indices(model)
1432
+ self.allocate_model_aux_vars(model)
1433
+ self._step = 0
1434
+
1435
+ def compute_articulation_indices(self, model):
1436
+ # calculate total size and offsets of Jacobian and mass matrices for entire system
1437
+ if model.joint_count:
1438
+ self.J_size = 0
1439
+ self.M_size = 0
1440
+ self.H_size = 0
1441
+
1442
+ articulation_J_start = []
1443
+ articulation_M_start = []
1444
+ articulation_H_start = []
1445
+
1446
+ articulation_M_rows = []
1447
+ articulation_H_rows = []
1448
+ articulation_J_rows = []
1449
+ articulation_J_cols = []
1450
+
1451
+ articulation_dof_start = []
1452
+ articulation_coord_start = []
1453
+
1454
+ articulation_start = model.articulation_start.numpy()
1455
+ joint_q_start = model.joint_q_start.numpy()
1456
+ joint_qd_start = model.joint_qd_start.numpy()
1457
+
1458
+ for i in range(model.articulation_count):
1459
+ first_joint = articulation_start[i]
1460
+ last_joint = articulation_start[i + 1]
1461
+
1462
+ first_coord = joint_q_start[first_joint]
1463
+
1464
+ first_dof = joint_qd_start[first_joint]
1465
+ last_dof = joint_qd_start[last_joint]
1466
+
1467
+ joint_count = last_joint - first_joint
1468
+ dof_count = last_dof - first_dof
1469
+
1470
+ articulation_J_start.append(self.J_size)
1471
+ articulation_M_start.append(self.M_size)
1472
+ articulation_H_start.append(self.H_size)
1473
+ articulation_dof_start.append(first_dof)
1474
+ articulation_coord_start.append(first_coord)
1475
+
1476
+ # bit of data duplication here, but will leave it as such for clarity
1477
+ articulation_M_rows.append(joint_count * 6)
1478
+ articulation_H_rows.append(dof_count)
1479
+ articulation_J_rows.append(joint_count * 6)
1480
+ articulation_J_cols.append(dof_count)
1481
+
1482
+ self.J_size += 6 * joint_count * dof_count
1483
+ self.M_size += 6 * joint_count * 6 * joint_count
1484
+ self.H_size += dof_count * dof_count
1485
+
1486
+ # matrix offsets for batched gemm
1487
+ self.articulation_J_start = wp.array(articulation_J_start, dtype=wp.int32, device=model.device)
1488
+ self.articulation_M_start = wp.array(articulation_M_start, dtype=wp.int32, device=model.device)
1489
+ self.articulation_H_start = wp.array(articulation_H_start, dtype=wp.int32, device=model.device)
1490
+
1491
+ self.articulation_M_rows = wp.array(articulation_M_rows, dtype=wp.int32, device=model.device)
1492
+ self.articulation_H_rows = wp.array(articulation_H_rows, dtype=wp.int32, device=model.device)
1493
+ self.articulation_J_rows = wp.array(articulation_J_rows, dtype=wp.int32, device=model.device)
1494
+ self.articulation_J_cols = wp.array(articulation_J_cols, dtype=wp.int32, device=model.device)
1495
+
1496
+ self.articulation_dof_start = wp.array(articulation_dof_start, dtype=wp.int32, device=model.device)
1497
+ self.articulation_coord_start = wp.array(articulation_coord_start, dtype=wp.int32, device=model.device)
1498
+
1499
+ def allocate_model_aux_vars(self, model):
1500
+ # allocate mass, Jacobian matrices, and other auxiliary variables pertaining to the model
1501
+ if model.joint_count:
1502
+ # system matrices
1503
+ self.M = wp.zeros((self.M_size,), dtype=wp.float32, device=model.device, requires_grad=model.requires_grad)
1504
+ self.J = wp.zeros((self.J_size,), dtype=wp.float32, device=model.device, requires_grad=model.requires_grad)
1505
+ self.P = wp.empty_like(self.J, requires_grad=model.requires_grad)
1506
+ self.H = wp.empty((self.H_size,), dtype=wp.float32, device=model.device, requires_grad=model.requires_grad)
1507
+
1508
+ # zero since only upper triangle is set which can trigger NaN detection
1509
+ self.L = wp.zeros_like(self.H)
1510
+
1511
+ if model.body_count:
1512
+ # TODO use requires_grad here?
1513
+ self.body_I_m = wp.empty(
1514
+ (model.body_count,), dtype=wp.spatial_matrix, device=model.device, requires_grad=model.requires_grad
1515
+ )
1516
+ wp.launch(
1517
+ compute_spatial_inertia,
1518
+ model.body_count,
1519
+ inputs=[model.body_inertia, model.body_mass],
1520
+ outputs=[self.body_I_m],
1521
+ device=model.device,
1522
+ )
1523
+ self.body_X_com = wp.empty(
1524
+ (model.body_count,), dtype=wp.transform, device=model.device, requires_grad=model.requires_grad
1525
+ )
1526
+ wp.launch(
1527
+ compute_com_transforms,
1528
+ model.body_count,
1529
+ inputs=[model.body_com],
1530
+ outputs=[self.body_X_com],
1531
+ device=model.device,
1532
+ )
1533
+
1534
+ def allocate_state_aux_vars(self, model, target, requires_grad):
1535
+ # allocate auxiliary variables that vary with state
1536
+ if model.body_count:
1537
+ # joints
1538
+ target.joint_qdd = wp.zeros_like(model.joint_qd, requires_grad=requires_grad)
1539
+ target.joint_tau = wp.empty_like(model.joint_qd, requires_grad=requires_grad)
1540
+ if requires_grad:
1541
+ # used in the custom grad implementation of eval_dense_solve_batched
1542
+ target.joint_solve_tmp = wp.zeros_like(model.joint_qd, requires_grad=True)
1543
+ else:
1544
+ target.joint_solve_tmp = None
1545
+ target.joint_S_s = wp.empty(
1546
+ (model.joint_dof_count,),
1547
+ dtype=wp.spatial_vector,
1548
+ device=model.device,
1549
+ requires_grad=requires_grad,
1550
+ )
1551
+
1552
+ # derived rigid body data (maximal coordinates)
1553
+ target.body_q_com = wp.empty_like(model.body_q, requires_grad=requires_grad)
1554
+ target.body_I_s = wp.empty(
1555
+ (model.body_count,), dtype=wp.spatial_matrix, device=model.device, requires_grad=requires_grad
1556
+ )
1557
+ target.body_v_s = wp.empty(
1558
+ (model.body_count,), dtype=wp.spatial_vector, device=model.device, requires_grad=requires_grad
1559
+ )
1560
+ target.body_a_s = wp.empty(
1561
+ (model.body_count,), dtype=wp.spatial_vector, device=model.device, requires_grad=requires_grad
1562
+ )
1563
+ target.body_f_s = wp.zeros(
1564
+ (model.body_count,), dtype=wp.spatial_vector, device=model.device, requires_grad=requires_grad
1565
+ )
1566
+ target.body_ft_s = wp.zeros(
1567
+ (model.body_count,), dtype=wp.spatial_vector, device=model.device, requires_grad=requires_grad
1568
+ )
1569
+
1570
+ target._featherstone_augmented = True
1571
+
1572
+ def simulate(self, model: Model, state_in: State, state_out: State, dt: float, control: Control = None):
1573
+ requires_grad = state_in.requires_grad
1574
+
1575
+ # optionally create dynamical auxiliary variables
1576
+ if requires_grad:
1577
+ state_aug = state_out
1578
+ else:
1579
+ state_aug = self
1580
+
1581
+ if not getattr(state_aug, "_featherstone_augmented", False):
1582
+ self.allocate_state_aux_vars(model, state_aug, requires_grad)
1583
+ if control is None:
1584
+ control = model.control(clone_variables=False)
1585
+
1586
+ with wp.ScopedTimer("simulate", False):
1587
+ particle_f = None
1588
+ body_f = None
1589
+
1590
+ if state_in.particle_count:
1591
+ particle_f = state_in.particle_f
1592
+
1593
+ if state_in.body_count:
1594
+ body_f = state_in.body_f
1595
+
1596
+ # damped springs
1597
+ eval_spring_forces(model, state_in, particle_f)
1598
+
1599
+ # triangle elastic and lift/drag forces
1600
+ eval_triangle_forces(model, state_in, control, particle_f)
1601
+
1602
+ # triangle/triangle contacts
1603
+ eval_triangle_contact_forces(model, state_in, particle_f)
1604
+
1605
+ # triangle bending
1606
+ eval_bending_forces(model, state_in, particle_f)
1607
+
1608
+ # tetrahedral FEM
1609
+ eval_tetrahedral_forces(model, state_in, control, particle_f)
1610
+
1611
+ # particle-particle interactions
1612
+ eval_particle_forces(model, state_in, particle_f)
1613
+
1614
+ # particle ground contacts
1615
+ eval_particle_ground_contact_forces(model, state_in, particle_f)
1616
+
1617
+ # particle shape contact
1618
+ eval_particle_body_contact_forces(model, state_in, particle_f, body_f)
1619
+
1620
+ # muscles
1621
+ if False:
1622
+ eval_muscle_forces(model, state_in, control, body_f)
1623
+
1624
+ # ----------------------------
1625
+ # articulations
1626
+
1627
+ if model.joint_count:
1628
+ # evaluate body transforms
1629
+ wp.launch(
1630
+ eval_rigid_fk,
1631
+ dim=model.articulation_count,
1632
+ inputs=[
1633
+ model.articulation_start,
1634
+ model.joint_type,
1635
+ model.joint_parent,
1636
+ model.joint_child,
1637
+ model.joint_q_start,
1638
+ state_in.joint_q,
1639
+ model.joint_X_p,
1640
+ model.joint_X_c,
1641
+ self.body_X_com,
1642
+ model.joint_axis,
1643
+ model.joint_axis_start,
1644
+ model.joint_axis_dim,
1645
+ ],
1646
+ outputs=[state_in.body_q, state_aug.body_q_com],
1647
+ device=model.device,
1648
+ )
1649
+
1650
+ # print("body_X_sc:")
1651
+ # print(state_in.body_q.numpy())
1652
+
1653
+ # evaluate joint inertias, motion vectors, and forces
1654
+ state_aug.body_f_s.zero_()
1655
+ wp.launch(
1656
+ eval_rigid_id,
1657
+ dim=model.articulation_count,
1658
+ inputs=[
1659
+ model.articulation_start,
1660
+ model.joint_type,
1661
+ model.joint_parent,
1662
+ model.joint_child,
1663
+ model.joint_q_start,
1664
+ model.joint_qd_start,
1665
+ state_in.joint_q,
1666
+ state_in.joint_qd,
1667
+ model.joint_axis,
1668
+ model.joint_axis_start,
1669
+ model.joint_axis_dim,
1670
+ self.body_I_m,
1671
+ state_in.body_q,
1672
+ state_aug.body_q_com,
1673
+ model.joint_X_p,
1674
+ model.joint_X_c,
1675
+ model.gravity,
1676
+ ],
1677
+ outputs=[
1678
+ state_aug.joint_S_s,
1679
+ state_aug.body_I_s,
1680
+ state_aug.body_v_s,
1681
+ state_aug.body_f_s,
1682
+ state_aug.body_a_s,
1683
+ ],
1684
+ device=model.device,
1685
+ )
1686
+
1687
+ if model.rigid_contact_max and (
1688
+ model.ground and model.shape_ground_contact_pair_count or model.shape_contact_pair_count
1689
+ ):
1690
+ wp.launch(
1691
+ kernel=eval_rigid_contacts,
1692
+ dim=model.rigid_contact_max,
1693
+ inputs=[
1694
+ state_in.body_q,
1695
+ state_aug.body_v_s,
1696
+ model.body_com,
1697
+ model.shape_materials,
1698
+ model.shape_geo,
1699
+ model.shape_body,
1700
+ model.rigid_contact_count,
1701
+ model.rigid_contact_point0,
1702
+ model.rigid_contact_point1,
1703
+ model.rigid_contact_normal,
1704
+ model.rigid_contact_shape0,
1705
+ model.rigid_contact_shape1,
1706
+ True,
1707
+ ],
1708
+ outputs=[body_f],
1709
+ device=model.device,
1710
+ )
1711
+
1712
+ # if model.rigid_contact_count.numpy()[0] > 0:
1713
+ # print(body_f.numpy())
1714
+
1715
+ if model.articulation_count:
1716
+ # evaluate joint torques
1717
+ state_aug.body_ft_s.zero_()
1718
+ wp.launch(
1719
+ eval_rigid_tau,
1720
+ dim=model.articulation_count,
1721
+ inputs=[
1722
+ model.articulation_start,
1723
+ model.joint_type,
1724
+ model.joint_parent,
1725
+ model.joint_child,
1726
+ model.joint_q_start,
1727
+ model.joint_qd_start,
1728
+ model.joint_axis_start,
1729
+ model.joint_axis_dim,
1730
+ model.joint_axis_mode,
1731
+ state_in.joint_q,
1732
+ state_in.joint_qd,
1733
+ control.joint_act,
1734
+ model.joint_target_ke,
1735
+ model.joint_target_kd,
1736
+ model.joint_limit_lower,
1737
+ model.joint_limit_upper,
1738
+ model.joint_limit_ke,
1739
+ model.joint_limit_kd,
1740
+ state_aug.joint_S_s,
1741
+ state_aug.body_f_s,
1742
+ body_f,
1743
+ ],
1744
+ outputs=[
1745
+ state_aug.body_ft_s,
1746
+ state_aug.joint_tau,
1747
+ ],
1748
+ device=model.device,
1749
+ )
1750
+
1751
+ # print("joint_tau:")
1752
+ # print(state_aug.joint_tau.numpy())
1753
+ # print("body_q:")
1754
+ # print(state_in.body_q.numpy())
1755
+ # print("body_qd:")
1756
+ # print(state_in.body_qd.numpy())
1757
+
1758
+ if self._step % self.update_mass_matrix_every == 0:
1759
+ # build J
1760
+ wp.launch(
1761
+ eval_rigid_jacobian,
1762
+ dim=model.articulation_count,
1763
+ inputs=[
1764
+ model.articulation_start,
1765
+ self.articulation_J_start,
1766
+ model.joint_parent,
1767
+ model.joint_qd_start,
1768
+ state_aug.joint_S_s,
1769
+ ],
1770
+ outputs=[self.J],
1771
+ device=model.device,
1772
+ )
1773
+
1774
+ # build M
1775
+ wp.launch(
1776
+ eval_rigid_mass,
1777
+ dim=model.articulation_count,
1778
+ inputs=[
1779
+ model.articulation_start,
1780
+ self.articulation_M_start,
1781
+ state_aug.body_I_s,
1782
+ ],
1783
+ outputs=[self.M],
1784
+ device=model.device,
1785
+ )
1786
+
1787
+ # form P = M*J
1788
+ wp.launch(
1789
+ eval_dense_gemm_batched,
1790
+ dim=model.articulation_count,
1791
+ inputs=[
1792
+ self.articulation_M_rows,
1793
+ self.articulation_J_cols,
1794
+ self.articulation_J_rows,
1795
+ False,
1796
+ False,
1797
+ self.articulation_M_start,
1798
+ self.articulation_J_start,
1799
+ # P start is the same as J start since it has the same dims as J
1800
+ self.articulation_J_start,
1801
+ self.M,
1802
+ self.J,
1803
+ ],
1804
+ outputs=[self.P],
1805
+ device=model.device,
1806
+ )
1807
+
1808
+ # form H = J^T*P
1809
+ wp.launch(
1810
+ eval_dense_gemm_batched,
1811
+ dim=model.articulation_count,
1812
+ inputs=[
1813
+ self.articulation_J_cols,
1814
+ self.articulation_J_cols,
1815
+ # P rows is the same as J rows
1816
+ self.articulation_J_rows,
1817
+ True,
1818
+ False,
1819
+ self.articulation_J_start,
1820
+ # P start is the same as J start since it has the same dims as J
1821
+ self.articulation_J_start,
1822
+ self.articulation_H_start,
1823
+ self.J,
1824
+ self.P,
1825
+ ],
1826
+ outputs=[self.H],
1827
+ device=model.device,
1828
+ )
1829
+
1830
+ # compute decomposition
1831
+ wp.launch(
1832
+ eval_dense_cholesky_batched,
1833
+ dim=model.articulation_count,
1834
+ inputs=[
1835
+ self.articulation_H_start,
1836
+ self.articulation_H_rows,
1837
+ self.H,
1838
+ model.joint_armature,
1839
+ ],
1840
+ outputs=[self.L],
1841
+ device=model.device,
1842
+ )
1843
+
1844
+ # print("joint_act:")
1845
+ # print(control.joint_act.numpy())
1846
+ # print("joint_tau:")
1847
+ # print(state_aug.joint_tau.numpy())
1848
+ # print("H:")
1849
+ # print(self.H.numpy())
1850
+ # print("L:")
1851
+ # print(self.L.numpy())
1852
+
1853
+ # solve for qdd
1854
+ state_aug.joint_qdd.zero_()
1855
+ wp.launch(
1856
+ eval_dense_solve_batched,
1857
+ dim=model.articulation_count,
1858
+ inputs=[
1859
+ self.articulation_H_start,
1860
+ self.articulation_H_rows,
1861
+ self.articulation_dof_start,
1862
+ self.L,
1863
+ state_aug.joint_tau,
1864
+ ],
1865
+ outputs=[
1866
+ state_aug.joint_qdd,
1867
+ state_aug.joint_solve_tmp,
1868
+ ],
1869
+ device=model.device,
1870
+ )
1871
+ # if wp.context.runtime.tape:
1872
+ # wp.context.runtime.tape.record_func(
1873
+ # backward=lambda: adj_matmul(
1874
+ # a, b, c, a.grad, b.grad, c.grad, d.grad, alpha, beta, allow_tf32x3_arith, device
1875
+ # ),
1876
+ # arrays=[a, b, c, d],
1877
+ # )
1878
+ # print("joint_qdd:")
1879
+ # print(state_aug.joint_qdd.numpy())
1880
+ # print("\n\n")
1881
+
1882
+ # -------------------------------------
1883
+ # integrate bodies
1884
+
1885
+ if model.joint_count:
1886
+ wp.launch(
1887
+ kernel=integrate_generalized_joints,
1888
+ dim=model.joint_count,
1889
+ inputs=[
1890
+ model.joint_type,
1891
+ model.joint_q_start,
1892
+ model.joint_qd_start,
1893
+ model.joint_axis_dim,
1894
+ state_in.joint_q,
1895
+ state_in.joint_qd,
1896
+ state_aug.joint_qdd,
1897
+ dt,
1898
+ ],
1899
+ outputs=[state_out.joint_q, state_out.joint_qd],
1900
+ device=model.device,
1901
+ )
1902
+
1903
+ # update maximal coordinates
1904
+ eval_fk(model, state_out.joint_q, state_out.joint_qd, None, state_out)
1905
+
1906
+ self.integrate_particles(model, state_in, state_out, dt)
1907
+
1908
+ self._step += 1
1909
+
1910
+ return state_out