mindspore 2.4.1__cp39-cp39-macosx_11_0_arm64.whl → 2.4.10__cp39-cp39-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (56) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/_c_dataengine.cpython-39-darwin.so +0 -0
  3. mindspore/_c_expression.cpython-39-darwin.so +0 -0
  4. mindspore/_c_mindrecord.cpython-39-darwin.so +0 -0
  5. mindspore/common/api.py +1 -4
  6. mindspore/common/file_system.py +2 -0
  7. mindspore/common/parameter.py +1 -14
  8. mindspore/communication/_comm_helper.py +5 -0
  9. mindspore/context.py +7 -2
  10. mindspore/dataset/engine/datasets_standard_format.py +17 -0
  11. mindspore/dataset/engine/datasets_user_defined.py +27 -1
  12. mindspore/experimental/llm_boost/__init__.py +2 -2
  13. mindspore/experimental/llm_boost/atb/boost_base.py +240 -64
  14. mindspore/experimental/llm_boost/atb/llama_boost.py +46 -29
  15. mindspore/experimental/llm_boost/atb/qwen_boost.py +47 -24
  16. mindspore/include/api/context.h +1 -1
  17. mindspore/include/dataset/constants.h +2 -2
  18. mindspore/lib/libavcodec.59.dylib +0 -0
  19. mindspore/lib/libavdevice.59.dylib +0 -0
  20. mindspore/lib/libavfilter.8.dylib +0 -0
  21. mindspore/lib/libavformat.59.dylib +0 -0
  22. mindspore/lib/libavutil.57.dylib +0 -0
  23. mindspore/lib/libmindspore_backend.dylib +0 -0
  24. mindspore/lib/libmindspore_common.dylib +0 -0
  25. mindspore/lib/libmindspore_core.dylib +0 -0
  26. mindspore/lib/libmindspore_gpr.15.dylib +0 -0
  27. mindspore/lib/libmindspore_grpc++.1.dylib +0 -0
  28. mindspore/lib/libmindspore_grpc.15.dylib +0 -0
  29. mindspore/lib/libmindspore_ops.dylib +0 -0
  30. mindspore/lib/libswresample.4.dylib +0 -0
  31. mindspore/lib/libswscale.6.dylib +0 -0
  32. mindspore/nn/__init__.py +2 -0
  33. mindspore/nn/cell.py +16 -2
  34. mindspore/nn/layer/conv.py +3 -0
  35. mindspore/nn/layer/pooling.py +8 -10
  36. mindspore/nn/utils/__init__.py +22 -0
  37. mindspore/nn/utils/init.py +71 -0
  38. mindspore/ops/_grad_experimental/grad_comm_ops.py +25 -7
  39. mindspore/ops/auto_generate/gen_ops_prim.py +3 -2
  40. mindspore/ops/function/math_func.py +5 -4
  41. mindspore/ops/operations/comm_ops.py +4 -1
  42. mindspore/ops/operations/custom_ops.py +6 -4
  43. mindspore/ops/operations/nn_ops.py +7 -2
  44. mindspore/parallel/_auto_parallel_context.py +23 -4
  45. mindspore/parallel/_cell_wrapper.py +22 -3
  46. mindspore/parallel/_utils.py +0 -1
  47. mindspore/run_check/_check_version.py +17 -8
  48. mindspore/train/callback/_tft_register.py +7 -6
  49. mindspore/train/model.py +1 -0
  50. mindspore/train/serialization.py +4 -1
  51. mindspore/version.py +1 -1
  52. {mindspore-2.4.1.dist-info → mindspore-2.4.10.dist-info}/METADATA +2 -2
  53. {mindspore-2.4.1.dist-info → mindspore-2.4.10.dist-info}/RECORD +56 -54
  54. {mindspore-2.4.1.dist-info → mindspore-2.4.10.dist-info}/WHEEL +0 -0
  55. {mindspore-2.4.1.dist-info → mindspore-2.4.10.dist-info}/entry_points.txt +0 -0
  56. {mindspore-2.4.1.dist-info → mindspore-2.4.10.dist-info}/top_level.txt +0 -0
mindspore/.commit_id CHANGED
@@ -1 +1 @@
1
- __commit_id__ = ''[sha1]:01847825,[branch]:(HEAD,origin/r2.4.1,r2.4.1)''
1
+ __commit_id__ = ''[sha1]:8e2ae935,[branch]:(HEAD,origin/r2.4.1,r2.4.1)''
mindspore/common/api.py CHANGED
@@ -1703,7 +1703,6 @@ class _CellGraphExecutor:
1703
1703
  self._graph_executor = GraphExecutor_.get_instance()
1704
1704
  self._graph_executor.set_py_exe_path(sys.executable)
1705
1705
  self._graph_executor.set_kernel_build_server_dir(os.path.split(kernel_build_server.__file__)[0] + os.sep)
