paddlex 2.0.0rc4__py3-none-any.whl → 3.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (1533) hide show
  1. paddlex/.version +1 -0
  2. paddlex/__init__.py +35 -18
  3. paddlex/__main__.py +39 -0
  4. paddlex/configs/modules/3d_bev_detection/BEVFusion.yaml +38 -0
  5. paddlex/configs/modules/chart_parsing/PP-Chart2Table.yaml +13 -0
  6. paddlex/configs/modules/doc_text_orientation/PP-LCNet_x1_0_doc_ori.yaml +41 -0
  7. paddlex/configs/modules/doc_vlm/PP-DocBee-2B.yaml +14 -0
  8. paddlex/configs/modules/doc_vlm/PP-DocBee-7B.yaml +14 -0
  9. paddlex/configs/modules/doc_vlm/PP-DocBee2-3B.yaml +14 -0
  10. paddlex/configs/modules/face_detection/BlazeFace-FPN-SSH.yaml +40 -0
  11. paddlex/configs/modules/face_detection/BlazeFace.yaml +40 -0
  12. paddlex/configs/modules/face_detection/PP-YOLOE_plus-S_face.yaml +40 -0
  13. paddlex/configs/modules/face_detection/PicoDet_LCNet_x2_5_face.yaml +40 -0
  14. paddlex/configs/modules/face_feature/MobileFaceNet.yaml +41 -0
  15. paddlex/configs/modules/face_feature/ResNet50_face.yaml +41 -0
  16. paddlex/configs/modules/formula_recognition/LaTeX_OCR_rec.yaml +40 -0
  17. paddlex/configs/modules/formula_recognition/PP-FormulaNet-L.yaml +40 -0
  18. paddlex/configs/modules/formula_recognition/PP-FormulaNet-S.yaml +40 -0
  19. paddlex/configs/modules/formula_recognition/PP-FormulaNet_plus-L.yaml +40 -0
  20. paddlex/configs/modules/formula_recognition/PP-FormulaNet_plus-M.yaml +40 -0
  21. paddlex/configs/modules/formula_recognition/PP-FormulaNet_plus-S.yaml +40 -0
  22. paddlex/configs/modules/formula_recognition/UniMERNet.yaml +40 -0
  23. paddlex/configs/modules/human_detection/PP-YOLOE-L_human.yaml +42 -0
  24. paddlex/configs/modules/human_detection/PP-YOLOE-S_human.yaml +42 -0
  25. paddlex/configs/modules/image_anomaly_detection/STFPM.yaml +41 -0
  26. paddlex/configs/modules/image_classification/CLIP_vit_base_patch16_224.yaml +41 -0
  27. paddlex/configs/modules/image_classification/CLIP_vit_large_patch14_224.yaml +41 -0
  28. paddlex/configs/modules/image_classification/ConvNeXt_base_224.yaml +41 -0
  29. paddlex/configs/modules/image_classification/ConvNeXt_base_384.yaml +41 -0
  30. paddlex/configs/modules/image_classification/ConvNeXt_large_224.yaml +41 -0
  31. paddlex/configs/modules/image_classification/ConvNeXt_large_384.yaml +41 -0
  32. paddlex/configs/modules/image_classification/ConvNeXt_small.yaml +41 -0
  33. paddlex/configs/modules/image_classification/ConvNeXt_tiny.yaml +41 -0
  34. paddlex/configs/modules/image_classification/FasterNet-L.yaml +40 -0
  35. paddlex/configs/modules/image_classification/FasterNet-M.yaml +40 -0
  36. paddlex/configs/modules/image_classification/FasterNet-S.yaml +40 -0
  37. paddlex/configs/modules/image_classification/FasterNet-T0.yaml +40 -0
  38. paddlex/configs/modules/image_classification/FasterNet-T1.yaml +40 -0
  39. paddlex/configs/modules/image_classification/FasterNet-T2.yaml +40 -0
  40. paddlex/configs/modules/image_classification/MobileNetV1_x0_25.yaml +41 -0
  41. paddlex/configs/modules/image_classification/MobileNetV1_x0_5.yaml +41 -0
  42. paddlex/configs/modules/image_classification/MobileNetV1_x0_75.yaml +41 -0
  43. paddlex/configs/modules/image_classification/MobileNetV1_x1_0.yaml +41 -0
  44. paddlex/configs/modules/image_classification/MobileNetV2_x0_25.yaml +41 -0
  45. paddlex/configs/modules/image_classification/MobileNetV2_x0_5.yaml +41 -0
  46. paddlex/configs/modules/image_classification/MobileNetV2_x1_0.yaml +41 -0
  47. paddlex/configs/modules/image_classification/MobileNetV2_x1_5.yaml +41 -0
  48. paddlex/configs/modules/image_classification/MobileNetV2_x2_0.yaml +41 -0
  49. paddlex/configs/modules/image_classification/MobileNetV3_large_x0_35.yaml +41 -0
  50. paddlex/configs/modules/image_classification/MobileNetV3_large_x0_5.yaml +41 -0
  51. paddlex/configs/modules/image_classification/MobileNetV3_large_x0_75.yaml +41 -0
  52. paddlex/configs/modules/image_classification/MobileNetV3_large_x1_0.yaml +41 -0
  53. paddlex/configs/modules/image_classification/MobileNetV3_large_x1_25.yaml +41 -0
  54. paddlex/configs/modules/image_classification/MobileNetV3_small_x0_35.yaml +41 -0
  55. paddlex/configs/modules/image_classification/MobileNetV3_small_x0_5.yaml +41 -0
  56. paddlex/configs/modules/image_classification/MobileNetV3_small_x0_75.yaml +41 -0
  57. paddlex/configs/modules/image_classification/MobileNetV3_small_x1_0.yaml +41 -0
  58. paddlex/configs/modules/image_classification/MobileNetV3_small_x1_25.yaml +41 -0
  59. paddlex/configs/modules/image_classification/MobileNetV4_conv_large.yaml +41 -0
  60. paddlex/configs/modules/image_classification/MobileNetV4_conv_medium.yaml +41 -0
  61. paddlex/configs/modules/image_classification/MobileNetV4_conv_small.yaml +41 -0
  62. paddlex/configs/modules/image_classification/MobileNetV4_hybrid_large.yaml +41 -0
  63. paddlex/configs/modules/image_classification/MobileNetV4_hybrid_medium.yaml +41 -0
  64. paddlex/configs/modules/image_classification/PP-HGNetV2-B0.yaml +41 -0
  65. paddlex/configs/modules/image_classification/PP-HGNetV2-B1.yaml +41 -0
  66. paddlex/configs/modules/image_classification/PP-HGNetV2-B2.yaml +41 -0
  67. paddlex/configs/modules/image_classification/PP-HGNetV2-B3.yaml +41 -0
  68. paddlex/configs/modules/image_classification/PP-HGNetV2-B4.yaml +41 -0
  69. paddlex/configs/modules/image_classification/PP-HGNetV2-B5.yaml +41 -0
  70. paddlex/configs/modules/image_classification/PP-HGNetV2-B6.yaml +41 -0
  71. paddlex/configs/modules/image_classification/PP-HGNet_base.yaml +41 -0
  72. paddlex/configs/modules/image_classification/PP-HGNet_small.yaml +41 -0
  73. paddlex/configs/modules/image_classification/PP-HGNet_tiny.yaml +41 -0
  74. paddlex/configs/modules/image_classification/PP-LCNetV2_base.yaml +41 -0
  75. paddlex/configs/modules/image_classification/PP-LCNetV2_large.yaml +41 -0
  76. paddlex/configs/modules/image_classification/PP-LCNetV2_small.yaml +41 -0
  77. paddlex/configs/modules/image_classification/PP-LCNet_x0_25.yaml +41 -0
  78. paddlex/configs/modules/image_classification/PP-LCNet_x0_35.yaml +41 -0
  79. paddlex/configs/modules/image_classification/PP-LCNet_x0_5.yaml +41 -0
  80. paddlex/configs/modules/image_classification/PP-LCNet_x0_75.yaml +41 -0
  81. paddlex/configs/modules/image_classification/PP-LCNet_x1_0.yaml +41 -0
  82. paddlex/configs/modules/image_classification/PP-LCNet_x1_5.yaml +41 -0
  83. paddlex/configs/modules/image_classification/PP-LCNet_x2_0.yaml +41 -0
  84. paddlex/configs/modules/image_classification/PP-LCNet_x2_5.yaml +41 -0
  85. paddlex/configs/modules/image_classification/ResNet101.yaml +41 -0
  86. paddlex/configs/modules/image_classification/ResNet101_vd.yaml +41 -0
  87. paddlex/configs/modules/image_classification/ResNet152.yaml +41 -0
  88. paddlex/configs/modules/image_classification/ResNet152_vd.yaml +41 -0
  89. paddlex/configs/modules/image_classification/ResNet18.yaml +41 -0
  90. paddlex/configs/modules/image_classification/ResNet18_vd.yaml +41 -0
  91. paddlex/configs/modules/image_classification/ResNet200_vd.yaml +41 -0
  92. paddlex/configs/modules/image_classification/ResNet34.yaml +41 -0
  93. paddlex/configs/modules/image_classification/ResNet34_vd.yaml +41 -0
  94. paddlex/configs/modules/image_classification/ResNet50.yaml +41 -0
  95. paddlex/configs/modules/image_classification/ResNet50_vd.yaml +41 -0
  96. paddlex/configs/modules/image_classification/StarNet-S1.yaml +41 -0
  97. paddlex/configs/modules/image_classification/StarNet-S2.yaml +41 -0
  98. paddlex/configs/modules/image_classification/StarNet-S3.yaml +41 -0
  99. paddlex/configs/modules/image_classification/StarNet-S4.yaml +41 -0
  100. paddlex/configs/modules/image_classification/SwinTransformer_base_patch4_window12_384.yaml +41 -0
  101. paddlex/configs/modules/image_classification/SwinTransformer_base_patch4_window7_224.yaml +41 -0
  102. paddlex/configs/modules/image_classification/SwinTransformer_large_patch4_window12_384.yaml +41 -0
  103. paddlex/configs/modules/image_classification/SwinTransformer_large_patch4_window7_224.yaml +41 -0
  104. paddlex/configs/modules/image_classification/SwinTransformer_small_patch4_window7_224.yaml +41 -0
  105. paddlex/configs/modules/image_classification/SwinTransformer_tiny_patch4_window7_224.yaml +41 -0
  106. paddlex/configs/modules/image_feature/PP-ShiTuV2_rec.yaml +42 -0
  107. paddlex/configs/modules/image_feature/PP-ShiTuV2_rec_CLIP_vit_base.yaml +42 -0
  108. paddlex/configs/modules/image_feature/PP-ShiTuV2_rec_CLIP_vit_large.yaml +41 -0
  109. paddlex/configs/modules/image_multilabel_classification/CLIP_vit_base_patch16_448_ML.yaml +41 -0
  110. paddlex/configs/modules/image_multilabel_classification/PP-HGNetV2-B0_ML.yaml +41 -0
  111. paddlex/configs/modules/image_multilabel_classification/PP-HGNetV2-B4_ML.yaml +41 -0
  112. paddlex/configs/modules/image_multilabel_classification/PP-HGNetV2-B6_ML.yaml +41 -0
  113. paddlex/configs/modules/image_multilabel_classification/PP-LCNet_x1_0_ML.yaml +41 -0
  114. paddlex/configs/modules/image_multilabel_classification/ResNet50_ML.yaml +41 -0
  115. paddlex/configs/modules/image_unwarping/UVDoc.yaml +12 -0
  116. paddlex/configs/modules/instance_segmentation/Cascade-MaskRCNN-ResNet50-FPN.yaml +40 -0
  117. paddlex/configs/modules/instance_segmentation/Cascade-MaskRCNN-ResNet50-vd-SSLDv2-FPN.yaml +40 -0
  118. paddlex/configs/modules/instance_segmentation/Mask-RT-DETR-H.yaml +40 -0
  119. paddlex/configs/modules/instance_segmentation/Mask-RT-DETR-L.yaml +40 -0
  120. paddlex/configs/modules/instance_segmentation/Mask-RT-DETR-M.yaml +40 -0
  121. paddlex/configs/modules/instance_segmentation/Mask-RT-DETR-S.yaml +40 -0
  122. paddlex/configs/modules/instance_segmentation/Mask-RT-DETR-X.yaml +40 -0
  123. paddlex/configs/modules/instance_segmentation/MaskRCNN-ResNeXt101-vd-FPN.yaml +39 -0
  124. paddlex/configs/modules/instance_segmentation/MaskRCNN-ResNet101-FPN.yaml +40 -0
  125. paddlex/configs/modules/instance_segmentation/MaskRCNN-ResNet101-vd-FPN.yaml +40 -0
  126. paddlex/configs/modules/instance_segmentation/MaskRCNN-ResNet50-FPN.yaml +40 -0
  127. paddlex/configs/modules/instance_segmentation/MaskRCNN-ResNet50-vd-FPN.yaml +40 -0
  128. paddlex/configs/modules/instance_segmentation/MaskRCNN-ResNet50.yaml +40 -0
  129. paddlex/configs/modules/instance_segmentation/PP-YOLOE_seg-S.yaml +40 -0
  130. paddlex/configs/modules/instance_segmentation/SOLOv2.yaml +40 -0
  131. paddlex/configs/modules/keypoint_detection/PP-TinyPose_128x96.yaml +40 -0
  132. paddlex/configs/modules/keypoint_detection/PP-TinyPose_256x192.yaml +40 -0
  133. paddlex/configs/modules/layout_detection/PP-DocBlockLayout.yaml +40 -0
  134. paddlex/configs/modules/layout_detection/PP-DocLayout-L.yaml +40 -0
  135. paddlex/configs/modules/layout_detection/PP-DocLayout-M.yaml +40 -0
  136. paddlex/configs/modules/layout_detection/PP-DocLayout-S.yaml +40 -0
  137. paddlex/configs/modules/layout_detection/PP-DocLayout_plus-L.yaml +40 -0
  138. paddlex/configs/modules/layout_detection/PicoDet-L_layout_17cls.yaml +40 -0
  139. paddlex/configs/modules/layout_detection/PicoDet-L_layout_3cls.yaml +40 -0
  140. paddlex/configs/modules/layout_detection/PicoDet-S_layout_17cls.yaml +40 -0
  141. paddlex/configs/modules/layout_detection/PicoDet-S_layout_3cls.yaml +40 -0
  142. paddlex/configs/modules/layout_detection/PicoDet_layout_1x.yaml +40 -0
  143. paddlex/configs/modules/layout_detection/PicoDet_layout_1x_table.yaml +40 -0
  144. paddlex/configs/modules/layout_detection/RT-DETR-H_layout_17cls.yaml +40 -0
  145. paddlex/configs/modules/layout_detection/RT-DETR-H_layout_3cls.yaml +40 -0
  146. paddlex/configs/modules/mainbody_detection/PP-ShiTuV2_det.yaml +41 -0
  147. paddlex/configs/modules/multilingual_speech_recognition/whisper_base.yaml +12 -0
  148. paddlex/configs/modules/multilingual_speech_recognition/whisper_large.yaml +12 -0
  149. paddlex/configs/modules/multilingual_speech_recognition/whisper_medium.yaml +12 -0
  150. paddlex/configs/modules/multilingual_speech_recognition/whisper_small.yaml +12 -0
  151. paddlex/configs/modules/multilingual_speech_recognition/whisper_tiny.yaml +12 -0
  152. paddlex/configs/modules/object_detection/Cascade-FasterRCNN-ResNet50-FPN.yaml +41 -0
  153. paddlex/configs/modules/object_detection/Cascade-FasterRCNN-ResNet50-vd-SSLDv2-FPN.yaml +42 -0
  154. paddlex/configs/modules/object_detection/CenterNet-DLA-34.yaml +41 -0
  155. paddlex/configs/modules/object_detection/CenterNet-ResNet50.yaml +41 -0
  156. paddlex/configs/modules/object_detection/Co-DINO-R50.yaml +40 -0
  157. paddlex/configs/modules/object_detection/Co-DINO-Swin-L.yaml +40 -0
  158. paddlex/configs/modules/object_detection/Co-Deformable-DETR-R50.yaml +40 -0
  159. paddlex/configs/modules/object_detection/Co-Deformable-DETR-Swin-T.yaml +40 -0
  160. paddlex/configs/modules/object_detection/DETR-R50.yaml +42 -0
  161. paddlex/configs/modules/object_detection/FCOS-ResNet50.yaml +41 -0
  162. paddlex/configs/modules/object_detection/FasterRCNN-ResNeXt101-vd-FPN.yaml +42 -0
  163. paddlex/configs/modules/object_detection/FasterRCNN-ResNet101-FPN.yaml +42 -0
  164. paddlex/configs/modules/object_detection/FasterRCNN-ResNet101.yaml +42 -0
  165. paddlex/configs/modules/object_detection/FasterRCNN-ResNet34-FPN.yaml +42 -0
  166. paddlex/configs/modules/object_detection/FasterRCNN-ResNet50-FPN.yaml +42 -0
  167. paddlex/configs/modules/object_detection/FasterRCNN-ResNet50-vd-FPN.yaml +42 -0
  168. paddlex/configs/modules/object_detection/FasterRCNN-ResNet50-vd-SSLDv2-FPN.yaml +42 -0
  169. paddlex/configs/modules/object_detection/FasterRCNN-ResNet50.yaml +42 -0
  170. paddlex/configs/modules/object_detection/FasterRCNN-Swin-Tiny-FPN.yaml +42 -0
  171. paddlex/configs/modules/object_detection/PP-YOLOE_plus-L.yaml +40 -0
  172. paddlex/configs/modules/object_detection/PP-YOLOE_plus-M.yaml +40 -0
  173. paddlex/configs/modules/object_detection/PP-YOLOE_plus-S.yaml +40 -0
  174. paddlex/configs/modules/object_detection/PP-YOLOE_plus-X.yaml +40 -0
  175. paddlex/configs/modules/object_detection/PicoDet-L.yaml +40 -0
  176. paddlex/configs/modules/object_detection/PicoDet-M.yaml +42 -0
  177. paddlex/configs/modules/object_detection/PicoDet-S.yaml +40 -0
  178. paddlex/configs/modules/object_detection/PicoDet-XS.yaml +42 -0
  179. paddlex/configs/modules/object_detection/RT-DETR-H.yaml +40 -0
  180. paddlex/configs/modules/object_detection/RT-DETR-L.yaml +40 -0
  181. paddlex/configs/modules/object_detection/RT-DETR-R18.yaml +40 -0
  182. paddlex/configs/modules/object_detection/RT-DETR-R50.yaml +40 -0
  183. paddlex/configs/modules/object_detection/RT-DETR-X.yaml +40 -0
  184. paddlex/configs/modules/object_detection/YOLOX-L.yaml +40 -0
  185. paddlex/configs/modules/object_detection/YOLOX-M.yaml +40 -0
  186. paddlex/configs/modules/object_detection/YOLOX-N.yaml +40 -0
  187. paddlex/configs/modules/object_detection/YOLOX-S.yaml +40 -0
  188. paddlex/configs/modules/object_detection/YOLOX-T.yaml +40 -0
  189. paddlex/configs/modules/object_detection/YOLOX-X.yaml +40 -0
  190. paddlex/configs/modules/object_detection/YOLOv3-DarkNet53.yaml +40 -0
  191. paddlex/configs/modules/object_detection/YOLOv3-MobileNetV3.yaml +40 -0
  192. paddlex/configs/modules/object_detection/YOLOv3-ResNet50_vd_DCN.yaml +40 -0
  193. paddlex/configs/modules/open_vocabulary_detection/GroundingDINO-T.yaml +13 -0
  194. paddlex/configs/modules/open_vocabulary_detection/YOLO-Worldv2-L.yaml +13 -0
  195. paddlex/configs/modules/open_vocabulary_segmentation/SAM-H_box.yaml +17 -0
  196. paddlex/configs/modules/open_vocabulary_segmentation/SAM-H_point.yaml +15 -0
  197. paddlex/configs/modules/pedestrian_attribute_recognition/PP-LCNet_x1_0_pedestrian_attribute.yaml +41 -0
  198. paddlex/configs/modules/rotated_object_detection/PP-YOLOE-R-L.yaml +40 -0
  199. paddlex/configs/modules/seal_text_detection/PP-OCRv4_mobile_seal_det.yaml +40 -0
  200. paddlex/configs/modules/seal_text_detection/PP-OCRv4_server_seal_det.yaml +40 -0
  201. paddlex/configs/modules/semantic_segmentation/Deeplabv3-R101.yaml +40 -0
  202. paddlex/configs/modules/semantic_segmentation/Deeplabv3-R50.yaml +40 -0
  203. paddlex/configs/modules/semantic_segmentation/Deeplabv3_Plus-R101.yaml +40 -0
  204. paddlex/configs/modules/semantic_segmentation/Deeplabv3_Plus-R50.yaml +40 -0
  205. paddlex/configs/modules/semantic_segmentation/MaskFormer_small.yaml +42 -0
  206. paddlex/configs/modules/semantic_segmentation/MaskFormer_tiny.yaml +42 -0
  207. paddlex/configs/modules/semantic_segmentation/OCRNet_HRNet-W18.yaml +40 -0
  208. paddlex/configs/modules/semantic_segmentation/OCRNet_HRNet-W48.yaml +40 -0
  209. paddlex/configs/modules/semantic_segmentation/PP-LiteSeg-B.yaml +41 -0
  210. paddlex/configs/modules/semantic_segmentation/PP-LiteSeg-T.yaml +40 -0
  211. paddlex/configs/modules/semantic_segmentation/SeaFormer_base.yaml +40 -0
  212. paddlex/configs/modules/semantic_segmentation/SeaFormer_large.yaml +40 -0
  213. paddlex/configs/modules/semantic_segmentation/SeaFormer_small.yaml +40 -0
  214. paddlex/configs/modules/semantic_segmentation/SeaFormer_tiny.yaml +40 -0
  215. paddlex/configs/modules/semantic_segmentation/SegFormer-B0.yaml +40 -0
  216. paddlex/configs/modules/semantic_segmentation/SegFormer-B1.yaml +40 -0
  217. paddlex/configs/modules/semantic_segmentation/SegFormer-B2.yaml +40 -0
  218. paddlex/configs/modules/semantic_segmentation/SegFormer-B3.yaml +40 -0
  219. paddlex/configs/modules/semantic_segmentation/SegFormer-B4.yaml +40 -0
  220. paddlex/configs/modules/semantic_segmentation/SegFormer-B5.yaml +40 -0
  221. paddlex/configs/modules/small_object_detection/PP-YOLOE_plus_SOD-L.yaml +42 -0
  222. paddlex/configs/modules/small_object_detection/PP-YOLOE_plus_SOD-S.yaml +42 -0
  223. paddlex/configs/modules/small_object_detection/PP-YOLOE_plus_SOD-largesize-L.yaml +42 -0
  224. paddlex/configs/modules/table_cells_detection/RT-DETR-L_wired_table_cell_det.yaml +40 -0
  225. paddlex/configs/modules/table_cells_detection/RT-DETR-L_wireless_table_cell_det.yaml +40 -0
  226. paddlex/configs/modules/table_classification/PP-LCNet_x1_0_table_cls.yaml +41 -0
  227. paddlex/configs/modules/table_structure_recognition/SLANeXt_wired.yaml +39 -0
  228. paddlex/configs/modules/table_structure_recognition/SLANeXt_wireless.yaml +39 -0
  229. paddlex/configs/modules/table_structure_recognition/SLANet.yaml +39 -0
  230. paddlex/configs/modules/table_structure_recognition/SLANet_plus.yaml +39 -0
  231. paddlex/configs/modules/text_detection/PP-OCRv3_mobile_det.yaml +40 -0
  232. paddlex/configs/modules/text_detection/PP-OCRv3_server_det.yaml +40 -0
  233. paddlex/configs/modules/text_detection/PP-OCRv4_mobile_det.yaml +40 -0
  234. paddlex/configs/modules/text_detection/PP-OCRv4_server_det.yaml +40 -0
  235. paddlex/configs/modules/text_detection/PP-OCRv5_mobile_det.yaml +40 -0
  236. paddlex/configs/modules/text_detection/PP-OCRv5_server_det.yaml +40 -0
  237. paddlex/configs/modules/text_recognition/PP-OCRv3_mobile_rec.yaml +39 -0
  238. paddlex/configs/modules/text_recognition/PP-OCRv4_mobile_rec.yaml +39 -0
  239. paddlex/configs/modules/text_recognition/PP-OCRv4_server_rec.yaml +39 -0
  240. paddlex/configs/modules/text_recognition/PP-OCRv4_server_rec_doc.yaml +39 -0
  241. paddlex/configs/modules/text_recognition/PP-OCRv5_mobile_rec.yaml +39 -0
  242. paddlex/configs/modules/text_recognition/PP-OCRv5_server_rec.yaml +39 -0
  243. paddlex/configs/modules/text_recognition/arabic_PP-OCRv3_mobile_rec.yaml +39 -0
  244. paddlex/configs/modules/text_recognition/ch_RepSVTR_rec.yaml +39 -0
  245. paddlex/configs/modules/text_recognition/ch_SVTRv2_rec.yaml +39 -0
  246. paddlex/configs/modules/text_recognition/chinese_cht_PP-OCRv3_mobile_rec.yaml +39 -0
  247. paddlex/configs/modules/text_recognition/cyrillic_PP-OCRv3_mobile_rec.yaml +39 -0
  248. paddlex/configs/modules/text_recognition/devanagari_PP-OCRv3_mobile_rec.yaml +39 -0
  249. paddlex/configs/modules/text_recognition/en_PP-OCRv3_mobile_rec.yaml +39 -0
  250. paddlex/configs/modules/text_recognition/en_PP-OCRv4_mobile_rec.yaml +39 -0
  251. paddlex/configs/modules/text_recognition/japan_PP-OCRv3_mobile_rec.yaml +39 -0
  252. paddlex/configs/modules/text_recognition/ka_PP-OCRv3_mobile_rec.yaml +39 -0
  253. paddlex/configs/modules/text_recognition/korean_PP-OCRv3_mobile_rec.yaml +39 -0
  254. paddlex/configs/modules/text_recognition/latin_PP-OCRv3_mobile_rec.yaml +39 -0
  255. paddlex/configs/modules/text_recognition/ta_PP-OCRv3_mobile_rec.yaml +39 -0
  256. paddlex/configs/modules/text_recognition/te_PP-OCRv3_mobile_rec.yaml +39 -0
  257. paddlex/configs/modules/textline_orientation/PP-LCNet_x0_25_textline_ori.yaml +41 -0
  258. paddlex/configs/modules/ts_anomaly_detection/AutoEncoder_ad.yaml +37 -0
  259. paddlex/configs/modules/ts_anomaly_detection/DLinear_ad.yaml +37 -0
  260. paddlex/configs/modules/ts_anomaly_detection/Nonstationary_ad.yaml +37 -0
  261. paddlex/configs/modules/ts_anomaly_detection/PatchTST_ad.yaml +37 -0
  262. paddlex/configs/modules/ts_anomaly_detection/TimesNet_ad.yaml +37 -0
  263. paddlex/configs/modules/ts_classification/TimesNet_cls.yaml +37 -0
  264. paddlex/configs/modules/ts_forecast/DLinear.yaml +38 -0
  265. paddlex/configs/modules/ts_forecast/NLinear.yaml +38 -0
  266. paddlex/configs/modules/ts_forecast/Nonstationary.yaml +38 -0
  267. paddlex/configs/modules/ts_forecast/PatchTST.yaml +38 -0
  268. paddlex/configs/modules/ts_forecast/RLinear.yaml +38 -0
  269. paddlex/configs/modules/ts_forecast/TiDE.yaml +38 -0
  270. paddlex/configs/modules/ts_forecast/TimesNet.yaml +38 -0
  271. paddlex/configs/modules/vehicle_attribute_recognition/PP-LCNet_x1_0_vehicle_attribute.yaml +41 -0
  272. paddlex/configs/modules/vehicle_detection/PP-YOLOE-L_vehicle.yaml +41 -0
  273. paddlex/configs/modules/vehicle_detection/PP-YOLOE-S_vehicle.yaml +42 -0
  274. paddlex/configs/modules/video_classification/PP-TSM-R50_8frames_uniform.yaml +42 -0
  275. paddlex/configs/modules/video_classification/PP-TSMv2-LCNetV2_16frames_uniform.yaml +42 -0
  276. paddlex/configs/modules/video_classification/PP-TSMv2-LCNetV2_8frames_uniform.yaml +42 -0
  277. paddlex/configs/modules/video_detection/YOWO.yaml +40 -0
  278. paddlex/configs/pipelines/3d_bev_detection.yaml +9 -0
  279. paddlex/configs/pipelines/OCR.yaml +45 -0
  280. paddlex/configs/pipelines/PP-ChatOCRv3-doc.yaml +151 -0
  281. paddlex/configs/pipelines/PP-ChatOCRv4-doc.yaml +237 -0
  282. paddlex/configs/pipelines/PP-ShiTuV2.yaml +18 -0
  283. paddlex/configs/pipelines/PP-StructureV3.yaml +226 -0
  284. paddlex/configs/pipelines/anomaly_detection.yaml +8 -0
  285. paddlex/configs/pipelines/doc_preprocessor.yaml +15 -0
  286. paddlex/configs/pipelines/doc_understanding.yaml +9 -0
  287. paddlex/configs/pipelines/face_recognition.yaml +18 -0
  288. paddlex/configs/pipelines/formula_recognition.yaml +39 -0
  289. paddlex/configs/pipelines/human_keypoint_detection.yaml +17 -0
  290. paddlex/configs/pipelines/image_classification.yaml +10 -0
  291. paddlex/configs/pipelines/image_multilabel_classification.yaml +9 -0
  292. paddlex/configs/pipelines/instance_segmentation.yaml +10 -0
  293. paddlex/configs/pipelines/layout_parsing.yaml +102 -0
  294. paddlex/configs/pipelines/multilingual_speech_recognition.yaml +9 -0
  295. paddlex/configs/pipelines/object_detection.yaml +10 -0
  296. paddlex/configs/pipelines/open_vocabulary_detection.yaml +12 -0
  297. paddlex/configs/pipelines/open_vocabulary_segmentation.yaml +13 -0
  298. paddlex/configs/pipelines/pedestrian_attribute_recognition.yaml +15 -0
  299. paddlex/configs/pipelines/rotated_object_detection.yaml +10 -0
  300. paddlex/configs/pipelines/seal_recognition.yaml +52 -0
  301. paddlex/configs/pipelines/semantic_segmentation.yaml +10 -0
  302. paddlex/configs/pipelines/small_object_detection.yaml +10 -0
  303. paddlex/configs/pipelines/table_recognition.yaml +57 -0
  304. paddlex/configs/pipelines/table_recognition_v2.yaml +82 -0
  305. paddlex/configs/pipelines/ts_anomaly_detection.yaml +8 -0
  306. paddlex/configs/pipelines/ts_classification.yaml +8 -0
  307. paddlex/configs/pipelines/ts_forecast.yaml +8 -0
  308. paddlex/configs/pipelines/vehicle_attribute_recognition.yaml +15 -0
  309. paddlex/configs/pipelines/video_classification.yaml +9 -0
  310. paddlex/configs/pipelines/video_detection.yaml +10 -0
  311. paddlex/constants.py +17 -0
  312. paddlex/engine.py +56 -0
  313. paddlex/hpip_links.html +31 -0
  314. paddlex/inference/__init__.py +19 -0
  315. paddlex/inference/common/__init__.py +13 -0
  316. paddlex/inference/common/batch_sampler/__init__.py +21 -0
  317. paddlex/inference/common/batch_sampler/audio_batch_sampler.py +83 -0
  318. paddlex/inference/common/batch_sampler/base_batch_sampler.py +94 -0
  319. paddlex/inference/common/batch_sampler/det_3d_batch_sampler.py +144 -0
  320. paddlex/inference/common/batch_sampler/doc_vlm_batch_sampler.py +87 -0
  321. paddlex/inference/common/batch_sampler/image_batch_sampler.py +121 -0
  322. paddlex/inference/common/batch_sampler/ts_batch_sampler.py +109 -0
  323. paddlex/inference/common/batch_sampler/video_batch_sampler.py +74 -0
  324. paddlex/inference/common/reader/__init__.py +19 -0
  325. paddlex/inference/common/reader/audio_reader.py +46 -0
  326. paddlex/inference/common/reader/det_3d_reader.py +241 -0
  327. paddlex/inference/common/reader/image_reader.py +73 -0
  328. paddlex/inference/common/reader/ts_reader.py +46 -0
  329. paddlex/inference/common/reader/video_reader.py +42 -0
  330. paddlex/inference/common/result/__init__.py +29 -0
  331. paddlex/inference/common/result/base_cv_result.py +41 -0
  332. paddlex/inference/common/result/base_result.py +72 -0
  333. paddlex/inference/common/result/base_ts_result.py +41 -0
  334. paddlex/inference/common/result/base_video_result.py +36 -0
  335. paddlex/inference/common/result/mixin.py +709 -0
  336. paddlex/inference/models/__init__.py +86 -0
  337. paddlex/inference/models/anomaly_detection/__init__.py +15 -0
  338. paddlex/inference/models/anomaly_detection/predictor.py +135 -0
  339. paddlex/inference/models/anomaly_detection/processors.py +53 -0
  340. paddlex/inference/models/anomaly_detection/result.py +71 -0
  341. paddlex/inference/models/base/__init__.py +15 -0
  342. paddlex/inference/models/base/predictor/__init__.py +15 -0
  343. paddlex/inference/models/base/predictor/base_predictor.py +414 -0
  344. paddlex/inference/models/common/__init__.py +26 -0
  345. paddlex/inference/models/common/static_infer.py +801 -0
  346. paddlex/inference/models/common/tokenizer/__init__.py +21 -0
  347. paddlex/inference/models/common/tokenizer/bert_tokenizer.py +655 -0
  348. paddlex/inference/models/common/tokenizer/clip_tokenizer.py +609 -0
  349. paddlex/inference/models/common/tokenizer/gpt_tokenizer.py +453 -0
  350. paddlex/inference/models/common/tokenizer/qwen2_5_tokenizer.py +112 -0
  351. paddlex/inference/models/common/tokenizer/qwen2_tokenizer.py +438 -0
  352. paddlex/inference/models/common/tokenizer/qwen_tokenizer.py +288 -0
  353. paddlex/inference/models/common/tokenizer/tokenizer_utils.py +2149 -0
  354. paddlex/inference/models/common/tokenizer/tokenizer_utils_base.py +3720 -0
  355. paddlex/inference/models/common/tokenizer/utils.py +66 -0
  356. paddlex/inference/models/common/tokenizer/vocab.py +647 -0
  357. paddlex/inference/models/common/ts/__init__.py +15 -0
  358. paddlex/inference/models/common/ts/funcs.py +540 -0
  359. paddlex/inference/models/common/ts/processors.py +322 -0
  360. paddlex/inference/models/common/vision/__init__.py +23 -0
  361. paddlex/inference/models/common/vision/funcs.py +98 -0
  362. paddlex/inference/models/common/vision/processors.py +285 -0
  363. paddlex/inference/models/common/vlm/__init__.py +13 -0
  364. paddlex/inference/models/common/vlm/activations.py +189 -0
  365. paddlex/inference/models/common/vlm/bert_padding.py +127 -0
  366. paddlex/inference/models/common/vlm/conversion_utils.py +99 -0
  367. paddlex/inference/models/common/vlm/distributed.py +229 -0
  368. paddlex/inference/models/common/vlm/flash_attn_utils.py +119 -0
  369. paddlex/inference/models/common/vlm/fusion_ops.py +205 -0
  370. paddlex/inference/models/common/vlm/generation/__init__.py +34 -0
  371. paddlex/inference/models/common/vlm/generation/configuration_utils.py +533 -0
  372. paddlex/inference/models/common/vlm/generation/logits_process.py +730 -0
  373. paddlex/inference/models/common/vlm/generation/stopping_criteria.py +106 -0
  374. paddlex/inference/models/common/vlm/generation/utils.py +2162 -0
  375. paddlex/inference/models/common/vlm/transformers/__init__.py +16 -0
  376. paddlex/inference/models/common/vlm/transformers/configuration_utils.py +1037 -0
  377. paddlex/inference/models/common/vlm/transformers/conversion_utils.py +408 -0
  378. paddlex/inference/models/common/vlm/transformers/model_outputs.py +1612 -0
  379. paddlex/inference/models/common/vlm/transformers/model_utils.py +2014 -0
  380. paddlex/inference/models/common/vlm/transformers/utils.py +178 -0
  381. paddlex/inference/models/common/vlm/utils.py +109 -0
  382. paddlex/inference/models/doc_vlm/__init__.py +15 -0
  383. paddlex/inference/models/doc_vlm/modeling/GOT_ocr_2_0.py +830 -0
  384. paddlex/inference/models/doc_vlm/modeling/__init__.py +17 -0
  385. paddlex/inference/models/doc_vlm/modeling/qwen2.py +1606 -0
  386. paddlex/inference/models/doc_vlm/modeling/qwen2_5_vl.py +3006 -0
  387. paddlex/inference/models/doc_vlm/modeling/qwen2_vl.py +2495 -0
  388. paddlex/inference/models/doc_vlm/predictor.py +253 -0
  389. paddlex/inference/models/doc_vlm/processors/GOT_ocr_2_0.py +97 -0
  390. paddlex/inference/models/doc_vlm/processors/__init__.py +17 -0
  391. paddlex/inference/models/doc_vlm/processors/common.py +561 -0
  392. paddlex/inference/models/doc_vlm/processors/qwen2_5_vl.py +548 -0
  393. paddlex/inference/models/doc_vlm/processors/qwen2_vl.py +543 -0
  394. paddlex/inference/models/doc_vlm/result.py +21 -0
  395. paddlex/inference/models/face_feature/__init__.py +15 -0
  396. paddlex/inference/models/face_feature/predictor.py +66 -0
  397. paddlex/inference/models/formula_recognition/__init__.py +15 -0
  398. paddlex/inference/models/formula_recognition/predictor.py +193 -0
  399. paddlex/inference/models/formula_recognition/processors.py +1015 -0
  400. paddlex/inference/models/formula_recognition/result.py +411 -0
  401. paddlex/inference/models/image_classification/__init__.py +15 -0
  402. paddlex/inference/models/image_classification/predictor.py +172 -0
  403. paddlex/inference/models/image_classification/processors.py +89 -0
  404. paddlex/inference/models/image_classification/result.py +93 -0
  405. paddlex/inference/models/image_feature/__init__.py +15 -0
  406. paddlex/inference/models/image_feature/predictor.py +146 -0
  407. paddlex/inference/models/image_feature/processors.py +31 -0
  408. paddlex/inference/models/image_feature/result.py +32 -0
  409. paddlex/inference/models/image_multilabel_classification/__init__.py +15 -0
  410. paddlex/inference/models/image_multilabel_classification/predictor.py +95 -0
  411. paddlex/inference/models/image_multilabel_classification/processors.py +89 -0
  412. paddlex/inference/models/image_multilabel_classification/result.py +96 -0
  413. paddlex/inference/models/image_unwarping/__init__.py +15 -0
  414. paddlex/inference/models/image_unwarping/predictor.py +97 -0
  415. paddlex/inference/models/image_unwarping/processors.py +92 -0
  416. paddlex/inference/models/image_unwarping/result.py +47 -0
  417. paddlex/inference/models/instance_segmentation/__init__.py +15 -0
  418. paddlex/inference/models/instance_segmentation/predictor.py +202 -0
  419. paddlex/inference/models/instance_segmentation/processors.py +102 -0
  420. paddlex/inference/models/instance_segmentation/result.py +162 -0
  421. paddlex/inference/models/keypoint_detection/__init__.py +15 -0
  422. paddlex/inference/models/keypoint_detection/predictor.py +190 -0
  423. paddlex/inference/models/keypoint_detection/processors.py +367 -0
  424. paddlex/inference/models/keypoint_detection/result.py +197 -0
  425. paddlex/inference/models/m_3d_bev_detection/__init__.py +15 -0
  426. paddlex/inference/models/m_3d_bev_detection/predictor.py +303 -0
  427. paddlex/inference/models/m_3d_bev_detection/processors.py +990 -0
  428. paddlex/inference/models/m_3d_bev_detection/result.py +68 -0
  429. paddlex/inference/models/m_3d_bev_detection/visualizer_3d.py +169 -0
  430. paddlex/inference/models/multilingual_speech_recognition/__init__.py +15 -0
  431. paddlex/inference/models/multilingual_speech_recognition/predictor.py +137 -0
  432. paddlex/inference/models/multilingual_speech_recognition/processors.py +1933 -0
  433. paddlex/inference/models/multilingual_speech_recognition/result.py +21 -0
  434. paddlex/inference/models/object_detection/__init__.py +15 -0
  435. paddlex/inference/models/object_detection/predictor.py +344 -0
  436. paddlex/inference/models/object_detection/processors.py +885 -0
  437. paddlex/inference/models/object_detection/result.py +114 -0
  438. paddlex/inference/models/object_detection/utils.py +70 -0
  439. paddlex/inference/models/open_vocabulary_detection/__init__.py +15 -0
  440. paddlex/inference/models/open_vocabulary_detection/predictor.py +172 -0
  441. paddlex/inference/models/open_vocabulary_detection/processors/__init__.py +16 -0
  442. paddlex/inference/models/open_vocabulary_detection/processors/common.py +114 -0
  443. paddlex/inference/models/open_vocabulary_detection/processors/groundingdino_processors.py +496 -0
  444. paddlex/inference/models/open_vocabulary_detection/processors/yoloworld_processors.py +209 -0
  445. paddlex/inference/models/open_vocabulary_segmentation/__init__.py +15 -0
  446. paddlex/inference/models/open_vocabulary_segmentation/predictor.py +113 -0
  447. paddlex/inference/models/open_vocabulary_segmentation/processors/__init__.py +15 -0
  448. paddlex/inference/models/open_vocabulary_segmentation/processors/sam_processer.py +249 -0
  449. paddlex/inference/models/open_vocabulary_segmentation/results/__init__.py +15 -0
  450. paddlex/inference/models/open_vocabulary_segmentation/results/sam_result.py +149 -0
  451. paddlex/inference/models/semantic_segmentation/__init__.py +15 -0
  452. paddlex/inference/models/semantic_segmentation/predictor.py +158 -0
  453. paddlex/inference/models/semantic_segmentation/processors.py +117 -0
  454. paddlex/inference/models/semantic_segmentation/result.py +73 -0
  455. paddlex/inference/models/table_structure_recognition/__init__.py +15 -0
  456. paddlex/inference/models/table_structure_recognition/predictor.py +161 -0
  457. paddlex/inference/models/table_structure_recognition/processors.py +229 -0
  458. paddlex/inference/models/table_structure_recognition/result.py +63 -0
  459. paddlex/inference/models/text_detection/__init__.py +15 -0
  460. paddlex/inference/models/text_detection/predictor.py +191 -0
  461. paddlex/inference/models/text_detection/processors.py +538 -0
  462. paddlex/inference/models/text_detection/result.py +46 -0
  463. paddlex/inference/models/text_recognition/__init__.py +15 -0
  464. paddlex/inference/models/text_recognition/predictor.py +98 -0
  465. paddlex/inference/models/text_recognition/processors.py +245 -0
  466. paddlex/inference/models/text_recognition/result.py +76 -0
  467. paddlex/inference/models/ts_anomaly_detection/__init__.py +15 -0
  468. paddlex/inference/models/ts_anomaly_detection/predictor.py +141 -0
  469. paddlex/inference/models/ts_anomaly_detection/processors.py +98 -0
  470. paddlex/inference/models/ts_anomaly_detection/result.py +83 -0
  471. paddlex/inference/models/ts_classification/__init__.py +15 -0
  472. paddlex/inference/models/ts_classification/predictor.py +122 -0
  473. paddlex/inference/models/ts_classification/processors.py +122 -0
  474. paddlex/inference/models/ts_classification/result.py +87 -0
  475. paddlex/inference/models/ts_forecasting/__init__.py +15 -0
  476. paddlex/inference/models/ts_forecasting/predictor.py +154 -0
  477. paddlex/inference/models/ts_forecasting/processors.py +158 -0
  478. paddlex/inference/models/ts_forecasting/result.py +96 -0
  479. paddlex/inference/models/video_classification/__init__.py +15 -0
  480. paddlex/inference/models/video_classification/predictor.py +141 -0
  481. paddlex/inference/models/video_classification/processors.py +409 -0
  482. paddlex/inference/models/video_classification/result.py +96 -0
  483. paddlex/inference/models/video_detection/__init__.py +15 -0
  484. paddlex/inference/models/video_detection/predictor.py +129 -0
  485. paddlex/inference/models/video_detection/processors.py +463 -0
  486. paddlex/inference/models/video_detection/result.py +109 -0
  487. paddlex/inference/pipelines/__init__.py +239 -0
  488. paddlex/inference/pipelines/_parallel.py +172 -0
  489. paddlex/inference/pipelines/anomaly_detection/__init__.py +15 -0
  490. paddlex/inference/pipelines/anomaly_detection/pipeline.py +82 -0
  491. paddlex/inference/pipelines/attribute_recognition/__init__.py +15 -0
  492. paddlex/inference/pipelines/attribute_recognition/pipeline.py +120 -0
  493. paddlex/inference/pipelines/attribute_recognition/result.py +102 -0
  494. paddlex/inference/pipelines/base.py +156 -0
  495. paddlex/inference/pipelines/components/__init__.py +29 -0
  496. paddlex/inference/pipelines/components/chat_server/__init__.py +16 -0
  497. paddlex/inference/pipelines/components/chat_server/base.py +39 -0
  498. paddlex/inference/pipelines/components/chat_server/openai_bot_chat.py +236 -0
  499. paddlex/inference/pipelines/components/common/__init__.py +19 -0
  500. paddlex/inference/pipelines/components/common/base_operator.py +37 -0
  501. paddlex/inference/pipelines/components/common/base_result.py +66 -0
  502. paddlex/inference/pipelines/components/common/convert_points_and_boxes.py +45 -0
  503. paddlex/inference/pipelines/components/common/crop_image_regions.py +556 -0
  504. paddlex/inference/pipelines/components/common/seal_det_warp.py +972 -0
  505. paddlex/inference/pipelines/components/common/sort_boxes.py +85 -0
  506. paddlex/inference/pipelines/components/common/warp_image.py +50 -0
  507. paddlex/inference/pipelines/components/faisser.py +357 -0
  508. paddlex/inference/pipelines/components/prompt_engineering/__init__.py +16 -0
  509. paddlex/inference/pipelines/components/prompt_engineering/base.py +35 -0
  510. paddlex/inference/pipelines/components/prompt_engineering/generate_ensemble_prompt.py +128 -0
  511. paddlex/inference/pipelines/components/prompt_engineering/generate_kie_prompt.py +148 -0
  512. paddlex/inference/pipelines/components/retriever/__init__.py +16 -0
  513. paddlex/inference/pipelines/components/retriever/base.py +228 -0
  514. paddlex/inference/pipelines/components/retriever/openai_bot_retriever.py +70 -0
  515. paddlex/inference/pipelines/components/retriever/qianfan_bot_retriever.py +166 -0
  516. paddlex/inference/pipelines/components/utils/__init__.py +13 -0
  517. paddlex/inference/pipelines/components/utils/mixin.py +206 -0
  518. paddlex/inference/pipelines/doc_preprocessor/__init__.py +15 -0
  519. paddlex/inference/pipelines/doc_preprocessor/pipeline.py +209 -0
  520. paddlex/inference/pipelines/doc_preprocessor/result.py +98 -0
  521. paddlex/inference/pipelines/doc_understanding/__init__.py +15 -0
  522. paddlex/inference/pipelines/doc_understanding/pipeline.py +71 -0
  523. paddlex/inference/pipelines/face_recognition/__init__.py +15 -0
  524. paddlex/inference/pipelines/face_recognition/pipeline.py +63 -0
  525. paddlex/inference/pipelines/face_recognition/result.py +44 -0
  526. paddlex/inference/pipelines/formula_recognition/__init__.py +15 -0
  527. paddlex/inference/pipelines/formula_recognition/pipeline.py +347 -0
  528. paddlex/inference/pipelines/formula_recognition/result.py +282 -0
  529. paddlex/inference/pipelines/image_classification/__init__.py +15 -0
  530. paddlex/inference/pipelines/image_classification/pipeline.py +90 -0
  531. paddlex/inference/pipelines/image_multilabel_classification/__init__.py +15 -0
  532. paddlex/inference/pipelines/image_multilabel_classification/pipeline.py +97 -0
  533. paddlex/inference/pipelines/instance_segmentation/__init__.py +15 -0
  534. paddlex/inference/pipelines/instance_segmentation/pipeline.py +91 -0
  535. paddlex/inference/pipelines/keypoint_detection/__init__.py +15 -0
  536. paddlex/inference/pipelines/keypoint_detection/pipeline.py +158 -0
  537. paddlex/inference/pipelines/layout_parsing/__init__.py +16 -0
  538. paddlex/inference/pipelines/layout_parsing/pipeline.py +568 -0
  539. paddlex/inference/pipelines/layout_parsing/pipeline_v2.py +1382 -0
  540. paddlex/inference/pipelines/layout_parsing/result.py +191 -0
  541. paddlex/inference/pipelines/layout_parsing/result_v2.py +745 -0
  542. paddlex/inference/pipelines/layout_parsing/setting.py +87 -0
  543. paddlex/inference/pipelines/layout_parsing/utils.py +951 -0
  544. paddlex/inference/pipelines/layout_parsing/xycut_enhanced/__init__.py +16 -0
  545. paddlex/inference/pipelines/layout_parsing/xycut_enhanced/utils.py +1143 -0
  546. paddlex/inference/pipelines/layout_parsing/xycut_enhanced/xycuts.py +562 -0
  547. paddlex/inference/pipelines/m_3d_bev_detection/__init__.py +15 -0
  548. paddlex/inference/pipelines/m_3d_bev_detection/pipeline.py +74 -0
  549. paddlex/inference/pipelines/multilingual_speech_recognition/__init__.py +15 -0
  550. paddlex/inference/pipelines/multilingual_speech_recognition/pipeline.py +78 -0
  551. paddlex/inference/pipelines/object_detection/__init__.py +15 -0
  552. paddlex/inference/pipelines/object_detection/pipeline.py +115 -0
  553. paddlex/inference/pipelines/ocr/__init__.py +15 -0
  554. paddlex/inference/pipelines/ocr/pipeline.py +463 -0
  555. paddlex/inference/pipelines/ocr/result.py +255 -0
  556. paddlex/inference/pipelines/open_vocabulary_detection/__init__.py +15 -0
  557. paddlex/inference/pipelines/open_vocabulary_detection/pipeline.py +86 -0
  558. paddlex/inference/pipelines/open_vocabulary_segmentation/__init__.py +15 -0
  559. paddlex/inference/pipelines/open_vocabulary_segmentation/pipeline.py +100 -0
  560. paddlex/inference/pipelines/pp_chatocr/__init__.py +16 -0
  561. paddlex/inference/pipelines/pp_chatocr/pipeline_base.py +111 -0
  562. paddlex/inference/pipelines/pp_chatocr/pipeline_v3.py +781 -0
  563. paddlex/inference/pipelines/pp_chatocr/pipeline_v4.py +992 -0
  564. paddlex/inference/pipelines/pp_shitu_v2/__init__.py +15 -0
  565. paddlex/inference/pipelines/pp_shitu_v2/pipeline.py +156 -0
  566. paddlex/inference/pipelines/pp_shitu_v2/result.py +126 -0
  567. paddlex/inference/pipelines/rotated_object_detection/__init__.py +15 -0
  568. paddlex/inference/pipelines/rotated_object_detection/pipeline.py +95 -0
  569. paddlex/inference/pipelines/seal_recognition/__init__.py +15 -0
  570. paddlex/inference/pipelines/seal_recognition/pipeline.py +335 -0
  571. paddlex/inference/pipelines/seal_recognition/result.py +89 -0
  572. paddlex/inference/pipelines/semantic_segmentation/__init__.py +15 -0
  573. paddlex/inference/pipelines/semantic_segmentation/pipeline.py +95 -0
  574. paddlex/inference/pipelines/small_object_detection/__init__.py +15 -0
  575. paddlex/inference/pipelines/small_object_detection/pipeline.py +95 -0
  576. paddlex/inference/pipelines/table_recognition/__init__.py +16 -0
  577. paddlex/inference/pipelines/table_recognition/pipeline.py +486 -0
  578. paddlex/inference/pipelines/table_recognition/pipeline_v2.py +1395 -0
  579. paddlex/inference/pipelines/table_recognition/result.py +218 -0
  580. paddlex/inference/pipelines/table_recognition/table_recognition_post_processing.py +366 -0
  581. paddlex/inference/pipelines/table_recognition/table_recognition_post_processing_v2.py +488 -0
  582. paddlex/inference/pipelines/table_recognition/utils.py +44 -0
  583. paddlex/inference/pipelines/ts_anomaly_detection/__init__.py +15 -0
  584. paddlex/inference/pipelines/ts_anomaly_detection/pipeline.py +72 -0
  585. paddlex/inference/pipelines/ts_classification/__init__.py +15 -0
  586. paddlex/inference/pipelines/ts_classification/pipeline.py +72 -0
  587. paddlex/inference/pipelines/ts_forecasting/__init__.py +15 -0
  588. paddlex/inference/pipelines/ts_forecasting/pipeline.py +72 -0
  589. paddlex/inference/pipelines/video_classification/__init__.py +15 -0
  590. paddlex/inference/pipelines/video_classification/pipeline.py +79 -0
  591. paddlex/inference/pipelines/video_detection/__init__.py +15 -0
  592. paddlex/inference/pipelines/video_detection/pipeline.py +86 -0
  593. paddlex/inference/serving/__init__.py +17 -0
  594. paddlex/inference/serving/basic_serving/__init__.py +18 -0
  595. paddlex/inference/serving/basic_serving/_app.py +221 -0
  596. paddlex/inference/serving/basic_serving/_pipeline_apps/__init__.py +44 -0
  597. paddlex/inference/serving/basic_serving/_pipeline_apps/_common/__init__.py +13 -0
  598. paddlex/inference/serving/basic_serving/_pipeline_apps/_common/common.py +104 -0
  599. paddlex/inference/serving/basic_serving/_pipeline_apps/_common/image_recognition.py +36 -0
  600. paddlex/inference/serving/basic_serving/_pipeline_apps/_common/ocr.py +95 -0
  601. paddlex/inference/serving/basic_serving/_pipeline_apps/anomaly_detection.py +67 -0
  602. paddlex/inference/serving/basic_serving/_pipeline_apps/doc_preprocessor.py +100 -0
  603. paddlex/inference/serving/basic_serving/_pipeline_apps/doc_understanding.py +153 -0
  604. paddlex/inference/serving/basic_serving/_pipeline_apps/face_recognition.py +226 -0
  605. paddlex/inference/serving/basic_serving/_pipeline_apps/formula_recognition.py +100 -0
  606. paddlex/inference/serving/basic_serving/_pipeline_apps/human_keypoint_detection.py +81 -0
  607. paddlex/inference/serving/basic_serving/_pipeline_apps/image_classification.py +69 -0
  608. paddlex/inference/serving/basic_serving/_pipeline_apps/image_multilabel_classification.py +73 -0
  609. paddlex/inference/serving/basic_serving/_pipeline_apps/instance_segmentation.py +87 -0
  610. paddlex/inference/serving/basic_serving/_pipeline_apps/layout_parsing.py +117 -0
  611. paddlex/inference/serving/basic_serving/_pipeline_apps/m_3d_bev_detection.py +79 -0
  612. paddlex/inference/serving/basic_serving/_pipeline_apps/multilingual_speech_recognition.py +92 -0
  613. paddlex/inference/serving/basic_serving/_pipeline_apps/object_detection.py +77 -0
  614. paddlex/inference/serving/basic_serving/_pipeline_apps/ocr.py +102 -0
  615. paddlex/inference/serving/basic_serving/_pipeline_apps/open_vocabulary_detection.py +81 -0
  616. paddlex/inference/serving/basic_serving/_pipeline_apps/open_vocabulary_segmentation.py +91 -0
  617. paddlex/inference/serving/basic_serving/_pipeline_apps/pedestrian_attribute_recognition.py +84 -0
  618. paddlex/inference/serving/basic_serving/_pipeline_apps/pp_chatocrv3_doc.py +193 -0
  619. paddlex/inference/serving/basic_serving/_pipeline_apps/pp_chatocrv4_doc.py +223 -0
  620. paddlex/inference/serving/basic_serving/_pipeline_apps/pp_shituv2.py +221 -0
  621. paddlex/inference/serving/basic_serving/_pipeline_apps/pp_structurev3.py +143 -0
  622. paddlex/inference/serving/basic_serving/_pipeline_apps/rotated_object_detection.py +81 -0
  623. paddlex/inference/serving/basic_serving/_pipeline_apps/seal_recognition.py +106 -0
  624. paddlex/inference/serving/basic_serving/_pipeline_apps/semantic_segmentation.py +67 -0
  625. paddlex/inference/serving/basic_serving/_pipeline_apps/small_object_detection.py +72 -0
  626. paddlex/inference/serving/basic_serving/_pipeline_apps/table_recognition.py +108 -0
  627. paddlex/inference/serving/basic_serving/_pipeline_apps/table_recognition_v2.py +113 -0
  628. paddlex/inference/serving/basic_serving/_pipeline_apps/ts_anomaly_detection.py +65 -0
  629. paddlex/inference/serving/basic_serving/_pipeline_apps/ts_classification.py +64 -0
  630. paddlex/inference/serving/basic_serving/_pipeline_apps/ts_forecast.py +65 -0
  631. paddlex/inference/serving/basic_serving/_pipeline_apps/vehicle_attribute_recognition.py +84 -0
  632. paddlex/inference/serving/basic_serving/_pipeline_apps/video_classification.py +76 -0
  633. paddlex/inference/serving/basic_serving/_pipeline_apps/video_detection.py +92 -0
  634. paddlex/inference/serving/basic_serving/_server.py +40 -0
  635. paddlex/inference/serving/infra/__init__.py +13 -0
  636. paddlex/inference/serving/infra/config.py +36 -0
  637. paddlex/inference/serving/infra/models.py +79 -0
  638. paddlex/inference/serving/infra/storage.py +180 -0
  639. paddlex/inference/serving/infra/utils.py +285 -0
  640. paddlex/inference/serving/schemas/__init__.py +13 -0
  641. paddlex/inference/serving/schemas/anomaly_detection.py +39 -0
  642. paddlex/inference/serving/schemas/doc_preprocessor.py +54 -0
  643. paddlex/inference/serving/schemas/doc_understanding.py +78 -0
  644. paddlex/inference/serving/schemas/face_recognition.py +124 -0
  645. paddlex/inference/serving/schemas/formula_recognition.py +56 -0
  646. paddlex/inference/serving/schemas/human_keypoint_detection.py +55 -0
  647. paddlex/inference/serving/schemas/image_classification.py +45 -0
  648. paddlex/inference/serving/schemas/image_multilabel_classification.py +47 -0
  649. paddlex/inference/serving/schemas/instance_segmentation.py +53 -0
  650. paddlex/inference/serving/schemas/layout_parsing.py +71 -0
  651. paddlex/inference/serving/schemas/m_3d_bev_detection.py +48 -0
  652. paddlex/inference/serving/schemas/multilingual_speech_recognition.py +57 -0
  653. paddlex/inference/serving/schemas/object_detection.py +52 -0
  654. paddlex/inference/serving/schemas/ocr.py +60 -0
  655. paddlex/inference/serving/schemas/open_vocabulary_detection.py +52 -0
  656. paddlex/inference/serving/schemas/open_vocabulary_segmentation.py +52 -0
  657. paddlex/inference/serving/schemas/pedestrian_attribute_recognition.py +61 -0
  658. paddlex/inference/serving/schemas/pp_chatocrv3_doc.py +133 -0
  659. paddlex/inference/serving/schemas/pp_chatocrv4_doc.py +150 -0
  660. paddlex/inference/serving/schemas/pp_shituv2.py +124 -0
  661. paddlex/inference/serving/schemas/pp_structurev3.py +88 -0
  662. paddlex/inference/serving/schemas/rotated_object_detection.py +52 -0
  663. paddlex/inference/serving/schemas/seal_recognition.py +62 -0
  664. paddlex/inference/serving/schemas/semantic_segmentation.py +45 -0
  665. paddlex/inference/serving/schemas/shared/__init__.py +13 -0
  666. paddlex/inference/serving/schemas/shared/classification.py +23 -0
  667. paddlex/inference/serving/schemas/shared/image_segmentation.py +28 -0
  668. paddlex/inference/serving/schemas/shared/object_detection.py +24 -0
  669. paddlex/inference/serving/schemas/shared/ocr.py +25 -0
  670. paddlex/inference/serving/schemas/small_object_detection.py +52 -0
  671. paddlex/inference/serving/schemas/table_recognition.py +64 -0
  672. paddlex/inference/serving/schemas/table_recognition_v2.py +69 -0
  673. paddlex/inference/serving/schemas/ts_anomaly_detection.py +37 -0
  674. paddlex/inference/serving/schemas/ts_classification.py +38 -0
  675. paddlex/inference/serving/schemas/ts_forecast.py +37 -0
  676. paddlex/inference/serving/schemas/vehicle_attribute_recognition.py +61 -0
  677. paddlex/inference/serving/schemas/video_classification.py +44 -0
  678. paddlex/inference/serving/schemas/video_detection.py +56 -0
  679. paddlex/inference/utils/__init__.py +13 -0
  680. paddlex/inference/utils/benchmark.py +379 -0
  681. paddlex/inference/utils/color_map.py +123 -0
  682. paddlex/inference/utils/get_pipeline_path.py +27 -0
  683. paddlex/inference/utils/hpi.py +254 -0
  684. paddlex/inference/utils/hpi_model_info_collection.json +2331 -0
  685. paddlex/inference/utils/io/__init__.py +36 -0
  686. paddlex/inference/utils/io/readers.py +504 -0
  687. paddlex/inference/utils/io/style.py +381 -0
  688. paddlex/inference/utils/io/tablepyxl.py +157 -0
  689. paddlex/inference/utils/io/writers.py +458 -0
  690. paddlex/inference/utils/model_paths.py +48 -0
  691. paddlex/inference/utils/new_ir_blocklist.py +27 -0
  692. paddlex/inference/utils/official_models.py +367 -0
  693. paddlex/inference/utils/pp_option.py +339 -0
  694. paddlex/inference/utils/trt_blocklist.py +43 -0
  695. paddlex/inference/utils/trt_config.py +420 -0
  696. paddlex/model.py +131 -0
  697. paddlex/modules/__init__.py +115 -0
  698. paddlex/modules/anomaly_detection/__init__.py +18 -0
  699. paddlex/modules/anomaly_detection/dataset_checker/__init__.py +94 -0
  700. paddlex/modules/anomaly_detection/dataset_checker/dataset_src/__init__.py +19 -0
  701. paddlex/modules/anomaly_detection/dataset_checker/dataset_src/analyse_dataset.py +82 -0
  702. paddlex/modules/anomaly_detection/dataset_checker/dataset_src/check_dataset.py +91 -0
  703. paddlex/modules/anomaly_detection/dataset_checker/dataset_src/convert_dataset.py +233 -0
  704. paddlex/modules/anomaly_detection/dataset_checker/dataset_src/split_dataset.py +87 -0
  705. paddlex/modules/anomaly_detection/dataset_checker/dataset_src/utils/__init__.py +13 -0
  706. paddlex/modules/anomaly_detection/dataset_checker/dataset_src/utils/visualizer.py +76 -0
  707. paddlex/modules/anomaly_detection/evaluator.py +58 -0
  708. paddlex/modules/anomaly_detection/exportor.py +22 -0
  709. paddlex/modules/anomaly_detection/model_list.py +16 -0
  710. paddlex/modules/anomaly_detection/trainer.py +70 -0
  711. paddlex/modules/base/__init__.py +18 -0
  712. paddlex/modules/base/build_model.py +33 -0
  713. paddlex/modules/base/dataset_checker/__init__.py +16 -0
  714. paddlex/modules/base/dataset_checker/dataset_checker.py +169 -0
  715. paddlex/modules/base/dataset_checker/utils.py +108 -0
  716. paddlex/modules/base/evaluator.py +170 -0
  717. paddlex/modules/base/exportor.py +145 -0
  718. paddlex/modules/base/trainer.py +144 -0
  719. paddlex/modules/base/utils/__init__.py +13 -0
  720. paddlex/modules/base/utils/cinn_setting.py +89 -0
  721. paddlex/modules/base/utils/coco_eval.py +94 -0
  722. paddlex/modules/base/utils/topk_eval.py +118 -0
  723. paddlex/modules/doc_vlm/__init__.py +18 -0
  724. paddlex/modules/doc_vlm/dataset_checker.py +29 -0
  725. paddlex/modules/doc_vlm/evaluator.py +29 -0
  726. paddlex/modules/doc_vlm/exportor.py +29 -0
  727. paddlex/modules/doc_vlm/model_list.py +16 -0
  728. paddlex/modules/doc_vlm/trainer.py +41 -0
  729. paddlex/modules/face_recognition/__init__.py +18 -0
  730. paddlex/modules/face_recognition/dataset_checker/__init__.py +71 -0
  731. paddlex/modules/face_recognition/dataset_checker/dataset_src/__init__.py +16 -0
  732. paddlex/modules/face_recognition/dataset_checker/dataset_src/check_dataset.py +172 -0
  733. paddlex/modules/face_recognition/dataset_checker/dataset_src/utils/__init__.py +13 -0
  734. paddlex/modules/face_recognition/dataset_checker/dataset_src/utils/visualizer.py +153 -0
  735. paddlex/modules/face_recognition/evaluator.py +52 -0
  736. paddlex/modules/face_recognition/exportor.py +22 -0
  737. paddlex/modules/face_recognition/model_list.py +15 -0
  738. paddlex/modules/face_recognition/trainer.py +75 -0
  739. paddlex/modules/formula_recognition/__init__.py +18 -0
  740. paddlex/modules/formula_recognition/dataset_checker/__init__.py +113 -0
  741. paddlex/modules/formula_recognition/dataset_checker/dataset_src/__init__.py +19 -0
  742. paddlex/modules/formula_recognition/dataset_checker/dataset_src/analyse_dataset.py +158 -0
  743. paddlex/modules/formula_recognition/dataset_checker/dataset_src/check_dataset.py +76 -0
  744. paddlex/modules/formula_recognition/dataset_checker/dataset_src/convert_dataset.py +95 -0
  745. paddlex/modules/formula_recognition/dataset_checker/dataset_src/split_dataset.py +80 -0
  746. paddlex/modules/formula_recognition/evaluator.py +80 -0
  747. paddlex/modules/formula_recognition/exportor.py +22 -0
  748. paddlex/modules/formula_recognition/model_list.py +23 -0
  749. paddlex/modules/formula_recognition/trainer.py +123 -0
  750. paddlex/modules/general_recognition/__init__.py +18 -0
  751. paddlex/modules/general_recognition/dataset_checker/__init__.py +107 -0
  752. paddlex/modules/general_recognition/dataset_checker/dataset_src/__init__.py +19 -0
  753. paddlex/modules/general_recognition/dataset_checker/dataset_src/analyse_dataset.py +96 -0
  754. paddlex/modules/general_recognition/dataset_checker/dataset_src/check_dataset.py +99 -0
  755. paddlex/modules/general_recognition/dataset_checker/dataset_src/convert_dataset.py +100 -0
  756. paddlex/modules/general_recognition/dataset_checker/dataset_src/split_dataset.py +82 -0
  757. paddlex/modules/general_recognition/dataset_checker/dataset_src/utils/__init__.py +13 -0
  758. paddlex/modules/general_recognition/dataset_checker/dataset_src/utils/visualizer.py +147 -0
  759. paddlex/modules/general_recognition/evaluator.py +31 -0
  760. paddlex/modules/general_recognition/exportor.py +22 -0
  761. paddlex/modules/general_recognition/model_list.py +19 -0
  762. paddlex/modules/general_recognition/trainer.py +52 -0
  763. paddlex/modules/image_classification/__init__.py +18 -0
  764. paddlex/modules/image_classification/dataset_checker/__init__.py +104 -0
  765. paddlex/modules/image_classification/dataset_checker/dataset_src/__init__.py +19 -0
  766. paddlex/modules/image_classification/dataset_checker/dataset_src/analyse_dataset.py +92 -0
  767. paddlex/modules/image_classification/dataset_checker/dataset_src/check_dataset.py +132 -0
  768. paddlex/modules/image_classification/dataset_checker/dataset_src/convert_dataset.py +51 -0
  769. paddlex/modules/image_classification/dataset_checker/dataset_src/split_dataset.py +81 -0
  770. paddlex/modules/image_classification/dataset_checker/dataset_src/utils/__init__.py +13 -0
  771. paddlex/modules/image_classification/dataset_checker/dataset_src/utils/visualizer.py +153 -0
  772. paddlex/modules/image_classification/evaluator.py +43 -0
  773. paddlex/modules/image_classification/exportor.py +22 -0
  774. paddlex/modules/image_classification/model_list.py +99 -0
  775. paddlex/modules/image_classification/trainer.py +82 -0
  776. paddlex/modules/image_unwarping/__init__.py +13 -0
  777. paddlex/modules/image_unwarping/model_list.py +17 -0
  778. paddlex/modules/instance_segmentation/__init__.py +18 -0
  779. paddlex/modules/instance_segmentation/dataset_checker/__init__.py +107 -0
  780. paddlex/modules/instance_segmentation/dataset_checker/dataset_src/__init__.py +19 -0
  781. paddlex/modules/instance_segmentation/dataset_checker/dataset_src/analyse_dataset.py +82 -0
  782. paddlex/modules/instance_segmentation/dataset_checker/dataset_src/check_dataset.py +95 -0
  783. paddlex/modules/instance_segmentation/dataset_checker/dataset_src/convert_dataset.py +241 -0
  784. paddlex/modules/instance_segmentation/dataset_checker/dataset_src/split_dataset.py +122 -0
  785. paddlex/modules/instance_segmentation/dataset_checker/dataset_src/utils/__init__.py +13 -0
  786. paddlex/modules/instance_segmentation/dataset_checker/dataset_src/utils/visualizer.py +223 -0
  787. paddlex/modules/instance_segmentation/evaluator.py +32 -0
  788. paddlex/modules/instance_segmentation/exportor.py +22 -0
  789. paddlex/modules/instance_segmentation/model_list.py +33 -0
  790. paddlex/modules/instance_segmentation/trainer.py +31 -0
  791. paddlex/modules/keypoint_detection/__init__.py +18 -0
  792. paddlex/modules/keypoint_detection/dataset_checker/__init__.py +56 -0
  793. paddlex/modules/keypoint_detection/dataset_checker/dataset_src/__init__.py +15 -0
  794. paddlex/modules/keypoint_detection/dataset_checker/dataset_src/check_dataset.py +91 -0
  795. paddlex/modules/keypoint_detection/dataset_checker/dataset_src/utils/__init__.py +13 -0
  796. paddlex/modules/keypoint_detection/dataset_checker/dataset_src/utils/visualizer.py +124 -0
  797. paddlex/modules/keypoint_detection/evaluator.py +41 -0
  798. paddlex/modules/keypoint_detection/exportor.py +22 -0
  799. paddlex/modules/keypoint_detection/model_list.py +16 -0
  800. paddlex/modules/keypoint_detection/trainer.py +39 -0
  801. paddlex/modules/m_3d_bev_detection/__init__.py +18 -0
  802. paddlex/modules/m_3d_bev_detection/dataset_checker/__init__.py +95 -0
  803. paddlex/modules/m_3d_bev_detection/dataset_checker/dataset_src/__init__.py +17 -0
  804. paddlex/modules/m_3d_bev_detection/dataset_checker/dataset_src/analyse_dataset.py +106 -0
  805. paddlex/modules/m_3d_bev_detection/dataset_checker/dataset_src/check_dataset.py +101 -0
  806. paddlex/modules/m_3d_bev_detection/evaluator.py +46 -0
  807. paddlex/modules/m_3d_bev_detection/exportor.py +22 -0
  808. paddlex/modules/m_3d_bev_detection/model_list.py +18 -0
  809. paddlex/modules/m_3d_bev_detection/trainer.py +68 -0
  810. paddlex/modules/multilabel_classification/__init__.py +18 -0
  811. paddlex/modules/multilabel_classification/dataset_checker/__init__.py +106 -0
  812. paddlex/modules/multilabel_classification/dataset_checker/dataset_src/__init__.py +19 -0
  813. paddlex/modules/multilabel_classification/dataset_checker/dataset_src/analyse_dataset.py +94 -0
  814. paddlex/modules/multilabel_classification/dataset_checker/dataset_src/check_dataset.py +132 -0
  815. paddlex/modules/multilabel_classification/dataset_checker/dataset_src/convert_dataset.py +120 -0
  816. paddlex/modules/multilabel_classification/dataset_checker/dataset_src/split_dataset.py +81 -0
  817. paddlex/modules/multilabel_classification/dataset_checker/dataset_src/utils/__init__.py +13 -0
  818. paddlex/modules/multilabel_classification/dataset_checker/dataset_src/utils/visualizer.py +149 -0
  819. paddlex/modules/multilabel_classification/evaluator.py +43 -0
  820. paddlex/modules/multilabel_classification/exportor.py +22 -0
  821. paddlex/modules/multilabel_classification/model_list.py +24 -0
  822. paddlex/modules/multilabel_classification/trainer.py +85 -0
  823. paddlex/modules/multilingual_speech_recognition/__init__.py +18 -0
  824. paddlex/modules/multilingual_speech_recognition/dataset_checker.py +27 -0
  825. paddlex/modules/multilingual_speech_recognition/evaluator.py +27 -0
  826. paddlex/modules/multilingual_speech_recognition/exportor.py +27 -0
  827. paddlex/modules/multilingual_speech_recognition/model_list.py +22 -0
  828. paddlex/modules/multilingual_speech_recognition/trainer.py +42 -0
  829. paddlex/modules/object_detection/__init__.py +18 -0
  830. paddlex/modules/object_detection/dataset_checker/__init__.py +106 -0
  831. paddlex/modules/object_detection/dataset_checker/dataset_src/__init__.py +19 -0
  832. paddlex/modules/object_detection/dataset_checker/dataset_src/analyse_dataset.py +82 -0
  833. paddlex/modules/object_detection/dataset_checker/dataset_src/check_dataset.py +91 -0
  834. paddlex/modules/object_detection/dataset_checker/dataset_src/convert_dataset.py +438 -0
  835. paddlex/modules/object_detection/dataset_checker/dataset_src/split_dataset.py +123 -0
  836. paddlex/modules/object_detection/dataset_checker/dataset_src/utils/__init__.py +13 -0
  837. paddlex/modules/object_detection/dataset_checker/dataset_src/utils/visualizer.py +193 -0
  838. paddlex/modules/object_detection/evaluator.py +57 -0
  839. paddlex/modules/object_detection/exportor.py +22 -0
  840. paddlex/modules/object_detection/model_list.py +86 -0
  841. paddlex/modules/object_detection/trainer.py +98 -0
  842. paddlex/modules/open_vocabulary_detection/__init__.py +18 -0
  843. paddlex/modules/open_vocabulary_detection/dataset_checker.py +29 -0
  844. paddlex/modules/open_vocabulary_detection/evaluator.py +29 -0
  845. paddlex/modules/open_vocabulary_detection/exportor.py +29 -0
  846. paddlex/modules/open_vocabulary_detection/model_list.py +16 -0
  847. paddlex/modules/open_vocabulary_detection/trainer.py +44 -0
  848. paddlex/modules/open_vocabulary_segmentation/__init__.py +18 -0
  849. paddlex/modules/open_vocabulary_segmentation/dataset_checker.py +29 -0
  850. paddlex/modules/open_vocabulary_segmentation/evaluator.py +29 -0
  851. paddlex/modules/open_vocabulary_segmentation/exportor.py +29 -0
  852. paddlex/modules/open_vocabulary_segmentation/model_list.py +19 -0
  853. paddlex/modules/open_vocabulary_segmentation/trainer.py +44 -0
  854. paddlex/modules/semantic_segmentation/__init__.py +18 -0
  855. paddlex/modules/semantic_segmentation/dataset_checker/__init__.py +109 -0
  856. paddlex/modules/semantic_segmentation/dataset_checker/dataset_src/__init__.py +19 -0
  857. paddlex/modules/semantic_segmentation/dataset_checker/dataset_src/analyse_dataset.py +76 -0
  858. paddlex/modules/semantic_segmentation/dataset_checker/dataset_src/check_dataset.py +80 -0
  859. paddlex/modules/semantic_segmentation/dataset_checker/dataset_src/convert_dataset.py +165 -0
  860. paddlex/modules/semantic_segmentation/dataset_checker/dataset_src/split_dataset.py +87 -0
  861. paddlex/modules/semantic_segmentation/dataset_checker/dataset_src/utils/__init__.py +13 -0
  862. paddlex/modules/semantic_segmentation/dataset_checker/dataset_src/utils/visualizer.py +75 -0
  863. paddlex/modules/semantic_segmentation/evaluator.py +58 -0
  864. paddlex/modules/semantic_segmentation/exportor.py +31 -0
  865. paddlex/modules/semantic_segmentation/model_list.py +37 -0
  866. paddlex/modules/semantic_segmentation/trainer.py +72 -0
  867. paddlex/modules/table_recognition/__init__.py +18 -0
  868. paddlex/modules/table_recognition/dataset_checker/__init__.py +98 -0
  869. paddlex/modules/table_recognition/dataset_checker/dataset_src/__init__.py +18 -0
  870. paddlex/modules/table_recognition/dataset_checker/dataset_src/analyse_dataset.py +59 -0
  871. paddlex/modules/table_recognition/dataset_checker/dataset_src/check_dataset.py +87 -0
  872. paddlex/modules/table_recognition/dataset_checker/dataset_src/split_dataset.py +80 -0
  873. paddlex/modules/table_recognition/evaluator.py +43 -0
  874. paddlex/modules/table_recognition/exportor.py +22 -0
  875. paddlex/modules/table_recognition/model_list.py +21 -0
  876. paddlex/modules/table_recognition/trainer.py +67 -0
  877. paddlex/modules/text_detection/__init__.py +18 -0
  878. paddlex/modules/text_detection/dataset_checker/__init__.py +107 -0
  879. paddlex/modules/text_detection/dataset_checker/dataset_src/__init__.py +18 -0
  880. paddlex/modules/text_detection/dataset_checker/dataset_src/analyse_dataset.py +220 -0
  881. paddlex/modules/text_detection/dataset_checker/dataset_src/check_dataset.py +106 -0
  882. paddlex/modules/text_detection/dataset_checker/dataset_src/split_dataset.py +140 -0
  883. paddlex/modules/text_detection/evaluator.py +41 -0
  884. paddlex/modules/text_detection/exportor.py +22 -0
  885. paddlex/modules/text_detection/model_list.py +26 -0
  886. paddlex/modules/text_detection/trainer.py +65 -0
  887. paddlex/modules/text_recognition/__init__.py +18 -0
  888. paddlex/modules/text_recognition/dataset_checker/__init__.py +125 -0
  889. paddlex/modules/text_recognition/dataset_checker/dataset_src/__init__.py +19 -0
  890. paddlex/modules/text_recognition/dataset_checker/dataset_src/analyse_dataset.py +162 -0
  891. paddlex/modules/text_recognition/dataset_checker/dataset_src/check_dataset.py +104 -0
  892. paddlex/modules/text_recognition/dataset_checker/dataset_src/convert_dataset.py +95 -0
  893. paddlex/modules/text_recognition/dataset_checker/dataset_src/split_dataset.py +80 -0
  894. paddlex/modules/text_recognition/evaluator.py +64 -0
  895. paddlex/modules/text_recognition/exportor.py +22 -0
  896. paddlex/modules/text_recognition/model_list.py +36 -0
  897. paddlex/modules/text_recognition/trainer.py +105 -0
  898. paddlex/modules/ts_anomaly_detection/__init__.py +19 -0
  899. paddlex/modules/ts_anomaly_detection/dataset_checker/__init__.py +111 -0
  900. paddlex/modules/ts_anomaly_detection/dataset_checker/dataset_src/__init__.py +19 -0
  901. paddlex/modules/ts_anomaly_detection/dataset_checker/dataset_src/analyse_dataset.py +19 -0
  902. paddlex/modules/ts_anomaly_detection/dataset_checker/dataset_src/check_dataset.py +64 -0
  903. paddlex/modules/ts_anomaly_detection/dataset_checker/dataset_src/convert_dataset.py +74 -0
  904. paddlex/modules/ts_anomaly_detection/dataset_checker/dataset_src/split_dataset.py +63 -0
  905. paddlex/modules/ts_anomaly_detection/evaluator.py +67 -0
  906. paddlex/modules/ts_anomaly_detection/exportor.py +44 -0
  907. paddlex/modules/ts_anomaly_detection/model_list.py +22 -0
  908. paddlex/modules/ts_anomaly_detection/trainer.py +113 -0
  909. paddlex/modules/ts_classification/__init__.py +19 -0
  910. paddlex/modules/ts_classification/dataset_checker/__init__.py +111 -0
  911. paddlex/modules/ts_classification/dataset_checker/dataset_src/__init__.py +19 -0
  912. paddlex/modules/ts_classification/dataset_checker/dataset_src/analyse_dataset.py +77 -0
  913. paddlex/modules/ts_classification/dataset_checker/dataset_src/check_dataset.py +64 -0
  914. paddlex/modules/ts_classification/dataset_checker/dataset_src/convert_dataset.py +74 -0
  915. paddlex/modules/ts_classification/dataset_checker/dataset_src/split_dataset.py +88 -0
  916. paddlex/modules/ts_classification/evaluator.py +66 -0
  917. paddlex/modules/ts_classification/exportor.py +44 -0
  918. paddlex/modules/ts_classification/model_list.py +18 -0
  919. paddlex/modules/ts_classification/trainer.py +108 -0
  920. paddlex/modules/ts_forecast/__init__.py +19 -0
  921. paddlex/modules/ts_forecast/dataset_checker/__init__.py +111 -0
  922. paddlex/modules/ts_forecast/dataset_checker/dataset_src/__init__.py +19 -0
  923. paddlex/modules/ts_forecast/dataset_checker/dataset_src/analyse_dataset.py +19 -0
  924. paddlex/modules/ts_forecast/dataset_checker/dataset_src/check_dataset.py +64 -0
  925. paddlex/modules/ts_forecast/dataset_checker/dataset_src/convert_dataset.py +73 -0
  926. paddlex/modules/ts_forecast/dataset_checker/dataset_src/split_dataset.py +63 -0
  927. paddlex/modules/ts_forecast/evaluator.py +66 -0
  928. paddlex/modules/ts_forecast/exportor.py +44 -0
  929. paddlex/modules/ts_forecast/model_list.py +24 -0
  930. paddlex/modules/ts_forecast/trainer.py +108 -0
  931. paddlex/modules/video_classification/__init__.py +18 -0
  932. paddlex/modules/video_classification/dataset_checker/__init__.py +93 -0
  933. paddlex/modules/video_classification/dataset_checker/dataset_src/__init__.py +18 -0
  934. paddlex/modules/video_classification/dataset_checker/dataset_src/analyse_dataset.py +93 -0
  935. paddlex/modules/video_classification/dataset_checker/dataset_src/check_dataset.py +120 -0
  936. paddlex/modules/video_classification/dataset_checker/dataset_src/split_dataset.py +82 -0
  937. paddlex/modules/video_classification/evaluator.py +44 -0
  938. paddlex/modules/video_classification/exportor.py +22 -0
  939. paddlex/modules/video_classification/model_list.py +19 -0
  940. paddlex/modules/video_classification/trainer.py +88 -0
  941. paddlex/modules/video_detection/__init__.py +18 -0
  942. paddlex/modules/video_detection/dataset_checker/__init__.py +86 -0
  943. paddlex/modules/video_detection/dataset_checker/dataset_src/__init__.py +17 -0
  944. paddlex/modules/video_detection/dataset_checker/dataset_src/analyse_dataset.py +100 -0
  945. paddlex/modules/video_detection/dataset_checker/dataset_src/check_dataset.py +132 -0
  946. paddlex/modules/video_detection/evaluator.py +42 -0
  947. paddlex/modules/video_detection/exportor.py +22 -0
  948. paddlex/modules/video_detection/model_list.py +15 -0
  949. paddlex/modules/video_detection/trainer.py +82 -0
  950. paddlex/ops/__init__.py +152 -0
  951. paddlex/ops/iou3d_nms/iou3d_cpu.cpp +266 -0
  952. paddlex/ops/iou3d_nms/iou3d_cpu.h +28 -0
  953. paddlex/ops/iou3d_nms/iou3d_nms.cpp +206 -0
  954. paddlex/ops/iou3d_nms/iou3d_nms.h +35 -0
  955. paddlex/ops/iou3d_nms/iou3d_nms_api.cpp +114 -0
  956. paddlex/ops/iou3d_nms/iou3d_nms_kernel.cu +484 -0
  957. paddlex/ops/setup.py +37 -0
  958. paddlex/ops/voxel/voxelize_op.cc +194 -0
  959. paddlex/ops/voxel/voxelize_op.cu +346 -0
  960. paddlex/paddlex_cli.py +476 -0
  961. paddlex/repo_apis/Paddle3D_api/__init__.py +17 -0
  962. paddlex/repo_apis/Paddle3D_api/bev_fusion/__init__.py +18 -0
  963. paddlex/repo_apis/Paddle3D_api/bev_fusion/config.py +118 -0
  964. paddlex/repo_apis/Paddle3D_api/bev_fusion/model.py +238 -0
  965. paddlex/repo_apis/Paddle3D_api/bev_fusion/register.py +55 -0
  966. paddlex/repo_apis/Paddle3D_api/bev_fusion/runner.py +104 -0
  967. paddlex/repo_apis/Paddle3D_api/pp3d_config.py +145 -0
  968. paddlex/repo_apis/PaddleClas_api/__init__.py +17 -0
  969. paddlex/repo_apis/PaddleClas_api/cls/__init__.py +19 -0
  970. paddlex/repo_apis/PaddleClas_api/cls/config.py +595 -0
  971. paddlex/repo_apis/PaddleClas_api/cls/model.py +355 -0
  972. paddlex/repo_apis/PaddleClas_api/cls/register.py +907 -0
  973. paddlex/repo_apis/PaddleClas_api/cls/runner.py +218 -0
  974. paddlex/repo_apis/PaddleClas_api/shitu_rec/__init__.py +18 -0
  975. paddlex/repo_apis/PaddleClas_api/shitu_rec/config.py +141 -0
  976. paddlex/repo_apis/PaddleClas_api/shitu_rec/model.py +20 -0
  977. paddlex/repo_apis/PaddleClas_api/shitu_rec/register.py +68 -0
  978. paddlex/repo_apis/PaddleClas_api/shitu_rec/runner.py +50 -0
  979. paddlex/repo_apis/PaddleDetection_api/__init__.py +17 -0
  980. paddlex/repo_apis/PaddleDetection_api/config_helper.py +280 -0
  981. paddlex/repo_apis/PaddleDetection_api/instance_seg/__init__.py +18 -0
  982. paddlex/repo_apis/PaddleDetection_api/instance_seg/config.py +457 -0
  983. paddlex/repo_apis/PaddleDetection_api/instance_seg/model.py +403 -0
  984. paddlex/repo_apis/PaddleDetection_api/instance_seg/register.py +262 -0
  985. paddlex/repo_apis/PaddleDetection_api/instance_seg/runner.py +225 -0
  986. paddlex/repo_apis/PaddleDetection_api/object_det/__init__.py +19 -0
  987. paddlex/repo_apis/PaddleDetection_api/object_det/config.py +540 -0
  988. paddlex/repo_apis/PaddleDetection_api/object_det/model.py +429 -0
  989. paddlex/repo_apis/PaddleDetection_api/object_det/official_categories.py +245 -0
  990. paddlex/repo_apis/PaddleDetection_api/object_det/register.py +1135 -0
  991. paddlex/repo_apis/PaddleDetection_api/object_det/runner.py +225 -0
  992. paddlex/repo_apis/PaddleNLP_api/__init__.py +13 -0
  993. paddlex/repo_apis/PaddleOCR_api/__init__.py +22 -0
  994. paddlex/repo_apis/PaddleOCR_api/config_utils.py +53 -0
  995. paddlex/repo_apis/PaddleOCR_api/formula_rec/__init__.py +16 -0
  996. paddlex/repo_apis/PaddleOCR_api/formula_rec/config.py +571 -0
  997. paddlex/repo_apis/PaddleOCR_api/formula_rec/model.py +398 -0
  998. paddlex/repo_apis/PaddleOCR_api/formula_rec/register.py +99 -0
  999. paddlex/repo_apis/PaddleOCR_api/formula_rec/runner.py +239 -0
  1000. paddlex/repo_apis/PaddleOCR_api/table_rec/__init__.py +16 -0
  1001. paddlex/repo_apis/PaddleOCR_api/table_rec/config.py +64 -0
  1002. paddlex/repo_apis/PaddleOCR_api/table_rec/model.py +126 -0
  1003. paddlex/repo_apis/PaddleOCR_api/table_rec/register.py +70 -0
  1004. paddlex/repo_apis/PaddleOCR_api/table_rec/runner.py +51 -0
  1005. paddlex/repo_apis/PaddleOCR_api/text_det/__init__.py +16 -0
  1006. paddlex/repo_apis/PaddleOCR_api/text_det/config.py +62 -0
  1007. paddlex/repo_apis/PaddleOCR_api/text_det/model.py +72 -0
  1008. paddlex/repo_apis/PaddleOCR_api/text_det/register.py +107 -0
  1009. paddlex/repo_apis/PaddleOCR_api/text_det/runner.py +53 -0
  1010. paddlex/repo_apis/PaddleOCR_api/text_rec/__init__.py +16 -0
  1011. paddlex/repo_apis/PaddleOCR_api/text_rec/config.py +564 -0
  1012. paddlex/repo_apis/PaddleOCR_api/text_rec/model.py +398 -0
  1013. paddlex/repo_apis/PaddleOCR_api/text_rec/register.py +216 -0
  1014. paddlex/repo_apis/PaddleOCR_api/text_rec/runner.py +239 -0
  1015. paddlex/repo_apis/PaddleSeg_api/__init__.py +16 -0
  1016. paddlex/repo_apis/PaddleSeg_api/base_seg_config.py +134 -0
  1017. paddlex/repo_apis/PaddleSeg_api/seg/__init__.py +16 -0
  1018. paddlex/repo_apis/PaddleSeg_api/seg/config.py +183 -0
  1019. paddlex/repo_apis/PaddleSeg_api/seg/model.py +491 -0
  1020. paddlex/repo_apis/PaddleSeg_api/seg/register.py +272 -0
  1021. paddlex/repo_apis/PaddleSeg_api/seg/runner.py +261 -0
  1022. paddlex/repo_apis/PaddleTS_api/__init__.py +20 -0
  1023. paddlex/repo_apis/PaddleTS_api/ts_ad/__init__.py +16 -0
  1024. paddlex/repo_apis/PaddleTS_api/ts_ad/config.py +88 -0
  1025. paddlex/repo_apis/PaddleTS_api/ts_ad/register.py +146 -0
  1026. paddlex/repo_apis/PaddleTS_api/ts_ad/runner.py +158 -0
  1027. paddlex/repo_apis/PaddleTS_api/ts_base/__init__.py +13 -0
  1028. paddlex/repo_apis/PaddleTS_api/ts_base/config.py +244 -0
  1029. paddlex/repo_apis/PaddleTS_api/ts_base/model.py +276 -0
  1030. paddlex/repo_apis/PaddleTS_api/ts_base/runner.py +158 -0
  1031. paddlex/repo_apis/PaddleTS_api/ts_cls/__init__.py +16 -0
  1032. paddlex/repo_apis/PaddleTS_api/ts_cls/config.py +72 -0
  1033. paddlex/repo_apis/PaddleTS_api/ts_cls/register.py +59 -0
  1034. paddlex/repo_apis/PaddleTS_api/ts_cls/runner.py +158 -0
  1035. paddlex/repo_apis/PaddleTS_api/ts_fc/__init__.py +16 -0
  1036. paddlex/repo_apis/PaddleTS_api/ts_fc/config.py +136 -0
  1037. paddlex/repo_apis/PaddleTS_api/ts_fc/register.py +186 -0
  1038. paddlex/repo_apis/PaddleVideo_api/__init__.py +17 -0
  1039. paddlex/repo_apis/PaddleVideo_api/config_utils.py +51 -0
  1040. paddlex/repo_apis/PaddleVideo_api/video_cls/__init__.py +19 -0
  1041. paddlex/repo_apis/PaddleVideo_api/video_cls/config.py +548 -0
  1042. paddlex/repo_apis/PaddleVideo_api/video_cls/model.py +346 -0
  1043. paddlex/repo_apis/PaddleVideo_api/video_cls/register.py +70 -0
  1044. paddlex/repo_apis/PaddleVideo_api/video_cls/runner.py +204 -0
  1045. paddlex/repo_apis/PaddleVideo_api/video_det/__init__.py +19 -0
  1046. paddlex/repo_apis/PaddleVideo_api/video_det/config.py +549 -0
  1047. paddlex/repo_apis/PaddleVideo_api/video_det/model.py +298 -0
  1048. paddlex/repo_apis/PaddleVideo_api/video_det/register.py +44 -0
  1049. paddlex/repo_apis/PaddleVideo_api/video_det/runner.py +199 -0
  1050. paddlex/repo_apis/__init__.py +13 -0
  1051. paddlex/repo_apis/base/__init__.py +22 -0
  1052. paddlex/repo_apis/base/config.py +237 -0
  1053. paddlex/repo_apis/base/model.py +563 -0
  1054. paddlex/repo_apis/base/register.py +135 -0
  1055. paddlex/repo_apis/base/runner.py +390 -0
  1056. paddlex/repo_apis/base/utils/__init__.py +13 -0
  1057. paddlex/repo_apis/base/utils/arg.py +64 -0
  1058. paddlex/repo_apis/base/utils/subprocess.py +107 -0
  1059. paddlex/repo_manager/__init__.py +17 -0
  1060. paddlex/repo_manager/core.py +253 -0
  1061. paddlex/repo_manager/meta.py +180 -0
  1062. paddlex/repo_manager/repo.py +425 -0
  1063. paddlex/repo_manager/utils.py +148 -0
  1064. paddlex/utils/__init__.py +1 -12
  1065. paddlex/utils/cache.py +146 -0
  1066. paddlex/utils/config.py +216 -0
  1067. paddlex/utils/custom_device_list.py +311 -0
  1068. paddlex/utils/deps.py +249 -0
  1069. paddlex/utils/device.py +195 -0
  1070. paddlex/utils/download.py +168 -182
  1071. paddlex/utils/env.py +31 -48
  1072. paddlex/utils/errors/__init__.py +17 -0
  1073. paddlex/utils/errors/dataset_checker.py +78 -0
  1074. paddlex/utils/errors/others.py +138 -0
  1075. paddlex/utils/file_interface.py +211 -0
  1076. paddlex/utils/flags.py +70 -0
  1077. paddlex/utils/fonts/__init__.py +97 -0
  1078. paddlex/utils/func_register.py +41 -0
  1079. paddlex/utils/install.py +87 -0
  1080. paddlex/utils/interactive_get_pipeline.py +55 -0
  1081. paddlex/utils/lazy_loader.py +68 -0
  1082. paddlex/utils/logging.py +140 -33
  1083. paddlex/utils/misc.py +201 -0
  1084. paddlex/utils/pipeline_arguments.py +719 -0
  1085. paddlex/utils/result_saver.py +58 -0
  1086. paddlex/utils/subclass_register.py +99 -0
  1087. paddlex/version.py +55 -0
  1088. paddlex-3.0.0.dist-info/METADATA +1168 -0
  1089. paddlex-3.0.0.dist-info/RECORD +1093 -0
  1090. paddlex-3.0.0.dist-info/WHEEL +5 -0
  1091. paddlex-3.0.0.dist-info/entry_points.txt +2 -0
  1092. paddlex-3.0.0.dist-info/licenses/LICENSE +169 -0
  1093. paddlex-3.0.0.dist-info/top_level.txt +1 -0
  1094. PaddleClas/__init__.py +0 -16
  1095. PaddleClas/paddleclas.py +0 -375
  1096. PaddleClas/ppcls/__init__.py +0 -20
  1097. PaddleClas/ppcls/data/__init__.py +0 -15
  1098. PaddleClas/ppcls/data/imaug/__init__.py +0 -94
  1099. PaddleClas/ppcls/data/imaug/autoaugment.py +0 -264
  1100. PaddleClas/ppcls/data/imaug/batch_operators.py +0 -117
  1101. PaddleClas/ppcls/data/imaug/cutout.py +0 -41
  1102. PaddleClas/ppcls/data/imaug/fmix.py +0 -217
  1103. PaddleClas/ppcls/data/imaug/grid.py +0 -89
  1104. PaddleClas/ppcls/data/imaug/hide_and_seek.py +0 -44
  1105. PaddleClas/ppcls/data/imaug/operators.py +0 -244
  1106. PaddleClas/ppcls/data/imaug/randaugment.py +0 -106
  1107. PaddleClas/ppcls/data/imaug/random_erasing.py +0 -55
  1108. PaddleClas/ppcls/data/reader.py +0 -318
  1109. PaddleClas/ppcls/modeling/__init__.py +0 -20
  1110. PaddleClas/ppcls/modeling/architectures/__init__.py +0 -51
  1111. PaddleClas/ppcls/modeling/architectures/alexnet.py +0 -132
  1112. PaddleClas/ppcls/modeling/architectures/darknet.py +0 -161
  1113. PaddleClas/ppcls/modeling/architectures/densenet.py +0 -308
  1114. PaddleClas/ppcls/modeling/architectures/distillation_models.py +0 -65
  1115. PaddleClas/ppcls/modeling/architectures/distilled_vision_transformer.py +0 -196
  1116. PaddleClas/ppcls/modeling/architectures/dpn.py +0 -425
  1117. PaddleClas/ppcls/modeling/architectures/efficientnet.py +0 -901
  1118. PaddleClas/ppcls/modeling/architectures/ghostnet.py +0 -331
  1119. PaddleClas/ppcls/modeling/architectures/googlenet.py +0 -207
  1120. PaddleClas/ppcls/modeling/architectures/hrnet.py +0 -742
  1121. PaddleClas/ppcls/modeling/architectures/inception_v3.py +0 -481
  1122. PaddleClas/ppcls/modeling/architectures/inception_v4.py +0 -455
  1123. PaddleClas/ppcls/modeling/architectures/mixnet.py +0 -782
  1124. PaddleClas/ppcls/modeling/architectures/mobilenet_v1.py +0 -266
  1125. PaddleClas/ppcls/modeling/architectures/mobilenet_v2.py +0 -248
  1126. PaddleClas/ppcls/modeling/architectures/mobilenet_v3.py +0 -359
  1127. PaddleClas/ppcls/modeling/architectures/regnet.py +0 -383
  1128. PaddleClas/ppcls/modeling/architectures/repvgg.py +0 -339
  1129. PaddleClas/ppcls/modeling/architectures/res2net.py +0 -272
  1130. PaddleClas/ppcls/modeling/architectures/res2net_vd.py +0 -295
  1131. PaddleClas/ppcls/modeling/architectures/resnest.py +0 -705
  1132. PaddleClas/ppcls/modeling/architectures/resnet.py +0 -316
  1133. PaddleClas/ppcls/modeling/architectures/resnet_vc.py +0 -309
  1134. PaddleClas/ppcls/modeling/architectures/resnet_vd.py +0 -354
  1135. PaddleClas/ppcls/modeling/architectures/resnext.py +0 -253
  1136. PaddleClas/ppcls/modeling/architectures/resnext101_wsl.py +0 -447
  1137. PaddleClas/ppcls/modeling/architectures/resnext_vd.py +0 -266
  1138. PaddleClas/ppcls/modeling/architectures/rexnet.py +0 -240
  1139. PaddleClas/ppcls/modeling/architectures/se_resnet_vd.py +0 -378
  1140. PaddleClas/ppcls/modeling/architectures/se_resnext.py +0 -290
  1141. PaddleClas/ppcls/modeling/architectures/se_resnext_vd.py +0 -285
  1142. PaddleClas/ppcls/modeling/architectures/shufflenet_v2.py +0 -320
  1143. PaddleClas/ppcls/modeling/architectures/squeezenet.py +0 -154
  1144. PaddleClas/ppcls/modeling/architectures/vgg.py +0 -152
  1145. PaddleClas/ppcls/modeling/architectures/vision_transformer.py +0 -402
  1146. PaddleClas/ppcls/modeling/architectures/xception.py +0 -345
  1147. PaddleClas/ppcls/modeling/architectures/xception_deeplab.py +0 -386
  1148. PaddleClas/ppcls/modeling/loss.py +0 -154
  1149. PaddleClas/ppcls/modeling/utils.py +0 -53
  1150. PaddleClas/ppcls/optimizer/__init__.py +0 -19
  1151. PaddleClas/ppcls/optimizer/learning_rate.py +0 -159
  1152. PaddleClas/ppcls/optimizer/optimizer.py +0 -165
  1153. PaddleClas/ppcls/utils/__init__.py +0 -27
  1154. PaddleClas/ppcls/utils/check.py +0 -151
  1155. PaddleClas/ppcls/utils/config.py +0 -201
  1156. PaddleClas/ppcls/utils/logger.py +0 -120
  1157. PaddleClas/ppcls/utils/metrics.py +0 -107
  1158. PaddleClas/ppcls/utils/misc.py +0 -62
  1159. PaddleClas/ppcls/utils/model_zoo.py +0 -213
  1160. PaddleClas/ppcls/utils/save_load.py +0 -163
  1161. PaddleClas/setup.py +0 -55
  1162. PaddleClas/tools/__init__.py +0 -15
  1163. PaddleClas/tools/download.py +0 -50
  1164. PaddleClas/tools/ema.py +0 -58
  1165. PaddleClas/tools/eval.py +0 -112
  1166. PaddleClas/tools/export_model.py +0 -85
  1167. PaddleClas/tools/export_serving_model.py +0 -76
  1168. PaddleClas/tools/infer/__init__.py +0 -16
  1169. PaddleClas/tools/infer/infer.py +0 -94
  1170. PaddleClas/tools/infer/predict.py +0 -117
  1171. PaddleClas/tools/infer/utils.py +0 -233
  1172. PaddleClas/tools/program.py +0 -444
  1173. PaddleClas/tools/test_hubserving.py +0 -113
  1174. PaddleClas/tools/train.py +0 -141
  1175. paddlex/cls.py +0 -76
  1176. paddlex/command.py +0 -215
  1177. paddlex/cv/__init__.py +0 -17
  1178. paddlex/cv/datasets/__init__.py +0 -18
  1179. paddlex/cv/datasets/coco.py +0 -169
  1180. paddlex/cv/datasets/imagenet.py +0 -88
  1181. paddlex/cv/datasets/seg_dataset.py +0 -91
  1182. paddlex/cv/datasets/voc.py +0 -301
  1183. paddlex/cv/models/__init__.py +0 -18
  1184. paddlex/cv/models/base.py +0 -623
  1185. paddlex/cv/models/classifier.py +0 -814
  1186. paddlex/cv/models/detector.py +0 -1747
  1187. paddlex/cv/models/load_model.py +0 -126
  1188. paddlex/cv/models/segmenter.py +0 -673
  1189. paddlex/cv/models/slim/__init__.py +0 -13
  1190. paddlex/cv/models/slim/prune.py +0 -55
  1191. paddlex/cv/models/utils/__init__.py +0 -13
  1192. paddlex/cv/models/utils/det_metrics/__init__.py +0 -15
  1193. paddlex/cv/models/utils/det_metrics/coco_utils.py +0 -217
  1194. paddlex/cv/models/utils/det_metrics/metrics.py +0 -220
  1195. paddlex/cv/models/utils/ema.py +0 -48
  1196. paddlex/cv/models/utils/seg_metrics.py +0 -62
  1197. paddlex/cv/models/utils/visualize.py +0 -394
  1198. paddlex/cv/transforms/__init__.py +0 -46
  1199. paddlex/cv/transforms/batch_operators.py +0 -286
  1200. paddlex/cv/transforms/box_utils.py +0 -41
  1201. paddlex/cv/transforms/functions.py +0 -193
  1202. paddlex/cv/transforms/operators.py +0 -1402
  1203. paddlex/det.py +0 -43
  1204. paddlex/paddleseg/__init__.py +0 -17
  1205. paddlex/paddleseg/core/__init__.py +0 -20
  1206. paddlex/paddleseg/core/infer.py +0 -289
  1207. paddlex/paddleseg/core/predict.py +0 -145
  1208. paddlex/paddleseg/core/train.py +0 -258
  1209. paddlex/paddleseg/core/val.py +0 -172
  1210. paddlex/paddleseg/cvlibs/__init__.py +0 -17
  1211. paddlex/paddleseg/cvlibs/callbacks.py +0 -279
  1212. paddlex/paddleseg/cvlibs/config.py +0 -359
  1213. paddlex/paddleseg/cvlibs/manager.py +0 -142
  1214. paddlex/paddleseg/cvlibs/param_init.py +0 -91
  1215. paddlex/paddleseg/datasets/__init__.py +0 -21
  1216. paddlex/paddleseg/datasets/ade.py +0 -112
  1217. paddlex/paddleseg/datasets/cityscapes.py +0 -86
  1218. paddlex/paddleseg/datasets/cocostuff.py +0 -79
  1219. paddlex/paddleseg/datasets/dataset.py +0 -164
  1220. paddlex/paddleseg/datasets/mini_deep_globe_road_extraction.py +0 -95
  1221. paddlex/paddleseg/datasets/optic_disc_seg.py +0 -97
  1222. paddlex/paddleseg/datasets/pascal_context.py +0 -80
  1223. paddlex/paddleseg/datasets/voc.py +0 -113
  1224. paddlex/paddleseg/models/__init__.py +0 -39
  1225. paddlex/paddleseg/models/ann.py +0 -436
  1226. paddlex/paddleseg/models/attention_unet.py +0 -189
  1227. paddlex/paddleseg/models/backbones/__init__.py +0 -18
  1228. paddlex/paddleseg/models/backbones/hrnet.py +0 -815
  1229. paddlex/paddleseg/models/backbones/mobilenetv3.py +0 -365
  1230. paddlex/paddleseg/models/backbones/resnet_vd.py +0 -364
  1231. paddlex/paddleseg/models/backbones/xception_deeplab.py +0 -415
  1232. paddlex/paddleseg/models/bisenet.py +0 -311
  1233. paddlex/paddleseg/models/danet.py +0 -220
  1234. paddlex/paddleseg/models/decoupled_segnet.py +0 -233
  1235. paddlex/paddleseg/models/deeplab.py +0 -258
  1236. paddlex/paddleseg/models/dnlnet.py +0 -231
  1237. paddlex/paddleseg/models/emanet.py +0 -219
  1238. paddlex/paddleseg/models/fast_scnn.py +0 -318
  1239. paddlex/paddleseg/models/fcn.py +0 -135
  1240. paddlex/paddleseg/models/gcnet.py +0 -223
  1241. paddlex/paddleseg/models/gscnn.py +0 -357
  1242. paddlex/paddleseg/models/hardnet.py +0 -309
  1243. paddlex/paddleseg/models/isanet.py +0 -202
  1244. paddlex/paddleseg/models/layers/__init__.py +0 -19
  1245. paddlex/paddleseg/models/layers/activation.py +0 -73
  1246. paddlex/paddleseg/models/layers/attention.py +0 -146
  1247. paddlex/paddleseg/models/layers/layer_libs.py +0 -168
  1248. paddlex/paddleseg/models/layers/nonlocal2d.py +0 -155
  1249. paddlex/paddleseg/models/layers/pyramid_pool.py +0 -182
  1250. paddlex/paddleseg/models/losses/__init__.py +0 -27
  1251. paddlex/paddleseg/models/losses/binary_cross_entropy_loss.py +0 -174
  1252. paddlex/paddleseg/models/losses/bootstrapped_cross_entropy.py +0 -73
  1253. paddlex/paddleseg/models/losses/cross_entropy_loss.py +0 -94
  1254. paddlex/paddleseg/models/losses/decoupledsegnet_relax_boundary_loss.py +0 -129
  1255. paddlex/paddleseg/models/losses/dice_loss.py +0 -61
  1256. paddlex/paddleseg/models/losses/edge_attention_loss.py +0 -78
  1257. paddlex/paddleseg/models/losses/gscnn_dual_task_loss.py +0 -141
  1258. paddlex/paddleseg/models/losses/l1_loss.py +0 -76
  1259. paddlex/paddleseg/models/losses/lovasz_loss.py +0 -222
  1260. paddlex/paddleseg/models/losses/mean_square_error_loss.py +0 -65
  1261. paddlex/paddleseg/models/losses/mixed_loss.py +0 -58
  1262. paddlex/paddleseg/models/losses/ohem_cross_entropy_loss.py +0 -99
  1263. paddlex/paddleseg/models/losses/ohem_edge_attention_loss.py +0 -114
  1264. paddlex/paddleseg/models/ocrnet.py +0 -248
  1265. paddlex/paddleseg/models/pspnet.py +0 -147
  1266. paddlex/paddleseg/models/sfnet.py +0 -236
  1267. paddlex/paddleseg/models/shufflenet_slim.py +0 -268
  1268. paddlex/paddleseg/models/u2net.py +0 -574
  1269. paddlex/paddleseg/models/unet.py +0 -155
  1270. paddlex/paddleseg/models/unet_3plus.py +0 -316
  1271. paddlex/paddleseg/models/unet_plusplus.py +0 -237
  1272. paddlex/paddleseg/transforms/__init__.py +0 -16
  1273. paddlex/paddleseg/transforms/functional.py +0 -161
  1274. paddlex/paddleseg/transforms/transforms.py +0 -937
  1275. paddlex/paddleseg/utils/__init__.py +0 -22
  1276. paddlex/paddleseg/utils/config_check.py +0 -60
  1277. paddlex/paddleseg/utils/download.py +0 -163
  1278. paddlex/paddleseg/utils/env/__init__.py +0 -16
  1279. paddlex/paddleseg/utils/env/seg_env.py +0 -56
  1280. paddlex/paddleseg/utils/env/sys_env.py +0 -122
  1281. paddlex/paddleseg/utils/logger.py +0 -48
  1282. paddlex/paddleseg/utils/metrics.py +0 -146
  1283. paddlex/paddleseg/utils/progbar.py +0 -212
  1284. paddlex/paddleseg/utils/timer.py +0 -53
  1285. paddlex/paddleseg/utils/utils.py +0 -120
  1286. paddlex/paddleseg/utils/visualize.py +0 -90
  1287. paddlex/ppcls/__init__.py +0 -20
  1288. paddlex/ppcls/data/__init__.py +0 -15
  1289. paddlex/ppcls/data/imaug/__init__.py +0 -94
  1290. paddlex/ppcls/data/imaug/autoaugment.py +0 -264
  1291. paddlex/ppcls/data/imaug/batch_operators.py +0 -117
  1292. paddlex/ppcls/data/imaug/cutout.py +0 -41
  1293. paddlex/ppcls/data/imaug/fmix.py +0 -217
  1294. paddlex/ppcls/data/imaug/grid.py +0 -89
  1295. paddlex/ppcls/data/imaug/hide_and_seek.py +0 -44
  1296. paddlex/ppcls/data/imaug/operators.py +0 -256
  1297. paddlex/ppcls/data/imaug/randaugment.py +0 -106
  1298. paddlex/ppcls/data/imaug/random_erasing.py +0 -55
  1299. paddlex/ppcls/data/reader.py +0 -318
  1300. paddlex/ppcls/modeling/__init__.py +0 -20
  1301. paddlex/ppcls/modeling/architectures/__init__.py +0 -51
  1302. paddlex/ppcls/modeling/architectures/alexnet.py +0 -132
  1303. paddlex/ppcls/modeling/architectures/darknet.py +0 -161
  1304. paddlex/ppcls/modeling/architectures/densenet.py +0 -308
  1305. paddlex/ppcls/modeling/architectures/distillation_models.py +0 -65
  1306. paddlex/ppcls/modeling/architectures/distilled_vision_transformer.py +0 -196
  1307. paddlex/ppcls/modeling/architectures/dpn.py +0 -425
  1308. paddlex/ppcls/modeling/architectures/efficientnet.py +0 -901
  1309. paddlex/ppcls/modeling/architectures/ghostnet.py +0 -331
  1310. paddlex/ppcls/modeling/architectures/googlenet.py +0 -207
  1311. paddlex/ppcls/modeling/architectures/hrnet.py +0 -742
  1312. paddlex/ppcls/modeling/architectures/inception_v3.py +0 -541
  1313. paddlex/ppcls/modeling/architectures/inception_v4.py +0 -455
  1314. paddlex/ppcls/modeling/architectures/mixnet.py +0 -782
  1315. paddlex/ppcls/modeling/architectures/mobilenet_v1.py +0 -266
  1316. paddlex/ppcls/modeling/architectures/mobilenet_v2.py +0 -248
  1317. paddlex/ppcls/modeling/architectures/mobilenet_v3.py +0 -359
  1318. paddlex/ppcls/modeling/architectures/regnet.py +0 -383
  1319. paddlex/ppcls/modeling/architectures/repvgg.py +0 -339
  1320. paddlex/ppcls/modeling/architectures/res2net.py +0 -272
  1321. paddlex/ppcls/modeling/architectures/res2net_vd.py +0 -295
  1322. paddlex/ppcls/modeling/architectures/resnest.py +0 -705
  1323. paddlex/ppcls/modeling/architectures/resnet.py +0 -317
  1324. paddlex/ppcls/modeling/architectures/resnet_vc.py +0 -309
  1325. paddlex/ppcls/modeling/architectures/resnet_vd.py +0 -354
  1326. paddlex/ppcls/modeling/architectures/resnext.py +0 -259
  1327. paddlex/ppcls/modeling/architectures/resnext101_wsl.py +0 -447
  1328. paddlex/ppcls/modeling/architectures/resnext_vd.py +0 -266
  1329. paddlex/ppcls/modeling/architectures/rexnet.py +0 -240
  1330. paddlex/ppcls/modeling/architectures/se_resnet_vd.py +0 -378
  1331. paddlex/ppcls/modeling/architectures/se_resnext.py +0 -290
  1332. paddlex/ppcls/modeling/architectures/se_resnext_vd.py +0 -285
  1333. paddlex/ppcls/modeling/architectures/shufflenet_v2.py +0 -320
  1334. paddlex/ppcls/modeling/architectures/squeezenet.py +0 -154
  1335. paddlex/ppcls/modeling/architectures/vgg.py +0 -152
  1336. paddlex/ppcls/modeling/architectures/vision_transformer.py +0 -402
  1337. paddlex/ppcls/modeling/architectures/xception.py +0 -345
  1338. paddlex/ppcls/modeling/architectures/xception_deeplab.py +0 -386
  1339. paddlex/ppcls/modeling/loss.py +0 -158
  1340. paddlex/ppcls/modeling/utils.py +0 -53
  1341. paddlex/ppcls/optimizer/__init__.py +0 -19
  1342. paddlex/ppcls/optimizer/learning_rate.py +0 -159
  1343. paddlex/ppcls/optimizer/optimizer.py +0 -165
  1344. paddlex/ppcls/utils/__init__.py +0 -27
  1345. paddlex/ppcls/utils/check.py +0 -151
  1346. paddlex/ppcls/utils/config.py +0 -201
  1347. paddlex/ppcls/utils/logger.py +0 -120
  1348. paddlex/ppcls/utils/metrics.py +0 -112
  1349. paddlex/ppcls/utils/misc.py +0 -62
  1350. paddlex/ppcls/utils/model_zoo.py +0 -213
  1351. paddlex/ppcls/utils/save_load.py +0 -163
  1352. paddlex/ppdet/__init__.py +0 -16
  1353. paddlex/ppdet/core/__init__.py +0 -15
  1354. paddlex/ppdet/core/config/__init__.py +0 -13
  1355. paddlex/ppdet/core/config/schema.py +0 -248
  1356. paddlex/ppdet/core/config/yaml_helpers.py +0 -118
  1357. paddlex/ppdet/core/workspace.py +0 -279
  1358. paddlex/ppdet/data/__init__.py +0 -21
  1359. paddlex/ppdet/data/reader.py +0 -304
  1360. paddlex/ppdet/data/shm_utils.py +0 -67
  1361. paddlex/ppdet/data/source/__init__.py +0 -27
  1362. paddlex/ppdet/data/source/category.py +0 -823
  1363. paddlex/ppdet/data/source/coco.py +0 -243
  1364. paddlex/ppdet/data/source/dataset.py +0 -192
  1365. paddlex/ppdet/data/source/keypoint_coco.py +0 -656
  1366. paddlex/ppdet/data/source/mot.py +0 -360
  1367. paddlex/ppdet/data/source/voc.py +0 -204
  1368. paddlex/ppdet/data/source/widerface.py +0 -180
  1369. paddlex/ppdet/data/transform/__init__.py +0 -28
  1370. paddlex/ppdet/data/transform/autoaugment_utils.py +0 -1593
  1371. paddlex/ppdet/data/transform/batch_operators.py +0 -758
  1372. paddlex/ppdet/data/transform/gridmask_utils.py +0 -83
  1373. paddlex/ppdet/data/transform/keypoint_operators.py +0 -665
  1374. paddlex/ppdet/data/transform/mot_operators.py +0 -636
  1375. paddlex/ppdet/data/transform/op_helper.py +0 -468
  1376. paddlex/ppdet/data/transform/operators.py +0 -2103
  1377. paddlex/ppdet/engine/__init__.py +0 -29
  1378. paddlex/ppdet/engine/callbacks.py +0 -262
  1379. paddlex/ppdet/engine/env.py +0 -47
  1380. paddlex/ppdet/engine/export_utils.py +0 -118
  1381. paddlex/ppdet/engine/tracker.py +0 -425
  1382. paddlex/ppdet/engine/trainer.py +0 -535
  1383. paddlex/ppdet/metrics/__init__.py +0 -23
  1384. paddlex/ppdet/metrics/coco_utils.py +0 -184
  1385. paddlex/ppdet/metrics/json_results.py +0 -151
  1386. paddlex/ppdet/metrics/keypoint_metrics.py +0 -202
  1387. paddlex/ppdet/metrics/map_utils.py +0 -396
  1388. paddlex/ppdet/metrics/metrics.py +0 -300
  1389. paddlex/ppdet/metrics/mot_eval_utils.py +0 -192
  1390. paddlex/ppdet/metrics/mot_metrics.py +0 -184
  1391. paddlex/ppdet/metrics/widerface_utils.py +0 -393
  1392. paddlex/ppdet/model_zoo/__init__.py +0 -18
  1393. paddlex/ppdet/model_zoo/model_zoo.py +0 -86
  1394. paddlex/ppdet/model_zoo/tests/__init__.py +0 -13
  1395. paddlex/ppdet/model_zoo/tests/test_get_model.py +0 -48
  1396. paddlex/ppdet/model_zoo/tests/test_list_model.py +0 -68
  1397. paddlex/ppdet/modeling/__init__.py +0 -41
  1398. paddlex/ppdet/modeling/architectures/__init__.py +0 -40
  1399. paddlex/ppdet/modeling/architectures/cascade_rcnn.py +0 -144
  1400. paddlex/ppdet/modeling/architectures/centernet.py +0 -103
  1401. paddlex/ppdet/modeling/architectures/deepsort.py +0 -111
  1402. paddlex/ppdet/modeling/architectures/fairmot.py +0 -107
  1403. paddlex/ppdet/modeling/architectures/faster_rcnn.py +0 -106
  1404. paddlex/ppdet/modeling/architectures/fcos.py +0 -105
  1405. paddlex/ppdet/modeling/architectures/jde.py +0 -125
  1406. paddlex/ppdet/modeling/architectures/keypoint_hrhrnet.py +0 -286
  1407. paddlex/ppdet/modeling/architectures/keypoint_hrnet.py +0 -203
  1408. paddlex/ppdet/modeling/architectures/mask_rcnn.py +0 -135
  1409. paddlex/ppdet/modeling/architectures/meta_arch.py +0 -45
  1410. paddlex/ppdet/modeling/architectures/s2anet.py +0 -103
  1411. paddlex/ppdet/modeling/architectures/solov2.py +0 -110
  1412. paddlex/ppdet/modeling/architectures/ssd.py +0 -84
  1413. paddlex/ppdet/modeling/architectures/ttfnet.py +0 -98
  1414. paddlex/ppdet/modeling/architectures/yolo.py +0 -104
  1415. paddlex/ppdet/modeling/backbones/__init__.py +0 -37
  1416. paddlex/ppdet/modeling/backbones/blazenet.py +0 -322
  1417. paddlex/ppdet/modeling/backbones/darknet.py +0 -341
  1418. paddlex/ppdet/modeling/backbones/dla.py +0 -244
  1419. paddlex/ppdet/modeling/backbones/ghostnet.py +0 -476
  1420. paddlex/ppdet/modeling/backbones/hrnet.py +0 -724
  1421. paddlex/ppdet/modeling/backbones/mobilenet_v1.py +0 -410
  1422. paddlex/ppdet/modeling/backbones/mobilenet_v3.py +0 -497
  1423. paddlex/ppdet/modeling/backbones/name_adapter.py +0 -69
  1424. paddlex/ppdet/modeling/backbones/res2net.py +0 -358
  1425. paddlex/ppdet/modeling/backbones/resnet.py +0 -606
  1426. paddlex/ppdet/modeling/backbones/senet.py +0 -140
  1427. paddlex/ppdet/modeling/backbones/vgg.py +0 -216
  1428. paddlex/ppdet/modeling/bbox_utils.py +0 -464
  1429. paddlex/ppdet/modeling/heads/__init__.py +0 -41
  1430. paddlex/ppdet/modeling/heads/bbox_head.py +0 -379
  1431. paddlex/ppdet/modeling/heads/cascade_head.py +0 -285
  1432. paddlex/ppdet/modeling/heads/centernet_head.py +0 -194
  1433. paddlex/ppdet/modeling/heads/face_head.py +0 -113
  1434. paddlex/ppdet/modeling/heads/fcos_head.py +0 -270
  1435. paddlex/ppdet/modeling/heads/keypoint_hrhrnet_head.py +0 -108
  1436. paddlex/ppdet/modeling/heads/mask_head.py +0 -253
  1437. paddlex/ppdet/modeling/heads/roi_extractor.py +0 -111
  1438. paddlex/ppdet/modeling/heads/s2anet_head.py +0 -845
  1439. paddlex/ppdet/modeling/heads/solov2_head.py +0 -537
  1440. paddlex/ppdet/modeling/heads/ssd_head.py +0 -175
  1441. paddlex/ppdet/modeling/heads/ttf_head.py +0 -314
  1442. paddlex/ppdet/modeling/heads/yolo_head.py +0 -124
  1443. paddlex/ppdet/modeling/keypoint_utils.py +0 -302
  1444. paddlex/ppdet/modeling/layers.py +0 -1142
  1445. paddlex/ppdet/modeling/losses/__init__.py +0 -35
  1446. paddlex/ppdet/modeling/losses/ctfocal_loss.py +0 -67
  1447. paddlex/ppdet/modeling/losses/fairmot_loss.py +0 -41
  1448. paddlex/ppdet/modeling/losses/fcos_loss.py +0 -225
  1449. paddlex/ppdet/modeling/losses/iou_aware_loss.py +0 -48
  1450. paddlex/ppdet/modeling/losses/iou_loss.py +0 -210
  1451. paddlex/ppdet/modeling/losses/jde_loss.py +0 -182
  1452. paddlex/ppdet/modeling/losses/keypoint_loss.py +0 -228
  1453. paddlex/ppdet/modeling/losses/solov2_loss.py +0 -101
  1454. paddlex/ppdet/modeling/losses/ssd_loss.py +0 -163
  1455. paddlex/ppdet/modeling/losses/yolo_loss.py +0 -212
  1456. paddlex/ppdet/modeling/mot/__init__.py +0 -25
  1457. paddlex/ppdet/modeling/mot/matching/__init__.py +0 -19
  1458. paddlex/ppdet/modeling/mot/matching/deepsort_matching.py +0 -382
  1459. paddlex/ppdet/modeling/mot/matching/jde_matching.py +0 -145
  1460. paddlex/ppdet/modeling/mot/motion/__init__.py +0 -17
  1461. paddlex/ppdet/modeling/mot/motion/kalman_filter.py +0 -270
  1462. paddlex/ppdet/modeling/mot/tracker/__init__.py +0 -23
  1463. paddlex/ppdet/modeling/mot/tracker/base_jde_tracker.py +0 -267
  1464. paddlex/ppdet/modeling/mot/tracker/base_sde_tracker.py +0 -145
  1465. paddlex/ppdet/modeling/mot/tracker/deepsort_tracker.py +0 -165
  1466. paddlex/ppdet/modeling/mot/tracker/jde_tracker.py +0 -262
  1467. paddlex/ppdet/modeling/mot/utils.py +0 -181
  1468. paddlex/ppdet/modeling/mot/visualization.py +0 -130
  1469. paddlex/ppdet/modeling/necks/__init__.py +0 -25
  1470. paddlex/ppdet/modeling/necks/centernet_fpn.py +0 -185
  1471. paddlex/ppdet/modeling/necks/fpn.py +0 -233
  1472. paddlex/ppdet/modeling/necks/hrfpn.py +0 -131
  1473. paddlex/ppdet/modeling/necks/ttf_fpn.py +0 -243
  1474. paddlex/ppdet/modeling/necks/yolo_fpn.py +0 -1034
  1475. paddlex/ppdet/modeling/ops.py +0 -1599
  1476. paddlex/ppdet/modeling/post_process.py +0 -449
  1477. paddlex/ppdet/modeling/proposal_generator/__init__.py +0 -2
  1478. paddlex/ppdet/modeling/proposal_generator/anchor_generator.py +0 -135
  1479. paddlex/ppdet/modeling/proposal_generator/proposal_generator.py +0 -81
  1480. paddlex/ppdet/modeling/proposal_generator/rpn_head.py +0 -269
  1481. paddlex/ppdet/modeling/proposal_generator/target.py +0 -671
  1482. paddlex/ppdet/modeling/proposal_generator/target_layer.py +0 -476
  1483. paddlex/ppdet/modeling/reid/__init__.py +0 -23
  1484. paddlex/ppdet/modeling/reid/fairmot_embedding_head.py +0 -117
  1485. paddlex/ppdet/modeling/reid/jde_embedding_head.py +0 -189
  1486. paddlex/ppdet/modeling/reid/pyramidal_embedding.py +0 -151
  1487. paddlex/ppdet/modeling/reid/resnet.py +0 -320
  1488. paddlex/ppdet/modeling/shape_spec.py +0 -33
  1489. paddlex/ppdet/modeling/tests/__init__.py +0 -13
  1490. paddlex/ppdet/modeling/tests/test_architectures.py +0 -59
  1491. paddlex/ppdet/modeling/tests/test_base.py +0 -75
  1492. paddlex/ppdet/modeling/tests/test_ops.py +0 -839
  1493. paddlex/ppdet/modeling/tests/test_yolov3_loss.py +0 -420
  1494. paddlex/ppdet/optimizer.py +0 -285
  1495. paddlex/ppdet/slim/__init__.py +0 -62
  1496. paddlex/ppdet/slim/distill.py +0 -111
  1497. paddlex/ppdet/slim/prune.py +0 -85
  1498. paddlex/ppdet/slim/quant.py +0 -52
  1499. paddlex/ppdet/utils/__init__.py +0 -13
  1500. paddlex/ppdet/utils/check.py +0 -93
  1501. paddlex/ppdet/utils/checkpoint.py +0 -216
  1502. paddlex/ppdet/utils/cli.py +0 -151
  1503. paddlex/ppdet/utils/colormap.py +0 -56
  1504. paddlex/ppdet/utils/download.py +0 -477
  1505. paddlex/ppdet/utils/logger.py +0 -71
  1506. paddlex/ppdet/utils/stats.py +0 -95
  1507. paddlex/ppdet/utils/visualizer.py +0 -292
  1508. paddlex/ppdet/utils/voc_utils.py +0 -87
  1509. paddlex/seg.py +0 -38
  1510. paddlex/tools/__init__.py +0 -16
  1511. paddlex/tools/convert.py +0 -52
  1512. paddlex/tools/dataset_conversion/__init__.py +0 -24
  1513. paddlex/tools/dataset_conversion/x2coco.py +0 -379
  1514. paddlex/tools/dataset_conversion/x2imagenet.py +0 -82
  1515. paddlex/tools/dataset_conversion/x2seg.py +0 -343
  1516. paddlex/tools/dataset_conversion/x2voc.py +0 -230
  1517. paddlex/tools/dataset_split/__init__.py +0 -23
  1518. paddlex/tools/dataset_split/coco_split.py +0 -69
  1519. paddlex/tools/dataset_split/imagenet_split.py +0 -75
  1520. paddlex/tools/dataset_split/seg_split.py +0 -96
  1521. paddlex/tools/dataset_split/utils.py +0 -75
  1522. paddlex/tools/dataset_split/voc_split.py +0 -91
  1523. paddlex/tools/split.py +0 -41
  1524. paddlex/utils/checkpoint.py +0 -439
  1525. paddlex/utils/shm.py +0 -67
  1526. paddlex/utils/stats.py +0 -68
  1527. paddlex/utils/utils.py +0 -140
  1528. paddlex-2.0.0rc4.dist-info/LICENSE +0 -201
  1529. paddlex-2.0.0rc4.dist-info/METADATA +0 -29
  1530. paddlex-2.0.0rc4.dist-info/RECORD +0 -445
  1531. paddlex-2.0.0rc4.dist-info/WHEEL +0 -5
  1532. paddlex-2.0.0rc4.dist-info/entry_points.txt +0 -3
  1533. paddlex-2.0.0rc4.dist-info/top_level.txt +0 -2
