mindspore 2.2.0__cp39-cp39-win_amd64.whl → 2.2.10__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (122) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  3. mindspore/Newtonsoft.Json.dll +0 -0
  4. mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
  5. mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
  6. mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
  7. mindspore/_checkparam.py +3 -3
  8. mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
  9. mindspore/_extends/graph_kernel/splitter.py +3 -2
  10. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +83 -66
  11. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -4
  12. mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
  13. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +2 -1
  14. mindspore/_extends/parse/standard_method.py +2 -9
  15. mindspore/_extends/remote/kernel_build_server.py +2 -1
  16. mindspore/atlprov.dll +0 -0
  17. mindspore/c1.dll +0 -0
  18. mindspore/c1xx.dll +0 -0
  19. mindspore/c2.dll +0 -0
  20. mindspore/common/api.py +1 -1
  21. mindspore/common/auto_dynamic_shape.py +81 -85
  22. mindspore/common/dump.py +1 -1
  23. mindspore/common/tensor.py +3 -20
  24. mindspore/config/op_info.config +1 -1
  25. mindspore/context.py +11 -4
  26. mindspore/dataset/engine/datasets_standard_format.py +5 -0
  27. mindspore/dataset/vision/transforms.py +21 -21
  28. mindspore/dnnl.dll +0 -0
  29. mindspore/dpcmi.dll +0 -0
  30. mindspore/experimental/optim/adam.py +1 -1
  31. mindspore/gen_ops.py +1 -1
  32. mindspore/include/api/model.h +17 -0
  33. mindspore/include/api/status.h +8 -3
  34. mindspore/jpeg62.dll +0 -0
  35. mindspore/mindspore_backend.dll +0 -0
  36. mindspore/mindspore_common.dll +0 -0
  37. mindspore/mindspore_core.dll +0 -0
  38. mindspore/mindspore_glog.dll +0 -0
  39. mindspore/mindspore_shared_lib.dll +0 -0
  40. mindspore/msobj140.dll +0 -0
  41. mindspore/mspdb140.dll +0 -0
  42. mindspore/mspdbcore.dll +0 -0
  43. mindspore/mspdbst.dll +0 -0
  44. mindspore/mspft140.dll +0 -0
  45. mindspore/msvcdis140.dll +0 -0
  46. mindspore/msvcp140_1.dll +0 -0
  47. mindspore/msvcp140_2.dll +0 -0
  48. mindspore/msvcp140_atomic_wait.dll +0 -0
  49. mindspore/msvcp140_codecvt_ids.dll +0 -0
  50. mindspore/nn/cell.py +0 -3
  51. mindspore/nn/layer/activation.py +4 -5
  52. mindspore/nn/layer/conv.py +39 -23
  53. mindspore/nn/layer/flash_attention.py +90 -78
  54. mindspore/nn/layer/math.py +3 -7
  55. mindspore/nn/layer/rnn_cells.py +5 -5
  56. mindspore/nn/wrap/cell_wrapper.py +6 -0
  57. mindspore/numpy/utils_const.py +5 -5
  58. mindspore/opencv_core452.dll +0 -0
  59. mindspore/opencv_imgcodecs452.dll +0 -0
  60. mindspore/opencv_imgproc452.dll +0 -0
  61. mindspore/ops/_grad_experimental/grad_array_ops.py +1 -1
  62. mindspore/ops/_grad_experimental/grad_implementations.py +2 -2
  63. mindspore/ops/_grad_experimental/grad_math_ops.py +19 -18
  64. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
  65. mindspore/ops/_op_impl/aicpu/add.py +3 -3
  66. mindspore/ops/_utils/utils.py +2 -0
  67. mindspore/ops/composite/multitype_ops/_compile_utils.py +2 -1
  68. mindspore/ops/composite/multitype_ops/getitem_impl.py +2 -2
  69. mindspore/ops/function/array_func.py +10 -7
  70. mindspore/ops/function/grad/grad_func.py +0 -1
  71. mindspore/ops/function/nn_func.py +98 -9
  72. mindspore/ops/function/random_func.py +2 -1
  73. mindspore/ops/op_info_register.py +24 -21
  74. mindspore/ops/operations/__init__.py +3 -2
  75. mindspore/ops/operations/_grad_ops.py +24 -4
  76. mindspore/ops/operations/_inner_ops.py +155 -23
  77. mindspore/ops/operations/array_ops.py +9 -7
  78. mindspore/ops/operations/comm_ops.py +2 -2
  79. mindspore/ops/operations/custom_ops.py +85 -68
  80. mindspore/ops/operations/inner_ops.py +26 -3
  81. mindspore/ops/operations/math_ops.py +4 -3
  82. mindspore/ops/operations/nn_ops.py +109 -28
  83. mindspore/parallel/_parallel_serialization.py +10 -3
  84. mindspore/parallel/_tensor.py +4 -1
  85. mindspore/parallel/checkpoint_transform.py +13 -2
  86. mindspore/parallel/shard.py +17 -10
  87. mindspore/pgodb140.dll +0 -0
  88. mindspore/pgort140.dll +0 -0
  89. mindspore/profiler/common/util.py +1 -0
  90. mindspore/profiler/parser/ascend_hccl_generator.py +232 -0
  91. mindspore/profiler/parser/ascend_msprof_exporter.py +86 -43
  92. mindspore/profiler/parser/ascend_msprof_generator.py +196 -9
  93. mindspore/profiler/parser/ascend_op_generator.py +1 -1
  94. mindspore/profiler/parser/ascend_timeline_generator.py +6 -182
  95. mindspore/profiler/parser/base_timeline_generator.py +1 -1
  96. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +2 -2
  97. mindspore/profiler/parser/framework_parser.py +1 -1
  98. mindspore/profiler/parser/profiler_info.py +19 -0
  99. mindspore/profiler/profiling.py +46 -24
  100. mindspore/rewrite/api/pattern_engine.py +1 -1
  101. mindspore/rewrite/parsers/for_parser.py +1 -1
  102. mindspore/rewrite/symbol_tree.py +1 -4
  103. mindspore/run_check/_check_version.py +5 -3
  104. mindspore/safeguard/rewrite_obfuscation.py +52 -28
  105. mindspore/tbbmalloc.dll +0 -0
  106. mindspore/tinyxml2.dll +0 -0
  107. mindspore/train/callback/_summary_collector.py +1 -1
  108. mindspore/train/dataset_helper.py +1 -0
  109. mindspore/train/model.py +2 -2
  110. mindspore/train/serialization.py +97 -11
  111. mindspore/train/summary/_summary_adapter.py +1 -1
  112. mindspore/train/summary/summary_record.py +23 -7
  113. mindspore/turbojpeg.dll +0 -0
  114. mindspore/vcmeta.dll +0 -0
  115. mindspore/vcruntime140.dll +0 -0
  116. mindspore/vcruntime140_1.dll +0 -0
  117. mindspore/version.py +1 -1
  118. {mindspore-2.2.0.dist-info → mindspore-2.2.10.dist-info}/METADATA +1 -1
  119. {mindspore-2.2.0.dist-info → mindspore-2.2.10.dist-info}/RECORD +122 -122
  120. {mindspore-2.2.0.dist-info → mindspore-2.2.10.dist-info}/WHEEL +0 -0
  121. {mindspore-2.2.0.dist-info → mindspore-2.2.10.dist-info}/entry_points.txt +0 -0
  122. {mindspore-2.2.0.dist-info → mindspore-2.2.10.dist-info}/top_level.txt +0 -0
