mindspore 2.7.0rc1__cp310-cp310-win_amd64.whl → 2.7.1__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (370) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +5 -2
  3. mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
  6. mindspore/_checkparam.py +2 -2
  7. mindspore/_extends/builtin_operations.py +3 -3
  8. mindspore/_extends/parallel_compile/akg_compiler/custom.py +1109 -0
  9. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
  10. mindspore/_extends/parse/__init__.py +3 -3
  11. mindspore/_extends/parse/compile_config.py +24 -1
  12. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +6 -3
  13. mindspore/_extends/parse/parser.py +28 -22
  14. mindspore/_extends/parse/resources.py +1 -1
  15. mindspore/_extends/parse/standard_method.py +23 -2
  16. mindspore/_extends/parse/trope.py +2 -1
  17. mindspore/_extends/pijit/pijit_func_white_list.py +9 -27
  18. mindspore/amp.py +0 -18
  19. mindspore/avcodec-59.dll +0 -0
  20. mindspore/avdevice-59.dll +0 -0
  21. mindspore/avfilter-8.dll +0 -0
  22. mindspore/avformat-59.dll +0 -0
  23. mindspore/avutil-57.dll +0 -0
  24. mindspore/boost/base.py +29 -2
  25. mindspore/common/__init__.py +18 -12
  26. mindspore/common/_decorator.py +3 -2
  27. mindspore/common/_grad_function.py +3 -1
  28. mindspore/common/_tensor_cpp_method.py +1 -1
  29. mindspore/common/_tensor_docs.py +371 -96
  30. mindspore/common/_utils.py +7 -43
  31. mindspore/common/api.py +434 -135
  32. mindspore/common/dtype.py +98 -57
  33. mindspore/common/dump.py +7 -108
  34. mindspore/common/dynamic_shape/__init__.py +0 -0
  35. mindspore/common/{auto_dynamic_shape.py → dynamic_shape/auto_dynamic_shape.py} +15 -23
  36. mindspore/common/dynamic_shape/enable_dynamic.py +197 -0
  37. mindspore/common/file_system.py +59 -9
  38. mindspore/common/hook_handle.py +82 -3
  39. mindspore/common/jit_config.py +5 -1
  40. mindspore/common/jit_trace.py +27 -12
  41. mindspore/common/lazy_inline.py +5 -3
  42. mindspore/common/np_dtype.py +3 -3
  43. mindspore/common/parameter.py +17 -127
  44. mindspore/common/recompute.py +4 -13
  45. mindspore/common/tensor.py +50 -217
  46. mindspore/communication/_comm_helper.py +11 -1
  47. mindspore/communication/comm_func.py +138 -4
  48. mindspore/communication/management.py +85 -1
  49. mindspore/config/op_info.config +0 -15
  50. mindspore/context.py +20 -106
  51. mindspore/dataset/__init__.py +1 -1
  52. mindspore/dataset/audio/transforms.py +1 -1
  53. mindspore/dataset/core/config.py +35 -1
  54. mindspore/dataset/engine/datasets.py +338 -319
  55. mindspore/dataset/engine/datasets_user_defined.py +38 -22
  56. mindspore/dataset/engine/datasets_vision.py +1 -1
  57. mindspore/dataset/engine/validators.py +1 -15
  58. mindspore/dataset/transforms/c_transforms.py +2 -2
  59. mindspore/dataset/transforms/transforms.py +3 -3
  60. mindspore/dataset/vision/__init__.py +1 -1
  61. mindspore/dataset/vision/py_transforms.py +8 -8
  62. mindspore/dataset/vision/transforms.py +17 -5
  63. mindspore/dataset/vision/utils.py +632 -21
  64. mindspore/device_context/ascend/op_tuning.py +35 -1
  65. mindspore/dnnl.dll +0 -0
  66. mindspore/{profiler/common/validator → graph}/__init__.py +9 -1
  67. mindspore/graph/custom_pass.py +55 -0
  68. mindspore/include/api/cell.h +28 -4
  69. mindspore/include/api/cfg.h +24 -7
  70. mindspore/include/api/context.h +1 -0
  71. mindspore/include/api/delegate.h +0 -2
  72. mindspore/include/api/dual_abi_helper.h +100 -19
  73. mindspore/include/api/graph.h +14 -1
  74. mindspore/include/api/kernel.h +16 -3
  75. mindspore/include/api/kernel_api.h +9 -1
  76. mindspore/include/api/metrics/accuracy.h +9 -0
  77. mindspore/include/api/model.h +5 -1
  78. mindspore/include/api/model_group.h +4 -0
  79. mindspore/include/api/model_parallel_runner.h +2 -0
  80. mindspore/include/api/status.h +48 -10
  81. mindspore/include/api/types.h +6 -1
  82. mindspore/include/dataset/constants.h +9 -0
  83. mindspore/include/dataset/execute.h +2 -2
  84. mindspore/jpeg62.dll +0 -0
  85. mindspore/mindrecord/__init__.py +3 -3
  86. mindspore/mindrecord/common/exceptions.py +1 -0
  87. mindspore/mindrecord/config.py +1 -1
  88. mindspore/{parallel/mpi → mindrecord/core}/__init__.py +4 -1
  89. mindspore/mindrecord/{shardheader.py → core/shardheader.py} +2 -1
  90. mindspore/mindrecord/{shardindexgenerator.py → core/shardindexgenerator.py} +1 -1
  91. mindspore/mindrecord/{shardreader.py → core/shardreader.py} +2 -1
  92. mindspore/mindrecord/{shardsegment.py → core/shardsegment.py} +2 -2
  93. mindspore/mindrecord/{shardutils.py → core/shardutils.py} +1 -1
  94. mindspore/mindrecord/{shardwriter.py → core/shardwriter.py} +1 -1
  95. mindspore/mindrecord/filereader.py +4 -4
  96. mindspore/mindrecord/filewriter.py +5 -5
  97. mindspore/mindrecord/mindpage.py +2 -2
  98. mindspore/mindrecord/tools/cifar10.py +4 -3
  99. mindspore/mindrecord/tools/cifar100.py +1 -1
  100. mindspore/mindrecord/tools/cifar100_to_mr.py +1 -1
  101. mindspore/mindrecord/tools/cifar10_to_mr.py +6 -6
  102. mindspore/mindrecord/tools/csv_to_mr.py +1 -1
  103. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  104. mindspore/mindrecord/tools/mnist_to_mr.py +1 -1
  105. mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -1
  106. mindspore/mindspore_backend_common.dll +0 -0
  107. mindspore/mindspore_backend_manager.dll +0 -0
  108. mindspore/mindspore_cluster.dll +0 -0
  109. mindspore/mindspore_common.dll +0 -0
  110. mindspore/mindspore_core.dll +0 -0
  111. mindspore/mindspore_cpu.dll +0 -0
  112. mindspore/mindspore_dump.dll +0 -0
  113. mindspore/mindspore_frontend.dll +0 -0
  114. mindspore/mindspore_glog.dll +0 -0
  115. mindspore/mindspore_hardware_abstract.dll +0 -0
  116. mindspore/mindspore_memory_pool.dll +0 -0
  117. mindspore/mindspore_ms_backend.dll +0 -0
  118. mindspore/mindspore_ops.dll +0 -0
  119. mindspore/{mindspore_ops_host.dll → mindspore_ops_cpu.dll} +0 -0
  120. mindspore/mindspore_profiler.dll +0 -0
  121. mindspore/mindspore_pyboost.dll +0 -0
  122. mindspore/mindspore_pynative.dll +0 -0
  123. mindspore/mindspore_runtime_pipeline.dll +0 -0
  124. mindspore/mindspore_runtime_utils.dll +0 -0
  125. mindspore/mindspore_tools.dll +0 -0
  126. mindspore/mint/__init__.py +15 -10
  127. mindspore/mint/distributed/__init__.py +4 -0
  128. mindspore/mint/distributed/distributed.py +392 -69
  129. mindspore/mint/nn/__init__.py +2 -16
  130. mindspore/mint/nn/functional.py +4 -110
  131. mindspore/mint/nn/layer/__init__.py +0 -2
  132. mindspore/mint/nn/layer/_functions.py +1 -2
  133. mindspore/mint/nn/layer/activation.py +0 -6
  134. mindspore/mint/nn/layer/basic.py +0 -47
  135. mindspore/mint/nn/layer/conv.py +10 -10
  136. mindspore/mint/nn/layer/normalization.py +11 -16
  137. mindspore/mint/nn/layer/pooling.py +0 -4
  138. mindspore/nn/__init__.py +1 -3
  139. mindspore/nn/cell.py +231 -239
  140. mindspore/nn/layer/activation.py +4 -2
  141. mindspore/nn/layer/basic.py +56 -14
  142. mindspore/nn/layer/container.py +16 -0
  143. mindspore/nn/layer/embedding.py +4 -169
  144. mindspore/nn/layer/image.py +1 -1
  145. mindspore/nn/layer/normalization.py +2 -1
  146. mindspore/nn/layer/thor_layer.py +4 -85
  147. mindspore/nn/optim/ada_grad.py +0 -1
  148. mindspore/nn/optim/adafactor.py +0 -1
  149. mindspore/nn/optim/adam.py +32 -127
  150. mindspore/nn/optim/adamax.py +0 -1
  151. mindspore/nn/optim/asgd.py +0 -1
  152. mindspore/nn/optim/ftrl.py +8 -102
  153. mindspore/nn/optim/lamb.py +1 -4
  154. mindspore/nn/optim/lars.py +0 -3
  155. mindspore/nn/optim/lazyadam.py +25 -218
  156. mindspore/nn/optim/momentum.py +5 -43
  157. mindspore/nn/optim/optimizer.py +6 -55
  158. mindspore/nn/optim/proximal_ada_grad.py +0 -1
  159. mindspore/nn/optim/rmsprop.py +0 -1
  160. mindspore/nn/optim/rprop.py +0 -1
  161. mindspore/nn/optim/sgd.py +0 -1
  162. mindspore/nn/optim/tft_wrapper.py +2 -4
  163. mindspore/nn/optim/thor.py +0 -2
  164. mindspore/nn/probability/bijector/bijector.py +7 -8
  165. mindspore/nn/probability/bijector/gumbel_cdf.py +2 -2
  166. mindspore/nn/probability/bijector/power_transform.py +20 -21
  167. mindspore/nn/probability/bijector/scalar_affine.py +5 -5
  168. mindspore/nn/probability/bijector/softplus.py +13 -14
  169. mindspore/nn/probability/distribution/_utils/utils.py +2 -2
  170. mindspore/nn/wrap/cell_wrapper.py +39 -5
  171. mindspore/nn/wrap/grad_reducer.py +4 -89
  172. mindspore/numpy/array_creations.py +4 -4
  173. mindspore/numpy/fft.py +9 -9
  174. mindspore/numpy/utils_const.py +1 -1
  175. mindspore/{nn/reinforcement → onnx}/__init__.py +5 -8
  176. mindspore/onnx/onnx_export.py +137 -0
  177. mindspore/opencv_core4110.dll +0 -0
  178. mindspore/opencv_imgcodecs4110.dll +0 -0
  179. mindspore/{opencv_imgproc452.dll → opencv_imgproc4110.dll} +0 -0
  180. mindspore/ops/__init__.py +2 -0
  181. mindspore/ops/_grad_experimental/grad_comm_ops.py +38 -2
  182. mindspore/ops/_grad_experimental/grad_inner_ops.py +0 -9
  183. mindspore/ops/_op_impl/aicpu/__init__.py +0 -10
  184. mindspore/ops/_op_impl/cpu/__init__.py +1 -5
  185. mindspore/ops/_op_impl/cpu/{buffer_append.py → joinedstr_op.py} +8 -8
  186. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +28 -24
  187. mindspore/ops/auto_generate/gen_extend_func.py +6 -11
  188. mindspore/ops/auto_generate/gen_ops_def.py +385 -154
  189. mindspore/ops/auto_generate/gen_ops_prim.py +5676 -5167
  190. mindspore/ops/communication.py +97 -0
  191. mindspore/ops/composite/__init__.py +5 -2
  192. mindspore/ops/composite/base.py +16 -2
  193. mindspore/ops/composite/multitype_ops/__init__.py +3 -1
  194. mindspore/ops/composite/multitype_ops/_compile_utils.py +150 -8
  195. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -1
  196. mindspore/ops/composite/multitype_ops/add_impl.py +7 -0
  197. mindspore/ops/composite/multitype_ops/mod_impl.py +27 -0
  198. mindspore/ops/function/__init__.py +2 -0
  199. mindspore/ops/function/array_func.py +24 -18
  200. mindspore/ops/function/comm_func.py +3883 -0
  201. mindspore/ops/function/debug_func.py +7 -6
  202. mindspore/ops/function/grad/grad_func.py +4 -12
  203. mindspore/ops/function/math_func.py +89 -86
  204. mindspore/ops/function/nn_func.py +92 -313
  205. mindspore/ops/function/random_func.py +9 -18
  206. mindspore/ops/functional.py +4 -1
  207. mindspore/ops/functional_overload.py +377 -30
  208. mindspore/ops/operations/__init__.py +2 -5
  209. mindspore/ops/operations/_custom_ops_utils.py +7 -9
  210. mindspore/ops/operations/_inner_ops.py +12 -50
  211. mindspore/ops/operations/_rl_inner_ops.py +0 -933
  212. mindspore/ops/operations/array_ops.py +5 -50
  213. mindspore/ops/operations/comm_ops.py +95 -17
  214. mindspore/ops/operations/custom_ops.py +237 -22
  215. mindspore/ops/operations/debug_ops.py +33 -35
  216. mindspore/ops/operations/manually_defined/ops_def.py +39 -318
  217. mindspore/ops/operations/math_ops.py +5 -5
  218. mindspore/ops/operations/nn_ops.py +3 -3
  219. mindspore/ops/operations/sparse_ops.py +0 -83
  220. mindspore/ops/primitive.py +4 -27
  221. mindspore/ops/tensor_method.py +88 -10
  222. mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +5 -5
  223. mindspore/ops_generate/aclnn/gen_aclnn_implement.py +8 -8
  224. mindspore/ops_generate/api/functions_cc_generator.py +53 -4
  225. mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +25 -11
  226. mindspore/ops_generate/common/gen_constants.py +11 -10
  227. mindspore/ops_generate/common/op_proto.py +18 -1
  228. mindspore/ops_generate/common/template.py +102 -245
  229. mindspore/ops_generate/common/template_utils.py +212 -0
  230. mindspore/ops_generate/gen_custom_ops.py +69 -0
  231. mindspore/ops_generate/op_def/ops_def_cc_generator.py +78 -7
  232. mindspore/ops_generate/op_def_py/base_op_prim_py_generator.py +360 -0
  233. mindspore/ops_generate/op_def_py/custom_op_prim_py_generator.py +140 -0
  234. mindspore/ops_generate/op_def_py/op_def_py_generator.py +54 -7
  235. mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -312
  236. mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +74 -17
  237. mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +22 -5
  238. mindspore/ops_generate/pyboost/gen_pyboost_func.py +0 -16
  239. mindspore/ops_generate/pyboost/op_template_parser.py +3 -2
  240. mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +21 -5
  241. mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +2 -2
  242. mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +30 -10
  243. mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +10 -3
  244. mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +1 -1
  245. mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +19 -9
  246. mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +71 -28
  247. mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +10 -9
  248. mindspore/ops_generate/pyboost/pyboost_utils.py +27 -16
  249. mindspore/ops_generate/resources/yaml_loader.py +13 -0
  250. mindspore/ops_generate/tensor_py_cc_generator.py +2 -2
  251. mindspore/parallel/_auto_parallel_context.py +5 -15
  252. mindspore/parallel/_cell_wrapper.py +1 -1
  253. mindspore/parallel/_parallel_serialization.py +4 -6
  254. mindspore/parallel/_ps_context.py +2 -2
  255. mindspore/parallel/_utils.py +34 -17
  256. mindspore/parallel/auto_parallel.py +23 -9
  257. mindspore/parallel/checkpoint_transform.py +20 -2
  258. mindspore/parallel/cluster/process_entity/_api.py +28 -33
  259. mindspore/parallel/cluster/process_entity/_utils.py +9 -5
  260. mindspore/parallel/cluster/run.py +5 -3
  261. mindspore/{experimental/llm_boost/ascend_native → parallel/distributed}/__init__.py +21 -22
  262. mindspore/parallel/distributed/distributed_data_parallel.py +393 -0
  263. mindspore/parallel/distributed/flatten_grad_buffer.py +295 -0
  264. mindspore/parallel/function/reshard_func.py +6 -5
  265. mindspore/parallel/nn/parallel_cell_wrapper.py +40 -3
  266. mindspore/parallel/nn/parallel_grad_reducer.py +0 -8
  267. mindspore/parallel/shard.py +7 -21
  268. mindspore/parallel/strategy.py +336 -0
  269. mindspore/parallel/transform_safetensors.py +127 -20
  270. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +13 -9
  271. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +1 -1
  272. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +1 -1
  273. mindspore/profiler/common/constant.py +5 -0
  274. mindspore/profiler/common/file_manager.py +9 -0
  275. mindspore/profiler/common/msprof_cmd_tool.py +40 -4
  276. mindspore/profiler/common/path_manager.py +65 -24
  277. mindspore/profiler/common/profiler_context.py +27 -14
  278. mindspore/profiler/common/profiler_info.py +3 -3
  279. mindspore/profiler/common/profiler_meta_data.py +1 -0
  280. mindspore/profiler/common/profiler_op_analyse.py +10 -6
  281. mindspore/profiler/common/profiler_path_manager.py +13 -0
  282. mindspore/profiler/common/util.py +30 -3
  283. mindspore/profiler/dynamic_profiler.py +91 -46
  284. mindspore/profiler/envprofiler.py +30 -5
  285. mindspore/profiler/experimental_config.py +18 -2
  286. mindspore/profiler/platform/cpu_profiler.py +10 -4
  287. mindspore/profiler/platform/npu_profiler.py +34 -7
  288. mindspore/profiler/profiler.py +193 -145
  289. mindspore/profiler/profiler_action_controller.py +1 -1
  290. mindspore/profiler/profiler_interface.py +2 -2
  291. mindspore/rewrite/symbol_tree/symbol_tree.py +1 -1
  292. mindspore/run_check/_check_version.py +108 -24
  293. mindspore/runtime/__init__.py +9 -6
  294. mindspore/runtime/executor.py +35 -0
  295. mindspore/runtime/memory.py +113 -0
  296. mindspore/runtime/thread_bind_core.py +1 -1
  297. mindspore/swresample-4.dll +0 -0
  298. mindspore/swscale-6.dll +0 -0
  299. mindspore/tinyxml2.dll +0 -0
  300. mindspore/{experimental/llm_boost → tools}/__init__.py +5 -5
  301. mindspore/tools/data_dump.py +130 -0
  302. mindspore/tools/sdc_detect.py +91 -0
  303. mindspore/tools/stress_detect.py +63 -0
  304. mindspore/train/__init__.py +6 -6
  305. mindspore/train/_utils.py +8 -21
  306. mindspore/train/amp.py +6 -7
  307. mindspore/train/callback/_callback.py +2 -1
  308. mindspore/train/callback/_checkpoint.py +1 -17
  309. mindspore/train/callback/_flops_collector.py +10 -6
  310. mindspore/train/callback/_train_fault_tolerance.py +72 -25
  311. mindspore/train/data_sink.py +5 -9
  312. mindspore/train/dataset_helper.py +5 -5
  313. mindspore/train/model.py +41 -230
  314. mindspore/train/serialization.py +160 -401
  315. mindspore/train/train_thor/model_thor.py +2 -2
  316. mindspore/turbojpeg.dll +0 -0
  317. mindspore/utils/__init__.py +6 -3
  318. mindspore/utils/dlpack.py +92 -0
  319. mindspore/utils/dryrun.py +1 -1
  320. mindspore/utils/runtime_execution_order_check.py +10 -0
  321. mindspore/utils/sdc_detect.py +14 -12
  322. mindspore/utils/stress_detect.py +43 -0
  323. mindspore/utils/utils.py +152 -16
  324. mindspore/version.py +1 -1
  325. {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/METADATA +3 -2
  326. {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/RECORD +330 -344
  327. mindspore/_extends/remote/kernel_build_server_ascend.py +0 -75
  328. mindspore/communication/_hccl_management.py +0 -297
  329. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -207
  330. mindspore/experimental/llm_boost/ascend_native/llm_boost.py +0 -52
  331. mindspore/experimental/llm_boost/atb/__init__.py +0 -23
  332. mindspore/experimental/llm_boost/atb/boost_base.py +0 -385
  333. mindspore/experimental/llm_boost/atb/llama_boost.py +0 -137
  334. mindspore/experimental/llm_boost/atb/qwen_boost.py +0 -124
  335. mindspore/experimental/llm_boost/register.py +0 -130
  336. mindspore/experimental/llm_boost/utils.py +0 -31
  337. mindspore/include/OWNERS +0 -7
  338. mindspore/mindspore_cpu_res_manager.dll +0 -0
  339. mindspore/mindspore_ops_kernel_common.dll +0 -0
  340. mindspore/mindspore_res_manager.dll +0 -0
  341. mindspore/nn/optim/_dist_optimizer_registry.py +0 -111
  342. mindspore/nn/reinforcement/_batch_read_write.py +0 -142
  343. mindspore/nn/reinforcement/_tensors_queue.py +0 -152
  344. mindspore/nn/reinforcement/tensor_array.py +0 -145
  345. mindspore/opencv_core452.dll +0 -0
  346. mindspore/opencv_imgcodecs452.dll +0 -0
  347. mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +0 -113
  348. mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +0 -96
  349. mindspore/ops/_op_impl/aicpu/sparse_cross.py +0 -42
  350. mindspore/ops/_op_impl/cpu/buffer_get.py +0 -28
  351. mindspore/ops/_op_impl/cpu/buffer_sample.py +0 -28
  352. mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +0 -42
  353. mindspore/ops/operations/_tensor_array.py +0 -359
  354. mindspore/ops/operations/rl_ops.py +0 -288
  355. mindspore/parallel/_offload_context.py +0 -275
  356. mindspore/parallel/_recovery_context.py +0 -115
  357. mindspore/parallel/_transformer/__init__.py +0 -35
  358. mindspore/parallel/_transformer/layers.py +0 -765
  359. mindspore/parallel/_transformer/loss.py +0 -251
  360. mindspore/parallel/_transformer/moe.py +0 -693
  361. mindspore/parallel/_transformer/op_parallel_config.py +0 -222
  362. mindspore/parallel/_transformer/transformer.py +0 -3124
  363. mindspore/parallel/mpi/_mpi_config.py +0 -116
  364. mindspore/profiler/common/validator/validate_path.py +0 -84
  365. mindspore/train/memory_profiling_pb2.py +0 -298
  366. mindspore/utils/hooks.py +0 -81
  367. /mindspore/common/{_auto_dynamic.py → dynamic_shape/_auto_dynamic.py} +0 -0
  368. {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/WHEEL +0 -0
  369. {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/entry_points.txt +0 -0
  370. {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/top_level.txt +0 -0
@@ -1,3124 +0,0 @@
1
- # Copyright 2021-2023 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ============================================================================
15
- """
16
- Note:
17
- Transformer Networks. This is interface that is subject to change or deletion.
18
- """
19
- from __future__ import absolute_import
20
-
21
- import math
22
- import numpy as np
23
-
24
- from mindspore.common.tensor import Tensor
25
- from mindspore.common.parameter import Parameter
26
- from mindspore.common.initializer import initializer
27
- from mindspore import nn
28
- from mindspore import context
29
- import mindspore.common.dtype as mstype
30
- from mindspore.ops import operations as P
31
- from mindspore.ops import functional as F
32
- from mindspore.nn.cell import Cell
33
- from mindspore import _checkparam as Validator
34
- from mindspore import log as logger
35
- from mindspore.parallel._utils import _get_parallel_mode
36
- from mindspore.context import ParallelMode
37
- from mindspore.log import _LogActionOnce
38
- from mindspore.parallel._transformer.layers import _LayerNorm, _Linear, \
39
- _args_type_validator_check, _valid_type_checks, _valid_value_checks, \
40
- _check_past_none_input_none, _check_input_dtype
41
- from mindspore.parallel._transformer.op_parallel_config import default_dpmp_config, _PipeLineConfig, OpParallelConfig, \
42
- _Config, _check_config, MoEParallelConfig
43
- from mindspore.parallel._transformer.moe import default_moe_config, MoE, _check_moe_config
44
-
45
- __all__ = [
46
- "AttentionMask",
47
- "VocabEmbedding",
48
- "MultiHeadAttention",
49
- "FeedForward",
50
- "TransformerEncoder",
51
- "TransformerDecoder",
52
- "TransformerEncoderLayer",
53
- "TransformerDecoderLayer",
54
- "Transformer",
55
- "TransformerOpParallelConfig",
56
- "EmbeddingOpParallelConfig",
57
- "TransformerRecomputeConfig"]
58
-
59
-
60
- class EmbeddingOpParallelConfig(_Config):
61
- r"""
62
- The parallel config of :class:`VocabEmbedding`
63
- for the setting data parallel or model parallel for the embedding table.
64
-
65
- Args:
66
- data_parallel(int): The data parallel way. The input data will be sliced into n parts for embedding layer
67
- according to this value. Default: 1.
68
- model_parallel(int): The model parallel way. The embedding table parameters
69
- will be sliced at 0-th axis according to the model parallel way. Default: 1.
70
- vocab_emb_dp(bool): Shard embedding in model parallel or data parallel. If True, the embedding lookup
71
- will be a data parallel style training and model_parallel value will be ignored. If false, the
72
- embedding table will be sharded into n parts at the 0-th dimension row slice of the embedding table,
73
- where the n is the model parallel way determined by this parameter. Default: ``True``
74
-
75
- Supported Platforms:
76
- ``Ascend`` ``GPU``
77
-
78
- Examples:
79
- >>> from mindspore.nn.transformer import EmbeddingOpParallelConfig
80
- >>> config=EmbeddingOpParallelConfig(data_parallel=1, model_parallel=1, vocab_emb_dp=True)
81
- """
82
-
83
- def __init__(self, data_parallel=1, model_parallel=1, vocab_emb_dp=True):
84
- self._dp_mp_config = OpParallelConfig(data_parallel=data_parallel, model_parallel=model_parallel)
85
- Validator.check_bool(vocab_emb_dp, "vocab_emb_dp")
86
- self.vocab_emb_dp = vocab_emb_dp
87
-
88
- @property
89
- def data_parallel(self):
90
- return self._dp_mp_config.data_parallel
91
-
92
- @data_parallel.setter
93
- def data_parallel(self, value):
94
- self._dp_mp_config.data_parallel = value
95
-
96
- @property
97
- def model_parallel(self):
98
- return self._dp_mp_config.model_parallel
99
-
100
- @model_parallel.setter
101
- def model_parallel(self, value):
102
- self._dp_mp_config.model_parallel = value
103
-
104
- @property
105
- def vocab_emb_dp(self):
106
- return self._vocab_emb_dp
107
-
108
- @vocab_emb_dp.setter
109
- def vocab_emb_dp(self, value):
110
- Validator.check_bool(value, "vocab_emb_dp")
111
- self._vocab_emb_dp = value
112
-
113
- @property
114
- def dp_mp_config(self):
115
- return self._dp_mp_config
116
-
117
-
118
- class TransformerRecomputeConfig(_Config):
119
- r"""
120
- TransformerRecomputeConfig for the setting recompute attributes for encoder/decoder layers.
121
-
122
- Args:
123
- recompute (bool): Enable recomputation of the transformer block or not. Default: ``False``.
124
- parallel_optimizer_comm_recompute (bool): Specifies whether the communication operator allgathers
125
- introduced by optimizer shard are recomputed in auto parallel or semi auto parallel mode.
126
- Default: ``False``.
127
- mp_comm_recompute (bool): Specifies whether the model parallel communication operators
128
- in the cell are recomputed in auto parallel or semi auto parallel mode. Default: ``True``.
129
- recompute_slice_activation (bool): Slice the cell output which would remains in memory. Default: ``False``.
130
-
131
- Supported Platforms:
132
- ``Ascend`` ``GPU``
133
-
134
- Examples:
135
- >>> from mindspore.nn.transformer import TransformerRecomputeConfig
136
- >>> config=TransformerRecomputeConfig(recompute=True, parallel_optimizer_comm_recompute=True, \
137
- ... mp_comm_recompute=True, recompute_slice_activation=True)
138
- """
139
-
140
- def __init__(self, recompute=False, parallel_optimizer_comm_recompute=False,
141
- mp_comm_recompute=True, recompute_slice_activation=False):
142
- Validator.check_bool(recompute, "recompute")
143
- Validator.check_bool(parallel_optimizer_comm_recompute, "parallel_optimizer_comm_recompute")
144
- Validator.check_bool(mp_comm_recompute, "mp_comm_recompute")
145
- Validator.check_bool(recompute_slice_activation, "recompute_slice_activation")
146
- self._recompute = recompute
147
- self._parallel_optimizer_comm_recompute = parallel_optimizer_comm_recompute
148
- self._mp_comm_recompute = mp_comm_recompute
149
- self._recompute_slice_activation = recompute_slice_activation
150
-
151
- @property
152
- def recompute(self):
153
- return self._recompute
154
-
155
- @recompute.setter
156
- def recompute(self, value):
157
- Validator.check_bool(value, "recompute")
158
- self._recompute = value
159
-
160
- @property
161
- def parallel_optimizer_comm_recompute(self):
162
- return self._parallel_optimizer_comm_recompute
163
-
164
- @parallel_optimizer_comm_recompute.setter
165
- def parallel_optimizer_comm_recompute(self, value):
166
- Validator.check_bool(value, "parallel_optimizer_comm_recompute")
167
- self._parallel_optimizer_comm_recompute = value
168
-
169
- @property
170
- def mp_comm_recompute(self):
171
- return self._mp_comm_recompute
172
-
173
- @mp_comm_recompute.setter
174
- def mp_comm_recompute(self, value):
175
- Validator.check_bool(value, "mp_comm_recompute")
176
- self._mp_comm_recompute = value
177
-
178
- @property
179
- def recompute_slice_activation(self):
180
- return self._recompute_slice_activation
181
-
182
- @recompute_slice_activation.setter
183
- def recompute_slice_activation(self, value):
184
- Validator.check_bool(value, "recompute_slice_activation")
185
- self._recompute_slice_activation = value
186
-
187
-
188
- default_transformer_recompute_config = TransformerRecomputeConfig()
189
-
190
-
191
- class TransformerOpParallelConfig(_Config):
192
- r"""
193
- TransformerOpParallelConfig for setting parallel configuration, such as the data parallel and model parallel.
194
-
195
- Note:
196
- Except the recompute argument, other arguments will **not** be effective when the user doesn't set
197
- auto_parallel_context to `SEMI_AUTO_PARALLEL` or `AUTO_PARALLEL`.
198
- The micro_batch_num must be greater than or equal to pipeline_stage when training.
199
- The data_parallel\*model_parallel \*pipeline_stage must be equal or less equal to the device. When setting
200
- the pipeline stage and optimizer_shard, the config will overwrite the auto_parallel_context. When given the
201
- 8 devices and the data_parallel is 1 and model_parallel is 1, the calculation will be repeated on each
202
- device.
203
-
204
- Args:
205
- data_parallel (int): The data parallel way. The input data will be sliced into n parts for each layer
206
- according to the data parallel way. Default: 1.
207
- model_parallel (int): The model parallel way. The parameters of dense layers in MultiheadAttention and
208
- FeedForward layer will be sliced according to the model parallel way. Default: 1.
209
- expert_parallel (int): The expert parallel way. This is effective only when MoE (Mixture of Experts)
210
- is applied. This value specifies the number of partitions to split the experts into.
211
- pipeline_stage (int): The number of the pipeline stage. Should be a positive value. Default: 1.
212
- micro_batch_num (int): The micro size of the batches for the pipeline training. Default: 1.
213
- optimizer_shard (bool): Whether to enable optimizer shard. Default False.
214
- gradient_aggregation_group (int): The fusion group size of the optimizer state sharding. Default: 4.
215
- recompute (Union[TransformerRecomputeConfig, bool]): The configuration of recomputation for
216
- the transformer block. Default: An instance of TransformerRecomputeConfig with default values.
217
- vocab_emb_dp (bool): Shard embedding in model parallel or data parallel. Default: ``True``.
218
-
219
- Supported Platforms:
220
- ``Ascend`` ``GPU``
221
-
222
- Examples:
223
- >>> from mindspore.nn.transformer import TransformerRecomputeConfig
224
- >>> recompute_config=TransformerRecomputeConfig(recompute=True, parallel_optimizer_comm_recompute=True, \
225
- ... mp_comm_recompute=True, recompute_slice_activation=True)
226
- >>> config=TransformerOpParallelConfig(data_parallel=1, model_parallel=1, recompute=recompute_config)
227
- """
228
-
229
- def __init__(self, data_parallel=1, model_parallel=1, expert_parallel=1, pipeline_stage=1, pipeline_segment=1,
230
- micro_batch_num=1,
231
- recompute=default_transformer_recompute_config,
232
- optimizer_shard=False, gradient_aggregation_group=4, vocab_emb_dp=True):
233
- self.recompute = recompute
234
- self.optimizer_shard = optimizer_shard
235
- self.gradient_aggregation_group = gradient_aggregation_group
236
- self._embed_dp_mp_config = EmbeddingOpParallelConfig(data_parallel=data_parallel, model_parallel=model_parallel,
237
- vocab_emb_dp=vocab_emb_dp)
238
- self._pp_config = _PipeLineConfig(pipeline_stage=pipeline_stage, micro_batch_num=micro_batch_num,
239
- pipeline_segment=pipeline_segment)
240
- self._moe_config = MoEParallelConfig(data_parallel=data_parallel, model_parallel=model_parallel,
241
- expert_parallel=expert_parallel)
242
-
243
- @property
244
- def recompute(self):
245
- return self._recompute
246
-
247
- @recompute.setter
248
- def recompute(self, value):
249
- if not isinstance(value, TransformerRecomputeConfig) and not isinstance(value, bool):
250
- raise TypeError(f"recompute must be a TransformerRecomputeConfig/bool, but got {type(value).__name__}.")
251
- if isinstance(value, bool):
252
- logger.warning(f"TransformerRecomputeConfig is recommended as the recompute configuration type.")
253
- self._recompute = value
254
-
255
- @property
256
- def vocab_emb_dp(self):
257
- return self._embed_dp_mp_config.vocab_emb_dp
258
-
259
- @vocab_emb_dp.setter
260
- def vocab_emb_dp(self, value):
261
- self._embed_dp_mp_config.vocab_emb_dp = value
262
-
263
- @property
264
- def gradient_aggregation_group(self):
265
- return self._gradient_aggregation_group
266
-
267
- @gradient_aggregation_group.setter
268
- def gradient_aggregation_group(self, value):
269
- Validator.check_positive_int(value, "gradient_aggregation_group")
270
- self._gradient_aggregation_group = value
271
-
272
- @property
273
- def micro_batch_num(self):
274
- return self._pp_config.micro_batch_num
275
-
276
- @micro_batch_num.setter
277
- def micro_batch_num(self, value):
278
- self._pp_config.micro_batch_num = value
279
-
280
- @property
281
- def model_parallel(self):
282
- return self._embed_dp_mp_config.model_parallel
283
-
284
- @model_parallel.setter
285
- def model_parallel(self, value):
286
- self._embed_dp_mp_config.model_parallel = value
287
- self._moe_config.model_parallel = value
288
-
289
- @property
290
- def data_parallel(self):
291
- return self._embed_dp_mp_config.data_parallel
292
-
293
- @data_parallel.setter
294
- def data_parallel(self, value):
295
- self._embed_dp_mp_config.data_parallel = value
296
- self._moe_config.data_parallel = value
297
-
298
- @property
299
- def expert_parallel(self):
300
- return self._moe_config.expert_parallel
301
-
302
- @expert_parallel.setter
303
- def expert_parallel(self, value):
304
- self._moe_config.expert_parallel = value
305
-
306
- @property
307
- def pipeline_stage(self):
308
- return self._pp_config.pipeline_stage
309
-
310
- @pipeline_stage.setter
311
- def pipeline_stage(self, value):
312
- self._pp_config.pipeline_stage = value
313
-
314
- @property
315
- def pipeline_segment(self):
316
- return self._pp_config.pipeline_segment
317
-
318
- @pipeline_segment.setter
319
- def pipeline_segment(self, value):
320
- self._pp_config.pipeline_segment = value
321
-
322
- @property
323
- def optimizer_shard(self):
324
- return self._optimizer_shard
325
-
326
- @optimizer_shard.setter
327
- def optimizer_shard(self, value):
328
- Validator.check_bool(value, "optimizer_shard")
329
- self._optimizer_shard = value
330
- context.set_auto_parallel_context(enable_parallel_optimizer=value)
331
-
332
- @property
333
- def embedding_dp_mp_config(self):
334
- return self._embed_dp_mp_config
335
-
336
- @property
337
- def dp_mp_config(self):
338
- return self._embed_dp_mp_config.dp_mp_config
339
-
340
- @property
341
- def moe_parallel_config(self):
342
- return self._moe_config
343
-
344
-
345
- default_transformer_config = TransformerOpParallelConfig()
346
- default_embedding_parallel_config = EmbeddingOpParallelConfig()
347
-
348
-
349
- class FeedForward(Cell):
350
- r"""
351
- The multilayer perceptron with two linear layers with dropout applied at final output. The first linear
352
- will project the input dimension from hidden_size to ffn_hidden_size. The second linear will project the
353
- dimension from ffn_hidden_size to hidden_size. The first linear is sharded on the relative dimension,
354
- and the second linear is sharded on the output dimension. The overview process can be:
355
-
356
- .. math::
357
- Dropout((xW_1+b_1)W_2 + b_2)
358
-
359
- where the :math:`W_1, W_2, b_1` and :math:`b_2` are trainable parameters.
360
-
361
- Args:
362
- hidden_size (int): The dimension of the inputs.
363
- ffn_hidden_size (int): The intermediate hidden size.
364
- dropout_rate (float): The dropout rate for the second linear's output.
365
- hidden_act (str, nn.Cell): The activation of the internal feedforward layer. Supports 'relu',
366
- 'relu6', 'tanh', 'gelu', 'fast_gelu', 'elu', 'sigmoid', 'prelu', 'leakyrelu', 'hswish',
367
- 'hsigmoid', 'logsigmoid' and so on. User can provide custom activition to the argument.
368
- If user wants to run the net in the parallel mode, the custom activation must also provide
369
- the `activation_shard` function. Please see examples. Default: gelu.
370
- expert_num (int): The number of experts used in Linear. For the case expert_num > 1, BatchMatMul is used
371
- and the first dimension in BatchMatMul indicate expert_num. Default: 1.
372
- expert_group_size (int): The number of tokens in each data parallel group. Default: ``None``.
373
- This parameter is effective only when in AUTO_PARALLEL mode, and NOT SHARDING_PROPAGATION.
374
- param_init_type (dtype.Number): The parameter initialization type. Should be mstype.float32 or
375
- mstype.float16. Default: mstype.float32.
376
- parallel_config (OpParallelConfig, MoEParallelConfig): The config of parallel setting, see
377
- `OpParallelConfig` or `MoEParallelConfig`. When MoE is applied, MoEParallelConfig is effective,
378
- otherwise OpParallelConfig is effective. Default `default_dpmp_config`,
379
- an instance of `OpParallelConfig` with default args.
380
-
381
- Inputs:
382
- - **x** (Tensor) - should be `[batch, seq_length, hidden_size] or [batch * seq_length, hidden_size]`.
383
- Float tensor.
384
-
385
- Outputs:
386
- Tensor, the output of this layer after mapping. The shape is `[batch, seq_length, hidden_size] or
387
- [batch * seq_length, hidden_size]`.
388
-
389
- Raises:
390
- TypeError: `hidden_act` is not a string or nn.Cell.
391
- TypeError: `parallel_config` is not a subclass of OpParallelConfig.
392
- ValueError: `ffn_hidden_size` is not a multiple of the model parallel way.
393
- ValueError: `hidden_size` is not a multiple of the model parallel way.
394
-
395
- Supported Platforms:
396
- ``Ascend`` ``GPU``
397
-
398
- Examples:
399
- >>> import numpy as np
400
- >>> from mindspore.nn.transformer import FeedForward
401
- >>> from mindspore import dtype as mstype
402
- >>> from mindspore import Tensor, nn
403
- >>> from mindspore import ops
404
- >>> model = FeedForward(hidden_size=15, ffn_hidden_size=30, dropout_rate=0.1)
405
- >>> tensor = Tensor(np.ones((2, 20, 15)), mstype.float32)
406
- >>> output = model(tensor)
407
- >>> print(output.shape)
408
- (2, 20, 15)
409
- >>> # Example 2 using custom hidden activation
410
- >>> class MyActivationNoShard(nn.Cell):
411
- ... def __init__(self):
412
- ... super(MyActivationNoShard, self).__init__()
413
- ... self.add = ops.Add()
414
- ... def construct(self, x):
415
- ... return self.add(x, 0.1)
416
- >>> model = FeedForward(hidden_size=15, ffn_hidden_size=30, dropout_rate=0.1,
417
- ... hidden_act=MyActivationNoShard)
418
- >>> tensor = Tensor(np.ones((2, 20, 15)), mstype.float32)
419
- >>> output = model(tensor)
420
- >>> print(output.shape)
421
- (2, 20, 15)
422
- >>> # Example 3 using custom hidden activation with activation_shard
423
- >>> # If user wantss to run on the SEMI/AUTO parallel mode, the custom activation must provide
424
- >>> # a class function named activation_shard. It accepts the argument parallel_config (OpParallelConfig,
425
- >>> # MoEParallelConfig) and set the shard for the primitives used in the construct.
426
- >>> class MyActivationWithShard(nn.Cell):
427
- ... def __init__(self):
428
- ... super(MyActivationWithShard, self).__init__()
429
- ... self.add = ops.Add()
430
- ... def construct(self, x):
431
- ... return self.add(x, 0.1)
432
- ... def activation_shard(self, parallel_config):
433
- ... self.add.shard(((parallel_config.data_parallel, parallel_config.model_parallel), ()))
434
- >>>
435
- >>> model = FeedForward(hidden_size=15, ffn_hidden_size=30, dropout_rate=0.1,
436
- ... hidden_act=MyActivationWithShard)
437
- >>> tensor = Tensor(np.ones((2, 20, 15)), mstype.float32)
438
- >>> output = model(tensor)
439
- >>> print(output.shape)
440
- (2, 20, 15)
441
- """
442
-
443
- @_LogActionOnce(logger=logger, key='FeedForward',
444
- no_warning=_get_parallel_mode() in (ParallelMode.STAND_ALONE,))
445
- @_args_type_validator_check(hidden_size=Validator.check_positive_int,
446
- ffn_hidden_size=Validator.check_positive_int,
447
- dropout_rate=Validator.check_non_negative_float,
448
- param_init_type=_valid_value_checks([mstype.float32, mstype.float16],
449
- "FeedForward"),
450
- parallel_config=_valid_type_checks([OpParallelConfig, MoEParallelConfig],
451
- "FeedForward"))
452
- def __init__(self, hidden_size,
453
- ffn_hidden_size,
454
- dropout_rate,
455
- hidden_act='gelu',
456
- expert_num=1,
457
- expert_group_size=None,
458
- param_init_type=mstype.float32,
459
- parallel_config=default_dpmp_config):
460
- super(FeedForward, self).__init__()
461
- if hidden_act is None or not (isinstance(hidden_act, str) or issubclass(hidden_act, nn.Cell)):
462
- raise TypeError(f"For FeedForward cell, the hidden_act should str type or nn.Cell type, "
463
- f"but got {hidden_act}.")
464
- if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,):
465
- _check_config(parallel_config)
466
- mp = parallel_config.model_parallel
467
- if expert_num > 1:
468
- ep = parallel_config.expert_parallel
469
- else:
470
- ep = 1
471
- # ffn use less dp than other ops when use_moe, due to there are ops use dp and ep.
472
- dp = parallel_config.data_parallel // ep
473
- if ffn_hidden_size % mp != 0:
474
- raise ValueError("For 'FeedForward', the class variable 'ffn_hidden_size' must be a multiple of the"
475
- "num of model parallel, but got the ffn_hidden_size is {} and the num of model "
476
- "parallel is {}.".format(ffn_hidden_size, mp))
477
- if hidden_size % mp != 0:
478
- raise ValueError("For 'FeedForward', the class variable 'hidden_size' must be a multiple of the num of "
479
- "model parallel, but got the hidden_size is {} and the num of model parallel is {}."
480
- .format(hidden_size, mp))
481
- if dropout_rate < 0 or dropout_rate >= 1:
482
- raise ValueError("For 'FeedForward', the class variable 'dropout_rate' must be in the range [0, 1.0), "
483
- "but got the value : {}.".format(dropout_rate))
484
- input_size = hidden_size
485
- output_size = ffn_hidden_size
486
-
487
- # Project to ffn_hidden_size
488
- self.mapping = _Linear(in_channels=input_size,
489
- out_channels=output_size,
490
- activation=hidden_act,
491
- transpose_b=False,
492
- expert_num=expert_num,
493
- expert_group_size=expert_group_size,
494
- outer_batch=dp,
495
- param_init_type=param_init_type)
496
-
497
- # Project back to hidden_size
498
- self.projection = _Linear(in_channels=output_size,
499
- out_channels=input_size,
500
- transpose_b=False,
501
- expert_num=expert_num,
502
- expert_group_size=expert_group_size,
503
- outer_batch=dp,
504
- param_init_type=param_init_type)
505
- if expert_num > 1:
506
- self.projection.shard(strategy_matmul=((dp, ep, 1, mp), (ep, mp, 1)))
507
- else:
508
- self.projection.shard(strategy_matmul=((dp, mp), (mp, 1)))
509
- self.projection.bias.parallel_optimizer = False
510
- self.dropout = nn.Dropout(p=dropout_rate)
511
- self.dropout_3d = nn.Dropout(p=dropout_rate)
512
- self.dropout_4d = nn.Dropout(p=dropout_rate)
513
- self.cast = P.Cast()
514
- else:
515
- _check_config(parallel_config)
516
- mp = parallel_config.model_parallel
517
- if expert_num > 1:
518
- ep = parallel_config.expert_parallel
519
- else:
520
- ep = 1
521
- # ffn use less dp than other ops when use_moe, due to there are ops use dp and ep.
522
- dp = parallel_config.data_parallel // ep
523
- if ffn_hidden_size % mp != 0:
524
- raise ValueError("For 'FeedForward', the class variable 'ffn_hidden_size' must be a multiple of the"
525
- "num of model parallel, but got the ffn_hidden_size is {} and the num of model "
526
- "parallel is {}.".format(ffn_hidden_size, mp))
527
- if hidden_size % mp != 0:
528
- raise ValueError("For 'FeedForward', the class variable 'hidden_size' must be a multiple of the num of "
529
- "model parallel, but got the hidden_size is {} and the num of model parallel is {}."
530
- .format(hidden_size, mp))
531
- if dropout_rate < 0 or dropout_rate >= 1:
532
- raise ValueError("For 'FeedForward', the class variable 'dropout_rate' must be in the range [0, 1.0), "
533
- "but got the value : {}.".format(dropout_rate))
534
- input_size = hidden_size
535
- output_size = ffn_hidden_size
536
-
537
- # Project to ffn_hidden_size
538
- self.mapping = _Linear(in_channels=input_size,
539
- out_channels=output_size,
540
- activation=hidden_act,
541
- transpose_b=False,
542
- expert_num=expert_num,
543
- expert_group_size=expert_group_size,
544
- outer_batch=dp,
545
- param_init_type=param_init_type)
546
-
547
- if expert_num > 1:
548
- self.mapping.shard(strategy_matmul=((dp, ep, 1, 1), (ep, 1, mp)),
549
- strategy_bias=((dp, ep, 1, mp), (1, ep, 1, mp)),
550
- strategy_activation=((dp, ep, 1, mp),))
551
- else:
552
- self.mapping.shard(strategy_matmul=((dp, 1), (1, mp)),
553
- strategy_bias=((dp, mp), (mp,)),
554
- strategy_activation=((dp, mp),))
555
- # Project back to hidden_size
556
- self.projection = _Linear(in_channels=output_size,
557
- out_channels=input_size,
558
- transpose_b=False,
559
- expert_num=expert_num,
560
- expert_group_size=expert_group_size,
561
- outer_batch=dp,
562
- param_init_type=param_init_type)
563
- if expert_num > 1:
564
- self.projection.shard(strategy_matmul=((dp, ep, 1, mp), (ep, mp, 1)),
565
- strategy_bias=((dp, ep, 1, 1), (1, ep, 1, 1)))
566
- else:
567
- self.projection.shard(strategy_matmul=((dp, mp), (mp, 1)),
568
- strategy_bias=((dp, 1), (1,)))
569
- self.projection.bias.parallel_optimizer = False
570
- self.dropout = nn.Dropout(p=dropout_rate)
571
- self.dropout.dropout.shard(((dp, 1),))
572
- self.dropout_3d = nn.Dropout(p=dropout_rate)
573
- self.dropout_3d.dropout.shard(((dp, 1, 1),))
574
- self.dropout_4d = nn.Dropout(p=dropout_rate)
575
- self.dropout_4d.dropout.shard(((dp, ep, 1, 1),))
576
- self.cast = P.Cast()
577
- # for grouped pairwise exchange alltoall method in pass
578
- self.mapping.matmul.add_prim_attr("gpea_label", True)
579
- self.projection.matmul.add_prim_attr("gpea_label", True)
580
-
581
- def construct(self, x):
582
- _check_input_dtype(F.dtype(x), "x", [mstype.float32, mstype.float16], self.cls_name)
583
- x = self.cast(x, mstype.float16)
584
- # returned shape is [bs, seq_length, ffn_hidden_size] or [bs * seq_length, ffn_hidden_size]
585
- hidden = self.mapping(x)
586
- output = self.projection(hidden)
587
- # returned shape is [bs, seq_length, ffn_hidden_size] or [bs * seq_length, ffn_hidden_size]
588
- if len(F.shape(output)) == 3:
589
- output = self.dropout_3d(output)
590
- elif len(F.shape(output)) == 2:
591
- output = self.dropout(output)
592
- else:
593
- output = self.dropout_4d(output)
594
- return output
595
-
596
-
597
- class AttentionMask(Cell):
598
- r"""
599
- Get the Lower triangular matrix from the input mask. The input mask is a 2D tensor (batch_size, seq_length)
600
- with 1 and 0, where 1 indicates the current position is a valid token, otherwise not.
601
-
602
- Args:
603
- seq_length(int): The sequence length of the input tensor.
604
- parallel_config(OpParallelConfig): The parallel configure. Default `default_dpmp_config`,
605
- an instance of `OpParallelConfig` with default args.
606
-
607
- Inputs:
608
- - **input_mask** (Tensor) - The mask indicating whether each position is a valid input with
609
- (batch_size, seq_length).
610
-
611
- Outputs:
612
- Tensor. The attention mask matrix with shape (batch_size, seq_length, seq_length).
613
-
614
- Raises:
615
- TypeError: `seq_length` is not an integer.
616
- ValueError: `seq_length` is not a positive value.
617
- TypeError: `parallel_config` is not a subclass of OpParallelConfig.
618
-
619
- Supported Platforms:
620
- ``Ascend`` ``GPU``
621
-
622
- Examples:
623
- >>> import numpy as np
624
- >>> from mindspore.nn.transformer import AttentionMask
625
- >>> from mindspore import Tensor
626
- >>> mask = AttentionMask(seq_length=4)
627
- >>> mask_array = np.array([[1, 1, 1, 0]], np.float32)
628
- >>> inputs = Tensor(mask_array)
629
- >>> res = mask(inputs)
630
- >>> print(res)
631
- [[[1. 0. 0. 0]
632
- [1. 1. 0. 0]
633
- [1. 1. 1. 0]
634
- [0. 0. 0. 0]]]
635
- """
636
-
637
- @_LogActionOnce(logger=logger, key='AttentionMask',
638
- no_warning=_get_parallel_mode() in (ParallelMode.STAND_ALONE,))
639
- @_args_type_validator_check(seq_length=Validator.check_positive_int,
640
- parallel_config=_valid_type_checks([OpParallelConfig], "AttentionMask"))
641
- def __init__(self, seq_length, parallel_config=default_dpmp_config):
642
- super(AttentionMask, self).__init__()
643
- self.seq_length = seq_length
644
- self.not_equal = P.NotEqual().shard(((parallel_config.data_parallel, 1), ()))
645
- self.reshape = P.Reshape()
646
- self.mul = P.BatchMatMul().shard(
647
- ((parallel_config.data_parallel, 1, 1), (parallel_config.data_parallel, 1, 1)))
648
- self.expand_dim = P.ExpandDims().shard(((1, 1),))
649
- ones = np.ones(shape=(seq_length, seq_length))
650
- # Default lower triangle mask matrix
651
- self.lower_triangle_mask = Tensor(np.tril(ones), mstype.float32)
652
- self.multiply = P.Mul().shard(((parallel_config.data_parallel, 1, 1), (1, 1, 1)))
653
-
654
- def construct(self, input_mask):
655
- _check_input_dtype(F.dtype(input_mask), "input_mask", [mstype.float32, mstype.float16], self.cls_name)
656
- input_mask = P.Cast()(self.not_equal(input_mask, 0), mstype.float16)
657
- input_shape = P.Shape()(input_mask)
658
- shape_right = (input_shape[0], 1, input_shape[1])
659
- shape_left = input_shape + (1,)
660
- # Mask the padded inputs
661
- mask_left = self.reshape(input_mask, shape_left)
662
- mask_right = self.reshape(input_mask, shape_right)
663
- attention_mask = self.mul(mask_left, mask_right)
664
- lower_traiangle = self.expand_dim(self.lower_triangle_mask, 0)
665
- # the returned shape is [bs, seq_length, seq_length]
666
- attention_mask = self.multiply(
667
- attention_mask, lower_traiangle)
668
- return attention_mask
669
-
670
-
671
- class VocabEmbedding(Cell):
672
- """
673
- The embedding lookup table from the 0-th dim of the parameter table. When the parallel_config.vocab_emb_dp is
674
- True and in the `AUTO_PARALLEL` mode, the embedding lookup will be trained by the data parallel way, as the
675
- parameters will be repeated on each device. If false, the embedding table will be sharded into n parts at
676
- the 0-th dimension of the embedding table, where the n is the model parallel way determined by
677
- `parallel_config.model_parallel` (EmbeddingOpParallelConfig).
678
-
679
- Note:
680
- When `AUTO_PARALLEL` or `SEMI_AUTO_PARALLEL` mode is enabled, this layer support only 2-d dimension inputs,
681
- as the shard is designed for 2d inputs.
682
-
683
- Args:
684
- vocab_size (int): Size of the dictionary of embeddings.
685
- embedding_size (int): The size of each embedding vector.
686
- parallel_config (EmbeddingOpParallelConfig): The parallel config of network. Default
687
- `default_embedding_parallel_config`, an instance of `EmbeddingOpParallelConfig` with default args.
688
- param_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the embedding_table.
689
- Refer to class `initializer` for the values of string when a string
690
- is specified. Default: 'normal'.
691
-
692
- Inputs:
693
- - **input_ids** (Tensor) - The tokenized inputs with datatype int32 with shape (batch_size, seq_length)
694
-
695
- Outputs:
696
- Tuple, a tuple contains (`output`, `embedding_table`)
697
-
698
- - **output** (Tensor) - The embedding vector for the input with shape (batch_size,
699
- seq_length, embedding_size).
700
- - **embedding_table** (Tensor) - The embedding table with shape (vocab_size, embedding_size).
701
-
702
- Raises:
703
- ValueError: If the parallel_config.vocab_emb_dp is True, the vocab size is not a multiple of
704
- parallel_config.model_parallel
705
- ValueError: `vocab_size` is not a positive value.
706
- ValueError: `embedding_size` is not a positive value.
707
- TypeError: `parallel_config` is not a subclass of OpParallelConfig.
708
-
709
- Supported Platforms:
710
- ``Ascend`` ``GPU``
711
-
712
- Examples:
713
- >>> import numpy as np
714
- >>> from mindspore.nn.transformer import VocabEmbedding
715
- >>> from mindspore import Tensor
716
- >>> from mindspore import dtype as mstype
717
- >>> model = VocabEmbedding(vocab_size=30, embedding_size=30)
718
- >>> tensor = Tensor(np.ones((20, 15)), mstype.int32)
719
- >>> output, table = model(tensor)
720
- >>> print(output.shape)
721
- (20, 15, 30)
722
- >>> print(table.shape)
723
- (30, 30)
724
- """
725
-
726
- @_LogActionOnce(logger=logger, key='VocabEmbedding',
727
- no_warning=_get_parallel_mode() in (ParallelMode.STAND_ALONE,))
728
- @_args_type_validator_check(vocab_size=Validator.check_positive_int,
729
- embedding_size=Validator.check_positive_int,
730
- parallel_config=_valid_type_checks([EmbeddingOpParallelConfig], "VocabEmbedding"))
731
- def __init__(self, vocab_size, embedding_size, parallel_config=default_embedding_parallel_config,
732
- param_init='normal'):
733
- super(VocabEmbedding, self).__init__()
734
- _check_config(parallel_config)
735
- self.vocab_size = vocab_size
736
- self.embedding_size = embedding_size
737
- self.embedding_table = Parameter(initializer(param_init, [self.vocab_size, self.embedding_size]),
738
- name='embedding_table', parallel_optimizer=False)
739
- if parallel_config.vocab_emb_dp:
740
- self.gather = P.Gather().shard(((1, 1), (parallel_config.data_parallel, 1)))
741
- logger.info(f"Using {parallel_config.data_parallel} data parallel for the embedding lookup.")
742
- else:
743
- if self.vocab_size % parallel_config.model_parallel != 0:
744
- raise ValueError(f"The vocab size of the embedding {self.vocab_size} must be a "
745
- f"multiple of parallel_config.model_parallel {parallel_config.model_parallel}.")
746
- self.gather = P.Gather().shard(((parallel_config.model_parallel, 1), (parallel_config.data_parallel, 1)))
747
- logger.info(f"Using {parallel_config.data_parallel} data parallel and {parallel_config.model_parallel} "
748
- f"model parallel for the embedding lookup.")
749
-
750
- def construct(self, input_ids):
751
- _check_input_dtype(F.dtype(input_ids), "input_ids", [mstype.int32], self.cls_name)
752
- output = self.gather(self.embedding_table, input_ids, 0)
753
- return output, self.embedding_table.value()
754
-
755
-
756
- class MultiHeadAttention(Cell):
757
- r"""
758
- This is an implementation of multihead attention in the paper `Attention is all you need
759
- <https://arxiv.org/pdf/1706.03762v5.pdf>`_. Given the query vector with source length, and the
760
- key and value vector with target length, the attention will be performed as the following
761
-
762
- .. math::
763
- MultiHeadAttention(query, key, vector) = Concat(head_1, \dots, head_h)W^O
764
-
765
- where :math:`head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)`. The default is with a bias.
766
-
767
- if query, key and value tensor is same, then it will be self attention.
768
-
769
- Args:
770
- batch_size(int): The batch size of the input tensor when do increnmental prediction. Should be a positive
771
- value. When do training or prediction, the argument will not work and the user can just pass None to
772
- the argument.
773
- src_seq_length(int): The sequence length of the query vector.
774
- tgt_seq_length(int): The sequence length of the key and value vector.
775
- hidden_size(int): The hidden size of the input.
776
- num_heads(int): The number of the heads.
777
- hidden_dropout_rate(float): The dropout rate of the final output of the layer. Default:0.1.
778
- attention_dropout_rate(float): The dropout rate of the attention scores. Default:0.1.
779
- compute_dtype(dtype.Number): The computation type of dense. Default mstype.float16.
780
- Should be mstype.float32 or mstype.float16.
781
- softmax_compute_type(dtype.Number): The type of softmax computation module. Default mstype.float32.
782
- Should be mstype.float32 or mstype.float16.
783
- param_init_type(dtype.Number): The parameter initialization type of the module. Default mstype.float32.
784
- Should be mstype.float32 or mstype.float16.
785
- use_past(bool): Use the past state to compute, used for incremental prediction. For example, if we have two
786
- words and want to generate the ten more words. We just need to compute the two words' state only once,
787
- and generate the next word one by one. When use_past is True, there are two steps to run the prediction.
788
- In the first step, set the is_first_iteration to be True by
789
- `model.add_flags_recursive(is_first_iteration=True)`, and pass the full inputs. Then, set the
790
- is_first_iteration to be False by `model.add_flags_recursive(is_first_iteration=False)`. At this moment,
791
- pass the single step's input tensor, and loop it. Default False.
792
- parallel_config(OpParallelConfig): The parallel configure. Default `default_dpmp_config`,
793
- an instance of `OpParallelConfig` with default args.
794
-
795
- Inputs:
796
- - **query_tensor** (Tensor) - The query vector with shape (batch_size, src_seq_length, hidden_size) or
797
- (batch_size * src_seq_length, hidden_size), if the use_past is False or is_first_iteration=True.
798
- Otherwise, must be (batch_size, 1, hidden_size)
799
- - **key_tensor** (Tensor) - The key vector with shape (batch_size, tgt_seq_length, hidden_size) or
800
- (batch_size * tgt_seq_length, hidden_size), if the use_past is False or is_first_iteration=True.
801
- Otherwise, must be (batch_size, 1, hidden_size)
802
- - **value_tensor** (Tensor) - The value vector with shape (batch_size, tgt_seq_length, hidden_size) or
803
- (batch_size * tgt_seq_length, hidden_size), if the use_past is False or is_first_iteration=True.
804
- Otherwise, must be (batch_size, 1, hidden_size)
805
- - **attention_mask** (Tensor) - If the use_past is False or is_first_iteration=True, the attention mask
806
- matrix should ba (batch_size, src_seq_length, tgt_seq_length), or None. None means there will be no mask
807
- in softmax computation. Otherwise, the mask must be (batch_size, 1, tgt_seq_length)
808
- - **key_past** (Tensor) - float16 tensor with shape (batch_size, num_heads, size_per_head, tgt_seq_length).
809
- The past calculated key vector. Used for incremental prediction when the use_past is True.
810
- Default None.
811
- - **value_past** (Tensor) - float16 tensor with shape
812
- (batch_size, num_heads, tgt_seq_length, size_per_head).
813
- The past calculated value vector. Used for incremental prediction when the use_past is True.
814
- Default None.
815
- - **batch_valid_length** (Tensor) - int32 tensor with shape (batch_size,) the past calculated the index.
816
- Used for incremental prediction when the use_past is True. Default None.
817
-
818
- Outputs:
819
- Tuple, a tuple contains(`output`, `layer_present`)
820
-
821
- - **output** (Tensor) - Tensor, the float tensor of the output of the layer with
822
- shape (batch_size, src_seq_length, hidden_size) or (batch_size * src_seq_length, hidden_size),
823
- if the use_past is False or is_first_iteration=True. Otherwise, it will be (batch_size, 1, hidden_size).
824
-
825
- - **layer_present** (Tuple) - A tuple of the Tensor of the projected key and value vector with
826
- ((batch_size, num_heads, size_per_head, tgt_seq_length),
827
- (batch_size, num_heads, tgt_seq_length, size_per_head)).
828
-
829
- Supported Platforms:
830
- ``Ascend`` ``GPU``
831
-
832
- Examples:
833
- >>> import numpy as np
834
- >>> from mindspore.nn.transformer import MultiHeadAttention
835
- >>> from mindspore import dtype as mstype
836
- >>> from mindspore import Tensor
837
- >>> model = MultiHeadAttention(batch_size=None, hidden_size=15, src_seq_length=20, tgt_seq_length=20,
838
- ... num_heads=3)
839
- >>> from_tensor = Tensor(np.ones((2, 20, 15)), mstype.float32)
840
- >>> to_tensor = Tensor(np.ones((2, 20, 15)), mstype.float16)
841
- >>> attention_mask = Tensor(np.ones((2, 20, 20)), mstype.float16)
842
- >>> attn_out, past = model(from_tensor, to_tensor, to_tensor, attention_mask)
843
- >>> print(attn_out.shape)
844
- (2, 20, 15)
845
- >>> print(past[0].shape)
846
- (2, 3, 5, 20)
847
- >>> print(past[1].shape)
848
- (2, 3, 20, 5)
849
- >>> # When use use_past=True, it includes two steps to implement the incremental prediction.
850
- >>> # Step 1: set is_first_iteration=True, and input the full sequence length's state.
851
- >>> # We need to prepare the memory parameters for saving key and value states firstly.
852
- >>> model = MultiHeadAttention(batch_size=2, hidden_size=15, src_seq_length=20, tgt_seq_length=20,
853
- ... num_heads=3, use_past=True)
854
- >>> key_past = Tensor(np.zeros(shape=(2, 3, 5, 20)), mstype.float16)
855
- >>> value_past = Tensor(np.zeros(shape=(2, 3, 20, 5)), mstype.float16)
856
- >>> batch_valid_length = Tensor(np.ones((2,)), mstype.int32)
857
- >>> # Set is_first_iteration=True to generate the full memory states
858
- >>> model.add_flags_recursive(is_first_iteration=True)
859
- >>> attn_out, past = model(from_tensor, to_tensor, to_tensor, attention_mask, key_past, value_past,
860
- ... batch_valid_length)
861
- >>> print(attn_out.shape)
862
- (2, 20, 15)
863
- >>> print(past[0].shape)
864
- (2, 3, 5, 20)
865
- >>> print(past[1].shape)
866
- (2, 3, 20, 5)
867
- >>> from_tensor = Tensor(np.ones((2, 1, 15)), mstype.float32)
868
- >>> to_tensor = Tensor(np.ones((2, 1, 15)), mstype.float16)
869
- >>> attention_mask = Tensor(np.ones((2, 1, 20)), mstype.float16)
870
- >>> # Step 2: set is_first_iteration=False, and pass the single word to run the prediction rather than the
871
- >>> # full sequence.
872
- >>> model.add_flags_recursive(is_first_iteration=False)
873
- >>> attn_out, past = model(from_tensor, to_tensor, to_tensor, attention_mask, key_past, value_past,
874
- ... batch_valid_length)
875
- >>> print(attn_out.shape)
876
- (2, 1, 15)
877
- >>> print(past[0].shape)
878
- (2, 3, 5, 20)
879
- >>> print(past[1].shape)
880
- (2, 3, 20, 5)
881
- """
882
-
883
- @_LogActionOnce(logger=logger, key='MultiHeadAttention',
884
- no_warning=_get_parallel_mode() in (ParallelMode.STAND_ALONE,))
885
- @_args_type_validator_check(hidden_size=Validator.check_positive_int,
886
- num_heads=Validator.check_positive_int,
887
- src_seq_length=Validator.check_positive_int,
888
- tgt_seq_length=Validator.check_positive_int,
889
- attention_dropout_rate=Validator.check_non_negative_float,
890
- hidden_dropout_rate=Validator.check_non_negative_float,
891
- compute_dtype=_valid_value_checks([mstype.float32, mstype.float16],
892
- "MultiHeadAttention"),
893
- softmax_compute_type=_valid_value_checks([mstype.float32, mstype.float16],
894
- "MultiHeadAttention"),
895
- param_init_type=_valid_value_checks([mstype.float32, mstype.float16],
896
- "MultiHeadAttention"),
897
- parallel_config=_valid_type_checks([OpParallelConfig],
898
- "MultiHeadAttention"),
899
- use_past=Validator.check_bool)
900
- def __init__(self, batch_size,
901
- src_seq_length,
902
- tgt_seq_length,
903
- hidden_size,
904
- num_heads,
905
- hidden_dropout_rate=0.1,
906
- attention_dropout_rate=0.1,
907
- compute_dtype=mstype.float16,
908
- softmax_compute_type=mstype.float32,
909
- param_init_type=mstype.float32,
910
- use_past=False,
911
- parallel_config=default_dpmp_config):
912
- super(MultiHeadAttention, self).__init__()
913
- self._is_ascend = context.get_context('device_target') in ["Ascend"]
914
- self.dp = parallel_config.data_parallel
915
- self.is_parallel_mode = _get_parallel_mode() in (
916
- ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
917
- if batch_size:
918
- Validator.check_positive_int(batch_size)
919
- if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,):
920
- _check_config(parallel_config)
921
- self.src_seq_length = src_seq_length
922
- self.tgt_seq_length = tgt_seq_length
923
- self.hidden_size = hidden_size
924
- self.batch_size = batch_size
925
- if hidden_dropout_rate < 0 or hidden_dropout_rate >= 1:
926
- raise ValueError("For 'MultiHeadAttention', the class variable 'hidden_dropout_rate' must be "
927
- "in range [0, 1.0), but got the value : {}.".format(hidden_dropout_rate))
928
- if attention_dropout_rate < 0 or attention_dropout_rate >= 1:
929
- raise ValueError("For 'MultiHeadAttention', the class variable 'attention_dropout_rate' must be "
930
- "in range [0, 1.0), but got the value : {}.".format(attention_dropout_rate))
931
- if hidden_size % num_heads != 0:
932
- raise ValueError("For 'MultiHeadAttention', the class variable 'hidden_size' must be a multiple "
933
- "of 'num_heads', but got the hidden_size is {} and the num_heads is {}."
934
- .format(hidden_size, num_heads))
935
- if num_heads % parallel_config.model_parallel != 0:
936
- raise ValueError("For 'MultiHeadAttention', the class variable 'num_heads' must be a multiple of "
937
- "'parallel_config.model_parallel', but got the num_heads is {} "
938
- "and the parallel_config.model_parallel is {}."
939
- .format(num_heads, parallel_config.model_parallel))
940
- self.is_first_iteration = True
941
- # Output layer
942
- self.projection = _Linear(in_channels=hidden_size,
943
- out_channels=hidden_size,
944
- transpose_b=False,
945
- compute_dtype=compute_dtype,
946
- param_init_type=param_init_type)
947
- self.projection.shard(strategy_bias=((parallel_config.data_parallel, 1), (1,)),
948
- strategy_matmul=((parallel_config.data_parallel, parallel_config.model_parallel),
949
- (parallel_config.model_parallel, 1)))
950
- self.projection.bias.parallel_optimizer = False
951
- self.transpose = P.Transpose()
952
- self.merger_head_transpose = P.Transpose()
953
- self.reshape = P.Reshape()
954
- self.n_head = num_heads
955
- # embedding size per head
956
- self.size_per_head = hidden_size // self.n_head
957
- self.concat_k = P.Concat(axis=3)
958
- self.concat_v = P.Concat(axis=2)
959
- self.multiply_data = Tensor([
960
- -10000.0,
961
- ], dtype=softmax_compute_type)
962
- self.batch_matmul = P.BatchMatMul()
963
- self.real_div = P.RealDiv()
964
- self.sub = P.Sub()
965
- self.mul = P.Mul()
966
- self.add = P.Add()
967
- # Normalize factor for attention, sqrt(dk) as widely used
968
- self.scale_factor = Tensor(math.sqrt(math.sqrt(self.size_per_head)))
969
- self.use_past = use_past
970
- self.dropout = nn.Dropout(p=hidden_dropout_rate)
971
- self.prob_dropout = nn.Dropout(p=attention_dropout_rate)
972
- self.softmax = nn.Softmax().to_float(softmax_compute_type)
973
- self.softmax_3d = nn.Softmax().to_float(softmax_compute_type)
974
- self.expand_dims = P.ExpandDims()
975
-
976
- # Query
977
- self.dense1 = _Linear(hidden_size,
978
- hidden_size,
979
- compute_dtype=compute_dtype,
980
- param_init_type=param_init_type)
981
- # Key
982
- self.dense2 = _Linear(hidden_size,
983
- hidden_size,
984
- compute_dtype=compute_dtype,
985
- param_init_type=param_init_type)
986
- # Value
987
- self.dense3 = _Linear(hidden_size,
988
- hidden_size,
989
- compute_dtype=compute_dtype,
990
- param_init_type=param_init_type)
991
-
992
- self.dtype = compute_dtype
993
- self.softmax_dtype = softmax_compute_type
994
- if self.use_past:
995
- # operators used for state reuse
996
- seq_range = np.arange(src_seq_length).reshape(1, 1, -1)
997
- self.range = Tensor(np.tile(seq_range, (batch_size, 1, 1)), mstype.int32)
998
- self.seq_length = src_seq_length
999
- self.attention_mask = Tensor(np.tril(np.ones(shape=(self.seq_length, self.seq_length))), mstype.int32)
1000
- self.slice = P.StridedSlice().shard(((1, 1, 1, 1),))
1001
- self.not_equal = P.NotEqual().shard(((1, 1, 1, 1), ()))
1002
- self.reducesum = P.ReduceSum().shard(((1, 1, 1, 1),))
1003
- self.expand_dims = P.ExpandDims().shard(((1, 1, 1),))
1004
- self.tensor_le = P.LessEqual().shard(((1, 1, 1), (1, 1, 1)))
1005
- self.add = P.Add().shard(((1, 1, 1, 1), (1, 1, 1, 1)))
1006
- self.equal = P.Equal().shard(((1, 1, 1), (1, 1, 1)))
1007
- self.sub1 = P.Sub().shard(((1,), ()))
1008
- self.tile = P.Tile().shard(((1, 1, 1, 1),))
1009
- self.less = P.Less().shard(((1, 1, 1), (1, 1, 1)))
1010
- self.mul1 = P.Mul().shard(((1, 1, 1, 1), (1, 1, 1, 1)))
1011
- else:
1012
- _check_config(parallel_config)
1013
- self.src_seq_length = src_seq_length
1014
- self.tgt_seq_length = tgt_seq_length
1015
- self.hidden_size = hidden_size
1016
- self.batch_size = batch_size
1017
- if hidden_dropout_rate < 0 or hidden_dropout_rate >= 1:
1018
- raise ValueError("For 'MultiHeadAttention', the class variable 'hidden_dropout_rate' must be "
1019
- "in range [0, 1.0), but got the value : {}.".format(hidden_dropout_rate))
1020
- if attention_dropout_rate < 0 or attention_dropout_rate >= 1:
1021
- raise ValueError("For 'MultiHeadAttention', the class variable 'attention_dropout_rate' must be "
1022
- "in range [0, 1.0), but got the value : {}.".format(attention_dropout_rate))
1023
- if hidden_size % num_heads != 0:
1024
- raise ValueError("For 'MultiHeadAttention', the class variable 'hidden_size' must be a multiple "
1025
- "of 'num_heads', but got the hidden_size is {} and the num_heads is {}."
1026
- .format(hidden_size, num_heads))
1027
- if num_heads % parallel_config.model_parallel != 0:
1028
- raise ValueError("For 'MultiHeadAttention', the class variable 'num_heads' must be a multiple of "
1029
- "'parallel_config.model_parallel', but got the num_heads is {} "
1030
- "and the parallel_config.model_parallel is {}."
1031
- .format(num_heads, parallel_config.model_parallel))
1032
- self.is_first_iteration = True
1033
- # Output layer
1034
- self.projection = _Linear(in_channels=hidden_size,
1035
- out_channels=hidden_size,
1036
- transpose_b=False,
1037
- compute_dtype=compute_dtype,
1038
- param_init_type=param_init_type)
1039
- self.projection.shard(strategy_bias=((parallel_config.data_parallel, 1), (1,)),
1040
- strategy_matmul=((parallel_config.data_parallel, parallel_config.model_parallel),
1041
- (parallel_config.model_parallel, 1)))
1042
- self.projection.bias.parallel_optimizer = False
1043
- self.transpose = P.Transpose().shard(
1044
- ((parallel_config.data_parallel, 1, parallel_config.model_parallel, 1),))
1045
- self.merger_head_transpose = P.Transpose().shard(
1046
- ((parallel_config.data_parallel, parallel_config.model_parallel, 1, 1),))
1047
- self.reshape = P.Reshape()
1048
- self.n_head = num_heads
1049
- # embedding size per head
1050
- self.size_per_head = hidden_size // self.n_head
1051
- self.concat_k = P.Concat(axis=3)
1052
- self.concat_v = P.Concat(axis=2)
1053
- self.multiply_data = Tensor([
1054
- -10000.0,
1055
- ], dtype=softmax_compute_type)
1056
- self.batch_matmul = P.BatchMatMul().shard(
1057
- ((parallel_config.data_parallel, parallel_config.model_parallel, 1, 1),
1058
- (parallel_config.data_parallel, parallel_config.model_parallel, 1, 1)))
1059
- self.real_div = P.RealDiv().shard(
1060
- ((parallel_config.data_parallel, parallel_config.model_parallel, 1, 1), ()))
1061
- self.sub = P.Sub().shard(
1062
- ((1,), (parallel_config.data_parallel, 1, 1, 1)))
1063
- self.mul = P.Mul().shard(
1064
- ((parallel_config.data_parallel, 1, 1, 1), (1,)))
1065
- self.add = P.Add().shard(
1066
- ((parallel_config.data_parallel, 1, 1, 1),
1067
- (parallel_config.data_parallel, parallel_config.model_parallel, 1, 1)))
1068
- # Normalize factor for attention, sqrt(dk) as widely used
1069
- self.scale_factor = Tensor(math.sqrt(math.sqrt(self.size_per_head)))
1070
- self.use_past = use_past
1071
- self.dropout = nn.Dropout(p=hidden_dropout_rate)
1072
- self.dropout.dropout.shard(((parallel_config.data_parallel, 1),))
1073
- self.prob_dropout = nn.Dropout(p=attention_dropout_rate)
1074
- self.prob_dropout.dropout.shard(
1075
- ((parallel_config.data_parallel, parallel_config.model_parallel, 1, 1),))
1076
- self.softmax = nn.Softmax().to_float(softmax_compute_type)
1077
- self.softmax.softmax.shard(((parallel_config.data_parallel, parallel_config.model_parallel, 1, 1),))
1078
- self.softmax_3d = nn.Softmax().to_float(softmax_compute_type)
1079
- self.softmax_3d.softmax.shard(((parallel_config.data_parallel, parallel_config.model_parallel, 1),))
1080
- self.expand_dims = P.ExpandDims().shard(((parallel_config.data_parallel, 1, 1),))
1081
-
1082
- # Query
1083
- self.dense1 = _Linear(hidden_size,
1084
- hidden_size,
1085
- compute_dtype=compute_dtype,
1086
- param_init_type=param_init_type)
1087
- self.dense1.shard(strategy_matmul=((parallel_config.data_parallel, 1), (parallel_config.model_parallel, 1)),
1088
- strategy_bias=((parallel_config.data_parallel, parallel_config.model_parallel),
1089
- (parallel_config.model_parallel,)))
1090
- # Key
1091
- self.dense2 = _Linear(hidden_size,
1092
- hidden_size,
1093
- compute_dtype=compute_dtype,
1094
- param_init_type=param_init_type)
1095
- self.dense2.shard(strategy_matmul=((parallel_config.data_parallel, 1), (parallel_config.model_parallel, 1)),
1096
- strategy_bias=((parallel_config.data_parallel, parallel_config.model_parallel),
1097
- (parallel_config.model_parallel,)))
1098
-
1099
- # Value
1100
- self.dense3 = _Linear(hidden_size,
1101
- hidden_size,
1102
- compute_dtype=compute_dtype,
1103
- param_init_type=param_init_type)
1104
- self.dense3.shard(strategy_matmul=((parallel_config.data_parallel, 1), (parallel_config.model_parallel, 1)),
1105
- strategy_bias=((parallel_config.data_parallel, parallel_config.model_parallel),
1106
- (parallel_config.model_parallel,)))
1107
- self.dtype = compute_dtype
1108
- self.softmax_dtype = softmax_compute_type
1109
- if self.use_past:
1110
- # operators used for state reuse
1111
- seq_range = np.arange(src_seq_length).reshape(1, 1, -1)
1112
- self.range = Tensor(np.tile(seq_range, (batch_size, 1, 1)), mstype.int32)
1113
- self.seq_length = src_seq_length
1114
- self.attention_mask = Tensor(np.tril(np.ones(shape=(self.seq_length, self.seq_length))), mstype.int32)
1115
- self.slice = P.StridedSlice().shard(((1, 1, 1, 1),))
1116
- self.not_equal = P.NotEqual().shard(((1, 1, 1, 1), ()))
1117
- self.reducesum = P.ReduceSum().shard(((1, 1, 1, 1),))
1118
- self.expand_dims = P.ExpandDims().shard(((1, 1, 1),))
1119
- self.tensor_le = P.LessEqual().shard(((1, 1, 1), (1, 1, 1)))
1120
- self.add = P.Add().shard(((1, 1, 1, 1), (1, 1, 1, 1)))
1121
- self.equal = P.Equal().shard(((1, 1, 1), (1, 1, 1)))
1122
- self.sub1 = P.Sub().shard(((1,), ()))
1123
- self.tile = P.Tile().shard(((1, 1, 1, 1),))
1124
- self.less = P.Less().shard(((1, 1, 1), (1, 1, 1)))
1125
- self.mul1 = P.Mul().shard(((1, 1, 1, 1), (1, 1, 1, 1)))
1126
-
1127
- def construct(self, query_tensor, key_tensor, value_tensor, attention_mask, key_past=None,
1128
- value_past=None, batch_valid_length=None):
1129
- self._check_inputs(query_tensor, key_tensor, value_tensor, attention_mask, key_past,
1130
- value_past, batch_valid_length)
1131
- ori_shape = F.shape(query_tensor)
1132
- batch_size = self._get_batch_size_from_query(query_tensor)
1133
- query_tensor, key_tensor, value_tensor = self._convert_to_2d_tensor(query_tensor,
1134
- key_tensor,
1135
- value_tensor,
1136
- attention_mask)
1137
- ori_dtype = F.dtype(query_tensor)
1138
- query_tensor = F.cast(query_tensor, self.dtype)
1139
- key_tensor = F.cast(key_tensor, self.dtype)
1140
- value_tensor = F.cast(value_tensor, self.dtype)
1141
- # multi head attention: query, key, value are derived from the same inputs
1142
- query = self.dense1(query_tensor)
1143
- key = self.dense2(key_tensor)
1144
- value = self.dense3(value_tensor)
1145
- # the returned shape is [bs, num_heads, seq_length, size_per_head]
1146
- query = self.transpose(
1147
- F.reshape(
1148
- query,
1149
- (batch_size, self._get_seq_length_under_incremental(self.src_seq_length),
1150
- self.n_head, self.size_per_head)),
1151
- (0, 2, 1, 3))
1152
- # the returned shape is [bs, size_per_head, seq_length, num_heads]
1153
- key = self.transpose(
1154
- F.reshape(
1155
- key, (batch_size, self._get_seq_length_under_incremental(self.tgt_seq_length),
1156
- self.n_head, self.size_per_head)),
1157
- (0, 2, 3, 1))
1158
- # the returned shape is [bs, num_heads, seq_length, size_per_head]
1159
- value = self.transpose(
1160
- F.reshape(
1161
- value,
1162
- (batch_size, self._get_seq_length_under_incremental(self.tgt_seq_length),
1163
- self.n_head, self.size_per_head)),
1164
- (0, 2, 1, 3))
1165
- # support input shape is [bs, seq, seq] or [bs, heads, seq, seq]
1166
- if attention_mask is not None and len(F.shape(attention_mask)) == 3:
1167
- # expand attention mask from [bs, seq, seq] -> [bs, 1, seq, seq]
1168
- attention_mask = self.expand_dims(attention_mask, 1)
1169
- # key and value for current token(s)
1170
- key_present = key
1171
- value_present = value
1172
- if self.use_past:
1173
- # The first graph with the input size of (bs, seq_length)
1174
- if self.is_first_iteration:
1175
- # Get the valid input length without padding
1176
- valid_length_vector = F.cast(self.less(self.range, batch_valid_length.view(-1, 1, 1)), self.dtype)
1177
- # Cover the key and value numbers corresponding to the padding position
1178
- key_present = self.mul1(key, self.expand_dims(valid_length_vector, 2))
1179
- value_present = self.mul1(value, self.expand_dims(valid_length_vector, 3))
1180
- # The second graph with the inpus size of (bs, 1)
1181
- # the shape of query is (bs, num_heads, 1, size_per_head)
1182
- # the shape of key is (bs, num_heads, size_per_head, 1)
1183
- # the shape of value is (bs, num_heads, 1, size_per_head)
1184
- else:
1185
- # Get the current token position index
1186
- valid_length = self.reducesum(F.cast(self.not_equal(self.slice(key_past, (0, 0, 0, 0),
1187
- (F.shape(key_tensor)[0], 1, 1,
1188
- self.src_seq_length),
1189
- (1, 1, 1, 1)),
1190
- 0), mstype.float32), (1, 2, 3))
1191
- valid_length = F.reshape(valid_length, (-1, 1, 1))
1192
- valid_length_vector = F.cast(self.equal(valid_length, self.range), self.dtype)
1193
- # Pad the key and value to seq_length with only the position index not zero
1194
- current_key = self.mul1(self.tile(key, (1, 1, 1, self.seq_length)),
1195
- self.expand_dims(valid_length_vector, 2))
1196
- current_value = self.mul1(self.tile(value, (1, 1, self.seq_length, 1)),
1197
- self.expand_dims(valid_length_vector, 3))
1198
- # Concat the previous saved state and current state
1199
- key = self.add(key_past, current_key)
1200
- value = self.add(value_past, current_value)
1201
- # Update key_present and value_present for state update
1202
- key_present = key
1203
- value_present = value
1204
- attention_mask = F.reshape(self.attention_mask, (self.seq_length, self.seq_length, 1, 1))
1205
-
1206
- layer_present = (key_present, value_present)
1207
- # multi head attention considering attention mask
1208
- # the return shape is [bs * seq_length, hidden_size]
1209
- attention = self._attn(query, key, value, attention_mask)
1210
- # Output
1211
- output = self.projection(attention)
1212
- output = self.dropout(output)
1213
- output = F.reshape(output, ori_shape)
1214
- output = F.cast(output, ori_dtype)
1215
- return output, layer_present
1216
-
1217
- def _get_batch_size_from_query(self, query):
1218
- r"""Get the batch size from query tensor"""
1219
- # For the incremental prediction, the seq length for the input is 1.
1220
- incr_infer = self.use_past and self.is_first_iteration
1221
- if len(F.shape(query)) == 2 and ((incr_infer) or (not self.use_past)):
1222
- return F.shape(query)[0] // self.src_seq_length
1223
- return F.shape(query)[0]
1224
-
1225
- def _get_seq_length_under_incremental(self, length):
1226
- r"""Return the length of the tensor.
1227
- For the incremental prediction, the seq length for the input is 1.
1228
- """
1229
- if self.use_past and not self.is_first_iteration:
1230
- return 1
1231
- return length
1232
-
1233
- def _check_inputs(self, query_tensor, key_tensor, value_tensor, attention_mask, key_past=None,
1234
- value_past=None, batch_valid_length=None):
1235
- r"""Check inputs"""
1236
- _check_input_dtype(F.dtype(query_tensor), "query_tensor", [mstype.float32, mstype.float16], self.cls_name)
1237
- _check_input_dtype(F.dtype(key_tensor), "key_tensor", [mstype.float32, mstype.float16], self.cls_name)
1238
- _check_input_dtype(F.dtype(value_tensor), "value_tensor", [mstype.float32, mstype.float16], self.cls_name)
1239
- if attention_mask is not None:
1240
- _check_input_dtype(F.dtype(attention_mask), "attention_mask", [mstype.float32, mstype.float16],
1241
- self.cls_name)
1242
-
1243
- key_is_tensor = isinstance(key_past, Tensor)
1244
- value_is_tensor = isinstance(value_past, Tensor)
1245
- batch_valid_length_is_tensor = isinstance(batch_valid_length, Tensor)
1246
- key_is_default = key_past is None
1247
- value_is_default = value_past is None
1248
- batch_is_default = batch_valid_length is None
1249
- _check_past_none_input_none(self.use_past, "key_past", self.cls_name, None, key_is_tensor,
1250
- key_is_default)
1251
- _check_past_none_input_none(self.use_past, "value_past", self.cls_name, None, value_is_tensor,
1252
- value_is_default)
1253
- _check_past_none_input_none(self.use_past, "batch_valid_length", self.cls_name, None,
1254
- batch_valid_length_is_tensor, batch_is_default)
1255
- if self.use_past:
1256
- _check_input_dtype(F.dtype(key_past), "key_past", [mstype.float16], self.cls_name)
1257
- _check_input_dtype(F.dtype(value_past), "value_past", [mstype.float16], self.cls_name)
1258
- _check_input_dtype(F.dtype(batch_valid_length), "batch_valid_length", [mstype.int32], self.cls_name)
1259
- return True
1260
-
1261
- def _convert_to_2d_tensor(self, query_tensor, key_tensor, value_tensor, attention_mask):
1262
- """convert a nd tensor to a 2d tensor"""
1263
- query_shape = F.shape(query_tensor)
1264
- query_tensor = F.reshape(query_tensor, (-1, query_shape[-1]))
1265
- key_shape = F.shape(key_tensor)
1266
- key_tensor = F.reshape(key_tensor, (-1, key_shape[-1]))
1267
- value_shape = F.shape(value_tensor)
1268
- value_tensor = F.reshape(value_tensor, (-1, value_shape[-1]))
1269
-
1270
- return query_tensor, key_tensor, value_tensor
1271
-
1272
- def _merge_heads(self, x):
1273
- """
1274
- convert a 4d input to a 2d output
1275
-
1276
- Inputs:
1277
- x: input tensor
1278
-
1279
- Output:
1280
- x_merge: the 2d output
1281
- """
1282
- x = self.merger_head_transpose(
1283
- x, (0, 2, 1, 3)) # bs, seq_length, head, size_per_head
1284
- x_shape = P.Shape()(x)
1285
- new_shape = (-1, x_shape[-2] * x_shape[-1])
1286
- x_merge = self.reshape(x, new_shape)
1287
- return x_merge
1288
-
1289
- def _softmax(self, attention_scores):
1290
- """
1291
- For the consideration of the performance, do softmax according to different situations
1292
- :param attention_scores: a 3d tensor before softmax
1293
- :return: the attention scores.
1294
- """
1295
-
1296
- if self._is_ascend and self.softmax_dtype == mstype.float16 or not self._is_ascend:
1297
- attention_probs = self.softmax(attention_scores)
1298
- else:
1299
- shape = F.shape(attention_scores)
1300
- # attention probs
1301
- attention_probs = self.softmax_3d(
1302
- F.reshape(attention_scores,
1303
- (shape[0], -1, shape[-1])))
1304
- attention_probs = F.reshape(attention_probs, shape)
1305
- return attention_probs
1306
-
1307
- def _attn(self, query, key, value, attention_mask):
1308
- """
1309
- Get the weighted score along the seq_length
1310
-
1311
- Inputs:
1312
- query: the query matrix
1313
- key: the key matrix
1314
- value: the value matrix
1315
- attention_mask: the attention mask matrix with shape (batch_size,
1316
- 1, seq_length, seq_length)
1317
- Outputs:
1318
- weighted_values: Tensor, the weighted sum scores
1319
- """
1320
- # Normalize query and key before MatMul, default off
1321
- # Attention score [bs, num_heads, seq_length, seq_length]
1322
- factor = P.Cast()(self.scale_factor, P.DType()(query))
1323
- query = self.real_div(query, factor)
1324
- key = self.real_div(key, factor)
1325
- score = self.batch_matmul(query, key)
1326
-
1327
- ori_dtype = P.DType()(score)
1328
- attention_scores = P.Cast()(score, self.softmax_dtype)
1329
-
1330
- # for input size of (bs, 1) namely the second graph,
1331
- # the shape of attention_mask matrix should be (bs, 1, 1, seq_length)
1332
- if attention_mask is not None:
1333
- if self.use_past and not self.is_first_iteration:
1334
- # Calculate the current total token
1335
- current_index = self.reducesum(F.cast(self.not_equal(self.slice(key, (0, 0, 0, 0),
1336
- (F.shape(query)[0], 1, 1,
1337
- self.seq_length),
1338
- (1, 1, 1, 1)),
1339
- 0), mstype.float32), (1, 2, 3))
1340
- # Get the precise position index
1341
- index = self.sub1(F.cast(current_index, mstype.int32), 1)
1342
- index = F.reshape(index, (-1, 1, 1))
1343
- # Calculate the attention_mask matrix via the position index
1344
- attention_mask = F.cast(self.tensor_le(self.range, index), mstype.int32)
1345
- attention_mask = self.expand_dims(attention_mask, 2)
1346
- # Minus 10000 for the position where masked to exclude them from softmax
1347
- multiplu_out = self.sub(
1348
- P.Cast()(F.tuple_to_array((1.0,)), P.DType()(attention_scores)),
1349
- P.Cast()(attention_mask, P.DType()(attention_scores)))
1350
-
1351
- adder = self.mul(multiplu_out, self.multiply_data)
1352
- attention_scores = self.add(adder, attention_scores)
1353
-
1354
- # attention probs
1355
- attention_probs = self._softmax(attention_scores)
1356
- attention_probs = P.Cast()(attention_probs, ori_dtype)
1357
-
1358
- attention_probs = self.prob_dropout(attention_probs)
1359
- # Weighted sum output [bs, num_heads, seq_length, size_per_head]
1360
- weighted_values = self.batch_matmul(attention_probs, value)
1361
- attention_merge = self._merge_heads(weighted_values)
1362
- return attention_merge
1363
-
1364
-
1365
- class TransformerEncoderLayer(Cell):
1366
- r"""
1367
- Transformer Encoder Layer. This is an implementation of the single layer of the transformer
1368
- encoder layer, mainly including Multi-Head Attention, Feed Forward, Add and LayerNorm layer.
1369
-
1370
- The TransformerEncoderLayer structure is shown in the following figure:
1371
-
1372
- .. image:: ../images/TransformerEncoderLayer.png
1373
- :align: center
1374
-
1375
- Args:
1376
- batch_size(int): The batch size of the input tensor when do increnmental prediction. Should be a positive
1377
- value. When do training or prediction, the argument will not work and the user can just pass None to
1378
- the argument.
1379
- hidden_size(int): The hidden size of the input.
1380
- ffn_hidden_size(int): The hidden size of bottleneck in the feedforward layer.
1381
- num_heads(int): The number of the heads.
1382
- seq_length(int): The input sequence length.
1383
- attention_dropout_rate(float): The dropout rate of the attention scores. Default:0.1.
1384
- hidden_dropout_rate(float): The dropout rate of the final output of the layer. Default:0.1.
1385
- post_layernorm_residual(bool): Do residuals adds before the layernorm. Default False.
1386
- layernorm_compute_type(dtype.Number): The computation type of the layernorm.
1387
- Should be mstype.float32 or mstype.float16. Default mstype.float32.
1388
- softmax_compute_type(dtype.Number): The computation type of the softmax in the attention.
1389
- Should be mstype.float32 or mstype.float16. Default mstype.float32.
1390
- param_init_type(dtype.Number): The parameter initialization type of the module.
1391
- Should be mstype.float32 or mstype.float16. Default mstype.float32.
1392
- hidden_act (str, nn.Cell): The activation of the internal feedforward layer. Supports 'relu',
1393
- 'relu6', 'tanh', 'gelu', 'fast_gelu', 'elu', 'sigmoid', 'prelu', 'leakyrelu', 'hswish',
1394
- 'hsigmoid', 'logsigmoid' and so on. User can provide custom activition to the argument.
1395
- If user wants to run the net in the parallel mode, the custom activation must also provide
1396
- the `activation_shard` function. Please see the examples of the
1397
- class:`mindspore.nn.transformer.FeedForward`. Default: gelu.
1398
- use_past(bool): Use the past state to compute, used for incremental prediction. For example, if we have two
1399
- words and want to generate the ten more words. We just need to compute the two words' state only once,
1400
- and generate the next word one by one. When use_past is True, there are two steps to run the prediction.
1401
- In the first step, set the is_first_iteration to be True by
1402
- `model.add_flags_recursive(is_first_iteration=True)`, and pass the full inputs. Then, set the
1403
- is_first_iteration to be False by `model.add_flags_recursive(is_first_iteration=False)`.
1404
- At this moment, pass the single step's input tensor, and loop it. Default False.
1405
- moe_config(MoEConfig): The configuration of MoE (Mixture of Expert). Default is an instance of MoEConfig
1406
- with default values. Please see `MoEConfig`.
1407
- parallel_config(OpParallelConfig, MoEParallelConfig): The parallel configure. When MoE is applied,
1408
- MoEParallelConfig is effective, otherwise OpParallelConfig is effective. Default `default_dpmp_config`,
1409
- an instance of `OpParallelConfig` with default args.
1410
-
1411
- Inputs:
1412
- - **x** (Tensor) - Float Tensor, shape should be [batch_size, seq_length, hidden_size] or
1413
- [batch_size * seq_length, hidden_size], if the use_past is False or is_first_iteration=True. Otherwise,
1414
- should be [batch_size, 1, hidden_size]
1415
- - **input_mask** (Tensor) - Float Tensor, If the use_past is False or is_first_iteration=True,
1416
- the attention mask matrix should ba [batch_size, seq_length, seq_length], or None. None means there will
1417
- be no mask in softmax computation. Otherwise, should be [batch_size, 1, hidden_size]
1418
- - **init_reset** (Tensor) - A bool tensor with shape [1], used to clear the past key parameter and
1419
- past value parameter used in the incremental prediction. Only valid when use_past is True. Default True.
1420
- - **batch_valid_length** (Tensor) - int32 tensor with shape [batch_size] the past calculated the index.
1421
- Used for incremental prediction when the use_past is True. Default None.
1422
-
1423
- Outputs:
1424
- Tuple, a tuple contains(`output`, `layer_present`).
1425
-
1426
- - **output** (Tensor) - The float tensor of the output of the layer with
1427
- shape (batch_size, seq_length, hidden_size) or (batch_size * seq_length, hidden_size), if the use_past is
1428
- False or is_first_iteration=True. Otherwise, it will be (batch_size, 1, hidden_size)
1429
-
1430
- - **layer_present** (Tuple) - A tuple of the Tensor of the projected key and value vector with
1431
- ((batch_size, num_heads, size_per_head, seq_length),
1432
- (batch_size, num_heads, seq_length, size_per_head)).
1433
-
1434
- Supported Platforms:
1435
- ``Ascend`` ``GPU``
1436
-
1437
- Examples:
1438
- >>> import numpy as np
1439
- >>> from mindspore import dtype as mstype
1440
- >>> from mindspore.nn.transformer import TransformerEncoderLayer
1441
- >>> from mindspore import Tensor
1442
- >>> model = TransformerEncoderLayer(batch_size=2, hidden_size=8, ffn_hidden_size=64, seq_length=16,
1443
- ... num_heads=2)
1444
- >>> encoder_input_value = Tensor(np.ones((2, 16, 8)), mstype.float32)
1445
- >>> encoder_input_mask = Tensor(np.ones((2, 16, 16)), mstype.float16)
1446
- >>> output, past = model(encoder_input_value, encoder_input_mask)
1447
- >>> print(output.shape)
1448
- (2, 16, 8)
1449
- >>> print(past[0].shape)
1450
- (2, 2, 4, 16)
1451
- >>> print(past[1].shape)
1452
- (2, 2, 16, 4)
1453
- >>> # When use use_past=True, it includes two steps to implement the incremental prediction.
1454
- >>> # Step 1: set is_first_iteration=True, and input the full sequence length's state.
1455
- >>> batch_valid_length = Tensor(np.ones((2,)), mstype.int32)
1456
- >>> init_reset = Tensor([True], mstype.bool_)
1457
- >>> # Set is_first_iteration=True to generate the full memory states
1458
- >>> model = TransformerEncoderLayer(batch_size=2, hidden_size=8, ffn_hidden_size=64, seq_length=16,
1459
- ... num_heads=2, use_past=True)
1460
- >>> model.add_flags_recursive(is_first_iteration=True)
1461
- >>> hidden, past = model(encoder_input_value, encoder_input_mask, init_reset, batch_valid_length)
1462
- >>> print(hidden.shape)
1463
- (2, 16, 8)
1464
- >>> print(past[0].shape)
1465
- (2, 2, 4, 16)
1466
- >>> print(past[1].shape)
1467
- (2, 2, 16, 4)
1468
- >>> encoder_input_value = Tensor(np.ones((2, 1, 8)), mstype.float32)
1469
- >>> encoder_input_mask = Tensor(np.ones((2, 1, 16)), mstype.float16)
1470
- >>> init_reset = Tensor([False], mstype.bool_)
1471
- >>> # Step 2: set is_first_iteration=False, and pass the single word to run the prediction rather than
1472
- >>> # the full sequence.
1473
- >>> model.add_flags_recursive(is_first_iteration=False)
1474
- >>> hidden, past = model(encoder_input_value, encoder_input_mask, init_reset, batch_valid_length)
1475
- >>> print(hidden.shape)
1476
- (2, 1, 8)
1477
- >>> print(past[0].shape)
1478
- (2, 2, 4, 16)
1479
- >>> print(past[1].shape)
1480
- (2, 2, 16, 4)
1481
- """
1482
-
1483
- @_LogActionOnce(logger=logger, key='TransformerEncoderLayer',
1484
- no_warning=_get_parallel_mode() in (ParallelMode.STAND_ALONE,))
1485
- @_args_type_validator_check(hidden_size=Validator.check_positive_int,
1486
- num_heads=Validator.check_positive_int,
1487
- ffn_hidden_size=Validator.check_positive_int,
1488
- seq_length=Validator.check_positive_int,
1489
- attention_dropout_rate=Validator.check_non_negative_float,
1490
- hidden_dropout_rate=Validator.check_non_negative_float,
1491
- post_layernorm_residual=Validator.check_bool,
1492
- layernorm_compute_type=_valid_value_checks([mstype.float32, mstype.float16],
1493
- "TransformerEncoderLayer"),
1494
- softmax_compute_type=_valid_value_checks([mstype.float32, mstype.float16],
1495
- "TransformerEncoderLayer"),
1496
- param_init_type=_valid_value_checks([mstype.float32, mstype.float16],
1497
- "TransformerEncoderLayer"),
1498
- parallel_config=_valid_type_checks([OpParallelConfig, MoEParallelConfig],
1499
- "TransformerEncoderLayer"),
1500
- use_past=Validator.check_bool)
1501
- def __init__(self,
1502
- batch_size,
1503
- hidden_size,
1504
- ffn_hidden_size,
1505
- num_heads,
1506
- seq_length,
1507
- attention_dropout_rate=0.1,
1508
- hidden_dropout_rate=0.1,
1509
- post_layernorm_residual=False,
1510
- layernorm_compute_type=mstype.float32,
1511
- softmax_compute_type=mstype.float32,
1512
- param_init_type=mstype.float32,
1513
- hidden_act='gelu',
1514
- use_past=False,
1515
- moe_config=default_moe_config,
1516
- parallel_config=default_dpmp_config):
1517
- super(TransformerEncoderLayer, self).__init__()
1518
- if batch_size or use_past:
1519
- Validator.check_positive_int(batch_size)
1520
- self.batch_size = batch_size
1521
- if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,):
1522
- _check_config(parallel_config)
1523
- if num_heads % parallel_config.model_parallel != 0:
1524
- raise ValueError(
1525
- "For 'TransformerEncoderLayer', the class variable 'num_heads' must be divisibled by the "
1526
- "'parallel_config.model_parallel', but got the num_heads is {} and "
1527
- "parallel_config.model_parallel is {}.".format(num_heads, parallel_config.model_parallel))
1528
- if hidden_size % parallel_config.model_parallel != 0:
1529
- raise ValueError(
1530
- "For 'TransformerEncoderLayer', the class variable 'hidden_size' must be divisibled by "
1531
- "the 'parallel_config.model_parallel', but got the hidden_size is {} and parallel_config."
1532
- " model_parallel is {}.".format(hidden_size, parallel_config.model_parallel))
1533
- if ffn_hidden_size % parallel_config.model_parallel != 0:
1534
- raise ValueError(
1535
- "For 'TransformerEncoderLayer', the class variable 'ffn_hidden_size' must be divisibled "
1536
- "by the 'parallel_config.model_parallel', but got the ffn_hidden_size is {} "
1537
- "and parallel_config. model_parallel is {}."
1538
- .format(ffn_hidden_size, parallel_config.model_parallel))
1539
- _check_moe_config(moe_config, parallel_config)
1540
- self.use_moe = moe_config.expert_num > 1
1541
- self.use_past = use_past
1542
- self.seq_length = seq_length
1543
- self.hidden_size = hidden_size
1544
- self.layernorm1 = _LayerNorm((hidden_size,)).to_float(layernorm_compute_type)
1545
- self.layernorm2 = _LayerNorm((hidden_size,)).to_float(layernorm_compute_type)
1546
-
1547
- attention_parallel_config = parallel_config.dpmp if self.use_moe else parallel_config
1548
- self.attention = MultiHeadAttention(batch_size=batch_size,
1549
- src_seq_length=seq_length,
1550
- tgt_seq_length=seq_length,
1551
- hidden_size=hidden_size,
1552
- num_heads=num_heads,
1553
- hidden_dropout_rate=hidden_dropout_rate,
1554
- attention_dropout_rate=attention_dropout_rate,
1555
- softmax_compute_type=softmax_compute_type,
1556
- param_init_type=param_init_type,
1557
- use_past=use_past,
1558
- parallel_config=attention_parallel_config)
1559
- if self.use_moe:
1560
- self.output = MoE(hidden_size=hidden_size,
1561
- dropout_rate=hidden_dropout_rate,
1562
- ffn_hidden_size=ffn_hidden_size,
1563
- param_init_type=param_init_type,
1564
- hidden_act=hidden_act,
1565
- moe_config=moe_config,
1566
- parallel_config=parallel_config)
1567
- else:
1568
- # Feed Forward Network, FFN
1569
- self.output = FeedForward(hidden_size=hidden_size,
1570
- dropout_rate=hidden_dropout_rate,
1571
- ffn_hidden_size=ffn_hidden_size,
1572
- param_init_type=param_init_type,
1573
- hidden_act=hidden_act,
1574
- parallel_config=parallel_config)
1575
- self.post_layernorm_residual = post_layernorm_residual
1576
- self.add = P.Add().shard(((parallel_config.data_parallel, 1), (parallel_config.data_parallel, 1)))
1577
- self.add_3d = P.Add().shard(((parallel_config.data_parallel, 1, 1), (parallel_config.data_parallel, 1, 1)))
1578
- self.dtype = mstype.float16
1579
- self.key_past = None
1580
- self.value_past = None
1581
-
1582
- if self.use_past:
1583
- # operator used for state reuse
1584
- self.reducesum = P.ReduceSum().shard(((1, 1, 1, 1),))
1585
- self.not_equal = P.NotEqual().shard(((1, 1, 1, 1), ()))
1586
- self.slice = P.StridedSlice().shard(((1, 1, 1, 1),))
1587
- size_per_head = hidden_size // num_heads
1588
- self.key_shape = (batch_size, num_heads, size_per_head, seq_length)
1589
- self.value_shape = (batch_size, num_heads, seq_length, size_per_head)
1590
- # parameters saving key and value states
1591
- self.key_past = Parameter(Tensor(np.zeros(shape=self.key_shape), self.dtype), name="key_past")
1592
- self.value_past = Parameter(Tensor(np.zeros(shape=self.value_shape), self.dtype), name="value_past")
1593
- self.tile = P.Tile().shard(((1, 1),))
1594
- self.mul = P.Mul().shard(((1, 1, 1, 1), (1,)))
1595
- self.assign = P.Assign().shard(((1, 1, 1, 1), (1, 1, 1, 1)))
1596
- elif _get_parallel_mode() not in (ParallelMode.AUTO_PARALLEL,):
1597
- _check_config(parallel_config)
1598
- if num_heads % parallel_config.model_parallel != 0:
1599
- raise ValueError(
1600
- "For 'TransformerEncoderLayer', the class variable 'num_heads' must be divisibled by the "
1601
- "'parallel_config.model_parallel', but got the num_heads is {} and "
1602
- "parallel_config.model_parallel is {}.".format(num_heads, parallel_config.model_parallel))
1603
- if hidden_size % parallel_config.model_parallel != 0:
1604
- raise ValueError(
1605
- "For 'TransformerEncoderLayer', the class variable 'hidden_size' must be divisibled by "
1606
- "the 'parallel_config.model_parallel', but got the hidden_size is {} and parallel_config."
1607
- " model_parallel is {}.".format(hidden_size, parallel_config.model_parallel))
1608
- if ffn_hidden_size % parallel_config.model_parallel != 0:
1609
- raise ValueError(
1610
- "For 'TransformerEncoderLayer', the class variable 'ffn_hidden_size' must be divisibled "
1611
- "by the 'parallel_config.model_parallel', but got the ffn_hidden_size is {} "
1612
- "and parallel_config. model_parallel is {}."
1613
- .format(ffn_hidden_size, parallel_config.model_parallel))
1614
- _check_moe_config(moe_config, parallel_config)
1615
- self.use_moe = moe_config.expert_num > 1
1616
- self.use_past = use_past
1617
- self.seq_length = seq_length
1618
- self.hidden_size = hidden_size
1619
- self.layernorm1 = _LayerNorm((hidden_size,)).to_float(layernorm_compute_type)
1620
- self.layernorm1.shard(((parallel_config.data_parallel, 1),))
1621
- self.layernorm2 = _LayerNorm((hidden_size,)).to_float(layernorm_compute_type)
1622
- self.layernorm2.shard(((parallel_config.data_parallel, 1),))
1623
-
1624
- attention_parallel_config = parallel_config.dpmp if self.use_moe else parallel_config
1625
- self.attention = MultiHeadAttention(batch_size=batch_size,
1626
- src_seq_length=seq_length,
1627
- tgt_seq_length=seq_length,
1628
- hidden_size=hidden_size,
1629
- num_heads=num_heads,
1630
- hidden_dropout_rate=hidden_dropout_rate,
1631
- attention_dropout_rate=attention_dropout_rate,
1632
- softmax_compute_type=softmax_compute_type,
1633
- param_init_type=param_init_type,
1634
- use_past=use_past,
1635
- parallel_config=attention_parallel_config)
1636
- if self.use_moe:
1637
- self.output = MoE(hidden_size=hidden_size,
1638
- dropout_rate=hidden_dropout_rate,
1639
- ffn_hidden_size=ffn_hidden_size,
1640
- param_init_type=param_init_type,
1641
- hidden_act=hidden_act,
1642
- moe_config=moe_config,
1643
- parallel_config=parallel_config)
1644
- else:
1645
- # Feed Forward Network, FFN
1646
- self.output = FeedForward(hidden_size=hidden_size,
1647
- dropout_rate=hidden_dropout_rate,
1648
- ffn_hidden_size=ffn_hidden_size,
1649
- param_init_type=param_init_type,
1650
- hidden_act=hidden_act,
1651
- parallel_config=parallel_config)
1652
- self.post_layernorm_residual = post_layernorm_residual
1653
- self.add = P.Add().shard(((parallel_config.data_parallel, 1), (parallel_config.data_parallel, 1)))
1654
- self.add_3d = P.Add().shard(((parallel_config.data_parallel, 1, 1), (parallel_config.data_parallel, 1, 1)))
1655
- self.dtype = mstype.float16
1656
- self.key_past = None
1657
- self.value_past = None
1658
-
1659
- if self.use_past:
1660
- # operator used for state reuse
1661
- self.reducesum = P.ReduceSum().shard(((1, 1, 1, 1),))
1662
- self.not_equal = P.NotEqual().shard(((1, 1, 1, 1), ()))
1663
- self.slice = P.StridedSlice().shard(((1, 1, 1, 1),))
1664
- size_per_head = hidden_size // num_heads
1665
- self.key_shape = (batch_size, num_heads, size_per_head, seq_length)
1666
- self.value_shape = (batch_size, num_heads, seq_length, size_per_head)
1667
- # parameters saving key and value states
1668
- self.key_past = Parameter(Tensor(np.zeros(shape=self.key_shape), self.dtype), name="key_past")
1669
- self.value_past = Parameter(Tensor(np.zeros(shape=self.value_shape), self.dtype), name="value_past")
1670
- self.tile = P.Tile().shard(((1, 1),))
1671
- self.mul = P.Mul().shard(((1, 1, 1, 1), (1,)))
1672
- self.assign = P.Assign().shard(((1, 1, 1, 1), (1, 1, 1, 1)))
1673
- else:
1674
- raise RuntimeError(f"The {self.cls_name} only support sharding propagation or "
1675
- f"semi-auto parallel mode now.")
1676
-
1677
- def construct(self, x, input_mask=None, init_reset=True, batch_valid_length=None):
1678
- self._check_input(x, input_mask, init_reset, batch_valid_length)
1679
- x_shape = F.shape(x)
1680
- x = F.reshape(x, (-1, x_shape[-1]))
1681
- if self.post_layernorm_residual:
1682
- input_x = x
1683
- else:
1684
- input_x = self.layernorm1(x)
1685
- input_x = F.cast(input_x, self.dtype)
1686
-
1687
- # indicate whether reset saved states
1688
- key_reset = None
1689
- value_reset = None
1690
-
1691
- if self.use_past:
1692
- # reset states, init_reset True for reuse and False for reset
1693
- self.assign(self.key_past, self.mul(self.key_past, F.cast(init_reset, self.dtype)))
1694
- key_reset = self.key_past
1695
- self.assign(self.value_past, self.mul(self.value_past, F.cast(init_reset, self.dtype)))
1696
- value_reset = self.value_past
1697
- # add dependency for desired execution order
1698
- input_x = F.depend(input_x, key_reset)
1699
- input_x = F.depend(input_x, value_reset)
1700
-
1701
- attention, layer_present = self.attention(input_x, input_x, input_x, input_mask,
1702
- self.key_past, self.value_past, batch_valid_length)
1703
- # For post-layernorm the inputs for residual path are output of self-attention and output of layernorm
1704
- if self.post_layernorm_residual:
1705
- x = self.add(input_x, attention)
1706
- # For pre-layernorm the inputs for residual path are output of self-attention and input of this layer
1707
- else:
1708
- x = self.add(x, attention)
1709
-
1710
- output_x = self.layernorm2(x)
1711
- output_x = F.cast(output_x, self.dtype)
1712
- aux_loss = None
1713
- if self.use_moe:
1714
- mlp_logit, aux_loss = self.output(output_x)
1715
- else:
1716
- mlp_logit = self.output(output_x)
1717
-
1718
- value_update = None
1719
- key_update = None
1720
- if self.use_past:
1721
- # current key and value
1722
- key_present, value_present = layer_present
1723
- # update key and value calculated this step
1724
- self.assign(self.key_past, key_present)
1725
- key_update = self.key_past
1726
- self.assign(self.value_past, value_present)
1727
- value_update = self.value_past
1728
- # add dependency for desired execution order
1729
- key_update = F.depend(key_update, key_reset)
1730
- value_update = F.depend(value_update, value_reset)
1731
-
1732
- # add dependency for desired execution order
1733
- mlp_logit = F.depend(mlp_logit, value_update)
1734
- mlp_logit = F.depend(mlp_logit, key_update)
1735
-
1736
- # if shape is 3d, we reshape the inputs of the add
1737
- if len(x_shape) == 3:
1738
- output_x = P.Reshape()(output_x, x_shape)
1739
- mlp_logit = P.Reshape()(mlp_logit, x_shape)
1740
- x = P.Reshape()(x, x_shape)
1741
-
1742
- if self.post_layernorm_residual:
1743
- output = self.add_3d(output_x, mlp_logit)
1744
- output = F.reshape(output, (-1, x_shape[-1]))
1745
- output = self.layernorm1(output)
1746
- output = F.reshape(output, x_shape)
1747
- else:
1748
- output = self.add_3d(x, mlp_logit)
1749
- else:
1750
- if self.post_layernorm_residual:
1751
- output = self.add(output_x, mlp_logit)
1752
- output = self.layernorm1(output)
1753
- else:
1754
- output = self.add(x, mlp_logit)
1755
- output = F.reshape(output, x_shape)
1756
-
1757
- if self.use_moe:
1758
- return output, layer_present, aux_loss
1759
- return output, layer_present
1760
-
1761
- def _check_input(self, x, input_mask, init_reset, batch_valid_length):
1762
- r"""Check inputs"""
1763
- _check_input_dtype(F.dtype(x), "x", [mstype.float32, mstype.float16], self.cls_name)
1764
- if input_mask is not None:
1765
- _check_input_dtype(F.dtype(input_mask), "input_mask", [mstype.float32, mstype.float16], self.cls_name)
1766
-
1767
- init_reset_is_tensor = isinstance(init_reset, Tensor)
1768
- init_reset_is_default = init_reset is True
1769
- batch_valid_length_is_tensor = isinstance(batch_valid_length, Tensor)
1770
- batch_is_default = batch_valid_length is None
1771
- _check_past_none_input_none(self.use_past, "init_reset", self.cls_name, True, init_reset_is_tensor,
1772
- init_reset_is_default)
1773
- _check_past_none_input_none(self.use_past, "batch_valid_length", self.cls_name, None,
1774
- batch_valid_length_is_tensor, batch_is_default)
1775
-
1776
- if self.use_past:
1777
- _check_input_dtype(F.dtype(init_reset), "init_reset", [mstype.bool_], self.cls_name)
1778
- _check_input_dtype(F.dtype(batch_valid_length), "batch_valid_length", [mstype.int32], self.cls_name)
1779
- return True
1780
-
1781
-
1782
- class TransformerDecoderLayer(Cell):
1783
- r"""
1784
- Transformer Decoder Layer. This is an implementation of the single layer of the transformer
1785
- decoder layer, including self-attention, cross attention and feedward layer. When the encoder_output is None,
1786
- the cross attention will not be effective.
1787
-
1788
- Args:
1789
- hidden_size(int): The hidden size of the input.
1790
- ffn_hidden_size(int): The hidden size of bottleneck in the feedforward layer.
1791
- num_heads(int): The number of the heads.
1792
- batch_size(int): The batch size of the input tensor when do increnmental prediction. Should be a positive
1793
- value. When do training or prediction, the argument will not work and the user can just pass None to
1794
- the argument.
1795
- src_seq_length(int): The input source sequence length.
1796
- tgt_seq_length(int): The input target sequence length.
1797
- attention_dropout_rate(float): The dropout rate of the attention scores. Default:0.1.
1798
- hidden_dropout_rate(float): The dropout rate of the final output of the layer. Default:0.1.
1799
- post_layernorm_residual(bool): Do residuals adds before the layernorm. Default False.
1800
- use_past(bool): Use the past state to compute, used for incremental prediction. Default False.
1801
- layernorm_compute_type(dtype.Number): The computation type of the layernorm.
1802
- Should be dtype.float32 or dtype.float16. Default dtype.float32.
1803
- softmax_compute_type(dtype.Number): The computation type of the softmax in the attention.
1804
- Should be dtype.float32 or dtype.float16. Default mstype.float32.
1805
- param_init_type(dtype.Number): The parameter initialization type of the module.
1806
- Should be dtype.float32 or dtype.float16. Default dtype.float32.
1807
- hidden_act (str, nn.Cell): The activation of the internal feedforward layer. Supports 'relu',
1808
- 'relu6', 'tanh', 'gelu', 'fast_gelu', 'elu', 'sigmoid', 'prelu', 'leakyrelu', 'hswish',
1809
- 'hsigmoid', 'logsigmoid' and so on. User can provide custom activition to the argument.
1810
- If user wants to run the net in the parallel mode, the custom activation must also provide
1811
- the `activation_shard` function. Please see the examples of the
1812
- class:`mindspore.nn.transformer.FeedForward`. Default: gelu.
1813
- moe_config(MoEConfig): The configuration of MoE (Mixture of Expert). Default is an instance of MoEConfig
1814
- with default values. Please see `MoEConfig`.
1815
- parallel_config(OpParallelConfig, MoEParallelConfig): The parallel configure. When MoE is applied,
1816
- MoEParallelConfig is effective, otherwise OpParallelConfig is effective. Default `default_dpmp_config`,
1817
- an instance of `OpParallelConfig` with default args.
1818
-
1819
- Inputs:
1820
- - **hidden_stats** (Tensor) - The input tensor with shape [batch_size, tgt_seq_length, hidden_size] or
1821
- [batch_size * tgt_seq_length, hidden_size].
1822
- - **decoder_mask** (Tensor) - The attention mask for decoder with shape [batch_size, src_seq_length,
1823
- seq_length] or None. None means there will be no mask in softmax computation in self attention.
1824
- - **encoder_output** (Tensor) - The output of the encoder with shape [batch_size, seq_length, hidden_size]
1825
- or [batch_size * seq_length, hidden_size].
1826
- Note this args can not be passed by None when the net is in outermost layer. Default None.
1827
- - **memory_mask** (Tensor) - The memory mask of the cross attention with shape [batch, tgt_seq_length,
1828
- src_seq_length] where tgt_seq_length is the length of the decoder. The user can also pass None. None
1829
- means there will be no mask in softmax computation in cross attention. Default None.
1830
- - **init_reset** (Tensor) - A bool tensor with shape [1], used to clear the past key parameter and
1831
- past value parameter used in the incremental prediction. Only valid when use_past is True. Default True.
1832
- - **batch_valid_length** (Tensor) - int32 tensor with shape [batch_size] the past calculated the index.
1833
- Used for incremental prediction when the use_past is True. Default None.
1834
-
1835
- Outputs:
1836
- Tuple, a tuple contains(`output`, `layer_present`)
1837
-
1838
- - **output** (Tensor) - The output logit of this layer. The shape is [batch, seq_length, hidden_size] or
1839
- [batch * seq_length, hidden_size].
1840
- - **layer_present** (Tuple) - A tuple, where each tuple is the tensor of the projected key and value
1841
- vector in self attention with shape ((batch_size, num_heads, size_per_head, tgt_seq_length),
1842
- (batch_size, num_heads, tgt_seq_length, size_per_head), and of the projected key and value vector
1843
- in cross attention with shape (batch_size, num_heads, size_per_head, src_seq_length),
1844
- (batch_size, num_heads, src_seq_length, size_per_head)).
1845
-
1846
- Supported Platforms:
1847
- ``Ascend`` ``GPU``
1848
-
1849
- Examples:
1850
- >>> import numpy as np
1851
- >>> from mindspore import dtype as mstype
1852
- >>> from mindspore.nn.transformer import TransformerDecoderLayer
1853
- >>> from mindspore import Tensor
1854
- >>> model = TransformerDecoderLayer(batch_size=2, hidden_size=64, ffn_hidden_size=64, num_heads=2,
1855
- ... src_seq_length=20, tgt_seq_length=10)
1856
- >>> encoder_input_value = Tensor(np.ones((2, 20, 64)), mstype.float32)
1857
- >>> decoder_input_value = Tensor(np.ones((2, 10, 64)), mstype.float32)
1858
- >>> decoder_input_mask = Tensor(np.ones((2, 10, 10)), mstype.float16)
1859
- >>> memory_mask = Tensor(np.ones((2, 10, 20)), mstype.float16)
1860
- >>> output, past = model(decoder_input_value, decoder_input_mask, encoder_input_value, memory_mask)
1861
- >>> print(output.shape)
1862
- (2, 10, 64)
1863
- >>> print(past[0].shape)
1864
- (2, 2, 32, 10)
1865
- >>> print(past[1].shape)
1866
- (2, 2, 10, 32)
1867
- >>> print(past[2].shape)
1868
- (2, 2, 32, 20)
1869
- >>> print(past[3].shape)
1870
- (2, 2, 20, 32)
1871
- """
1872
-
1873
- @_LogActionOnce(logger=logger, key='TransformerDecoderLayer',
1874
- no_warning=_get_parallel_mode() in (ParallelMode.STAND_ALONE,))
1875
- @_args_type_validator_check(hidden_size=Validator.check_positive_int,
1876
- num_heads=Validator.check_positive_int,
1877
- ffn_hidden_size=Validator.check_positive_int,
1878
- src_seq_length=Validator.check_positive_int,
1879
- tgt_seq_length=Validator.check_positive_int,
1880
- attention_dropout_rate=Validator.check_non_negative_float,
1881
- hidden_dropout_rate=Validator.check_non_negative_float,
1882
- post_layernorm_residual=Validator.check_bool,
1883
- layernorm_compute_type=_valid_value_checks([mstype.float32, mstype.float16],
1884
- "TransformerDecoderLayer"),
1885
- softmax_compute_type=_valid_value_checks([mstype.float32, mstype.float16],
1886
- "TransformerDecoderLayer"),
1887
- param_init_type=_valid_value_checks([mstype.float32, mstype.float16],
1888
- "TransformerDecoderLayer"),
1889
- parallel_config=_valid_type_checks([OpParallelConfig, MoEParallelConfig],
1890
- "TransformerDecoderLayer"),
1891
- use_past=Validator.check_bool)
1892
- def __init__(self, hidden_size,
1893
- ffn_hidden_size,
1894
- num_heads,
1895
- batch_size,
1896
- src_seq_length,
1897
- tgt_seq_length,
1898
- attention_dropout_rate=0.1,
1899
- hidden_dropout_rate=0.1,
1900
- post_layernorm_residual=False,
1901
- use_past=False,
1902
- layernorm_compute_type=mstype.float32,
1903
- softmax_compute_type=mstype.float32,
1904
- param_init_type=mstype.float32,
1905
- hidden_act='gelu',
1906
- moe_config=default_moe_config,
1907
- parallel_config=default_dpmp_config):
1908
- super(TransformerDecoderLayer, self).__init__()
1909
- _check_moe_config(moe_config, parallel_config)
1910
- self.use_moe = moe_config.expert_num > 1
1911
- config_to_attention = parallel_config.dpmp if self.use_moe else parallel_config
1912
- if batch_size or use_past:
1913
- Validator.check_positive_int(batch_size)
1914
- if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,):
1915
- _check_config(parallel_config)
1916
- if num_heads % parallel_config.model_parallel != 0:
1917
- raise ValueError("For 'TransformerDecoderLayer', the class variable 'num_heads' must be divisibled by "
1918
- "'parallel_config.model_parallel', but got the num_heads is {} and "
1919
- "parallel_config.model_parallel is {}.".format(num_heads,
1920
- parallel_config.model_parallel))
1921
- if hidden_size % parallel_config.model_parallel != 0:
1922
- raise ValueError(
1923
- "For 'TransformerDecoderLayer', the class variable 'hidden_size' must be divisibled by "
1924
- "'parallel_config.model_parallel', but got the hidden_size is {} and "
1925
- "parallel_config.model_parallel is {}."
1926
- .format(hidden_size, parallel_config.model_parallel))
1927
- if ffn_hidden_size % parallel_config.model_parallel != 0:
1928
- raise ValueError("For 'TransformerDecoderLayer', the class variable 'ffn_hidden_size' must be "
1929
- "divisibled by 'parallel_config.model_parallel', but got the ffn_hidden_size is {} "
1930
- "and parallel_config.model_parallel is {}."
1931
- .format(ffn_hidden_size, parallel_config.model_parallel))
1932
- if use_past:
1933
- raise ValueError(f"The {self.cls_name} does not support use_past=True.")
1934
- self.batch_size = batch_size
1935
- self.use_past = use_past
1936
- self.softmax_compute_type = softmax_compute_type
1937
-
1938
- self.src_seq_length = src_seq_length
1939
- self.tgt_seq_length = tgt_seq_length
1940
- self.use_past = use_past
1941
- self.hidden_size = hidden_size
1942
-
1943
- self.layernorm1 = _LayerNorm((hidden_size,)).to_float(layernorm_compute_type)
1944
- self.layernorm2 = _LayerNorm((hidden_size,)).to_float(layernorm_compute_type)
1945
- self.attention = MultiHeadAttention(hidden_size=hidden_size,
1946
- num_heads=num_heads,
1947
- batch_size=batch_size,
1948
- src_seq_length=tgt_seq_length,
1949
- tgt_seq_length=tgt_seq_length,
1950
- hidden_dropout_rate=hidden_dropout_rate,
1951
- attention_dropout_rate=attention_dropout_rate,
1952
- use_past=use_past,
1953
- softmax_compute_type=softmax_compute_type,
1954
- param_init_type=param_init_type,
1955
- parallel_config=config_to_attention)
1956
-
1957
- # Cross attention with the output of encoder as memory tensor
1958
- self.cross_attention = MultiHeadAttention(hidden_size=hidden_size,
1959
- num_heads=num_heads,
1960
- batch_size=batch_size,
1961
- src_seq_length=tgt_seq_length,
1962
- tgt_seq_length=src_seq_length,
1963
- hidden_dropout_rate=hidden_dropout_rate,
1964
- attention_dropout_rate=attention_dropout_rate,
1965
- softmax_compute_type=softmax_compute_type,
1966
- use_past=use_past,
1967
- param_init_type=param_init_type,
1968
- parallel_config=config_to_attention)
1969
- self.cross_attention_layernorm = _LayerNorm((hidden_size,)).to_float(
1970
- layernorm_compute_type)
1971
-
1972
- if self.use_moe:
1973
- self.output = MoE(hidden_size=hidden_size,
1974
- dropout_rate=hidden_dropout_rate,
1975
- ffn_hidden_size=ffn_hidden_size,
1976
- param_init_type=param_init_type,
1977
- hidden_act=hidden_act,
1978
- moe_config=moe_config,
1979
- parallel_config=parallel_config)
1980
- else:
1981
- # Feed Forward Network, FFN
1982
- self.output = FeedForward(hidden_size=hidden_size,
1983
- dropout_rate=hidden_dropout_rate,
1984
- ffn_hidden_size=ffn_hidden_size,
1985
- hidden_act=hidden_act,
1986
- param_init_type=param_init_type,
1987
- parallel_config=parallel_config)
1988
- self.post_layernorm_residual = post_layernorm_residual
1989
- self.add = P.Add()
1990
- self.add_3d = P.Add()
1991
- self.dtype = mstype.float16
1992
- self.key_past = None
1993
- self.value_past = None
1994
- if self.use_past:
1995
- # operator used for state reuse
1996
- self.reducesum = P.ReduceSum().shard(((1, 1, 1, 1),))
1997
- self.not_equal = P.NotEqual().shard(((1, 1, 1, 1), ()))
1998
- self.slice = P.StridedSlice().shard(((1, 1, 1, 1),))
1999
- size_per_head = hidden_size // num_heads
2000
- self.key_shape = (batch_size, num_heads, size_per_head, tgt_seq_length)
2001
- self.value_shape = (batch_size, num_heads, tgt_seq_length, size_per_head)
2002
- # parameters saving key and value states
2003
- self.key_past = Parameter(Tensor(np.zeros(shape=self.key_shape), self.dtype), name="key_past")
2004
- self.value_past = Parameter(Tensor(np.zeros(shape=self.value_shape), self.dtype), name="value_past")
2005
- self.tile = P.Tile().shard(((1, 1),))
2006
- self.mul = P.Mul().shard(((1, 1, 1, 1), (1,)))
2007
- self.assign = P.Assign().shard(((1, 1, 1, 1), (1, 1, 1, 1)))
2008
- elif _get_parallel_mode() not in (ParallelMode.AUTO_PARALLEL,):
2009
- _check_config(parallel_config)
2010
- if num_heads % parallel_config.model_parallel != 0:
2011
- raise ValueError("For 'TransformerDecoderLayer', the class variable 'num_heads' must be divisibled by "
2012
- "'parallel_config.model_parallel', but got the num_heads is {} and "
2013
- "parallel_config.model_parallel is {}.".format(num_heads,
2014
- parallel_config.model_parallel))
2015
- if hidden_size % parallel_config.model_parallel != 0:
2016
- raise ValueError(
2017
- "For 'TransformerDecoderLayer', the class variable 'hidden_size' must be divisibled by "
2018
- "'parallel_config.model_parallel', but got the hidden_size is {} and "
2019
- "parallel_config.model_parallel is {}."
2020
- .format(hidden_size, parallel_config.model_parallel))
2021
- if ffn_hidden_size % parallel_config.model_parallel != 0:
2022
- raise ValueError("For 'TransformerDecoderLayer', the class variable 'ffn_hidden_size' must be "
2023
- "divisibled by 'parallel_config.model_parallel', but got the ffn_hidden_size is {} "
2024
- "and parallel_config.model_parallel is {}."
2025
- .format(ffn_hidden_size, parallel_config.model_parallel))
2026
- if use_past:
2027
- raise ValueError(f"The {self.cls_name} does not support use_past=True.")
2028
- self.batch_size = batch_size
2029
- self.use_past = use_past
2030
- self.softmax_compute_type = softmax_compute_type
2031
-
2032
- self.src_seq_length = src_seq_length
2033
- self.tgt_seq_length = tgt_seq_length
2034
- self.use_past = use_past
2035
- self.hidden_size = hidden_size
2036
-
2037
- self.layernorm1 = _LayerNorm((hidden_size,)).to_float(layernorm_compute_type)
2038
- self.layernorm1.shard(((parallel_config.data_parallel, 1),))
2039
- self.layernorm2 = _LayerNorm((hidden_size,)).to_float(layernorm_compute_type)
2040
- self.layernorm2.shard(((parallel_config.data_parallel, 1),))
2041
- self.attention = MultiHeadAttention(hidden_size=hidden_size,
2042
- num_heads=num_heads,
2043
- batch_size=batch_size,
2044
- src_seq_length=tgt_seq_length,
2045
- tgt_seq_length=tgt_seq_length,
2046
- hidden_dropout_rate=hidden_dropout_rate,
2047
- attention_dropout_rate=attention_dropout_rate,
2048
- use_past=use_past,
2049
- softmax_compute_type=softmax_compute_type,
2050
- param_init_type=param_init_type,
2051
- parallel_config=config_to_attention)
2052
-
2053
- # Cross attention with the output of encoder as memory tensor
2054
- self.cross_attention = MultiHeadAttention(hidden_size=hidden_size,
2055
- num_heads=num_heads,
2056
- batch_size=batch_size,
2057
- src_seq_length=tgt_seq_length,
2058
- tgt_seq_length=src_seq_length,
2059
- hidden_dropout_rate=hidden_dropout_rate,
2060
- attention_dropout_rate=attention_dropout_rate,
2061
- softmax_compute_type=softmax_compute_type,
2062
- use_past=use_past,
2063
- param_init_type=param_init_type,
2064
- parallel_config=config_to_attention)
2065
- self.cross_attention_layernorm = _LayerNorm((hidden_size,)).to_float(
2066
- layernorm_compute_type)
2067
- self.cross_attention_layernorm.shard(((parallel_config.data_parallel, 1),))
2068
-
2069
- if self.use_moe:
2070
- self.output = MoE(hidden_size=hidden_size,
2071
- dropout_rate=hidden_dropout_rate,
2072
- ffn_hidden_size=ffn_hidden_size,
2073
- param_init_type=param_init_type,
2074
- hidden_act=hidden_act,
2075
- moe_config=moe_config,
2076
- parallel_config=parallel_config)
2077
- else:
2078
- # Feed Forward Network, FFN
2079
- self.output = FeedForward(hidden_size=hidden_size,
2080
- dropout_rate=hidden_dropout_rate,
2081
- ffn_hidden_size=ffn_hidden_size,
2082
- hidden_act=hidden_act,
2083
- param_init_type=param_init_type,
2084
- parallel_config=parallel_config)
2085
- self.post_layernorm_residual = post_layernorm_residual
2086
- self.add = P.Add().shard(((parallel_config.data_parallel, 1), (parallel_config.data_parallel, 1)))
2087
- self.add_3d = P.Add().shard(((parallel_config.data_parallel, 1, 1), (parallel_config.data_parallel, 1, 1)))
2088
- self.dtype = mstype.float16
2089
- self.key_past = None
2090
- self.value_past = None
2091
- if self.use_past:
2092
- # operator used for state reuse
2093
- self.reducesum = P.ReduceSum().shard(((1, 1, 1, 1),))
2094
- self.not_equal = P.NotEqual().shard(((1, 1, 1, 1), ()))
2095
- self.slice = P.StridedSlice().shard(((1, 1, 1, 1),))
2096
- size_per_head = hidden_size // num_heads
2097
- self.key_shape = (batch_size, num_heads, size_per_head, tgt_seq_length)
2098
- self.value_shape = (batch_size, num_heads, tgt_seq_length, size_per_head)
2099
- # parameters saving key and value states
2100
- self.key_past = Parameter(Tensor(np.zeros(shape=self.key_shape), self.dtype), name="key_past")
2101
- self.value_past = Parameter(Tensor(np.zeros(shape=self.value_shape), self.dtype), name="value_past")
2102
- self.tile = P.Tile().shard(((1, 1),))
2103
- self.mul = P.Mul().shard(((1, 1, 1, 1), (1,)))
2104
- self.assign = P.Assign().shard(((1, 1, 1, 1), (1, 1, 1, 1)))
2105
- else:
2106
- raise RuntimeError(f"The {self.cls_name} only support sharding propagation or "
2107
- f"semi-auto parallel mode now.")
2108
-
2109
- def construct(self, hidden_stats,
2110
- decoder_mask,
2111
- encoder_output=None,
2112
- memory_mask=None,
2113
- init_reset=True, batch_valid_length=None):
2114
- self._check_input(hidden_stats, decoder_mask, encoder_output, memory_mask, init_reset, batch_valid_length)
2115
- # the returned shape is [bs, seq_length, embedding_size] or [bs * seq_length, embedding_size]
2116
- hidden_shape = F.shape(hidden_stats)
2117
- hidden_stats = F.reshape(hidden_stats, (-1, hidden_shape[-1]))
2118
- input_x = self.layernorm1(hidden_stats)
2119
- input_x = F.cast(input_x, self.dtype)
2120
-
2121
- # indicate whether reset saved states
2122
- key_reset = None
2123
- value_reset = None
2124
- if self.use_past:
2125
- # reset states, init_reset True for reuse and False for reset
2126
- self.assign(self.key_past, self.mul(self.key_past, F.cast(init_reset, self.dtype)))
2127
- key_reset = self.key_past
2128
- self.assign(self.value_past, self.mul(self.value_past, F.cast(init_reset, self.dtype)))
2129
- value_reset = self.value_past
2130
- # add dependency for desired execution order
2131
- input_x = F.depend(input_x, key_reset)
2132
- input_x = F.depend(input_x, value_reset)
2133
-
2134
- attention, layer_present = self.attention(input_x, input_x, input_x, decoder_mask, self.key_past,
2135
- self.value_past, batch_valid_length)
2136
- # For post-layernorm the inputs for residual path are output of self-attention and output of layernorm
2137
- if self.post_layernorm_residual:
2138
- x = self.add(input_x, attention)
2139
- # For pre-layernorm the inputs for residual path are output of self-attention and input of this layer
2140
- else:
2141
- x = self.add(hidden_stats, attention)
2142
-
2143
- middle_output = None
2144
- if encoder_output is not None:
2145
- middle_output = self.cross_attention_layernorm(x)
2146
- middle_output = F.cast(middle_output, self.dtype)
2147
- encoder_output = F.cast(encoder_output, self.dtype)
2148
- cross_attn_output, cross_layer_present = self.cross_attention(middle_output, encoder_output,
2149
- encoder_output,
2150
- memory_mask, self.key_past,
2151
- self.value_past, batch_valid_length)
2152
- layer_present += cross_layer_present
2153
- if self.post_layernorm_residual:
2154
- x = self.add(middle_output, cross_attn_output)
2155
- else:
2156
- x = self.add(x, cross_attn_output)
2157
-
2158
- output_x = self.layernorm2(x)
2159
- output_x = F.cast(output_x, self.dtype)
2160
- aux_loss = None
2161
- if self.use_moe:
2162
- mlp_logit, aux_loss = self.output(output_x)
2163
- else:
2164
- mlp_logit = self.output(output_x)
2165
-
2166
- value_update = None
2167
- key_update = None
2168
- if self.use_past:
2169
- # current key and value
2170
- key_present, value_present = layer_present
2171
- # update key and value calculated this step
2172
- self.assign(self.key_past, key_present)
2173
- key_update = self.key_past
2174
- self.assign(self.value_past, value_present)
2175
- value_update = self.value_past
2176
- # add dependency for desired execution order
2177
- key_update = F.depend(key_update, key_reset)
2178
- value_update = F.depend(value_update, value_reset)
2179
-
2180
- # add dependency for desired execution order
2181
- mlp_logit = F.depend(mlp_logit, value_update)
2182
- mlp_logit = F.depend(mlp_logit, key_update)
2183
-
2184
- # if shape is 3d, we reshape the inputs of the add
2185
- if len(hidden_shape) == 3:
2186
- output_x = P.Reshape()(output_x, hidden_shape)
2187
- mlp_logit = P.Reshape()(mlp_logit, hidden_shape)
2188
- x = P.Reshape()(x, hidden_shape)
2189
-
2190
- if self.post_layernorm_residual:
2191
- output = self.add_3d(output_x, mlp_logit)
2192
- else:
2193
- output = self.add_3d(x, mlp_logit)
2194
- else:
2195
- if self.post_layernorm_residual:
2196
- output = self.add(output_x, mlp_logit)
2197
- else:
2198
- output = self.add(x, mlp_logit)
2199
- output = F.reshape(output, hidden_shape)
2200
-
2201
- if self.use_moe:
2202
- return output, layer_present, aux_loss
2203
- return output, layer_present
2204
-
2205
- def _check_input(self, hidden_states, attention_mask, encoder_output, memory_mask, init_reset, batch_valid_length):
2206
- r"""Check inputs"""
2207
- _check_input_dtype(F.dtype(hidden_states), "hidden_states", [mstype.float32, mstype.float16], self.cls_name)
2208
- if attention_mask is not None:
2209
- _check_input_dtype(F.dtype(attention_mask), "attention_mask", [mstype.float32, mstype.float16],
2210
- self.cls_name)
2211
- if encoder_output is not None:
2212
- _check_input_dtype(F.dtype(encoder_output), "encoder_output",
2213
- [mstype.float32, mstype.float16], self.cls_name)
2214
- if memory_mask is not None:
2215
- _check_input_dtype(F.dtype(memory_mask), "memory_mask",
2216
- [mstype.float32, mstype.float16], self.cls_name)
2217
-
2218
- init_reset_is_tensor = isinstance(init_reset, Tensor)
2219
- init_reset_is_default = init_reset is True
2220
- batch_valid_length_is_tensor = isinstance(batch_valid_length, Tensor)
2221
- batch_is_default = batch_valid_length is None
2222
- _check_past_none_input_none(self.use_past, "init_reset", self.cls_name, True, init_reset_is_tensor,
2223
- init_reset_is_default)
2224
- _check_past_none_input_none(self.use_past, "batch_valid_length", self.cls_name, None,
2225
- batch_valid_length_is_tensor, batch_is_default)
2226
-
2227
- if self.use_past:
2228
- _check_input_dtype(F.dtype(init_reset), "init_reset", [mstype.bool_], self.cls_name)
2229
- _check_input_dtype(F.dtype(batch_valid_length), "batch_valid_length", [mstype.int32], self.cls_name)
2230
- return True
2231
-
2232
-
2233
- def _get_lambda_func(total_layer=None):
2234
- r"""
2235
- A wrapper function of specifying pipeline stage and gradient aggregation fusion. If the total layer
2236
- is not None, for example, set in the transformer model, the pipeline stage setting function will be
2237
- `(layer_id + 0) // (total_layers / parallel_config.pipeline_stage)` for the encoder and,
2238
- `(layer_id + offset) //
2239
- (total_layers / parallel_config.pipeline_stage)` for the decoder, where `offset` is the layers in the encoder.
2240
- """
2241
-
2242
- def _set_parallel_configure_for_layer(network, layer_id, offset, parallel_config, layers):
2243
- r"""
2244
- Default setting for the pipeline is: `(layer_id + offset) // (layers / pipeline_stage)`.
2245
-
2246
- Args:
2247
- network(Cell) - Represents the transformer block
2248
- layer_id(int) - Means the layer index for the current module, counts from zero.
2249
- offset(int) - Means the layer_index needs an offset, if there are other modules in the net.
2250
- layers(int) - The total layers used for the model.
2251
- """
2252
- # override the layers
2253
- if total_layer:
2254
- layers = total_layer
2255
- # Used for the pipeline's stages setting
2256
- if layers < parallel_config.pipeline_stage:
2257
- raise ValueError(f"layers {layers} must be larger than pipeline stage {parallel_config.pipeline_stage}")
2258
-
2259
- pp_dis = max(layers // parallel_config.pipeline_stage, 1)
2260
- # the pipeline stage must be in [0, parallel_config.pipeline_stage - 1]
2261
- pp_id = min((layer_id + offset) // pp_dis, parallel_config.pipeline_stage - 1)
2262
- network.pipeline_stage = pp_id
2263
-
2264
- # Used for optimizer's fusion tag
2265
- dis = max(layers // parallel_config.gradient_aggregation_group, 1)
2266
- network.set_comm_fusion((layer_id + offset) // dis + 1)
2267
- # Used for enabling recomputation of the block
2268
- if isinstance(parallel_config.recompute, bool):
2269
- if parallel_config.recompute:
2270
- network.recompute()
2271
- else:
2272
- if parallel_config.recompute.recompute:
2273
- paralel_op_comm_compute = parallel_config.recompute.parallel_optimizer_comm_recompute
2274
- network.recompute(parallel_optimizer_comm_recompute=paralel_op_comm_compute,
2275
- mp_comm_recompute=parallel_config.recompute.mp_comm_recompute,
2276
- recompute_slice_activation=parallel_config.recompute.recompute_slice_activation)
2277
-
2278
- return _set_parallel_configure_for_layer
2279
-
2280
-
2281
- class TransformerEncoder(Cell):
2282
- r"""
2283
- Transformer Encoder module with multi-layer stacked of `TransformerEncoderLayer`, including multihead self
2284
- attention and feedforward layer.
2285
-
2286
- Args:
2287
- batch_size(int): The batch size of the input tensor when do increnmental prediction. Should be a positive
2288
- value. When do training or prediction, the argument will not work and the user can just pass None to
2289
- the argument.
2290
- num_layers(int): The layers of the `TransformerEncoderLayer`
2291
- hidden_size(int): The hidden size of the input.
2292
- ffn_hidden_size(int): The hidden size of bottleneck in the feedforward layer.
2293
- seq_length(int): The seq_length of the input tensor.
2294
- num_heads(int): The number of the heads.
2295
- attention_dropout_rate(float): The dropout rate of the attention scores. Default:0.1.
2296
- hidden_dropout_rate(float): The dropout rate of the final output of the layer. Default: 0.1.
2297
- hidden_act (str, nn.Cell): The activation of the internal feedforward layer. Supports 'relu',
2298
- 'relu6', 'tanh', 'gelu', 'fast_gelu', 'elu', 'sigmoid', 'prelu', 'leakyrelu', 'hswish',
2299
- 'hsigmoid', 'logsigmoid' and so on. User can provide custom activition to the argument.
2300
- If user wants to run the net in the parallel mode, the custom activation must also provide
2301
- the `activation_shard` function. Please see the examples of the
2302
- class:`mindspore.nn.transformer.FeedForward`. Default: gelu.
2303
- post_layernorm_residual(bool): Do residuals adds before the layernorm. Default False.
2304
- layernorm_compute_type(dtype.Number): The computation type of the layernorm.
2305
- Should be mstype.float32 or mstype.float16. Default mstype.float32.
2306
- softmax_compute_type(dtype.Number): The computation type of the softmax in the attention.
2307
- Should be mstype.float32 or mstype.float16. Default: mstype.float32.
2308
- param_init_type(dtype.Number): The parameter initialization type of the module.
2309
- Should be mstype.float32 or mstype.float16. Default: mstype.float32.
2310
- lambda_func(function): A function can determine the fusion index,
2311
- pipeline stages and recompute attribute. If the
2312
- user wants to determine the pipeline stage and gradient aggregation fusion, the user can pass a
2313
- function that accepts `network`, `layer_id`, `offset`, `parallel_config`, `layers`. The `network(Cell)`
2314
- represents the transformer block, `layer_id(int)` means the layer index for the current module, counts
2315
- from zero, `offset(int)` means the layer_index needs an offset, if there are other modules in the net.
2316
- The default setting for the pipeline is: `(layer_id + offset) // (layers / pipeline_stage)`.
2317
- Default: ``None``.
2318
- offset(int): The initial layer index for the `encoder`. Used for setting the fusion id and stage id, to not
2319
- overlap with the encoder layer. Default 0.
2320
- use_past(bool): Use the past state to compute, used for incremental prediction. For example, if we have two
2321
- words and want to generate the ten more words. We just need to compute the two words' state only once,
2322
- and generate the next word one by one. When use_past is True, there are two steps to run the prediction.
2323
- In the first step, set the is_first_iteration to be True by
2324
- `model.add_flags_recursive(is_first_iteration=True)`, and pass the full inputs. Then, set the
2325
- is_first_iteration to be False by `model.add_flags_recursive(is_first_iteration=False)`. At this moment,
2326
- pass the single step's input tensor, and loop it. Default: ``False``.
2327
- moe_config(MoEConfig): The configuration of MoE (Mixture of Expert). Default is an instance of MoEConfig
2328
- with default values. Please see `MoEConfig`.
2329
- parallel_config(TransformerOpParallelConfig): The parallel configure. Default `default_transformer_config`,
2330
- an instance of `TransformerOpParallelConfig` with default args.
2331
-
2332
- Inputs:
2333
- - **hidden_states** (Tensor) - Tensor, shape should be [batch_size, seq_length, hidden_size] or
2334
- [batch_size * seq_length, hidden_size], if the use_past is False or is_first_iteration=True. Otherwise,
2335
- should be [batch_size, 1, hidden_size].
2336
- - **attention_mask** (Tensor) - Float Tensor, If the use_past is False or is_first_iteration=True,
2337
- the attention mask matrix should ba [batch_size, seq_length, seq_length], or None. None means there will
2338
- be no mask in softmax computation. Otherwise, should be [batch_size, 1, hidden_size]
2339
- - **init_reset** (Tensor) - A bool tensor with shape [1], used to clear the past key parameter and
2340
- past value parameter used in the incremental prediction. Only valid when use_past is True. Default True.
2341
- - **batch_valid_length** (Tensor) - int32 tensor with shape [batch_size] the past calculated the index.
2342
- Used for incremental prediction when the use_past is True. Default None.
2343
-
2344
- Outputs:
2345
- Tuple, a tuple contains(`output`, `layer_present`)
2346
-
2347
- - **output** (Tensor) - The float tensor of the output of the layer with
2348
- shape (batch_size, seq_length, hidden_size) or (batch_size * seq_length, hidden_size), if the use_past is
2349
- False or is_first_iteration=True. Otherwise, it will be (batch_size, 1, hidden_size).
2350
- - **layer_present** (Tuple) - A tuple with size of num_layers, where each tuple contains the Tensor the
2351
- projected key and value vector with shape ((batch_size, num_heads, size_per_head, seq_length),
2352
- and (batch_size, num_heads, seq_length, size_per_head)).
2353
-
2354
- Supported Platforms:
2355
- ``Ascend`` ``GPU``
2356
-
2357
- Examples:
2358
- >>> import numpy as np
2359
- >>> from mindspore import dtype as mstype
2360
- >>> from mindspore.nn.transformer import TransformerEncoder
2361
- >>> from mindspore import Tensor
2362
- >>> model = TransformerEncoder(batch_size=2, num_layers=2, hidden_size=8, ffn_hidden_size=64,
2363
- ... seq_length=16, num_heads=2)
2364
- >>> encoder_input_value = Tensor(np.ones((2, 16, 8)), mstype.float32)
2365
- >>> encoder_input_mask = Tensor(np.ones((2, 16, 16)), mstype.float16)
2366
- >>> output, past = model(encoder_input_value, encoder_input_mask)
2367
- >>> print(output.shape)
2368
- (2, 16, 8)
2369
- >>> print(len(past))
2370
- 2
2371
- >>> print(past[0][0].shape)
2372
- (2, 2, 4, 16)
2373
- >>> print(past[0][1].shape)
2374
- (2, 2, 16, 4)
2375
- >>> # When use use_past=True, it includes two steps to implement the incremental prediction.
2376
- >>> # Step 1: set is_first_iteration=True, and input the full sequence length's state.
2377
- >>> batch_valid_length = Tensor(np.ones((2,)), mstype.int32)
2378
- >>> init_reset = Tensor([True], mstype.bool_)
2379
- >>> # Set is_first_iteration=True to generate the full memory states
2380
- >>> model = TransformerEncoder(batch_size=2, hidden_size=8, ffn_hidden_size=64, seq_length=16,
2381
- ... num_heads=2, num_layers=2, use_past=True)
2382
- >>> model.add_flags_recursive(is_first_iteration=True)
2383
- >>> hidden, past = model(encoder_input_value, encoder_input_mask, init_reset, batch_valid_length)
2384
- >>> print(hidden.shape)
2385
- (2, 16, 8)
2386
- >>> print(past[0][0].shape)
2387
- (2, 2, 4, 16)
2388
- >>> print(past[0][1].shape)
2389
- (2, 2, 16, 4)
2390
- >>> encoder_input_value = Tensor(np.ones((2, 1, 8)), mstype.float32)
2391
- >>> encoder_input_mask = Tensor(np.ones((2, 1, 16)), mstype.float16)
2392
- >>> init_reset = Tensor([False], mstype.bool_)
2393
- >>> # Step 2: set is_first_iteration=False, and pass the single word to run the prediction rather than
2394
- >>> # the full sequence.
2395
- >>> model.add_flags_recursive(is_first_iteration=False)
2396
- >>> hidden, past = model(encoder_input_value, encoder_input_mask, init_reset, batch_valid_length)
2397
- >>> print(hidden.shape)
2398
- (2, 1, 8)
2399
- >>> print(past[0][0].shape)
2400
- (2, 2, 4, 16)
2401
- >>> print(past[0][1].shape)
2402
- (2, 2, 16, 4)
2403
- """
2404
-
2405
- @_LogActionOnce(logger=logger, key='TransformerEncoder',
2406
- no_warning=_get_parallel_mode() in (ParallelMode.STAND_ALONE,))
2407
- @_args_type_validator_check(batch_size=Validator.check_positive_int,
2408
- hidden_size=Validator.check_positive_int,
2409
- num_heads=Validator.check_positive_int,
2410
- ffn_hidden_size=Validator.check_positive_int,
2411
- seq_length=Validator.check_positive_int,
2412
- num_layers=Validator.check_positive_int,
2413
- offset=Validator.check_non_negative_int,
2414
- attention_dropout_rate=Validator.check_non_negative_float,
2415
- hidden_dropout_rate=Validator.check_non_negative_float,
2416
- post_layernorm_residual=Validator.check_bool,
2417
- layernorm_compute_type=_valid_value_checks([mstype.float32, mstype.float16],
2418
- "TransformerEncoder"),
2419
- softmax_compute_type=_valid_value_checks([mstype.float32, mstype.float16],
2420
- "TransformerEncoder"),
2421
- param_init_type=_valid_value_checks([mstype.float32, mstype.float16],
2422
- "TransformerEncoder"),
2423
- parallel_config=_valid_type_checks([TransformerOpParallelConfig],
2424
- "TransformerEncoder"),
2425
- use_past=Validator.check_bool)
2426
- def __init__(self,
2427
- batch_size,
2428
- num_layers,
2429
- hidden_size,
2430
- ffn_hidden_size,
2431
- seq_length,
2432
- num_heads,
2433
- attention_dropout_rate=0.1,
2434
- hidden_dropout_rate=0.1,
2435
- hidden_act='gelu',
2436
- post_layernorm_residual=False,
2437
- layernorm_compute_type=mstype.float32,
2438
- softmax_compute_type=mstype.float32,
2439
- param_init_type=mstype.float32,
2440
- lambda_func=None,
2441
- offset=0,
2442
- use_past=False,
2443
- moe_config=default_moe_config,
2444
- parallel_config=default_transformer_config):
2445
- super(TransformerEncoder, self).__init__()
2446
- _check_config(parallel_config)
2447
- _check_moe_config(moe_config, parallel_config)
2448
- self.use_moe = moe_config.expert_num > 1
2449
- config_to_layer = parallel_config.moe_parallel_config if self.use_moe else parallel_config.dp_mp_config
2450
- if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,):
2451
- self.add = P.Add()
2452
- self.aux_loss = Tensor(0.0, mstype.float32)
2453
- self.num_layers = num_layers
2454
- self.blocks = nn.CellList()
2455
- for i in range(num_layers):
2456
- block = TransformerEncoderLayer(hidden_size=hidden_size,
2457
- batch_size=batch_size,
2458
- ffn_hidden_size=ffn_hidden_size,
2459
- seq_length=seq_length,
2460
- attention_dropout_rate=attention_dropout_rate,
2461
- hidden_dropout_rate=hidden_dropout_rate,
2462
- layernorm_compute_type=layernorm_compute_type,
2463
- softmax_compute_type=softmax_compute_type,
2464
- num_heads=num_heads,
2465
- hidden_act=hidden_act,
2466
- post_layernorm_residual=post_layernorm_residual,
2467
- param_init_type=param_init_type,
2468
- use_past=use_past,
2469
- moe_config=moe_config,
2470
- parallel_config=config_to_layer)
2471
- # If the user doesn't pass the fusion function, use the default one
2472
- if not lambda_func:
2473
- lambda_func = _get_lambda_func()
2474
-
2475
- lambda_func(block, layer_id=i, layers=num_layers,
2476
- offset=offset, parallel_config=parallel_config)
2477
- self.blocks.append(block)
2478
- elif _get_parallel_mode() not in (ParallelMode.AUTO_PARALLEL,):
2479
- self.add = P.Add().shard(((), ()))
2480
- self.aux_loss = Tensor(0.0, mstype.float32)
2481
- logger.warning("For parallel mode, sharding propagation is recommended, you can use it by setting "
2482
- "'set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, "
2483
- "search_mode=\"sharding_propagation\")' and "
2484
- "'set_algo_parameters(elementwise_op_strategy_follow=False, fully_use_devices=False)'")
2485
- self.num_layers = num_layers
2486
- self.blocks = nn.CellList()
2487
- for i in range(num_layers):
2488
- block = TransformerEncoderLayer(hidden_size=hidden_size,
2489
- batch_size=batch_size,
2490
- ffn_hidden_size=ffn_hidden_size,
2491
- seq_length=seq_length,
2492
- attention_dropout_rate=attention_dropout_rate,
2493
- hidden_dropout_rate=hidden_dropout_rate,
2494
- layernorm_compute_type=layernorm_compute_type,
2495
- softmax_compute_type=softmax_compute_type,
2496
- num_heads=num_heads,
2497
- hidden_act=hidden_act,
2498
- post_layernorm_residual=post_layernorm_residual,
2499
- param_init_type=param_init_type,
2500
- use_past=use_past,
2501
- moe_config=moe_config,
2502
- parallel_config=config_to_layer)
2503
- # If the user doesn't pass the fusion function, use the default one
2504
- if not lambda_func:
2505
- lambda_func = _get_lambda_func()
2506
-
2507
- lambda_func(block, layer_id=i, layers=num_layers,
2508
- offset=offset, parallel_config=parallel_config)
2509
- self.blocks.append(block)
2510
- else:
2511
- raise RuntimeError(f"The {self.cls_name} only support sharding propagation or "
2512
- f"semi-auto parallel mode now.")
2513
-
2514
- def construct(self, hidden_states, attention_mask, init_reset=True, batch_valid_length=None):
2515
- present_layer = ()
2516
- if self.use_moe:
2517
- accum_loss = self.aux_loss
2518
- for i in range(self.num_layers):
2519
- hidden_states, present, aux_loss = self.blocks[i](hidden_states,
2520
- attention_mask,
2521
- init_reset,
2522
- batch_valid_length)
2523
- present_layer = present_layer + (present,)
2524
- accum_loss = self.add(accum_loss, aux_loss)
2525
- return hidden_states, present_layer, accum_loss
2526
-
2527
- for i in range(self.num_layers):
2528
- hidden_states, present = self.blocks[i](hidden_states,
2529
- attention_mask,
2530
- init_reset,
2531
- batch_valid_length)
2532
- present_layer = present_layer + (present,)
2533
-
2534
- return hidden_states, present_layer
2535
-
2536
-
2537
- class TransformerDecoder(Cell):
2538
- r"""
2539
- Transformer Decoder module with multi-layer stacked of `TransformerDecoderLayer`, including multihead self
2540
- attention, cross attention and feedforward layer.
2541
-
2542
- Args:
2543
- num_layers(int): The layers of the `TransformerDecoderLayer`.
2544
- batch_size(int): The batch size of the input tensor when do increnmental prediction. Should be a positive
2545
- value. When do training or prediction, the argument will not work and the user can just pass None to
2546
- the argument.
2547
- hidden_size(int): The hidden size of the input.
2548
- ffn_hidden_size(int): The hidden size of bottleneck in the feedforward layer.
2549
- src_seq_length(int): The input source sequence length.
2550
- tgt_seq_length(int): The input target sequence length.
2551
- num_heads(int): The number of the heads.
2552
- attention_dropout_rate(float): The dropout rate of the attention scores. Default:0.1.
2553
- hidden_dropout_rate(float): The dropout rate of the final output of the layer. Default:0.1.
2554
- post_layernorm_residual(bool): Do residuals adds before the layernorm. Default False.
2555
- layernorm_compute_type(dtype.Number): The computation type of the layernorm.
2556
- Should be mstype.float32 or mstype.float16. Default mstype.float32.
2557
- softmax_compute_type(dtype.Number): The computation type of the softmax in the attention.
2558
- Should be mstype.float32 or mstype.float16. Default mstype.float32.
2559
- param_init_type(dtype.Number): The parameter initialization type of the module.
2560
- Should be mstype.float32 or mstype.float16. Default mstype.float32.
2561
- hidden_act (str, nn.Cell): The activation of the internal feedforward layer. Supports 'relu',
2562
- 'relu6', 'tanh', 'gelu', 'fast_gelu', 'elu', 'sigmoid', 'prelu', 'leakyrelu', 'hswish',
2563
- 'hsigmoid', 'logsigmoid' and so on. User can provide custom activition to the argument.
2564
- If user wants to run the net in the parallel mode, the custom activation must also provide
2565
- the `activation_shard` function. Please see the examples of the
2566
- class:`mindspore.nn.transformer.FeedForward`. Default: gelu.
2567
- lambda_func(function): A function can determine the fusion index,
2568
- pipeline stages and recompute attribute. If the
2569
- user wants to determine the pipeline stage and gradient aggregation fusion, the user can pass a
2570
- function that accepts `network`, `layer_id`, `offset`, `parallel_config`, `layers`. The `network(Cell)`
2571
- represents the transformer block, `layer_id(int)` means the layer index for the current module, counts
2572
- from zero, `offset(int)` means the layer_index needs an offset, if there are other modules in the net.
2573
- The default setting for the pipeline is: `(layer_id + offset) // (layers / pipeline_stage)`.
2574
- Default: ``None``.
2575
- use_past(bool): Use the past state to compute, used for incremental prediction. Default False.
2576
- offset(int): The initial layer index for the `decoder`. Used for setting the fusion id and stage id, to not
2577
- overlap with the encoder layer. Default 0.
2578
- moe_config(MoEConfig): The configuration of MoE (Mixture of Expert). Default is an instance of MoEConfig
2579
- with default values. Please see `MoEConfig`.
2580
- parallel_config(TransformerOpParallelConfig): The parallel configure. Default `default_transformer_config`,
2581
- an instance of `TransformerOpParallelConfig` with default args.
2582
-
2583
- Inputs:
2584
- - **hidden_stats** (Tensor) - The input tensor with shape [batch_size, seq_length, hidden_size] or
2585
- [batch_size * seq_length, hidden_size]
2586
- - **attention_mask** (Tensor) - The attention mask for decoder with shape
2587
- [batch_size, seq_length, seq_length] or None. None means there will be no mask in softmax
2588
- computation in self attention.
2589
- - **encoder_output** (Tensor) - The output of the encoder with shape [batch_size, seq_length, hidden_size]
2590
- or [batch_size * seq_length, hidden_size]. Note this args can not be passed by None when the net is in
2591
- outermost layer. Default None.
2592
- - **memory_mask** (Tensor) - The memory mask of the cross attention with shape [batch, tgt_seq_length,
2593
- src_seq_length] where tgt_seq_length is the length of the decoder. The user can also pass None. None
2594
- means there will be no mask in softmax computation in cross attention. Default None.
2595
- - **init_reset** (Tensor) - A bool tensor with shape [1], used to clear the past key parameter and
2596
- past value parameter used in the incremental prediction. Only valid when use_past is True. Default True.
2597
- - **batch_valid_length** (Tensor) - int32 tensor with shape [batch_size] the past calculated the index.
2598
- Used for incremental prediction when the use_past is True. Default None.
2599
-
2600
- Outputs:
2601
- Tuple, a tuple contains(`output`, `layer_present`)
2602
-
2603
- - **output** (Tensor) - The output logit of this layer. The shape is [batch, tgt_seq_length, hidden_size] or
2604
- [batch * tgt_seq_length, hidden_size]
2605
- - **layer_present** (Tuple) - A tuple with size of num_layers, where each tuple is the tensor of the
2606
- projected key and value vector in self attention with shape ((batch_size, num_heads, size_per_head,
2607
- tgt_seq_length), (batch_size, num_heads, tgt_seq_length, size_per_head), and of the projected key
2608
- and value vector in cross attention with shape (batch_size, num_heads, size_per_head, src_seq_length),
2609
- (batch_size, num_heads, src_seq_length, size_per_head)).
2610
-
2611
- Supported Platforms:
2612
- ``Ascend`` ``GPU``
2613
-
2614
- Examples:
2615
- >>> import numpy as np
2616
- >>> from mindspore import dtype as mstype
2617
- >>> from mindspore.nn.transformer import TransformerDecoder
2618
- >>> from mindspore import Tensor
2619
- >>> model = TransformerDecoder(batch_size=2, num_layers=1, hidden_size=64, ffn_hidden_size=64,
2620
- ... num_heads=2, src_seq_length=20, tgt_seq_length=10)
2621
- >>> encoder_input_value = Tensor(np.ones((2, 20, 64)), mstype.float32)
2622
- >>> decoder_input_value = Tensor(np.ones((2, 10, 64)), mstype.float32)
2623
- >>> decoder_input_mask = Tensor(np.ones((2, 10, 10)), mstype.float16)
2624
- >>> memory_mask = Tensor(np.ones((2, 10, 20)), mstype.float16)
2625
- >>> output, past = model(decoder_input_value, decoder_input_mask, encoder_input_value, memory_mask)
2626
- >>> print(output.shape)
2627
- (2, 10, 64)
2628
- >>> print(len(past))
2629
- 1
2630
- >>> print(past[0][0].shape)
2631
- (2, 2, 32, 10)
2632
- >>> print(past[0][1].shape)
2633
- (2, 2, 10, 32)
2634
- >>> print(past[0][2].shape)
2635
- (2, 2, 32, 20)
2636
- >>> print(past[0][3].shape)
2637
- (2, 2, 20, 32)
2638
- """
2639
-
2640
- @_LogActionOnce(logger=logger, key='TransformerDecoder',
2641
- no_warning=_get_parallel_mode() in (ParallelMode.STAND_ALONE,))
2642
- @_args_type_validator_check(batch_size=Validator.check_positive_int,
2643
- hidden_size=Validator.check_positive_int,
2644
- num_heads=Validator.check_positive_int,
2645
- ffn_hidden_size=Validator.check_positive_int,
2646
- src_seq_length=Validator.check_positive_int,
2647
- num_layers=Validator.check_positive_int,
2648
- tgt_seq_length=Validator.check_positive_int,
2649
- offset=Validator.check_non_negative_int,
2650
- attention_dropout_rate=Validator.check_non_negative_float,
2651
- hidden_dropout_rate=Validator.check_non_negative_float,
2652
- post_layernorm_residual=Validator.check_bool,
2653
- layernorm_compute_type=_valid_value_checks([mstype.float32, mstype.float16],
2654
- "TransformerDecoder"),
2655
- softmax_compute_type=_valid_value_checks([mstype.float32, mstype.float16],
2656
- "TransformerDecoder"),
2657
- param_init_type=_valid_value_checks([mstype.float32, mstype.float16],
2658
- "TransformerDecoder"),
2659
- parallel_config=_valid_type_checks([TransformerOpParallelConfig],
2660
- "TransformerDecoder"),
2661
- use_past=Validator.check_bool)
2662
- def __init__(self,
2663
- num_layers,
2664
- batch_size,
2665
- hidden_size,
2666
- ffn_hidden_size,
2667
- src_seq_length,
2668
- tgt_seq_length,
2669
- num_heads,
2670
- attention_dropout_rate=0.1,
2671
- hidden_dropout_rate=0.1,
2672
- post_layernorm_residual=False,
2673
- layernorm_compute_type=mstype.float32,
2674
- softmax_compute_type=mstype.float32,
2675
- param_init_type=mstype.float32,
2676
- hidden_act='gelu',
2677
- lambda_func=None,
2678
- use_past=False,
2679
- offset=0,
2680
- moe_config=default_moe_config,
2681
- parallel_config=default_transformer_config):
2682
- super(TransformerDecoder, self).__init__()
2683
- _check_moe_config(moe_config, parallel_config)
2684
- _check_config(parallel_config)
2685
- self.use_moe = moe_config.expert_num > 1
2686
- config_to_layer = parallel_config.moe_parallel_config if self.use_moe else parallel_config.dp_mp_config
2687
- if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,):
2688
- self.add = P.Add()
2689
- self.aux_loss = Tensor(0.0, mstype.float32)
2690
- self.num_layers = num_layers
2691
- self.blocks = nn.CellList()
2692
-
2693
- for i in range(num_layers):
2694
- block = TransformerDecoderLayer(hidden_size=hidden_size,
2695
- batch_size=batch_size,
2696
- ffn_hidden_size=ffn_hidden_size,
2697
- src_seq_length=src_seq_length,
2698
- tgt_seq_length=tgt_seq_length,
2699
- attention_dropout_rate=attention_dropout_rate,
2700
- hidden_dropout_rate=hidden_dropout_rate,
2701
- num_heads=num_heads,
2702
- layernorm_compute_type=layernorm_compute_type,
2703
- softmax_compute_type=softmax_compute_type,
2704
- hidden_act=hidden_act,
2705
- use_past=use_past,
2706
- param_init_type=param_init_type,
2707
- post_layernorm_residual=post_layernorm_residual,
2708
- moe_config=moe_config,
2709
- parallel_config=config_to_layer)
2710
- # If the user doesn't pass the fusion function, use the default one
2711
- if not lambda_func:
2712
- lambda_func = _get_lambda_func()
2713
-
2714
- lambda_func(block, layer_id=i, layers=num_layers,
2715
- offset=offset, parallel_config=parallel_config)
2716
-
2717
- self.blocks.append(block)
2718
- elif _get_parallel_mode() not in (ParallelMode.AUTO_PARALLEL,):
2719
- self.add = P.Add().shard(((), ()))
2720
- self.aux_loss = Tensor(0.0, mstype.float32)
2721
- logger.warning("For parallel mode, sharding propagation is recommended, you can use it by setting "
2722
- "'set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, "
2723
- "search_mode=\"sharding_propagation\")' and "
2724
- "'set_algo_parameters(elementwise_op_strategy_follow=False, fully_use_devices=False)'")
2725
- self.num_layers = num_layers
2726
- self.blocks = nn.CellList()
2727
- for i in range(num_layers):
2728
- block = TransformerDecoderLayer(hidden_size=hidden_size,
2729
- batch_size=batch_size,
2730
- ffn_hidden_size=ffn_hidden_size,
2731
- src_seq_length=src_seq_length,
2732
- tgt_seq_length=tgt_seq_length,
2733
- attention_dropout_rate=attention_dropout_rate,
2734
- hidden_dropout_rate=hidden_dropout_rate,
2735
- num_heads=num_heads,
2736
- layernorm_compute_type=layernorm_compute_type,
2737
- softmax_compute_type=softmax_compute_type,
2738
- hidden_act=hidden_act,
2739
- use_past=use_past,
2740
- param_init_type=param_init_type,
2741
- post_layernorm_residual=post_layernorm_residual,
2742
- moe_config=moe_config,
2743
- parallel_config=config_to_layer)
2744
- # If the user doesn't pass the fusion function, use the default one
2745
- if not lambda_func:
2746
- lambda_func = _get_lambda_func()
2747
-
2748
- lambda_func(block, layer_id=i, layers=num_layers,
2749
- offset=offset, parallel_config=parallel_config)
2750
-
2751
- self.blocks.append(block)
2752
- else:
2753
- raise RuntimeError(f"The {self.cls_name} only support sharding propagation or "
2754
- f"semi-auto parallel mode now.")
2755
-
2756
- def construct(self, hidden_states, attention_mask, encoder_output=None, memory_mask=None,
2757
- init_reset=True, batch_valid_length=None):
2758
- present_layer = ()
2759
- if self.use_moe:
2760
- accum_loss = self.aux_loss
2761
- for i in range(self.num_layers):
2762
- hidden_states, present, aux_loss = self.blocks[i](hidden_states,
2763
- attention_mask,
2764
- encoder_output,
2765
- memory_mask,
2766
- init_reset,
2767
- batch_valid_length)
2768
- present_layer = present_layer + (present,)
2769
- accum_loss = self.add(accum_loss, aux_loss)
2770
- return hidden_states, present_layer, accum_loss
2771
-
2772
- # Loop through each self-attention layer
2773
- for i in range(self.num_layers):
2774
- hidden_states, present = self.blocks[i](hidden_states,
2775
- attention_mask,
2776
- encoder_output,
2777
- memory_mask,
2778
- init_reset,
2779
- batch_valid_length)
2780
- present_layer = present_layer + (present,)
2781
-
2782
- return hidden_states, present_layer
2783
-
2784
-
2785
- class Transformer(Cell):
2786
- r"""
2787
- Transformer module including encoder and decoder. The difference with the original implements is the module use
2788
- the residual addition before the layer normalization. And the default hidden act is `gelu`.
2789
- The details can be found in `Attention is all you need <https://arxiv.org/pdf/1706.03762v5.pdf>`_.
2790
-
2791
- .. warning::
2792
- This is an experimental API that is subject to change or deletion.
2793
-
2794
- Args:
2795
- hidden_size(int): The hidden size of the input.
2796
- batch_size(int): The batch size of the input tensor when do increnmental prediction. Should be a positive
2797
- value. When do training or prediction, the argument will not work and the user can just pass None to
2798
- the argument.
2799
- ffn_hidden_size(int): The hidden size of bottleneck in the feedforward layer.
2800
- src_seq_length(int): The seq_length of the encoder's input tensor.
2801
- tgt_seq_length(int): The seq_length of the decoder's input tensor.
2802
- encoder_layers(int): The layers of the `TransformerEncoderLayer`. Default 3.
2803
- decoder_layers(int): The layers of the `TransformerDecoderLayer`. Default 3.
2804
- num_heads(int): The number of the heads. Default: 2.
2805
- attention_dropout_rate(float): The dropout rate of the attention scores. Default:0.1.
2806
- hidden_dropout_rate(float): The dropout rate of the final output of the layer. Default:0.1.
2807
- hidden_act (str, nn.Cell): The activation of the internal feedforward layer. Supports 'relu',
2808
- 'relu6', 'tanh', 'gelu', 'fast_gelu', 'elu', 'sigmoid', 'prelu', 'leakyrelu', 'hswish',
2809
- 'hsigmoid', 'logsigmoid' and so on. User can provide custom activition to the argument.
2810
- If user wants to run the net in the parallel mode, the custom activation must also provide
2811
- the `activation_shard` function. Please see the examples of the
2812
- class:`mindspore.nn.transformer.FeedForward`. Default: gelu.
2813
- post_layernorm_residual(bool): Do residuals adds before the layernorm. Default False.
2814
- layernorm_compute_type(dtype.Number): The computation type of the layernorm.
2815
- Should be dtype.float32 or dtype.float16. Default dtype.float32.
2816
- softmax_compute_type(dtype.Number): The computation type of the softmax in the attention.
2817
- Should be dtype.float32 or dtype.float16. Default mstype.float32.
2818
- param_init_type(dtype.Number): The parameter initialization type of the module.
2819
- Should be dtype.float32 or dtype.float16. Default dtype.float32.
2820
- lambda_func: A function can determine the fusion index, pipeline stages and recompute attribute. If the user
2821
- wants to determine the pipeline stage and gradient aggregation fusion, the user can pass a function
2822
- that accepts `network`, `layer_id`, `offset`, `parallel_config`, `layers`. The `network(Cell)`
2823
- represents the transformer block, `layer_id(int)` means the layer index for the current module, counts
2824
- from zero, `offset(int)` means the layer_index needs an offset, if there are other modules in the net.
2825
- The default setting for the pipeline is: `(layer_id + offset) // ((encoder_layers + decoder_layers)
2826
- / pipeline_stage)`. Default None.
2827
- use_past(bool): Use the past state to compute, used for incremental prediction. Default False.
2828
- moe_config(MoEConfig): The configuration of MoE (Mixture of Expert). Default is an instance of MoEConfig
2829
- with default values. Please see `MoEConfig`.
2830
- parallel_config(TransformerOpParallelConfig): The parallel configure. Default `default_transformer_config`,
2831
- an instance of `TransformerOpParallelConfig` with default args.
2832
-
2833
- Inputs:
2834
- - **encoder_inputs** (Tensor) - The input tensor with shape [batch_size, seq_length, hidden_size] or
2835
- [batch_size * seq_length, hidden_size].
2836
- - **encoder_masks** (Tensor) - The attention mask for decoder with shape
2837
- [batch_size, seq_length, seq_length] or None. None means there will be no mask in softmax computation
2838
- in self attention of the encoder module.
2839
- - **decoder_inputs** (Tensor) - The output of the encoder with shape [batch_size, seq_length, hidden_size]
2840
- or [batch_size * seq_length, hidden_size], this should be none if the decoder layer is 0.
2841
- - **decoder_masks** (Tensor) - The attention mask for decoder with shape
2842
- [batch_size, seq_length, seq_length] or None. None means there will be no mask in softmax computation
2843
- in self attention of the decoder module.
2844
- - **memory_mask** (Tensor) - The memory mask of the cross attention with shape [batch, tgt_seq_length,
2845
- src_seq_length]
2846
- where tgt_seq_length is the length of the decoder. The output of the encoder with shape [batch_size,
2847
- seq_length, hidden_size], this should be none if the decoder layer is 0 or the user wants no mask.
2848
- - **init_reset** (Tensor) - A bool tensor with shape [1], used to clear the past key parameter and
2849
- past value parameter used in the incremental prediction. Only valid when use_past is True. Default True.
2850
- - **batch_valid_length** (Tensor) - int32 tensor with shape [batch_size] the past calculated the index.
2851
- Used for incremental prediction when the use_past is True. Default None.
2852
-
2853
- Outputs:
2854
- Tuple, a tuple contains(`output`, `encoder_layer_present`, `decoder_layer_present`, `accum_loss`)
2855
-
2856
- - **output** (Tensor) - If there is only encoder, the output logit of the encoder layer. The shape is
2857
- [batch, src_seq_length, hidden_size] or [batch * src_seq_length, hidden_size], if there are encoder and
2858
- decoders, the output is from the decoder layer. The shape is [batch, tgt_seq_length, hidden_size] or
2859
- [batch * tgt_seq_length, hidden_size].
2860
- - **encoder_layer_present** (Tuple) - A tuple with size of num_layers, where each tuple is the tensor the
2861
- projected key and value vector in self attention with shape ((batch_size, num_heads, size_per_head,
2862
- src_seq_length), (batch_size, num_heads, src_seq_length, size_per_head)).
2863
- - **decoder_layer_present** (Tuple) - A tuple with size of num_layers, where each tuple is the tensor
2864
- of the projected key and value vector in self attention with shape ((batch_size, num_heads, size_per_head,
2865
- tgt_seq_length), (batch_size, num_heads, tgt_seq_length, size_per_head)), and the
2866
- projected key and value vector in cross attention with shape
2867
- ((batch_size, num_heads, size_per_head, src_seq_length),
2868
- (batch_size, num_heads, src_seq_length, size_per_head)). If the decoder is not set, the
2869
- returned value will be None.
2870
- - **accum_loss** (Tensor) - A Tensor indicates an auxiliary loss to minimize the mean square of the data
2871
- part routed to each expert, and only returned if the number of experts is greater than 1.
2872
-
2873
- Supported Platforms:
2874
- ``Ascend`` ``GPU``
2875
-
2876
- Examples:
2877
- >>> import numpy as np
2878
- >>> from mindspore import dtype as mstype
2879
- >>> from mindspore.nn.transformer import Transformer
2880
- >>> from mindspore import Tensor
2881
- >>> model = Transformer(batch_size=2, encoder_layers=1, decoder_layers=2, hidden_size=64,
2882
- ... ffn_hidden_size=64, src_seq_length=20, tgt_seq_length=10)
2883
- >>> encoder_input_value = Tensor(np.ones((2, 20, 64)), mstype.float32)
2884
- >>> encoder_input_mask = Tensor(np.ones((2, 20, 20)), mstype.float16)
2885
- >>> decoder_input_value = Tensor(np.ones((2, 10, 64)), mstype.float32)
2886
- >>> decoder_input_mask = Tensor(np.ones((2, 10, 10)), mstype.float16)
2887
- >>> memory_mask = Tensor(np.ones((2, 10, 20)), mstype.float16)
2888
- >>> output, en_past, de_past = model(encoder_input_value, encoder_input_mask, decoder_input_value,
2889
- ... decoder_input_mask, memory_mask)
2890
- >>> print(output.shape)
2891
- (2, 10, 64)
2892
- >>> print(len(en_past))
2893
- 1
2894
- >>> print(len(de_past))
2895
- 2
2896
- >>> print(en_past[0][0].shape)
2897
- (2, 2, 32, 20)
2898
- >>> print(en_past[0][1].shape)
2899
- (2, 2, 20, 32)
2900
- >>> print(de_past[0][0].shape)
2901
- (2, 2, 32, 10)
2902
- >>> print(de_past[0][1].shape)
2903
- (2, 2, 10, 32)
2904
- >>> print(de_past[0][2].shape)
2905
- (2, 2, 32, 20)
2906
- >>> print(de_past[0][3].shape)
2907
- (2, 2, 20, 32)
2908
- """
2909
-
2910
- @_LogActionOnce(logger=logger, key='Transformer',
2911
- no_warning=_get_parallel_mode() in (ParallelMode.STAND_ALONE,))
2912
- @_args_type_validator_check(batch_size=Validator.check_positive_int,
2913
- hidden_size=Validator.check_positive_int,
2914
- num_heads=Validator.check_positive_int,
2915
- ffn_hidden_size=Validator.check_positive_int,
2916
- src_seq_length=Validator.check_positive_int,
2917
- encoder_layers=Validator.check_positive_int,
2918
- decoder_layers=Validator.check_non_negative_int,
2919
- tgt_seq_length=Validator.check_positive_int,
2920
- attention_dropout_rate=Validator.check_non_negative_float,
2921
- hidden_dropout_rate=Validator.check_non_negative_float,
2922
- post_layernorm_residual=Validator.check_bool,
2923
- layernorm_compute_type=_valid_value_checks([mstype.float32, mstype.float16],
2924
- "Transformer"),
2925
- softmax_compute_type=_valid_value_checks([mstype.float32, mstype.float16],
2926
- "Transformer"),
2927
- param_init_type=_valid_value_checks([mstype.float32, mstype.float16], "Transformer"),
2928
- parallel_config=_valid_type_checks([TransformerOpParallelConfig], "Transformer"),
2929
- use_past=Validator.check_bool)
2930
- def __init__(self,
2931
- hidden_size,
2932
- batch_size,
2933
- ffn_hidden_size,
2934
- src_seq_length,
2935
- tgt_seq_length,
2936
- encoder_layers=3,
2937
- decoder_layers=3,
2938
- num_heads=2,
2939
- attention_dropout_rate=0.1,
2940
- hidden_dropout_rate=0.1,
2941
- hidden_act='gelu',
2942
- post_layernorm_residual=False,
2943
- layernorm_compute_type=mstype.float32,
2944
- softmax_compute_type=mstype.float32,
2945
- param_init_type=mstype.float32,
2946
- lambda_func=None,
2947
- use_past=False,
2948
- moe_config=default_moe_config,
2949
- parallel_config=default_transformer_config):
2950
- super(Transformer, self).__init__()
2951
- if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,):
2952
- _check_config(parallel_config)
2953
- self.batch_size = batch_size
2954
- self.hidden_size = hidden_size
2955
- self.src_seq_length = src_seq_length
2956
- self.tgt_seq_length = tgt_seq_length
2957
- self.use_past = use_past
2958
- if encoder_layers <= 0 < decoder_layers:
2959
- raise ValueError(f"Transformer doest support encoder layer {encoder_layers} and decoder"
2960
- f"layer {decoder_layers}, please use TransformerDecoder")
2961
- if encoder_layers > 0 and decoder_layers > 0 and use_past:
2962
- raise ValueError(f"The {self.cls_name} with encoder and decoder does not support use_past=True.")
2963
- # The shard setting of Transformer is set within the TransformerEncoderLayer
2964
- if not lambda_func:
2965
- lambda_func = _get_lambda_func(total_layer=encoder_layers + decoder_layers)
2966
- _check_moe_config(moe_config, parallel_config)
2967
- self.use_moe = moe_config.expert_num > 1
2968
- self.add = P.Add()
2969
- self.aux_loss = Tensor(0.0, mstype.float32)
2970
- if encoder_layers > 0:
2971
- self.encoder = TransformerEncoder(num_layers=encoder_layers,
2972
- batch_size=batch_size,
2973
- hidden_size=hidden_size,
2974
- ffn_hidden_size=ffn_hidden_size,
2975
- num_heads=num_heads,
2976
- seq_length=src_seq_length,
2977
- attention_dropout_rate=attention_dropout_rate,
2978
- hidden_dropout_rate=hidden_dropout_rate,
2979
- hidden_act=hidden_act,
2980
- layernorm_compute_type=layernorm_compute_type,
2981
- softmax_compute_type=softmax_compute_type,
2982
- post_layernorm_residual=post_layernorm_residual,
2983
- param_init_type=param_init_type,
2984
- lambda_func=lambda_func,
2985
- use_past=use_past,
2986
- moe_config=moe_config,
2987
- parallel_config=parallel_config)
2988
- else:
2989
- self.encoder = None
2990
-
2991
- # Offset is needed as the encoder has consumed some flags.
2992
- # so the decoder need to increase the flags based on the encoder layer
2993
- self.decoder = None
2994
- if decoder_layers > 0:
2995
- self.decoder = TransformerDecoder(num_layers=decoder_layers,
2996
- batch_size=batch_size,
2997
- hidden_size=hidden_size,
2998
- ffn_hidden_size=ffn_hidden_size,
2999
- num_heads=num_heads,
3000
- src_seq_length=src_seq_length,
3001
- tgt_seq_length=tgt_seq_length,
3002
- attention_dropout_rate=attention_dropout_rate,
3003
- hidden_dropout_rate=hidden_dropout_rate,
3004
- hidden_act=hidden_act,
3005
- post_layernorm_residual=post_layernorm_residual,
3006
- layernorm_compute_type=layernorm_compute_type,
3007
- softmax_compute_type=softmax_compute_type,
3008
- lambda_func=lambda_func,
3009
- use_past=use_past,
3010
- param_init_type=param_init_type,
3011
- offset=encoder_layers,
3012
- moe_config=moe_config,
3013
- parallel_config=parallel_config)
3014
- elif _get_parallel_mode() not in (ParallelMode.AUTO_PARALLEL,):
3015
- _check_config(parallel_config)
3016
- self.batch_size = batch_size
3017
- self.hidden_size = hidden_size
3018
- self.src_seq_length = src_seq_length
3019
- self.tgt_seq_length = tgt_seq_length
3020
- self.use_past = use_past
3021
- if encoder_layers <= 0 < decoder_layers:
3022
- raise ValueError(f"Transformer doest support encoder layer {encoder_layers} and decoder"
3023
- f"layer {decoder_layers}, please use TransformerDecoder")
3024
- if encoder_layers > 0 and decoder_layers > 0 and use_past:
3025
- raise ValueError(f"The {self.cls_name} with encoder and decoder does not support use_past=True.")
3026
- logger.warning("For parallel mode, sharding propagation is recommended, you can use it by setting "
3027
- "'set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, "
3028
- "search_mode=\"sharding_propagation\")' and "
3029
- "'set_algo_parameters(elementwise_op_strategy_follow=False, fully_use_devices=False)'")
3030
- # The shard setting of Transformer is set within the TransformerEncoderLayer
3031
- if not lambda_func:
3032
- lambda_func = _get_lambda_func(total_layer=encoder_layers + decoder_layers)
3033
- _check_moe_config(moe_config, parallel_config)
3034
- self.use_moe = moe_config.expert_num > 1
3035
- self.add = P.Add().shard(((), ()))
3036
- self.aux_loss = Tensor(0.0, mstype.float32)
3037
- if encoder_layers > 0:
3038
- self.encoder = TransformerEncoder(num_layers=encoder_layers,
3039
- batch_size=batch_size,
3040
- hidden_size=hidden_size,
3041
- ffn_hidden_size=ffn_hidden_size,
3042
- num_heads=num_heads,
3043
- seq_length=src_seq_length,
3044
- attention_dropout_rate=attention_dropout_rate,
3045
- hidden_dropout_rate=hidden_dropout_rate,
3046
- hidden_act=hidden_act,
3047
- layernorm_compute_type=layernorm_compute_type,
3048
- softmax_compute_type=softmax_compute_type,
3049
- post_layernorm_residual=post_layernorm_residual,
3050
- param_init_type=param_init_type,
3051
- lambda_func=lambda_func,
3052
- use_past=use_past,
3053
- moe_config=moe_config,
3054
- parallel_config=parallel_config)
3055
- else:
3056
- self.encoder = None
3057
-
3058
- # Offset is needed as the encoder has consumed some flags.
3059
- # so the decoder need to increase the flags based on the encoder layer
3060
- self.decoder = None
3061
- if decoder_layers > 0:
3062
- self.decoder = TransformerDecoder(num_layers=decoder_layers,
3063
- batch_size=batch_size,
3064
- hidden_size=hidden_size,
3065
- ffn_hidden_size=ffn_hidden_size,
3066
- num_heads=num_heads,
3067
- src_seq_length=src_seq_length,
3068
- tgt_seq_length=tgt_seq_length,
3069
- attention_dropout_rate=attention_dropout_rate,
3070
- hidden_dropout_rate=hidden_dropout_rate,
3071
- hidden_act=hidden_act,
3072
- post_layernorm_residual=post_layernorm_residual,
3073
- layernorm_compute_type=layernorm_compute_type,
3074
- softmax_compute_type=softmax_compute_type,
3075
- lambda_func=lambda_func,
3076
- use_past=use_past,
3077
- param_init_type=param_init_type,
3078
- offset=encoder_layers,
3079
- moe_config=moe_config,
3080
- parallel_config=parallel_config)
3081
- else:
3082
- raise RuntimeError(f"The {self.cls_name} only support sharding propagation or "
3083
- f"semi-auto parallel mode now.")
3084
-
3085
- def construct(self, encoder_inputs,
3086
- encoder_masks,
3087
- decoder_inputs=None,
3088
- decoder_masks=None,
3089
- memory_mask=None,
3090
- init_reset=True,
3091
- batch_valid_length=None):
3092
-
3093
- encoder_output = None
3094
- output = None
3095
- encoder_layer_present = None
3096
- decoder_layer_present = None
3097
- accum_loss = self.aux_loss
3098
- if self.encoder is not None:
3099
- if self.use_moe:
3100
- encoder_output, encoder_layer_present, encoder_aux_loss = self.encoder(encoder_inputs, encoder_masks,
3101
- init_reset, batch_valid_length)
3102
- accum_loss = self.add(accum_loss, encoder_aux_loss)
3103
- else:
3104
- encoder_output, encoder_layer_present = self.encoder(encoder_inputs, encoder_masks, init_reset,
3105
- batch_valid_length)
3106
- output = encoder_output
3107
-
3108
- if self.decoder is not None:
3109
- # decoder mask should be created outside of the model
3110
- if self.use_moe:
3111
- decoder_output, decoder_layer_present, decoder_aux_loss = self.decoder(decoder_inputs, decoder_masks,
3112
- encoder_output, memory_mask,
3113
- init_reset, batch_valid_length)
3114
- accum_loss = self.add(accum_loss, decoder_aux_loss)
3115
- else:
3116
- decoder_output, decoder_layer_present = self.decoder(decoder_inputs,
3117
- decoder_masks,
3118
- encoder_output,
3119
- memory_mask, init_reset,
3120
- batch_valid_length)
3121
- output = decoder_output
3122
- if self.use_moe:
3123
- return output, encoder_layer_present, decoder_layer_present, accum_loss
3124
- return output, encoder_layer_present, decoder_layer_present