paddlex 3.0.0rc0__py3-none-any.whl → 3.0.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (785) hide show
  1. paddlex/.version +1 -1
  2. paddlex/__init__.py +17 -34
  3. paddlex/__main__.py +1 -1
  4. paddlex/configs/modules/doc_vlm/PP-DocBee-2B.yaml +14 -0
  5. paddlex/configs/modules/doc_vlm/PP-DocBee-7B.yaml +14 -0
  6. paddlex/configs/modules/open_vocabulary_detection/YOLO-Worldv2-L.yaml +13 -0
  7. paddlex/configs/pipelines/anomaly_detection.yaml +1 -1
  8. paddlex/configs/pipelines/doc_understanding.yaml +9 -0
  9. paddlex/configs/pipelines/ts_anomaly_detection.yaml +1 -1
  10. paddlex/configs/pipelines/ts_classification.yaml +1 -1
  11. paddlex/configs/pipelines/ts_forecast.yaml +1 -1
  12. paddlex/constants.py +17 -0
  13. paddlex/engine.py +7 -5
  14. paddlex/hpip_links.html +23 -11
  15. paddlex/inference/__init__.py +3 -3
  16. paddlex/inference/common/__init__.py +1 -1
  17. paddlex/inference/common/batch_sampler/__init__.py +5 -4
  18. paddlex/inference/common/batch_sampler/audio_batch_sampler.py +5 -6
  19. paddlex/inference/common/batch_sampler/base_batch_sampler.py +20 -16
  20. paddlex/inference/common/batch_sampler/det_3d_batch_sampler.py +4 -7
  21. paddlex/inference/common/batch_sampler/doc_vlm_batch_sampler.py +64 -0
  22. paddlex/inference/common/batch_sampler/image_batch_sampler.py +12 -36
  23. paddlex/inference/common/batch_sampler/ts_batch_sampler.py +9 -10
  24. paddlex/inference/common/batch_sampler/video_batch_sampler.py +2 -22
  25. paddlex/inference/common/reader/__init__.py +4 -4
  26. paddlex/inference/common/reader/audio_reader.py +3 -3
  27. paddlex/inference/common/reader/det_3d_reader.py +7 -5
  28. paddlex/inference/common/reader/image_reader.py +16 -12
  29. paddlex/inference/common/reader/ts_reader.py +3 -2
  30. paddlex/inference/common/reader/video_reader.py +3 -3
  31. paddlex/inference/common/result/__init__.py +7 -7
  32. paddlex/inference/common/result/base_cv_result.py +12 -2
  33. paddlex/inference/common/result/base_result.py +7 -5
  34. paddlex/inference/common/result/base_ts_result.py +1 -2
  35. paddlex/inference/common/result/base_video_result.py +2 -2
  36. paddlex/inference/common/result/mixin.py +12 -13
  37. paddlex/inference/models/__init__.py +41 -85
  38. paddlex/inference/models/anomaly_detection/__init__.py +1 -1
  39. paddlex/inference/models/anomaly_detection/predictor.py +9 -19
  40. paddlex/inference/models/anomaly_detection/processors.py +9 -2
  41. paddlex/inference/models/anomaly_detection/result.py +3 -2
  42. paddlex/inference/models/base/__init__.py +2 -2
  43. paddlex/inference/models/base/predictor/__init__.py +1 -2
  44. paddlex/inference/models/base/predictor/base_predictor.py +284 -39
  45. paddlex/inference/models/common/__init__.py +6 -15
  46. paddlex/inference/models/common/static_infer.py +764 -243
  47. paddlex/inference/models/common/tokenizer/__init__.py +5 -3
  48. paddlex/inference/models/common/tokenizer/bert_tokenizer.py +1 -1
  49. paddlex/inference/models/common/tokenizer/clip_tokenizer.py +609 -0
  50. paddlex/inference/models/common/tokenizer/gpt_tokenizer.py +7 -5
  51. paddlex/inference/models/common/tokenizer/qwen2_tokenizer.py +432 -0
  52. paddlex/inference/models/common/tokenizer/tokenizer_utils.py +72 -64
  53. paddlex/inference/models/common/tokenizer/tokenizer_utils_base.py +337 -121
  54. paddlex/inference/models/common/tokenizer/utils.py +1 -1
  55. paddlex/inference/models/common/tokenizer/vocab.py +1 -1
  56. paddlex/inference/models/common/ts/__init__.py +1 -1
  57. paddlex/inference/models/common/ts/funcs.py +13 -6
  58. paddlex/inference/models/common/ts/processors.py +14 -5
  59. paddlex/inference/models/common/vision/__init__.py +3 -3
  60. paddlex/inference/models/common/vision/funcs.py +17 -12
  61. paddlex/inference/models/common/vision/processors.py +61 -46
  62. paddlex/inference/models/common/vlm/__init__.py +13 -0
  63. paddlex/inference/models/common/vlm/activations.py +189 -0
  64. paddlex/inference/models/common/vlm/bert_padding.py +127 -0
  65. paddlex/inference/models/common/vlm/distributed.py +229 -0
  66. paddlex/inference/models/common/vlm/flash_attn_utils.py +119 -0
  67. paddlex/inference/models/common/vlm/generation/__init__.py +34 -0
  68. paddlex/inference/models/common/vlm/generation/configuration_utils.py +533 -0
  69. paddlex/inference/models/common/vlm/generation/logits_process.py +730 -0
  70. paddlex/inference/models/common/vlm/generation/stopping_criteria.py +106 -0
  71. paddlex/inference/models/common/vlm/generation/utils.py +2162 -0
  72. paddlex/inference/models/common/vlm/transformers/__init__.py +16 -0
  73. paddlex/inference/models/common/vlm/transformers/configuration_utils.py +1037 -0
  74. paddlex/inference/models/common/vlm/transformers/conversion_utils.py +408 -0
  75. paddlex/inference/models/common/vlm/transformers/model_outputs.py +1612 -0
  76. paddlex/inference/models/common/vlm/transformers/model_utils.py +2038 -0
  77. paddlex/inference/models/common/vlm/transformers/utils.py +178 -0
  78. paddlex/inference/models/common/vlm/utils.py +109 -0
  79. paddlex/inference/models/doc_vlm/__init__.py +15 -0
  80. paddlex/inference/models/doc_vlm/modeling/__init__.py +15 -0
  81. paddlex/inference/models/doc_vlm/modeling/qwen2_vl.py +2600 -0
  82. paddlex/inference/models/doc_vlm/predictor.py +198 -0
  83. paddlex/inference/models/doc_vlm/processors/__init__.py +15 -0
  84. paddlex/inference/models/doc_vlm/processors/common.py +372 -0
  85. paddlex/inference/models/doc_vlm/processors/qwen2_vl.py +698 -0
  86. paddlex/inference/models/doc_vlm/result.py +21 -0
  87. paddlex/inference/models/face_feature/__init__.py +1 -1
  88. paddlex/inference/models/face_feature/predictor.py +2 -1
  89. paddlex/inference/models/formula_recognition/__init__.py +1 -1
  90. paddlex/inference/models/formula_recognition/predictor.py +11 -27
  91. paddlex/inference/models/formula_recognition/processors.py +35 -19
  92. paddlex/inference/models/formula_recognition/result.py +19 -12
  93. paddlex/inference/models/image_classification/__init__.py +1 -1
  94. paddlex/inference/models/image_classification/predictor.py +9 -19
  95. paddlex/inference/models/image_classification/processors.py +4 -2
  96. paddlex/inference/models/image_classification/result.py +4 -3
  97. paddlex/inference/models/image_feature/__init__.py +1 -1
  98. paddlex/inference/models/image_feature/predictor.py +9 -19
  99. paddlex/inference/models/image_feature/processors.py +4 -1
  100. paddlex/inference/models/image_feature/result.py +2 -3
  101. paddlex/inference/models/image_multilabel_classification/__init__.py +1 -1
  102. paddlex/inference/models/image_multilabel_classification/predictor.py +7 -6
  103. paddlex/inference/models/image_multilabel_classification/processors.py +6 -2
  104. paddlex/inference/models/image_multilabel_classification/result.py +4 -3
  105. paddlex/inference/models/image_unwarping/__init__.py +1 -1
  106. paddlex/inference/models/image_unwarping/predictor.py +8 -16
  107. paddlex/inference/models/image_unwarping/processors.py +6 -2
  108. paddlex/inference/models/image_unwarping/result.py +4 -2
  109. paddlex/inference/models/instance_segmentation/__init__.py +1 -1
  110. paddlex/inference/models/instance_segmentation/predictor.py +7 -15
  111. paddlex/inference/models/instance_segmentation/processors.py +4 -7
  112. paddlex/inference/models/instance_segmentation/result.py +11 -10
  113. paddlex/inference/models/keypoint_detection/__init__.py +1 -1
  114. paddlex/inference/models/keypoint_detection/predictor.py +2 -3
  115. paddlex/inference/models/keypoint_detection/processors.py +11 -3
  116. paddlex/inference/models/keypoint_detection/result.py +9 -4
  117. paddlex/inference/models/{3d_bev_detection → m_3d_bev_detection}/__init__.py +1 -1
  118. paddlex/inference/models/{3d_bev_detection → m_3d_bev_detection}/predictor.py +15 -26
  119. paddlex/inference/models/{3d_bev_detection → m_3d_bev_detection}/processors.py +26 -14
  120. paddlex/inference/models/{3d_bev_detection → m_3d_bev_detection}/result.py +15 -12
  121. paddlex/inference/models/{3d_bev_detection → m_3d_bev_detection}/visualizer_3d.py +77 -39
  122. paddlex/inference/models/multilingual_speech_recognition/__init__.py +1 -1
  123. paddlex/inference/models/multilingual_speech_recognition/predictor.py +11 -15
  124. paddlex/inference/models/multilingual_speech_recognition/processors.py +45 -53
  125. paddlex/inference/models/multilingual_speech_recognition/result.py +1 -1
  126. paddlex/inference/models/object_detection/__init__.py +1 -1
  127. paddlex/inference/models/object_detection/predictor.py +6 -12
  128. paddlex/inference/models/object_detection/processors.py +36 -31
  129. paddlex/inference/models/object_detection/result.py +5 -4
  130. paddlex/inference/models/object_detection/utils.py +1 -1
  131. paddlex/inference/models/open_vocabulary_detection/__init__.py +1 -1
  132. paddlex/inference/models/open_vocabulary_detection/predictor.py +31 -14
  133. paddlex/inference/models/open_vocabulary_detection/processors/__init__.py +3 -2
  134. paddlex/inference/models/open_vocabulary_detection/processors/common.py +114 -0
  135. paddlex/inference/models/open_vocabulary_detection/processors/groundingdino_processors.py +19 -8
  136. paddlex/inference/models/open_vocabulary_detection/processors/yoloworld_processors.py +209 -0
  137. paddlex/inference/models/open_vocabulary_segmentation/__init__.py +1 -1
  138. paddlex/inference/models/open_vocabulary_segmentation/predictor.py +6 -13
  139. paddlex/inference/models/open_vocabulary_segmentation/processors/__init__.py +1 -1
  140. paddlex/inference/models/open_vocabulary_segmentation/processors/sam_processer.py +12 -12
  141. paddlex/inference/models/open_vocabulary_segmentation/results/__init__.py +1 -1
  142. paddlex/inference/models/open_vocabulary_segmentation/results/sam_result.py +11 -9
  143. paddlex/inference/models/semantic_segmentation/__init__.py +1 -1
  144. paddlex/inference/models/semantic_segmentation/predictor.py +9 -18
  145. paddlex/inference/models/semantic_segmentation/processors.py +11 -8
  146. paddlex/inference/models/semantic_segmentation/result.py +4 -3
  147. paddlex/inference/models/table_structure_recognition/__init__.py +1 -1
  148. paddlex/inference/models/table_structure_recognition/predictor.py +8 -18
  149. paddlex/inference/models/table_structure_recognition/processors.py +23 -29
  150. paddlex/inference/models/table_structure_recognition/result.py +9 -6
  151. paddlex/inference/models/text_detection/__init__.py +1 -1
  152. paddlex/inference/models/text_detection/predictor.py +16 -24
  153. paddlex/inference/models/text_detection/processors.py +74 -36
  154. paddlex/inference/models/text_detection/result.py +9 -4
  155. paddlex/inference/models/text_recognition/__init__.py +1 -1
  156. paddlex/inference/models/text_recognition/predictor.py +11 -19
  157. paddlex/inference/models/text_recognition/processors.py +27 -13
  158. paddlex/inference/models/text_recognition/result.py +3 -2
  159. paddlex/inference/models/ts_anomaly_detection/__init__.py +1 -1
  160. paddlex/inference/models/ts_anomaly_detection/predictor.py +12 -17
  161. paddlex/inference/models/ts_anomaly_detection/processors.py +6 -2
  162. paddlex/inference/models/ts_anomaly_detection/result.py +21 -10
  163. paddlex/inference/models/ts_classification/__init__.py +1 -1
  164. paddlex/inference/models/ts_classification/predictor.py +14 -27
  165. paddlex/inference/models/ts_classification/processors.py +7 -2
  166. paddlex/inference/models/ts_classification/result.py +21 -12
  167. paddlex/inference/models/ts_forecasting/__init__.py +1 -1
  168. paddlex/inference/models/ts_forecasting/predictor.py +13 -18
  169. paddlex/inference/models/ts_forecasting/processors.py +12 -3
  170. paddlex/inference/models/ts_forecasting/result.py +24 -11
  171. paddlex/inference/models/video_classification/__init__.py +1 -1
  172. paddlex/inference/models/video_classification/predictor.py +9 -15
  173. paddlex/inference/models/video_classification/processors.py +24 -24
  174. paddlex/inference/models/video_classification/result.py +7 -3
  175. paddlex/inference/models/video_detection/__init__.py +1 -1
  176. paddlex/inference/models/video_detection/predictor.py +8 -15
  177. paddlex/inference/models/video_detection/processors.py +24 -11
  178. paddlex/inference/models/video_detection/result.py +10 -5
  179. paddlex/inference/pipelines/__init__.py +44 -37
  180. paddlex/inference/pipelines/anomaly_detection/__init__.py +1 -1
  181. paddlex/inference/pipelines/anomaly_detection/pipeline.py +16 -6
  182. paddlex/inference/pipelines/attribute_recognition/__init__.py +1 -1
  183. paddlex/inference/pipelines/attribute_recognition/pipeline.py +13 -8
  184. paddlex/inference/pipelines/attribute_recognition/result.py +10 -8
  185. paddlex/inference/pipelines/base.py +31 -11
  186. paddlex/inference/pipelines/components/__init__.py +14 -8
  187. paddlex/inference/pipelines/components/chat_server/__init__.py +1 -1
  188. paddlex/inference/pipelines/components/chat_server/base.py +2 -2
  189. paddlex/inference/pipelines/components/chat_server/openai_bot_chat.py +8 -8
  190. paddlex/inference/pipelines/components/common/__init__.py +5 -4
  191. paddlex/inference/pipelines/components/common/base_operator.py +2 -1
  192. paddlex/inference/pipelines/components/common/base_result.py +3 -2
  193. paddlex/inference/pipelines/components/common/convert_points_and_boxes.py +1 -2
  194. paddlex/inference/pipelines/components/common/crop_image_regions.py +11 -5
  195. paddlex/inference/pipelines/components/common/seal_det_warp.py +44 -13
  196. paddlex/inference/pipelines/components/common/sort_boxes.py +4 -2
  197. paddlex/inference/pipelines/components/common/warp_image.py +50 -0
  198. paddlex/inference/pipelines/components/faisser.py +9 -4
  199. paddlex/inference/pipelines/components/prompt_engineering/__init__.py +2 -2
  200. paddlex/inference/pipelines/components/prompt_engineering/base.py +2 -2
  201. paddlex/inference/pipelines/components/prompt_engineering/generate_ensemble_prompt.py +2 -1
  202. paddlex/inference/pipelines/components/prompt_engineering/generate_kie_prompt.py +2 -2
  203. paddlex/inference/pipelines/components/retriever/__init__.py +2 -2
  204. paddlex/inference/pipelines/components/retriever/base.py +18 -16
  205. paddlex/inference/pipelines/components/retriever/openai_bot_retriever.py +2 -2
  206. paddlex/inference/pipelines/components/retriever/qianfan_bot_retriever.py +87 -84
  207. paddlex/inference/pipelines/components/utils/__init__.py +1 -1
  208. paddlex/inference/pipelines/components/utils/mixin.py +7 -7
  209. paddlex/inference/pipelines/doc_preprocessor/__init__.py +1 -1
  210. paddlex/inference/pipelines/doc_preprocessor/pipeline.py +21 -28
  211. paddlex/inference/pipelines/doc_preprocessor/result.py +5 -10
  212. paddlex/inference/pipelines/doc_understanding/__init__.py +15 -0
  213. paddlex/inference/pipelines/doc_understanding/pipeline.py +71 -0
  214. paddlex/inference/pipelines/face_recognition/__init__.py +1 -1
  215. paddlex/inference/pipelines/face_recognition/pipeline.py +3 -1
  216. paddlex/inference/pipelines/face_recognition/result.py +3 -2
  217. paddlex/inference/pipelines/formula_recognition/__init__.py +1 -1
  218. paddlex/inference/pipelines/formula_recognition/pipeline.py +22 -16
  219. paddlex/inference/pipelines/formula_recognition/result.py +20 -19
  220. paddlex/inference/pipelines/image_classification/__init__.py +1 -1
  221. paddlex/inference/pipelines/image_classification/pipeline.py +17 -8
  222. paddlex/inference/pipelines/image_multilabel_classification/__init__.py +1 -1
  223. paddlex/inference/pipelines/image_multilabel_classification/pipeline.py +18 -9
  224. paddlex/inference/pipelines/instance_segmentation/__init__.py +1 -1
  225. paddlex/inference/pipelines/instance_segmentation/pipeline.py +17 -6
  226. paddlex/inference/pipelines/keypoint_detection/__init__.py +1 -1
  227. paddlex/inference/pipelines/keypoint_detection/pipeline.py +17 -6
  228. paddlex/inference/pipelines/layout_parsing/__init__.py +1 -1
  229. paddlex/inference/pipelines/layout_parsing/pipeline.py +23 -12
  230. paddlex/inference/pipelines/layout_parsing/pipeline_v2.py +16 -6
  231. paddlex/inference/pipelines/layout_parsing/result.py +5 -4
  232. paddlex/inference/pipelines/layout_parsing/result_v2.py +5 -8
  233. paddlex/inference/pipelines/layout_parsing/utils.py +7 -8
  234. paddlex/inference/pipelines/{3d_bev_detection → m_3d_bev_detection}/__init__.py +1 -1
  235. paddlex/inference/pipelines/{3d_bev_detection → m_3d_bev_detection}/pipeline.py +17 -10
  236. paddlex/inference/pipelines/multilingual_speech_recognition/__init__.py +1 -1
  237. paddlex/inference/pipelines/multilingual_speech_recognition/pipeline.py +17 -6
  238. paddlex/inference/pipelines/object_detection/__init__.py +1 -1
  239. paddlex/inference/pipelines/object_detection/pipeline.py +16 -6
  240. paddlex/inference/pipelines/ocr/__init__.py +1 -1
  241. paddlex/inference/pipelines/ocr/pipeline.py +28 -11
  242. paddlex/inference/pipelines/ocr/result.py +13 -9
  243. paddlex/inference/pipelines/open_vocabulary_detection/__init__.py +1 -1
  244. paddlex/inference/pipelines/open_vocabulary_detection/pipeline.py +17 -6
  245. paddlex/inference/pipelines/open_vocabulary_segmentation/__init__.py +1 -1
  246. paddlex/inference/pipelines/open_vocabulary_segmentation/pipeline.py +17 -6
  247. paddlex/inference/pipelines/pp_chatocr/__init__.py +1 -1
  248. paddlex/inference/pipelines/pp_chatocr/pipeline_base.py +14 -5
  249. paddlex/inference/pipelines/pp_chatocr/pipeline_v3.py +22 -11
  250. paddlex/inference/pipelines/pp_chatocr/pipeline_v4.py +31 -13
  251. paddlex/inference/pipelines/pp_shitu_v2/__init__.py +1 -1
  252. paddlex/inference/pipelines/pp_shitu_v2/pipeline.py +12 -8
  253. paddlex/inference/pipelines/pp_shitu_v2/result.py +4 -4
  254. paddlex/inference/pipelines/rotated_object_detection/__init__.py +1 -1
  255. paddlex/inference/pipelines/rotated_object_detection/pipeline.py +17 -6
  256. paddlex/inference/pipelines/seal_recognition/__init__.py +1 -1
  257. paddlex/inference/pipelines/seal_recognition/pipeline.py +21 -13
  258. paddlex/inference/pipelines/seal_recognition/result.py +4 -2
  259. paddlex/inference/pipelines/semantic_segmentation/__init__.py +1 -1
  260. paddlex/inference/pipelines/semantic_segmentation/pipeline.py +17 -6
  261. paddlex/inference/pipelines/small_object_detection/__init__.py +1 -1
  262. paddlex/inference/pipelines/small_object_detection/pipeline.py +17 -6
  263. paddlex/inference/pipelines/table_recognition/__init__.py +1 -1
  264. paddlex/inference/pipelines/table_recognition/pipeline.py +41 -25
  265. paddlex/inference/pipelines/table_recognition/pipeline_v2.py +65 -33
  266. paddlex/inference/pipelines/table_recognition/result.py +11 -9
  267. paddlex/inference/pipelines/table_recognition/table_recognition_post_processing.py +12 -8
  268. paddlex/inference/pipelines/table_recognition/table_recognition_post_processing_v2.py +46 -32
  269. paddlex/inference/pipelines/table_recognition/utils.py +1 -1
  270. paddlex/inference/pipelines/ts_anomaly_detection/__init__.py +1 -1
  271. paddlex/inference/pipelines/ts_anomaly_detection/pipeline.py +16 -6
  272. paddlex/inference/pipelines/ts_classification/__init__.py +1 -1
  273. paddlex/inference/pipelines/ts_classification/pipeline.py +16 -6
  274. paddlex/inference/pipelines/ts_forecasting/__init__.py +1 -1
  275. paddlex/inference/pipelines/ts_forecasting/pipeline.py +16 -6
  276. paddlex/inference/pipelines/video_classification/__init__.py +1 -1
  277. paddlex/inference/pipelines/video_classification/pipeline.py +17 -6
  278. paddlex/inference/pipelines/video_detection/__init__.py +1 -1
  279. paddlex/inference/pipelines/video_detection/pipeline.py +20 -7
  280. paddlex/inference/serving/__init__.py +5 -1
  281. paddlex/inference/serving/basic_serving/__init__.py +1 -1
  282. paddlex/inference/serving/basic_serving/_app.py +31 -19
  283. paddlex/inference/serving/basic_serving/_pipeline_apps/__init__.py +7 -4
  284. paddlex/inference/serving/basic_serving/_pipeline_apps/_common/__init__.py +1 -1
  285. paddlex/inference/serving/basic_serving/_pipeline_apps/_common/common.py +7 -3
  286. paddlex/inference/serving/basic_serving/_pipeline_apps/_common/image_recognition.py +1 -1
  287. paddlex/inference/serving/basic_serving/_pipeline_apps/_common/ocr.py +7 -2
  288. paddlex/inference/serving/basic_serving/_pipeline_apps/anomaly_detection.py +10 -7
  289. paddlex/inference/serving/basic_serving/_pipeline_apps/doc_preprocessor.py +10 -7
  290. paddlex/inference/serving/basic_serving/_pipeline_apps/doc_understanding.py +153 -0
  291. paddlex/inference/serving/basic_serving/_pipeline_apps/face_recognition.py +16 -13
  292. paddlex/inference/serving/basic_serving/_pipeline_apps/formula_recognition.py +10 -7
  293. paddlex/inference/serving/basic_serving/_pipeline_apps/human_keypoint_detection.py +10 -7
  294. paddlex/inference/serving/basic_serving/_pipeline_apps/image_classification.py +10 -7
  295. paddlex/inference/serving/basic_serving/_pipeline_apps/image_multilabel_classification.py +10 -7
  296. paddlex/inference/serving/basic_serving/_pipeline_apps/instance_segmentation.py +13 -7
  297. paddlex/inference/serving/basic_serving/_pipeline_apps/layout_parsing.py +10 -7
  298. paddlex/inference/serving/basic_serving/_pipeline_apps/m_3d_bev_detection.py +10 -7
  299. paddlex/inference/serving/basic_serving/_pipeline_apps/multilingual_speech_recognition.py +10 -7
  300. paddlex/inference/serving/basic_serving/_pipeline_apps/object_detection.py +10 -7
  301. paddlex/inference/serving/basic_serving/_pipeline_apps/ocr.py +10 -7
  302. paddlex/inference/serving/basic_serving/_pipeline_apps/open_vocabulary_detection.py +10 -7
  303. paddlex/inference/serving/basic_serving/_pipeline_apps/open_vocabulary_segmentation.py +13 -7
  304. paddlex/inference/serving/basic_serving/_pipeline_apps/pedestrian_attribute_recognition.py +10 -7
  305. paddlex/inference/serving/basic_serving/_pipeline_apps/pp_chatocrv3_doc.py +14 -11
  306. paddlex/inference/serving/basic_serving/_pipeline_apps/pp_chatocrv4_doc.py +16 -13
  307. paddlex/inference/serving/basic_serving/_pipeline_apps/pp_shituv2.py +16 -13
  308. paddlex/inference/serving/basic_serving/_pipeline_apps/pp_structurev3.py +10 -7
  309. paddlex/inference/serving/basic_serving/_pipeline_apps/rotated_object_detection.py +10 -7
  310. paddlex/inference/serving/basic_serving/_pipeline_apps/seal_recognition.py +10 -7
  311. paddlex/inference/serving/basic_serving/_pipeline_apps/semantic_segmentation.py +10 -7
  312. paddlex/inference/serving/basic_serving/_pipeline_apps/small_object_detection.py +10 -7
  313. paddlex/inference/serving/basic_serving/_pipeline_apps/table_recognition.py +10 -7
  314. paddlex/inference/serving/basic_serving/_pipeline_apps/table_recognition_v2.py +10 -7
  315. paddlex/inference/serving/basic_serving/_pipeline_apps/ts_anomaly_detection.py +10 -7
  316. paddlex/inference/serving/basic_serving/_pipeline_apps/ts_classification.py +10 -7
  317. paddlex/inference/serving/basic_serving/_pipeline_apps/ts_forecast.py +10 -7
  318. paddlex/inference/serving/basic_serving/_pipeline_apps/vehicle_attribute_recognition.py +10 -7
  319. paddlex/inference/serving/basic_serving/_pipeline_apps/video_classification.py +10 -7
  320. paddlex/inference/serving/basic_serving/_pipeline_apps/video_detection.py +10 -7
  321. paddlex/inference/serving/basic_serving/_server.py +9 -4
  322. paddlex/inference/serving/infra/__init__.py +1 -1
  323. paddlex/inference/serving/infra/config.py +1 -1
  324. paddlex/inference/serving/infra/models.py +13 -6
  325. paddlex/inference/serving/infra/storage.py +9 -4
  326. paddlex/inference/serving/infra/utils.py +37 -9
  327. paddlex/inference/serving/schemas/__init__.py +1 -1
  328. paddlex/inference/serving/schemas/anomaly_detection.py +1 -1
  329. paddlex/inference/serving/schemas/doc_preprocessor.py +1 -1
  330. paddlex/inference/serving/schemas/doc_understanding.py +78 -0
  331. paddlex/inference/serving/schemas/face_recognition.py +1 -1
  332. paddlex/inference/serving/schemas/formula_recognition.py +1 -1
  333. paddlex/inference/serving/schemas/human_keypoint_detection.py +1 -1
  334. paddlex/inference/serving/schemas/image_classification.py +1 -1
  335. paddlex/inference/serving/schemas/image_multilabel_classification.py +1 -1
  336. paddlex/inference/serving/schemas/instance_segmentation.py +1 -1
  337. paddlex/inference/serving/schemas/layout_parsing.py +1 -1
  338. paddlex/inference/serving/schemas/m_3d_bev_detection.py +1 -1
  339. paddlex/inference/serving/schemas/multilingual_speech_recognition.py +1 -1
  340. paddlex/inference/serving/schemas/object_detection.py +1 -1
  341. paddlex/inference/serving/schemas/ocr.py +1 -1
  342. paddlex/inference/serving/schemas/open_vocabulary_detection.py +1 -1
  343. paddlex/inference/serving/schemas/open_vocabulary_segmentation.py +1 -1
  344. paddlex/inference/serving/schemas/pedestrian_attribute_recognition.py +1 -1
  345. paddlex/inference/serving/schemas/pp_chatocrv3_doc.py +1 -1
  346. paddlex/inference/serving/schemas/pp_chatocrv4_doc.py +1 -1
  347. paddlex/inference/serving/schemas/pp_shituv2.py +1 -1
  348. paddlex/inference/serving/schemas/pp_structurev3.py +1 -1
  349. paddlex/inference/serving/schemas/rotated_object_detection.py +1 -1
  350. paddlex/inference/serving/schemas/seal_recognition.py +1 -1
  351. paddlex/inference/serving/schemas/semantic_segmentation.py +1 -1
  352. paddlex/inference/serving/schemas/shared/__init__.py +1 -1
  353. paddlex/inference/serving/schemas/shared/classification.py +1 -1
  354. paddlex/inference/serving/schemas/shared/image_segmentation.py +1 -1
  355. paddlex/inference/serving/schemas/shared/object_detection.py +1 -1
  356. paddlex/inference/serving/schemas/shared/ocr.py +1 -1
  357. paddlex/inference/serving/schemas/small_object_detection.py +1 -1
  358. paddlex/inference/serving/schemas/table_recognition.py +1 -1
  359. paddlex/inference/serving/schemas/table_recognition_v2.py +1 -1
  360. paddlex/inference/serving/schemas/ts_anomaly_detection.py +1 -1
  361. paddlex/inference/serving/schemas/ts_classification.py +1 -1
  362. paddlex/inference/serving/schemas/ts_forecast.py +1 -1
  363. paddlex/inference/serving/schemas/vehicle_attribute_recognition.py +1 -1
  364. paddlex/inference/serving/schemas/video_classification.py +1 -1
  365. paddlex/inference/serving/schemas/video_detection.py +1 -1
  366. paddlex/inference/utils/__init__.py +1 -1
  367. paddlex/inference/utils/benchmark.py +332 -179
  368. paddlex/inference/utils/color_map.py +1 -1
  369. paddlex/inference/utils/get_pipeline_path.py +1 -1
  370. paddlex/inference/utils/hpi.py +251 -0
  371. paddlex/inference/utils/hpi_model_info_collection.json +2252 -0
  372. paddlex/inference/utils/io/__init__.py +11 -11
  373. paddlex/inference/utils/io/readers.py +22 -18
  374. paddlex/inference/utils/io/style.py +21 -14
  375. paddlex/inference/utils/io/tablepyxl.py +13 -5
  376. paddlex/inference/utils/io/writers.py +9 -10
  377. paddlex/inference/utils/model_paths.py +48 -0
  378. paddlex/inference/utils/{new_ir_blacklist.py → new_ir_blocklist.py} +1 -2
  379. paddlex/inference/utils/official_models.py +264 -262
  380. paddlex/inference/utils/pp_option.py +164 -93
  381. paddlex/inference/utils/trt_blocklist.py +43 -0
  382. paddlex/inference/utils/trt_config.py +420 -0
  383. paddlex/model.py +28 -10
  384. paddlex/modules/__init__.py +57 -80
  385. paddlex/modules/anomaly_detection/__init__.py +2 -2
  386. paddlex/modules/anomaly_detection/dataset_checker/__init__.py +2 -3
  387. paddlex/modules/anomaly_detection/dataset_checker/dataset_src/__init__.py +2 -2
  388. paddlex/modules/anomaly_detection/dataset_checker/dataset_src/analyse_dataset.py +6 -3
  389. paddlex/modules/anomaly_detection/dataset_checker/dataset_src/check_dataset.py +8 -4
  390. paddlex/modules/anomaly_detection/dataset_checker/dataset_src/convert_dataset.py +7 -4
  391. paddlex/modules/anomaly_detection/dataset_checker/dataset_src/split_dataset.py +2 -2
  392. paddlex/modules/anomaly_detection/dataset_checker/dataset_src/utils/__init__.py +1 -1
  393. paddlex/modules/anomaly_detection/dataset_checker/dataset_src/utils/visualizer.py +7 -2
  394. paddlex/modules/anomaly_detection/evaluator.py +1 -1
  395. paddlex/modules/anomaly_detection/exportor.py +1 -1
  396. paddlex/modules/anomaly_detection/model_list.py +1 -1
  397. paddlex/modules/anomaly_detection/trainer.py +3 -4
  398. paddlex/modules/base/__init__.py +5 -5
  399. paddlex/modules/base/build_model.py +1 -2
  400. paddlex/modules/base/dataset_checker/__init__.py +2 -2
  401. paddlex/modules/base/dataset_checker/dataset_checker.py +4 -4
  402. paddlex/modules/base/dataset_checker/utils.py +1 -3
  403. paddlex/modules/base/evaluator.py +8 -8
  404. paddlex/modules/base/exportor.py +12 -13
  405. paddlex/modules/base/trainer.py +21 -11
  406. paddlex/modules/base/utils/__init__.py +13 -0
  407. paddlex/modules/base/utils/cinn_setting.py +89 -0
  408. paddlex/modules/base/utils/coco_eval.py +94 -0
  409. paddlex/modules/base/utils/topk_eval.py +118 -0
  410. paddlex/modules/doc_vlm/__init__.py +18 -0
  411. paddlex/modules/doc_vlm/dataset_checker.py +29 -0
  412. paddlex/modules/doc_vlm/evaluator.py +29 -0
  413. paddlex/modules/doc_vlm/exportor.py +29 -0
  414. paddlex/modules/doc_vlm/model_list.py +16 -0
  415. paddlex/modules/doc_vlm/trainer.py +41 -0
  416. paddlex/modules/face_recognition/__init__.py +2 -2
  417. paddlex/modules/face_recognition/dataset_checker/__init__.py +2 -2
  418. paddlex/modules/face_recognition/dataset_checker/dataset_src/__init__.py +1 -1
  419. paddlex/modules/face_recognition/dataset_checker/dataset_src/check_dataset.py +3 -5
  420. paddlex/modules/face_recognition/dataset_checker/dataset_src/utils/__init__.py +1 -1
  421. paddlex/modules/face_recognition/dataset_checker/dataset_src/utils/visualizer.py +2 -5
  422. paddlex/modules/face_recognition/evaluator.py +1 -1
  423. paddlex/modules/face_recognition/exportor.py +1 -1
  424. paddlex/modules/face_recognition/model_list.py +1 -1
  425. paddlex/modules/face_recognition/trainer.py +1 -1
  426. paddlex/modules/formula_recognition/__init__.py +2 -2
  427. paddlex/modules/formula_recognition/dataset_checker/__init__.py +3 -3
  428. paddlex/modules/formula_recognition/dataset_checker/dataset_src/__init__.py +2 -2
  429. paddlex/modules/formula_recognition/dataset_checker/dataset_src/analyse_dataset.py +13 -12
  430. paddlex/modules/formula_recognition/dataset_checker/dataset_src/check_dataset.py +2 -6
  431. paddlex/modules/formula_recognition/dataset_checker/dataset_src/convert_dataset.py +11 -10
  432. paddlex/modules/formula_recognition/dataset_checker/dataset_src/split_dataset.py +1 -2
  433. paddlex/modules/formula_recognition/evaluator.py +1 -1
  434. paddlex/modules/formula_recognition/exportor.py +1 -1
  435. paddlex/modules/formula_recognition/model_list.py +1 -1
  436. paddlex/modules/formula_recognition/trainer.py +2 -3
  437. paddlex/modules/general_recognition/__init__.py +2 -2
  438. paddlex/modules/general_recognition/dataset_checker/__init__.py +2 -2
  439. paddlex/modules/general_recognition/dataset_checker/dataset_src/__init__.py +2 -2
  440. paddlex/modules/general_recognition/dataset_checker/dataset_src/analyse_dataset.py +7 -9
  441. paddlex/modules/general_recognition/dataset_checker/dataset_src/check_dataset.py +4 -5
  442. paddlex/modules/general_recognition/dataset_checker/dataset_src/convert_dataset.py +6 -5
  443. paddlex/modules/general_recognition/dataset_checker/dataset_src/split_dataset.py +1 -1
  444. paddlex/modules/general_recognition/dataset_checker/dataset_src/utils/__init__.py +1 -1
  445. paddlex/modules/general_recognition/dataset_checker/dataset_src/utils/visualizer.py +2 -5
  446. paddlex/modules/general_recognition/evaluator.py +1 -1
  447. paddlex/modules/general_recognition/exportor.py +1 -1
  448. paddlex/modules/general_recognition/model_list.py +1 -1
  449. paddlex/modules/general_recognition/trainer.py +1 -1
  450. paddlex/modules/image_classification/__init__.py +2 -2
  451. paddlex/modules/image_classification/dataset_checker/__init__.py +2 -2
  452. paddlex/modules/image_classification/dataset_checker/dataset_src/__init__.py +2 -2
  453. paddlex/modules/image_classification/dataset_checker/dataset_src/analyse_dataset.py +8 -9
  454. paddlex/modules/image_classification/dataset_checker/dataset_src/check_dataset.py +4 -3
  455. paddlex/modules/image_classification/dataset_checker/dataset_src/convert_dataset.py +4 -4
  456. paddlex/modules/image_classification/dataset_checker/dataset_src/split_dataset.py +1 -1
  457. paddlex/modules/image_classification/dataset_checker/dataset_src/utils/__init__.py +1 -1
  458. paddlex/modules/image_classification/dataset_checker/dataset_src/utils/visualizer.py +2 -5
  459. paddlex/modules/image_classification/evaluator.py +1 -1
  460. paddlex/modules/image_classification/exportor.py +1 -1
  461. paddlex/modules/image_classification/model_list.py +1 -1
  462. paddlex/modules/image_classification/trainer.py +3 -3
  463. paddlex/modules/image_unwarping/__init__.py +1 -1
  464. paddlex/modules/image_unwarping/model_list.py +1 -1
  465. paddlex/modules/instance_segmentation/__init__.py +2 -2
  466. paddlex/modules/instance_segmentation/dataset_checker/__init__.py +2 -3
  467. paddlex/modules/instance_segmentation/dataset_checker/dataset_src/__init__.py +2 -2
  468. paddlex/modules/instance_segmentation/dataset_checker/dataset_src/analyse_dataset.py +9 -5
  469. paddlex/modules/instance_segmentation/dataset_checker/dataset_src/check_dataset.py +8 -5
  470. paddlex/modules/instance_segmentation/dataset_checker/dataset_src/convert_dataset.py +8 -8
  471. paddlex/modules/instance_segmentation/dataset_checker/dataset_src/split_dataset.py +7 -4
  472. paddlex/modules/instance_segmentation/dataset_checker/dataset_src/utils/__init__.py +1 -1
  473. paddlex/modules/instance_segmentation/dataset_checker/dataset_src/utils/visualizer.py +10 -8
  474. paddlex/modules/instance_segmentation/evaluator.py +1 -1
  475. paddlex/modules/instance_segmentation/exportor.py +1 -1
  476. paddlex/modules/instance_segmentation/model_list.py +1 -1
  477. paddlex/modules/instance_segmentation/trainer.py +1 -1
  478. paddlex/modules/keypoint_detection/__init__.py +2 -2
  479. paddlex/modules/keypoint_detection/dataset_checker/__init__.py +2 -2
  480. paddlex/modules/keypoint_detection/dataset_checker/dataset_src/__init__.py +1 -1
  481. paddlex/modules/keypoint_detection/dataset_checker/dataset_src/check_dataset.py +10 -5
  482. paddlex/modules/keypoint_detection/dataset_checker/dataset_src/utils/__init__.py +1 -1
  483. paddlex/modules/keypoint_detection/dataset_checker/dataset_src/utils/visualizer.py +8 -3
  484. paddlex/modules/keypoint_detection/evaluator.py +1 -1
  485. paddlex/modules/keypoint_detection/exportor.py +1 -1
  486. paddlex/modules/keypoint_detection/model_list.py +1 -1
  487. paddlex/modules/keypoint_detection/trainer.py +2 -2
  488. paddlex/modules/{3d_bev_detection → m_3d_bev_detection}/__init__.py +2 -2
  489. paddlex/modules/{3d_bev_detection → m_3d_bev_detection}/dataset_checker/__init__.py +3 -3
  490. paddlex/modules/{3d_bev_detection → m_3d_bev_detection}/dataset_checker/dataset_src/__init__.py +2 -2
  491. paddlex/modules/{3d_bev_detection → m_3d_bev_detection}/dataset_checker/dataset_src/analyse_dataset.py +8 -8
  492. paddlex/modules/{3d_bev_detection → m_3d_bev_detection}/dataset_checker/dataset_src/check_dataset.py +1 -2
  493. paddlex/modules/{3d_bev_detection → m_3d_bev_detection}/evaluator.py +1 -1
  494. paddlex/modules/{3d_bev_detection → m_3d_bev_detection}/exportor.py +1 -1
  495. paddlex/modules/{3d_bev_detection → m_3d_bev_detection}/model_list.py +1 -1
  496. paddlex/modules/{3d_bev_detection → m_3d_bev_detection}/trainer.py +5 -7
  497. paddlex/modules/multilabel_classification/__init__.py +2 -2
  498. paddlex/modules/multilabel_classification/dataset_checker/__init__.py +2 -2
  499. paddlex/modules/multilabel_classification/dataset_checker/dataset_src/__init__.py +2 -2
  500. paddlex/modules/multilabel_classification/dataset_checker/dataset_src/analyse_dataset.py +8 -9
  501. paddlex/modules/multilabel_classification/dataset_checker/dataset_src/check_dataset.py +4 -3
  502. paddlex/modules/multilabel_classification/dataset_checker/dataset_src/convert_dataset.py +10 -7
  503. paddlex/modules/multilabel_classification/dataset_checker/dataset_src/split_dataset.py +1 -1
  504. paddlex/modules/multilabel_classification/dataset_checker/dataset_src/utils/__init__.py +1 -1
  505. paddlex/modules/multilabel_classification/dataset_checker/dataset_src/utils/visualizer.py +1 -5
  506. paddlex/modules/multilabel_classification/evaluator.py +1 -1
  507. paddlex/modules/multilabel_classification/exportor.py +1 -1
  508. paddlex/modules/multilabel_classification/model_list.py +1 -1
  509. paddlex/modules/multilabel_classification/trainer.py +3 -3
  510. paddlex/modules/multilingual_speech_recognition/__init__.py +2 -2
  511. paddlex/modules/multilingual_speech_recognition/dataset_checker.py +3 -3
  512. paddlex/modules/multilingual_speech_recognition/evaluator.py +3 -3
  513. paddlex/modules/multilingual_speech_recognition/exportor.py +3 -3
  514. paddlex/modules/multilingual_speech_recognition/model_list.py +1 -1
  515. paddlex/modules/multilingual_speech_recognition/trainer.py +7 -5
  516. paddlex/modules/object_detection/__init__.py +2 -2
  517. paddlex/modules/object_detection/dataset_checker/__init__.py +2 -11
  518. paddlex/modules/object_detection/dataset_checker/dataset_src/__init__.py +2 -2
  519. paddlex/modules/object_detection/dataset_checker/dataset_src/analyse_dataset.py +10 -8
  520. paddlex/modules/object_detection/dataset_checker/dataset_src/check_dataset.py +10 -5
  521. paddlex/modules/object_detection/dataset_checker/dataset_src/convert_dataset.py +13 -8
  522. paddlex/modules/object_detection/dataset_checker/dataset_src/split_dataset.py +8 -4
  523. paddlex/modules/object_detection/dataset_checker/dataset_src/utils/__init__.py +1 -1
  524. paddlex/modules/object_detection/dataset_checker/dataset_src/utils/visualizer.py +9 -8
  525. paddlex/modules/object_detection/evaluator.py +9 -4
  526. paddlex/modules/object_detection/exportor.py +1 -1
  527. paddlex/modules/object_detection/model_list.py +1 -1
  528. paddlex/modules/object_detection/trainer.py +4 -5
  529. paddlex/modules/open_vocabulary_detection/__init__.py +2 -2
  530. paddlex/modules/open_vocabulary_detection/dataset_checker.py +3 -3
  531. paddlex/modules/open_vocabulary_detection/evaluator.py +3 -3
  532. paddlex/modules/open_vocabulary_detection/exportor.py +3 -3
  533. paddlex/modules/open_vocabulary_detection/model_list.py +2 -4
  534. paddlex/modules/open_vocabulary_detection/trainer.py +7 -5
  535. paddlex/modules/open_vocabulary_segmentation/__init__.py +2 -2
  536. paddlex/modules/open_vocabulary_segmentation/dataset_checker.py +3 -3
  537. paddlex/modules/open_vocabulary_segmentation/evaluator.py +3 -3
  538. paddlex/modules/open_vocabulary_segmentation/exportor.py +3 -3
  539. paddlex/modules/open_vocabulary_segmentation/model_list.py +1 -1
  540. paddlex/modules/open_vocabulary_segmentation/trainer.py +7 -5
  541. paddlex/modules/semantic_segmentation/__init__.py +2 -2
  542. paddlex/modules/semantic_segmentation/dataset_checker/__init__.py +2 -3
  543. paddlex/modules/semantic_segmentation/dataset_checker/dataset_src/__init__.py +2 -2
  544. paddlex/modules/semantic_segmentation/dataset_checker/dataset_src/analyse_dataset.py +6 -3
  545. paddlex/modules/semantic_segmentation/dataset_checker/dataset_src/check_dataset.py +2 -2
  546. paddlex/modules/semantic_segmentation/dataset_checker/dataset_src/convert_dataset.py +7 -4
  547. paddlex/modules/semantic_segmentation/dataset_checker/dataset_src/split_dataset.py +2 -2
  548. paddlex/modules/semantic_segmentation/dataset_checker/dataset_src/utils/__init__.py +1 -1
  549. paddlex/modules/semantic_segmentation/dataset_checker/dataset_src/utils/visualizer.py +6 -2
  550. paddlex/modules/semantic_segmentation/evaluator.py +1 -1
  551. paddlex/modules/semantic_segmentation/exportor.py +1 -1
  552. paddlex/modules/semantic_segmentation/model_list.py +1 -1
  553. paddlex/modules/semantic_segmentation/trainer.py +3 -4
  554. paddlex/modules/table_recognition/__init__.py +2 -2
  555. paddlex/modules/table_recognition/dataset_checker/__init__.py +5 -5
  556. paddlex/modules/table_recognition/dataset_checker/dataset_src/__init__.py +2 -2
  557. paddlex/modules/table_recognition/dataset_checker/dataset_src/analyse_dataset.py +3 -2
  558. paddlex/modules/table_recognition/dataset_checker/dataset_src/check_dataset.py +8 -7
  559. paddlex/modules/table_recognition/dataset_checker/dataset_src/split_dataset.py +2 -1
  560. paddlex/modules/table_recognition/evaluator.py +1 -1
  561. paddlex/modules/table_recognition/exportor.py +1 -1
  562. paddlex/modules/table_recognition/model_list.py +1 -1
  563. paddlex/modules/table_recognition/trainer.py +2 -5
  564. paddlex/modules/text_detection/__init__.py +2 -2
  565. paddlex/modules/text_detection/dataset_checker/__init__.py +4 -6
  566. paddlex/modules/text_detection/dataset_checker/dataset_src/__init__.py +2 -2
  567. paddlex/modules/text_detection/dataset_checker/dataset_src/analyse_dataset.py +12 -9
  568. paddlex/modules/text_detection/dataset_checker/dataset_src/check_dataset.py +3 -3
  569. paddlex/modules/text_detection/dataset_checker/dataset_src/split_dataset.py +3 -3
  570. paddlex/modules/text_detection/evaluator.py +1 -1
  571. paddlex/modules/text_detection/exportor.py +1 -1
  572. paddlex/modules/text_detection/model_list.py +1 -1
  573. paddlex/modules/text_detection/trainer.py +2 -5
  574. paddlex/modules/text_recognition/__init__.py +2 -2
  575. paddlex/modules/text_recognition/dataset_checker/__init__.py +4 -5
  576. paddlex/modules/text_recognition/dataset_checker/dataset_src/__init__.py +2 -2
  577. paddlex/modules/text_recognition/dataset_checker/dataset_src/analyse_dataset.py +13 -12
  578. paddlex/modules/text_recognition/dataset_checker/dataset_src/check_dataset.py +2 -5
  579. paddlex/modules/text_recognition/dataset_checker/dataset_src/convert_dataset.py +11 -10
  580. paddlex/modules/text_recognition/dataset_checker/dataset_src/split_dataset.py +1 -2
  581. paddlex/modules/text_recognition/evaluator.py +1 -1
  582. paddlex/modules/text_recognition/exportor.py +1 -1
  583. paddlex/modules/text_recognition/model_list.py +1 -1
  584. paddlex/modules/text_recognition/trainer.py +2 -3
  585. paddlex/modules/ts_anomaly_detection/__init__.py +2 -2
  586. paddlex/modules/ts_anomaly_detection/dataset_checker/__init__.py +4 -5
  587. paddlex/modules/ts_anomaly_detection/dataset_checker/dataset_src/__init__.py +2 -2
  588. paddlex/modules/ts_anomaly_detection/dataset_checker/dataset_src/analyse_dataset.py +1 -9
  589. paddlex/modules/ts_anomaly_detection/dataset_checker/dataset_src/check_dataset.py +2 -2
  590. paddlex/modules/ts_anomaly_detection/dataset_checker/dataset_src/convert_dataset.py +2 -6
  591. paddlex/modules/ts_anomaly_detection/dataset_checker/dataset_src/split_dataset.py +4 -4
  592. paddlex/modules/ts_anomaly_detection/evaluator.py +1 -1
  593. paddlex/modules/ts_anomaly_detection/exportor.py +2 -3
  594. paddlex/modules/ts_anomaly_detection/model_list.py +1 -1
  595. paddlex/modules/ts_anomaly_detection/trainer.py +8 -8
  596. paddlex/modules/ts_classification/__init__.py +2 -2
  597. paddlex/modules/ts_classification/dataset_checker/__init__.py +4 -5
  598. paddlex/modules/ts_classification/dataset_checker/dataset_src/__init__.py +2 -2
  599. paddlex/modules/ts_classification/dataset_checker/dataset_src/analyse_dataset.py +8 -5
  600. paddlex/modules/ts_classification/dataset_checker/dataset_src/check_dataset.py +2 -2
  601. paddlex/modules/ts_classification/dataset_checker/dataset_src/convert_dataset.py +2 -6
  602. paddlex/modules/ts_classification/dataset_checker/dataset_src/split_dataset.py +4 -4
  603. paddlex/modules/ts_classification/evaluator.py +1 -1
  604. paddlex/modules/ts_classification/exportor.py +2 -3
  605. paddlex/modules/ts_classification/model_list.py +1 -1
  606. paddlex/modules/ts_classification/trainer.py +7 -7
  607. paddlex/modules/ts_forecast/__init__.py +2 -2
  608. paddlex/modules/ts_forecast/dataset_checker/__init__.py +4 -5
  609. paddlex/modules/ts_forecast/dataset_checker/dataset_src/__init__.py +2 -2
  610. paddlex/modules/ts_forecast/dataset_checker/dataset_src/analyse_dataset.py +1 -9
  611. paddlex/modules/ts_forecast/dataset_checker/dataset_src/check_dataset.py +2 -2
  612. paddlex/modules/ts_forecast/dataset_checker/dataset_src/convert_dataset.py +2 -6
  613. paddlex/modules/ts_forecast/dataset_checker/dataset_src/split_dataset.py +4 -4
  614. paddlex/modules/ts_forecast/evaluator.py +1 -1
  615. paddlex/modules/ts_forecast/exportor.py +2 -3
  616. paddlex/modules/ts_forecast/model_list.py +1 -1
  617. paddlex/modules/ts_forecast/trainer.py +7 -7
  618. paddlex/modules/video_classification/__init__.py +2 -2
  619. paddlex/modules/video_classification/dataset_checker/__init__.py +2 -2
  620. paddlex/modules/video_classification/dataset_checker/dataset_src/__init__.py +2 -2
  621. paddlex/modules/video_classification/dataset_checker/dataset_src/analyse_dataset.py +9 -9
  622. paddlex/modules/video_classification/dataset_checker/dataset_src/check_dataset.py +2 -3
  623. paddlex/modules/video_classification/dataset_checker/dataset_src/split_dataset.py +1 -1
  624. paddlex/modules/video_classification/evaluator.py +1 -1
  625. paddlex/modules/video_classification/exportor.py +1 -1
  626. paddlex/modules/video_classification/model_list.py +1 -1
  627. paddlex/modules/video_classification/trainer.py +3 -3
  628. paddlex/modules/video_detection/__init__.py +2 -2
  629. paddlex/modules/video_detection/dataset_checker/__init__.py +2 -2
  630. paddlex/modules/video_detection/dataset_checker/dataset_src/__init__.py +2 -2
  631. paddlex/modules/video_detection/dataset_checker/dataset_src/analyse_dataset.py +8 -9
  632. paddlex/modules/video_detection/dataset_checker/dataset_src/check_dataset.py +3 -5
  633. paddlex/modules/video_detection/evaluator.py +1 -1
  634. paddlex/modules/video_detection/exportor.py +1 -1
  635. paddlex/modules/video_detection/model_list.py +1 -1
  636. paddlex/modules/video_detection/trainer.py +3 -3
  637. paddlex/ops/__init__.py +5 -2
  638. paddlex/ops/iou3d_nms/iou3d_cpu.cpp +8 -6
  639. paddlex/ops/iou3d_nms/iou3d_cpu.h +3 -2
  640. paddlex/ops/iou3d_nms/iou3d_nms.cpp +8 -6
  641. paddlex/ops/iou3d_nms/iou3d_nms.h +6 -4
  642. paddlex/ops/iou3d_nms/iou3d_nms_api.cpp +24 -18
  643. paddlex/ops/iou3d_nms/iou3d_nms_kernel.cu +9 -7
  644. paddlex/ops/setup.py +3 -3
  645. paddlex/ops/voxel/voxelize_op.cc +22 -19
  646. paddlex/ops/voxel/voxelize_op.cu +25 -25
  647. paddlex/paddlex_cli.py +86 -75
  648. paddlex/repo_apis/Paddle3D_api/__init__.py +1 -1
  649. paddlex/repo_apis/Paddle3D_api/bev_fusion/__init__.py +1 -1
  650. paddlex/repo_apis/Paddle3D_api/bev_fusion/config.py +1 -1
  651. paddlex/repo_apis/Paddle3D_api/bev_fusion/model.py +4 -4
  652. paddlex/repo_apis/Paddle3D_api/bev_fusion/register.py +2 -2
  653. paddlex/repo_apis/Paddle3D_api/bev_fusion/runner.py +1 -1
  654. paddlex/repo_apis/Paddle3D_api/pp3d_config.py +3 -2
  655. paddlex/repo_apis/PaddleClas_api/__init__.py +1 -1
  656. paddlex/repo_apis/PaddleClas_api/cls/__init__.py +3 -3
  657. paddlex/repo_apis/PaddleClas_api/cls/config.py +4 -3
  658. paddlex/repo_apis/PaddleClas_api/cls/model.py +3 -3
  659. paddlex/repo_apis/PaddleClas_api/cls/register.py +2 -3
  660. paddlex/repo_apis/PaddleClas_api/cls/runner.py +1 -2
  661. paddlex/repo_apis/PaddleClas_api/shitu_rec/__init__.py +2 -2
  662. paddlex/repo_apis/PaddleClas_api/shitu_rec/config.py +2 -2
  663. paddlex/repo_apis/PaddleClas_api/shitu_rec/model.py +1 -4
  664. paddlex/repo_apis/PaddleClas_api/shitu_rec/register.py +2 -2
  665. paddlex/repo_apis/PaddleClas_api/shitu_rec/runner.py +1 -6
  666. paddlex/repo_apis/PaddleDetection_api/__init__.py +2 -2
  667. paddlex/repo_apis/PaddleDetection_api/config_helper.py +3 -3
  668. paddlex/repo_apis/PaddleDetection_api/instance_seg/__init__.py +2 -2
  669. paddlex/repo_apis/PaddleDetection_api/instance_seg/config.py +2 -3
  670. paddlex/repo_apis/PaddleDetection_api/instance_seg/model.py +3 -3
  671. paddlex/repo_apis/PaddleDetection_api/instance_seg/register.py +2 -3
  672. paddlex/repo_apis/PaddleDetection_api/instance_seg/runner.py +1 -2
  673. paddlex/repo_apis/PaddleDetection_api/object_det/__init__.py +3 -3
  674. paddlex/repo_apis/PaddleDetection_api/object_det/config.py +4 -3
  675. paddlex/repo_apis/PaddleDetection_api/object_det/model.py +5 -6
  676. paddlex/repo_apis/PaddleDetection_api/object_det/official_categories.py +1 -1
  677. paddlex/repo_apis/PaddleDetection_api/object_det/register.py +2 -3
  678. paddlex/repo_apis/PaddleDetection_api/object_det/runner.py +1 -2
  679. paddlex/repo_apis/PaddleNLP_api/__init__.py +1 -1
  680. paddlex/repo_apis/PaddleOCR_api/__init__.py +4 -3
  681. paddlex/repo_apis/PaddleOCR_api/config_utils.py +1 -1
  682. paddlex/repo_apis/PaddleOCR_api/formula_rec/__init__.py +1 -1
  683. paddlex/repo_apis/PaddleOCR_api/formula_rec/config.py +4 -3
  684. paddlex/repo_apis/PaddleOCR_api/formula_rec/model.py +4 -4
  685. paddlex/repo_apis/PaddleOCR_api/formula_rec/register.py +2 -3
  686. paddlex/repo_apis/PaddleOCR_api/formula_rec/runner.py +1 -2
  687. paddlex/repo_apis/PaddleOCR_api/table_rec/__init__.py +1 -1
  688. paddlex/repo_apis/PaddleOCR_api/table_rec/config.py +1 -1
  689. paddlex/repo_apis/PaddleOCR_api/table_rec/model.py +3 -3
  690. paddlex/repo_apis/PaddleOCR_api/table_rec/register.py +2 -3
  691. paddlex/repo_apis/PaddleOCR_api/table_rec/runner.py +2 -2
  692. paddlex/repo_apis/PaddleOCR_api/text_det/__init__.py +1 -1
  693. paddlex/repo_apis/PaddleOCR_api/text_det/config.py +1 -1
  694. paddlex/repo_apis/PaddleOCR_api/text_det/model.py +3 -3
  695. paddlex/repo_apis/PaddleOCR_api/text_det/register.py +2 -3
  696. paddlex/repo_apis/PaddleOCR_api/text_det/runner.py +2 -2
  697. paddlex/repo_apis/PaddleOCR_api/text_rec/__init__.py +1 -1
  698. paddlex/repo_apis/PaddleOCR_api/text_rec/config.py +4 -3
  699. paddlex/repo_apis/PaddleOCR_api/text_rec/model.py +4 -4
  700. paddlex/repo_apis/PaddleOCR_api/text_rec/register.py +2 -3
  701. paddlex/repo_apis/PaddleOCR_api/text_rec/runner.py +1 -2
  702. paddlex/repo_apis/PaddleSeg_api/__init__.py +1 -1
  703. paddlex/repo_apis/PaddleSeg_api/base_seg_config.py +2 -2
  704. paddlex/repo_apis/PaddleSeg_api/seg/__init__.py +1 -1
  705. paddlex/repo_apis/PaddleSeg_api/seg/config.py +3 -6
  706. paddlex/repo_apis/PaddleSeg_api/seg/model.py +5 -5
  707. paddlex/repo_apis/PaddleSeg_api/seg/register.py +2 -3
  708. paddlex/repo_apis/PaddleSeg_api/seg/runner.py +1 -2
  709. paddlex/repo_apis/PaddleTS_api/__init__.py +4 -3
  710. paddlex/repo_apis/PaddleTS_api/ts_ad/__init__.py +1 -1
  711. paddlex/repo_apis/PaddleTS_api/ts_ad/config.py +2 -3
  712. paddlex/repo_apis/PaddleTS_api/ts_ad/register.py +2 -2
  713. paddlex/repo_apis/PaddleTS_api/ts_ad/runner.py +2 -2
  714. paddlex/repo_apis/PaddleTS_api/ts_base/__init__.py +1 -1
  715. paddlex/repo_apis/PaddleTS_api/ts_base/config.py +2 -4
  716. paddlex/repo_apis/PaddleTS_api/ts_base/model.py +4 -4
  717. paddlex/repo_apis/PaddleTS_api/ts_base/runner.py +2 -2
  718. paddlex/repo_apis/PaddleTS_api/ts_cls/__init__.py +1 -1
  719. paddlex/repo_apis/PaddleTS_api/ts_cls/config.py +2 -3
  720. paddlex/repo_apis/PaddleTS_api/ts_cls/register.py +2 -2
  721. paddlex/repo_apis/PaddleTS_api/ts_cls/runner.py +2 -2
  722. paddlex/repo_apis/PaddleTS_api/ts_fc/__init__.py +1 -1
  723. paddlex/repo_apis/PaddleTS_api/ts_fc/config.py +2 -3
  724. paddlex/repo_apis/PaddleTS_api/ts_fc/register.py +1 -1
  725. paddlex/repo_apis/PaddleVideo_api/__init__.py +1 -1
  726. paddlex/repo_apis/PaddleVideo_api/config_utils.py +1 -1
  727. paddlex/repo_apis/PaddleVideo_api/video_cls/__init__.py +3 -3
  728. paddlex/repo_apis/PaddleVideo_api/video_cls/config.py +4 -3
  729. paddlex/repo_apis/PaddleVideo_api/video_cls/model.py +3 -3
  730. paddlex/repo_apis/PaddleVideo_api/video_cls/register.py +2 -3
  731. paddlex/repo_apis/PaddleVideo_api/video_cls/runner.py +1 -2
  732. paddlex/repo_apis/PaddleVideo_api/video_det/__init__.py +3 -3
  733. paddlex/repo_apis/PaddleVideo_api/video_det/config.py +4 -3
  734. paddlex/repo_apis/PaddleVideo_api/video_det/model.py +4 -4
  735. paddlex/repo_apis/PaddleVideo_api/video_det/register.py +2 -3
  736. paddlex/repo_apis/PaddleVideo_api/video_det/runner.py +1 -2
  737. paddlex/repo_apis/__init__.py +1 -1
  738. paddlex/repo_apis/base/__init__.py +4 -5
  739. paddlex/repo_apis/base/config.py +2 -3
  740. paddlex/repo_apis/base/model.py +11 -19
  741. paddlex/repo_apis/base/register.py +1 -1
  742. paddlex/repo_apis/base/runner.py +11 -12
  743. paddlex/repo_apis/base/utils/__init__.py +1 -1
  744. paddlex/repo_apis/base/utils/arg.py +1 -1
  745. paddlex/repo_apis/base/utils/subprocess.py +1 -1
  746. paddlex/repo_manager/__init__.py +2 -9
  747. paddlex/repo_manager/core.py +9 -27
  748. paddlex/repo_manager/meta.py +37 -31
  749. paddlex/repo_manager/repo.py +169 -160
  750. paddlex/repo_manager/utils.py +13 -224
  751. paddlex/utils/__init__.py +1 -1
  752. paddlex/utils/cache.py +8 -10
  753. paddlex/utils/config.py +6 -5
  754. paddlex/utils/{custom_device_whitelist.py → custom_device_list.py} +29 -199
  755. paddlex/utils/deps.py +249 -0
  756. paddlex/utils/device.py +73 -29
  757. paddlex/utils/download.py +4 -4
  758. paddlex/utils/env.py +33 -7
  759. paddlex/utils/errors/__init__.py +1 -1
  760. paddlex/utils/errors/dataset_checker.py +1 -1
  761. paddlex/utils/errors/others.py +2 -16
  762. paddlex/utils/file_interface.py +4 -5
  763. paddlex/utils/flags.py +19 -12
  764. paddlex/utils/fonts/__init__.py +2 -1
  765. paddlex/utils/func_register.py +1 -1
  766. paddlex/utils/install.py +87 -0
  767. paddlex/utils/interactive_get_pipeline.py +3 -3
  768. paddlex/utils/lazy_loader.py +3 -3
  769. paddlex/utils/logging.py +10 -1
  770. paddlex/utils/misc.py +5 -5
  771. paddlex/utils/pipeline_arguments.py +15 -7
  772. paddlex/utils/result_saver.py +4 -5
  773. paddlex/utils/subclass_register.py +2 -4
  774. paddlex/version.py +2 -1
  775. {paddlex-3.0.0rc0.dist-info → paddlex-3.0.0rc1.dist-info}/METADATA +212 -73
  776. paddlex-3.0.0rc1.dist-info/RECORD +1068 -0
  777. {paddlex-3.0.0rc0.dist-info → paddlex-3.0.0rc1.dist-info}/WHEEL +1 -1
  778. paddlex/inference/models/base/predictor/basic_predictor.py +0 -139
  779. paddlex/paddle2onnx_requirements.txt +0 -1
  780. paddlex/repo_manager/requirements.txt +0 -21
  781. paddlex/serving_requirements.txt +0 -9
  782. paddlex-3.0.0rc0.dist-info/RECORD +0 -1015
  783. {paddlex-3.0.0rc0.dist-info → paddlex-3.0.0rc1.dist-info}/entry_points.txt +0 -0
  784. {paddlex-3.0.0rc0.dist-info → paddlex-3.0.0rc1.dist-info/licenses}/LICENSE +0 -0
  785. {paddlex-3.0.0rc0.dist-info → paddlex-3.0.0rc1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,730 @@
1
+ # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from abc import ABC
17
+ from collections import OrderedDict
18
+ from typing import Callable, Dict, List, Tuple, Union
19
+
20
+ import numpy as np
21
+ import paddle
22
+ from paddle.nn.layer.layers import in_declarative_mode
23
+
24
+
25
+ class LogitsProcessor(ABC):
26
+ """
27
+ Abstract base class for all logit processors that can be applied during
28
+ generation.
29
+ """
30
+
31
+ def __call__(self, input_ids: paddle.Tensor, logits: paddle.Tensor):
32
+ raise NotImplementedError(
33
+ f"{self.__class__} is an abstract class. "
34
+ "Only classes inheriting this class can be called."
35
+ )
36
+
37
+
38
+ class LogitsProcessorList:
39
+ """use ordered dict to store processors"""
40
+
41
+ def __init__(self, processors: List[LogitsProcessor] = None) -> None:
42
+ self._processors = OrderedDict()
43
+ processors = processors or []
44
+ for processor in processors:
45
+ self.append(processor)
46
+
47
+ def __call__(self, input_ids: paddle.Tensor, logits: paddle.Tensor, **kwargs):
48
+ for processor in self._processors.values():
49
+ processor_args = inspect.signature(processor.__call__).parameters
50
+ if len(processor_args) > 2:
51
+ assert all(
52
+ arg in kwargs for arg in list(processor_args.keys())[2:]
53
+ ), f"The parameters don't match for {processor.__class__}"
54
+ logits = processor(input_ids, logits, **kwargs)
55
+ else:
56
+ logits = processor(input_ids, logits)
57
+ return logits
58
+
59
+ def append(self, processor: LogitsProcessor):
60
+ self._processors[len(self._processors)] = processor
61
+
62
+
63
+ class MinLengthLogitsProcessor(LogitsProcessor):
64
+ r"""
65
+ Enforcing a min-length by setting EOS probability to 0.
66
+
67
+ Args:
68
+ min_length (int): The minimum length of generation sequence.
69
+ eos_token_id (int): The id of the `end-of-sequence` token.
70
+ """
71
+
72
+ def __init__(self, min_length: int, eos_token_id: Union[int, List[int]]):
73
+ if min_length < 0 and not in_declarative_mode():
74
+ raise ValueError(
75
+ "`min_length` should be a positive integer, but get {}".format(
76
+ min_length
77
+ )
78
+ )
79
+
80
+ if not isinstance(eos_token_id, int) or eos_token_id < 0:
81
+ raise ValueError(
82
+ "`eos_token_id` should be a positive integer, but get {}".format(
83
+ eos_token_id
84
+ )
85
+ )
86
+
87
+ self.min_length = min_length
88
+ self.eos_token_id = eos_token_id
89
+
90
+ def __call__(self, input_ids: paddle.Tensor, logits: paddle.Tensor):
91
+ cur_len = input_ids.shape[-1]
92
+ if cur_len < self.min_length:
93
+ logits[:, self.eos_token_id] = paddle.finfo(logits.dtype).min
94
+ return logits
95
+
96
+
97
+ class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
98
+ r"""
99
+ Enforcing an exponential penalty on repeated sequences.
100
+
101
+ Args:
102
+ repetition_penalty (float):
103
+ The parameter for repetition penalty. 1.0 means no penalty. See `this paper
104
+ <https://arxiv.org/pdf/1909.05858.pdf>`__ for more details.
105
+ """
106
+
107
+ def __init__(self, penalty: float):
108
+ if not (penalty > 0) and not in_declarative_mode():
109
+ raise ValueError(
110
+ f"`penalty` has to be a strictly positive float, but is {penalty}"
111
+ )
112
+
113
+ self.penalty = penalty
114
+
115
+ def __call__(self, input_ids: paddle.Tensor, logits: paddle.Tensor):
116
+ score = paddle.index_sample(logits, input_ids)
117
+ score = paddle.where(score < 0, score * self.penalty, score / self.penalty)
118
+ input_ids = (
119
+ input_ids
120
+ + paddle.arange(logits.shape[0], dtype="int64").unsqueeze(-1)
121
+ * logits.shape[-1]
122
+ )
123
+ outputs = paddle.scatter(
124
+ logits.flatten(), input_ids.flatten(), score.flatten()
125
+ ).reshape(logits.shape)
126
+ return outputs
127
+
128
+
129
+ def _get_ngrams(ngram_size: int, prev_input_ids: paddle.Tensor, num_hypos: int):
130
+ """
131
+ Assume ngram_size=2 and prev_input_ids=tensor([[40, 2883, 2712, 4346]]). The output of generated ngrams look like
132
+ this {(40,): [2883], (2883,): [2712], (2712,): [4346]}.
133
+
134
+ Args:
135
+ ngram_size (`int`):
136
+ The number sequential tokens taken as a group which may only occur once before being banned.
137
+ prev_input_ids (`paddle.Tensor`):
138
+ Generated token ids for the current hypothesis.
139
+ num_hypos (`int`):
140
+ The number of hypotheses for which n-grams need to be generated.
141
+
142
+ Returns:
143
+ generated_ngrams (`dict`):
144
+ Dictionary of generated ngrams.
145
+ """
146
+ generated_ngrams = [{} for _ in range(num_hypos)]
147
+ for idx in range(num_hypos):
148
+ gen_tokens = prev_input_ids[idx].tolist()
149
+ generated_ngram = generated_ngrams[idx]
150
+ for ngram in zip(*[gen_tokens[i:] for i in range(ngram_size)]):
151
+ prev_ngram_tuple = tuple(ngram[:-1])
152
+ generated_ngram[prev_ngram_tuple] = generated_ngram.get(
153
+ prev_ngram_tuple, []
154
+ ) + [ngram[-1]]
155
+ return generated_ngrams
156
+
157
+
158
+ def _get_generated_ngrams(banned_ngrams, prev_input_ids, ngram_size, cur_len):
159
+ """
160
+ Determines the banned tokens for the current hypothesis based on previously generated n-grams.
161
+
162
+ Args:
163
+ banned_ngrams (`dict`):
164
+ A dictionary containing previously generated n-grams for each hypothesis.
165
+ prev_input_ids (`paddle.Tensor`):
166
+ Generated token ids for the current hypothesis.
167
+ ngram_size (`int`):
168
+ The number sequential tokens taken as a group which may only occur once before being banned.
169
+ cur_len (`int`):
170
+ The current length of the token sequences for which the n-grams are being checked.
171
+
172
+ Returns:
173
+ List of tokens that are banned.
174
+ """
175
+ start_idx = cur_len + 1 - ngram_size
176
+ ngram_idx = tuple(prev_input_ids[start_idx:cur_len].tolist())
177
+ return banned_ngrams.get(ngram_idx, [])
178
+
179
+
180
+ def _calc_banned_ngram_tokens(
181
+ ngram_size: int, prev_input_ids: paddle.Tensor, num_hypos: int, cur_len: int
182
+ ):
183
+ """Copied from fairseq for no_repeat_ngram in beam_search"""
184
+ if cur_len + 1 < ngram_size:
185
+ # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
186
+ return [[] for _ in range(num_hypos)]
187
+
188
+ generated_ngrams = _get_ngrams(ngram_size, prev_input_ids, num_hypos)
189
+
190
+ banned_tokens = [
191
+ _get_generated_ngrams(
192
+ generated_ngrams[hypo_idx], prev_input_ids[hypo_idx], ngram_size, cur_len
193
+ )
194
+ for hypo_idx in range(num_hypos)
195
+ ]
196
+ return banned_tokens
197
+
198
+
199
+ class NoRepeatNGramLogitsProcessor(LogitsProcessor):
200
+ r"""
201
+ [`LogitsProcessor`] that enforces no repetition of n-grams. See
202
+ [Fairseq](https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345).
203
+ Args:
204
+ ngram_size (`int`):
205
+ All ngrams of size `ngram_size` can only occur once.
206
+ """
207
+
208
+ def __init__(self, ngram_size: int):
209
+ if not isinstance(ngram_size, int) or ngram_size <= 0:
210
+ raise ValueError(
211
+ f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}"
212
+ )
213
+ self.ngram_size = ngram_size
214
+
215
+ def __call__(self, input_ids: paddle.Tensor, scores: paddle.Tensor):
216
+ num_batch_hypotheses = scores.shape[0]
217
+ cur_len = input_ids.shape[-1]
218
+ banned_batch_tokens = _calc_banned_ngram_tokens(
219
+ self.ngram_size, input_ids, num_batch_hypotheses, cur_len
220
+ )
221
+
222
+ for i, banned_tokens in enumerate(banned_batch_tokens):
223
+ if len(banned_tokens) == 0:
224
+ continue
225
+ scores[i, banned_tokens] = paddle.finfo(scores.dtype).min
226
+
227
+ return scores
228
+
229
+
230
+ class HammingDiversityLogitsProcessor(LogitsProcessor):
231
+ """
232
+ This `LogitsProcessor` enforces diverse beam search. Note that this logits
233
+ processor is only effective for `group_beam_search`. See
234
+ `this paper <https://arxiv.org/pdf/1610.02424.pdf>`__ for more details.
235
+
236
+ Args:
237
+ diversity_rate (float): This value is subtracted from a beam's score if
238
+ it generates a token same as any beam from other group at a particular
239
+ time.
240
+ num_beams (int): Number of beams used for group beam search.
241
+ num_beam_groups (int): Number of groups to divide `num_beams` into in order
242
+ to ensure diversity among different groups of beams.
243
+ """
244
+
245
+ def __init__(self, diversity_rate: float, num_beams: int, num_beam_groups: int):
246
+ if not isinstance(diversity_rate, float) or (not diversity_rate > 0.0):
247
+ raise ValueError(
248
+ "`diversity_rate` should be a float strictly larger than 0."
249
+ )
250
+ self._diversity_rate = diversity_rate
251
+ if not isinstance(num_beams, int) or num_beams < 2:
252
+ raise ValueError("`num_beams` should be an integer strictly larger than 1.")
253
+ self._num_beams = num_beams
254
+ if not isinstance(num_beam_groups, int) or num_beam_groups < 2:
255
+ raise ValueError(
256
+ "`num_beam_groups` should be an integer strictly larger than 1."
257
+ )
258
+ self._num_sub_beams = num_beams // num_beam_groups
259
+
260
+ def __call__(
261
+ self,
262
+ input_ids: paddle.Tensor,
263
+ scores: paddle.Tensor,
264
+ current_tokens: paddle.Tensor,
265
+ beam_group_idx: int,
266
+ ):
267
+ batch_size = current_tokens.shape[0] // self._num_beams
268
+ group_start_idx = beam_group_idx * self._num_sub_beams
269
+ group_end_idx = min(group_start_idx + self._num_sub_beams, self._num_beams)
270
+ group_size = group_end_idx - group_start_idx
271
+ vocab_size = scores.shape[-1]
272
+
273
+ if group_start_idx == 0:
274
+ return scores
275
+
276
+ for batch_idx in range(batch_size):
277
+ previous_group_tokens = current_tokens[
278
+ batch_idx * self._num_beams : batch_idx * self._num_beams
279
+ + group_start_idx
280
+ ]
281
+ token_frequency = paddle.bincount(
282
+ previous_group_tokens, minlength=vocab_size
283
+ )
284
+ scores[batch_idx * group_size : (batch_idx + 1) * group_size] -= (
285
+ self._diversity_rate * token_frequency
286
+ )
287
+
288
+ return scores
289
+
290
+
291
+ class ForcedBOSTokenLogitsProcessor(LogitsProcessor):
292
+ """
293
+ This `LogitsProcessor` enforces the first generated token to be the selected `forced_bos_token`.
294
+
295
+ Args:
296
+ forced_bos_token_id (:obj:`int`):
297
+ The id of the token to be generated as the first token.
298
+ """
299
+
300
+ def __init__(self, forced_bos_token_id: int):
301
+ self.forced_bos_token_id = forced_bos_token_id
302
+
303
+ def __call__(self, input_ids: paddle.Tensor, scores: paddle.Tensor):
304
+ cur_len = input_ids.shape[-1]
305
+ if cur_len == 1:
306
+ scores[:] = paddle.finfo(scores.dtype).min
307
+ scores[:, self.forced_bos_token_id] = 0
308
+ return scores
309
+
310
+
311
+ class ForcedEOSTokenLogitsProcessor(LogitsProcessor):
312
+ """
313
+ This `LogitsProcessor` enforces the last generated token to be the selected `forced_eos_token`.
314
+
315
+ Args:
316
+ max_length (int): The maximum length of the sequence to be generated.
317
+ forced_eos_token_id (int): The id of the token to be generated as the last token.
318
+ """
319
+
320
+ def __init__(self, max_length: int, forced_eos_token_id: Union[int, List[int]]):
321
+ self.max_length = max_length
322
+ self.forced_eos_token_id = forced_eos_token_id
323
+
324
+ def __call__(self, input_ids, scores):
325
+ cur_len = input_ids.shape[-1]
326
+ if cur_len == self.max_length - 1:
327
+ scores[:] = paddle.finfo(scores.dtype).min
328
+ scores[:, self.forced_eos_token_id] = 0
329
+ return scores
330
+
331
+
332
+ def TopKProcess(probs: paddle.Tensor, top_k: int, min_tokens_to_keep: int):
333
+ top_k = paddle.minimum(
334
+ paddle.maximum(paddle.to_tensor(top_k), paddle.to_tensor(min_tokens_to_keep)),
335
+ paddle.to_tensor(probs.shape[-1]),
336
+ )
337
+ # Remove all tokens with a probability less than the last token of the top-k
338
+ # cast to float16 to support generation & d2s
339
+ if probs.dtype == paddle.bfloat16:
340
+ probs = paddle.cast(probs, paddle.float32)
341
+ topk_probs, _ = paddle.topk(probs, k=top_k)
342
+ topk_probs = paddle.cast(topk_probs, paddle.bfloat16)
343
+ else:
344
+ topk_probs, _ = paddle.topk(probs, k=top_k)
345
+
346
+ probs = paddle.where(
347
+ probs >= topk_probs[:, -1:], probs, paddle.full_like(probs, 0.0)
348
+ )
349
+ return probs
350
+
351
+
352
+ def TopPProcess(probs: paddle.Tensor, top_p: float, min_tokens_to_keep: int):
353
+ if probs.dtype == paddle.bfloat16:
354
+ probs = paddle.cast(probs, paddle.float32)
355
+
356
+ sorted_indices = paddle.argsort(probs, descending=True)
357
+ sorted_probs = paddle.sort(probs, descending=True)
358
+
359
+ sorted_probs = paddle.cast(sorted_probs, paddle.bfloat16)
360
+
361
+ else:
362
+ sorted_indices = paddle.argsort(probs, descending=True)
363
+ sorted_probs = paddle.sort(probs, descending=True)
364
+
365
+ cumulative_probs = paddle.cumsum(sorted_probs, axis=-1)
366
+
367
+ # Remove tokens with cumulative probs above the top_p, But keep at
368
+ # least min_tokens_to_keep tokens
369
+ sorted_indices_to_remove = cumulative_probs > top_p
370
+ if min_tokens_to_keep > 1:
371
+ # Set 'min_tokens_to_keep - 1' because the first token is kept
372
+ sorted_indices_to_remove[:, : min_tokens_to_keep - 1] = 0
373
+ # Keep the first token
374
+ sorted_indices_to_remove = paddle.cast(sorted_indices_to_remove, dtype="int64")
375
+ sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
376
+ sorted_indices_to_remove[:, 0] = 0
377
+
378
+ # Scatter sorted tensors to original indexing
379
+ sorted_indices = (
380
+ sorted_indices
381
+ + paddle.arange(probs.shape[0], dtype="int64").unsqueeze(-1) * probs.shape[-1]
382
+ )
383
+ condition = paddle.scatter(
384
+ sorted_indices_to_remove.flatten(),
385
+ sorted_indices.flatten(),
386
+ sorted_indices_to_remove.flatten(),
387
+ )
388
+ condition = paddle.cast(condition, "bool").reshape(probs.shape)
389
+ probs = paddle.where(condition, paddle.full_like(probs, 0.0), probs)
390
+ return probs
391
+
392
+
393
+ class LogitsWarper:
394
+ """Abstract base class for all logit warpers that can be applied during generation with multinomial sampling."""
395
+
396
+ def __call__(self, input_ids: paddle.Tensor, scores: paddle.Tensor):
397
+ raise NotImplementedError(
398
+ f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
399
+ )
400
+
401
+
402
+ class TemperatureLogitsWarper(LogitsWarper):
403
+ r"""
404
+ [`LogitsWarper`] for temperature (exponential scaling output probability distribution).
405
+ Args:
406
+ temperature (`float`):
407
+ The value used to module the logits distribution.
408
+ """
409
+
410
+ def __init__(self, temperature: float):
411
+ if not isinstance(temperature, float) or not (temperature > 0):
412
+ raise ValueError(
413
+ f"`temperature` has to be a strictly positive float, but is {temperature}"
414
+ )
415
+
416
+ self.temperature = temperature
417
+
418
+ def __call__(self, input_ids: paddle.Tensor, scores: paddle.Tensor):
419
+ scores = scores / self.temperature
420
+ return scores
421
+
422
+
423
+ class SequenceBiasLogitsProcessor(LogitsProcessor):
424
+ """
425
+ [`LogitsProcessor`] that applies an additive bias on sequences. The bias is applied to the last token of a sequence
426
+ when the next generated token can complete it. Consequently, to take the most of biasing sequences with more than
427
+ one token, consider using beam methods (to gracefully work around partially completed sequences that have a
428
+ negative bias) and applying the bias to their prefixes (to ensure the bias is applied earlier).
429
+
430
+ <Tip>
431
+
432
+ In order to get the token ids of the sequences that you want to bias, make sure to set `add_prefix_space=True` when
433
+ initializing the tokenizer, and use `tokenizer(bad_words, add_special_tokens=False).input_ids`. The
434
+ `add_prefix_space` argument is only supported for some slow tokenizers, as fast tokenizers' prefixing behaviours
435
+ come from `pre tokenizers`.
436
+
437
+ </Tip>
438
+
439
+ Args:
440
+ sequence_bias (`Dict[Tuple[int], float]`):
441
+ Dictionary that maps a sequence of tokens to its bias term. Positive biases increase the odds of the
442
+ sequence being selected, while negative biases do the opposite. If a sequence has a length of 1, its bias
443
+ will always be applied. Otherwise, the bias will only be applied if the sequence in question is about to be
444
+ completed (in the token selection step after this processor is applied).
445
+
446
+ Examples:
447
+
448
+ ```python
449
+ >>> from paddlenlp.transformers import AutoTokenizer, AutoModelForCausalLM
450
+
451
+ >>> model = AutoModelForCausalLM.from_pretrained("gpt2-en")
452
+ >>> tokenizer = AutoTokenizer.from_pretrained("gpt2-en")
453
+ >>> inputs = tokenizer(["The full name of Donald is Donald"], return_tensors="pt")
454
+
455
+ >>> summary_ids = model.generate(inputs["input_ids"], max_new_tokens=4)
456
+ >>> print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True)[0])
457
+ The full name of Donald is Donald J. Trump Jr
458
+
459
+ >>> # Now let's control generation through a bias. Please note that the tokenizer is initialized differently!
460
+ >>> tokenizer_with_prefix_space = AutoTokenizer.from_pretrained("gpt2-en")
461
+
462
+
463
+ >>> def get_tokens_as_tuple(word):
464
+ ... return tuple(tokenizer_with_prefix_space([word], add_special_tokens=False).input_ids[0])
465
+
466
+
467
+ >>> # If we add a negative bias without beam search, it may become "stuck" in a prefix without good continuations
468
+ >>> sequence_bias = {get_tokens_as_tuple("Trump"): -10.0}
469
+ >>> biased_ids = model.generate(inputs["input_ids"], max_new_tokens=4, sequence_bias=sequence_bias)
470
+ >>> print(tokenizer.batch_decode(biased_ids, skip_special_tokens=True)[0])
471
+ The full name of Donald is Donald J. Donald,
472
+
473
+ >>> biased_ids = model.generate(inputs["input_ids"], max_new_tokens=4, num_beams=4, sequence_bias=sequence_bias)
474
+ >>> print(tokenizer.batch_decode(biased_ids, skip_special_tokens=True)[0])
475
+ The full name of Donald is Donald Rumsfeld,
476
+
477
+ >>> # We can also add a positive bias to nudge the model towards specific tokens or continuations
478
+ >>> sequence_bias = {get_tokens_as_tuple("Donald Duck"): 10.0}
479
+ >>> biased_ids = model.generate(inputs["input_ids"], max_new_tokens=4, num_beams=4, sequence_bias=sequence_bias)
480
+ >>> print(tokenizer.batch_decode(biased_ids, skip_special_tokens=True)[0])
481
+ The full name of Donald is Donald Duck.
482
+ ```
483
+ """
484
+
485
+ def __init__(self, sequence_bias: Dict[Tuple[int], float]):
486
+ self.sequence_bias = sequence_bias
487
+ self._validate_arguments()
488
+
489
+ # Bias variables that will be populated on the first call (for retrocompatibility purposes, the vocabulary size
490
+ # is infered in the first usage, which inhibits initializing here)
491
+ self.length_1_bias = None
492
+ self.prepared_bias_variables = False
493
+
494
+ def __call__(self, input_ids, scores):
495
+ # 1 - Prepares the bias tensors. This is only needed the first time the logit processor is called.
496
+ if not self.prepared_bias_variables:
497
+ self._prepare_bias_variables(scores)
498
+
499
+ # 2 - prepares an empty bias to add
500
+ bias = paddle.zeros_like(scores)
501
+
502
+ # 3 - include the bias from length = 1
503
+ if self.length_1_bias is not None:
504
+ bias += self.length_1_bias
505
+
506
+ # 4 - include the bias from length > 1, after determining which biased sequences may be completed.
507
+ for sequence_ids, sequence_bias in self.sequence_bias.items():
508
+ if len(sequence_ids) == 1: # the sequence is of length 1, already applied
509
+ continue
510
+ if (
511
+ len(sequence_ids) > input_ids.shape[1]
512
+ ): # the sequence is longer than the context, ignore
513
+ continue
514
+ prefix_length = len(sequence_ids) - 1
515
+ last_token = sequence_ids[-1]
516
+ matching_rows = (
517
+ paddle.equal(
518
+ input_ids[:, -prefix_length:],
519
+ paddle.to_tensor(sequence_ids[:-1], dtype=input_ids.dtype),
520
+ )
521
+ .astype(paddle.int64)
522
+ .prod(axis=1)
523
+ )
524
+ bias[:, last_token] += paddle.where(
525
+ matching_rows == 1,
526
+ paddle.to_tensor(sequence_bias),
527
+ paddle.to_tensor(0.0),
528
+ )
529
+
530
+ # 5 - apply the bias to the scores
531
+ scores = scores + bias
532
+ return scores
533
+
534
+ def _prepare_bias_variables(self, scores):
535
+ vocabulary_size = scores.shape[-1]
536
+
537
+ # Check biased tokens out of bounds
538
+ invalid_biases = []
539
+ for sequence_ids in self.sequence_bias:
540
+ for token_id in sequence_ids:
541
+ if token_id >= vocabulary_size:
542
+ invalid_biases.append(token_id)
543
+ if len(invalid_biases) > 0:
544
+ raise ValueError(
545
+ f"The model vocabulary size is {vocabulary_size}, but the following tokens were being biased: "
546
+ f"{invalid_biases}"
547
+ )
548
+
549
+ # Precompute the bias tensors to be applied. Sequences of length 1 are kept separately, as they can be applied
550
+ # with simpler logic.
551
+ self.length_1_bias = paddle.zeros((vocabulary_size,))
552
+ for sequence_ids, bias in self.sequence_bias.items():
553
+ if len(sequence_ids) == 1:
554
+ self.length_1_bias[sequence_ids[-1]] = bias
555
+
556
+ self.prepared_bias_variables = True
557
+
558
+ def _validate_arguments(self):
559
+ sequence_bias = self.sequence_bias
560
+ if not isinstance(sequence_bias, dict) or len(sequence_bias) == 0:
561
+ raise ValueError(
562
+ f"`sequence_bias` has to be a non-empty dictionary, but is {sequence_bias}."
563
+ )
564
+ if any(
565
+ not isinstance(sequence_ids, tuple) for sequence_ids in sequence_bias.keys()
566
+ ):
567
+ raise ValueError(
568
+ f"`sequence_bias` has to be a dict with tuples as keys, but is {sequence_bias}."
569
+ )
570
+ if any(
571
+ any(
572
+ (not isinstance(token_id, (int, np.integer)) or token_id < 0)
573
+ for token_id in sequence_ids
574
+ )
575
+ or len(sequence_ids) == 0
576
+ for sequence_ids in sequence_bias.keys()
577
+ ):
578
+ raise ValueError(
579
+ f"Each key in `sequence_bias` has to be a non-empty tuple of positive integers, but is "
580
+ f"{sequence_bias}."
581
+ )
582
+ if any(not isinstance(bias, float) for bias in sequence_bias.values()):
583
+ raise ValueError(
584
+ f"`sequence_bias` has to be a dict with floats as values, but is {sequence_bias}."
585
+ )
586
+
587
+
588
+ class NoBadWordsLogitsProcessor(SequenceBiasLogitsProcessor):
589
+ """
590
+ [`LogitsProcessor`] that enforces that specified sequences will never be selected.
591
+
592
+ <Tip>
593
+
594
+ In order to get the token ids of the words that should not appear in the generated text, make sure to set
595
+ `add_prefix_space=True` when initializing the tokenizer, and use `tokenizer(bad_words,
596
+ add_special_tokens=False).input_ids`. The `add_prefix_space` argument is only supported for some slow tokenizers,
597
+ as fast tokenizers' prefixing behaviours come from `pre tokenizers`. Read more
598
+ [here](https://huggingface.co/docs/tokenizers/api/pre-tokenizers).
599
+
600
+ </Tip>
601
+
602
+ Args:
603
+ bad_words_ids (`List[List[int]]`):
604
+ List of list of token ids that are not allowed to be generated.
605
+ eos_token_id (`Union[int, List[int]]`):
606
+ The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
607
+
608
+ Examples:
609
+
610
+ ```python
611
+ >>> from paddlenlp.transformers import AutoTokenizer, AutoModelForCausalLM
612
+
613
+ >>> model = AutoModelForCausalLM.from_pretrained("gpt2-en")
614
+ >>> tokenizer = AutoTokenizer.from_pretrained("gpt2-en")
615
+ >>> inputs = tokenizer(["In a word, the cake is a"], return_tensors="pt")
616
+
617
+ >>> output_ids = model.generate(inputs["input_ids"], max_new_tokens=5, pad_token_id=tokenizer.eos_token_id)
618
+ >>> print(tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0])
619
+ In a word, the cake is a bit of a mess.
620
+
621
+ >>> # Now let's take the bad words out. Please note that the tokenizer is initialized differently
622
+ >>> tokenizer_with_prefix_space = AutoTokenizer.from_pretrained("gpt2-en", add_prefix_space=True)
623
+
624
+
625
+ >>> def get_tokens_as_list(word_list):
626
+ ... "Converts a sequence of words into a list of tokens"
627
+ ... tokens_list = []
628
+ ... for word in word_list:
629
+ ... tokenized_word = tokenizer_with_prefix_space([word], add_special_tokens=False).input_ids[0]
630
+ ... tokens_list.append(tokenized_word)
631
+ ... return tokens_list
632
+
633
+
634
+ >>> bad_words_ids = get_tokens_as_list(word_list=["mess"])
635
+ >>> output_ids = model.generate(
636
+ ... inputs["input_ids"], max_new_tokens=5, bad_words_ids=bad_words_ids, pad_token_id=tokenizer.eos_token_id
637
+ ... )
638
+ >>> print(tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0])
639
+ In a word, the cake is a bit of a surprise.
640
+ ```
641
+
642
+ >>> from paddlenlp.transformers.generation import NoBadWordsLogitsProcessor, LogitsProcessorList
643
+ >>> logits_processors = LogitsProcessorList([NoBadWordsLogitsProcessor([[5,6]], eos_token_id=tokenizer.eos_token_id)])
644
+ >>> output_ids = model.generate(
645
+ ... inputs["input_ids"], max_new_tokens=5, logits_processors=logits_processors, pad_token_id=tokenizer.eos_token_id
646
+ ... )
647
+ >>> print(tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0])
648
+ In a word, the cake is a bit of a surprise.
649
+ ```
650
+ """
651
+
652
+ def __init__(
653
+ self, bad_words_ids: List[List[int]], eos_token_id: Union[int, List[int]]
654
+ ):
655
+ self.bad_word_ids = bad_words_ids
656
+ self._validate_arguments()
657
+
658
+ # Filter EOS token from bad_words_ids
659
+ if eos_token_id is None:
660
+ eos_token_id = []
661
+ if isinstance(eos_token_id, int):
662
+ eos_token_id = [eos_token_id]
663
+ bad_words_ids = list(
664
+ filter(
665
+ lambda bad_token_seq: all(bad_token_seq != [i] for i in eos_token_id),
666
+ bad_words_ids,
667
+ )
668
+ )
669
+
670
+ # Forbidding a sequence is equivalent to setting its bias to -inf
671
+ sequence_bias = {tuple(sequence): float("-inf") for sequence in bad_words_ids}
672
+ super().__init__(sequence_bias=sequence_bias)
673
+
674
+ def _validate_arguments(self):
675
+ bad_words_ids = self.bad_word_ids
676
+ if not isinstance(bad_words_ids, list) or len(bad_words_ids) == 0:
677
+ raise ValueError(
678
+ f"`bad_words_ids` has to be a non-empty list, but is {bad_words_ids}."
679
+ )
680
+ if any(not isinstance(bad_word_ids, list) for bad_word_ids in bad_words_ids):
681
+ raise ValueError(
682
+ f"`bad_words_ids` has to be a list of lists, but is {bad_words_ids}."
683
+ )
684
+ if any(
685
+ any(
686
+ (not isinstance(token_id, (int, np.integer)) or token_id < 0)
687
+ for token_id in bad_word_ids
688
+ )
689
+ for bad_word_ids in bad_words_ids
690
+ ):
691
+ raise ValueError(
692
+ f"Each list in `bad_words_ids` has to be a list of positive integers, but is {bad_words_ids}."
693
+ )
694
+
695
+
696
+ class PrefixConstrainedLogitsProcessor(LogitsProcessor):
697
+ r"""
698
+ [`LogitsProcessor`] that enforces constrained generation and is useful for prefix-conditioned constrained
699
+ generation. See [Autoregressive Entity Retrieval](https://arxiv.org/abs/2010.00904) for more information.
700
+
701
+ Args:
702
+ prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`):
703
+ This function constraints the beam search to allowed tokens only at each step. This function takes 2
704
+ arguments `inputs_ids` and the batch ID `batch_id`. It has to return a list with the allowed tokens for the
705
+ next generation step conditioned on the previously generated tokens `inputs_ids` and the batch ID
706
+ `batch_id`.
707
+ """
708
+
709
+ def __init__(
710
+ self,
711
+ prefix_allowed_tokens_fn: Callable[[int, paddle.Tensor], List[int]],
712
+ num_beams: int,
713
+ ):
714
+ self._prefix_allowed_tokens_fn = prefix_allowed_tokens_fn
715
+ self._num_beams = num_beams
716
+
717
+ def __call__(
718
+ self, input_ids: paddle.Tensor, scores: paddle.Tensor
719
+ ) -> paddle.Tensor:
720
+ mask = paddle.full_like(scores, paddle.finfo(scores.dtype).min)
721
+ for batch_id, beam_sent in enumerate(
722
+ input_ids.reshape([-1, self._num_beams, input_ids.shape[-1]])
723
+ ):
724
+ for beam_id, sent in enumerate(beam_sent):
725
+ mask[
726
+ batch_id * self._num_beams + beam_id,
727
+ self._prefix_allowed_tokens_fn(batch_id, sent),
728
+ ] = 0
729
+
730
+ return scores + mask