pymomentum-cpu 0.1.82.post0__cp313-cp313-macosx_14_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pymomentum-cpu might be problematic. Click here for more details.
- include/axel/BoundingBox.h +58 -0
- include/axel/Bvh.h +708 -0
- include/axel/BvhBase.h +75 -0
- include/axel/BvhCommon.h +43 -0
- include/axel/BvhEmbree.h +86 -0
- include/axel/BvhFactory.h +34 -0
- include/axel/Checks.h +21 -0
- include/axel/DualContouring.h +79 -0
- include/axel/KdTree.h +199 -0
- include/axel/Log.h +22 -0
- include/axel/MeshToSdf.h +123 -0
- include/axel/Profile.h +64 -0
- include/axel/Ray.h +45 -0
- include/axel/SignedDistanceField.h +248 -0
- include/axel/SimdKdTree.h +515 -0
- include/axel/TriBvh.h +157 -0
- include/axel/TriBvhEmbree.h +57 -0
- include/axel/common/Constants.h +27 -0
- include/axel/common/Types.h +21 -0
- include/axel/common/VectorizationTypes.h +58 -0
- include/axel/math/BoundingBoxUtils.h +54 -0
- include/axel/math/ContinuousCollisionDetection.h +48 -0
- include/axel/math/CoplanarityCheck.h +30 -0
- include/axel/math/EdgeEdgeDistance.h +31 -0
- include/axel/math/MeshHoleFilling.h +117 -0
- include/axel/math/PointTriangleProjection.h +34 -0
- include/axel/math/PointTriangleProjectionDefinitions.h +209 -0
- include/axel/math/RayTriangleIntersection.h +36 -0
- include/momentum/character/blend_shape.h +91 -0
- include/momentum/character/blend_shape_base.h +70 -0
- include/momentum/character/blend_shape_skinning.h +96 -0
- include/momentum/character/character.h +272 -0
- include/momentum/character/character_state.h +108 -0
- include/momentum/character/character_utility.h +128 -0
- include/momentum/character/collision_geometry.h +80 -0
- include/momentum/character/collision_geometry_state.h +130 -0
- include/momentum/character/fwd.h +262 -0
- include/momentum/character/inverse_parameter_transform.h +58 -0
- include/momentum/character/joint.h +82 -0
- include/momentum/character/joint_state.h +241 -0
- include/momentum/character/linear_skinning.h +139 -0
- include/momentum/character/locator.h +82 -0
- include/momentum/character/locator_state.h +43 -0
- include/momentum/character/marker.h +48 -0
- include/momentum/character/mesh_state.h +71 -0
- include/momentum/character/parameter_limits.h +144 -0
- include/momentum/character/parameter_transform.h +250 -0
- include/momentum/character/pose_shape.h +65 -0
- include/momentum/character/skeleton.h +85 -0
- include/momentum/character/skeleton_state.h +181 -0
- include/momentum/character/skeleton_utility.h +38 -0
- include/momentum/character/skin_weights.h +67 -0
- include/momentum/character/skinned_locator.h +80 -0
- include/momentum/character/types.h +202 -0
- include/momentum/character_sequence_solver/fwd.h +200 -0
- include/momentum/character_sequence_solver/model_parameters_sequence_error_function.h +65 -0
- include/momentum/character_sequence_solver/multipose_solver.h +65 -0
- include/momentum/character_sequence_solver/multipose_solver_function.h +82 -0
- include/momentum/character_sequence_solver/sequence_error_function.h +104 -0
- include/momentum/character_sequence_solver/sequence_solver.h +144 -0
- include/momentum/character_sequence_solver/sequence_solver_function.h +134 -0
- include/momentum/character_sequence_solver/state_sequence_error_function.h +109 -0
- include/momentum/character_sequence_solver/vertex_sequence_error_function.h +123 -0
- include/momentum/character_solver/aim_error_function.h +112 -0
- include/momentum/character_solver/collision_error_function.h +92 -0
- include/momentum/character_solver/collision_error_function_stateless.h +75 -0
- include/momentum/character_solver/constraint_error_function-inl.h +324 -0
- include/momentum/character_solver/constraint_error_function.h +248 -0
- include/momentum/character_solver/distance_error_function.h +77 -0
- include/momentum/character_solver/error_function_utils.h +60 -0
- include/momentum/character_solver/fixed_axis_error_function.h +139 -0
- include/momentum/character_solver/fwd.h +924 -0
- include/momentum/character_solver/gauss_newton_solver_qr.h +64 -0
- include/momentum/character_solver/limit_error_function.h +57 -0
- include/momentum/character_solver/model_parameters_error_function.h +64 -0
- include/momentum/character_solver/normal_error_function.h +73 -0
- include/momentum/character_solver/orientation_error_function.h +74 -0
- include/momentum/character_solver/plane_error_function.h +102 -0
- include/momentum/character_solver/point_triangle_vertex_error_function.h +141 -0
- include/momentum/character_solver/pose_prior_error_function.h +80 -0
- include/momentum/character_solver/position_error_function.h +75 -0
- include/momentum/character_solver/projection_error_function.h +93 -0
- include/momentum/character_solver/simd_collision_error_function.h +99 -0
- include/momentum/character_solver/simd_normal_error_function.h +157 -0
- include/momentum/character_solver/simd_plane_error_function.h +164 -0
- include/momentum/character_solver/simd_position_error_function.h +165 -0
- include/momentum/character_solver/skeleton_error_function.h +151 -0
- include/momentum/character_solver/skeleton_solver_function.h +94 -0
- include/momentum/character_solver/skinned_locator_error_function.h +166 -0
- include/momentum/character_solver/skinned_locator_triangle_error_function.h +146 -0
- include/momentum/character_solver/skinning_weight_iterator.h +80 -0
- include/momentum/character_solver/state_error_function.h +94 -0
- include/momentum/character_solver/transform_pose.h +80 -0
- include/momentum/character_solver/trust_region_qr.h +80 -0
- include/momentum/character_solver/vertex_error_function.h +155 -0
- include/momentum/character_solver/vertex_projection_error_function.h +117 -0
- include/momentum/character_solver/vertex_vertex_distance_error_function.h +147 -0
- include/momentum/common/aligned.h +155 -0
- include/momentum/common/checks.h +27 -0
- include/momentum/common/exception.h +70 -0
- include/momentum/common/filesystem.h +20 -0
- include/momentum/common/fwd.h +27 -0
- include/momentum/common/log.h +173 -0
- include/momentum/common/log_channel.h +17 -0
- include/momentum/common/memory.h +71 -0
- include/momentum/common/profile.h +79 -0
- include/momentum/common/progress_bar.h +37 -0
- include/momentum/common/string.h +52 -0
- include/momentum/diff_ik/ceres_utility.h +73 -0
- include/momentum/diff_ik/fully_differentiable_body_ik.h +58 -0
- include/momentum/diff_ik/fully_differentiable_distance_error_function.h +69 -0
- include/momentum/diff_ik/fully_differentiable_motion_error_function.h +46 -0
- include/momentum/diff_ik/fully_differentiable_orientation_error_function.h +114 -0
- include/momentum/diff_ik/fully_differentiable_pose_prior_error_function.h +76 -0
- include/momentum/diff_ik/fully_differentiable_position_error_function.h +138 -0
- include/momentum/diff_ik/fully_differentiable_projection_error_function.h +65 -0
- include/momentum/diff_ik/fully_differentiable_skeleton_error_function.h +160 -0
- include/momentum/diff_ik/fully_differentiable_state_error_function.h +54 -0
- include/momentum/diff_ik/fwd.h +385 -0
- include/momentum/diff_ik/union_error_function.h +67 -0
- include/momentum/gui/rerun/eigen_adapters.h +70 -0
- include/momentum/gui/rerun/logger.h +102 -0
- include/momentum/gui/rerun/logging_redirect.h +27 -0
- include/momentum/io/character_io.h +56 -0
- include/momentum/io/common/gsl_utils.h +50 -0
- include/momentum/io/common/stream_utils.h +65 -0
- include/momentum/io/fbx/fbx_io.h +109 -0
- include/momentum/io/fbx/fbx_memory_stream.h +66 -0
- include/momentum/io/fbx/openfbx_loader.h +49 -0
- include/momentum/io/fbx/polygon_data.h +60 -0
- include/momentum/io/gltf/gltf_builder.h +132 -0
- include/momentum/io/gltf/gltf_file_format.h +19 -0
- include/momentum/io/gltf/gltf_io.h +148 -0
- include/momentum/io/gltf/utils/accessor_utils.h +299 -0
- include/momentum/io/gltf/utils/coordinate_utils.h +60 -0
- include/momentum/io/gltf/utils/json_utils.h +102 -0
- include/momentum/io/legacy_json/legacy_json_io.h +70 -0
- include/momentum/io/marker/c3d_io.h +30 -0
- include/momentum/io/marker/conversions.h +57 -0
- include/momentum/io/marker/coordinate_system.h +30 -0
- include/momentum/io/marker/marker_io.h +56 -0
- include/momentum/io/marker/trc_io.h +27 -0
- include/momentum/io/motion/mmo_io.h +97 -0
- include/momentum/io/shape/blend_shape_io.h +70 -0
- include/momentum/io/shape/pose_shape_io.h +21 -0
- include/momentum/io/skeleton/locator_io.h +41 -0
- include/momentum/io/skeleton/mppca_io.h +26 -0
- include/momentum/io/skeleton/parameter_limits_io.h +25 -0
- include/momentum/io/skeleton/parameter_transform_io.h +41 -0
- include/momentum/io/skeleton/parameters_io.h +20 -0
- include/momentum/io/urdf/urdf_io.h +26 -0
- include/momentum/io/usd/usd_io.h +36 -0
- include/momentum/marker_tracking/app_utils.h +62 -0
- include/momentum/marker_tracking/marker_tracker.h +213 -0
- include/momentum/marker_tracking/process_markers.h +58 -0
- include/momentum/marker_tracking/tracker_utils.h +90 -0
- include/momentum/math/constants.h +82 -0
- include/momentum/math/covariance_matrix.h +84 -0
- include/momentum/math/fmt_eigen.h +23 -0
- include/momentum/math/fwd.h +132 -0
- include/momentum/math/generalized_loss.h +61 -0
- include/momentum/math/intersection.h +32 -0
- include/momentum/math/mesh.h +84 -0
- include/momentum/math/mppca.h +67 -0
- include/momentum/math/online_householder_qr.h +516 -0
- include/momentum/math/random-inl.h +404 -0
- include/momentum/math/random.h +310 -0
- include/momentum/math/simd_generalized_loss.h +40 -0
- include/momentum/math/transform.h +229 -0
- include/momentum/math/types.h +461 -0
- include/momentum/math/utility.h +251 -0
- include/momentum/rasterizer/camera.h +453 -0
- include/momentum/rasterizer/fwd.h +102 -0
- include/momentum/rasterizer/geometry.h +83 -0
- include/momentum/rasterizer/image.h +18 -0
- include/momentum/rasterizer/rasterizer.h +583 -0
- include/momentum/rasterizer/tensor.h +140 -0
- include/momentum/rasterizer/utility.h +268 -0
- include/momentum/simd/simd.h +221 -0
- include/momentum/solver/fwd.h +131 -0
- include/momentum/solver/gauss_newton_solver.h +136 -0
- include/momentum/solver/gradient_descent_solver.h +65 -0
- include/momentum/solver/solver.h +155 -0
- include/momentum/solver/solver_function.h +126 -0
- include/momentum/solver/subset_gauss_newton_solver.h +109 -0
- include/rerun/archetypes/annotation_context.hpp +157 -0
- include/rerun/archetypes/arrows2d.hpp +271 -0
- include/rerun/archetypes/arrows3d.hpp +257 -0
- include/rerun/archetypes/asset3d.hpp +262 -0
- include/rerun/archetypes/asset_video.hpp +275 -0
- include/rerun/archetypes/bar_chart.hpp +261 -0
- include/rerun/archetypes/boxes2d.hpp +293 -0
- include/rerun/archetypes/boxes3d.hpp +369 -0
- include/rerun/archetypes/capsules3d.hpp +333 -0
- include/rerun/archetypes/clear.hpp +180 -0
- include/rerun/archetypes/depth_image.hpp +425 -0
- include/rerun/archetypes/ellipsoids3d.hpp +384 -0
- include/rerun/archetypes/encoded_image.hpp +250 -0
- include/rerun/archetypes/geo_line_strings.hpp +166 -0
- include/rerun/archetypes/geo_points.hpp +177 -0
- include/rerun/archetypes/graph_edges.hpp +152 -0
- include/rerun/archetypes/graph_nodes.hpp +206 -0
- include/rerun/archetypes/image.hpp +434 -0
- include/rerun/archetypes/instance_poses3d.hpp +221 -0
- include/rerun/archetypes/line_strips2d.hpp +289 -0
- include/rerun/archetypes/line_strips3d.hpp +270 -0
- include/rerun/archetypes/mesh3d.hpp +387 -0
- include/rerun/archetypes/pinhole.hpp +385 -0
- include/rerun/archetypes/points2d.hpp +333 -0
- include/rerun/archetypes/points3d.hpp +369 -0
- include/rerun/archetypes/recording_properties.hpp +132 -0
- include/rerun/archetypes/scalar.hpp +170 -0
- include/rerun/archetypes/scalars.hpp +153 -0
- include/rerun/archetypes/segmentation_image.hpp +305 -0
- include/rerun/archetypes/series_line.hpp +274 -0
- include/rerun/archetypes/series_lines.hpp +271 -0
- include/rerun/archetypes/series_point.hpp +265 -0
- include/rerun/archetypes/series_points.hpp +251 -0
- include/rerun/archetypes/tensor.hpp +213 -0
- include/rerun/archetypes/text_document.hpp +200 -0
- include/rerun/archetypes/text_log.hpp +211 -0
- include/rerun/archetypes/transform3d.hpp +925 -0
- include/rerun/archetypes/video_frame_reference.hpp +295 -0
- include/rerun/archetypes/view_coordinates.hpp +393 -0
- include/rerun/archetypes.hpp +43 -0
- include/rerun/arrow_utils.hpp +32 -0
- include/rerun/as_components.hpp +90 -0
- include/rerun/blueprint/archetypes/background.hpp +113 -0
- include/rerun/blueprint/archetypes/container_blueprint.hpp +259 -0
- include/rerun/blueprint/archetypes/dataframe_query.hpp +178 -0
- include/rerun/blueprint/archetypes/entity_behavior.hpp +130 -0
- include/rerun/blueprint/archetypes/force_center.hpp +115 -0
- include/rerun/blueprint/archetypes/force_collision_radius.hpp +141 -0
- include/rerun/blueprint/archetypes/force_link.hpp +136 -0
- include/rerun/blueprint/archetypes/force_many_body.hpp +124 -0
- include/rerun/blueprint/archetypes/force_position.hpp +132 -0
- include/rerun/blueprint/archetypes/line_grid3d.hpp +178 -0
- include/rerun/blueprint/archetypes/map_background.hpp +104 -0
- include/rerun/blueprint/archetypes/map_zoom.hpp +103 -0
- include/rerun/blueprint/archetypes/near_clip_plane.hpp +109 -0
- include/rerun/blueprint/archetypes/panel_blueprint.hpp +95 -0
- include/rerun/blueprint/archetypes/plot_legend.hpp +118 -0
- include/rerun/blueprint/archetypes/scalar_axis.hpp +116 -0
- include/rerun/blueprint/archetypes/tensor_scalar_mapping.hpp +146 -0
- include/rerun/blueprint/archetypes/tensor_slice_selection.hpp +167 -0
- include/rerun/blueprint/archetypes/tensor_view_fit.hpp +95 -0
- include/rerun/blueprint/archetypes/view_blueprint.hpp +170 -0
- include/rerun/blueprint/archetypes/view_contents.hpp +142 -0
- include/rerun/blueprint/archetypes/viewport_blueprint.hpp +200 -0
- include/rerun/blueprint/archetypes/visible_time_ranges.hpp +116 -0
- include/rerun/blueprint/archetypes/visual_bounds2d.hpp +109 -0
- include/rerun/blueprint/archetypes/visualizer_overrides.hpp +113 -0
- include/rerun/blueprint/archetypes.hpp +29 -0
- include/rerun/blueprint/components/active_tab.hpp +82 -0
- include/rerun/blueprint/components/apply_latest_at.hpp +79 -0
- include/rerun/blueprint/components/auto_layout.hpp +77 -0
- include/rerun/blueprint/components/auto_views.hpp +77 -0
- include/rerun/blueprint/components/background_kind.hpp +66 -0
- include/rerun/blueprint/components/column_share.hpp +78 -0
- include/rerun/blueprint/components/component_column_selector.hpp +81 -0
- include/rerun/blueprint/components/container_kind.hpp +65 -0
- include/rerun/blueprint/components/corner2d.hpp +64 -0
- include/rerun/blueprint/components/enabled.hpp +77 -0
- include/rerun/blueprint/components/filter_by_range.hpp +74 -0
- include/rerun/blueprint/components/filter_is_not_null.hpp +77 -0
- include/rerun/blueprint/components/force_distance.hpp +82 -0
- include/rerun/blueprint/components/force_iterations.hpp +82 -0
- include/rerun/blueprint/components/force_strength.hpp +82 -0
- include/rerun/blueprint/components/grid_columns.hpp +78 -0
- include/rerun/blueprint/components/grid_spacing.hpp +78 -0
- include/rerun/blueprint/components/included_content.hpp +86 -0
- include/rerun/blueprint/components/lock_range_during_zoom.hpp +82 -0
- include/rerun/blueprint/components/map_provider.hpp +64 -0
- include/rerun/blueprint/components/near_clip_plane.hpp +82 -0
- include/rerun/blueprint/components/panel_state.hpp +61 -0
- include/rerun/blueprint/components/query_expression.hpp +89 -0
- include/rerun/blueprint/components/root_container.hpp +77 -0
- include/rerun/blueprint/components/row_share.hpp +78 -0
- include/rerun/blueprint/components/selected_columns.hpp +76 -0
- include/rerun/blueprint/components/tensor_dimension_index_slider.hpp +90 -0
- include/rerun/blueprint/components/timeline_name.hpp +76 -0
- include/rerun/blueprint/components/view_class.hpp +76 -0
- include/rerun/blueprint/components/view_fit.hpp +61 -0
- include/rerun/blueprint/components/view_maximized.hpp +79 -0
- include/rerun/blueprint/components/view_origin.hpp +81 -0
- include/rerun/blueprint/components/viewer_recommendation_hash.hpp +82 -0
- include/rerun/blueprint/components/visible_time_range.hpp +77 -0
- include/rerun/blueprint/components/visual_bounds2d.hpp +74 -0
- include/rerun/blueprint/components/visualizer_override.hpp +86 -0
- include/rerun/blueprint/components/zoom_level.hpp +78 -0
- include/rerun/blueprint/components.hpp +41 -0
- include/rerun/blueprint/datatypes/component_column_selector.hpp +61 -0
- include/rerun/blueprint/datatypes/filter_by_range.hpp +59 -0
- include/rerun/blueprint/datatypes/filter_is_not_null.hpp +61 -0
- include/rerun/blueprint/datatypes/selected_columns.hpp +62 -0
- include/rerun/blueprint/datatypes/tensor_dimension_index_slider.hpp +63 -0
- include/rerun/blueprint/datatypes.hpp +9 -0
- include/rerun/c/arrow_c_data_interface.h +111 -0
- include/rerun/c/compiler_utils.h +10 -0
- include/rerun/c/rerun.h +627 -0
- include/rerun/c/sdk_info.h +28 -0
- include/rerun/collection.hpp +496 -0
- include/rerun/collection_adapter.hpp +43 -0
- include/rerun/collection_adapter_builtins.hpp +138 -0
- include/rerun/compiler_utils.hpp +61 -0
- include/rerun/component_batch.hpp +163 -0
- include/rerun/component_column.hpp +111 -0
- include/rerun/component_descriptor.hpp +142 -0
- include/rerun/component_type.hpp +35 -0
- include/rerun/components/aggregation_policy.hpp +76 -0
- include/rerun/components/albedo_factor.hpp +74 -0
- include/rerun/components/annotation_context.hpp +102 -0
- include/rerun/components/axis_length.hpp +74 -0
- include/rerun/components/blob.hpp +73 -0
- include/rerun/components/class_id.hpp +71 -0
- include/rerun/components/clear_is_recursive.hpp +75 -0
- include/rerun/components/color.hpp +99 -0
- include/rerun/components/colormap.hpp +99 -0
- include/rerun/components/depth_meter.hpp +84 -0
- include/rerun/components/draw_order.hpp +79 -0
- include/rerun/components/entity_path.hpp +83 -0
- include/rerun/components/fill_mode.hpp +72 -0
- include/rerun/components/fill_ratio.hpp +79 -0
- include/rerun/components/gamma_correction.hpp +80 -0
- include/rerun/components/geo_line_string.hpp +63 -0
- include/rerun/components/graph_edge.hpp +75 -0
- include/rerun/components/graph_node.hpp +79 -0
- include/rerun/components/graph_type.hpp +57 -0
- include/rerun/components/half_size2d.hpp +91 -0
- include/rerun/components/half_size3d.hpp +95 -0
- include/rerun/components/image_buffer.hpp +86 -0
- include/rerun/components/image_format.hpp +84 -0
- include/rerun/components/image_plane_distance.hpp +77 -0
- include/rerun/components/interactive.hpp +76 -0
- include/rerun/components/keypoint_id.hpp +74 -0
- include/rerun/components/lat_lon.hpp +89 -0
- include/rerun/components/length.hpp +77 -0
- include/rerun/components/line_strip2d.hpp +73 -0
- include/rerun/components/line_strip3d.hpp +73 -0
- include/rerun/components/magnification_filter.hpp +63 -0
- include/rerun/components/marker_shape.hpp +82 -0
- include/rerun/components/marker_size.hpp +74 -0
- include/rerun/components/media_type.hpp +157 -0
- include/rerun/components/name.hpp +83 -0
- include/rerun/components/opacity.hpp +77 -0
- include/rerun/components/pinhole_projection.hpp +94 -0
- include/rerun/components/plane3d.hpp +75 -0
- include/rerun/components/pose_rotation_axis_angle.hpp +73 -0
- include/rerun/components/pose_rotation_quat.hpp +71 -0
- include/rerun/components/pose_scale3d.hpp +102 -0
- include/rerun/components/pose_transform_mat3x3.hpp +87 -0
- include/rerun/components/pose_translation3d.hpp +96 -0
- include/rerun/components/position2d.hpp +86 -0
- include/rerun/components/position3d.hpp +90 -0
- include/rerun/components/radius.hpp +98 -0
- include/rerun/components/range1d.hpp +75 -0
- include/rerun/components/resolution.hpp +88 -0
- include/rerun/components/rotation_axis_angle.hpp +72 -0
- include/rerun/components/rotation_quat.hpp +71 -0
- include/rerun/components/scalar.hpp +76 -0
- include/rerun/components/scale3d.hpp +102 -0
- include/rerun/components/series_visible.hpp +76 -0
- include/rerun/components/show_labels.hpp +79 -0
- include/rerun/components/stroke_width.hpp +74 -0
- include/rerun/components/tensor_data.hpp +94 -0
- include/rerun/components/tensor_dimension_index_selection.hpp +77 -0
- include/rerun/components/tensor_height_dimension.hpp +71 -0
- include/rerun/components/tensor_width_dimension.hpp +71 -0
- include/rerun/components/texcoord2d.hpp +101 -0
- include/rerun/components/text.hpp +83 -0
- include/rerun/components/text_log_level.hpp +110 -0
- include/rerun/components/timestamp.hpp +76 -0
- include/rerun/components/transform_mat3x3.hpp +92 -0
- include/rerun/components/transform_relation.hpp +66 -0
- include/rerun/components/translation3d.hpp +96 -0
- include/rerun/components/triangle_indices.hpp +85 -0
- include/rerun/components/value_range.hpp +78 -0
- include/rerun/components/vector2d.hpp +92 -0
- include/rerun/components/vector3d.hpp +96 -0
- include/rerun/components/video_timestamp.hpp +120 -0
- include/rerun/components/view_coordinates.hpp +346 -0
- include/rerun/components/visible.hpp +74 -0
- include/rerun/components.hpp +77 -0
- include/rerun/config.hpp +52 -0
- include/rerun/datatypes/angle.hpp +76 -0
- include/rerun/datatypes/annotation_info.hpp +76 -0
- include/rerun/datatypes/blob.hpp +67 -0
- include/rerun/datatypes/bool.hpp +57 -0
- include/rerun/datatypes/channel_datatype.hpp +87 -0
- include/rerun/datatypes/class_description.hpp +92 -0
- include/rerun/datatypes/class_description_map_elem.hpp +69 -0
- include/rerun/datatypes/class_id.hpp +62 -0
- include/rerun/datatypes/color_model.hpp +68 -0
- include/rerun/datatypes/dvec2d.hpp +76 -0
- include/rerun/datatypes/entity_path.hpp +60 -0
- include/rerun/datatypes/float32.hpp +62 -0
- include/rerun/datatypes/float64.hpp +62 -0
- include/rerun/datatypes/image_format.hpp +107 -0
- include/rerun/datatypes/keypoint_id.hpp +63 -0
- include/rerun/datatypes/keypoint_pair.hpp +65 -0
- include/rerun/datatypes/mat3x3.hpp +105 -0
- include/rerun/datatypes/mat4x4.hpp +119 -0
- include/rerun/datatypes/pixel_format.hpp +142 -0
- include/rerun/datatypes/plane3d.hpp +60 -0
- include/rerun/datatypes/quaternion.hpp +110 -0
- include/rerun/datatypes/range1d.hpp +59 -0
- include/rerun/datatypes/range2d.hpp +55 -0
- include/rerun/datatypes/rgba32.hpp +94 -0
- include/rerun/datatypes/rotation_axis_angle.hpp +67 -0
- include/rerun/datatypes/tensor_buffer.hpp +529 -0
- include/rerun/datatypes/tensor_data.hpp +100 -0
- include/rerun/datatypes/tensor_dimension_index_selection.hpp +58 -0
- include/rerun/datatypes/tensor_dimension_selection.hpp +56 -0
- include/rerun/datatypes/time_int.hpp +62 -0
- include/rerun/datatypes/time_range.hpp +55 -0
- include/rerun/datatypes/time_range_boundary.hpp +175 -0
- include/rerun/datatypes/uint16.hpp +62 -0
- include/rerun/datatypes/uint32.hpp +62 -0
- include/rerun/datatypes/uint64.hpp +62 -0
- include/rerun/datatypes/utf8.hpp +76 -0
- include/rerun/datatypes/utf8pair.hpp +62 -0
- include/rerun/datatypes/uuid.hpp +60 -0
- include/rerun/datatypes/uvec2d.hpp +76 -0
- include/rerun/datatypes/uvec3d.hpp +80 -0
- include/rerun/datatypes/uvec4d.hpp +59 -0
- include/rerun/datatypes/vec2d.hpp +76 -0
- include/rerun/datatypes/vec3d.hpp +80 -0
- include/rerun/datatypes/vec4d.hpp +84 -0
- include/rerun/datatypes/video_timestamp.hpp +67 -0
- include/rerun/datatypes/view_coordinates.hpp +87 -0
- include/rerun/datatypes/visible_time_range.hpp +57 -0
- include/rerun/datatypes.hpp +51 -0
- include/rerun/demo_utils.hpp +75 -0
- include/rerun/entity_path.hpp +20 -0
- include/rerun/error.hpp +180 -0
- include/rerun/half.hpp +10 -0
- include/rerun/image_utils.hpp +187 -0
- include/rerun/indicator_component.hpp +59 -0
- include/rerun/loggable.hpp +54 -0
- include/rerun/recording_stream.hpp +960 -0
- include/rerun/rerun_sdk_export.hpp +25 -0
- include/rerun/result.hpp +86 -0
- include/rerun/rotation3d.hpp +33 -0
- include/rerun/sdk_info.hpp +20 -0
- include/rerun/spawn.hpp +21 -0
- include/rerun/spawn_options.hpp +57 -0
- include/rerun/string_utils.hpp +16 -0
- include/rerun/third_party/cxxopts.hpp +2198 -0
- include/rerun/time_column.hpp +288 -0
- include/rerun/timeline.hpp +38 -0
- include/rerun/type_traits.hpp +40 -0
- include/rerun.hpp +86 -0
- lib/cmake/axel/axel-config.cmake +45 -0
- lib/cmake/axel/axelTargets-release.cmake +19 -0
- lib/cmake/axel/axelTargets.cmake +108 -0
- lib/cmake/momentum/Findre2.cmake +52 -0
- lib/cmake/momentum/momentum-config.cmake +67 -0
- lib/cmake/momentum/momentumTargets-release.cmake +259 -0
- lib/cmake/momentum/momentumTargets.cmake +377 -0
- lib/cmake/rerun_sdk/rerun_sdkConfig.cmake +70 -0
- lib/cmake/rerun_sdk/rerun_sdkConfigVersion.cmake +83 -0
- lib/cmake/rerun_sdk/rerun_sdkTargets-release.cmake +19 -0
- lib/cmake/rerun_sdk/rerun_sdkTargets.cmake +108 -0
- lib/libarrow.a +0 -0
- lib/libarrow_bundled_dependencies.a +0 -0
- lib/libaxel.a +0 -0
- lib/libmomentum_app_utils.a +0 -0
- lib/libmomentum_character.a +0 -0
- lib/libmomentum_character_sequence_solver.a +0 -0
- lib/libmomentum_character_solver.a +0 -0
- lib/libmomentum_common.a +0 -0
- lib/libmomentum_diff_ik.a +0 -0
- lib/libmomentum_io.a +0 -0
- lib/libmomentum_io_common.a +0 -0
- lib/libmomentum_io_fbx.a +0 -0
- lib/libmomentum_io_gltf.a +0 -0
- lib/libmomentum_io_legacy_json.a +0 -0
- lib/libmomentum_io_marker.a +0 -0
- lib/libmomentum_io_motion.a +0 -0
- lib/libmomentum_io_shape.a +0 -0
- lib/libmomentum_io_skeleton.a +0 -0
- lib/libmomentum_io_urdf.a +0 -0
- lib/libmomentum_marker_tracker.a +0 -0
- lib/libmomentum_math.a +0 -0
- lib/libmomentum_online_qr.a +0 -0
- lib/libmomentum_process_markers.a +0 -0
- lib/libmomentum_rerun.a +0 -0
- lib/libmomentum_simd_constraints.a +0 -0
- lib/libmomentum_simd_generalized_loss.a +0 -0
- lib/libmomentum_skeleton.a +0 -0
- lib/libmomentum_solver.a +0 -0
- lib/librerun_c__macos_arm64.a +0 -0
- lib/librerun_sdk.a +0 -0
- pymomentum/axel.cpython-313-darwin.so +0 -0
- pymomentum/backend/__init__.py +16 -0
- pymomentum/backend/skel_state_backend.py +614 -0
- pymomentum/backend/trs_backend.py +871 -0
- pymomentum/backend/utils.py +224 -0
- pymomentum/geometry.cpython-313-darwin.so +0 -0
- pymomentum/marker_tracking.cpython-313-darwin.so +0 -0
- pymomentum/quaternion.py +740 -0
- pymomentum/skel_state.py +514 -0
- pymomentum/solver.cpython-313-darwin.so +0 -0
- pymomentum/solver2.cpython-313-darwin.so +0 -0
- pymomentum/torch/character.py +856 -0
- pymomentum/torch/parameter_limits.py +494 -0
- pymomentum/torch/utility.py +20 -0
- pymomentum/trs.py +535 -0
- pymomentum_cpu-0.1.82.post0.dist-info/METADATA +121 -0
- pymomentum_cpu-0.1.82.post0.dist-info/RECORD +512 -0
- pymomentum_cpu-0.1.82.post0.dist-info/WHEEL +5 -0
- pymomentum_cpu-0.1.82.post0.dist-info/licenses/LICENSE +21 -0
pymomentum/quaternion.py
ADDED
|
@@ -0,0 +1,740 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
"""
|
|
7
|
+
Quaternion Utilities
|
|
8
|
+
====================
|
|
9
|
+
|
|
10
|
+
This module provides comprehensive utilities for working with quaternions in PyMomentum.
|
|
11
|
+
|
|
12
|
+
Quaternions are a mathematical representation of rotations in 3D space that offer several
|
|
13
|
+
advantages over other rotation representations like Euler angles or rotation matrices:
|
|
14
|
+
|
|
15
|
+
- **No gimbal lock**: Unlike Euler angles, quaternions don't suffer from singularities
|
|
16
|
+
- **Compact representation**: Only 4 components vs 9 for rotation matrices
|
|
17
|
+
- **Efficient composition**: Quaternion multiplication is faster than matrix multiplication
|
|
18
|
+
- **Smooth interpolation**: SLERP provides natural rotation interpolation
|
|
19
|
+
|
|
20
|
+
Quaternion Format
|
|
21
|
+
-----------------
|
|
22
|
+
This module uses the (x, y, z, w) format where:
|
|
23
|
+
|
|
24
|
+
- **(x, y, z)**: Vector part representing the rotation axis scaled by sin(θ/2)
|
|
25
|
+
- **w**: Scalar part representing cos(θ/2), where θ is the rotation angle
|
|
26
|
+
|
|
27
|
+
The identity quaternion is (0, 0, 0, 1), representing no rotation.
|
|
28
|
+
|
|
29
|
+
Core Operations
|
|
30
|
+
---------------
|
|
31
|
+
The module provides functions for:
|
|
32
|
+
|
|
33
|
+
- **Basic operations**: :func:`multiply`, :func:`conjugate`, :func:`inverse`, :func:`normalize`
|
|
34
|
+
- **Conversions**: :func:`from_axis_angle`, :func:`euler_xyz_to_quaternion`,
|
|
35
|
+
:func:`from_rotation_matrix`, :func:`to_rotation_matrix`
|
|
36
|
+
- **Vector operations**: :func:`rotate_vector`, :func:`from_two_vectors`
|
|
37
|
+
- **Interpolation**: :func:`slerp`, :func:`blend`
|
|
38
|
+
- **Utilities**: :func:`check`, :func:`split`, :func:`identity`
|
|
39
|
+
|
|
40
|
+
Example:
|
|
41
|
+
Basic quaternion operations::
|
|
42
|
+
|
|
43
|
+
import torch
|
|
44
|
+
from pymomentum import quaternion
|
|
45
|
+
|
|
46
|
+
# Create identity quaternion
|
|
47
|
+
q_identity = quaternion.identity()
|
|
48
|
+
|
|
49
|
+
# Create quaternion from axis-angle
|
|
50
|
+
axis_angle = torch.tensor([0.0, 0.0, 1.57]) # 90° rotation around Z
|
|
51
|
+
q_rot = quaternion.from_axis_angle(axis_angle)
|
|
52
|
+
|
|
53
|
+
# Rotate a vector
|
|
54
|
+
vector = torch.tensor([1.0, 0.0, 0.0])
|
|
55
|
+
rotated = quaternion.rotate_vector(q_rot, vector)
|
|
56
|
+
|
|
57
|
+
# Interpolate between quaternions
|
|
58
|
+
q_interp = quaternion.slerp(q_identity, q_rot, 0.5)
|
|
59
|
+
|
|
60
|
+
Note:
|
|
61
|
+
All functions expect quaternions as PyTorch tensors with the last dimension
|
|
62
|
+
having size 4, following the (x, y, z, w) format. Most functions support
|
|
63
|
+
batched operations for efficient processing of multiple quaternions.
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
from typing import Sequence, Tuple
|
|
67
|
+
|
|
68
|
+
import torch
|
|
69
|
+
|
|
70
|
+
# pyre-strict
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def check(q: torch.Tensor) -> None:
|
|
74
|
+
"""
|
|
75
|
+
Check if a tensor represents a quaternion.
|
|
76
|
+
|
|
77
|
+
:parameter q: A tensor representing a quaternion.
|
|
78
|
+
"""
|
|
79
|
+
assert q.size(-1) == 4, "Quaternion should have last dimension equal to 4."
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def split(q: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
83
|
+
"""
|
|
84
|
+
Split a quaternion into its scalar and vector parts.
|
|
85
|
+
|
|
86
|
+
:parameter q: A tensor representing a quaternion.
|
|
87
|
+
:return: The scalar and vector parts of the quaternion.
|
|
88
|
+
"""
|
|
89
|
+
check(q)
|
|
90
|
+
return q.narrow(-1, 3, 1), q.narrow(-1, 0, 3)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def multiply(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor:
|
|
94
|
+
"""
|
|
95
|
+
Multiply two quaternions together.
|
|
96
|
+
|
|
97
|
+
Normalizes input quaternions before multiplication for numerical stability.
|
|
98
|
+
For performance-critical code where quaternions are guaranteed to be normalized,
|
|
99
|
+
use :func:`multiply_assume_normalized`.
|
|
100
|
+
|
|
101
|
+
:param q1: A quaternion ((x, y, z), w)).
|
|
102
|
+
:param q2: A quaternion ((x, y, z), w)).
|
|
103
|
+
:return: The normalized product q1*q2.
|
|
104
|
+
"""
|
|
105
|
+
return multiply_assume_normalized(normalize(q1), normalize(q2))
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def multiply_assume_normalized(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor:
|
|
109
|
+
"""
|
|
110
|
+
Multiply two quaternions together, assuming they are already normalized.
|
|
111
|
+
|
|
112
|
+
This is a performance-optimized version of :func:`multiply` that skips
|
|
113
|
+
normalization of the input quaternions. Use this only when you are certain
|
|
114
|
+
both quaternions are already normalized.
|
|
115
|
+
|
|
116
|
+
:param q1: A normalized quaternion ((x, y, z), w)).
|
|
117
|
+
:param q2: A normalized quaternion ((x, y, z), w)).
|
|
118
|
+
:return: The product q1*q2.
|
|
119
|
+
"""
|
|
120
|
+
check(q1)
|
|
121
|
+
check(q2)
|
|
122
|
+
|
|
123
|
+
x1, y1, z1, w1 = q1.unbind(-1)
|
|
124
|
+
x2, y2, z2, w2 = q2.unbind(-1)
|
|
125
|
+
|
|
126
|
+
x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2
|
|
127
|
+
y = w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2
|
|
128
|
+
z = w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2
|
|
129
|
+
w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2
|
|
130
|
+
|
|
131
|
+
return torch.stack((x, y, z, w), dim=-1)
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def normalize(q: torch.Tensor) -> torch.Tensor:
|
|
135
|
+
"""
|
|
136
|
+
Normalize a quaternion.
|
|
137
|
+
|
|
138
|
+
:parameter q: A quaternion ((x, y, z), w)).
|
|
139
|
+
:return: The normalized quaternion.
|
|
140
|
+
"""
|
|
141
|
+
check(q)
|
|
142
|
+
return torch.nn.functional.normalize(q, dim=-1)
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def conjugate(q: torch.Tensor) -> torch.Tensor:
|
|
146
|
+
"""
|
|
147
|
+
Conjugate a quaternion.
|
|
148
|
+
|
|
149
|
+
:parameter q: A quaternion ((x, y, z), w)).
|
|
150
|
+
:return: The conjugate.
|
|
151
|
+
"""
|
|
152
|
+
check(q)
|
|
153
|
+
scalar, vec = split(q)
|
|
154
|
+
return torch.cat((-vec, scalar), -1)
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def inverse(q: torch.Tensor) -> torch.Tensor:
|
|
158
|
+
"""
|
|
159
|
+
Compute the inverse of a quaternion.
|
|
160
|
+
|
|
161
|
+
Uses numerical clamping to avoid division by very small numbers,
|
|
162
|
+
improving numerical stability for near-zero quaternions.
|
|
163
|
+
|
|
164
|
+
:parameter q: A quaternion ((x, y, z), w)).
|
|
165
|
+
:return: The inverse.
|
|
166
|
+
"""
|
|
167
|
+
check(q)
|
|
168
|
+
return conjugate(q) / torch.clamp((q * q).sum(-1, keepdim=True), min=1e-7)
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def _get_nonzero_denominator(d: torch.Tensor, eps: float) -> torch.Tensor:
|
|
172
|
+
near_zeros = torch.abs(d) < eps
|
|
173
|
+
d = d * (near_zeros.logical_not())
|
|
174
|
+
d = d + torch.sign(d) * (near_zeros * eps)
|
|
175
|
+
return d
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def quaternion_to_xyz_euler(q: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
|
|
179
|
+
"""
|
|
180
|
+
:param eps: a small number to avoid calling asin(1) or asin(-1).
|
|
181
|
+
Should not be smaller than 1e-6 as this can cause NaN gradients for some models.
|
|
182
|
+
"""
|
|
183
|
+
check(q)
|
|
184
|
+
q = normalize(q)
|
|
185
|
+
x, y, z, w = q.unbind(-1)
|
|
186
|
+
|
|
187
|
+
denom = _get_nonzero_denominator(
|
|
188
|
+
1 - 2 * (torch.square(x) + torch.square(y)), eps=eps
|
|
189
|
+
)
|
|
190
|
+
rx = torch.atan2(2 * (w * x + y * z), denom)
|
|
191
|
+
ry = torch.asin(torch.clamp(2 * (w * y - z * x), -1 + eps, 1 - eps))
|
|
192
|
+
|
|
193
|
+
denom = _get_nonzero_denominator(
|
|
194
|
+
1 - 2 * (torch.square(y) + torch.square(z)), eps=eps
|
|
195
|
+
)
|
|
196
|
+
rz = torch.atan2(2 * (w * z + x * y), denom)
|
|
197
|
+
return torch.stack([rx, ry, rz], -1)
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
def rotate_vector(q: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
|
|
201
|
+
"""
|
|
202
|
+
Rotate a vector by a quaternion.
|
|
203
|
+
|
|
204
|
+
Normalizes the input quaternion before rotation for numerical stability.
|
|
205
|
+
For performance-critical code where quaternions are guaranteed to be normalized,
|
|
206
|
+
use :func:`rotate_vector_assume_normalized`.
|
|
207
|
+
|
|
208
|
+
:param q: (nBatch x k x 4) tensor with the quaternions in ((x, y, z), w) format.
|
|
209
|
+
:param v: (nBatch x k x 3) vector.
|
|
210
|
+
:return: (nBatch x k x 3) rotated vectors.
|
|
211
|
+
"""
|
|
212
|
+
return rotate_vector_assume_normalized(normalize(q), v)
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
def rotate_vector_assume_normalized(q: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
|
|
216
|
+
"""
|
|
217
|
+
Rotate a vector by a quaternion, assuming the quaternion is already normalized.
|
|
218
|
+
|
|
219
|
+
This is a performance-optimized version of :func:`rotate_vector` that skips
|
|
220
|
+
normalization of the input quaternion. Use this only when you are certain
|
|
221
|
+
the quaternion is already normalized.
|
|
222
|
+
|
|
223
|
+
:param q: (nBatch x k x 4) tensor with normalized quaternions in ((x, y, z), w) format.
|
|
224
|
+
:param v: (nBatch x k x 3) vector.
|
|
225
|
+
:return: (nBatch x k x 3) rotated vectors.
|
|
226
|
+
"""
|
|
227
|
+
check(q)
|
|
228
|
+
r, axis = split(q)
|
|
229
|
+
av = torch.cross(axis, v, -1)
|
|
230
|
+
aav = torch.cross(axis, av, -1)
|
|
231
|
+
return v + 2 * (av * r + aav)
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
def to_rotation_matrix_assume_normalized(q: torch.Tensor) -> torch.Tensor:
|
|
235
|
+
"""
|
|
236
|
+
Convert quaternions to 3x3 rotation matrices.
|
|
237
|
+
|
|
238
|
+
:parameter q: (nBatch x k x 4) tensor with the quaternions in ((x, y, z), w) format.
|
|
239
|
+
:return: (nBatch x k x 3 x 3) tensor with 3x3 rotation matrices.
|
|
240
|
+
"""
|
|
241
|
+
check(q)
|
|
242
|
+
qx = q.select(-1, 0).unsqueeze(-1)
|
|
243
|
+
qy = q.select(-1, 1).unsqueeze(-1)
|
|
244
|
+
qz = q.select(-1, 2).unsqueeze(-1)
|
|
245
|
+
qw = q.select(-1, 3).unsqueeze(-1)
|
|
246
|
+
qx2 = torch.square(qx)
|
|
247
|
+
qy2 = torch.square(qy)
|
|
248
|
+
qz2 = torch.square(qz)
|
|
249
|
+
qxqy = qx * qy
|
|
250
|
+
qxqz = qx * qz
|
|
251
|
+
qxqw = qx * qw
|
|
252
|
+
qyqz = qy * qz
|
|
253
|
+
qyqw = qy * qw
|
|
254
|
+
qzqw = qz * qw
|
|
255
|
+
one = torch.ones_like(qx)
|
|
256
|
+
result = torch.cat(
|
|
257
|
+
[
|
|
258
|
+
one - 2 * (qy2 + qz2),
|
|
259
|
+
2 * (qxqy - qzqw),
|
|
260
|
+
2 * (qxqz + qyqw),
|
|
261
|
+
2 * (qxqy + qzqw),
|
|
262
|
+
one - 2 * (qx2 + qz2),
|
|
263
|
+
2 * (qyqz - qxqw),
|
|
264
|
+
2 * (qxqz - qyqw),
|
|
265
|
+
2 * (qyqz + qxqw),
|
|
266
|
+
one - 2 * (qx2 + qy2),
|
|
267
|
+
],
|
|
268
|
+
-1,
|
|
269
|
+
)
|
|
270
|
+
return result.reshape(list(q.shape[:-1]) + [3, 3])
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
def to_rotation_matrix(q: torch.Tensor) -> torch.Tensor:
|
|
274
|
+
"""
|
|
275
|
+
Convert quaternions to 3x3 rotation matrices.
|
|
276
|
+
|
|
277
|
+
:parameter q: (nBatch x k x 4) tensor with the quaternions in ((x, y, z), w) format.
|
|
278
|
+
:return: (nBatch x k x 3 x 3) tensor with 3x3 rotation matrices.
|
|
279
|
+
"""
|
|
280
|
+
return to_rotation_matrix_assume_normalized(normalize(q))
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
def identity(
|
|
284
|
+
size: Sequence[int] | None = None,
|
|
285
|
+
device: torch.device | None = None,
|
|
286
|
+
dtype: torch.dtype = torch.float32,
|
|
287
|
+
) -> torch.Tensor:
|
|
288
|
+
"""
|
|
289
|
+
Create a quaternion identity tensor.
|
|
290
|
+
|
|
291
|
+
:parameter sizes: A tuple of integers representing the size of the quaternion tensor.
|
|
292
|
+
:parameter device: The device on which to create the tensor.
|
|
293
|
+
:return: A quaternion identity tensor with the specified sizes and device.
|
|
294
|
+
"""
|
|
295
|
+
size = size or ()
|
|
296
|
+
return torch.cat(
|
|
297
|
+
[
|
|
298
|
+
torch.zeros(*size, 3, device=device, dtype=dtype),
|
|
299
|
+
torch.ones(*size, 1, device=device, dtype=dtype),
|
|
300
|
+
],
|
|
301
|
+
dim=-1,
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
def from_axis_angle(axis_angle: torch.Tensor) -> torch.Tensor:
|
|
306
|
+
"""
|
|
307
|
+
Convert an axis-angle tensor to a quaternion.
|
|
308
|
+
|
|
309
|
+
:parameter axis_angle: A tensor of shape (..., 3) representing the axis-angle.
|
|
310
|
+
:return: A tensor of shape (..., 4) representing the quaternion in ((x, y, z), w) format.
|
|
311
|
+
"""
|
|
312
|
+
angles = axis_angle.norm(dim=-1, keepdim=True)
|
|
313
|
+
normed_axes = axis_angle / angles.clamp(min=1e-8)
|
|
314
|
+
sin_half_angles = torch.sin(angles / 2)
|
|
315
|
+
cos_half_angles = torch.cos(angles / 2)
|
|
316
|
+
|
|
317
|
+
return torch.cat([normed_axes * sin_half_angles, cos_half_angles], dim=-1)
|
|
318
|
+
|
|
319
|
+
|
|
320
|
+
def euler_xyz_to_quaternion(euler_xyz: torch.Tensor) -> torch.Tensor:
|
|
321
|
+
"""
|
|
322
|
+
Convert Euler XYZ angles to a quaternion.
|
|
323
|
+
|
|
324
|
+
This function converts XYZ Euler angles to quaternions.
|
|
325
|
+
The rotation order is X-Y-Z, meaning first rotate around X-axis, then Y-axis,
|
|
326
|
+
then Z-axis.
|
|
327
|
+
|
|
328
|
+
:parameter euler_xyz: A tensor of shape (..., 3) representing the Euler XYZ angles
|
|
329
|
+
in order [roll, pitch, yaw].
|
|
330
|
+
:return: A tensor of shape (..., 4) representing the quaternion in ((x, y, z), w) format.
|
|
331
|
+
"""
|
|
332
|
+
roll, pitch, yaw = euler_xyz.unbind(-1)
|
|
333
|
+
|
|
334
|
+
cy = torch.cos(yaw * 0.5)
|
|
335
|
+
sy = torch.sin(yaw * 0.5)
|
|
336
|
+
cp = torch.cos(pitch * 0.5)
|
|
337
|
+
sp = torch.sin(pitch * 0.5)
|
|
338
|
+
cr = torch.cos(roll * 0.5)
|
|
339
|
+
sr = torch.sin(roll * 0.5)
|
|
340
|
+
|
|
341
|
+
x = sr * cp * cy - cr * sp * sy
|
|
342
|
+
y = cr * sp * cy + sr * cp * sy
|
|
343
|
+
z = cr * cp * sy - sr * sp * cy
|
|
344
|
+
w = cr * cp * cy + sr * sp * sy
|
|
345
|
+
|
|
346
|
+
return torch.stack((x, y, z, w), dim=-1)
|
|
347
|
+
|
|
348
|
+
|
|
349
|
+
def euler_zyx_to_quaternion(euler_zyx: torch.Tensor) -> torch.Tensor:
|
|
350
|
+
"""
|
|
351
|
+
Convert Euler ZYX angles to a quaternion.
|
|
352
|
+
|
|
353
|
+
This function converts ZYX Euler angles (yaw-pitch-roll convention) to quaternions.
|
|
354
|
+
The rotation order is Z-Y-X, meaning first rotate around Z-axis (yaw), then Y-axis (pitch),
|
|
355
|
+
then X-axis (roll).
|
|
356
|
+
|
|
357
|
+
:parameter euler_zyx: A tensor of shape (..., 3) representing the Euler ZYX angles
|
|
358
|
+
in order [yaw, pitch, roll].
|
|
359
|
+
:return: A tensor of shape (..., 4) representing the quaternion in ((x, y, z), w) format.
|
|
360
|
+
"""
|
|
361
|
+
yaw, pitch, roll = euler_zyx.unbind(-1)
|
|
362
|
+
|
|
363
|
+
# Compute half angles
|
|
364
|
+
cy = torch.cos(yaw * 0.5)
|
|
365
|
+
sy = torch.sin(yaw * 0.5)
|
|
366
|
+
cp = torch.cos(pitch * 0.5)
|
|
367
|
+
sp = torch.sin(pitch * 0.5)
|
|
368
|
+
cr = torch.cos(roll * 0.5)
|
|
369
|
+
sr = torch.sin(roll * 0.5)
|
|
370
|
+
|
|
371
|
+
# Compute quaternion components for ZYX convention
|
|
372
|
+
x = sr * cp * cy + cr * sp * sy
|
|
373
|
+
y = cr * sp * cy - sr * cp * sy
|
|
374
|
+
z = cr * cp * sy + sr * sp * cy
|
|
375
|
+
w = cr * cp * cy - sr * sp * sy
|
|
376
|
+
|
|
377
|
+
return torch.stack((x, y, z, w), dim=-1)
|
|
378
|
+
|
|
379
|
+
|
|
380
|
+
def from_rotation_matrix(matrices: torch.Tensor, eta: float = 1e-6) -> torch.Tensor:
|
|
381
|
+
"""
|
|
382
|
+
Convert a rotation matrix to a quaternion using numerically stable method.
|
|
383
|
+
|
|
384
|
+
This implementation uses the robust algorithm that computes all four quaternion
|
|
385
|
+
component candidates and selects the best-conditioned one, ensuring numerical
|
|
386
|
+
stability across all rotation matrix configurations.
|
|
387
|
+
|
|
388
|
+
:parameter matrices: A tensor of shape (..., 3, 3) representing the rotation matrices.
|
|
389
|
+
:parameter eta: Numerical precision threshold (unused, kept for compatibility).
|
|
390
|
+
:return: A tensor of shape (..., 4) representing the quaternions in ((x, y, z), w) format.
|
|
391
|
+
"""
|
|
392
|
+
m = matrices
|
|
393
|
+
m00, m01, m02 = m[..., 0, 0], m[..., 0, 1], m[..., 0, 2]
|
|
394
|
+
m10, m11, m12 = m[..., 1, 0], m[..., 1, 1], m[..., 1, 2]
|
|
395
|
+
m20, m21, m22 = m[..., 2, 0], m[..., 2, 1], m[..., 2, 2]
|
|
396
|
+
|
|
397
|
+
# Compute the absolute values of all four quaternion components
|
|
398
|
+
q_abs = torch.sqrt(
|
|
399
|
+
torch.clamp(
|
|
400
|
+
torch.stack(
|
|
401
|
+
[
|
|
402
|
+
1.0 + m00 + m11 + m22, # w component
|
|
403
|
+
1.0 + m00 - m11 - m22, # x component
|
|
404
|
+
1.0 - m00 + m11 - m22, # y component
|
|
405
|
+
1.0 - m00 - m11 + m22, # z component
|
|
406
|
+
],
|
|
407
|
+
dim=-1,
|
|
408
|
+
),
|
|
409
|
+
min=1e-15,
|
|
410
|
+
)
|
|
411
|
+
)
|
|
412
|
+
|
|
413
|
+
# We produce the desired quaternion multiplied by each of r, i, j, k
|
|
414
|
+
quat_by_rijk = torch.stack(
|
|
415
|
+
[
|
|
416
|
+
torch.stack(
|
|
417
|
+
[m21 - m12, m02 - m20, m10 - m01, torch.square(q_abs[..., 0])], dim=-1
|
|
418
|
+
),
|
|
419
|
+
torch.stack(
|
|
420
|
+
[torch.square(q_abs[..., 1]), m10 + m01, m02 + m20, m21 - m12], dim=-1
|
|
421
|
+
),
|
|
422
|
+
torch.stack(
|
|
423
|
+
[m10 + m01, torch.square(q_abs[..., 2]), m12 + m21, m02 - m20], dim=-1
|
|
424
|
+
),
|
|
425
|
+
torch.stack(
|
|
426
|
+
[m20 + m02, m21 + m12, torch.square(q_abs[..., 3]), m10 - m01], dim=-1
|
|
427
|
+
),
|
|
428
|
+
],
|
|
429
|
+
dim=-2,
|
|
430
|
+
)
|
|
431
|
+
|
|
432
|
+
# We floor here at 0.01 to avoid divide-by-zero but the exact level is not important;
|
|
433
|
+
# if q_abs is small, the candidate won't be picked.
|
|
434
|
+
flr = 0.01
|
|
435
|
+
quat_candidates = quat_by_rijk / (2.0 * torch.clamp(q_abs[..., None], min=flr))
|
|
436
|
+
|
|
437
|
+
# If not for numerical problems, quat_candidates[i] should be same (up to a sign),
|
|
438
|
+
# forall i; we pick the best-conditioned one (with the largest denominator)
|
|
439
|
+
result = quat_candidates[..., 0, :]
|
|
440
|
+
|
|
441
|
+
# Select the best candidate by picking the one with the largest denominator.
|
|
442
|
+
result = torch.where(
|
|
443
|
+
q_abs[..., 1, None] > q_abs[..., 0, None], quat_candidates[..., 1, :], result
|
|
444
|
+
)
|
|
445
|
+
result = torch.where(
|
|
446
|
+
torch.logical_and(
|
|
447
|
+
q_abs[..., 2, None] > q_abs[..., 0, None],
|
|
448
|
+
q_abs[..., 2, None] > q_abs[..., 1, None],
|
|
449
|
+
),
|
|
450
|
+
quat_candidates[..., 2, :],
|
|
451
|
+
result,
|
|
452
|
+
)
|
|
453
|
+
result = torch.where(
|
|
454
|
+
torch.logical_and(
|
|
455
|
+
torch.logical_and(
|
|
456
|
+
q_abs[..., 3, None] > q_abs[..., 0, None],
|
|
457
|
+
q_abs[..., 3, None] > q_abs[..., 1, None],
|
|
458
|
+
),
|
|
459
|
+
q_abs[..., 3, None] > q_abs[..., 2, None],
|
|
460
|
+
),
|
|
461
|
+
quat_candidates[..., 3, :],
|
|
462
|
+
result,
|
|
463
|
+
)
|
|
464
|
+
return normalize(result)
|
|
465
|
+
|
|
466
|
+
|
|
467
|
+
def check_and_normalize_weights(
|
|
468
|
+
quaternions: torch.Tensor, weights_in: torch.Tensor | None = None
|
|
469
|
+
) -> torch.Tensor:
|
|
470
|
+
"""
|
|
471
|
+
Check and normalize the weights for blending quaternions.
|
|
472
|
+
|
|
473
|
+
:parameter quaternions: A tensor of shape (..., k, 4) representing the quaternions to blend.
|
|
474
|
+
:parameter weights_in: An optional tensor of shape (..., k) representing the weights for each quaternion.
|
|
475
|
+
If not provided, all quaternions will be weighted equally.
|
|
476
|
+
:return: A tensor of shape (..., k) representing the normalized weights.
|
|
477
|
+
"""
|
|
478
|
+
if weights_in is not None:
|
|
479
|
+
weights = weights_in
|
|
480
|
+
else:
|
|
481
|
+
weights = torch.ones_like(quaternions.select(-1, 0))
|
|
482
|
+
|
|
483
|
+
if weights.dim() == quaternions.dim():
|
|
484
|
+
weights = weights.squeeze(-1)
|
|
485
|
+
|
|
486
|
+
if weights.dim() + 1 != quaternions.dim():
|
|
487
|
+
raise ValueError(
|
|
488
|
+
f"Expected weights vector to match quaternion vector in all dimensions except the last; "
|
|
489
|
+
f"got weights={weights.size()} and quaternions={quaternions.size()}"
|
|
490
|
+
)
|
|
491
|
+
|
|
492
|
+
for i in range(weights.dim()):
|
|
493
|
+
if weights.size(i) != quaternions.size(i):
|
|
494
|
+
raise ValueError(
|
|
495
|
+
f"Expected weights vector to match quaternion vector in all dimensions except the last; "
|
|
496
|
+
f"got weights={weights.size()} and quaternions={quaternions.size()}"
|
|
497
|
+
)
|
|
498
|
+
|
|
499
|
+
# Normalize the weights
|
|
500
|
+
weights = weights.clamp(min=0)
|
|
501
|
+
weight_sum = weights.sum(dim=-1, keepdim=True)
|
|
502
|
+
return weights / weight_sum.expand_as(weights)
|
|
503
|
+
|
|
504
|
+
|
|
505
|
+
def blend(
|
|
506
|
+
quaternions: torch.Tensor, weights_in: torch.Tensor | None = None
|
|
507
|
+
) -> torch.Tensor:
|
|
508
|
+
"""
|
|
509
|
+
Blend multiple quaternions together using the method described in
|
|
510
|
+
https://stackoverflow.com/questions/12374087/average-of-multiple-quaternions
|
|
511
|
+
and http://www.acsu.buffalo.edu/~johnc/ave_quat07.pdf.
|
|
512
|
+
|
|
513
|
+
:parameter quaternions: A tensor of shape (..., k, 4) representing the quaternions to blend.
|
|
514
|
+
:parameter weights_in: An optional tensor of shape (..., k) representing the weights for each quaternion.
|
|
515
|
+
If not provided, all quaternions will be weighted equally.
|
|
516
|
+
:return: A tensor of shape (..., 4) representing the blended quaternion.
|
|
517
|
+
"""
|
|
518
|
+
# If no weights, then assume evenly weighted:
|
|
519
|
+
weights = check_and_normalize_weights(quaternions, weights_in)
|
|
520
|
+
|
|
521
|
+
# Find average rotation by means described in the references above
|
|
522
|
+
check(quaternions)
|
|
523
|
+
outer_prod = torch.einsum("...i,...k->...ik", [quaternions, quaternions])
|
|
524
|
+
QtQ = (weights.unsqueeze(-1).unsqueeze(-1) * outer_prod).sum(dim=-3)
|
|
525
|
+
_, eigenvectors = torch.linalg.eigh(QtQ)
|
|
526
|
+
result = eigenvectors.select(dim=-1, index=3)
|
|
527
|
+
return result
|
|
528
|
+
|
|
529
|
+
|
|
530
|
+
def slerp(q0: torch.Tensor, q1: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
|
531
|
+
"""
|
|
532
|
+
Perform spherical linear interpolation (slerp) between two quaternions.
|
|
533
|
+
|
|
534
|
+
:parameter q0: The starting quaternion.
|
|
535
|
+
:parameter q1: The ending quaternion.
|
|
536
|
+
:parameter t: The interpolation parameter, where 0 <= t <= 1. t=0 corresponds to q0, t=1 corresponds to q1.
|
|
537
|
+
:return: The interpolated quaternion.
|
|
538
|
+
"""
|
|
539
|
+
check(q0)
|
|
540
|
+
check(q1)
|
|
541
|
+
|
|
542
|
+
# Compute the cosine of the angle between the two quaternions
|
|
543
|
+
cos_theta = torch.einsum("...x,...x", q0, q1)[..., None]
|
|
544
|
+
# Clamp for numerical stability
|
|
545
|
+
cos_theta = torch.clamp(cos_theta, -1.0, 1.0)
|
|
546
|
+
|
|
547
|
+
# If the dot product is negative, the quaternions have opposite handed-ness
|
|
548
|
+
# and slerp won't take the shorter path. Fix by reversing one quaternion.
|
|
549
|
+
q1 = torch.where(cos_theta < 0, -q1, q1)
|
|
550
|
+
cos_theta = torch.abs(cos_theta)
|
|
551
|
+
|
|
552
|
+
# Use linear interpolation for very close quaternions to avoid division by zero
|
|
553
|
+
lerp_result = normalize(q0 + t * (q1 - q0))
|
|
554
|
+
|
|
555
|
+
# Calculate the angle and the sin of the angle
|
|
556
|
+
eps = 1e-4
|
|
557
|
+
theta = torch.acos(torch.clamp(cos_theta, 0, 1.0 - eps))
|
|
558
|
+
inv_sin_theta = torch.reciprocal(torch.sin(theta))
|
|
559
|
+
c0 = torch.sin((1 - t) * theta) * inv_sin_theta
|
|
560
|
+
c1 = torch.sin(t * theta) * inv_sin_theta
|
|
561
|
+
|
|
562
|
+
slerp_result = normalize(c0 * q0 + c1 * q1)
|
|
563
|
+
|
|
564
|
+
return torch.where(cos_theta > 0.9995, lerp_result, slerp_result)
|
|
565
|
+
|
|
566
|
+
|
|
567
|
+
def from_two_vectors(v1: torch.Tensor, v2: torch.Tensor) -> torch.Tensor:
|
|
568
|
+
"""
|
|
569
|
+
Construct a quaternion that rotates one vector into another.
|
|
570
|
+
|
|
571
|
+
:parameter v1: The initial vector.
|
|
572
|
+
:parameter v2: The target vector.
|
|
573
|
+
:return: A quaternion representing the rotation from v1 to v2.
|
|
574
|
+
"""
|
|
575
|
+
# Ensure both vectors are unit vectors
|
|
576
|
+
v1 = torch.nn.functional.normalize(v1, dim=-1)
|
|
577
|
+
v2 = torch.nn.functional.normalize(v2, dim=-1)
|
|
578
|
+
|
|
579
|
+
scalar = torch.sum(v1 * v2, dim=-1, keepdim=True) + 1
|
|
580
|
+
vec = torch.cross(v1, v2, dim=-1)
|
|
581
|
+
|
|
582
|
+
# handle the anti-parallel case, we need a vector which is perpendicular to
|
|
583
|
+
# both v1 and v2 which we can obtain using the SVD:
|
|
584
|
+
m = torch.stack([v1, v2], dim=-2)
|
|
585
|
+
_, _, vh = torch.svd(m, compute_uv=True, some=False)
|
|
586
|
+
axis = vh[..., :, 2]
|
|
587
|
+
|
|
588
|
+
vec = torch.where(scalar <= 0, axis, vec)
|
|
589
|
+
return normalize(torch.cat((vec, scalar), dim=-1))
|
|
590
|
+
|
|
591
|
+
|
|
592
|
+
def normalize_backprop(q: torch.Tensor, grad: torch.Tensor) -> torch.Tensor:
|
|
593
|
+
"""
|
|
594
|
+
Custom backpropagation for quaternion normalization.
|
|
595
|
+
|
|
596
|
+
This function computes gradients for quaternion normalization in a numerically
|
|
597
|
+
stable way, avoiding potential issues with automatic differentiation when
|
|
598
|
+
quaternions are near zero norm.
|
|
599
|
+
|
|
600
|
+
:param q: The input quaternion tensor of shape (..., 4).
|
|
601
|
+
:param grad: The gradient from the output of shape (..., 4).
|
|
602
|
+
:return: The gradient with respect to the input quaternion q.
|
|
603
|
+
"""
|
|
604
|
+
with torch.no_grad():
|
|
605
|
+
s = torch.linalg.norm(q, dim=-1, keepdim=True)
|
|
606
|
+
g = s * s * grad - q * (torch.sum(q * grad, dim=-1, keepdim=True))
|
|
607
|
+
g = g / (s * s * s)
|
|
608
|
+
return g
|
|
609
|
+
|
|
610
|
+
|
|
611
|
+
def rotate_vector_backprop(
|
|
612
|
+
q: torch.Tensor, v: torch.Tensor, grad: torch.Tensor
|
|
613
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
614
|
+
"""
|
|
615
|
+
Custom backpropagation for quaternion vector rotation.
|
|
616
|
+
|
|
617
|
+
Computes gradients for the quaternion rotation operation using the
|
|
618
|
+
Euler-Rodrigues formula.
|
|
619
|
+
|
|
620
|
+
This version normalizes the input quaternion. For performance-critical code
|
|
621
|
+
where quaternions are guaranteed to be normalized, use
|
|
622
|
+
:func:`rotate_vector_backprop_assume_normalized`.
|
|
623
|
+
|
|
624
|
+
:param q: The quaternion tensor of shape (..., 4).
|
|
625
|
+
:param v: The vector tensor of shape (..., 3).
|
|
626
|
+
:param grad: The gradient from the output of shape (..., 3).
|
|
627
|
+
:return: A tuple of (grad_q, grad_v) representing gradients with respect
|
|
628
|
+
to the quaternion and vector respectively.
|
|
629
|
+
"""
|
|
630
|
+
q_normalized = normalize(q)
|
|
631
|
+
grad_q_normalized, grad_v = rotate_vector_backprop_assume_normalized(
|
|
632
|
+
q_normalized, v, grad
|
|
633
|
+
)
|
|
634
|
+
# Convert gradient from normalized quaternion back to original quaternion
|
|
635
|
+
grad_q = normalize_backprop(q, grad_q_normalized)
|
|
636
|
+
return grad_q, grad_v
|
|
637
|
+
|
|
638
|
+
|
|
639
|
+
def rotate_vector_backprop_assume_normalized(
|
|
640
|
+
q: torch.Tensor, v: torch.Tensor, grad: torch.Tensor
|
|
641
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
642
|
+
"""
|
|
643
|
+
Custom backpropagation for quaternion vector rotation assuming unit quaternions.
|
|
644
|
+
|
|
645
|
+
This is a performance-optimized version of :func:`rotate_vector_backprop` that
|
|
646
|
+
assumes the input quaternion is already normalized. Use this only when you are
|
|
647
|
+
certain the quaternion is normalized to avoid numerical issues.
|
|
648
|
+
|
|
649
|
+
:param q: The normalized quaternion tensor of shape (..., 4).
|
|
650
|
+
:param v: The vector tensor of shape (..., 3).
|
|
651
|
+
:param grad: The gradient from the output of shape (..., 3).
|
|
652
|
+
:return: A tuple of (grad_q, grad_v) representing gradients with respect
|
|
653
|
+
to the quaternion and vector respectively.
|
|
654
|
+
"""
|
|
655
|
+
with torch.no_grad():
|
|
656
|
+
# Split quaternion into axis and scalar parts
|
|
657
|
+
a = q[..., :3] # axis
|
|
658
|
+
w = q[..., 3:] # scalar
|
|
659
|
+
|
|
660
|
+
# Compute cross products needed for gradients
|
|
661
|
+
av = torch.cross(a, v, dim=-1)
|
|
662
|
+
ag = torch.cross(a, grad, dim=-1)
|
|
663
|
+
aag = torch.cross(a, ag, dim=-1)
|
|
664
|
+
gv = torch.cross(grad, v, dim=-1)
|
|
665
|
+
|
|
666
|
+
# Compute dot products needed for gradients
|
|
667
|
+
adv = (a * v).sum(dim=-1, keepdim=True)
|
|
668
|
+
adg = (a * grad).sum(dim=-1, keepdim=True)
|
|
669
|
+
vdg = (v * grad).sum(dim=-1, keepdim=True)
|
|
670
|
+
avdg = (av * grad).sum(dim=-1, keepdim=True)
|
|
671
|
+
|
|
672
|
+
# Calculate gradients
|
|
673
|
+
grad_v = grad - 2 * w * ag + 2 * aag
|
|
674
|
+
grad_w = 2 * avdg
|
|
675
|
+
grad_a = -2 * gv * w + 2 * (adv * grad + v * adg - 2 * a * vdg)
|
|
676
|
+
|
|
677
|
+
grad_q = torch.cat([grad_a, grad_w], dim=-1)
|
|
678
|
+
# For unit quaternions, project gradient to tangent space
|
|
679
|
+
grad_q = grad_q - q * torch.sum(q * grad_q, dim=-1, keepdim=True)
|
|
680
|
+
|
|
681
|
+
return grad_q, grad_v
|
|
682
|
+
|
|
683
|
+
|
|
684
|
+
def multiply_backprop(
|
|
685
|
+
q1: torch.Tensor, q2: torch.Tensor, grad_q: torch.Tensor
|
|
686
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
687
|
+
"""
|
|
688
|
+
Custom backpropagation for quaternion multiplication.
|
|
689
|
+
|
|
690
|
+
Computes gradients for quaternion multiplication with proper handling of
|
|
691
|
+
normalization.
|
|
692
|
+
|
|
693
|
+
This version normalizes the input quaternions. For performance-critical code
|
|
694
|
+
where quaternions are guaranteed to be normalized, use
|
|
695
|
+
:func:`multiply_backprop_assume_normalized`.
|
|
696
|
+
|
|
697
|
+
:param q1: The first quaternion tensor of shape (..., 4).
|
|
698
|
+
:param q2: The second quaternion tensor of shape (..., 4).
|
|
699
|
+
:param grad_q: The gradient from the output of shape (..., 4).
|
|
700
|
+
:return: A tuple of (grad_q1, grad_q2) representing gradients with respect
|
|
701
|
+
to the first and second quaternions respectively.
|
|
702
|
+
"""
|
|
703
|
+
q1_normalized = normalize(q1)
|
|
704
|
+
q2_normalized = normalize(q2)
|
|
705
|
+
grad_q1_normalized, grad_q2_normalized = multiply_backprop_assume_normalized(
|
|
706
|
+
q1_normalized, q2_normalized, grad_q
|
|
707
|
+
)
|
|
708
|
+
# Convert gradients from normalized quaternions back to original quaternions
|
|
709
|
+
grad_q1 = normalize_backprop(q1, grad_q1_normalized)
|
|
710
|
+
grad_q2 = normalize_backprop(q2, grad_q2_normalized)
|
|
711
|
+
return grad_q1, grad_q2
|
|
712
|
+
|
|
713
|
+
|
|
714
|
+
def multiply_backprop_assume_normalized(
|
|
715
|
+
q1: torch.Tensor, q2: torch.Tensor, grad_q: torch.Tensor
|
|
716
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
717
|
+
"""
|
|
718
|
+
Custom backpropagation for quaternion multiplication assuming unit quaternions.
|
|
719
|
+
|
|
720
|
+
Computes gradients for quaternion multiplication when both input quaternions
|
|
721
|
+
are assumed to be normalized. This is more efficient than the general case
|
|
722
|
+
but should only be used when quaternions are guaranteed to be unit quaternions.
|
|
723
|
+
|
|
724
|
+
:param q1: The first normalized quaternion tensor of shape (..., 4).
|
|
725
|
+
:param q2: The second normalized quaternion tensor of shape (..., 4).
|
|
726
|
+
:param grad_q: The gradient from the output of shape (..., 4).
|
|
727
|
+
:return: A tuple of (grad_q1, grad_q2) representing gradients with respect
|
|
728
|
+
to the first and second quaternions respectively.
|
|
729
|
+
"""
|
|
730
|
+
with torch.no_grad():
|
|
731
|
+
# Use quaternion multiplication properties for gradient computation
|
|
732
|
+
grad_q1 = multiply_assume_normalized(grad_q, conjugate(q2))
|
|
733
|
+
grad_q2 = multiply_assume_normalized(conjugate(q1), grad_q)
|
|
734
|
+
|
|
735
|
+
# For unit quaternions, project gradients to tangent space
|
|
736
|
+
q_result = multiply_assume_normalized(q1, q2)
|
|
737
|
+
grad_q1 = grad_q1 - q1 * torch.sum(q_result * grad_q, dim=-1, keepdim=True)
|
|
738
|
+
grad_q2 = grad_q2 - q2 * torch.sum(q_result * grad_q, dim=-1, keepdim=True)
|
|
739
|
+
|
|
740
|
+
return grad_q1, grad_q2
|