mindspore 2.3.0__cp310-cp310-win_amd64.whl → 2.4.0__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (308) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  3. mindspore/Newtonsoft.Json.dll +0 -0
  4. mindspore/__init__.py +3 -1
  5. mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
  6. mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
  7. mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
  8. mindspore/_checkparam.py +50 -9
  9. mindspore/_extends/parse/compile_config.py +41 -0
  10. mindspore/_extends/parse/parser.py +9 -7
  11. mindspore/_extends/parse/standard_method.py +52 -14
  12. mindspore/_extends/pijit/pijit_func_white_list.py +350 -24
  13. mindspore/amp.py +24 -10
  14. mindspore/atlprov.dll +0 -0
  15. mindspore/avcodec-59.dll +0 -0
  16. mindspore/avdevice-59.dll +0 -0
  17. mindspore/avfilter-8.dll +0 -0
  18. mindspore/avformat-59.dll +0 -0
  19. mindspore/avutil-57.dll +0 -0
  20. mindspore/c1.dll +0 -0
  21. mindspore/c1xx.dll +0 -0
  22. mindspore/c2.dll +0 -0
  23. mindspore/common/__init__.py +6 -4
  24. mindspore/common/_pijit_context.py +190 -0
  25. mindspore/common/_register_for_tensor.py +2 -1
  26. mindspore/common/_tensor_overload.py +139 -0
  27. mindspore/common/api.py +102 -87
  28. mindspore/common/dump.py +5 -6
  29. mindspore/common/generator.py +1 -7
  30. mindspore/common/hook_handle.py +14 -26
  31. mindspore/common/mindir_util.py +2 -2
  32. mindspore/common/parameter.py +46 -13
  33. mindspore/common/recompute.py +39 -9
  34. mindspore/common/sparse_tensor.py +7 -3
  35. mindspore/common/tensor.py +209 -29
  36. mindspore/communication/__init__.py +1 -1
  37. mindspore/communication/_comm_helper.py +38 -3
  38. mindspore/communication/comm_func.py +310 -55
  39. mindspore/communication/management.py +14 -14
  40. mindspore/context.py +123 -22
  41. mindspore/dataset/__init__.py +1 -1
  42. mindspore/dataset/audio/__init__.py +1 -1
  43. mindspore/dataset/core/config.py +7 -0
  44. mindspore/dataset/core/validator_helpers.py +7 -0
  45. mindspore/dataset/engine/cache_client.py +1 -1
  46. mindspore/dataset/engine/datasets.py +72 -44
  47. mindspore/dataset/engine/datasets_audio.py +7 -7
  48. mindspore/dataset/engine/datasets_standard_format.py +53 -3
  49. mindspore/dataset/engine/datasets_text.py +20 -20
  50. mindspore/dataset/engine/datasets_user_defined.py +174 -104
  51. mindspore/dataset/engine/datasets_vision.py +33 -33
  52. mindspore/dataset/engine/iterators.py +29 -0
  53. mindspore/dataset/engine/obs/util.py +7 -0
  54. mindspore/dataset/engine/queue.py +114 -60
  55. mindspore/dataset/engine/serializer_deserializer.py +2 -2
  56. mindspore/dataset/engine/validators.py +34 -14
  57. mindspore/dataset/text/__init__.py +1 -4
  58. mindspore/dataset/transforms/__init__.py +0 -3
  59. mindspore/dataset/utils/line_reader.py +2 -0
  60. mindspore/dataset/vision/__init__.py +1 -4
  61. mindspore/dataset/vision/utils.py +1 -1
  62. mindspore/dataset/vision/validators.py +2 -1
  63. mindspore/dnnl.dll +0 -0
  64. mindspore/dpcmi.dll +0 -0
  65. mindspore/{nn/extend → experimental/es}/__init__.py +4 -11
  66. mindspore/experimental/es/embedding_service.py +883 -0
  67. mindspore/{nn/layer → experimental/es}/embedding_service_layer.py +218 -30
  68. mindspore/experimental/llm_boost/__init__.py +21 -0
  69. mindspore/{nn/extend/layer → experimental/llm_boost/atb}/__init__.py +4 -8
  70. mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
  71. mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
  72. mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
  73. mindspore/experimental/llm_boost/register.py +129 -0
  74. mindspore/experimental/llm_boost/utils.py +31 -0
  75. mindspore/experimental/optim/adamw.py +85 -0
  76. mindspore/experimental/optim/optimizer.py +3 -0
  77. mindspore/hal/__init__.py +3 -3
  78. mindspore/hal/contiguous_tensors_handle.py +175 -0
  79. mindspore/hal/stream.py +18 -0
  80. mindspore/include/api/model_group.h +13 -1
  81. mindspore/include/api/types.h +10 -10
  82. mindspore/include/dataset/config.h +2 -2
  83. mindspore/include/dataset/constants.h +2 -2
  84. mindspore/include/dataset/execute.h +2 -2
  85. mindspore/include/dataset/vision.h +4 -0
  86. mindspore/jpeg62.dll +0 -0
  87. mindspore/log.py +1 -1
  88. mindspore/mindrecord/filewriter.py +68 -51
  89. mindspore/mindspore_backend.dll +0 -0
  90. mindspore/mindspore_common.dll +0 -0
  91. mindspore/mindspore_core.dll +0 -0
  92. mindspore/mindspore_glog.dll +0 -0
  93. mindspore/mindspore_np_dtype.dll +0 -0
  94. mindspore/mindspore_ops.dll +0 -0
  95. mindspore/mint/__init__.py +495 -46
  96. mindspore/mint/distributed/__init__.py +31 -0
  97. mindspore/mint/distributed/distributed.py +254 -0
  98. mindspore/mint/nn/__init__.py +266 -21
  99. mindspore/mint/nn/functional.py +125 -19
  100. mindspore/mint/nn/layer/__init__.py +39 -0
  101. mindspore/mint/nn/layer/activation.py +133 -0
  102. mindspore/mint/nn/layer/normalization.py +477 -0
  103. mindspore/mint/nn/layer/pooling.py +110 -0
  104. mindspore/mint/optim/adamw.py +28 -7
  105. mindspore/mint/special/__init__.py +63 -0
  106. mindspore/msobj140.dll +0 -0
  107. mindspore/mspdb140.dll +0 -0
  108. mindspore/mspdbcore.dll +0 -0
  109. mindspore/mspdbst.dll +0 -0
  110. mindspore/mspft140.dll +0 -0
  111. mindspore/msvcdis140.dll +0 -0
  112. mindspore/msvcp140_1.dll +0 -0
  113. mindspore/msvcp140_2.dll +0 -0
  114. mindspore/msvcp140_atomic_wait.dll +0 -0
  115. mindspore/msvcp140_codecvt_ids.dll +0 -0
  116. mindspore/multiprocessing/__init__.py +2 -1
  117. mindspore/nn/__init__.py +0 -1
  118. mindspore/nn/cell.py +275 -93
  119. mindspore/nn/layer/activation.py +211 -44
  120. mindspore/nn/layer/basic.py +113 -3
  121. mindspore/nn/layer/embedding.py +120 -2
  122. mindspore/nn/layer/normalization.py +101 -5
  123. mindspore/nn/layer/padding.py +34 -48
  124. mindspore/nn/layer/pooling.py +161 -7
  125. mindspore/nn/layer/transformer.py +3 -3
  126. mindspore/nn/loss/__init__.py +2 -2
  127. mindspore/nn/loss/loss.py +84 -6
  128. mindspore/nn/optim/__init__.py +2 -1
  129. mindspore/nn/optim/adadelta.py +1 -1
  130. mindspore/nn/optim/adam.py +1 -1
  131. mindspore/nn/optim/lamb.py +1 -1
  132. mindspore/nn/optim/tft_wrapper.py +127 -0
  133. mindspore/nn/wrap/cell_wrapper.py +12 -23
  134. mindspore/nn/wrap/grad_reducer.py +5 -5
  135. mindspore/nn/wrap/loss_scale.py +17 -3
  136. mindspore/numpy/__init__.py +1 -1
  137. mindspore/numpy/array_creations.py +65 -68
  138. mindspore/numpy/array_ops.py +64 -60
  139. mindspore/numpy/fft.py +610 -75
  140. mindspore/numpy/logic_ops.py +11 -10
  141. mindspore/numpy/math_ops.py +85 -84
  142. mindspore/numpy/utils_const.py +4 -4
  143. mindspore/opencv_core452.dll +0 -0
  144. mindspore/opencv_imgcodecs452.dll +0 -0
  145. mindspore/opencv_imgproc452.dll +0 -0
  146. mindspore/ops/__init__.py +6 -4
  147. mindspore/ops/_grad_experimental/grad_comm_ops.py +47 -3
  148. mindspore/ops/_grad_experimental/grad_math_ops.py +0 -22
  149. mindspore/ops/_vmap/vmap_array_ops.py +2 -4
  150. mindspore/ops/_vmap/vmap_math_ops.py +17 -1
  151. mindspore/ops/_vmap/vmap_nn_ops.py +43 -2
  152. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +85 -7
  153. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +2 -0
  154. mindspore/ops/auto_generate/gen_extend_func.py +734 -13
  155. mindspore/ops/auto_generate/gen_ops_def.py +2420 -381
  156. mindspore/ops/auto_generate/gen_ops_prim.py +5196 -1659
  157. mindspore/ops/auto_generate/pyboost_inner_prim.py +176 -56
  158. mindspore/ops/composite/base.py +85 -48
  159. mindspore/ops/composite/multitype_ops/_compile_utils.py +1 -0
  160. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -2
  161. mindspore/ops/function/__init__.py +22 -0
  162. mindspore/ops/function/array_func.py +490 -153
  163. mindspore/ops/function/debug_func.py +113 -1
  164. mindspore/ops/function/fft_func.py +15 -2
  165. mindspore/ops/function/grad/grad_func.py +3 -2
  166. mindspore/ops/function/math_func.py +558 -207
  167. mindspore/ops/function/nn_func.py +817 -383
  168. mindspore/ops/function/other_func.py +3 -2
  169. mindspore/ops/function/random_func.py +184 -8
  170. mindspore/ops/function/reshard_func.py +13 -11
  171. mindspore/ops/function/sparse_unary_func.py +1 -1
  172. mindspore/ops/function/vmap_func.py +3 -2
  173. mindspore/ops/functional.py +24 -14
  174. mindspore/ops/op_info_register.py +3 -3
  175. mindspore/ops/operations/__init__.py +6 -1
  176. mindspore/ops/operations/_grad_ops.py +2 -76
  177. mindspore/ops/operations/_infer_ops.py +1 -1
  178. mindspore/ops/operations/_inner_ops.py +71 -94
  179. mindspore/ops/operations/array_ops.py +12 -146
  180. mindspore/ops/operations/comm_ops.py +42 -53
  181. mindspore/ops/operations/custom_ops.py +83 -19
  182. mindspore/ops/operations/debug_ops.py +42 -10
  183. mindspore/ops/operations/manually_defined/_inner.py +12 -0
  184. mindspore/ops/operations/manually_defined/ops_def.py +265 -10
  185. mindspore/ops/operations/math_ops.py +12 -223
  186. mindspore/ops/operations/nn_ops.py +20 -114
  187. mindspore/ops/operations/other_ops.py +7 -4
  188. mindspore/ops/operations/random_ops.py +46 -1
  189. mindspore/ops/primitive.py +18 -6
  190. mindspore/ops_generate/arg_dtype_cast.py +2 -0
  191. mindspore/ops_generate/gen_aclnn_implement.py +11 -11
  192. mindspore/ops_generate/gen_constants.py +36 -0
  193. mindspore/ops_generate/gen_ops.py +67 -52
  194. mindspore/ops_generate/gen_ops_inner_prim.py +1 -1
  195. mindspore/ops_generate/gen_pyboost_func.py +131 -47
  196. mindspore/ops_generate/op_proto.py +10 -3
  197. mindspore/ops_generate/pyboost_utils.py +14 -1
  198. mindspore/ops_generate/template.py +43 -21
  199. mindspore/parallel/__init__.py +3 -1
  200. mindspore/parallel/_auto_parallel_context.py +28 -8
  201. mindspore/parallel/_cell_wrapper.py +83 -0
  202. mindspore/parallel/_parallel_serialization.py +47 -19
  203. mindspore/parallel/_tensor.py +81 -11
  204. mindspore/parallel/_utils.py +13 -1
  205. mindspore/parallel/algo_parameter_config.py +5 -5
  206. mindspore/parallel/checkpoint_transform.py +46 -39
  207. mindspore/parallel/cluster/process_entity/__init__.py +1 -1
  208. mindspore/parallel/cluster/process_entity/_api.py +31 -23
  209. mindspore/parallel/cluster/process_entity/_utils.py +2 -27
  210. mindspore/parallel/parameter_broadcast.py +3 -4
  211. mindspore/parallel/shard.py +162 -31
  212. mindspore/parallel/transform_safetensors.py +993 -0
  213. mindspore/pgodb140.dll +0 -0
  214. mindspore/pgort140.dll +0 -0
  215. mindspore/profiler/__init__.py +2 -1
  216. mindspore/profiler/common/constant.py +29 -0
  217. mindspore/profiler/common/registry.py +47 -0
  218. mindspore/profiler/common/util.py +28 -0
  219. mindspore/profiler/dynamic_profiler.py +694 -0
  220. mindspore/profiler/envprofiling.py +17 -19
  221. mindspore/profiler/parser/ascend_analysis/constant.py +18 -0
  222. mindspore/profiler/parser/ascend_analysis/file_manager.py +25 -4
  223. mindspore/profiler/parser/ascend_analysis/function_event.py +43 -19
  224. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +31 -26
  225. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +56 -10
  226. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +55 -8
  227. mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
  228. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +27 -20
  229. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +9 -2
  230. mindspore/profiler/parser/ascend_msprof_exporter.py +5 -4
  231. mindspore/profiler/parser/ascend_timeline_generator.py +27 -25
  232. mindspore/profiler/parser/base_timeline_generator.py +19 -25
  233. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
  234. mindspore/profiler/parser/framework_parser.py +1 -391
  235. mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
  236. mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
  237. mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
  238. mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
  239. mindspore/profiler/parser/memory_usage_parser.py +0 -154
  240. mindspore/profiler/parser/profiler_info.py +78 -6
  241. mindspore/profiler/profiler.py +153 -0
  242. mindspore/profiler/profiling.py +280 -412
  243. mindspore/rewrite/__init__.py +1 -2
  244. mindspore/rewrite/common/namespace.py +4 -4
  245. mindspore/rewrite/symbol_tree/symbol_tree.py +3 -3
  246. mindspore/run_check/_check_version.py +36 -103
  247. mindspore/safeguard/rewrite_obfuscation.py +591 -247
  248. mindspore/swresample-4.dll +0 -0
  249. mindspore/swscale-6.dll +0 -0
  250. mindspore/tbbmalloc.dll +0 -0
  251. mindspore/tinyxml2.dll +0 -0
  252. mindspore/train/__init__.py +4 -3
  253. mindspore/train/_utils.py +28 -2
  254. mindspore/train/amp.py +171 -53
  255. mindspore/train/callback/__init__.py +2 -2
  256. mindspore/train/callback/_callback.py +4 -4
  257. mindspore/train/callback/_checkpoint.py +85 -22
  258. mindspore/train/callback/_cluster_monitor.py +1 -1
  259. mindspore/train/callback/_flops_collector.py +1 -0
  260. mindspore/train/callback/_loss_monitor.py +3 -3
  261. mindspore/train/callback/_on_request_exit.py +134 -31
  262. mindspore/train/callback/_summary_collector.py +5 -5
  263. mindspore/train/callback/_tft_register.py +352 -0
  264. mindspore/train/dataset_helper.py +7 -3
  265. mindspore/train/metrics/metric.py +3 -3
  266. mindspore/train/metrics/roc.py +4 -4
  267. mindspore/train/mind_ir_pb2.py +44 -39
  268. mindspore/train/model.py +134 -58
  269. mindspore/train/serialization.py +336 -112
  270. mindspore/turbojpeg.dll +0 -0
  271. mindspore/utils/__init__.py +21 -0
  272. mindspore/utils/utils.py +60 -0
  273. mindspore/vcmeta.dll +0 -0
  274. mindspore/vcruntime140.dll +0 -0
  275. mindspore/vcruntime140_1.dll +0 -0
  276. mindspore/version.py +1 -1
  277. {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/METADATA +6 -2
  278. {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/RECORD +281 -275
  279. mindspore/include/c_api/ms/abstract.h +0 -67
  280. mindspore/include/c_api/ms/attribute.h +0 -197
  281. mindspore/include/c_api/ms/base/handle_types.h +0 -43
  282. mindspore/include/c_api/ms/base/macros.h +0 -32
  283. mindspore/include/c_api/ms/base/status.h +0 -33
  284. mindspore/include/c_api/ms/base/types.h +0 -283
  285. mindspore/include/c_api/ms/context.h +0 -102
  286. mindspore/include/c_api/ms/graph.h +0 -160
  287. mindspore/include/c_api/ms/node.h +0 -606
  288. mindspore/include/c_api/ms/tensor.h +0 -161
  289. mindspore/include/c_api/ms/value.h +0 -84
  290. mindspore/mindspore_shared_lib.dll +0 -0
  291. mindspore/nn/extend/basic.py +0 -140
  292. mindspore/nn/extend/embedding.py +0 -143
  293. mindspore/nn/extend/layer/normalization.py +0 -109
  294. mindspore/nn/extend/pooling.py +0 -117
  295. mindspore/nn/layer/embedding_service.py +0 -531
  296. mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
  297. mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
  298. mindspore/ops/extend/__init__.py +0 -53
  299. mindspore/ops/extend/array_func.py +0 -218
  300. mindspore/ops/extend/math_func.py +0 -76
  301. mindspore/ops/extend/nn_func.py +0 -308
  302. mindspore/ops/silent_check.py +0 -162
  303. mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
  304. mindspore/profiler/parser/msadvisor_parser.py +0 -240
  305. mindspore/train/callback/_mindio_ttp.py +0 -443
  306. {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/WHEEL +0 -0
  307. {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/entry_points.txt +0 -0
  308. {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/top_level.txt +0 -0
@@ -19,8 +19,10 @@ import mindspore as ms
19
19
  from mindspore import nn, ops, Tensor, Parameter
20
20
  from mindspore.ops.auto_generate import init_partition_map, init_embedding_hashmap, embedding_table_find_and_init,\
21
21
  embedding_table_find, fake_remote_lookup_uniqued
22
- from mindspore.ops.operations.manually_defined import EmbeddingTableImport, EmbeddingTableExport, \
23
- EmbeddingComputeVarImport, EmbeddingComputeVarExport
22
+ from mindspore.ops.auto_generate import EmbeddingTableImport, EmbeddingTableExport, \
23
+ EmbeddingComputeVarImport, EmbeddingComputeVarExport, EmbeddingTableEvict, EmbeddingFeatureMappingV2, \
24
+ EmbeddingFeatureMappingTableSize, EmbeddingFeatureMappingFind, EmbeddingFeatureMappingExport, \
25
+ EmbeddingFeatureMappingFileSize, EmbeddingFeatureMappingImport, EmbeddingFeatureMappingInsert
24
26
 
25
27
 
26
28
  class CounterFilter:
@@ -55,12 +57,14 @@ def _get_backward_float_params(optimizer_mode):
55
57
  [beta1_power, beta2_power, lr, weight_decay, beta1, beta2, epsilon]
56
58
  - when the backward_mode is 'adagrad', it means [lr,]
57
59
  """
58
- if optimizer_mode == "adagrad":
60
+ if optimizer_mode == "adagrad" or optimizer_mode == "sgd":
59
61
  return [0.001]
60
62
  if optimizer_mode == "adam":
61
63
  return [0.9, 0.99, 0.001, 0.9, 0.999, 1e-08]
62
64
  if optimizer_mode == "ftrl":
63
65
  return [0.001, -0.5, 0.0, 0.0]
66
+ if optimizer_mode == "rmsprop":
67
+ return [0.001, 0.9, 0.1, 1e-08]
64
68
  # adamw
65
69
  return [0.9, 0.99, 0.001, 0.01, 0.9, 0.999, 1e-08]
66
70
 
@@ -99,6 +103,9 @@ class ESInitLayer(nn.Cell):
99
103
  self.default_value = None
100
104
 
101
105
  def construct(self):
106
+ """
107
+ ESInitLayer construct: init embedding hashmap
108
+ """
102
109
  init_partition = init_partition_map(self.ps_num_tensor,
103
110
  self.ps_ids_tensor,
104
111
  _embedding_dim=self.embedding_dim,
@@ -145,9 +152,36 @@ class ESInitLayer(nn.Cell):
145
152
 
146
153
 
147
154
  class EsEmbeddingLookup(nn.Cell):
155
+ r"""
156
+ Look up a PS embedding.
157
+
158
+ .. warning::
159
+ This is an experimental EmbeddingService API that is subject to change.
160
+
161
+ Args:
162
+ table_id (int): The table id.
163
+ es_initializer (EsInitializer): The EsInitialize object for PS embedding with table_id,
164
+ which can be None when the inference is performed.
165
+ embedding_dim (int): The embedding dim of keys for PS embedding with table_id.
166
+ max_key_num (int): The num of keys when lookup.
167
+ optimizer_mode (str): The type of optimizer. Default is ``None``.
168
+ optimizer_params (tuple[float]): The parameters of optimizer. Default is ``None``.
169
+ es_filter (CounterFilter): The option of counter filter for PS embedding with table_id. Default is ``None``.
170
+ es_padding_key (PaddingParamsOption): The option of padding key for PS embedding with table_id.
171
+ Default is ``None``.
172
+ es_completion_key (CompletionKeyOption): The option of completion key for PS embedding with table_id.
173
+ Default is ``None``.
174
+
175
+ Inputs:
176
+ - **keys** (Tensor): The keys of each feature in PS embedding.
177
+ - **actual_keys_input** (Tensor): Tensor composed of all unique elements of keys.
178
+ - **unique_indices** (Tensor): The index value of each element in keys to actual_keys_input .
179
+ - **key_count** (Tensor): The count of each element in the actual_keys_input to keys.
180
+
181
+ Supported Platforms:
182
+ ``Atlas A2 training series products``
148
183
  """
149
- EsEmbeddingLookup.
150
- """
184
+
151
185
  def __init__(self, table_id, es_initializer, embedding_dim, max_key_num, optimizer_mode=None,
152
186
  optimizer_params=None, es_filter=None, es_padding_key=None, es_completion_key=None):
153
187
  super(EsEmbeddingLookup, self).__init__()
@@ -182,7 +216,7 @@ class EsEmbeddingLookup(nn.Cell):
182
216
  self.filter_freq = 1
183
217
  self.default_key_or_value = 1
184
218
  self.default_key = 0
185
- self.default_value = -1.0
219
+ self.default_value = 1.0
186
220
 
187
221
  self.global_step = 1
188
222
  if es_padding_key is not None:
@@ -193,7 +227,7 @@ class EsEmbeddingLookup(nn.Cell):
193
227
  self.mask_zero = 0
194
228
  self.padding_key = 0
195
229
  self.padding_key_mask = 1
196
- if self.optimizer_mode in ["adam", "ftrl", "adagrad"]:
230
+ if self.optimizer_mode in ["adam", "ftrl", "adagrad", "sgd", "rmsprop"]:
197
231
  self.backward_int_params = ([self.global_step], [self.mask_zero],
198
232
  [self.padding_key], [self.padding_key_mask])
199
233
  else:
@@ -211,6 +245,9 @@ class EsEmbeddingLookup(nn.Cell):
211
245
  self.max_grad_norm = Tensor([1.0], ms.float32)
212
246
 
213
247
  def construct(self, keys, actual_keys_input=None, unique_indices=None, key_count=None):
248
+ """
249
+ Using the corresponding query method to calculate the PS embedding for each key.
250
+ """
214
251
  origin_shape = None
215
252
  if len(keys.shape) != 1:
216
253
  origin_shape = keys.shape
@@ -227,11 +264,9 @@ class EsEmbeddingLookup(nn.Cell):
227
264
  key_count = keys
228
265
  if self.training:
229
266
  if use_host_unique:
230
- output = fake_remote_lookup_uniqued(table_id=self.table_id,
231
- keys=keys,
267
+ output = fake_remote_lookup_uniqued(table_id=self.table_id, keys=keys,
232
268
  actual_keys_num=actual_keys_input,
233
- unique_indices=unique_indices,
234
- key_count=key_count,
269
+ unique_indices=unique_indices, key_count=key_count,
235
270
  max_grad_norm=self.max_grad_norm,
236
271
  embedding_dim=self.embedding_dim,
237
272
  initializer_mode=self.es_initializer.initializer_mode,
@@ -250,8 +285,7 @@ class EsEmbeddingLookup(nn.Cell):
250
285
  default_value=self.default_value,
251
286
  optimizer_mode=self.optimizer_mode,
252
287
  optimizer_params=self.optimizer_params,
253
- _max_key_num=self.max_key_num,
254
- _table_id=self._table_id,
288
+ _max_key_num=self.max_key_num, _table_id=self._table_id,
255
289
  _use_counter_filter=use_counter_filter,
256
290
  backward_mode=self.optimizer_mode,
257
291
  backward_int_params=self.backward_int_params,
@@ -280,8 +314,7 @@ class EsEmbeddingLookup(nn.Cell):
280
314
  default_value=self.default_value,
281
315
  optimizer_mode=self.optimizer_mode,
282
316
  optimizer_params=self.optimizer_params,
283
- _max_key_num=self.max_key_num,
284
- _table_id=self._table_id,
317
+ _max_key_num=self.max_key_num, _table_id=self._table_id,
285
318
  _use_counter_filter=use_counter_filter,
286
319
  backward_mode=self.optimizer_mode,
287
320
  backward_int_params=self.backward_int_params,
@@ -290,14 +323,10 @@ class EsEmbeddingLookup(nn.Cell):
290
323
  completion_key_mask=self.completion_key_mask,
291
324
  parameter=self.b)
292
325
  else:
293
- output = embedding_table_find(self.table_id, keys,
294
- embedding_dim=self.embedding_dim,
326
+ output = embedding_table_find(self.table_id, keys, embedding_dim=self.embedding_dim,
295
327
  default_value=self.default_value,
296
- _max_key_num=self.max_key_num,
297
- _table_id=self._table_id,
328
+ _max_key_num=self.max_key_num, _table_id=self._table_id,
298
329
  _use_counter_filter=use_counter_filter)
299
- # input 20480 2 ->41960
300
- # output 41960 embedding_dim -> 20480 2 embedding_dim
301
330
  if origin_shape is not None:
302
331
  output = self.reshape(output, origin_shape + (-1,))
303
332
  return output
@@ -321,10 +350,10 @@ class ESEmbeddingCKPTExport(nn.Cell):
321
350
  self.table_id_tensor = Tensor(table_id_list, ms.int32)
322
351
  self.depend = ops.Depend()
323
352
 
324
- def construct(self):
325
- export_op1 = self.embedding_table_export(self.file_path, self.ps_id_tensor, self.table_id_tensor)
353
+ def construct(self, global_step):
354
+ export_op1 = self.embedding_table_export(self.file_path, self.ps_id_tensor, self.table_id_tensor, global_step)
326
355
  z = self.depend(self.file_path, export_op1)
327
- export_op2 = self.embedding_compute_var_export(z, self.ps_id_tensor, self.table_id_tensor)
356
+ export_op2 = self.embedding_compute_var_export(z, self.ps_id_tensor, self.table_id_tensor, global_step)
328
357
  return export_op2
329
358
 
330
359
 
@@ -345,8 +374,31 @@ class ESEmbeddingTableExport(nn.Cell):
345
374
  self.ps_id_tensor = Tensor(0, ms.int32)
346
375
  self.table_id_tensor = Tensor(table_id_list, ms.int32)
347
376
 
348
- def construct(self):
349
- y = self.op(self.file_path, self.ps_id_tensor, self.table_id_tensor)
377
+ def construct(self, global_step):
378
+ y = self.op(self.file_path, self.ps_id_tensor, self.table_id_tensor, global_step)
379
+ return y
380
+
381
+
382
+ class ESIncrementalEmbeddingTableExport(nn.Cell):
383
+ """
384
+ ESIncrementalEmbeddingTableExport.
385
+ """
386
+ def __init__(self, embedding_dim_list, value_total_len_list, table_name_list, table_id_list,
387
+ file_path, steps_to_live_list):
388
+ super(ESIncrementalEmbeddingTableExport, self).__init__()
389
+ self.op = EmbeddingTableExport(
390
+ embedding_dim_list,
391
+ value_total_len_list,
392
+ table_name=table_name_list,
393
+ steps_to_live_list=steps_to_live_list,
394
+ export_mode="new",
395
+ only_var_flag=True)
396
+ self.file_path = Tensor(np.array(file_path))
397
+ self.ps_id_tensor = Tensor(0, ms.int32)
398
+ self.table_id_tensor = Tensor(table_id_list, ms.int32)
399
+
400
+ def construct(self, global_step):
401
+ y = self.op(self.file_path, self.ps_id_tensor, self.table_id_tensor, global_step)
350
402
  return y
351
403
 
352
404
 
@@ -366,10 +418,10 @@ class ESEmbeddingCKPTImport(nn.Cell):
366
418
  self.table_id_tensor = Tensor(table_id_list, ms.int32)
367
419
  self.depend = ops.Depend()
368
420
 
369
- def construct(self):
370
- export_op1 = self.embedding_table_import(self.file_path, self.ps_id_tensor, self.table_id_tensor)
421
+ def construct(self, global_step):
422
+ export_op1 = self.embedding_table_import(self.file_path, self.ps_id_tensor, self.table_id_tensor, global_step)
371
423
  z = self.depend(self.file_path, export_op1)
372
- export_op2 = self.embedding_compute_var_import(z, self.ps_id_tensor, self.table_id_tensor)
424
+ export_op2 = self.embedding_compute_var_import(z, self.ps_id_tensor, self.table_id_tensor, global_step)
373
425
  return export_op2
374
426
 
375
427
 
@@ -388,6 +440,142 @@ class ESEmbeddingTableImport(nn.Cell):
388
440
  self.ps_id_tensor = Tensor(0, ms.int32)
389
441
  self.table_id_tensor = Tensor(table_id_list, ms.int32)
390
442
 
443
+ def construct(self, global_step):
444
+ y = self.op(self.file_path, self.ps_id_tensor, self.table_id_tensor, global_step)
445
+ return y
446
+
447
+
448
+ class ESEmbeddingTableEvict(nn.Cell):
449
+ """
450
+ ESEmbeddingTableEvict.
451
+ """
452
+ def __init__(self, var_handle, global_step, steps_to_live):
453
+ super(ESEmbeddingTableEvict, self).__init__()
454
+ self.op = EmbeddingTableEvict()
455
+ self.var_handle = Tensor(var_handle, ms.int32)
456
+ self.global_step = global_step
457
+ self.steps_to_live = steps_to_live
458
+
391
459
  def construct(self):
392
- y = self.op(self.file_path, self.ps_id_tensor, self.table_id_tensor)
460
+ y = self.op(self.var_handle, self.global_step, self.steps_to_live)
393
461
  return y
462
+
463
+
464
+ class ESEmbeddingFeatureMappingExport(nn.Cell):
465
+ """
466
+ ESEmbeddingFeatureMappingExport.
467
+ """
468
+ def __init__(self, file_path, export_value, var, var_name, small_table_embedding_dim):
469
+ super(ESEmbeddingFeatureMappingExport, self).__init__()
470
+ self.embedding_feature_mapping_table_size = EmbeddingFeatureMappingTableSize()
471
+ self.embedding_feature_mapping_find = EmbeddingFeatureMappingFind()
472
+ self.embedding_feature_mapping_export = EmbeddingFeatureMappingExport()
473
+ self.file_path = file_path
474
+ self.export_value = export_value
475
+ self.gather = ops.Gather()
476
+ self.var = Tensor(var, ms.float32)
477
+ self.var_name = Tensor(np.array([var_name]))
478
+ self.small_table_embedding_dim = [small_table_embedding_dim]
479
+ self.global_step = Tensor([-1], ms.int64)
480
+
481
+ def construct(self):
482
+ """
483
+ ESEmbeddingFeatureMappingExport construct: export feature mapping for data_parallel embedding.
484
+ """
485
+ feature_size = self.embedding_feature_mapping_table_size(self.var_name)
486
+ feature_id, offset_id = self.embedding_feature_mapping_find(self.var_name, feature_size, 1)
487
+ values = self.gather(self.var, offset_id, 0)
488
+ if self.export_value:
489
+ embed_values = values
490
+ else:
491
+ embed_values = Tensor([0], ms.float32)
492
+ feature_mapping_export = self.embedding_feature_mapping_export(self.file_path, self.var_name, self.global_step,
493
+ embed_values, self.small_table_embedding_dim,
494
+ [feature_id], [offset_id])
495
+ return feature_mapping_export
496
+
497
+
498
+ class ESEmbeddingFeatureMappingImport(nn.Cell):
499
+ """
500
+ ESEmbeddingFeatureMappingImport.
501
+ """
502
+ def __init__(self, file_path, small_table_name, small_table_embedding_dim, only_offset_flag):
503
+ super(ESEmbeddingFeatureMappingImport, self).__init__()
504
+ self.embedding_feature_mapping_file_size = EmbeddingFeatureMappingFileSize()
505
+ self.embedding_feature_mapping_import = EmbeddingFeatureMappingImport()
506
+ self.embedding_feature_mapping_insert = EmbeddingFeatureMappingInsert()
507
+ self.file_path = file_path
508
+ self.small_table_name = Tensor(np.array([small_table_name]))
509
+ self.small_table_embedding_dim = [small_table_embedding_dim]
510
+ self.only_offset_flag = only_offset_flag
511
+ self.global_step = Tensor([-1], ms.int64)
512
+
513
+ def construct(self):
514
+ """
515
+ ESEmbeddingFeatureMappingImport construct: import feature mapping for data_parallel embedding.
516
+ """
517
+ feature_size = self.embedding_feature_mapping_file_size(self.file_path,
518
+ self.small_table_name,
519
+ self.global_step,
520
+ self.small_table_embedding_dim,
521
+ self.only_offset_flag)
522
+ feature_id, offset_id = self.embedding_feature_mapping_import(self.file_path,
523
+ self.small_table_name,
524
+ feature_size, self.global_step,
525
+ self.small_table_embedding_dim,
526
+ self.only_offset_flag, 1)
527
+ feature_mapping_insert = self.embedding_feature_mapping_insert(self.small_table_name, 1,
528
+ [feature_id], [offset_id])
529
+ return feature_mapping_insert
530
+
531
+
532
+ class ESEmbeddingSmallTableLookup(nn.Cell):
533
+ r"""
534
+ Look up a data_parallel embedding.
535
+
536
+ .. warning::
537
+ This is an experimental EmbeddingService API that is subject to change.
538
+
539
+ Args:
540
+ name (str): The data_parallel embedding name.
541
+ rank_id (int): The rank id when look up data_parallel embedding key.
542
+ rank_size (int): The rank size when look up data_parallel embedding key.
543
+ small_table_to_variable (dict[str, parameter]): The dict to restore data_parallel embedding information:
544
+ key is table name, value is parameter.
545
+
546
+ Inputs:
547
+ - **ids_list** (Tensor) - The keys of each feature in data_parallel embedding.
548
+
549
+ Supported Platforms:
550
+ ``Atlas A2 training series products``
551
+ """
552
+
553
+ def __init__(self, name, rank_id, rank_size, small_table_to_variable):
554
+ super(ESEmbeddingSmallTableLookup, self).__init__()
555
+ self.small_table_to_variable = small_table_to_variable[name]
556
+ self.small_table_to_variable.feature_name = name
557
+ self.allgather = ops.AllGather()
558
+ self.gather = ops.Gather()
559
+ self.embedding_feature_mapping_v2 = EmbeddingFeatureMappingV2()
560
+ self.name = name
561
+ self.rank_id = rank_id
562
+ self.rank_size = rank_size
563
+
564
+ def construct(self, ids_list):
565
+ """
566
+ Using the EmbeddingFeatureMappingV2 method to mapping hash key to non hash key, and then get embedding value.
567
+ """
568
+ hash_key_shape = ids_list.shape
569
+ if self.rank_size > 1 and (hash_key_shape[0] is not None):
570
+ hash_key = ops.stop_gradient(self.allgather(ids_list))
571
+ non_hash_key = self.embedding_feature_mapping_v2(self.name, hash_key, [1], [1])
572
+ recovery_matrix = []
573
+ for i in range(hash_key_shape[0]):
574
+ recovery_matrix.append(self.rank_id * hash_key_shape[0] + i)
575
+ local_non_hash_keys = self.gather(non_hash_key, Tensor(recovery_matrix), 0)
576
+ else:
577
+ hash_key = ids_list
578
+ local_non_hash_keys = self.embedding_feature_mapping_v2(self.name, hash_key, [1], [1])
579
+
580
+ embedding = self.gather(self.small_table_to_variable, local_non_hash_keys, 0)
581
+ return embedding
@@ -0,0 +1,21 @@
1
+ # Copyright 2024 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """LlmBoost Register"""
16
+ from __future__ import absolute_import
17
+
18
+ from mindspore.experimental.llm_boost.atb import *
19
+ from mindspore.experimental.llm_boost.register import LlmBoostRegister
20
+
21
+ __all__ = ['LlmBoostRegister']
@@ -13,15 +13,11 @@
13
13
  # limitations under the License.
14
14
  # ============================================================================
15
15
  """
16
- Layer.
17
-
18
- The high-level components(Cells) used to construct the neural network.
16
+ Provide llm boost for inference, such as LlamaBoost.
19
17
  """
20
18
  from __future__ import absolute_import
21
19
 
22
- from mindspore.nn.extend.layer import normalization
23
- from mindspore.nn.extend.layer.normalization import *
24
-
25
- __all__ = []
20
+ from mindspore.experimental.llm_boost.atb.llama_boost import LlamaBoost
21
+ from mindspore.experimental.llm_boost.atb.qwen_boost import QwenBoost
26
22
 
27
- __all__.extend(normalization.__all__)
23
+ __all__ = ['LlamaBoost', 'QwenBoost']
@@ -0,0 +1,211 @@
1
+ # Copyright 2024 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """boost base class"""
16
+ import numpy as np
17
+ import mindspore as ms
18
+ from mindspore import ops, Tensor
19
+ from mindspore.ops import operations as P
20
+ import mindspore.common.dtype as mstype
21
+ from mindspore._c_expression import _set_format
22
+
23
+ from mindspore.common.parameter import Parameter
24
+ from mindspore.experimental.llm_boost.utils import get_real_rank, get_real_group_size
25
+ from mindspore.common.initializer import Zero
26
+
27
+
28
+ class AttentionMask:
29
+ """attention mask"""
30
+
31
+ @classmethod
32
+ def static(cls, max_seq_len, dtype=mstype.float16, need_nz=False):
33
+ """cache mask"""
34
+ bias_cache = Tensor(np.tril(np.ones((max_seq_len, max_seq_len), dtype=np.bool_))).reshape(max_seq_len,
35
+ max_seq_len)
36
+ bias_cache = ~bias_cache
37
+ if dtype == mstype.float16:
38
+ mask_value = Tensor(np.finfo(np.float32).min, mstype.float16)
39
+ else:
40
+ mask_value = Tensor(1)
41
+ attn_mask = ops.masked_fill(Tensor(np.zeros(
42
+ (max_seq_len, max_seq_len)), dtype=mstype.float16), bias_cache, mask_value)
43
+ if need_nz:
44
+ # ND -> NZ
45
+ attn_mask = ops.reshape(attn_mask, (1, max_seq_len, max_seq_len))
46
+ attn_mask = ops.reshape(
47
+ attn_mask, (1, max_seq_len, max_seq_len // 16, 16))
48
+ attn_mask = ops.transpose(attn_mask, (0, 2, 1, 3)).contiguous()
49
+ attn_mask = _set_format(attn_mask, "FRACTAL_NZ")
50
+ return attn_mask
51
+
52
+
53
+ class AtbBoostBase():
54
+ """atb boost base class"""
55
+
56
+ def __init__(self, config):
57
+ super().__init__()
58
+ self.is_first_iteration = False
59
+ self.config = config
60
+ self.dtype = config.compute_dtype
61
+ self.num_heads = config.num_heads
62
+ self.num_kv_heads = config.n_kv_heads if config.n_kv_heads else self.num_heads
63
+ self.num_layers = config.num_layers
64
+ self.n_kv_heads = config.n_kv_heads if config.n_kv_heads else config.num_heads
65
+ self.head_dim = config.hidden_size // self.num_heads
66
+ self.need_nz = False
67
+ if hasattr(config, "need_nz"):
68
+ self.need_nz = config.need_nz
69
+ self.placeholder = Tensor(np.zeros(1), dtype=self.dtype)
70
+ self.lm_head_indices_fake = Tensor([0], dtype=mstype.int64)
71
+ self.position_embedding_type = "ROPE"
72
+ self.add_norm_enable = True
73
+ self.max_decode_length = self.config.max_decode_length
74
+ self.max_base_len = 128
75
+ self.attn_mask = AttentionMask.static(
76
+ self.max_base_len, dtype=self.dtype, need_nz=self.need_nz)
77
+
78
+ self.cast = P.Cast()
79
+ self.reshape = P.Reshape()
80
+ self.kv_quant = None
81
+ self.rank_id = get_real_rank()
82
+ self.device_num = get_real_group_size()
83
+
84
+ def _convert_tensor_format_and_dtype(self, tensor, dtype=mstype.float16):
85
+ tensor = self.cast(tensor, dtype=dtype)
86
+ if self.need_nz:
87
+ tensor = _set_format(tensor, "FRACTAL_NZ")
88
+ return tensor
89
+
90
+ def set_weights(self, parm_dict, dtype=mstype.float16):
91
+ """set weights for llm boost"""
92
+ embedding_weight_name = "model.tok_embeddings.embedding_weight"
93
+ attention_norm_name = "attention_norm"
94
+ qkv_name = "attention.w_qkv"
95
+ o_name = "attention.wo"
96
+ mlp_norm_name = "ffn_norm"
97
+ mlp_gate_name = "feed_forward.w_gate_hidden"
98
+ mlp_down_name = "feed_forward.w2"
99
+ norm_out_name = "model.norm_out"
100
+ lm_head_name = "lm_head"
101
+ placeholder = Parameter(Tensor(np.zeros(1), dtype=dtype))
102
+
103
+ ascend_weight = []
104
+ ascend_weight.append(
105
+ self.cast(parm_dict[embedding_weight_name], dtype))
106
+ for i in range(self.num_layers):
107
+ ascend_weight.append(self._convert_tensor_format_and_dtype(
108
+ parm_dict[f"model.layers.{i}.{attention_norm_name}.weight"], dtype))
109
+ ascend_weight.extend([placeholder] * 3)
110
+
111
+ ascend_weight.append(
112
+ self._convert_tensor_format_and_dtype(parm_dict[f"model.layers.{i}.{qkv_name}.weight"], dtype))
113
+ ascend_weight.append(self._convert_tensor_format_and_dtype(parm_dict.get(
114
+ f"model.layers.{i}.{qkv_name}.bias", placeholder), dtype))
115
+ ascend_weight.extend([placeholder] * 16)
116
+
117
+ ascend_weight.append(
118
+ self._convert_tensor_format_and_dtype(parm_dict[f"model.layers.{i}.{o_name}.weight"], dtype))
119
+ ascend_weight.append(self._convert_tensor_format_and_dtype(parm_dict.get(
120
+ f"model.layers.{i}.{o_name}.bias", placeholder), dtype))
121
+ ascend_weight.extend([placeholder] * 4)
122
+
123
+ ascend_weight.append(
124
+ self._convert_tensor_format_and_dtype(parm_dict[f"model.layers.{i}.{mlp_norm_name}.weight"], dtype))
125
+ ascend_weight.extend([placeholder] * 3)
126
+
127
+ ascend_weight.append(
128
+ self._convert_tensor_format_and_dtype(parm_dict[f"model.layers.{i}.{mlp_gate_name}.weight"], dtype))
129
+ ascend_weight.append(self._convert_tensor_format_and_dtype(parm_dict.get(
130
+ f"model.layers.{i}.{mlp_gate_name}.bias", placeholder), dtype))
131
+ ascend_weight.extend([placeholder] * 10)
132
+
133
+ ascend_weight.append(
134
+ self._convert_tensor_format_and_dtype(parm_dict[f"model.layers.{i}.{mlp_down_name}.weight"], dtype))
135
+ ascend_weight.append(self._convert_tensor_format_and_dtype(parm_dict.get(
136
+ f"model.layers.{i}.{mlp_down_name}.bias", placeholder), dtype))
137
+ ascend_weight.extend([placeholder] * 4)
138
+
139
+ ascend_weight.append(
140
+ self._convert_tensor_format_and_dtype(parm_dict[f"{norm_out_name}.weight"], dtype))
141
+ ascend_weight.append(
142
+ self._convert_tensor_format_and_dtype(parm_dict[f"{lm_head_name}.weight"], dtype))
143
+ self.atb_encoder_operation.set_weights(ascend_weight)
144
+ self.atb_decoder_operation.set_weights(ascend_weight)
145
+
146
+ def set_kvcache(self, k_caches=None, v_caches=None):
147
+ """set kv_cache for llm boost"""
148
+ if not k_caches or v_caches:
149
+ if self.need_nz:
150
+ kv_shape = (self.config.num_blocks, self.num_kv_heads*self.head_dim //
151
+ self.device_num // 16, self.config.block_size, 16)
152
+ k_caches = [_set_format(Parameter(Tensor(
153
+ shape=kv_shape, dtype=self.dtype, init=Zero())), "FRACTAL_NZ") for _ in range(self.num_layers)]
154
+ v_caches = [_set_format(Parameter(Tensor(
155
+ shape=kv_shape, dtype=self.dtype, init=Zero())), "FRACTAL_NZ") for _ in range(self.num_layers)]
156
+ else:
157
+ kv_shape = (self.config.num_blocks, self.config.block_size,
158
+ self.num_kv_heads // self.device_num, self.head_dim)
159
+ k_caches = [Parameter(Tensor(
160
+ shape=kv_shape, dtype=self.dtype, init=Zero())) for _ in range(self.num_layers)]
161
+ v_caches = [Parameter(Tensor(
162
+ shape=kv_shape, dtype=self.dtype, init=Zero())) for _ in range(self.num_layers)]
163
+
164
+ self.atb_encoder_operation.set_kvcache(k_caches, v_caches)
165
+ self.atb_decoder_operation.set_kvcache(k_caches, v_caches)
166
+
167
+ def add_flags(self, is_first_iteration):
168
+ """add_flags."""
169
+ self.is_first_iteration = is_first_iteration
170
+
171
+ def _execute_operator(self, acl_inputs, acl_param):
172
+ """execute operator."""
173
+ if self.is_first_iteration:
174
+ acl_model_out = self.atb_encoder_operation.forward(
175
+ acl_inputs, acl_param)
176
+ else:
177
+ acl_model_out = self.atb_decoder_operation.forward(
178
+ acl_inputs, acl_param)
179
+ acl_hidden_state = acl_model_out[0]
180
+ return acl_hidden_state
181
+
182
+ def forward(self, boost_inputs):
183
+ r"""
184
+ LlmBoost forward.
185
+ """
186
+ input_ids = boost_inputs["input_ids"]
187
+ position_ids = boost_inputs["position_ids"]
188
+ cos_embed = boost_inputs["cos_embed"]
189
+ sin_embed = boost_inputs["sin_embed"]
190
+ block_tables = boost_inputs["block_tables"]
191
+ slot_mapping = boost_inputs["slot_mapping"]
192
+ batch_valid_length = boost_inputs["batch_valid_length"]
193
+ lm_head_indices = boost_inputs["lm_head_indices"]
194
+ seqLen = boost_inputs["seq_lens"]
195
+ if self.is_first_iteration:
196
+ attention_mask = self.attn_mask
197
+ else:
198
+ position_ids = batch_valid_length - 1
199
+ attention_mask = self.placeholder
200
+ lm_head_indices = self.lm_head_indices_fake
201
+
202
+ acl_inputs, acl_param = self._prepare_inputs(prefill=self.is_first_iteration, input_ids=input_ids,
203
+ position_ids=position_ids, cos_embed=cos_embed,
204
+ sin_embed=sin_embed, attention_mask=attention_mask,
205
+ block_tables=block_tables, slots=slot_mapping,
206
+ input_lengths=batch_valid_length, lm_head_indices=lm_head_indices,
207
+ seqLen=seqLen)
208
+ ms.hal.synchronize()
209
+ logits = self._execute_operator(acl_inputs, acl_param)
210
+ logits = self.cast(logits, mstype.float32)
211
+ return logits