mindspore 2.7.0__cp311-cp311-win_amd64.whl → 2.7.1__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +4 -1
- mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
- mindspore/_extends/parse/compile_config.py +24 -1
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +6 -2
- mindspore/_extends/parse/resources.py +1 -1
- mindspore/_extends/parse/standard_method.py +8 -1
- mindspore/_extends/parse/trope.py +2 -1
- mindspore/_extends/pijit/pijit_func_white_list.py +7 -22
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/base.py +29 -2
- mindspore/common/_decorator.py +3 -2
- mindspore/common/_grad_function.py +3 -1
- mindspore/common/_tensor_cpp_method.py +1 -1
- mindspore/common/_tensor_docs.py +275 -64
- mindspore/common/_utils.py +0 -44
- mindspore/common/api.py +285 -35
- mindspore/common/dump.py +7 -108
- mindspore/common/dynamic_shape/auto_dynamic_shape.py +1 -3
- mindspore/common/hook_handle.py +60 -0
- mindspore/common/jit_config.py +5 -1
- mindspore/common/jit_trace.py +27 -12
- mindspore/common/lazy_inline.py +5 -3
- mindspore/common/parameter.py +13 -107
- mindspore/common/recompute.py +4 -11
- mindspore/common/tensor.py +16 -169
- mindspore/communication/_comm_helper.py +11 -1
- mindspore/communication/comm_func.py +138 -4
- mindspore/communication/management.py +85 -1
- mindspore/config/op_info.config +0 -15
- mindspore/context.py +5 -85
- mindspore/dataset/engine/datasets.py +8 -4
- mindspore/dataset/engine/datasets_vision.py +1 -1
- mindspore/dataset/engine/validators.py +1 -15
- mindspore/dnnl.dll +0 -0
- mindspore/{experimental/llm_boost/ascend_native → graph}/__init__.py +7 -7
- mindspore/graph/custom_pass.py +55 -0
- mindspore/include/dataset/execute.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/mindrecord/__init__.py +3 -3
- mindspore/mindrecord/common/exceptions.py +1 -0
- mindspore/mindrecord/config.py +1 -1
- mindspore/{parallel/mpi → mindrecord/core}/__init__.py +4 -1
- mindspore/mindrecord/{shardheader.py → core/shardheader.py} +2 -1
- mindspore/mindrecord/{shardindexgenerator.py → core/shardindexgenerator.py} +1 -1
- mindspore/mindrecord/{shardreader.py → core/shardreader.py} +2 -1
- mindspore/mindrecord/{shardsegment.py → core/shardsegment.py} +2 -2
- mindspore/mindrecord/{shardutils.py → core/shardutils.py} +1 -1
- mindspore/mindrecord/{shardwriter.py → core/shardwriter.py} +1 -1
- mindspore/mindrecord/filereader.py +4 -4
- mindspore/mindrecord/filewriter.py +5 -5
- mindspore/mindrecord/mindpage.py +2 -2
- mindspore/mindrecord/tools/cifar10.py +1 -1
- mindspore/mindrecord/tools/cifar100.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +1 -1
- mindspore/mindrecord/tools/cifar10_to_mr.py +1 -1
- mindspore/mindrecord/tools/csv_to_mr.py +1 -1
- mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
- mindspore/mindrecord/tools/mnist_to_mr.py +1 -1
- mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -1
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_cluster.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_cpu.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_hardware_abstract.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/{mindspore_ops_host.dll → mindspore_ops_cpu.dll} +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mindspore_runtime_utils.dll +0 -0
- mindspore/mindspore_tools.dll +0 -0
- mindspore/mint/__init__.py +15 -10
- mindspore/mint/distributed/distributed.py +182 -62
- mindspore/mint/nn/__init__.py +2 -16
- mindspore/mint/nn/functional.py +4 -110
- mindspore/mint/nn/layer/__init__.py +0 -2
- mindspore/mint/nn/layer/activation.py +0 -6
- mindspore/mint/nn/layer/basic.py +0 -47
- mindspore/mint/nn/layer/conv.py +4 -4
- mindspore/mint/nn/layer/normalization.py +8 -13
- mindspore/mint/nn/layer/pooling.py +0 -4
- mindspore/nn/__init__.py +1 -3
- mindspore/nn/cell.py +16 -66
- mindspore/nn/layer/basic.py +49 -1
- mindspore/nn/layer/container.py +16 -0
- mindspore/nn/layer/embedding.py +4 -169
- mindspore/nn/layer/normalization.py +2 -1
- mindspore/nn/layer/thor_layer.py +4 -85
- mindspore/nn/optim/ada_grad.py +0 -1
- mindspore/nn/optim/adafactor.py +0 -1
- mindspore/nn/optim/adam.py +31 -124
- mindspore/nn/optim/adamax.py +0 -1
- mindspore/nn/optim/asgd.py +0 -1
- mindspore/nn/optim/ftrl.py +8 -102
- mindspore/nn/optim/lamb.py +0 -1
- mindspore/nn/optim/lars.py +0 -3
- mindspore/nn/optim/lazyadam.py +25 -218
- mindspore/nn/optim/momentum.py +5 -43
- mindspore/nn/optim/optimizer.py +6 -55
- mindspore/nn/optim/proximal_ada_grad.py +0 -1
- mindspore/nn/optim/rmsprop.py +0 -1
- mindspore/nn/optim/rprop.py +0 -1
- mindspore/nn/optim/sgd.py +0 -1
- mindspore/nn/optim/tft_wrapper.py +0 -1
- mindspore/nn/optim/thor.py +0 -2
- mindspore/nn/probability/bijector/bijector.py +7 -8
- mindspore/nn/probability/bijector/gumbel_cdf.py +2 -2
- mindspore/nn/probability/bijector/power_transform.py +20 -21
- mindspore/nn/probability/bijector/scalar_affine.py +5 -5
- mindspore/nn/probability/bijector/softplus.py +13 -14
- mindspore/nn/wrap/grad_reducer.py +4 -74
- mindspore/numpy/array_creations.py +2 -2
- mindspore/numpy/fft.py +9 -9
- mindspore/{nn/reinforcement → onnx}/__init__.py +5 -8
- mindspore/onnx/onnx_export.py +137 -0
- mindspore/opencv_core4110.dll +0 -0
- mindspore/opencv_imgcodecs4110.dll +0 -0
- mindspore/{opencv_imgproc452.dll → opencv_imgproc4110.dll} +0 -0
- mindspore/ops/__init__.py +2 -0
- mindspore/ops/_grad_experimental/grad_comm_ops.py +38 -2
- mindspore/ops/_op_impl/aicpu/__init__.py +0 -10
- mindspore/ops/_op_impl/cpu/__init__.py +0 -5
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +16 -22
- mindspore/ops/auto_generate/gen_extend_func.py +2 -7
- mindspore/ops/auto_generate/gen_ops_def.py +98 -141
- mindspore/ops/auto_generate/gen_ops_prim.py +12708 -12686
- mindspore/ops/communication.py +97 -0
- mindspore/ops/composite/__init__.py +5 -2
- mindspore/ops/composite/base.py +15 -1
- mindspore/ops/composite/multitype_ops/__init__.py +3 -1
- mindspore/ops/composite/multitype_ops/_compile_utils.py +150 -8
- mindspore/ops/composite/multitype_ops/add_impl.py +7 -0
- mindspore/ops/composite/multitype_ops/mod_impl.py +27 -0
- mindspore/ops/function/__init__.py +1 -0
- mindspore/ops/function/array_func.py +14 -12
- mindspore/ops/function/comm_func.py +3883 -0
- mindspore/ops/function/debug_func.py +3 -4
- mindspore/ops/function/math_func.py +45 -54
- mindspore/ops/function/nn_func.py +75 -294
- mindspore/ops/function/random_func.py +9 -18
- mindspore/ops/functional.py +2 -0
- mindspore/ops/functional_overload.py +354 -18
- mindspore/ops/operations/__init__.py +2 -5
- mindspore/ops/operations/_custom_ops_utils.py +7 -9
- mindspore/ops/operations/_inner_ops.py +1 -38
- mindspore/ops/operations/_rl_inner_ops.py +0 -933
- mindspore/ops/operations/array_ops.py +1 -0
- mindspore/ops/operations/comm_ops.py +94 -2
- mindspore/ops/operations/custom_ops.py +228 -19
- mindspore/ops/operations/debug_ops.py +27 -29
- mindspore/ops/operations/manually_defined/ops_def.py +27 -306
- mindspore/ops/operations/nn_ops.py +2 -2
- mindspore/ops/operations/sparse_ops.py +0 -83
- mindspore/ops/primitive.py +1 -17
- mindspore/ops/tensor_method.py +72 -3
- mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +5 -5
- mindspore/ops_generate/aclnn/gen_aclnn_implement.py +8 -8
- mindspore/ops_generate/api/functions_cc_generator.py +53 -4
- mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +25 -11
- mindspore/ops_generate/common/gen_constants.py +11 -10
- mindspore/ops_generate/common/op_proto.py +18 -1
- mindspore/ops_generate/common/template.py +102 -245
- mindspore/ops_generate/common/template_utils.py +212 -0
- mindspore/ops_generate/gen_custom_ops.py +69 -0
- mindspore/ops_generate/op_def/ops_def_cc_generator.py +78 -7
- mindspore/ops_generate/op_def_py/base_op_prim_py_generator.py +360 -0
- mindspore/ops_generate/op_def_py/custom_op_prim_py_generator.py +140 -0
- mindspore/ops_generate/op_def_py/op_def_py_generator.py +54 -7
- mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -312
- mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +74 -17
- mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +22 -5
- mindspore/ops_generate/pyboost/op_template_parser.py +3 -2
- mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +21 -5
- mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +2 -2
- mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +30 -10
- mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +10 -3
- mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +1 -1
- mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +19 -9
- mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +71 -28
- mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +10 -9
- mindspore/ops_generate/pyboost/pyboost_utils.py +27 -16
- mindspore/ops_generate/resources/yaml_loader.py +13 -0
- mindspore/ops_generate/tensor_py_cc_generator.py +2 -2
- mindspore/parallel/_cell_wrapper.py +1 -1
- mindspore/parallel/_parallel_serialization.py +1 -4
- mindspore/parallel/_utils.py +29 -6
- mindspore/parallel/checkpoint_transform.py +18 -2
- mindspore/parallel/cluster/process_entity/_api.py +24 -32
- mindspore/parallel/cluster/process_entity/_utils.py +9 -5
- mindspore/{experimental/llm_boost/atb → parallel/distributed}/__init__.py +21 -23
- mindspore/parallel/distributed/distributed_data_parallel.py +393 -0
- mindspore/parallel/distributed/flatten_grad_buffer.py +295 -0
- mindspore/parallel/strategy.py +336 -0
- mindspore/parallel/transform_safetensors.py +117 -16
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +3 -0
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +1 -1
- mindspore/profiler/common/constant.py +5 -0
- mindspore/profiler/common/file_manager.py +9 -0
- mindspore/profiler/common/msprof_cmd_tool.py +38 -2
- mindspore/profiler/common/path_manager.py +56 -24
- mindspore/profiler/common/profiler_context.py +2 -12
- mindspore/profiler/common/profiler_info.py +3 -3
- mindspore/profiler/common/profiler_path_manager.py +13 -0
- mindspore/profiler/common/util.py +30 -3
- mindspore/profiler/experimental_config.py +2 -1
- mindspore/profiler/platform/npu_profiler.py +33 -6
- mindspore/run_check/_check_version.py +108 -24
- mindspore/runtime/__init__.py +3 -2
- mindspore/runtime/executor.py +11 -3
- mindspore/runtime/memory.py +112 -0
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/{experimental/llm_boost → tools}/__init__.py +5 -5
- mindspore/tools/data_dump.py +130 -0
- mindspore/tools/sdc_detect.py +91 -0
- mindspore/tools/stress_detect.py +63 -0
- mindspore/train/__init__.py +6 -6
- mindspore/train/_utils.py +5 -18
- mindspore/train/amp.py +6 -4
- mindspore/train/callback/_checkpoint.py +0 -9
- mindspore/train/callback/_train_fault_tolerance.py +69 -18
- mindspore/train/data_sink.py +1 -5
- mindspore/train/model.py +38 -211
- mindspore/train/serialization.py +126 -387
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +6 -3
- mindspore/utils/dlpack.py +92 -0
- mindspore/utils/dryrun.py +1 -1
- mindspore/utils/runtime_execution_order_check.py +10 -0
- mindspore/utils/sdc_detect.py +14 -12
- mindspore/utils/stress_detect.py +43 -0
- mindspore/utils/utils.py +144 -8
- mindspore/version.py +1 -1
- {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/METADATA +3 -2
- {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/RECORD +254 -267
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -210
- mindspore/experimental/llm_boost/ascend_native/llm_boost.py +0 -52
- mindspore/experimental/llm_boost/atb/boost_base.py +0 -385
- mindspore/experimental/llm_boost/atb/llama_boost.py +0 -137
- mindspore/experimental/llm_boost/atb/qwen_boost.py +0 -124
- mindspore/experimental/llm_boost/register.py +0 -130
- mindspore/experimental/llm_boost/utils.py +0 -31
- mindspore/include/OWNERS +0 -7
- mindspore/mindspore_cpu_res_manager.dll +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/nn/optim/_dist_optimizer_registry.py +0 -111
- mindspore/nn/reinforcement/_batch_read_write.py +0 -142
- mindspore/nn/reinforcement/_tensors_queue.py +0 -152
- mindspore/nn/reinforcement/tensor_array.py +0 -145
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +0 -113
- mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +0 -96
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +0 -42
- mindspore/ops/_op_impl/cpu/buffer_append.py +0 -28
- mindspore/ops/_op_impl/cpu/buffer_get.py +0 -28
- mindspore/ops/_op_impl/cpu/buffer_sample.py +0 -28
- mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +0 -42
- mindspore/ops/operations/_tensor_array.py +0 -359
- mindspore/ops/operations/rl_ops.py +0 -288
- mindspore/parallel/_offload_context.py +0 -275
- mindspore/parallel/_recovery_context.py +0 -115
- mindspore/parallel/_transformer/__init__.py +0 -35
- mindspore/parallel/_transformer/layers.py +0 -765
- mindspore/parallel/_transformer/loss.py +0 -251
- mindspore/parallel/_transformer/moe.py +0 -693
- mindspore/parallel/_transformer/op_parallel_config.py +0 -222
- mindspore/parallel/_transformer/transformer.py +0 -3124
- mindspore/parallel/mpi/_mpi_config.py +0 -116
- mindspore/train/memory_profiling_pb2.py +0 -298
- {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/WHEEL +0 -0
- {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/top_level.txt +0 -0
|
@@ -1,693 +0,0 @@
|
|
|
1
|
-
# Copyright 2023 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ============================================================================
|
|
15
|
-
"""
|
|
16
|
-
Note:
|
|
17
|
-
Mixture of Expert (MoE) structure.
|
|
18
|
-
These are experimental APIs that are subject to change or deletion.
|
|
19
|
-
"""
|
|
20
|
-
from __future__ import absolute_import
|
|
21
|
-
from __future__ import division
|
|
22
|
-
|
|
23
|
-
import numpy as np
|
|
24
|
-
|
|
25
|
-
from mindspore.common.tensor import Tensor
|
|
26
|
-
import mindspore.common.dtype as mstype
|
|
27
|
-
import mindspore.communication.management as D
|
|
28
|
-
from mindspore import _checkparam as Validator
|
|
29
|
-
from mindspore.ops import operations as P
|
|
30
|
-
from mindspore.ops import functional as F
|
|
31
|
-
from mindspore.ops.primitive import _primexpr
|
|
32
|
-
from mindspore.nn.cell import Cell
|
|
33
|
-
from mindspore.nn.layer import Dense
|
|
34
|
-
from mindspore.context import ParallelMode
|
|
35
|
-
from mindspore.parallel._utils import _get_parallel_mode, _is_sharding_propagation
|
|
36
|
-
from mindspore.parallel._transformer.op_parallel_config import default_moeparallel_config
|
|
37
|
-
|
|
38
|
-
__all__ = [
|
|
39
|
-
"MoEConfig"]
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
class MoEConfig:
|
|
43
|
-
r"""
|
|
44
|
-
The configuration of MoE (Mixture of Expert).
|
|
45
|
-
|
|
46
|
-
Args:
|
|
47
|
-
expert_num (int): The number of experts employed. Default: 1
|
|
48
|
-
capacity_factor (float): The factor is used to indicate how much to expand expert capacity,
|
|
49
|
-
which is >=1.0. Default: 1.1.
|
|
50
|
-
aux_loss_factor (float): The factor is used to indicate how much the load balance loss (produced by the
|
|
51
|
-
router) to be added to the entire model loss, which is < 1.0. Default: 0.05.
|
|
52
|
-
num_experts_chosen (int): The number of experts is chosen by each token and it should not be larger
|
|
53
|
-
than expert_num. Default: 1.
|
|
54
|
-
expert_group_size (int): The number of tokens in each data parallel group. Default: ``None``.
|
|
55
|
-
This parameter is effective only when in AUTO_PARALLEL mode, and NOT SHARDING_PROPAGATION.
|
|
56
|
-
group_wise_a2a (bool): Whether to enable group-wise alltoall communication, which can reduce communication
|
|
57
|
-
time by converting part of inter communication into intra communication. Default: ``False``.
|
|
58
|
-
This parameter is effective only when model parallel > 1 and data_parallel equal to expert parallel.
|
|
59
|
-
comp_comm_parallel (bool): Whether to enable ffn compute and communication parallel, which can reduce pure
|
|
60
|
-
communicattion time by splitting and overlapping compute and communication. Default: ``False``.
|
|
61
|
-
comp_comm_parallel_degree (int): The split number of compute and communication. The larger the numbers,
|
|
62
|
-
the more overlap there will be but will consume more memory. Default: 2. This parameter is effective
|
|
63
|
-
only when comp_comm_parallel enable.
|
|
64
|
-
|
|
65
|
-
Supported Platforms:
|
|
66
|
-
``Ascend`` ``GPU``
|
|
67
|
-
|
|
68
|
-
Examples:
|
|
69
|
-
>>> from mindspore.nn.transformer import MoEConfig
|
|
70
|
-
>>> moe_config = MoEConfig(expert_num=4, capacity_factor=5.0, aux_loss_factor=0.05, num_experts_chosen=1,
|
|
71
|
-
... expert_group_size=64, group_wise_a2a=True, comp_comm_parallel=False,
|
|
72
|
-
... comp_comm_parallel_degree=2)
|
|
73
|
-
"""
|
|
74
|
-
|
|
75
|
-
def __init__(self, expert_num=1, capacity_factor=1.1, aux_loss_factor=0.05, num_experts_chosen=1,
|
|
76
|
-
expert_group_size=None, group_wise_a2a=False, comp_comm_parallel=False, comp_comm_parallel_degree=2):
|
|
77
|
-
Validator.check_positive_int(expert_num, "expert_num")
|
|
78
|
-
Validator.check_positive_float(capacity_factor, "capacity_factor")
|
|
79
|
-
Validator.check_positive_float(aux_loss_factor, "aux_loss_factor")
|
|
80
|
-
Validator.check_positive_int(num_experts_chosen, "num_experts_chosen")
|
|
81
|
-
Validator.check_bool(group_wise_a2a, "group_wise_a2a")
|
|
82
|
-
Validator.check_bool(comp_comm_parallel, "comp_comm_parallel")
|
|
83
|
-
Validator.check_positive_int(comp_comm_parallel_degree, "comp_comm_parallel_degree")
|
|
84
|
-
if expert_group_size is not None:
|
|
85
|
-
Validator.check_positive_int(expert_group_size, "expert_group_size")
|
|
86
|
-
if capacity_factor < 1.0:
|
|
87
|
-
raise ValueError(f"'capacity_factor' must be equal to or greater than 1.0, "
|
|
88
|
-
f"but got {capacity_factor}.")
|
|
89
|
-
if aux_loss_factor >= 1.0:
|
|
90
|
-
raise ValueError(f"'aux_loss_factor' must be less than 1.0, "
|
|
91
|
-
f"but got {aux_loss_factor}.")
|
|
92
|
-
if num_experts_chosen > expert_num:
|
|
93
|
-
raise ValueError(f"'num_experts_chosen' must not be larger than 'expert_num', "
|
|
94
|
-
f"but got {num_experts_chosen}.")
|
|
95
|
-
self.expert_num = expert_num
|
|
96
|
-
self.capacity_factor = capacity_factor
|
|
97
|
-
self.aux_loss_factor = aux_loss_factor
|
|
98
|
-
self.num_experts_chosen = num_experts_chosen
|
|
99
|
-
self.expert_group_size = expert_group_size
|
|
100
|
-
self.group_wise_a2a = group_wise_a2a
|
|
101
|
-
self.comp_comm_parallel = comp_comm_parallel
|
|
102
|
-
self.comp_comm_parallel_degree = comp_comm_parallel_degree
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
default_moe_config = MoEConfig()
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
def _check_moe_config(moe_config=None, parallel_config=None):
|
|
109
|
-
"""
|
|
110
|
-
check if MoE with right configuration.
|
|
111
|
-
"""
|
|
112
|
-
if not isinstance(moe_config, MoEConfig):
|
|
113
|
-
raise TypeError(f"'moe_config' must be an instance of MoEConfig, but got {type(moe_config).__name__}.")
|
|
114
|
-
use_moe = moe_config.expert_num > 1
|
|
115
|
-
if use_moe is False:
|
|
116
|
-
return
|
|
117
|
-
if moe_config.expert_num % parallel_config.expert_parallel != 0:
|
|
118
|
-
raise ValueError(f"When using MoE, the 'expert_num' in {type(moe_config).__name__} must be a multiple "
|
|
119
|
-
f"of 'expert_parallel' value in {type(parallel_config).__name__}, but got "
|
|
120
|
-
f"{moe_config.expert_num} for 'expert_num' and {parallel_config.expert_parallel} for "
|
|
121
|
-
f"'expert_parallel'.")
|
|
122
|
-
|
|
123
|
-
device_num = D.get_group_size()
|
|
124
|
-
if device_num % parallel_config.expert_parallel != 0:
|
|
125
|
-
raise ValueError(f"device_num: {device_num} must be a multiple of expert_parallel: "
|
|
126
|
-
f"{parallel_config.expert_parallel}.")
|
|
127
|
-
if parallel_config.data_parallel % parallel_config.expert_parallel != 0:
|
|
128
|
-
raise ValueError(f"data parallel: {parallel_config.data_parallel} must be a multiple of "
|
|
129
|
-
f"expert_parallel: {parallel_config.expert_parallel} when using MoE.")
|
|
130
|
-
if parallel_config.data_parallel * parallel_config.model_parallel > device_num:
|
|
131
|
-
raise ValueError(f"The product of the data parallel: {parallel_config.data_parallel} and "
|
|
132
|
-
f"model parallel: {parallel_config.model_parallel} "
|
|
133
|
-
f"should be less than device_num: {device_num}.")
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
@_primexpr
|
|
137
|
-
def calculate_expert_capacity(k, tokens_per_group, capacity_factor, expert_dim):
|
|
138
|
-
res = k * tokens_per_group * capacity_factor / expert_dim
|
|
139
|
-
res_int = int(res)
|
|
140
|
-
return res_int if res < 0 or res == res_int else res_int + 1
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
class MoE(Cell):
|
|
144
|
-
"""
|
|
145
|
-
The mixture of experts (MoE) implementation. The implementation includes a router and a FeedForward layer.
|
|
146
|
-
The router dispatches tokens to experts in FeedForward, then FeedForward does computation, and the final output is
|
|
147
|
-
obtained by multiplying FeedForward's output and router's combine weight.
|
|
148
|
-
|
|
149
|
-
Args:
|
|
150
|
-
hidden_size (int): The dimension of the inputs.
|
|
151
|
-
ffn_hidden_size (int): The intermediate hidden size.
|
|
152
|
-
dropout_rate (float): The dropout rate for the second linear's output.
|
|
153
|
-
hidden_act (str): The activation of the internal feedforward layer. Supports 'relu',
|
|
154
|
-
'relu6', 'tanh', 'gelu', 'fast_gelu', 'elu', 'sigmoid', 'prelu', 'leakyrelu', 'hswish',
|
|
155
|
-
'hsigmoid', 'logsigmoid' and so on. Default: gelu.
|
|
156
|
-
param_init_type (dtype.Number): The parameter initialization type. Can be dtype.float32 or dtype.float16.
|
|
157
|
-
moe_config(MoEConfig): The configuration of MoE (Mixture of Expert). Default is an instance of MoEConfig with
|
|
158
|
-
default values. Please see `MoEConfig`.
|
|
159
|
-
parallel_config(MoEParallelConfig): The parallel config for MoE, see `MoEParallelConfig`.
|
|
160
|
-
Default `default_moeparallel_config`, an instance of `MoEParallelConfig` with default args.
|
|
161
|
-
|
|
162
|
-
Inputs:
|
|
163
|
-
- **x** (Tensor) - should be `[batch, seq_length, hidden_size]`. Float tensor.
|
|
164
|
-
|
|
165
|
-
Outputs:
|
|
166
|
-
Tensor, the output of this layer after mapping. The shape is `[batch, seq_length, hidden_size]`.
|
|
167
|
-
"""
|
|
168
|
-
|
|
169
|
-
def __init__(self, hidden_size,
|
|
170
|
-
ffn_hidden_size,
|
|
171
|
-
dropout_rate,
|
|
172
|
-
hidden_act='gelu',
|
|
173
|
-
param_init_type=mstype.float32,
|
|
174
|
-
moe_config=default_moe_config,
|
|
175
|
-
parallel_config=default_moeparallel_config):
|
|
176
|
-
super(MoE, self).__init__()
|
|
177
|
-
if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,) and _is_sharding_propagation():
|
|
178
|
-
self.hidden_size = hidden_size
|
|
179
|
-
self.expert_dim = moe_config.expert_num
|
|
180
|
-
self.capacity_factor = moe_config.capacity_factor
|
|
181
|
-
self.aux_loss_factor = moe_config.aux_loss_factor
|
|
182
|
-
self.num_experts_chosen = moe_config.num_experts_chosen
|
|
183
|
-
self.expert_group_size = moe_config.expert_group_size
|
|
184
|
-
self.dp_group = parallel_config.data_parallel
|
|
185
|
-
self.dp = parallel_config.data_parallel
|
|
186
|
-
self.ep = parallel_config.expert_parallel
|
|
187
|
-
self.mp = parallel_config.model_parallel
|
|
188
|
-
self.comp_comm_parallel = moe_config.comp_comm_parallel
|
|
189
|
-
self.comp_comm_parallel_degree = moe_config.comp_comm_parallel_degree
|
|
190
|
-
self.group_wise_a2a = moe_config.group_wise_a2a
|
|
191
|
-
if not (self.mp > 1 and self.dp == self.ep):
|
|
192
|
-
self.group_wise_a2a = False
|
|
193
|
-
from mindspore.parallel._transformer import FeedForward
|
|
194
|
-
|
|
195
|
-
self.ffn = FeedForward(hidden_size=hidden_size,
|
|
196
|
-
ffn_hidden_size=ffn_hidden_size,
|
|
197
|
-
dropout_rate=dropout_rate,
|
|
198
|
-
hidden_act=hidden_act,
|
|
199
|
-
expert_num=self.expert_dim,
|
|
200
|
-
expert_group_size=self.expert_group_size,
|
|
201
|
-
param_init_type=param_init_type,
|
|
202
|
-
parallel_config=parallel_config)
|
|
203
|
-
self.reshape = P.Reshape()
|
|
204
|
-
self.shape = P.Shape()
|
|
205
|
-
self.transpose_2dim = P.Transpose().shard(((self.dp, 1),))
|
|
206
|
-
self.transpose_3dim = P.Transpose().shard(((self.dp, 1, 1),))
|
|
207
|
-
self.transpose_4dim = P.Transpose().shard(((1, self.dp, 1, 1),))
|
|
208
|
-
self.transpose_4dim_dp = P.Transpose().shard(((1, 1, self.dp, 1),))
|
|
209
|
-
self.batch_mm = P.BatchMatMul().shard(((self.dp, 1, 1), (self.dp, 1, 1)))
|
|
210
|
-
self.batch_mm2 = P.BatchMatMul().shard(((self.dp, 1, 1), (self.dp, 1, 1)))
|
|
211
|
-
self.mul = P.Mul()
|
|
212
|
-
self.router = Router(d_model=hidden_size, moe_config=moe_config, routing_policy=None,
|
|
213
|
-
training=True, parallel_config=parallel_config)
|
|
214
|
-
self.cast = P.Cast()
|
|
215
|
-
self.concat = P.Concat(3).shard(tuple((self.dp, 1, 1, 1) for _ in range(self.comp_comm_parallel_degree)))
|
|
216
|
-
self.concat_dp = P.Concat(2).shard(((1, self.dp, 1, 1), (1, self.dp, 1, 1)))
|
|
217
|
-
self.split = P.Split(axis=2, output_num=self.comp_comm_parallel_degree).shard(((1, self.dp, 1, 1),))
|
|
218
|
-
self.stride_slice = P.StridedSlice().shard(((self.dp, 1, 1, 1),))
|
|
219
|
-
self.stride_slice_dp = P.StridedSlice().shard(((1, self.dp, 1, 1),))
|
|
220
|
-
self.stride_slice_ep = P.StridedSlice().shard(((self.ep, 1, 1, 1),))
|
|
221
|
-
self.stride_slice_dp_mp = P.StridedSlice().shard(((1, self.dp, self.mp, 1),))
|
|
222
|
-
self.stride_slice_ep_mp = P.StridedSlice().shard(((self.ep, 1, self.mp, 1),))
|
|
223
|
-
else:
|
|
224
|
-
self.hidden_size = hidden_size
|
|
225
|
-
self.expert_dim = moe_config.expert_num
|
|
226
|
-
self.capacity_factor = moe_config.capacity_factor
|
|
227
|
-
self.aux_loss_factor = moe_config.aux_loss_factor
|
|
228
|
-
self.num_experts_chosen = moe_config.num_experts_chosen
|
|
229
|
-
self.dp_group = parallel_config.data_parallel
|
|
230
|
-
self.dp = parallel_config.data_parallel
|
|
231
|
-
self.ep = parallel_config.expert_parallel
|
|
232
|
-
self.mp = parallel_config.model_parallel
|
|
233
|
-
self.comp_comm_parallel = moe_config.comp_comm_parallel
|
|
234
|
-
self.comp_comm_parallel_degree = moe_config.comp_comm_parallel_degree
|
|
235
|
-
self.group_wise_a2a = moe_config.group_wise_a2a
|
|
236
|
-
if not (self.mp > 1 and self.dp == self.ep):
|
|
237
|
-
self.group_wise_a2a = False
|
|
238
|
-
from mindspore.parallel._transformer import FeedForward
|
|
239
|
-
|
|
240
|
-
self.ffn = FeedForward(hidden_size=hidden_size,
|
|
241
|
-
ffn_hidden_size=ffn_hidden_size,
|
|
242
|
-
dropout_rate=dropout_rate,
|
|
243
|
-
hidden_act=hidden_act,
|
|
244
|
-
expert_num=self.expert_dim,
|
|
245
|
-
param_init_type=param_init_type,
|
|
246
|
-
parallel_config=parallel_config)
|
|
247
|
-
self.reshape = P.Reshape()
|
|
248
|
-
self.shape = P.Shape()
|
|
249
|
-
self.transpose_2dim = P.Transpose().shard(((self.dp, 1),))
|
|
250
|
-
self.transpose_3dim = P.Transpose().shard(((self.dp, 1, 1),))
|
|
251
|
-
self.transpose_4dim = P.Transpose().shard(((1, self.dp, 1, 1),))
|
|
252
|
-
self.transpose_4dim_dp = P.Transpose().shard(((1, 1, self.dp, 1),))
|
|
253
|
-
self.batch_mm = P.BatchMatMul().shard(((self.dp, 1, 1), (self.dp, 1, 1)))
|
|
254
|
-
self.batch_mm2 = P.BatchMatMul().shard(((self.dp, 1, 1), (self.dp, 1, 1)))
|
|
255
|
-
self.mul = P.Mul().shard(((), ()))
|
|
256
|
-
self.router = Router(d_model=hidden_size, moe_config=moe_config, routing_policy=None,
|
|
257
|
-
training=True, parallel_config=parallel_config)
|
|
258
|
-
self.cast = P.Cast()
|
|
259
|
-
self.concat = P.Concat(3).shard(tuple((self.dp, 1, 1, 1) for _ in range(self.comp_comm_parallel_degree)))
|
|
260
|
-
self.concat_dp = P.Concat(2).shard(((1, self.dp, 1, 1), (1, self.dp, 1, 1)))
|
|
261
|
-
self.split = P.Split(axis=2, output_num=self.comp_comm_parallel_degree).shard(((1, self.dp, 1, 1),))
|
|
262
|
-
self.stride_slice = P.StridedSlice().shard(((self.dp, 1, 1, 1),))
|
|
263
|
-
self.stride_slice_dp = P.StridedSlice().shard(((1, self.dp, 1, 1),))
|
|
264
|
-
self.stride_slice_ep = P.StridedSlice().shard(((self.ep, 1, 1, 1),))
|
|
265
|
-
self.stride_slice_dp_mp = P.StridedSlice().shard(((1, self.dp, self.mp, 1),))
|
|
266
|
-
self.stride_slice_ep_mp = P.StridedSlice().shard(((self.ep, 1, self.mp, 1),))
|
|
267
|
-
|
|
268
|
-
def ffn_infer(self, expert_input, capacity):
|
|
269
|
-
"""
|
|
270
|
-
Computing the FFN.
|
|
271
|
-
"""
|
|
272
|
-
pad_size = 0
|
|
273
|
-
if self.group_wise_a2a:
|
|
274
|
-
# If capacity can't div by mp, pad for mp shard.
|
|
275
|
-
if capacity % self.mp != 0:
|
|
276
|
-
pad_size = self.mp - (capacity % self.mp)
|
|
277
|
-
if pad_size != 0:
|
|
278
|
-
capacity += pad_size
|
|
279
|
-
pad_tensor = self.stride_slice_dp(expert_input, (0, 0, 0, 0),
|
|
280
|
-
(self.expert_dim, self.dp_group, pad_size, self.hidden_size),
|
|
281
|
-
(1, 1, 1, 1))
|
|
282
|
-
expert_input = self.concat_dp((expert_input, pad_tensor))
|
|
283
|
-
# capacity shard by mp
|
|
284
|
-
expert_input = self.stride_slice_dp_mp(expert_input, (0, 0, 0, 0),
|
|
285
|
-
(self.expert_dim, self.dp_group, capacity, self.hidden_size),
|
|
286
|
-
(1, 1, 1, 1))
|
|
287
|
-
# group-wise alltoall
|
|
288
|
-
expert_input = self.stride_slice_ep_mp(expert_input, (0, 0, 0, 0),
|
|
289
|
-
(self.expert_dim, self.dp_group, capacity, self.hidden_size),
|
|
290
|
-
(1, 1, 1, 1))
|
|
291
|
-
# allgather
|
|
292
|
-
expert_input = self.stride_slice_ep(expert_input, (0, 0, 0, 0),
|
|
293
|
-
(self.expert_dim, self.dp_group, capacity, self.hidden_size),
|
|
294
|
-
(1, 1, 1, 1))
|
|
295
|
-
|
|
296
|
-
expert_input = self.reshape(expert_input, (self.expert_dim * self.dp_group * capacity,
|
|
297
|
-
self.hidden_size))
|
|
298
|
-
# expert_output's shape: (self.expert_dim, self.dp_group*expert_capacity, self.hidden_size)
|
|
299
|
-
expert_output = self.ffn(expert_input)
|
|
300
|
-
expert_output = self.reshape(expert_output, (self.expert_dim, self.dp_group,
|
|
301
|
-
capacity, self.hidden_size))
|
|
302
|
-
|
|
303
|
-
if self.group_wise_a2a:
|
|
304
|
-
# capacity shard by mp
|
|
305
|
-
expert_output = self.stride_slice_ep_mp(expert_output, (0, 0, 0, 0),
|
|
306
|
-
(self.expert_dim, self.dp_group, capacity, self.hidden_size),
|
|
307
|
-
(1, 1, 1, 1))
|
|
308
|
-
# group-wise alltoall
|
|
309
|
-
expert_output = self.stride_slice_dp_mp(expert_output, (0, 0, 0, 0),
|
|
310
|
-
(self.expert_dim, self.dp_group, capacity, self.hidden_size),
|
|
311
|
-
(1, 1, 1, 1))
|
|
312
|
-
# allgather
|
|
313
|
-
expert_output = self.stride_slice_dp(expert_output, (0, 0, 0, 0),
|
|
314
|
-
(self.expert_dim, self.dp_group, capacity, self.hidden_size),
|
|
315
|
-
(1, 1, 1, 1))
|
|
316
|
-
# Slice capacity back to org shape.
|
|
317
|
-
if pad_size != 0:
|
|
318
|
-
capacity -= pad_size
|
|
319
|
-
expert_output = self.stride_slice_dp(expert_output, (0, 0, 0, 0),
|
|
320
|
-
(self.expert_dim, self.dp_group, capacity, self.hidden_size),
|
|
321
|
-
(1, 1, 1, 1))
|
|
322
|
-
# expert_output's shape: (self.dp_group, self.hidden_size, self.expert_dim, expert_capacity)
|
|
323
|
-
expert_output = self.transpose_4dim(expert_output, (1, 3, 0, 2))
|
|
324
|
-
return expert_output
|
|
325
|
-
|
|
326
|
-
def ffn_parallel_infer(self, expert_input, capacity):
|
|
327
|
-
"""
|
|
328
|
-
Split and overlap FFN compute and communication.
|
|
329
|
-
"""
|
|
330
|
-
# Pad capacity for comp_comm_parallel_degree split.
|
|
331
|
-
pad_size = 0
|
|
332
|
-
if capacity % self.comp_comm_parallel_degree != 0:
|
|
333
|
-
pad_size = self.comp_comm_parallel_degree - (capacity % self.comp_comm_parallel_degree)
|
|
334
|
-
capacity += pad_size
|
|
335
|
-
pad_tensor = self.stride_slice_dp(expert_input, (0, 0, 0, 0),
|
|
336
|
-
(self.expert_dim, self.dp_group, pad_size, self.hidden_size),
|
|
337
|
-
(1, 1, 1, 1))
|
|
338
|
-
expert_input = self.concat_dp((expert_input, pad_tensor))
|
|
339
|
-
|
|
340
|
-
sub_capacity = capacity // self.comp_comm_parallel_degree
|
|
341
|
-
output_list = []
|
|
342
|
-
for sub_expert_input in self.split(expert_input):
|
|
343
|
-
sub_expert_output = self.ffn_infer(sub_expert_input, sub_capacity)
|
|
344
|
-
output_list.append(sub_expert_output)
|
|
345
|
-
expert_output = self.concat(output_list)
|
|
346
|
-
|
|
347
|
-
# Slice capacity back to org shape.
|
|
348
|
-
if pad_size != 0:
|
|
349
|
-
capacity -= pad_size
|
|
350
|
-
expert_output = self.stride_slice(expert_output, (0, 0, 0, 0),
|
|
351
|
-
(self.dp_group, self.hidden_size, self.expert_dim, capacity),
|
|
352
|
-
(1, 1, 1, 1))
|
|
353
|
-
return expert_output
|
|
354
|
-
|
|
355
|
-
def construct(self, input_tensor):
|
|
356
|
-
input_shape = F.shape(input_tensor)
|
|
357
|
-
input_tensor = self.reshape(input_tensor, (-1, self.hidden_size))
|
|
358
|
-
bs_and_dmodel = self.shape(input_tensor)
|
|
359
|
-
tokens_per_group = bs_and_dmodel[0] // self.dp_group
|
|
360
|
-
input_tensor = self.reshape(input_tensor, (self.dp_group, tokens_per_group, self.hidden_size))
|
|
361
|
-
|
|
362
|
-
expert_capacity = calculate_expert_capacity(self.num_experts_chosen, tokens_per_group,
|
|
363
|
-
self.capacity_factor, self.expert_dim)
|
|
364
|
-
# dispatch_tensor's shape: (self.dp_group, tokens_per_group, self.expert_dim, expert_capacity)
|
|
365
|
-
# combine_tensor's shape: (self.dp_group, tokens_per_group, self.expert_dim, expert_capacity)
|
|
366
|
-
dispatch_tensor, combine_tensor, aux_loss = self.router(input_tensor)
|
|
367
|
-
|
|
368
|
-
# after transpose, input_tensor's shape: (self.dp_group, self.hidden_size, tokens_per_group)
|
|
369
|
-
input_tensor = self.transpose_3dim(input_tensor, (0, 2, 1))
|
|
370
|
-
dispatch_tensor = self.reshape(dispatch_tensor, (self.dp_group, tokens_per_group,
|
|
371
|
-
self.expert_dim * expert_capacity))
|
|
372
|
-
dispatch_tensor = self.cast(dispatch_tensor, F.dtype(input_tensor))
|
|
373
|
-
# expert_input's shape: (self.dp_group, self.hidden_size, self.expert_dim * expert_capacity)
|
|
374
|
-
expert_input = self.batch_mm(input_tensor, dispatch_tensor)
|
|
375
|
-
expert_input = self.reshape(expert_input, (self.dp_group, self.hidden_size, self.expert_dim,
|
|
376
|
-
expert_capacity))
|
|
377
|
-
# The following four ops are to implement transpose(expert_input, (2, 0, 3, 1)), for that a single transpose
|
|
378
|
-
# has bad performance
|
|
379
|
-
expert_input = self.reshape(expert_input, (self.dp_group * self.hidden_size,
|
|
380
|
-
self.expert_dim * expert_capacity))
|
|
381
|
-
expert_input = self.transpose_2dim(expert_input, (1, 0))
|
|
382
|
-
expert_input = self.reshape(expert_input, (self.expert_dim, expert_capacity, self.dp_group,
|
|
383
|
-
self.hidden_size))
|
|
384
|
-
# expert_input's shape: (self.expert_dim, self.dp_group, expert_capacity, self.hidden_size)
|
|
385
|
-
expert_input = self.transpose_4dim_dp(expert_input, (0, 2, 1, 3))
|
|
386
|
-
|
|
387
|
-
# expert_output's shape: (self.dp_group, self.hidden_size, self.expert_dim, expert_capacity)
|
|
388
|
-
if self.comp_comm_parallel:
|
|
389
|
-
expert_output = self.ffn_parallel_infer(expert_input, expert_capacity)
|
|
390
|
-
else:
|
|
391
|
-
expert_output = self.ffn_infer(expert_input, expert_capacity)
|
|
392
|
-
|
|
393
|
-
expert_output = self.reshape(expert_output, (self.dp_group, self.hidden_size,
|
|
394
|
-
self.expert_dim * expert_capacity))
|
|
395
|
-
combine_tensor = self.reshape(combine_tensor, (self.dp_group, tokens_per_group,
|
|
396
|
-
self.expert_dim * expert_capacity))
|
|
397
|
-
# combine_tensor's shape: (self.dp_group, self.expert_dim*expert_capacity, tokens_per_group)
|
|
398
|
-
combine_tensor = self.transpose_3dim(combine_tensor, (0, 2, 1))
|
|
399
|
-
combine_tensor = self.cast(combine_tensor, F.dtype(expert_output))
|
|
400
|
-
|
|
401
|
-
# combined_output's shape: (self.dp_group, self.hidden_size, tokens_per_group)
|
|
402
|
-
combined_output = self.batch_mm2(expert_output, combine_tensor)
|
|
403
|
-
# combined_output's shape: (self.dp_group, tokens_per_group, self.hidden_size)
|
|
404
|
-
combined_output = self.transpose_3dim(combined_output, (0, 2, 1))
|
|
405
|
-
combined_output = self.reshape(combined_output, (bs_and_dmodel[0], bs_and_dmodel[1]))
|
|
406
|
-
combined_output = self.reshape(combined_output, input_shape)
|
|
407
|
-
|
|
408
|
-
aux_loss = self.mul(self.aux_loss_factor, aux_loss)
|
|
409
|
-
return combined_output, aux_loss
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
class Router(Cell):
|
|
413
|
-
r"""
|
|
414
|
-
A router backbone used to calculate logits of each token, which should be cascaded by router implementations
|
|
415
|
-
mapping tokens to experts.
|
|
416
|
-
when moe_config.num_experts_chosen = 1, use top1 routing;
|
|
417
|
-
when moe_config.num_experts_chosen > 1, use topk routing
|
|
418
|
-
|
|
419
|
-
Args:
|
|
420
|
-
d_model (int): The hidden size of each token.
|
|
421
|
-
moe_config(MoEConfig): The configuration of MoE (Mixture of Expert).
|
|
422
|
-
routing_policy: The policy of mapping tokens to experts. Default: topkRouter
|
|
423
|
-
training (bool): The value indicating whether is in training phase.
|
|
424
|
-
parallel_config: The parallel-related configuration.
|
|
425
|
-
Inputs:
|
|
426
|
-
- **input_tensor** (Tensor) - Tensor of shape :math:`(expert\_parallel, tokens\_per\_device,
|
|
427
|
-
hidden\_size)`.
|
|
428
|
-
|
|
429
|
-
Outputs:
|
|
430
|
-
Tensor of shape :math:`(expert\_parallel, tokens\_per\_device, expert\_dim)`.
|
|
431
|
-
"""
|
|
432
|
-
|
|
433
|
-
def __init__(self,
|
|
434
|
-
d_model,
|
|
435
|
-
moe_config,
|
|
436
|
-
routing_policy=None,
|
|
437
|
-
training=True,
|
|
438
|
-
parallel_config=None):
|
|
439
|
-
super(Router, self).__init__()
|
|
440
|
-
dp = parallel_config.data_parallel
|
|
441
|
-
self.d_model = d_model
|
|
442
|
-
self.expert_dim = moe_config.expert_num
|
|
443
|
-
self.capacity_factor = moe_config.capacity_factor
|
|
444
|
-
self.num_experts_chosen = moe_config.num_experts_chosen
|
|
445
|
-
self.training = training
|
|
446
|
-
self.routing_policy = routing_policy
|
|
447
|
-
self.noisy_policy = None # candidate: ["jitter", "rsample", "None"]
|
|
448
|
-
self.noisy_epsilon = 1e-2
|
|
449
|
-
self.noise = Tensor(np.random.uniform(1 - self.noisy_epsilon, 1 + self.noisy_epsilon, (d_model,)))
|
|
450
|
-
|
|
451
|
-
self.dense = Dense(in_channels=self.d_model, out_channels=self.expert_dim, has_bias=False)
|
|
452
|
-
self.dense.matmul.shard(((dp, 1), (1, 1)))
|
|
453
|
-
self.mul = P.Mul()
|
|
454
|
-
self.cast = P.Cast()
|
|
455
|
-
|
|
456
|
-
if self.routing_policy is None:
|
|
457
|
-
self.router = TopkRouter(d_model=d_model, moe_config=moe_config, training=training,
|
|
458
|
-
parallel_config=parallel_config)
|
|
459
|
-
else:
|
|
460
|
-
self.router = routing_policy
|
|
461
|
-
|
|
462
|
-
if not (_get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,) and _is_sharding_propagation()):
|
|
463
|
-
self.mul.shard(((dp, 1, 1), (dp,)))
|
|
464
|
-
|
|
465
|
-
def construct(self, input_tensor):
|
|
466
|
-
input_tensor = self.cast(input_tensor, mstype.float32)
|
|
467
|
-
if self.noisy_policy == "jitter" and self.training:
|
|
468
|
-
# Here, we temporarily implement the multiplicative jitter this way,
|
|
469
|
-
# for the lack of UniforReal parallel operator.
|
|
470
|
-
input_tensor = self.mul(input_tensor, self.noise)
|
|
471
|
-
|
|
472
|
-
router_logits = self.dense(input_tensor)
|
|
473
|
-
return self.router(router_logits)
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
class TopkRouter(Cell):
|
|
477
|
-
r"""
|
|
478
|
-
A router implementation which maps each tokens to the topk expert.
|
|
479
|
-
|
|
480
|
-
Args:
|
|
481
|
-
d_model (int): The hidden size of each token.
|
|
482
|
-
moe_config(MoEConfig): The configuration of MoE (Mixture of Expert).
|
|
483
|
-
training (bool): The value indicating whether is in training phase.
|
|
484
|
-
config: The parallel-related configuration.
|
|
485
|
-
Inputs:
|
|
486
|
-
- **input_tensor** (Tensor) - Tensor of shape :math:`(expert\_parallel, tokens\_per\_device,
|
|
487
|
-
hidden\_size)`.
|
|
488
|
-
|
|
489
|
-
Outputs:
|
|
490
|
-
Tensor of shape :math:`(expert\_parallel, tokens\_per\_device, expert\_dim, expert\_capacity)`,
|
|
491
|
-
Tensor of shape :math:`(expert\_parallel, tokens\_per\_device, expert\_dim, expert\_capacity)`,
|
|
492
|
-
Tensor of shape :math:`(1)`.
|
|
493
|
-
"""
|
|
494
|
-
|
|
495
|
-
def __init__(self,
|
|
496
|
-
d_model,
|
|
497
|
-
moe_config,
|
|
498
|
-
training=True,
|
|
499
|
-
parallel_config=None):
|
|
500
|
-
super(TopkRouter, self).__init__()
|
|
501
|
-
if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,) and _is_sharding_propagation():
|
|
502
|
-
dp = parallel_config.data_parallel
|
|
503
|
-
self.d_model = d_model
|
|
504
|
-
self.expert_dim = moe_config.expert_num
|
|
505
|
-
self.capacity_factor = moe_config.capacity_factor
|
|
506
|
-
self.training = training
|
|
507
|
-
self.dp_group = dp
|
|
508
|
-
self.noisy_policy = None
|
|
509
|
-
self.cast = P.Cast()
|
|
510
|
-
self.reshape = P.Reshape()
|
|
511
|
-
self.shape = P.Shape()
|
|
512
|
-
self.softmax = P.Softmax(axis=-1)
|
|
513
|
-
self.argmax = P.ArgMaxWithValue(axis=-1, keep_dims=False)
|
|
514
|
-
self.num_experts_chosen = moe_config.num_experts_chosen
|
|
515
|
-
self.onehot = P.OneHot()
|
|
516
|
-
self.onehot2 = P.OneHot()
|
|
517
|
-
self.onehot3 = P.OneHot()
|
|
518
|
-
self.on_value = Tensor(1.0, mstype.float32)
|
|
519
|
-
self.off_value = Tensor(0.0, mstype.float32)
|
|
520
|
-
|
|
521
|
-
self.reduce_mean = P.ReduceMean(keep_dims=False)
|
|
522
|
-
self.reduce_mean2 = P.ReduceMean(keep_dims=False)
|
|
523
|
-
self.reduce_mean3 = P.ReduceMean(keep_dims=False)
|
|
524
|
-
self.mul = P.Mul()
|
|
525
|
-
self.mul2 = P.Mul()
|
|
526
|
-
self.mul3 = P.Mul()
|
|
527
|
-
self.mul4 = P.Mul()
|
|
528
|
-
self.mul5 = P.Mul()
|
|
529
|
-
self.mul6 = P.Mul()
|
|
530
|
-
self.mul7 = P.Mul()
|
|
531
|
-
self.mul8 = P.Mul().shard(((dp, 1, 1), (dp, 1, 1)))
|
|
532
|
-
self.mul9 = P.Mul().shard(((dp, 1, 1, 1), (dp, 1, 1, 1)))
|
|
533
|
-
self.not_equal = P.NotEqual()
|
|
534
|
-
self.div1 = P.RealDiv()
|
|
535
|
-
self.div2 = P.RealDiv()
|
|
536
|
-
self.add = P.Add()
|
|
537
|
-
self.add1 = P.Add()
|
|
538
|
-
self.add2 = P.Add()
|
|
539
|
-
self.add3 = P.Add()
|
|
540
|
-
self.add4 = P.Add()
|
|
541
|
-
self.sub = P.Sub()
|
|
542
|
-
|
|
543
|
-
self.cumsum = P.CumSum(exclusive=True)
|
|
544
|
-
self.less = P.Less()
|
|
545
|
-
self.reduce_sum = P.ReduceSum(keep_dims=False)
|
|
546
|
-
self.reduce_sum_keep = P.ReduceSum(keep_dims=True)
|
|
547
|
-
self.reduce_sum_keep2 = P.ReduceSum(keep_dims=True)
|
|
548
|
-
self.expand = P.ExpandDims()
|
|
549
|
-
self.expand2 = P.ExpandDims()
|
|
550
|
-
self.add_scala = P.Add()
|
|
551
|
-
self.init_loss = Tensor(0.0, mstype.float32)
|
|
552
|
-
else:
|
|
553
|
-
dp = parallel_config.data_parallel
|
|
554
|
-
self.d_model = d_model
|
|
555
|
-
self.expert_dim = moe_config.expert_num
|
|
556
|
-
self.capacity_factor = moe_config.capacity_factor
|
|
557
|
-
self.training = training
|
|
558
|
-
self.dp_group = dp
|
|
559
|
-
self.noisy_policy = None
|
|
560
|
-
self.cast = P.Cast()
|
|
561
|
-
self.reshape = P.Reshape()
|
|
562
|
-
self.shape = P.Shape()
|
|
563
|
-
self.softmax = P.Softmax(axis=-1).shard(((dp, 1, 1,),))
|
|
564
|
-
self.argmax = P.ArgMaxWithValue(axis=-1, keep_dims=False).shard(((dp, 1, 1),))
|
|
565
|
-
self.num_experts_chosen = moe_config.num_experts_chosen
|
|
566
|
-
self.onehot = P.OneHot().shard(((dp, 1, 1), (), ()))
|
|
567
|
-
self.onehot2 = P.OneHot().shard(((dp, 1, 1), (), ()))
|
|
568
|
-
self.onehot3 = P.OneHot().shard(((dp, 1, 1, 1), (), ()))
|
|
569
|
-
self.on_value = Tensor(1.0, mstype.float32)
|
|
570
|
-
self.off_value = Tensor(0.0, mstype.float32)
|
|
571
|
-
|
|
572
|
-
self.reduce_mean = P.ReduceMean(keep_dims=False).shard(((dp, 1, 1),))
|
|
573
|
-
self.reduce_mean2 = P.ReduceMean(keep_dims=False).shard(((dp, 1, 1),))
|
|
574
|
-
self.reduce_mean3 = P.ReduceMean(keep_dims=False).shard(((dp, 1),))
|
|
575
|
-
self.mul = P.Mul().shard(((dp, 1), (dp, 1)))
|
|
576
|
-
self.mul2 = P.Mul().shard(((), ()))
|
|
577
|
-
self.mul3 = P.Mul().shard(((), ()))
|
|
578
|
-
self.mul4 = P.Mul().shard(((dp, 1, 1), (dp, 1, 1)))
|
|
579
|
-
self.mul5 = P.Mul().shard(((dp, 1, 1), (dp, 1, 1)))
|
|
580
|
-
self.mul6 = P.Mul().shard(((dp, 1), (dp, 1)))
|
|
581
|
-
self.mul7 = P.Mul().shard(((dp, 1), (dp, 1)))
|
|
582
|
-
self.mul8 = P.Mul().shard(((dp, 1, 1), (dp, 1, 1)))
|
|
583
|
-
self.mul9 = P.Mul().shard(((dp, 1, 1, 1), (dp, 1, 1, 1)))
|
|
584
|
-
self.not_equal = P.NotEqual().shard(((dp, 1, 1, 1), ()))
|
|
585
|
-
self.div1 = P.RealDiv().shard(((dp, 1, 1), (dp, 1, 1)))
|
|
586
|
-
self.div2 = P.RealDiv().shard(((dp, 1, 1, 1), (dp, 1, 1, 1)))
|
|
587
|
-
self.add = P.Add().shard(((dp, 1, 1), (dp, 1, 1)))
|
|
588
|
-
self.add1 = P.Add().shard(((dp, 1, 1), ()))
|
|
589
|
-
self.add2 = P.Add().shard(((dp, 1, 1, 1), (dp, 1, 1, 1)))
|
|
590
|
-
self.add3 = P.Add().shard(((dp, 1), (dp, 1)))
|
|
591
|
-
self.add4 = P.Add().shard(((dp, 1, 1, 1), ()))
|
|
592
|
-
self.sub = P.Sub().shard(((), (dp, 1, 1)))
|
|
593
|
-
|
|
594
|
-
self.cumsum = P.CumSum(exclusive=True).shard(((dp, 1, 1),))
|
|
595
|
-
self.less = P.Less().shard(((dp, 1, 1), ()))
|
|
596
|
-
self.reduce_sum = P.ReduceSum(keep_dims=False).shard(((dp, 1, 1),))
|
|
597
|
-
self.reduce_sum_keep = P.ReduceSum(keep_dims=True).shard(((dp, 1, 1),))
|
|
598
|
-
self.reduce_sum_keep2 = P.ReduceSum(keep_dims=True).shard(((dp, 1, 1, 1),))
|
|
599
|
-
self.expand = P.ExpandDims().shard(((dp, 1),))
|
|
600
|
-
self.expand2 = P.ExpandDims().shard(((dp, 1, 1),))
|
|
601
|
-
self.add_scala = P.Add().shard(((), ()))
|
|
602
|
-
self.init_loss = Tensor(0.0, mstype.float32)
|
|
603
|
-
|
|
604
|
-
def construct(self, router_logits):
|
|
605
|
-
router_logits_shape = self.shape(router_logits)
|
|
606
|
-
router_logits = self.reshape(router_logits, (-1, router_logits_shape[-1]))
|
|
607
|
-
logits_shape = self.shape(router_logits)
|
|
608
|
-
tokens_per_group = logits_shape[0] // self.dp_group
|
|
609
|
-
expert_capacity = calculate_expert_capacity(self.num_experts_chosen, tokens_per_group, self.capacity_factor,
|
|
610
|
-
self.expert_dim)
|
|
611
|
-
router_logits = self.reshape(router_logits, (self.dp_group, tokens_per_group, self.expert_dim))
|
|
612
|
-
|
|
613
|
-
accum_expert_mask = 0
|
|
614
|
-
accum_expert_gate = 0
|
|
615
|
-
loss = self.init_loss
|
|
616
|
-
mask_count = 0
|
|
617
|
-
accum_combine_tensor = 0
|
|
618
|
-
# Probabilities for each token of what expert is should be sent to
|
|
619
|
-
router_prob = self.softmax(router_logits)
|
|
620
|
-
|
|
621
|
-
for expert_chosen_index in range(self.num_experts_chosen):
|
|
622
|
-
# for each token, set the router_prob of the selected experts to zero
|
|
623
|
-
router_prob = self.mul4(router_prob, self.sub(self.on_value, accum_expert_mask))
|
|
624
|
-
# shape is : (dp_group, tokens_per_group)
|
|
625
|
-
expert_index, expert_gate = self.argmax(router_prob)
|
|
626
|
-
# expert_mask's shape: (dp_group, tokens_per_group, self.expert_dim)
|
|
627
|
-
expert_mask = self.onehot(expert_index, self.expert_dim, self.on_value, self.off_value)
|
|
628
|
-
# renormalize the rest prob to be of sum 1
|
|
629
|
-
router_prob_normal = self.div1(router_prob, self.add1(self.reduce_sum_keep(router_prob, -1), 1e-9))
|
|
630
|
-
|
|
631
|
-
# the balance loss is computed at each routing step
|
|
632
|
-
loss = self.add_scala(loss, self._auxiliary_loss(expert_mask, router_prob_normal))
|
|
633
|
-
|
|
634
|
-
output = self._maskout_overflowed_tokens(expert_mask, expert_capacity, expert_gate,
|
|
635
|
-
mask_count, expert_chosen_index)
|
|
636
|
-
expert_mask, expert_gate, expert_mask_flat, position_in_expert = output[0], output[1], output[2], output[3]
|
|
637
|
-
accum_expert_mask = self.add(accum_expert_mask, expert_mask)
|
|
638
|
-
accum_expert_gate = self.add3(accum_expert_gate, expert_gate)
|
|
639
|
-
mask_count = self.add(mask_count, self.reduce_sum_keep(expert_mask, 1))
|
|
640
|
-
|
|
641
|
-
# combine_tensor's shape: (dp_group, tokens_per_group)
|
|
642
|
-
combine_tensor = self.mul7(expert_gate, expert_mask_flat)
|
|
643
|
-
# combine_tensor's shape: (dp_group, tokens_per_group, self.expert_dim)
|
|
644
|
-
combine_tensor = self.mul8(self.expand(combine_tensor, -1),
|
|
645
|
-
self.onehot2(expert_index, self.expert_dim, self.on_value, self.off_value))
|
|
646
|
-
# combine_tensor's shape: (dp_group, tokens_per_group, self.expert_dim, self.expert_capacity)
|
|
647
|
-
combine_tensor = self.mul9(self.expand2(combine_tensor, -1),
|
|
648
|
-
self.onehot3(self.cast(position_in_expert, mstype.int32), expert_capacity,
|
|
649
|
-
self.on_value, self.off_value))
|
|
650
|
-
accum_combine_tensor = self.add2(accum_combine_tensor, combine_tensor)
|
|
651
|
-
|
|
652
|
-
# expert weights normalization when k > 1
|
|
653
|
-
if self.num_experts_chosen > 1:
|
|
654
|
-
combine_tensor_sum = self.reduce_sum_keep2(self.reduce_sum_keep2(accum_combine_tensor, -1), -2)
|
|
655
|
-
accum_combine_tensor = self.div2(accum_combine_tensor, self.add4(combine_tensor_sum, 1e-9))
|
|
656
|
-
# dispatch_tensor is of boolean type. Here, using NotEqual instead of Cast, for that 'Cast to bool' has
|
|
657
|
-
# bad performance
|
|
658
|
-
dispatch_tensor = self.not_equal(accum_combine_tensor, 0.0)
|
|
659
|
-
return dispatch_tensor, accum_combine_tensor, loss
|
|
660
|
-
|
|
661
|
-
def _auxiliary_loss(self, expert_mask, router_prob):
|
|
662
|
-
"""
|
|
663
|
-
Computing the load balance loss.
|
|
664
|
-
"""
|
|
665
|
-
# density_1's shape: (dp_group, self.expert_dim)
|
|
666
|
-
density_1 = self.reduce_mean(expert_mask, 1)
|
|
667
|
-
# density_1_proxy's shape: (dp_group, self.expert_dim)
|
|
668
|
-
density_1_proxy = self.reduce_mean2(router_prob, 1)
|
|
669
|
-
loss = self.mul(density_1, density_1_proxy)
|
|
670
|
-
loss = self.reduce_mean3(loss)
|
|
671
|
-
loss = self.mul3(self.mul2(loss, self.expert_dim), self.expert_dim)
|
|
672
|
-
return loss
|
|
673
|
-
|
|
674
|
-
def _maskout_overflowed_tokens(self, expert_mask, expert_capacity, expert_gate, last_num, expert_chosen_index):
|
|
675
|
-
"""
|
|
676
|
-
Keeping only the tokens that fit within expert_capacity.
|
|
677
|
-
"""
|
|
678
|
-
cumsum = self.cumsum(expert_mask, 1)
|
|
679
|
-
if expert_chosen_index > 0:
|
|
680
|
-
cumsum = self.add(cumsum, last_num)
|
|
681
|
-
# position_in_expert's shape: (dp_group, tokens_per_group, self.expert_dim)
|
|
682
|
-
position_in_expert = self.mul4(cumsum, expert_mask)
|
|
683
|
-
less_result = self.less(position_in_expert, expert_capacity)
|
|
684
|
-
# expert_mask's shape: (dp_group, tokens_per_group, self.expert_dim)
|
|
685
|
-
expert_mask = self.mul5(less_result, expert_mask)
|
|
686
|
-
# expert_mask_flat's shape: (dp_group, tokens_per_group)
|
|
687
|
-
expert_mask_flat = self.reduce_sum(expert_mask, -1)
|
|
688
|
-
|
|
689
|
-
# Mask out the experts that have overflowed the expert_capacity.
|
|
690
|
-
# expert_gate's shape: (dp_group, tokens_per_group)
|
|
691
|
-
expert_gate = self.mul6(expert_gate, expert_mask_flat)
|
|
692
|
-
output = (expert_mask, expert_gate, expert_mask_flat, position_in_expert)
|
|
693
|
-
return output
|