pymomentum-cpu 0.1.93.post0__cp312-cp312-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (517) hide show
  1. include/axel/BoundingBox.h +59 -0
  2. include/axel/Bvh.h +708 -0
  3. include/axel/BvhBase.h +75 -0
  4. include/axel/BvhCommon.h +43 -0
  5. include/axel/BvhEmbree.h +87 -0
  6. include/axel/BvhFactory.h +34 -0
  7. include/axel/Checks.h +21 -0
  8. include/axel/DualContouring.h +79 -0
  9. include/axel/KdTree.h +208 -0
  10. include/axel/Log.h +22 -0
  11. include/axel/MeshToSdf.h +123 -0
  12. include/axel/Profile.h +64 -0
  13. include/axel/Ray.h +45 -0
  14. include/axel/SignedDistanceField.h +248 -0
  15. include/axel/SimdKdTree.h +515 -0
  16. include/axel/TriBvh.h +157 -0
  17. include/axel/TriBvhEmbree.h +57 -0
  18. include/axel/common/Constants.h +27 -0
  19. include/axel/common/Types.h +21 -0
  20. include/axel/common/VectorizationTypes.h +58 -0
  21. include/axel/math/BoundingBoxUtils.h +54 -0
  22. include/axel/math/ContinuousCollisionDetection.h +48 -0
  23. include/axel/math/CoplanarityCheck.h +30 -0
  24. include/axel/math/EdgeEdgeDistance.h +31 -0
  25. include/axel/math/MeshHoleFilling.h +117 -0
  26. include/axel/math/PointTriangleProjection.h +34 -0
  27. include/axel/math/PointTriangleProjectionDefinitions.h +209 -0
  28. include/axel/math/RayTriangleIntersection.h +36 -0
  29. include/momentum/character/blend_shape.h +97 -0
  30. include/momentum/character/blend_shape_base.h +86 -0
  31. include/momentum/character/blend_shape_skinning.h +96 -0
  32. include/momentum/character/character.h +272 -0
  33. include/momentum/character/character_state.h +108 -0
  34. include/momentum/character/character_utility.h +128 -0
  35. include/momentum/character/collision_geometry.h +80 -0
  36. include/momentum/character/collision_geometry_state.h +130 -0
  37. include/momentum/character/fwd.h +262 -0
  38. include/momentum/character/inverse_parameter_transform.h +58 -0
  39. include/momentum/character/joint.h +82 -0
  40. include/momentum/character/joint_state.h +241 -0
  41. include/momentum/character/linear_skinning.h +139 -0
  42. include/momentum/character/locator.h +94 -0
  43. include/momentum/character/locator_state.h +43 -0
  44. include/momentum/character/marker.h +48 -0
  45. include/momentum/character/mesh_state.h +71 -0
  46. include/momentum/character/parameter_limits.h +144 -0
  47. include/momentum/character/parameter_transform.h +207 -0
  48. include/momentum/character/pose_shape.h +65 -0
  49. include/momentum/character/skeleton.h +85 -0
  50. include/momentum/character/skeleton_state.h +227 -0
  51. include/momentum/character/skeleton_utility.h +38 -0
  52. include/momentum/character/skin_weights.h +67 -0
  53. include/momentum/character/skinned_locator.h +80 -0
  54. include/momentum/character/types.h +202 -0
  55. include/momentum/character_sequence_solver/fwd.h +200 -0
  56. include/momentum/character_sequence_solver/model_parameters_sequence_error_function.h +65 -0
  57. include/momentum/character_sequence_solver/multipose_solver.h +65 -0
  58. include/momentum/character_sequence_solver/multipose_solver_function.h +82 -0
  59. include/momentum/character_sequence_solver/sequence_error_function.h +104 -0
  60. include/momentum/character_sequence_solver/sequence_solver.h +155 -0
  61. include/momentum/character_sequence_solver/sequence_solver_function.h +158 -0
  62. include/momentum/character_sequence_solver/state_sequence_error_function.h +117 -0
  63. include/momentum/character_sequence_solver/vertex_sequence_error_function.h +123 -0
  64. include/momentum/character_solver/aim_error_function.h +112 -0
  65. include/momentum/character_solver/collision_error_function.h +92 -0
  66. include/momentum/character_solver/collision_error_function_stateless.h +75 -0
  67. include/momentum/character_solver/constraint_error_function-inl.h +324 -0
  68. include/momentum/character_solver/constraint_error_function.h +248 -0
  69. include/momentum/character_solver/distance_error_function.h +77 -0
  70. include/momentum/character_solver/error_function_utils.h +60 -0
  71. include/momentum/character_solver/fixed_axis_error_function.h +139 -0
  72. include/momentum/character_solver/fwd.h +943 -0
  73. include/momentum/character_solver/gauss_newton_solver_qr.h +64 -0
  74. include/momentum/character_solver/height_error_function.h +176 -0
  75. include/momentum/character_solver/joint_to_joint_distance_error_function.h +111 -0
  76. include/momentum/character_solver/limit_error_function.h +57 -0
  77. include/momentum/character_solver/model_parameters_error_function.h +64 -0
  78. include/momentum/character_solver/normal_error_function.h +73 -0
  79. include/momentum/character_solver/orientation_error_function.h +74 -0
  80. include/momentum/character_solver/plane_error_function.h +102 -0
  81. include/momentum/character_solver/point_triangle_vertex_error_function.h +141 -0
  82. include/momentum/character_solver/pose_prior_error_function.h +80 -0
  83. include/momentum/character_solver/position_error_function.h +75 -0
  84. include/momentum/character_solver/projection_error_function.h +93 -0
  85. include/momentum/character_solver/simd_collision_error_function.h +99 -0
  86. include/momentum/character_solver/simd_normal_error_function.h +157 -0
  87. include/momentum/character_solver/simd_plane_error_function.h +164 -0
  88. include/momentum/character_solver/simd_position_error_function.h +165 -0
  89. include/momentum/character_solver/skeleton_error_function.h +151 -0
  90. include/momentum/character_solver/skeleton_solver_function.h +94 -0
  91. include/momentum/character_solver/skinned_locator_error_function.h +166 -0
  92. include/momentum/character_solver/skinned_locator_triangle_error_function.h +146 -0
  93. include/momentum/character_solver/skinning_weight_iterator.h +80 -0
  94. include/momentum/character_solver/state_error_function.h +119 -0
  95. include/momentum/character_solver/transform_pose.h +80 -0
  96. include/momentum/character_solver/trust_region_qr.h +80 -0
  97. include/momentum/character_solver/vertex_error_function.h +155 -0
  98. include/momentum/character_solver/vertex_projection_error_function.h +117 -0
  99. include/momentum/character_solver/vertex_vertex_distance_error_function.h +147 -0
  100. include/momentum/common/aligned.h +155 -0
  101. include/momentum/common/checks.h +27 -0
  102. include/momentum/common/exception.h +70 -0
  103. include/momentum/common/filesystem.h +20 -0
  104. include/momentum/common/fwd.h +27 -0
  105. include/momentum/common/log.h +173 -0
  106. include/momentum/common/log_channel.h +17 -0
  107. include/momentum/common/memory.h +71 -0
  108. include/momentum/common/profile.h +79 -0
  109. include/momentum/common/progress_bar.h +37 -0
  110. include/momentum/common/string.h +52 -0
  111. include/momentum/diff_ik/ceres_utility.h +73 -0
  112. include/momentum/diff_ik/fully_differentiable_body_ik.h +58 -0
  113. include/momentum/diff_ik/fully_differentiable_distance_error_function.h +69 -0
  114. include/momentum/diff_ik/fully_differentiable_motion_error_function.h +46 -0
  115. include/momentum/diff_ik/fully_differentiable_orientation_error_function.h +114 -0
  116. include/momentum/diff_ik/fully_differentiable_pose_prior_error_function.h +76 -0
  117. include/momentum/diff_ik/fully_differentiable_position_error_function.h +138 -0
  118. include/momentum/diff_ik/fully_differentiable_projection_error_function.h +65 -0
  119. include/momentum/diff_ik/fully_differentiable_skeleton_error_function.h +160 -0
  120. include/momentum/diff_ik/fully_differentiable_state_error_function.h +54 -0
  121. include/momentum/diff_ik/fwd.h +385 -0
  122. include/momentum/diff_ik/union_error_function.h +67 -0
  123. include/momentum/gui/rerun/eigen_adapters.h +70 -0
  124. include/momentum/gui/rerun/logger.h +102 -0
  125. include/momentum/gui/rerun/logging_redirect.h +27 -0
  126. include/momentum/io/character_io.h +98 -0
  127. include/momentum/io/common/gsl_utils.h +50 -0
  128. include/momentum/io/common/stream_utils.h +65 -0
  129. include/momentum/io/fbx/fbx_io.h +135 -0
  130. include/momentum/io/fbx/fbx_memory_stream.h +70 -0
  131. include/momentum/io/fbx/openfbx_loader.h +62 -0
  132. include/momentum/io/fbx/polygon_data.h +60 -0
  133. include/momentum/io/file_save_options.h +107 -0
  134. include/momentum/io/gltf/gltf_builder.h +141 -0
  135. include/momentum/io/gltf/gltf_io.h +149 -0
  136. include/momentum/io/gltf/utils/accessor_utils.h +299 -0
  137. include/momentum/io/gltf/utils/coordinate_utils.h +60 -0
  138. include/momentum/io/gltf/utils/json_utils.h +102 -0
  139. include/momentum/io/legacy_json/legacy_json_io.h +70 -0
  140. include/momentum/io/marker/c3d_io.h +30 -0
  141. include/momentum/io/marker/conversions.h +57 -0
  142. include/momentum/io/marker/coordinate_system.h +30 -0
  143. include/momentum/io/marker/marker_io.h +56 -0
  144. include/momentum/io/marker/trc_io.h +27 -0
  145. include/momentum/io/motion/mmo_io.h +97 -0
  146. include/momentum/io/shape/blend_shape_io.h +82 -0
  147. include/momentum/io/shape/pose_shape_io.h +21 -0
  148. include/momentum/io/skeleton/locator_io.h +41 -0
  149. include/momentum/io/skeleton/mppca_io.h +26 -0
  150. include/momentum/io/skeleton/parameter_limits_io.h +38 -0
  151. include/momentum/io/skeleton/parameter_transform_io.h +80 -0
  152. include/momentum/io/skeleton/parameters_io.h +20 -0
  153. include/momentum/io/skeleton/utility.h +67 -0
  154. include/momentum/io/urdf/urdf_io.h +26 -0
  155. include/momentum/io/usd/usd_io.h +36 -0
  156. include/momentum/marker_tracking/app_utils.h +64 -0
  157. include/momentum/marker_tracking/marker_tracker.h +221 -0
  158. include/momentum/marker_tracking/process_markers.h +58 -0
  159. include/momentum/marker_tracking/tracker_utils.h +99 -0
  160. include/momentum/math/constants.h +82 -0
  161. include/momentum/math/covariance_matrix.h +84 -0
  162. include/momentum/math/fmt_eigen.h +23 -0
  163. include/momentum/math/fwd.h +132 -0
  164. include/momentum/math/generalized_loss.h +61 -0
  165. include/momentum/math/intersection.h +32 -0
  166. include/momentum/math/mesh.h +84 -0
  167. include/momentum/math/mppca.h +67 -0
  168. include/momentum/math/online_householder_qr.h +516 -0
  169. include/momentum/math/random-inl.h +404 -0
  170. include/momentum/math/random.h +310 -0
  171. include/momentum/math/simd_generalized_loss.h +40 -0
  172. include/momentum/math/transform.h +229 -0
  173. include/momentum/math/types.h +461 -0
  174. include/momentum/math/utility.h +324 -0
  175. include/momentum/rasterizer/camera.h +453 -0
  176. include/momentum/rasterizer/fwd.h +102 -0
  177. include/momentum/rasterizer/geometry.h +83 -0
  178. include/momentum/rasterizer/image.h +18 -0
  179. include/momentum/rasterizer/rasterizer.h +583 -0
  180. include/momentum/rasterizer/tensor.h +140 -0
  181. include/momentum/rasterizer/text_rasterizer.h +89 -0
  182. include/momentum/rasterizer/utility.h +268 -0
  183. include/momentum/simd/simd.h +221 -0
  184. include/momentum/solver/fwd.h +131 -0
  185. include/momentum/solver/gauss_newton_solver.h +136 -0
  186. include/momentum/solver/gradient_descent_solver.h +65 -0
  187. include/momentum/solver/solver.h +155 -0
  188. include/momentum/solver/solver_function.h +126 -0
  189. include/momentum/solver/subset_gauss_newton_solver.h +109 -0
  190. include/rerun/archetypes/annotation_context.hpp +157 -0
  191. include/rerun/archetypes/arrows2d.hpp +271 -0
  192. include/rerun/archetypes/arrows3d.hpp +257 -0
  193. include/rerun/archetypes/asset3d.hpp +262 -0
  194. include/rerun/archetypes/asset_video.hpp +275 -0
  195. include/rerun/archetypes/bar_chart.hpp +261 -0
  196. include/rerun/archetypes/boxes2d.hpp +293 -0
  197. include/rerun/archetypes/boxes3d.hpp +369 -0
  198. include/rerun/archetypes/capsules3d.hpp +333 -0
  199. include/rerun/archetypes/clear.hpp +180 -0
  200. include/rerun/archetypes/depth_image.hpp +425 -0
  201. include/rerun/archetypes/ellipsoids3d.hpp +384 -0
  202. include/rerun/archetypes/encoded_image.hpp +250 -0
  203. include/rerun/archetypes/geo_line_strings.hpp +166 -0
  204. include/rerun/archetypes/geo_points.hpp +177 -0
  205. include/rerun/archetypes/graph_edges.hpp +152 -0
  206. include/rerun/archetypes/graph_nodes.hpp +206 -0
  207. include/rerun/archetypes/image.hpp +434 -0
  208. include/rerun/archetypes/instance_poses3d.hpp +221 -0
  209. include/rerun/archetypes/line_strips2d.hpp +289 -0
  210. include/rerun/archetypes/line_strips3d.hpp +270 -0
  211. include/rerun/archetypes/mesh3d.hpp +387 -0
  212. include/rerun/archetypes/pinhole.hpp +385 -0
  213. include/rerun/archetypes/points2d.hpp +333 -0
  214. include/rerun/archetypes/points3d.hpp +369 -0
  215. include/rerun/archetypes/recording_properties.hpp +132 -0
  216. include/rerun/archetypes/scalar.hpp +170 -0
  217. include/rerun/archetypes/scalars.hpp +153 -0
  218. include/rerun/archetypes/segmentation_image.hpp +305 -0
  219. include/rerun/archetypes/series_line.hpp +274 -0
  220. include/rerun/archetypes/series_lines.hpp +271 -0
  221. include/rerun/archetypes/series_point.hpp +265 -0
  222. include/rerun/archetypes/series_points.hpp +251 -0
  223. include/rerun/archetypes/tensor.hpp +213 -0
  224. include/rerun/archetypes/text_document.hpp +200 -0
  225. include/rerun/archetypes/text_log.hpp +211 -0
  226. include/rerun/archetypes/transform3d.hpp +925 -0
  227. include/rerun/archetypes/video_frame_reference.hpp +295 -0
  228. include/rerun/archetypes/view_coordinates.hpp +393 -0
  229. include/rerun/archetypes.hpp +43 -0
  230. include/rerun/arrow_utils.hpp +32 -0
  231. include/rerun/as_components.hpp +90 -0
  232. include/rerun/blueprint/archetypes/background.hpp +113 -0
  233. include/rerun/blueprint/archetypes/container_blueprint.hpp +259 -0
  234. include/rerun/blueprint/archetypes/dataframe_query.hpp +178 -0
  235. include/rerun/blueprint/archetypes/entity_behavior.hpp +130 -0
  236. include/rerun/blueprint/archetypes/force_center.hpp +115 -0
  237. include/rerun/blueprint/archetypes/force_collision_radius.hpp +141 -0
  238. include/rerun/blueprint/archetypes/force_link.hpp +136 -0
  239. include/rerun/blueprint/archetypes/force_many_body.hpp +124 -0
  240. include/rerun/blueprint/archetypes/force_position.hpp +132 -0
  241. include/rerun/blueprint/archetypes/line_grid3d.hpp +178 -0
  242. include/rerun/blueprint/archetypes/map_background.hpp +104 -0
  243. include/rerun/blueprint/archetypes/map_zoom.hpp +103 -0
  244. include/rerun/blueprint/archetypes/near_clip_plane.hpp +109 -0
  245. include/rerun/blueprint/archetypes/panel_blueprint.hpp +95 -0
  246. include/rerun/blueprint/archetypes/plot_legend.hpp +118 -0
  247. include/rerun/blueprint/archetypes/scalar_axis.hpp +116 -0
  248. include/rerun/blueprint/archetypes/tensor_scalar_mapping.hpp +146 -0
  249. include/rerun/blueprint/archetypes/tensor_slice_selection.hpp +167 -0
  250. include/rerun/blueprint/archetypes/tensor_view_fit.hpp +95 -0
  251. include/rerun/blueprint/archetypes/view_blueprint.hpp +170 -0
  252. include/rerun/blueprint/archetypes/view_contents.hpp +142 -0
  253. include/rerun/blueprint/archetypes/viewport_blueprint.hpp +200 -0
  254. include/rerun/blueprint/archetypes/visible_time_ranges.hpp +116 -0
  255. include/rerun/blueprint/archetypes/visual_bounds2d.hpp +109 -0
  256. include/rerun/blueprint/archetypes/visualizer_overrides.hpp +113 -0
  257. include/rerun/blueprint/archetypes.hpp +29 -0
  258. include/rerun/blueprint/components/active_tab.hpp +82 -0
  259. include/rerun/blueprint/components/apply_latest_at.hpp +79 -0
  260. include/rerun/blueprint/components/auto_layout.hpp +77 -0
  261. include/rerun/blueprint/components/auto_views.hpp +77 -0
  262. include/rerun/blueprint/components/background_kind.hpp +66 -0
  263. include/rerun/blueprint/components/column_share.hpp +78 -0
  264. include/rerun/blueprint/components/component_column_selector.hpp +81 -0
  265. include/rerun/blueprint/components/container_kind.hpp +65 -0
  266. include/rerun/blueprint/components/corner2d.hpp +64 -0
  267. include/rerun/blueprint/components/enabled.hpp +77 -0
  268. include/rerun/blueprint/components/filter_by_range.hpp +74 -0
  269. include/rerun/blueprint/components/filter_is_not_null.hpp +77 -0
  270. include/rerun/blueprint/components/force_distance.hpp +82 -0
  271. include/rerun/blueprint/components/force_iterations.hpp +82 -0
  272. include/rerun/blueprint/components/force_strength.hpp +82 -0
  273. include/rerun/blueprint/components/grid_columns.hpp +78 -0
  274. include/rerun/blueprint/components/grid_spacing.hpp +78 -0
  275. include/rerun/blueprint/components/included_content.hpp +86 -0
  276. include/rerun/blueprint/components/lock_range_during_zoom.hpp +82 -0
  277. include/rerun/blueprint/components/map_provider.hpp +64 -0
  278. include/rerun/blueprint/components/near_clip_plane.hpp +82 -0
  279. include/rerun/blueprint/components/panel_state.hpp +61 -0
  280. include/rerun/blueprint/components/query_expression.hpp +89 -0
  281. include/rerun/blueprint/components/root_container.hpp +77 -0
  282. include/rerun/blueprint/components/row_share.hpp +78 -0
  283. include/rerun/blueprint/components/selected_columns.hpp +76 -0
  284. include/rerun/blueprint/components/tensor_dimension_index_slider.hpp +90 -0
  285. include/rerun/blueprint/components/timeline_name.hpp +76 -0
  286. include/rerun/blueprint/components/view_class.hpp +76 -0
  287. include/rerun/blueprint/components/view_fit.hpp +61 -0
  288. include/rerun/blueprint/components/view_maximized.hpp +79 -0
  289. include/rerun/blueprint/components/view_origin.hpp +81 -0
  290. include/rerun/blueprint/components/viewer_recommendation_hash.hpp +82 -0
  291. include/rerun/blueprint/components/visible_time_range.hpp +77 -0
  292. include/rerun/blueprint/components/visual_bounds2d.hpp +74 -0
  293. include/rerun/blueprint/components/visualizer_override.hpp +86 -0
  294. include/rerun/blueprint/components/zoom_level.hpp +78 -0
  295. include/rerun/blueprint/components.hpp +41 -0
  296. include/rerun/blueprint/datatypes/component_column_selector.hpp +61 -0
  297. include/rerun/blueprint/datatypes/filter_by_range.hpp +59 -0
  298. include/rerun/blueprint/datatypes/filter_is_not_null.hpp +61 -0
  299. include/rerun/blueprint/datatypes/selected_columns.hpp +62 -0
  300. include/rerun/blueprint/datatypes/tensor_dimension_index_slider.hpp +63 -0
  301. include/rerun/blueprint/datatypes.hpp +9 -0
  302. include/rerun/c/arrow_c_data_interface.h +111 -0
  303. include/rerun/c/compiler_utils.h +10 -0
  304. include/rerun/c/rerun.h +627 -0
  305. include/rerun/c/sdk_info.h +28 -0
  306. include/rerun/collection.hpp +496 -0
  307. include/rerun/collection_adapter.hpp +43 -0
  308. include/rerun/collection_adapter_builtins.hpp +138 -0
  309. include/rerun/compiler_utils.hpp +61 -0
  310. include/rerun/component_batch.hpp +163 -0
  311. include/rerun/component_column.hpp +111 -0
  312. include/rerun/component_descriptor.hpp +142 -0
  313. include/rerun/component_type.hpp +35 -0
  314. include/rerun/components/aggregation_policy.hpp +76 -0
  315. include/rerun/components/albedo_factor.hpp +74 -0
  316. include/rerun/components/annotation_context.hpp +102 -0
  317. include/rerun/components/axis_length.hpp +74 -0
  318. include/rerun/components/blob.hpp +73 -0
  319. include/rerun/components/class_id.hpp +71 -0
  320. include/rerun/components/clear_is_recursive.hpp +75 -0
  321. include/rerun/components/color.hpp +99 -0
  322. include/rerun/components/colormap.hpp +99 -0
  323. include/rerun/components/depth_meter.hpp +84 -0
  324. include/rerun/components/draw_order.hpp +79 -0
  325. include/rerun/components/entity_path.hpp +83 -0
  326. include/rerun/components/fill_mode.hpp +72 -0
  327. include/rerun/components/fill_ratio.hpp +79 -0
  328. include/rerun/components/gamma_correction.hpp +80 -0
  329. include/rerun/components/geo_line_string.hpp +63 -0
  330. include/rerun/components/graph_edge.hpp +75 -0
  331. include/rerun/components/graph_node.hpp +79 -0
  332. include/rerun/components/graph_type.hpp +57 -0
  333. include/rerun/components/half_size2d.hpp +91 -0
  334. include/rerun/components/half_size3d.hpp +95 -0
  335. include/rerun/components/image_buffer.hpp +86 -0
  336. include/rerun/components/image_format.hpp +84 -0
  337. include/rerun/components/image_plane_distance.hpp +77 -0
  338. include/rerun/components/interactive.hpp +76 -0
  339. include/rerun/components/keypoint_id.hpp +74 -0
  340. include/rerun/components/lat_lon.hpp +89 -0
  341. include/rerun/components/length.hpp +77 -0
  342. include/rerun/components/line_strip2d.hpp +73 -0
  343. include/rerun/components/line_strip3d.hpp +73 -0
  344. include/rerun/components/magnification_filter.hpp +63 -0
  345. include/rerun/components/marker_shape.hpp +82 -0
  346. include/rerun/components/marker_size.hpp +74 -0
  347. include/rerun/components/media_type.hpp +157 -0
  348. include/rerun/components/name.hpp +83 -0
  349. include/rerun/components/opacity.hpp +77 -0
  350. include/rerun/components/pinhole_projection.hpp +94 -0
  351. include/rerun/components/plane3d.hpp +75 -0
  352. include/rerun/components/pose_rotation_axis_angle.hpp +73 -0
  353. include/rerun/components/pose_rotation_quat.hpp +71 -0
  354. include/rerun/components/pose_scale3d.hpp +102 -0
  355. include/rerun/components/pose_transform_mat3x3.hpp +87 -0
  356. include/rerun/components/pose_translation3d.hpp +96 -0
  357. include/rerun/components/position2d.hpp +86 -0
  358. include/rerun/components/position3d.hpp +90 -0
  359. include/rerun/components/radius.hpp +98 -0
  360. include/rerun/components/range1d.hpp +75 -0
  361. include/rerun/components/resolution.hpp +88 -0
  362. include/rerun/components/rotation_axis_angle.hpp +72 -0
  363. include/rerun/components/rotation_quat.hpp +71 -0
  364. include/rerun/components/scalar.hpp +76 -0
  365. include/rerun/components/scale3d.hpp +102 -0
  366. include/rerun/components/series_visible.hpp +76 -0
  367. include/rerun/components/show_labels.hpp +79 -0
  368. include/rerun/components/stroke_width.hpp +74 -0
  369. include/rerun/components/tensor_data.hpp +94 -0
  370. include/rerun/components/tensor_dimension_index_selection.hpp +77 -0
  371. include/rerun/components/tensor_height_dimension.hpp +71 -0
  372. include/rerun/components/tensor_width_dimension.hpp +71 -0
  373. include/rerun/components/texcoord2d.hpp +101 -0
  374. include/rerun/components/text.hpp +83 -0
  375. include/rerun/components/text_log_level.hpp +110 -0
  376. include/rerun/components/timestamp.hpp +76 -0
  377. include/rerun/components/transform_mat3x3.hpp +92 -0
  378. include/rerun/components/transform_relation.hpp +66 -0
  379. include/rerun/components/translation3d.hpp +96 -0
  380. include/rerun/components/triangle_indices.hpp +85 -0
  381. include/rerun/components/value_range.hpp +78 -0
  382. include/rerun/components/vector2d.hpp +92 -0
  383. include/rerun/components/vector3d.hpp +96 -0
  384. include/rerun/components/video_timestamp.hpp +120 -0
  385. include/rerun/components/view_coordinates.hpp +346 -0
  386. include/rerun/components/visible.hpp +74 -0
  387. include/rerun/components.hpp +77 -0
  388. include/rerun/config.hpp +52 -0
  389. include/rerun/datatypes/angle.hpp +76 -0
  390. include/rerun/datatypes/annotation_info.hpp +76 -0
  391. include/rerun/datatypes/blob.hpp +67 -0
  392. include/rerun/datatypes/bool.hpp +57 -0
  393. include/rerun/datatypes/channel_datatype.hpp +87 -0
  394. include/rerun/datatypes/class_description.hpp +92 -0
  395. include/rerun/datatypes/class_description_map_elem.hpp +69 -0
  396. include/rerun/datatypes/class_id.hpp +62 -0
  397. include/rerun/datatypes/color_model.hpp +68 -0
  398. include/rerun/datatypes/dvec2d.hpp +76 -0
  399. include/rerun/datatypes/entity_path.hpp +60 -0
  400. include/rerun/datatypes/float32.hpp +62 -0
  401. include/rerun/datatypes/float64.hpp +62 -0
  402. include/rerun/datatypes/image_format.hpp +107 -0
  403. include/rerun/datatypes/keypoint_id.hpp +63 -0
  404. include/rerun/datatypes/keypoint_pair.hpp +65 -0
  405. include/rerun/datatypes/mat3x3.hpp +105 -0
  406. include/rerun/datatypes/mat4x4.hpp +119 -0
  407. include/rerun/datatypes/pixel_format.hpp +142 -0
  408. include/rerun/datatypes/plane3d.hpp +60 -0
  409. include/rerun/datatypes/quaternion.hpp +110 -0
  410. include/rerun/datatypes/range1d.hpp +59 -0
  411. include/rerun/datatypes/range2d.hpp +55 -0
  412. include/rerun/datatypes/rgba32.hpp +94 -0
  413. include/rerun/datatypes/rotation_axis_angle.hpp +67 -0
  414. include/rerun/datatypes/tensor_buffer.hpp +529 -0
  415. include/rerun/datatypes/tensor_data.hpp +100 -0
  416. include/rerun/datatypes/tensor_dimension_index_selection.hpp +58 -0
  417. include/rerun/datatypes/tensor_dimension_selection.hpp +56 -0
  418. include/rerun/datatypes/time_int.hpp +62 -0
  419. include/rerun/datatypes/time_range.hpp +55 -0
  420. include/rerun/datatypes/time_range_boundary.hpp +175 -0
  421. include/rerun/datatypes/uint16.hpp +62 -0
  422. include/rerun/datatypes/uint32.hpp +62 -0
  423. include/rerun/datatypes/uint64.hpp +62 -0
  424. include/rerun/datatypes/utf8.hpp +76 -0
  425. include/rerun/datatypes/utf8pair.hpp +62 -0
  426. include/rerun/datatypes/uuid.hpp +60 -0
  427. include/rerun/datatypes/uvec2d.hpp +76 -0
  428. include/rerun/datatypes/uvec3d.hpp +80 -0
  429. include/rerun/datatypes/uvec4d.hpp +59 -0
  430. include/rerun/datatypes/vec2d.hpp +76 -0
  431. include/rerun/datatypes/vec3d.hpp +80 -0
  432. include/rerun/datatypes/vec4d.hpp +84 -0
  433. include/rerun/datatypes/video_timestamp.hpp +67 -0
  434. include/rerun/datatypes/view_coordinates.hpp +87 -0
  435. include/rerun/datatypes/visible_time_range.hpp +57 -0
  436. include/rerun/datatypes.hpp +51 -0
  437. include/rerun/demo_utils.hpp +75 -0
  438. include/rerun/entity_path.hpp +20 -0
  439. include/rerun/error.hpp +180 -0
  440. include/rerun/half.hpp +10 -0
  441. include/rerun/image_utils.hpp +187 -0
  442. include/rerun/indicator_component.hpp +59 -0
  443. include/rerun/loggable.hpp +54 -0
  444. include/rerun/recording_stream.hpp +960 -0
  445. include/rerun/rerun_sdk_export.hpp +25 -0
  446. include/rerun/result.hpp +86 -0
  447. include/rerun/rotation3d.hpp +33 -0
  448. include/rerun/sdk_info.hpp +20 -0
  449. include/rerun/spawn.hpp +21 -0
  450. include/rerun/spawn_options.hpp +57 -0
  451. include/rerun/string_utils.hpp +16 -0
  452. include/rerun/third_party/cxxopts.hpp +2198 -0
  453. include/rerun/time_column.hpp +288 -0
  454. include/rerun/timeline.hpp +38 -0
  455. include/rerun/type_traits.hpp +40 -0
  456. include/rerun.hpp +86 -0
  457. lib/cmake/axel/axel-config.cmake +45 -0
  458. lib/cmake/axel/axelTargets-release.cmake +19 -0
  459. lib/cmake/axel/axelTargets.cmake +108 -0
  460. lib/cmake/momentum/FindFbxSdk.cmake +115 -0
  461. lib/cmake/momentum/Findre2.cmake +52 -0
  462. lib/cmake/momentum/momentum-config.cmake +67 -0
  463. lib/cmake/momentum/momentumTargets-release.cmake +259 -0
  464. lib/cmake/momentum/momentumTargets.cmake +385 -0
  465. lib/cmake/rerun_sdk/rerun_sdkConfig.cmake +70 -0
  466. lib/cmake/rerun_sdk/rerun_sdkConfigVersion.cmake +83 -0
  467. lib/cmake/rerun_sdk/rerun_sdkTargets-release.cmake +19 -0
  468. lib/cmake/rerun_sdk/rerun_sdkTargets.cmake +108 -0
  469. lib/libarrow.a +0 -0
  470. lib/libarrow_bundled_dependencies.a +0 -0
  471. lib/libaxel.a +0 -0
  472. lib/libmomentum_app_utils.a +0 -0
  473. lib/libmomentum_character.a +0 -0
  474. lib/libmomentum_character_sequence_solver.a +0 -0
  475. lib/libmomentum_character_solver.a +0 -0
  476. lib/libmomentum_common.a +0 -0
  477. lib/libmomentum_diff_ik.a +0 -0
  478. lib/libmomentum_io.a +0 -0
  479. lib/libmomentum_io_common.a +0 -0
  480. lib/libmomentum_io_fbx.a +0 -0
  481. lib/libmomentum_io_gltf.a +0 -0
  482. lib/libmomentum_io_legacy_json.a +0 -0
  483. lib/libmomentum_io_marker.a +0 -0
  484. lib/libmomentum_io_motion.a +0 -0
  485. lib/libmomentum_io_shape.a +0 -0
  486. lib/libmomentum_io_skeleton.a +0 -0
  487. lib/libmomentum_io_urdf.a +0 -0
  488. lib/libmomentum_marker_tracker.a +0 -0
  489. lib/libmomentum_math.a +0 -0
  490. lib/libmomentum_online_qr.a +0 -0
  491. lib/libmomentum_process_markers.a +0 -0
  492. lib/libmomentum_rerun.a +0 -0
  493. lib/libmomentum_simd_constraints.a +0 -0
  494. lib/libmomentum_simd_generalized_loss.a +0 -0
  495. lib/libmomentum_skeleton.a +0 -0
  496. lib/libmomentum_solver.a +0 -0
  497. lib/librerun_c__macos_arm64.a +0 -0
  498. lib/librerun_sdk.a +0 -0
  499. pymomentum/axel.cpython-312-darwin.so +0 -0
  500. pymomentum/backend/__init__.py +16 -0
  501. pymomentum/backend/skel_state_backend.py +631 -0
  502. pymomentum/backend/trs_backend.py +889 -0
  503. pymomentum/backend/utils.py +224 -0
  504. pymomentum/geometry.cpython-312-darwin.so +0 -0
  505. pymomentum/marker_tracking.cpython-312-darwin.so +0 -0
  506. pymomentum/quaternion.py +740 -0
  507. pymomentum/skel_state.py +514 -0
  508. pymomentum/solver.cpython-312-darwin.so +0 -0
  509. pymomentum/solver2.cpython-312-darwin.so +0 -0
  510. pymomentum/torch/character.py +868 -0
  511. pymomentum/torch/parameter_limits.py +494 -0
  512. pymomentum/torch/utility.py +20 -0
  513. pymomentum/trs.py +535 -0
  514. pymomentum_cpu-0.1.93.post0.dist-info/METADATA +126 -0
  515. pymomentum_cpu-0.1.93.post0.dist-info/RECORD +517 -0
  516. pymomentum_cpu-0.1.93.post0.dist-info/WHEEL +5 -0
  517. pymomentum_cpu-0.1.93.post0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,889 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ # pyre-strict