@@ -0,0 +1,2162 @@
1
+ # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import copy
16
+ import inspect
17
+ from typing import Optional, Union
18
+
19
+ import paddle
20
+ import paddle.distributed as dist
21
+ import paddle.nn as nn
22
+ import paddle.nn.functional as F
23
+ from paddle import Tensor
24
+ from paddle.common_ops_import import convert_dtype
25
+ from paddle.utils import map_structure
26
+
27
+ from ......utils import logging
28
+ from ..transformers.model_outputs import ModelOutput
29
+ from .configuration_utils import DEFAULT_MAX_NEW_TOKENS, GenerationConfig
30
+ from .logits_process import (
31
+ ForcedBOSTokenLogitsProcessor,
32
+ ForcedEOSTokenLogitsProcessor,
33
+ HammingDiversityLogitsProcessor,
34
+ LogitsProcessor,
35
+ LogitsProcessorList,
36
+ MinLengthLogitsProcessor,
37
+ NoRepeatNGramLogitsProcessor,
38
+ RepetitionPenaltyLogitsProcessor,
39
+ TopKProcess,
40
+ TopPProcess,
41
+ )
42
+ from .stopping_criteria import (
43
+ StoppingCriteria,
44
+ StoppingCriteriaList,
45
+ validate_stopping_criteria,
46
+ )
47
+
48
+ __all__ = [
49
+ "GenerationMixin",
50
+ "BeamSearchScorer",
51
+ "BeamHypotheses",
52
+ "LogitsProcessorList",
53
+ "LogitsProcessor",
54
+ "MinLengthLogitsProcessor",
55
+ "RepetitionPenaltyLogitsProcessor",
56
+ "TopKProcess",
57
+ "TopPProcess",
58
+ "get_unfinished_flag",
59
+ ]
60
+
61
+
62
+ def get_scale_by_dtype(dtype: str = None, return_positive: bool = True) -> float:
63
+ """get scale value by dtype
64
+
65
+ Args:
66
+ dtype (str): the string dtype value
67
+
68
+ Returns:
69
+ float: the scale value
70
+ """
71
+ if dtype is None:
72
+ dtype = paddle.get_default_dtype()
73
+
74
+ dtype = convert_dtype(dtype)
75
+ scale_value = 1e6
76
+
77
+ # TODO(wj-Mcaf): support int8, int4 dtypes later
78
+ if dtype == "float16":
79
+ scale_value = 1e4
80
+
81
+ if return_positive:
82
+ return scale_value
83
+ return -1 * scale_value
84
+
85
+
86
+ def get_unfinished_flag(
87
+ input_ids: Tensor,
88
+ unfinished_flag: Tensor,
89
+ eos_token_id: Union[int, list[int], list[list[int]]],
90
+ ) -> Tensor:
91
+ """get unfinished flag for generation step
92
+
93
+ Args:
94
+ input_ids (Tensor): the input_ids
95
+ eos_token_id (Union[int, list[int], list[list[int]]]): the end os sentence flag, which can be:
96
+ * single token id, eg: 10
97
+ * multiple token ids to stop generation, eg: [10, 10]
98
+ * some more tokens to stop generations, eg: [[10], [20, 20], [30, 30, 30]]
99
+
100
+ Returns:
101
+ Tensor: the unfinished flag tensor
102
+ """
103
+ if isinstance(eos_token_id, int):
104
+ unfinished_flag = paddle.logical_and(
105
+ unfinished_flag, input_ids[:, -1:] != eos_token_id
106
+ )
107
+ else:
108
+ batch_unfinish_flag = None
109
+ for batch_eos_token_id in eos_token_id:
110
+ if batch_unfinish_flag is None:
111
+ batch_unfinish_flag = ~get_unfinished_flag(
112
+ input_ids, unfinished_flag, batch_eos_token_id
113
+ )
114
+ else:
115
+ batch_unfinish_flag = paddle.logical_or(
116
+ batch_unfinish_flag,
117
+ ~get_unfinished_flag(
118
+ input_ids, unfinished_flag, batch_eos_token_id
119
+ ),
120
+ )
121
+
122
+ unfinished_flag = ~batch_unfinish_flag
123
+ return unfinished_flag
124
+
125
+
126
+ class BeamHypotheses:
127
+ def __init__(self, num_beams, length_penalty, early_stopping):
128
+ """
129
+ Initialize n-best list of hypotheses.
130
+ """
131
+ self.length_penalty = length_penalty
132
+ self.early_stopping = early_stopping
133
+ self.num_beams = num_beams
134
+ self.beams = []
135
+ self.worst_score = get_scale_by_dtype()
136
+
137
+ def __len__(self):
138
+ """
139
+ Number of hypotheses in the list.
140
+ """
141
+ return len(self.beams)
142
+
143
+ def add(self, hyp, sum_logprobs, origin_len=0):
144
+ """
145
+ Add a new hypothesis to the list.
146
+ """
147
+ score = sum_logprobs / (
148
+ ((hyp.shape[-1] - origin_len + 5) / 6) ** self.length_penalty
149
+ )
150
+ if len(self) < self.num_beams or score > self.worst_score:
151
+ self.beams.append((score, hyp))
152
+ if len(self) > self.num_beams:
153
+ sorted_next_scores = sorted(
154
+ [(s, idx) for idx, (s, _) in enumerate(self.beams)]
155
+ )
156
+ del self.beams[sorted_next_scores[0][1]]
157
+ self.worst_score = sorted_next_scores[1][0]
158
+ else:
159
+ self.worst_score = min(score, self.worst_score)
160
+
161
+ def is_done(self, best_sum_logprobs, cur_len, origin_len=0):
162
+ """
163
+ If there are enough hypotheses and that none of the hypotheses being
164
+ generated can become better than the worst one in the heap, then we
165
+ are done with this sentence.
166
+ """
167
+ if len(self) < self.num_beams:
168
+ return False
169
+ elif self.early_stopping:
170
+ return True
171
+ else:
172
+ cur_score = (
173
+ best_sum_logprobs
174
+ / ((cur_len - origin_len + 5) / 6) ** self.length_penalty
175
+ )
176
+ ret = self.worst_score >= cur_score
177
+ return ret
178
+
179
+
180
+ class BeamSearchScorer(object):
181
+ """
182
+ implementing standard beam search decoding.
183
+ """
184
+
185
+ def __init__(
186
+ self,
187
+ batch_size,
188
+ max_length,
189
+ num_beams,
190
+ length_penalty=1.0,
191
+ do_early_stopping=False,
192
+ num_beam_hyps_to_keep=1,
193
+ num_beam_groups=1,
194
+ ):
195
+ self.max_length = max_length
196
+ self.num_beams = num_beams
197
+ self.length_penalty = length_penalty
198
+ self.do_early_stopping = do_early_stopping
199
+ self.num_beam_hyps_to_keep = num_beam_hyps_to_keep
200
+ self.num_beam_groups = num_beam_groups
201
+ self.group_size = self.num_beams // self.num_beam_groups
202
+
203
+ self._is_init = False
204
+ self._beam_hyps = [
205
+ BeamHypotheses(
206
+ num_beams=self.num_beams,
207
+ length_penalty=self.length_penalty,
208
+ early_stopping=self.do_early_stopping,
209
+ )
210
+ for _ in range(batch_size)
211
+ ]
212
+ self._done = paddle.to_tensor([0 for _ in range(batch_size)], dtype="int64")
213
+
214
+ if not isinstance(num_beams, int) or num_beams <= 1:
215
+ raise ValueError(
216
+ "`num_beams` has to be an integer strictly greater than 1, but "
217
+ "received {}. For `num_beams` == 1, one should make use of "
218
+ "`greedy_search` instead.".format(num_beams)
219
+ )
220
+
221
+ if (
222
+ not isinstance(num_beam_groups, int)
223
+ or (num_beam_groups > num_beams)
224
+ or (num_beams % num_beam_groups != 0)
225
+ ):
226
+ raise ValueError(
227
+ "`num_beam_groups` has to be an integer smaller or equal than "
228
+ "`num_beams` and `num_beams` has to be divisible by "
229
+ "`num_beam_groups`, but received num_beam_groups={}, num_beams="
230
+ "{}.".format(num_beam_groups, num_beams)
231
+ )
232
+
233
+ @property
234
+ def is_done(self):
235
+ return paddle.min(self._done) == 1
236
+
237
+ def process(
238
+ self,
239
+ input_ids,
240
+ next_scores,
241
+ next_tokens,
242
+ next_indices,
243
+ origin_len=0,
244
+ pad_token_id=None,
245
+ eos_token_id=None,
246
+ ):
247
+ cur_len = input_ids.shape[-1]
248
+ batch_size = len(self._beam_hyps)
249
+ assert batch_size == (input_ids.shape[0] // self.group_size)
250
+
251
+ next_beam_scores = paddle.zeros(
252
+ [batch_size, self.group_size], dtype=next_scores.dtype
253
+ )
254
+ next_beam_tokens = paddle.zeros(
255
+ [batch_size, self.group_size], dtype=next_tokens.dtype
256
+ )
257
+ next_beam_indices = paddle.zeros(
258
+ [batch_size, self.group_size], dtype=next_indices.dtype
259
+ )
260
+
261
+ for batch_idx, beam_hyp in enumerate(self._beam_hyps):
262
+ if self._done[batch_idx] == 1:
263
+ assert (
264
+ len(beam_hyp) >= self.num_beams
265
+ ), "Batch can only be done if at least {} beams have been generated".format(
266
+ self.num_beams
267
+ )
268
+ assert (
269
+ eos_token_id is not None and pad_token_id is not None
270
+ ), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
271
+ # pad the batch
272
+ next_beam_scores[batch_idx, :] = 0
273
+ next_beam_tokens[batch_idx, :] = pad_token_id
274
+ next_beam_indices[batch_idx, :] = 0
275
+ continue
276
+
277
+ # next tokens for this sentence
278
+ beam_idx = 0
279
+ for beam_token_rank, (next_token, next_score, next_index) in enumerate(
280
+ zip(
281
+ next_tokens[batch_idx],
282
+ next_scores[batch_idx],
283
+ next_indices[batch_idx],
284
+ )
285
+ ):
286
+ batch_beam_idx = batch_idx * self.group_size + next_index
287
+ # add to generated hypotheses if end of sentence
288
+ if (eos_token_id is not None) and (next_token.item() == eos_token_id):
289
+ # If beam_token does not belong to top num_beams tokens,
290
+ # it should not be added
291
+ is_beam_token_worse_than_top_num_beams = (
292
+ beam_token_rank >= self.group_size
293
+ )
294
+ if is_beam_token_worse_than_top_num_beams:
295
+ continue
296
+ beam_hyp.add(
297
+ input_ids[batch_beam_idx.item()].clone(),
298
+ next_score.item(),
299
+ origin_len,
300
+ )
301
+
302
+ else:
303
+ # add next predicted token since it is not eos_token
304
+ next_beam_scores[batch_idx, beam_idx] = next_score
305
+ next_beam_tokens[batch_idx, beam_idx] = next_token.item()
306
+ next_beam_indices[batch_idx, beam_idx] = batch_beam_idx.item()
307
+ beam_idx += 1
308
+
309
+ # once the beam for next step is full, don't add more tokens to it.
310
+ if beam_idx == self.group_size:
311
+ break
312
+
313
+ if beam_idx < self.group_size:
314
+ raise ValueError(
315
+ "At most {} tokens in `next_tokens[batch_idx]` can be equal "
316
+ "to `eos_token_id: {}`. Make sure `next_tokens[batch_idx]` "
317
+ "are corrected.".format(self.group_size, eos_token_id)
318
+ )
319
+
320
+ # Check if we are done so that we can save a pad step if all(done)
321
+ if beam_hyp.is_done(
322
+ next_scores[batch_idx].max().item(), cur_len, origin_len
323
+ ):
324
+ self._done[batch_idx] = 1
325
+
326
+ return {
327
+ "next_beam_scores": next_beam_scores.reshape([-1]),
328
+ "next_beam_tokens": next_beam_tokens.reshape([-1]),
329
+ "next_beam_indices": next_beam_indices.reshape([-1]),
330
+ }
331
+
332
+ def finalize(
333
+ self,
334
+ input_ids,
335
+ final_beam_scores,
336
+ final_beam_tokens,
337
+ final_beam_indices,
338
+ origin_len=0,
339
+ pad_token_id=None,
340
+ eos_token_id=None,
341
+ ):
342
+ batch_size = len(self._beam_hyps)
343
+
344
+ # finalize all open beam hypotheses and add to generated hypotheses
345
+ for batch_idx, beam_hyp in enumerate(self._beam_hyps):
346
+ if self._done[batch_idx] == 1:
347
+ continue
348
+
349
+ # all open beam hypotheses are added to the beam hypothesis
350
+ # beam hypothesis class automatically keeps the best beams
351
+ for beam_id in range(self.num_beams):
352
+ batch_beam_idx = batch_idx * self.num_beams + beam_id
353
+ final_score = final_beam_scores[batch_beam_idx].item()
354
+ final_tokens = input_ids[batch_beam_idx]
355
+ beam_hyp.add(final_tokens, final_score, origin_len=origin_len)
356
+
357
+ # select the best hypotheses
358
+ sent_lengths = paddle.zeros(
359
+ [batch_size * self.num_beam_hyps_to_keep], dtype=input_ids.dtype
360
+ )
361
+ best = []
362
+
363
+ # retrieve best hypotheses
364
+ for i, beam_hyp in enumerate(self._beam_hyps):
365
+ sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0])
366
+ for j in range(self.num_beam_hyps_to_keep):
367
+ best_score, best_hyp = sorted_hyps.pop()
368
+ sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp)
369
+ best.append([best_hyp, best_score])
370
+
371
+ # prepare for adding eos
372
+ sent_max_len = min(sent_lengths.max().item() + 1, self.max_length)
373
+ decoded = paddle.zeros(
374
+ [batch_size * self.num_beam_hyps_to_keep, sent_max_len],
375
+ dtype=input_ids.dtype,
376
+ )
377
+ # shorter batches are padded if needed
378
+ if sent_lengths.min().item() != sent_lengths.max().item():
379
+ assert pad_token_id is not None, "`pad_token_id` has to be defined"
380
+ decoded[:, :] = pad_token_id
381
+ decoded_score = paddle.zeros([batch_size * self.num_beam_hyps_to_keep, 1])
382
+
383
+ # fill with hypotheses and eos_token_id if the latter fits in
384
+ for i, (hypo, score) in enumerate(best):
385
+ decoded[i, : sent_lengths[i].item()] = hypo.cpu().numpy()
386
+ decoded_score[i] = score
387
+ if sent_lengths[i] < self.max_length:
388
+ decoded[i, sent_lengths[i].item()] = eos_token_id
389
+ return decoded, decoded_score
390
+
391
+
392
+ class GenerationMixin(object):
393
+ r"""
394
+ This class implements the interface for generation task.
395
+
396
+ It's used as the base class of `paddlenlp.transformers.PretrainedModel
397
+ <https://paddlenlp.readthedocs.io/zh/latest/source/paddlenlp.transformers.model_utils.html>`__.
398
+ """
399
+
400
+ # enable `to_static` method for CausalLM Model
401
+ enable_to_static_method = False
402
+
403
+ @staticmethod
404
+ def prepare_input_ids_for_generation(bos_token_id, encoder_output=None):
405
+ batch_size = 1
406
+ if bos_token_id is None:
407
+ raise ValueError(
408
+ "`bos_token_id` should be defined when no " "`input_ids` are provided."
409
+ )
410
+ if encoder_output is not None:
411
+ batch_size = encoder_output.shape[0]
412
+ return paddle.ones([batch_size, 1], dtype="int64") * bos_token_id
413
+
414
+ @staticmethod
415
+ def prepare_attention_mask_for_generation(input_ids, pad_token_id, eos_token_id):
416
+ is_pad_token_in_inputs_ids = (pad_token_id is not None) and paddle.any(
417
+ input_ids == pad_token_id
418
+ ).item()
419
+ is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (
420
+ (eos_token_id is not None) and (pad_token_id != eos_token_id)
421
+ )
422
+ if is_pad_token_in_inputs_ids and is_pad_token_not_equal_to_eos_token_id:
423
+ attention_mask = (input_ids == pad_token_id).astype(
424
+ paddle.get_default_dtype()
425
+ ) * get_scale_by_dtype(return_positive=False)
426
+ else:
427
+ attention_mask = paddle.zeros_like(
428
+ input_ids, dtype=paddle.get_default_dtype()
429
+ )
430
+ return paddle.unsqueeze(attention_mask, axis=[1, 2])
431
+
432
+ @staticmethod
433
+ def prepare_seq_len_for_generation(input_ids, pad_token_id, eos_token_id):
434
+ is_pad_token_in_inputs_ids = (pad_token_id is not None) and paddle.any(
435
+ input_ids == pad_token_id
436
+ ).item()
437
+ is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (
438
+ (eos_token_id is not None) and (pad_token_id != eos_token_id)
439
+ )
440
+ if is_pad_token_in_inputs_ids and is_pad_token_not_equal_to_eos_token_id:
441
+ seq_len = paddle.sum(input_ids != pad_token_id, axis=1).unsqueeze(-1)
442
+ else:
443
+ seq_len = paddle.full(
444
+ (input_ids.shape[0], 1), input_ids.shape[1], dtype="int64"
445
+ )
446
+ return seq_len
447
+
448
+ def get_logits_processor(
449
+ self,
450
+ min_length=None,
451
+ max_length=None,
452
+ eos_token_id=None,
453
+ forced_bos_token_id=None,
454
+ forced_eos_token_id=None,
455
+ num_beams=1,
456
+ num_beam_groups=1,
457
+ diversity_rate=0.0,
458
+ repetition_penalty=None,
459
+ no_repeat_ngram_size=None,
460
+ logits_processors=None,
461
+ ):
462
+ processors = LogitsProcessorList()
463
+
464
+ if min_length is not None and eos_token_id is not None and min_length > -1:
465
+ processors.append(MinLengthLogitsProcessor(min_length, eos_token_id))
466
+ if num_beam_groups > 1 and diversity_rate > 0.0:
467
+ processors.append(
468
+ HammingDiversityLogitsProcessor(
469
+ diversity_rate=diversity_rate,
470
+ num_beams=num_beams,
471
+ num_beam_groups=num_beam_groups,
472
+ )
473
+ )
474
+ if repetition_penalty is not None and repetition_penalty != 1.0:
475
+ processors.append(
476
+ RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)
477
+ )
478
+ if no_repeat_ngram_size is not None and no_repeat_ngram_size > 0:
479
+ processors.append(NoRepeatNGramLogitsProcessor(no_repeat_ngram_size))
480
+ if forced_bos_token_id is not None:
481
+ processors.append(ForcedBOSTokenLogitsProcessor(forced_bos_token_id))
482
+ if forced_eos_token_id is not None:
483
+ processors.append(
484
+ ForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id)
485
+ )
486
+ # TODO
487
+ # Add more pre_processing for distribution
488
+
489
+ if logits_processors is not None:
490
+ custom_processors = LogitsProcessorList()
491
+ custom_processors_type = [type(lp) for lp in logits_processors]
492
+
493
+ for processor in processors:
494
+ if type(processor) not in custom_processors_type:
495
+ custom_processors.append(processor)
496
+ custom_processors.extend(logits_processors)
497
+
498
+ return custom_processors
499
+ else:
500
+ return processors
501
+
502
+ @staticmethod
503
+ def expand_inputs_for_generation(
504
+ input_ids, expand_size, attention_mask=None, **model_kwargs
505
+ ):
506
+
507
+ index = paddle.tile(
508
+ paddle.arange(input_ids.shape[0], dtype="int64").unsqueeze(-1),
509
+ [1, expand_size],
510
+ ).reshape([-1])
511
+
512
+ input_ids = paddle.gather(input_ids, index)
513
+
514
+ if attention_mask is not None:
515
+ model_kwargs["attention_mask"] = paddle.gather(attention_mask, index)
516
+
517
+ if (
518
+ "token_type_ids" in model_kwargs
519
+ and model_kwargs["token_type_ids"] is not None
520
+ ):
521
+ token_type_ids = model_kwargs["token_type_ids"]
522
+ model_kwargs["token_type_ids"] = paddle.gather(token_type_ids, index)
523
+
524
+ if "position_ids" in model_kwargs and model_kwargs["position_ids"] is not None:
525
+ position_ids = model_kwargs["position_ids"]
526
+ model_kwargs["position_ids"] = paddle.gather(position_ids, index)
527
+
528
+ if "seq_len" in model_kwargs and model_kwargs["seq_len"] is not None:
529
+ seq_len = model_kwargs["seq_len"]
530
+ model_kwargs["seq_len"] = paddle.gather(seq_len, index)
531
+
532
+ if (
533
+ "encoder_output" in model_kwargs
534
+ and model_kwargs["encoder_output"] is not None
535
+ ):
536
+ encoder_output = model_kwargs["encoder_output"]
537
+ model_kwargs["encoder_output"] = paddle.gather(encoder_output, index)
538
+
539
+ if "role_ids" in model_kwargs and model_kwargs["role_ids"] is not None:
540
+ role_ids = model_kwargs["role_ids"]
541
+ model_kwargs["role_ids"] = paddle.gather(role_ids, index)
542
+
543
+ return input_ids, model_kwargs
544
+
545
+ @staticmethod
546
+ def update_model_kwargs_for_generation(
547
+ outputs, model_kwargs, is_encoder_decoder=False
548
+ ):
549
+ # Update the model inputs during generation.
550
+ # Note that If `token_type_ids` and `attention_mask` in `model_kwargs`
551
+ # and they contain pad value, the result vectors updated by this method
552
+ # may be different from expected. In this case, you need to rewrite the
553
+ # method.
554
+
555
+ # update cache
556
+ if (
557
+ isinstance(outputs, tuple)
558
+ and len(outputs) > 1
559
+ and not isinstance(outputs[1], paddle.Tensor)
560
+ ):
561
+ model_kwargs["cache"] = outputs[1]
562
+ model_kwargs["past_key_values"] = outputs[1]
563
+
564
+ if isinstance(outputs, ModelOutput) and "past_key_values" in outputs:
565
+ model_kwargs["cache"] = outputs.past_key_values
566
+ model_kwargs["past_key_values"] = outputs.past_key_values
567
+
568
+ # update token_type_ids with last value
569
+ if (
570
+ "token_type_ids" in model_kwargs
571
+ and model_kwargs["token_type_ids"] is not None
572
+ ):
573
+ token_type_ids = model_kwargs["token_type_ids"]
574
+ model_kwargs["token_type_ids"] = paddle.concat(
575
+ [token_type_ids, token_type_ids[:, -1:]], axis=-1
576
+ )
577
+
578
+ # update position_ids
579
+ if "position_ids" in model_kwargs and model_kwargs["position_ids"] is not None:
580
+ position_ids = model_kwargs["position_ids"]
581
+ model_kwargs["position_ids"] = paddle.concat(
582
+ [position_ids, position_ids[..., -1:] + 1], axis=-1
583
+ )
584
+
585
+ # update attention_mask
586
+ if not is_encoder_decoder and "attention_mask" in model_kwargs:
587
+ attention_mask = model_kwargs["attention_mask"]
588
+ # nn.Pad2D don't support the data type `bool`
589
+ if convert_dtype(attention_mask.dtype) == "bool":
590
+ attention_mask = paddle.cast(attention_mask, "int64")
591
+ if len(attention_mask.shape) == 4:
592
+ cur_device = paddle.get_device()
593
+ if cur_device.split(":")[0] == "npu":
594
+ attention_mask = nn.Pad2D([0, 0, 0, 1], mode="constant")(
595
+ attention_mask
596
+ )
597
+ attention_mask = nn.Pad2D([0, 1, 0, 0], value=0)(attention_mask)
598
+ else:
599
+ attention_mask = nn.Pad2D([0, 0, 0, 1], mode="replicate")(
600
+ attention_mask
601
+ )
602
+ attention_mask = nn.Pad2D(
603
+ [0, 1, 0, 0], value=get_scale_by_dtype(return_positive=False)
604
+ )(attention_mask)
605
+
606
+ dtype = convert_dtype(attention_mask.dtype)
607
+ if "int" in dtype:
608
+ attention_mask[:, :, -1, -1] = 1
609
+ elif "float" in dtype:
610
+ attention_mask[:, :, -1, -1] = 0.0
611
+ else:
612
+ raise ValueError(
613
+ "The data type of input `attention_mask` must "
614
+ "be bool, int or float"
615
+ )
616
+ else:
617
+ attention_mask = paddle.concat(
618
+ [
619
+ attention_mask,
620
+ paddle.ones([attention_mask.shape[0], 1], dtype="int64"),
621
+ ],
622
+ axis=-1,
623
+ )
624
+ model_kwargs["attention_mask"] = attention_mask
625
+
626
+ # update role_ids
627
+ if "role_ids" in model_kwargs and model_kwargs["role_ids"] is not None:
628
+ role_ids = model_kwargs["role_ids"]
629
+ model_kwargs["role_ids"] = paddle.concat(
630
+ [role_ids, role_ids[:, -1:]], axis=-1
631
+ )
632
+
633
+ return model_kwargs
634
+
635
+ @staticmethod
636
+ def update_scores_for_generation(scores, next_scores, length, unfinished_flag):
637
+ # update scores
638
+
639
+ unfinished_scores = (
640
+ scores * paddle.to_tensor(length, dtype=scores.dtype) + next_scores
641
+ ) / (paddle.to_tensor(length, dtype=scores.dtype) + 1)
642
+ scores = paddle.where(unfinished_flag, unfinished_scores, scores)
643
+ return scores
644
+
645
+ def prepare_encoder_decoder_kwargs_for_generation(self, input_ids, model_kwargs):
646
+ if "encoder_output" not in model_kwargs:
647
+ # retrieve encoder hidden states
648
+ encoder = self.get_encoder()
649
+ encoder_kwargs = {
650
+ argument: value
651
+ for argument, value in model_kwargs.items()
652
+ if not (
653
+ argument.startswith("decoder_")
654
+ or argument.startswith("cross_attn")
655
+ or argument == "use_cache"
656
+ )
657
+ }
658
+ # Use inputs_embeds as the priority if inputs_embeds exists
659
+ if "inputs_embeds" in encoder_kwargs:
660
+ model_kwargs["encoder_output"] = encoder(**encoder_kwargs)
661
+ else:
662
+ model_kwargs["encoder_output"] = encoder(
663
+ input_ids=input_ids, **encoder_kwargs
664
+ )
665
+ return model_kwargs
666
+
667
+ def prepare_decoder_input_ids_for_generation(
668
+ self, input_ids, decoder_start_token_id=None, bos_token_id=None
669
+ ):
670
+ decoder_start_token_id = (
671
+ decoder_start_token_id
672
+ if decoder_start_token_id is not None
673
+ else self.config.decoder_start_token_id
674
+ )
675
+ decoder_start_token_id = (
676
+ decoder_start_token_id
677
+ if decoder_start_token_id is not None
678
+ else bos_token_id
679
+ )
680
+
681
+ decoder_input_ids = (
682
+ paddle.ones([input_ids.shape[0], 1], dtype="int64") * decoder_start_token_id
683
+ )
684
+
685
+ return decoder_input_ids
686
+
687
+ def get_decoder_start_token_id(
688
+ self, decoder_start_token_id=None, bos_token_id=None
689
+ ):
690
+ decoder_start_token_id = (
691
+ decoder_start_token_id
692
+ if decoder_start_token_id is not None
693
+ else self.config.decoder_start_token_id
694
+ )
695
+ bos_token_id = (
696
+ bos_token_id if bos_token_id is not None else self.config.bos_token_id
697
+ )
698
+
699
+ if decoder_start_token_id is not None:
700
+ return decoder_start_token_id
701
+ elif self.config.decoder_start_token_id is not None:
702
+ return self.config.decoder_start_token_id
703
+ elif bos_token_id is not None:
704
+ return bos_token_id
705
+ elif self.config.bos_token_id is not None:
706
+ return self.config.bos_token_id
707
+ raise ValueError(
708
+ "`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation."
709
+ )
710
+
711
+ def prepare_inputs_for_generation(self, input_ids, **kwargs):
712
+ # Implement in subclasses for custom behavior to prepare inputs in the
713
+ # generate method.
714
+
715
+ return {"input_ids": input_ids}
716
+
717
+ def adjust_logits_during_generation(self, logits):
718
+ # Implement in subclasses for custom behavior to adjust the logits in
719
+ # the generate method.
720
+
721
+ return logits
722
+
723
+ def prepare_fast_entry(self, kwargs):
724
+ return False
725
+
726
+ def _convert_to_fast(self, kwargs):
727
+ # try general convert
728
+ pass
729
+
730
+ def _build_fast(self, kwargs):
731
+ self._fast_entry = False
732
+ if kwargs["num_beam_groups"] != 1:
733
+ # not support for group_beam_search yet in the fast version
734
+ raise AttributeError(
735
+ "'num_beam_groups != 1' is not supported yet in the fast version"
736
+ )
737
+ if (
738
+ paddle.get_default_dtype() == "float16"
739
+ and kwargs["use_fp16_decoding"] is False
740
+ ):
741
+ logging.info(
742
+ "Since the default dtype is float16, float16 would be used "
743
+ "though 'use_fp16_decoding=False'."
744
+ )
745
+ kwargs["use_fp16_decoding"] = True
746
+ self.prepare_fast_entry(kwargs)
747
+
748
+ def set_pad_token_id(self, pad_token_id, eos_token_id):
749
+ if pad_token_id is None and eos_token_id is not None:
750
+ logging.warning(
751
+ "Setting `pad_token_id` to `eos_token_id`:{} for "
752
+ "open-end generation.".format(eos_token_id)
753
+ )
754
+ if isinstance(eos_token_id, list):
755
+ pad_token_id = eos_token_id[0]
756
+ else:
757
+ pad_token_id = eos_token_id
758
+ return pad_token_id
759
+
760
+ @paddle.no_grad()
761
+ def generate(
762
+ self,
763
+ input_ids: paddle.Tensor = None,
764
+ generation_config: GenerationConfig = None,
765
+ stopping_criteria: StoppingCriteria = None,
766
+ streamer=None,
767
+ synced_gpus: Optional[bool] = None,
768
+ **kwargs,
769
+ ):
770
+ r"""
771
+ The interface for generation task. This method can generate sequences
772
+ by using decoding strategy. Currently, there are three decoding
773
+ strategies supported: "greedy_search", "sampling" and "beam_search".
774
+
775
+ Args:
776
+ input_ids (Tensor, optional): The input sequence ids for the
777
+ generation. It is a Tensor with shape [batch_size, sequence_length].
778
+ The data type should be int32 or int64. Default to None, which
779
+ we will initialize it as a Tensor with shape [1, 1], filled
780
+ with the value `bos_token_id`.
781
+ generation_config (`~generation.GenerationConfig`, *optional*):
782
+ The generation configuration to be used as base parametrization for the generation call. `**kwargs`
783
+ passed to generate matching the attributes of `generation_config` will override them. If
784
+ `generation_config` is not provided, the default will be used, which had the following loading
785
+ priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
786
+ configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
787
+ default values, whose documentation should be checked to parameterize generation.
788
+ stopping_criteria (`StoppingCriteriaList`, *optional*):
789
+ Custom stopping criteria that complement the default stopping criteria built from arguments and a
790
+ generation config. If a stopping criteria is passed that is already created with the arguments or a
791
+ generation config an error is thrown. This feature is intended for advanced users.
792
+ streamer (`~streamer.BaseStreamer`, *optional*):
793
+ Streamer object that will be used to stream the generated sequences. Generated tokens are passed
794
+ through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
795
+ synced_gpus (`bool`, *optional*):
796
+ Whether to continue running the while loop until max_length. Unless overridden this flag will be set to
797
+ `True` under DeepSpeed ZeRO Stage 3 multiple GPUs environment to avoid hanging if one GPU finished
798
+ generating before other GPUs. Otherwise it'll be set to `False`.
799
+ kwargs (dict): It can be used to specify additional kwargs
800
+ passed to the model.
801
+
802
+ Returns:
803
+ tuple[Tensor]: It is a tuple contains two elements: ids and scores.
804
+ Each element is a Tensor.
805
+
806
+ With the fields:
807
+
808
+ - ids (Tensor):
809
+ The ids of the generated sequences. It is a Tensor with shape
810
+ [batch_size * num_return_sequences, sequence_length]. The data
811
+ type is same as the input `input_ids`.
812
+ - scores (Tensor):
813
+ The scores of the generated sequences. It is a Tensor with shape
814
+ [batch_size * num_return_sequences, 1]. The data type is float32
815
+ or float64, which is the same as the parameters in the model.
816
+
817
+ Example:
818
+ .. code-block::
819
+
820
+ import paddle
821
+ from paddlenlp.transformers import (
822
+ UnifiedTransformerLMHeadModel,
823
+ UnifiedTransformerTokenizer
824
+ )
825
+
826
+ paddle.seed(2)
827
+
828
+ # Initialize the model and tokenizer
829
+ model_name_or_path = 'unified_transformer-12L-cn-luge'
830
+ model = UnifiedTransformerLMHeadModel.from_pretrained(model_name_or_path)
831
+ tokenizer = UnifiedTransformerTokenizer.from_pretrained(model_name_or_path)
832
+
833
+ # Prepare the model inputs.
834
+ history = "早上好,今天空气质量不错。"
835
+ inputs = tokenizer.dialogue_encode(history, task_type='chitchat',
836
+ add_start_token_as_response=True, return_tensors=True)
837
+
838
+ .. code-block::
839
+
840
+ # Generate the sequence by using "greedy_search" strategy
841
+ ids, scores = model.generate(
842
+ **inputs,
843
+ decode_strategy="greedy_search")
844
+ print(ids.shape, scores.shape)
845
+ # [1, 3] [1, 1]
846
+ sequence_ids = ids.cpu().numpy().tolist()[0]
847
+ sequence_ids = sequence_ids[:sequence_ids.index(tokenizer.sep_token_id)]
848
+ response = tokenizer.convert_ids_to_string(sequence_ids, keep_space=False)
849
+ print(response)
850
+ # 是的
851
+
852
+ .. code-block::
853
+
854
+ # Generate 2 sequences by using "sampling" strategy (top_k=5)
855
+ generation_config = GenerationConfig(
856
+ decode_strategy="sampling",
857
+ top_k=5,
858
+ num_return_sequences=2
859
+ )
860
+ ids, scores = model.generate(
861
+ **inputs,
862
+ generation_config=generation_config,
863
+ )
864
+ print(ids.shape, scores.shape)
865
+ # [2, 7] [2, 1]
866
+ response = []
867
+ for sequence_ids in ids.cpu().numpy().tolist():
868
+ sequence_ids = sequence_ids[:sequence_ids.index(tokenizer.sep_token_id)]
869
+ text = tokenizer.convert_ids_to_string(sequence_ids, keep_space=False)
870
+ response.append(text)
871
+ print(response)
872
+ # ['天气好,心情也好', '你也是']
873
+
874
+ .. code-block::
875
+
876
+ # Generate 2 sequences by using "beam_search" strategy (num_beams=5)
877
+ generation_config = GenerationConfig(
878
+ decode_strategy="beam_search",
879
+ num_beams=5,
880
+ num_return_sequences=2
881
+ )
882
+ ids, scores = model.generate(
883
+ **inputs,
884
+ generation_config=generation_config,
885
+ )
886
+ print(ids.shape, scores.shape)
887
+ # [2, 3] [2, 1]
888
+ response = []
889
+ for sequence_ids in ids.cpu().numpy().tolist():
890
+ sequence_ids = sequence_ids[:sequence_ids.index(tokenizer.sep_token_id)]
891
+ text = tokenizer.convert_ids_to_string(sequence_ids, keep_space=False)
892
+ response.append(text)
893
+ print(response)
894
+ # ['是的', '嗯嗯']
895
+ """
896
+ if generation_config is None:
897
+ if (
898
+ self.generation_config is None
899
+ or self.generation_config._from_model_config
900
+ ):
901
+ new_generation_config = GenerationConfig.from_model_config(self.config)
902
+ if new_generation_config != self.generation_config:
903
+ logging.warning(
904
+ "model.generation_config is in conflict with model.config, "
905
+ "model.config is used."
906
+ )
907
+ self.generation_config = new_generation_config
908
+ generation_config = self.generation_config
909
+
910
+ # without update model.generation_config
911
+ generation_config = copy.deepcopy(generation_config)
912
+ model_kwargs = generation_config.update(**kwargs)
913
+
914
+ assert generation_config.decode_strategy in [
915
+ "greedy_search",
916
+ "sampling",
917
+ "beam_search",
918
+ ], "`decode_strategy` must be one of 'greedy_search', 'sampling' or 'beam_search' but received {}.".format(
919
+ generation_config.decode_strategy
920
+ )
921
+
922
+ if getattr(self, "deprecated_warnings", None) is None:
923
+ self.deprecated_warnings = {}
924
+
925
+ use_fast = False
926
+ if "use_faster" in model_kwargs:
927
+ raise ValueError("`use_faster` is deprecated now.")
928
+
929
+ if "use_fast" in model_kwargs:
930
+ raise ValueError("`use_fast` is deprecated now.")
931
+
932
+ bos_token_id = (
933
+ generation_config.bos_token_id
934
+ if generation_config.bos_token_id is not None
935
+ else self.config.bos_token_id
936
+ )
937
+ eos_token_id = (
938
+ generation_config.eos_token_id
939
+ if generation_config.eos_token_id is not None
940
+ else self.config.eos_token_id
941
+ )
942
+ pad_token_id = (
943
+ generation_config.pad_token_id
944
+ if generation_config.pad_token_id is not None
945
+ else self.config.pad_token_id
946
+ )
947
+ forced_bos_token_id = (
948
+ generation_config.forced_bos_token_id
949
+ if generation_config.forced_bos_token_id is not None
950
+ else self.config.forced_bos_token_id
951
+ )
952
+ forced_eos_token_id = (
953
+ generation_config.forced_eos_token_id
954
+ if generation_config.forced_eos_token_id is not None
955
+ else self.config.forced_eos_token_id
956
+ )
957
+ decoder_start_token_id = (
958
+ generation_config.decoder_start_token_id
959
+ if generation_config.decoder_start_token_id is not None
960
+ else self.config.decoder_start_token_id
961
+ )
962
+ no_repeat_ngram_size = (
963
+ generation_config.no_repeat_ngram_size
964
+ if generation_config.no_repeat_ngram_size is not None
965
+ else self.config.no_repeat_ngram_size
966
+ )
967
+
968
+ if getattr(self, "_fast_entry", None) is not False and use_fast:
969
+ fg_args = locals()
970
+ fg_args.pop("self")
971
+ fg_args.pop("__class__", None)
972
+ model_kwargs = fg_args.pop("model_kwargs")
973
+ fg_args.update(model_kwargs)
974
+ try:
975
+ if getattr(self, "_fast_entry", None) is None:
976
+ self._build_fast(fg_args)
977
+ if self._fast_entry:
978
+ output = self._fast_entry(**fg_args)
979
+ if isinstance(output, tuple):
980
+ output_ids, dummy_srore = output
981
+ else:
982
+ output_ids = output
983
+ # make result and fast result oneconsistent
984
+ dummy_srore = None
985
+ if generation_config.decode_strategy == "beam_search":
986
+ output_ids = output_ids.transpose([1, 2, 0])
987
+ output_ids = output_ids[
988
+ :, : generation_config.num_return_sequences, :
989
+ ].reshape([-1, output_ids.shape[-1]])
990
+ if dummy_srore is not None:
991
+ dummy_srore = dummy_srore[
992
+ :, : generation_config.num_return_sequences
993
+ ].flatten()
994
+ else:
995
+ output_ids = output_ids.transpose([1, 0])
996
+ return output_ids, dummy_srore
997
+
998
+ except Exception as e:
999
+ fg_args["model_kwargs"] = model_kwargs
1000
+ # TODO
1001
+ # Prevent self._convert_to_fast to throw Exception
1002
+ self._convert_to_fast(fg_args)
1003
+ logging.warning(e)
1004
+ logging.warning(
1005
+ "FastGeneration is not available, "
1006
+ "and the original version would be used instead."
1007
+ )
1008
+
1009
+ # input_ids in model_kwargs is supported
1010
+ if "input_ids" in model_kwargs:
1011
+ _input_ids = model_kwargs.pop("input_ids")
1012
+ if input_ids is None:
1013
+ input_ids = _input_ids
1014
+
1015
+ # params check
1016
+ if input_ids is None and "inputs_embeds" not in model_kwargs:
1017
+ # Init `input_ids` with bos_token_id
1018
+ input_ids = self.prepare_input_ids_for_generation(bos_token_id)
1019
+ elif "inputs_embeds" in model_kwargs:
1020
+ # Add input embeds support
1021
+ input_ids = self.prepare_input_ids_for_generation(
1022
+ bos_token_id, encoder_output=model_kwargs["inputs_embeds"]
1023
+ )
1024
+
1025
+ if model_kwargs.get("attention_mask", None) is None:
1026
+ # TODO
1027
+ # Init `attention_mask` depending on `pad_token_id`
1028
+ model_kwargs["attention_mask"] = self.prepare_attention_mask_for_generation(
1029
+ input_ids, pad_token_id, eos_token_id
1030
+ )
1031
+ self.is_encoder_decoder = self.config.is_encoder_decoder
1032
+
1033
+ if self.is_encoder_decoder:
1034
+ model_kwargs = self.prepare_encoder_decoder_kwargs_for_generation(
1035
+ input_ids, model_kwargs
1036
+ )
1037
+ # set input_ids as decoder_input_ids
1038
+ if "decoder_input_ids" in model_kwargs:
1039
+ input_ids = model_kwargs.pop("decoder_input_ids")
1040
+ else:
1041
+ input_ids = self.prepare_decoder_input_ids_for_generation(
1042
+ input_ids, decoder_start_token_id, bos_token_id
1043
+ )
1044
+ # streamer
1045
+ if streamer is not None:
1046
+ # streamer couldn't support beam_search strategy
1047
+ if (
1048
+ generation_config.decode_strategy == "beam_search"
1049
+ or generation_config.num_beams > 1
1050
+ ):
1051
+ raise ValueError(
1052
+ "`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1."
1053
+ )
1054
+
1055
+ pad_token_id = self.set_pad_token_id(pad_token_id, eos_token_id)
1056
+
1057
+ if (
1058
+ generation_config.max_length != 0
1059
+ and generation_config.max_new_tokens == DEFAULT_MAX_NEW_TOKENS
1060
+ ):
1061
+ logging.warning(
1062
+ "`max_length` will be deprecated in future releases, use `max_new_tokens` instead."
1063
+ )
1064
+ generation_config.max_new_tokens = generation_config.max_length
1065
+
1066
+ if generation_config.min_length != 0 and generation_config.min_new_tokens == 0:
1067
+ logging.warning(
1068
+ "`min_length` will be deprecated in future releases, use `min_new_tokens` instead."
1069
+ )
1070
+ generation_config.min_new_tokens = generation_config.min_length
1071
+
1072
+ max_length = generation_config.max_new_tokens
1073
+ min_length = generation_config.min_new_tokens
1074
+
1075
+ input_len = input_ids.shape[-1]
1076
+ min_len = input_len + min_length
1077
+ max_len = input_len + max_length
1078
+
1079
+ logits_processors = self.get_logits_processor(
1080
+ min_length=min_len if min_length > 0 else None,
1081
+ max_length=max_len,
1082
+ eos_token_id=eos_token_id,
1083
+ forced_bos_token_id=forced_bos_token_id,
1084
+ forced_eos_token_id=forced_eos_token_id,
1085
+ num_beams=generation_config.num_beams,
1086
+ num_beam_groups=generation_config.num_beam_groups,
1087
+ diversity_rate=generation_config.diversity_rate,
1088
+ repetition_penalty=generation_config.repetition_penalty,
1089
+ no_repeat_ngram_size=generation_config.no_repeat_ngram_size,
1090
+ logits_processors=(
1091
+ model_kwargs["logits_processors"]
1092
+ if "logits_processors" in model_kwargs
1093
+ and isinstance(model_kwargs["logits_processors"], LogitsProcessorList)
1094
+ else None
1095
+ ),
1096
+ )
1097
+ if "logits_processors" in model_kwargs:
1098
+ model_kwargs.pop("logits_processors")
1099
+
1100
+ stopping_criteria = (
1101
+ stopping_criteria
1102
+ if stopping_criteria is not None
1103
+ else StoppingCriteriaList()
1104
+ )
1105
+
1106
+ if generation_config.decode_strategy == "greedy_search":
1107
+ if generation_config.num_return_sequences > 1:
1108
+ raise ValueError(
1109
+ "`num_return_sequences` has to be 1, but is {} "
1110
+ "when doing greedy search.".format(
1111
+ generation_config.num_return_sequences
1112
+ )
1113
+ )
1114
+ return self.greedy_search(
1115
+ input_ids,
1116
+ logits_processors,
1117
+ max_len,
1118
+ pad_token_id,
1119
+ eos_token_id,
1120
+ stopping_criteria=stopping_criteria,
1121
+ streamer=streamer,
1122
+ fast_ptq_sampling=generation_config.fast_ptq_sampling,
1123
+ trunc_input=generation_config.trunc_input,
1124
+ synced_gpus=synced_gpus,
1125
+ **model_kwargs,
1126
+ )
1127
+
1128
+ elif generation_config.decode_strategy == "sampling":
1129
+ if generation_config.num_return_sequences > 1:
1130
+ input_ids, model_kwargs = self.expand_inputs_for_generation(
1131
+ input_ids,
1132
+ expand_size=generation_config.num_return_sequences,
1133
+ **model_kwargs,
1134
+ )
1135
+
1136
+ return self.sample(
1137
+ input_ids,
1138
+ logits_processors,
1139
+ max_len,
1140
+ pad_token_id,
1141
+ eos_token_id,
1142
+ generation_config.top_k,
1143
+ generation_config.top_p,
1144
+ generation_config.temperature,
1145
+ stopping_criteria=stopping_criteria,
1146
+ streamer=streamer,
1147
+ fast_ptq_sampling=generation_config.fast_ptq_sampling,
1148
+ trunc_input=generation_config.trunc_input,
1149
+ synced_gpus=synced_gpus,
1150
+ **model_kwargs,
1151
+ )
1152
+
1153
+ elif generation_config.decode_strategy == "beam_search":
1154
+ batch_size = input_ids.shape[0]
1155
+ if generation_config.num_return_sequences > generation_config.num_beams:
1156
+ raise ValueError(
1157
+ "`num_return_sequences` has to be smaller or equal to "
1158
+ "`num_beams`. But received `num_return_sequences` is {}, "
1159
+ "`num_beams` is {}".format(
1160
+ generation_config.num_return_sequences,
1161
+ generation_config.num_beams,
1162
+ )
1163
+ )
1164
+ if generation_config.num_beams <= 1:
1165
+ raise ValueError(
1166
+ "`num_beams` has to be bigger than 1. But received "
1167
+ "`num_beams` is {}. If `num_beams` is 1, `decode_strategy` "
1168
+ "should be 'greedy_search'".format(generation_config.num_beams)
1169
+ )
1170
+ if generation_config.num_beam_groups > 1:
1171
+ diverse_beam_scorer = BeamSearchScorer(
1172
+ batch_size=batch_size,
1173
+ max_length=max_len,
1174
+ num_beams=generation_config.num_beams,
1175
+ length_penalty=generation_config.length_penalty,
1176
+ do_early_stopping=generation_config.early_stopping,
1177
+ num_beam_hyps_to_keep=generation_config.num_return_sequences,
1178
+ num_beam_groups=generation_config.num_beam_groups,
1179
+ )
1180
+
1181
+ # interleave with `num_beams`
1182
+ input_ids, model_kwargs = self.expand_inputs_for_generation(
1183
+ input_ids, expand_size=generation_config.num_beams, **model_kwargs
1184
+ )
1185
+
1186
+ return self.group_beam_search(
1187
+ input_ids,
1188
+ diverse_beam_scorer,
1189
+ logits_processors,
1190
+ max_len,
1191
+ pad_token_id,
1192
+ eos_token_id,
1193
+ stopping_criteria=stopping_criteria,
1194
+ fast_ptq_sampling=generation_config.fast_ptq_sampling,
1195
+ trunc_input=generation_config.trunc_input,
1196
+ synced_gpus=synced_gpus,
1197
+ **model_kwargs,
1198
+ )
1199
+ else:
1200
+ beam_scorer = BeamSearchScorer(
1201
+ batch_size=batch_size,
1202
+ max_length=max_len,
1203
+ num_beams=generation_config.num_beams,
1204
+ length_penalty=generation_config.length_penalty,
1205
+ do_early_stopping=generation_config.early_stopping,
1206
+ num_beam_hyps_to_keep=generation_config.num_return_sequences,
1207
+ )
1208
+
1209
+ input_ids, model_kwargs = self.expand_inputs_for_generation(
1210
+ input_ids, expand_size=generation_config.num_beams, **model_kwargs
1211
+ )
1212
+
1213
+ return self.beam_search(
1214
+ input_ids,
1215
+ beam_scorer,
1216
+ logits_processors,
1217
+ max_len,
1218
+ generation_config.diversity_rate,
1219
+ pad_token_id,
1220
+ eos_token_id,
1221
+ stopping_criteria=stopping_criteria,
1222
+ fast_ptq_sampling=generation_config.fast_ptq_sampling,
1223
+ trunc_input=generation_config.trunc_input,
1224
+ synced_gpus=synced_gpus,
1225
+ **model_kwargs,
1226
+ )
1227
+
1228
+ def greedy_search(
1229
+ self,
1230
+ input_ids,
1231
+ logits_processors,
1232
+ max_length,
1233
+ pad_token_id,
1234
+ eos_token_id,
1235
+ stopping_criteria=None,
1236
+ streamer=None,
1237
+ fast_ptq_sampling=False,
1238
+ trunc_input=True,
1239
+ synced_gpus=False,
1240
+ **model_kwargs,
1241
+ ):
1242
+ model_kwargs["use_cache"] = model_kwargs.get("use_cache", True)
1243
+ logits_processors = (
1244
+ logits_processors
1245
+ if logits_processors is not None
1246
+ else LogitsProcessorList()
1247
+ )
1248
+
1249
+ # max_length will be convert to MaxLengthCriteria
1250
+ stopping_criteria = (
1251
+ stopping_criteria
1252
+ if stopping_criteria is not None
1253
+ else StoppingCriteriaList()
1254
+ )
1255
+ if max_length is not None:
1256
+ # logging.warning(
1257
+ # "`max_length` is deprecated in this function, use"
1258
+ # " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead."
1259
+ # )
1260
+ stopping_criteria = validate_stopping_criteria(
1261
+ stopping_criteria, max_length
1262
+ )
1263
+
1264
+ batch_size, cur_len = input_ids.shape
1265
+ origin_len = cur_len
1266
+ unfinished_flag = paddle.full([batch_size, 1], True, dtype="bool")
1267
+ scores = paddle.full([batch_size, 1], 0.0, dtype=paddle.get_default_dtype())
1268
+ generate_end = False
1269
+ while True:
1270
+ if synced_gpus:
1271
+ # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
1272
+ # The following logic allows an early break if all peers finished generating their sequence
1273
+ this_peer_finished_flag = paddle.to_tensor(0.0 if generate_end else 1.0)
1274
+ # send 0.0 if we finished, 1.0 otherwise
1275
+ dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
1276
+ # did all peers finish? the reduced sum will be 0.0 then
1277
+ if this_peer_finished_flag.item() == 0.0:
1278
+ break
1279
+
1280
+ # prepare model inputs & get model output
1281
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
1282
+
1283
+ outputs = self(**model_inputs)
1284
+
1285
+ if synced_gpus and generate_end:
1286
+ continue # don't waste resources running the code we don't need
1287
+
1288
+ if isinstance(outputs, tuple):
1289
+ logits = outputs[0]
1290
+ elif isinstance(outputs, ModelOutput):
1291
+ logits = outputs.logits
1292
+ else:
1293
+ logits = outputs
1294
+
1295
+ # [batch_size, vocab_size]
1296
+ next_token_logits = logits[:, -1, :]
1297
+
1298
+ # pre-process distribution
1299
+ next_token_logits = self.adjust_logits_during_generation(next_token_logits)
1300
+ probs = logits_processors(input_ids, next_token_logits)
1301
+ # greedy
1302
+ next_tokens = paddle.argmax(probs, axis=-1).unsqueeze(-1)
1303
+ next_scores = paddle.index_sample(probs, next_tokens)
1304
+
1305
+ if eos_token_id is not None:
1306
+ next_tokens = paddle.where(
1307
+ unfinished_flag,
1308
+ next_tokens,
1309
+ paddle.full_like(next_tokens, pad_token_id),
1310
+ )
1311
+
1312
+ scores = self.update_scores_for_generation(
1313
+ scores, next_scores, cur_len - origin_len, unfinished_flag
1314
+ )
1315
+ cur_len += 1
1316
+
1317
+ input_ids = paddle.concat([input_ids, next_tokens], axis=1)
1318
+ if streamer is not None:
1319
+ if self.config.tensor_parallel_rank == 0:
1320
+ streamer.put(next_tokens.cpu())
1321
+
1322
+ if stopping_criteria(input_ids, scores):
1323
+ generate_end = True
1324
+
1325
+ if eos_token_id is not None:
1326
+ unfinished_flag = get_unfinished_flag(
1327
+ input_ids, unfinished_flag, eos_token_id
1328
+ )
1329
+ if not paddle.any(unfinished_flag):
1330
+ generate_end = True
1331
+
1332
+ # Stop when there is a </s> in all sentences
1333
+ if generate_end and not synced_gpus:
1334
+ break
1335
+
1336
+ model_kwargs = self.update_model_kwargs_for_generation(
1337
+ outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
1338
+ )
1339
+ if fast_ptq_sampling:
1340
+ break
1341
+
1342
+ if streamer is not None:
1343
+ streamer.end()
1344
+
1345
+ return input_ids[:, origin_len:] if trunc_input else input_ids, scores
1346
+
1347
+ def sample(
1348
+ self,
1349
+ input_ids,
1350
+ logits_processors,
1351
+ max_length,
1352
+ pad_token_id,
1353
+ eos_token_id,
1354
+ top_k=None,
1355
+ top_p=None,
1356
+ temperature=None,
1357
+ min_tokens_to_keep=1,
1358
+ stopping_criteria=None,
1359
+ streamer=None,
1360
+ fast_ptq_sampling=False,
1361
+ trunc_input=True,
1362
+ synced_gpus=False,
1363
+ **model_kwargs,
1364
+ ):
1365
+ model_kwargs["use_cache"] = model_kwargs.get("use_cache", True)
1366
+
1367
+ logits_processors = (
1368
+ logits_processors
1369
+ if logits_processors is not None
1370
+ else LogitsProcessorList()
1371
+ )
1372
+
1373
+ # max_length will be convert to MaxLengthCriteria
1374
+ stopping_criteria = (
1375
+ stopping_criteria
1376
+ if stopping_criteria is not None
1377
+ else StoppingCriteriaList()
1378
+ )
1379
+ if max_length is not None:
1380
+ # logging.warning(
1381
+ # "`max_length` is deprecated in this function, use"
1382
+ # " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead."
1383
+ # )
1384
+ stopping_criteria = validate_stopping_criteria(
1385
+ stopping_criteria, max_length
1386
+ )
1387
+
1388
+ batch_size, cur_len = input_ids.shape
1389
+ origin_len = cur_len
1390
+ unfinished_flag = paddle.full([batch_size, 1], True, dtype="bool")
1391
+ scores = paddle.full([batch_size, 1], 0.0, dtype=paddle.get_default_dtype())
1392
+
1393
+ generate_end = False
1394
+ while True:
1395
+ if synced_gpus:
1396
+ # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
1397
+ # The following logic allows an early break if all peers finished generating their sequence
1398
+ this_peer_finished_flag = paddle.to_tensor(0.0 if generate_end else 1.0)
1399
+ # send 0.0 if we finished, 1.0 otherwise
1400
+ dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
1401
+ # did all peers finish? the reduced sum will be 0.0 then
1402
+ if this_peer_finished_flag.item() == 0.0:
1403
+ break
1404
+ # prepare model inputs & get model output
1405
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
1406
+ # NOTE: to decrease ref-count and clear outdate cache in-time
1407
+ model_kwargs["cache"] = None
1408
+ model_kwargs["past_key_values"] = None
1409
+ outputs = self(**model_inputs)
1410
+ if synced_gpus and generate_end:
1411
+ continue # don't waste resources running the code we don't need
1412
+
1413
+ if isinstance(outputs, tuple):
1414
+ logits = outputs[0]
1415
+ elif isinstance(outputs, ModelOutput):
1416
+ logits = outputs.logits
1417
+ else:
1418
+ logits = outputs
1419
+
1420
+ # [batch_size, vocab_size]
1421
+ logits = logits[:, -1, :]
1422
+
1423
+ # pre-process distribution
1424
+ logits = self.adjust_logits_during_generation(logits)
1425
+ logits = logits_processors(input_ids, logits)
1426
+
1427
+ # sample
1428
+ origin_probs = F.softmax(logits)
1429
+ origin_probs = paddle.log(origin_probs)
1430
+ if temperature is not None and temperature != 1.0:
1431
+ logits = logits / temperature
1432
+ probs = F.softmax(logits)
1433
+ if top_k is not None and top_k != 0:
1434
+ probs = TopKProcess(probs, top_k, min_tokens_to_keep)
1435
+ if top_p is not None and top_p < 1.0:
1436
+ probs = TopPProcess(probs, top_p, min_tokens_to_keep)
1437
+ if paddle.device.is_compiled_with_custom_device("gcu"):
1438
+ probs = paddle.cast(probs, "float32")
1439
+ if paddle.device.is_compiled_with_xpu():
1440
+ probs = paddle.cast(probs, "float32")
1441
+
1442
+ # multinomial already support fp16 and bf16 currently, fix issue: https://github.com/PaddlePaddle/Paddle/issues/51852
1443
+ next_tokens = paddle.multinomial(probs)
1444
+
1445
+ if self.config.tensor_parallel_degree > 1:
1446
+ # Maybe no need to broadcast if seed is set correctly.
1447
+ from paddle.distributed import fleet
1448
+
1449
+ try:
1450
+ hcg = fleet.get_hybrid_communicate_group()
1451
+ group = hcg.get_model_parallel_group()
1452
+ src = hcg.get_model_parallel_group_src_rank()
1453
+ except:
1454
+ group, src = None, 0
1455
+ paddle.distributed.broadcast(next_tokens, src=src, group=group)
1456
+ # config does not include pipeline_parallel_degree, and pipeline parallel
1457
+ # uses trainer.model_wrapped to run in both train and predict mode
1458
+ # which has pp_group as a attribute
1459
+ # TODO(guosheng): only let the last stage of pipeline to do softmax
1460
+ # and sampling, and then broadcast to avoid broadcast logits.
1461
+ if getattr(self, "pp_group", None) is not None:
1462
+ paddle.distributed.broadcast(
1463
+ next_tokens,
1464
+ src=self.pp_group.ranks[0],
1465
+ group=self.pp_group, # use rank 0 for same seed to check
1466
+ )
1467
+
1468
+ next_scores = paddle.index_sample(origin_probs, next_tokens)
1469
+ if eos_token_id is not None:
1470
+ next_tokens = paddle.where(
1471
+ unfinished_flag,
1472
+ next_tokens,
1473
+ paddle.full_like(next_tokens, pad_token_id),
1474
+ )
1475
+
1476
+ scores = self.update_scores_for_generation(
1477
+ scores, next_scores, cur_len - origin_len, unfinished_flag
1478
+ )
1479
+
1480
+ cur_len += 1
1481
+ input_ids = paddle.concat([input_ids, next_tokens], axis=1)
1482
+ if streamer is not None:
1483
+ if self.config.tensor_parallel_rank == 0:
1484
+ streamer.put(next_tokens.cpu())
1485
+
1486
+ if stopping_criteria(input_ids, scores):
1487
+ generate_end = True
1488
+
1489
+ if eos_token_id is not None:
1490
+ unfinished_flag = get_unfinished_flag(
1491
+ input_ids, unfinished_flag, eos_token_id
1492
+ )
1493
+ if not paddle.any(unfinished_flag):
1494
+ generate_end = True
1495
+
1496
+ # Stop when there is a </s> in all sentences
1497
+ if generate_end and not synced_gpus:
1498
+ break
1499
+
1500
+ model_kwargs = self.update_model_kwargs_for_generation(
1501
+ outputs, model_kwargs, is_encoder_decoder=self.is_encoder_decoder
1502
+ )
1503
+ if fast_ptq_sampling:
1504
+ break
1505
+
1506
+ if streamer is not None:
1507
+ streamer.end()
1508
+
1509
+ return input_ids[:, origin_len:] if trunc_input else input_ids, scores
1510
+
1511
+ def _get_model_inputs_spec(self, dtype: str):
1512
+ spec = {
1513
+ "input_ids": paddle.static.InputSpec(shape=[None, None], dtype="int64"),
1514
+ "attention_mask": paddle.static.InputSpec(
1515
+ shape=[None, None], dtype="int64"
1516
+ ),
1517
+ }
1518
+ if "position_ids" in inspect.getfullargspec(self.forward).args:
1519
+ spec["position_ids"] = paddle.static.InputSpec(
1520
+ shape=[None, None], dtype="int64"
1521
+ )
1522
+ return spec
1523
+
1524
+ def to_static(self, path: str, config: dict):
1525
+ """export generation model to static
1526
+
1527
+ Args:
1528
+ path (str): path of saved inference model
1529
+ config (dict): configuration for generation
1530
+ bos_token_id (int): token id of begin-of-sentence
1531
+ eos_token_id (int): token id of end-of-sentence
1532
+ pad_token_id (int): token id of pad token
1533
+ use_top_p (bool): whether use top_p decoding strategy
1534
+ """
1535
+
1536
+ use_top_p = config.get("use_top_p", True)
1537
+
1538
+ top_k_spec = (
1539
+ paddle.static.InputSpec(shape=[1], dtype="int64") if not use_top_p else 0
1540
+ )
1541
+
1542
+ top_p_spec = (
1543
+ paddle.static.InputSpec(shape=[1], dtype="float32") if use_top_p else 1.0
1544
+ )
1545
+ temperature = (
1546
+ paddle.static.InputSpec(shape=[1], dtype="float32") if use_top_p else 1.0
1547
+ )
1548
+ dtype = config.get("dtype", None)
1549
+
1550
+ logits_processors = config.get("logits_processors", None)
1551
+ model_inputs_spec = self._get_model_inputs_spec(dtype)
1552
+
1553
+ input_spec = [
1554
+ model_inputs_spec["input_ids"], # input_ids
1555
+ model_inputs_spec["attention_mask"], # attention_mask
1556
+ model_inputs_spec.get("position_ids", None), # attention_mask
1557
+ logits_processors,
1558
+ paddle.static.InputSpec(shape=[1], dtype="int64"), # max_length
1559
+ self.generation_config.pad_token_id or config.get("pad_token_id", None),
1560
+ self.generation_config.eos_token_id or config.get("eos_token_id", None),
1561
+ top_k_spec, # top_k
1562
+ top_p_spec, # top_p
1563
+ temperature, # temperature
1564
+ 1,
1565
+ ]
1566
+
1567
+ model = paddle.jit.to_static(self.sample_d2s, input_spec=input_spec)
1568
+
1569
+ paddle.jit.save(model, path)
1570
+
1571
+ def sample_d2s(
1572
+ self,
1573
+ input_ids,
1574
+ attention_mask,
1575
+ position_ids,
1576
+ logits_processors,
1577
+ max_new_tokens,
1578
+ pad_token_id,
1579
+ eos_token_id,
1580
+ top_k=None,
1581
+ top_p=None,
1582
+ temperature=None,
1583
+ min_tokens_to_keep=1,
1584
+ ):
1585
+
1586
+ pad_token_id = self.set_pad_token_id(pad_token_id, eos_token_id)
1587
+ logits_processors = (
1588
+ logits_processors
1589
+ if logits_processors is not None
1590
+ else LogitsProcessorList()
1591
+ )
1592
+
1593
+ if paddle.is_tensor(top_k) and not paddle.is_tensor(top_p):
1594
+ use_top_p = False
1595
+ elif not paddle.is_tensor(top_k) and paddle.is_tensor(top_p):
1596
+ use_top_p = True
1597
+
1598
+ # top_k and top_p are the const value
1599
+ elif isinstance(top_p, float) or isinstance(top_k, int):
1600
+ use_top_p = True
1601
+ else:
1602
+ if top_p is None and top_k is None:
1603
+ raise ValueError("top_k and top_p should not be None")
1604
+ raise ValueError(
1605
+ "you should not specify InputSpec for top_k and top_p parameters, one of InputSpec is expected"
1606
+ )
1607
+
1608
+ batch_size, cur_len = input_ids.shape
1609
+ # used for compute on gpu, avoid memcpy D2H
1610
+ cur_len_gpu = paddle.full([1], cur_len, dtype="int64")
1611
+
1612
+ origin_len = input_ids.shape[1]
1613
+ # used for compute on gpu, avoid memcpy D2H
1614
+ origin_len_gpu = paddle.full([1], origin_len, dtype="int64")
1615
+
1616
+ unfinished_flag = paddle.full([batch_size, 1], True, dtype="bool")
1617
+
1618
+ scores = paddle.full([batch_size, 1], 0.0, dtype=paddle.get_default_dtype())
1619
+
1620
+ # use_cache is immutable, we split it off other mutable kwargs.
1621
+ immutable = {"use_cache": True}
1622
+ model_kwargs = {"attention_mask": attention_mask, "position_ids": position_ids}
1623
+
1624
+ def _forward_(**args):
1625
+ model_inputs = self.prepare_inputs_for_generation(
1626
+ input_ids, **args, **immutable
1627
+ )
1628
+ assert "use_cache" in model_inputs
1629
+ del model_inputs["use_cache"]
1630
+ return self(**model_inputs, **immutable)
1631
+
1632
+ def _post_process_(
1633
+ outputs,
1634
+ input_ids,
1635
+ cur_len,
1636
+ origin_len,
1637
+ scores,
1638
+ unfinished_flag,
1639
+ model_kwargs,
1640
+ pad_token_id,
1641
+ ):
1642
+ if isinstance(outputs, tuple):
1643
+ logits = outputs[0]
1644
+ elif isinstance(outputs, ModelOutput):
1645
+ logits = outputs.logits
1646
+ else:
1647
+ logits = outputs
1648
+
1649
+ # [batch_size, vocab_size]
1650
+ logits = logits[:, -1, :]
1651
+
1652
+ # pre-process distribution
1653
+ logits = self.adjust_logits_during_generation(logits)
1654
+
1655
+ logits = logits_processors(input_ids, logits)
1656
+ probs = F.softmax(logits)
1657
+
1658
+ # sample
1659
+ origin_probs = F.log_softmax(logits)
1660
+ # compute next_tokens
1661
+ if use_top_p:
1662
+ logits = logits / temperature
1663
+ top_ps_tensor = paddle.full(
1664
+ shape=[probs.shape[0], 1], fill_value=top_p, dtype=probs.dtype
1665
+ )
1666
+ _, next_tokens = paddle.tensor.top_p_sampling(probs, top_ps_tensor)
1667
+ else:
1668
+ probs = TopKProcess(probs, top_k, min_tokens_to_keep)
1669
+ if top_k == 1:
1670
+ next_tokens = paddle.unsqueeze_(paddle.argmax(probs, axis=-1), -1)
1671
+ else:
1672
+ next_tokens = paddle.multinomial(probs)
1673
+
1674
+ next_scores = paddle.index_sample(origin_probs, next_tokens)
1675
+ scores = self.update_scores_for_generation(
1676
+ scores, next_scores, cur_len - origin_len, unfinished_flag
1677
+ )
1678
+ if eos_token_id is not None:
1679
+ next_tokens = paddle.where(
1680
+ unfinished_flag,
1681
+ next_tokens,
1682
+ paddle.full_like(next_tokens, pad_token_id),
1683
+ )
1684
+
1685
+ input_ids = paddle.concat([input_ids, next_tokens], axis=1)
1686
+
1687
+ if eos_token_id is not None:
1688
+ unfinished_flag = get_unfinished_flag(
1689
+ input_ids, unfinished_flag, eos_token_id
1690
+ )
1691
+
1692
+ model_kwargs = self.update_model_kwargs_for_generation(
1693
+ outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
1694
+ )
1695
+
1696
+ return input_ids, scores, unfinished_flag, model_kwargs
1697
+
1698
+ outputs = _forward_(**model_kwargs)
1699
+ input_ids, scores, unfinished_flag, model_kwargs = _post_process_(
1700
+ outputs,
1701
+ input_ids,
1702
+ cur_len_gpu,
1703
+ origin_len_gpu,
1704
+ scores,
1705
+ unfinished_flag,
1706
+ model_kwargs,
1707
+ pad_token_id,
1708
+ )
1709
+
1710
+ cur_len += 1
1711
+ cur_len_gpu += 1
1712
+
1713
+ attn_mask = model_kwargs["attention_mask"]
1714
+ # make the shape of attention_mask = (-1, -1, -1, -1) in dy2static.
1715
+ model_kwargs["attention_mask"] = paddle.reshape(attn_mask, attn_mask.shape)
1716
+ model_kwargs["cache"] = outputs[1] if isinstance(outputs, tuple) else None
1717
+ max_new_tokens = paddle.full([1], max_new_tokens + cur_len - 1, dtype="int64")
1718
+
1719
+ while cur_len < max_new_tokens and paddle.any(unfinished_flag):
1720
+ input_ids, scores, unfinished_flag, model_kwargs = _post_process_(
1721
+ _forward_(**model_kwargs),
1722
+ input_ids,
1723
+ cur_len_gpu,
1724
+ origin_len_gpu,
1725
+ scores,
1726
+ unfinished_flag,
1727
+ model_kwargs,
1728
+ pad_token_id,
1729
+ )
1730
+ cur_len += 1
1731
+ cur_len_gpu += 1
1732
+
1733
+ return input_ids[:, origin_len:], scores
1734
+
1735
+ def reorder_cache(self, cache, beam_idx):
1736
+ cache = map_structure(lambda x: paddle.index_select(x, beam_idx), cache)
1737
+ return cache
1738
+
1739
+ def beam_search(
1740
+ self,
1741
+ input_ids,
1742
+ beam_scorer,
1743
+ logits_processors,
1744
+ max_length,
1745
+ diversity_rate,
1746
+ pad_token_id,
1747
+ eos_token_id,
1748
+ stopping_criteria=None,
1749
+ fast_ptq_sampling=False,
1750
+ trunc_input=True,
1751
+ synced_gpus=False,
1752
+ **model_kwargs,
1753
+ ):
1754
+ model_kwargs["use_cache"] = model_kwargs.get("use_cache", True)
1755
+
1756
+ logits_processors = (
1757
+ logits_processors
1758
+ if logits_processors is not None
1759
+ else LogitsProcessorList()
1760
+ )
1761
+
1762
+ # max_length will be convert to MaxLengthCriteria
1763
+ stopping_criteria = (
1764
+ stopping_criteria
1765
+ if stopping_criteria is not None
1766
+ else StoppingCriteriaList()
1767
+ )
1768
+ if max_length is not None:
1769
+ # logging.warning(
1770
+ # "`max_length` is deprecated in this function, use"
1771
+ # " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead."
1772
+ # )
1773
+ stopping_criteria = validate_stopping_criteria(
1774
+ stopping_criteria, max_length
1775
+ )
1776
+
1777
+ batch_size = len(beam_scorer._beam_hyps)
1778
+ num_beams = beam_scorer.num_beams
1779
+ batch_beam_size, cur_len = input_ids.shape
1780
+ origin_len = cur_len
1781
+
1782
+ assert (
1783
+ num_beams * batch_size == batch_beam_size
1784
+ ), "Batch dimension of `input_ids` should be {}, but received {}.".format(
1785
+ num_beams * batch_size, batch_beam_size
1786
+ )
1787
+
1788
+ beam_scores = paddle.zeros(
1789
+ (batch_size, num_beams), dtype=paddle.get_default_dtype()
1790
+ )
1791
+
1792
+ beam_scores[:, 1:] = get_scale_by_dtype(return_positive=False)
1793
+ beam_scores = paddle.reshape(beam_scores, [-1])
1794
+
1795
+ generate_end = False
1796
+ while True:
1797
+ if synced_gpus:
1798
+ # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
1799
+ # The following logic allows an early break if all peers finished generating their sequence
1800
+ this_peer_finished_flag = paddle.to_tensor(0.0 if generate_end else 1.0)
1801
+ # send 0.0 if we finished, 1.0 otherwise
1802
+ dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
1803
+ # did all peers finish? the reduced sum will be 0.0 then
1804
+ if this_peer_finished_flag.item() == 0.0:
1805
+ break
1806
+ # prepare model inputs & get model output
1807
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
1808
+
1809
+ outputs = self(**model_inputs)
1810
+ if synced_gpus and generate_end:
1811
+ cur_len = cur_len + 1
1812
+ continue # don't waste resources running the code we don't need
1813
+
1814
+ if isinstance(outputs, tuple):
1815
+ logits = outputs[0]
1816
+ elif isinstance(outputs, ModelOutput):
1817
+ logits = outputs.logits
1818
+ else:
1819
+ logits = outputs
1820
+
1821
+ # [batch_size, vocab_size]
1822
+ logits = logits[:, -1, :]
1823
+
1824
+ # pre-process distribution
1825
+ logits = self.adjust_logits_during_generation(logits)
1826
+ # beam search
1827
+ # [batch_size * num_beams, vocab_size]
1828
+ next_scores = F.softmax(logits)
1829
+ next_scores = paddle.log(next_scores)
1830
+ next_scores = logits_processors(input_ids, next_scores)
1831
+ next_scores = next_scores + beam_scores.unsqueeze(-1)
1832
+
1833
+ vocab_size = next_scores.shape[-1]
1834
+ if diversity_rate == 0.0:
1835
+ # reshape for beam search
1836
+ next_scores = next_scores.reshape([batch_size, num_beams * vocab_size])
1837
+
1838
+ next_scores, next_tokens = paddle.topk(
1839
+ next_scores, 2 * num_beams, axis=1
1840
+ )
1841
+
1842
+ next_indices = next_tokens // vocab_size
1843
+ next_tokens = next_tokens % vocab_size
1844
+
1845
+ else:
1846
+ next_scores, next_tokens = paddle.topk(
1847
+ next_scores, 2 * num_beams, axis=1
1848
+ )
1849
+
1850
+ sibling_score = (
1851
+ paddle.arange(1, 2 * num_beams + 1, dtype="int64").unsqueeze(0)
1852
+ * diversity_rate
1853
+ )
1854
+
1855
+ diversed_score = next_scores - sibling_score
1856
+
1857
+ next_scores = next_scores.reshape(
1858
+ [batch_size, 2 * num_beams * num_beams]
1859
+ )
1860
+ next_tokens = next_tokens.reshape(
1861
+ [batch_size, 2 * num_beams * num_beams]
1862
+ )
1863
+
1864
+ diversed_score = diversed_score.reshape(
1865
+ [batch_size, 2 * num_beams * num_beams]
1866
+ )
1867
+ diversed_score, diversed_tokens = paddle.topk(
1868
+ diversed_score, 2 * num_beams, axis=1
1869
+ )
1870
+
1871
+ # TODO
1872
+ # Use gather_nd() to select origan token and score
1873
+ next_scores = paddle.stack(
1874
+ [
1875
+ paddle.index_select(next_scores[i], diversed_tokens[i])
1876
+ for i in range(next_scores.shape[0])
1877
+ ]
1878
+ )
1879
+ next_tokens = paddle.stack(
1880
+ [
1881
+ paddle.index_select(next_tokens[i], diversed_tokens[i])
1882
+ for i in range(next_tokens.shape[0])
1883
+ ]
1884
+ )
1885
+
1886
+ next_indices = diversed_tokens // (2 * num_beams)
1887
+
1888
+ # stateless
1889
+ beam_outputs = beam_scorer.process(
1890
+ input_ids,
1891
+ next_scores,
1892
+ next_tokens,
1893
+ next_indices,
1894
+ origin_len=origin_len,
1895
+ pad_token_id=pad_token_id,
1896
+ eos_token_id=eos_token_id,
1897
+ )
1898
+ beam_scores = beam_outputs["next_beam_scores"]
1899
+ beam_next_tokens = beam_outputs["next_beam_tokens"]
1900
+ beam_idx = beam_outputs["next_beam_indices"]
1901
+ # beam_idx may contain element -1 and cause error
1902
+ # PR: https://github.com/PaddlePaddle/Paddle/issues/57366
1903
+ beam_idx = paddle.maximum(beam_idx, paddle.full_like(beam_idx, 0))
1904
+
1905
+ cur_len += 1
1906
+ input_ids = paddle.concat(
1907
+ [
1908
+ paddle.index_select(input_ids, beam_idx),
1909
+ beam_next_tokens.unsqueeze(-1),
1910
+ ],
1911
+ axis=-1,
1912
+ )
1913
+
1914
+ if beam_scorer.is_done or stopping_criteria(input_ids, beam_scores):
1915
+ if not synced_gpus:
1916
+ break
1917
+ else:
1918
+ generate_end = True
1919
+
1920
+ model_kwargs = self.update_model_kwargs_for_generation(
1921
+ outputs, model_kwargs, is_encoder_decoder=self.is_encoder_decoder
1922
+ )
1923
+ if "cache" in model_kwargs:
1924
+ # reorder the cache
1925
+ model_kwargs["cache"] = self.reorder_cache(
1926
+ model_kwargs["cache"], beam_idx
1927
+ )
1928
+ if "past_key_values" in model_kwargs:
1929
+ # reorder the cache
1930
+ model_kwargs["past_key_values"] = self.reorder_cache(
1931
+ model_kwargs["past_key_values"], beam_idx
1932
+ )
1933
+ if fast_ptq_sampling:
1934
+ break
1935
+
1936
+ pred_ids, scores = beam_scorer.finalize(
1937
+ input_ids,
1938
+ beam_scores,
1939
+ next_tokens,
1940
+ next_indices,
1941
+ origin_len=origin_len,
1942
+ pad_token_id=pad_token_id,
1943
+ eos_token_id=eos_token_id,
1944
+ )
1945
+ return pred_ids[:, origin_len:] if trunc_input else input_ids, scores
1946
+
1947
+ def group_beam_search(
1948
+ self,
1949
+ input_ids,
1950
+ beam_scorer,
1951
+ logits_processors,
1952
+ max_length,
1953
+ pad_token_id,
1954
+ eos_token_id,
1955
+ stopping_criteria=None,
1956
+ fast_ptq_sampling=False,
1957
+ trunc_input=True,
1958
+ synced_gpus=False,
1959
+ **model_kwargs,
1960
+ ):
1961
+ model_kwargs["use_cache"] = model_kwargs.get("use_cache", True)
1962
+ logits_processors = (
1963
+ logits_processors
1964
+ if logits_processors is not None
1965
+ else LogitsProcessorList()
1966
+ )
1967
+
1968
+ # max_length will be convert to MaxLengthCriteria
1969
+ stopping_criteria = (
1970
+ stopping_criteria
1971
+ if stopping_criteria is not None
1972
+ else StoppingCriteriaList()
1973
+ )
1974
+ if max_length is not None:
1975
+ # logging.warning(
1976
+ # "`max_length` is deprecated in this function, use"
1977
+ # " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead."
1978
+ # )
1979
+ stopping_criteria = validate_stopping_criteria(
1980
+ stopping_criteria, max_length
1981
+ )
1982
+
1983
+ batch_size = len(beam_scorer._beam_hyps)
1984
+ num_beams = beam_scorer.num_beams
1985
+ num_beam_groups = beam_scorer.num_beam_groups
1986
+ num_sub_beams = num_beams // num_beam_groups
1987
+
1988
+ batch_beam_size, cur_len = input_ids.shape
1989
+ origin_len = cur_len
1990
+
1991
+ assert (
1992
+ num_beams * batch_size == batch_beam_size
1993
+ ), "Batch dimension of `input_ids` should be {}, but received {}.".format(
1994
+ num_beams * batch_size, batch_beam_size
1995
+ )
1996
+
1997
+ beam_scores = paddle.full(
1998
+ (batch_size, num_beams),
1999
+ get_scale_by_dtype(return_positive=False),
2000
+ dtype="float32",
2001
+ )
2002
+ # initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in
2003
+ # the same group don't produce same tokens everytime.
2004
+ beam_scores[:, ::num_sub_beams] = 0
2005
+ beam_scores = paddle.reshape(beam_scores, [-1])
2006
+
2007
+ generate_end = False
2008
+ while True:
2009
+ if synced_gpus:
2010
+ # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
2011
+ # The following logic allows an early break if all peers finished generating their sequence
2012
+ this_peer_finished_flag = paddle.to_tensor(0.0 if generate_end else 1.0)
2013
+ # send 0.0 if we finished, 1.0 otherwise
2014
+ dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
2015
+ # did all peers finish? the reduced sum will be 0.0 then
2016
+ if this_peer_finished_flag.item() == 0.0:
2017
+ break
2018
+ # predicted tokens in cur_len step
2019
+ current_tokens = paddle.zeros(
2020
+ shape=[batch_size * num_beams], dtype=input_ids.dtype
2021
+ )
2022
+
2023
+ # indices which will form the beams in the next time step
2024
+ reordering_indices = paddle.zeros(
2025
+ shape=[batch_size * num_beams], dtype="int64"
2026
+ )
2027
+ # prepare model inputs & get model output
2028
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
2029
+ outputs = self(**model_inputs)
2030
+ if synced_gpus and generate_end:
2031
+ cur_len = cur_len + 1
2032
+ continue # don't waste resources running the code we don't need
2033
+
2034
+ for beam_group_idx in range(num_beam_groups):
2035
+ group_start_idx = beam_group_idx * num_sub_beams
2036
+ group_end_idx = min(group_start_idx + num_sub_beams, num_beams)
2037
+ group_size = group_end_idx - group_start_idx
2038
+
2039
+ # indices of beams of current group among all sentences in batch
2040
+ batch_group_indices = []
2041
+
2042
+ for batch_idx in range(batch_size):
2043
+ batch_group_indices.extend(
2044
+ [
2045
+ batch_idx * num_beams + idx
2046
+ for idx in range(group_start_idx, group_end_idx)
2047
+ ]
2048
+ )
2049
+
2050
+ group_input_ids = input_ids[batch_group_indices]
2051
+
2052
+ if isinstance(outputs, tuple):
2053
+ logits = outputs[0]
2054
+ elif isinstance(outputs, ModelOutput):
2055
+ logits = outputs.logits
2056
+ else:
2057
+ logits = outputs
2058
+
2059
+ logits = logits[:, -1, :]
2060
+ logits = paddle.index_select(
2061
+ logits, paddle.to_tensor(batch_group_indices)
2062
+ )
2063
+ logits = self.adjust_logits_during_generation(logits)
2064
+
2065
+ next_scores = F.softmax(logits)
2066
+ next_scores = paddle.log(next_scores)
2067
+ vocab_size = next_scores.shape[-1]
2068
+
2069
+ next_scores = logits_processors(
2070
+ group_input_ids,
2071
+ next_scores,
2072
+ current_tokens=current_tokens,
2073
+ beam_group_idx=beam_group_idx,
2074
+ )
2075
+
2076
+ next_scores = next_scores + beam_scores[batch_group_indices].unsqueeze(
2077
+ -1
2078
+ )
2079
+
2080
+ # reshape for beam search
2081
+ next_scores = next_scores.reshape([batch_size, group_size * vocab_size])
2082
+
2083
+ next_scores, next_tokens = paddle.topk(
2084
+ next_scores, 2 * group_size, axis=1
2085
+ )
2086
+
2087
+ next_indices = next_tokens // vocab_size
2088
+ next_tokens = next_tokens % vocab_size
2089
+
2090
+ beam_outputs = beam_scorer.process(
2091
+ group_input_ids,
2092
+ next_scores,
2093
+ next_tokens,
2094
+ next_indices,
2095
+ origin_len=origin_len,
2096
+ pad_token_id=pad_token_id,
2097
+ eos_token_id=eos_token_id,
2098
+ )
2099
+
2100
+ beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]
2101
+ beam_next_tokens = beam_outputs["next_beam_tokens"]
2102
+ beam_idx = beam_outputs["next_beam_indices"]
2103
+ # beam_idx may contain element -1 and cause error
2104
+ # PR: https://github.com/PaddlePaddle/Paddle/issues/57366
2105
+ beam_idx = paddle.maximum(beam_idx, paddle.full_like(beam_idx, 0))
2106
+
2107
+ input_ids[batch_group_indices] = group_input_ids[beam_idx]
2108
+ group_input_ids = paddle.concat(
2109
+ [
2110
+ paddle.index_select(group_input_ids, index=beam_idx),
2111
+ beam_next_tokens.unsqueeze(-1),
2112
+ ],
2113
+ axis=-1,
2114
+ )
2115
+ current_tokens[batch_group_indices] = beam_next_tokens
2116
+
2117
+ reordering_indices[batch_group_indices] = (
2118
+ num_beams * (beam_idx // group_size)
2119
+ + group_start_idx
2120
+ + (beam_idx % group_size)
2121
+ )
2122
+
2123
+ input_ids = paddle.concat(
2124
+ [input_ids, current_tokens.unsqueeze(-1)], axis=-1
2125
+ )
2126
+
2127
+ cur_len += 1
2128
+
2129
+ if beam_scorer.is_done or stopping_criteria(input_ids, beam_scores):
2130
+ if not synced_gpus:
2131
+ break
2132
+ else:
2133
+ generate_end = True
2134
+
2135
+ model_kwargs = self.update_model_kwargs_for_generation(
2136
+ outputs, model_kwargs, is_encoder_decoder=self.is_encoder_decoder
2137
+ )
2138
+
2139
+ if "cache" in model_kwargs:
2140
+ # reorder the cache
2141
+ model_kwargs["cache"] = self.reorder_cache(
2142
+ model_kwargs["cache"], reordering_indices
2143
+ )
2144
+ if "past_key_values" in model_kwargs:
2145
+ # reorder the cache
2146
+ model_kwargs["past_key_values"] = self.reorder_cache(
2147
+ model_kwargs["past_key_values"], reordering_indices
2148
+ )
2149
+
2150
+ if fast_ptq_sampling:
2151
+ break
2152
+
2153
+ pred_ids, scores = beam_scorer.finalize(
2154
+ input_ids,
2155
+ beam_scores,
2156
+ next_tokens,
2157
+ next_indices,
2158
+ origin_len=origin_len,
2159
+ pad_token_id=pad_token_id,
2160
+ eos_token_id=eos_token_id,
2161
+ )
2162
+ return pred_ids[:, origin_len:] if trunc_input else input_ids, scores