mindspore 2.2.0__cp39-cp39-win_amd64.whl → 2.2.10__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (122) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  3. mindspore/Newtonsoft.Json.dll +0 -0
  4. mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
  5. mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
  6. mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
  7. mindspore/_checkparam.py +3 -3
  8. mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
  9. mindspore/_extends/graph_kernel/splitter.py +3 -2
  10. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +83 -66
  11. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -4
  12. mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
  13. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +2 -1
  14. mindspore/_extends/parse/standard_method.py +2 -9
  15. mindspore/_extends/remote/kernel_build_server.py +2 -1
  16. mindspore/atlprov.dll +0 -0
  17. mindspore/c1.dll +0 -0
  18. mindspore/c1xx.dll +0 -0
  19. mindspore/c2.dll +0 -0
  20. mindspore/common/api.py +1 -1
  21. mindspore/common/auto_dynamic_shape.py +81 -85
  22. mindspore/common/dump.py +1 -1
  23. mindspore/common/tensor.py +3 -20
  24. mindspore/config/op_info.config +1 -1
  25. mindspore/context.py +11 -4
  26. mindspore/dataset/engine/datasets_standard_format.py +5 -0
  27. mindspore/dataset/vision/transforms.py +21 -21
  28. mindspore/dnnl.dll +0 -0
  29. mindspore/dpcmi.dll +0 -0
  30. mindspore/experimental/optim/adam.py +1 -1
  31. mindspore/gen_ops.py +1 -1
  32. mindspore/include/api/model.h +17 -0
  33. mindspore/include/api/status.h +8 -3
  34. mindspore/jpeg62.dll +0 -0
  35. mindspore/mindspore_backend.dll +0 -0
  36. mindspore/mindspore_common.dll +0 -0
  37. mindspore/mindspore_core.dll +0 -0
  38. mindspore/mindspore_glog.dll +0 -0
  39. mindspore/mindspore_shared_lib.dll +0 -0
  40. mindspore/msobj140.dll +0 -0
  41. mindspore/mspdb140.dll +0 -0
  42. mindspore/mspdbcore.dll +0 -0
  43. mindspore/mspdbst.dll +0 -0
  44. mindspore/mspft140.dll +0 -0
  45. mindspore/msvcdis140.dll +0 -0
  46. mindspore/msvcp140_1.dll +0 -0
  47. mindspore/msvcp140_2.dll +0 -0
  48. mindspore/msvcp140_atomic_wait.dll +0 -0
  49. mindspore/msvcp140_codecvt_ids.dll +0 -0
  50. mindspore/nn/cell.py +0 -3
  51. mindspore/nn/layer/activation.py +4 -5
  52. mindspore/nn/layer/conv.py +39 -23
  53. mindspore/nn/layer/flash_attention.py +90 -78
  54. mindspore/nn/layer/math.py +3 -7
  55. mindspore/nn/layer/rnn_cells.py +5 -5
  56. mindspore/nn/wrap/cell_wrapper.py +6 -0
  57. mindspore/numpy/utils_const.py +5 -5
  58. mindspore/opencv_core452.dll +0 -0
  59. mindspore/opencv_imgcodecs452.dll +0 -0
  60. mindspore/opencv_imgproc452.dll +0 -0
  61. mindspore/ops/_grad_experimental/grad_array_ops.py +1 -1
  62. mindspore/ops/_grad_experimental/grad_implementations.py +2 -2
  63. mindspore/ops/_grad_experimental/grad_math_ops.py +19 -18
  64. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
  65. mindspore/ops/_op_impl/aicpu/add.py +3 -3
  66. mindspore/ops/_utils/utils.py +2 -0
  67. mindspore/ops/composite/multitype_ops/_compile_utils.py +2 -1
  68. mindspore/ops/composite/multitype_ops/getitem_impl.py +2 -2
  69. mindspore/ops/function/array_func.py +10 -7
  70. mindspore/ops/function/grad/grad_func.py +0 -1
  71. mindspore/ops/function/nn_func.py +98 -9
  72. mindspore/ops/function/random_func.py +2 -1
  73. mindspore/ops/op_info_register.py +24 -21
  74. mindspore/ops/operations/__init__.py +3 -2
  75. mindspore/ops/operations/_grad_ops.py +24 -4
  76. mindspore/ops/operations/_inner_ops.py +155 -23
  77. mindspore/ops/operations/array_ops.py +9 -7
  78. mindspore/ops/operations/comm_ops.py +2 -2
  79. mindspore/ops/operations/custom_ops.py +85 -68
  80. mindspore/ops/operations/inner_ops.py +26 -3
  81. mindspore/ops/operations/math_ops.py +4 -3
  82. mindspore/ops/operations/nn_ops.py +109 -28
  83. mindspore/parallel/_parallel_serialization.py +10 -3
  84. mindspore/parallel/_tensor.py +4 -1
  85. mindspore/parallel/checkpoint_transform.py +13 -2
  86. mindspore/parallel/shard.py +17 -10
  87. mindspore/pgodb140.dll +0 -0
  88. mindspore/pgort140.dll +0 -0
  89. mindspore/profiler/common/util.py +1 -0
  90. mindspore/profiler/parser/ascend_hccl_generator.py +232 -0
  91. mindspore/profiler/parser/ascend_msprof_exporter.py +86 -43
  92. mindspore/profiler/parser/ascend_msprof_generator.py +196 -9
  93. mindspore/profiler/parser/ascend_op_generator.py +1 -1
  94. mindspore/profiler/parser/ascend_timeline_generator.py +6 -182
  95. mindspore/profiler/parser/base_timeline_generator.py +1 -1
  96. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +2 -2
  97. mindspore/profiler/parser/framework_parser.py +1 -1
  98. mindspore/profiler/parser/profiler_info.py +19 -0
  99. mindspore/profiler/profiling.py +46 -24
  100. mindspore/rewrite/api/pattern_engine.py +1 -1
  101. mindspore/rewrite/parsers/for_parser.py +1 -1
  102. mindspore/rewrite/symbol_tree.py +1 -4
  103. mindspore/run_check/_check_version.py +5 -3
  104. mindspore/safeguard/rewrite_obfuscation.py +52 -28
  105. mindspore/tbbmalloc.dll +0 -0
  106. mindspore/tinyxml2.dll +0 -0
  107. mindspore/train/callback/_summary_collector.py +1 -1
  108. mindspore/train/dataset_helper.py +1 -0
  109. mindspore/train/model.py +2 -2
  110. mindspore/train/serialization.py +97 -11
  111. mindspore/train/summary/_summary_adapter.py +1 -1
  112. mindspore/train/summary/summary_record.py +23 -7
  113. mindspore/turbojpeg.dll +0 -0
  114. mindspore/vcmeta.dll +0 -0
  115. mindspore/vcruntime140.dll +0 -0
  116. mindspore/vcruntime140_1.dll +0 -0
  117. mindspore/version.py +1 -1
  118. {mindspore-2.2.0.dist-info → mindspore-2.2.10.dist-info}/METADATA +1 -1
  119. {mindspore-2.2.0.dist-info → mindspore-2.2.10.dist-info}/RECORD +122 -122
  120. {mindspore-2.2.0.dist-info → mindspore-2.2.10.dist-info}/WHEEL +0 -0
  121. {mindspore-2.2.0.dist-info → mindspore-2.2.10.dist-info}/entry_points.txt +0 -0
  122. {mindspore-2.2.0.dist-info → mindspore-2.2.10.dist-info}/top_level.txt +0 -0
