pymomentum-cpu 0.1.77.post30__cp313-cp313-manylinux_2_39_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pymomentum-cpu might be problematic. Click here for more details.

Files changed (555) hide show
  1. include/axel/BoundingBox.h +58 -0
  2. include/axel/Bvh.h +708 -0
  3. include/axel/BvhBase.h +75 -0
  4. include/axel/BvhCommon.h +43 -0
  5. include/axel/BvhEmbree.h +86 -0
  6. include/axel/BvhFactory.h +34 -0
  7. include/axel/Checks.h +21 -0
  8. include/axel/DualContouring.h +79 -0
  9. include/axel/KdTree.h +199 -0
  10. include/axel/Log.h +22 -0
  11. include/axel/MeshToSdf.h +123 -0
  12. include/axel/Profile.h +64 -0
  13. include/axel/Ray.h +45 -0
  14. include/axel/SignedDistanceField.h +248 -0
  15. include/axel/SimdKdTree.h +515 -0
  16. include/axel/TriBvh.h +157 -0
  17. include/axel/TriBvhEmbree.h +57 -0
  18. include/axel/common/Constants.h +27 -0
  19. include/axel/common/Types.h +21 -0
  20. include/axel/common/VectorizationTypes.h +58 -0
  21. include/axel/math/BoundingBoxUtils.h +54 -0
  22. include/axel/math/ContinuousCollisionDetection.h +48 -0
  23. include/axel/math/CoplanarityCheck.h +30 -0
  24. include/axel/math/EdgeEdgeDistance.h +31 -0
  25. include/axel/math/MeshHoleFilling.h +117 -0
  26. include/axel/math/PointTriangleProjection.h +34 -0
  27. include/axel/math/PointTriangleProjectionDefinitions.h +209 -0
  28. include/axel/math/RayTriangleIntersection.h +36 -0
  29. include/momentum/character/blend_shape.h +91 -0
  30. include/momentum/character/blend_shape_base.h +70 -0
  31. include/momentum/character/blend_shape_skinning.h +96 -0
  32. include/momentum/character/character.h +272 -0
  33. include/momentum/character/character_state.h +108 -0
  34. include/momentum/character/character_utility.h +128 -0
  35. include/momentum/character/collision_geometry.h +80 -0
  36. include/momentum/character/collision_geometry_state.h +130 -0
  37. include/momentum/character/fwd.h +262 -0
  38. include/momentum/character/inverse_parameter_transform.h +58 -0
  39. include/momentum/character/joint.h +82 -0
  40. include/momentum/character/joint_state.h +241 -0
  41. include/momentum/character/linear_skinning.h +139 -0
  42. include/momentum/character/locator.h +82 -0
  43. include/momentum/character/locator_state.h +43 -0
  44. include/momentum/character/marker.h +48 -0
  45. include/momentum/character/mesh_state.h +71 -0
  46. include/momentum/character/parameter_limits.h +144 -0
  47. include/momentum/character/parameter_transform.h +250 -0
  48. include/momentum/character/pose_shape.h +65 -0
  49. include/momentum/character/skeleton.h +85 -0
  50. include/momentum/character/skeleton_state.h +181 -0
  51. include/momentum/character/skeleton_utility.h +38 -0
  52. include/momentum/character/skin_weights.h +67 -0
  53. include/momentum/character/skinned_locator.h +80 -0
  54. include/momentum/character/types.h +202 -0
  55. include/momentum/character_sequence_solver/fwd.h +200 -0
  56. include/momentum/character_sequence_solver/model_parameters_sequence_error_function.h +65 -0
  57. include/momentum/character_sequence_solver/multipose_solver.h +65 -0
  58. include/momentum/character_sequence_solver/multipose_solver_function.h +82 -0
  59. include/momentum/character_sequence_solver/sequence_error_function.h +104 -0
  60. include/momentum/character_sequence_solver/sequence_solver.h +144 -0
  61. include/momentum/character_sequence_solver/sequence_solver_function.h +134 -0
  62. include/momentum/character_sequence_solver/state_sequence_error_function.h +109 -0
  63. include/momentum/character_sequence_solver/vertex_sequence_error_function.h +128 -0
  64. include/momentum/character_solver/aim_error_function.h +112 -0
  65. include/momentum/character_solver/collision_error_function.h +92 -0
  66. include/momentum/character_solver/collision_error_function_stateless.h +75 -0
  67. include/momentum/character_solver/constraint_error_function-inl.h +324 -0
  68. include/momentum/character_solver/constraint_error_function.h +248 -0
  69. include/momentum/character_solver/distance_error_function.h +77 -0
  70. include/momentum/character_solver/error_function_utils.h +60 -0
  71. include/momentum/character_solver/fixed_axis_error_function.h +139 -0
  72. include/momentum/character_solver/fwd.h +924 -0
  73. include/momentum/character_solver/gauss_newton_solver_qr.h +64 -0
  74. include/momentum/character_solver/limit_error_function.h +57 -0
  75. include/momentum/character_solver/model_parameters_error_function.h +64 -0
  76. include/momentum/character_solver/normal_error_function.h +73 -0
  77. include/momentum/character_solver/orientation_error_function.h +74 -0
  78. include/momentum/character_solver/plane_error_function.h +102 -0
  79. include/momentum/character_solver/point_triangle_vertex_error_function.h +141 -0
  80. include/momentum/character_solver/pose_prior_error_function.h +80 -0
  81. include/momentum/character_solver/position_error_function.h +75 -0
  82. include/momentum/character_solver/projection_error_function.h +93 -0
  83. include/momentum/character_solver/simd_collision_error_function.h +99 -0
  84. include/momentum/character_solver/simd_normal_error_function.h +157 -0
  85. include/momentum/character_solver/simd_plane_error_function.h +164 -0
  86. include/momentum/character_solver/simd_position_error_function.h +165 -0
  87. include/momentum/character_solver/skeleton_error_function.h +151 -0
  88. include/momentum/character_solver/skeleton_solver_function.h +94 -0
  89. include/momentum/character_solver/skinned_locator_error_function.h +166 -0
  90. include/momentum/character_solver/skinned_locator_triangle_error_function.h +146 -0
  91. include/momentum/character_solver/skinning_weight_iterator.h +80 -0
  92. include/momentum/character_solver/state_error_function.h +94 -0
  93. include/momentum/character_solver/transform_pose.h +80 -0
  94. include/momentum/character_solver/trust_region_qr.h +80 -0
  95. include/momentum/character_solver/vertex_error_function.h +155 -0
  96. include/momentum/character_solver/vertex_projection_error_function.h +126 -0
  97. include/momentum/character_solver/vertex_vertex_distance_error_function.h +151 -0
  98. include/momentum/common/aligned.h +155 -0
  99. include/momentum/common/checks.h +27 -0
  100. include/momentum/common/exception.h +70 -0
  101. include/momentum/common/filesystem.h +20 -0
  102. include/momentum/common/fwd.h +27 -0
  103. include/momentum/common/log.h +173 -0
  104. include/momentum/common/log_channel.h +17 -0
  105. include/momentum/common/memory.h +71 -0
  106. include/momentum/common/profile.h +79 -0
  107. include/momentum/common/progress_bar.h +37 -0
  108. include/momentum/common/string.h +52 -0
  109. include/momentum/diff_ik/ceres_utility.h +73 -0
  110. include/momentum/diff_ik/fully_differentiable_body_ik.h +58 -0
  111. include/momentum/diff_ik/fully_differentiable_distance_error_function.h +69 -0
  112. include/momentum/diff_ik/fully_differentiable_motion_error_function.h +46 -0
  113. include/momentum/diff_ik/fully_differentiable_orientation_error_function.h +114 -0
  114. include/momentum/diff_ik/fully_differentiable_pose_prior_error_function.h +76 -0
  115. include/momentum/diff_ik/fully_differentiable_position_error_function.h +138 -0
  116. include/momentum/diff_ik/fully_differentiable_projection_error_function.h +65 -0
  117. include/momentum/diff_ik/fully_differentiable_skeleton_error_function.h +160 -0
  118. include/momentum/diff_ik/fully_differentiable_state_error_function.h +54 -0
  119. include/momentum/diff_ik/fwd.h +385 -0
  120. include/momentum/diff_ik/union_error_function.h +67 -0
  121. include/momentum/gui/rerun/eigen_adapters.h +70 -0
  122. include/momentum/gui/rerun/logger.h +102 -0
  123. include/momentum/gui/rerun/logging_redirect.h +27 -0
  124. include/momentum/io/character_io.h +56 -0
  125. include/momentum/io/common/gsl_utils.h +50 -0
  126. include/momentum/io/common/stream_utils.h +65 -0
  127. include/momentum/io/fbx/fbx_io.h +109 -0
  128. include/momentum/io/fbx/fbx_memory_stream.h +66 -0
  129. include/momentum/io/fbx/openfbx_loader.h +49 -0
  130. include/momentum/io/fbx/polygon_data.h +60 -0
  131. include/momentum/io/gltf/gltf_builder.h +132 -0
  132. include/momentum/io/gltf/gltf_file_format.h +19 -0
  133. include/momentum/io/gltf/gltf_io.h +148 -0
  134. include/momentum/io/gltf/utils/accessor_utils.h +299 -0
  135. include/momentum/io/gltf/utils/coordinate_utils.h +60 -0
  136. include/momentum/io/gltf/utils/json_utils.h +102 -0
  137. include/momentum/io/legacy_json/legacy_json_io.h +70 -0
  138. include/momentum/io/marker/c3d_io.h +29 -0
  139. include/momentum/io/marker/conversions.h +57 -0
  140. include/momentum/io/marker/coordinate_system.h +30 -0
  141. include/momentum/io/marker/marker_io.h +54 -0
  142. include/momentum/io/marker/trc_io.h +27 -0
  143. include/momentum/io/motion/mmo_io.h +97 -0
  144. include/momentum/io/shape/blend_shape_io.h +70 -0
  145. include/momentum/io/shape/pose_shape_io.h +21 -0
  146. include/momentum/io/skeleton/locator_io.h +41 -0
  147. include/momentum/io/skeleton/mppca_io.h +26 -0
  148. include/momentum/io/skeleton/parameter_limits_io.h +25 -0
  149. include/momentum/io/skeleton/parameter_transform_io.h +41 -0
  150. include/momentum/io/skeleton/parameters_io.h +20 -0
  151. include/momentum/io/urdf/urdf_io.h +26 -0
  152. include/momentum/io/usd/usd_io.h +36 -0
  153. include/momentum/marker_tracking/app_utils.h +62 -0
  154. include/momentum/marker_tracking/marker_tracker.h +213 -0
  155. include/momentum/marker_tracking/process_markers.h +58 -0
  156. include/momentum/marker_tracking/tracker_utils.h +90 -0
  157. include/momentum/math/constants.h +82 -0
  158. include/momentum/math/covariance_matrix.h +84 -0
  159. include/momentum/math/fmt_eigen.h +23 -0
  160. include/momentum/math/fwd.h +132 -0
  161. include/momentum/math/generalized_loss.h +61 -0
  162. include/momentum/math/intersection.h +32 -0
  163. include/momentum/math/mesh.h +84 -0
  164. include/momentum/math/mppca.h +67 -0
  165. include/momentum/math/online_householder_qr.h +516 -0
  166. include/momentum/math/random-inl.h +404 -0
  167. include/momentum/math/random.h +310 -0
  168. include/momentum/math/simd_generalized_loss.h +40 -0
  169. include/momentum/math/transform.h +229 -0
  170. include/momentum/math/types.h +461 -0
  171. include/momentum/math/utility.h +251 -0
  172. include/momentum/rasterizer/camera.h +453 -0
  173. include/momentum/rasterizer/fwd.h +102 -0
  174. include/momentum/rasterizer/geometry.h +83 -0
  175. include/momentum/rasterizer/image.h +18 -0
  176. include/momentum/rasterizer/rasterizer.h +583 -0
  177. include/momentum/rasterizer/tensor.h +140 -0
  178. include/momentum/rasterizer/utility.h +268 -0
  179. include/momentum/simd/simd.h +221 -0
  180. include/momentum/solver/fwd.h +131 -0
  181. include/momentum/solver/gauss_newton_solver.h +136 -0
  182. include/momentum/solver/gradient_descent_solver.h +65 -0
  183. include/momentum/solver/solver.h +155 -0
  184. include/momentum/solver/solver_function.h +126 -0
  185. include/momentum/solver/subset_gauss_newton_solver.h +109 -0
  186. include/rerun/archetypes/annotation_context.hpp +157 -0
  187. include/rerun/archetypes/arrows2d.hpp +271 -0
  188. include/rerun/archetypes/arrows3d.hpp +257 -0
  189. include/rerun/archetypes/asset3d.hpp +262 -0
  190. include/rerun/archetypes/asset_video.hpp +275 -0
  191. include/rerun/archetypes/bar_chart.hpp +261 -0
  192. include/rerun/archetypes/boxes2d.hpp +293 -0
  193. include/rerun/archetypes/boxes3d.hpp +369 -0
  194. include/rerun/archetypes/capsules3d.hpp +333 -0
  195. include/rerun/archetypes/clear.hpp +180 -0
  196. include/rerun/archetypes/depth_image.hpp +425 -0
  197. include/rerun/archetypes/ellipsoids3d.hpp +384 -0
  198. include/rerun/archetypes/encoded_image.hpp +250 -0
  199. include/rerun/archetypes/geo_line_strings.hpp +166 -0
  200. include/rerun/archetypes/geo_points.hpp +177 -0
  201. include/rerun/archetypes/graph_edges.hpp +152 -0
  202. include/rerun/archetypes/graph_nodes.hpp +206 -0
  203. include/rerun/archetypes/image.hpp +434 -0
  204. include/rerun/archetypes/instance_poses3d.hpp +221 -0
  205. include/rerun/archetypes/line_strips2d.hpp +289 -0
  206. include/rerun/archetypes/line_strips3d.hpp +270 -0
  207. include/rerun/archetypes/mesh3d.hpp +387 -0
  208. include/rerun/archetypes/pinhole.hpp +385 -0
  209. include/rerun/archetypes/points2d.hpp +333 -0
  210. include/rerun/archetypes/points3d.hpp +369 -0
  211. include/rerun/archetypes/recording_properties.hpp +132 -0
  212. include/rerun/archetypes/scalar.hpp +170 -0
  213. include/rerun/archetypes/scalars.hpp +153 -0
  214. include/rerun/archetypes/segmentation_image.hpp +305 -0
  215. include/rerun/archetypes/series_line.hpp +274 -0
  216. include/rerun/archetypes/series_lines.hpp +271 -0
  217. include/rerun/archetypes/series_point.hpp +265 -0
  218. include/rerun/archetypes/series_points.hpp +251 -0
  219. include/rerun/archetypes/tensor.hpp +213 -0
  220. include/rerun/archetypes/text_document.hpp +200 -0
  221. include/rerun/archetypes/text_log.hpp +211 -0
  222. include/rerun/archetypes/transform3d.hpp +925 -0
  223. include/rerun/archetypes/video_frame_reference.hpp +295 -0
  224. include/rerun/archetypes/view_coordinates.hpp +393 -0
  225. include/rerun/archetypes.hpp +43 -0
  226. include/rerun/arrow_utils.hpp +32 -0
  227. include/rerun/as_components.hpp +90 -0
  228. include/rerun/blueprint/archetypes/background.hpp +113 -0
  229. include/rerun/blueprint/archetypes/container_blueprint.hpp +259 -0
  230. include/rerun/blueprint/archetypes/dataframe_query.hpp +178 -0
  231. include/rerun/blueprint/archetypes/entity_behavior.hpp +130 -0
  232. include/rerun/blueprint/archetypes/force_center.hpp +115 -0
  233. include/rerun/blueprint/archetypes/force_collision_radius.hpp +141 -0
  234. include/rerun/blueprint/archetypes/force_link.hpp +136 -0
  235. include/rerun/blueprint/archetypes/force_many_body.hpp +124 -0
  236. include/rerun/blueprint/archetypes/force_position.hpp +132 -0
  237. include/rerun/blueprint/archetypes/line_grid3d.hpp +178 -0
  238. include/rerun/blueprint/archetypes/map_background.hpp +104 -0
  239. include/rerun/blueprint/archetypes/map_zoom.hpp +103 -0
  240. include/rerun/blueprint/archetypes/near_clip_plane.hpp +109 -0
  241. include/rerun/blueprint/archetypes/panel_blueprint.hpp +95 -0
  242. include/rerun/blueprint/archetypes/plot_legend.hpp +118 -0
  243. include/rerun/blueprint/archetypes/scalar_axis.hpp +116 -0
  244. include/rerun/blueprint/archetypes/tensor_scalar_mapping.hpp +146 -0
  245. include/rerun/blueprint/archetypes/tensor_slice_selection.hpp +167 -0
  246. include/rerun/blueprint/archetypes/tensor_view_fit.hpp +95 -0
  247. include/rerun/blueprint/archetypes/view_blueprint.hpp +170 -0
  248. include/rerun/blueprint/archetypes/view_contents.hpp +142 -0
  249. include/rerun/blueprint/archetypes/viewport_blueprint.hpp +200 -0
  250. include/rerun/blueprint/archetypes/visible_time_ranges.hpp +116 -0
  251. include/rerun/blueprint/archetypes/visual_bounds2d.hpp +109 -0
  252. include/rerun/blueprint/archetypes/visualizer_overrides.hpp +113 -0
  253. include/rerun/blueprint/archetypes.hpp +29 -0
  254. include/rerun/blueprint/components/active_tab.hpp +82 -0
  255. include/rerun/blueprint/components/apply_latest_at.hpp +79 -0
  256. include/rerun/blueprint/components/auto_layout.hpp +77 -0
  257. include/rerun/blueprint/components/auto_views.hpp +77 -0
  258. include/rerun/blueprint/components/background_kind.hpp +66 -0
  259. include/rerun/blueprint/components/column_share.hpp +78 -0
  260. include/rerun/blueprint/components/component_column_selector.hpp +81 -0
  261. include/rerun/blueprint/components/container_kind.hpp +65 -0
  262. include/rerun/blueprint/components/corner2d.hpp +64 -0
  263. include/rerun/blueprint/components/enabled.hpp +77 -0
  264. include/rerun/blueprint/components/filter_by_range.hpp +74 -0
  265. include/rerun/blueprint/components/filter_is_not_null.hpp +77 -0
  266. include/rerun/blueprint/components/force_distance.hpp +82 -0
  267. include/rerun/blueprint/components/force_iterations.hpp +82 -0
  268. include/rerun/blueprint/components/force_strength.hpp +82 -0
  269. include/rerun/blueprint/components/grid_columns.hpp +78 -0
  270. include/rerun/blueprint/components/grid_spacing.hpp +78 -0
  271. include/rerun/blueprint/components/included_content.hpp +86 -0
  272. include/rerun/blueprint/components/lock_range_during_zoom.hpp +82 -0
  273. include/rerun/blueprint/components/map_provider.hpp +64 -0
  274. include/rerun/blueprint/components/near_clip_plane.hpp +82 -0
  275. include/rerun/blueprint/components/panel_state.hpp +61 -0
  276. include/rerun/blueprint/components/query_expression.hpp +89 -0
  277. include/rerun/blueprint/components/root_container.hpp +77 -0
  278. include/rerun/blueprint/components/row_share.hpp +78 -0
  279. include/rerun/blueprint/components/selected_columns.hpp +76 -0
  280. include/rerun/blueprint/components/tensor_dimension_index_slider.hpp +90 -0
  281. include/rerun/blueprint/components/timeline_name.hpp +76 -0
  282. include/rerun/blueprint/components/view_class.hpp +76 -0
  283. include/rerun/blueprint/components/view_fit.hpp +61 -0
  284. include/rerun/blueprint/components/view_maximized.hpp +79 -0
  285. include/rerun/blueprint/components/view_origin.hpp +81 -0
  286. include/rerun/blueprint/components/viewer_recommendation_hash.hpp +82 -0
  287. include/rerun/blueprint/components/visible_time_range.hpp +77 -0
  288. include/rerun/blueprint/components/visual_bounds2d.hpp +74 -0
  289. include/rerun/blueprint/components/visualizer_override.hpp +86 -0
  290. include/rerun/blueprint/components/zoom_level.hpp +78 -0
  291. include/rerun/blueprint/components.hpp +41 -0
  292. include/rerun/blueprint/datatypes/component_column_selector.hpp +61 -0
  293. include/rerun/blueprint/datatypes/filter_by_range.hpp +59 -0
  294. include/rerun/blueprint/datatypes/filter_is_not_null.hpp +61 -0
  295. include/rerun/blueprint/datatypes/selected_columns.hpp +62 -0
  296. include/rerun/blueprint/datatypes/tensor_dimension_index_slider.hpp +63 -0
  297. include/rerun/blueprint/datatypes.hpp +9 -0
  298. include/rerun/c/arrow_c_data_interface.h +111 -0
  299. include/rerun/c/compiler_utils.h +10 -0
  300. include/rerun/c/rerun.h +627 -0
  301. include/rerun/c/sdk_info.h +28 -0
  302. include/rerun/collection.hpp +496 -0
  303. include/rerun/collection_adapter.hpp +43 -0
  304. include/rerun/collection_adapter_builtins.hpp +138 -0
  305. include/rerun/compiler_utils.hpp +61 -0
  306. include/rerun/component_batch.hpp +163 -0
  307. include/rerun/component_column.hpp +111 -0
  308. include/rerun/component_descriptor.hpp +142 -0
  309. include/rerun/component_type.hpp +35 -0
  310. include/rerun/components/aggregation_policy.hpp +76 -0
  311. include/rerun/components/albedo_factor.hpp +74 -0
  312. include/rerun/components/annotation_context.hpp +102 -0
  313. include/rerun/components/axis_length.hpp +74 -0
  314. include/rerun/components/blob.hpp +73 -0
  315. include/rerun/components/class_id.hpp +71 -0
  316. include/rerun/components/clear_is_recursive.hpp +75 -0
  317. include/rerun/components/color.hpp +99 -0
  318. include/rerun/components/colormap.hpp +99 -0
  319. include/rerun/components/depth_meter.hpp +84 -0
  320. include/rerun/components/draw_order.hpp +79 -0
  321. include/rerun/components/entity_path.hpp +83 -0
  322. include/rerun/components/fill_mode.hpp +72 -0
  323. include/rerun/components/fill_ratio.hpp +79 -0
  324. include/rerun/components/gamma_correction.hpp +80 -0
  325. include/rerun/components/geo_line_string.hpp +63 -0
  326. include/rerun/components/graph_edge.hpp +75 -0
  327. include/rerun/components/graph_node.hpp +79 -0
  328. include/rerun/components/graph_type.hpp +57 -0
  329. include/rerun/components/half_size2d.hpp +91 -0
  330. include/rerun/components/half_size3d.hpp +95 -0
  331. include/rerun/components/image_buffer.hpp +86 -0
  332. include/rerun/components/image_format.hpp +84 -0
  333. include/rerun/components/image_plane_distance.hpp +77 -0
  334. include/rerun/components/interactive.hpp +76 -0
  335. include/rerun/components/keypoint_id.hpp +74 -0
  336. include/rerun/components/lat_lon.hpp +89 -0
  337. include/rerun/components/length.hpp +77 -0
  338. include/rerun/components/line_strip2d.hpp +73 -0
  339. include/rerun/components/line_strip3d.hpp +73 -0
  340. include/rerun/components/magnification_filter.hpp +63 -0
  341. include/rerun/components/marker_shape.hpp +82 -0
  342. include/rerun/components/marker_size.hpp +74 -0
  343. include/rerun/components/media_type.hpp +157 -0
  344. include/rerun/components/name.hpp +83 -0
  345. include/rerun/components/opacity.hpp +77 -0
  346. include/rerun/components/pinhole_projection.hpp +94 -0
  347. include/rerun/components/plane3d.hpp +75 -0
  348. include/rerun/components/pose_rotation_axis_angle.hpp +73 -0
  349. include/rerun/components/pose_rotation_quat.hpp +71 -0
  350. include/rerun/components/pose_scale3d.hpp +102 -0
  351. include/rerun/components/pose_transform_mat3x3.hpp +87 -0
  352. include/rerun/components/pose_translation3d.hpp +96 -0
  353. include/rerun/components/position2d.hpp +86 -0
  354. include/rerun/components/position3d.hpp +90 -0
  355. include/rerun/components/radius.hpp +98 -0
  356. include/rerun/components/range1d.hpp +75 -0
  357. include/rerun/components/resolution.hpp +88 -0
  358. include/rerun/components/rotation_axis_angle.hpp +72 -0
  359. include/rerun/components/rotation_quat.hpp +71 -0
  360. include/rerun/components/scalar.hpp +76 -0
  361. include/rerun/components/scale3d.hpp +102 -0
  362. include/rerun/components/series_visible.hpp +76 -0
  363. include/rerun/components/show_labels.hpp +79 -0
  364. include/rerun/components/stroke_width.hpp +74 -0
  365. include/rerun/components/tensor_data.hpp +94 -0
  366. include/rerun/components/tensor_dimension_index_selection.hpp +77 -0
  367. include/rerun/components/tensor_height_dimension.hpp +71 -0
  368. include/rerun/components/tensor_width_dimension.hpp +71 -0
  369. include/rerun/components/texcoord2d.hpp +101 -0
  370. include/rerun/components/text.hpp +83 -0
  371. include/rerun/components/text_log_level.hpp +110 -0
  372. include/rerun/components/timestamp.hpp +76 -0
  373. include/rerun/components/transform_mat3x3.hpp +92 -0
  374. include/rerun/components/transform_relation.hpp +66 -0
  375. include/rerun/components/translation3d.hpp +96 -0
  376. include/rerun/components/triangle_indices.hpp +85 -0
  377. include/rerun/components/value_range.hpp +78 -0
  378. include/rerun/components/vector2d.hpp +92 -0
  379. include/rerun/components/vector3d.hpp +96 -0
  380. include/rerun/components/video_timestamp.hpp +120 -0
  381. include/rerun/components/view_coordinates.hpp +346 -0
  382. include/rerun/components/visible.hpp +74 -0
  383. include/rerun/components.hpp +77 -0
  384. include/rerun/config.hpp +52 -0
  385. include/rerun/datatypes/angle.hpp +76 -0
  386. include/rerun/datatypes/annotation_info.hpp +76 -0
  387. include/rerun/datatypes/blob.hpp +67 -0
  388. include/rerun/datatypes/bool.hpp +57 -0
  389. include/rerun/datatypes/channel_datatype.hpp +87 -0
  390. include/rerun/datatypes/class_description.hpp +92 -0
  391. include/rerun/datatypes/class_description_map_elem.hpp +69 -0
  392. include/rerun/datatypes/class_id.hpp +62 -0
  393. include/rerun/datatypes/color_model.hpp +68 -0
  394. include/rerun/datatypes/dvec2d.hpp +76 -0
  395. include/rerun/datatypes/entity_path.hpp +60 -0
  396. include/rerun/datatypes/float32.hpp +62 -0
  397. include/rerun/datatypes/float64.hpp +62 -0
  398. include/rerun/datatypes/image_format.hpp +107 -0
  399. include/rerun/datatypes/keypoint_id.hpp +63 -0
  400. include/rerun/datatypes/keypoint_pair.hpp +65 -0
  401. include/rerun/datatypes/mat3x3.hpp +105 -0
  402. include/rerun/datatypes/mat4x4.hpp +119 -0
  403. include/rerun/datatypes/pixel_format.hpp +142 -0
  404. include/rerun/datatypes/plane3d.hpp +60 -0
  405. include/rerun/datatypes/quaternion.hpp +110 -0
  406. include/rerun/datatypes/range1d.hpp +59 -0
  407. include/rerun/datatypes/range2d.hpp +55 -0
  408. include/rerun/datatypes/rgba32.hpp +94 -0
  409. include/rerun/datatypes/rotation_axis_angle.hpp +67 -0
  410. include/rerun/datatypes/tensor_buffer.hpp +529 -0
  411. include/rerun/datatypes/tensor_data.hpp +100 -0
  412. include/rerun/datatypes/tensor_dimension_index_selection.hpp +58 -0
  413. include/rerun/datatypes/tensor_dimension_selection.hpp +56 -0
  414. include/rerun/datatypes/time_int.hpp +62 -0
  415. include/rerun/datatypes/time_range.hpp +55 -0
  416. include/rerun/datatypes/time_range_boundary.hpp +175 -0
  417. include/rerun/datatypes/uint16.hpp +62 -0
  418. include/rerun/datatypes/uint32.hpp +62 -0
  419. include/rerun/datatypes/uint64.hpp +62 -0
  420. include/rerun/datatypes/utf8.hpp +76 -0
  421. include/rerun/datatypes/utf8pair.hpp +62 -0
  422. include/rerun/datatypes/uuid.hpp +60 -0
  423. include/rerun/datatypes/uvec2d.hpp +76 -0
  424. include/rerun/datatypes/uvec3d.hpp +80 -0
  425. include/rerun/datatypes/uvec4d.hpp +59 -0
  426. include/rerun/datatypes/vec2d.hpp +76 -0
  427. include/rerun/datatypes/vec3d.hpp +80 -0
  428. include/rerun/datatypes/vec4d.hpp +84 -0
  429. include/rerun/datatypes/video_timestamp.hpp +67 -0
  430. include/rerun/datatypes/view_coordinates.hpp +87 -0
  431. include/rerun/datatypes/visible_time_range.hpp +57 -0
  432. include/rerun/datatypes.hpp +51 -0
  433. include/rerun/demo_utils.hpp +75 -0
  434. include/rerun/entity_path.hpp +20 -0
  435. include/rerun/error.hpp +180 -0
  436. include/rerun/half.hpp +10 -0
  437. include/rerun/image_utils.hpp +187 -0
  438. include/rerun/indicator_component.hpp +59 -0
  439. include/rerun/loggable.hpp +54 -0
  440. include/rerun/recording_stream.hpp +960 -0
  441. include/rerun/rerun_sdk_export.hpp +25 -0
  442. include/rerun/result.hpp +86 -0
  443. include/rerun/rotation3d.hpp +33 -0
  444. include/rerun/sdk_info.hpp +20 -0
  445. include/rerun/spawn.hpp +21 -0
  446. include/rerun/spawn_options.hpp +57 -0
  447. include/rerun/string_utils.hpp +16 -0
  448. include/rerun/third_party/cxxopts.hpp +2198 -0
  449. include/rerun/time_column.hpp +288 -0
  450. include/rerun/timeline.hpp +38 -0
  451. include/rerun/type_traits.hpp +40 -0
  452. include/rerun.hpp +86 -0
  453. lib/cmake/rerun_sdk/rerun_sdkConfig.cmake +70 -0
  454. lib/cmake/rerun_sdk/rerun_sdkConfigVersion.cmake +83 -0
  455. lib/cmake/rerun_sdk/rerun_sdkTargets-release.cmake +19 -0
  456. lib/cmake/rerun_sdk/rerun_sdkTargets.cmake +108 -0
  457. lib/libarrow.a +0 -0
  458. lib/libarrow_bundled_dependencies.a +0 -0
  459. lib/librerun_c__linux_x64.a +0 -0
  460. lib/librerun_sdk.a +0 -0
  461. lib64/cmake/axel/axel-config.cmake +45 -0
  462. lib64/cmake/axel/axelTargets-release.cmake +19 -0
  463. lib64/cmake/axel/axelTargets.cmake +108 -0
  464. lib64/cmake/momentum/Findre2.cmake +52 -0
  465. lib64/cmake/momentum/momentum-config.cmake +67 -0
  466. lib64/cmake/momentum/momentumTargets-release.cmake +259 -0
  467. lib64/cmake/momentum/momentumTargets.cmake +377 -0
  468. lib64/libaxel.a +0 -0
  469. lib64/libmomentum_app_utils.a +0 -0
  470. lib64/libmomentum_character.a +0 -0
  471. lib64/libmomentum_character_sequence_solver.a +0 -0
  472. lib64/libmomentum_character_solver.a +0 -0
  473. lib64/libmomentum_common.a +0 -0
  474. lib64/libmomentum_diff_ik.a +0 -0
  475. lib64/libmomentum_io.a +0 -0
  476. lib64/libmomentum_io_common.a +0 -0
  477. lib64/libmomentum_io_fbx.a +0 -0
  478. lib64/libmomentum_io_gltf.a +0 -0
  479. lib64/libmomentum_io_legacy_json.a +0 -0
  480. lib64/libmomentum_io_marker.a +0 -0
  481. lib64/libmomentum_io_motion.a +0 -0
  482. lib64/libmomentum_io_shape.a +0 -0
  483. lib64/libmomentum_io_skeleton.a +0 -0
  484. lib64/libmomentum_io_urdf.a +0 -0
  485. lib64/libmomentum_marker_tracker.a +0 -0
  486. lib64/libmomentum_math.a +0 -0
  487. lib64/libmomentum_online_qr.a +0 -0
  488. lib64/libmomentum_process_markers.a +0 -0
  489. lib64/libmomentum_rerun.a +0 -0
  490. lib64/libmomentum_simd_constraints.a +0 -0
  491. lib64/libmomentum_simd_generalized_loss.a +0 -0
  492. lib64/libmomentum_skeleton.a +0 -0
  493. lib64/libmomentum_solver.a +0 -0
  494. pymomentum/axel.cpython-313-x86_64-linux-gnu.so +0 -0
  495. pymomentum/backend/__init__.py +16 -0
  496. pymomentum/backend/skel_state_backend.py +614 -0
  497. pymomentum/backend/trs_backend.py +871 -0
  498. pymomentum/backend/utils.py +224 -0
  499. pymomentum/geometry.cpython-313-x86_64-linux-gnu.so +0 -0
  500. pymomentum/marker_tracking.cpython-313-x86_64-linux-gnu.so +0 -0
  501. pymomentum/quaternion.py +740 -0
  502. pymomentum/skel_state.py +514 -0
  503. pymomentum/solver.cpython-313-x86_64-linux-gnu.so +0 -0
  504. pymomentum/solver2.cpython-313-x86_64-linux-gnu.so +0 -0
  505. pymomentum/torch/character.py +809 -0
  506. pymomentum/torch/parameter_limits.py +494 -0
  507. pymomentum/torch/utility.py +20 -0
  508. pymomentum/trs.py +535 -0
  509. pymomentum_cpu-0.1.77.post30.dist-info/METADATA +208 -0
  510. pymomentum_cpu-0.1.77.post30.dist-info/RECORD +555 -0
  511. pymomentum_cpu-0.1.77.post30.dist-info/WHEEL +5 -0
  512. pymomentum_cpu-0.1.77.post30.dist-info/licenses/LICENSE +21 -0
  513. pymomentum_cpu.libs/libabsl_base-86f3b38c.so.2505.0.0 +0 -0
  514. pymomentum_cpu.libs/libabsl_city-31b65ca2.so.2505.0.0 +0 -0
  515. pymomentum_cpu.libs/libabsl_debugging_internal-38680253.so.2505.0.0 +0 -0
  516. pymomentum_cpu.libs/libabsl_decode_rust_punycode-750652c3.so.2505.0.0 +0 -0
  517. pymomentum_cpu.libs/libabsl_demangle_internal-9a0351a3.so.2505.0.0 +0 -0
  518. pymomentum_cpu.libs/libabsl_demangle_rust-71629506.so.2505.0.0 +0 -0
  519. pymomentum_cpu.libs/libabsl_examine_stack-57661ecd.so.2505.0.0 +0 -0
  520. pymomentum_cpu.libs/libabsl_hash-8c523b7e.so.2505.0.0 +0 -0
  521. pymomentum_cpu.libs/libabsl_hashtablez_sampler-b5c3e343.so.2505.0.0 +0 -0
  522. pymomentum_cpu.libs/libabsl_int128-295bfed5.so.2505.0.0 +0 -0
  523. pymomentum_cpu.libs/libabsl_kernel_timeout_internal-29296ac1.so.2505.0.0 +0 -0
  524. pymomentum_cpu.libs/libabsl_log_globals-6cfa8af5.so.2505.0.0 +0 -0
  525. pymomentum_cpu.libs/libabsl_log_internal_format-a5c79460.so.2505.0.0 +0 -0
  526. pymomentum_cpu.libs/libabsl_log_internal_globals-481e9a7c.so.2505.0.0 +0 -0
  527. pymomentum_cpu.libs/libabsl_log_internal_log_sink_set-ac08f942.so.2505.0.0 +0 -0
  528. pymomentum_cpu.libs/libabsl_log_internal_message-7dfe150a.so.2505.0.0 +0 -0
  529. pymomentum_cpu.libs/libabsl_log_internal_nullguard-883adc72.so.2505.0.0 +0 -0
  530. pymomentum_cpu.libs/libabsl_log_internal_proto-a5da8c75.so.2505.0.0 +0 -0
  531. pymomentum_cpu.libs/libabsl_log_internal_structured_proto-e601fd9b.so.2505.0.0 +0 -0
  532. pymomentum_cpu.libs/libabsl_log_sink-894261b2.so.2505.0.0 +0 -0
  533. pymomentum_cpu.libs/libabsl_low_level_hash-a3284638.so.2505.0.0 +0 -0
  534. pymomentum_cpu.libs/libabsl_malloc_internal-814569de.so.2505.0.0 +0 -0
  535. pymomentum_cpu.libs/libabsl_raw_hash_set-922d64ad.so.2505.0.0 +0 -0
  536. pymomentum_cpu.libs/libabsl_raw_logging_internal-477f78ec.so.2505.0.0 +0 -0
  537. pymomentum_cpu.libs/libabsl_spinlock_wait-8b85a473.so.2505.0.0 +0 -0
  538. pymomentum_cpu.libs/libabsl_stacktrace-7369e71d.so.2505.0.0 +0 -0
  539. pymomentum_cpu.libs/libabsl_str_format_internal-98de729d.so.2505.0.0 +0 -0
  540. pymomentum_cpu.libs/libabsl_strerror-39a52998.so.2505.0.0 +0 -0
  541. pymomentum_cpu.libs/libabsl_strings-a57d5127.so.2505.0.0 +0 -0
  542. pymomentum_cpu.libs/libabsl_strings_internal-ed8c7c0d.so.2505.0.0 +0 -0
  543. pymomentum_cpu.libs/libabsl_symbolize-eba17dd1.so.2505.0.0 +0 -0
  544. pymomentum_cpu.libs/libabsl_synchronization-2f8cf326.so.2505.0.0 +0 -0
  545. pymomentum_cpu.libs/libabsl_time-066c0dde.so.2505.0.0 +0 -0
  546. pymomentum_cpu.libs/libabsl_time_zone-72867365.so.2505.0.0 +0 -0
  547. pymomentum_cpu.libs/libabsl_tracing_internal-021e37ee.so.2505.0.0 +0 -0
  548. pymomentum_cpu.libs/libabsl_utf8_for_code_point-de2a4d4a.so.2505.0.0 +0 -0
  549. pymomentum_cpu.libs/libconsole_bridge-f26e11cc.so.1.0 +0 -0
  550. pymomentum_cpu.libs/libdeflate-577b71e3.so.0 +0 -0
  551. pymomentum_cpu.libs/libdispenso-67ac1721.so.1.4.0 +0 -0
  552. pymomentum_cpu.libs/libezc3d-4a95ab2c.so +0 -0
  553. pymomentum_cpu.libs/libre2-985fb83c.so.11 +0 -0
  554. pymomentum_cpu.libs/libtinyxml2-8d10763c.so.11.0.0 +0 -0
  555. pymomentum_cpu.libs/liburdfdom_model-7b26ae88.so.4.0 +0 -0
