mindspore 2.2.0__cp38-cp38-manylinux1_x86_64.whl → 2.2.11__cp38-cp38-manylinux1_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/_akg/akg/composite/build_module.py +104 -20
- mindspore/_akg/akg/utils/ascend_profilier/cann_file_parser.py +76 -0
- mindspore/_akg/akg/utils/ascend_profilier/file_manager.py +56 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_bean.py +23 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_headers.py +8 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_parser.py +42 -0
- mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +65 -0
- mindspore/_akg/akg/utils/composite_op_helper.py +7 -2
- mindspore/_akg/akg/utils/dump_ascend_meta.py +22 -3
- mindspore/_akg/akg/utils/kernel_exec.py +41 -15
- mindspore/_akg/akg/utils/tbe_codegen_utils.py +27 -6
- mindspore/_akg/akg/utils/util.py +56 -1
- mindspore/_c_dataengine.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/_checkparam.py +3 -3
- mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
- mindspore/_extends/graph_kernel/splitter.py +3 -2
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +83 -66
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -4
- mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +2 -1
- mindspore/_extends/parse/__init__.py +3 -2
- mindspore/_extends/parse/parser.py +6 -1
- mindspore/_extends/parse/standard_method.py +14 -11
- mindspore/_extends/remote/kernel_build_server.py +2 -1
- mindspore/_mindspore_offline_debug.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/common/_utils.py +16 -0
- mindspore/common/api.py +1 -1
- mindspore/common/auto_dynamic_shape.py +81 -85
- mindspore/common/dump.py +1 -1
- mindspore/common/tensor.py +3 -20
- mindspore/config/op_info.config +1 -1
- mindspore/context.py +11 -4
- mindspore/dataset/engine/cache_client.py +8 -5
- mindspore/dataset/engine/datasets_standard_format.py +5 -0
- mindspore/dataset/vision/transforms.py +21 -21
- mindspore/experimental/optim/adam.py +1 -1
- mindspore/gen_ops.py +1 -1
- mindspore/include/api/model.h +17 -0
- mindspore/include/api/status.h +8 -3
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310/aic-ascend310-ops-info.json +123 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +123 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +158 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +37 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_proto/libop_proto.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +78 -80
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/cpu/libakg.so +0 -0
- mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
- mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
- mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
- mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
- mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
- mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
- mindspore/nn/cell.py +0 -3
- mindspore/nn/layer/activation.py +4 -5
- mindspore/nn/layer/conv.py +39 -23
- mindspore/nn/layer/flash_attention.py +54 -129
- mindspore/nn/layer/math.py +3 -7
- mindspore/nn/layer/rnn_cells.py +5 -5
- mindspore/nn/wrap/__init__.py +4 -2
- mindspore/nn/wrap/cell_wrapper.py +12 -3
- mindspore/numpy/utils_const.py +5 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +1 -1
- mindspore/ops/_grad_experimental/grad_implementations.py +2 -2
- mindspore/ops/_grad_experimental/grad_math_ops.py +19 -18
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
- mindspore/ops/_op_impl/aicpu/add.py +3 -3
- mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +21 -2
- mindspore/ops/_utils/utils.py +2 -0
- mindspore/ops/composite/multitype_ops/_compile_utils.py +2 -1
- mindspore/ops/composite/multitype_ops/getitem_impl.py +2 -2
- mindspore/ops/function/array_func.py +10 -7
- mindspore/ops/function/grad/grad_func.py +0 -1
- mindspore/ops/function/nn_func.py +98 -9
- mindspore/ops/function/random_func.py +2 -1
- mindspore/ops/op_info_register.py +24 -21
- mindspore/ops/operations/__init__.py +6 -2
- mindspore/ops/operations/_grad_ops.py +25 -6
- mindspore/ops/operations/_inner_ops.py +155 -23
- mindspore/ops/operations/array_ops.py +9 -7
- mindspore/ops/operations/comm_ops.py +2 -2
- mindspore/ops/operations/custom_ops.py +85 -68
- mindspore/ops/operations/inner_ops.py +26 -3
- mindspore/ops/operations/math_ops.py +7 -6
- mindspore/ops/operations/nn_ops.py +193 -49
- mindspore/parallel/_parallel_serialization.py +10 -3
- mindspore/parallel/_tensor.py +4 -1
- mindspore/parallel/checkpoint_transform.py +13 -2
- mindspore/parallel/shard.py +17 -10
- mindspore/profiler/common/util.py +1 -0
- mindspore/profiler/parser/ascend_hccl_generator.py +232 -0
- mindspore/profiler/parser/ascend_msprof_exporter.py +86 -43
- mindspore/profiler/parser/ascend_msprof_generator.py +196 -9
- mindspore/profiler/parser/ascend_op_generator.py +1 -1
- mindspore/profiler/parser/ascend_timeline_generator.py +6 -182
- mindspore/profiler/parser/base_timeline_generator.py +1 -1
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +2 -2
- mindspore/profiler/parser/framework_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +19 -0
- mindspore/profiler/profiling.py +46 -24
- mindspore/rewrite/api/pattern_engine.py +1 -1
- mindspore/rewrite/parsers/for_parser.py +7 -7
- mindspore/rewrite/parsers/module_parser.py +4 -4
- mindspore/rewrite/symbol_tree.py +1 -4
- mindspore/run_check/_check_version.py +5 -3
- mindspore/safeguard/rewrite_obfuscation.py +52 -28
- mindspore/scipy/ops.py +55 -5
- mindspore/scipy/optimize/__init__.py +3 -2
- mindspore/scipy/optimize/linear_sum_assignment.py +38 -33
- mindspore/train/callback/_summary_collector.py +1 -1
- mindspore/train/dataset_helper.py +1 -0
- mindspore/train/model.py +2 -2
- mindspore/train/serialization.py +97 -11
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/summary_record.py +23 -7
- mindspore/version.py +1 -1
- {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/METADATA +3 -2
- {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/RECORD +160 -151
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +0 -406
- mindspore/ops/_op_impl/_custom_op/flash_attention/constants.py +0 -41
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +0 -467
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +0 -563
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +0 -193
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +0 -435
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +0 -45
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +0 -67
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +0 -62
- /mindspore/{ops/_op_impl/_custom_op/flash_attention → _akg/akg/utils/ascend_profilier}/__init__.py +0 -0
- {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/WHEEL +0 -0
- {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/entry_points.txt +0 -0
- {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/top_level.txt +0 -0
|
@@ -238,13 +238,14 @@ class LambApplyOptimizerAssign(PrimitiveWithInfer):
|
|
|
238
238
|
@prim_attr_register
|
|
239
239
|
def __init__(self):
|
|
240
240
|
"""Initialize LambApplyOptimizerAssign"""
|
|
241
|
+
self.var_shape = "var_shape"
|
|
241
242
|
self.add_prim_attr('side_effect_mem', True)
|
|
242
243
|
|
|
243
244
|
def infer_shape(self, grad_shape, v_shape, m_shape, var_shape, beta1_shape, sub1_shape,
|
|
244
245
|
beta2_shape, sub2_shape, eps_shape, steps_shape, use_weight_shape, weight_decay_shape):
|
|
245
|
-
validator.check(
|
|
246
|
-
validator.check(
|
|
247
|
-
validator.check(
|
|
246
|
+
validator.check(self.var_shape, var_shape, "m_shape", m_shape, validator.EQ, self.name)
|
|
247
|
+
validator.check(self.var_shape, var_shape, "v_shape", v_shape, validator.EQ, self.name)
|
|
248
|
+
validator.check(self.var_shape, var_shape, "grad_shape", grad_shape, validator.EQ, self.name)
|
|
248
249
|
return m_shape, v_shape, m_shape
|
|
249
250
|
|
|
250
251
|
def infer_dtype(self, grad_dtype, v_dtype, m_dtype, var_dtype, beta1_dtype, sub1_dtype,
|
|
@@ -658,3 +659,25 @@ class ScaleGrad(PrimitiveWithInfer):
|
|
|
658
659
|
@prim_attr_register
|
|
659
660
|
def __init__(self):
|
|
660
661
|
"""Initialize ScaleGrad"""
|
|
662
|
+
|
|
663
|
+
|
|
664
|
+
class KVCacheMgr(Primitive):
|
|
665
|
+
"""
|
|
666
|
+
Update past with cur and index along sequence axis.
|
|
667
|
+
|
|
668
|
+
Inputs:
|
|
669
|
+
- **past** (Parameter) - 4-D tensor with shape: :math:`(batch_size, num_head, seq_len, hidden_size)`.
|
|
670
|
+
- **cur** (Tensor) - 4-D tensor with shape: :math:`(batch_size, num_head, 1, hidden_size)`.
|
|
671
|
+
- **index** (Tensor) - 1-D tensor with shape: :math:`(batch_size,)`.
|
|
672
|
+
|
|
673
|
+
Outputs:
|
|
674
|
+
Tensor, has the same data type and shape as original `past`.
|
|
675
|
+
|
|
676
|
+
Supported Platforms:
|
|
677
|
+
``Ascend``
|
|
678
|
+
"""
|
|
679
|
+
|
|
680
|
+
@prim_attr_register
|
|
681
|
+
def __init__(self):
|
|
682
|
+
self.init_prim_io_names(inputs=['past', 'cur', 'index'], outputs=['past'])
|
|
683
|
+
self.add_prim_attr('side_effect_mem', True)
|
|
@@ -1536,9 +1536,8 @@ class LpNorm(Primitive):
|
|
|
1536
1536
|
"""
|
|
1537
1537
|
|
|
1538
1538
|
@prim_attr_register
|
|
1539
|
-
def __init__(self, axis, p=2, keep_dims=False, epsilon=1e-12):
|
|
1539
|
+
def __init__(self, axis=(), p=2, keep_dims=False, epsilon=1e-12):
|
|
1540
1540
|
"""Initialize LpNorm"""
|
|
1541
|
-
super().__init__("LpNorm")
|
|
1542
1541
|
validator.check_value_type("p", p, [int], self.name)
|
|
1543
1542
|
validator.check_value_type("axis", axis, [int, tuple, list], self.name)
|
|
1544
1543
|
validator.check_value_type("keep_dims", keep_dims, [bool], self.name)
|
|
@@ -2494,6 +2493,7 @@ class Reciprocal(PrimitiveWithCheck):
|
|
|
2494
2493
|
self.init_prim_io_names(inputs=['x'], outputs=['y'])
|
|
2495
2494
|
|
|
2496
2495
|
def infer_value(self, x):
|
|
2496
|
+
"""Infer value for Reciprocal"""
|
|
2497
2497
|
if x is not None:
|
|
2498
2498
|
x = x.asnumpy()
|
|
2499
2499
|
out = 1.0 / x
|
|
@@ -2551,6 +2551,7 @@ class Pow(Primitive):
|
|
|
2551
2551
|
self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['y'])
|
|
2552
2552
|
|
|
2553
2553
|
def infer_value(self, x, power):
|
|
2554
|
+
"""infer value for _BinaryOp"""
|
|
2554
2555
|
if x is not None and power is not None:
|
|
2555
2556
|
x = x.asnumpy()
|
|
2556
2557
|
power = power.asnumpy()
|
|
@@ -2931,7 +2932,7 @@ class Histogram(Primitive):
|
|
|
2931
2932
|
"""
|
|
2932
2933
|
|
|
2933
2934
|
@prim_attr_register
|
|
2934
|
-
def __init__(self, bins=100, min=0.0, max=0.0):
|
|
2935
|
+
def __init__(self, bins=100, min=0.0, max=0.0):
|
|
2935
2936
|
"""Initialize Histogram."""
|
|
2936
2937
|
self.init_prim_io_names(inputs=['x'], outputs=['y'])
|
|
2937
2938
|
validator.check_value_type("bins", bins, [int], self.name)
|
|
@@ -6568,9 +6569,9 @@ class LinSpace(Primitive):
|
|
|
6568
6569
|
|
|
6569
6570
|
Inputs:
|
|
6570
6571
|
- **start** (Tensor) - Start value of interval, 0-D Tensor with dtype float32 or float64.
|
|
6571
|
-
- **stop** (Tensor) - Last value of interval, 0-D Tensor with dtype
|
|
6572
|
-
- **num** (int) - Number of ticks in the interval, inclusive of `start` and `stop`.
|
|
6573
|
-
|
|
6572
|
+
- **stop** (Tensor) - Last value of interval, 0-D Tensor with dtype float32 or float64.
|
|
6573
|
+
- **num** (Union[int, Tensor]) - Number of ticks in the interval, inclusive of `start` and `stop`.
|
|
6574
|
+
Must be a positive integer. When the input is Tensor, it must be a 0-D Tensor with dtype int32 or int64.
|
|
6574
6575
|
|
|
6575
6576
|
Outputs:
|
|
6576
6577
|
Tensor, has the same shape and dtype as `start`.
|
|
@@ -1990,6 +1990,7 @@ class MaxPoolV1(Primitive):
|
|
|
1990
1990
|
self.add_prim_attr("kernel_size", kernel_size_adapted)
|
|
1991
1991
|
self.add_prim_attr("strides", strides_adapted)
|
|
1992
1992
|
|
|
1993
|
+
|
|
1993
1994
|
class MaxPool3D(Primitive):
|
|
1994
1995
|
r"""
|
|
1995
1996
|
Applies a 3D max pooling over an input Tensor which can be regarded as a composition of 3D planes.
|
|
@@ -3918,7 +3919,6 @@ class ResizeBilinear(PrimitiveWithInfer):
|
|
|
3918
3919
|
def infer_dtype(self, input_dtype):
|
|
3919
3920
|
validator.check_tensor_dtype_valid('input_dtype', input_dtype, [mstype.float16, mstype.float32],
|
|
3920
3921
|
self.name)
|
|
3921
|
-
self.add_prim_attr("dtype", input_dtype)
|
|
3922
3922
|
return input_dtype
|
|
3923
3923
|
|
|
3924
3924
|
|
|
@@ -4009,6 +4009,7 @@ class OneHot(Primitive):
|
|
|
4009
4009
|
|
|
4010
4010
|
Note:
|
|
4011
4011
|
If the input indices is rank `N`, the output will have rank `N+1`. The new axis is created at dimension `axis`.
|
|
4012
|
+
On Ascend, if `on_value` is Int64 dtype, `indices` must be Int64 dtype.
|
|
4012
4013
|
|
|
4013
4014
|
Args:
|
|
4014
4015
|
axis (int): Position to insert the value. e.g. If shape of `indices` is :math:`(N, C)`, and `axis` is -1,
|
|
@@ -4019,12 +4020,14 @@ class OneHot(Primitive):
|
|
|
4019
4020
|
- **indices** (Tensor) - A tensor of indices. Tensor of shape :math:`(X_0, \ldots, X_n)`.
|
|
4020
4021
|
Data type must be int32 or int64.
|
|
4021
4022
|
- **depth** (int) - A scalar defining the depth of the one-hot dimension.
|
|
4022
|
-
- **on_value** (Tensor) - A value to fill in output when `indices[j] = i`.
|
|
4023
|
+
- **on_value** (Tensor) - A value to fill in output when `indices[j] = i`. Data type must be int32, int64,
|
|
4024
|
+
float16 or float32.
|
|
4023
4025
|
- **off_value** (Tensor) - A value to fill in output when `indices[j] != i`.
|
|
4024
4026
|
It has the same data type as `on_value`.
|
|
4025
4027
|
|
|
4026
4028
|
Outputs:
|
|
4027
|
-
Tensor, one-hot tensor. Tensor of shape :math:`(X_0, \ldots, X_{axis}, \text{depth} ,X_{axis+1}, \ldots, X_n)
|
|
4029
|
+
Tensor, one-hot tensor. Tensor of shape :math:`(X_0, \ldots, X_{axis}, \text{depth} ,X_{axis+1}, \ldots, X_n)`,
|
|
4030
|
+
and it has the same data type as `on_value`.
|
|
4028
4031
|
|
|
4029
4032
|
Raises:
|
|
4030
4033
|
TypeError: If `axis` or `depth` is not an int.
|
|
@@ -8259,8 +8262,12 @@ class Conv3D(Primitive):
|
|
|
8259
8262
|
self.add_prim_attr('data_format', self.format)
|
|
8260
8263
|
self.out_channel = validator.check_positive_int(out_channel, 'out_channel', self.name)
|
|
8261
8264
|
validator.check_value_type("group", group, (int,), self.name)
|
|
8265
|
+
validator.check_int_range(group, 1, out_channel, validator.INC_BOTH, "group", self.name)
|
|
8266
|
+
device_target = context.get_context("device_target")
|
|
8262
8267
|
if self.out_channel % group != 0:
|
|
8263
8268
|
raise ValueError("The argument 'group' should be divisible by 'out_channel'")
|
|
8269
|
+
if device_target == "Ascend" and group != 1:
|
|
8270
|
+
raise ValueError("On Ascend platform, group = 1 must be satisfied.")
|
|
8264
8271
|
|
|
8265
8272
|
self.group = group
|
|
8266
8273
|
self.add_prim_attr('groups', self.group)
|
|
@@ -8956,8 +8963,10 @@ class Dilation2D(Primitive):
|
|
|
8956
8963
|
self.pad_mode = validator.check_string(pad_mode, ['VALID', 'SAME', 'valid', 'same'], 'pad_mode', self.name)
|
|
8957
8964
|
self.add_prim_attr('pad_mode', self.pad_mode.upper())
|
|
8958
8965
|
self.stride = _check_format_stride_or_dilation("stride", stride, self.name, self.data_format)
|
|
8966
|
+
|
|
8959
8967
|
def is_in_range(x):
|
|
8960
8968
|
return 1 <= x <= 255
|
|
8969
|
+
|
|
8961
8970
|
if not is_in_range(self.stride[2]) or not is_in_range(self.stride[3]):
|
|
8962
8971
|
raise ValueError(f'For Dilation2D, size of stride is not supported, '
|
|
8963
8972
|
f'stride should be in the range of [1, 255], '
|
|
@@ -11325,9 +11334,24 @@ class PromptFlashAttention(Primitive):
|
|
|
11325
11334
|
S -- Sequence length
|
|
11326
11335
|
H -- Hidden size
|
|
11327
11336
|
|
|
11337
|
+
Refer to :func:mindspore.ops.prompt_flash_attention for more detail.
|
|
11338
|
+
|
|
11328
11339
|
.. warning::
|
|
11329
11340
|
This is an experimental API that is subject to change or deletion.
|
|
11330
11341
|
|
|
11342
|
+
Args:
|
|
11343
|
+
num_heads (int): The number of heads.
|
|
11344
|
+
scale_value (float): The scale value indicating the scale coefficient, which is used as the scalar of
|
|
11345
|
+
Muls in the calculation. Default: 1.0.
|
|
11346
|
+
pre_tokens (int): Previous tokens. Default: 2147483547.
|
|
11347
|
+
next_tokens (int): next tokens. Default: 0.
|
|
11348
|
+
indicate the upper triangle, Indicate the number of data blocks involved in the calculation. The value 0
|
|
11349
|
+
indicates that the data blocks in the upper triangle are not involved in the calculation
|
|
11350
|
+
input_layout (str): the data layout of the input qkv, support `(BSH)` and `(BNSD)`, Default `BSH`.
|
|
11351
|
+
num_key_value_heads (int): head numbers of key/value which are used in GQA algorithm.
|
|
11352
|
+
The value o indicates if the key and value have the same head nums, use numHeads. Default: 0.
|
|
11353
|
+
sparse_mode (int): Default: 0
|
|
11354
|
+
|
|
11331
11355
|
Inputs:
|
|
11332
11356
|
- **query** (Tensor) - The query tensor with data type of float16 or float32.
|
|
11333
11357
|
Input tensor of shape :math:`(B, S, H)` / `(B, N, S, D)`.
|
|
@@ -11337,28 +11361,42 @@ class PromptFlashAttention(Primitive):
|
|
|
11337
11361
|
Input tensor of shape :math:`(B, S, H)` / `(B, N, S, D)`.
|
|
11338
11362
|
- **attn_mask** (Tensor) - The attention mask tensor with data type of float16 or float32.
|
|
11339
11363
|
For each element, 0 indicates retention and 1 indicates discard. Input tensor of shape :math:`(B, 1, S, S)`.
|
|
11340
|
-
- **padding_mask** (Tensor) - The padding mask tensor with data type of float16 or float32
|
|
11341
11364
|
- **actual_seq_lengths** (Tensor): Describe actual sequence length of each input with data type of int.
|
|
11342
|
-
- **
|
|
11343
|
-
- **
|
|
11344
|
-
|
|
11345
|
-
- **
|
|
11346
|
-
- **
|
|
11347
|
-
|
|
11348
|
-
|
|
11349
|
-
|
|
11350
|
-
- **num_key_value_heads** (int): head numbers of key/value which are used in GQA algorithm.
|
|
11351
|
-
The value o indicates if the key and value have the same head nums, use numHeads. Default: 0.
|
|
11365
|
+
- **actual_seq_lengths_kv** (Tensor): Describe actual sequence length of each input with data type of int.
|
|
11366
|
+
- **padding_mask** (Tensor) - The padding mask tensor with data type of float16 or float32
|
|
11367
|
+
- **dep_scale1** (Tensor)
|
|
11368
|
+
- **quant_scale1** (Tensor)
|
|
11369
|
+
- **deq_scale2** (Tensor)
|
|
11370
|
+
- **quant_scale2** (Tensor)
|
|
11371
|
+
- **quant_offset2** (Tensor)
|
|
11372
|
+
|
|
11352
11373
|
|
|
11353
11374
|
Outputs:
|
|
11354
11375
|
- **attention_out** (Tensor) - Input tensor of shape :math:`(B, S, H)` / `(B, N, S, D)`.
|
|
11355
11376
|
|
|
11356
|
-
|
|
11357
|
-
``
|
|
11377
|
+
Supported Platforms:
|
|
11378
|
+
``Ascend``
|
|
11379
|
+
|
|
11380
|
+
Examples:
|
|
11381
|
+
>>> import mindspore.ops.operations.nn_ops as P
|
|
11382
|
+
>>> from mindspore import Tensor
|
|
11383
|
+
>>> import numpy as np
|
|
11384
|
+
>>> B = 1
|
|
11385
|
+
>>> N = 16
|
|
11386
|
+
>>> S = 256
|
|
11387
|
+
>>> D = 16
|
|
11388
|
+
>>> query = Tensor(np.ones((B, N, S, D), dtype=np.float16))
|
|
11389
|
+
>>> key = Tensor(np.ones((B, N, S, D), dtype=np.float16))
|
|
11390
|
+
>>> value = Tensor(np.ones((B, N, S, D), dtype=np.float16))
|
|
11391
|
+
>>> pfa = P.PromptFlashAttention(N, input_layout='BNSD')
|
|
11392
|
+
>>> out = pfa(query, key, value, None, None, None, None, None, None, None, None, None)
|
|
11393
|
+
>>> print(out[0].shape)
|
|
11394
|
+
(1, 16, 256, 16)
|
|
11358
11395
|
"""
|
|
11396
|
+
|
|
11359
11397
|
@prim_attr_register
|
|
11360
11398
|
def __init__(self, num_heads, scale_value=1.0, pre_tokens=2147483547, next_tokens=0, input_layout='BSH',
|
|
11361
|
-
num_key_value_heads=0):
|
|
11399
|
+
num_key_value_heads=0, sparse_mode=0):
|
|
11362
11400
|
"""Initialize PromptFlashAttention."""
|
|
11363
11401
|
validator.check_value_type('num_heads', num_heads, [int], self.name)
|
|
11364
11402
|
validator.check_value_type('scale_value', scale_value, [float], self.name)
|
|
@@ -11366,7 +11404,10 @@ class PromptFlashAttention(Primitive):
|
|
|
11366
11404
|
validator.check_value_type('next_tokens', next_tokens, [int], self.name)
|
|
11367
11405
|
validator.check_value_type('input_layout', input_layout, [str], self.name)
|
|
11368
11406
|
validator.check_value_type('num_key_value_heads', num_key_value_heads, [int], self.name)
|
|
11369
|
-
|
|
11407
|
+
validator.check_value_type('sparse_mode', sparse_mode, [int], self.name)
|
|
11408
|
+
self.init_prim_io_names(inputs=["query", "key", "value", "attn_mask", "actual_seq_lengths",
|
|
11409
|
+
"actual_seq_lengths_kv", "padding_mask", "deq_scale1", "quant_scale1",
|
|
11410
|
+
"deq_scale2", "quant_scale2", "quant_offset2"],
|
|
11370
11411
|
outputs=["attention_out"])
|
|
11371
11412
|
|
|
11372
11413
|
|
|
@@ -11376,46 +11417,57 @@ class FlashAttentionScore(Primitive):
|
|
|
11376
11417
|
.. warning::
|
|
11377
11418
|
This is an experimental API that is subject to change or deletion.
|
|
11378
11419
|
B -- Batch size
|
|
11379
|
-
|
|
11380
|
-
|
|
11381
|
-
|
|
11382
|
-
|
|
11420
|
+
S1 -- Sequence length of query
|
|
11421
|
+
S2 -- Sequence length of key and value
|
|
11422
|
+
N1 -- Num heads of query
|
|
11423
|
+
N2 -- Num heads of key and value, and N2 must be a factor of N1
|
|
11424
|
+
D -- head size
|
|
11425
|
+
H1 -- Hidden size of query, which equals to N1 * D
|
|
11426
|
+
H2 -- Hidden size of key and value, which equals to N2 * D
|
|
11383
11427
|
Args:
|
|
11384
|
-
head_num (int): The
|
|
11428
|
+
head_num (int): The head num of query.
|
|
11385
11429
|
keep_prob (float): The keep probability of dropout. Default: 1.0.
|
|
11386
11430
|
scale_value (float): The scale value. Default: 1.0.
|
|
11387
11431
|
pre_tokens (int): Previous tokens. Default: 65536.
|
|
11388
11432
|
next_tokens (int): Next tokens. Default: 65536.
|
|
11389
11433
|
inner_precise (int): Specify the execution mode, where 0 indicates high precision mode and 1 indicates high
|
|
11390
|
-
performance mode. Default: 0.
|
|
11391
|
-
input_layout (str, optional): Specifies the layout of `query`, the value must be one of ["BSH", "
|
|
11392
|
-
|
|
11393
|
-
|
|
11394
|
-
|
|
11395
|
-
|
|
11396
|
-
|
|
11397
|
-
|
|
11398
|
-
|
|
11399
|
-
|
|
11400
|
-
|
|
11401
|
-
|
|
11402
|
-
|
|
11403
|
-
|
|
11404
|
-
|
|
11405
|
-
|
|
11434
|
+
performance mode. Only support 0 currently. Default: 0.
|
|
11435
|
+
input_layout (str, optional): Specifies the layout of `query`, the value must be one of ["BSH", "BNSD"].
|
|
11436
|
+
Default: "BSH".
|
|
11437
|
+
sparse_mode (int): Default 0.
|
|
11438
|
+
|
|
11439
|
+
Inputs:
|
|
11440
|
+
- **query** (Tensor[float16, float32, bfloat16]) - The query tensor.
|
|
11441
|
+
Input tensor of shape :math:`(B, S1, H1)` or `(B, N1, S1, D)`.
|
|
11442
|
+
- **key** (Tensor[float16, float32, bfloat16]) - The key tensor.
|
|
11443
|
+
Input tensor of shape :math:`(B, S2, H2)` or `(B, N2, S2, D)`.
|
|
11444
|
+
- **value** (Tensor[float16, float32, bfloat16]) - The value tensor.
|
|
11445
|
+
Input tensor of shape :math:`(B, S2, H2)` or `(B, N2, S2, D)`.
|
|
11446
|
+
- **real_shift** (Tensor[float16, float32, bfloat16], None) - The position embedding code.
|
|
11447
|
+
Input tensor of shape :math: `(B, N1, S1, S2)` or `(B, N1, 1, S2)`.
|
|
11448
|
+
- **drop_mask** (Tensor[uint8], None) - The dropout mask tensor.
|
|
11449
|
+
Input tensor of shape :math:`(B, N1, S1, S2 // 8) or None`.
|
|
11406
11450
|
- **padding_mask** (None) - The padding mask of float16 or float32, not implemented yet.
|
|
11451
|
+
- **attn_mask** (Tensor[uint8], None) - The attention mask tensor.
|
|
11452
|
+
For each element, 0 indicates retention and 1 indicates discard.
|
|
11453
|
+
Input tensor of shape :math:`(B, N1, S1, S2)`, `(B, 1, S1, S2)` or `(S1, S2)`.
|
|
11454
|
+
- **prefix** (Tensor[int64], None) - Not implemented yet.
|
|
11455
|
+
Input tensor of shape :math:`(B,)`.
|
|
11407
11456
|
|
|
11408
11457
|
Outputs:
|
|
11409
|
-
- **
|
|
11410
|
-
- **
|
|
11411
|
-
- **
|
|
11458
|
+
- **softmax_max** (Tensor[float32]) - (B, N1, S1, 8)
|
|
11459
|
+
- **softmax_sum** (Tensor[float32]) - (B, N1, S1, 8)
|
|
11460
|
+
- **softmax_out** (Tensor[float32]) - Useless output, ignore it. Output tensor of shape : `()`
|
|
11461
|
+
- **attention_out** (Tensor[float16, float32, bfloat16]) - The output of attention, its shape, and data type
|
|
11462
|
+
are the same as the query.
|
|
11463
|
+
|
|
11412
11464
|
Supported Platforms:
|
|
11413
11465
|
``Ascend``
|
|
11414
11466
|
"""
|
|
11415
11467
|
|
|
11416
11468
|
@prim_attr_register
|
|
11417
11469
|
def __init__(self, head_num, keep_prob=1.0, scale_value=1.0, pre_tokens=65536, next_tokens=65536, inner_precise=0,
|
|
11418
|
-
input_layout="BSH"):
|
|
11470
|
+
input_layout="BSH", sparse_mode=0):
|
|
11419
11471
|
"""Initialize FlashAttentionScore"""
|
|
11420
11472
|
validator.check_value_type('head_num', head_num, [int], self.name)
|
|
11421
11473
|
validator.check_value_type('keep_prob', keep_prob, [int, float], self.name)
|
|
@@ -11425,11 +11477,103 @@ class FlashAttentionScore(Primitive):
|
|
|
11425
11477
|
validator.check_value_type('pre_tokens', pre_tokens, [int], self.name)
|
|
11426
11478
|
validator.check_value_type('next_tokens', next_tokens, [int], self.name)
|
|
11427
11479
|
validator.check_value_type('inner_precise', inner_precise, [int], self.name)
|
|
11428
|
-
|
|
11429
|
-
|
|
11480
|
+
validator.check_value_type('sparse_mode', sparse_mode, [int], self.name)
|
|
11481
|
+
if inner_precise not in [0]:
|
|
11482
|
+
raise ValueError(f"Attribute 'inner_precise' must be 0, but got {inner_precise}")
|
|
11430
11483
|
validator.check_value_type('input_layout', input_layout, [str], self.name)
|
|
11431
|
-
if input_layout not in ["BSH"]:
|
|
11432
|
-
raise ValueError(f"Attribute 'input_layout' must be either '
|
|
11484
|
+
if input_layout not in ["BSH", "BNSD"]:
|
|
11485
|
+
raise ValueError(f"Attribute 'input_layout' must be either 'BSH' or 'BNSD', but got {input_layout}")
|
|
11433
11486
|
self.init_prim_io_names(
|
|
11434
|
-
inputs=['query', 'key', 'value', '
|
|
11435
|
-
outputs=['
|
|
11487
|
+
inputs=['query', 'key', 'value', 'real_shift', 'drop_mask', 'padding_mask', 'attn_mask', 'prefix'],
|
|
11488
|
+
outputs=['softmax_max', 'softmax_sum', 'softmax_out', 'attention_out'])
|
|
11489
|
+
|
|
11490
|
+
|
|
11491
|
+
class RmsNorm(Primitive):
|
|
11492
|
+
r"""
|
|
11493
|
+
The RmsNorm operator is a normalization operation, and its formula is:
|
|
11494
|
+
|
|
11495
|
+
.. math::
|
|
11496
|
+
y=\frac{x_i}{\sqrt{\frac{1}{n}}\sum_{i=1}^{n}{ x_i^2}+\varepsilon }\gamma_i
|
|
11497
|
+
|
|
11498
|
+
.. warning::
|
|
11499
|
+
This is an experimental API that is subject to change or deletion.
|
|
11500
|
+
|
|
11501
|
+
Args:
|
|
11502
|
+
epsilon (float): prevent division by 0, default value is `1e-6`
|
|
11503
|
+
|
|
11504
|
+
Inputs:
|
|
11505
|
+
- **input_x** (Tensor) - Input data of RmsNorm, support data type: float16, float32, bfloat16.
|
|
11506
|
+
- **gamma** (Tensor) - Support data type: float16, float32, bfloat16.
|
|
11507
|
+
|
|
11508
|
+
Outputs:
|
|
11509
|
+
- **y** (Tensor) - Has the same type and shape with `input_x`.
|
|
11510
|
+
- **rstd** (Tensor) - Has the same type with `input_x`, used by gradient calculation.
|
|
11511
|
+
|
|
11512
|
+
Raises:
|
|
11513
|
+
TypeError: If data type of `input_x` is not one of the following: float16, float32, bfloat16.
|
|
11514
|
+
TypeError: If data type of `gamma` is not one of the following: float16, float32, bfloat16.
|
|
11515
|
+
TypeError: If data type of "input_x" is not the same with the data type of "gamma"
|
|
11516
|
+
|
|
11517
|
+
Supported Platforms:
|
|
11518
|
+
``Ascend``
|
|
11519
|
+
"""
|
|
11520
|
+
|
|
11521
|
+
@prim_attr_register
|
|
11522
|
+
def __init__(self, epsilon=1e-6):
|
|
11523
|
+
"""Initialize Dense."""
|
|
11524
|
+
validator.check_value_type("epsilon", epsilon, [float], self.name)
|
|
11525
|
+
self.init_prim_io_names(inputs=['x', 'gamma'], outputs=["y", "rstd"])
|
|
11526
|
+
|
|
11527
|
+
|
|
11528
|
+
class PagedAttention(Primitive):
|
|
11529
|
+
r"""
|
|
11530
|
+
.. warning::
|
|
11531
|
+
This is an experimental API that is subject to change or deletion.
|
|
11532
|
+
"""
|
|
11533
|
+
@prim_attr_register
|
|
11534
|
+
def __init__(self, head_num, scale_value=1.0, kv_head_num=0):
|
|
11535
|
+
"""Initialize PagedAttention"""
|
|
11536
|
+
validator.check_value_type('head_num', head_num, [int], self.name)
|
|
11537
|
+
validator.check_value_type('scale_value', scale_value, [float], self.name) # scale after qkbmm
|
|
11538
|
+
validator.check_value_type('kv_head_num', kv_head_num, [int], self.name) # for MQA
|
|
11539
|
+
self.init_prim_io_names(
|
|
11540
|
+
inputs=['query', 'key_cache', 'value_cache', 'block_tables', 'context_lens'],
|
|
11541
|
+
outputs=['attention_out'])
|
|
11542
|
+
|
|
11543
|
+
|
|
11544
|
+
class PagedAttentionMask(Primitive):
|
|
11545
|
+
r"""
|
|
11546
|
+
.. warning::
|
|
11547
|
+
This is an experimental API that is subject to change or deletion.
|
|
11548
|
+
"""
|
|
11549
|
+
@prim_attr_register
|
|
11550
|
+
def __init__(self, head_num, scale_value=1.0, kv_head_num=0):
|
|
11551
|
+
"""Initialize PagedAttentionMask"""
|
|
11552
|
+
validator.check_value_type('head_num', head_num, [int], self.name)
|
|
11553
|
+
validator.check_value_type('scale_value', scale_value, [float], self.name) # scale after qkbmm
|
|
11554
|
+
validator.check_value_type('kv_head_num', kv_head_num, [int], self.name) # for MQA
|
|
11555
|
+
self.init_prim_io_names(
|
|
11556
|
+
inputs=['query', 'key_cache', 'value_cache', 'block_tables', 'context_lens', 'alibi_mask'],
|
|
11557
|
+
outputs=['attention_out'])
|
|
11558
|
+
|
|
11559
|
+
|
|
11560
|
+
class ReshapeAndCache(Primitive):
|
|
11561
|
+
r"""
|
|
11562
|
+
.. warning::
|
|
11563
|
+
This is an experimental API that is subject to change or deletion.
|
|
11564
|
+
"""
|
|
11565
|
+
__mindspore_signature__ = (
|
|
11566
|
+
sig.make_sig('key', dtype=sig.sig_dtype.T),
|
|
11567
|
+
sig.make_sig('value', dtype=sig.sig_dtype.T),
|
|
11568
|
+
sig.make_sig('key_cache', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
|
11569
|
+
sig.make_sig('value_cache', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
|
11570
|
+
sig.make_sig('slot_mapping', dtype=sig.sig_dtype.T1),
|
|
11571
|
+
)
|
|
11572
|
+
|
|
11573
|
+
@prim_attr_register
|
|
11574
|
+
def __init__(self):
|
|
11575
|
+
"""Initialize ReshapeAndCache"""
|
|
11576
|
+
self.init_prim_io_names(
|
|
11577
|
+
inputs=['key', 'value', 'key_cache', 'value_cache', 'slot_mapping'],
|
|
11578
|
+
outputs=['key_out'])
|
|
11579
|
+
self.add_prim_attr('side_effect_mem', True)
|
|
@@ -335,7 +335,8 @@ def _rank_list_for_transform_parallel_checkpoint(rank_id, src_strategy_list, dst
|
|
|
335
335
|
return list(result_list)
|
|
336
336
|
|
|
337
337
|
|
|
338
|
-
def _transform_parallel_checkpoint(rank_id, param_total_dict, param_attr_dict, src_strategy_list,
|
|
338
|
+
def _transform_parallel_checkpoint(rank_id, param_total_dict, param_attr_dict, src_strategy_list,
|
|
339
|
+
dst_strategy_list, param_type_dict):
|
|
339
340
|
"""
|
|
340
341
|
Transform model parallel dimension for distributed checkpoint files.
|
|
341
342
|
"""
|
|
@@ -397,15 +398,21 @@ def _transform_parallel_checkpoint(rank_id, param_total_dict, param_attr_dict, s
|
|
|
397
398
|
transform_tensor = ms.Tensor(param_total_dict[param_name][rank_id % device_num])
|
|
398
399
|
requires_grad = param_attr_dict[param_name][rank_id % device_num][0]
|
|
399
400
|
layerwise_parallel = param_attr_dict[param_name][rank_id % device_num][1]
|
|
400
|
-
|
|
401
|
+
transform_para = ms.Parameter(transform_tensor, param_name, requires_grad, layerwise_parallel)
|
|
402
|
+
if param_type_dict[param_name][rank_id % device_num] == "BFloat16":
|
|
403
|
+
transform_para.set_dtype(ms.bfloat16)
|
|
404
|
+
transform_param_dict[param_name] = transform_para
|
|
401
405
|
|
|
402
406
|
# Handle those parameter like learning_rate, global_step which not in strategy_file.
|
|
403
407
|
for param_name, _ in param_total_dict.items():
|
|
404
408
|
if param_name not in transform_param_dict:
|
|
405
|
-
|
|
409
|
+
transform_para = ms.Parameter(
|
|
406
410
|
ms.Tensor(param_total_dict[param_name][rank_id % device_num]), param_name,
|
|
407
411
|
param_attr_dict[param_name][rank_id % device_num][0],
|
|
408
412
|
param_attr_dict[param_name][rank_id % device_num][1])
|
|
413
|
+
if param_type_dict[param_name][rank_id % device_num] == "BFloat16":
|
|
414
|
+
transform_para.set_dtype(ms.bfloat16)
|
|
415
|
+
transform_param_dict[param_name] = transform_para
|
|
409
416
|
|
|
410
417
|
transform_param_list = [{"name": param_name, "data": param_data}
|
|
411
418
|
for param_name, param_data in transform_param_dict.items()]
|
mindspore/parallel/_tensor.py
CHANGED
|
@@ -17,6 +17,7 @@ from __future__ import division
|
|
|
17
17
|
from __future__ import absolute_import
|
|
18
18
|
|
|
19
19
|
import numpy as np
|
|
20
|
+
from mindspore.common import dtype as mstype
|
|
20
21
|
from mindspore.common.tensor import Tensor
|
|
21
22
|
from mindspore.communication.management import get_rank, get_group_size
|
|
22
23
|
from mindspore._c_expression import TensorTransform
|
|
@@ -221,6 +222,8 @@ def _load_tensor(tensor, dev_mat, tensor_map, rank_id=-1):
|
|
|
221
222
|
rank = rank_id
|
|
222
223
|
tensor_strategy = _get_tensor_strategy(dev_mat, tensor_map)
|
|
223
224
|
tensor_slice_index = _get_tensor_slice_index(dev_mat, tensor_strategy, tensor_map, rank)
|
|
225
|
+
if tensor.dtype == mstype.bfloat16:
|
|
226
|
+
tensor = tensor.float()
|
|
224
227
|
np_tensor = tensor.asnumpy()
|
|
225
228
|
np_tensor_list = _chunk_tensor_by_strategy(np_tensor, tensor_strategy)
|
|
226
229
|
np_tensor_slice = np_tensor_list[int(tensor_slice_index)]
|
|
@@ -260,7 +263,7 @@ def _load_tensor_by_layout(tensor, layout, rank_id):
|
|
|
260
263
|
rank = get_rank(group)
|
|
261
264
|
size = get_group_size(group)
|
|
262
265
|
tensor_slice = np.split(tensor_slice, size)[rank]
|
|
263
|
-
return Tensor(tensor_slice)
|
|
266
|
+
return Tensor(tensor_slice, tensor.dtype)
|
|
264
267
|
|
|
265
268
|
|
|
266
269
|
def _reshape_param_data(param_data, dev_mat, tensor_map):
|
|
@@ -21,6 +21,7 @@ import copy
|
|
|
21
21
|
from collections import defaultdict
|
|
22
22
|
import numpy as np
|
|
23
23
|
import mindspore as ms
|
|
24
|
+
from mindspore.common import dtype as mstype
|
|
24
25
|
from mindspore.parallel._parallel_serialization import _rank_list_for_transform_parallel_checkpoint, \
|
|
25
26
|
_transform_parallel_checkpoint, _get_device_num_from_strategy, _make_dir, \
|
|
26
27
|
_extract_layout_map, _extract_src_dst_layout_map, _parameter_not_in_local_stage, _extract_pipeline_stage_num, \
|
|
@@ -192,6 +193,7 @@ def transform_checkpoint_by_rank(rank_id, checkpoint_files_map, save_checkpoint_
|
|
|
192
193
|
raise ValueError("Checkpoint file {} in rank {} not exits: ".format(local_file, rank))
|
|
193
194
|
param_total_dict = defaultdict(dict)
|
|
194
195
|
param_attr_dict = defaultdict(dict)
|
|
196
|
+
param_type_dict = defaultdict(dict)
|
|
195
197
|
src_strategy_list, dst_strategy_list = _extract_src_dst_layout_map(rank_id, src_strategy_file, dst_strategy_file)
|
|
196
198
|
# src rank => local rank inside pipeline stage
|
|
197
199
|
src_stage_device_num = np.prod(src_strategy_list.get(list(src_strategy_list.keys())[0])[0]) if src_strategy_list \
|
|
@@ -208,11 +210,15 @@ def transform_checkpoint_by_rank(rank_id, checkpoint_files_map, save_checkpoint_
|
|
|
208
210
|
and _parameter_not_in_local_stage(param_name, origin_dst_strategy_list, dst_strategy_list):
|
|
209
211
|
continue
|
|
210
212
|
src_rank = rank % src_stage_device_num
|
|
213
|
+
param_type_dict[param_name][src_rank] = str(param.data.dtype)
|
|
214
|
+
if param.data.dtype == mstype.bfloat16:
|
|
215
|
+
param.set_dtype(mstype.float32)
|
|
211
216
|
param_total_dict[param_name][src_rank] = param.data.asnumpy()
|
|
212
217
|
param_attr_dict[param_name][src_rank] = (param.requires_grad, param.layerwise_parallel)
|
|
213
218
|
local_rank_id = rank_id % dst_stage_device_num
|
|
214
219
|
transform_param_list = _transform_parallel_checkpoint(local_rank_id, param_total_dict,
|
|
215
|
-
param_attr_dict, src_strategy_list, dst_strategy_list
|
|
220
|
+
param_attr_dict, src_strategy_list, dst_strategy_list,
|
|
221
|
+
param_type_dict)
|
|
216
222
|
ms.save_checkpoint(transform_param_list, save_checkpoint_file_name)
|
|
217
223
|
|
|
218
224
|
|
|
@@ -297,11 +303,15 @@ def transform_checkpoints(src_checkpoints_dir, dst_checkpoints_dir, ckpt_prefix,
|
|
|
297
303
|
for needed_rank_list_key, transform_rank_list in needed_rank_list_map.items():
|
|
298
304
|
param_total_dict = defaultdict(dict)
|
|
299
305
|
param_attr_dict = defaultdict(dict)
|
|
306
|
+
param_type_dict = defaultdict(dict)
|
|
300
307
|
needed_rank_list = needed_rank_list_key.split("-")
|
|
301
308
|
for needed_rank in needed_rank_list:
|
|
302
309
|
ckpt_dict = ms.load_checkpoint(all_checkpoint_files_map.get(int(needed_rank)))
|
|
303
310
|
for param_name, param in ckpt_dict.items():
|
|
304
311
|
src_rank = int(needed_rank) % src_stage_device_num
|
|
312
|
+
param_type_dict[param_name][src_rank] = str(param.data.dtype)
|
|
313
|
+
if param.data.dtype == mstype.bfloat16:
|
|
314
|
+
param.set_dtype(mstype.float32)
|
|
305
315
|
param_total_dict[param_name][src_rank] = param.data.asnumpy()
|
|
306
316
|
param_attr_dict[param_name][src_rank] = (param.requires_grad, param.layerwise_parallel)
|
|
307
317
|
for transform_rank in transform_rank_list:
|
|
@@ -316,7 +326,8 @@ def transform_checkpoints(src_checkpoints_dir, dst_checkpoints_dir, ckpt_prefix,
|
|
|
316
326
|
|
|
317
327
|
local_rank_id = transform_rank % dst_stage_device_num
|
|
318
328
|
transform_param_list = _transform_parallel_checkpoint(local_rank_id, param_total_dict_copy,
|
|
319
|
-
param_attr_dict, src_strategy_list, dst_strategy_list
|
|
329
|
+
param_attr_dict, src_strategy_list, dst_strategy_list,
|
|
330
|
+
param_type_dict)
|
|
320
331
|
save_checkpoint_file = "{}{}.ckpt".format(ckpt_prefix, transform_rank)
|
|
321
332
|
save_checkpoint_file_dir = os.path.join(dst_checkpoints_dir, "rank_{}".format(transform_rank))
|
|
322
333
|
if not os.path.exists(save_checkpoint_file_dir):
|
mindspore/parallel/shard.py
CHANGED
|
@@ -36,13 +36,17 @@ class Shard(Shard_):
|
|
|
36
36
|
def __call__(self, fn, in_strategy, out_strategy=None, parameter_plan=None, device="Ascend", level=0):
|
|
37
37
|
if ms.context.get_context("mode") != ms.context.PYNATIVE_MODE or \
|
|
38
38
|
ms.context.get_auto_parallel_context("parallel_mode") not in ["auto_parallel"]:
|
|
39
|
-
raise AssertionError(
|
|
39
|
+
raise AssertionError(
|
|
40
|
+
f"Cell shard only supports auto parallel under PyNative mode.")
|
|
40
41
|
if ms.context.get_context("device_target") not in ["Ascend", "GPU"]:
|
|
41
|
-
raise AssertionError(
|
|
42
|
+
raise AssertionError(
|
|
43
|
+
f"'Shard' now only supports 'Ascend' and 'GPU'")
|
|
42
44
|
if ms.context.get_auto_parallel_context("search_mode") != "sharding_propagation":
|
|
43
|
-
raise AssertionError(
|
|
45
|
+
raise AssertionError(
|
|
46
|
+
f"'search_mode' must be 'sharding_propagation' for 'Shard'")
|
|
44
47
|
if not isinstance(in_strategy, tuple):
|
|
45
|
-
raise TypeError(
|
|
48
|
+
raise TypeError(
|
|
49
|
+
f"For 'Shard', the 'in_strategy' should be a tuple, but got {type(in_strategy).__name__}")
|
|
46
50
|
if not isinstance(out_strategy, (type(None), tuple)):
|
|
47
51
|
raise TypeError(f"For 'Shard', the 'out_strategy' should be None or tuple, "
|
|
48
52
|
f"but got {type(out_strategy).__name__}")
|
|
@@ -117,7 +121,8 @@ class Shard(Shard_):
|
|
|
117
121
|
return
|
|
118
122
|
if isinstance(parameter_plan, dict):
|
|
119
123
|
if not isinstance(fn, ms.nn.Cell):
|
|
120
|
-
raise TypeError(
|
|
124
|
+
raise TypeError(
|
|
125
|
+
f"If parameter_plan is set, type of fn must be mindspore.nn.Cell, but got {type(fn)}")
|
|
121
126
|
for k in parameter_plan.keys():
|
|
122
127
|
v = parameter_plan[k]
|
|
123
128
|
if not isinstance(k, str) or not isinstance(v, tuple):
|
|
@@ -131,10 +136,12 @@ class Shard(Shard_):
|
|
|
131
136
|
param_strategy = parameter_plan[param_name]
|
|
132
137
|
param = self._search_parameter_by_name(param_name, fn)
|
|
133
138
|
if param is None:
|
|
134
|
-
logger.warning(
|
|
139
|
+
logger.warning(
|
|
140
|
+
f"{param_name} is not exist, ignored its setting.")
|
|
135
141
|
continue
|
|
136
142
|
|
|
137
|
-
self._check_layout_is_valid(
|
|
143
|
+
self._check_layout_is_valid(
|
|
144
|
+
param_name, param.shape, param_strategy)
|
|
138
145
|
if param.param_info.param_strategy:
|
|
139
146
|
logger.warning(f"The layout of parameter '{param_name}' "
|
|
140
147
|
f"has been set to {param.param_info.param_strategy}, "
|
|
@@ -143,7 +150,7 @@ class Shard(Shard_):
|
|
|
143
150
|
|
|
144
151
|
def _is_attrs_has_been_set(self, fn, in_strategy, out_strategy, device, level):
|
|
145
152
|
return self.shard_fn is not None and self.fn == fn and self.in_strategy == in_strategy and \
|
|
146
|
-
|
|
153
|
+
self.out_strategy == out_strategy and self.device == device and self.level == level
|
|
147
154
|
|
|
148
155
|
|
|
149
156
|
def shard(fn, in_strategy, out_strategy=None, parameter_plan=None, device="Ascend", level=0):
|
|
@@ -216,8 +223,8 @@ def shard(fn, in_strategy, out_strategy=None, parameter_plan=None, device="Ascen
|
|
|
216
223
|
... device_num=2)
|
|
217
224
|
>>> def test_shard(x, y):
|
|
218
225
|
... return x + y
|
|
219
|
-
>>> x = Tensor(np.ones(shape=(32, 10)))
|
|
220
|
-
>>> y = Tensor(np.ones(shape=(32, 10)))
|
|
226
|
+
>>> x = Tensor(np.ones(shape=(32, 10)), dtype=ms.float32)
|
|
227
|
+
>>> y = Tensor(np.ones(shape=(32, 10)), dtype=ms.float32)
|
|
221
228
|
>>> output = ms.shard(test_shard, in_strategy=((2, 1), (2, 1)))(x, y)
|
|
222
229
|
>>> print(output.shape)
|
|
223
230
|
(32, 10)
|