@@ -136,6 +136,13 @@ class MS_API Model {
136
136
  /// \return Status.
137
137
  Status UpdateWeights(const std::vector<MSTensor> &new_weights);
138
138
 
139
+ /// \brief Change the size and or content of weight tensors
140
+ ///
141
+ /// \param[in] A vector where model constant are arranged in sequence
142
+ ///
143
+ /// \return Status.
144
+ Status UpdateWeights(const std::vector<std::vector<MSTensor>> &new_weights);
145
+
139
146
  /// \brief Inference model API. If use this API in train mode, it's equal to RunStep API.
140
147
  ///
141
148
  /// \param[in] inputs A vector where model inputs are arranged in sequence.
@@ -358,6 +365,13 @@ class MS_API Model {
358
365
 
359
366
  const std::shared_ptr<ModelImpl> impl() const { return impl_; }
360
367
 
368
+ /// \brief Get model info by key
369
+ ///
370
+ /// \param[in] key The key of model info key-value pair
371
+ ///
372
+ /// \return The value of the model info associated with the given key.
373
+ inline std::string GetModelInfo(const std::string &key);
374
+
361
375
  private:
362
376
  friend class Serialization;
363
377
  // api without std::string
@@ -374,6 +388,7 @@ class MS_API Model {
374
388
  const std::vector<char> &cropto_lib_path);
375
389
  Status Build(const std::vector<char> &model_path, ModelType model_type, const std::shared_ptr<Context> &model_context,
376
390
  const Key &dec_key, const std::vector<char> &dec_mode, const std::vector<char> &cropto_lib_path);
391
+ std::vector<char> GetModelInfo(const std::vector<char> &key);
377
392
  std::shared_ptr<ModelImpl> impl_;
378
393
  };
379
394
 
@@ -416,5 +431,7 @@ Status Model::Build(const std::string &model_path, ModelType model_type,
416
431
  const std::shared_ptr<Context> &model_context) {
417
432
  return Build(StringToChar(model_path), model_type, model_context);
418
433
  }
434
+
435
+ inline std::string Model::GetModelInfo(const std::string &key) { return CharToString(GetModelInfo(StringToChar(key))); }
419
436
  } // namespace mindspore
420
437
  #endif // MINDSPORE_INCLUDE_API_MODEL_H
