mindspore 2.2.0__cp38-cp38-win_amd64.whl → 2.2.11__cp38-cp38-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (112) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/_c_dataengine.cp38-win_amd64.pyd +0 -0
  3. mindspore/_c_expression.cp38-win_amd64.pyd +0 -0
  4. mindspore/_c_mindrecord.cp38-win_amd64.pyd +0 -0
  5. mindspore/_checkparam.py +3 -3
  6. mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
  7. mindspore/_extends/graph_kernel/splitter.py +3 -2
  8. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +83 -66
  9. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -4
  10. mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
  11. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +2 -1
  12. mindspore/_extends/parse/__init__.py +3 -2
  13. mindspore/_extends/parse/parser.py +6 -1
  14. mindspore/_extends/parse/standard_method.py +14 -11
  15. mindspore/_extends/remote/kernel_build_server.py +2 -1
  16. mindspore/common/_utils.py +16 -0
  17. mindspore/common/api.py +1 -1
  18. mindspore/common/auto_dynamic_shape.py +81 -85
  19. mindspore/common/dump.py +1 -1
  20. mindspore/common/tensor.py +3 -20
  21. mindspore/config/op_info.config +1 -1
  22. mindspore/context.py +11 -4
  23. mindspore/dataset/engine/cache_client.py +8 -5
  24. mindspore/dataset/engine/datasets_standard_format.py +5 -0
  25. mindspore/dataset/vision/transforms.py +21 -21
  26. mindspore/experimental/optim/adam.py +1 -1
  27. mindspore/gen_ops.py +1 -1
  28. mindspore/include/api/model.h +17 -0
  29. mindspore/include/api/status.h +8 -3
  30. mindspore/mindspore_backend.dll +0 -0
  31. mindspore/mindspore_common.dll +0 -0
  32. mindspore/mindspore_core.dll +0 -0
  33. mindspore/mindspore_shared_lib.dll +0 -0
  34. mindspore/nn/cell.py +0 -3
  35. mindspore/nn/layer/activation.py +4 -5
  36. mindspore/nn/layer/conv.py +39 -23
  37. mindspore/nn/layer/flash_attention.py +54 -129
  38. mindspore/nn/layer/math.py +3 -7
  39. mindspore/nn/layer/rnn_cells.py +5 -5
  40. mindspore/nn/wrap/__init__.py +4 -2
  41. mindspore/nn/wrap/cell_wrapper.py +12 -3
  42. mindspore/numpy/utils_const.py +5 -5
  43. mindspore/opencv_core452.dll +0 -0
  44. mindspore/opencv_imgcodecs452.dll +0 -0
  45. mindspore/opencv_imgproc452.dll +0 -0
  46. mindspore/ops/_grad_experimental/grad_array_ops.py +1 -1
  47. mindspore/ops/_grad_experimental/grad_implementations.py +2 -2
  48. mindspore/ops/_grad_experimental/grad_math_ops.py +19 -18
  49. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
  50. mindspore/ops/_op_impl/aicpu/add.py +3 -3
  51. mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +21 -2
  52. mindspore/ops/_utils/utils.py +2 -0
  53. mindspore/ops/composite/multitype_ops/_compile_utils.py +2 -1
  54. mindspore/ops/composite/multitype_ops/getitem_impl.py +2 -2
  55. mindspore/ops/function/array_func.py +10 -7
  56. mindspore/ops/function/grad/grad_func.py +0 -1
  57. mindspore/ops/function/nn_func.py +98 -9
  58. mindspore/ops/function/random_func.py +2 -1
  59. mindspore/ops/op_info_register.py +24 -21
  60. mindspore/ops/operations/__init__.py +6 -2
  61. mindspore/ops/operations/_grad_ops.py +25 -6
  62. mindspore/ops/operations/_inner_ops.py +155 -23
  63. mindspore/ops/operations/array_ops.py +9 -7
  64. mindspore/ops/operations/comm_ops.py +2 -2
  65. mindspore/ops/operations/custom_ops.py +85 -68
  66. mindspore/ops/operations/inner_ops.py +26 -3
  67. mindspore/ops/operations/math_ops.py +7 -6
  68. mindspore/ops/operations/nn_ops.py +193 -49
  69. mindspore/parallel/_parallel_serialization.py +10 -3
  70. mindspore/parallel/_tensor.py +4 -1
  71. mindspore/parallel/checkpoint_transform.py +13 -2
  72. mindspore/parallel/shard.py +17 -10
  73. mindspore/profiler/common/util.py +1 -0
  74. mindspore/profiler/parser/ascend_hccl_generator.py +232 -0
  75. mindspore/profiler/parser/ascend_msprof_exporter.py +86 -43
  76. mindspore/profiler/parser/ascend_msprof_generator.py +196 -9
  77. mindspore/profiler/parser/ascend_op_generator.py +1 -1
  78. mindspore/profiler/parser/ascend_timeline_generator.py +6 -182
  79. mindspore/profiler/parser/base_timeline_generator.py +1 -1
  80. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +2 -2
  81. mindspore/profiler/parser/framework_parser.py +1 -1
  82. mindspore/profiler/parser/profiler_info.py +19 -0
  83. mindspore/profiler/profiling.py +46 -24
  84. mindspore/rewrite/api/pattern_engine.py +1 -1
  85. mindspore/rewrite/parsers/for_parser.py +7 -7
  86. mindspore/rewrite/parsers/module_parser.py +4 -4
  87. mindspore/rewrite/symbol_tree.py +1 -4
  88. mindspore/run_check/_check_version.py +5 -3
  89. mindspore/safeguard/rewrite_obfuscation.py +52 -28
  90. mindspore/train/callback/_summary_collector.py +1 -1
  91. mindspore/train/dataset_helper.py +1 -0
  92. mindspore/train/model.py +2 -2
  93. mindspore/train/serialization.py +97 -11
  94. mindspore/train/summary/_summary_adapter.py +1 -1
  95. mindspore/train/summary/summary_record.py +23 -7
  96. mindspore/version.py +1 -1
  97. {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/METADATA +3 -2
  98. {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/RECORD +101 -112
  99. mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
  100. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +0 -406
  101. mindspore/ops/_op_impl/_custom_op/flash_attention/constants.py +0 -41
  102. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +0 -467
  103. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +0 -563
  104. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +0 -193
  105. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +0 -435
  106. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
  107. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +0 -45
  108. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +0 -67
  109. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +0 -62
  110. {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/WHEEL +0 -0
  111. {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/entry_points.txt +0 -0
  112. {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/top_level.txt +0 -0
@@ -36,13 +36,16 @@ if platform.system() == "Linux":
36
36
  BUILT_IN_OPS_REGISTER_PATH = "mindspore/ops/_op_impl"
37
37
  BUILT_IN_CUSTOM_OPS_REGISTER_PATH = "mindspore/ops/_op_impl/_custom_op"
38
38
 
39
+ KEY_NAME = "name"
40
+ ASCEND_CUSTOM_OPP_PATH = "ASCEND_CUSTOM_OPP_PATH"
39
41
 
40
- def _get_reg_info_attr(op_info, attr_name):
42
+
43
+ def _get_reg_info_attr(op_info, attr_name, default_value=None):
41
44
  """get attr value"""
42
45
  for _, item in enumerate(op_info.get("attr", [])):
43
- if item.get("name") == attr_name:
46
+ if item.get(KEY_NAME) == attr_name:
44
47
  return item.get("defaultValue")
45
- return None
48
+ return default_value
46
49
 
47
50
 
48
51
  class _CustomInstaller:
@@ -66,12 +69,12 @@ class _CustomInstaller:
66
69
  @staticmethod
67
70
  def _set_env(custom_opp_path):
68
71
  """set custom file path to env"""
69
- if not os.environ.get("ASCEND_CUSTOM_OPP_PATH"):
70
- os.environ["ASCEND_CUSTOM_OPP_PATH"] = custom_opp_path
72
+ if not os.environ.get(ASCEND_CUSTOM_OPP_PATH):
73
+ os.environ[ASCEND_CUSTOM_OPP_PATH] = custom_opp_path
71
74
  else:
72
- paths = os.environ["ASCEND_CUSTOM_OPP_PATH"].split(':')
75
+ paths = os.environ[ASCEND_CUSTOM_OPP_PATH].split(':')
73
76
  if custom_opp_path not in paths:
74
- os.environ["ASCEND_CUSTOM_OPP_PATH"] = custom_opp_path + ':' + os.environ["ASCEND_CUSTOM_OPP_PATH"]
77
+ os.environ[ASCEND_CUSTOM_OPP_PATH] = custom_opp_path + ':' + os.environ[ASCEND_CUSTOM_OPP_PATH]
75
78
 
76
79
  @staticmethod
77
80
  def _create_dir(*dir_names):
@@ -94,11 +97,11 @@ class _CustomInstaller:
94
97
  _CustomInstaller.copied_paths.append(src_path)
95
98
  if os.path.isfile(src_path):
96
99
  lock_file = os.path.join(dst_dir, "file.lock")
97
- with open(lock_file, "w") as f:
100
+ with os.fdopen(os.open(lock_file, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600), 'w') as f:
98
101
  fcntl.flock(f.fileno(), fcntl.LOCK_EX)
99
102
  shutil.copy(src_path, dst_dir)
100
103
 
101
- def _check(self):
104
+ def check(self):
102
105
  """check if the reg info need written"""
103
106
  if platform.system() != "Linux":
104
107
  return False
@@ -153,12 +156,12 @@ class _CustomInstaller:
153
156
  # attr
154
157
  attrs_name = []
155
158
  for _, item in enumerate(self.op_info.get("attr", [])):
156
- attr_name = item.get("name")
159
+ attr_name = item.get(KEY_NAME)
157
160
  attrs_name.append(attr_name)
158
161
  key = "attr_" + attr_name
159
162
  op_info[key] = {}
160
163
  for k, v in item.items():
161
- if k != "name":
164
+ if k != KEY_NAME:
162
165
  op_info[key][k] = v
163
166
  if attrs_name:
164
167
  op_info["attr"] = {"list": ",".join(attrs_name)}
@@ -171,7 +174,7 @@ class _CustomInstaller:
171
174
  item = inputs[i] if i < input_num else outputs[i - input_num]
172
175
  key = "input" if i < input_num else "output"
173
176
  key += str(item.get("index"))
174
- op_info[key] = {"name": item.get("name"),
177
+ op_info[key] = {KEY_NAME: item.get(KEY_NAME),
175
178
  "paramType": item.get("paramType", "required"),
176
179
  "shape": item.get("shape", "all")}
177
180
  dtype, formats = _get_dtype_format(i)
@@ -181,7 +184,8 @@ class _CustomInstaller:
181
184
  op_info[key]["format"] = ",".join(formats)
182
185
  return op_info
183
186
 
184
- def _gen_ai_cpu_reg_info(self, so_file):
187
+ @staticmethod
188
+ def _gen_ai_cpu_reg_info(so_file):
185
189
  """generate reg info"""
186
190
  op_info = {"opInfo": {"computeCost": "100",
187
191
  "engine": "DNN_VM_AICPU",
@@ -198,7 +202,7 @@ class _CustomInstaller:
198
202
  repo = {}
199
203
  save_path = os.path.join(dst_dir, file_name)
200
204
  lock_file = os.path.join(dst_dir, "file.lock")
201
- with open(lock_file, "w") as f:
205
+ with os.fdopen(os.open(lock_file, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600), 'w') as f:
202
206
  fcntl.flock(f.fileno(), fcntl.LOCK_EX)
203
207
  if os.path.isfile(save_path):
204
208
  with open(save_path, 'r') as fr:
@@ -211,7 +215,7 @@ class _CustomInstaller:
211
215
 
212
216
  def run(self):
213
217
  """save reg info to file"""
214
- if not self._check():
218
+ if not self.check():
215
219
  return
216
220
  so_name = _get_reg_info_attr(self.op_info, "cust_aicpu")
217
221
  if so_name:
@@ -380,7 +384,6 @@ class RegOp:
380
384
  """
381
385
  if not isinstance(value, str):
382
386
  raise TypeError("%s value must be str" % str(value))
383
- return True
384
387
 
385
388
  def _is_int(self, value):
386
389
  """
@@ -394,7 +397,6 @@ class RegOp:
394
397
  """
395
398
  if not isinstance(value, int):
396
399
  raise TypeError("%s value must be int" % str(value))
397
- return True
398
400
 
399
401
  def _is_bool(self, value):
400
402
  """
@@ -408,7 +410,6 @@ class RegOp:
408
410
  """
409
411
  if not isinstance(value, bool):
410
412
  raise TypeError("%s value must be bool" % str(value))
411
- return True
412
413
 
413
414
  @staticmethod
414
415
  def _is_list(value):
@@ -423,7 +424,6 @@ class RegOp:
423
424
  """
424
425
  if not isinstance(value, list):
425
426
  raise TypeError("%s value must be list" % str(value))
426
- return True
427
427
 
428
428
  def _check_param(self, param_list, key_list, fn_list, kwargs):
429
429
  """
@@ -491,7 +491,9 @@ class RegOp:
491
491
  self._is_string(arg[1])
492
492
  if len(arg) == 3:
493
493
  self._is_string(arg[2])
494
- dtype_format.append(arg)
494
+ dtype_format.append(arg)
495
+ else:
496
+ dtype_format.append(arg)
495
497
  self.dtype_format_.append(tuple(dtype_format))
496
498
  return self
497
499
 
@@ -920,7 +922,8 @@ class TBERegOp(RegOp):
920
922
  Args:
921
923
  pattern (str): Value of op pattern, e.g. "broadcast", "reduce". Default: ``None`` .
922
924
  """
923
- if pattern is not None and self._is_string(pattern):
925
+ if pattern is not None:
926
+ self._is_string(pattern)
924
927
  self.op_pattern_ = pattern
925
928
  return self
926
929
 
@@ -118,7 +118,7 @@ from .nn_ops import (LSTM, SGD, Adam, AdamWeightDecay, FusedSparseAdam, FusedSpa
118
118
  Dilation2D, DataFormatVecPermute, DeformableOffsets, Dense, FractionalAvgPool,
119
119
  FractionalMaxPool, FractionalMaxPool3DWithFixedKsize, FractionalMaxPoolWithFixedKsize,
120
120
  GridSampler2D, TripletMarginLoss, UpsampleNearest3D, UpsampleTrilinear3D, PadV3, ChannelShuffle,
121
- GLU, MaxUnpool3D, Pdist)
121
+ GLU, MaxUnpool3D, Pdist, RmsNorm, PagedAttention, PagedAttentionMask, ReshapeAndCache)
122
122
  from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode,
123
123
  ConfusionMatrix, UpdateState, Load, StopGradient,
124
124
  CheckValid, Partial, Depend, Push, Pull, PyExecute, PyFunc, _DynamicLossScale,
@@ -691,7 +691,11 @@ __all__ = [
691
691
  "IndexPut",
692
692
  "MaskedScatter",
693
693
  "Ormqr",
694
- "RandpermV2"
694
+ "RandpermV2",
695
+ "RmsNorm",
696
+ "PagedAttention",
697
+ "PagedAttentionMask",
698
+ "ReshapeAndCache"
695
699
  ]
696
700
 
697
701
  __custom__ = [
@@ -3845,7 +3845,7 @@ class FlashAttentionScoreGrad(Primitive):
3845
3845
  """
3846
3846
  @prim_attr_register
3847
3847
  def __init__(self, head_num, keep_prob=1.0, scale_value=1.0, pre_tokens=65536, next_tokens=65536, inner_precise=1,
3848
- input_layout='BSH'):
3848
+ input_layout='BSH', sparse_mode=0):
3849
3849
  """Initialize FlashAttentionScoreGrad."""
3850
3850
  validator.check_value_type('head_num', head_num, [int], self.name)
3851
3851
  validator.check_value_type('keep_prob', keep_prob, [int, float], self.name)
@@ -3855,11 +3855,30 @@ class FlashAttentionScoreGrad(Primitive):
3855
3855
  validator.check_value_type('pre_tokens', pre_tokens, [int], self.name)
3856
3856
  validator.check_value_type('next_tokens', next_tokens, [int], self.name)
3857
3857
  validator.check_value_type('inner_precise', inner_precise, [int], self.name)
3858
+ validator.check_value_type('sparse_mode', sparse_mode, [int], self.name)
3858
3859
  if inner_precise not in [0, 1]:
3859
3860
  raise ValueError(f"Attribute 'inner_precise' must be either 0 or 1, but got {inner_precise}")
3860
3861
  validator.check_value_type('input_layout', input_layout, [str], self.name)
3861
- if input_layout not in ["BSH"]:
3862
- raise ValueError(f"Attribute 'input_layout' must be either 'bsh' or 'sbh', but got {input_layout}")
3863
- self.init_prim_io_names(inputs=['query', 'key', 'value', 'attn_mask', 'attention_in', 'softmax_max',
3864
- 'softmax_sum', 'dy', 'drop_mask', 'real_shift', "padding_mask", 'softmax_out'],
3865
- outputs=['dq', 'dk', 'dv'])
3862
+ if input_layout not in ["BSH", "BNSD"]:
3863
+ raise ValueError(f"Attribute 'input_layout' must be either 'BSH' or 'BNSD', but got {input_layout}")
3864
+ self.init_prim_io_names(inputs=['query', 'key', 'value', 'dy', 'pse_shift', 'drop_mask', "padding_mask",
3865
+ 'attn_mask', 'softmax_max', 'softmax_sum', 'softmax_out', 'attention_in',
3866
+ 'prefix'],
3867
+ outputs=['dq', 'dk', 'dv', 'dpse'])
3868
+
3869
+
3870
+ class RmsNormGrad(Primitive):
3871
+ r"""
3872
+ Calculates the gradient of RmsNorm operation.
3873
+ .. warning::
3874
+ This is an experimental API that is subject to change or deletion.
3875
+
3876
+ Supported Platforms:
3877
+ ``Ascend``
3878
+ """
3879
+
3880
+ @prim_attr_register
3881
+ def __init__(self):
3882
+ """Initialize RmsNormGrad."""
3883
+ self.init_prim_io_names(inputs=["dy", "x", "rstd", "gamma"],
3884
+ outputs=["dx", "dgamma"])
@@ -26,7 +26,7 @@ from mindspore.ops.operations._scalar_ops import bit_or, bit_and
26
26
  from mindspore.ops.operations.comm_ops import ReduceOp
27
27
  from mindspore.ops import signature as sig
28
28
  from mindspore.ops.operations.math_ops import _infer_shape_reduce
29
- from mindspore.ops.primitive import PrimitiveWithCheck, PrimitiveWithInfer, prim_attr_register, Primitive,\
29
+ from mindspore.ops.primitive import PrimitiveWithCheck, PrimitiveWithInfer, prim_attr_register, Primitive, \
30
30
  _run_op, _check_contains_variable
31
31
  from mindspore._c_expression import Tensor as Tensor_
32
32
  from mindspore._c_expression import typing
@@ -167,6 +167,7 @@ class Quant(PrimitiveWithInfer):
167
167
  self.sqrt_mode = validator.check_value_type("sqrt_mode", sqrt_mode, [bool], self.name)
168
168
  self.round_mode = validator.check_string(round_mode, ["Round", "Floor", "Ceil", "Trunc"],
169
169
  "round_mode", self.name)
170
+ self.add_prim_attr("dst_type", mstype.int8)
170
171
 
171
172
  def infer_shape(self, x_shape):
172
173
  return x_shape
@@ -174,7 +175,7 @@ class Quant(PrimitiveWithInfer):
174
175
  def infer_dtype(self, x_type):
175
176
  validator.check_subclass("input_x", x_type, mstype.tensor_type, self.name)
176
177
  validator.check_type_name("input_x", x_type, [mstype.float16, mstype.float32], self.name)
177
- return mstype.int8
178
+ return self.get_attr_dict()['dst_type']
178
179
 
179
180
 
180
181
  class Lamb(PrimitiveWithInfer):
@@ -491,7 +492,7 @@ class Receive(PrimitiveWithInfer):
491
492
  self.dtype = dtype
492
493
  self.group = group
493
494
  self.add_prim_attr("no_eliminate", True)
494
- valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8]
495
+ valid_type = [mstype.float16, mstype.bfloat16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8]
495
496
  args = {"dtype": dtype}
496
497
  validator.check_scalar_or_tensor_types_same(args, valid_type, self.name)
497
498
 
@@ -2146,13 +2147,14 @@ class ClipByNorm(PrimitiveWithInfer):
2146
2147
  @prim_attr_register
2147
2148
  def __init__(self, axis=None):
2148
2149
  """Initialize ClipByNorm"""
2150
+ self.axis_str = 'axis'
2149
2151
  self.axis = () if axis is None else axis
2150
- validator.check_value_type('axis', self.axis, [int, tuple, list], self.name)
2152
+ validator.check_value_type(self.axis_str, self.axis, [int, tuple, list], self.name)
2151
2153
  axis_check = self.axis if isinstance(self.axis, Iterable) else (self.axis,)
2152
2154
  for i, value in enumerate(axis_check):
2153
2155
  validator.check_value_type('axis[%d]' % i, value, [int], self.name)
2154
- self.init_attrs['axis'] = self.axis
2155
- self.add_prim_attr('axis', self.axis)
2156
+ self.init_attrs[self.axis_str] = self.axis
2157
+ self.add_prim_attr(self.axis_str, self.axis)
2156
2158
  self.init_prim_io_names(inputs=['x', 'clip_norm'], outputs=['output'])
2157
2159
 
2158
2160
  def infer_shape(self, x_shape, clip_norm_shape):
@@ -2729,27 +2731,29 @@ class CopyWithSlice(Primitive):
2729
2731
  self.init_prim_io_names(inputs=['x', 'y'], outputs=['x'])
2730
2732
 
2731
2733
 
2732
- class MoeFFN(Primitive):
2734
+ class FFN(Primitive):
2733
2735
  r"""
2734
- The MoeFFN computation is similar to Feed-Forward Network, it contains matmul + gelu + matmul.
2736
+ The FFN computation is similar to Feed-Forward Network, it contains matmul + gelu + matmul.
2735
2737
 
2736
2738
  Args:
2737
2739
  activation (string): The activation type, set to 'fastgelu' or 'gelu'.
2738
- Only support 'fastgelu' for now. Default: "fastgelu".
2740
+ Only support 'fastgelu' for now. Default: "fastgelu".
2741
+ inner_precise (int): The precise mode, set to 0 for high precision or 1 for high performance.
2742
+ Only support 1 for now. Default: 0.
2739
2743
 
2740
2744
  Inputs:
2741
2745
  - **x** (Tensor) - The input tensor with data type of int8, float16.
2742
2746
  Input tensor of shape :math:`(batch\_size * seq\_length, hidden\_size)`.
2747
+ - **weight1** (Tensor) - The weight1 tensor with data type of float16.
2748
+ Weight1 tensor of shape :math:`(expert\_num, hidden\_size, ffn\_hidden\_size)`.
2749
+ - **weight2** (Tensor) - The weight2 tensor with data type of float16.
2750
+ Weight2 tensor of shape :math:`(expert\_num, ffn\_hidden\_size, hidden\_size)`.
2743
2751
  - **expert_tokens** (Tensor]) - The expert tokens tensor with data type of int64.
2744
2752
  Expert tokens tensor of shape :math:`(16,)`. For example, `(2, 1, 0, .., 9)`
2745
2753
  indicate that the 0th expert deals with 2 tokens, the 1th expert deals with 1 tokens,
2746
2754
  the 2th expert do noting and so on.
2747
- - **weight1** (Tensor) - The weight1 tensor with data type of float16.
2748
- Weight1 tensor of shape :math:`(expert\_num, hidden\_size, ffn\_hidden\_size)`.
2749
2755
  - **bias1** (Tensor) - The bias1 tensor with data type of float16.
2750
2756
  Bias1 tensor of shape :math:`(expert\_num, ffn\_hidden\_size)`.
2751
- - **weight2** (Tensor) - The weight2 tensor with data type of float16.
2752
- Weight2 tensor of shape :math:`(expert\_num, ffn\_hidden\_size, hidden\_size)`.
2753
2757
  - **bias2** (Tensor) - The bias2 tensor with data type of float16.
2754
2758
  Bias2 tensor of shape :math:`(expert\_num, hidden\_size)`.
2755
2759
  - **scale** (Tensor) - The scale tensor with data type of float16. Not enable now.
@@ -2771,21 +2775,149 @@ class MoeFFN(Primitive):
2771
2775
  >>> h_f = 4 * h
2772
2776
  >>> e = 16
2773
2777
  >>> x = Tensor(np.random.randn(b * s, h).astype(np.float16))
2774
- >>> expert_tokens = Tensor(np.random.randn(e).astype(np.int64))
2775
2778
  >>> w1 = Tensor(np.random.randn(e, h, h_f).astype(np.float16))
2776
- >>> bias1 = Tensor(np.random.randn(e, h_f).astype(np.float16))
2777
2779
  >>> w2 = Tensor(np.random.randn(e, h_f, h).astype(np.float16))
2780
+ >>> expert_tokens = Tensor(np.random.randn(e).astype(np.int64))
2781
+ >>> bias1 = Tensor(np.random.randn(e, h_f).astype(np.float16))
2778
2782
  >>> bias2 = Tensor(np.random.randn(e, h).astype(np.float16))
2779
- >>> moe_ffn = _inner_ops.MoeFFN("fastgelu")
2780
- >>> output = moe_ffn(x, w1, bias1, w2, bias2)
2783
+ >>> ffn = _inner_ops.FFN("fastgelu", 1)
2784
+ >>> output = ffn(x, w1, w2, expert_tokens, bias1, bias2)
2781
2785
  >>> print(output)
2782
2786
  """
2783
2787
 
2784
2788
  @prim_attr_register
2785
- def __init__(self, activation):
2786
- """Initialize MoeFFN."""
2787
- self.init_prim_io_names(inputs=["x", "expert_tokens", "weight1", "bias1",
2788
- "weight2", "bias2", "scale", "offset", "deq_scale1"
2789
- "deq_scale2"],
2789
+ def __init__(self, activation, inner_precise):
2790
+ """Initialize FFN."""
2791
+ self.init_prim_io_names(inputs=["x", "weight1", "weight2", "expert_tokens", "bias1",
2792
+ "bias2", "scale", "offset", "deq_scale1", "deq_scale2"],
2790
2793
  outputs=["y"])
2791
- self.activation = activation
2794
+ cls_name = self.name
2795
+ validator.check_value_type("activation", activation, [str], cls_name)
2796
+ validator.check_value_type("inner_precise", inner_precise, [int], cls_name)
2797
+
2798
+
2799
+ class DecoderKVCache(Primitive):
2800
+ r"""
2801
+ The DecoderKVCache is used for decoding the KVCache of transformer network.
2802
+
2803
+ Args:
2804
+ cache (Tensor): The cahe tensor with data type of int8, uint8, int16, uint16, float16, float32 and int32.
2805
+ When seq_len_axis is 2, cache tensor of shape
2806
+ :math:`(batch\_size, num_head, max\_seq\_length, hidden\_size)`.
2807
+ When seq_len_axis is 1, cache tensor of shape
2808
+ :math:`(batch\_size, max\_seq\_length, num_head, hidden\_size)`.
2809
+ update (Tensor]): The tensor which is used to update the cache tensor. Same data type as cache tensor.
2810
+ When seq_len_axis is 2, update tensor of shape
2811
+ :math:`(batch\_size, num_head, update\_seq\_length, hidden\_size)`.
2812
+ When seq_len_axis is 1, update tensor of shape
2813
+ :math:`(batch\_size, update\_seq\_length, num_head, hidden\_size)`.
2814
+ valid_seq_len (Tensor): The valid_seq_len tensor with data type of int64.
2815
+ Valid_seq_len tensor of shape :math:`(batch\_size)`.
2816
+ batch_index (Tensor): The batch_index tensor with data type of int64.
2817
+ Batch_index tensor of shape :math:`(1)`. Indicate that which batch of cache tensor is going to be update.
2818
+ seq_len_axis (int64): The seq_len_axis indicate which axis is seq_eln, set to '1' or '2'. Default: "2".
2819
+ new_max_seq_len (Tensor): The new_max_seq_len tensor with data type of int64.
2820
+ New_max_seq_len tensor of shape :math:`(1)`.
2821
+ Indicate that user want to change the shape of cache tensor from
2822
+ :math:`(batch\_size, num_head, max\_seq\_length, hidden\_size)` to
2823
+ :math:
2824
+ `(batch\_size * max\_seq\_length / new\_max\_seq\_length, num_head, new\_max\_seq\_length, hidden\_size)`
2825
+ to update the cache tensor. This will not real change the shape of `cache` tensor. Not able for now.
2826
+ cur_max_seq_len (Tensor): The new_max_seq_len tensor with data type of int64.
2827
+ Cur_max_seq_len tensor of shape :math:`(1)`. Keep the current seq_len of cache tensor. Not abel for now.
2828
+
2829
+ Outputs:
2830
+ With same data type and same shape as `cache` tensor.
2831
+
2832
+ Supported Platforms:
2833
+ ``Ascend``
2834
+
2835
+ Examples:
2836
+ >>> from mindspore.ops.operations import _inner_ops
2837
+ >>> b = 4
2838
+ >>> h = 40
2839
+ >>> max_s = 1024
2840
+ >>> s = 1
2841
+ >>> d = 128
2842
+ >>> cache = Tensor(np.random.randn(b, h, max_s, d).astype(np.float16))
2843
+ >>> update = Tensor(np.random.randn(b, h, s, d).astype(np.float16))
2844
+ >>> valid_seq_len = Tensor(np.random.randn(b).astype(np.int64))
2845
+ >>> batch_index = Tensor(np.random.randn(1).astype(np.int64))
2846
+ >>> new_max_seq_len = Tensor(np.random.randn(1).astype(np.int64))
2847
+ >>> cur_max_seq_len = Tensor(np.random.randn(1).astype(np.int64))
2848
+ >>> decoder_kv_cache = _inner_ops.DecoderKVCache()
2849
+ >>> output = decoder_kv_cache(cache, update, valid_seq_len, batch_index, 2, new_max_seq_len, cur_max_seq_len)
2850
+ >>> print(cache)
2851
+ """
2852
+ @prim_attr_register
2853
+ def __init__(self):
2854
+ """Initialize DecoderKVCache."""
2855
+ self.init_prim_io_names(inputs=["cache", "update", "valid_seq_len", "batch_index", "seq_len_axis",
2856
+ "new_max_seq_len", "cur_max_seq_len"],
2857
+ outputs=["out"])
2858
+ self.add_prim_attr('side_effect_mem', True)
2859
+
2860
+
2861
+ class PromptKVCache(Primitive):
2862
+ r"""
2863
+ The PromptKVCache is used for prefill the KVCache of transformer network.
2864
+
2865
+ Args:
2866
+ cache (Tensor): The cahe tensor with data type of int8, uint8, int16, uint16, float16, float32 and int32.
2867
+ When seq_len_axis is 2, cache tensor of shape
2868
+ :math:`(batch\_size, num_head, max\_seq\_length, hidden\_size)`.
2869
+ When seq_len_axis is 1, cache tensor of shape
2870
+ :math:`(batch\_size, max\_seq\_length, num_head, hidden\_size)`.
2871
+ update (Tensor]): The tensor which is used to update the cache tensor. Same data type as cache tensor.
2872
+ When seq_len_axis is 2, update tensor of shape
2873
+ :math:`(batch\_size, num_head, update\_seq\_length, hidden\_size)`.
2874
+ When seq_len_axis is 1, update tensor of shape
2875
+ :math:`(batch\_size, update\_seq\_length, num_head, hidden\_size)`.
2876
+ valid_seq_len (Tensor): The valid_seq_len tensor with data type of int64.
2877
+ Valid_seq_len tensor of shape :math:`(batch\_size)`.
2878
+ batch_index (Tensor): The batch_index tensor with data type of int64.
2879
+ Batch_index tensor of shape :math:`(1)`. Indicate that which batch of cache tensor is going to be update.
2880
+ seq_len_axis (int64): The seq_len_axis indicate which axis is seq_eln, set to '1' or '2'. Default: "2".
2881
+ new_max_seq_len (Tensor): The new_max_seq_len tensor with data type of int64.
2882
+ New_max_seq_len tensor of shape :math:`(1)`.
2883
+ Indicate that user want to change the shape of cache tensor from
2884
+ :math:`(batch\_size, num_head, max\_seq\_length, hidden\_size)` to
2885
+ :math:
2886
+ `(batch\_size * max\_seq\_length / new\_max\_seq\_length, num_head, new\_max\_seq\_length, hidden\_size)`
2887
+ to update the cache tensor. This will not real change the shape of `cache` tensor. Not able for now.
2888
+ cur_max_seq_len (Tensor): The new_max_seq_len tensor with data type of int64.
2889
+ Cur_max_seq_len tensor of shape :math:`(1)`. Keep the current seq_len of cache tensor. Not abel for now.
2890
+ align_mode (int64): indicate which axis is seq_eln, 0 is 'right', 1 is 'left'. Default: 0.
2891
+
2892
+ Outputs:
2893
+ With same data type and same shape as `cache` tensor.
2894
+
2895
+ Supported Platforms:
2896
+ ``Ascend``
2897
+
2898
+ Examples:
2899
+ >>> from mindspore import Tensor
2900
+ >>> from mindspore.ops.operations import _inner_ops
2901
+ >>> b = 4
2902
+ >>> h = 40
2903
+ >>> max_s = 1024
2904
+ >>> s = 256
2905
+ >>> d = 128
2906
+ >>> cache = Tensor(np.random.randn(b, h, max_s, d).astype(np.float16))
2907
+ >>> update = Tensor(np.random.randn(b, h, s, d).astype(np.float16))
2908
+ >>> valid_seq_len = Tensor(np.random.randn(b).astype(np.int64))
2909
+ >>> batch_index = Tensor(np.random.randn(1).astype(np.int64))
2910
+ >>> new_max_seq_len = Tensor(np.random.randn(1).astype(np.int64))
2911
+ >>> cur_max_seq_len = Tensor(np.random.randn(1).astype(np.int64))
2912
+ >>> prompt_kv_cache = _inner_ops.PromptKVCache(0)
2913
+ >>> output = prompt_kv_cache(cache, update, valid_seq_len, batch_index, 2, new_max_seq_len, cur_max_seq_len)
2914
+ >>> print(cache)
2915
+ """
2916
+ @prim_attr_register
2917
+ def __init__(self, padding_mode="right"):
2918
+ """Initialize PromptKVCache."""
2919
+ self.init_prim_io_names(inputs=["cache", "update", "valid_seq_len", "batch_index", "seq_len_axis",
2920
+ "new_max_seq_len", "cur_max_seq_len"],
2921
+ outputs=["out"])
2922
+ self.add_prim_attr('side_effect_mem', True)
2923
+ self.padding_mode = padding_mode
@@ -1208,7 +1208,7 @@ class UniqueWithPad(Primitive):
1208
1208
 
1209
1209
 
1210
1210
  class Split(Primitive):
1211
- """
1211
+ r"""
1212
1212
  Splits the input tensor into output_num of tensors along the given axis and output numbers.
1213
1213
 
1214
1214
  Refer to :func:`mindspore.ops.split` for more details.
@@ -1222,7 +1222,7 @@ class Split(Primitive):
1222
1222
 
1223
1223
  Outputs:
1224
1224
  tuple[Tensor], the shape of each output tensor is the same, which is
1225
- :math:`(x_0, x_1, ..., x_{axis}/{output_num}, ..., x_{R-1})`.
1225
+ :math:`(x_0, x_1, ..., x_{axis}/{output\_num}, ..., x_{R-1})`.
1226
1226
  And the data type is the same as `input_x`.
1227
1227
 
1228
1228
  Supported Platforms:
@@ -1763,16 +1763,18 @@ class FillV2(PrimitiveWithCheck):
1763
1763
  self.init_prim_io_names(inputs=['shape', 'value'], outputs=['y'])
1764
1764
 
1765
1765
  def check_elim(self, dims, x):
1766
- if x is None or (not isinstance(x, (Tensor, Tensor_))) or (x.shape != ()) or\
1767
- dims is None or (isinstance(dims, (tuple, list)) and dims) or\
1768
- isinstance(dims, (Tensor, Tensor_)):
1766
+ x_is_invalid = x is None or (not isinstance(x, (Tensor, Tensor_))) or (x.shape != ())
1767
+ dims_is_invalid = dims is None or (isinstance(dims, (tuple, list)) and dims) or\
1768
+ isinstance(dims, (Tensor, Tensor_))
1769
+ if x_is_invalid or dims_is_invalid:
1769
1770
  return (False, None)
1770
1771
  return (True, x)
1771
1772
 
1772
1773
  def infer_value(self, dims, x):
1773
- if x is None or dims is None or\
1774
+ dims_is_invalid = dims is None or\
1774
1775
  (isinstance(dims, (tuple, list)) and dims) or\
1775
- isinstance(dims, (Tensor, Tensor_)):
1776
+ isinstance(dims, (Tensor, Tensor_))
1777
+ if x is None or dims_is_invalid:
1776
1778
  return None
1777
1779
  return x
1778
1780
 
@@ -94,7 +94,7 @@ class ReduceOp:
94
94
 
95
95
  def check_collective_target_dtype(data_name, data_dtype, prim_name):
96
96
  """Check if data type is valid."""
97
- default_target_dtypes = (mstype.int8, mstype.int32, mstype.float16, mstype.float32)
97
+ default_target_dtypes = (mstype.int8, mstype.uint8, mstype.int32, mstype.float16, mstype.bfloat16, mstype.float32)
98
98
  gpu_target_dtypes = (mstype.bool_, mstype.int8, mstype.int32, mstype.int64, mstype.uint32, mstype.uint64,
99
99
  mstype.float16, mstype.float32, mstype.float64)
100
100
 
@@ -1310,4 +1310,4 @@ class _GetTensorSlice(PrimitiveWithInfer):
1310
1310
  from mindspore.parallel._tensor import _load_tensor
1311
1311
  validator.check_value_type("dev_mat", dev_mat, [tuple], self.name)
1312
1312
  validator.check_value_type("tensor_map", tensor_map, [tuple], self.name)
1313
- return Tensor(_load_tensor(x, dev_mat, tensor_map))
1313
+ return Tensor(_load_tensor(x, dev_mat, tensor_map), x.dtype)