paddlex 3.0.0rc0__py3-none-any.whl → 3.0.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- paddlex/.version +1 -1
- paddlex/__init__.py +17 -34
- paddlex/__main__.py +1 -1
- paddlex/configs/modules/doc_vlm/PP-DocBee-2B.yaml +14 -0
- paddlex/configs/modules/doc_vlm/PP-DocBee-7B.yaml +14 -0
- paddlex/configs/modules/open_vocabulary_detection/YOLO-Worldv2-L.yaml +13 -0
- paddlex/configs/pipelines/anomaly_detection.yaml +1 -1
- paddlex/configs/pipelines/doc_understanding.yaml +9 -0
- paddlex/configs/pipelines/ts_anomaly_detection.yaml +1 -1
- paddlex/configs/pipelines/ts_classification.yaml +1 -1
- paddlex/configs/pipelines/ts_forecast.yaml +1 -1
- paddlex/constants.py +17 -0
- paddlex/engine.py +7 -5
- paddlex/hpip_links.html +23 -11
- paddlex/inference/__init__.py +3 -3
- paddlex/inference/common/__init__.py +1 -1
- paddlex/inference/common/batch_sampler/__init__.py +5 -4
- paddlex/inference/common/batch_sampler/audio_batch_sampler.py +5 -6
- paddlex/inference/common/batch_sampler/base_batch_sampler.py +20 -16
- paddlex/inference/common/batch_sampler/det_3d_batch_sampler.py +4 -7
- paddlex/inference/common/batch_sampler/doc_vlm_batch_sampler.py +64 -0
- paddlex/inference/common/batch_sampler/image_batch_sampler.py +12 -36
- paddlex/inference/common/batch_sampler/ts_batch_sampler.py +9 -10
- paddlex/inference/common/batch_sampler/video_batch_sampler.py +2 -22
- paddlex/inference/common/reader/__init__.py +4 -4
- paddlex/inference/common/reader/audio_reader.py +3 -3
- paddlex/inference/common/reader/det_3d_reader.py +7 -5
- paddlex/inference/common/reader/image_reader.py +16 -12
- paddlex/inference/common/reader/ts_reader.py +3 -2
- paddlex/inference/common/reader/video_reader.py +3 -3
- paddlex/inference/common/result/__init__.py +7 -7
- paddlex/inference/common/result/base_cv_result.py +12 -2
- paddlex/inference/common/result/base_result.py +7 -5
- paddlex/inference/common/result/base_ts_result.py +1 -2
- paddlex/inference/common/result/base_video_result.py +2 -2
- paddlex/inference/common/result/mixin.py +12 -13
- paddlex/inference/models/__init__.py +41 -85
- paddlex/inference/models/anomaly_detection/__init__.py +1 -1
- paddlex/inference/models/anomaly_detection/predictor.py +9 -19
- paddlex/inference/models/anomaly_detection/processors.py +9 -2
- paddlex/inference/models/anomaly_detection/result.py +3 -2
- paddlex/inference/models/base/__init__.py +2 -2
- paddlex/inference/models/base/predictor/__init__.py +1 -2
- paddlex/inference/models/base/predictor/base_predictor.py +284 -39
- paddlex/inference/models/common/__init__.py +6 -15
- paddlex/inference/models/common/static_infer.py +764 -243
- paddlex/inference/models/common/tokenizer/__init__.py +5 -3
- paddlex/inference/models/common/tokenizer/bert_tokenizer.py +1 -1
- paddlex/inference/models/common/tokenizer/clip_tokenizer.py +609 -0
- paddlex/inference/models/common/tokenizer/gpt_tokenizer.py +7 -5
- paddlex/inference/models/common/tokenizer/qwen2_tokenizer.py +432 -0
- paddlex/inference/models/common/tokenizer/tokenizer_utils.py +72 -64
- paddlex/inference/models/common/tokenizer/tokenizer_utils_base.py +337 -121
- paddlex/inference/models/common/tokenizer/utils.py +1 -1
- paddlex/inference/models/common/tokenizer/vocab.py +1 -1
- paddlex/inference/models/common/ts/__init__.py +1 -1
- paddlex/inference/models/common/ts/funcs.py +13 -6
- paddlex/inference/models/common/ts/processors.py +14 -5
- paddlex/inference/models/common/vision/__init__.py +3 -3
- paddlex/inference/models/common/vision/funcs.py +17 -12
- paddlex/inference/models/common/vision/processors.py +61 -46
- paddlex/inference/models/common/vlm/__init__.py +13 -0
- paddlex/inference/models/common/vlm/activations.py +189 -0
- paddlex/inference/models/common/vlm/bert_padding.py +127 -0
- paddlex/inference/models/common/vlm/distributed.py +229 -0
- paddlex/inference/models/common/vlm/flash_attn_utils.py +119 -0
- paddlex/inference/models/common/vlm/generation/__init__.py +34 -0
- paddlex/inference/models/common/vlm/generation/configuration_utils.py +533 -0
- paddlex/inference/models/common/vlm/generation/logits_process.py +730 -0
- paddlex/inference/models/common/vlm/generation/stopping_criteria.py +106 -0
- paddlex/inference/models/common/vlm/generation/utils.py +2162 -0
- paddlex/inference/models/common/vlm/transformers/__init__.py +16 -0
- paddlex/inference/models/common/vlm/transformers/configuration_utils.py +1037 -0
- paddlex/inference/models/common/vlm/transformers/conversion_utils.py +408 -0
- paddlex/inference/models/common/vlm/transformers/model_outputs.py +1612 -0
- paddlex/inference/models/common/vlm/transformers/model_utils.py +2038 -0
- paddlex/inference/models/common/vlm/transformers/utils.py +178 -0
- paddlex/inference/models/common/vlm/utils.py +109 -0
- paddlex/inference/models/doc_vlm/__init__.py +15 -0
- paddlex/inference/models/doc_vlm/modeling/__init__.py +15 -0
- paddlex/inference/models/doc_vlm/modeling/qwen2_vl.py +2600 -0
- paddlex/inference/models/doc_vlm/predictor.py +198 -0
- paddlex/inference/models/doc_vlm/processors/__init__.py +15 -0
- paddlex/inference/models/doc_vlm/processors/common.py +372 -0
- paddlex/inference/models/doc_vlm/processors/qwen2_vl.py +698 -0
- paddlex/inference/models/doc_vlm/result.py +21 -0
- paddlex/inference/models/face_feature/__init__.py +1 -1
- paddlex/inference/models/face_feature/predictor.py +2 -1
- paddlex/inference/models/formula_recognition/__init__.py +1 -1
- paddlex/inference/models/formula_recognition/predictor.py +11 -27
- paddlex/inference/models/formula_recognition/processors.py +35 -19
- paddlex/inference/models/formula_recognition/result.py +19 -12
- paddlex/inference/models/image_classification/__init__.py +1 -1
- paddlex/inference/models/image_classification/predictor.py +9 -19
- paddlex/inference/models/image_classification/processors.py +4 -2
- paddlex/inference/models/image_classification/result.py +4 -3
- paddlex/inference/models/image_feature/__init__.py +1 -1
- paddlex/inference/models/image_feature/predictor.py +9 -19
- paddlex/inference/models/image_feature/processors.py +4 -1
- paddlex/inference/models/image_feature/result.py +2 -3
- paddlex/inference/models/image_multilabel_classification/__init__.py +1 -1
- paddlex/inference/models/image_multilabel_classification/predictor.py +7 -6
- paddlex/inference/models/image_multilabel_classification/processors.py +6 -2
- paddlex/inference/models/image_multilabel_classification/result.py +4 -3
- paddlex/inference/models/image_unwarping/__init__.py +1 -1
- paddlex/inference/models/image_unwarping/predictor.py +8 -16
- paddlex/inference/models/image_unwarping/processors.py +6 -2
- paddlex/inference/models/image_unwarping/result.py +4 -2
- paddlex/inference/models/instance_segmentation/__init__.py +1 -1
- paddlex/inference/models/instance_segmentation/predictor.py +7 -15
- paddlex/inference/models/instance_segmentation/processors.py +4 -7
- paddlex/inference/models/instance_segmentation/result.py +11 -10
- paddlex/inference/models/keypoint_detection/__init__.py +1 -1
- paddlex/inference/models/keypoint_detection/predictor.py +2 -3
- paddlex/inference/models/keypoint_detection/processors.py +11 -3
- paddlex/inference/models/keypoint_detection/result.py +9 -4
- paddlex/inference/models/{3d_bev_detection → m_3d_bev_detection}/__init__.py +1 -1
- paddlex/inference/models/{3d_bev_detection → m_3d_bev_detection}/predictor.py +15 -26
- paddlex/inference/models/{3d_bev_detection → m_3d_bev_detection}/processors.py +26 -14
- paddlex/inference/models/{3d_bev_detection → m_3d_bev_detection}/result.py +15 -12
- paddlex/inference/models/{3d_bev_detection → m_3d_bev_detection}/visualizer_3d.py +77 -39
- paddlex/inference/models/multilingual_speech_recognition/__init__.py +1 -1
- paddlex/inference/models/multilingual_speech_recognition/predictor.py +11 -15
- paddlex/inference/models/multilingual_speech_recognition/processors.py +45 -53
- paddlex/inference/models/multilingual_speech_recognition/result.py +1 -1
- paddlex/inference/models/object_detection/__init__.py +1 -1
- paddlex/inference/models/object_detection/predictor.py +6 -12
- paddlex/inference/models/object_detection/processors.py +36 -31
- paddlex/inference/models/object_detection/result.py +5 -4
- paddlex/inference/models/object_detection/utils.py +1 -1
- paddlex/inference/models/open_vocabulary_detection/__init__.py +1 -1
- paddlex/inference/models/open_vocabulary_detection/predictor.py +31 -14
- paddlex/inference/models/open_vocabulary_detection/processors/__init__.py +3 -2
- paddlex/inference/models/open_vocabulary_detection/processors/common.py +114 -0
- paddlex/inference/models/open_vocabulary_detection/processors/groundingdino_processors.py +19 -8
- paddlex/inference/models/open_vocabulary_detection/processors/yoloworld_processors.py +209 -0
- paddlex/inference/models/open_vocabulary_segmentation/__init__.py +1 -1
- paddlex/inference/models/open_vocabulary_segmentation/predictor.py +6 -13
- paddlex/inference/models/open_vocabulary_segmentation/processors/__init__.py +1 -1
- paddlex/inference/models/open_vocabulary_segmentation/processors/sam_processer.py +12 -12
- paddlex/inference/models/open_vocabulary_segmentation/results/__init__.py +1 -1
- paddlex/inference/models/open_vocabulary_segmentation/results/sam_result.py +11 -9
- paddlex/inference/models/semantic_segmentation/__init__.py +1 -1
- paddlex/inference/models/semantic_segmentation/predictor.py +9 -18
- paddlex/inference/models/semantic_segmentation/processors.py +11 -8
- paddlex/inference/models/semantic_segmentation/result.py +4 -3
- paddlex/inference/models/table_structure_recognition/__init__.py +1 -1
- paddlex/inference/models/table_structure_recognition/predictor.py +8 -18
- paddlex/inference/models/table_structure_recognition/processors.py +23 -29
- paddlex/inference/models/table_structure_recognition/result.py +9 -6
- paddlex/inference/models/text_detection/__init__.py +1 -1
- paddlex/inference/models/text_detection/predictor.py +16 -24
- paddlex/inference/models/text_detection/processors.py +74 -36
- paddlex/inference/models/text_detection/result.py +9 -4
- paddlex/inference/models/text_recognition/__init__.py +1 -1
- paddlex/inference/models/text_recognition/predictor.py +11 -19
- paddlex/inference/models/text_recognition/processors.py +27 -13
- paddlex/inference/models/text_recognition/result.py +3 -2
- paddlex/inference/models/ts_anomaly_detection/__init__.py +1 -1
- paddlex/inference/models/ts_anomaly_detection/predictor.py +12 -17
- paddlex/inference/models/ts_anomaly_detection/processors.py +6 -2
- paddlex/inference/models/ts_anomaly_detection/result.py +21 -10
- paddlex/inference/models/ts_classification/__init__.py +1 -1
- paddlex/inference/models/ts_classification/predictor.py +14 -27
- paddlex/inference/models/ts_classification/processors.py +7 -2
- paddlex/inference/models/ts_classification/result.py +21 -12
- paddlex/inference/models/ts_forecasting/__init__.py +1 -1
- paddlex/inference/models/ts_forecasting/predictor.py +13 -18
- paddlex/inference/models/ts_forecasting/processors.py +12 -3
- paddlex/inference/models/ts_forecasting/result.py +24 -11
- paddlex/inference/models/video_classification/__init__.py +1 -1
- paddlex/inference/models/video_classification/predictor.py +9 -15
- paddlex/inference/models/video_classification/processors.py +24 -24
- paddlex/inference/models/video_classification/result.py +7 -3
- paddlex/inference/models/video_detection/__init__.py +1 -1
- paddlex/inference/models/video_detection/predictor.py +8 -15
- paddlex/inference/models/video_detection/processors.py +24 -11
- paddlex/inference/models/video_detection/result.py +10 -5
- paddlex/inference/pipelines/__init__.py +44 -37
- paddlex/inference/pipelines/anomaly_detection/__init__.py +1 -1
- paddlex/inference/pipelines/anomaly_detection/pipeline.py +16 -6
- paddlex/inference/pipelines/attribute_recognition/__init__.py +1 -1
- paddlex/inference/pipelines/attribute_recognition/pipeline.py +13 -8
- paddlex/inference/pipelines/attribute_recognition/result.py +10 -8
- paddlex/inference/pipelines/base.py +31 -11
- paddlex/inference/pipelines/components/__init__.py +14 -8
- paddlex/inference/pipelines/components/chat_server/__init__.py +1 -1
- paddlex/inference/pipelines/components/chat_server/base.py +2 -2
- paddlex/inference/pipelines/components/chat_server/openai_bot_chat.py +8 -8
- paddlex/inference/pipelines/components/common/__init__.py +5 -4
- paddlex/inference/pipelines/components/common/base_operator.py +2 -1
- paddlex/inference/pipelines/components/common/base_result.py +3 -2
- paddlex/inference/pipelines/components/common/convert_points_and_boxes.py +1 -2
- paddlex/inference/pipelines/components/common/crop_image_regions.py +11 -5
- paddlex/inference/pipelines/components/common/seal_det_warp.py +44 -13
- paddlex/inference/pipelines/components/common/sort_boxes.py +4 -2
- paddlex/inference/pipelines/components/common/warp_image.py +50 -0
- paddlex/inference/pipelines/components/faisser.py +9 -4
- paddlex/inference/pipelines/components/prompt_engineering/__init__.py +2 -2
- paddlex/inference/pipelines/components/prompt_engineering/base.py +2 -2
- paddlex/inference/pipelines/components/prompt_engineering/generate_ensemble_prompt.py +2 -1
- paddlex/inference/pipelines/components/prompt_engineering/generate_kie_prompt.py +2 -2
- paddlex/inference/pipelines/components/retriever/__init__.py +2 -2
- paddlex/inference/pipelines/components/retriever/base.py +18 -16
- paddlex/inference/pipelines/components/retriever/openai_bot_retriever.py +2 -2
- paddlex/inference/pipelines/components/retriever/qianfan_bot_retriever.py +87 -84
- paddlex/inference/pipelines/components/utils/__init__.py +1 -1
- paddlex/inference/pipelines/components/utils/mixin.py +7 -7
- paddlex/inference/pipelines/doc_preprocessor/__init__.py +1 -1
- paddlex/inference/pipelines/doc_preprocessor/pipeline.py +21 -28
- paddlex/inference/pipelines/doc_preprocessor/result.py +5 -10
- paddlex/inference/pipelines/doc_understanding/__init__.py +15 -0
- paddlex/inference/pipelines/doc_understanding/pipeline.py +71 -0
- paddlex/inference/pipelines/face_recognition/__init__.py +1 -1
- paddlex/inference/pipelines/face_recognition/pipeline.py +3 -1
- paddlex/inference/pipelines/face_recognition/result.py +3 -2
- paddlex/inference/pipelines/formula_recognition/__init__.py +1 -1
- paddlex/inference/pipelines/formula_recognition/pipeline.py +22 -16
- paddlex/inference/pipelines/formula_recognition/result.py +20 -19
- paddlex/inference/pipelines/image_classification/__init__.py +1 -1
- paddlex/inference/pipelines/image_classification/pipeline.py +17 -8
- paddlex/inference/pipelines/image_multilabel_classification/__init__.py +1 -1
- paddlex/inference/pipelines/image_multilabel_classification/pipeline.py +18 -9
- paddlex/inference/pipelines/instance_segmentation/__init__.py +1 -1
- paddlex/inference/pipelines/instance_segmentation/pipeline.py +17 -6
- paddlex/inference/pipelines/keypoint_detection/__init__.py +1 -1
- paddlex/inference/pipelines/keypoint_detection/pipeline.py +17 -6
- paddlex/inference/pipelines/layout_parsing/__init__.py +1 -1
- paddlex/inference/pipelines/layout_parsing/pipeline.py +23 -12
- paddlex/inference/pipelines/layout_parsing/pipeline_v2.py +16 -6
- paddlex/inference/pipelines/layout_parsing/result.py +5 -4
- paddlex/inference/pipelines/layout_parsing/result_v2.py +5 -8
- paddlex/inference/pipelines/layout_parsing/utils.py +7 -8
- paddlex/inference/pipelines/{3d_bev_detection → m_3d_bev_detection}/__init__.py +1 -1
- paddlex/inference/pipelines/{3d_bev_detection → m_3d_bev_detection}/pipeline.py +17 -10
- paddlex/inference/pipelines/multilingual_speech_recognition/__init__.py +1 -1
- paddlex/inference/pipelines/multilingual_speech_recognition/pipeline.py +17 -6
- paddlex/inference/pipelines/object_detection/__init__.py +1 -1
- paddlex/inference/pipelines/object_detection/pipeline.py +16 -6
- paddlex/inference/pipelines/ocr/__init__.py +1 -1
- paddlex/inference/pipelines/ocr/pipeline.py +28 -11
- paddlex/inference/pipelines/ocr/result.py +13 -9
- paddlex/inference/pipelines/open_vocabulary_detection/__init__.py +1 -1
- paddlex/inference/pipelines/open_vocabulary_detection/pipeline.py +17 -6
- paddlex/inference/pipelines/open_vocabulary_segmentation/__init__.py +1 -1
- paddlex/inference/pipelines/open_vocabulary_segmentation/pipeline.py +17 -6
- paddlex/inference/pipelines/pp_chatocr/__init__.py +1 -1
- paddlex/inference/pipelines/pp_chatocr/pipeline_base.py +14 -5
- paddlex/inference/pipelines/pp_chatocr/pipeline_v3.py +22 -11
- paddlex/inference/pipelines/pp_chatocr/pipeline_v4.py +31 -13
- paddlex/inference/pipelines/pp_shitu_v2/__init__.py +1 -1
- paddlex/inference/pipelines/pp_shitu_v2/pipeline.py +12 -8
- paddlex/inference/pipelines/pp_shitu_v2/result.py +4 -4
- paddlex/inference/pipelines/rotated_object_detection/__init__.py +1 -1
- paddlex/inference/pipelines/rotated_object_detection/pipeline.py +17 -6
- paddlex/inference/pipelines/seal_recognition/__init__.py +1 -1
- paddlex/inference/pipelines/seal_recognition/pipeline.py +21 -13
- paddlex/inference/pipelines/seal_recognition/result.py +4 -2
- paddlex/inference/pipelines/semantic_segmentation/__init__.py +1 -1
- paddlex/inference/pipelines/semantic_segmentation/pipeline.py +17 -6
- paddlex/inference/pipelines/small_object_detection/__init__.py +1 -1
- paddlex/inference/pipelines/small_object_detection/pipeline.py +17 -6
- paddlex/inference/pipelines/table_recognition/__init__.py +1 -1
- paddlex/inference/pipelines/table_recognition/pipeline.py +41 -25
- paddlex/inference/pipelines/table_recognition/pipeline_v2.py +65 -33
- paddlex/inference/pipelines/table_recognition/result.py +11 -9
- paddlex/inference/pipelines/table_recognition/table_recognition_post_processing.py +12 -8
- paddlex/inference/pipelines/table_recognition/table_recognition_post_processing_v2.py +46 -32
- paddlex/inference/pipelines/table_recognition/utils.py +1 -1
- paddlex/inference/pipelines/ts_anomaly_detection/__init__.py +1 -1
- paddlex/inference/pipelines/ts_anomaly_detection/pipeline.py +16 -6
- paddlex/inference/pipelines/ts_classification/__init__.py +1 -1
- paddlex/inference/pipelines/ts_classification/pipeline.py +16 -6
- paddlex/inference/pipelines/ts_forecasting/__init__.py +1 -1
- paddlex/inference/pipelines/ts_forecasting/pipeline.py +16 -6
- paddlex/inference/pipelines/video_classification/__init__.py +1 -1
- paddlex/inference/pipelines/video_classification/pipeline.py +17 -6
- paddlex/inference/pipelines/video_detection/__init__.py +1 -1
- paddlex/inference/pipelines/video_detection/pipeline.py +20 -7
- paddlex/inference/serving/__init__.py +5 -1
- paddlex/inference/serving/basic_serving/__init__.py +1 -1
- paddlex/inference/serving/basic_serving/_app.py +31 -19
- paddlex/inference/serving/basic_serving/_pipeline_apps/__init__.py +7 -4
- paddlex/inference/serving/basic_serving/_pipeline_apps/_common/__init__.py +1 -1
- paddlex/inference/serving/basic_serving/_pipeline_apps/_common/common.py +7 -3
- paddlex/inference/serving/basic_serving/_pipeline_apps/_common/image_recognition.py +1 -1
- paddlex/inference/serving/basic_serving/_pipeline_apps/_common/ocr.py +7 -2
- paddlex/inference/serving/basic_serving/_pipeline_apps/anomaly_detection.py +10 -7
- paddlex/inference/serving/basic_serving/_pipeline_apps/doc_preprocessor.py +10 -7
- paddlex/inference/serving/basic_serving/_pipeline_apps/doc_understanding.py +153 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/face_recognition.py +16 -13
- paddlex/inference/serving/basic_serving/_pipeline_apps/formula_recognition.py +10 -7
- paddlex/inference/serving/basic_serving/_pipeline_apps/human_keypoint_detection.py +10 -7
- paddlex/inference/serving/basic_serving/_pipeline_apps/image_classification.py +10 -7
- paddlex/inference/serving/basic_serving/_pipeline_apps/image_multilabel_classification.py +10 -7
- paddlex/inference/serving/basic_serving/_pipeline_apps/instance_segmentation.py +13 -7
- paddlex/inference/serving/basic_serving/_pipeline_apps/layout_parsing.py +10 -7
- paddlex/inference/serving/basic_serving/_pipeline_apps/m_3d_bev_detection.py +10 -7
- paddlex/inference/serving/basic_serving/_pipeline_apps/multilingual_speech_recognition.py +10 -7
- paddlex/inference/serving/basic_serving/_pipeline_apps/object_detection.py +10 -7
- paddlex/inference/serving/basic_serving/_pipeline_apps/ocr.py +10 -7
- paddlex/inference/serving/basic_serving/_pipeline_apps/open_vocabulary_detection.py +10 -7
- paddlex/inference/serving/basic_serving/_pipeline_apps/open_vocabulary_segmentation.py +13 -7
- paddlex/inference/serving/basic_serving/_pipeline_apps/pedestrian_attribute_recognition.py +10 -7
- paddlex/inference/serving/basic_serving/_pipeline_apps/pp_chatocrv3_doc.py +14 -11
- paddlex/inference/serving/basic_serving/_pipeline_apps/pp_chatocrv4_doc.py +16 -13
- paddlex/inference/serving/basic_serving/_pipeline_apps/pp_shituv2.py +16 -13
- paddlex/inference/serving/basic_serving/_pipeline_apps/pp_structurev3.py +10 -7
- paddlex/inference/serving/basic_serving/_pipeline_apps/rotated_object_detection.py +10 -7
- paddlex/inference/serving/basic_serving/_pipeline_apps/seal_recognition.py +10 -7
- paddlex/inference/serving/basic_serving/_pipeline_apps/semantic_segmentation.py +10 -7
- paddlex/inference/serving/basic_serving/_pipeline_apps/small_object_detection.py +10 -7
- paddlex/inference/serving/basic_serving/_pipeline_apps/table_recognition.py +10 -7
- paddlex/inference/serving/basic_serving/_pipeline_apps/table_recognition_v2.py +10 -7
- paddlex/inference/serving/basic_serving/_pipeline_apps/ts_anomaly_detection.py +10 -7
- paddlex/inference/serving/basic_serving/_pipeline_apps/ts_classification.py +10 -7
- paddlex/inference/serving/basic_serving/_pipeline_apps/ts_forecast.py +10 -7
- paddlex/inference/serving/basic_serving/_pipeline_apps/vehicle_attribute_recognition.py +10 -7
- paddlex/inference/serving/basic_serving/_pipeline_apps/video_classification.py +10 -7
- paddlex/inference/serving/basic_serving/_pipeline_apps/video_detection.py +10 -7
- paddlex/inference/serving/basic_serving/_server.py +9 -4
- paddlex/inference/serving/infra/__init__.py +1 -1
- paddlex/inference/serving/infra/config.py +1 -1
- paddlex/inference/serving/infra/models.py +13 -6
- paddlex/inference/serving/infra/storage.py +9 -4
- paddlex/inference/serving/infra/utils.py +37 -9
- paddlex/inference/serving/schemas/__init__.py +1 -1
- paddlex/inference/serving/schemas/anomaly_detection.py +1 -1
- paddlex/inference/serving/schemas/doc_preprocessor.py +1 -1
- paddlex/inference/serving/schemas/doc_understanding.py +78 -0
- paddlex/inference/serving/schemas/face_recognition.py +1 -1
- paddlex/inference/serving/schemas/formula_recognition.py +1 -1
- paddlex/inference/serving/schemas/human_keypoint_detection.py +1 -1
- paddlex/inference/serving/schemas/image_classification.py +1 -1
- paddlex/inference/serving/schemas/image_multilabel_classification.py +1 -1
- paddlex/inference/serving/schemas/instance_segmentation.py +1 -1
- paddlex/inference/serving/schemas/layout_parsing.py +1 -1
- paddlex/inference/serving/schemas/m_3d_bev_detection.py +1 -1
- paddlex/inference/serving/schemas/multilingual_speech_recognition.py +1 -1
- paddlex/inference/serving/schemas/object_detection.py +1 -1
- paddlex/inference/serving/schemas/ocr.py +1 -1
- paddlex/inference/serving/schemas/open_vocabulary_detection.py +1 -1
- paddlex/inference/serving/schemas/open_vocabulary_segmentation.py +1 -1
- paddlex/inference/serving/schemas/pedestrian_attribute_recognition.py +1 -1
- paddlex/inference/serving/schemas/pp_chatocrv3_doc.py +1 -1
- paddlex/inference/serving/schemas/pp_chatocrv4_doc.py +1 -1
- paddlex/inference/serving/schemas/pp_shituv2.py +1 -1
- paddlex/inference/serving/schemas/pp_structurev3.py +1 -1
- paddlex/inference/serving/schemas/rotated_object_detection.py +1 -1
- paddlex/inference/serving/schemas/seal_recognition.py +1 -1
- paddlex/inference/serving/schemas/semantic_segmentation.py +1 -1
- paddlex/inference/serving/schemas/shared/__init__.py +1 -1
- paddlex/inference/serving/schemas/shared/classification.py +1 -1
- paddlex/inference/serving/schemas/shared/image_segmentation.py +1 -1
- paddlex/inference/serving/schemas/shared/object_detection.py +1 -1
- paddlex/inference/serving/schemas/shared/ocr.py +1 -1
- paddlex/inference/serving/schemas/small_object_detection.py +1 -1
- paddlex/inference/serving/schemas/table_recognition.py +1 -1
- paddlex/inference/serving/schemas/table_recognition_v2.py +1 -1
- paddlex/inference/serving/schemas/ts_anomaly_detection.py +1 -1
- paddlex/inference/serving/schemas/ts_classification.py +1 -1
- paddlex/inference/serving/schemas/ts_forecast.py +1 -1
- paddlex/inference/serving/schemas/vehicle_attribute_recognition.py +1 -1
- paddlex/inference/serving/schemas/video_classification.py +1 -1
- paddlex/inference/serving/schemas/video_detection.py +1 -1
- paddlex/inference/utils/__init__.py +1 -1
- paddlex/inference/utils/benchmark.py +332 -179
- paddlex/inference/utils/color_map.py +1 -1
- paddlex/inference/utils/get_pipeline_path.py +1 -1
- paddlex/inference/utils/hpi.py +251 -0
- paddlex/inference/utils/hpi_model_info_collection.json +2252 -0
- paddlex/inference/utils/io/__init__.py +11 -11
- paddlex/inference/utils/io/readers.py +22 -18
- paddlex/inference/utils/io/style.py +21 -14
- paddlex/inference/utils/io/tablepyxl.py +13 -5
- paddlex/inference/utils/io/writers.py +9 -10
- paddlex/inference/utils/model_paths.py +48 -0
- paddlex/inference/utils/{new_ir_blacklist.py → new_ir_blocklist.py} +1 -2
- paddlex/inference/utils/official_models.py +264 -262
- paddlex/inference/utils/pp_option.py +164 -93
- paddlex/inference/utils/trt_blocklist.py +43 -0
- paddlex/inference/utils/trt_config.py +420 -0
- paddlex/model.py +28 -10
- paddlex/modules/__init__.py +57 -80
- paddlex/modules/anomaly_detection/__init__.py +2 -2
- paddlex/modules/anomaly_detection/dataset_checker/__init__.py +2 -3
- paddlex/modules/anomaly_detection/dataset_checker/dataset_src/__init__.py +2 -2
- paddlex/modules/anomaly_detection/dataset_checker/dataset_src/analyse_dataset.py +6 -3
- paddlex/modules/anomaly_detection/dataset_checker/dataset_src/check_dataset.py +8 -4
- paddlex/modules/anomaly_detection/dataset_checker/dataset_src/convert_dataset.py +7 -4
- paddlex/modules/anomaly_detection/dataset_checker/dataset_src/split_dataset.py +2 -2
- paddlex/modules/anomaly_detection/dataset_checker/dataset_src/utils/__init__.py +1 -1
- paddlex/modules/anomaly_detection/dataset_checker/dataset_src/utils/visualizer.py +7 -2
- paddlex/modules/anomaly_detection/evaluator.py +1 -1
- paddlex/modules/anomaly_detection/exportor.py +1 -1
- paddlex/modules/anomaly_detection/model_list.py +1 -1
- paddlex/modules/anomaly_detection/trainer.py +3 -4
- paddlex/modules/base/__init__.py +5 -5
- paddlex/modules/base/build_model.py +1 -2
- paddlex/modules/base/dataset_checker/__init__.py +2 -2
- paddlex/modules/base/dataset_checker/dataset_checker.py +4 -4
- paddlex/modules/base/dataset_checker/utils.py +1 -3
- paddlex/modules/base/evaluator.py +8 -8
- paddlex/modules/base/exportor.py +12 -13
- paddlex/modules/base/trainer.py +21 -11
- paddlex/modules/base/utils/__init__.py +13 -0
- paddlex/modules/base/utils/cinn_setting.py +89 -0
- paddlex/modules/base/utils/coco_eval.py +94 -0
- paddlex/modules/base/utils/topk_eval.py +118 -0
- paddlex/modules/doc_vlm/__init__.py +18 -0
- paddlex/modules/doc_vlm/dataset_checker.py +29 -0
- paddlex/modules/doc_vlm/evaluator.py +29 -0
- paddlex/modules/doc_vlm/exportor.py +29 -0
- paddlex/modules/doc_vlm/model_list.py +16 -0
- paddlex/modules/doc_vlm/trainer.py +41 -0
- paddlex/modules/face_recognition/__init__.py +2 -2
- paddlex/modules/face_recognition/dataset_checker/__init__.py +2 -2
- paddlex/modules/face_recognition/dataset_checker/dataset_src/__init__.py +1 -1
- paddlex/modules/face_recognition/dataset_checker/dataset_src/check_dataset.py +3 -5
- paddlex/modules/face_recognition/dataset_checker/dataset_src/utils/__init__.py +1 -1
- paddlex/modules/face_recognition/dataset_checker/dataset_src/utils/visualizer.py +2 -5
- paddlex/modules/face_recognition/evaluator.py +1 -1
- paddlex/modules/face_recognition/exportor.py +1 -1
- paddlex/modules/face_recognition/model_list.py +1 -1
- paddlex/modules/face_recognition/trainer.py +1 -1
- paddlex/modules/formula_recognition/__init__.py +2 -2
- paddlex/modules/formula_recognition/dataset_checker/__init__.py +3 -3
- paddlex/modules/formula_recognition/dataset_checker/dataset_src/__init__.py +2 -2
- paddlex/modules/formula_recognition/dataset_checker/dataset_src/analyse_dataset.py +13 -12
- paddlex/modules/formula_recognition/dataset_checker/dataset_src/check_dataset.py +2 -6
- paddlex/modules/formula_recognition/dataset_checker/dataset_src/convert_dataset.py +11 -10
- paddlex/modules/formula_recognition/dataset_checker/dataset_src/split_dataset.py +1 -2
- paddlex/modules/formula_recognition/evaluator.py +1 -1
- paddlex/modules/formula_recognition/exportor.py +1 -1
- paddlex/modules/formula_recognition/model_list.py +1 -1
- paddlex/modules/formula_recognition/trainer.py +2 -3
- paddlex/modules/general_recognition/__init__.py +2 -2
- paddlex/modules/general_recognition/dataset_checker/__init__.py +2 -2
- paddlex/modules/general_recognition/dataset_checker/dataset_src/__init__.py +2 -2
- paddlex/modules/general_recognition/dataset_checker/dataset_src/analyse_dataset.py +7 -9
- paddlex/modules/general_recognition/dataset_checker/dataset_src/check_dataset.py +4 -5
- paddlex/modules/general_recognition/dataset_checker/dataset_src/convert_dataset.py +6 -5
- paddlex/modules/general_recognition/dataset_checker/dataset_src/split_dataset.py +1 -1
- paddlex/modules/general_recognition/dataset_checker/dataset_src/utils/__init__.py +1 -1
- paddlex/modules/general_recognition/dataset_checker/dataset_src/utils/visualizer.py +2 -5
- paddlex/modules/general_recognition/evaluator.py +1 -1
- paddlex/modules/general_recognition/exportor.py +1 -1
- paddlex/modules/general_recognition/model_list.py +1 -1
- paddlex/modules/general_recognition/trainer.py +1 -1
- paddlex/modules/image_classification/__init__.py +2 -2
- paddlex/modules/image_classification/dataset_checker/__init__.py +2 -2
- paddlex/modules/image_classification/dataset_checker/dataset_src/__init__.py +2 -2
- paddlex/modules/image_classification/dataset_checker/dataset_src/analyse_dataset.py +8 -9
- paddlex/modules/image_classification/dataset_checker/dataset_src/check_dataset.py +4 -3
- paddlex/modules/image_classification/dataset_checker/dataset_src/convert_dataset.py +4 -4
- paddlex/modules/image_classification/dataset_checker/dataset_src/split_dataset.py +1 -1
- paddlex/modules/image_classification/dataset_checker/dataset_src/utils/__init__.py +1 -1
- paddlex/modules/image_classification/dataset_checker/dataset_src/utils/visualizer.py +2 -5
- paddlex/modules/image_classification/evaluator.py +1 -1
- paddlex/modules/image_classification/exportor.py +1 -1
- paddlex/modules/image_classification/model_list.py +1 -1
- paddlex/modules/image_classification/trainer.py +3 -3
- paddlex/modules/image_unwarping/__init__.py +1 -1
- paddlex/modules/image_unwarping/model_list.py +1 -1
- paddlex/modules/instance_segmentation/__init__.py +2 -2
- paddlex/modules/instance_segmentation/dataset_checker/__init__.py +2 -3
- paddlex/modules/instance_segmentation/dataset_checker/dataset_src/__init__.py +2 -2
- paddlex/modules/instance_segmentation/dataset_checker/dataset_src/analyse_dataset.py +9 -5
- paddlex/modules/instance_segmentation/dataset_checker/dataset_src/check_dataset.py +8 -5
- paddlex/modules/instance_segmentation/dataset_checker/dataset_src/convert_dataset.py +8 -8
- paddlex/modules/instance_segmentation/dataset_checker/dataset_src/split_dataset.py +7 -4
- paddlex/modules/instance_segmentation/dataset_checker/dataset_src/utils/__init__.py +1 -1
- paddlex/modules/instance_segmentation/dataset_checker/dataset_src/utils/visualizer.py +10 -8
- paddlex/modules/instance_segmentation/evaluator.py +1 -1
- paddlex/modules/instance_segmentation/exportor.py +1 -1
- paddlex/modules/instance_segmentation/model_list.py +1 -1
- paddlex/modules/instance_segmentation/trainer.py +1 -1
- paddlex/modules/keypoint_detection/__init__.py +2 -2
- paddlex/modules/keypoint_detection/dataset_checker/__init__.py +2 -2
- paddlex/modules/keypoint_detection/dataset_checker/dataset_src/__init__.py +1 -1
- paddlex/modules/keypoint_detection/dataset_checker/dataset_src/check_dataset.py +10 -5
- paddlex/modules/keypoint_detection/dataset_checker/dataset_src/utils/__init__.py +1 -1
- paddlex/modules/keypoint_detection/dataset_checker/dataset_src/utils/visualizer.py +8 -3
- paddlex/modules/keypoint_detection/evaluator.py +1 -1
- paddlex/modules/keypoint_detection/exportor.py +1 -1
- paddlex/modules/keypoint_detection/model_list.py +1 -1
- paddlex/modules/keypoint_detection/trainer.py +2 -2
- paddlex/modules/{3d_bev_detection → m_3d_bev_detection}/__init__.py +2 -2
- paddlex/modules/{3d_bev_detection → m_3d_bev_detection}/dataset_checker/__init__.py +3 -3
- paddlex/modules/{3d_bev_detection → m_3d_bev_detection}/dataset_checker/dataset_src/__init__.py +2 -2
- paddlex/modules/{3d_bev_detection → m_3d_bev_detection}/dataset_checker/dataset_src/analyse_dataset.py +8 -8
- paddlex/modules/{3d_bev_detection → m_3d_bev_detection}/dataset_checker/dataset_src/check_dataset.py +1 -2
- paddlex/modules/{3d_bev_detection → m_3d_bev_detection}/evaluator.py +1 -1
- paddlex/modules/{3d_bev_detection → m_3d_bev_detection}/exportor.py +1 -1
- paddlex/modules/{3d_bev_detection → m_3d_bev_detection}/model_list.py +1 -1
- paddlex/modules/{3d_bev_detection → m_3d_bev_detection}/trainer.py +5 -7
- paddlex/modules/multilabel_classification/__init__.py +2 -2
- paddlex/modules/multilabel_classification/dataset_checker/__init__.py +2 -2
- paddlex/modules/multilabel_classification/dataset_checker/dataset_src/__init__.py +2 -2
- paddlex/modules/multilabel_classification/dataset_checker/dataset_src/analyse_dataset.py +8 -9
- paddlex/modules/multilabel_classification/dataset_checker/dataset_src/check_dataset.py +4 -3
- paddlex/modules/multilabel_classification/dataset_checker/dataset_src/convert_dataset.py +10 -7
- paddlex/modules/multilabel_classification/dataset_checker/dataset_src/split_dataset.py +1 -1
- paddlex/modules/multilabel_classification/dataset_checker/dataset_src/utils/__init__.py +1 -1
- paddlex/modules/multilabel_classification/dataset_checker/dataset_src/utils/visualizer.py +1 -5
- paddlex/modules/multilabel_classification/evaluator.py +1 -1
- paddlex/modules/multilabel_classification/exportor.py +1 -1
- paddlex/modules/multilabel_classification/model_list.py +1 -1
- paddlex/modules/multilabel_classification/trainer.py +3 -3
- paddlex/modules/multilingual_speech_recognition/__init__.py +2 -2
- paddlex/modules/multilingual_speech_recognition/dataset_checker.py +3 -3
- paddlex/modules/multilingual_speech_recognition/evaluator.py +3 -3
- paddlex/modules/multilingual_speech_recognition/exportor.py +3 -3
- paddlex/modules/multilingual_speech_recognition/model_list.py +1 -1
- paddlex/modules/multilingual_speech_recognition/trainer.py +7 -5
- paddlex/modules/object_detection/__init__.py +2 -2
- paddlex/modules/object_detection/dataset_checker/__init__.py +2 -11
- paddlex/modules/object_detection/dataset_checker/dataset_src/__init__.py +2 -2
- paddlex/modules/object_detection/dataset_checker/dataset_src/analyse_dataset.py +10 -8
- paddlex/modules/object_detection/dataset_checker/dataset_src/check_dataset.py +10 -5
- paddlex/modules/object_detection/dataset_checker/dataset_src/convert_dataset.py +13 -8
- paddlex/modules/object_detection/dataset_checker/dataset_src/split_dataset.py +8 -4
- paddlex/modules/object_detection/dataset_checker/dataset_src/utils/__init__.py +1 -1
- paddlex/modules/object_detection/dataset_checker/dataset_src/utils/visualizer.py +9 -8
- paddlex/modules/object_detection/evaluator.py +9 -4
- paddlex/modules/object_detection/exportor.py +1 -1
- paddlex/modules/object_detection/model_list.py +1 -1
- paddlex/modules/object_detection/trainer.py +4 -5
- paddlex/modules/open_vocabulary_detection/__init__.py +2 -2
- paddlex/modules/open_vocabulary_detection/dataset_checker.py +3 -3
- paddlex/modules/open_vocabulary_detection/evaluator.py +3 -3
- paddlex/modules/open_vocabulary_detection/exportor.py +3 -3
- paddlex/modules/open_vocabulary_detection/model_list.py +2 -4
- paddlex/modules/open_vocabulary_detection/trainer.py +7 -5
- paddlex/modules/open_vocabulary_segmentation/__init__.py +2 -2
- paddlex/modules/open_vocabulary_segmentation/dataset_checker.py +3 -3
- paddlex/modules/open_vocabulary_segmentation/evaluator.py +3 -3
- paddlex/modules/open_vocabulary_segmentation/exportor.py +3 -3
- paddlex/modules/open_vocabulary_segmentation/model_list.py +1 -1
- paddlex/modules/open_vocabulary_segmentation/trainer.py +7 -5
- paddlex/modules/semantic_segmentation/__init__.py +2 -2
- paddlex/modules/semantic_segmentation/dataset_checker/__init__.py +2 -3
- paddlex/modules/semantic_segmentation/dataset_checker/dataset_src/__init__.py +2 -2
- paddlex/modules/semantic_segmentation/dataset_checker/dataset_src/analyse_dataset.py +6 -3
- paddlex/modules/semantic_segmentation/dataset_checker/dataset_src/check_dataset.py +2 -2
- paddlex/modules/semantic_segmentation/dataset_checker/dataset_src/convert_dataset.py +7 -4
- paddlex/modules/semantic_segmentation/dataset_checker/dataset_src/split_dataset.py +2 -2
- paddlex/modules/semantic_segmentation/dataset_checker/dataset_src/utils/__init__.py +1 -1
- paddlex/modules/semantic_segmentation/dataset_checker/dataset_src/utils/visualizer.py +6 -2
- paddlex/modules/semantic_segmentation/evaluator.py +1 -1
- paddlex/modules/semantic_segmentation/exportor.py +1 -1
- paddlex/modules/semantic_segmentation/model_list.py +1 -1
- paddlex/modules/semantic_segmentation/trainer.py +3 -4
- paddlex/modules/table_recognition/__init__.py +2 -2
- paddlex/modules/table_recognition/dataset_checker/__init__.py +5 -5
- paddlex/modules/table_recognition/dataset_checker/dataset_src/__init__.py +2 -2
- paddlex/modules/table_recognition/dataset_checker/dataset_src/analyse_dataset.py +3 -2
- paddlex/modules/table_recognition/dataset_checker/dataset_src/check_dataset.py +8 -7
- paddlex/modules/table_recognition/dataset_checker/dataset_src/split_dataset.py +2 -1
- paddlex/modules/table_recognition/evaluator.py +1 -1
- paddlex/modules/table_recognition/exportor.py +1 -1
- paddlex/modules/table_recognition/model_list.py +1 -1
- paddlex/modules/table_recognition/trainer.py +2 -5
- paddlex/modules/text_detection/__init__.py +2 -2
- paddlex/modules/text_detection/dataset_checker/__init__.py +4 -6
- paddlex/modules/text_detection/dataset_checker/dataset_src/__init__.py +2 -2
- paddlex/modules/text_detection/dataset_checker/dataset_src/analyse_dataset.py +12 -9
- paddlex/modules/text_detection/dataset_checker/dataset_src/check_dataset.py +3 -3
- paddlex/modules/text_detection/dataset_checker/dataset_src/split_dataset.py +3 -3
- paddlex/modules/text_detection/evaluator.py +1 -1
- paddlex/modules/text_detection/exportor.py +1 -1
- paddlex/modules/text_detection/model_list.py +1 -1
- paddlex/modules/text_detection/trainer.py +2 -5
- paddlex/modules/text_recognition/__init__.py +2 -2
- paddlex/modules/text_recognition/dataset_checker/__init__.py +4 -5
- paddlex/modules/text_recognition/dataset_checker/dataset_src/__init__.py +2 -2
- paddlex/modules/text_recognition/dataset_checker/dataset_src/analyse_dataset.py +13 -12
- paddlex/modules/text_recognition/dataset_checker/dataset_src/check_dataset.py +2 -5
- paddlex/modules/text_recognition/dataset_checker/dataset_src/convert_dataset.py +11 -10
- paddlex/modules/text_recognition/dataset_checker/dataset_src/split_dataset.py +1 -2
- paddlex/modules/text_recognition/evaluator.py +1 -1
- paddlex/modules/text_recognition/exportor.py +1 -1
- paddlex/modules/text_recognition/model_list.py +1 -1
- paddlex/modules/text_recognition/trainer.py +2 -3
- paddlex/modules/ts_anomaly_detection/__init__.py +2 -2
- paddlex/modules/ts_anomaly_detection/dataset_checker/__init__.py +4 -5
- paddlex/modules/ts_anomaly_detection/dataset_checker/dataset_src/__init__.py +2 -2
- paddlex/modules/ts_anomaly_detection/dataset_checker/dataset_src/analyse_dataset.py +1 -9
- paddlex/modules/ts_anomaly_detection/dataset_checker/dataset_src/check_dataset.py +2 -2
- paddlex/modules/ts_anomaly_detection/dataset_checker/dataset_src/convert_dataset.py +2 -6
- paddlex/modules/ts_anomaly_detection/dataset_checker/dataset_src/split_dataset.py +4 -4
- paddlex/modules/ts_anomaly_detection/evaluator.py +1 -1
- paddlex/modules/ts_anomaly_detection/exportor.py +2 -3
- paddlex/modules/ts_anomaly_detection/model_list.py +1 -1
- paddlex/modules/ts_anomaly_detection/trainer.py +8 -8
- paddlex/modules/ts_classification/__init__.py +2 -2
- paddlex/modules/ts_classification/dataset_checker/__init__.py +4 -5
- paddlex/modules/ts_classification/dataset_checker/dataset_src/__init__.py +2 -2
- paddlex/modules/ts_classification/dataset_checker/dataset_src/analyse_dataset.py +8 -5
- paddlex/modules/ts_classification/dataset_checker/dataset_src/check_dataset.py +2 -2
- paddlex/modules/ts_classification/dataset_checker/dataset_src/convert_dataset.py +2 -6
- paddlex/modules/ts_classification/dataset_checker/dataset_src/split_dataset.py +4 -4
- paddlex/modules/ts_classification/evaluator.py +1 -1
- paddlex/modules/ts_classification/exportor.py +2 -3
- paddlex/modules/ts_classification/model_list.py +1 -1
- paddlex/modules/ts_classification/trainer.py +7 -7
- paddlex/modules/ts_forecast/__init__.py +2 -2
- paddlex/modules/ts_forecast/dataset_checker/__init__.py +4 -5
- paddlex/modules/ts_forecast/dataset_checker/dataset_src/__init__.py +2 -2
- paddlex/modules/ts_forecast/dataset_checker/dataset_src/analyse_dataset.py +1 -9
- paddlex/modules/ts_forecast/dataset_checker/dataset_src/check_dataset.py +2 -2
- paddlex/modules/ts_forecast/dataset_checker/dataset_src/convert_dataset.py +2 -6
- paddlex/modules/ts_forecast/dataset_checker/dataset_src/split_dataset.py +4 -4
- paddlex/modules/ts_forecast/evaluator.py +1 -1
- paddlex/modules/ts_forecast/exportor.py +2 -3
- paddlex/modules/ts_forecast/model_list.py +1 -1
- paddlex/modules/ts_forecast/trainer.py +7 -7
- paddlex/modules/video_classification/__init__.py +2 -2
- paddlex/modules/video_classification/dataset_checker/__init__.py +2 -2
- paddlex/modules/video_classification/dataset_checker/dataset_src/__init__.py +2 -2
- paddlex/modules/video_classification/dataset_checker/dataset_src/analyse_dataset.py +9 -9
- paddlex/modules/video_classification/dataset_checker/dataset_src/check_dataset.py +2 -3
- paddlex/modules/video_classification/dataset_checker/dataset_src/split_dataset.py +1 -1
- paddlex/modules/video_classification/evaluator.py +1 -1
- paddlex/modules/video_classification/exportor.py +1 -1
- paddlex/modules/video_classification/model_list.py +1 -1
- paddlex/modules/video_classification/trainer.py +3 -3
- paddlex/modules/video_detection/__init__.py +2 -2
- paddlex/modules/video_detection/dataset_checker/__init__.py +2 -2
- paddlex/modules/video_detection/dataset_checker/dataset_src/__init__.py +2 -2
- paddlex/modules/video_detection/dataset_checker/dataset_src/analyse_dataset.py +8 -9
- paddlex/modules/video_detection/dataset_checker/dataset_src/check_dataset.py +3 -5
- paddlex/modules/video_detection/evaluator.py +1 -1
- paddlex/modules/video_detection/exportor.py +1 -1
- paddlex/modules/video_detection/model_list.py +1 -1
- paddlex/modules/video_detection/trainer.py +3 -3
- paddlex/ops/__init__.py +5 -2
- paddlex/ops/iou3d_nms/iou3d_cpu.cpp +8 -6
- paddlex/ops/iou3d_nms/iou3d_cpu.h +3 -2
- paddlex/ops/iou3d_nms/iou3d_nms.cpp +8 -6
- paddlex/ops/iou3d_nms/iou3d_nms.h +6 -4
- paddlex/ops/iou3d_nms/iou3d_nms_api.cpp +24 -18
- paddlex/ops/iou3d_nms/iou3d_nms_kernel.cu +9 -7
- paddlex/ops/setup.py +3 -3
- paddlex/ops/voxel/voxelize_op.cc +22 -19
- paddlex/ops/voxel/voxelize_op.cu +25 -25
- paddlex/paddlex_cli.py +86 -75
- paddlex/repo_apis/Paddle3D_api/__init__.py +1 -1
- paddlex/repo_apis/Paddle3D_api/bev_fusion/__init__.py +1 -1
- paddlex/repo_apis/Paddle3D_api/bev_fusion/config.py +1 -1
- paddlex/repo_apis/Paddle3D_api/bev_fusion/model.py +4 -4
- paddlex/repo_apis/Paddle3D_api/bev_fusion/register.py +2 -2
- paddlex/repo_apis/Paddle3D_api/bev_fusion/runner.py +1 -1
- paddlex/repo_apis/Paddle3D_api/pp3d_config.py +3 -2
- paddlex/repo_apis/PaddleClas_api/__init__.py +1 -1
- paddlex/repo_apis/PaddleClas_api/cls/__init__.py +3 -3
- paddlex/repo_apis/PaddleClas_api/cls/config.py +4 -3
- paddlex/repo_apis/PaddleClas_api/cls/model.py +3 -3
- paddlex/repo_apis/PaddleClas_api/cls/register.py +2 -3
- paddlex/repo_apis/PaddleClas_api/cls/runner.py +1 -2
- paddlex/repo_apis/PaddleClas_api/shitu_rec/__init__.py +2 -2
- paddlex/repo_apis/PaddleClas_api/shitu_rec/config.py +2 -2
- paddlex/repo_apis/PaddleClas_api/shitu_rec/model.py +1 -4
- paddlex/repo_apis/PaddleClas_api/shitu_rec/register.py +2 -2
- paddlex/repo_apis/PaddleClas_api/shitu_rec/runner.py +1 -6
- paddlex/repo_apis/PaddleDetection_api/__init__.py +2 -2
- paddlex/repo_apis/PaddleDetection_api/config_helper.py +3 -3
- paddlex/repo_apis/PaddleDetection_api/instance_seg/__init__.py +2 -2
- paddlex/repo_apis/PaddleDetection_api/instance_seg/config.py +2 -3
- paddlex/repo_apis/PaddleDetection_api/instance_seg/model.py +3 -3
- paddlex/repo_apis/PaddleDetection_api/instance_seg/register.py +2 -3
- paddlex/repo_apis/PaddleDetection_api/instance_seg/runner.py +1 -2
- paddlex/repo_apis/PaddleDetection_api/object_det/__init__.py +3 -3
- paddlex/repo_apis/PaddleDetection_api/object_det/config.py +4 -3
- paddlex/repo_apis/PaddleDetection_api/object_det/model.py +5 -6
- paddlex/repo_apis/PaddleDetection_api/object_det/official_categories.py +1 -1
- paddlex/repo_apis/PaddleDetection_api/object_det/register.py +2 -3
- paddlex/repo_apis/PaddleDetection_api/object_det/runner.py +1 -2
- paddlex/repo_apis/PaddleNLP_api/__init__.py +1 -1
- paddlex/repo_apis/PaddleOCR_api/__init__.py +4 -3
- paddlex/repo_apis/PaddleOCR_api/config_utils.py +1 -1
- paddlex/repo_apis/PaddleOCR_api/formula_rec/__init__.py +1 -1
- paddlex/repo_apis/PaddleOCR_api/formula_rec/config.py +4 -3
- paddlex/repo_apis/PaddleOCR_api/formula_rec/model.py +4 -4
- paddlex/repo_apis/PaddleOCR_api/formula_rec/register.py +2 -3
- paddlex/repo_apis/PaddleOCR_api/formula_rec/runner.py +1 -2
- paddlex/repo_apis/PaddleOCR_api/table_rec/__init__.py +1 -1
- paddlex/repo_apis/PaddleOCR_api/table_rec/config.py +1 -1
- paddlex/repo_apis/PaddleOCR_api/table_rec/model.py +3 -3
- paddlex/repo_apis/PaddleOCR_api/table_rec/register.py +2 -3
- paddlex/repo_apis/PaddleOCR_api/table_rec/runner.py +2 -2
- paddlex/repo_apis/PaddleOCR_api/text_det/__init__.py +1 -1
- paddlex/repo_apis/PaddleOCR_api/text_det/config.py +1 -1
- paddlex/repo_apis/PaddleOCR_api/text_det/model.py +3 -3
- paddlex/repo_apis/PaddleOCR_api/text_det/register.py +2 -3
- paddlex/repo_apis/PaddleOCR_api/text_det/runner.py +2 -2
- paddlex/repo_apis/PaddleOCR_api/text_rec/__init__.py +1 -1
- paddlex/repo_apis/PaddleOCR_api/text_rec/config.py +4 -3
- paddlex/repo_apis/PaddleOCR_api/text_rec/model.py +4 -4
- paddlex/repo_apis/PaddleOCR_api/text_rec/register.py +2 -3
- paddlex/repo_apis/PaddleOCR_api/text_rec/runner.py +1 -2
- paddlex/repo_apis/PaddleSeg_api/__init__.py +1 -1
- paddlex/repo_apis/PaddleSeg_api/base_seg_config.py +2 -2
- paddlex/repo_apis/PaddleSeg_api/seg/__init__.py +1 -1
- paddlex/repo_apis/PaddleSeg_api/seg/config.py +3 -6
- paddlex/repo_apis/PaddleSeg_api/seg/model.py +5 -5
- paddlex/repo_apis/PaddleSeg_api/seg/register.py +2 -3
- paddlex/repo_apis/PaddleSeg_api/seg/runner.py +1 -2
- paddlex/repo_apis/PaddleTS_api/__init__.py +4 -3
- paddlex/repo_apis/PaddleTS_api/ts_ad/__init__.py +1 -1
- paddlex/repo_apis/PaddleTS_api/ts_ad/config.py +2 -3
- paddlex/repo_apis/PaddleTS_api/ts_ad/register.py +2 -2
- paddlex/repo_apis/PaddleTS_api/ts_ad/runner.py +2 -2
- paddlex/repo_apis/PaddleTS_api/ts_base/__init__.py +1 -1
- paddlex/repo_apis/PaddleTS_api/ts_base/config.py +2 -4
- paddlex/repo_apis/PaddleTS_api/ts_base/model.py +4 -4
- paddlex/repo_apis/PaddleTS_api/ts_base/runner.py +2 -2
- paddlex/repo_apis/PaddleTS_api/ts_cls/__init__.py +1 -1
- paddlex/repo_apis/PaddleTS_api/ts_cls/config.py +2 -3
- paddlex/repo_apis/PaddleTS_api/ts_cls/register.py +2 -2
- paddlex/repo_apis/PaddleTS_api/ts_cls/runner.py +2 -2
- paddlex/repo_apis/PaddleTS_api/ts_fc/__init__.py +1 -1
- paddlex/repo_apis/PaddleTS_api/ts_fc/config.py +2 -3
- paddlex/repo_apis/PaddleTS_api/ts_fc/register.py +1 -1
- paddlex/repo_apis/PaddleVideo_api/__init__.py +1 -1
- paddlex/repo_apis/PaddleVideo_api/config_utils.py +1 -1
- paddlex/repo_apis/PaddleVideo_api/video_cls/__init__.py +3 -3
- paddlex/repo_apis/PaddleVideo_api/video_cls/config.py +4 -3
- paddlex/repo_apis/PaddleVideo_api/video_cls/model.py +3 -3
- paddlex/repo_apis/PaddleVideo_api/video_cls/register.py +2 -3
- paddlex/repo_apis/PaddleVideo_api/video_cls/runner.py +1 -2
- paddlex/repo_apis/PaddleVideo_api/video_det/__init__.py +3 -3
- paddlex/repo_apis/PaddleVideo_api/video_det/config.py +4 -3
- paddlex/repo_apis/PaddleVideo_api/video_det/model.py +4 -4
- paddlex/repo_apis/PaddleVideo_api/video_det/register.py +2 -3
- paddlex/repo_apis/PaddleVideo_api/video_det/runner.py +1 -2
- paddlex/repo_apis/__init__.py +1 -1
- paddlex/repo_apis/base/__init__.py +4 -5
- paddlex/repo_apis/base/config.py +2 -3
- paddlex/repo_apis/base/model.py +11 -19
- paddlex/repo_apis/base/register.py +1 -1
- paddlex/repo_apis/base/runner.py +11 -12
- paddlex/repo_apis/base/utils/__init__.py +1 -1
- paddlex/repo_apis/base/utils/arg.py +1 -1
- paddlex/repo_apis/base/utils/subprocess.py +1 -1
- paddlex/repo_manager/__init__.py +2 -9
- paddlex/repo_manager/core.py +9 -27
- paddlex/repo_manager/meta.py +37 -31
- paddlex/repo_manager/repo.py +169 -160
- paddlex/repo_manager/utils.py +13 -224
- paddlex/utils/__init__.py +1 -1
- paddlex/utils/cache.py +8 -10
- paddlex/utils/config.py +6 -5
- paddlex/utils/{custom_device_whitelist.py → custom_device_list.py} +29 -199
- paddlex/utils/deps.py +249 -0
- paddlex/utils/device.py +73 -29
- paddlex/utils/download.py +4 -4
- paddlex/utils/env.py +33 -7
- paddlex/utils/errors/__init__.py +1 -1
- paddlex/utils/errors/dataset_checker.py +1 -1
- paddlex/utils/errors/others.py +2 -16
- paddlex/utils/file_interface.py +4 -5
- paddlex/utils/flags.py +19 -12
- paddlex/utils/fonts/__init__.py +2 -1
- paddlex/utils/func_register.py +1 -1
- paddlex/utils/install.py +87 -0
- paddlex/utils/interactive_get_pipeline.py +3 -3
- paddlex/utils/lazy_loader.py +3 -3
- paddlex/utils/logging.py +10 -1
- paddlex/utils/misc.py +5 -5
- paddlex/utils/pipeline_arguments.py +15 -7
- paddlex/utils/result_saver.py +4 -5
- paddlex/utils/subclass_register.py +2 -4
- paddlex/version.py +2 -1
- {paddlex-3.0.0rc0.dist-info → paddlex-3.0.0rc1.dist-info}/METADATA +212 -73
- paddlex-3.0.0rc1.dist-info/RECORD +1068 -0
- {paddlex-3.0.0rc0.dist-info → paddlex-3.0.0rc1.dist-info}/WHEEL +1 -1
- paddlex/inference/models/base/predictor/basic_predictor.py +0 -139
- paddlex/paddle2onnx_requirements.txt +0 -1
- paddlex/repo_manager/requirements.txt +0 -21
- paddlex/serving_requirements.txt +0 -9
- paddlex-3.0.0rc0.dist-info/RECORD +0 -1015
- {paddlex-3.0.0rc0.dist-info → paddlex-3.0.0rc1.dist-info}/entry_points.txt +0 -0
- {paddlex-3.0.0rc0.dist-info → paddlex-3.0.0rc1.dist-info/licenses}/LICENSE +0 -0
- {paddlex-3.0.0rc0.dist-info → paddlex-3.0.0rc1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,730 @@
|
|
1
|
+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import inspect
|
16
|
+
from abc import ABC
|
17
|
+
from collections import OrderedDict
|
18
|
+
from typing import Callable, Dict, List, Tuple, Union
|
19
|
+
|
20
|
+
import numpy as np
|
21
|
+
import paddle
|
22
|
+
from paddle.nn.layer.layers import in_declarative_mode
|
23
|
+
|
24
|
+
|
25
|
+
class LogitsProcessor(ABC):
|
26
|
+
"""
|
27
|
+
Abstract base class for all logit processors that can be applied during
|
28
|
+
generation.
|
29
|
+
"""
|
30
|
+
|
31
|
+
def __call__(self, input_ids: paddle.Tensor, logits: paddle.Tensor):
|
32
|
+
raise NotImplementedError(
|
33
|
+
f"{self.__class__} is an abstract class. "
|
34
|
+
"Only classes inheriting this class can be called."
|
35
|
+
)
|
36
|
+
|
37
|
+
|
38
|
+
class LogitsProcessorList:
|
39
|
+
"""use ordered dict to store processors"""
|
40
|
+
|
41
|
+
def __init__(self, processors: List[LogitsProcessor] = None) -> None:
|
42
|
+
self._processors = OrderedDict()
|
43
|
+
processors = processors or []
|
44
|
+
for processor in processors:
|
45
|
+
self.append(processor)
|
46
|
+
|
47
|
+
def __call__(self, input_ids: paddle.Tensor, logits: paddle.Tensor, **kwargs):
|
48
|
+
for processor in self._processors.values():
|
49
|
+
processor_args = inspect.signature(processor.__call__).parameters
|
50
|
+
if len(processor_args) > 2:
|
51
|
+
assert all(
|
52
|
+
arg in kwargs for arg in list(processor_args.keys())[2:]
|
53
|
+
), f"The parameters don't match for {processor.__class__}"
|
54
|
+
logits = processor(input_ids, logits, **kwargs)
|
55
|
+
else:
|
56
|
+
logits = processor(input_ids, logits)
|
57
|
+
return logits
|
58
|
+
|
59
|
+
def append(self, processor: LogitsProcessor):
|
60
|
+
self._processors[len(self._processors)] = processor
|
61
|
+
|
62
|
+
|
63
|
+
class MinLengthLogitsProcessor(LogitsProcessor):
|
64
|
+
r"""
|
65
|
+
Enforcing a min-length by setting EOS probability to 0.
|
66
|
+
|
67
|
+
Args:
|
68
|
+
min_length (int): The minimum length of generation sequence.
|
69
|
+
eos_token_id (int): The id of the `end-of-sequence` token.
|
70
|
+
"""
|
71
|
+
|
72
|
+
def __init__(self, min_length: int, eos_token_id: Union[int, List[int]]):
|
73
|
+
if min_length < 0 and not in_declarative_mode():
|
74
|
+
raise ValueError(
|
75
|
+
"`min_length` should be a positive integer, but get {}".format(
|
76
|
+
min_length
|
77
|
+
)
|
78
|
+
)
|
79
|
+
|
80
|
+
if not isinstance(eos_token_id, int) or eos_token_id < 0:
|
81
|
+
raise ValueError(
|
82
|
+
"`eos_token_id` should be a positive integer, but get {}".format(
|
83
|
+
eos_token_id
|
84
|
+
)
|
85
|
+
)
|
86
|
+
|
87
|
+
self.min_length = min_length
|
88
|
+
self.eos_token_id = eos_token_id
|
89
|
+
|
90
|
+
def __call__(self, input_ids: paddle.Tensor, logits: paddle.Tensor):
|
91
|
+
cur_len = input_ids.shape[-1]
|
92
|
+
if cur_len < self.min_length:
|
93
|
+
logits[:, self.eos_token_id] = paddle.finfo(logits.dtype).min
|
94
|
+
return logits
|
95
|
+
|
96
|
+
|
97
|
+
class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
|
98
|
+
r"""
|
99
|
+
Enforcing an exponential penalty on repeated sequences.
|
100
|
+
|
101
|
+
Args:
|
102
|
+
repetition_penalty (float):
|
103
|
+
The parameter for repetition penalty. 1.0 means no penalty. See `this paper
|
104
|
+
<https://arxiv.org/pdf/1909.05858.pdf>`__ for more details.
|
105
|
+
"""
|
106
|
+
|
107
|
+
def __init__(self, penalty: float):
|
108
|
+
if not (penalty > 0) and not in_declarative_mode():
|
109
|
+
raise ValueError(
|
110
|
+
f"`penalty` has to be a strictly positive float, but is {penalty}"
|
111
|
+
)
|
112
|
+
|
113
|
+
self.penalty = penalty
|
114
|
+
|
115
|
+
def __call__(self, input_ids: paddle.Tensor, logits: paddle.Tensor):
|
116
|
+
score = paddle.index_sample(logits, input_ids)
|
117
|
+
score = paddle.where(score < 0, score * self.penalty, score / self.penalty)
|
118
|
+
input_ids = (
|
119
|
+
input_ids
|
120
|
+
+ paddle.arange(logits.shape[0], dtype="int64").unsqueeze(-1)
|
121
|
+
* logits.shape[-1]
|
122
|
+
)
|
123
|
+
outputs = paddle.scatter(
|
124
|
+
logits.flatten(), input_ids.flatten(), score.flatten()
|
125
|
+
).reshape(logits.shape)
|
126
|
+
return outputs
|
127
|
+
|
128
|
+
|
129
|
+
def _get_ngrams(ngram_size: int, prev_input_ids: paddle.Tensor, num_hypos: int):
|
130
|
+
"""
|
131
|
+
Assume ngram_size=2 and prev_input_ids=tensor([[40, 2883, 2712, 4346]]). The output of generated ngrams look like
|
132
|
+
this {(40,): [2883], (2883,): [2712], (2712,): [4346]}.
|
133
|
+
|
134
|
+
Args:
|
135
|
+
ngram_size (`int`):
|
136
|
+
The number sequential tokens taken as a group which may only occur once before being banned.
|
137
|
+
prev_input_ids (`paddle.Tensor`):
|
138
|
+
Generated token ids for the current hypothesis.
|
139
|
+
num_hypos (`int`):
|
140
|
+
The number of hypotheses for which n-grams need to be generated.
|
141
|
+
|
142
|
+
Returns:
|
143
|
+
generated_ngrams (`dict`):
|
144
|
+
Dictionary of generated ngrams.
|
145
|
+
"""
|
146
|
+
generated_ngrams = [{} for _ in range(num_hypos)]
|
147
|
+
for idx in range(num_hypos):
|
148
|
+
gen_tokens = prev_input_ids[idx].tolist()
|
149
|
+
generated_ngram = generated_ngrams[idx]
|
150
|
+
for ngram in zip(*[gen_tokens[i:] for i in range(ngram_size)]):
|
151
|
+
prev_ngram_tuple = tuple(ngram[:-1])
|
152
|
+
generated_ngram[prev_ngram_tuple] = generated_ngram.get(
|
153
|
+
prev_ngram_tuple, []
|
154
|
+
) + [ngram[-1]]
|
155
|
+
return generated_ngrams
|
156
|
+
|
157
|
+
|
158
|
+
def _get_generated_ngrams(banned_ngrams, prev_input_ids, ngram_size, cur_len):
|
159
|
+
"""
|
160
|
+
Determines the banned tokens for the current hypothesis based on previously generated n-grams.
|
161
|
+
|
162
|
+
Args:
|
163
|
+
banned_ngrams (`dict`):
|
164
|
+
A dictionary containing previously generated n-grams for each hypothesis.
|
165
|
+
prev_input_ids (`paddle.Tensor`):
|
166
|
+
Generated token ids for the current hypothesis.
|
167
|
+
ngram_size (`int`):
|
168
|
+
The number sequential tokens taken as a group which may only occur once before being banned.
|
169
|
+
cur_len (`int`):
|
170
|
+
The current length of the token sequences for which the n-grams are being checked.
|
171
|
+
|
172
|
+
Returns:
|
173
|
+
List of tokens that are banned.
|
174
|
+
"""
|
175
|
+
start_idx = cur_len + 1 - ngram_size
|
176
|
+
ngram_idx = tuple(prev_input_ids[start_idx:cur_len].tolist())
|
177
|
+
return banned_ngrams.get(ngram_idx, [])
|
178
|
+
|
179
|
+
|
180
|
+
def _calc_banned_ngram_tokens(
|
181
|
+
ngram_size: int, prev_input_ids: paddle.Tensor, num_hypos: int, cur_len: int
|
182
|
+
):
|
183
|
+
"""Copied from fairseq for no_repeat_ngram in beam_search"""
|
184
|
+
if cur_len + 1 < ngram_size:
|
185
|
+
# return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
|
186
|
+
return [[] for _ in range(num_hypos)]
|
187
|
+
|
188
|
+
generated_ngrams = _get_ngrams(ngram_size, prev_input_ids, num_hypos)
|
189
|
+
|
190
|
+
banned_tokens = [
|
191
|
+
_get_generated_ngrams(
|
192
|
+
generated_ngrams[hypo_idx], prev_input_ids[hypo_idx], ngram_size, cur_len
|
193
|
+
)
|
194
|
+
for hypo_idx in range(num_hypos)
|
195
|
+
]
|
196
|
+
return banned_tokens
|
197
|
+
|
198
|
+
|
199
|
+
class NoRepeatNGramLogitsProcessor(LogitsProcessor):
|
200
|
+
r"""
|
201
|
+
[`LogitsProcessor`] that enforces no repetition of n-grams. See
|
202
|
+
[Fairseq](https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345).
|
203
|
+
Args:
|
204
|
+
ngram_size (`int`):
|
205
|
+
All ngrams of size `ngram_size` can only occur once.
|
206
|
+
"""
|
207
|
+
|
208
|
+
def __init__(self, ngram_size: int):
|
209
|
+
if not isinstance(ngram_size, int) or ngram_size <= 0:
|
210
|
+
raise ValueError(
|
211
|
+
f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}"
|
212
|
+
)
|
213
|
+
self.ngram_size = ngram_size
|
214
|
+
|
215
|
+
def __call__(self, input_ids: paddle.Tensor, scores: paddle.Tensor):
|
216
|
+
num_batch_hypotheses = scores.shape[0]
|
217
|
+
cur_len = input_ids.shape[-1]
|
218
|
+
banned_batch_tokens = _calc_banned_ngram_tokens(
|
219
|
+
self.ngram_size, input_ids, num_batch_hypotheses, cur_len
|
220
|
+
)
|
221
|
+
|
222
|
+
for i, banned_tokens in enumerate(banned_batch_tokens):
|
223
|
+
if len(banned_tokens) == 0:
|
224
|
+
continue
|
225
|
+
scores[i, banned_tokens] = paddle.finfo(scores.dtype).min
|
226
|
+
|
227
|
+
return scores
|
228
|
+
|
229
|
+
|
230
|
+
class HammingDiversityLogitsProcessor(LogitsProcessor):
|
231
|
+
"""
|
232
|
+
This `LogitsProcessor` enforces diverse beam search. Note that this logits
|
233
|
+
processor is only effective for `group_beam_search`. See
|
234
|
+
`this paper <https://arxiv.org/pdf/1610.02424.pdf>`__ for more details.
|
235
|
+
|
236
|
+
Args:
|
237
|
+
diversity_rate (float): This value is subtracted from a beam's score if
|
238
|
+
it generates a token same as any beam from other group at a particular
|
239
|
+
time.
|
240
|
+
num_beams (int): Number of beams used for group beam search.
|
241
|
+
num_beam_groups (int): Number of groups to divide `num_beams` into in order
|
242
|
+
to ensure diversity among different groups of beams.
|
243
|
+
"""
|
244
|
+
|
245
|
+
def __init__(self, diversity_rate: float, num_beams: int, num_beam_groups: int):
|
246
|
+
if not isinstance(diversity_rate, float) or (not diversity_rate > 0.0):
|
247
|
+
raise ValueError(
|
248
|
+
"`diversity_rate` should be a float strictly larger than 0."
|
249
|
+
)
|
250
|
+
self._diversity_rate = diversity_rate
|
251
|
+
if not isinstance(num_beams, int) or num_beams < 2:
|
252
|
+
raise ValueError("`num_beams` should be an integer strictly larger than 1.")
|
253
|
+
self._num_beams = num_beams
|
254
|
+
if not isinstance(num_beam_groups, int) or num_beam_groups < 2:
|
255
|
+
raise ValueError(
|
256
|
+
"`num_beam_groups` should be an integer strictly larger than 1."
|
257
|
+
)
|
258
|
+
self._num_sub_beams = num_beams // num_beam_groups
|
259
|
+
|
260
|
+
def __call__(
|
261
|
+
self,
|
262
|
+
input_ids: paddle.Tensor,
|
263
|
+
scores: paddle.Tensor,
|
264
|
+
current_tokens: paddle.Tensor,
|
265
|
+
beam_group_idx: int,
|
266
|
+
):
|
267
|
+
batch_size = current_tokens.shape[0] // self._num_beams
|
268
|
+
group_start_idx = beam_group_idx * self._num_sub_beams
|
269
|
+
group_end_idx = min(group_start_idx + self._num_sub_beams, self._num_beams)
|
270
|
+
group_size = group_end_idx - group_start_idx
|
271
|
+
vocab_size = scores.shape[-1]
|
272
|
+
|
273
|
+
if group_start_idx == 0:
|
274
|
+
return scores
|
275
|
+
|
276
|
+
for batch_idx in range(batch_size):
|
277
|
+
previous_group_tokens = current_tokens[
|
278
|
+
batch_idx * self._num_beams : batch_idx * self._num_beams
|
279
|
+
+ group_start_idx
|
280
|
+
]
|
281
|
+
token_frequency = paddle.bincount(
|
282
|
+
previous_group_tokens, minlength=vocab_size
|
283
|
+
)
|
284
|
+
scores[batch_idx * group_size : (batch_idx + 1) * group_size] -= (
|
285
|
+
self._diversity_rate * token_frequency
|
286
|
+
)
|
287
|
+
|
288
|
+
return scores
|
289
|
+
|
290
|
+
|
291
|
+
class ForcedBOSTokenLogitsProcessor(LogitsProcessor):
|
292
|
+
"""
|
293
|
+
This `LogitsProcessor` enforces the first generated token to be the selected `forced_bos_token`.
|
294
|
+
|
295
|
+
Args:
|
296
|
+
forced_bos_token_id (:obj:`int`):
|
297
|
+
The id of the token to be generated as the first token.
|
298
|
+
"""
|
299
|
+
|
300
|
+
def __init__(self, forced_bos_token_id: int):
|
301
|
+
self.forced_bos_token_id = forced_bos_token_id
|
302
|
+
|
303
|
+
def __call__(self, input_ids: paddle.Tensor, scores: paddle.Tensor):
|
304
|
+
cur_len = input_ids.shape[-1]
|
305
|
+
if cur_len == 1:
|
306
|
+
scores[:] = paddle.finfo(scores.dtype).min
|
307
|
+
scores[:, self.forced_bos_token_id] = 0
|
308
|
+
return scores
|
309
|
+
|
310
|
+
|
311
|
+
class ForcedEOSTokenLogitsProcessor(LogitsProcessor):
|
312
|
+
"""
|
313
|
+
This `LogitsProcessor` enforces the last generated token to be the selected `forced_eos_token`.
|
314
|
+
|
315
|
+
Args:
|
316
|
+
max_length (int): The maximum length of the sequence to be generated.
|
317
|
+
forced_eos_token_id (int): The id of the token to be generated as the last token.
|
318
|
+
"""
|
319
|
+
|
320
|
+
def __init__(self, max_length: int, forced_eos_token_id: Union[int, List[int]]):
|
321
|
+
self.max_length = max_length
|
322
|
+
self.forced_eos_token_id = forced_eos_token_id
|
323
|
+
|
324
|
+
def __call__(self, input_ids, scores):
|
325
|
+
cur_len = input_ids.shape[-1]
|
326
|
+
if cur_len == self.max_length - 1:
|
327
|
+
scores[:] = paddle.finfo(scores.dtype).min
|
328
|
+
scores[:, self.forced_eos_token_id] = 0
|
329
|
+
return scores
|
330
|
+
|
331
|
+
|
332
|
+
def TopKProcess(probs: paddle.Tensor, top_k: int, min_tokens_to_keep: int):
|
333
|
+
top_k = paddle.minimum(
|
334
|
+
paddle.maximum(paddle.to_tensor(top_k), paddle.to_tensor(min_tokens_to_keep)),
|
335
|
+
paddle.to_tensor(probs.shape[-1]),
|
336
|
+
)
|
337
|
+
# Remove all tokens with a probability less than the last token of the top-k
|
338
|
+
# cast to float16 to support generation & d2s
|
339
|
+
if probs.dtype == paddle.bfloat16:
|
340
|
+
probs = paddle.cast(probs, paddle.float32)
|
341
|
+
topk_probs, _ = paddle.topk(probs, k=top_k)
|
342
|
+
topk_probs = paddle.cast(topk_probs, paddle.bfloat16)
|
343
|
+
else:
|
344
|
+
topk_probs, _ = paddle.topk(probs, k=top_k)
|
345
|
+
|
346
|
+
probs = paddle.where(
|
347
|
+
probs >= topk_probs[:, -1:], probs, paddle.full_like(probs, 0.0)
|
348
|
+
)
|
349
|
+
return probs
|
350
|
+
|
351
|
+
|
352
|
+
def TopPProcess(probs: paddle.Tensor, top_p: float, min_tokens_to_keep: int):
|
353
|
+
if probs.dtype == paddle.bfloat16:
|
354
|
+
probs = paddle.cast(probs, paddle.float32)
|
355
|
+
|
356
|
+
sorted_indices = paddle.argsort(probs, descending=True)
|
357
|
+
sorted_probs = paddle.sort(probs, descending=True)
|
358
|
+
|
359
|
+
sorted_probs = paddle.cast(sorted_probs, paddle.bfloat16)
|
360
|
+
|
361
|
+
else:
|
362
|
+
sorted_indices = paddle.argsort(probs, descending=True)
|
363
|
+
sorted_probs = paddle.sort(probs, descending=True)
|
364
|
+
|
365
|
+
cumulative_probs = paddle.cumsum(sorted_probs, axis=-1)
|
366
|
+
|
367
|
+
# Remove tokens with cumulative probs above the top_p, But keep at
|
368
|
+
# least min_tokens_to_keep tokens
|
369
|
+
sorted_indices_to_remove = cumulative_probs > top_p
|
370
|
+
if min_tokens_to_keep > 1:
|
371
|
+
# Set 'min_tokens_to_keep - 1' because the first token is kept
|
372
|
+
sorted_indices_to_remove[:, : min_tokens_to_keep - 1] = 0
|
373
|
+
# Keep the first token
|
374
|
+
sorted_indices_to_remove = paddle.cast(sorted_indices_to_remove, dtype="int64")
|
375
|
+
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
|
376
|
+
sorted_indices_to_remove[:, 0] = 0
|
377
|
+
|
378
|
+
# Scatter sorted tensors to original indexing
|
379
|
+
sorted_indices = (
|
380
|
+
sorted_indices
|
381
|
+
+ paddle.arange(probs.shape[0], dtype="int64").unsqueeze(-1) * probs.shape[-1]
|
382
|
+
)
|
383
|
+
condition = paddle.scatter(
|
384
|
+
sorted_indices_to_remove.flatten(),
|
385
|
+
sorted_indices.flatten(),
|
386
|
+
sorted_indices_to_remove.flatten(),
|
387
|
+
)
|
388
|
+
condition = paddle.cast(condition, "bool").reshape(probs.shape)
|
389
|
+
probs = paddle.where(condition, paddle.full_like(probs, 0.0), probs)
|
390
|
+
return probs
|
391
|
+
|
392
|
+
|
393
|
+
class LogitsWarper:
|
394
|
+
"""Abstract base class for all logit warpers that can be applied during generation with multinomial sampling."""
|
395
|
+
|
396
|
+
def __call__(self, input_ids: paddle.Tensor, scores: paddle.Tensor):
|
397
|
+
raise NotImplementedError(
|
398
|
+
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
|
399
|
+
)
|
400
|
+
|
401
|
+
|
402
|
+
class TemperatureLogitsWarper(LogitsWarper):
|
403
|
+
r"""
|
404
|
+
[`LogitsWarper`] for temperature (exponential scaling output probability distribution).
|
405
|
+
Args:
|
406
|
+
temperature (`float`):
|
407
|
+
The value used to module the logits distribution.
|
408
|
+
"""
|
409
|
+
|
410
|
+
def __init__(self, temperature: float):
|
411
|
+
if not isinstance(temperature, float) or not (temperature > 0):
|
412
|
+
raise ValueError(
|
413
|
+
f"`temperature` has to be a strictly positive float, but is {temperature}"
|
414
|
+
)
|
415
|
+
|
416
|
+
self.temperature = temperature
|
417
|
+
|
418
|
+
def __call__(self, input_ids: paddle.Tensor, scores: paddle.Tensor):
|
419
|
+
scores = scores / self.temperature
|
420
|
+
return scores
|
421
|
+
|
422
|
+
|
423
|
+
class SequenceBiasLogitsProcessor(LogitsProcessor):
|
424
|
+
"""
|
425
|
+
[`LogitsProcessor`] that applies an additive bias on sequences. The bias is applied to the last token of a sequence
|
426
|
+
when the next generated token can complete it. Consequently, to take the most of biasing sequences with more than
|
427
|
+
one token, consider using beam methods (to gracefully work around partially completed sequences that have a
|
428
|
+
negative bias) and applying the bias to their prefixes (to ensure the bias is applied earlier).
|
429
|
+
|
430
|
+
<Tip>
|
431
|
+
|
432
|
+
In order to get the token ids of the sequences that you want to bias, make sure to set `add_prefix_space=True` when
|
433
|
+
initializing the tokenizer, and use `tokenizer(bad_words, add_special_tokens=False).input_ids`. The
|
434
|
+
`add_prefix_space` argument is only supported for some slow tokenizers, as fast tokenizers' prefixing behaviours
|
435
|
+
come from `pre tokenizers`.
|
436
|
+
|
437
|
+
</Tip>
|
438
|
+
|
439
|
+
Args:
|
440
|
+
sequence_bias (`Dict[Tuple[int], float]`):
|
441
|
+
Dictionary that maps a sequence of tokens to its bias term. Positive biases increase the odds of the
|
442
|
+
sequence being selected, while negative biases do the opposite. If a sequence has a length of 1, its bias
|
443
|
+
will always be applied. Otherwise, the bias will only be applied if the sequence in question is about to be
|
444
|
+
completed (in the token selection step after this processor is applied).
|
445
|
+
|
446
|
+
Examples:
|
447
|
+
|
448
|
+
```python
|
449
|
+
>>> from paddlenlp.transformers import AutoTokenizer, AutoModelForCausalLM
|
450
|
+
|
451
|
+
>>> model = AutoModelForCausalLM.from_pretrained("gpt2-en")
|
452
|
+
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2-en")
|
453
|
+
>>> inputs = tokenizer(["The full name of Donald is Donald"], return_tensors="pt")
|
454
|
+
|
455
|
+
>>> summary_ids = model.generate(inputs["input_ids"], max_new_tokens=4)
|
456
|
+
>>> print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True)[0])
|
457
|
+
The full name of Donald is Donald J. Trump Jr
|
458
|
+
|
459
|
+
>>> # Now let's control generation through a bias. Please note that the tokenizer is initialized differently!
|
460
|
+
>>> tokenizer_with_prefix_space = AutoTokenizer.from_pretrained("gpt2-en")
|
461
|
+
|
462
|
+
|
463
|
+
>>> def get_tokens_as_tuple(word):
|
464
|
+
... return tuple(tokenizer_with_prefix_space([word], add_special_tokens=False).input_ids[0])
|
465
|
+
|
466
|
+
|
467
|
+
>>> # If we add a negative bias without beam search, it may become "stuck" in a prefix without good continuations
|
468
|
+
>>> sequence_bias = {get_tokens_as_tuple("Trump"): -10.0}
|
469
|
+
>>> biased_ids = model.generate(inputs["input_ids"], max_new_tokens=4, sequence_bias=sequence_bias)
|
470
|
+
>>> print(tokenizer.batch_decode(biased_ids, skip_special_tokens=True)[0])
|
471
|
+
The full name of Donald is Donald J. Donald,
|
472
|
+
|
473
|
+
>>> biased_ids = model.generate(inputs["input_ids"], max_new_tokens=4, num_beams=4, sequence_bias=sequence_bias)
|
474
|
+
>>> print(tokenizer.batch_decode(biased_ids, skip_special_tokens=True)[0])
|
475
|
+
The full name of Donald is Donald Rumsfeld,
|
476
|
+
|
477
|
+
>>> # We can also add a positive bias to nudge the model towards specific tokens or continuations
|
478
|
+
>>> sequence_bias = {get_tokens_as_tuple("Donald Duck"): 10.0}
|
479
|
+
>>> biased_ids = model.generate(inputs["input_ids"], max_new_tokens=4, num_beams=4, sequence_bias=sequence_bias)
|
480
|
+
>>> print(tokenizer.batch_decode(biased_ids, skip_special_tokens=True)[0])
|
481
|
+
The full name of Donald is Donald Duck.
|
482
|
+
```
|
483
|
+
"""
|
484
|
+
|
485
|
+
def __init__(self, sequence_bias: Dict[Tuple[int], float]):
|
486
|
+
self.sequence_bias = sequence_bias
|
487
|
+
self._validate_arguments()
|
488
|
+
|
489
|
+
# Bias variables that will be populated on the first call (for retrocompatibility purposes, the vocabulary size
|
490
|
+
# is infered in the first usage, which inhibits initializing here)
|
491
|
+
self.length_1_bias = None
|
492
|
+
self.prepared_bias_variables = False
|
493
|
+
|
494
|
+
def __call__(self, input_ids, scores):
|
495
|
+
# 1 - Prepares the bias tensors. This is only needed the first time the logit processor is called.
|
496
|
+
if not self.prepared_bias_variables:
|
497
|
+
self._prepare_bias_variables(scores)
|
498
|
+
|
499
|
+
# 2 - prepares an empty bias to add
|
500
|
+
bias = paddle.zeros_like(scores)
|
501
|
+
|
502
|
+
# 3 - include the bias from length = 1
|
503
|
+
if self.length_1_bias is not None:
|
504
|
+
bias += self.length_1_bias
|
505
|
+
|
506
|
+
# 4 - include the bias from length > 1, after determining which biased sequences may be completed.
|
507
|
+
for sequence_ids, sequence_bias in self.sequence_bias.items():
|
508
|
+
if len(sequence_ids) == 1: # the sequence is of length 1, already applied
|
509
|
+
continue
|
510
|
+
if (
|
511
|
+
len(sequence_ids) > input_ids.shape[1]
|
512
|
+
): # the sequence is longer than the context, ignore
|
513
|
+
continue
|
514
|
+
prefix_length = len(sequence_ids) - 1
|
515
|
+
last_token = sequence_ids[-1]
|
516
|
+
matching_rows = (
|
517
|
+
paddle.equal(
|
518
|
+
input_ids[:, -prefix_length:],
|
519
|
+
paddle.to_tensor(sequence_ids[:-1], dtype=input_ids.dtype),
|
520
|
+
)
|
521
|
+
.astype(paddle.int64)
|
522
|
+
.prod(axis=1)
|
523
|
+
)
|
524
|
+
bias[:, last_token] += paddle.where(
|
525
|
+
matching_rows == 1,
|
526
|
+
paddle.to_tensor(sequence_bias),
|
527
|
+
paddle.to_tensor(0.0),
|
528
|
+
)
|
529
|
+
|
530
|
+
# 5 - apply the bias to the scores
|
531
|
+
scores = scores + bias
|
532
|
+
return scores
|
533
|
+
|
534
|
+
def _prepare_bias_variables(self, scores):
|
535
|
+
vocabulary_size = scores.shape[-1]
|
536
|
+
|
537
|
+
# Check biased tokens out of bounds
|
538
|
+
invalid_biases = []
|
539
|
+
for sequence_ids in self.sequence_bias:
|
540
|
+
for token_id in sequence_ids:
|
541
|
+
if token_id >= vocabulary_size:
|
542
|
+
invalid_biases.append(token_id)
|
543
|
+
if len(invalid_biases) > 0:
|
544
|
+
raise ValueError(
|
545
|
+
f"The model vocabulary size is {vocabulary_size}, but the following tokens were being biased: "
|
546
|
+
f"{invalid_biases}"
|
547
|
+
)
|
548
|
+
|
549
|
+
# Precompute the bias tensors to be applied. Sequences of length 1 are kept separately, as they can be applied
|
550
|
+
# with simpler logic.
|
551
|
+
self.length_1_bias = paddle.zeros((vocabulary_size,))
|
552
|
+
for sequence_ids, bias in self.sequence_bias.items():
|
553
|
+
if len(sequence_ids) == 1:
|
554
|
+
self.length_1_bias[sequence_ids[-1]] = bias
|
555
|
+
|
556
|
+
self.prepared_bias_variables = True
|
557
|
+
|
558
|
+
def _validate_arguments(self):
|
559
|
+
sequence_bias = self.sequence_bias
|
560
|
+
if not isinstance(sequence_bias, dict) or len(sequence_bias) == 0:
|
561
|
+
raise ValueError(
|
562
|
+
f"`sequence_bias` has to be a non-empty dictionary, but is {sequence_bias}."
|
563
|
+
)
|
564
|
+
if any(
|
565
|
+
not isinstance(sequence_ids, tuple) for sequence_ids in sequence_bias.keys()
|
566
|
+
):
|
567
|
+
raise ValueError(
|
568
|
+
f"`sequence_bias` has to be a dict with tuples as keys, but is {sequence_bias}."
|
569
|
+
)
|
570
|
+
if any(
|
571
|
+
any(
|
572
|
+
(not isinstance(token_id, (int, np.integer)) or token_id < 0)
|
573
|
+
for token_id in sequence_ids
|
574
|
+
)
|
575
|
+
or len(sequence_ids) == 0
|
576
|
+
for sequence_ids in sequence_bias.keys()
|
577
|
+
):
|
578
|
+
raise ValueError(
|
579
|
+
f"Each key in `sequence_bias` has to be a non-empty tuple of positive integers, but is "
|
580
|
+
f"{sequence_bias}."
|
581
|
+
)
|
582
|
+
if any(not isinstance(bias, float) for bias in sequence_bias.values()):
|
583
|
+
raise ValueError(
|
584
|
+
f"`sequence_bias` has to be a dict with floats as values, but is {sequence_bias}."
|
585
|
+
)
|
586
|
+
|
587
|
+
|
588
|
+
class NoBadWordsLogitsProcessor(SequenceBiasLogitsProcessor):
|
589
|
+
"""
|
590
|
+
[`LogitsProcessor`] that enforces that specified sequences will never be selected.
|
591
|
+
|
592
|
+
<Tip>
|
593
|
+
|
594
|
+
In order to get the token ids of the words that should not appear in the generated text, make sure to set
|
595
|
+
`add_prefix_space=True` when initializing the tokenizer, and use `tokenizer(bad_words,
|
596
|
+
add_special_tokens=False).input_ids`. The `add_prefix_space` argument is only supported for some slow tokenizers,
|
597
|
+
as fast tokenizers' prefixing behaviours come from `pre tokenizers`. Read more
|
598
|
+
[here](https://huggingface.co/docs/tokenizers/api/pre-tokenizers).
|
599
|
+
|
600
|
+
</Tip>
|
601
|
+
|
602
|
+
Args:
|
603
|
+
bad_words_ids (`List[List[int]]`):
|
604
|
+
List of list of token ids that are not allowed to be generated.
|
605
|
+
eos_token_id (`Union[int, List[int]]`):
|
606
|
+
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
|
607
|
+
|
608
|
+
Examples:
|
609
|
+
|
610
|
+
```python
|
611
|
+
>>> from paddlenlp.transformers import AutoTokenizer, AutoModelForCausalLM
|
612
|
+
|
613
|
+
>>> model = AutoModelForCausalLM.from_pretrained("gpt2-en")
|
614
|
+
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2-en")
|
615
|
+
>>> inputs = tokenizer(["In a word, the cake is a"], return_tensors="pt")
|
616
|
+
|
617
|
+
>>> output_ids = model.generate(inputs["input_ids"], max_new_tokens=5, pad_token_id=tokenizer.eos_token_id)
|
618
|
+
>>> print(tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0])
|
619
|
+
In a word, the cake is a bit of a mess.
|
620
|
+
|
621
|
+
>>> # Now let's take the bad words out. Please note that the tokenizer is initialized differently
|
622
|
+
>>> tokenizer_with_prefix_space = AutoTokenizer.from_pretrained("gpt2-en", add_prefix_space=True)
|
623
|
+
|
624
|
+
|
625
|
+
>>> def get_tokens_as_list(word_list):
|
626
|
+
... "Converts a sequence of words into a list of tokens"
|
627
|
+
... tokens_list = []
|
628
|
+
... for word in word_list:
|
629
|
+
... tokenized_word = tokenizer_with_prefix_space([word], add_special_tokens=False).input_ids[0]
|
630
|
+
... tokens_list.append(tokenized_word)
|
631
|
+
... return tokens_list
|
632
|
+
|
633
|
+
|
634
|
+
>>> bad_words_ids = get_tokens_as_list(word_list=["mess"])
|
635
|
+
>>> output_ids = model.generate(
|
636
|
+
... inputs["input_ids"], max_new_tokens=5, bad_words_ids=bad_words_ids, pad_token_id=tokenizer.eos_token_id
|
637
|
+
... )
|
638
|
+
>>> print(tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0])
|
639
|
+
In a word, the cake is a bit of a surprise.
|
640
|
+
```
|
641
|
+
|
642
|
+
>>> from paddlenlp.transformers.generation import NoBadWordsLogitsProcessor, LogitsProcessorList
|
643
|
+
>>> logits_processors = LogitsProcessorList([NoBadWordsLogitsProcessor([[5,6]], eos_token_id=tokenizer.eos_token_id)])
|
644
|
+
>>> output_ids = model.generate(
|
645
|
+
... inputs["input_ids"], max_new_tokens=5, logits_processors=logits_processors, pad_token_id=tokenizer.eos_token_id
|
646
|
+
... )
|
647
|
+
>>> print(tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0])
|
648
|
+
In a word, the cake is a bit of a surprise.
|
649
|
+
```
|
650
|
+
"""
|
651
|
+
|
652
|
+
def __init__(
|
653
|
+
self, bad_words_ids: List[List[int]], eos_token_id: Union[int, List[int]]
|
654
|
+
):
|
655
|
+
self.bad_word_ids = bad_words_ids
|
656
|
+
self._validate_arguments()
|
657
|
+
|
658
|
+
# Filter EOS token from bad_words_ids
|
659
|
+
if eos_token_id is None:
|
660
|
+
eos_token_id = []
|
661
|
+
if isinstance(eos_token_id, int):
|
662
|
+
eos_token_id = [eos_token_id]
|
663
|
+
bad_words_ids = list(
|
664
|
+
filter(
|
665
|
+
lambda bad_token_seq: all(bad_token_seq != [i] for i in eos_token_id),
|
666
|
+
bad_words_ids,
|
667
|
+
)
|
668
|
+
)
|
669
|
+
|
670
|
+
# Forbidding a sequence is equivalent to setting its bias to -inf
|
671
|
+
sequence_bias = {tuple(sequence): float("-inf") for sequence in bad_words_ids}
|
672
|
+
super().__init__(sequence_bias=sequence_bias)
|
673
|
+
|
674
|
+
def _validate_arguments(self):
|
675
|
+
bad_words_ids = self.bad_word_ids
|
676
|
+
if not isinstance(bad_words_ids, list) or len(bad_words_ids) == 0:
|
677
|
+
raise ValueError(
|
678
|
+
f"`bad_words_ids` has to be a non-empty list, but is {bad_words_ids}."
|
679
|
+
)
|
680
|
+
if any(not isinstance(bad_word_ids, list) for bad_word_ids in bad_words_ids):
|
681
|
+
raise ValueError(
|
682
|
+
f"`bad_words_ids` has to be a list of lists, but is {bad_words_ids}."
|
683
|
+
)
|
684
|
+
if any(
|
685
|
+
any(
|
686
|
+
(not isinstance(token_id, (int, np.integer)) or token_id < 0)
|
687
|
+
for token_id in bad_word_ids
|
688
|
+
)
|
689
|
+
for bad_word_ids in bad_words_ids
|
690
|
+
):
|
691
|
+
raise ValueError(
|
692
|
+
f"Each list in `bad_words_ids` has to be a list of positive integers, but is {bad_words_ids}."
|
693
|
+
)
|
694
|
+
|
695
|
+
|
696
|
+
class PrefixConstrainedLogitsProcessor(LogitsProcessor):
|
697
|
+
r"""
|
698
|
+
[`LogitsProcessor`] that enforces constrained generation and is useful for prefix-conditioned constrained
|
699
|
+
generation. See [Autoregressive Entity Retrieval](https://arxiv.org/abs/2010.00904) for more information.
|
700
|
+
|
701
|
+
Args:
|
702
|
+
prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`):
|
703
|
+
This function constraints the beam search to allowed tokens only at each step. This function takes 2
|
704
|
+
arguments `inputs_ids` and the batch ID `batch_id`. It has to return a list with the allowed tokens for the
|
705
|
+
next generation step conditioned on the previously generated tokens `inputs_ids` and the batch ID
|
706
|
+
`batch_id`.
|
707
|
+
"""
|
708
|
+
|
709
|
+
def __init__(
|
710
|
+
self,
|
711
|
+
prefix_allowed_tokens_fn: Callable[[int, paddle.Tensor], List[int]],
|
712
|
+
num_beams: int,
|
713
|
+
):
|
714
|
+
self._prefix_allowed_tokens_fn = prefix_allowed_tokens_fn
|
715
|
+
self._num_beams = num_beams
|
716
|
+
|
717
|
+
def __call__(
|
718
|
+
self, input_ids: paddle.Tensor, scores: paddle.Tensor
|
719
|
+
) -> paddle.Tensor:
|
720
|
+
mask = paddle.full_like(scores, paddle.finfo(scores.dtype).min)
|
721
|
+
for batch_id, beam_sent in enumerate(
|
722
|
+
input_ids.reshape([-1, self._num_beams, input_ids.shape[-1]])
|
723
|
+
):
|
724
|
+
for beam_id, sent in enumerate(beam_sent):
|
725
|
+
mask[
|
726
|
+
batch_id * self._num_beams + beam_id,
|
727
|
+
self._prefix_allowed_tokens_fn(batch_id, sent),
|
728
|
+
] = 0
|
729
|
+
|
730
|
+
return scores + mask
|