@@ -3845,7 +3845,7 @@ class FlashAttentionScoreGrad(Primitive):
3845
3845
  """
3846
3846
  @prim_attr_register
3847
3847
  def __init__(self, head_num, keep_prob=1.0, scale_value=1.0, pre_tokens=65536, next_tokens=65536, inner_precise=1,
3848
- input_layout='BSH'):
3848
+ input_layout='BSH', sparse_mode=0):
3849
3849
  """Initialize FlashAttentionScoreGrad."""
3850
3850
  validator.check_value_type('head_num', head_num, [int], self.name)
3851
3851
  validator.check_value_type('keep_prob', keep_prob, [int, float], self.name)
@@ -3855,11 +3855,31 @@ class FlashAttentionScoreGrad(Primitive):
3855
3855
  validator.check_value_type('pre_tokens', pre_tokens, [int], self.name)
3856
3856
  validator.check_value_type('next_tokens', next_tokens, [int], self.name)
3857
3857
  validator.check_value_type('inner_precise', inner_precise, [int], self.name)
3858
+ validator.check_value_type('sparse_mode', sparse_mode, [int], self.name)
3858
3859
  if inner_precise not in [0, 1]:
3859
3860
  raise ValueError(f"Attribute 'inner_precise' must be either 0 or 1, but got {inner_precise}")
3860
3861
  validator.check_value_type('input_layout', input_layout, [str], self.name)
3861
- if input_layout not in ["BSH"]:
3862
- raise ValueError(f"Attribute 'input_layout' must be either 'bsh' or 'sbh', but got {input_layout}")
3862
+ if input_layout not in ["BSH", "BNSD"]:
3863
+ raise ValueError(f"Attribute 'input_layout' must be either 'BSH' or 'BNSD', but got {input_layout}")
3863
3864
  self.init_prim_io_names(inputs=['query', 'key', 'value', 'attn_mask', 'attention_in', 'softmax_max',
3864
- 'softmax_sum', 'dy', 'drop_mask', 'real_shift', "padding_mask", 'softmax_out'],
3865
+ 'softmax_sum', 'dy', 'drop_mask', 'real_shift', "padding_mask", 'softmax_out',
3866
+ 'prefix'],
3865
3867
  outputs=['dq', 'dk', 'dv'])
3868
+
3869
+
3870
+ class RmsNormGrad(Primitive):
3871
+ r"""
3872
+ Calculates the gradient of RmsNorm operation.
3873
+ .. warning::
3874
+ This is an experimental API that is subject to change or deletion.
3875
+
3876
+ Supported Platforms:
3877
+ ``Ascend``
3878
+ """
3879
+
3880
+ @prim_attr_register
3881
+ def __init__(self):
3882
+ """Initialize RmsNormGrad."""
3883
+ self.init_prim_io_names(inputs=["dy", "x", "rstd", "gamma"],
3884
+ outputs=["dx", "dgamma"])
3885
+
@@ -26,7 +26,7 @@ from mindspore.ops.operations._scalar_ops import bit_or, bit_and
26
26
  from mindspore.ops.operations.comm_ops import ReduceOp
27
27
  from mindspore.ops import signature as sig
28
28
  from mindspore.ops.operations.math_ops import _infer_shape_reduce
29
- from mindspore.ops.primitive import PrimitiveWithCheck, PrimitiveWithInfer, prim_attr_register, Primitive,\
29
+ from mindspore.ops.primitive import PrimitiveWithCheck, PrimitiveWithInfer, prim_attr_register, Primitive, \
30
30
  _run_op, _check_contains_variable
31
31
  from mindspore._c_expression import Tensor as Tensor_
32
32
  from mindspore._c_expression import typing
@@ -167,6 +167,7 @@ class Quant(PrimitiveWithInfer):
167
167
  self.sqrt_mode = validator.check_value_type("sqrt_mode", sqrt_mode, [bool], self.name)
168
168
  self.round_mode = validator.check_string(round_mode, ["Round", "Floor", "Ceil", "Trunc"],
169
169
  "round_mode", self.name)
170
+ self.add_prim_attr("dst_type", mstype.int8)
170
171
 
171
172
  def infer_shape(self, x_shape):
172
173
  return x_shape
@@ -174,7 +175,7 @@ class Quant(PrimitiveWithInfer):
174
175
  def infer_dtype(self, x_type):
175
176
  validator.check_subclass("input_x", x_type, mstype.tensor_type, self.name)
176
177
  validator.check_type_name("input_x", x_type, [mstype.float16, mstype.float32], self.name)
177
- return mstype.int8
178
+ return self.get_attr_dict()['dst_type']
178
179
 
179
180
 
180
181
  class Lamb(PrimitiveWithInfer):
@@ -491,7 +492,7 @@ class Receive(PrimitiveWithInfer):
491
492
  self.dtype = dtype
492
493
  self.group = group
493
494
  self.add_prim_attr("no_eliminate", True)
494
- valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8]
495
+ valid_type = [mstype.float16, mstype.bfloat16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8]
495
496
  args = {"dtype": dtype}
496
497
  validator.check_scalar_or_tensor_types_same(args, valid_type, self.name)
497
498
 
@@ -2146,13 +2147,14 @@ class ClipByNorm(PrimitiveWithInfer):
2146
2147
  @prim_attr_register
2147
2148
  def __init__(self, axis=None):
2148
2149
  """Initialize ClipByNorm"""
2150
+ self.axis_str = 'axis'
2149
2151
  self.axis = () if axis is None else axis
2150
- validator.check_value_type('axis', self.axis, [int, tuple, list], self.name)
2152
+ validator.check_value_type(self.axis_str, self.axis, [int, tuple, list], self.name)
2151
2153
  axis_check = self.axis if isinstance(self.axis, Iterable) else (self.axis,)
2152
2154
  for i, value in enumerate(axis_check):
2153
2155
  validator.check_value_type('axis[%d]' % i, value, [int], self.name)
2154
- self.init_attrs['axis'] = self.axis
2155
- self.add_prim_attr('axis', self.axis)
2156
+ self.init_attrs[self.axis_str] = self.axis
2157
+ self.add_prim_attr(self.axis_str, self.axis)
2156
2158
  self.init_prim_io_names(inputs=['x', 'clip_norm'], outputs=['output'])
2157
2159
 
2158
2160
  def infer_shape(self, x_shape, clip_norm_shape):
@@ -2729,27 +2731,29 @@ class CopyWithSlice(Primitive):
2729
2731
  self.init_prim_io_names(inputs=['x', 'y'], outputs=['x'])
2730
2732
 
2731
2733
 
2732
- class MoeFFN(Primitive):
2734
+ class FFN(Primitive):
2733
2735
  r"""