1706
- self._pid = os.getpid()
1707
1706
 
1708
1707
  def init_dataset(self, queue_name, dataset_size, batch_size, dataset_types, dataset_shapes,
1709
1708
  input_indexs, phase='dataset', need_run=True):
@@ -1934,9 +1933,7 @@ class _CellGraphExecutor:
1934
1933
 
1935
1934
  def del_net_res(self, obj, net_id):
1936
1935
  """Clear the memory resource of a network."""
1937
- # no need to del net res by gc in independent dataset process which is a subprocess forked by main process
1938
- if self._pid == os.getpid():
1939
- self._graph_executor.del_net_res(obj, net_id)
1936
+ self._graph_executor.del_net_res(obj, net_id)
1940
1937
 
1941
1938
  def _get_branch_control_input(self):
1942
1939
  if ('obf_ratio' not in self.obfuscate_config.keys()) or (
@@ -22,6 +22,7 @@ class FileSystem:
22
22
  self.create_args = ("ab",)
23
23
  self.open = open
24
24
  self.open_args = ("rb",)
25
+ self.backend = "basic"
25
26
 
26
27
 
27
28
  def _register_basic_file_system(fs: FileSystem):
@@ -45,4 +46,5 @@ def _register_mindio_file_system(fs: FileSystem):
45
46
  fs.create_args = ()
46
47
  fs.open = mindio.open_file
47
48
  fs.open_args = ()
49
+ fs.backend = "mindio"
48
50
  return True
@@ -22,7 +22,6 @@ import os
22
22
  import sys
23
23
  import math
24
24
  import numbers
25
- from contextlib import contextmanager
26
25
  import numpy as np
27
26
  from mindspore import log as logger
28
27
  from mindspore.log import _LogActionOnce
@@ -55,16 +54,6 @@ PARAMETER_NAME_PREFIX_MAX_LEN = 1024
55
54
  _GLOBAL_PARAMETER_KEY = -1
56
55
 
57
56
 
58
- @contextmanager
59
- def no_init_parameters():
60
- init_class = globals()["Parameter"]
61
- setattr(init_class, "init_param", False)
62
- try:
63
- yield
64
- finally:
65
- setattr(init_class, "init_param", True)
66
-
67
-
68
57
  def _is_in_auto_parallel_mode():
69
58
  """Get parallel mode."""
70
59
  return auto_parallel_context().get_parallel_mode() in ["semi_auto_parallel", "auto_parallel"]
@@ -988,9 +977,7 @@ class Parameter(Tensor_):
988
977
  """
989
978
  if self.is_default_input_init and self.is_in_parallel != _is_in_auto_parallel_mode():
990
979
  raise RuntimeError("Must set or change parallel mode before any initializer Tensor created.")
991
- if hasattr(self, "init_param") and self.init_param:
992
- return self
993
- if self.init_mode is None:
980
+ if self.init_mode is None or not self.has_init:
994
981
  return self
995
982
  if self.inited_param is not None:
996
983
  return self.inited_param
@@ -499,3 +499,8 @@ def _destroy_group_helper(group):
499
499
  hccl.create_group(group)
500
500
  else:
501
501
  CollectiveManager.get_instance().destroy_group(group)
502
+
503
+
504
+ def _get_group_map():
505
+ """Get the group map"""
506
+ return CollectiveManager.get_instance().get_group_map()
mindspore/context.py CHANGED
@@ -936,6 +936,7 @@ def set_auto_parallel_context(**kwargs):
936
936
  \ group_ckpt_save_file
937
937
  \ auto_pipeline
938
938
  \ dump_local_norm
939
+ \ dump_device_local_norm
939
940
  =========================== ===========================
940
941
 
941
942
  Args:
@@ -1090,6 +1091,9 @@ def set_auto_parallel_context(**kwargs):
1090
1091
  dump_local_norm (bool): Whether to dump local_norm value, when the `parallel_mode` is set to
1091
1092
  ``semi_auto_parallel`` or ``auto_parallel``.
1092
1093
  Default: ``False`` .
1094
+ dump_device_local_norm (bool): Whether to dump device_local_norm value, when the `parallel_mode` is set to
1095
+ ``semi_auto_parallel`` or ``auto_parallel``.
1096
+ Default: ``False`` .
1093
1097
 
1094
1098
  Raises:
1095
1099
  ValueError: If input key is not attribute in auto parallel context.
@@ -1165,8 +1169,9 @@ def reset_auto_parallel_context():
1165
1169
  - pipeline_stages: 1.
1166
1170
  - pipeline_result_broadcast: False.
1167
1171
  - fusion_threshold: 64.
1168
- - dump_local_norm: False.
1169
1172
  - auto_pipeline: False.
1173
+ - dump_local_norm: False.
1174
+ - dump_device_local_norm: False.
1170
1175
 
1171
1176
  Examples:
1172
1177
  >>> import mindspore as ms
@@ -1793,7 +1798,7 @@ def set_context(**kwargs):
1793
1798
  When both exist simultaneously, the global jit config will not overwrite the local network's jit config.
1794
1799
 
1795
1800
  - jit_level (str): Used to control the compilation optimization level. Default: ``""`` , The framework
1796
- automatically selects the execution method based on product, Altas training product is O2, and all other
1801
+ automatically selects the execution method based on product, Atlas training product is O2, and all other
1797
1802
  products are O0. In addition, The option of the dynamic shape must be O0 or O1, O2 is not supported.
1798
1803
  The value range is as follows:
1799
1804
 
@@ -33,6 +33,7 @@ from .datasets import UnionBaseDataset, SourceDataset, MappableDataset, Shuffle,
33
33
  from .datasets_user_defined import GeneratorDataset
34
34
  from .obs.obs_mindrecord_dataset import MindRecordFromOBS
35
35
  from .validators import check_csvdataset, check_minddataset, check_tfrecorddataset, check_obsminddataset
36
+ from ..core.validator_helpers import type_check
36
37
  from ...mindrecord.config import _get_enc_key, _get_dec_mode, _get_hash_mode, decrypt, verify_file_hash
37
38
 
38
39
 
@@ -301,6 +302,22 @@ class MindDataset(MappableDataset, UnionBaseDataset):
301
302
  else:
302
303
  self.new_padded_sample[k] = v
303
304
 
305
+ def __deepcopy__(self, memodict):
306
+ if id(self) in memodict:
307
+ return memodict[id(self)]
308
+ return self.__safe_deepcopy__(memodict, exclude=("mindrecord_op"))
309
+
310
+ def __getitem__(self, index):
311
+ type_check(index, (int,), "index")
312
+ if index < 0:
313
+ raise ValueError("index cannot be negative, but got {0}.".format(index))
314
+ if not hasattr(self, "mindrecord_op"):
315
+ minddata_node = cde.MindDataNode(
316
+ self.dataset_files, self.columns_list, self.sampler, self.new_padded_sample,
317
+ self.num_padded, shuffle_to_shuffle_mode(self.shuffle_option))
318
+ self.mindrecord_op = minddata_node.Build()
319
+ return [t.as_array() for t in self.mindrecord_op[index]]
320
+
304
321
 
305
322
  class TFRecordDataset(SourceDataset, UnionBaseDataset):
306
323
  """
@@ -19,6 +19,7 @@ After declaring the dataset object, you can further apply dataset operations
19
19
  (e.g. filter, skip, concat, map, batch) on it.
20
20
  """
21
21
  import builtins
22
+ import copy
22
23
  import errno
23
24
  import itertools
24
25
  import math
@@ -50,6 +51,7 @@ from ..core.config import get_enable_shared_mem, get_prefetch_size, get_multipro
50
51
  get_enable_watchdog, get_debug_mode, get_seed, set_seed
51
52
  from ..core.datatypes import mstypelist_to_detypelist
52
53
  from ..core.py_util_helpers import ExceptionHandler
54
+ from ..core.validator_helpers import type_check
53
55
  from ..transforms import transforms
54
56
 
55
57
 
@@ -427,6 +429,10 @@ class SamplerFn(cde.PythonMultiprocessingRuntime):
427
429
  subprocess_file_descriptor = w.sentinel
428
430
  st = time.time()
429
431
  while _PythonMultiprocessing.is_process_alive(w.pid):
432
+ process = psutil.Process(w.pid)
433
+ if process.status() == psutil.STATUS_ZOMBIE:
434
+ process.kill()
435
+ break
430
436
  time.sleep(0.01) # sleep 10ms, waiting for the subprocess exit
431
437
  if time.time() - st > check_interval:
432
438
  logger.warning("Waiting for the subprocess worker [{}] to exit.".format(w.pid))
@@ -469,7 +475,7 @@ class SamplerFn(cde.PythonMultiprocessingRuntime):
469
475
 
470
476
  # let the quit event notify the worker process to exit
471
477
  w.join(timeout=5)
472
- if w.is_alive():
478
+ if _PythonMultiprocessing.is_process_alive(w.pid):
473
479
  # if the worker process did not exit, it may hang, try to terminate it
474
480
  w.terminate()
475
481
  w.close()
@@ -907,6 +913,26 @@ class GeneratorDataset(MappableDataset, UnionBaseDataset):
907
913
  return memodict[id(self)]
908
914
  return self.__safe_deepcopy__(memodict, exclude=("source", "__transfer_dataset__"))
909
915
 
916
+ def __getitem__(self, index):
917
+ type_check(index, (int, np.number), "index")
918
+ if not hasattr(self.source, "__getitem__"):
919
+ raise RuntimeError("Dataset don't support randomized access.")
920
+ if not hasattr(self, "generator_op"):
921
+ dataset = copy.deepcopy(self)
922
+ self.prepared_source = _generator_fn_wrapper(_cpp_sampler_fn, self.source)
923
+ if self.schema is None:
924
+ dataset.generator_node = cde.GeneratorNode(self.prepared_source, self.column_names, self.column_types,
925
+ self.source_len, self.sampler, 1, None)
926
+ else:
927
+ schema = self.schema
928
+ if isinstance(schema, Schema):
929
+ schema = self.schema.cpp_schema
930
+ dataset.generator_node = cde.GeneratorNode(self.prepared_source, schema, self.source_len,
931
+ self.sampler, 1, None)
932
+ self.generator_op = dataset.generator_node.Build()
933
+ sample_id = self.generator_op.GetMappedIndex(index)
934
+ return self.source[sample_id]
935
+
910
936
  def is_shuffled(self):
911
937
  if self.sampler:
912
938
  return self.sampler.is_shuffled()
@@ -15,7 +15,7 @@
15
15
  """LlmBoost Register"""
16
16
  from __future__ import absolute_import
17
17
 
18
- from mindspore.experimental.llm_boost.atb import *
18
+ from mindspore.experimental.llm_boost.atb import LlamaBoost, QwenBoost
19
19
  from mindspore.experimental.llm_boost.register import LlmBoostRegister
20
20
 
21
- __all__ = ['LlmBoostRegister']
21
+ __all__ = ["LlmBoostRegister"]
@@ -13,17 +13,32 @@
13
13
  # limitations under the License.
14
14
  # ============================================================================
15
15
  """boost base class"""
16
+ from enum import Enum
16
17
  import numpy as np
17
18
  import mindspore as ms
18
19
  from mindspore import ops, Tensor
20
+ from mindspore import log as logger
19
21
  from mindspore.ops import operations as P
20
22
  import mindspore.common.dtype as mstype
21
23
  from mindspore._c_expression import _set_format
22
-
23
24
  from mindspore.common.parameter import Parameter
24
25
  from mindspore.experimental.llm_boost.utils import get_real_rank, get_real_group_size
25
26
  from mindspore.common.initializer import Zero
26
27
 
28
+ FORMAT_NZ = "FRACTAL_NZ"
29
+ BUILDIN_BACKEND_NAME = "ATB"
30
+
31
+
32
+ class PositionEmbeddingType(int, Enum):
33
+ ROPE = 0
34
+ ALIBI = 1
35
+ ABSOLUTE = 2
36
+
37
+
38
+ class NormType(int, Enum):
39
+ RMS_NORM = 0
40
+ LAYER_NORM = 1
41
+
27
42
 
28
43
  class AttentionMask:
29
44
  """attention mask"""
@@ -31,30 +46,34 @@ class AttentionMask:
31
46
  @classmethod
32
47
  def static(cls, max_seq_len, dtype=mstype.float16, need_nz=False):
33
48
  """cache mask"""
34
- bias_cache = Tensor(np.tril(np.ones((max_seq_len, max_seq_len), dtype=np.bool_))).reshape(max_seq_len,
35
- max_seq_len)
49
+ bias_cache = Tensor(
50
+ np.tril(np.ones((max_seq_len, max_seq_len), dtype=np.bool_))
51
+ ).reshape(max_seq_len, max_seq_len)
36
52
  bias_cache = ~bias_cache
37
53
  if dtype == mstype.float16:
38
54
  mask_value = Tensor(np.finfo(np.float32).min, mstype.float16)
39
55
  else:
40
56
  mask_value = Tensor(1)
41
- attn_mask = ops.masked_fill(Tensor(np.zeros(
42
- (max_seq_len, max_seq_len)), dtype=mstype.float16), bias_cache, mask_value)
57
+ attn_mask = ops.masked_fill(
58
+ Tensor(np.zeros((max_seq_len, max_seq_len)), dtype=mstype.float16),
59
+ bias_cache,
60
+ mask_value,
61
+ )
43
62
  if need_nz:
44
63
  # ND -> NZ
45
64
  attn_mask = ops.reshape(attn_mask, (1, max_seq_len, max_seq_len))
46
- attn_mask = ops.reshape(
47
- attn_mask, (1, max_seq_len, max_seq_len // 16, 16))
65
+ attn_mask = ops.reshape(attn_mask, (1, max_seq_len, max_seq_len // 16, 16))
48
66
  attn_mask = ops.transpose(attn_mask, (0, 2, 1, 3)).contiguous()
49
- attn_mask = _set_format(attn_mask, "FRACTAL_NZ")
67
+ attn_mask = _set_format(attn_mask, FORMAT_NZ)
50
68
  return attn_mask
51
69
 
52
70
 
53
- class AtbBoostBase():
71
+ class AtbBoostBase:
54
72
  """atb boost base class"""
55
73
 
56
74
  def __init__(self, config):
57
75
  super().__init__()
76
+ self.backend_name = BUILDIN_BACKEND_NAME
58
77
  self.is_first_iteration = False
59
78
  self.config = config
60
79
  self.dtype = config.compute_dtype
@@ -68,27 +87,98 @@ class AtbBoostBase():
68
87
  self.need_nz = config.need_nz
69
88
  self.placeholder = Tensor(np.zeros(1), dtype=self.dtype)
70
89
  self.lm_head_indices_fake = Tensor([0], dtype=mstype.int64)
71
- self.position_embedding_type = "ROPE"
90
+ self.position_embedding_type = PositionEmbeddingType.ROPE
72
91
  self.add_norm_enable = True
73
92
  self.max_decode_length = self.config.max_decode_length
74
93
  self.max_base_len = 128
75
94
  self.attn_mask = AttentionMask.static(
76
- self.max_base_len, dtype=self.dtype, need_nz=self.need_nz)
95
+ self.max_base_len, dtype=self.dtype, need_nz=self.need_nz
96
+ )
77
97
 
78
98
  self.cast = P.Cast()
79
99
  self.reshape = P.Reshape()
80
100
  self.kv_quant = None
81
101
  self.rank_id = get_real_rank()
82
102
  self.device_num = get_real_group_size()
103
+ self.ascend_weight = []
104
+ self.k_caches = []
105
+ self.v_caches = []
83
106
 
84
107
  def _convert_tensor_format_and_dtype(self, tensor, dtype=mstype.float16):
85
108
  tensor = self.cast(tensor, dtype=dtype)
86
109
  if self.need_nz:
87
- tensor = _set_format(tensor, "FRACTAL_NZ")
110
+ tensor = _set_format(tensor, FORMAT_NZ)
88
111
  return tensor
89
112
 
113
+ def _convert_qkv_concat_weight(self, param_dict):
114
+ """convert qkv concat weight"""
115
+ assume_num_layers = 500
116
+ for i in range(assume_num_layers):
117
+ # qkv weight concat
118
+ wq_weight_name = f"model.layers.{i}.attention.wq.weight"
119
+ wk_weight_name = f"model.layers.{i}.attention.wk.weight"
120
+ wv_weight_name = f"model.layers.{i}.attention.wv.weight"
121
+ qkv_concat_weight_name = f"model.layers.{i}.attention.w_qkv.weight"
122
+ if wq_weight_name not in param_dict:
123
+ break
124
+ wq_weight = param_dict[wq_weight_name].asnumpy()
125
+ wk_weight = param_dict[wk_weight_name].asnumpy()
126
+ wv_weight = param_dict[wv_weight_name].asnumpy()
127
+ qkv_weight = np.concatenate((wq_weight, wk_weight, wv_weight), 0)
128
+ param_dict[qkv_concat_weight_name] = Parameter(
129
+ qkv_weight, name=qkv_concat_weight_name
130
+ )
131
+
132
+ # gate hidden weight concat
133
+ ffn_gate_weight_name = f"model.layers.{i}.feed_forward.w1.weight"
134
+ ffn_hidden_weight_name = f"model.layers.{i}.feed_forward.w3.weight"
135
+ gate_hidden_concat_weight_name = (
136
+ f"model.layers.{i}.feed_forward.w_gate_hidden.weight"
137
+ )
138
+
139
+ ffn_gate_weight = param_dict[ffn_gate_weight_name].asnumpy()
140
+ ffn_hidden_weight = param_dict[ffn_hidden_weight_name].asnumpy()
141
+ gate_hidden_weight = np.concatenate((ffn_gate_weight, ffn_hidden_weight), 0)
142
+ param_dict[gate_hidden_concat_weight_name] = Parameter(
143
+ gate_hidden_weight, name=gate_hidden_concat_weight_name
144
+ )
145
+
146
+ param_dict.pop(wq_weight_name)
147
+ param_dict.pop(wk_weight_name)
148
+ param_dict.pop(wv_weight_name)
149
+ param_dict.pop(ffn_gate_weight_name)
150
+ param_dict.pop(ffn_hidden_weight_name)
151
+ logger.info(f"transform: {qkv_concat_weight_name}")
152
+ logger.info(f"transform: {gate_hidden_concat_weight_name}")
153
+
154
+ for i in range(assume_num_layers):
155
+ # qkv bias concat
156
+ wq_bias_name = f"model.layers.{i}.attention.wq.bias"
157
+ wk_bias_name = f"model.layers.{i}.attention.wk.bias"
158
+ wv_bias_name = f"model.layers.{i}.attention.wv.bias"
159
+ qkv_concat_bias_name = f"model.layers.{i}.attention.w_qkv.bias"
160
+ if wq_bias_name not in param_dict:
161
+ break
162
+
163
+ wq_bias_weight = param_dict[wq_bias_name].asnumpy()
164
+ wk_bias_weight = param_dict[wk_bias_name].asnumpy()
165
+ wv_bias_weight = param_dict[wv_bias_name].asnumpy()
166
+ qkv_bias_weight = np.concatenate(
167
+ (wq_bias_weight, wk_bias_weight, wv_bias_weight), 0
168
+ )
169
+ param_dict[qkv_concat_bias_name] = Parameter(
170
+ qkv_bias_weight, name=qkv_concat_bias_name
171
+ )
172
+
173
+ param_dict.pop(wq_bias_name)
174
+ param_dict.pop(wk_bias_name)
175
+ param_dict.pop(wv_bias_name)
176
+ logger.info(f"transform: {qkv_concat_bias_name}")
177
+ return param_dict
178
+
90
179
  def set_weights(self, parm_dict, dtype=mstype.float16):
91
180
  """set weights for llm boost"""
181
+ self._convert_qkv_concat_weight(parm_dict)
92
182
  embedding_weight_name = "model.tok_embeddings.embedding_weight"
93
183
  attention_norm_name = "attention_norm"
94
184
  qkv_name = "attention.w_qkv"
@@ -101,45 +191,88 @@ class AtbBoostBase():
101
191
  placeholder = Parameter(Tensor(np.zeros(1), dtype=dtype))
102
192
 
103
193
  ascend_weight = []
104
- ascend_weight.append(
105
- self.cast(parm_dict[embedding_weight_name], dtype))
194
+ ascend_weight.append(self.cast(parm_dict[embedding_weight_name], dtype))
106
195
  for i in range(self.num_layers):
107
- ascend_weight.append(self._convert_tensor_format_and_dtype(
108
- parm_dict[f"model.layers.{i}.{attention_norm_name}.weight"], dtype))
196
+ ascend_weight.append(
197
+ self._convert_tensor_format_and_dtype(
198
+ parm_dict[f"model.layers.{i}.{attention_norm_name}.weight"], dtype
199
+ )
200
+ )
109
201
  ascend_weight.extend([placeholder] * 3)
110
202
 
111
203
  ascend_weight.append(
112
- self._convert_tensor_format_and_dtype(parm_dict[f"model.layers.{i}.{qkv_name}.weight"], dtype))
113
- ascend_weight.append(self._convert_tensor_format_and_dtype(parm_dict.get(
114
- f"model.layers.{i}.{qkv_name}.bias", placeholder), dtype))
204
+ self._convert_tensor_format_and_dtype(
205
+ parm_dict[f"model.layers.{i}.{qkv_name}.weight"], dtype
206
+ )
207
+ )
208
+ ascend_weight.append(
209
+ self._convert_tensor_format_and_dtype(
210
+ parm_dict.get(f"model.layers.{i}.{qkv_name}.bias", placeholder),
211
+ dtype,
212
+ )
213
+ )
115
214
  ascend_weight.extend([placeholder] * 16)
116
215
 
117
216
  ascend_weight.append(
118
- self._convert_tensor_format_and_dtype(parm_dict[f"model.layers.{i}.{o_name}.weight"], dtype))
119
- ascend_weight.append(self._convert_tensor_format_and_dtype(parm_dict.get(
120
- f"model.layers.{i}.{o_name}.bias", placeholder), dtype))
217
+ self._convert_tensor_format_and_dtype(
218
+ parm_dict[f"model.layers.{i}.{o_name}.weight"], dtype
219
+ )
220
+ )
221
+ ascend_weight.append(
222
+ self._convert_tensor_format_and_dtype(
223
+ parm_dict.get(f"model.layers.{i}.{o_name}.bias", placeholder), dtype
224
+ )
225
+ )
121
226
  ascend_weight.extend([placeholder] * 4)
122
227
 
123
228
  ascend_weight.append(
124
- self._convert_tensor_format_and_dtype(parm_dict[f"model.layers.{i}.{mlp_norm_name}.weight"], dtype))
229
+ self._convert_tensor_format_and_dtype(
230
+ parm_dict[f"model.layers.{i}.{mlp_norm_name}.weight"], dtype
231
+ )
232
+ )
125
233
  ascend_weight.extend([placeholder] * 3)
126
234
 
127
235
  ascend_weight.append(
128
- self._convert_tensor_format_and_dtype(parm_dict[f"model.layers.{i}.{mlp_gate_name}.weight"], dtype))
129
- ascend_weight.append(self._convert_tensor_format_and_dtype(parm_dict.get(
130
- f"model.layers.{i}.{mlp_gate_name}.bias", placeholder), dtype))
236
+ self._convert_tensor_format_and_dtype(
237
+ parm_dict[f"model.layers.{i}.{mlp_gate_name}.weight"], dtype
238
+ )
239
+ )
240
+ ascend_weight.append(
241
+ self._convert_tensor_format_and_dtype(
242
+ parm_dict.get(
243
+ f"model.layers.{i}.{mlp_gate_name}.bias", placeholder
244
+ ),
245
+ dtype,
246
+ )
247
+ )
131
248
  ascend_weight.extend([placeholder] * 10)
132
249
 
133
250
  ascend_weight.append(
134
- self._convert_tensor_format_and_dtype(parm_dict[f"model.layers.{i}.{mlp_down_name}.weight"], dtype))
135
- ascend_weight.append(self._convert_tensor_format_and_dtype(parm_dict.get(
136
- f"model.layers.{i}.{mlp_down_name}.bias", placeholder), dtype))
251
+ self._convert_tensor_format_and_dtype(
252
+ parm_dict[f"model.layers.{i}.{mlp_down_name}.weight"], dtype
253
+ )
254
+ )
255
+ ascend_weight.append(
256
+ self._convert_tensor_format_and_dtype(
257
+ parm_dict.get(
258
+ f"model.layers.{i}.{mlp_down_name}.bias", placeholder
259
+ ),
260
+ dtype,
261
+ )
262
+ )
137
263
  ascend_weight.extend([placeholder] * 4)
138
264
 
139
265
  ascend_weight.append(
140
- self._convert_tensor_format_and_dtype(parm_dict[f"{norm_out_name}.weight"], dtype))
266
+ self._convert_tensor_format_and_dtype(
267
+ parm_dict[f"{norm_out_name}.weight"], dtype
268
+ )
269
+ )
141
270
  ascend_weight.append(
142
- self._convert_tensor_format_and_dtype(parm_dict[f"{lm_head_name}.weight"], dtype))
271
+ self._convert_tensor_format_and_dtype(
272
+ parm_dict[f"{lm_head_name}.weight"], dtype
273
+ )
274
+ )
275
+ self.ascend_weight = ascend_weight
143
276
  self.atb_encoder_operation.set_weights(ascend_weight)
144
277
  self.atb_decoder_operation.set_weights(ascend_weight)
145
278
 
@@ -147,20 +280,47 @@ class AtbBoostBase():
147
280
  """set kv_cache for llm boost"""
148
281
  if not k_caches or v_caches:
149
282
  if self.need_nz:
150
- kv_shape = (self.config.num_blocks, self.num_kv_heads*self.head_dim //
151
- self.device_num // 16, self.config.block_size, 16)
152
- k_caches = [_set_format(Parameter(Tensor(
153
- shape=kv_shape, dtype=self.dtype, init=Zero())), "FRACTAL_NZ") for _ in range(self.num_layers)]
154
- v_caches = [_set_format(Parameter(Tensor(
155
- shape=kv_shape, dtype=self.dtype, init=Zero())), "FRACTAL_NZ") for _ in range(self.num_layers)]
283
+ kv_shape = (
284
+ self.config.num_blocks,
285
+ self.num_kv_heads * self.head_dim // self.device_num // 16,
286
+ self.config.block_size,
287
+ 16,
288
+ )
289
+ k_caches = [
290
+ _set_format(
291
+ Parameter(
292
+ Tensor(shape=kv_shape, dtype=self.dtype, init=Zero())
293
+ ),
294
+ FORMAT_NZ,
295
+ )
296
+ for _ in range(self.num_layers)
297
+ ]
298
+ v_caches = [
299
+ _set_format(
300
+ Parameter(
301
+ Tensor(shape=kv_shape, dtype=self.dtype, init=Zero())
302
+ ),
303
+ FORMAT_NZ,
304
+ )
305
+ for _ in range(self.num_layers)
306
+ ]
156
307
  else:
157
- kv_shape = (self.config.num_blocks, self.config.block_size,
158
- self.num_kv_heads // self.device_num, self.head_dim)
159
- k_caches = [Parameter(Tensor(
160
- shape=kv_shape, dtype=self.dtype, init=Zero())) for _ in range(self.num_layers)]
161
- v_caches = [Parameter(Tensor(
162
- shape=kv_shape, dtype=self.dtype, init=Zero())) for _ in range(self.num_layers)]
163
-
308
+ kv_shape = (
309
+ self.config.num_blocks,
310
+ self.config.block_size,
311
+ self.num_kv_heads // self.device_num,
312
+ self.head_dim,
313
+ )
314
+ k_caches = [
315
+ Parameter(Tensor(shape=kv_shape, dtype=self.dtype, init=Zero()))
316
+ for _ in range(self.num_layers)
317
+ ]
318
+ v_caches = [
319
+ Parameter(Tensor(shape=kv_shape, dtype=self.dtype, init=Zero()))
320
+ for _ in range(self.num_layers)
321
+ ]
322
+ self.k_caches = k_caches
323
+ self.v_caches = v_caches
164
324
  self.atb_encoder_operation.set_kvcache(k_caches, v_caches)
165
325
  self.atb_decoder_operation.set_kvcache(k_caches, v_caches)
166
326
 
@@ -171,11 +331,9 @@ class AtbBoostBase():
171
331
  def _execute_operator(self, acl_inputs, acl_param):
172
332
  """execute operator."""
173
333
  if self.is_first_iteration:
174
- acl_model_out = self.atb_encoder_operation.forward(
175
- acl_inputs, acl_param)
334
+ acl_model_out = self.atb_encoder_operation.forward(acl_inputs, acl_param)
176
335
  else:
177
- acl_model_out = self.atb_decoder_operation.forward(
178
- acl_inputs, acl_param)
336
+ acl_model_out = self.atb_decoder_operation.forward(acl_inputs, acl_param)
179
337
  acl_hidden_state = acl_model_out[0]
180
338
  return acl_hidden_state
181
339
 
@@ -183,28 +341,46 @@ class AtbBoostBase():
183
341
  r"""
184
342
  LlmBoost forward.
185
343
  """
186
- input_ids = boost_inputs["input_ids"]
187
- position_ids = boost_inputs["position_ids"]
188
- cos_embed = boost_inputs["cos_embed"]
189
- sin_embed = boost_inputs["sin_embed"]
190
- block_tables = boost_inputs["block_tables"]
191
- slot_mapping = boost_inputs["slot_mapping"]
192
- batch_valid_length = boost_inputs["batch_valid_length"]
193
- lm_head_indices = boost_inputs["lm_head_indices"]
194
- seqLen = boost_inputs["seq_lens"]
344
+ input_ids = boost_inputs.get("input_ids", None)
345
+ position_ids = boost_inputs.get("position_ids", None)
346
+ cos_embed = boost_inputs.get("cos_embed", None)
347
+ sin_embed = boost_inputs.get("sin_embed", None)
348
+ block_tables = boost_inputs.get("block_tables", None)
349
+ slot_mapping = boost_inputs.get("slot_mapping", None)
350
+ batch_valid_length = boost_inputs.get("batch_valid_length", None)
351
+ lm_head_indices = boost_inputs.get("lm_head_indices", None)
352
+ seqLen = boost_inputs.get("seq_lens", None)
353
+ input_ids = self.reshape(input_ids, (-1,))
195
354
  if self.is_first_iteration:
196
355
  attention_mask = self.attn_mask
197
356
  else:
198
- position_ids = batch_valid_length - 1
357
+ if position_ids is None:
358
+ position_ids = batch_valid_length - 1
199
359
  attention_mask = self.placeholder
200
360
  lm_head_indices = self.lm_head_indices_fake
201
361
 
202
- acl_inputs, acl_param = self._prepare_inputs(prefill=self.is_first_iteration, input_ids=input_ids,
203
- position_ids=position_ids, cos_embed=cos_embed,
204
- sin_embed=sin_embed, attention_mask=attention_mask,
205
- block_tables=block_tables, slots=slot_mapping,
206
- input_lengths=batch_valid_length, lm_head_indices=lm_head_indices,
207
- seqLen=seqLen)
362
+ if input_ids is not None and input_ids.dtype != mstype.int64:
363
+ input_ids = self.cast(input_ids, mstype.int64)
364
+ if position_ids is not None and position_ids.dtype != mstype.int64:
365
+ position_ids = self.cast(position_ids, mstype.int64)
366
+ if batch_valid_length is not None and batch_valid_length.dtype != mstype.int32:
367
+ batch_valid_length = self.cast(batch_valid_length, mstype.int32)
368
+ if lm_head_indices is not None and lm_head_indices.dtype != mstype.int64:
369
+ lm_head_indices = self.cast(lm_head_indices, mstype.int64)
370
+
371
+ acl_inputs, acl_param = self._prepare_inputs(
372
+ prefill=self.is_first_iteration,
373
+ input_ids=input_ids,
374
+ position_ids=position_ids,
375
+ cos_embed=cos_embed,
376
+ sin_embed=sin_embed,
377
+ attention_mask=attention_mask,
378
+ block_tables=block_tables,
379
+ slots=slot_mapping,
380
+ input_lengths=batch_valid_length,
381
+ lm_head_indices=lm_head_indices,
382
+ seqLen=seqLen,
383
+ )
208
384
  ms.hal.synchronize()
209
385
  logits = self._execute_operator(acl_inputs, acl_param)
210
386
  logits = self.cast(logits, mstype.float32)