paddlex 3.0.0rc1__py3-none-any.whl → 3.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (233) hide show
  1. paddlex/.version +1 -1
  2. paddlex/__init__.py +1 -1
  3. paddlex/configs/modules/chart_parsing/PP-Chart2Table.yaml +13 -0
  4. paddlex/configs/modules/doc_vlm/PP-DocBee2-3B.yaml +14 -0
  5. paddlex/configs/modules/formula_recognition/PP-FormulaNet_plus-L.yaml +40 -0
  6. paddlex/configs/modules/formula_recognition/PP-FormulaNet_plus-M.yaml +40 -0
  7. paddlex/configs/modules/formula_recognition/PP-FormulaNet_plus-S.yaml +40 -0
  8. paddlex/configs/modules/layout_detection/PP-DocBlockLayout.yaml +40 -0
  9. paddlex/configs/modules/layout_detection/PP-DocLayout-L.yaml +2 -2
  10. paddlex/configs/modules/layout_detection/PP-DocLayout-M.yaml +2 -2
  11. paddlex/configs/modules/layout_detection/PP-DocLayout-S.yaml +2 -2
  12. paddlex/configs/modules/layout_detection/PP-DocLayout_plus-L.yaml +40 -0
  13. paddlex/configs/modules/text_detection/PP-OCRv5_mobile_det.yaml +40 -0
  14. paddlex/configs/modules/text_detection/PP-OCRv5_server_det.yaml +40 -0
  15. paddlex/configs/modules/text_recognition/PP-OCRv5_mobile_rec.yaml +39 -0
  16. paddlex/configs/modules/text_recognition/PP-OCRv5_server_rec.yaml +39 -0
  17. paddlex/configs/modules/textline_orientation/PP-LCNet_x1_0_textline_ori.yaml +41 -0
  18. paddlex/configs/pipelines/OCR.yaml +7 -6
  19. paddlex/configs/pipelines/PP-ChatOCRv3-doc.yaml +3 -1
  20. paddlex/configs/pipelines/PP-ChatOCRv4-doc.yaml +91 -34
  21. paddlex/configs/pipelines/PP-StructureV3.yaml +72 -72
  22. paddlex/configs/pipelines/doc_understanding.yaml +1 -1
  23. paddlex/configs/pipelines/formula_recognition.yaml +2 -2
  24. paddlex/configs/pipelines/layout_parsing.yaml +3 -2
  25. paddlex/configs/pipelines/seal_recognition.yaml +1 -0
  26. paddlex/configs/pipelines/table_recognition.yaml +2 -1
  27. paddlex/configs/pipelines/table_recognition_v2.yaml +7 -1
  28. paddlex/hpip_links.html +20 -20
  29. paddlex/inference/common/batch_sampler/doc_vlm_batch_sampler.py +33 -10
  30. paddlex/inference/common/batch_sampler/image_batch_sampler.py +34 -25
  31. paddlex/inference/common/result/mixin.py +19 -12
  32. paddlex/inference/models/base/predictor/base_predictor.py +2 -8
  33. paddlex/inference/models/common/static_infer.py +11 -59
  34. paddlex/inference/models/common/tokenizer/__init__.py +2 -0
  35. paddlex/inference/models/common/tokenizer/clip_tokenizer.py +1 -1
  36. paddlex/inference/models/common/tokenizer/gpt_tokenizer.py +2 -2
  37. paddlex/inference/models/common/tokenizer/qwen2_5_tokenizer.py +112 -0
  38. paddlex/inference/models/common/tokenizer/qwen2_tokenizer.py +7 -1
  39. paddlex/inference/models/common/tokenizer/qwen_tokenizer.py +288 -0
  40. paddlex/inference/models/common/tokenizer/tokenizer_utils.py +13 -13
  41. paddlex/inference/models/common/tokenizer/tokenizer_utils_base.py +3 -3
  42. paddlex/inference/models/common/tokenizer/vocab.py +7 -7
  43. paddlex/inference/models/common/vlm/conversion_utils.py +99 -0
  44. paddlex/inference/models/common/vlm/fusion_ops.py +205 -0
  45. paddlex/inference/models/common/vlm/generation/configuration_utils.py +1 -1
  46. paddlex/inference/models/common/vlm/generation/logits_process.py +1 -1
  47. paddlex/inference/models/common/vlm/generation/utils.py +1 -1
  48. paddlex/inference/models/common/vlm/transformers/configuration_utils.py +3 -3
  49. paddlex/inference/models/common/vlm/transformers/conversion_utils.py +3 -3
  50. paddlex/inference/models/common/vlm/transformers/model_outputs.py +2 -2
  51. paddlex/inference/models/common/vlm/transformers/model_utils.py +7 -31
  52. paddlex/inference/models/doc_vlm/modeling/GOT_ocr_2_0.py +830 -0
  53. paddlex/inference/models/doc_vlm/modeling/__init__.py +2 -0
  54. paddlex/inference/models/doc_vlm/modeling/qwen2.py +1606 -0
  55. paddlex/inference/models/doc_vlm/modeling/qwen2_5_vl.py +3006 -0
  56. paddlex/inference/models/doc_vlm/modeling/qwen2_vl.py +0 -105
  57. paddlex/inference/models/doc_vlm/predictor.py +79 -24
  58. paddlex/inference/models/doc_vlm/processors/GOT_ocr_2_0.py +97 -0
  59. paddlex/inference/models/doc_vlm/processors/__init__.py +2 -0
  60. paddlex/inference/models/doc_vlm/processors/common.py +189 -0
  61. paddlex/inference/models/doc_vlm/processors/qwen2_5_vl.py +548 -0
  62. paddlex/inference/models/doc_vlm/processors/qwen2_vl.py +21 -176
  63. paddlex/inference/models/formula_recognition/predictor.py +7 -1
  64. paddlex/inference/models/formula_recognition/processors.py +92 -79
  65. paddlex/inference/models/formula_recognition/result.py +28 -27
  66. paddlex/inference/models/image_feature/processors.py +3 -4
  67. paddlex/inference/models/keypoint_detection/predictor.py +3 -0
  68. paddlex/inference/models/object_detection/predictor.py +2 -0
  69. paddlex/inference/models/object_detection/processors.py +28 -3
  70. paddlex/inference/models/object_detection/utils.py +2 -0
  71. paddlex/inference/models/table_structure_recognition/result.py +0 -10
  72. paddlex/inference/models/text_detection/predictor.py +8 -0
  73. paddlex/inference/models/text_detection/processors.py +44 -10
  74. paddlex/inference/models/text_detection/result.py +0 -10
  75. paddlex/inference/pipelines/__init__.py +9 -5
  76. paddlex/inference/pipelines/_parallel.py +172 -0
  77. paddlex/inference/pipelines/anomaly_detection/pipeline.py +16 -6
  78. paddlex/inference/pipelines/attribute_recognition/pipeline.py +11 -1
  79. paddlex/inference/pipelines/base.py +14 -4
  80. paddlex/inference/pipelines/components/faisser.py +1 -1
  81. paddlex/inference/pipelines/doc_preprocessor/pipeline.py +53 -27
  82. paddlex/inference/pipelines/formula_recognition/pipeline.py +120 -82
  83. paddlex/inference/pipelines/formula_recognition/result.py +1 -11
  84. paddlex/inference/pipelines/image_classification/pipeline.py +16 -6
  85. paddlex/inference/pipelines/image_multilabel_classification/pipeline.py +16 -6
  86. paddlex/inference/pipelines/instance_segmentation/pipeline.py +16 -6
  87. paddlex/inference/pipelines/keypoint_detection/pipeline.py +16 -6
  88. paddlex/inference/pipelines/layout_parsing/pipeline.py +34 -47
  89. paddlex/inference/pipelines/layout_parsing/pipeline_v2.py +893 -260
  90. paddlex/inference/pipelines/layout_parsing/result.py +4 -17
  91. paddlex/inference/pipelines/layout_parsing/result_v2.py +523 -245
  92. paddlex/inference/pipelines/layout_parsing/setting.py +87 -0
  93. paddlex/inference/pipelines/layout_parsing/utils.py +565 -1998
  94. paddlex/inference/pipelines/layout_parsing/xycut_enhanced/__init__.py +16 -0
  95. paddlex/inference/pipelines/layout_parsing/xycut_enhanced/utils.py +1144 -0
  96. paddlex/inference/pipelines/layout_parsing/xycut_enhanced/xycuts.py +563 -0
  97. paddlex/inference/pipelines/m_3d_bev_detection/pipeline.py +2 -2
  98. paddlex/inference/pipelines/multilingual_speech_recognition/pipeline.py +2 -2
  99. paddlex/inference/pipelines/object_detection/pipeline.py +16 -6
  100. paddlex/inference/pipelines/ocr/pipeline.py +127 -70
  101. paddlex/inference/pipelines/ocr/result.py +19 -16
  102. paddlex/inference/pipelines/open_vocabulary_detection/pipeline.py +2 -2
  103. paddlex/inference/pipelines/open_vocabulary_segmentation/pipeline.py +2 -2
  104. paddlex/inference/pipelines/pp_chatocr/pipeline_base.py +2 -2
  105. paddlex/inference/pipelines/pp_chatocr/pipeline_v3.py +2 -5
  106. paddlex/inference/pipelines/pp_chatocr/pipeline_v4.py +5 -5
  107. paddlex/inference/pipelines/rotated_object_detection/pipeline.py +16 -6
  108. paddlex/inference/pipelines/seal_recognition/pipeline.py +109 -53
  109. paddlex/inference/pipelines/semantic_segmentation/pipeline.py +16 -6
  110. paddlex/inference/pipelines/small_object_detection/pipeline.py +16 -6
  111. paddlex/inference/pipelines/table_recognition/pipeline.py +26 -18
  112. paddlex/inference/pipelines/table_recognition/pipeline_v2.py +624 -53
  113. paddlex/inference/pipelines/table_recognition/result.py +1 -1
  114. paddlex/inference/pipelines/table_recognition/table_recognition_post_processing_v2.py +9 -5
  115. paddlex/inference/pipelines/ts_anomaly_detection/pipeline.py +2 -2
  116. paddlex/inference/pipelines/ts_classification/pipeline.py +2 -2
  117. paddlex/inference/pipelines/ts_forecasting/pipeline.py +2 -2
  118. paddlex/inference/pipelines/video_classification/pipeline.py +2 -2
  119. paddlex/inference/pipelines/video_detection/pipeline.py +2 -2
  120. paddlex/inference/serving/basic_serving/_pipeline_apps/_common/common.py +5 -1
  121. paddlex/inference/serving/basic_serving/_pipeline_apps/layout_parsing.py +0 -1
  122. paddlex/inference/serving/basic_serving/_pipeline_apps/pp_chatocrv3_doc.py +0 -1
  123. paddlex/inference/serving/basic_serving/_pipeline_apps/pp_chatocrv4_doc.py +1 -1
  124. paddlex/inference/serving/basic_serving/_pipeline_apps/pp_structurev3.py +6 -2
  125. paddlex/inference/serving/basic_serving/_pipeline_apps/table_recognition.py +1 -5
  126. paddlex/inference/serving/basic_serving/_pipeline_apps/table_recognition_v2.py +4 -5
  127. paddlex/inference/serving/infra/utils.py +20 -22
  128. paddlex/inference/serving/schemas/formula_recognition.py +1 -1
  129. paddlex/inference/serving/schemas/layout_parsing.py +1 -2
  130. paddlex/inference/serving/schemas/pp_chatocrv3_doc.py +1 -2
  131. paddlex/inference/serving/schemas/pp_chatocrv4_doc.py +2 -2
  132. paddlex/inference/serving/schemas/pp_structurev3.py +10 -6
  133. paddlex/inference/serving/schemas/seal_recognition.py +1 -1
  134. paddlex/inference/serving/schemas/table_recognition.py +2 -6
  135. paddlex/inference/serving/schemas/table_recognition_v2.py +5 -6
  136. paddlex/inference/utils/hpi.py +8 -1
  137. paddlex/inference/utils/hpi_model_info_collection.json +81 -2
  138. paddlex/inference/utils/io/readers.py +12 -12
  139. paddlex/inference/utils/mkldnn_blocklist.py +25 -0
  140. paddlex/inference/utils/official_models.py +14 -0
  141. paddlex/inference/utils/pp_option.py +29 -8
  142. paddlex/model.py +2 -2
  143. paddlex/modules/__init__.py +1 -1
  144. paddlex/modules/anomaly_detection/evaluator.py +2 -2
  145. paddlex/modules/base/__init__.py +1 -1
  146. paddlex/modules/base/evaluator.py +5 -5
  147. paddlex/modules/base/trainer.py +1 -1
  148. paddlex/modules/doc_vlm/dataset_checker.py +2 -2
  149. paddlex/modules/doc_vlm/evaluator.py +2 -2
  150. paddlex/modules/doc_vlm/exportor.py +2 -2
  151. paddlex/modules/doc_vlm/model_list.py +1 -1
  152. paddlex/modules/doc_vlm/trainer.py +2 -2
  153. paddlex/modules/face_recognition/evaluator.py +2 -2
  154. paddlex/modules/formula_recognition/evaluator.py +5 -2
  155. paddlex/modules/formula_recognition/model_list.py +3 -0
  156. paddlex/modules/formula_recognition/trainer.py +3 -0
  157. paddlex/modules/general_recognition/evaluator.py +1 -1
  158. paddlex/modules/image_classification/evaluator.py +2 -2
  159. paddlex/modules/image_classification/model_list.py +1 -0
  160. paddlex/modules/instance_segmentation/evaluator.py +1 -1
  161. paddlex/modules/keypoint_detection/evaluator.py +1 -1
  162. paddlex/modules/m_3d_bev_detection/evaluator.py +2 -2
  163. paddlex/modules/multilabel_classification/evaluator.py +2 -2
  164. paddlex/modules/object_detection/dataset_checker/dataset_src/convert_dataset.py +4 -4
  165. paddlex/modules/object_detection/evaluator.py +2 -2
  166. paddlex/modules/object_detection/model_list.py +2 -0
  167. paddlex/modules/semantic_segmentation/evaluator.py +2 -2
  168. paddlex/modules/table_recognition/evaluator.py +2 -2
  169. paddlex/modules/text_detection/evaluator.py +2 -2
  170. paddlex/modules/text_detection/model_list.py +2 -0
  171. paddlex/modules/text_recognition/evaluator.py +2 -2
  172. paddlex/modules/text_recognition/model_list.py +2 -0
  173. paddlex/modules/ts_anomaly_detection/evaluator.py +2 -2
  174. paddlex/modules/ts_classification/dataset_checker/dataset_src/split_dataset.py +1 -1
  175. paddlex/modules/ts_classification/evaluator.py +2 -2
  176. paddlex/modules/ts_forecast/evaluator.py +2 -2
  177. paddlex/modules/video_classification/evaluator.py +2 -2
  178. paddlex/modules/video_detection/evaluator.py +2 -2
  179. paddlex/ops/__init__.py +2 -2
  180. paddlex/paddlex_cli.py +19 -13
  181. paddlex/repo_apis/Paddle3D_api/bev_fusion/model.py +2 -2
  182. paddlex/repo_apis/PaddleClas_api/cls/config.py +1 -1
  183. paddlex/repo_apis/PaddleClas_api/cls/model.py +1 -1
  184. paddlex/repo_apis/PaddleClas_api/cls/register.py +10 -0
  185. paddlex/repo_apis/PaddleClas_api/cls/runner.py +1 -1
  186. paddlex/repo_apis/PaddleDetection_api/instance_seg/model.py +1 -1
  187. paddlex/repo_apis/PaddleDetection_api/instance_seg/runner.py +1 -1
  188. paddlex/repo_apis/PaddleDetection_api/object_det/config.py +1 -1
  189. paddlex/repo_apis/PaddleDetection_api/object_det/model.py +1 -1
  190. paddlex/repo_apis/PaddleDetection_api/object_det/official_categories.py +25 -0
  191. paddlex/repo_apis/PaddleDetection_api/object_det/register.py +30 -0
  192. paddlex/repo_apis/PaddleDetection_api/object_det/runner.py +1 -1
  193. paddlex/repo_apis/PaddleOCR_api/formula_rec/config.py +3 -3
  194. paddlex/repo_apis/PaddleOCR_api/formula_rec/model.py +5 -9
  195. paddlex/repo_apis/PaddleOCR_api/formula_rec/register.py +27 -0
  196. paddlex/repo_apis/PaddleOCR_api/formula_rec/runner.py +1 -1
  197. paddlex/repo_apis/PaddleOCR_api/table_rec/model.py +1 -1
  198. paddlex/repo_apis/PaddleOCR_api/table_rec/runner.py +1 -1
  199. paddlex/repo_apis/PaddleOCR_api/text_det/model.py +1 -1
  200. paddlex/repo_apis/PaddleOCR_api/text_det/register.py +18 -0
  201. paddlex/repo_apis/PaddleOCR_api/text_det/runner.py +1 -1
  202. paddlex/repo_apis/PaddleOCR_api/text_rec/config.py +3 -3
  203. paddlex/repo_apis/PaddleOCR_api/text_rec/model.py +5 -9
  204. paddlex/repo_apis/PaddleOCR_api/text_rec/register.py +18 -0
  205. paddlex/repo_apis/PaddleOCR_api/text_rec/runner.py +1 -1
  206. paddlex/repo_apis/PaddleSeg_api/seg/model.py +1 -1
  207. paddlex/repo_apis/PaddleSeg_api/seg/runner.py +1 -1
  208. paddlex/repo_apis/PaddleTS_api/ts_ad/config.py +3 -3
  209. paddlex/repo_apis/PaddleTS_api/ts_cls/config.py +2 -2
  210. paddlex/repo_apis/PaddleTS_api/ts_fc/config.py +4 -4
  211. paddlex/repo_apis/PaddleVideo_api/video_cls/config.py +1 -1
  212. paddlex/repo_apis/PaddleVideo_api/video_cls/model.py +1 -1
  213. paddlex/repo_apis/PaddleVideo_api/video_cls/runner.py +1 -1
  214. paddlex/repo_apis/PaddleVideo_api/video_det/config.py +1 -1
  215. paddlex/repo_apis/PaddleVideo_api/video_det/model.py +1 -1
  216. paddlex/repo_apis/PaddleVideo_api/video_det/runner.py +1 -1
  217. paddlex/repo_apis/base/config.py +1 -1
  218. paddlex/repo_manager/core.py +3 -3
  219. paddlex/repo_manager/meta.py +6 -2
  220. paddlex/repo_manager/repo.py +17 -16
  221. paddlex/utils/custom_device_list.py +26 -2
  222. paddlex/utils/deps.py +1 -1
  223. paddlex/utils/device.py +15 -8
  224. paddlex/utils/env.py +4 -0
  225. paddlex/utils/flags.py +2 -4
  226. paddlex/utils/fonts/__init__.py +34 -4
  227. paddlex/utils/misc.py +1 -1
  228. {paddlex-3.0.0rc1.dist-info → paddlex-3.0.1.dist-info}/METADATA +52 -56
  229. {paddlex-3.0.0rc1.dist-info → paddlex-3.0.1.dist-info}/RECORD +233 -206
  230. {paddlex-3.0.0rc1.dist-info → paddlex-3.0.1.dist-info}/WHEEL +1 -1
  231. {paddlex-3.0.0rc1.dist-info → paddlex-3.0.1.dist-info}/entry_points.txt +0 -0
  232. {paddlex-3.0.0rc1.dist-info → paddlex-3.0.1.dist-info}/licenses/LICENSE +0 -0
  233. {paddlex-3.0.0rc1.dist-info → paddlex-3.0.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,205 @@
1
+ # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+
17
+ import paddle
18
+ import paddle.nn.functional as F
19
+
20
+ try:
21
+ from paddle.incubate.nn.functional import fused_rotary_position_embedding
22
+ except ImportError:
23
+ fused_rotary_position_embedding = None
24
+
25
+ try:
26
+ from paddle.incubate.nn.functional import swiglu
27
+ except ImportError:
28
+
29
+ def swiglu(x, y=None):
30
+ if y is None:
31
+ x, y = paddle.chunk(x, chunks=2, axis=-1)
32
+ return F.silu(x) * y
33
+
34
+
35
+ from paddle.utils import try_import
36
+
37
+
38
+ def get_env_device():
39
+ """
40
+ Return the device name of running environment.
41
+ """
42
+ if paddle.is_compiled_with_cuda():
43
+ return "gpu"
44
+ elif "npu" in paddle.device.get_all_custom_device_type():
45
+ return "npu"
46
+ elif "mlu" in paddle.device.get_all_custom_device_type():
47
+ return "mlu"
48
+ elif "gcu" in paddle.device.get_all_custom_device_type():
49
+ return "gcu"
50
+ elif "intel_hpu" in paddle.device.get_all_custom_device_type():
51
+ return "intel_hpu"
52
+ elif paddle.is_compiled_with_rocm():
53
+ return "rocm"
54
+ elif paddle.is_compiled_with_xpu():
55
+ return "xpu"
56
+ return "cpu"
57
+
58
+
59
+ try:
60
+ from paddle.incubate.nn.functional import fused_rotary_position_embedding
61
+ except ImportError:
62
+ fused_rotary_position_embedding = None
63
+ try:
64
+ if get_env_device() in ["npu", "mlu", "gcu"]:
65
+ from paddle.base import core
66
+
67
+ for lib in os.listdir(os.getenv("CUSTOM_DEVICE_ROOT")):
68
+ if lib.endswith(".so"):
69
+ paddle.utils.cpp_extension.extension_utils.load_op_meta_info_and_register_op(
70
+ lib
71
+ )
72
+ from paddle.nn.functional.flash_attention import flash_attention
73
+ except:
74
+ flash_attention = None
75
+
76
+
77
+ def fusion_rope(
78
+ query_states,
79
+ key_states,
80
+ value_states,
81
+ hidden_states,
82
+ position_ids,
83
+ past_key_value,
84
+ rotary_emb,
85
+ context_parallel_degree=-1,
86
+ ):
87
+ if get_env_device() not in ["gcu", "intel_hpu"]:
88
+ assert past_key_value is None, "fuse rotary not support cache kv for now"
89
+ batch_size, seq_length, num_heads, head_dim = query_states.shape
90
+ _, kv_seq_len, num_key_value_heads, _ = key_states.shape
91
+ if context_parallel_degree > 1:
92
+ assert (
93
+ get_env_device() == "gpu"
94
+ ), "context parallel only support cuda device for now"
95
+ kv_seq_len *= context_parallel_degree
96
+ if get_env_device() not in ["gcu", "intel_hpu"]:
97
+ cos, sin = rotary_emb(value_states, seq_len=kv_seq_len)
98
+ if get_env_device() == "npu":
99
+ query_states = core.eager._run_custom_op("fused_rope", query_states, cos, sin)[
100
+ 0
101
+ ]
102
+ key_states = core.eager._run_custom_op("fused_rope", key_states, cos, sin)[0]
103
+ elif get_env_device() == "intel_hpu":
104
+ if past_key_value is not None:
105
+ kv_seq_len += past_key_value[0].shape[-3]
106
+ cos, sin = rotary_emb(value_states, seq_len=kv_seq_len)
107
+ cos = cos.squeeze().unsqueeze(0).unsqueeze(0)
108
+ sin = sin.squeeze().unsqueeze(0).unsqueeze(0)
109
+ query_states, _, _ = (
110
+ paddle.incubate.nn.functional.fused_rotary_position_embedding(
111
+ paddle.transpose(query_states, [0, 2, 1, 3]),
112
+ None,
113
+ None,
114
+ sin=sin,
115
+ cos=cos,
116
+ position_ids=position_ids,
117
+ )
118
+ )
119
+ key_states, _, _ = (
120
+ paddle.incubate.nn.functional.fused_rotary_position_embedding(
121
+ paddle.transpose(key_states, [0, 2, 1, 3]),
122
+ None,
123
+ None,
124
+ sin=sin,
125
+ cos=cos,
126
+ position_ids=position_ids,
127
+ )
128
+ )
129
+ query_states = paddle.transpose(query_states, [0, 2, 1, 3])
130
+ key_states = paddle.transpose(key_states, [0, 2, 1, 3])
131
+ elif get_env_device() == "gcu":
132
+ cos_sin = rotary_emb.get_fused_cos_sin(value_states, seq_len=kv_seq_len)
133
+ query_states, key_states = core.eager._run_custom_op(
134
+ "fused_rotary_embedding_gcu",
135
+ query_states,
136
+ key_states,
137
+ cos_sin,
138
+ position_ids,
139
+ True,
140
+ )
141
+ else:
142
+ # paddle version > 2.6 or develop support q and k/v with different num_heads
143
+ paddle_version = float(paddle.__version__[:3])
144
+ if ((paddle_version != 0.0) and (paddle_version <= 2.6)) and (
145
+ num_heads != num_key_value_heads
146
+ ):
147
+ query_states, _, _ = fused_rotary_position_embedding(
148
+ query_states,
149
+ None,
150
+ None,
151
+ sin=sin,
152
+ cos=cos,
153
+ position_ids=position_ids,
154
+ use_neox_rotary_style=False,
155
+ )
156
+ key_states, _, _ = fused_rotary_position_embedding(
157
+ key_states,
158
+ None,
159
+ None,
160
+ sin=sin,
161
+ cos=cos,
162
+ position_ids=position_ids,
163
+ use_neox_rotary_style=False,
164
+ )
165
+ else:
166
+ query_states, key_states, _ = fused_rotary_position_embedding(
167
+ query_states,
168
+ key_states,
169
+ v=None,
170
+ sin=sin,
171
+ cos=cos,
172
+ position_ids=position_ids,
173
+ use_neox_rotary_style=False,
174
+ )
175
+ return query_states, key_states
176
+
177
+
178
+ def rms_norm_fused(x_in, w, eps, use_fast_ln=False):
179
+ if use_fast_ln:
180
+ fast_ln = try_import("fast_ln")
181
+ return fast_ln.fast_rms_norm(x_in, w, eps)[0]
182
+ else:
183
+ fused_ln = try_import("fused_ln")
184
+ return fused_ln.fused_rms_norm(x_in, w, eps)[0]
185
+
186
+
187
+ def fusion_rms_norm(hidden_states, weight, variance_epsilon, use_fast_ln=False):
188
+ if get_env_device() == "npu":
189
+ return core.eager._run_custom_op(
190
+ "rms_norm_npu", hidden_states, weight, variance_epsilon
191
+ )[0]
192
+ if get_env_device() == "mlu":
193
+ return core.eager._run_custom_op(
194
+ "rms_norm_mlu", hidden_states, weight, variance_epsilon
195
+ )[0]
196
+ elif get_env_device() == "gcu":
197
+ return core.eager._run_custom_op(
198
+ "rms_norm_gcu", hidden_states, weight, variance_epsilon
199
+ )[0]
200
+ elif get_env_device() == "intel_hpu":
201
+ return paddle.incubate.nn.functional.fused_rms_norm(
202
+ hidden_states, weight, None, variance_epsilon, hidden_states.dim() - 1
203
+ )[0]
204
+
205
+ return rms_norm_fused(hidden_states, weight, variance_epsilon, use_fast_ln)
@@ -88,7 +88,7 @@ class GenerationConfig:
88
88
  use_fast: (bool, optional): Whether to use fast entry of model
89
89
  for FastGeneration. Default to False.
90
90
  use_fp16_decoding: (bool, optional): Whether to use fp16 for decoding.
91
- Only works when fast entry is avalible. Default to False.
91
+ Only works when fast entry is available. Default to False.
92
92
  trunc_input: (bool, optional): Whether to truncate the inputs from
93
93
  output sequences . Default to True.
94
94
  model_kwargs (dict): It can be used to specify additional kwargs
@@ -487,7 +487,7 @@ class SequenceBiasLogitsProcessor(LogitsProcessor):
487
487
  self._validate_arguments()
488
488
 
489
489
  # Bias variables that will be populated on the first call (for retrocompatibility purposes, the vocabulary size
490
- # is infered in the first usage, which inhibits initializing here)
490
+ # is inferred in the first usage, which inhibits initializing here)
491
491
  self.length_1_bias = None
492
492
  self.prepared_bias_variables = False
493
493
 
@@ -1443,7 +1443,7 @@ class GenerationMixin(object):
1443
1443
  next_tokens = paddle.multinomial(probs)
1444
1444
 
1445
1445
  if self.config.tensor_parallel_degree > 1:
1446
- # Maybe no need to broadcast if seed is set correclty.
1446
+ # Maybe no need to broadcast if seed is set correctly.
1447
1447
  from paddle.distributed import fleet
1448
1448
 
1449
1449
  try:
@@ -496,7 +496,7 @@ class PretrainedConfig:
496
496
  if num_labels is not None and len(self.id2label) != num_labels:
497
497
  logging.warning(
498
498
  f"You passed along `num_labels={num_labels}` with an incompatible id to label map: "
499
- f"{self.id2label}. The number of labels wil be overwritten to {self.num_labels}."
499
+ f"{self.id2label}. The number of labels will be overwritten to {self.num_labels}."
500
500
  )
501
501
  self.id2label = dict(
502
502
  (int(key), value) for key, value in self.id2label.items()
@@ -909,7 +909,7 @@ class PretrainedConfig:
909
909
 
910
910
  def register_unsavable_keys(self, keys):
911
911
  # Save: not save it in any case
912
- # Print: show it if non defalut value
912
+ # Print: show it if non default value
913
913
  if isinstance(keys, list) or isinstance(keys, tuple):
914
914
  for key in keys:
915
915
  self._unsavable_keys.add(key)
@@ -939,7 +939,7 @@ class PretrainedConfig:
939
939
 
940
940
  output[key] = value
941
941
 
942
- # Fix for rewrited from_pretrained method, hasattr
942
+ # Fix for rewrote from_pretrained method, hasattr
943
943
  if saving_file and hasattr(self, "_unsavable_keys"):
944
944
  for key in list(output.keys()):
945
945
  if key in self._unsavable_keys:
@@ -51,7 +51,7 @@ class StateDictNameMapping:
51
51
  return self.action == "transpose"
52
52
 
53
53
  def should_merge_last_two_dim(self) -> bool:
54
- """check that wether merge last two dim"""
54
+ """check that whether merge last two dim"""
55
55
  return self.action == "merge_last_two_dim"
56
56
 
57
57
  def run(self, state_dict: dict[str, ndarray], name: str) -> ndarray:
@@ -104,7 +104,7 @@ class StateDictNameMapping:
104
104
  class ConversionMixin:
105
105
  @classmethod
106
106
  def support_conversion(cls, config: PretrainedConfig) -> bool:
107
- """check wether the model support conversion"""
107
+ """check whether the model support conversion"""
108
108
  try:
109
109
  # try to get the name-mapping info
110
110
  _ = cls._get_name_mappings(config)
@@ -166,7 +166,7 @@ class ConversionMixin:
166
166
  with device_guard("cpu"):
167
167
  state_dict = paddle.load(weight_file, return_numpy=False)
168
168
  logging.info(
169
- "Starting to convert orignal state_dict to tensor parallel state_dict."
169
+ "Starting to convert original state_dict to tensor parallel state_dict."
170
170
  )
171
171
 
172
172
  state_keys_map = cls._resolve_prefix_keys(
@@ -298,11 +298,11 @@ def _transformer_encoder_fwd(
298
298
  # MultiHeadAttention not so efficiently, and maybe optimize it later.
299
299
  if cache is None and getattr(self, "_use_cache", False):
300
300
  cache = [tuple(self.layers[0].gen_cache(src))] * len(self.layers)
301
- # To be compatible with `TransformerEncoder.forward`, `_use_cache` defualts
301
+ # To be compatible with `TransformerEncoder.forward`, `_use_cache` defaults
302
302
  # to True when cache is not None.
303
303
  new_caches = [] if cache is not None and getattr(self, "_use_cache", True) else None
304
304
  all_attentions = [] if output_attentions else None
305
- # NOTE: Also includes embeding output which is same as HF.
305
+ # NOTE: Also includes embedding output which is same as HF.
306
306
  all_hidden_states = [output] if output_hidden_states else None
307
307
  for i, mod in enumerate(self.layers):
308
308
  # if output has no gradient, recompute is unnecessary
@@ -185,7 +185,7 @@ def _convert_state_dict_dtype_and_shape(state_dict, model_to_load):
185
185
  if key in list(state_dict.keys()):
186
186
  if isinstance(state_dict[key], np.ndarray):
187
187
  raise ValueError(
188
- "convert_state_dict_dtype expected paddle.Tensor not numpy.ndarray, plase convert numpy.ndarray to paddle.Tensor"
188
+ "convert_state_dict_dtype expected paddle.Tensor not numpy.ndarray, please convert numpy.ndarray to paddle.Tensor"
189
189
  )
190
190
  # confirm parameter cast is executed on the same device as model
191
191
  # TODO: cast(FP32 -> FP16) has diff on different devices, need to fix it
@@ -1080,7 +1080,7 @@ class PretrainedModel(
1080
1080
  elif "pytorch_model.bin" in str(resolved_archive_file):
1081
1081
  if not from_hf_hub and not convert_from_torch:
1082
1082
  raise ValueError(
1083
- f"Download pytorch wight in "
1083
+ f"Download pytorch weight in "
1084
1084
  f" {resolved_archive_file}. Please set convert_from_torch=True in from_pretrained. eg, Model.from_pretrained(model_name, convert_from_torch=True) "
1085
1085
  )
1086
1086
 
@@ -1488,7 +1488,7 @@ class PretrainedModel(
1488
1488
  dtype,
1489
1489
  )
1490
1490
 
1491
- # Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
1491
+ # Mismatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
1492
1492
  # matching the weights in the model.
1493
1493
  mismatched_keys += _find_mismatched_keys(
1494
1494
  state_dict,
@@ -1681,14 +1681,14 @@ class PretrainedModel(
1681
1681
  )
1682
1682
  convert_from_torch = True
1683
1683
 
1684
- # from_hf_hub defalut enable convert_from_torch
1684
+ # from_hf_hub default enable convert_from_torch
1685
1685
  if from_hf_hub and convert_from_torch is None:
1686
1686
  logging.warning(
1687
1687
  "If you are attempting to load weights from Hugging Face Hub and want to disable the default behavior of considering torch weights,"
1688
1688
  " you can set ·convert_from_torch=False·. By default, `convert_from_torch` is set to `True`. "
1689
1689
  )
1690
1690
  convert_from_torch = True
1691
- # convert_from_torch defalut is False
1691
+ # convert_from_torch default is False
1692
1692
  if convert_from_torch is None:
1693
1693
  convert_from_torch = False
1694
1694
 
@@ -1922,7 +1922,7 @@ class PretrainedModel(
1922
1922
  assert (
1923
1923
  k
1924
1924
  not in final_config["mp_config"]["parallelize_plan"].keys()
1925
- ), f"sublayer mp_config shuld be a subset of model but got sublayer config {config['mp_config']} and model config {final_config['mp_config']}."
1925
+ ), f"sublayer mp_config should be a subset of model but got sublayer config {config['mp_config']} and model config {final_config['mp_config']}."
1926
1926
  final_config["mp_config"]["parallelize_plan"][k] = v
1927
1927
  if "sp_config" in config and config["sp_config"] is not None:
1928
1928
  if final_config["sp_config"] is None:
@@ -1932,7 +1932,7 @@ class PretrainedModel(
1932
1932
  assert (
1933
1933
  k
1934
1934
  not in final_config["sp_config"]["parallelize_plan"].keys()
1935
- ), f"sublayer sp_config shuld be a subset of model but got sublayer config {config['sp_config']} and model config {final_config['sp_config']}."
1935
+ ), f"sublayer sp_config should be a subset of model but got sublayer config {config['sp_config']} and model config {final_config['sp_config']}."
1936
1936
  final_config["sp_config"]["parallelize_plan"][k] = v
1937
1937
  if "pp_config" in config and config["pp_config"] is not None:
1938
1938
  if isinstance(config["pp_config"]["split_spec"], str):
@@ -2011,28 +2011,4 @@ class PretrainedModel(
2011
2011
  merged_config["pp_config"] is not None
2012
2012
  final_config["pp_config"] = merged_config["pp_config"]
2013
2013
 
2014
- if (
2015
- "data_sharding_parallel" in auto_dist_degree
2016
- and auto_dist_degree["data_sharding_parallel"]
2017
- ):
2018
- # to avoid a circular import
2019
- from paddlenlp.trainer.trainer_utils import ShardingOption
2020
-
2021
- level = 0
2022
- if (
2023
- "sharding" in auto_dist_degree
2024
- and auto_dist_degree["sharding"] is not None
2025
- ):
2026
- sharding = auto_dist_degree["sharding"]
2027
- if ShardingOption.SHARD_OP in sharding:
2028
- level = 1
2029
- if ShardingOption.SHARD_GRAD_OP in sharding:
2030
- level = 2
2031
- if ShardingOption.FULL_SHARD in sharding:
2032
- level = 3
2033
- final_config["dp_config"] = {
2034
- "sharding_level": level,
2035
- "sharding_mesh_dim": auto_dist_degree.get("sharding_mesh_dim", None),
2036
- }
2037
-
2038
2014
  return final_config