mindspore 2.1.0__cp39-none-any.whl → 2.2.10__cp39-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (569) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -1
  3. mindspore/_akg/akg/build_module.py +5 -6
  4. mindspore/_akg/akg/composite/build_module.py +46 -19
  5. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  6. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  7. mindspore/_akg/akg/tvm/api.py +4 -3
  8. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  9. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  10. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  11. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  12. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  13. mindspore/_akg/akg/tvm/build_module.py +16 -1
  14. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  15. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  16. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  17. mindspore/_akg/akg/tvm/module.py +1 -2
  18. mindspore/_akg/akg/tvm/stmt.py +2 -2
  19. mindspore/_akg/akg/utils/ascend_profilier/__init__.py +0 -0
  20. mindspore/_akg/akg/utils/ascend_profilier/cann_file_parser.py +76 -0
  21. mindspore/_akg/akg/utils/ascend_profilier/file_manager.py +56 -0
  22. mindspore/_akg/akg/utils/ascend_profilier/op_summary_bean.py +23 -0
  23. mindspore/_akg/akg/utils/ascend_profilier/op_summary_headers.py +8 -0
  24. mindspore/_akg/akg/utils/ascend_profilier/op_summary_parser.py +42 -0
  25. mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +65 -0
  26. mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
  27. mindspore/_akg/akg/utils/kernel_exec.py +98 -274
  28. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  29. mindspore/_akg/akg/utils/tbe_codegen_utils.py +219 -0
  30. mindspore/_akg/akg/utils/util.py +38 -0
  31. mindspore/_c_dataengine.cpython-39-aarch64-linux-gnu.so +0 -0
  32. mindspore/_c_expression.cpython-39-aarch64-linux-gnu.so +0 -0
  33. mindspore/_c_mindrecord.cpython-39-aarch64-linux-gnu.so +0 -0
  34. mindspore/_check_jit_forbidden_api.py +3 -1
  35. mindspore/_checkparam.py +23 -29
  36. mindspore/_extends/graph_kernel/__init__.py +0 -1
  37. mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
  38. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  39. mindspore/_extends/graph_kernel/splitter.py +4 -11
  40. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
  41. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
  42. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  43. mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
  44. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
  45. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
  46. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  47. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  48. mindspore/_extends/parse/__init__.py +12 -15
  49. mindspore/_extends/parse/namespace.py +7 -33
  50. mindspore/_extends/parse/parser.py +61 -71
  51. mindspore/_extends/parse/resources.py +1 -1
  52. mindspore/_extends/parse/standard_method.py +74 -104
  53. mindspore/_extends/parse/trope.py +1 -1
  54. mindspore/_extends/remote/kernel_build_server.py +25 -7
  55. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  56. mindspore/_install_custom.py +43 -0
  57. mindspore/_mindspore_offline_debug.cpython-39-aarch64-linux-gnu.so +0 -0
  58. mindspore/amp.py +47 -11
  59. mindspore/bin/cache_admin +0 -0
  60. mindspore/bin/cache_server +0 -0
  61. mindspore/boost/boost.py +1 -8
  62. mindspore/boost/boost_cell_wrapper.py +3 -2
  63. mindspore/boost/grad_accumulation.py +1 -1
  64. mindspore/boost/group_loss_scale_manager.py +8 -7
  65. mindspore/common/__init__.py +5 -3
  66. mindspore/common/_jit_fallback_utils.py +6 -0
  67. mindspore/common/_register_for_adapter.py +2 -0
  68. mindspore/common/_register_for_tensor.py +2 -2
  69. mindspore/common/_stub_tensor.py +13 -0
  70. mindspore/common/_utils.py +13 -0
  71. mindspore/common/api.py +174 -259
  72. mindspore/common/auto_dynamic_shape.py +494 -0
  73. mindspore/common/dtype.py +18 -11
  74. mindspore/common/dump.py +6 -4
  75. mindspore/common/initializer.py +14 -14
  76. mindspore/common/jit_config.py +33 -15
  77. mindspore/common/lazy_inline.py +126 -7
  78. mindspore/common/mindir_util.py +101 -0
  79. mindspore/common/parameter.py +51 -41
  80. mindspore/common/seed.py +4 -4
  81. mindspore/common/sparse_tensor.py +13 -14
  82. mindspore/common/tensor.py +243 -165
  83. mindspore/communication/__init__.py +7 -4
  84. mindspore/communication/_comm_helper.py +83 -4
  85. mindspore/communication/management.py +152 -84
  86. mindspore/config/op_info.config +14 -3
  87. mindspore/config/super_bar_config.json +4 -2
  88. mindspore/context.py +152 -61
  89. mindspore/dataset/__init__.py +5 -5
  90. mindspore/dataset/audio/__init__.py +2 -2
  91. mindspore/dataset/audio/transforms.py +52 -52
  92. mindspore/dataset/callback/ds_callback.py +16 -2
  93. mindspore/dataset/core/config.py +68 -51
  94. mindspore/dataset/engine/cache_client.py +28 -5
  95. mindspore/dataset/engine/datasets.py +250 -112
  96. mindspore/dataset/engine/datasets_audio.py +43 -211
  97. mindspore/dataset/engine/datasets_standard_format.py +16 -35
  98. mindspore/dataset/engine/datasets_text.py +43 -67
  99. mindspore/dataset/engine/datasets_user_defined.py +86 -100
  100. mindspore/dataset/engine/datasets_vision.py +219 -1029
  101. mindspore/dataset/engine/iterators.py +11 -4
  102. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
  103. mindspore/dataset/engine/obs/util.py +3 -0
  104. mindspore/dataset/engine/samplers.py +1 -1
  105. mindspore/dataset/engine/validators.py +19 -5
  106. mindspore/dataset/text/__init__.py +3 -3
  107. mindspore/dataset/text/transforms.py +101 -127
  108. mindspore/dataset/text/utils.py +205 -138
  109. mindspore/dataset/transforms/__init__.py +1 -1
  110. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  111. mindspore/dataset/transforms/transforms.py +95 -40
  112. mindspore/dataset/utils/browse_dataset.py +8 -2
  113. mindspore/dataset/utils/line_reader.py +17 -19
  114. mindspore/dataset/vision/__init__.py +3 -3
  115. mindspore/dataset/vision/c_transforms.py +6 -3
  116. mindspore/dataset/vision/transforms.py +409 -287
  117. mindspore/dataset/vision/utils.py +13 -14
  118. mindspore/dataset/vision/validators.py +11 -1
  119. mindspore/experimental/map_parameter.py +14 -0
  120. mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
  121. mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
  122. mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
  123. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  124. mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
  125. mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
  126. mindspore/gen_ops.py +273 -0
  127. mindspore/include/OWNERS +0 -1
  128. mindspore/include/api/data_type.h +2 -1
  129. mindspore/include/api/graph.h +0 -15
  130. mindspore/include/api/kernel.h +2 -0
  131. mindspore/include/api/kernel_api.h +37 -12
  132. mindspore/include/api/model.h +17 -14
  133. mindspore/include/api/status.h +8 -3
  134. mindspore/include/api/types.h +37 -4
  135. mindspore/include/c_api/ms/abstract.h +67 -0
  136. mindspore/include/c_api/ms/attribute.h +197 -0
  137. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  138. mindspore/include/c_api/ms/base/macros.h +32 -0
  139. mindspore/include/c_api/ms/base/status.h +33 -0
  140. mindspore/include/c_api/ms/base/types.h +282 -0
  141. mindspore/include/c_api/ms/context.h +102 -0
  142. mindspore/include/c_api/ms/graph.h +160 -0
  143. mindspore/include/c_api/ms/node.h +606 -0
  144. mindspore/include/c_api/ms/tensor.h +161 -0
  145. mindspore/include/c_api/ms/value.h +84 -0
  146. mindspore/include/dataset/constants.h +6 -5
  147. mindspore/include/dataset/execute.h +23 -13
  148. mindspore/include/dataset/text.h +26 -26
  149. mindspore/include/dataset/transforms.h +13 -13
  150. mindspore/include/dataset/vision.h +60 -60
  151. mindspore/include/dataset/vision_ascend.h +5 -6
  152. mindspore/include/dataset/vision_lite.h +17 -17
  153. mindspore/include/mindapi/base/type_id.h +1 -0
  154. mindspore/include/mindapi/base/types.h +1 -0
  155. mindspore/lib/libdnnl.so.2 +0 -0
  156. mindspore/lib/libjemalloc.so.2 +0 -0
  157. mindspore/lib/libmindspore.so +0 -0
  158. mindspore/lib/libmindspore_backend.so +0 -0
  159. mindspore/lib/libmindspore_common.so +0 -0
  160. mindspore/lib/libmindspore_core.so +0 -0
  161. mindspore/lib/libmindspore_glog.so.0 +0 -0
  162. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  163. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  164. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  165. mindspore/lib/libmindspore_shared_lib.so +0 -0
  166. mindspore/lib/libnnacl.so +0 -0
  167. mindspore/lib/libopencv_core.so.4.5 +0 -0
  168. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  169. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  170. mindspore/lib/libps_cache.so +0 -0
  171. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310/aic-ascend310-ops-info.json +123 -0
  172. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +123 -0
  173. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +158 -0
  174. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +37 -0
  175. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  176. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  177. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  178. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  179. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  180. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  181. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  182. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  183. mindspore/lib/plugin/ascend/custom_aicore_ops/op_proto/libop_proto.so +0 -0
  184. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  185. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  186. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +8928 -0
  187. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  188. mindspore/lib/plugin/ascend/libakg.so +0 -0
  189. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  190. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  191. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  192. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  193. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  194. mindspore/lib/plugin/cpu/libakg.so +0 -0
  195. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  196. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  197. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  198. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  199. mindspore/nn/__init__.py +0 -2
  200. mindspore/nn/cell.py +313 -74
  201. mindspore/nn/dynamic_lr.py +21 -21
  202. mindspore/nn/layer/activation.py +22 -30
  203. mindspore/nn/layer/basic.py +15 -13
  204. mindspore/nn/layer/channel_shuffle.py +1 -1
  205. mindspore/nn/layer/container.py +271 -9
  206. mindspore/nn/layer/conv.py +323 -204
  207. mindspore/nn/layer/dense.py +8 -5
  208. mindspore/nn/layer/embedding.py +33 -27
  209. mindspore/nn/layer/flash_attention.py +141 -88
  210. mindspore/nn/layer/image.py +8 -6
  211. mindspore/nn/layer/math.py +16 -25
  212. mindspore/nn/layer/normalization.py +107 -66
  213. mindspore/nn/layer/padding.py +1 -1
  214. mindspore/nn/layer/pooling.py +131 -109
  215. mindspore/nn/layer/rnn_cells.py +27 -22
  216. mindspore/nn/layer/rnns.py +13 -16
  217. mindspore/nn/layer/thor_layer.py +1 -1
  218. mindspore/nn/layer/transformer.py +221 -154
  219. mindspore/nn/learning_rate_schedule.py +9 -1
  220. mindspore/nn/loss/loss.py +235 -174
  221. mindspore/nn/optim/ada_grad.py +2 -1
  222. mindspore/nn/optim/adadelta.py +1 -0
  223. mindspore/nn/optim/adafactor.py +2 -1
  224. mindspore/nn/optim/adam.py +7 -4
  225. mindspore/nn/optim/adamax.py +3 -2
  226. mindspore/nn/optim/adasum.py +2 -2
  227. mindspore/nn/optim/asgd.py +2 -3
  228. mindspore/nn/optim/ftrl.py +6 -5
  229. mindspore/nn/optim/lamb.py +7 -4
  230. mindspore/nn/optim/lars.py +1 -1
  231. mindspore/nn/optim/lazyadam.py +5 -3
  232. mindspore/nn/optim/momentum.py +2 -1
  233. mindspore/nn/optim/optimizer.py +53 -4
  234. mindspore/nn/optim/proximal_ada_grad.py +3 -4
  235. mindspore/nn/optim/rmsprop.py +4 -3
  236. mindspore/nn/optim/rprop.py +23 -12
  237. mindspore/nn/optim/sgd.py +26 -11
  238. mindspore/nn/optim/thor.py +9 -7
  239. mindspore/nn/probability/bijector/bijector.py +5 -5
  240. mindspore/nn/probability/bijector/power_transform.py +27 -27
  241. mindspore/nn/probability/bijector/softplus.py +3 -3
  242. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
  243. mindspore/nn/probability/distribution/bernoulli.py +5 -5
  244. mindspore/nn/probability/distribution/beta.py +3 -3
  245. mindspore/nn/probability/distribution/categorical.py +7 -7
  246. mindspore/nn/probability/distribution/cauchy.py +0 -1
  247. mindspore/nn/probability/distribution/distribution.py +3 -3
  248. mindspore/nn/probability/distribution/gamma.py +3 -3
  249. mindspore/nn/probability/distribution/geometric.py +4 -4
  250. mindspore/nn/probability/distribution/gumbel.py +4 -4
  251. mindspore/nn/probability/distribution/log_normal.py +2 -2
  252. mindspore/nn/probability/distribution/logistic.py +2 -2
  253. mindspore/nn/probability/distribution/poisson.py +4 -4
  254. mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
  255. mindspore/nn/probability/distribution/uniform.py +6 -6
  256. mindspore/nn/wrap/cell_wrapper.py +84 -34
  257. mindspore/nn/wrap/grad_reducer.py +8 -5
  258. mindspore/nn/wrap/loss_scale.py +105 -42
  259. mindspore/numpy/array_creations.py +1 -2
  260. mindspore/numpy/array_ops.py +3 -2
  261. mindspore/numpy/utils_const.py +5 -5
  262. mindspore/offline_debug/convert_async.py +2 -2
  263. mindspore/ops/_grad_experimental/__init__.py +0 -5
  264. mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
  265. mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
  266. mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
  267. mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
  268. mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
  269. mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
  270. mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
  271. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
  272. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  273. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +165 -109
  274. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +144 -86
  275. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +172 -187
  276. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +51 -57
  277. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +6 -17
  278. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +1 -1
  279. mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
  280. mindspore/ops/_op_impl/aicpu/add.py +3 -3
  281. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  282. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  283. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  284. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  285. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  286. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  287. mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
  288. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  289. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  290. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  291. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  292. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  293. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  294. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  295. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
  296. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
  297. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  298. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  299. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  300. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  301. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  302. mindspore/ops/_op_impl/tbe/__init__.py +4 -4
  303. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  304. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  305. mindspore/ops/_primitive_cache.py +1 -1
  306. mindspore/ops/_tracefunc.py +45 -13
  307. mindspore/ops/_utils/utils.py +6 -1
  308. mindspore/ops/_vmap/vmap_array_ops.py +3 -3
  309. mindspore/ops/_vmap/vmap_base.py +3 -3
  310. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  311. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  312. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  313. mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
  314. mindspore/ops/arg_dtype_cast.py +54 -0
  315. mindspore/ops/composite/base.py +37 -10
  316. mindspore/ops/composite/math_ops.py +5 -4
  317. mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
  318. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
  319. mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
  320. mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
  321. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  322. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  323. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  324. mindspore/ops/deprecated.py +304 -0
  325. mindspore/ops/function/__init__.py +4 -1
  326. mindspore/ops/function/array_func.py +174 -193
  327. mindspore/ops/function/clip_func.py +81 -13
  328. mindspore/ops/function/debug_func.py +1 -1
  329. mindspore/ops/function/grad/grad_func.py +18 -9
  330. mindspore/ops/function/image_func.py +10 -4
  331. mindspore/ops/function/linalg_func.py +5 -5
  332. mindspore/ops/function/math_func.py +575 -386
  333. mindspore/ops/function/nn_func.py +568 -260
  334. mindspore/ops/function/random_func.py +88 -57
  335. mindspore/ops/function/sparse_func.py +1 -1
  336. mindspore/ops/function/sparse_unary_func.py +14 -12
  337. mindspore/ops/function/vmap_func.py +6 -5
  338. mindspore/ops/functional.py +15 -10
  339. mindspore/ops/op_info_register.py +244 -25
  340. mindspore/ops/operations/__init__.py +28 -19
  341. mindspore/ops/operations/_grad_ops.py +72 -7
  342. mindspore/ops/operations/_inner_ops.py +350 -17
  343. mindspore/ops/operations/_quant_ops.py +4 -8
  344. mindspore/ops/operations/_sequence_ops.py +42 -0
  345. mindspore/ops/operations/array_ops.py +68 -282
  346. mindspore/ops/operations/comm_ops.py +107 -59
  347. mindspore/ops/operations/custom_ops.py +94 -70
  348. mindspore/ops/operations/debug_ops.py +8 -4
  349. mindspore/ops/operations/image_ops.py +18 -12
  350. mindspore/ops/operations/inner_ops.py +26 -3
  351. mindspore/ops/operations/math_ops.py +189 -141
  352. mindspore/ops/operations/nn_ops.py +794 -489
  353. mindspore/ops/operations/other_ops.py +0 -22
  354. mindspore/ops/operations/random_ops.py +53 -111
  355. mindspore/ops/operations/sparse_ops.py +3 -1
  356. mindspore/ops/primitive.py +24 -18
  357. mindspore/parallel/_auto_parallel_context.py +68 -8
  358. mindspore/parallel/_cost_model_context.py +2 -2
  359. mindspore/parallel/_offload_context.py +17 -3
  360. mindspore/parallel/_parallel_serialization.py +12 -5
  361. mindspore/parallel/_ps_context.py +12 -0
  362. mindspore/parallel/_tensor.py +18 -13
  363. mindspore/parallel/_transformer/layers.py +5 -3
  364. mindspore/parallel/_transformer/loss.py +1 -0
  365. mindspore/parallel/_transformer/moe.py +2 -2
  366. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  367. mindspore/parallel/_transformer/transformer.py +23 -3
  368. mindspore/parallel/_utils.py +11 -7
  369. mindspore/parallel/algo_parameter_config.py +85 -5
  370. mindspore/parallel/checkpoint_transform.py +19 -12
  371. mindspore/parallel/shard.py +21 -14
  372. mindspore/profiler/common/struct_type.py +3 -3
  373. mindspore/profiler/common/util.py +4 -2
  374. mindspore/profiler/envprofiling.py +1 -1
  375. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  376. mindspore/profiler/parser/ascend_flops_generator.py +2 -2
  377. mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
  378. mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
  379. mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
  380. mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
  381. mindspore/profiler/parser/ascend_op_generator.py +6 -6
  382. mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
  383. mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
  384. mindspore/profiler/parser/base_timeline_generator.py +10 -8
  385. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
  386. mindspore/profiler/parser/flops_parser.py +15 -11
  387. mindspore/profiler/parser/framework_parser.py +38 -22
  388. mindspore/profiler/parser/hccl_parser.py +16 -12
  389. mindspore/profiler/parser/integrator.py +22 -11
  390. mindspore/profiler/parser/memory_usage_parser.py +2 -2
  391. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  392. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  393. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  394. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  395. mindspore/profiler/parser/optime_parser.py +1 -1
  396. mindspore/profiler/parser/profiler_info.py +21 -2
  397. mindspore/profiler/parser/step_trace_parser.py +11 -14
  398. mindspore/profiler/profiling.py +179 -89
  399. mindspore/rewrite/api/node.py +102 -19
  400. mindspore/rewrite/api/node_type.py +5 -1
  401. mindspore/rewrite/api/pattern_engine.py +1 -1
  402. mindspore/rewrite/api/scoped_value.py +9 -17
  403. mindspore/rewrite/api/symbol_tree.py +131 -47
  404. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  405. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  406. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  407. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
  408. mindspore/rewrite/common/rewrite_elog.py +5 -1
  409. mindspore/rewrite/namer.py +33 -24
  410. mindspore/rewrite/namespace.py +14 -5
  411. mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
  412. mindspore/rewrite/node/call_function.py +79 -0
  413. mindspore/rewrite/node/cell_container.py +135 -0
  414. mindspore/rewrite/node/control_flow.py +88 -0
  415. mindspore/rewrite/{node.py → node/node.py} +273 -234
  416. mindspore/rewrite/node/node_manager.py +254 -0
  417. mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
  418. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  419. mindspore/rewrite/parsers/assign_parser.py +216 -221
  420. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  421. mindspore/rewrite/parsers/class_def_parser.py +174 -113
  422. mindspore/rewrite/parsers/constant_parser.py +9 -6
  423. mindspore/rewrite/parsers/container_parser.py +9 -7
  424. mindspore/rewrite/parsers/for_parser.py +36 -15
  425. mindspore/rewrite/parsers/function_def_parser.py +24 -16
  426. mindspore/rewrite/parsers/if_parser.py +28 -24
  427. mindspore/rewrite/parsers/module_parser.py +196 -25
  428. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  429. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  430. mindspore/rewrite/parsers/return_parser.py +6 -6
  431. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  432. mindspore/rewrite/sparsify/utils.py +1 -1
  433. mindspore/rewrite/symbol_tree.py +523 -578
  434. mindspore/rewrite/symbol_tree_builder.py +9 -193
  435. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  436. mindspore/run_check/_check_version.py +6 -4
  437. mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
  438. mindspore/safeguard/rewrite_obfuscation.py +541 -0
  439. mindspore/scipy/linalg.py +1 -1
  440. mindspore/scipy/optimize/minimize.py +7 -3
  441. mindspore/train/_utils.py +7 -3
  442. mindspore/train/amp.py +323 -123
  443. mindspore/train/anf_ir_pb2.py +14 -2
  444. mindspore/train/callback/_backup_and_restore.py +2 -12
  445. mindspore/train/callback/_callback.py +29 -4
  446. mindspore/train/callback/_checkpoint.py +23 -8
  447. mindspore/train/callback/_early_stop.py +2 -2
  448. mindspore/train/callback/_landscape.py +4 -4
  449. mindspore/train/callback/_loss_monitor.py +2 -2
  450. mindspore/train/callback/_on_request_exit.py +2 -2
  451. mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
  452. mindspore/train/callback/_summary_collector.py +15 -8
  453. mindspore/train/callback/_time_monitor.py +58 -5
  454. mindspore/train/data_sink.py +5 -11
  455. mindspore/train/dataset_helper.py +84 -57
  456. mindspore/train/loss_scale_manager.py +2 -2
  457. mindspore/train/metrics/__init__.py +3 -3
  458. mindspore/train/metrics/cosine_similarity.py +1 -1
  459. mindspore/train/metrics/hausdorff_distance.py +3 -2
  460. mindspore/train/metrics/mean_surface_distance.py +3 -2
  461. mindspore/train/metrics/metric.py +39 -19
  462. mindspore/train/metrics/roc.py +2 -2
  463. mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
  464. mindspore/train/mind_ir_pb2.py +85 -36
  465. mindspore/train/model.py +187 -47
  466. mindspore/train/serialization.py +487 -161
  467. mindspore/train/summary/_summary_adapter.py +1 -1
  468. mindspore/train/summary/_writer_pool.py +3 -2
  469. mindspore/train/summary/summary_record.py +37 -17
  470. mindspore/train/train_thor/convert_utils.py +3 -3
  471. mindspore/train/train_thor/dataset_helper.py +1 -1
  472. mindspore/version.py +1 -1
  473. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/METADATA +6 -7
  474. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/RECORD +477 -517
  475. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/entry_points.txt +0 -1
  476. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  477. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  478. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  479. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  480. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  481. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  482. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  483. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  484. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  485. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  486. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  487. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  488. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  489. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  490. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  491. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  492. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  493. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  494. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  495. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  496. mindspore/_extends/graph_kernel/expander.py +0 -80
  497. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
  498. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  499. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  500. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  501. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  502. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  503. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  504. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  505. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  506. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  507. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  508. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  509. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  510. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  511. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  512. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  513. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  514. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  515. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  516. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  517. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  518. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  519. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  520. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  521. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  522. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  523. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  524. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  525. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  526. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  527. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  528. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  529. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  530. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  531. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  532. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  533. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  534. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  535. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  536. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  537. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  538. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  539. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  540. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  541. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  542. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  543. mindspore/dataset/datapreprocess/__init__.py +0 -20
  544. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  545. mindspore/include/api/net.h +0 -142
  546. mindspore/nn/lr_scheduler.py +0 -262
  547. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
  548. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
  549. mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
  550. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  551. mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
  552. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
  553. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
  554. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
  555. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  556. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  557. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
  558. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
  559. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  560. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  561. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
  562. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
  563. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
  564. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
  565. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
  566. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  567. mindspore/rewrite/node_visitor.py +0 -44
  568. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/WHEEL +0 -0
  569. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/top_level.txt +0 -0
