mindspore 2.2.0__cp38-cp38-win_amd64.whl → 2.2.11__cp38-cp38-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/_c_dataengine.cp38-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp38-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp38-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +3 -3
- mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
- mindspore/_extends/graph_kernel/splitter.py +3 -2
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +83 -66
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -4
- mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +2 -1
- mindspore/_extends/parse/__init__.py +3 -2
- mindspore/_extends/parse/parser.py +6 -1
- mindspore/_extends/parse/standard_method.py +14 -11
- mindspore/_extends/remote/kernel_build_server.py +2 -1
- mindspore/common/_utils.py +16 -0
- mindspore/common/api.py +1 -1
- mindspore/common/auto_dynamic_shape.py +81 -85
- mindspore/common/dump.py +1 -1
- mindspore/common/tensor.py +3 -20
- mindspore/config/op_info.config +1 -1
- mindspore/context.py +11 -4
- mindspore/dataset/engine/cache_client.py +8 -5
- mindspore/dataset/engine/datasets_standard_format.py +5 -0
- mindspore/dataset/vision/transforms.py +21 -21
- mindspore/experimental/optim/adam.py +1 -1
- mindspore/gen_ops.py +1 -1
- mindspore/include/api/model.h +17 -0
- mindspore/include/api/status.h +8 -3
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/nn/cell.py +0 -3
- mindspore/nn/layer/activation.py +4 -5
- mindspore/nn/layer/conv.py +39 -23
- mindspore/nn/layer/flash_attention.py +54 -129
- mindspore/nn/layer/math.py +3 -7
- mindspore/nn/layer/rnn_cells.py +5 -5
- mindspore/nn/wrap/__init__.py +4 -2
- mindspore/nn/wrap/cell_wrapper.py +12 -3
- mindspore/numpy/utils_const.py +5 -5
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/_grad_experimental/grad_array_ops.py +1 -1
- mindspore/ops/_grad_experimental/grad_implementations.py +2 -2
- mindspore/ops/_grad_experimental/grad_math_ops.py +19 -18
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
- mindspore/ops/_op_impl/aicpu/add.py +3 -3
- mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +21 -2
- mindspore/ops/_utils/utils.py +2 -0
- mindspore/ops/composite/multitype_ops/_compile_utils.py +2 -1
- mindspore/ops/composite/multitype_ops/getitem_impl.py +2 -2
- mindspore/ops/function/array_func.py +10 -7
- mindspore/ops/function/grad/grad_func.py +0 -1
- mindspore/ops/function/nn_func.py +98 -9
- mindspore/ops/function/random_func.py +2 -1
- mindspore/ops/op_info_register.py +24 -21
- mindspore/ops/operations/__init__.py +6 -2
- mindspore/ops/operations/_grad_ops.py +25 -6
- mindspore/ops/operations/_inner_ops.py +155 -23
- mindspore/ops/operations/array_ops.py +9 -7
- mindspore/ops/operations/comm_ops.py +2 -2
- mindspore/ops/operations/custom_ops.py +85 -68
- mindspore/ops/operations/inner_ops.py +26 -3
- mindspore/ops/operations/math_ops.py +7 -6
- mindspore/ops/operations/nn_ops.py +193 -49
- mindspore/parallel/_parallel_serialization.py +10 -3
- mindspore/parallel/_tensor.py +4 -1
- mindspore/parallel/checkpoint_transform.py +13 -2
- mindspore/parallel/shard.py +17 -10
- mindspore/profiler/common/util.py +1 -0
- mindspore/profiler/parser/ascend_hccl_generator.py +232 -0
- mindspore/profiler/parser/ascend_msprof_exporter.py +86 -43
- mindspore/profiler/parser/ascend_msprof_generator.py +196 -9
- mindspore/profiler/parser/ascend_op_generator.py +1 -1
- mindspore/profiler/parser/ascend_timeline_generator.py +6 -182
- mindspore/profiler/parser/base_timeline_generator.py +1 -1
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +2 -2
- mindspore/profiler/parser/framework_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +19 -0
- mindspore/profiler/profiling.py +46 -24
- mindspore/rewrite/api/pattern_engine.py +1 -1
- mindspore/rewrite/parsers/for_parser.py +7 -7
- mindspore/rewrite/parsers/module_parser.py +4 -4
- mindspore/rewrite/symbol_tree.py +1 -4
- mindspore/run_check/_check_version.py +5 -3
- mindspore/safeguard/rewrite_obfuscation.py +52 -28
- mindspore/train/callback/_summary_collector.py +1 -1
- mindspore/train/dataset_helper.py +1 -0
- mindspore/train/model.py +2 -2
- mindspore/train/serialization.py +97 -11
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/summary_record.py +23 -7
- mindspore/version.py +1 -1
- {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/METADATA +3 -2
- {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/RECORD +101 -112
- mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +0 -406
- mindspore/ops/_op_impl/_custom_op/flash_attention/constants.py +0 -41
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +0 -467
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +0 -563
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +0 -193
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +0 -435
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +0 -45
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +0 -67
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +0 -62
- {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/WHEEL +0 -0
- {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/entry_points.txt +0 -0
- {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/top_level.txt +0 -0
|
@@ -144,14 +144,14 @@ class AdjustBrightness(ImageTensorOperation, PyTensorOperation):
|
|
|
144
144
|
|
|
145
145
|
Args:
|
|
146
146
|
device_target (str, optional): The operator will be executed on this device. Currently supports
|
|
147
|
-
``CPU``
|
|
147
|
+
``CPU`` . Default: ``CPU`` .
|
|
148
148
|
|
|
149
149
|
Raises:
|
|
150
150
|
TypeError: If `device_target` is not of type str.
|
|
151
|
-
ValueError: If `device_target` is not
|
|
151
|
+
ValueError: If `device_target` is not ``CPU`` .
|
|
152
152
|
|
|
153
153
|
Supported Platforms:
|
|
154
|
-
``CPU``
|
|
154
|
+
``CPU``
|
|
155
155
|
|
|
156
156
|
Examples:
|
|
157
157
|
>>> import mindspore.dataset as ds
|
|
@@ -227,14 +227,14 @@ class AdjustContrast(ImageTensorOperation, PyTensorOperation):
|
|
|
227
227
|
|
|
228
228
|
Args:
|
|
229
229
|
device_target (str, optional): The operator will be executed on this device. Currently supports
|
|
230
|
-
``CPU``
|
|
230
|
+
``CPU`` . Default: ``CPU`` .
|
|
231
231
|
|
|
232
232
|
Raises:
|
|
233
233
|
TypeError: If `device_target` is not of type str.
|
|
234
|
-
ValueError: If `device_target` is not
|
|
234
|
+
ValueError: If `device_target` is not ``CPU`` .
|
|
235
235
|
|
|
236
236
|
Supported Platforms:
|
|
237
|
-
``CPU``
|
|
237
|
+
``CPU``
|
|
238
238
|
|
|
239
239
|
Examples:
|
|
240
240
|
>>> import mindspore.dataset as ds
|
|
@@ -373,14 +373,14 @@ class AdjustHue(ImageTensorOperation, PyTensorOperation):
|
|
|
373
373
|
|
|
374
374
|
Args:
|
|
375
375
|
device_target (str, optional): The operator will be executed on this device. Currently supports
|
|
376
|
-
``CPU``
|
|
376
|
+
``CPU`` . Default: ``CPU`` .
|
|
377
377
|
|
|
378
378
|
Raises:
|
|
379
379
|
TypeError: If `device_target` is not of type str.
|
|
380
|
-
ValueError: If `device_target` is not
|
|
380
|
+
ValueError: If `device_target` is not ``CPU`` .
|
|
381
381
|
|
|
382
382
|
Supported Platforms:
|
|
383
|
-
``CPU``
|
|
383
|
+
``CPU``
|
|
384
384
|
|
|
385
385
|
Examples:
|
|
386
386
|
>>> import mindspore.dataset as ds
|
|
@@ -457,14 +457,14 @@ class AdjustSaturation(ImageTensorOperation, PyTensorOperation):
|
|
|
457
457
|
|
|
458
458
|
Args:
|
|
459
459
|
device_target (str, optional): The operator will be executed on this device. Currently supports
|
|
460
|
-
``CPU``
|
|
460
|
+
``CPU`` . Default: ``CPU`` .
|
|
461
461
|
|
|
462
462
|
Raises:
|
|
463
463
|
TypeError: If `device_target` is not of type str.
|
|
464
|
-
ValueError: If `device_target` is not
|
|
464
|
+
ValueError: If `device_target` is not ``CPU`` .
|
|
465
465
|
|
|
466
466
|
Supported Platforms:
|
|
467
|
-
``CPU``
|
|
467
|
+
``CPU``
|
|
468
468
|
|
|
469
469
|
Examples:
|
|
470
470
|
>>> import mindspore.dataset as ds
|
|
@@ -1159,14 +1159,14 @@ class Decode(ImageTensorOperation, PyTensorOperation):
|
|
|
1159
1159
|
|
|
1160
1160
|
Args:
|
|
1161
1161
|
device_target (str, optional): The operator will be executed on this device. Currently supports
|
|
1162
|
-
``CPU``
|
|
1162
|
+
``CPU`` . Default: ``CPU`` .
|
|
1163
1163
|
|
|
1164
1164
|
Raises:
|
|
1165
1165
|
TypeError: If `device_target` is not of type str.
|
|
1166
|
-
ValueError: If `device_target` is not
|
|
1166
|
+
ValueError: If `device_target` is not ``CPU`` .
|
|
1167
1167
|
|
|
1168
1168
|
Supported Platforms:
|
|
1169
|
-
``CPU``
|
|
1169
|
+
``CPU``
|
|
1170
1170
|
|
|
1171
1171
|
Examples:
|
|
1172
1172
|
>>> import mindspore.dataset as ds
|
|
@@ -1908,14 +1908,14 @@ class Normalize(ImageTensorOperation):
|
|
|
1908
1908
|
|
|
1909
1909
|
Args:
|
|
1910
1910
|
device_target (str, optional): The operator will be executed on this device. Currently supports
|
|
1911
|
-
``CPU``
|
|
1911
|
+
``CPU`` . Default: ``CPU`` .
|
|
1912
1912
|
|
|
1913
1913
|
Raises:
|
|
1914
1914
|
TypeError: If `device_target` is not of type str.
|
|
1915
|
-
ValueError: If `device_target` is not
|
|
1915
|
+
ValueError: If `device_target` is not ``CPU`` .
|
|
1916
1916
|
|
|
1917
1917
|
Supported Platforms:
|
|
1918
|
-
``CPU``
|
|
1918
|
+
``CPU``
|
|
1919
1919
|
|
|
1920
1920
|
Examples:
|
|
1921
1921
|
>>> import mindspore.dataset as ds
|
|
@@ -4182,14 +4182,14 @@ class Resize(ImageTensorOperation, PyTensorOperation):
|
|
|
4182
4182
|
|
|
4183
4183
|
Args:
|
|
4184
4184
|
device_target (str, optional): The operator will be executed on this device. Currently supports
|
|
4185
|
-
``CPU``
|
|
4185
|
+
``CPU`` . Default: ``CPU`` .
|
|
4186
4186
|
|
|
4187
4187
|
Raises:
|
|
4188
4188
|
TypeError: If `device_target` is not of type str.
|
|
4189
|
-
ValueError: If `device_target` is not
|
|
4189
|
+
ValueError: If `device_target` is not ``CPU`` .
|
|
4190
4190
|
|
|
4191
4191
|
Supported Platforms:
|
|
4192
|
-
``CPU``
|
|
4192
|
+
``CPU``
|
|
4193
4193
|
|
|
4194
4194
|
Examples:
|
|
4195
4195
|
>>> import mindspore.dataset as ds
|
mindspore/gen_ops.py
CHANGED
|
@@ -120,7 +120,7 @@ def generate_py_primitive(yaml_data):
|
|
|
120
120
|
assign_str += arg_name
|
|
121
121
|
args_assign.append(assign_str)
|
|
122
122
|
|
|
123
|
-
args_assign = '\n'.join(assign for assign in args_assign)
|
|
123
|
+
args_assign = '\n'.join([assign for assign in args_assign])
|
|
124
124
|
primitive_code = f"""
|
|
125
125
|
class {class_name}(Primitive):
|
|
126
126
|
def __init__(self, {', '.join(init_args_with_default)}):
|
mindspore/include/api/model.h
CHANGED
|
@@ -136,6 +136,13 @@ class MS_API Model {
|
|
|
136
136
|
/// \return Status.
|
|
137
137
|
Status UpdateWeights(const std::vector<MSTensor> &new_weights);
|
|
138
138
|
|
|
139
|
+
/// \brief Change the size and or content of weight tensors
|
|
140
|
+
///
|
|
141
|
+
/// \param[in] A vector where model constant are arranged in sequence
|
|
142
|
+
///
|
|
143
|
+
/// \return Status.
|
|
144
|
+
Status UpdateWeights(const std::vector<std::vector<MSTensor>> &new_weights);
|
|
145
|
+
|
|
139
146
|
/// \brief Inference model API. If use this API in train mode, it's equal to RunStep API.
|
|
140
147
|
///
|
|
141
148
|
/// \param[in] inputs A vector where model inputs are arranged in sequence.
|
|
@@ -358,6 +365,13 @@ class MS_API Model {
|
|
|
358
365
|
|
|
359
366
|
const std::shared_ptr<ModelImpl> impl() const { return impl_; }
|
|
360
367
|
|
|
368
|
+
/// \brief Get model info by key
|
|
369
|
+
///
|
|
370
|
+
/// \param[in] key The key of model info key-value pair
|
|
371
|
+
///
|
|
372
|
+
/// \return The value of the model info associated with the given key.
|
|
373
|
+
inline std::string GetModelInfo(const std::string &key);
|
|
374
|
+
|
|
361
375
|
private:
|
|
362
376
|
friend class Serialization;
|
|
363
377
|
// api without std::string
|
|
@@ -374,6 +388,7 @@ class MS_API Model {
|
|
|
374
388
|
const std::vector<char> &cropto_lib_path);
|
|
375
389
|
Status Build(const std::vector<char> &model_path, ModelType model_type, const std::shared_ptr<Context> &model_context,
|
|
376
390
|
const Key &dec_key, const std::vector<char> &dec_mode, const std::vector<char> &cropto_lib_path);
|
|
391
|
+
std::vector<char> GetModelInfo(const std::vector<char> &key);
|
|
377
392
|
std::shared_ptr<ModelImpl> impl_;
|
|
378
393
|
};
|
|
379
394
|
|
|
@@ -416,5 +431,7 @@ Status Model::Build(const std::string &model_path, ModelType model_type,
|
|
|
416
431
|
const std::shared_ptr<Context> &model_context) {
|
|
417
432
|
return Build(StringToChar(model_path), model_type, model_context);
|
|
418
433
|
}
|
|
434
|
+
|
|
435
|
+
inline std::string Model::GetModelInfo(const std::string &key) { return CharToString(GetModelInfo(StringToChar(key))); }
|
|
419
436
|
} // namespace mindspore
|
|
420
437
|
#endif // MINDSPORE_INCLUDE_API_MODEL_H
|
mindspore/include/api/status.h
CHANGED
|
@@ -83,9 +83,14 @@ enum StatusCode : uint32_t {
|
|
|
83
83
|
kLiteModelRebuild = kLite | (0x0FFFFFFF & -12), /**< Model has been built. */
|
|
84
84
|
|
|
85
85
|
// Executor error code, range: [-100,-200)
|
|
86
|
-
kLiteOutOfTensorRange = kLite | (0x0FFFFFFF & -100),
|
|
87
|
-
kLiteInputTensorError = kLite | (0x0FFFFFFF & -101),
|
|
88
|
-
kLiteReentrantError = kLite | (0x0FFFFFFF & -102),
|
|
86
|
+
kLiteOutOfTensorRange = kLite | (0x0FFFFFFF & -100), /**< Failed to check range. */
|
|
87
|
+
kLiteInputTensorError = kLite | (0x0FFFFFFF & -101), /**< Failed to check input tensor. */
|
|
88
|
+
kLiteReentrantError = kLite | (0x0FFFFFFF & -102), /**< Exist executor running. */
|
|
89
|
+
kLiteLLMWaitProcessTimeOut = kLite | (0x0FFFFFFF & -103), /**< Wait to be processed time out. */
|
|
90
|
+
kLiteLLMKVCacheNotExist = kLite | (0x0FFFFFFF & -104), /**< KV Cache not exist. */
|
|
91
|
+
kLiteLLMRepeatRequest = kLite | (0x0FFFFFFF & -105), /**< repeat request. */
|
|
92
|
+
kLiteLLMRequestAlreadyCompleted = kLite | (0x0FFFFFFF & -106), /**< request already complete!. */
|
|
93
|
+
kLiteLLMEngineFinalized = kLite | (0x0FFFFFFF & -107), /**< llm engine finalized. */
|
|
89
94
|
|
|
90
95
|
// Graph error code, range: [-200,-300)
|
|
91
96
|
kLiteGraphFileError = kLite | (0x0FFFFFFF & -200), /**< Failed to verify graph file. */
|
mindspore/mindspore_backend.dll
CHANGED
|
Binary file
|
mindspore/mindspore_common.dll
CHANGED
|
Binary file
|
mindspore/mindspore_core.dll
CHANGED
|
Binary file
|
|
Binary file
|
mindspore/nn/cell.py
CHANGED
|
@@ -1081,9 +1081,6 @@ class Cell(Cell_):
|
|
|
1081
1081
|
if not isinstance(param, Parameter) and param is not None:
|
|
1082
1082
|
raise TypeError(f"For 'insert_param_to_cell', the argument 'param' must be 'Parameter' if not None, "
|
|
1083
1083
|
f"but got {type(param)}.")
|
|
1084
|
-
if param is None:
|
|
1085
|
-
raise TypeError(f"For 'insert_param_to_cell', the argument 'param' must not be None, "
|
|
1086
|
-
f"but got None.")
|
|
1087
1084
|
if isinstance(param, Parameter) and param.name == PARAMETER_NAME_DEFAULT:
|
|
1088
1085
|
param.name = param_name
|
|
1089
1086
|
self._params[param_name] = param
|
mindspore/nn/layer/activation.py
CHANGED
|
@@ -932,10 +932,8 @@ class GELU(Cell):
|
|
|
932
932
|
"""Initialize GELU."""
|
|
933
933
|
super(GELU, self).__init__()
|
|
934
934
|
validator.check_bool(approximate, 'approximate', self.cls_name)
|
|
935
|
-
self.approximate =
|
|
936
|
-
if approximate:
|
|
937
|
-
self.approximate = 'tanh'
|
|
938
|
-
else:
|
|
935
|
+
self.approximate = 'tanh'
|
|
936
|
+
if not approximate:
|
|
939
937
|
self.approximate = 'none'
|
|
940
938
|
|
|
941
939
|
def construct(self, x):
|
|
@@ -1335,7 +1333,8 @@ class LRN(Cell):
|
|
|
1335
1333
|
|
|
1336
1334
|
.. warning::
|
|
1337
1335
|
LRN is deprecated on Ascend due to potential accuracy problem. It's recommended to use other
|
|
1338
|
-
normalization methods, e.g. :class:`mindspore.nn.
|
|
1336
|
+
normalization methods, e.g. :class:`mindspore.nn.BatchNorm1d` ,
|
|
1337
|
+
:class:`mindspore.nn.BatchNorm2d` , :class:`mindspore.nn.BatchNorm3d`.
|
|
1339
1338
|
|
|
1340
1339
|
Refer to :func:`mindspore.ops.lrn` for more details.
|
|
1341
1340
|
|
mindspore/nn/layer/conv.py
CHANGED
|
@@ -718,9 +718,9 @@ class Conv3d(_Conv):
|
|
|
718
718
|
|
|
719
719
|
.. math::
|
|
720
720
|
\begin{array}{ll} \\
|
|
721
|
-
D_{out}
|
|
722
|
-
H_{out}
|
|
723
|
-
W_{out}
|
|
721
|
+
D_{out} = \left \lceil{\frac{D_{in}}{\text{stride[0]}}} \right \rceil \\
|
|
722
|
+
H_{out} = \left \lceil{\frac{H_{in}}{\text{stride[1]}}} \right \rceil \\
|
|
723
|
+
W_{out} = \left \lceil{\frac{W_{in}}{\text{stride[2]}}} \right \rceil \\
|
|
724
724
|
\end{array}
|
|
725
725
|
|
|
726
726
|
|
|
@@ -728,11 +728,11 @@ class Conv3d(_Conv):
|
|
|
728
728
|
|
|
729
729
|
.. math::
|
|
730
730
|
\begin{array}{ll} \\
|
|
731
|
-
D_{out}
|
|
731
|
+
D_{out} = \left \lfloor{\frac{D_{in} - \text{dilation[0]} \times (\text{kernel_size[0]} - 1) }
|
|
732
732
|
{\text{stride[0]}} + 1} \right \rfloor \\
|
|
733
|
-
H_{out}
|
|
733
|
+
H_{out} = \left \lfloor{\frac{H_{in} - \text{dilation[1]} \times (\text{kernel_size[1]} - 1) }
|
|
734
734
|
{\text{stride[1]}} + 1} \right \rfloor \\
|
|
735
|
-
W_{out}
|
|
735
|
+
W_{out} = \left \lfloor{\frac{W_{in} - \text{dilation[2]} \times (\text{kernel_size[2]} - 1) }
|
|
736
736
|
{\text{stride[2]}} + 1} \right \rfloor \\
|
|
737
737
|
\end{array}
|
|
738
738
|
|
|
@@ -740,11 +740,11 @@ class Conv3d(_Conv):
|
|
|
740
740
|
|
|
741
741
|
.. math::
|
|
742
742
|
\begin{array}{ll} \\
|
|
743
|
-
D_{out}
|
|
743
|
+
D_{out} = \left \lfloor{\frac{D_{in} + padding[0] + padding[1] - (\text{dilation[0]} - 1) \times
|
|
744
744
|
\text{kernel_size[0]} - 1 }{\text{stride[0]}} + 1} \right \rfloor \\
|
|
745
|
-
H_{out}
|
|
745
|
+
H_{out} = \left \lfloor{\frac{H_{in} + padding[2] + padding[3] - (\text{dilation[1]} - 1) \times
|
|
746
746
|
\text{kernel_size[1]} - 1 }{\text{stride[1]}} + 1} \right \rfloor \\
|
|
747
|
-
W_{out}
|
|
747
|
+
W_{out} = \left \lfloor{\frac{W_{in} + padding[4] + padding[5] - (\text{dilation[2]} - 1) \times
|
|
748
748
|
\text{kernel_size[2]} - 1 }{\text{stride[2]}} + 1} \right \rfloor \\
|
|
749
749
|
\end{array}
|
|
750
750
|
|
|
@@ -812,7 +812,7 @@ class Conv3d(_Conv):
|
|
|
812
812
|
bias_init,
|
|
813
813
|
data_format,
|
|
814
814
|
dtype=dtype)
|
|
815
|
-
out_channels = self.out_channels
|
|
815
|
+
out_channels = self.out_channels // group
|
|
816
816
|
self.conv3d = P.Conv3D(out_channel=out_channels,
|
|
817
817
|
kernel_size=self.kernel_size,
|
|
818
818
|
mode=1,
|
|
@@ -820,17 +820,33 @@ class Conv3d(_Conv):
|
|
|
820
820
|
pad=self.padding,
|
|
821
821
|
stride=self.stride,
|
|
822
822
|
dilation=self.dilation,
|
|
823
|
-
group=
|
|
823
|
+
group=1,
|
|
824
824
|
data_format=self.data_format)
|
|
825
825
|
self.bias_add = P.BiasAdd(data_format=self.data_format)
|
|
826
826
|
self.shape = P.Shape()
|
|
827
|
+
self.concat = P.Concat(1)
|
|
828
|
+
self.split_0 = P.Split(0, self.group)
|
|
829
|
+
self.split_1 = P.Split(1, self.group)
|
|
827
830
|
|
|
828
831
|
def construct(self, x):
|
|
829
832
|
x_shape = self.shape(x)
|
|
830
833
|
_check_input_5dims(x_shape, self.cls_name)
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
|
|
834
|
+
if self.group == 1:
|
|
835
|
+
out = self.conv3d(x, self.weight)
|
|
836
|
+
if self.has_bias:
|
|
837
|
+
out = self.bias_add(out, self.bias)
|
|
838
|
+
else:
|
|
839
|
+
features = self.split_1(x)
|
|
840
|
+
weights = self.split_0(self.weight)
|
|
841
|
+
outputs = ()
|
|
842
|
+
for i in range(self.group):
|
|
843
|
+
output = self.conv3d(features[i], weights[i])
|
|
844
|
+
outputs = outputs + (output,)
|
|
845
|
+
out = self.concat(outputs)
|
|
846
|
+
if self.bias is not None:
|
|
847
|
+
new_shape = [1 for _ in range(out.ndim)]
|
|
848
|
+
new_shape[1] = self.out_channels
|
|
849
|
+
out = out + self.bias.reshape(new_shape)
|
|
834
850
|
return out
|
|
835
851
|
|
|
836
852
|
|
|
@@ -921,9 +937,9 @@ class Conv3dTranspose(_Conv):
|
|
|
921
937
|
|
|
922
938
|
.. math::
|
|
923
939
|
\begin{array}{ll} \\
|
|
924
|
-
D_{out}
|
|
925
|
-
H_{out}
|
|
926
|
-
W_{out}
|
|
940
|
+
D_{out} = \left \lfloor{\frac{D_{in}}{\text{stride[0]}} + 1} \right \rfloor \\
|
|
941
|
+
H_{out} = \left \lfloor{\frac{H_{in}}{\text{stride[1]}} + 1} \right \rfloor \\
|
|
942
|
+
W_{out} = \left \lfloor{\frac{W_{in}}{\text{stride[2]}} + 1} \right \rfloor \\
|
|
927
943
|
\end{array}
|
|
928
944
|
|
|
929
945
|
|
|
@@ -931,11 +947,11 @@ class Conv3dTranspose(_Conv):
|
|
|
931
947
|
|
|
932
948
|
.. math::
|
|
933
949
|
\begin{array}{ll} \\
|
|
934
|
-
D_{out}
|
|
950
|
+
D_{out} = \left \lfloor{\frac{D_{in} - \text{dilation[0]} \times (\text{kernel_size[0]} - 1) }
|
|
935
951
|
{\text{stride[0]}} + 1} \right \rfloor \\
|
|
936
|
-
H_{out}
|
|
952
|
+
H_{out} = \left \lfloor{\frac{H_{in} - \text{dilation[1]} \times (\text{kernel_size[1]} - 1) }
|
|
937
953
|
{\text{stride[1]}} + 1} \right \rfloor \\
|
|
938
|
-
W_{out}
|
|
954
|
+
W_{out} = \left \lfloor{\frac{W_{in} - \text{dilation[2]} \times (\text{kernel_size[2]} - 1) }
|
|
939
955
|
{\text{stride[2]}} + 1} \right \rfloor \\
|
|
940
956
|
\end{array}
|
|
941
957
|
|
|
@@ -943,11 +959,11 @@ class Conv3dTranspose(_Conv):
|
|
|
943
959
|
|
|
944
960
|
.. math::
|
|
945
961
|
\begin{array}{ll} \\
|
|
946
|
-
D_{out}
|
|
962
|
+
D_{out} = \left \lfloor{\frac{D_{in} + padding[0] + padding[1] - (\text{dilation[0]} - 1) \times
|
|
947
963
|
\text{kernel_size[0]} - 1 }{\text{stride[0]}} + 1} \right \rfloor \\
|
|
948
|
-
H_{out}
|
|
964
|
+
H_{out} = \left \lfloor{\frac{H_{in} + padding[2] + padding[3] - (\text{dilation[1]} - 1) \times
|
|
949
965
|
\text{kernel_size[1]} - 1 }{\text{stride[1]}} + 1} \right \rfloor \\
|
|
950
|
-
W_{out}
|
|
966
|
+
W_{out} = \left \lfloor{\frac{W_{in} + padding[4] + padding[5] - (\text{dilation[2]} - 1) \times
|
|
951
967
|
\text{kernel_size[2]} - 1 }{\text{stride[2]}} + 1} \right \rfloor \\
|
|
952
968
|
\end{array}
|
|
953
969
|
|
|
@@ -21,9 +21,7 @@ import mindspore.common.dtype as mstype
|
|
|
21
21
|
from mindspore.common.tensor import Tensor
|
|
22
22
|
from mindspore import ops
|
|
23
23
|
from mindspore.nn.cell import Cell
|
|
24
|
-
from mindspore.ops._op_impl._custom_op.flash_attention.flash_attention_impl import get_flash_attention
|
|
25
24
|
from mindspore.ops.operations.nn_ops import FlashAttentionScore
|
|
26
|
-
from mindspore._c_expression import MSContext
|
|
27
25
|
|
|
28
26
|
__all__ = ['FlashAttention']
|
|
29
27
|
|
|
@@ -46,25 +44,25 @@ class FlashAttention(Cell):
|
|
|
46
44
|
Default 65536.
|
|
47
45
|
next_block_num(int): A integer to define the number of blocks to look behind for local block sparse attention.
|
|
48
46
|
Default 65536.
|
|
49
|
-
tiling_stgy_name(str): A str to define tiling strategy of flash attention.
|
|
50
47
|
dp(int): data parallel.
|
|
51
48
|
Default 1.
|
|
52
49
|
mp(int): model parallel.
|
|
53
50
|
Default 1.
|
|
54
|
-
high_precision(bool): This mode has higher precision but some performance loss.
|
|
51
|
+
high_precision(bool): This mode has higher precision but some performance loss. Only take effect on Ascend910A.
|
|
55
52
|
Default False.
|
|
56
53
|
have_attention_mask_batch(bool): indicates whether attention_mask contains the batch dimension.
|
|
57
54
|
Default True
|
|
58
55
|
alibi(bool): This parameter indicates whether the flashattention supports the Alibi.
|
|
59
56
|
Default: False
|
|
57
|
+
use_mqa(bool): Using MQA if True, only take effect under 910B. Default: False.
|
|
60
58
|
|
|
61
59
|
|
|
62
60
|
Inputs:
|
|
63
61
|
- **query** (Tensor) - Tensor query (:class:`mstype.fp16` [batch_size, head_num, seq_length, head_dim])
|
|
64
62
|
- **key** (Tensor) - Tensor key (:class:`mstype.fp16` [batch_size, head_num, seq_length, head_dim])
|
|
65
63
|
- **value** (Tensor) - Tensor value (:class:`mstype.fp16` [batch_size, head_num, seq_length, head_dim])
|
|
66
|
-
- **attention_mask** (Tensor) - Float Tensor the mask of (:class:`mstype.fp16`
|
|
67
|
-
|
|
64
|
+
- **attention_mask** (Tensor) - Float Tensor the mask of (:class:`mstype.fp16` `mstype.uint8`
|
|
65
|
+
[batch_size, seq_length, seq_length]): A matrix to pass masked information.
|
|
68
66
|
|
|
69
67
|
Outputs:
|
|
70
68
|
A Tensor. The output of the attention with shape [batch_size, head_num, seq_length, head_dim]
|
|
@@ -97,56 +95,51 @@ class FlashAttention(Cell):
|
|
|
97
95
|
dropout_rate=0.0,
|
|
98
96
|
prev_block_num=65536,
|
|
99
97
|
next_block_num=65536,
|
|
100
|
-
tiling_stgy_name="sparse",
|
|
101
98
|
dp=1,
|
|
102
99
|
mp=1,
|
|
103
100
|
high_precision=False,
|
|
104
101
|
have_attention_mask_batch=True,
|
|
105
|
-
alibi=False
|
|
102
|
+
alibi=False,
|
|
103
|
+
use_mqa=False
|
|
106
104
|
):
|
|
107
105
|
super(FlashAttention, self).__init__()
|
|
108
106
|
|
|
109
107
|
scaling_constant = math.sqrt(head_dim)
|
|
110
108
|
if scaling_constant == 0:
|
|
111
109
|
raise ValueError("the scaling constant must not be 0.")
|
|
112
|
-
self.
|
|
110
|
+
self.dropout_rate = dropout_rate
|
|
111
|
+
self.alibi = alibi
|
|
112
|
+
self.have_attention_mask_batch = have_attention_mask_batch
|
|
113
113
|
|
|
114
|
-
self.
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
)
|
|
122
|
-
|
|
123
|
-
else:
|
|
124
|
-
if alibi:
|
|
125
|
-
raise ValueError(f"When soc_version is not Ascend910A, alibi must be False")
|
|
126
|
-
self.transpose_4d_pre = ops.Transpose().shard(((dp, mp, 1, 1),))
|
|
127
|
-
self.transpose_4d_post = ops.Transpose().shard(((dp, 1, mp, 1),))
|
|
128
|
-
self.reshape = ops.Reshape()
|
|
129
|
-
self.zeros_like = ops.ZerosLike().shard(((dp, mp, 1, 1),))
|
|
130
|
-
self.zeros = ops.Zeros()
|
|
131
|
-
self.attn_expand_dims = ops.ExpandDims().shard(((dp, 1, 1),))
|
|
132
|
-
fa_strategies = ((dp, 1, mp),
|
|
133
|
-
(dp, 1, mp),
|
|
134
|
-
(dp, 1, mp),
|
|
114
|
+
self.transpose_4d_pre = ops.Transpose().shard(((dp, mp, 1, 1),))
|
|
115
|
+
self.transpose_4d_post = ops.Transpose().shard(((dp, 1, mp, 1),))
|
|
116
|
+
self.reshape = ops.Reshape()
|
|
117
|
+
self.zeros_like = ops.ZerosLike().shard(((dp, mp, 1, 1),))
|
|
118
|
+
self.zeros = ops.Zeros()
|
|
119
|
+
self.attn_cast = ops.Cast()
|
|
120
|
+
if use_mqa:
|
|
121
|
+
fa_strategies = ((dp, mp, 1, 1),
|
|
122
|
+
(dp, 1, 1, 1),
|
|
135
123
|
(dp, 1, 1, 1))
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
124
|
+
else:
|
|
125
|
+
fa_strategies = ((dp, mp, 1, 1),
|
|
126
|
+
(dp, mp, 1, 1),
|
|
127
|
+
(dp, mp, 1, 1))
|
|
128
|
+
if self.alibi:
|
|
129
|
+
self.alibi_rescale_mul = ops.Mul().shard(((dp, mp, 1, 1), (1,)))
|
|
130
|
+
self.alibi_rescale_factor = Tensor([scaling_constant], dtype=mstype.float16)
|
|
131
|
+
fa_strategies += ((dp, mp, 1, 1),)
|
|
132
|
+
if dropout_rate > 1e-5:
|
|
133
|
+
fa_strategies += ((dp, mp, 1, 1),)
|
|
134
|
+
fa_strategies += ((dp, 1, 1, 1),)
|
|
135
|
+
self.flash_attention = FlashAttentionScore(head_num=head_num, pre_tokens=prev_block_num,
|
|
136
|
+
next_tokens=next_block_num,
|
|
137
|
+
keep_prob=1 - dropout_rate,
|
|
138
|
+
scale_value=1. / scaling_constant,
|
|
139
|
+
inner_precise=0,
|
|
140
|
+
input_layout="BNSD").shard(fa_strategies)
|
|
143
141
|
|
|
144
|
-
self.ones = ops.Ones()
|
|
145
|
-
self.dim_mask = Tensor([1 for _ in range(head_dim)], dtype=mstype.int8)
|
|
146
|
-
self.scale_mul = ops.Mul().shard(((dp, mp, 1, 1), (1,)))
|
|
147
142
|
self.dropout_rate = dropout_rate
|
|
148
|
-
self.have_attention_mask_batch = have_attention_mask_batch
|
|
149
|
-
self.alibi = alibi
|
|
150
143
|
if self.dropout_rate > 1e-5:
|
|
151
144
|
self.keep_prob = Tensor(1 - self.dropout_rate, dtype=mstype.float16)
|
|
152
145
|
self.fill_v2 = ops.FillV2().shard(((dp, mp, 1, 1), ()))
|
|
@@ -162,46 +155,7 @@ class FlashAttention(Cell):
|
|
|
162
155
|
such as MatMul. Default: None.
|
|
163
156
|
:return:
|
|
164
157
|
"""
|
|
165
|
-
if in_strategy is None:
|
|
166
|
-
# default: dp=1, mp=1, construct inputs only contain query, key, value
|
|
167
|
-
in_strategy = (
|
|
168
|
-
(1, 1, 1, 1),
|
|
169
|
-
(1, 1, 1, 1),
|
|
170
|
-
(1, 1, 1, 1),
|
|
171
|
-
)
|
|
172
158
|
self.flash_attention.shard(in_strategy)
|
|
173
|
-
dp = in_strategy[0][0]
|
|
174
|
-
mp = in_strategy[0][1]
|
|
175
|
-
self.flash_attention.add_prim_attr("dev_matrix_shape", [dp, mp, 1, 1])
|
|
176
|
-
inputs_tensor_map = [
|
|
177
|
-
[3, 2, 1, 0],
|
|
178
|
-
[3, 2, 1, 0],
|
|
179
|
-
[3, 2, 1, 0],
|
|
180
|
-
]
|
|
181
|
-
if self.have_attention_mask_batch:
|
|
182
|
-
inputs_tensor_map.append([3, 1, 0])
|
|
183
|
-
else:
|
|
184
|
-
inputs_tensor_map.append([-1, 1, 0])
|
|
185
|
-
|
|
186
|
-
input_empty_args_num = 2
|
|
187
|
-
# dropout_mask
|
|
188
|
-
if self.dropout_rate > 1e-5:
|
|
189
|
-
input_empty_args_num -= 1
|
|
190
|
-
inputs_tensor_map.append([3, 2, 1, 0])
|
|
191
|
-
|
|
192
|
-
if self.alibi:
|
|
193
|
-
input_empty_args_num -= 1
|
|
194
|
-
inputs_tensor_map.append([3, 2, 1, 0])
|
|
195
|
-
|
|
196
|
-
self.flash_attention.add_prim_attr("inputs_tensor_map", inputs_tensor_map)
|
|
197
|
-
|
|
198
|
-
self.flash_attention.add_prim_attr("outputs_tensor_map", [
|
|
199
|
-
[3, 2, 1, 0], # O
|
|
200
|
-
[3, 2, 1], # L
|
|
201
|
-
[3, 2, 1] # M
|
|
202
|
-
])
|
|
203
|
-
self.flash_attention.add_prim_attr("as_loss_divisor", 0)
|
|
204
|
-
self.flash_attention.add_prim_attr("empty_mirror_ops", input_empty_args_num)
|
|
205
159
|
|
|
206
160
|
def construct(self, query, key, value, attn_mask=None, alibi_mask=None):
|
|
207
161
|
"""FlashAttention forward
|
|
@@ -212,53 +166,24 @@ class FlashAttention(Cell):
|
|
|
212
166
|
:param alibi_mask: [bsz, head_num, 1, seq_len], if not None
|
|
213
167
|
:return: output [bsz, head_num, seq_len, head_dim]
|
|
214
168
|
"""
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
raise ValueError(
|
|
221
|
-
"the head_num of query, key and value must be the same, "
|
|
222
|
-
"If different head_num are used, users need to change themselves to be same by tile.")
|
|
223
|
-
if seq_len % 16 != 0 or k_seq_len % 16 != 0 or k_seq_len != v_seq_len:
|
|
224
|
-
raise ValueError(
|
|
225
|
-
"query, key, value seq_len must be a multiple of 16, and key seq_len, value seq_len must be the same.")
|
|
226
|
-
|
|
227
|
-
if head_dim > 304:
|
|
228
|
-
raise ValueError(
|
|
229
|
-
"the head_dim must be less than 304, otherwise the ub would be OOM.")
|
|
230
|
-
|
|
231
|
-
if self.is_910A:
|
|
232
|
-
# 910A -- FlashAttentionPrimtive
|
|
233
|
-
if self.dropout_rate > 1e-5:
|
|
234
|
-
drop_mask_bits = self.drop_gen_mask((bsz, head_num, seq_len, seq_len), self.keep_prob)
|
|
235
|
-
tensor_shape = Tensor((bsz, head_num, seq_len, seq_len), mstype.int32)
|
|
236
|
-
ones = self.fill_v2(tensor_shape, self.tensor_one)
|
|
237
|
-
ones = self.depend(ones, query)
|
|
238
|
-
drop_mask = self.do_dropout(ones, drop_mask_bits, self.keep_prob)
|
|
239
|
-
else:
|
|
240
|
-
drop_mask = None
|
|
241
|
-
output, _, _ = self.flash_attention(query, key, value, attn_mask, drop_mask, alibi_mask)
|
|
169
|
+
bsz, head_num, seq_len, _ = query.shape
|
|
170
|
+
# 910B -- FlashAttentionScore
|
|
171
|
+
if self.dropout_rate > 1e-5:
|
|
172
|
+
drop_mask_bits = self.reshape(self.drop_gen_mask((bsz, head_num, seq_len, seq_len), self.keep_prob),
|
|
173
|
+
(bsz, head_num, seq_len, seq_len // 8))
|
|
242
174
|
else:
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
value,
|
|
258
|
-
attn_mask,
|
|
259
|
-
drop_mask_bits,
|
|
260
|
-
None,
|
|
261
|
-
None)
|
|
262
|
-
output = self.transpose_4d_post(self.reshape(output, (bsz, seq_len, head_num, head_dim)), (0, 2, 1, 3))
|
|
263
|
-
|
|
175
|
+
drop_mask_bits = None
|
|
176
|
+
if self.alibi:
|
|
177
|
+
alibi_mask = self.alibi_rescale_mul(alibi_mask, self.cast(self.alibi_rescale_factor, alibi_mask.dtype))
|
|
178
|
+
# (B, S, S) -> (B, 1, S, S)
|
|
179
|
+
if self.have_attention_mask_batch:
|
|
180
|
+
attn_mask = self.cast(self.reshape(attn_mask, (bsz, 1, seq_len, seq_len)), mstype.uint8)
|
|
181
|
+
_, _, _, output = self.flash_attention(query,
|
|
182
|
+
key,
|
|
183
|
+
value,
|
|
184
|
+
alibi_mask,
|
|
185
|
+
drop_mask_bits,
|
|
186
|
+
None,
|
|
187
|
+
attn_mask,
|
|
188
|
+
None)
|
|
264
189
|
return output
|