paddlex 2.0.0rc4__py3-none-any.whl → 3.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- paddlex/.version +1 -0
- paddlex/__init__.py +35 -18
- paddlex/__main__.py +39 -0
- paddlex/configs/modules/3d_bev_detection/BEVFusion.yaml +38 -0
- paddlex/configs/modules/chart_parsing/PP-Chart2Table.yaml +13 -0
- paddlex/configs/modules/doc_text_orientation/PP-LCNet_x1_0_doc_ori.yaml +41 -0
- paddlex/configs/modules/doc_vlm/PP-DocBee-2B.yaml +14 -0
- paddlex/configs/modules/doc_vlm/PP-DocBee-7B.yaml +14 -0
- paddlex/configs/modules/doc_vlm/PP-DocBee2-3B.yaml +14 -0
- paddlex/configs/modules/face_detection/BlazeFace-FPN-SSH.yaml +40 -0
- paddlex/configs/modules/face_detection/BlazeFace.yaml +40 -0
- paddlex/configs/modules/face_detection/PP-YOLOE_plus-S_face.yaml +40 -0
- paddlex/configs/modules/face_detection/PicoDet_LCNet_x2_5_face.yaml +40 -0
- paddlex/configs/modules/face_feature/MobileFaceNet.yaml +41 -0
- paddlex/configs/modules/face_feature/ResNet50_face.yaml +41 -0
- paddlex/configs/modules/formula_recognition/LaTeX_OCR_rec.yaml +40 -0
- paddlex/configs/modules/formula_recognition/PP-FormulaNet-L.yaml +40 -0
- paddlex/configs/modules/formula_recognition/PP-FormulaNet-S.yaml +40 -0
- paddlex/configs/modules/formula_recognition/PP-FormulaNet_plus-L.yaml +40 -0
- paddlex/configs/modules/formula_recognition/PP-FormulaNet_plus-M.yaml +40 -0
- paddlex/configs/modules/formula_recognition/PP-FormulaNet_plus-S.yaml +40 -0
- paddlex/configs/modules/formula_recognition/UniMERNet.yaml +40 -0
- paddlex/configs/modules/human_detection/PP-YOLOE-L_human.yaml +42 -0
- paddlex/configs/modules/human_detection/PP-YOLOE-S_human.yaml +42 -0
- paddlex/configs/modules/image_anomaly_detection/STFPM.yaml +41 -0
- paddlex/configs/modules/image_classification/CLIP_vit_base_patch16_224.yaml +41 -0
- paddlex/configs/modules/image_classification/CLIP_vit_large_patch14_224.yaml +41 -0
- paddlex/configs/modules/image_classification/ConvNeXt_base_224.yaml +41 -0
- paddlex/configs/modules/image_classification/ConvNeXt_base_384.yaml +41 -0
- paddlex/configs/modules/image_classification/ConvNeXt_large_224.yaml +41 -0
- paddlex/configs/modules/image_classification/ConvNeXt_large_384.yaml +41 -0
- paddlex/configs/modules/image_classification/ConvNeXt_small.yaml +41 -0
- paddlex/configs/modules/image_classification/ConvNeXt_tiny.yaml +41 -0
- paddlex/configs/modules/image_classification/FasterNet-L.yaml +40 -0
- paddlex/configs/modules/image_classification/FasterNet-M.yaml +40 -0
- paddlex/configs/modules/image_classification/FasterNet-S.yaml +40 -0
- paddlex/configs/modules/image_classification/FasterNet-T0.yaml +40 -0
- paddlex/configs/modules/image_classification/FasterNet-T1.yaml +40 -0
- paddlex/configs/modules/image_classification/FasterNet-T2.yaml +40 -0
- paddlex/configs/modules/image_classification/MobileNetV1_x0_25.yaml +41 -0
- paddlex/configs/modules/image_classification/MobileNetV1_x0_5.yaml +41 -0
- paddlex/configs/modules/image_classification/MobileNetV1_x0_75.yaml +41 -0
- paddlex/configs/modules/image_classification/MobileNetV1_x1_0.yaml +41 -0
- paddlex/configs/modules/image_classification/MobileNetV2_x0_25.yaml +41 -0
- paddlex/configs/modules/image_classification/MobileNetV2_x0_5.yaml +41 -0
- paddlex/configs/modules/image_classification/MobileNetV2_x1_0.yaml +41 -0
- paddlex/configs/modules/image_classification/MobileNetV2_x1_5.yaml +41 -0
- paddlex/configs/modules/image_classification/MobileNetV2_x2_0.yaml +41 -0
- paddlex/configs/modules/image_classification/MobileNetV3_large_x0_35.yaml +41 -0
- paddlex/configs/modules/image_classification/MobileNetV3_large_x0_5.yaml +41 -0
- paddlex/configs/modules/image_classification/MobileNetV3_large_x0_75.yaml +41 -0
- paddlex/configs/modules/image_classification/MobileNetV3_large_x1_0.yaml +41 -0
- paddlex/configs/modules/image_classification/MobileNetV3_large_x1_25.yaml +41 -0
- paddlex/configs/modules/image_classification/MobileNetV3_small_x0_35.yaml +41 -0
- paddlex/configs/modules/image_classification/MobileNetV3_small_x0_5.yaml +41 -0
- paddlex/configs/modules/image_classification/MobileNetV3_small_x0_75.yaml +41 -0
- paddlex/configs/modules/image_classification/MobileNetV3_small_x1_0.yaml +41 -0
- paddlex/configs/modules/image_classification/MobileNetV3_small_x1_25.yaml +41 -0
- paddlex/configs/modules/image_classification/MobileNetV4_conv_large.yaml +41 -0
- paddlex/configs/modules/image_classification/MobileNetV4_conv_medium.yaml +41 -0
- paddlex/configs/modules/image_classification/MobileNetV4_conv_small.yaml +41 -0
- paddlex/configs/modules/image_classification/MobileNetV4_hybrid_large.yaml +41 -0
- paddlex/configs/modules/image_classification/MobileNetV4_hybrid_medium.yaml +41 -0
- paddlex/configs/modules/image_classification/PP-HGNetV2-B0.yaml +41 -0
- paddlex/configs/modules/image_classification/PP-HGNetV2-B1.yaml +41 -0
- paddlex/configs/modules/image_classification/PP-HGNetV2-B2.yaml +41 -0
- paddlex/configs/modules/image_classification/PP-HGNetV2-B3.yaml +41 -0
- paddlex/configs/modules/image_classification/PP-HGNetV2-B4.yaml +41 -0
- paddlex/configs/modules/image_classification/PP-HGNetV2-B5.yaml +41 -0
- paddlex/configs/modules/image_classification/PP-HGNetV2-B6.yaml +41 -0
- paddlex/configs/modules/image_classification/PP-HGNet_base.yaml +41 -0
- paddlex/configs/modules/image_classification/PP-HGNet_small.yaml +41 -0
- paddlex/configs/modules/image_classification/PP-HGNet_tiny.yaml +41 -0
- paddlex/configs/modules/image_classification/PP-LCNetV2_base.yaml +41 -0
- paddlex/configs/modules/image_classification/PP-LCNetV2_large.yaml +41 -0
- paddlex/configs/modules/image_classification/PP-LCNetV2_small.yaml +41 -0
- paddlex/configs/modules/image_classification/PP-LCNet_x0_25.yaml +41 -0
- paddlex/configs/modules/image_classification/PP-LCNet_x0_35.yaml +41 -0
- paddlex/configs/modules/image_classification/PP-LCNet_x0_5.yaml +41 -0
- paddlex/configs/modules/image_classification/PP-LCNet_x0_75.yaml +41 -0
- paddlex/configs/modules/image_classification/PP-LCNet_x1_0.yaml +41 -0
- paddlex/configs/modules/image_classification/PP-LCNet_x1_5.yaml +41 -0
- paddlex/configs/modules/image_classification/PP-LCNet_x2_0.yaml +41 -0
- paddlex/configs/modules/image_classification/PP-LCNet_x2_5.yaml +41 -0
- paddlex/configs/modules/image_classification/ResNet101.yaml +41 -0
- paddlex/configs/modules/image_classification/ResNet101_vd.yaml +41 -0
- paddlex/configs/modules/image_classification/ResNet152.yaml +41 -0
- paddlex/configs/modules/image_classification/ResNet152_vd.yaml +41 -0
- paddlex/configs/modules/image_classification/ResNet18.yaml +41 -0
- paddlex/configs/modules/image_classification/ResNet18_vd.yaml +41 -0
- paddlex/configs/modules/image_classification/ResNet200_vd.yaml +41 -0
- paddlex/configs/modules/image_classification/ResNet34.yaml +41 -0
- paddlex/configs/modules/image_classification/ResNet34_vd.yaml +41 -0
- paddlex/configs/modules/image_classification/ResNet50.yaml +41 -0
- paddlex/configs/modules/image_classification/ResNet50_vd.yaml +41 -0
- paddlex/configs/modules/image_classification/StarNet-S1.yaml +41 -0
- paddlex/configs/modules/image_classification/StarNet-S2.yaml +41 -0
- paddlex/configs/modules/image_classification/StarNet-S3.yaml +41 -0
- paddlex/configs/modules/image_classification/StarNet-S4.yaml +41 -0
- paddlex/configs/modules/image_classification/SwinTransformer_base_patch4_window12_384.yaml +41 -0
- paddlex/configs/modules/image_classification/SwinTransformer_base_patch4_window7_224.yaml +41 -0
- paddlex/configs/modules/image_classification/SwinTransformer_large_patch4_window12_384.yaml +41 -0
- paddlex/configs/modules/image_classification/SwinTransformer_large_patch4_window7_224.yaml +41 -0
- paddlex/configs/modules/image_classification/SwinTransformer_small_patch4_window7_224.yaml +41 -0
- paddlex/configs/modules/image_classification/SwinTransformer_tiny_patch4_window7_224.yaml +41 -0
- paddlex/configs/modules/image_feature/PP-ShiTuV2_rec.yaml +42 -0
- paddlex/configs/modules/image_feature/PP-ShiTuV2_rec_CLIP_vit_base.yaml +42 -0
- paddlex/configs/modules/image_feature/PP-ShiTuV2_rec_CLIP_vit_large.yaml +41 -0
- paddlex/configs/modules/image_multilabel_classification/CLIP_vit_base_patch16_448_ML.yaml +41 -0
- paddlex/configs/modules/image_multilabel_classification/PP-HGNetV2-B0_ML.yaml +41 -0
- paddlex/configs/modules/image_multilabel_classification/PP-HGNetV2-B4_ML.yaml +41 -0
- paddlex/configs/modules/image_multilabel_classification/PP-HGNetV2-B6_ML.yaml +41 -0
- paddlex/configs/modules/image_multilabel_classification/PP-LCNet_x1_0_ML.yaml +41 -0
- paddlex/configs/modules/image_multilabel_classification/ResNet50_ML.yaml +41 -0
- paddlex/configs/modules/image_unwarping/UVDoc.yaml +12 -0
- paddlex/configs/modules/instance_segmentation/Cascade-MaskRCNN-ResNet50-FPN.yaml +40 -0
- paddlex/configs/modules/instance_segmentation/Cascade-MaskRCNN-ResNet50-vd-SSLDv2-FPN.yaml +40 -0
- paddlex/configs/modules/instance_segmentation/Mask-RT-DETR-H.yaml +40 -0
- paddlex/configs/modules/instance_segmentation/Mask-RT-DETR-L.yaml +40 -0
- paddlex/configs/modules/instance_segmentation/Mask-RT-DETR-M.yaml +40 -0
- paddlex/configs/modules/instance_segmentation/Mask-RT-DETR-S.yaml +40 -0
- paddlex/configs/modules/instance_segmentation/Mask-RT-DETR-X.yaml +40 -0
- paddlex/configs/modules/instance_segmentation/MaskRCNN-ResNeXt101-vd-FPN.yaml +39 -0
- paddlex/configs/modules/instance_segmentation/MaskRCNN-ResNet101-FPN.yaml +40 -0
- paddlex/configs/modules/instance_segmentation/MaskRCNN-ResNet101-vd-FPN.yaml +40 -0
- paddlex/configs/modules/instance_segmentation/MaskRCNN-ResNet50-FPN.yaml +40 -0
- paddlex/configs/modules/instance_segmentation/MaskRCNN-ResNet50-vd-FPN.yaml +40 -0
- paddlex/configs/modules/instance_segmentation/MaskRCNN-ResNet50.yaml +40 -0
- paddlex/configs/modules/instance_segmentation/PP-YOLOE_seg-S.yaml +40 -0
- paddlex/configs/modules/instance_segmentation/SOLOv2.yaml +40 -0
- paddlex/configs/modules/keypoint_detection/PP-TinyPose_128x96.yaml +40 -0
- paddlex/configs/modules/keypoint_detection/PP-TinyPose_256x192.yaml +40 -0
- paddlex/configs/modules/layout_detection/PP-DocBlockLayout.yaml +40 -0
- paddlex/configs/modules/layout_detection/PP-DocLayout-L.yaml +40 -0
- paddlex/configs/modules/layout_detection/PP-DocLayout-M.yaml +40 -0
- paddlex/configs/modules/layout_detection/PP-DocLayout-S.yaml +40 -0
- paddlex/configs/modules/layout_detection/PP-DocLayout_plus-L.yaml +40 -0
- paddlex/configs/modules/layout_detection/PicoDet-L_layout_17cls.yaml +40 -0
- paddlex/configs/modules/layout_detection/PicoDet-L_layout_3cls.yaml +40 -0
- paddlex/configs/modules/layout_detection/PicoDet-S_layout_17cls.yaml +40 -0
- paddlex/configs/modules/layout_detection/PicoDet-S_layout_3cls.yaml +40 -0
- paddlex/configs/modules/layout_detection/PicoDet_layout_1x.yaml +40 -0
- paddlex/configs/modules/layout_detection/PicoDet_layout_1x_table.yaml +40 -0
- paddlex/configs/modules/layout_detection/RT-DETR-H_layout_17cls.yaml +40 -0
- paddlex/configs/modules/layout_detection/RT-DETR-H_layout_3cls.yaml +40 -0
- paddlex/configs/modules/mainbody_detection/PP-ShiTuV2_det.yaml +41 -0
- paddlex/configs/modules/multilingual_speech_recognition/whisper_base.yaml +12 -0
- paddlex/configs/modules/multilingual_speech_recognition/whisper_large.yaml +12 -0
- paddlex/configs/modules/multilingual_speech_recognition/whisper_medium.yaml +12 -0
- paddlex/configs/modules/multilingual_speech_recognition/whisper_small.yaml +12 -0
- paddlex/configs/modules/multilingual_speech_recognition/whisper_tiny.yaml +12 -0
- paddlex/configs/modules/object_detection/Cascade-FasterRCNN-ResNet50-FPN.yaml +41 -0
- paddlex/configs/modules/object_detection/Cascade-FasterRCNN-ResNet50-vd-SSLDv2-FPN.yaml +42 -0
- paddlex/configs/modules/object_detection/CenterNet-DLA-34.yaml +41 -0
- paddlex/configs/modules/object_detection/CenterNet-ResNet50.yaml +41 -0
- paddlex/configs/modules/object_detection/Co-DINO-R50.yaml +40 -0
- paddlex/configs/modules/object_detection/Co-DINO-Swin-L.yaml +40 -0
- paddlex/configs/modules/object_detection/Co-Deformable-DETR-R50.yaml +40 -0
- paddlex/configs/modules/object_detection/Co-Deformable-DETR-Swin-T.yaml +40 -0
- paddlex/configs/modules/object_detection/DETR-R50.yaml +42 -0
- paddlex/configs/modules/object_detection/FCOS-ResNet50.yaml +41 -0
- paddlex/configs/modules/object_detection/FasterRCNN-ResNeXt101-vd-FPN.yaml +42 -0
- paddlex/configs/modules/object_detection/FasterRCNN-ResNet101-FPN.yaml +42 -0
- paddlex/configs/modules/object_detection/FasterRCNN-ResNet101.yaml +42 -0
- paddlex/configs/modules/object_detection/FasterRCNN-ResNet34-FPN.yaml +42 -0
- paddlex/configs/modules/object_detection/FasterRCNN-ResNet50-FPN.yaml +42 -0
- paddlex/configs/modules/object_detection/FasterRCNN-ResNet50-vd-FPN.yaml +42 -0
- paddlex/configs/modules/object_detection/FasterRCNN-ResNet50-vd-SSLDv2-FPN.yaml +42 -0
- paddlex/configs/modules/object_detection/FasterRCNN-ResNet50.yaml +42 -0
- paddlex/configs/modules/object_detection/FasterRCNN-Swin-Tiny-FPN.yaml +42 -0
- paddlex/configs/modules/object_detection/PP-YOLOE_plus-L.yaml +40 -0
- paddlex/configs/modules/object_detection/PP-YOLOE_plus-M.yaml +40 -0
- paddlex/configs/modules/object_detection/PP-YOLOE_plus-S.yaml +40 -0
- paddlex/configs/modules/object_detection/PP-YOLOE_plus-X.yaml +40 -0
- paddlex/configs/modules/object_detection/PicoDet-L.yaml +40 -0
- paddlex/configs/modules/object_detection/PicoDet-M.yaml +42 -0
- paddlex/configs/modules/object_detection/PicoDet-S.yaml +40 -0
- paddlex/configs/modules/object_detection/PicoDet-XS.yaml +42 -0
- paddlex/configs/modules/object_detection/RT-DETR-H.yaml +40 -0
- paddlex/configs/modules/object_detection/RT-DETR-L.yaml +40 -0
- paddlex/configs/modules/object_detection/RT-DETR-R18.yaml +40 -0
- paddlex/configs/modules/object_detection/RT-DETR-R50.yaml +40 -0
- paddlex/configs/modules/object_detection/RT-DETR-X.yaml +40 -0
- paddlex/configs/modules/object_detection/YOLOX-L.yaml +40 -0
- paddlex/configs/modules/object_detection/YOLOX-M.yaml +40 -0
- paddlex/configs/modules/object_detection/YOLOX-N.yaml +40 -0
- paddlex/configs/modules/object_detection/YOLOX-S.yaml +40 -0
- paddlex/configs/modules/object_detection/YOLOX-T.yaml +40 -0
- paddlex/configs/modules/object_detection/YOLOX-X.yaml +40 -0
- paddlex/configs/modules/object_detection/YOLOv3-DarkNet53.yaml +40 -0
- paddlex/configs/modules/object_detection/YOLOv3-MobileNetV3.yaml +40 -0
- paddlex/configs/modules/object_detection/YOLOv3-ResNet50_vd_DCN.yaml +40 -0
- paddlex/configs/modules/open_vocabulary_detection/GroundingDINO-T.yaml +13 -0
- paddlex/configs/modules/open_vocabulary_detection/YOLO-Worldv2-L.yaml +13 -0
- paddlex/configs/modules/open_vocabulary_segmentation/SAM-H_box.yaml +17 -0
- paddlex/configs/modules/open_vocabulary_segmentation/SAM-H_point.yaml +15 -0
- paddlex/configs/modules/pedestrian_attribute_recognition/PP-LCNet_x1_0_pedestrian_attribute.yaml +41 -0
- paddlex/configs/modules/rotated_object_detection/PP-YOLOE-R-L.yaml +40 -0
- paddlex/configs/modules/seal_text_detection/PP-OCRv4_mobile_seal_det.yaml +40 -0
- paddlex/configs/modules/seal_text_detection/PP-OCRv4_server_seal_det.yaml +40 -0
- paddlex/configs/modules/semantic_segmentation/Deeplabv3-R101.yaml +40 -0
- paddlex/configs/modules/semantic_segmentation/Deeplabv3-R50.yaml +40 -0
- paddlex/configs/modules/semantic_segmentation/Deeplabv3_Plus-R101.yaml +40 -0
- paddlex/configs/modules/semantic_segmentation/Deeplabv3_Plus-R50.yaml +40 -0
- paddlex/configs/modules/semantic_segmentation/MaskFormer_small.yaml +42 -0
- paddlex/configs/modules/semantic_segmentation/MaskFormer_tiny.yaml +42 -0
- paddlex/configs/modules/semantic_segmentation/OCRNet_HRNet-W18.yaml +40 -0
- paddlex/configs/modules/semantic_segmentation/OCRNet_HRNet-W48.yaml +40 -0
- paddlex/configs/modules/semantic_segmentation/PP-LiteSeg-B.yaml +41 -0
- paddlex/configs/modules/semantic_segmentation/PP-LiteSeg-T.yaml +40 -0
- paddlex/configs/modules/semantic_segmentation/SeaFormer_base.yaml +40 -0
- paddlex/configs/modules/semantic_segmentation/SeaFormer_large.yaml +40 -0
- paddlex/configs/modules/semantic_segmentation/SeaFormer_small.yaml +40 -0
- paddlex/configs/modules/semantic_segmentation/SeaFormer_tiny.yaml +40 -0
- paddlex/configs/modules/semantic_segmentation/SegFormer-B0.yaml +40 -0
- paddlex/configs/modules/semantic_segmentation/SegFormer-B1.yaml +40 -0
- paddlex/configs/modules/semantic_segmentation/SegFormer-B2.yaml +40 -0
- paddlex/configs/modules/semantic_segmentation/SegFormer-B3.yaml +40 -0
- paddlex/configs/modules/semantic_segmentation/SegFormer-B4.yaml +40 -0
- paddlex/configs/modules/semantic_segmentation/SegFormer-B5.yaml +40 -0
- paddlex/configs/modules/small_object_detection/PP-YOLOE_plus_SOD-L.yaml +42 -0
- paddlex/configs/modules/small_object_detection/PP-YOLOE_plus_SOD-S.yaml +42 -0
- paddlex/configs/modules/small_object_detection/PP-YOLOE_plus_SOD-largesize-L.yaml +42 -0
- paddlex/configs/modules/table_cells_detection/RT-DETR-L_wired_table_cell_det.yaml +40 -0
- paddlex/configs/modules/table_cells_detection/RT-DETR-L_wireless_table_cell_det.yaml +40 -0
- paddlex/configs/modules/table_classification/PP-LCNet_x1_0_table_cls.yaml +41 -0
- paddlex/configs/modules/table_structure_recognition/SLANeXt_wired.yaml +39 -0
- paddlex/configs/modules/table_structure_recognition/SLANeXt_wireless.yaml +39 -0
- paddlex/configs/modules/table_structure_recognition/SLANet.yaml +39 -0
- paddlex/configs/modules/table_structure_recognition/SLANet_plus.yaml +39 -0
- paddlex/configs/modules/text_detection/PP-OCRv3_mobile_det.yaml +40 -0
- paddlex/configs/modules/text_detection/PP-OCRv3_server_det.yaml +40 -0
- paddlex/configs/modules/text_detection/PP-OCRv4_mobile_det.yaml +40 -0
- paddlex/configs/modules/text_detection/PP-OCRv4_server_det.yaml +40 -0
- paddlex/configs/modules/text_detection/PP-OCRv5_mobile_det.yaml +40 -0
- paddlex/configs/modules/text_detection/PP-OCRv5_server_det.yaml +40 -0
- paddlex/configs/modules/text_recognition/PP-OCRv3_mobile_rec.yaml +39 -0
- paddlex/configs/modules/text_recognition/PP-OCRv4_mobile_rec.yaml +39 -0
- paddlex/configs/modules/text_recognition/PP-OCRv4_server_rec.yaml +39 -0
- paddlex/configs/modules/text_recognition/PP-OCRv4_server_rec_doc.yaml +39 -0
- paddlex/configs/modules/text_recognition/PP-OCRv5_mobile_rec.yaml +39 -0
- paddlex/configs/modules/text_recognition/PP-OCRv5_server_rec.yaml +39 -0
- paddlex/configs/modules/text_recognition/arabic_PP-OCRv3_mobile_rec.yaml +39 -0
- paddlex/configs/modules/text_recognition/ch_RepSVTR_rec.yaml +39 -0
- paddlex/configs/modules/text_recognition/ch_SVTRv2_rec.yaml +39 -0
- paddlex/configs/modules/text_recognition/chinese_cht_PP-OCRv3_mobile_rec.yaml +39 -0
- paddlex/configs/modules/text_recognition/cyrillic_PP-OCRv3_mobile_rec.yaml +39 -0
- paddlex/configs/modules/text_recognition/devanagari_PP-OCRv3_mobile_rec.yaml +39 -0
- paddlex/configs/modules/text_recognition/en_PP-OCRv3_mobile_rec.yaml +39 -0
- paddlex/configs/modules/text_recognition/en_PP-OCRv4_mobile_rec.yaml +39 -0
- paddlex/configs/modules/text_recognition/japan_PP-OCRv3_mobile_rec.yaml +39 -0
- paddlex/configs/modules/text_recognition/ka_PP-OCRv3_mobile_rec.yaml +39 -0
- paddlex/configs/modules/text_recognition/korean_PP-OCRv3_mobile_rec.yaml +39 -0
- paddlex/configs/modules/text_recognition/latin_PP-OCRv3_mobile_rec.yaml +39 -0
- paddlex/configs/modules/text_recognition/ta_PP-OCRv3_mobile_rec.yaml +39 -0
- paddlex/configs/modules/text_recognition/te_PP-OCRv3_mobile_rec.yaml +39 -0
- paddlex/configs/modules/textline_orientation/PP-LCNet_x0_25_textline_ori.yaml +41 -0
- paddlex/configs/modules/ts_anomaly_detection/AutoEncoder_ad.yaml +37 -0
- paddlex/configs/modules/ts_anomaly_detection/DLinear_ad.yaml +37 -0
- paddlex/configs/modules/ts_anomaly_detection/Nonstationary_ad.yaml +37 -0
- paddlex/configs/modules/ts_anomaly_detection/PatchTST_ad.yaml +37 -0
- paddlex/configs/modules/ts_anomaly_detection/TimesNet_ad.yaml +37 -0
- paddlex/configs/modules/ts_classification/TimesNet_cls.yaml +37 -0
- paddlex/configs/modules/ts_forecast/DLinear.yaml +38 -0
- paddlex/configs/modules/ts_forecast/NLinear.yaml +38 -0
- paddlex/configs/modules/ts_forecast/Nonstationary.yaml +38 -0
- paddlex/configs/modules/ts_forecast/PatchTST.yaml +38 -0
- paddlex/configs/modules/ts_forecast/RLinear.yaml +38 -0
- paddlex/configs/modules/ts_forecast/TiDE.yaml +38 -0
- paddlex/configs/modules/ts_forecast/TimesNet.yaml +38 -0
- paddlex/configs/modules/vehicle_attribute_recognition/PP-LCNet_x1_0_vehicle_attribute.yaml +41 -0
- paddlex/configs/modules/vehicle_detection/PP-YOLOE-L_vehicle.yaml +41 -0
- paddlex/configs/modules/vehicle_detection/PP-YOLOE-S_vehicle.yaml +42 -0
- paddlex/configs/modules/video_classification/PP-TSM-R50_8frames_uniform.yaml +42 -0
- paddlex/configs/modules/video_classification/PP-TSMv2-LCNetV2_16frames_uniform.yaml +42 -0
- paddlex/configs/modules/video_classification/PP-TSMv2-LCNetV2_8frames_uniform.yaml +42 -0
- paddlex/configs/modules/video_detection/YOWO.yaml +40 -0
- paddlex/configs/pipelines/3d_bev_detection.yaml +9 -0
- paddlex/configs/pipelines/OCR.yaml +45 -0
- paddlex/configs/pipelines/PP-ChatOCRv3-doc.yaml +151 -0
- paddlex/configs/pipelines/PP-ChatOCRv4-doc.yaml +237 -0
- paddlex/configs/pipelines/PP-ShiTuV2.yaml +18 -0
- paddlex/configs/pipelines/PP-StructureV3.yaml +226 -0
- paddlex/configs/pipelines/anomaly_detection.yaml +8 -0
- paddlex/configs/pipelines/doc_preprocessor.yaml +15 -0
- paddlex/configs/pipelines/doc_understanding.yaml +9 -0
- paddlex/configs/pipelines/face_recognition.yaml +18 -0
- paddlex/configs/pipelines/formula_recognition.yaml +39 -0
- paddlex/configs/pipelines/human_keypoint_detection.yaml +17 -0
- paddlex/configs/pipelines/image_classification.yaml +10 -0
- paddlex/configs/pipelines/image_multilabel_classification.yaml +9 -0
- paddlex/configs/pipelines/instance_segmentation.yaml +10 -0
- paddlex/configs/pipelines/layout_parsing.yaml +102 -0
- paddlex/configs/pipelines/multilingual_speech_recognition.yaml +9 -0
- paddlex/configs/pipelines/object_detection.yaml +10 -0
- paddlex/configs/pipelines/open_vocabulary_detection.yaml +12 -0
- paddlex/configs/pipelines/open_vocabulary_segmentation.yaml +13 -0
- paddlex/configs/pipelines/pedestrian_attribute_recognition.yaml +15 -0
- paddlex/configs/pipelines/rotated_object_detection.yaml +10 -0
- paddlex/configs/pipelines/seal_recognition.yaml +52 -0
- paddlex/configs/pipelines/semantic_segmentation.yaml +10 -0
- paddlex/configs/pipelines/small_object_detection.yaml +10 -0
- paddlex/configs/pipelines/table_recognition.yaml +57 -0
- paddlex/configs/pipelines/table_recognition_v2.yaml +82 -0
- paddlex/configs/pipelines/ts_anomaly_detection.yaml +8 -0
- paddlex/configs/pipelines/ts_classification.yaml +8 -0
- paddlex/configs/pipelines/ts_forecast.yaml +8 -0
- paddlex/configs/pipelines/vehicle_attribute_recognition.yaml +15 -0
- paddlex/configs/pipelines/video_classification.yaml +9 -0
- paddlex/configs/pipelines/video_detection.yaml +10 -0
- paddlex/constants.py +17 -0
- paddlex/engine.py +56 -0
- paddlex/hpip_links.html +31 -0
- paddlex/inference/__init__.py +19 -0
- paddlex/inference/common/__init__.py +13 -0
- paddlex/inference/common/batch_sampler/__init__.py +21 -0
- paddlex/inference/common/batch_sampler/audio_batch_sampler.py +83 -0
- paddlex/inference/common/batch_sampler/base_batch_sampler.py +94 -0
- paddlex/inference/common/batch_sampler/det_3d_batch_sampler.py +144 -0
- paddlex/inference/common/batch_sampler/doc_vlm_batch_sampler.py +87 -0
- paddlex/inference/common/batch_sampler/image_batch_sampler.py +121 -0
- paddlex/inference/common/batch_sampler/ts_batch_sampler.py +109 -0
- paddlex/inference/common/batch_sampler/video_batch_sampler.py +74 -0
- paddlex/inference/common/reader/__init__.py +19 -0
- paddlex/inference/common/reader/audio_reader.py +46 -0
- paddlex/inference/common/reader/det_3d_reader.py +241 -0
- paddlex/inference/common/reader/image_reader.py +73 -0
- paddlex/inference/common/reader/ts_reader.py +46 -0
- paddlex/inference/common/reader/video_reader.py +42 -0
- paddlex/inference/common/result/__init__.py +29 -0
- paddlex/inference/common/result/base_cv_result.py +41 -0
- paddlex/inference/common/result/base_result.py +72 -0
- paddlex/inference/common/result/base_ts_result.py +41 -0
- paddlex/inference/common/result/base_video_result.py +36 -0
- paddlex/inference/common/result/mixin.py +709 -0
- paddlex/inference/models/__init__.py +86 -0
- paddlex/inference/models/anomaly_detection/__init__.py +15 -0
- paddlex/inference/models/anomaly_detection/predictor.py +135 -0
- paddlex/inference/models/anomaly_detection/processors.py +53 -0
- paddlex/inference/models/anomaly_detection/result.py +71 -0
- paddlex/inference/models/base/__init__.py +15 -0
- paddlex/inference/models/base/predictor/__init__.py +15 -0
- paddlex/inference/models/base/predictor/base_predictor.py +414 -0
- paddlex/inference/models/common/__init__.py +26 -0
- paddlex/inference/models/common/static_infer.py +801 -0
- paddlex/inference/models/common/tokenizer/__init__.py +21 -0
- paddlex/inference/models/common/tokenizer/bert_tokenizer.py +655 -0
- paddlex/inference/models/common/tokenizer/clip_tokenizer.py +609 -0
- paddlex/inference/models/common/tokenizer/gpt_tokenizer.py +453 -0
- paddlex/inference/models/common/tokenizer/qwen2_5_tokenizer.py +112 -0
- paddlex/inference/models/common/tokenizer/qwen2_tokenizer.py +438 -0
- paddlex/inference/models/common/tokenizer/qwen_tokenizer.py +288 -0
- paddlex/inference/models/common/tokenizer/tokenizer_utils.py +2149 -0
- paddlex/inference/models/common/tokenizer/tokenizer_utils_base.py +3720 -0
- paddlex/inference/models/common/tokenizer/utils.py +66 -0
- paddlex/inference/models/common/tokenizer/vocab.py +647 -0
- paddlex/inference/models/common/ts/__init__.py +15 -0
- paddlex/inference/models/common/ts/funcs.py +540 -0
- paddlex/inference/models/common/ts/processors.py +322 -0
- paddlex/inference/models/common/vision/__init__.py +23 -0
- paddlex/inference/models/common/vision/funcs.py +98 -0
- paddlex/inference/models/common/vision/processors.py +285 -0
- paddlex/inference/models/common/vlm/__init__.py +13 -0
- paddlex/inference/models/common/vlm/activations.py +189 -0
- paddlex/inference/models/common/vlm/bert_padding.py +127 -0
- paddlex/inference/models/common/vlm/conversion_utils.py +99 -0
- paddlex/inference/models/common/vlm/distributed.py +229 -0
- paddlex/inference/models/common/vlm/flash_attn_utils.py +119 -0
- paddlex/inference/models/common/vlm/fusion_ops.py +205 -0
- paddlex/inference/models/common/vlm/generation/__init__.py +34 -0
- paddlex/inference/models/common/vlm/generation/configuration_utils.py +533 -0
- paddlex/inference/models/common/vlm/generation/logits_process.py +730 -0
- paddlex/inference/models/common/vlm/generation/stopping_criteria.py +106 -0
- paddlex/inference/models/common/vlm/generation/utils.py +2162 -0
- paddlex/inference/models/common/vlm/transformers/__init__.py +16 -0
- paddlex/inference/models/common/vlm/transformers/configuration_utils.py +1037 -0
- paddlex/inference/models/common/vlm/transformers/conversion_utils.py +408 -0
- paddlex/inference/models/common/vlm/transformers/model_outputs.py +1612 -0
- paddlex/inference/models/common/vlm/transformers/model_utils.py +2014 -0
- paddlex/inference/models/common/vlm/transformers/utils.py +178 -0
- paddlex/inference/models/common/vlm/utils.py +109 -0
- paddlex/inference/models/doc_vlm/__init__.py +15 -0
- paddlex/inference/models/doc_vlm/modeling/GOT_ocr_2_0.py +830 -0
- paddlex/inference/models/doc_vlm/modeling/__init__.py +17 -0
- paddlex/inference/models/doc_vlm/modeling/qwen2.py +1606 -0
- paddlex/inference/models/doc_vlm/modeling/qwen2_5_vl.py +3006 -0
- paddlex/inference/models/doc_vlm/modeling/qwen2_vl.py +2495 -0
- paddlex/inference/models/doc_vlm/predictor.py +253 -0
- paddlex/inference/models/doc_vlm/processors/GOT_ocr_2_0.py +97 -0
- paddlex/inference/models/doc_vlm/processors/__init__.py +17 -0
- paddlex/inference/models/doc_vlm/processors/common.py +561 -0
- paddlex/inference/models/doc_vlm/processors/qwen2_5_vl.py +548 -0
- paddlex/inference/models/doc_vlm/processors/qwen2_vl.py +543 -0
- paddlex/inference/models/doc_vlm/result.py +21 -0
- paddlex/inference/models/face_feature/__init__.py +15 -0
- paddlex/inference/models/face_feature/predictor.py +66 -0
- paddlex/inference/models/formula_recognition/__init__.py +15 -0
- paddlex/inference/models/formula_recognition/predictor.py +193 -0
- paddlex/inference/models/formula_recognition/processors.py +1015 -0
- paddlex/inference/models/formula_recognition/result.py +411 -0
- paddlex/inference/models/image_classification/__init__.py +15 -0
- paddlex/inference/models/image_classification/predictor.py +172 -0
- paddlex/inference/models/image_classification/processors.py +89 -0
- paddlex/inference/models/image_classification/result.py +93 -0
- paddlex/inference/models/image_feature/__init__.py +15 -0
- paddlex/inference/models/image_feature/predictor.py +146 -0
- paddlex/inference/models/image_feature/processors.py +31 -0
- paddlex/inference/models/image_feature/result.py +32 -0
- paddlex/inference/models/image_multilabel_classification/__init__.py +15 -0
- paddlex/inference/models/image_multilabel_classification/predictor.py +95 -0
- paddlex/inference/models/image_multilabel_classification/processors.py +89 -0
- paddlex/inference/models/image_multilabel_classification/result.py +96 -0
- paddlex/inference/models/image_unwarping/__init__.py +15 -0
- paddlex/inference/models/image_unwarping/predictor.py +97 -0
- paddlex/inference/models/image_unwarping/processors.py +92 -0
- paddlex/inference/models/image_unwarping/result.py +47 -0
- paddlex/inference/models/instance_segmentation/__init__.py +15 -0
- paddlex/inference/models/instance_segmentation/predictor.py +202 -0
- paddlex/inference/models/instance_segmentation/processors.py +102 -0
- paddlex/inference/models/instance_segmentation/result.py +162 -0
- paddlex/inference/models/keypoint_detection/__init__.py +15 -0
- paddlex/inference/models/keypoint_detection/predictor.py +190 -0
- paddlex/inference/models/keypoint_detection/processors.py +367 -0
- paddlex/inference/models/keypoint_detection/result.py +197 -0
- paddlex/inference/models/m_3d_bev_detection/__init__.py +15 -0
- paddlex/inference/models/m_3d_bev_detection/predictor.py +303 -0
- paddlex/inference/models/m_3d_bev_detection/processors.py +990 -0
- paddlex/inference/models/m_3d_bev_detection/result.py +68 -0
- paddlex/inference/models/m_3d_bev_detection/visualizer_3d.py +169 -0
- paddlex/inference/models/multilingual_speech_recognition/__init__.py +15 -0
- paddlex/inference/models/multilingual_speech_recognition/predictor.py +137 -0
- paddlex/inference/models/multilingual_speech_recognition/processors.py +1933 -0
- paddlex/inference/models/multilingual_speech_recognition/result.py +21 -0
- paddlex/inference/models/object_detection/__init__.py +15 -0
- paddlex/inference/models/object_detection/predictor.py +344 -0
- paddlex/inference/models/object_detection/processors.py +885 -0
- paddlex/inference/models/object_detection/result.py +114 -0
- paddlex/inference/models/object_detection/utils.py +70 -0
- paddlex/inference/models/open_vocabulary_detection/__init__.py +15 -0
- paddlex/inference/models/open_vocabulary_detection/predictor.py +172 -0
- paddlex/inference/models/open_vocabulary_detection/processors/__init__.py +16 -0
- paddlex/inference/models/open_vocabulary_detection/processors/common.py +114 -0
- paddlex/inference/models/open_vocabulary_detection/processors/groundingdino_processors.py +496 -0
- paddlex/inference/models/open_vocabulary_detection/processors/yoloworld_processors.py +209 -0
- paddlex/inference/models/open_vocabulary_segmentation/__init__.py +15 -0
- paddlex/inference/models/open_vocabulary_segmentation/predictor.py +113 -0
- paddlex/inference/models/open_vocabulary_segmentation/processors/__init__.py +15 -0
- paddlex/inference/models/open_vocabulary_segmentation/processors/sam_processer.py +249 -0
- paddlex/inference/models/open_vocabulary_segmentation/results/__init__.py +15 -0
- paddlex/inference/models/open_vocabulary_segmentation/results/sam_result.py +149 -0
- paddlex/inference/models/semantic_segmentation/__init__.py +15 -0
- paddlex/inference/models/semantic_segmentation/predictor.py +158 -0
- paddlex/inference/models/semantic_segmentation/processors.py +117 -0
- paddlex/inference/models/semantic_segmentation/result.py +73 -0
- paddlex/inference/models/table_structure_recognition/__init__.py +15 -0
- paddlex/inference/models/table_structure_recognition/predictor.py +161 -0
- paddlex/inference/models/table_structure_recognition/processors.py +229 -0
- paddlex/inference/models/table_structure_recognition/result.py +63 -0
- paddlex/inference/models/text_detection/__init__.py +15 -0
- paddlex/inference/models/text_detection/predictor.py +191 -0
- paddlex/inference/models/text_detection/processors.py +538 -0
- paddlex/inference/models/text_detection/result.py +46 -0
- paddlex/inference/models/text_recognition/__init__.py +15 -0
- paddlex/inference/models/text_recognition/predictor.py +98 -0
- paddlex/inference/models/text_recognition/processors.py +245 -0
- paddlex/inference/models/text_recognition/result.py +76 -0
- paddlex/inference/models/ts_anomaly_detection/__init__.py +15 -0
- paddlex/inference/models/ts_anomaly_detection/predictor.py +141 -0
- paddlex/inference/models/ts_anomaly_detection/processors.py +98 -0
- paddlex/inference/models/ts_anomaly_detection/result.py +83 -0
- paddlex/inference/models/ts_classification/__init__.py +15 -0
- paddlex/inference/models/ts_classification/predictor.py +122 -0
- paddlex/inference/models/ts_classification/processors.py +122 -0
- paddlex/inference/models/ts_classification/result.py +87 -0
- paddlex/inference/models/ts_forecasting/__init__.py +15 -0
- paddlex/inference/models/ts_forecasting/predictor.py +154 -0
- paddlex/inference/models/ts_forecasting/processors.py +158 -0
- paddlex/inference/models/ts_forecasting/result.py +96 -0
- paddlex/inference/models/video_classification/__init__.py +15 -0
- paddlex/inference/models/video_classification/predictor.py +141 -0
- paddlex/inference/models/video_classification/processors.py +409 -0
- paddlex/inference/models/video_classification/result.py +96 -0
- paddlex/inference/models/video_detection/__init__.py +15 -0
- paddlex/inference/models/video_detection/predictor.py +129 -0
- paddlex/inference/models/video_detection/processors.py +463 -0
- paddlex/inference/models/video_detection/result.py +109 -0
- paddlex/inference/pipelines/__init__.py +239 -0
- paddlex/inference/pipelines/_parallel.py +172 -0
- paddlex/inference/pipelines/anomaly_detection/__init__.py +15 -0
- paddlex/inference/pipelines/anomaly_detection/pipeline.py +82 -0
- paddlex/inference/pipelines/attribute_recognition/__init__.py +15 -0
- paddlex/inference/pipelines/attribute_recognition/pipeline.py +120 -0
- paddlex/inference/pipelines/attribute_recognition/result.py +102 -0
- paddlex/inference/pipelines/base.py +156 -0
- paddlex/inference/pipelines/components/__init__.py +29 -0
- paddlex/inference/pipelines/components/chat_server/__init__.py +16 -0
- paddlex/inference/pipelines/components/chat_server/base.py +39 -0
- paddlex/inference/pipelines/components/chat_server/openai_bot_chat.py +236 -0
- paddlex/inference/pipelines/components/common/__init__.py +19 -0
- paddlex/inference/pipelines/components/common/base_operator.py +37 -0
- paddlex/inference/pipelines/components/common/base_result.py +66 -0
- paddlex/inference/pipelines/components/common/convert_points_and_boxes.py +45 -0
- paddlex/inference/pipelines/components/common/crop_image_regions.py +556 -0
- paddlex/inference/pipelines/components/common/seal_det_warp.py +972 -0
- paddlex/inference/pipelines/components/common/sort_boxes.py +85 -0
- paddlex/inference/pipelines/components/common/warp_image.py +50 -0
- paddlex/inference/pipelines/components/faisser.py +357 -0
- paddlex/inference/pipelines/components/prompt_engineering/__init__.py +16 -0
- paddlex/inference/pipelines/components/prompt_engineering/base.py +35 -0
- paddlex/inference/pipelines/components/prompt_engineering/generate_ensemble_prompt.py +128 -0
- paddlex/inference/pipelines/components/prompt_engineering/generate_kie_prompt.py +148 -0
- paddlex/inference/pipelines/components/retriever/__init__.py +16 -0
- paddlex/inference/pipelines/components/retriever/base.py +228 -0
- paddlex/inference/pipelines/components/retriever/openai_bot_retriever.py +70 -0
- paddlex/inference/pipelines/components/retriever/qianfan_bot_retriever.py +166 -0
- paddlex/inference/pipelines/components/utils/__init__.py +13 -0
- paddlex/inference/pipelines/components/utils/mixin.py +206 -0
- paddlex/inference/pipelines/doc_preprocessor/__init__.py +15 -0
- paddlex/inference/pipelines/doc_preprocessor/pipeline.py +209 -0
- paddlex/inference/pipelines/doc_preprocessor/result.py +98 -0
- paddlex/inference/pipelines/doc_understanding/__init__.py +15 -0
- paddlex/inference/pipelines/doc_understanding/pipeline.py +71 -0
- paddlex/inference/pipelines/face_recognition/__init__.py +15 -0
- paddlex/inference/pipelines/face_recognition/pipeline.py +63 -0
- paddlex/inference/pipelines/face_recognition/result.py +44 -0
- paddlex/inference/pipelines/formula_recognition/__init__.py +15 -0
- paddlex/inference/pipelines/formula_recognition/pipeline.py +347 -0
- paddlex/inference/pipelines/formula_recognition/result.py +282 -0
- paddlex/inference/pipelines/image_classification/__init__.py +15 -0
- paddlex/inference/pipelines/image_classification/pipeline.py +90 -0
- paddlex/inference/pipelines/image_multilabel_classification/__init__.py +15 -0
- paddlex/inference/pipelines/image_multilabel_classification/pipeline.py +97 -0
- paddlex/inference/pipelines/instance_segmentation/__init__.py +15 -0
- paddlex/inference/pipelines/instance_segmentation/pipeline.py +91 -0
- paddlex/inference/pipelines/keypoint_detection/__init__.py +15 -0
- paddlex/inference/pipelines/keypoint_detection/pipeline.py +158 -0
- paddlex/inference/pipelines/layout_parsing/__init__.py +16 -0
- paddlex/inference/pipelines/layout_parsing/pipeline.py +568 -0
- paddlex/inference/pipelines/layout_parsing/pipeline_v2.py +1382 -0
- paddlex/inference/pipelines/layout_parsing/result.py +191 -0
- paddlex/inference/pipelines/layout_parsing/result_v2.py +745 -0
- paddlex/inference/pipelines/layout_parsing/setting.py +87 -0
- paddlex/inference/pipelines/layout_parsing/utils.py +951 -0
- paddlex/inference/pipelines/layout_parsing/xycut_enhanced/__init__.py +16 -0
- paddlex/inference/pipelines/layout_parsing/xycut_enhanced/utils.py +1143 -0
- paddlex/inference/pipelines/layout_parsing/xycut_enhanced/xycuts.py +562 -0
- paddlex/inference/pipelines/m_3d_bev_detection/__init__.py +15 -0
- paddlex/inference/pipelines/m_3d_bev_detection/pipeline.py +74 -0
- paddlex/inference/pipelines/multilingual_speech_recognition/__init__.py +15 -0
- paddlex/inference/pipelines/multilingual_speech_recognition/pipeline.py +78 -0
- paddlex/inference/pipelines/object_detection/__init__.py +15 -0
- paddlex/inference/pipelines/object_detection/pipeline.py +115 -0
- paddlex/inference/pipelines/ocr/__init__.py +15 -0
- paddlex/inference/pipelines/ocr/pipeline.py +463 -0
- paddlex/inference/pipelines/ocr/result.py +255 -0
- paddlex/inference/pipelines/open_vocabulary_detection/__init__.py +15 -0
- paddlex/inference/pipelines/open_vocabulary_detection/pipeline.py +86 -0
- paddlex/inference/pipelines/open_vocabulary_segmentation/__init__.py +15 -0
- paddlex/inference/pipelines/open_vocabulary_segmentation/pipeline.py +100 -0
- paddlex/inference/pipelines/pp_chatocr/__init__.py +16 -0
- paddlex/inference/pipelines/pp_chatocr/pipeline_base.py +111 -0
- paddlex/inference/pipelines/pp_chatocr/pipeline_v3.py +781 -0
- paddlex/inference/pipelines/pp_chatocr/pipeline_v4.py +992 -0
- paddlex/inference/pipelines/pp_shitu_v2/__init__.py +15 -0
- paddlex/inference/pipelines/pp_shitu_v2/pipeline.py +156 -0
- paddlex/inference/pipelines/pp_shitu_v2/result.py +126 -0
- paddlex/inference/pipelines/rotated_object_detection/__init__.py +15 -0
- paddlex/inference/pipelines/rotated_object_detection/pipeline.py +95 -0
- paddlex/inference/pipelines/seal_recognition/__init__.py +15 -0
- paddlex/inference/pipelines/seal_recognition/pipeline.py +335 -0
- paddlex/inference/pipelines/seal_recognition/result.py +89 -0
- paddlex/inference/pipelines/semantic_segmentation/__init__.py +15 -0
- paddlex/inference/pipelines/semantic_segmentation/pipeline.py +95 -0
- paddlex/inference/pipelines/small_object_detection/__init__.py +15 -0
- paddlex/inference/pipelines/small_object_detection/pipeline.py +95 -0
- paddlex/inference/pipelines/table_recognition/__init__.py +16 -0
- paddlex/inference/pipelines/table_recognition/pipeline.py +486 -0
- paddlex/inference/pipelines/table_recognition/pipeline_v2.py +1395 -0
- paddlex/inference/pipelines/table_recognition/result.py +218 -0
- paddlex/inference/pipelines/table_recognition/table_recognition_post_processing.py +366 -0
- paddlex/inference/pipelines/table_recognition/table_recognition_post_processing_v2.py +488 -0
- paddlex/inference/pipelines/table_recognition/utils.py +44 -0
- paddlex/inference/pipelines/ts_anomaly_detection/__init__.py +15 -0
- paddlex/inference/pipelines/ts_anomaly_detection/pipeline.py +72 -0
- paddlex/inference/pipelines/ts_classification/__init__.py +15 -0
- paddlex/inference/pipelines/ts_classification/pipeline.py +72 -0
- paddlex/inference/pipelines/ts_forecasting/__init__.py +15 -0
- paddlex/inference/pipelines/ts_forecasting/pipeline.py +72 -0
- paddlex/inference/pipelines/video_classification/__init__.py +15 -0
- paddlex/inference/pipelines/video_classification/pipeline.py +79 -0
- paddlex/inference/pipelines/video_detection/__init__.py +15 -0
- paddlex/inference/pipelines/video_detection/pipeline.py +86 -0
- paddlex/inference/serving/__init__.py +17 -0
- paddlex/inference/serving/basic_serving/__init__.py +18 -0
- paddlex/inference/serving/basic_serving/_app.py +221 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/__init__.py +44 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/_common/__init__.py +13 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/_common/common.py +104 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/_common/image_recognition.py +36 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/_common/ocr.py +95 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/anomaly_detection.py +67 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/doc_preprocessor.py +100 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/doc_understanding.py +153 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/face_recognition.py +226 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/formula_recognition.py +100 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/human_keypoint_detection.py +81 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/image_classification.py +69 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/image_multilabel_classification.py +73 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/instance_segmentation.py +87 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/layout_parsing.py +117 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/m_3d_bev_detection.py +79 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/multilingual_speech_recognition.py +92 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/object_detection.py +77 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/ocr.py +102 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/open_vocabulary_detection.py +81 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/open_vocabulary_segmentation.py +91 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/pedestrian_attribute_recognition.py +84 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/pp_chatocrv3_doc.py +193 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/pp_chatocrv4_doc.py +223 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/pp_shituv2.py +221 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/pp_structurev3.py +143 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/rotated_object_detection.py +81 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/seal_recognition.py +106 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/semantic_segmentation.py +67 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/small_object_detection.py +72 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/table_recognition.py +108 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/table_recognition_v2.py +113 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/ts_anomaly_detection.py +65 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/ts_classification.py +64 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/ts_forecast.py +65 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/vehicle_attribute_recognition.py +84 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/video_classification.py +76 -0
- paddlex/inference/serving/basic_serving/_pipeline_apps/video_detection.py +92 -0
- paddlex/inference/serving/basic_serving/_server.py +40 -0
- paddlex/inference/serving/infra/__init__.py +13 -0
- paddlex/inference/serving/infra/config.py +36 -0
- paddlex/inference/serving/infra/models.py +79 -0
- paddlex/inference/serving/infra/storage.py +180 -0
- paddlex/inference/serving/infra/utils.py +285 -0
- paddlex/inference/serving/schemas/__init__.py +13 -0
- paddlex/inference/serving/schemas/anomaly_detection.py +39 -0
- paddlex/inference/serving/schemas/doc_preprocessor.py +54 -0
- paddlex/inference/serving/schemas/doc_understanding.py +78 -0
- paddlex/inference/serving/schemas/face_recognition.py +124 -0
- paddlex/inference/serving/schemas/formula_recognition.py +56 -0
- paddlex/inference/serving/schemas/human_keypoint_detection.py +55 -0
- paddlex/inference/serving/schemas/image_classification.py +45 -0
- paddlex/inference/serving/schemas/image_multilabel_classification.py +47 -0
- paddlex/inference/serving/schemas/instance_segmentation.py +53 -0
- paddlex/inference/serving/schemas/layout_parsing.py +71 -0
- paddlex/inference/serving/schemas/m_3d_bev_detection.py +48 -0
- paddlex/inference/serving/schemas/multilingual_speech_recognition.py +57 -0
- paddlex/inference/serving/schemas/object_detection.py +52 -0
- paddlex/inference/serving/schemas/ocr.py +60 -0
- paddlex/inference/serving/schemas/open_vocabulary_detection.py +52 -0
- paddlex/inference/serving/schemas/open_vocabulary_segmentation.py +52 -0
- paddlex/inference/serving/schemas/pedestrian_attribute_recognition.py +61 -0
- paddlex/inference/serving/schemas/pp_chatocrv3_doc.py +133 -0
- paddlex/inference/serving/schemas/pp_chatocrv4_doc.py +150 -0
- paddlex/inference/serving/schemas/pp_shituv2.py +124 -0
- paddlex/inference/serving/schemas/pp_structurev3.py +88 -0
- paddlex/inference/serving/schemas/rotated_object_detection.py +52 -0
- paddlex/inference/serving/schemas/seal_recognition.py +62 -0
- paddlex/inference/serving/schemas/semantic_segmentation.py +45 -0
- paddlex/inference/serving/schemas/shared/__init__.py +13 -0
- paddlex/inference/serving/schemas/shared/classification.py +23 -0
- paddlex/inference/serving/schemas/shared/image_segmentation.py +28 -0
- paddlex/inference/serving/schemas/shared/object_detection.py +24 -0
- paddlex/inference/serving/schemas/shared/ocr.py +25 -0
- paddlex/inference/serving/schemas/small_object_detection.py +52 -0
- paddlex/inference/serving/schemas/table_recognition.py +64 -0
- paddlex/inference/serving/schemas/table_recognition_v2.py +69 -0
- paddlex/inference/serving/schemas/ts_anomaly_detection.py +37 -0
- paddlex/inference/serving/schemas/ts_classification.py +38 -0
- paddlex/inference/serving/schemas/ts_forecast.py +37 -0
- paddlex/inference/serving/schemas/vehicle_attribute_recognition.py +61 -0
- paddlex/inference/serving/schemas/video_classification.py +44 -0
- paddlex/inference/serving/schemas/video_detection.py +56 -0
- paddlex/inference/utils/__init__.py +13 -0
- paddlex/inference/utils/benchmark.py +379 -0
- paddlex/inference/utils/color_map.py +123 -0
- paddlex/inference/utils/get_pipeline_path.py +27 -0
- paddlex/inference/utils/hpi.py +254 -0
- paddlex/inference/utils/hpi_model_info_collection.json +2331 -0
- paddlex/inference/utils/io/__init__.py +36 -0
- paddlex/inference/utils/io/readers.py +504 -0
- paddlex/inference/utils/io/style.py +381 -0
- paddlex/inference/utils/io/tablepyxl.py +157 -0
- paddlex/inference/utils/io/writers.py +458 -0
- paddlex/inference/utils/model_paths.py +48 -0
- paddlex/inference/utils/new_ir_blocklist.py +27 -0
- paddlex/inference/utils/official_models.py +367 -0
- paddlex/inference/utils/pp_option.py +339 -0
- paddlex/inference/utils/trt_blocklist.py +43 -0
- paddlex/inference/utils/trt_config.py +420 -0
- paddlex/model.py +131 -0
- paddlex/modules/__init__.py +115 -0
- paddlex/modules/anomaly_detection/__init__.py +18 -0
- paddlex/modules/anomaly_detection/dataset_checker/__init__.py +94 -0
- paddlex/modules/anomaly_detection/dataset_checker/dataset_src/__init__.py +19 -0
- paddlex/modules/anomaly_detection/dataset_checker/dataset_src/analyse_dataset.py +82 -0
- paddlex/modules/anomaly_detection/dataset_checker/dataset_src/check_dataset.py +91 -0
- paddlex/modules/anomaly_detection/dataset_checker/dataset_src/convert_dataset.py +233 -0
- paddlex/modules/anomaly_detection/dataset_checker/dataset_src/split_dataset.py +87 -0
- paddlex/modules/anomaly_detection/dataset_checker/dataset_src/utils/__init__.py +13 -0
- paddlex/modules/anomaly_detection/dataset_checker/dataset_src/utils/visualizer.py +76 -0
- paddlex/modules/anomaly_detection/evaluator.py +58 -0
- paddlex/modules/anomaly_detection/exportor.py +22 -0
- paddlex/modules/anomaly_detection/model_list.py +16 -0
- paddlex/modules/anomaly_detection/trainer.py +70 -0
- paddlex/modules/base/__init__.py +18 -0
- paddlex/modules/base/build_model.py +33 -0
- paddlex/modules/base/dataset_checker/__init__.py +16 -0
- paddlex/modules/base/dataset_checker/dataset_checker.py +169 -0
- paddlex/modules/base/dataset_checker/utils.py +108 -0
- paddlex/modules/base/evaluator.py +170 -0
- paddlex/modules/base/exportor.py +145 -0
- paddlex/modules/base/trainer.py +144 -0
- paddlex/modules/base/utils/__init__.py +13 -0
- paddlex/modules/base/utils/cinn_setting.py +89 -0
- paddlex/modules/base/utils/coco_eval.py +94 -0
- paddlex/modules/base/utils/topk_eval.py +118 -0
- paddlex/modules/doc_vlm/__init__.py +18 -0
- paddlex/modules/doc_vlm/dataset_checker.py +29 -0
- paddlex/modules/doc_vlm/evaluator.py +29 -0
- paddlex/modules/doc_vlm/exportor.py +29 -0
- paddlex/modules/doc_vlm/model_list.py +16 -0
- paddlex/modules/doc_vlm/trainer.py +41 -0
- paddlex/modules/face_recognition/__init__.py +18 -0
- paddlex/modules/face_recognition/dataset_checker/__init__.py +71 -0
- paddlex/modules/face_recognition/dataset_checker/dataset_src/__init__.py +16 -0
- paddlex/modules/face_recognition/dataset_checker/dataset_src/check_dataset.py +172 -0
- paddlex/modules/face_recognition/dataset_checker/dataset_src/utils/__init__.py +13 -0
- paddlex/modules/face_recognition/dataset_checker/dataset_src/utils/visualizer.py +153 -0
- paddlex/modules/face_recognition/evaluator.py +52 -0
- paddlex/modules/face_recognition/exportor.py +22 -0
- paddlex/modules/face_recognition/model_list.py +15 -0
- paddlex/modules/face_recognition/trainer.py +75 -0
- paddlex/modules/formula_recognition/__init__.py +18 -0
- paddlex/modules/formula_recognition/dataset_checker/__init__.py +113 -0
- paddlex/modules/formula_recognition/dataset_checker/dataset_src/__init__.py +19 -0
- paddlex/modules/formula_recognition/dataset_checker/dataset_src/analyse_dataset.py +158 -0
- paddlex/modules/formula_recognition/dataset_checker/dataset_src/check_dataset.py +76 -0
- paddlex/modules/formula_recognition/dataset_checker/dataset_src/convert_dataset.py +95 -0
- paddlex/modules/formula_recognition/dataset_checker/dataset_src/split_dataset.py +80 -0
- paddlex/modules/formula_recognition/evaluator.py +80 -0
- paddlex/modules/formula_recognition/exportor.py +22 -0
- paddlex/modules/formula_recognition/model_list.py +23 -0
- paddlex/modules/formula_recognition/trainer.py +123 -0
- paddlex/modules/general_recognition/__init__.py +18 -0
- paddlex/modules/general_recognition/dataset_checker/__init__.py +107 -0
- paddlex/modules/general_recognition/dataset_checker/dataset_src/__init__.py +19 -0
- paddlex/modules/general_recognition/dataset_checker/dataset_src/analyse_dataset.py +96 -0
- paddlex/modules/general_recognition/dataset_checker/dataset_src/check_dataset.py +99 -0
- paddlex/modules/general_recognition/dataset_checker/dataset_src/convert_dataset.py +100 -0
- paddlex/modules/general_recognition/dataset_checker/dataset_src/split_dataset.py +82 -0
- paddlex/modules/general_recognition/dataset_checker/dataset_src/utils/__init__.py +13 -0
- paddlex/modules/general_recognition/dataset_checker/dataset_src/utils/visualizer.py +147 -0
- paddlex/modules/general_recognition/evaluator.py +31 -0
- paddlex/modules/general_recognition/exportor.py +22 -0
- paddlex/modules/general_recognition/model_list.py +19 -0
- paddlex/modules/general_recognition/trainer.py +52 -0
- paddlex/modules/image_classification/__init__.py +18 -0
- paddlex/modules/image_classification/dataset_checker/__init__.py +104 -0
- paddlex/modules/image_classification/dataset_checker/dataset_src/__init__.py +19 -0
- paddlex/modules/image_classification/dataset_checker/dataset_src/analyse_dataset.py +92 -0
- paddlex/modules/image_classification/dataset_checker/dataset_src/check_dataset.py +132 -0
- paddlex/modules/image_classification/dataset_checker/dataset_src/convert_dataset.py +51 -0
- paddlex/modules/image_classification/dataset_checker/dataset_src/split_dataset.py +81 -0
- paddlex/modules/image_classification/dataset_checker/dataset_src/utils/__init__.py +13 -0
- paddlex/modules/image_classification/dataset_checker/dataset_src/utils/visualizer.py +153 -0
- paddlex/modules/image_classification/evaluator.py +43 -0
- paddlex/modules/image_classification/exportor.py +22 -0
- paddlex/modules/image_classification/model_list.py +99 -0
- paddlex/modules/image_classification/trainer.py +82 -0
- paddlex/modules/image_unwarping/__init__.py +13 -0
- paddlex/modules/image_unwarping/model_list.py +17 -0
- paddlex/modules/instance_segmentation/__init__.py +18 -0
- paddlex/modules/instance_segmentation/dataset_checker/__init__.py +107 -0
- paddlex/modules/instance_segmentation/dataset_checker/dataset_src/__init__.py +19 -0
- paddlex/modules/instance_segmentation/dataset_checker/dataset_src/analyse_dataset.py +82 -0
- paddlex/modules/instance_segmentation/dataset_checker/dataset_src/check_dataset.py +95 -0
- paddlex/modules/instance_segmentation/dataset_checker/dataset_src/convert_dataset.py +241 -0
- paddlex/modules/instance_segmentation/dataset_checker/dataset_src/split_dataset.py +122 -0
- paddlex/modules/instance_segmentation/dataset_checker/dataset_src/utils/__init__.py +13 -0
- paddlex/modules/instance_segmentation/dataset_checker/dataset_src/utils/visualizer.py +223 -0
- paddlex/modules/instance_segmentation/evaluator.py +32 -0
- paddlex/modules/instance_segmentation/exportor.py +22 -0
- paddlex/modules/instance_segmentation/model_list.py +33 -0
- paddlex/modules/instance_segmentation/trainer.py +31 -0
- paddlex/modules/keypoint_detection/__init__.py +18 -0
- paddlex/modules/keypoint_detection/dataset_checker/__init__.py +56 -0
- paddlex/modules/keypoint_detection/dataset_checker/dataset_src/__init__.py +15 -0
- paddlex/modules/keypoint_detection/dataset_checker/dataset_src/check_dataset.py +91 -0
- paddlex/modules/keypoint_detection/dataset_checker/dataset_src/utils/__init__.py +13 -0
- paddlex/modules/keypoint_detection/dataset_checker/dataset_src/utils/visualizer.py +124 -0
- paddlex/modules/keypoint_detection/evaluator.py +41 -0
- paddlex/modules/keypoint_detection/exportor.py +22 -0
- paddlex/modules/keypoint_detection/model_list.py +16 -0
- paddlex/modules/keypoint_detection/trainer.py +39 -0
- paddlex/modules/m_3d_bev_detection/__init__.py +18 -0
- paddlex/modules/m_3d_bev_detection/dataset_checker/__init__.py +95 -0
- paddlex/modules/m_3d_bev_detection/dataset_checker/dataset_src/__init__.py +17 -0
- paddlex/modules/m_3d_bev_detection/dataset_checker/dataset_src/analyse_dataset.py +106 -0
- paddlex/modules/m_3d_bev_detection/dataset_checker/dataset_src/check_dataset.py +101 -0
- paddlex/modules/m_3d_bev_detection/evaluator.py +46 -0
- paddlex/modules/m_3d_bev_detection/exportor.py +22 -0
- paddlex/modules/m_3d_bev_detection/model_list.py +18 -0
- paddlex/modules/m_3d_bev_detection/trainer.py +68 -0
- paddlex/modules/multilabel_classification/__init__.py +18 -0
- paddlex/modules/multilabel_classification/dataset_checker/__init__.py +106 -0
- paddlex/modules/multilabel_classification/dataset_checker/dataset_src/__init__.py +19 -0
- paddlex/modules/multilabel_classification/dataset_checker/dataset_src/analyse_dataset.py +94 -0
- paddlex/modules/multilabel_classification/dataset_checker/dataset_src/check_dataset.py +132 -0
- paddlex/modules/multilabel_classification/dataset_checker/dataset_src/convert_dataset.py +120 -0
- paddlex/modules/multilabel_classification/dataset_checker/dataset_src/split_dataset.py +81 -0
- paddlex/modules/multilabel_classification/dataset_checker/dataset_src/utils/__init__.py +13 -0
- paddlex/modules/multilabel_classification/dataset_checker/dataset_src/utils/visualizer.py +149 -0
- paddlex/modules/multilabel_classification/evaluator.py +43 -0
- paddlex/modules/multilabel_classification/exportor.py +22 -0
- paddlex/modules/multilabel_classification/model_list.py +24 -0
- paddlex/modules/multilabel_classification/trainer.py +85 -0
- paddlex/modules/multilingual_speech_recognition/__init__.py +18 -0
- paddlex/modules/multilingual_speech_recognition/dataset_checker.py +27 -0
- paddlex/modules/multilingual_speech_recognition/evaluator.py +27 -0
- paddlex/modules/multilingual_speech_recognition/exportor.py +27 -0
- paddlex/modules/multilingual_speech_recognition/model_list.py +22 -0
- paddlex/modules/multilingual_speech_recognition/trainer.py +42 -0
- paddlex/modules/object_detection/__init__.py +18 -0
- paddlex/modules/object_detection/dataset_checker/__init__.py +106 -0
- paddlex/modules/object_detection/dataset_checker/dataset_src/__init__.py +19 -0
- paddlex/modules/object_detection/dataset_checker/dataset_src/analyse_dataset.py +82 -0
- paddlex/modules/object_detection/dataset_checker/dataset_src/check_dataset.py +91 -0
- paddlex/modules/object_detection/dataset_checker/dataset_src/convert_dataset.py +438 -0
- paddlex/modules/object_detection/dataset_checker/dataset_src/split_dataset.py +123 -0
- paddlex/modules/object_detection/dataset_checker/dataset_src/utils/__init__.py +13 -0
- paddlex/modules/object_detection/dataset_checker/dataset_src/utils/visualizer.py +193 -0
- paddlex/modules/object_detection/evaluator.py +57 -0
- paddlex/modules/object_detection/exportor.py +22 -0
- paddlex/modules/object_detection/model_list.py +86 -0
- paddlex/modules/object_detection/trainer.py +98 -0
- paddlex/modules/open_vocabulary_detection/__init__.py +18 -0
- paddlex/modules/open_vocabulary_detection/dataset_checker.py +29 -0
- paddlex/modules/open_vocabulary_detection/evaluator.py +29 -0
- paddlex/modules/open_vocabulary_detection/exportor.py +29 -0
- paddlex/modules/open_vocabulary_detection/model_list.py +16 -0
- paddlex/modules/open_vocabulary_detection/trainer.py +44 -0
- paddlex/modules/open_vocabulary_segmentation/__init__.py +18 -0
- paddlex/modules/open_vocabulary_segmentation/dataset_checker.py +29 -0
- paddlex/modules/open_vocabulary_segmentation/evaluator.py +29 -0
- paddlex/modules/open_vocabulary_segmentation/exportor.py +29 -0
- paddlex/modules/open_vocabulary_segmentation/model_list.py +19 -0
- paddlex/modules/open_vocabulary_segmentation/trainer.py +44 -0
- paddlex/modules/semantic_segmentation/__init__.py +18 -0
- paddlex/modules/semantic_segmentation/dataset_checker/__init__.py +109 -0
- paddlex/modules/semantic_segmentation/dataset_checker/dataset_src/__init__.py +19 -0
- paddlex/modules/semantic_segmentation/dataset_checker/dataset_src/analyse_dataset.py +76 -0
- paddlex/modules/semantic_segmentation/dataset_checker/dataset_src/check_dataset.py +80 -0
- paddlex/modules/semantic_segmentation/dataset_checker/dataset_src/convert_dataset.py +165 -0
- paddlex/modules/semantic_segmentation/dataset_checker/dataset_src/split_dataset.py +87 -0
- paddlex/modules/semantic_segmentation/dataset_checker/dataset_src/utils/__init__.py +13 -0
- paddlex/modules/semantic_segmentation/dataset_checker/dataset_src/utils/visualizer.py +75 -0
- paddlex/modules/semantic_segmentation/evaluator.py +58 -0
- paddlex/modules/semantic_segmentation/exportor.py +31 -0
- paddlex/modules/semantic_segmentation/model_list.py +37 -0
- paddlex/modules/semantic_segmentation/trainer.py +72 -0
- paddlex/modules/table_recognition/__init__.py +18 -0
- paddlex/modules/table_recognition/dataset_checker/__init__.py +98 -0
- paddlex/modules/table_recognition/dataset_checker/dataset_src/__init__.py +18 -0
- paddlex/modules/table_recognition/dataset_checker/dataset_src/analyse_dataset.py +59 -0
- paddlex/modules/table_recognition/dataset_checker/dataset_src/check_dataset.py +87 -0
- paddlex/modules/table_recognition/dataset_checker/dataset_src/split_dataset.py +80 -0
- paddlex/modules/table_recognition/evaluator.py +43 -0
- paddlex/modules/table_recognition/exportor.py +22 -0
- paddlex/modules/table_recognition/model_list.py +21 -0
- paddlex/modules/table_recognition/trainer.py +67 -0
- paddlex/modules/text_detection/__init__.py +18 -0
- paddlex/modules/text_detection/dataset_checker/__init__.py +107 -0
- paddlex/modules/text_detection/dataset_checker/dataset_src/__init__.py +18 -0
- paddlex/modules/text_detection/dataset_checker/dataset_src/analyse_dataset.py +220 -0
- paddlex/modules/text_detection/dataset_checker/dataset_src/check_dataset.py +106 -0
- paddlex/modules/text_detection/dataset_checker/dataset_src/split_dataset.py +140 -0
- paddlex/modules/text_detection/evaluator.py +41 -0
- paddlex/modules/text_detection/exportor.py +22 -0
- paddlex/modules/text_detection/model_list.py +26 -0
- paddlex/modules/text_detection/trainer.py +65 -0
- paddlex/modules/text_recognition/__init__.py +18 -0
- paddlex/modules/text_recognition/dataset_checker/__init__.py +125 -0
- paddlex/modules/text_recognition/dataset_checker/dataset_src/__init__.py +19 -0
- paddlex/modules/text_recognition/dataset_checker/dataset_src/analyse_dataset.py +162 -0
- paddlex/modules/text_recognition/dataset_checker/dataset_src/check_dataset.py +104 -0
- paddlex/modules/text_recognition/dataset_checker/dataset_src/convert_dataset.py +95 -0
- paddlex/modules/text_recognition/dataset_checker/dataset_src/split_dataset.py +80 -0
- paddlex/modules/text_recognition/evaluator.py +64 -0
- paddlex/modules/text_recognition/exportor.py +22 -0
- paddlex/modules/text_recognition/model_list.py +36 -0
- paddlex/modules/text_recognition/trainer.py +105 -0
- paddlex/modules/ts_anomaly_detection/__init__.py +19 -0
- paddlex/modules/ts_anomaly_detection/dataset_checker/__init__.py +111 -0
- paddlex/modules/ts_anomaly_detection/dataset_checker/dataset_src/__init__.py +19 -0
- paddlex/modules/ts_anomaly_detection/dataset_checker/dataset_src/analyse_dataset.py +19 -0
- paddlex/modules/ts_anomaly_detection/dataset_checker/dataset_src/check_dataset.py +64 -0
- paddlex/modules/ts_anomaly_detection/dataset_checker/dataset_src/convert_dataset.py +74 -0
- paddlex/modules/ts_anomaly_detection/dataset_checker/dataset_src/split_dataset.py +63 -0
- paddlex/modules/ts_anomaly_detection/evaluator.py +67 -0
- paddlex/modules/ts_anomaly_detection/exportor.py +44 -0
- paddlex/modules/ts_anomaly_detection/model_list.py +22 -0
- paddlex/modules/ts_anomaly_detection/trainer.py +113 -0
- paddlex/modules/ts_classification/__init__.py +19 -0
- paddlex/modules/ts_classification/dataset_checker/__init__.py +111 -0
- paddlex/modules/ts_classification/dataset_checker/dataset_src/__init__.py +19 -0
- paddlex/modules/ts_classification/dataset_checker/dataset_src/analyse_dataset.py +77 -0
- paddlex/modules/ts_classification/dataset_checker/dataset_src/check_dataset.py +64 -0
- paddlex/modules/ts_classification/dataset_checker/dataset_src/convert_dataset.py +74 -0
- paddlex/modules/ts_classification/dataset_checker/dataset_src/split_dataset.py +88 -0
- paddlex/modules/ts_classification/evaluator.py +66 -0
- paddlex/modules/ts_classification/exportor.py +44 -0
- paddlex/modules/ts_classification/model_list.py +18 -0
- paddlex/modules/ts_classification/trainer.py +108 -0
- paddlex/modules/ts_forecast/__init__.py +19 -0
- paddlex/modules/ts_forecast/dataset_checker/__init__.py +111 -0
- paddlex/modules/ts_forecast/dataset_checker/dataset_src/__init__.py +19 -0
- paddlex/modules/ts_forecast/dataset_checker/dataset_src/analyse_dataset.py +19 -0
- paddlex/modules/ts_forecast/dataset_checker/dataset_src/check_dataset.py +64 -0
- paddlex/modules/ts_forecast/dataset_checker/dataset_src/convert_dataset.py +73 -0
- paddlex/modules/ts_forecast/dataset_checker/dataset_src/split_dataset.py +63 -0
- paddlex/modules/ts_forecast/evaluator.py +66 -0
- paddlex/modules/ts_forecast/exportor.py +44 -0
- paddlex/modules/ts_forecast/model_list.py +24 -0
- paddlex/modules/ts_forecast/trainer.py +108 -0
- paddlex/modules/video_classification/__init__.py +18 -0
- paddlex/modules/video_classification/dataset_checker/__init__.py +93 -0
- paddlex/modules/video_classification/dataset_checker/dataset_src/__init__.py +18 -0
- paddlex/modules/video_classification/dataset_checker/dataset_src/analyse_dataset.py +93 -0
- paddlex/modules/video_classification/dataset_checker/dataset_src/check_dataset.py +120 -0
- paddlex/modules/video_classification/dataset_checker/dataset_src/split_dataset.py +82 -0
- paddlex/modules/video_classification/evaluator.py +44 -0
- paddlex/modules/video_classification/exportor.py +22 -0
- paddlex/modules/video_classification/model_list.py +19 -0
- paddlex/modules/video_classification/trainer.py +88 -0
- paddlex/modules/video_detection/__init__.py +18 -0
- paddlex/modules/video_detection/dataset_checker/__init__.py +86 -0
- paddlex/modules/video_detection/dataset_checker/dataset_src/__init__.py +17 -0
- paddlex/modules/video_detection/dataset_checker/dataset_src/analyse_dataset.py +100 -0
- paddlex/modules/video_detection/dataset_checker/dataset_src/check_dataset.py +132 -0
- paddlex/modules/video_detection/evaluator.py +42 -0
- paddlex/modules/video_detection/exportor.py +22 -0
- paddlex/modules/video_detection/model_list.py +15 -0
- paddlex/modules/video_detection/trainer.py +82 -0
- paddlex/ops/__init__.py +152 -0
- paddlex/ops/iou3d_nms/iou3d_cpu.cpp +266 -0
- paddlex/ops/iou3d_nms/iou3d_cpu.h +28 -0
- paddlex/ops/iou3d_nms/iou3d_nms.cpp +206 -0
- paddlex/ops/iou3d_nms/iou3d_nms.h +35 -0
- paddlex/ops/iou3d_nms/iou3d_nms_api.cpp +114 -0
- paddlex/ops/iou3d_nms/iou3d_nms_kernel.cu +484 -0
- paddlex/ops/setup.py +37 -0
- paddlex/ops/voxel/voxelize_op.cc +194 -0
- paddlex/ops/voxel/voxelize_op.cu +346 -0
- paddlex/paddlex_cli.py +476 -0
- paddlex/repo_apis/Paddle3D_api/__init__.py +17 -0
- paddlex/repo_apis/Paddle3D_api/bev_fusion/__init__.py +18 -0
- paddlex/repo_apis/Paddle3D_api/bev_fusion/config.py +118 -0
- paddlex/repo_apis/Paddle3D_api/bev_fusion/model.py +238 -0
- paddlex/repo_apis/Paddle3D_api/bev_fusion/register.py +55 -0
- paddlex/repo_apis/Paddle3D_api/bev_fusion/runner.py +104 -0
- paddlex/repo_apis/Paddle3D_api/pp3d_config.py +145 -0
- paddlex/repo_apis/PaddleClas_api/__init__.py +17 -0
- paddlex/repo_apis/PaddleClas_api/cls/__init__.py +19 -0
- paddlex/repo_apis/PaddleClas_api/cls/config.py +595 -0
- paddlex/repo_apis/PaddleClas_api/cls/model.py +355 -0
- paddlex/repo_apis/PaddleClas_api/cls/register.py +907 -0
- paddlex/repo_apis/PaddleClas_api/cls/runner.py +218 -0
- paddlex/repo_apis/PaddleClas_api/shitu_rec/__init__.py +18 -0
- paddlex/repo_apis/PaddleClas_api/shitu_rec/config.py +141 -0
- paddlex/repo_apis/PaddleClas_api/shitu_rec/model.py +20 -0
- paddlex/repo_apis/PaddleClas_api/shitu_rec/register.py +68 -0
- paddlex/repo_apis/PaddleClas_api/shitu_rec/runner.py +50 -0
- paddlex/repo_apis/PaddleDetection_api/__init__.py +17 -0
- paddlex/repo_apis/PaddleDetection_api/config_helper.py +280 -0
- paddlex/repo_apis/PaddleDetection_api/instance_seg/__init__.py +18 -0
- paddlex/repo_apis/PaddleDetection_api/instance_seg/config.py +457 -0
- paddlex/repo_apis/PaddleDetection_api/instance_seg/model.py +403 -0
- paddlex/repo_apis/PaddleDetection_api/instance_seg/register.py +262 -0
- paddlex/repo_apis/PaddleDetection_api/instance_seg/runner.py +225 -0
- paddlex/repo_apis/PaddleDetection_api/object_det/__init__.py +19 -0
- paddlex/repo_apis/PaddleDetection_api/object_det/config.py +540 -0
- paddlex/repo_apis/PaddleDetection_api/object_det/model.py +429 -0
- paddlex/repo_apis/PaddleDetection_api/object_det/official_categories.py +245 -0
- paddlex/repo_apis/PaddleDetection_api/object_det/register.py +1135 -0
- paddlex/repo_apis/PaddleDetection_api/object_det/runner.py +225 -0
- paddlex/repo_apis/PaddleNLP_api/__init__.py +13 -0
- paddlex/repo_apis/PaddleOCR_api/__init__.py +22 -0
- paddlex/repo_apis/PaddleOCR_api/config_utils.py +53 -0
- paddlex/repo_apis/PaddleOCR_api/formula_rec/__init__.py +16 -0
- paddlex/repo_apis/PaddleOCR_api/formula_rec/config.py +571 -0
- paddlex/repo_apis/PaddleOCR_api/formula_rec/model.py +398 -0
- paddlex/repo_apis/PaddleOCR_api/formula_rec/register.py +99 -0
- paddlex/repo_apis/PaddleOCR_api/formula_rec/runner.py +239 -0
- paddlex/repo_apis/PaddleOCR_api/table_rec/__init__.py +16 -0
- paddlex/repo_apis/PaddleOCR_api/table_rec/config.py +64 -0
- paddlex/repo_apis/PaddleOCR_api/table_rec/model.py +126 -0
- paddlex/repo_apis/PaddleOCR_api/table_rec/register.py +70 -0
- paddlex/repo_apis/PaddleOCR_api/table_rec/runner.py +51 -0
- paddlex/repo_apis/PaddleOCR_api/text_det/__init__.py +16 -0
- paddlex/repo_apis/PaddleOCR_api/text_det/config.py +62 -0
- paddlex/repo_apis/PaddleOCR_api/text_det/model.py +72 -0
- paddlex/repo_apis/PaddleOCR_api/text_det/register.py +107 -0
- paddlex/repo_apis/PaddleOCR_api/text_det/runner.py +53 -0
- paddlex/repo_apis/PaddleOCR_api/text_rec/__init__.py +16 -0
- paddlex/repo_apis/PaddleOCR_api/text_rec/config.py +564 -0
- paddlex/repo_apis/PaddleOCR_api/text_rec/model.py +398 -0
- paddlex/repo_apis/PaddleOCR_api/text_rec/register.py +216 -0
- paddlex/repo_apis/PaddleOCR_api/text_rec/runner.py +239 -0
- paddlex/repo_apis/PaddleSeg_api/__init__.py +16 -0
- paddlex/repo_apis/PaddleSeg_api/base_seg_config.py +134 -0
- paddlex/repo_apis/PaddleSeg_api/seg/__init__.py +16 -0
- paddlex/repo_apis/PaddleSeg_api/seg/config.py +183 -0
- paddlex/repo_apis/PaddleSeg_api/seg/model.py +491 -0
- paddlex/repo_apis/PaddleSeg_api/seg/register.py +272 -0
- paddlex/repo_apis/PaddleSeg_api/seg/runner.py +261 -0
- paddlex/repo_apis/PaddleTS_api/__init__.py +20 -0
- paddlex/repo_apis/PaddleTS_api/ts_ad/__init__.py +16 -0
- paddlex/repo_apis/PaddleTS_api/ts_ad/config.py +88 -0
- paddlex/repo_apis/PaddleTS_api/ts_ad/register.py +146 -0
- paddlex/repo_apis/PaddleTS_api/ts_ad/runner.py +158 -0
- paddlex/repo_apis/PaddleTS_api/ts_base/__init__.py +13 -0
- paddlex/repo_apis/PaddleTS_api/ts_base/config.py +244 -0
- paddlex/repo_apis/PaddleTS_api/ts_base/model.py +276 -0
- paddlex/repo_apis/PaddleTS_api/ts_base/runner.py +158 -0
- paddlex/repo_apis/PaddleTS_api/ts_cls/__init__.py +16 -0
- paddlex/repo_apis/PaddleTS_api/ts_cls/config.py +72 -0
- paddlex/repo_apis/PaddleTS_api/ts_cls/register.py +59 -0
- paddlex/repo_apis/PaddleTS_api/ts_cls/runner.py +158 -0
- paddlex/repo_apis/PaddleTS_api/ts_fc/__init__.py +16 -0
- paddlex/repo_apis/PaddleTS_api/ts_fc/config.py +136 -0
- paddlex/repo_apis/PaddleTS_api/ts_fc/register.py +186 -0
- paddlex/repo_apis/PaddleVideo_api/__init__.py +17 -0
- paddlex/repo_apis/PaddleVideo_api/config_utils.py +51 -0
- paddlex/repo_apis/PaddleVideo_api/video_cls/__init__.py +19 -0
- paddlex/repo_apis/PaddleVideo_api/video_cls/config.py +548 -0
- paddlex/repo_apis/PaddleVideo_api/video_cls/model.py +346 -0
- paddlex/repo_apis/PaddleVideo_api/video_cls/register.py +70 -0
- paddlex/repo_apis/PaddleVideo_api/video_cls/runner.py +204 -0
- paddlex/repo_apis/PaddleVideo_api/video_det/__init__.py +19 -0
- paddlex/repo_apis/PaddleVideo_api/video_det/config.py +549 -0
- paddlex/repo_apis/PaddleVideo_api/video_det/model.py +298 -0
- paddlex/repo_apis/PaddleVideo_api/video_det/register.py +44 -0
- paddlex/repo_apis/PaddleVideo_api/video_det/runner.py +199 -0
- paddlex/repo_apis/__init__.py +13 -0
- paddlex/repo_apis/base/__init__.py +22 -0
- paddlex/repo_apis/base/config.py +237 -0
- paddlex/repo_apis/base/model.py +563 -0
- paddlex/repo_apis/base/register.py +135 -0
- paddlex/repo_apis/base/runner.py +390 -0
- paddlex/repo_apis/base/utils/__init__.py +13 -0
- paddlex/repo_apis/base/utils/arg.py +64 -0
- paddlex/repo_apis/base/utils/subprocess.py +107 -0
- paddlex/repo_manager/__init__.py +17 -0
- paddlex/repo_manager/core.py +253 -0
- paddlex/repo_manager/meta.py +180 -0
- paddlex/repo_manager/repo.py +425 -0
- paddlex/repo_manager/utils.py +148 -0
- paddlex/utils/__init__.py +1 -12
- paddlex/utils/cache.py +146 -0
- paddlex/utils/config.py +216 -0
- paddlex/utils/custom_device_list.py +311 -0
- paddlex/utils/deps.py +249 -0
- paddlex/utils/device.py +195 -0
- paddlex/utils/download.py +168 -182
- paddlex/utils/env.py +31 -48
- paddlex/utils/errors/__init__.py +17 -0
- paddlex/utils/errors/dataset_checker.py +78 -0
- paddlex/utils/errors/others.py +138 -0
- paddlex/utils/file_interface.py +211 -0
- paddlex/utils/flags.py +70 -0
- paddlex/utils/fonts/__init__.py +97 -0
- paddlex/utils/func_register.py +41 -0
- paddlex/utils/install.py +87 -0
- paddlex/utils/interactive_get_pipeline.py +55 -0
- paddlex/utils/lazy_loader.py +68 -0
- paddlex/utils/logging.py +140 -33
- paddlex/utils/misc.py +201 -0
- paddlex/utils/pipeline_arguments.py +719 -0
- paddlex/utils/result_saver.py +58 -0
- paddlex/utils/subclass_register.py +99 -0
- paddlex/version.py +55 -0
- paddlex-3.0.0.dist-info/METADATA +1168 -0
- paddlex-3.0.0.dist-info/RECORD +1093 -0
- paddlex-3.0.0.dist-info/WHEEL +5 -0
- paddlex-3.0.0.dist-info/entry_points.txt +2 -0
- paddlex-3.0.0.dist-info/licenses/LICENSE +169 -0
- paddlex-3.0.0.dist-info/top_level.txt +1 -0
- PaddleClas/__init__.py +0 -16
- PaddleClas/paddleclas.py +0 -375
- PaddleClas/ppcls/__init__.py +0 -20
- PaddleClas/ppcls/data/__init__.py +0 -15
- PaddleClas/ppcls/data/imaug/__init__.py +0 -94
- PaddleClas/ppcls/data/imaug/autoaugment.py +0 -264
- PaddleClas/ppcls/data/imaug/batch_operators.py +0 -117
- PaddleClas/ppcls/data/imaug/cutout.py +0 -41
- PaddleClas/ppcls/data/imaug/fmix.py +0 -217
- PaddleClas/ppcls/data/imaug/grid.py +0 -89
- PaddleClas/ppcls/data/imaug/hide_and_seek.py +0 -44
- PaddleClas/ppcls/data/imaug/operators.py +0 -244
- PaddleClas/ppcls/data/imaug/randaugment.py +0 -106
- PaddleClas/ppcls/data/imaug/random_erasing.py +0 -55
- PaddleClas/ppcls/data/reader.py +0 -318
- PaddleClas/ppcls/modeling/__init__.py +0 -20
- PaddleClas/ppcls/modeling/architectures/__init__.py +0 -51
- PaddleClas/ppcls/modeling/architectures/alexnet.py +0 -132
- PaddleClas/ppcls/modeling/architectures/darknet.py +0 -161
- PaddleClas/ppcls/modeling/architectures/densenet.py +0 -308
- PaddleClas/ppcls/modeling/architectures/distillation_models.py +0 -65
- PaddleClas/ppcls/modeling/architectures/distilled_vision_transformer.py +0 -196
- PaddleClas/ppcls/modeling/architectures/dpn.py +0 -425
- PaddleClas/ppcls/modeling/architectures/efficientnet.py +0 -901
- PaddleClas/ppcls/modeling/architectures/ghostnet.py +0 -331
- PaddleClas/ppcls/modeling/architectures/googlenet.py +0 -207
- PaddleClas/ppcls/modeling/architectures/hrnet.py +0 -742
- PaddleClas/ppcls/modeling/architectures/inception_v3.py +0 -481
- PaddleClas/ppcls/modeling/architectures/inception_v4.py +0 -455
- PaddleClas/ppcls/modeling/architectures/mixnet.py +0 -782
- PaddleClas/ppcls/modeling/architectures/mobilenet_v1.py +0 -266
- PaddleClas/ppcls/modeling/architectures/mobilenet_v2.py +0 -248
- PaddleClas/ppcls/modeling/architectures/mobilenet_v3.py +0 -359
- PaddleClas/ppcls/modeling/architectures/regnet.py +0 -383
- PaddleClas/ppcls/modeling/architectures/repvgg.py +0 -339
- PaddleClas/ppcls/modeling/architectures/res2net.py +0 -272
- PaddleClas/ppcls/modeling/architectures/res2net_vd.py +0 -295
- PaddleClas/ppcls/modeling/architectures/resnest.py +0 -705
- PaddleClas/ppcls/modeling/architectures/resnet.py +0 -316
- PaddleClas/ppcls/modeling/architectures/resnet_vc.py +0 -309
- PaddleClas/ppcls/modeling/architectures/resnet_vd.py +0 -354
- PaddleClas/ppcls/modeling/architectures/resnext.py +0 -253
- PaddleClas/ppcls/modeling/architectures/resnext101_wsl.py +0 -447
- PaddleClas/ppcls/modeling/architectures/resnext_vd.py +0 -266
- PaddleClas/ppcls/modeling/architectures/rexnet.py +0 -240
- PaddleClas/ppcls/modeling/architectures/se_resnet_vd.py +0 -378
- PaddleClas/ppcls/modeling/architectures/se_resnext.py +0 -290
- PaddleClas/ppcls/modeling/architectures/se_resnext_vd.py +0 -285
- PaddleClas/ppcls/modeling/architectures/shufflenet_v2.py +0 -320
- PaddleClas/ppcls/modeling/architectures/squeezenet.py +0 -154
- PaddleClas/ppcls/modeling/architectures/vgg.py +0 -152
- PaddleClas/ppcls/modeling/architectures/vision_transformer.py +0 -402
- PaddleClas/ppcls/modeling/architectures/xception.py +0 -345
- PaddleClas/ppcls/modeling/architectures/xception_deeplab.py +0 -386
- PaddleClas/ppcls/modeling/loss.py +0 -154
- PaddleClas/ppcls/modeling/utils.py +0 -53
- PaddleClas/ppcls/optimizer/__init__.py +0 -19
- PaddleClas/ppcls/optimizer/learning_rate.py +0 -159
- PaddleClas/ppcls/optimizer/optimizer.py +0 -165
- PaddleClas/ppcls/utils/__init__.py +0 -27
- PaddleClas/ppcls/utils/check.py +0 -151
- PaddleClas/ppcls/utils/config.py +0 -201
- PaddleClas/ppcls/utils/logger.py +0 -120
- PaddleClas/ppcls/utils/metrics.py +0 -107
- PaddleClas/ppcls/utils/misc.py +0 -62
- PaddleClas/ppcls/utils/model_zoo.py +0 -213
- PaddleClas/ppcls/utils/save_load.py +0 -163
- PaddleClas/setup.py +0 -55
- PaddleClas/tools/__init__.py +0 -15
- PaddleClas/tools/download.py +0 -50
- PaddleClas/tools/ema.py +0 -58
- PaddleClas/tools/eval.py +0 -112
- PaddleClas/tools/export_model.py +0 -85
- PaddleClas/tools/export_serving_model.py +0 -76
- PaddleClas/tools/infer/__init__.py +0 -16
- PaddleClas/tools/infer/infer.py +0 -94
- PaddleClas/tools/infer/predict.py +0 -117
- PaddleClas/tools/infer/utils.py +0 -233
- PaddleClas/tools/program.py +0 -444
- PaddleClas/tools/test_hubserving.py +0 -113
- PaddleClas/tools/train.py +0 -141
- paddlex/cls.py +0 -76
- paddlex/command.py +0 -215
- paddlex/cv/__init__.py +0 -17
- paddlex/cv/datasets/__init__.py +0 -18
- paddlex/cv/datasets/coco.py +0 -169
- paddlex/cv/datasets/imagenet.py +0 -88
- paddlex/cv/datasets/seg_dataset.py +0 -91
- paddlex/cv/datasets/voc.py +0 -301
- paddlex/cv/models/__init__.py +0 -18
- paddlex/cv/models/base.py +0 -623
- paddlex/cv/models/classifier.py +0 -814
- paddlex/cv/models/detector.py +0 -1747
- paddlex/cv/models/load_model.py +0 -126
- paddlex/cv/models/segmenter.py +0 -673
- paddlex/cv/models/slim/__init__.py +0 -13
- paddlex/cv/models/slim/prune.py +0 -55
- paddlex/cv/models/utils/__init__.py +0 -13
- paddlex/cv/models/utils/det_metrics/__init__.py +0 -15
- paddlex/cv/models/utils/det_metrics/coco_utils.py +0 -217
- paddlex/cv/models/utils/det_metrics/metrics.py +0 -220
- paddlex/cv/models/utils/ema.py +0 -48
- paddlex/cv/models/utils/seg_metrics.py +0 -62
- paddlex/cv/models/utils/visualize.py +0 -394
- paddlex/cv/transforms/__init__.py +0 -46
- paddlex/cv/transforms/batch_operators.py +0 -286
- paddlex/cv/transforms/box_utils.py +0 -41
- paddlex/cv/transforms/functions.py +0 -193
- paddlex/cv/transforms/operators.py +0 -1402
- paddlex/det.py +0 -43
- paddlex/paddleseg/__init__.py +0 -17
- paddlex/paddleseg/core/__init__.py +0 -20
- paddlex/paddleseg/core/infer.py +0 -289
- paddlex/paddleseg/core/predict.py +0 -145
- paddlex/paddleseg/core/train.py +0 -258
- paddlex/paddleseg/core/val.py +0 -172
- paddlex/paddleseg/cvlibs/__init__.py +0 -17
- paddlex/paddleseg/cvlibs/callbacks.py +0 -279
- paddlex/paddleseg/cvlibs/config.py +0 -359
- paddlex/paddleseg/cvlibs/manager.py +0 -142
- paddlex/paddleseg/cvlibs/param_init.py +0 -91
- paddlex/paddleseg/datasets/__init__.py +0 -21
- paddlex/paddleseg/datasets/ade.py +0 -112
- paddlex/paddleseg/datasets/cityscapes.py +0 -86
- paddlex/paddleseg/datasets/cocostuff.py +0 -79
- paddlex/paddleseg/datasets/dataset.py +0 -164
- paddlex/paddleseg/datasets/mini_deep_globe_road_extraction.py +0 -95
- paddlex/paddleseg/datasets/optic_disc_seg.py +0 -97
- paddlex/paddleseg/datasets/pascal_context.py +0 -80
- paddlex/paddleseg/datasets/voc.py +0 -113
- paddlex/paddleseg/models/__init__.py +0 -39
- paddlex/paddleseg/models/ann.py +0 -436
- paddlex/paddleseg/models/attention_unet.py +0 -189
- paddlex/paddleseg/models/backbones/__init__.py +0 -18
- paddlex/paddleseg/models/backbones/hrnet.py +0 -815
- paddlex/paddleseg/models/backbones/mobilenetv3.py +0 -365
- paddlex/paddleseg/models/backbones/resnet_vd.py +0 -364
- paddlex/paddleseg/models/backbones/xception_deeplab.py +0 -415
- paddlex/paddleseg/models/bisenet.py +0 -311
- paddlex/paddleseg/models/danet.py +0 -220
- paddlex/paddleseg/models/decoupled_segnet.py +0 -233
- paddlex/paddleseg/models/deeplab.py +0 -258
- paddlex/paddleseg/models/dnlnet.py +0 -231
- paddlex/paddleseg/models/emanet.py +0 -219
- paddlex/paddleseg/models/fast_scnn.py +0 -318
- paddlex/paddleseg/models/fcn.py +0 -135
- paddlex/paddleseg/models/gcnet.py +0 -223
- paddlex/paddleseg/models/gscnn.py +0 -357
- paddlex/paddleseg/models/hardnet.py +0 -309
- paddlex/paddleseg/models/isanet.py +0 -202
- paddlex/paddleseg/models/layers/__init__.py +0 -19
- paddlex/paddleseg/models/layers/activation.py +0 -73
- paddlex/paddleseg/models/layers/attention.py +0 -146
- paddlex/paddleseg/models/layers/layer_libs.py +0 -168
- paddlex/paddleseg/models/layers/nonlocal2d.py +0 -155
- paddlex/paddleseg/models/layers/pyramid_pool.py +0 -182
- paddlex/paddleseg/models/losses/__init__.py +0 -27
- paddlex/paddleseg/models/losses/binary_cross_entropy_loss.py +0 -174
- paddlex/paddleseg/models/losses/bootstrapped_cross_entropy.py +0 -73
- paddlex/paddleseg/models/losses/cross_entropy_loss.py +0 -94
- paddlex/paddleseg/models/losses/decoupledsegnet_relax_boundary_loss.py +0 -129
- paddlex/paddleseg/models/losses/dice_loss.py +0 -61
- paddlex/paddleseg/models/losses/edge_attention_loss.py +0 -78
- paddlex/paddleseg/models/losses/gscnn_dual_task_loss.py +0 -141
- paddlex/paddleseg/models/losses/l1_loss.py +0 -76
- paddlex/paddleseg/models/losses/lovasz_loss.py +0 -222
- paddlex/paddleseg/models/losses/mean_square_error_loss.py +0 -65
- paddlex/paddleseg/models/losses/mixed_loss.py +0 -58
- paddlex/paddleseg/models/losses/ohem_cross_entropy_loss.py +0 -99
- paddlex/paddleseg/models/losses/ohem_edge_attention_loss.py +0 -114
- paddlex/paddleseg/models/ocrnet.py +0 -248
- paddlex/paddleseg/models/pspnet.py +0 -147
- paddlex/paddleseg/models/sfnet.py +0 -236
- paddlex/paddleseg/models/shufflenet_slim.py +0 -268
- paddlex/paddleseg/models/u2net.py +0 -574
- paddlex/paddleseg/models/unet.py +0 -155
- paddlex/paddleseg/models/unet_3plus.py +0 -316
- paddlex/paddleseg/models/unet_plusplus.py +0 -237
- paddlex/paddleseg/transforms/__init__.py +0 -16
- paddlex/paddleseg/transforms/functional.py +0 -161
- paddlex/paddleseg/transforms/transforms.py +0 -937
- paddlex/paddleseg/utils/__init__.py +0 -22
- paddlex/paddleseg/utils/config_check.py +0 -60
- paddlex/paddleseg/utils/download.py +0 -163
- paddlex/paddleseg/utils/env/__init__.py +0 -16
- paddlex/paddleseg/utils/env/seg_env.py +0 -56
- paddlex/paddleseg/utils/env/sys_env.py +0 -122
- paddlex/paddleseg/utils/logger.py +0 -48
- paddlex/paddleseg/utils/metrics.py +0 -146
- paddlex/paddleseg/utils/progbar.py +0 -212
- paddlex/paddleseg/utils/timer.py +0 -53
- paddlex/paddleseg/utils/utils.py +0 -120
- paddlex/paddleseg/utils/visualize.py +0 -90
- paddlex/ppcls/__init__.py +0 -20
- paddlex/ppcls/data/__init__.py +0 -15
- paddlex/ppcls/data/imaug/__init__.py +0 -94
- paddlex/ppcls/data/imaug/autoaugment.py +0 -264
- paddlex/ppcls/data/imaug/batch_operators.py +0 -117
- paddlex/ppcls/data/imaug/cutout.py +0 -41
- paddlex/ppcls/data/imaug/fmix.py +0 -217
- paddlex/ppcls/data/imaug/grid.py +0 -89
- paddlex/ppcls/data/imaug/hide_and_seek.py +0 -44
- paddlex/ppcls/data/imaug/operators.py +0 -256
- paddlex/ppcls/data/imaug/randaugment.py +0 -106
- paddlex/ppcls/data/imaug/random_erasing.py +0 -55
- paddlex/ppcls/data/reader.py +0 -318
- paddlex/ppcls/modeling/__init__.py +0 -20
- paddlex/ppcls/modeling/architectures/__init__.py +0 -51
- paddlex/ppcls/modeling/architectures/alexnet.py +0 -132
- paddlex/ppcls/modeling/architectures/darknet.py +0 -161
- paddlex/ppcls/modeling/architectures/densenet.py +0 -308
- paddlex/ppcls/modeling/architectures/distillation_models.py +0 -65
- paddlex/ppcls/modeling/architectures/distilled_vision_transformer.py +0 -196
- paddlex/ppcls/modeling/architectures/dpn.py +0 -425
- paddlex/ppcls/modeling/architectures/efficientnet.py +0 -901
- paddlex/ppcls/modeling/architectures/ghostnet.py +0 -331
- paddlex/ppcls/modeling/architectures/googlenet.py +0 -207
- paddlex/ppcls/modeling/architectures/hrnet.py +0 -742
- paddlex/ppcls/modeling/architectures/inception_v3.py +0 -541
- paddlex/ppcls/modeling/architectures/inception_v4.py +0 -455
- paddlex/ppcls/modeling/architectures/mixnet.py +0 -782
- paddlex/ppcls/modeling/architectures/mobilenet_v1.py +0 -266
- paddlex/ppcls/modeling/architectures/mobilenet_v2.py +0 -248
- paddlex/ppcls/modeling/architectures/mobilenet_v3.py +0 -359
- paddlex/ppcls/modeling/architectures/regnet.py +0 -383
- paddlex/ppcls/modeling/architectures/repvgg.py +0 -339
- paddlex/ppcls/modeling/architectures/res2net.py +0 -272
- paddlex/ppcls/modeling/architectures/res2net_vd.py +0 -295
- paddlex/ppcls/modeling/architectures/resnest.py +0 -705
- paddlex/ppcls/modeling/architectures/resnet.py +0 -317
- paddlex/ppcls/modeling/architectures/resnet_vc.py +0 -309
- paddlex/ppcls/modeling/architectures/resnet_vd.py +0 -354
- paddlex/ppcls/modeling/architectures/resnext.py +0 -259
- paddlex/ppcls/modeling/architectures/resnext101_wsl.py +0 -447
- paddlex/ppcls/modeling/architectures/resnext_vd.py +0 -266
- paddlex/ppcls/modeling/architectures/rexnet.py +0 -240
- paddlex/ppcls/modeling/architectures/se_resnet_vd.py +0 -378
- paddlex/ppcls/modeling/architectures/se_resnext.py +0 -290
- paddlex/ppcls/modeling/architectures/se_resnext_vd.py +0 -285
- paddlex/ppcls/modeling/architectures/shufflenet_v2.py +0 -320
- paddlex/ppcls/modeling/architectures/squeezenet.py +0 -154
- paddlex/ppcls/modeling/architectures/vgg.py +0 -152
- paddlex/ppcls/modeling/architectures/vision_transformer.py +0 -402
- paddlex/ppcls/modeling/architectures/xception.py +0 -345
- paddlex/ppcls/modeling/architectures/xception_deeplab.py +0 -386
- paddlex/ppcls/modeling/loss.py +0 -158
- paddlex/ppcls/modeling/utils.py +0 -53
- paddlex/ppcls/optimizer/__init__.py +0 -19
- paddlex/ppcls/optimizer/learning_rate.py +0 -159
- paddlex/ppcls/optimizer/optimizer.py +0 -165
- paddlex/ppcls/utils/__init__.py +0 -27
- paddlex/ppcls/utils/check.py +0 -151
- paddlex/ppcls/utils/config.py +0 -201
- paddlex/ppcls/utils/logger.py +0 -120
- paddlex/ppcls/utils/metrics.py +0 -112
- paddlex/ppcls/utils/misc.py +0 -62
- paddlex/ppcls/utils/model_zoo.py +0 -213
- paddlex/ppcls/utils/save_load.py +0 -163
- paddlex/ppdet/__init__.py +0 -16
- paddlex/ppdet/core/__init__.py +0 -15
- paddlex/ppdet/core/config/__init__.py +0 -13
- paddlex/ppdet/core/config/schema.py +0 -248
- paddlex/ppdet/core/config/yaml_helpers.py +0 -118
- paddlex/ppdet/core/workspace.py +0 -279
- paddlex/ppdet/data/__init__.py +0 -21
- paddlex/ppdet/data/reader.py +0 -304
- paddlex/ppdet/data/shm_utils.py +0 -67
- paddlex/ppdet/data/source/__init__.py +0 -27
- paddlex/ppdet/data/source/category.py +0 -823
- paddlex/ppdet/data/source/coco.py +0 -243
- paddlex/ppdet/data/source/dataset.py +0 -192
- paddlex/ppdet/data/source/keypoint_coco.py +0 -656
- paddlex/ppdet/data/source/mot.py +0 -360
- paddlex/ppdet/data/source/voc.py +0 -204
- paddlex/ppdet/data/source/widerface.py +0 -180
- paddlex/ppdet/data/transform/__init__.py +0 -28
- paddlex/ppdet/data/transform/autoaugment_utils.py +0 -1593
- paddlex/ppdet/data/transform/batch_operators.py +0 -758
- paddlex/ppdet/data/transform/gridmask_utils.py +0 -83
- paddlex/ppdet/data/transform/keypoint_operators.py +0 -665
- paddlex/ppdet/data/transform/mot_operators.py +0 -636
- paddlex/ppdet/data/transform/op_helper.py +0 -468
- paddlex/ppdet/data/transform/operators.py +0 -2103
- paddlex/ppdet/engine/__init__.py +0 -29
- paddlex/ppdet/engine/callbacks.py +0 -262
- paddlex/ppdet/engine/env.py +0 -47
- paddlex/ppdet/engine/export_utils.py +0 -118
- paddlex/ppdet/engine/tracker.py +0 -425
- paddlex/ppdet/engine/trainer.py +0 -535
- paddlex/ppdet/metrics/__init__.py +0 -23
- paddlex/ppdet/metrics/coco_utils.py +0 -184
- paddlex/ppdet/metrics/json_results.py +0 -151
- paddlex/ppdet/metrics/keypoint_metrics.py +0 -202
- paddlex/ppdet/metrics/map_utils.py +0 -396
- paddlex/ppdet/metrics/metrics.py +0 -300
- paddlex/ppdet/metrics/mot_eval_utils.py +0 -192
- paddlex/ppdet/metrics/mot_metrics.py +0 -184
- paddlex/ppdet/metrics/widerface_utils.py +0 -393
- paddlex/ppdet/model_zoo/__init__.py +0 -18
- paddlex/ppdet/model_zoo/model_zoo.py +0 -86
- paddlex/ppdet/model_zoo/tests/__init__.py +0 -13
- paddlex/ppdet/model_zoo/tests/test_get_model.py +0 -48
- paddlex/ppdet/model_zoo/tests/test_list_model.py +0 -68
- paddlex/ppdet/modeling/__init__.py +0 -41
- paddlex/ppdet/modeling/architectures/__init__.py +0 -40
- paddlex/ppdet/modeling/architectures/cascade_rcnn.py +0 -144
- paddlex/ppdet/modeling/architectures/centernet.py +0 -103
- paddlex/ppdet/modeling/architectures/deepsort.py +0 -111
- paddlex/ppdet/modeling/architectures/fairmot.py +0 -107
- paddlex/ppdet/modeling/architectures/faster_rcnn.py +0 -106
- paddlex/ppdet/modeling/architectures/fcos.py +0 -105
- paddlex/ppdet/modeling/architectures/jde.py +0 -125
- paddlex/ppdet/modeling/architectures/keypoint_hrhrnet.py +0 -286
- paddlex/ppdet/modeling/architectures/keypoint_hrnet.py +0 -203
- paddlex/ppdet/modeling/architectures/mask_rcnn.py +0 -135
- paddlex/ppdet/modeling/architectures/meta_arch.py +0 -45
- paddlex/ppdet/modeling/architectures/s2anet.py +0 -103
- paddlex/ppdet/modeling/architectures/solov2.py +0 -110
- paddlex/ppdet/modeling/architectures/ssd.py +0 -84
- paddlex/ppdet/modeling/architectures/ttfnet.py +0 -98
- paddlex/ppdet/modeling/architectures/yolo.py +0 -104
- paddlex/ppdet/modeling/backbones/__init__.py +0 -37
- paddlex/ppdet/modeling/backbones/blazenet.py +0 -322
- paddlex/ppdet/modeling/backbones/darknet.py +0 -341
- paddlex/ppdet/modeling/backbones/dla.py +0 -244
- paddlex/ppdet/modeling/backbones/ghostnet.py +0 -476
- paddlex/ppdet/modeling/backbones/hrnet.py +0 -724
- paddlex/ppdet/modeling/backbones/mobilenet_v1.py +0 -410
- paddlex/ppdet/modeling/backbones/mobilenet_v3.py +0 -497
- paddlex/ppdet/modeling/backbones/name_adapter.py +0 -69
- paddlex/ppdet/modeling/backbones/res2net.py +0 -358
- paddlex/ppdet/modeling/backbones/resnet.py +0 -606
- paddlex/ppdet/modeling/backbones/senet.py +0 -140
- paddlex/ppdet/modeling/backbones/vgg.py +0 -216
- paddlex/ppdet/modeling/bbox_utils.py +0 -464
- paddlex/ppdet/modeling/heads/__init__.py +0 -41
- paddlex/ppdet/modeling/heads/bbox_head.py +0 -379
- paddlex/ppdet/modeling/heads/cascade_head.py +0 -285
- paddlex/ppdet/modeling/heads/centernet_head.py +0 -194
- paddlex/ppdet/modeling/heads/face_head.py +0 -113
- paddlex/ppdet/modeling/heads/fcos_head.py +0 -270
- paddlex/ppdet/modeling/heads/keypoint_hrhrnet_head.py +0 -108
- paddlex/ppdet/modeling/heads/mask_head.py +0 -253
- paddlex/ppdet/modeling/heads/roi_extractor.py +0 -111
- paddlex/ppdet/modeling/heads/s2anet_head.py +0 -845
- paddlex/ppdet/modeling/heads/solov2_head.py +0 -537
- paddlex/ppdet/modeling/heads/ssd_head.py +0 -175
- paddlex/ppdet/modeling/heads/ttf_head.py +0 -314
- paddlex/ppdet/modeling/heads/yolo_head.py +0 -124
- paddlex/ppdet/modeling/keypoint_utils.py +0 -302
- paddlex/ppdet/modeling/layers.py +0 -1142
- paddlex/ppdet/modeling/losses/__init__.py +0 -35
- paddlex/ppdet/modeling/losses/ctfocal_loss.py +0 -67
- paddlex/ppdet/modeling/losses/fairmot_loss.py +0 -41
- paddlex/ppdet/modeling/losses/fcos_loss.py +0 -225
- paddlex/ppdet/modeling/losses/iou_aware_loss.py +0 -48
- paddlex/ppdet/modeling/losses/iou_loss.py +0 -210
- paddlex/ppdet/modeling/losses/jde_loss.py +0 -182
- paddlex/ppdet/modeling/losses/keypoint_loss.py +0 -228
- paddlex/ppdet/modeling/losses/solov2_loss.py +0 -101
- paddlex/ppdet/modeling/losses/ssd_loss.py +0 -163
- paddlex/ppdet/modeling/losses/yolo_loss.py +0 -212
- paddlex/ppdet/modeling/mot/__init__.py +0 -25
- paddlex/ppdet/modeling/mot/matching/__init__.py +0 -19
- paddlex/ppdet/modeling/mot/matching/deepsort_matching.py +0 -382
- paddlex/ppdet/modeling/mot/matching/jde_matching.py +0 -145
- paddlex/ppdet/modeling/mot/motion/__init__.py +0 -17
- paddlex/ppdet/modeling/mot/motion/kalman_filter.py +0 -270
- paddlex/ppdet/modeling/mot/tracker/__init__.py +0 -23
- paddlex/ppdet/modeling/mot/tracker/base_jde_tracker.py +0 -267
- paddlex/ppdet/modeling/mot/tracker/base_sde_tracker.py +0 -145
- paddlex/ppdet/modeling/mot/tracker/deepsort_tracker.py +0 -165
- paddlex/ppdet/modeling/mot/tracker/jde_tracker.py +0 -262
- paddlex/ppdet/modeling/mot/utils.py +0 -181
- paddlex/ppdet/modeling/mot/visualization.py +0 -130
- paddlex/ppdet/modeling/necks/__init__.py +0 -25
- paddlex/ppdet/modeling/necks/centernet_fpn.py +0 -185
- paddlex/ppdet/modeling/necks/fpn.py +0 -233
- paddlex/ppdet/modeling/necks/hrfpn.py +0 -131
- paddlex/ppdet/modeling/necks/ttf_fpn.py +0 -243
- paddlex/ppdet/modeling/necks/yolo_fpn.py +0 -1034
- paddlex/ppdet/modeling/ops.py +0 -1599
- paddlex/ppdet/modeling/post_process.py +0 -449
- paddlex/ppdet/modeling/proposal_generator/__init__.py +0 -2
- paddlex/ppdet/modeling/proposal_generator/anchor_generator.py +0 -135
- paddlex/ppdet/modeling/proposal_generator/proposal_generator.py +0 -81
- paddlex/ppdet/modeling/proposal_generator/rpn_head.py +0 -269
- paddlex/ppdet/modeling/proposal_generator/target.py +0 -671
- paddlex/ppdet/modeling/proposal_generator/target_layer.py +0 -476
- paddlex/ppdet/modeling/reid/__init__.py +0 -23
- paddlex/ppdet/modeling/reid/fairmot_embedding_head.py +0 -117
- paddlex/ppdet/modeling/reid/jde_embedding_head.py +0 -189
- paddlex/ppdet/modeling/reid/pyramidal_embedding.py +0 -151
- paddlex/ppdet/modeling/reid/resnet.py +0 -320
- paddlex/ppdet/modeling/shape_spec.py +0 -33
- paddlex/ppdet/modeling/tests/__init__.py +0 -13
- paddlex/ppdet/modeling/tests/test_architectures.py +0 -59
- paddlex/ppdet/modeling/tests/test_base.py +0 -75
- paddlex/ppdet/modeling/tests/test_ops.py +0 -839
- paddlex/ppdet/modeling/tests/test_yolov3_loss.py +0 -420
- paddlex/ppdet/optimizer.py +0 -285
- paddlex/ppdet/slim/__init__.py +0 -62
- paddlex/ppdet/slim/distill.py +0 -111
- paddlex/ppdet/slim/prune.py +0 -85
- paddlex/ppdet/slim/quant.py +0 -52
- paddlex/ppdet/utils/__init__.py +0 -13
- paddlex/ppdet/utils/check.py +0 -93
- paddlex/ppdet/utils/checkpoint.py +0 -216
- paddlex/ppdet/utils/cli.py +0 -151
- paddlex/ppdet/utils/colormap.py +0 -56
- paddlex/ppdet/utils/download.py +0 -477
- paddlex/ppdet/utils/logger.py +0 -71
- paddlex/ppdet/utils/stats.py +0 -95
- paddlex/ppdet/utils/visualizer.py +0 -292
- paddlex/ppdet/utils/voc_utils.py +0 -87
- paddlex/seg.py +0 -38
- paddlex/tools/__init__.py +0 -16
- paddlex/tools/convert.py +0 -52
- paddlex/tools/dataset_conversion/__init__.py +0 -24
- paddlex/tools/dataset_conversion/x2coco.py +0 -379
- paddlex/tools/dataset_conversion/x2imagenet.py +0 -82
- paddlex/tools/dataset_conversion/x2seg.py +0 -343
- paddlex/tools/dataset_conversion/x2voc.py +0 -230
- paddlex/tools/dataset_split/__init__.py +0 -23
- paddlex/tools/dataset_split/coco_split.py +0 -69
- paddlex/tools/dataset_split/imagenet_split.py +0 -75
- paddlex/tools/dataset_split/seg_split.py +0 -96
- paddlex/tools/dataset_split/utils.py +0 -75
- paddlex/tools/dataset_split/voc_split.py +0 -91
- paddlex/tools/split.py +0 -41
- paddlex/utils/checkpoint.py +0 -439
- paddlex/utils/shm.py +0 -67
- paddlex/utils/stats.py +0 -68
- paddlex/utils/utils.py +0 -140
- paddlex-2.0.0rc4.dist-info/LICENSE +0 -201
- paddlex-2.0.0rc4.dist-info/METADATA +0 -29
- paddlex-2.0.0rc4.dist-info/RECORD +0 -445
- paddlex-2.0.0rc4.dist-info/WHEEL +0 -5
- paddlex-2.0.0rc4.dist-info/entry_points.txt +0 -3
- paddlex-2.0.0rc4.dist-info/top_level.txt +0 -2
@@ -0,0 +1,2014 @@
|
|
1
|
+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
from __future__ import annotations
|
15
|
+
|
16
|
+
import gc
|
17
|
+
import os
|
18
|
+
import re
|
19
|
+
import warnings
|
20
|
+
from contextlib import contextmanager
|
21
|
+
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
22
|
+
|
23
|
+
import numpy as np
|
24
|
+
import paddle
|
25
|
+
import paddle.nn as nn
|
26
|
+
from paddle import Tensor
|
27
|
+
from paddle.distributed.fleet.meta_parallel.parallel_layers import PipelineLayer
|
28
|
+
|
29
|
+
try:
|
30
|
+
from paddle.distributed.fleet.meta_parallel import LocalSharedLayerDesc
|
31
|
+
except:
|
32
|
+
LocalSharedLayerDesc = None
|
33
|
+
from paddle.nn import Layer
|
34
|
+
|
35
|
+
from ......utils import logging
|
36
|
+
from ...tokenizer.tokenizer_utils import InitTrackerMeta, adapt_stale_fwd_patch
|
37
|
+
from ..generation import GenerationConfig, GenerationMixin
|
38
|
+
from ..utils import (
|
39
|
+
CONFIG_NAME,
|
40
|
+
LEGACY_CONFIG_NAME,
|
41
|
+
PADDLE_WEIGHTS_INDEX_NAME,
|
42
|
+
PADDLE_WEIGHTS_NAME,
|
43
|
+
PYTORCH_WEIGHTS_INDEX_NAME,
|
44
|
+
PYTORCH_WEIGHTS_NAME,
|
45
|
+
SAFE_WEIGHTS_INDEX_NAME,
|
46
|
+
SAFE_WEIGHTS_NAME,
|
47
|
+
device_guard,
|
48
|
+
resolve_file_path,
|
49
|
+
)
|
50
|
+
from .configuration_utils import PretrainedConfig
|
51
|
+
from .conversion_utils import ConversionMixin
|
52
|
+
from .utils import (
|
53
|
+
ContextManagers,
|
54
|
+
fn_args_to_dict,
|
55
|
+
get_checkpoint_shard_files,
|
56
|
+
is_safetensors_available,
|
57
|
+
paddlenlp_load,
|
58
|
+
weight_name_suffix,
|
59
|
+
)
|
60
|
+
|
61
|
+
__all__ = [
|
62
|
+
"PretrainedModel",
|
63
|
+
]
|
64
|
+
|
65
|
+
|
66
|
+
def _add_variant(weights_name: str, variant=None) -> str:
|
67
|
+
if variant is not None and len(variant) > 0:
|
68
|
+
splits = weights_name.split(".")
|
69
|
+
splits = splits[:-1] + [variant] + splits[-1:]
|
70
|
+
weights_name = ".".join(splits)
|
71
|
+
|
72
|
+
return weights_name
|
73
|
+
|
74
|
+
|
75
|
+
@contextmanager
|
76
|
+
def dtype_guard(dtype="float32"):
|
77
|
+
origin_dtype = paddle.get_default_dtype()
|
78
|
+
paddle.set_default_dtype(dtype)
|
79
|
+
try:
|
80
|
+
yield
|
81
|
+
finally:
|
82
|
+
paddle.set_default_dtype(origin_dtype)
|
83
|
+
|
84
|
+
|
85
|
+
_init_weights = True
|
86
|
+
|
87
|
+
|
88
|
+
@contextmanager
|
89
|
+
def no_init_weights(_enable=True):
|
90
|
+
"""
|
91
|
+
Context manager to globally disable weight initialization to speed up loading large models.
|
92
|
+
|
93
|
+
TODO(Patrick): Delete safety argument `_enable=True` at next major version. .
|
94
|
+
"""
|
95
|
+
global _init_weights
|
96
|
+
old_init_weights = _init_weights
|
97
|
+
if _enable:
|
98
|
+
_init_weights = False
|
99
|
+
try:
|
100
|
+
yield
|
101
|
+
finally:
|
102
|
+
_init_weights = old_init_weights
|
103
|
+
|
104
|
+
|
105
|
+
def _split_keys_evenly(keys: list, n: int) -> list:
|
106
|
+
"""Split a list into n lists with an equal number of elements.
|
107
|
+
|
108
|
+
Args:
|
109
|
+
keys (list): the list to be split
|
110
|
+
n (int): number of splits
|
111
|
+
|
112
|
+
Returns:
|
113
|
+
result: list of lists
|
114
|
+
"""
|
115
|
+
|
116
|
+
total_len = len(keys)
|
117
|
+
base_size = total_len // n
|
118
|
+
extra = total_len % n
|
119
|
+
|
120
|
+
result = []
|
121
|
+
index = 0
|
122
|
+
for _ in range(n):
|
123
|
+
part_size = base_size + 1 if extra > 0 else base_size
|
124
|
+
extra -= 1
|
125
|
+
result.append(keys[index : index + part_size])
|
126
|
+
index += part_size
|
127
|
+
|
128
|
+
return result
|
129
|
+
|
130
|
+
|
131
|
+
def load_state_dict(
|
132
|
+
checkpoint_file: Union[str, os.PathLike],
|
133
|
+
tensor_parallel_split_mapping=None,
|
134
|
+
fliter_dict_keys=None,
|
135
|
+
device="cpu",
|
136
|
+
ckpt_quant_stage="O0",
|
137
|
+
):
|
138
|
+
"""
|
139
|
+
Reads a PaddlePaddle checkpoint file, returning properly formatted errors if they arise.
|
140
|
+
"""
|
141
|
+
|
142
|
+
if tensor_parallel_split_mapping is None:
|
143
|
+
tensor_parallel_split_mapping = {}
|
144
|
+
|
145
|
+
state_dict = paddlenlp_load(checkpoint_file, map_location="cpu")
|
146
|
+
return state_dict
|
147
|
+
|
148
|
+
|
149
|
+
_re_layer_prefix = re.compile(r"\.(\d+)\.")
|
150
|
+
|
151
|
+
|
152
|
+
def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):
|
153
|
+
# torch will cast dtype in load_state_dict, but paddle strictly check dtype
|
154
|
+
_convert_state_dict_dtype_and_shape(state_dict, model_to_load)
|
155
|
+
|
156
|
+
error_msgs = []
|
157
|
+
|
158
|
+
if len(start_prefix) > 0:
|
159
|
+
for key in list(state_dict.keys()):
|
160
|
+
if key.startswith(start_prefix):
|
161
|
+
state_dict[key.replace(start_prefix, "")] = state_dict.pop(key)
|
162
|
+
|
163
|
+
# TODO: add return status to state_dict
|
164
|
+
with warnings.catch_warnings(record=True) as w:
|
165
|
+
warnings.resetwarnings()
|
166
|
+
# paddlenlp hold missing_keys , just ignore not found warnings.
|
167
|
+
warnings.filterwarnings(
|
168
|
+
"ignore", message=r".*is not found in the provided dict.*"
|
169
|
+
)
|
170
|
+
warnings.filterwarnings("ignore", message=r".*paddle.to_tensor.*")
|
171
|
+
model_to_load.set_state_dict(state_dict)
|
172
|
+
error_msgs.extend([str(x.message) for x in w])
|
173
|
+
|
174
|
+
del state_dict
|
175
|
+
|
176
|
+
return error_msgs
|
177
|
+
|
178
|
+
|
179
|
+
def _convert_state_dict_dtype_and_shape(state_dict, model_to_load):
|
180
|
+
# convert the dtype of state dict
|
181
|
+
def is_0d_or_1d(tensor):
|
182
|
+
return len(tensor.shape) == 0 or list(tensor.shape) == [1]
|
183
|
+
|
184
|
+
for key, value in model_to_load.state_dict().items():
|
185
|
+
if key in list(state_dict.keys()):
|
186
|
+
if isinstance(state_dict[key], np.ndarray):
|
187
|
+
raise ValueError(
|
188
|
+
"convert_state_dict_dtype expected paddle.Tensor not numpy.ndarray, please convert numpy.ndarray to paddle.Tensor"
|
189
|
+
)
|
190
|
+
# confirm parameter cast is executed on the same device as model
|
191
|
+
# TODO: cast(FP32 -> FP16) has diff on different devices, need to fix it
|
192
|
+
if (
|
193
|
+
state_dict[key].is_floating_point()
|
194
|
+
and state_dict[key].dtype != value.dtype
|
195
|
+
):
|
196
|
+
state_dict[key] = paddle.cast(state_dict.pop(key), value.dtype)
|
197
|
+
# unified 0d and 1d tensor
|
198
|
+
if is_0d_or_1d(value) and is_0d_or_1d(state_dict[key]):
|
199
|
+
if list(value.shape) != list(state_dict[key].shape):
|
200
|
+
state_dict[key] = paddle.reshape(state_dict.pop(key), value.shape)
|
201
|
+
|
202
|
+
|
203
|
+
def _load_state_dict_into_meta_model(
|
204
|
+
model,
|
205
|
+
state_dict,
|
206
|
+
loaded_state_dict_keys, # left for now but could be removed, see below
|
207
|
+
start_prefix,
|
208
|
+
expected_keys,
|
209
|
+
dtype=None,
|
210
|
+
is_safetensors=False,
|
211
|
+
keep_in_fp32_modules=None,
|
212
|
+
):
|
213
|
+
"""
|
214
|
+
This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its
|
215
|
+
params on a `meta` device. It replaces the model params with the data from the `state_dict`, while moving the
|
216
|
+
params back to the normal device, but only for `loaded_state_dict_keys`.
|
217
|
+
|
218
|
+
`start_prefix` is used for models which insert their name into model keys, e.g. `bert` in
|
219
|
+
`bert.pooler.dense.weight`
|
220
|
+
|
221
|
+
"""
|
222
|
+
from paddle.common_ops_import import convert_np_dtype_to_dtype_
|
223
|
+
|
224
|
+
dtype = convert_np_dtype_to_dtype_(dtype)
|
225
|
+
error_msgs = []
|
226
|
+
model_state_dict = model.state_dict()
|
227
|
+
for param_name, param in state_dict.items():
|
228
|
+
# First part of the test is always true as loaded_state_dict_keys always contains state_dict keys.
|
229
|
+
if param_name not in loaded_state_dict_keys or param_name not in expected_keys:
|
230
|
+
continue
|
231
|
+
|
232
|
+
if param_name.startswith(start_prefix):
|
233
|
+
param_name = param_name[len(start_prefix) :]
|
234
|
+
|
235
|
+
if param.place != paddle.framework._current_expected_place():
|
236
|
+
param = param._copy_to(paddle.framework._current_expected_place(), False)
|
237
|
+
|
238
|
+
# # We convert floating dtypes to the `dtype` passed. We want to keep the buffers/params
|
239
|
+
# # in int/uint/bool and not cast them.
|
240
|
+
if dtype is not None and paddle.is_floating_point(param):
|
241
|
+
if (
|
242
|
+
keep_in_fp32_modules is not None
|
243
|
+
and any(
|
244
|
+
module_to_keep_in_fp32 in param_name
|
245
|
+
for module_to_keep_in_fp32 in keep_in_fp32_modules
|
246
|
+
)
|
247
|
+
and (dtype == paddle.float16 or dtype == paddle.bfloat16)
|
248
|
+
):
|
249
|
+
param = param.astype(dtype=paddle.float32)
|
250
|
+
else:
|
251
|
+
param = param.astype(dtype=dtype)
|
252
|
+
|
253
|
+
if dtype is None:
|
254
|
+
old_param = model
|
255
|
+
splits = param_name.split(".")
|
256
|
+
for split in splits:
|
257
|
+
old_param = getattr(old_param, split)
|
258
|
+
if old_param is None:
|
259
|
+
break
|
260
|
+
|
261
|
+
if old_param is not None:
|
262
|
+
param = param.astype(dtype=old_param.dtype)
|
263
|
+
with paddle.no_grad():
|
264
|
+
model_state_dict[param_name].get_tensor()._share_data_with(
|
265
|
+
param.value().get_tensor()
|
266
|
+
)
|
267
|
+
param.value().get_tensor()._clear()
|
268
|
+
return error_msgs
|
269
|
+
|
270
|
+
|
271
|
+
class PretrainedModel(
|
272
|
+
Layer, GenerationMixin, ConversionMixin, metaclass=InitTrackerMeta
|
273
|
+
):
|
274
|
+
"""
|
275
|
+
The base class for all pretrained models. It mainly provides common methods
|
276
|
+
for loading (construction and loading) and saving pretrained models. Loading
|
277
|
+
and saving also rely on the following class attributes which should be overridden
|
278
|
+
by derived classes accordingly:
|
279
|
+
|
280
|
+
- **model_config_file** (str): Represents the file name of model configuration
|
281
|
+
for configuration saving and loading in local file system. The value is
|
282
|
+
`model_config.json`.
|
283
|
+
- **resource_files_names** (dict): Name of local file where the model configuration
|
284
|
+
can be saved and loaded locally. Currently, resources only include the model state,
|
285
|
+
thus the dict only includes `'model_state'` as key with corresponding
|
286
|
+
value `'model_state.pdparams'` for model weights saving and loading.
|
287
|
+
- **pretrained_init_configuration** (dict): Provides the model configurations
|
288
|
+
of built-in pretrained models (contrasts to models in local file system).
|
289
|
+
It has pretrained model names as keys (such as `bert-base-uncased`), and
|
290
|
+
the values are dict preserving corresponding configuration for model initialization.
|
291
|
+
- **pretrained_resource_files_map** (dict): Provides resource URLs of built-in
|
292
|
+
pretrained models (contrasts to models in local file system).
|
293
|
+
It has the same key as resource_files_names (that is "model_state"),
|
294
|
+
and the corresponding value is a dict with specific model name to model weights URL mapping
|
295
|
+
(such as "bert-base-uncased" ->
|
296
|
+
"https://bj.bcebos.com/paddlenlp/models/transformers/bert-base-uncased.pdparams").
|
297
|
+
- **base_model_prefix** (str): Represents the attribute associated to the
|
298
|
+
base model in derived classes of the same architecture adding layers on
|
299
|
+
top of the base model. Note: A base model class is pretrained model class
|
300
|
+
decorated by `register_base_model`, such as `BertModel`; A derived model
|
301
|
+
class is a pretrained model class adding layers on top of the base model,
|
302
|
+
and it has a base model as attribute, such as `BertForSequenceClassification`.
|
303
|
+
|
304
|
+
Methods common to models for text generation are defined in `GenerationMixin`
|
305
|
+
and also inherited here.
|
306
|
+
|
307
|
+
Besides, metaclass `InitTrackerMeta` is used to create `PretrainedModel`,
|
308
|
+
by which subclasses can track arguments for initialization automatically.
|
309
|
+
"""
|
310
|
+
|
311
|
+
# Deprecated(wj-Mcat): after 2.6.* version
|
312
|
+
# save the old-school `LEGACY_CONFIG_NAME`, and will be changed to `CONFIG_NAME` after 2.6.* version
|
313
|
+
model_config_file = LEGACY_CONFIG_NAME
|
314
|
+
|
315
|
+
pretrained_init_configuration = {}
|
316
|
+
# TODO: more flexible resource handle, namedtuple with fields as:
|
317
|
+
# resource_name, saved_file, handle_name_for_load(None for used as __init__
|
318
|
+
# arguments), handle_name_for_save
|
319
|
+
resource_files_names = {"model_state": PADDLE_WEIGHTS_NAME}
|
320
|
+
pretrained_resource_files_map = {}
|
321
|
+
base_model_prefix = ""
|
322
|
+
main_input_name = "input_ids"
|
323
|
+
config_class = None
|
324
|
+
_keep_in_fp32_modules = None
|
325
|
+
|
326
|
+
# a list of `re` patterns of `state_dict` keys that should be removed from the list of missing
|
327
|
+
# keys we find (keys inside the model but not in the checkpoint) and avoid unnecessary warnings.
|
328
|
+
_keys_to_ignore_on_load_missing = None
|
329
|
+
# a list of `re` patterns of `state_dict` keys that should be removed from the list of
|
330
|
+
# unexpected keys we find (keys inside the checkpoint but not the model) and avoid unnecessary
|
331
|
+
# warnings.
|
332
|
+
_keys_to_ignore_on_load_unexpected = None
|
333
|
+
# a list of `state_dict` keys to ignore when saving the model (useful for keys that aren't
|
334
|
+
# trained, but which are either deterministic or tied variables)
|
335
|
+
_keys_to_ignore_on_save = None
|
336
|
+
_tied_weights_keys = None
|
337
|
+
|
338
|
+
def __init__(self, *args, **kwargs):
|
339
|
+
super(PretrainedModel, self).__init__()
|
340
|
+
|
341
|
+
if not self.constructed_from_pretrained_config():
|
342
|
+
return
|
343
|
+
|
344
|
+
# extract config from args
|
345
|
+
config = None
|
346
|
+
for arg in args:
|
347
|
+
if isinstance(arg, PretrainedConfig):
|
348
|
+
config = arg
|
349
|
+
break
|
350
|
+
if config is not None:
|
351
|
+
self.config: PretrainedConfig = config
|
352
|
+
self.model_config_file = CONFIG_NAME
|
353
|
+
self.generation_config = (
|
354
|
+
GenerationConfig.from_model_config(self.config)
|
355
|
+
if self.can_generate()
|
356
|
+
else None
|
357
|
+
)
|
358
|
+
return
|
359
|
+
|
360
|
+
# extract config from kwargs
|
361
|
+
if "config" not in kwargs:
|
362
|
+
raise ValueError(
|
363
|
+
"PretrainedConfig instance not found in the arguments, you can set it as args or kwargs with config field"
|
364
|
+
)
|
365
|
+
|
366
|
+
config = kwargs["config"]
|
367
|
+
if not isinstance(config, PretrainedConfig):
|
368
|
+
raise TypeError(
|
369
|
+
"config parameter should be the instance of PretrainedConfig"
|
370
|
+
)
|
371
|
+
|
372
|
+
self.config: PretrainedConfig = kwargs["config"]
|
373
|
+
self.generation_config = (
|
374
|
+
GenerationConfig.from_model_config(self.config)
|
375
|
+
if self.can_generate()
|
376
|
+
else None
|
377
|
+
)
|
378
|
+
self.model_config_file = CONFIG_NAME
|
379
|
+
self.warnings_issued = {}
|
380
|
+
|
381
|
+
def _post_init(self, original_init, *args, **kwargs):
|
382
|
+
"""
|
383
|
+
It would be hooked after `__init__` to add a dict including arguments of
|
384
|
+
`__init__` as a attribute named `config` of the pretrained model instance.
|
385
|
+
"""
|
386
|
+
if not self.constructed_from_pretrained_config():
|
387
|
+
init_dict = fn_args_to_dict(original_init, *((self,) + args), **kwargs)
|
388
|
+
self.config = init_dict
|
389
|
+
|
390
|
+
# only execute when it's the base method
|
391
|
+
if (
|
392
|
+
original_init.__module__ != "paddlenlp.transformers.model_utils"
|
393
|
+
and self.__class__.init_weights is PretrainedModel.init_weights
|
394
|
+
):
|
395
|
+
self.init_weights()
|
396
|
+
|
397
|
+
# Note:
|
398
|
+
# 1. PipelineLayer will create parameters for each layer and
|
399
|
+
# call `_synchronize_shared_weights()` to synchronize the shared parameters.
|
400
|
+
# 2. When setting the model `state_dict`, `_synchronize_shared_weights` will be called to
|
401
|
+
# synchronize the shared parameters.
|
402
|
+
# However, `self._init_weights` will re-initialize the parameters without
|
403
|
+
# synchronizing the shared parameters. If the following step does not load a checkpoint,
|
404
|
+
# the shared parameters will be different.
|
405
|
+
|
406
|
+
if isinstance(self, PipelineLayer):
|
407
|
+
self._synchronize_shared_weights()
|
408
|
+
|
409
|
+
def _init_weights(self, layer):
|
410
|
+
"""
|
411
|
+
Initialize the weights. This method should be overridden by derived class.
|
412
|
+
"""
|
413
|
+
pass
|
414
|
+
|
415
|
+
def _initialize_weights(self, layer):
|
416
|
+
"""
|
417
|
+
Initialize the weights if they are not already initialized.
|
418
|
+
"""
|
419
|
+
if getattr(layer, "_is_initialized", False):
|
420
|
+
return
|
421
|
+
self._init_weights(layer)
|
422
|
+
layer._is_initialized = True
|
423
|
+
|
424
|
+
def init_weights(self):
|
425
|
+
"""
|
426
|
+
If needed prunes and maybe initializes weights. If using a custom `PreTrainedModel`, you need to implement any
|
427
|
+
initialization logic in `_init_weights`.
|
428
|
+
"""
|
429
|
+
# call pure
|
430
|
+
if _init_weights:
|
431
|
+
# Initialize weights
|
432
|
+
self.apply(self._initialize_weights)
|
433
|
+
|
434
|
+
# Tie weights should be skipped when not initializing all weights
|
435
|
+
# since from_pretrained(...) calls tie weights anyways
|
436
|
+
|
437
|
+
# TODO(wj-Mcat): enable all tie-weights later
|
438
|
+
# self.tie_weights()
|
439
|
+
|
440
|
+
@classmethod
|
441
|
+
def _from_config(cls, config, **kwargs):
|
442
|
+
"""
|
443
|
+
All context managers that the model should be initialized under go here.
|
444
|
+
|
445
|
+
Args:
|
446
|
+
dtype (`paddle.dtype`, *optional*):
|
447
|
+
Override the default `paddle.dtype` and load the model under this dtype.
|
448
|
+
"""
|
449
|
+
dtype = kwargs.pop("dtype", None)
|
450
|
+
|
451
|
+
if dtype is None:
|
452
|
+
if config.dtype is not None:
|
453
|
+
dtype = config.dtype
|
454
|
+
else:
|
455
|
+
dtype = paddle.get_default_dtype()
|
456
|
+
|
457
|
+
with dtype_guard(dtype):
|
458
|
+
model = cls(config, **kwargs)
|
459
|
+
|
460
|
+
return model
|
461
|
+
|
462
|
+
@classmethod
|
463
|
+
def from_config(cls, config, **kwargs):
|
464
|
+
"""
|
465
|
+
All context managers that the model should be initialized under go here.
|
466
|
+
|
467
|
+
Args:
|
468
|
+
dtype (`paddle.dtype`, *optional*):
|
469
|
+
Override the default `paddle.dtype` and load the model under this dtype.
|
470
|
+
"""
|
471
|
+
return cls._from_config(config, **kwargs)
|
472
|
+
|
473
|
+
@classmethod
|
474
|
+
def set_inference_config(cls, config, predictor_args, **kwargs):
|
475
|
+
"""
|
476
|
+
All inference config can set here.
|
477
|
+
Args:
|
478
|
+
config : PretrainedConfig
|
479
|
+
The config of the model.
|
480
|
+
predictor_args : PredictorArgument
|
481
|
+
The args of the predictor.
|
482
|
+
"""
|
483
|
+
tensor_parallel_degree = kwargs.pop("tensor_parallel_degree", 1)
|
484
|
+
tensor_parallel_rank = kwargs.pop("tensor_parallel_rank", 0)
|
485
|
+
|
486
|
+
if predictor_args.mode == "dynamic" or predictor_args.speculate_method in [
|
487
|
+
"eagle",
|
488
|
+
"mtp",
|
489
|
+
]:
|
490
|
+
config.tensor_parallel_degree = tensor_parallel_degree
|
491
|
+
config.tensor_parallel_rank = tensor_parallel_rank
|
492
|
+
config.model_name_or_path = predictor_args.model_name_or_path
|
493
|
+
config.quant_type = predictor_args.quant_type
|
494
|
+
config.cachekv_int8_type = predictor_args.cachekv_int8_type
|
495
|
+
config.use_fake_parameter = predictor_args.use_fake_parameter
|
496
|
+
config.single_card_ptq = not predictor_args.use_fake_parameter
|
497
|
+
config.append_attn = predictor_args.append_attn
|
498
|
+
config.decode_strategy = predictor_args.decode_strategy
|
499
|
+
config.mla_use_matrix_absorption = predictor_args.mla_use_matrix_absorption
|
500
|
+
config.weightonly_group_size = predictor_args.weightonly_group_size
|
501
|
+
config.weight_block_size = predictor_args.weight_block_size
|
502
|
+
config.moe_quant_type = predictor_args.moe_quant_type
|
503
|
+
if config.quantization_config.quant_method is not None:
|
504
|
+
predictor_args.weight_block_size = (
|
505
|
+
config.quantization_config.weight_block_size
|
506
|
+
)
|
507
|
+
config.weight_block_size = predictor_args.weight_block_size
|
508
|
+
|
509
|
+
if config.quantization_config.quant_type is not None:
|
510
|
+
if predictor_args.mode == "dynamic":
|
511
|
+
predictor_args.quant_type = config.quantization_config.quant_type
|
512
|
+
config.quant_type = config.quantization_config.quant_type
|
513
|
+
if "c8" in config.quant_type:
|
514
|
+
predictor_args.cachekv_int8_type = "static"
|
515
|
+
if predictor_args.mode == "dynamic":
|
516
|
+
config.cachekv_int8_type = "static"
|
517
|
+
|
518
|
+
if predictor_args.mode == "dynamic":
|
519
|
+
ptq_multicards_num = 0
|
520
|
+
if os.path.exists(config.model_name_or_path):
|
521
|
+
prefix = "act_scales_"
|
522
|
+
for filename in os.listdir(config.model_name_or_path):
|
523
|
+
if filename.startswith(prefix):
|
524
|
+
ptq_multicards_num += 1
|
525
|
+
|
526
|
+
logging.info(
|
527
|
+
f"PTQ from {ptq_multicards_num} cards, so we will not split"
|
528
|
+
)
|
529
|
+
if ptq_multicards_num > 1:
|
530
|
+
config.single_card_ptq = False
|
531
|
+
|
532
|
+
if predictor_args.block_attn:
|
533
|
+
config.block_size = predictor_args.block_size
|
534
|
+
config.max_seq_len = predictor_args.total_max_length
|
535
|
+
|
536
|
+
if predictor_args.speculate_method is not None:
|
537
|
+
config.speculate_method = predictor_args.speculate_method
|
538
|
+
config.speculate_max_draft_token_num = (
|
539
|
+
predictor_args.speculate_max_draft_token_num
|
540
|
+
)
|
541
|
+
config.speculate_verify_window = predictor_args.speculate_verify_window
|
542
|
+
config.speculate_max_candidate_len = (
|
543
|
+
predictor_args.speculate_max_candidate_len
|
544
|
+
)
|
545
|
+
if predictor_args.speculate_method == "inference_with_reference":
|
546
|
+
config.speculate_max_ngram_size = (
|
547
|
+
predictor_args.speculate_max_ngram_size
|
548
|
+
)
|
549
|
+
if predictor_args.speculate_method is not None:
|
550
|
+
if not config.get("speculate_model_type", "None") in ["eagle", "mtp"]:
|
551
|
+
config.decode_strategy = "speculate_decoding"
|
552
|
+
config.return_full_hidden_states = predictor_args.return_full_hidden_states
|
553
|
+
|
554
|
+
@classmethod
|
555
|
+
def confirm_inference_model(cls, predictor_args, **kwargs):
|
556
|
+
"""
|
557
|
+
Confirm the inference model whether it need to change the AVX inference Model
|
558
|
+
Args:
|
559
|
+
model : PretrainedModel
|
560
|
+
The model for inference.
|
561
|
+
predictor_args : PredictorArgument
|
562
|
+
The args of the predictor.
|
563
|
+
"""
|
564
|
+
return cls
|
565
|
+
|
566
|
+
@property
|
567
|
+
def base_model(self):
|
568
|
+
"""
|
569
|
+
PretrainedModel: The body of the same model architecture. It is the base
|
570
|
+
model itself for base model or the base model attribute for derived
|
571
|
+
model.
|
572
|
+
"""
|
573
|
+
return getattr(self, self.base_model_prefix, self)
|
574
|
+
|
575
|
+
@property
|
576
|
+
def model_name_list(self):
|
577
|
+
"""
|
578
|
+
list: Contains all supported built-in pretrained model names of the
|
579
|
+
current PretrainedModel class.
|
580
|
+
"""
|
581
|
+
# Todo: return all model name
|
582
|
+
return list(self.pretrained_init_configuration.keys())
|
583
|
+
|
584
|
+
def can_generate(self) -> bool:
|
585
|
+
"""
|
586
|
+
Returns whether this model can generate sequences with `.generate()`.
|
587
|
+
Returns:
|
588
|
+
`bool`: Whether this model can generate sequences with `.generate()`.
|
589
|
+
"""
|
590
|
+
# Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation
|
591
|
+
if "GenerationMixin" in str(self.prepare_inputs_for_generation):
|
592
|
+
return False
|
593
|
+
return True
|
594
|
+
|
595
|
+
def recompute_enable(self):
|
596
|
+
r"""
|
597
|
+
Enable Recompute.
|
598
|
+
All layers with the `enable_recompute` attribute will be set to `True`
|
599
|
+
"""
|
600
|
+
|
601
|
+
def fn(layer):
|
602
|
+
if hasattr(layer, "enable_recompute") and (
|
603
|
+
layer.enable_recompute is False or layer.enable_recompute == 0
|
604
|
+
):
|
605
|
+
layer.enable_recompute = True
|
606
|
+
|
607
|
+
self.apply(fn)
|
608
|
+
|
609
|
+
def recompute_disable(self):
|
610
|
+
r"""
|
611
|
+
Disable Recompute.
|
612
|
+
All layers with the `enable_recompute` attribute will be set to `False`
|
613
|
+
"""
|
614
|
+
|
615
|
+
def fn(layer):
|
616
|
+
if hasattr(layer, "enable_recompute") and (
|
617
|
+
layer.enable_recompute is False or layer.enable_recompute == 0
|
618
|
+
):
|
619
|
+
layer.enable_recompute = True
|
620
|
+
|
621
|
+
self.apply(fn)
|
622
|
+
|
623
|
+
def tie_weights(self):
|
624
|
+
"""
|
625
|
+
Tie the weights between the input embeddings and the output embeddings.
|
626
|
+
"""
|
627
|
+
if self.config.tie_word_embeddings:
|
628
|
+
output_embeddings = self.get_output_embeddings()
|
629
|
+
input_embeddings = self.get_input_embeddings()
|
630
|
+
if output_embeddings is not None and input_embeddings is not None:
|
631
|
+
if input_embeddings.weight.shape != output_embeddings.weight.shape:
|
632
|
+
logging.warning(
|
633
|
+
f"The shape of input embeddings is {input_embeddings.weight.shape} and the shape of output embeddings is {output_embeddings.weight.shape}. "
|
634
|
+
"This is only expected if you are calling the `resize_token_embeddings` method"
|
635
|
+
)
|
636
|
+
output_embeddings.weight = input_embeddings.weight
|
637
|
+
if getattr(output_embeddings, "bias", None) is not None:
|
638
|
+
# need to pad
|
639
|
+
if (
|
640
|
+
output_embeddings.weight.shape[0]
|
641
|
+
> output_embeddings.bias.shape[0]
|
642
|
+
):
|
643
|
+
old_bias = output_embeddings.bias
|
644
|
+
pad_length = (
|
645
|
+
output_embeddings.weight.shape[0] - old_bias.shape[0]
|
646
|
+
)
|
647
|
+
output_embeddings.bias = output_embeddings.create_parameter(
|
648
|
+
shape=[output_embeddings.weight.shape[0]],
|
649
|
+
attr=output_embeddings._bias_attr,
|
650
|
+
dtype=output_embeddings._dtype,
|
651
|
+
is_bias=True,
|
652
|
+
)
|
653
|
+
new_bias = paddle.concat(
|
654
|
+
[
|
655
|
+
old_bias,
|
656
|
+
paddle.zeros(
|
657
|
+
[pad_length], dtype=output_embeddings.bias.dtype
|
658
|
+
),
|
659
|
+
]
|
660
|
+
)
|
661
|
+
output_embeddings.bias.set_value(new_bias)
|
662
|
+
# need to trim
|
663
|
+
elif (
|
664
|
+
output_embeddings.weight.shape[0]
|
665
|
+
< output_embeddings.bias.shape[0]
|
666
|
+
):
|
667
|
+
new_bias = output_embeddings.bias[
|
668
|
+
: output_embeddings.weight.shape[0]
|
669
|
+
]
|
670
|
+
output_embeddings.bias = output_embeddings.create_parameter(
|
671
|
+
shape=[output_embeddings.weight.shape[0]],
|
672
|
+
attr=output_embeddings._bias_attr,
|
673
|
+
dtype=output_embeddings._dtype,
|
674
|
+
is_bias=True,
|
675
|
+
)
|
676
|
+
output_embeddings.bias.set_value(new_bias)
|
677
|
+
|
678
|
+
def resize_position_embeddings(self, new_num_position_embeddings: int):
|
679
|
+
"""resize position embedding, this method should be overrited overwrited by downstream models
|
680
|
+
|
681
|
+
Args:
|
682
|
+
new_num_position_embeddings (int): the new position size
|
683
|
+
|
684
|
+
Raises:
|
685
|
+
NotImplementedError: when called and not be implemented
|
686
|
+
"""
|
687
|
+
raise NotImplementedError(
|
688
|
+
f"`resize_position_embeddings` is not implemented for {self.__class__}`. To implement it, you should "
|
689
|
+
f"overwrite this method in the class {self.__class__} in `{self.__class__.__module__}.py`"
|
690
|
+
)
|
691
|
+
|
692
|
+
@classmethod
|
693
|
+
def constructed_from_pretrained_config(cls, init_func=None) -> bool:
|
694
|
+
"""check if the model is constructed from `PretrainedConfig`
|
695
|
+
Returns:
|
696
|
+
bool: if the model is constructed from `PretrainedConfig`
|
697
|
+
"""
|
698
|
+
return cls.config_class is not None and issubclass(
|
699
|
+
cls.config_class, PretrainedConfig
|
700
|
+
)
|
701
|
+
|
702
|
+
def resize_token_embeddings(
|
703
|
+
self, new_num_tokens: Optional[int] = None
|
704
|
+
) -> nn.Embedding:
|
705
|
+
"""
|
706
|
+
Resizes input token embeddings matrix of the model according to new_num_tokens.
|
707
|
+
|
708
|
+
Args:
|
709
|
+
new_num_tokens (Optional[int]):
|
710
|
+
The number of new tokens in the embedding matrix. Increasing the size will add newly initialized
|
711
|
+
vectors at the end. Reducing the size will remove vectors from the end. If not provided or None, just
|
712
|
+
returns a pointer to the input tokens embedding module of the model without doing anything.
|
713
|
+
|
714
|
+
Returns:
|
715
|
+
paddle.nn.Embedding: The input tokens Embeddings Module of the model.
|
716
|
+
"""
|
717
|
+
old_embeddings: nn.Embedding = self.get_input_embeddings()
|
718
|
+
if not new_num_tokens or new_num_tokens == old_embeddings.weight.shape[0]:
|
719
|
+
return old_embeddings
|
720
|
+
|
721
|
+
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
|
722
|
+
self.set_input_embeddings(new_embeddings)
|
723
|
+
|
724
|
+
# 2. Update vocab_size
|
725
|
+
self.base_model.config["vocab_size"] = new_num_tokens
|
726
|
+
self.vocab_size = new_num_tokens
|
727
|
+
|
728
|
+
# update init_config
|
729
|
+
self._update_init_config(self.init_config, "vocab_size", new_num_tokens)
|
730
|
+
|
731
|
+
# Tie the weights between the input embeddings and the output embeddings if needed.
|
732
|
+
self.tie_weights()
|
733
|
+
|
734
|
+
return new_embeddings
|
735
|
+
|
736
|
+
def _update_init_config(self, init_config: dict, key: str, value: Any):
|
737
|
+
"""update init_config by <key, value> pair
|
738
|
+
|
739
|
+
Args:
|
740
|
+
init_config (dict): the init_config instance
|
741
|
+
key (str): the key field
|
742
|
+
value (Any): the new value of instance
|
743
|
+
"""
|
744
|
+
if key in init_config:
|
745
|
+
init_config[key] = value
|
746
|
+
return
|
747
|
+
|
748
|
+
for arg in init_config.get("init_args", []):
|
749
|
+
if not isinstance(arg, PretrainedModel):
|
750
|
+
continue
|
751
|
+
self._update_init_config(arg.init_config, key, value)
|
752
|
+
|
753
|
+
def _get_resized_embeddings(
|
754
|
+
self, old_embeddings: nn.Embedding, new_num_tokens: Optional[int] = None
|
755
|
+
) -> nn.Embedding:
|
756
|
+
"""
|
757
|
+
Build a resized Embedding Module from a provided token Embedding Module. Increasing the size will add newly
|
758
|
+
initialized vectors at the end. Reducing the size will remove vectors from the end
|
759
|
+
|
760
|
+
Args:
|
761
|
+
old_embeddings (nn.Embedding):
|
762
|
+
Old embeddings to be resized.
|
763
|
+
new_num_tokens (Optional[int]):
|
764
|
+
New number of tokens in the embedding matrix.
|
765
|
+
Increasing the size will add newly initialized vectors at the end. Reducing the size will remove
|
766
|
+
vectors from the end.
|
767
|
+
|
768
|
+
Returns:
|
769
|
+
paddle.nn.Embedding: The resized Embedding Module or the old Embedding Module if new_num_tokens is None.
|
770
|
+
"""
|
771
|
+
if new_num_tokens is None:
|
772
|
+
return old_embeddings
|
773
|
+
|
774
|
+
old_num_tokens, old_embedding_dim = old_embeddings.weight.shape
|
775
|
+
if old_num_tokens == new_num_tokens:
|
776
|
+
return old_embeddings
|
777
|
+
|
778
|
+
if not isinstance(old_embeddings, nn.Embedding):
|
779
|
+
raise TypeError(
|
780
|
+
f"Old embeddings are of type {type(old_embeddings)}, which is not an instance of {nn.Embedding}. You"
|
781
|
+
" should either use a different resize function or make sure that old_embeddings are an instance of"
|
782
|
+
f" {nn.Embedding}."
|
783
|
+
)
|
784
|
+
|
785
|
+
# Build new embeddings
|
786
|
+
new_embeddings = nn.Embedding(
|
787
|
+
new_num_tokens,
|
788
|
+
old_embedding_dim,
|
789
|
+
padding_idx=old_embeddings._padding_idx,
|
790
|
+
sparse=old_embeddings._sparse,
|
791
|
+
)
|
792
|
+
|
793
|
+
# make sure that new_embeddings's dtype is same as the old embeddings' dtype
|
794
|
+
if new_embeddings.weight.dtype != old_embeddings.weight.dtype:
|
795
|
+
new_embeddings.to(dtype=old_embeddings.weight.dtype)
|
796
|
+
|
797
|
+
# numbers of tokens to copy
|
798
|
+
n = min(old_num_tokens, new_num_tokens)
|
799
|
+
with paddle.no_grad():
|
800
|
+
new_embeddings.weight[:n, :] = old_embeddings.weight[:n, :]
|
801
|
+
|
802
|
+
return new_embeddings
|
803
|
+
|
804
|
+
def __setattr__(self, name, value):
|
805
|
+
value = adapt_stale_fwd_patch(self, name, value)
|
806
|
+
return super(PretrainedModel, self).__setattr__(name, value)
|
807
|
+
|
808
|
+
@classmethod
|
809
|
+
def _resolve_model_file_path(
|
810
|
+
cls: Type[PretrainedModel],
|
811
|
+
pretrained_model_name_or_path: str,
|
812
|
+
from_hf_hub: bool = False,
|
813
|
+
from_aistudio: bool = False,
|
814
|
+
cache_dir: str | None = None,
|
815
|
+
subfolder: Optional[str] = "",
|
816
|
+
config: PretrainedConfig = None,
|
817
|
+
convert_from_torch: bool = False,
|
818
|
+
use_safetensors: bool | None = None,
|
819
|
+
variant=None,
|
820
|
+
) -> str:
|
821
|
+
"""resolve model target file path from `` and `cache_dir`
|
822
|
+
|
823
|
+
1. when it is file path:
|
824
|
+
return the weight file
|
825
|
+
|
826
|
+
2. when it is model-name:
|
827
|
+
2.1 check default `MODEL_HOME` + `model-mame` + model_state.pdparams
|
828
|
+
2.2 get the url from `pretrained_resource_files_map`, and set it to `pretrained_model_name_or_path`
|
829
|
+
|
830
|
+
3. when it is local dir:
|
831
|
+
check whether the file<local_dir + weight_file> exist
|
832
|
+
|
833
|
+
Args:
|
834
|
+
cls (Type[PretrainedModel]): the inherited PretrainedModel class
|
835
|
+
pretrained_model_name_or_path (str): the model-name/url/local_dir/local_dir
|
836
|
+
cache_dir (Optional[str], optional): cache_dir is used when name_or_path is model-name/url. Defaults to None.
|
837
|
+
convert_from_torch (bool, optional): whether support convert pytorch model to paddle model
|
838
|
+
|
839
|
+
Returns:
|
840
|
+
str: the model weight file path
|
841
|
+
"""
|
842
|
+
is_sharded = False
|
843
|
+
sharded_metadata = None
|
844
|
+
|
845
|
+
if pretrained_model_name_or_path is not None:
|
846
|
+
# the following code use a lot of os.path.join, hence setting subfolder to empty str if None
|
847
|
+
if subfolder is None:
|
848
|
+
subfolder = ""
|
849
|
+
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
850
|
+
is_local = os.path.isdir(pretrained_model_name_or_path)
|
851
|
+
|
852
|
+
def get_file_path(
|
853
|
+
pretrained_model_name_or_path, subfolder, SAFE_WEIGHTS_NAME, variant
|
854
|
+
):
|
855
|
+
return os.path.join(
|
856
|
+
pretrained_model_name_or_path,
|
857
|
+
subfolder,
|
858
|
+
_add_variant(SAFE_WEIGHTS_NAME, variant),
|
859
|
+
)
|
860
|
+
|
861
|
+
# pretrained_model_name_or_path is file
|
862
|
+
if os.path.isfile(pretrained_model_name_or_path):
|
863
|
+
archive_file = pretrained_model_name_or_path
|
864
|
+
is_local = True
|
865
|
+
# pretrained_model_name_or_path is dir
|
866
|
+
elif is_local:
|
867
|
+
if use_safetensors is not False and os.path.isfile(
|
868
|
+
get_file_path(
|
869
|
+
pretrained_model_name_or_path,
|
870
|
+
subfolder,
|
871
|
+
SAFE_WEIGHTS_INDEX_NAME,
|
872
|
+
variant,
|
873
|
+
)
|
874
|
+
):
|
875
|
+
# Load from a sharded safetensors checkpoint
|
876
|
+
archive_file = get_file_path(
|
877
|
+
pretrained_model_name_or_path,
|
878
|
+
subfolder,
|
879
|
+
SAFE_WEIGHTS_INDEX_NAME,
|
880
|
+
variant,
|
881
|
+
)
|
882
|
+
is_sharded = True
|
883
|
+
elif use_safetensors is not False and os.path.isfile(
|
884
|
+
get_file_path(
|
885
|
+
pretrained_model_name_or_path,
|
886
|
+
subfolder,
|
887
|
+
SAFE_WEIGHTS_INDEX_NAME,
|
888
|
+
weight_name_suffix(),
|
889
|
+
)
|
890
|
+
):
|
891
|
+
# Load from a sharded safetensors checkpoint
|
892
|
+
archive_file = get_file_path(
|
893
|
+
pretrained_model_name_or_path,
|
894
|
+
subfolder,
|
895
|
+
SAFE_WEIGHTS_INDEX_NAME,
|
896
|
+
weight_name_suffix(),
|
897
|
+
)
|
898
|
+
is_sharded = True
|
899
|
+
elif use_safetensors is not False and os.path.isfile(
|
900
|
+
get_file_path(
|
901
|
+
pretrained_model_name_or_path,
|
902
|
+
subfolder,
|
903
|
+
SAFE_WEIGHTS_NAME,
|
904
|
+
variant,
|
905
|
+
)
|
906
|
+
):
|
907
|
+
# Load from a safetensors checkpoint
|
908
|
+
archive_file = get_file_path(
|
909
|
+
pretrained_model_name_or_path,
|
910
|
+
subfolder,
|
911
|
+
SAFE_WEIGHTS_NAME,
|
912
|
+
variant,
|
913
|
+
)
|
914
|
+
elif use_safetensors is not False and os.path.isfile(
|
915
|
+
get_file_path(
|
916
|
+
pretrained_model_name_or_path,
|
917
|
+
subfolder,
|
918
|
+
SAFE_WEIGHTS_NAME,
|
919
|
+
weight_name_suffix(),
|
920
|
+
)
|
921
|
+
):
|
922
|
+
# Load from a safetensors checkpoint
|
923
|
+
archive_file = get_file_path(
|
924
|
+
pretrained_model_name_or_path,
|
925
|
+
subfolder,
|
926
|
+
SAFE_WEIGHTS_NAME,
|
927
|
+
weight_name_suffix(),
|
928
|
+
)
|
929
|
+
elif os.path.isfile(
|
930
|
+
get_file_path(
|
931
|
+
pretrained_model_name_or_path,
|
932
|
+
subfolder,
|
933
|
+
PADDLE_WEIGHTS_INDEX_NAME,
|
934
|
+
variant,
|
935
|
+
)
|
936
|
+
):
|
937
|
+
# Load from a sharded PaddlePaddle checkpoint
|
938
|
+
archive_file = get_file_path(
|
939
|
+
pretrained_model_name_or_path,
|
940
|
+
subfolder,
|
941
|
+
PADDLE_WEIGHTS_INDEX_NAME,
|
942
|
+
variant,
|
943
|
+
)
|
944
|
+
is_sharded = True
|
945
|
+
elif os.path.isfile(
|
946
|
+
get_file_path(
|
947
|
+
pretrained_model_name_or_path,
|
948
|
+
subfolder,
|
949
|
+
PADDLE_WEIGHTS_INDEX_NAME,
|
950
|
+
weight_name_suffix(),
|
951
|
+
)
|
952
|
+
):
|
953
|
+
# Load from a sharded PaddlePaddle checkpoint for hybrid parallel model
|
954
|
+
archive_file = get_file_path(
|
955
|
+
pretrained_model_name_or_path,
|
956
|
+
subfolder,
|
957
|
+
PADDLE_WEIGHTS_INDEX_NAME,
|
958
|
+
weight_name_suffix(),
|
959
|
+
)
|
960
|
+
is_sharded = True
|
961
|
+
elif os.path.isfile(
|
962
|
+
get_file_path(
|
963
|
+
pretrained_model_name_or_path,
|
964
|
+
subfolder,
|
965
|
+
PADDLE_WEIGHTS_NAME,
|
966
|
+
variant,
|
967
|
+
)
|
968
|
+
):
|
969
|
+
# Load from a PaddlePaddle checkpoint
|
970
|
+
archive_file = get_file_path(
|
971
|
+
pretrained_model_name_or_path,
|
972
|
+
subfolder,
|
973
|
+
PADDLE_WEIGHTS_NAME,
|
974
|
+
variant,
|
975
|
+
)
|
976
|
+
elif os.path.isfile(
|
977
|
+
get_file_path(
|
978
|
+
pretrained_model_name_or_path,
|
979
|
+
subfolder,
|
980
|
+
PADDLE_WEIGHTS_NAME,
|
981
|
+
weight_name_suffix(),
|
982
|
+
)
|
983
|
+
):
|
984
|
+
# Load from a PaddlePaddle checkpoint for hybrid parallel model
|
985
|
+
archive_file = get_file_path(
|
986
|
+
pretrained_model_name_or_path,
|
987
|
+
subfolder,
|
988
|
+
PADDLE_WEIGHTS_NAME,
|
989
|
+
weight_name_suffix(),
|
990
|
+
)
|
991
|
+
elif os.path.isfile(
|
992
|
+
os.path.join(
|
993
|
+
pretrained_model_name_or_path,
|
994
|
+
subfolder,
|
995
|
+
_add_variant(PYTORCH_WEIGHTS_INDEX_NAME, variant),
|
996
|
+
)
|
997
|
+
):
|
998
|
+
if from_hf_hub or convert_from_torch:
|
999
|
+
archive_file = os.path.join(
|
1000
|
+
pretrained_model_name_or_path,
|
1001
|
+
subfolder,
|
1002
|
+
_add_variant(PYTORCH_WEIGHTS_INDEX_NAME, variant),
|
1003
|
+
)
|
1004
|
+
else:
|
1005
|
+
raise ValueError(
|
1006
|
+
f"Found {_add_variant(PYTORCH_WEIGHTS_INDEX_NAME, variant)} in directory"
|
1007
|
+
f" {pretrained_model_name_or_path}. Please set convert_from_torch=True in from_pretrained. eg, Model.from_pretrained(model_name, convert_from_torch=True) "
|
1008
|
+
)
|
1009
|
+
elif os.path.isfile(
|
1010
|
+
os.path.join(
|
1011
|
+
pretrained_model_name_or_path,
|
1012
|
+
subfolder,
|
1013
|
+
_add_variant(PYTORCH_WEIGHTS_NAME, variant),
|
1014
|
+
)
|
1015
|
+
):
|
1016
|
+
if from_hf_hub or convert_from_torch:
|
1017
|
+
archive_file = os.path.join(
|
1018
|
+
pretrained_model_name_or_path,
|
1019
|
+
subfolder,
|
1020
|
+
_add_variant(PYTORCH_WEIGHTS_NAME, variant),
|
1021
|
+
)
|
1022
|
+
else:
|
1023
|
+
raise ValueError(
|
1024
|
+
f"Found {_add_variant(PYTORCH_WEIGHTS_NAME, variant)} in directory"
|
1025
|
+
f" {pretrained_model_name_or_path}. Please set convert_from_torch=True in from_pretrained. eg, Model.from_pretrained(model_name, convert_from_torch=True) "
|
1026
|
+
)
|
1027
|
+
else:
|
1028
|
+
raise EnvironmentError(
|
1029
|
+
f"Error no file named {_add_variant(PADDLE_WEIGHTS_NAME, variant)}, found in directory"
|
1030
|
+
f" {pretrained_model_name_or_path}."
|
1031
|
+
)
|
1032
|
+
|
1033
|
+
elif pretrained_model_name_or_path in cls.pretrained_init_configuration:
|
1034
|
+
# fetch the weight url from the `pretrained_resource_files_map`
|
1035
|
+
resource_file_url = cls.pretrained_resource_files_map["model_state"][
|
1036
|
+
pretrained_model_name_or_path
|
1037
|
+
]
|
1038
|
+
resolved_archive_file = resolve_file_path(
|
1039
|
+
pretrained_model_name_or_path,
|
1040
|
+
[resource_file_url],
|
1041
|
+
subfolder,
|
1042
|
+
cache_dir=cache_dir,
|
1043
|
+
from_aistudio=from_aistudio,
|
1044
|
+
from_hf_hub=from_hf_hub,
|
1045
|
+
)
|
1046
|
+
else:
|
1047
|
+
if use_safetensors is True:
|
1048
|
+
filenames = [
|
1049
|
+
_add_variant(SAFE_WEIGHTS_INDEX_NAME, variant),
|
1050
|
+
_add_variant(SAFE_WEIGHTS_NAME, variant),
|
1051
|
+
]
|
1052
|
+
elif use_safetensors is None:
|
1053
|
+
filenames = [
|
1054
|
+
_add_variant(SAFE_WEIGHTS_INDEX_NAME, variant),
|
1055
|
+
_add_variant(PADDLE_WEIGHTS_INDEX_NAME, variant),
|
1056
|
+
_add_variant(SAFE_WEIGHTS_NAME, variant),
|
1057
|
+
_add_variant(PADDLE_WEIGHTS_NAME, variant),
|
1058
|
+
_add_variant(PYTORCH_WEIGHTS_INDEX_NAME, variant),
|
1059
|
+
_add_variant(PYTORCH_WEIGHTS_NAME, variant),
|
1060
|
+
]
|
1061
|
+
else:
|
1062
|
+
filenames = [
|
1063
|
+
_add_variant(PADDLE_WEIGHTS_INDEX_NAME, variant),
|
1064
|
+
_add_variant(PADDLE_WEIGHTS_NAME, variant),
|
1065
|
+
_add_variant(PYTORCH_WEIGHTS_INDEX_NAME, variant),
|
1066
|
+
_add_variant(PYTORCH_WEIGHTS_NAME, variant),
|
1067
|
+
]
|
1068
|
+
resolved_archive_file = resolve_file_path(
|
1069
|
+
pretrained_model_name_or_path,
|
1070
|
+
filenames,
|
1071
|
+
subfolder,
|
1072
|
+
cache_dir=cache_dir,
|
1073
|
+
from_aistudio=from_aistudio,
|
1074
|
+
from_hf_hub=from_hf_hub,
|
1075
|
+
)
|
1076
|
+
if resolved_archive_file is None:
|
1077
|
+
raise EnvironmentError(
|
1078
|
+
f"Error no files {filenames} found in repo {pretrained_model_name_or_path}."
|
1079
|
+
)
|
1080
|
+
elif "pytorch_model.bin" in str(resolved_archive_file):
|
1081
|
+
if not from_hf_hub and not convert_from_torch:
|
1082
|
+
raise ValueError(
|
1083
|
+
f"Download pytorch weight in "
|
1084
|
+
f" {resolved_archive_file}. Please set convert_from_torch=True in from_pretrained. eg, Model.from_pretrained(model_name, convert_from_torch=True) "
|
1085
|
+
)
|
1086
|
+
|
1087
|
+
if is_local:
|
1088
|
+
logging.info(f"Loading weights file {archive_file}")
|
1089
|
+
resolved_archive_file = archive_file
|
1090
|
+
else:
|
1091
|
+
logging.info(
|
1092
|
+
f"Loading weights file from cache at {resolved_archive_file}"
|
1093
|
+
)
|
1094
|
+
else:
|
1095
|
+
resolved_archive_file = None
|
1096
|
+
|
1097
|
+
# We'll need to download and cache each checkpoint shard if the checkpoint is sharded.
|
1098
|
+
resolved_sharded_files = None
|
1099
|
+
if str(resolved_archive_file).endswith(".json"):
|
1100
|
+
is_sharded = True
|
1101
|
+
if is_sharded:
|
1102
|
+
# resolved_archive_file becomes a list of files that point to the different checkpoint shards in this case.
|
1103
|
+
resolved_sharded_files, sharded_metadata = get_checkpoint_shard_files(
|
1104
|
+
pretrained_model_name_or_path,
|
1105
|
+
resolved_archive_file,
|
1106
|
+
from_aistudio=from_aistudio,
|
1107
|
+
from_hf_hub=from_hf_hub,
|
1108
|
+
cache_dir=cache_dir,
|
1109
|
+
subfolder=subfolder,
|
1110
|
+
)
|
1111
|
+
|
1112
|
+
return (
|
1113
|
+
resolved_archive_file,
|
1114
|
+
resolved_sharded_files,
|
1115
|
+
sharded_metadata,
|
1116
|
+
is_sharded,
|
1117
|
+
)
|
1118
|
+
|
1119
|
+
@classmethod
|
1120
|
+
def _load_pretrained_model(
|
1121
|
+
cls,
|
1122
|
+
model: PretrainedModel,
|
1123
|
+
state_dict: Dict[str, Tensor],
|
1124
|
+
loaded_keys: List[str],
|
1125
|
+
resolved_archive_file: Union[str, List] = [],
|
1126
|
+
pretrained_model_name_or_path=None,
|
1127
|
+
config=None,
|
1128
|
+
ignore_mismatched_sizes=False,
|
1129
|
+
low_cpu_mem_usage=False,
|
1130
|
+
dtype=None,
|
1131
|
+
keep_in_fp32_modules=None,
|
1132
|
+
quantization_linear_list=None,
|
1133
|
+
sharded_metadata=None,
|
1134
|
+
) -> Tuple[List[str]]:
|
1135
|
+
"""load the state_dict into model, and do the following things:
|
1136
|
+
|
1137
|
+
* check the
|
1138
|
+
|
1139
|
+
Args:
|
1140
|
+
model (PretrainedModel): the pretrained model instance
|
1141
|
+
state_dict (Dict[str, Tensor]): the model state dict data
|
1142
|
+
loaded_keys (List[str]):
|
1143
|
+
ignore_mismatched_sizes (bool, optional): whether ignore error when tensor size mismatched. Defaults to False.
|
1144
|
+
dtype (_type_, optional): the dtype of model state dict. Defaults to None.
|
1145
|
+
|
1146
|
+
Returns:
|
1147
|
+
Tuple[List[str]]: _description_
|
1148
|
+
"""
|
1149
|
+
is_safetensors = False
|
1150
|
+
|
1151
|
+
model_state_dict = model.state_dict()
|
1152
|
+
|
1153
|
+
expected_keys = list(model_state_dict.keys())
|
1154
|
+
prefix = model.base_model_prefix
|
1155
|
+
|
1156
|
+
if len(prefix) > 0:
|
1157
|
+
has_prefix_module = any(s.startswith(prefix) for s in loaded_keys)
|
1158
|
+
expects_prefix_module = any(s.startswith(prefix) for s in expected_keys)
|
1159
|
+
else:
|
1160
|
+
has_prefix_module = False
|
1161
|
+
expects_prefix_module = False
|
1162
|
+
|
1163
|
+
# key re-naming operations are never done on the keys
|
1164
|
+
# that are loaded, but always on the keys of the newly initialized model
|
1165
|
+
remove_prefix_from_model = not has_prefix_module and expects_prefix_module
|
1166
|
+
add_prefix_to_model = has_prefix_module and not expects_prefix_module
|
1167
|
+
|
1168
|
+
if remove_prefix_from_model:
|
1169
|
+
_prefix = f"{prefix}."
|
1170
|
+
expected_keys_not_prefixed = [
|
1171
|
+
s for s in expected_keys if not s.startswith(_prefix)
|
1172
|
+
]
|
1173
|
+
expected_keys = [
|
1174
|
+
s[len(_prefix) :] if s.startswith(_prefix) else s for s in expected_keys
|
1175
|
+
]
|
1176
|
+
if quantization_linear_list is not None:
|
1177
|
+
quantization_linear_list = [
|
1178
|
+
s[len(_prefix) :] if s.startswith(_prefix) else s
|
1179
|
+
for s in quantization_linear_list
|
1180
|
+
]
|
1181
|
+
elif add_prefix_to_model:
|
1182
|
+
expected_keys = [".".join([prefix, s]) for s in expected_keys]
|
1183
|
+
if quantization_linear_list is not None:
|
1184
|
+
quantization_linear_list = [
|
1185
|
+
".".join([prefix, s]) for s in quantization_linear_list
|
1186
|
+
]
|
1187
|
+
|
1188
|
+
# Weight quantization if not yet quantized & update loaded_keys
|
1189
|
+
if (
|
1190
|
+
hasattr(config, "quantization_config")
|
1191
|
+
and config.quantization_config.is_weight_quantize()
|
1192
|
+
):
|
1193
|
+
try:
|
1194
|
+
from ..quantization.quantization_utils import (
|
1195
|
+
convert_to_quantize_state_dict,
|
1196
|
+
update_loaded_state_dict_keys,
|
1197
|
+
)
|
1198
|
+
except ImportError:
|
1199
|
+
raise ImportError(
|
1200
|
+
"Quantization features require `paddlepaddle >= 2.5.2`"
|
1201
|
+
)
|
1202
|
+
if state_dict is not None:
|
1203
|
+
state_dict = convert_to_quantize_state_dict(
|
1204
|
+
state_dict,
|
1205
|
+
quantization_linear_list,
|
1206
|
+
config.quantization_config,
|
1207
|
+
dtype,
|
1208
|
+
)
|
1209
|
+
loaded_keys = [k for k in state_dict.keys()]
|
1210
|
+
else:
|
1211
|
+
loaded_keys = update_loaded_state_dict_keys(
|
1212
|
+
loaded_keys, quantization_linear_list, config.quantization_config
|
1213
|
+
)
|
1214
|
+
if keep_in_fp32_modules is None:
|
1215
|
+
keep_in_fp32_modules = (
|
1216
|
+
["quant_scale"]
|
1217
|
+
if config.quantization_config.weight_quantize_algo in ["nf4", "fp4"]
|
1218
|
+
else None
|
1219
|
+
)
|
1220
|
+
else:
|
1221
|
+
keep_in_fp32_modules = (
|
1222
|
+
keep_in_fp32_modules + ["quant_scale"]
|
1223
|
+
if config.quantization_config.weight_quantize_algo in ["nf4", "fp4"]
|
1224
|
+
else keep_in_fp32_modules
|
1225
|
+
)
|
1226
|
+
|
1227
|
+
missing_keys = list(set(expected_keys) - set(loaded_keys))
|
1228
|
+
unexpected_keys = list(set(loaded_keys) - set(expected_keys))
|
1229
|
+
|
1230
|
+
# Optimize for skip unused shard files for supper large model
|
1231
|
+
if sharded_metadata is not None:
|
1232
|
+
assert isinstance(resolved_archive_file, list)
|
1233
|
+
new_archive_file = []
|
1234
|
+
skip_archive_file = []
|
1235
|
+
expected_keys_set = set(expected_keys)
|
1236
|
+
for file in resolved_archive_file:
|
1237
|
+
filename = os.path.split(file)[-1]
|
1238
|
+
if not expected_keys_set.isdisjoint(
|
1239
|
+
set(sharded_metadata["file_map"][filename])
|
1240
|
+
):
|
1241
|
+
new_archive_file.append(file)
|
1242
|
+
else:
|
1243
|
+
skip_archive_file.append(filename)
|
1244
|
+
|
1245
|
+
resolved_archive_file = new_archive_file
|
1246
|
+
if len(skip_archive_file) > 0:
|
1247
|
+
logging.info(
|
1248
|
+
f"Skip load files for not contrains expected key, {skip_archive_file}"
|
1249
|
+
)
|
1250
|
+
|
1251
|
+
# Some models may have keys that are not in the state by design, removing them before needlessly warning
|
1252
|
+
# the user.
|
1253
|
+
if cls._keys_to_ignore_on_load_missing is not None:
|
1254
|
+
for pat in cls._keys_to_ignore_on_load_missing:
|
1255
|
+
missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
|
1256
|
+
|
1257
|
+
if cls._keys_to_ignore_on_load_unexpected is not None:
|
1258
|
+
for pat in cls._keys_to_ignore_on_load_unexpected:
|
1259
|
+
unexpected_keys = [
|
1260
|
+
k for k in unexpected_keys if re.search(pat, k) is None
|
1261
|
+
]
|
1262
|
+
|
1263
|
+
# Set some modules to fp32 if any
|
1264
|
+
if keep_in_fp32_modules is not None:
|
1265
|
+
for name, param in model.named_parameters():
|
1266
|
+
if any(
|
1267
|
+
module_to_keep_in_fp32 in name
|
1268
|
+
for module_to_keep_in_fp32 in keep_in_fp32_modules
|
1269
|
+
):
|
1270
|
+
if param.dtype != paddle.float32:
|
1271
|
+
param_fp32 = param.cast(dtype=paddle.float32)
|
1272
|
+
param_fp32_tensor = param_fp32.value().get_tensor()
|
1273
|
+
param_tensor = param.value().get_tensor()
|
1274
|
+
param_tensor._share_data_with(param_fp32_tensor)
|
1275
|
+
|
1276
|
+
# Make sure we are able to load base models as well as derived models (with heads)
|
1277
|
+
start_prefix = ""
|
1278
|
+
model_to_load = model
|
1279
|
+
if (
|
1280
|
+
len(cls.base_model_prefix) > 0
|
1281
|
+
and not hasattr(model, cls.base_model_prefix)
|
1282
|
+
and has_prefix_module
|
1283
|
+
):
|
1284
|
+
start_prefix = cls.base_model_prefix + "."
|
1285
|
+
if (
|
1286
|
+
len(cls.base_model_prefix) > 0
|
1287
|
+
and hasattr(model, cls.base_model_prefix)
|
1288
|
+
and not has_prefix_module
|
1289
|
+
):
|
1290
|
+
model_to_load = getattr(model, cls.base_model_prefix)
|
1291
|
+
base_model_expected_keys = list(model_to_load.state_dict().keys())
|
1292
|
+
if any(
|
1293
|
+
key in expected_keys_not_prefixed
|
1294
|
+
and key not in base_model_expected_keys
|
1295
|
+
for key in loaded_keys
|
1296
|
+
):
|
1297
|
+
raise ValueError(
|
1298
|
+
"The state dictionary of the model you are trying to load is corrupted. Are you sure it was "
|
1299
|
+
"properly saved?"
|
1300
|
+
)
|
1301
|
+
|
1302
|
+
def _find_mismatched_keys(
|
1303
|
+
state_dict,
|
1304
|
+
model_state_dict,
|
1305
|
+
loaded_keys,
|
1306
|
+
add_prefix_to_model,
|
1307
|
+
remove_prefix_from_model,
|
1308
|
+
ignore_mismatched_sizes,
|
1309
|
+
):
|
1310
|
+
mismatched_keys = []
|
1311
|
+
if ignore_mismatched_sizes:
|
1312
|
+
for checkpoint_key in loaded_keys:
|
1313
|
+
# If the checkpoint is sharded, we may not have the key here.
|
1314
|
+
if checkpoint_key not in state_dict:
|
1315
|
+
continue
|
1316
|
+
model_key = checkpoint_key
|
1317
|
+
if remove_prefix_from_model:
|
1318
|
+
# The model key starts with `prefix` but `checkpoint_key` doesn't so we add it.
|
1319
|
+
model_key = f"{prefix}.{checkpoint_key}"
|
1320
|
+
elif add_prefix_to_model:
|
1321
|
+
# The model key doesn't start with `prefix` but `checkpoint_key` does so we remove it.
|
1322
|
+
model_key = ".".join(checkpoint_key.split(".")[1:])
|
1323
|
+
|
1324
|
+
if (
|
1325
|
+
model_key in model_state_dict
|
1326
|
+
and state_dict[checkpoint_key].shape
|
1327
|
+
!= model_state_dict[model_key].shape
|
1328
|
+
):
|
1329
|
+
mismatched_keys.append(
|
1330
|
+
(
|
1331
|
+
checkpoint_key,
|
1332
|
+
state_dict[checkpoint_key].shape,
|
1333
|
+
model_state_dict[model_key].shape,
|
1334
|
+
)
|
1335
|
+
)
|
1336
|
+
del state_dict[checkpoint_key]
|
1337
|
+
return mismatched_keys
|
1338
|
+
|
1339
|
+
def _fuse_or_split_keys(
|
1340
|
+
state_dict,
|
1341
|
+
config,
|
1342
|
+
loaded_keys,
|
1343
|
+
pre_tensor_parallel_split=False,
|
1344
|
+
resume_state_dict=None,
|
1345
|
+
):
|
1346
|
+
if resume_state_dict is not None:
|
1347
|
+
state_dict.update(resume_state_dict)
|
1348
|
+
|
1349
|
+
before_fuse_keys = list(state_dict.keys())
|
1350
|
+
if pre_tensor_parallel_split:
|
1351
|
+
tp_actions = cls.get_tensor_parallel_convert_actions(
|
1352
|
+
config, loaded_keys, ignore_error=True
|
1353
|
+
)
|
1354
|
+
else:
|
1355
|
+
tp_actions = None
|
1356
|
+
state_dict, resume_state_dict = cls.convert_fuse_and_split(
|
1357
|
+
config, state_dict, tp_actions
|
1358
|
+
)
|
1359
|
+
after_fuse_keys = list(state_dict.keys())
|
1360
|
+
|
1361
|
+
fused_keys = list(set(before_fuse_keys) - set(after_fuse_keys))
|
1362
|
+
new_keys = list(set(after_fuse_keys) - set(before_fuse_keys))
|
1363
|
+
|
1364
|
+
return state_dict, resume_state_dict, fused_keys, new_keys
|
1365
|
+
|
1366
|
+
if state_dict is not None:
|
1367
|
+
# have loaded all state_dict, no resume state_dict
|
1368
|
+
state_dict, _, fused_keys, new_keys = _fuse_or_split_keys(
|
1369
|
+
state_dict,
|
1370
|
+
config,
|
1371
|
+
loaded_keys,
|
1372
|
+
pre_tensor_parallel_split=(
|
1373
|
+
True
|
1374
|
+
if config is not None and config.tensor_parallel_degree > 1
|
1375
|
+
else False
|
1376
|
+
),
|
1377
|
+
)
|
1378
|
+
missing_keys = list(set(missing_keys) - set(new_keys))
|
1379
|
+
unexpected_keys = list(set(unexpected_keys) - set(fused_keys))
|
1380
|
+
|
1381
|
+
mismatched_keys = _find_mismatched_keys(
|
1382
|
+
state_dict,
|
1383
|
+
model_state_dict,
|
1384
|
+
loaded_keys,
|
1385
|
+
add_prefix_to_model,
|
1386
|
+
remove_prefix_from_model,
|
1387
|
+
ignore_mismatched_sizes,
|
1388
|
+
)
|
1389
|
+
|
1390
|
+
if (
|
1391
|
+
hasattr(config, "quantization_config")
|
1392
|
+
and config.quantization_config.is_weight_quantize()
|
1393
|
+
):
|
1394
|
+
error_msgs = _load_state_dict_into_meta_model(
|
1395
|
+
model_to_load,
|
1396
|
+
state_dict,
|
1397
|
+
loaded_keys,
|
1398
|
+
start_prefix,
|
1399
|
+
expected_keys,
|
1400
|
+
dtype=dtype,
|
1401
|
+
is_safetensors=is_safetensors,
|
1402
|
+
keep_in_fp32_modules=keep_in_fp32_modules,
|
1403
|
+
)
|
1404
|
+
else:
|
1405
|
+
error_msgs = _load_state_dict_into_model(
|
1406
|
+
model_to_load, state_dict, start_prefix
|
1407
|
+
)
|
1408
|
+
else:
|
1409
|
+
# Sharded checkpoint or whole but low_cpu_mem_usage==True
|
1410
|
+
|
1411
|
+
# This should always be a list but, just to be sure.
|
1412
|
+
if not isinstance(resolved_archive_file, list):
|
1413
|
+
resolved_archive_file = [resolved_archive_file]
|
1414
|
+
|
1415
|
+
error_msgs = []
|
1416
|
+
mismatched_keys = []
|
1417
|
+
resume_state_dict = {}
|
1418
|
+
|
1419
|
+
for shard_file in resolved_archive_file:
|
1420
|
+
pre_tensor_parallel_split = False
|
1421
|
+
if (
|
1422
|
+
shard_file.endswith(".safetensors")
|
1423
|
+
and config.tensor_parallel_degree > 1
|
1424
|
+
and "tp" not in os.path.split(shard_file)[-1]
|
1425
|
+
):
|
1426
|
+
pre_tensor_parallel_split = True
|
1427
|
+
assert loaded_keys is not None, "loaded_keys is not None."
|
1428
|
+
tp_actions = cls.get_tensor_parallel_convert_actions(
|
1429
|
+
config, loaded_keys, ignore_error=True
|
1430
|
+
)
|
1431
|
+
# Here we use expected_keys to optimize weights loading for pipeline model. Only works for safetensors
|
1432
|
+
filter_dict_keys = set(expected_keys)
|
1433
|
+
fuse_actions, _ = cls.get_fuse_or_split_param_convert_actions(
|
1434
|
+
config, loaded_keys, is_fuse=True
|
1435
|
+
)
|
1436
|
+
split_actions, _ = cls.get_fuse_or_split_param_convert_actions(
|
1437
|
+
config, loaded_keys, is_fuse=False
|
1438
|
+
)
|
1439
|
+
for k in list(fuse_actions.keys()):
|
1440
|
+
need_add_except_key = k[-1] in expected_keys
|
1441
|
+
if need_add_except_key:
|
1442
|
+
filter_dict_keys |= set(k[:-1])
|
1443
|
+
# remove pre_tensor_parallel_split function from tp_actions
|
1444
|
+
if pre_tensor_parallel_split:
|
1445
|
+
for item in k[:-1]:
|
1446
|
+
if item in tp_actions:
|
1447
|
+
tp_actions.pop(item, None)
|
1448
|
+
|
1449
|
+
for k in list(split_actions.keys()):
|
1450
|
+
need_add_except_key = False
|
1451
|
+
for item in k[:-1]:
|
1452
|
+
if item in expected_keys:
|
1453
|
+
need_add_except_key = True
|
1454
|
+
break
|
1455
|
+
if need_add_except_key:
|
1456
|
+
filter_dict_keys.add(k[-1])
|
1457
|
+
# remove pre_tensor_parallel_split function from tp_actions
|
1458
|
+
if pre_tensor_parallel_split:
|
1459
|
+
if k[-1] in tp_actions:
|
1460
|
+
fuse_actions.pop(k[-1], None)
|
1461
|
+
|
1462
|
+
if config.quantization_config.is_weight_quantize():
|
1463
|
+
filter_dict_keys = None
|
1464
|
+
state_dict = load_state_dict(
|
1465
|
+
shard_file,
|
1466
|
+
tp_actions if pre_tensor_parallel_split else None,
|
1467
|
+
filter_dict_keys,
|
1468
|
+
)
|
1469
|
+
|
1470
|
+
# convert for fusing or splitting weights
|
1471
|
+
state_dict, resume_state_dict, fused_keys, new_keys = (
|
1472
|
+
_fuse_or_split_keys(
|
1473
|
+
state_dict,
|
1474
|
+
config,
|
1475
|
+
loaded_keys,
|
1476
|
+
pre_tensor_parallel_split=pre_tensor_parallel_split,
|
1477
|
+
resume_state_dict=resume_state_dict,
|
1478
|
+
)
|
1479
|
+
)
|
1480
|
+
missing_keys = list(set(missing_keys) - set(new_keys))
|
1481
|
+
unexpected_keys = list(set(unexpected_keys) - set(fused_keys))
|
1482
|
+
|
1483
|
+
if config.quantization_config.is_weight_quantize():
|
1484
|
+
state_dict = convert_to_quantize_state_dict(
|
1485
|
+
state_dict,
|
1486
|
+
quantization_linear_list,
|
1487
|
+
config.quantization_config,
|
1488
|
+
dtype,
|
1489
|
+
)
|
1490
|
+
|
1491
|
+
# Mismatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
|
1492
|
+
# matching the weights in the model.
|
1493
|
+
mismatched_keys += _find_mismatched_keys(
|
1494
|
+
state_dict,
|
1495
|
+
model_state_dict,
|
1496
|
+
loaded_keys,
|
1497
|
+
add_prefix_to_model,
|
1498
|
+
remove_prefix_from_model,
|
1499
|
+
ignore_mismatched_sizes,
|
1500
|
+
)
|
1501
|
+
|
1502
|
+
if (
|
1503
|
+
config.tensor_parallel_degree > 1
|
1504
|
+
and ".tp" not in shard_file
|
1505
|
+
and not pre_tensor_parallel_split
|
1506
|
+
):
|
1507
|
+
logging.info("Converting state_dict to Tensor Parallel Format")
|
1508
|
+
# ignore error for multi shard, since only parts of data
|
1509
|
+
state_dict = cls.convert_tensor_parallel(
|
1510
|
+
None,
|
1511
|
+
config,
|
1512
|
+
state_dict=state_dict,
|
1513
|
+
ignore_error=len(resolved_archive_file) > 1,
|
1514
|
+
)
|
1515
|
+
logging.info("Converted state_dict to Tensor Parallel Format")
|
1516
|
+
|
1517
|
+
if low_cpu_mem_usage or config.quantization_config.is_weight_quantize():
|
1518
|
+
new_error_msgs = _load_state_dict_into_meta_model(
|
1519
|
+
model_to_load,
|
1520
|
+
state_dict,
|
1521
|
+
loaded_keys,
|
1522
|
+
start_prefix,
|
1523
|
+
expected_keys,
|
1524
|
+
dtype=dtype,
|
1525
|
+
is_safetensors=is_safetensors,
|
1526
|
+
keep_in_fp32_modules=keep_in_fp32_modules,
|
1527
|
+
)
|
1528
|
+
error_msgs += new_error_msgs
|
1529
|
+
else:
|
1530
|
+
error_msgs += _load_state_dict_into_model(
|
1531
|
+
model_to_load, state_dict, start_prefix
|
1532
|
+
)
|
1533
|
+
|
1534
|
+
# force memory release
|
1535
|
+
del state_dict
|
1536
|
+
gc.collect()
|
1537
|
+
|
1538
|
+
if len(error_msgs) > 0:
|
1539
|
+
error_msg = "\n\t".join(error_msgs)
|
1540
|
+
if " but the expected shape is" in error_msg:
|
1541
|
+
error_msg += "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
|
1542
|
+
raise RuntimeError(
|
1543
|
+
f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}"
|
1544
|
+
)
|
1545
|
+
|
1546
|
+
if len(unexpected_keys) > 0:
|
1547
|
+
if logging.logging.level < 20:
|
1548
|
+
logging.warning(
|
1549
|
+
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
|
1550
|
+
f" initializing {model.__class__.__name__}: {sorted(unexpected_keys)}\n- This IS expected if you are"
|
1551
|
+
f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or"
|
1552
|
+
" with another architecture (e.g. initializing a BertForSequenceClassification model from a"
|
1553
|
+
" BertForPreTraining model).\n- This IS NOT expected if you are initializing"
|
1554
|
+
f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical"
|
1555
|
+
" (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
|
1556
|
+
)
|
1557
|
+
else:
|
1558
|
+
logging.warning(
|
1559
|
+
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
|
1560
|
+
f" initializing the model, - This IS expected if you are"
|
1561
|
+
f" initializing the model from a checkpoint of a model trained on another task or"
|
1562
|
+
" with another architecture."
|
1563
|
+
)
|
1564
|
+
else:
|
1565
|
+
logging.info(
|
1566
|
+
f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n"
|
1567
|
+
)
|
1568
|
+
|
1569
|
+
if len(missing_keys) > 0:
|
1570
|
+
logging.warning(
|
1571
|
+
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
|
1572
|
+
f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
|
1573
|
+
" TRAIN this model on a down-stream task to be able to use it for predictions and inference."
|
1574
|
+
)
|
1575
|
+
elif len(mismatched_keys) == 0:
|
1576
|
+
logging.info(
|
1577
|
+
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
|
1578
|
+
f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint"
|
1579
|
+
f" was trained on, you can already use {model.__class__.__name__} for predictions without further"
|
1580
|
+
" training."
|
1581
|
+
)
|
1582
|
+
if len(mismatched_keys) > 0:
|
1583
|
+
mismatched_warning = "\n".join(
|
1584
|
+
[
|
1585
|
+
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
|
1586
|
+
for key, shape1, shape2 in mismatched_keys
|
1587
|
+
]
|
1588
|
+
)
|
1589
|
+
logging.warning(
|
1590
|
+
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
|
1591
|
+
f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
|
1592
|
+
f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able"
|
1593
|
+
" to use it for predictions and inference."
|
1594
|
+
)
|
1595
|
+
|
1596
|
+
return model, missing_keys, unexpected_keys, mismatched_keys
|
1597
|
+
|
1598
|
+
@classmethod
|
1599
|
+
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
|
1600
|
+
"""
|
1601
|
+
Creates an instance of `PretrainedModel`. Model weights are loaded
|
1602
|
+
by specifying name of a built-in pretrained model, a pretrained model from HF Hub, a community contributed model,
|
1603
|
+
or a local file directory path.
|
1604
|
+
|
1605
|
+
Args:
|
1606
|
+
pretrained_model_name_or_path (str): Name of pretrained model or dir path
|
1607
|
+
to load from. The string can be:
|
1608
|
+
|
1609
|
+
- Name of a built-in pretrained model
|
1610
|
+
- Name of a pretrained model from HF Hub
|
1611
|
+
- Name of a community-contributed pretrained model.
|
1612
|
+
- Local directory path which contains model weights file("model_state.pdparams")
|
1613
|
+
and model config file ("model_config.json").
|
1614
|
+
from_hf_hub (bool): load model from huggingface hub. Default to `False`.
|
1615
|
+
subfolder (str, optional) An optional value corresponding to a folder inside the repo.
|
1616
|
+
Only works when loading from Huggingface Hub.
|
1617
|
+
*args (tuple): Position arguments for model `__init__`. If provided,
|
1618
|
+
use these as position argument values for model initialization.
|
1619
|
+
**kwargs (dict): Keyword arguments for model `__init__`. If provided,
|
1620
|
+
use these to update pre-defined keyword argument values for model
|
1621
|
+
initialization. If the keyword is in `__init__` argument names of
|
1622
|
+
base model, update argument values of the base model; else update
|
1623
|
+
argument values of derived model.
|
1624
|
+
load_state_as_np (bool, optional): The weights read in can be choosed
|
1625
|
+
to place on CPU or GPU though the model is on the default device.
|
1626
|
+
If `True`, load the model weights as `numpy.ndarray` on CPU.
|
1627
|
+
Otherwise, weights would be loaded as tensors on the default
|
1628
|
+
device. Note that if on GPU, the latter would creates extra
|
1629
|
+
temporary tensors in addition to the model weights, which
|
1630
|
+
doubles the memory usage . Thus it is suggested to use `True`
|
1631
|
+
for big models on GPU. Default to `False`.
|
1632
|
+
|
1633
|
+
Returns:
|
1634
|
+
PretrainedModel: An instance of `PretrainedModel`.
|
1635
|
+
|
1636
|
+
Example:
|
1637
|
+
.. code-block::
|
1638
|
+
|
1639
|
+
from paddlenlp.transformers import BertForSequenceClassification
|
1640
|
+
|
1641
|
+
# Name of built-in pretrained model
|
1642
|
+
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
|
1643
|
+
|
1644
|
+
# Name of pretrained model from PaddleHub
|
1645
|
+
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
|
1646
|
+
|
1647
|
+
# Name of community-contributed pretrained model
|
1648
|
+
model = BertForSequenceClassification.from_pretrained('yingyibiao/bert-base-uncased-sst-2-finetuned', num_labels=3)
|
1649
|
+
|
1650
|
+
# Load from local directory path
|
1651
|
+
model = BertForSequenceClassification.from_pretrained('./my_bert/')
|
1652
|
+
"""
|
1653
|
+
config = kwargs.pop("config", None)
|
1654
|
+
state_dict = kwargs.pop("state_dict", None)
|
1655
|
+
cache_dir = kwargs.pop("cache_dir", None)
|
1656
|
+
force_download = kwargs.get("force_download", False)
|
1657
|
+
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
|
1658
|
+
dtype = kwargs.pop("dtype", None)
|
1659
|
+
from_hf_hub = kwargs.pop("from_hf_hub", False)
|
1660
|
+
from_aistudio = kwargs.pop("from_aistudio", False)
|
1661
|
+
subfolder = kwargs.pop("subfolder", None)
|
1662
|
+
if subfolder is None:
|
1663
|
+
subfolder = ""
|
1664
|
+
variant = kwargs.pop("variant", None)
|
1665
|
+
use_safetensors = kwargs.pop(
|
1666
|
+
"use_safetensors", None if is_safetensors_available() else False
|
1667
|
+
)
|
1668
|
+
|
1669
|
+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", False)
|
1670
|
+
convert_from_torch = kwargs.pop("convert_from_torch", None)
|
1671
|
+
load_state_as_np = kwargs.pop("load_state_as_np", None)
|
1672
|
+
if load_state_as_np is not None:
|
1673
|
+
logging.warning("`load_state_as_np` is deprecated, please delete it!")
|
1674
|
+
|
1675
|
+
model_kwargs = kwargs
|
1676
|
+
|
1677
|
+
if convert_from_torch is None and os.environ.get("from_modelscope", False):
|
1678
|
+
logging.warning(
|
1679
|
+
"If you are attempting to load weights from ModelScope Hub and want to disable the default behavior of considering torch weights,"
|
1680
|
+
" you can set ·convert_from_torch=False·. By default, `convert_from_torch` is set to `True`. "
|
1681
|
+
)
|
1682
|
+
convert_from_torch = True
|
1683
|
+
|
1684
|
+
# from_hf_hub default enable convert_from_torch
|
1685
|
+
if from_hf_hub and convert_from_torch is None:
|
1686
|
+
logging.warning(
|
1687
|
+
"If you are attempting to load weights from Hugging Face Hub and want to disable the default behavior of considering torch weights,"
|
1688
|
+
" you can set ·convert_from_torch=False·. By default, `convert_from_torch` is set to `True`. "
|
1689
|
+
)
|
1690
|
+
convert_from_torch = True
|
1691
|
+
# convert_from_torch default is False
|
1692
|
+
if convert_from_torch is None:
|
1693
|
+
convert_from_torch = False
|
1694
|
+
|
1695
|
+
# 1. get the PretrainedConfig to init model
|
1696
|
+
if not isinstance(config, PretrainedConfig):
|
1697
|
+
config_path = (
|
1698
|
+
config if config is not None else pretrained_model_name_or_path
|
1699
|
+
)
|
1700
|
+
config, model_kwargs = (
|
1701
|
+
cls.config_class.from_pretrained( # NOTE cls.config_class : Qwen2VLForConditionalGeneration
|
1702
|
+
config_path,
|
1703
|
+
cache_dir=cache_dir,
|
1704
|
+
from_hf_hub=from_hf_hub,
|
1705
|
+
from_aistudio=from_aistudio,
|
1706
|
+
subfolder=subfolder,
|
1707
|
+
return_unused_kwargs=True,
|
1708
|
+
**kwargs,
|
1709
|
+
)
|
1710
|
+
)
|
1711
|
+
if "from_aistudio" in model_kwargs:
|
1712
|
+
model_kwargs.pop("from_aistudio")
|
1713
|
+
|
1714
|
+
if dtype is None:
|
1715
|
+
dtype = config.dtype
|
1716
|
+
config.dtype = dtype
|
1717
|
+
|
1718
|
+
init_contexts = []
|
1719
|
+
|
1720
|
+
if dtype:
|
1721
|
+
init_contexts.append(dtype_guard(dtype))
|
1722
|
+
|
1723
|
+
# Keep in fp32 modules
|
1724
|
+
keep_in_fp32_modules = None
|
1725
|
+
use_keep_in_fp32_modules = False
|
1726
|
+
|
1727
|
+
# resolve model_weight file
|
1728
|
+
resolved_archive_file, resolved_sharded_files, sharded_metadata, is_sharded = (
|
1729
|
+
cls._resolve_model_file_path(
|
1730
|
+
pretrained_model_name_or_path,
|
1731
|
+
cache_dir=cache_dir,
|
1732
|
+
subfolder=subfolder,
|
1733
|
+
from_hf_hub=from_hf_hub,
|
1734
|
+
from_aistudio=from_aistudio,
|
1735
|
+
config=config,
|
1736
|
+
convert_from_torch=False,
|
1737
|
+
use_safetensors=use_safetensors,
|
1738
|
+
variant=variant,
|
1739
|
+
)
|
1740
|
+
)
|
1741
|
+
|
1742
|
+
if convert_from_torch and state_dict is None:
|
1743
|
+
if (
|
1744
|
+
resolved_archive_file.endswith(PYTORCH_WEIGHTS_NAME)
|
1745
|
+
or resolved_archive_file.endswith(PYTORCH_WEIGHTS_INDEX_NAME)
|
1746
|
+
or resolved_archive_file.endswith(SAFE_WEIGHTS_NAME)
|
1747
|
+
or resolved_archive_file.endswith(SAFE_WEIGHTS_INDEX_NAME)
|
1748
|
+
):
|
1749
|
+
# try to get the name-mapping info
|
1750
|
+
convert_dir = os.path.dirname(resolved_archive_file)
|
1751
|
+
logging.info(
|
1752
|
+
f"Starting to convert pytorch weight file<{resolved_archive_file}> to "
|
1753
|
+
f"paddle weight file<{convert_dir}> ..."
|
1754
|
+
)
|
1755
|
+
state_dict = cls.convert(
|
1756
|
+
resolved_archive_file,
|
1757
|
+
config,
|
1758
|
+
# cache_dir=os.path.join(cache_dir, pretrained_model_name_or_path, subfolder),
|
1759
|
+
cache_dir=convert_dir,
|
1760
|
+
)
|
1761
|
+
elif (
|
1762
|
+
resolved_archive_file.endswith(PADDLE_WEIGHTS_NAME)
|
1763
|
+
or resolved_archive_file.endswith(PADDLE_WEIGHTS_INDEX_NAME)
|
1764
|
+
or resolved_archive_file.endswith(".pdparams")
|
1765
|
+
):
|
1766
|
+
print(f"file: {resolved_archive_file} is paddle weight.")
|
1767
|
+
else:
|
1768
|
+
raise ValueError(
|
1769
|
+
f"Unexpected file: {resolved_archive_file} for weight conversion."
|
1770
|
+
)
|
1771
|
+
# load pt weights early so that we know which dtype to init the model under
|
1772
|
+
if not is_sharded and state_dict is None:
|
1773
|
+
# 4. loading non-sharded ckpt from the state dict
|
1774
|
+
if config.tensor_parallel_degree > 1 and resolved_archive_file.endswith(
|
1775
|
+
"model_state.pdparams"
|
1776
|
+
):
|
1777
|
+
state_dict = cls.convert_tensor_parallel(resolved_archive_file, config)
|
1778
|
+
elif config.tensor_parallel_degree > 1 and resolved_archive_file.endswith(
|
1779
|
+
"model.safetensors"
|
1780
|
+
):
|
1781
|
+
raise NotImplementedError
|
1782
|
+
else:
|
1783
|
+
state_dict = load_state_dict(resolved_archive_file)
|
1784
|
+
|
1785
|
+
logging.info("Loaded weights file from disk, setting weights to model.")
|
1786
|
+
|
1787
|
+
# Check if `_keep_in_fp32_modules` is not None
|
1788
|
+
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
|
1789
|
+
dtype == "float16" or dtype == "bfloat16"
|
1790
|
+
)
|
1791
|
+
|
1792
|
+
if state_dict is not None:
|
1793
|
+
loaded_state_dict_keys = [k for k in state_dict.keys()]
|
1794
|
+
# will only support load paddle.Tensor to model.
|
1795
|
+
for k in list(state_dict.keys()):
|
1796
|
+
if not isinstance(state_dict[k], paddle.Tensor):
|
1797
|
+
with device_guard():
|
1798
|
+
state_dict[k] = paddle.Tensor.__call__(
|
1799
|
+
state_dict.pop(k), zero_copy=True
|
1800
|
+
)
|
1801
|
+
else:
|
1802
|
+
if is_sharded:
|
1803
|
+
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
|
1804
|
+
else:
|
1805
|
+
loaded_state_dict_keys = [k for k in state_dict.keys()]
|
1806
|
+
|
1807
|
+
if low_cpu_mem_usage: # or use_keep_in_fp32_modules:
|
1808
|
+
state_dict = None
|
1809
|
+
|
1810
|
+
# will only support load paddle.Tensor to model.
|
1811
|
+
if state_dict is not None:
|
1812
|
+
for k in list(state_dict.keys()):
|
1813
|
+
if not isinstance(state_dict[k], paddle.Tensor):
|
1814
|
+
with device_guard():
|
1815
|
+
state_dict[k] = paddle.Tensor.__call__(
|
1816
|
+
state_dict.pop(k), zero_copy=True
|
1817
|
+
)
|
1818
|
+
# 3. init the model
|
1819
|
+
init_args = config["init_args"] or ()
|
1820
|
+
with ContextManagers(init_contexts):
|
1821
|
+
model = cls(config, *init_args, **model_kwargs)
|
1822
|
+
|
1823
|
+
if use_keep_in_fp32_modules:
|
1824
|
+
# low_cpu_mem_usage = True
|
1825
|
+
keep_in_fp32_modules = model._keep_in_fp32_modules
|
1826
|
+
else:
|
1827
|
+
keep_in_fp32_modules = []
|
1828
|
+
|
1829
|
+
quantization_linear_list = None
|
1830
|
+
|
1831
|
+
model, missing_keys, unexpected_keys, mismatched_keys = (
|
1832
|
+
cls._load_pretrained_model(
|
1833
|
+
model=model,
|
1834
|
+
state_dict=state_dict,
|
1835
|
+
loaded_keys=loaded_state_dict_keys,
|
1836
|
+
resolved_archive_file=(
|
1837
|
+
resolved_sharded_files if is_sharded else resolved_archive_file
|
1838
|
+
),
|
1839
|
+
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
1840
|
+
config=config,
|
1841
|
+
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
1842
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
1843
|
+
dtype=dtype,
|
1844
|
+
keep_in_fp32_modules=keep_in_fp32_modules,
|
1845
|
+
quantization_linear_list=quantization_linear_list,
|
1846
|
+
sharded_metadata=sharded_metadata if is_sharded else None,
|
1847
|
+
)
|
1848
|
+
)
|
1849
|
+
|
1850
|
+
# load generation_config.json
|
1851
|
+
if model.can_generate() and pretrained_model_name_or_path is not None:
|
1852
|
+
try:
|
1853
|
+
model.generation_config = GenerationConfig.from_pretrained(
|
1854
|
+
pretrained_model_name_or_path,
|
1855
|
+
cache_dir=cache_dir,
|
1856
|
+
force_download=force_download,
|
1857
|
+
from_hf_hub=from_hf_hub,
|
1858
|
+
from_aistudio=from_aistudio,
|
1859
|
+
subfolder=subfolder,
|
1860
|
+
**kwargs,
|
1861
|
+
)
|
1862
|
+
except:
|
1863
|
+
logging.info(
|
1864
|
+
"Generation config file not found, using a generation config created from the model config."
|
1865
|
+
)
|
1866
|
+
pass
|
1867
|
+
|
1868
|
+
# Note:
|
1869
|
+
# 1. PipelineLayer will create parameters for each layer and
|
1870
|
+
# call `_synchronize_shared_weights()` to synchronize the shared parameters.
|
1871
|
+
# 2. When setting the model `state_dict`, `_synchronize_shared_weights` will be called to
|
1872
|
+
# synchronize the shared parameters.
|
1873
|
+
# However, when state dict only contains the one piece of shared parameters, the shared parameters
|
1874
|
+
# will be different from the original shared parameters.
|
1875
|
+
|
1876
|
+
if isinstance(model, PipelineLayer):
|
1877
|
+
model._synchronize_shared_weights()
|
1878
|
+
|
1879
|
+
if paddle.in_dynamic_mode():
|
1880
|
+
return model
|
1881
|
+
|
1882
|
+
return model, state_dict
|
1883
|
+
|
1884
|
+
def merge_auto_dist_configs(self, configs):
|
1885
|
+
"""
|
1886
|
+
Merged all auto dist configs into one config.
|
1887
|
+
configs is a list of config,every config is a dict,which means a model auto_dist_config.
|
1888
|
+
[
|
1889
|
+
{
|
1890
|
+
mp_config (dict): {
|
1891
|
+
"parallelize_plan": dict, the plan to shard the layer.
|
1892
|
+
}
|
1893
|
+
pp_config (dict): {
|
1894
|
+
"split_spec": OrderedDict|dict|str|list(str), The pipeline parallel split point.
|
1895
|
+
"global_spec": str|list(str), make the output tensor of specific layers on global mesh.
|
1896
|
+
}
|
1897
|
+
},{
|
1898
|
+
mp_config (dict): {
|
1899
|
+
"parallelize_plan": dict, the plan to shard the layer.
|
1900
|
+
}
|
1901
|
+
pp_config (dict): {
|
1902
|
+
"split_spec": OrderedDict|dict|str|list(str), The pipeline parallel split point.
|
1903
|
+
"global_spec": str|list(str), make the output tensor of specific layers on global mesh.
|
1904
|
+
}
|
1905
|
+
},....
|
1906
|
+
]
|
1907
|
+
"""
|
1908
|
+
assert isinstance(configs, (dict, list))
|
1909
|
+
if isinstance(configs, dict):
|
1910
|
+
return configs
|
1911
|
+
final_config = {
|
1912
|
+
"mp_config": None,
|
1913
|
+
"sp_config": None,
|
1914
|
+
"pp_config": None,
|
1915
|
+
}
|
1916
|
+
for config in configs:
|
1917
|
+
if "mp_config" in config and config["mp_config"] is not None:
|
1918
|
+
if final_config["mp_config"] is None:
|
1919
|
+
final_config["mp_config"] = config["mp_config"]
|
1920
|
+
else:
|
1921
|
+
for k, v in config["mp_config"]["parallelize_plan"].items():
|
1922
|
+
assert (
|
1923
|
+
k
|
1924
|
+
not in final_config["mp_config"]["parallelize_plan"].keys()
|
1925
|
+
), f"sublayer mp_config should be a subset of model but got sublayer config {config['mp_config']} and model config {final_config['mp_config']}."
|
1926
|
+
final_config["mp_config"]["parallelize_plan"][k] = v
|
1927
|
+
if "sp_config" in config and config["sp_config"] is not None:
|
1928
|
+
if final_config["sp_config"] is None:
|
1929
|
+
final_config["sp_config"] = config["sp_config"]
|
1930
|
+
else:
|
1931
|
+
for k, v in config["sp_config"]["parallelize_plan"].items():
|
1932
|
+
assert (
|
1933
|
+
k
|
1934
|
+
not in final_config["sp_config"]["parallelize_plan"].keys()
|
1935
|
+
), f"sublayer sp_config should be a subset of model but got sublayer config {config['sp_config']} and model config {final_config['sp_config']}."
|
1936
|
+
final_config["sp_config"]["parallelize_plan"][k] = v
|
1937
|
+
if "pp_config" in config and config["pp_config"] is not None:
|
1938
|
+
if isinstance(config["pp_config"]["split_spec"], str):
|
1939
|
+
config["pp_config"]["split_spec"] = [
|
1940
|
+
config["pp_config"]["split_spec"]
|
1941
|
+
]
|
1942
|
+
if final_config["pp_config"] is None:
|
1943
|
+
final_config["pp_config"] = config["pp_config"]
|
1944
|
+
else:
|
1945
|
+
final_config["pp_config"]["split_spec"] += config["pp_config"][
|
1946
|
+
"split_spec"
|
1947
|
+
]
|
1948
|
+
elif isinstance(config["pp_config"]["split_spec"], (tuple, list)):
|
1949
|
+
if final_config["pp_config"] is None:
|
1950
|
+
final_config["pp_config"] = config["pp_config"]
|
1951
|
+
else:
|
1952
|
+
final_config["pp_config"]["split_spec"] += config["pp_config"][
|
1953
|
+
"split_spec"
|
1954
|
+
]
|
1955
|
+
|
1956
|
+
if (
|
1957
|
+
final_config["pp_config"] is not None
|
1958
|
+
and len(final_config["pp_config"]["split_spec"]) == 1
|
1959
|
+
):
|
1960
|
+
final_config["pp_config"]["split_spec"] = final_config["pp_config"][
|
1961
|
+
"split_spec"
|
1962
|
+
][0]
|
1963
|
+
|
1964
|
+
return final_config
|
1965
|
+
|
1966
|
+
def _generate_auto_dist_config(self, auto_dist_degree):
|
1967
|
+
merged_config = {
|
1968
|
+
"sp_config": None,
|
1969
|
+
"mp_config": None,
|
1970
|
+
"pp_config": None,
|
1971
|
+
}
|
1972
|
+
for name, layer in self.named_sublayers(include_self=True):
|
1973
|
+
if hasattr(layer, "auto_dist_config"):
|
1974
|
+
if name != "":
|
1975
|
+
prefix = name + "."
|
1976
|
+
else:
|
1977
|
+
prefix = ""
|
1978
|
+
layer_config = layer.auto_dist_config(prefix)
|
1979
|
+
merged_config = self.merge_auto_dist_configs(
|
1980
|
+
[merged_config, layer_config]
|
1981
|
+
)
|
1982
|
+
for _, deeper_layer in layer.named_sublayers():
|
1983
|
+
if hasattr(deeper_layer, "auto_dist_config"):
|
1984
|
+
# mask all `auto_dist_config` methods in deeper layer
|
1985
|
+
deeper_layer.auto_dist_config = lambda x: {}
|
1986
|
+
|
1987
|
+
final_config = {
|
1988
|
+
"dp_config": None,
|
1989
|
+
"mp_config": None,
|
1990
|
+
"pp_config": None,
|
1991
|
+
}
|
1992
|
+
|
1993
|
+
if (
|
1994
|
+
"tensor_parallel" in auto_dist_degree
|
1995
|
+
and auto_dist_degree["tensor_parallel"]
|
1996
|
+
):
|
1997
|
+
merged_config["mp_config"] is not None
|
1998
|
+
final_config["mp_config"] = merged_config["mp_config"]
|
1999
|
+
|
2000
|
+
if (
|
2001
|
+
"sequence_parallel" in auto_dist_degree
|
2002
|
+
and auto_dist_degree["sequence_parallel"]
|
2003
|
+
):
|
2004
|
+
merged_config["sp_config"] is not None
|
2005
|
+
final_config["mp_config"] = merged_config["sp_config"]
|
2006
|
+
|
2007
|
+
if (
|
2008
|
+
"pipeline_parallel" in auto_dist_degree
|
2009
|
+
and auto_dist_degree["pipeline_parallel"]
|
2010
|
+
):
|
2011
|
+
merged_config["pp_config"] is not None
|
2012
|
+
final_config["pp_config"] = merged_config["pp_config"]
|
2013
|
+
|
2014
|
+
return final_config
|