@@ -16,11 +16,11 @@
16
16
  Transformer Cells module, include TransformerEncoderLayer, TransformerDecoderLayer,
17
17
  TransformerEncoder, TransformerDecoder, Transformer.
18
18
  """
19
- import copy
20
19
  import math
21
20
  from typing import Union, Optional
22
21
  import mindspore
23
22
  import mindspore.ops as ops
23
+ import mindspore.common.dtype as mstype
24
24
  from mindspore.common.tensor import Tensor
25
25
  from mindspore.common.parameter import Parameter
26
26
  from mindspore.common.initializer import initializer, XavierNormal, XavierUniform, \
@@ -36,30 +36,20 @@ __all__ = ['MultiheadAttention', 'TransformerEncoderLayer', 'TransformerDecoderL
36
36
  'TransformerEncoder', 'TransformerDecoder', 'Transformer']
37
37
 
38
38
 
39
- class _Linear(Dense):
40
- def __init__(self, in_channels, out_channels, has_bias=True):
41
- fan_in, _ = _calculate_fan_in_and_fan_out((out_channels, in_channels))
42
- bound = 1 / math.sqrt(fan_in)
43
- super().__init__(in_channels, out_channels, weight_init=HeUniform(math.sqrt(5)),
44
- bias_init=Uniform(bound), has_bias=has_bias, activation=None)
45
-
46
-
47
39
  class MultiheadAttention(Cell):
48
40
  r"""
