paddlex 3.0.0rc0__py3-none-any.whl → 3.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (824) hide show
  1. paddlex/.version +1 -1
  2. paddlex/__init__.py +17 -34
  3. paddlex/__main__.py +1 -1
  4. paddlex/configs/modules/chart_parsing/PP-Chart2Table.yaml +13 -0
  5. paddlex/configs/modules/doc_vlm/PP-DocBee-2B.yaml +14 -0
  6. paddlex/configs/modules/doc_vlm/PP-DocBee-7B.yaml +14 -0
  7. paddlex/configs/modules/doc_vlm/PP-DocBee2-3B.yaml +14 -0
  8. paddlex/configs/modules/formula_recognition/PP-FormulaNet_plus-L.yaml +40 -0
  9. paddlex/configs/modules/formula_recognition/PP-FormulaNet_plus-M.yaml +40 -0
  10. paddlex/configs/modules/formula_recognition/PP-FormulaNet_plus-S.yaml +40 -0
  11. paddlex/configs/modules/layout_detection/PP-DocBlockLayout.yaml +40 -0
  12. paddlex/configs/modules/layout_detection/PP-DocLayout-L.yaml +2 -2
  13. paddlex/configs/modules/layout_detection/PP-DocLayout-M.yaml +2 -2
  14. paddlex/configs/modules/layout_detection/PP-DocLayout-S.yaml +2 -2
  15. paddlex/configs/modules/layout_detection/PP-DocLayout_plus-L.yaml +40 -0
  16. paddlex/configs/modules/open_vocabulary_detection/YOLO-Worldv2-L.yaml +13 -0
  17. paddlex/configs/modules/text_detection/PP-OCRv5_mobile_det.yaml +40 -0
  18. paddlex/configs/modules/text_detection/PP-OCRv5_server_det.yaml +40 -0
  19. paddlex/configs/modules/text_recognition/PP-OCRv5_mobile_rec.yaml +39 -0
  20. paddlex/configs/modules/text_recognition/PP-OCRv5_server_rec.yaml +39 -0
  21. paddlex/configs/modules/textline_orientation/PP-LCNet_x1_0_textline_ori.yaml +41 -0
  22. paddlex/configs/pipelines/OCR.yaml +7 -6
  23. paddlex/configs/pipelines/PP-ChatOCRv3-doc.yaml +3 -1
  24. paddlex/configs/pipelines/PP-ChatOCRv4-doc.yaml +91 -34
  25. paddlex/configs/pipelines/PP-StructureV3.yaml +72 -72
  26. paddlex/configs/pipelines/anomaly_detection.yaml +1 -1
  27. paddlex/configs/pipelines/doc_understanding.yaml +9 -0
  28. paddlex/configs/pipelines/formula_recognition.yaml +2 -2
  29. paddlex/configs/pipelines/layout_parsing.yaml +3 -2
  30. paddlex/configs/pipelines/seal_recognition.yaml +1 -0
  31. paddlex/configs/pipelines/table_recognition.yaml +2 -1
  32. paddlex/configs/pipelines/table_recognition_v2.yaml +7 -1
  33. paddlex/configs/pipelines/ts_anomaly_detection.yaml +1 -1
  34. paddlex/configs/pipelines/ts_classification.yaml +1 -1
  35. paddlex/configs/pipelines/ts_forecast.yaml +1 -1
  36. paddlex/constants.py +17 -0
  37. paddlex/engine.py +7 -5
  38. paddlex/hpip_links.html +23 -11
  39. paddlex/inference/__init__.py +3 -3
  40. paddlex/inference/common/__init__.py +1 -1
  41. paddlex/inference/common/batch_sampler/__init__.py +5 -4
  42. paddlex/inference/common/batch_sampler/audio_batch_sampler.py +5 -6
  43. paddlex/inference/common/batch_sampler/base_batch_sampler.py +20 -16
  44. paddlex/inference/common/batch_sampler/det_3d_batch_sampler.py +4 -7
  45. paddlex/inference/common/batch_sampler/doc_vlm_batch_sampler.py +87 -0
  46. paddlex/inference/common/batch_sampler/image_batch_sampler.py +45 -60
  47. paddlex/inference/common/batch_sampler/ts_batch_sampler.py +9 -10
  48. paddlex/inference/common/batch_sampler/video_batch_sampler.py +2 -22
  49. paddlex/inference/common/reader/__init__.py +4 -4
  50. paddlex/inference/common/reader/audio_reader.py +3 -3
  51. paddlex/inference/common/reader/det_3d_reader.py +7 -5
  52. paddlex/inference/common/reader/image_reader.py +16 -12
  53. paddlex/inference/common/reader/ts_reader.py +3 -2
  54. paddlex/inference/common/reader/video_reader.py +3 -3
  55. paddlex/inference/common/result/__init__.py +7 -7
  56. paddlex/inference/common/result/base_cv_result.py +12 -2
  57. paddlex/inference/common/result/base_result.py +7 -5
  58. paddlex/inference/common/result/base_ts_result.py +1 -2
  59. paddlex/inference/common/result/base_video_result.py +2 -2
  60. paddlex/inference/common/result/mixin.py +31 -25
  61. paddlex/inference/models/__init__.py +41 -85
  62. paddlex/inference/models/anomaly_detection/__init__.py +1 -1
  63. paddlex/inference/models/anomaly_detection/predictor.py +9 -19
  64. paddlex/inference/models/anomaly_detection/processors.py +9 -2
  65. paddlex/inference/models/anomaly_detection/result.py +3 -2
  66. paddlex/inference/models/base/__init__.py +2 -2
  67. paddlex/inference/models/base/predictor/__init__.py +1 -2
  68. paddlex/inference/models/base/predictor/base_predictor.py +278 -39
  69. paddlex/inference/models/common/__init__.py +6 -15
  70. paddlex/inference/models/common/static_infer.py +724 -251
  71. paddlex/inference/models/common/tokenizer/__init__.py +7 -3
  72. paddlex/inference/models/common/tokenizer/bert_tokenizer.py +1 -1
  73. paddlex/inference/models/common/tokenizer/clip_tokenizer.py +609 -0
  74. paddlex/inference/models/common/tokenizer/gpt_tokenizer.py +9 -7
  75. paddlex/inference/models/common/tokenizer/qwen2_5_tokenizer.py +112 -0
  76. paddlex/inference/models/common/tokenizer/qwen2_tokenizer.py +438 -0
  77. paddlex/inference/models/common/tokenizer/qwen_tokenizer.py +288 -0
  78. paddlex/inference/models/common/tokenizer/tokenizer_utils.py +85 -77
  79. paddlex/inference/models/common/tokenizer/tokenizer_utils_base.py +339 -123
  80. paddlex/inference/models/common/tokenizer/utils.py +1 -1
  81. paddlex/inference/models/common/tokenizer/vocab.py +8 -8
  82. paddlex/inference/models/common/ts/__init__.py +1 -1
  83. paddlex/inference/models/common/ts/funcs.py +13 -6
  84. paddlex/inference/models/common/ts/processors.py +14 -5
  85. paddlex/inference/models/common/vision/__init__.py +3 -3
  86. paddlex/inference/models/common/vision/funcs.py +17 -12
  87. paddlex/inference/models/common/vision/processors.py +61 -46
  88. paddlex/inference/models/common/vlm/__init__.py +13 -0
  89. paddlex/inference/models/common/vlm/activations.py +189 -0
  90. paddlex/inference/models/common/vlm/bert_padding.py +127 -0
  91. paddlex/inference/models/common/vlm/conversion_utils.py +99 -0
  92. paddlex/inference/models/common/vlm/distributed.py +229 -0
  93. paddlex/inference/models/common/vlm/flash_attn_utils.py +119 -0
  94. paddlex/inference/models/common/vlm/fusion_ops.py +205 -0
  95. paddlex/inference/models/common/vlm/generation/__init__.py +34 -0
  96. paddlex/inference/models/common/vlm/generation/configuration_utils.py +533 -0
  97. paddlex/inference/models/common/vlm/generation/logits_process.py +730 -0
  98. paddlex/inference/models/common/vlm/generation/stopping_criteria.py +106 -0
  99. paddlex/inference/models/common/vlm/generation/utils.py +2162 -0
  100. paddlex/inference/models/common/vlm/transformers/__init__.py +16 -0
  101. paddlex/inference/models/common/vlm/transformers/configuration_utils.py +1037 -0
  102. paddlex/inference/models/common/vlm/transformers/conversion_utils.py +408 -0
  103. paddlex/inference/models/common/vlm/transformers/model_outputs.py +1612 -0
  104. paddlex/inference/models/common/vlm/transformers/model_utils.py +2014 -0
  105. paddlex/inference/models/common/vlm/transformers/utils.py +178 -0
  106. paddlex/inference/models/common/vlm/utils.py +109 -0
  107. paddlex/inference/models/doc_vlm/__init__.py +15 -0
  108. paddlex/inference/models/doc_vlm/modeling/GOT_ocr_2_0.py +830 -0
  109. paddlex/inference/models/doc_vlm/modeling/__init__.py +17 -0
  110. paddlex/inference/models/doc_vlm/modeling/qwen2.py +1606 -0
  111. paddlex/inference/models/doc_vlm/modeling/qwen2_5_vl.py +3006 -0
  112. paddlex/inference/models/doc_vlm/modeling/qwen2_vl.py +2495 -0
  113. paddlex/inference/models/doc_vlm/predictor.py +253 -0
  114. paddlex/inference/models/doc_vlm/processors/GOT_ocr_2_0.py +97 -0
  115. paddlex/inference/models/doc_vlm/processors/__init__.py +17 -0
  116. paddlex/inference/models/doc_vlm/processors/common.py +561 -0
  117. paddlex/inference/models/doc_vlm/processors/qwen2_5_vl.py +548 -0
  118. paddlex/inference/models/doc_vlm/processors/qwen2_vl.py +543 -0
  119. paddlex/inference/models/doc_vlm/result.py +21 -0
  120. paddlex/inference/models/face_feature/__init__.py +1 -1
  121. paddlex/inference/models/face_feature/predictor.py +2 -1
  122. paddlex/inference/models/formula_recognition/__init__.py +1 -1
  123. paddlex/inference/models/formula_recognition/predictor.py +18 -28
  124. paddlex/inference/models/formula_recognition/processors.py +126 -97
  125. paddlex/inference/models/formula_recognition/result.py +43 -35
  126. paddlex/inference/models/image_classification/__init__.py +1 -1
  127. paddlex/inference/models/image_classification/predictor.py +9 -19
  128. paddlex/inference/models/image_classification/processors.py +4 -2
  129. paddlex/inference/models/image_classification/result.py +4 -3
  130. paddlex/inference/models/image_feature/__init__.py +1 -1
  131. paddlex/inference/models/image_feature/predictor.py +9 -19
  132. paddlex/inference/models/image_feature/processors.py +7 -5
  133. paddlex/inference/models/image_feature/result.py +2 -3
  134. paddlex/inference/models/image_multilabel_classification/__init__.py +1 -1
  135. paddlex/inference/models/image_multilabel_classification/predictor.py +7 -6
  136. paddlex/inference/models/image_multilabel_classification/processors.py +6 -2
  137. paddlex/inference/models/image_multilabel_classification/result.py +4 -3
  138. paddlex/inference/models/image_unwarping/__init__.py +1 -1
  139. paddlex/inference/models/image_unwarping/predictor.py +8 -16
  140. paddlex/inference/models/image_unwarping/processors.py +6 -2
  141. paddlex/inference/models/image_unwarping/result.py +4 -2
  142. paddlex/inference/models/instance_segmentation/__init__.py +1 -1
  143. paddlex/inference/models/instance_segmentation/predictor.py +7 -15
  144. paddlex/inference/models/instance_segmentation/processors.py +4 -7
  145. paddlex/inference/models/instance_segmentation/result.py +11 -10
  146. paddlex/inference/models/keypoint_detection/__init__.py +1 -1
  147. paddlex/inference/models/keypoint_detection/predictor.py +5 -3
  148. paddlex/inference/models/keypoint_detection/processors.py +11 -3
  149. paddlex/inference/models/keypoint_detection/result.py +9 -4
  150. paddlex/inference/models/{3d_bev_detection → m_3d_bev_detection}/__init__.py +1 -1
  151. paddlex/inference/models/{3d_bev_detection → m_3d_bev_detection}/predictor.py +15 -26
  152. paddlex/inference/models/{3d_bev_detection → m_3d_bev_detection}/processors.py +26 -14
  153. paddlex/inference/models/{3d_bev_detection → m_3d_bev_detection}/result.py +15 -12
  154. paddlex/inference/models/{3d_bev_detection → m_3d_bev_detection}/visualizer_3d.py +77 -39
  155. paddlex/inference/models/multilingual_speech_recognition/__init__.py +1 -1
  156. paddlex/inference/models/multilingual_speech_recognition/predictor.py +11 -15
  157. paddlex/inference/models/multilingual_speech_recognition/processors.py +45 -53
  158. paddlex/inference/models/multilingual_speech_recognition/result.py +1 -1
  159. paddlex/inference/models/object_detection/__init__.py +1 -1
  160. paddlex/inference/models/object_detection/predictor.py +8 -12
  161. paddlex/inference/models/object_detection/processors.py +63 -33
  162. paddlex/inference/models/object_detection/result.py +5 -4
  163. paddlex/inference/models/object_detection/utils.py +3 -1
  164. paddlex/inference/models/open_vocabulary_detection/__init__.py +1 -1
  165. paddlex/inference/models/open_vocabulary_detection/predictor.py +31 -14
  166. paddlex/inference/models/open_vocabulary_detection/processors/__init__.py +3 -2
  167. paddlex/inference/models/open_vocabulary_detection/processors/common.py +114 -0
  168. paddlex/inference/models/open_vocabulary_detection/processors/groundingdino_processors.py +19 -8
  169. paddlex/inference/models/open_vocabulary_detection/processors/yoloworld_processors.py +209 -0
  170. paddlex/inference/models/open_vocabulary_segmentation/__init__.py +1 -1
  171. paddlex/inference/models/open_vocabulary_segmentation/predictor.py +6 -13
  172. paddlex/inference/models/open_vocabulary_segmentation/processors/__init__.py +1 -1
  173. paddlex/inference/models/open_vocabulary_segmentation/processors/sam_processer.py +12 -12
  174. paddlex/inference/models/open_vocabulary_segmentation/results/__init__.py +1 -1
  175. paddlex/inference/models/open_vocabulary_segmentation/results/sam_result.py +11 -9
  176. paddlex/inference/models/semantic_segmentation/__init__.py +1 -1
  177. paddlex/inference/models/semantic_segmentation/predictor.py +9 -18
  178. paddlex/inference/models/semantic_segmentation/processors.py +11 -8
  179. paddlex/inference/models/semantic_segmentation/result.py +4 -3
  180. paddlex/inference/models/table_structure_recognition/__init__.py +1 -1
  181. paddlex/inference/models/table_structure_recognition/predictor.py +8 -18
  182. paddlex/inference/models/table_structure_recognition/processors.py +23 -29
  183. paddlex/inference/models/table_structure_recognition/result.py +8 -15
  184. paddlex/inference/models/text_detection/__init__.py +1 -1
  185. paddlex/inference/models/text_detection/predictor.py +24 -24
  186. paddlex/inference/models/text_detection/processors.py +116 -44
  187. paddlex/inference/models/text_detection/result.py +8 -13
  188. paddlex/inference/models/text_recognition/__init__.py +1 -1
  189. paddlex/inference/models/text_recognition/predictor.py +11 -19
  190. paddlex/inference/models/text_recognition/processors.py +27 -13
  191. paddlex/inference/models/text_recognition/result.py +3 -2
  192. paddlex/inference/models/ts_anomaly_detection/__init__.py +1 -1
  193. paddlex/inference/models/ts_anomaly_detection/predictor.py +12 -17
  194. paddlex/inference/models/ts_anomaly_detection/processors.py +6 -2
  195. paddlex/inference/models/ts_anomaly_detection/result.py +21 -10
  196. paddlex/inference/models/ts_classification/__init__.py +1 -1
  197. paddlex/inference/models/ts_classification/predictor.py +14 -27
  198. paddlex/inference/models/ts_classification/processors.py +7 -2
  199. paddlex/inference/models/ts_classification/result.py +21 -12
  200. paddlex/inference/models/ts_forecasting/__init__.py +1 -1
  201. paddlex/inference/models/ts_forecasting/predictor.py +13 -18
  202. paddlex/inference/models/ts_forecasting/processors.py +12 -3
  203. paddlex/inference/models/ts_forecasting/result.py +24 -11
  204. paddlex/inference/models/video_classification/__init__.py +1 -1
  205. paddlex/inference/models/video_classification/predictor.py +9 -15
  206. paddlex/inference/models/video_classification/processors.py +24 -24
  207. paddlex/inference/models/video_classification/result.py +7 -3
  208. paddlex/inference/models/video_detection/__init__.py +1 -1
  209. paddlex/inference/models/video_detection/predictor.py +8 -15
  210. paddlex/inference/models/video_detection/processors.py +24 -11
  211. paddlex/inference/models/video_detection/result.py +10 -5
  212. paddlex/inference/pipelines/__init__.py +48 -37
  213. paddlex/inference/pipelines/_parallel.py +172 -0
  214. paddlex/inference/pipelines/anomaly_detection/__init__.py +1 -1
  215. paddlex/inference/pipelines/anomaly_detection/pipeline.py +29 -9
  216. paddlex/inference/pipelines/attribute_recognition/__init__.py +1 -1
  217. paddlex/inference/pipelines/attribute_recognition/pipeline.py +24 -9
  218. paddlex/inference/pipelines/attribute_recognition/result.py +10 -8
  219. paddlex/inference/pipelines/base.py +43 -13
  220. paddlex/inference/pipelines/components/__init__.py +14 -8
  221. paddlex/inference/pipelines/components/chat_server/__init__.py +1 -1
  222. paddlex/inference/pipelines/components/chat_server/base.py +2 -2
  223. paddlex/inference/pipelines/components/chat_server/openai_bot_chat.py +8 -8
  224. paddlex/inference/pipelines/components/common/__init__.py +5 -4
  225. paddlex/inference/pipelines/components/common/base_operator.py +2 -1
  226. paddlex/inference/pipelines/components/common/base_result.py +3 -2
  227. paddlex/inference/pipelines/components/common/convert_points_and_boxes.py +1 -2
  228. paddlex/inference/pipelines/components/common/crop_image_regions.py +11 -5
  229. paddlex/inference/pipelines/components/common/seal_det_warp.py +44 -13
  230. paddlex/inference/pipelines/components/common/sort_boxes.py +4 -2
  231. paddlex/inference/pipelines/components/common/warp_image.py +50 -0
  232. paddlex/inference/pipelines/components/faisser.py +10 -5
  233. paddlex/inference/pipelines/components/prompt_engineering/__init__.py +2 -2
  234. paddlex/inference/pipelines/components/prompt_engineering/base.py +2 -2
  235. paddlex/inference/pipelines/components/prompt_engineering/generate_ensemble_prompt.py +2 -1
  236. paddlex/inference/pipelines/components/prompt_engineering/generate_kie_prompt.py +2 -2
  237. paddlex/inference/pipelines/components/retriever/__init__.py +2 -2
  238. paddlex/inference/pipelines/components/retriever/base.py +18 -16
  239. paddlex/inference/pipelines/components/retriever/openai_bot_retriever.py +2 -2
  240. paddlex/inference/pipelines/components/retriever/qianfan_bot_retriever.py +87 -84
  241. paddlex/inference/pipelines/components/utils/__init__.py +1 -1
  242. paddlex/inference/pipelines/components/utils/mixin.py +7 -7
  243. paddlex/inference/pipelines/doc_preprocessor/__init__.py +1 -1
  244. paddlex/inference/pipelines/doc_preprocessor/pipeline.py +70 -51
  245. paddlex/inference/pipelines/doc_preprocessor/result.py +5 -10
  246. paddlex/inference/pipelines/doc_understanding/__init__.py +15 -0
  247. paddlex/inference/pipelines/doc_understanding/pipeline.py +71 -0
  248. paddlex/inference/pipelines/face_recognition/__init__.py +1 -1
  249. paddlex/inference/pipelines/face_recognition/pipeline.py +3 -1
  250. paddlex/inference/pipelines/face_recognition/result.py +3 -2
  251. paddlex/inference/pipelines/formula_recognition/__init__.py +1 -1
  252. paddlex/inference/pipelines/formula_recognition/pipeline.py +137 -93
  253. paddlex/inference/pipelines/formula_recognition/result.py +20 -29
  254. paddlex/inference/pipelines/image_classification/__init__.py +1 -1
  255. paddlex/inference/pipelines/image_classification/pipeline.py +30 -11
  256. paddlex/inference/pipelines/image_multilabel_classification/__init__.py +1 -1
  257. paddlex/inference/pipelines/image_multilabel_classification/pipeline.py +31 -12
  258. paddlex/inference/pipelines/instance_segmentation/__init__.py +1 -1
  259. paddlex/inference/pipelines/instance_segmentation/pipeline.py +30 -9
  260. paddlex/inference/pipelines/keypoint_detection/__init__.py +1 -1
  261. paddlex/inference/pipelines/keypoint_detection/pipeline.py +30 -9
  262. paddlex/inference/pipelines/layout_parsing/__init__.py +1 -1
  263. paddlex/inference/pipelines/layout_parsing/pipeline.py +54 -56
  264. paddlex/inference/pipelines/layout_parsing/pipeline_v2.py +904 -261
  265. paddlex/inference/pipelines/layout_parsing/result.py +9 -21
  266. paddlex/inference/pipelines/layout_parsing/result_v2.py +525 -250
  267. paddlex/inference/pipelines/layout_parsing/setting.py +87 -0
  268. paddlex/inference/pipelines/layout_parsing/utils.py +570 -2004
  269. paddlex/inference/pipelines/layout_parsing/xycut_enhanced/__init__.py +16 -0
  270. paddlex/inference/pipelines/layout_parsing/xycut_enhanced/utils.py +1144 -0
  271. paddlex/inference/pipelines/layout_parsing/xycut_enhanced/xycuts.py +563 -0
  272. paddlex/inference/pipelines/{3d_bev_detection → m_3d_bev_detection}/__init__.py +1 -1
  273. paddlex/inference/pipelines/{3d_bev_detection → m_3d_bev_detection}/pipeline.py +17 -10
  274. paddlex/inference/pipelines/multilingual_speech_recognition/__init__.py +1 -1
  275. paddlex/inference/pipelines/multilingual_speech_recognition/pipeline.py +17 -6
  276. paddlex/inference/pipelines/object_detection/__init__.py +1 -1
  277. paddlex/inference/pipelines/object_detection/pipeline.py +29 -9
  278. paddlex/inference/pipelines/ocr/__init__.py +1 -1
  279. paddlex/inference/pipelines/ocr/pipeline.py +151 -77
  280. paddlex/inference/pipelines/ocr/result.py +31 -24
  281. paddlex/inference/pipelines/open_vocabulary_detection/__init__.py +1 -1
  282. paddlex/inference/pipelines/open_vocabulary_detection/pipeline.py +17 -6
  283. paddlex/inference/pipelines/open_vocabulary_segmentation/__init__.py +1 -1
  284. paddlex/inference/pipelines/open_vocabulary_segmentation/pipeline.py +17 -6
  285. paddlex/inference/pipelines/pp_chatocr/__init__.py +1 -1
  286. paddlex/inference/pipelines/pp_chatocr/pipeline_base.py +14 -5
  287. paddlex/inference/pipelines/pp_chatocr/pipeline_v3.py +22 -14
  288. paddlex/inference/pipelines/pp_chatocr/pipeline_v4.py +34 -16
  289. paddlex/inference/pipelines/pp_shitu_v2/__init__.py +1 -1
  290. paddlex/inference/pipelines/pp_shitu_v2/pipeline.py +12 -8
  291. paddlex/inference/pipelines/pp_shitu_v2/result.py +4 -4
  292. paddlex/inference/pipelines/rotated_object_detection/__init__.py +1 -1
  293. paddlex/inference/pipelines/rotated_object_detection/pipeline.py +30 -9
  294. paddlex/inference/pipelines/seal_recognition/__init__.py +1 -1
  295. paddlex/inference/pipelines/seal_recognition/pipeline.py +127 -63
  296. paddlex/inference/pipelines/seal_recognition/result.py +4 -2
  297. paddlex/inference/pipelines/semantic_segmentation/__init__.py +1 -1
  298. paddlex/inference/pipelines/semantic_segmentation/pipeline.py +30 -9
  299. paddlex/inference/pipelines/small_object_detection/__init__.py +1 -1
  300. paddlex/inference/pipelines/small_object_detection/pipeline.py +30 -9
  301. paddlex/inference/pipelines/table_recognition/__init__.py +1 -1
  302. paddlex/inference/pipelines/table_recognition/pipeline.py +61 -37
  303. paddlex/inference/pipelines/table_recognition/pipeline_v2.py +668 -65
  304. paddlex/inference/pipelines/table_recognition/result.py +12 -10
  305. paddlex/inference/pipelines/table_recognition/table_recognition_post_processing.py +12 -8
  306. paddlex/inference/pipelines/table_recognition/table_recognition_post_processing_v2.py +55 -37
  307. paddlex/inference/pipelines/table_recognition/utils.py +1 -1
  308. paddlex/inference/pipelines/ts_anomaly_detection/__init__.py +1 -1
  309. paddlex/inference/pipelines/ts_anomaly_detection/pipeline.py +16 -6
  310. paddlex/inference/pipelines/ts_classification/__init__.py +1 -1
  311. paddlex/inference/pipelines/ts_classification/pipeline.py +16 -6
  312. paddlex/inference/pipelines/ts_forecasting/__init__.py +1 -1
  313. paddlex/inference/pipelines/ts_forecasting/pipeline.py +16 -6
  314. paddlex/inference/pipelines/video_classification/__init__.py +1 -1
  315. paddlex/inference/pipelines/video_classification/pipeline.py +17 -6
  316. paddlex/inference/pipelines/video_detection/__init__.py +1 -1
  317. paddlex/inference/pipelines/video_detection/pipeline.py +20 -7
  318. paddlex/inference/serving/__init__.py +5 -1
  319. paddlex/inference/serving/basic_serving/__init__.py +1 -1
  320. paddlex/inference/serving/basic_serving/_app.py +31 -19
  321. paddlex/inference/serving/basic_serving/_pipeline_apps/__init__.py +7 -4
  322. paddlex/inference/serving/basic_serving/_pipeline_apps/_common/__init__.py +1 -1
  323. paddlex/inference/serving/basic_serving/_pipeline_apps/_common/common.py +12 -4
  324. paddlex/inference/serving/basic_serving/_pipeline_apps/_common/image_recognition.py +1 -1
  325. paddlex/inference/serving/basic_serving/_pipeline_apps/_common/ocr.py +7 -2
  326. paddlex/inference/serving/basic_serving/_pipeline_apps/anomaly_detection.py +10 -7
  327. paddlex/inference/serving/basic_serving/_pipeline_apps/doc_preprocessor.py +10 -7
  328. paddlex/inference/serving/basic_serving/_pipeline_apps/doc_understanding.py +153 -0
  329. paddlex/inference/serving/basic_serving/_pipeline_apps/face_recognition.py +16 -13
  330. paddlex/inference/serving/basic_serving/_pipeline_apps/formula_recognition.py +10 -7
  331. paddlex/inference/serving/basic_serving/_pipeline_apps/human_keypoint_detection.py +10 -7
  332. paddlex/inference/serving/basic_serving/_pipeline_apps/image_classification.py +10 -7
  333. paddlex/inference/serving/basic_serving/_pipeline_apps/image_multilabel_classification.py +10 -7
  334. paddlex/inference/serving/basic_serving/_pipeline_apps/instance_segmentation.py +13 -7
  335. paddlex/inference/serving/basic_serving/_pipeline_apps/layout_parsing.py +10 -8
  336. paddlex/inference/serving/basic_serving/_pipeline_apps/m_3d_bev_detection.py +10 -7
  337. paddlex/inference/serving/basic_serving/_pipeline_apps/multilingual_speech_recognition.py +10 -7
  338. paddlex/inference/serving/basic_serving/_pipeline_apps/object_detection.py +10 -7
  339. paddlex/inference/serving/basic_serving/_pipeline_apps/ocr.py +10 -7
  340. paddlex/inference/serving/basic_serving/_pipeline_apps/open_vocabulary_detection.py +10 -7
  341. paddlex/inference/serving/basic_serving/_pipeline_apps/open_vocabulary_segmentation.py +13 -7
  342. paddlex/inference/serving/basic_serving/_pipeline_apps/pedestrian_attribute_recognition.py +10 -7
  343. paddlex/inference/serving/basic_serving/_pipeline_apps/pp_chatocrv3_doc.py +14 -12
  344. paddlex/inference/serving/basic_serving/_pipeline_apps/pp_chatocrv4_doc.py +17 -14
  345. paddlex/inference/serving/basic_serving/_pipeline_apps/pp_shituv2.py +16 -13
  346. paddlex/inference/serving/basic_serving/_pipeline_apps/pp_structurev3.py +16 -9
  347. paddlex/inference/serving/basic_serving/_pipeline_apps/rotated_object_detection.py +10 -7
  348. paddlex/inference/serving/basic_serving/_pipeline_apps/seal_recognition.py +10 -7
  349. paddlex/inference/serving/basic_serving/_pipeline_apps/semantic_segmentation.py +10 -7
  350. paddlex/inference/serving/basic_serving/_pipeline_apps/small_object_detection.py +10 -7
  351. paddlex/inference/serving/basic_serving/_pipeline_apps/table_recognition.py +11 -12
  352. paddlex/inference/serving/basic_serving/_pipeline_apps/table_recognition_v2.py +14 -12
  353. paddlex/inference/serving/basic_serving/_pipeline_apps/ts_anomaly_detection.py +10 -7
  354. paddlex/inference/serving/basic_serving/_pipeline_apps/ts_classification.py +10 -7
  355. paddlex/inference/serving/basic_serving/_pipeline_apps/ts_forecast.py +10 -7
  356. paddlex/inference/serving/basic_serving/_pipeline_apps/vehicle_attribute_recognition.py +10 -7
  357. paddlex/inference/serving/basic_serving/_pipeline_apps/video_classification.py +10 -7
  358. paddlex/inference/serving/basic_serving/_pipeline_apps/video_detection.py +10 -7
  359. paddlex/inference/serving/basic_serving/_server.py +9 -4
  360. paddlex/inference/serving/infra/__init__.py +1 -1
  361. paddlex/inference/serving/infra/config.py +1 -1
  362. paddlex/inference/serving/infra/models.py +13 -6
  363. paddlex/inference/serving/infra/storage.py +9 -4
  364. paddlex/inference/serving/infra/utils.py +54 -28
  365. paddlex/inference/serving/schemas/__init__.py +1 -1
  366. paddlex/inference/serving/schemas/anomaly_detection.py +1 -1
  367. paddlex/inference/serving/schemas/doc_preprocessor.py +1 -1
  368. paddlex/inference/serving/schemas/doc_understanding.py +78 -0
  369. paddlex/inference/serving/schemas/face_recognition.py +1 -1
  370. paddlex/inference/serving/schemas/formula_recognition.py +2 -2
  371. paddlex/inference/serving/schemas/human_keypoint_detection.py +1 -1
  372. paddlex/inference/serving/schemas/image_classification.py +1 -1
  373. paddlex/inference/serving/schemas/image_multilabel_classification.py +1 -1
  374. paddlex/inference/serving/schemas/instance_segmentation.py +1 -1
  375. paddlex/inference/serving/schemas/layout_parsing.py +2 -3
  376. paddlex/inference/serving/schemas/m_3d_bev_detection.py +1 -1
  377. paddlex/inference/serving/schemas/multilingual_speech_recognition.py +1 -1
  378. paddlex/inference/serving/schemas/object_detection.py +1 -1
  379. paddlex/inference/serving/schemas/ocr.py +1 -1
  380. paddlex/inference/serving/schemas/open_vocabulary_detection.py +1 -1
  381. paddlex/inference/serving/schemas/open_vocabulary_segmentation.py +1 -1
  382. paddlex/inference/serving/schemas/pedestrian_attribute_recognition.py +1 -1
  383. paddlex/inference/serving/schemas/pp_chatocrv3_doc.py +2 -3
  384. paddlex/inference/serving/schemas/pp_chatocrv4_doc.py +3 -3
  385. paddlex/inference/serving/schemas/pp_shituv2.py +1 -1
  386. paddlex/inference/serving/schemas/pp_structurev3.py +11 -7
  387. paddlex/inference/serving/schemas/rotated_object_detection.py +1 -1
  388. paddlex/inference/serving/schemas/seal_recognition.py +2 -2
  389. paddlex/inference/serving/schemas/semantic_segmentation.py +1 -1
  390. paddlex/inference/serving/schemas/shared/__init__.py +1 -1
  391. paddlex/inference/serving/schemas/shared/classification.py +1 -1
  392. paddlex/inference/serving/schemas/shared/image_segmentation.py +1 -1
  393. paddlex/inference/serving/schemas/shared/object_detection.py +1 -1
  394. paddlex/inference/serving/schemas/shared/ocr.py +1 -1
  395. paddlex/inference/serving/schemas/small_object_detection.py +1 -1
  396. paddlex/inference/serving/schemas/table_recognition.py +3 -7
  397. paddlex/inference/serving/schemas/table_recognition_v2.py +6 -7
  398. paddlex/inference/serving/schemas/ts_anomaly_detection.py +1 -1
  399. paddlex/inference/serving/schemas/ts_classification.py +1 -1
  400. paddlex/inference/serving/schemas/ts_forecast.py +1 -1
  401. paddlex/inference/serving/schemas/vehicle_attribute_recognition.py +1 -1
  402. paddlex/inference/serving/schemas/video_classification.py +1 -1
  403. paddlex/inference/serving/schemas/video_detection.py +1 -1
  404. paddlex/inference/utils/__init__.py +1 -1
  405. paddlex/inference/utils/benchmark.py +332 -179
  406. paddlex/inference/utils/color_map.py +1 -1
  407. paddlex/inference/utils/get_pipeline_path.py +1 -1
  408. paddlex/inference/utils/hpi.py +258 -0
  409. paddlex/inference/utils/hpi_model_info_collection.json +2331 -0
  410. paddlex/inference/utils/io/__init__.py +11 -11
  411. paddlex/inference/utils/io/readers.py +31 -27
  412. paddlex/inference/utils/io/style.py +21 -14
  413. paddlex/inference/utils/io/tablepyxl.py +13 -5
  414. paddlex/inference/utils/io/writers.py +9 -10
  415. paddlex/inference/utils/mkldnn_blocklist.py +25 -0
  416. paddlex/inference/utils/model_paths.py +48 -0
  417. paddlex/inference/utils/{new_ir_blacklist.py → new_ir_blocklist.py} +1 -2
  418. paddlex/inference/utils/official_models.py +278 -262
  419. paddlex/inference/utils/pp_option.py +184 -92
  420. paddlex/inference/utils/trt_blocklist.py +43 -0
  421. paddlex/inference/utils/trt_config.py +420 -0
  422. paddlex/model.py +30 -12
  423. paddlex/modules/__init__.py +57 -80
  424. paddlex/modules/anomaly_detection/__init__.py +2 -2
  425. paddlex/modules/anomaly_detection/dataset_checker/__init__.py +2 -3
  426. paddlex/modules/anomaly_detection/dataset_checker/dataset_src/__init__.py +2 -2
  427. paddlex/modules/anomaly_detection/dataset_checker/dataset_src/analyse_dataset.py +6 -3
  428. paddlex/modules/anomaly_detection/dataset_checker/dataset_src/check_dataset.py +8 -4
  429. paddlex/modules/anomaly_detection/dataset_checker/dataset_src/convert_dataset.py +7 -4
  430. paddlex/modules/anomaly_detection/dataset_checker/dataset_src/split_dataset.py +2 -2
  431. paddlex/modules/anomaly_detection/dataset_checker/dataset_src/utils/__init__.py +1 -1
  432. paddlex/modules/anomaly_detection/dataset_checker/dataset_src/utils/visualizer.py +7 -2
  433. paddlex/modules/anomaly_detection/evaluator.py +3 -3
  434. paddlex/modules/anomaly_detection/exportor.py +1 -1
  435. paddlex/modules/anomaly_detection/model_list.py +1 -1
  436. paddlex/modules/anomaly_detection/trainer.py +3 -4
  437. paddlex/modules/base/__init__.py +5 -5
  438. paddlex/modules/base/build_model.py +1 -2
  439. paddlex/modules/base/dataset_checker/__init__.py +2 -2
  440. paddlex/modules/base/dataset_checker/dataset_checker.py +4 -4
  441. paddlex/modules/base/dataset_checker/utils.py +1 -3
  442. paddlex/modules/base/evaluator.py +13 -13
  443. paddlex/modules/base/exportor.py +12 -13
  444. paddlex/modules/base/trainer.py +21 -11
  445. paddlex/modules/base/utils/__init__.py +13 -0
  446. paddlex/modules/base/utils/cinn_setting.py +89 -0
  447. paddlex/modules/base/utils/coco_eval.py +94 -0
  448. paddlex/modules/base/utils/topk_eval.py +118 -0
  449. paddlex/modules/doc_vlm/__init__.py +18 -0
  450. paddlex/modules/doc_vlm/dataset_checker.py +29 -0
  451. paddlex/modules/doc_vlm/evaluator.py +29 -0
  452. paddlex/modules/doc_vlm/exportor.py +29 -0
  453. paddlex/modules/doc_vlm/model_list.py +16 -0
  454. paddlex/modules/doc_vlm/trainer.py +41 -0
  455. paddlex/modules/face_recognition/__init__.py +2 -2
  456. paddlex/modules/face_recognition/dataset_checker/__init__.py +2 -2
  457. paddlex/modules/face_recognition/dataset_checker/dataset_src/__init__.py +1 -1
  458. paddlex/modules/face_recognition/dataset_checker/dataset_src/check_dataset.py +3 -5
  459. paddlex/modules/face_recognition/dataset_checker/dataset_src/utils/__init__.py +1 -1
  460. paddlex/modules/face_recognition/dataset_checker/dataset_src/utils/visualizer.py +2 -5
  461. paddlex/modules/face_recognition/evaluator.py +3 -3
  462. paddlex/modules/face_recognition/exportor.py +1 -1
  463. paddlex/modules/face_recognition/model_list.py +1 -1
  464. paddlex/modules/face_recognition/trainer.py +1 -1
  465. paddlex/modules/formula_recognition/__init__.py +2 -2
  466. paddlex/modules/formula_recognition/dataset_checker/__init__.py +3 -3
  467. paddlex/modules/formula_recognition/dataset_checker/dataset_src/__init__.py +2 -2
  468. paddlex/modules/formula_recognition/dataset_checker/dataset_src/analyse_dataset.py +13 -12
  469. paddlex/modules/formula_recognition/dataset_checker/dataset_src/check_dataset.py +2 -6
  470. paddlex/modules/formula_recognition/dataset_checker/dataset_src/convert_dataset.py +11 -10
  471. paddlex/modules/formula_recognition/dataset_checker/dataset_src/split_dataset.py +1 -2
  472. paddlex/modules/formula_recognition/evaluator.py +6 -3
  473. paddlex/modules/formula_recognition/exportor.py +1 -1
  474. paddlex/modules/formula_recognition/model_list.py +4 -1
  475. paddlex/modules/formula_recognition/trainer.py +5 -3
  476. paddlex/modules/general_recognition/__init__.py +2 -2
  477. paddlex/modules/general_recognition/dataset_checker/__init__.py +2 -2
  478. paddlex/modules/general_recognition/dataset_checker/dataset_src/__init__.py +2 -2
  479. paddlex/modules/general_recognition/dataset_checker/dataset_src/analyse_dataset.py +7 -9
  480. paddlex/modules/general_recognition/dataset_checker/dataset_src/check_dataset.py +4 -5
  481. paddlex/modules/general_recognition/dataset_checker/dataset_src/convert_dataset.py +6 -5
  482. paddlex/modules/general_recognition/dataset_checker/dataset_src/split_dataset.py +1 -1
  483. paddlex/modules/general_recognition/dataset_checker/dataset_src/utils/__init__.py +1 -1
  484. paddlex/modules/general_recognition/dataset_checker/dataset_src/utils/visualizer.py +2 -5
  485. paddlex/modules/general_recognition/evaluator.py +2 -2
  486. paddlex/modules/general_recognition/exportor.py +1 -1
  487. paddlex/modules/general_recognition/model_list.py +1 -1
  488. paddlex/modules/general_recognition/trainer.py +1 -1
  489. paddlex/modules/image_classification/__init__.py +2 -2
  490. paddlex/modules/image_classification/dataset_checker/__init__.py +2 -2
  491. paddlex/modules/image_classification/dataset_checker/dataset_src/__init__.py +2 -2
  492. paddlex/modules/image_classification/dataset_checker/dataset_src/analyse_dataset.py +8 -9
  493. paddlex/modules/image_classification/dataset_checker/dataset_src/check_dataset.py +4 -3
  494. paddlex/modules/image_classification/dataset_checker/dataset_src/convert_dataset.py +4 -4
  495. paddlex/modules/image_classification/dataset_checker/dataset_src/split_dataset.py +1 -1
  496. paddlex/modules/image_classification/dataset_checker/dataset_src/utils/__init__.py +1 -1
  497. paddlex/modules/image_classification/dataset_checker/dataset_src/utils/visualizer.py +2 -5
  498. paddlex/modules/image_classification/evaluator.py +3 -3
  499. paddlex/modules/image_classification/exportor.py +1 -1
  500. paddlex/modules/image_classification/model_list.py +2 -1
  501. paddlex/modules/image_classification/trainer.py +3 -3
  502. paddlex/modules/image_unwarping/__init__.py +1 -1
  503. paddlex/modules/image_unwarping/model_list.py +1 -1
  504. paddlex/modules/instance_segmentation/__init__.py +2 -2
  505. paddlex/modules/instance_segmentation/dataset_checker/__init__.py +2 -3
  506. paddlex/modules/instance_segmentation/dataset_checker/dataset_src/__init__.py +2 -2
  507. paddlex/modules/instance_segmentation/dataset_checker/dataset_src/analyse_dataset.py +9 -5
  508. paddlex/modules/instance_segmentation/dataset_checker/dataset_src/check_dataset.py +8 -5
  509. paddlex/modules/instance_segmentation/dataset_checker/dataset_src/convert_dataset.py +8 -8
  510. paddlex/modules/instance_segmentation/dataset_checker/dataset_src/split_dataset.py +7 -4
  511. paddlex/modules/instance_segmentation/dataset_checker/dataset_src/utils/__init__.py +1 -1
  512. paddlex/modules/instance_segmentation/dataset_checker/dataset_src/utils/visualizer.py +10 -8
  513. paddlex/modules/instance_segmentation/evaluator.py +2 -2
  514. paddlex/modules/instance_segmentation/exportor.py +1 -1
  515. paddlex/modules/instance_segmentation/model_list.py +1 -1
  516. paddlex/modules/instance_segmentation/trainer.py +1 -1
  517. paddlex/modules/keypoint_detection/__init__.py +2 -2
  518. paddlex/modules/keypoint_detection/dataset_checker/__init__.py +2 -2
  519. paddlex/modules/keypoint_detection/dataset_checker/dataset_src/__init__.py +1 -1
  520. paddlex/modules/keypoint_detection/dataset_checker/dataset_src/check_dataset.py +10 -5
  521. paddlex/modules/keypoint_detection/dataset_checker/dataset_src/utils/__init__.py +1 -1
  522. paddlex/modules/keypoint_detection/dataset_checker/dataset_src/utils/visualizer.py +8 -3
  523. paddlex/modules/keypoint_detection/evaluator.py +2 -2
  524. paddlex/modules/keypoint_detection/exportor.py +1 -1
  525. paddlex/modules/keypoint_detection/model_list.py +1 -1
  526. paddlex/modules/keypoint_detection/trainer.py +2 -2
  527. paddlex/modules/{3d_bev_detection → m_3d_bev_detection}/__init__.py +2 -2
  528. paddlex/modules/{3d_bev_detection → m_3d_bev_detection}/dataset_checker/__init__.py +3 -3
  529. paddlex/modules/{3d_bev_detection → m_3d_bev_detection}/dataset_checker/dataset_src/__init__.py +2 -2
  530. paddlex/modules/{3d_bev_detection → m_3d_bev_detection}/dataset_checker/dataset_src/analyse_dataset.py +8 -8
  531. paddlex/modules/{3d_bev_detection → m_3d_bev_detection}/dataset_checker/dataset_src/check_dataset.py +1 -2
  532. paddlex/modules/{3d_bev_detection → m_3d_bev_detection}/evaluator.py +3 -3
  533. paddlex/modules/{3d_bev_detection → m_3d_bev_detection}/exportor.py +1 -1
  534. paddlex/modules/{3d_bev_detection → m_3d_bev_detection}/model_list.py +1 -1
  535. paddlex/modules/{3d_bev_detection → m_3d_bev_detection}/trainer.py +5 -7
  536. paddlex/modules/multilabel_classification/__init__.py +2 -2
  537. paddlex/modules/multilabel_classification/dataset_checker/__init__.py +2 -2
  538. paddlex/modules/multilabel_classification/dataset_checker/dataset_src/__init__.py +2 -2
  539. paddlex/modules/multilabel_classification/dataset_checker/dataset_src/analyse_dataset.py +8 -9
  540. paddlex/modules/multilabel_classification/dataset_checker/dataset_src/check_dataset.py +4 -3
  541. paddlex/modules/multilabel_classification/dataset_checker/dataset_src/convert_dataset.py +10 -7
  542. paddlex/modules/multilabel_classification/dataset_checker/dataset_src/split_dataset.py +1 -1
  543. paddlex/modules/multilabel_classification/dataset_checker/dataset_src/utils/__init__.py +1 -1
  544. paddlex/modules/multilabel_classification/dataset_checker/dataset_src/utils/visualizer.py +1 -5
  545. paddlex/modules/multilabel_classification/evaluator.py +3 -3
  546. paddlex/modules/multilabel_classification/exportor.py +1 -1
  547. paddlex/modules/multilabel_classification/model_list.py +1 -1
  548. paddlex/modules/multilabel_classification/trainer.py +3 -3
  549. paddlex/modules/multilingual_speech_recognition/__init__.py +2 -2
  550. paddlex/modules/multilingual_speech_recognition/dataset_checker.py +3 -3
  551. paddlex/modules/multilingual_speech_recognition/evaluator.py +3 -3
  552. paddlex/modules/multilingual_speech_recognition/exportor.py +3 -3
  553. paddlex/modules/multilingual_speech_recognition/model_list.py +1 -1
  554. paddlex/modules/multilingual_speech_recognition/trainer.py +7 -5
  555. paddlex/modules/object_detection/__init__.py +2 -2
  556. paddlex/modules/object_detection/dataset_checker/__init__.py +2 -11
  557. paddlex/modules/object_detection/dataset_checker/dataset_src/__init__.py +2 -2
  558. paddlex/modules/object_detection/dataset_checker/dataset_src/analyse_dataset.py +10 -8
  559. paddlex/modules/object_detection/dataset_checker/dataset_src/check_dataset.py +10 -5
  560. paddlex/modules/object_detection/dataset_checker/dataset_src/convert_dataset.py +17 -12
  561. paddlex/modules/object_detection/dataset_checker/dataset_src/split_dataset.py +8 -4
  562. paddlex/modules/object_detection/dataset_checker/dataset_src/utils/__init__.py +1 -1
  563. paddlex/modules/object_detection/dataset_checker/dataset_src/utils/visualizer.py +9 -8
  564. paddlex/modules/object_detection/evaluator.py +11 -6
  565. paddlex/modules/object_detection/exportor.py +1 -1
  566. paddlex/modules/object_detection/model_list.py +3 -1
  567. paddlex/modules/object_detection/trainer.py +4 -5
  568. paddlex/modules/open_vocabulary_detection/__init__.py +2 -2
  569. paddlex/modules/open_vocabulary_detection/dataset_checker.py +3 -3
  570. paddlex/modules/open_vocabulary_detection/evaluator.py +3 -3
  571. paddlex/modules/open_vocabulary_detection/exportor.py +3 -3
  572. paddlex/modules/open_vocabulary_detection/model_list.py +2 -4
  573. paddlex/modules/open_vocabulary_detection/trainer.py +7 -5
  574. paddlex/modules/open_vocabulary_segmentation/__init__.py +2 -2
  575. paddlex/modules/open_vocabulary_segmentation/dataset_checker.py +3 -3
  576. paddlex/modules/open_vocabulary_segmentation/evaluator.py +3 -3
  577. paddlex/modules/open_vocabulary_segmentation/exportor.py +3 -3
  578. paddlex/modules/open_vocabulary_segmentation/model_list.py +1 -1
  579. paddlex/modules/open_vocabulary_segmentation/trainer.py +7 -5
  580. paddlex/modules/semantic_segmentation/__init__.py +2 -2
  581. paddlex/modules/semantic_segmentation/dataset_checker/__init__.py +2 -3
  582. paddlex/modules/semantic_segmentation/dataset_checker/dataset_src/__init__.py +2 -2
  583. paddlex/modules/semantic_segmentation/dataset_checker/dataset_src/analyse_dataset.py +6 -3
  584. paddlex/modules/semantic_segmentation/dataset_checker/dataset_src/check_dataset.py +2 -2
  585. paddlex/modules/semantic_segmentation/dataset_checker/dataset_src/convert_dataset.py +7 -4
  586. paddlex/modules/semantic_segmentation/dataset_checker/dataset_src/split_dataset.py +2 -2
  587. paddlex/modules/semantic_segmentation/dataset_checker/dataset_src/utils/__init__.py +1 -1
  588. paddlex/modules/semantic_segmentation/dataset_checker/dataset_src/utils/visualizer.py +6 -2
  589. paddlex/modules/semantic_segmentation/evaluator.py +3 -3
  590. paddlex/modules/semantic_segmentation/exportor.py +1 -1
  591. paddlex/modules/semantic_segmentation/model_list.py +1 -1
  592. paddlex/modules/semantic_segmentation/trainer.py +3 -4
  593. paddlex/modules/table_recognition/__init__.py +2 -2
  594. paddlex/modules/table_recognition/dataset_checker/__init__.py +5 -5
  595. paddlex/modules/table_recognition/dataset_checker/dataset_src/__init__.py +2 -2
  596. paddlex/modules/table_recognition/dataset_checker/dataset_src/analyse_dataset.py +3 -2
  597. paddlex/modules/table_recognition/dataset_checker/dataset_src/check_dataset.py +8 -7
  598. paddlex/modules/table_recognition/dataset_checker/dataset_src/split_dataset.py +2 -1
  599. paddlex/modules/table_recognition/evaluator.py +3 -3
  600. paddlex/modules/table_recognition/exportor.py +1 -1
  601. paddlex/modules/table_recognition/model_list.py +1 -1
  602. paddlex/modules/table_recognition/trainer.py +2 -5
  603. paddlex/modules/text_detection/__init__.py +2 -2
  604. paddlex/modules/text_detection/dataset_checker/__init__.py +4 -6
  605. paddlex/modules/text_detection/dataset_checker/dataset_src/__init__.py +2 -2
  606. paddlex/modules/text_detection/dataset_checker/dataset_src/analyse_dataset.py +12 -9
  607. paddlex/modules/text_detection/dataset_checker/dataset_src/check_dataset.py +3 -3
  608. paddlex/modules/text_detection/dataset_checker/dataset_src/split_dataset.py +3 -3
  609. paddlex/modules/text_detection/evaluator.py +3 -3
  610. paddlex/modules/text_detection/exportor.py +1 -1
  611. paddlex/modules/text_detection/model_list.py +3 -1
  612. paddlex/modules/text_detection/trainer.py +2 -5
  613. paddlex/modules/text_recognition/__init__.py +2 -2
  614. paddlex/modules/text_recognition/dataset_checker/__init__.py +4 -5
  615. paddlex/modules/text_recognition/dataset_checker/dataset_src/__init__.py +2 -2
  616. paddlex/modules/text_recognition/dataset_checker/dataset_src/analyse_dataset.py +13 -12
  617. paddlex/modules/text_recognition/dataset_checker/dataset_src/check_dataset.py +2 -5
  618. paddlex/modules/text_recognition/dataset_checker/dataset_src/convert_dataset.py +11 -10
  619. paddlex/modules/text_recognition/dataset_checker/dataset_src/split_dataset.py +1 -2
  620. paddlex/modules/text_recognition/evaluator.py +3 -3
  621. paddlex/modules/text_recognition/exportor.py +1 -1
  622. paddlex/modules/text_recognition/model_list.py +3 -1
  623. paddlex/modules/text_recognition/trainer.py +2 -3
  624. paddlex/modules/ts_anomaly_detection/__init__.py +2 -2
  625. paddlex/modules/ts_anomaly_detection/dataset_checker/__init__.py +4 -5
  626. paddlex/modules/ts_anomaly_detection/dataset_checker/dataset_src/__init__.py +2 -2
  627. paddlex/modules/ts_anomaly_detection/dataset_checker/dataset_src/analyse_dataset.py +1 -9
  628. paddlex/modules/ts_anomaly_detection/dataset_checker/dataset_src/check_dataset.py +2 -2
  629. paddlex/modules/ts_anomaly_detection/dataset_checker/dataset_src/convert_dataset.py +2 -6
  630. paddlex/modules/ts_anomaly_detection/dataset_checker/dataset_src/split_dataset.py +4 -4
  631. paddlex/modules/ts_anomaly_detection/evaluator.py +3 -3
  632. paddlex/modules/ts_anomaly_detection/exportor.py +2 -3
  633. paddlex/modules/ts_anomaly_detection/model_list.py +1 -1
  634. paddlex/modules/ts_anomaly_detection/trainer.py +8 -8
  635. paddlex/modules/ts_classification/__init__.py +2 -2
  636. paddlex/modules/ts_classification/dataset_checker/__init__.py +4 -5
  637. paddlex/modules/ts_classification/dataset_checker/dataset_src/__init__.py +2 -2
  638. paddlex/modules/ts_classification/dataset_checker/dataset_src/analyse_dataset.py +8 -5
  639. paddlex/modules/ts_classification/dataset_checker/dataset_src/check_dataset.py +2 -2
  640. paddlex/modules/ts_classification/dataset_checker/dataset_src/convert_dataset.py +2 -6
  641. paddlex/modules/ts_classification/dataset_checker/dataset_src/split_dataset.py +5 -5
  642. paddlex/modules/ts_classification/evaluator.py +3 -3
  643. paddlex/modules/ts_classification/exportor.py +2 -3
  644. paddlex/modules/ts_classification/model_list.py +1 -1
  645. paddlex/modules/ts_classification/trainer.py +7 -7
  646. paddlex/modules/ts_forecast/__init__.py +2 -2
  647. paddlex/modules/ts_forecast/dataset_checker/__init__.py +4 -5
  648. paddlex/modules/ts_forecast/dataset_checker/dataset_src/__init__.py +2 -2
  649. paddlex/modules/ts_forecast/dataset_checker/dataset_src/analyse_dataset.py +1 -9
  650. paddlex/modules/ts_forecast/dataset_checker/dataset_src/check_dataset.py +2 -2
  651. paddlex/modules/ts_forecast/dataset_checker/dataset_src/convert_dataset.py +2 -6
  652. paddlex/modules/ts_forecast/dataset_checker/dataset_src/split_dataset.py +4 -4
  653. paddlex/modules/ts_forecast/evaluator.py +3 -3
  654. paddlex/modules/ts_forecast/exportor.py +2 -3
  655. paddlex/modules/ts_forecast/model_list.py +1 -1
  656. paddlex/modules/ts_forecast/trainer.py +7 -7
  657. paddlex/modules/video_classification/__init__.py +2 -2
  658. paddlex/modules/video_classification/dataset_checker/__init__.py +2 -2
  659. paddlex/modules/video_classification/dataset_checker/dataset_src/__init__.py +2 -2
  660. paddlex/modules/video_classification/dataset_checker/dataset_src/analyse_dataset.py +9 -9
  661. paddlex/modules/video_classification/dataset_checker/dataset_src/check_dataset.py +2 -3
  662. paddlex/modules/video_classification/dataset_checker/dataset_src/split_dataset.py +1 -1
  663. paddlex/modules/video_classification/evaluator.py +3 -3
  664. paddlex/modules/video_classification/exportor.py +1 -1
  665. paddlex/modules/video_classification/model_list.py +1 -1
  666. paddlex/modules/video_classification/trainer.py +3 -3
  667. paddlex/modules/video_detection/__init__.py +2 -2
  668. paddlex/modules/video_detection/dataset_checker/__init__.py +2 -2
  669. paddlex/modules/video_detection/dataset_checker/dataset_src/__init__.py +2 -2
  670. paddlex/modules/video_detection/dataset_checker/dataset_src/analyse_dataset.py +8 -9
  671. paddlex/modules/video_detection/dataset_checker/dataset_src/check_dataset.py +3 -5
  672. paddlex/modules/video_detection/evaluator.py +3 -3
  673. paddlex/modules/video_detection/exportor.py +1 -1
  674. paddlex/modules/video_detection/model_list.py +1 -1
  675. paddlex/modules/video_detection/trainer.py +3 -3
  676. paddlex/ops/__init__.py +7 -4
  677. paddlex/ops/iou3d_nms/iou3d_cpu.cpp +8 -6
  678. paddlex/ops/iou3d_nms/iou3d_cpu.h +3 -2
  679. paddlex/ops/iou3d_nms/iou3d_nms.cpp +8 -6
  680. paddlex/ops/iou3d_nms/iou3d_nms.h +6 -4
  681. paddlex/ops/iou3d_nms/iou3d_nms_api.cpp +24 -18
  682. paddlex/ops/iou3d_nms/iou3d_nms_kernel.cu +9 -7
  683. paddlex/ops/setup.py +3 -3
  684. paddlex/ops/voxel/voxelize_op.cc +22 -19
  685. paddlex/ops/voxel/voxelize_op.cu +25 -25
  686. paddlex/paddlex_cli.py +104 -87
  687. paddlex/repo_apis/Paddle3D_api/__init__.py +1 -1
  688. paddlex/repo_apis/Paddle3D_api/bev_fusion/__init__.py +1 -1
  689. paddlex/repo_apis/Paddle3D_api/bev_fusion/config.py +1 -1
  690. paddlex/repo_apis/Paddle3D_api/bev_fusion/model.py +6 -6
  691. paddlex/repo_apis/Paddle3D_api/bev_fusion/register.py +2 -2
  692. paddlex/repo_apis/Paddle3D_api/bev_fusion/runner.py +1 -1
  693. paddlex/repo_apis/Paddle3D_api/pp3d_config.py +3 -2
  694. paddlex/repo_apis/PaddleClas_api/__init__.py +1 -1
  695. paddlex/repo_apis/PaddleClas_api/cls/__init__.py +3 -3
  696. paddlex/repo_apis/PaddleClas_api/cls/config.py +5 -4
  697. paddlex/repo_apis/PaddleClas_api/cls/model.py +4 -4
  698. paddlex/repo_apis/PaddleClas_api/cls/register.py +12 -3
  699. paddlex/repo_apis/PaddleClas_api/cls/runner.py +2 -3
  700. paddlex/repo_apis/PaddleClas_api/shitu_rec/__init__.py +2 -2
  701. paddlex/repo_apis/PaddleClas_api/shitu_rec/config.py +2 -2
  702. paddlex/repo_apis/PaddleClas_api/shitu_rec/model.py +1 -4
  703. paddlex/repo_apis/PaddleClas_api/shitu_rec/register.py +2 -2
  704. paddlex/repo_apis/PaddleClas_api/shitu_rec/runner.py +1 -6
  705. paddlex/repo_apis/PaddleDetection_api/__init__.py +2 -2
  706. paddlex/repo_apis/PaddleDetection_api/config_helper.py +3 -3
  707. paddlex/repo_apis/PaddleDetection_api/instance_seg/__init__.py +2 -2
  708. paddlex/repo_apis/PaddleDetection_api/instance_seg/config.py +2 -3
  709. paddlex/repo_apis/PaddleDetection_api/instance_seg/model.py +4 -4
  710. paddlex/repo_apis/PaddleDetection_api/instance_seg/register.py +2 -3
  711. paddlex/repo_apis/PaddleDetection_api/instance_seg/runner.py +2 -3
  712. paddlex/repo_apis/PaddleDetection_api/object_det/__init__.py +3 -3
  713. paddlex/repo_apis/PaddleDetection_api/object_det/config.py +5 -4
  714. paddlex/repo_apis/PaddleDetection_api/object_det/model.py +6 -7
  715. paddlex/repo_apis/PaddleDetection_api/object_det/official_categories.py +26 -1
  716. paddlex/repo_apis/PaddleDetection_api/object_det/register.py +32 -3
  717. paddlex/repo_apis/PaddleDetection_api/object_det/runner.py +2 -3
  718. paddlex/repo_apis/PaddleNLP_api/__init__.py +1 -1
  719. paddlex/repo_apis/PaddleOCR_api/__init__.py +4 -3
  720. paddlex/repo_apis/PaddleOCR_api/config_utils.py +1 -1
  721. paddlex/repo_apis/PaddleOCR_api/formula_rec/__init__.py +1 -1
  722. paddlex/repo_apis/PaddleOCR_api/formula_rec/config.py +7 -6
  723. paddlex/repo_apis/PaddleOCR_api/formula_rec/model.py +9 -13
  724. paddlex/repo_apis/PaddleOCR_api/formula_rec/register.py +29 -3
  725. paddlex/repo_apis/PaddleOCR_api/formula_rec/runner.py +2 -3
  726. paddlex/repo_apis/PaddleOCR_api/table_rec/__init__.py +1 -1
  727. paddlex/repo_apis/PaddleOCR_api/table_rec/config.py +1 -1
  728. paddlex/repo_apis/PaddleOCR_api/table_rec/model.py +4 -4
  729. paddlex/repo_apis/PaddleOCR_api/table_rec/register.py +2 -3
  730. paddlex/repo_apis/PaddleOCR_api/table_rec/runner.py +3 -3
  731. paddlex/repo_apis/PaddleOCR_api/text_det/__init__.py +1 -1
  732. paddlex/repo_apis/PaddleOCR_api/text_det/config.py +1 -1
  733. paddlex/repo_apis/PaddleOCR_api/text_det/model.py +4 -4
  734. paddlex/repo_apis/PaddleOCR_api/text_det/register.py +20 -3
  735. paddlex/repo_apis/PaddleOCR_api/text_det/runner.py +3 -3
  736. paddlex/repo_apis/PaddleOCR_api/text_rec/__init__.py +1 -1
  737. paddlex/repo_apis/PaddleOCR_api/text_rec/config.py +7 -6
  738. paddlex/repo_apis/PaddleOCR_api/text_rec/model.py +9 -13
  739. paddlex/repo_apis/PaddleOCR_api/text_rec/register.py +20 -3
  740. paddlex/repo_apis/PaddleOCR_api/text_rec/runner.py +2 -3
  741. paddlex/repo_apis/PaddleSeg_api/__init__.py +1 -1
  742. paddlex/repo_apis/PaddleSeg_api/base_seg_config.py +2 -2
  743. paddlex/repo_apis/PaddleSeg_api/seg/__init__.py +1 -1
  744. paddlex/repo_apis/PaddleSeg_api/seg/config.py +3 -6
  745. paddlex/repo_apis/PaddleSeg_api/seg/model.py +6 -6
  746. paddlex/repo_apis/PaddleSeg_api/seg/register.py +2 -3
  747. paddlex/repo_apis/PaddleSeg_api/seg/runner.py +2 -3
  748. paddlex/repo_apis/PaddleTS_api/__init__.py +4 -3
  749. paddlex/repo_apis/PaddleTS_api/ts_ad/__init__.py +1 -1
  750. paddlex/repo_apis/PaddleTS_api/ts_ad/config.py +5 -6
  751. paddlex/repo_apis/PaddleTS_api/ts_ad/register.py +2 -2
  752. paddlex/repo_apis/PaddleTS_api/ts_ad/runner.py +2 -2
  753. paddlex/repo_apis/PaddleTS_api/ts_base/__init__.py +1 -1
  754. paddlex/repo_apis/PaddleTS_api/ts_base/config.py +2 -4
  755. paddlex/repo_apis/PaddleTS_api/ts_base/model.py +4 -4
  756. paddlex/repo_apis/PaddleTS_api/ts_base/runner.py +2 -2
  757. paddlex/repo_apis/PaddleTS_api/ts_cls/__init__.py +1 -1
  758. paddlex/repo_apis/PaddleTS_api/ts_cls/config.py +4 -5
  759. paddlex/repo_apis/PaddleTS_api/ts_cls/register.py +2 -2
  760. paddlex/repo_apis/PaddleTS_api/ts_cls/runner.py +2 -2
  761. paddlex/repo_apis/PaddleTS_api/ts_fc/__init__.py +1 -1
  762. paddlex/repo_apis/PaddleTS_api/ts_fc/config.py +6 -7
  763. paddlex/repo_apis/PaddleTS_api/ts_fc/register.py +1 -1
  764. paddlex/repo_apis/PaddleVideo_api/__init__.py +1 -1
  765. paddlex/repo_apis/PaddleVideo_api/config_utils.py +1 -1
  766. paddlex/repo_apis/PaddleVideo_api/video_cls/__init__.py +3 -3
  767. paddlex/repo_apis/PaddleVideo_api/video_cls/config.py +5 -4
  768. paddlex/repo_apis/PaddleVideo_api/video_cls/model.py +4 -4
  769. paddlex/repo_apis/PaddleVideo_api/video_cls/register.py +2 -3
  770. paddlex/repo_apis/PaddleVideo_api/video_cls/runner.py +2 -3
  771. paddlex/repo_apis/PaddleVideo_api/video_det/__init__.py +3 -3
  772. paddlex/repo_apis/PaddleVideo_api/video_det/config.py +5 -4
  773. paddlex/repo_apis/PaddleVideo_api/video_det/model.py +5 -5
  774. paddlex/repo_apis/PaddleVideo_api/video_det/register.py +2 -3
  775. paddlex/repo_apis/PaddleVideo_api/video_det/runner.py +2 -3
  776. paddlex/repo_apis/__init__.py +1 -1
  777. paddlex/repo_apis/base/__init__.py +4 -5
  778. paddlex/repo_apis/base/config.py +3 -4
  779. paddlex/repo_apis/base/model.py +11 -19
  780. paddlex/repo_apis/base/register.py +1 -1
  781. paddlex/repo_apis/base/runner.py +11 -12
  782. paddlex/repo_apis/base/utils/__init__.py +1 -1
  783. paddlex/repo_apis/base/utils/arg.py +1 -1
  784. paddlex/repo_apis/base/utils/subprocess.py +1 -1
  785. paddlex/repo_manager/__init__.py +2 -9
  786. paddlex/repo_manager/core.py +12 -30
  787. paddlex/repo_manager/meta.py +41 -31
  788. paddlex/repo_manager/repo.py +171 -161
  789. paddlex/repo_manager/utils.py +13 -224
  790. paddlex/utils/__init__.py +1 -1
  791. paddlex/utils/cache.py +8 -10
  792. paddlex/utils/config.py +6 -5
  793. paddlex/utils/{custom_device_whitelist.py → custom_device_list.py} +53 -199
  794. paddlex/utils/deps.py +249 -0
  795. paddlex/utils/device.py +87 -36
  796. paddlex/utils/download.py +4 -4
  797. paddlex/utils/env.py +37 -7
  798. paddlex/utils/errors/__init__.py +1 -1
  799. paddlex/utils/errors/dataset_checker.py +1 -1
  800. paddlex/utils/errors/others.py +2 -16
  801. paddlex/utils/file_interface.py +4 -5
  802. paddlex/utils/flags.py +17 -12
  803. paddlex/utils/fonts/__init__.py +36 -5
  804. paddlex/utils/func_register.py +1 -1
  805. paddlex/utils/install.py +87 -0
  806. paddlex/utils/interactive_get_pipeline.py +3 -3
  807. paddlex/utils/lazy_loader.py +3 -3
  808. paddlex/utils/logging.py +10 -1
  809. paddlex/utils/misc.py +6 -6
  810. paddlex/utils/pipeline_arguments.py +15 -7
  811. paddlex/utils/result_saver.py +4 -5
  812. paddlex/utils/subclass_register.py +2 -4
  813. paddlex/version.py +2 -1
  814. {paddlex-3.0.0rc0.dist-info → paddlex-3.0.1.dist-info}/METADATA +237 -102
  815. paddlex-3.0.1.dist-info/RECORD +1095 -0
  816. {paddlex-3.0.0rc0.dist-info → paddlex-3.0.1.dist-info}/WHEEL +1 -1
  817. paddlex/inference/models/base/predictor/basic_predictor.py +0 -139
  818. paddlex/paddle2onnx_requirements.txt +0 -1
  819. paddlex/repo_manager/requirements.txt +0 -21
  820. paddlex/serving_requirements.txt +0 -9
  821. paddlex-3.0.0rc0.dist-info/RECORD +0 -1015
  822. {paddlex-3.0.0rc0.dist-info → paddlex-3.0.1.dist-info}/entry_points.txt +0 -0
  823. {paddlex-3.0.0rc0.dist-info → paddlex-3.0.1.dist-info/licenses}/LICENSE +0 -0
  824. {paddlex-3.0.0rc0.dist-info → paddlex-3.0.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,2162 @@
1
+ # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import copy
16
+ import inspect
17
+ from typing import Optional, Union
18
+
19
+ import paddle
20
+ import paddle.distributed as dist
21
+ import paddle.nn as nn
22
+ import paddle.nn.functional as F
23
+ from paddle import Tensor
24
+ from paddle.common_ops_import import convert_dtype
25
+ from paddle.utils import map_structure
26
+
27
+ from ......utils import logging
28
+ from ..transformers.model_outputs import ModelOutput
29
+ from .configuration_utils import DEFAULT_MAX_NEW_TOKENS, GenerationConfig
30
+ from .logits_process import (
31
+ ForcedBOSTokenLogitsProcessor,
32
+ ForcedEOSTokenLogitsProcessor,
33
+ HammingDiversityLogitsProcessor,
34
+ LogitsProcessor,
35
+ LogitsProcessorList,
36
+ MinLengthLogitsProcessor,
37
+ NoRepeatNGramLogitsProcessor,
38
+ RepetitionPenaltyLogitsProcessor,
39
+ TopKProcess,
40
+ TopPProcess,
41
+ )
42
+ from .stopping_criteria import (
43
+ StoppingCriteria,
44
+ StoppingCriteriaList,
45
+ validate_stopping_criteria,
46
+ )
47
+
48
+ __all__ = [
49
+ "GenerationMixin",
50
+ "BeamSearchScorer",
51
+ "BeamHypotheses",
52
+ "LogitsProcessorList",
53
+ "LogitsProcessor",
54
+ "MinLengthLogitsProcessor",
55
+ "RepetitionPenaltyLogitsProcessor",
56
+ "TopKProcess",
57
+ "TopPProcess",
58
+ "get_unfinished_flag",
59
+ ]
60
+
61
+
62
+ def get_scale_by_dtype(dtype: str = None, return_positive: bool = True) -> float:
63
+ """get scale value by dtype
64
+
65
+ Args:
66
+ dtype (str): the string dtype value
67
+
68
+ Returns:
69
+ float: the scale value
70
+ """
71
+ if dtype is None:
72
+ dtype = paddle.get_default_dtype()
73
+
74
+ dtype = convert_dtype(dtype)
75
+ scale_value = 1e6
76
+
77
+ # TODO(wj-Mcaf): support int8, int4 dtypes later
78
+ if dtype == "float16":
79
+ scale_value = 1e4
80
+
81
+ if return_positive:
82
+ return scale_value
83
+ return -1 * scale_value
84
+
85
+
86
+ def get_unfinished_flag(
87
+ input_ids: Tensor,
88
+ unfinished_flag: Tensor,
89
+ eos_token_id: Union[int, list[int], list[list[int]]],
90
+ ) -> Tensor:
91
+ """get unfinished flag for generation step
92
+
93
+ Args:
94
+ input_ids (Tensor): the input_ids
95
+ eos_token_id (Union[int, list[int], list[list[int]]]): the end os sentence flag, which can be:
96
+ * single token id, eg: 10
97
+ * multiple token ids to stop generation, eg: [10, 10]
98
+ * some more tokens to stop generations, eg: [[10], [20, 20], [30, 30, 30]]
99
+
100
+ Returns:
101
+ Tensor: the unfinished flag tensor
102
+ """
103
+ if isinstance(eos_token_id, int):
104
+ unfinished_flag = paddle.logical_and(
105
+ unfinished_flag, input_ids[:, -1:] != eos_token_id
106
+ )
107
+ else:
108
+ batch_unfinish_flag = None
109
+ for batch_eos_token_id in eos_token_id:
110
+ if batch_unfinish_flag is None:
111
+ batch_unfinish_flag = ~get_unfinished_flag(
112
+ input_ids, unfinished_flag, batch_eos_token_id
113
+ )
114
+ else:
115
+ batch_unfinish_flag = paddle.logical_or(
116
+ batch_unfinish_flag,
117
+ ~get_unfinished_flag(
118
+ input_ids, unfinished_flag, batch_eos_token_id
119
+ ),
120
+ )
121
+
122
+ unfinished_flag = ~batch_unfinish_flag
123
+ return unfinished_flag
124
+
125
+
126
+ class BeamHypotheses:
127
+ def __init__(self, num_beams, length_penalty, early_stopping):
128
+ """
129
+ Initialize n-best list of hypotheses.
130
+ """
131
+ self.length_penalty = length_penalty
132
+ self.early_stopping = early_stopping
133
+ self.num_beams = num_beams
134
+ self.beams = []
135
+ self.worst_score = get_scale_by_dtype()
136
+
137
+ def __len__(self):
138
+ """
139
+ Number of hypotheses in the list.
140
+ """
141
+ return len(self.beams)
142
+
143
+ def add(self, hyp, sum_logprobs, origin_len=0):
144
+ """
145
+ Add a new hypothesis to the list.
146
+ """
147
+ score = sum_logprobs / (
148
+ ((hyp.shape[-1] - origin_len + 5) / 6) ** self.length_penalty
149
+ )
150
+ if len(self) < self.num_beams or score > self.worst_score:
151
+ self.beams.append((score, hyp))
152
+ if len(self) > self.num_beams:
153
+ sorted_next_scores = sorted(
154
+ [(s, idx) for idx, (s, _) in enumerate(self.beams)]
155
+ )
156
+ del self.beams[sorted_next_scores[0][1]]
157
+ self.worst_score = sorted_next_scores[1][0]
158
+ else:
159
+ self.worst_score = min(score, self.worst_score)
160
+
161
+ def is_done(self, best_sum_logprobs, cur_len, origin_len=0):
162
+ """
163
+ If there are enough hypotheses and that none of the hypotheses being
164
+ generated can become better than the worst one in the heap, then we
165
+ are done with this sentence.
166
+ """
167
+ if len(self) < self.num_beams:
168
+ return False
169
+ elif self.early_stopping:
170
+ return True
171
+ else:
172
+ cur_score = (
173
+ best_sum_logprobs
174
+ / ((cur_len - origin_len + 5) / 6) ** self.length_penalty
175
+ )
176
+ ret = self.worst_score >= cur_score
177
+ return ret
178
+
179
+
180
+ class BeamSearchScorer(object):
181
+ """
182
+ implementing standard beam search decoding.
183
+ """
184
+
185
+ def __init__(
186
+ self,
187
+ batch_size,
188
+ max_length,
189
+ num_beams,
190
+ length_penalty=1.0,
191
+ do_early_stopping=False,
192
+ num_beam_hyps_to_keep=1,
193
+ num_beam_groups=1,
194
+ ):
195
+ self.max_length = max_length
196
+ self.num_beams = num_beams
197
+ self.length_penalty = length_penalty
198
+ self.do_early_stopping = do_early_stopping
199
+ self.num_beam_hyps_to_keep = num_beam_hyps_to_keep
200
+ self.num_beam_groups = num_beam_groups
201
+ self.group_size = self.num_beams // self.num_beam_groups
202
+
203
+ self._is_init = False
204
+ self._beam_hyps = [
205
+ BeamHypotheses(
206
+ num_beams=self.num_beams,
207
+ length_penalty=self.length_penalty,
208
+ early_stopping=self.do_early_stopping,
209
+ )
210
+ for _ in range(batch_size)
211
+ ]
212
+ self._done = paddle.to_tensor([0 for _ in range(batch_size)], dtype="int64")
213
+
214
+ if not isinstance(num_beams, int) or num_beams <= 1:
215
+ raise ValueError(
216
+ "`num_beams` has to be an integer strictly greater than 1, but "
217
+ "received {}. For `num_beams` == 1, one should make use of "
218
+ "`greedy_search` instead.".format(num_beams)
219
+ )
220
+
221
+ if (
222
+ not isinstance(num_beam_groups, int)
223
+ or (num_beam_groups > num_beams)
224
+ or (num_beams % num_beam_groups != 0)
225
+ ):
226
+ raise ValueError(
227
+ "`num_beam_groups` has to be an integer smaller or equal than "
228
+ "`num_beams` and `num_beams` has to be divisible by "
229
+ "`num_beam_groups`, but received num_beam_groups={}, num_beams="
230
+ "{}.".format(num_beam_groups, num_beams)
231
+ )
232
+
233
+ @property
234
+ def is_done(self):
235
+ return paddle.min(self._done) == 1
236
+
237
+ def process(
238
+ self,
239
+ input_ids,
240
+ next_scores,
241
+ next_tokens,
242
+ next_indices,
243
+ origin_len=0,
244
+ pad_token_id=None,
245
+ eos_token_id=None,
246
+ ):
247
+ cur_len = input_ids.shape[-1]
248
+ batch_size = len(self._beam_hyps)
249
+ assert batch_size == (input_ids.shape[0] // self.group_size)
250
+
251
+ next_beam_scores = paddle.zeros(
252
+ [batch_size, self.group_size], dtype=next_scores.dtype
253
+ )
254
+ next_beam_tokens = paddle.zeros(
255
+ [batch_size, self.group_size], dtype=next_tokens.dtype
256
+ )
257
+ next_beam_indices = paddle.zeros(
258
+ [batch_size, self.group_size], dtype=next_indices.dtype
259
+ )
260
+
261
+ for batch_idx, beam_hyp in enumerate(self._beam_hyps):
262
+ if self._done[batch_idx] == 1:
263
+ assert (
264
+ len(beam_hyp) >= self.num_beams
265
+ ), "Batch can only be done if at least {} beams have been generated".format(
266
+ self.num_beams
267
+ )
268
+ assert (
269
+ eos_token_id is not None and pad_token_id is not None
270
+ ), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
271
+ # pad the batch
272
+ next_beam_scores[batch_idx, :] = 0
273
+ next_beam_tokens[batch_idx, :] = pad_token_id
274
+ next_beam_indices[batch_idx, :] = 0
275
+ continue
276
+
277
+ # next tokens for this sentence
278
+ beam_idx = 0
279
+ for beam_token_rank, (next_token, next_score, next_index) in enumerate(
280
+ zip(
281
+ next_tokens[batch_idx],
282
+ next_scores[batch_idx],
283
+ next_indices[batch_idx],
284
+ )
285
+ ):
286
+ batch_beam_idx = batch_idx * self.group_size + next_index
287
+ # add to generated hypotheses if end of sentence
288
+ if (eos_token_id is not None) and (next_token.item() == eos_token_id):
289
+ # If beam_token does not belong to top num_beams tokens,
290
+ # it should not be added
291
+ is_beam_token_worse_than_top_num_beams = (
292
+ beam_token_rank >= self.group_size
293
+ )
294
+ if is_beam_token_worse_than_top_num_beams:
295
+ continue
296
+ beam_hyp.add(
297
+ input_ids[batch_beam_idx.item()].clone(),
298
+ next_score.item(),
299
+ origin_len,
300
+ )
301
+
302
+ else:
303
+ # add next predicted token since it is not eos_token
304
+ next_beam_scores[batch_idx, beam_idx] = next_score
305
+ next_beam_tokens[batch_idx, beam_idx] = next_token.item()
306
+ next_beam_indices[batch_idx, beam_idx] = batch_beam_idx.item()
307
+ beam_idx += 1
308
+
309
+ # once the beam for next step is full, don't add more tokens to it.
310
+ if beam_idx == self.group_size:
311
+ break
312
+
313
+ if beam_idx < self.group_size:
314
+ raise ValueError(
315
+ "At most {} tokens in `next_tokens[batch_idx]` can be equal "
316
+ "to `eos_token_id: {}`. Make sure `next_tokens[batch_idx]` "
317
+ "are corrected.".format(self.group_size, eos_token_id)
318
+ )
319
+
320
+ # Check if we are done so that we can save a pad step if all(done)
321
+ if beam_hyp.is_done(
322
+ next_scores[batch_idx].max().item(), cur_len, origin_len
323
+ ):
324
+ self._done[batch_idx] = 1
325
+
326
+ return {
327
+ "next_beam_scores": next_beam_scores.reshape([-1]),
328
+ "next_beam_tokens": next_beam_tokens.reshape([-1]),
329
+ "next_beam_indices": next_beam_indices.reshape([-1]),
330
+ }
331
+
332
+ def finalize(
333
+ self,
334
+ input_ids,
335
+ final_beam_scores,
336
+ final_beam_tokens,
337
+ final_beam_indices,
338
+ origin_len=0,
339
+ pad_token_id=None,
340
+ eos_token_id=None,
341
+ ):
342
+ batch_size = len(self._beam_hyps)
343
+
344
+ # finalize all open beam hypotheses and add to generated hypotheses
345
+ for batch_idx, beam_hyp in enumerate(self._beam_hyps):
346
+ if self._done[batch_idx] == 1:
347
+ continue
348
+
349
+ # all open beam hypotheses are added to the beam hypothesis
350
+ # beam hypothesis class automatically keeps the best beams
351
+ for beam_id in range(self.num_beams):
352
+ batch_beam_idx = batch_idx * self.num_beams + beam_id
353
+ final_score = final_beam_scores[batch_beam_idx].item()
354
+ final_tokens = input_ids[batch_beam_idx]
355
+ beam_hyp.add(final_tokens, final_score, origin_len=origin_len)
356
+
357
+ # select the best hypotheses
358
+ sent_lengths = paddle.zeros(
359
+ [batch_size * self.num_beam_hyps_to_keep], dtype=input_ids.dtype
360
+ )
361
+ best = []
362
+
363
+ # retrieve best hypotheses
364
+ for i, beam_hyp in enumerate(self._beam_hyps):
365
+ sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0])
366
+ for j in range(self.num_beam_hyps_to_keep):
367
+ best_score, best_hyp = sorted_hyps.pop()
368
+ sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp)
369
+ best.append([best_hyp, best_score])
370
+
371
+ # prepare for adding eos
372
+ sent_max_len = min(sent_lengths.max().item() + 1, self.max_length)
373
+ decoded = paddle.zeros(
374
+ [batch_size * self.num_beam_hyps_to_keep, sent_max_len],
375
+ dtype=input_ids.dtype,
376
+ )
377
+ # shorter batches are padded if needed
378
+ if sent_lengths.min().item() != sent_lengths.max().item():
379
+ assert pad_token_id is not None, "`pad_token_id` has to be defined"
380
+ decoded[:, :] = pad_token_id
381
+ decoded_score = paddle.zeros([batch_size * self.num_beam_hyps_to_keep, 1])
382
+
383
+ # fill with hypotheses and eos_token_id if the latter fits in
384
+ for i, (hypo, score) in enumerate(best):
385
+ decoded[i, : sent_lengths[i].item()] = hypo.cpu().numpy()
386
+ decoded_score[i] = score
387
+ if sent_lengths[i] < self.max_length:
388
+ decoded[i, sent_lengths[i].item()] = eos_token_id
389
+ return decoded, decoded_score
390
+
391
+
392
+ class GenerationMixin(object):
393
+ r"""
394
+ This class implements the interface for generation task.
395
+
396
+ It's used as the base class of `paddlenlp.transformers.PretrainedModel
397
+ <https://paddlenlp.readthedocs.io/zh/latest/source/paddlenlp.transformers.model_utils.html>`__.
398
+ """
399
+
400
+ # enable `to_static` method for CausalLM Model
401
+ enable_to_static_method = False
402
+
403
+ @staticmethod
404
+ def prepare_input_ids_for_generation(bos_token_id, encoder_output=None):
405
+ batch_size = 1
406
+ if bos_token_id is None:
407
+ raise ValueError(
408
+ "`bos_token_id` should be defined when no " "`input_ids` are provided."
409
+ )
410
+ if encoder_output is not None:
411
+ batch_size = encoder_output.shape[0]
412
+ return paddle.ones([batch_size, 1], dtype="int64") * bos_token_id
413
+
414
+ @staticmethod
415
+ def prepare_attention_mask_for_generation(input_ids, pad_token_id, eos_token_id):
416
+ is_pad_token_in_inputs_ids = (pad_token_id is not None) and paddle.any(
417
+ input_ids == pad_token_id
418
+ ).item()
419
+ is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (
420
+ (eos_token_id is not None) and (pad_token_id != eos_token_id)
421
+ )
422
+ if is_pad_token_in_inputs_ids and is_pad_token_not_equal_to_eos_token_id:
423
+ attention_mask = (input_ids == pad_token_id).astype(
424
+ paddle.get_default_dtype()
425
+ ) * get_scale_by_dtype(return_positive=False)
426
+ else:
427
+ attention_mask = paddle.zeros_like(
428
+ input_ids, dtype=paddle.get_default_dtype()
429
+ )
430
+ return paddle.unsqueeze(attention_mask, axis=[1, 2])
431
+
432
+ @staticmethod
433
+ def prepare_seq_len_for_generation(input_ids, pad_token_id, eos_token_id):
434
+ is_pad_token_in_inputs_ids = (pad_token_id is not None) and paddle.any(
435
+ input_ids == pad_token_id
436
+ ).item()
437
+ is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (
438
+ (eos_token_id is not None) and (pad_token_id != eos_token_id)
439
+ )
440
+ if is_pad_token_in_inputs_ids and is_pad_token_not_equal_to_eos_token_id:
441
+ seq_len = paddle.sum(input_ids != pad_token_id, axis=1).unsqueeze(-1)
442
+ else:
443
+ seq_len = paddle.full(
444
+ (input_ids.shape[0], 1), input_ids.shape[1], dtype="int64"
445
+ )
446
+ return seq_len
447
+
448
+ def get_logits_processor(
449
+ self,
450
+ min_length=None,
451
+ max_length=None,
452
+ eos_token_id=None,
453
+ forced_bos_token_id=None,
454
+ forced_eos_token_id=None,
455
+ num_beams=1,
456
+ num_beam_groups=1,
457
+ diversity_rate=0.0,
458
+ repetition_penalty=None,
459
+ no_repeat_ngram_size=None,
460
+ logits_processors=None,
461
+ ):
462
+ processors = LogitsProcessorList()
463
+
464
+ if min_length is not None and eos_token_id is not None and min_length > -1:
465
+ processors.append(MinLengthLogitsProcessor(min_length, eos_token_id))
466
+ if num_beam_groups > 1 and diversity_rate > 0.0:
467
+ processors.append(
468
+ HammingDiversityLogitsProcessor(
469
+ diversity_rate=diversity_rate,
470
+ num_beams=num_beams,
471
+ num_beam_groups=num_beam_groups,
472
+ )
473
+ )
474
+ if repetition_penalty is not None and repetition_penalty != 1.0:
475
+ processors.append(
476
+ RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)
477
+ )
478
+ if no_repeat_ngram_size is not None and no_repeat_ngram_size > 0:
479
+ processors.append(NoRepeatNGramLogitsProcessor(no_repeat_ngram_size))
480
+ if forced_bos_token_id is not None:
481
+ processors.append(ForcedBOSTokenLogitsProcessor(forced_bos_token_id))
482
+ if forced_eos_token_id is not None:
483
+ processors.append(
484
+ ForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id)
485
+ )
486
+ # TODO
487
+ # Add more pre_processing for distribution
488
+
489
+ if logits_processors is not None:
490
+ custom_processors = LogitsProcessorList()
491
+ custom_processors_type = [type(lp) for lp in logits_processors]
492
+
493
+ for processor in processors:
494
+ if type(processor) not in custom_processors_type:
495
+ custom_processors.append(processor)
496
+ custom_processors.extend(logits_processors)
497
+
498
+ return custom_processors
499
+ else:
500
+ return processors
501
+
502
+ @staticmethod
503
+ def expand_inputs_for_generation(
504
+ input_ids, expand_size, attention_mask=None, **model_kwargs
505
+ ):
506
+
507
+ index = paddle.tile(
508
+ paddle.arange(input_ids.shape[0], dtype="int64").unsqueeze(-1),
509
+ [1, expand_size],
510
+ ).reshape([-1])
511
+
512
+ input_ids = paddle.gather(input_ids, index)
513
+
514
+ if attention_mask is not None:
515
+ model_kwargs["attention_mask"] = paddle.gather(attention_mask, index)
516
+
517
+ if (
518
+ "token_type_ids" in model_kwargs
519
+ and model_kwargs["token_type_ids"] is not None
520
+ ):
521
+ token_type_ids = model_kwargs["token_type_ids"]
522
+ model_kwargs["token_type_ids"] = paddle.gather(token_type_ids, index)
523
+
524
+ if "position_ids" in model_kwargs and model_kwargs["position_ids"] is not None:
525
+ position_ids = model_kwargs["position_ids"]
526
+ model_kwargs["position_ids"] = paddle.gather(position_ids, index)
527
+
528
+ if "seq_len" in model_kwargs and model_kwargs["seq_len"] is not None:
529
+ seq_len = model_kwargs["seq_len"]
530
+ model_kwargs["seq_len"] = paddle.gather(seq_len, index)
531
+
532
+ if (
533
+ "encoder_output" in model_kwargs
534
+ and model_kwargs["encoder_output"] is not None
535
+ ):
536
+ encoder_output = model_kwargs["encoder_output"]
537
+ model_kwargs["encoder_output"] = paddle.gather(encoder_output, index)
538
+
539
+ if "role_ids" in model_kwargs and model_kwargs["role_ids"] is not None:
540
+ role_ids = model_kwargs["role_ids"]
541
+ model_kwargs["role_ids"] = paddle.gather(role_ids, index)
542
+
543
+ return input_ids, model_kwargs
544
+
545
+ @staticmethod
546
+ def update_model_kwargs_for_generation(
547
+ outputs, model_kwargs, is_encoder_decoder=False
548
+ ):
549
+ # Update the model inputs during generation.
550
+ # Note that If `token_type_ids` and `attention_mask` in `model_kwargs`
551
+ # and they contain pad value, the result vectors updated by this method
552
+ # may be different from expected. In this case, you need to rewrite the
553
+ # method.
554
+
555
+ # update cache
556
+ if (
557
+ isinstance(outputs, tuple)
558
+ and len(outputs) > 1
559
+ and not isinstance(outputs[1], paddle.Tensor)
560
+ ):
561
+ model_kwargs["cache"] = outputs[1]
562
+ model_kwargs["past_key_values"] = outputs[1]
563
+
564
+ if isinstance(outputs, ModelOutput) and "past_key_values" in outputs:
565
+ model_kwargs["cache"] = outputs.past_key_values
566
+ model_kwargs["past_key_values"] = outputs.past_key_values
567
+
568
+ # update token_type_ids with last value
569
+ if (
570
+ "token_type_ids" in model_kwargs
571
+ and model_kwargs["token_type_ids"] is not None
572
+ ):
573
+ token_type_ids = model_kwargs["token_type_ids"]
574
+ model_kwargs["token_type_ids"] = paddle.concat(
575
+ [token_type_ids, token_type_ids[:, -1:]], axis=-1
576
+ )
577
+
578
+ # update position_ids
579
+ if "position_ids" in model_kwargs and model_kwargs["position_ids"] is not None:
580
+ position_ids = model_kwargs["position_ids"]
581
+ model_kwargs["position_ids"] = paddle.concat(
582
+ [position_ids, position_ids[..., -1:] + 1], axis=-1
583
+ )
584
+
585
+ # update attention_mask
586
+ if not is_encoder_decoder and "attention_mask" in model_kwargs:
587
+ attention_mask = model_kwargs["attention_mask"]
588
+ # nn.Pad2D don't support the data type `bool`
589
+ if convert_dtype(attention_mask.dtype) == "bool":
590
+ attention_mask = paddle.cast(attention_mask, "int64")
591
+ if len(attention_mask.shape) == 4:
592
+ cur_device = paddle.get_device()
593
+ if cur_device.split(":")[0] == "npu":
594
+ attention_mask = nn.Pad2D([0, 0, 0, 1], mode="constant")(
595
+ attention_mask
596
+ )
597
+ attention_mask = nn.Pad2D([0, 1, 0, 0], value=0)(attention_mask)
598
+ else:
599
+ attention_mask = nn.Pad2D([0, 0, 0, 1], mode="replicate")(
600
+ attention_mask
601
+ )
602
+ attention_mask = nn.Pad2D(
603
+ [0, 1, 0, 0], value=get_scale_by_dtype(return_positive=False)
604
+ )(attention_mask)
605
+
606
+ dtype = convert_dtype(attention_mask.dtype)
607
+ if "int" in dtype:
608
+ attention_mask[:, :, -1, -1] = 1
609
+ elif "float" in dtype:
610
+ attention_mask[:, :, -1, -1] = 0.0
611
+ else:
612
+ raise ValueError(
613
+ "The data type of input `attention_mask` must "
614
+ "be bool, int or float"
615
+ )
616
+ else:
617
+ attention_mask = paddle.concat(
618
+ [
619
+ attention_mask,
620
+ paddle.ones([attention_mask.shape[0], 1], dtype="int64"),
621
+ ],
622
+ axis=-1,
623
+ )
624
+ model_kwargs["attention_mask"] = attention_mask
625
+
626
+ # update role_ids
627
+ if "role_ids" in model_kwargs and model_kwargs["role_ids"] is not None:
628
+ role_ids = model_kwargs["role_ids"]
629
+ model_kwargs["role_ids"] = paddle.concat(
630
+ [role_ids, role_ids[:, -1:]], axis=-1
631
+ )
632
+
633
+ return model_kwargs
634
+
635
+ @staticmethod
636
+ def update_scores_for_generation(scores, next_scores, length, unfinished_flag):
637
+ # update scores
638
+
639
+ unfinished_scores = (
640
+ scores * paddle.to_tensor(length, dtype=scores.dtype) + next_scores
641
+ ) / (paddle.to_tensor(length, dtype=scores.dtype) + 1)
642
+ scores = paddle.where(unfinished_flag, unfinished_scores, scores)
643
+ return scores
644
+
645
+ def prepare_encoder_decoder_kwargs_for_generation(self, input_ids, model_kwargs):
646
+ if "encoder_output" not in model_kwargs:
647
+ # retrieve encoder hidden states
648
+ encoder = self.get_encoder()
649
+ encoder_kwargs = {
650
+ argument: value
651
+ for argument, value in model_kwargs.items()
652
+ if not (
653
+ argument.startswith("decoder_")
654
+ or argument.startswith("cross_attn")
655
+ or argument == "use_cache"
656
+ )
657
+ }
658
+ # Use inputs_embeds as the priority if inputs_embeds exists
659
+ if "inputs_embeds" in encoder_kwargs:
660
+ model_kwargs["encoder_output"] = encoder(**encoder_kwargs)
661
+ else:
662
+ model_kwargs["encoder_output"] = encoder(
663
+ input_ids=input_ids, **encoder_kwargs
664
+ )
665
+ return model_kwargs
666
+
667
+ def prepare_decoder_input_ids_for_generation(
668
+ self, input_ids, decoder_start_token_id=None, bos_token_id=None
669
+ ):
670
+ decoder_start_token_id = (
671
+ decoder_start_token_id
672
+ if decoder_start_token_id is not None
673
+ else self.config.decoder_start_token_id
674
+ )
675
+ decoder_start_token_id = (
676
+ decoder_start_token_id
677
+ if decoder_start_token_id is not None
678
+ else bos_token_id
679
+ )
680
+
681
+ decoder_input_ids = (
682
+ paddle.ones([input_ids.shape[0], 1], dtype="int64") * decoder_start_token_id
683
+ )
684
+
685
+ return decoder_input_ids
686
+
687
+ def get_decoder_start_token_id(
688
+ self, decoder_start_token_id=None, bos_token_id=None
689
+ ):
690
+ decoder_start_token_id = (
691
+ decoder_start_token_id
692
+ if decoder_start_token_id is not None
693
+ else self.config.decoder_start_token_id
694
+ )
695
+ bos_token_id = (
696
+ bos_token_id if bos_token_id is not None else self.config.bos_token_id
697
+ )
698
+
699
+ if decoder_start_token_id is not None:
700
+ return decoder_start_token_id
701
+ elif self.config.decoder_start_token_id is not None:
702
+ return self.config.decoder_start_token_id
703
+ elif bos_token_id is not None:
704
+ return bos_token_id
705
+ elif self.config.bos_token_id is not None:
706
+ return self.config.bos_token_id
707
+ raise ValueError(
708
+ "`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation."
709
+ )
710
+
711
+ def prepare_inputs_for_generation(self, input_ids, **kwargs):
712
+ # Implement in subclasses for custom behavior to prepare inputs in the
713
+ # generate method.
714
+
715
+ return {"input_ids": input_ids}
716
+
717
+ def adjust_logits_during_generation(self, logits):
718
+ # Implement in subclasses for custom behavior to adjust the logits in
719
+ # the generate method.
720
+
721
+ return logits
722
+
723
+ def prepare_fast_entry(self, kwargs):
724
+ return False
725
+
726
+ def _convert_to_fast(self, kwargs):
727
+ # try general convert
728
+ pass
729
+
730
+ def _build_fast(self, kwargs):
731
+ self._fast_entry = False
732
+ if kwargs["num_beam_groups"] != 1:
733
+ # not support for group_beam_search yet in the fast version
734
+ raise AttributeError(
735
+ "'num_beam_groups != 1' is not supported yet in the fast version"
736
+ )
737
+ if (
738
+ paddle.get_default_dtype() == "float16"
739
+ and kwargs["use_fp16_decoding"] is False
740
+ ):
741
+ logging.info(
742
+ "Since the default dtype is float16, float16 would be used "
743
+ "though 'use_fp16_decoding=False'."
744
+ )
745
+ kwargs["use_fp16_decoding"] = True
746
+ self.prepare_fast_entry(kwargs)
747
+
748
+ def set_pad_token_id(self, pad_token_id, eos_token_id):
749
+ if pad_token_id is None and eos_token_id is not None:
750
+ logging.warning(
751
+ "Setting `pad_token_id` to `eos_token_id`:{} for "
752
+ "open-end generation.".format(eos_token_id)
753
+ )
754
+ if isinstance(eos_token_id, list):
755
+ pad_token_id = eos_token_id[0]
756
+ else:
757
+ pad_token_id = eos_token_id
758
+ return pad_token_id
759
+
760
+ @paddle.no_grad()
761
+ def generate(
762
+ self,
763
+ input_ids: paddle.Tensor = None,
764
+ generation_config: GenerationConfig = None,
765
+ stopping_criteria: StoppingCriteria = None,
766
+ streamer=None,
767
+ synced_gpus: Optional[bool] = None,
768
+ **kwargs,
769
+ ):
770
+ r"""
771
+ The interface for generation task. This method can generate sequences
772
+ by using decoding strategy. Currently, there are three decoding
773
+ strategies supported: "greedy_search", "sampling" and "beam_search".
774
+
775
+ Args:
776
+ input_ids (Tensor, optional): The input sequence ids for the
777
+ generation. It is a Tensor with shape [batch_size, sequence_length].
778
+ The data type should be int32 or int64. Default to None, which
779
+ we will initialize it as a Tensor with shape [1, 1], filled
780
+ with the value `bos_token_id`.
781
+ generation_config (`~generation.GenerationConfig`, *optional*):
782
+ The generation configuration to be used as base parametrization for the generation call. `**kwargs`
783
+ passed to generate matching the attributes of `generation_config` will override them. If
784
+ `generation_config` is not provided, the default will be used, which had the following loading
785
+ priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
786
+ configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
787
+ default values, whose documentation should be checked to parameterize generation.
788
+ stopping_criteria (`StoppingCriteriaList`, *optional*):
789
+ Custom stopping criteria that complement the default stopping criteria built from arguments and a
790
+ generation config. If a stopping criteria is passed that is already created with the arguments or a
791
+ generation config an error is thrown. This feature is intended for advanced users.
792
+ streamer (`~streamer.BaseStreamer`, *optional*):
793
+ Streamer object that will be used to stream the generated sequences. Generated tokens are passed
794
+ through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
795
+ synced_gpus (`bool`, *optional*):
796
+ Whether to continue running the while loop until max_length. Unless overridden this flag will be set to
797
+ `True` under DeepSpeed ZeRO Stage 3 multiple GPUs environment to avoid hanging if one GPU finished
798
+ generating before other GPUs. Otherwise it'll be set to `False`.
799
+ kwargs (dict): It can be used to specify additional kwargs
800
+ passed to the model.
801
+
802
+ Returns:
803
+ tuple[Tensor]: It is a tuple contains two elements: ids and scores.
804
+ Each element is a Tensor.
805
+
806
+ With the fields:
807
+
808
+ - ids (Tensor):
809
+ The ids of the generated sequences. It is a Tensor with shape
810
+ [batch_size * num_return_sequences, sequence_length]. The data
811
+ type is same as the input `input_ids`.
812
+ - scores (Tensor):
813
+ The scores of the generated sequences. It is a Tensor with shape
814
+ [batch_size * num_return_sequences, 1]. The data type is float32
815
+ or float64, which is the same as the parameters in the model.
816
+
817
+ Example:
818
+ .. code-block::
819
+
820
+ import paddle
821
+ from paddlenlp.transformers import (
822
+ UnifiedTransformerLMHeadModel,
823
+ UnifiedTransformerTokenizer
824
+ )
825
+
826
+ paddle.seed(2)
827
+
828
+ # Initialize the model and tokenizer
829
+ model_name_or_path = 'unified_transformer-12L-cn-luge'
830
+ model = UnifiedTransformerLMHeadModel.from_pretrained(model_name_or_path)
831
+ tokenizer = UnifiedTransformerTokenizer.from_pretrained(model_name_or_path)
832
+
833
+ # Prepare the model inputs.
834
+ history = "早上好,今天空气质量不错。"
835
+ inputs = tokenizer.dialogue_encode(history, task_type='chitchat',
836
+ add_start_token_as_response=True, return_tensors=True)
837
+
838
+ .. code-block::
839
+
840
+ # Generate the sequence by using "greedy_search" strategy
841
+ ids, scores = model.generate(
842
+ **inputs,
843
+ decode_strategy="greedy_search")
844
+ print(ids.shape, scores.shape)
845
+ # [1, 3] [1, 1]
846
+ sequence_ids = ids.cpu().numpy().tolist()[0]
847
+ sequence_ids = sequence_ids[:sequence_ids.index(tokenizer.sep_token_id)]
848
+ response = tokenizer.convert_ids_to_string(sequence_ids, keep_space=False)
849
+ print(response)
850
+ # 是的
851
+
852
+ .. code-block::
853
+
854
+ # Generate 2 sequences by using "sampling" strategy (top_k=5)
855
+ generation_config = GenerationConfig(
856
+ decode_strategy="sampling",
857
+ top_k=5,
858
+ num_return_sequences=2
859
+ )
860
+ ids, scores = model.generate(
861
+ **inputs,
862
+ generation_config=generation_config,
863
+ )
864
+ print(ids.shape, scores.shape)
865
+ # [2, 7] [2, 1]
866
+ response = []
867
+ for sequence_ids in ids.cpu().numpy().tolist():
868
+ sequence_ids = sequence_ids[:sequence_ids.index(tokenizer.sep_token_id)]
869
+ text = tokenizer.convert_ids_to_string(sequence_ids, keep_space=False)
870
+ response.append(text)
871
+ print(response)
872
+ # ['天气好,心情也好', '你也是']
873
+
874
+ .. code-block::
875
+
876
+ # Generate 2 sequences by using "beam_search" strategy (num_beams=5)
877
+ generation_config = GenerationConfig(
878
+ decode_strategy="beam_search",
879
+ num_beams=5,
880
+ num_return_sequences=2
881
+ )
882
+ ids, scores = model.generate(
883
+ **inputs,
884
+ generation_config=generation_config,
885
+ )
886
+ print(ids.shape, scores.shape)
887
+ # [2, 3] [2, 1]
888
+ response = []
889
+ for sequence_ids in ids.cpu().numpy().tolist():
890
+ sequence_ids = sequence_ids[:sequence_ids.index(tokenizer.sep_token_id)]
891
+ text = tokenizer.convert_ids_to_string(sequence_ids, keep_space=False)
892
+ response.append(text)
893
+ print(response)
894
+ # ['是的', '嗯嗯']
895
+ """
896
+ if generation_config is None:
897
+ if (
898
+ self.generation_config is None
899
+ or self.generation_config._from_model_config
900
+ ):
901
+ new_generation_config = GenerationConfig.from_model_config(self.config)
902
+ if new_generation_config != self.generation_config:
903
+ logging.warning(
904
+ "model.generation_config is in conflict with model.config, "
905
+ "model.config is used."
906
+ )
907
+ self.generation_config = new_generation_config
908
+ generation_config = self.generation_config
909
+
910
+ # without update model.generation_config
911
+ generation_config = copy.deepcopy(generation_config)
912
+ model_kwargs = generation_config.update(**kwargs)
913
+
914
+ assert generation_config.decode_strategy in [
915
+ "greedy_search",
916
+ "sampling",
917
+ "beam_search",
918
+ ], "`decode_strategy` must be one of 'greedy_search', 'sampling' or 'beam_search' but received {}.".format(
919
+ generation_config.decode_strategy
920
+ )
921
+
922
+ if getattr(self, "deprecated_warnings", None) is None:
923
+ self.deprecated_warnings = {}
924
+
925
+ use_fast = False
926
+ if "use_faster" in model_kwargs:
927
+ raise ValueError("`use_faster` is deprecated now.")
928
+
929
+ if "use_fast" in model_kwargs:
930
+ raise ValueError("`use_fast` is deprecated now.")
931
+
932
+ bos_token_id = (
933
+ generation_config.bos_token_id
934
+ if generation_config.bos_token_id is not None
935
+ else self.config.bos_token_id
936
+ )
937
+ eos_token_id = (
938
+ generation_config.eos_token_id
939
+ if generation_config.eos_token_id is not None
940
+ else self.config.eos_token_id
941
+ )
942
+ pad_token_id = (
943
+ generation_config.pad_token_id
944
+ if generation_config.pad_token_id is not None
945
+ else self.config.pad_token_id
946
+ )
947
+ forced_bos_token_id = (
948
+ generation_config.forced_bos_token_id
949
+ if generation_config.forced_bos_token_id is not None
950
+ else self.config.forced_bos_token_id
951
+ )
952
+ forced_eos_token_id = (
953
+ generation_config.forced_eos_token_id
954
+ if generation_config.forced_eos_token_id is not None
955
+ else self.config.forced_eos_token_id
956
+ )
957
+ decoder_start_token_id = (
958
+ generation_config.decoder_start_token_id
959
+ if generation_config.decoder_start_token_id is not None
960
+ else self.config.decoder_start_token_id
961
+ )
962
+ no_repeat_ngram_size = (
963
+ generation_config.no_repeat_ngram_size
964
+ if generation_config.no_repeat_ngram_size is not None
965
+ else self.config.no_repeat_ngram_size
966
+ )
967
+
968
+ if getattr(self, "_fast_entry", None) is not False and use_fast:
969
+ fg_args = locals()
970
+ fg_args.pop("self")
971
+ fg_args.pop("__class__", None)
972
+ model_kwargs = fg_args.pop("model_kwargs")
973
+ fg_args.update(model_kwargs)
974
+ try:
975
+ if getattr(self, "_fast_entry", None) is None:
976
+ self._build_fast(fg_args)
977
+ if self._fast_entry:
978
+ output = self._fast_entry(**fg_args)
979
+ if isinstance(output, tuple):
980
+ output_ids, dummy_srore = output
981
+ else:
982
+ output_ids = output
983
+ # make result and fast result oneconsistent
984
+ dummy_srore = None
985
+ if generation_config.decode_strategy == "beam_search":
986
+ output_ids = output_ids.transpose([1, 2, 0])
987
+ output_ids = output_ids[
988
+ :, : generation_config.num_return_sequences, :
989
+ ].reshape([-1, output_ids.shape[-1]])
990
+ if dummy_srore is not None:
991
+ dummy_srore = dummy_srore[
992
+ :, : generation_config.num_return_sequences
993
+ ].flatten()
994
+ else:
995
+ output_ids = output_ids.transpose([1, 0])
996
+ return output_ids, dummy_srore
997
+
998
+ except Exception as e:
999
+ fg_args["model_kwargs"] = model_kwargs
1000
+ # TODO
1001
+ # Prevent self._convert_to_fast to throw Exception
1002
+ self._convert_to_fast(fg_args)
1003
+ logging.warning(e)
1004
+ logging.warning(
1005
+ "FastGeneration is not available, "
1006
+ "and the original version would be used instead."
1007
+ )
1008
+
1009
+ # input_ids in model_kwargs is supported
1010
+ if "input_ids" in model_kwargs:
1011
+ _input_ids = model_kwargs.pop("input_ids")
1012
+ if input_ids is None:
1013
+ input_ids = _input_ids
1014
+
1015
+ # params check
1016
+ if input_ids is None and "inputs_embeds" not in model_kwargs:
1017
+ # Init `input_ids` with bos_token_id
1018
+ input_ids = self.prepare_input_ids_for_generation(bos_token_id)
1019
+ elif "inputs_embeds" in model_kwargs:
1020
+ # Add input embeds support
1021
+ input_ids = self.prepare_input_ids_for_generation(
1022
+ bos_token_id, encoder_output=model_kwargs["inputs_embeds"]
1023
+ )
1024
+
1025
+ if model_kwargs.get("attention_mask", None) is None:
1026
+ # TODO
1027
+ # Init `attention_mask` depending on `pad_token_id`
1028
+ model_kwargs["attention_mask"] = self.prepare_attention_mask_for_generation(
1029
+ input_ids, pad_token_id, eos_token_id
1030
+ )
1031
+ self.is_encoder_decoder = self.config.is_encoder_decoder
1032
+
1033
+ if self.is_encoder_decoder:
1034
+ model_kwargs = self.prepare_encoder_decoder_kwargs_for_generation(
1035
+ input_ids, model_kwargs
1036
+ )
1037
+ # set input_ids as decoder_input_ids
1038
+ if "decoder_input_ids" in model_kwargs:
1039
+ input_ids = model_kwargs.pop("decoder_input_ids")
1040
+ else:
1041
+ input_ids = self.prepare_decoder_input_ids_for_generation(
1042
+ input_ids, decoder_start_token_id, bos_token_id
1043
+ )
1044
+ # streamer
1045
+ if streamer is not None:
1046
+ # streamer couldn't support beam_search strategy
1047
+ if (
1048
+ generation_config.decode_strategy == "beam_search"
1049
+ or generation_config.num_beams > 1
1050
+ ):
1051
+ raise ValueError(
1052
+ "`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1."
1053
+ )
1054
+
1055
+ pad_token_id = self.set_pad_token_id(pad_token_id, eos_token_id)
1056
+
1057
+ if (
1058
+ generation_config.max_length != 0
1059
+ and generation_config.max_new_tokens == DEFAULT_MAX_NEW_TOKENS
1060
+ ):
1061
+ logging.warning(
1062
+ "`max_length` will be deprecated in future releases, use `max_new_tokens` instead."
1063
+ )
1064
+ generation_config.max_new_tokens = generation_config.max_length
1065
+
1066
+ if generation_config.min_length != 0 and generation_config.min_new_tokens == 0:
1067
+ logging.warning(
1068
+ "`min_length` will be deprecated in future releases, use `min_new_tokens` instead."
1069
+ )
1070
+ generation_config.min_new_tokens = generation_config.min_length
1071
+
1072
+ max_length = generation_config.max_new_tokens
1073
+ min_length = generation_config.min_new_tokens
1074
+
1075
+ input_len = input_ids.shape[-1]
1076
+ min_len = input_len + min_length
1077
+ max_len = input_len + max_length
1078
+
1079
+ logits_processors = self.get_logits_processor(
1080
+ min_length=min_len if min_length > 0 else None,
1081
+ max_length=max_len,
1082
+ eos_token_id=eos_token_id,
1083
+ forced_bos_token_id=forced_bos_token_id,
1084
+ forced_eos_token_id=forced_eos_token_id,
1085
+ num_beams=generation_config.num_beams,
1086
+ num_beam_groups=generation_config.num_beam_groups,
1087
+ diversity_rate=generation_config.diversity_rate,
1088
+ repetition_penalty=generation_config.repetition_penalty,
1089
+ no_repeat_ngram_size=generation_config.no_repeat_ngram_size,
1090
+ logits_processors=(
1091
+ model_kwargs["logits_processors"]
1092
+ if "logits_processors" in model_kwargs
1093
+ and isinstance(model_kwargs["logits_processors"], LogitsProcessorList)
1094
+ else None
1095
+ ),
1096
+ )
1097
+ if "logits_processors" in model_kwargs:
1098
+ model_kwargs.pop("logits_processors")
1099
+
1100
+ stopping_criteria = (
1101
+ stopping_criteria
1102
+ if stopping_criteria is not None
1103
+ else StoppingCriteriaList()
1104
+ )
1105
+
1106
+ if generation_config.decode_strategy == "greedy_search":
1107
+ if generation_config.num_return_sequences > 1:
1108
+ raise ValueError(
1109
+ "`num_return_sequences` has to be 1, but is {} "
1110
+ "when doing greedy search.".format(
1111
+ generation_config.num_return_sequences
1112
+ )
1113
+ )
1114
+ return self.greedy_search(
1115
+ input_ids,
1116
+ logits_processors,
1117
+ max_len,
1118
+ pad_token_id,
1119
+ eos_token_id,
1120
+ stopping_criteria=stopping_criteria,
1121
+ streamer=streamer,
1122
+ fast_ptq_sampling=generation_config.fast_ptq_sampling,
1123
+ trunc_input=generation_config.trunc_input,
1124
+ synced_gpus=synced_gpus,
1125
+ **model_kwargs,
1126
+ )
1127
+
1128
+ elif generation_config.decode_strategy == "sampling":
1129
+ if generation_config.num_return_sequences > 1:
1130
+ input_ids, model_kwargs = self.expand_inputs_for_generation(
1131
+ input_ids,
1132
+ expand_size=generation_config.num_return_sequences,
1133
+ **model_kwargs,
1134
+ )
1135
+
1136
+ return self.sample(
1137
+ input_ids,
1138
+ logits_processors,
1139
+ max_len,
1140
+ pad_token_id,
1141
+ eos_token_id,
1142
+ generation_config.top_k,
1143
+ generation_config.top_p,
1144
+ generation_config.temperature,
1145
+ stopping_criteria=stopping_criteria,
1146
+ streamer=streamer,
1147
+ fast_ptq_sampling=generation_config.fast_ptq_sampling,
1148
+ trunc_input=generation_config.trunc_input,
1149
+ synced_gpus=synced_gpus,
1150
+ **model_kwargs,
1151
+ )
1152
+
1153
+ elif generation_config.decode_strategy == "beam_search":
1154
+ batch_size = input_ids.shape[0]
1155
+ if generation_config.num_return_sequences > generation_config.num_beams:
1156
+ raise ValueError(
1157
+ "`num_return_sequences` has to be smaller or equal to "
1158
+ "`num_beams`. But received `num_return_sequences` is {}, "
1159
+ "`num_beams` is {}".format(
1160
+ generation_config.num_return_sequences,
1161
+ generation_config.num_beams,
1162
+ )
1163
+ )
1164
+ if generation_config.num_beams <= 1:
1165
+ raise ValueError(
1166
+ "`num_beams` has to be bigger than 1. But received "
1167
+ "`num_beams` is {}. If `num_beams` is 1, `decode_strategy` "
1168
+ "should be 'greedy_search'".format(generation_config.num_beams)
1169
+ )
1170
+ if generation_config.num_beam_groups > 1:
1171
+ diverse_beam_scorer = BeamSearchScorer(
1172
+ batch_size=batch_size,
1173
+ max_length=max_len,
1174
+ num_beams=generation_config.num_beams,
1175
+ length_penalty=generation_config.length_penalty,
1176
+ do_early_stopping=generation_config.early_stopping,
1177
+ num_beam_hyps_to_keep=generation_config.num_return_sequences,
1178
+ num_beam_groups=generation_config.num_beam_groups,
1179
+ )
1180
+
1181
+ # interleave with `num_beams`
1182
+ input_ids, model_kwargs = self.expand_inputs_for_generation(
1183
+ input_ids, expand_size=generation_config.num_beams, **model_kwargs
1184
+ )
1185
+
1186
+ return self.group_beam_search(
1187
+ input_ids,
1188
+ diverse_beam_scorer,
1189
+ logits_processors,
1190
+ max_len,
1191
+ pad_token_id,
1192
+ eos_token_id,
1193
+ stopping_criteria=stopping_criteria,
1194
+ fast_ptq_sampling=generation_config.fast_ptq_sampling,
1195
+ trunc_input=generation_config.trunc_input,
1196
+ synced_gpus=synced_gpus,
1197
+ **model_kwargs,
1198
+ )
1199
+ else:
1200
+ beam_scorer = BeamSearchScorer(
1201
+ batch_size=batch_size,
1202
+ max_length=max_len,
1203
+ num_beams=generation_config.num_beams,
1204
+ length_penalty=generation_config.length_penalty,
1205
+ do_early_stopping=generation_config.early_stopping,
1206
+ num_beam_hyps_to_keep=generation_config.num_return_sequences,
1207
+ )
1208
+
1209
+ input_ids, model_kwargs = self.expand_inputs_for_generation(
1210
+ input_ids, expand_size=generation_config.num_beams, **model_kwargs
1211
+ )
1212
+
1213
+ return self.beam_search(
1214
+ input_ids,
1215
+ beam_scorer,
1216
+ logits_processors,
1217
+ max_len,
1218
+ generation_config.diversity_rate,
1219
+ pad_token_id,
1220
+ eos_token_id,
1221
+ stopping_criteria=stopping_criteria,
1222
+ fast_ptq_sampling=generation_config.fast_ptq_sampling,
1223
+ trunc_input=generation_config.trunc_input,
1224
+ synced_gpus=synced_gpus,
1225
+ **model_kwargs,
1226
+ )
1227
+
1228
+ def greedy_search(
1229
+ self,
1230
+ input_ids,
1231
+ logits_processors,
1232
+ max_length,
1233
+ pad_token_id,
1234
+ eos_token_id,
1235
+ stopping_criteria=None,
1236
+ streamer=None,
1237
+ fast_ptq_sampling=False,
1238
+ trunc_input=True,
1239
+ synced_gpus=False,
1240
+ **model_kwargs,
1241
+ ):
1242
+ model_kwargs["use_cache"] = model_kwargs.get("use_cache", True)
1243
+ logits_processors = (
1244
+ logits_processors
1245
+ if logits_processors is not None
1246
+ else LogitsProcessorList()
1247
+ )
1248
+
1249
+ # max_length will be convert to MaxLengthCriteria
1250
+ stopping_criteria = (
1251
+ stopping_criteria
1252
+ if stopping_criteria is not None
1253
+ else StoppingCriteriaList()
1254
+ )
1255
+ if max_length is not None:
1256
+ # logging.warning(
1257
+ # "`max_length` is deprecated in this function, use"
1258
+ # " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead."
1259
+ # )
1260
+ stopping_criteria = validate_stopping_criteria(
1261
+ stopping_criteria, max_length
1262
+ )
1263
+
1264
+ batch_size, cur_len = input_ids.shape
1265
+ origin_len = cur_len
1266
+ unfinished_flag = paddle.full([batch_size, 1], True, dtype="bool")
1267
+ scores = paddle.full([batch_size, 1], 0.0, dtype=paddle.get_default_dtype())
1268
+ generate_end = False
1269
+ while True:
1270
+ if synced_gpus:
1271
+ # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
1272
+ # The following logic allows an early break if all peers finished generating their sequence
1273
+ this_peer_finished_flag = paddle.to_tensor(0.0 if generate_end else 1.0)
1274
+ # send 0.0 if we finished, 1.0 otherwise
1275
+ dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
1276
+ # did all peers finish? the reduced sum will be 0.0 then
1277
+ if this_peer_finished_flag.item() == 0.0:
1278
+ break
1279
+
1280
+ # prepare model inputs & get model output
1281
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
1282
+
1283
+ outputs = self(**model_inputs)
1284
+
1285
+ if synced_gpus and generate_end:
1286
+ continue # don't waste resources running the code we don't need
1287
+
1288
+ if isinstance(outputs, tuple):
1289
+ logits = outputs[0]
1290
+ elif isinstance(outputs, ModelOutput):
1291
+ logits = outputs.logits
1292
+ else:
1293
+ logits = outputs
1294
+
1295
+ # [batch_size, vocab_size]
1296
+ next_token_logits = logits[:, -1, :]
1297
+
1298
+ # pre-process distribution
1299
+ next_token_logits = self.adjust_logits_during_generation(next_token_logits)
1300
+ probs = logits_processors(input_ids, next_token_logits)
1301
+ # greedy
1302
+ next_tokens = paddle.argmax(probs, axis=-1).unsqueeze(-1)
1303
+ next_scores = paddle.index_sample(probs, next_tokens)
1304
+
1305
+ if eos_token_id is not None:
1306
+ next_tokens = paddle.where(
1307
+ unfinished_flag,
1308
+ next_tokens,
1309
+ paddle.full_like(next_tokens, pad_token_id),
1310
+ )
1311
+
1312
+ scores = self.update_scores_for_generation(
1313
+ scores, next_scores, cur_len - origin_len, unfinished_flag
1314
+ )
1315
+ cur_len += 1
1316
+
1317
+ input_ids = paddle.concat([input_ids, next_tokens], axis=1)
1318
+ if streamer is not None:
1319
+ if self.config.tensor_parallel_rank == 0:
1320
+ streamer.put(next_tokens.cpu())
1321
+
1322
+ if stopping_criteria(input_ids, scores):
1323
+ generate_end = True
1324
+
1325
+ if eos_token_id is not None:
1326
+ unfinished_flag = get_unfinished_flag(
1327
+ input_ids, unfinished_flag, eos_token_id
1328
+ )
1329
+ if not paddle.any(unfinished_flag):
1330
+ generate_end = True
1331
+
1332
+ # Stop when there is a </s> in all sentences
1333
+ if generate_end and not synced_gpus:
1334
+ break
1335
+
1336
+ model_kwargs = self.update_model_kwargs_for_generation(
1337
+ outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
1338
+ )
1339
+ if fast_ptq_sampling:
1340
+ break
1341
+
1342
+ if streamer is not None:
1343
+ streamer.end()
1344
+
1345
+ return input_ids[:, origin_len:] if trunc_input else input_ids, scores
1346
+
1347
+ def sample(
1348
+ self,
1349
+ input_ids,
1350
+ logits_processors,
1351
+ max_length,
1352
+ pad_token_id,
1353
+ eos_token_id,
1354
+ top_k=None,
1355
+ top_p=None,
1356
+ temperature=None,
1357
+ min_tokens_to_keep=1,
1358
+ stopping_criteria=None,
1359
+ streamer=None,
1360
+ fast_ptq_sampling=False,
1361
+ trunc_input=True,
1362
+ synced_gpus=False,
1363
+ **model_kwargs,
1364
+ ):
1365
+ model_kwargs["use_cache"] = model_kwargs.get("use_cache", True)
1366
+
1367
+ logits_processors = (
1368
+ logits_processors
1369
+ if logits_processors is not None
1370
+ else LogitsProcessorList()
1371
+ )
1372
+
1373
+ # max_length will be convert to MaxLengthCriteria
1374
+ stopping_criteria = (
1375
+ stopping_criteria
1376
+ if stopping_criteria is not None
1377
+ else StoppingCriteriaList()
1378
+ )
1379
+ if max_length is not None:
1380
+ # logging.warning(
1381
+ # "`max_length` is deprecated in this function, use"
1382
+ # " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead."
1383
+ # )
1384
+ stopping_criteria = validate_stopping_criteria(
1385
+ stopping_criteria, max_length
1386
+ )
1387
+
1388
+ batch_size, cur_len = input_ids.shape
1389
+ origin_len = cur_len
1390
+ unfinished_flag = paddle.full([batch_size, 1], True, dtype="bool")
1391
+ scores = paddle.full([batch_size, 1], 0.0, dtype=paddle.get_default_dtype())
1392
+
1393
+ generate_end = False
1394
+ while True:
1395
+ if synced_gpus:
1396
+ # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
1397
+ # The following logic allows an early break if all peers finished generating their sequence
1398
+ this_peer_finished_flag = paddle.to_tensor(0.0 if generate_end else 1.0)
1399
+ # send 0.0 if we finished, 1.0 otherwise
1400
+ dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
1401
+ # did all peers finish? the reduced sum will be 0.0 then
1402
+ if this_peer_finished_flag.item() == 0.0:
1403
+ break
1404
+ # prepare model inputs & get model output
1405
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
1406
+ # NOTE: to decrease ref-count and clear outdate cache in-time
1407
+ model_kwargs["cache"] = None
1408
+ model_kwargs["past_key_values"] = None
1409
+ outputs = self(**model_inputs)
1410
+ if synced_gpus and generate_end:
1411
+ continue # don't waste resources running the code we don't need
1412
+
1413
+ if isinstance(outputs, tuple):
1414
+ logits = outputs[0]
1415
+ elif isinstance(outputs, ModelOutput):
1416
+ logits = outputs.logits
1417
+ else:
1418
+ logits = outputs
1419
+
1420
+ # [batch_size, vocab_size]
1421
+ logits = logits[:, -1, :]
1422
+
1423
+ # pre-process distribution
1424
+ logits = self.adjust_logits_during_generation(logits)
1425
+ logits = logits_processors(input_ids, logits)
1426
+
1427
+ # sample
1428
+ origin_probs = F.softmax(logits)
1429
+ origin_probs = paddle.log(origin_probs)
1430
+ if temperature is not None and temperature != 1.0:
1431
+ logits = logits / temperature
1432
+ probs = F.softmax(logits)
1433
+ if top_k is not None and top_k != 0:
1434
+ probs = TopKProcess(probs, top_k, min_tokens_to_keep)
1435
+ if top_p is not None and top_p < 1.0:
1436
+ probs = TopPProcess(probs, top_p, min_tokens_to_keep)
1437
+ if paddle.device.is_compiled_with_custom_device("gcu"):
1438
+ probs = paddle.cast(probs, "float32")
1439
+ if paddle.device.is_compiled_with_xpu():
1440
+ probs = paddle.cast(probs, "float32")
1441
+
1442
+ # multinomial already support fp16 and bf16 currently, fix issue: https://github.com/PaddlePaddle/Paddle/issues/51852
1443
+ next_tokens = paddle.multinomial(probs)
1444
+
1445
+ if self.config.tensor_parallel_degree > 1:
1446
+ # Maybe no need to broadcast if seed is set correctly.
1447
+ from paddle.distributed import fleet
1448
+
1449
+ try:
1450
+ hcg = fleet.get_hybrid_communicate_group()
1451
+ group = hcg.get_model_parallel_group()
1452
+ src = hcg.get_model_parallel_group_src_rank()
1453
+ except:
1454
+ group, src = None, 0
1455
+ paddle.distributed.broadcast(next_tokens, src=src, group=group)
1456
+ # config does not include pipeline_parallel_degree, and pipeline parallel
1457
+ # uses trainer.model_wrapped to run in both train and predict mode
1458
+ # which has pp_group as a attribute
1459
+ # TODO(guosheng): only let the last stage of pipeline to do softmax
1460
+ # and sampling, and then broadcast to avoid broadcast logits.
1461
+ if getattr(self, "pp_group", None) is not None:
1462
+ paddle.distributed.broadcast(
1463
+ next_tokens,
1464
+ src=self.pp_group.ranks[0],
1465
+ group=self.pp_group, # use rank 0 for same seed to check
1466
+ )
1467
+
1468
+ next_scores = paddle.index_sample(origin_probs, next_tokens)
1469
+ if eos_token_id is not None:
1470
+ next_tokens = paddle.where(
1471
+ unfinished_flag,
1472
+ next_tokens,
1473
+ paddle.full_like(next_tokens, pad_token_id),
1474
+ )
1475
+
1476
+ scores = self.update_scores_for_generation(
1477
+ scores, next_scores, cur_len - origin_len, unfinished_flag
1478
+ )
1479
+
1480
+ cur_len += 1
1481
+ input_ids = paddle.concat([input_ids, next_tokens], axis=1)
1482
+ if streamer is not None:
1483
+ if self.config.tensor_parallel_rank == 0:
1484
+ streamer.put(next_tokens.cpu())
1485
+
1486
+ if stopping_criteria(input_ids, scores):
1487
+ generate_end = True
1488
+
1489
+ if eos_token_id is not None:
1490
+ unfinished_flag = get_unfinished_flag(
1491
+ input_ids, unfinished_flag, eos_token_id
1492
+ )
1493
+ if not paddle.any(unfinished_flag):
1494
+ generate_end = True
1495
+
1496
+ # Stop when there is a </s> in all sentences
1497
+ if generate_end and not synced_gpus:
1498
+ break
1499
+
1500
+ model_kwargs = self.update_model_kwargs_for_generation(
1501
+ outputs, model_kwargs, is_encoder_decoder=self.is_encoder_decoder
1502
+ )
1503
+ if fast_ptq_sampling:
1504
+ break
1505
+
1506
+ if streamer is not None:
1507
+ streamer.end()
1508
+
1509
+ return input_ids[:, origin_len:] if trunc_input else input_ids, scores
1510
+
1511
+ def _get_model_inputs_spec(self, dtype: str):
1512
+ spec = {
1513
+ "input_ids": paddle.static.InputSpec(shape=[None, None], dtype="int64"),
1514
+ "attention_mask": paddle.static.InputSpec(
1515
+ shape=[None, None], dtype="int64"
1516
+ ),
1517
+ }
1518
+ if "position_ids" in inspect.getfullargspec(self.forward).args:
1519
+ spec["position_ids"] = paddle.static.InputSpec(
1520
+ shape=[None, None], dtype="int64"
1521
+ )
1522
+ return spec
1523
+
1524
+ def to_static(self, path: str, config: dict):
1525
+ """export generation model to static
1526
+
1527
+ Args:
1528
+ path (str): path of saved inference model
1529
+ config (dict): configuration for generation
1530
+ bos_token_id (int): token id of begin-of-sentence
1531
+ eos_token_id (int): token id of end-of-sentence
1532
+ pad_token_id (int): token id of pad token
1533
+ use_top_p (bool): whether use top_p decoding strategy
1534
+ """
1535
+
1536
+ use_top_p = config.get("use_top_p", True)
1537
+
1538
+ top_k_spec = (
1539
+ paddle.static.InputSpec(shape=[1], dtype="int64") if not use_top_p else 0
1540
+ )
1541
+
1542
+ top_p_spec = (
1543
+ paddle.static.InputSpec(shape=[1], dtype="float32") if use_top_p else 1.0
1544
+ )
1545
+ temperature = (
1546
+ paddle.static.InputSpec(shape=[1], dtype="float32") if use_top_p else 1.0
1547
+ )
1548
+ dtype = config.get("dtype", None)
1549
+
1550
+ logits_processors = config.get("logits_processors", None)
1551
+ model_inputs_spec = self._get_model_inputs_spec(dtype)
1552
+
1553
+ input_spec = [
1554
+ model_inputs_spec["input_ids"], # input_ids
1555
+ model_inputs_spec["attention_mask"], # attention_mask
1556
+ model_inputs_spec.get("position_ids", None), # attention_mask
1557
+ logits_processors,
1558
+ paddle.static.InputSpec(shape=[1], dtype="int64"), # max_length
1559
+ self.generation_config.pad_token_id or config.get("pad_token_id", None),
1560
+ self.generation_config.eos_token_id or config.get("eos_token_id", None),
1561
+ top_k_spec, # top_k
1562
+ top_p_spec, # top_p
1563
+ temperature, # temperature
1564
+ 1,
1565
+ ]
1566
+
1567
+ model = paddle.jit.to_static(self.sample_d2s, input_spec=input_spec)
1568
+
1569
+ paddle.jit.save(model, path)
1570
+
1571
+ def sample_d2s(
1572
+ self,
1573
+ input_ids,
1574
+ attention_mask,
1575
+ position_ids,
1576
+ logits_processors,
1577
+ max_new_tokens,
1578
+ pad_token_id,
1579
+ eos_token_id,
1580
+ top_k=None,
1581
+ top_p=None,
1582
+ temperature=None,
1583
+ min_tokens_to_keep=1,
1584
+ ):
1585
+
1586
+ pad_token_id = self.set_pad_token_id(pad_token_id, eos_token_id)
1587
+ logits_processors = (
1588
+ logits_processors
1589
+ if logits_processors is not None
1590
+ else LogitsProcessorList()
1591
+ )
1592
+
1593
+ if paddle.is_tensor(top_k) and not paddle.is_tensor(top_p):
1594
+ use_top_p = False
1595
+ elif not paddle.is_tensor(top_k) and paddle.is_tensor(top_p):
1596
+ use_top_p = True
1597
+
1598
+ # top_k and top_p are the const value
1599
+ elif isinstance(top_p, float) or isinstance(top_k, int):
1600
+ use_top_p = True
1601
+ else:
1602
+ if top_p is None and top_k is None:
1603
+ raise ValueError("top_k and top_p should not be None")
1604
+ raise ValueError(
1605
+ "you should not specify InputSpec for top_k and top_p parameters, one of InputSpec is expected"
1606
+ )
1607
+
1608
+ batch_size, cur_len = input_ids.shape
1609
+ # used for compute on gpu, avoid memcpy D2H
1610
+ cur_len_gpu = paddle.full([1], cur_len, dtype="int64")
1611
+
1612
+ origin_len = input_ids.shape[1]
1613
+ # used for compute on gpu, avoid memcpy D2H
1614
+ origin_len_gpu = paddle.full([1], origin_len, dtype="int64")
1615
+
1616
+ unfinished_flag = paddle.full([batch_size, 1], True, dtype="bool")
1617
+
1618
+ scores = paddle.full([batch_size, 1], 0.0, dtype=paddle.get_default_dtype())
1619
+
1620
+ # use_cache is immutable, we split it off other mutable kwargs.
1621
+ immutable = {"use_cache": True}
1622
+ model_kwargs = {"attention_mask": attention_mask, "position_ids": position_ids}
1623
+
1624
+ def _forward_(**args):
1625
+ model_inputs = self.prepare_inputs_for_generation(
1626
+ input_ids, **args, **immutable
1627
+ )
1628
+ assert "use_cache" in model_inputs
1629
+ del model_inputs["use_cache"]
1630
+ return self(**model_inputs, **immutable)
1631
+
1632
+ def _post_process_(
1633
+ outputs,
1634
+ input_ids,
1635
+ cur_len,
1636
+ origin_len,
1637
+ scores,
1638
+ unfinished_flag,
1639
+ model_kwargs,
1640
+ pad_token_id,
1641
+ ):
1642
+ if isinstance(outputs, tuple):
1643
+ logits = outputs[0]
1644
+ elif isinstance(outputs, ModelOutput):
1645
+ logits = outputs.logits
1646
+ else:
1647
+ logits = outputs
1648
+
1649
+ # [batch_size, vocab_size]
1650
+ logits = logits[:, -1, :]
1651
+
1652
+ # pre-process distribution
1653
+ logits = self.adjust_logits_during_generation(logits)
1654
+
1655
+ logits = logits_processors(input_ids, logits)
1656
+ probs = F.softmax(logits)
1657
+
1658
+ # sample
1659
+ origin_probs = F.log_softmax(logits)
1660
+ # compute next_tokens
1661
+ if use_top_p:
1662
+ logits = logits / temperature
1663
+ top_ps_tensor = paddle.full(
1664
+ shape=[probs.shape[0], 1], fill_value=top_p, dtype=probs.dtype
1665
+ )
1666
+ _, next_tokens = paddle.tensor.top_p_sampling(probs, top_ps_tensor)
1667
+ else:
1668
+ probs = TopKProcess(probs, top_k, min_tokens_to_keep)
1669
+ if top_k == 1:
1670
+ next_tokens = paddle.unsqueeze_(paddle.argmax(probs, axis=-1), -1)
1671
+ else:
1672
+ next_tokens = paddle.multinomial(probs)
1673
+
1674
+ next_scores = paddle.index_sample(origin_probs, next_tokens)
1675
+ scores = self.update_scores_for_generation(
1676
+ scores, next_scores, cur_len - origin_len, unfinished_flag
1677
+ )
1678
+ if eos_token_id is not None:
1679
+ next_tokens = paddle.where(
1680
+ unfinished_flag,
1681
+ next_tokens,
1682
+ paddle.full_like(next_tokens, pad_token_id),
1683
+ )
1684
+
1685
+ input_ids = paddle.concat([input_ids, next_tokens], axis=1)
1686
+
1687
+ if eos_token_id is not None:
1688
+ unfinished_flag = get_unfinished_flag(
1689
+ input_ids, unfinished_flag, eos_token_id
1690
+ )
1691
+
1692
+ model_kwargs = self.update_model_kwargs_for_generation(
1693
+ outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
1694
+ )
1695
+
1696
+ return input_ids, scores, unfinished_flag, model_kwargs
1697
+
1698
+ outputs = _forward_(**model_kwargs)
1699
+ input_ids, scores, unfinished_flag, model_kwargs = _post_process_(
1700
+ outputs,
1701
+ input_ids,
1702
+ cur_len_gpu,
1703
+ origin_len_gpu,
1704
+ scores,
1705
+ unfinished_flag,
1706
+ model_kwargs,
1707
+ pad_token_id,
1708
+ )
1709
+
1710
+ cur_len += 1
1711
+ cur_len_gpu += 1
1712
+
1713
+ attn_mask = model_kwargs["attention_mask"]
1714
+ # make the shape of attention_mask = (-1, -1, -1, -1) in dy2static.
1715
+ model_kwargs["attention_mask"] = paddle.reshape(attn_mask, attn_mask.shape)
1716
+ model_kwargs["cache"] = outputs[1] if isinstance(outputs, tuple) else None
1717
+ max_new_tokens = paddle.full([1], max_new_tokens + cur_len - 1, dtype="int64")
1718
+
1719
+ while cur_len < max_new_tokens and paddle.any(unfinished_flag):
1720
+ input_ids, scores, unfinished_flag, model_kwargs = _post_process_(
1721
+ _forward_(**model_kwargs),
1722
+ input_ids,
1723
+ cur_len_gpu,
1724
+ origin_len_gpu,
1725
+ scores,
1726
+ unfinished_flag,
1727
+ model_kwargs,
1728
+ pad_token_id,
1729
+ )
1730
+ cur_len += 1
1731
+ cur_len_gpu += 1
1732
+
1733
+ return input_ids[:, origin_len:], scores
1734
+
1735
+ def reorder_cache(self, cache, beam_idx):
1736
+ cache = map_structure(lambda x: paddle.index_select(x, beam_idx), cache)
1737
+ return cache
1738
+
1739
+ def beam_search(
1740
+ self,
1741
+ input_ids,
1742
+ beam_scorer,
1743
+ logits_processors,
1744
+ max_length,
1745
+ diversity_rate,
1746
+ pad_token_id,
1747
+ eos_token_id,
1748
+ stopping_criteria=None,
1749
+ fast_ptq_sampling=False,
1750
+ trunc_input=True,
1751
+ synced_gpus=False,
1752
+ **model_kwargs,
1753
+ ):
1754
+ model_kwargs["use_cache"] = model_kwargs.get("use_cache", True)
1755
+
1756
+ logits_processors = (
1757
+ logits_processors
1758
+ if logits_processors is not None
1759
+ else LogitsProcessorList()
1760
+ )
1761
+
1762
+ # max_length will be convert to MaxLengthCriteria
1763
+ stopping_criteria = (
1764
+ stopping_criteria
1765
+ if stopping_criteria is not None
1766
+ else StoppingCriteriaList()
1767
+ )
1768
+ if max_length is not None:
1769
+ # logging.warning(
1770
+ # "`max_length` is deprecated in this function, use"
1771
+ # " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead."
1772
+ # )
1773
+ stopping_criteria = validate_stopping_criteria(
1774
+ stopping_criteria, max_length
1775
+ )
1776
+
1777
+ batch_size = len(beam_scorer._beam_hyps)
1778
+ num_beams = beam_scorer.num_beams
1779
+ batch_beam_size, cur_len = input_ids.shape
1780
+ origin_len = cur_len
1781
+
1782
+ assert (
1783
+ num_beams * batch_size == batch_beam_size
1784
+ ), "Batch dimension of `input_ids` should be {}, but received {}.".format(
1785
+ num_beams * batch_size, batch_beam_size
1786
+ )
1787
+
1788
+ beam_scores = paddle.zeros(
1789
+ (batch_size, num_beams), dtype=paddle.get_default_dtype()
1790
+ )
1791
+
1792
+ beam_scores[:, 1:] = get_scale_by_dtype(return_positive=False)
1793
+ beam_scores = paddle.reshape(beam_scores, [-1])
1794
+
1795
+ generate_end = False
1796
+ while True:
1797
+ if synced_gpus:
1798
+ # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
1799
+ # The following logic allows an early break if all peers finished generating their sequence
1800
+ this_peer_finished_flag = paddle.to_tensor(0.0 if generate_end else 1.0)
1801
+ # send 0.0 if we finished, 1.0 otherwise
1802
+ dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
1803
+ # did all peers finish? the reduced sum will be 0.0 then
1804
+ if this_peer_finished_flag.item() == 0.0:
1805
+ break
1806
+ # prepare model inputs & get model output
1807
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
1808
+
1809
+ outputs = self(**model_inputs)
1810
+ if synced_gpus and generate_end:
1811
+ cur_len = cur_len + 1
1812
+ continue # don't waste resources running the code we don't need
1813
+
1814
+ if isinstance(outputs, tuple):
1815
+ logits = outputs[0]
1816
+ elif isinstance(outputs, ModelOutput):
1817
+ logits = outputs.logits
1818
+ else:
1819
+ logits = outputs
1820
+
1821
+ # [batch_size, vocab_size]
1822
+ logits = logits[:, -1, :]
1823
+
1824
+ # pre-process distribution
1825
+ logits = self.adjust_logits_during_generation(logits)
1826
+ # beam search
1827
+ # [batch_size * num_beams, vocab_size]
1828
+ next_scores = F.softmax(logits)
1829
+ next_scores = paddle.log(next_scores)
1830
+ next_scores = logits_processors(input_ids, next_scores)
1831
+ next_scores = next_scores + beam_scores.unsqueeze(-1)
1832
+
1833
+ vocab_size = next_scores.shape[-1]
1834
+ if diversity_rate == 0.0:
1835
+ # reshape for beam search
1836
+ next_scores = next_scores.reshape([batch_size, num_beams * vocab_size])
1837
+
1838
+ next_scores, next_tokens = paddle.topk(
1839
+ next_scores, 2 * num_beams, axis=1
1840
+ )
1841
+
1842
+ next_indices = next_tokens // vocab_size
1843
+ next_tokens = next_tokens % vocab_size
1844
+
1845
+ else:
1846
+ next_scores, next_tokens = paddle.topk(
1847
+ next_scores, 2 * num_beams, axis=1
1848
+ )
1849
+
1850
+ sibling_score = (
1851
+ paddle.arange(1, 2 * num_beams + 1, dtype="int64").unsqueeze(0)
1852
+ * diversity_rate
1853
+ )
1854
+
1855
+ diversed_score = next_scores - sibling_score
1856
+
1857
+ next_scores = next_scores.reshape(
1858
+ [batch_size, 2 * num_beams * num_beams]
1859
+ )
1860
+ next_tokens = next_tokens.reshape(
1861
+ [batch_size, 2 * num_beams * num_beams]
1862
+ )
1863
+
1864
+ diversed_score = diversed_score.reshape(
1865
+ [batch_size, 2 * num_beams * num_beams]
1866
+ )
1867
+ diversed_score, diversed_tokens = paddle.topk(
1868
+ diversed_score, 2 * num_beams, axis=1
1869
+ )
1870
+
1871
+ # TODO
1872
+ # Use gather_nd() to select origan token and score
1873
+ next_scores = paddle.stack(
1874
+ [
1875
+ paddle.index_select(next_scores[i], diversed_tokens[i])
1876
+ for i in range(next_scores.shape[0])
1877
+ ]
1878
+ )
1879
+ next_tokens = paddle.stack(
1880
+ [
1881
+ paddle.index_select(next_tokens[i], diversed_tokens[i])
1882
+ for i in range(next_tokens.shape[0])
1883
+ ]
1884
+ )
1885
+
1886
+ next_indices = diversed_tokens // (2 * num_beams)
1887
+
1888
+ # stateless
1889
+ beam_outputs = beam_scorer.process(
1890
+ input_ids,
1891
+ next_scores,
1892
+ next_tokens,
1893
+ next_indices,
1894
+ origin_len=origin_len,
1895
+ pad_token_id=pad_token_id,
1896
+ eos_token_id=eos_token_id,
1897
+ )
1898
+ beam_scores = beam_outputs["next_beam_scores"]
1899
+ beam_next_tokens = beam_outputs["next_beam_tokens"]
1900
+ beam_idx = beam_outputs["next_beam_indices"]
1901
+ # beam_idx may contain element -1 and cause error
1902
+ # PR: https://github.com/PaddlePaddle/Paddle/issues/57366
1903
+ beam_idx = paddle.maximum(beam_idx, paddle.full_like(beam_idx, 0))
1904
+
1905
+ cur_len += 1
1906
+ input_ids = paddle.concat(
1907
+ [
1908
+ paddle.index_select(input_ids, beam_idx),
1909
+ beam_next_tokens.unsqueeze(-1),
1910
+ ],
1911
+ axis=-1,
1912
+ )
1913
+
1914
+ if beam_scorer.is_done or stopping_criteria(input_ids, beam_scores):
1915
+ if not synced_gpus:
1916
+ break
1917
+ else:
1918
+ generate_end = True
1919
+
1920
+ model_kwargs = self.update_model_kwargs_for_generation(
1921
+ outputs, model_kwargs, is_encoder_decoder=self.is_encoder_decoder
1922
+ )
1923
+ if "cache" in model_kwargs:
1924
+ # reorder the cache
1925
+ model_kwargs["cache"] = self.reorder_cache(
1926
+ model_kwargs["cache"], beam_idx
1927
+ )
1928
+ if "past_key_values" in model_kwargs:
1929
+ # reorder the cache
1930
+ model_kwargs["past_key_values"] = self.reorder_cache(
1931
+ model_kwargs["past_key_values"], beam_idx
1932
+ )
1933
+ if fast_ptq_sampling:
1934
+ break
1935
+
1936
+ pred_ids, scores = beam_scorer.finalize(
1937
+ input_ids,
1938
+ beam_scores,
1939
+ next_tokens,
1940
+ next_indices,
1941
+ origin_len=origin_len,
1942
+ pad_token_id=pad_token_id,
1943
+ eos_token_id=eos_token_id,
1944
+ )
1945
+ return pred_ids[:, origin_len:] if trunc_input else input_ids, scores
1946
+
1947
+ def group_beam_search(
1948
+ self,
1949
+ input_ids,
1950
+ beam_scorer,
1951
+ logits_processors,
1952
+ max_length,
1953
+ pad_token_id,
1954
+ eos_token_id,
1955
+ stopping_criteria=None,
1956
+ fast_ptq_sampling=False,
1957
+ trunc_input=True,
1958
+ synced_gpus=False,
1959
+ **model_kwargs,
1960
+ ):
1961
+ model_kwargs["use_cache"] = model_kwargs.get("use_cache", True)
1962
+ logits_processors = (
1963
+ logits_processors
1964
+ if logits_processors is not None
1965
+ else LogitsProcessorList()
1966
+ )
1967
+
1968
+ # max_length will be convert to MaxLengthCriteria
1969
+ stopping_criteria = (
1970
+ stopping_criteria
1971
+ if stopping_criteria is not None
1972
+ else StoppingCriteriaList()
1973
+ )
1974
+ if max_length is not None:
1975
+ # logging.warning(
1976
+ # "`max_length` is deprecated in this function, use"
1977
+ # " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead."
1978
+ # )
1979
+ stopping_criteria = validate_stopping_criteria(
1980
+ stopping_criteria, max_length
1981
+ )
1982
+
1983
+ batch_size = len(beam_scorer._beam_hyps)
1984
+ num_beams = beam_scorer.num_beams
1985
+ num_beam_groups = beam_scorer.num_beam_groups
1986
+ num_sub_beams = num_beams // num_beam_groups
1987
+
1988
+ batch_beam_size, cur_len = input_ids.shape
1989
+ origin_len = cur_len
1990
+
1991
+ assert (
1992
+ num_beams * batch_size == batch_beam_size
1993
+ ), "Batch dimension of `input_ids` should be {}, but received {}.".format(
1994
+ num_beams * batch_size, batch_beam_size
1995
+ )
1996
+
1997
+ beam_scores = paddle.full(
1998
+ (batch_size, num_beams),
1999
+ get_scale_by_dtype(return_positive=False),
2000
+ dtype="float32",
2001
+ )
2002
+ # initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in
2003
+ # the same group don't produce same tokens everytime.
2004
+ beam_scores[:, ::num_sub_beams] = 0
2005
+ beam_scores = paddle.reshape(beam_scores, [-1])
2006
+
2007
+ generate_end = False
2008
+ while True:
2009
+ if synced_gpus:
2010
+ # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
2011
+ # The following logic allows an early break if all peers finished generating their sequence
2012
+ this_peer_finished_flag = paddle.to_tensor(0.0 if generate_end else 1.0)
2013
+ # send 0.0 if we finished, 1.0 otherwise
2014
+ dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
2015
+ # did all peers finish? the reduced sum will be 0.0 then
2016
+ if this_peer_finished_flag.item() == 0.0:
2017
+ break
2018
+ # predicted tokens in cur_len step
2019
+ current_tokens = paddle.zeros(
2020
+ shape=[batch_size * num_beams], dtype=input_ids.dtype
2021
+ )
2022
+
2023
+ # indices which will form the beams in the next time step
2024
+ reordering_indices = paddle.zeros(
2025
+ shape=[batch_size * num_beams], dtype="int64"
2026
+ )
2027
+ # prepare model inputs & get model output
2028
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
2029
+ outputs = self(**model_inputs)
2030
+ if synced_gpus and generate_end:
2031
+ cur_len = cur_len + 1
2032
+ continue # don't waste resources running the code we don't need
2033
+
2034
+ for beam_group_idx in range(num_beam_groups):
2035
+ group_start_idx = beam_group_idx * num_sub_beams
2036
+ group_end_idx = min(group_start_idx + num_sub_beams, num_beams)
2037
+ group_size = group_end_idx - group_start_idx
2038
+
2039
+ # indices of beams of current group among all sentences in batch
2040
+ batch_group_indices = []
2041
+
2042
+ for batch_idx in range(batch_size):
2043
+ batch_group_indices.extend(
2044
+ [
2045
+ batch_idx * num_beams + idx
2046
+ for idx in range(group_start_idx, group_end_idx)
2047
+ ]
2048
+ )
2049
+
2050
+ group_input_ids = input_ids[batch_group_indices]
2051
+
2052
+ if isinstance(outputs, tuple):
2053
+ logits = outputs[0]
2054
+ elif isinstance(outputs, ModelOutput):
2055
+ logits = outputs.logits
2056
+ else:
2057
+ logits = outputs
2058
+
2059
+ logits = logits[:, -1, :]
2060
+ logits = paddle.index_select(
2061
+ logits, paddle.to_tensor(batch_group_indices)
2062
+ )
2063
+ logits = self.adjust_logits_during_generation(logits)
2064
+
2065
+ next_scores = F.softmax(logits)
2066
+ next_scores = paddle.log(next_scores)
2067
+ vocab_size = next_scores.shape[-1]
2068
+
2069
+ next_scores = logits_processors(
2070
+ group_input_ids,
2071
+ next_scores,
2072
+ current_tokens=current_tokens,
2073
+ beam_group_idx=beam_group_idx,
2074
+ )
2075
+
2076
+ next_scores = next_scores + beam_scores[batch_group_indices].unsqueeze(
2077
+ -1
2078
+ )
2079
+
2080
+ # reshape for beam search
2081
+ next_scores = next_scores.reshape([batch_size, group_size * vocab_size])
2082
+
2083
+ next_scores, next_tokens = paddle.topk(
2084
+ next_scores, 2 * group_size, axis=1
2085
+ )
2086
+
2087
+ next_indices = next_tokens // vocab_size
2088
+ next_tokens = next_tokens % vocab_size
2089
+
2090
+ beam_outputs = beam_scorer.process(
2091
+ group_input_ids,
2092
+ next_scores,
2093
+ next_tokens,
2094
+ next_indices,
2095
+ origin_len=origin_len,
2096
+ pad_token_id=pad_token_id,
2097
+ eos_token_id=eos_token_id,
2098
+ )
2099
+
2100
+ beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]
2101
+ beam_next_tokens = beam_outputs["next_beam_tokens"]
2102
+ beam_idx = beam_outputs["next_beam_indices"]
2103
+ # beam_idx may contain element -1 and cause error
2104
+ # PR: https://github.com/PaddlePaddle/Paddle/issues/57366
2105
+ beam_idx = paddle.maximum(beam_idx, paddle.full_like(beam_idx, 0))
2106
+
2107
+ input_ids[batch_group_indices] = group_input_ids[beam_idx]
2108
+ group_input_ids = paddle.concat(
2109
+ [
2110
+ paddle.index_select(group_input_ids, index=beam_idx),
2111
+ beam_next_tokens.unsqueeze(-1),
2112
+ ],
2113
+ axis=-1,
2114
+ )
2115
+ current_tokens[batch_group_indices] = beam_next_tokens
2116
+
2117
+ reordering_indices[batch_group_indices] = (
2118
+ num_beams * (beam_idx // group_size)
2119
+ + group_start_idx
2120
+ + (beam_idx % group_size)
2121
+ )
2122
+
2123
+ input_ids = paddle.concat(
2124
+ [input_ids, current_tokens.unsqueeze(-1)], axis=-1
2125
+ )
2126
+
2127
+ cur_len += 1
2128
+
2129
+ if beam_scorer.is_done or stopping_criteria(input_ids, beam_scores):
2130
+ if not synced_gpus:
2131
+ break
2132
+ else:
2133
+ generate_end = True
2134
+
2135
+ model_kwargs = self.update_model_kwargs_for_generation(
2136
+ outputs, model_kwargs, is_encoder_decoder=self.is_encoder_decoder
2137
+ )
2138
+
2139
+ if "cache" in model_kwargs:
2140
+ # reorder the cache
2141
+ model_kwargs["cache"] = self.reorder_cache(
2142
+ model_kwargs["cache"], reordering_indices
2143
+ )
2144
+ if "past_key_values" in model_kwargs:
2145
+ # reorder the cache
2146
+ model_kwargs["past_key_values"] = self.reorder_cache(
2147
+ model_kwargs["past_key_values"], reordering_indices
2148
+ )
2149
+
2150
+ if fast_ptq_sampling:
2151
+ break
2152
+
2153
+ pred_ids, scores = beam_scorer.finalize(
2154
+ input_ids,
2155
+ beam_scores,
2156
+ next_tokens,
2157
+ next_indices,
2158
+ origin_len=origin_len,
2159
+ pad_token_id=pad_token_id,
2160
+ eos_token_id=eos_token_id,
2161
+ )
2162
+ return pred_ids[:, origin_len:] if trunc_input else input_ids, scores