paddlex 3.0.0b2__py3-none-any.whl → 3.0.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (940) hide show
  1. paddlex/.version +1 -1
  2. paddlex/__init__.py +1 -0
  3. paddlex/__main__.py +3 -4
  4. paddlex/configs/modules/3d_bev_detection/BEVFusion.yaml +38 -0
  5. paddlex/configs/modules/face_feature/MobileFaceNet.yaml +41 -0
  6. paddlex/configs/modules/face_feature/ResNet50_face.yaml +41 -0
  7. paddlex/configs/modules/formula_recognition/LaTeX_OCR_rec.yaml +40 -0
  8. paddlex/configs/modules/formula_recognition/PP-FormulaNet-L.yaml +40 -0
  9. paddlex/configs/modules/formula_recognition/PP-FormulaNet-S.yaml +40 -0
  10. paddlex/configs/modules/formula_recognition/UniMERNet.yaml +40 -0
  11. paddlex/configs/modules/image_classification/CLIP_vit_base_patch16_224.yaml +41 -0
  12. paddlex/configs/modules/image_classification/CLIP_vit_large_patch14_224.yaml +41 -0
  13. paddlex/configs/modules/image_classification/ConvNeXt_large_384.yaml +41 -0
  14. paddlex/configs/modules/keypoint_detection/PP-TinyPose_128x96.yaml +40 -0
  15. paddlex/configs/modules/keypoint_detection/PP-TinyPose_256x192.yaml +40 -0
  16. paddlex/configs/modules/layout_detection/PP-DocLayout-L.yaml +40 -0
  17. paddlex/configs/modules/layout_detection/PP-DocLayout-M.yaml +40 -0
  18. paddlex/configs/modules/layout_detection/PP-DocLayout-S.yaml +40 -0
  19. paddlex/configs/modules/multilingual_speech_recognition/whisper_base.yaml +12 -0
  20. paddlex/configs/modules/multilingual_speech_recognition/whisper_large.yaml +12 -0
  21. paddlex/configs/modules/multilingual_speech_recognition/whisper_medium.yaml +12 -0
  22. paddlex/configs/modules/multilingual_speech_recognition/whisper_small.yaml +12 -0
  23. paddlex/configs/modules/multilingual_speech_recognition/whisper_tiny.yaml +12 -0
  24. paddlex/configs/modules/object_detection/Co-DINO-R50.yaml +40 -0
  25. paddlex/configs/modules/object_detection/Co-DINO-Swin-L.yaml +40 -0
  26. paddlex/configs/modules/object_detection/Co-Deformable-DETR-R50.yaml +40 -0
  27. paddlex/configs/modules/object_detection/Co-Deformable-DETR-Swin-T.yaml +40 -0
  28. paddlex/configs/modules/object_detection/YOLOX-X.yaml +40 -0
  29. paddlex/configs/modules/open_vocabulary_detection/GroundingDINO-T.yaml +13 -0
  30. paddlex/configs/modules/open_vocabulary_segmentation/SAM-H_box.yaml +17 -0
  31. paddlex/configs/modules/open_vocabulary_segmentation/SAM-H_point.yaml +15 -0
  32. paddlex/configs/modules/rotated_object_detection/PP-YOLOE-R-L.yaml +40 -0
  33. paddlex/configs/modules/semantic_segmentation/MaskFormer_small.yaml +42 -0
  34. paddlex/configs/modules/semantic_segmentation/MaskFormer_tiny.yaml +42 -0
  35. paddlex/configs/modules/semantic_segmentation/SeaFormer_base.yaml +40 -0
  36. paddlex/configs/modules/semantic_segmentation/SeaFormer_large.yaml +40 -0
  37. paddlex/configs/modules/semantic_segmentation/SeaFormer_small.yaml +40 -0
  38. paddlex/configs/modules/semantic_segmentation/SeaFormer_tiny.yaml +40 -0
  39. paddlex/configs/modules/table_cells_detection/RT-DETR-L_wired_table_cell_det.yaml +40 -0
  40. paddlex/configs/modules/table_cells_detection/RT-DETR-L_wireless_table_cell_det.yaml +40 -0
  41. paddlex/configs/modules/table_classification/PP-LCNet_x1_0_table_cls.yaml +41 -0
  42. paddlex/configs/modules/table_structure_recognition/SLANeXt_wired.yaml +39 -0
  43. paddlex/configs/modules/table_structure_recognition/SLANeXt_wireless.yaml +39 -0
  44. paddlex/configs/modules/text_detection/PP-OCRv3_mobile_det.yaml +40 -0
  45. paddlex/configs/modules/text_detection/PP-OCRv3_server_det.yaml +40 -0
  46. paddlex/configs/modules/text_recognition/PP-OCRv3_mobile_rec.yaml +39 -0
  47. paddlex/configs/modules/text_recognition/PP-OCRv4_server_rec_doc.yaml +39 -0
  48. paddlex/configs/modules/text_recognition/arabic_PP-OCRv3_mobile_rec.yaml +39 -0
  49. paddlex/configs/modules/text_recognition/chinese_cht_PP-OCRv3_mobile_rec.yaml +39 -0
  50. paddlex/configs/modules/text_recognition/cyrillic_PP-OCRv3_mobile_rec.yaml +39 -0
  51. paddlex/configs/modules/text_recognition/devanagari_PP-OCRv3_mobile_rec.yaml +39 -0
  52. paddlex/configs/modules/text_recognition/en_PP-OCRv3_mobile_rec.yaml +39 -0
  53. paddlex/configs/modules/text_recognition/en_PP-OCRv4_mobile_rec.yaml +39 -0
  54. paddlex/configs/modules/text_recognition/japan_PP-OCRv3_mobile_rec.yaml +39 -0
  55. paddlex/configs/modules/text_recognition/ka_PP-OCRv3_mobile_rec.yaml +39 -0
  56. paddlex/configs/modules/text_recognition/korean_PP-OCRv3_mobile_rec.yaml +39 -0
  57. paddlex/configs/modules/text_recognition/latin_PP-OCRv3_mobile_rec.yaml +39 -0
  58. paddlex/configs/modules/text_recognition/ta_PP-OCRv3_mobile_rec.yaml +39 -0
  59. paddlex/configs/modules/text_recognition/te_PP-OCRv3_mobile_rec.yaml +39 -0
  60. paddlex/configs/modules/textline_orientation/PP-LCNet_x0_25_textline_ori.yaml +41 -0
  61. paddlex/configs/modules/video_classification/PP-TSM-R50_8frames_uniform.yaml +42 -0
  62. paddlex/configs/modules/video_classification/PP-TSMv2-LCNetV2_16frames_uniform.yaml +42 -0
  63. paddlex/configs/modules/video_classification/PP-TSMv2-LCNetV2_8frames_uniform.yaml +42 -0
  64. paddlex/configs/modules/video_detection/YOWO.yaml +40 -0
  65. paddlex/configs/pipelines/3d_bev_detection.yaml +9 -0
  66. paddlex/configs/pipelines/OCR.yaml +44 -0
  67. paddlex/configs/pipelines/PP-ChatOCRv3-doc.yaml +149 -0
  68. paddlex/configs/pipelines/PP-ChatOCRv4-doc.yaml +184 -0
  69. paddlex/configs/pipelines/PP-ShiTuV2.yaml +18 -0
  70. paddlex/configs/pipelines/PP-StructureV3.yaml +226 -0
  71. paddlex/configs/pipelines/anomaly_detection.yaml +8 -0
  72. paddlex/configs/pipelines/doc_preprocessor.yaml +15 -0
  73. paddlex/configs/pipelines/face_recognition.yaml +18 -0
  74. paddlex/configs/pipelines/formula_recognition.yaml +39 -0
  75. paddlex/configs/pipelines/human_keypoint_detection.yaml +17 -0
  76. paddlex/configs/pipelines/image_classification.yaml +10 -0
  77. paddlex/configs/pipelines/image_multilabel_classification.yaml +9 -0
  78. paddlex/configs/pipelines/instance_segmentation.yaml +10 -0
  79. paddlex/configs/pipelines/layout_parsing.yaml +101 -0
  80. paddlex/configs/pipelines/multilingual_speech_recognition.yaml +9 -0
  81. paddlex/configs/pipelines/object_detection.yaml +10 -0
  82. paddlex/configs/pipelines/open_vocabulary_detection.yaml +12 -0
  83. paddlex/configs/pipelines/open_vocabulary_segmentation.yaml +13 -0
  84. paddlex/configs/pipelines/pedestrian_attribute_recognition.yaml +15 -0
  85. paddlex/configs/pipelines/rotated_object_detection.yaml +10 -0
  86. paddlex/configs/pipelines/seal_recognition.yaml +51 -0
  87. paddlex/configs/pipelines/semantic_segmentation.yaml +10 -0
  88. paddlex/configs/pipelines/small_object_detection.yaml +10 -0
  89. paddlex/configs/pipelines/table_recognition.yaml +56 -0
  90. paddlex/configs/pipelines/table_recognition_v2.yaml +76 -0
  91. paddlex/configs/pipelines/ts_anomaly_detection.yaml +8 -0
  92. paddlex/configs/pipelines/ts_classification.yaml +8 -0
  93. paddlex/configs/pipelines/ts_forecast.yaml +8 -0
  94. paddlex/configs/pipelines/vehicle_attribute_recognition.yaml +15 -0
  95. paddlex/configs/pipelines/video_classification.yaml +9 -0
  96. paddlex/configs/pipelines/video_detection.yaml +10 -0
  97. paddlex/engine.py +1 -1
  98. paddlex/hpip_links.html +19 -0
  99. paddlex/inference/__init__.py +3 -1
  100. paddlex/inference/common/batch_sampler/__init__.py +20 -0
  101. paddlex/inference/common/batch_sampler/audio_batch_sampler.py +84 -0
  102. paddlex/inference/common/batch_sampler/base_batch_sampler.py +90 -0
  103. paddlex/inference/common/batch_sampler/det_3d_batch_sampler.py +147 -0
  104. paddlex/inference/common/batch_sampler/image_batch_sampler.py +136 -0
  105. paddlex/inference/common/batch_sampler/ts_batch_sampler.py +110 -0
  106. paddlex/inference/common/batch_sampler/video_batch_sampler.py +94 -0
  107. paddlex/inference/common/reader/__init__.py +19 -0
  108. paddlex/inference/common/reader/audio_reader.py +46 -0
  109. paddlex/inference/common/reader/det_3d_reader.py +239 -0
  110. paddlex/inference/common/reader/image_reader.py +69 -0
  111. paddlex/inference/common/reader/ts_reader.py +45 -0
  112. paddlex/inference/common/reader/video_reader.py +42 -0
  113. paddlex/inference/common/result/__init__.py +29 -0
  114. paddlex/inference/common/result/base_cv_result.py +31 -0
  115. paddlex/inference/common/result/base_result.py +70 -0
  116. paddlex/inference/common/result/base_ts_result.py +42 -0
  117. paddlex/inference/common/result/base_video_result.py +36 -0
  118. paddlex/inference/common/result/mixin.py +703 -0
  119. paddlex/inference/models/3d_bev_detection/__init__.py +15 -0
  120. paddlex/inference/models/3d_bev_detection/predictor.py +314 -0
  121. paddlex/inference/models/3d_bev_detection/processors.py +978 -0
  122. paddlex/inference/models/3d_bev_detection/result.py +65 -0
  123. paddlex/inference/models/3d_bev_detection/visualizer_3d.py +131 -0
  124. paddlex/inference/models/__init__.py +37 -13
  125. paddlex/inference/models/anomaly_detection/__init__.py +15 -0
  126. paddlex/inference/models/anomaly_detection/predictor.py +145 -0
  127. paddlex/inference/models/anomaly_detection/processors.py +46 -0
  128. paddlex/inference/models/anomaly_detection/result.py +70 -0
  129. paddlex/inference/models/base/__init__.py +1 -2
  130. paddlex/inference/models/base/predictor/__init__.py +16 -0
  131. paddlex/inference/models/base/predictor/base_predictor.py +175 -0
  132. paddlex/inference/models/base/predictor/basic_predictor.py +139 -0
  133. paddlex/inference/models/common/__init__.py +35 -0
  134. paddlex/inference/models/common/static_infer.py +329 -0
  135. paddlex/inference/models/common/tokenizer/__init__.py +17 -0
  136. paddlex/inference/models/common/tokenizer/bert_tokenizer.py +655 -0
  137. paddlex/inference/models/common/tokenizer/gpt_tokenizer.py +451 -0
  138. paddlex/inference/models/common/tokenizer/tokenizer_utils.py +2141 -0
  139. paddlex/inference/models/common/tokenizer/tokenizer_utils_base.py +3504 -0
  140. paddlex/inference/models/common/tokenizer/utils.py +66 -0
  141. paddlex/inference/models/common/tokenizer/vocab.py +647 -0
  142. paddlex/inference/models/common/ts/__init__.py +15 -0
  143. paddlex/inference/models/common/ts/funcs.py +533 -0
  144. paddlex/inference/models/common/ts/processors.py +313 -0
  145. paddlex/inference/models/common/vision/__init__.py +23 -0
  146. paddlex/inference/models/common/vision/funcs.py +93 -0
  147. paddlex/inference/models/common/vision/processors.py +270 -0
  148. paddlex/inference/models/face_feature/__init__.py +15 -0
  149. paddlex/inference/models/face_feature/predictor.py +65 -0
  150. paddlex/inference/models/formula_recognition/__init__.py +15 -0
  151. paddlex/inference/models/formula_recognition/predictor.py +203 -0
  152. paddlex/inference/models/formula_recognition/processors.py +986 -0
  153. paddlex/inference/models/formula_recognition/result.py +403 -0
  154. paddlex/inference/models/image_classification/__init__.py +15 -0
  155. paddlex/inference/models/image_classification/predictor.py +182 -0
  156. paddlex/inference/models/image_classification/processors.py +87 -0
  157. paddlex/inference/models/image_classification/result.py +92 -0
  158. paddlex/inference/models/image_feature/__init__.py +15 -0
  159. paddlex/inference/models/image_feature/predictor.py +156 -0
  160. paddlex/inference/models/image_feature/processors.py +29 -0
  161. paddlex/inference/models/image_feature/result.py +33 -0
  162. paddlex/inference/models/image_multilabel_classification/__init__.py +15 -0
  163. paddlex/inference/models/image_multilabel_classification/predictor.py +94 -0
  164. paddlex/inference/models/image_multilabel_classification/processors.py +85 -0
  165. paddlex/inference/models/image_multilabel_classification/result.py +95 -0
  166. paddlex/inference/models/image_unwarping/__init__.py +15 -0
  167. paddlex/inference/models/image_unwarping/predictor.py +105 -0
  168. paddlex/inference/models/image_unwarping/processors.py +88 -0
  169. paddlex/inference/models/image_unwarping/result.py +45 -0
  170. paddlex/inference/models/instance_segmentation/__init__.py +15 -0
  171. paddlex/inference/models/instance_segmentation/predictor.py +210 -0
  172. paddlex/inference/models/instance_segmentation/processors.py +105 -0
  173. paddlex/inference/models/instance_segmentation/result.py +161 -0
  174. paddlex/inference/models/keypoint_detection/__init__.py +15 -0
  175. paddlex/inference/models/keypoint_detection/predictor.py +188 -0
  176. paddlex/inference/models/keypoint_detection/processors.py +359 -0
  177. paddlex/inference/models/keypoint_detection/result.py +192 -0
  178. paddlex/inference/models/multilingual_speech_recognition/__init__.py +15 -0
  179. paddlex/inference/models/multilingual_speech_recognition/predictor.py +141 -0
  180. paddlex/inference/models/multilingual_speech_recognition/processors.py +1941 -0
  181. paddlex/inference/models/multilingual_speech_recognition/result.py +21 -0
  182. paddlex/inference/models/object_detection/__init__.py +15 -0
  183. paddlex/inference/models/object_detection/predictor.py +348 -0
  184. paddlex/inference/models/object_detection/processors.py +855 -0
  185. paddlex/inference/models/object_detection/result.py +113 -0
  186. paddlex/inference/models/object_detection/utils.py +68 -0
  187. paddlex/inference/models/open_vocabulary_detection/__init__.py +15 -0
  188. paddlex/inference/models/open_vocabulary_detection/predictor.py +155 -0
  189. paddlex/inference/models/open_vocabulary_detection/processors/__init__.py +15 -0
  190. paddlex/inference/models/open_vocabulary_detection/processors/groundingdino_processors.py +485 -0
  191. paddlex/inference/models/open_vocabulary_segmentation/__init__.py +15 -0
  192. paddlex/inference/models/open_vocabulary_segmentation/predictor.py +120 -0
  193. paddlex/inference/models/open_vocabulary_segmentation/processors/__init__.py +15 -0
  194. paddlex/inference/models/open_vocabulary_segmentation/processors/sam_processer.py +249 -0
  195. paddlex/inference/models/open_vocabulary_segmentation/results/__init__.py +15 -0
  196. paddlex/inference/models/open_vocabulary_segmentation/results/sam_result.py +147 -0
  197. paddlex/inference/models/semantic_segmentation/__init__.py +15 -0
  198. paddlex/inference/models/semantic_segmentation/predictor.py +167 -0
  199. paddlex/inference/models/semantic_segmentation/processors.py +114 -0
  200. paddlex/inference/models/semantic_segmentation/result.py +72 -0
  201. paddlex/inference/models/table_structure_recognition/__init__.py +15 -0
  202. paddlex/inference/models/table_structure_recognition/predictor.py +171 -0
  203. paddlex/inference/models/table_structure_recognition/processors.py +235 -0
  204. paddlex/inference/models/table_structure_recognition/result.py +70 -0
  205. paddlex/inference/models/text_detection/__init__.py +15 -0
  206. paddlex/inference/models/text_detection/predictor.py +191 -0
  207. paddlex/inference/models/text_detection/processors.py +466 -0
  208. paddlex/inference/models/text_detection/result.py +51 -0
  209. paddlex/inference/models/text_recognition/__init__.py +15 -0
  210. paddlex/inference/models/text_recognition/predictor.py +106 -0
  211. paddlex/inference/models/text_recognition/processors.py +231 -0
  212. paddlex/inference/models/text_recognition/result.py +75 -0
  213. paddlex/inference/models/ts_anomaly_detection/__init__.py +15 -0
  214. paddlex/inference/models/ts_anomaly_detection/predictor.py +146 -0
  215. paddlex/inference/models/ts_anomaly_detection/processors.py +94 -0
  216. paddlex/inference/models/ts_anomaly_detection/result.py +72 -0
  217. paddlex/inference/models/ts_classification/__init__.py +15 -0
  218. paddlex/inference/models/ts_classification/predictor.py +135 -0
  219. paddlex/inference/models/ts_classification/processors.py +117 -0
  220. paddlex/inference/models/ts_classification/result.py +78 -0
  221. paddlex/inference/models/ts_forecasting/__init__.py +15 -0
  222. paddlex/inference/models/ts_forecasting/predictor.py +159 -0
  223. paddlex/inference/models/ts_forecasting/processors.py +149 -0
  224. paddlex/inference/models/ts_forecasting/result.py +83 -0
  225. paddlex/inference/models/video_classification/__init__.py +15 -0
  226. paddlex/inference/models/video_classification/predictor.py +147 -0
  227. paddlex/inference/models/video_classification/processors.py +409 -0
  228. paddlex/inference/models/video_classification/result.py +92 -0
  229. paddlex/inference/models/video_detection/__init__.py +15 -0
  230. paddlex/inference/models/video_detection/predictor.py +136 -0
  231. paddlex/inference/models/video_detection/processors.py +450 -0
  232. paddlex/inference/models/video_detection/result.py +104 -0
  233. paddlex/inference/pipelines/3d_bev_detection/__init__.py +15 -0
  234. paddlex/inference/pipelines/3d_bev_detection/pipeline.py +67 -0
  235. paddlex/inference/pipelines/__init__.py +174 -73
  236. paddlex/inference/pipelines/anomaly_detection/__init__.py +15 -0
  237. paddlex/inference/pipelines/anomaly_detection/pipeline.py +62 -0
  238. paddlex/inference/pipelines/attribute_recognition/__init__.py +15 -0
  239. paddlex/inference/pipelines/attribute_recognition/pipeline.py +105 -0
  240. paddlex/inference/pipelines/attribute_recognition/result.py +100 -0
  241. paddlex/inference/pipelines/base.py +103 -57
  242. paddlex/inference/pipelines/components/__init__.py +23 -0
  243. paddlex/inference/pipelines/components/chat_server/__init__.py +16 -0
  244. paddlex/inference/pipelines/components/chat_server/base.py +39 -0
  245. paddlex/inference/pipelines/components/chat_server/openai_bot_chat.py +236 -0
  246. paddlex/inference/pipelines/components/common/__init__.py +18 -0
  247. paddlex/inference/pipelines/components/common/base_operator.py +36 -0
  248. paddlex/inference/pipelines/components/common/base_result.py +65 -0
  249. paddlex/inference/pipelines/components/common/convert_points_and_boxes.py +46 -0
  250. paddlex/inference/pipelines/components/common/crop_image_regions.py +550 -0
  251. paddlex/inference/pipelines/components/common/seal_det_warp.py +941 -0
  252. paddlex/inference/pipelines/components/common/sort_boxes.py +83 -0
  253. paddlex/inference/pipelines/components/faisser.py +352 -0
  254. paddlex/inference/pipelines/components/prompt_engineering/__init__.py +16 -0
  255. paddlex/inference/pipelines/components/prompt_engineering/base.py +35 -0
  256. paddlex/inference/pipelines/components/prompt_engineering/generate_ensemble_prompt.py +127 -0
  257. paddlex/inference/pipelines/components/prompt_engineering/generate_kie_prompt.py +148 -0
  258. paddlex/inference/pipelines/components/retriever/__init__.py +16 -0
  259. paddlex/inference/pipelines/components/retriever/base.py +226 -0
  260. paddlex/inference/pipelines/components/retriever/openai_bot_retriever.py +70 -0
  261. paddlex/inference/pipelines/components/retriever/qianfan_bot_retriever.py +163 -0
  262. paddlex/inference/pipelines/components/utils/__init__.py +13 -0
  263. paddlex/inference/pipelines/components/utils/mixin.py +206 -0
  264. paddlex/inference/pipelines/doc_preprocessor/__init__.py +15 -0
  265. paddlex/inference/pipelines/doc_preprocessor/pipeline.py +190 -0
  266. paddlex/inference/pipelines/doc_preprocessor/result.py +103 -0
  267. paddlex/inference/pipelines/face_recognition/__init__.py +15 -0
  268. paddlex/inference/pipelines/face_recognition/pipeline.py +61 -0
  269. paddlex/inference/pipelines/face_recognition/result.py +43 -0
  270. paddlex/inference/pipelines/formula_recognition/__init__.py +15 -0
  271. paddlex/inference/pipelines/formula_recognition/pipeline.py +303 -0
  272. paddlex/inference/pipelines/formula_recognition/result.py +291 -0
  273. paddlex/inference/pipelines/image_classification/__init__.py +15 -0
  274. paddlex/inference/pipelines/image_classification/pipeline.py +71 -0
  275. paddlex/inference/pipelines/image_multilabel_classification/__init__.py +15 -0
  276. paddlex/inference/pipelines/image_multilabel_classification/pipeline.py +78 -0
  277. paddlex/inference/pipelines/instance_segmentation/__init__.py +15 -0
  278. paddlex/inference/pipelines/instance_segmentation/pipeline.py +70 -0
  279. paddlex/inference/pipelines/keypoint_detection/__init__.py +15 -0
  280. paddlex/inference/pipelines/keypoint_detection/pipeline.py +137 -0
  281. paddlex/inference/pipelines/layout_parsing/__init__.py +2 -1
  282. paddlex/inference/pipelines/layout_parsing/pipeline.py +570 -0
  283. paddlex/inference/pipelines/layout_parsing/pipeline_v2.py +739 -0
  284. paddlex/inference/pipelines/layout_parsing/result.py +203 -0
  285. paddlex/inference/pipelines/layout_parsing/result_v2.py +470 -0
  286. paddlex/inference/pipelines/layout_parsing/utils.py +2385 -0
  287. paddlex/inference/pipelines/multilingual_speech_recognition/__init__.py +15 -0
  288. paddlex/inference/pipelines/multilingual_speech_recognition/pipeline.py +67 -0
  289. paddlex/inference/pipelines/object_detection/__init__.py +15 -0
  290. paddlex/inference/pipelines/object_detection/pipeline.py +95 -0
  291. paddlex/inference/pipelines/ocr/__init__.py +15 -0
  292. paddlex/inference/pipelines/ocr/pipeline.py +389 -0
  293. paddlex/inference/pipelines/ocr/result.py +248 -0
  294. paddlex/inference/pipelines/open_vocabulary_detection/__init__.py +15 -0
  295. paddlex/inference/pipelines/open_vocabulary_detection/pipeline.py +75 -0
  296. paddlex/inference/pipelines/open_vocabulary_segmentation/__init__.py +15 -0
  297. paddlex/inference/pipelines/open_vocabulary_segmentation/pipeline.py +89 -0
  298. paddlex/inference/pipelines/pp_chatocr/__init__.py +16 -0
  299. paddlex/inference/pipelines/pp_chatocr/pipeline_base.py +102 -0
  300. paddlex/inference/pipelines/pp_chatocr/pipeline_v3.py +773 -0
  301. paddlex/inference/pipelines/pp_chatocr/pipeline_v4.py +977 -0
  302. paddlex/inference/pipelines/pp_shitu_v2/__init__.py +15 -0
  303. paddlex/inference/pipelines/pp_shitu_v2/pipeline.py +152 -0
  304. paddlex/inference/pipelines/pp_shitu_v2/result.py +126 -0
  305. paddlex/inference/pipelines/rotated_object_detection/__init__.py +15 -0
  306. paddlex/inference/pipelines/rotated_object_detection/pipeline.py +74 -0
  307. paddlex/inference/pipelines/seal_recognition/__init__.py +15 -0
  308. paddlex/inference/pipelines/seal_recognition/pipeline.py +271 -0
  309. paddlex/inference/pipelines/seal_recognition/result.py +87 -0
  310. paddlex/inference/pipelines/semantic_segmentation/__init__.py +15 -0
  311. paddlex/inference/pipelines/semantic_segmentation/pipeline.py +74 -0
  312. paddlex/inference/pipelines/small_object_detection/__init__.py +15 -0
  313. paddlex/inference/pipelines/small_object_detection/pipeline.py +74 -0
  314. paddlex/inference/pipelines/table_recognition/__init__.py +2 -1
  315. paddlex/inference/pipelines/table_recognition/pipeline.py +462 -0
  316. paddlex/inference/pipelines/table_recognition/pipeline_v2.py +792 -0
  317. paddlex/inference/pipelines/table_recognition/result.py +216 -0
  318. paddlex/inference/pipelines/table_recognition/table_recognition_post_processing.py +362 -0
  319. paddlex/inference/pipelines/table_recognition/table_recognition_post_processing_v2.py +470 -0
  320. paddlex/inference/pipelines/table_recognition/utils.py +23 -436
  321. paddlex/inference/pipelines/ts_anomaly_detection/__init__.py +15 -0
  322. paddlex/inference/pipelines/ts_anomaly_detection/pipeline.py +62 -0
  323. paddlex/inference/pipelines/ts_classification/__init__.py +15 -0
  324. paddlex/inference/pipelines/ts_classification/pipeline.py +62 -0
  325. paddlex/inference/pipelines/ts_forecasting/__init__.py +15 -0
  326. paddlex/inference/pipelines/ts_forecasting/pipeline.py +62 -0
  327. paddlex/inference/pipelines/video_classification/__init__.py +15 -0
  328. paddlex/inference/pipelines/video_classification/pipeline.py +68 -0
  329. paddlex/inference/pipelines/video_detection/__init__.py +15 -0
  330. paddlex/inference/pipelines/video_detection/pipeline.py +73 -0
  331. paddlex/inference/serving/__init__.py +13 -0
  332. paddlex/inference/serving/basic_serving/__init__.py +18 -0
  333. paddlex/inference/serving/basic_serving/_app.py +209 -0
  334. paddlex/inference/serving/basic_serving/_pipeline_apps/__init__.py +41 -0
  335. paddlex/inference/serving/basic_serving/_pipeline_apps/_common/__init__.py +13 -0
  336. paddlex/inference/serving/basic_serving/_pipeline_apps/_common/common.py +96 -0
  337. paddlex/inference/serving/basic_serving/_pipeline_apps/_common/image_recognition.py +36 -0
  338. paddlex/inference/serving/basic_serving/_pipeline_apps/_common/ocr.py +90 -0
  339. paddlex/inference/serving/basic_serving/_pipeline_apps/anomaly_detection.py +64 -0
  340. paddlex/inference/serving/basic_serving/_pipeline_apps/doc_preprocessor.py +97 -0
  341. paddlex/inference/serving/basic_serving/_pipeline_apps/face_recognition.py +223 -0
  342. paddlex/inference/serving/basic_serving/_pipeline_apps/formula_recognition.py +97 -0
  343. paddlex/inference/serving/basic_serving/_pipeline_apps/human_keypoint_detection.py +78 -0
  344. paddlex/inference/serving/basic_serving/_pipeline_apps/image_classification.py +66 -0
  345. paddlex/inference/serving/basic_serving/_pipeline_apps/image_multilabel_classification.py +70 -0
  346. paddlex/inference/serving/basic_serving/_pipeline_apps/instance_segmentation.py +81 -0
  347. paddlex/inference/serving/basic_serving/_pipeline_apps/layout_parsing.py +115 -0
  348. paddlex/inference/serving/basic_serving/_pipeline_apps/m_3d_bev_detection.py +76 -0
  349. paddlex/inference/serving/basic_serving/_pipeline_apps/multilingual_speech_recognition.py +89 -0
  350. paddlex/inference/serving/basic_serving/_pipeline_apps/object_detection.py +74 -0
  351. paddlex/inference/serving/basic_serving/_pipeline_apps/ocr.py +99 -0
  352. paddlex/inference/serving/basic_serving/_pipeline_apps/open_vocabulary_detection.py +78 -0
  353. paddlex/inference/serving/basic_serving/_pipeline_apps/open_vocabulary_segmentation.py +85 -0
  354. paddlex/inference/serving/basic_serving/_pipeline_apps/pedestrian_attribute_recognition.py +81 -0
  355. paddlex/inference/serving/basic_serving/_pipeline_apps/pp_chatocrv3_doc.py +191 -0
  356. paddlex/inference/serving/basic_serving/_pipeline_apps/pp_chatocrv4_doc.py +221 -0
  357. paddlex/inference/serving/basic_serving/_pipeline_apps/pp_shituv2.py +218 -0
  358. paddlex/inference/serving/basic_serving/_pipeline_apps/pp_structurev3.py +136 -0
  359. paddlex/inference/serving/basic_serving/_pipeline_apps/rotated_object_detection.py +78 -0
  360. paddlex/inference/serving/basic_serving/_pipeline_apps/seal_recognition.py +103 -0
  361. paddlex/inference/serving/basic_serving/_pipeline_apps/semantic_segmentation.py +64 -0
  362. paddlex/inference/serving/basic_serving/_pipeline_apps/small_object_detection.py +69 -0
  363. paddlex/inference/serving/basic_serving/_pipeline_apps/table_recognition.py +105 -0
  364. paddlex/inference/serving/basic_serving/_pipeline_apps/table_recognition_v2.py +107 -0
  365. paddlex/inference/serving/basic_serving/_pipeline_apps/ts_anomaly_detection.py +62 -0
  366. paddlex/inference/serving/basic_serving/_pipeline_apps/ts_classification.py +61 -0
  367. paddlex/inference/serving/basic_serving/_pipeline_apps/ts_forecast.py +62 -0
  368. paddlex/inference/serving/basic_serving/_pipeline_apps/vehicle_attribute_recognition.py +81 -0
  369. paddlex/inference/serving/basic_serving/_pipeline_apps/video_classification.py +73 -0
  370. paddlex/inference/serving/basic_serving/_pipeline_apps/video_detection.py +89 -0
  371. paddlex/inference/serving/basic_serving/_server.py +35 -0
  372. paddlex/inference/serving/infra/__init__.py +13 -0
  373. paddlex/inference/serving/infra/config.py +36 -0
  374. paddlex/inference/serving/infra/models.py +72 -0
  375. paddlex/inference/serving/infra/storage.py +175 -0
  376. paddlex/inference/serving/infra/utils.py +259 -0
  377. paddlex/inference/serving/schemas/__init__.py +13 -0
  378. paddlex/inference/serving/schemas/anomaly_detection.py +39 -0
  379. paddlex/inference/serving/schemas/doc_preprocessor.py +54 -0
  380. paddlex/inference/serving/schemas/face_recognition.py +124 -0
  381. paddlex/inference/serving/schemas/formula_recognition.py +56 -0
  382. paddlex/inference/serving/schemas/human_keypoint_detection.py +55 -0
  383. paddlex/inference/serving/schemas/image_classification.py +45 -0
  384. paddlex/inference/serving/schemas/image_multilabel_classification.py +47 -0
  385. paddlex/inference/serving/schemas/instance_segmentation.py +53 -0
  386. paddlex/inference/serving/schemas/layout_parsing.py +72 -0
  387. paddlex/inference/serving/schemas/m_3d_bev_detection.py +48 -0
  388. paddlex/inference/serving/schemas/multilingual_speech_recognition.py +57 -0
  389. paddlex/inference/serving/schemas/object_detection.py +52 -0
  390. paddlex/inference/serving/schemas/ocr.py +60 -0
  391. paddlex/inference/serving/schemas/open_vocabulary_detection.py +52 -0
  392. paddlex/inference/serving/schemas/open_vocabulary_segmentation.py +52 -0
  393. paddlex/inference/serving/schemas/pedestrian_attribute_recognition.py +61 -0
  394. paddlex/inference/serving/schemas/pp_chatocrv3_doc.py +134 -0
  395. paddlex/inference/serving/schemas/pp_chatocrv4_doc.py +151 -0
  396. paddlex/inference/serving/schemas/pp_shituv2.py +124 -0
  397. paddlex/inference/serving/schemas/pp_structurev3.py +84 -0
  398. paddlex/inference/serving/schemas/rotated_object_detection.py +52 -0
  399. paddlex/inference/serving/schemas/seal_recognition.py +62 -0
  400. paddlex/inference/serving/schemas/semantic_segmentation.py +45 -0
  401. paddlex/inference/serving/schemas/shared/__init__.py +13 -0
  402. paddlex/inference/serving/schemas/shared/classification.py +23 -0
  403. paddlex/inference/serving/schemas/shared/image_segmentation.py +28 -0
  404. paddlex/inference/serving/schemas/shared/object_detection.py +24 -0
  405. paddlex/inference/serving/schemas/shared/ocr.py +25 -0
  406. paddlex/inference/serving/schemas/small_object_detection.py +52 -0
  407. paddlex/inference/serving/schemas/table_recognition.py +64 -0
  408. paddlex/inference/serving/schemas/table_recognition_v2.py +66 -0
  409. paddlex/inference/serving/schemas/ts_anomaly_detection.py +37 -0
  410. paddlex/inference/serving/schemas/ts_classification.py +38 -0
  411. paddlex/inference/serving/schemas/ts_forecast.py +37 -0
  412. paddlex/inference/serving/schemas/vehicle_attribute_recognition.py +61 -0
  413. paddlex/inference/serving/schemas/video_classification.py +44 -0
  414. paddlex/inference/serving/schemas/video_detection.py +56 -0
  415. paddlex/inference/utils/benchmark.py +23 -11
  416. paddlex/inference/utils/get_pipeline_path.py +2 -1
  417. paddlex/inference/utils/io/__init__.py +3 -0
  418. paddlex/inference/utils/io/readers.py +164 -17
  419. paddlex/inference/utils/io/writers.py +85 -2
  420. paddlex/inference/utils/new_ir_blacklist.py +6 -0
  421. paddlex/inference/utils/official_models.py +277 -211
  422. paddlex/inference/utils/pp_option.py +24 -4
  423. paddlex/model.py +12 -5
  424. paddlex/modules/3d_bev_detection/__init__.py +18 -0
  425. paddlex/modules/3d_bev_detection/dataset_checker/__init__.py +95 -0
  426. paddlex/modules/3d_bev_detection/dataset_checker/dataset_src/__init__.py +17 -0
  427. paddlex/modules/3d_bev_detection/dataset_checker/dataset_src/analyse_dataset.py +106 -0
  428. paddlex/modules/3d_bev_detection/dataset_checker/dataset_src/check_dataset.py +102 -0
  429. paddlex/modules/3d_bev_detection/evaluator.py +46 -0
  430. paddlex/modules/3d_bev_detection/exportor.py +22 -0
  431. paddlex/modules/3d_bev_detection/model_list.py +18 -0
  432. paddlex/modules/3d_bev_detection/trainer.py +70 -0
  433. paddlex/modules/__init__.py +34 -1
  434. paddlex/modules/base/build_model.py +1 -1
  435. paddlex/modules/base/dataset_checker/dataset_checker.py +6 -1
  436. paddlex/modules/base/evaluator.py +20 -4
  437. paddlex/modules/base/exportor.py +30 -5
  438. paddlex/modules/base/trainer.py +29 -6
  439. paddlex/modules/face_recognition/trainer.py +1 -23
  440. paddlex/modules/formula_recognition/__init__.py +5 -0
  441. paddlex/modules/formula_recognition/dataset_checker/__init__.py +113 -0
  442. paddlex/modules/formula_recognition/dataset_checker/dataset_src/__init__.py +19 -0
  443. paddlex/modules/formula_recognition/dataset_checker/dataset_src/analyse_dataset.py +157 -0
  444. paddlex/modules/formula_recognition/dataset_checker/dataset_src/check_dataset.py +80 -0
  445. paddlex/modules/formula_recognition/dataset_checker/dataset_src/convert_dataset.py +94 -0
  446. paddlex/modules/formula_recognition/dataset_checker/dataset_src/split_dataset.py +81 -0
  447. paddlex/modules/formula_recognition/evaluator.py +77 -0
  448. paddlex/modules/formula_recognition/exportor.py +22 -0
  449. paddlex/modules/formula_recognition/model_list.py +3 -0
  450. paddlex/modules/formula_recognition/trainer.py +121 -0
  451. paddlex/modules/image_classification/model_list.py +2 -0
  452. paddlex/modules/instance_segmentation/dataset_checker/__init__.py +15 -0
  453. paddlex/modules/keypoint_detection/__init__.py +18 -0
  454. paddlex/modules/keypoint_detection/dataset_checker/__init__.py +56 -0
  455. paddlex/modules/keypoint_detection/dataset_checker/dataset_src/__init__.py +15 -0
  456. paddlex/modules/keypoint_detection/dataset_checker/dataset_src/check_dataset.py +86 -0
  457. paddlex/modules/keypoint_detection/dataset_checker/dataset_src/utils/__init__.py +13 -0
  458. paddlex/modules/keypoint_detection/dataset_checker/dataset_src/utils/visualizer.py +119 -0
  459. paddlex/modules/keypoint_detection/evaluator.py +41 -0
  460. paddlex/modules/keypoint_detection/exportor.py +22 -0
  461. paddlex/modules/keypoint_detection/model_list.py +16 -0
  462. paddlex/modules/keypoint_detection/trainer.py +39 -0
  463. paddlex/modules/multilingual_speech_recognition/__init__.py +18 -0
  464. paddlex/modules/multilingual_speech_recognition/dataset_checker.py +27 -0
  465. paddlex/modules/multilingual_speech_recognition/evaluator.py +27 -0
  466. paddlex/modules/multilingual_speech_recognition/exportor.py +27 -0
  467. paddlex/modules/multilingual_speech_recognition/model_list.py +22 -0
  468. paddlex/modules/multilingual_speech_recognition/trainer.py +40 -0
  469. paddlex/modules/object_detection/evaluator.py +12 -1
  470. paddlex/modules/object_detection/model_list.py +10 -0
  471. paddlex/modules/object_detection/trainer.py +15 -1
  472. paddlex/modules/open_vocabulary_detection/__init__.py +18 -0
  473. paddlex/modules/open_vocabulary_detection/dataset_checker.py +29 -0
  474. paddlex/modules/open_vocabulary_detection/evaluator.py +29 -0
  475. paddlex/modules/open_vocabulary_detection/exportor.py +29 -0
  476. paddlex/modules/open_vocabulary_detection/model_list.py +18 -0
  477. paddlex/modules/open_vocabulary_detection/trainer.py +42 -0
  478. paddlex/modules/open_vocabulary_segmentation/__init__.py +18 -0
  479. paddlex/modules/open_vocabulary_segmentation/dataset_checker.py +29 -0
  480. paddlex/modules/open_vocabulary_segmentation/evaluator.py +29 -0
  481. paddlex/modules/open_vocabulary_segmentation/exportor.py +29 -0
  482. paddlex/modules/open_vocabulary_segmentation/model_list.py +19 -0
  483. paddlex/modules/open_vocabulary_segmentation/trainer.py +42 -0
  484. paddlex/modules/semantic_segmentation/dataset_checker/__init__.py +15 -0
  485. paddlex/modules/semantic_segmentation/exportor.py +9 -0
  486. paddlex/modules/semantic_segmentation/model_list.py +2 -0
  487. paddlex/modules/semantic_segmentation/trainer.py +2 -0
  488. paddlex/modules/table_recognition/dataset_checker/__init__.py +16 -1
  489. paddlex/modules/table_recognition/dataset_checker/dataset_src/check_dataset.py +13 -14
  490. paddlex/modules/table_recognition/model_list.py +2 -0
  491. paddlex/modules/text_detection/dataset_checker/__init__.py +16 -1
  492. paddlex/modules/text_detection/dataset_checker/dataset_src/check_dataset.py +13 -3
  493. paddlex/modules/text_detection/model_list.py +2 -0
  494. paddlex/modules/text_recognition/dataset_checker/__init__.py +16 -4
  495. paddlex/modules/text_recognition/dataset_checker/dataset_src/check_dataset.py +13 -3
  496. paddlex/modules/text_recognition/evaluator.py +4 -3
  497. paddlex/modules/text_recognition/exportor.py +0 -3
  498. paddlex/modules/text_recognition/model_list.py +14 -0
  499. paddlex/modules/text_recognition/trainer.py +4 -3
  500. paddlex/modules/ts_anomaly_detection/dataset_checker/__init__.py +15 -0
  501. paddlex/modules/ts_anomaly_detection/trainer.py +17 -1
  502. paddlex/modules/ts_classification/dataset_checker/__init__.py +15 -0
  503. paddlex/modules/ts_classification/trainer.py +17 -1
  504. paddlex/modules/ts_forecast/dataset_checker/__init__.py +15 -0
  505. paddlex/modules/ts_forecast/trainer.py +17 -1
  506. paddlex/modules/video_classification/__init__.py +18 -0
  507. paddlex/modules/video_classification/dataset_checker/__init__.py +93 -0
  508. paddlex/modules/video_classification/dataset_checker/dataset_src/__init__.py +18 -0
  509. paddlex/modules/video_classification/dataset_checker/dataset_src/analyse_dataset.py +93 -0
  510. paddlex/modules/video_classification/dataset_checker/dataset_src/check_dataset.py +121 -0
  511. paddlex/modules/video_classification/dataset_checker/dataset_src/split_dataset.py +82 -0
  512. paddlex/modules/video_classification/evaluator.py +44 -0
  513. paddlex/modules/video_classification/exportor.py +22 -0
  514. paddlex/modules/video_classification/model_list.py +19 -0
  515. paddlex/modules/video_classification/trainer.py +88 -0
  516. paddlex/modules/video_detection/__init__.py +18 -0
  517. paddlex/modules/video_detection/dataset_checker/__init__.py +86 -0
  518. paddlex/modules/video_detection/dataset_checker/dataset_src/__init__.py +17 -0
  519. paddlex/modules/video_detection/dataset_checker/dataset_src/analyse_dataset.py +101 -0
  520. paddlex/modules/video_detection/dataset_checker/dataset_src/check_dataset.py +134 -0
  521. paddlex/modules/video_detection/evaluator.py +42 -0
  522. paddlex/modules/video_detection/exportor.py +22 -0
  523. paddlex/modules/video_detection/model_list.py +15 -0
  524. paddlex/modules/video_detection/trainer.py +82 -0
  525. paddlex/ops/__init__.py +149 -0
  526. paddlex/ops/iou3d_nms/iou3d_cpu.cpp +264 -0
  527. paddlex/ops/iou3d_nms/iou3d_cpu.h +27 -0
  528. paddlex/ops/iou3d_nms/iou3d_nms.cpp +204 -0
  529. paddlex/ops/iou3d_nms/iou3d_nms.h +33 -0
  530. paddlex/ops/iou3d_nms/iou3d_nms_api.cpp +108 -0
  531. paddlex/ops/iou3d_nms/iou3d_nms_kernel.cu +482 -0
  532. paddlex/ops/setup.py +37 -0
  533. paddlex/ops/voxel/voxelize_op.cc +191 -0
  534. paddlex/ops/voxel/voxelize_op.cu +346 -0
  535. paddlex/paddle2onnx_requirements.txt +1 -0
  536. paddlex/paddlex_cli.py +339 -72
  537. paddlex/repo_apis/Paddle3D_api/__init__.py +17 -0
  538. paddlex/repo_apis/Paddle3D_api/bev_fusion/__init__.py +18 -0
  539. paddlex/repo_apis/Paddle3D_api/bev_fusion/config.py +118 -0
  540. paddlex/repo_apis/Paddle3D_api/bev_fusion/model.py +238 -0
  541. paddlex/repo_apis/Paddle3D_api/bev_fusion/register.py +55 -0
  542. paddlex/repo_apis/Paddle3D_api/bev_fusion/runner.py +104 -0
  543. paddlex/repo_apis/Paddle3D_api/pp3d_config.py +144 -0
  544. paddlex/repo_apis/PaddleClas_api/cls/model.py +6 -0
  545. paddlex/repo_apis/PaddleClas_api/cls/register.py +20 -2
  546. paddlex/repo_apis/PaddleDetection_api/instance_seg/config.py +8 -4
  547. paddlex/repo_apis/PaddleDetection_api/instance_seg/model.py +6 -0
  548. paddlex/repo_apis/PaddleDetection_api/object_det/config.py +27 -5
  549. paddlex/repo_apis/PaddleDetection_api/object_det/model.py +6 -0
  550. paddlex/repo_apis/PaddleDetection_api/object_det/official_categories.py +81 -0
  551. paddlex/repo_apis/PaddleDetection_api/object_det/register.py +182 -3
  552. paddlex/repo_apis/PaddleOCR_api/__init__.py +1 -0
  553. paddlex/repo_apis/PaddleOCR_api/formula_rec/__init__.py +16 -0
  554. paddlex/repo_apis/PaddleOCR_api/formula_rec/config.py +570 -0
  555. paddlex/repo_apis/PaddleOCR_api/formula_rec/model.py +402 -0
  556. paddlex/repo_apis/PaddleOCR_api/formula_rec/register.py +73 -0
  557. paddlex/repo_apis/PaddleOCR_api/formula_rec/runner.py +240 -0
  558. paddlex/repo_apis/PaddleOCR_api/table_rec/register.py +18 -0
  559. paddlex/repo_apis/PaddleOCR_api/text_det/register.py +18 -0
  560. paddlex/repo_apis/PaddleOCR_api/text_rec/config.py +21 -0
  561. paddlex/repo_apis/PaddleOCR_api/text_rec/model.py +6 -0
  562. paddlex/repo_apis/PaddleOCR_api/text_rec/register.py +126 -7
  563. paddlex/repo_apis/PaddleSeg_api/seg/config.py +9 -0
  564. paddlex/repo_apis/PaddleSeg_api/seg/model.py +10 -0
  565. paddlex/repo_apis/PaddleSeg_api/seg/register.py +20 -0
  566. paddlex/repo_apis/PaddleTS_api/ts_base/config.py +24 -0
  567. paddlex/repo_apis/PaddleTS_api/ts_base/model.py +11 -7
  568. paddlex/repo_apis/PaddleVideo_api/__init__.py +17 -0
  569. paddlex/repo_apis/PaddleVideo_api/config_utils.py +51 -0
  570. paddlex/repo_apis/PaddleVideo_api/video_cls/__init__.py +19 -0
  571. paddlex/repo_apis/PaddleVideo_api/video_cls/config.py +547 -0
  572. paddlex/repo_apis/PaddleVideo_api/video_cls/model.py +346 -0
  573. paddlex/repo_apis/PaddleVideo_api/video_cls/register.py +71 -0
  574. paddlex/repo_apis/PaddleVideo_api/video_cls/runner.py +205 -0
  575. paddlex/repo_apis/PaddleVideo_api/video_det/__init__.py +19 -0
  576. paddlex/repo_apis/PaddleVideo_api/video_det/config.py +548 -0
  577. paddlex/repo_apis/PaddleVideo_api/video_det/model.py +298 -0
  578. paddlex/repo_apis/PaddleVideo_api/video_det/register.py +45 -0
  579. paddlex/repo_apis/PaddleVideo_api/video_det/runner.py +200 -0
  580. paddlex/repo_apis/base/runner.py +2 -1
  581. paddlex/repo_manager/meta.py +29 -2
  582. paddlex/repo_manager/repo.py +24 -5
  583. paddlex/repo_manager/requirements.txt +10 -7
  584. paddlex/repo_manager/utils.py +62 -1
  585. paddlex/serving_requirements.txt +9 -0
  586. paddlex/utils/config.py +4 -3
  587. paddlex/utils/custom_device_whitelist.py +457 -0
  588. paddlex/utils/device.py +74 -26
  589. paddlex/utils/env.py +28 -0
  590. paddlex/utils/flags.py +4 -0
  591. paddlex/utils/fonts/__init__.py +48 -5
  592. paddlex/utils/lazy_loader.py +2 -0
  593. paddlex/utils/logging.py +1 -2
  594. paddlex/utils/pipeline_arguments.py +711 -0
  595. paddlex-3.0.0rc0.dist-info/METADATA +1035 -0
  596. paddlex-3.0.0rc0.dist-info/RECORD +1015 -0
  597. paddlex-3.0.0rc0.dist-info/WHEEL +5 -0
  598. paddlex/configs/face_recognition/MobileFaceNet.yaml +0 -44
  599. paddlex/configs/face_recognition/ResNet50_face.yaml +0 -44
  600. paddlex/configs/formula_recognition/LaTeX_OCR_rec.yaml +0 -40
  601. paddlex/configs/image_classification/CLIP_vit_base_patch16_224.yaml +0 -41
  602. paddlex/configs/image_classification/CLIP_vit_large_patch14_224.yaml +0 -41
  603. paddlex/configs/image_classification/ConvNeXt_large_384.yaml +0 -41
  604. paddlex/configs/object_detection/YOLOX-X.yaml +0 -40
  605. paddlex/configs/semantic_segmentation/SeaFormer_base.yaml +0 -40
  606. paddlex/configs/semantic_segmentation/SeaFormer_large.yaml +0 -40
  607. paddlex/configs/semantic_segmentation/SeaFormer_small.yaml +0 -40
  608. paddlex/configs/semantic_segmentation/SeaFormer_tiny.yaml +0 -40
  609. paddlex/inference/components/__init__.py +0 -18
  610. paddlex/inference/components/base.py +0 -292
  611. paddlex/inference/components/llm/__init__.py +0 -25
  612. paddlex/inference/components/llm/base.py +0 -65
  613. paddlex/inference/components/llm/erniebot.py +0 -212
  614. paddlex/inference/components/paddle_predictor/__init__.py +0 -20
  615. paddlex/inference/components/paddle_predictor/predictor.py +0 -332
  616. paddlex/inference/components/retrieval/__init__.py +0 -15
  617. paddlex/inference/components/retrieval/faiss.py +0 -359
  618. paddlex/inference/components/task_related/__init__.py +0 -33
  619. paddlex/inference/components/task_related/clas.py +0 -124
  620. paddlex/inference/components/task_related/det.py +0 -284
  621. paddlex/inference/components/task_related/instance_seg.py +0 -89
  622. paddlex/inference/components/task_related/seal_det_warp.py +0 -940
  623. paddlex/inference/components/task_related/seg.py +0 -40
  624. paddlex/inference/components/task_related/table_rec.py +0 -191
  625. paddlex/inference/components/task_related/text_det.py +0 -895
  626. paddlex/inference/components/task_related/text_rec.py +0 -353
  627. paddlex/inference/components/task_related/warp.py +0 -43
  628. paddlex/inference/components/transforms/__init__.py +0 -16
  629. paddlex/inference/components/transforms/image/__init__.py +0 -15
  630. paddlex/inference/components/transforms/image/common.py +0 -598
  631. paddlex/inference/components/transforms/image/funcs.py +0 -58
  632. paddlex/inference/components/transforms/read_data.py +0 -67
  633. paddlex/inference/components/transforms/ts/__init__.py +0 -15
  634. paddlex/inference/components/transforms/ts/common.py +0 -393
  635. paddlex/inference/components/transforms/ts/funcs.py +0 -424
  636. paddlex/inference/models/anomaly_detection.py +0 -87
  637. paddlex/inference/models/base/base_predictor.py +0 -76
  638. paddlex/inference/models/base/basic_predictor.py +0 -122
  639. paddlex/inference/models/face_recognition.py +0 -21
  640. paddlex/inference/models/formula_recognition.py +0 -55
  641. paddlex/inference/models/general_recognition.py +0 -99
  642. paddlex/inference/models/image_classification.py +0 -101
  643. paddlex/inference/models/image_unwarping.py +0 -43
  644. paddlex/inference/models/instance_segmentation.py +0 -66
  645. paddlex/inference/models/multilabel_classification.py +0 -33
  646. paddlex/inference/models/object_detection.py +0 -129
  647. paddlex/inference/models/semantic_segmentation.py +0 -86
  648. paddlex/inference/models/table_recognition.py +0 -106
  649. paddlex/inference/models/text_detection.py +0 -105
  650. paddlex/inference/models/text_recognition.py +0 -78
  651. paddlex/inference/models/ts_ad.py +0 -68
  652. paddlex/inference/models/ts_cls.py +0 -57
  653. paddlex/inference/models/ts_fc.py +0 -73
  654. paddlex/inference/pipelines/attribute_recognition.py +0 -92
  655. paddlex/inference/pipelines/face_recognition.py +0 -49
  656. paddlex/inference/pipelines/formula_recognition.py +0 -102
  657. paddlex/inference/pipelines/layout_parsing/layout_parsing.py +0 -362
  658. paddlex/inference/pipelines/ocr.py +0 -80
  659. paddlex/inference/pipelines/pp_shitu_v2.py +0 -152
  660. paddlex/inference/pipelines/ppchatocrv3/__init__.py +0 -15
  661. paddlex/inference/pipelines/ppchatocrv3/ch_prompt.yaml +0 -14
  662. paddlex/inference/pipelines/ppchatocrv3/ppchatocrv3.py +0 -717
  663. paddlex/inference/pipelines/ppchatocrv3/utils.py +0 -168
  664. paddlex/inference/pipelines/seal_recognition.py +0 -152
  665. paddlex/inference/pipelines/serving/__init__.py +0 -17
  666. paddlex/inference/pipelines/serving/_pipeline_apps/__init__.py +0 -205
  667. paddlex/inference/pipelines/serving/_pipeline_apps/anomaly_detection.py +0 -80
  668. paddlex/inference/pipelines/serving/_pipeline_apps/face_recognition.py +0 -317
  669. paddlex/inference/pipelines/serving/_pipeline_apps/formula_recognition.py +0 -119
  670. paddlex/inference/pipelines/serving/_pipeline_apps/image_classification.py +0 -101
  671. paddlex/inference/pipelines/serving/_pipeline_apps/instance_segmentation.py +0 -112
  672. paddlex/inference/pipelines/serving/_pipeline_apps/layout_parsing.py +0 -205
  673. paddlex/inference/pipelines/serving/_pipeline_apps/multi_label_image_classification.py +0 -90
  674. paddlex/inference/pipelines/serving/_pipeline_apps/object_detection.py +0 -90
  675. paddlex/inference/pipelines/serving/_pipeline_apps/ocr.py +0 -98
  676. paddlex/inference/pipelines/serving/_pipeline_apps/pedestrian_attribute_recognition.py +0 -102
  677. paddlex/inference/pipelines/serving/_pipeline_apps/pp_shitu_v2.py +0 -319
  678. paddlex/inference/pipelines/serving/_pipeline_apps/ppchatocrv3.py +0 -445
  679. paddlex/inference/pipelines/serving/_pipeline_apps/seal_recognition.py +0 -110
  680. paddlex/inference/pipelines/serving/_pipeline_apps/semantic_segmentation.py +0 -82
  681. paddlex/inference/pipelines/serving/_pipeline_apps/small_object_detection.py +0 -92
  682. paddlex/inference/pipelines/serving/_pipeline_apps/table_recognition.py +0 -110
  683. paddlex/inference/pipelines/serving/_pipeline_apps/ts_ad.py +0 -68
  684. paddlex/inference/pipelines/serving/_pipeline_apps/ts_cls.py +0 -68
  685. paddlex/inference/pipelines/serving/_pipeline_apps/ts_fc.py +0 -68
  686. paddlex/inference/pipelines/serving/_pipeline_apps/vehicle_attribute_recognition.py +0 -102
  687. paddlex/inference/pipelines/serving/app.py +0 -164
  688. paddlex/inference/pipelines/serving/models.py +0 -30
  689. paddlex/inference/pipelines/serving/server.py +0 -25
  690. paddlex/inference/pipelines/serving/storage.py +0 -161
  691. paddlex/inference/pipelines/serving/utils.py +0 -190
  692. paddlex/inference/pipelines/single_model_pipeline.py +0 -76
  693. paddlex/inference/pipelines/table_recognition/table_recognition.py +0 -193
  694. paddlex/inference/results/__init__.py +0 -31
  695. paddlex/inference/results/attribute_rec.py +0 -89
  696. paddlex/inference/results/base.py +0 -43
  697. paddlex/inference/results/chat_ocr.py +0 -158
  698. paddlex/inference/results/clas.py +0 -133
  699. paddlex/inference/results/det.py +0 -86
  700. paddlex/inference/results/face_rec.py +0 -34
  701. paddlex/inference/results/formula_rec.py +0 -363
  702. paddlex/inference/results/instance_seg.py +0 -152
  703. paddlex/inference/results/ocr.py +0 -157
  704. paddlex/inference/results/seal_rec.py +0 -50
  705. paddlex/inference/results/seg.py +0 -72
  706. paddlex/inference/results/shitu.py +0 -35
  707. paddlex/inference/results/table_rec.py +0 -109
  708. paddlex/inference/results/text_det.py +0 -33
  709. paddlex/inference/results/text_rec.py +0 -66
  710. paddlex/inference/results/ts.py +0 -37
  711. paddlex/inference/results/utils/mixin.py +0 -204
  712. paddlex/inference/results/warp.py +0 -31
  713. paddlex/inference/utils/process_hook.py +0 -54
  714. paddlex/pipelines/OCR.yaml +0 -8
  715. paddlex/pipelines/PP-ChatOCRv3-doc.yaml +0 -27
  716. paddlex/pipelines/PP-ShiTuV2.yaml +0 -13
  717. paddlex/pipelines/anomaly_detection.yaml +0 -7
  718. paddlex/pipelines/face_recognition.yaml +0 -13
  719. paddlex/pipelines/formula_recognition.yaml +0 -8
  720. paddlex/pipelines/image_classification.yaml +0 -7
  721. paddlex/pipelines/instance_segmentation.yaml +0 -7
  722. paddlex/pipelines/layout_parsing.yaml +0 -14
  723. paddlex/pipelines/multi_label_image_classification.yaml +0 -7
  724. paddlex/pipelines/object_detection.yaml +0 -7
  725. paddlex/pipelines/pedestrian_attribute_recognition.yaml +0 -7
  726. paddlex/pipelines/seal_recognition.yaml +0 -10
  727. paddlex/pipelines/semantic_segmentation.yaml +0 -7
  728. paddlex/pipelines/small_object_detection.yaml +0 -7
  729. paddlex/pipelines/table_recognition.yaml +0 -12
  730. paddlex/pipelines/ts_ad.yaml +0 -7
  731. paddlex/pipelines/ts_cls.yaml +0 -7
  732. paddlex/pipelines/ts_fc.yaml +0 -7
  733. paddlex/pipelines/vehicle_attribute_recognition.yaml +0 -7
  734. paddlex/utils/fonts/PingFang-SC-Regular.ttf +0 -0
  735. paddlex-3.0.0b2.dist-info/METADATA +0 -760
  736. paddlex-3.0.0b2.dist-info/RECORD +0 -646
  737. paddlex-3.0.0b2.dist-info/WHEEL +0 -5
  738. /paddlex/configs/{doc_text_orientation → modules/doc_text_orientation}/PP-LCNet_x1_0_doc_ori.yaml +0 -0
  739. /paddlex/configs/{face_detection → modules/face_detection}/BlazeFace-FPN-SSH.yaml +0 -0
  740. /paddlex/configs/{face_detection → modules/face_detection}/BlazeFace.yaml +0 -0
  741. /paddlex/configs/{face_detection → modules/face_detection}/PP-YOLOE_plus-S_face.yaml +0 -0
  742. /paddlex/configs/{face_detection → modules/face_detection}/PicoDet_LCNet_x2_5_face.yaml +0 -0
  743. /paddlex/configs/{human_detection → modules/human_detection}/PP-YOLOE-L_human.yaml +0 -0
  744. /paddlex/configs/{human_detection → modules/human_detection}/PP-YOLOE-S_human.yaml +0 -0
  745. /paddlex/configs/{anomaly_detection → modules/image_anomaly_detection}/STFPM.yaml +0 -0
  746. /paddlex/configs/{image_classification → modules/image_classification}/ConvNeXt_base_224.yaml +0 -0
  747. /paddlex/configs/{image_classification → modules/image_classification}/ConvNeXt_base_384.yaml +0 -0
  748. /paddlex/configs/{image_classification → modules/image_classification}/ConvNeXt_large_224.yaml +0 -0
  749. /paddlex/configs/{image_classification → modules/image_classification}/ConvNeXt_small.yaml +0 -0
  750. /paddlex/configs/{image_classification → modules/image_classification}/ConvNeXt_tiny.yaml +0 -0
  751. /paddlex/configs/{image_classification → modules/image_classification}/FasterNet-L.yaml +0 -0
  752. /paddlex/configs/{image_classification → modules/image_classification}/FasterNet-M.yaml +0 -0
  753. /paddlex/configs/{image_classification → modules/image_classification}/FasterNet-S.yaml +0 -0
  754. /paddlex/configs/{image_classification → modules/image_classification}/FasterNet-T0.yaml +0 -0
  755. /paddlex/configs/{image_classification → modules/image_classification}/FasterNet-T1.yaml +0 -0
  756. /paddlex/configs/{image_classification → modules/image_classification}/FasterNet-T2.yaml +0 -0
  757. /paddlex/configs/{image_classification → modules/image_classification}/MobileNetV1_x0_25.yaml +0 -0
  758. /paddlex/configs/{image_classification → modules/image_classification}/MobileNetV1_x0_5.yaml +0 -0
  759. /paddlex/configs/{image_classification → modules/image_classification}/MobileNetV1_x0_75.yaml +0 -0
  760. /paddlex/configs/{image_classification → modules/image_classification}/MobileNetV1_x1_0.yaml +0 -0
  761. /paddlex/configs/{image_classification → modules/image_classification}/MobileNetV2_x0_25.yaml +0 -0
  762. /paddlex/configs/{image_classification → modules/image_classification}/MobileNetV2_x0_5.yaml +0 -0
  763. /paddlex/configs/{image_classification → modules/image_classification}/MobileNetV2_x1_0.yaml +0 -0
  764. /paddlex/configs/{image_classification → modules/image_classification}/MobileNetV2_x1_5.yaml +0 -0
  765. /paddlex/configs/{image_classification → modules/image_classification}/MobileNetV2_x2_0.yaml +0 -0
  766. /paddlex/configs/{image_classification → modules/image_classification}/MobileNetV3_large_x0_35.yaml +0 -0
  767. /paddlex/configs/{image_classification → modules/image_classification}/MobileNetV3_large_x0_5.yaml +0 -0
  768. /paddlex/configs/{image_classification → modules/image_classification}/MobileNetV3_large_x0_75.yaml +0 -0
  769. /paddlex/configs/{image_classification → modules/image_classification}/MobileNetV3_large_x1_0.yaml +0 -0
  770. /paddlex/configs/{image_classification → modules/image_classification}/MobileNetV3_large_x1_25.yaml +0 -0
  771. /paddlex/configs/{image_classification → modules/image_classification}/MobileNetV3_small_x0_35.yaml +0 -0
  772. /paddlex/configs/{image_classification → modules/image_classification}/MobileNetV3_small_x0_5.yaml +0 -0
  773. /paddlex/configs/{image_classification → modules/image_classification}/MobileNetV3_small_x0_75.yaml +0 -0
  774. /paddlex/configs/{image_classification → modules/image_classification}/MobileNetV3_small_x1_0.yaml +0 -0
  775. /paddlex/configs/{image_classification → modules/image_classification}/MobileNetV3_small_x1_25.yaml +0 -0
  776. /paddlex/configs/{image_classification → modules/image_classification}/MobileNetV4_conv_large.yaml +0 -0
  777. /paddlex/configs/{image_classification → modules/image_classification}/MobileNetV4_conv_medium.yaml +0 -0
  778. /paddlex/configs/{image_classification → modules/image_classification}/MobileNetV4_conv_small.yaml +0 -0
  779. /paddlex/configs/{image_classification → modules/image_classification}/MobileNetV4_hybrid_large.yaml +0 -0
  780. /paddlex/configs/{image_classification → modules/image_classification}/MobileNetV4_hybrid_medium.yaml +0 -0
  781. /paddlex/configs/{image_classification → modules/image_classification}/PP-HGNetV2-B0.yaml +0 -0
  782. /paddlex/configs/{image_classification → modules/image_classification}/PP-HGNetV2-B1.yaml +0 -0
  783. /paddlex/configs/{image_classification → modules/image_classification}/PP-HGNetV2-B2.yaml +0 -0
  784. /paddlex/configs/{image_classification → modules/image_classification}/PP-HGNetV2-B3.yaml +0 -0
  785. /paddlex/configs/{image_classification → modules/image_classification}/PP-HGNetV2-B4.yaml +0 -0
  786. /paddlex/configs/{image_classification → modules/image_classification}/PP-HGNetV2-B5.yaml +0 -0
  787. /paddlex/configs/{image_classification → modules/image_classification}/PP-HGNetV2-B6.yaml +0 -0
  788. /paddlex/configs/{image_classification → modules/image_classification}/PP-HGNet_base.yaml +0 -0
  789. /paddlex/configs/{image_classification → modules/image_classification}/PP-HGNet_small.yaml +0 -0
  790. /paddlex/configs/{image_classification → modules/image_classification}/PP-HGNet_tiny.yaml +0 -0
  791. /paddlex/configs/{image_classification → modules/image_classification}/PP-LCNetV2_base.yaml +0 -0
  792. /paddlex/configs/{image_classification → modules/image_classification}/PP-LCNetV2_large.yaml +0 -0
  793. /paddlex/configs/{image_classification → modules/image_classification}/PP-LCNetV2_small.yaml +0 -0
  794. /paddlex/configs/{image_classification → modules/image_classification}/PP-LCNet_x0_25.yaml +0 -0
  795. /paddlex/configs/{image_classification → modules/image_classification}/PP-LCNet_x0_35.yaml +0 -0
  796. /paddlex/configs/{image_classification → modules/image_classification}/PP-LCNet_x0_5.yaml +0 -0
  797. /paddlex/configs/{image_classification → modules/image_classification}/PP-LCNet_x0_75.yaml +0 -0
  798. /paddlex/configs/{image_classification → modules/image_classification}/PP-LCNet_x1_0.yaml +0 -0
  799. /paddlex/configs/{image_classification → modules/image_classification}/PP-LCNet_x1_5.yaml +0 -0
  800. /paddlex/configs/{image_classification → modules/image_classification}/PP-LCNet_x2_0.yaml +0 -0
  801. /paddlex/configs/{image_classification → modules/image_classification}/PP-LCNet_x2_5.yaml +0 -0
  802. /paddlex/configs/{image_classification → modules/image_classification}/ResNet101.yaml +0 -0
  803. /paddlex/configs/{image_classification → modules/image_classification}/ResNet101_vd.yaml +0 -0
  804. /paddlex/configs/{image_classification → modules/image_classification}/ResNet152.yaml +0 -0
  805. /paddlex/configs/{image_classification → modules/image_classification}/ResNet152_vd.yaml +0 -0
  806. /paddlex/configs/{image_classification → modules/image_classification}/ResNet18.yaml +0 -0
  807. /paddlex/configs/{image_classification → modules/image_classification}/ResNet18_vd.yaml +0 -0
  808. /paddlex/configs/{image_classification → modules/image_classification}/ResNet200_vd.yaml +0 -0
  809. /paddlex/configs/{image_classification → modules/image_classification}/ResNet34.yaml +0 -0
  810. /paddlex/configs/{image_classification → modules/image_classification}/ResNet34_vd.yaml +0 -0
  811. /paddlex/configs/{image_classification → modules/image_classification}/ResNet50.yaml +0 -0
  812. /paddlex/configs/{image_classification → modules/image_classification}/ResNet50_vd.yaml +0 -0
  813. /paddlex/configs/{image_classification → modules/image_classification}/StarNet-S1.yaml +0 -0
  814. /paddlex/configs/{image_classification → modules/image_classification}/StarNet-S2.yaml +0 -0
  815. /paddlex/configs/{image_classification → modules/image_classification}/StarNet-S3.yaml +0 -0
  816. /paddlex/configs/{image_classification → modules/image_classification}/StarNet-S4.yaml +0 -0
  817. /paddlex/configs/{image_classification → modules/image_classification}/SwinTransformer_base_patch4_window12_384.yaml +0 -0
  818. /paddlex/configs/{image_classification → modules/image_classification}/SwinTransformer_base_patch4_window7_224.yaml +0 -0
  819. /paddlex/configs/{image_classification → modules/image_classification}/SwinTransformer_large_patch4_window12_384.yaml +0 -0
  820. /paddlex/configs/{image_classification → modules/image_classification}/SwinTransformer_large_patch4_window7_224.yaml +0 -0
  821. /paddlex/configs/{image_classification → modules/image_classification}/SwinTransformer_small_patch4_window7_224.yaml +0 -0
  822. /paddlex/configs/{image_classification → modules/image_classification}/SwinTransformer_tiny_patch4_window7_224.yaml +0 -0
  823. /paddlex/configs/{general_recognition → modules/image_feature}/PP-ShiTuV2_rec.yaml +0 -0
  824. /paddlex/configs/{general_recognition → modules/image_feature}/PP-ShiTuV2_rec_CLIP_vit_base.yaml +0 -0
  825. /paddlex/configs/{general_recognition → modules/image_feature}/PP-ShiTuV2_rec_CLIP_vit_large.yaml +0 -0
  826. /paddlex/configs/{multilabel_classification → modules/image_multilabel_classification}/CLIP_vit_base_patch16_448_ML.yaml +0 -0
  827. /paddlex/configs/{multilabel_classification → modules/image_multilabel_classification}/PP-HGNetV2-B0_ML.yaml +0 -0
  828. /paddlex/configs/{multilabel_classification → modules/image_multilabel_classification}/PP-HGNetV2-B4_ML.yaml +0 -0
  829. /paddlex/configs/{multilabel_classification → modules/image_multilabel_classification}/PP-HGNetV2-B6_ML.yaml +0 -0
  830. /paddlex/configs/{multilabel_classification → modules/image_multilabel_classification}/PP-LCNet_x1_0_ML.yaml +0 -0
  831. /paddlex/configs/{multilabel_classification → modules/image_multilabel_classification}/ResNet50_ML.yaml +0 -0
  832. /paddlex/configs/{image_unwarping → modules/image_unwarping}/UVDoc.yaml +0 -0
  833. /paddlex/configs/{instance_segmentation → modules/instance_segmentation}/Cascade-MaskRCNN-ResNet50-FPN.yaml +0 -0
  834. /paddlex/configs/{instance_segmentation → modules/instance_segmentation}/Cascade-MaskRCNN-ResNet50-vd-SSLDv2-FPN.yaml +0 -0
  835. /paddlex/configs/{instance_segmentation → modules/instance_segmentation}/Mask-RT-DETR-H.yaml +0 -0
  836. /paddlex/configs/{instance_segmentation → modules/instance_segmentation}/Mask-RT-DETR-L.yaml +0 -0
  837. /paddlex/configs/{instance_segmentation → modules/instance_segmentation}/Mask-RT-DETR-M.yaml +0 -0
  838. /paddlex/configs/{instance_segmentation → modules/instance_segmentation}/Mask-RT-DETR-S.yaml +0 -0
  839. /paddlex/configs/{instance_segmentation → modules/instance_segmentation}/Mask-RT-DETR-X.yaml +0 -0
  840. /paddlex/configs/{instance_segmentation → modules/instance_segmentation}/MaskRCNN-ResNeXt101-vd-FPN.yaml +0 -0
  841. /paddlex/configs/{instance_segmentation → modules/instance_segmentation}/MaskRCNN-ResNet101-FPN.yaml +0 -0
  842. /paddlex/configs/{instance_segmentation → modules/instance_segmentation}/MaskRCNN-ResNet101-vd-FPN.yaml +0 -0
  843. /paddlex/configs/{instance_segmentation → modules/instance_segmentation}/MaskRCNN-ResNet50-FPN.yaml +0 -0
  844. /paddlex/configs/{instance_segmentation → modules/instance_segmentation}/MaskRCNN-ResNet50-vd-FPN.yaml +0 -0
  845. /paddlex/configs/{instance_segmentation → modules/instance_segmentation}/MaskRCNN-ResNet50.yaml +0 -0
  846. /paddlex/configs/{instance_segmentation → modules/instance_segmentation}/PP-YOLOE_seg-S.yaml +0 -0
  847. /paddlex/configs/{instance_segmentation → modules/instance_segmentation}/SOLOv2.yaml +0 -0
  848. /paddlex/configs/{structure_analysis → modules/layout_detection}/PicoDet-L_layout_17cls.yaml +0 -0
  849. /paddlex/configs/{structure_analysis → modules/layout_detection}/PicoDet-L_layout_3cls.yaml +0 -0
  850. /paddlex/configs/{structure_analysis → modules/layout_detection}/PicoDet-S_layout_17cls.yaml +0 -0
  851. /paddlex/configs/{structure_analysis → modules/layout_detection}/PicoDet-S_layout_3cls.yaml +0 -0
  852. /paddlex/configs/{structure_analysis → modules/layout_detection}/PicoDet_layout_1x.yaml +0 -0
  853. /paddlex/configs/{structure_analysis → modules/layout_detection}/PicoDet_layout_1x_table.yaml +0 -0
  854. /paddlex/configs/{structure_analysis → modules/layout_detection}/RT-DETR-H_layout_17cls.yaml +0 -0
  855. /paddlex/configs/{structure_analysis → modules/layout_detection}/RT-DETR-H_layout_3cls.yaml +0 -0
  856. /paddlex/configs/{mainbody_detection → modules/mainbody_detection}/PP-ShiTuV2_det.yaml +0 -0
  857. /paddlex/configs/{object_detection → modules/object_detection}/Cascade-FasterRCNN-ResNet50-FPN.yaml +0 -0
  858. /paddlex/configs/{object_detection → modules/object_detection}/Cascade-FasterRCNN-ResNet50-vd-SSLDv2-FPN.yaml +0 -0
  859. /paddlex/configs/{object_detection → modules/object_detection}/CenterNet-DLA-34.yaml +0 -0
  860. /paddlex/configs/{object_detection → modules/object_detection}/CenterNet-ResNet50.yaml +0 -0
  861. /paddlex/configs/{object_detection → modules/object_detection}/DETR-R50.yaml +0 -0
  862. /paddlex/configs/{object_detection → modules/object_detection}/FCOS-ResNet50.yaml +0 -0
  863. /paddlex/configs/{object_detection → modules/object_detection}/FasterRCNN-ResNeXt101-vd-FPN.yaml +0 -0
  864. /paddlex/configs/{object_detection → modules/object_detection}/FasterRCNN-ResNet101-FPN.yaml +0 -0
  865. /paddlex/configs/{object_detection → modules/object_detection}/FasterRCNN-ResNet101.yaml +0 -0
  866. /paddlex/configs/{object_detection → modules/object_detection}/FasterRCNN-ResNet34-FPN.yaml +0 -0
  867. /paddlex/configs/{object_detection → modules/object_detection}/FasterRCNN-ResNet50-FPN.yaml +0 -0
  868. /paddlex/configs/{object_detection → modules/object_detection}/FasterRCNN-ResNet50-vd-FPN.yaml +0 -0
  869. /paddlex/configs/{object_detection → modules/object_detection}/FasterRCNN-ResNet50-vd-SSLDv2-FPN.yaml +0 -0
  870. /paddlex/configs/{object_detection → modules/object_detection}/FasterRCNN-ResNet50.yaml +0 -0
  871. /paddlex/configs/{object_detection → modules/object_detection}/FasterRCNN-Swin-Tiny-FPN.yaml +0 -0
  872. /paddlex/configs/{object_detection → modules/object_detection}/PP-YOLOE_plus-L.yaml +0 -0
  873. /paddlex/configs/{object_detection → modules/object_detection}/PP-YOLOE_plus-M.yaml +0 -0
  874. /paddlex/configs/{object_detection → modules/object_detection}/PP-YOLOE_plus-S.yaml +0 -0
  875. /paddlex/configs/{object_detection → modules/object_detection}/PP-YOLOE_plus-X.yaml +0 -0
  876. /paddlex/configs/{object_detection → modules/object_detection}/PicoDet-L.yaml +0 -0
  877. /paddlex/configs/{object_detection → modules/object_detection}/PicoDet-M.yaml +0 -0
  878. /paddlex/configs/{object_detection → modules/object_detection}/PicoDet-S.yaml +0 -0
  879. /paddlex/configs/{object_detection → modules/object_detection}/PicoDet-XS.yaml +0 -0
  880. /paddlex/configs/{object_detection → modules/object_detection}/RT-DETR-H.yaml +0 -0
  881. /paddlex/configs/{object_detection → modules/object_detection}/RT-DETR-L.yaml +0 -0
  882. /paddlex/configs/{object_detection → modules/object_detection}/RT-DETR-R18.yaml +0 -0
  883. /paddlex/configs/{object_detection → modules/object_detection}/RT-DETR-R50.yaml +0 -0
  884. /paddlex/configs/{object_detection → modules/object_detection}/RT-DETR-X.yaml +0 -0
  885. /paddlex/configs/{object_detection → modules/object_detection}/YOLOX-L.yaml +0 -0
  886. /paddlex/configs/{object_detection → modules/object_detection}/YOLOX-M.yaml +0 -0
  887. /paddlex/configs/{object_detection → modules/object_detection}/YOLOX-N.yaml +0 -0
  888. /paddlex/configs/{object_detection → modules/object_detection}/YOLOX-S.yaml +0 -0
  889. /paddlex/configs/{object_detection → modules/object_detection}/YOLOX-T.yaml +0 -0
  890. /paddlex/configs/{object_detection → modules/object_detection}/YOLOv3-DarkNet53.yaml +0 -0
  891. /paddlex/configs/{object_detection → modules/object_detection}/YOLOv3-MobileNetV3.yaml +0 -0
  892. /paddlex/configs/{object_detection → modules/object_detection}/YOLOv3-ResNet50_vd_DCN.yaml +0 -0
  893. /paddlex/configs/{pedestrian_attribute → modules/pedestrian_attribute_recognition}/PP-LCNet_x1_0_pedestrian_attribute.yaml +0 -0
  894. /paddlex/configs/{text_detection_seal → modules/seal_text_detection}/PP-OCRv4_mobile_seal_det.yaml +0 -0
  895. /paddlex/configs/{text_detection_seal → modules/seal_text_detection}/PP-OCRv4_server_seal_det.yaml +0 -0
  896. /paddlex/configs/{semantic_segmentation → modules/semantic_segmentation}/Deeplabv3-R101.yaml +0 -0
  897. /paddlex/configs/{semantic_segmentation → modules/semantic_segmentation}/Deeplabv3-R50.yaml +0 -0
  898. /paddlex/configs/{semantic_segmentation → modules/semantic_segmentation}/Deeplabv3_Plus-R101.yaml +0 -0
  899. /paddlex/configs/{semantic_segmentation → modules/semantic_segmentation}/Deeplabv3_Plus-R50.yaml +0 -0
  900. /paddlex/configs/{semantic_segmentation → modules/semantic_segmentation}/OCRNet_HRNet-W18.yaml +0 -0
  901. /paddlex/configs/{semantic_segmentation → modules/semantic_segmentation}/OCRNet_HRNet-W48.yaml +0 -0
  902. /paddlex/configs/{semantic_segmentation → modules/semantic_segmentation}/PP-LiteSeg-B.yaml +0 -0
  903. /paddlex/configs/{semantic_segmentation → modules/semantic_segmentation}/PP-LiteSeg-T.yaml +0 -0
  904. /paddlex/configs/{semantic_segmentation → modules/semantic_segmentation}/SegFormer-B0.yaml +0 -0
  905. /paddlex/configs/{semantic_segmentation → modules/semantic_segmentation}/SegFormer-B1.yaml +0 -0
  906. /paddlex/configs/{semantic_segmentation → modules/semantic_segmentation}/SegFormer-B2.yaml +0 -0
  907. /paddlex/configs/{semantic_segmentation → modules/semantic_segmentation}/SegFormer-B3.yaml +0 -0
  908. /paddlex/configs/{semantic_segmentation → modules/semantic_segmentation}/SegFormer-B4.yaml +0 -0
  909. /paddlex/configs/{semantic_segmentation → modules/semantic_segmentation}/SegFormer-B5.yaml +0 -0
  910. /paddlex/configs/{small_object_detection → modules/small_object_detection}/PP-YOLOE_plus_SOD-L.yaml +0 -0
  911. /paddlex/configs/{small_object_detection → modules/small_object_detection}/PP-YOLOE_plus_SOD-S.yaml +0 -0
  912. /paddlex/configs/{small_object_detection → modules/small_object_detection}/PP-YOLOE_plus_SOD-largesize-L.yaml +0 -0
  913. /paddlex/configs/{table_recognition → modules/table_structure_recognition}/SLANet.yaml +0 -0
  914. /paddlex/configs/{table_recognition → modules/table_structure_recognition}/SLANet_plus.yaml +0 -0
  915. /paddlex/configs/{text_detection → modules/text_detection}/PP-OCRv4_mobile_det.yaml +0 -0
  916. /paddlex/configs/{text_detection → modules/text_detection}/PP-OCRv4_server_det.yaml +0 -0
  917. /paddlex/configs/{text_recognition → modules/text_recognition}/PP-OCRv4_mobile_rec.yaml +0 -0
  918. /paddlex/configs/{text_recognition → modules/text_recognition}/PP-OCRv4_server_rec.yaml +0 -0
  919. /paddlex/configs/{text_recognition → modules/text_recognition}/ch_RepSVTR_rec.yaml +0 -0
  920. /paddlex/configs/{text_recognition → modules/text_recognition}/ch_SVTRv2_rec.yaml +0 -0
  921. /paddlex/configs/{ts_anomaly_detection → modules/ts_anomaly_detection}/AutoEncoder_ad.yaml +0 -0
  922. /paddlex/configs/{ts_anomaly_detection → modules/ts_anomaly_detection}/DLinear_ad.yaml +0 -0
  923. /paddlex/configs/{ts_anomaly_detection → modules/ts_anomaly_detection}/Nonstationary_ad.yaml +0 -0
  924. /paddlex/configs/{ts_anomaly_detection → modules/ts_anomaly_detection}/PatchTST_ad.yaml +0 -0
  925. /paddlex/configs/{ts_anomaly_detection → modules/ts_anomaly_detection}/TimesNet_ad.yaml +0 -0
  926. /paddlex/configs/{ts_classification → modules/ts_classification}/TimesNet_cls.yaml +0 -0
  927. /paddlex/configs/{ts_forecast → modules/ts_forecast}/DLinear.yaml +0 -0
  928. /paddlex/configs/{ts_forecast → modules/ts_forecast}/NLinear.yaml +0 -0
  929. /paddlex/configs/{ts_forecast → modules/ts_forecast}/Nonstationary.yaml +0 -0
  930. /paddlex/configs/{ts_forecast → modules/ts_forecast}/PatchTST.yaml +0 -0
  931. /paddlex/configs/{ts_forecast → modules/ts_forecast}/RLinear.yaml +0 -0
  932. /paddlex/configs/{ts_forecast → modules/ts_forecast}/TiDE.yaml +0 -0
  933. /paddlex/configs/{ts_forecast → modules/ts_forecast}/TimesNet.yaml +0 -0
  934. /paddlex/configs/{vehicle_attribute → modules/vehicle_attribute_recognition}/PP-LCNet_x1_0_vehicle_attribute.yaml +0 -0
  935. /paddlex/configs/{vehicle_detection → modules/vehicle_detection}/PP-YOLOE-L_vehicle.yaml +0 -0
  936. /paddlex/configs/{vehicle_detection → modules/vehicle_detection}/PP-YOLOE-S_vehicle.yaml +0 -0
  937. /paddlex/inference/{results/utils → common}/__init__.py +0 -0
  938. {paddlex-3.0.0b2.dist-info → paddlex-3.0.0rc0.dist-info}/LICENSE +0 -0
  939. {paddlex-3.0.0b2.dist-info → paddlex-3.0.0rc0.dist-info}/entry_points.txt +0 -0
  940. {paddlex-3.0.0b2.dist-info → paddlex-3.0.0rc0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1941 @@
1
+ # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # Modified from OpenAI Whisper 2022 (https://github.com/openai/whisper/whisper)
15
+ import os
16
+ import tqdm
17
+ import zlib
18
+ import soundfile
19
+ import numpy as np
20
+ import lazy_paddle as paddle
21
+
22
+ from dataclasses import dataclass
23
+ from dataclasses import field
24
+ from functools import lru_cache
25
+ from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union
26
+
27
+ from ..common.tokenizer import GPTTokenizer
28
+
29
+ __all__ = [
30
+ "Whisper",
31
+ "Tokenizer",
32
+ ]
33
+
34
+
35
+ def exact_div(x, y):
36
+ assert x % y == 0
37
+ return x // y
38
+
39
+
40
+ _MODELS = ["large"]
41
+ SAMPLE_RATE = 16000
42
+ N_FFT = 400
43
+ N_MELS = 80
44
+ HOP_LENGTH = 160
45
+ CHUNK_LENGTH = 30
46
+ N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk
47
+ N_FRAMES = exact_div(
48
+ N_SAMPLES, HOP_LENGTH
49
+ ) # 3000: number of frames in a mel spectrogram input
50
+
51
+
52
+ @dataclass
53
+ class ModelDimensions:
54
+ n_mels: int
55
+ n_audio_ctx: int
56
+ n_audio_state: int
57
+ n_audio_head: int
58
+ n_audio_layer: int
59
+ n_vocab: int
60
+ n_text_ctx: int
61
+ n_text_state: int
62
+ n_text_head: int
63
+ n_text_layer: int
64
+
65
+
66
+ LANGUAGES = {
67
+ "en": "english",
68
+ "zh": "chinese",
69
+ "de": "german",
70
+ "es": "spanish",
71
+ "ru": "russian",
72
+ "ko": "korean",
73
+ "fr": "french",
74
+ "ja": "japanese",
75
+ "pt": "portuguese",
76
+ "tr": "turkish",
77
+ "pl": "polish",
78
+ "ca": "catalan",
79
+ "nl": "dutch",
80
+ "ar": "arabic",
81
+ "sv": "swedish",
82
+ "it": "italian",
83
+ "id": "indonesian",
84
+ "hi": "hindi",
85
+ "fi": "finnish",
86
+ "vi": "vietnamese",
87
+ "iw": "hebrew",
88
+ "uk": "ukrainian",
89
+ "el": "greek",
90
+ "ms": "malay",
91
+ "cs": "czech",
92
+ "ro": "romanian",
93
+ "da": "danish",
94
+ "hu": "hungarian",
95
+ "ta": "tamil",
96
+ "no": "norwegian",
97
+ "th": "thai",
98
+ "ur": "urdu",
99
+ "hr": "croatian",
100
+ "bg": "bulgarian",
101
+ "lt": "lithuanian",
102
+ "la": "latin",
103
+ "mi": "maori",
104
+ "ml": "malayalam",
105
+ "cy": "welsh",
106
+ "sk": "slovak",
107
+ "te": "telugu",
108
+ "fa": "persian",
109
+ "lv": "latvian",
110
+ "bn": "bengali",
111
+ "sr": "serbian",
112
+ "az": "azerbaijani",
113
+ "sl": "slovenian",
114
+ "kn": "kannada",
115
+ "et": "estonian",
116
+ "mk": "macedonian",
117
+ "br": "breton",
118
+ "eu": "basque",
119
+ "is": "icelandic",
120
+ "hy": "armenian",
121
+ "ne": "nepali",
122
+ "mn": "mongolian",
123
+ "bs": "bosnian",
124
+ "kk": "kazakh",
125
+ "sq": "albanian",
126
+ "sw": "swahili",
127
+ "gl": "galician",
128
+ "mr": "marathi",
129
+ "pa": "punjabi",
130
+ "si": "sinhala",
131
+ "km": "khmer",
132
+ "sn": "shona",
133
+ "yo": "yoruba",
134
+ "so": "somali",
135
+ "af": "afrikaans",
136
+ "oc": "occitan",
137
+ "ka": "georgian",
138
+ "be": "belarusian",
139
+ "tg": "tajik",
140
+ "sd": "sindhi",
141
+ "gu": "gujarati",
142
+ "am": "amharic",
143
+ "yi": "yiddish",
144
+ "lo": "lao",
145
+ "uz": "uzbek",
146
+ "fo": "faroese",
147
+ "ht": "haitian creole",
148
+ "ps": "pashto",
149
+ "tk": "turkmen",
150
+ "nn": "nynorsk",
151
+ "mt": "maltese",
152
+ "sa": "sanskrit",
153
+ "lb": "luxembourgish",
154
+ "my": "myanmar",
155
+ "bo": "tibetan",
156
+ "tl": "tagalog",
157
+ "mg": "malagasy",
158
+ "as": "assamese",
159
+ "tt": "tatar",
160
+ "haw": "hawaiian",
161
+ "ln": "lingala",
162
+ "ha": "hausa",
163
+ "ba": "bashkir",
164
+ "jw": "javanese",
165
+ "su": "sundanese",
166
+ }
167
+
168
+ # language code lookup by name, with a few language aliases
169
+ TO_LANGUAGE_CODE = {
170
+ **{language: code for code, language in LANGUAGES.items()},
171
+ "burmese": "my",
172
+ "valencian": "ca",
173
+ "flemish": "nl",
174
+ "haitian": "ht",
175
+ "letzeburgesch": "lb",
176
+ "pushto": "ps",
177
+ "panjabi": "pa",
178
+ "moldavian": "ro",
179
+ "moldovan": "ro",
180
+ "sinhalese": "si",
181
+ "castilian": "es",
182
+ }
183
+
184
+
185
+ def compression_ratio(text) -> float:
186
+ return len(text) / len(zlib.compress(text.encode("utf-8")))
187
+
188
+
189
+ def format_timestamp(
190
+ seconds: float, always_include_hours: bool = False, decimal_marker: str = "."
191
+ ):
192
+ assert seconds >= 0, "non-negative timestamp expected"
193
+ milliseconds = round(seconds * 1000.0)
194
+
195
+ hours = milliseconds // 3_600_000
196
+ milliseconds -= hours * 3_600_000
197
+
198
+ minutes = milliseconds // 60_000
199
+ milliseconds -= minutes * 60_000
200
+
201
+ seconds = milliseconds // 1_000
202
+ milliseconds -= seconds * 1_000
203
+
204
+ hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
205
+ return (
206
+ f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
207
+ )
208
+
209
+
210
+ @dataclass(frozen=True)
211
+ class Tokenizer:
212
+ """A thin wrapper around `GPTTokenizer` providing quick access to special tokens"""
213
+
214
+ tokenizer: "GPTTokenizer"
215
+ language: Optional[str]
216
+ sot_sequence: Tuple[int]
217
+
218
+ def encode(self, text, **kwargs):
219
+ return self.tokenizer.encode(text, **kwargs)
220
+
221
+ def decode(
222
+ self, token_ids: Union[int, List[int], np.ndarray, paddle.Tensor], **kwargs
223
+ ):
224
+ if len(token_ids) > 1:
225
+ ids_list = []
226
+ for ids in token_ids:
227
+ if paddle.is_tensor(ids):
228
+ ids = ids.item()
229
+ if ids < len(self.tokenizer):
230
+ ids_list.append(ids)
231
+ token_ids = ids_list
232
+ elif len(token_ids) == 1:
233
+ token_ids = token_ids[0]
234
+ else:
235
+ raise ValueError(f"token_ids {token_ids} load error.")
236
+
237
+ return self.tokenizer.decode(token_ids, **kwargs)
238
+
239
+ def decode_with_timestamps(self, tokens) -> str:
240
+ """
241
+ Timestamp tokens are above the special tokens' id range and are ignored by `decode()`.
242
+ This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
243
+ """
244
+ outputs = [[]]
245
+ for token in tokens:
246
+ if token >= self.timestamp_begin:
247
+ timestamp = f"<|{(token - self.timestamp_begin) * 0.02:.2f}|>"
248
+ outputs.append(timestamp)
249
+ outputs.append([])
250
+ else:
251
+ outputs[-1].append(token)
252
+ outputs = [
253
+ s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs
254
+ ]
255
+ return "".join(outputs)
256
+
257
+ @property
258
+ @lru_cache()
259
+ def eot(self) -> int:
260
+ return self.tokenizer.eos_token_id
261
+
262
+ @property
263
+ @lru_cache()
264
+ def sot(self) -> int:
265
+ return self._get_single_token_id("<|startoftranscript|>")
266
+
267
+ @property
268
+ @lru_cache()
269
+ def sot_lm(self) -> int:
270
+ return self._get_single_token_id("<|startoflm|>")
271
+
272
+ @property
273
+ @lru_cache()
274
+ def sot_prev(self) -> int:
275
+ return self._get_single_token_id("<|startofprev|>")
276
+
277
+ @property
278
+ @lru_cache()
279
+ def no_speech(self) -> int:
280
+ return self._get_single_token_id("<|nospeech|>")
281
+
282
+ @property
283
+ @lru_cache()
284
+ def no_timestamps(self) -> int:
285
+ return self._get_single_token_id("<|notimestamps|>")
286
+
287
+ @property
288
+ @lru_cache()
289
+ def timestamp_begin(self) -> int:
290
+ return self.tokenizer.all_special_ids[-1] + 1
291
+
292
+ @property
293
+ @lru_cache()
294
+ def language_token(self) -> int:
295
+ """Returns the token id corresponding to the value of the `language` field"""
296
+ if self.language is None:
297
+ raise ValueError("This tokenizer does not have language token configured")
298
+
299
+ additional_tokens = dict(
300
+ zip(
301
+ self.tokenizer.additional_special_tokens,
302
+ self.tokenizer.additional_special_tokens_ids,
303
+ )
304
+ )
305
+ candidate = f"<|{self.language}|>"
306
+ if candidate in additional_tokens:
307
+ return additional_tokens[candidate]
308
+
309
+ raise KeyError(f"Language {self.language} not found in tokenizer.")
310
+
311
+ @property
312
+ @lru_cache()
313
+ def all_language_tokens(self) -> Tuple[int]:
314
+ result = []
315
+ for token, token_id in zip(
316
+ self.tokenizer.additional_special_tokens,
317
+ self.tokenizer.additional_special_tokens_ids,
318
+ ):
319
+ if token.strip("<|>") in LANGUAGES:
320
+ result.append(token_id)
321
+ return tuple(result)
322
+
323
+ @property
324
+ @lru_cache()
325
+ def all_language_codes(self) -> Tuple[str]:
326
+ return tuple(self.decode([l]).strip("<|>") for l in self.all_language_tokens)
327
+
328
+ @property
329
+ @lru_cache()
330
+ def sot_sequence_including_notimestamps(self) -> Tuple[int]:
331
+ return tuple(list(self.sot_sequence) + [self.no_timestamps])
332
+
333
+ @property
334
+ @lru_cache()
335
+ def non_speech_tokens(self) -> Tuple[int]:
336
+ """
337
+ Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
338
+ annotations, to prevent sampling texts that are not actually spoken in the audio, e.g.
339
+
340
+ - ♪♪♪
341
+ - ( SPEAKING FOREIGN LANGUAGE )
342
+ - [DAVID] Hey there,
343
+
344
+ keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
345
+ """
346
+ symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』')
347
+ symbols += (
348
+ "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
349
+ )
350
+
351
+ # symbols that may be a single token or multiple tokens depending on the tokenizer.
352
+ # In case they're multiple tokens, suppress the first token, which is safe because:
353
+ # These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress
354
+ # in generations, and in the 3-byte UTF-8 representation they share the first two bytes.
355
+ miscellaneous = set("♩♪♫♬♭♮♯")
356
+ assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
357
+
358
+ # allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
359
+ result = {
360
+ self.tokenizer.encode(" -").input_ids[0],
361
+ self.tokenizer.encode(" '").input_ids[0],
362
+ }
363
+ for symbol in symbols + list(miscellaneous):
364
+ for tokens in [
365
+ self.tokenizer.encode(symbol).input_ids,
366
+ self.tokenizer.encode(" " + symbol).input_ids,
367
+ ]:
368
+ if len(tokens) == 1 or symbol in miscellaneous:
369
+ result.add(tokens[0])
370
+
371
+ return tuple(sorted(result))
372
+
373
+ def _get_single_token_id(self, text) -> int:
374
+ tokens = self.tokenizer.encode(text).input_ids
375
+ assert len(tokens) == 1, f"{text} is not encoded as a single token"
376
+ return tokens[0]
377
+
378
+
379
+ @lru_cache(maxsize=None)
380
+ def build_tokenizer(resource_path: str, name: str = "gpt2"):
381
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
382
+ path = os.path.join(resource_path, "assets", name)
383
+ tokenizer = GPTTokenizer.from_pretrained(path)
384
+
385
+ specials = [
386
+ "<|startoftranscript|>",
387
+ *[f"<|{lang}|>" for lang in LANGUAGES.keys()],
388
+ "<|translate|>",
389
+ "<|transcribe|>",
390
+ "<|startoflm|>",
391
+ "<|startofprev|>",
392
+ "<|nospeech|>",
393
+ "<|notimestamps|>",
394
+ ]
395
+
396
+ tokenizer.add_special_tokens(dict(additional_special_tokens=specials))
397
+ return tokenizer
398
+
399
+
400
+ @lru_cache(maxsize=None)
401
+ def get_tokenizer(
402
+ multilingual: bool,
403
+ resource_path: str,
404
+ *,
405
+ task: Optional[str] = None, # Literal["transcribe", "translate", None]
406
+ language: Optional[str] = None,
407
+ ) -> Tokenizer:
408
+ if language is not None:
409
+ language = language.lower()
410
+ if language not in LANGUAGES:
411
+ if language in TO_LANGUAGE_CODE:
412
+ language = TO_LANGUAGE_CODE[language]
413
+ else:
414
+ raise ValueError(f"Unsupported language: {language}")
415
+
416
+ if multilingual:
417
+ tokenizer_name = "multilingual"
418
+ task = task or "transcribe"
419
+ language = language or "en"
420
+ else:
421
+ tokenizer_name = "gpt2"
422
+ task = None
423
+ language = None
424
+
425
+ tokenizer = build_tokenizer(resource_path=resource_path, name=tokenizer_name)
426
+ all_special_ids: List[int] = tokenizer.all_special_ids
427
+ sot: int = all_special_ids[1]
428
+ translate: int = all_special_ids[-6]
429
+ transcribe: int = all_special_ids[-5]
430
+
431
+ langs = tuple(LANGUAGES.keys())
432
+ sot_sequence = [sot]
433
+ if language is not None:
434
+ sot_sequence.append(sot + 1 + langs.index(language))
435
+ if task is not None:
436
+ sot_sequence.append(transcribe if task == "transcribe" else translate)
437
+
438
+ return Tokenizer(
439
+ tokenizer=tokenizer, language=language, sot_sequence=tuple(sot_sequence)
440
+ )
441
+
442
+
443
+ class MultiHeadAttention(paddle.nn.Layer):
444
+ def __init__(self, n_state: int, n_head: int):
445
+ super().__init__()
446
+ self.n_head = n_head
447
+ self.query = paddle.nn.Linear(n_state, n_state, bias_attr=True)
448
+ self.key = paddle.nn.Linear(n_state, n_state, bias_attr=False)
449
+ self.value = paddle.nn.Linear(n_state, n_state, bias_attr=True)
450
+ self.out = paddle.nn.Linear(n_state, n_state, bias_attr=True)
451
+
452
+ def forward(
453
+ self,
454
+ x: paddle.Tensor,
455
+ xa: Optional[paddle.Tensor] = None,
456
+ mask: Optional[paddle.Tensor] = None,
457
+ kv_cache: Optional[dict] = None,
458
+ ):
459
+ q = self.query(x)
460
+
461
+ if kv_cache is None or xa is None or self.key not in kv_cache:
462
+ # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
463
+ # otherwise, perform key/value projections for self- or cross-attention as usual.
464
+ k = self.key(x if xa is None else xa)
465
+ v = self.value(x if xa is None else xa)
466
+ else:
467
+ # for cross-attention, calculate keys and values once and reuse in subsequent calls.
468
+ k = kv_cache[self.key]
469
+ v = kv_cache[self.value]
470
+
471
+ wv = self.qkv_attention(q, k, v, mask)
472
+ return self.out(wv)
473
+
474
+ def qkv_attention(
475
+ self,
476
+ q: paddle.Tensor,
477
+ k: paddle.Tensor,
478
+ v: paddle.Tensor,
479
+ mask: Optional[paddle.Tensor] = None,
480
+ ):
481
+ n_batch, n_ctx, n_state = q.shape
482
+ scale = (n_state // self.n_head) ** -0.25
483
+ q = (
484
+ paddle.transpose(q.reshape([*q.shape[:2], self.n_head, -1]), (0, 2, 1, 3))
485
+ * scale
486
+ )
487
+ k = (
488
+ paddle.transpose(k.reshape([*k.shape[:2], self.n_head, -1]), (0, 2, 3, 1))
489
+ * scale
490
+ )
491
+ v = paddle.transpose(v.reshape([*v.shape[:2], self.n_head, -1]), (0, 2, 1, 3))
492
+
493
+ qk = q @ k
494
+ if mask is not None:
495
+ qk = qk + mask[:n_ctx, :n_ctx]
496
+
497
+ w = paddle.nn.functional.softmax(qk.astype(q.dtype), axis=-1)
498
+ return paddle.transpose((w @ v), (0, 2, 1, 3)).flatten(start_axis=2)
499
+
500
+
501
+ class ResidualAttentionBlock(paddle.nn.Layer):
502
+ def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
503
+ super().__init__()
504
+
505
+ self.attn = MultiHeadAttention(n_state, n_head)
506
+ self.attn_ln = paddle.nn.LayerNorm(n_state)
507
+
508
+ self.cross_attn = (
509
+ MultiHeadAttention(n_state, n_head) if cross_attention else None
510
+ )
511
+ self.cross_attn_ln = paddle.nn.LayerNorm(n_state) if cross_attention else None
512
+
513
+ n_mlp = n_state * 4
514
+ self.mlp = paddle.nn.Sequential(
515
+ paddle.nn.Linear(n_state, n_mlp, bias_attr=True),
516
+ paddle.nn.GELU(),
517
+ paddle.nn.Linear(n_mlp, n_state, bias_attr=True),
518
+ )
519
+ self.mlp_ln = paddle.nn.LayerNorm(n_state)
520
+
521
+ def forward(
522
+ self,
523
+ x: paddle.Tensor,
524
+ xa: Optional[paddle.Tensor] = None,
525
+ mask: Optional[paddle.Tensor] = None,
526
+ kv_cache: Optional[dict] = None,
527
+ ):
528
+ x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)
529
+ if self.cross_attn:
530
+ x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)
531
+ x = x + self.mlp(self.mlp_ln(x))
532
+ return x
533
+
534
+
535
+ def sinusoids(length, channels, max_timescale=10000):
536
+ """Returns sinusoids for positional embedding"""
537
+ assert channels % 2 == 0
538
+ log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
539
+ inv_timescales = paddle.exp(
540
+ -log_timescale_increment * paddle.arange(channels // 2, dtype=paddle.float32)
541
+ )
542
+ scaled_time = (
543
+ paddle.arange(length, dtype=paddle.float32)[:, np.newaxis]
544
+ * inv_timescales[np.newaxis, :]
545
+ )
546
+ return paddle.to_tensor(
547
+ paddle.concat([paddle.sin(scaled_time), paddle.cos(scaled_time)], axis=1)
548
+ )
549
+
550
+
551
+ class AudioEncoder(paddle.nn.Layer):
552
+ def __init__(
553
+ self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
554
+ ):
555
+ super().__init__()
556
+ self.conv1 = paddle.nn.Conv1D(
557
+ n_mels, n_state, kernel_size=3, stride=1, padding=1, bias_attr=True
558
+ )
559
+ self.conv2 = paddle.nn.Conv1D(
560
+ n_state, n_state, kernel_size=3, stride=2, padding=1, bias_attr=True
561
+ )
562
+ self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
563
+
564
+ self.blocks: Iterable[ResidualAttentionBlock] = paddle.nn.LayerList(
565
+ [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
566
+ )
567
+ self.ln_post = paddle.nn.LayerNorm(n_state)
568
+
569
+ def forward(self, x: paddle.Tensor):
570
+ """
571
+ x : paddle.Tensor, shape = (batch_size, n_mels, n_ctx)
572
+ the mel spectrogram of the audio
573
+ """
574
+ x = paddle.nn.functional.gelu(self.conv1(x))
575
+ x = paddle.nn.functional.gelu(self.conv2(x))
576
+ x = paddle.transpose(x, (0, 2, 1))
577
+
578
+ assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
579
+ x = x + self.positional_embedding
580
+
581
+ for block in self.blocks:
582
+ x = block(x)
583
+
584
+ x = self.ln_post(x)
585
+ return x
586
+
587
+
588
+ class TextDecoder(paddle.nn.Layer):
589
+ def __init__(
590
+ self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
591
+ ):
592
+ super().__init__()
593
+
594
+ self.token_embedding = paddle.nn.Embedding(n_vocab, n_state)
595
+ self.positional_embedding = paddle.create_parameter(
596
+ shape=[n_ctx, n_state], dtype="float32"
597
+ )
598
+
599
+ self.blocks: Iterable[ResidualAttentionBlock] = paddle.nn.LayerList(
600
+ [
601
+ ResidualAttentionBlock(n_state, n_head, cross_attention=True)
602
+ for _ in range(n_layer)
603
+ ]
604
+ )
605
+ self.ln = paddle.nn.LayerNorm(n_state)
606
+
607
+ mask = paddle.full(shape=[n_ctx, n_state], fill_value=-np.inf, dtype="float32")
608
+ mask = paddle.triu(mask, diagonal=1)
609
+ self.register_buffer("mask", mask, persistable=False)
610
+
611
+ def forward(
612
+ self, x: paddle.Tensor, xa: paddle.Tensor, kv_cache: Optional[dict] = None
613
+ ):
614
+ """
615
+ x : paddle.LongTensor, shape = (batch_size, <= n_ctx)
616
+ the text tokens
617
+ xa : paddle.Tensor, shape = (batch_size, n_mels, n_audio_ctx)
618
+ the encoded audio features to be attended on
619
+ """
620
+ offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
621
+ x = (
622
+ self.token_embedding(x)
623
+ + self.positional_embedding[offset : offset + x.shape[-1]]
624
+ )
625
+ x = x.to(xa.dtype)
626
+
627
+ for block in self.blocks:
628
+ x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
629
+
630
+ x = self.ln(x)
631
+ logits = x @ paddle.transpose(self.token_embedding.weight, (1, 0))
632
+
633
+ return logits
634
+
635
+
636
+ @dataclass(frozen=True)
637
+ class DecodingOptions:
638
+ task: str = (
639
+ "transcribe" # whether to perform X->X "transcribe" or X->English "translate"
640
+ )
641
+ language: Optional[str] = (
642
+ None # language that the audio is in; uses detected language if None
643
+ )
644
+ # sampling-related options
645
+ temperature: float = 0.0
646
+ sample_len: Optional[int] = None # maximum number of tokens to sample
647
+ best_of: Optional[int] = (
648
+ None # number of independent samples to collect, when t > 0
649
+ )
650
+ beam_size: Optional[int] = None # number of beams in beam search, when t == 0
651
+ patience: Optional[float] = (
652
+ None # patience in beam search (https://arxiv.org/abs/2204.05424)
653
+ )
654
+
655
+ # options for ranking generations (either beams or best-of-N samples)
656
+ length_penalty: Optional[float] = (
657
+ None # "alpha" in Google NMT, None defaults to length norm
658
+ )
659
+
660
+ # prompt, prefix, and token suppression
661
+ prompt: Optional[Union[str, List[int]]] = (
662
+ None # text or tokens for the previous context
663
+ )
664
+ prefix: Optional[Union[str, List[int]]] = (
665
+ None # text or tokens to prefix the current context
666
+ )
667
+ suppress_blank: bool = True # this will suppress blank outputs
668
+
669
+ # list of tokens ids (or comma-separated token ids) to suppress
670
+ # "-1" will suppress a set of symbols as defined in `tokenizer.non_speech_tokens()`
671
+ suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1"
672
+
673
+ # timestamp sampling options
674
+ without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only
675
+ max_initial_timestamp: Optional[float] = (
676
+ 1.0 # the initial timestamp cannot be later than this
677
+ )
678
+
679
+ # implementation details
680
+ fp16: bool = False # use fp16 for most of the calculation
681
+
682
+
683
+ @dataclass(frozen=True)
684
+ class DecodingResult:
685
+ audio_features: paddle.Tensor
686
+ language: str
687
+ language_probs: Optional[Dict[str, float]] = None
688
+ tokens: List[int] = field(default_factory=list)
689
+ text: str = ""
690
+ avg_logprob: float = np.nan
691
+ no_speech_prob: float = np.nan
692
+ temperature: float = np.nan
693
+ compression_ratio: float = np.nan
694
+
695
+
696
+ class Inference:
697
+ def logits(
698
+ self, tokens: paddle.Tensor, audio_features: paddle.Tensor
699
+ ) -> paddle.Tensor:
700
+ """Perform a forward pass on the decoder and return per-token logits"""
701
+ raise NotImplementedError
702
+
703
+ def rearrange_kv_cache(self, source_indices) -> None:
704
+ """Update the key-value cache according to the updated beams"""
705
+ raise NotImplementedError
706
+
707
+ def cleanup_caching(self) -> None:
708
+ """Clean up any resources or hooks after decoding is finished"""
709
+ pass
710
+
711
+
712
+ class WhisperInference(Inference):
713
+ def __init__(self, model: "Whisper", initial_token_length: int):
714
+ self.model: "Whisper" = model
715
+ self.initial_token_length = initial_token_length
716
+ self.kv_cache = {}
717
+ self.hooks = []
718
+
719
+ def logits(
720
+ self, tokens: paddle.Tensor, audio_features: paddle.Tensor
721
+ ) -> paddle.Tensor:
722
+ if not self.kv_cache:
723
+ self.kv_cache, self.hooks = self.model.install_kv_cache_hooks()
724
+
725
+ if tokens.shape[-1] > self.initial_token_length:
726
+ # only need to use the last token except in the first forward pass
727
+ tokens = tokens[:, -1:]
728
+
729
+ return self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache)
730
+
731
+ def cleanup_caching(self):
732
+ for hook in self.hooks:
733
+ hook.remove()
734
+
735
+ self.kv_cache = {}
736
+ self.hooks = []
737
+
738
+ def rearrange_kv_cache(self, source_indices):
739
+ for module, tensor in self.kv_cache.items():
740
+ # update the key/value cache to contain the selected sequences
741
+ self.kv_cache[module] = tensor[source_indices].detach()
742
+
743
+
744
+ @paddle.no_grad()
745
+ def detect_language(
746
+ model: "Whisper",
747
+ mel: paddle.Tensor,
748
+ resource_path: str,
749
+ tokenizer: Tokenizer = None,
750
+ ) -> Tuple[paddle.Tensor, List[dict]]:
751
+ """
752
+ Detect the spoken language in the audio, and return them as list of strings, along with the ids
753
+ of the most probable language tokens and the probability distribution over all language tokens.
754
+ This is performed outside the main decode loop in order to not interfere with kv-caching.
755
+
756
+ Returns
757
+ -------
758
+ language_tokens : Tensor, shape = (batch_size,)
759
+ ids of the most probable language tokens, which appears after the startoftranscript token.
760
+ language_probs : List[Dict[str, float]], length = batch_size
761
+ list of dictionaries containing the probability distribution over all languages.
762
+ """
763
+ if tokenizer is None:
764
+ tokenizer = get_tokenizer(model.is_multilingual, resource_path=resource_path)
765
+ if (
766
+ tokenizer.language is None
767
+ or tokenizer.language_token not in tokenizer.sot_sequence
768
+ ):
769
+ raise ValueError(
770
+ "This model doesn't have language tokens so it can't perform lang id"
771
+ )
772
+
773
+ single = mel.ndim == 2
774
+ if single:
775
+ mel = mel.unsqueeze(0)
776
+
777
+ # skip encoder forward pass if already-encoded audio features were given
778
+ if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state):
779
+ mel = model.encoder(mel)
780
+
781
+ # forward pass using a single token, startoftranscript
782
+ batch_size = mel.shape[0]
783
+ x = paddle.to_tensor([[tokenizer.sot]] * batch_size) # [batch_size, 1]
784
+ logits = model.logits(x, mel)[:, 0]
785
+
786
+ # collect detected languages; suppress all non-language tokens
787
+ mask = paddle.ones(paddle.to_tensor(logits.shape[-1]), dtype=bool)
788
+ mask[list(tokenizer.all_language_tokens)] = False
789
+ logits[:, mask] = -np.inf
790
+ language_tokens = paddle.argmax(logits, axis=-1)
791
+ language_token_probs = paddle.nn.functional.softmax(logits, axis=-1)
792
+ language_probs = [
793
+ {
794
+ c: language_token_probs[i, j].tolist()
795
+ for j, c in zip(tokenizer.all_language_tokens, tokenizer.all_language_codes)
796
+ }
797
+ for i in range(batch_size)
798
+ ]
799
+
800
+ if single:
801
+ language_tokens = language_tokens[0]
802
+ language_probs = language_probs[0]
803
+
804
+ return language_tokens, language_probs
805
+
806
+
807
+ def transcribe(
808
+ model: "Whisper",
809
+ mel: paddle.Tensor,
810
+ resource_path: str,
811
+ *,
812
+ verbose: Optional[bool] = None,
813
+ temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
814
+ compression_ratio_threshold: Optional[float] = 2.4,
815
+ logprob_threshold: Optional[float] = -1.0,
816
+ no_speech_threshold: Optional[float] = 0.6,
817
+ condition_on_previous_text: bool = True,
818
+ **decode_options,
819
+ ):
820
+ """
821
+ Transcribe an audio file using Whisper
822
+
823
+ Parameters
824
+ ----------
825
+ model: Whisper
826
+ The Whisper model instance
827
+
828
+ mel: paddle.Tensor
829
+ The audio feature
830
+
831
+ verbose: bool
832
+ Whether to display the text being decoded to the console. If True, displays all the details,
833
+ If False, displays minimal details. If None, does not display anything
834
+
835
+ temperature: Union[float, Tuple[float, ...]]
836
+ Temperature for sampling. It can be a tuple of temperatures, which will be successfully used
837
+ upon failures according to either `compression_ratio_threshold` or `logprob_threshold`.
838
+
839
+ compression_ratio_threshold: float
840
+ If the gzip compression ratio is above this value, treat as failed
841
+
842
+ logprob_threshold: float
843
+ If the average log probability over sampled tokens is below this value, treat as failed
844
+
845
+ no_speech_threshold: float
846
+ If the no_speech probability is higher than this value AND the average log probability
847
+ over sampled tokens is below `logprob_threshold`, consider the segment as silent
848
+
849
+ condition_on_previous_text: bool
850
+ if True, the previous output of the model is provided as a prompt for the next window;
851
+ disabling may make the text inconsistent across windows, but the model becomes less prone to
852
+ getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
853
+
854
+ decode_options: dict
855
+ Keyword arguments to construct `DecodingOptions` instances
856
+
857
+ Returns
858
+ -------
859
+ A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
860
+ the spoken language ("language"), which is detected when `decode_options["language"]` is None.
861
+ """
862
+ dtype = np.float32 # paddle only support float32
863
+
864
+ if dtype == np.float32:
865
+ decode_options["fp16"] = False
866
+
867
+ if (
868
+ decode_options.get("language") == "None"
869
+ or decode_options.get("language", None) is None
870
+ ):
871
+ if not model.is_multilingual:
872
+ decode_options["language"] = "en"
873
+ else:
874
+ if verbose:
875
+ print(
876
+ "Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
877
+ )
878
+ segment = pad_or_trim(mel, N_FRAMES)
879
+ _, probs = model.detect_language(segment, resource_path)
880
+ decode_options["language"] = max(probs, key=probs.get)
881
+ if verbose is not None:
882
+ print(
883
+ f"Detected language: {LANGUAGES[decode_options['language']].title()}"
884
+ )
885
+
886
+ language = decode_options["language"]
887
+ task = decode_options.get("task", "transcribe")
888
+ tokenizer = get_tokenizer(
889
+ model.is_multilingual, resource_path=resource_path, language=language, task=task
890
+ )
891
+
892
+ def decode_with_fallback(segment: paddle.Tensor) -> DecodingResult:
893
+ temperatures = (
894
+ [temperature] if isinstance(temperature, (int, float)) else temperature
895
+ )
896
+ decode_result = None
897
+
898
+ for t in temperatures:
899
+ kwargs = {**decode_options}
900
+ if t > 0:
901
+ # disable beam_size and patience when t > 0
902
+ kwargs.pop("beam_size", None)
903
+ kwargs.pop("patience", None)
904
+ else:
905
+ # disable best_of when t == 0
906
+ kwargs.pop("best_of", None)
907
+
908
+ options = DecodingOptions(**kwargs, temperature=t)
909
+ decode_result = model.decode(segment, options, resource_path)
910
+
911
+ needs_fallback = False
912
+ if (
913
+ compression_ratio_threshold is not None
914
+ and decode_result.compression_ratio > compression_ratio_threshold
915
+ ):
916
+ needs_fallback = True # too repetitive
917
+ if (
918
+ logprob_threshold is not None
919
+ and decode_result.avg_logprob < logprob_threshold
920
+ ):
921
+ needs_fallback = True # average log probability is too low
922
+
923
+ if not needs_fallback:
924
+ break
925
+
926
+ return decode_result
927
+
928
+ seek = 0
929
+ input_stride = exact_div(
930
+ N_FRAMES, model.dims.n_audio_ctx
931
+ ) # mel frames per output token: 2
932
+ time_precision = (
933
+ input_stride * HOP_LENGTH / SAMPLE_RATE
934
+ ) # time per output token: 0.02 (seconds)
935
+ all_tokens = []
936
+ all_segments = []
937
+ prompt_reset_since = 0
938
+
939
+ initial_prompt = decode_options.pop("initial_prompt", None)
940
+ if initial_prompt and initial_prompt != "None":
941
+ initial_prompt = tokenizer.encode(" " + initial_prompt.strip()).input_ids
942
+ all_tokens.extend(initial_prompt)
943
+ else:
944
+ initial_prompt = []
945
+
946
+ def add_segment(
947
+ *, start: float, end: float, text_tokens: paddle.Tensor, result: DecodingResult
948
+ ):
949
+ text = tokenizer.decode(
950
+ [token for token in text_tokens if token < tokenizer.eot]
951
+ )
952
+ if len(text.strip()) == 0: # skip empty text output
953
+ return
954
+
955
+ all_segments.append(
956
+ {
957
+ "id": len(all_segments),
958
+ "seek": seek,
959
+ "start": start,
960
+ "end": end,
961
+ "text": text,
962
+ "tokens": result.tokens,
963
+ "temperature": result.temperature,
964
+ "avg_logprob": result.avg_logprob,
965
+ "compression_ratio": result.compression_ratio,
966
+ "no_speech_prob": result.no_speech_prob,
967
+ }
968
+ )
969
+ if verbose:
970
+ print(f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}")
971
+
972
+ # show the progress bar when verbose is False (otherwise the transcribed text will be printed)
973
+ num_frames = mel.shape[-1]
974
+ previous_seek_value = seek
975
+
976
+ with tqdm.tqdm(
977
+ total=num_frames, unit="frames", disable=verbose is not False
978
+ ) as pbar:
979
+ while seek < num_frames:
980
+ timestamp_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
981
+ segment = pad_or_trim(mel[:, seek:], N_FRAMES)
982
+ segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE
983
+
984
+ decode_options["prompt"] = all_tokens[prompt_reset_since:]
985
+ result: DecodingResult = decode_with_fallback(segment)
986
+ tokens = paddle.to_tensor(result.tokens)
987
+
988
+ if no_speech_threshold is not None:
989
+ # no voice activity check
990
+ should_skip = result.no_speech_prob > no_speech_threshold
991
+ if (
992
+ logprob_threshold is not None
993
+ and result.avg_logprob > logprob_threshold
994
+ ):
995
+ # don't skip if the logprob is high enough, despite the no_speech_prob
996
+ should_skip = False
997
+
998
+ if should_skip:
999
+ seek += segment.shape[
1000
+ -1
1001
+ ] # fast-forward to the next segment boundary
1002
+ continue
1003
+
1004
+ timestamp_tokens: paddle.Tensor = tokens.greater_equal(
1005
+ paddle.to_tensor(tokenizer.timestamp_begin)
1006
+ )
1007
+
1008
+ consecutive = paddle.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
1009
+ if (
1010
+ len(consecutive) > 0
1011
+ ): # if the output contains two consecutive timestamp tokens
1012
+ consecutive = paddle.add(consecutive, paddle.to_tensor(1))
1013
+ last_slice = 0
1014
+ for current_slice in consecutive:
1015
+ sliced_tokens = tokens[last_slice:current_slice]
1016
+ start_timestamp_position = (
1017
+ sliced_tokens[0].item() - tokenizer.timestamp_begin
1018
+ )
1019
+ end_timestamp_position = (
1020
+ sliced_tokens[-1].item() - tokenizer.timestamp_begin
1021
+ )
1022
+ add_segment(
1023
+ start=timestamp_offset
1024
+ + start_timestamp_position * time_precision,
1025
+ end=timestamp_offset + end_timestamp_position * time_precision,
1026
+ text_tokens=sliced_tokens[1:-1],
1027
+ result=result,
1028
+ )
1029
+ last_slice = current_slice
1030
+ last_timestamp_position = (
1031
+ tokens[last_slice - 1].item() - tokenizer.timestamp_begin
1032
+ )
1033
+ seek += last_timestamp_position * input_stride
1034
+ all_tokens.extend(tokens[: last_slice + 1].tolist())
1035
+ else:
1036
+ duration = segment_duration
1037
+ timestamps = tokens[timestamp_tokens.nonzero().flatten()]
1038
+ if (
1039
+ len(timestamps) > 0
1040
+ and timestamps[-1].item() != tokenizer.timestamp_begin
1041
+ ):
1042
+ # no consecutive timestamps but it has a timestamp; use the last one.
1043
+ # single timestamp at the end means no speech after the last timestamp.
1044
+ last_timestamp_position = (
1045
+ timestamps[-1].item() - tokenizer.timestamp_begin
1046
+ )
1047
+ duration = last_timestamp_position * time_precision
1048
+
1049
+ add_segment(
1050
+ start=timestamp_offset,
1051
+ end=timestamp_offset + duration,
1052
+ text_tokens=tokens,
1053
+ result=result,
1054
+ )
1055
+
1056
+ seek += segment.shape[-1]
1057
+ all_tokens.extend(tokens.tolist())
1058
+
1059
+ if not condition_on_previous_text or result.temperature > 0.5:
1060
+ # do not feed the prompt tokens if a high temperature was used
1061
+ prompt_reset_since = len(all_tokens)
1062
+
1063
+ # update progress bar
1064
+ pbar.update(min(num_frames, seek) - previous_seek_value)
1065
+ previous_seek_value = seek
1066
+
1067
+ return dict(
1068
+ text=tokenizer.decode(all_tokens[len(initial_prompt) :]),
1069
+ segments=all_segments,
1070
+ language=language,
1071
+ )
1072
+
1073
+
1074
+ class SequenceRanker:
1075
+ def rank(
1076
+ self, tokens: List[List[paddle.Tensor]], sum_logprobs: List[List[float]]
1077
+ ) -> List[int]:
1078
+ """
1079
+ Given a list of groups of samples and their cumulative log probabilities,
1080
+ return the indices of the samples in each group to select as the final result
1081
+ """
1082
+ raise NotImplementedError
1083
+
1084
+
1085
+ class MaximumLikelihoodRanker(SequenceRanker):
1086
+ """
1087
+ Select the sample with the highest log probabilities, penalized using either
1088
+ a simple length normalization or Google NMT paper's length penalty
1089
+ """
1090
+
1091
+ def __init__(self, length_penalty: Optional[float]):
1092
+ self.length_penalty = length_penalty
1093
+
1094
+ def rank(self, tokens: List[List[paddle.Tensor]], sum_logprobs: List[List[float]]):
1095
+ def scores(logprobs, lengths):
1096
+ result = []
1097
+ for logprob, length in zip(logprobs, lengths):
1098
+ if self.length_penalty is None or self.length_penalty == "None":
1099
+ penalty = length
1100
+ else:
1101
+ # from the Google NMT paper
1102
+ penalty = ((5 + length) / 6) ** self.length_penalty
1103
+ result.append(logprob / penalty)
1104
+ return result
1105
+
1106
+ # get the sequence with the highest score
1107
+ lengths = [[len(t) for t in s] for s in tokens]
1108
+ return [np.argmax(scores(p, l)) for p, l in zip(sum_logprobs, lengths)]
1109
+
1110
+
1111
+ class TokenDecoder:
1112
+ def reset(self):
1113
+ """Initialize any stateful variables for decoding a new sequence"""
1114
+
1115
+ def update(
1116
+ self, tokens: paddle.Tensor, logits: paddle.Tensor, sum_logprobs: paddle.Tensor
1117
+ ) -> Tuple[paddle.Tensor, bool]:
1118
+ """Specify how to select the next token, based on the current trace and logits
1119
+
1120
+ Parameters
1121
+ ----------
1122
+ tokens : Tensor, shape = (n_batch, current_sequence_length)
1123
+ all tokens in the context so far, including the prefix and sot_sequence tokens
1124
+
1125
+ logits : Tensor, shape = (n_batch, vocab_size)
1126
+ per-token logits of the probability distribution at the current step
1127
+
1128
+ sum_logprobs : Tensor, shape = (n_batch)
1129
+ cumulative log probabilities for each sequence
1130
+
1131
+ Returns
1132
+ -------
1133
+ tokens : Tensor, shape = (n_batch, current_sequence_length + 1)
1134
+ the tokens, appended with the selected next token
1135
+
1136
+ completed : bool
1137
+ True if all sequences has reached the end of text
1138
+
1139
+ """
1140
+ raise NotImplementedError
1141
+
1142
+ def finalize(
1143
+ self, tokens: paddle.Tensor, sum_logprobs: paddle.Tensor
1144
+ ) -> Tuple[Sequence[Sequence[paddle.Tensor]], List[List[float]]]:
1145
+ """Finalize search and return the final candidate sequences
1146
+
1147
+ Parameters
1148
+ ----------
1149
+ tokens : Tensor, shape = (batch_size, beam_size, current_sequence_length)
1150
+ all tokens in the context so far, including the prefix and sot_sequence
1151
+
1152
+ sum_logprobs : Tensor, shape = (batch_size, beam_size)
1153
+ cumulative log probabilities for each sequence
1154
+
1155
+ Returns
1156
+ -------
1157
+ tokens : Sequence[Sequence[Tensor]], length = batch_size
1158
+ sequence of Tensors containing candidate token sequences, for each audio input
1159
+
1160
+ sum_logprobs : List[List[float]], length = batch_size
1161
+ sequence of cumulative log probabilities corresponding to the above
1162
+
1163
+ """
1164
+ raise NotImplementedError
1165
+
1166
+
1167
+ class GreedyDecoder(TokenDecoder):
1168
+ def __init__(self, temperature: float, eot: int):
1169
+ self.temperature = temperature
1170
+ self.eot = eot
1171
+
1172
+ def update(
1173
+ self, tokens: paddle.Tensor, logits: paddle.Tensor, sum_logprobs: paddle.Tensor
1174
+ ) -> Tuple[paddle.Tensor, bool]:
1175
+ temperature = self.temperature
1176
+ if temperature == 0:
1177
+ next_tokens = paddle.argmax(logits, axis=-1)
1178
+ else:
1179
+ next_tokens = paddle.distribution.Categorical(
1180
+ logits=logits / temperature
1181
+ ).sample([1])
1182
+ next_tokens = paddle.reshape(
1183
+ next_tokens,
1184
+ [
1185
+ next_tokens.shape[0] * next_tokens.shape[1],
1186
+ ],
1187
+ )
1188
+
1189
+ logprobs = paddle.nn.functional.log_softmax(
1190
+ logits, axis=-1, dtype=paddle.float32
1191
+ )
1192
+ current_logprobs = logprobs[paddle.arange(logprobs.shape[0]), next_tokens]
1193
+ sum_logprobs += current_logprobs * paddle.to_tensor(
1194
+ (tokens[:, -1] != self.eot), dtype=paddle.float32
1195
+ )
1196
+
1197
+ next_tokens[tokens[:, -1] == self.eot] = self.eot
1198
+ tokens = paddle.concat([tokens, next_tokens[:, None]], axis=-1)
1199
+
1200
+ completed = paddle.all((tokens[:, -1] == self.eot))
1201
+ return tokens, completed
1202
+
1203
+ def finalize(self, tokens: paddle.Tensor, sum_logprobs: paddle.Tensor):
1204
+ # make sure each sequence has at least one EOT token at the end
1205
+ tokens = paddle.nn.functional.pad(
1206
+ tokens, (0, 1), value=self.eot, data_format="NCL"
1207
+ )
1208
+ return tokens, sum_logprobs.tolist()
1209
+
1210
+
1211
+ class BeamSearchDecoder(TokenDecoder):
1212
+ def __init__(
1213
+ self,
1214
+ beam_size: int,
1215
+ eot: int,
1216
+ inference: Inference,
1217
+ patience: Optional[float] = None,
1218
+ ):
1219
+ self.beam_size = beam_size
1220
+ self.eot = eot
1221
+ self.inference = inference
1222
+ self.patience = patience or 1.0
1223
+ if patience is None or patience == "None":
1224
+ self.patience = 1.0
1225
+ else:
1226
+ self.patience = patience
1227
+ self.max_candidates: int = round(beam_size * self.patience)
1228
+ self.finished_sequences = None
1229
+
1230
+ assert (
1231
+ self.max_candidates > 0
1232
+ ), f"Invalid beam size ({beam_size}) or patience ({patience})"
1233
+
1234
+ def reset(self):
1235
+ self.finished_sequences = None
1236
+
1237
+ def update(
1238
+ self, tokens: paddle.Tensor, logits: paddle.Tensor, sum_logprobs: paddle.Tensor
1239
+ ) -> Tuple[paddle.Tensor, bool]:
1240
+ if tokens.shape[0] % self.beam_size != 0:
1241
+ raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
1242
+
1243
+ batch_size = tokens.shape[0] // self.beam_size
1244
+ if self.finished_sequences is None: # for the first update
1245
+ self.finished_sequences = [{} for _ in range(batch_size)]
1246
+
1247
+ logprobs = paddle.nn.functional.log_softmax(logits, axis=-1, dtype="float32")
1248
+ next_tokens, source_indices, finished_sequences = [], [], []
1249
+ for i in range(batch_size):
1250
+ scores, sources, finished = {}, {}, {}
1251
+
1252
+ # STEP 1: calculate the cumulative log probabilities for possible candidates
1253
+ for j in range(self.beam_size):
1254
+ idx = i * self.beam_size + j
1255
+ prefix = tokens[idx].tolist()
1256
+ logprob, token = paddle.topk(logprobs[idx], k=self.beam_size + 1)
1257
+ for logprob, token in zip(logprob, token):
1258
+ new_logprob = (sum_logprobs[idx] + logprob).item()
1259
+ sequence = tuple(prefix + [token.item()])
1260
+ scores[sequence] = new_logprob
1261
+ sources[sequence] = idx
1262
+
1263
+ # STEP 2: rank the candidates and keep the top beam_size sequences for each audio
1264
+ saved = 0
1265
+ for sequence in sorted(scores, key=scores.get, reverse=True):
1266
+ if sequence[-1] == self.eot:
1267
+ finished[sequence] = scores[sequence]
1268
+ else:
1269
+ sum_logprobs[len(next_tokens)] = scores[sequence]
1270
+ next_tokens.append(sequence)
1271
+ source_indices.append(sources[sequence])
1272
+
1273
+ saved += 1
1274
+ if saved == self.beam_size:
1275
+ break
1276
+
1277
+ finished_sequences.append(finished)
1278
+
1279
+ tokens = paddle.to_tensor(next_tokens)
1280
+ self.inference.rearrange_kv_cache(source_indices)
1281
+
1282
+ # add newly finished sequences to self.finished_sequences
1283
+ assert len(self.finished_sequences) == len(finished_sequences)
1284
+ for previously_finished, newly_finished in zip(
1285
+ self.finished_sequences, finished_sequences
1286
+ ):
1287
+ for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
1288
+ if len(previously_finished) >= self.max_candidates:
1289
+ break # the candidate list is full
1290
+ previously_finished[seq] = newly_finished[seq]
1291
+
1292
+ # mark as completed if all audio has enough number of samples
1293
+ completed = all(
1294
+ len(sequences) >= self.max_candidates
1295
+ for sequences in self.finished_sequences
1296
+ )
1297
+ return tokens, completed
1298
+
1299
+ def finalize(self, preceding_tokens: paddle.Tensor, sum_logprobs: paddle.Tensor):
1300
+ # collect all finished sequences, including patience, and add unfinished ones if not enough
1301
+ sum_logprobs = sum_logprobs.cpu()
1302
+ for i, sequences in enumerate(self.finished_sequences):
1303
+ if (
1304
+ len(sequences) < self.beam_size
1305
+ ): # when not enough sequences are finished
1306
+ for j in list(np.argsort(sum_logprobs[i]))[::-1]:
1307
+ sequence = preceding_tokens[i, j].tolist() + [self.eot]
1308
+ sequences[tuple(sequence)] = sum_logprobs[i][j].item()
1309
+ if len(sequences) >= self.beam_size:
1310
+ break
1311
+
1312
+ tokens: List[List[paddle.Tensor]] = [
1313
+ [paddle.to_tensor(seq) for seq in sequences.keys()]
1314
+ for sequences in self.finished_sequences
1315
+ ]
1316
+ sum_logprobs: List[List[float]] = [
1317
+ list(sequences.values()) for sequences in self.finished_sequences
1318
+ ]
1319
+ return tokens, sum_logprobs
1320
+
1321
+
1322
+ class LogitFilter:
1323
+ def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor) -> None:
1324
+ """Apply any filtering or masking to logits in-place
1325
+
1326
+ Parameters
1327
+ ----------
1328
+ logits : Tensor, shape = (n_batch, vocab_size)
1329
+ per-token logits of the probability distribution at the current step
1330
+
1331
+ tokens : Tensor, shape = (n_batch, current_sequence_length)
1332
+ all tokens in the context so far, including the prefix and sot_sequence tokens
1333
+
1334
+ """
1335
+ raise NotImplementedError
1336
+
1337
+
1338
+ class SuppressBlank(LogitFilter):
1339
+ def __init__(self, tokenizer: Tokenizer, sample_begin: int):
1340
+ self.tokenizer = tokenizer
1341
+ self.sample_begin = sample_begin
1342
+
1343
+ def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor):
1344
+ if tokens.shape[1] == self.sample_begin:
1345
+ logits[:, self.tokenizer.encode(" ").input_ids + [self.tokenizer.eot]] = (
1346
+ -np.inf
1347
+ )
1348
+
1349
+
1350
+ class SuppressTokens(LogitFilter):
1351
+ def __init__(self, suppress_tokens: Sequence[int]):
1352
+ self.suppress_tokens = list(suppress_tokens)
1353
+
1354
+ def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor):
1355
+ logits[:, self.suppress_tokens] = -np.inf
1356
+
1357
+
1358
+ class ApplyTimestampRules(LogitFilter):
1359
+ def __init__(
1360
+ self,
1361
+ tokenizer: Tokenizer,
1362
+ sample_begin: int,
1363
+ max_initial_timestamp_index: Optional[int],
1364
+ ):
1365
+ self.tokenizer = tokenizer
1366
+ self.sample_begin = sample_begin
1367
+ self.max_initial_timestamp_index = max_initial_timestamp_index
1368
+
1369
+ def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor):
1370
+ # suppress <|notimestamps|> which is handled by without_timestamps
1371
+ if self.tokenizer.no_timestamps is not None:
1372
+ logits[:, self.tokenizer.no_timestamps] = -np.inf
1373
+
1374
+ # timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
1375
+ for k in range(tokens.shape[0]):
1376
+ seq = [t for t in tokens[k, self.sample_begin :].tolist()]
1377
+ last_was_timestamp = (
1378
+ len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
1379
+ )
1380
+ penultimate_was_timestamp = (
1381
+ len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin
1382
+ )
1383
+
1384
+ if last_was_timestamp:
1385
+ if penultimate_was_timestamp: # has to be non-timestamp
1386
+ logits[k, self.tokenizer.timestamp_begin :] = -np.inf
1387
+ else: # cannot be normal text tokens
1388
+ logits[k, : self.tokenizer.eot] = -np.inf
1389
+
1390
+ # apply the `max_initial_timestamp` option
1391
+ if (
1392
+ tokens.shape[1] == self.sample_begin
1393
+ and self.max_initial_timestamp_index is not None
1394
+ ):
1395
+ last_allowed = (
1396
+ self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
1397
+ )
1398
+ logits[:, last_allowed + 1 :] = -np.inf
1399
+
1400
+ # if sum of probability over timestamps is above any other token, sample timestamp
1401
+ logprobs = paddle.nn.functional.log_softmax(logits, axis=-1, dtype="float32")
1402
+ for k in range(tokens.shape[0]):
1403
+ # When using paddle.logsumexp on a 32GB Tesla-V100 GPU, we encountered CUDA error 700.
1404
+ # To bypass this issue in CI, we have decomposed the operation into separate steps.
1405
+ # It will raise 2e-6 difference in precision.
1406
+ # TODO: revert this after logsumexp been fixed.
1407
+ timestamp_logprob = paddle.exp(
1408
+ logprobs[k, self.tokenizer.timestamp_begin :]
1409
+ )
1410
+ timestamp_logprob = paddle.sum(timestamp_logprob, axis=-1)
1411
+ timestamp_logprob = paddle.log(timestamp_logprob)
1412
+ max_text_token_logprob = paddle.max(
1413
+ logprobs[k, : self.tokenizer.timestamp_begin]
1414
+ )
1415
+ if timestamp_logprob > max_text_token_logprob:
1416
+ logits[k, : self.tokenizer.timestamp_begin] = -np.inf
1417
+
1418
+
1419
+ class DecodingTask:
1420
+ inference: Inference
1421
+ sequence_ranker: SequenceRanker
1422
+ decoder: TokenDecoder
1423
+ logit_filters: List[LogitFilter]
1424
+
1425
+ def __init__(self, model: "Whisper", options: DecodingOptions, resource_path: str):
1426
+ self.model = model
1427
+
1428
+ language = options.language or "en"
1429
+ tokenizer = get_tokenizer(
1430
+ model.is_multilingual,
1431
+ resource_path=resource_path,
1432
+ language=language,
1433
+ task=options.task,
1434
+ )
1435
+ self.tokenizer: Tokenizer = tokenizer
1436
+ self.options: DecodingOptions = self._verify_options(options)
1437
+ self.resource_path: str = resource_path
1438
+
1439
+ self.beam_size: int = options.beam_size or options.best_of or 1
1440
+ self.n_ctx: int = model.dims.n_text_ctx
1441
+ self.sample_len: int = options.sample_len or model.dims.n_text_ctx // 2
1442
+
1443
+ self.sot_sequence: Tuple[int] = tokenizer.sot_sequence
1444
+ if self.options.without_timestamps:
1445
+ self.sot_sequence = tokenizer.sot_sequence_including_notimestamps
1446
+
1447
+ self.initial_tokens: Tuple[int] = self._get_initial_tokens()
1448
+ self.sample_begin: int = len(self.initial_tokens)
1449
+ self.sot_index: int = self.initial_tokens.index(tokenizer.sot)
1450
+
1451
+ # inference: implements the forward pass through the decoder, including kv caching
1452
+ self.inference = WhisperInference(model, len(self.initial_tokens))
1453
+
1454
+ # sequence ranker: implements how to rank a group of sampled sequences
1455
+ self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty)
1456
+
1457
+ # decoder: implements how to select the next tokens, given the autoregressive distribution
1458
+ if options.beam_size is not None:
1459
+ self.decoder = BeamSearchDecoder(
1460
+ options.beam_size, tokenizer.eot, self.inference, options.patience
1461
+ )
1462
+ else:
1463
+ self.decoder = GreedyDecoder(options.temperature, tokenizer.eot)
1464
+
1465
+ # logit filters: applies various rules to suppress or penalize certain tokens
1466
+ self.logit_filters = []
1467
+ if self.options.suppress_blank:
1468
+ self.logit_filters.append(SuppressBlank(self.tokenizer, self.sample_begin))
1469
+ if self.options.suppress_tokens:
1470
+ self.logit_filters.append(SuppressTokens(self._get_suppress_tokens()))
1471
+ if not options.without_timestamps:
1472
+ precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds
1473
+ max_initial_timestamp_index = None
1474
+ if options.max_initial_timestamp:
1475
+ max_initial_timestamp_index = round(
1476
+ self.options.max_initial_timestamp / precision
1477
+ )
1478
+ self.logit_filters.append(
1479
+ ApplyTimestampRules(
1480
+ tokenizer, self.sample_begin, max_initial_timestamp_index
1481
+ )
1482
+ )
1483
+
1484
+ def _verify_options(self, options: DecodingOptions) -> DecodingOptions:
1485
+ if options.beam_size is not None and options.best_of is not None:
1486
+ raise ValueError("beam_size and best_of can't be given together")
1487
+ if options.temperature == 0:
1488
+ if options.best_of is not None:
1489
+ raise ValueError("best_of with greedy sampling (T=0) is not compatible")
1490
+ if options.patience is not None and options.beam_size is None:
1491
+ raise ValueError("patience requires beam_size to be given")
1492
+ if options.length_penalty is not None and options.length_penalty != "None":
1493
+ if not (0 <= options.length_penalty <= 1):
1494
+ raise ValueError(
1495
+ "length_penalty (alpha) should be a value between 0 and 1"
1496
+ )
1497
+
1498
+ return options
1499
+
1500
+ def _get_initial_tokens(self) -> Tuple[int]:
1501
+ tokens = list(self.sot_sequence)
1502
+ prefix = self.options.prefix
1503
+ prompt = self.options.prompt
1504
+
1505
+ if prefix:
1506
+ prefix_tokens = (
1507
+ self.tokenizer.encode(" " + prefix.strip().input_ids)
1508
+ if isinstance(prefix, str)
1509
+ else prefix
1510
+ )
1511
+ if self.sample_len is not None:
1512
+ max_prefix_len = self.n_ctx // 2 - self.sample_len
1513
+ prefix_tokens = prefix_tokens[-max_prefix_len:]
1514
+ tokens = tokens + prefix_tokens
1515
+
1516
+ if prompt:
1517
+ prompt_tokens = (
1518
+ self.tokenizer.encode(" " + prompt.strip().input_ids)
1519
+ if isinstance(prompt, str)
1520
+ else prompt
1521
+ )
1522
+ tokens = (
1523
+ [self.tokenizer.sot_prev]
1524
+ + prompt_tokens[-(self.n_ctx // 2 - 1) :]
1525
+ + tokens
1526
+ )
1527
+
1528
+ return tuple(tokens)
1529
+
1530
+ def _get_suppress_tokens(self) -> Tuple[int]:
1531
+ suppress_tokens = self.options.suppress_tokens
1532
+
1533
+ if isinstance(suppress_tokens, str):
1534
+ suppress_tokens = [int(t) for t in suppress_tokens.split(",")]
1535
+
1536
+ if -1 in suppress_tokens:
1537
+ suppress_tokens = [t for t in suppress_tokens if t >= 0]
1538
+ suppress_tokens.extend(self.tokenizer.non_speech_tokens)
1539
+ elif suppress_tokens is None or len(suppress_tokens) == 0:
1540
+ suppress_tokens = [] # interpret empty string as an empty list
1541
+ else:
1542
+ assert isinstance(suppress_tokens, list), "suppress_tokens must be a list"
1543
+
1544
+ suppress_tokens.extend(
1545
+ [self.tokenizer.sot, self.tokenizer.sot_prev, self.tokenizer.sot_lm]
1546
+ )
1547
+ if self.tokenizer.no_speech is not None:
1548
+ # no-speech probability is collected separately
1549
+ suppress_tokens.append(self.tokenizer.no_speech)
1550
+
1551
+ return tuple(sorted(set(suppress_tokens)))
1552
+
1553
+ def _get_audio_features(self, mel: paddle.Tensor):
1554
+
1555
+ if mel.shape[-2:] == (
1556
+ self.model.dims.n_audio_ctx,
1557
+ self.model.dims.n_audio_state,
1558
+ ):
1559
+ # encoded audio features are given; skip audio encoding
1560
+ audio_features = mel
1561
+ else:
1562
+ audio_features = self.model.encoder(mel)
1563
+
1564
+ return audio_features
1565
+
1566
+ def _detect_language(
1567
+ self, audio_features: paddle.Tensor, tokens: paddle.Tensor, resource_path: str
1568
+ ):
1569
+ languages = [self.options.language] * audio_features.shape[0]
1570
+ lang_probs = None
1571
+
1572
+ if self.options.language is None or self.options.task == "lang_id":
1573
+ lang_tokens, lang_probs = self.model.detect_language(
1574
+ audio_features, self.tokenizer, self.resource_path
1575
+ )
1576
+ languages = [max(probs, key=probs.get) for probs in lang_probs]
1577
+ if self.options.language is None:
1578
+ tokens[:, self.sot_index + 1] = lang_tokens # write language tokens
1579
+
1580
+ return languages, lang_probs
1581
+
1582
+ def _main_loop(self, audio_features: paddle.Tensor, tokens: paddle.Tensor):
1583
+ assert audio_features.shape[0] == tokens.shape[0]
1584
+ n_batch = tokens.shape[0]
1585
+ sum_logprobs: paddle.Tensor = paddle.zeros(
1586
+ paddle.to_tensor(n_batch), dtype=paddle.float32
1587
+ )
1588
+ no_speech_probs = [np.nan] * n_batch
1589
+
1590
+ try:
1591
+ for i in range(self.sample_len):
1592
+ logits = self.inference.logits(tokens, audio_features)
1593
+
1594
+ if (
1595
+ i == 0 and self.tokenizer.no_speech is not None
1596
+ ): # save no_speech_probs
1597
+ probs_at_sot = paddle.nn.functional.softmax(
1598
+ logits[:, self.sot_index], axis=-1, dtype=paddle.float32
1599
+ )
1600
+ no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
1601
+
1602
+ # now we need to consider the logits at the last token only
1603
+ logits = logits[:, -1]
1604
+
1605
+ # apply the logit filters, e.g. for suppressing or applying penalty to
1606
+ for logit_filter in self.logit_filters:
1607
+ logit_filter.apply(logits, tokens)
1608
+
1609
+ # expand the tokens tensor with the selected next tokens
1610
+ tokens, completed = self.decoder.update(tokens, logits, sum_logprobs)
1611
+ if completed or tokens.shape[-1] > self.n_ctx:
1612
+ break
1613
+ finally:
1614
+ self.inference.cleanup_caching()
1615
+
1616
+ return tokens, sum_logprobs, no_speech_probs
1617
+
1618
+ @paddle.no_grad()
1619
+ def run(self, mel: paddle.Tensor) -> List[DecodingResult]:
1620
+ self.decoder.reset()
1621
+ tokenizer: Tokenizer = self.tokenizer
1622
+ batch_size: int = mel.shape[0]
1623
+
1624
+ audio_features: paddle.Tensor = self._get_audio_features(
1625
+ mel
1626
+ ) # encoder forward pass
1627
+
1628
+ tokens: paddle.Tensor
1629
+ if batch_size > 1:
1630
+ for i in range(batch_size):
1631
+ tokens = paddle.concat(
1632
+ x=[
1633
+ paddle.to_tensor([self.initial_tokens]),
1634
+ paddle.to_tensor([self.initial_tokens]),
1635
+ ],
1636
+ axis=0,
1637
+ )
1638
+ elif batch_size == 1:
1639
+ tokens = paddle.to_tensor([self.initial_tokens])
1640
+
1641
+ # detect language if requested, overwriting the language token
1642
+ languages, language_probs = self._detect_language(
1643
+ paddle.to_tensor(audio_features),
1644
+ paddle.to_tensor(tokens),
1645
+ self.resource_path,
1646
+ )
1647
+
1648
+ if self.options.task == "lang_id":
1649
+ return [
1650
+ DecodingResult(
1651
+ audio_features=features, language=language, language_probs=probs
1652
+ )
1653
+ for features, language, probs in zip(
1654
+ audio_features, languages, language_probs
1655
+ )
1656
+ ]
1657
+
1658
+ # repeat the audio & text tensors by the group size, for beam search or best-of-n sampling
1659
+
1660
+ audio_features = paddle.repeat_interleave(
1661
+ audio_features, self.beam_size, axis=0
1662
+ )
1663
+ tokens = paddle.repeat_interleave(tokens, self.beam_size, axis=0)
1664
+
1665
+ # call the main sampling loop
1666
+ tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens)
1667
+
1668
+ # reshape the tensors to have (batch_size, beam_size) as the first two dimensions
1669
+ audio_features = audio_features[:: self.beam_size]
1670
+ no_speech_probs = no_speech_probs[:: self.beam_size]
1671
+ assert audio_features.shape[0] == len(no_speech_probs) == batch_size
1672
+
1673
+ tokens = tokens.reshape([batch_size, self.beam_size, -1])
1674
+ sum_logprobs = sum_logprobs.reshape([batch_size, self.beam_size])
1675
+
1676
+ # get the final candidates for each group, and slice between the first sampled token and EOT
1677
+ tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
1678
+ tokens: List[List[paddle.Tensor]] = [
1679
+ [t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s]
1680
+ for s in tokens
1681
+ ]
1682
+
1683
+ # select the top-ranked sample in each group
1684
+ selected = self.sequence_ranker.rank(tokens, sum_logprobs)
1685
+ tokens: List[List[int]] = [t[i].tolist() for i, t in zip(selected, tokens)]
1686
+ texts: List[str] = [tokenizer.decode(t).strip() for t in tokens]
1687
+
1688
+ sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
1689
+ avg_logprobs: List[float] = [
1690
+ lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)
1691
+ ]
1692
+
1693
+ fields = (
1694
+ texts,
1695
+ languages,
1696
+ tokens,
1697
+ audio_features,
1698
+ avg_logprobs,
1699
+ no_speech_probs,
1700
+ )
1701
+ if len(set(map(len, fields))) != 1:
1702
+ raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}")
1703
+
1704
+ return [
1705
+ DecodingResult(
1706
+ audio_features=features,
1707
+ language=language,
1708
+ tokens=tokens,
1709
+ text=text,
1710
+ avg_logprob=avg_logprob,
1711
+ no_speech_prob=no_speech_prob,
1712
+ temperature=self.options.temperature,
1713
+ compression_ratio=compression_ratio(text),
1714
+ )
1715
+ for text, language, tokens, features, avg_logprob, no_speech_prob in zip(
1716
+ *fields
1717
+ )
1718
+ ]
1719
+
1720
+
1721
+ @paddle.no_grad()
1722
+ def decode(
1723
+ model: "Whisper",
1724
+ mel: paddle.Tensor,
1725
+ options: DecodingOptions = DecodingOptions(),
1726
+ resource_path=str,
1727
+ ) -> Union[DecodingResult, List[DecodingResult]]:
1728
+ """
1729
+ Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).
1730
+
1731
+ Parameters
1732
+ ----------
1733
+ model: Whisper
1734
+ the Whisper model instance
1735
+
1736
+ mel: paddle.Tensor, shape = (80, 3000) or (*, 80, 3000)
1737
+ A tensor containing the Mel spectrogram(s)
1738
+
1739
+ options: DecodingOptions
1740
+ A dataclass that contains all necessary options for decoding 30-second segments
1741
+
1742
+ Returns
1743
+ -------
1744
+ result: Union[DecodingResult, List[DecodingResult]]
1745
+ The result(s) of decoding contained in `DecodingResult` dataclass instance(s)
1746
+ """
1747
+ single = mel.ndim == 2
1748
+ if single:
1749
+ mel = mel.unsqueeze(0)
1750
+
1751
+ result = DecodingTask(model, options, resource_path).run(mel)
1752
+
1753
+ if single:
1754
+ result = result[0]
1755
+
1756
+ return result
1757
+
1758
+
1759
+ class Whisper(paddle.nn.Layer):
1760
+ """
1761
+ The `Whisper` module use AudioEncoder and TextDecoder, and return detect_language, transcribe, decode.
1762
+ """
1763
+
1764
+ def __init__(self, dims: ModelDimensions):
1765
+ super().__init__()
1766
+ self.dims = dims
1767
+ self.encoder = AudioEncoder(
1768
+ self.dims.n_mels,
1769
+ self.dims.n_audio_ctx,
1770
+ self.dims.n_audio_state,
1771
+ self.dims.n_audio_head,
1772
+ self.dims.n_audio_layer,
1773
+ )
1774
+ self.decoder = TextDecoder(
1775
+ self.dims.n_vocab,
1776
+ self.dims.n_text_ctx,
1777
+ self.dims.n_text_state,
1778
+ self.dims.n_text_head,
1779
+ self.dims.n_text_layer,
1780
+ )
1781
+
1782
+ def embed_audio(self, mel: paddle.Tensor):
1783
+ return self.encoder.forward(mel)
1784
+
1785
+ def logits(self, tokens: paddle.Tensor, audio_features: paddle.Tensor):
1786
+ return self.decoder.forward(tokens, audio_features)
1787
+
1788
+ def forward(
1789
+ self, mel: paddle.Tensor, tokens: paddle.Tensor
1790
+ ) -> Dict[str, paddle.Tensor]:
1791
+ return self.decoder(tokens, self.encoder(mel))
1792
+
1793
+ @property
1794
+ def device(self):
1795
+ return paddle.device.get_device()
1796
+
1797
+ @property
1798
+ def is_multilingual(self):
1799
+ return self.dims.n_vocab == 51865
1800
+
1801
+ def install_kv_cache_hooks(self, cache: Optional[dict] = None):
1802
+ """
1803
+ The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
1804
+ tensors calculated for the previous positions. This method returns a dictionary that stores
1805
+ all caches, and the necessary hooks for the key and value projection modules that save the
1806
+ intermediate tensors to be reused during later calculations.
1807
+
1808
+ Returns
1809
+ -------
1810
+ cache : Dict[nn.Layer, paddle.Tensor]
1811
+ A dictionary object mapping the key/value projection modules to its cache
1812
+ hooks : List[RemovableHandle]
1813
+ List of PyTorch RemovableHandle objects to stop the hooks to be called
1814
+ """
1815
+ cache = {**cache} if cache is not None else {}
1816
+ hooks = []
1817
+
1818
+ def save_to_cache(module, _, output):
1819
+ if (
1820
+ module not in cache
1821
+ or output.shape[1] > self.decoder.positional_embedding.shape[0]
1822
+ ):
1823
+ cache[module] = (
1824
+ output # save as-is, for the first token or cross attention
1825
+ )
1826
+ else:
1827
+ cache[module] = paddle.concat([cache[module], output], axis=1).detach()
1828
+ return cache[module]
1829
+
1830
+ def install_hooks(layer: paddle.nn.Layer):
1831
+ if isinstance(layer, MultiHeadAttention):
1832
+ hooks.append(layer.key.register_forward_post_hook(save_to_cache))
1833
+ hooks.append(layer.value.register_forward_post_hook(save_to_cache))
1834
+
1835
+ self.decoder.apply(install_hooks)
1836
+ return cache, hooks
1837
+
1838
+ detect_language = detect_language
1839
+ transcribe = transcribe
1840
+ decode = decode
1841
+
1842
+
1843
+ def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
1844
+ """
1845
+ Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
1846
+ """
1847
+ if paddle.is_tensor(array):
1848
+ if array.shape[axis] > length:
1849
+ array = array.index_select(axis=axis, index=paddle.arange(length))
1850
+
1851
+ if array.shape[axis] < length:
1852
+ pad_widths = [(0, 0)] * array.ndim
1853
+ pad_widths[axis] = (0, length - array.shape[axis])
1854
+ array = paddle.transpose(array, (1, 0))
1855
+ array = paddle.nn.functional.pad(
1856
+ array,
1857
+ [pad for sizes in pad_widths[::-1] for pad in sizes],
1858
+ data_format="NLC",
1859
+ )
1860
+ array = paddle.transpose(array, (1, 0))
1861
+ else:
1862
+ if array.shape[axis] > length:
1863
+ array = array.take(indices=range(length), axis=axis)
1864
+
1865
+ if array.shape[axis] < length:
1866
+ pad_widths = [(0, 0)] * array.ndim
1867
+ pad_widths[axis] = (0, length - array.shape[axis])
1868
+ array = paddle.transpose(array, (1, 0))
1869
+ array = np.pad(array, pad_widths)
1870
+ array = paddle.transpose(array, (1, 0))
1871
+
1872
+ return array
1873
+
1874
+
1875
+ def hann_window(n_fft: int = N_FFT):
1876
+ """
1877
+ hanning window
1878
+ n_fft: The number of frequency components of the discrete Fourier transform.
1879
+ """
1880
+ return paddle.to_tensor(
1881
+ [0.5 - 0.5 * np.cos(2 * np.pi * n / n_fft) for n in range(n_fft)],
1882
+ dtype=paddle.float32,
1883
+ )
1884
+
1885
+
1886
+ @lru_cache(maxsize=None)
1887
+ def mel_filters(resource_path: str, n_mels: int = N_MELS) -> paddle.Tensor:
1888
+ """
1889
+ load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
1890
+ Allows decoupling librosa dependency; saved using:
1891
+
1892
+ np.savez_compressed(
1893
+ "mel_filters.npz",
1894
+ mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
1895
+ )
1896
+ """
1897
+ assert n_mels == 80, f"Unsupported n_mels: {n_mels}"
1898
+ with np.load(os.path.join(resource_path, "assets", "mel_filters.npz")) as f:
1899
+ return paddle.to_tensor(f[f"mel_{n_mels}"])
1900
+
1901
+
1902
+ def log_mel_spectrogram(
1903
+ audio: Union[str, np.ndarray, paddle.Tensor],
1904
+ n_mels: int = N_MELS,
1905
+ resource_path: str = None,
1906
+ ):
1907
+ """
1908
+ Compute the log-Mel spectrogram of
1909
+
1910
+ Parameters
1911
+ ----------
1912
+ audio: Union[str, np.ndarray, paddle.Tensor], shape = (*)
1913
+ The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
1914
+
1915
+ n_mels: int
1916
+ The number of Mel-frequency filters, only 80 is supported
1917
+
1918
+ Returns
1919
+ -------
1920
+ paddle.Tensor, shape = (80, n_frames)
1921
+ A Tensor that contains the Mel spectrogram
1922
+ """
1923
+ if not paddle.is_tensor(audio):
1924
+ if isinstance(audio, str):
1925
+ audio, _ = soundfile.read(audio, dtype="float32", always_2d=True)
1926
+ audio = audio[:, 0]
1927
+ audio = paddle.to_tensor(audio)
1928
+
1929
+ window = hann_window(N_FFT)
1930
+ stft = paddle.signal.stft(audio, N_FFT, HOP_LENGTH, window=window)
1931
+
1932
+ magnitudes = stft[:, :-1].abs() ** 2
1933
+
1934
+ filters = mel_filters(resource_path, n_mels)
1935
+ mel_spec = filters @ magnitudes
1936
+ mel_spec = paddle.to_tensor(mel_spec.numpy().tolist())
1937
+
1938
+ log_spec = paddle.clip(mel_spec, min=1e-10).log10()
1939
+ log_spec = paddle.maximum(log_spec, log_spec.max() - 8.0)
1940
+ log_spec = (log_spec + 4.0) / 4.0
1941
+ return log_spec