49
41
  This is an implementation of multihead attention in the paper `Attention is all you need
50
- <https://arxiv.org/pdf/1706.03762v5.pdf>`_. Given the query vector with source length, and the
51
- key and value vector with target length, the attention will be performed as the following
42
+ <https://arxiv.org/pdf/1706.03762v5.pdf>`_. Given the query vector, the key vector and value vector,
43
+ the attention will be performed as the following:
52
44
 
53
45
  .. math::
54
- MultiHeadAttention(query, key, vector) = Concat(head_1, \dots, head_h)W^O
46
+ MultiHeadAttention(query, key, value) = Concat(head_1, \dots, head_h)W^O
55
47
 
56
- where :math:`head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)`. The default is with a bias.
48
+ where :math:`head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)`, and :math:`W^O` , :math:`W_i^Q` , :math:`W_i^K` ,
49
+ :math:`W_i^V` are weight matrices. The default input / output projection layers is with a bias.
57
50
 
58
51
  if query, key and value tensor is same, then it will be self attention.
59
52
 
60
- .. warning::
61
- This is an experimental API that is subject to change or deletion.
62
-
63
53
  Args:
64
54
  embed_dim (int): Total dimension of MultiheadAttention.
65
55
  num_heads (int): Number of attention heads. Note that `embed_dim` will be split
@@ -73,36 +63,37 @@ class MultiheadAttention(Cell):
73
63
  vdim (int): Total number of features for values. Default: ``None`` (`vdim=embed_dim`).
