mindspore 2.2.0__cp39-cp39-win_amd64.whl → 2.2.11__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (112) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
  3. mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
  4. mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
  5. mindspore/_checkparam.py +3 -3
  6. mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
  7. mindspore/_extends/graph_kernel/splitter.py +3 -2
  8. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +83 -66
  9. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -4
  10. mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
  11. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +2 -1
  12. mindspore/_extends/parse/__init__.py +3 -2
  13. mindspore/_extends/parse/parser.py +6 -1
  14. mindspore/_extends/parse/standard_method.py +14 -11
  15. mindspore/_extends/remote/kernel_build_server.py +2 -1
  16. mindspore/common/_utils.py +16 -0
  17. mindspore/common/api.py +1 -1
  18. mindspore/common/auto_dynamic_shape.py +81 -85
  19. mindspore/common/dump.py +1 -1
  20. mindspore/common/tensor.py +3 -20
  21. mindspore/config/op_info.config +1 -1
  22. mindspore/context.py +11 -4
  23. mindspore/dataset/engine/cache_client.py +8 -5
  24. mindspore/dataset/engine/datasets_standard_format.py +5 -0
  25. mindspore/dataset/vision/transforms.py +21 -21
  26. mindspore/experimental/optim/adam.py +1 -1
  27. mindspore/gen_ops.py +1 -1
  28. mindspore/include/api/model.h +17 -0
  29. mindspore/include/api/status.h +8 -3
  30. mindspore/mindspore_backend.dll +0 -0
  31. mindspore/mindspore_common.dll +0 -0
  32. mindspore/mindspore_core.dll +0 -0
  33. mindspore/mindspore_shared_lib.dll +0 -0
  34. mindspore/nn/cell.py +0 -3
  35. mindspore/nn/layer/activation.py +4 -5
  36. mindspore/nn/layer/conv.py +39 -23
  37. mindspore/nn/layer/flash_attention.py +54 -129
  38. mindspore/nn/layer/math.py +3 -7
  39. mindspore/nn/layer/rnn_cells.py +5 -5
  40. mindspore/nn/wrap/__init__.py +4 -2
  41. mindspore/nn/wrap/cell_wrapper.py +12 -3
  42. mindspore/numpy/utils_const.py +5 -5
  43. mindspore/opencv_core452.dll +0 -0
  44. mindspore/opencv_imgcodecs452.dll +0 -0
  45. mindspore/opencv_imgproc452.dll +0 -0
  46. mindspore/ops/_grad_experimental/grad_array_ops.py +1 -1
  47. mindspore/ops/_grad_experimental/grad_implementations.py +2 -2
  48. mindspore/ops/_grad_experimental/grad_math_ops.py +19 -18
  49. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
  50. mindspore/ops/_op_impl/aicpu/add.py +3 -3
  51. mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +21 -2
  52. mindspore/ops/_utils/utils.py +2 -0
  53. mindspore/ops/composite/multitype_ops/_compile_utils.py +2 -1
  54. mindspore/ops/composite/multitype_ops/getitem_impl.py +2 -2
  55. mindspore/ops/function/array_func.py +10 -7
  56. mindspore/ops/function/grad/grad_func.py +0 -1
  57. mindspore/ops/function/nn_func.py +98 -9
  58. mindspore/ops/function/random_func.py +2 -1
  59. mindspore/ops/op_info_register.py +24 -21
  60. mindspore/ops/operations/__init__.py +6 -2
  61. mindspore/ops/operations/_grad_ops.py +25 -6
  62. mindspore/ops/operations/_inner_ops.py +155 -23
  63. mindspore/ops/operations/array_ops.py +9 -7
  64. mindspore/ops/operations/comm_ops.py +2 -2
  65. mindspore/ops/operations/custom_ops.py +85 -68
  66. mindspore/ops/operations/inner_ops.py +26 -3
  67. mindspore/ops/operations/math_ops.py +7 -6
  68. mindspore/ops/operations/nn_ops.py +193 -49
  69. mindspore/parallel/_parallel_serialization.py +10 -3
  70. mindspore/parallel/_tensor.py +4 -1
  71. mindspore/parallel/checkpoint_transform.py +13 -2
  72. mindspore/parallel/shard.py +17 -10
  73. mindspore/profiler/common/util.py +1 -0
  74. mindspore/profiler/parser/ascend_hccl_generator.py +232 -0
  75. mindspore/profiler/parser/ascend_msprof_exporter.py +86 -43
  76. mindspore/profiler/parser/ascend_msprof_generator.py +196 -9
  77. mindspore/profiler/parser/ascend_op_generator.py +1 -1
  78. mindspore/profiler/parser/ascend_timeline_generator.py +6 -182
  79. mindspore/profiler/parser/base_timeline_generator.py +1 -1
  80. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +2 -2
  81. mindspore/profiler/parser/framework_parser.py +1 -1
  82. mindspore/profiler/parser/profiler_info.py +19 -0
  83. mindspore/profiler/profiling.py +46 -24
  84. mindspore/rewrite/api/pattern_engine.py +1 -1
  85. mindspore/rewrite/parsers/for_parser.py +7 -7
  86. mindspore/rewrite/parsers/module_parser.py +4 -4
  87. mindspore/rewrite/symbol_tree.py +1 -4
  88. mindspore/run_check/_check_version.py +5 -3
  89. mindspore/safeguard/rewrite_obfuscation.py +52 -28
  90. mindspore/train/callback/_summary_collector.py +1 -1
  91. mindspore/train/dataset_helper.py +1 -0
  92. mindspore/train/model.py +2 -2
  93. mindspore/train/serialization.py +97 -11
  94. mindspore/train/summary/_summary_adapter.py +1 -1
  95. mindspore/train/summary/summary_record.py +23 -7
  96. mindspore/version.py +1 -1
  97. {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/METADATA +3 -2
  98. {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/RECORD +101 -112
  99. mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
  100. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +0 -406
  101. mindspore/ops/_op_impl/_custom_op/flash_attention/constants.py +0 -41
  102. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +0 -467
  103. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +0 -563
  104. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +0 -193
  105. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +0 -435
  106. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
  107. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +0 -45
  108. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +0 -67
  109. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +0 -62
  110. {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/WHEEL +0 -0
  111. {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/entry_points.txt +0 -0
  112. {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/top_level.txt +0 -0
@@ -144,14 +144,14 @@ class AdjustBrightness(ImageTensorOperation, PyTensorOperation):
144
144
 
145
145
  Args:
146
146
  device_target (str, optional): The operator will be executed on this device. Currently supports
147
- ``CPU`` and ``Ascend`` , where ``Ascend`` refers to Ascend910B device. Default: ``CPU`` .
147
+ ``CPU`` . Default: ``CPU`` .
148
148
 
149
149
  Raises:
150
150
  TypeError: If `device_target` is not of type str.
151
- ValueError: If `device_target` is not within the valid set of ['CPU', 'Ascend'].
151
+ ValueError: If `device_target` is not ``CPU`` .
152
152
 
153
153
  Supported Platforms:
154
- ``CPU`` ``Ascend``
154
+ ``CPU``
155
155
 
156
156
  Examples:
157
157
  >>> import mindspore.dataset as ds
@@ -227,14 +227,14 @@ class AdjustContrast(ImageTensorOperation, PyTensorOperation):
227
227
 
228
228
  Args:
229
229
  device_target (str, optional): The operator will be executed on this device. Currently supports
230
- ``CPU`` and ``Ascend`` , where ``Ascend`` refers to Ascend910B device. Default: ``CPU`` .
230
+ ``CPU`` . Default: ``CPU`` .
231
231
 
232
232
  Raises:
233
233
  TypeError: If `device_target` is not of type str.
234
- ValueError: If `device_target` is not within the valid set of ['CPU', 'Ascend'].
234
+ ValueError: If `device_target` is not ``CPU`` .
235
235
 
236
236
  Supported Platforms:
237
- ``CPU`` ``Ascend``
237
+ ``CPU``
238
238
 
239
239
  Examples:
240
240
  >>> import mindspore.dataset as ds
@@ -373,14 +373,14 @@ class AdjustHue(ImageTensorOperation, PyTensorOperation):
373
373
 
374
374
  Args:
375
375
  device_target (str, optional): The operator will be executed on this device. Currently supports
376
- ``CPU`` and ``Ascend`` , where ``Ascend`` refers to Ascend910B device. Default: ``CPU`` .
376
+ ``CPU`` . Default: ``CPU`` .
377
377
 
378
378
  Raises:
379
379
  TypeError: If `device_target` is not of type str.
380
- ValueError: If `device_target` is not within the valid set of ['CPU', 'Ascend'].
380
+ ValueError: If `device_target` is not ``CPU`` .
381
381
 
382
382
  Supported Platforms:
383
- ``CPU`` ``Ascend``
383
+ ``CPU``
384
384
 
385
385
  Examples:
386
386
  >>> import mindspore.dataset as ds
@@ -457,14 +457,14 @@ class AdjustSaturation(ImageTensorOperation, PyTensorOperation):
457
457
 
458
458
  Args:
459
459
  device_target (str, optional): The operator will be executed on this device. Currently supports
460
- ``CPU`` and ``Ascend`` , where ``Ascend`` refers to Ascend910B device. Default: ``CPU`` .
460
+ ``CPU`` . Default: ``CPU`` .
461
461
 
462
462
  Raises:
463
463
  TypeError: If `device_target` is not of type str.
464
- ValueError: If `device_target` is not within the valid set of ['CPU', 'Ascend'].
464
+ ValueError: If `device_target` is not ``CPU`` .
465
465
 
466
466
  Supported Platforms:
467
- ``CPU`` ``Ascend``
467
+ ``CPU``
468
468
 
469
469
  Examples:
470
470
  >>> import mindspore.dataset as ds
@@ -1159,14 +1159,14 @@ class Decode(ImageTensorOperation, PyTensorOperation):
1159
1159
 
1160
1160
  Args:
1161
1161
  device_target (str, optional): The operator will be executed on this device. Currently supports
1162
- ``CPU`` and ``Ascend`` , where ``Ascend`` refers to Ascend910B device. Default: ``CPU`` .
1162
+ ``CPU`` . Default: ``CPU`` .
1163
1163
 
1164
1164
  Raises:
1165
1165
  TypeError: If `device_target` is not of type str.
1166
- ValueError: If `device_target` is not within the valid set of ['CPU', 'Ascend'].
1166
+ ValueError: If `device_target` is not ``CPU`` .
1167
1167
 
1168
1168
  Supported Platforms:
1169
- ``CPU`` ``Ascend``
1169
+ ``CPU``
1170
1170
 
1171
1171
  Examples:
1172
1172
  >>> import mindspore.dataset as ds
@@ -1908,14 +1908,14 @@ class Normalize(ImageTensorOperation):
1908
1908
 
1909
1909
  Args:
1910
1910
  device_target (str, optional): The operator will be executed on this device. Currently supports
1911
- ``CPU`` and ``Ascend`` , where ``Ascend`` refers to Ascend910B device. Default: ``CPU`` .
1911
+ ``CPU`` . Default: ``CPU`` .
1912
1912
 
1913
1913
  Raises:
1914
1914
  TypeError: If `device_target` is not of type str.
1915
- ValueError: If `device_target` is not within the valid set of ['CPU', 'Ascend'].
1915
+ ValueError: If `device_target` is not ``CPU`` .
1916
1916
 
1917
1917
  Supported Platforms:
1918
- ``CPU`` ``Ascend``
1918
+ ``CPU``
1919
1919
 
1920
1920
  Examples:
1921
1921
  >>> import mindspore.dataset as ds
@@ -4182,14 +4182,14 @@ class Resize(ImageTensorOperation, PyTensorOperation):
4182
4182
 
4183
4183
  Args:
4184
4184
  device_target (str, optional): The operator will be executed on this device. Currently supports
4185
- ``CPU`` and ``Ascend`` , where ``Ascend`` refers to Ascend910B device. Default: ``CPU`` .
4185
+ ``CPU`` . Default: ``CPU`` .
4186
4186
 
4187
4187
  Raises:
4188
4188
  TypeError: If `device_target` is not of type str.
4189
- ValueError: If `device_target` is not within the valid set of ['CPU', 'Ascend'].
4189
+ ValueError: If `device_target` is not ``CPU`` .
4190
4190
 
4191
4191
  Supported Platforms:
4192
- ``CPU`` ``Ascend``
4192
+ ``CPU``
4193
4193
 
4194
4194
  Examples:
4195
4195
  >>> import mindspore.dataset as ds
@@ -43,7 +43,7 @@ def _run_adam_with_amsgrad_opt(opt, beta1_power, beta2_power, lr, gradient, para
43
43
 
44
44
  class Adam(Optimizer):
45
45
  r"""
46
- Implements Adam algorithm..
46
+ Implements Adam algorithm.
47
47
 
48
48
  The updating formulas are as follows:
49
49
 
mindspore/gen_ops.py CHANGED
@@ -120,7 +120,7 @@ def generate_py_primitive(yaml_data):
120
120
  assign_str += arg_name
121
121
  args_assign.append(assign_str)
122
122
 
123
- args_assign = '\n'.join(assign for assign in args_assign)
123
+ args_assign = '\n'.join([assign for assign in args_assign])
124
124
  primitive_code = f"""
125
125
  class {class_name}(Primitive):
126
126
  def __init__(self, {', '.join(init_args_with_default)}):
@@ -136,6 +136,13 @@ class MS_API Model {
136
136
  /// \return Status.
137
137
  Status UpdateWeights(const std::vector<MSTensor> &new_weights);
138
138
 
139
+ /// \brief Change the size and or content of weight tensors
140
+ ///
141
+ /// \param[in] A vector where model constant are arranged in sequence
142
+ ///
143
+ /// \return Status.
144
+ Status UpdateWeights(const std::vector<std::vector<MSTensor>> &new_weights);
145
+
139
146
  /// \brief Inference model API. If use this API in train mode, it's equal to RunStep API.
140
147
  ///
141
148
  /// \param[in] inputs A vector where model inputs are arranged in sequence.
@@ -358,6 +365,13 @@ class MS_API Model {
358
365
 
359
366
  const std::shared_ptr<ModelImpl> impl() const { return impl_; }
360
367
 
368
+ /// \brief Get model info by key
369
+ ///
370
+ /// \param[in] key The key of model info key-value pair
371
+ ///
372
+ /// \return The value of the model info associated with the given key.
373
+ inline std::string GetModelInfo(const std::string &key);
374
+
361
375
  private:
362
376
  friend class Serialization;
363
377
  // api without std::string
@@ -374,6 +388,7 @@ class MS_API Model {
374
388
  const std::vector<char> &cropto_lib_path);
375
389
  Status Build(const std::vector<char> &model_path, ModelType model_type, const std::shared_ptr<Context> &model_context,
376
390
  const Key &dec_key, const std::vector<char> &dec_mode, const std::vector<char> &cropto_lib_path);
391
+ std::vector<char> GetModelInfo(const std::vector<char> &key);
377
392
  std::shared_ptr<ModelImpl> impl_;
378
393
  };
379
394
 
@@ -416,5 +431,7 @@ Status Model::Build(const std::string &model_path, ModelType model_type,
416
431
  const std::shared_ptr<Context> &model_context) {
417
432
  return Build(StringToChar(model_path), model_type, model_context);
418
433
  }
434
+
435
+ inline std::string Model::GetModelInfo(const std::string &key) { return CharToString(GetModelInfo(StringToChar(key))); }
419
436
  } // namespace mindspore
420
437
  #endif // MINDSPORE_INCLUDE_API_MODEL_H
@@ -83,9 +83,14 @@ enum StatusCode : uint32_t {
83
83
  kLiteModelRebuild = kLite | (0x0FFFFFFF & -12), /**< Model has been built. */
84
84
 
85
85
  // Executor error code, range: [-100,-200)
86
- kLiteOutOfTensorRange = kLite | (0x0FFFFFFF & -100), /**< Failed to check range. */
87
- kLiteInputTensorError = kLite | (0x0FFFFFFF & -101), /**< Failed to check input tensor. */
88
- kLiteReentrantError = kLite | (0x0FFFFFFF & -102), /**< Exist executor running. */
86
+ kLiteOutOfTensorRange = kLite | (0x0FFFFFFF & -100), /**< Failed to check range. */
87
+ kLiteInputTensorError = kLite | (0x0FFFFFFF & -101), /**< Failed to check input tensor. */
88
+ kLiteReentrantError = kLite | (0x0FFFFFFF & -102), /**< Exist executor running. */
89
+ kLiteLLMWaitProcessTimeOut = kLite | (0x0FFFFFFF & -103), /**< Wait to be processed time out. */
90
+ kLiteLLMKVCacheNotExist = kLite | (0x0FFFFFFF & -104), /**< KV Cache not exist. */
91
+ kLiteLLMRepeatRequest = kLite | (0x0FFFFFFF & -105), /**< repeat request. */
92
+ kLiteLLMRequestAlreadyCompleted = kLite | (0x0FFFFFFF & -106), /**< request already complete!. */
93
+ kLiteLLMEngineFinalized = kLite | (0x0FFFFFFF & -107), /**< llm engine finalized. */
89
94
 
90
95
  // Graph error code, range: [-200,-300)
91
96
  kLiteGraphFileError = kLite | (0x0FFFFFFF & -200), /**< Failed to verify graph file. */
Binary file
Binary file
Binary file
Binary file
mindspore/nn/cell.py CHANGED
@@ -1081,9 +1081,6 @@ class Cell(Cell_):
1081
1081
  if not isinstance(param, Parameter) and param is not None:
1082
1082
  raise TypeError(f"For 'insert_param_to_cell', the argument 'param' must be 'Parameter' if not None, "
1083
1083
  f"but got {type(param)}.")
1084
- if param is None:
1085
- raise TypeError(f"For 'insert_param_to_cell', the argument 'param' must not be None, "
1086
- f"but got None.")
1087
1084
  if isinstance(param, Parameter) and param.name == PARAMETER_NAME_DEFAULT:
1088
1085
  param.name = param_name
1089
1086
  self._params[param_name] = param
@@ -932,10 +932,8 @@ class GELU(Cell):
932
932
  """Initialize GELU."""
933
933
  super(GELU, self).__init__()
934
934
  validator.check_bool(approximate, 'approximate', self.cls_name)
935
- self.approximate = approximate
936
- if approximate:
937
- self.approximate = 'tanh'
938
- else:
935
+ self.approximate = 'tanh'
936
+ if not approximate:
939
937
  self.approximate = 'none'
940
938
 
941
939
  def construct(self, x):
@@ -1335,7 +1333,8 @@ class LRN(Cell):
1335
1333
 
1336
1334
  .. warning::
1337
1335
  LRN is deprecated on Ascend due to potential accuracy problem. It's recommended to use other
1338
- normalization methods, e.g. :class:`mindspore.nn.BatchNorm`.
1336
+ normalization methods, e.g. :class:`mindspore.nn.BatchNorm1d` ,
1337
+ :class:`mindspore.nn.BatchNorm2d` , :class:`mindspore.nn.BatchNorm3d`.
1339
1338
 
1340
1339
  Refer to :func:`mindspore.ops.lrn` for more details.
1341
1340
 
@@ -718,9 +718,9 @@ class Conv3d(_Conv):
718
718
 
719
719
  .. math::
720
720
  \begin{array}{ll} \\
721
- D_{out} \left \lceil{\frac{D_{in}}{\text{stride[0]}}} \right \rceil \\
722
- H_{out} \left \lceil{\frac{H_{in}}{\text{stride[1]}}} \right \rceil \\
723
- W_{out} \left \lceil{\frac{W_{in}}{\text{stride[2]}}} \right \rceil \\
721
+ D_{out} = \left \lceil{\frac{D_{in}}{\text{stride[0]}}} \right \rceil \\
722
+ H_{out} = \left \lceil{\frac{H_{in}}{\text{stride[1]}}} \right \rceil \\
723
+ W_{out} = \left \lceil{\frac{W_{in}}{\text{stride[2]}}} \right \rceil \\
724
724
  \end{array}
725
725
 
726
726
 
@@ -728,11 +728,11 @@ class Conv3d(_Conv):
728
728
 
729
729
  .. math::
730
730
  \begin{array}{ll} \\
731
- D_{out} \left \lfloor{\frac{D_{in} - \text{dilation[0]} \times (\text{kernel_size[0]} - 1) }
731
+ D_{out} = \left \lfloor{\frac{D_{in} - \text{dilation[0]} \times (\text{kernel_size[0]} - 1) }
732
732
  {\text{stride[0]}} + 1} \right \rfloor \\
733
- H_{out} \left \lfloor{\frac{H_{in} - \text{dilation[1]} \times (\text{kernel_size[1]} - 1) }
733
+ H_{out} = \left \lfloor{\frac{H_{in} - \text{dilation[1]} \times (\text{kernel_size[1]} - 1) }
734
734
  {\text{stride[1]}} + 1} \right \rfloor \\
735
- W_{out} \left \lfloor{\frac{W_{in} - \text{dilation[2]} \times (\text{kernel_size[2]} - 1) }
735
+ W_{out} = \left \lfloor{\frac{W_{in} - \text{dilation[2]} \times (\text{kernel_size[2]} - 1) }
736
736
  {\text{stride[2]}} + 1} \right \rfloor \\
737
737
  \end{array}
738
738
 
@@ -740,11 +740,11 @@ class Conv3d(_Conv):
740
740
 
741
741
  .. math::
742
742
  \begin{array}{ll} \\
743
- D_{out} \left \lfloor{\frac{D_{in} + padding[0] + padding[1] - (\text{dilation[0]} - 1) \times
743
+ D_{out} = \left \lfloor{\frac{D_{in} + padding[0] + padding[1] - (\text{dilation[0]} - 1) \times
744
744
  \text{kernel_size[0]} - 1 }{\text{stride[0]}} + 1} \right \rfloor \\
745
- H_{out} \left \lfloor{\frac{H_{in} + padding[2] + padding[3] - (\text{dilation[1]} - 1) \times
745
+ H_{out} = \left \lfloor{\frac{H_{in} + padding[2] + padding[3] - (\text{dilation[1]} - 1) \times
746
746
  \text{kernel_size[1]} - 1 }{\text{stride[1]}} + 1} \right \rfloor \\
747
- W_{out} \left \lfloor{\frac{W_{in} + padding[4] + padding[5] - (\text{dilation[2]} - 1) \times
747
+ W_{out} = \left \lfloor{\frac{W_{in} + padding[4] + padding[5] - (\text{dilation[2]} - 1) \times
748
748
  \text{kernel_size[2]} - 1 }{\text{stride[2]}} + 1} \right \rfloor \\
749
749
  \end{array}
750
750
 
@@ -812,7 +812,7 @@ class Conv3d(_Conv):
812
812
  bias_init,
813
813
  data_format,
814
814
  dtype=dtype)
815
- out_channels = self.out_channels
815
+ out_channels = self.out_channels // group
816
816
  self.conv3d = P.Conv3D(out_channel=out_channels,
817
817
  kernel_size=self.kernel_size,
818
818
  mode=1,
@@ -820,17 +820,33 @@ class Conv3d(_Conv):
820
820
  pad=self.padding,
821
821
  stride=self.stride,
822
822
  dilation=self.dilation,
823
- group=group,
823
+ group=1,
824
824
  data_format=self.data_format)
825
825
  self.bias_add = P.BiasAdd(data_format=self.data_format)
826
826
  self.shape = P.Shape()
827
+ self.concat = P.Concat(1)
828
+ self.split_0 = P.Split(0, self.group)
829
+ self.split_1 = P.Split(1, self.group)
827
830
 
828
831
  def construct(self, x):
829
832
  x_shape = self.shape(x)
830
833
  _check_input_5dims(x_shape, self.cls_name)
831
- out = self.conv3d(x, self.weight)
832
- if self.has_bias:
833
- out = self.bias_add(out, self.bias)
834
+ if self.group == 1:
835
+ out = self.conv3d(x, self.weight)
836
+ if self.has_bias:
837
+ out = self.bias_add(out, self.bias)
838
+ else:
839
+ features = self.split_1(x)
840
+ weights = self.split_0(self.weight)
841
+ outputs = ()
842
+ for i in range(self.group):
843
+ output = self.conv3d(features[i], weights[i])
844
+ outputs = outputs + (output,)
845
+ out = self.concat(outputs)
846
+ if self.bias is not None:
847
+ new_shape = [1 for _ in range(out.ndim)]
848
+ new_shape[1] = self.out_channels
849
+ out = out + self.bias.reshape(new_shape)
834
850
  return out
835
851
 
836
852
 
@@ -921,9 +937,9 @@ class Conv3dTranspose(_Conv):
921
937
 
922
938
  .. math::
923
939
  \begin{array}{ll} \\
924
- D_{out} \left \lfloor{\frac{D_{in}}{\text{stride[0]}} + 1} \right \rfloor \\
925
- H_{out} \left \lfloor{\frac{H_{in}}{\text{stride[1]}} + 1} \right \rfloor \\
926
- W_{out} \left \lfloor{\frac{W_{in}}{\text{stride[2]}} + 1} \right \rfloor \\
940
+ D_{out} = \left \lfloor{\frac{D_{in}}{\text{stride[0]}} + 1} \right \rfloor \\
941
+ H_{out} = \left \lfloor{\frac{H_{in}}{\text{stride[1]}} + 1} \right \rfloor \\
942
+ W_{out} = \left \lfloor{\frac{W_{in}}{\text{stride[2]}} + 1} \right \rfloor \\
927
943
  \end{array}
928
944
 
929
945
 
@@ -931,11 +947,11 @@ class Conv3dTranspose(_Conv):
931
947
 
932
948
  .. math::
933
949
  \begin{array}{ll} \\
934
- D_{out} \left \lfloor{\frac{D_{in} - \text{dilation[0]} \times (\text{kernel_size[0]} - 1) }
950
+ D_{out} = \left \lfloor{\frac{D_{in} - \text{dilation[0]} \times (\text{kernel_size[0]} - 1) }
935
951
  {\text{stride[0]}} + 1} \right \rfloor \\
936
- H_{out} \left \lfloor{\frac{H_{in} - \text{dilation[1]} \times (\text{kernel_size[1]} - 1) }
952
+ H_{out} = \left \lfloor{\frac{H_{in} - \text{dilation[1]} \times (\text{kernel_size[1]} - 1) }
937
953
  {\text{stride[1]}} + 1} \right \rfloor \\
938
- W_{out} \left \lfloor{\frac{W_{in} - \text{dilation[2]} \times (\text{kernel_size[2]} - 1) }
954
+ W_{out} = \left \lfloor{\frac{W_{in} - \text{dilation[2]} \times (\text{kernel_size[2]} - 1) }
939
955
  {\text{stride[2]}} + 1} \right \rfloor \\
940
956
  \end{array}
941
957
 
@@ -943,11 +959,11 @@ class Conv3dTranspose(_Conv):
943
959
 
944
960
  .. math::
945
961
  \begin{array}{ll} \\
946
- D_{out} \left \lfloor{\frac{D_{in} + padding[0] + padding[1] - (\text{dilation[0]} - 1) \times
962
+ D_{out} = \left \lfloor{\frac{D_{in} + padding[0] + padding[1] - (\text{dilation[0]} - 1) \times
947
963
  \text{kernel_size[0]} - 1 }{\text{stride[0]}} + 1} \right \rfloor \\
948
- H_{out} \left \lfloor{\frac{H_{in} + padding[2] + padding[3] - (\text{dilation[1]} - 1) \times
964
+ H_{out} = \left \lfloor{\frac{H_{in} + padding[2] + padding[3] - (\text{dilation[1]} - 1) \times
949
965
  \text{kernel_size[1]} - 1 }{\text{stride[1]}} + 1} \right \rfloor \\
950
- W_{out} \left \lfloor{\frac{W_{in} + padding[4] + padding[5] - (\text{dilation[2]} - 1) \times
966
+ W_{out} = \left \lfloor{\frac{W_{in} + padding[4] + padding[5] - (\text{dilation[2]} - 1) \times
951
967
  \text{kernel_size[2]} - 1 }{\text{stride[2]}} + 1} \right \rfloor \\
952
968
  \end{array}
953
969
 
@@ -21,9 +21,7 @@ import mindspore.common.dtype as mstype
21
21
  from mindspore.common.tensor import Tensor
22
22
  from mindspore import ops
23
23
  from mindspore.nn.cell import Cell
24
- from mindspore.ops._op_impl._custom_op.flash_attention.flash_attention_impl import get_flash_attention
25
24
  from mindspore.ops.operations.nn_ops import FlashAttentionScore
26
- from mindspore._c_expression import MSContext
27
25
 
28
26
  __all__ = ['FlashAttention']
29
27
 
@@ -46,25 +44,25 @@ class FlashAttention(Cell):
46
44
  Default 65536.
47
45
  next_block_num(int): A integer to define the number of blocks to look behind for local block sparse attention.
48
46
  Default 65536.
49
- tiling_stgy_name(str): A str to define tiling strategy of flash attention.
50
47
  dp(int): data parallel.
51
48
  Default 1.
52
49
  mp(int): model parallel.
53
50
  Default 1.
54
- high_precision(bool): This mode has higher precision but some performance loss.
51
+ high_precision(bool): This mode has higher precision but some performance loss. Only take effect on Ascend910A.
55
52
  Default False.
56
53
  have_attention_mask_batch(bool): indicates whether attention_mask contains the batch dimension.
57
54
  Default True
58
55
  alibi(bool): This parameter indicates whether the flashattention supports the Alibi.
59
56
  Default: False
57
+ use_mqa(bool): Using MQA if True, only take effect under 910B. Default: False.
60
58
 
61
59
 
62
60
  Inputs:
63
61
  - **query** (Tensor) - Tensor query (:class:`mstype.fp16` [batch_size, head_num, seq_length, head_dim])
64
62
  - **key** (Tensor) - Tensor key (:class:`mstype.fp16` [batch_size, head_num, seq_length, head_dim])
65
63
  - **value** (Tensor) - Tensor value (:class:`mstype.fp16` [batch_size, head_num, seq_length, head_dim])
66
- - **attention_mask** (Tensor) - Float Tensor the mask of (:class:`mstype.fp16` [batch_size, seq_length,
67
- seq_length]): A matrix to pass masked information.
64
+ - **attention_mask** (Tensor) - Float Tensor the mask of (:class:`mstype.fp16` `mstype.uint8`
65
+ [batch_size, seq_length, seq_length]): A matrix to pass masked information.
68
66
 
69
67
  Outputs:
70
68
  A Tensor. The output of the attention with shape [batch_size, head_num, seq_length, head_dim]
@@ -97,56 +95,51 @@ class FlashAttention(Cell):
97
95
  dropout_rate=0.0,
98
96
  prev_block_num=65536,
99
97
  next_block_num=65536,
100
- tiling_stgy_name="sparse",
101
98
  dp=1,
102
99
  mp=1,
103
100
  high_precision=False,
104
101
  have_attention_mask_batch=True,
105
- alibi=False
102
+ alibi=False,
103
+ use_mqa=False
106
104
  ):
107
105
  super(FlashAttention, self).__init__()
108
106
 
109
107
  scaling_constant = math.sqrt(head_dim)
110
108
  if scaling_constant == 0:
111
109
  raise ValueError("the scaling constant must not be 0.")
112
- self.scale_factor = Tensor([1. / scaling_constant], dtype=mstype.float16)
110
+ self.dropout_rate = dropout_rate
111
+ self.alibi = alibi
112
+ self.have_attention_mask_batch = have_attention_mask_batch
113
113
 
114
- self.is_910A = MSContext.get_instance().get_ascend_soc_version() == "Ascend910"
115
- if self.is_910A:
116
- self.flash_attention = get_flash_attention(
117
- prev_block_num=prev_block_num,
118
- next_block_num=next_block_num,
119
- tiling_stgy_name=tiling_stgy_name,
120
- high_precision=high_precision
121
- )
122
- self.flash_attention.add_prim_attr("primitive_target", "Ascend")
123
- else:
124
- if alibi:
125
- raise ValueError(f"When soc_version is not Ascend910A, alibi must be False")
126
- self.transpose_4d_pre = ops.Transpose().shard(((dp, mp, 1, 1),))
127
- self.transpose_4d_post = ops.Transpose().shard(((dp, 1, mp, 1),))
128
- self.reshape = ops.Reshape()
129
- self.zeros_like = ops.ZerosLike().shard(((dp, mp, 1, 1),))
130
- self.zeros = ops.Zeros()
131
- self.attn_expand_dims = ops.ExpandDims().shard(((dp, 1, 1),))
132
- fa_strategies = ((dp, 1, mp),
133
- (dp, 1, mp),
134
- (dp, 1, mp),
114
+ self.transpose_4d_pre = ops.Transpose().shard(((dp, mp, 1, 1),))
115
+ self.transpose_4d_post = ops.Transpose().shard(((dp, 1, mp, 1),))
116
+ self.reshape = ops.Reshape()
117
+ self.zeros_like = ops.ZerosLike().shard(((dp, mp, 1, 1),))
118
+ self.zeros = ops.Zeros()
119
+ self.attn_cast = ops.Cast()
120
+ if use_mqa:
121
+ fa_strategies = ((dp, mp, 1, 1),
122
+ (dp, 1, 1, 1),
135
123
  (dp, 1, 1, 1))
136
- if dropout_rate > 1e-5:
137
- fa_strategies += ((dp, mp, 1, 1),)
138
- self.flash_attention = FlashAttentionScore(head_num=head_num, pre_tokens=prev_block_num,
139
- next_tokens=next_block_num,
140
- keep_prob=1 - dropout_rate,
141
- scale_value=1.0,
142
- inner_precise=0 if high_precision else 1).shard(fa_strategies)
124
+ else:
125
+ fa_strategies = ((dp, mp, 1, 1),
126
+ (dp, mp, 1, 1),
127
+ (dp, mp, 1, 1))
128
+ if self.alibi:
129
+ self.alibi_rescale_mul = ops.Mul().shard(((dp, mp, 1, 1), (1,)))
130
+ self.alibi_rescale_factor = Tensor([scaling_constant], dtype=mstype.float16)
131
+ fa_strategies += ((dp, mp, 1, 1),)
132
+ if dropout_rate > 1e-5:
133
+ fa_strategies += ((dp, mp, 1, 1),)
134
+ fa_strategies += ((dp, 1, 1, 1),)
135
+ self.flash_attention = FlashAttentionScore(head_num=head_num, pre_tokens=prev_block_num,
136
+ next_tokens=next_block_num,
137
+ keep_prob=1 - dropout_rate,
138
+ scale_value=1. / scaling_constant,
139
+ inner_precise=0,
140
+ input_layout="BNSD").shard(fa_strategies)
143
141
 
144
- self.ones = ops.Ones()
145
- self.dim_mask = Tensor([1 for _ in range(head_dim)], dtype=mstype.int8)
146
- self.scale_mul = ops.Mul().shard(((dp, mp, 1, 1), (1,)))
147
142
  self.dropout_rate = dropout_rate
148
- self.have_attention_mask_batch = have_attention_mask_batch
149
- self.alibi = alibi
150
143
  if self.dropout_rate > 1e-5:
151
144
  self.keep_prob = Tensor(1 - self.dropout_rate, dtype=mstype.float16)
152
145
  self.fill_v2 = ops.FillV2().shard(((dp, mp, 1, 1), ()))
@@ -162,46 +155,7 @@ class FlashAttention(Cell):
162
155
  such as MatMul. Default: None.
163
156
  :return:
164
157
  """
165
- if in_strategy is None:
166
- # default: dp=1, mp=1, construct inputs only contain query, key, value
167
- in_strategy = (
168
- (1, 1, 1, 1),
169
- (1, 1, 1, 1),
170
- (1, 1, 1, 1),
171
- )
172
158
  self.flash_attention.shard(in_strategy)
173
- dp = in_strategy[0][0]
174
- mp = in_strategy[0][1]
175
- self.flash_attention.add_prim_attr("dev_matrix_shape", [dp, mp, 1, 1])
176
- inputs_tensor_map = [
177
- [3, 2, 1, 0],
178
- [3, 2, 1, 0],
179
- [3, 2, 1, 0],
180
- ]
181
- if self.have_attention_mask_batch:
182
- inputs_tensor_map.append([3, 1, 0])
183
- else:
184
- inputs_tensor_map.append([-1, 1, 0])
185
-
186
- input_empty_args_num = 2
187
- # dropout_mask
188
- if self.dropout_rate > 1e-5:
189
- input_empty_args_num -= 1
190
- inputs_tensor_map.append([3, 2, 1, 0])
191
-
192
- if self.alibi:
193
- input_empty_args_num -= 1
194
- inputs_tensor_map.append([3, 2, 1, 0])
195
-
196
- self.flash_attention.add_prim_attr("inputs_tensor_map", inputs_tensor_map)
197
-
198
- self.flash_attention.add_prim_attr("outputs_tensor_map", [
199
- [3, 2, 1, 0], # O
200
- [3, 2, 1], # L
201
- [3, 2, 1] # M
202
- ])
203
- self.flash_attention.add_prim_attr("as_loss_divisor", 0)
204
- self.flash_attention.add_prim_attr("empty_mirror_ops", input_empty_args_num)
205
159
 
206
160
  def construct(self, query, key, value, attn_mask=None, alibi_mask=None):
207
161
  """FlashAttention forward
@@ -212,53 +166,24 @@ class FlashAttention(Cell):
212
166
  :param alibi_mask: [bsz, head_num, 1, seq_len], if not None
213
167
  :return: output [bsz, head_num, seq_len, head_dim]
214
168
  """
215
- query = self.scale_mul(query, self.scale_factor)
216
- bsz, head_num, seq_len, head_dim = query.shape
217
- _, k_head_num, k_seq_len, _ = key.shape
218
- _, v_head_num, v_seq_len, _ = value.shape
219
- if head_num != k_head_num or head_num != v_head_num:
220
- raise ValueError(
221
- "the head_num of query, key and value must be the same, "
222
- "If different head_num are used, users need to change themselves to be same by tile.")
223
- if seq_len % 16 != 0 or k_seq_len % 16 != 0 or k_seq_len != v_seq_len:
224
- raise ValueError(
225
- "query, key, value seq_len must be a multiple of 16, and key seq_len, value seq_len must be the same.")
226
-
227
- if head_dim > 304:
228
- raise ValueError(
229
- "the head_dim must be less than 304, otherwise the ub would be OOM.")
230
-
231
- if self.is_910A:
232
- # 910A -- FlashAttentionPrimtive
233
- if self.dropout_rate > 1e-5:
234
- drop_mask_bits = self.drop_gen_mask((bsz, head_num, seq_len, seq_len), self.keep_prob)
235
- tensor_shape = Tensor((bsz, head_num, seq_len, seq_len), mstype.int32)
236
- ones = self.fill_v2(tensor_shape, self.tensor_one)
237
- ones = self.depend(ones, query)
238
- drop_mask = self.do_dropout(ones, drop_mask_bits, self.keep_prob)
239
- else:
240
- drop_mask = None
241
- output, _, _ = self.flash_attention(query, key, value, attn_mask, drop_mask, alibi_mask)
169
+ bsz, head_num, seq_len, _ = query.shape
170
+ # 910B -- FlashAttentionScore
171
+ if self.dropout_rate > 1e-5:
172
+ drop_mask_bits = self.reshape(self.drop_gen_mask((bsz, head_num, seq_len, seq_len), self.keep_prob),
173
+ (bsz, head_num, seq_len, seq_len // 8))
242
174
  else:
243
- # FlashAttentionScore
244
- # Useless input, just for binary calls.
245
- if self.dropout_rate > 1e-5:
246
- drop_mask_bits = self.reshape(self.drop_gen_mask((bsz, head_num, seq_len, seq_len), self.keep_prob),
247
- (bsz, head_num, seq_len, seq_len // 8))
248
- else:
249
- drop_mask_bits = None
250
- # (B, N, S, D) -> (B, S, H)
251
- query = self.reshape(self.transpose_4d_pre(query, (0, 2, 1, 3)), (bsz, seq_len, -1))
252
- key = self.reshape(self.transpose_4d_pre(key, (0, 2, 1, 3)), (bsz, seq_len, -1))
253
- value = self.reshape(self.transpose_4d_pre(value, (0, 2, 1, 3)), (bsz, seq_len, -1))
254
- attn_mask = self.attn_expand_dims(attn_mask, 1)
255
- output, _, _ = self.flash_attention(query,
256
- key,
257
- value,
258
- attn_mask,
259
- drop_mask_bits,
260
- None,
261
- None)
262
- output = self.transpose_4d_post(self.reshape(output, (bsz, seq_len, head_num, head_dim)), (0, 2, 1, 3))
263
-
175
+ drop_mask_bits = None
176
+ if self.alibi:
177
+ alibi_mask = self.alibi_rescale_mul(alibi_mask, self.cast(self.alibi_rescale_factor, alibi_mask.dtype))
178
+ # (B, S, S) -> (B, 1, S, S)
179
+ if self.have_attention_mask_batch:
180
+ attn_mask = self.cast(self.reshape(attn_mask, (bsz, 1, seq_len, seq_len)), mstype.uint8)
181
+ _, _, _, output = self.flash_attention(query,
182
+ key,
183
+ value,
184
+ alibi_mask,
185
+ drop_mask_bits,
186
+ None,
187
+ attn_mask,
188
+ None)
264
189
  return output