2734
- The MoeFFN computation is similar to Feed-Forward Network, it contains matmul + gelu + matmul.
2736
+ The FFN computation is similar to Feed-Forward Network, it contains matmul + gelu + matmul.
2735
2737
 
2736
2738
  Args:
2737
2739
  activation (string): The activation type, set to 'fastgelu' or 'gelu'.
2738
- Only support 'fastgelu' for now. Default: "fastgelu".
2740
+ Only support 'fastgelu' for now. Default: "fastgelu".
2741
+ inner_precise (int): The precise mode, set to 0 for high precision or 1 for high performance.
2742
+ Only support 1 for now. Default: 0.
2739
2743
 
2740
2744
  Inputs:
2741
2745
  - **x** (Tensor) - The input tensor with data type of int8, float16.
2742
2746
  Input tensor of shape :math:`(batch\_size * seq\_length, hidden\_size)`.
2747
+ - **weight1** (Tensor) - The weight1 tensor with data type of float16.
2748
+ Weight1 tensor of shape :math:`(expert\_num, hidden\_size, ffn\_hidden\_size)`.
2749
+ - **weight2** (Tensor) - The weight2 tensor with data type of float16.
2750
+ Weight2 tensor of shape :math:`(expert\_num, ffn\_hidden\_size, hidden\_size)`.
2743
2751
  - **expert_tokens** (Tensor]) - The expert tokens tensor with data type of int64.