74
64
  batch_first (bool): If ``True``, then the input and output shape are :math:`(batch, seq, feature)` ,
75
65
  else :math:`(seq, batch, feature)` . Default: ``False``.
66
+ dtype (:class:`mindspore.dtype`): Data type of Parameter. Default: ``mstype.float32`` .
76
67
 
77
68
  Inputs:
78
69
  - **query** (Tensor): The query embeddings. If `query` is unbatched, the shape is :math:`(L, E_q)`,
79
70
  otherwise the shape is :math:`(L, N, E_q)` when `batch_first=False` or :math:`(N, L, E_q)` when
80
- `batch_first=True`, where :math:`L`is the target sequence length, :math:`N` is the batch size,
81
- and :math:`E_q` is the query embedding dimension `embed_dim`. Queries are compared against
82
- key-value pairs to produce the output. See "Attention Is All You Need" for more details.
71
+ `batch_first=True` , where :math:`L`is the target sequence length, :math:`N` is the batch size,
72
+ and :math:`E_q` is the query embedding dimension `embed_dim`. Supported types: float16, float32,
73
+ float64. Queries are compared against key-value pairs to produce the output.
83
74
  - **key** (Tensor): The key embeddings. If `key` is unbatched, the shape is :math:`(S, E_k)`, otherwise
84
75
  the shape is :math:`(S, N, E_k)` when `batch_first=False` or :math:`(N, S, E_k)` when
85
- `batch_first=True`, where :math:`S` is the source sequence length, :math:`N` is the batch size,
86
- and :math:`E_k` is the key embedding dimension `kdim`. See "Attention Is All You Need" for more details.
76
+ `batch_first=True` , where :math:`S` is the source sequence length, :math:`N` is the batch size,
77
+ and :math:`E_k` is the key embedding dimension `kdim`. Supported types: float16, float32, float64.
87
78
  - **value** (Tensor): The value embeddings. If `value` is unbatched, the shape is :math:`(S, E_v)`,
88
79
  otherwise the shape is :math:`(S, N, E_v)` when `batch_first=False` or :math:`(N, S, E_v)` when
89
- `batch_first=True`, where :math:`S` is the source sequence length, :math:`N` is the batch size,
90
- and :math:`E_v` is the value embedding dimension `vdim`. See "Attention Is All You Need" for more details.
80
+ `batch_first=True` , where :math:`S` is the source sequence length, :math:`N` is the batch size,
81
+ and :math:`E_v` is the value embedding dimension `vdim`. Supported types: float16, float32, float64.
91
82
  - **key_padding_mask** (Tensor, optional): If specified, a mask of shape :math:`(N, S)` indicating which
92
83
  elements within `key` to ignore for the purpose of attention (i.e. treat as "padding").
93
- For unbatched `query`, shape should be :math:`(S)`. Binary and byte masks are supported.
84
+ For unbatched `query`, shape should be :math:`(S)`. Binary and float masks are supported.
94
85
  For a binary mask, a ``True`` value indicates that the corresponding `key` value will be ignored for
95
86
  the purpose of attention. For a float mask, it will be directly added to the corresponding `key` value.
87
+ Supported float types: float16, float32, float64. Default: ``None``.
96
88
  - **need_weights** (bool): Whether returns `attn_output_weights` in addition to `attn_outputs`.
97
89
  Default: ``True``.
98
90
  - **attn_mask** (Tensor, optional): If specified, a 2D or 3D mask preventing attention to certain positions.
99
91
  Must be of shape :math:`(L, S)` or :math:`(N\cdot\text{num_heads}, L, S)`, where :math:`N` is the
100
92
  batch size, :math:`L` is the target sequence length, and :math:`S` is the source sequence length.
101
93
  A 2D mask will be broadcasted across the batch while a 3D mask allows for a different mask for each entry
102
- in the batch. Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates
103
- that the corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that
104
- the corresponding position is not allowed to attend. For a float mask, the mask values will be added to
105
- the attention weight.
94
+ in the batch. For a binary mask, a ``True`` value indicates that the corresponding position is not allowed
95
+ to attend. For a float mask, the mask values will be added to the attention weight.
96
+ Supported float types: float16, float32, float64. Default: ``None``.
106
97
  - **average_attn_weights** (bool): If true, indicates that the returned `attn_weights` should be averaged
107
98
  across heads. Otherwise, `attn_weights` are provided separately per head. Note that this flag only
108
99
  has an effect when `need_weights=True`. Default: ``True`` (i.e. average weights across heads)
@@ -112,7 +103,7 @@ class MultiheadAttention(Cell):
112
103
 
113
104
  - **attn_output** - Attention outputs. If input is unbatched, the output shape is :math:`(L, E)`, otherwise
114
105
  the output shape is :math:`(L, N, E)` when `batch_first=False` or :math:`(N, L, E)` when
115
- `batch_first=True`, where :math:`L` is the target sequence length, :math:`N` is the batch size,
106
+ `batch_first=True` , where :math:`L` is the target sequence length, :math:`N` is the batch size,
116
107
  and :math:`E` is the embedding dimension `embed_dim`.
117
108
  - **attn_output_weights** - Only returned when `need_weights=True`. If `average_attn_weights=True`,
118
109
  returns attention weights averaged across heads with shape :math:`(L, S)` when input is unbatched or
@@ -143,8 +134,8 @@ class MultiheadAttention(Cell):
143
134
  (10, 8, 128)
