mindspore 2.2.10__cp37-none-any.whl → 2.2.14__cp37-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +2 -1
- mindspore/_akg/akg/composite/build_module.py +95 -5
- mindspore/_akg/akg/topi/cpp/impl.py +1 -1
- mindspore/_akg/akg/tvm/_ffi/base.py +1 -1
- mindspore/_akg/akg/utils/composite_op_helper.py +7 -2
- mindspore/_akg/akg/utils/dump_ascend_meta.py +22 -3
- mindspore/_akg/akg/utils/util.py +18 -1
- mindspore/_c_dataengine.cpython-37m-aarch64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-37m-aarch64-linux-gnu.so +0 -0
- mindspore/_c_mindrecord.cpython-37m-aarch64-linux-gnu.so +0 -0
- mindspore/_extends/parse/__init__.py +3 -2
- mindspore/_extends/parse/parser.py +6 -1
- mindspore/_extends/parse/standard_method.py +12 -2
- mindspore/_mindspore_offline_debug.cpython-37m-aarch64-linux-gnu.so +0 -0
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/common/_utils.py +16 -0
- mindspore/common/tensor.py +0 -2
- mindspore/communication/management.py +3 -0
- mindspore/context.py +34 -4
- mindspore/dataset/engine/cache_client.py +8 -5
- mindspore/dataset/engine/datasets.py +23 -0
- mindspore/dataset/engine/validators.py +1 -1
- mindspore/dataset/vision/py_transforms_util.py +2 -2
- mindspore/experimental/optim/lr_scheduler.py +5 -6
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +118 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/cpu/libakg.so +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
- mindspore/mindrecord/tools/cifar100_to_mr.py +49 -57
- mindspore/mindrecord/tools/cifar10_to_mr.py +46 -55
- mindspore/mindrecord/tools/csv_to_mr.py +3 -8
- mindspore/mindrecord/tools/mnist_to_mr.py +4 -9
- mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -4
- mindspore/nn/layer/activation.py +1 -1
- mindspore/nn/layer/embedding.py +2 -2
- mindspore/nn/layer/flash_attention.py +48 -135
- mindspore/nn/loss/loss.py +1 -1
- mindspore/nn/optim/ada_grad.py +2 -2
- mindspore/nn/optim/sgd.py +3 -2
- mindspore/nn/wrap/__init__.py +4 -2
- mindspore/nn/wrap/cell_wrapper.py +6 -3
- mindspore/numpy/math_ops.py +1 -1
- mindspore/ops/__init__.py +3 -0
- mindspore/ops/_grad_experimental/grad_array_ops.py +0 -31
- mindspore/ops/_grad_experimental/grad_comm_ops.py +4 -2
- mindspore/ops/_grad_experimental/grad_inner_ops.py +8 -0
- mindspore/ops/_grad_experimental/grad_math_ops.py +37 -17
- mindspore/ops/_op_impl/aicpu/__init__.py +1 -0
- mindspore/ops/_op_impl/aicpu/generate_eod_mask.py +38 -0
- mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +21 -2
- mindspore/ops/function/array_func.py +6 -5
- mindspore/ops/function/debug_func.py +1 -1
- mindspore/ops/function/linalg_func.py +21 -11
- mindspore/ops/function/math_func.py +3 -0
- mindspore/ops/function/nn_func.py +13 -11
- mindspore/ops/function/parameter_func.py +2 -0
- mindspore/ops/function/sparse_unary_func.py +2 -2
- mindspore/ops/function/vmap_func.py +1 -0
- mindspore/ops/operations/__init__.py +5 -2
- mindspore/ops/operations/_embedding_cache_ops.py +1 -1
- mindspore/ops/operations/_grad_ops.py +3 -4
- mindspore/ops/operations/_inner_ops.py +56 -1
- mindspore/ops/operations/_quant_ops.py +4 -4
- mindspore/ops/operations/_rl_inner_ops.py +1 -1
- mindspore/ops/operations/array_ops.py +15 -4
- mindspore/ops/operations/custom_ops.py +1 -1
- mindspore/ops/operations/debug_ops.py +1 -1
- mindspore/ops/operations/image_ops.py +3 -3
- mindspore/ops/operations/inner_ops.py +49 -0
- mindspore/ops/operations/math_ops.py +65 -3
- mindspore/ops/operations/nn_ops.py +95 -28
- mindspore/ops/operations/random_ops.py +2 -0
- mindspore/ops/operations/sparse_ops.py +4 -4
- mindspore/ops/silent_check.py +162 -0
- mindspore/parallel/__init__.py +3 -2
- mindspore/parallel/_auto_parallel_context.py +82 -3
- mindspore/parallel/_parallel_serialization.py +34 -2
- mindspore/parallel/_tensor.py +3 -1
- mindspore/parallel/_transformer/transformer.py +8 -8
- mindspore/parallel/checkpoint_transform.py +191 -45
- mindspore/profiler/parser/ascend_cluster_generator.py +111 -0
- mindspore/profiler/parser/ascend_communicate_generator.py +315 -0
- mindspore/profiler/parser/ascend_flops_generator.py +8 -2
- mindspore/profiler/parser/ascend_fpbp_generator.py +8 -2
- mindspore/profiler/parser/ascend_hccl_generator.py +2 -2
- mindspore/profiler/parser/ascend_msprof_exporter.py +30 -6
- mindspore/profiler/parser/ascend_msprof_generator.py +16 -5
- mindspore/profiler/parser/ascend_op_generator.py +15 -7
- mindspore/profiler/parser/ascend_timeline_generator.py +5 -2
- mindspore/profiler/parser/base_timeline_generator.py +11 -3
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +2 -1
- mindspore/profiler/parser/framework_parser.py +8 -2
- mindspore/profiler/parser/memory_usage_parser.py +8 -2
- mindspore/profiler/parser/minddata_analyzer.py +8 -2
- mindspore/profiler/parser/minddata_parser.py +1 -1
- mindspore/profiler/parser/msadvisor_analyzer.py +4 -2
- mindspore/profiler/parser/msadvisor_parser.py +9 -3
- mindspore/profiler/profiling.py +97 -25
- mindspore/rewrite/api/node.py +1 -1
- mindspore/rewrite/api/symbol_tree.py +2 -2
- mindspore/rewrite/parsers/for_parser.py +6 -6
- mindspore/rewrite/parsers/module_parser.py +4 -4
- mindspore/scipy/ops.py +55 -5
- mindspore/scipy/optimize/__init__.py +3 -2
- mindspore/scipy/optimize/linear_sum_assignment.py +38 -33
- mindspore/train/callback/_checkpoint.py +8 -8
- mindspore/train/callback/_landscape.py +2 -3
- mindspore/train/callback/_summary_collector.py +6 -7
- mindspore/train/dataset_helper.py +6 -0
- mindspore/train/model.py +17 -5
- mindspore/train/serialization.py +6 -1
- mindspore/train/summary/_writer_pool.py +1 -1
- mindspore/train/summary/summary_record.py +5 -6
- mindspore/version.py +1 -1
- {mindspore-2.2.10.dist-info → mindspore-2.2.14.dist-info}/METADATA +3 -2
- {mindspore-2.2.10.dist-info → mindspore-2.2.14.dist-info}/RECORD +141 -149
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +0 -406
- mindspore/ops/_op_impl/_custom_op/flash_attention/constants.py +0 -41
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +0 -467
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +0 -563
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +0 -193
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +0 -435
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +0 -45
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +0 -67
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +0 -62
- {mindspore-2.2.10.dist-info → mindspore-2.2.14.dist-info}/WHEEL +0 -0
- {mindspore-2.2.10.dist-info → mindspore-2.2.14.dist-info}/entry_points.txt +0 -0
- {mindspore-2.2.10.dist-info → mindspore-2.2.14.dist-info}/top_level.txt +0 -0
|
@@ -123,6 +123,64 @@ class _MathBinaryOp(_BinaryOp):
|
|
|
123
123
|
real_shape = [dim if cmp_dim > 0 else cmp_dim for dim, cmp_dim in zip(shape_value, cmp_shape)]
|
|
124
124
|
return tuple(real_shape)
|
|
125
125
|
|
|
126
|
+
class SilentCheck(Primitive):
|
|
127
|
+
"""
|
|
128
|
+
Implement SilentCheck on `pre_val`, `min_val`, `max_val`, `result` and
|
|
129
|
+
update them inplace with given parameters.
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
c_min_steps (int): an int determines...
|
|
133
|
+
|
|
134
|
+
c_thresh_l1 (float): a float determines...
|
|
135
|
+
|
|
136
|
+
c_coeff_l1 (float): a float determines...
|
|
137
|
+
|
|
138
|
+
c_thresh_l2 (float): a float determines...
|
|
139
|
+
|
|
140
|
+
c_coeff_l2 (float): a float determines...
|
|
141
|
+
|
|
142
|
+
Inputs:
|
|
143
|
+
- **val** (Tensor) - Tensor with dtype float32.
|
|
144
|
+
- **input_grad** (Parameter) - Tensor with dtype float32.
|
|
145
|
+
- **pre_val** (Parameter) - Input Parameter with dtype float32.
|
|
146
|
+
- **min_val** (Parameter) - Input Parameter with dtype float32.
|
|
147
|
+
- **max_val** (Parameter) - Input Parameter with dtype float32.
|
|
148
|
+
- **val_counter** (Parameter) - Input Parameter with dtype int32.
|
|
149
|
+
|
|
150
|
+
Outputs:
|
|
151
|
+
Tuple of 5 Tensors, the updated parameters.
|
|
152
|
+
- **input_grad** (Tensor) - Tensor with dtype float32.
|
|
153
|
+
- **pre_val** (Tensor) - Tensor with dtype float32.
|
|
154
|
+
- **min_val** (Tensor) - Tensor with dtype float32.
|
|
155
|
+
- **max_val** (Tensor) - Tensor with dtype float32.
|
|
156
|
+
- **result** (Tensor) - Tensor with dtype int32.
|
|
157
|
+
|
|
158
|
+
Raises:
|
|
159
|
+
TypeError: If `val` is not Tensor with dtype float32.
|
|
160
|
+
TypeError: If `result` is not Tensor with dtype int32.
|
|
161
|
+
TypeError: If `pre_val`, `min_val`, `max_val`, `input_grad` are not all Parameter type with dtype float32.
|
|
162
|
+
TypeError: If `c_thresh_l1` or `c_coeff_l1` is not a float number.
|
|
163
|
+
TypeError: If `c_min_steps` is not an int number.
|
|
164
|
+
|
|
165
|
+
Supported Platforms:
|
|
166
|
+
``Ascend``
|
|
167
|
+
|
|
168
|
+
Examples:
|
|
169
|
+
>>> from mindspore.ops.operations.math_ops import SilentCheck
|
|
170
|
+
>>> silent_check = SilentCheck()
|
|
171
|
+
xxx
|
|
172
|
+
"""
|
|
173
|
+
|
|
174
|
+
@prim_attr_register
|
|
175
|
+
def __init__(self, c_min_steps, c_thresh_l1, c_coeff_l1, c_thresh_l2, c_coeff_l2):
|
|
176
|
+
"""Initialize SilentCheck."""
|
|
177
|
+
validator.check_value_type("c_min_steps", c_min_steps, [int], self.name)
|
|
178
|
+
validator.check_value_type("c_thresh_l1", c_thresh_l1, [float], self.name)
|
|
179
|
+
validator.check_value_type("c_coeff_l1", c_coeff_l1, [float], self.name)
|
|
180
|
+
validator.check_value_type("c_thresh_l2", c_thresh_l2, [float], self.name)
|
|
181
|
+
validator.check_value_type("c_coeff_l2", c_coeff_l2, [float], self.name)
|
|
182
|
+
self.add_prim_attr('side_effect_mem', True)
|
|
183
|
+
|
|
126
184
|
|
|
127
185
|
class _BitwiseBinaryOp(_MathBinaryOp):
|
|
128
186
|
"""
|
|
@@ -462,6 +520,7 @@ class AssignAdd(Primitive):
|
|
|
462
520
|
>>> import mindspore
|
|
463
521
|
>>> import numpy as np
|
|
464
522
|
>>> from mindspore import Tensor, ops, nn
|
|
523
|
+
>>> from mindspore.common.initializer import initializer
|
|
465
524
|
>>> class Net(nn.Cell):
|
|
466
525
|
... def __init__(self):
|
|
467
526
|
... super(Net, self).__init__()
|
|
@@ -512,6 +571,7 @@ class AssignSub(Primitive):
|
|
|
512
571
|
>>> import mindspore
|
|
513
572
|
>>> import numpy as np
|
|
514
573
|
>>> from mindspore import Tensor, ops, nn
|
|
574
|
+
>>> from mindspore.common.initializer import initializer
|
|
515
575
|
>>> class Net(nn.Cell):
|
|
516
576
|
... def __init__(self):
|
|
517
577
|
... super(Net, self).__init__()
|
|
@@ -6569,9 +6629,9 @@ class LinSpace(Primitive):
|
|
|
6569
6629
|
|
|
6570
6630
|
Inputs:
|
|
6571
6631
|
- **start** (Tensor) - Start value of interval, 0-D Tensor with dtype float32 or float64.
|
|
6572
|
-
- **stop** (Tensor) - Last value of interval, 0-D Tensor with dtype
|
|
6573
|
-
- **num** (int) - Number of ticks in the interval, inclusive of `start` and `stop`.
|
|
6574
|
-
|
|
6632
|
+
- **stop** (Tensor) - Last value of interval, 0-D Tensor with dtype float32 or float64.
|
|
6633
|
+
- **num** (Union[int, Tensor]) - Number of ticks in the interval, inclusive of `start` and `stop`.
|
|
6634
|
+
Must be a positive integer. When the input is Tensor, it must be a 0-D Tensor with dtype int32 or int64.
|
|
6575
6635
|
|
|
6576
6636
|
Outputs:
|
|
6577
6637
|
Tensor, has the same shape and dtype as `start`.
|
|
@@ -7253,6 +7313,7 @@ class Igamma(Primitive):
|
|
|
7253
7313
|
|
|
7254
7314
|
Examples:
|
|
7255
7315
|
>>> import numpy as np
|
|
7316
|
+
>>> import mindspore
|
|
7256
7317
|
>>> from mindspore import Tensor, ops
|
|
7257
7318
|
>>> a = Tensor(np.array([2.0, 4.0, 6.0, 8.0]).astype(np.float32))
|
|
7258
7319
|
>>> x = Tensor(np.array([2.0, 3.0, 4.0, 5.0]).astype(np.float32))
|
|
@@ -7291,6 +7352,7 @@ class Igammac(Primitive):
|
|
|
7291
7352
|
``Ascend`` ``GPU`` ``CPU``
|
|
7292
7353
|
|
|
7293
7354
|
Examples:
|
|
7355
|
+
>>> import mindspore
|
|
7294
7356
|
>>> import numpy as np
|
|
7295
7357
|
>>> from mindspore import Tensor, ops
|
|
7296
7358
|
>>> a = Tensor(np.array([2.0, 4.0, 6.0, 8.0]).astype(np.float32))
|
|
@@ -3777,7 +3777,7 @@ class LayerNorm(Primitive):
|
|
|
3777
3777
|
- **output_x** (Tensor) - The normalized input, has the same type and shape as the `input_x`.
|
|
3778
3778
|
- **mean** (Tensor) - The first `begin_norm_axis` dimensions of `mean` shape is the same as `input_x`,
|
|
3779
3779
|
and the remaining dimensions are 1. Suppose the shape of the `input_x` is :math:`(x_1, x_2, \ldots, x_R)`,
|
|
3780
|
-
the shape of the `mean` is :math:`(x_1, \ldots, x_{
|
|
3780
|
+
the shape of the `mean` is :math:`(x_1, \ldots, x_{begin\_params\_axis}, 1, \ldots, 1)`
|
|
3781
3781
|
(when `begin_params_axis=0`, the shape of `mean` is :math:`(1, \ldots, 1)` ).
|
|
3782
3782
|
- **variance** (Tensor) - Shape is the same as `mean` .
|
|
3783
3783
|
|
|
@@ -4917,6 +4917,7 @@ class Adam(Primitive):
|
|
|
4917
4917
|
>>> import mindspore
|
|
4918
4918
|
>>> import numpy as np
|
|
4919
4919
|
>>> from mindspore import Tensor, nn, ops
|
|
4920
|
+
>>> from mindspore import Parameter
|
|
4920
4921
|
>>> class Net(nn.Cell):
|
|
4921
4922
|
... def __init__(self):
|
|
4922
4923
|
... super(Net, self).__init__()
|
|
@@ -9991,6 +9992,9 @@ class FractionalMaxPool3DWithFixedKsize(Primitive):
|
|
|
9991
9992
|
``Ascend`` ``GPU`` ``CPU``
|
|
9992
9993
|
|
|
9993
9994
|
Examples:
|
|
9995
|
+
>>> import numpy as np
|
|
9996
|
+
>>> from mindspore import Tensor, ops
|
|
9997
|
+
>>> from mindspore import dtype as mstype
|
|
9994
9998
|
>>> x = Tensor(np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16])
|
|
9995
9999
|
... .reshape([1, 1, 2, 2, 4]), mstype.float32)
|
|
9996
10000
|
>>> random_samples = Tensor(np.array([0.7, 0.7, 0.7]).reshape([1, 1, 3]), mstype.float32)
|
|
@@ -11363,7 +11367,7 @@ class PromptFlashAttention(Primitive):
|
|
|
11363
11367
|
For each element, 0 indicates retention and 1 indicates discard. Input tensor of shape :math:`(B, 1, S, S)`.
|
|
11364
11368
|
- **actual_seq_lengths** (Tensor): Describe actual sequence length of each input with data type of int.
|
|
11365
11369
|
- **actual_seq_lengths_kv** (Tensor): Describe actual sequence length of each input with data type of int.
|
|
11366
|
-
- **
|
|
11370
|
+
- **pse_shift** (Tensor) - The position encoding tensor with data type of float16 or float32.
|
|
11367
11371
|
- **dep_scale1** (Tensor)
|
|
11368
11372
|
- **quant_scale1** (Tensor)
|
|
11369
11373
|
- **deq_scale2** (Tensor)
|
|
@@ -11406,7 +11410,7 @@ class PromptFlashAttention(Primitive):
|
|
|
11406
11410
|
validator.check_value_type('num_key_value_heads', num_key_value_heads, [int], self.name)
|
|
11407
11411
|
validator.check_value_type('sparse_mode', sparse_mode, [int], self.name)
|
|
11408
11412
|
self.init_prim_io_names(inputs=["query", "key", "value", "attn_mask", "actual_seq_lengths",
|
|
11409
|
-
"actual_seq_lengths_kv", "
|
|
11413
|
+
"actual_seq_lengths_kv", "pse_shift", "deq_scale1", "quant_scale1",
|
|
11410
11414
|
"deq_scale2", "quant_scale2", "quant_offset2"],
|
|
11411
11415
|
outputs=["attention_out"])
|
|
11412
11416
|
|
|
@@ -11417,41 +11421,50 @@ class FlashAttentionScore(Primitive):
|
|
|
11417
11421
|
.. warning::
|
|
11418
11422
|
This is an experimental API that is subject to change or deletion.
|
|
11419
11423
|
B -- Batch size
|
|
11420
|
-
|
|
11421
|
-
|
|
11422
|
-
|
|
11423
|
-
|
|
11424
|
+
S1 -- Sequence length of query
|
|
11425
|
+
S2 -- Sequence length of key and value
|
|
11426
|
+
N1 -- Num heads of query
|
|
11427
|
+
N2 -- Num heads of key and value, and N2 must be a factor of N1
|
|
11428
|
+
D -- head size
|
|
11429
|
+
H1 -- Hidden size of query, which equals to N1 * D
|
|
11430
|
+
H2 -- Hidden size of key and value, which equals to N2 * D
|
|
11424
11431
|
Args:
|
|
11425
|
-
head_num (int): The
|
|
11432
|
+
head_num (int): The head num of query.
|
|
11426
11433
|
keep_prob (float): The keep probability of dropout. Default: 1.0.
|
|
11427
11434
|
scale_value (float): The scale value. Default: 1.0.
|
|
11428
11435
|
pre_tokens (int): Previous tokens. Default: 65536.
|
|
11429
11436
|
next_tokens (int): Next tokens. Default: 65536.
|
|
11430
11437
|
inner_precise (int): Specify the execution mode, where 0 indicates high precision mode and 1 indicates high
|
|
11431
|
-
performance mode. Default: 0.
|
|
11438
|
+
performance mode. Only support 0 currently. Default: 0.
|
|
11432
11439
|
input_layout (str, optional): Specifies the layout of `query`, the value must be one of ["BSH", "BNSD"].
|
|
11433
11440
|
Default: "BSH".
|
|
11434
11441
|
sparse_mode (int): Default 0.
|
|
11435
11442
|
|
|
11436
11443
|
Inputs:
|
|
11437
|
-
- **query** (Tensor) - The query tensor
|
|
11438
|
-
Input tensor of shape :math:`(B,
|
|
11439
|
-
- **key** (Tensor) - The key tensor
|
|
11440
|
-
Input tensor of shape :math:`(B,
|
|
11441
|
-
- **value** (Tensor) - The value tensor
|
|
11442
|
-
Input tensor of shape :math:`(B,
|
|
11443
|
-
- **
|
|
11444
|
-
|
|
11445
|
-
- **drop_mask** (Tensor) - The dropout mask tensor
|
|
11446
|
-
Input tensor of shape :math:`(B,
|
|
11447
|
-
- **real_shift** (None) - The position embedding code of float16 or float32, not implemented yet.
|
|
11444
|
+
- **query** (Tensor[float16, float32, bfloat16]) - The query tensor.
|
|
11445
|
+
Input tensor of shape :math:`(B, S1, H1)` or `(B, N1, S1, D)`.
|
|
11446
|
+
- **key** (Tensor[float16, float32, bfloat16]) - The key tensor.
|
|
11447
|
+
Input tensor of shape :math:`(B, S2, H2)` or `(B, N2, S2, D)`.
|
|
11448
|
+
- **value** (Tensor[float16, float32, bfloat16]) - The value tensor.
|
|
11449
|
+
Input tensor of shape :math:`(B, S2, H2)` or `(B, N2, S2, D)`.
|
|
11450
|
+
- **real_shift** (Tensor[float16, float32, bfloat16], None) - The position embedding code.
|
|
11451
|
+
Input tensor of shape :math: `(B, N1, S1, S2)` or `(B, N1, 1, S2)`.
|
|
11452
|
+
- **drop_mask** (Tensor[uint8], None) - The dropout mask tensor.
|
|
11453
|
+
Input tensor of shape :math:`(B, N1, S1, S2 // 8) or None`.
|
|
11448
11454
|
- **padding_mask** (None) - The padding mask of float16 or float32, not implemented yet.
|
|
11449
|
-
- **
|
|
11455
|
+
- **attn_mask** (Tensor[uint8], None) - The attention mask tensor.
|
|
11456
|
+
For each element, 0 indicates retention and 1 indicates discard.
|
|
11457
|
+
Input tensor of shape :math:`(B, N1, S1, S2)`, `(B, 1, S1, S2)` or `(S1, S2)`.
|
|
11458
|
+
- **prefix** (Tensor[int64], None) - Not implemented yet.
|
|
11459
|
+
Input tensor of shape :math:`(B,)`.
|
|
11450
11460
|
|
|
11451
11461
|
Outputs:
|
|
11452
|
-
- **
|
|
11453
|
-
- **
|
|
11454
|
-
- **
|
|
11462
|
+
- **softmax_max** (Tensor[float32]) - (B, N1, S1, 8)
|
|
11463
|
+
- **softmax_sum** (Tensor[float32]) - (B, N1, S1, 8)
|
|
11464
|
+
- **softmax_out** (Tensor[float32]) - Useless output, ignore it. Output tensor of shape : `()`
|
|
11465
|
+
- **attention_out** (Tensor[float16, float32, bfloat16]) - The output of attention, its shape, and data type
|
|
11466
|
+
are the same as the query.
|
|
11467
|
+
|
|
11455
11468
|
Supported Platforms:
|
|
11456
11469
|
``Ascend``
|
|
11457
11470
|
"""
|
|
@@ -11469,14 +11482,14 @@ class FlashAttentionScore(Primitive):
|
|
|
11469
11482
|
validator.check_value_type('next_tokens', next_tokens, [int], self.name)
|
|
11470
11483
|
validator.check_value_type('inner_precise', inner_precise, [int], self.name)
|
|
11471
11484
|
validator.check_value_type('sparse_mode', sparse_mode, [int], self.name)
|
|
11472
|
-
if inner_precise not in [0
|
|
11473
|
-
raise ValueError(f"Attribute 'inner_precise' must be
|
|
11485
|
+
if inner_precise not in [0]:
|
|
11486
|
+
raise ValueError(f"Attribute 'inner_precise' must be 0, but got {inner_precise}")
|
|
11474
11487
|
validator.check_value_type('input_layout', input_layout, [str], self.name)
|
|
11475
11488
|
if input_layout not in ["BSH", "BNSD"]:
|
|
11476
11489
|
raise ValueError(f"Attribute 'input_layout' must be either 'BSH' or 'BNSD', but got {input_layout}")
|
|
11477
11490
|
self.init_prim_io_names(
|
|
11478
|
-
inputs=['query', 'key', 'value', '
|
|
11479
|
-
outputs=['
|
|
11491
|
+
inputs=['query', 'key', 'value', 'real_shift', 'drop_mask', 'padding_mask', 'attn_mask', 'prefix'],
|
|
11492
|
+
outputs=['softmax_max', 'softmax_sum', 'softmax_out', 'attention_out'])
|
|
11480
11493
|
|
|
11481
11494
|
|
|
11482
11495
|
class RmsNorm(Primitive):
|
|
@@ -11514,3 +11527,57 @@ class RmsNorm(Primitive):
|
|
|
11514
11527
|
"""Initialize Dense."""
|
|
11515
11528
|
validator.check_value_type("epsilon", epsilon, [float], self.name)
|
|
11516
11529
|
self.init_prim_io_names(inputs=['x', 'gamma'], outputs=["y", "rstd"])
|
|
11530
|
+
|
|
11531
|
+
|
|
11532
|
+
class PagedAttention(Primitive):
|
|
11533
|
+
r"""
|
|
11534
|
+
.. warning::
|
|
11535
|
+
This is an experimental API that is subject to change or deletion.
|
|
11536
|
+
"""
|
|
11537
|
+
@prim_attr_register
|
|
11538
|
+
def __init__(self, head_num, scale_value=1.0, kv_head_num=0):
|
|
11539
|
+
"""Initialize PagedAttention"""
|
|
11540
|
+
validator.check_value_type('head_num', head_num, [int], self.name)
|
|
11541
|
+
validator.check_value_type('scale_value', scale_value, [float], self.name) # scale after qkbmm
|
|
11542
|
+
validator.check_value_type('kv_head_num', kv_head_num, [int], self.name) # for MQA
|
|
11543
|
+
self.init_prim_io_names(
|
|
11544
|
+
inputs=['query', 'key_cache', 'value_cache', 'block_tables', 'context_lens'],
|
|
11545
|
+
outputs=['attention_out'])
|
|
11546
|
+
|
|
11547
|
+
|
|
11548
|
+
class PagedAttentionMask(Primitive):
|
|
11549
|
+
r"""
|
|
11550
|
+
.. warning::
|
|
11551
|
+
This is an experimental API that is subject to change or deletion.
|
|
11552
|
+
"""
|
|
11553
|
+
@prim_attr_register
|
|
11554
|
+
def __init__(self, head_num, scale_value=1.0, kv_head_num=0):
|
|
11555
|
+
"""Initialize PagedAttentionMask"""
|
|
11556
|
+
validator.check_value_type('head_num', head_num, [int], self.name)
|
|
11557
|
+
validator.check_value_type('scale_value', scale_value, [float], self.name) # scale after qkbmm
|
|
11558
|
+
validator.check_value_type('kv_head_num', kv_head_num, [int], self.name) # for MQA
|
|
11559
|
+
self.init_prim_io_names(
|
|
11560
|
+
inputs=['query', 'key_cache', 'value_cache', 'block_tables', 'context_lens', 'alibi_mask'],
|
|
11561
|
+
outputs=['attention_out'])
|
|
11562
|
+
|
|
11563
|
+
|
|
11564
|
+
class ReshapeAndCache(Primitive):
|
|
11565
|
+
r"""
|
|
11566
|
+
.. warning::
|
|
11567
|
+
This is an experimental API that is subject to change or deletion.
|
|
11568
|
+
"""
|
|
11569
|
+
__mindspore_signature__ = (
|
|
11570
|
+
sig.make_sig('key', dtype=sig.sig_dtype.T),
|
|
11571
|
+
sig.make_sig('value', dtype=sig.sig_dtype.T),
|
|
11572
|
+
sig.make_sig('key_cache', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
|
11573
|
+
sig.make_sig('value_cache', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
|
11574
|
+
sig.make_sig('slot_mapping', dtype=sig.sig_dtype.T1),
|
|
11575
|
+
)
|
|
11576
|
+
|
|
11577
|
+
@prim_attr_register
|
|
11578
|
+
def __init__(self):
|
|
11579
|
+
"""Initialize ReshapeAndCache"""
|
|
11580
|
+
self.init_prim_io_names(
|
|
11581
|
+
inputs=['key', 'value', 'key_cache', 'value_cache', 'slot_mapping'],
|
|
11582
|
+
outputs=['key_out'])
|
|
11583
|
+
self.add_prim_attr('side_effect_mem', True)
|
|
@@ -479,8 +479,8 @@ class SparseToDenseV2(Primitive):
|
|
|
479
479
|
Tensor, converted from sparse tensor. The dtype is same as `values`, and the shape is `output_shape`.
|
|
480
480
|
|
|
481
481
|
Raises:
|
|
482
|
-
TypeError: If the dtype of `indices` is neither
|
|
483
|
-
TypeError: If the dtype of `outputshape` is neither
|
|
482
|
+
TypeError: If the dtype of `indices` is neither int32 nor int64.
|
|
483
|
+
TypeError: If the dtype of `outputshape` is neither int32 nor int64.
|
|
484
484
|
ValueError: If the shape of `output_shape`, shape of `indices`,
|
|
485
485
|
shape of `default_value` and shape of `values` don't meet the parameter description.
|
|
486
486
|
ValueError: If each Element of `output_shape` is not > 0.
|
|
@@ -2382,8 +2382,8 @@ class SparseCountSparseOutput(Primitive):
|
|
|
2382
2382
|
Args:
|
|
2383
2383
|
binary_output (bool) - If ``False`` , output the number of occurrences of each value,
|
|
2384
2384
|
if ``True`` output 1 for orresponding values. Default: ``False`` .
|
|
2385
|
-
minlength(Scalar) -
|
|
2386
|
-
maxlength(Scalar) -
|
|
2385
|
+
minlength(Scalar) - int type minimum value to count, Default: ``-1`` .
|
|
2386
|
+
maxlength(Scalar) - int type maximum value to count, Default: ``-1`` .
|
|
2387
2387
|
|
|
2388
2388
|
Inputs:
|
|
2389
2389
|
- **indices** (Tensor) - Tensor representing the position of the element in the sparse
|
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""Silent Check."""
|
|
16
|
+
import os
|
|
17
|
+
|
|
18
|
+
from mindspore.common.tensor import Tensor
|
|
19
|
+
from mindspore.common.parameter import Parameter
|
|
20
|
+
import mindspore.common.dtype as mstype
|
|
21
|
+
|
|
22
|
+
from . import operations
|
|
23
|
+
from .operations._inner_ops import _MirrorSilentCheck
|
|
24
|
+
from .operations import RmsNorm as OriginRmsNorm
|
|
25
|
+
from .operations import LayerNorm as OriginLayerNorm
|
|
26
|
+
from .primitive import Primitive
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
NPU_ASD_ENABLE = 'NPU_ASD_ENABLE'
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class ASDBase:
|
|
33
|
+
"""
|
|
34
|
+
ASDBase is the base class of operator with accuracy-sensitive detection feature in python.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
cls (Primitive): Original operator requiring accuracy-sensitive detection feature.
|
|
38
|
+
args (tuple): A variable parameter tuple to the original operator.
|
|
39
|
+
kwargs (dict): A variable parameter dictionary passed the original operator.
|
|
40
|
+
|
|
41
|
+
Supported Platforms:
|
|
42
|
+
``Ascend``
|
|
43
|
+
|
|
44
|
+
Examples:
|
|
45
|
+
>>> from mindspore.ops.silent_check import ASDBase
|
|
46
|
+
>>> from mindspore.ops import LayerNorm as OriginLayerNorm
|
|
47
|
+
>>> class LayerNormASD(ASDBase):
|
|
48
|
+
... def __init__(self, *args, **kwargs):
|
|
49
|
+
... super().__init__(OriginLayerNorm, *args, **kwargs)
|
|
50
|
+
... # init parameters for accuracy-sensitive detection by calling the base class method generate_params()
|
|
51
|
+
... self.pre_val, self.min_val, self.max_val, self.cnt = self.generate_params()
|
|
52
|
+
...
|
|
53
|
+
... def __call__(self, input_x, gamma, beta):
|
|
54
|
+
... if self.enable_check:
|
|
55
|
+
... # execute accuracy-sensitive detection by calling the check_op of base class
|
|
56
|
+
... input_x = self.check_op(
|
|
57
|
+
... input_x, self.pre_val, self.min_val, self.max_val, self.cnt, None)
|
|
58
|
+
... self.cnt += 1
|
|
59
|
+
... # return the result of original operator
|
|
60
|
+
... return self.op(input_x, gamma, beta)
|
|
61
|
+
"""
|
|
62
|
+
_index = 0
|
|
63
|
+
__ms_class__ = True
|
|
64
|
+
|
|
65
|
+
def __init__(self, cls, *args, **kwargs):
|
|
66
|
+
self.op = cls(*args, **kwargs)
|
|
67
|
+
self.check_op = _MirrorSilentCheck()
|
|
68
|
+
self._suffix = "ASD_" + cls.__name__
|
|
69
|
+
primitive_attr = dir(Primitive)
|
|
70
|
+
self._op_attr_dict = {
|
|
71
|
+
name for name in primitive_attr if not name.startswith("_")}
|
|
72
|
+
self.enable_check = os.environ.get(NPU_ASD_ENABLE) == "1"
|
|
73
|
+
|
|
74
|
+
def __getattr__(self, name):
|
|
75
|
+
def method_wrapper(*args, **kwargs):
|
|
76
|
+
out = getattr(self.op, name)(*args, **kwargs)
|
|
77
|
+
if out is self.op:
|
|
78
|
+
return self
|
|
79
|
+
return out
|
|
80
|
+
|
|
81
|
+
if name in self._op_attr_dict:
|
|
82
|
+
if callable(getattr(self.op, name)):
|
|
83
|
+
return method_wrapper
|
|
84
|
+
if hasattr(self.op, name):
|
|
85
|
+
return getattr(self.op, name)
|
|
86
|
+
return super().__getattr__(self, name)
|
|
87
|
+
|
|
88
|
+
def __repr__(self):
|
|
89
|
+
return self.op.__repr__()
|
|
90
|
+
|
|
91
|
+
def generate_params(self):
|
|
92
|
+
"""
|
|
93
|
+
Generate support params for accuracy-sensitive detection.
|
|
94
|
+
|
|
95
|
+
Returns:
|
|
96
|
+
tuple consisting of four elements.
|
|
97
|
+
The derived class initializes the parameters required for accuracy-sensitive detection by calling
|
|
98
|
+
this function.
|
|
99
|
+
|
|
100
|
+
Examples:
|
|
101
|
+
>>> from mindspore.ops.silent_check import ASDBase
|
|
102
|
+
>>> from mindspore.ops import LayerNorm as OriginLayerNorm
|
|
103
|
+
>>> class LayerNormASD(ASDBase):
|
|
104
|
+
... def __init__(self, *args, **kwargs):
|
|
105
|
+
... super().__init__(OriginLayerNorm, *args, **kwargs)
|
|
106
|
+
... # init parameters for accuracy-sensitive detection by calling the base class function
|
|
107
|
+
... self.pre_val, self.min_val, self.max_val, self.cnt = self.generate_params()
|
|
108
|
+
"""
|
|
109
|
+
pre_val = Parameter(Tensor(0, mstype.float32),
|
|
110
|
+
name=f"{self._suffix}_pre_val_{self._index}",
|
|
111
|
+
requires_grad=False)
|
|
112
|
+
min_val = Parameter(Tensor(0, mstype.float32),
|
|
113
|
+
name=f"{self._suffix}_min_val_{self._index}",
|
|
114
|
+
requires_grad=False)
|
|
115
|
+
max_val = Parameter(Tensor(0, mstype.float32),
|
|
116
|
+
name=f"{self._suffix}_max_val_{self._index}",
|
|
117
|
+
requires_grad=False)
|
|
118
|
+
cnt = Parameter(Tensor(0, mstype.int32),
|
|
119
|
+
name=f"{self._suffix}_cnt_{self._index}",
|
|
120
|
+
requires_grad=False)
|
|
121
|
+
ASDBase._index += 1
|
|
122
|
+
return pre_val, min_val, max_val, cnt
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
class RmsNormASD(ASDBase):
|
|
126
|
+
"""
|
|
127
|
+
RmsNorm with ASD.
|
|
128
|
+
"""
|
|
129
|
+
|
|
130
|
+
def __init__(self, *args, **kwargs):
|
|
131
|
+
super().__init__(OriginRmsNorm, *args, **kwargs)
|
|
132
|
+
self.pre_val, self.min_val, self.max_val, self.cnt = self.generate_params()
|
|
133
|
+
|
|
134
|
+
def __call__(self, input_x, gamma):
|
|
135
|
+
if self.enable_check:
|
|
136
|
+
input_x = self.check_op(
|
|
137
|
+
input_x, self.pre_val, self.min_val, self.max_val, self.cnt, None)
|
|
138
|
+
self.cnt += 1
|
|
139
|
+
return self.op(input_x, gamma)
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
class LayerNormASD(ASDBase):
|
|
143
|
+
"""
|
|
144
|
+
LayerNorm with ASD.
|
|
145
|
+
"""
|
|
146
|
+
|
|
147
|
+
def __init__(self, *args, **kwargs):
|
|
148
|
+
super().__init__(OriginLayerNorm, *args, **kwargs)
|
|
149
|
+
self.pre_val, self.min_val, self.max_val, self.cnt = self.generate_params()
|
|
150
|
+
|
|
151
|
+
def __call__(self, input_x, gamma, beta):
|
|
152
|
+
if self.enable_check:
|
|
153
|
+
input_x = self.check_op(
|
|
154
|
+
input_x, self.pre_val, self.min_val, self.max_val, self.cnt, None)
|
|
155
|
+
self.cnt += 1
|
|
156
|
+
return self.op(input_x, gamma, beta)
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def _silent_check():
|
|
160
|
+
if os.environ.get(NPU_ASD_ENABLE) == "1":
|
|
161
|
+
operations.LayerNorm = LayerNormASD
|
|
162
|
+
operations.RmsNorm = RmsNormASD
|
mindspore/parallel/__init__.py
CHANGED
|
@@ -18,8 +18,9 @@ from __future__ import absolute_import
|
|
|
18
18
|
from mindspore.parallel.algo_parameter_config import get_algo_parameters, reset_algo_parameters, \
|
|
19
19
|
set_algo_parameters
|
|
20
20
|
from mindspore.parallel.checkpoint_transform import rank_list_for_transform, transform_checkpoint_by_rank, \
|
|
21
|
-
transform_checkpoints, merge_pipeline_strategys
|
|
21
|
+
transform_checkpoints, merge_pipeline_strategys, load_segmented_checkpoints
|
|
22
22
|
from mindspore.parallel.shard import shard
|
|
23
23
|
|
|
24
24
|
__all__ = ["set_algo_parameters", "reset_algo_parameters", "get_algo_parameters", "rank_list_for_transform",
|
|
25
|
-
"transform_checkpoint_by_rank", "transform_checkpoints", "merge_pipeline_strategys", "shard"
|
|
25
|
+
"transform_checkpoint_by_rank", "transform_checkpoints", "merge_pipeline_strategys", "shard",
|
|
26
|
+
"load_segmented_checkpoints"]
|
|
@@ -65,6 +65,19 @@ class _ParallelOptimizerConfig:
|
|
|
65
65
|
OPTIMIZER_WEIGHT_SHARD_SIZE = "optimizer_weight_shard_size"
|
|
66
66
|
|
|
67
67
|
|
|
68
|
+
class _PipelineConfig:
|
|
69
|
+
"""
|
|
70
|
+
The key of the Pipeline parallelism.
|
|
71
|
+
"""
|
|
72
|
+
PIPELINE_INTERLEAVE = "pipeline_interleave"
|
|
73
|
+
PIPELINE_SCHEDULER = "pipeline_scheduler"
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class _PipelineScheduler:
|
|
77
|
+
PIPELINE_1F1B = "1f1b"
|
|
78
|
+
PIPELINE_GPIPE = "gpipe"
|
|
79
|
+
|
|
80
|
+
|
|
68
81
|
class _AutoParallelContext:
|
|
69
82
|
"""
|
|
70
83
|
_AutoParallelContext is the environment in which operations are executed
|
|
@@ -105,11 +118,11 @@ class _AutoParallelContext:
|
|
|
105
118
|
device_num (int): The device number.
|
|
106
119
|
|
|
107
120
|
Raises:
|
|
108
|
-
ValueError: If the device num is not
|
|
121
|
+
ValueError: If the device num is not a positive integer.
|
|
109
122
|
"""
|
|
110
123
|
self.check_context_handle()
|
|
111
|
-
if device_num < 1
|
|
112
|
-
raise ValueError("The context configuration parameter 'device_num' must be
|
|
124
|
+
if device_num < 1:
|
|
125
|
+
raise ValueError("The context configuration parameter 'device_num' must be a positive integer, "
|
|
113
126
|
"but got the value of device_num : {}.".format(device_num))
|
|
114
127
|
from mindspore.communication._comm_helper import _HCCL_TEST_AVAILABLE
|
|
115
128
|
self._context_handle.set_hccl_test_avaible(_HCCL_TEST_AVAILABLE)
|
|
@@ -229,6 +242,16 @@ class _AutoParallelContext:
|
|
|
229
242
|
self.check_context_handle()
|
|
230
243
|
return self._context_handle.get_pipeline_stage_split_num()
|
|
231
244
|
|
|
245
|
+
def get_pipeline_interleave(self):
|
|
246
|
+
"""Get pipeline interleave flag"""
|
|
247
|
+
self.check_context_handle()
|
|
248
|
+
return self._context_handle.get_pipeline_interleave()
|
|
249
|
+
|
|
250
|
+
def get_pipeline_scheduler(self):
|
|
251
|
+
"""Get pipeline scheduler"""
|
|
252
|
+
self.check_context_handle()
|
|
253
|
+
return self._context_handle.get_pipeline_scheduler()
|
|
254
|
+
|
|
232
255
|
def set_pipeline_segments(self, segments):
|
|
233
256
|
"""Set the segments of the pipeline"""
|
|
234
257
|
if isinstance(segments, bool) or not isinstance(segments, int):
|
|
@@ -782,6 +805,57 @@ class _AutoParallelContext:
|
|
|
782
805
|
self.check_context_handle()
|
|
783
806
|
return self._context_handle.get_enable_fold_pipeline()
|
|
784
807
|
|
|
808
|
+
def set_pipeline_config(self, pipeline_config):
|
|
809
|
+
r"""
|
|
810
|
+
Set the configuration for pipeline parallelism. The configuration provides more detailed behavior control about
|
|
811
|
+
parallel training when pipeline parallelism is enabled.
|
|
812
|
+
|
|
813
|
+
Args:
|
|
814
|
+
pipeline_config (dict): The configuration for pipeline parallelism. It supports following keys:
|
|
815
|
+
|
|
816
|
+
- pipeline_interleave(bool): Setting true enable interleave scheduler for pipeline parallelism. This
|
|
817
|
+
scheduler requires more memory but less bubble.
|
|
818
|
+
- pipeline_scheduler(string): There are two choices, "1f1b" and "gpipe". default is "1f1b"
|
|
819
|
+
|
|
820
|
+
- 1f1b: It requires less memory and bubble ratio, for it run backward pass when corresponding forward pass
|
|
821
|
+
finished.
|
|
822
|
+
- gpipe: It requires more memory and bubble ratio, for it run backward pass after all forward pass
|
|
823
|
+
finished.
|
|
824
|
+
|
|
825
|
+
Raises:
|
|
826
|
+
TypeError: If the type of `pipeline_config` is not `dict`.
|
|
827
|
+
ValueError: If the key in `pipeline_config` not in ["pipeline_interleave", "pipeline_scheduler"].
|
|
828
|
+
ValueError: If pipeline interleave is False, pipeline scheduler is not `1f1b`.
|
|
829
|
+
"""
|
|
830
|
+
self.check_context_handle()
|
|
831
|
+
|
|
832
|
+
if not isinstance(pipeline_config, dict):
|
|
833
|
+
raise TypeError("For 'set_pipeline_config', the argument 'pipeine_config' "
|
|
834
|
+
"must be dict, but got the type : {}.".format(type(pipeline_config)))
|
|
835
|
+
|
|
836
|
+
pp_interleave = _PipelineConfig.PIPELINE_INTERLEAVE
|
|
837
|
+
pp_scheduler = _PipelineConfig.PIPELINE_SCHEDULER
|
|
838
|
+
|
|
839
|
+
for config_name in pipeline_config:
|
|
840
|
+
unknown_config = []
|
|
841
|
+
if config_name not in [pp_interleave, pp_scheduler]:
|
|
842
|
+
unknown_config.append(config_name)
|
|
843
|
+
|
|
844
|
+
if unknown_config:
|
|
845
|
+
raise ValueError("Unknown config: {}".format(unknown_config))
|
|
846
|
+
|
|
847
|
+
Validator.check_bool(
|
|
848
|
+
pipeline_config[pp_interleave], pp_interleave, pp_interleave)
|
|
849
|
+
self._context_handle.set_pipeline_interleave(
|
|
850
|
+
pipeline_config[pp_interleave])
|
|
851
|
+
|
|
852
|
+
Validator.check_string(pipeline_config[pp_scheduler], [_PipelineScheduler.PIPELINE_1F1B,
|
|
853
|
+
_PipelineScheduler.PIPELINE_GPIPE])
|
|
854
|
+
if not pipeline_config[pp_interleave] and pipeline_config[pp_scheduler] != _PipelineScheduler.PIPELINE_1F1B:
|
|
855
|
+
raise ValueError(f"When pipeline_interleave is False, {pp_scheduler} is not supported")
|
|
856
|
+
|
|
857
|
+
self._context_handle.set_pipeline_scheduler(pipeline_config[pp_scheduler])
|
|
858
|
+
|
|
785
859
|
def get_enable_parallel_optimizer(self):
|
|
786
860
|
"""Get parallel optimizer flag."""
|
|
787
861
|
self.check_context_handle()
|
|
@@ -1068,6 +1142,7 @@ class _AutoParallelContext:
|
|
|
1068
1142
|
self.set_enable_all_gather_fusion(openstate)
|
|
1069
1143
|
self.set_enable_reduce_scatter_fusion(openstate)
|
|
1070
1144
|
|
|
1145
|
+
|
|
1071
1146
|
def _set_ops_strategy_json_config(type="SAVE", path="", mode="all"):
|
|
1072
1147
|
"""
|
|
1073
1148
|
Set strategy json configuration.
|
|
@@ -1091,6 +1166,7 @@ def _set_ops_strategy_json_config(type="SAVE", path="", mode="all"):
|
|
|
1091
1166
|
else:
|
|
1092
1167
|
raise KeyError("Type must be 'SAVE' or 'LOAD' and mode must be 'all' or 'principal'")
|
|
1093
1168
|
|
|
1169
|
+
|
|
1094
1170
|
_AUTO_PARALLEL_CONTEXT = None
|
|
1095
1171
|
|
|
1096
1172
|
|
|
@@ -1126,6 +1202,7 @@ _set_auto_parallel_context_func_map = {
|
|
|
1126
1202
|
"dataset_strategy": auto_parallel_context().set_dataset_strategy,
|
|
1127
1203
|
"enable_parallel_optimizer": auto_parallel_context().set_enable_parallel_optimizer,
|
|
1128
1204
|
"parallel_optimizer_config": auto_parallel_context().set_parallel_optimizer_config,
|
|
1205
|
+
"pipeline_config": auto_parallel_context().set_pipeline_config,
|
|
1129
1206
|
"grad_accumulation_step": auto_parallel_context().set_grad_accumulation_step,
|
|
1130
1207
|
"all_reduce_fusion_config": auto_parallel_context().set_all_reduce_fusion_split_indices,
|
|
1131
1208
|
"communi_parallel_mode": auto_parallel_context().set_communi_parallel_mode,
|
|
@@ -1143,6 +1220,8 @@ _get_auto_parallel_context_func_map = {
|
|
|
1143
1220
|
"gradient_fp32_sync": auto_parallel_context().get_gradient_fp32_sync,
|
|
1144
1221
|
"loss_repeated_mean": auto_parallel_context().get_loss_repeated_mean,
|
|
1145
1222
|
"pipeline_stages": auto_parallel_context().get_pipeline_stages,
|
|
1223
|
+
"pipeline_interleave": auto_parallel_context().get_pipeline_interleave,
|
|
1224
|
+
"pipeline_scheduler": auto_parallel_context().get_pipeline_scheduler,
|
|
1146
1225
|
"parallel_mode": auto_parallel_context().get_parallel_mode,
|
|
1147
1226
|
"search_mode": auto_parallel_context().get_strategy_search_mode,
|
|
1148
1227
|
"auto_parallel_search_mode": auto_parallel_context().get_auto_parallel_search_mode,
|