2744
2752
  Expert tokens tensor of shape :math:`(16,)`. For example, `(2, 1, 0, .., 9)`
2745
2753
  indicate that the 0th expert deals with 2 tokens, the 1th expert deals with 1 tokens,
2746
2754
  the 2th expert do noting and so on.
2747
- - **weight1** (Tensor) - The weight1 tensor with data type of float16.
2748
- Weight1 tensor of shape :math:`(expert\_num, hidden\_size, ffn\_hidden\_size)`.
2749
2755
  - **bias1** (Tensor) - The bias1 tensor with data type of float16.
2750
2756
  Bias1 tensor of shape :math:`(expert\_num, ffn\_hidden\_size)`.
2751
- - **weight2** (Tensor) - The weight2 tensor with data type of float16.
2752
- Weight2 tensor of shape :math:`(expert\_num, ffn\_hidden\_size, hidden\_size)`.
2753
2757
  - **bias2** (Tensor) - The bias2 tensor with data type of float16.
2754
2758
  Bias2 tensor of shape :math:`(expert\_num, hidden\_size)`.
2755
2759
  - **scale** (Tensor) - The scale tensor with data type of float16. Not enable now.
@@ -2771,21 +2775,149 @@ class MoeFFN(Primitive):
2771
2775
  >>> h_f = 4 * h