144
135
  """
145
136
 
146
- def __init__(self, embed_dim, num_heads, dropout=0., has_bias=True, add_bias_kv=False,
147
- add_zero_attn=False, kdim=None, vdim=None, batch_first=False):
137
+ def __init__(self, embed_dim, num_heads, dropout=0.0, has_bias=True, add_bias_kv=False,
138
+ add_zero_attn=False, kdim=None, vdim=None, batch_first=False, dtype=mstype.float32):
148
139
  super().__init__()
149
140
  self.embed_dim = embed_dim
150
141
  self.kdim = kdim if kdim is not None else embed_dim
@@ -158,32 +149,39 @@ class MultiheadAttention(Cell):
158
149
  if self.head_dim * num_heads != self.embed_dim:
159
150
  raise ValueError("The init argument 'embed_dim' must be divisible by 'num_heads'.")
160
151
 
152
+ if dtype is None:
153
+ dtype = mindspore.float32
161
154
  if not self._qkv_same_embed_dim:
162
- self.q_proj_weight = Parameter(initializer(XavierUniform(), (embed_dim, embed_dim)), 'q_proj_weight')
163
- self.k_proj_weight = Parameter(initializer(XavierUniform(), (embed_dim, self.kdim)), 'k_proj_weight')
164
- self.v_proj_weight = Parameter(initializer(XavierUniform(), (embed_dim, self.vdim)), 'v_proj_weight')
155
+ self.q_proj_weight = Parameter(initializer(XavierUniform(), (embed_dim, embed_dim), dtype), 'q_proj_weight')
156
+ self.k_proj_weight = Parameter(initializer(XavierUniform(), (embed_dim, self.kdim), dtype), 'k_proj_weight')
157
+ self.v_proj_weight = Parameter(initializer(XavierUniform(), (embed_dim, self.vdim), dtype), 'v_proj_weight')
165
158
  self.in_proj_weight = None
166
159
  else:
167
- self.in_proj_weight = Parameter(initializer(XavierUniform(), (3 * embed_dim, embed_dim)), 'in_proj_weight')
160
+ self.in_proj_weight = Parameter(initializer(XavierUniform(), (3 * embed_dim, embed_dim), dtype),
161
+ 'in_proj_weight')
168
162
  self.q_proj_weight = None
169
163
  self.k_proj_weight = None
170
164
  self.v_proj_weight = None
171
165
 
172
166
  if has_bias:
173
- self.in_proj_bias = Parameter(initializer('zeros', (3 * embed_dim)), 'in_proj_bias')
167
+ self.in_proj_bias = Parameter(initializer('zeros', (3 * embed_dim), dtype), 'in_proj_bias')
174
168
  else:
175
169
  self.in_proj_bias = None
176
- self.out_proj = _Linear(embed_dim, embed_dim, has_bias=has_bias)
170
+ fan_in, _ = _calculate_fan_in_and_fan_out((embed_dim, embed_dim))
171
+ bound = 1 / math.sqrt(fan_in)
172
+ self.out_proj = Dense(embed_dim, embed_dim, has_bias=has_bias, weight_init=HeUniform(math.sqrt(5)),
173
+ bias_init=Uniform(bound), dtype=dtype)
177
174
 
178
175
  if add_bias_kv:
179
- self.bias_k = Parameter(initializer(XavierNormal(), (1, 1, embed_dim)), 'bias_k')
180
- self.bias_v = Parameter(initializer(XavierNormal(), (1, 1, embed_dim)), 'bias_v')
176
+ self.bias_k = Parameter(initializer(XavierNormal(), (1, 1, embed_dim), dtype), 'bias_k')
177
+ self.bias_v = Parameter(initializer(XavierNormal(), (1, 1, embed_dim), dtype), 'bias_v')
181
178
  else:
182
179
  self.bias_k = self.bias_v = None
183
180
 
184
181
  self.add_zero_attn = add_zero_attn
185
182
  self.k_is_v = False
186
183
  self.q_is_k = False
184
+ self.dtype = dtype
187
185
 
188
186
  def __call__(self, *args, **kwargs):
189
187
  query = kwargs.get('query', args[0])
@@ -224,7 +222,7 @@ class MultiheadAttention(Cell):
224
222
  attn_mask=attn_mask, use_separate_proj_weight=True,
225
223
  q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
226
224
  v_proj_weight=self.v_proj_weight, average_attn_weights=average_attn_weights,
227
- k_is_v=self.k_is_v, q_is_k=self.q_is_k)
225
+ k_is_v=self.k_is_v, q_is_k=self.q_is_k, dtype=self.dtype)
228
226
  else:
229
227
  attn_output, attn_output_weights = multi_head_attention_forward(
230
228
  query, key, value, self.embed_dim, self.num_heads,
@@ -234,7 +232,7 @@ class MultiheadAttention(Cell):
234
232
  training=self.training,
235
233
  key_padding_mask=key_padding_mask,
236
234
  attn_mask=attn_mask, average_attn_weights=average_attn_weights,
237
- k_is_v=self.k_is_v, q_is_k=self.q_is_k)
235
+ k_is_v=self.k_is_v, q_is_k=self.q_is_k, dtype=self.dtype)
238
236
 
239
237
  if self.batch_first and is_batched:
240
238
  attn_output = attn_output.swapaxes(1, 0)
@@ -248,38 +246,44 @@ class TransformerEncoderLayer(Cell):
248
246
  Transformer Encoder Layer. This is an implementation of the single layer of the transformer
249
247
  encoder layer, including multihead attention and feedward layer.
250
248
 
251
- .. warning::
252
- This is an experimental API that is subject to change or deletion.
253
-
254
249
  Args:
255
250
  d_model (int): The number of features in the input tensor.
256
251
  nhead (int): The number of heads in the MultiheadAttention modules.
257
252
  dim_feedforward (int): The dimension of the feedforward layer. Default: ``2048``.
258
253
  dropout (float): The dropout value. Default: ``0.1``.
259
254
  activation (Union[str, callable, Cell]): The activation function of the intermediate layer,
260
- can be a string (``"relu"`` or ``"gelu"``), Cell instance (``nn.ReLU()`` or ``nn.GELU()``) or
261
- a callable (``ops.relu`` or ``ops.gelu``). Default: ``"relu"``.
255
+ can be a string (``"relu"`` or ``"gelu"``), Cell instance (:class:`mindspore.nn.ReLU` or
256
+ :class:`mindspore.nn.GELU` ) or a callable ( :func:`mindspore.ops.relu` or
257
+ :func:`mindspore.ops.gelu` ). Default: ``"relu"``.
262
258
  layer_norm_eps (float): The epsilon value in LayerNorm modules. Default: ``1e-5``.
263
- batch_first (bool): If `batch_first = True`, then the shape of input and output tensors is
259
+ batch_first (bool): If `batch_first=True` , then the shape of input and output tensors is
264
260
  :math:`(batch, seq, feature)` , otherwise the shape is :math:`(seq, batch, feature)` .
265
261
  Default: ``False``.
266
- norm_first (bool): If `norm_first = True`, layer norm is done prior to attention and feedforward
267
- operations, respectively. Default: ``False``.
262
+ norm_first (bool): If `norm_first = True`, layer norm is located prior to attention and feedforward
263
+ operations; if `norm_first = False`, layer norm is located after the attention and feedforward
264
+ operations. Default: ``False``.
265
+ dtype (:class:`mindspore.dtype`): Data type of Parameter. Default: ``mstype.float32`` .
268
266
 
269
267
  Inputs:
270
- - **src** (Tensor): the sequence to the encoder layer.
271
- - **src_mask** (Tensor, optional): the mask for the src sequence. Default: ``None``.
272
- - **src_key_padding_mask** (Tensor, optional): the mask for the src keys per batch.
273
- Default: ``None``.
268
+ - **src** (Tensor): the sequence to the encoder layer. For unbatched input, the shape is
269
+ :math:`(S, E)` ; otherwise if `batch_first=False` , the shape is :math:`(S, N, E)` and if
270
+ `batch_first=True` , the shape is :math:`(S, N, E)`, where :math:`(S)` is the source sequence
271
+ length, :math:`(N)` is the batch number and :math:`(E)` is the feature number.
272
+ Supported types: float16, float32, float64.
273
+ - **src_mask** (Tensor, optional): the mask for the src sequence. The shape is :math:`(S, S)`
274
+ or :math:`(N*nhead, S, S)`. Supported types: float16, float32, float64, bool. Default: ``None``.
275
+ - **src_key_padding_mask** (Tensor, optional): the mask for the src keys per batch. The shape is
276
+ :math:`(S)` for unbatched input, otherwise :math:`(N, S)` . Supported types: float16, float32,
277
+ float64, bool. Default: ``None``.
274
278
 
275
279
  Outputs:
276
- Tensor.
280
+ Tensor. The shape and dtype of Tensor is the same with `src` .
277
281
 
278
282
  Raises:
279
283
  ValueError: If the init argument `activation` is not str, callable or Cell instance.
280
284
  ValueError: If the init argument `activation` is not :class:`mindspore.nn.ReLU`,
281
285
  :class:`mindspore.nn.GELU` instance, :func:`mindspore.ops.relu`,
282
- :func:`mindspore.ops.gelu` instance, "relu" or "gelu" .
286
+ :func:`mindspore.ops.gelu`, "relu" or "gelu" .
283
287
 
284
288
  Supported Platforms:
285
289
  ``Ascend`` ``GPU`` ``CPU``
@@ -290,6 +294,8 @@ class TransformerEncoderLayer(Cell):
290
294
  >>> encoder_layer = ms.nn.TransformerEncoderLayer(d_model=512, nhead=8)
291
295
  >>> src = ms.Tensor(np.random.rand(10, 32, 512), ms.float32)
292
296
  >>> out = encoder_layer(src)
297
+ >>> print(out.shape)
298
+ (10, 32, 512)
293
299
  >>> # Alternatively, when batch_first=True:
294
300
  >>> encoder_layer = ms.nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True)
295
301
  >>> src = ms.Tensor(np.random.rand(32, 10, 512), ms.float32)
@@ -297,33 +303,39 @@ class TransformerEncoderLayer(Cell):
297
303
  >>> print(out.shape)
298
304
  (32, 10, 512)
299
305
  """
300
- __constants__ = ['batch_first', 'norm_first']
301
306
 
302
307
  def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1,
303
308
  activation: Union[str, Cell, callable] = 'relu', layer_norm_eps: float = 1e-5,
304
- batch_first: bool = False, norm_first: bool = False):
309
+ batch_first: bool = False, norm_first: bool = False, dtype=mstype.float32):
305
310
  super().__init__()
306
- self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first)
311
+ self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first, dtype=dtype)
307
312
  # feedforward layer
308
- self.linear1 = _Linear(d_model, dim_feedforward)
313
+ fan_in, _ = _calculate_fan_in_and_fan_out((dim_feedforward, d_model))
314
+ bound = 1 / math.sqrt(fan_in)
315
+ self.dense1 = Dense(d_model, dim_feedforward, weight_init=HeUniform(math.sqrt(5)),
316
+ bias_init=Uniform(bound), dtype=dtype)
309
317
  self.dropout = Dropout(p=dropout)
310
- self.linear2 = _Linear(dim_feedforward, d_model)
318
+ fan_in1, _ = _calculate_fan_in_and_fan_out((d_model, dim_feedforward))
319
+ bound1 = 1 / math.sqrt(fan_in1)
320
+ self.dense2 = Dense(dim_feedforward, d_model, weight_init=HeUniform(math.sqrt(5)),
321
+ bias_init=Uniform(bound1), dtype=dtype)
311
322
 
312
323
  self.norm_first = norm_first
313
- self.norm1 = LayerNorm((d_model,), epsilon=layer_norm_eps)
314
- self.norm2 = LayerNorm((d_model,), epsilon=layer_norm_eps)
324
+ self.norm1 = LayerNorm((d_model,), epsilon=layer_norm_eps, dtype=dtype)
325
+ self.norm2 = LayerNorm((d_model,), epsilon=layer_norm_eps, dtype=dtype)
315
326
  self.dropout1 = Dropout(p=dropout)
316
327
  self.dropout2 = Dropout(p=dropout)
328
+ self.activation1 = activation
317
329
 
318
330
  if not isinstance(activation, str) and not isinstance(activation, Cell) \
319
331
  and not callable(activation):
320
332
  raise ValueError(f"The argument 'activation' must be str, callable or Cell instance,"
321
333
  f" but get {activation}.")
