mindspore 2.7.0__cp310-cp310-win_amd64.whl → 2.7.1__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (290) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -1
  3. mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
  6. mindspore/_extends/parse/compile_config.py +24 -1
  7. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +6 -2
  8. mindspore/_extends/parse/resources.py +1 -1
  9. mindspore/_extends/parse/standard_method.py +8 -1
  10. mindspore/_extends/parse/trope.py +2 -1
  11. mindspore/_extends/pijit/pijit_func_white_list.py +7 -22
  12. mindspore/avcodec-59.dll +0 -0
  13. mindspore/avdevice-59.dll +0 -0
  14. mindspore/avfilter-8.dll +0 -0
  15. mindspore/avformat-59.dll +0 -0
  16. mindspore/avutil-57.dll +0 -0
  17. mindspore/boost/base.py +29 -2
  18. mindspore/common/_decorator.py +3 -2
  19. mindspore/common/_grad_function.py +3 -1
  20. mindspore/common/_tensor_cpp_method.py +1 -1
  21. mindspore/common/_tensor_docs.py +275 -64
  22. mindspore/common/_utils.py +0 -44
  23. mindspore/common/api.py +285 -35
  24. mindspore/common/dump.py +7 -108
  25. mindspore/common/dynamic_shape/auto_dynamic_shape.py +1 -3
  26. mindspore/common/hook_handle.py +60 -0
  27. mindspore/common/jit_config.py +5 -1
  28. mindspore/common/jit_trace.py +27 -12
  29. mindspore/common/lazy_inline.py +5 -3
  30. mindspore/common/parameter.py +13 -107
  31. mindspore/common/recompute.py +4 -11
  32. mindspore/common/tensor.py +16 -169
  33. mindspore/communication/_comm_helper.py +11 -1
  34. mindspore/communication/comm_func.py +138 -4
  35. mindspore/communication/management.py +85 -1
  36. mindspore/config/op_info.config +0 -15
  37. mindspore/context.py +5 -85
  38. mindspore/dataset/engine/datasets.py +8 -4
  39. mindspore/dataset/engine/datasets_vision.py +1 -1
  40. mindspore/dataset/engine/validators.py +1 -15
  41. mindspore/dnnl.dll +0 -0
  42. mindspore/{experimental/llm_boost/ascend_native → graph}/__init__.py +7 -7
  43. mindspore/graph/custom_pass.py +55 -0
  44. mindspore/include/dataset/execute.h +2 -2
  45. mindspore/jpeg62.dll +0 -0
  46. mindspore/mindrecord/__init__.py +3 -3
  47. mindspore/mindrecord/common/exceptions.py +1 -0
  48. mindspore/mindrecord/config.py +1 -1
  49. mindspore/{parallel/mpi → mindrecord/core}/__init__.py +4 -1
  50. mindspore/mindrecord/{shardheader.py → core/shardheader.py} +2 -1
  51. mindspore/mindrecord/{shardindexgenerator.py → core/shardindexgenerator.py} +1 -1
  52. mindspore/mindrecord/{shardreader.py → core/shardreader.py} +2 -1
  53. mindspore/mindrecord/{shardsegment.py → core/shardsegment.py} +2 -2
  54. mindspore/mindrecord/{shardutils.py → core/shardutils.py} +1 -1
  55. mindspore/mindrecord/{shardwriter.py → core/shardwriter.py} +1 -1
  56. mindspore/mindrecord/filereader.py +4 -4
  57. mindspore/mindrecord/filewriter.py +5 -5
  58. mindspore/mindrecord/mindpage.py +2 -2
  59. mindspore/mindrecord/tools/cifar10.py +1 -1
  60. mindspore/mindrecord/tools/cifar100.py +1 -1
  61. mindspore/mindrecord/tools/cifar100_to_mr.py +1 -1
  62. mindspore/mindrecord/tools/cifar10_to_mr.py +1 -1
  63. mindspore/mindrecord/tools/csv_to_mr.py +1 -1
  64. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  65. mindspore/mindrecord/tools/mnist_to_mr.py +1 -1
  66. mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -1
  67. mindspore/mindspore_backend_common.dll +0 -0
  68. mindspore/mindspore_backend_manager.dll +0 -0
  69. mindspore/mindspore_cluster.dll +0 -0
  70. mindspore/mindspore_common.dll +0 -0
  71. mindspore/mindspore_core.dll +0 -0
  72. mindspore/mindspore_cpu.dll +0 -0
  73. mindspore/mindspore_dump.dll +0 -0
  74. mindspore/mindspore_frontend.dll +0 -0
  75. mindspore/mindspore_glog.dll +0 -0
  76. mindspore/mindspore_hardware_abstract.dll +0 -0
  77. mindspore/mindspore_memory_pool.dll +0 -0
  78. mindspore/mindspore_ms_backend.dll +0 -0
  79. mindspore/mindspore_ops.dll +0 -0
  80. mindspore/{mindspore_ops_host.dll → mindspore_ops_cpu.dll} +0 -0
  81. mindspore/mindspore_profiler.dll +0 -0
  82. mindspore/mindspore_pyboost.dll +0 -0
  83. mindspore/mindspore_pynative.dll +0 -0
  84. mindspore/mindspore_runtime_pipeline.dll +0 -0
  85. mindspore/mindspore_runtime_utils.dll +0 -0
  86. mindspore/mindspore_tools.dll +0 -0
  87. mindspore/mint/__init__.py +15 -10
  88. mindspore/mint/distributed/distributed.py +182 -62
  89. mindspore/mint/nn/__init__.py +2 -16
  90. mindspore/mint/nn/functional.py +4 -110
  91. mindspore/mint/nn/layer/__init__.py +0 -2
  92. mindspore/mint/nn/layer/activation.py +0 -6
  93. mindspore/mint/nn/layer/basic.py +0 -47
  94. mindspore/mint/nn/layer/conv.py +4 -4
  95. mindspore/mint/nn/layer/normalization.py +8 -13
  96. mindspore/mint/nn/layer/pooling.py +0 -4
  97. mindspore/nn/__init__.py +1 -3
  98. mindspore/nn/cell.py +16 -66
  99. mindspore/nn/layer/basic.py +49 -1
  100. mindspore/nn/layer/container.py +16 -0
  101. mindspore/nn/layer/embedding.py +4 -169
  102. mindspore/nn/layer/normalization.py +2 -1
  103. mindspore/nn/layer/thor_layer.py +4 -85
  104. mindspore/nn/optim/ada_grad.py +0 -1
  105. mindspore/nn/optim/adafactor.py +0 -1
  106. mindspore/nn/optim/adam.py +31 -124
  107. mindspore/nn/optim/adamax.py +0 -1
  108. mindspore/nn/optim/asgd.py +0 -1
  109. mindspore/nn/optim/ftrl.py +8 -102
  110. mindspore/nn/optim/lamb.py +0 -1
  111. mindspore/nn/optim/lars.py +0 -3
  112. mindspore/nn/optim/lazyadam.py +25 -218
  113. mindspore/nn/optim/momentum.py +5 -43
  114. mindspore/nn/optim/optimizer.py +6 -55
  115. mindspore/nn/optim/proximal_ada_grad.py +0 -1
  116. mindspore/nn/optim/rmsprop.py +0 -1
  117. mindspore/nn/optim/rprop.py +0 -1
  118. mindspore/nn/optim/sgd.py +0 -1
  119. mindspore/nn/optim/tft_wrapper.py +0 -1
  120. mindspore/nn/optim/thor.py +0 -2
  121. mindspore/nn/probability/bijector/bijector.py +7 -8
  122. mindspore/nn/probability/bijector/gumbel_cdf.py +2 -2
  123. mindspore/nn/probability/bijector/power_transform.py +20 -21
  124. mindspore/nn/probability/bijector/scalar_affine.py +5 -5
  125. mindspore/nn/probability/bijector/softplus.py +13 -14
  126. mindspore/nn/wrap/grad_reducer.py +4 -74
  127. mindspore/numpy/array_creations.py +2 -2
  128. mindspore/numpy/fft.py +9 -9
  129. mindspore/{nn/reinforcement → onnx}/__init__.py +5 -8
  130. mindspore/onnx/onnx_export.py +137 -0
  131. mindspore/opencv_core4110.dll +0 -0
  132. mindspore/opencv_imgcodecs4110.dll +0 -0
  133. mindspore/{opencv_imgproc452.dll → opencv_imgproc4110.dll} +0 -0
  134. mindspore/ops/__init__.py +2 -0
  135. mindspore/ops/_grad_experimental/grad_comm_ops.py +38 -2
  136. mindspore/ops/_op_impl/aicpu/__init__.py +0 -10
  137. mindspore/ops/_op_impl/cpu/__init__.py +0 -5
  138. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +16 -22
  139. mindspore/ops/auto_generate/gen_extend_func.py +2 -7
  140. mindspore/ops/auto_generate/gen_ops_def.py +98 -141
  141. mindspore/ops/auto_generate/gen_ops_prim.py +12708 -12686
  142. mindspore/ops/communication.py +97 -0
  143. mindspore/ops/composite/__init__.py +5 -2
  144. mindspore/ops/composite/base.py +15 -1
  145. mindspore/ops/composite/multitype_ops/__init__.py +3 -1
  146. mindspore/ops/composite/multitype_ops/_compile_utils.py +150 -8
  147. mindspore/ops/composite/multitype_ops/add_impl.py +7 -0
  148. mindspore/ops/composite/multitype_ops/mod_impl.py +27 -0
  149. mindspore/ops/function/__init__.py +1 -0
  150. mindspore/ops/function/array_func.py +14 -12
  151. mindspore/ops/function/comm_func.py +3883 -0
  152. mindspore/ops/function/debug_func.py +3 -4
  153. mindspore/ops/function/math_func.py +45 -54
  154. mindspore/ops/function/nn_func.py +75 -294
  155. mindspore/ops/function/random_func.py +9 -18
  156. mindspore/ops/functional.py +2 -0
  157. mindspore/ops/functional_overload.py +354 -18
  158. mindspore/ops/operations/__init__.py +2 -5
  159. mindspore/ops/operations/_custom_ops_utils.py +7 -9
  160. mindspore/ops/operations/_inner_ops.py +1 -38
  161. mindspore/ops/operations/_rl_inner_ops.py +0 -933
  162. mindspore/ops/operations/array_ops.py +1 -0
  163. mindspore/ops/operations/comm_ops.py +94 -2
  164. mindspore/ops/operations/custom_ops.py +228 -19
  165. mindspore/ops/operations/debug_ops.py +27 -29
  166. mindspore/ops/operations/manually_defined/ops_def.py +27 -306
  167. mindspore/ops/operations/nn_ops.py +2 -2
  168. mindspore/ops/operations/sparse_ops.py +0 -83
  169. mindspore/ops/primitive.py +1 -17
  170. mindspore/ops/tensor_method.py +72 -3
  171. mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +5 -5
  172. mindspore/ops_generate/aclnn/gen_aclnn_implement.py +8 -8
  173. mindspore/ops_generate/api/functions_cc_generator.py +53 -4
  174. mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +25 -11
  175. mindspore/ops_generate/common/gen_constants.py +11 -10
  176. mindspore/ops_generate/common/op_proto.py +18 -1
  177. mindspore/ops_generate/common/template.py +102 -245
  178. mindspore/ops_generate/common/template_utils.py +212 -0
  179. mindspore/ops_generate/gen_custom_ops.py +69 -0
  180. mindspore/ops_generate/op_def/ops_def_cc_generator.py +78 -7
  181. mindspore/ops_generate/op_def_py/base_op_prim_py_generator.py +360 -0
  182. mindspore/ops_generate/op_def_py/custom_op_prim_py_generator.py +140 -0
  183. mindspore/ops_generate/op_def_py/op_def_py_generator.py +54 -7
  184. mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -312
  185. mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +74 -17
  186. mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +22 -5
  187. mindspore/ops_generate/pyboost/op_template_parser.py +3 -2
  188. mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +21 -5
  189. mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +2 -2
  190. mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +30 -10
  191. mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +10 -3
  192. mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +1 -1
  193. mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +19 -9
  194. mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +71 -28
  195. mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +10 -9
  196. mindspore/ops_generate/pyboost/pyboost_utils.py +27 -16
  197. mindspore/ops_generate/resources/yaml_loader.py +13 -0
  198. mindspore/ops_generate/tensor_py_cc_generator.py +2 -2
  199. mindspore/parallel/_cell_wrapper.py +1 -1
  200. mindspore/parallel/_parallel_serialization.py +1 -4
  201. mindspore/parallel/_utils.py +29 -6
  202. mindspore/parallel/checkpoint_transform.py +18 -2
  203. mindspore/parallel/cluster/process_entity/_api.py +24 -32
  204. mindspore/parallel/cluster/process_entity/_utils.py +9 -5
  205. mindspore/{experimental/llm_boost/atb → parallel/distributed}/__init__.py +21 -23
  206. mindspore/parallel/distributed/distributed_data_parallel.py +393 -0
  207. mindspore/parallel/distributed/flatten_grad_buffer.py +295 -0
  208. mindspore/parallel/strategy.py +336 -0
  209. mindspore/parallel/transform_safetensors.py +117 -16
  210. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +3 -0
  211. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +1 -1
  212. mindspore/profiler/common/constant.py +5 -0
  213. mindspore/profiler/common/file_manager.py +9 -0
  214. mindspore/profiler/common/msprof_cmd_tool.py +38 -2
  215. mindspore/profiler/common/path_manager.py +56 -24
  216. mindspore/profiler/common/profiler_context.py +2 -12
  217. mindspore/profiler/common/profiler_info.py +3 -3
  218. mindspore/profiler/common/profiler_path_manager.py +13 -0
  219. mindspore/profiler/common/util.py +30 -3
  220. mindspore/profiler/experimental_config.py +2 -1
  221. mindspore/profiler/platform/npu_profiler.py +33 -6
  222. mindspore/run_check/_check_version.py +108 -24
  223. mindspore/runtime/__init__.py +3 -2
  224. mindspore/runtime/executor.py +11 -3
  225. mindspore/runtime/memory.py +112 -0
  226. mindspore/swresample-4.dll +0 -0
  227. mindspore/swscale-6.dll +0 -0
  228. mindspore/tinyxml2.dll +0 -0
  229. mindspore/{experimental/llm_boost → tools}/__init__.py +5 -5
  230. mindspore/tools/data_dump.py +130 -0
  231. mindspore/tools/sdc_detect.py +91 -0
  232. mindspore/tools/stress_detect.py +63 -0
  233. mindspore/train/__init__.py +6 -6
  234. mindspore/train/_utils.py +5 -18
  235. mindspore/train/amp.py +6 -4
  236. mindspore/train/callback/_checkpoint.py +0 -9
  237. mindspore/train/callback/_train_fault_tolerance.py +69 -18
  238. mindspore/train/data_sink.py +1 -5
  239. mindspore/train/model.py +38 -211
  240. mindspore/train/serialization.py +126 -387
  241. mindspore/turbojpeg.dll +0 -0
  242. mindspore/utils/__init__.py +6 -3
  243. mindspore/utils/dlpack.py +92 -0
  244. mindspore/utils/dryrun.py +1 -1
  245. mindspore/utils/runtime_execution_order_check.py +10 -0
  246. mindspore/utils/sdc_detect.py +14 -12
  247. mindspore/utils/stress_detect.py +43 -0
  248. mindspore/utils/utils.py +144 -8
  249. mindspore/version.py +1 -1
  250. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/METADATA +3 -2
  251. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/RECORD +254 -267
  252. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -210
  253. mindspore/experimental/llm_boost/ascend_native/llm_boost.py +0 -52
  254. mindspore/experimental/llm_boost/atb/boost_base.py +0 -385
  255. mindspore/experimental/llm_boost/atb/llama_boost.py +0 -137
  256. mindspore/experimental/llm_boost/atb/qwen_boost.py +0 -124
  257. mindspore/experimental/llm_boost/register.py +0 -130
  258. mindspore/experimental/llm_boost/utils.py +0 -31
  259. mindspore/include/OWNERS +0 -7
  260. mindspore/mindspore_cpu_res_manager.dll +0 -0
  261. mindspore/mindspore_ops_kernel_common.dll +0 -0
  262. mindspore/mindspore_res_manager.dll +0 -0
  263. mindspore/nn/optim/_dist_optimizer_registry.py +0 -111
  264. mindspore/nn/reinforcement/_batch_read_write.py +0 -142
  265. mindspore/nn/reinforcement/_tensors_queue.py +0 -152
  266. mindspore/nn/reinforcement/tensor_array.py +0 -145
  267. mindspore/opencv_core452.dll +0 -0
  268. mindspore/opencv_imgcodecs452.dll +0 -0
  269. mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +0 -113
  270. mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +0 -96
  271. mindspore/ops/_op_impl/aicpu/sparse_cross.py +0 -42
  272. mindspore/ops/_op_impl/cpu/buffer_append.py +0 -28
  273. mindspore/ops/_op_impl/cpu/buffer_get.py +0 -28
  274. mindspore/ops/_op_impl/cpu/buffer_sample.py +0 -28
  275. mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +0 -42
  276. mindspore/ops/operations/_tensor_array.py +0 -359
  277. mindspore/ops/operations/rl_ops.py +0 -288
  278. mindspore/parallel/_offload_context.py +0 -275
  279. mindspore/parallel/_recovery_context.py +0 -115
  280. mindspore/parallel/_transformer/__init__.py +0 -35
  281. mindspore/parallel/_transformer/layers.py +0 -765
  282. mindspore/parallel/_transformer/loss.py +0 -251
  283. mindspore/parallel/_transformer/moe.py +0 -693
  284. mindspore/parallel/_transformer/op_parallel_config.py +0 -222
  285. mindspore/parallel/_transformer/transformer.py +0 -3124
  286. mindspore/parallel/mpi/_mpi_config.py +0 -116
  287. mindspore/train/memory_profiling_pb2.py +0 -298
  288. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/WHEEL +0 -0
  289. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/entry_points.txt +0 -0
  290. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/top_level.txt +0 -0