2772
2776
  >>> e = 16
2773
2777
  >>> x = Tensor(np.random.randn(b * s, h).astype(np.float16))
2774
- >>> expert_tokens = Tensor(np.random.randn(e).astype(np.int64))
2775
2778
  >>> w1 = Tensor(np.random.randn(e, h, h_f).astype(np.float16))
2776
- >>> bias1 = Tensor(np.random.randn(e, h_f).astype(np.float16))
2777
2779
  >>> w2 = Tensor(np.random.randn(e, h_f, h).astype(np.float16))
2780
+ >>> expert_tokens = Tensor(np.random.randn(e).astype(np.int64))
2781
+ >>> bias1 = Tensor(np.random.randn(e, h_f).astype(np.float16))
2778
2782
  >>> bias2 = Tensor(np.random.randn(e, h).astype(np.float16))
2779
- >>> moe_ffn = _inner_ops.MoeFFN("fastgelu")
2780
- >>> output = moe_ffn(x, w1, bias1, w2, bias2)
2783
+ >>> ffn = _inner_ops.FFN("fastgelu", 1)
2784
+ >>> output = ffn(x, w1, w2, expert_tokens, bias1, bias2)
2781
2785
  >>> print(output)
2782
2786
  """
2783
2787
 
2784
2788
  @prim_attr_register
2785
- def __init__(self, activation):
2786
- """Initialize MoeFFN."""
2787
- self.init_prim_io_names(inputs=["x", "expert_tokens", "weight1", "bias1",
2788
- "weight2", "bias2", "scale", "offset", "deq_scale1"
2789
- "deq_scale2"],
2789
+ def __init__(self, activation, inner_precise):
2790
+ """Initialize FFN."""
2791
+ self.init_prim_io_names(inputs=["x", "weight1", "weight2", "expert_tokens", "bias1",
2792
+ "bias2", "scale", "offset", "deq_scale1", "deq_scale2"],
2790
2793
  outputs=["y"])
2791
- self.activation = activation
2794
+ cls_name = self.name
2795
+ validator.check_value_type("activation", activation, [str], cls_name)
2796
+ validator.check_value_type("inner_precise", inner_precise, [int], cls_name)
2797
+
2798
+
2799
+ class DecoderKVCache(Primitive):
2800
+ r"""
2801
+ The DecoderKVCache is used for decoding the KVCache of transformer network.
2802
+
2803
+ Args:
2804
+ cache (Tensor): The cahe tensor with data type of int8, uint8, int16, uint16, float16, float32 and int32.
2805
+ When seq_len_axis is 2, cache tensor of shape
2806
+ :math:`(batch\_size, num_head, max\_seq\_length, hidden\_size)`.
2807
+ When seq_len_axis is 1, cache tensor of shape
2808
+ :math:`(batch\_size, max\_seq\_length, num_head, hidden\_size)`.
2809
+ update (Tensor]): The tensor which is used to update the cache tensor. Same data type as cache tensor.
2810
+ When seq_len_axis is 2, update tensor of shape
2811
+ :math:`(batch\_size, num_head, update\_seq\_length, hidden\_size)`.
2812
+ When seq_len_axis is 1, update tensor of shape
2813
+ :math:`(batch\_size, update\_seq\_length, num_head, hidden\_size)`.
2814
+ valid_seq_len (Tensor): The valid_seq_len tensor with data type of int64.
2815
+ Valid_seq_len tensor of shape :math:`(batch\_size)`.
2816
+ batch_index (Tensor): The batch_index tensor with data type of int64.
2817
+ Batch_index tensor of shape :math:`(1)`. Indicate that which batch of cache tensor is going to be update.
2818
+ seq_len_axis (int64): The seq_len_axis indicate which axis is seq_eln, set to '1' or '2'. Default: "2".
2819
+ new_max_seq_len (Tensor): The new_max_seq_len tensor with data type of int64.
2820
+ New_max_seq_len tensor of shape :math:`(1)`.
2821
+ Indicate that user want to change the shape of cache tensor from
2822
+ :math:`(batch\_size, num_head, max\_seq\_length, hidden\_size)` to
2823
+ :math:
2824
+ `(batch\_size * max\_seq\_length / new\_max\_seq\_length, num_head, new\_max\_seq\_length, hidden\_size)`
2825
+ to update the cache tensor. This will not real change the shape of `cache` tensor. Not able for now.
2826
+ cur_max_seq_len (Tensor): The new_max_seq_len tensor with data type of int64.
2827
+ Cur_max_seq_len tensor of shape :math:`(1)`. Keep the current seq_len of cache tensor. Not abel for now.
2828
+
2829
+ Outputs:
2830
+ With same data type and same shape as `cache` tensor.
2831
+
2832
+ Supported Platforms:
2833
+ ``Ascend``
2834
+
2835
+ Examples:
2836
+ >>> from mindspore.ops.operations import _inner_ops
2837
+ >>> b = 4
2838
+ >>> h = 40
2839
+ >>> max_s = 1024
2840
+ >>> s = 1
2841
+ >>> d = 128
2842
+ >>> cache = Tensor(np.random.randn(b, h, max_s, d).astype(np.float16))
2843
+ >>> update = Tensor(np.random.randn(b, h, s, d).astype(np.float16))
2844
+ >>> valid_seq_len = Tensor(np.random.randn(b).astype(np.int64))
2845
+ >>> batch_index = Tensor(np.random.randn(1).astype(np.int64))
2846
+ >>> new_max_seq_len = Tensor(np.random.randn(1).astype(np.int64))
2847
+ >>> cur_max_seq_len = Tensor(np.random.randn(1).astype(np.int64))
2848
+ >>> decoder_kv_cache = _inner_ops.DecoderKVCache()
2849
+ >>> output = decoder_kv_cache(cache, update, valid_seq_len, batch_index, 2, new_max_seq_len, cur_max_seq_len)
2850
+ >>> print(cache)
2851
+ """
2852
+ @prim_attr_register
2853
+ def __init__(self):
2854
+ """Initialize DecoderKVCache."""
2855
+ self.init_prim_io_names(inputs=["cache", "update", "valid_seq_len", "batch_index", "seq_len_axis",
2856
+ "new_max_seq_len", "cur_max_seq_len"],
2857
+ outputs=["out"])
2858
+ self.add_prim_attr('side_effect_mem', True)
2859
+
2860
+
2861
+ class PromptKVCache(Primitive):
2862
+ r"""
2863
+ The PromptKVCache is used for prefill the KVCache of transformer network.
2864
+
2865
+ Args:
2866
+ cache (Tensor): The cahe tensor with data type of int8, uint8, int16, uint16, float16, float32 and int32.
2867
+ When seq_len_axis is 2, cache tensor of shape
2868
+ :math:`(batch\_size, num_head, max\_seq\_length, hidden\_size)`.
2869
+ When seq_len_axis is 1, cache tensor of shape
2870
+ :math:`(batch\_size, max\_seq\_length, num_head, hidden\_size)`.
2871
+ update (Tensor]): The tensor which is used to update the cache tensor. Same data type as cache tensor.
2872
+ When seq_len_axis is 2, update tensor of shape
2873
+ :math:`(batch\_size, num_head, update\_seq\_length, hidden\_size)`.
2874
+ When seq_len_axis is 1, update tensor of shape
2875
+ :math:`(batch\_size, update\_seq\_length, num_head, hidden\_size)`.
2876
+ valid_seq_len (Tensor): The valid_seq_len tensor with data type of int64.
2877
+ Valid_seq_len tensor of shape :math:`(batch\_size)`.
2878
+ batch_index (Tensor): The batch_index tensor with data type of int64.
2879
+ Batch_index tensor of shape :math:`(1)`. Indicate that which batch of cache tensor is going to be update.
2880
+ seq_len_axis (int64): The seq_len_axis indicate which axis is seq_eln, set to '1' or '2'. Default: "2".
2881
+ new_max_seq_len (Tensor): The new_max_seq_len tensor with data type of int64.
2882
+ New_max_seq_len tensor of shape :math:`(1)`.
2883
+ Indicate that user want to change the shape of cache tensor from
2884
+ :math:`(batch\_size, num_head, max\_seq\_length, hidden\_size)` to
2885
+ :math:
2886
+ `(batch\_size * max\_seq\_length / new\_max\_seq\_length, num_head, new\_max\_seq\_length, hidden\_size)`
2887
+ to update the cache tensor. This will not real change the shape of `cache` tensor. Not able for now.
2888
+ cur_max_seq_len (Tensor): The new_max_seq_len tensor with data type of int64.
2889
+ Cur_max_seq_len tensor of shape :math:`(1)`. Keep the current seq_len of cache tensor. Not abel for now.
2890
+ align_mode (int64): indicate which axis is seq_eln, 0 is 'right', 1 is 'left'. Default: 0.
2891
+
2892
+ Outputs:
2893
+ With same data type and same shape as `cache` tensor.
2894
+
2895
+ Supported Platforms:
2896
+ ``Ascend``
2897
+
2898
+ Examples:
2899
+ >>> from mindspore import Tensor
2900
+ >>> from mindspore.ops.operations import _inner_ops
2901
+ >>> b = 4
2902
+ >>> h = 40
2903
+ >>> max_s = 1024
2904
+ >>> s = 256
2905
+ >>> d = 128
2906
+ >>> cache = Tensor(np.random.randn(b, h, max_s, d).astype(np.float16))
2907
+ >>> update = Tensor(np.random.randn(b, h, s, d).astype(np.float16))
2908
+ >>> valid_seq_len = Tensor(np.random.randn(b).astype(np.int64))
2909
+ >>> batch_index = Tensor(np.random.randn(1).astype(np.int64))
2910
+ >>> new_max_seq_len = Tensor(np.random.randn(1).astype(np.int64))
2911
+ >>> cur_max_seq_len = Tensor(np.random.randn(1).astype(np.int64))
2912
+ >>> prompt_kv_cache = _inner_ops.PromptKVCache(0)
2913
+ >>> output = prompt_kv_cache(cache, update, valid_seq_len, batch_index, 2, new_max_seq_len, cur_max_seq_len)
2914
+ >>> print(cache)
2915
+ """
2916
+ @prim_attr_register
2917
+ def __init__(self, padding_mode="right"):
2918
+ """Initialize PromptKVCache."""
2919
+ self.init_prim_io_names(inputs=["cache", "update", "valid_seq_len", "batch_index", "seq_len_axis",
2920
+ "new_max_seq_len", "cur_max_seq_len"],
2921
+ outputs=["out"])
2922
+ self.add_prim_attr('side_effect_mem', True)
2923
+ self.padding_mode = padding_mode
@@ -1208,7 +1208,7 @@ class UniqueWithPad(Primitive):
1208
1208
 
1209
1209
 
1210
1210
  class Split(Primitive):
1211
- """
1211
+ r"""
1212
1212
  Splits the input tensor into output_num of tensors along the given axis and output numbers.
1213
1213
 
1214
1214
  Refer to :func:`mindspore.ops.split` for more details.
@@ -1222,7 +1222,7 @@ class Split(Primitive):
1222
1222
 
1223
1223
  Outputs:
1224
1224
  tuple[Tensor], the shape of each output tensor is the same, which is
1225
- :math:`(x_0, x_1, ..., x_{axis}/{output_num}, ..., x_{R-1})`.
1225
+ :math:`(x_0, x_1, ..., x_{axis}/{output\_num}, ..., x_{R-1})`.
1226
1226
  And the data type is the same as `input_x`.
1227
1227
 
1228
1228
  Supported Platforms:
@@ -1763,16 +1763,18 @@ class FillV2(PrimitiveWithCheck):
1763
1763
  self.init_prim_io_names(inputs=['shape', 'value'], outputs=['y'])
1764
1764
 
1765
1765
  def check_elim(self, dims, x):
1766
- if x is None or (not isinstance(x, (Tensor, Tensor_))) or (x.shape != ()) or\
1767
- dims is None or (isinstance(dims, (tuple, list)) and dims) or\
1768
- isinstance(dims, (Tensor, Tensor_)):
1766
+ x_is_invalid = x is None or (not isinstance(x, (Tensor, Tensor_))) or (x.shape != ())
1767
+ dims_is_invalid = dims is None or (isinstance(dims, (tuple, list)) and dims) or\
1768
+ isinstance(dims, (Tensor, Tensor_))
1769
+ if x_is_invalid or dims_is_invalid:
1769
1770
  return (False, None)
1770
1771
  return (True, x)
1771
1772
 
1772
1773
  def infer_value(self, dims, x):
1773
- if x is None or dims is None or\
1774
+ dims_is_invalid = dims is None or\
1774
1775
  (isinstance(dims, (tuple, list)) and dims) or\
1775
- isinstance(dims, (Tensor, Tensor_)):
1776
+ isinstance(dims, (Tensor, Tensor_))
1777
+ if x is None or dims_is_invalid:
1776
1778
  return None
1777
1779
  return x
1778
1780
 
@@ -94,7 +94,7 @@ class ReduceOp:
94
94
 
95
95
  def check_collective_target_dtype(data_name, data_dtype, prim_name):
96
96
  """Check if data type is valid."""
97
- default_target_dtypes = (mstype.int8, mstype.int32, mstype.float16, mstype.float32)
97
+ default_target_dtypes = (mstype.int8, mstype.uint8, mstype.int32, mstype.float16, mstype.bfloat16, mstype.float32)
98
98
  gpu_target_dtypes = (mstype.bool_, mstype.int8, mstype.int32, mstype.int64, mstype.uint32, mstype.uint64,
99
99
  mstype.float16, mstype.float32, mstype.float64)
100
100
 
@@ -1310,4 +1310,4 @@ class _GetTensorSlice(PrimitiveWithInfer):
1310
1310
  from mindspore.parallel._tensor import _load_tensor
1311
1311
  validator.check_value_type("dev_mat", dev_mat, [tuple], self.name)
1312
1312
  validator.check_value_type("tensor_map", tensor_map, [tuple], self.name)
1313
- return Tensor(_load_tensor(x, dev_mat, tensor_map))
1313
+ return Tensor(_load_tensor(x, dev_mat, tensor_map), x.dtype)