@@ -0,0 +1,871 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ # pyre-strict
7
+ """
8
+ TRS Backend for PyMomentum
9
+
10
+ This module provides efficient forward kinematics and skinning operations using
11
+ the TRS (Translation-Rotation-Scale) representation where each transformation
12
+ component is stored separately.
13
+
14
+ The TRS representation uses separate tensors for translation (3D), rotation matrices (3x3),
15
+ and scale factors (1D), making it suitable for applications that need explicit access
16
+ to individual transformation components.
17
+
18
+ Performance Notes:
19
+ This backend is typically 25-50% faster than the skeleton state backend in PyTorch,
20
+ likely due to not requiring quaternion normalization operations. While it doesn't
21
+ match the C++ reference implementation exactly (use skel_state_backend for that),
22
+ it provides excellent performance for PyTorch-based applications.
23
+
24
+ Key Functions:
25
+ - global_trs_state_from_local_trs_state: Forward kinematics from local to global joint states
26
+ - skin_points_from_trs_state: Linear blend skinning using TRS transformations
27
+ - local_trs_state_from_joint_params: Convert joint parameters to local TRS states
28
+
29
+ Related Modules:
30
+ - skel_state_backend: Alternative backend using compact 8-parameter skeleton states
31
+ - trs: Core TRS transformation operations and utilities
32
+ """
33
+
34
+ from typing import List, Tuple
35
+
36
+ import torch as th
37
+ from pymomentum import trs
38
+
39
+
40
+ @th.jit.script
41
+ def global_trs_state_from_local_trs_state_impl(
42
+ local_state_t: th.Tensor,
43
+ local_state_r: th.Tensor,
44
+ local_state_s: th.Tensor,
45
+ prefix_mul_indices: List[th.Tensor],
46
+ save_intermediate_results: bool = True,
47
+ use_double_precision: bool = True,
48
+ ) -> Tuple[
49
+ th.Tensor, th.Tensor, th.Tensor, List[Tuple[th.Tensor, th.Tensor, th.Tensor]]
50
+ ]:
51
+ """
52
+ Compute global TRS state from local joint transformations using forward kinematics.
53
+
54
+ This function implements forward kinematics (FK) using prefix multiplication for efficient
55
+ parallel computation. Each joint's local TRS transformation is composed with its parent's
56
+ global transformation to produce the joint's global transformation.
57
+
58
+ The TRS representation uses separate tensors for each transformation component:
59
+ - Translation (3D): translation vector [tx, ty, tz]
60
+ - Rotation (3x3): rotation matrix
61
+ - Scale (1D): uniform scale factor [s]
62
+
63
+ Forward Kinematics Formula:
64
+ For each joint j with parent p in the kinematic hierarchy:
65
+ s_global_j = s_parent * s_local_j
66
+ R_global_j = R_parent * R_local_j
67
+ t_global_j = t_parent + s_parent * R_parent * t_local_j
68
+
69
+ This corresponds to the similarity transformation: y = s * R * x + t
70
+
71
+ Args:
72
+ local_state_t: Local joint translations, shape (batch_size, num_joints, 3).
73
+ local_state_r: Local joint rotations, shape (batch_size, num_joints, 3, 3).
74
+ local_state_s: Local joint scales, shape (batch_size, num_joints, 1).
75
+ prefix_mul_indices: List of [child_index, parent_index] tensor pairs that define
76
+ the traversal order for the kinematic tree. This ordering enables efficient
77
+ parallel computation while respecting parent-child dependencies.
78
+ save_intermediate_results: If True, saves intermediate joint states during the
79
+ forward pass for use in backpropagation. Set to False for inference-only
80
+ computations to reduce memory usage.
81
+ use_double_precision: If True, performs computations in float64 for improved
82
+ numerical stability. Recommended for deep kinematic chains to minimize
83
+ accumulated floating-point errors.
84
+
85
+ Returns:
86
+ global_state_t: Global joint translations, shape (batch_size, num_joints, 3).
87
+ global_state_r: Global joint rotations, shape (batch_size, num_joints, 3, 3).
88
+ global_state_s: Global joint scales, shape (batch_size, num_joints, 1).
89
+ intermediate_results: List of (t, r, s) tuples from the forward pass.
90
+ Required for efficient gradient computation during backpropagation.
91
+ Empty if save_intermediate_results=False.
92
+
93
+ Note:
94
+ This function is JIT-compiled for performance. The prefix multiplication approach
95
+ allows vectorized batch computation while maintaining kinematic chain dependencies.
96
+ The function is not differentiable by itself - use the wrapper function for gradients.
97
+
98
+ See Also:
99
+ :func:`global_trs_state_from_local_trs_state`: User-facing wrapper with autodiff
100
+ :func:`local_trs_state_from_joint_params`: Convert joint parameters to local states
101
+ """
102
+ dtype = local_state_t.dtype
103
+ with th.no_grad():
104
+ if use_double_precision:
105
+ joint_state_t = local_state_t.clone().double()
106
+ joint_state_r = local_state_r.clone().double()
107
+ joint_state_s = local_state_s.clone().double()
108
+ else:
109
+ joint_state_t = local_state_t.clone()
110
+ joint_state_r = local_state_r.clone()
111
+ joint_state_s = local_state_s.clone()
112
+
113
+ intermediate_results: List[Tuple[th.Tensor, th.Tensor, th.Tensor]] = []
114
+
115
+ for prefix_mul_index in prefix_mul_indices:
116
+ source = prefix_mul_index[0]
117
+ target = prefix_mul_index[1]
118
+
119
+ s1 = joint_state_s[:, target]
120
+ r1 = joint_state_r[:, target]
121
+ t1 = joint_state_t[:, target]
122
+
123
+ s2 = joint_state_s[:, source]
124
+ r2 = joint_state_r[:, source]
125
+ t2 = joint_state_t[:, source]
126
+
127
+ if save_intermediate_results:
128
+ intermediate_results.append(
129
+ (
130
+ t2.clone(),
131
+ r2.clone(),
132
+ s2.clone(),
133
+ )
134
+ )
135
+
136
+ t3, r3, s3 = trs.multiply((t1, r1, s1), (t2, r2, s2))
137
+
138
+ joint_state_s[:, source] = s3
139
+ joint_state_r[:, source] = r3
140
+ joint_state_t[:, source] = t3
141
+
142
+ return (
143
+ joint_state_t.to(dtype),
144
+ joint_state_r.to(dtype),
145
+ joint_state_s.to(dtype),
146
+ intermediate_results,
147
+ )
148
+
149
+
150
+ @th.jit.script
151
+ def global_trs_state_from_local_trs_state_no_grad(
152
+ local_state_t: th.Tensor,
153
+ local_state_r: th.Tensor,
154
+ local_state_s: th.Tensor,
155
+ prefix_mul_indices: List[th.Tensor],
156
+ save_intermediate_results: bool = True,
157
+ use_double_precision: bool = True,
158
+ ) -> Tuple[
159
+ th.Tensor, th.Tensor, th.Tensor, List[Tuple[th.Tensor, th.Tensor, th.Tensor]]
160
+ ]:
161
+ """
162
+ Compute global TRS state without gradient tracking.
163
+
164
+ This is a convenience wrapper around global_trs_state_from_local_trs_state_impl
165
+ that explicitly disables gradient computation using torch.no_grad(). Useful for
166
+ inference-only forward passes to reduce memory usage.
167
+
168
+ Args:
169
+ local_state_t: Local joint translations, shape (batch_size, num_joints, 3)
170
+ local_state_r: Local joint rotations, shape (batch_size, num_joints, 3, 3)
171
+ local_state_s: Local joint scales, shape (batch_size, num_joints, 1)
172
+ prefix_mul_indices: List of [child_index, parent_index] tensor pairs
173
+ save_intermediate_results: Whether to save intermediate states for backprop
174
+ use_double_precision: Whether to use float64 for numerical stability
175
+
176
+ Returns:
177
+ global_state_t: Global joint translations, shape (batch_size, num_joints, 3)
178
+ global_state_r: Global joint rotations, shape (batch_size, num_joints, 3, 3)
179
+ global_state_s: Global joint scales, shape (batch_size, num_joints, 1)
180
+ intermediate_results: List of (t, r, s) tuples from forward pass
181
+
182
+ See Also:
183
+ :func:`global_trs_state_from_local_trs_state_impl`: Implementation function
184
+ """
185
+ with th.no_grad():
186
+ outputs = global_trs_state_from_local_trs_state_impl(
187
+ local_state_t,
188
+ local_state_r,
189
+ local_state_s,
190
+ prefix_mul_indices,
191
+ save_intermediate_results=save_intermediate_results,
192
+ use_double_precision=use_double_precision,
193
+ )
194
+ return outputs
195
+
196
+
197
+ @th.jit.script
198
+ def ik_from_global_state(
199
+ global_state_t: th.Tensor,
200
+ global_state_r: th.Tensor,
201
+ global_state_s: th.Tensor,
202
+ prefix_mul_indices: List[th.Tensor],
203
+ use_double_precision: bool = True,
204
+ ) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
205
+ dtype = global_state_t.dtype
206
+
207
+ if use_double_precision:
208
+ local_state_t = global_state_t.clone().double()
209
+ local_state_r = global_state_r.clone().double()
210
+ local_state_s = global_state_s.clone().double()
211
+ else:
212
+ local_state_t = global_state_t.clone()
213
+ local_state_r = global_state_r.clone()
214
+ local_state_s = global_state_s.clone()
215
+
216
+ # Compose the inverse of the FK transforms, in reverse order.
217
+ for prefix_mul_index in prefix_mul_indices[::-1]:
218
+ joint = prefix_mul_index[0]
219
+ parent = prefix_mul_index[1]
220
+
221
+ s1 = local_state_s[:, parent].reciprocal()
222
+ r1 = trs.rotmat_inverse(local_state_r[:, parent])
223
+ t1 = local_state_t[:, parent]
224
+
225
+ s2 = local_state_s[:, joint]
226
+ r2 = local_state_r[:, joint]
227
+ t2 = local_state_t[:, joint]
228
+
229
+ local_state_s[:, joint] = s1 * s2
230
+ local_state_r[:, joint] = trs.rotmat_multiply(r1, r2)
231
+ local_state_t[:, joint] = trs.rotmat_rotate_vector(r1, (t2 - t1) * s1)
232
+
233
+ return (
234
+ local_state_t.to(dtype),
235
+ local_state_r.to(dtype),
236
+ local_state_s.to(dtype),
237
+ )
238
+
239
+
240
+ @th.jit.script
241
+ def global_trs_state_from_local_trs_state_backprop(
242
+ joint_state_t: th.Tensor,
243
+ joint_state_r: th.Tensor,
244
+ joint_state_s: th.Tensor,
245
+ grad_joint_state_t: th.Tensor,
246
+ grad_joint_state_r: th.Tensor,
247
+ grad_joint_state_s: th.Tensor,
248
+ prefix_mul_indices: List[th.Tensor],
249
+ intermediate_results: List[Tuple[th.Tensor, th.Tensor, th.Tensor]],
250
+ use_double_precision: bool = True,
251
+ ) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
252
+ r"""
253
+ The backward pass of fk_from_local_state_no_grad.
254
+
255
+ during backprop, we have
256
+ \partial L/\partial tl_i = \sum_j \partial L/\partial tg_j * \partial tg_j/\partial tl_i
257
+ \partial L/\partial sl_i = \sum_j \partial L/\partial tg_j * \partial tg_j/\partial sl_i + \sum_j \partial L/\partial sg_j * \partial tg_j/\partial sl_i
258
+ \partial L/\partial rl_i = \sum_j \partial L/\partial tg_j * \partial tg_j/\partial rl_i + \sum_j \partial L/\partial rg_j * \partial tg_j/\partial rl_i
259
+
260
+ however, if we naively do this, sum_j is very expensive (and the jacobian is very sparse).
261
+
262
+ consider how we do prefix multiplication during forward:
263
+ assume the chain order is [0, 1, 2, 3]
264
+
265
+ forward (source <- target)
266
+
267
+ level 0: [1, 3] <- [0, 2]
268
+ now the chain is [0, 01, 2, 23]
269
+
270
+ level 1: [2, 3] <- [1, 1]
271
+ now the chain is [0, 01, 012, 0123]
272
+
273
+ now consider backward, for level 0 we need to cast
274
+ {
275
+ g(s1), g(s1s2), g(s1s2s3), g(s1s2s3s4);
276
+ g(R1), g(R1R2), g(R1R2R3), g(R1R2R3R4);
277
+ g(t1), g(t1+s1R1t2), g(s1+s1R2t2+s1R1s2R2t3), ...
278
+ }
279
+ into
280
+ {
281
+ g(s1), g(s1s2), g(s3), g(s3s4);
282
+ g(R1), g(R1R2), g(R3), g(R3R4);
283
+ g(t1), g(t1+s1R1t2), g(t3), g(t3+s3R3t4)
284
+ }
285
+ which is actually in reverse order of forward levels.
286
+ """
287
+ dtype = joint_state_t.dtype
288
+ if use_double_precision:
289
+ grad_local_state_t = grad_joint_state_t.clone().double()
290
+ grad_local_state_r = grad_joint_state_r.clone().double()
291
+ grad_local_state_s = grad_joint_state_s.clone().double()
292
+ joint_state_t = joint_state_t.clone().double()
293
+ joint_state_r = joint_state_r.clone().double()
294
+ joint_state_s = joint_state_s.clone().double()
295
+ else:
296
+ grad_local_state_t = grad_joint_state_t.clone()
297
+ grad_local_state_r = grad_joint_state_r.clone()
298
+ grad_local_state_s = grad_joint_state_s.clone()
299
+ joint_state_t = joint_state_t.clone()
300
+ joint_state_r = joint_state_r.clone()
301
+ joint_state_s = joint_state_s.clone()
302
+
303
+ # instead of calculating the original s, r and t from global state
304
+ # we just load them via forward intermediate results
305
+ for prefix_mul_index, (t, r, s) in list(
306
+ zip(prefix_mul_indices, intermediate_results)
307
+ )[::-1]:
308
+ source = prefix_mul_index[0]
309
+ target = prefix_mul_index[1]
310
+
311
+ grad_s2 = grad_local_state_s[:, source]
312
+ grad_r2 = grad_local_state_r[:, source]
313
+ grad_t2 = grad_local_state_t[:, source]
314
+
315
+ # the corresponding global state
316
+ sg1 = joint_state_s[:, target]
317
+ rg1 = joint_state_r[:, target]
318
+
319
+ # TODO: maybe we should better formulate this function
320
+ # similar to pymomentum_state.py
321
+
322
+ # backward accumulation on the reduced child state (source)
323
+ # (translate torch.einsum as explicit summations to improve speed)
324
+ grad_s = sg1 * grad_s2
325
+ # original: grad_t = sg1 * th.einsum("bjyx,bjy->bjx", rg1, grad_t2)
326
+ grad_t = sg1 * (rg1 * grad_t2[..., None]).sum(dim=2)
327
+ # original: grad_r = th.einsum("bjyx,bjyz->bjxz", rg1, grad_r2)
328
+ grad_r = (rg1[:, :, :, :, None] * grad_r2[:, :, :, None, :]).sum(dim=2)
329
+
330
+ # backward accumulation on the ancestor state (target)
331
+ # original: grad_s1_accum = th.einsum("bjxy,bjy,bjx->bj", rg1, t, grad_t2)[
332
+ # :, :, None
333
+ # ]
334
+ grad_s1_accum = (
335
+ (rg1 * t[:, :, None, :] * grad_t2[:, :, :, None])
336
+ .sum(dim=3)
337
+ .sum(dim=2, keepdim=True)
338
+ )
339
+ grad_s1_accum = grad_s1_accum + s * grad_s2
340
+
341
+ # original: grad_r1_accum = th.einsum("bjxy,bjzy->bjxz", grad_r2, r)
342
+ # original: grad_r1_accum = grad_r1_accum + th.einsum(
343
+ # "bj,bjx,bjy->bjxy",
344
+ # sg1[:, :, 0],
345
+ # grad_t2,
346
+ # t,
347
+ # )
348
+ grad_r1_accum = (grad_r2[:, :, :, None, :] * r[:, :, None, :, :]).sum(dim=4)
349
+ grad_r1_accum = (
350
+ grad_r1_accum + (sg1 * grad_t2)[:, :, :, None] * t[:, :, None, :]
351
+ )
352
+
353
+ grad_t1_accum = grad_t2
354
+
355
+ # setup the reduced gradients
356
+ grad_local_state_t[:, source] = grad_t
357
+ grad_local_state_r[:, source] = grad_r
358
+ grad_local_state_s[:, source] = grad_s
359
+
360
+ grad_local_state_t.index_add_(1, target, grad_t1_accum)
361
+ grad_local_state_r.index_add_(1, target, grad_r1_accum)
362
+ grad_local_state_s.index_add_(1, target, grad_s1_accum)
363
+
364
+ # setup the reduced KC
365
+ joint_state_t[:, source] = t
366
+ joint_state_r[:, source] = r
367
+ joint_state_s[:, source] = s
368
+
369
+ return (
370
+ grad_local_state_t.to(dtype),
371
+ grad_local_state_r.to(dtype),
372
+ grad_local_state_s.to(dtype),
373
+ )
374
+
375
+
376
+ class ForwardKinematicsFromLocalTransformationJIT(th.autograd.Function):
377
+ @staticmethod
378
+ # pyre-ignore[14]
379
+ def forward(
380
+ local_state_t: th.Tensor,
381
+ local_state_r: th.Tensor,
382
+ local_state_s: th.Tensor,
383
+ prefix_mul_indices: List[th.Tensor],
384
+ ) -> Tuple[
385
+ th.Tensor, th.Tensor, th.Tensor, List[Tuple[th.Tensor, th.Tensor, th.Tensor]]
386
+ ]:
387
+ return global_trs_state_from_local_trs_state_no_grad(
388
+ local_state_t,
389
+ local_state_r,
390
+ local_state_s,
391
+ prefix_mul_indices,
392
+ )
393
+
394
+ @staticmethod
395
+ # pyre-ignore[14]
396
+ # pyre-ignore[2]
397
+ def setup_context(ctx, inputs, outputs) -> None:
398
+ (
399
+ _,
400
+ _,
401
+ _,
402
+ prefix_mul_indices,
403
+ ) = inputs
404
+ (
405
+ joint_state_t,
406
+ joint_state_r,
407
+ joint_state_s,
408
+ intermediate_results,
409
+ ) = outputs
410
+ # need to clone as it's modified in-place
411
+ ctx.save_for_backward(
412
+ joint_state_t.clone(),
413
+ joint_state_r.clone(),
414
+ joint_state_s.clone(),
415
+ )
416
+ ctx.intermediate_results = intermediate_results
417
+ ctx.prefix_mul_indices = prefix_mul_indices
418
+
419
+ @staticmethod
420
+ # pyre-ignore[14]
421
+ def backward(
422
+ # pyre-ignore[2]
423
+ ctx,
424
+ grad_joint_state_t: th.Tensor,
425
+ grad_joint_state_r: th.Tensor,
426
+ grad_joint_state_s: th.Tensor,
427
+ _0,
428
+ ) -> Tuple[th.Tensor, th.Tensor, th.Tensor, None]:
429
+ (
430
+ joint_state_t,
431
+ joint_state_r,
432
+ joint_state_s,
433
+ ) = ctx.saved_tensors
434
+
435
+ intermediate_results = ctx.intermediate_results
436
+ prefix_mul_indices = ctx.prefix_mul_indices
437
+
438
+ (
439
+ grad_local_state_t,
440
+ grad_local_state_r,
441
+ grad_local_state_s,
442
+ ) = global_trs_state_from_local_trs_state_backprop(
443
+ joint_state_t,
444
+ joint_state_r,
445
+ joint_state_s,
446
+ grad_joint_state_t,
447
+ grad_joint_state_r,
448
+ grad_joint_state_s,
449
+ prefix_mul_indices,
450
+ intermediate_results,
451
+ )
452
+ return (grad_local_state_t, grad_local_state_r, grad_local_state_s, None)
453
+
454
+
455
+ def global_trs_state_from_local_trs_state(
456
+ local_state_t: th.Tensor,
457
+ local_state_r: th.Tensor,
458
+ local_state_s: th.Tensor,
459
+ prefix_mul_indices: List[th.Tensor],
460
+ ) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
461
+ """
462
+ Compute global TRS state from local joint transformations (user-facing wrapper).
463
+
464
+ This is the main entry point for forward kinematics using TRS states. It automatically
465
+ selects between JIT-compiled and autograd-enabled implementations based on the execution context.
466
+
467
+ Args:
468
+ local_state_t: Local joint translations, shape (batch_size, num_joints, 3).
469
+ local_state_r: Local joint rotations, shape (batch_size, num_joints, 3, 3).
470
+ local_state_s: Local joint scales, shape (batch_size, num_joints, 1).
471
+ prefix_mul_indices: List of [child_index, parent_index] tensor pairs defining
472
+ the kinematic hierarchy traversal order.
473
+
474
+ Returns:
475
+ global_state_t: Global joint translations, shape (batch_size, num_joints, 3).
476
+ global_state_r: Global joint rotations, shape (batch_size, num_joints, 3, 3).
477
+ global_state_s: Global joint scales, shape (batch_size, num_joints, 1).
478
+
479
+ Note:
480
+ When called within torch.jit.script or torch.jit.trace context, uses the JIT-compiled
481
+ implementation for maximum performance. Otherwise, uses the autograd-enabled version
482
+ for gradient computation.
483
+
484
+ See Also:
485
+ :func:`global_trs_state_from_local_trs_state_impl`: JIT implementation
486
+ :func:`local_trs_state_from_joint_params`: Convert joint parameters to local states
487
+ """
488
+ if th.jit.is_tracing() or th.jit.is_scripting():
489
+ (
490
+ joint_state_t,
491
+ joint_state_r,
492
+ joint_state_s,
493
+ _,
494
+ ) = global_trs_state_from_local_trs_state_impl(
495
+ local_state_t,
496
+ local_state_r,
497
+ local_state_s,
498
+ prefix_mul_indices,
499
+ )
500
+ else:
501
+ (
502
+ joint_state_t,
503
+ joint_state_r,
504
+ joint_state_s,
505
+ _,
506
+ ) = ForwardKinematicsFromLocalTransformationJIT.apply(
507
+ local_state_t,
508
+ local_state_r,
509
+ local_state_s,
510
+ prefix_mul_indices,
511
+ )
512
+ return (
513
+ joint_state_t,
514
+ joint_state_r,
515
+ joint_state_s,
516
+ )
517
+
518
+
519
+ def global_trs_state_from_local_trs_state_forward_only(
520
+ local_state_t: th.Tensor,
521
+ local_state_r: th.Tensor,
522
+ local_state_s: th.Tensor,
523
+ prefix_mul_indices: list[th.Tensor],
524
+ ) -> tuple[th.Tensor, th.Tensor, th.Tensor]:
525
+ """
526
+ Compute global TRS state from local joint transformations (forward-only wrapper).
527
+
528
+ This is a forward-only version that bypasses autograd completely, used when
529
+ gradients are not needed and maximum performance is required.
530
+
531
+ Args:
532
+ local_state_t: Local joint translations, shape (batch_size, num_joints, 3).
533
+ local_state_r: Local joint rotations, shape (batch_size, num_joints, 3, 3).
534
+ local_state_s: Local joint scales, shape (batch_size, num_joints, 1).
535
+ prefix_mul_indices: List of [child_index, parent_index] tensor pairs.
536
+
537
+ Returns:
538
+ global_state_t: Global joint translations, shape (batch_size, num_joints, 3).
539
+ global_state_r: Global joint rotations, shape (batch_size, num_joints, 3, 3).
540
+ global_state_s: Global joint scales, shape (batch_size, num_joints, 1).
541
+
542
+ See Also:
543
+ :func:`global_trs_state_from_local_trs_state`: Main user-facing function with autograd
544
+ """
545
+ (
546
+ joint_state_t,
547
+ joint_state_r,
548
+ joint_state_s,
549
+ _,
550
+ ) = ForwardKinematicsFromLocalTransformationJIT.forward(
551
+ local_state_t,
552
+ local_state_r,
553
+ local_state_s,
554
+ prefix_mul_indices,
555
+ )
556
+ return (
557
+ joint_state_t,
558
+ joint_state_r,
559
+ joint_state_s,
560
+ )
561
+
562
+
563
+ @th.jit.script
564
+ def skinning(
565
+ template: th.Tensor,
566
+ t: th.Tensor,
567
+ r: th.Tensor,
568
+ s: th.Tensor,
569
+ t0: th.Tensor,
570
+ r0: th.Tensor,
571
+ skin_indices_flattened: th.Tensor,
572
+ skin_weights_flattened: th.Tensor,
573
+ vert_indices_flattened: th.Tensor,
574
+ ) -> th.Tensor:
575
+ r"""
576
+ LBS skinning formula as is in lbs_pytorch:
577
+ https://ghe.oculus-rep.com/ydong142857/lbs_pytorch
578
+
579
+ TODO: we might want to change skinning to double precision
580
+ with current float32 formulation the numerical error is bigger than 1e-3 level
581
+ (but smaller than 1e-2 level)
582
+
583
+ Basically,
584
+ y_i = \sum_j w_ij (s_j * r_j * (r0_j * x_i + t0_j) + t_j)
585
+ where \sum_j w_ij = 1, \forall i
586
+
587
+ Args:
588
+ template: (B, V, 3) LBS template
589
+ t: (B, J, 3) Translation of the joints
590
+ r: (B, J, 3, 3) Rotation of the joints
591
+ s: (B, J, 1) Scale of the joints
592
+ t0: (J, 3) Translation of inverse bind pose
593
+ r0: (J, 3, 3) Rotation of inverse bind pose
594
+ (for our setting, s0 == 1)
595
+ skin_indices_flattened: (N, ) LBS skinning nbr joint indices
596
+ skin_weights_flattened: (N, ) LBS skinning nbr joint weights
597
+ vert_indices_flattened: (N, ) LBS skinning nbr corresponding vertex indices
598
+
599
+ Returns:
600
+ skinned: (B, V, 3) Skinned mesh
601
+ """
602
+ batch_size = t.shape[0]
603
+ if template.shape[0] != batch_size:
604
+ template = template[None, ...].expand(batch_size, -1, -1)
605
+
606
+ sr = s[:, :, :, None] * r
607
+ A = trs.rotmat_multiply(sr, r0[None])
608
+ b = trs.rotmat_rotate_vector(sr, t0[None]) + t
609
+
610
+ skinned = th.zeros_like(template)
611
+ skinned = skinned.index_add(
612
+ 1,
613
+ vert_indices_flattened,
614
+ (
615
+ trs.rotmat_rotate_vector(
616
+ th.index_select(A, 1, skin_indices_flattened),
617
+ th.index_select(template, 1, vert_indices_flattened),
618
+ )
619
+ + th.index_select(b, 1, skin_indices_flattened)
620
+ )
621
+ * skin_weights_flattened[None, :, None],
622
+ )
623
+ return skinned
624
+
625
+
626
+ @th.jit.script
627
+ def multi_topology_skinning(
628
+ template: th.Tensor,
629
+ t: th.Tensor,
630
+ r: th.Tensor,
631
+ s: th.Tensor,
632
+ t0: th.Tensor,
633
+ r0: th.Tensor,
634
+ skin_indices_flattened: th.Tensor,
635
+ skin_weights_flattened: th.Tensor,
636
+ vert_indices_flattened: th.Tensor,
637
+ ) -> th.Tensor:
638
+ r"""
639
+ LBS skinning formula as is in lbs_pytorch:
640
+ https://ghe.oculus-rep.com/ydong142857/lbs_pytorch
641
+
642
+ The difference here is that we assume that the flattened indices are for multiple
643
+ topologies. So vert_indices_flattened needs to flattened with the batch dimension.
644
+
645
+ TODO: we might want to change skinning to double precision
646
+ with current float32 formulation the numerical error is bigger than 1e-3 level
647
+ (but smaller than 1e-2 level)
648
+
649
+ Basically,
650
+ y_i = \sum_j w_ij (s_j * r_j * (r0_j * x_i + t0_j) + t_j)
651
+ where \sum_j w_ij = 1, \forall i
652
+
653
+ Args:
654
+ template: (B, V, 3) LBS template
655
+ t: (B, J, 3) Translation of the joints
656
+ r: (B, J, 3, 3) Rotation of the joints
657
+ s: (B, J, 1) Scale of the joints
658
+ t0: (J, 3) Translation of inverse bind pose
659
+ r0: (J, 3, 3) Rotation of inverse bind pose
660
+ (for our setting, s0 == 1)
661
+ skin_indices_flattened: (N, ) LBS skinning nbr joint indices
662
+ skin_weights_flattened: (N, ) LBS skinning nbr joint weights
663
+ vert_indices_flattened: (N, ) LBS skinning nbr corresponding vertex indices
664
+
665
+ Returns:
666
+ skinned: (B, V, 3) Skinned mesh
667
+ """
668
+ batch_size = t.shape[0]
669
+ if template.shape[0] != batch_size:
670
+ template = template[None, ...].expand(batch_size, -1, -1)
671
+
672
+ sr = s[:, :, :, None] * r
673
+ A = trs.rotmat_multiply(sr, r0[None])
674
+ b = trs.rotmat_rotate_vector(sr, t0[None]) + t
675
+
676
+ # If multi_topology is True, then index on the 0th dimension of A and b
677
+ # because we assume that the skin indices are flattened to index into different
678
+ # vertex indices in each sample of the batch.
679
+
680
+ skinning_A = th.index_select(
681
+ A.view(A.shape[0] * A.shape[1], A.shape[2], A.shape[3]),
682
+ 0,
683
+ skin_indices_flattened,
684
+ )
685
+
686
+ skinning_b = th.index_select(
687
+ b.view(b.shape[0] * b.shape[1], b.shape[2]), 0, skin_indices_flattened
688
+ )
689
+
690
+ skinning_verts = th.index_select(
691
+ template.view(template.shape[0] * template.shape[1], template.shape[2]),
692
+ 0,
693
+ vert_indices_flattened,
694
+ )
695
+
696
+ skinned = th.zeros_like(template).view(
697
+ template.shape[0] * template.shape[1], template.shape[2]
698
+ )
699
+ skinned = skinned.index_add(
700
+ 0,
701
+ vert_indices_flattened,
702
+ (trs.rotmat_rotate_vector(skinning_A, skinning_verts) + skinning_b)
703
+ * skin_weights_flattened[..., None],
704
+ )
705
+ return skinned.view(template.shape[0], template.shape[1], template.shape[2])
706
+
707
+
708
+ def unpose_from_global_joint_state(
709
+ verts: th.Tensor,
710
+ t: th.Tensor,
711
+ r: th.Tensor,
712
+ s: th.Tensor,
713
+ t0: th.Tensor,
714
+ r0: th.Tensor,
715
+ skin_indices_flattened: th.Tensor,
716
+ skin_weights_flattened: th.Tensor,
717
+ vert_indices_flattened: th.Tensor,
718
+ with_high_precision: bool = True,
719
+ ) -> th.Tensor:
720
+ """
721
+ The inverse function of skinning().
722
+ WARNING: the precision is low...
723
+
724
+ Args:
725
+ verts: [batch_size, num_verts, 3]
726
+ t: (B, J, 3) Translation of the joints
727
+ r: (B, J, 3, 3) Rotation of the joints
728
+ s: (B, J, 1) Scale of the joints
729
+ t0: (J, 3) Translation of inverse bind pose
730
+ r0: (J, 3, 3) Rotation of inverse bind pose
731
+ skin_indices_flattened: (N, ) LBS skinning nbr joint indices
732
+ skin_weights_flattened: (N, ) LBS skinning nbr joint weights
733
+ vert_indices_flattened: (N, ) LBS skinning nbr corresponding vertex indices
734
+ with_high_precision: if True, use high precision solver (LDLT),
735
+ but requires a cuda device sync
736
+ """
737
+ dtype = verts.dtype
738
+ device = verts.device
739
+
740
+ sr = s[:, :, :, None] * r
741
+ A = trs.rotmat_multiply(sr, r0[None])
742
+ b = trs.rotmat_rotate_vector(sr, t0[None]) + t
743
+
744
+ fused_A = th.zeros(verts.shape + (3,), dtype=dtype, device=device)
745
+ fused_b = th.zeros(verts.shape, dtype=dtype, device=device)
746
+ fused_A = fused_A.index_add_(
747
+ 1,
748
+ vert_indices_flattened,
749
+ th.index_select(
750
+ A,
751
+ 1,
752
+ skin_indices_flattened,
753
+ )
754
+ * skin_weights_flattened[None, :, None, None],
755
+ )
756
+ fused_b = fused_b.index_add_(
757
+ 1,
758
+ vert_indices_flattened,
759
+ th.index_select(
760
+ b,
761
+ 1,
762
+ skin_indices_flattened,
763
+ )
764
+ * skin_weights_flattened[None, :, None],
765
+ )
766
+
767
+ if with_high_precision:
768
+ # th.linalg.solve is not aware of the condition number
769
+ # let's use LDLT decomposition
770
+ ATA = th.einsum("bvyx,bvyz->bvxz", fused_A, fused_A)
771
+ ATb = th.einsum("bvyx,bvy->bvx", fused_A, verts - fused_b)
772
+
773
+ # ldl_factor_ex is very slow on GPU
774
+ LD, pivots, _ = th.linalg.ldl_factor_ex(ATA.cpu())
775
+ unposed_mesh = th.linalg.ldl_solve(LD, pivots, ATb[..., None].cpu())[..., 0]
776
+
777
+ unposed_mesh = unposed_mesh.to(ATA.device)
778
+ else:
779
+ unposed_mesh = th.linalg.solve(fused_A, verts - fused_b)
780
+
781
+ return unposed_mesh
782
+
783
+
784
+ @th.jit.script
785
+ def get_local_state_from_joint_params(
786
+ joint_params: th.Tensor,
787
+ joint_offset: th.Tensor,
788
+ joint_rotation: th.Tensor,
789
+ joint_parents: th.Tensor | None = None,
790
+ allow_inverse_kinematic_chain: bool = False,
791
+ ) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
792
+ """
793
+ calculate local joint state from joint parameters.
794
+
795
+ Args:
796
+ joint_params: [batch_size, num_joints, 7] or [batch_size, num_joints * 7]
797
+ joint_offset: [num_joints, 3]
798
+ joint_rotation: [num_joints, 3, 3]
799
+ allow_inverse_kinematic_chain: if set to True, this hints that the kinematic
800
+ chain might be reversed (e.g. from wrist to root). This leads to a few
801
+ changes in assumption. One of the major difference is that the root joint
802
+ always has identity [0, I, 1] transformation.
803
+
804
+ Returns:
805
+ local_state_t: [batch_size, num_joints, 3]
806
+ local_state_r: [batch_size, num_joints, 3, 3]
807
+ local_state_s: [batch_size, num_joints, 1]
808
+ """
809
+ if len(joint_params.shape) == 2:
810
+ # reshape joint_params as (batch_size, num_joints, 7)
811
+ joint_params = joint_params.view(joint_params.shape[0], -1, 7)
812
+
813
+ # the vanilla conversion
814
+ local_state_t = joint_params[:, :, :3] + joint_offset[None, :]
815
+ local_state_r = trs.rotmat_multiply(
816
+ joint_rotation[None], trs.rotmat_from_euler_xyz(joint_params[:, :, 3:6])
817
+ )
818
+ local_state_s = th.exp2(joint_params[:, :, 6:])
819
+
820
+ if allow_inverse_kinematic_chain:
821
+ assert joint_parents is not None
822
+ assert len(joint_parents.shape) == 1
823
+ device = joint_parents.device
824
+ root_joint = th.where(joint_parents == -1)[0]
825
+ inversed_joints = th.where(
826
+ joint_parents
827
+ > th.arange(0, len(joint_parents), dtype=th.long, device=device)
828
+ )[0]
829
+ inversed_joint_parents = joint_parents[inversed_joints]
830
+
831
+ # create a new node so the autograd does not fail
832
+ (
833
+ _local_state_t,
834
+ _local_state_r,
835
+ _local_state_s,
836
+ ) = (
837
+ local_state_t.clone(),
838
+ local_state_r.clone(),
839
+ local_state_s.clone(),
840
+ )
841
+
842
+ # for the inverse joints
843
+ # the order needs to be inversed
844
+ (
845
+ _local_state_t[:, inversed_joints],
846
+ _local_state_r[:, inversed_joints],
847
+ _local_state_s[:, inversed_joints],
848
+ ) = trs.inverse(
849
+ (
850
+ local_state_t[:, inversed_joint_parents],
851
+ local_state_r[:, inversed_joint_parents],
852
+ local_state_s[:, inversed_joint_parents],
853
+ )
854
+ )
855
+
856
+ # set new root joint to identity
857
+ _local_state_t[:, root_joint] = 0
858
+ _local_state_r[:, root_joint] = th.eye(3, device=device)[None]
859
+ _local_state_s[:, root_joint] = 1
860
+
861
+ (
862
+ local_state_t,
863
+ local_state_r,
864
+ local_state_s,
865
+ ) = (
866
+ _local_state_t,
867
+ _local_state_r,
868
+ _local_state_s,
869
+ )
870
+
871
+ return local_state_t, local_state_r, local_state_s