mindspore 2.2.0__cp38-cp38-manylinux1_x86_64.whl → 2.2.11__cp38-cp38-manylinux1_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (170) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/_akg/akg/composite/build_module.py +104 -20
  3. mindspore/_akg/akg/utils/ascend_profilier/cann_file_parser.py +76 -0
  4. mindspore/_akg/akg/utils/ascend_profilier/file_manager.py +56 -0
  5. mindspore/_akg/akg/utils/ascend_profilier/op_summary_bean.py +23 -0
  6. mindspore/_akg/akg/utils/ascend_profilier/op_summary_headers.py +8 -0
  7. mindspore/_akg/akg/utils/ascend_profilier/op_summary_parser.py +42 -0
  8. mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +65 -0
  9. mindspore/_akg/akg/utils/composite_op_helper.py +7 -2
  10. mindspore/_akg/akg/utils/dump_ascend_meta.py +22 -3
  11. mindspore/_akg/akg/utils/kernel_exec.py +41 -15
  12. mindspore/_akg/akg/utils/tbe_codegen_utils.py +27 -6
  13. mindspore/_akg/akg/utils/util.py +56 -1
  14. mindspore/_c_dataengine.cpython-38-x86_64-linux-gnu.so +0 -0
  15. mindspore/_c_expression.cpython-38-x86_64-linux-gnu.so +0 -0
  16. mindspore/_checkparam.py +3 -3
  17. mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
  18. mindspore/_extends/graph_kernel/splitter.py +3 -2
  19. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +83 -66
  20. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -4
  21. mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
  22. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +2 -1
  23. mindspore/_extends/parse/__init__.py +3 -2
  24. mindspore/_extends/parse/parser.py +6 -1
  25. mindspore/_extends/parse/standard_method.py +14 -11
  26. mindspore/_extends/remote/kernel_build_server.py +2 -1
  27. mindspore/_mindspore_offline_debug.cpython-38-x86_64-linux-gnu.so +0 -0
  28. mindspore/bin/cache_admin +0 -0
  29. mindspore/bin/cache_server +0 -0
  30. mindspore/common/_utils.py +16 -0
  31. mindspore/common/api.py +1 -1
  32. mindspore/common/auto_dynamic_shape.py +81 -85
  33. mindspore/common/dump.py +1 -1
  34. mindspore/common/tensor.py +3 -20
  35. mindspore/config/op_info.config +1 -1
  36. mindspore/context.py +11 -4
  37. mindspore/dataset/engine/cache_client.py +8 -5
  38. mindspore/dataset/engine/datasets_standard_format.py +5 -0
  39. mindspore/dataset/vision/transforms.py +21 -21
  40. mindspore/experimental/optim/adam.py +1 -1
  41. mindspore/gen_ops.py +1 -1
  42. mindspore/include/api/model.h +17 -0
  43. mindspore/include/api/status.h +8 -3
  44. mindspore/lib/libdnnl.so.2 +0 -0
  45. mindspore/lib/libmindspore.so +0 -0
  46. mindspore/lib/libmindspore_backend.so +0 -0
  47. mindspore/lib/libmindspore_common.so +0 -0
  48. mindspore/lib/libmindspore_core.so +0 -0
  49. mindspore/lib/libmindspore_glog.so.0 +0 -0
  50. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  51. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  52. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  53. mindspore/lib/libmindspore_shared_lib.so +0 -0
  54. mindspore/lib/libnnacl.so +0 -0
  55. mindspore/lib/libopencv_core.so.4.5 +0 -0
  56. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  57. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  58. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310/aic-ascend310-ops-info.json +123 -0
  59. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +123 -0
  60. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +158 -0
  61. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +37 -0
  62. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  63. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  64. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  65. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  66. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  67. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  68. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  69. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  70. mindspore/lib/plugin/ascend/custom_aicore_ops/op_proto/libop_proto.so +0 -0
  71. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  72. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  73. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +78 -80
  74. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  75. mindspore/lib/plugin/ascend/libakg.so +0 -0
  76. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  77. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  78. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  79. mindspore/lib/plugin/cpu/libakg.so +0 -0
  80. mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
  81. mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
  82. mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
  83. mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
  84. mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
  85. mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
  86. mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
  87. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  88. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  89. mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
  90. mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
  91. mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
  92. mindspore/nn/cell.py +0 -3
  93. mindspore/nn/layer/activation.py +4 -5
  94. mindspore/nn/layer/conv.py +39 -23
  95. mindspore/nn/layer/flash_attention.py +54 -129
  96. mindspore/nn/layer/math.py +3 -7
  97. mindspore/nn/layer/rnn_cells.py +5 -5
  98. mindspore/nn/wrap/__init__.py +4 -2
  99. mindspore/nn/wrap/cell_wrapper.py +12 -3
  100. mindspore/numpy/utils_const.py +5 -5
  101. mindspore/ops/_grad_experimental/grad_array_ops.py +1 -1
  102. mindspore/ops/_grad_experimental/grad_implementations.py +2 -2
  103. mindspore/ops/_grad_experimental/grad_math_ops.py +19 -18
  104. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
  105. mindspore/ops/_op_impl/aicpu/add.py +3 -3
  106. mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +21 -2
  107. mindspore/ops/_utils/utils.py +2 -0
  108. mindspore/ops/composite/multitype_ops/_compile_utils.py +2 -1
  109. mindspore/ops/composite/multitype_ops/getitem_impl.py +2 -2
  110. mindspore/ops/function/array_func.py +10 -7
  111. mindspore/ops/function/grad/grad_func.py +0 -1
  112. mindspore/ops/function/nn_func.py +98 -9
  113. mindspore/ops/function/random_func.py +2 -1
  114. mindspore/ops/op_info_register.py +24 -21
  115. mindspore/ops/operations/__init__.py +6 -2
  116. mindspore/ops/operations/_grad_ops.py +25 -6
  117. mindspore/ops/operations/_inner_ops.py +155 -23
  118. mindspore/ops/operations/array_ops.py +9 -7
  119. mindspore/ops/operations/comm_ops.py +2 -2
  120. mindspore/ops/operations/custom_ops.py +85 -68
  121. mindspore/ops/operations/inner_ops.py +26 -3
  122. mindspore/ops/operations/math_ops.py +7 -6
  123. mindspore/ops/operations/nn_ops.py +193 -49
  124. mindspore/parallel/_parallel_serialization.py +10 -3
  125. mindspore/parallel/_tensor.py +4 -1
  126. mindspore/parallel/checkpoint_transform.py +13 -2
  127. mindspore/parallel/shard.py +17 -10
  128. mindspore/profiler/common/util.py +1 -0
  129. mindspore/profiler/parser/ascend_hccl_generator.py +232 -0
  130. mindspore/profiler/parser/ascend_msprof_exporter.py +86 -43
  131. mindspore/profiler/parser/ascend_msprof_generator.py +196 -9
  132. mindspore/profiler/parser/ascend_op_generator.py +1 -1
  133. mindspore/profiler/parser/ascend_timeline_generator.py +6 -182
  134. mindspore/profiler/parser/base_timeline_generator.py +1 -1
  135. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +2 -2
  136. mindspore/profiler/parser/framework_parser.py +1 -1
  137. mindspore/profiler/parser/profiler_info.py +19 -0
  138. mindspore/profiler/profiling.py +46 -24
  139. mindspore/rewrite/api/pattern_engine.py +1 -1
  140. mindspore/rewrite/parsers/for_parser.py +7 -7
  141. mindspore/rewrite/parsers/module_parser.py +4 -4
  142. mindspore/rewrite/symbol_tree.py +1 -4
  143. mindspore/run_check/_check_version.py +5 -3
  144. mindspore/safeguard/rewrite_obfuscation.py +52 -28
  145. mindspore/scipy/ops.py +55 -5
  146. mindspore/scipy/optimize/__init__.py +3 -2
  147. mindspore/scipy/optimize/linear_sum_assignment.py +38 -33
  148. mindspore/train/callback/_summary_collector.py +1 -1
  149. mindspore/train/dataset_helper.py +1 -0
  150. mindspore/train/model.py +2 -2
  151. mindspore/train/serialization.py +97 -11
  152. mindspore/train/summary/_summary_adapter.py +1 -1
  153. mindspore/train/summary/summary_record.py +23 -7
  154. mindspore/version.py +1 -1
  155. {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/METADATA +3 -2
  156. {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/RECORD +160 -151
  157. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +0 -406
  158. mindspore/ops/_op_impl/_custom_op/flash_attention/constants.py +0 -41
  159. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +0 -467
  160. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +0 -563
  161. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +0 -193
  162. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +0 -435
  163. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
  164. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +0 -45
  165. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +0 -67
  166. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +0 -62
  167. /mindspore/{ops/_op_impl/_custom_op/flash_attention → _akg/akg/utils/ascend_profilier}/__init__.py +0 -0
  168. {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/WHEEL +0 -0
  169. {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/entry_points.txt +0 -0
  170. {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/top_level.txt +0 -0
@@ -238,13 +238,14 @@ class LambApplyOptimizerAssign(PrimitiveWithInfer):
238
238
  @prim_attr_register
239
239
  def __init__(self):
240
240
  """Initialize LambApplyOptimizerAssign"""
241
+ self.var_shape = "var_shape"
241
242
  self.add_prim_attr('side_effect_mem', True)
242
243
 
243
244
  def infer_shape(self, grad_shape, v_shape, m_shape, var_shape, beta1_shape, sub1_shape,
244
245
  beta2_shape, sub2_shape, eps_shape, steps_shape, use_weight_shape, weight_decay_shape):
245
- validator.check("var_shape", var_shape, "m_shape", m_shape, validator.EQ, self.name)
246
- validator.check("var_shape", var_shape, "v_shape", v_shape, validator.EQ, self.name)
247
- validator.check("var_shape", var_shape, "grad_shape", grad_shape, validator.EQ, self.name)
246
+ validator.check(self.var_shape, var_shape, "m_shape", m_shape, validator.EQ, self.name)
247
+ validator.check(self.var_shape, var_shape, "v_shape", v_shape, validator.EQ, self.name)
248
+ validator.check(self.var_shape, var_shape, "grad_shape", grad_shape, validator.EQ, self.name)
248
249
  return m_shape, v_shape, m_shape
249
250
 
250
251
  def infer_dtype(self, grad_dtype, v_dtype, m_dtype, var_dtype, beta1_dtype, sub1_dtype,
@@ -658,3 +659,25 @@ class ScaleGrad(PrimitiveWithInfer):
658
659
  @prim_attr_register
659
660
  def __init__(self):
660
661
  """Initialize ScaleGrad"""
662
+
663
+
664
+ class KVCacheMgr(Primitive):
665
+ """
666
+ Update past with cur and index along sequence axis.
667
+
668
+ Inputs:
669
+ - **past** (Parameter) - 4-D tensor with shape: :math:`(batch_size, num_head, seq_len, hidden_size)`.
670
+ - **cur** (Tensor) - 4-D tensor with shape: :math:`(batch_size, num_head, 1, hidden_size)`.
671
+ - **index** (Tensor) - 1-D tensor with shape: :math:`(batch_size,)`.
672
+
673
+ Outputs:
674
+ Tensor, has the same data type and shape as original `past`.
675
+
676
+ Supported Platforms:
677
+ ``Ascend``
678
+ """
679
+
680
+ @prim_attr_register
681
+ def __init__(self):
682
+ self.init_prim_io_names(inputs=['past', 'cur', 'index'], outputs=['past'])
683
+ self.add_prim_attr('side_effect_mem', True)
@@ -1536,9 +1536,8 @@ class LpNorm(Primitive):
1536
1536
  """
1537
1537
 
1538
1538
  @prim_attr_register
1539
- def __init__(self, axis, p=2, keep_dims=False, epsilon=1e-12):
1539
+ def __init__(self, axis=(), p=2, keep_dims=False, epsilon=1e-12):
1540
1540
  """Initialize LpNorm"""
1541
- super().__init__("LpNorm")
1542
1541
  validator.check_value_type("p", p, [int], self.name)
1543
1542
  validator.check_value_type("axis", axis, [int, tuple, list], self.name)
1544
1543
  validator.check_value_type("keep_dims", keep_dims, [bool], self.name)
@@ -2494,6 +2493,7 @@ class Reciprocal(PrimitiveWithCheck):
2494
2493
  self.init_prim_io_names(inputs=['x'], outputs=['y'])
2495
2494
 
2496
2495
  def infer_value(self, x):
2496
+ """Infer value for Reciprocal"""
2497
2497
  if x is not None:
2498
2498
  x = x.asnumpy()
2499
2499
  out = 1.0 / x
@@ -2551,6 +2551,7 @@ class Pow(Primitive):
2551
2551
  self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['y'])
2552
2552
 
2553
2553
  def infer_value(self, x, power):
2554
+ """infer value for _BinaryOp"""
2554
2555
  if x is not None and power is not None:
2555
2556
  x = x.asnumpy()
2556
2557
  power = power.asnumpy()
@@ -2931,7 +2932,7 @@ class Histogram(Primitive):
2931
2932
  """
2932
2933
 
2933
2934
  @prim_attr_register
2934
- def __init__(self, bins=100, min=0.0, max=0.0): # pylint: disable=W0622
2935
+ def __init__(self, bins=100, min=0.0, max=0.0):
2935
2936
  """Initialize Histogram."""
2936
2937
  self.init_prim_io_names(inputs=['x'], outputs=['y'])
2937
2938
  validator.check_value_type("bins", bins, [int], self.name)
@@ -6568,9 +6569,9 @@ class LinSpace(Primitive):
6568
6569
 
6569
6570
  Inputs:
6570
6571
  - **start** (Tensor) - Start value of interval, 0-D Tensor with dtype float32 or float64.
6571
- - **stop** (Tensor) - Last value of interval, 0-D Tensor with dtype float32 or float64.
6572
- - **num** (int) - Number of ticks in the interval, inclusive of `start` and `stop`.
6573
- Supported dtypes: int32, int64.
6572
+ - **stop** (Tensor) - Last value of interval, 0-D Tensor with dtype float32 or float64.
6573
+ - **num** (Union[int, Tensor]) - Number of ticks in the interval, inclusive of `start` and `stop`.
6574
+ Must be a positive integer. When the input is Tensor, it must be a 0-D Tensor with dtype int32 or int64.
6574
6575
 
6575
6576
  Outputs:
6576
6577
  Tensor, has the same shape and dtype as `start`.
@@ -1990,6 +1990,7 @@ class MaxPoolV1(Primitive):
1990
1990
  self.add_prim_attr("kernel_size", kernel_size_adapted)
1991
1991
  self.add_prim_attr("strides", strides_adapted)
1992
1992
 
1993
+
1993
1994
  class MaxPool3D(Primitive):
1994
1995
  r"""
1995
1996
  Applies a 3D max pooling over an input Tensor which can be regarded as a composition of 3D planes.
@@ -3918,7 +3919,6 @@ class ResizeBilinear(PrimitiveWithInfer):
3918
3919
  def infer_dtype(self, input_dtype):
3919
3920
  validator.check_tensor_dtype_valid('input_dtype', input_dtype, [mstype.float16, mstype.float32],
3920
3921
  self.name)
3921
- self.add_prim_attr("dtype", input_dtype)
3922
3922
  return input_dtype
3923
3923
 
3924
3924
 
@@ -4009,6 +4009,7 @@ class OneHot(Primitive):
4009
4009
 
4010
4010
  Note:
4011
4011
  If the input indices is rank `N`, the output will have rank `N+1`. The new axis is created at dimension `axis`.
4012
+ On Ascend, if `on_value` is Int64 dtype, `indices` must be Int64 dtype.
4012
4013
 
4013
4014
  Args:
4014
4015
  axis (int): Position to insert the value. e.g. If shape of `indices` is :math:`(N, C)`, and `axis` is -1,
@@ -4019,12 +4020,14 @@ class OneHot(Primitive):
4019
4020
  - **indices** (Tensor) - A tensor of indices. Tensor of shape :math:`(X_0, \ldots, X_n)`.
4020
4021
  Data type must be int32 or int64.
4021
4022
  - **depth** (int) - A scalar defining the depth of the one-hot dimension.
4022
- - **on_value** (Tensor) - A value to fill in output when `indices[j] = i`.
4023
+ - **on_value** (Tensor) - A value to fill in output when `indices[j] = i`. Data type must be int32, int64,
4024
+ float16 or float32.
4023
4025
  - **off_value** (Tensor) - A value to fill in output when `indices[j] != i`.
4024
4026
  It has the same data type as `on_value`.
4025
4027
 
4026
4028
  Outputs:
4027
- Tensor, one-hot tensor. Tensor of shape :math:`(X_0, \ldots, X_{axis}, \text{depth} ,X_{axis+1}, \ldots, X_n)`.
4029
+ Tensor, one-hot tensor. Tensor of shape :math:`(X_0, \ldots, X_{axis}, \text{depth} ,X_{axis+1}, \ldots, X_n)`,
4030
+ and it has the same data type as `on_value`.
4028
4031
 
4029
4032
  Raises:
4030
4033
  TypeError: If `axis` or `depth` is not an int.
@@ -8259,8 +8262,12 @@ class Conv3D(Primitive):
8259
8262
  self.add_prim_attr('data_format', self.format)
8260
8263
  self.out_channel = validator.check_positive_int(out_channel, 'out_channel', self.name)
8261
8264
  validator.check_value_type("group", group, (int,), self.name)
8265
+ validator.check_int_range(group, 1, out_channel, validator.INC_BOTH, "group", self.name)
8266
+ device_target = context.get_context("device_target")
8262
8267
  if self.out_channel % group != 0:
8263
8268
  raise ValueError("The argument 'group' should be divisible by 'out_channel'")
8269
+ if device_target == "Ascend" and group != 1:
8270
+ raise ValueError("On Ascend platform, group = 1 must be satisfied.")
8264
8271
 
8265
8272
  self.group = group
8266
8273
  self.add_prim_attr('groups', self.group)
@@ -8956,8 +8963,10 @@ class Dilation2D(Primitive):
8956
8963
  self.pad_mode = validator.check_string(pad_mode, ['VALID', 'SAME', 'valid', 'same'], 'pad_mode', self.name)
8957
8964
  self.add_prim_attr('pad_mode', self.pad_mode.upper())
8958
8965
  self.stride = _check_format_stride_or_dilation("stride", stride, self.name, self.data_format)
8966
+
8959
8967
  def is_in_range(x):
8960
8968
  return 1 <= x <= 255
8969
+
8961
8970
  if not is_in_range(self.stride[2]) or not is_in_range(self.stride[3]):
8962
8971
  raise ValueError(f'For Dilation2D, size of stride is not supported, '
8963
8972
  f'stride should be in the range of [1, 255], '
@@ -11325,9 +11334,24 @@ class PromptFlashAttention(Primitive):
11325
11334
  S -- Sequence length
11326
11335
  H -- Hidden size
11327
11336
 
11337
+ Refer to :func:mindspore.ops.prompt_flash_attention for more detail.
11338
+
11328
11339
  .. warning::
11329
11340
  This is an experimental API that is subject to change or deletion.
11330
11341
 
11342
+ Args:
11343
+ num_heads (int): The number of heads.
11344
+ scale_value (float): The scale value indicating the scale coefficient, which is used as the scalar of
11345
+ Muls in the calculation. Default: 1.0.
11346
+ pre_tokens (int): Previous tokens. Default: 2147483547.
11347
+ next_tokens (int): next tokens. Default: 0.
11348
+ indicate the upper triangle, Indicate the number of data blocks involved in the calculation. The value 0
11349
+ indicates that the data blocks in the upper triangle are not involved in the calculation
11350
+ input_layout (str): the data layout of the input qkv, support `(BSH)` and `(BNSD)`, Default `BSH`.
11351
+ num_key_value_heads (int): head numbers of key/value which are used in GQA algorithm.
11352
+ The value o indicates if the key and value have the same head nums, use numHeads. Default: 0.
11353
+ sparse_mode (int): Default: 0
11354
+
11331
11355
  Inputs:
11332
11356
  - **query** (Tensor) - The query tensor with data type of float16 or float32.
11333
11357
  Input tensor of shape :math:`(B, S, H)` / `(B, N, S, D)`.
@@ -11337,28 +11361,42 @@ class PromptFlashAttention(Primitive):
11337
11361
  Input tensor of shape :math:`(B, S, H)` / `(B, N, S, D)`.
11338
11362
  - **attn_mask** (Tensor) - The attention mask tensor with data type of float16 or float32.
11339
11363
  For each element, 0 indicates retention and 1 indicates discard. Input tensor of shape :math:`(B, 1, S, S)`.
11340
- - **padding_mask** (Tensor) - The padding mask tensor with data type of float16 or float32
11341
11364
  - **actual_seq_lengths** (Tensor): Describe actual sequence length of each input with data type of int.
11342
- - **num_heads** (int): The number of heads.
11343
- - **scale_value** (float): The scale value indicating the scale coefficient, which is used as the scalar of
11344
- Muls in the calculation. Default: 1.0.
11345
- - **pre_tokens** (int): Previous tokens. Default: 2147483547.
11346
- - **next_tokens** (int): next tokens. Default: 0.
11347
- indicate the upper triangle, Indicate the number of data blocks involved in the calculation. The value 0
11348
- indicates that the data blocks in the upper triangle are not involved in the calculation
11349
- - **input_layout** (str): the data layout of the input qkv, support `(BSH)` and `(BNSD)`, Default `BSH`.
11350
- - **num_key_value_heads** (int): head numbers of key/value which are used in GQA algorithm.
11351
- The value o indicates if the key and value have the same head nums, use numHeads. Default: 0.
11365
+ - **actual_seq_lengths_kv** (Tensor): Describe actual sequence length of each input with data type of int.
11366
+ - **padding_mask** (Tensor) - The padding mask tensor with data type of float16 or float32
11367
+ - **dep_scale1** (Tensor)
11368
+ - **quant_scale1** (Tensor)
11369
+ - **deq_scale2** (Tensor)
11370
+ - **quant_scale2** (Tensor)
11371
+ - **quant_offset2** (Tensor)
11372
+
11352
11373
 
11353
11374
  Outputs:
11354
11375
  - **attention_out** (Tensor) - Input tensor of shape :math:`(B, S, H)` / `(B, N, S, D)`.
11355
11376
 
11356
- Supported Platforms:
11357
- ``Ascend910B``
11377
+ Supported Platforms:
11378
+ ``Ascend``
11379
+
11380
+ Examples:
11381
+ >>> import mindspore.ops.operations.nn_ops as P
11382
+ >>> from mindspore import Tensor
11383
+ >>> import numpy as np
11384
+ >>> B = 1
11385
+ >>> N = 16
11386
+ >>> S = 256
11387
+ >>> D = 16
11388
+ >>> query = Tensor(np.ones((B, N, S, D), dtype=np.float16))
11389
+ >>> key = Tensor(np.ones((B, N, S, D), dtype=np.float16))
11390
+ >>> value = Tensor(np.ones((B, N, S, D), dtype=np.float16))
11391
+ >>> pfa = P.PromptFlashAttention(N, input_layout='BNSD')
11392
+ >>> out = pfa(query, key, value, None, None, None, None, None, None, None, None, None)
11393
+ >>> print(out[0].shape)
11394
+ (1, 16, 256, 16)
11358
11395
  """
11396
+
11359
11397
  @prim_attr_register
11360
11398
  def __init__(self, num_heads, scale_value=1.0, pre_tokens=2147483547, next_tokens=0, input_layout='BSH',
11361
- num_key_value_heads=0):
11399
+ num_key_value_heads=0, sparse_mode=0):
11362
11400
  """Initialize PromptFlashAttention."""
11363
11401
  validator.check_value_type('num_heads', num_heads, [int], self.name)
11364
11402
  validator.check_value_type('scale_value', scale_value, [float], self.name)
@@ -11366,7 +11404,10 @@ class PromptFlashAttention(Primitive):
11366
11404
  validator.check_value_type('next_tokens', next_tokens, [int], self.name)
11367
11405
  validator.check_value_type('input_layout', input_layout, [str], self.name)
11368
11406
  validator.check_value_type('num_key_value_heads', num_key_value_heads, [int], self.name)
11369
- self.init_prim_io_names(inputs=["query", "key", "value", "attn_mask", "padding_mask", "actual_seq_lengths"],
11407
+ validator.check_value_type('sparse_mode', sparse_mode, [int], self.name)
11408
+ self.init_prim_io_names(inputs=["query", "key", "value", "attn_mask", "actual_seq_lengths",
11409
+ "actual_seq_lengths_kv", "padding_mask", "deq_scale1", "quant_scale1",
11410
+ "deq_scale2", "quant_scale2", "quant_offset2"],
11370
11411
  outputs=["attention_out"])
11371
11412
 
11372
11413
 
@@ -11376,46 +11417,57 @@ class FlashAttentionScore(Primitive):
11376
11417
  .. warning::
11377
11418
  This is an experimental API that is subject to change or deletion.
11378
11419
  B -- Batch size
11379
- S -- Sequence length
11380
- H -- Hidden size
11381
- N -- Num heads
11382
- D -- Dim size
11420
+ S1 -- Sequence length of query
11421
+ S2 -- Sequence length of key and value
11422
+ N1 -- Num heads of query
11423
+ N2 -- Num heads of key and value, and N2 must be a factor of N1
11424
+ D -- head size
11425
+ H1 -- Hidden size of query, which equals to N1 * D
11426
+ H2 -- Hidden size of key and value, which equals to N2 * D
11383
11427
  Args:
11384
- head_num (int): The number of the heads.
11428
+ head_num (int): The head num of query.
11385
11429
  keep_prob (float): The keep probability of dropout. Default: 1.0.
11386
11430
  scale_value (float): The scale value. Default: 1.0.
11387
11431
  pre_tokens (int): Previous tokens. Default: 65536.
11388
11432
  next_tokens (int): Next tokens. Default: 65536.
11389
11433
  inner_precise (int): Specify the execution mode, where 0 indicates high precision mode and 1 indicates high
11390
- performance mode. Default: 0.
11391
- input_layout (str, optional): Specifies the layout of `query`, the value must be one of ["BSH", "SBH"].
11392
- Currently, only BSH is supported. Default: "BSH".
11393
-
11394
- Inputs:
11395
- - **query** (Tensor) - The query tensor with data type of float16 or float32.
11396
- Input tensor of shape :math:`(B, S, H)`.
11397
- - **key** (Tensor) - The key tensor with data type of float16 or float32.
11398
- Input tensor of shape :math:`(B, S, H)`.
11399
- - **value** (Tensor) - The value tensor with data type of float16 or float32.
11400
- Input tensor of shape :math:`(B, S, H)`.
11401
- - **attn_mask** (Tensor) - The attention mask tensor with data type of float16 or float32.
11402
- For each element, 0 indicates retention and 1 indicates discard. Input tensor of shape :math:`(B, 1, S, S)`.
11403
- - **drop_mask** (Tensor) - The dropout mask tensor with data type of UInt8.
11404
- Input tensor of shape :math:`(B, N, S, S // 8) or ()`.
11405
- - **real_shift** (None) - The position embedding code of float16 or float32, not implemented yet.
11434
+ performance mode. Only support 0 currently. Default: 0.
11435
+ input_layout (str, optional): Specifies the layout of `query`, the value must be one of ["BSH", "BNSD"].
11436
+ Default: "BSH".
11437
+ sparse_mode (int): Default 0.
11438
+
11439
+ Inputs:
11440
+ - **query** (Tensor[float16, float32, bfloat16]) - The query tensor.
11441
+ Input tensor of shape :math:`(B, S1, H1)` or `(B, N1, S1, D)`.
11442
+ - **key** (Tensor[float16, float32, bfloat16]) - The key tensor.
11443
+ Input tensor of shape :math:`(B, S2, H2)` or `(B, N2, S2, D)`.
11444
+ - **value** (Tensor[float16, float32, bfloat16]) - The value tensor.
11445
+ Input tensor of shape :math:`(B, S2, H2)` or `(B, N2, S2, D)`.
11446
+ - **real_shift** (Tensor[float16, float32, bfloat16], None) - The position embedding code.
11447
+ Input tensor of shape :math: `(B, N1, S1, S2)` or `(B, N1, 1, S2)`.
11448
+ - **drop_mask** (Tensor[uint8], None) - The dropout mask tensor.
11449
+ Input tensor of shape :math:`(B, N1, S1, S2 // 8) or None`.
11406
11450
  - **padding_mask** (None) - The padding mask of float16 or float32, not implemented yet.
11451
+ - **attn_mask** (Tensor[uint8], None) - The attention mask tensor.
11452
+ For each element, 0 indicates retention and 1 indicates discard.
11453
+ Input tensor of shape :math:`(B, N1, S1, S2)`, `(B, 1, S1, S2)` or `(S1, S2)`.
11454
+ - **prefix** (Tensor[int64], None) - Not implemented yet.
11455
+ Input tensor of shape :math:`(B,)`.
11407
11456
 
11408
11457
  Outputs:
11409
- - **attention_out** (Tensor) - (B, S, H)
11410
- - **softmax_max** (Tensor) - (B, N, S, 16)/(B, N, S, 8) when fp16/fp32
11411
- - **softmax_sum** (Tensor) - (B, N, S, 16)/(B, N, S, 8) when fp16/fp32
11458
+ - **softmax_max** (Tensor[float32]) - (B, N1, S1, 8)
11459
+ - **softmax_sum** (Tensor[float32]) - (B, N1, S1, 8)
11460
+ - **softmax_out** (Tensor[float32]) - Useless output, ignore it. Output tensor of shape : `()`
11461
+ - **attention_out** (Tensor[float16, float32, bfloat16]) - The output of attention, its shape, and data type
11462
+ are the same as the query.
11463
+
11412
11464
  Supported Platforms:
11413
11465
  ``Ascend``
11414
11466
  """
11415
11467
 
11416
11468
  @prim_attr_register
11417
11469
  def __init__(self, head_num, keep_prob=1.0, scale_value=1.0, pre_tokens=65536, next_tokens=65536, inner_precise=0,
11418
- input_layout="BSH"):
11470
+ input_layout="BSH", sparse_mode=0):
11419
11471
  """Initialize FlashAttentionScore"""
11420
11472
  validator.check_value_type('head_num', head_num, [int], self.name)
11421
11473
  validator.check_value_type('keep_prob', keep_prob, [int, float], self.name)
@@ -11425,11 +11477,103 @@ class FlashAttentionScore(Primitive):
11425
11477
  validator.check_value_type('pre_tokens', pre_tokens, [int], self.name)
11426
11478
  validator.check_value_type('next_tokens', next_tokens, [int], self.name)
11427
11479
  validator.check_value_type('inner_precise', inner_precise, [int], self.name)
11428
- if inner_precise not in [0, 1]:
11429
- raise ValueError(f"Attribute 'inner_precise' must be either 0 or 1, but got {inner_precise}")
11480
+ validator.check_value_type('sparse_mode', sparse_mode, [int], self.name)
11481
+ if inner_precise not in [0]:
11482
+ raise ValueError(f"Attribute 'inner_precise' must be 0, but got {inner_precise}")
11430
11483
  validator.check_value_type('input_layout', input_layout, [str], self.name)
11431
- if input_layout not in ["BSH"]:
11432
- raise ValueError(f"Attribute 'input_layout' must be either 'bsh' or 'sbh', but got {input_layout}")
11484
+ if input_layout not in ["BSH", "BNSD"]:
11485
+ raise ValueError(f"Attribute 'input_layout' must be either 'BSH' or 'BNSD', but got {input_layout}")
11433
11486
  self.init_prim_io_names(
11434
- inputs=['query', 'key', 'value', 'attn_mask', 'drop_mask', 'real_shift', 'padding_mask'],
11435
- outputs=['attention_out', 'softmax_max', 'softmax_sum'])
11487
+ inputs=['query', 'key', 'value', 'real_shift', 'drop_mask', 'padding_mask', 'attn_mask', 'prefix'],
11488
+ outputs=['softmax_max', 'softmax_sum', 'softmax_out', 'attention_out'])
11489
+
11490
+
11491
+ class RmsNorm(Primitive):
11492
+ r"""
11493
+ The RmsNorm operator is a normalization operation, and its formula is:
11494
+
11495
+ .. math::
11496
+ y=\frac{x_i}{\sqrt{\frac{1}{n}}\sum_{i=1}^{n}{ x_i^2}+\varepsilon }\gamma_i
11497
+
11498
+ .. warning::
11499
+ This is an experimental API that is subject to change or deletion.
11500
+
11501
+ Args:
11502
+ epsilon (float): prevent division by 0, default value is `1e-6`
11503
+
11504
+ Inputs:
11505
+ - **input_x** (Tensor) - Input data of RmsNorm, support data type: float16, float32, bfloat16.
11506
+ - **gamma** (Tensor) - Support data type: float16, float32, bfloat16.
11507
+
11508
+ Outputs:
11509
+ - **y** (Tensor) - Has the same type and shape with `input_x`.
11510
+ - **rstd** (Tensor) - Has the same type with `input_x`, used by gradient calculation.
11511
+
11512
+ Raises:
11513
+ TypeError: If data type of `input_x` is not one of the following: float16, float32, bfloat16.
11514
+ TypeError: If data type of `gamma` is not one of the following: float16, float32, bfloat16.
11515
+ TypeError: If data type of "input_x" is not the same with the data type of "gamma"
11516
+
11517
+ Supported Platforms:
11518
+ ``Ascend``
11519
+ """
11520
+
11521
+ @prim_attr_register
11522
+ def __init__(self, epsilon=1e-6):
11523
+ """Initialize Dense."""
11524
+ validator.check_value_type("epsilon", epsilon, [float], self.name)
11525
+ self.init_prim_io_names(inputs=['x', 'gamma'], outputs=["y", "rstd"])
11526
+
11527
+
11528
+ class PagedAttention(Primitive):
11529
+ r"""
11530
+ .. warning::
11531
+ This is an experimental API that is subject to change or deletion.
11532
+ """
11533
+ @prim_attr_register
11534
+ def __init__(self, head_num, scale_value=1.0, kv_head_num=0):
11535
+ """Initialize PagedAttention"""
11536
+ validator.check_value_type('head_num', head_num, [int], self.name)
11537
+ validator.check_value_type('scale_value', scale_value, [float], self.name) # scale after qkbmm
11538
+ validator.check_value_type('kv_head_num', kv_head_num, [int], self.name) # for MQA
11539
+ self.init_prim_io_names(
11540
+ inputs=['query', 'key_cache', 'value_cache', 'block_tables', 'context_lens'],
11541
+ outputs=['attention_out'])
11542
+
11543
+
11544
+ class PagedAttentionMask(Primitive):
11545
+ r"""
11546
+ .. warning::
11547
+ This is an experimental API that is subject to change or deletion.
11548
+ """
11549
+ @prim_attr_register
11550
+ def __init__(self, head_num, scale_value=1.0, kv_head_num=0):
11551
+ """Initialize PagedAttentionMask"""
11552
+ validator.check_value_type('head_num', head_num, [int], self.name)
11553
+ validator.check_value_type('scale_value', scale_value, [float], self.name) # scale after qkbmm
11554
+ validator.check_value_type('kv_head_num', kv_head_num, [int], self.name) # for MQA
11555
+ self.init_prim_io_names(
11556
+ inputs=['query', 'key_cache', 'value_cache', 'block_tables', 'context_lens', 'alibi_mask'],
11557
+ outputs=['attention_out'])
11558
+
11559
+
11560
+ class ReshapeAndCache(Primitive):
11561
+ r"""
11562
+ .. warning::
11563
+ This is an experimental API that is subject to change or deletion.
11564
+ """
11565
+ __mindspore_signature__ = (
11566
+ sig.make_sig('key', dtype=sig.sig_dtype.T),
11567
+ sig.make_sig('value', dtype=sig.sig_dtype.T),
11568
+ sig.make_sig('key_cache', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
11569
+ sig.make_sig('value_cache', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
11570
+ sig.make_sig('slot_mapping', dtype=sig.sig_dtype.T1),
11571
+ )
11572
+
11573
+ @prim_attr_register
11574
+ def __init__(self):
11575
+ """Initialize ReshapeAndCache"""
11576
+ self.init_prim_io_names(
11577
+ inputs=['key', 'value', 'key_cache', 'value_cache', 'slot_mapping'],
11578
+ outputs=['key_out'])
11579
+ self.add_prim_attr('side_effect_mem', True)
@@ -335,7 +335,8 @@ def _rank_list_for_transform_parallel_checkpoint(rank_id, src_strategy_list, dst
335
335
  return list(result_list)
336
336
 
337
337
 
338
- def _transform_parallel_checkpoint(rank_id, param_total_dict, param_attr_dict, src_strategy_list, dst_strategy_list):
338
+ def _transform_parallel_checkpoint(rank_id, param_total_dict, param_attr_dict, src_strategy_list,
339
+ dst_strategy_list, param_type_dict):
339
340
  """
340
341
  Transform model parallel dimension for distributed checkpoint files.
341
342
  """
@@ -397,15 +398,21 @@ def _transform_parallel_checkpoint(rank_id, param_total_dict, param_attr_dict, s
397
398
  transform_tensor = ms.Tensor(param_total_dict[param_name][rank_id % device_num])
398
399
  requires_grad = param_attr_dict[param_name][rank_id % device_num][0]
399
400
  layerwise_parallel = param_attr_dict[param_name][rank_id % device_num][1]
400
- transform_param_dict[param_name] = ms.Parameter(transform_tensor, param_name, requires_grad, layerwise_parallel)
401
+ transform_para = ms.Parameter(transform_tensor, param_name, requires_grad, layerwise_parallel)
402
+ if param_type_dict[param_name][rank_id % device_num] == "BFloat16":
403
+ transform_para.set_dtype(ms.bfloat16)
404
+ transform_param_dict[param_name] = transform_para
401
405
 
402
406
  # Handle those parameter like learning_rate, global_step which not in strategy_file.
403
407
  for param_name, _ in param_total_dict.items():
404
408
  if param_name not in transform_param_dict:
405
- transform_param_dict[param_name] = ms.Parameter(
409
+ transform_para = ms.Parameter(
406
410
  ms.Tensor(param_total_dict[param_name][rank_id % device_num]), param_name,
407
411
  param_attr_dict[param_name][rank_id % device_num][0],
408
412
  param_attr_dict[param_name][rank_id % device_num][1])
413
+ if param_type_dict[param_name][rank_id % device_num] == "BFloat16":
414
+ transform_para.set_dtype(ms.bfloat16)
415
+ transform_param_dict[param_name] = transform_para
409
416
 
410
417
  transform_param_list = [{"name": param_name, "data": param_data}
411
418
  for param_name, param_data in transform_param_dict.items()]
@@ -17,6 +17,7 @@ from __future__ import division
17
17
  from __future__ import absolute_import
18
18
 
19
19
  import numpy as np
20
+ from mindspore.common import dtype as mstype
20
21
  from mindspore.common.tensor import Tensor
21
22
  from mindspore.communication.management import get_rank, get_group_size
22
23
  from mindspore._c_expression import TensorTransform
@@ -221,6 +222,8 @@ def _load_tensor(tensor, dev_mat, tensor_map, rank_id=-1):
221
222
  rank = rank_id
222
223
  tensor_strategy = _get_tensor_strategy(dev_mat, tensor_map)
223
224
  tensor_slice_index = _get_tensor_slice_index(dev_mat, tensor_strategy, tensor_map, rank)
225
+ if tensor.dtype == mstype.bfloat16:
226
+ tensor = tensor.float()
224
227
  np_tensor = tensor.asnumpy()
225
228
  np_tensor_list = _chunk_tensor_by_strategy(np_tensor, tensor_strategy)
226
229
  np_tensor_slice = np_tensor_list[int(tensor_slice_index)]
@@ -260,7 +263,7 @@ def _load_tensor_by_layout(tensor, layout, rank_id):
260
263
  rank = get_rank(group)
261
264
  size = get_group_size(group)
262
265
  tensor_slice = np.split(tensor_slice, size)[rank]
263
- return Tensor(tensor_slice)
266
+ return Tensor(tensor_slice, tensor.dtype)
264
267
 
265
268
 
266
269
  def _reshape_param_data(param_data, dev_mat, tensor_map):
@@ -21,6 +21,7 @@ import copy
21
21
  from collections import defaultdict
22
22
  import numpy as np
23
23
  import mindspore as ms
24
+ from mindspore.common import dtype as mstype
24
25
  from mindspore.parallel._parallel_serialization import _rank_list_for_transform_parallel_checkpoint, \
25
26
  _transform_parallel_checkpoint, _get_device_num_from_strategy, _make_dir, \
26
27
  _extract_layout_map, _extract_src_dst_layout_map, _parameter_not_in_local_stage, _extract_pipeline_stage_num, \
@@ -192,6 +193,7 @@ def transform_checkpoint_by_rank(rank_id, checkpoint_files_map, save_checkpoint_
192
193
  raise ValueError("Checkpoint file {} in rank {} not exits: ".format(local_file, rank))
193
194
  param_total_dict = defaultdict(dict)
194
195
  param_attr_dict = defaultdict(dict)
196
+ param_type_dict = defaultdict(dict)
195
197
  src_strategy_list, dst_strategy_list = _extract_src_dst_layout_map(rank_id, src_strategy_file, dst_strategy_file)
196
198
  # src rank => local rank inside pipeline stage
197
199
  src_stage_device_num = np.prod(src_strategy_list.get(list(src_strategy_list.keys())[0])[0]) if src_strategy_list \
@@ -208,11 +210,15 @@ def transform_checkpoint_by_rank(rank_id, checkpoint_files_map, save_checkpoint_
208
210
  and _parameter_not_in_local_stage(param_name, origin_dst_strategy_list, dst_strategy_list):
209
211
  continue
210
212
  src_rank = rank % src_stage_device_num
213
+ param_type_dict[param_name][src_rank] = str(param.data.dtype)
214
+ if param.data.dtype == mstype.bfloat16:
215
+ param.set_dtype(mstype.float32)
211
216
  param_total_dict[param_name][src_rank] = param.data.asnumpy()
212
217
  param_attr_dict[param_name][src_rank] = (param.requires_grad, param.layerwise_parallel)
213
218
  local_rank_id = rank_id % dst_stage_device_num
214
219
  transform_param_list = _transform_parallel_checkpoint(local_rank_id, param_total_dict,
215
- param_attr_dict, src_strategy_list, dst_strategy_list)
220
+ param_attr_dict, src_strategy_list, dst_strategy_list,
221
+ param_type_dict)
216
222
  ms.save_checkpoint(transform_param_list, save_checkpoint_file_name)
217
223
 
218
224
 
@@ -297,11 +303,15 @@ def transform_checkpoints(src_checkpoints_dir, dst_checkpoints_dir, ckpt_prefix,
297
303
  for needed_rank_list_key, transform_rank_list in needed_rank_list_map.items():
298
304
  param_total_dict = defaultdict(dict)
299
305
  param_attr_dict = defaultdict(dict)
306
+ param_type_dict = defaultdict(dict)
300
307
  needed_rank_list = needed_rank_list_key.split("-")
301
308
  for needed_rank in needed_rank_list:
302
309
  ckpt_dict = ms.load_checkpoint(all_checkpoint_files_map.get(int(needed_rank)))
303
310
  for param_name, param in ckpt_dict.items():
304
311
  src_rank = int(needed_rank) % src_stage_device_num
312
+ param_type_dict[param_name][src_rank] = str(param.data.dtype)
313
+ if param.data.dtype == mstype.bfloat16:
314
+ param.set_dtype(mstype.float32)
305
315
  param_total_dict[param_name][src_rank] = param.data.asnumpy()
306
316
  param_attr_dict[param_name][src_rank] = (param.requires_grad, param.layerwise_parallel)
307
317
  for transform_rank in transform_rank_list:
@@ -316,7 +326,8 @@ def transform_checkpoints(src_checkpoints_dir, dst_checkpoints_dir, ckpt_prefix,
316
326
 
317
327
  local_rank_id = transform_rank % dst_stage_device_num
318
328
  transform_param_list = _transform_parallel_checkpoint(local_rank_id, param_total_dict_copy,
319
- param_attr_dict, src_strategy_list, dst_strategy_list)
329
+ param_attr_dict, src_strategy_list, dst_strategy_list,
330
+ param_type_dict)
320
331
  save_checkpoint_file = "{}{}.ckpt".format(ckpt_prefix, transform_rank)
321
332
  save_checkpoint_file_dir = os.path.join(dst_checkpoints_dir, "rank_{}".format(transform_rank))
322
333
  if not os.path.exists(save_checkpoint_file_dir):
@@ -36,13 +36,17 @@ class Shard(Shard_):
36
36
  def __call__(self, fn, in_strategy, out_strategy=None, parameter_plan=None, device="Ascend", level=0):
37
37
  if ms.context.get_context("mode") != ms.context.PYNATIVE_MODE or \
38
38
  ms.context.get_auto_parallel_context("parallel_mode") not in ["auto_parallel"]:
39
- raise AssertionError(f"Cell shard only supports auto parallel under PyNative mode.")
39
+ raise AssertionError(
40
+ f"Cell shard only supports auto parallel under PyNative mode.")
40
41
  if ms.context.get_context("device_target") not in ["Ascend", "GPU"]:
41
- raise AssertionError(f"'Shard' now only supports 'Ascend' and 'GPU'")
42
+ raise AssertionError(
43
+ f"'Shard' now only supports 'Ascend' and 'GPU'")
42
44
  if ms.context.get_auto_parallel_context("search_mode") != "sharding_propagation":
43
- raise AssertionError(f"'search_mode' must be 'sharding_propagation' for 'Shard'")
45
+ raise AssertionError(
46
+ f"'search_mode' must be 'sharding_propagation' for 'Shard'")
44
47
  if not isinstance(in_strategy, tuple):
45
- raise TypeError(f"For 'Shard', the 'in_strategy' should be a tuple, but got {type(in_strategy).__name__}")
48
+ raise TypeError(
49
+ f"For 'Shard', the 'in_strategy' should be a tuple, but got {type(in_strategy).__name__}")
46
50
  if not isinstance(out_strategy, (type(None), tuple)):
47
51
  raise TypeError(f"For 'Shard', the 'out_strategy' should be None or tuple, "
48
52
  f"but got {type(out_strategy).__name__}")
@@ -117,7 +121,8 @@ class Shard(Shard_):
117
121
  return
118
122
  if isinstance(parameter_plan, dict):
119
123
  if not isinstance(fn, ms.nn.Cell):
120
- raise TypeError(f"If parameter_plan is set, type of fn must be mindspore.nn.Cell, but got {type(fn)}")
124
+ raise TypeError(
125
+ f"If parameter_plan is set, type of fn must be mindspore.nn.Cell, but got {type(fn)}")
121
126
  for k in parameter_plan.keys():
122
127
  v = parameter_plan[k]
123
128
  if not isinstance(k, str) or not isinstance(v, tuple):
@@ -131,10 +136,12 @@ class Shard(Shard_):
131
136
  param_strategy = parameter_plan[param_name]
132
137
  param = self._search_parameter_by_name(param_name, fn)
133
138
  if param is None:
134
- logger.warning(f"{param_name} is not exist, ignored its setting.")
139
+ logger.warning(
140
+ f"{param_name} is not exist, ignored its setting.")
135
141
  continue
136
142
 
137
- self._check_layout_is_valid(param_name, param.shape, param_strategy)
143
+ self._check_layout_is_valid(
144
+ param_name, param.shape, param_strategy)
138
145
  if param.param_info.param_strategy:
139
146
  logger.warning(f"The layout of parameter '{param_name}' "
140
147
  f"has been set to {param.param_info.param_strategy}, "
@@ -143,7 +150,7 @@ class Shard(Shard_):
143
150
 
144
151
  def _is_attrs_has_been_set(self, fn, in_strategy, out_strategy, device, level):
145
152
  return self.shard_fn is not None and self.fn == fn and self.in_strategy == in_strategy and \
146
- self.out_strategy == out_strategy and self.device == device and self.level == level
153
+ self.out_strategy == out_strategy and self.device == device and self.level == level
147
154
 
148
155
 
149
156
  def shard(fn, in_strategy, out_strategy=None, parameter_plan=None, device="Ascend", level=0):
@@ -216,8 +223,8 @@ def shard(fn, in_strategy, out_strategy=None, parameter_plan=None, device="Ascen
216
223
  ... device_num=2)
217
224
  >>> def test_shard(x, y):
218
225
  ... return x + y
219
- >>> x = Tensor(np.ones(shape=(32, 10)))
220
- >>> y = Tensor(np.ones(shape=(32, 10)))
226
+ >>> x = Tensor(np.ones(shape=(32, 10)), dtype=ms.float32)
227
+ >>> y = Tensor(np.ones(shape=(32, 10)), dtype=ms.float32)
221
228
  >>> output = ms.shard(test_shard, in_strategy=((2, 1), (2, 1)))(x, y)
222
229
  >>> print(output.shape)
223
230
  (32, 10)
@@ -25,6 +25,7 @@ import stat
25
25
 
26
26
  from mindspore import log as logger
27
27
 
28
+
28
29
  def to_int(param, param_name):
29
30
  """
30
31
  Transfer param to int type.