mindspore 2.7.0__cp310-cp310-win_amd64.whl → 2.7.1__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (290) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -1
  3. mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
  6. mindspore/_extends/parse/compile_config.py +24 -1
  7. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +6 -2
  8. mindspore/_extends/parse/resources.py +1 -1
  9. mindspore/_extends/parse/standard_method.py +8 -1
  10. mindspore/_extends/parse/trope.py +2 -1
  11. mindspore/_extends/pijit/pijit_func_white_list.py +7 -22
  12. mindspore/avcodec-59.dll +0 -0
  13. mindspore/avdevice-59.dll +0 -0
  14. mindspore/avfilter-8.dll +0 -0
  15. mindspore/avformat-59.dll +0 -0
  16. mindspore/avutil-57.dll +0 -0
  17. mindspore/boost/base.py +29 -2
  18. mindspore/common/_decorator.py +3 -2
  19. mindspore/common/_grad_function.py +3 -1
  20. mindspore/common/_tensor_cpp_method.py +1 -1
  21. mindspore/common/_tensor_docs.py +275 -64
  22. mindspore/common/_utils.py +0 -44
  23. mindspore/common/api.py +285 -35
  24. mindspore/common/dump.py +7 -108
  25. mindspore/common/dynamic_shape/auto_dynamic_shape.py +1 -3
  26. mindspore/common/hook_handle.py +60 -0
  27. mindspore/common/jit_config.py +5 -1
  28. mindspore/common/jit_trace.py +27 -12
  29. mindspore/common/lazy_inline.py +5 -3
  30. mindspore/common/parameter.py +13 -107
  31. mindspore/common/recompute.py +4 -11
  32. mindspore/common/tensor.py +16 -169
  33. mindspore/communication/_comm_helper.py +11 -1
  34. mindspore/communication/comm_func.py +138 -4
  35. mindspore/communication/management.py +85 -1
  36. mindspore/config/op_info.config +0 -15
  37. mindspore/context.py +5 -85
  38. mindspore/dataset/engine/datasets.py +8 -4
  39. mindspore/dataset/engine/datasets_vision.py +1 -1
  40. mindspore/dataset/engine/validators.py +1 -15
  41. mindspore/dnnl.dll +0 -0
  42. mindspore/{experimental/llm_boost/ascend_native → graph}/__init__.py +7 -7
  43. mindspore/graph/custom_pass.py +55 -0
  44. mindspore/include/dataset/execute.h +2 -2
  45. mindspore/jpeg62.dll +0 -0
  46. mindspore/mindrecord/__init__.py +3 -3
  47. mindspore/mindrecord/common/exceptions.py +1 -0
  48. mindspore/mindrecord/config.py +1 -1
  49. mindspore/{parallel/mpi → mindrecord/core}/__init__.py +4 -1
  50. mindspore/mindrecord/{shardheader.py → core/shardheader.py} +2 -1
  51. mindspore/mindrecord/{shardindexgenerator.py → core/shardindexgenerator.py} +1 -1
  52. mindspore/mindrecord/{shardreader.py → core/shardreader.py} +2 -1
  53. mindspore/mindrecord/{shardsegment.py → core/shardsegment.py} +2 -2
  54. mindspore/mindrecord/{shardutils.py → core/shardutils.py} +1 -1
  55. mindspore/mindrecord/{shardwriter.py → core/shardwriter.py} +1 -1
  56. mindspore/mindrecord/filereader.py +4 -4
  57. mindspore/mindrecord/filewriter.py +5 -5
  58. mindspore/mindrecord/mindpage.py +2 -2
  59. mindspore/mindrecord/tools/cifar10.py +1 -1
  60. mindspore/mindrecord/tools/cifar100.py +1 -1
  61. mindspore/mindrecord/tools/cifar100_to_mr.py +1 -1
  62. mindspore/mindrecord/tools/cifar10_to_mr.py +1 -1
  63. mindspore/mindrecord/tools/csv_to_mr.py +1 -1
  64. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  65. mindspore/mindrecord/tools/mnist_to_mr.py +1 -1
  66. mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -1
  67. mindspore/mindspore_backend_common.dll +0 -0
  68. mindspore/mindspore_backend_manager.dll +0 -0
  69. mindspore/mindspore_cluster.dll +0 -0
  70. mindspore/mindspore_common.dll +0 -0
  71. mindspore/mindspore_core.dll +0 -0
  72. mindspore/mindspore_cpu.dll +0 -0
  73. mindspore/mindspore_dump.dll +0 -0
  74. mindspore/mindspore_frontend.dll +0 -0
  75. mindspore/mindspore_glog.dll +0 -0
  76. mindspore/mindspore_hardware_abstract.dll +0 -0
  77. mindspore/mindspore_memory_pool.dll +0 -0
  78. mindspore/mindspore_ms_backend.dll +0 -0
  79. mindspore/mindspore_ops.dll +0 -0
  80. mindspore/{mindspore_ops_host.dll → mindspore_ops_cpu.dll} +0 -0
  81. mindspore/mindspore_profiler.dll +0 -0
  82. mindspore/mindspore_pyboost.dll +0 -0
  83. mindspore/mindspore_pynative.dll +0 -0
  84. mindspore/mindspore_runtime_pipeline.dll +0 -0
  85. mindspore/mindspore_runtime_utils.dll +0 -0
  86. mindspore/mindspore_tools.dll +0 -0
  87. mindspore/mint/__init__.py +15 -10
  88. mindspore/mint/distributed/distributed.py +182 -62
  89. mindspore/mint/nn/__init__.py +2 -16
  90. mindspore/mint/nn/functional.py +4 -110
  91. mindspore/mint/nn/layer/__init__.py +0 -2
  92. mindspore/mint/nn/layer/activation.py +0 -6
  93. mindspore/mint/nn/layer/basic.py +0 -47
  94. mindspore/mint/nn/layer/conv.py +4 -4
  95. mindspore/mint/nn/layer/normalization.py +8 -13
  96. mindspore/mint/nn/layer/pooling.py +0 -4
  97. mindspore/nn/__init__.py +1 -3
  98. mindspore/nn/cell.py +16 -66
  99. mindspore/nn/layer/basic.py +49 -1
  100. mindspore/nn/layer/container.py +16 -0
  101. mindspore/nn/layer/embedding.py +4 -169
  102. mindspore/nn/layer/normalization.py +2 -1
  103. mindspore/nn/layer/thor_layer.py +4 -85
  104. mindspore/nn/optim/ada_grad.py +0 -1
  105. mindspore/nn/optim/adafactor.py +0 -1
  106. mindspore/nn/optim/adam.py +31 -124
  107. mindspore/nn/optim/adamax.py +0 -1
  108. mindspore/nn/optim/asgd.py +0 -1
  109. mindspore/nn/optim/ftrl.py +8 -102
  110. mindspore/nn/optim/lamb.py +0 -1
  111. mindspore/nn/optim/lars.py +0 -3
  112. mindspore/nn/optim/lazyadam.py +25 -218
  113. mindspore/nn/optim/momentum.py +5 -43
  114. mindspore/nn/optim/optimizer.py +6 -55
  115. mindspore/nn/optim/proximal_ada_grad.py +0 -1
  116. mindspore/nn/optim/rmsprop.py +0 -1
  117. mindspore/nn/optim/rprop.py +0 -1
  118. mindspore/nn/optim/sgd.py +0 -1
  119. mindspore/nn/optim/tft_wrapper.py +0 -1
  120. mindspore/nn/optim/thor.py +0 -2
  121. mindspore/nn/probability/bijector/bijector.py +7 -8
  122. mindspore/nn/probability/bijector/gumbel_cdf.py +2 -2
  123. mindspore/nn/probability/bijector/power_transform.py +20 -21
  124. mindspore/nn/probability/bijector/scalar_affine.py +5 -5
  125. mindspore/nn/probability/bijector/softplus.py +13 -14
  126. mindspore/nn/wrap/grad_reducer.py +4 -74
  127. mindspore/numpy/array_creations.py +2 -2
  128. mindspore/numpy/fft.py +9 -9
  129. mindspore/{nn/reinforcement → onnx}/__init__.py +5 -8
  130. mindspore/onnx/onnx_export.py +137 -0
  131. mindspore/opencv_core4110.dll +0 -0
  132. mindspore/opencv_imgcodecs4110.dll +0 -0
  133. mindspore/{opencv_imgproc452.dll → opencv_imgproc4110.dll} +0 -0
  134. mindspore/ops/__init__.py +2 -0
  135. mindspore/ops/_grad_experimental/grad_comm_ops.py +38 -2
  136. mindspore/ops/_op_impl/aicpu/__init__.py +0 -10
  137. mindspore/ops/_op_impl/cpu/__init__.py +0 -5
  138. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +16 -22
  139. mindspore/ops/auto_generate/gen_extend_func.py +2 -7
  140. mindspore/ops/auto_generate/gen_ops_def.py +98 -141
  141. mindspore/ops/auto_generate/gen_ops_prim.py +12708 -12686
  142. mindspore/ops/communication.py +97 -0
  143. mindspore/ops/composite/__init__.py +5 -2
  144. mindspore/ops/composite/base.py +15 -1
  145. mindspore/ops/composite/multitype_ops/__init__.py +3 -1
  146. mindspore/ops/composite/multitype_ops/_compile_utils.py +150 -8
  147. mindspore/ops/composite/multitype_ops/add_impl.py +7 -0
  148. mindspore/ops/composite/multitype_ops/mod_impl.py +27 -0
  149. mindspore/ops/function/__init__.py +1 -0
  150. mindspore/ops/function/array_func.py +14 -12
  151. mindspore/ops/function/comm_func.py +3883 -0
  152. mindspore/ops/function/debug_func.py +3 -4
  153. mindspore/ops/function/math_func.py +45 -54
  154. mindspore/ops/function/nn_func.py +75 -294
  155. mindspore/ops/function/random_func.py +9 -18
  156. mindspore/ops/functional.py +2 -0
  157. mindspore/ops/functional_overload.py +354 -18
  158. mindspore/ops/operations/__init__.py +2 -5
  159. mindspore/ops/operations/_custom_ops_utils.py +7 -9
  160. mindspore/ops/operations/_inner_ops.py +1 -38
  161. mindspore/ops/operations/_rl_inner_ops.py +0 -933
  162. mindspore/ops/operations/array_ops.py +1 -0
  163. mindspore/ops/operations/comm_ops.py +94 -2
  164. mindspore/ops/operations/custom_ops.py +228 -19
  165. mindspore/ops/operations/debug_ops.py +27 -29
  166. mindspore/ops/operations/manually_defined/ops_def.py +27 -306
  167. mindspore/ops/operations/nn_ops.py +2 -2
  168. mindspore/ops/operations/sparse_ops.py +0 -83
  169. mindspore/ops/primitive.py +1 -17
  170. mindspore/ops/tensor_method.py +72 -3
  171. mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +5 -5
  172. mindspore/ops_generate/aclnn/gen_aclnn_implement.py +8 -8
  173. mindspore/ops_generate/api/functions_cc_generator.py +53 -4
  174. mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +25 -11
  175. mindspore/ops_generate/common/gen_constants.py +11 -10
  176. mindspore/ops_generate/common/op_proto.py +18 -1
  177. mindspore/ops_generate/common/template.py +102 -245
  178. mindspore/ops_generate/common/template_utils.py +212 -0
  179. mindspore/ops_generate/gen_custom_ops.py +69 -0
  180. mindspore/ops_generate/op_def/ops_def_cc_generator.py +78 -7
  181. mindspore/ops_generate/op_def_py/base_op_prim_py_generator.py +360 -0
  182. mindspore/ops_generate/op_def_py/custom_op_prim_py_generator.py +140 -0
  183. mindspore/ops_generate/op_def_py/op_def_py_generator.py +54 -7
  184. mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -312
  185. mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +74 -17
  186. mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +22 -5
  187. mindspore/ops_generate/pyboost/op_template_parser.py +3 -2
  188. mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +21 -5
  189. mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +2 -2
  190. mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +30 -10
  191. mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +10 -3
  192. mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +1 -1
  193. mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +19 -9
  194. mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +71 -28
  195. mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +10 -9
  196. mindspore/ops_generate/pyboost/pyboost_utils.py +27 -16
  197. mindspore/ops_generate/resources/yaml_loader.py +13 -0
  198. mindspore/ops_generate/tensor_py_cc_generator.py +2 -2
  199. mindspore/parallel/_cell_wrapper.py +1 -1
  200. mindspore/parallel/_parallel_serialization.py +1 -4
  201. mindspore/parallel/_utils.py +29 -6
  202. mindspore/parallel/checkpoint_transform.py +18 -2
  203. mindspore/parallel/cluster/process_entity/_api.py +24 -32
  204. mindspore/parallel/cluster/process_entity/_utils.py +9 -5
  205. mindspore/{experimental/llm_boost/atb → parallel/distributed}/__init__.py +21 -23
  206. mindspore/parallel/distributed/distributed_data_parallel.py +393 -0
  207. mindspore/parallel/distributed/flatten_grad_buffer.py +295 -0
  208. mindspore/parallel/strategy.py +336 -0
  209. mindspore/parallel/transform_safetensors.py +117 -16
  210. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +3 -0
  211. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +1 -1
  212. mindspore/profiler/common/constant.py +5 -0
  213. mindspore/profiler/common/file_manager.py +9 -0
  214. mindspore/profiler/common/msprof_cmd_tool.py +38 -2
  215. mindspore/profiler/common/path_manager.py +56 -24
  216. mindspore/profiler/common/profiler_context.py +2 -12
  217. mindspore/profiler/common/profiler_info.py +3 -3
  218. mindspore/profiler/common/profiler_path_manager.py +13 -0
  219. mindspore/profiler/common/util.py +30 -3
  220. mindspore/profiler/experimental_config.py +2 -1
  221. mindspore/profiler/platform/npu_profiler.py +33 -6
  222. mindspore/run_check/_check_version.py +108 -24
  223. mindspore/runtime/__init__.py +3 -2
  224. mindspore/runtime/executor.py +11 -3
  225. mindspore/runtime/memory.py +112 -0
  226. mindspore/swresample-4.dll +0 -0
  227. mindspore/swscale-6.dll +0 -0
  228. mindspore/tinyxml2.dll +0 -0
  229. mindspore/{experimental/llm_boost → tools}/__init__.py +5 -5
  230. mindspore/tools/data_dump.py +130 -0
  231. mindspore/tools/sdc_detect.py +91 -0
  232. mindspore/tools/stress_detect.py +63 -0
  233. mindspore/train/__init__.py +6 -6
  234. mindspore/train/_utils.py +5 -18
  235. mindspore/train/amp.py +6 -4
  236. mindspore/train/callback/_checkpoint.py +0 -9
  237. mindspore/train/callback/_train_fault_tolerance.py +69 -18
  238. mindspore/train/data_sink.py +1 -5
  239. mindspore/train/model.py +38 -211
  240. mindspore/train/serialization.py +126 -387
  241. mindspore/turbojpeg.dll +0 -0
  242. mindspore/utils/__init__.py +6 -3
  243. mindspore/utils/dlpack.py +92 -0
  244. mindspore/utils/dryrun.py +1 -1
  245. mindspore/utils/runtime_execution_order_check.py +10 -0
  246. mindspore/utils/sdc_detect.py +14 -12
  247. mindspore/utils/stress_detect.py +43 -0
  248. mindspore/utils/utils.py +144 -8
  249. mindspore/version.py +1 -1
  250. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/METADATA +3 -2
  251. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/RECORD +254 -267
  252. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -210
  253. mindspore/experimental/llm_boost/ascend_native/llm_boost.py +0 -52
  254. mindspore/experimental/llm_boost/atb/boost_base.py +0 -385
  255. mindspore/experimental/llm_boost/atb/llama_boost.py +0 -137
  256. mindspore/experimental/llm_boost/atb/qwen_boost.py +0 -124
  257. mindspore/experimental/llm_boost/register.py +0 -130
  258. mindspore/experimental/llm_boost/utils.py +0 -31
  259. mindspore/include/OWNERS +0 -7
  260. mindspore/mindspore_cpu_res_manager.dll +0 -0
  261. mindspore/mindspore_ops_kernel_common.dll +0 -0
  262. mindspore/mindspore_res_manager.dll +0 -0
  263. mindspore/nn/optim/_dist_optimizer_registry.py +0 -111
  264. mindspore/nn/reinforcement/_batch_read_write.py +0 -142
  265. mindspore/nn/reinforcement/_tensors_queue.py +0 -152
  266. mindspore/nn/reinforcement/tensor_array.py +0 -145
  267. mindspore/opencv_core452.dll +0 -0
  268. mindspore/opencv_imgcodecs452.dll +0 -0
  269. mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +0 -113
  270. mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +0 -96
  271. mindspore/ops/_op_impl/aicpu/sparse_cross.py +0 -42
  272. mindspore/ops/_op_impl/cpu/buffer_append.py +0 -28
  273. mindspore/ops/_op_impl/cpu/buffer_get.py +0 -28
  274. mindspore/ops/_op_impl/cpu/buffer_sample.py +0 -28
  275. mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +0 -42
  276. mindspore/ops/operations/_tensor_array.py +0 -359
  277. mindspore/ops/operations/rl_ops.py +0 -288
  278. mindspore/parallel/_offload_context.py +0 -275
  279. mindspore/parallel/_recovery_context.py +0 -115
  280. mindspore/parallel/_transformer/__init__.py +0 -35
  281. mindspore/parallel/_transformer/layers.py +0 -765
  282. mindspore/parallel/_transformer/loss.py +0 -251
  283. mindspore/parallel/_transformer/moe.py +0 -693
  284. mindspore/parallel/_transformer/op_parallel_config.py +0 -222
  285. mindspore/parallel/_transformer/transformer.py +0 -3124
  286. mindspore/parallel/mpi/_mpi_config.py +0 -116
  287. mindspore/train/memory_profiling_pb2.py +0 -298
  288. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/WHEEL +0 -0
  289. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/entry_points.txt +0 -0
  290. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/top_level.txt +0 -0
