paddlex 3.0.0rc0__py3-none-any.whl → 3.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- paddlex/.version +1 -1
- paddlex/__init__.py +17 -34
- paddlex/__main__.py +1 -1
- paddlex/configs/modules/chart_parsing/PP-Chart2Table.yaml +13 -0
- paddlex/configs/modules/doc_vlm/PP-DocBee-2B.yaml +14 -0
- paddlex/configs/modules/doc_vlm/PP-DocBee-7B.yaml +14 -0
- paddlex/configs/modules/doc_vlm/PP-DocBee2-3B.yaml +14 -0
- paddlex/configs/modules/formula_recognition/PP-FormulaNet_plus-L.yaml +40 -0
- paddlex/configs/modules/formula_recognition/PP-FormulaNet_plus-M.yaml +40 -0
- paddlex/configs/modules/formula_recognition/PP-FormulaNet_plus-S.yaml +40 -0
- paddlex/configs/modules/layout_detection/PP-DocBlockLayout.yaml +40 -0
- paddlex/configs/modules/layout_detection/PP-DocLayout-L.yaml +2 -2
- paddlex/configs/modules/layout_detection/PP-DocLayout-M.yaml +2 -2
- paddlex/configs/modules/layout_detection/PP-DocLayout-S.yaml +2 -2
- paddlex/configs/modules/layout_detection/PP-DocLayout_plus-L.yaml +40 -0
- paddlex/configs/modules/open_vocabulary_detection/YOLO-Worldv2-L.yaml +13 -0
- paddlex/configs/modules/text_detection/PP-OCRv5_mobile_det.yaml +40 -0
- paddlex/configs/modules/text_detection/PP-OCRv5_server_det.yaml +40 -0
- paddlex/configs/modules/text_recognition/PP-OCRv5_mobile_rec.yaml +39 -0
- paddlex/configs/modules/text_recognition/PP-OCRv5_server_rec.yaml +39 -0
- paddlex/configs/modules/textline_orientation/PP-LCNet_x1_0_textline_ori.yaml +41 -0
- paddlex/configs/pipelines/OCR.yaml +7 -6
- paddlex/configs/pipelines/PP-ChatOCRv3-doc.yaml +3 -1
- paddlex/configs/pipelines/PP-ChatOCRv4-doc.yaml +91 -34
- paddlex/configs/pipelines/PP-StructureV3.yaml +72 -72
- paddlex/configs/pipelines/anomaly_detection.yaml +1 -1
- paddlex/configs/pipelines/doc_understanding.yaml +9 -0
- paddlex/configs/pipelines/formula_recognition.yaml +2 -2
- paddlex/configs/pipelines/layout_parsing.yaml +3 -2
- paddlex/configs/pipelines/seal_recognition.yaml +1 -0
- paddlex/configs/pipelines/table_recognition.yaml +2 -1
- paddlex/configs/pipelines/table_recognition_v2.yaml +7 -1
- paddlex/configs/pipelines/ts_anomaly_detection.yaml +1 -1
- paddlex/configs/pipelines/ts_classification.yaml +1 -1
- paddlex/configs/pipelines/ts_forecast.yaml +1 -1
- paddlex/constants.py +17 -0
- paddlex/engine.py +7 -5
- paddlex/hpip_links.html +23 -11
- paddlex/inference/__init__.py +3 -3
- paddlex/inference/common/__init__.py +1 -1
- paddlex/inference/common/batch_sampler/__init__.py +5 -4
- paddlex/inference/common/batch_sampler/audio_batch_sampler.py +5 -6
- paddlex/inference/common/batch_sampler/base_batch_sampler.py +20 -16
- paddlex/inference/common/batch_sampler/det_3d_batch_sampler.py +4 -7
- paddlex/inference/common/batch_sampler/doc_vlm_batch_sampler.py +87 -0
- paddlex/inference/common/batch_sampler/image_batch_sampler.py +45 -60
- paddlex/inference/common/batch_sampler/ts_batch_sampler.py +9 -10
- paddlex/inference/common/batch_sampler/video_batch_sampler.py +2 -22
- paddlex/inference/common/reader/__init__.py +4 -4
- paddlex/inference/common/reader/audio_reader.py +3 -3
- paddlex/inference/common/reader/det_3d_reader.py +7 -5
- paddlex/inference/common/reader/image_reader.py +16 -12
- paddlex/inference/common/reader/ts_reader.py +3 -2
- paddlex/inference/common/reader/video_reader.py +3 -3
- paddlex/inference/common/result/__init__.py +7 -7
- paddlex/inference/common/result/base_cv_result.py +12 -2
- paddlex/inference/common/result/base_result.py +7 -5
- paddlex/inference/common/result/base_ts_result.py +1 -2
- paddlex/inference/common/result/base_video_result.py +2 -2
- paddlex/inference/common/result/mixin.py +31 -25
- paddlex/inference/models/__init__.py +41 -85
- paddlex/inference/models/anomaly_detection/__init__.py +1 -1
- paddlex/inference/models/anomaly_detection/predictor.py +9 -19
- paddlex/inference/models/anomaly_detection/processors.py +9 -2
- paddlex/inference/models/anomaly_detection/result.py +3 -2
- paddlex/inference/models/base/__init__.py +2 -2
- paddlex/inference/models/base/predictor/__init__.py +1 -2
- paddlex/inference/models/base/predictor/base_predictor.py +278 -39
- paddlex/inference/models/common/__init__.py +6 -15
- paddlex/inference/models/common/static_infer.py +724 -251
- paddlex/inference/models/common/tokenizer/__init__.py +7 -3
- paddlex/inference/models/common/tokenizer/bert_tokenizer.py +1 -1
- paddlex/inference/models/common/tokenizer/clip_tokenizer.py +609 -0
- paddlex/inference/models/common/tokenizer/gpt_tokenizer.py +9 -7
- paddlex/inference/models/common/tokenizer/qwen2_5_tokenizer.py +112 -0
- paddlex/inference/models/common/tokenizer/qwen2_tokenizer.py +438 -0
- paddlex/inference/models/common/tokenizer/qwen_tokenizer.py +288 -0
- paddlex/inference/models/common/tokenizer/tokenizer_utils.py +85 -77
- paddlex/inference/models/common/tokenizer/tokenizer_utils_base.py +339 -123
- paddlex/inference/models/common/tokenizer/utils.py +1 -1
- paddlex/inference/models/common/tokenizer/vocab.py +8 -8
- paddlex/inference/models/common/ts/__init__.py +1 -1
- paddlex/inference/models/common/ts/funcs.py +13 -6
- paddlex/inference/models/common/ts/processors.py +14 -5
- paddlex/inference/models/common/vision/__init__.py +3 -3
- paddlex/inference/models/common/vision/funcs.py +17 -12
- paddlex/inference/models/common/vision/processors.py +61 -46
- paddlex/inference/models/common/vlm/__init__.py +13 -0
- paddlex/inference/models/common/vlm/activations.py +189 -0
- paddlex/inference/models/common/vlm/bert_padding.py +127 -0
- paddlex/inference/models/common/vlm/conversion_utils.py +99 -0
- paddlex/inference/models/common/vlm/distributed.py +229 -0
- paddlex/inference/models/common/vlm/flash_attn_utils.py +119 -0
- paddlex/inference/models/common/vlm/fusion_ops.py +205 -0
- paddlex/inference/models/common/vlm/generation/__init__.py +34 -0
- paddlex/inference/models/common/vlm/generation/configuration_utils.py +533 -0
- paddlex/inference/models/common/vlm/generation/logits_process.py +730 -0
- paddlex/inference/models/common/vlm/generation/stopping_criteria.py +106 -0
- paddlex/inference/models/common/vlm/generation/utils.py +2162 -0
- paddlex/inference/models/common/vlm/transformers/__init__.py +16 -0
- paddlex/inference/models/common/vlm/transformers/configuration_utils.py +1037 -0
- paddlex/inference/models/common/vlm/transformers/conversion_utils.py +408 -0
- paddlex/inference/models/common/vlm/transformers/model_outputs.py +1612 -0
- paddlex/inference/models/common/vlm/transformers/model_utils.py +2014 -0
- paddlex/inference/models/common/vlm/transformers/utils.py +178 -0
- paddlex/inference/models/common/vlm/utils.py +109 -0
- paddlex/inference/models/doc_vlm/__init__.py +15 -0
- paddlex/inference/models/doc_vlm/modeling/GOT_ocr_2_0.py +830 -0
- paddlex/inference/models/doc_vlm/modeling/__init__.py +17 -0
- paddlex/inference/models/doc_vlm/modeling/qwen2.py +1606 -0
- paddlex/inference/models/doc_vlm/modeling/qwen2_5_vl.py +3006 -0
- paddlex/inference/models/doc_vlm/modeling/qwen2_vl.py +2495 -0
- paddlex/inference/models/doc_vlm/predictor.py +253 -0
- paddlex/inference/models/doc_vlm/processors/GOT_ocr_2_0.py +97 -0
- paddlex/inference/models/doc_vlm/processors/__init__.py +17 -0
- paddlex/inference/models/doc_vlm/processors/common.py +561 -0
- paddlex/inference/models/doc_vlm/processors/qwen2_5_vl.py +548 -0
- paddlex/inference/models/doc_vlm/processors/qwen2_vl.py +543 -0
- paddlex/inference/models/doc_vlm/result.py +21 -0
- paddlex/inference/models/face_feature/__init__.py +1 -1
- paddlex/inference/models/face_feature/predictor.py +2 -1
- paddlex/inference/models/formula_recognition/__init__.py +1 -1
- paddlex/inference/models/formula_recognition/predictor.py +18 -28
- paddlex/inference/models/formula_recognition/processors.py +126 -97
- paddlex/inference/models/formula_recognition/result.py +43 -35
- paddlex/inference/models/image_classification/__init__.py +1 -1
- paddlex/inference/models/image_classification/predictor.py +9 -19
- paddlex/inference/models/image_classification/processors.py +4 -2
- paddlex/inference/models/image_classification/result.py +4 -3
- paddlex/inference/models/image_feature/__init__.py +1 -1
- paddlex/inference/models/image_feature/predictor.py +9 -19
- paddlex/inference/models/image_feature/processors.py +7 -5
- paddlex/inference/models/image_feature/result.py +2 -3
- paddlex/inference/models/image_multilabel_classification/__init__.py +1 -1
- paddlex/inference/models/image_multilabel_classification/predictor.py +7 -6
- paddlex/inference/models/image_multilabel_classification/processors.py +6 -2
- paddlex/inference/models/image_multilabel_classification/result.py +4 -3
- paddlex/inference/models/image_unwarping/__init__.py +1 -1
- paddlex/inference/models/image_unwarping/predictor.py +8 -16
- paddlex/inference/models/image_unwarping/processors.py +6 -2
- paddlex/inference/models/image_unwarping/result.py +4 -2
- paddlex/inference/models/instance_segmentation/__init__.py +1 -1
- paddlex/inference/models/instance_segmentation/predictor.py +7 -15
- paddlex/inference/models/instance_segmentation/processors.py +4 -7
- paddlex/inference/models/instance_segmentation/result.py +11 -10
- paddlex/inference/models/keypoint_detection/__init__.py +1 -1
- paddlex/inference/models/keypoint_detection/predictor.py +5 -3
- paddlex/inference/models/keypoint_detection/processors.py +11 -3
- paddlex/inference/models/keypoint_detection/result.py +9 -4
- paddlex/inference/models/{3d_bev_detection → m_3d_bev_detection}/__init__.py +1 -1
- paddlex/inference/models/{3d_bev_detection → m_3d_bev_detection}/predictor.py +15 -26
- paddlex/inference/models/{3d_bev_detection → m_3d_bev_detection}/processors.py +26 -14
- paddlex/inference/models/{3d_bev_detection → m_3d_bev_detection}/result.py +15 -12
- paddlex/inference/models/{3d_bev_detection → m_3d_bev_detection}/visualizer_3d.py +77 -39
- paddlex/inference/models/multilingual_speech_recognition/__init__.py +1 -1
- paddlex/inference/models/multilingual_speech_recognition/predictor.py +11 -15
- paddlex/inference/models/multilingual_speech_recognition/processors.py +45 -53
- paddlex/inference/models/multilingual_speech_recognition/result.py +1 -1
- paddlex/inference/models/object_detection/__init__.py +1 -1
- paddlex/inference/models/object_detection/predictor.py +8 -12
- paddlex/inference/models/object_detection/processors.py +63 -33
- paddlex/inference/models/object_detection/result.py +5 -4
- paddlex/inference/models/object_detection/utils.py +3 -1
- paddlex/inference/models/open_vocabulary_detection/__init__.py +1 -1
- paddlex/inference/models/open_vocabulary_detection/predictor.py +31 -14
- paddlex/inference/models/open_vocabulary_detection/processors/__init__.py +3 -2
- paddlex/inference/models/open_vocabulary_detection/processors/common.py +114 -0
- paddlex/inference/models/open_vocabulary_detection/processors/groundingdino_processors.py +19 -8
- paddlex/inference/models/open_vocabulary_detection/processors/yoloworld_processors.py +209 -0
- paddlex/inference/models/open_vocabulary_segmentation/__init__.py +1 -1
- paddlex/inference/models/open_vocabulary_segmentation/predictor.py +6 -13
- paddlex/inference/models/open_vocabulary_segmentation/processors/__init__.py +1 -1
- paddlex/inference/models/open_vocabulary_segmentation/processors/sam_processer.py +12 -12
- paddlex/inference/models/open_vocabulary_segmentation/results/__init__.py +1 -1
- paddlex/inference/models/open_vocabulary_segmentation/results/sam_result.py +11 -9
- paddlex/inference/models/semantic_segmentation/__init__.py +1 -1
- paddlex/inference/models/semantic_segmentation/predictor.py +9 -18
- paddlex/inference/models/semantic_segmentation/processors.py +11 -8
- paddlex/inference/models/semantic_segmentation/result.py +4 -3
- paddlex/inference/models/table_structure_recognition/__init__.py +1 -1
- paddlex/inference/models/table_structure_recognition/predictor.py +8 -18
- paddlex/inference/models/table_structure_recognition/processors.py +23 -29
- paddlex/inference/models/table_structure_recognition/result.py +8 -15
- paddlex/inference/models/text_detection/__init__.py +1 -1
- paddlex/inference/models/text_detection/predictor.py +24 -24
- paddlex/inference/models/text_detection/processors.py +116 -44
- paddlex/inference/models/text_detection/result.py +8 -13
- paddlex/inference/models/text_recognition/__init__.py +1 -1
- paddlex/inference/models/text_recognition/predictor.py +11 -19
- paddlex/inference/models/text_recognition/processors.py +27 -13
- paddlex/inference/models/text_recognition/result.py +3 -2
- paddlex/inference/models/ts_anomaly_detection/__init__.py +1 -1
- paddlex/inference/models/ts_anomaly_detection/predictor.py +12 -17
- paddlex/inference/models/ts_anomaly_detection/processors.py +6 -2
- paddlex/inference/models/ts_anomaly_detection/result.py +21 -10
- paddlex/inference/models/ts_classification/__init__.py +1 -1
- paddlex/inference/models/ts_classification/predictor.py +14 -27
- paddlex/inference/models/ts_classification/processors.py +7 -2
- paddlex/inference/models/ts_classification/result.py +21 -12
- paddlex/inference/models/ts_forecasting/__init__.py +1 -1
- paddlex/inference/models/ts_forecasting/predictor.py +13 -18
- paddlex/inference/models/ts_forecasting/processors.py +12 -3
- paddlex/inference/models/ts_forecasting/result.py +24 -11
- paddlex/inference/models/video_classification/__init__.py +1 -1
- paddlex/inference/models/video_classification/predictor.py +9 -15
- paddlex/inference/models/video_classification/processors.py +24 -24
- paddlex/inference/models/video_classification/result.py +7 -3
- paddlex/inference/models/video_detection/__init__.py +1 -1
- paddlex/inference/models/video_detection/predictor.py +8 -15
- paddlex/inference/models/video_detection/processors.py +24 -11
- paddlex/inference/models/video_detection/result.py +10 -5
- paddlex/inference/pipelines/__init__.py +48 -37
- paddlex/inference/pipelines/_parallel.py +172 -0
- paddlex/inference/pipelines/anomaly_detection/__init__.py +1 -1
- paddlex/inference/pipelines/anomaly_detection/pipeline.py +29 -9
- paddlex/inference/pipelines/attribute_recognition/__init__.py +1 -1
- paddlex/inference/pipelines/attribute_recognition/pipeline.py +24 -9
- paddlex/inference/pipelines/attribute_recognition/result.py +10 -8
- paddlex/inference/pipelines/base.py +43 -13
- paddlex/inference/pipelines/components/__init__.py +14 -8
- paddlex/inference/pipelines/components/chat_server/__init__.py +1 -1
- paddlex/inference/pipelines/components/chat_server/base.py +2 -2
- paddlex/inference/pipelines/components/chat_server/openai_bot_chat.py +8 -8
- paddlex/inference/pipelines/components/common/__init__.py +5 -4
- paddlex/inference/pipelines/components/common/base_operator.py +2 -1
- paddlex/inference/pipelines/components/common/base_result.py +3 -2
- paddlex/inference/pipelines/components/common/convert_points_and_boxes.py +1 -2
- paddlex/inference/pipelines/components/common/crop_image_regions.py +11 -5
- paddlex/inference/pipelines/components/common/seal_det_warp.py +44 -13
- paddlex/inference/pipelines/components/common/sort_boxes.py +4 -2
- paddlex/inference/pipelines/components/common/warp_image.py +50 -0
- paddlex/inference/pipelines/components/faisser.py +10 -5
- paddlex/inference/pipelines/components/prompt_engineering/__init__.py +2 -2
- paddlex/inference/pipelines/components/prompt_engineering/base.py +2 -2
- paddlex/inference/pipelines/components/prompt_engineering/generate_ensemble_prompt.py +2 -1
- paddlex/inference/pipelines/components/prompt_engineering/generate_kie_prompt.py +2 -2
- paddlex/inference/pipelines/components/retriever/__init__.py +2 -2
- paddlex/inference/pipelines/components/retriever/base.py +18 -16
- paddlex/inference/pipelines/components/retriever/openai_bot_retriever.py +2 -2
- paddlex/inference/pipelines/components/retriever/qianfan_bot_retriever.py +87 -84
- paddlex/inference/pipelines/components/utils/__init__.py +1 -1
- paddlex/inference/pipelines/components/utils/mixin.py +7 -7
- paddlex/inference/pipelines/doc_preprocessor/__init__.py +1 -1
- paddlex/inference/pipelines/doc_preprocessor/pipeline.py +70 -51
- paddlex/inference/pipelines/doc_preprocessor/result.py +5 -10
- paddlex/inference/pipelines/doc_understanding/__init__.py +15 -0
- paddlex/inference/pipelines/doc_understanding/pipeline.py +71 -0
- paddlex/inference/pipelines/face_recognition/__init__.py +1 -1
- paddlex/inference/pipelines/face_recognition/pipeline.py +3 -1
- paddlex/inference/pipelines/face_recognition/result.py +3 -2
- paddlex/inference/pipelines/formula_recognition/__init__.py +1 -1
- paddlex/inference/pipelines/formula_recognition/pipeline.py +137 -93
- paddlex/inference/pipelines/formula_recognition/result.py +20 -29
- paddlex/inference/pipelines/image_classification/__init__.py +1 -1
- paddlex/inference/pipelines/image_classification/pipeline.py +30 -11
- paddlex/inference/pipelines/image_multilabel_classification/__init__.py +1 -1
- paddlex/inference/pipelines/image_multilabel_classification/pipeline.py +31 -12
- paddlex/inference/pipelines/instance_segmentation/__init__.py +1 -1
- paddlex/inference/pipelines/instance_segmentation/pipeline.py +30 -9
- paddlex/inference/pipelines/keypoint_detection/__init__.py +1 -1
- paddlex/inference/pipelines/keypoint_detection/pipeline.py +30 -9
- paddlex/inference/pipelines/layout_parsing/__init__.py +1 -1
- paddlex/inference/pipelines/layout_parsing/pipeline.py +54 -56
- paddlex/inference/pipelines/layout_parsing/pipeline_v2.py +904 -261
- paddlex/inference/pipelines/layout_parsing/result.py +9 -21
- paddlex/inference/pipelines/layout_parsing/result_v2.py +525 -250
- paddlex/inference/pipelines/layout_parsing/setting.py +87 -0
- paddlex/inference/pipelines/layout_parsing/utils.py +570 -2004
- paddlex/inference/pipelines/layout_parsing/xycut_enhanced/__init__.py +16 -0
- paddlex/inference/pipelines/layout_parsing/xycut_enhanced/utils.py +1144 -0
- paddlex/inference/pipelines/layout_parsing/xycut_enhanced/xycuts.py +563 -0
- paddlex/inference/pipelines/{3d_bev_detection → m_3d_bev_detection}/__init__.py +1 -1
- paddlex/inference/pipelines/{3d_bev_detection → m_3d_bev_detection}/pipeline.py +17 -10
- paddlex/inference/pipelines/multilingual_speech_recognition/__init__.py +1 -1
- paddlex/inference/pipelines/multilingual_speech_recognition/pipeline.py +17 -6
- paddlex/inference/pipelines/object_detection/__init__.py +1 -1
- paddlex/inference/pipelines/object_detection/pipeline.py +29 -9
- paddlex/inference/pipelines/ocr/__init__.py +1 -1
- paddlex/inference/pipelines/ocr/pipeline.py +151 -77
- paddlex/inference/pipelines/ocr/result.py +31 -24
- paddlex/inference/pipelines/open_vocabulary_detection/__init__.py +1 -1
- paddlex/inference/pipelines/open_vocabulary_detection/pipeline.py +17 -6
- paddlex/inference/pipelines/open_vocabulary_segmentation/__init__.py +1 -1
- paddlex/inference/pipelines/open_vocabulary_segmentation/pipeline.py +17 -6
- paddlex/inference/pipelines/pp_chatocr/__init__.py +1 -1
- paddlex/inference/pipelines/pp_chatocr/pipeline_base.py +14 -5
- paddlex/inference/pipelines/pp_chatocr/pipeline_v3.py +22 -14
- paddlex/inference/pipelines/pp_chatocr/pipeline_v4.py +34 -16
- paddlex/inference/pipelines/pp_shitu_v2/__init__.py +1 -1
- paddlex/inference/pipelines/pp_shitu_v2/pipeline.py +12 -8
- paddlex/inference/pipelines/pp_shitu_v2/result.py +4 -4
- paddlex/inference/pipelines/rotated_object_detection/__init__.py +1 -1
- paddlex/inference/pipelines/rotated_object_detection/pipeline.py +30 -9
- paddlex/inference/pipelines/seal_recognition/__init__.py +1 -1
- paddlex/inference/pipelines/seal_recognition/pipeline.py +127 -63
- paddlex/inference/pipelines/seal_recognition/result.py +4 -2
- paddlex/inference/pipelines/semantic_segmentation/__init__.py +1 -1
- paddlex/inference/pipelines/semantic_segmentation/pipeline.py +30 -9
- paddlex/inference/pipelines/small_object_detection/__init__.py +1 -1
- paddlex/inference/pipelines/small_object_detection/pipeline.py +30 -9
- paddlex/inference/pipelines/table_recognition/__init__.py +1 -1
- paddlex/inference/pipelines/table_recognition/pipeline.py +61 -37
- paddlex/inference/pipelines/table_recognition/pipeline_v2.py +668 -65
- paddlex/inference/pipelines/table_recognition/result.py +12 -10
- paddlex/inference/pipelines/table_recognition/table_recognition_post_processing.py +12 -8
- paddlex/inference/pipelines/table_recognition/table_recognition_post_processing_v2.py +55 -37
- paddlex/inference/pipelines/table_recognition/utils.py +1 -1
- paddlex/inference/pipelines/ts_anomaly_detection/__init__.py +1 -1
- paddlex/inference/pipelines/ts_anomaly_detection/pipeline.py +16 -6
- paddlex/inference/pipelines/ts_classification/__init__.py +1 -1
- paddlex/inference/pipelines/ts_classification/pipeline.py +16 -6
- paddlex/inference/pipelines/ts_forecasting/__init__.py +1 -1
- paddlex/inference/pipelines/ts_forecasting/pipeline.py +16 -6
- paddlex/inference/pipelines/video_classification/__init__.py +1 -1
- paddlex/inference/pipelines/video_classification/pipeline.py +17 -6
- paddlex/inference/pipelines/video_detection/__init__.py +1 -1
- paddlex/inference/pipelines/video_detection/pipeline.py +20 -7
- paddlex/inference/serving/__init__.py +5 -1
- paddlex/inference/serving/basic_serving/__init__.py +1 -1
- paddlex/inference/serving/basic_serving/_app.py +31 -19
- paddlex/inference/serving/basic_serving/_pipeline_apps/__init__.py +7 -4
- paddlex/inference/serving/basic_serving/_pipeline_apps/_common/__init__.py +1 -1
- paddlex/inference/serving/basic_serving/_pipeline_apps/_common/common.py +12 -4
- paddlex/inference/serving/basic_serving/_pipeline_apps/_common/image_recognition.py +1 -1
- paddlex/inference/serving/basic_serving/_pipeline_apps/_common/ocr.py +7 -2
- paddlex/inference/serving/basic_serving/_pipeline_apps/anomaly_detection.py +10 -7
- paddlex/inference/serving/basic_serving/_pipeline_apps/doc_preprocessor.py +10 -7
- paddlex/inference/serving/basic_serving/_pipeline_apps/doc_understanding.py +153 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/face_recognition.py +16 -13
- paddlex/inference/serving/basic_serving/_pipeline_apps/formula_recognition.py +10 -7
- paddlex/inference/serving/basic_serving/_pipeline_apps/human_keypoint_detection.py +10 -7
- paddlex/inference/serving/basic_serving/_pipeline_apps/image_classification.py +10 -7
- paddlex/inference/serving/basic_serving/_pipeline_apps/image_multilabel_classification.py +10 -7
- paddlex/inference/serving/basic_serving/_pipeline_apps/instance_segmentation.py +13 -7
- paddlex/inference/serving/basic_serving/_pipeline_apps/layout_parsing.py +10 -8
- paddlex/inference/serving/basic_serving/_pipeline_apps/m_3d_bev_detection.py +10 -7
- paddlex/inference/serving/basic_serving/_pipeline_apps/multilingual_speech_recognition.py +10 -7
- paddlex/inference/serving/basic_serving/_pipeline_apps/object_detection.py +10 -7
- paddlex/inference/serving/basic_serving/_pipeline_apps/ocr.py +10 -7
- paddlex/inference/serving/basic_serving/_pipeline_apps/open_vocabulary_detection.py +10 -7
- paddlex/inference/serving/basic_serving/_pipeline_apps/open_vocabulary_segmentation.py +13 -7
- paddlex/inference/serving/basic_serving/_pipeline_apps/pedestrian_attribute_recognition.py +10 -7
- paddlex/inference/serving/basic_serving/_pipeline_apps/pp_chatocrv3_doc.py +14 -12
- paddlex/inference/serving/basic_serving/_pipeline_apps/pp_chatocrv4_doc.py +17 -14
- paddlex/inference/serving/basic_serving/_pipeline_apps/pp_shituv2.py +16 -13
- paddlex/inference/serving/basic_serving/_pipeline_apps/pp_structurev3.py +16 -9
- paddlex/inference/serving/basic_serving/_pipeline_apps/rotated_object_detection.py +10 -7
- paddlex/inference/serving/basic_serving/_pipeline_apps/seal_recognition.py +10 -7
- paddlex/inference/serving/basic_serving/_pipeline_apps/semantic_segmentation.py +10 -7
- paddlex/inference/serving/basic_serving/_pipeline_apps/small_object_detection.py +10 -7
- paddlex/inference/serving/basic_serving/_pipeline_apps/table_recognition.py +11 -12
- paddlex/inference/serving/basic_serving/_pipeline_apps/table_recognition_v2.py +14 -12
- paddlex/inference/serving/basic_serving/_pipeline_apps/ts_anomaly_detection.py +10 -7
- paddlex/inference/serving/basic_serving/_pipeline_apps/ts_classification.py +10 -7
- paddlex/inference/serving/basic_serving/_pipeline_apps/ts_forecast.py +10 -7
- paddlex/inference/serving/basic_serving/_pipeline_apps/vehicle_attribute_recognition.py +10 -7
- paddlex/inference/serving/basic_serving/_pipeline_apps/video_classification.py +10 -7
- paddlex/inference/serving/basic_serving/_pipeline_apps/video_detection.py +10 -7
- paddlex/inference/serving/basic_serving/_server.py +9 -4
- paddlex/inference/serving/infra/__init__.py +1 -1
- paddlex/inference/serving/infra/config.py +1 -1
- paddlex/inference/serving/infra/models.py +13 -6
- paddlex/inference/serving/infra/storage.py +9 -4
- paddlex/inference/serving/infra/utils.py +54 -28
- paddlex/inference/serving/schemas/__init__.py +1 -1
- paddlex/inference/serving/schemas/anomaly_detection.py +1 -1
- paddlex/inference/serving/schemas/doc_preprocessor.py +1 -1
- paddlex/inference/serving/schemas/doc_understanding.py +78 -0
- paddlex/inference/serving/schemas/face_recognition.py +1 -1
- paddlex/inference/serving/schemas/formula_recognition.py +2 -2
- paddlex/inference/serving/schemas/human_keypoint_detection.py +1 -1
- paddlex/inference/serving/schemas/image_classification.py +1 -1
- paddlex/inference/serving/schemas/image_multilabel_classification.py +1 -1
- paddlex/inference/serving/schemas/instance_segmentation.py +1 -1
- paddlex/inference/serving/schemas/layout_parsing.py +2 -3
- paddlex/inference/serving/schemas/m_3d_bev_detection.py +1 -1
- paddlex/inference/serving/schemas/multilingual_speech_recognition.py +1 -1
- paddlex/inference/serving/schemas/object_detection.py +1 -1
- paddlex/inference/serving/schemas/ocr.py +1 -1
- paddlex/inference/serving/schemas/open_vocabulary_detection.py +1 -1
- paddlex/inference/serving/schemas/open_vocabulary_segmentation.py +1 -1
- paddlex/inference/serving/schemas/pedestrian_attribute_recognition.py +1 -1
- paddlex/inference/serving/schemas/pp_chatocrv3_doc.py +2 -3
- paddlex/inference/serving/schemas/pp_chatocrv4_doc.py +3 -3
- paddlex/inference/serving/schemas/pp_shituv2.py +1 -1
- paddlex/inference/serving/schemas/pp_structurev3.py +11 -7
- paddlex/inference/serving/schemas/rotated_object_detection.py +1 -1
- paddlex/inference/serving/schemas/seal_recognition.py +2 -2
- paddlex/inference/serving/schemas/semantic_segmentation.py +1 -1
- paddlex/inference/serving/schemas/shared/__init__.py +1 -1
- paddlex/inference/serving/schemas/shared/classification.py +1 -1
- paddlex/inference/serving/schemas/shared/image_segmentation.py +1 -1
- paddlex/inference/serving/schemas/shared/object_detection.py +1 -1
- paddlex/inference/serving/schemas/shared/ocr.py +1 -1
- paddlex/inference/serving/schemas/small_object_detection.py +1 -1
- paddlex/inference/serving/schemas/table_recognition.py +3 -7
- paddlex/inference/serving/schemas/table_recognition_v2.py +6 -7
- paddlex/inference/serving/schemas/ts_anomaly_detection.py +1 -1
- paddlex/inference/serving/schemas/ts_classification.py +1 -1
- paddlex/inference/serving/schemas/ts_forecast.py +1 -1
- paddlex/inference/serving/schemas/vehicle_attribute_recognition.py +1 -1
- paddlex/inference/serving/schemas/video_classification.py +1 -1
- paddlex/inference/serving/schemas/video_detection.py +1 -1
- paddlex/inference/utils/__init__.py +1 -1
- paddlex/inference/utils/benchmark.py +332 -179
- paddlex/inference/utils/color_map.py +1 -1
- paddlex/inference/utils/get_pipeline_path.py +1 -1
- paddlex/inference/utils/hpi.py +258 -0
- paddlex/inference/utils/hpi_model_info_collection.json +2331 -0
- paddlex/inference/utils/io/__init__.py +11 -11
- paddlex/inference/utils/io/readers.py +31 -27
- paddlex/inference/utils/io/style.py +21 -14
- paddlex/inference/utils/io/tablepyxl.py +13 -5
- paddlex/inference/utils/io/writers.py +9 -10
- paddlex/inference/utils/mkldnn_blocklist.py +25 -0
- paddlex/inference/utils/model_paths.py +48 -0
- paddlex/inference/utils/{new_ir_blacklist.py → new_ir_blocklist.py} +1 -2
- paddlex/inference/utils/official_models.py +278 -262
- paddlex/inference/utils/pp_option.py +184 -92
- paddlex/inference/utils/trt_blocklist.py +43 -0
- paddlex/inference/utils/trt_config.py +420 -0
- paddlex/model.py +30 -12
- paddlex/modules/__init__.py +57 -80
- paddlex/modules/anomaly_detection/__init__.py +2 -2
- paddlex/modules/anomaly_detection/dataset_checker/__init__.py +2 -3
- paddlex/modules/anomaly_detection/dataset_checker/dataset_src/__init__.py +2 -2
- paddlex/modules/anomaly_detection/dataset_checker/dataset_src/analyse_dataset.py +6 -3
- paddlex/modules/anomaly_detection/dataset_checker/dataset_src/check_dataset.py +8 -4
- paddlex/modules/anomaly_detection/dataset_checker/dataset_src/convert_dataset.py +7 -4
- paddlex/modules/anomaly_detection/dataset_checker/dataset_src/split_dataset.py +2 -2
- paddlex/modules/anomaly_detection/dataset_checker/dataset_src/utils/__init__.py +1 -1
- paddlex/modules/anomaly_detection/dataset_checker/dataset_src/utils/visualizer.py +7 -2
- paddlex/modules/anomaly_detection/evaluator.py +3 -3
- paddlex/modules/anomaly_detection/exportor.py +1 -1
- paddlex/modules/anomaly_detection/model_list.py +1 -1
- paddlex/modules/anomaly_detection/trainer.py +3 -4
- paddlex/modules/base/__init__.py +5 -5
- paddlex/modules/base/build_model.py +1 -2
- paddlex/modules/base/dataset_checker/__init__.py +2 -2
- paddlex/modules/base/dataset_checker/dataset_checker.py +4 -4
- paddlex/modules/base/dataset_checker/utils.py +1 -3
- paddlex/modules/base/evaluator.py +13 -13
- paddlex/modules/base/exportor.py +12 -13
- paddlex/modules/base/trainer.py +21 -11
- paddlex/modules/base/utils/__init__.py +13 -0
- paddlex/modules/base/utils/cinn_setting.py +89 -0
- paddlex/modules/base/utils/coco_eval.py +94 -0
- paddlex/modules/base/utils/topk_eval.py +118 -0
- paddlex/modules/doc_vlm/__init__.py +18 -0
- paddlex/modules/doc_vlm/dataset_checker.py +29 -0
- paddlex/modules/doc_vlm/evaluator.py +29 -0
- paddlex/modules/doc_vlm/exportor.py +29 -0
- paddlex/modules/doc_vlm/model_list.py +16 -0
- paddlex/modules/doc_vlm/trainer.py +41 -0
- paddlex/modules/face_recognition/__init__.py +2 -2
- paddlex/modules/face_recognition/dataset_checker/__init__.py +2 -2
- paddlex/modules/face_recognition/dataset_checker/dataset_src/__init__.py +1 -1
- paddlex/modules/face_recognition/dataset_checker/dataset_src/check_dataset.py +3 -5
- paddlex/modules/face_recognition/dataset_checker/dataset_src/utils/__init__.py +1 -1
- paddlex/modules/face_recognition/dataset_checker/dataset_src/utils/visualizer.py +2 -5
- paddlex/modules/face_recognition/evaluator.py +3 -3
- paddlex/modules/face_recognition/exportor.py +1 -1
- paddlex/modules/face_recognition/model_list.py +1 -1
- paddlex/modules/face_recognition/trainer.py +1 -1
- paddlex/modules/formula_recognition/__init__.py +2 -2
- paddlex/modules/formula_recognition/dataset_checker/__init__.py +3 -3
- paddlex/modules/formula_recognition/dataset_checker/dataset_src/__init__.py +2 -2
- paddlex/modules/formula_recognition/dataset_checker/dataset_src/analyse_dataset.py +13 -12
- paddlex/modules/formula_recognition/dataset_checker/dataset_src/check_dataset.py +2 -6
- paddlex/modules/formula_recognition/dataset_checker/dataset_src/convert_dataset.py +11 -10
- paddlex/modules/formula_recognition/dataset_checker/dataset_src/split_dataset.py +1 -2
- paddlex/modules/formula_recognition/evaluator.py +6 -3
- paddlex/modules/formula_recognition/exportor.py +1 -1
- paddlex/modules/formula_recognition/model_list.py +4 -1
- paddlex/modules/formula_recognition/trainer.py +5 -3
- paddlex/modules/general_recognition/__init__.py +2 -2
- paddlex/modules/general_recognition/dataset_checker/__init__.py +2 -2
- paddlex/modules/general_recognition/dataset_checker/dataset_src/__init__.py +2 -2
- paddlex/modules/general_recognition/dataset_checker/dataset_src/analyse_dataset.py +7 -9
- paddlex/modules/general_recognition/dataset_checker/dataset_src/check_dataset.py +4 -5
- paddlex/modules/general_recognition/dataset_checker/dataset_src/convert_dataset.py +6 -5
- paddlex/modules/general_recognition/dataset_checker/dataset_src/split_dataset.py +1 -1
- paddlex/modules/general_recognition/dataset_checker/dataset_src/utils/__init__.py +1 -1
- paddlex/modules/general_recognition/dataset_checker/dataset_src/utils/visualizer.py +2 -5
- paddlex/modules/general_recognition/evaluator.py +2 -2
- paddlex/modules/general_recognition/exportor.py +1 -1
- paddlex/modules/general_recognition/model_list.py +1 -1
- paddlex/modules/general_recognition/trainer.py +1 -1
- paddlex/modules/image_classification/__init__.py +2 -2
- paddlex/modules/image_classification/dataset_checker/__init__.py +2 -2
- paddlex/modules/image_classification/dataset_checker/dataset_src/__init__.py +2 -2
- paddlex/modules/image_classification/dataset_checker/dataset_src/analyse_dataset.py +8 -9
- paddlex/modules/image_classification/dataset_checker/dataset_src/check_dataset.py +4 -3
- paddlex/modules/image_classification/dataset_checker/dataset_src/convert_dataset.py +4 -4
- paddlex/modules/image_classification/dataset_checker/dataset_src/split_dataset.py +1 -1
- paddlex/modules/image_classification/dataset_checker/dataset_src/utils/__init__.py +1 -1
- paddlex/modules/image_classification/dataset_checker/dataset_src/utils/visualizer.py +2 -5
- paddlex/modules/image_classification/evaluator.py +3 -3
- paddlex/modules/image_classification/exportor.py +1 -1
- paddlex/modules/image_classification/model_list.py +2 -1
- paddlex/modules/image_classification/trainer.py +3 -3
- paddlex/modules/image_unwarping/__init__.py +1 -1
- paddlex/modules/image_unwarping/model_list.py +1 -1
- paddlex/modules/instance_segmentation/__init__.py +2 -2
- paddlex/modules/instance_segmentation/dataset_checker/__init__.py +2 -3
- paddlex/modules/instance_segmentation/dataset_checker/dataset_src/__init__.py +2 -2
- paddlex/modules/instance_segmentation/dataset_checker/dataset_src/analyse_dataset.py +9 -5
- paddlex/modules/instance_segmentation/dataset_checker/dataset_src/check_dataset.py +8 -5
- paddlex/modules/instance_segmentation/dataset_checker/dataset_src/convert_dataset.py +8 -8
- paddlex/modules/instance_segmentation/dataset_checker/dataset_src/split_dataset.py +7 -4
- paddlex/modules/instance_segmentation/dataset_checker/dataset_src/utils/__init__.py +1 -1
- paddlex/modules/instance_segmentation/dataset_checker/dataset_src/utils/visualizer.py +10 -8
- paddlex/modules/instance_segmentation/evaluator.py +2 -2
- paddlex/modules/instance_segmentation/exportor.py +1 -1
- paddlex/modules/instance_segmentation/model_list.py +1 -1
- paddlex/modules/instance_segmentation/trainer.py +1 -1
- paddlex/modules/keypoint_detection/__init__.py +2 -2
- paddlex/modules/keypoint_detection/dataset_checker/__init__.py +2 -2
- paddlex/modules/keypoint_detection/dataset_checker/dataset_src/__init__.py +1 -1
- paddlex/modules/keypoint_detection/dataset_checker/dataset_src/check_dataset.py +10 -5
- paddlex/modules/keypoint_detection/dataset_checker/dataset_src/utils/__init__.py +1 -1
- paddlex/modules/keypoint_detection/dataset_checker/dataset_src/utils/visualizer.py +8 -3
- paddlex/modules/keypoint_detection/evaluator.py +2 -2
- paddlex/modules/keypoint_detection/exportor.py +1 -1
- paddlex/modules/keypoint_detection/model_list.py +1 -1
- paddlex/modules/keypoint_detection/trainer.py +2 -2
- paddlex/modules/{3d_bev_detection → m_3d_bev_detection}/__init__.py +2 -2
- paddlex/modules/{3d_bev_detection → m_3d_bev_detection}/dataset_checker/__init__.py +3 -3
- paddlex/modules/{3d_bev_detection → m_3d_bev_detection}/dataset_checker/dataset_src/__init__.py +2 -2
- paddlex/modules/{3d_bev_detection → m_3d_bev_detection}/dataset_checker/dataset_src/analyse_dataset.py +8 -8
- paddlex/modules/{3d_bev_detection → m_3d_bev_detection}/dataset_checker/dataset_src/check_dataset.py +1 -2
- paddlex/modules/{3d_bev_detection → m_3d_bev_detection}/evaluator.py +3 -3
- paddlex/modules/{3d_bev_detection → m_3d_bev_detection}/exportor.py +1 -1
- paddlex/modules/{3d_bev_detection → m_3d_bev_detection}/model_list.py +1 -1
- paddlex/modules/{3d_bev_detection → m_3d_bev_detection}/trainer.py +5 -7
- paddlex/modules/multilabel_classification/__init__.py +2 -2
- paddlex/modules/multilabel_classification/dataset_checker/__init__.py +2 -2
- paddlex/modules/multilabel_classification/dataset_checker/dataset_src/__init__.py +2 -2
- paddlex/modules/multilabel_classification/dataset_checker/dataset_src/analyse_dataset.py +8 -9
- paddlex/modules/multilabel_classification/dataset_checker/dataset_src/check_dataset.py +4 -3
- paddlex/modules/multilabel_classification/dataset_checker/dataset_src/convert_dataset.py +10 -7
- paddlex/modules/multilabel_classification/dataset_checker/dataset_src/split_dataset.py +1 -1
- paddlex/modules/multilabel_classification/dataset_checker/dataset_src/utils/__init__.py +1 -1
- paddlex/modules/multilabel_classification/dataset_checker/dataset_src/utils/visualizer.py +1 -5
- paddlex/modules/multilabel_classification/evaluator.py +3 -3
- paddlex/modules/multilabel_classification/exportor.py +1 -1
- paddlex/modules/multilabel_classification/model_list.py +1 -1
- paddlex/modules/multilabel_classification/trainer.py +3 -3
- paddlex/modules/multilingual_speech_recognition/__init__.py +2 -2
- paddlex/modules/multilingual_speech_recognition/dataset_checker.py +3 -3
- paddlex/modules/multilingual_speech_recognition/evaluator.py +3 -3
- paddlex/modules/multilingual_speech_recognition/exportor.py +3 -3
- paddlex/modules/multilingual_speech_recognition/model_list.py +1 -1
- paddlex/modules/multilingual_speech_recognition/trainer.py +7 -5
- paddlex/modules/object_detection/__init__.py +2 -2
- paddlex/modules/object_detection/dataset_checker/__init__.py +2 -11
- paddlex/modules/object_detection/dataset_checker/dataset_src/__init__.py +2 -2
- paddlex/modules/object_detection/dataset_checker/dataset_src/analyse_dataset.py +10 -8
- paddlex/modules/object_detection/dataset_checker/dataset_src/check_dataset.py +10 -5
- paddlex/modules/object_detection/dataset_checker/dataset_src/convert_dataset.py +17 -12
- paddlex/modules/object_detection/dataset_checker/dataset_src/split_dataset.py +8 -4
- paddlex/modules/object_detection/dataset_checker/dataset_src/utils/__init__.py +1 -1
- paddlex/modules/object_detection/dataset_checker/dataset_src/utils/visualizer.py +9 -8
- paddlex/modules/object_detection/evaluator.py +11 -6
- paddlex/modules/object_detection/exportor.py +1 -1
- paddlex/modules/object_detection/model_list.py +3 -1
- paddlex/modules/object_detection/trainer.py +4 -5
- paddlex/modules/open_vocabulary_detection/__init__.py +2 -2
- paddlex/modules/open_vocabulary_detection/dataset_checker.py +3 -3
- paddlex/modules/open_vocabulary_detection/evaluator.py +3 -3
- paddlex/modules/open_vocabulary_detection/exportor.py +3 -3
- paddlex/modules/open_vocabulary_detection/model_list.py +2 -4
- paddlex/modules/open_vocabulary_detection/trainer.py +7 -5
- paddlex/modules/open_vocabulary_segmentation/__init__.py +2 -2
- paddlex/modules/open_vocabulary_segmentation/dataset_checker.py +3 -3
- paddlex/modules/open_vocabulary_segmentation/evaluator.py +3 -3
- paddlex/modules/open_vocabulary_segmentation/exportor.py +3 -3
- paddlex/modules/open_vocabulary_segmentation/model_list.py +1 -1
- paddlex/modules/open_vocabulary_segmentation/trainer.py +7 -5
- paddlex/modules/semantic_segmentation/__init__.py +2 -2
- paddlex/modules/semantic_segmentation/dataset_checker/__init__.py +2 -3
- paddlex/modules/semantic_segmentation/dataset_checker/dataset_src/__init__.py +2 -2
- paddlex/modules/semantic_segmentation/dataset_checker/dataset_src/analyse_dataset.py +6 -3
- paddlex/modules/semantic_segmentation/dataset_checker/dataset_src/check_dataset.py +2 -2
- paddlex/modules/semantic_segmentation/dataset_checker/dataset_src/convert_dataset.py +7 -4
- paddlex/modules/semantic_segmentation/dataset_checker/dataset_src/split_dataset.py +2 -2
- paddlex/modules/semantic_segmentation/dataset_checker/dataset_src/utils/__init__.py +1 -1
- paddlex/modules/semantic_segmentation/dataset_checker/dataset_src/utils/visualizer.py +6 -2
- paddlex/modules/semantic_segmentation/evaluator.py +3 -3
- paddlex/modules/semantic_segmentation/exportor.py +1 -1
- paddlex/modules/semantic_segmentation/model_list.py +1 -1
- paddlex/modules/semantic_segmentation/trainer.py +3 -4
- paddlex/modules/table_recognition/__init__.py +2 -2
- paddlex/modules/table_recognition/dataset_checker/__init__.py +5 -5
- paddlex/modules/table_recognition/dataset_checker/dataset_src/__init__.py +2 -2
- paddlex/modules/table_recognition/dataset_checker/dataset_src/analyse_dataset.py +3 -2
- paddlex/modules/table_recognition/dataset_checker/dataset_src/check_dataset.py +8 -7
- paddlex/modules/table_recognition/dataset_checker/dataset_src/split_dataset.py +2 -1
- paddlex/modules/table_recognition/evaluator.py +3 -3
- paddlex/modules/table_recognition/exportor.py +1 -1
- paddlex/modules/table_recognition/model_list.py +1 -1
- paddlex/modules/table_recognition/trainer.py +2 -5
- paddlex/modules/text_detection/__init__.py +2 -2
- paddlex/modules/text_detection/dataset_checker/__init__.py +4 -6
- paddlex/modules/text_detection/dataset_checker/dataset_src/__init__.py +2 -2
- paddlex/modules/text_detection/dataset_checker/dataset_src/analyse_dataset.py +12 -9
- paddlex/modules/text_detection/dataset_checker/dataset_src/check_dataset.py +3 -3
- paddlex/modules/text_detection/dataset_checker/dataset_src/split_dataset.py +3 -3
- paddlex/modules/text_detection/evaluator.py +3 -3
- paddlex/modules/text_detection/exportor.py +1 -1
- paddlex/modules/text_detection/model_list.py +3 -1
- paddlex/modules/text_detection/trainer.py +2 -5
- paddlex/modules/text_recognition/__init__.py +2 -2
- paddlex/modules/text_recognition/dataset_checker/__init__.py +4 -5
- paddlex/modules/text_recognition/dataset_checker/dataset_src/__init__.py +2 -2
- paddlex/modules/text_recognition/dataset_checker/dataset_src/analyse_dataset.py +13 -12
- paddlex/modules/text_recognition/dataset_checker/dataset_src/check_dataset.py +2 -5
- paddlex/modules/text_recognition/dataset_checker/dataset_src/convert_dataset.py +11 -10
- paddlex/modules/text_recognition/dataset_checker/dataset_src/split_dataset.py +1 -2
- paddlex/modules/text_recognition/evaluator.py +3 -3
- paddlex/modules/text_recognition/exportor.py +1 -1
- paddlex/modules/text_recognition/model_list.py +3 -1
- paddlex/modules/text_recognition/trainer.py +2 -3
- paddlex/modules/ts_anomaly_detection/__init__.py +2 -2
- paddlex/modules/ts_anomaly_detection/dataset_checker/__init__.py +4 -5
- paddlex/modules/ts_anomaly_detection/dataset_checker/dataset_src/__init__.py +2 -2
- paddlex/modules/ts_anomaly_detection/dataset_checker/dataset_src/analyse_dataset.py +1 -9
- paddlex/modules/ts_anomaly_detection/dataset_checker/dataset_src/check_dataset.py +2 -2
- paddlex/modules/ts_anomaly_detection/dataset_checker/dataset_src/convert_dataset.py +2 -6
- paddlex/modules/ts_anomaly_detection/dataset_checker/dataset_src/split_dataset.py +4 -4
- paddlex/modules/ts_anomaly_detection/evaluator.py +3 -3
- paddlex/modules/ts_anomaly_detection/exportor.py +2 -3
- paddlex/modules/ts_anomaly_detection/model_list.py +1 -1
- paddlex/modules/ts_anomaly_detection/trainer.py +8 -8
- paddlex/modules/ts_classification/__init__.py +2 -2
- paddlex/modules/ts_classification/dataset_checker/__init__.py +4 -5
- paddlex/modules/ts_classification/dataset_checker/dataset_src/__init__.py +2 -2
- paddlex/modules/ts_classification/dataset_checker/dataset_src/analyse_dataset.py +8 -5
- paddlex/modules/ts_classification/dataset_checker/dataset_src/check_dataset.py +2 -2
- paddlex/modules/ts_classification/dataset_checker/dataset_src/convert_dataset.py +2 -6
- paddlex/modules/ts_classification/dataset_checker/dataset_src/split_dataset.py +5 -5
- paddlex/modules/ts_classification/evaluator.py +3 -3
- paddlex/modules/ts_classification/exportor.py +2 -3
- paddlex/modules/ts_classification/model_list.py +1 -1
- paddlex/modules/ts_classification/trainer.py +7 -7
- paddlex/modules/ts_forecast/__init__.py +2 -2
- paddlex/modules/ts_forecast/dataset_checker/__init__.py +4 -5
- paddlex/modules/ts_forecast/dataset_checker/dataset_src/__init__.py +2 -2
- paddlex/modules/ts_forecast/dataset_checker/dataset_src/analyse_dataset.py +1 -9
- paddlex/modules/ts_forecast/dataset_checker/dataset_src/check_dataset.py +2 -2
- paddlex/modules/ts_forecast/dataset_checker/dataset_src/convert_dataset.py +2 -6
- paddlex/modules/ts_forecast/dataset_checker/dataset_src/split_dataset.py +4 -4
- paddlex/modules/ts_forecast/evaluator.py +3 -3
- paddlex/modules/ts_forecast/exportor.py +2 -3
- paddlex/modules/ts_forecast/model_list.py +1 -1
- paddlex/modules/ts_forecast/trainer.py +7 -7
- paddlex/modules/video_classification/__init__.py +2 -2
- paddlex/modules/video_classification/dataset_checker/__init__.py +2 -2
- paddlex/modules/video_classification/dataset_checker/dataset_src/__init__.py +2 -2
- paddlex/modules/video_classification/dataset_checker/dataset_src/analyse_dataset.py +9 -9
- paddlex/modules/video_classification/dataset_checker/dataset_src/check_dataset.py +2 -3
- paddlex/modules/video_classification/dataset_checker/dataset_src/split_dataset.py +1 -1
- paddlex/modules/video_classification/evaluator.py +3 -3
- paddlex/modules/video_classification/exportor.py +1 -1
- paddlex/modules/video_classification/model_list.py +1 -1
- paddlex/modules/video_classification/trainer.py +3 -3
- paddlex/modules/video_detection/__init__.py +2 -2
- paddlex/modules/video_detection/dataset_checker/__init__.py +2 -2
- paddlex/modules/video_detection/dataset_checker/dataset_src/__init__.py +2 -2
- paddlex/modules/video_detection/dataset_checker/dataset_src/analyse_dataset.py +8 -9
- paddlex/modules/video_detection/dataset_checker/dataset_src/check_dataset.py +3 -5
- paddlex/modules/video_detection/evaluator.py +3 -3
- paddlex/modules/video_detection/exportor.py +1 -1
- paddlex/modules/video_detection/model_list.py +1 -1
- paddlex/modules/video_detection/trainer.py +3 -3
- paddlex/ops/__init__.py +7 -4
- paddlex/ops/iou3d_nms/iou3d_cpu.cpp +8 -6
- paddlex/ops/iou3d_nms/iou3d_cpu.h +3 -2
- paddlex/ops/iou3d_nms/iou3d_nms.cpp +8 -6
- paddlex/ops/iou3d_nms/iou3d_nms.h +6 -4
- paddlex/ops/iou3d_nms/iou3d_nms_api.cpp +24 -18
- paddlex/ops/iou3d_nms/iou3d_nms_kernel.cu +9 -7
- paddlex/ops/setup.py +3 -3
- paddlex/ops/voxel/voxelize_op.cc +22 -19
- paddlex/ops/voxel/voxelize_op.cu +25 -25
- paddlex/paddlex_cli.py +104 -87
- paddlex/repo_apis/Paddle3D_api/__init__.py +1 -1
- paddlex/repo_apis/Paddle3D_api/bev_fusion/__init__.py +1 -1
- paddlex/repo_apis/Paddle3D_api/bev_fusion/config.py +1 -1
- paddlex/repo_apis/Paddle3D_api/bev_fusion/model.py +6 -6
- paddlex/repo_apis/Paddle3D_api/bev_fusion/register.py +2 -2
- paddlex/repo_apis/Paddle3D_api/bev_fusion/runner.py +1 -1
- paddlex/repo_apis/Paddle3D_api/pp3d_config.py +3 -2
- paddlex/repo_apis/PaddleClas_api/__init__.py +1 -1
- paddlex/repo_apis/PaddleClas_api/cls/__init__.py +3 -3
- paddlex/repo_apis/PaddleClas_api/cls/config.py +5 -4
- paddlex/repo_apis/PaddleClas_api/cls/model.py +4 -4
- paddlex/repo_apis/PaddleClas_api/cls/register.py +12 -3
- paddlex/repo_apis/PaddleClas_api/cls/runner.py +2 -3
- paddlex/repo_apis/PaddleClas_api/shitu_rec/__init__.py +2 -2
- paddlex/repo_apis/PaddleClas_api/shitu_rec/config.py +2 -2
- paddlex/repo_apis/PaddleClas_api/shitu_rec/model.py +1 -4
- paddlex/repo_apis/PaddleClas_api/shitu_rec/register.py +2 -2
- paddlex/repo_apis/PaddleClas_api/shitu_rec/runner.py +1 -6
- paddlex/repo_apis/PaddleDetection_api/__init__.py +2 -2
- paddlex/repo_apis/PaddleDetection_api/config_helper.py +3 -3
- paddlex/repo_apis/PaddleDetection_api/instance_seg/__init__.py +2 -2
- paddlex/repo_apis/PaddleDetection_api/instance_seg/config.py +2 -3
- paddlex/repo_apis/PaddleDetection_api/instance_seg/model.py +4 -4
- paddlex/repo_apis/PaddleDetection_api/instance_seg/register.py +2 -3
- paddlex/repo_apis/PaddleDetection_api/instance_seg/runner.py +2 -3
- paddlex/repo_apis/PaddleDetection_api/object_det/__init__.py +3 -3
- paddlex/repo_apis/PaddleDetection_api/object_det/config.py +5 -4
- paddlex/repo_apis/PaddleDetection_api/object_det/model.py +6 -7
- paddlex/repo_apis/PaddleDetection_api/object_det/official_categories.py +26 -1
- paddlex/repo_apis/PaddleDetection_api/object_det/register.py +32 -3
- paddlex/repo_apis/PaddleDetection_api/object_det/runner.py +2 -3
- paddlex/repo_apis/PaddleNLP_api/__init__.py +1 -1
- paddlex/repo_apis/PaddleOCR_api/__init__.py +4 -3
- paddlex/repo_apis/PaddleOCR_api/config_utils.py +1 -1
- paddlex/repo_apis/PaddleOCR_api/formula_rec/__init__.py +1 -1
- paddlex/repo_apis/PaddleOCR_api/formula_rec/config.py +7 -6
- paddlex/repo_apis/PaddleOCR_api/formula_rec/model.py +9 -13
- paddlex/repo_apis/PaddleOCR_api/formula_rec/register.py +29 -3
- paddlex/repo_apis/PaddleOCR_api/formula_rec/runner.py +2 -3
- paddlex/repo_apis/PaddleOCR_api/table_rec/__init__.py +1 -1
- paddlex/repo_apis/PaddleOCR_api/table_rec/config.py +1 -1
- paddlex/repo_apis/PaddleOCR_api/table_rec/model.py +4 -4
- paddlex/repo_apis/PaddleOCR_api/table_rec/register.py +2 -3
- paddlex/repo_apis/PaddleOCR_api/table_rec/runner.py +3 -3
- paddlex/repo_apis/PaddleOCR_api/text_det/__init__.py +1 -1
- paddlex/repo_apis/PaddleOCR_api/text_det/config.py +1 -1
- paddlex/repo_apis/PaddleOCR_api/text_det/model.py +4 -4
- paddlex/repo_apis/PaddleOCR_api/text_det/register.py +20 -3
- paddlex/repo_apis/PaddleOCR_api/text_det/runner.py +3 -3
- paddlex/repo_apis/PaddleOCR_api/text_rec/__init__.py +1 -1
- paddlex/repo_apis/PaddleOCR_api/text_rec/config.py +7 -6
- paddlex/repo_apis/PaddleOCR_api/text_rec/model.py +9 -13
- paddlex/repo_apis/PaddleOCR_api/text_rec/register.py +20 -3
- paddlex/repo_apis/PaddleOCR_api/text_rec/runner.py +2 -3
- paddlex/repo_apis/PaddleSeg_api/__init__.py +1 -1
- paddlex/repo_apis/PaddleSeg_api/base_seg_config.py +2 -2
- paddlex/repo_apis/PaddleSeg_api/seg/__init__.py +1 -1
- paddlex/repo_apis/PaddleSeg_api/seg/config.py +3 -6
- paddlex/repo_apis/PaddleSeg_api/seg/model.py +6 -6
- paddlex/repo_apis/PaddleSeg_api/seg/register.py +2 -3
- paddlex/repo_apis/PaddleSeg_api/seg/runner.py +2 -3
- paddlex/repo_apis/PaddleTS_api/__init__.py +4 -3
- paddlex/repo_apis/PaddleTS_api/ts_ad/__init__.py +1 -1
- paddlex/repo_apis/PaddleTS_api/ts_ad/config.py +5 -6
- paddlex/repo_apis/PaddleTS_api/ts_ad/register.py +2 -2
- paddlex/repo_apis/PaddleTS_api/ts_ad/runner.py +2 -2
- paddlex/repo_apis/PaddleTS_api/ts_base/__init__.py +1 -1
- paddlex/repo_apis/PaddleTS_api/ts_base/config.py +2 -4
- paddlex/repo_apis/PaddleTS_api/ts_base/model.py +4 -4
- paddlex/repo_apis/PaddleTS_api/ts_base/runner.py +2 -2
- paddlex/repo_apis/PaddleTS_api/ts_cls/__init__.py +1 -1
- paddlex/repo_apis/PaddleTS_api/ts_cls/config.py +4 -5
- paddlex/repo_apis/PaddleTS_api/ts_cls/register.py +2 -2
- paddlex/repo_apis/PaddleTS_api/ts_cls/runner.py +2 -2
- paddlex/repo_apis/PaddleTS_api/ts_fc/__init__.py +1 -1
- paddlex/repo_apis/PaddleTS_api/ts_fc/config.py +6 -7
- paddlex/repo_apis/PaddleTS_api/ts_fc/register.py +1 -1
- paddlex/repo_apis/PaddleVideo_api/__init__.py +1 -1
- paddlex/repo_apis/PaddleVideo_api/config_utils.py +1 -1
- paddlex/repo_apis/PaddleVideo_api/video_cls/__init__.py +3 -3
- paddlex/repo_apis/PaddleVideo_api/video_cls/config.py +5 -4
- paddlex/repo_apis/PaddleVideo_api/video_cls/model.py +4 -4
- paddlex/repo_apis/PaddleVideo_api/video_cls/register.py +2 -3
- paddlex/repo_apis/PaddleVideo_api/video_cls/runner.py +2 -3
- paddlex/repo_apis/PaddleVideo_api/video_det/__init__.py +3 -3
- paddlex/repo_apis/PaddleVideo_api/video_det/config.py +5 -4
- paddlex/repo_apis/PaddleVideo_api/video_det/model.py +5 -5
- paddlex/repo_apis/PaddleVideo_api/video_det/register.py +2 -3
- paddlex/repo_apis/PaddleVideo_api/video_det/runner.py +2 -3
- paddlex/repo_apis/__init__.py +1 -1
- paddlex/repo_apis/base/__init__.py +4 -5
- paddlex/repo_apis/base/config.py +3 -4
- paddlex/repo_apis/base/model.py +11 -19
- paddlex/repo_apis/base/register.py +1 -1
- paddlex/repo_apis/base/runner.py +11 -12
- paddlex/repo_apis/base/utils/__init__.py +1 -1
- paddlex/repo_apis/base/utils/arg.py +1 -1
- paddlex/repo_apis/base/utils/subprocess.py +1 -1
- paddlex/repo_manager/__init__.py +2 -9
- paddlex/repo_manager/core.py +12 -30
- paddlex/repo_manager/meta.py +41 -31
- paddlex/repo_manager/repo.py +171 -161
- paddlex/repo_manager/utils.py +13 -224
- paddlex/utils/__init__.py +1 -1
- paddlex/utils/cache.py +8 -10
- paddlex/utils/config.py +6 -5
- paddlex/utils/{custom_device_whitelist.py → custom_device_list.py} +53 -199
- paddlex/utils/deps.py +249 -0
- paddlex/utils/device.py +87 -36
- paddlex/utils/download.py +4 -4
- paddlex/utils/env.py +37 -7
- paddlex/utils/errors/__init__.py +1 -1
- paddlex/utils/errors/dataset_checker.py +1 -1
- paddlex/utils/errors/others.py +2 -16
- paddlex/utils/file_interface.py +4 -5
- paddlex/utils/flags.py +17 -12
- paddlex/utils/fonts/__init__.py +36 -5
- paddlex/utils/func_register.py +1 -1
- paddlex/utils/install.py +87 -0
- paddlex/utils/interactive_get_pipeline.py +3 -3
- paddlex/utils/lazy_loader.py +3 -3
- paddlex/utils/logging.py +10 -1
- paddlex/utils/misc.py +6 -6
- paddlex/utils/pipeline_arguments.py +15 -7
- paddlex/utils/result_saver.py +4 -5
- paddlex/utils/subclass_register.py +2 -4
- paddlex/version.py +2 -1
- {paddlex-3.0.0rc0.dist-info → paddlex-3.0.1.dist-info}/METADATA +237 -102
- paddlex-3.0.1.dist-info/RECORD +1095 -0
- {paddlex-3.0.0rc0.dist-info → paddlex-3.0.1.dist-info}/WHEEL +1 -1
- paddlex/inference/models/base/predictor/basic_predictor.py +0 -139
- paddlex/paddle2onnx_requirements.txt +0 -1
- paddlex/repo_manager/requirements.txt +0 -21
- paddlex/serving_requirements.txt +0 -9
- paddlex-3.0.0rc0.dist-info/RECORD +0 -1015
- {paddlex-3.0.0rc0.dist-info → paddlex-3.0.1.dist-info}/entry_points.txt +0 -0
- {paddlex-3.0.0rc0.dist-info → paddlex-3.0.1.dist-info/licenses}/LICENSE +0 -0
- {paddlex-3.0.0rc0.dist-info → paddlex-3.0.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,2162 @@
|
|
1
|
+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import copy
|
16
|
+
import inspect
|
17
|
+
from typing import Optional, Union
|
18
|
+
|
19
|
+
import paddle
|
20
|
+
import paddle.distributed as dist
|
21
|
+
import paddle.nn as nn
|
22
|
+
import paddle.nn.functional as F
|
23
|
+
from paddle import Tensor
|
24
|
+
from paddle.common_ops_import import convert_dtype
|
25
|
+
from paddle.utils import map_structure
|
26
|
+
|
27
|
+
from ......utils import logging
|
28
|
+
from ..transformers.model_outputs import ModelOutput
|
29
|
+
from .configuration_utils import DEFAULT_MAX_NEW_TOKENS, GenerationConfig
|
30
|
+
from .logits_process import (
|
31
|
+
ForcedBOSTokenLogitsProcessor,
|
32
|
+
ForcedEOSTokenLogitsProcessor,
|
33
|
+
HammingDiversityLogitsProcessor,
|
34
|
+
LogitsProcessor,
|
35
|
+
LogitsProcessorList,
|
36
|
+
MinLengthLogitsProcessor,
|
37
|
+
NoRepeatNGramLogitsProcessor,
|
38
|
+
RepetitionPenaltyLogitsProcessor,
|
39
|
+
TopKProcess,
|
40
|
+
TopPProcess,
|
41
|
+
)
|
42
|
+
from .stopping_criteria import (
|
43
|
+
StoppingCriteria,
|
44
|
+
StoppingCriteriaList,
|
45
|
+
validate_stopping_criteria,
|
46
|
+
)
|
47
|
+
|
48
|
+
__all__ = [
|
49
|
+
"GenerationMixin",
|
50
|
+
"BeamSearchScorer",
|
51
|
+
"BeamHypotheses",
|
52
|
+
"LogitsProcessorList",
|
53
|
+
"LogitsProcessor",
|
54
|
+
"MinLengthLogitsProcessor",
|
55
|
+
"RepetitionPenaltyLogitsProcessor",
|
56
|
+
"TopKProcess",
|
57
|
+
"TopPProcess",
|
58
|
+
"get_unfinished_flag",
|
59
|
+
]
|
60
|
+
|
61
|
+
|
62
|
+
def get_scale_by_dtype(dtype: str = None, return_positive: bool = True) -> float:
|
63
|
+
"""get scale value by dtype
|
64
|
+
|
65
|
+
Args:
|
66
|
+
dtype (str): the string dtype value
|
67
|
+
|
68
|
+
Returns:
|
69
|
+
float: the scale value
|
70
|
+
"""
|
71
|
+
if dtype is None:
|
72
|
+
dtype = paddle.get_default_dtype()
|
73
|
+
|
74
|
+
dtype = convert_dtype(dtype)
|
75
|
+
scale_value = 1e6
|
76
|
+
|
77
|
+
# TODO(wj-Mcaf): support int8, int4 dtypes later
|
78
|
+
if dtype == "float16":
|
79
|
+
scale_value = 1e4
|
80
|
+
|
81
|
+
if return_positive:
|
82
|
+
return scale_value
|
83
|
+
return -1 * scale_value
|
84
|
+
|
85
|
+
|
86
|
+
def get_unfinished_flag(
|
87
|
+
input_ids: Tensor,
|
88
|
+
unfinished_flag: Tensor,
|
89
|
+
eos_token_id: Union[int, list[int], list[list[int]]],
|
90
|
+
) -> Tensor:
|
91
|
+
"""get unfinished flag for generation step
|
92
|
+
|
93
|
+
Args:
|
94
|
+
input_ids (Tensor): the input_ids
|
95
|
+
eos_token_id (Union[int, list[int], list[list[int]]]): the end os sentence flag, which can be:
|
96
|
+
* single token id, eg: 10
|
97
|
+
* multiple token ids to stop generation, eg: [10, 10]
|
98
|
+
* some more tokens to stop generations, eg: [[10], [20, 20], [30, 30, 30]]
|
99
|
+
|
100
|
+
Returns:
|
101
|
+
Tensor: the unfinished flag tensor
|
102
|
+
"""
|
103
|
+
if isinstance(eos_token_id, int):
|
104
|
+
unfinished_flag = paddle.logical_and(
|
105
|
+
unfinished_flag, input_ids[:, -1:] != eos_token_id
|
106
|
+
)
|
107
|
+
else:
|
108
|
+
batch_unfinish_flag = None
|
109
|
+
for batch_eos_token_id in eos_token_id:
|
110
|
+
if batch_unfinish_flag is None:
|
111
|
+
batch_unfinish_flag = ~get_unfinished_flag(
|
112
|
+
input_ids, unfinished_flag, batch_eos_token_id
|
113
|
+
)
|
114
|
+
else:
|
115
|
+
batch_unfinish_flag = paddle.logical_or(
|
116
|
+
batch_unfinish_flag,
|
117
|
+
~get_unfinished_flag(
|
118
|
+
input_ids, unfinished_flag, batch_eos_token_id
|
119
|
+
),
|
120
|
+
)
|
121
|
+
|
122
|
+
unfinished_flag = ~batch_unfinish_flag
|
123
|
+
return unfinished_flag
|
124
|
+
|
125
|
+
|
126
|
+
class BeamHypotheses:
|
127
|
+
def __init__(self, num_beams, length_penalty, early_stopping):
|
128
|
+
"""
|
129
|
+
Initialize n-best list of hypotheses.
|
130
|
+
"""
|
131
|
+
self.length_penalty = length_penalty
|
132
|
+
self.early_stopping = early_stopping
|
133
|
+
self.num_beams = num_beams
|
134
|
+
self.beams = []
|
135
|
+
self.worst_score = get_scale_by_dtype()
|
136
|
+
|
137
|
+
def __len__(self):
|
138
|
+
"""
|
139
|
+
Number of hypotheses in the list.
|
140
|
+
"""
|
141
|
+
return len(self.beams)
|
142
|
+
|
143
|
+
def add(self, hyp, sum_logprobs, origin_len=0):
|
144
|
+
"""
|
145
|
+
Add a new hypothesis to the list.
|
146
|
+
"""
|
147
|
+
score = sum_logprobs / (
|
148
|
+
((hyp.shape[-1] - origin_len + 5) / 6) ** self.length_penalty
|
149
|
+
)
|
150
|
+
if len(self) < self.num_beams or score > self.worst_score:
|
151
|
+
self.beams.append((score, hyp))
|
152
|
+
if len(self) > self.num_beams:
|
153
|
+
sorted_next_scores = sorted(
|
154
|
+
[(s, idx) for idx, (s, _) in enumerate(self.beams)]
|
155
|
+
)
|
156
|
+
del self.beams[sorted_next_scores[0][1]]
|
157
|
+
self.worst_score = sorted_next_scores[1][0]
|
158
|
+
else:
|
159
|
+
self.worst_score = min(score, self.worst_score)
|
160
|
+
|
161
|
+
def is_done(self, best_sum_logprobs, cur_len, origin_len=0):
|
162
|
+
"""
|
163
|
+
If there are enough hypotheses and that none of the hypotheses being
|
164
|
+
generated can become better than the worst one in the heap, then we
|
165
|
+
are done with this sentence.
|
166
|
+
"""
|
167
|
+
if len(self) < self.num_beams:
|
168
|
+
return False
|
169
|
+
elif self.early_stopping:
|
170
|
+
return True
|
171
|
+
else:
|
172
|
+
cur_score = (
|
173
|
+
best_sum_logprobs
|
174
|
+
/ ((cur_len - origin_len + 5) / 6) ** self.length_penalty
|
175
|
+
)
|
176
|
+
ret = self.worst_score >= cur_score
|
177
|
+
return ret
|
178
|
+
|
179
|
+
|
180
|
+
class BeamSearchScorer(object):
|
181
|
+
"""
|
182
|
+
implementing standard beam search decoding.
|
183
|
+
"""
|
184
|
+
|
185
|
+
def __init__(
|
186
|
+
self,
|
187
|
+
batch_size,
|
188
|
+
max_length,
|
189
|
+
num_beams,
|
190
|
+
length_penalty=1.0,
|
191
|
+
do_early_stopping=False,
|
192
|
+
num_beam_hyps_to_keep=1,
|
193
|
+
num_beam_groups=1,
|
194
|
+
):
|
195
|
+
self.max_length = max_length
|
196
|
+
self.num_beams = num_beams
|
197
|
+
self.length_penalty = length_penalty
|
198
|
+
self.do_early_stopping = do_early_stopping
|
199
|
+
self.num_beam_hyps_to_keep = num_beam_hyps_to_keep
|
200
|
+
self.num_beam_groups = num_beam_groups
|
201
|
+
self.group_size = self.num_beams // self.num_beam_groups
|
202
|
+
|
203
|
+
self._is_init = False
|
204
|
+
self._beam_hyps = [
|
205
|
+
BeamHypotheses(
|
206
|
+
num_beams=self.num_beams,
|
207
|
+
length_penalty=self.length_penalty,
|
208
|
+
early_stopping=self.do_early_stopping,
|
209
|
+
)
|
210
|
+
for _ in range(batch_size)
|
211
|
+
]
|
212
|
+
self._done = paddle.to_tensor([0 for _ in range(batch_size)], dtype="int64")
|
213
|
+
|
214
|
+
if not isinstance(num_beams, int) or num_beams <= 1:
|
215
|
+
raise ValueError(
|
216
|
+
"`num_beams` has to be an integer strictly greater than 1, but "
|
217
|
+
"received {}. For `num_beams` == 1, one should make use of "
|
218
|
+
"`greedy_search` instead.".format(num_beams)
|
219
|
+
)
|
220
|
+
|
221
|
+
if (
|
222
|
+
not isinstance(num_beam_groups, int)
|
223
|
+
or (num_beam_groups > num_beams)
|
224
|
+
or (num_beams % num_beam_groups != 0)
|
225
|
+
):
|
226
|
+
raise ValueError(
|
227
|
+
"`num_beam_groups` has to be an integer smaller or equal than "
|
228
|
+
"`num_beams` and `num_beams` has to be divisible by "
|
229
|
+
"`num_beam_groups`, but received num_beam_groups={}, num_beams="
|
230
|
+
"{}.".format(num_beam_groups, num_beams)
|
231
|
+
)
|
232
|
+
|
233
|
+
@property
|
234
|
+
def is_done(self):
|
235
|
+
return paddle.min(self._done) == 1
|
236
|
+
|
237
|
+
def process(
|
238
|
+
self,
|
239
|
+
input_ids,
|
240
|
+
next_scores,
|
241
|
+
next_tokens,
|
242
|
+
next_indices,
|
243
|
+
origin_len=0,
|
244
|
+
pad_token_id=None,
|
245
|
+
eos_token_id=None,
|
246
|
+
):
|
247
|
+
cur_len = input_ids.shape[-1]
|
248
|
+
batch_size = len(self._beam_hyps)
|
249
|
+
assert batch_size == (input_ids.shape[0] // self.group_size)
|
250
|
+
|
251
|
+
next_beam_scores = paddle.zeros(
|
252
|
+
[batch_size, self.group_size], dtype=next_scores.dtype
|
253
|
+
)
|
254
|
+
next_beam_tokens = paddle.zeros(
|
255
|
+
[batch_size, self.group_size], dtype=next_tokens.dtype
|
256
|
+
)
|
257
|
+
next_beam_indices = paddle.zeros(
|
258
|
+
[batch_size, self.group_size], dtype=next_indices.dtype
|
259
|
+
)
|
260
|
+
|
261
|
+
for batch_idx, beam_hyp in enumerate(self._beam_hyps):
|
262
|
+
if self._done[batch_idx] == 1:
|
263
|
+
assert (
|
264
|
+
len(beam_hyp) >= self.num_beams
|
265
|
+
), "Batch can only be done if at least {} beams have been generated".format(
|
266
|
+
self.num_beams
|
267
|
+
)
|
268
|
+
assert (
|
269
|
+
eos_token_id is not None and pad_token_id is not None
|
270
|
+
), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
|
271
|
+
# pad the batch
|
272
|
+
next_beam_scores[batch_idx, :] = 0
|
273
|
+
next_beam_tokens[batch_idx, :] = pad_token_id
|
274
|
+
next_beam_indices[batch_idx, :] = 0
|
275
|
+
continue
|
276
|
+
|
277
|
+
# next tokens for this sentence
|
278
|
+
beam_idx = 0
|
279
|
+
for beam_token_rank, (next_token, next_score, next_index) in enumerate(
|
280
|
+
zip(
|
281
|
+
next_tokens[batch_idx],
|
282
|
+
next_scores[batch_idx],
|
283
|
+
next_indices[batch_idx],
|
284
|
+
)
|
285
|
+
):
|
286
|
+
batch_beam_idx = batch_idx * self.group_size + next_index
|
287
|
+
# add to generated hypotheses if end of sentence
|
288
|
+
if (eos_token_id is not None) and (next_token.item() == eos_token_id):
|
289
|
+
# If beam_token does not belong to top num_beams tokens,
|
290
|
+
# it should not be added
|
291
|
+
is_beam_token_worse_than_top_num_beams = (
|
292
|
+
beam_token_rank >= self.group_size
|
293
|
+
)
|
294
|
+
if is_beam_token_worse_than_top_num_beams:
|
295
|
+
continue
|
296
|
+
beam_hyp.add(
|
297
|
+
input_ids[batch_beam_idx.item()].clone(),
|
298
|
+
next_score.item(),
|
299
|
+
origin_len,
|
300
|
+
)
|
301
|
+
|
302
|
+
else:
|
303
|
+
# add next predicted token since it is not eos_token
|
304
|
+
next_beam_scores[batch_idx, beam_idx] = next_score
|
305
|
+
next_beam_tokens[batch_idx, beam_idx] = next_token.item()
|
306
|
+
next_beam_indices[batch_idx, beam_idx] = batch_beam_idx.item()
|
307
|
+
beam_idx += 1
|
308
|
+
|
309
|
+
# once the beam for next step is full, don't add more tokens to it.
|
310
|
+
if beam_idx == self.group_size:
|
311
|
+
break
|
312
|
+
|
313
|
+
if beam_idx < self.group_size:
|
314
|
+
raise ValueError(
|
315
|
+
"At most {} tokens in `next_tokens[batch_idx]` can be equal "
|
316
|
+
"to `eos_token_id: {}`. Make sure `next_tokens[batch_idx]` "
|
317
|
+
"are corrected.".format(self.group_size, eos_token_id)
|
318
|
+
)
|
319
|
+
|
320
|
+
# Check if we are done so that we can save a pad step if all(done)
|
321
|
+
if beam_hyp.is_done(
|
322
|
+
next_scores[batch_idx].max().item(), cur_len, origin_len
|
323
|
+
):
|
324
|
+
self._done[batch_idx] = 1
|
325
|
+
|
326
|
+
return {
|
327
|
+
"next_beam_scores": next_beam_scores.reshape([-1]),
|
328
|
+
"next_beam_tokens": next_beam_tokens.reshape([-1]),
|
329
|
+
"next_beam_indices": next_beam_indices.reshape([-1]),
|
330
|
+
}
|
331
|
+
|
332
|
+
def finalize(
|
333
|
+
self,
|
334
|
+
input_ids,
|
335
|
+
final_beam_scores,
|
336
|
+
final_beam_tokens,
|
337
|
+
final_beam_indices,
|
338
|
+
origin_len=0,
|
339
|
+
pad_token_id=None,
|
340
|
+
eos_token_id=None,
|
341
|
+
):
|
342
|
+
batch_size = len(self._beam_hyps)
|
343
|
+
|
344
|
+
# finalize all open beam hypotheses and add to generated hypotheses
|
345
|
+
for batch_idx, beam_hyp in enumerate(self._beam_hyps):
|
346
|
+
if self._done[batch_idx] == 1:
|
347
|
+
continue
|
348
|
+
|
349
|
+
# all open beam hypotheses are added to the beam hypothesis
|
350
|
+
# beam hypothesis class automatically keeps the best beams
|
351
|
+
for beam_id in range(self.num_beams):
|
352
|
+
batch_beam_idx = batch_idx * self.num_beams + beam_id
|
353
|
+
final_score = final_beam_scores[batch_beam_idx].item()
|
354
|
+
final_tokens = input_ids[batch_beam_idx]
|
355
|
+
beam_hyp.add(final_tokens, final_score, origin_len=origin_len)
|
356
|
+
|
357
|
+
# select the best hypotheses
|
358
|
+
sent_lengths = paddle.zeros(
|
359
|
+
[batch_size * self.num_beam_hyps_to_keep], dtype=input_ids.dtype
|
360
|
+
)
|
361
|
+
best = []
|
362
|
+
|
363
|
+
# retrieve best hypotheses
|
364
|
+
for i, beam_hyp in enumerate(self._beam_hyps):
|
365
|
+
sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0])
|
366
|
+
for j in range(self.num_beam_hyps_to_keep):
|
367
|
+
best_score, best_hyp = sorted_hyps.pop()
|
368
|
+
sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp)
|
369
|
+
best.append([best_hyp, best_score])
|
370
|
+
|
371
|
+
# prepare for adding eos
|
372
|
+
sent_max_len = min(sent_lengths.max().item() + 1, self.max_length)
|
373
|
+
decoded = paddle.zeros(
|
374
|
+
[batch_size * self.num_beam_hyps_to_keep, sent_max_len],
|
375
|
+
dtype=input_ids.dtype,
|
376
|
+
)
|
377
|
+
# shorter batches are padded if needed
|
378
|
+
if sent_lengths.min().item() != sent_lengths.max().item():
|
379
|
+
assert pad_token_id is not None, "`pad_token_id` has to be defined"
|
380
|
+
decoded[:, :] = pad_token_id
|
381
|
+
decoded_score = paddle.zeros([batch_size * self.num_beam_hyps_to_keep, 1])
|
382
|
+
|
383
|
+
# fill with hypotheses and eos_token_id if the latter fits in
|
384
|
+
for i, (hypo, score) in enumerate(best):
|
385
|
+
decoded[i, : sent_lengths[i].item()] = hypo.cpu().numpy()
|
386
|
+
decoded_score[i] = score
|
387
|
+
if sent_lengths[i] < self.max_length:
|
388
|
+
decoded[i, sent_lengths[i].item()] = eos_token_id
|
389
|
+
return decoded, decoded_score
|
390
|
+
|
391
|
+
|
392
|
+
class GenerationMixin(object):
|
393
|
+
r"""
|
394
|
+
This class implements the interface for generation task.
|
395
|
+
|
396
|
+
It's used as the base class of `paddlenlp.transformers.PretrainedModel
|
397
|
+
<https://paddlenlp.readthedocs.io/zh/latest/source/paddlenlp.transformers.model_utils.html>`__.
|
398
|
+
"""
|
399
|
+
|
400
|
+
# enable `to_static` method for CausalLM Model
|
401
|
+
enable_to_static_method = False
|
402
|
+
|
403
|
+
@staticmethod
|
404
|
+
def prepare_input_ids_for_generation(bos_token_id, encoder_output=None):
|
405
|
+
batch_size = 1
|
406
|
+
if bos_token_id is None:
|
407
|
+
raise ValueError(
|
408
|
+
"`bos_token_id` should be defined when no " "`input_ids` are provided."
|
409
|
+
)
|
410
|
+
if encoder_output is not None:
|
411
|
+
batch_size = encoder_output.shape[0]
|
412
|
+
return paddle.ones([batch_size, 1], dtype="int64") * bos_token_id
|
413
|
+
|
414
|
+
@staticmethod
|
415
|
+
def prepare_attention_mask_for_generation(input_ids, pad_token_id, eos_token_id):
|
416
|
+
is_pad_token_in_inputs_ids = (pad_token_id is not None) and paddle.any(
|
417
|
+
input_ids == pad_token_id
|
418
|
+
).item()
|
419
|
+
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (
|
420
|
+
(eos_token_id is not None) and (pad_token_id != eos_token_id)
|
421
|
+
)
|
422
|
+
if is_pad_token_in_inputs_ids and is_pad_token_not_equal_to_eos_token_id:
|
423
|
+
attention_mask = (input_ids == pad_token_id).astype(
|
424
|
+
paddle.get_default_dtype()
|
425
|
+
) * get_scale_by_dtype(return_positive=False)
|
426
|
+
else:
|
427
|
+
attention_mask = paddle.zeros_like(
|
428
|
+
input_ids, dtype=paddle.get_default_dtype()
|
429
|
+
)
|
430
|
+
return paddle.unsqueeze(attention_mask, axis=[1, 2])
|
431
|
+
|
432
|
+
@staticmethod
|
433
|
+
def prepare_seq_len_for_generation(input_ids, pad_token_id, eos_token_id):
|
434
|
+
is_pad_token_in_inputs_ids = (pad_token_id is not None) and paddle.any(
|
435
|
+
input_ids == pad_token_id
|
436
|
+
).item()
|
437
|
+
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (
|
438
|
+
(eos_token_id is not None) and (pad_token_id != eos_token_id)
|
439
|
+
)
|
440
|
+
if is_pad_token_in_inputs_ids and is_pad_token_not_equal_to_eos_token_id:
|
441
|
+
seq_len = paddle.sum(input_ids != pad_token_id, axis=1).unsqueeze(-1)
|
442
|
+
else:
|
443
|
+
seq_len = paddle.full(
|
444
|
+
(input_ids.shape[0], 1), input_ids.shape[1], dtype="int64"
|
445
|
+
)
|
446
|
+
return seq_len
|
447
|
+
|
448
|
+
def get_logits_processor(
|
449
|
+
self,
|
450
|
+
min_length=None,
|
451
|
+
max_length=None,
|
452
|
+
eos_token_id=None,
|
453
|
+
forced_bos_token_id=None,
|
454
|
+
forced_eos_token_id=None,
|
455
|
+
num_beams=1,
|
456
|
+
num_beam_groups=1,
|
457
|
+
diversity_rate=0.0,
|
458
|
+
repetition_penalty=None,
|
459
|
+
no_repeat_ngram_size=None,
|
460
|
+
logits_processors=None,
|
461
|
+
):
|
462
|
+
processors = LogitsProcessorList()
|
463
|
+
|
464
|
+
if min_length is not None and eos_token_id is not None and min_length > -1:
|
465
|
+
processors.append(MinLengthLogitsProcessor(min_length, eos_token_id))
|
466
|
+
if num_beam_groups > 1 and diversity_rate > 0.0:
|
467
|
+
processors.append(
|
468
|
+
HammingDiversityLogitsProcessor(
|
469
|
+
diversity_rate=diversity_rate,
|
470
|
+
num_beams=num_beams,
|
471
|
+
num_beam_groups=num_beam_groups,
|
472
|
+
)
|
473
|
+
)
|
474
|
+
if repetition_penalty is not None and repetition_penalty != 1.0:
|
475
|
+
processors.append(
|
476
|
+
RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)
|
477
|
+
)
|
478
|
+
if no_repeat_ngram_size is not None and no_repeat_ngram_size > 0:
|
479
|
+
processors.append(NoRepeatNGramLogitsProcessor(no_repeat_ngram_size))
|
480
|
+
if forced_bos_token_id is not None:
|
481
|
+
processors.append(ForcedBOSTokenLogitsProcessor(forced_bos_token_id))
|
482
|
+
if forced_eos_token_id is not None:
|
483
|
+
processors.append(
|
484
|
+
ForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id)
|
485
|
+
)
|
486
|
+
# TODO
|
487
|
+
# Add more pre_processing for distribution
|
488
|
+
|
489
|
+
if logits_processors is not None:
|
490
|
+
custom_processors = LogitsProcessorList()
|
491
|
+
custom_processors_type = [type(lp) for lp in logits_processors]
|
492
|
+
|
493
|
+
for processor in processors:
|
494
|
+
if type(processor) not in custom_processors_type:
|
495
|
+
custom_processors.append(processor)
|
496
|
+
custom_processors.extend(logits_processors)
|
497
|
+
|
498
|
+
return custom_processors
|
499
|
+
else:
|
500
|
+
return processors
|
501
|
+
|
502
|
+
@staticmethod
|
503
|
+
def expand_inputs_for_generation(
|
504
|
+
input_ids, expand_size, attention_mask=None, **model_kwargs
|
505
|
+
):
|
506
|
+
|
507
|
+
index = paddle.tile(
|
508
|
+
paddle.arange(input_ids.shape[0], dtype="int64").unsqueeze(-1),
|
509
|
+
[1, expand_size],
|
510
|
+
).reshape([-1])
|
511
|
+
|
512
|
+
input_ids = paddle.gather(input_ids, index)
|
513
|
+
|
514
|
+
if attention_mask is not None:
|
515
|
+
model_kwargs["attention_mask"] = paddle.gather(attention_mask, index)
|
516
|
+
|
517
|
+
if (
|
518
|
+
"token_type_ids" in model_kwargs
|
519
|
+
and model_kwargs["token_type_ids"] is not None
|
520
|
+
):
|
521
|
+
token_type_ids = model_kwargs["token_type_ids"]
|
522
|
+
model_kwargs["token_type_ids"] = paddle.gather(token_type_ids, index)
|
523
|
+
|
524
|
+
if "position_ids" in model_kwargs and model_kwargs["position_ids"] is not None:
|
525
|
+
position_ids = model_kwargs["position_ids"]
|
526
|
+
model_kwargs["position_ids"] = paddle.gather(position_ids, index)
|
527
|
+
|
528
|
+
if "seq_len" in model_kwargs and model_kwargs["seq_len"] is not None:
|
529
|
+
seq_len = model_kwargs["seq_len"]
|
530
|
+
model_kwargs["seq_len"] = paddle.gather(seq_len, index)
|
531
|
+
|
532
|
+
if (
|
533
|
+
"encoder_output" in model_kwargs
|
534
|
+
and model_kwargs["encoder_output"] is not None
|
535
|
+
):
|
536
|
+
encoder_output = model_kwargs["encoder_output"]
|
537
|
+
model_kwargs["encoder_output"] = paddle.gather(encoder_output, index)
|
538
|
+
|
539
|
+
if "role_ids" in model_kwargs and model_kwargs["role_ids"] is not None:
|
540
|
+
role_ids = model_kwargs["role_ids"]
|
541
|
+
model_kwargs["role_ids"] = paddle.gather(role_ids, index)
|
542
|
+
|
543
|
+
return input_ids, model_kwargs
|
544
|
+
|
545
|
+
@staticmethod
|
546
|
+
def update_model_kwargs_for_generation(
|
547
|
+
outputs, model_kwargs, is_encoder_decoder=False
|
548
|
+
):
|
549
|
+
# Update the model inputs during generation.
|
550
|
+
# Note that If `token_type_ids` and `attention_mask` in `model_kwargs`
|
551
|
+
# and they contain pad value, the result vectors updated by this method
|
552
|
+
# may be different from expected. In this case, you need to rewrite the
|
553
|
+
# method.
|
554
|
+
|
555
|
+
# update cache
|
556
|
+
if (
|
557
|
+
isinstance(outputs, tuple)
|
558
|
+
and len(outputs) > 1
|
559
|
+
and not isinstance(outputs[1], paddle.Tensor)
|
560
|
+
):
|
561
|
+
model_kwargs["cache"] = outputs[1]
|
562
|
+
model_kwargs["past_key_values"] = outputs[1]
|
563
|
+
|
564
|
+
if isinstance(outputs, ModelOutput) and "past_key_values" in outputs:
|
565
|
+
model_kwargs["cache"] = outputs.past_key_values
|
566
|
+
model_kwargs["past_key_values"] = outputs.past_key_values
|
567
|
+
|
568
|
+
# update token_type_ids with last value
|
569
|
+
if (
|
570
|
+
"token_type_ids" in model_kwargs
|
571
|
+
and model_kwargs["token_type_ids"] is not None
|
572
|
+
):
|
573
|
+
token_type_ids = model_kwargs["token_type_ids"]
|
574
|
+
model_kwargs["token_type_ids"] = paddle.concat(
|
575
|
+
[token_type_ids, token_type_ids[:, -1:]], axis=-1
|
576
|
+
)
|
577
|
+
|
578
|
+
# update position_ids
|
579
|
+
if "position_ids" in model_kwargs and model_kwargs["position_ids"] is not None:
|
580
|
+
position_ids = model_kwargs["position_ids"]
|
581
|
+
model_kwargs["position_ids"] = paddle.concat(
|
582
|
+
[position_ids, position_ids[..., -1:] + 1], axis=-1
|
583
|
+
)
|
584
|
+
|
585
|
+
# update attention_mask
|
586
|
+
if not is_encoder_decoder and "attention_mask" in model_kwargs:
|
587
|
+
attention_mask = model_kwargs["attention_mask"]
|
588
|
+
# nn.Pad2D don't support the data type `bool`
|
589
|
+
if convert_dtype(attention_mask.dtype) == "bool":
|
590
|
+
attention_mask = paddle.cast(attention_mask, "int64")
|
591
|
+
if len(attention_mask.shape) == 4:
|
592
|
+
cur_device = paddle.get_device()
|
593
|
+
if cur_device.split(":")[0] == "npu":
|
594
|
+
attention_mask = nn.Pad2D([0, 0, 0, 1], mode="constant")(
|
595
|
+
attention_mask
|
596
|
+
)
|
597
|
+
attention_mask = nn.Pad2D([0, 1, 0, 0], value=0)(attention_mask)
|
598
|
+
else:
|
599
|
+
attention_mask = nn.Pad2D([0, 0, 0, 1], mode="replicate")(
|
600
|
+
attention_mask
|
601
|
+
)
|
602
|
+
attention_mask = nn.Pad2D(
|
603
|
+
[0, 1, 0, 0], value=get_scale_by_dtype(return_positive=False)
|
604
|
+
)(attention_mask)
|
605
|
+
|
606
|
+
dtype = convert_dtype(attention_mask.dtype)
|
607
|
+
if "int" in dtype:
|
608
|
+
attention_mask[:, :, -1, -1] = 1
|
609
|
+
elif "float" in dtype:
|
610
|
+
attention_mask[:, :, -1, -1] = 0.0
|
611
|
+
else:
|
612
|
+
raise ValueError(
|
613
|
+
"The data type of input `attention_mask` must "
|
614
|
+
"be bool, int or float"
|
615
|
+
)
|
616
|
+
else:
|
617
|
+
attention_mask = paddle.concat(
|
618
|
+
[
|
619
|
+
attention_mask,
|
620
|
+
paddle.ones([attention_mask.shape[0], 1], dtype="int64"),
|
621
|
+
],
|
622
|
+
axis=-1,
|
623
|
+
)
|
624
|
+
model_kwargs["attention_mask"] = attention_mask
|
625
|
+
|
626
|
+
# update role_ids
|
627
|
+
if "role_ids" in model_kwargs and model_kwargs["role_ids"] is not None:
|
628
|
+
role_ids = model_kwargs["role_ids"]
|
629
|
+
model_kwargs["role_ids"] = paddle.concat(
|
630
|
+
[role_ids, role_ids[:, -1:]], axis=-1
|
631
|
+
)
|
632
|
+
|
633
|
+
return model_kwargs
|
634
|
+
|
635
|
+
@staticmethod
|
636
|
+
def update_scores_for_generation(scores, next_scores, length, unfinished_flag):
|
637
|
+
# update scores
|
638
|
+
|
639
|
+
unfinished_scores = (
|
640
|
+
scores * paddle.to_tensor(length, dtype=scores.dtype) + next_scores
|
641
|
+
) / (paddle.to_tensor(length, dtype=scores.dtype) + 1)
|
642
|
+
scores = paddle.where(unfinished_flag, unfinished_scores, scores)
|
643
|
+
return scores
|
644
|
+
|
645
|
+
def prepare_encoder_decoder_kwargs_for_generation(self, input_ids, model_kwargs):
|
646
|
+
if "encoder_output" not in model_kwargs:
|
647
|
+
# retrieve encoder hidden states
|
648
|
+
encoder = self.get_encoder()
|
649
|
+
encoder_kwargs = {
|
650
|
+
argument: value
|
651
|
+
for argument, value in model_kwargs.items()
|
652
|
+
if not (
|
653
|
+
argument.startswith("decoder_")
|
654
|
+
or argument.startswith("cross_attn")
|
655
|
+
or argument == "use_cache"
|
656
|
+
)
|
657
|
+
}
|
658
|
+
# Use inputs_embeds as the priority if inputs_embeds exists
|
659
|
+
if "inputs_embeds" in encoder_kwargs:
|
660
|
+
model_kwargs["encoder_output"] = encoder(**encoder_kwargs)
|
661
|
+
else:
|
662
|
+
model_kwargs["encoder_output"] = encoder(
|
663
|
+
input_ids=input_ids, **encoder_kwargs
|
664
|
+
)
|
665
|
+
return model_kwargs
|
666
|
+
|
667
|
+
def prepare_decoder_input_ids_for_generation(
|
668
|
+
self, input_ids, decoder_start_token_id=None, bos_token_id=None
|
669
|
+
):
|
670
|
+
decoder_start_token_id = (
|
671
|
+
decoder_start_token_id
|
672
|
+
if decoder_start_token_id is not None
|
673
|
+
else self.config.decoder_start_token_id
|
674
|
+
)
|
675
|
+
decoder_start_token_id = (
|
676
|
+
decoder_start_token_id
|
677
|
+
if decoder_start_token_id is not None
|
678
|
+
else bos_token_id
|
679
|
+
)
|
680
|
+
|
681
|
+
decoder_input_ids = (
|
682
|
+
paddle.ones([input_ids.shape[0], 1], dtype="int64") * decoder_start_token_id
|
683
|
+
)
|
684
|
+
|
685
|
+
return decoder_input_ids
|
686
|
+
|
687
|
+
def get_decoder_start_token_id(
|
688
|
+
self, decoder_start_token_id=None, bos_token_id=None
|
689
|
+
):
|
690
|
+
decoder_start_token_id = (
|
691
|
+
decoder_start_token_id
|
692
|
+
if decoder_start_token_id is not None
|
693
|
+
else self.config.decoder_start_token_id
|
694
|
+
)
|
695
|
+
bos_token_id = (
|
696
|
+
bos_token_id if bos_token_id is not None else self.config.bos_token_id
|
697
|
+
)
|
698
|
+
|
699
|
+
if decoder_start_token_id is not None:
|
700
|
+
return decoder_start_token_id
|
701
|
+
elif self.config.decoder_start_token_id is not None:
|
702
|
+
return self.config.decoder_start_token_id
|
703
|
+
elif bos_token_id is not None:
|
704
|
+
return bos_token_id
|
705
|
+
elif self.config.bos_token_id is not None:
|
706
|
+
return self.config.bos_token_id
|
707
|
+
raise ValueError(
|
708
|
+
"`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation."
|
709
|
+
)
|
710
|
+
|
711
|
+
def prepare_inputs_for_generation(self, input_ids, **kwargs):
|
712
|
+
# Implement in subclasses for custom behavior to prepare inputs in the
|
713
|
+
# generate method.
|
714
|
+
|
715
|
+
return {"input_ids": input_ids}
|
716
|
+
|
717
|
+
def adjust_logits_during_generation(self, logits):
|
718
|
+
# Implement in subclasses for custom behavior to adjust the logits in
|
719
|
+
# the generate method.
|
720
|
+
|
721
|
+
return logits
|
722
|
+
|
723
|
+
def prepare_fast_entry(self, kwargs):
|
724
|
+
return False
|
725
|
+
|
726
|
+
def _convert_to_fast(self, kwargs):
|
727
|
+
# try general convert
|
728
|
+
pass
|
729
|
+
|
730
|
+
def _build_fast(self, kwargs):
|
731
|
+
self._fast_entry = False
|
732
|
+
if kwargs["num_beam_groups"] != 1:
|
733
|
+
# not support for group_beam_search yet in the fast version
|
734
|
+
raise AttributeError(
|
735
|
+
"'num_beam_groups != 1' is not supported yet in the fast version"
|
736
|
+
)
|
737
|
+
if (
|
738
|
+
paddle.get_default_dtype() == "float16"
|
739
|
+
and kwargs["use_fp16_decoding"] is False
|
740
|
+
):
|
741
|
+
logging.info(
|
742
|
+
"Since the default dtype is float16, float16 would be used "
|
743
|
+
"though 'use_fp16_decoding=False'."
|
744
|
+
)
|
745
|
+
kwargs["use_fp16_decoding"] = True
|
746
|
+
self.prepare_fast_entry(kwargs)
|
747
|
+
|
748
|
+
def set_pad_token_id(self, pad_token_id, eos_token_id):
|
749
|
+
if pad_token_id is None and eos_token_id is not None:
|
750
|
+
logging.warning(
|
751
|
+
"Setting `pad_token_id` to `eos_token_id`:{} for "
|
752
|
+
"open-end generation.".format(eos_token_id)
|
753
|
+
)
|
754
|
+
if isinstance(eos_token_id, list):
|
755
|
+
pad_token_id = eos_token_id[0]
|
756
|
+
else:
|
757
|
+
pad_token_id = eos_token_id
|
758
|
+
return pad_token_id
|
759
|
+
|
760
|
+
@paddle.no_grad()
|
761
|
+
def generate(
|
762
|
+
self,
|
763
|
+
input_ids: paddle.Tensor = None,
|
764
|
+
generation_config: GenerationConfig = None,
|
765
|
+
stopping_criteria: StoppingCriteria = None,
|
766
|
+
streamer=None,
|
767
|
+
synced_gpus: Optional[bool] = None,
|
768
|
+
**kwargs,
|
769
|
+
):
|
770
|
+
r"""
|
771
|
+
The interface for generation task. This method can generate sequences
|
772
|
+
by using decoding strategy. Currently, there are three decoding
|
773
|
+
strategies supported: "greedy_search", "sampling" and "beam_search".
|
774
|
+
|
775
|
+
Args:
|
776
|
+
input_ids (Tensor, optional): The input sequence ids for the
|
777
|
+
generation. It is a Tensor with shape [batch_size, sequence_length].
|
778
|
+
The data type should be int32 or int64. Default to None, which
|
779
|
+
we will initialize it as a Tensor with shape [1, 1], filled
|
780
|
+
with the value `bos_token_id`.
|
781
|
+
generation_config (`~generation.GenerationConfig`, *optional*):
|
782
|
+
The generation configuration to be used as base parametrization for the generation call. `**kwargs`
|
783
|
+
passed to generate matching the attributes of `generation_config` will override them. If
|
784
|
+
`generation_config` is not provided, the default will be used, which had the following loading
|
785
|
+
priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
|
786
|
+
configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
|
787
|
+
default values, whose documentation should be checked to parameterize generation.
|
788
|
+
stopping_criteria (`StoppingCriteriaList`, *optional*):
|
789
|
+
Custom stopping criteria that complement the default stopping criteria built from arguments and a
|
790
|
+
generation config. If a stopping criteria is passed that is already created with the arguments or a
|
791
|
+
generation config an error is thrown. This feature is intended for advanced users.
|
792
|
+
streamer (`~streamer.BaseStreamer`, *optional*):
|
793
|
+
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
|
794
|
+
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
|
795
|
+
synced_gpus (`bool`, *optional*):
|
796
|
+
Whether to continue running the while loop until max_length. Unless overridden this flag will be set to
|
797
|
+
`True` under DeepSpeed ZeRO Stage 3 multiple GPUs environment to avoid hanging if one GPU finished
|
798
|
+
generating before other GPUs. Otherwise it'll be set to `False`.
|
799
|
+
kwargs (dict): It can be used to specify additional kwargs
|
800
|
+
passed to the model.
|
801
|
+
|
802
|
+
Returns:
|
803
|
+
tuple[Tensor]: It is a tuple contains two elements: ids and scores.
|
804
|
+
Each element is a Tensor.
|
805
|
+
|
806
|
+
With the fields:
|
807
|
+
|
808
|
+
- ids (Tensor):
|
809
|
+
The ids of the generated sequences. It is a Tensor with shape
|
810
|
+
[batch_size * num_return_sequences, sequence_length]. The data
|
811
|
+
type is same as the input `input_ids`.
|
812
|
+
- scores (Tensor):
|
813
|
+
The scores of the generated sequences. It is a Tensor with shape
|
814
|
+
[batch_size * num_return_sequences, 1]. The data type is float32
|
815
|
+
or float64, which is the same as the parameters in the model.
|
816
|
+
|
817
|
+
Example:
|
818
|
+
.. code-block::
|
819
|
+
|
820
|
+
import paddle
|
821
|
+
from paddlenlp.transformers import (
|
822
|
+
UnifiedTransformerLMHeadModel,
|
823
|
+
UnifiedTransformerTokenizer
|
824
|
+
)
|
825
|
+
|
826
|
+
paddle.seed(2)
|
827
|
+
|
828
|
+
# Initialize the model and tokenizer
|
829
|
+
model_name_or_path = 'unified_transformer-12L-cn-luge'
|
830
|
+
model = UnifiedTransformerLMHeadModel.from_pretrained(model_name_or_path)
|
831
|
+
tokenizer = UnifiedTransformerTokenizer.from_pretrained(model_name_or_path)
|
832
|
+
|
833
|
+
# Prepare the model inputs.
|
834
|
+
history = "早上好,今天空气质量不错。"
|
835
|
+
inputs = tokenizer.dialogue_encode(history, task_type='chitchat',
|
836
|
+
add_start_token_as_response=True, return_tensors=True)
|
837
|
+
|
838
|
+
.. code-block::
|
839
|
+
|
840
|
+
# Generate the sequence by using "greedy_search" strategy
|
841
|
+
ids, scores = model.generate(
|
842
|
+
**inputs,
|
843
|
+
decode_strategy="greedy_search")
|
844
|
+
print(ids.shape, scores.shape)
|
845
|
+
# [1, 3] [1, 1]
|
846
|
+
sequence_ids = ids.cpu().numpy().tolist()[0]
|
847
|
+
sequence_ids = sequence_ids[:sequence_ids.index(tokenizer.sep_token_id)]
|
848
|
+
response = tokenizer.convert_ids_to_string(sequence_ids, keep_space=False)
|
849
|
+
print(response)
|
850
|
+
# 是的
|
851
|
+
|
852
|
+
.. code-block::
|
853
|
+
|
854
|
+
# Generate 2 sequences by using "sampling" strategy (top_k=5)
|
855
|
+
generation_config = GenerationConfig(
|
856
|
+
decode_strategy="sampling",
|
857
|
+
top_k=5,
|
858
|
+
num_return_sequences=2
|
859
|
+
)
|
860
|
+
ids, scores = model.generate(
|
861
|
+
**inputs,
|
862
|
+
generation_config=generation_config,
|
863
|
+
)
|
864
|
+
print(ids.shape, scores.shape)
|
865
|
+
# [2, 7] [2, 1]
|
866
|
+
response = []
|
867
|
+
for sequence_ids in ids.cpu().numpy().tolist():
|
868
|
+
sequence_ids = sequence_ids[:sequence_ids.index(tokenizer.sep_token_id)]
|
869
|
+
text = tokenizer.convert_ids_to_string(sequence_ids, keep_space=False)
|
870
|
+
response.append(text)
|
871
|
+
print(response)
|
872
|
+
# ['天气好,心情也好', '你也是']
|
873
|
+
|
874
|
+
.. code-block::
|
875
|
+
|
876
|
+
# Generate 2 sequences by using "beam_search" strategy (num_beams=5)
|
877
|
+
generation_config = GenerationConfig(
|
878
|
+
decode_strategy="beam_search",
|
879
|
+
num_beams=5,
|
880
|
+
num_return_sequences=2
|
881
|
+
)
|
882
|
+
ids, scores = model.generate(
|
883
|
+
**inputs,
|
884
|
+
generation_config=generation_config,
|
885
|
+
)
|
886
|
+
print(ids.shape, scores.shape)
|
887
|
+
# [2, 3] [2, 1]
|
888
|
+
response = []
|
889
|
+
for sequence_ids in ids.cpu().numpy().tolist():
|
890
|
+
sequence_ids = sequence_ids[:sequence_ids.index(tokenizer.sep_token_id)]
|
891
|
+
text = tokenizer.convert_ids_to_string(sequence_ids, keep_space=False)
|
892
|
+
response.append(text)
|
893
|
+
print(response)
|
894
|
+
# ['是的', '嗯嗯']
|
895
|
+
"""
|
896
|
+
if generation_config is None:
|
897
|
+
if (
|
898
|
+
self.generation_config is None
|
899
|
+
or self.generation_config._from_model_config
|
900
|
+
):
|
901
|
+
new_generation_config = GenerationConfig.from_model_config(self.config)
|
902
|
+
if new_generation_config != self.generation_config:
|
903
|
+
logging.warning(
|
904
|
+
"model.generation_config is in conflict with model.config, "
|
905
|
+
"model.config is used."
|
906
|
+
)
|
907
|
+
self.generation_config = new_generation_config
|
908
|
+
generation_config = self.generation_config
|
909
|
+
|
910
|
+
# without update model.generation_config
|
911
|
+
generation_config = copy.deepcopy(generation_config)
|
912
|
+
model_kwargs = generation_config.update(**kwargs)
|
913
|
+
|
914
|
+
assert generation_config.decode_strategy in [
|
915
|
+
"greedy_search",
|
916
|
+
"sampling",
|
917
|
+
"beam_search",
|
918
|
+
], "`decode_strategy` must be one of 'greedy_search', 'sampling' or 'beam_search' but received {}.".format(
|
919
|
+
generation_config.decode_strategy
|
920
|
+
)
|
921
|
+
|
922
|
+
if getattr(self, "deprecated_warnings", None) is None:
|
923
|
+
self.deprecated_warnings = {}
|
924
|
+
|
925
|
+
use_fast = False
|
926
|
+
if "use_faster" in model_kwargs:
|
927
|
+
raise ValueError("`use_faster` is deprecated now.")
|
928
|
+
|
929
|
+
if "use_fast" in model_kwargs:
|
930
|
+
raise ValueError("`use_fast` is deprecated now.")
|
931
|
+
|
932
|
+
bos_token_id = (
|
933
|
+
generation_config.bos_token_id
|
934
|
+
if generation_config.bos_token_id is not None
|
935
|
+
else self.config.bos_token_id
|
936
|
+
)
|
937
|
+
eos_token_id = (
|
938
|
+
generation_config.eos_token_id
|
939
|
+
if generation_config.eos_token_id is not None
|
940
|
+
else self.config.eos_token_id
|
941
|
+
)
|
942
|
+
pad_token_id = (
|
943
|
+
generation_config.pad_token_id
|
944
|
+
if generation_config.pad_token_id is not None
|
945
|
+
else self.config.pad_token_id
|
946
|
+
)
|
947
|
+
forced_bos_token_id = (
|
948
|
+
generation_config.forced_bos_token_id
|
949
|
+
if generation_config.forced_bos_token_id is not None
|
950
|
+
else self.config.forced_bos_token_id
|
951
|
+
)
|
952
|
+
forced_eos_token_id = (
|
953
|
+
generation_config.forced_eos_token_id
|
954
|
+
if generation_config.forced_eos_token_id is not None
|
955
|
+
else self.config.forced_eos_token_id
|
956
|
+
)
|
957
|
+
decoder_start_token_id = (
|
958
|
+
generation_config.decoder_start_token_id
|
959
|
+
if generation_config.decoder_start_token_id is not None
|
960
|
+
else self.config.decoder_start_token_id
|
961
|
+
)
|
962
|
+
no_repeat_ngram_size = (
|
963
|
+
generation_config.no_repeat_ngram_size
|
964
|
+
if generation_config.no_repeat_ngram_size is not None
|
965
|
+
else self.config.no_repeat_ngram_size
|
966
|
+
)
|
967
|
+
|
968
|
+
if getattr(self, "_fast_entry", None) is not False and use_fast:
|
969
|
+
fg_args = locals()
|
970
|
+
fg_args.pop("self")
|
971
|
+
fg_args.pop("__class__", None)
|
972
|
+
model_kwargs = fg_args.pop("model_kwargs")
|
973
|
+
fg_args.update(model_kwargs)
|
974
|
+
try:
|
975
|
+
if getattr(self, "_fast_entry", None) is None:
|
976
|
+
self._build_fast(fg_args)
|
977
|
+
if self._fast_entry:
|
978
|
+
output = self._fast_entry(**fg_args)
|
979
|
+
if isinstance(output, tuple):
|
980
|
+
output_ids, dummy_srore = output
|
981
|
+
else:
|
982
|
+
output_ids = output
|
983
|
+
# make result and fast result oneconsistent
|
984
|
+
dummy_srore = None
|
985
|
+
if generation_config.decode_strategy == "beam_search":
|
986
|
+
output_ids = output_ids.transpose([1, 2, 0])
|
987
|
+
output_ids = output_ids[
|
988
|
+
:, : generation_config.num_return_sequences, :
|
989
|
+
].reshape([-1, output_ids.shape[-1]])
|
990
|
+
if dummy_srore is not None:
|
991
|
+
dummy_srore = dummy_srore[
|
992
|
+
:, : generation_config.num_return_sequences
|
993
|
+
].flatten()
|
994
|
+
else:
|
995
|
+
output_ids = output_ids.transpose([1, 0])
|
996
|
+
return output_ids, dummy_srore
|
997
|
+
|
998
|
+
except Exception as e:
|
999
|
+
fg_args["model_kwargs"] = model_kwargs
|
1000
|
+
# TODO
|
1001
|
+
# Prevent self._convert_to_fast to throw Exception
|
1002
|
+
self._convert_to_fast(fg_args)
|
1003
|
+
logging.warning(e)
|
1004
|
+
logging.warning(
|
1005
|
+
"FastGeneration is not available, "
|
1006
|
+
"and the original version would be used instead."
|
1007
|
+
)
|
1008
|
+
|
1009
|
+
# input_ids in model_kwargs is supported
|
1010
|
+
if "input_ids" in model_kwargs:
|
1011
|
+
_input_ids = model_kwargs.pop("input_ids")
|
1012
|
+
if input_ids is None:
|
1013
|
+
input_ids = _input_ids
|
1014
|
+
|
1015
|
+
# params check
|
1016
|
+
if input_ids is None and "inputs_embeds" not in model_kwargs:
|
1017
|
+
# Init `input_ids` with bos_token_id
|
1018
|
+
input_ids = self.prepare_input_ids_for_generation(bos_token_id)
|
1019
|
+
elif "inputs_embeds" in model_kwargs:
|
1020
|
+
# Add input embeds support
|
1021
|
+
input_ids = self.prepare_input_ids_for_generation(
|
1022
|
+
bos_token_id, encoder_output=model_kwargs["inputs_embeds"]
|
1023
|
+
)
|
1024
|
+
|
1025
|
+
if model_kwargs.get("attention_mask", None) is None:
|
1026
|
+
# TODO
|
1027
|
+
# Init `attention_mask` depending on `pad_token_id`
|
1028
|
+
model_kwargs["attention_mask"] = self.prepare_attention_mask_for_generation(
|
1029
|
+
input_ids, pad_token_id, eos_token_id
|
1030
|
+
)
|
1031
|
+
self.is_encoder_decoder = self.config.is_encoder_decoder
|
1032
|
+
|
1033
|
+
if self.is_encoder_decoder:
|
1034
|
+
model_kwargs = self.prepare_encoder_decoder_kwargs_for_generation(
|
1035
|
+
input_ids, model_kwargs
|
1036
|
+
)
|
1037
|
+
# set input_ids as decoder_input_ids
|
1038
|
+
if "decoder_input_ids" in model_kwargs:
|
1039
|
+
input_ids = model_kwargs.pop("decoder_input_ids")
|
1040
|
+
else:
|
1041
|
+
input_ids = self.prepare_decoder_input_ids_for_generation(
|
1042
|
+
input_ids, decoder_start_token_id, bos_token_id
|
1043
|
+
)
|
1044
|
+
# streamer
|
1045
|
+
if streamer is not None:
|
1046
|
+
# streamer couldn't support beam_search strategy
|
1047
|
+
if (
|
1048
|
+
generation_config.decode_strategy == "beam_search"
|
1049
|
+
or generation_config.num_beams > 1
|
1050
|
+
):
|
1051
|
+
raise ValueError(
|
1052
|
+
"`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1."
|
1053
|
+
)
|
1054
|
+
|
1055
|
+
pad_token_id = self.set_pad_token_id(pad_token_id, eos_token_id)
|
1056
|
+
|
1057
|
+
if (
|
1058
|
+
generation_config.max_length != 0
|
1059
|
+
and generation_config.max_new_tokens == DEFAULT_MAX_NEW_TOKENS
|
1060
|
+
):
|
1061
|
+
logging.warning(
|
1062
|
+
"`max_length` will be deprecated in future releases, use `max_new_tokens` instead."
|
1063
|
+
)
|
1064
|
+
generation_config.max_new_tokens = generation_config.max_length
|
1065
|
+
|
1066
|
+
if generation_config.min_length != 0 and generation_config.min_new_tokens == 0:
|
1067
|
+
logging.warning(
|
1068
|
+
"`min_length` will be deprecated in future releases, use `min_new_tokens` instead."
|
1069
|
+
)
|
1070
|
+
generation_config.min_new_tokens = generation_config.min_length
|
1071
|
+
|
1072
|
+
max_length = generation_config.max_new_tokens
|
1073
|
+
min_length = generation_config.min_new_tokens
|
1074
|
+
|
1075
|
+
input_len = input_ids.shape[-1]
|
1076
|
+
min_len = input_len + min_length
|
1077
|
+
max_len = input_len + max_length
|
1078
|
+
|
1079
|
+
logits_processors = self.get_logits_processor(
|
1080
|
+
min_length=min_len if min_length > 0 else None,
|
1081
|
+
max_length=max_len,
|
1082
|
+
eos_token_id=eos_token_id,
|
1083
|
+
forced_bos_token_id=forced_bos_token_id,
|
1084
|
+
forced_eos_token_id=forced_eos_token_id,
|
1085
|
+
num_beams=generation_config.num_beams,
|
1086
|
+
num_beam_groups=generation_config.num_beam_groups,
|
1087
|
+
diversity_rate=generation_config.diversity_rate,
|
1088
|
+
repetition_penalty=generation_config.repetition_penalty,
|
1089
|
+
no_repeat_ngram_size=generation_config.no_repeat_ngram_size,
|
1090
|
+
logits_processors=(
|
1091
|
+
model_kwargs["logits_processors"]
|
1092
|
+
if "logits_processors" in model_kwargs
|
1093
|
+
and isinstance(model_kwargs["logits_processors"], LogitsProcessorList)
|
1094
|
+
else None
|
1095
|
+
),
|
1096
|
+
)
|
1097
|
+
if "logits_processors" in model_kwargs:
|
1098
|
+
model_kwargs.pop("logits_processors")
|
1099
|
+
|
1100
|
+
stopping_criteria = (
|
1101
|
+
stopping_criteria
|
1102
|
+
if stopping_criteria is not None
|
1103
|
+
else StoppingCriteriaList()
|
1104
|
+
)
|
1105
|
+
|
1106
|
+
if generation_config.decode_strategy == "greedy_search":
|
1107
|
+
if generation_config.num_return_sequences > 1:
|
1108
|
+
raise ValueError(
|
1109
|
+
"`num_return_sequences` has to be 1, but is {} "
|
1110
|
+
"when doing greedy search.".format(
|
1111
|
+
generation_config.num_return_sequences
|
1112
|
+
)
|
1113
|
+
)
|
1114
|
+
return self.greedy_search(
|
1115
|
+
input_ids,
|
1116
|
+
logits_processors,
|
1117
|
+
max_len,
|
1118
|
+
pad_token_id,
|
1119
|
+
eos_token_id,
|
1120
|
+
stopping_criteria=stopping_criteria,
|
1121
|
+
streamer=streamer,
|
1122
|
+
fast_ptq_sampling=generation_config.fast_ptq_sampling,
|
1123
|
+
trunc_input=generation_config.trunc_input,
|
1124
|
+
synced_gpus=synced_gpus,
|
1125
|
+
**model_kwargs,
|
1126
|
+
)
|
1127
|
+
|
1128
|
+
elif generation_config.decode_strategy == "sampling":
|
1129
|
+
if generation_config.num_return_sequences > 1:
|
1130
|
+
input_ids, model_kwargs = self.expand_inputs_for_generation(
|
1131
|
+
input_ids,
|
1132
|
+
expand_size=generation_config.num_return_sequences,
|
1133
|
+
**model_kwargs,
|
1134
|
+
)
|
1135
|
+
|
1136
|
+
return self.sample(
|
1137
|
+
input_ids,
|
1138
|
+
logits_processors,
|
1139
|
+
max_len,
|
1140
|
+
pad_token_id,
|
1141
|
+
eos_token_id,
|
1142
|
+
generation_config.top_k,
|
1143
|
+
generation_config.top_p,
|
1144
|
+
generation_config.temperature,
|
1145
|
+
stopping_criteria=stopping_criteria,
|
1146
|
+
streamer=streamer,
|
1147
|
+
fast_ptq_sampling=generation_config.fast_ptq_sampling,
|
1148
|
+
trunc_input=generation_config.trunc_input,
|
1149
|
+
synced_gpus=synced_gpus,
|
1150
|
+
**model_kwargs,
|
1151
|
+
)
|
1152
|
+
|
1153
|
+
elif generation_config.decode_strategy == "beam_search":
|
1154
|
+
batch_size = input_ids.shape[0]
|
1155
|
+
if generation_config.num_return_sequences > generation_config.num_beams:
|
1156
|
+
raise ValueError(
|
1157
|
+
"`num_return_sequences` has to be smaller or equal to "
|
1158
|
+
"`num_beams`. But received `num_return_sequences` is {}, "
|
1159
|
+
"`num_beams` is {}".format(
|
1160
|
+
generation_config.num_return_sequences,
|
1161
|
+
generation_config.num_beams,
|
1162
|
+
)
|
1163
|
+
)
|
1164
|
+
if generation_config.num_beams <= 1:
|
1165
|
+
raise ValueError(
|
1166
|
+
"`num_beams` has to be bigger than 1. But received "
|
1167
|
+
"`num_beams` is {}. If `num_beams` is 1, `decode_strategy` "
|
1168
|
+
"should be 'greedy_search'".format(generation_config.num_beams)
|
1169
|
+
)
|
1170
|
+
if generation_config.num_beam_groups > 1:
|
1171
|
+
diverse_beam_scorer = BeamSearchScorer(
|
1172
|
+
batch_size=batch_size,
|
1173
|
+
max_length=max_len,
|
1174
|
+
num_beams=generation_config.num_beams,
|
1175
|
+
length_penalty=generation_config.length_penalty,
|
1176
|
+
do_early_stopping=generation_config.early_stopping,
|
1177
|
+
num_beam_hyps_to_keep=generation_config.num_return_sequences,
|
1178
|
+
num_beam_groups=generation_config.num_beam_groups,
|
1179
|
+
)
|
1180
|
+
|
1181
|
+
# interleave with `num_beams`
|
1182
|
+
input_ids, model_kwargs = self.expand_inputs_for_generation(
|
1183
|
+
input_ids, expand_size=generation_config.num_beams, **model_kwargs
|
1184
|
+
)
|
1185
|
+
|
1186
|
+
return self.group_beam_search(
|
1187
|
+
input_ids,
|
1188
|
+
diverse_beam_scorer,
|
1189
|
+
logits_processors,
|
1190
|
+
max_len,
|
1191
|
+
pad_token_id,
|
1192
|
+
eos_token_id,
|
1193
|
+
stopping_criteria=stopping_criteria,
|
1194
|
+
fast_ptq_sampling=generation_config.fast_ptq_sampling,
|
1195
|
+
trunc_input=generation_config.trunc_input,
|
1196
|
+
synced_gpus=synced_gpus,
|
1197
|
+
**model_kwargs,
|
1198
|
+
)
|
1199
|
+
else:
|
1200
|
+
beam_scorer = BeamSearchScorer(
|
1201
|
+
batch_size=batch_size,
|
1202
|
+
max_length=max_len,
|
1203
|
+
num_beams=generation_config.num_beams,
|
1204
|
+
length_penalty=generation_config.length_penalty,
|
1205
|
+
do_early_stopping=generation_config.early_stopping,
|
1206
|
+
num_beam_hyps_to_keep=generation_config.num_return_sequences,
|
1207
|
+
)
|
1208
|
+
|
1209
|
+
input_ids, model_kwargs = self.expand_inputs_for_generation(
|
1210
|
+
input_ids, expand_size=generation_config.num_beams, **model_kwargs
|
1211
|
+
)
|
1212
|
+
|
1213
|
+
return self.beam_search(
|
1214
|
+
input_ids,
|
1215
|
+
beam_scorer,
|
1216
|
+
logits_processors,
|
1217
|
+
max_len,
|
1218
|
+
generation_config.diversity_rate,
|
1219
|
+
pad_token_id,
|
1220
|
+
eos_token_id,
|
1221
|
+
stopping_criteria=stopping_criteria,
|
1222
|
+
fast_ptq_sampling=generation_config.fast_ptq_sampling,
|
1223
|
+
trunc_input=generation_config.trunc_input,
|
1224
|
+
synced_gpus=synced_gpus,
|
1225
|
+
**model_kwargs,
|
1226
|
+
)
|
1227
|
+
|
1228
|
+
def greedy_search(
|
1229
|
+
self,
|
1230
|
+
input_ids,
|
1231
|
+
logits_processors,
|
1232
|
+
max_length,
|
1233
|
+
pad_token_id,
|
1234
|
+
eos_token_id,
|
1235
|
+
stopping_criteria=None,
|
1236
|
+
streamer=None,
|
1237
|
+
fast_ptq_sampling=False,
|
1238
|
+
trunc_input=True,
|
1239
|
+
synced_gpus=False,
|
1240
|
+
**model_kwargs,
|
1241
|
+
):
|
1242
|
+
model_kwargs["use_cache"] = model_kwargs.get("use_cache", True)
|
1243
|
+
logits_processors = (
|
1244
|
+
logits_processors
|
1245
|
+
if logits_processors is not None
|
1246
|
+
else LogitsProcessorList()
|
1247
|
+
)
|
1248
|
+
|
1249
|
+
# max_length will be convert to MaxLengthCriteria
|
1250
|
+
stopping_criteria = (
|
1251
|
+
stopping_criteria
|
1252
|
+
if stopping_criteria is not None
|
1253
|
+
else StoppingCriteriaList()
|
1254
|
+
)
|
1255
|
+
if max_length is not None:
|
1256
|
+
# logging.warning(
|
1257
|
+
# "`max_length` is deprecated in this function, use"
|
1258
|
+
# " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead."
|
1259
|
+
# )
|
1260
|
+
stopping_criteria = validate_stopping_criteria(
|
1261
|
+
stopping_criteria, max_length
|
1262
|
+
)
|
1263
|
+
|
1264
|
+
batch_size, cur_len = input_ids.shape
|
1265
|
+
origin_len = cur_len
|
1266
|
+
unfinished_flag = paddle.full([batch_size, 1], True, dtype="bool")
|
1267
|
+
scores = paddle.full([batch_size, 1], 0.0, dtype=paddle.get_default_dtype())
|
1268
|
+
generate_end = False
|
1269
|
+
while True:
|
1270
|
+
if synced_gpus:
|
1271
|
+
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
|
1272
|
+
# The following logic allows an early break if all peers finished generating their sequence
|
1273
|
+
this_peer_finished_flag = paddle.to_tensor(0.0 if generate_end else 1.0)
|
1274
|
+
# send 0.0 if we finished, 1.0 otherwise
|
1275
|
+
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
|
1276
|
+
# did all peers finish? the reduced sum will be 0.0 then
|
1277
|
+
if this_peer_finished_flag.item() == 0.0:
|
1278
|
+
break
|
1279
|
+
|
1280
|
+
# prepare model inputs & get model output
|
1281
|
+
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
1282
|
+
|
1283
|
+
outputs = self(**model_inputs)
|
1284
|
+
|
1285
|
+
if synced_gpus and generate_end:
|
1286
|
+
continue # don't waste resources running the code we don't need
|
1287
|
+
|
1288
|
+
if isinstance(outputs, tuple):
|
1289
|
+
logits = outputs[0]
|
1290
|
+
elif isinstance(outputs, ModelOutput):
|
1291
|
+
logits = outputs.logits
|
1292
|
+
else:
|
1293
|
+
logits = outputs
|
1294
|
+
|
1295
|
+
# [batch_size, vocab_size]
|
1296
|
+
next_token_logits = logits[:, -1, :]
|
1297
|
+
|
1298
|
+
# pre-process distribution
|
1299
|
+
next_token_logits = self.adjust_logits_during_generation(next_token_logits)
|
1300
|
+
probs = logits_processors(input_ids, next_token_logits)
|
1301
|
+
# greedy
|
1302
|
+
next_tokens = paddle.argmax(probs, axis=-1).unsqueeze(-1)
|
1303
|
+
next_scores = paddle.index_sample(probs, next_tokens)
|
1304
|
+
|
1305
|
+
if eos_token_id is not None:
|
1306
|
+
next_tokens = paddle.where(
|
1307
|
+
unfinished_flag,
|
1308
|
+
next_tokens,
|
1309
|
+
paddle.full_like(next_tokens, pad_token_id),
|
1310
|
+
)
|
1311
|
+
|
1312
|
+
scores = self.update_scores_for_generation(
|
1313
|
+
scores, next_scores, cur_len - origin_len, unfinished_flag
|
1314
|
+
)
|
1315
|
+
cur_len += 1
|
1316
|
+
|
1317
|
+
input_ids = paddle.concat([input_ids, next_tokens], axis=1)
|
1318
|
+
if streamer is not None:
|
1319
|
+
if self.config.tensor_parallel_rank == 0:
|
1320
|
+
streamer.put(next_tokens.cpu())
|
1321
|
+
|
1322
|
+
if stopping_criteria(input_ids, scores):
|
1323
|
+
generate_end = True
|
1324
|
+
|
1325
|
+
if eos_token_id is not None:
|
1326
|
+
unfinished_flag = get_unfinished_flag(
|
1327
|
+
input_ids, unfinished_flag, eos_token_id
|
1328
|
+
)
|
1329
|
+
if not paddle.any(unfinished_flag):
|
1330
|
+
generate_end = True
|
1331
|
+
|
1332
|
+
# Stop when there is a </s> in all sentences
|
1333
|
+
if generate_end and not synced_gpus:
|
1334
|
+
break
|
1335
|
+
|
1336
|
+
model_kwargs = self.update_model_kwargs_for_generation(
|
1337
|
+
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
1338
|
+
)
|
1339
|
+
if fast_ptq_sampling:
|
1340
|
+
break
|
1341
|
+
|
1342
|
+
if streamer is not None:
|
1343
|
+
streamer.end()
|
1344
|
+
|
1345
|
+
return input_ids[:, origin_len:] if trunc_input else input_ids, scores
|
1346
|
+
|
1347
|
+
def sample(
|
1348
|
+
self,
|
1349
|
+
input_ids,
|
1350
|
+
logits_processors,
|
1351
|
+
max_length,
|
1352
|
+
pad_token_id,
|
1353
|
+
eos_token_id,
|
1354
|
+
top_k=None,
|
1355
|
+
top_p=None,
|
1356
|
+
temperature=None,
|
1357
|
+
min_tokens_to_keep=1,
|
1358
|
+
stopping_criteria=None,
|
1359
|
+
streamer=None,
|
1360
|
+
fast_ptq_sampling=False,
|
1361
|
+
trunc_input=True,
|
1362
|
+
synced_gpus=False,
|
1363
|
+
**model_kwargs,
|
1364
|
+
):
|
1365
|
+
model_kwargs["use_cache"] = model_kwargs.get("use_cache", True)
|
1366
|
+
|
1367
|
+
logits_processors = (
|
1368
|
+
logits_processors
|
1369
|
+
if logits_processors is not None
|
1370
|
+
else LogitsProcessorList()
|
1371
|
+
)
|
1372
|
+
|
1373
|
+
# max_length will be convert to MaxLengthCriteria
|
1374
|
+
stopping_criteria = (
|
1375
|
+
stopping_criteria
|
1376
|
+
if stopping_criteria is not None
|
1377
|
+
else StoppingCriteriaList()
|
1378
|
+
)
|
1379
|
+
if max_length is not None:
|
1380
|
+
# logging.warning(
|
1381
|
+
# "`max_length` is deprecated in this function, use"
|
1382
|
+
# " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead."
|
1383
|
+
# )
|
1384
|
+
stopping_criteria = validate_stopping_criteria(
|
1385
|
+
stopping_criteria, max_length
|
1386
|
+
)
|
1387
|
+
|
1388
|
+
batch_size, cur_len = input_ids.shape
|
1389
|
+
origin_len = cur_len
|
1390
|
+
unfinished_flag = paddle.full([batch_size, 1], True, dtype="bool")
|
1391
|
+
scores = paddle.full([batch_size, 1], 0.0, dtype=paddle.get_default_dtype())
|
1392
|
+
|
1393
|
+
generate_end = False
|
1394
|
+
while True:
|
1395
|
+
if synced_gpus:
|
1396
|
+
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
|
1397
|
+
# The following logic allows an early break if all peers finished generating their sequence
|
1398
|
+
this_peer_finished_flag = paddle.to_tensor(0.0 if generate_end else 1.0)
|
1399
|
+
# send 0.0 if we finished, 1.0 otherwise
|
1400
|
+
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
|
1401
|
+
# did all peers finish? the reduced sum will be 0.0 then
|
1402
|
+
if this_peer_finished_flag.item() == 0.0:
|
1403
|
+
break
|
1404
|
+
# prepare model inputs & get model output
|
1405
|
+
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
1406
|
+
# NOTE: to decrease ref-count and clear outdate cache in-time
|
1407
|
+
model_kwargs["cache"] = None
|
1408
|
+
model_kwargs["past_key_values"] = None
|
1409
|
+
outputs = self(**model_inputs)
|
1410
|
+
if synced_gpus and generate_end:
|
1411
|
+
continue # don't waste resources running the code we don't need
|
1412
|
+
|
1413
|
+
if isinstance(outputs, tuple):
|
1414
|
+
logits = outputs[0]
|
1415
|
+
elif isinstance(outputs, ModelOutput):
|
1416
|
+
logits = outputs.logits
|
1417
|
+
else:
|
1418
|
+
logits = outputs
|
1419
|
+
|
1420
|
+
# [batch_size, vocab_size]
|
1421
|
+
logits = logits[:, -1, :]
|
1422
|
+
|
1423
|
+
# pre-process distribution
|
1424
|
+
logits = self.adjust_logits_during_generation(logits)
|
1425
|
+
logits = logits_processors(input_ids, logits)
|
1426
|
+
|
1427
|
+
# sample
|
1428
|
+
origin_probs = F.softmax(logits)
|
1429
|
+
origin_probs = paddle.log(origin_probs)
|
1430
|
+
if temperature is not None and temperature != 1.0:
|
1431
|
+
logits = logits / temperature
|
1432
|
+
probs = F.softmax(logits)
|
1433
|
+
if top_k is not None and top_k != 0:
|
1434
|
+
probs = TopKProcess(probs, top_k, min_tokens_to_keep)
|
1435
|
+
if top_p is not None and top_p < 1.0:
|
1436
|
+
probs = TopPProcess(probs, top_p, min_tokens_to_keep)
|
1437
|
+
if paddle.device.is_compiled_with_custom_device("gcu"):
|
1438
|
+
probs = paddle.cast(probs, "float32")
|
1439
|
+
if paddle.device.is_compiled_with_xpu():
|
1440
|
+
probs = paddle.cast(probs, "float32")
|
1441
|
+
|
1442
|
+
# multinomial already support fp16 and bf16 currently, fix issue: https://github.com/PaddlePaddle/Paddle/issues/51852
|
1443
|
+
next_tokens = paddle.multinomial(probs)
|
1444
|
+
|
1445
|
+
if self.config.tensor_parallel_degree > 1:
|
1446
|
+
# Maybe no need to broadcast if seed is set correctly.
|
1447
|
+
from paddle.distributed import fleet
|
1448
|
+
|
1449
|
+
try:
|
1450
|
+
hcg = fleet.get_hybrid_communicate_group()
|
1451
|
+
group = hcg.get_model_parallel_group()
|
1452
|
+
src = hcg.get_model_parallel_group_src_rank()
|
1453
|
+
except:
|
1454
|
+
group, src = None, 0
|
1455
|
+
paddle.distributed.broadcast(next_tokens, src=src, group=group)
|
1456
|
+
# config does not include pipeline_parallel_degree, and pipeline parallel
|
1457
|
+
# uses trainer.model_wrapped to run in both train and predict mode
|
1458
|
+
# which has pp_group as a attribute
|
1459
|
+
# TODO(guosheng): only let the last stage of pipeline to do softmax
|
1460
|
+
# and sampling, and then broadcast to avoid broadcast logits.
|
1461
|
+
if getattr(self, "pp_group", None) is not None:
|
1462
|
+
paddle.distributed.broadcast(
|
1463
|
+
next_tokens,
|
1464
|
+
src=self.pp_group.ranks[0],
|
1465
|
+
group=self.pp_group, # use rank 0 for same seed to check
|
1466
|
+
)
|
1467
|
+
|
1468
|
+
next_scores = paddle.index_sample(origin_probs, next_tokens)
|
1469
|
+
if eos_token_id is not None:
|
1470
|
+
next_tokens = paddle.where(
|
1471
|
+
unfinished_flag,
|
1472
|
+
next_tokens,
|
1473
|
+
paddle.full_like(next_tokens, pad_token_id),
|
1474
|
+
)
|
1475
|
+
|
1476
|
+
scores = self.update_scores_for_generation(
|
1477
|
+
scores, next_scores, cur_len - origin_len, unfinished_flag
|
1478
|
+
)
|
1479
|
+
|
1480
|
+
cur_len += 1
|
1481
|
+
input_ids = paddle.concat([input_ids, next_tokens], axis=1)
|
1482
|
+
if streamer is not None:
|
1483
|
+
if self.config.tensor_parallel_rank == 0:
|
1484
|
+
streamer.put(next_tokens.cpu())
|
1485
|
+
|
1486
|
+
if stopping_criteria(input_ids, scores):
|
1487
|
+
generate_end = True
|
1488
|
+
|
1489
|
+
if eos_token_id is not None:
|
1490
|
+
unfinished_flag = get_unfinished_flag(
|
1491
|
+
input_ids, unfinished_flag, eos_token_id
|
1492
|
+
)
|
1493
|
+
if not paddle.any(unfinished_flag):
|
1494
|
+
generate_end = True
|
1495
|
+
|
1496
|
+
# Stop when there is a </s> in all sentences
|
1497
|
+
if generate_end and not synced_gpus:
|
1498
|
+
break
|
1499
|
+
|
1500
|
+
model_kwargs = self.update_model_kwargs_for_generation(
|
1501
|
+
outputs, model_kwargs, is_encoder_decoder=self.is_encoder_decoder
|
1502
|
+
)
|
1503
|
+
if fast_ptq_sampling:
|
1504
|
+
break
|
1505
|
+
|
1506
|
+
if streamer is not None:
|
1507
|
+
streamer.end()
|
1508
|
+
|
1509
|
+
return input_ids[:, origin_len:] if trunc_input else input_ids, scores
|
1510
|
+
|
1511
|
+
def _get_model_inputs_spec(self, dtype: str):
|
1512
|
+
spec = {
|
1513
|
+
"input_ids": paddle.static.InputSpec(shape=[None, None], dtype="int64"),
|
1514
|
+
"attention_mask": paddle.static.InputSpec(
|
1515
|
+
shape=[None, None], dtype="int64"
|
1516
|
+
),
|
1517
|
+
}
|
1518
|
+
if "position_ids" in inspect.getfullargspec(self.forward).args:
|
1519
|
+
spec["position_ids"] = paddle.static.InputSpec(
|
1520
|
+
shape=[None, None], dtype="int64"
|
1521
|
+
)
|
1522
|
+
return spec
|
1523
|
+
|
1524
|
+
def to_static(self, path: str, config: dict):
|
1525
|
+
"""export generation model to static
|
1526
|
+
|
1527
|
+
Args:
|
1528
|
+
path (str): path of saved inference model
|
1529
|
+
config (dict): configuration for generation
|
1530
|
+
bos_token_id (int): token id of begin-of-sentence
|
1531
|
+
eos_token_id (int): token id of end-of-sentence
|
1532
|
+
pad_token_id (int): token id of pad token
|
1533
|
+
use_top_p (bool): whether use top_p decoding strategy
|
1534
|
+
"""
|
1535
|
+
|
1536
|
+
use_top_p = config.get("use_top_p", True)
|
1537
|
+
|
1538
|
+
top_k_spec = (
|
1539
|
+
paddle.static.InputSpec(shape=[1], dtype="int64") if not use_top_p else 0
|
1540
|
+
)
|
1541
|
+
|
1542
|
+
top_p_spec = (
|
1543
|
+
paddle.static.InputSpec(shape=[1], dtype="float32") if use_top_p else 1.0
|
1544
|
+
)
|
1545
|
+
temperature = (
|
1546
|
+
paddle.static.InputSpec(shape=[1], dtype="float32") if use_top_p else 1.0
|
1547
|
+
)
|
1548
|
+
dtype = config.get("dtype", None)
|
1549
|
+
|
1550
|
+
logits_processors = config.get("logits_processors", None)
|
1551
|
+
model_inputs_spec = self._get_model_inputs_spec(dtype)
|
1552
|
+
|
1553
|
+
input_spec = [
|
1554
|
+
model_inputs_spec["input_ids"], # input_ids
|
1555
|
+
model_inputs_spec["attention_mask"], # attention_mask
|
1556
|
+
model_inputs_spec.get("position_ids", None), # attention_mask
|
1557
|
+
logits_processors,
|
1558
|
+
paddle.static.InputSpec(shape=[1], dtype="int64"), # max_length
|
1559
|
+
self.generation_config.pad_token_id or config.get("pad_token_id", None),
|
1560
|
+
self.generation_config.eos_token_id or config.get("eos_token_id", None),
|
1561
|
+
top_k_spec, # top_k
|
1562
|
+
top_p_spec, # top_p
|
1563
|
+
temperature, # temperature
|
1564
|
+
1,
|
1565
|
+
]
|
1566
|
+
|
1567
|
+
model = paddle.jit.to_static(self.sample_d2s, input_spec=input_spec)
|
1568
|
+
|
1569
|
+
paddle.jit.save(model, path)
|
1570
|
+
|
1571
|
+
def sample_d2s(
|
1572
|
+
self,
|
1573
|
+
input_ids,
|
1574
|
+
attention_mask,
|
1575
|
+
position_ids,
|
1576
|
+
logits_processors,
|
1577
|
+
max_new_tokens,
|
1578
|
+
pad_token_id,
|
1579
|
+
eos_token_id,
|
1580
|
+
top_k=None,
|
1581
|
+
top_p=None,
|
1582
|
+
temperature=None,
|
1583
|
+
min_tokens_to_keep=1,
|
1584
|
+
):
|
1585
|
+
|
1586
|
+
pad_token_id = self.set_pad_token_id(pad_token_id, eos_token_id)
|
1587
|
+
logits_processors = (
|
1588
|
+
logits_processors
|
1589
|
+
if logits_processors is not None
|
1590
|
+
else LogitsProcessorList()
|
1591
|
+
)
|
1592
|
+
|
1593
|
+
if paddle.is_tensor(top_k) and not paddle.is_tensor(top_p):
|
1594
|
+
use_top_p = False
|
1595
|
+
elif not paddle.is_tensor(top_k) and paddle.is_tensor(top_p):
|
1596
|
+
use_top_p = True
|
1597
|
+
|
1598
|
+
# top_k and top_p are the const value
|
1599
|
+
elif isinstance(top_p, float) or isinstance(top_k, int):
|
1600
|
+
use_top_p = True
|
1601
|
+
else:
|
1602
|
+
if top_p is None and top_k is None:
|
1603
|
+
raise ValueError("top_k and top_p should not be None")
|
1604
|
+
raise ValueError(
|
1605
|
+
"you should not specify InputSpec for top_k and top_p parameters, one of InputSpec is expected"
|
1606
|
+
)
|
1607
|
+
|
1608
|
+
batch_size, cur_len = input_ids.shape
|
1609
|
+
# used for compute on gpu, avoid memcpy D2H
|
1610
|
+
cur_len_gpu = paddle.full([1], cur_len, dtype="int64")
|
1611
|
+
|
1612
|
+
origin_len = input_ids.shape[1]
|
1613
|
+
# used for compute on gpu, avoid memcpy D2H
|
1614
|
+
origin_len_gpu = paddle.full([1], origin_len, dtype="int64")
|
1615
|
+
|
1616
|
+
unfinished_flag = paddle.full([batch_size, 1], True, dtype="bool")
|
1617
|
+
|
1618
|
+
scores = paddle.full([batch_size, 1], 0.0, dtype=paddle.get_default_dtype())
|
1619
|
+
|
1620
|
+
# use_cache is immutable, we split it off other mutable kwargs.
|
1621
|
+
immutable = {"use_cache": True}
|
1622
|
+
model_kwargs = {"attention_mask": attention_mask, "position_ids": position_ids}
|
1623
|
+
|
1624
|
+
def _forward_(**args):
|
1625
|
+
model_inputs = self.prepare_inputs_for_generation(
|
1626
|
+
input_ids, **args, **immutable
|
1627
|
+
)
|
1628
|
+
assert "use_cache" in model_inputs
|
1629
|
+
del model_inputs["use_cache"]
|
1630
|
+
return self(**model_inputs, **immutable)
|
1631
|
+
|
1632
|
+
def _post_process_(
|
1633
|
+
outputs,
|
1634
|
+
input_ids,
|
1635
|
+
cur_len,
|
1636
|
+
origin_len,
|
1637
|
+
scores,
|
1638
|
+
unfinished_flag,
|
1639
|
+
model_kwargs,
|
1640
|
+
pad_token_id,
|
1641
|
+
):
|
1642
|
+
if isinstance(outputs, tuple):
|
1643
|
+
logits = outputs[0]
|
1644
|
+
elif isinstance(outputs, ModelOutput):
|
1645
|
+
logits = outputs.logits
|
1646
|
+
else:
|
1647
|
+
logits = outputs
|
1648
|
+
|
1649
|
+
# [batch_size, vocab_size]
|
1650
|
+
logits = logits[:, -1, :]
|
1651
|
+
|
1652
|
+
# pre-process distribution
|
1653
|
+
logits = self.adjust_logits_during_generation(logits)
|
1654
|
+
|
1655
|
+
logits = logits_processors(input_ids, logits)
|
1656
|
+
probs = F.softmax(logits)
|
1657
|
+
|
1658
|
+
# sample
|
1659
|
+
origin_probs = F.log_softmax(logits)
|
1660
|
+
# compute next_tokens
|
1661
|
+
if use_top_p:
|
1662
|
+
logits = logits / temperature
|
1663
|
+
top_ps_tensor = paddle.full(
|
1664
|
+
shape=[probs.shape[0], 1], fill_value=top_p, dtype=probs.dtype
|
1665
|
+
)
|
1666
|
+
_, next_tokens = paddle.tensor.top_p_sampling(probs, top_ps_tensor)
|
1667
|
+
else:
|
1668
|
+
probs = TopKProcess(probs, top_k, min_tokens_to_keep)
|
1669
|
+
if top_k == 1:
|
1670
|
+
next_tokens = paddle.unsqueeze_(paddle.argmax(probs, axis=-1), -1)
|
1671
|
+
else:
|
1672
|
+
next_tokens = paddle.multinomial(probs)
|
1673
|
+
|
1674
|
+
next_scores = paddle.index_sample(origin_probs, next_tokens)
|
1675
|
+
scores = self.update_scores_for_generation(
|
1676
|
+
scores, next_scores, cur_len - origin_len, unfinished_flag
|
1677
|
+
)
|
1678
|
+
if eos_token_id is not None:
|
1679
|
+
next_tokens = paddle.where(
|
1680
|
+
unfinished_flag,
|
1681
|
+
next_tokens,
|
1682
|
+
paddle.full_like(next_tokens, pad_token_id),
|
1683
|
+
)
|
1684
|
+
|
1685
|
+
input_ids = paddle.concat([input_ids, next_tokens], axis=1)
|
1686
|
+
|
1687
|
+
if eos_token_id is not None:
|
1688
|
+
unfinished_flag = get_unfinished_flag(
|
1689
|
+
input_ids, unfinished_flag, eos_token_id
|
1690
|
+
)
|
1691
|
+
|
1692
|
+
model_kwargs = self.update_model_kwargs_for_generation(
|
1693
|
+
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
1694
|
+
)
|
1695
|
+
|
1696
|
+
return input_ids, scores, unfinished_flag, model_kwargs
|
1697
|
+
|
1698
|
+
outputs = _forward_(**model_kwargs)
|
1699
|
+
input_ids, scores, unfinished_flag, model_kwargs = _post_process_(
|
1700
|
+
outputs,
|
1701
|
+
input_ids,
|
1702
|
+
cur_len_gpu,
|
1703
|
+
origin_len_gpu,
|
1704
|
+
scores,
|
1705
|
+
unfinished_flag,
|
1706
|
+
model_kwargs,
|
1707
|
+
pad_token_id,
|
1708
|
+
)
|
1709
|
+
|
1710
|
+
cur_len += 1
|
1711
|
+
cur_len_gpu += 1
|
1712
|
+
|
1713
|
+
attn_mask = model_kwargs["attention_mask"]
|
1714
|
+
# make the shape of attention_mask = (-1, -1, -1, -1) in dy2static.
|
1715
|
+
model_kwargs["attention_mask"] = paddle.reshape(attn_mask, attn_mask.shape)
|
1716
|
+
model_kwargs["cache"] = outputs[1] if isinstance(outputs, tuple) else None
|
1717
|
+
max_new_tokens = paddle.full([1], max_new_tokens + cur_len - 1, dtype="int64")
|
1718
|
+
|
1719
|
+
while cur_len < max_new_tokens and paddle.any(unfinished_flag):
|
1720
|
+
input_ids, scores, unfinished_flag, model_kwargs = _post_process_(
|
1721
|
+
_forward_(**model_kwargs),
|
1722
|
+
input_ids,
|
1723
|
+
cur_len_gpu,
|
1724
|
+
origin_len_gpu,
|
1725
|
+
scores,
|
1726
|
+
unfinished_flag,
|
1727
|
+
model_kwargs,
|
1728
|
+
pad_token_id,
|
1729
|
+
)
|
1730
|
+
cur_len += 1
|
1731
|
+
cur_len_gpu += 1
|
1732
|
+
|
1733
|
+
return input_ids[:, origin_len:], scores
|
1734
|
+
|
1735
|
+
def reorder_cache(self, cache, beam_idx):
|
1736
|
+
cache = map_structure(lambda x: paddle.index_select(x, beam_idx), cache)
|
1737
|
+
return cache
|
1738
|
+
|
1739
|
+
def beam_search(
|
1740
|
+
self,
|
1741
|
+
input_ids,
|
1742
|
+
beam_scorer,
|
1743
|
+
logits_processors,
|
1744
|
+
max_length,
|
1745
|
+
diversity_rate,
|
1746
|
+
pad_token_id,
|
1747
|
+
eos_token_id,
|
1748
|
+
stopping_criteria=None,
|
1749
|
+
fast_ptq_sampling=False,
|
1750
|
+
trunc_input=True,
|
1751
|
+
synced_gpus=False,
|
1752
|
+
**model_kwargs,
|
1753
|
+
):
|
1754
|
+
model_kwargs["use_cache"] = model_kwargs.get("use_cache", True)
|
1755
|
+
|
1756
|
+
logits_processors = (
|
1757
|
+
logits_processors
|
1758
|
+
if logits_processors is not None
|
1759
|
+
else LogitsProcessorList()
|
1760
|
+
)
|
1761
|
+
|
1762
|
+
# max_length will be convert to MaxLengthCriteria
|
1763
|
+
stopping_criteria = (
|
1764
|
+
stopping_criteria
|
1765
|
+
if stopping_criteria is not None
|
1766
|
+
else StoppingCriteriaList()
|
1767
|
+
)
|
1768
|
+
if max_length is not None:
|
1769
|
+
# logging.warning(
|
1770
|
+
# "`max_length` is deprecated in this function, use"
|
1771
|
+
# " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead."
|
1772
|
+
# )
|
1773
|
+
stopping_criteria = validate_stopping_criteria(
|
1774
|
+
stopping_criteria, max_length
|
1775
|
+
)
|
1776
|
+
|
1777
|
+
batch_size = len(beam_scorer._beam_hyps)
|
1778
|
+
num_beams = beam_scorer.num_beams
|
1779
|
+
batch_beam_size, cur_len = input_ids.shape
|
1780
|
+
origin_len = cur_len
|
1781
|
+
|
1782
|
+
assert (
|
1783
|
+
num_beams * batch_size == batch_beam_size
|
1784
|
+
), "Batch dimension of `input_ids` should be {}, but received {}.".format(
|
1785
|
+
num_beams * batch_size, batch_beam_size
|
1786
|
+
)
|
1787
|
+
|
1788
|
+
beam_scores = paddle.zeros(
|
1789
|
+
(batch_size, num_beams), dtype=paddle.get_default_dtype()
|
1790
|
+
)
|
1791
|
+
|
1792
|
+
beam_scores[:, 1:] = get_scale_by_dtype(return_positive=False)
|
1793
|
+
beam_scores = paddle.reshape(beam_scores, [-1])
|
1794
|
+
|
1795
|
+
generate_end = False
|
1796
|
+
while True:
|
1797
|
+
if synced_gpus:
|
1798
|
+
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
|
1799
|
+
# The following logic allows an early break if all peers finished generating their sequence
|
1800
|
+
this_peer_finished_flag = paddle.to_tensor(0.0 if generate_end else 1.0)
|
1801
|
+
# send 0.0 if we finished, 1.0 otherwise
|
1802
|
+
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
|
1803
|
+
# did all peers finish? the reduced sum will be 0.0 then
|
1804
|
+
if this_peer_finished_flag.item() == 0.0:
|
1805
|
+
break
|
1806
|
+
# prepare model inputs & get model output
|
1807
|
+
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
1808
|
+
|
1809
|
+
outputs = self(**model_inputs)
|
1810
|
+
if synced_gpus and generate_end:
|
1811
|
+
cur_len = cur_len + 1
|
1812
|
+
continue # don't waste resources running the code we don't need
|
1813
|
+
|
1814
|
+
if isinstance(outputs, tuple):
|
1815
|
+
logits = outputs[0]
|
1816
|
+
elif isinstance(outputs, ModelOutput):
|
1817
|
+
logits = outputs.logits
|
1818
|
+
else:
|
1819
|
+
logits = outputs
|
1820
|
+
|
1821
|
+
# [batch_size, vocab_size]
|
1822
|
+
logits = logits[:, -1, :]
|
1823
|
+
|
1824
|
+
# pre-process distribution
|
1825
|
+
logits = self.adjust_logits_during_generation(logits)
|
1826
|
+
# beam search
|
1827
|
+
# [batch_size * num_beams, vocab_size]
|
1828
|
+
next_scores = F.softmax(logits)
|
1829
|
+
next_scores = paddle.log(next_scores)
|
1830
|
+
next_scores = logits_processors(input_ids, next_scores)
|
1831
|
+
next_scores = next_scores + beam_scores.unsqueeze(-1)
|
1832
|
+
|
1833
|
+
vocab_size = next_scores.shape[-1]
|
1834
|
+
if diversity_rate == 0.0:
|
1835
|
+
# reshape for beam search
|
1836
|
+
next_scores = next_scores.reshape([batch_size, num_beams * vocab_size])
|
1837
|
+
|
1838
|
+
next_scores, next_tokens = paddle.topk(
|
1839
|
+
next_scores, 2 * num_beams, axis=1
|
1840
|
+
)
|
1841
|
+
|
1842
|
+
next_indices = next_tokens // vocab_size
|
1843
|
+
next_tokens = next_tokens % vocab_size
|
1844
|
+
|
1845
|
+
else:
|
1846
|
+
next_scores, next_tokens = paddle.topk(
|
1847
|
+
next_scores, 2 * num_beams, axis=1
|
1848
|
+
)
|
1849
|
+
|
1850
|
+
sibling_score = (
|
1851
|
+
paddle.arange(1, 2 * num_beams + 1, dtype="int64").unsqueeze(0)
|
1852
|
+
* diversity_rate
|
1853
|
+
)
|
1854
|
+
|
1855
|
+
diversed_score = next_scores - sibling_score
|
1856
|
+
|
1857
|
+
next_scores = next_scores.reshape(
|
1858
|
+
[batch_size, 2 * num_beams * num_beams]
|
1859
|
+
)
|
1860
|
+
next_tokens = next_tokens.reshape(
|
1861
|
+
[batch_size, 2 * num_beams * num_beams]
|
1862
|
+
)
|
1863
|
+
|
1864
|
+
diversed_score = diversed_score.reshape(
|
1865
|
+
[batch_size, 2 * num_beams * num_beams]
|
1866
|
+
)
|
1867
|
+
diversed_score, diversed_tokens = paddle.topk(
|
1868
|
+
diversed_score, 2 * num_beams, axis=1
|
1869
|
+
)
|
1870
|
+
|
1871
|
+
# TODO
|
1872
|
+
# Use gather_nd() to select origan token and score
|
1873
|
+
next_scores = paddle.stack(
|
1874
|
+
[
|
1875
|
+
paddle.index_select(next_scores[i], diversed_tokens[i])
|
1876
|
+
for i in range(next_scores.shape[0])
|
1877
|
+
]
|
1878
|
+
)
|
1879
|
+
next_tokens = paddle.stack(
|
1880
|
+
[
|
1881
|
+
paddle.index_select(next_tokens[i], diversed_tokens[i])
|
1882
|
+
for i in range(next_tokens.shape[0])
|
1883
|
+
]
|
1884
|
+
)
|
1885
|
+
|
1886
|
+
next_indices = diversed_tokens // (2 * num_beams)
|
1887
|
+
|
1888
|
+
# stateless
|
1889
|
+
beam_outputs = beam_scorer.process(
|
1890
|
+
input_ids,
|
1891
|
+
next_scores,
|
1892
|
+
next_tokens,
|
1893
|
+
next_indices,
|
1894
|
+
origin_len=origin_len,
|
1895
|
+
pad_token_id=pad_token_id,
|
1896
|
+
eos_token_id=eos_token_id,
|
1897
|
+
)
|
1898
|
+
beam_scores = beam_outputs["next_beam_scores"]
|
1899
|
+
beam_next_tokens = beam_outputs["next_beam_tokens"]
|
1900
|
+
beam_idx = beam_outputs["next_beam_indices"]
|
1901
|
+
# beam_idx may contain element -1 and cause error
|
1902
|
+
# PR: https://github.com/PaddlePaddle/Paddle/issues/57366
|
1903
|
+
beam_idx = paddle.maximum(beam_idx, paddle.full_like(beam_idx, 0))
|
1904
|
+
|
1905
|
+
cur_len += 1
|
1906
|
+
input_ids = paddle.concat(
|
1907
|
+
[
|
1908
|
+
paddle.index_select(input_ids, beam_idx),
|
1909
|
+
beam_next_tokens.unsqueeze(-1),
|
1910
|
+
],
|
1911
|
+
axis=-1,
|
1912
|
+
)
|
1913
|
+
|
1914
|
+
if beam_scorer.is_done or stopping_criteria(input_ids, beam_scores):
|
1915
|
+
if not synced_gpus:
|
1916
|
+
break
|
1917
|
+
else:
|
1918
|
+
generate_end = True
|
1919
|
+
|
1920
|
+
model_kwargs = self.update_model_kwargs_for_generation(
|
1921
|
+
outputs, model_kwargs, is_encoder_decoder=self.is_encoder_decoder
|
1922
|
+
)
|
1923
|
+
if "cache" in model_kwargs:
|
1924
|
+
# reorder the cache
|
1925
|
+
model_kwargs["cache"] = self.reorder_cache(
|
1926
|
+
model_kwargs["cache"], beam_idx
|
1927
|
+
)
|
1928
|
+
if "past_key_values" in model_kwargs:
|
1929
|
+
# reorder the cache
|
1930
|
+
model_kwargs["past_key_values"] = self.reorder_cache(
|
1931
|
+
model_kwargs["past_key_values"], beam_idx
|
1932
|
+
)
|
1933
|
+
if fast_ptq_sampling:
|
1934
|
+
break
|
1935
|
+
|
1936
|
+
pred_ids, scores = beam_scorer.finalize(
|
1937
|
+
input_ids,
|
1938
|
+
beam_scores,
|
1939
|
+
next_tokens,
|
1940
|
+
next_indices,
|
1941
|
+
origin_len=origin_len,
|
1942
|
+
pad_token_id=pad_token_id,
|
1943
|
+
eos_token_id=eos_token_id,
|
1944
|
+
)
|
1945
|
+
return pred_ids[:, origin_len:] if trunc_input else input_ids, scores
|
1946
|
+
|
1947
|
+
def group_beam_search(
|
1948
|
+
self,
|
1949
|
+
input_ids,
|
1950
|
+
beam_scorer,
|
1951
|
+
logits_processors,
|
1952
|
+
max_length,
|
1953
|
+
pad_token_id,
|
1954
|
+
eos_token_id,
|
1955
|
+
stopping_criteria=None,
|
1956
|
+
fast_ptq_sampling=False,
|
1957
|
+
trunc_input=True,
|
1958
|
+
synced_gpus=False,
|
1959
|
+
**model_kwargs,
|
1960
|
+
):
|
1961
|
+
model_kwargs["use_cache"] = model_kwargs.get("use_cache", True)
|
1962
|
+
logits_processors = (
|
1963
|
+
logits_processors
|
1964
|
+
if logits_processors is not None
|
1965
|
+
else LogitsProcessorList()
|
1966
|
+
)
|
1967
|
+
|
1968
|
+
# max_length will be convert to MaxLengthCriteria
|
1969
|
+
stopping_criteria = (
|
1970
|
+
stopping_criteria
|
1971
|
+
if stopping_criteria is not None
|
1972
|
+
else StoppingCriteriaList()
|
1973
|
+
)
|
1974
|
+
if max_length is not None:
|
1975
|
+
# logging.warning(
|
1976
|
+
# "`max_length` is deprecated in this function, use"
|
1977
|
+
# " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead."
|
1978
|
+
# )
|
1979
|
+
stopping_criteria = validate_stopping_criteria(
|
1980
|
+
stopping_criteria, max_length
|
1981
|
+
)
|
1982
|
+
|
1983
|
+
batch_size = len(beam_scorer._beam_hyps)
|
1984
|
+
num_beams = beam_scorer.num_beams
|
1985
|
+
num_beam_groups = beam_scorer.num_beam_groups
|
1986
|
+
num_sub_beams = num_beams // num_beam_groups
|
1987
|
+
|
1988
|
+
batch_beam_size, cur_len = input_ids.shape
|
1989
|
+
origin_len = cur_len
|
1990
|
+
|
1991
|
+
assert (
|
1992
|
+
num_beams * batch_size == batch_beam_size
|
1993
|
+
), "Batch dimension of `input_ids` should be {}, but received {}.".format(
|
1994
|
+
num_beams * batch_size, batch_beam_size
|
1995
|
+
)
|
1996
|
+
|
1997
|
+
beam_scores = paddle.full(
|
1998
|
+
(batch_size, num_beams),
|
1999
|
+
get_scale_by_dtype(return_positive=False),
|
2000
|
+
dtype="float32",
|
2001
|
+
)
|
2002
|
+
# initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in
|
2003
|
+
# the same group don't produce same tokens everytime.
|
2004
|
+
beam_scores[:, ::num_sub_beams] = 0
|
2005
|
+
beam_scores = paddle.reshape(beam_scores, [-1])
|
2006
|
+
|
2007
|
+
generate_end = False
|
2008
|
+
while True:
|
2009
|
+
if synced_gpus:
|
2010
|
+
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
|
2011
|
+
# The following logic allows an early break if all peers finished generating their sequence
|
2012
|
+
this_peer_finished_flag = paddle.to_tensor(0.0 if generate_end else 1.0)
|
2013
|
+
# send 0.0 if we finished, 1.0 otherwise
|
2014
|
+
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
|
2015
|
+
# did all peers finish? the reduced sum will be 0.0 then
|
2016
|
+
if this_peer_finished_flag.item() == 0.0:
|
2017
|
+
break
|
2018
|
+
# predicted tokens in cur_len step
|
2019
|
+
current_tokens = paddle.zeros(
|
2020
|
+
shape=[batch_size * num_beams], dtype=input_ids.dtype
|
2021
|
+
)
|
2022
|
+
|
2023
|
+
# indices which will form the beams in the next time step
|
2024
|
+
reordering_indices = paddle.zeros(
|
2025
|
+
shape=[batch_size * num_beams], dtype="int64"
|
2026
|
+
)
|
2027
|
+
# prepare model inputs & get model output
|
2028
|
+
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
2029
|
+
outputs = self(**model_inputs)
|
2030
|
+
if synced_gpus and generate_end:
|
2031
|
+
cur_len = cur_len + 1
|
2032
|
+
continue # don't waste resources running the code we don't need
|
2033
|
+
|
2034
|
+
for beam_group_idx in range(num_beam_groups):
|
2035
|
+
group_start_idx = beam_group_idx * num_sub_beams
|
2036
|
+
group_end_idx = min(group_start_idx + num_sub_beams, num_beams)
|
2037
|
+
group_size = group_end_idx - group_start_idx
|
2038
|
+
|
2039
|
+
# indices of beams of current group among all sentences in batch
|
2040
|
+
batch_group_indices = []
|
2041
|
+
|
2042
|
+
for batch_idx in range(batch_size):
|
2043
|
+
batch_group_indices.extend(
|
2044
|
+
[
|
2045
|
+
batch_idx * num_beams + idx
|
2046
|
+
for idx in range(group_start_idx, group_end_idx)
|
2047
|
+
]
|
2048
|
+
)
|
2049
|
+
|
2050
|
+
group_input_ids = input_ids[batch_group_indices]
|
2051
|
+
|
2052
|
+
if isinstance(outputs, tuple):
|
2053
|
+
logits = outputs[0]
|
2054
|
+
elif isinstance(outputs, ModelOutput):
|
2055
|
+
logits = outputs.logits
|
2056
|
+
else:
|
2057
|
+
logits = outputs
|
2058
|
+
|
2059
|
+
logits = logits[:, -1, :]
|
2060
|
+
logits = paddle.index_select(
|
2061
|
+
logits, paddle.to_tensor(batch_group_indices)
|
2062
|
+
)
|
2063
|
+
logits = self.adjust_logits_during_generation(logits)
|
2064
|
+
|
2065
|
+
next_scores = F.softmax(logits)
|
2066
|
+
next_scores = paddle.log(next_scores)
|
2067
|
+
vocab_size = next_scores.shape[-1]
|
2068
|
+
|
2069
|
+
next_scores = logits_processors(
|
2070
|
+
group_input_ids,
|
2071
|
+
next_scores,
|
2072
|
+
current_tokens=current_tokens,
|
2073
|
+
beam_group_idx=beam_group_idx,
|
2074
|
+
)
|
2075
|
+
|
2076
|
+
next_scores = next_scores + beam_scores[batch_group_indices].unsqueeze(
|
2077
|
+
-1
|
2078
|
+
)
|
2079
|
+
|
2080
|
+
# reshape for beam search
|
2081
|
+
next_scores = next_scores.reshape([batch_size, group_size * vocab_size])
|
2082
|
+
|
2083
|
+
next_scores, next_tokens = paddle.topk(
|
2084
|
+
next_scores, 2 * group_size, axis=1
|
2085
|
+
)
|
2086
|
+
|
2087
|
+
next_indices = next_tokens // vocab_size
|
2088
|
+
next_tokens = next_tokens % vocab_size
|
2089
|
+
|
2090
|
+
beam_outputs = beam_scorer.process(
|
2091
|
+
group_input_ids,
|
2092
|
+
next_scores,
|
2093
|
+
next_tokens,
|
2094
|
+
next_indices,
|
2095
|
+
origin_len=origin_len,
|
2096
|
+
pad_token_id=pad_token_id,
|
2097
|
+
eos_token_id=eos_token_id,
|
2098
|
+
)
|
2099
|
+
|
2100
|
+
beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]
|
2101
|
+
beam_next_tokens = beam_outputs["next_beam_tokens"]
|
2102
|
+
beam_idx = beam_outputs["next_beam_indices"]
|
2103
|
+
# beam_idx may contain element -1 and cause error
|
2104
|
+
# PR: https://github.com/PaddlePaddle/Paddle/issues/57366
|
2105
|
+
beam_idx = paddle.maximum(beam_idx, paddle.full_like(beam_idx, 0))
|
2106
|
+
|
2107
|
+
input_ids[batch_group_indices] = group_input_ids[beam_idx]
|
2108
|
+
group_input_ids = paddle.concat(
|
2109
|
+
[
|
2110
|
+
paddle.index_select(group_input_ids, index=beam_idx),
|
2111
|
+
beam_next_tokens.unsqueeze(-1),
|
2112
|
+
],
|
2113
|
+
axis=-1,
|
2114
|
+
)
|
2115
|
+
current_tokens[batch_group_indices] = beam_next_tokens
|
2116
|
+
|
2117
|
+
reordering_indices[batch_group_indices] = (
|
2118
|
+
num_beams * (beam_idx // group_size)
|
2119
|
+
+ group_start_idx
|
2120
|
+
+ (beam_idx % group_size)
|
2121
|
+
)
|
2122
|
+
|
2123
|
+
input_ids = paddle.concat(
|
2124
|
+
[input_ids, current_tokens.unsqueeze(-1)], axis=-1
|
2125
|
+
)
|
2126
|
+
|
2127
|
+
cur_len += 1
|
2128
|
+
|
2129
|
+
if beam_scorer.is_done or stopping_criteria(input_ids, beam_scores):
|
2130
|
+
if not synced_gpus:
|
2131
|
+
break
|
2132
|
+
else:
|
2133
|
+
generate_end = True
|
2134
|
+
|
2135
|
+
model_kwargs = self.update_model_kwargs_for_generation(
|
2136
|
+
outputs, model_kwargs, is_encoder_decoder=self.is_encoder_decoder
|
2137
|
+
)
|
2138
|
+
|
2139
|
+
if "cache" in model_kwargs:
|
2140
|
+
# reorder the cache
|
2141
|
+
model_kwargs["cache"] = self.reorder_cache(
|
2142
|
+
model_kwargs["cache"], reordering_indices
|
2143
|
+
)
|
2144
|
+
if "past_key_values" in model_kwargs:
|
2145
|
+
# reorder the cache
|
2146
|
+
model_kwargs["past_key_values"] = self.reorder_cache(
|
2147
|
+
model_kwargs["past_key_values"], reordering_indices
|
2148
|
+
)
|
2149
|
+
|
2150
|
+
if fast_ptq_sampling:
|
2151
|
+
break
|
2152
|
+
|
2153
|
+
pred_ids, scores = beam_scorer.finalize(
|
2154
|
+
input_ids,
|
2155
|
+
beam_scores,
|
2156
|
+
next_tokens,
|
2157
|
+
next_indices,
|
2158
|
+
origin_len=origin_len,
|
2159
|
+
pad_token_id=pad_token_id,
|
2160
|
+
eos_token_id=eos_token_id,
|
2161
|
+
)
|
2162
|
+
return pred_ids[:, origin_len:] if trunc_input else input_ids, scores
|