paddlex 3.0.0b2__py3-none-any.whl → 3.0.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- paddlex/.version +1 -1
- paddlex/__init__.py +17 -33
- paddlex/__main__.py +4 -5
- paddlex/configs/modules/3d_bev_detection/BEVFusion.yaml +38 -0
- paddlex/configs/modules/doc_vlm/PP-DocBee-2B.yaml +14 -0
- paddlex/configs/modules/doc_vlm/PP-DocBee-7B.yaml +14 -0
- paddlex/configs/modules/face_feature/MobileFaceNet.yaml +41 -0
- paddlex/configs/modules/face_feature/ResNet50_face.yaml +41 -0
- paddlex/configs/modules/formula_recognition/LaTeX_OCR_rec.yaml +40 -0
- paddlex/configs/modules/formula_recognition/PP-FormulaNet-L.yaml +40 -0
- paddlex/configs/modules/formula_recognition/PP-FormulaNet-S.yaml +40 -0
- paddlex/configs/modules/formula_recognition/UniMERNet.yaml +40 -0
- paddlex/configs/modules/image_classification/CLIP_vit_base_patch16_224.yaml +41 -0
- paddlex/configs/modules/image_classification/CLIP_vit_large_patch14_224.yaml +41 -0
- paddlex/configs/modules/image_classification/ConvNeXt_large_384.yaml +41 -0
- paddlex/configs/modules/keypoint_detection/PP-TinyPose_128x96.yaml +40 -0
- paddlex/configs/modules/keypoint_detection/PP-TinyPose_256x192.yaml +40 -0
- paddlex/configs/modules/layout_detection/PP-DocLayout-L.yaml +40 -0
- paddlex/configs/modules/layout_detection/PP-DocLayout-M.yaml +40 -0
- paddlex/configs/modules/layout_detection/PP-DocLayout-S.yaml +40 -0
- paddlex/configs/modules/multilingual_speech_recognition/whisper_base.yaml +12 -0
- paddlex/configs/modules/multilingual_speech_recognition/whisper_large.yaml +12 -0
- paddlex/configs/modules/multilingual_speech_recognition/whisper_medium.yaml +12 -0
- paddlex/configs/modules/multilingual_speech_recognition/whisper_small.yaml +12 -0
- paddlex/configs/modules/multilingual_speech_recognition/whisper_tiny.yaml +12 -0
- paddlex/configs/modules/object_detection/Co-DINO-R50.yaml +40 -0
- paddlex/configs/modules/object_detection/Co-DINO-Swin-L.yaml +40 -0
- paddlex/configs/modules/object_detection/Co-Deformable-DETR-R50.yaml +40 -0
- paddlex/configs/modules/object_detection/Co-Deformable-DETR-Swin-T.yaml +40 -0
- paddlex/configs/modules/object_detection/YOLOX-X.yaml +40 -0
- paddlex/configs/modules/open_vocabulary_detection/GroundingDINO-T.yaml +13 -0
- paddlex/configs/modules/open_vocabulary_detection/YOLO-Worldv2-L.yaml +13 -0
- paddlex/configs/modules/open_vocabulary_segmentation/SAM-H_box.yaml +17 -0
- paddlex/configs/modules/open_vocabulary_segmentation/SAM-H_point.yaml +15 -0
- paddlex/configs/modules/rotated_object_detection/PP-YOLOE-R-L.yaml +40 -0
- paddlex/configs/modules/semantic_segmentation/MaskFormer_small.yaml +42 -0
- paddlex/configs/modules/semantic_segmentation/MaskFormer_tiny.yaml +42 -0
- paddlex/configs/modules/semantic_segmentation/SeaFormer_base.yaml +40 -0
- paddlex/configs/modules/semantic_segmentation/SeaFormer_large.yaml +40 -0
- paddlex/configs/modules/semantic_segmentation/SeaFormer_small.yaml +40 -0
- paddlex/configs/modules/semantic_segmentation/SeaFormer_tiny.yaml +40 -0
- paddlex/configs/modules/table_cells_detection/RT-DETR-L_wired_table_cell_det.yaml +40 -0
- paddlex/configs/modules/table_cells_detection/RT-DETR-L_wireless_table_cell_det.yaml +40 -0
- paddlex/configs/modules/table_classification/PP-LCNet_x1_0_table_cls.yaml +41 -0
- paddlex/configs/modules/table_structure_recognition/SLANeXt_wired.yaml +39 -0
- paddlex/configs/modules/table_structure_recognition/SLANeXt_wireless.yaml +39 -0
- paddlex/configs/modules/text_detection/PP-OCRv3_mobile_det.yaml +40 -0
- paddlex/configs/modules/text_detection/PP-OCRv3_server_det.yaml +40 -0
- paddlex/configs/modules/text_recognition/PP-OCRv3_mobile_rec.yaml +39 -0
- paddlex/configs/modules/text_recognition/PP-OCRv4_server_rec_doc.yaml +39 -0
- paddlex/configs/modules/text_recognition/arabic_PP-OCRv3_mobile_rec.yaml +39 -0
- paddlex/configs/modules/text_recognition/chinese_cht_PP-OCRv3_mobile_rec.yaml +39 -0
- paddlex/configs/modules/text_recognition/cyrillic_PP-OCRv3_mobile_rec.yaml +39 -0
- paddlex/configs/modules/text_recognition/devanagari_PP-OCRv3_mobile_rec.yaml +39 -0
- paddlex/configs/modules/text_recognition/en_PP-OCRv3_mobile_rec.yaml +39 -0
- paddlex/configs/modules/text_recognition/en_PP-OCRv4_mobile_rec.yaml +39 -0
- paddlex/configs/modules/text_recognition/japan_PP-OCRv3_mobile_rec.yaml +39 -0
- paddlex/configs/modules/text_recognition/ka_PP-OCRv3_mobile_rec.yaml +39 -0
- paddlex/configs/modules/text_recognition/korean_PP-OCRv3_mobile_rec.yaml +39 -0
- paddlex/configs/modules/text_recognition/latin_PP-OCRv3_mobile_rec.yaml +39 -0
- paddlex/configs/modules/text_recognition/ta_PP-OCRv3_mobile_rec.yaml +39 -0
- paddlex/configs/modules/text_recognition/te_PP-OCRv3_mobile_rec.yaml +39 -0
- paddlex/configs/modules/textline_orientation/PP-LCNet_x0_25_textline_ori.yaml +41 -0
- paddlex/configs/modules/video_classification/PP-TSM-R50_8frames_uniform.yaml +42 -0
- paddlex/configs/modules/video_classification/PP-TSMv2-LCNetV2_16frames_uniform.yaml +42 -0
- paddlex/configs/modules/video_classification/PP-TSMv2-LCNetV2_8frames_uniform.yaml +42 -0
- paddlex/configs/modules/video_detection/YOWO.yaml +40 -0
- paddlex/configs/pipelines/3d_bev_detection.yaml +9 -0
- paddlex/configs/pipelines/OCR.yaml +44 -0
- paddlex/configs/pipelines/PP-ChatOCRv3-doc.yaml +149 -0
- paddlex/configs/pipelines/PP-ChatOCRv4-doc.yaml +184 -0
- paddlex/configs/pipelines/PP-ShiTuV2.yaml +18 -0
- paddlex/configs/pipelines/PP-StructureV3.yaml +226 -0
- paddlex/configs/pipelines/anomaly_detection.yaml +8 -0
- paddlex/configs/pipelines/doc_preprocessor.yaml +15 -0
- paddlex/configs/pipelines/doc_understanding.yaml +9 -0
- paddlex/configs/pipelines/face_recognition.yaml +18 -0
- paddlex/configs/pipelines/formula_recognition.yaml +39 -0
- paddlex/configs/pipelines/human_keypoint_detection.yaml +17 -0
- paddlex/configs/pipelines/image_classification.yaml +10 -0
- paddlex/configs/pipelines/image_multilabel_classification.yaml +9 -0
- paddlex/configs/pipelines/instance_segmentation.yaml +10 -0
- paddlex/configs/pipelines/layout_parsing.yaml +101 -0
- paddlex/configs/pipelines/multilingual_speech_recognition.yaml +9 -0
- paddlex/configs/pipelines/object_detection.yaml +10 -0
- paddlex/configs/pipelines/open_vocabulary_detection.yaml +12 -0
- paddlex/configs/pipelines/open_vocabulary_segmentation.yaml +13 -0
- paddlex/configs/pipelines/pedestrian_attribute_recognition.yaml +15 -0
- paddlex/configs/pipelines/rotated_object_detection.yaml +10 -0
- paddlex/configs/pipelines/seal_recognition.yaml +51 -0
- paddlex/configs/pipelines/semantic_segmentation.yaml +10 -0
- paddlex/configs/pipelines/small_object_detection.yaml +10 -0
- paddlex/configs/pipelines/table_recognition.yaml +56 -0
- paddlex/configs/pipelines/table_recognition_v2.yaml +76 -0
- paddlex/configs/pipelines/ts_anomaly_detection.yaml +8 -0
- paddlex/configs/pipelines/ts_classification.yaml +8 -0
- paddlex/configs/pipelines/ts_forecast.yaml +8 -0
- paddlex/configs/pipelines/vehicle_attribute_recognition.yaml +15 -0
- paddlex/configs/pipelines/video_classification.yaml +9 -0
- paddlex/configs/pipelines/video_detection.yaml +10 -0
- paddlex/constants.py +17 -0
- paddlex/engine.py +8 -6
- paddlex/hpip_links.html +31 -0
- paddlex/inference/__init__.py +4 -2
- paddlex/inference/common/__init__.py +13 -0
- paddlex/inference/common/batch_sampler/__init__.py +21 -0
- paddlex/inference/common/batch_sampler/audio_batch_sampler.py +83 -0
- paddlex/inference/common/batch_sampler/base_batch_sampler.py +94 -0
- paddlex/inference/common/batch_sampler/det_3d_batch_sampler.py +144 -0
- paddlex/inference/common/batch_sampler/doc_vlm_batch_sampler.py +64 -0
- paddlex/inference/common/batch_sampler/image_batch_sampler.py +112 -0
- paddlex/inference/common/batch_sampler/ts_batch_sampler.py +109 -0
- paddlex/inference/common/batch_sampler/video_batch_sampler.py +74 -0
- paddlex/inference/common/reader/__init__.py +19 -0
- paddlex/inference/common/reader/audio_reader.py +46 -0
- paddlex/inference/common/reader/det_3d_reader.py +241 -0
- paddlex/inference/common/reader/image_reader.py +73 -0
- paddlex/inference/common/reader/ts_reader.py +46 -0
- paddlex/inference/common/reader/video_reader.py +42 -0
- paddlex/inference/common/result/__init__.py +29 -0
- paddlex/inference/common/result/base_cv_result.py +41 -0
- paddlex/inference/common/result/base_result.py +72 -0
- paddlex/inference/common/result/base_ts_result.py +41 -0
- paddlex/inference/common/result/base_video_result.py +36 -0
- paddlex/inference/common/result/mixin.py +702 -0
- paddlex/inference/models/__init__.py +55 -75
- paddlex/inference/models/anomaly_detection/__init__.py +15 -0
- paddlex/inference/models/anomaly_detection/predictor.py +135 -0
- paddlex/inference/models/anomaly_detection/processors.py +53 -0
- paddlex/inference/models/anomaly_detection/result.py +71 -0
- paddlex/inference/models/base/__init__.py +2 -3
- paddlex/inference/models/base/predictor/__init__.py +15 -0
- paddlex/inference/models/base/predictor/base_predictor.py +420 -0
- paddlex/inference/models/common/__init__.py +26 -0
- paddlex/inference/models/common/static_infer.py +850 -0
- paddlex/inference/models/common/tokenizer/__init__.py +19 -0
- paddlex/inference/models/common/tokenizer/bert_tokenizer.py +655 -0
- paddlex/inference/models/common/tokenizer/clip_tokenizer.py +609 -0
- paddlex/inference/models/common/tokenizer/gpt_tokenizer.py +453 -0
- paddlex/inference/models/common/tokenizer/qwen2_tokenizer.py +432 -0
- paddlex/inference/models/common/tokenizer/tokenizer_utils.py +2149 -0
- paddlex/inference/models/common/tokenizer/tokenizer_utils_base.py +3720 -0
- paddlex/inference/models/common/tokenizer/utils.py +66 -0
- paddlex/inference/models/common/tokenizer/vocab.py +647 -0
- paddlex/inference/models/common/ts/__init__.py +15 -0
- paddlex/inference/models/common/ts/funcs.py +540 -0
- paddlex/inference/models/common/ts/processors.py +322 -0
- paddlex/inference/models/common/vision/__init__.py +23 -0
- paddlex/inference/models/common/vision/funcs.py +98 -0
- paddlex/inference/models/common/vision/processors.py +285 -0
- paddlex/inference/models/common/vlm/__init__.py +13 -0
- paddlex/inference/models/common/vlm/activations.py +189 -0
- paddlex/inference/models/common/vlm/bert_padding.py +127 -0
- paddlex/inference/models/common/vlm/distributed.py +229 -0
- paddlex/inference/models/common/vlm/flash_attn_utils.py +119 -0
- paddlex/inference/models/common/vlm/generation/__init__.py +34 -0
- paddlex/inference/models/common/vlm/generation/configuration_utils.py +533 -0
- paddlex/inference/models/common/vlm/generation/logits_process.py +730 -0
- paddlex/inference/models/common/vlm/generation/stopping_criteria.py +106 -0
- paddlex/inference/models/common/vlm/generation/utils.py +2162 -0
- paddlex/inference/models/common/vlm/transformers/__init__.py +16 -0
- paddlex/inference/models/common/vlm/transformers/configuration_utils.py +1037 -0
- paddlex/inference/models/common/vlm/transformers/conversion_utils.py +408 -0
- paddlex/inference/models/common/vlm/transformers/model_outputs.py +1612 -0
- paddlex/inference/models/common/vlm/transformers/model_utils.py +2038 -0
- paddlex/inference/models/common/vlm/transformers/utils.py +178 -0
- paddlex/inference/models/common/vlm/utils.py +109 -0
- paddlex/inference/models/doc_vlm/__init__.py +15 -0
- paddlex/inference/models/doc_vlm/modeling/__init__.py +15 -0
- paddlex/inference/models/doc_vlm/modeling/qwen2_vl.py +2600 -0
- paddlex/inference/models/doc_vlm/predictor.py +198 -0
- paddlex/inference/models/doc_vlm/processors/__init__.py +15 -0
- paddlex/inference/models/doc_vlm/processors/common.py +372 -0
- paddlex/inference/models/doc_vlm/processors/qwen2_vl.py +698 -0
- paddlex/inference/models/doc_vlm/result.py +21 -0
- paddlex/inference/models/face_feature/__init__.py +15 -0
- paddlex/inference/models/face_feature/predictor.py +66 -0
- paddlex/inference/models/formula_recognition/__init__.py +15 -0
- paddlex/inference/models/formula_recognition/predictor.py +187 -0
- paddlex/inference/models/formula_recognition/processors.py +1002 -0
- paddlex/inference/models/formula_recognition/result.py +410 -0
- paddlex/inference/models/image_classification/__init__.py +15 -0
- paddlex/inference/models/image_classification/predictor.py +172 -0
- paddlex/inference/models/image_classification/processors.py +89 -0
- paddlex/inference/models/image_classification/result.py +93 -0
- paddlex/inference/models/image_feature/__init__.py +15 -0
- paddlex/inference/models/image_feature/predictor.py +146 -0
- paddlex/inference/models/image_feature/processors.py +32 -0
- paddlex/inference/models/image_feature/result.py +32 -0
- paddlex/inference/models/image_multilabel_classification/__init__.py +15 -0
- paddlex/inference/models/image_multilabel_classification/predictor.py +95 -0
- paddlex/inference/models/image_multilabel_classification/processors.py +89 -0
- paddlex/inference/models/image_multilabel_classification/result.py +96 -0
- paddlex/inference/models/image_unwarping/__init__.py +15 -0
- paddlex/inference/models/image_unwarping/predictor.py +97 -0
- paddlex/inference/models/image_unwarping/processors.py +92 -0
- paddlex/inference/models/image_unwarping/result.py +47 -0
- paddlex/inference/models/instance_segmentation/__init__.py +15 -0
- paddlex/inference/models/instance_segmentation/predictor.py +202 -0
- paddlex/inference/models/instance_segmentation/processors.py +102 -0
- paddlex/inference/models/instance_segmentation/result.py +162 -0
- paddlex/inference/models/keypoint_detection/__init__.py +15 -0
- paddlex/inference/models/keypoint_detection/predictor.py +187 -0
- paddlex/inference/models/keypoint_detection/processors.py +367 -0
- paddlex/inference/models/keypoint_detection/result.py +197 -0
- paddlex/inference/models/m_3d_bev_detection/__init__.py +15 -0
- paddlex/inference/models/m_3d_bev_detection/predictor.py +303 -0
- paddlex/inference/models/m_3d_bev_detection/processors.py +990 -0
- paddlex/inference/models/m_3d_bev_detection/result.py +68 -0
- paddlex/inference/models/m_3d_bev_detection/visualizer_3d.py +169 -0
- paddlex/inference/models/multilingual_speech_recognition/__init__.py +15 -0
- paddlex/inference/models/multilingual_speech_recognition/predictor.py +137 -0
- paddlex/inference/models/multilingual_speech_recognition/processors.py +1933 -0
- paddlex/inference/models/multilingual_speech_recognition/result.py +21 -0
- paddlex/inference/models/object_detection/__init__.py +15 -0
- paddlex/inference/models/object_detection/predictor.py +342 -0
- paddlex/inference/models/object_detection/processors.py +860 -0
- paddlex/inference/models/object_detection/result.py +114 -0
- paddlex/inference/models/object_detection/utils.py +68 -0
- paddlex/inference/models/open_vocabulary_detection/__init__.py +15 -0
- paddlex/inference/models/open_vocabulary_detection/predictor.py +172 -0
- paddlex/inference/models/open_vocabulary_detection/processors/__init__.py +16 -0
- paddlex/inference/models/open_vocabulary_detection/processors/common.py +114 -0
- paddlex/inference/models/open_vocabulary_detection/processors/groundingdino_processors.py +496 -0
- paddlex/inference/models/open_vocabulary_detection/processors/yoloworld_processors.py +209 -0
- paddlex/inference/models/open_vocabulary_segmentation/__init__.py +15 -0
- paddlex/inference/models/open_vocabulary_segmentation/predictor.py +113 -0
- paddlex/inference/models/open_vocabulary_segmentation/processors/__init__.py +15 -0
- paddlex/inference/models/open_vocabulary_segmentation/processors/sam_processer.py +249 -0
- paddlex/inference/models/open_vocabulary_segmentation/results/__init__.py +15 -0
- paddlex/inference/models/open_vocabulary_segmentation/results/sam_result.py +149 -0
- paddlex/inference/models/semantic_segmentation/__init__.py +15 -0
- paddlex/inference/models/semantic_segmentation/predictor.py +158 -0
- paddlex/inference/models/semantic_segmentation/processors.py +117 -0
- paddlex/inference/models/semantic_segmentation/result.py +73 -0
- paddlex/inference/models/table_structure_recognition/__init__.py +15 -0
- paddlex/inference/models/table_structure_recognition/predictor.py +161 -0
- paddlex/inference/models/table_structure_recognition/processors.py +229 -0
- paddlex/inference/models/table_structure_recognition/result.py +73 -0
- paddlex/inference/models/text_detection/__init__.py +15 -0
- paddlex/inference/models/text_detection/predictor.py +183 -0
- paddlex/inference/models/text_detection/processors.py +504 -0
- paddlex/inference/models/text_detection/result.py +56 -0
- paddlex/inference/models/text_recognition/__init__.py +15 -0
- paddlex/inference/models/text_recognition/predictor.py +98 -0
- paddlex/inference/models/text_recognition/processors.py +245 -0
- paddlex/inference/models/text_recognition/result.py +76 -0
- paddlex/inference/models/ts_anomaly_detection/__init__.py +15 -0
- paddlex/inference/models/ts_anomaly_detection/predictor.py +141 -0
- paddlex/inference/models/ts_anomaly_detection/processors.py +98 -0
- paddlex/inference/models/ts_anomaly_detection/result.py +83 -0
- paddlex/inference/models/ts_classification/__init__.py +15 -0
- paddlex/inference/models/ts_classification/predictor.py +122 -0
- paddlex/inference/models/ts_classification/processors.py +122 -0
- paddlex/inference/models/ts_classification/result.py +87 -0
- paddlex/inference/models/ts_forecasting/__init__.py +15 -0
- paddlex/inference/models/ts_forecasting/predictor.py +154 -0
- paddlex/inference/models/ts_forecasting/processors.py +158 -0
- paddlex/inference/models/ts_forecasting/result.py +96 -0
- paddlex/inference/models/video_classification/__init__.py +15 -0
- paddlex/inference/models/video_classification/predictor.py +141 -0
- paddlex/inference/models/video_classification/processors.py +409 -0
- paddlex/inference/models/video_classification/result.py +96 -0
- paddlex/inference/models/video_detection/__init__.py +15 -0
- paddlex/inference/models/video_detection/predictor.py +129 -0
- paddlex/inference/models/video_detection/processors.py +463 -0
- paddlex/inference/models/video_detection/result.py +109 -0
- paddlex/inference/pipelines/__init__.py +186 -78
- paddlex/inference/pipelines/anomaly_detection/__init__.py +15 -0
- paddlex/inference/pipelines/anomaly_detection/pipeline.py +72 -0
- paddlex/inference/pipelines/attribute_recognition/__init__.py +15 -0
- paddlex/inference/pipelines/attribute_recognition/pipeline.py +110 -0
- paddlex/inference/pipelines/attribute_recognition/result.py +102 -0
- paddlex/inference/pipelines/base.py +125 -59
- paddlex/inference/pipelines/components/__init__.py +29 -0
- paddlex/inference/pipelines/components/chat_server/__init__.py +16 -0
- paddlex/inference/pipelines/components/chat_server/base.py +39 -0
- paddlex/inference/pipelines/components/chat_server/openai_bot_chat.py +236 -0
- paddlex/inference/pipelines/components/common/__init__.py +19 -0
- paddlex/inference/pipelines/components/common/base_operator.py +37 -0
- paddlex/inference/pipelines/components/common/base_result.py +66 -0
- paddlex/inference/pipelines/components/common/convert_points_and_boxes.py +45 -0
- paddlex/inference/pipelines/components/common/crop_image_regions.py +556 -0
- paddlex/inference/pipelines/components/common/seal_det_warp.py +972 -0
- paddlex/inference/pipelines/components/common/sort_boxes.py +85 -0
- paddlex/inference/pipelines/components/common/warp_image.py +50 -0
- paddlex/inference/pipelines/components/faisser.py +357 -0
- paddlex/inference/pipelines/components/prompt_engineering/__init__.py +16 -0
- paddlex/inference/pipelines/components/prompt_engineering/base.py +35 -0
- paddlex/inference/pipelines/components/prompt_engineering/generate_ensemble_prompt.py +128 -0
- paddlex/inference/pipelines/components/prompt_engineering/generate_kie_prompt.py +148 -0
- paddlex/inference/pipelines/components/retriever/__init__.py +16 -0
- paddlex/inference/pipelines/components/retriever/base.py +228 -0
- paddlex/inference/pipelines/components/retriever/openai_bot_retriever.py +70 -0
- paddlex/inference/pipelines/components/retriever/qianfan_bot_retriever.py +166 -0
- paddlex/inference/pipelines/components/utils/__init__.py +13 -0
- paddlex/inference/pipelines/components/utils/mixin.py +206 -0
- paddlex/inference/pipelines/doc_preprocessor/__init__.py +15 -0
- paddlex/inference/pipelines/doc_preprocessor/pipeline.py +183 -0
- paddlex/inference/pipelines/doc_preprocessor/result.py +98 -0
- paddlex/inference/pipelines/doc_understanding/__init__.py +15 -0
- paddlex/inference/pipelines/doc_understanding/pipeline.py +71 -0
- paddlex/inference/pipelines/face_recognition/__init__.py +15 -0
- paddlex/inference/pipelines/face_recognition/pipeline.py +63 -0
- paddlex/inference/pipelines/face_recognition/result.py +44 -0
- paddlex/inference/pipelines/formula_recognition/__init__.py +15 -0
- paddlex/inference/pipelines/formula_recognition/pipeline.py +309 -0
- paddlex/inference/pipelines/formula_recognition/result.py +292 -0
- paddlex/inference/pipelines/image_classification/__init__.py +15 -0
- paddlex/inference/pipelines/image_classification/pipeline.py +80 -0
- paddlex/inference/pipelines/image_multilabel_classification/__init__.py +15 -0
- paddlex/inference/pipelines/image_multilabel_classification/pipeline.py +87 -0
- paddlex/inference/pipelines/instance_segmentation/__init__.py +15 -0
- paddlex/inference/pipelines/instance_segmentation/pipeline.py +81 -0
- paddlex/inference/pipelines/keypoint_detection/__init__.py +15 -0
- paddlex/inference/pipelines/keypoint_detection/pipeline.py +148 -0
- paddlex/inference/pipelines/layout_parsing/__init__.py +3 -2
- paddlex/inference/pipelines/layout_parsing/pipeline.py +581 -0
- paddlex/inference/pipelines/layout_parsing/pipeline_v2.py +749 -0
- paddlex/inference/pipelines/layout_parsing/result.py +204 -0
- paddlex/inference/pipelines/layout_parsing/result_v2.py +467 -0
- paddlex/inference/pipelines/layout_parsing/utils.py +2384 -0
- paddlex/inference/pipelines/m_3d_bev_detection/__init__.py +15 -0
- paddlex/inference/pipelines/m_3d_bev_detection/pipeline.py +74 -0
- paddlex/inference/pipelines/multilingual_speech_recognition/__init__.py +15 -0
- paddlex/inference/pipelines/multilingual_speech_recognition/pipeline.py +78 -0
- paddlex/inference/pipelines/object_detection/__init__.py +15 -0
- paddlex/inference/pipelines/object_detection/pipeline.py +105 -0
- paddlex/inference/pipelines/ocr/__init__.py +15 -0
- paddlex/inference/pipelines/ocr/pipeline.py +406 -0
- paddlex/inference/pipelines/ocr/result.py +252 -0
- paddlex/inference/pipelines/open_vocabulary_detection/__init__.py +15 -0
- paddlex/inference/pipelines/open_vocabulary_detection/pipeline.py +86 -0
- paddlex/inference/pipelines/open_vocabulary_segmentation/__init__.py +15 -0
- paddlex/inference/pipelines/open_vocabulary_segmentation/pipeline.py +100 -0
- paddlex/inference/pipelines/pp_chatocr/__init__.py +16 -0
- paddlex/inference/pipelines/pp_chatocr/pipeline_base.py +111 -0
- paddlex/inference/pipelines/pp_chatocr/pipeline_v3.py +784 -0
- paddlex/inference/pipelines/pp_chatocr/pipeline_v4.py +995 -0
- paddlex/inference/pipelines/pp_shitu_v2/__init__.py +15 -0
- paddlex/inference/pipelines/pp_shitu_v2/pipeline.py +156 -0
- paddlex/inference/pipelines/pp_shitu_v2/result.py +126 -0
- paddlex/inference/pipelines/rotated_object_detection/__init__.py +15 -0
- paddlex/inference/pipelines/rotated_object_detection/pipeline.py +85 -0
- paddlex/inference/pipelines/seal_recognition/__init__.py +15 -0
- paddlex/inference/pipelines/seal_recognition/pipeline.py +279 -0
- paddlex/inference/pipelines/seal_recognition/result.py +89 -0
- paddlex/inference/pipelines/semantic_segmentation/__init__.py +15 -0
- paddlex/inference/pipelines/semantic_segmentation/pipeline.py +85 -0
- paddlex/inference/pipelines/small_object_detection/__init__.py +15 -0
- paddlex/inference/pipelines/small_object_detection/pipeline.py +85 -0
- paddlex/inference/pipelines/table_recognition/__init__.py +3 -2
- paddlex/inference/pipelines/table_recognition/pipeline.py +478 -0
- paddlex/inference/pipelines/table_recognition/pipeline_v2.py +824 -0
- paddlex/inference/pipelines/table_recognition/result.py +218 -0
- paddlex/inference/pipelines/table_recognition/table_recognition_post_processing.py +366 -0
- paddlex/inference/pipelines/table_recognition/table_recognition_post_processing_v2.py +484 -0
- paddlex/inference/pipelines/table_recognition/utils.py +24 -437
- paddlex/inference/pipelines/ts_anomaly_detection/__init__.py +15 -0
- paddlex/inference/pipelines/ts_anomaly_detection/pipeline.py +72 -0
- paddlex/inference/pipelines/ts_classification/__init__.py +15 -0
- paddlex/inference/pipelines/ts_classification/pipeline.py +72 -0
- paddlex/inference/pipelines/ts_forecasting/__init__.py +15 -0
- paddlex/inference/pipelines/ts_forecasting/pipeline.py +72 -0
- paddlex/inference/pipelines/video_classification/__init__.py +15 -0
- paddlex/inference/pipelines/video_classification/pipeline.py +79 -0
- paddlex/inference/pipelines/video_detection/__init__.py +15 -0
- paddlex/inference/pipelines/video_detection/pipeline.py +86 -0
- paddlex/inference/serving/__init__.py +17 -0
- paddlex/inference/serving/basic_serving/__init__.py +18 -0
- paddlex/inference/serving/basic_serving/_app.py +221 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/__init__.py +44 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/_common/__init__.py +13 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/_common/common.py +100 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/_common/image_recognition.py +36 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/_common/ocr.py +95 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/anomaly_detection.py +67 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/doc_preprocessor.py +100 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/doc_understanding.py +153 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/face_recognition.py +226 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/formula_recognition.py +100 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/human_keypoint_detection.py +81 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/image_classification.py +69 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/image_multilabel_classification.py +73 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/instance_segmentation.py +87 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/layout_parsing.py +118 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/m_3d_bev_detection.py +79 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/multilingual_speech_recognition.py +92 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/object_detection.py +77 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/ocr.py +102 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/open_vocabulary_detection.py +81 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/open_vocabulary_segmentation.py +91 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/pedestrian_attribute_recognition.py +84 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/pp_chatocrv3_doc.py +194 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/pp_chatocrv4_doc.py +224 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/pp_shituv2.py +221 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/pp_structurev3.py +139 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/rotated_object_detection.py +81 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/seal_recognition.py +106 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/semantic_segmentation.py +67 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/small_object_detection.py +72 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/table_recognition.py +108 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/table_recognition_v2.py +110 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/ts_anomaly_detection.py +65 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/ts_classification.py +64 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/ts_forecast.py +65 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/vehicle_attribute_recognition.py +84 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/video_classification.py +76 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/video_detection.py +92 -0
- paddlex/inference/serving/basic_serving/_server.py +40 -0
- paddlex/inference/serving/infra/__init__.py +13 -0
- paddlex/inference/serving/infra/config.py +36 -0
- paddlex/inference/serving/infra/models.py +79 -0
- paddlex/inference/serving/infra/storage.py +180 -0
- paddlex/inference/serving/infra/utils.py +287 -0
- paddlex/inference/serving/schemas/__init__.py +13 -0
- paddlex/inference/serving/schemas/anomaly_detection.py +39 -0
- paddlex/inference/serving/schemas/doc_preprocessor.py +54 -0
- paddlex/inference/serving/schemas/doc_understanding.py +78 -0
- paddlex/inference/serving/schemas/face_recognition.py +124 -0
- paddlex/inference/serving/schemas/formula_recognition.py +56 -0
- paddlex/inference/serving/schemas/human_keypoint_detection.py +55 -0
- paddlex/inference/serving/schemas/image_classification.py +45 -0
- paddlex/inference/serving/schemas/image_multilabel_classification.py +47 -0
- paddlex/inference/serving/schemas/instance_segmentation.py +53 -0
- paddlex/inference/serving/schemas/layout_parsing.py +72 -0
- paddlex/inference/serving/schemas/m_3d_bev_detection.py +48 -0
- paddlex/inference/serving/schemas/multilingual_speech_recognition.py +57 -0
- paddlex/inference/serving/schemas/object_detection.py +52 -0
- paddlex/inference/serving/schemas/ocr.py +60 -0
- paddlex/inference/serving/schemas/open_vocabulary_detection.py +52 -0
- paddlex/inference/serving/schemas/open_vocabulary_segmentation.py +52 -0
- paddlex/inference/serving/schemas/pedestrian_attribute_recognition.py +61 -0
- paddlex/inference/serving/schemas/pp_chatocrv3_doc.py +134 -0
- paddlex/inference/serving/schemas/pp_chatocrv4_doc.py +151 -0
- paddlex/inference/serving/schemas/pp_shituv2.py +124 -0
- paddlex/inference/serving/schemas/pp_structurev3.py +84 -0
- paddlex/inference/serving/schemas/rotated_object_detection.py +52 -0
- paddlex/inference/serving/schemas/seal_recognition.py +62 -0
- paddlex/inference/serving/schemas/semantic_segmentation.py +45 -0
- paddlex/inference/serving/schemas/shared/__init__.py +13 -0
- paddlex/inference/serving/schemas/shared/classification.py +23 -0
- paddlex/inference/serving/schemas/shared/image_segmentation.py +28 -0
- paddlex/inference/serving/schemas/shared/object_detection.py +24 -0
- paddlex/inference/serving/schemas/shared/ocr.py +25 -0
- paddlex/inference/serving/schemas/small_object_detection.py +52 -0
- paddlex/inference/serving/schemas/table_recognition.py +64 -0
- paddlex/inference/serving/schemas/table_recognition_v2.py +66 -0
- paddlex/inference/serving/schemas/ts_anomaly_detection.py +37 -0
- paddlex/inference/serving/schemas/ts_classification.py +38 -0
- paddlex/inference/serving/schemas/ts_forecast.py +37 -0
- paddlex/inference/serving/schemas/vehicle_attribute_recognition.py +61 -0
- paddlex/inference/serving/schemas/video_classification.py +44 -0
- paddlex/inference/serving/schemas/video_detection.py +56 -0
- paddlex/inference/utils/__init__.py +1 -1
- paddlex/inference/utils/benchmark.py +333 -168
- paddlex/inference/utils/color_map.py +1 -1
- paddlex/inference/utils/get_pipeline_path.py +3 -2
- paddlex/inference/utils/hpi.py +251 -0
- paddlex/inference/utils/hpi_model_info_collection.json +2252 -0
- paddlex/inference/utils/io/__init__.py +11 -8
- paddlex/inference/utils/io/readers.py +178 -27
- paddlex/inference/utils/io/style.py +21 -14
- paddlex/inference/utils/io/tablepyxl.py +13 -5
- paddlex/inference/utils/io/writers.py +92 -10
- paddlex/inference/utils/model_paths.py +48 -0
- paddlex/inference/utils/new_ir_blocklist.py +27 -0
- paddlex/inference/utils/official_models.py +281 -213
- paddlex/inference/utils/pp_option.py +168 -77
- paddlex/inference/utils/trt_blocklist.py +43 -0
- paddlex/inference/utils/trt_config.py +420 -0
- paddlex/model.py +39 -14
- paddlex/modules/__init__.py +67 -57
- paddlex/modules/anomaly_detection/__init__.py +2 -2
- paddlex/modules/anomaly_detection/dataset_checker/__init__.py +2 -3
- paddlex/modules/anomaly_detection/dataset_checker/dataset_src/__init__.py +2 -2
- paddlex/modules/anomaly_detection/dataset_checker/dataset_src/analyse_dataset.py +6 -3
- paddlex/modules/anomaly_detection/dataset_checker/dataset_src/check_dataset.py +8 -4
- paddlex/modules/anomaly_detection/dataset_checker/dataset_src/convert_dataset.py +7 -4
- paddlex/modules/anomaly_detection/dataset_checker/dataset_src/split_dataset.py +2 -2
- paddlex/modules/anomaly_detection/dataset_checker/dataset_src/utils/__init__.py +1 -1
- paddlex/modules/anomaly_detection/dataset_checker/dataset_src/utils/visualizer.py +7 -2
- paddlex/modules/anomaly_detection/evaluator.py +1 -1
- paddlex/modules/anomaly_detection/exportor.py +1 -1
- paddlex/modules/anomaly_detection/model_list.py +1 -1
- paddlex/modules/anomaly_detection/trainer.py +3 -4
- paddlex/modules/base/__init__.py +5 -5
- paddlex/modules/base/build_model.py +2 -3
- paddlex/modules/base/dataset_checker/__init__.py +2 -2
- paddlex/modules/base/dataset_checker/dataset_checker.py +9 -4
- paddlex/modules/base/dataset_checker/utils.py +1 -3
- paddlex/modules/base/evaluator.py +24 -8
- paddlex/modules/base/exportor.py +36 -12
- paddlex/modules/base/trainer.py +43 -10
- paddlex/modules/base/utils/__init__.py +13 -0
- paddlex/modules/base/utils/cinn_setting.py +89 -0
- paddlex/modules/base/utils/coco_eval.py +94 -0
- paddlex/modules/base/utils/topk_eval.py +118 -0
- paddlex/modules/doc_vlm/__init__.py +18 -0
- paddlex/modules/doc_vlm/dataset_checker.py +29 -0
- paddlex/modules/doc_vlm/evaluator.py +29 -0
- paddlex/modules/doc_vlm/exportor.py +29 -0
- paddlex/modules/doc_vlm/model_list.py +16 -0
- paddlex/modules/doc_vlm/trainer.py +41 -0
- paddlex/modules/face_recognition/__init__.py +2 -2
- paddlex/modules/face_recognition/dataset_checker/__init__.py +2 -2
- paddlex/modules/face_recognition/dataset_checker/dataset_src/__init__.py +1 -1
- paddlex/modules/face_recognition/dataset_checker/dataset_src/check_dataset.py +3 -5
- paddlex/modules/face_recognition/dataset_checker/dataset_src/utils/__init__.py +1 -1
- paddlex/modules/face_recognition/dataset_checker/dataset_src/utils/visualizer.py +2 -5
- paddlex/modules/face_recognition/evaluator.py +1 -1
- paddlex/modules/face_recognition/exportor.py +1 -1
- paddlex/modules/face_recognition/model_list.py +1 -1
- paddlex/modules/face_recognition/trainer.py +2 -24
- paddlex/modules/formula_recognition/__init__.py +6 -1
- paddlex/modules/formula_recognition/dataset_checker/__init__.py +113 -0
- paddlex/modules/formula_recognition/dataset_checker/dataset_src/__init__.py +19 -0
- paddlex/modules/formula_recognition/dataset_checker/dataset_src/analyse_dataset.py +158 -0
- paddlex/modules/formula_recognition/dataset_checker/dataset_src/check_dataset.py +76 -0
- paddlex/modules/formula_recognition/dataset_checker/dataset_src/convert_dataset.py +95 -0
- paddlex/modules/formula_recognition/dataset_checker/dataset_src/split_dataset.py +80 -0
- paddlex/modules/formula_recognition/evaluator.py +77 -0
- paddlex/modules/formula_recognition/exportor.py +22 -0
- paddlex/modules/formula_recognition/model_list.py +4 -1
- paddlex/modules/formula_recognition/trainer.py +120 -0
- paddlex/modules/general_recognition/__init__.py +2 -2
- paddlex/modules/general_recognition/dataset_checker/__init__.py +2 -2
- paddlex/modules/general_recognition/dataset_checker/dataset_src/__init__.py +2 -2
- paddlex/modules/general_recognition/dataset_checker/dataset_src/analyse_dataset.py +7 -9
- paddlex/modules/general_recognition/dataset_checker/dataset_src/check_dataset.py +4 -5
- paddlex/modules/general_recognition/dataset_checker/dataset_src/convert_dataset.py +6 -5
- paddlex/modules/general_recognition/dataset_checker/dataset_src/split_dataset.py +1 -1
- paddlex/modules/general_recognition/dataset_checker/dataset_src/utils/__init__.py +1 -1
- paddlex/modules/general_recognition/dataset_checker/dataset_src/utils/visualizer.py +2 -5
- paddlex/modules/general_recognition/evaluator.py +1 -1
- paddlex/modules/general_recognition/exportor.py +1 -1
- paddlex/modules/general_recognition/model_list.py +1 -1
- paddlex/modules/general_recognition/trainer.py +1 -1
- paddlex/modules/image_classification/__init__.py +2 -2
- paddlex/modules/image_classification/dataset_checker/__init__.py +2 -2
- paddlex/modules/image_classification/dataset_checker/dataset_src/__init__.py +2 -2
- paddlex/modules/image_classification/dataset_checker/dataset_src/analyse_dataset.py +8 -9
- paddlex/modules/image_classification/dataset_checker/dataset_src/check_dataset.py +4 -3
- paddlex/modules/image_classification/dataset_checker/dataset_src/convert_dataset.py +4 -4
- paddlex/modules/image_classification/dataset_checker/dataset_src/split_dataset.py +1 -1
- paddlex/modules/image_classification/dataset_checker/dataset_src/utils/__init__.py +1 -1
- paddlex/modules/image_classification/dataset_checker/dataset_src/utils/visualizer.py +2 -5
- paddlex/modules/image_classification/evaluator.py +1 -1
- paddlex/modules/image_classification/exportor.py +1 -1
- paddlex/modules/image_classification/model_list.py +3 -1
- paddlex/modules/image_classification/trainer.py +3 -3
- paddlex/modules/image_unwarping/__init__.py +1 -1
- paddlex/modules/image_unwarping/model_list.py +1 -1
- paddlex/modules/instance_segmentation/__init__.py +2 -2
- paddlex/modules/instance_segmentation/dataset_checker/__init__.py +17 -3
- paddlex/modules/instance_segmentation/dataset_checker/dataset_src/__init__.py +2 -2
- paddlex/modules/instance_segmentation/dataset_checker/dataset_src/analyse_dataset.py +9 -5
- paddlex/modules/instance_segmentation/dataset_checker/dataset_src/check_dataset.py +8 -5
- paddlex/modules/instance_segmentation/dataset_checker/dataset_src/convert_dataset.py +8 -8
- paddlex/modules/instance_segmentation/dataset_checker/dataset_src/split_dataset.py +7 -4
- paddlex/modules/instance_segmentation/dataset_checker/dataset_src/utils/__init__.py +1 -1
- paddlex/modules/instance_segmentation/dataset_checker/dataset_src/utils/visualizer.py +10 -8
- paddlex/modules/instance_segmentation/evaluator.py +1 -1
- paddlex/modules/instance_segmentation/exportor.py +1 -1
- paddlex/modules/instance_segmentation/model_list.py +1 -1
- paddlex/modules/instance_segmentation/trainer.py +1 -1
- paddlex/modules/keypoint_detection/__init__.py +18 -0
- paddlex/modules/keypoint_detection/dataset_checker/__init__.py +56 -0
- paddlex/modules/keypoint_detection/dataset_checker/dataset_src/__init__.py +15 -0
- paddlex/modules/keypoint_detection/dataset_checker/dataset_src/check_dataset.py +91 -0
- paddlex/modules/keypoint_detection/dataset_checker/dataset_src/utils/__init__.py +13 -0
- paddlex/modules/keypoint_detection/dataset_checker/dataset_src/utils/visualizer.py +124 -0
- paddlex/modules/keypoint_detection/evaluator.py +41 -0
- paddlex/modules/keypoint_detection/exportor.py +22 -0
- paddlex/modules/keypoint_detection/model_list.py +16 -0
- paddlex/modules/keypoint_detection/trainer.py +39 -0
- paddlex/modules/m_3d_bev_detection/__init__.py +18 -0
- paddlex/modules/m_3d_bev_detection/dataset_checker/__init__.py +95 -0
- paddlex/modules/m_3d_bev_detection/dataset_checker/dataset_src/__init__.py +17 -0
- paddlex/modules/m_3d_bev_detection/dataset_checker/dataset_src/analyse_dataset.py +106 -0
- paddlex/modules/m_3d_bev_detection/dataset_checker/dataset_src/check_dataset.py +101 -0
- paddlex/modules/m_3d_bev_detection/evaluator.py +46 -0
- paddlex/modules/m_3d_bev_detection/exportor.py +22 -0
- paddlex/modules/m_3d_bev_detection/model_list.py +18 -0
- paddlex/modules/m_3d_bev_detection/trainer.py +68 -0
- paddlex/modules/multilabel_classification/__init__.py +2 -2
- paddlex/modules/multilabel_classification/dataset_checker/__init__.py +2 -2
- paddlex/modules/multilabel_classification/dataset_checker/dataset_src/__init__.py +2 -2
- paddlex/modules/multilabel_classification/dataset_checker/dataset_src/analyse_dataset.py +8 -9
- paddlex/modules/multilabel_classification/dataset_checker/dataset_src/check_dataset.py +4 -3
- paddlex/modules/multilabel_classification/dataset_checker/dataset_src/convert_dataset.py +10 -7
- paddlex/modules/multilabel_classification/dataset_checker/dataset_src/split_dataset.py +1 -1
- paddlex/modules/multilabel_classification/dataset_checker/dataset_src/utils/__init__.py +1 -1
- paddlex/modules/multilabel_classification/dataset_checker/dataset_src/utils/visualizer.py +1 -5
- paddlex/modules/multilabel_classification/evaluator.py +1 -1
- paddlex/modules/multilabel_classification/exportor.py +1 -1
- paddlex/modules/multilabel_classification/model_list.py +1 -1
- paddlex/modules/multilabel_classification/trainer.py +3 -3
- paddlex/modules/multilingual_speech_recognition/__init__.py +18 -0
- paddlex/modules/multilingual_speech_recognition/dataset_checker.py +27 -0
- paddlex/modules/multilingual_speech_recognition/evaluator.py +27 -0
- paddlex/modules/multilingual_speech_recognition/exportor.py +27 -0
- paddlex/modules/multilingual_speech_recognition/model_list.py +22 -0
- paddlex/modules/multilingual_speech_recognition/trainer.py +42 -0
- paddlex/modules/object_detection/__init__.py +2 -2
- paddlex/modules/object_detection/dataset_checker/__init__.py +2 -11
- paddlex/modules/object_detection/dataset_checker/dataset_src/__init__.py +2 -2
- paddlex/modules/object_detection/dataset_checker/dataset_src/analyse_dataset.py +10 -8
- paddlex/modules/object_detection/dataset_checker/dataset_src/check_dataset.py +10 -5
- paddlex/modules/object_detection/dataset_checker/dataset_src/convert_dataset.py +13 -8
- paddlex/modules/object_detection/dataset_checker/dataset_src/split_dataset.py +8 -4
- paddlex/modules/object_detection/dataset_checker/dataset_src/utils/__init__.py +1 -1
- paddlex/modules/object_detection/dataset_checker/dataset_src/utils/visualizer.py +9 -8
- paddlex/modules/object_detection/evaluator.py +18 -2
- paddlex/modules/object_detection/exportor.py +1 -1
- paddlex/modules/object_detection/model_list.py +11 -1
- paddlex/modules/object_detection/trainer.py +19 -6
- paddlex/modules/open_vocabulary_detection/__init__.py +18 -0
- paddlex/modules/open_vocabulary_detection/dataset_checker.py +29 -0
- paddlex/modules/open_vocabulary_detection/evaluator.py +29 -0
- paddlex/modules/open_vocabulary_detection/exportor.py +29 -0
- paddlex/modules/open_vocabulary_detection/model_list.py +16 -0
- paddlex/modules/open_vocabulary_detection/trainer.py +44 -0
- paddlex/modules/open_vocabulary_segmentation/__init__.py +18 -0
- paddlex/modules/open_vocabulary_segmentation/dataset_checker.py +29 -0
- paddlex/modules/open_vocabulary_segmentation/evaluator.py +29 -0
- paddlex/modules/open_vocabulary_segmentation/exportor.py +29 -0
- paddlex/modules/open_vocabulary_segmentation/model_list.py +19 -0
- paddlex/modules/open_vocabulary_segmentation/trainer.py +44 -0
- paddlex/modules/semantic_segmentation/__init__.py +2 -2
- paddlex/modules/semantic_segmentation/dataset_checker/__init__.py +17 -3
- paddlex/modules/semantic_segmentation/dataset_checker/dataset_src/__init__.py +2 -2
- paddlex/modules/semantic_segmentation/dataset_checker/dataset_src/analyse_dataset.py +6 -3
- paddlex/modules/semantic_segmentation/dataset_checker/dataset_src/check_dataset.py +2 -2
- paddlex/modules/semantic_segmentation/dataset_checker/dataset_src/convert_dataset.py +7 -4
- paddlex/modules/semantic_segmentation/dataset_checker/dataset_src/split_dataset.py +2 -2
- paddlex/modules/semantic_segmentation/dataset_checker/dataset_src/utils/__init__.py +1 -1
- paddlex/modules/semantic_segmentation/dataset_checker/dataset_src/utils/visualizer.py +6 -2
- paddlex/modules/semantic_segmentation/evaluator.py +1 -1
- paddlex/modules/semantic_segmentation/exportor.py +10 -1
- paddlex/modules/semantic_segmentation/model_list.py +3 -1
- paddlex/modules/semantic_segmentation/trainer.py +5 -4
- paddlex/modules/table_recognition/__init__.py +2 -2
- paddlex/modules/table_recognition/dataset_checker/__init__.py +21 -6
- paddlex/modules/table_recognition/dataset_checker/dataset_src/__init__.py +2 -2
- paddlex/modules/table_recognition/dataset_checker/dataset_src/analyse_dataset.py +3 -2
- paddlex/modules/table_recognition/dataset_checker/dataset_src/check_dataset.py +20 -20
- paddlex/modules/table_recognition/dataset_checker/dataset_src/split_dataset.py +2 -1
- paddlex/modules/table_recognition/evaluator.py +1 -1
- paddlex/modules/table_recognition/exportor.py +1 -1
- paddlex/modules/table_recognition/model_list.py +3 -1
- paddlex/modules/table_recognition/trainer.py +2 -5
- paddlex/modules/text_detection/__init__.py +2 -2
- paddlex/modules/text_detection/dataset_checker/__init__.py +20 -7
- paddlex/modules/text_detection/dataset_checker/dataset_src/__init__.py +2 -2
- paddlex/modules/text_detection/dataset_checker/dataset_src/analyse_dataset.py +12 -9
- paddlex/modules/text_detection/dataset_checker/dataset_src/check_dataset.py +15 -5
- paddlex/modules/text_detection/dataset_checker/dataset_src/split_dataset.py +3 -3
- paddlex/modules/text_detection/evaluator.py +1 -1
- paddlex/modules/text_detection/exportor.py +1 -1
- paddlex/modules/text_detection/model_list.py +3 -1
- paddlex/modules/text_detection/trainer.py +2 -5
- paddlex/modules/text_recognition/__init__.py +2 -2
- paddlex/modules/text_recognition/dataset_checker/__init__.py +20 -9
- paddlex/modules/text_recognition/dataset_checker/dataset_src/__init__.py +2 -2
- paddlex/modules/text_recognition/dataset_checker/dataset_src/analyse_dataset.py +13 -12
- paddlex/modules/text_recognition/dataset_checker/dataset_src/check_dataset.py +15 -8
- paddlex/modules/text_recognition/dataset_checker/dataset_src/convert_dataset.py +11 -10
- paddlex/modules/text_recognition/dataset_checker/dataset_src/split_dataset.py +1 -2
- paddlex/modules/text_recognition/evaluator.py +5 -4
- paddlex/modules/text_recognition/exportor.py +1 -4
- paddlex/modules/text_recognition/model_list.py +15 -1
- paddlex/modules/text_recognition/trainer.py +6 -6
- paddlex/modules/ts_anomaly_detection/__init__.py +2 -2
- paddlex/modules/ts_anomaly_detection/dataset_checker/__init__.py +19 -5
- paddlex/modules/ts_anomaly_detection/dataset_checker/dataset_src/__init__.py +2 -2
- paddlex/modules/ts_anomaly_detection/dataset_checker/dataset_src/analyse_dataset.py +1 -9
- paddlex/modules/ts_anomaly_detection/dataset_checker/dataset_src/check_dataset.py +2 -2
- paddlex/modules/ts_anomaly_detection/dataset_checker/dataset_src/convert_dataset.py +2 -6
- paddlex/modules/ts_anomaly_detection/dataset_checker/dataset_src/split_dataset.py +4 -4
- paddlex/modules/ts_anomaly_detection/evaluator.py +1 -1
- paddlex/modules/ts_anomaly_detection/exportor.py +2 -3
- paddlex/modules/ts_anomaly_detection/model_list.py +1 -1
- paddlex/modules/ts_anomaly_detection/trainer.py +22 -6
- paddlex/modules/ts_classification/__init__.py +2 -2
- paddlex/modules/ts_classification/dataset_checker/__init__.py +19 -5
- paddlex/modules/ts_classification/dataset_checker/dataset_src/__init__.py +2 -2
- paddlex/modules/ts_classification/dataset_checker/dataset_src/analyse_dataset.py +8 -5
- paddlex/modules/ts_classification/dataset_checker/dataset_src/check_dataset.py +2 -2
- paddlex/modules/ts_classification/dataset_checker/dataset_src/convert_dataset.py +2 -6
- paddlex/modules/ts_classification/dataset_checker/dataset_src/split_dataset.py +4 -4
- paddlex/modules/ts_classification/evaluator.py +1 -1
- paddlex/modules/ts_classification/exportor.py +2 -3
- paddlex/modules/ts_classification/model_list.py +1 -1
- paddlex/modules/ts_classification/trainer.py +21 -5
- paddlex/modules/ts_forecast/__init__.py +2 -2
- paddlex/modules/ts_forecast/dataset_checker/__init__.py +19 -5
- paddlex/modules/ts_forecast/dataset_checker/dataset_src/__init__.py +2 -2
- paddlex/modules/ts_forecast/dataset_checker/dataset_src/analyse_dataset.py +1 -9
- paddlex/modules/ts_forecast/dataset_checker/dataset_src/check_dataset.py +2 -2
- paddlex/modules/ts_forecast/dataset_checker/dataset_src/convert_dataset.py +2 -6
- paddlex/modules/ts_forecast/dataset_checker/dataset_src/split_dataset.py +4 -4
- paddlex/modules/ts_forecast/evaluator.py +1 -1
- paddlex/modules/ts_forecast/exportor.py +2 -3
- paddlex/modules/ts_forecast/model_list.py +1 -1
- paddlex/modules/ts_forecast/trainer.py +21 -5
- paddlex/modules/video_classification/__init__.py +18 -0
- paddlex/modules/video_classification/dataset_checker/__init__.py +93 -0
- paddlex/modules/video_classification/dataset_checker/dataset_src/__init__.py +18 -0
- paddlex/modules/video_classification/dataset_checker/dataset_src/analyse_dataset.py +93 -0
- paddlex/modules/video_classification/dataset_checker/dataset_src/check_dataset.py +120 -0
- paddlex/modules/video_classification/dataset_checker/dataset_src/split_dataset.py +82 -0
- paddlex/modules/video_classification/evaluator.py +44 -0
- paddlex/modules/video_classification/exportor.py +22 -0
- paddlex/modules/video_classification/model_list.py +19 -0
- paddlex/modules/video_classification/trainer.py +88 -0
- paddlex/modules/video_detection/__init__.py +18 -0
- paddlex/modules/video_detection/dataset_checker/__init__.py +86 -0
- paddlex/modules/video_detection/dataset_checker/dataset_src/__init__.py +17 -0
- paddlex/modules/video_detection/dataset_checker/dataset_src/analyse_dataset.py +100 -0
- paddlex/modules/video_detection/dataset_checker/dataset_src/check_dataset.py +132 -0
- paddlex/modules/video_detection/evaluator.py +42 -0
- paddlex/modules/video_detection/exportor.py +22 -0
- paddlex/modules/video_detection/model_list.py +15 -0
- paddlex/modules/video_detection/trainer.py +82 -0
- paddlex/ops/__init__.py +152 -0
- paddlex/ops/iou3d_nms/iou3d_cpu.cpp +266 -0
- paddlex/ops/iou3d_nms/iou3d_cpu.h +28 -0
- paddlex/ops/iou3d_nms/iou3d_nms.cpp +206 -0
- paddlex/ops/iou3d_nms/iou3d_nms.h +35 -0
- paddlex/ops/iou3d_nms/iou3d_nms_api.cpp +114 -0
- paddlex/ops/iou3d_nms/iou3d_nms_kernel.cu +484 -0
- paddlex/ops/setup.py +37 -0
- paddlex/ops/voxel/voxelize_op.cc +194 -0
- paddlex/ops/voxel/voxelize_op.cu +346 -0
- paddlex/paddlex_cli.py +352 -74
- paddlex/repo_apis/Paddle3D_api/__init__.py +17 -0
- paddlex/repo_apis/Paddle3D_api/bev_fusion/__init__.py +18 -0
- paddlex/repo_apis/Paddle3D_api/bev_fusion/config.py +118 -0
- paddlex/repo_apis/Paddle3D_api/bev_fusion/model.py +238 -0
- paddlex/repo_apis/Paddle3D_api/bev_fusion/register.py +55 -0
- paddlex/repo_apis/Paddle3D_api/bev_fusion/runner.py +104 -0
- paddlex/repo_apis/Paddle3D_api/pp3d_config.py +145 -0
- paddlex/repo_apis/PaddleClas_api/__init__.py +1 -1
- paddlex/repo_apis/PaddleClas_api/cls/__init__.py +3 -3
- paddlex/repo_apis/PaddleClas_api/cls/config.py +4 -3
- paddlex/repo_apis/PaddleClas_api/cls/model.py +9 -3
- paddlex/repo_apis/PaddleClas_api/cls/register.py +22 -5
- paddlex/repo_apis/PaddleClas_api/cls/runner.py +1 -2
- paddlex/repo_apis/PaddleClas_api/shitu_rec/__init__.py +2 -2
- paddlex/repo_apis/PaddleClas_api/shitu_rec/config.py +2 -2
- paddlex/repo_apis/PaddleClas_api/shitu_rec/model.py +1 -4
- paddlex/repo_apis/PaddleClas_api/shitu_rec/register.py +2 -2
- paddlex/repo_apis/PaddleClas_api/shitu_rec/runner.py +1 -6
- paddlex/repo_apis/PaddleDetection_api/__init__.py +2 -2
- paddlex/repo_apis/PaddleDetection_api/config_helper.py +3 -3
- paddlex/repo_apis/PaddleDetection_api/instance_seg/__init__.py +2 -2
- paddlex/repo_apis/PaddleDetection_api/instance_seg/config.py +10 -7
- paddlex/repo_apis/PaddleDetection_api/instance_seg/model.py +9 -3
- paddlex/repo_apis/PaddleDetection_api/instance_seg/register.py +2 -3
- paddlex/repo_apis/PaddleDetection_api/instance_seg/runner.py +1 -2
- paddlex/repo_apis/PaddleDetection_api/object_det/__init__.py +3 -3
- paddlex/repo_apis/PaddleDetection_api/object_det/config.py +31 -8
- paddlex/repo_apis/PaddleDetection_api/object_det/model.py +11 -6
- paddlex/repo_apis/PaddleDetection_api/object_det/official_categories.py +82 -1
- paddlex/repo_apis/PaddleDetection_api/object_det/register.py +184 -6
- paddlex/repo_apis/PaddleDetection_api/object_det/runner.py +1 -2
- paddlex/repo_apis/PaddleNLP_api/__init__.py +1 -1
- paddlex/repo_apis/PaddleOCR_api/__init__.py +4 -2
- paddlex/repo_apis/PaddleOCR_api/config_utils.py +1 -1
- paddlex/repo_apis/PaddleOCR_api/formula_rec/__init__.py +16 -0
- paddlex/repo_apis/PaddleOCR_api/formula_rec/config.py +571 -0
- paddlex/repo_apis/PaddleOCR_api/formula_rec/model.py +402 -0
- paddlex/repo_apis/PaddleOCR_api/formula_rec/register.py +72 -0
- paddlex/repo_apis/PaddleOCR_api/formula_rec/runner.py +239 -0
- paddlex/repo_apis/PaddleOCR_api/table_rec/__init__.py +1 -1
- paddlex/repo_apis/PaddleOCR_api/table_rec/config.py +1 -1
- paddlex/repo_apis/PaddleOCR_api/table_rec/model.py +3 -3
- paddlex/repo_apis/PaddleOCR_api/table_rec/register.py +20 -3
- paddlex/repo_apis/PaddleOCR_api/table_rec/runner.py +2 -2
- paddlex/repo_apis/PaddleOCR_api/text_det/__init__.py +1 -1
- paddlex/repo_apis/PaddleOCR_api/text_det/config.py +1 -1
- paddlex/repo_apis/PaddleOCR_api/text_det/model.py +3 -3
- paddlex/repo_apis/PaddleOCR_api/text_det/register.py +20 -3
- paddlex/repo_apis/PaddleOCR_api/text_det/runner.py +2 -2
- paddlex/repo_apis/PaddleOCR_api/text_rec/__init__.py +1 -1
- paddlex/repo_apis/PaddleOCR_api/text_rec/config.py +25 -3
- paddlex/repo_apis/PaddleOCR_api/text_rec/model.py +10 -4
- paddlex/repo_apis/PaddleOCR_api/text_rec/register.py +128 -10
- paddlex/repo_apis/PaddleOCR_api/text_rec/runner.py +1 -2
- paddlex/repo_apis/PaddleSeg_api/__init__.py +1 -1
- paddlex/repo_apis/PaddleSeg_api/base_seg_config.py +2 -2
- paddlex/repo_apis/PaddleSeg_api/seg/__init__.py +1 -1
- paddlex/repo_apis/PaddleSeg_api/seg/config.py +12 -6
- paddlex/repo_apis/PaddleSeg_api/seg/model.py +15 -5
- paddlex/repo_apis/PaddleSeg_api/seg/register.py +22 -3
- paddlex/repo_apis/PaddleSeg_api/seg/runner.py +1 -2
- paddlex/repo_apis/PaddleTS_api/__init__.py +4 -3
- paddlex/repo_apis/PaddleTS_api/ts_ad/__init__.py +1 -1
- paddlex/repo_apis/PaddleTS_api/ts_ad/config.py +2 -3
- paddlex/repo_apis/PaddleTS_api/ts_ad/register.py +2 -2
- paddlex/repo_apis/PaddleTS_api/ts_ad/runner.py +2 -2
- paddlex/repo_apis/PaddleTS_api/ts_base/__init__.py +1 -1
- paddlex/repo_apis/PaddleTS_api/ts_base/config.py +25 -3
- paddlex/repo_apis/PaddleTS_api/ts_base/model.py +15 -11
- paddlex/repo_apis/PaddleTS_api/ts_base/runner.py +2 -2
- paddlex/repo_apis/PaddleTS_api/ts_cls/__init__.py +1 -1
- paddlex/repo_apis/PaddleTS_api/ts_cls/config.py +2 -3
- paddlex/repo_apis/PaddleTS_api/ts_cls/register.py +2 -2
- paddlex/repo_apis/PaddleTS_api/ts_cls/runner.py +2 -2
- paddlex/repo_apis/PaddleTS_api/ts_fc/__init__.py +1 -1
- paddlex/repo_apis/PaddleTS_api/ts_fc/config.py +2 -3
- paddlex/repo_apis/PaddleTS_api/ts_fc/register.py +1 -1
- paddlex/repo_apis/PaddleVideo_api/__init__.py +17 -0
- paddlex/repo_apis/PaddleVideo_api/config_utils.py +51 -0
- paddlex/repo_apis/PaddleVideo_api/video_cls/__init__.py +19 -0
- paddlex/repo_apis/PaddleVideo_api/video_cls/config.py +548 -0
- paddlex/repo_apis/PaddleVideo_api/video_cls/model.py +346 -0
- paddlex/repo_apis/PaddleVideo_api/video_cls/register.py +70 -0
- paddlex/repo_apis/PaddleVideo_api/video_cls/runner.py +204 -0
- paddlex/repo_apis/PaddleVideo_api/video_det/__init__.py +19 -0
- paddlex/repo_apis/PaddleVideo_api/video_det/config.py +549 -0
- paddlex/repo_apis/PaddleVideo_api/video_det/model.py +298 -0
- paddlex/repo_apis/PaddleVideo_api/video_det/register.py +44 -0
- paddlex/repo_apis/PaddleVideo_api/video_det/runner.py +199 -0
- paddlex/repo_apis/__init__.py +1 -1
- paddlex/repo_apis/base/__init__.py +4 -5
- paddlex/repo_apis/base/config.py +2 -3
- paddlex/repo_apis/base/model.py +11 -19
- paddlex/repo_apis/base/register.py +1 -1
- paddlex/repo_apis/base/runner.py +13 -13
- paddlex/repo_apis/base/utils/__init__.py +1 -1
- paddlex/repo_apis/base/utils/arg.py +1 -1
- paddlex/repo_apis/base/utils/subprocess.py +1 -1
- paddlex/repo_manager/__init__.py +2 -9
- paddlex/repo_manager/core.py +9 -27
- paddlex/repo_manager/meta.py +58 -25
- paddlex/repo_manager/repo.py +169 -141
- paddlex/repo_manager/utils.py +72 -222
- paddlex/utils/__init__.py +1 -1
- paddlex/utils/cache.py +8 -10
- paddlex/utils/config.py +10 -8
- paddlex/utils/custom_device_list.py +287 -0
- paddlex/utils/deps.py +249 -0
- paddlex/utils/device.py +125 -33
- paddlex/utils/download.py +4 -4
- paddlex/utils/env.py +54 -0
- paddlex/utils/errors/__init__.py +1 -1
- paddlex/utils/errors/dataset_checker.py +1 -1
- paddlex/utils/errors/others.py +2 -16
- paddlex/utils/file_interface.py +4 -5
- paddlex/utils/flags.py +22 -11
- paddlex/utils/fonts/__init__.py +50 -6
- paddlex/utils/func_register.py +1 -1
- paddlex/utils/install.py +87 -0
- paddlex/utils/interactive_get_pipeline.py +3 -3
- paddlex/utils/lazy_loader.py +4 -2
- paddlex/utils/logging.py +11 -3
- paddlex/utils/misc.py +5 -5
- paddlex/utils/pipeline_arguments.py +719 -0
- paddlex/utils/result_saver.py +4 -5
- paddlex/utils/subclass_register.py +2 -4
- paddlex/version.py +2 -1
- paddlex-3.0.0rc1.dist-info/METADATA +1174 -0
- paddlex-3.0.0rc1.dist-info/RECORD +1068 -0
- paddlex-3.0.0rc1.dist-info/WHEEL +5 -0
- paddlex/configs/face_recognition/MobileFaceNet.yaml +0 -44
- paddlex/configs/face_recognition/ResNet50_face.yaml +0 -44
- paddlex/configs/formula_recognition/LaTeX_OCR_rec.yaml +0 -40
- paddlex/configs/image_classification/CLIP_vit_base_patch16_224.yaml +0 -41
- paddlex/configs/image_classification/CLIP_vit_large_patch14_224.yaml +0 -41
- paddlex/configs/image_classification/ConvNeXt_large_384.yaml +0 -41
- paddlex/configs/object_detection/YOLOX-X.yaml +0 -40
- paddlex/configs/semantic_segmentation/SeaFormer_base.yaml +0 -40
- paddlex/configs/semantic_segmentation/SeaFormer_large.yaml +0 -40
- paddlex/configs/semantic_segmentation/SeaFormer_small.yaml +0 -40
- paddlex/configs/semantic_segmentation/SeaFormer_tiny.yaml +0 -40
- paddlex/inference/components/__init__.py +0 -18
- paddlex/inference/components/base.py +0 -292
- paddlex/inference/components/llm/__init__.py +0 -25
- paddlex/inference/components/llm/base.py +0 -65
- paddlex/inference/components/llm/erniebot.py +0 -212
- paddlex/inference/components/paddle_predictor/__init__.py +0 -20
- paddlex/inference/components/paddle_predictor/predictor.py +0 -332
- paddlex/inference/components/retrieval/__init__.py +0 -15
- paddlex/inference/components/retrieval/faiss.py +0 -359
- paddlex/inference/components/task_related/__init__.py +0 -33
- paddlex/inference/components/task_related/clas.py +0 -124
- paddlex/inference/components/task_related/det.py +0 -284
- paddlex/inference/components/task_related/instance_seg.py +0 -89
- paddlex/inference/components/task_related/seal_det_warp.py +0 -940
- paddlex/inference/components/task_related/seg.py +0 -40
- paddlex/inference/components/task_related/table_rec.py +0 -191
- paddlex/inference/components/task_related/text_det.py +0 -895
- paddlex/inference/components/task_related/text_rec.py +0 -353
- paddlex/inference/components/task_related/warp.py +0 -43
- paddlex/inference/components/transforms/__init__.py +0 -16
- paddlex/inference/components/transforms/image/__init__.py +0 -15
- paddlex/inference/components/transforms/image/common.py +0 -598
- paddlex/inference/components/transforms/image/funcs.py +0 -58
- paddlex/inference/components/transforms/read_data.py +0 -67
- paddlex/inference/components/transforms/ts/__init__.py +0 -15
- paddlex/inference/components/transforms/ts/common.py +0 -393
- paddlex/inference/components/transforms/ts/funcs.py +0 -424
- paddlex/inference/models/anomaly_detection.py +0 -87
- paddlex/inference/models/base/base_predictor.py +0 -76
- paddlex/inference/models/base/basic_predictor.py +0 -122
- paddlex/inference/models/face_recognition.py +0 -21
- paddlex/inference/models/formula_recognition.py +0 -55
- paddlex/inference/models/general_recognition.py +0 -99
- paddlex/inference/models/image_classification.py +0 -101
- paddlex/inference/models/image_unwarping.py +0 -43
- paddlex/inference/models/instance_segmentation.py +0 -66
- paddlex/inference/models/multilabel_classification.py +0 -33
- paddlex/inference/models/object_detection.py +0 -129
- paddlex/inference/models/semantic_segmentation.py +0 -86
- paddlex/inference/models/table_recognition.py +0 -106
- paddlex/inference/models/text_detection.py +0 -105
- paddlex/inference/models/text_recognition.py +0 -78
- paddlex/inference/models/ts_ad.py +0 -68
- paddlex/inference/models/ts_cls.py +0 -57
- paddlex/inference/models/ts_fc.py +0 -73
- paddlex/inference/pipelines/attribute_recognition.py +0 -92
- paddlex/inference/pipelines/face_recognition.py +0 -49
- paddlex/inference/pipelines/formula_recognition.py +0 -102
- paddlex/inference/pipelines/layout_parsing/layout_parsing.py +0 -362
- paddlex/inference/pipelines/ocr.py +0 -80
- paddlex/inference/pipelines/pp_shitu_v2.py +0 -152
- paddlex/inference/pipelines/ppchatocrv3/__init__.py +0 -15
- paddlex/inference/pipelines/ppchatocrv3/ch_prompt.yaml +0 -14
- paddlex/inference/pipelines/ppchatocrv3/ppchatocrv3.py +0 -717
- paddlex/inference/pipelines/ppchatocrv3/utils.py +0 -168
- paddlex/inference/pipelines/seal_recognition.py +0 -152
- paddlex/inference/pipelines/serving/__init__.py +0 -17
- paddlex/inference/pipelines/serving/_pipeline_apps/__init__.py +0 -205
- paddlex/inference/pipelines/serving/_pipeline_apps/anomaly_detection.py +0 -80
- paddlex/inference/pipelines/serving/_pipeline_apps/face_recognition.py +0 -317
- paddlex/inference/pipelines/serving/_pipeline_apps/formula_recognition.py +0 -119
- paddlex/inference/pipelines/serving/_pipeline_apps/image_classification.py +0 -101
- paddlex/inference/pipelines/serving/_pipeline_apps/instance_segmentation.py +0 -112
- paddlex/inference/pipelines/serving/_pipeline_apps/layout_parsing.py +0 -205
- paddlex/inference/pipelines/serving/_pipeline_apps/multi_label_image_classification.py +0 -90
- paddlex/inference/pipelines/serving/_pipeline_apps/object_detection.py +0 -90
- paddlex/inference/pipelines/serving/_pipeline_apps/ocr.py +0 -98
- paddlex/inference/pipelines/serving/_pipeline_apps/pedestrian_attribute_recognition.py +0 -102
- paddlex/inference/pipelines/serving/_pipeline_apps/pp_shitu_v2.py +0 -319
- paddlex/inference/pipelines/serving/_pipeline_apps/ppchatocrv3.py +0 -445
- paddlex/inference/pipelines/serving/_pipeline_apps/seal_recognition.py +0 -110
- paddlex/inference/pipelines/serving/_pipeline_apps/semantic_segmentation.py +0 -82
- paddlex/inference/pipelines/serving/_pipeline_apps/small_object_detection.py +0 -92
- paddlex/inference/pipelines/serving/_pipeline_apps/table_recognition.py +0 -110
- paddlex/inference/pipelines/serving/_pipeline_apps/ts_ad.py +0 -68
- paddlex/inference/pipelines/serving/_pipeline_apps/ts_cls.py +0 -68
- paddlex/inference/pipelines/serving/_pipeline_apps/ts_fc.py +0 -68
- paddlex/inference/pipelines/serving/_pipeline_apps/vehicle_attribute_recognition.py +0 -102
- paddlex/inference/pipelines/serving/app.py +0 -164
- paddlex/inference/pipelines/serving/models.py +0 -30
- paddlex/inference/pipelines/serving/server.py +0 -25
- paddlex/inference/pipelines/serving/storage.py +0 -161
- paddlex/inference/pipelines/serving/utils.py +0 -190
- paddlex/inference/pipelines/single_model_pipeline.py +0 -76
- paddlex/inference/pipelines/table_recognition/table_recognition.py +0 -193
- paddlex/inference/results/__init__.py +0 -31
- paddlex/inference/results/attribute_rec.py +0 -89
- paddlex/inference/results/base.py +0 -43
- paddlex/inference/results/chat_ocr.py +0 -158
- paddlex/inference/results/clas.py +0 -133
- paddlex/inference/results/det.py +0 -86
- paddlex/inference/results/face_rec.py +0 -34
- paddlex/inference/results/formula_rec.py +0 -363
- paddlex/inference/results/instance_seg.py +0 -152
- paddlex/inference/results/ocr.py +0 -157
- paddlex/inference/results/seal_rec.py +0 -50
- paddlex/inference/results/seg.py +0 -72
- paddlex/inference/results/shitu.py +0 -35
- paddlex/inference/results/table_rec.py +0 -109
- paddlex/inference/results/text_det.py +0 -33
- paddlex/inference/results/text_rec.py +0 -66
- paddlex/inference/results/ts.py +0 -37
- paddlex/inference/results/utils/__init__.py +0 -13
- paddlex/inference/results/utils/mixin.py +0 -204
- paddlex/inference/results/warp.py +0 -31
- paddlex/inference/utils/new_ir_blacklist.py +0 -22
- paddlex/inference/utils/process_hook.py +0 -54
- paddlex/pipelines/OCR.yaml +0 -8
- paddlex/pipelines/PP-ChatOCRv3-doc.yaml +0 -27
- paddlex/pipelines/PP-ShiTuV2.yaml +0 -13
- paddlex/pipelines/anomaly_detection.yaml +0 -7
- paddlex/pipelines/face_recognition.yaml +0 -13
- paddlex/pipelines/formula_recognition.yaml +0 -8
- paddlex/pipelines/image_classification.yaml +0 -7
- paddlex/pipelines/instance_segmentation.yaml +0 -7
- paddlex/pipelines/layout_parsing.yaml +0 -14
- paddlex/pipelines/multi_label_image_classification.yaml +0 -7
- paddlex/pipelines/object_detection.yaml +0 -7
- paddlex/pipelines/pedestrian_attribute_recognition.yaml +0 -7
- paddlex/pipelines/seal_recognition.yaml +0 -10
- paddlex/pipelines/semantic_segmentation.yaml +0 -7
- paddlex/pipelines/small_object_detection.yaml +0 -7
- paddlex/pipelines/table_recognition.yaml +0 -12
- paddlex/pipelines/ts_ad.yaml +0 -7
- paddlex/pipelines/ts_cls.yaml +0 -7
- paddlex/pipelines/ts_fc.yaml +0 -7
- paddlex/pipelines/vehicle_attribute_recognition.yaml +0 -7
- paddlex/repo_manager/requirements.txt +0 -18
- paddlex/utils/fonts/PingFang-SC-Regular.ttf +0 -0
- paddlex-3.0.0b2.dist-info/METADATA +0 -760
- paddlex-3.0.0b2.dist-info/RECORD +0 -646
- paddlex-3.0.0b2.dist-info/WHEEL +0 -5
- /paddlex/configs/{doc_text_orientation → modules/doc_text_orientation}/PP-LCNet_x1_0_doc_ori.yaml +0 -0
- /paddlex/configs/{face_detection → modules/face_detection}/BlazeFace-FPN-SSH.yaml +0 -0
- /paddlex/configs/{face_detection → modules/face_detection}/BlazeFace.yaml +0 -0
- /paddlex/configs/{face_detection → modules/face_detection}/PP-YOLOE_plus-S_face.yaml +0 -0
- /paddlex/configs/{face_detection → modules/face_detection}/PicoDet_LCNet_x2_5_face.yaml +0 -0
- /paddlex/configs/{human_detection → modules/human_detection}/PP-YOLOE-L_human.yaml +0 -0
- /paddlex/configs/{human_detection → modules/human_detection}/PP-YOLOE-S_human.yaml +0 -0
- /paddlex/configs/{anomaly_detection → modules/image_anomaly_detection}/STFPM.yaml +0 -0
- /paddlex/configs/{image_classification → modules/image_classification}/ConvNeXt_base_224.yaml +0 -0
- /paddlex/configs/{image_classification → modules/image_classification}/ConvNeXt_base_384.yaml +0 -0
- /paddlex/configs/{image_classification → modules/image_classification}/ConvNeXt_large_224.yaml +0 -0
- /paddlex/configs/{image_classification → modules/image_classification}/ConvNeXt_small.yaml +0 -0
- /paddlex/configs/{image_classification → modules/image_classification}/ConvNeXt_tiny.yaml +0 -0
- /paddlex/configs/{image_classification → modules/image_classification}/FasterNet-L.yaml +0 -0
- /paddlex/configs/{image_classification → modules/image_classification}/FasterNet-M.yaml +0 -0
- /paddlex/configs/{image_classification → modules/image_classification}/FasterNet-S.yaml +0 -0
- /paddlex/configs/{image_classification → modules/image_classification}/FasterNet-T0.yaml +0 -0
- /paddlex/configs/{image_classification → modules/image_classification}/FasterNet-T1.yaml +0 -0
- /paddlex/configs/{image_classification → modules/image_classification}/FasterNet-T2.yaml +0 -0
- /paddlex/configs/{image_classification → modules/image_classification}/MobileNetV1_x0_25.yaml +0 -0
- /paddlex/configs/{image_classification → modules/image_classification}/MobileNetV1_x0_5.yaml +0 -0
- /paddlex/configs/{image_classification → modules/image_classification}/MobileNetV1_x0_75.yaml +0 -0
- /paddlex/configs/{image_classification → modules/image_classification}/MobileNetV1_x1_0.yaml +0 -0
- /paddlex/configs/{image_classification → modules/image_classification}/MobileNetV2_x0_25.yaml +0 -0
- /paddlex/configs/{image_classification → modules/image_classification}/MobileNetV2_x0_5.yaml +0 -0
- /paddlex/configs/{image_classification → modules/image_classification}/MobileNetV2_x1_0.yaml +0 -0
- /paddlex/configs/{image_classification → modules/image_classification}/MobileNetV2_x1_5.yaml +0 -0
- /paddlex/configs/{image_classification → modules/image_classification}/MobileNetV2_x2_0.yaml +0 -0
- /paddlex/configs/{image_classification → modules/image_classification}/MobileNetV3_large_x0_35.yaml +0 -0
- /paddlex/configs/{image_classification → modules/image_classification}/MobileNetV3_large_x0_5.yaml +0 -0
- /paddlex/configs/{image_classification → modules/image_classification}/MobileNetV3_large_x0_75.yaml +0 -0
- /paddlex/configs/{image_classification → modules/image_classification}/MobileNetV3_large_x1_0.yaml +0 -0
- /paddlex/configs/{image_classification → modules/image_classification}/MobileNetV3_large_x1_25.yaml +0 -0
- /paddlex/configs/{image_classification → modules/image_classification}/MobileNetV3_small_x0_35.yaml +0 -0
- /paddlex/configs/{image_classification → modules/image_classification}/MobileNetV3_small_x0_5.yaml +0 -0
- /paddlex/configs/{image_classification → modules/image_classification}/MobileNetV3_small_x0_75.yaml +0 -0
- /paddlex/configs/{image_classification → modules/image_classification}/MobileNetV3_small_x1_0.yaml +0 -0
- /paddlex/configs/{image_classification → modules/image_classification}/MobileNetV3_small_x1_25.yaml +0 -0
- /paddlex/configs/{image_classification → modules/image_classification}/MobileNetV4_conv_large.yaml +0 -0
- /paddlex/configs/{image_classification → modules/image_classification}/MobileNetV4_conv_medium.yaml +0 -0
- /paddlex/configs/{image_classification → modules/image_classification}/MobileNetV4_conv_small.yaml +0 -0
- /paddlex/configs/{image_classification → modules/image_classification}/MobileNetV4_hybrid_large.yaml +0 -0
- /paddlex/configs/{image_classification → modules/image_classification}/MobileNetV4_hybrid_medium.yaml +0 -0
- /paddlex/configs/{image_classification → modules/image_classification}/PP-HGNetV2-B0.yaml +0 -0
- /paddlex/configs/{image_classification → modules/image_classification}/PP-HGNetV2-B1.yaml +0 -0
- /paddlex/configs/{image_classification → modules/image_classification}/PP-HGNetV2-B2.yaml +0 -0
- /paddlex/configs/{image_classification → modules/image_classification}/PP-HGNetV2-B3.yaml +0 -0
- /paddlex/configs/{image_classification → modules/image_classification}/PP-HGNetV2-B4.yaml +0 -0
- /paddlex/configs/{image_classification → modules/image_classification}/PP-HGNetV2-B5.yaml +0 -0
- /paddlex/configs/{image_classification → modules/image_classification}/PP-HGNetV2-B6.yaml +0 -0
- /paddlex/configs/{image_classification → modules/image_classification}/PP-HGNet_base.yaml +0 -0
- /paddlex/configs/{image_classification → modules/image_classification}/PP-HGNet_small.yaml +0 -0
- /paddlex/configs/{image_classification → modules/image_classification}/PP-HGNet_tiny.yaml +0 -0
- /paddlex/configs/{image_classification → modules/image_classification}/PP-LCNetV2_base.yaml +0 -0
- /paddlex/configs/{image_classification → modules/image_classification}/PP-LCNetV2_large.yaml +0 -0
- /paddlex/configs/{image_classification → modules/image_classification}/PP-LCNetV2_small.yaml +0 -0
- /paddlex/configs/{image_classification → modules/image_classification}/PP-LCNet_x0_25.yaml +0 -0
- /paddlex/configs/{image_classification → modules/image_classification}/PP-LCNet_x0_35.yaml +0 -0
- /paddlex/configs/{image_classification → modules/image_classification}/PP-LCNet_x0_5.yaml +0 -0
- /paddlex/configs/{image_classification → modules/image_classification}/PP-LCNet_x0_75.yaml +0 -0
- /paddlex/configs/{image_classification → modules/image_classification}/PP-LCNet_x1_0.yaml +0 -0
- /paddlex/configs/{image_classification → modules/image_classification}/PP-LCNet_x1_5.yaml +0 -0
- /paddlex/configs/{image_classification → modules/image_classification}/PP-LCNet_x2_0.yaml +0 -0
- /paddlex/configs/{image_classification → modules/image_classification}/PP-LCNet_x2_5.yaml +0 -0
- /paddlex/configs/{image_classification → modules/image_classification}/ResNet101.yaml +0 -0
- /paddlex/configs/{image_classification → modules/image_classification}/ResNet101_vd.yaml +0 -0
- /paddlex/configs/{image_classification → modules/image_classification}/ResNet152.yaml +0 -0
- /paddlex/configs/{image_classification → modules/image_classification}/ResNet152_vd.yaml +0 -0
- /paddlex/configs/{image_classification → modules/image_classification}/ResNet18.yaml +0 -0
- /paddlex/configs/{image_classification → modules/image_classification}/ResNet18_vd.yaml +0 -0
- /paddlex/configs/{image_classification → modules/image_classification}/ResNet200_vd.yaml +0 -0
- /paddlex/configs/{image_classification → modules/image_classification}/ResNet34.yaml +0 -0
- /paddlex/configs/{image_classification → modules/image_classification}/ResNet34_vd.yaml +0 -0
- /paddlex/configs/{image_classification → modules/image_classification}/ResNet50.yaml +0 -0
- /paddlex/configs/{image_classification → modules/image_classification}/ResNet50_vd.yaml +0 -0
- /paddlex/configs/{image_classification → modules/image_classification}/StarNet-S1.yaml +0 -0
- /paddlex/configs/{image_classification → modules/image_classification}/StarNet-S2.yaml +0 -0
- /paddlex/configs/{image_classification → modules/image_classification}/StarNet-S3.yaml +0 -0
- /paddlex/configs/{image_classification → modules/image_classification}/StarNet-S4.yaml +0 -0
- /paddlex/configs/{image_classification → modules/image_classification}/SwinTransformer_base_patch4_window12_384.yaml +0 -0
- /paddlex/configs/{image_classification → modules/image_classification}/SwinTransformer_base_patch4_window7_224.yaml +0 -0
- /paddlex/configs/{image_classification → modules/image_classification}/SwinTransformer_large_patch4_window12_384.yaml +0 -0
- /paddlex/configs/{image_classification → modules/image_classification}/SwinTransformer_large_patch4_window7_224.yaml +0 -0
- /paddlex/configs/{image_classification → modules/image_classification}/SwinTransformer_small_patch4_window7_224.yaml +0 -0
- /paddlex/configs/{image_classification → modules/image_classification}/SwinTransformer_tiny_patch4_window7_224.yaml +0 -0
- /paddlex/configs/{general_recognition → modules/image_feature}/PP-ShiTuV2_rec.yaml +0 -0
- /paddlex/configs/{general_recognition → modules/image_feature}/PP-ShiTuV2_rec_CLIP_vit_base.yaml +0 -0
- /paddlex/configs/{general_recognition → modules/image_feature}/PP-ShiTuV2_rec_CLIP_vit_large.yaml +0 -0
- /paddlex/configs/{multilabel_classification → modules/image_multilabel_classification}/CLIP_vit_base_patch16_448_ML.yaml +0 -0
- /paddlex/configs/{multilabel_classification → modules/image_multilabel_classification}/PP-HGNetV2-B0_ML.yaml +0 -0
- /paddlex/configs/{multilabel_classification → modules/image_multilabel_classification}/PP-HGNetV2-B4_ML.yaml +0 -0
- /paddlex/configs/{multilabel_classification → modules/image_multilabel_classification}/PP-HGNetV2-B6_ML.yaml +0 -0
- /paddlex/configs/{multilabel_classification → modules/image_multilabel_classification}/PP-LCNet_x1_0_ML.yaml +0 -0
- /paddlex/configs/{multilabel_classification → modules/image_multilabel_classification}/ResNet50_ML.yaml +0 -0
- /paddlex/configs/{image_unwarping → modules/image_unwarping}/UVDoc.yaml +0 -0
- /paddlex/configs/{instance_segmentation → modules/instance_segmentation}/Cascade-MaskRCNN-ResNet50-FPN.yaml +0 -0
- /paddlex/configs/{instance_segmentation → modules/instance_segmentation}/Cascade-MaskRCNN-ResNet50-vd-SSLDv2-FPN.yaml +0 -0
- /paddlex/configs/{instance_segmentation → modules/instance_segmentation}/Mask-RT-DETR-H.yaml +0 -0
- /paddlex/configs/{instance_segmentation → modules/instance_segmentation}/Mask-RT-DETR-L.yaml +0 -0
- /paddlex/configs/{instance_segmentation → modules/instance_segmentation}/Mask-RT-DETR-M.yaml +0 -0
- /paddlex/configs/{instance_segmentation → modules/instance_segmentation}/Mask-RT-DETR-S.yaml +0 -0
- /paddlex/configs/{instance_segmentation → modules/instance_segmentation}/Mask-RT-DETR-X.yaml +0 -0
- /paddlex/configs/{instance_segmentation → modules/instance_segmentation}/MaskRCNN-ResNeXt101-vd-FPN.yaml +0 -0
- /paddlex/configs/{instance_segmentation → modules/instance_segmentation}/MaskRCNN-ResNet101-FPN.yaml +0 -0
- /paddlex/configs/{instance_segmentation → modules/instance_segmentation}/MaskRCNN-ResNet101-vd-FPN.yaml +0 -0
- /paddlex/configs/{instance_segmentation → modules/instance_segmentation}/MaskRCNN-ResNet50-FPN.yaml +0 -0
- /paddlex/configs/{instance_segmentation → modules/instance_segmentation}/MaskRCNN-ResNet50-vd-FPN.yaml +0 -0
- /paddlex/configs/{instance_segmentation → modules/instance_segmentation}/MaskRCNN-ResNet50.yaml +0 -0
- /paddlex/configs/{instance_segmentation → modules/instance_segmentation}/PP-YOLOE_seg-S.yaml +0 -0
- /paddlex/configs/{instance_segmentation → modules/instance_segmentation}/SOLOv2.yaml +0 -0
- /paddlex/configs/{structure_analysis → modules/layout_detection}/PicoDet-L_layout_17cls.yaml +0 -0
- /paddlex/configs/{structure_analysis → modules/layout_detection}/PicoDet-L_layout_3cls.yaml +0 -0
- /paddlex/configs/{structure_analysis → modules/layout_detection}/PicoDet-S_layout_17cls.yaml +0 -0
- /paddlex/configs/{structure_analysis → modules/layout_detection}/PicoDet-S_layout_3cls.yaml +0 -0
- /paddlex/configs/{structure_analysis → modules/layout_detection}/PicoDet_layout_1x.yaml +0 -0
- /paddlex/configs/{structure_analysis → modules/layout_detection}/PicoDet_layout_1x_table.yaml +0 -0
- /paddlex/configs/{structure_analysis → modules/layout_detection}/RT-DETR-H_layout_17cls.yaml +0 -0
- /paddlex/configs/{structure_analysis → modules/layout_detection}/RT-DETR-H_layout_3cls.yaml +0 -0
- /paddlex/configs/{mainbody_detection → modules/mainbody_detection}/PP-ShiTuV2_det.yaml +0 -0
- /paddlex/configs/{object_detection → modules/object_detection}/Cascade-FasterRCNN-ResNet50-FPN.yaml +0 -0
- /paddlex/configs/{object_detection → modules/object_detection}/Cascade-FasterRCNN-ResNet50-vd-SSLDv2-FPN.yaml +0 -0
- /paddlex/configs/{object_detection → modules/object_detection}/CenterNet-DLA-34.yaml +0 -0
- /paddlex/configs/{object_detection → modules/object_detection}/CenterNet-ResNet50.yaml +0 -0
- /paddlex/configs/{object_detection → modules/object_detection}/DETR-R50.yaml +0 -0
- /paddlex/configs/{object_detection → modules/object_detection}/FCOS-ResNet50.yaml +0 -0
- /paddlex/configs/{object_detection → modules/object_detection}/FasterRCNN-ResNeXt101-vd-FPN.yaml +0 -0
- /paddlex/configs/{object_detection → modules/object_detection}/FasterRCNN-ResNet101-FPN.yaml +0 -0
- /paddlex/configs/{object_detection → modules/object_detection}/FasterRCNN-ResNet101.yaml +0 -0
- /paddlex/configs/{object_detection → modules/object_detection}/FasterRCNN-ResNet34-FPN.yaml +0 -0
- /paddlex/configs/{object_detection → modules/object_detection}/FasterRCNN-ResNet50-FPN.yaml +0 -0
- /paddlex/configs/{object_detection → modules/object_detection}/FasterRCNN-ResNet50-vd-FPN.yaml +0 -0
- /paddlex/configs/{object_detection → modules/object_detection}/FasterRCNN-ResNet50-vd-SSLDv2-FPN.yaml +0 -0
- /paddlex/configs/{object_detection → modules/object_detection}/FasterRCNN-ResNet50.yaml +0 -0
- /paddlex/configs/{object_detection → modules/object_detection}/FasterRCNN-Swin-Tiny-FPN.yaml +0 -0
- /paddlex/configs/{object_detection → modules/object_detection}/PP-YOLOE_plus-L.yaml +0 -0
- /paddlex/configs/{object_detection → modules/object_detection}/PP-YOLOE_plus-M.yaml +0 -0
- /paddlex/configs/{object_detection → modules/object_detection}/PP-YOLOE_plus-S.yaml +0 -0
- /paddlex/configs/{object_detection → modules/object_detection}/PP-YOLOE_plus-X.yaml +0 -0
- /paddlex/configs/{object_detection → modules/object_detection}/PicoDet-L.yaml +0 -0
- /paddlex/configs/{object_detection → modules/object_detection}/PicoDet-M.yaml +0 -0
- /paddlex/configs/{object_detection → modules/object_detection}/PicoDet-S.yaml +0 -0
- /paddlex/configs/{object_detection → modules/object_detection}/PicoDet-XS.yaml +0 -0
- /paddlex/configs/{object_detection → modules/object_detection}/RT-DETR-H.yaml +0 -0
- /paddlex/configs/{object_detection → modules/object_detection}/RT-DETR-L.yaml +0 -0
- /paddlex/configs/{object_detection → modules/object_detection}/RT-DETR-R18.yaml +0 -0
- /paddlex/configs/{object_detection → modules/object_detection}/RT-DETR-R50.yaml +0 -0
- /paddlex/configs/{object_detection → modules/object_detection}/RT-DETR-X.yaml +0 -0
- /paddlex/configs/{object_detection → modules/object_detection}/YOLOX-L.yaml +0 -0
- /paddlex/configs/{object_detection → modules/object_detection}/YOLOX-M.yaml +0 -0
- /paddlex/configs/{object_detection → modules/object_detection}/YOLOX-N.yaml +0 -0
- /paddlex/configs/{object_detection → modules/object_detection}/YOLOX-S.yaml +0 -0
- /paddlex/configs/{object_detection → modules/object_detection}/YOLOX-T.yaml +0 -0
- /paddlex/configs/{object_detection → modules/object_detection}/YOLOv3-DarkNet53.yaml +0 -0
- /paddlex/configs/{object_detection → modules/object_detection}/YOLOv3-MobileNetV3.yaml +0 -0
- /paddlex/configs/{object_detection → modules/object_detection}/YOLOv3-ResNet50_vd_DCN.yaml +0 -0
- /paddlex/configs/{pedestrian_attribute → modules/pedestrian_attribute_recognition}/PP-LCNet_x1_0_pedestrian_attribute.yaml +0 -0
- /paddlex/configs/{text_detection_seal → modules/seal_text_detection}/PP-OCRv4_mobile_seal_det.yaml +0 -0
- /paddlex/configs/{text_detection_seal → modules/seal_text_detection}/PP-OCRv4_server_seal_det.yaml +0 -0
- /paddlex/configs/{semantic_segmentation → modules/semantic_segmentation}/Deeplabv3-R101.yaml +0 -0
- /paddlex/configs/{semantic_segmentation → modules/semantic_segmentation}/Deeplabv3-R50.yaml +0 -0
- /paddlex/configs/{semantic_segmentation → modules/semantic_segmentation}/Deeplabv3_Plus-R101.yaml +0 -0
- /paddlex/configs/{semantic_segmentation → modules/semantic_segmentation}/Deeplabv3_Plus-R50.yaml +0 -0
- /paddlex/configs/{semantic_segmentation → modules/semantic_segmentation}/OCRNet_HRNet-W18.yaml +0 -0
- /paddlex/configs/{semantic_segmentation → modules/semantic_segmentation}/OCRNet_HRNet-W48.yaml +0 -0
- /paddlex/configs/{semantic_segmentation → modules/semantic_segmentation}/PP-LiteSeg-B.yaml +0 -0
- /paddlex/configs/{semantic_segmentation → modules/semantic_segmentation}/PP-LiteSeg-T.yaml +0 -0
- /paddlex/configs/{semantic_segmentation → modules/semantic_segmentation}/SegFormer-B0.yaml +0 -0
- /paddlex/configs/{semantic_segmentation → modules/semantic_segmentation}/SegFormer-B1.yaml +0 -0
- /paddlex/configs/{semantic_segmentation → modules/semantic_segmentation}/SegFormer-B2.yaml +0 -0
- /paddlex/configs/{semantic_segmentation → modules/semantic_segmentation}/SegFormer-B3.yaml +0 -0
- /paddlex/configs/{semantic_segmentation → modules/semantic_segmentation}/SegFormer-B4.yaml +0 -0
- /paddlex/configs/{semantic_segmentation → modules/semantic_segmentation}/SegFormer-B5.yaml +0 -0
- /paddlex/configs/{small_object_detection → modules/small_object_detection}/PP-YOLOE_plus_SOD-L.yaml +0 -0
- /paddlex/configs/{small_object_detection → modules/small_object_detection}/PP-YOLOE_plus_SOD-S.yaml +0 -0
- /paddlex/configs/{small_object_detection → modules/small_object_detection}/PP-YOLOE_plus_SOD-largesize-L.yaml +0 -0
- /paddlex/configs/{table_recognition → modules/table_structure_recognition}/SLANet.yaml +0 -0
- /paddlex/configs/{table_recognition → modules/table_structure_recognition}/SLANet_plus.yaml +0 -0
- /paddlex/configs/{text_detection → modules/text_detection}/PP-OCRv4_mobile_det.yaml +0 -0
- /paddlex/configs/{text_detection → modules/text_detection}/PP-OCRv4_server_det.yaml +0 -0
- /paddlex/configs/{text_recognition → modules/text_recognition}/PP-OCRv4_mobile_rec.yaml +0 -0
- /paddlex/configs/{text_recognition → modules/text_recognition}/PP-OCRv4_server_rec.yaml +0 -0
- /paddlex/configs/{text_recognition → modules/text_recognition}/ch_RepSVTR_rec.yaml +0 -0
- /paddlex/configs/{text_recognition → modules/text_recognition}/ch_SVTRv2_rec.yaml +0 -0
- /paddlex/configs/{ts_anomaly_detection → modules/ts_anomaly_detection}/AutoEncoder_ad.yaml +0 -0
- /paddlex/configs/{ts_anomaly_detection → modules/ts_anomaly_detection}/DLinear_ad.yaml +0 -0
- /paddlex/configs/{ts_anomaly_detection → modules/ts_anomaly_detection}/Nonstationary_ad.yaml +0 -0
- /paddlex/configs/{ts_anomaly_detection → modules/ts_anomaly_detection}/PatchTST_ad.yaml +0 -0
- /paddlex/configs/{ts_anomaly_detection → modules/ts_anomaly_detection}/TimesNet_ad.yaml +0 -0
- /paddlex/configs/{ts_classification → modules/ts_classification}/TimesNet_cls.yaml +0 -0
- /paddlex/configs/{ts_forecast → modules/ts_forecast}/DLinear.yaml +0 -0
- /paddlex/configs/{ts_forecast → modules/ts_forecast}/NLinear.yaml +0 -0
- /paddlex/configs/{ts_forecast → modules/ts_forecast}/Nonstationary.yaml +0 -0
- /paddlex/configs/{ts_forecast → modules/ts_forecast}/PatchTST.yaml +0 -0
- /paddlex/configs/{ts_forecast → modules/ts_forecast}/RLinear.yaml +0 -0
- /paddlex/configs/{ts_forecast → modules/ts_forecast}/TiDE.yaml +0 -0
- /paddlex/configs/{ts_forecast → modules/ts_forecast}/TimesNet.yaml +0 -0
- /paddlex/configs/{vehicle_attribute → modules/vehicle_attribute_recognition}/PP-LCNet_x1_0_vehicle_attribute.yaml +0 -0
- /paddlex/configs/{vehicle_detection → modules/vehicle_detection}/PP-YOLOE-L_vehicle.yaml +0 -0
- /paddlex/configs/{vehicle_detection → modules/vehicle_detection}/PP-YOLOE-S_vehicle.yaml +0 -0
- {paddlex-3.0.0b2.dist-info → paddlex-3.0.0rc1.dist-info}/entry_points.txt +0 -0
- {paddlex-3.0.0b2.dist-info → paddlex-3.0.0rc1.dist-info/licenses}/LICENSE +0 -0
- {paddlex-3.0.0b2.dist-info → paddlex-3.0.0rc1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1933 @@
|
|
1
|
+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# Modified from OpenAI Whisper 2022 (https://github.com/openai/whisper/whisper)
|
15
|
+
import os
|
16
|
+
import zlib
|
17
|
+
from dataclasses import dataclass, field
|
18
|
+
from functools import lru_cache
|
19
|
+
from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
20
|
+
|
21
|
+
import numpy as np
|
22
|
+
import paddle
|
23
|
+
|
24
|
+
from ....utils.deps import function_requires_deps, is_dep_available
|
25
|
+
from ...utils.benchmark import (
|
26
|
+
benchmark,
|
27
|
+
get_inference_operations,
|
28
|
+
set_inference_operations,
|
29
|
+
)
|
30
|
+
from ..common.tokenizer import GPTTokenizer
|
31
|
+
|
32
|
+
if is_dep_available("soundfile"):
|
33
|
+
import soundfile
|
34
|
+
if is_dep_available("tqdm"):
|
35
|
+
import tqdm
|
36
|
+
|
37
|
+
__all__ = [
|
38
|
+
"Whisper",
|
39
|
+
"Tokenizer",
|
40
|
+
]
|
41
|
+
|
42
|
+
|
43
|
+
def exact_div(x, y):
|
44
|
+
assert x % y == 0
|
45
|
+
return x // y
|
46
|
+
|
47
|
+
|
48
|
+
_MODELS = ["large"]
|
49
|
+
SAMPLE_RATE = 16000
|
50
|
+
N_FFT = 400
|
51
|
+
N_MELS = 80
|
52
|
+
HOP_LENGTH = 160
|
53
|
+
CHUNK_LENGTH = 30
|
54
|
+
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk
|
55
|
+
N_FRAMES = exact_div(
|
56
|
+
N_SAMPLES, HOP_LENGTH
|
57
|
+
) # 3000: number of frames in a mel spectrogram input
|
58
|
+
|
59
|
+
|
60
|
+
@dataclass
|
61
|
+
class ModelDimensions:
|
62
|
+
n_mels: int
|
63
|
+
n_audio_ctx: int
|
64
|
+
n_audio_state: int
|
65
|
+
n_audio_head: int
|
66
|
+
n_audio_layer: int
|
67
|
+
n_vocab: int
|
68
|
+
n_text_ctx: int
|
69
|
+
n_text_state: int
|
70
|
+
n_text_head: int
|
71
|
+
n_text_layer: int
|
72
|
+
|
73
|
+
|
74
|
+
LANGUAGES = {
|
75
|
+
"en": "english",
|
76
|
+
"zh": "chinese",
|
77
|
+
"de": "german",
|
78
|
+
"es": "spanish",
|
79
|
+
"ru": "russian",
|
80
|
+
"ko": "korean",
|
81
|
+
"fr": "french",
|
82
|
+
"ja": "japanese",
|
83
|
+
"pt": "portuguese",
|
84
|
+
"tr": "turkish",
|
85
|
+
"pl": "polish",
|
86
|
+
"ca": "catalan",
|
87
|
+
"nl": "dutch",
|
88
|
+
"ar": "arabic",
|
89
|
+
"sv": "swedish",
|
90
|
+
"it": "italian",
|
91
|
+
"id": "indonesian",
|
92
|
+
"hi": "hindi",
|
93
|
+
"fi": "finnish",
|
94
|
+
"vi": "vietnamese",
|
95
|
+
"iw": "hebrew",
|
96
|
+
"uk": "ukrainian",
|
97
|
+
"el": "greek",
|
98
|
+
"ms": "malay",
|
99
|
+
"cs": "czech",
|
100
|
+
"ro": "romanian",
|
101
|
+
"da": "danish",
|
102
|
+
"hu": "hungarian",
|
103
|
+
"ta": "tamil",
|
104
|
+
"no": "norwegian",
|
105
|
+
"th": "thai",
|
106
|
+
"ur": "urdu",
|
107
|
+
"hr": "croatian",
|
108
|
+
"bg": "bulgarian",
|
109
|
+
"lt": "lithuanian",
|
110
|
+
"la": "latin",
|
111
|
+
"mi": "maori",
|
112
|
+
"ml": "malayalam",
|
113
|
+
"cy": "welsh",
|
114
|
+
"sk": "slovak",
|
115
|
+
"te": "telugu",
|
116
|
+
"fa": "persian",
|
117
|
+
"lv": "latvian",
|
118
|
+
"bn": "bengali",
|
119
|
+
"sr": "serbian",
|
120
|
+
"az": "azerbaijani",
|
121
|
+
"sl": "slovenian",
|
122
|
+
"kn": "kannada",
|
123
|
+
"et": "estonian",
|
124
|
+
"mk": "macedonian",
|
125
|
+
"br": "breton",
|
126
|
+
"eu": "basque",
|
127
|
+
"is": "icelandic",
|
128
|
+
"hy": "armenian",
|
129
|
+
"ne": "nepali",
|
130
|
+
"mn": "mongolian",
|
131
|
+
"bs": "bosnian",
|
132
|
+
"kk": "kazakh",
|
133
|
+
"sq": "albanian",
|
134
|
+
"sw": "swahili",
|
135
|
+
"gl": "galician",
|
136
|
+
"mr": "marathi",
|
137
|
+
"pa": "punjabi",
|
138
|
+
"si": "sinhala",
|
139
|
+
"km": "khmer",
|
140
|
+
"sn": "shona",
|
141
|
+
"yo": "yoruba",
|
142
|
+
"so": "somali",
|
143
|
+
"af": "afrikaans",
|
144
|
+
"oc": "occitan",
|
145
|
+
"ka": "georgian",
|
146
|
+
"be": "belarusian",
|
147
|
+
"tg": "tajik",
|
148
|
+
"sd": "sindhi",
|
149
|
+
"gu": "gujarati",
|
150
|
+
"am": "amharic",
|
151
|
+
"yi": "yiddish",
|
152
|
+
"lo": "lao",
|
153
|
+
"uz": "uzbek",
|
154
|
+
"fo": "faroese",
|
155
|
+
"ht": "haitian creole",
|
156
|
+
"ps": "pashto",
|
157
|
+
"tk": "turkmen",
|
158
|
+
"nn": "nynorsk",
|
159
|
+
"mt": "maltese",
|
160
|
+
"sa": "sanskrit",
|
161
|
+
"lb": "luxembourgish",
|
162
|
+
"my": "myanmar",
|
163
|
+
"bo": "tibetan",
|
164
|
+
"tl": "tagalog",
|
165
|
+
"mg": "malagasy",
|
166
|
+
"as": "assamese",
|
167
|
+
"tt": "tatar",
|
168
|
+
"haw": "hawaiian",
|
169
|
+
"ln": "lingala",
|
170
|
+
"ha": "hausa",
|
171
|
+
"ba": "bashkir",
|
172
|
+
"jw": "javanese",
|
173
|
+
"su": "sundanese",
|
174
|
+
}
|
175
|
+
|
176
|
+
# language code lookup by name, with a few language aliases
|
177
|
+
TO_LANGUAGE_CODE = {
|
178
|
+
**{language: code for code, language in LANGUAGES.items()},
|
179
|
+
"burmese": "my",
|
180
|
+
"valencian": "ca",
|
181
|
+
"flemish": "nl",
|
182
|
+
"haitian": "ht",
|
183
|
+
"letzeburgesch": "lb",
|
184
|
+
"pushto": "ps",
|
185
|
+
"panjabi": "pa",
|
186
|
+
"moldavian": "ro",
|
187
|
+
"moldovan": "ro",
|
188
|
+
"sinhalese": "si",
|
189
|
+
"castilian": "es",
|
190
|
+
}
|
191
|
+
|
192
|
+
|
193
|
+
def compression_ratio(text) -> float:
|
194
|
+
return len(text) / len(zlib.compress(text.encode("utf-8")))
|
195
|
+
|
196
|
+
|
197
|
+
def format_timestamp(
|
198
|
+
seconds: float, always_include_hours: bool = False, decimal_marker: str = "."
|
199
|
+
):
|
200
|
+
assert seconds >= 0, "non-negative timestamp expected"
|
201
|
+
milliseconds = round(seconds * 1000.0)
|
202
|
+
|
203
|
+
hours = milliseconds // 3_600_000
|
204
|
+
milliseconds -= hours * 3_600_000
|
205
|
+
|
206
|
+
minutes = milliseconds // 60_000
|
207
|
+
milliseconds -= minutes * 60_000
|
208
|
+
|
209
|
+
seconds = milliseconds // 1_000
|
210
|
+
milliseconds -= seconds * 1_000
|
211
|
+
|
212
|
+
hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
|
213
|
+
return (
|
214
|
+
f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
|
215
|
+
)
|
216
|
+
|
217
|
+
|
218
|
+
@dataclass(frozen=True)
|
219
|
+
class Tokenizer:
|
220
|
+
"""A thin wrapper around `GPTTokenizer` providing quick access to special tokens"""
|
221
|
+
|
222
|
+
tokenizer: "GPTTokenizer"
|
223
|
+
language: Optional[str]
|
224
|
+
sot_sequence: Tuple[int]
|
225
|
+
|
226
|
+
def encode(self, text, **kwargs):
|
227
|
+
return self.tokenizer.encode(text, **kwargs)
|
228
|
+
|
229
|
+
def decode(
|
230
|
+
self, token_ids: Union[int, List[int], np.ndarray, paddle.Tensor], **kwargs
|
231
|
+
):
|
232
|
+
if len(token_ids) > 1:
|
233
|
+
ids_list = []
|
234
|
+
for ids in token_ids:
|
235
|
+
if paddle.is_tensor(ids):
|
236
|
+
ids = ids.item()
|
237
|
+
if ids < len(self.tokenizer):
|
238
|
+
ids_list.append(ids)
|
239
|
+
token_ids = ids_list
|
240
|
+
elif len(token_ids) == 1:
|
241
|
+
token_ids = token_ids[0]
|
242
|
+
else:
|
243
|
+
raise ValueError(f"token_ids {token_ids} load error.")
|
244
|
+
|
245
|
+
return self.tokenizer.decode(token_ids, **kwargs)
|
246
|
+
|
247
|
+
def decode_with_timestamps(self, tokens) -> str:
|
248
|
+
"""
|
249
|
+
Timestamp tokens are above the special tokens' id range and are ignored by `decode()`.
|
250
|
+
This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
|
251
|
+
"""
|
252
|
+
outputs = [[]]
|
253
|
+
for token in tokens:
|
254
|
+
if token >= self.timestamp_begin:
|
255
|
+
timestamp = f"<|{(token - self.timestamp_begin) * 0.02:.2f}|>"
|
256
|
+
outputs.append(timestamp)
|
257
|
+
outputs.append([])
|
258
|
+
else:
|
259
|
+
outputs[-1].append(token)
|
260
|
+
outputs = [
|
261
|
+
s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs
|
262
|
+
]
|
263
|
+
return "".join(outputs)
|
264
|
+
|
265
|
+
@property
|
266
|
+
@lru_cache()
|
267
|
+
def eot(self) -> int:
|
268
|
+
return self.tokenizer.eos_token_id
|
269
|
+
|
270
|
+
@property
|
271
|
+
@lru_cache()
|
272
|
+
def sot(self) -> int:
|
273
|
+
return self._get_single_token_id("<|startoftranscript|>")
|
274
|
+
|
275
|
+
@property
|
276
|
+
@lru_cache()
|
277
|
+
def sot_lm(self) -> int:
|
278
|
+
return self._get_single_token_id("<|startoflm|>")
|
279
|
+
|
280
|
+
@property
|
281
|
+
@lru_cache()
|
282
|
+
def sot_prev(self) -> int:
|
283
|
+
return self._get_single_token_id("<|startofprev|>")
|
284
|
+
|
285
|
+
@property
|
286
|
+
@lru_cache()
|
287
|
+
def no_speech(self) -> int:
|
288
|
+
return self._get_single_token_id("<|nospeech|>")
|
289
|
+
|
290
|
+
@property
|
291
|
+
@lru_cache()
|
292
|
+
def no_timestamps(self) -> int:
|
293
|
+
return self._get_single_token_id("<|notimestamps|>")
|
294
|
+
|
295
|
+
@property
|
296
|
+
@lru_cache()
|
297
|
+
def timestamp_begin(self) -> int:
|
298
|
+
return self.tokenizer.all_special_ids[-1] + 1
|
299
|
+
|
300
|
+
@property
|
301
|
+
@lru_cache()
|
302
|
+
def language_token(self) -> int:
|
303
|
+
"""Returns the token id corresponding to the value of the `language` field"""
|
304
|
+
if self.language is None:
|
305
|
+
raise ValueError("This tokenizer does not have language token configured")
|
306
|
+
|
307
|
+
additional_tokens = dict(
|
308
|
+
zip(
|
309
|
+
self.tokenizer.additional_special_tokens,
|
310
|
+
self.tokenizer.additional_special_tokens_ids,
|
311
|
+
)
|
312
|
+
)
|
313
|
+
candidate = f"<|{self.language}|>"
|
314
|
+
if candidate in additional_tokens:
|
315
|
+
return additional_tokens[candidate]
|
316
|
+
|
317
|
+
raise KeyError(f"Language {self.language} not found in tokenizer.")
|
318
|
+
|
319
|
+
@property
|
320
|
+
@lru_cache()
|
321
|
+
def all_language_tokens(self) -> Tuple[int]:
|
322
|
+
result = []
|
323
|
+
for token, token_id in zip(
|
324
|
+
self.tokenizer.additional_special_tokens,
|
325
|
+
self.tokenizer.additional_special_tokens_ids,
|
326
|
+
):
|
327
|
+
if token.strip("<|>") in LANGUAGES:
|
328
|
+
result.append(token_id)
|
329
|
+
return tuple(result)
|
330
|
+
|
331
|
+
@property
|
332
|
+
@lru_cache()
|
333
|
+
def all_language_codes(self) -> Tuple[str]:
|
334
|
+
return tuple(self.decode([l]).strip("<|>") for l in self.all_language_tokens)
|
335
|
+
|
336
|
+
@property
|
337
|
+
@lru_cache()
|
338
|
+
def sot_sequence_including_notimestamps(self) -> Tuple[int]:
|
339
|
+
return tuple(list(self.sot_sequence) + [self.no_timestamps])
|
340
|
+
|
341
|
+
@property
|
342
|
+
@lru_cache()
|
343
|
+
def non_speech_tokens(self) -> Tuple[int]:
|
344
|
+
"""
|
345
|
+
Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
|
346
|
+
annotations, to prevent sampling texts that are not actually spoken in the audio, e.g.
|
347
|
+
- ♪♪♪
|
348
|
+
- ( SPEAKING FOREIGN LANGUAGE )
|
349
|
+
- [DAVID] Hey there,
|
350
|
+
keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
|
351
|
+
"""
|
352
|
+
symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』')
|
353
|
+
symbols += (
|
354
|
+
"<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
|
355
|
+
)
|
356
|
+
|
357
|
+
# symbols that may be a single token or multiple tokens depending on the tokenizer.
|
358
|
+
# In case they're multiple tokens, suppress the first token, which is safe because:
|
359
|
+
# These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress
|
360
|
+
# in generations, and in the 3-byte UTF-8 representation they share the first two bytes.
|
361
|
+
miscellaneous = set("♩♪♫♬♭♮♯")
|
362
|
+
assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
|
363
|
+
|
364
|
+
# allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
|
365
|
+
result = {
|
366
|
+
self.tokenizer.encode(" -").input_ids[0],
|
367
|
+
self.tokenizer.encode(" '").input_ids[0],
|
368
|
+
}
|
369
|
+
for symbol in symbols + list(miscellaneous):
|
370
|
+
for tokens in [
|
371
|
+
self.tokenizer.encode(symbol).input_ids,
|
372
|
+
self.tokenizer.encode(" " + symbol).input_ids,
|
373
|
+
]:
|
374
|
+
if len(tokens) == 1 or symbol in miscellaneous:
|
375
|
+
result.add(tokens[0])
|
376
|
+
|
377
|
+
return tuple(sorted(result))
|
378
|
+
|
379
|
+
def _get_single_token_id(self, text) -> int:
|
380
|
+
tokens = self.tokenizer.encode(text).input_ids
|
381
|
+
assert len(tokens) == 1, f"{text} is not encoded as a single token"
|
382
|
+
return tokens[0]
|
383
|
+
|
384
|
+
|
385
|
+
@lru_cache(maxsize=None)
|
386
|
+
def build_tokenizer(resource_path: str, name: str = "gpt2"):
|
387
|
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
388
|
+
path = os.path.join(resource_path, "assets", name)
|
389
|
+
tokenizer = GPTTokenizer.from_pretrained(path)
|
390
|
+
|
391
|
+
specials = [
|
392
|
+
"<|startoftranscript|>",
|
393
|
+
*[f"<|{lang}|>" for lang in LANGUAGES.keys()],
|
394
|
+
"<|translate|>",
|
395
|
+
"<|transcribe|>",
|
396
|
+
"<|startoflm|>",
|
397
|
+
"<|startofprev|>",
|
398
|
+
"<|nospeech|>",
|
399
|
+
"<|notimestamps|>",
|
400
|
+
]
|
401
|
+
|
402
|
+
tokenizer.add_special_tokens(dict(additional_special_tokens=specials))
|
403
|
+
return tokenizer
|
404
|
+
|
405
|
+
|
406
|
+
@lru_cache(maxsize=None)
|
407
|
+
def get_tokenizer(
|
408
|
+
multilingual: bool,
|
409
|
+
resource_path: str,
|
410
|
+
*,
|
411
|
+
task: Optional[str] = None, # Literal["transcribe", "translate", None]
|
412
|
+
language: Optional[str] = None,
|
413
|
+
) -> Tokenizer:
|
414
|
+
if language is not None:
|
415
|
+
language = language.lower()
|
416
|
+
if language not in LANGUAGES:
|
417
|
+
if language in TO_LANGUAGE_CODE:
|
418
|
+
language = TO_LANGUAGE_CODE[language]
|
419
|
+
else:
|
420
|
+
raise ValueError(f"Unsupported language: {language}")
|
421
|
+
|
422
|
+
if multilingual:
|
423
|
+
tokenizer_name = "multilingual"
|
424
|
+
task = task or "transcribe"
|
425
|
+
language = language or "en"
|
426
|
+
else:
|
427
|
+
tokenizer_name = "gpt2"
|
428
|
+
task = None
|
429
|
+
language = None
|
430
|
+
|
431
|
+
tokenizer = build_tokenizer(resource_path=resource_path, name=tokenizer_name)
|
432
|
+
all_special_ids: List[int] = tokenizer.all_special_ids
|
433
|
+
sot: int = all_special_ids[1]
|
434
|
+
translate: int = all_special_ids[-6]
|
435
|
+
transcribe: int = all_special_ids[-5]
|
436
|
+
|
437
|
+
langs = tuple(LANGUAGES.keys())
|
438
|
+
sot_sequence = [sot]
|
439
|
+
if language is not None:
|
440
|
+
sot_sequence.append(sot + 1 + langs.index(language))
|
441
|
+
if task is not None:
|
442
|
+
sot_sequence.append(transcribe if task == "transcribe" else translate)
|
443
|
+
|
444
|
+
return Tokenizer(
|
445
|
+
tokenizer=tokenizer, language=language, sot_sequence=tuple(sot_sequence)
|
446
|
+
)
|
447
|
+
|
448
|
+
|
449
|
+
class MultiHeadAttention(paddle.nn.Layer):
|
450
|
+
def __init__(self, n_state: int, n_head: int):
|
451
|
+
super().__init__()
|
452
|
+
self.n_head = n_head
|
453
|
+
self.query = paddle.nn.Linear(n_state, n_state, bias_attr=True)
|
454
|
+
self.key = paddle.nn.Linear(n_state, n_state, bias_attr=False)
|
455
|
+
self.value = paddle.nn.Linear(n_state, n_state, bias_attr=True)
|
456
|
+
self.out = paddle.nn.Linear(n_state, n_state, bias_attr=True)
|
457
|
+
|
458
|
+
def forward(
|
459
|
+
self,
|
460
|
+
x: paddle.Tensor,
|
461
|
+
xa: Optional[paddle.Tensor] = None,
|
462
|
+
mask: Optional[paddle.Tensor] = None,
|
463
|
+
kv_cache: Optional[dict] = None,
|
464
|
+
):
|
465
|
+
q = self.query(x)
|
466
|
+
|
467
|
+
if kv_cache is None or xa is None or self.key not in kv_cache:
|
468
|
+
# hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
|
469
|
+
# otherwise, perform key/value projections for self- or cross-attention as usual.
|
470
|
+
k = self.key(x if xa is None else xa)
|
471
|
+
v = self.value(x if xa is None else xa)
|
472
|
+
else:
|
473
|
+
# for cross-attention, calculate keys and values once and reuse in subsequent calls.
|
474
|
+
k = kv_cache[self.key]
|
475
|
+
v = kv_cache[self.value]
|
476
|
+
|
477
|
+
wv = self.qkv_attention(q, k, v, mask)
|
478
|
+
return self.out(wv)
|
479
|
+
|
480
|
+
def qkv_attention(
|
481
|
+
self,
|
482
|
+
q: paddle.Tensor,
|
483
|
+
k: paddle.Tensor,
|
484
|
+
v: paddle.Tensor,
|
485
|
+
mask: Optional[paddle.Tensor] = None,
|
486
|
+
):
|
487
|
+
n_batch, n_ctx, n_state = q.shape
|
488
|
+
scale = (n_state // self.n_head) ** -0.25
|
489
|
+
q = (
|
490
|
+
paddle.transpose(q.reshape([*q.shape[:2], self.n_head, -1]), (0, 2, 1, 3))
|
491
|
+
* scale
|
492
|
+
)
|
493
|
+
k = (
|
494
|
+
paddle.transpose(k.reshape([*k.shape[:2], self.n_head, -1]), (0, 2, 3, 1))
|
495
|
+
* scale
|
496
|
+
)
|
497
|
+
v = paddle.transpose(v.reshape([*v.shape[:2], self.n_head, -1]), (0, 2, 1, 3))
|
498
|
+
|
499
|
+
qk = q @ k
|
500
|
+
if mask is not None:
|
501
|
+
qk = qk + mask[:n_ctx, :n_ctx]
|
502
|
+
|
503
|
+
w = paddle.nn.functional.softmax(qk.astype(q.dtype), axis=-1)
|
504
|
+
return paddle.transpose((w @ v), (0, 2, 1, 3)).flatten(start_axis=2)
|
505
|
+
|
506
|
+
|
507
|
+
class ResidualAttentionBlock(paddle.nn.Layer):
|
508
|
+
def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
|
509
|
+
super().__init__()
|
510
|
+
|
511
|
+
self.attn = MultiHeadAttention(n_state, n_head)
|
512
|
+
self.attn_ln = paddle.nn.LayerNorm(n_state)
|
513
|
+
|
514
|
+
self.cross_attn = (
|
515
|
+
MultiHeadAttention(n_state, n_head) if cross_attention else None
|
516
|
+
)
|
517
|
+
self.cross_attn_ln = paddle.nn.LayerNorm(n_state) if cross_attention else None
|
518
|
+
|
519
|
+
n_mlp = n_state * 4
|
520
|
+
self.mlp = paddle.nn.Sequential(
|
521
|
+
paddle.nn.Linear(n_state, n_mlp, bias_attr=True),
|
522
|
+
paddle.nn.GELU(),
|
523
|
+
paddle.nn.Linear(n_mlp, n_state, bias_attr=True),
|
524
|
+
)
|
525
|
+
self.mlp_ln = paddle.nn.LayerNorm(n_state)
|
526
|
+
|
527
|
+
def forward(
|
528
|
+
self,
|
529
|
+
x: paddle.Tensor,
|
530
|
+
xa: Optional[paddle.Tensor] = None,
|
531
|
+
mask: Optional[paddle.Tensor] = None,
|
532
|
+
kv_cache: Optional[dict] = None,
|
533
|
+
):
|
534
|
+
x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)
|
535
|
+
if self.cross_attn:
|
536
|
+
x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)
|
537
|
+
x = x + self.mlp(self.mlp_ln(x))
|
538
|
+
return x
|
539
|
+
|
540
|
+
|
541
|
+
def sinusoids(length, channels, max_timescale=10000):
|
542
|
+
"""Returns sinusoids for positional embedding"""
|
543
|
+
assert channels % 2 == 0
|
544
|
+
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
|
545
|
+
inv_timescales = paddle.exp(
|
546
|
+
-log_timescale_increment * paddle.arange(channels // 2, dtype=paddle.float32)
|
547
|
+
)
|
548
|
+
scaled_time = (
|
549
|
+
paddle.arange(length, dtype=paddle.float32)[:, np.newaxis]
|
550
|
+
* inv_timescales[np.newaxis, :]
|
551
|
+
)
|
552
|
+
return paddle.to_tensor(
|
553
|
+
paddle.concat([paddle.sin(scaled_time), paddle.cos(scaled_time)], axis=1)
|
554
|
+
)
|
555
|
+
|
556
|
+
|
557
|
+
class AudioEncoder(paddle.nn.Layer):
|
558
|
+
def __init__(
|
559
|
+
self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
|
560
|
+
):
|
561
|
+
super().__init__()
|
562
|
+
self.conv1 = paddle.nn.Conv1D(
|
563
|
+
n_mels, n_state, kernel_size=3, stride=1, padding=1, bias_attr=True
|
564
|
+
)
|
565
|
+
self.conv2 = paddle.nn.Conv1D(
|
566
|
+
n_state, n_state, kernel_size=3, stride=2, padding=1, bias_attr=True
|
567
|
+
)
|
568
|
+
self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
|
569
|
+
|
570
|
+
self.blocks: Iterable[ResidualAttentionBlock] = paddle.nn.LayerList(
|
571
|
+
[ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
|
572
|
+
)
|
573
|
+
self.ln_post = paddle.nn.LayerNorm(n_state)
|
574
|
+
|
575
|
+
def forward(self, x: paddle.Tensor):
|
576
|
+
"""
|
577
|
+
x : paddle.Tensor, shape = (batch_size, n_mels, n_ctx)
|
578
|
+
the mel spectrogram of the audio
|
579
|
+
"""
|
580
|
+
x = paddle.nn.functional.gelu(self.conv1(x))
|
581
|
+
x = paddle.nn.functional.gelu(self.conv2(x))
|
582
|
+
x = paddle.transpose(x, (0, 2, 1))
|
583
|
+
|
584
|
+
assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
|
585
|
+
x = x + self.positional_embedding
|
586
|
+
|
587
|
+
for block in self.blocks:
|
588
|
+
x = block(x)
|
589
|
+
|
590
|
+
x = self.ln_post(x)
|
591
|
+
return x
|
592
|
+
|
593
|
+
|
594
|
+
class TextDecoder(paddle.nn.Layer):
|
595
|
+
def __init__(
|
596
|
+
self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
|
597
|
+
):
|
598
|
+
super().__init__()
|
599
|
+
|
600
|
+
self.token_embedding = paddle.nn.Embedding(n_vocab, n_state)
|
601
|
+
self.positional_embedding = paddle.create_parameter(
|
602
|
+
shape=[n_ctx, n_state], dtype="float32"
|
603
|
+
)
|
604
|
+
|
605
|
+
self.blocks: Iterable[ResidualAttentionBlock] = paddle.nn.LayerList(
|
606
|
+
[
|
607
|
+
ResidualAttentionBlock(n_state, n_head, cross_attention=True)
|
608
|
+
for _ in range(n_layer)
|
609
|
+
]
|
610
|
+
)
|
611
|
+
self.ln = paddle.nn.LayerNorm(n_state)
|
612
|
+
|
613
|
+
mask = paddle.full(shape=[n_ctx, n_state], fill_value=-np.inf, dtype="float32")
|
614
|
+
mask = paddle.triu(mask, diagonal=1)
|
615
|
+
self.register_buffer("mask", mask, persistable=False)
|
616
|
+
|
617
|
+
def forward(
|
618
|
+
self, x: paddle.Tensor, xa: paddle.Tensor, kv_cache: Optional[dict] = None
|
619
|
+
):
|
620
|
+
"""
|
621
|
+
x : paddle.LongTensor, shape = (batch_size, <= n_ctx)
|
622
|
+
the text tokens
|
623
|
+
xa : paddle.Tensor, shape = (batch_size, n_mels, n_audio_ctx)
|
624
|
+
the encoded audio features to be attended on
|
625
|
+
"""
|
626
|
+
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
|
627
|
+
x = (
|
628
|
+
self.token_embedding(x)
|
629
|
+
+ self.positional_embedding[offset : offset + x.shape[-1]]
|
630
|
+
)
|
631
|
+
x = x.to(xa.dtype)
|
632
|
+
|
633
|
+
for block in self.blocks:
|
634
|
+
x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
|
635
|
+
|
636
|
+
x = self.ln(x)
|
637
|
+
logits = x @ paddle.transpose(self.token_embedding.weight, (1, 0))
|
638
|
+
|
639
|
+
return logits
|
640
|
+
|
641
|
+
|
642
|
+
@dataclass(frozen=True)
|
643
|
+
class DecodingOptions:
|
644
|
+
task: str = (
|
645
|
+
"transcribe" # whether to perform X->X "transcribe" or X->English "translate"
|
646
|
+
)
|
647
|
+
language: Optional[str] = (
|
648
|
+
None # language that the audio is in; uses detected language if None
|
649
|
+
)
|
650
|
+
# sampling-related options
|
651
|
+
temperature: float = 0.0
|
652
|
+
sample_len: Optional[int] = None # maximum number of tokens to sample
|
653
|
+
best_of: Optional[int] = (
|
654
|
+
None # number of independent samples to collect, when t > 0
|
655
|
+
)
|
656
|
+
beam_size: Optional[int] = None # number of beams in beam search, when t == 0
|
657
|
+
patience: Optional[float] = (
|
658
|
+
None # patience in beam search (https://arxiv.org/abs/2204.05424)
|
659
|
+
)
|
660
|
+
|
661
|
+
# options for ranking generations (either beams or best-of-N samples)
|
662
|
+
length_penalty: Optional[float] = (
|
663
|
+
None # "alpha" in Google NMT, None defaults to length norm
|
664
|
+
)
|
665
|
+
|
666
|
+
# prompt, prefix, and token suppression
|
667
|
+
prompt: Optional[Union[str, List[int]]] = (
|
668
|
+
None # text or tokens for the previous context
|
669
|
+
)
|
670
|
+
prefix: Optional[Union[str, List[int]]] = (
|
671
|
+
None # text or tokens to prefix the current context
|
672
|
+
)
|
673
|
+
suppress_blank: bool = True # this will suppress blank outputs
|
674
|
+
|
675
|
+
# list of tokens ids (or comma-separated token ids) to suppress
|
676
|
+
# "-1" will suppress a set of symbols as defined in `tokenizer.non_speech_tokens()`
|
677
|
+
suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1"
|
678
|
+
|
679
|
+
# timestamp sampling options
|
680
|
+
without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only
|
681
|
+
max_initial_timestamp: Optional[float] = (
|
682
|
+
1.0 # the initial timestamp cannot be later than this
|
683
|
+
)
|
684
|
+
|
685
|
+
# implementation details
|
686
|
+
fp16: bool = False # use fp16 for most of the calculation
|
687
|
+
|
688
|
+
|
689
|
+
@dataclass(frozen=True)
|
690
|
+
class DecodingResult:
|
691
|
+
audio_features: paddle.Tensor
|
692
|
+
language: str
|
693
|
+
language_probs: Optional[Dict[str, float]] = None
|
694
|
+
tokens: List[int] = field(default_factory=list)
|
695
|
+
text: str = ""
|
696
|
+
avg_logprob: float = np.nan
|
697
|
+
no_speech_prob: float = np.nan
|
698
|
+
temperature: float = np.nan
|
699
|
+
compression_ratio: float = np.nan
|
700
|
+
|
701
|
+
|
702
|
+
class Inference:
|
703
|
+
def logits(
|
704
|
+
self, tokens: paddle.Tensor, audio_features: paddle.Tensor
|
705
|
+
) -> paddle.Tensor:
|
706
|
+
"""Perform a forward pass on the decoder and return per-token logits"""
|
707
|
+
raise NotImplementedError
|
708
|
+
|
709
|
+
def rearrange_kv_cache(self, source_indices) -> None:
|
710
|
+
"""Update the key-value cache according to the updated beams"""
|
711
|
+
raise NotImplementedError
|
712
|
+
|
713
|
+
def cleanup_caching(self) -> None:
|
714
|
+
"""Clean up any resources or hooks after decoding is finished"""
|
715
|
+
|
716
|
+
|
717
|
+
class WhisperInference(Inference):
|
718
|
+
def __init__(self, model: "Whisper", initial_token_length: int):
|
719
|
+
self.model: "Whisper" = model
|
720
|
+
self.initial_token_length = initial_token_length
|
721
|
+
self.kv_cache = {}
|
722
|
+
self.hooks = []
|
723
|
+
|
724
|
+
def logits(
|
725
|
+
self, tokens: paddle.Tensor, audio_features: paddle.Tensor
|
726
|
+
) -> paddle.Tensor:
|
727
|
+
if not self.kv_cache:
|
728
|
+
self.kv_cache, self.hooks = self.model.install_kv_cache_hooks()
|
729
|
+
|
730
|
+
if tokens.shape[-1] > self.initial_token_length:
|
731
|
+
# only need to use the last token except in the first forward pass
|
732
|
+
tokens = tokens[:, -1:]
|
733
|
+
|
734
|
+
return self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache)
|
735
|
+
|
736
|
+
def cleanup_caching(self):
|
737
|
+
for hook in self.hooks:
|
738
|
+
hook.remove()
|
739
|
+
|
740
|
+
self.kv_cache = {}
|
741
|
+
self.hooks = []
|
742
|
+
|
743
|
+
def rearrange_kv_cache(self, source_indices):
|
744
|
+
for module, tensor in self.kv_cache.items():
|
745
|
+
# update the key/value cache to contain the selected sequences
|
746
|
+
self.kv_cache[module] = tensor[source_indices].detach()
|
747
|
+
|
748
|
+
|
749
|
+
@paddle.no_grad()
|
750
|
+
def detect_language(
|
751
|
+
model: "Whisper",
|
752
|
+
mel: paddle.Tensor,
|
753
|
+
resource_path: str,
|
754
|
+
tokenizer: Tokenizer = None,
|
755
|
+
) -> Tuple[paddle.Tensor, List[dict]]:
|
756
|
+
"""
|
757
|
+
Detect the spoken language in the audio, and return them as list of strings, along with the ids
|
758
|
+
of the most probable language tokens and the probability distribution over all language tokens.
|
759
|
+
This is performed outside the main decode loop in order to not interfere with kv-caching.
|
760
|
+
Returns
|
761
|
+
-------
|
762
|
+
language_tokens : Tensor, shape = (batch_size,)
|
763
|
+
ids of the most probable language tokens, which appears after the startoftranscript token.
|
764
|
+
language_probs : List[Dict[str, float]], length = batch_size
|
765
|
+
list of dictionaries containing the probability distribution over all languages.
|
766
|
+
"""
|
767
|
+
if tokenizer is None:
|
768
|
+
tokenizer = get_tokenizer(model.is_multilingual, resource_path=resource_path)
|
769
|
+
if (
|
770
|
+
tokenizer.language is None
|
771
|
+
or tokenizer.language_token not in tokenizer.sot_sequence
|
772
|
+
):
|
773
|
+
raise ValueError(
|
774
|
+
"This model doesn't have language tokens so it can't perform lang id"
|
775
|
+
)
|
776
|
+
|
777
|
+
single = mel.ndim == 2
|
778
|
+
if single:
|
779
|
+
mel = mel.unsqueeze(0)
|
780
|
+
|
781
|
+
# skip encoder forward pass if already-encoded audio features were given
|
782
|
+
if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state):
|
783
|
+
mel = model.encoder(mel)
|
784
|
+
|
785
|
+
# forward pass using a single token, startoftranscript
|
786
|
+
batch_size = mel.shape[0]
|
787
|
+
x = paddle.to_tensor([[tokenizer.sot]] * batch_size) # [batch_size, 1]
|
788
|
+
logits = model.logits(x, mel)[:, 0]
|
789
|
+
|
790
|
+
# collect detected languages; suppress all non-language tokens
|
791
|
+
mask = paddle.ones(paddle.to_tensor(logits.shape[-1]), dtype=bool)
|
792
|
+
mask[list(tokenizer.all_language_tokens)] = False
|
793
|
+
logits[:, mask] = -np.inf
|
794
|
+
language_tokens = paddle.argmax(logits, axis=-1)
|
795
|
+
language_token_probs = paddle.nn.functional.softmax(logits, axis=-1)
|
796
|
+
language_probs = [
|
797
|
+
{
|
798
|
+
c: language_token_probs[i, j].tolist()
|
799
|
+
for j, c in zip(tokenizer.all_language_tokens, tokenizer.all_language_codes)
|
800
|
+
}
|
801
|
+
for i in range(batch_size)
|
802
|
+
]
|
803
|
+
|
804
|
+
if single:
|
805
|
+
language_tokens = language_tokens[0]
|
806
|
+
language_probs = language_probs[0]
|
807
|
+
|
808
|
+
return language_tokens, language_probs
|
809
|
+
|
810
|
+
|
811
|
+
@function_requires_deps("tqdm")
|
812
|
+
def transcribe(
|
813
|
+
model: "Whisper",
|
814
|
+
mel: paddle.Tensor,
|
815
|
+
resource_path: str,
|
816
|
+
*,
|
817
|
+
verbose: Optional[bool] = None,
|
818
|
+
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
|
819
|
+
compression_ratio_threshold: Optional[float] = 2.4,
|
820
|
+
logprob_threshold: Optional[float] = -1.0,
|
821
|
+
no_speech_threshold: Optional[float] = 0.6,
|
822
|
+
condition_on_previous_text: bool = True,
|
823
|
+
**decode_options,
|
824
|
+
):
|
825
|
+
"""
|
826
|
+
Transcribe an audio file using Whisper
|
827
|
+
Parameters
|
828
|
+
----------
|
829
|
+
model: Whisper
|
830
|
+
The Whisper model instance
|
831
|
+
mel: paddle.Tensor
|
832
|
+
The audio feature
|
833
|
+
verbose: bool
|
834
|
+
Whether to display the text being decoded to the console. If True, displays all the details,
|
835
|
+
If False, displays minimal details. If None, does not display anything
|
836
|
+
temperature: Union[float, Tuple[float, ...]]
|
837
|
+
Temperature for sampling. It can be a tuple of temperatures, which will be successfully used
|
838
|
+
upon failures according to either `compression_ratio_threshold` or `logprob_threshold`.
|
839
|
+
compression_ratio_threshold: float
|
840
|
+
If the gzip compression ratio is above this value, treat as failed
|
841
|
+
logprob_threshold: float
|
842
|
+
If the average log probability over sampled tokens is below this value, treat as failed
|
843
|
+
no_speech_threshold: float
|
844
|
+
If the no_speech probability is higher than this value AND the average log probability
|
845
|
+
over sampled tokens is below `logprob_threshold`, consider the segment as silent
|
846
|
+
condition_on_previous_text: bool
|
847
|
+
if True, the previous output of the model is provided as a prompt for the next window;
|
848
|
+
disabling may make the text inconsistent across windows, but the model becomes less prone to
|
849
|
+
getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
|
850
|
+
decode_options: dict
|
851
|
+
Keyword arguments to construct `DecodingOptions` instances
|
852
|
+
Returns
|
853
|
+
-------
|
854
|
+
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
|
855
|
+
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
|
856
|
+
"""
|
857
|
+
dtype = np.float32 # paddle only support float32
|
858
|
+
|
859
|
+
if dtype == np.float32:
|
860
|
+
decode_options["fp16"] = False
|
861
|
+
|
862
|
+
if (
|
863
|
+
decode_options.get("language") == "None"
|
864
|
+
or decode_options.get("language", None) is None
|
865
|
+
):
|
866
|
+
if not model.is_multilingual:
|
867
|
+
decode_options["language"] = "en"
|
868
|
+
else:
|
869
|
+
if verbose:
|
870
|
+
print(
|
871
|
+
"Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
|
872
|
+
)
|
873
|
+
segment = pad_or_trim(mel, N_FRAMES)
|
874
|
+
_, probs = model.detect_language(segment, resource_path)
|
875
|
+
decode_options["language"] = max(probs, key=probs.get)
|
876
|
+
if verbose is not None:
|
877
|
+
print(
|
878
|
+
f"Detected language: {LANGUAGES[decode_options['language']].title()}"
|
879
|
+
)
|
880
|
+
|
881
|
+
language = decode_options["language"]
|
882
|
+
task = decode_options.get("task", "transcribe")
|
883
|
+
tokenizer = get_tokenizer(
|
884
|
+
model.is_multilingual,
|
885
|
+
resource_path=resource_path,
|
886
|
+
language=language,
|
887
|
+
task=task,
|
888
|
+
)
|
889
|
+
|
890
|
+
def decode_with_fallback(segment: paddle.Tensor) -> DecodingResult:
|
891
|
+
temperatures = (
|
892
|
+
[temperature] if isinstance(temperature, (int, float)) else temperature
|
893
|
+
)
|
894
|
+
decode_result = None
|
895
|
+
|
896
|
+
for t in temperatures:
|
897
|
+
kwargs = {**decode_options}
|
898
|
+
if t > 0:
|
899
|
+
# disable beam_size and patience when t > 0
|
900
|
+
kwargs.pop("beam_size", None)
|
901
|
+
kwargs.pop("patience", None)
|
902
|
+
else:
|
903
|
+
# disable best_of when t == 0
|
904
|
+
kwargs.pop("best_of", None)
|
905
|
+
|
906
|
+
options = DecodingOptions(**kwargs, temperature=t)
|
907
|
+
decode_result = model.decode(segment, options, resource_path)
|
908
|
+
|
909
|
+
needs_fallback = False
|
910
|
+
if (
|
911
|
+
compression_ratio_threshold is not None
|
912
|
+
and decode_result.compression_ratio > compression_ratio_threshold
|
913
|
+
):
|
914
|
+
needs_fallback = True # too repetitive
|
915
|
+
if (
|
916
|
+
logprob_threshold is not None
|
917
|
+
and decode_result.avg_logprob < logprob_threshold
|
918
|
+
):
|
919
|
+
needs_fallback = True # average log probability is too low
|
920
|
+
|
921
|
+
if not needs_fallback:
|
922
|
+
break
|
923
|
+
|
924
|
+
return decode_result
|
925
|
+
|
926
|
+
seek = 0
|
927
|
+
input_stride = exact_div(
|
928
|
+
N_FRAMES, model.dims.n_audio_ctx
|
929
|
+
) # mel frames per output token: 2
|
930
|
+
time_precision = (
|
931
|
+
input_stride * HOP_LENGTH / SAMPLE_RATE
|
932
|
+
) # time per output token: 0.02 (seconds)
|
933
|
+
all_tokens = []
|
934
|
+
all_segments = []
|
935
|
+
prompt_reset_since = 0
|
936
|
+
|
937
|
+
initial_prompt = decode_options.pop("initial_prompt", None)
|
938
|
+
if initial_prompt and initial_prompt != "None":
|
939
|
+
initial_prompt = tokenizer.encode(" " + initial_prompt.strip()).input_ids
|
940
|
+
all_tokens.extend(initial_prompt)
|
941
|
+
else:
|
942
|
+
initial_prompt = []
|
943
|
+
|
944
|
+
def add_segment(
|
945
|
+
*,
|
946
|
+
start: float,
|
947
|
+
end: float,
|
948
|
+
text_tokens: paddle.Tensor,
|
949
|
+
result: DecodingResult,
|
950
|
+
):
|
951
|
+
text = tokenizer.decode(
|
952
|
+
[token for token in text_tokens if token < tokenizer.eot]
|
953
|
+
)
|
954
|
+
if len(text.strip()) == 0: # skip empty text output
|
955
|
+
return
|
956
|
+
|
957
|
+
all_segments.append(
|
958
|
+
{
|
959
|
+
"id": len(all_segments),
|
960
|
+
"seek": seek,
|
961
|
+
"start": start,
|
962
|
+
"end": end,
|
963
|
+
"text": text,
|
964
|
+
"tokens": result.tokens,
|
965
|
+
"temperature": result.temperature,
|
966
|
+
"avg_logprob": result.avg_logprob,
|
967
|
+
"compression_ratio": result.compression_ratio,
|
968
|
+
"no_speech_prob": result.no_speech_prob,
|
969
|
+
}
|
970
|
+
)
|
971
|
+
if verbose:
|
972
|
+
print(f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}")
|
973
|
+
|
974
|
+
# show the progress bar when verbose is False (otherwise the transcribed text will be printed)
|
975
|
+
num_frames = mel.shape[-1]
|
976
|
+
previous_seek_value = seek
|
977
|
+
|
978
|
+
with tqdm.tqdm(
|
979
|
+
total=num_frames, unit="frames", disable=verbose is not False
|
980
|
+
) as pbar:
|
981
|
+
while seek < num_frames:
|
982
|
+
timestamp_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
|
983
|
+
segment = pad_or_trim(mel[:, seek:], N_FRAMES)
|
984
|
+
segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE
|
985
|
+
|
986
|
+
decode_options["prompt"] = all_tokens[prompt_reset_since:]
|
987
|
+
result: DecodingResult = decode_with_fallback(segment)
|
988
|
+
tokens = paddle.to_tensor(result.tokens)
|
989
|
+
|
990
|
+
if no_speech_threshold is not None:
|
991
|
+
# no voice activity check
|
992
|
+
should_skip = result.no_speech_prob > no_speech_threshold
|
993
|
+
if (
|
994
|
+
logprob_threshold is not None
|
995
|
+
and result.avg_logprob > logprob_threshold
|
996
|
+
):
|
997
|
+
# don't skip if the logprob is high enough, despite the no_speech_prob
|
998
|
+
should_skip = False
|
999
|
+
|
1000
|
+
if should_skip:
|
1001
|
+
seek += segment.shape[
|
1002
|
+
-1
|
1003
|
+
] # fast-forward to the next segment boundary
|
1004
|
+
continue
|
1005
|
+
|
1006
|
+
timestamp_tokens: paddle.Tensor = tokens.greater_equal(
|
1007
|
+
paddle.to_tensor(tokenizer.timestamp_begin)
|
1008
|
+
)
|
1009
|
+
|
1010
|
+
consecutive = paddle.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
|
1011
|
+
if (
|
1012
|
+
len(consecutive) > 0
|
1013
|
+
): # if the output contains two consecutive timestamp tokens
|
1014
|
+
consecutive = paddle.add(consecutive, paddle.to_tensor(1))
|
1015
|
+
last_slice = 0
|
1016
|
+
for current_slice in consecutive:
|
1017
|
+
sliced_tokens = tokens[last_slice:current_slice]
|
1018
|
+
start_timestamp_position = (
|
1019
|
+
sliced_tokens[0].item() - tokenizer.timestamp_begin
|
1020
|
+
)
|
1021
|
+
end_timestamp_position = (
|
1022
|
+
sliced_tokens[-1].item() - tokenizer.timestamp_begin
|
1023
|
+
)
|
1024
|
+
add_segment(
|
1025
|
+
start=timestamp_offset
|
1026
|
+
+ start_timestamp_position * time_precision,
|
1027
|
+
end=timestamp_offset + end_timestamp_position * time_precision,
|
1028
|
+
text_tokens=sliced_tokens[1:-1],
|
1029
|
+
result=result,
|
1030
|
+
)
|
1031
|
+
last_slice = current_slice
|
1032
|
+
last_timestamp_position = (
|
1033
|
+
tokens[last_slice - 1].item() - tokenizer.timestamp_begin
|
1034
|
+
)
|
1035
|
+
seek += last_timestamp_position * input_stride
|
1036
|
+
all_tokens.extend(tokens[: last_slice + 1].tolist())
|
1037
|
+
else:
|
1038
|
+
duration = segment_duration
|
1039
|
+
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
|
1040
|
+
if (
|
1041
|
+
len(timestamps) > 0
|
1042
|
+
and timestamps[-1].item() != tokenizer.timestamp_begin
|
1043
|
+
):
|
1044
|
+
# no consecutive timestamps but it has a timestamp; use the last one.
|
1045
|
+
# single timestamp at the end means no speech after the last timestamp.
|
1046
|
+
last_timestamp_position = (
|
1047
|
+
timestamps[-1].item() - tokenizer.timestamp_begin
|
1048
|
+
)
|
1049
|
+
duration = last_timestamp_position * time_precision
|
1050
|
+
|
1051
|
+
add_segment(
|
1052
|
+
start=timestamp_offset,
|
1053
|
+
end=timestamp_offset + duration,
|
1054
|
+
text_tokens=tokens,
|
1055
|
+
result=result,
|
1056
|
+
)
|
1057
|
+
|
1058
|
+
seek += segment.shape[-1]
|
1059
|
+
all_tokens.extend(tokens.tolist())
|
1060
|
+
|
1061
|
+
if not condition_on_previous_text or result.temperature > 0.5:
|
1062
|
+
# do not feed the prompt tokens if a high temperature was used
|
1063
|
+
prompt_reset_since = len(all_tokens)
|
1064
|
+
|
1065
|
+
# update progress bar
|
1066
|
+
pbar.update(min(num_frames, seek) - previous_seek_value)
|
1067
|
+
previous_seek_value = seek
|
1068
|
+
|
1069
|
+
return dict(
|
1070
|
+
text=tokenizer.decode(all_tokens[len(initial_prompt) :]),
|
1071
|
+
segments=all_segments,
|
1072
|
+
language=language,
|
1073
|
+
)
|
1074
|
+
|
1075
|
+
|
1076
|
+
class SequenceRanker:
|
1077
|
+
def rank(
|
1078
|
+
self, tokens: List[List[paddle.Tensor]], sum_logprobs: List[List[float]]
|
1079
|
+
) -> List[int]:
|
1080
|
+
"""
|
1081
|
+
Given a list of groups of samples and their cumulative log probabilities,
|
1082
|
+
return the indices of the samples in each group to select as the final result
|
1083
|
+
"""
|
1084
|
+
raise NotImplementedError
|
1085
|
+
|
1086
|
+
|
1087
|
+
class MaximumLikelihoodRanker(SequenceRanker):
|
1088
|
+
"""
|
1089
|
+
Select the sample with the highest log probabilities, penalized using either
|
1090
|
+
a simple length normalization or Google NMT paper's length penalty
|
1091
|
+
"""
|
1092
|
+
|
1093
|
+
def __init__(self, length_penalty: Optional[float]):
|
1094
|
+
self.length_penalty = length_penalty
|
1095
|
+
|
1096
|
+
def rank(self, tokens: List[List[paddle.Tensor]], sum_logprobs: List[List[float]]):
|
1097
|
+
def scores(logprobs, lengths):
|
1098
|
+
result = []
|
1099
|
+
for logprob, length in zip(logprobs, lengths):
|
1100
|
+
if self.length_penalty is None or self.length_penalty == "None":
|
1101
|
+
penalty = length
|
1102
|
+
else:
|
1103
|
+
# from the Google NMT paper
|
1104
|
+
penalty = ((5 + length) / 6) ** self.length_penalty
|
1105
|
+
result.append(logprob / penalty)
|
1106
|
+
return result
|
1107
|
+
|
1108
|
+
# get the sequence with the highest score
|
1109
|
+
lengths = [[len(t) for t in s] for s in tokens]
|
1110
|
+
return [np.argmax(scores(p, l)) for p, l in zip(sum_logprobs, lengths)]
|
1111
|
+
|
1112
|
+
|
1113
|
+
class TokenDecoder:
|
1114
|
+
def reset(self):
|
1115
|
+
"""Initialize any stateful variables for decoding a new sequence"""
|
1116
|
+
|
1117
|
+
def update(
|
1118
|
+
self,
|
1119
|
+
tokens: paddle.Tensor,
|
1120
|
+
logits: paddle.Tensor,
|
1121
|
+
sum_logprobs: paddle.Tensor,
|
1122
|
+
) -> Tuple[paddle.Tensor, bool]:
|
1123
|
+
"""Specify how to select the next token, based on the current trace and logits
|
1124
|
+
Parameters
|
1125
|
+
----------
|
1126
|
+
tokens : Tensor, shape = (n_batch, current_sequence_length)
|
1127
|
+
all tokens in the context so far, including the prefix and sot_sequence tokens
|
1128
|
+
logits : Tensor, shape = (n_batch, vocab_size)
|
1129
|
+
per-token logits of the probability distribution at the current step
|
1130
|
+
sum_logprobs : Tensor, shape = (n_batch)
|
1131
|
+
cumulative log probabilities for each sequence
|
1132
|
+
Returns
|
1133
|
+
-------
|
1134
|
+
tokens : Tensor, shape = (n_batch, current_sequence_length + 1)
|
1135
|
+
the tokens, appended with the selected next token
|
1136
|
+
completed : bool
|
1137
|
+
True if all sequences has reached the end of text
|
1138
|
+
"""
|
1139
|
+
raise NotImplementedError
|
1140
|
+
|
1141
|
+
def finalize(
|
1142
|
+
self, tokens: paddle.Tensor, sum_logprobs: paddle.Tensor
|
1143
|
+
) -> Tuple[Sequence[Sequence[paddle.Tensor]], List[List[float]]]:
|
1144
|
+
"""Finalize search and return the final candidate sequences
|
1145
|
+
Parameters
|
1146
|
+
----------
|
1147
|
+
tokens : Tensor, shape = (batch_size, beam_size, current_sequence_length)
|
1148
|
+
all tokens in the context so far, including the prefix and sot_sequence
|
1149
|
+
sum_logprobs : Tensor, shape = (batch_size, beam_size)
|
1150
|
+
cumulative log probabilities for each sequence
|
1151
|
+
Returns
|
1152
|
+
-------
|
1153
|
+
tokens : Sequence[Sequence[Tensor]], length = batch_size
|
1154
|
+
sequence of Tensors containing candidate token sequences, for each audio input
|
1155
|
+
sum_logprobs : List[List[float]], length = batch_size
|
1156
|
+
sequence of cumulative log probabilities corresponding to the above
|
1157
|
+
"""
|
1158
|
+
raise NotImplementedError
|
1159
|
+
|
1160
|
+
|
1161
|
+
class GreedyDecoder(TokenDecoder):
|
1162
|
+
def __init__(self, temperature: float, eot: int):
|
1163
|
+
self.temperature = temperature
|
1164
|
+
self.eot = eot
|
1165
|
+
|
1166
|
+
def update(
|
1167
|
+
self,
|
1168
|
+
tokens: paddle.Tensor,
|
1169
|
+
logits: paddle.Tensor,
|
1170
|
+
sum_logprobs: paddle.Tensor,
|
1171
|
+
) -> Tuple[paddle.Tensor, bool]:
|
1172
|
+
temperature = self.temperature
|
1173
|
+
if temperature == 0:
|
1174
|
+
next_tokens = paddle.argmax(logits, axis=-1)
|
1175
|
+
else:
|
1176
|
+
next_tokens = paddle.distribution.Categorical(
|
1177
|
+
logits=logits / temperature
|
1178
|
+
).sample([1])
|
1179
|
+
next_tokens = paddle.reshape(
|
1180
|
+
next_tokens,
|
1181
|
+
[
|
1182
|
+
next_tokens.shape[0] * next_tokens.shape[1],
|
1183
|
+
],
|
1184
|
+
)
|
1185
|
+
|
1186
|
+
logprobs = paddle.nn.functional.log_softmax(
|
1187
|
+
logits, axis=-1, dtype=paddle.float32
|
1188
|
+
)
|
1189
|
+
current_logprobs = logprobs[paddle.arange(logprobs.shape[0]), next_tokens]
|
1190
|
+
sum_logprobs += current_logprobs * paddle.to_tensor(
|
1191
|
+
(tokens[:, -1] != self.eot), dtype=paddle.float32
|
1192
|
+
)
|
1193
|
+
|
1194
|
+
next_tokens[tokens[:, -1] == self.eot] = self.eot
|
1195
|
+
tokens = paddle.concat([tokens, next_tokens[:, None]], axis=-1)
|
1196
|
+
|
1197
|
+
completed = paddle.all((tokens[:, -1] == self.eot))
|
1198
|
+
return tokens, completed
|
1199
|
+
|
1200
|
+
def finalize(self, tokens: paddle.Tensor, sum_logprobs: paddle.Tensor):
|
1201
|
+
# make sure each sequence has at least one EOT token at the end
|
1202
|
+
tokens = paddle.nn.functional.pad(
|
1203
|
+
tokens, (0, 1), value=self.eot, data_format="NCL"
|
1204
|
+
)
|
1205
|
+
return tokens, sum_logprobs.tolist()
|
1206
|
+
|
1207
|
+
|
1208
|
+
class BeamSearchDecoder(TokenDecoder):
|
1209
|
+
def __init__(
|
1210
|
+
self,
|
1211
|
+
beam_size: int,
|
1212
|
+
eot: int,
|
1213
|
+
inference: Inference,
|
1214
|
+
patience: Optional[float] = None,
|
1215
|
+
):
|
1216
|
+
self.beam_size = beam_size
|
1217
|
+
self.eot = eot
|
1218
|
+
self.inference = inference
|
1219
|
+
self.patience = patience or 1.0
|
1220
|
+
if patience is None or patience == "None":
|
1221
|
+
self.patience = 1.0
|
1222
|
+
else:
|
1223
|
+
self.patience = patience
|
1224
|
+
self.max_candidates: int = round(beam_size * self.patience)
|
1225
|
+
self.finished_sequences = None
|
1226
|
+
|
1227
|
+
assert (
|
1228
|
+
self.max_candidates > 0
|
1229
|
+
), f"Invalid beam size ({beam_size}) or patience ({patience})"
|
1230
|
+
|
1231
|
+
def reset(self):
|
1232
|
+
self.finished_sequences = None
|
1233
|
+
|
1234
|
+
def update(
|
1235
|
+
self,
|
1236
|
+
tokens: paddle.Tensor,
|
1237
|
+
logits: paddle.Tensor,
|
1238
|
+
sum_logprobs: paddle.Tensor,
|
1239
|
+
) -> Tuple[paddle.Tensor, bool]:
|
1240
|
+
if tokens.shape[0] % self.beam_size != 0:
|
1241
|
+
raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
|
1242
|
+
|
1243
|
+
batch_size = tokens.shape[0] // self.beam_size
|
1244
|
+
if self.finished_sequences is None: # for the first update
|
1245
|
+
self.finished_sequences = [{} for _ in range(batch_size)]
|
1246
|
+
|
1247
|
+
logprobs = paddle.nn.functional.log_softmax(logits, axis=-1, dtype="float32")
|
1248
|
+
next_tokens, source_indices, finished_sequences = [], [], []
|
1249
|
+
for i in range(batch_size):
|
1250
|
+
scores, sources, finished = {}, {}, {}
|
1251
|
+
|
1252
|
+
# STEP 1: calculate the cumulative log probabilities for possible candidates
|
1253
|
+
for j in range(self.beam_size):
|
1254
|
+
idx = i * self.beam_size + j
|
1255
|
+
prefix = tokens[idx].tolist()
|
1256
|
+
logprob, token = paddle.topk(logprobs[idx], k=self.beam_size + 1)
|
1257
|
+
for logprob, token in zip(logprob, token):
|
1258
|
+
new_logprob = (sum_logprobs[idx] + logprob).item()
|
1259
|
+
sequence = tuple(prefix + [token.item()])
|
1260
|
+
scores[sequence] = new_logprob
|
1261
|
+
sources[sequence] = idx
|
1262
|
+
|
1263
|
+
# STEP 2: rank the candidates and keep the top beam_size sequences for each audio
|
1264
|
+
saved = 0
|
1265
|
+
for sequence in sorted(scores, key=scores.get, reverse=True):
|
1266
|
+
if sequence[-1] == self.eot:
|
1267
|
+
finished[sequence] = scores[sequence]
|
1268
|
+
else:
|
1269
|
+
sum_logprobs[len(next_tokens)] = scores[sequence]
|
1270
|
+
next_tokens.append(sequence)
|
1271
|
+
source_indices.append(sources[sequence])
|
1272
|
+
|
1273
|
+
saved += 1
|
1274
|
+
if saved == self.beam_size:
|
1275
|
+
break
|
1276
|
+
|
1277
|
+
finished_sequences.append(finished)
|
1278
|
+
|
1279
|
+
tokens = paddle.to_tensor(next_tokens)
|
1280
|
+
self.inference.rearrange_kv_cache(source_indices)
|
1281
|
+
|
1282
|
+
# add newly finished sequences to self.finished_sequences
|
1283
|
+
assert len(self.finished_sequences) == len(finished_sequences)
|
1284
|
+
for previously_finished, newly_finished in zip(
|
1285
|
+
self.finished_sequences, finished_sequences
|
1286
|
+
):
|
1287
|
+
for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
|
1288
|
+
if len(previously_finished) >= self.max_candidates:
|
1289
|
+
break # the candidate list is full
|
1290
|
+
previously_finished[seq] = newly_finished[seq]
|
1291
|
+
|
1292
|
+
# mark as completed if all audio has enough number of samples
|
1293
|
+
completed = all(
|
1294
|
+
len(sequences) >= self.max_candidates
|
1295
|
+
for sequences in self.finished_sequences
|
1296
|
+
)
|
1297
|
+
return tokens, completed
|
1298
|
+
|
1299
|
+
def finalize(self, preceding_tokens: paddle.Tensor, sum_logprobs: paddle.Tensor):
|
1300
|
+
# collect all finished sequences, including patience, and add unfinished ones if not enough
|
1301
|
+
sum_logprobs = sum_logprobs.cpu()
|
1302
|
+
for i, sequences in enumerate(self.finished_sequences):
|
1303
|
+
if (
|
1304
|
+
len(sequences) < self.beam_size
|
1305
|
+
): # when not enough sequences are finished
|
1306
|
+
for j in list(np.argsort(sum_logprobs[i]))[::-1]:
|
1307
|
+
sequence = preceding_tokens[i, j].tolist() + [self.eot]
|
1308
|
+
sequences[tuple(sequence)] = sum_logprobs[i][j].item()
|
1309
|
+
if len(sequences) >= self.beam_size:
|
1310
|
+
break
|
1311
|
+
|
1312
|
+
tokens: List[List[paddle.Tensor]] = [
|
1313
|
+
[paddle.to_tensor(seq) for seq in sequences.keys()]
|
1314
|
+
for sequences in self.finished_sequences
|
1315
|
+
]
|
1316
|
+
sum_logprobs: List[List[float]] = [
|
1317
|
+
list(sequences.values()) for sequences in self.finished_sequences
|
1318
|
+
]
|
1319
|
+
return tokens, sum_logprobs
|
1320
|
+
|
1321
|
+
|
1322
|
+
class LogitFilter:
|
1323
|
+
def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor) -> None:
|
1324
|
+
"""Apply any filtering or masking to logits in-place
|
1325
|
+
|
1326
|
+
Parameters
|
1327
|
+
----------
|
1328
|
+
logits : Tensor, shape = (n_batch, vocab_size)
|
1329
|
+
per-token logits of the probability distribution at the current step
|
1330
|
+
|
1331
|
+
tokens : Tensor, shape = (n_batch, current_sequence_length)
|
1332
|
+
all tokens in the context so far, including the prefix and sot_sequence tokens
|
1333
|
+
|
1334
|
+
"""
|
1335
|
+
raise NotImplementedError
|
1336
|
+
|
1337
|
+
|
1338
|
+
class SuppressBlank(LogitFilter):
|
1339
|
+
def __init__(self, tokenizer: Tokenizer, sample_begin: int):
|
1340
|
+
self.tokenizer = tokenizer
|
1341
|
+
self.sample_begin = sample_begin
|
1342
|
+
|
1343
|
+
def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor):
|
1344
|
+
if tokens.shape[1] == self.sample_begin:
|
1345
|
+
logits[:, self.tokenizer.encode(" ").input_ids + [self.tokenizer.eot]] = (
|
1346
|
+
-np.inf
|
1347
|
+
)
|
1348
|
+
|
1349
|
+
|
1350
|
+
class SuppressTokens(LogitFilter):
|
1351
|
+
def __init__(self, suppress_tokens: Sequence[int]):
|
1352
|
+
self.suppress_tokens = list(suppress_tokens)
|
1353
|
+
|
1354
|
+
def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor):
|
1355
|
+
logits[:, self.suppress_tokens] = -np.inf
|
1356
|
+
|
1357
|
+
|
1358
|
+
class ApplyTimestampRules(LogitFilter):
|
1359
|
+
def __init__(
|
1360
|
+
self,
|
1361
|
+
tokenizer: Tokenizer,
|
1362
|
+
sample_begin: int,
|
1363
|
+
max_initial_timestamp_index: Optional[int],
|
1364
|
+
):
|
1365
|
+
self.tokenizer = tokenizer
|
1366
|
+
self.sample_begin = sample_begin
|
1367
|
+
self.max_initial_timestamp_index = max_initial_timestamp_index
|
1368
|
+
|
1369
|
+
def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor):
|
1370
|
+
# suppress <|notimestamps|> which is handled by without_timestamps
|
1371
|
+
if self.tokenizer.no_timestamps is not None:
|
1372
|
+
logits[:, self.tokenizer.no_timestamps] = -np.inf
|
1373
|
+
|
1374
|
+
# timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
|
1375
|
+
for k in range(tokens.shape[0]):
|
1376
|
+
seq = [t for t in tokens[k, self.sample_begin :].tolist()]
|
1377
|
+
last_was_timestamp = (
|
1378
|
+
len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
|
1379
|
+
)
|
1380
|
+
penultimate_was_timestamp = (
|
1381
|
+
len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin
|
1382
|
+
)
|
1383
|
+
|
1384
|
+
if last_was_timestamp:
|
1385
|
+
if penultimate_was_timestamp: # has to be non-timestamp
|
1386
|
+
logits[k, self.tokenizer.timestamp_begin :] = -np.inf
|
1387
|
+
else: # cannot be normal text tokens
|
1388
|
+
logits[k, : self.tokenizer.eot] = -np.inf
|
1389
|
+
|
1390
|
+
# apply the `max_initial_timestamp` option
|
1391
|
+
if (
|
1392
|
+
tokens.shape[1] == self.sample_begin
|
1393
|
+
and self.max_initial_timestamp_index is not None
|
1394
|
+
):
|
1395
|
+
last_allowed = (
|
1396
|
+
self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
|
1397
|
+
)
|
1398
|
+
logits[:, last_allowed + 1 :] = -np.inf
|
1399
|
+
|
1400
|
+
# if sum of probability over timestamps is above any other token, sample timestamp
|
1401
|
+
logprobs = paddle.nn.functional.log_softmax(logits, axis=-1, dtype="float32")
|
1402
|
+
for k in range(tokens.shape[0]):
|
1403
|
+
# When using paddle.logsumexp on a 32GB Tesla-V100 GPU, we encountered CUDA error 700.
|
1404
|
+
# To bypass this issue in CI, we have decomposed the operation into separate steps.
|
1405
|
+
# It will raise 2e-6 difference in precision.
|
1406
|
+
# TODO: revert this after logsumexp been fixed.
|
1407
|
+
timestamp_logprob = paddle.exp(
|
1408
|
+
logprobs[k, self.tokenizer.timestamp_begin :]
|
1409
|
+
)
|
1410
|
+
timestamp_logprob = paddle.sum(timestamp_logprob, axis=-1)
|
1411
|
+
timestamp_logprob = paddle.log(timestamp_logprob)
|
1412
|
+
max_text_token_logprob = paddle.max(
|
1413
|
+
logprobs[k, : self.tokenizer.timestamp_begin]
|
1414
|
+
)
|
1415
|
+
if timestamp_logprob > max_text_token_logprob:
|
1416
|
+
logits[k, : self.tokenizer.timestamp_begin] = -np.inf
|
1417
|
+
|
1418
|
+
|
1419
|
+
class DecodingTask:
|
1420
|
+
inference: Inference
|
1421
|
+
sequence_ranker: SequenceRanker
|
1422
|
+
decoder: TokenDecoder
|
1423
|
+
logit_filters: List[LogitFilter]
|
1424
|
+
|
1425
|
+
def __init__(self, model: "Whisper", options: DecodingOptions, resource_path: str):
|
1426
|
+
self.model = model
|
1427
|
+
|
1428
|
+
language = options.language or "en"
|
1429
|
+
tokenizer = get_tokenizer(
|
1430
|
+
model.is_multilingual,
|
1431
|
+
resource_path=resource_path,
|
1432
|
+
language=language,
|
1433
|
+
task=options.task,
|
1434
|
+
)
|
1435
|
+
self.tokenizer: Tokenizer = tokenizer
|
1436
|
+
self.options: DecodingOptions = self._verify_options(options)
|
1437
|
+
self.resource_path: str = resource_path
|
1438
|
+
|
1439
|
+
self.beam_size: int = options.beam_size or options.best_of or 1
|
1440
|
+
self.n_ctx: int = model.dims.n_text_ctx
|
1441
|
+
self.sample_len: int = options.sample_len or model.dims.n_text_ctx // 2
|
1442
|
+
|
1443
|
+
self.sot_sequence: Tuple[int] = tokenizer.sot_sequence
|
1444
|
+
if self.options.without_timestamps:
|
1445
|
+
self.sot_sequence = tokenizer.sot_sequence_including_notimestamps
|
1446
|
+
|
1447
|
+
self.initial_tokens: Tuple[int] = self._get_initial_tokens()
|
1448
|
+
self.sample_begin: int = len(self.initial_tokens)
|
1449
|
+
self.sot_index: int = self.initial_tokens.index(tokenizer.sot)
|
1450
|
+
|
1451
|
+
# inference: implements the forward pass through the decoder, including kv caching
|
1452
|
+
self.inference = WhisperInference(model, len(self.initial_tokens))
|
1453
|
+
|
1454
|
+
# sequence ranker: implements how to rank a group of sampled sequences
|
1455
|
+
self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty)
|
1456
|
+
|
1457
|
+
# decoder: implements how to select the next tokens, given the autoregressive distribution
|
1458
|
+
if options.beam_size is not None:
|
1459
|
+
self.decoder = BeamSearchDecoder(
|
1460
|
+
options.beam_size, tokenizer.eot, self.inference, options.patience
|
1461
|
+
)
|
1462
|
+
else:
|
1463
|
+
self.decoder = GreedyDecoder(options.temperature, tokenizer.eot)
|
1464
|
+
|
1465
|
+
# logit filters: applies various rules to suppress or penalize certain tokens
|
1466
|
+
self.logit_filters = []
|
1467
|
+
if self.options.suppress_blank:
|
1468
|
+
self.logit_filters.append(SuppressBlank(self.tokenizer, self.sample_begin))
|
1469
|
+
if self.options.suppress_tokens:
|
1470
|
+
self.logit_filters.append(SuppressTokens(self._get_suppress_tokens()))
|
1471
|
+
if not options.without_timestamps:
|
1472
|
+
precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds
|
1473
|
+
max_initial_timestamp_index = None
|
1474
|
+
if options.max_initial_timestamp:
|
1475
|
+
max_initial_timestamp_index = round(
|
1476
|
+
self.options.max_initial_timestamp / precision
|
1477
|
+
)
|
1478
|
+
self.logit_filters.append(
|
1479
|
+
ApplyTimestampRules(
|
1480
|
+
tokenizer, self.sample_begin, max_initial_timestamp_index
|
1481
|
+
)
|
1482
|
+
)
|
1483
|
+
|
1484
|
+
def _verify_options(self, options: DecodingOptions) -> DecodingOptions:
|
1485
|
+
if options.beam_size is not None and options.best_of is not None:
|
1486
|
+
raise ValueError("beam_size and best_of can't be given together")
|
1487
|
+
if options.temperature == 0:
|
1488
|
+
if options.best_of is not None:
|
1489
|
+
raise ValueError("best_of with greedy sampling (T=0) is not compatible")
|
1490
|
+
if options.patience is not None and options.beam_size is None:
|
1491
|
+
raise ValueError("patience requires beam_size to be given")
|
1492
|
+
if options.length_penalty is not None and options.length_penalty != "None":
|
1493
|
+
if not (0 <= options.length_penalty <= 1):
|
1494
|
+
raise ValueError(
|
1495
|
+
"length_penalty (alpha) should be a value between 0 and 1"
|
1496
|
+
)
|
1497
|
+
|
1498
|
+
return options
|
1499
|
+
|
1500
|
+
def _get_initial_tokens(self) -> Tuple[int]:
|
1501
|
+
tokens = list(self.sot_sequence)
|
1502
|
+
prefix = self.options.prefix
|
1503
|
+
prompt = self.options.prompt
|
1504
|
+
|
1505
|
+
if prefix:
|
1506
|
+
prefix_tokens = (
|
1507
|
+
self.tokenizer.encode(" " + prefix.strip().input_ids)
|
1508
|
+
if isinstance(prefix, str)
|
1509
|
+
else prefix
|
1510
|
+
)
|
1511
|
+
if self.sample_len is not None:
|
1512
|
+
max_prefix_len = self.n_ctx // 2 - self.sample_len
|
1513
|
+
prefix_tokens = prefix_tokens[-max_prefix_len:]
|
1514
|
+
tokens = tokens + prefix_tokens
|
1515
|
+
|
1516
|
+
if prompt:
|
1517
|
+
prompt_tokens = (
|
1518
|
+
self.tokenizer.encode(" " + prompt.strip().input_ids)
|
1519
|
+
if isinstance(prompt, str)
|
1520
|
+
else prompt
|
1521
|
+
)
|
1522
|
+
tokens = (
|
1523
|
+
[self.tokenizer.sot_prev]
|
1524
|
+
+ prompt_tokens[-(self.n_ctx // 2 - 1) :]
|
1525
|
+
+ tokens
|
1526
|
+
)
|
1527
|
+
|
1528
|
+
return tuple(tokens)
|
1529
|
+
|
1530
|
+
def _get_suppress_tokens(self) -> Tuple[int]:
|
1531
|
+
suppress_tokens = self.options.suppress_tokens
|
1532
|
+
|
1533
|
+
if isinstance(suppress_tokens, str):
|
1534
|
+
suppress_tokens = [int(t) for t in suppress_tokens.split(",")]
|
1535
|
+
|
1536
|
+
if -1 in suppress_tokens:
|
1537
|
+
suppress_tokens = [t for t in suppress_tokens if t >= 0]
|
1538
|
+
suppress_tokens.extend(self.tokenizer.non_speech_tokens)
|
1539
|
+
elif suppress_tokens is None or len(suppress_tokens) == 0:
|
1540
|
+
suppress_tokens = [] # interpret empty string as an empty list
|
1541
|
+
else:
|
1542
|
+
assert isinstance(suppress_tokens, list), "suppress_tokens must be a list"
|
1543
|
+
|
1544
|
+
suppress_tokens.extend(
|
1545
|
+
[self.tokenizer.sot, self.tokenizer.sot_prev, self.tokenizer.sot_lm]
|
1546
|
+
)
|
1547
|
+
if self.tokenizer.no_speech is not None:
|
1548
|
+
# no-speech probability is collected separately
|
1549
|
+
suppress_tokens.append(self.tokenizer.no_speech)
|
1550
|
+
|
1551
|
+
return tuple(sorted(set(suppress_tokens)))
|
1552
|
+
|
1553
|
+
def _get_audio_features(self, mel: paddle.Tensor):
|
1554
|
+
|
1555
|
+
if mel.shape[-2:] == (
|
1556
|
+
self.model.dims.n_audio_ctx,
|
1557
|
+
self.model.dims.n_audio_state,
|
1558
|
+
):
|
1559
|
+
# encoded audio features are given; skip audio encoding
|
1560
|
+
audio_features = mel
|
1561
|
+
else:
|
1562
|
+
audio_features = self.model.encoder(mel)
|
1563
|
+
|
1564
|
+
return audio_features
|
1565
|
+
|
1566
|
+
def _detect_language(
|
1567
|
+
self,
|
1568
|
+
audio_features: paddle.Tensor,
|
1569
|
+
tokens: paddle.Tensor,
|
1570
|
+
resource_path: str,
|
1571
|
+
):
|
1572
|
+
languages = [self.options.language] * audio_features.shape[0]
|
1573
|
+
lang_probs = None
|
1574
|
+
|
1575
|
+
if self.options.language is None or self.options.task == "lang_id":
|
1576
|
+
lang_tokens, lang_probs = self.model.detect_language(
|
1577
|
+
audio_features, self.tokenizer, self.resource_path
|
1578
|
+
)
|
1579
|
+
languages = [max(probs, key=probs.get) for probs in lang_probs]
|
1580
|
+
if self.options.language is None:
|
1581
|
+
tokens[:, self.sot_index + 1] = lang_tokens # write language tokens
|
1582
|
+
|
1583
|
+
return languages, lang_probs
|
1584
|
+
|
1585
|
+
def _main_loop(self, audio_features: paddle.Tensor, tokens: paddle.Tensor):
|
1586
|
+
assert audio_features.shape[0] == tokens.shape[0]
|
1587
|
+
n_batch = tokens.shape[0]
|
1588
|
+
sum_logprobs: paddle.Tensor = paddle.zeros(
|
1589
|
+
paddle.to_tensor(n_batch), dtype=paddle.float32
|
1590
|
+
)
|
1591
|
+
no_speech_probs = [np.nan] * n_batch
|
1592
|
+
|
1593
|
+
try:
|
1594
|
+
for i in range(self.sample_len):
|
1595
|
+
logits = self.inference.logits(tokens, audio_features)
|
1596
|
+
|
1597
|
+
if (
|
1598
|
+
i == 0 and self.tokenizer.no_speech is not None
|
1599
|
+
): # save no_speech_probs
|
1600
|
+
probs_at_sot = paddle.nn.functional.softmax(
|
1601
|
+
logits[:, self.sot_index], axis=-1, dtype=paddle.float32
|
1602
|
+
)
|
1603
|
+
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
|
1604
|
+
|
1605
|
+
# now we need to consider the logits at the last token only
|
1606
|
+
logits = logits[:, -1]
|
1607
|
+
|
1608
|
+
# apply the logit filters, e.g. for suppressing or applying penalty to
|
1609
|
+
for logit_filter in self.logit_filters:
|
1610
|
+
logit_filter.apply(logits, tokens)
|
1611
|
+
|
1612
|
+
# expand the tokens tensor with the selected next tokens
|
1613
|
+
tokens, completed = self.decoder.update(tokens, logits, sum_logprobs)
|
1614
|
+
if completed or tokens.shape[-1] > self.n_ctx:
|
1615
|
+
break
|
1616
|
+
finally:
|
1617
|
+
self.inference.cleanup_caching()
|
1618
|
+
|
1619
|
+
return tokens, sum_logprobs, no_speech_probs
|
1620
|
+
|
1621
|
+
@paddle.no_grad()
|
1622
|
+
def run(self, mel: paddle.Tensor) -> List[DecodingResult]:
|
1623
|
+
self.decoder.reset()
|
1624
|
+
tokenizer: Tokenizer = self.tokenizer
|
1625
|
+
batch_size: int = mel.shape[0]
|
1626
|
+
|
1627
|
+
audio_features: paddle.Tensor = self._get_audio_features(
|
1628
|
+
mel
|
1629
|
+
) # encoder forward pass
|
1630
|
+
|
1631
|
+
tokens: paddle.Tensor
|
1632
|
+
if batch_size > 1:
|
1633
|
+
for i in range(batch_size):
|
1634
|
+
tokens = paddle.concat(
|
1635
|
+
x=[
|
1636
|
+
paddle.to_tensor([self.initial_tokens]),
|
1637
|
+
paddle.to_tensor([self.initial_tokens]),
|
1638
|
+
],
|
1639
|
+
axis=0,
|
1640
|
+
)
|
1641
|
+
elif batch_size == 1:
|
1642
|
+
tokens = paddle.to_tensor([self.initial_tokens])
|
1643
|
+
|
1644
|
+
# detect language if requested, overwriting the language token
|
1645
|
+
languages, language_probs = self._detect_language(
|
1646
|
+
paddle.to_tensor(audio_features),
|
1647
|
+
paddle.to_tensor(tokens),
|
1648
|
+
self.resource_path,
|
1649
|
+
)
|
1650
|
+
|
1651
|
+
if self.options.task == "lang_id":
|
1652
|
+
return [
|
1653
|
+
DecodingResult(
|
1654
|
+
audio_features=features, language=language, language_probs=probs
|
1655
|
+
)
|
1656
|
+
for features, language, probs in zip(
|
1657
|
+
audio_features, languages, language_probs
|
1658
|
+
)
|
1659
|
+
]
|
1660
|
+
|
1661
|
+
# repeat the audio & text tensors by the group size, for beam search or best-of-n sampling
|
1662
|
+
audio_features = paddle.repeat_interleave(
|
1663
|
+
audio_features, self.beam_size, axis=0
|
1664
|
+
)
|
1665
|
+
tokens = paddle.repeat_interleave(tokens, self.beam_size, axis=0)
|
1666
|
+
# call the main sampling loop
|
1667
|
+
tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens)
|
1668
|
+
# reshape the tensors to have (batch_size, beam_size) as the first two dimensions
|
1669
|
+
audio_features = audio_features[:: self.beam_size]
|
1670
|
+
no_speech_probs = no_speech_probs[:: self.beam_size]
|
1671
|
+
assert audio_features.shape[0] == len(no_speech_probs) == batch_size
|
1672
|
+
tokens = tokens.reshape([batch_size, self.beam_size, -1])
|
1673
|
+
sum_logprobs = sum_logprobs.reshape([batch_size, self.beam_size])
|
1674
|
+
|
1675
|
+
# get the final candidates for each group, and slice between the first sampled token and EOT
|
1676
|
+
tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
|
1677
|
+
tokens: List[List[paddle.Tensor]] = [
|
1678
|
+
[t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s]
|
1679
|
+
for s in tokens
|
1680
|
+
]
|
1681
|
+
|
1682
|
+
# select the top-ranked sample in each group
|
1683
|
+
selected = self.sequence_ranker.rank(tokens, sum_logprobs)
|
1684
|
+
tokens: List[List[int]] = [t[i].tolist() for i, t in zip(selected, tokens)]
|
1685
|
+
texts: List[str] = [tokenizer.decode(t).strip() for t in tokens]
|
1686
|
+
|
1687
|
+
sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
|
1688
|
+
avg_logprobs: List[float] = [
|
1689
|
+
lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)
|
1690
|
+
]
|
1691
|
+
|
1692
|
+
fields = (
|
1693
|
+
texts,
|
1694
|
+
languages,
|
1695
|
+
tokens,
|
1696
|
+
audio_features,
|
1697
|
+
avg_logprobs,
|
1698
|
+
no_speech_probs,
|
1699
|
+
)
|
1700
|
+
if len(set(map(len, fields))) != 1:
|
1701
|
+
raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}")
|
1702
|
+
|
1703
|
+
return [
|
1704
|
+
DecodingResult(
|
1705
|
+
audio_features=features,
|
1706
|
+
language=language,
|
1707
|
+
tokens=tokens,
|
1708
|
+
text=text,
|
1709
|
+
avg_logprob=avg_logprob,
|
1710
|
+
no_speech_prob=no_speech_prob,
|
1711
|
+
temperature=self.options.temperature,
|
1712
|
+
compression_ratio=compression_ratio(text),
|
1713
|
+
)
|
1714
|
+
for text, language, tokens, features, avg_logprob, no_speech_prob in zip(
|
1715
|
+
*fields
|
1716
|
+
)
|
1717
|
+
]
|
1718
|
+
|
1719
|
+
|
1720
|
+
@paddle.no_grad()
|
1721
|
+
def decode(
|
1722
|
+
model: "Whisper",
|
1723
|
+
mel: paddle.Tensor,
|
1724
|
+
options: DecodingOptions = DecodingOptions(),
|
1725
|
+
resource_path=str,
|
1726
|
+
) -> Union[DecodingResult, List[DecodingResult]]:
|
1727
|
+
"""
|
1728
|
+
Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).
|
1729
|
+
Parameters
|
1730
|
+
----------
|
1731
|
+
model: Whisper
|
1732
|
+
the Whisper model instance
|
1733
|
+
mel: paddle.Tensor, shape = (80, 3000) or (*, 80, 3000)
|
1734
|
+
A tensor containing the Mel spectrogram(s)
|
1735
|
+
options: DecodingOptions
|
1736
|
+
A dataclass that contains all necessary options for decoding 30-second segments
|
1737
|
+
Returns
|
1738
|
+
-------
|
1739
|
+
result: Union[DecodingResult, List[DecodingResult]]
|
1740
|
+
The result(s) of decoding contained in `DecodingResult` dataclass instance(s)
|
1741
|
+
"""
|
1742
|
+
single = mel.ndim == 2
|
1743
|
+
if single:
|
1744
|
+
mel = mel.unsqueeze(0)
|
1745
|
+
|
1746
|
+
result = DecodingTask(model, options, resource_path).run(mel)
|
1747
|
+
|
1748
|
+
if single:
|
1749
|
+
result = result[0]
|
1750
|
+
|
1751
|
+
return result
|
1752
|
+
|
1753
|
+
|
1754
|
+
class Whisper(paddle.nn.Layer):
|
1755
|
+
"""
|
1756
|
+
The `Whisper` module use AudioEncoder and TextDecoder, and return detect_language, transcribe, decode.
|
1757
|
+
"""
|
1758
|
+
|
1759
|
+
def __init__(self, dims: ModelDimensions):
|
1760
|
+
super().__init__()
|
1761
|
+
self.dims = dims
|
1762
|
+
self.encoder = AudioEncoder(
|
1763
|
+
self.dims.n_mels,
|
1764
|
+
self.dims.n_audio_ctx,
|
1765
|
+
self.dims.n_audio_state,
|
1766
|
+
self.dims.n_audio_head,
|
1767
|
+
self.dims.n_audio_layer,
|
1768
|
+
)
|
1769
|
+
self.decoder = TextDecoder(
|
1770
|
+
self.dims.n_vocab,
|
1771
|
+
self.dims.n_text_ctx,
|
1772
|
+
self.dims.n_text_state,
|
1773
|
+
self.dims.n_text_head,
|
1774
|
+
self.dims.n_text_layer,
|
1775
|
+
)
|
1776
|
+
|
1777
|
+
def embed_audio(self, mel: paddle.Tensor):
|
1778
|
+
return self.encoder.forward(mel)
|
1779
|
+
|
1780
|
+
def logits(self, tokens: paddle.Tensor, audio_features: paddle.Tensor):
|
1781
|
+
return self.decoder.forward(tokens, audio_features)
|
1782
|
+
|
1783
|
+
def forward(
|
1784
|
+
self, mel: paddle.Tensor, tokens: paddle.Tensor
|
1785
|
+
) -> Dict[str, paddle.Tensor]:
|
1786
|
+
return self.decoder(tokens, self.encoder(mel))
|
1787
|
+
|
1788
|
+
@property
|
1789
|
+
def device(self):
|
1790
|
+
return paddle.device.get_device()
|
1791
|
+
|
1792
|
+
@property
|
1793
|
+
def is_multilingual(self):
|
1794
|
+
return self.dims.n_vocab == 51865
|
1795
|
+
|
1796
|
+
def install_kv_cache_hooks(self, cache: Optional[dict] = None):
|
1797
|
+
"""
|
1798
|
+
The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
|
1799
|
+
tensors calculated for the previous positions. This method returns a dictionary that stores
|
1800
|
+
all caches, and the necessary hooks for the key and value projection modules that save the
|
1801
|
+
intermediate tensors to be reused during later calculations.
|
1802
|
+
Returns
|
1803
|
+
-------
|
1804
|
+
cache : Dict[nn.Layer, paddle.Tensor]
|
1805
|
+
A dictionary object mapping the key/value projection modules to its cache
|
1806
|
+
hooks : List[RemovableHandle]
|
1807
|
+
List of PyTorch RemovableHandle objects to stop the hooks to be called
|
1808
|
+
"""
|
1809
|
+
cache = {**cache} if cache is not None else {}
|
1810
|
+
hooks = []
|
1811
|
+
|
1812
|
+
def save_to_cache(module, _, output):
|
1813
|
+
if (
|
1814
|
+
module not in cache
|
1815
|
+
or output.shape[1] > self.decoder.positional_embedding.shape[0]
|
1816
|
+
):
|
1817
|
+
cache[module] = (
|
1818
|
+
output # save as-is, for the first token or cross attention
|
1819
|
+
)
|
1820
|
+
else:
|
1821
|
+
cache[module] = paddle.concat([cache[module], output], axis=1).detach()
|
1822
|
+
return cache[module]
|
1823
|
+
|
1824
|
+
def install_hooks(layer: paddle.nn.Layer):
|
1825
|
+
if isinstance(layer, MultiHeadAttention):
|
1826
|
+
hooks.append(layer.key.register_forward_post_hook(save_to_cache))
|
1827
|
+
hooks.append(layer.value.register_forward_post_hook(save_to_cache))
|
1828
|
+
|
1829
|
+
self.decoder.apply(install_hooks)
|
1830
|
+
return cache, hooks
|
1831
|
+
|
1832
|
+
detect_language = detect_language
|
1833
|
+
set_inference_operations(get_inference_operations() + ["speech_transcribe"])
|
1834
|
+
transcribe = benchmark.timeit_with_options(name="speech_transcribe")(transcribe)
|
1835
|
+
decode = decode
|
1836
|
+
|
1837
|
+
|
1838
|
+
def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
|
1839
|
+
"""
|
1840
|
+
Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
|
1841
|
+
"""
|
1842
|
+
if paddle.is_tensor(array):
|
1843
|
+
if array.shape[axis] > length:
|
1844
|
+
array = array.index_select(axis=axis, index=paddle.arange(length))
|
1845
|
+
|
1846
|
+
if array.shape[axis] < length:
|
1847
|
+
pad_widths = [(0, 0)] * array.ndim
|
1848
|
+
pad_widths[axis] = (0, length - array.shape[axis])
|
1849
|
+
array = paddle.transpose(array, (1, 0))
|
1850
|
+
array = paddle.nn.functional.pad(
|
1851
|
+
array,
|
1852
|
+
[pad for sizes in pad_widths[::-1] for pad in sizes],
|
1853
|
+
data_format="NLC",
|
1854
|
+
)
|
1855
|
+
array = paddle.transpose(array, (1, 0))
|
1856
|
+
else:
|
1857
|
+
if array.shape[axis] > length:
|
1858
|
+
array = array.take(indices=range(length), axis=axis)
|
1859
|
+
|
1860
|
+
if array.shape[axis] < length:
|
1861
|
+
pad_widths = [(0, 0)] * array.ndim
|
1862
|
+
pad_widths[axis] = (0, length - array.shape[axis])
|
1863
|
+
array = paddle.transpose(array, (1, 0))
|
1864
|
+
array = np.pad(array, pad_widths)
|
1865
|
+
array = paddle.transpose(array, (1, 0))
|
1866
|
+
|
1867
|
+
return array
|
1868
|
+
|
1869
|
+
|
1870
|
+
def hann_window(n_fft: int = N_FFT):
|
1871
|
+
"""
|
1872
|
+
hanning window
|
1873
|
+
n_fft: The number of frequency components of the discrete Fourier transform.
|
1874
|
+
"""
|
1875
|
+
return paddle.to_tensor(
|
1876
|
+
[0.5 - 0.5 * np.cos(2 * np.pi * n / n_fft) for n in range(n_fft)],
|
1877
|
+
dtype=paddle.float32,
|
1878
|
+
)
|
1879
|
+
|
1880
|
+
|
1881
|
+
@lru_cache(maxsize=None)
|
1882
|
+
def mel_filters(resource_path: str, n_mels: int = N_MELS) -> paddle.Tensor:
|
1883
|
+
"""
|
1884
|
+
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
|
1885
|
+
Allows decoupling librosa dependency; saved using:
|
1886
|
+
np.savez_compressed(
|
1887
|
+
"mel_filters.npz",
|
1888
|
+
mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
|
1889
|
+
)
|
1890
|
+
"""
|
1891
|
+
assert n_mels == 80, f"Unsupported n_mels: {n_mels}"
|
1892
|
+
with np.load(os.path.join(resource_path, "assets", "mel_filters.npz")) as f:
|
1893
|
+
return paddle.to_tensor(f[f"mel_{n_mels}"])
|
1894
|
+
|
1895
|
+
|
1896
|
+
@function_requires_deps("soundfile")
|
1897
|
+
def log_mel_spectrogram(
|
1898
|
+
audio: Union[str, np.ndarray, paddle.Tensor],
|
1899
|
+
n_mels: int = N_MELS,
|
1900
|
+
resource_path: str = None,
|
1901
|
+
):
|
1902
|
+
"""
|
1903
|
+
Compute the log-Mel spectrogram of
|
1904
|
+
Parameters
|
1905
|
+
----------
|
1906
|
+
audio: Union[str, np.ndarray, paddle.Tensor], shape = (*)
|
1907
|
+
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
|
1908
|
+
n_mels: int
|
1909
|
+
The number of Mel-frequency filters, only 80 is supported
|
1910
|
+
Returns
|
1911
|
+
-------
|
1912
|
+
paddle.Tensor, shape = (80, n_frames)
|
1913
|
+
A Tensor that contains the Mel spectrogram
|
1914
|
+
"""
|
1915
|
+
if not paddle.is_tensor(audio):
|
1916
|
+
if isinstance(audio, str):
|
1917
|
+
audio, _ = soundfile.read(audio, dtype="float32", always_2d=True)
|
1918
|
+
audio = audio[:, 0]
|
1919
|
+
audio = paddle.to_tensor(audio)
|
1920
|
+
|
1921
|
+
window = hann_window(N_FFT)
|
1922
|
+
stft = paddle.signal.stft(audio, N_FFT, HOP_LENGTH, window=window)
|
1923
|
+
|
1924
|
+
magnitudes = stft[:, :-1].abs() ** 2
|
1925
|
+
|
1926
|
+
filters = mel_filters(resource_path, n_mels)
|
1927
|
+
mel_spec = filters @ magnitudes
|
1928
|
+
mel_spec = paddle.to_tensor(mel_spec.numpy().tolist())
|
1929
|
+
|
1930
|
+
log_spec = paddle.clip(mel_spec, min=1e-10).log10()
|
1931
|
+
log_spec = paddle.maximum(log_spec, log_spec.max() - 8.0)
|
1932
|
+
log_spec = (log_spec + 4.0) / 4.0
|
1933
|
+
return log_spec
|