@@ -1,210 +0,0 @@
1
- # Copyright 2024 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ============================================================================
15
- """AscendNative Llama Boost APIs."""
16
-
17
- import os
18
- import numpy as np
19
- from mindspore.common import Tensor, dtype
20
- from mindspore.experimental.llm_boost.ascend_native.llm_boost import LLMBoost
21
- from mindspore.experimental.llm_boost.register import LlmBoostRegister, LlmBoostType
22
-
23
- def RoundUp(val: int, align: int) -> int:
24
- if align == 0:
25
- return 0
26
- return -(val // -align) * align
27
-
28
-
29
- def ConvertTensor(nd_mat: np.ndarray, transpose: bool = True, nd2nz: bool = True) -> np.ndarray:
30
- """ Transforms tensor format from Nd to Nz """
31
- if transpose:
32
- nd_mat = np.transpose(nd_mat)
33
- if not nd2nz:
34
- return nd_mat
35
- block_size = (16, 16)
36
- r = RoundUp(nd_mat.shape[0], block_size[0])
37
- c = RoundUp(nd_mat.shape[1], block_size[1])
38
- r_pad = r - nd_mat.shape[0]
39
- c_pad = c - nd_mat.shape[1]
40
- nd_mat = np.pad(nd_mat, ((0, r_pad), (0, c_pad)))
41
- nz_mat = np.transpose(np.reshape(
42
- nd_mat, (r, c // block_size[1], block_size[1])), (1, 0, 2))
43
- nz_mat = nz_mat.reshape(r, c)
44
- return nz_mat
45
-
46
-
47
- @LlmBoostRegister.register(LlmBoostType.ASCEND_NATIVE, "Llama")
48
- class LlamaBoostAscendNative(LLMBoost):
49
- r"""
50
- Implements an Llama model in a single kernel.
51
- it forwards the python functions to the C++ binded object
52
- """
53
- def _get_from_dict(self, dictionary, name):
54
- """ internal function to get a specific tensor from the dictionary """
55
- all_relevant_layers = [value for key, value in dictionary.items() if name in key]
56
- if all_relevant_layers:
57
- return all_relevant_layers[0].asnumpy()
58
- return None
59
-
60
- def _get_quant_triplet_from_dict(self, dictionary, name):
61
- """ internal function to get a weight triple tensor from the dictionary """
62
- weights = self._get_from_dict(dictionary, name + "._handler.weight")
63
- scale = self._get_from_dict(dictionary, name + "._weight_quantizer.scale")
64
- offset = self._get_from_dict(dictionary, name + "._weight_quantizer.zp_neg")
65
- return weights, scale, offset
66
-
67
- def _prepare_single_layer(self, ckpt, config, id):
68
- """ prepares the dictionary of weights of a single layer """
69
- prefix = 'model.layers.' + str(id)
70
- is_last = id == config.num_layers-1
71
- layer = 'layers.' + str(id) + '.'
72
- l_dict = {key: value for key, value in ckpt.items() if layer in key}
73
- if config.n_kv_heads is None:
74
- config.n_kv_heads = config.num_heads
75
- start = 0
76
- end = config.hidden_size
77
- kv_start = 0
78
- kv_end = int(config.hidden_size*config.n_kv_heads/config.num_heads)
79
- ffn_hid = [value for key, value in l_dict.items() if "w3" in key][0].shape[0]
80
- ffn_start = 0
81
- ffn_end = ffn_hid
82
- rank_size = int(os.getenv('RANK_SIZE', '1'))
83
- #Emir if (config.parallel_mode != 2): # 2 - AUTO_PARALLEL
84
- hid_size = end
85
- kv_hid_size = kv_end
86
- embed_size = config.vocab_size
87
- rank_id = int(os.getenv('RANK_ID', '0'))
88
- if (hid_size % rank_size == 0) and (ffn_hid % rank_size == 0) and (embed_size % rank_size == 0):
89
- start = int(rank_id * hid_size / rank_size)
90
- end = int((rank_id + 1) * hid_size / rank_size)
91
- kv_start = int(rank_id * kv_hid_size / rank_size)
92
- kv_end = int((rank_id + 1) * kv_hid_size / rank_size)
93
- ffn_start = int(rank_id * ffn_hid / rank_size)
94
- ffn_end = int((rank_id + 1) * ffn_hid / rank_size)
95
- else:
96
- raise RuntimeError("hidden size and ffn hidden size must be divided by rank size without remainder. \
97
- hidden_size: ", hid_size, " ffn_hidden_size: ", ffn_hid, " rank_size: ", rank_size)
98
- quant = self._get_from_dict(l_dict, "_weight_quantizer") is not None
99
- unite_qkv = config.num_heads == config.n_kv_heads
100
- self.dictionary[prefix + ".attention_norm.weight"] = \
101
- Tensor(self._get_from_dict(l_dict, "attention_norm"), dtype=dtype.float16)
102
- self.dictionary[prefix + ".ffn_norm.weight"] = \
103
- Tensor(self._get_from_dict(l_dict, "ffn_norm"), dtype=dtype.float16)
104
- if is_last:
105
- self.dictionary['lm_head.weight'] = Tensor(ConvertTensor(ckpt['lm_head.weight'].asnumpy()[:, start:end]))
106
-
107
- if not quant:
108
- self._pack_attn_weights(l_dict, prefix, start, end, kv_start, kv_end, unite_qkv)
109
- self._pack_ffn_weights(l_dict, prefix, ffn_start, ffn_end)
110
- else:
111
- self._pack_attn_quant_weights(l_dict, prefix, start, end, kv_start, kv_end, unite_qkv)
112
- self._pack_ffn_quant_weights(l_dict, prefix, ffn_start, ffn_end)
113
-
114
- def _pack_attn_weights(self, l_dict, prefix, start, end, kv_start, kv_end, unite_qkv):
115
- """ prepares the dictionary of weights of an attention block """
116
- wq = self._get_from_dict(l_dict, "wq")[start:end, :]
117
- wk = self._get_from_dict(l_dict, "wk")[kv_start:kv_end, :]
118
- wv = self._get_from_dict(l_dict, "wv")[kv_start:kv_end, :]
119
- self.dictionary[prefix + ".attention.wo.weight"] = \
120
- Tensor(ConvertTensor(self._get_from_dict(l_dict, "wo")[:, start:end]))
121
- if unite_qkv:
122
- self.dictionary[prefix + ".attention.wqkv.weight"] = Tensor(ConvertTensor(np.concatenate((wq, wk, wv))))
123
- else:
124
- self.dictionary[prefix + ".attention.wq.weight"] = Tensor(ConvertTensor(wq))
125
- self.dictionary[prefix + ".attention.wkv.weight"] = Tensor(ConvertTensor(np.concatenate((wk, wv))))
126
-
127
- def _pack_ffn_weights(self, l_dict, prefix, ffn_start, ffn_end):
128
- """ prepares the dictionary of weights of an ffn block """
129
- self.dictionary[prefix + ".feed_forward.w2.weight"] = \
130
- Tensor(ConvertTensor(self._get_from_dict(l_dict, "w2")[:, ffn_start:ffn_end]))
131
- w1 = self._get_from_dict(l_dict, "w1")[ffn_start:ffn_end, :]
132
- w3 = self._get_from_dict(l_dict, "w3")[ffn_start:ffn_end, :]
133
- self.dictionary[prefix + ".feed_forward.w13.weight"] = Tensor(ConvertTensor(np.concatenate((w1, w3))))
134
-
135
- def _pack_attn_quant_weights(self, l_dict, prefix, start, end, kv_start, kv_end, unite_qkv):
136
- """ prepares the dictionary of weights of a quantized attention block """
137
- wq, wq_scale, wq_offset = self._get_quant_triplet_from_dict(l_dict, "wq")
138
- wk, wk_scale, wk_offset = self._get_quant_triplet_from_dict(l_dict, "wk")
139
- wv, wv_scale, wv_offset = self._get_quant_triplet_from_dict(l_dict, "wv")
140
- wo, wo_scale, wo_offset = self._get_quant_triplet_from_dict(l_dict, "wo")
141
- self.dictionary[prefix + ".attention.wo.weight"] = Tensor(ConvertTensor(wo[:, start:end], nd2nz=False))
142
- self.dictionary[prefix + ".attention.wo.weight.scale"] = Tensor(wo_scale[start:end])
143
- self.dictionary[prefix + ".attention.wo.weight.offset"] = Tensor(wo_offset[start:end])
144
-
145
- if unite_qkv:
146
- self.dictionary[prefix + ".attention.wqkv.weight"] = \
147
- Tensor(ConvertTensor(np.concatenate((wq[start:end, :], wk[kv_start:kv_end, :], wv[kv_start:kv_end, :])),
148
- nd2nz=False))
149
- self.dictionary[prefix + ".attention.wqkv.weight.scale"] = \
150
- Tensor(np.concatenate((wq_scale[start:end], wk_scale[kv_start:kv_end], wv_scale[kv_start:kv_end])))
151
- self.dictionary[prefix + ".attention.wqkv.weight.offset"] = \
152
- Tensor(np.concatenate((wq_offset[start:end], wk_offset[kv_start:kv_end], wv_offset[kv_start:kv_end])))
153
- else:
154
- self.dictionary[prefix + ".attention.wq.weight"] = Tensor(ConvertTensor(wq[start:end, :], nd2nz=False))
155
- self.dictionary[prefix + ".attention.wq.weight.scale"] = Tensor(wq_scale[start:end])
156
- self.dictionary[prefix + ".attention.wq.weight.offset"] = Tensor(wq_offset[start:end])
157
- self.dictionary[prefix + ".attention.wkv.weight"] = \
158
- Tensor(ConvertTensor(np.concatenate((wk[kv_start:kv_end, :], wv[kv_start:kv_end, :])), nd2nz=False))
159
- self.dictionary[prefix + ".attention.wkv.weight.scale"] = \
160
- Tensor(np.concatenate((wk_scale[kv_start:kv_end], wv_scale[kv_start:kv_end])))
161
- self.dictionary[prefix + ".attention.wkv.weight.offset"] = \
162
- Tensor(np.concatenate((wk_offset[kv_start:kv_end], wv_offset[kv_start:kv_end])))
163
-
164
- def _pack_ffn_quant_weights(self, l_dict, prefix, ffn_start, ffn_end):
165
- """ prepares the dictionary of weights of a quantized ffn block """
166
- w1, w1_scale, w1_offset = self._get_quant_triplet_from_dict(l_dict, "w1")
167
- w2, w2_scale, w2_offset = self._get_quant_triplet_from_dict(l_dict, "w2")
168
- w3, w3_scale, w3_offset = self._get_quant_triplet_from_dict(l_dict, "w3")
169
- self.dictionary[prefix + ".feed_forward.w2.weight"] = Tensor(ConvertTensor(w2[:, ffn_start:ffn_end],
170
- nd2nz=False))
171
- self.dictionary[prefix + ".feed_forward.w2.weight.scale"] = Tensor(w2_scale[ffn_start:ffn_end])
172
- self.dictionary[prefix + ".feed_forward.w2.weight.offset"] = Tensor(w2_offset[ffn_start:ffn_end])
173
-
174
- self.dictionary[prefix + ".feed_forward.w13.weight"] = \
175
- Tensor(ConvertTensor(np.concatenate((w1[ffn_start:ffn_end, :], w3[ffn_start:ffn_end, :])), nd2nz=False))
176
- self.dictionary[prefix + ".feed_forward.w13.weight.scale"] = \
177
- Tensor(np.concatenate((w1_scale[ffn_start:ffn_end], w3_scale[ffn_start:ffn_end])))
178
- self.dictionary[prefix + ".feed_forward.w13.weight.offset"] = \
179
- Tensor(np.concatenate((w1_offset[ffn_start:ffn_end], w3_offset[ffn_start:ffn_end])))
180
-
181
- def _prepare_cos_sin_arrays(self, config, theta=10000):
182
- """ prepares the cosine and sine arrays """
183
- head_dim = config.hidden_size // config.num_heads
184
- max_position_embedding = \
185
- config.max_position_embedding if config.max_position_embedding is not None else config.seq_length
186
- freqs_base = np.arange(0, head_dim, 2)[: (head_dim // 2)].astype(np.float32)
187
- freqs = 1.0 / (theta ** (freqs_base / head_dim))
188
- t = np.arange(0, max_position_embedding, 1).astype(np.float32)
189
- freqs = np.outer(t, freqs)
190
- emb = np.concatenate((freqs, freqs), axis=-1)
191
- freqs_cos = Tensor(np.cos(emb), dtype=dtype.float16)
192
- sin = np.sin(emb)
193
-
194
- sin[:, :int(emb.shape[1]/2)] = -sin[:, :int(emb.shape[1]/2)]
195
- self.dictionary['model.cos.weight'] = freqs_cos
196
- freqs_sin = Tensor(sin, dtype=dtype.float16)
197
- self.dictionary['model.sin.weight'] = freqs_sin
198
-
199
- def set_weights(self, ckpt_dict):
200
- """ load the checkpoint """
201
- self.dictionary = {}
202
- self.dictionary['model.tok_embeddings.embedding_weight'] = \
203
- Tensor(ckpt_dict['model.tok_embeddings.embedding_weight'].asnumpy())
204
- self.dictionary['model.norm_out.weight'] = \
205
- Tensor(ckpt_dict['model.norm_out.weight'].asnumpy(), dtype=dtype.float16)
206
- self._prepare_cos_sin_arrays(self.config)
207
- for layer_id in range(self.config.num_layers):
208
- self._prepare_single_layer(ckpt_dict, self.config, layer_id)
209
-
210
- self.binder.set_weights_map(self.dictionary)
@@ -1,52 +0,0 @@
1
- # Copyright 2024 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ============================================================================
15
- """LLMBoost APIs."""
16
-
17
- from mindspore.common import Tensor
18
-
19
- class LLMBoost():
20
- r"""
21
- Implements an LLM in a single kernel.
22
- it forwards the python function to the C++ binded object
23
- """
24
- def __init__(self, config):
25
- r"""
26
- initialize the parameters of the llm binder.
27
- config is simply the config object of the model
28
- """
29
- from mindspore._c_expression import LlmBoostBinder
30
- self.config = config
31
- self.binder = LlmBoostBinder("AscendNative", config.model_type)
32
- self.binder.init_model(config.to_dict())
33
-
34
- def init(self):
35
- """
36
- Initialize the object
37
- returns True if object needs input manipulation by mindformers
38
- """
39
- return False
40
-
41
- def set_kvcache(self, k_caches=None, v_caches=None):
42
- return
43
-
44
- def forward(self, input_ids, batch_valid_length, position_ids=None):
45
- ret = self.binder.forward([input_ids, batch_valid_length], "nothing really")
46
- return Tensor(ret[0])
47
-
48
- def set_weights(self, ckpt_dict):
49
- self.binder.set_weights_map(ckpt_dict)
50
-
51
- def add_flags(self, is_first_iteration=False):
52
- self.binder.add_flags(is_first_iteration=is_first_iteration)
@@ -1,385 +0,0 @@
1
- # Copyright 2024 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ============================================================================
15
- """boost base class"""
16
- from enum import Enum
17
- import numpy as np
18
- import mindspore as ms
19
- from mindspore import ops, Tensor
20
- from mindspore import log as logger
21
- import mindspore.common.dtype as mstype
22
- from mindspore._c_expression import _set_format
23
- from mindspore.common.parameter import Parameter
24
- from mindspore.experimental.llm_boost.utils import get_real_rank, get_real_group_size
25
- from mindspore.common.initializer import Zero
26
-
27
- FORMAT_NZ = "FRACTAL_NZ"
28
- BUILDIN_BACKEND_NAME = "ATB"
29
-
30
-
31
- class PositionEmbeddingType(int, Enum):
32
- ROPE = 0
33
- ALIBI = 1
34
- ABSOLUTE = 2
35
-
36
-
37
- class NormType(int, Enum):
38
- RMS_NORM = 0
39
- LAYER_NORM = 1
40
-
41
-
42
- class AttentionMask:
43
- """attention mask"""
44
-
45
- @classmethod
46
- def static(cls, max_seq_len, dtype=mstype.float16, need_nz=False):
47
- """cache mask"""
48
- bias_cache = Tensor(
49
- np.tril(np.ones((max_seq_len, max_seq_len), dtype=np.bool_))
50
- ).reshape(max_seq_len, max_seq_len)
51
- bias_cache = ~bias_cache
52
- if dtype == mstype.float16:
53
- mask_value = Tensor(np.finfo(np.float32).min, mstype.float16)
54
- else:
55
- mask_value = Tensor(1)
56
- attn_mask = ops.masked_fill(
57
- Tensor(np.zeros((max_seq_len, max_seq_len)), dtype=mstype.float16),
58
- bias_cache,
59
- mask_value,
60
- )
61
- if need_nz:
62
- # ND -> NZ
63
- attn_mask = ops.reshape(attn_mask, (1, max_seq_len, max_seq_len))
64
- attn_mask = ops.reshape(attn_mask, (1, max_seq_len, max_seq_len // 16, 16))
65
- attn_mask = ops.transpose(attn_mask, (0, 2, 1, 3)).contiguous()
66
- attn_mask = _set_format(attn_mask, FORMAT_NZ)
67
- return attn_mask
68
-
69
-
70
- class AtbBoostBase:
71
- """atb boost base class"""
72
-
73
- def __init__(self, config):
74
- super().__init__()
75
- self.backend_name = BUILDIN_BACKEND_NAME
76
- self.is_first_iteration = False
77
- self.config = config
78
- self.dtype = config.compute_dtype
79
- self.num_heads = config.num_heads
80
- self.num_kv_heads = config.n_kv_heads if config.n_kv_heads else self.num_heads
81
- self.num_layers = config.num_layers
82
- self.n_kv_heads = config.n_kv_heads if config.n_kv_heads else config.num_heads
83
- self.head_dim = config.hidden_size // self.num_heads
84
- self.need_nz = False
85
- if hasattr(config, "need_nz"):
86
- self.need_nz = config.need_nz
87
- self.placeholder = Tensor(np.zeros(1), dtype=self.dtype)
88
- self.lm_head_indices_fake = Tensor([0], dtype=mstype.int64)
89
- self.position_embedding_type = PositionEmbeddingType.ROPE
90
- self.add_norm_enable = True
91
- self.max_decode_length = self.config.max_decode_length
92
- self.max_base_len = 128
93
- self.attn_mask = AttentionMask.static(
94
- self.max_base_len, dtype=self.dtype, need_nz=self.need_nz
95
- )
96
-
97
- self.cast = ops.Cast()
98
- self.reshape = ops.Reshape()
99
- self.kv_quant = None
100
- self.rank_id = get_real_rank()
101
- self.device_num = get_real_group_size()
102
- self.ascend_weight = []
103
- self.k_caches = []
104
- self.v_caches = []
105
-
106
- def _convert_tensor_format_and_dtype(self, tensor, dtype=mstype.float16):
107
- tensor = self.cast(tensor, dtype=dtype)
108
- if self.need_nz:
109
- tensor = _set_format(tensor, FORMAT_NZ)
110
- return tensor
111
-
112
- def _convert_qkv_concat_weight(self, param_dict):
113
- """convert qkv concat weight"""
114
- for i in range(self.num_layers):
115
- # qkv weight concat
116
- wq_weight_name = f"model.layers.{i}.attention.wq.weight"
117
- wk_weight_name = f"model.layers.{i}.attention.wk.weight"
118
- wv_weight_name = f"model.layers.{i}.attention.wv.weight"
119
- qkv_concat_weight_name = f"model.layers.{i}.attention.w_qkv.weight"
120
- if wq_weight_name not in param_dict:
121
- break
122
- wq_weight = param_dict[wq_weight_name].asnumpy()
123
- wk_weight = param_dict[wk_weight_name].asnumpy()
124
- wv_weight = param_dict[wv_weight_name].asnumpy()
125
- qkv_weight = np.concatenate((wq_weight, wk_weight, wv_weight), 0)
126
- param_dict[qkv_concat_weight_name] = Parameter(
127
- qkv_weight, name=qkv_concat_weight_name
128
- )
129
-
130
- # gate hidden weight concat
131
- ffn_gate_weight_name = f"model.layers.{i}.feed_forward.w1.weight"
132
- ffn_hidden_weight_name = f"model.layers.{i}.feed_forward.w3.weight"
133
- gate_hidden_concat_weight_name = (
134
- f"model.layers.{i}.feed_forward.w_gate_hidden.weight"
135
- )
136
-
137
- ffn_gate_weight = param_dict[ffn_gate_weight_name].asnumpy()
138
- ffn_hidden_weight = param_dict[ffn_hidden_weight_name].asnumpy()
139
- gate_hidden_weight = np.concatenate((ffn_gate_weight, ffn_hidden_weight), 0)
140
- param_dict[gate_hidden_concat_weight_name] = Parameter(
141
- gate_hidden_weight, name=gate_hidden_concat_weight_name
142
- )
143
-
144
- param_dict.pop(wq_weight_name)
145
- param_dict.pop(wk_weight_name)
146
- param_dict.pop(wv_weight_name)
147
- param_dict.pop(ffn_gate_weight_name)
148
- param_dict.pop(ffn_hidden_weight_name)
149
- logger.info(f"transform: {qkv_concat_weight_name}")
150
- logger.info(f"transform: {gate_hidden_concat_weight_name}")
151
-
152
- for i in range(self.num_layers):
153
- # qkv bias concat
154
- wq_bias_name = f"model.layers.{i}.attention.wq.bias"
155
- wk_bias_name = f"model.layers.{i}.attention.wk.bias"
156
- wv_bias_name = f"model.layers.{i}.attention.wv.bias"
157
- qkv_concat_bias_name = f"model.layers.{i}.attention.w_qkv.bias"
158
- if wq_bias_name not in param_dict:
159
- break
160
-
161
- wq_bias_weight = param_dict[wq_bias_name].asnumpy()
162
- wk_bias_weight = param_dict[wk_bias_name].asnumpy()
163
- wv_bias_weight = param_dict[wv_bias_name].asnumpy()
164
- qkv_bias_weight = np.concatenate(
165
- (wq_bias_weight, wk_bias_weight, wv_bias_weight), 0
166
- )
167
- param_dict[qkv_concat_bias_name] = Parameter(
168
- qkv_bias_weight, name=qkv_concat_bias_name
169
- )
170
-
171
- param_dict.pop(wq_bias_name)
172
- param_dict.pop(wk_bias_name)
173
- param_dict.pop(wv_bias_name)
174
- logger.info(f"transform: {qkv_concat_bias_name}")
175
- return param_dict
176
-
177
- def set_weights(self, parm_dict, dtype=mstype.float16):
178
- """set weights for llm boost"""
179
- self._convert_qkv_concat_weight(parm_dict)
180
- embedding_weight_name = "model.tok_embeddings.embedding_weight"
181
- attention_norm_name = "attention_norm"
182
- qkv_name = "attention.w_qkv"
183
- o_name = "attention.wo"
184
- mlp_norm_name = "ffn_norm"
185
- mlp_gate_name = "feed_forward.w_gate_hidden"
186
- mlp_down_name = "feed_forward.w2"
187
- norm_out_name = "model.norm_out"
188
- lm_head_name = "lm_head"
189
- placeholder = Parameter(Tensor(np.zeros(1), dtype=dtype))
190
-
191
- ascend_weight = []
192
- ascend_weight.append(self.cast(parm_dict[embedding_weight_name], dtype))
193
- for i in range(self.num_layers):
194
- ascend_weight.append(
195
- self._convert_tensor_format_and_dtype(
196
- parm_dict[f"model.layers.{i}.{attention_norm_name}.weight"], dtype
197
- )
198
- )
199
- ascend_weight.extend([placeholder] * 3)
200
-
201
- ascend_weight.append(
202
- self._convert_tensor_format_and_dtype(
203
- parm_dict[f"model.layers.{i}.{qkv_name}.weight"], dtype
204
- )
205
- )
206
- ascend_weight.append(
207
- self._convert_tensor_format_and_dtype(
208
- parm_dict.get(f"model.layers.{i}.{qkv_name}.bias", placeholder),
209
- dtype,
210
- )
211
- )
212
- ascend_weight.extend([placeholder] * 16)
213
-
214
- ascend_weight.append(
215
- self._convert_tensor_format_and_dtype(
216
- parm_dict[f"model.layers.{i}.{o_name}.weight"], dtype
217
- )
218
- )
219
- ascend_weight.append(
220
- self._convert_tensor_format_and_dtype(
221
- parm_dict.get(f"model.layers.{i}.{o_name}.bias", placeholder), dtype
222
- )
223
- )
224
- ascend_weight.extend([placeholder] * 4)
225
-
226
- ascend_weight.append(
227
- self._convert_tensor_format_and_dtype(
228
- parm_dict[f"model.layers.{i}.{mlp_norm_name}.weight"], dtype
229
- )
230
- )
231
- ascend_weight.extend([placeholder] * 3)
232
-
233
- ascend_weight.append(
234
- self._convert_tensor_format_and_dtype(
235
- parm_dict[f"model.layers.{i}.{mlp_gate_name}.weight"], dtype
236
- )
237
- )
238
- ascend_weight.append(
239
- self._convert_tensor_format_and_dtype(
240
- parm_dict.get(
241
- f"model.layers.{i}.{mlp_gate_name}.bias", placeholder
242
- ),
243
- dtype,
244
- )
245
- )
246
- ascend_weight.extend([placeholder] * 10)
247
-
248
- ascend_weight.append(
249
- self._convert_tensor_format_and_dtype(
250
- parm_dict[f"model.layers.{i}.{mlp_down_name}.weight"], dtype
251
- )
252
- )
253
- ascend_weight.append(
254
- self._convert_tensor_format_and_dtype(
255
- parm_dict.get(
256
- f"model.layers.{i}.{mlp_down_name}.bias", placeholder
257
- ),
258
- dtype,
259
- )
260
- )
261
- ascend_weight.extend([placeholder] * 4)
262
-
263
- ascend_weight.append(
264
- self._convert_tensor_format_and_dtype(
265
- parm_dict[f"{norm_out_name}.weight"], dtype
266
- )
267
- )
268
- ascend_weight.append(
269
- self._convert_tensor_format_and_dtype(
270
- parm_dict[f"{lm_head_name}.weight"], dtype
271
- )
272
- )
273
- self.ascend_weight = ascend_weight
274
- self.atb_encoder_operation.set_weights(ascend_weight)
275
- self.atb_decoder_operation.set_weights(ascend_weight)
276
-
277
- def set_kvcache(self, k_caches=None, v_caches=None):
278
- """set kv_cache for llm boost"""
279
- if not k_caches or v_caches:
280
- if self.need_nz:
281
- kv_shape = (
282
- self.config.num_blocks,
283
- self.num_kv_heads * self.head_dim // self.device_num // 16,
284
- self.config.block_size,
285
- 16,
286
- )
287
- k_caches = [
288
- _set_format(
289
- Parameter(
290
- Tensor(shape=kv_shape, dtype=self.dtype, init=Zero())
291
- ),
292
- FORMAT_NZ,
293
- )
294
- for _ in range(self.num_layers)
295
- ]
296
- v_caches = [
297
- _set_format(
298
- Parameter(
299
- Tensor(shape=kv_shape, dtype=self.dtype, init=Zero())
300
- ),
301
- FORMAT_NZ,
302
- )
303
- for _ in range(self.num_layers)
304
- ]
305
- else:
306
- kv_shape = (
307
- self.config.num_blocks,
308
- self.config.block_size,
309
- self.num_kv_heads // self.device_num,
310
- self.head_dim,
311
- )
312
- k_caches = [
313
- Parameter(Tensor(shape=kv_shape, dtype=self.dtype, init=Zero()))
314
- for _ in range(self.num_layers)
315
- ]
316
- v_caches = [
317
- Parameter(Tensor(shape=kv_shape, dtype=self.dtype, init=Zero()))
318
- for _ in range(self.num_layers)
319
- ]
320
- self.k_caches = k_caches
321
- self.v_caches = v_caches
322
- self.atb_encoder_operation.set_kvcache(k_caches, v_caches)
323
- self.atb_decoder_operation.set_kvcache(k_caches, v_caches)
324
-
325
- def add_flags(self, is_first_iteration):
326
- """add_flags."""
327
- self.is_first_iteration = is_first_iteration
328
-
329
- def _execute_operator(self, acl_inputs, acl_param):
330
- """execute operator."""
331
- if self.is_first_iteration:
332
- acl_model_out = self.atb_encoder_operation.forward(acl_inputs, acl_param)
333
- else:
334
- acl_model_out = self.atb_decoder_operation.forward(acl_inputs, acl_param)
335
- acl_hidden_state = acl_model_out[0]
336
- return acl_hidden_state
337
-
338
- def forward(self, boost_inputs):
339
- r"""
340
- LlmBoost forward.
341
- """
342
- input_ids = boost_inputs.get("input_ids", None)
343
- position_ids = boost_inputs.get("position_ids", None)
344
- cos_embed = boost_inputs.get("cos_embed", None)
345
- sin_embed = boost_inputs.get("sin_embed", None)
346
- block_tables = boost_inputs.get("block_tables", None)
347
- slot_mapping = boost_inputs.get("slot_mapping", None)
348
- batch_valid_length = boost_inputs.get("batch_valid_length", None)
349
- lm_head_indices = boost_inputs.get("lm_head_indices", None)
350
- seqLen = boost_inputs.get("seq_lens", None)
351
- input_ids = self.reshape(input_ids, (-1,))
352
- if self.is_first_iteration:
353
- attention_mask = self.attn_mask
354
- else:
355
- if position_ids is None:
356
- position_ids = batch_valid_length - 1
357
- attention_mask = self.placeholder
358
- lm_head_indices = self.lm_head_indices_fake
359
-
360
- if input_ids is not None and input_ids.dtype != mstype.int64:
361
- input_ids = self.cast(input_ids, mstype.int64)
362
- if position_ids is not None and position_ids.dtype != mstype.int64:
363
- position_ids = self.cast(position_ids, mstype.int64)
364
- if batch_valid_length is not None and batch_valid_length.dtype != mstype.int32:
365
- batch_valid_length = self.cast(batch_valid_length, mstype.int32)
366
- if lm_head_indices is not None and lm_head_indices.dtype != mstype.int64:
367
- lm_head_indices = self.cast(lm_head_indices, mstype.int64)
368
-
369
- acl_inputs, acl_param = self._prepare_inputs(
370
- prefill=self.is_first_iteration,
371
- input_ids=input_ids,
372
- position_ids=position_ids,
373
- cos_embed=cos_embed,
374
- sin_embed=sin_embed,
375
- attention_mask=attention_mask,
376
- block_tables=block_tables,
377
- slots=slot_mapping,
378
- input_lengths=batch_valid_length,
379
- lm_head_indices=lm_head_indices,
380
- seqLen=seqLen,
381
- )
382
- ms.hal.synchronize()
383
- logits = self._execute_operator(acl_inputs, acl_param)
384
- logits = self.cast(logits, mstype.float32)
385
- return logits