@@ -83,9 +83,14 @@ enum StatusCode : uint32_t {
83
83
  kLiteModelRebuild = kLite | (0x0FFFFFFF & -12), /**< Model has been built. */
84
84
 
85
85
  // Executor error code, range: [-100,-200)
86
- kLiteOutOfTensorRange = kLite | (0x0FFFFFFF & -100), /**< Failed to check range. */
87
- kLiteInputTensorError = kLite | (0x0FFFFFFF & -101), /**< Failed to check input tensor. */
88
- kLiteReentrantError = kLite | (0x0FFFFFFF & -102), /**< Exist executor running. */
86
+ kLiteOutOfTensorRange = kLite | (0x0FFFFFFF & -100), /**< Failed to check range. */
87
+ kLiteInputTensorError = kLite | (0x0FFFFFFF & -101), /**< Failed to check input tensor. */
88
+ kLiteReentrantError = kLite | (0x0FFFFFFF & -102), /**< Exist executor running. */
89
+ kLiteLLMWaitProcessTimeOut = kLite | (0x0FFFFFFF & -103), /**< Wait to be processed time out. */
90
+ kLiteLLMKVCacheNotExist = kLite | (0x0FFFFFFF & -104), /**< KV Cache not exist. */
91
+ kLiteLLMRepeatRequest = kLite | (0x0FFFFFFF & -105), /**< repeat request. */
92
+ kLiteLLMRequestAlreadyCompleted = kLite | (0x0FFFFFFF & -106), /**< request already complete!. */
93
+ kLiteLLMEngineFinalized = kLite | (0x0FFFFFFF & -107), /**< llm engine finalized. */
89
94
 
90
95
  // Graph error code, range: [-200,-300)
91
96
  kLiteGraphFileError = kLite | (0x0FFFFFFF & -200), /**< Failed to verify graph file. */
mindspore/jpeg62.dll CHANGED
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
mindspore/msobj140.dll CHANGED
Binary file
mindspore/mspdb140.dll CHANGED
Binary file
mindspore/mspdbcore.dll CHANGED
Binary file
mindspore/mspdbst.dll CHANGED
Binary file
mindspore/mspft140.dll CHANGED
Binary file
mindspore/msvcdis140.dll CHANGED
Binary file
mindspore/msvcp140_1.dll CHANGED
Binary file
mindspore/msvcp140_2.dll CHANGED
Binary file
Binary file
Binary file
mindspore/nn/cell.py CHANGED
@@ -1081,9 +1081,6 @@ class Cell(Cell_):
1081
1081
  if not isinstance(param, Parameter) and param is not None:
1082
1082
  raise TypeError(f"For 'insert_param_to_cell', the argument 'param' must be 'Parameter' if not None, "
1083
1083
  f"but got {type(param)}.")
1084
- if param is None:
1085
- raise TypeError(f"For 'insert_param_to_cell', the argument 'param' must not be None, "
1086
- f"but got None.")
1087
1084
  if isinstance(param, Parameter) and param.name == PARAMETER_NAME_DEFAULT:
1088
1085
  param.name = param_name
1089
1086
  self._params[param_name] = param
@@ -932,10 +932,8 @@ class GELU(Cell):
932
932
  """Initialize GELU."""
933
933
  super(GELU, self).__init__()
934
934
  validator.check_bool(approximate, 'approximate', self.cls_name)
935
- self.approximate = approximate
936
- if approximate:
937
- self.approximate = 'tanh'
938
- else:
935
+ self.approximate = 'tanh'
936
+ if not approximate:
939
937
  self.approximate = 'none'
940
938
 
941
939
  def construct(self, x):
@@ -1335,7 +1333,8 @@ class LRN(Cell):
1335
1333
 
1336
1334
  .. warning::
1337
1335
  LRN is deprecated on Ascend due to potential accuracy problem. It's recommended to use other
1338
- normalization methods, e.g. :class:`mindspore.nn.BatchNorm`.
1336
+ normalization methods, e.g. :class:`mindspore.nn.BatchNorm1d` ,
1337
+ :class:`mindspore.nn.BatchNorm2d` , :class:`mindspore.nn.BatchNorm3d`.
1339
1338
 
1340
1339
  Refer to :func:`mindspore.ops.lrn` for more details.
1341
1340
 
@@ -718,9 +718,9 @@ class Conv3d(_Conv):
718
718
 
719
719
  .. math::
720
720
  \begin{array}{ll} \\
721
- D_{out} \left \lceil{\frac{D_{in}}{\text{stride[0]}}} \right \rceil \\
722
- H_{out} \left \lceil{\frac{H_{in}}{\text{stride[1]}}} \right \rceil \\
723
- W_{out} \left \lceil{\frac{W_{in}}{\text{stride[2]}}} \right \rceil \\
721
+ D_{out} = \left \lceil{\frac{D_{in}}{\text{stride[0]}}} \right \rceil \\
722
+ H_{out} = \left \lceil{\frac{H_{in}}{\text{stride[1]}}} \right \rceil \\
723
+ W_{out} = \left \lceil{\frac{W_{in}}{\text{stride[2]}}} \right \rceil \\
724
724
  \end{array}
725
725
 
726
726
 
@@ -728,11 +728,11 @@ class Conv3d(_Conv):
728
728
 
729
729
  .. math::
730
730
  \begin{array}{ll} \\
731
- D_{out} \left \lfloor{\frac{D_{in} - \text{dilation[0]} \times (\text{kernel_size[0]} - 1) }
731
+ D_{out} = \left \lfloor{\frac{D_{in} - \text{dilation[0]} \times (\text{kernel_size[0]} - 1) }
732
732
  {\text{stride[0]}} + 1} \right \rfloor \\
733
- H_{out} \left \lfloor{\frac{H_{in} - \text{dilation[1]} \times (\text{kernel_size[1]} - 1) }
733
+ H_{out} = \left \lfloor{\frac{H_{in} - \text{dilation[1]} \times (\text{kernel_size[1]} - 1) }
734
734
  {\text{stride[1]}} + 1} \right \rfloor \\
735
- W_{out} \left \lfloor{\frac{W_{in} - \text{dilation[2]} \times (\text{kernel_size[2]} - 1) }
735
+ W_{out} = \left \lfloor{\frac{W_{in} - \text{dilation[2]} \times (\text{kernel_size[2]} - 1) }
736
736
  {\text{stride[2]}} + 1} \right \rfloor \\
737
737
  \end{array}
738
738
 
@@ -740,11 +740,11 @@ class Conv3d(_Conv):
740
740
 
741
741
  .. math::
742
742
  \begin{array}{ll} \\
743
- D_{out} \left \lfloor{\frac{D_{in} + padding[0] + padding[1] - (\text{dilation[0]} - 1) \times
743
+ D_{out} = \left \lfloor{\frac{D_{in} + padding[0] + padding[1] - (\text{dilation[0]} - 1) \times
744
744
  \text{kernel_size[0]} - 1 }{\text{stride[0]}} + 1} \right \rfloor \\
745
- H_{out} \left \lfloor{\frac{H_{in} + padding[2] + padding[3] - (\text{dilation[1]} - 1) \times
745
+ H_{out} = \left \lfloor{\frac{H_{in} + padding[2] + padding[3] - (\text{dilation[1]} - 1) \times
746
746
  \text{kernel_size[1]} - 1 }{\text{stride[1]}} + 1} \right \rfloor \\
747
- W_{out} \left \lfloor{\frac{W_{in} + padding[4] + padding[5] - (\text{dilation[2]} - 1) \times
747
+ W_{out} = \left \lfloor{\frac{W_{in} + padding[4] + padding[5] - (\text{dilation[2]} - 1) \times
748
748
  \text{kernel_size[2]} - 1 }{\text{stride[2]}} + 1} \right \rfloor \\
749
749
  \end{array}
750
750
 
@@ -812,7 +812,7 @@ class Conv3d(_Conv):
812
812
  bias_init,
813
813
  data_format,
814
814
  dtype=dtype)
815
- out_channels = self.out_channels
815
+ out_channels = self.out_channels // group
816
816
  self.conv3d = P.Conv3D(out_channel=out_channels,
817
817
  kernel_size=self.kernel_size,
818
818
  mode=1,
@@ -820,17 +820,33 @@ class Conv3d(_Conv):
820
820
  pad=self.padding,
821
821
  stride=self.stride,
822
822
  dilation=self.dilation,
823
- group=group,
823
+ group=1,
824
824
  data_format=self.data_format)
825
825
  self.bias_add = P.BiasAdd(data_format=self.data_format)
826
826
  self.shape = P.Shape()
827
+ self.concat = P.Concat(1)
828
+ self.split_0 = P.Split(0, self.group)
829
+ self.split_1 = P.Split(1, self.group)
827
830
 
828
831
  def construct(self, x):
829
832
  x_shape = self.shape(x)
830
833
  _check_input_5dims(x_shape, self.cls_name)
831
- out = self.conv3d(x, self.weight)
832
- if self.has_bias:
833
- out = self.bias_add(out, self.bias)
834
+ if self.group == 1:
835
+ out = self.conv3d(x, self.weight)
836
+ if self.has_bias:
837
+ out = self.bias_add(out, self.bias)
838
+ else:
839
+ features = self.split_1(x)
840
+ weights = self.split_0(self.weight)
841
+ outputs = ()
842
+ for i in range(self.group):
843
+ output = self.conv3d(features[i], weights[i])
844
+ outputs = outputs + (output,)
845
+ out = self.concat(outputs)
846
+ if self.bias is not None:
847
+ new_shape = [1 for _ in range(out.ndim)]
848
+ new_shape[1] = self.out_channels
849
+ out = out + self.bias.reshape(new_shape)
834
850
  return out
835
851
 
836
852
 
@@ -921,9 +937,9 @@ class Conv3dTranspose(_Conv):
921
937
 
922
938
  .. math::
923
939
  \begin{array}{ll} \\
924
- D_{out} \left \lfloor{\frac{D_{in}}{\text{stride[0]}} + 1} \right \rfloor \\
925
- H_{out} \left \lfloor{\frac{H_{in}}{\text{stride[1]}} + 1} \right \rfloor \\
926
- W_{out} \left \lfloor{\frac{W_{in}}{\text{stride[2]}} + 1} \right \rfloor \\
940
+ D_{out} = \left \lfloor{\frac{D_{in}}{\text{stride[0]}} + 1} \right \rfloor \\
941
+ H_{out} = \left \lfloor{\frac{H_{in}}{\text{stride[1]}} + 1} \right \rfloor \\
942
+ W_{out} = \left \lfloor{\frac{W_{in}}{\text{stride[2]}} + 1} \right \rfloor \\
927
943
  \end{array}
928
944
 
929
945
 
@@ -931,11 +947,11 @@ class Conv3dTranspose(_Conv):
931
947
 
932
948
  .. math::
933
949
  \begin{array}{ll} \\
934
- D_{out} \left \lfloor{\frac{D_{in} - \text{dilation[0]} \times (\text{kernel_size[0]} - 1) }
950
+ D_{out} = \left \lfloor{\frac{D_{in} - \text{dilation[0]} \times (\text{kernel_size[0]} - 1) }
935
951
  {\text{stride[0]}} + 1} \right \rfloor \\
936
- H_{out} \left \lfloor{\frac{H_{in} - \text{dilation[1]} \times (\text{kernel_size[1]} - 1) }
952
+ H_{out} = \left \lfloor{\frac{H_{in} - \text{dilation[1]} \times (\text{kernel_size[1]} - 1) }
937
953
  {\text{stride[1]}} + 1} \right \rfloor \\
938
- W_{out} \left \lfloor{\frac{W_{in} - \text{dilation[2]} \times (\text{kernel_size[2]} - 1) }
954
+ W_{out} = \left \lfloor{\frac{W_{in} - \text{dilation[2]} \times (\text{kernel_size[2]} - 1) }
939
955
  {\text{stride[2]}} + 1} \right \rfloor \\
940
956
  \end{array}
941
957
 
@@ -943,11 +959,11 @@ class Conv3dTranspose(_Conv):
943
959
 
944
960
  .. math::
945
961
  \begin{array}{ll} \\
946
- D_{out} \left \lfloor{\frac{D_{in} + padding[0] + padding[1] - (\text{dilation[0]} - 1) \times
962
+ D_{out} = \left \lfloor{\frac{D_{in} + padding[0] + padding[1] - (\text{dilation[0]} - 1) \times
947
963
  \text{kernel_size[0]} - 1 }{\text{stride[0]}} + 1} \right \rfloor \\
948
- H_{out} \left \lfloor{\frac{H_{in} + padding[2] + padding[3] - (\text{dilation[1]} - 1) \times
964
+ H_{out} = \left \lfloor{\frac{H_{in} + padding[2] + padding[3] - (\text{dilation[1]} - 1) \times
949
965
  \text{kernel_size[1]} - 1 }{\text{stride[1]}} + 1} \right \rfloor \\
950
- W_{out} \left \lfloor{\frac{W_{in} + padding[4] + padding[5] - (\text{dilation[2]} - 1) \times
966
+ W_{out} = \left \lfloor{\frac{W_{in} + padding[4] + padding[5] - (\text{dilation[2]} - 1) \times
951
967
  \text{kernel_size[2]} - 1 }{\text{stride[2]}} + 1} \right \rfloor \\
952
968
  \end{array}
953
969
 
@@ -57,14 +57,15 @@ class FlashAttention(Cell):
57
57
  Default True
58
58
  alibi(bool): This parameter indicates whether the flashattention supports the Alibi.
59
59
  Default: False
60
+ use_mqa(bool): Using MHA if True, only take effect under 910B. Default: False.
60
61
 
61
62
 
62
63
  Inputs:
63
64
  - **query** (Tensor) - Tensor query (:class:`mstype.fp16` [batch_size, head_num, seq_length, head_dim])
64
65
  - **key** (Tensor) - Tensor key (:class:`mstype.fp16` [batch_size, head_num, seq_length, head_dim])
65
66
  - **value** (Tensor) - Tensor value (:class:`mstype.fp16` [batch_size, head_num, seq_length, head_dim])
66
- - **attention_mask** (Tensor) - Float Tensor the mask of (:class:`mstype.fp16` [batch_size, seq_length,
67
- seq_length]): A matrix to pass masked information.
67
+ - **attention_mask** (Tensor) - Float Tensor the mask of (:class:`mstype.fp16` `mstype.uint8`
68
+ [batch_size, seq_length, seq_length]): A matrix to pass masked information.
68
69
 
69
70
  Outputs:
70
71
  A Tensor. The output of the attention with shape [batch_size, head_num, seq_length, head_dim]
@@ -102,17 +103,23 @@ class FlashAttention(Cell):
102
103
  mp=1,
103
104
  high_precision=False,
104
105
  have_attention_mask_batch=True,
105
- alibi=False
106
+ alibi=False,
107
+ use_mqa=False
106
108
  ):
107
109
  super(FlashAttention, self).__init__()
108
110
 
109
111
  scaling_constant = math.sqrt(head_dim)
110
112
  if scaling_constant == 0:
111
113
  raise ValueError("the scaling constant must not be 0.")
112
- self.scale_factor = Tensor([1. / scaling_constant], dtype=mstype.float16)
113
-
114
- self.is_910A = MSContext.get_instance().get_ascend_soc_version() == "Ascend910"
114
+ self.dropout_rate = dropout_rate
115
+ self.is_910A = MSContext.get_instance().get_ascend_soc_version() == "ascend910"
115
116
  if self.is_910A:
117
+ self.scale_factor = Tensor([1. / math.sqrt(scaling_constant)], dtype=mstype.float16)
118
+ self.scale_mul = ops.Mul().shard(((dp, mp, 1, 1), (1,)))
119
+ self.ones = ops.Ones()
120
+ self.dim_mask = Tensor([1 for _ in range(head_dim)], dtype=mstype.int8)
121
+ self.have_attention_mask_batch = have_attention_mask_batch
122
+ self.alibi = alibi
116
123
  self.flash_attention = get_flash_attention(
117
124
  prev_block_num=prev_block_num,
118
125
  next_block_num=next_block_num,
@@ -120,6 +127,10 @@ class FlashAttention(Cell):
120
127
  high_precision=high_precision
121
128
  )
122
129
  self.flash_attention.add_prim_attr("primitive_target", "Ascend")
130
+ fa_strategies = ((dp, mp, 1, 1),
131
+ (dp, mp, 1, 1),
132
+ (dp, mp, 1, 1))
133
+ self.shard(fa_strategies)
123
134
  else:
124
135
  if alibi:
125
136
  raise ValueError(f"When soc_version is not Ascend910A, alibi must be False")
@@ -128,25 +139,27 @@ class FlashAttention(Cell):
128
139
  self.reshape = ops.Reshape()
129
140
  self.zeros_like = ops.ZerosLike().shard(((dp, mp, 1, 1),))
130
141
  self.zeros = ops.Zeros()
131
- self.attn_expand_dims = ops.ExpandDims().shard(((dp, 1, 1),))
132
- fa_strategies = ((dp, 1, mp),
133
- (dp, 1, mp),
134
- (dp, 1, mp),
135
- (dp, 1, 1, 1))
142
+ self.attn_cast = ops.Cast()
143
+ if use_mqa:
144
+ fa_strategies = ((dp, mp, 1, 1),
145
+ (dp, 1, 1, 1),
146
+ (dp, 1, 1, 1),
147
+ (dp, 1, 1, 1))
148
+ else:
149
+ fa_strategies = ((dp, mp, 1, 1),
150
+ (dp, mp, 1, 1),
151
+ (dp, mp, 1, 1),
152
+ (dp, 1, 1, 1))
136
153
  if dropout_rate > 1e-5:
137
154
  fa_strategies += ((dp, mp, 1, 1),)
138
155
  self.flash_attention = FlashAttentionScore(head_num=head_num, pre_tokens=prev_block_num,
139
156
  next_tokens=next_block_num,
140
157
  keep_prob=1 - dropout_rate,
141
- scale_value=1.0,
142
- inner_precise=0 if high_precision else 1).shard(fa_strategies)
158
+ scale_value=1. / scaling_constant,
159
+ inner_precise=0 if high_precision else 1,
160
+ input_layout="BNSD").shard(fa_strategies)
143
161
 
144
- self.ones = ops.Ones()
145
- self.dim_mask = Tensor([1 for _ in range(head_dim)], dtype=mstype.int8)
146
- self.scale_mul = ops.Mul().shard(((dp, mp, 1, 1), (1,)))
147
162
  self.dropout_rate = dropout_rate
148
- self.have_attention_mask_batch = have_attention_mask_batch
149
- self.alibi = alibi
150
163
  if self.dropout_rate > 1e-5:
151
164
  self.keep_prob = Tensor(1 - self.dropout_rate, dtype=mstype.float16)
152
165
  self.fill_v2 = ops.FillV2().shard(((dp, mp, 1, 1), ()))
@@ -162,46 +175,49 @@ class FlashAttention(Cell):
162
175
  such as MatMul. Default: None.
163
176
  :return:
164
177
  """
165
- if in_strategy is None:
166
- # default: dp=1, mp=1, construct inputs only contain query, key, value
167
- in_strategy = (
168
- (1, 1, 1, 1),
169
- (1, 1, 1, 1),
170
- (1, 1, 1, 1),
171
- )
172
- self.flash_attention.shard(in_strategy)
173
- dp = in_strategy[0][0]
174
- mp = in_strategy[0][1]
175
- self.flash_attention.add_prim_attr("dev_matrix_shape", [dp, mp, 1, 1])
176
- inputs_tensor_map = [
177
- [3, 2, 1, 0],
178
- [3, 2, 1, 0],
179
- [3, 2, 1, 0],
180
- ]
181
- if self.have_attention_mask_batch:
182
- inputs_tensor_map.append([3, 1, 0])
183
- else:
184
- inputs_tensor_map.append([-1, 1, 0])
178
+ if self.is_910A:
179
+ if in_strategy is None:
180
+ # default: dp=1, mp=1, construct inputs only contain query, key, value
181
+ in_strategy = (
182
+ (1, 1, 1, 1),
183
+ (1, 1, 1, 1),
184
+ (1, 1, 1, 1),
185
+ )
186
+ self.flash_attention.shard(in_strategy)
187
+ dp = in_strategy[0][0]
188
+ mp = in_strategy[0][1]
189
+ self.flash_attention.add_prim_attr("dev_matrix_shape", [dp, mp, 1, 1])
190
+ inputs_tensor_map = [
191
+ [3, 2, 1, 0],
192
+ [3, 2, 1, 0],
193
+ [3, 2, 1, 0],
194
+ ]
195
+ if self.have_attention_mask_batch:
196
+ inputs_tensor_map.append([3, 1, 0])
197
+ else:
198
+ inputs_tensor_map.append([-1, 1, 0])
185
199
 
186
- input_empty_args_num = 2
187
- # dropout_mask
188
- if self.dropout_rate > 1e-5:
189
- input_empty_args_num -= 1
190
- inputs_tensor_map.append([3, 2, 1, 0])
200
+ input_empty_args_num = 2
201
+ # dropout_mask
202
+ if self.dropout_rate > 1e-5:
203
+ input_empty_args_num -= 1
204
+ inputs_tensor_map.append([3, 2, 1, 0])
191
205
 
192
- if self.alibi:
193
- input_empty_args_num -= 1
194
- inputs_tensor_map.append([3, 2, 1, 0])
206
+ if self.alibi:
207
+ input_empty_args_num -= 1
208
+ inputs_tensor_map.append([3, 2, 1, 0])
195
209
 
196
- self.flash_attention.add_prim_attr("inputs_tensor_map", inputs_tensor_map)
210
+ self.flash_attention.add_prim_attr("inputs_tensor_map", inputs_tensor_map)
197
211
 
198
- self.flash_attention.add_prim_attr("outputs_tensor_map", [
199
- [3, 2, 1, 0], # O
200
- [3, 2, 1], # L
201
- [3, 2, 1] # M
202
- ])
203
- self.flash_attention.add_prim_attr("as_loss_divisor", 0)
204
- self.flash_attention.add_prim_attr("empty_mirror_ops", input_empty_args_num)
212
+ self.flash_attention.add_prim_attr("outputs_tensor_map", [
213
+ [3, 2, 1, 0], # O
214
+ [3, 2, 1], # L
215
+ [3, 2, 1] # M
216
+ ])
217
+ self.flash_attention.add_prim_attr("as_loss_divisor", 0)
218
+ self.flash_attention.add_prim_attr("empty_mirror_ops", input_empty_args_num)
219
+ else:
220
+ self.flash_attention.shard(in_strategy)
205
221
 
206
222
  def construct(self, query, key, value, attn_mask=None, alibi_mask=None):
207
223
  """FlashAttention forward
@@ -212,24 +228,22 @@ class FlashAttention(Cell):
212
228
  :param alibi_mask: [bsz, head_num, 1, seq_len], if not None
213
229
  :return: output [bsz, head_num, seq_len, head_dim]
214
230
  """
215
- query = self.scale_mul(query, self.scale_factor)
216
231
  bsz, head_num, seq_len, head_dim = query.shape
217
- _, k_head_num, k_seq_len, _ = key.shape
218
- _, v_head_num, v_seq_len, _ = value.shape
219
- if head_num != k_head_num or head_num != v_head_num:
220
- raise ValueError(
221
- "the head_num of query, key and value must be the same, "
222
- "If different head_num are used, users need to change themselves to be same by tile.")
223
- if seq_len % 16 != 0 or k_seq_len % 16 != 0 or k_seq_len != v_seq_len:
224
- raise ValueError(
225
- "query, key, value seq_len must be a multiple of 16, and key seq_len, value seq_len must be the same.")
226
-
227
- if head_dim > 304:
228
- raise ValueError(
229
- "the head_dim must be less than 304, otherwise the ub would be OOM.")
230
-
231
232
  if self.is_910A:
233
+ _, k_head_num, k_seq_len, _ = key.shape
234
+ _, v_head_num, v_seq_len, _ = value.shape
235
+ if head_num != k_head_num or head_num != v_head_num:
236
+ raise ValueError(
237
+ "the head_num of query, key and value must be the same, "
238
+ "If different head_num are used, users need to change themselves to be same by tile.")
239
+ if seq_len % 16 != 0 or k_seq_len % 16 != 0 or k_seq_len != v_seq_len:
240
+ raise ValueError(
241
+ "query, key, value seq_len must be a multiple of 16, "
242
+ "and the seq_len between key and value must be equal.")
232
243
  # 910A -- FlashAttentionPrimtive
244
+ if head_dim > 304:
245
+ raise ValueError(
246
+ "the head_dim must be less than 304, otherwise the ub would be OOM.")
233
247
  if self.dropout_rate > 1e-5:
234
248
  drop_mask_bits = self.drop_gen_mask((bsz, head_num, seq_len, seq_len), self.keep_prob)
235
249
  tensor_shape = Tensor((bsz, head_num, seq_len, seq_len), mstype.int32)
@@ -238,27 +252,25 @@ class FlashAttention(Cell):
238
252
  drop_mask = self.do_dropout(ones, drop_mask_bits, self.keep_prob)
239
253
  else:
240
254
  drop_mask = None
255
+ query = self.scale_mul(query, self.scale_factor)
256
+ key = self.scale_mul(key, self.scale_factor)
257
+ attn_mask = self.cast(attn_mask, mstype.float16)
241
258
  output, _, _ = self.flash_attention(query, key, value, attn_mask, drop_mask, alibi_mask)
242
259
  else:
243
- # FlashAttentionScore
244
- # Useless input, just for binary calls.
260
+ # 910B -- FlashAttentionScore
245
261
  if self.dropout_rate > 1e-5:
246
262
  drop_mask_bits = self.reshape(self.drop_gen_mask((bsz, head_num, seq_len, seq_len), self.keep_prob),
247
263
  (bsz, head_num, seq_len, seq_len // 8))
248
264
  else:
249
265
  drop_mask_bits = None
250
- # (B, N, S, D) -> (B, S, H)
251
- query = self.reshape(self.transpose_4d_pre(query, (0, 2, 1, 3)), (bsz, seq_len, -1))
252
- key = self.reshape(self.transpose_4d_pre(key, (0, 2, 1, 3)), (bsz, seq_len, -1))
253
- value = self.reshape(self.transpose_4d_pre(value, (0, 2, 1, 3)), (bsz, seq_len, -1))
254
- attn_mask = self.attn_expand_dims(attn_mask, 1)
266
+ # (B, S, S) -> (B, 1, S, S)
267
+ attn_mask = self.cast(self.reshape(attn_mask, (bsz, 1, seq_len, seq_len)), mstype.uint8)
255
268
  output, _, _ = self.flash_attention(query,
256
269
  key,
257
270
  value,
258
271
  attn_mask,
259
272
  drop_mask_bits,
260
273
  None,
274
+ None,
261
275
  None)
262
- output = self.transpose_4d_post(self.reshape(output, (bsz, seq_len, head_num, head_dim)), (0, 2, 1, 3))
263
-
264
276
  return output
@@ -375,9 +375,6 @@ class DiGamma(Cell):
375
375
  nan, real_result)
376
376
 
377
377
 
378
- eps_fp32 = Tensor(np.finfo(np.float32).eps, mstype.float32)
379
-
380
-
381
378
  def _while_helper_func(cond, body, vals):
382
379
  while cond(vals).any():
383
380
  vals = body(vals)
@@ -394,7 +391,7 @@ def _igamma_series(ax, x, a, enabled):
394
391
  select = P.Select()
395
392
 
396
393
  # If more data types are supported, this epsilon need to be selected.
397
- epsilon = eps_fp32
394
+ epsilon = Tensor(np.finfo(np.float32).eps, mstype.float32)
398
395
 
399
396
  def cond(vals):
400
397
  enabled = vals[0]
@@ -443,7 +440,7 @@ def _igammac_continued_fraction(ax, x, a, enabled):
443
440
  select = P.Select()
444
441
 
445
442
  # If more data types are supported, this epsilon need to be selected.
446
- epsilon = eps_fp32
443
+ epsilon = Tensor(np.finfo(np.float32).eps, mstype.float32)
447
444
 
448
445
  def cond(vals):
449
446
  enabled = vals[0]
@@ -620,8 +617,7 @@ class IGamma(Cell):
620
617
  x = F.broadcast_to(x, para_shape)
621
618
  a = F.broadcast_to(a, para_shape)
622
619
  x_is_zero = self.equal(x, 0)
623
- log_maxfloat = self.log_maxfloat32
624
- underflow = self.less(ax, self.neg(log_maxfloat))
620
+ underflow = self.less(ax, self.neg(self.log_maxfloat32))
625
621
  ax = self.exp(ax)
626
622
  enabled = self.logicalnot(self.logicalor(self.logicalor(x_is_zero, domain_error), underflow))
627
623
  output = self.select(use_igammac,
@@ -83,7 +83,7 @@ def _check_lstmcell_init(func):
83
83
 
84
84
 
85
85
  def _rnn_tanh_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
86
- '''RNN cell function with tanh activation'''
86
+ """RNN cell function with tanh activation"""
87
87
  if b_ih is None:
88
88
  igates = P.MatMul(False, True)(inputs, w_ih)
89
89
  hgates = P.MatMul(False, True)(hidden, w_hh)
@@ -94,7 +94,7 @@ def _rnn_tanh_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
94
94
 
95
95
 
96
96
  def _rnn_relu_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
97
- '''RNN cell function with relu activation'''
97
+ """RNN cell function with relu activation"""
98
98
  if b_ih is None:
99
99
  igates = P.MatMul(False, True)(inputs, w_ih)
100
100
  hgates = P.MatMul(False, True)(hidden, w_hh)
@@ -105,7 +105,7 @@ def _rnn_relu_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
105
105
 
106
106
 
107
107
  def _lstm_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
108
- '''LSTM cell function'''
108
+ """LSTM cell function"""
109
109
  hx, cx = hidden
110
110
  if b_ih is None:
111
111
  gates = P.MatMul(False, True)(inputs, w_ih) + P.MatMul(False, True)(hx, w_hh)
@@ -125,7 +125,7 @@ def _lstm_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
125
125
 
126
126
 
127
127
  def _gru_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
128
- '''GRU cell function'''
128
+ """GRU cell function"""
129
129
  if b_ih is None:
130
130
  gi = P.MatMul(False, True)(inputs, w_ih)
131
131
  gh = P.MatMul(False, True)(hidden, w_hh)
@@ -144,7 +144,7 @@ def _gru_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
144
144
 
145
145
 
146
146
  class RNNCellBase(Cell):
147
- '''Basic class for RNN Cells'''
147
+ """Basic class for RNN Cells"""
148
148
  def __init__(self, input_size: int, hidden_size: int, has_bias: bool, num_chunks: int,
149
149
  dtype=mstype.float32):
150
150
  super().__init__()
@@ -644,6 +644,9 @@ class PipelineCell(Cell):
644
644
  self.micro_inputs = nn.CellList()
645
645
  self.micro_size = micro_size
646
646
  self.add_list = []
647
+ if not isinstance(network, Cell):
648
+ raise TypeError("For 'PipelineCell', the argument 'network' must cell type, "
649
+ "but got the type : {}.".format(type(network)))
647
650
  if not isinstance(micro_size, int):
648
651
  raise TypeError("For 'PipelineCell', the argument 'micro_size' must be integer, "
649
652
  "but got the type : {}.".format(type(micro_size)))
@@ -689,6 +692,9 @@ class GradAccumulationCell(Cell):
689
692
  self.micro_inputs = nn.CellList()
690
693
  self.micro_size = micro_size
691
694
  self.add_list = []
695
+ if not isinstance(network, Cell):
696
+ raise TypeError("For 'GradAccumulationCell', the argument 'network' must cell type, "
697
+ "but got the type : {}.".format(type(network)))
692
698
  if not isinstance(micro_size, int):
693
699
  raise TypeError("For 'GradAccumulationCell', the argument 'micro_size' must be integer, "
694
700
  "but got the type : {}.".format(type(micro_size)))