mindspore 2.2.0__cp39-cp39-win_amd64.whl → 2.2.10__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +3 -3
- mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
- mindspore/_extends/graph_kernel/splitter.py +3 -2
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +83 -66
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -4
- mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +2 -1
- mindspore/_extends/parse/standard_method.py +2 -9
- mindspore/_extends/remote/kernel_build_server.py +2 -1
- mindspore/atlprov.dll +0 -0
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/common/api.py +1 -1
- mindspore/common/auto_dynamic_shape.py +81 -85
- mindspore/common/dump.py +1 -1
- mindspore/common/tensor.py +3 -20
- mindspore/config/op_info.config +1 -1
- mindspore/context.py +11 -4
- mindspore/dataset/engine/datasets_standard_format.py +5 -0
- mindspore/dataset/vision/transforms.py +21 -21
- mindspore/dnnl.dll +0 -0
- mindspore/dpcmi.dll +0 -0
- mindspore/experimental/optim/adam.py +1 -1
- mindspore/gen_ops.py +1 -1
- mindspore/include/api/model.h +17 -0
- mindspore/include/api/status.h +8 -3
- mindspore/jpeg62.dll +0 -0
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/nn/cell.py +0 -3
- mindspore/nn/layer/activation.py +4 -5
- mindspore/nn/layer/conv.py +39 -23
- mindspore/nn/layer/flash_attention.py +90 -78
- mindspore/nn/layer/math.py +3 -7
- mindspore/nn/layer/rnn_cells.py +5 -5
- mindspore/nn/wrap/cell_wrapper.py +6 -0
- mindspore/numpy/utils_const.py +5 -5
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/_grad_experimental/grad_array_ops.py +1 -1
- mindspore/ops/_grad_experimental/grad_implementations.py +2 -2
- mindspore/ops/_grad_experimental/grad_math_ops.py +19 -18
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
- mindspore/ops/_op_impl/aicpu/add.py +3 -3
- mindspore/ops/_utils/utils.py +2 -0
- mindspore/ops/composite/multitype_ops/_compile_utils.py +2 -1
- mindspore/ops/composite/multitype_ops/getitem_impl.py +2 -2
- mindspore/ops/function/array_func.py +10 -7
- mindspore/ops/function/grad/grad_func.py +0 -1
- mindspore/ops/function/nn_func.py +98 -9
- mindspore/ops/function/random_func.py +2 -1
- mindspore/ops/op_info_register.py +24 -21
- mindspore/ops/operations/__init__.py +3 -2
- mindspore/ops/operations/_grad_ops.py +24 -4
- mindspore/ops/operations/_inner_ops.py +155 -23
- mindspore/ops/operations/array_ops.py +9 -7
- mindspore/ops/operations/comm_ops.py +2 -2
- mindspore/ops/operations/custom_ops.py +85 -68
- mindspore/ops/operations/inner_ops.py +26 -3
- mindspore/ops/operations/math_ops.py +4 -3
- mindspore/ops/operations/nn_ops.py +109 -28
- mindspore/parallel/_parallel_serialization.py +10 -3
- mindspore/parallel/_tensor.py +4 -1
- mindspore/parallel/checkpoint_transform.py +13 -2
- mindspore/parallel/shard.py +17 -10
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/common/util.py +1 -0
- mindspore/profiler/parser/ascend_hccl_generator.py +232 -0
- mindspore/profiler/parser/ascend_msprof_exporter.py +86 -43
- mindspore/profiler/parser/ascend_msprof_generator.py +196 -9
- mindspore/profiler/parser/ascend_op_generator.py +1 -1
- mindspore/profiler/parser/ascend_timeline_generator.py +6 -182
- mindspore/profiler/parser/base_timeline_generator.py +1 -1
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +2 -2
- mindspore/profiler/parser/framework_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +19 -0
- mindspore/profiler/profiling.py +46 -24
- mindspore/rewrite/api/pattern_engine.py +1 -1
- mindspore/rewrite/parsers/for_parser.py +1 -1
- mindspore/rewrite/symbol_tree.py +1 -4
- mindspore/run_check/_check_version.py +5 -3
- mindspore/safeguard/rewrite_obfuscation.py +52 -28
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/callback/_summary_collector.py +1 -1
- mindspore/train/dataset_helper.py +1 -0
- mindspore/train/model.py +2 -2
- mindspore/train/serialization.py +97 -11
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/summary_record.py +23 -7
- mindspore/turbojpeg.dll +0 -0
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.2.0.dist-info → mindspore-2.2.10.dist-info}/METADATA +1 -1
- {mindspore-2.2.0.dist-info → mindspore-2.2.10.dist-info}/RECORD +122 -122
- {mindspore-2.2.0.dist-info → mindspore-2.2.10.dist-info}/WHEEL +0 -0
- {mindspore-2.2.0.dist-info → mindspore-2.2.10.dist-info}/entry_points.txt +0 -0
- {mindspore-2.2.0.dist-info → mindspore-2.2.10.dist-info}/top_level.txt +0 -0
mindspore/include/api/model.h
CHANGED
|
@@ -136,6 +136,13 @@ class MS_API Model {
|
|
|
136
136
|
/// \return Status.
|
|
137
137
|
Status UpdateWeights(const std::vector<MSTensor> &new_weights);
|
|
138
138
|
|
|
139
|
+
/// \brief Change the size and or content of weight tensors
|
|
140
|
+
///
|
|
141
|
+
/// \param[in] A vector where model constant are arranged in sequence
|
|
142
|
+
///
|
|
143
|
+
/// \return Status.
|
|
144
|
+
Status UpdateWeights(const std::vector<std::vector<MSTensor>> &new_weights);
|
|
145
|
+
|
|
139
146
|
/// \brief Inference model API. If use this API in train mode, it's equal to RunStep API.
|
|
140
147
|
///
|
|
141
148
|
/// \param[in] inputs A vector where model inputs are arranged in sequence.
|
|
@@ -358,6 +365,13 @@ class MS_API Model {
|
|
|
358
365
|
|
|
359
366
|
const std::shared_ptr<ModelImpl> impl() const { return impl_; }
|
|
360
367
|
|
|
368
|
+
/// \brief Get model info by key
|
|
369
|
+
///
|
|
370
|
+
/// \param[in] key The key of model info key-value pair
|
|
371
|
+
///
|
|
372
|
+
/// \return The value of the model info associated with the given key.
|
|
373
|
+
inline std::string GetModelInfo(const std::string &key);
|
|
374
|
+
|
|
361
375
|
private:
|
|
362
376
|
friend class Serialization;
|
|
363
377
|
// api without std::string
|
|
@@ -374,6 +388,7 @@ class MS_API Model {
|
|
|
374
388
|
const std::vector<char> &cropto_lib_path);
|
|
375
389
|
Status Build(const std::vector<char> &model_path, ModelType model_type, const std::shared_ptr<Context> &model_context,
|
|
376
390
|
const Key &dec_key, const std::vector<char> &dec_mode, const std::vector<char> &cropto_lib_path);
|
|
391
|
+
std::vector<char> GetModelInfo(const std::vector<char> &key);
|
|
377
392
|
std::shared_ptr<ModelImpl> impl_;
|
|
378
393
|
};
|
|
379
394
|
|
|
@@ -416,5 +431,7 @@ Status Model::Build(const std::string &model_path, ModelType model_type,
|
|
|
416
431
|
const std::shared_ptr<Context> &model_context) {
|
|
417
432
|
return Build(StringToChar(model_path), model_type, model_context);
|
|
418
433
|
}
|
|
434
|
+
|
|
435
|
+
inline std::string Model::GetModelInfo(const std::string &key) { return CharToString(GetModelInfo(StringToChar(key))); }
|
|
419
436
|
} // namespace mindspore
|
|
420
437
|
#endif // MINDSPORE_INCLUDE_API_MODEL_H
|
mindspore/include/api/status.h
CHANGED
|
@@ -83,9 +83,14 @@ enum StatusCode : uint32_t {
|
|
|
83
83
|
kLiteModelRebuild = kLite | (0x0FFFFFFF & -12), /**< Model has been built. */
|
|
84
84
|
|
|
85
85
|
// Executor error code, range: [-100,-200)
|
|
86
|
-
kLiteOutOfTensorRange = kLite | (0x0FFFFFFF & -100),
|
|
87
|
-
kLiteInputTensorError = kLite | (0x0FFFFFFF & -101),
|
|
88
|
-
kLiteReentrantError = kLite | (0x0FFFFFFF & -102),
|
|
86
|
+
kLiteOutOfTensorRange = kLite | (0x0FFFFFFF & -100), /**< Failed to check range. */
|
|
87
|
+
kLiteInputTensorError = kLite | (0x0FFFFFFF & -101), /**< Failed to check input tensor. */
|
|
88
|
+
kLiteReentrantError = kLite | (0x0FFFFFFF & -102), /**< Exist executor running. */
|
|
89
|
+
kLiteLLMWaitProcessTimeOut = kLite | (0x0FFFFFFF & -103), /**< Wait to be processed time out. */
|
|
90
|
+
kLiteLLMKVCacheNotExist = kLite | (0x0FFFFFFF & -104), /**< KV Cache not exist. */
|
|
91
|
+
kLiteLLMRepeatRequest = kLite | (0x0FFFFFFF & -105), /**< repeat request. */
|
|
92
|
+
kLiteLLMRequestAlreadyCompleted = kLite | (0x0FFFFFFF & -106), /**< request already complete!. */
|
|
93
|
+
kLiteLLMEngineFinalized = kLite | (0x0FFFFFFF & -107), /**< llm engine finalized. */
|
|
89
94
|
|
|
90
95
|
// Graph error code, range: [-200,-300)
|
|
91
96
|
kLiteGraphFileError = kLite | (0x0FFFFFFF & -200), /**< Failed to verify graph file. */
|
mindspore/jpeg62.dll
CHANGED
|
Binary file
|
mindspore/mindspore_backend.dll
CHANGED
|
Binary file
|
mindspore/mindspore_common.dll
CHANGED
|
Binary file
|
mindspore/mindspore_core.dll
CHANGED
|
Binary file
|
mindspore/mindspore_glog.dll
CHANGED
|
Binary file
|
|
Binary file
|
mindspore/msobj140.dll
CHANGED
|
Binary file
|
mindspore/mspdb140.dll
CHANGED
|
Binary file
|
mindspore/mspdbcore.dll
CHANGED
|
Binary file
|
mindspore/mspdbst.dll
CHANGED
|
Binary file
|
mindspore/mspft140.dll
CHANGED
|
Binary file
|
mindspore/msvcdis140.dll
CHANGED
|
Binary file
|
mindspore/msvcp140_1.dll
CHANGED
|
Binary file
|
mindspore/msvcp140_2.dll
CHANGED
|
Binary file
|
|
Binary file
|
|
Binary file
|
mindspore/nn/cell.py
CHANGED
|
@@ -1081,9 +1081,6 @@ class Cell(Cell_):
|
|
|
1081
1081
|
if not isinstance(param, Parameter) and param is not None:
|
|
1082
1082
|
raise TypeError(f"For 'insert_param_to_cell', the argument 'param' must be 'Parameter' if not None, "
|
|
1083
1083
|
f"but got {type(param)}.")
|
|
1084
|
-
if param is None:
|
|
1085
|
-
raise TypeError(f"For 'insert_param_to_cell', the argument 'param' must not be None, "
|
|
1086
|
-
f"but got None.")
|
|
1087
1084
|
if isinstance(param, Parameter) and param.name == PARAMETER_NAME_DEFAULT:
|
|
1088
1085
|
param.name = param_name
|
|
1089
1086
|
self._params[param_name] = param
|
mindspore/nn/layer/activation.py
CHANGED
|
@@ -932,10 +932,8 @@ class GELU(Cell):
|
|
|
932
932
|
"""Initialize GELU."""
|
|
933
933
|
super(GELU, self).__init__()
|
|
934
934
|
validator.check_bool(approximate, 'approximate', self.cls_name)
|
|
935
|
-
self.approximate =
|
|
936
|
-
if approximate:
|
|
937
|
-
self.approximate = 'tanh'
|
|
938
|
-
else:
|
|
935
|
+
self.approximate = 'tanh'
|
|
936
|
+
if not approximate:
|
|
939
937
|
self.approximate = 'none'
|
|
940
938
|
|
|
941
939
|
def construct(self, x):
|
|
@@ -1335,7 +1333,8 @@ class LRN(Cell):
|
|
|
1335
1333
|
|
|
1336
1334
|
.. warning::
|
|
1337
1335
|
LRN is deprecated on Ascend due to potential accuracy problem. It's recommended to use other
|
|
1338
|
-
normalization methods, e.g. :class:`mindspore.nn.
|
|
1336
|
+
normalization methods, e.g. :class:`mindspore.nn.BatchNorm1d` ,
|
|
1337
|
+
:class:`mindspore.nn.BatchNorm2d` , :class:`mindspore.nn.BatchNorm3d`.
|
|
1339
1338
|
|
|
1340
1339
|
Refer to :func:`mindspore.ops.lrn` for more details.
|
|
1341
1340
|
|
mindspore/nn/layer/conv.py
CHANGED
|
@@ -718,9 +718,9 @@ class Conv3d(_Conv):
|
|
|
718
718
|
|
|
719
719
|
.. math::
|
|
720
720
|
\begin{array}{ll} \\
|
|
721
|
-
D_{out}
|
|
722
|
-
H_{out}
|
|
723
|
-
W_{out}
|
|
721
|
+
D_{out} = \left \lceil{\frac{D_{in}}{\text{stride[0]}}} \right \rceil \\
|
|
722
|
+
H_{out} = \left \lceil{\frac{H_{in}}{\text{stride[1]}}} \right \rceil \\
|
|
723
|
+
W_{out} = \left \lceil{\frac{W_{in}}{\text{stride[2]}}} \right \rceil \\
|
|
724
724
|
\end{array}
|
|
725
725
|
|
|
726
726
|
|
|
@@ -728,11 +728,11 @@ class Conv3d(_Conv):
|
|
|
728
728
|
|
|
729
729
|
.. math::
|
|
730
730
|
\begin{array}{ll} \\
|
|
731
|
-
D_{out}
|
|
731
|
+
D_{out} = \left \lfloor{\frac{D_{in} - \text{dilation[0]} \times (\text{kernel_size[0]} - 1) }
|
|
732
732
|
{\text{stride[0]}} + 1} \right \rfloor \\
|
|
733
|
-
H_{out}
|
|
733
|
+
H_{out} = \left \lfloor{\frac{H_{in} - \text{dilation[1]} \times (\text{kernel_size[1]} - 1) }
|
|
734
734
|
{\text{stride[1]}} + 1} \right \rfloor \\
|
|
735
|
-
W_{out}
|
|
735
|
+
W_{out} = \left \lfloor{\frac{W_{in} - \text{dilation[2]} \times (\text{kernel_size[2]} - 1) }
|
|
736
736
|
{\text{stride[2]}} + 1} \right \rfloor \\
|
|
737
737
|
\end{array}
|
|
738
738
|
|
|
@@ -740,11 +740,11 @@ class Conv3d(_Conv):
|
|
|
740
740
|
|
|
741
741
|
.. math::
|
|
742
742
|
\begin{array}{ll} \\
|
|
743
|
-
D_{out}
|
|
743
|
+
D_{out} = \left \lfloor{\frac{D_{in} + padding[0] + padding[1] - (\text{dilation[0]} - 1) \times
|
|
744
744
|
\text{kernel_size[0]} - 1 }{\text{stride[0]}} + 1} \right \rfloor \\
|
|
745
|
-
H_{out}
|
|
745
|
+
H_{out} = \left \lfloor{\frac{H_{in} + padding[2] + padding[3] - (\text{dilation[1]} - 1) \times
|
|
746
746
|
\text{kernel_size[1]} - 1 }{\text{stride[1]}} + 1} \right \rfloor \\
|
|
747
|
-
W_{out}
|
|
747
|
+
W_{out} = \left \lfloor{\frac{W_{in} + padding[4] + padding[5] - (\text{dilation[2]} - 1) \times
|
|
748
748
|
\text{kernel_size[2]} - 1 }{\text{stride[2]}} + 1} \right \rfloor \\
|
|
749
749
|
\end{array}
|
|
750
750
|
|
|
@@ -812,7 +812,7 @@ class Conv3d(_Conv):
|
|
|
812
812
|
bias_init,
|
|
813
813
|
data_format,
|
|
814
814
|
dtype=dtype)
|
|
815
|
-
out_channels = self.out_channels
|
|
815
|
+
out_channels = self.out_channels // group
|
|
816
816
|
self.conv3d = P.Conv3D(out_channel=out_channels,
|
|
817
817
|
kernel_size=self.kernel_size,
|
|
818
818
|
mode=1,
|
|
@@ -820,17 +820,33 @@ class Conv3d(_Conv):
|
|
|
820
820
|
pad=self.padding,
|
|
821
821
|
stride=self.stride,
|
|
822
822
|
dilation=self.dilation,
|
|
823
|
-
group=
|
|
823
|
+
group=1,
|
|
824
824
|
data_format=self.data_format)
|
|
825
825
|
self.bias_add = P.BiasAdd(data_format=self.data_format)
|
|
826
826
|
self.shape = P.Shape()
|
|
827
|
+
self.concat = P.Concat(1)
|
|
828
|
+
self.split_0 = P.Split(0, self.group)
|
|
829
|
+
self.split_1 = P.Split(1, self.group)
|
|
827
830
|
|
|
828
831
|
def construct(self, x):
|
|
829
832
|
x_shape = self.shape(x)
|
|
830
833
|
_check_input_5dims(x_shape, self.cls_name)
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
|
|
834
|
+
if self.group == 1:
|
|
835
|
+
out = self.conv3d(x, self.weight)
|
|
836
|
+
if self.has_bias:
|
|
837
|
+
out = self.bias_add(out, self.bias)
|
|
838
|
+
else:
|
|
839
|
+
features = self.split_1(x)
|
|
840
|
+
weights = self.split_0(self.weight)
|
|
841
|
+
outputs = ()
|
|
842
|
+
for i in range(self.group):
|
|
843
|
+
output = self.conv3d(features[i], weights[i])
|
|
844
|
+
outputs = outputs + (output,)
|
|
845
|
+
out = self.concat(outputs)
|
|
846
|
+
if self.bias is not None:
|
|
847
|
+
new_shape = [1 for _ in range(out.ndim)]
|
|
848
|
+
new_shape[1] = self.out_channels
|
|
849
|
+
out = out + self.bias.reshape(new_shape)
|
|
834
850
|
return out
|
|
835
851
|
|
|
836
852
|
|
|
@@ -921,9 +937,9 @@ class Conv3dTranspose(_Conv):
|
|
|
921
937
|
|
|
922
938
|
.. math::
|
|
923
939
|
\begin{array}{ll} \\
|
|
924
|
-
D_{out}
|
|
925
|
-
H_{out}
|
|
926
|
-
W_{out}
|
|
940
|
+
D_{out} = \left \lfloor{\frac{D_{in}}{\text{stride[0]}} + 1} \right \rfloor \\
|
|
941
|
+
H_{out} = \left \lfloor{\frac{H_{in}}{\text{stride[1]}} + 1} \right \rfloor \\
|
|
942
|
+
W_{out} = \left \lfloor{\frac{W_{in}}{\text{stride[2]}} + 1} \right \rfloor \\
|
|
927
943
|
\end{array}
|
|
928
944
|
|
|
929
945
|
|
|
@@ -931,11 +947,11 @@ class Conv3dTranspose(_Conv):
|
|
|
931
947
|
|
|
932
948
|
.. math::
|
|
933
949
|
\begin{array}{ll} \\
|
|
934
|
-
D_{out}
|
|
950
|
+
D_{out} = \left \lfloor{\frac{D_{in} - \text{dilation[0]} \times (\text{kernel_size[0]} - 1) }
|
|
935
951
|
{\text{stride[0]}} + 1} \right \rfloor \\
|
|
936
|
-
H_{out}
|
|
952
|
+
H_{out} = \left \lfloor{\frac{H_{in} - \text{dilation[1]} \times (\text{kernel_size[1]} - 1) }
|
|
937
953
|
{\text{stride[1]}} + 1} \right \rfloor \\
|
|
938
|
-
W_{out}
|
|
954
|
+
W_{out} = \left \lfloor{\frac{W_{in} - \text{dilation[2]} \times (\text{kernel_size[2]} - 1) }
|
|
939
955
|
{\text{stride[2]}} + 1} \right \rfloor \\
|
|
940
956
|
\end{array}
|
|
941
957
|
|
|
@@ -943,11 +959,11 @@ class Conv3dTranspose(_Conv):
|
|
|
943
959
|
|
|
944
960
|
.. math::
|
|
945
961
|
\begin{array}{ll} \\
|
|
946
|
-
D_{out}
|
|
962
|
+
D_{out} = \left \lfloor{\frac{D_{in} + padding[0] + padding[1] - (\text{dilation[0]} - 1) \times
|
|
947
963
|
\text{kernel_size[0]} - 1 }{\text{stride[0]}} + 1} \right \rfloor \\
|
|
948
|
-
H_{out}
|
|
964
|
+
H_{out} = \left \lfloor{\frac{H_{in} + padding[2] + padding[3] - (\text{dilation[1]} - 1) \times
|
|
949
965
|
\text{kernel_size[1]} - 1 }{\text{stride[1]}} + 1} \right \rfloor \\
|
|
950
|
-
W_{out}
|
|
966
|
+
W_{out} = \left \lfloor{\frac{W_{in} + padding[4] + padding[5] - (\text{dilation[2]} - 1) \times
|
|
951
967
|
\text{kernel_size[2]} - 1 }{\text{stride[2]}} + 1} \right \rfloor \\
|
|
952
968
|
\end{array}
|
|
953
969
|
|
|
@@ -57,14 +57,15 @@ class FlashAttention(Cell):
|
|
|
57
57
|
Default True
|
|
58
58
|
alibi(bool): This parameter indicates whether the flashattention supports the Alibi.
|
|
59
59
|
Default: False
|
|
60
|
+
use_mqa(bool): Using MHA if True, only take effect under 910B. Default: False.
|
|
60
61
|
|
|
61
62
|
|
|
62
63
|
Inputs:
|
|
63
64
|
- **query** (Tensor) - Tensor query (:class:`mstype.fp16` [batch_size, head_num, seq_length, head_dim])
|
|
64
65
|
- **key** (Tensor) - Tensor key (:class:`mstype.fp16` [batch_size, head_num, seq_length, head_dim])
|
|
65
66
|
- **value** (Tensor) - Tensor value (:class:`mstype.fp16` [batch_size, head_num, seq_length, head_dim])
|
|
66
|
-
- **attention_mask** (Tensor) - Float Tensor the mask of (:class:`mstype.fp16`
|
|
67
|
-
|
|
67
|
+
- **attention_mask** (Tensor) - Float Tensor the mask of (:class:`mstype.fp16` `mstype.uint8`
|
|
68
|
+
[batch_size, seq_length, seq_length]): A matrix to pass masked information.
|
|
68
69
|
|
|
69
70
|
Outputs:
|
|
70
71
|
A Tensor. The output of the attention with shape [batch_size, head_num, seq_length, head_dim]
|
|
@@ -102,17 +103,23 @@ class FlashAttention(Cell):
|
|
|
102
103
|
mp=1,
|
|
103
104
|
high_precision=False,
|
|
104
105
|
have_attention_mask_batch=True,
|
|
105
|
-
alibi=False
|
|
106
|
+
alibi=False,
|
|
107
|
+
use_mqa=False
|
|
106
108
|
):
|
|
107
109
|
super(FlashAttention, self).__init__()
|
|
108
110
|
|
|
109
111
|
scaling_constant = math.sqrt(head_dim)
|
|
110
112
|
if scaling_constant == 0:
|
|
111
113
|
raise ValueError("the scaling constant must not be 0.")
|
|
112
|
-
self.
|
|
113
|
-
|
|
114
|
-
self.is_910A = MSContext.get_instance().get_ascend_soc_version() == "Ascend910"
|
|
114
|
+
self.dropout_rate = dropout_rate
|
|
115
|
+
self.is_910A = MSContext.get_instance().get_ascend_soc_version() == "ascend910"
|
|
115
116
|
if self.is_910A:
|
|
117
|
+
self.scale_factor = Tensor([1. / math.sqrt(scaling_constant)], dtype=mstype.float16)
|
|
118
|
+
self.scale_mul = ops.Mul().shard(((dp, mp, 1, 1), (1,)))
|
|
119
|
+
self.ones = ops.Ones()
|
|
120
|
+
self.dim_mask = Tensor([1 for _ in range(head_dim)], dtype=mstype.int8)
|
|
121
|
+
self.have_attention_mask_batch = have_attention_mask_batch
|
|
122
|
+
self.alibi = alibi
|
|
116
123
|
self.flash_attention = get_flash_attention(
|
|
117
124
|
prev_block_num=prev_block_num,
|
|
118
125
|
next_block_num=next_block_num,
|
|
@@ -120,6 +127,10 @@ class FlashAttention(Cell):
|
|
|
120
127
|
high_precision=high_precision
|
|
121
128
|
)
|
|
122
129
|
self.flash_attention.add_prim_attr("primitive_target", "Ascend")
|
|
130
|
+
fa_strategies = ((dp, mp, 1, 1),
|
|
131
|
+
(dp, mp, 1, 1),
|
|
132
|
+
(dp, mp, 1, 1))
|
|
133
|
+
self.shard(fa_strategies)
|
|
123
134
|
else:
|
|
124
135
|
if alibi:
|
|
125
136
|
raise ValueError(f"When soc_version is not Ascend910A, alibi must be False")
|
|
@@ -128,25 +139,27 @@ class FlashAttention(Cell):
|
|
|
128
139
|
self.reshape = ops.Reshape()
|
|
129
140
|
self.zeros_like = ops.ZerosLike().shard(((dp, mp, 1, 1),))
|
|
130
141
|
self.zeros = ops.Zeros()
|
|
131
|
-
self.
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
142
|
+
self.attn_cast = ops.Cast()
|
|
143
|
+
if use_mqa:
|
|
144
|
+
fa_strategies = ((dp, mp, 1, 1),
|
|
145
|
+
(dp, 1, 1, 1),
|
|
146
|
+
(dp, 1, 1, 1),
|
|
147
|
+
(dp, 1, 1, 1))
|
|
148
|
+
else:
|
|
149
|
+
fa_strategies = ((dp, mp, 1, 1),
|
|
150
|
+
(dp, mp, 1, 1),
|
|
151
|
+
(dp, mp, 1, 1),
|
|
152
|
+
(dp, 1, 1, 1))
|
|
136
153
|
if dropout_rate > 1e-5:
|
|
137
154
|
fa_strategies += ((dp, mp, 1, 1),)
|
|
138
155
|
self.flash_attention = FlashAttentionScore(head_num=head_num, pre_tokens=prev_block_num,
|
|
139
156
|
next_tokens=next_block_num,
|
|
140
157
|
keep_prob=1 - dropout_rate,
|
|
141
|
-
scale_value=1.
|
|
142
|
-
inner_precise=0 if high_precision else 1
|
|
158
|
+
scale_value=1. / scaling_constant,
|
|
159
|
+
inner_precise=0 if high_precision else 1,
|
|
160
|
+
input_layout="BNSD").shard(fa_strategies)
|
|
143
161
|
|
|
144
|
-
self.ones = ops.Ones()
|
|
145
|
-
self.dim_mask = Tensor([1 for _ in range(head_dim)], dtype=mstype.int8)
|
|
146
|
-
self.scale_mul = ops.Mul().shard(((dp, mp, 1, 1), (1,)))
|
|
147
162
|
self.dropout_rate = dropout_rate
|
|
148
|
-
self.have_attention_mask_batch = have_attention_mask_batch
|
|
149
|
-
self.alibi = alibi
|
|
150
163
|
if self.dropout_rate > 1e-5:
|
|
151
164
|
self.keep_prob = Tensor(1 - self.dropout_rate, dtype=mstype.float16)
|
|
152
165
|
self.fill_v2 = ops.FillV2().shard(((dp, mp, 1, 1), ()))
|
|
@@ -162,46 +175,49 @@ class FlashAttention(Cell):
|
|
|
162
175
|
such as MatMul. Default: None.
|
|
163
176
|
:return:
|
|
164
177
|
"""
|
|
165
|
-
if
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
178
|
+
if self.is_910A:
|
|
179
|
+
if in_strategy is None:
|
|
180
|
+
# default: dp=1, mp=1, construct inputs only contain query, key, value
|
|
181
|
+
in_strategy = (
|
|
182
|
+
(1, 1, 1, 1),
|
|
183
|
+
(1, 1, 1, 1),
|
|
184
|
+
(1, 1, 1, 1),
|
|
185
|
+
)
|
|
186
|
+
self.flash_attention.shard(in_strategy)
|
|
187
|
+
dp = in_strategy[0][0]
|
|
188
|
+
mp = in_strategy[0][1]
|
|
189
|
+
self.flash_attention.add_prim_attr("dev_matrix_shape", [dp, mp, 1, 1])
|
|
190
|
+
inputs_tensor_map = [
|
|
191
|
+
[3, 2, 1, 0],
|
|
192
|
+
[3, 2, 1, 0],
|
|
193
|
+
[3, 2, 1, 0],
|
|
194
|
+
]
|
|
195
|
+
if self.have_attention_mask_batch:
|
|
196
|
+
inputs_tensor_map.append([3, 1, 0])
|
|
197
|
+
else:
|
|
198
|
+
inputs_tensor_map.append([-1, 1, 0])
|
|
185
199
|
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
200
|
+
input_empty_args_num = 2
|
|
201
|
+
# dropout_mask
|
|
202
|
+
if self.dropout_rate > 1e-5:
|
|
203
|
+
input_empty_args_num -= 1
|
|
204
|
+
inputs_tensor_map.append([3, 2, 1, 0])
|
|
191
205
|
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
206
|
+
if self.alibi:
|
|
207
|
+
input_empty_args_num -= 1
|
|
208
|
+
inputs_tensor_map.append([3, 2, 1, 0])
|
|
195
209
|
|
|
196
|
-
|
|
210
|
+
self.flash_attention.add_prim_attr("inputs_tensor_map", inputs_tensor_map)
|
|
197
211
|
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
212
|
+
self.flash_attention.add_prim_attr("outputs_tensor_map", [
|
|
213
|
+
[3, 2, 1, 0], # O
|
|
214
|
+
[3, 2, 1], # L
|
|
215
|
+
[3, 2, 1] # M
|
|
216
|
+
])
|
|
217
|
+
self.flash_attention.add_prim_attr("as_loss_divisor", 0)
|
|
218
|
+
self.flash_attention.add_prim_attr("empty_mirror_ops", input_empty_args_num)
|
|
219
|
+
else:
|
|
220
|
+
self.flash_attention.shard(in_strategy)
|
|
205
221
|
|
|
206
222
|
def construct(self, query, key, value, attn_mask=None, alibi_mask=None):
|
|
207
223
|
"""FlashAttention forward
|
|
@@ -212,24 +228,22 @@ class FlashAttention(Cell):
|
|
|
212
228
|
:param alibi_mask: [bsz, head_num, 1, seq_len], if not None
|
|
213
229
|
:return: output [bsz, head_num, seq_len, head_dim]
|
|
214
230
|
"""
|
|
215
|
-
query = self.scale_mul(query, self.scale_factor)
|
|
216
231
|
bsz, head_num, seq_len, head_dim = query.shape
|
|
217
|
-
_, k_head_num, k_seq_len, _ = key.shape
|
|
218
|
-
_, v_head_num, v_seq_len, _ = value.shape
|
|
219
|
-
if head_num != k_head_num or head_num != v_head_num:
|
|
220
|
-
raise ValueError(
|
|
221
|
-
"the head_num of query, key and value must be the same, "
|
|
222
|
-
"If different head_num are used, users need to change themselves to be same by tile.")
|
|
223
|
-
if seq_len % 16 != 0 or k_seq_len % 16 != 0 or k_seq_len != v_seq_len:
|
|
224
|
-
raise ValueError(
|
|
225
|
-
"query, key, value seq_len must be a multiple of 16, and key seq_len, value seq_len must be the same.")
|
|
226
|
-
|
|
227
|
-
if head_dim > 304:
|
|
228
|
-
raise ValueError(
|
|
229
|
-
"the head_dim must be less than 304, otherwise the ub would be OOM.")
|
|
230
|
-
|
|
231
232
|
if self.is_910A:
|
|
233
|
+
_, k_head_num, k_seq_len, _ = key.shape
|
|
234
|
+
_, v_head_num, v_seq_len, _ = value.shape
|
|
235
|
+
if head_num != k_head_num or head_num != v_head_num:
|
|
236
|
+
raise ValueError(
|
|
237
|
+
"the head_num of query, key and value must be the same, "
|
|
238
|
+
"If different head_num are used, users need to change themselves to be same by tile.")
|
|
239
|
+
if seq_len % 16 != 0 or k_seq_len % 16 != 0 or k_seq_len != v_seq_len:
|
|
240
|
+
raise ValueError(
|
|
241
|
+
"query, key, value seq_len must be a multiple of 16, "
|
|
242
|
+
"and the seq_len between key and value must be equal.")
|
|
232
243
|
# 910A -- FlashAttentionPrimtive
|
|
244
|
+
if head_dim > 304:
|
|
245
|
+
raise ValueError(
|
|
246
|
+
"the head_dim must be less than 304, otherwise the ub would be OOM.")
|
|
233
247
|
if self.dropout_rate > 1e-5:
|
|
234
248
|
drop_mask_bits = self.drop_gen_mask((bsz, head_num, seq_len, seq_len), self.keep_prob)
|
|
235
249
|
tensor_shape = Tensor((bsz, head_num, seq_len, seq_len), mstype.int32)
|
|
@@ -238,27 +252,25 @@ class FlashAttention(Cell):
|
|
|
238
252
|
drop_mask = self.do_dropout(ones, drop_mask_bits, self.keep_prob)
|
|
239
253
|
else:
|
|
240
254
|
drop_mask = None
|
|
255
|
+
query = self.scale_mul(query, self.scale_factor)
|
|
256
|
+
key = self.scale_mul(key, self.scale_factor)
|
|
257
|
+
attn_mask = self.cast(attn_mask, mstype.float16)
|
|
241
258
|
output, _, _ = self.flash_attention(query, key, value, attn_mask, drop_mask, alibi_mask)
|
|
242
259
|
else:
|
|
243
|
-
# FlashAttentionScore
|
|
244
|
-
# Useless input, just for binary calls.
|
|
260
|
+
# 910B -- FlashAttentionScore
|
|
245
261
|
if self.dropout_rate > 1e-5:
|
|
246
262
|
drop_mask_bits = self.reshape(self.drop_gen_mask((bsz, head_num, seq_len, seq_len), self.keep_prob),
|
|
247
263
|
(bsz, head_num, seq_len, seq_len // 8))
|
|
248
264
|
else:
|
|
249
265
|
drop_mask_bits = None
|
|
250
|
-
# (B,
|
|
251
|
-
|
|
252
|
-
key = self.reshape(self.transpose_4d_pre(key, (0, 2, 1, 3)), (bsz, seq_len, -1))
|
|
253
|
-
value = self.reshape(self.transpose_4d_pre(value, (0, 2, 1, 3)), (bsz, seq_len, -1))
|
|
254
|
-
attn_mask = self.attn_expand_dims(attn_mask, 1)
|
|
266
|
+
# (B, S, S) -> (B, 1, S, S)
|
|
267
|
+
attn_mask = self.cast(self.reshape(attn_mask, (bsz, 1, seq_len, seq_len)), mstype.uint8)
|
|
255
268
|
output, _, _ = self.flash_attention(query,
|
|
256
269
|
key,
|
|
257
270
|
value,
|
|
258
271
|
attn_mask,
|
|
259
272
|
drop_mask_bits,
|
|
260
273
|
None,
|
|
274
|
+
None,
|
|
261
275
|
None)
|
|
262
|
-
output = self.transpose_4d_post(self.reshape(output, (bsz, seq_len, head_num, head_dim)), (0, 2, 1, 3))
|
|
263
|
-
|
|
264
276
|
return output
|
mindspore/nn/layer/math.py
CHANGED
|
@@ -375,9 +375,6 @@ class DiGamma(Cell):
|
|
|
375
375
|
nan, real_result)
|
|
376
376
|
|
|
377
377
|
|
|
378
|
-
eps_fp32 = Tensor(np.finfo(np.float32).eps, mstype.float32)
|
|
379
|
-
|
|
380
|
-
|
|
381
378
|
def _while_helper_func(cond, body, vals):
|
|
382
379
|
while cond(vals).any():
|
|
383
380
|
vals = body(vals)
|
|
@@ -394,7 +391,7 @@ def _igamma_series(ax, x, a, enabled):
|
|
|
394
391
|
select = P.Select()
|
|
395
392
|
|
|
396
393
|
# If more data types are supported, this epsilon need to be selected.
|
|
397
|
-
epsilon =
|
|
394
|
+
epsilon = Tensor(np.finfo(np.float32).eps, mstype.float32)
|
|
398
395
|
|
|
399
396
|
def cond(vals):
|
|
400
397
|
enabled = vals[0]
|
|
@@ -443,7 +440,7 @@ def _igammac_continued_fraction(ax, x, a, enabled):
|
|
|
443
440
|
select = P.Select()
|
|
444
441
|
|
|
445
442
|
# If more data types are supported, this epsilon need to be selected.
|
|
446
|
-
epsilon =
|
|
443
|
+
epsilon = Tensor(np.finfo(np.float32).eps, mstype.float32)
|
|
447
444
|
|
|
448
445
|
def cond(vals):
|
|
449
446
|
enabled = vals[0]
|
|
@@ -620,8 +617,7 @@ class IGamma(Cell):
|
|
|
620
617
|
x = F.broadcast_to(x, para_shape)
|
|
621
618
|
a = F.broadcast_to(a, para_shape)
|
|
622
619
|
x_is_zero = self.equal(x, 0)
|
|
623
|
-
|
|
624
|
-
underflow = self.less(ax, self.neg(log_maxfloat))
|
|
620
|
+
underflow = self.less(ax, self.neg(self.log_maxfloat32))
|
|
625
621
|
ax = self.exp(ax)
|
|
626
622
|
enabled = self.logicalnot(self.logicalor(self.logicalor(x_is_zero, domain_error), underflow))
|
|
627
623
|
output = self.select(use_igammac,
|
mindspore/nn/layer/rnn_cells.py
CHANGED
|
@@ -83,7 +83,7 @@ def _check_lstmcell_init(func):
|
|
|
83
83
|
|
|
84
84
|
|
|
85
85
|
def _rnn_tanh_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
|
|
86
|
-
|
|
86
|
+
"""RNN cell function with tanh activation"""
|
|
87
87
|
if b_ih is None:
|
|
88
88
|
igates = P.MatMul(False, True)(inputs, w_ih)
|
|
89
89
|
hgates = P.MatMul(False, True)(hidden, w_hh)
|
|
@@ -94,7 +94,7 @@ def _rnn_tanh_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
|
|
|
94
94
|
|
|
95
95
|
|
|
96
96
|
def _rnn_relu_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
|
|
97
|
-
|
|
97
|
+
"""RNN cell function with relu activation"""
|
|
98
98
|
if b_ih is None:
|
|
99
99
|
igates = P.MatMul(False, True)(inputs, w_ih)
|
|
100
100
|
hgates = P.MatMul(False, True)(hidden, w_hh)
|
|
@@ -105,7 +105,7 @@ def _rnn_relu_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
|
|
|
105
105
|
|
|
106
106
|
|
|
107
107
|
def _lstm_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
|
|
108
|
-
|
|
108
|
+
"""LSTM cell function"""
|
|
109
109
|
hx, cx = hidden
|
|
110
110
|
if b_ih is None:
|
|
111
111
|
gates = P.MatMul(False, True)(inputs, w_ih) + P.MatMul(False, True)(hx, w_hh)
|
|
@@ -125,7 +125,7 @@ def _lstm_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
|
|
|
125
125
|
|
|
126
126
|
|
|
127
127
|
def _gru_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
|
|
128
|
-
|
|
128
|
+
"""GRU cell function"""
|
|
129
129
|
if b_ih is None:
|
|
130
130
|
gi = P.MatMul(False, True)(inputs, w_ih)
|
|
131
131
|
gh = P.MatMul(False, True)(hidden, w_hh)
|
|
@@ -144,7 +144,7 @@ def _gru_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
|
|
|
144
144
|
|
|
145
145
|
|
|
146
146
|
class RNNCellBase(Cell):
|
|
147
|
-
|
|
147
|
+
"""Basic class for RNN Cells"""
|
|
148
148
|
def __init__(self, input_size: int, hidden_size: int, has_bias: bool, num_chunks: int,
|
|
149
149
|
dtype=mstype.float32):
|
|
150
150
|
super().__init__()
|
|
@@ -644,6 +644,9 @@ class PipelineCell(Cell):
|
|
|
644
644
|
self.micro_inputs = nn.CellList()
|
|
645
645
|
self.micro_size = micro_size
|
|
646
646
|
self.add_list = []
|
|
647
|
+
if not isinstance(network, Cell):
|
|
648
|
+
raise TypeError("For 'PipelineCell', the argument 'network' must cell type, "
|
|
649
|
+
"but got the type : {}.".format(type(network)))
|
|
647
650
|
if not isinstance(micro_size, int):
|
|
648
651
|
raise TypeError("For 'PipelineCell', the argument 'micro_size' must be integer, "
|
|
649
652
|
"but got the type : {}.".format(type(micro_size)))
|
|
@@ -689,6 +692,9 @@ class GradAccumulationCell(Cell):
|
|
|
689
692
|
self.micro_inputs = nn.CellList()
|
|
690
693
|
self.micro_size = micro_size
|
|
691
694
|
self.add_list = []
|
|
695
|
+
if not isinstance(network, Cell):
|
|
696
|
+
raise TypeError("For 'GradAccumulationCell', the argument 'network' must cell type, "
|
|
697
|
+
"but got the type : {}.".format(type(network)))
|
|
692
698
|
if not isinstance(micro_size, int):
|
|
693
699
|
raise TypeError("For 'GradAccumulationCell', the argument 'micro_size' must be integer, "
|
|
694
700
|
"but got the type : {}.".format(type(micro_size)))
|