7
+ """
8
+ TRS Backend for PyMomentum
9
+
10
+ This module provides efficient forward kinematics and skinning operations using
11
+ the TRS (Translation-Rotation-Scale) representation where each transformation
12
+ component is stored separately.
13
+
14
+ The TRS representation uses separate tensors for translation (3D), rotation matrices (3x3),
15
+ and scale factors (1D), making it suitable for applications that need explicit access
16
+ to individual transformation components.
17
+
18
+ Performance Notes:
19
+ This backend is typically 25-50% faster than the skeleton state backend in PyTorch,
20
+ likely due to not requiring quaternion normalization operations. While it doesn't
21
+ match the C++ reference implementation exactly (use skel_state_backend for that),
22
+ it provides excellent performance for PyTorch-based applications.
23
+
24
+ Key Functions:
25
+ - global_trs_state_from_local_trs_state: Forward kinematics from local to global joint states
26
+ - skin_points_from_trs_state: Linear blend skinning using TRS transformations
27
+ - local_trs_state_from_joint_params: Convert joint parameters to local TRS states
28
+
29
+ Related Modules:
30
+ - skel_state_backend: Alternative backend using compact 8-parameter skeleton states
31
+ - trs: Core TRS transformation operations and utilities
32
+ """
33
+
34
+ from typing import List, Tuple
35
+
36
+ import torch as th
37
+ from pymomentum import trs
38
+
39
+
40
+ @th.jit.script
41
+ def global_trs_state_from_local_trs_state_impl(
42
+ local_state_t: th.Tensor,
43
+ local_state_r: th.Tensor,
44
+ local_state_s: th.Tensor,
45
+ prefix_mul_indices: List[th.Tensor],
46
+ save_intermediate_results: bool = True,
47
+ use_double_precision: bool = True,
48
+ ) -> Tuple[
49
+ th.Tensor, th.Tensor, th.Tensor, List[Tuple[th.Tensor, th.Tensor, th.Tensor]]
50
+ ]:
51
+ """
52
+ Compute global TRS state from local joint transformations using forward kinematics.
53
+
54
+ This function implements forward kinematics (FK) using prefix multiplication for efficient
55
+ parallel computation. Each joint's local TRS transformation is composed with its parent's
56
+ global transformation to produce the joint's global transformation.
57
+
58
+ The TRS representation uses separate tensors for each transformation component:
59
+ - Translation (3D): translation vector [tx, ty, tz]
60
+ - Rotation (3x3): rotation matrix
61
+ - Scale (1D): uniform scale factor [s]
62
+
63
+ Forward Kinematics Formula:
64
+ For each joint j with parent p in the kinematic hierarchy:
65
+ s_global_j = s_parent * s_local_j
66
+ R_global_j = R_parent * R_local_j
67
+ t_global_j = t_parent + s_parent * R_parent * t_local_j
68
+
69
+ This corresponds to the similarity transformation: y = s * R * x + t
70
+
71
+ Args:
72
+ local_state_t: Local joint translations, shape (batch_size, num_joints, 3).
73
+ local_state_r: Local joint rotations, shape (batch_size, num_joints, 3, 3).
74
+ local_state_s: Local joint scales, shape (batch_size, num_joints, 1).
75
+ prefix_mul_indices: List of [child_index, parent_index] tensor pairs that define
76
+ the traversal order for the kinematic tree. This ordering enables efficient
77
+ parallel computation while respecting parent-child dependencies.
78
+ save_intermediate_results: If True, saves intermediate joint states during the
79
+ forward pass for use in backpropagation. Set to False for inference-only
80
+ computations to reduce memory usage.
81
+ use_double_precision: If True, performs computations in float64 for improved
82
+ numerical stability. Recommended for deep kinematic chains to minimize
83
+ accumulated floating-point errors.
84
+
85
+ Returns:
86
+ global_state_t: Global joint translations, shape (batch_size, num_joints, 3).
87
+ global_state_r: Global joint rotations, shape (batch_size, num_joints, 3, 3).
88
+ global_state_s: Global joint scales, shape (batch_size, num_joints, 1).
89
+ intermediate_results: List of (t, r, s) tuples from the forward pass.
90
+ Required for efficient gradient computation during backpropagation.
91
+ Empty if save_intermediate_results=False.
92
+
93
+ Note:
94
+ This function is JIT-compiled for performance. The prefix multiplication approach
95
+ allows vectorized batch computation while maintaining kinematic chain dependencies.
96
+ The function is not differentiable by itself - use the wrapper function for gradients.
97
+
98
+ See Also:
99
+ :func:`global_trs_state_from_local_trs_state`: User-facing wrapper with autodiff
100
+ :func:`local_trs_state_from_joint_params`: Convert joint parameters to local states
101
+ """
102
+ dtype = local_state_t.dtype
103
+ with th.no_grad():
104
+ if use_double_precision:
105
+ joint_state_t = local_state_t.clone().double()
106
+ joint_state_r = local_state_r.clone().double()
107
+ joint_state_s = local_state_s.clone().double()
108
+ else:
109
+ joint_state_t = local_state_t.clone()
110
+ joint_state_r = local_state_r.clone()
111
+ joint_state_s = local_state_s.clone()
112
+
113
+ intermediate_results: List[Tuple[th.Tensor, th.Tensor, th.Tensor]] = []
114
+
115
+ for prefix_mul_index in prefix_mul_indices:
116
+ source = prefix_mul_index[0]
117
+ target = prefix_mul_index[1]
118
+
119
+ s1 = joint_state_s[:, target]
120
+ r1 = joint_state_r[:, target]
121
+ t1 = joint_state_t[:, target]
122
+
123
+ s2 = joint_state_s[:, source]
124
+ r2 = joint_state_r[:, source]
125
+ t2 = joint_state_t[:, source]
126
+
127
+ if save_intermediate_results:
128
+ intermediate_results.append(
129
+ (
130
+ t2.clone(),
131
+ r2.clone(),
132
+ s2.clone(),
133
+ )
134
+ )
135
+
136
+ t3, r3, s3 = trs.multiply((t1, r1, s1), (t2, r2, s2))
137
+
138
+ joint_state_s[:, source] = s3
139
+ joint_state_r[:, source] = r3
140
+ joint_state_t[:, source] = t3
141
+
142
+ return (
143
+ joint_state_t.to(dtype),
144
+ joint_state_r.to(dtype),
145
+ joint_state_s.to(dtype),
146
+ intermediate_results,
147
+ )
148
+
149
+
150
+ @th.jit.script
151
+ def global_trs_state_from_local_trs_state_no_grad(
152
+ local_state_t: th.Tensor,
153
+ local_state_r: th.Tensor,
154
+ local_state_s: th.Tensor,
155
+ prefix_mul_indices: List[th.Tensor],
156
+ save_intermediate_results: bool = True,
157
+ use_double_precision: bool = True,
158
+ ) -> Tuple[
159
+ th.Tensor, th.Tensor, th.Tensor, List[Tuple[th.Tensor, th.Tensor, th.Tensor]]
160
+ ]:
161
+ """
162
+ Compute global TRS state without gradient tracking.
163
+
164
+ This is a convenience wrapper around global_trs_state_from_local_trs_state_impl
165
+ that explicitly disables gradient computation using torch.no_grad(). Useful for
166
+ inference-only forward passes to reduce memory usage.
167
+
168
+ Args:
169
+ local_state_t: Local joint translations, shape (batch_size, num_joints, 3)
170
+ local_state_r: Local joint rotations, shape (batch_size, num_joints, 3, 3)
171
+ local_state_s: Local joint scales, shape (batch_size, num_joints, 1)
172
+ prefix_mul_indices: List of [child_index, parent_index] tensor pairs
173
+ save_intermediate_results: Whether to save intermediate states for backprop
174
+ use_double_precision: Whether to use float64 for numerical stability
175
+
176
+ Returns:
177
+ global_state_t: Global joint translations, shape (batch_size, num_joints, 3)
178
+ global_state_r: Global joint rotations, shape (batch_size, num_joints, 3, 3)
179
+ global_state_s: Global joint scales, shape (batch_size, num_joints, 1)
180
+ intermediate_results: List of (t, r, s) tuples from forward pass
181
+
182
+ See Also:
183
+ :func:`global_trs_state_from_local_trs_state_impl`: Implementation function
184
+ """
185
+ with th.no_grad():
186
+ outputs = global_trs_state_from_local_trs_state_impl(
187
+ local_state_t,
188
+ local_state_r,
189
+ local_state_s,
190
+ prefix_mul_indices,
191
+ save_intermediate_results=save_intermediate_results,
192
+ use_double_precision=use_double_precision,
193
+ )
194
+ return outputs
195
+
196
+
197
+ @th.jit.script
198
+ def ik_from_global_state(
199
+ global_state_t: th.Tensor,
200
+ global_state_r: th.Tensor,
201
+ global_state_s: th.Tensor,
202
+ prefix_mul_indices: List[th.Tensor],
203
+ use_double_precision: bool = True,
204
+ ) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
205
+ dtype = global_state_t.dtype
206
+
207
+ if use_double_precision:
208
+ local_state_t = global_state_t.clone().double()
209
+ local_state_r = global_state_r.clone().double()
210
+ local_state_s = global_state_s.clone().double()
211
+ else:
212
+ local_state_t = global_state_t.clone()
213
+ local_state_r = global_state_r.clone()
214
+ local_state_s = global_state_s.clone()
215
+
216
+ # Compose the inverse of the FK transforms, in reverse order.
217
+ for prefix_mul_index in prefix_mul_indices[::-1]:
218
+ joint = prefix_mul_index[0]
219
+ parent = prefix_mul_index[1]
220
+
221
+ s1 = local_state_s[:, parent].reciprocal()
222
+ r1 = trs.rotmat_inverse(local_state_r[:, parent])
223
+ t1 = local_state_t[:, parent]
224
+
225
+ s2 = local_state_s[:, joint]
226
+ r2 = local_state_r[:, joint]
227
+ t2 = local_state_t[:, joint]
228
+
229
+ local_state_s[:, joint] = s1 * s2
230
+ local_state_r[:, joint] = trs.rotmat_multiply(r1, r2)
231
+ local_state_t[:, joint] = trs.rotmat_rotate_vector(r1, (t2 - t1) * s1)
232
+
233
+ return (
234
+ local_state_t.to(dtype),
235
+ local_state_r.to(dtype),
236
+ local_state_s.to(dtype),
237
+ )
238
+
239
+
240
+ @th.jit.script
241
+ def global_trs_state_from_local_trs_state_backprop(
242
+ joint_state_t: th.Tensor,
243
+ joint_state_r: th.Tensor,
244
+ joint_state_s: th.Tensor,
245
+ grad_joint_state_t: th.Tensor,
246
+ grad_joint_state_r: th.Tensor,
247
+ grad_joint_state_s: th.Tensor,
248
+ prefix_mul_indices: List[th.Tensor],
249
+ intermediate_results: List[Tuple[th.Tensor, th.Tensor, th.Tensor]],
250
+ use_double_precision: bool = True,
251
+ ) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
252
+ r"""
253
+ The backward pass of fk_from_local_state_no_grad.
254
+
255
+ during backprop, we have
256
+ \partial L/\partial tl_i = \sum_j \partial L/\partial tg_j * \partial tg_j/\partial tl_i
257
+ \partial L/\partial sl_i = \sum_j \partial L/\partial tg_j * \partial tg_j/\partial sl_i + \sum_j \partial L/\partial sg_j * \partial tg_j/\partial sl_i
258
+ \partial L/\partial rl_i = \sum_j \partial L/\partial tg_j * \partial tg_j/\partial rl_i + \sum_j \partial L/\partial rg_j * \partial tg_j/\partial rl_i
259
+
260
+ however, if we naively do this, sum_j is very expensive (and the jacobian is very sparse).
261
+
262
+ consider how we do prefix multiplication during forward:
263
+ assume the chain order is [0, 1, 2, 3]
264
+
265
+ forward (source <- target)
266
+
267
+ level 0: [1, 3] <- [0, 2]
268
+ now the chain is [0, 01, 2, 23]
269
+
270
+ level 1: [2, 3] <- [1, 1]
271
+ now the chain is [0, 01, 012, 0123]
272
+
273
+ now consider backward, for level 0 we need to cast
274
+ {
275
+ g(s1), g(s1s2), g(s1s2s3), g(s1s2s3s4);
276
+ g(R1), g(R1R2), g(R1R2R3), g(R1R2R3R4);
277
+ g(t1), g(t1+s1R1t2), g(s1+s1R2t2+s1R1s2R2t3), ...
278
+ }
279
+ into
280
+ {
281
+ g(s1), g(s1s2), g(s3), g(s3s4);
282
+ g(R1), g(R1R2), g(R3), g(R3R4);
283
+ g(t1), g(t1+s1R1t2), g(t3), g(t3+s3R3t4)
284
+ }
285
+ which is actually in reverse order of forward levels.
286
+ """
287
+ dtype = joint_state_t.dtype
288
+ if use_double_precision:
289
+ grad_local_state_t = grad_joint_state_t.clone().double()
290
+ grad_local_state_r = grad_joint_state_r.clone().double()
291
+ grad_local_state_s = grad_joint_state_s.clone().double()
292
+ joint_state_t = joint_state_t.clone().double()
293
+ joint_state_r = joint_state_r.clone().double()
294
+ joint_state_s = joint_state_s.clone().double()
295
+ else:
296
+ grad_local_state_t = grad_joint_state_t.clone()
297
+ grad_local_state_r = grad_joint_state_r.clone()
298
+ grad_local_state_s = grad_joint_state_s.clone()
299
+ joint_state_t = joint_state_t.clone()
300
+ joint_state_r = joint_state_r.clone()
301
+ joint_state_s = joint_state_s.clone()
302
+
303
+ # instead of calculating the original s, r and t from global state
304
+ # we just load them via forward intermediate results
305
+ for prefix_mul_index, (t, r, s) in list(
306
+ zip(prefix_mul_indices, intermediate_results)
307
+ )[::-1]:
308
+ source = prefix_mul_index[0]
309
+ target = prefix_mul_index[1]
310
+
311
+ grad_s2 = grad_local_state_s[:, source]
312
+ grad_r2 = grad_local_state_r[:, source]
313
+ grad_t2 = grad_local_state_t[:, source]
314
+
315
+ # the corresponding global state
316
+ sg1 = joint_state_s[:, target]
317
+ rg1 = joint_state_r[:, target]
318
+
319
+ # TODO: maybe we should better formulate this function
320
+ # similar to pymomentum_state.py
321
+
322
+ # backward accumulation on the reduced child state (source)
323
+ # (translate torch.einsum as explicit summations to improve speed)
324
+ grad_s = sg1 * grad_s2
325
+ # original: grad_t = sg1 * th.einsum("bjyx,bjy->bjx", rg1, grad_t2)
326
+ grad_t = sg1 * (rg1 * grad_t2[..., None]).sum(dim=2)
327
+ # original: grad_r = th.einsum("bjyx,bjyz->bjxz", rg1, grad_r2)
328
+ grad_r = (rg1[:, :, :, :, None] * grad_r2[:, :, :, None, :]).sum(dim=2)
329
+
330
+ # backward accumulation on the ancestor state (target)
331
+ # original: grad_s1_accum = th.einsum("bjxy,bjy,bjx->bj", rg1, t, grad_t2)[
332
+ # :, :, None
333
+ # ]
334
+ grad_s1_accum = (
335
+ (rg1 * t[:, :, None, :] * grad_t2[:, :, :, None])
336
+ .sum(dim=3)
337
+ .sum(dim=2, keepdim=True)
338
+ )
339
+ grad_s1_accum = grad_s1_accum + s * grad_s2
340
+
341
+ # original: grad_r1_accum = th.einsum("bjxy,bjzy->bjxz", grad_r2, r)
342
+ # original: grad_r1_accum = grad_r1_accum + th.einsum(
343
+ # "bj,bjx,bjy->bjxy",
344
+ # sg1[:, :, 0],
345
+ # grad_t2,
346
+ # t,
347
+ # )
348
+ grad_r1_accum = (grad_r2[:, :, :, None, :] * r[:, :, None, :, :]).sum(dim=4)
349
+ grad_r1_accum = (
350
+ grad_r1_accum + (sg1 * grad_t2)[:, :, :, None] * t[:, :, None, :]
351
+ )
352
+
353
+ grad_t1_accum = grad_t2
354
+
355
+ # setup the reduced gradients
356
+ grad_local_state_t[:, source] = grad_t
357
+ grad_local_state_r[:, source] = grad_r
358
+ grad_local_state_s[:, source] = grad_s
359
+
360
+ grad_local_state_t.index_add_(1, target, grad_t1_accum)
361
+ grad_local_state_r.index_add_(1, target, grad_r1_accum)
362
+ grad_local_state_s.index_add_(1, target, grad_s1_accum)
363
+
364
+ # setup the reduced KC
365
+ joint_state_t[:, source] = t
366
+ joint_state_r[:, source] = r
367
+ joint_state_s[:, source] = s
368
+
369
+ return (
370
+ grad_local_state_t.to(dtype),
371
+ grad_local_state_r.to(dtype),
372
+ grad_local_state_s.to(dtype),
373
+ )
374
+
375
+
376
+ class ForwardKinematicsFromLocalTransformationJIT(th.autograd.Function):
377
+ @staticmethod
378
+ # pyre-ignore[14]
379
+ def forward(
380
+ local_state_t: th.Tensor,
381
+ local_state_r: th.Tensor,
382
+ local_state_s: th.Tensor,
383
+ prefix_mul_indices: List[th.Tensor],
384
+ ) -> Tuple[
385
+ th.Tensor, th.Tensor, th.Tensor, List[Tuple[th.Tensor, th.Tensor, th.Tensor]]
386
+ ]:
387
+ """
388
+ Compute forward pass for differentiable forward kinematics using TRS representation.
389
+
390
+ Args:
391
+ local_state_t: Local joint translations, shape (batch_size, num_joints, 3)
392
+ local_state_r: Local joint rotations, shape (batch_size, num_joints, 3, 3)
393
+ local_state_s: Local joint scales, shape (batch_size, num_joints, 1)
394
+ prefix_mul_indices: List of [child_index, parent_index] tensor pairs
395
+
396
+ Returns:
397
+ Tuple of (global_state_t, global_state_r, global_state_s, intermediate_results)
398
+ """
399
+ return global_trs_state_from_local_trs_state_no_grad(
400
+ local_state_t,
401
+ local_state_r,
402
+ local_state_s,
403
+ prefix_mul_indices,
404
+ )
405
+
406
+ @staticmethod
407
+ # pyre-ignore[14]
408
+ # pyre-ignore[2]
409
+ def setup_context(ctx, inputs, outputs) -> None:
410
+ """
411
+ Save context for backward pass.
412
+
413
+ Args:
414
+ ctx: Context object for saving tensors and data
415
+ inputs: Tuple of (local_state_t, local_state_r, local_state_s, prefix_mul_indices)
416
+ outputs: Tuple of (joint_state_t, joint_state_r, joint_state_s, intermediate_results)
417
+ """
418
+ (
419
+ _,
420
+ _,
421
+ _,
422
+ prefix_mul_indices,
423
+ ) = inputs
424
+ (
425
+ joint_state_t,
426
+ joint_state_r,
427
+ joint_state_s,
428
+ intermediate_results,
429
+ ) = outputs
430
+ # need to clone as it's modified in-place
431
+ ctx.save_for_backward(
432
+ joint_state_t.clone(),
433
+ joint_state_r.clone(),
434
+ joint_state_s.clone(),
435
+ )
436
+ ctx.intermediate_results = intermediate_results
437
+ ctx.prefix_mul_indices = prefix_mul_indices
438
+
439
+ @staticmethod
440
+ # pyre-ignore[14]
441
+ def backward(
442
+ # pyre-ignore[2]
443
+ ctx,
444
+ grad_joint_state_t: th.Tensor,
445
+ grad_joint_state_r: th.Tensor,
446
+ grad_joint_state_s: th.Tensor,
447
+ _0,
448
+ ) -> Tuple[th.Tensor, th.Tensor, th.Tensor, None]:
449
+ (
450
+ joint_state_t,
451
+ joint_state_r,
452
+ joint_state_s,
453
+ ) = ctx.saved_tensors
454
+
455
+ intermediate_results = ctx.intermediate_results
456
+ prefix_mul_indices = ctx.prefix_mul_indices
457
+
458
+ (
459
+ grad_local_state_t,
460
+ grad_local_state_r,
461
+ grad_local_state_s,
462
+ ) = global_trs_state_from_local_trs_state_backprop(
463
+ joint_state_t,
464
+ joint_state_r,
465
+ joint_state_s,
466
+ grad_joint_state_t,
467
+ grad_joint_state_r,
468
+ grad_joint_state_s,
469
+ prefix_mul_indices,
470
+ intermediate_results,
471
+ )
472
+ return (grad_local_state_t, grad_local_state_r, grad_local_state_s, None)
473
+
474
+
475
+ def global_trs_state_from_local_trs_state(
476
+ local_state_t: th.Tensor,
477
+ local_state_r: th.Tensor,
478
+ local_state_s: th.Tensor,
479
+ prefix_mul_indices: List[th.Tensor],
480
+ ) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
481
+ """
482
+ Compute global TRS state from local joint transformations (user-facing wrapper).
483
+
484
+ This is the main entry point for forward kinematics using TRS states. It automatically
485
+ selects between JIT-compiled and autograd-enabled implementations based on the execution context.
486
+
487
+ Args:
488
+ local_state_t: Local joint translations, shape (batch_size, num_joints, 3).
489
+ local_state_r: Local joint rotations, shape (batch_size, num_joints, 3, 3).
490
+ local_state_s: Local joint scales, shape (batch_size, num_joints, 1).
491
+ prefix_mul_indices: List of [child_index, parent_index] tensor pairs defining the kinematic hierarchy traversal order.
492
+
493
+ Returns:
494
+ global_state_t: Global joint translations, shape (batch_size, num_joints, 3).
495
+ global_state_r: Global joint rotations, shape (batch_size, num_joints, 3, 3).
496
+ global_state_s: Global joint scales, shape (batch_size, num_joints, 1).
497
+
498
+ Note:
499
+ When called within torch.jit.script or torch.jit.trace context, uses the JIT-compiled
500
+ implementation for maximum performance. Otherwise, uses the autograd-enabled version
501
+ for gradient computation.
502
+
503
+ See Also:
504
+ :func:`global_trs_state_from_local_trs_state_impl`: JIT implementation
505
+ :func:`local_trs_state_from_joint_params`: Convert joint parameters to local states
506
+ """
507
+ if th.jit.is_tracing() or th.jit.is_scripting():
508
+ (
509
+ joint_state_t,
510
+ joint_state_r,
511
+ joint_state_s,
512
+ _,
513
+ ) = global_trs_state_from_local_trs_state_impl(
514
+ local_state_t,
515
+ local_state_r,
516
+ local_state_s,
517
+ prefix_mul_indices,
518
+ )
519
+ else:
520
+ (
521
+ joint_state_t,
522
+ joint_state_r,
523
+ joint_state_s,
524
+ _,
525
+ ) = ForwardKinematicsFromLocalTransformationJIT.apply(
526
+ local_state_t,
527
+ local_state_r,
528
+ local_state_s,
529
+ prefix_mul_indices,
530
+ )
531
+ return (
532
+ joint_state_t,
533
+ joint_state_r,
534
+ joint_state_s,
535
+ )
536
+
537
+
538
+ def global_trs_state_from_local_trs_state_forward_only(
539
+ local_state_t: th.Tensor,
540
+ local_state_r: th.Tensor,
541
+ local_state_s: th.Tensor,
542
+ prefix_mul_indices: list[th.Tensor],
543
+ ) -> tuple[th.Tensor, th.Tensor, th.Tensor]:
544
+ """
545
+ Compute global TRS state from local joint transformations (forward-only wrapper).
546
+
547
+ This is a forward-only version that bypasses autograd completely, used when
548
+ gradients are not needed and maximum performance is required.
549
+
550
+ Args:
551
+ local_state_t: Local joint translations, shape (batch_size, num_joints, 3).
552
+ local_state_r: Local joint rotations, shape (batch_size, num_joints, 3, 3).
553
+ local_state_s: Local joint scales, shape (batch_size, num_joints, 1).
554
+ prefix_mul_indices: List of [child_index, parent_index] tensor pairs.
555
+
556
+ Returns:
557
+ global_state_t: Global joint translations, shape (batch_size, num_joints, 3).
558
+ global_state_r: Global joint rotations, shape (batch_size, num_joints, 3, 3).
559
+ global_state_s: Global joint scales, shape (batch_size, num_joints, 1).
560
+
561
+ See Also:
562
+ :func:`global_trs_state_from_local_trs_state`: Main user-facing function with autograd
563
+ """
564
+ (
565
+ joint_state_t,
566
+ joint_state_r,
567
+ joint_state_s,
568
+ _,
569
+ ) = ForwardKinematicsFromLocalTransformationJIT.forward(
570
+ local_state_t,
571
+ local_state_r,
572
+ local_state_s,
573
+ prefix_mul_indices,
574
+ )
575
+ return (
576
+ joint_state_t,
577
+ joint_state_r,
578
+ joint_state_s,
579
+ )
580
+
581
+
582
+ @th.jit.script
583
+ def skinning(
584
+ template: th.Tensor,
585
+ t: th.Tensor,
586
+ r: th.Tensor,
587
+ s: th.Tensor,
588
+ t0: th.Tensor,
589
+ r0: th.Tensor,
590
+ skin_indices_flattened: th.Tensor,
591
+ skin_weights_flattened: th.Tensor,
592
+ vert_indices_flattened: th.Tensor,
593
+ ) -> th.Tensor:
594
+ r"""
595
+ LBS skinning formula as is in lbs_pytorch:
596
+ https://ghe.oculus-rep.com/ydong142857/lbs_pytorch
597
+
598
+ TODO: we might want to change skinning to double precision
599
+ with current float32 formulation the numerical error is bigger than 1e-3 level
600
+ (but smaller than 1e-2 level)
601
+
602
+ Basically,
603
+ y_i = \sum_j w_ij (s_j * r_j * (r0_j * x_i + t0_j) + t_j)
604
+ where \sum_j w_ij = 1, \forall i
605
+
606
+ Args:
607
+ template: (B, V, 3) LBS template
608
+ t: (B, J, 3) Translation of the joints
609
+ r: (B, J, 3, 3) Rotation of the joints
610
+ s: (B, J, 1) Scale of the joints
611
+ t0: (J, 3) Translation of inverse bind pose
612
+ r0: (J, 3, 3) Rotation of inverse bind pose
613
+ (for our setting, s0 == 1)
614
+ skin_indices_flattened: (N, ) LBS skinning nbr joint indices
615
+ skin_weights_flattened: (N, ) LBS skinning nbr joint weights
616
+ vert_indices_flattened: (N, ) LBS skinning nbr corresponding vertex indices
617
+
618
+ Returns:
619
+ skinned: (B, V, 3) Skinned mesh
620
+ """
621
+ batch_size = t.shape[0]
622
+ if template.shape[0] != batch_size:
623
+ template = template[None, ...].expand(batch_size, -1, -1)
624
+
625
+ sr = s[:, :, :, None] * r
626
+ A = trs.rotmat_multiply(sr, r0[None])
627
+ b = trs.rotmat_rotate_vector(sr, t0[None]) + t
628
+
629
+ skinned = th.zeros_like(template)
630
+ skinned = skinned.index_add(
631
+ 1,
632
+ vert_indices_flattened,
633
+ (
634
+ trs.rotmat_rotate_vector(
635
+ th.index_select(A, 1, skin_indices_flattened),
636
+ th.index_select(template, 1, vert_indices_flattened),
637
+ )
638
+ + th.index_select(b, 1, skin_indices_flattened)
639
+ )
640
+ * skin_weights_flattened[None, :, None],
641
+ )
642
+ return skinned
643
+
644
+
645
+ @th.jit.script
646
+ def multi_topology_skinning(
647
+ template: th.Tensor,
648
+ t: th.Tensor,
649
+ r: th.Tensor,
650
+ s: th.Tensor,
651
+ t0: th.Tensor,
652
+ r0: th.Tensor,
653
+ skin_indices_flattened: th.Tensor,
654
+ skin_weights_flattened: th.Tensor,
655
+ vert_indices_flattened: th.Tensor,
656
+ ) -> th.Tensor:
657
+ r"""
658
+ LBS skinning formula as is in lbs_pytorch:
659
+ https://ghe.oculus-rep.com/ydong142857/lbs_pytorch
660
+
661
+ The difference here is that we assume that the flattened indices are for multiple
662
+ topologies. So vert_indices_flattened needs to flattened with the batch dimension.
663
+
664
+ TODO: we might want to change skinning to double precision
665
+ with current float32 formulation the numerical error is bigger than 1e-3 level
666
+ (but smaller than 1e-2 level)
667
+
668
+ Basically,
669
+ y_i = \sum_j w_ij (s_j * r_j * (r0_j * x_i + t0_j) + t_j)
670
+ where \sum_j w_ij = 1, \forall i
671
+
672
+ Args:
673
+ template: (B, V, 3) LBS template
674
+ t: (B, J, 3) Translation of the joints
675
+ r: (B, J, 3, 3) Rotation of the joints
676
+ s: (B, J, 1) Scale of the joints
677
+ t0: (J, 3) Translation of inverse bind pose
678
+ r0: (J, 3, 3) Rotation of inverse bind pose
679
+ (for our setting, s0 == 1)
680
+ skin_indices_flattened: (N, ) LBS skinning nbr joint indices
681
+ skin_weights_flattened: (N, ) LBS skinning nbr joint weights
682
+ vert_indices_flattened: (N, ) LBS skinning nbr corresponding vertex indices
683
+
684
+ Returns:
685
+ skinned: (B, V, 3) Skinned mesh
686
+ """
687
+ batch_size = t.shape[0]
688
+ if template.shape[0] != batch_size:
689
+ template = template[None, ...].expand(batch_size, -1, -1)
690
+
691
+ sr = s[:, :, :, None] * r
692
+ A = trs.rotmat_multiply(sr, r0[None])
693
+ b = trs.rotmat_rotate_vector(sr, t0[None]) + t
694
+
695
+ # If multi_topology is True, then index on the 0th dimension of A and b
696
+ # because we assume that the skin indices are flattened to index into different
697
+ # vertex indices in each sample of the batch.
698
+
699
+ skinning_A = th.index_select(
700
+ A.view(A.shape[0] * A.shape[1], A.shape[2], A.shape[3]),
701
+ 0,
702
+ skin_indices_flattened,
703
+ )
704
+
705
+ skinning_b = th.index_select(
706
+ b.view(b.shape[0] * b.shape[1], b.shape[2]), 0, skin_indices_flattened
707
+ )
708
+
709
+ skinning_verts = th.index_select(
710
+ template.view(template.shape[0] * template.shape[1], template.shape[2]),
711
+ 0,
712
+ vert_indices_flattened,
713
+ )
714
+
715
+ skinned = th.zeros_like(template).view(
716
+ template.shape[0] * template.shape[1], template.shape[2]
717
+ )
718
+ skinned = skinned.index_add(
719
+ 0,
720
+ vert_indices_flattened,
721
+ (trs.rotmat_rotate_vector(skinning_A, skinning_verts) + skinning_b)
722
+ * skin_weights_flattened[..., None],
723
+ )
724
+ return skinned.view(template.shape[0], template.shape[1], template.shape[2])
725
+
726
+
727
+ def unpose_from_global_joint_state(
728
+ verts: th.Tensor,
729
+ t: th.Tensor,
730
+ r: th.Tensor,
731
+ s: th.Tensor,
732
+ t0: th.Tensor,
733
+ r0: th.Tensor,
734
+ skin_indices_flattened: th.Tensor,
735
+ skin_weights_flattened: th.Tensor,
736
+ vert_indices_flattened: th.Tensor,
737
+ with_high_precision: bool = True,
738
+ ) -> th.Tensor:
739
+ """
740
+ The inverse function of skinning().
741
+ WARNING: the precision is low...
742
+
743
+ Args:
744
+ verts: [batch_size, num_verts, 3]
745
+ t: (B, J, 3) Translation of the joints
746
+ r: (B, J, 3, 3) Rotation of the joints
747
+ s: (B, J, 1) Scale of the joints
748
+ t0: (J, 3) Translation of inverse bind pose
749
+ r0: (J, 3, 3) Rotation of inverse bind pose
750
+ skin_indices_flattened: (N, ) LBS skinning nbr joint indices
751
+ skin_weights_flattened: (N, ) LBS skinning nbr joint weights
752
+ vert_indices_flattened: (N, ) LBS skinning nbr corresponding vertex indices
753
+ with_high_precision: if True, use high precision solver (LDLT), but requires a cuda device sync
754
+ """
755
+ dtype = verts.dtype
756
+ device = verts.device
757
+
758
+ sr = s[:, :, :, None] * r
759
+ A = trs.rotmat_multiply(sr, r0[None])
760
+ b = trs.rotmat_rotate_vector(sr, t0[None]) + t
761
+
762
+ fused_A = th.zeros(verts.shape + (3,), dtype=dtype, device=device)
763
+ fused_b = th.zeros(verts.shape, dtype=dtype, device=device)
764
+ fused_A = fused_A.index_add_(
765
+ 1,
766
+ vert_indices_flattened,
767
+ th.index_select(
768
+ A,
769
+ 1,
770
+ skin_indices_flattened,
771
+ )
772
+ * skin_weights_flattened[None, :, None, None],
773
+ )
774
+ fused_b = fused_b.index_add_(
775
+ 1,
776
+ vert_indices_flattened,
777
+ th.index_select(
778
+ b,
779
+ 1,
780
+ skin_indices_flattened,
781
+ )
782
+ * skin_weights_flattened[None, :, None],
783
+ )
784
+
785
+ if with_high_precision:
786
+ # th.linalg.solve is not aware of the condition number
787
+ # let's use LDLT decomposition
788
+ ATA = th.einsum("bvyx,bvyz->bvxz", fused_A, fused_A)
789
+ ATb = th.einsum("bvyx,bvy->bvx", fused_A, verts - fused_b)
790
+
791
+ # ldl_factor_ex is very slow on GPU
792
+ LD, pivots, _ = th.linalg.ldl_factor_ex(ATA.cpu())
793
+ unposed_mesh = th.linalg.ldl_solve(LD, pivots, ATb[..., None].cpu())[..., 0]
794
+
795
+ unposed_mesh = unposed_mesh.to(ATA.device)
796
+ else:
797
+ unposed_mesh = th.linalg.solve(fused_A, verts - fused_b)
798
+
799
+ return unposed_mesh
800
+
801
+
802
+ @th.jit.script
803
+ def get_local_state_from_joint_params(
804
+ joint_params: th.Tensor,
805
+ joint_offset: th.Tensor,
806
+ joint_rotation: th.Tensor,
807
+ joint_parents: th.Tensor | None = None,
808
+ allow_inverse_kinematic_chain: bool = False,
809
+ ) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
810
+ """
811
+ calculate local joint state from joint parameters.
812
+
813
+ Args:
814
+ joint_params: [batch_size, num_joints, 7] or [batch_size, num_joints * 7]
815
+ joint_offset: [num_joints, 3]
816
+ joint_rotation: [num_joints, 3, 3]
817
+ allow_inverse_kinematic_chain: if set to True, this hints that the kinematic
818
+ chain might be reversed (e.g. from wrist to root). This leads to a few
819
+ changes in assumption. One of the major difference is that the root joint
820
+ always has identity [0, I, 1] transformation.
821
+
822
+ Returns:
823
+ local_state_t: [batch_size, num_joints, 3]
824
+ local_state_r: [batch_size, num_joints, 3, 3]
825
+ local_state_s: [batch_size, num_joints, 1]
826
+ """
827
+ if len(joint_params.shape) == 2:
828
+ # reshape joint_params as (batch_size, num_joints, 7)
829
+ joint_params = joint_params.view(joint_params.shape[0], -1, 7)
830
+
831
+ # the vanilla conversion
832
+ local_state_t = joint_params[:, :, :3] + joint_offset[None, :]
833
+ local_state_r = trs.rotmat_multiply(
834
+ joint_rotation[None], trs.rotmat_from_euler_xyz(joint_params[:, :, 3:6])
835
+ )
836
+ local_state_s = th.exp2(joint_params[:, :, 6:])
837
+
838
+ if allow_inverse_kinematic_chain:
839
+ assert joint_parents is not None
840
+ assert len(joint_parents.shape) == 1
841
+ device = joint_parents.device
842
+ root_joint = th.where(joint_parents == -1)[0]
843
+ inversed_joints = th.where(
844
+ joint_parents
845
+ > th.arange(0, len(joint_parents), dtype=th.long, device=device)
846
+ )[0]
847
+ inversed_joint_parents = joint_parents[inversed_joints]
848
+
849
+ # create a new node so the autograd does not fail
850
+ (
851
+ _local_state_t,
852
+ _local_state_r,
853
+ _local_state_s,
854
+ ) = (
855
+ local_state_t.clone(),
856
+ local_state_r.clone(),
857
+ local_state_s.clone(),
858
+ )
859
+
860
+ # for the inverse joints
861
+ # the order needs to be inversed
862
+ (
863
+ _local_state_t[:, inversed_joints],
864
+ _local_state_r[:, inversed_joints],
865
+ _local_state_s[:, inversed_joints],
866
+ ) = trs.inverse(
867
+ (
868
+ local_state_t[:, inversed_joint_parents],
869
+ local_state_r[:, inversed_joint_parents],
870
+ local_state_s[:, inversed_joint_parents],
871
+ )
872
+ )
873
+
874
+ # set new root joint to identity
875
+ _local_state_t[:, root_joint] = 0
876
+ _local_state_r[:, root_joint] = th.eye(3, device=device)[None]
877
+ _local_state_s[:, root_joint] = 1
878
+
879
+ (
880
+ local_state_t,
881
+ local_state_r,
882
+ local_state_s,
883
+ ) = (
884
+ _local_state_t,
885
+ _local_state_r,
886
+ _local_state_s,
887
+ )
888
+
889
+ return local_state_t, local_state_r, local_state_s