322
- if isinstance(activation, Cell) and (not isinstance(activation, ReLU) or \
334
+ if isinstance(activation, Cell) and (not isinstance(activation, ReLU) and \
323
335
  not isinstance(activation, GELU)):
324
336
  raise ValueError(f"The argument 'activation' must be nn.ReLU or nn.GELU instance,"
325
337
  f" but get {activation}.")
326
- if callable(activation) and (activation is not ops.relu or \
338
+ if callable(activation) and (activation is not ops.relu and \
327
339
  activation is not ops.gelu):
328
340
  raise ValueError(f"The argument 'activation' must be ops.relu or ops.gelu instance,"
329
341
  f" but get {activation}.")
@@ -331,6 +343,14 @@ class TransformerEncoderLayer(Cell):
331
343
  if isinstance(activation, str):
332
344
  activation = _get_activation_fn(activation)
333
345
  self.activation = activation
346
+ self.d_model = d_model
347
+ self.nhead = nhead
348
+ self.dim_feedforward = dim_feedforward
349
+ self.dropout_num = dropout
350
+ self.layernorm_eps = layer_norm_eps
351
+ self.batch_first = batch_first
352
+ self.norm_first = norm_first
353
+ self.dtype = dtype
334
354
 
335
355
  def construct(self, src: Tensor, src_mask: Optional[Tensor] = None,
336
356
  src_key_padding_mask: Optional[Tensor] = None):
@@ -358,7 +378,7 @@ class TransformerEncoderLayer(Cell):
358
378
  return self.dropout1(x)
359
379
 
360
380
  def _ff_block(self, x):
361
- x = self.linear2(self.dropout(self.activation(self.linear1(x))))
381
+ x = self.dense2(self.dropout(self.activation(self.dense1(x))))
362
382
  return self.dropout2(x)
363
383
 
364
384
 
@@ -367,42 +387,50 @@ class TransformerDecoderLayer(Cell):
367
387
  Transformer Decoder Layer. This is an implementation of the single layer of the transformer
368
388
  decoder layer, including self-attention, cross attention and feedward layer.
369
389
 
370
- .. warning::
371
- This is an experimental API that is subject to change or deletion.
372
-
373
390
  Args:
374
391
  d_model (int): The number of expected features in the input tensor.
375
392
  nhead (int): The number of heads in the MultiheadAttention modules.
376
393
  dim_feedforward (int): The dimension of the feedforward layer. Default: ``2048``.
377
394
  dropout (float): The dropout value. Default: ``0.1``.
378
395
  activation (Union[str, callable, Cell]): The activation function of the intermediate layer,
379
- can be a string (`"relu"` or `"gelu"`), Cell instance (`nn.ReLU()` or `nn.GELU()`) or
380
- a callable (`ops.relu` or `ops.gelu`). Default: ``"relu"``
396
+ can be a string (``"relu"`` or ``"gelu"``), Cell instance (:class:`mindspore.nn.ReLU` or
397
+ :class:`mindspore.nn.GELU` ) or a callable ( :func:`mindspore.ops.relu` or
398
+ :func:`mindspore.ops.gelu` ). Default: ``"relu"``.
381
399
  layer_norm_eps (float): The epsilon value in LayerNorm modules. Default: ``1e-5``.
382
- batch_first (bool): If `batch_first = True`, then the shape of input and output tensors is
400
+ batch_first (bool): If `batch_first=True` , then the shape of input and output tensors is
383
401
  :math:`(batch, seq, feature)` , otherwise the shape is :math:`(seq, batch, feature)`.
384
402
  Default: ``False``.
385
- norm_first (bool): If `norm_first = True`, layer norm is done prior to attention and feedforward
386
- operations, respectively. Default: ``False``.
403
+ norm_first (bool): If `norm_first = True`, layer norm is located prior to attention and feedforward
404
+ operations; if `norm_first = False`, layer norm is located after the attention and feedforward
405
+ operations. Default: ``False``.
406
+ dtype (:class:`mindspore.dtype`): Data type of Parameter. Default: ``mstype.float32`` .
387
407
 
388
408
  Inputs:
389
- - **tgt** (Tensor): The sequence to the decoder layer.
390
- - **memory** (Tensor): The sequence from the last layer of the encoder.
391
- - **tgt_mask** (Tensor, optional): The mask of the tgt sequence. Default: ``None``.
392
- - **memory_mask** (Tensor, optional): The mask of the memory sequence. Default: ``None``.
393
- - **tgt_key_padding_mask** (Tensor, optional): The mask of the tgt keys per batch.
394
- Default: ``None``.
395
- - **memory_key_padding_mask** (Tensor, optional): The mask of the memory keys per batch.
396
- Default: ``None``.
409
+ - **tgt** (Tensor): The sequence to the decoder layer. For unbatched input, the shape is
410
+ :math:`(T, E)` ; otherwise if `batch_first=False` , the shape is :math:`(T, N, E)` and if
411
+ `batch_first=True` , the shape is :math:`(T, N, E)`, where :math:`(T)` is the target sequence
412
+ length. Supported types: float16, float32, float64.
413
+ - **memory** (Tensor): The sequence from the last layer of the encoder. Supported types: float16,
414
+ float32, float64.
415
+ - **tgt_mask** (Tensor, optional): The mask of the tgt sequence. The shape is :math:`(T, T)`
416
+ or :math:`(N*nhead, T, T)`. Supported types: float16, float32, float64, bool. Default: ``None``.
417
+ - **memory_mask** (Tensor, optional): The mask of the memory sequence. The shape is
418
+ :math:`(T, S)` . Supported types: float16, float32, float64, bool. Default: ``None``.
419
+ - **tgt_key_padding_mask** (Tensor, optional): The mask of the tgt keys per batch. The shape is
420
+ :math:`(T)` for unbatched input, otherwise :math:`(N, S)` . Supported types: float16, float32,
421
+ float64, bool. Default: ``None``.
422
+ - **memory_key_padding_mask** (Tensor, optional): The mask of the memory keys per batch. The shape
423
+ is :math:`(S)` for unbatched input, otherwise :math:`(N, S)` . Supported types: float16, float32,
424
+ float64, bool. Default: ``None``.
397
425
 
398
426
  Outputs:
399
- Tensor.
427
+ Tensor. The shape and dtype of Tensor is the same with `tgt` .
400
428
 
401
429
  Raises:
402
430
  ValueError: If the init argument `activation` is not str, callable or Cell instance.
403
431
  ValueError: If the init argument `activation` is not :class:`mindspore.nn.ReLU`,
404
432
  :class:`mindspore.nn.GELU` instance, :func:`mindspore.ops.relu`,
405
- :func:`mindspore.ops.gelu` instance, "relu" or "gelu" .
433
+ :func:`mindspore.ops.gelu` , "relu" or "gelu" .
406
434
 
407
435
  Supported Platforms:
408
436
  ``Ascend`` ``GPU`` ``CPU``
@@ -414,6 +442,8 @@ class TransformerDecoderLayer(Cell):
414
442
  >>> memory = ms.Tensor(np.random.rand(10, 32, 512), ms.float32)
415
443
  >>> tgt = ms.Tensor(np.random.rand(20, 32, 512), ms.float32)
416
444
  >>> out = decoder_layer(tgt, memory)
445
+ >>> print(out.shape)
446
+ (20, 32, 512)
417
447
  >>> # Alternatively, when `batch_first` is ``True``:
418
448
  >>> decoder_layer = ms.nn.TransformerDecoderLayer(d_model=512, nhead=8, batch_first=True)
419
449
  >>> memory = ms.Tensor(np.random.rand(32, 10, 512), ms.float32)
@@ -422,36 +452,42 @@ class TransformerDecoderLayer(Cell):
422
452
  >>> print(out.shape)
423
453
  (32, 20, 512)
424
454
  """
425
- __constants__ = ['batch_first', 'norm_first']
426
455
 
427
456
  def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1,
428
457
  activation: Union[str, Cell, callable] = 'relu', layer_norm_eps: float = 1e-5,
429
- batch_first: bool = False, norm_first: bool = False):
458
+ batch_first: bool = False, norm_first: bool = False, dtype=mstype.float32):
430
459
  super().__init__()
431
- self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first)
432
- self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first)
460
+ self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first, dtype=dtype)
461
+ self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first, dtype=dtype)
433
462
  # feedforward layer
434
- self.linear1 = _Linear(d_model, dim_feedforward)
463
+ fan_in, _ = _calculate_fan_in_and_fan_out((dim_feedforward, d_model))
464
+ bound = 1 / math.sqrt(fan_in)
465
+ self.dense1 = Dense(d_model, dim_feedforward, weight_init=HeUniform(math.sqrt(5)),
466
+ bias_init=Uniform(bound), dtype=dtype)
435
467
  self.dropout = Dropout(p=dropout)
436
- self.linear2 = _Linear(dim_feedforward, d_model)
468
+ fan_in1, _ = _calculate_fan_in_and_fan_out((d_model, dim_feedforward))
469
+ bound1 = 1 / math.sqrt(fan_in1)
470
+ self.dense2 = Dense(dim_feedforward, d_model, weight_init=HeUniform(math.sqrt(5)),
471
+ bias_init=Uniform(bound1), dtype=dtype)
437
472
 
438
473
  self.norm_first = norm_first
439
- self.norm1 = LayerNorm((d_model,), epsilon=layer_norm_eps)
440
- self.norm2 = LayerNorm((d_model,), epsilon=layer_norm_eps)
441
- self.norm3 = LayerNorm((d_model,), epsilon=layer_norm_eps)
474
+ self.norm1 = LayerNorm((d_model,), epsilon=layer_norm_eps, dtype=dtype)
475
+ self.norm2 = LayerNorm((d_model,), epsilon=layer_norm_eps, dtype=dtype)
476
+ self.norm3 = LayerNorm((d_model,), epsilon=layer_norm_eps, dtype=dtype)
442
477
  self.dropout1 = Dropout(p=dropout)
443
478
  self.dropout2 = Dropout(p=dropout)
444
479
  self.dropout3 = Dropout(p=dropout)
480
+ self.activation1 = activation
445
481
 
446
482
  if not isinstance(activation, str) and not isinstance(activation, Cell) \
447
483
  and not callable(activation):
448
484
  raise ValueError(f"The argument 'activation' must be str, callable or Cell instance,"
449
485
  f" but get {activation}.")
450
- if isinstance(activation, Cell) and (not isinstance(activation, ReLU) or \
486
+ if isinstance(activation, Cell) and (not isinstance(activation, ReLU) and \
451
487
  not isinstance(activation, GELU)):
452
488
  raise ValueError(f"The argument 'activation' must be nn.ReLU or nn.GELU instance,"
453
489
  f" but get {activation}.")
454
- if callable(activation) and (activation is not ops.relu or \
490
+ if callable(activation) and (activation is not ops.relu and \
455
491
  activation is not ops.gelu):
456
492
  raise ValueError(f"The argument 'activation' must be ops.relu or ops.gelu instance,"
457
493
  f" but get {activation}.")
@@ -459,6 +495,14 @@ class TransformerDecoderLayer(Cell):
459
495
  if isinstance(activation, str):
460
496
  activation = _get_activation_fn(activation)
461
497
  self.activation = activation
498
+ self.d_model = d_model
499
+ self.nhead = nhead
500
+ self.dim_feedforward = dim_feedforward
501
+ self.dropout_num = dropout
502
+ self.layernorm_eps = layer_norm_eps
503
+ self.batch_first = batch_first
504
+ self.norm_first = norm_first
505
+ self.dtype = dtype
462
506
 
463
507
  def construct(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None,
464
508
  memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None,
@@ -490,32 +534,36 @@ class TransformerDecoderLayer(Cell):
490
534
  return self.dropout2(x)
491
535
 
492
536
  def _ff_block(self, x):
493
- x = self.linear2(self.dropout(self.activation(self.linear1(x))))
537
+ x = self.dense2(self.dropout(self.activation(self.dense1(x))))
494
538
  return self.dropout3(x)
495
539
 
496
540
 
497
541
  class TransformerEncoder(Cell):
498
542
  r"""
499
- Transformer Encoder module with multi-layer stacked of `TransformerEncoderLayer`, including multihead self
543
+ Transformer Encoder module with multi-layer stacked of `TransformerEncoderLayer`, including multihead
500
544
  attention and feedforward layer. Users can build the
501
545
  BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters.
502
546
 
503
- .. warning::
504
- This is an experimental API that is subject to change or deletion.
505
-
506
547
  Args:
507
- encoder_layer (Cell): An instance of the TransformerEncoderLayer() class.
548
+ encoder_layer (Cell): An instance of the :class:`mindspore.nn.TransformerEncoderLayer` class.
508
549
  num_layers (int): The number of encoder-layers in the encoder.
509
550
  norm (Cell, optional): The layer normalization module. Default: ``None``.
510
551
 
511
552
  Inputs:
512
- - **src** (Tensor): The sequence to the encoder.
513
- - **src_mask** (Tensor, optional): The mask of the src sequence. Default: ``None``.
514
- - **src_key_padding_mask** (Tensor, optional): the mask of the src keys per batch .
515
- Default: ``None``.
553
+ - **src** (Tensor): The sequence to the encoder. For unbatched input, the shape is
554
+ :math:`(S, E)` ; otherwise if `batch_first=False` in TransformerEncoderLayer, the shape is
555
+ :math:`(S, N, E)` and if `batch_first=True` , the shape is :math:`(S, N, E)`, where :math:`(S)` is the
556
+ source sequence length, :math:`(N)` is the batch number and :math:`(E)` is the feature number.
557
+ Supported types: float16, float32, float64.
558
+ - **src_mask** (Tensor, optional): The mask of the src sequence. The shape is :math:`(S, S)`
559
+ or :math:`(N*nhead, S, S)` , where `nhead` is the arguent in TransformerDecoderLayer.
560
+ Supported types: float16, float32, float64, bool. Default: ``None``.
561
+ - **src_key_padding_mask** (Tensor, optional): the mask of the src keys per batch. The shape is
562
+ :math:`(S)` for unbatched input, otherwise :math:`(N, S)` . Supported types: float16, float32,
563
+ float64, bool. Default: ``None``.
516
564
 
517
565
  Outputs:
518
- Tensor.
566
+ Tensor. The shape and dtype of Tensor is the same with `src` .
519
567
 
520
568
  Raises:
521
569
  AssertionError: If the input argument `src_key_padding_mask` is not bool or floating types.
@@ -533,11 +581,14 @@ class TransformerEncoder(Cell):
533
581
  >>> print(out.shape)
534
582
  (10, 32, 512)
535
583
  """
536
- __constants__ = ['norm']
537
584
 
538
585
  def __init__(self, encoder_layer, num_layers, norm=None):
539
586
  super(TransformerEncoder, self).__init__()
540
- self.layers = _get_clones(encoder_layer, num_layers)
587
+ layers = TransformerEncoderLayer(encoder_layer.d_model, encoder_layer.nhead, encoder_layer.dim_feedforward,
588
+ encoder_layer.dropout_num, encoder_layer.activation1,
589
+ encoder_layer.layernorm_eps, encoder_layer.batch_first,
590
+ encoder_layer.norm_first, dtype=encoder_layer.dtype)
591
+ self.layers = CellList([layers for _ in range(num_layers)])
541
592
  self.num_layers = num_layers
542
593
  self.norm = norm
543
594
 
@@ -563,26 +614,31 @@ class TransformerDecoder(Cell):
563
614
  Transformer Decoder module with multi-layer stacked of `TransformerDecoderLayer`, including multihead self
564
615
  attention, cross attention and feedforward layer.
565
616
 
566
- .. warning::
567
- This is an experimental API that is subject to change or deletion.
568
-
569
617
  Args:
570
618
  decoder_layer (Cell): An instance of the :class:`mindspore.nn.TransformerDecoderLayer` class.
571
619
  num_layers (int): The number of decoder-layers in the decoder.
572
620
  norm (Cell, optional): The layer normalization module. Default: ``None``.
573
621
 
574
622
  Inputs:
575
- - **tgt** (Tensor): The sequence to the decoder.
576
- - **memory** (Tensor): The sequence from the last layer of the encoder.
577
- - **tgt_mask** (Tensor, optional): the mask of the tgt sequence. Default: ``None``.
578
- - **memory_mask** (Tensor, optional): the mask of the memory sequence. Default: ``None``.
579
- - **tgt_key_padding_mask** (Tensor, optional): the mask of the tgt keys per batch.
580
- Default: ``None``.
581
- - **memory_key_padding_mask** (Tensor, optional): the mask of the memory keys per batch.
582
- Default: ``None``.
623
+ - **tgt** (Tensor): The sequence to the decoder. For unbatched input, the shape is
624
+ :math:`(T, E)` ; otherwise if `batch_first=False` in TransformerDecoderLayer, the shape is
625
+ :math:`(T, N, E)` and if `batch_first=True` , the shape is :math:`(T, N, E)`, where :math:`(T)` is the
626
+ target sequence length. Supported types: float16, float32, float64.
627
+ - **memory** (Tensor): The sequence from the last layer of the encoder. Supported types: float16,
628
+ float32, float64.
629
+ - **tgt_mask** (Tensor, optional): the mask of the tgt sequence. The shape is :math:`(T, T)`
630
+ or :math:`(N*nhead, T, T)` , where `nhead` is the arguent in TransformerDecoderLayer.
631
+ Supported types: float16, float32, float64, bool. Default: ``None``.
632
+ - **memory_mask** (Tensor, optional): the mask of the memory sequence. The shape is
633
+ :math:`(T, S)` . Supported types: float16, float32, float64, bool. Default: ``None``.
634
+ - **tgt_key_padding_mask** (Tensor, optional): the mask of the tgt keys per batch. Supported
635
+ types: float16, float32, float64, bool. Default: ``None``.
636
+ - **memory_key_padding_mask** (Tensor, optional): the mask of the memory keys per batch. The shape
637
+ is :math:`(S)` for unbatched input, otherwise :math:`(N, S)` . Supported types: float16, float32,
638
+ float64, bool. Default: ``None``.
583
639
 
584
640
  Outputs:
585
- Tensor.
641
+ Tensor. The shape and dtype of Tensor is the same with `tgt` .
586
642
 
587
643
  Supported Platforms:
588
644
  ``Ascend`` ``GPU`` ``CPU``
@@ -598,11 +654,14 @@ class TransformerDecoder(Cell):
598
654
  >>> print(out.shape)
599
655
  (20, 32, 512)
600
656
  """
601
- __constants__ = ['norm']
602
657
 
603
658
  def __init__(self, decoder_layer, num_layers, norm=None):
604
659
  super(TransformerDecoder, self).__init__()
605
- self.layers = _get_clones(decoder_layer, num_layers)
660
+ layers = TransformerDecoderLayer(decoder_layer.d_model, decoder_layer.nhead, decoder_layer.dim_feedforward,
661
+ decoder_layer.dropout_num, decoder_layer.activation1,
662
+ decoder_layer.layernorm_eps, decoder_layer.batch_first,
663
+ decoder_layer.norm_first, dtype=decoder_layer.dtype)
664
+ self.layers = CellList([layers for _ in range(num_layers)])
606
665
  self.num_layers = num_layers
607
666
  self.norm = norm
608
667
 
@@ -610,7 +669,6 @@ class TransformerDecoder(Cell):
610
669
  memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None,
611
670
  memory_key_padding_mask: Optional[Tensor] = None):
612
671
  output = tgt
613
-
614
672
  for mod in self.layers:
615
673
  output = mod(output, memory, tgt_mask=tgt_mask,
616
674
  memory_mask=memory_mask,
@@ -626,47 +684,60 @@ class TransformerDecoder(Cell):
626
684
  class Transformer(Cell):
627
685
  r"""
628
686
  Transformer module including encoder and decoder. The difference with the original implements is the module use
629
- the residual addition before the layer normalization. And the default hidden act is `gelu`.
687
+ the residual addition before the layer normalization. And the default hidden activation is `gelu`.
630
688
  The details can be found in `Attention is all you need <https://arxiv.org/pdf/1706.03762v5.pdf>`_.
631
689
 
632
- .. warning::
633
- This is an experimental API that is subject to change or deletion.
634
-
635
690
  Args:
636
- d_model (int): The number of expected features in the inputs tensor. Default: ``512``.
691
+ d_model (int): The number of expected features in the inputs tensor for Encoder and Decoder. Default: ``512``.
637
692
  nhead (int): The number of heads in the MultiheadAttention modules. Default: ``8``.
638
693
  num_encoder_layers (int): The number of encoder-layers in the encoder. Default: ``6``.
639
694
  num_decoder_layers (int): The number of decoder-layers in the decoder. Default: ``6``.
640
695
  dim_feedforward (int): The dimension of the feedforward layer. Default: ``2048``.
641
696
  dropout (float): The dropout value. Default: ``0.1``.
642
697
  activation (Union[str, callable, Cell]): The activation function of the intermediate layer,
643
- can be a string (`"relu"` or `"gelu"`), Cell instance (`nn.ReLU()` or `nn.GELU()`) or
644
- a callable (`ops.relu` or `ops.gelu`). Default: ``"relu"``
698
+ can be a string (``"relu"`` or ``"gelu"``), Cell instance (:class:`mindspore.nn.ReLU` or
699
+ :class:`mindspore.nn.GELU` ) or a callable ( :func:`mindspore.ops.relu` or
700
+ :func:`mindspore.ops.gelu` ). Default: ``"relu"``.
645
701
  custom_encoder (Cell): Custom encoder. Default: ``None``.
646
702
  custom_decoder (Cell): Custom decoder. Default: ``None``.
647
703
  layer_norm_eps (float): the epsilion value in layer normalization module. Default: ``1e-5``.
648
- batch_first (bool): If `batch_first = True`, then the shape of input and output tensors is
704
+ batch_first (bool): If `batch_first=True`, then the shape of input and output tensors is
649
705
  :math:`(batch, seq, feature)` , otherwise the shape is :math:`(seq, batch, feature)` .
650
706
  Default: ``False``.
651
- norm_first (bool): If `norm_first = True`, layer norm is done prior to attention and feedforward
652
- operations, respectively. Default: ``False``.
707
+ norm_first (bool): If `norm_first = True`, layer norm is located prior to attention and feedforward
708
+ operations; if `norm_first = False`, layer norm is located after the attention and feedforward
709
+ operations. Default: ``False``.
710
+ dtype (:class:`mindspore.dtype`): Data type of Parameter. Default: ``mstype.float32`` .
653
711
 
654
712
  Inputs:
655
- - **src** (Tensor): The source sequence to the encoder.
656
- - **tgt** (Tensor): The target sequence to the decoder.
657
- - **src_mask** (Tensor, optional): The mask of the src sequence. Default: ``None``.
658
- - **tgt_mask** (Tensor, optional): The mask of the tgt sequence. Default: ``None``.
659
- - **memory_mask** (Tensor, optional): The additive mask of the encoder output.
660
- Default: ``None``.
661
- - **src_key_padding_mask** (Tensor, optional): The mask of src keys per batch.
662
- Default: ``None``.
663
- - **tgt_key_padding_mask** (Tensor, optional): The mask of tgt keys per batch.
664
- Default: ``None``.
665
- - **memory_key_padding_mask** (Tensor, optional): The mask of memory keys per batch.
666
- Default: ``None``.
713
+ - **src** (Tensor): The source sequence to the encoder. For unbatched input, the shape is
714
+ :math:`(S, E)` ; otherwise if `batch_first=False` , the shape is :math:`(S, N, E)` and if
715
+ `batch_first=True` , the shape is :math:`(S, N, E)`, where :math:`(S)` is the source sequence
716
+ length, :math:`(N)` is the batch number and :math:`(E)` is the feature number. Supported
717
+ types: float16, float32, float64.
718
+ - **tgt** (Tensor): The target sequence to the decoder. For unbatched input, the shape is
719
+ :math:`(T, E)` ; otherwise if `batch_first=False` , the shape is :math:`(T, N, E)` and if
720
+ `batch_first=True` , the shape is :math:`(T, N, E)`, where :math:`(T)` is the target sequence
721
+ length. Supported types: float16, float32, float64.
722
+ - **src_mask** (Tensor, optional): The mask of the src sequence. The shape is :math:`(S, S)`
723
+ or :math:`(N*nhead, S, S)`. Supported types: float16, float32, float64, bool. Default: ``None``.
724
+ - **tgt_mask** (Tensor, optional): The mask of the tgt sequence. The shape is :math:`(T, T)`
725
+ or :math:`(N*nhead, T, T)`. Supported types: float16, float32, float64, bool. Default: ``None``.
726
+ - **memory_mask** (Tensor, optional): The additive mask of the encoder output. The shape is
727
+ :math:`(T, S)` . Supported types: float16, float32, float64, bool. Default: ``None``.
728
+ - **src_key_padding_mask** (Tensor, optional): The mask of src keys per batch. The shape is
729
+ :math:`(S)` for unbatched input, otherwise :math:`(N, S)` . Supported types: float16, float32,
730
+ float64, bool. Default: ``None``.
731
+ - **tgt_key_padding_mask** (Tensor, optional): The mask of tgt keys per batch. The shape is
732
+ :math:`(T)` for unbatched input, otherwise :math:`(N, S)` . Supported types: float16, float32,
733
+ float64, bool. Default: ``None``.
734
+ - **memory_key_padding_mask** (Tensor, optional): The mask of memory keys per batch. The shape
735
+ is :math:`(S)` for unbatched input, otherwise :math:`(N, S)` . Supported types: float16,
736
+ float32, float64, bool. Default: ``None``.
667
737
 
668
738
  Outputs:
669
- Tensor.
739
+ Tensor. The shape is :math:`(T, E)` for unbatched input, otherwise if `batch_first=False` , the shape is
740
+ :math:`(T, N, E)` and if `batch_first=True` , the shape is :math:`(N, T, E)`.
670
741
 
671
742
  Raises:
672
743
  ValueError: If the batch sizes of the init argument `src` and `tgt` are not equal.
@@ -690,23 +761,23 @@ class Transformer(Cell):
690
761
  num_decoder_layers: int = 6, dim_feedforward: int = 2048, dropout: float = 0.1,
691
762
  activation: Union[str, Cell, callable] = 'relu', custom_encoder: Optional[Cell] = None,
692
763
  custom_decoder: Optional[Cell] = None, layer_norm_eps: float = 1e-5,
693
- batch_first: bool = False, norm_first: bool = False):
764
+ batch_first: bool = False, norm_first: bool = False, dtype=mstype.float32):
694
765
  super(Transformer, self).__init__()
695
766
 
696
767
  if custom_encoder is not None:
697
768
  self.encoder = custom_encoder
698
769
  else:
699
770
  encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout,
700
- activation, layer_norm_eps, batch_first, norm_first)
701
- encoder_norm = LayerNorm((d_model,), epsilon=layer_norm_eps)
771
+ activation, layer_norm_eps, batch_first, norm_first, dtype=dtype)
772
+ encoder_norm = LayerNorm((d_model,), epsilon=layer_norm_eps, dtype=dtype)
702
773
  self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
703
774
 
704
775
  if custom_decoder is not None:
705
776
  self.decoder = custom_decoder
706
777
  else:
707
778
  decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout,
708
- activation, layer_norm_eps, batch_first, norm_first)
709
- decoder_norm = LayerNorm((d_model,), epsilon=layer_norm_eps)
779
+ activation, layer_norm_eps, batch_first, norm_first, dtype=dtype)
780
+ decoder_norm = LayerNorm((d_model,), epsilon=layer_norm_eps, dtype=dtype)
710
781
  self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm)
711
782
 
712
783
  for _, p in self.parameters_and_names():
@@ -748,7 +819,3 @@ def _get_activation_fn(activation: str):
748
819
  return ops.gelu
749
820
 
750
821
  raise ValueError(f"The activation must be relu/gelu, but get {activation}")
751
-
752
-
753
- def _get_clones(module, N):
754
- return CellList([copy.deepcopy(module) for i in range(N)])