mediapipe-nightly 0.10.21.post20241223__cp39-cp39-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mediapipe/__init__.py +26 -0
- mediapipe/calculators/__init__.py +0 -0
- mediapipe/calculators/audio/__init__.py +0 -0
- mediapipe/calculators/audio/mfcc_mel_calculators_pb2.py +33 -0
- mediapipe/calculators/audio/rational_factor_resample_calculator_pb2.py +33 -0
- mediapipe/calculators/audio/spectrogram_calculator_pb2.py +37 -0
- mediapipe/calculators/audio/stabilized_log_calculator_pb2.py +31 -0
- mediapipe/calculators/audio/time_series_framer_calculator_pb2.py +33 -0
- mediapipe/calculators/core/__init__.py +0 -0
- mediapipe/calculators/core/bypass_calculator_pb2.py +31 -0
- mediapipe/calculators/core/clip_vector_size_calculator_pb2.py +31 -0
- mediapipe/calculators/core/concatenate_vector_calculator_pb2.py +31 -0
- mediapipe/calculators/core/constant_side_packet_calculator_pb2.py +39 -0
- mediapipe/calculators/core/dequantize_byte_array_calculator_pb2.py +31 -0
- mediapipe/calculators/core/flow_limiter_calculator_pb2.py +32 -0
- mediapipe/calculators/core/gate_calculator_pb2.py +33 -0
- mediapipe/calculators/core/get_vector_item_calculator_pb2.py +31 -0
- mediapipe/calculators/core/graph_profile_calculator_pb2.py +31 -0
- mediapipe/calculators/core/packet_cloner_calculator_pb2.py +31 -0
- mediapipe/calculators/core/packet_resampler_calculator_pb2.py +33 -0
- mediapipe/calculators/core/packet_thinner_calculator_pb2.py +33 -0
- mediapipe/calculators/core/quantize_float_vector_calculator_pb2.py +31 -0
- mediapipe/calculators/core/sequence_shift_calculator_pb2.py +31 -0
- mediapipe/calculators/core/split_vector_calculator_pb2.py +33 -0
- mediapipe/calculators/image/__init__.py +0 -0
- mediapipe/calculators/image/bilateral_filter_calculator_pb2.py +31 -0
- mediapipe/calculators/image/feature_detector_calculator_pb2.py +31 -0
- mediapipe/calculators/image/image_clone_calculator_pb2.py +31 -0
- mediapipe/calculators/image/image_cropping_calculator_pb2.py +33 -0
- mediapipe/calculators/image/image_transformation_calculator_pb2.py +38 -0
- mediapipe/calculators/image/mask_overlay_calculator_pb2.py +33 -0
- mediapipe/calculators/image/opencv_encoded_image_to_image_frame_calculator_pb2.py +31 -0
- mediapipe/calculators/image/opencv_image_encoder_calculator_pb2.py +35 -0
- mediapipe/calculators/image/recolor_calculator_pb2.py +34 -0
- mediapipe/calculators/image/rotation_mode_pb2.py +29 -0
- mediapipe/calculators/image/scale_image_calculator_pb2.py +34 -0
- mediapipe/calculators/image/segmentation_smoothing_calculator_pb2.py +31 -0
- mediapipe/calculators/image/set_alpha_calculator_pb2.py +31 -0
- mediapipe/calculators/image/warp_affine_calculator_pb2.py +36 -0
- mediapipe/calculators/internal/__init__.py +0 -0
- mediapipe/calculators/internal/callback_packet_calculator_pb2.py +33 -0
- mediapipe/calculators/tensor/__init__.py +0 -0
- mediapipe/calculators/tensor/audio_to_tensor_calculator_pb2.py +35 -0
- mediapipe/calculators/tensor/bert_preprocessor_calculator_pb2.py +31 -0
- mediapipe/calculators/tensor/feedback_tensors_calculator_pb2.py +37 -0
- mediapipe/calculators/tensor/image_to_tensor_calculator_pb2.py +40 -0
- mediapipe/calculators/tensor/inference_calculator_pb2.py +63 -0
- mediapipe/calculators/tensor/landmarks_to_tensor_calculator_pb2.py +33 -0
- mediapipe/calculators/tensor/regex_preprocessor_calculator_pb2.py +31 -0
- mediapipe/calculators/tensor/tensor_converter_calculator_pb2.py +34 -0
- mediapipe/calculators/tensor/tensor_to_joints_calculator_pb2.py +31 -0
- mediapipe/calculators/tensor/tensors_readback_calculator_pb2.py +35 -0
- mediapipe/calculators/tensor/tensors_to_audio_calculator_pb2.py +33 -0
- mediapipe/calculators/tensor/tensors_to_classification_calculator_pb2.py +44 -0
- mediapipe/calculators/tensor/tensors_to_detections_calculator_pb2.py +39 -0
- mediapipe/calculators/tensor/tensors_to_floats_calculator_pb2.py +33 -0
- mediapipe/calculators/tensor/tensors_to_landmarks_calculator_pb2.py +33 -0
- mediapipe/calculators/tensor/tensors_to_segmentation_calculator_pb2.py +34 -0
- mediapipe/calculators/tensor/vector_to_tensor_calculator_pb2.py +27 -0
- mediapipe/calculators/tflite/__init__.py +0 -0
- mediapipe/calculators/tflite/ssd_anchors_calculator_pb2.py +32 -0
- mediapipe/calculators/tflite/tflite_converter_calculator_pb2.py +33 -0
- mediapipe/calculators/tflite/tflite_custom_op_resolver_calculator_pb2.py +31 -0
- mediapipe/calculators/tflite/tflite_inference_calculator_pb2.py +49 -0
- mediapipe/calculators/tflite/tflite_tensors_to_classification_calculator_pb2.py +31 -0
- mediapipe/calculators/tflite/tflite_tensors_to_detections_calculator_pb2.py +31 -0
- mediapipe/calculators/tflite/tflite_tensors_to_landmarks_calculator_pb2.py +33 -0
- mediapipe/calculators/tflite/tflite_tensors_to_segmentation_calculator_pb2.py +31 -0
- mediapipe/calculators/util/__init__.py +0 -0
- mediapipe/calculators/util/align_hand_to_pose_in_world_calculator_pb2.py +31 -0
- mediapipe/calculators/util/annotation_overlay_calculator_pb2.py +32 -0
- mediapipe/calculators/util/association_calculator_pb2.py +31 -0
- mediapipe/calculators/util/collection_has_min_size_calculator_pb2.py +31 -0
- mediapipe/calculators/util/combine_joints_calculator_pb2.py +36 -0
- mediapipe/calculators/util/detection_label_id_to_text_calculator_pb2.py +36 -0
- mediapipe/calculators/util/detections_to_rects_calculator_pb2.py +33 -0
- mediapipe/calculators/util/detections_to_render_data_calculator_pb2.py +33 -0
- mediapipe/calculators/util/face_to_rect_calculator_pb2.py +26 -0
- mediapipe/calculators/util/filter_detections_calculator_pb2.py +31 -0
- mediapipe/calculators/util/flat_color_image_calculator_pb2.py +32 -0
- mediapipe/calculators/util/labels_to_render_data_calculator_pb2.py +34 -0
- mediapipe/calculators/util/landmark_projection_calculator_pb2.py +31 -0
- mediapipe/calculators/util/landmarks_refinement_calculator_pb2.py +41 -0
- mediapipe/calculators/util/landmarks_smoothing_calculator_pb2.py +33 -0
- mediapipe/calculators/util/landmarks_to_detection_calculator_pb2.py +31 -0
- mediapipe/calculators/util/landmarks_to_floats_calculator_pb2.py +31 -0
- mediapipe/calculators/util/landmarks_to_render_data_calculator_pb2.py +32 -0
- mediapipe/calculators/util/landmarks_transformation_calculator_pb2.py +37 -0
- mediapipe/calculators/util/latency_pb2.py +26 -0
- mediapipe/calculators/util/local_file_contents_calculator_pb2.py +31 -0
- mediapipe/calculators/util/logic_calculator_pb2.py +34 -0
- mediapipe/calculators/util/non_max_suppression_calculator_pb2.py +35 -0
- mediapipe/calculators/util/packet_frequency_calculator_pb2.py +31 -0
- mediapipe/calculators/util/packet_frequency_pb2.py +26 -0
- mediapipe/calculators/util/packet_latency_calculator_pb2.py +31 -0
- mediapipe/calculators/util/rect_to_render_data_calculator_pb2.py +32 -0
- mediapipe/calculators/util/rect_to_render_scale_calculator_pb2.py +31 -0
- mediapipe/calculators/util/rect_transformation_calculator_pb2.py +31 -0
- mediapipe/calculators/util/refine_landmarks_from_heatmap_calculator_pb2.py +31 -0
- mediapipe/calculators/util/resource_provider_calculator_pb2.py +28 -0
- mediapipe/calculators/util/set_joints_visibility_calculator_pb2.py +41 -0
- mediapipe/calculators/util/thresholding_calculator_pb2.py +31 -0
- mediapipe/calculators/util/timed_box_list_id_to_label_calculator_pb2.py +31 -0
- mediapipe/calculators/util/timed_box_list_to_render_data_calculator_pb2.py +32 -0
- mediapipe/calculators/util/top_k_scores_calculator_pb2.py +31 -0
- mediapipe/calculators/util/visibility_copy_calculator_pb2.py +27 -0
- mediapipe/calculators/util/visibility_smoothing_calculator_pb2.py +31 -0
- mediapipe/calculators/video/__init__.py +0 -0
- mediapipe/calculators/video/box_detector_calculator_pb2.py +32 -0
- mediapipe/calculators/video/box_tracker_calculator_pb2.py +32 -0
- mediapipe/calculators/video/flow_packager_calculator_pb2.py +32 -0
- mediapipe/calculators/video/flow_to_image_calculator_pb2.py +31 -0
- mediapipe/calculators/video/motion_analysis_calculator_pb2.py +42 -0
- mediapipe/calculators/video/opencv_video_encoder_calculator_pb2.py +31 -0
- mediapipe/calculators/video/tool/__init__.py +0 -0
- mediapipe/calculators/video/tool/flow_quantizer_model_pb2.py +26 -0
- mediapipe/calculators/video/tracked_detection_manager_calculator_pb2.py +32 -0
- mediapipe/calculators/video/video_pre_stream_calculator_pb2.py +35 -0
- mediapipe/examples/__init__.py +14 -0
- mediapipe/examples/desktop/__init__.py +14 -0
- mediapipe/framework/__init__.py +0 -0
- mediapipe/framework/calculator_options_pb2.py +29 -0
- mediapipe/framework/calculator_pb2.py +59 -0
- mediapipe/framework/calculator_profile_pb2.py +48 -0
- mediapipe/framework/deps/__init__.py +0 -0
- mediapipe/framework/deps/proto_descriptor_pb2.py +29 -0
- mediapipe/framework/formats/__init__.py +0 -0
- mediapipe/framework/formats/affine_transform_data_pb2.py +28 -0
- mediapipe/framework/formats/annotation/__init__.py +0 -0
- mediapipe/framework/formats/annotation/locus_pb2.py +32 -0
- mediapipe/framework/formats/annotation/rasterization_pb2.py +29 -0
- mediapipe/framework/formats/body_rig_pb2.py +28 -0
- mediapipe/framework/formats/classification_pb2.py +31 -0
- mediapipe/framework/formats/detection_pb2.py +36 -0
- mediapipe/framework/formats/image_file_properties_pb2.py +26 -0
- mediapipe/framework/formats/image_format_pb2.py +29 -0
- mediapipe/framework/formats/landmark_pb2.py +37 -0
- mediapipe/framework/formats/location_data_pb2.py +38 -0
- mediapipe/framework/formats/matrix_data_pb2.py +31 -0
- mediapipe/framework/formats/motion/__init__.py +0 -0
- mediapipe/framework/formats/motion/optical_flow_field_data_pb2.py +30 -0
- mediapipe/framework/formats/object_detection/__init__.py +0 -0
- mediapipe/framework/formats/object_detection/anchor_pb2.py +26 -0
- mediapipe/framework/formats/rect_pb2.py +29 -0
- mediapipe/framework/formats/time_series_header_pb2.py +28 -0
- mediapipe/framework/graph_runtime_info_pb2.py +31 -0
- mediapipe/framework/mediapipe_options_pb2.py +27 -0
- mediapipe/framework/packet_factory_pb2.py +31 -0
- mediapipe/framework/packet_generator_pb2.py +33 -0
- mediapipe/framework/status_handler_pb2.py +28 -0
- mediapipe/framework/stream_handler/__init__.py +0 -0
- mediapipe/framework/stream_handler/default_input_stream_handler_pb2.py +27 -0
- mediapipe/framework/stream_handler/fixed_size_input_stream_handler_pb2.py +27 -0
- mediapipe/framework/stream_handler/sync_set_input_stream_handler_pb2.py +29 -0
- mediapipe/framework/stream_handler/timestamp_align_input_stream_handler_pb2.py +27 -0
- mediapipe/framework/stream_handler_pb2.py +30 -0
- mediapipe/framework/test_calculators_pb2.py +31 -0
- mediapipe/framework/thread_pool_executor_pb2.py +29 -0
- mediapipe/framework/tool/__init__.py +0 -0
- mediapipe/framework/tool/calculator_graph_template_pb2.py +44 -0
- mediapipe/framework/tool/field_data_pb2.py +28 -0
- mediapipe/framework/tool/node_chain_subgraph_pb2.py +31 -0
- mediapipe/framework/tool/packet_generator_wrapper_calculator_pb2.py +28 -0
- mediapipe/framework/tool/source_pb2.py +33 -0
- mediapipe/framework/tool/switch_container_pb2.py +32 -0
- mediapipe/gpu/__init__.py +0 -0
- mediapipe/gpu/copy_calculator_pb2.py +33 -0
- mediapipe/gpu/gl_animation_overlay_calculator_pb2.py +31 -0
- mediapipe/gpu/gl_context_options_pb2.py +31 -0
- mediapipe/gpu/gl_scaler_calculator_pb2.py +32 -0
- mediapipe/gpu/gl_surface_sink_calculator_pb2.py +32 -0
- mediapipe/gpu/gpu_origin_pb2.py +29 -0
- mediapipe/gpu/scale_mode_pb2.py +28 -0
- mediapipe/model_maker/__init__.py +27 -0
- mediapipe/model_maker/setup.py +107 -0
- mediapipe/modules/__init__.py +0 -0
- mediapipe/modules/face_detection/__init__.py +0 -0
- mediapipe/modules/face_detection/face_detection_full_range_cpu.binarypb +0 -0
- mediapipe/modules/face_detection/face_detection_full_range_sparse.tflite +0 -0
- mediapipe/modules/face_detection/face_detection_pb2.py +30 -0
- mediapipe/modules/face_detection/face_detection_short_range.tflite +0 -0
- mediapipe/modules/face_detection/face_detection_short_range_cpu.binarypb +0 -0
- mediapipe/modules/face_geometry/__init__.py +0 -0
- mediapipe/modules/face_geometry/data/__init__.py +0 -0
- mediapipe/modules/face_geometry/effect_renderer_calculator_pb2.py +27 -0
- mediapipe/modules/face_geometry/env_generator_calculator_pb2.py +28 -0
- mediapipe/modules/face_geometry/geometry_pipeline_calculator_pb2.py +27 -0
- mediapipe/modules/face_geometry/libs/__init__.py +0 -0
- mediapipe/modules/face_geometry/protos/__init__.py +0 -0
- mediapipe/modules/face_geometry/protos/environment_pb2.py +31 -0
- mediapipe/modules/face_geometry/protos/face_geometry_pb2.py +29 -0
- mediapipe/modules/face_geometry/protos/geometry_pipeline_metadata_pb2.py +32 -0
- mediapipe/modules/face_geometry/protos/mesh_3d_pb2.py +31 -0
- mediapipe/modules/face_landmark/__init__.py +0 -0
- mediapipe/modules/face_landmark/face_landmark.tflite +0 -0
- mediapipe/modules/face_landmark/face_landmark_front_cpu.binarypb +0 -0
- mediapipe/modules/face_landmark/face_landmark_with_attention.tflite +0 -0
- mediapipe/modules/hand_landmark/__init__.py +0 -0
- mediapipe/modules/hand_landmark/calculators/__init__.py +0 -0
- mediapipe/modules/hand_landmark/hand_landmark_full.tflite +0 -0
- mediapipe/modules/hand_landmark/hand_landmark_lite.tflite +0 -0
- mediapipe/modules/hand_landmark/hand_landmark_tracking_cpu.binarypb +0 -0
- mediapipe/modules/hand_landmark/handedness.txt +2 -0
- mediapipe/modules/holistic_landmark/__init__.py +0 -0
- mediapipe/modules/holistic_landmark/calculators/__init__.py +0 -0
- mediapipe/modules/holistic_landmark/calculators/roi_tracking_calculator_pb2.py +37 -0
- mediapipe/modules/holistic_landmark/hand_recrop.tflite +0 -0
- mediapipe/modules/holistic_landmark/holistic_landmark_cpu.binarypb +0 -0
- mediapipe/modules/iris_landmark/__init__.py +0 -0
- mediapipe/modules/iris_landmark/iris_landmark.tflite +0 -0
- mediapipe/modules/objectron/__init__.py +0 -0
- mediapipe/modules/objectron/calculators/__init__.py +0 -0
- mediapipe/modules/objectron/calculators/a_r_capture_metadata_pb2.py +102 -0
- mediapipe/modules/objectron/calculators/annotation_data_pb2.py +38 -0
- mediapipe/modules/objectron/calculators/belief_decoder_config_pb2.py +28 -0
- mediapipe/modules/objectron/calculators/camera_parameters_pb2.py +30 -0
- mediapipe/modules/objectron/calculators/filter_detection_calculator_pb2.py +35 -0
- mediapipe/modules/objectron/calculators/frame_annotation_to_rect_calculator_pb2.py +31 -0
- mediapipe/modules/objectron/calculators/frame_annotation_tracker_calculator_pb2.py +31 -0
- mediapipe/modules/objectron/calculators/lift_2d_frame_annotation_to_3d_calculator_pb2.py +32 -0
- mediapipe/modules/objectron/calculators/object_pb2.py +38 -0
- mediapipe/modules/objectron/calculators/tensors_to_objects_calculator_pb2.py +32 -0
- mediapipe/modules/objectron/calculators/tflite_tensors_to_objects_calculator_pb2.py +32 -0
- mediapipe/modules/objectron/object_detection_oidv4_labelmap.txt +24 -0
- mediapipe/modules/objectron/objectron_cpu.binarypb +0 -0
- mediapipe/modules/palm_detection/__init__.py +0 -0
- mediapipe/modules/palm_detection/palm_detection_full.tflite +0 -0
- mediapipe/modules/palm_detection/palm_detection_lite.tflite +0 -0
- mediapipe/modules/pose_detection/__init__.py +0 -0
- mediapipe/modules/pose_detection/pose_detection.tflite +0 -0
- mediapipe/modules/pose_landmark/__init__.py +0 -0
- mediapipe/modules/pose_landmark/pose_landmark_cpu.binarypb +0 -0
- mediapipe/modules/pose_landmark/pose_landmark_full.tflite +0 -0
- mediapipe/modules/selfie_segmentation/__init__.py +0 -0
- mediapipe/modules/selfie_segmentation/selfie_segmentation.tflite +0 -0
- mediapipe/modules/selfie_segmentation/selfie_segmentation_cpu.binarypb +0 -0
- mediapipe/modules/selfie_segmentation/selfie_segmentation_landscape.tflite +0 -0
- mediapipe/python/__init__.py +29 -0
- mediapipe/python/_framework_bindings.cpython-39-x86_64-linux-gnu.so +0 -0
- mediapipe/python/calculator_graph_test.py +251 -0
- mediapipe/python/image_frame_test.py +194 -0
- mediapipe/python/image_test.py +218 -0
- mediapipe/python/packet_creator.py +275 -0
- mediapipe/python/packet_getter.py +120 -0
- mediapipe/python/packet_test.py +533 -0
- mediapipe/python/solution_base.py +604 -0
- mediapipe/python/solution_base_test.py +396 -0
- mediapipe/python/solutions/__init__.py +27 -0
- mediapipe/python/solutions/download_utils.py +37 -0
- mediapipe/python/solutions/drawing_styles.py +249 -0
- mediapipe/python/solutions/drawing_utils.py +320 -0
- mediapipe/python/solutions/drawing_utils_test.py +258 -0
- mediapipe/python/solutions/face_detection.py +105 -0
- mediapipe/python/solutions/face_detection_test.py +92 -0
- mediapipe/python/solutions/face_mesh.py +125 -0
- mediapipe/python/solutions/face_mesh_connections.py +500 -0
- mediapipe/python/solutions/face_mesh_test.py +170 -0
- mediapipe/python/solutions/hands.py +153 -0
- mediapipe/python/solutions/hands_connections.py +32 -0
- mediapipe/python/solutions/hands_test.py +219 -0
- mediapipe/python/solutions/holistic.py +167 -0
- mediapipe/python/solutions/holistic_test.py +142 -0
- mediapipe/python/solutions/objectron.py +288 -0
- mediapipe/python/solutions/objectron_test.py +81 -0
- mediapipe/python/solutions/pose.py +192 -0
- mediapipe/python/solutions/pose_connections.py +22 -0
- mediapipe/python/solutions/pose_test.py +262 -0
- mediapipe/python/solutions/selfie_segmentation.py +76 -0
- mediapipe/python/solutions/selfie_segmentation_test.py +68 -0
- mediapipe/python/timestamp_test.py +78 -0
- mediapipe/tasks/__init__.py +14 -0
- mediapipe/tasks/cc/__init__.py +0 -0
- mediapipe/tasks/cc/audio/__init__.py +0 -0
- mediapipe/tasks/cc/audio/audio_classifier/__init__.py +0 -0
- mediapipe/tasks/cc/audio/audio_classifier/proto/__init__.py +0 -0
- mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options_pb2.py +35 -0
- mediapipe/tasks/cc/audio/audio_embedder/__init__.py +0 -0
- mediapipe/tasks/cc/audio/audio_embedder/proto/__init__.py +0 -0
- mediapipe/tasks/cc/audio/audio_embedder/proto/audio_embedder_graph_options_pb2.py +35 -0
- mediapipe/tasks/cc/audio/core/__init__.py +0 -0
- mediapipe/tasks/cc/audio/utils/__init__.py +0 -0
- mediapipe/tasks/cc/components/__init__.py +0 -0
- mediapipe/tasks/cc/components/calculators/__init__.py +0 -0
- mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator_pb2.py +31 -0
- mediapipe/tasks/cc/components/calculators/score_calibration_calculator_pb2.py +35 -0
- mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator_pb2.py +32 -0
- mediapipe/tasks/cc/components/containers/__init__.py +0 -0
- mediapipe/tasks/cc/components/containers/proto/__init__.py +0 -0
- mediapipe/tasks/cc/components/containers/proto/classifications_pb2.py +30 -0
- mediapipe/tasks/cc/components/containers/proto/embeddings_pb2.py +35 -0
- mediapipe/tasks/cc/components/containers/proto/landmarks_detection_result_pb2.py +32 -0
- mediapipe/tasks/cc/components/processors/__init__.py +0 -0
- mediapipe/tasks/cc/components/processors/proto/__init__.py +0 -0
- mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options_pb2.py +38 -0
- mediapipe/tasks/cc/components/processors/proto/classifier_options_pb2.py +27 -0
- mediapipe/tasks/cc/components/processors/proto/detection_postprocessing_graph_options_pb2.py +36 -0
- mediapipe/tasks/cc/components/processors/proto/detector_options_pb2.py +27 -0
- mediapipe/tasks/cc/components/processors/proto/embedder_options_pb2.py +27 -0
- mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options_pb2.py +32 -0
- mediapipe/tasks/cc/components/processors/proto/image_preprocessing_graph_options_pb2.py +34 -0
- mediapipe/tasks/cc/components/processors/proto/text_model_type_pb2.py +28 -0
- mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options_pb2.py +32 -0
- mediapipe/tasks/cc/components/utils/__init__.py +0 -0
- mediapipe/tasks/cc/core/__init__.py +0 -0
- mediapipe/tasks/cc/core/proto/__init__.py +0 -0
- mediapipe/tasks/cc/core/proto/acceleration_pb2.py +28 -0
- mediapipe/tasks/cc/core/proto/base_options_pb2.py +30 -0
- mediapipe/tasks/cc/core/proto/external_file_pb2.py +31 -0
- mediapipe/tasks/cc/core/proto/inference_subgraph_pb2.py +32 -0
- mediapipe/tasks/cc/core/proto/model_resources_calculator_pb2.py +32 -0
- mediapipe/tasks/cc/genai/__init__.py +0 -0
- mediapipe/tasks/cc/genai/inference/__init__.py +0 -0
- mediapipe/tasks/cc/genai/inference/c/__init__.py +0 -0
- mediapipe/tasks/cc/genai/inference/calculators/__init__.py +0 -0
- mediapipe/tasks/cc/genai/inference/calculators/detokenizer_calculator_pb2.py +27 -0
- mediapipe/tasks/cc/genai/inference/calculators/llm_gpu_calculator_pb2.py +32 -0
- mediapipe/tasks/cc/genai/inference/calculators/model_data_calculator_pb2.py +27 -0
- mediapipe/tasks/cc/genai/inference/calculators/tokenizer_calculator_pb2.py +29 -0
- mediapipe/tasks/cc/genai/inference/common/__init__.py +0 -0
- mediapipe/tasks/cc/genai/inference/proto/__init__.py +0 -0
- mediapipe/tasks/cc/genai/inference/proto/llm_file_metadata_pb2.py +32 -0
- mediapipe/tasks/cc/genai/inference/proto/llm_params_pb2.py +33 -0
- mediapipe/tasks/cc/genai/inference/proto/prompt_template_pb2.py +27 -0
- mediapipe/tasks/cc/genai/inference/proto/sampler_params_pb2.py +29 -0
- mediapipe/tasks/cc/genai/inference/proto/transformer_params_pb2.py +45 -0
- mediapipe/tasks/cc/genai/inference/utils/__init__.py +0 -0
- mediapipe/tasks/cc/genai/inference/utils/llm_utils/__init__.py +0 -0
- mediapipe/tasks/cc/genai/inference/utils/xnn_utils/__init__.py +0 -0
- mediapipe/tasks/cc/metadata/__init__.py +0 -0
- mediapipe/tasks/cc/metadata/python/__init__.py +0 -0
- mediapipe/tasks/cc/metadata/python/_pywrap_metadata_version.cpython-39-x86_64-linux-gnu.so +0 -0
- mediapipe/tasks/cc/metadata/tests/__init__.py +0 -0
- mediapipe/tasks/cc/metadata/utils/__init__.py +0 -0
- mediapipe/tasks/cc/text/__init__.py +0 -0
- mediapipe/tasks/cc/text/custom_ops/__init__.py +0 -0
- mediapipe/tasks/cc/text/custom_ops/ragged/__init__.py +0 -0
- mediapipe/tasks/cc/text/custom_ops/sentencepiece/__init__.py +0 -0
- mediapipe/tasks/cc/text/custom_ops/sentencepiece/testdata/__init__.py +0 -0
- mediapipe/tasks/cc/text/language_detector/__init__.py +0 -0
- mediapipe/tasks/cc/text/language_detector/custom_ops/__init__.py +0 -0
- mediapipe/tasks/cc/text/language_detector/custom_ops/utils/__init__.py +0 -0
- mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash/__init__.py +0 -0
- mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/__init__.py +0 -0
- mediapipe/tasks/cc/text/text_classifier/__init__.py +0 -0
- mediapipe/tasks/cc/text/text_classifier/proto/__init__.py +0 -0
- mediapipe/tasks/cc/text/text_classifier/proto/text_classifier_graph_options_pb2.py +35 -0
- mediapipe/tasks/cc/text/text_embedder/__init__.py +0 -0
- mediapipe/tasks/cc/text/text_embedder/proto/__init__.py +0 -0
- mediapipe/tasks/cc/text/text_embedder/proto/text_embedder_graph_options_pb2.py +35 -0
- mediapipe/tasks/cc/text/tokenizers/__init__.py +0 -0
- mediapipe/tasks/cc/text/utils/__init__.py +0 -0
- mediapipe/tasks/cc/vision/__init__.py +0 -0
- mediapipe/tasks/cc/vision/core/__init__.py +0 -0
- mediapipe/tasks/cc/vision/custom_ops/__init__.py +0 -0
- mediapipe/tasks/cc/vision/face_detector/__init__.py +0 -0
- mediapipe/tasks/cc/vision/face_detector/proto/__init__.py +0 -0
- mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options_pb2.py +34 -0
- mediapipe/tasks/cc/vision/face_geometry/__init__.py +0 -0
- mediapipe/tasks/cc/vision/face_geometry/calculators/__init__.py +0 -0
- mediapipe/tasks/cc/vision/face_geometry/calculators/env_generator_calculator_pb2.py +28 -0
- mediapipe/tasks/cc/vision/face_geometry/calculators/geometry_pipeline_calculator_pb2.py +29 -0
- mediapipe/tasks/cc/vision/face_geometry/data/__init__.py +0 -0
- mediapipe/tasks/cc/vision/face_geometry/libs/__init__.py +0 -0
- mediapipe/tasks/cc/vision/face_geometry/proto/__init__.py +0 -0
- mediapipe/tasks/cc/vision/face_geometry/proto/environment_pb2.py +31 -0
- mediapipe/tasks/cc/vision/face_geometry/proto/face_geometry_graph_options_pb2.py +29 -0
- mediapipe/tasks/cc/vision/face_geometry/proto/face_geometry_pb2.py +29 -0
- mediapipe/tasks/cc/vision/face_geometry/proto/geometry_pipeline_metadata_pb2.py +32 -0
- mediapipe/tasks/cc/vision/face_geometry/proto/mesh_3d_pb2.py +31 -0
- mediapipe/tasks/cc/vision/face_landmarker/__init__.py +0 -0
- mediapipe/tasks/cc/vision/face_landmarker/proto/__init__.py +0 -0
- mediapipe/tasks/cc/vision/face_landmarker/proto/face_blendshapes_graph_options_pb2.py +34 -0
- mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarker_graph_options_pb2.py +37 -0
- mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarks_detector_graph_options_pb2.py +35 -0
- mediapipe/tasks/cc/vision/face_landmarker/proto/tensors_to_face_landmarks_graph_options_pb2.py +32 -0
- mediapipe/tasks/cc/vision/face_stylizer/__init__.py +0 -0
- mediapipe/tasks/cc/vision/face_stylizer/calculators/__init__.py +0 -0
- mediapipe/tasks/cc/vision/face_stylizer/calculators/tensors_to_image_calculator_pb2.py +36 -0
- mediapipe/tasks/cc/vision/face_stylizer/proto/__init__.py +0 -0
- mediapipe/tasks/cc/vision/face_stylizer/proto/face_stylizer_graph_options_pb2.py +35 -0
- mediapipe/tasks/cc/vision/gesture_recognizer/__init__.py +0 -0
- mediapipe/tasks/cc/vision/gesture_recognizer/calculators/__init__.py +0 -0
- mediapipe/tasks/cc/vision/gesture_recognizer/calculators/combined_prediction_calculator_pb2.py +33 -0
- mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator_pb2.py +31 -0
- mediapipe/tasks/cc/vision/gesture_recognizer/proto/__init__.py +0 -0
- mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options_pb2.py +35 -0
- mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_embedder_graph_options_pb2.py +34 -0
- mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_recognizer_graph_options_pb2.py +36 -0
- mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options_pb2.py +36 -0
- mediapipe/tasks/cc/vision/hand_detector/__init__.py +0 -0
- mediapipe/tasks/cc/vision/hand_detector/proto/__init__.py +0 -0
- mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options_pb2.py +34 -0
- mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_result_pb2.py +30 -0
- mediapipe/tasks/cc/vision/hand_landmarker/__init__.py +0 -0
- mediapipe/tasks/cc/vision/hand_landmarker/calculators/__init__.py +0 -0
- mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator_pb2.py +31 -0
- mediapipe/tasks/cc/vision/hand_landmarker/proto/__init__.py +0 -0
- mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options_pb2.py +36 -0
- mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options_pb2.py +34 -0
- mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_roi_refinement_graph_options_pb2.py +28 -0
- mediapipe/tasks/cc/vision/holistic_landmarker/__init__.py +0 -0
- mediapipe/tasks/cc/vision/holistic_landmarker/proto/__init__.py +0 -0
- mediapipe/tasks/cc/vision/holistic_landmarker/proto/holistic_landmarker_graph_options_pb2.py +34 -0
- mediapipe/tasks/cc/vision/holistic_landmarker/proto/holistic_result_pb2.py +29 -0
- mediapipe/tasks/cc/vision/image_classifier/__init__.py +0 -0
- mediapipe/tasks/cc/vision/image_classifier/proto/__init__.py +0 -0
- mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options_pb2.py +35 -0
- mediapipe/tasks/cc/vision/image_embedder/__init__.py +0 -0
- mediapipe/tasks/cc/vision/image_embedder/proto/__init__.py +0 -0
- mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options_pb2.py +35 -0
- mediapipe/tasks/cc/vision/image_generator/__init__.py +0 -0
- mediapipe/tasks/cc/vision/image_generator/diffuser/__init__.py +0 -0
- mediapipe/tasks/cc/vision/image_generator/diffuser/stable_diffusion_iterate_calculator_pb2.py +40 -0
- mediapipe/tasks/cc/vision/image_generator/proto/__init__.py +0 -0
- mediapipe/tasks/cc/vision/image_generator/proto/conditioned_image_graph_options_pb2.py +40 -0
- mediapipe/tasks/cc/vision/image_generator/proto/control_plugin_graph_options_pb2.py +34 -0
- mediapipe/tasks/cc/vision/image_generator/proto/image_generator_graph_options_pb2.py +30 -0
- mediapipe/tasks/cc/vision/image_segmenter/__init__.py +0 -0
- mediapipe/tasks/cc/vision/image_segmenter/calculators/__init__.py +0 -0
- mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator_pb2.py +34 -0
- mediapipe/tasks/cc/vision/image_segmenter/proto/__init__.py +0 -0
- mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options_pb2.py +35 -0
- mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options_pb2.py +33 -0
- mediapipe/tasks/cc/vision/interactive_segmenter/__init__.py +0 -0
- mediapipe/tasks/cc/vision/object_detector/__init__.py +0 -0
- mediapipe/tasks/cc/vision/object_detector/proto/__init__.py +0 -0
- mediapipe/tasks/cc/vision/object_detector/proto/object_detector_options_pb2.py +34 -0
- mediapipe/tasks/cc/vision/pose_detector/__init__.py +0 -0
- mediapipe/tasks/cc/vision/pose_detector/proto/__init__.py +0 -0
- mediapipe/tasks/cc/vision/pose_detector/proto/pose_detector_graph_options_pb2.py +34 -0
- mediapipe/tasks/cc/vision/pose_landmarker/__init__.py +0 -0
- mediapipe/tasks/cc/vision/pose_landmarker/proto/__init__.py +0 -0
- mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarker_graph_options_pb2.py +36 -0
- mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarks_detector_graph_options_pb2.py +34 -0
- mediapipe/tasks/cc/vision/utils/__init__.py +0 -0
- mediapipe/tasks/cc/vision/utils/ghum/__init__.py +0 -0
- mediapipe/tasks/metadata/image_segmenter_metadata_schema.fbs +59 -0
- mediapipe/tasks/metadata/image_segmenter_metadata_schema_py_generated.py +108 -0
- mediapipe/tasks/metadata/metadata_schema.fbs +732 -0
- mediapipe/tasks/metadata/metadata_schema_py_generated.py +3251 -0
- mediapipe/tasks/metadata/object_detector_metadata_schema.fbs +98 -0
- mediapipe/tasks/metadata/object_detector_metadata_schema_py_generated.py +674 -0
- mediapipe/tasks/metadata/schema_py_generated.py +18438 -0
- mediapipe/tasks/python/__init__.py +27 -0
- mediapipe/tasks/python/audio/__init__.py +33 -0
- mediapipe/tasks/python/audio/audio_classifier.py +324 -0
- mediapipe/tasks/python/audio/audio_embedder.py +285 -0
- mediapipe/tasks/python/audio/core/__init__.py +16 -0
- mediapipe/tasks/python/audio/core/audio_record.py +125 -0
- mediapipe/tasks/python/audio/core/audio_task_running_mode.py +29 -0
- mediapipe/tasks/python/audio/core/base_audio_task_api.py +181 -0
- mediapipe/tasks/python/benchmark/__init__.py +13 -0
- mediapipe/tasks/python/benchmark/benchmark_utils.py +70 -0
- mediapipe/tasks/python/benchmark/vision/__init__.py +13 -0
- mediapipe/tasks/python/benchmark/vision/benchmark.py +99 -0
- mediapipe/tasks/python/benchmark/vision/core/__init__.py +14 -0
- mediapipe/tasks/python/benchmark/vision/core/base_vision_benchmark_api.py +40 -0
- mediapipe/tasks/python/components/__init__.py +13 -0
- mediapipe/tasks/python/components/containers/__init__.py +53 -0
- mediapipe/tasks/python/components/containers/audio_data.py +137 -0
- mediapipe/tasks/python/components/containers/bounding_box.py +73 -0
- mediapipe/tasks/python/components/containers/category.py +78 -0
- mediapipe/tasks/python/components/containers/classification_result.py +111 -0
- mediapipe/tasks/python/components/containers/detections.py +181 -0
- mediapipe/tasks/python/components/containers/embedding_result.py +89 -0
- mediapipe/tasks/python/components/containers/keypoint.py +77 -0
- mediapipe/tasks/python/components/containers/landmark.py +122 -0
- mediapipe/tasks/python/components/containers/landmark_detection_result.py +106 -0
- mediapipe/tasks/python/components/containers/rect.py +109 -0
- mediapipe/tasks/python/components/processors/__init__.py +23 -0
- mediapipe/tasks/python/components/processors/classifier_options.py +86 -0
- mediapipe/tasks/python/components/utils/__init__.py +13 -0
- mediapipe/tasks/python/components/utils/cosine_similarity.py +68 -0
- mediapipe/tasks/python/core/__init__.py +13 -0
- mediapipe/tasks/python/core/base_options.py +121 -0
- mediapipe/tasks/python/core/optional_dependencies.py +25 -0
- mediapipe/tasks/python/core/task_info.py +139 -0
- mediapipe/tasks/python/genai/__init__.py +14 -0
- mediapipe/tasks/python/genai/bundler/__init__.py +23 -0
- mediapipe/tasks/python/genai/bundler/llm_bundler.py +130 -0
- mediapipe/tasks/python/genai/bundler/llm_bundler_test.py +168 -0
- mediapipe/tasks/python/genai/converter/__init__.py +24 -0
- mediapipe/tasks/python/genai/converter/converter_base.py +179 -0
- mediapipe/tasks/python/genai/converter/converter_factory.py +79 -0
- mediapipe/tasks/python/genai/converter/llm_converter.py +374 -0
- mediapipe/tasks/python/genai/converter/llm_converter_test.py +63 -0
- mediapipe/tasks/python/genai/converter/pytorch_converter.py +318 -0
- mediapipe/tasks/python/genai/converter/pytorch_converter_test.py +86 -0
- mediapipe/tasks/python/genai/converter/quantization_util.py +516 -0
- mediapipe/tasks/python/genai/converter/quantization_util_test.py +259 -0
- mediapipe/tasks/python/genai/converter/safetensors_converter.py +580 -0
- mediapipe/tasks/python/genai/converter/safetensors_converter_test.py +83 -0
- mediapipe/tasks/python/genai/converter/weight_bins_writer.py +120 -0
- mediapipe/tasks/python/genai/converter/weight_bins_writer_test.py +95 -0
- mediapipe/tasks/python/metadata/__init__.py +13 -0
- mediapipe/tasks/python/metadata/flatbuffers_lib/_pywrap_flatbuffers.cpython-39-x86_64-linux-gnu.so +0 -0
- mediapipe/tasks/python/metadata/metadata.py +928 -0
- mediapipe/tasks/python/metadata/metadata_displayer_cli.py +34 -0
- mediapipe/tasks/python/metadata/metadata_writers/__init__.py +13 -0
- mediapipe/tasks/python/metadata/metadata_writers/face_stylizer.py +138 -0
- mediapipe/tasks/python/metadata/metadata_writers/image_classifier.py +71 -0
- mediapipe/tasks/python/metadata/metadata_writers/image_segmenter.py +170 -0
- mediapipe/tasks/python/metadata/metadata_writers/metadata_info.py +1166 -0
- mediapipe/tasks/python/metadata/metadata_writers/metadata_writer.py +845 -0
- mediapipe/tasks/python/metadata/metadata_writers/model_asset_bundle_utils.py +71 -0
- mediapipe/tasks/python/metadata/metadata_writers/object_detector.py +331 -0
- mediapipe/tasks/python/metadata/metadata_writers/text_classifier.py +119 -0
- mediapipe/tasks/python/metadata/metadata_writers/writer_utils.py +91 -0
- mediapipe/tasks/python/test/__init__.py +13 -0
- mediapipe/tasks/python/test/audio/__init__.py +13 -0
- mediapipe/tasks/python/test/audio/audio_classifier_test.py +387 -0
- mediapipe/tasks/python/test/audio/audio_embedder_test.py +297 -0
- mediapipe/tasks/python/test/test_utils.py +196 -0
- mediapipe/tasks/python/test/text/__init__.py +13 -0
- mediapipe/tasks/python/test/text/language_detector_test.py +228 -0
- mediapipe/tasks/python/test/text/text_classifier_test.py +235 -0
- mediapipe/tasks/python/test/text/text_embedder_test.py +326 -0
- mediapipe/tasks/python/test/vision/__init__.py +13 -0
- mediapipe/tasks/python/test/vision/face_aligner_test.py +190 -0
- mediapipe/tasks/python/test/vision/face_detector_test.py +523 -0
- mediapipe/tasks/python/test/vision/face_landmarker_test.py +565 -0
- mediapipe/tasks/python/test/vision/face_stylizer_test.py +191 -0
- mediapipe/tasks/python/test/vision/hand_landmarker_test.py +437 -0
- mediapipe/tasks/python/test/vision/holistic_landmarker_test.py +544 -0
- mediapipe/tasks/python/test/vision/image_classifier_test.py +657 -0
- mediapipe/tasks/python/test/vision/image_embedder_test.py +423 -0
- mediapipe/tasks/python/test/vision/image_segmenter_test.py +512 -0
- mediapipe/tasks/python/test/vision/interactive_segmenter_test.py +341 -0
- mediapipe/tasks/python/test/vision/object_detector_test.py +493 -0
- mediapipe/tasks/python/test/vision/pose_landmarker_test.py +518 -0
- mediapipe/tasks/python/text/__init__.py +35 -0
- mediapipe/tasks/python/text/core/__init__.py +16 -0
- mediapipe/tasks/python/text/core/base_text_task_api.py +54 -0
- mediapipe/tasks/python/text/language_detector.py +220 -0
- mediapipe/tasks/python/text/text_classifier.py +187 -0
- mediapipe/tasks/python/text/text_embedder.py +188 -0
- mediapipe/tasks/python/vision/__init__.py +90 -0
- mediapipe/tasks/python/vision/core/__init__.py +14 -0
- mediapipe/tasks/python/vision/core/base_vision_task_api.py +226 -0
- mediapipe/tasks/python/vision/core/image_processing_options.py +39 -0
- mediapipe/tasks/python/vision/core/vision_task_running_mode.py +31 -0
- mediapipe/tasks/python/vision/face_aligner.py +158 -0
- mediapipe/tasks/python/vision/face_detector.py +332 -0
- mediapipe/tasks/python/vision/face_landmarker.py +3244 -0
- mediapipe/tasks/python/vision/face_stylizer.py +158 -0
- mediapipe/tasks/python/vision/gesture_recognizer.py +480 -0
- mediapipe/tasks/python/vision/hand_landmarker.py +504 -0
- mediapipe/tasks/python/vision/holistic_landmarker.py +576 -0
- mediapipe/tasks/python/vision/image_classifier.py +358 -0
- mediapipe/tasks/python/vision/image_embedder.py +362 -0
- mediapipe/tasks/python/vision/image_segmenter.py +433 -0
- mediapipe/tasks/python/vision/interactive_segmenter.py +285 -0
- mediapipe/tasks/python/vision/object_detector.py +389 -0
- mediapipe/tasks/python/vision/pose_landmarker.py +455 -0
- mediapipe/util/__init__.py +0 -0
- mediapipe/util/analytics/__init__.py +0 -0
- mediapipe/util/analytics/mediapipe_log_extension_pb2.py +44 -0
- mediapipe/util/analytics/mediapipe_logging_enums_pb2.py +37 -0
- mediapipe/util/audio_decoder_pb2.py +33 -0
- mediapipe/util/color_pb2.py +33 -0
- mediapipe/util/label_map_pb2.py +27 -0
- mediapipe/util/render_data_pb2.py +58 -0
- mediapipe/util/sequence/__init__.py +14 -0
- mediapipe/util/sequence/media_sequence.py +716 -0
- mediapipe/util/sequence/media_sequence_test.py +290 -0
- mediapipe/util/sequence/media_sequence_util.py +800 -0
- mediapipe/util/sequence/media_sequence_util_test.py +389 -0
- mediapipe/util/tracking/__init__.py +0 -0
- mediapipe/util/tracking/box_detector_pb2.py +39 -0
- mediapipe/util/tracking/box_tracker_pb2.py +32 -0
- mediapipe/util/tracking/camera_motion_pb2.py +31 -0
- mediapipe/util/tracking/flow_packager_pb2.py +60 -0
- mediapipe/util/tracking/frame_selection_pb2.py +35 -0
- mediapipe/util/tracking/frame_selection_solution_evaluator_pb2.py +28 -0
- mediapipe/util/tracking/motion_analysis_pb2.py +35 -0
- mediapipe/util/tracking/motion_estimation_pb2.py +66 -0
- mediapipe/util/tracking/motion_models_pb2.py +42 -0
- mediapipe/util/tracking/motion_saliency_pb2.py +26 -0
- mediapipe/util/tracking/push_pull_filtering_pb2.py +26 -0
- mediapipe/util/tracking/region_flow_computation_pb2.py +59 -0
- mediapipe/util/tracking/region_flow_pb2.py +49 -0
- mediapipe/util/tracking/tone_estimation_pb2.py +45 -0
- mediapipe/util/tracking/tone_models_pb2.py +32 -0
- mediapipe/util/tracking/tracked_detection_manager_config_pb2.py +26 -0
- mediapipe/util/tracking/tracking_pb2.py +73 -0
- mediapipe_nightly-0.10.21.post20241223.dist-info/LICENSE +218 -0
- mediapipe_nightly-0.10.21.post20241223.dist-info/METADATA +199 -0
- mediapipe_nightly-0.10.21.post20241223.dist-info/RECORD +593 -0
- mediapipe_nightly-0.10.21.post20241223.dist-info/WHEEL +5 -0
- mediapipe_nightly-0.10.21.post20241223.dist-info/top_level.txt +4 -0
- mediapipe_nightly.libs/libEGL-48f73270.so.1.1.0 +0 -0
- mediapipe_nightly.libs/libGLESv2-ed5eda4f.so.2.1.0 +0 -0
- mediapipe_nightly.libs/libGLdispatch-64b28464.so.0.0.0 +0 -0
@@ -0,0 +1,516 @@
|
|
1
|
+
# Copyright 2024 The MediaPipe Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
"""Utilities for quantizing tensors.
|
16
|
+
|
17
|
+
Note that this is a reduced fork version of the praxis libraries to provide a
|
18
|
+
self-contained library for packaging.
|
19
|
+
"""
|
20
|
+
|
21
|
+
from typing import Any, List, Optional, Sequence, Tuple, Union
|
22
|
+
|
23
|
+
import jax
|
24
|
+
from jax import lax
|
25
|
+
from jax import numpy as jnp
|
26
|
+
import numpy as np
|
27
|
+
|
28
|
+
|
29
|
+
JTensor = jax.Array
|
30
|
+
_UINT4_ZP = 8 # Default zero point for unsigned 4-bit.
|
31
|
+
|
32
|
+
|
33
|
+
def _get_scan_range() -> np.ndarray:
|
34
|
+
# Produce candidate scan values.
|
35
|
+
return np.linspace(1.0, 0.5, num=11)
|
36
|
+
|
37
|
+
|
38
|
+
def _get_mean_error(bound, t, min_value, max_value, p_value):
|
39
|
+
scale = bound / max_value
|
40
|
+
candidate = jnp.divide(t, scale)
|
41
|
+
candidate = jnp.clip(jnp.round(candidate), min_value, max_value)
|
42
|
+
candidate = jnp.multiply(candidate, scale)
|
43
|
+
pmean_error = jnp.mean(jnp.abs(jnp.subtract(candidate, t)) ** p_value)
|
44
|
+
return pmean_error
|
45
|
+
|
46
|
+
|
47
|
+
def _get_best_bound_per_tensor(
|
48
|
+
t: JTensor,
|
49
|
+
bound: JTensor,
|
50
|
+
min_value: float,
|
51
|
+
max_value: float,
|
52
|
+
p_value: float = 1.0,
|
53
|
+
) -> JTensor:
|
54
|
+
"""Scan around [0.5, 1] * hard max value to get bound value for whole tensor.
|
55
|
+
|
56
|
+
This does a scan to get bound value(s) that minimize mean absolute error (MAE)
|
57
|
+
between original tensor 't' and quantized tensor. It's (almost) equivalent to
|
58
|
+
maximizing entropy.
|
59
|
+
|
60
|
+
Args:
|
61
|
+
t: The input float tensor.
|
62
|
+
bound: The hard max value for tensor 't'. It has the same length as shape.
|
63
|
+
min_value: Minimal value for the quantization bound.
|
64
|
+
max_value: Maximal value for the quantization bound.
|
65
|
+
p_value: Exponent of the p-mean error metric. Default to 1.0 which is MAE.
|
66
|
+
|
67
|
+
Returns:
|
68
|
+
The best bound values for 't', that minimize p-mean error.
|
69
|
+
"""
|
70
|
+
|
71
|
+
def _quant(scaling_factors):
|
72
|
+
return _get_mean_error(
|
73
|
+
bound * scaling_factors, t, min_value, max_value, p_value
|
74
|
+
)
|
75
|
+
|
76
|
+
scaling_factors = _get_scan_range()
|
77
|
+
diffs = jax.vmap(_quant)(scaling_factors)
|
78
|
+
best_scaling = scaling_factors[jnp.argmin(diffs)].astype(bound.dtype)
|
79
|
+
return bound * best_scaling
|
80
|
+
|
81
|
+
|
82
|
+
def _quantrow(
|
83
|
+
vec: JTensor,
|
84
|
+
bound: JTensor,
|
85
|
+
min_value: float,
|
86
|
+
max_value: float,
|
87
|
+
p_value: float,
|
88
|
+
factors: np.ndarray,
|
89
|
+
) -> JTensor:
|
90
|
+
"""Get best rescaling factor from a list of factors applied a channel.
|
91
|
+
|
92
|
+
Args:
|
93
|
+
vec: The vector in a channel.
|
94
|
+
bound: The hard bound (max(abs(vec))) of the vector.
|
95
|
+
min_value: The target min value.
|
96
|
+
max_value: The target max value.
|
97
|
+
p_value: Exponent of the p-mean error metric.
|
98
|
+
factors: The values to be applied on top of bound.
|
99
|
+
|
100
|
+
Returns:
|
101
|
+
adjusted bound value out of the list of factors applied to bound.
|
102
|
+
"""
|
103
|
+
|
104
|
+
def _quant(bounds):
|
105
|
+
return _get_mean_error(bounds, vec, min_value, max_value, p_value)
|
106
|
+
|
107
|
+
diffs = jax.vmap(_quant)(bound * factors)
|
108
|
+
best_scaling = factors[jnp.argmin(diffs)]
|
109
|
+
return bound * best_scaling
|
110
|
+
|
111
|
+
|
112
|
+
def _get_best_bound_per_channel(
|
113
|
+
t: JTensor,
|
114
|
+
bound: JTensor,
|
115
|
+
min_value: float,
|
116
|
+
max_value: float,
|
117
|
+
p_value: float = 1.0,
|
118
|
+
) -> JTensor:
|
119
|
+
"""Scan around [0.5, 1] * hard max value to get bound value for each channel.
|
120
|
+
|
121
|
+
This does a scan to get bound value(s) that minimize mean absolute error (MAE)
|
122
|
+
between original tensor 't' and quantized tensor. It's (almost) equivalent to
|
123
|
+
maximizing entropy.
|
124
|
+
|
125
|
+
Args:
|
126
|
+
t: The input float tensor.
|
127
|
+
bound: The hard max value for tensor 't'. It has the same length as shape.
|
128
|
+
min_value: Minimal value for the quantization bound.
|
129
|
+
max_value: Maximal value for the quantization bound.
|
130
|
+
p_value: Exponent of the p-mean error metric. Default to 1.0 which is MAE.
|
131
|
+
|
132
|
+
Returns:
|
133
|
+
The best bound values for 't', that minimize p-mean error.
|
134
|
+
"""
|
135
|
+
assert len(t.shape) == 2
|
136
|
+
assert len(bound.shape) == 2
|
137
|
+
assert t.shape[1] == bound.shape[1]
|
138
|
+
assert bound.shape[0] == 1
|
139
|
+
scans = _get_scan_range()
|
140
|
+
|
141
|
+
def _quant(tensor, bound, min_value, max_value, p_value, factors):
|
142
|
+
ret = np.zeros(bound.shape)
|
143
|
+
for i in range(len(tensor)):
|
144
|
+
best = _quantrow(
|
145
|
+
tensor[i], bound[i], min_value, max_value, p_value, factors
|
146
|
+
)
|
147
|
+
ret[i] = best
|
148
|
+
return ret
|
149
|
+
|
150
|
+
t = t.transpose()
|
151
|
+
t_split = list(t)
|
152
|
+
res = _quant(t_split, bound[0, :], min_value, max_value, p_value, scans)
|
153
|
+
res = res.reshape(bound.shape)
|
154
|
+
return res
|
155
|
+
|
156
|
+
|
157
|
+
def get_best_bound(
|
158
|
+
t: JTensor,
|
159
|
+
bound: JTensor,
|
160
|
+
min_value: float,
|
161
|
+
max_value: float,
|
162
|
+
p_value: float = 1.0,
|
163
|
+
per_channel: bool = False,
|
164
|
+
) -> JTensor:
|
165
|
+
"""Scan multiple factors on max value to get best bound value.
|
166
|
+
|
167
|
+
This does a scan to get bound value(s) that minimize mean absolute error (MAE)
|
168
|
+
between original tensor 't' and quantized tensor. It's (almost) equivalent to
|
169
|
+
maximizing entropy.
|
170
|
+
|
171
|
+
Args:
|
172
|
+
t: The input float tensor.
|
173
|
+
bound: The hard max value for tensor 't'. It has the same length as shape.
|
174
|
+
min_value: Minimal value for the quantization bound.
|
175
|
+
max_value: Maximal value for the quantization bound.
|
176
|
+
p_value: Exponent of the p-mean error metric. Default to 1.0 which is MAE.
|
177
|
+
per_channel: if get best bound for entire tensor or per channel.
|
178
|
+
|
179
|
+
Returns:
|
180
|
+
The best bound values for 't', that minimize p-mean error.
|
181
|
+
"""
|
182
|
+
if per_channel:
|
183
|
+
return _get_best_bound_per_channel(t, bound, min_value, max_value, p_value)
|
184
|
+
else:
|
185
|
+
return _get_best_bound_per_tensor(t, bound, min_value, max_value, p_value)
|
186
|
+
|
187
|
+
|
188
|
+
def get_min_max(
|
189
|
+
bits: int = 8,
|
190
|
+
unsigned: bool = False,
|
191
|
+
use_fp: bool = False,
|
192
|
+
) -> Tuple[float, float]:
|
193
|
+
"""Gets the min/max range for a given number of bits.
|
194
|
+
|
195
|
+
Args:
|
196
|
+
bits: Target number of bits for quantization.
|
197
|
+
unsigned: If True compute min and max for unsigned number, else for signed.
|
198
|
+
use_fp: in floating point.
|
199
|
+
|
200
|
+
Returns:
|
201
|
+
min/max values for the provide number of bits.
|
202
|
+
"""
|
203
|
+
if use_fp:
|
204
|
+
# TODO: support other fp types.
|
205
|
+
return -448.0, 448.0
|
206
|
+
# Calculation instead of jax.iinfo is used to support bits beside 4 and 8.
|
207
|
+
if unsigned:
|
208
|
+
# For unsigned 8 bits precision it is [0, 255]
|
209
|
+
return 0, 2**bits - 1
|
210
|
+
else:
|
211
|
+
# For signed 8 bits precision it is [-128, 127]
|
212
|
+
return -1 * 2 ** (bits - 1), 2 ** (bits - 1) - 1
|
213
|
+
|
214
|
+
|
215
|
+
def pass_through(x: JTensor, fn: Any) -> JTensor:
|
216
|
+
# Create an exactly-zero expression with Sterbenz lemma that has an
|
217
|
+
# exactly-one gradient.
|
218
|
+
return x - jax.lax.stop_gradient(x) + jax.lax.stop_gradient(fn(x))
|
219
|
+
|
220
|
+
|
221
|
+
def reduce_precision(
|
222
|
+
t: JTensor,
|
223
|
+
contract_dims: Optional[Sequence[int]],
|
224
|
+
need_gradient: bool = False,
|
225
|
+
bits: int = 8,
|
226
|
+
optimization_on_bound: bool = False,
|
227
|
+
p_value: float = 1.0,
|
228
|
+
percentile: float = 1.0,
|
229
|
+
use_symmetric: bool = True,
|
230
|
+
use_fp: bool = False,
|
231
|
+
add_scale_eps: bool = False,
|
232
|
+
per_channel: bool = False,
|
233
|
+
random_rounding: bool = False,
|
234
|
+
key: Optional[jax.Array] = None,
|
235
|
+
) -> Tuple[JTensor, JTensor, Optional[JTensor]]:
|
236
|
+
"""Reduce the precision of a tensor.
|
237
|
+
|
238
|
+
Generic for all tensors.
|
239
|
+
|
240
|
+
Args:
|
241
|
+
t: Input tensor.
|
242
|
+
contract_dims: Specifies contracting dimensions of the input tensor.
|
243
|
+
need_gradient: If gradient is needed out of this function.
|
244
|
+
bits: Target number of bits.
|
245
|
+
optimization_on_bound: If MAE bound optimizer is used.
|
246
|
+
p_value: Exponent of the p-mean error metric. Default to 1.0 which is MAE.
|
247
|
+
percentile: Percentile Factor to apply on the min/max range. Setting this to
|
248
|
+
other than 1.0 disables optimization_on_bound.
|
249
|
+
use_symmetric: If the input tensor is quantized symmetrically.
|
250
|
+
use_fp: Use floating point.
|
251
|
+
add_scale_eps: Add eps value or replace zero value by 1 to avoid division by
|
252
|
+
zero.
|
253
|
+
per_channel: use per-channel clipping optimization.
|
254
|
+
random_rounding: round with uniform random.
|
255
|
+
key: rng key for rounding.
|
256
|
+
|
257
|
+
Returns:
|
258
|
+
A tuple of quantized tensor, quantization scale
|
259
|
+
and quantization zero point (optional).
|
260
|
+
"""
|
261
|
+
min_value, max_value = get_min_max(bits, use_fp=use_fp)
|
262
|
+
|
263
|
+
if use_symmetric:
|
264
|
+
bound = jnp.max(jnp.abs(t), axis=contract_dims, keepdims=True)
|
265
|
+
scale_bound = max_value
|
266
|
+
else:
|
267
|
+
t_max = jnp.max(t, axis=contract_dims, keepdims=True)
|
268
|
+
t_min = jnp.min(t, axis=contract_dims, keepdims=True)
|
269
|
+
bound = t_max - t_min
|
270
|
+
scale_bound = max_value - min_value
|
271
|
+
|
272
|
+
if percentile < 1.0:
|
273
|
+
bound = jnp.multiply(bound, percentile)
|
274
|
+
elif optimization_on_bound:
|
275
|
+
bound = get_best_bound(
|
276
|
+
t, bound, min_value, max_value, p_value, per_channel=per_channel
|
277
|
+
)
|
278
|
+
|
279
|
+
scale = bound / scale_bound
|
280
|
+
|
281
|
+
if add_scale_eps:
|
282
|
+
# Add epsilon to avoid divide-by-zero.
|
283
|
+
scale = scale + jnp.finfo(t.dtype).eps
|
284
|
+
else:
|
285
|
+
scale = jnp.where(scale == 0.0, 1.0, scale)
|
286
|
+
|
287
|
+
if use_symmetric:
|
288
|
+
zp = None
|
289
|
+
t = jnp.divide(t, scale)
|
290
|
+
else:
|
291
|
+
zp = min_value - t_min / scale
|
292
|
+
t = jnp.divide(t, scale) + zp
|
293
|
+
zp = jnp.multiply(scale, zp)
|
294
|
+
|
295
|
+
if use_fp:
|
296
|
+
# No need to round.
|
297
|
+
t = jnp.clip(t, min_value, max_value).astype(jnp.float8_e4m3fn)
|
298
|
+
# TODO: refactor to remove this logic.
|
299
|
+
t = jax.lax.bitcast_convert_type(t, new_dtype=jnp.int8)
|
300
|
+
else:
|
301
|
+
if need_gradient:
|
302
|
+
t = pass_through(t, jnp.round)
|
303
|
+
t = jnp.clip(t, min_value, max_value)
|
304
|
+
else:
|
305
|
+
if random_rounding:
|
306
|
+
t = t + jax.random.uniform(
|
307
|
+
key=key, shape=t.shape, minval=-0.5, maxval=0.5
|
308
|
+
)
|
309
|
+
t = jnp.round(t)
|
310
|
+
container_dtype = (
|
311
|
+
jnp.int8 if bits <= 8 else jnp.int16 if bits <= 16 else jnp.int32
|
312
|
+
)
|
313
|
+
t = jnp.clip(t, min_value, max_value).astype(container_dtype)
|
314
|
+
|
315
|
+
return t, scale, zp
|
316
|
+
|
317
|
+
|
318
|
+
def quantize_tensor(
|
319
|
+
var: np.ndarray,
|
320
|
+
axis: List[int],
|
321
|
+
factor: float = 1.0,
|
322
|
+
sym: bool = True,
|
323
|
+
number_bits: int = 8,
|
324
|
+
use_fp: bool = False,
|
325
|
+
add_scale_eps: bool = False,
|
326
|
+
optimization_on_bound: bool = False,
|
327
|
+
p_value: float = 1.0,
|
328
|
+
per_channel: bool = False,
|
329
|
+
block_size: int = 0,
|
330
|
+
) -> Union[
|
331
|
+
Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray, np.ndarray]
|
332
|
+
]:
|
333
|
+
"""Quantize a tensor.
|
334
|
+
|
335
|
+
Args:
|
336
|
+
var: The variable to be quantized.
|
337
|
+
axis: The axis along which variable will be quantized.
|
338
|
+
factor: The clipping factor.
|
339
|
+
sym: Symmetric or asymmetric quantize the variable.
|
340
|
+
number_bits: Number of bits for quantized value.
|
341
|
+
use_fp: do fp with number of bits (i.e. fp8)
|
342
|
+
add_scale_eps: add epsilon to scale to avoid division by zero, else it will
|
343
|
+
replace zero scale by 1.
|
344
|
+
optimization_on_bound: If p-mean bound optimizer is used.
|
345
|
+
p_value: Exponent of the p-mean error metric. Default to 1.0 which is MAE.
|
346
|
+
per_channel: use per-channel clipping optimization.
|
347
|
+
block_size: block size for sub-channel quantization. Defaults to 0, which
|
348
|
+
means off.
|
349
|
+
|
350
|
+
Returns:
|
351
|
+
Quantized tensors, along with scales and zero point.
|
352
|
+
"""
|
353
|
+
# TODO: support jnp.float8_e5m2
|
354
|
+
assert number_bits == 8 or number_bits == 4
|
355
|
+
jnp_var = jnp.asarray(var)
|
356
|
+
# When using sub-channel, the contracting dim is split into a sub-channel
|
357
|
+
# dim followed by the block dim. Therefore the contracting dim
|
358
|
+
# (quantize_axis) should increment by one, and the corresponding pack_dim
|
359
|
+
# should also increment by one.
|
360
|
+
if block_size > 0:
|
361
|
+
shape = list(jnp_var.shape)
|
362
|
+
assert len(axis) == 1, 'Only support 1D sub-channel quantization'
|
363
|
+
sub_channels, rem = divmod(shape[axis[0]], block_size)
|
364
|
+
assert rem == 0
|
365
|
+
shape.insert(axis[0], sub_channels)
|
366
|
+
axis[0] += 1
|
367
|
+
shape[axis[0]] = block_size
|
368
|
+
jnp_var = jnp.reshape(jnp_var, shape)
|
369
|
+
|
370
|
+
qvar, scale, zp = reduce_precision(
|
371
|
+
jnp_var,
|
372
|
+
contract_dims=axis,
|
373
|
+
need_gradient=False,
|
374
|
+
bits=number_bits,
|
375
|
+
optimization_on_bound=optimization_on_bound,
|
376
|
+
percentile=factor,
|
377
|
+
use_symmetric=sym,
|
378
|
+
use_fp=use_fp,
|
379
|
+
add_scale_eps=add_scale_eps,
|
380
|
+
p_value=p_value,
|
381
|
+
per_channel=per_channel,
|
382
|
+
)
|
383
|
+
if sym:
|
384
|
+
return np.array(qvar), np.array(jnp.squeeze(scale, axis=axis)) # pytype: disable=wrong-arg-types # jnp-type
|
385
|
+
else:
|
386
|
+
return (
|
387
|
+
np.array(qvar),
|
388
|
+
# CAVEAT: the following squeezes should squeeze along the quantization
|
389
|
+
# axis only.
|
390
|
+
np.array(jnp.squeeze(scale)),
|
391
|
+
np.array(jnp.squeeze(zp)),
|
392
|
+
)
|
393
|
+
|
394
|
+
|
395
|
+
def pack_4bit(
|
396
|
+
x: np.ndarray, pack_dim: int, packed_dtype: jnp.dtype = jnp.int32
|
397
|
+
) -> np.ndarray:
|
398
|
+
"""Pack int8 or uint8 tensor where its values are actually int4 or uint4, to int32 or int8 nibble format along pack_dim.
|
399
|
+
|
400
|
+
Args:
|
401
|
+
x: Original int8 or uint8 tensor to pack.
|
402
|
+
pack_dim: Dimension to pack along. x.shape[pack_dim] must be divisible by 8,
|
403
|
+
when packed_dtype is int32 and divisible by 2 when target_type is int8.
|
404
|
+
Also pack_dim must be < x.ndim - 1.
|
405
|
+
packed_dtype: Target type to pack to, int32 or int8.
|
406
|
+
|
407
|
+
Returns:
|
408
|
+
int32 or int8 packed tensor where the pack_dim size is dividend by 8
|
409
|
+
from the original tensor x.
|
410
|
+
"""
|
411
|
+
x = jnp.asarray(x)
|
412
|
+
if packed_dtype == jnp.int8 and x.dtype == jnp.uint8:
|
413
|
+
# It doesn't make sense to pack uint8 numbers into int4 as we'll
|
414
|
+
# the range overlap between uint8 and int4 is [0..7].
|
415
|
+
raise ValueError(
|
416
|
+
'only int8 input dtype is supported when packing into int8. '
|
417
|
+
f'Given {x.dtype}'
|
418
|
+
)
|
419
|
+
|
420
|
+
if x.dtype != jnp.int8 and x.dtype != jnp.uint8:
|
421
|
+
raise ValueError(
|
422
|
+
f'input dtype must be either int8 or uint8. Given {x.dtype}'
|
423
|
+
)
|
424
|
+
if pack_dim >= x.ndim - 1:
|
425
|
+
raise ValueError(
|
426
|
+
f'pack_dim must be < input ndim - 1. input shape {x.shape} and pack_dim'
|
427
|
+
f' {pack_dim}'
|
428
|
+
)
|
429
|
+
if packed_dtype != jnp.int32 and packed_dtype != jnp.int8:
|
430
|
+
raise ValueError(
|
431
|
+
f'packed_dtype must be either int32 or int8. Given {packed_dtype}'
|
432
|
+
)
|
433
|
+
if packed_dtype == jnp.int32 and x.shape[pack_dim] % 8 != 0:
|
434
|
+
raise ValueError(
|
435
|
+
'input shape[pack_dim] must be divisible by 8 when target_type '
|
436
|
+
f'is int32. Given shape {x.shape}'
|
437
|
+
)
|
438
|
+
if packed_dtype == jnp.int8 and x.shape[pack_dim] % 2 != 0:
|
439
|
+
raise ValueError(
|
440
|
+
'input shape[pack_dim] must be divisible by 2 when target_type '
|
441
|
+
f'is int8. Given shape {x.shape}'
|
442
|
+
)
|
443
|
+
|
444
|
+
int4s_per_packed_type = 8 if packed_dtype == jnp.int32 else 2
|
445
|
+
|
446
|
+
rep_shape = list(x.shape)
|
447
|
+
rep_shape.insert(pack_dim + 1, int4s_per_packed_type)
|
448
|
+
rep_shape[pack_dim] //= int4s_per_packed_type
|
449
|
+
|
450
|
+
shifts = lax.broadcasted_iota(packed_dtype, rep_shape, pack_dim + 1)
|
451
|
+
shifts <<= 2
|
452
|
+
|
453
|
+
# Promote x to packed_dtype
|
454
|
+
x = x & jnp.array(0x0F, packed_dtype)
|
455
|
+
x = lax.reshape(x, rep_shape)
|
456
|
+
x = x << shifts
|
457
|
+
x = lax.reduce(x, jnp.array(0x0, packed_dtype), lax.add, [pack_dim + 1])
|
458
|
+
return np.asarray(x)
|
459
|
+
|
460
|
+
|
461
|
+
def update_to_uint4(
|
462
|
+
qx: np.ndarray, scale: np.ndarray, zp: Optional[np.ndarray] = None
|
463
|
+
):
|
464
|
+
"""Updates the quantized weights from int4 to uint4.
|
465
|
+
|
466
|
+
This is a conversion function designed for XNNPack as it expects the 4-bit
|
467
|
+
quantized weight to be represented differently from the original Pax setting.
|
468
|
+
Specifically, the differences are:
|
469
|
+
1) The dynamic range of weight values: int4 (Pax) vs. uint4 (XNNPack).
|
470
|
+
2) The dynamic range of zero-point: float (Pax) vs. uint4 (XNNPack).
|
471
|
+
3) The number of zero-point: per-channel (Pax) vs. per-tensor (XNNPack).
|
472
|
+
|
473
|
+
Args:
|
474
|
+
qx: np.array of shape [..., channel], which is the quantized weight values
|
475
|
+
from Pax in the shape of. The values are in the dynamic range of int4 but
|
476
|
+
are hosted as int8 type. Note that if the first dimension is 3, it means
|
477
|
+
the qkv matrices are concatenated together and should be treated
|
478
|
+
differently.
|
479
|
+
scale: np.array of shape [1(3), channel] as np.float type, which are the
|
480
|
+
scaling factors for dequantization per channel.
|
481
|
+
zp: (optional) np.array of shape [1 (or 3), channel] as np.float type, which
|
482
|
+
are the zero points for dequantization per channel.
|
483
|
+
|
484
|
+
Returns:
|
485
|
+
A tuple (qx, scale, zp):
|
486
|
+
qx: The updated np.array of shape [..., channel] as np.int8 type with
|
487
|
+
updated dynamic range as uint4 (with 8 as the default zero points).
|
488
|
+
scale: Same as the input scale.
|
489
|
+
zp: (optional) np.array of shape [1 (or 3)] as np.int8 type with the
|
490
|
+
updated zero point values in the dynamic range as uint4.
|
491
|
+
"""
|
492
|
+
if qx.dtype != np.int8 or ('float' not in str(scale.dtype)):
|
493
|
+
raise ValueError(
|
494
|
+
'Unexpected dtype qx:' + str(qx.dtype) + ' scale:' + str(scale.dtype)
|
495
|
+
)
|
496
|
+
|
497
|
+
scale = scale.astype(np.float32)
|
498
|
+
|
499
|
+
def get_new_zp(old_zp):
|
500
|
+
new_zp = old_zp / (scale + np.finfo(np.float32).eps)
|
501
|
+
per_tensor_zp = np.mean(new_zp)
|
502
|
+
per_tensor_zp = per_tensor_zp.astype(np.int8) + _UINT4_ZP
|
503
|
+
return per_tensor_zp
|
504
|
+
|
505
|
+
if zp is not None:
|
506
|
+
if qx.shape[0] == 3:
|
507
|
+
per_tensor_zp = np.stack([get_new_zp(szp) for szp in zp], axis=0)
|
508
|
+
else:
|
509
|
+
per_tensor_zp = get_new_zp(zp)
|
510
|
+
else:
|
511
|
+
per_tensor_zp = (
|
512
|
+
_UINT4_ZP * np.ones(shape=(3)) if qx.shape[0] == 3 else _UINT4_ZP
|
513
|
+
)
|
514
|
+
|
515
|
+
qx = qx + _UINT4_ZP
|
516
|
+
return qx, scale, np.array(per_tensor_zp, dtype=np.int32)
|