@@ -1,693 +0,0 @@
1
- # Copyright 2023 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ============================================================================
15
- """
16
- Note:
17
- Mixture of Expert (MoE) structure.
18
- These are experimental APIs that are subject to change or deletion.
19
- """
20
- from __future__ import absolute_import
21
- from __future__ import division
22
-
23
- import numpy as np
24
-
25
- from mindspore.common.tensor import Tensor
26
- import mindspore.common.dtype as mstype
27
- import mindspore.communication.management as D
28
- from mindspore import _checkparam as Validator
29
- from mindspore.ops import operations as P
30
- from mindspore.ops import functional as F
31
- from mindspore.ops.primitive import _primexpr
32
- from mindspore.nn.cell import Cell
33
- from mindspore.nn.layer import Dense
34
- from mindspore.context import ParallelMode
35
- from mindspore.parallel._utils import _get_parallel_mode, _is_sharding_propagation
36
- from mindspore.parallel._transformer.op_parallel_config import default_moeparallel_config
37
-
38
- __all__ = [
39
- "MoEConfig"]
40
-
41
-
42
- class MoEConfig:
43
- r"""
44
- The configuration of MoE (Mixture of Expert).
45
-
46
- Args:
47
- expert_num (int): The number of experts employed. Default: 1
48
- capacity_factor (float): The factor is used to indicate how much to expand expert capacity,
49
- which is >=1.0. Default: 1.1.
50
- aux_loss_factor (float): The factor is used to indicate how much the load balance loss (produced by the
51
- router) to be added to the entire model loss, which is < 1.0. Default: 0.05.
52
- num_experts_chosen (int): The number of experts is chosen by each token and it should not be larger
53
- than expert_num. Default: 1.
54
- expert_group_size (int): The number of tokens in each data parallel group. Default: ``None``.
55
- This parameter is effective only when in AUTO_PARALLEL mode, and NOT SHARDING_PROPAGATION.
56
- group_wise_a2a (bool): Whether to enable group-wise alltoall communication, which can reduce communication
57
- time by converting part of inter communication into intra communication. Default: ``False``.
58
- This parameter is effective only when model parallel > 1 and data_parallel equal to expert parallel.
59
- comp_comm_parallel (bool): Whether to enable ffn compute and communication parallel, which can reduce pure
60
- communicattion time by splitting and overlapping compute and communication. Default: ``False``.
61
- comp_comm_parallel_degree (int): The split number of compute and communication. The larger the numbers,
62
- the more overlap there will be but will consume more memory. Default: 2. This parameter is effective
63
- only when comp_comm_parallel enable.
64
-
65
- Supported Platforms:
66
- ``Ascend`` ``GPU``
67
-
68
- Examples:
69
- >>> from mindspore.nn.transformer import MoEConfig
70
- >>> moe_config = MoEConfig(expert_num=4, capacity_factor=5.0, aux_loss_factor=0.05, num_experts_chosen=1,
71
- ... expert_group_size=64, group_wise_a2a=True, comp_comm_parallel=False,
72
- ... comp_comm_parallel_degree=2)
73
- """
74
-
75
- def __init__(self, expert_num=1, capacity_factor=1.1, aux_loss_factor=0.05, num_experts_chosen=1,
76
- expert_group_size=None, group_wise_a2a=False, comp_comm_parallel=False, comp_comm_parallel_degree=2):
77
- Validator.check_positive_int(expert_num, "expert_num")
78
- Validator.check_positive_float(capacity_factor, "capacity_factor")
79
- Validator.check_positive_float(aux_loss_factor, "aux_loss_factor")
80
- Validator.check_positive_int(num_experts_chosen, "num_experts_chosen")
81
- Validator.check_bool(group_wise_a2a, "group_wise_a2a")
82
- Validator.check_bool(comp_comm_parallel, "comp_comm_parallel")
83
- Validator.check_positive_int(comp_comm_parallel_degree, "comp_comm_parallel_degree")
84
- if expert_group_size is not None:
85
- Validator.check_positive_int(expert_group_size, "expert_group_size")
86
- if capacity_factor < 1.0:
87
- raise ValueError(f"'capacity_factor' must be equal to or greater than 1.0, "
88
- f"but got {capacity_factor}.")
89
- if aux_loss_factor >= 1.0:
90
- raise ValueError(f"'aux_loss_factor' must be less than 1.0, "
91
- f"but got {aux_loss_factor}.")
92
- if num_experts_chosen > expert_num:
93
- raise ValueError(f"'num_experts_chosen' must not be larger than 'expert_num', "
94
- f"but got {num_experts_chosen}.")
95
- self.expert_num = expert_num
96
- self.capacity_factor = capacity_factor
97
- self.aux_loss_factor = aux_loss_factor
98
- self.num_experts_chosen = num_experts_chosen
99
- self.expert_group_size = expert_group_size
100
- self.group_wise_a2a = group_wise_a2a
101
- self.comp_comm_parallel = comp_comm_parallel
102
- self.comp_comm_parallel_degree = comp_comm_parallel_degree
103
-
104
-
105
- default_moe_config = MoEConfig()
106
-
107
-
108
- def _check_moe_config(moe_config=None, parallel_config=None):
109
- """
110
- check if MoE with right configuration.
111
- """
112
- if not isinstance(moe_config, MoEConfig):
113
- raise TypeError(f"'moe_config' must be an instance of MoEConfig, but got {type(moe_config).__name__}.")
114
- use_moe = moe_config.expert_num > 1
115
- if use_moe is False:
116
- return
117
- if moe_config.expert_num % parallel_config.expert_parallel != 0:
118
- raise ValueError(f"When using MoE, the 'expert_num' in {type(moe_config).__name__} must be a multiple "
119
- f"of 'expert_parallel' value in {type(parallel_config).__name__}, but got "
120
- f"{moe_config.expert_num} for 'expert_num' and {parallel_config.expert_parallel} for "
121
- f"'expert_parallel'.")
122
-
123
- device_num = D.get_group_size()
124
- if device_num % parallel_config.expert_parallel != 0:
125
- raise ValueError(f"device_num: {device_num} must be a multiple of expert_parallel: "
126
- f"{parallel_config.expert_parallel}.")
127
- if parallel_config.data_parallel % parallel_config.expert_parallel != 0:
128
- raise ValueError(f"data parallel: {parallel_config.data_parallel} must be a multiple of "
129
- f"expert_parallel: {parallel_config.expert_parallel} when using MoE.")
130
- if parallel_config.data_parallel * parallel_config.model_parallel > device_num:
131
- raise ValueError(f"The product of the data parallel: {parallel_config.data_parallel} and "
132
- f"model parallel: {parallel_config.model_parallel} "
133
- f"should be less than device_num: {device_num}.")
134
-
135
-
136
- @_primexpr
137
- def calculate_expert_capacity(k, tokens_per_group, capacity_factor, expert_dim):
138
- res = k * tokens_per_group * capacity_factor / expert_dim
139
- res_int = int(res)
140
- return res_int if res < 0 or res == res_int else res_int + 1
141
-
142
-
143
- class MoE(Cell):
144
- """
145
- The mixture of experts (MoE) implementation. The implementation includes a router and a FeedForward layer.
146
- The router dispatches tokens to experts in FeedForward, then FeedForward does computation, and the final output is
147
- obtained by multiplying FeedForward's output and router's combine weight.
148
-
149
- Args:
150
- hidden_size (int): The dimension of the inputs.
151
- ffn_hidden_size (int): The intermediate hidden size.
152
- dropout_rate (float): The dropout rate for the second linear's output.
153
- hidden_act (str): The activation of the internal feedforward layer. Supports 'relu',
154
- 'relu6', 'tanh', 'gelu', 'fast_gelu', 'elu', 'sigmoid', 'prelu', 'leakyrelu', 'hswish',
155
- 'hsigmoid', 'logsigmoid' and so on. Default: gelu.
156
- param_init_type (dtype.Number): The parameter initialization type. Can be dtype.float32 or dtype.float16.
157
- moe_config(MoEConfig): The configuration of MoE (Mixture of Expert). Default is an instance of MoEConfig with
158
- default values. Please see `MoEConfig`.
159
- parallel_config(MoEParallelConfig): The parallel config for MoE, see `MoEParallelConfig`.
160
- Default `default_moeparallel_config`, an instance of `MoEParallelConfig` with default args.
161
-
162
- Inputs:
163
- - **x** (Tensor) - should be `[batch, seq_length, hidden_size]`. Float tensor.
164
-
165
- Outputs:
166
- Tensor, the output of this layer after mapping. The shape is `[batch, seq_length, hidden_size]`.
167
- """
168
-
169
- def __init__(self, hidden_size,
170
- ffn_hidden_size,
171
- dropout_rate,
172
- hidden_act='gelu',
173
- param_init_type=mstype.float32,
174
- moe_config=default_moe_config,
175
- parallel_config=default_moeparallel_config):
176
- super(MoE, self).__init__()
177
- if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,) and _is_sharding_propagation():
178
- self.hidden_size = hidden_size
179
- self.expert_dim = moe_config.expert_num
180
- self.capacity_factor = moe_config.capacity_factor
181
- self.aux_loss_factor = moe_config.aux_loss_factor
182
- self.num_experts_chosen = moe_config.num_experts_chosen
183
- self.expert_group_size = moe_config.expert_group_size
184
- self.dp_group = parallel_config.data_parallel
185
- self.dp = parallel_config.data_parallel
186
- self.ep = parallel_config.expert_parallel
187
- self.mp = parallel_config.model_parallel
188
- self.comp_comm_parallel = moe_config.comp_comm_parallel
189
- self.comp_comm_parallel_degree = moe_config.comp_comm_parallel_degree
190
- self.group_wise_a2a = moe_config.group_wise_a2a
191
- if not (self.mp > 1 and self.dp == self.ep):
192
- self.group_wise_a2a = False
193
- from mindspore.parallel._transformer import FeedForward
194
-
195
- self.ffn = FeedForward(hidden_size=hidden_size,
196
- ffn_hidden_size=ffn_hidden_size,
197
- dropout_rate=dropout_rate,
198
- hidden_act=hidden_act,
199
- expert_num=self.expert_dim,
200
- expert_group_size=self.expert_group_size,
201
- param_init_type=param_init_type,
202
- parallel_config=parallel_config)
203
- self.reshape = P.Reshape()
204
- self.shape = P.Shape()
205
- self.transpose_2dim = P.Transpose().shard(((self.dp, 1),))
206
- self.transpose_3dim = P.Transpose().shard(((self.dp, 1, 1),))
207
- self.transpose_4dim = P.Transpose().shard(((1, self.dp, 1, 1),))
208
- self.transpose_4dim_dp = P.Transpose().shard(((1, 1, self.dp, 1),))
209
- self.batch_mm = P.BatchMatMul().shard(((self.dp, 1, 1), (self.dp, 1, 1)))
210
- self.batch_mm2 = P.BatchMatMul().shard(((self.dp, 1, 1), (self.dp, 1, 1)))
211
- self.mul = P.Mul()
212
- self.router = Router(d_model=hidden_size, moe_config=moe_config, routing_policy=None,
213
- training=True, parallel_config=parallel_config)
214
- self.cast = P.Cast()
215
- self.concat = P.Concat(3).shard(tuple((self.dp, 1, 1, 1) for _ in range(self.comp_comm_parallel_degree)))
216
- self.concat_dp = P.Concat(2).shard(((1, self.dp, 1, 1), (1, self.dp, 1, 1)))
217
- self.split = P.Split(axis=2, output_num=self.comp_comm_parallel_degree).shard(((1, self.dp, 1, 1),))
218
- self.stride_slice = P.StridedSlice().shard(((self.dp, 1, 1, 1),))
219
- self.stride_slice_dp = P.StridedSlice().shard(((1, self.dp, 1, 1),))
220
- self.stride_slice_ep = P.StridedSlice().shard(((self.ep, 1, 1, 1),))
221
- self.stride_slice_dp_mp = P.StridedSlice().shard(((1, self.dp, self.mp, 1),))
222
- self.stride_slice_ep_mp = P.StridedSlice().shard(((self.ep, 1, self.mp, 1),))
223
- else:
224
- self.hidden_size = hidden_size
225
- self.expert_dim = moe_config.expert_num
226
- self.capacity_factor = moe_config.capacity_factor
227
- self.aux_loss_factor = moe_config.aux_loss_factor
228
- self.num_experts_chosen = moe_config.num_experts_chosen
229
- self.dp_group = parallel_config.data_parallel
230
- self.dp = parallel_config.data_parallel
231
- self.ep = parallel_config.expert_parallel
232
- self.mp = parallel_config.model_parallel
233
- self.comp_comm_parallel = moe_config.comp_comm_parallel
234
- self.comp_comm_parallel_degree = moe_config.comp_comm_parallel_degree
235
- self.group_wise_a2a = moe_config.group_wise_a2a
236
- if not (self.mp > 1 and self.dp == self.ep):
237
- self.group_wise_a2a = False
238
- from mindspore.parallel._transformer import FeedForward
239
-
240
- self.ffn = FeedForward(hidden_size=hidden_size,
241
- ffn_hidden_size=ffn_hidden_size,
242
- dropout_rate=dropout_rate,
243
- hidden_act=hidden_act,
244
- expert_num=self.expert_dim,
245
- param_init_type=param_init_type,
246
- parallel_config=parallel_config)
247
- self.reshape = P.Reshape()
248
- self.shape = P.Shape()
249
- self.transpose_2dim = P.Transpose().shard(((self.dp, 1),))
250
- self.transpose_3dim = P.Transpose().shard(((self.dp, 1, 1),))
251
- self.transpose_4dim = P.Transpose().shard(((1, self.dp, 1, 1),))
252
- self.transpose_4dim_dp = P.Transpose().shard(((1, 1, self.dp, 1),))
253
- self.batch_mm = P.BatchMatMul().shard(((self.dp, 1, 1), (self.dp, 1, 1)))
254
- self.batch_mm2 = P.BatchMatMul().shard(((self.dp, 1, 1), (self.dp, 1, 1)))
255
- self.mul = P.Mul().shard(((), ()))
256
- self.router = Router(d_model=hidden_size, moe_config=moe_config, routing_policy=None,
257
- training=True, parallel_config=parallel_config)
258
- self.cast = P.Cast()
259
- self.concat = P.Concat(3).shard(tuple((self.dp, 1, 1, 1) for _ in range(self.comp_comm_parallel_degree)))
260
- self.concat_dp = P.Concat(2).shard(((1, self.dp, 1, 1), (1, self.dp, 1, 1)))
261
- self.split = P.Split(axis=2, output_num=self.comp_comm_parallel_degree).shard(((1, self.dp, 1, 1),))
262
- self.stride_slice = P.StridedSlice().shard(((self.dp, 1, 1, 1),))
263
- self.stride_slice_dp = P.StridedSlice().shard(((1, self.dp, 1, 1),))
264
- self.stride_slice_ep = P.StridedSlice().shard(((self.ep, 1, 1, 1),))
265
- self.stride_slice_dp_mp = P.StridedSlice().shard(((1, self.dp, self.mp, 1),))
266
- self.stride_slice_ep_mp = P.StridedSlice().shard(((self.ep, 1, self.mp, 1),))
267
-
268
- def ffn_infer(self, expert_input, capacity):
269
- """
270
- Computing the FFN.
271
- """
272
- pad_size = 0
273
- if self.group_wise_a2a:
274
- # If capacity can't div by mp, pad for mp shard.
275
- if capacity % self.mp != 0:
276
- pad_size = self.mp - (capacity % self.mp)
277
- if pad_size != 0:
278
- capacity += pad_size
279
- pad_tensor = self.stride_slice_dp(expert_input, (0, 0, 0, 0),
280
- (self.expert_dim, self.dp_group, pad_size, self.hidden_size),
281
- (1, 1, 1, 1))
282
- expert_input = self.concat_dp((expert_input, pad_tensor))
283
- # capacity shard by mp
284
- expert_input = self.stride_slice_dp_mp(expert_input, (0, 0, 0, 0),
285
- (self.expert_dim, self.dp_group, capacity, self.hidden_size),
286
- (1, 1, 1, 1))
287
- # group-wise alltoall
288
- expert_input = self.stride_slice_ep_mp(expert_input, (0, 0, 0, 0),
289
- (self.expert_dim, self.dp_group, capacity, self.hidden_size),
290
- (1, 1, 1, 1))
291
- # allgather
292
- expert_input = self.stride_slice_ep(expert_input, (0, 0, 0, 0),
293
- (self.expert_dim, self.dp_group, capacity, self.hidden_size),
294
- (1, 1, 1, 1))
295
-
296
- expert_input = self.reshape(expert_input, (self.expert_dim * self.dp_group * capacity,
297
- self.hidden_size))
298
- # expert_output's shape: (self.expert_dim, self.dp_group*expert_capacity, self.hidden_size)
299
- expert_output = self.ffn(expert_input)
300
- expert_output = self.reshape(expert_output, (self.expert_dim, self.dp_group,
301
- capacity, self.hidden_size))
302
-
303
- if self.group_wise_a2a:
304
- # capacity shard by mp
305
- expert_output = self.stride_slice_ep_mp(expert_output, (0, 0, 0, 0),
306
- (self.expert_dim, self.dp_group, capacity, self.hidden_size),
307
- (1, 1, 1, 1))
308
- # group-wise alltoall
309
- expert_output = self.stride_slice_dp_mp(expert_output, (0, 0, 0, 0),
310
- (self.expert_dim, self.dp_group, capacity, self.hidden_size),
311
- (1, 1, 1, 1))
312
- # allgather
313
- expert_output = self.stride_slice_dp(expert_output, (0, 0, 0, 0),
314
- (self.expert_dim, self.dp_group, capacity, self.hidden_size),
315
- (1, 1, 1, 1))
316
- # Slice capacity back to org shape.
317
- if pad_size != 0:
318
- capacity -= pad_size
319
- expert_output = self.stride_slice_dp(expert_output, (0, 0, 0, 0),
320
- (self.expert_dim, self.dp_group, capacity, self.hidden_size),
321
- (1, 1, 1, 1))
322
- # expert_output's shape: (self.dp_group, self.hidden_size, self.expert_dim, expert_capacity)
323
- expert_output = self.transpose_4dim(expert_output, (1, 3, 0, 2))
324
- return expert_output
325
-
326
- def ffn_parallel_infer(self, expert_input, capacity):
327
- """
328
- Split and overlap FFN compute and communication.
329
- """
330
- # Pad capacity for comp_comm_parallel_degree split.
331
- pad_size = 0
332
- if capacity % self.comp_comm_parallel_degree != 0:
333
- pad_size = self.comp_comm_parallel_degree - (capacity % self.comp_comm_parallel_degree)
334
- capacity += pad_size
335
- pad_tensor = self.stride_slice_dp(expert_input, (0, 0, 0, 0),
336
- (self.expert_dim, self.dp_group, pad_size, self.hidden_size),
337
- (1, 1, 1, 1))
338
- expert_input = self.concat_dp((expert_input, pad_tensor))
339
-
340
- sub_capacity = capacity // self.comp_comm_parallel_degree
341
- output_list = []
342
- for sub_expert_input in self.split(expert_input):
343
- sub_expert_output = self.ffn_infer(sub_expert_input, sub_capacity)
344
- output_list.append(sub_expert_output)
345
- expert_output = self.concat(output_list)
346
-
347
- # Slice capacity back to org shape.
348
- if pad_size != 0:
349
- capacity -= pad_size
350
- expert_output = self.stride_slice(expert_output, (0, 0, 0, 0),
351
- (self.dp_group, self.hidden_size, self.expert_dim, capacity),
352
- (1, 1, 1, 1))
353
- return expert_output
354
-
355
- def construct(self, input_tensor):
356
- input_shape = F.shape(input_tensor)
357
- input_tensor = self.reshape(input_tensor, (-1, self.hidden_size))
358
- bs_and_dmodel = self.shape(input_tensor)
359
- tokens_per_group = bs_and_dmodel[0] // self.dp_group
360
- input_tensor = self.reshape(input_tensor, (self.dp_group, tokens_per_group, self.hidden_size))
361
-
362
- expert_capacity = calculate_expert_capacity(self.num_experts_chosen, tokens_per_group,
363
- self.capacity_factor, self.expert_dim)
364
- # dispatch_tensor's shape: (self.dp_group, tokens_per_group, self.expert_dim, expert_capacity)
365
- # combine_tensor's shape: (self.dp_group, tokens_per_group, self.expert_dim, expert_capacity)
366
- dispatch_tensor, combine_tensor, aux_loss = self.router(input_tensor)
367
-
368
- # after transpose, input_tensor's shape: (self.dp_group, self.hidden_size, tokens_per_group)
369
- input_tensor = self.transpose_3dim(input_tensor, (0, 2, 1))
370
- dispatch_tensor = self.reshape(dispatch_tensor, (self.dp_group, tokens_per_group,
371
- self.expert_dim * expert_capacity))
372
- dispatch_tensor = self.cast(dispatch_tensor, F.dtype(input_tensor))
373
- # expert_input's shape: (self.dp_group, self.hidden_size, self.expert_dim * expert_capacity)
374
- expert_input = self.batch_mm(input_tensor, dispatch_tensor)
375
- expert_input = self.reshape(expert_input, (self.dp_group, self.hidden_size, self.expert_dim,
376
- expert_capacity))
377
- # The following four ops are to implement transpose(expert_input, (2, 0, 3, 1)), for that a single transpose
378
- # has bad performance
379
- expert_input = self.reshape(expert_input, (self.dp_group * self.hidden_size,
380
- self.expert_dim * expert_capacity))
381
- expert_input = self.transpose_2dim(expert_input, (1, 0))
382
- expert_input = self.reshape(expert_input, (self.expert_dim, expert_capacity, self.dp_group,
383
- self.hidden_size))
384
- # expert_input's shape: (self.expert_dim, self.dp_group, expert_capacity, self.hidden_size)
385
- expert_input = self.transpose_4dim_dp(expert_input, (0, 2, 1, 3))
386
-
387
- # expert_output's shape: (self.dp_group, self.hidden_size, self.expert_dim, expert_capacity)
388
- if self.comp_comm_parallel:
389
- expert_output = self.ffn_parallel_infer(expert_input, expert_capacity)
390
- else:
391
- expert_output = self.ffn_infer(expert_input, expert_capacity)
392
-
393
- expert_output = self.reshape(expert_output, (self.dp_group, self.hidden_size,
394
- self.expert_dim * expert_capacity))
395
- combine_tensor = self.reshape(combine_tensor, (self.dp_group, tokens_per_group,
396
- self.expert_dim * expert_capacity))
397
- # combine_tensor's shape: (self.dp_group, self.expert_dim*expert_capacity, tokens_per_group)
398
- combine_tensor = self.transpose_3dim(combine_tensor, (0, 2, 1))
399
- combine_tensor = self.cast(combine_tensor, F.dtype(expert_output))
400
-
401
- # combined_output's shape: (self.dp_group, self.hidden_size, tokens_per_group)
402
- combined_output = self.batch_mm2(expert_output, combine_tensor)
403
- # combined_output's shape: (self.dp_group, tokens_per_group, self.hidden_size)
404
- combined_output = self.transpose_3dim(combined_output, (0, 2, 1))
405
- combined_output = self.reshape(combined_output, (bs_and_dmodel[0], bs_and_dmodel[1]))
406
- combined_output = self.reshape(combined_output, input_shape)
407
-
408
- aux_loss = self.mul(self.aux_loss_factor, aux_loss)
409
- return combined_output, aux_loss
410
-
411
-
412
- class Router(Cell):
413
- r"""
414
- A router backbone used to calculate logits of each token, which should be cascaded by router implementations
415
- mapping tokens to experts.
416
- when moe_config.num_experts_chosen = 1, use top1 routing;
417
- when moe_config.num_experts_chosen > 1, use topk routing
418
-
419
- Args:
420
- d_model (int): The hidden size of each token.
421
- moe_config(MoEConfig): The configuration of MoE (Mixture of Expert).
422
- routing_policy: The policy of mapping tokens to experts. Default: topkRouter
423
- training (bool): The value indicating whether is in training phase.
424
- parallel_config: The parallel-related configuration.
425
- Inputs:
426
- - **input_tensor** (Tensor) - Tensor of shape :math:`(expert\_parallel, tokens\_per\_device,
427
- hidden\_size)`.
428
-
429
- Outputs:
430
- Tensor of shape :math:`(expert\_parallel, tokens\_per\_device, expert\_dim)`.
431
- """
432
-
433
- def __init__(self,
434
- d_model,
435
- moe_config,
436
- routing_policy=None,
437
- training=True,
438
- parallel_config=None):
439
- super(Router, self).__init__()
440
- dp = parallel_config.data_parallel
441
- self.d_model = d_model
442
- self.expert_dim = moe_config.expert_num
443
- self.capacity_factor = moe_config.capacity_factor
444
- self.num_experts_chosen = moe_config.num_experts_chosen
445
- self.training = training
446
- self.routing_policy = routing_policy
447
- self.noisy_policy = None # candidate: ["jitter", "rsample", "None"]
448
- self.noisy_epsilon = 1e-2
449
- self.noise = Tensor(np.random.uniform(1 - self.noisy_epsilon, 1 + self.noisy_epsilon, (d_model,)))
450
-
451
- self.dense = Dense(in_channels=self.d_model, out_channels=self.expert_dim, has_bias=False)
452
- self.dense.matmul.shard(((dp, 1), (1, 1)))
453
- self.mul = P.Mul()
454
- self.cast = P.Cast()
455
-
456
- if self.routing_policy is None:
457
- self.router = TopkRouter(d_model=d_model, moe_config=moe_config, training=training,
458
- parallel_config=parallel_config)
459
- else:
460
- self.router = routing_policy
461
-
462
- if not (_get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,) and _is_sharding_propagation()):
463
- self.mul.shard(((dp, 1, 1), (dp,)))
464
-
465
- def construct(self, input_tensor):
466
- input_tensor = self.cast(input_tensor, mstype.float32)
467
- if self.noisy_policy == "jitter" and self.training:
468
- # Here, we temporarily implement the multiplicative jitter this way,
469
- # for the lack of UniforReal parallel operator.
470
- input_tensor = self.mul(input_tensor, self.noise)
471
-
472
- router_logits = self.dense(input_tensor)
473
- return self.router(router_logits)
474
-
475
-
476
- class TopkRouter(Cell):
477
- r"""
478
- A router implementation which maps each tokens to the topk expert.
479
-
480
- Args:
481
- d_model (int): The hidden size of each token.
482
- moe_config(MoEConfig): The configuration of MoE (Mixture of Expert).
483
- training (bool): The value indicating whether is in training phase.
484
- config: The parallel-related configuration.
485
- Inputs:
486
- - **input_tensor** (Tensor) - Tensor of shape :math:`(expert\_parallel, tokens\_per\_device,
487
- hidden\_size)`.
488
-
489
- Outputs:
490
- Tensor of shape :math:`(expert\_parallel, tokens\_per\_device, expert\_dim, expert\_capacity)`,
491
- Tensor of shape :math:`(expert\_parallel, tokens\_per\_device, expert\_dim, expert\_capacity)`,
492
- Tensor of shape :math:`(1)`.
493
- """
494
-
495
- def __init__(self,
496
- d_model,
497
- moe_config,
498
- training=True,
499
- parallel_config=None):
500
- super(TopkRouter, self).__init__()
501
- if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,) and _is_sharding_propagation():
502
- dp = parallel_config.data_parallel
503
- self.d_model = d_model
504
- self.expert_dim = moe_config.expert_num
505
- self.capacity_factor = moe_config.capacity_factor
506
- self.training = training
507
- self.dp_group = dp
508
- self.noisy_policy = None
509
- self.cast = P.Cast()
510
- self.reshape = P.Reshape()
511
- self.shape = P.Shape()
512
- self.softmax = P.Softmax(axis=-1)
513
- self.argmax = P.ArgMaxWithValue(axis=-1, keep_dims=False)
514
- self.num_experts_chosen = moe_config.num_experts_chosen
515
- self.onehot = P.OneHot()
516
- self.onehot2 = P.OneHot()
517
- self.onehot3 = P.OneHot()
518
- self.on_value = Tensor(1.0, mstype.float32)
519
- self.off_value = Tensor(0.0, mstype.float32)
520
-
521
- self.reduce_mean = P.ReduceMean(keep_dims=False)
522
- self.reduce_mean2 = P.ReduceMean(keep_dims=False)
523
- self.reduce_mean3 = P.ReduceMean(keep_dims=False)
524
- self.mul = P.Mul()
525
- self.mul2 = P.Mul()
526
- self.mul3 = P.Mul()
527
- self.mul4 = P.Mul()
528
- self.mul5 = P.Mul()
529
- self.mul6 = P.Mul()
530
- self.mul7 = P.Mul()
531
- self.mul8 = P.Mul().shard(((dp, 1, 1), (dp, 1, 1)))
532
- self.mul9 = P.Mul().shard(((dp, 1, 1, 1), (dp, 1, 1, 1)))
533
- self.not_equal = P.NotEqual()
534
- self.div1 = P.RealDiv()
535
- self.div2 = P.RealDiv()
536
- self.add = P.Add()
537
- self.add1 = P.Add()
538
- self.add2 = P.Add()
539
- self.add3 = P.Add()
540
- self.add4 = P.Add()
541
- self.sub = P.Sub()
542
-
543
- self.cumsum = P.CumSum(exclusive=True)
544
- self.less = P.Less()
545
- self.reduce_sum = P.ReduceSum(keep_dims=False)
546
- self.reduce_sum_keep = P.ReduceSum(keep_dims=True)
547
- self.reduce_sum_keep2 = P.ReduceSum(keep_dims=True)
548
- self.expand = P.ExpandDims()
549
- self.expand2 = P.ExpandDims()
550
- self.add_scala = P.Add()
551
- self.init_loss = Tensor(0.0, mstype.float32)
552
- else:
553
- dp = parallel_config.data_parallel
554
- self.d_model = d_model
555
- self.expert_dim = moe_config.expert_num
556
- self.capacity_factor = moe_config.capacity_factor
557
- self.training = training
558
- self.dp_group = dp
559
- self.noisy_policy = None
560
- self.cast = P.Cast()
561
- self.reshape = P.Reshape()
562
- self.shape = P.Shape()
563
- self.softmax = P.Softmax(axis=-1).shard(((dp, 1, 1,),))
564
- self.argmax = P.ArgMaxWithValue(axis=-1, keep_dims=False).shard(((dp, 1, 1),))
565
- self.num_experts_chosen = moe_config.num_experts_chosen
566
- self.onehot = P.OneHot().shard(((dp, 1, 1), (), ()))
567
- self.onehot2 = P.OneHot().shard(((dp, 1, 1), (), ()))
568
- self.onehot3 = P.OneHot().shard(((dp, 1, 1, 1), (), ()))
569
- self.on_value = Tensor(1.0, mstype.float32)
570
- self.off_value = Tensor(0.0, mstype.float32)
571
-
572
- self.reduce_mean = P.ReduceMean(keep_dims=False).shard(((dp, 1, 1),))
573
- self.reduce_mean2 = P.ReduceMean(keep_dims=False).shard(((dp, 1, 1),))
574
- self.reduce_mean3 = P.ReduceMean(keep_dims=False).shard(((dp, 1),))
575
- self.mul = P.Mul().shard(((dp, 1), (dp, 1)))
576
- self.mul2 = P.Mul().shard(((), ()))
577
- self.mul3 = P.Mul().shard(((), ()))
578
- self.mul4 = P.Mul().shard(((dp, 1, 1), (dp, 1, 1)))
579
- self.mul5 = P.Mul().shard(((dp, 1, 1), (dp, 1, 1)))
580
- self.mul6 = P.Mul().shard(((dp, 1), (dp, 1)))
581
- self.mul7 = P.Mul().shard(((dp, 1), (dp, 1)))
582
- self.mul8 = P.Mul().shard(((dp, 1, 1), (dp, 1, 1)))
583
- self.mul9 = P.Mul().shard(((dp, 1, 1, 1), (dp, 1, 1, 1)))
584
- self.not_equal = P.NotEqual().shard(((dp, 1, 1, 1), ()))
585
- self.div1 = P.RealDiv().shard(((dp, 1, 1), (dp, 1, 1)))
586
- self.div2 = P.RealDiv().shard(((dp, 1, 1, 1), (dp, 1, 1, 1)))
587
- self.add = P.Add().shard(((dp, 1, 1), (dp, 1, 1)))
588
- self.add1 = P.Add().shard(((dp, 1, 1), ()))
589
- self.add2 = P.Add().shard(((dp, 1, 1, 1), (dp, 1, 1, 1)))
590
- self.add3 = P.Add().shard(((dp, 1), (dp, 1)))
591
- self.add4 = P.Add().shard(((dp, 1, 1, 1), ()))
592
- self.sub = P.Sub().shard(((), (dp, 1, 1)))
593
-
594
- self.cumsum = P.CumSum(exclusive=True).shard(((dp, 1, 1),))
595
- self.less = P.Less().shard(((dp, 1, 1), ()))
596
- self.reduce_sum = P.ReduceSum(keep_dims=False).shard(((dp, 1, 1),))
597
- self.reduce_sum_keep = P.ReduceSum(keep_dims=True).shard(((dp, 1, 1),))
598
- self.reduce_sum_keep2 = P.ReduceSum(keep_dims=True).shard(((dp, 1, 1, 1),))
599
- self.expand = P.ExpandDims().shard(((dp, 1),))
600
- self.expand2 = P.ExpandDims().shard(((dp, 1, 1),))
601
- self.add_scala = P.Add().shard(((), ()))
602
- self.init_loss = Tensor(0.0, mstype.float32)
603
-
604
- def construct(self, router_logits):
605
- router_logits_shape = self.shape(router_logits)
606
- router_logits = self.reshape(router_logits, (-1, router_logits_shape[-1]))
607
- logits_shape = self.shape(router_logits)
608
- tokens_per_group = logits_shape[0] // self.dp_group
609
- expert_capacity = calculate_expert_capacity(self.num_experts_chosen, tokens_per_group, self.capacity_factor,
610
- self.expert_dim)
611
- router_logits = self.reshape(router_logits, (self.dp_group, tokens_per_group, self.expert_dim))
612
-
613
- accum_expert_mask = 0
614
- accum_expert_gate = 0
615
- loss = self.init_loss
616
- mask_count = 0
617
- accum_combine_tensor = 0
618
- # Probabilities for each token of what expert is should be sent to
619
- router_prob = self.softmax(router_logits)
620
-
621
- for expert_chosen_index in range(self.num_experts_chosen):
622
- # for each token, set the router_prob of the selected experts to zero
623
- router_prob = self.mul4(router_prob, self.sub(self.on_value, accum_expert_mask))
624
- # shape is : (dp_group, tokens_per_group)
625
- expert_index, expert_gate = self.argmax(router_prob)
626
- # expert_mask's shape: (dp_group, tokens_per_group, self.expert_dim)
627
- expert_mask = self.onehot(expert_index, self.expert_dim, self.on_value, self.off_value)
628
- # renormalize the rest prob to be of sum 1
629
- router_prob_normal = self.div1(router_prob, self.add1(self.reduce_sum_keep(router_prob, -1), 1e-9))
630
-
631
- # the balance loss is computed at each routing step
632
- loss = self.add_scala(loss, self._auxiliary_loss(expert_mask, router_prob_normal))
633
-
634
- output = self._maskout_overflowed_tokens(expert_mask, expert_capacity, expert_gate,
635
- mask_count, expert_chosen_index)
636
- expert_mask, expert_gate, expert_mask_flat, position_in_expert = output[0], output[1], output[2], output[3]
637
- accum_expert_mask = self.add(accum_expert_mask, expert_mask)
638
- accum_expert_gate = self.add3(accum_expert_gate, expert_gate)
639
- mask_count = self.add(mask_count, self.reduce_sum_keep(expert_mask, 1))
640
-
641
- # combine_tensor's shape: (dp_group, tokens_per_group)
642
- combine_tensor = self.mul7(expert_gate, expert_mask_flat)
643
- # combine_tensor's shape: (dp_group, tokens_per_group, self.expert_dim)
644
- combine_tensor = self.mul8(self.expand(combine_tensor, -1),
645
- self.onehot2(expert_index, self.expert_dim, self.on_value, self.off_value))
646
- # combine_tensor's shape: (dp_group, tokens_per_group, self.expert_dim, self.expert_capacity)
647
- combine_tensor = self.mul9(self.expand2(combine_tensor, -1),
648
- self.onehot3(self.cast(position_in_expert, mstype.int32), expert_capacity,
649
- self.on_value, self.off_value))
650
- accum_combine_tensor = self.add2(accum_combine_tensor, combine_tensor)
651
-
652
- # expert weights normalization when k > 1
653
- if self.num_experts_chosen > 1:
654
- combine_tensor_sum = self.reduce_sum_keep2(self.reduce_sum_keep2(accum_combine_tensor, -1), -2)
655
- accum_combine_tensor = self.div2(accum_combine_tensor, self.add4(combine_tensor_sum, 1e-9))
656
- # dispatch_tensor is of boolean type. Here, using NotEqual instead of Cast, for that 'Cast to bool' has
657
- # bad performance
658
- dispatch_tensor = self.not_equal(accum_combine_tensor, 0.0)
659
- return dispatch_tensor, accum_combine_tensor, loss
660
-
661
- def _auxiliary_loss(self, expert_mask, router_prob):
662
- """
663
- Computing the load balance loss.
664
- """
665
- # density_1's shape: (dp_group, self.expert_dim)
666
- density_1 = self.reduce_mean(expert_mask, 1)
667
- # density_1_proxy's shape: (dp_group, self.expert_dim)
668
- density_1_proxy = self.reduce_mean2(router_prob, 1)
669
- loss = self.mul(density_1, density_1_proxy)
670
- loss = self.reduce_mean3(loss)
671
- loss = self.mul3(self.mul2(loss, self.expert_dim), self.expert_dim)
672
- return loss
673
-
674
- def _maskout_overflowed_tokens(self, expert_mask, expert_capacity, expert_gate, last_num, expert_chosen_index):
675
- """
676
- Keeping only the tokens that fit within expert_capacity.
677
- """
678
- cumsum = self.cumsum(expert_mask, 1)
679
- if expert_chosen_index > 0:
680
- cumsum = self.add(cumsum, last_num)
681
- # position_in_expert's shape: (dp_group, tokens_per_group, self.expert_dim)
682
- position_in_expert = self.mul4(cumsum, expert_mask)
683
- less_result = self.less(position_in_expert, expert_capacity)
684
- # expert_mask's shape: (dp_group, tokens_per_group, self.expert_dim)
685
- expert_mask = self.mul5(less_result, expert_mask)
686
- # expert_mask_flat's shape: (dp_group, tokens_per_group)
687
- expert_mask_flat = self.reduce_sum(expert_mask, -1)
688
-
689
- # Mask out the experts that have overflowed the expert_capacity.
690
- # expert_gate's shape: (dp_group, tokens_per_group)
691
- expert_gate = self.mul6(expert_gate, expert_mask_flat)
692
- output = (expert_mask, expert_gate, expert_mask_flat, position_in_expert)
693
- return output