mindspore 1.10.0__cp37-none-any.whl → 2.0.0rc1__cp37-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (944) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Third_Party_Open_Source_Software_Notice +9064 -0
  3. mindspore/__init__.py +9 -4
  4. mindspore/_akg/akg/composite/build_module.py +11 -0
  5. mindspore/_akg/akg/config/repository_cuda.json +11 -0
  6. mindspore/_akg/akg/tvm/contrib/nvcc.py +4 -3
  7. mindspore/_c_dataengine.cpython-37m-aarch64-linux-gnu.so +0 -0
  8. mindspore/_c_expression.cpython-37m-aarch64-linux-gnu.so +0 -0
  9. mindspore/_c_mindrecord.cpython-37m-aarch64-linux-gnu.so +0 -0
  10. mindspore/_check_jit_forbidden_api.py +102 -0
  11. mindspore/_checkparam.py +1066 -1001
  12. mindspore/_extends/builtin_operations.py +32 -4
  13. mindspore/_extends/graph_kernel/model/graph_split.py +66 -222
  14. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +12 -9
  15. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +119 -26
  16. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +50 -50
  17. mindspore/_extends/parallel_compile/akg_compiler/util.py +9 -6
  18. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +4 -25
  19. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +9 -4
  20. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -27
  21. mindspore/_extends/parse/__init__.py +5 -3
  22. mindspore/_extends/parse/namespace.py +17 -2
  23. mindspore/_extends/parse/parser.py +193 -34
  24. mindspore/_extends/parse/resources.py +7 -8
  25. mindspore/_extends/parse/standard_method.py +1780 -435
  26. mindspore/_extends/parse/trope.py +3 -1
  27. mindspore/_mindspore_offline_debug.cpython-37m-aarch64-linux-gnu.so +0 -0
  28. mindspore/amp.py +53 -58
  29. mindspore/bin/cache_admin +0 -0
  30. mindspore/bin/cache_server +0 -0
  31. mindspore/boost/adasum.py +3 -2
  32. mindspore/boost/boost.py +2 -2
  33. mindspore/boost/boost_cell_wrapper.py +46 -26
  34. mindspore/boost/dim_reduce.py +6 -5
  35. mindspore/boost/grad_accumulation.py +2 -1
  36. mindspore/boost/group_loss_scale_manager.py +1 -1
  37. mindspore/common/__init__.py +11 -10
  38. mindspore/common/_decorator.py +2 -0
  39. mindspore/common/_register_for_adapter.py +55 -0
  40. mindspore/common/_stub_tensor.py +201 -0
  41. mindspore/common/_utils.py +57 -0
  42. mindspore/common/api.py +582 -297
  43. mindspore/common/dtype.py +66 -18
  44. mindspore/common/dump.py +2 -2
  45. mindspore/common/initializer.py +38 -1
  46. mindspore/common/jit_config.py +25 -13
  47. mindspore/common/mutable.py +53 -24
  48. mindspore/common/parameter.py +60 -37
  49. mindspore/common/seed.py +8 -24
  50. mindspore/common/sparse_tensor.py +927 -0
  51. mindspore/common/tensor.py +1627 -3900
  52. mindspore/communication/__init__.py +10 -5
  53. mindspore/communication/_comm_helper.py +78 -214
  54. mindspore/communication/_hccl_management.py +2 -1
  55. mindspore/communication/management.py +136 -47
  56. mindspore/config/op_info.config +501 -1008
  57. mindspore/config/super_bar_config.json +512 -0
  58. mindspore/context.py +291 -56
  59. mindspore/dataset/__init__.py +12 -8
  60. mindspore/dataset/audio/__init__.py +9 -9
  61. mindspore/dataset/audio/transforms.py +1090 -228
  62. mindspore/dataset/audio/utils.py +87 -39
  63. mindspore/dataset/audio/validators.py +223 -1
  64. mindspore/dataset/callback/ds_callback.py +17 -15
  65. mindspore/dataset/core/config.py +246 -17
  66. mindspore/dataset/core/py_util_helpers.py +4 -3
  67. mindspore/dataset/core/validator_helpers.py +10 -10
  68. mindspore/{parallel/nn/layers.py → dataset/debug/__init__.py} +7 -8
  69. mindspore/dataset/debug/debug_hook.py +65 -0
  70. mindspore/dataset/debug/pre_defined_hook.py +67 -0
  71. mindspore/dataset/engine/__init__.py +7 -3
  72. mindspore/dataset/engine/cache_client.py +9 -9
  73. mindspore/dataset/engine/datasets.py +648 -477
  74. mindspore/dataset/engine/datasets_audio.py +165 -167
  75. mindspore/dataset/engine/datasets_standard_format.py +93 -67
  76. mindspore/dataset/engine/datasets_text.py +492 -342
  77. mindspore/dataset/engine/datasets_user_defined.py +85 -50
  78. mindspore/dataset/engine/datasets_vision.py +1224 -699
  79. mindspore/dataset/engine/graphdata.py +134 -69
  80. mindspore/dataset/engine/iterators.py +50 -9
  81. mindspore/dataset/engine/offload.py +52 -31
  82. mindspore/dataset/engine/samplers.py +27 -24
  83. mindspore/dataset/engine/serializer_deserializer.py +14 -15
  84. mindspore/dataset/engine/validators.py +213 -52
  85. mindspore/dataset/text/__init__.py +10 -8
  86. mindspore/dataset/text/transforms.py +152 -57
  87. mindspore/dataset/text/utils.py +98 -49
  88. mindspore/dataset/text/validators.py +25 -0
  89. mindspore/dataset/transforms/__init__.py +4 -2
  90. mindspore/dataset/transforms/c_transforms.py +11 -13
  91. mindspore/dataset/transforms/py_transforms.py +2 -2
  92. mindspore/dataset/transforms/py_transforms_util.py +10 -0
  93. mindspore/dataset/transforms/transforms.py +13 -15
  94. mindspore/dataset/transforms/validators.py +7 -7
  95. mindspore/dataset/utils/__init__.py +2 -1
  96. mindspore/dataset/utils/browse_dataset.py +13 -13
  97. mindspore/dataset/utils/line_reader.py +121 -0
  98. mindspore/dataset/vision/__init__.py +8 -7
  99. mindspore/dataset/vision/c_transforms.py +125 -126
  100. mindspore/dataset/vision/py_transforms.py +37 -37
  101. mindspore/dataset/vision/py_transforms_util.py +23 -20
  102. mindspore/dataset/vision/transforms.py +316 -315
  103. mindspore/dataset/vision/utils.py +313 -17
  104. mindspore/dataset/vision/validators.py +6 -6
  105. mindspore/default_config.py +0 -1
  106. mindspore/{compression → experimental}/__init__.py +6 -5
  107. mindspore/experimental/map_parameter.py +275 -0
  108. mindspore/include/OWNERS +0 -1
  109. mindspore/include/api/callback/callback.h +9 -13
  110. mindspore/include/api/callback/ckpt_saver.h +2 -2
  111. mindspore/include/api/callback/loss_monitor.h +2 -2
  112. mindspore/include/api/callback/lr_scheduler.h +5 -5
  113. mindspore/include/api/callback/time_monitor.h +2 -2
  114. mindspore/include/api/callback/train_accuracy.h +4 -6
  115. mindspore/include/api/cfg.h +19 -6
  116. mindspore/include/api/context.h +70 -9
  117. mindspore/include/api/delegate.h +8 -1
  118. mindspore/include/api/dual_abi_helper.h +8 -24
  119. mindspore/include/api/metrics/accuracy.h +2 -2
  120. mindspore/include/api/metrics/metrics.h +4 -3
  121. mindspore/include/api/model.h +9 -4
  122. mindspore/include/api/model_group.h +68 -0
  123. mindspore/include/api/model_parallel_runner.h +17 -17
  124. mindspore/include/api/net.h +12 -11
  125. mindspore/include/api/serialization.h +20 -4
  126. mindspore/include/api/status.h +7 -1
  127. mindspore/include/api/types.h +25 -21
  128. mindspore/include/api/visible.h +4 -0
  129. mindspore/include/c_api/model_c.h +5 -0
  130. mindspore/include/c_api/status_c.h +1 -1
  131. mindspore/include/dataset/config.h +1 -1
  132. mindspore/include/dataset/constants.h +14 -0
  133. mindspore/include/dataset/text.h +59 -0
  134. mindspore/include/dataset/vision.h +56 -117
  135. mindspore/include/dataset/vision_lite.h +102 -0
  136. mindspore/include/mindapi/base/type_id.h +42 -3
  137. mindspore/lib/libdnnl.so.2 +0 -0
  138. mindspore/lib/libicudata.so.69 +0 -0
  139. mindspore/lib/libicui18n.so.69 +0 -0
  140. mindspore/lib/libicuuc.so.69 +0 -0
  141. mindspore/lib/libmindspore.so +0 -0
  142. mindspore/lib/libmindspore_backend.so +0 -0
  143. mindspore/lib/libmindspore_common.so +0 -0
  144. mindspore/lib/libmindspore_core.so +0 -0
  145. mindspore/lib/libmindspore_glog.so.0 +0 -0
  146. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  147. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  148. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  149. mindspore/lib/libmindspore_shared_lib.so +0 -0
  150. mindspore/lib/libmpi_adapter.so +0 -0
  151. mindspore/lib/libmpi_collective.so +0 -0
  152. mindspore/lib/libnnacl.so +0 -0
  153. mindspore/lib/libopencv_core.so.4.5 +0 -0
  154. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  155. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  156. mindspore/lib/libps_cache.so +0 -0
  157. mindspore/lib/plugin/ascend/libakg.so +0 -0
  158. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  159. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  160. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  161. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  162. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  163. mindspore/lib/{libakg.so → plugin/cpu/libakg.so} +0 -0
  164. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  165. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  166. mindspore/log.py +28 -28
  167. mindspore/mindrecord/common/exceptions.py +2 -4
  168. mindspore/mindrecord/filereader.py +19 -1
  169. mindspore/mindrecord/filewriter.py +250 -88
  170. mindspore/mindrecord/mindpage.py +13 -13
  171. mindspore/mindrecord/shardheader.py +15 -15
  172. mindspore/mindrecord/shardreader.py +9 -0
  173. mindspore/mindrecord/shardwriter.py +29 -29
  174. mindspore/mindrecord/tools/cifar100_to_mr.py +9 -9
  175. mindspore/mindrecord/tools/cifar10_to_mr.py +9 -9
  176. mindspore/mindrecord/tools/csv_to_mr.py +4 -4
  177. mindspore/mindrecord/tools/imagenet_to_mr.py +70 -65
  178. mindspore/mindrecord/tools/mnist_to_mr.py +41 -41
  179. mindspore/mindrecord/tools/tfrecord_to_mr.py +6 -6
  180. mindspore/nn/__init__.py +1 -5
  181. mindspore/nn/cell.py +297 -234
  182. mindspore/nn/dynamic_lr.py +1 -1
  183. mindspore/nn/grad/cell_grad.py +17 -42
  184. mindspore/nn/layer/__init__.py +7 -4
  185. mindspore/nn/layer/activation.py +131 -88
  186. mindspore/nn/layer/basic.py +313 -613
  187. mindspore/nn/layer/channel_shuffle.py +103 -0
  188. mindspore/nn/layer/combined.py +1 -1
  189. mindspore/nn/layer/container.py +52 -6
  190. mindspore/nn/layer/conv.py +112 -43
  191. mindspore/nn/layer/dense.py +10 -9
  192. mindspore/nn/layer/embedding.py +36 -34
  193. mindspore/nn/layer/image.py +123 -27
  194. mindspore/nn/layer/math.py +108 -107
  195. mindspore/nn/layer/normalization.py +212 -366
  196. mindspore/nn/layer/padding.py +370 -42
  197. mindspore/nn/layer/pooling.py +1443 -219
  198. mindspore/nn/layer/rnn_cells.py +11 -16
  199. mindspore/nn/layer/rnns.py +38 -39
  200. mindspore/nn/layer/thor_layer.py +24 -25
  201. mindspore/nn/layer/timedistributed.py +5 -5
  202. mindspore/nn/layer/transformer.py +701 -0
  203. mindspore/nn/learning_rate_schedule.py +8 -8
  204. mindspore/nn/loss/__init__.py +9 -6
  205. mindspore/nn/loss/loss.py +678 -142
  206. mindspore/nn/metrics.py +53 -0
  207. mindspore/nn/optim/_dist_optimizer_registry.py +2 -2
  208. mindspore/nn/optim/ada_grad.py +8 -8
  209. mindspore/nn/optim/adadelta.py +2 -3
  210. mindspore/nn/optim/adafactor.py +18 -14
  211. mindspore/nn/optim/adam.py +429 -87
  212. mindspore/nn/optim/adamax.py +5 -6
  213. mindspore/nn/optim/adasum.py +10 -8
  214. mindspore/nn/optim/asgd.py +7 -7
  215. mindspore/nn/optim/ftrl.py +81 -11
  216. mindspore/nn/optim/lamb.py +7 -8
  217. mindspore/nn/optim/lars.py +4 -4
  218. mindspore/nn/optim/lazyadam.py +82 -7
  219. mindspore/nn/optim/momentum.py +8 -7
  220. mindspore/nn/optim/optimizer.py +19 -10
  221. mindspore/nn/optim/proximal_ada_grad.py +6 -5
  222. mindspore/nn/optim/rmsprop.py +3 -3
  223. mindspore/nn/optim/rprop.py +20 -16
  224. mindspore/nn/optim/sgd.py +21 -15
  225. mindspore/nn/optim/thor.py +23 -21
  226. mindspore/nn/probability/__init__.py +0 -2
  227. mindspore/nn/probability/bijector/bijector.py +7 -6
  228. mindspore/nn/probability/bijector/invert.py +4 -2
  229. mindspore/nn/probability/bijector/softplus.py +2 -2
  230. mindspore/nn/probability/bnn_layers/dense_variational.py +1 -1
  231. mindspore/nn/probability/bnn_layers/layer_distribution.py +2 -2
  232. mindspore/nn/probability/distribution/__init__.py +6 -0
  233. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -2
  234. mindspore/nn/probability/distribution/_utils/utils.py +11 -17
  235. mindspore/nn/probability/distribution/bernoulli.py +6 -6
  236. mindspore/nn/probability/distribution/beta.py +1 -1
  237. mindspore/nn/probability/distribution/categorical.py +9 -9
  238. mindspore/nn/probability/distribution/cauchy.py +8 -8
  239. mindspore/nn/probability/distribution/distribution.py +12 -6
  240. mindspore/nn/probability/distribution/exponential.py +5 -5
  241. mindspore/nn/probability/distribution/gamma.py +3 -3
  242. mindspore/nn/probability/distribution/geometric.py +6 -5
  243. mindspore/nn/probability/distribution/gumbel.py +5 -5
  244. mindspore/nn/probability/distribution/half_normal.py +133 -0
  245. mindspore/nn/probability/distribution/laplace.py +128 -0
  246. mindspore/nn/probability/distribution/log_normal.py +0 -1
  247. mindspore/nn/probability/distribution/logistic.py +4 -5
  248. mindspore/nn/probability/distribution/normal.py +11 -15
  249. mindspore/nn/probability/distribution/poisson.py +6 -2
  250. mindspore/nn/probability/distribution/student_t.py +150 -0
  251. mindspore/nn/probability/distribution/transformed_distribution.py +4 -4
  252. mindspore/nn/probability/distribution/uniform.py +5 -5
  253. mindspore/nn/reinforcement/_tensors_queue.py +3 -3
  254. mindspore/nn/reinforcement/tensor_array.py +2 -2
  255. mindspore/nn/sparse/sparse.py +8 -1
  256. mindspore/nn/wrap/cell_wrapper.py +55 -27
  257. mindspore/nn/wrap/grad_reducer.py +20 -11
  258. mindspore/nn/wrap/loss_scale.py +47 -30
  259. mindspore/numpy/array_creations.py +33 -22
  260. mindspore/numpy/array_ops.py +46 -42
  261. mindspore/numpy/logic_ops.py +6 -27
  262. mindspore/numpy/math_ops.py +26 -19
  263. mindspore/numpy/utils.py +1 -8
  264. mindspore/numpy/utils_const.py +112 -62
  265. mindspore/ops/__init__.py +6 -3
  266. mindspore/ops/_constants.py +0 -6
  267. mindspore/ops/_grad/__init__.py +2 -1
  268. mindspore/ops/_grad/grad_array_ops.py +209 -152
  269. mindspore/ops/_grad/grad_base.py +55 -17
  270. mindspore/ops/_grad/grad_clip_ops.py +11 -3
  271. mindspore/ops/_grad/grad_comm_ops.py +58 -47
  272. mindspore/ops/_grad/grad_implementations.py +21 -61
  273. mindspore/ops/_grad/grad_inner_ops.py +48 -6
  274. mindspore/ops/_grad/grad_math_ops.py +306 -161
  275. mindspore/ops/_grad/grad_nn_ops.py +192 -181
  276. mindspore/ops/_grad/grad_other_ops.py +1 -1
  277. mindspore/ops/_grad/grad_quant_ops.py +5 -5
  278. mindspore/ops/_grad/grad_sequence_ops.py +296 -0
  279. mindspore/ops/_grad/grad_sparse.py +15 -9
  280. mindspore/ops/_grad_experimental/__init__.py +1 -0
  281. mindspore/ops/_grad_experimental/grad_array_ops.py +441 -55
  282. mindspore/ops/_grad_experimental/grad_image_ops.py +25 -7
  283. mindspore/ops/_grad_experimental/grad_inner_ops.py +3 -44
  284. mindspore/ops/_grad_experimental/grad_linalg_ops.py +16 -21
  285. mindspore/ops/_grad_experimental/grad_math_ops.py +979 -49
  286. mindspore/ops/_grad_experimental/grad_nn_ops.py +78 -8
  287. mindspore/ops/_grad_experimental/grad_scalar_ops.py +112 -0
  288. mindspore/ops/_grad_experimental/grad_sparse_ops.py +197 -13
  289. mindspore/ops/_op_impl/__init__.py +3 -3
  290. mindspore/ops/_op_impl/_custom_op/__init__.py +0 -1
  291. mindspore/ops/_op_impl/_custom_op/_basic.py +0 -1
  292. mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +1 -1
  293. mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +4 -2
  294. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +2 -2
  295. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +2 -2
  296. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +5 -5
  297. mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +3 -3
  298. mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +1 -1
  299. mindspore/ops/_op_impl/_custom_op/correction_mul.py +3 -3
  300. mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +2 -2
  301. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +4 -8
  302. mindspore/ops/_op_impl/_custom_op/dsd_impl.py +1 -1
  303. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +2 -2
  304. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +2 -2
  305. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +2 -2
  306. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +2 -2
  307. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +2 -2
  308. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +2 -2
  309. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +2 -2
  310. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +2 -2
  311. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +2 -2
  312. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +2 -2
  313. mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +1 -1
  314. mindspore/ops/_op_impl/_custom_op/img2col_impl.py +1 -1
  315. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
  316. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +1 -1
  317. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +1 -1
  318. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +1 -1
  319. mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +2 -2
  320. mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py +0 -1
  321. mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +0 -1
  322. mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +1 -1
  323. mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +2 -2
  324. mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +2 -2
  325. mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +1 -1
  326. mindspore/ops/_op_impl/aicpu/__init__.py +238 -3
  327. mindspore/ops/_op_impl/aicpu/abs.py +36 -0
  328. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d.py +34 -0
  329. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
  330. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d.py +39 -0
  331. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d_grad.py +39 -0
  332. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py +37 -0
  333. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d.py +42 -0
  334. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d_grad.py +152 -0
  335. mindspore/ops/_op_impl/aicpu/add.py +43 -0
  336. mindspore/ops/_op_impl/aicpu/addcdiv.py +0 -32
  337. mindspore/ops/_op_impl/aicpu/addcmul.py +0 -84
  338. mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
  339. mindspore/ops/_op_impl/aicpu/arg_max.py +75 -0
  340. mindspore/ops/_op_impl/aicpu/arg_min.py +75 -0
  341. mindspore/ops/_op_impl/aicpu/argmin_with_value.py +43 -0
  342. mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -0
  343. mindspore/ops/_op_impl/aicpu/batch_norm_grad_grad.py +49 -0
  344. mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
  345. mindspore/ops/_op_impl/aicpu/bessel_i0.py +31 -0
  346. mindspore/ops/_op_impl/aicpu/bias_add.py +44 -0
  347. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +43 -0
  348. mindspore/ops/_op_impl/aicpu/bincount.py +33 -0
  349. mindspore/{nn/probability/infer/variational/__init__.py → ops/_op_impl/aicpu/cauchy.py} +17 -10
  350. mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
  351. mindspore/ops/_op_impl/aicpu/cholesky.py +1 -1
  352. mindspore/ops/_op_impl/{cpu/bias_add.py → aicpu/choleskygrad.py} +9 -7
  353. mindspore/ops/_op_impl/aicpu/combined_non_max_suppression.py +42 -0
  354. mindspore/ops/_op_impl/aicpu/concat_offset.py +42 -0
  355. mindspore/ops/_op_impl/aicpu/concat_offset_v1.py +31 -0
  356. mindspore/ops/_op_impl/aicpu/conj.py +11 -0
  357. mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_image.py +38 -0
  358. mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +36 -0
  359. mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
  360. mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +2 -2
  361. mindspore/ops/_op_impl/aicpu/dense_to_sparse_set_operation.py +48 -0
  362. mindspore/ops/_op_impl/aicpu/diag.py +36 -0
  363. mindspore/ops/_op_impl/aicpu/diag_part.py +36 -0
  364. mindspore/ops/_op_impl/aicpu/diagonal.py +35 -0
  365. mindspore/ops/_op_impl/{cpu/bias_add_grad.py → aicpu/digamma.py} +9 -7
  366. mindspore/ops/_op_impl/aicpu/eig.py +35 -0
  367. mindspore/ops/_op_impl/aicpu/fft_with_size.py +41 -0
  368. mindspore/ops/_op_impl/aicpu/flatten.py +1 -0
  369. mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
  370. mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
  371. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +1 -1
  372. mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
  373. mindspore/ops/_op_impl/aicpu/glu.py +33 -0
  374. mindspore/ops/_op_impl/aicpu/glu_grad.py +34 -0
  375. mindspore/ops/_op_impl/aicpu/greater.py +41 -0
  376. mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
  377. mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
  378. mindspore/ops/_op_impl/{tbe/scatter_add_ds.py → aicpu/inplace_index_add.py} +17 -21
  379. mindspore/ops/_op_impl/aicpu/instance_norm_v2.py +41 -0
  380. mindspore/ops/_op_impl/aicpu/instance_norm_v2_grad.py +44 -0
  381. mindspore/ops/_op_impl/aicpu/layer_norm_grad_grad.py +47 -0
  382. mindspore/ops/_op_impl/aicpu/less.py +41 -0
  383. mindspore/ops/_op_impl/aicpu/less_equal.py +41 -0
  384. mindspore/ops/_op_impl/aicpu/lgamma.py +32 -0
  385. mindspore/ops/_op_impl/aicpu/log_normal_reverse.py +33 -0
  386. mindspore/ops/_op_impl/aicpu/logit.py +33 -0
  387. mindspore/ops/_op_impl/aicpu/logit_grad.py +34 -0
  388. mindspore/ops/_op_impl/aicpu/masked_fill.py +42 -0
  389. mindspore/ops/_op_impl/aicpu/masked_scatter.py +39 -0
  390. mindspore/ops/_op_impl/aicpu/matmul.py +39 -0
  391. mindspore/ops/_op_impl/aicpu/matrix_logarithm.py +31 -0
  392. mindspore/ops/_op_impl/aicpu/matrix_power.py +32 -0
  393. mindspore/ops/_op_impl/aicpu/matrix_solve_ls.py +36 -0
  394. mindspore/ops/_op_impl/aicpu/matrix_triangular_solve.py +36 -0
  395. mindspore/ops/_op_impl/aicpu/mirror_pad.py +2 -0
  396. mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +0 -4
  397. mindspore/ops/_op_impl/aicpu/mul.py +3 -1
  398. mindspore/ops/_op_impl/aicpu/multinomial.py +14 -6
  399. mindspore/ops/_op_impl/aicpu/multinomial_with_replacement.py +35 -0
  400. mindspore/ops/_op_impl/aicpu/nan_to_num.py +34 -0
  401. mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
  402. mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
  403. mindspore/ops/_op_impl/aicpu/ones_like.py +0 -2
  404. mindspore/ops/_op_impl/aicpu/polar.py +32 -0
  405. mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
  406. mindspore/ops/_op_impl/aicpu/qr.py +36 -0
  407. mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
  408. mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
  409. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
  410. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_tensor.py +74 -0
  411. mindspore/ops/_op_impl/aicpu/random_shuffle.py +3 -0
  412. mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
  413. mindspore/ops/_op_impl/aicpu/range.py +36 -0
  414. mindspore/ops/_op_impl/aicpu/reciprocal.py +34 -0
  415. mindspore/ops/_op_impl/aicpu/reciprocal_grad.py +35 -0
  416. mindspore/ops/_op_impl/aicpu/reduce_sum.py +57 -0
  417. mindspore/ops/_op_impl/aicpu/resize_bicubic.py +2 -8
  418. mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +1 -1
  419. mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
  420. mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
  421. mindspore/ops/_op_impl/aicpu/scatter_elements.py +4 -0
  422. mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +2 -0
  423. mindspore/ops/_op_impl/aicpu/search_sorted.py +12 -6
  424. mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py +34 -0
  425. mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
  426. mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
  427. mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
  428. mindspore/ops/_op_impl/aicpu/slice_grad.py +76 -0
  429. mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
  430. mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
  431. mindspore/ops/_op_impl/aicpu/sort.py +39 -0
  432. mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +0 -24
  433. mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
  434. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows.py +63 -0
  435. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows_grad.py +45 -0
  436. mindspore/ops/_op_impl/aicpu/sparse_matrix_mat_mul.py +56 -0
  437. mindspore/ops/_op_impl/{tbe/slice_ds.py → aicpu/sparse_segment_sum.py} +16 -24
  438. mindspore/ops/_op_impl/aicpu/sparse_segment_sum_with_num_segments.py +68 -0
  439. mindspore/ops/_op_impl/aicpu/sparse_slice.py +63 -0
  440. mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +61 -0
  441. mindspore/ops/_op_impl/aicpu/squared_difference.py +2 -0
  442. mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +93 -0
  443. mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +66 -0
  444. mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
  445. mindspore/ops/_op_impl/{tbe/gather_v2.py → aicpu/tile.py} +24 -24
  446. mindspore/ops/_op_impl/aicpu/tridiagonal_solve.py +35 -0
  447. mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
  448. mindspore/ops/_op_impl/aicpu/triu_indices.py +34 -0
  449. mindspore/ops/_op_impl/aicpu/uniform.py +34 -0
  450. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +1 -0
  451. mindspore/ops/_op_impl/aicpu/unique_consecutive.py +10 -2
  452. mindspore/ops/_op_impl/cpu/__init__.py +1 -2
  453. mindspore/ops/_op_impl/cpu/dynamic_shape.py +5 -1
  454. mindspore/ops/_op_impl/cpu/maximum_grad.py +2 -0
  455. mindspore/{compression/common/__init__.py → ops/_op_impl/cpu/pyexecute.py} +13 -8
  456. mindspore/ops/_op_impl/cpu/reduce_sum.py +8 -0
  457. mindspore/ops/_op_impl/cpu/sparse_slice.py +62 -0
  458. mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +60 -0
  459. mindspore/ops/_op_impl/cpu/tensor_shape.py +5 -1
  460. mindspore/ops/_op_impl/tbe/__init__.py +27 -608
  461. mindspore/ops/_op_impl/tbe/addcdiv_ds.py +42 -0
  462. mindspore/ops/_op_impl/tbe/addcmul_ds.py +44 -0
  463. mindspore/ops/_op_impl/tbe/assign_add_ds.py +1 -0
  464. mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
  465. mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +1 -1
  466. mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad_v2.py +0 -1
  467. mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
  468. mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +1 -1
  469. mindspore/ops/_op_impl/tbe/batch_to_space_nd_v2.py +41 -0
  470. mindspore/ops/_op_impl/tbe/bce_with_logits_loss.py +1 -0
  471. mindspore/ops/_op_impl/tbe/bias_add_grad.py +2 -0
  472. mindspore/ops/_op_impl/tbe/bn_infer_grad.py +4 -2
  473. mindspore/ops/_op_impl/tbe/bn_infer_grad_ds.py +40 -0
  474. mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -1
  475. mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -1
  476. mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +6 -4
  477. mindspore/ops/_op_impl/tbe/cast.py +0 -2
  478. mindspore/ops/_op_impl/tbe/cast_ds.py +3 -3
  479. mindspore/ops/_op_impl/tbe/ctc_loss_v2.py +0 -2
  480. mindspore/ops/_op_impl/tbe/ctc_loss_v2_grad.py +0 -2
  481. mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +1 -0
  482. mindspore/ops/_op_impl/tbe/deformable_offsets.py +1 -0
  483. mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +1 -1
  484. mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +1 -1
  485. mindspore/ops/_op_impl/tbe/gather_nd.py +1 -0
  486. mindspore/ops/_op_impl/tbe/greater.py +2 -0
  487. mindspore/ops/_op_impl/tbe/{index_add.py → inplace_index_add.py} +3 -6
  488. mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2.py +0 -1
  489. mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +35 -0
  490. mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +35 -0
  491. mindspore/ops/_op_impl/tbe/one_hot_ds.py +0 -6
  492. mindspore/ops/_op_impl/tbe/{greater_ds.py → reduce_all_ds.py} +13 -16
  493. mindspore/ops/_op_impl/tbe/reduce_any_ds.py +39 -0
  494. mindspore/ops/_op_impl/tbe/roi_align_ds.py +44 -0
  495. mindspore/ops/_op_impl/tbe/roi_align_grad_ds.py +44 -0
  496. mindspore/ops/_op_impl/tbe/scatter_add.py +2 -0
  497. mindspore/ops/_op_impl/tbe/scatter_nd_add.py +2 -2
  498. mindspore/ops/_op_impl/tbe/slice.py +26 -15
  499. mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
  500. mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +1 -1
  501. mindspore/ops/_op_impl/tbe/strided_slice_grad_d.py +1 -0
  502. mindspore/ops/_op_impl/tbe/trans_data_ds.py +15 -5
  503. mindspore/ops/_op_impl/tbe/unsorted_segment_sum.py +1 -1
  504. mindspore/ops/_op_impl/tbe/unsorted_segment_sum_ds.py +2 -0
  505. mindspore/ops/_primitive_cache.py +3 -2
  506. mindspore/ops/_register_for_op.py +11 -0
  507. mindspore/ops/_utils/__init__.py +1 -1
  508. mindspore/ops/_utils/utils.py +20 -41
  509. mindspore/ops/_vmap/__init__.py +2 -2
  510. mindspore/ops/_vmap/vmap_array_ops.py +170 -78
  511. mindspore/ops/_vmap/vmap_base.py +24 -10
  512. mindspore/ops/_vmap/vmap_convolution_ops.py +7 -10
  513. mindspore/ops/_vmap/vmap_grad_math_ops.py +4 -4
  514. mindspore/ops/_vmap/vmap_grad_nn_ops.py +41 -9
  515. mindspore/ops/_vmap/vmap_image_ops.py +52 -0
  516. mindspore/ops/_vmap/vmap_math_ops.py +77 -6
  517. mindspore/ops/_vmap/vmap_nn_ops.py +78 -29
  518. mindspore/ops/_vmap/vmap_other_ops.py +3 -1
  519. mindspore/ops/_vmap/vmap_random_ops.py +55 -3
  520. mindspore/ops/_vmap/vmap_sparse_ops.py +1 -0
  521. mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
  522. mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
  523. mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +18 -19
  524. mindspore/ops/bprop_mindir/Argmax_bprop.mindir +13 -12
  525. mindspore/ops/bprop_mindir/Argmin_bprop.mindir +14 -13
  526. mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +17 -18
  527. mindspore/ops/bprop_mindir/Assign_bprop.mindir +16 -16
  528. mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +150 -0
  529. mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +66 -0
  530. mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
  531. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +13 -12
  532. mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
  533. mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +28 -0
  534. mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
  535. mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +33 -0
  536. mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +306 -0
  537. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +12 -8
  538. mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
  539. mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
  540. mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +240 -0
  541. mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +247 -0
  542. mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +247 -0
  543. mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +315 -0
  544. mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +278 -0
  545. mindspore/ops/bprop_mindir/DType_bprop.mindir +12 -12
  546. mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +58 -0
  547. mindspore/ops/bprop_mindir/Depend_bprop.mindir +12 -13
  548. mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +23 -0
  549. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +138 -0
  550. mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +15 -0
  551. mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
  552. mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
  553. mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +22 -24
  554. mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +16 -14
  555. mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +27 -0
  556. mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
  557. mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
  558. mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
  559. mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +12 -12
  560. mindspore/ops/bprop_mindir/Elu_bprop.mindir +16 -0
  561. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  562. mindspore/ops/bprop_mindir/Equal_bprop.mindir +18 -19
  563. mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +58 -0
  564. mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +16 -0
  565. mindspore/ops/bprop_mindir/Flatten_bprop.mindir +54 -0
  566. mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +18 -15
  567. mindspore/ops/bprop_mindir/GatherD_bprop.mindir +26 -0
  568. mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +57 -0
  569. mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
  570. mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +17 -18
  571. mindspore/ops/bprop_mindir/Greater_bprop.mindir +18 -19
  572. mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +16 -0
  573. mindspore/ops/bprop_mindir/HSwish_bprop.mindir +16 -0
  574. mindspore/ops/bprop_mindir/IOU_bprop.mindir +18 -19
  575. mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
  576. mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +13 -12
  577. mindspore/ops/bprop_mindir/IsInf_bprop.mindir +13 -10
  578. mindspore/ops/bprop_mindir/IsNan_bprop.mindir +14 -11
  579. mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +126 -0
  580. mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +15 -0
  581. mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +30 -0
  582. mindspore/ops/bprop_mindir/LRN_bprop.mindir +43 -0
  583. mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
  584. mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +18 -19
  585. mindspore/ops/bprop_mindir/Less_bprop.mindir +17 -18
  586. mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +22 -19
  587. mindspore/ops/bprop_mindir/Load_bprop.mindir +12 -13
  588. mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +23 -0
  589. mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +17 -18
  590. mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +14 -13
  591. mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +21 -0
  592. mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +74 -0
  593. mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +74 -0
  594. mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +75 -0
  595. mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +65 -0
  596. mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
  597. mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
  598. mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
  599. mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +27 -0
  600. mindspore/ops/bprop_mindir/Mish_bprop.mindir +35 -0
  601. mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
  602. mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
  603. mindspore/ops/bprop_mindir/NonZero_bprop.mindir +14 -0
  604. mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +18 -19
  605. mindspore/ops/bprop_mindir/OneHot_bprop.mindir +25 -23
  606. mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +13 -13
  607. mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
  608. mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
  609. mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
  610. mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +29 -0
  611. mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +82 -0
  612. mindspore/ops/bprop_mindir/Range_bprop.mindir +21 -19
  613. mindspore/ops/bprop_mindir/Rank_bprop.mindir +11 -11
  614. mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +16 -0
  615. mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
  616. mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +18 -17
  617. mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +18 -17
  618. mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +19 -23
  619. mindspore/ops/bprop_mindir/Reshape_bprop.mindir +60 -0
  620. mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +29 -0
  621. mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +89 -0
  622. mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +52 -0
  623. mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +22 -0
  624. mindspore/ops/bprop_mindir/Round_bprop.mindir +14 -13
  625. mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
  626. mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
  627. mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +22 -0
  628. mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +24 -0
  629. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +22 -0
  630. mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
  631. mindspore/ops/bprop_mindir/SeLU_bprop.mindir +21 -0
  632. mindspore/ops/bprop_mindir/Select_bprop.mindir +30 -34
  633. mindspore/ops/bprop_mindir/Shape_bprop.mindir +12 -12
  634. mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +21 -0
  635. mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
  636. mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +16 -0
  637. mindspore/ops/bprop_mindir/Sign_bprop.mindir +13 -12
  638. mindspore/ops/bprop_mindir/Slice_bprop.mindir +26 -0
  639. mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +36 -0
  640. mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  641. mindspore/ops/bprop_mindir/Softplus_bprop.mindir +16 -0
  642. mindspore/ops/bprop_mindir/Softsign_bprop.mindir +33 -0
  643. mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
  644. mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +28 -0
  645. mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +23 -0
  646. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  647. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  648. mindspore/ops/bprop_mindir/Split_bprop.mindir +22 -0
  649. mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +54 -0
  650. mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +95 -0
  651. mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +98 -0
  652. mindspore/ops/bprop_mindir/Switch_bprop.mindir +28 -32
  653. mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
  654. mindspore/ops/bprop_mindir/Tanh_bprop.mindir +66 -0
  655. mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +22 -0
  656. mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +29 -0
  657. mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +14 -0
  658. mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
  659. mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
  660. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +23 -0
  661. mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +18 -15
  662. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +11 -13
  663. mindspore/ops/bprop_mindir/Unique_bprop.mindir +16 -0
  664. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +22 -0
  665. mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +32 -0
  666. mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +38 -0
  667. mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +13 -12
  668. mindspore/ops/bprop_mindir/__init__.py +1 -4
  669. mindspore/ops/bprop_mindir/generate_mindir.py +32 -20
  670. mindspore/ops/composite/__init__.py +12 -13
  671. mindspore/ops/composite/base.py +261 -254
  672. mindspore/ops/composite/env_ops.py +41 -0
  673. mindspore/ops/composite/math_ops.py +197 -156
  674. mindspore/ops/composite/multitype_ops/_compile_utils.py +428 -176
  675. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +188 -87
  676. mindspore/ops/composite/multitype_ops/add_impl.py +23 -1
  677. mindspore/ops/composite/multitype_ops/div_impl.py +3 -3
  678. mindspore/ops/composite/multitype_ops/equal_impl.py +1 -0
  679. mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -1
  680. mindspore/ops/composite/multitype_ops/getitem_impl.py +52 -5
  681. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +31 -0
  682. mindspore/ops/composite/multitype_ops/greater_impl.py +31 -0
  683. mindspore/ops/composite/multitype_ops/in_impl.py +15 -3
  684. mindspore/ops/composite/multitype_ops/less_equal_impl.py +33 -2
  685. mindspore/ops/composite/multitype_ops/less_impl.py +33 -0
  686. mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -2
  687. mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
  688. mindspore/ops/composite/multitype_ops/mod_impl.py +1 -1
  689. mindspore/ops/composite/multitype_ops/mul_impl.py +21 -7
  690. mindspore/ops/composite/multitype_ops/not_in_impl.py +15 -3
  691. mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -4
  692. mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
  693. mindspore/ops/composite/multitype_ops/setitem_impl.py +62 -70
  694. mindspore/ops/composite/multitype_ops/sub_impl.py +3 -3
  695. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +41 -4
  696. mindspore/ops/function/__init__.py +323 -8
  697. mindspore/ops/function/array_func.py +3511 -780
  698. mindspore/ops/function/clip_func.py +329 -0
  699. mindspore/ops/function/debug_func.py +6 -6
  700. mindspore/ops/function/grad/__init__.py +5 -1
  701. mindspore/ops/function/grad/grad_func.py +736 -65
  702. mindspore/ops/function/image_func.py +270 -0
  703. mindspore/ops/function/linalg_func.py +268 -8
  704. mindspore/ops/function/math_func.py +8032 -3164
  705. mindspore/ops/function/nn_func.py +5619 -1855
  706. mindspore/ops/function/other_func.py +115 -0
  707. mindspore/ops/function/parameter_func.py +11 -10
  708. mindspore/ops/function/random_func.py +939 -77
  709. mindspore/ops/function/sparse_func.py +249 -84
  710. mindspore/ops/function/sparse_unary_func.py +2303 -0
  711. mindspore/ops/function/spectral_func.py +146 -0
  712. mindspore/ops/function/vmap_func.py +114 -0
  713. mindspore/ops/functional.py +182 -254
  714. mindspore/ops/op_info_register.py +79 -34
  715. mindspore/ops/operations/__init__.py +210 -118
  716. mindspore/ops/operations/_csr_ops.py +7 -7
  717. mindspore/ops/operations/_embedding_cache_ops.py +25 -15
  718. mindspore/ops/operations/_grad_ops.py +447 -322
  719. mindspore/ops/operations/_inner_ops.py +547 -176
  720. mindspore/ops/operations/_map_tensor_ops.py +112 -0
  721. mindspore/ops/operations/_ms_kernel.py +29 -27
  722. mindspore/ops/operations/_ocr_ops.py +11 -11
  723. mindspore/ops/operations/_opaque_predicate_registry.py +41 -0
  724. mindspore/ops/operations/_quant_ops.py +186 -101
  725. mindspore/ops/operations/_rl_inner_ops.py +122 -61
  726. mindspore/ops/operations/_scalar_ops.py +466 -0
  727. mindspore/ops/operations/_sequence_ops.py +1047 -0
  728. mindspore/ops/operations/_tensor_array.py +10 -11
  729. mindspore/ops/operations/_thor_ops.py +4 -4
  730. mindspore/ops/operations/array_ops.py +1428 -1226
  731. mindspore/ops/operations/comm_ops.py +180 -117
  732. mindspore/ops/operations/control_ops.py +4 -2
  733. mindspore/ops/operations/custom_ops.py +185 -98
  734. mindspore/ops/operations/debug_ops.py +92 -54
  735. mindspore/ops/operations/image_ops.py +406 -211
  736. mindspore/ops/operations/inner_ops.py +42 -53
  737. mindspore/ops/operations/linalg_ops.py +32 -29
  738. mindspore/ops/operations/math_ops.py +2076 -897
  739. mindspore/ops/operations/nn_ops.py +1282 -1252
  740. mindspore/ops/operations/other_ops.py +124 -278
  741. mindspore/ops/operations/random_ops.py +345 -178
  742. mindspore/ops/operations/rl_ops.py +8 -9
  743. mindspore/ops/operations/sparse_ops.py +502 -157
  744. mindspore/ops/operations/spectral_ops.py +107 -0
  745. mindspore/ops/primitive.py +192 -15
  746. mindspore/ops/vm_impl_registry.py +23 -2
  747. mindspore/parallel/__init__.py +6 -1
  748. mindspore/parallel/_auto_parallel_context.py +199 -92
  749. mindspore/parallel/_cell_wrapper.py +4 -2
  750. mindspore/parallel/_cost_model_context.py +3 -0
  751. mindspore/parallel/_dp_allreduce_fusion.py +2 -1
  752. mindspore/parallel/_offload_context.py +185 -0
  753. mindspore/parallel/_parallel_serialization.py +167 -28
  754. mindspore/parallel/_ps_context.py +9 -5
  755. mindspore/parallel/_recovery_context.py +1 -1
  756. mindspore/parallel/_tensor.py +9 -1
  757. mindspore/{nn/transformer → parallel/_transformer}/__init__.py +6 -6
  758. mindspore/{nn/transformer → parallel/_transformer}/layers.py +59 -37
  759. mindspore/{nn/transformer → parallel/_transformer}/loss.py +4 -7
  760. mindspore/{nn/transformer → parallel/_transformer}/moe.py +160 -35
  761. mindspore/{nn/transformer → parallel/_transformer}/op_parallel_config.py +3 -3
  762. mindspore/{nn/transformer → parallel/_transformer}/transformer.py +235 -196
  763. mindspore/parallel/_utils.py +47 -7
  764. mindspore/parallel/algo_parameter_config.py +5 -1
  765. mindspore/parallel/checkpoint_transform.py +329 -0
  766. mindspore/parallel/shard.py +229 -0
  767. mindspore/profiler/__init__.py +2 -1
  768. mindspore/profiler/common/util.py +4 -3
  769. mindspore/profiler/common/validator/validate_path.py +2 -2
  770. mindspore/profiler/envprofiling.py +249 -0
  771. mindspore/profiler/parser/aicpu_data_parser.py +38 -39
  772. mindspore/profiler/parser/ascend_timeline_generator.py +497 -0
  773. mindspore/profiler/parser/base_timeline_generator.py +471 -0
  774. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +684 -0
  775. mindspore/profiler/parser/framework_parser.py +42 -16
  776. mindspore/profiler/parser/hccl_parser.py +158 -158
  777. mindspore/profiler/parser/hwts_log_parser.py +7 -6
  778. mindspore/profiler/parser/integrator.py +18 -1579
  779. mindspore/profiler/parser/minddata_analyzer.py +8 -8
  780. mindspore/profiler/parser/msadvisor_analyzer.py +14 -27
  781. mindspore/profiler/parser/msadvisor_parser.py +2 -4
  782. mindspore/profiler/parser/optime_parser.py +17 -18
  783. mindspore/profiler/parser/profiler_info.py +108 -0
  784. mindspore/profiler/parser/step_trace_parser.py +1 -1
  785. mindspore/profiler/profiling.py +396 -194
  786. mindspore/rewrite/__init__.py +6 -2
  787. mindspore/rewrite/api/node.py +51 -110
  788. mindspore/rewrite/api/node_type.py +10 -6
  789. mindspore/rewrite/api/pattern_engine.py +51 -7
  790. mindspore/rewrite/api/scoped_value.py +64 -53
  791. mindspore/rewrite/api/symbol_tree.py +108 -61
  792. mindspore/rewrite/api/tree_node_helper.py +2 -3
  793. mindspore/{compression/quant/__init__.py → rewrite/ast_creator_register.py} +20 -11
  794. mindspore/rewrite/ast_helpers/__init__.py +6 -3
  795. mindspore/rewrite/ast_helpers/ast_creator.py +115 -0
  796. mindspore/rewrite/ast_helpers/ast_finder.py +99 -1
  797. mindspore/rewrite/ast_helpers/ast_modifier.py +17 -4
  798. mindspore/rewrite/ast_helpers/ast_replacer.py +1 -1
  799. mindspore/rewrite/ast_transformers/__init__.py +0 -1
  800. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +46 -5
  801. mindspore/rewrite/ast_transformers/remove_return_out_of_if.py +6 -3
  802. mindspore/rewrite/common/__init__.py +2 -0
  803. mindspore/rewrite/common/event.py +1 -1
  804. mindspore/rewrite/common/observable.py +1 -1
  805. mindspore/rewrite/common/observer.py +1 -1
  806. mindspore/rewrite/common/rewrite_elog.py +35 -0
  807. mindspore/rewrite/namer.py +2 -2
  808. mindspore/rewrite/namespace.py +14 -4
  809. mindspore/rewrite/node.py +161 -13
  810. mindspore/rewrite/parser.py +0 -1
  811. mindspore/rewrite/parser_register.py +0 -1
  812. mindspore/rewrite/parsers/arguments_parser.py +3 -2
  813. mindspore/rewrite/parsers/assign_parser.py +267 -67
  814. mindspore/rewrite/parsers/attribute_parser.py +56 -0
  815. mindspore/rewrite/parsers/class_def_parser.py +191 -108
  816. mindspore/rewrite/parsers/constant_parser.py +101 -0
  817. mindspore/rewrite/parsers/container_parser.py +88 -0
  818. mindspore/rewrite/parsers/for_parser.py +28 -15
  819. mindspore/rewrite/parsers/function_def_parser.py +21 -5
  820. mindspore/rewrite/parsers/if_parser.py +11 -28
  821. mindspore/rewrite/parsers/module_parser.py +9 -6
  822. mindspore/rewrite/parsers/return_parser.py +3 -2
  823. mindspore/rewrite/sparsify/__init__.py +0 -0
  824. mindspore/rewrite/sparsify/sparse_transformer.py +448 -0
  825. mindspore/rewrite/sparsify/sparsify.py +109 -0
  826. mindspore/rewrite/sparsify/utils.py +173 -0
  827. mindspore/rewrite/symbol_tree.py +322 -109
  828. mindspore/rewrite/symbol_tree_builder.py +45 -8
  829. mindspore/rewrite/symbol_tree_dumper.py +0 -1
  830. mindspore/rewrite/topological_manager.py +1 -2
  831. mindspore/run_check/_check_version.py +209 -112
  832. mindspore/run_check/run_check.py +2 -1
  833. mindspore/scipy/linalg.py +13 -117
  834. mindspore/scipy/ops.py +5 -71
  835. mindspore/scipy/ops_grad.py +1 -25
  836. mindspore/scipy/ops_wrapper.py +1 -1
  837. mindspore/scipy/optimize/_bfgs.py +1 -1
  838. mindspore/scipy/optimize/_lagrange.py +200 -0
  839. mindspore/scipy/optimize/line_search.py +3 -2
  840. mindspore/scipy/optimize/minimize.py +43 -6
  841. mindspore/scipy/sparse/__init__.py +2 -2
  842. mindspore/scipy/sparse/linalg.py +5 -465
  843. mindspore/scipy/utils.py +2 -1
  844. mindspore/scipy/utils_const.py +7 -1
  845. mindspore/train/__init__.py +6 -4
  846. mindspore/train/_utils.py +28 -5
  847. mindspore/train/amp.py +321 -50
  848. mindspore/train/callback/__init__.py +3 -1
  849. mindspore/train/callback/_backup_and_restore.py +120 -0
  850. mindspore/train/callback/_callback.py +8 -8
  851. mindspore/train/callback/_checkpoint.py +12 -9
  852. mindspore/train/callback/_early_stop.py +13 -7
  853. mindspore/train/callback/_history.py +8 -8
  854. mindspore/train/callback/_lambda_callback.py +6 -6
  855. mindspore/train/callback/_landscape.py +36 -38
  856. mindspore/train/callback/_loss_monitor.py +12 -6
  857. mindspore/train/callback/_lr_scheduler_callback.py +2 -4
  858. mindspore/train/callback/_on_request_exit.py +212 -0
  859. mindspore/train/callback/_reduce_lr_on_plateau.py +13 -7
  860. mindspore/train/callback/_summary_collector.py +27 -19
  861. mindspore/train/callback/_time_monitor.py +13 -7
  862. mindspore/train/checkpoint_pb2.py +68 -8
  863. mindspore/train/data_sink.py +122 -33
  864. mindspore/train/dataset_helper.py +28 -87
  865. mindspore/train/loss_scale_manager.py +4 -7
  866. mindspore/{nn → train}/metrics/__init__.py +20 -20
  867. mindspore/{nn → train}/metrics/accuracy.py +12 -10
  868. mindspore/{nn → train}/metrics/auc.py +4 -4
  869. mindspore/{nn → train}/metrics/bleu_score.py +4 -4
  870. mindspore/{nn → train}/metrics/confusion_matrix.py +10 -8
  871. mindspore/{nn → train}/metrics/cosine_similarity.py +4 -4
  872. mindspore/{nn → train}/metrics/dice.py +6 -5
  873. mindspore/{nn → train}/metrics/error.py +7 -5
  874. mindspore/{nn → train}/metrics/fbeta.py +9 -7
  875. mindspore/{nn → train}/metrics/hausdorff_distance.py +8 -6
  876. mindspore/{nn → train}/metrics/loss.py +4 -3
  877. mindspore/{nn → train}/metrics/mean_surface_distance.py +6 -5
  878. mindspore/{nn → train}/metrics/metric.py +6 -5
  879. mindspore/{nn → train}/metrics/occlusion_sensitivity.py +4 -3
  880. mindspore/{nn → train}/metrics/perplexity.py +5 -4
  881. mindspore/{nn → train}/metrics/precision.py +5 -4
  882. mindspore/{nn → train}/metrics/recall.py +5 -4
  883. mindspore/{nn → train}/metrics/roc.py +7 -6
  884. mindspore/{nn → train}/metrics/root_mean_square_surface_distance.py +6 -5
  885. mindspore/{nn → train}/metrics/topk.py +7 -5
  886. mindspore/train/mind_ir_pb2.py +339 -32
  887. mindspore/train/model.py +113 -84
  888. mindspore/train/serialization.py +547 -167
  889. mindspore/train/summary/_summary_adapter.py +1 -1
  890. mindspore/train/summary/summary_record.py +43 -12
  891. mindspore/train/train_thor/convert_utils.py +7 -1
  892. mindspore/train/train_thor/dataset_helper.py +3 -3
  893. mindspore/train/train_thor/model_thor.py +0 -4
  894. mindspore/version.py +1 -1
  895. {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/METADATA +4 -3
  896. {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/RECORD +899 -675
  897. mindspore/compression/common/constant.py +0 -124
  898. mindspore/compression/export/__init__.py +0 -19
  899. mindspore/compression/export/quant_export.py +0 -514
  900. mindspore/compression/quant/qat.py +0 -636
  901. mindspore/compression/quant/quant_utils.py +0 -462
  902. mindspore/compression/quant/quantizer.py +0 -68
  903. mindspore/nn/layer/quant.py +0 -1868
  904. mindspore/nn/layer/rnn_utils.py +0 -90
  905. mindspore/nn/probability/dpn/__init__.py +0 -22
  906. mindspore/nn/probability/dpn/vae/__init__.py +0 -25
  907. mindspore/nn/probability/dpn/vae/cvae.py +0 -138
  908. mindspore/nn/probability/dpn/vae/vae.py +0 -122
  909. mindspore/nn/probability/infer/__init__.py +0 -22
  910. mindspore/nn/probability/infer/variational/elbo.py +0 -70
  911. mindspore/nn/probability/infer/variational/svi.py +0 -84
  912. mindspore/nn/probability/toolbox/__init__.py +0 -22
  913. mindspore/nn/probability/toolbox/anomaly_detection.py +0 -99
  914. mindspore/nn/probability/toolbox/uncertainty_evaluation.py +0 -363
  915. mindspore/nn/probability/transforms/__init__.py +0 -22
  916. mindspore/nn/probability/transforms/transform_bnn.py +0 -262
  917. mindspore/nn/probability/zhusuan/__init__.py +0 -18
  918. mindspore/nn/probability/zhusuan/framework/__init__.py +0 -18
  919. mindspore/nn/probability/zhusuan/framework/bn.py +0 -95
  920. mindspore/nn/probability/zhusuan/variational/__init__.py +0 -18
  921. mindspore/nn/probability/zhusuan/variational/elbo.py +0 -46
  922. mindspore/ops/_op_impl/tbe/bias_add_grad_ds.py +0 -52
  923. mindspore/ops/_op_impl/tbe/scatter_nd_add_ds.py +0 -43
  924. mindspore/ops/bprop_mindir/AssignAdd_bprop.mindir +0 -20
  925. mindspore/ops/bprop_mindir/Identity_bprop.mindir +0 -9
  926. mindspore/ops/bprop_mindir/LogicalOr_bprop.mindir +0 -20
  927. mindspore/ops/bprop_mindir/ReLU_bprop.mindir +0 -16
  928. mindspore/ops/bprop_mindir/UpdateState_bprop.mindir +0 -17
  929. mindspore/ops/bprop_mindir/stop_gradient_bprop.mindir +0 -12
  930. mindspore/ops/composite/array_ops.py +0 -210
  931. mindspore/ops/composite/clip_ops.py +0 -238
  932. mindspore/ops/composite/random_ops.py +0 -426
  933. mindspore/ops/composite/vmap_ops.py +0 -38
  934. mindspore/ops/operations/sponge_ops.py +0 -3531
  935. mindspore/ops/operations/sponge_update_ops.py +0 -2546
  936. mindspore/parallel/nn/__init__.py +0 -42
  937. mindspore/parallel/nn/loss.py +0 -22
  938. mindspore/parallel/nn/moe.py +0 -21
  939. mindspore/parallel/nn/op_parallel_config.py +0 -22
  940. mindspore/parallel/nn/transformer.py +0 -31
  941. mindspore/run_check/_check_deps_version.py +0 -84
  942. {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/WHEEL +0 -0
  943. {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/entry_points.txt +0 -0
  944. {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2021 Huawei Technologies Co., Ltd
1
+ # Copyright 2021-2022 Huawei Technologies Co., Ltd
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -30,17 +30,17 @@ import mindspore.common.dtype as mstype
30
30
  from mindspore.ops import operations as P
31
31
  from mindspore.ops import functional as F
32
32
  from mindspore.nn.cell import Cell
33
- from mindspore._checkparam import Validator
33
+ from mindspore import _checkparam as Validator
34
34
  from mindspore import log as logger
35
- from mindspore.parallel._utils import _get_parallel_mode, _is_sharding_propagation
35
+ from mindspore.parallel._utils import _get_parallel_mode
36
36
  from mindspore.context import ParallelMode
37
37
  from mindspore.log import _LogActionOnce
38
- from mindspore.nn.transformer.layers import _LayerNorm, _Linear, _check_input_shape, \
38
+ from mindspore.parallel._transformer.layers import _LayerNorm, _Linear, \
39
39
  _args_type_validator_check, _valid_type_checks, _valid_value_checks, \
40
- _check_shape_equal, _check_past_none_input_none, _check_input_dtype, _check_input_shape_value
41
- from mindspore.nn.transformer.op_parallel_config import default_dpmp_config, _PipeLineConfig, OpParallelConfig,\
40
+ _check_past_none_input_none, _check_input_dtype
41
+ from mindspore.parallel._transformer.op_parallel_config import default_dpmp_config, _PipeLineConfig, OpParallelConfig, \
42
42
  _Config, _check_config, MoEParallelConfig
43
- from mindspore.nn.transformer.moe import default_moe_config, MoE, _check_moe_config
43
+ from mindspore.parallel._transformer.moe import default_moe_config, MoE, _check_moe_config
44
44
 
45
45
  __all__ = [
46
46
  "AttentionMask",
@@ -352,9 +352,11 @@ class FeedForward(Cell):
352
352
  hidden_size (int): The dimension of the inputs.
353
353
  ffn_hidden_size (int): The intermediate hidden size.
354
354
  dropout_rate (float): The dropout rate for the second linear's output.
355
- hidden_act (str): The activation of the internal feedforward layer. Supports 'relu',
355
+ hidden_act (str, nn.Cell): The activation of the internal feedforward layer. Supports 'relu',
356
356
  'relu6', 'tanh', 'gelu', 'fast_gelu', 'elu', 'sigmoid', 'prelu', 'leakyrelu', 'hswish',
357
- 'hsigmoid', 'logsigmoid' and so on. Default: gelu.
357
+ 'hsigmoid', 'logsigmoid' and so on. User can provide custom activition to the argument.
358
+ If user wants to run the net in the parallel mode, the custom activation must also provide
359
+ the `activation_shard` function. Please see examples. Default: gelu.
358
360
  expert_num (int): The number of experts used in Linear. For the case expert_num > 1, BatchMatMul is used
359
361
  and the first dimension in BatchMatMul indicate expert_num. Default: 1.
360
362
  expert_group_size (int): The number of tokens in each data parallel group. Default: None. This parameter is
@@ -375,7 +377,7 @@ class FeedForward(Cell):
375
377
  [batch * seq_length, hidden_size]`.
376
378
 
377
379
  Raises:
378
- ValueError: `hidden_act` is not a string.
380
+ TypeError: `hidden_act` is not a string or nn.Cell.
379
381
  TypeError: `parallel_config` is not a subclass of OpParallelConfig.
380
382
  ValueError: `ffn_hidden_size` is not a multiple of the model parallel way.
381
383
  ValueError: `hidden_size` is not a multiple of the model parallel way.
@@ -387,19 +389,51 @@ class FeedForward(Cell):
387
389
  >>> import numpy as np
388
390
  >>> from mindspore.nn.transformer import FeedForward
389
391
  >>> from mindspore import dtype as mstype
390
- >>> from mindspore import Tensor
392
+ >>> from mindspore import Tensor, nn
393
+ >>> import mindspore.ops as ops
391
394
  >>> model = FeedForward(hidden_size=15, ffn_hidden_size=30, dropout_rate=0.1)
392
395
  >>> tensor = Tensor(np.ones((2, 20, 15)), mstype.float32)
393
396
  >>> output = model(tensor)
394
397
  >>> print(output.shape)
395
398
  (2, 20, 15)
399
+ >>> # Example 2 using custom hidden activation
400
+ >>> class MyActivationNoShard(nn.Cell):
401
+ ... def __init__(self):
402
+ ... super(MyActivationNoShard, self).__init__()
403
+ ... self.add = ops.Add()
404
+ ... def construct(self, x):
405
+ ... return self.add(x, 0.1)
406
+ >>> model = FeedForward(hidden_size=15, ffn_hidden_size=30, dropout_rate=0.1,
407
+ ... hidden_act=MyActivationNoShard)
408
+ >>> tensor = Tensor(np.ones((2, 20, 15)), mstype.float32)
409
+ >>> output = model(tensor)
410
+ >>> print(output.shape)
411
+ (2, 20, 15)
412
+ >>> # Example 3 using custom hidden activation with activation_shard
413
+ >>> # If user wantss to run on the SEMI/AUTO parallel mode, the custom activation must provide
414
+ >>> # a class function named activation_shard. It accepts the argument parallel_config (OpParallelConfig,
415
+ >>> # MoEParallelConfig) and set the shard for the primitives used in the construct.
416
+ >>> class MyActivationWithShard(nn.Cell):
417
+ ... def __init__(self):
418
+ ... super(MyActivationWithShard, self).__init__()
419
+ ... self.add = ops.Add()
420
+ ... def construct(self, x):
421
+ ... return self.add(x, 0.1)
422
+ ... def activation_shard(self, parallel_config):
423
+ ... self.add.shard(((parallel_config.data_parallel, parallel_config.model_parallel), ()))
424
+ >>>
425
+ >>> model = FeedForward(hidden_size=15, ffn_hidden_size=30, dropout_rate=0.1,
426
+ ... hidden_act=MyActivationWithShard)
427
+ >>> tensor = Tensor(np.ones((2, 20, 15)), mstype.float32)
428
+ >>> output = model(tensor)
429
+ >>> print(output.shape)
430
+ (2, 20, 15)
396
431
  """
397
432
  @_LogActionOnce(logger=logger, key='FeedForward',
398
433
  no_warning=_get_parallel_mode() in (ParallelMode.STAND_ALONE,))
399
434
  @_args_type_validator_check(hidden_size=Validator.check_positive_int,
400
435
  ffn_hidden_size=Validator.check_positive_int,
401
436
  dropout_rate=Validator.check_non_negative_float,
402
- hidden_act=_valid_type_checks([str], "FeedForward"),
403
437
  param_init_type=_valid_value_checks([mstype.float32, mstype.float16],
404
438
  "FeedForward"),
405
439
  parallel_config=_valid_type_checks([OpParallelConfig, MoEParallelConfig],
@@ -413,7 +447,10 @@ class FeedForward(Cell):
413
447
  param_init_type=mstype.float32,
414
448
  parallel_config=default_dpmp_config):
415
449
  super(FeedForward, self).__init__()
416
- if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,) and _is_sharding_propagation():
450
+ if hidden_act is None or not (isinstance(hidden_act, str) or issubclass(hidden_act, nn.Cell)):
451
+ raise TypeError(f"For FeedForward cell, the hidden_act should str type or nn.Cell type, "
452
+ f"but got {hidden_act}.")
453
+ if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,):
417
454
  _check_config(parallel_config)
418
455
  mp = parallel_config.model_parallel
419
456
  if expert_num > 1:
@@ -459,9 +496,9 @@ class FeedForward(Cell):
459
496
  else:
460
497
  self.projection.shard(strategy_matmul=((dp, mp), (mp, 1)))
461
498
  self.projection.bias.parallel_optimizer = False
462
- self.dropout = nn.Dropout(1 - dropout_rate)
463
- self.dropout_3d = nn.Dropout(1 - dropout_rate)
464
- self.dropout_4d = nn.Dropout(1 - dropout_rate)
499
+ self.dropout = nn.Dropout(p=dropout_rate)
500
+ self.dropout_3d = nn.Dropout(p=dropout_rate)
501
+ self.dropout_4d = nn.Dropout(p=dropout_rate)
465
502
  self.cast = P.Cast()
466
503
  else:
467
504
  _check_config(parallel_config)
@@ -519,16 +556,18 @@ class FeedForward(Cell):
519
556
  self.projection.shard(strategy_matmul=((dp, mp), (mp, 1)),
520
557
  strategy_bias=((dp, 1), (1,)))
521
558
  self.projection.bias.parallel_optimizer = False
522
- self.dropout = nn.Dropout(1 - dropout_rate)
559
+ self.dropout = nn.Dropout(p=dropout_rate)
523
560
  self.dropout.dropout.shard(((dp, 1),))
524
- self.dropout_3d = nn.Dropout(1 - dropout_rate)
561
+ self.dropout_3d = nn.Dropout(p=dropout_rate)
525
562
  self.dropout_3d.dropout.shard(((dp, 1, 1),))
526
- self.dropout_4d = nn.Dropout(1 - dropout_rate)
563
+ self.dropout_4d = nn.Dropout(p=dropout_rate)
527
564
  self.dropout_4d.dropout.shard(((dp, ep, 1, 1),))
528
565
  self.cast = P.Cast()
566
+ # for grouped pairwise exchange alltoall method in pass
567
+ self.mapping.matmul.add_prim_attr("gpea_label", True)
568
+ self.projection.matmul.add_prim_attr("gpea_label", True)
529
569
 
530
570
  def construct(self, x):
531
- _check_input_shape(F.shape(x), "x", self.cls_name, [2, 3])
532
571
  _check_input_dtype(F.dtype(x), "x", [mstype.float32, mstype.float16], self.cls_name)
533
572
  x = self.cast(x, mstype.float16)
534
573
  # returned shape is [bs, seq_length, ffn_hidden_size] or [bs * seq_length, ffn_hidden_size]
@@ -601,9 +640,7 @@ class AttentionMask(Cell):
601
640
  self.multiply = P.Mul().shard(((parallel_config.data_parallel, 1, 1), (1, 1, 1)))
602
641
 
603
642
  def construct(self, input_mask):
604
- _check_input_shape(F.shape(input_mask), "input_mask", self.cls_name, 2)
605
643
  _check_input_dtype(F.dtype(input_mask), "input_mask", [mstype.float32, mstype.float16], self.cls_name)
606
- _check_input_shape_value(F.shape(input_mask), 1, "input_mask", self.cls_name, self.seq_length)
607
644
  input_mask = P.Cast()(self.not_equal(input_mask, 0), mstype.float16)
608
645
  input_shape = P.Shape()(input_mask)
609
646
  shape_right = (input_shape[0], 1, input_shape[1])
@@ -698,10 +735,9 @@ class VocabEmbedding(Cell):
698
735
  f"model parallel for the embedding lookup.")
699
736
 
700
737
  def construct(self, input_ids):
701
- _check_input_shape(F.shape(input_ids), "input_ids", self.cls_name, 2)
702
738
  _check_input_dtype(F.dtype(input_ids), "input_ids", [mstype.int32], self.cls_name)
703
739
  output = self.gather(self.embedding_table, input_ids, 0)
704
- return output, self.embedding_table
740
+ return output, self.embedding_table.value()
705
741
 
706
742
 
707
743
  class MultiHeadAttention(Cell):
@@ -718,7 +754,9 @@ class MultiHeadAttention(Cell):
718
754
  if query, key and value tensor is same, then it will be self attention.
719
755
 
720
756
  Args:
721
- batch_size(int): The batch size of the input tensor.
757
+ batch_size(int): The batch size of the input tensor when do increnmental prediction. Should be a positive
758
+ value. When do training or prediction, the argument will not work and the user can just pass None to
759
+ the argument.
722
760
  src_seq_length(int): The sequence length of the query vector.
723
761
  tgt_seq_length(int): The sequence length of the key and value vector.
724
762
  hidden_size(int): The hidden size of the input.
@@ -751,9 +789,9 @@ class MultiHeadAttention(Cell):
751
789
  - **value_tensor** (Tensor) - The value vector with shape (batch_size, tgt_seq_length, hidden_size) or
752
790
  (batch_size * tgt_seq_length, hidden_size), if the use_past is False or is_first_iteration=True.
753
791
  Otherwise, must be (batch_size, 1, hidden_size)
754
- - **attention_mask** (Tensor) - The attention mask matrix with shape (batch_size, src_seq_length,
755
- tgt_seq_length), if the use_past is False or is_first_iteration=True. Otherwise,
756
- must be (batch_size, 1, tgt_seq_length)
792
+ - **attention_mask** (Tensor) - If the use_past is False or is_first_iteration=True, the attention mask
793
+ matrix should ba (batch_size, src_seq_length, tgt_seq_length), or None. None means there will be no mask
794
+ in softmax computation. Otherwise, the mask must be (batch_size, 1, tgt_seq_length)
757
795
  - **key_past** (Tensor) - Float16 tensor with shape (batch_size, num_heads, size_per_head, tgt_seq_length).
758
796
  The past calculated key vector. Used for incremental prediction when the use_past is True.
759
797
  Default None.
@@ -783,7 +821,7 @@ class MultiHeadAttention(Cell):
783
821
  >>> from mindspore.nn.transformer import MultiHeadAttention
784
822
  >>> from mindspore import dtype as mstype
785
823
  >>> from mindspore import Tensor
786
- >>> model = MultiHeadAttention(batch_size=2, hidden_size=15, src_seq_length=20, tgt_seq_length=20,
824
+ >>> model = MultiHeadAttention(batch_size=None, hidden_size=15, src_seq_length=20, tgt_seq_length=20,
787
825
  ... num_heads=3)
788
826
  >>> from_tensor = Tensor(np.ones((2, 20, 15)), mstype.float32)
789
827
  >>> to_tensor = Tensor(np.ones((2, 20, 15)), mstype.float16)
@@ -830,8 +868,7 @@ class MultiHeadAttention(Cell):
830
868
  """
831
869
  @_LogActionOnce(logger=logger, key='MultiHeadAttention',
832
870
  no_warning=_get_parallel_mode() in (ParallelMode.STAND_ALONE,))
833
- @_args_type_validator_check(batch_size=Validator.check_positive_int,
834
- hidden_size=Validator.check_positive_int,
871
+ @_args_type_validator_check(hidden_size=Validator.check_positive_int,
835
872
  num_heads=Validator.check_positive_int,
836
873
  src_seq_length=Validator.check_positive_int,
837
874
  tgt_seq_length=Validator.check_positive_int,
@@ -860,10 +897,13 @@ class MultiHeadAttention(Cell):
860
897
  parallel_config=default_dpmp_config):
861
898
  super(MultiHeadAttention, self).__init__()
862
899
  self._is_ascend = context.get_context('device_target') in ["Ascend"]
863
- if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,) and _is_sharding_propagation():
900
+ self.dp = parallel_config.data_parallel
901
+ self.is_parallel_mode = _get_parallel_mode() in (
902
+ ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
903
+ if batch_size:
904
+ Validator.check_positive_int(batch_size)
905
+ if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,):
864
906
  _check_config(parallel_config)
865
- self.is_parallel_mode = _get_parallel_mode() in (
866
- ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
867
907
  self.src_seq_length = src_seq_length
868
908
  self.tgt_seq_length = tgt_seq_length
869
909
  self.hidden_size = hidden_size
@@ -883,11 +923,6 @@ class MultiHeadAttention(Cell):
883
923
  "'parallel_config.model_parallel', but got the num_heads is {} "
884
924
  "and the parallel_config.model_parallel is {}."
885
925
  .format(num_heads, parallel_config.model_parallel))
886
- if self.is_parallel_mode and batch_size % parallel_config.data_parallel != 0:
887
- raise ValueError("For 'MultiHeadAttention', the class variable 'batch_size' must be a multiple of "
888
- "'parallel_config.data_parallel', but got the batch_size is {} "
889
- "and the parallel_config.data_parallel is {}."
890
- .format(batch_size, parallel_config.data_parallel))
891
926
  self.is_first_iteration = True
892
927
  # Output layer
893
928
  self.projection = _Linear(in_channels=hidden_size,
@@ -918,8 +953,8 @@ class MultiHeadAttention(Cell):
918
953
  # Normalize factor for attention, sqrt(dk) as widely used
919
954
  self.scale_factor = Tensor(math.sqrt(math.sqrt(self.size_per_head)))
920
955
  self.use_past = use_past
921
- self.dropout = nn.Dropout(1 - hidden_dropout_rate)
922
- self.prob_dropout = nn.Dropout(1 - attention_dropout_rate)
956
+ self.dropout = nn.Dropout(p=hidden_dropout_rate)
957
+ self.prob_dropout = nn.Dropout(p=attention_dropout_rate)
923
958
  self.softmax = nn.Softmax().to_float(softmax_compute_type)
924
959
  self.softmax_3d = nn.Softmax().to_float(softmax_compute_type)
925
960
  self.expand_dims = P.ExpandDims()
@@ -961,8 +996,6 @@ class MultiHeadAttention(Cell):
961
996
  self.mul1 = P.Mul().shard(((1, 1, 1, 1), (1, 1, 1, 1)))
962
997
  else:
963
998
  _check_config(parallel_config)
964
- self.is_parallel_mode = _get_parallel_mode() in (
965
- ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
966
999
  self.src_seq_length = src_seq_length
967
1000
  self.tgt_seq_length = tgt_seq_length
968
1001
  self.hidden_size = hidden_size
@@ -982,11 +1015,6 @@ class MultiHeadAttention(Cell):
982
1015
  "'parallel_config.model_parallel', but got the num_heads is {} "
983
1016
  "and the parallel_config.model_parallel is {}."
984
1017
  .format(num_heads, parallel_config.model_parallel))
985
- if self.is_parallel_mode and batch_size % parallel_config.data_parallel != 0:
986
- raise ValueError("For 'MultiHeadAttention', the class variable 'batch_size' must be a multiple of "
987
- "'parallel_config.data_parallel', but got the batch_size is {} "
988
- "and the parallel_config.data_parallel is {}."
989
- .format(batch_size, parallel_config.data_parallel))
990
1018
  self.is_first_iteration = True
991
1019
  # Output layer
992
1020
  self.projection = _Linear(in_channels=hidden_size,
@@ -1026,9 +1054,9 @@ class MultiHeadAttention(Cell):
1026
1054
  # Normalize factor for attention, sqrt(dk) as widely used
1027
1055
  self.scale_factor = Tensor(math.sqrt(math.sqrt(self.size_per_head)))
1028
1056
  self.use_past = use_past
1029
- self.dropout = nn.Dropout(1 - hidden_dropout_rate)
1057
+ self.dropout = nn.Dropout(p=hidden_dropout_rate)
1030
1058
  self.dropout.dropout.shard(((parallel_config.data_parallel, 1),))
1031
- self.prob_dropout = nn.Dropout(1 - attention_dropout_rate)
1059
+ self.prob_dropout = nn.Dropout(p=attention_dropout_rate)
1032
1060
  self.prob_dropout.dropout.shard(
1033
1061
  ((parallel_config.data_parallel, parallel_config.model_parallel, 1, 1),))
1034
1062
  self.softmax = nn.Softmax().to_float(softmax_compute_type)
@@ -1086,10 +1114,12 @@ class MultiHeadAttention(Cell):
1086
1114
  value_past=None, batch_valid_length=None):
1087
1115
  self._check_inputs(query_tensor, key_tensor, value_tensor, attention_mask, key_past,
1088
1116
  value_past, batch_valid_length)
1089
- query_tensor, key_tensor, value_tensor, batch_size, ori_shape = self._convert_to_2d_tensor(query_tensor,
1090
- key_tensor,
1091
- value_tensor,
1092
- attention_mask)
1117
+ ori_shape = F.shape(query_tensor)
1118
+ batch_size = self._get_batch_size_from_query(query_tensor)
1119
+ query_tensor, key_tensor, value_tensor = self._convert_to_2d_tensor(query_tensor,
1120
+ key_tensor,
1121
+ value_tensor,
1122
+ attention_mask)
1093
1123
  ori_dtype = F.dtype(query_tensor)
1094
1124
  query_tensor = F.cast(query_tensor, self.dtype)
1095
1125
  key_tensor = F.cast(key_tensor, self.dtype)
@@ -1102,21 +1132,24 @@ class MultiHeadAttention(Cell):
1102
1132
  query = self.transpose(
1103
1133
  F.reshape(
1104
1134
  query,
1105
- (batch_size, -1, self.n_head, self.size_per_head)),
1135
+ (batch_size, self._get_seq_length_under_incremental(self.src_seq_length),
1136
+ self.n_head, self.size_per_head)),
1106
1137
  (0, 2, 1, 3))
1107
1138
  # the returned shape is [bs, size_per_head, seq_length, num_heads]
1108
1139
  key = self.transpose(
1109
1140
  F.reshape(
1110
- key, (batch_size, -1, self.n_head, self.size_per_head)),
1141
+ key, (batch_size, self._get_seq_length_under_incremental(self.tgt_seq_length),
1142
+ self.n_head, self.size_per_head)),
1111
1143
  (0, 2, 3, 1))
1112
1144
  # the returned shape is [bs, num_heads, seq_length, size_per_head]
1113
1145
  value = self.transpose(
1114
1146
  F.reshape(
1115
1147
  value,
1116
- (batch_size, -1, self.n_head, self.size_per_head)),
1148
+ (batch_size, self._get_seq_length_under_incremental(self.tgt_seq_length),
1149
+ self.n_head, self.size_per_head)),
1117
1150
  (0, 2, 1, 3))
1118
1151
  # support input shape is [bs, seq, seq] or [bs, heads, seq, seq]
1119
- if len(F.shape(attention_mask)) == 3:
1152
+ if attention_mask is not None and len(F.shape(attention_mask)) == 3:
1120
1153
  # expand attention mask from [bs, seq, seq] -> [bs, 1, seq, seq]
1121
1154
  attention_mask = self.expand_dims(attention_mask, 1)
1122
1155
  # key and value for current token(s)
@@ -1167,35 +1200,30 @@ class MultiHeadAttention(Cell):
1167
1200
  output = F.cast(output, ori_dtype)
1168
1201
  return output, layer_present
1169
1202
 
1203
+ def _get_batch_size_from_query(self, query):
1204
+ r"""Get the batch size from query tensor"""
1205
+ # For the incremental prediction, the seq length for the input is 1.
1206
+ if len(F.shape(query)) == 2 and ((self.use_past and self.is_first_iteration) or (not self.use_past)):
1207
+ return F.shape(query)[0] // self.src_seq_length
1208
+ return F.shape(query)[0]
1209
+
1210
+ def _get_seq_length_under_incremental(self, length):
1211
+ r"""Return the length of the tensor.
1212
+ For the incremental prediction, the seq length for the input is 1.
1213
+ """
1214
+ if self.use_past and not self.is_first_iteration:
1215
+ return 1
1216
+ return length
1217
+
1170
1218
  def _check_inputs(self, query_tensor, key_tensor, value_tensor, attention_mask, key_past=None,
1171
1219
  value_past=None, batch_valid_length=None):
1172
1220
  r"""Check inputs"""
1173
- if not self.use_past or (self.use_past and self.is_first_iteration):
1174
- _check_shape_equal(F.shape(query_tensor), "query_tensor", self.cls_name,
1175
- [[self.batch_size, self.src_seq_length, self.hidden_size],
1176
- [self.batch_size * self.src_seq_length, self.hidden_size]])
1177
- _check_shape_equal(F.shape(key_tensor), "key_tensor", self.cls_name,
1178
- [[self.batch_size, self.tgt_seq_length, self.hidden_size],
1179
- [self.batch_size * self.tgt_seq_length, self.hidden_size]])
1180
- _check_shape_equal(F.shape(value_tensor), "value_tensor", self.cls_name,
1181
- [[self.batch_size, self.tgt_seq_length, self.hidden_size],
1182
- [self.batch_size * self.tgt_seq_length, self.hidden_size]])
1183
- _check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name,
1184
- [self.batch_size, self.src_seq_length, self.tgt_seq_length])
1185
- else:
1186
- _check_shape_equal(F.shape(query_tensor), "query_tensor", self.cls_name,
1187
- [[self.batch_size, 1, self.hidden_size], [self.batch_size, self.hidden_size]])
1188
- _check_shape_equal(F.shape(key_tensor), "key_tensor", self.cls_name,
1189
- [[self.batch_size, 1, self.hidden_size], [self.batch_size, self.hidden_size]])
1190
- _check_shape_equal(F.shape(value_tensor), "value_tensor", self.cls_name,
1191
- [[self.batch_size, 1, self.hidden_size], [self.batch_size, self.hidden_size]])
1192
- _check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name,
1193
- [[self.batch_size, 1, self.tgt_seq_length], [self.batch_size, self.hidden_size]])
1194
-
1195
1221
  _check_input_dtype(F.dtype(query_tensor), "query_tensor", [mstype.float32, mstype.float16], self.cls_name)
1196
1222
  _check_input_dtype(F.dtype(key_tensor), "key_tensor", [mstype.float32, mstype.float16], self.cls_name)
1197
1223
  _check_input_dtype(F.dtype(value_tensor), "value_tensor", [mstype.float32, mstype.float16], self.cls_name)
1198
- _check_input_dtype(F.dtype(attention_mask), "attention_mask", [mstype.float32, mstype.float16], self.cls_name)
1224
+ if attention_mask is not None:
1225
+ _check_input_dtype(F.dtype(attention_mask), "attention_mask", [mstype.float32, mstype.float16],
1226
+ self.cls_name)
1199
1227
 
1200
1228
  key_is_tensor = isinstance(key_past, Tensor)
1201
1229
  value_is_tensor = isinstance(value_past, Tensor)
@@ -1210,13 +1238,8 @@ class MultiHeadAttention(Cell):
1210
1238
  _check_past_none_input_none(self.use_past, "batch_valid_length", self.cls_name, None,
1211
1239
  batch_valid_length_is_tensor, batch_is_default)
1212
1240
  if self.use_past:
1213
- _check_shape_equal(F.shape(key_past), "key_past", self.cls_name,
1214
- [self.batch_size, self.n_head, self.size_per_head, self.tgt_seq_length])
1215
1241
  _check_input_dtype(F.dtype(key_past), "key_past", [mstype.float16], self.cls_name)
1216
- _check_shape_equal(F.shape(value_past), "value_past", self.cls_name,
1217
- [self.batch_size, self.n_head, self.tgt_seq_length, self.size_per_head])
1218
1242
  _check_input_dtype(F.dtype(value_past), "value_past", [mstype.float16], self.cls_name)
1219
- _check_shape_equal(F.shape(batch_valid_length), "batch_valid_length", self.cls_name, [self.batch_size])
1220
1243
  _check_input_dtype(F.dtype(batch_valid_length), "batch_valid_length", [mstype.int32], self.cls_name)
1221
1244
  return True
1222
1245
 
@@ -1228,7 +1251,8 @@ class MultiHeadAttention(Cell):
1228
1251
  key_tensor = F.reshape(key_tensor, (-1, key_shape[-1]))
1229
1252
  value_shape = F.shape(value_tensor)
1230
1253
  value_tensor = F.reshape(value_tensor, (-1, value_shape[-1]))
1231
- return query_tensor, key_tensor, value_tensor, F.shape(attention_mask)[0], query_shape
1254
+
1255
+ return query_tensor, key_tensor, value_tensor
1232
1256
 
1233
1257
  def _merge_heads(self, x):
1234
1258
  """
@@ -1286,30 +1310,31 @@ class MultiHeadAttention(Cell):
1286
1310
  score = self.batch_matmul(query, key)
1287
1311
 
1288
1312
  ori_dtype = P.DType()(score)
1289
- score = P.Cast()(score, self.softmax_dtype)
1313
+ attention_scores = P.Cast()(score, self.softmax_dtype)
1290
1314
 
1291
1315
  # for input size of (bs, 1) namely the second graph,
1292
1316
  # the shape of attention_mask matrix should be (bs, 1, 1, seq_length)
1293
- if self.use_past and not self.is_first_iteration:
1294
- # Calculate the current total token
1295
- current_index = self.reducesum(F.cast(self.not_equal(self.slice(key, (0, 0, 0, 0),
1296
- (F.shape(query)[0], 1, 1, self.seq_length),
1297
- (1, 1, 1, 1)),
1298
- 0), mstype.float32), (1, 2, 3))
1299
- # Get the precise position index
1300
- index = self.sub1(F.cast(current_index, mstype.int32), 1)
1301
- index = F.reshape(index, (-1, 1, 1))
1302
- # Calculate the attention_mask matrix via the position index
1303
- attention_mask = F.cast(self.tensor_le(self.range, index), mstype.int32)
1304
- attention_mask = self.expand_dims(attention_mask, 2)
1305
-
1306
- # Minus 10000 for the position where masked to exclude them from softmax
1307
- multiplu_out = self.sub(
1308
- P.Cast()(F.tuple_to_array((1.0,)), P.DType()(score)),
1309
- P.Cast()(attention_mask, P.DType()(score)))
1310
-
1311
- adder = self.mul(multiplu_out, self.multiply_data)
1312
- attention_scores = self.add(adder, score)
1317
+ if attention_mask is not None:
1318
+ if self.use_past and not self.is_first_iteration:
1319
+ # Calculate the current total token
1320
+ current_index = self.reducesum(F.cast(self.not_equal(self.slice(key, (0, 0, 0, 0),
1321
+ (F.shape(query)[0], 1, 1,
1322
+ self.seq_length),
1323
+ (1, 1, 1, 1)),
1324
+ 0), mstype.float32), (1, 2, 3))
1325
+ # Get the precise position index
1326
+ index = self.sub1(F.cast(current_index, mstype.int32), 1)
1327
+ index = F.reshape(index, (-1, 1, 1))
1328
+ # Calculate the attention_mask matrix via the position index
1329
+ attention_mask = F.cast(self.tensor_le(self.range, index), mstype.int32)
1330
+ attention_mask = self.expand_dims(attention_mask, 2)
1331
+ # Minus 10000 for the position where masked to exclude them from softmax
1332
+ multiplu_out = self.sub(
1333
+ P.Cast()(F.tuple_to_array((1.0,)), P.DType()(attention_scores)),
1334
+ P.Cast()(attention_mask, P.DType()(attention_scores)))
1335
+
1336
+ adder = self.mul(multiplu_out, self.multiply_data)
1337
+ attention_scores = self.add(adder, attention_scores)
1313
1338
 
1314
1339
  # attention probs
1315
1340
  attention_probs = self._softmax(attention_scores)
@@ -1328,7 +1353,9 @@ class TransformerEncoderLayer(Cell):
1328
1353
  encoder layer, including multihead attention and feedward layer.
1329
1354
 
1330
1355
  Args:
1331
- batch_size(int): The batch size of the input tensor.
1356
+ batch_size(int): The batch size of the input tensor when do increnmental prediction. Should be a positive
1357
+ value. When do training or prediction, the argument will not work and the user can just pass None to
1358
+ the argument.
1332
1359
  hidden_size(int): The hidden size of the input.
1333
1360
  ffn_hidden_size(int): The hidden size of bottleneck in the feedforward layer.
1334
1361
  num_heads(int): The number of the heads.
@@ -1342,9 +1369,12 @@ class TransformerEncoderLayer(Cell):
1342
1369
  Should be mstype.float32 or mstype.float16. Default mstype.float32.
1343
1370
  param_init_type(dtype.Number): The parameter initialization type of the module.
1344
1371
  Should be mstype.float32 or mstype.float16. Default mstype.float32.
1345
- hidden_act(str): The activation of the internal feedforward layer. Supports 'relu',
1372
+ hidden_act (str, nn.Cell): The activation of the internal feedforward layer. Supports 'relu',
1346
1373
  'relu6', 'tanh', 'gelu', 'fast_gelu', 'elu', 'sigmoid', 'prelu', 'leakyrelu', 'hswish',
1347
- 'hsigmoid', 'logsigmoid' and so on. Default: gelu.
1374
+ 'hsigmoid', 'logsigmoid' and so on. User can provide custom activition to the argument.
1375
+ If user wants to run the net in the parallel mode, the custom activation must also provide
1376
+ the `activation_shard` function. Please see the examples of the
1377
+ class:`mindspore.nn.transformer.FeedForward`. Default: gelu.
1348
1378
  use_past(bool): Use the past state to compute, used for incremental prediction. For example, if we have two
1349
1379
  words and want to generate the ten more words. We just need to compute the two words' state only once,
1350
1380
  and generate the next word one by one. When use_past is True, there are two steps to run the prediction.
@@ -1362,8 +1392,9 @@ class TransformerEncoderLayer(Cell):
1362
1392
  - **x** (Tensor) - Float Tensor, shape should be [batch_size, seq_length, hidden_size] or
1363
1393
  [batch_size * seq_length, hidden_size], if the use_past is False or is_first_iteration=True. Otherwise,
1364
1394
  should be [batch_size, 1, hidden_size]
1365
- - **input_mask** (Tensor) - Float Tensor, attention mask with shape [batch_size, seq_length, seq_length],
1366
- if the use_past is False or is_first_iteration=True. Otherwise, should be [batch_size, 1, hidden_size]
1395
+ - **input_mask** (Tensor) - Float Tensor, If the use_past is False or is_first_iteration=True,
1396
+ the attention mask matrix should ba [batch_size, seq_length, seq_length], or None. None means there will
1397
+ be no mask in softmax computation. Otherwise, should be [batch_size, 1, hidden_size]
1367
1398
  - **init_reset** (Tensor) - A bool tensor with shape [1], used to clear the past key parameter and
1368
1399
  past value parameter used in the incremental prediction. Only valid when use_past is True. Default True.
1369
1400
  - **batch_valid_length** (Tensor) - Int32 tensor with shape [batch_size] the past calculated the index.
@@ -1430,14 +1461,12 @@ class TransformerEncoderLayer(Cell):
1430
1461
  """
1431
1462
  @_LogActionOnce(logger=logger, key='TransformerEncoderLayer',
1432
1463
  no_warning=_get_parallel_mode() in (ParallelMode.STAND_ALONE,))
1433
- @_args_type_validator_check(batch_size=Validator.check_positive_int,
1434
- hidden_size=Validator.check_positive_int,
1464
+ @_args_type_validator_check(hidden_size=Validator.check_positive_int,
1435
1465
  num_heads=Validator.check_positive_int,
1436
1466
  ffn_hidden_size=Validator.check_positive_int,
1437
1467
  seq_length=Validator.check_positive_int,
1438
1468
  attention_dropout_rate=Validator.check_non_negative_float,
1439
1469
  hidden_dropout_rate=Validator.check_non_negative_float,
1440
- hidden_act=_valid_type_checks([str], "TransformerEncoderLayer"),
1441
1470
  post_layernorm_residual=Validator.check_bool,
1442
1471
  layernorm_compute_type=_valid_value_checks([mstype.float32, mstype.float16],
1443
1472
  "TransformerEncoderLayer"),
@@ -1465,7 +1494,10 @@ class TransformerEncoderLayer(Cell):
1465
1494
  moe_config=default_moe_config,
1466
1495
  parallel_config=default_dpmp_config):
1467
1496
  super(TransformerEncoderLayer, self).__init__()
1468
- if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,) and _is_sharding_propagation():
1497
+ if batch_size or use_past:
1498
+ Validator.check_positive_int(batch_size)
1499
+ self.batch_size = batch_size
1500
+ if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,):
1469
1501
  _check_config(parallel_config)
1470
1502
  if num_heads % parallel_config.model_parallel != 0:
1471
1503
  raise ValueError(
@@ -1488,7 +1520,6 @@ class TransformerEncoderLayer(Cell):
1488
1520
  self.use_past = use_past
1489
1521
  self.seq_length = seq_length
1490
1522
  self.hidden_size = hidden_size
1491
- self.batch_size = batch_size
1492
1523
  self.layernorm1 = _LayerNorm((hidden_size,)).to_float(layernorm_compute_type)
1493
1524
  self.layernorm2 = _LayerNorm((hidden_size,)).to_float(layernorm_compute_type)
1494
1525
 
@@ -1564,7 +1595,6 @@ class TransformerEncoderLayer(Cell):
1564
1595
  self.use_past = use_past
1565
1596
  self.seq_length = seq_length
1566
1597
  self.hidden_size = hidden_size
1567
- self.batch_size = batch_size
1568
1598
  self.layernorm1 = _LayerNorm((hidden_size,)).to_float(layernorm_compute_type)
1569
1599
  self.layernorm1.shard(((parallel_config.data_parallel, 1),))
1570
1600
  self.layernorm2 = _LayerNorm((hidden_size,)).to_float(layernorm_compute_type)
@@ -1623,11 +1653,14 @@ class TransformerEncoderLayer(Cell):
1623
1653
  raise RuntimeError(f"The {self.cls_name} only support sharding propagation or "
1624
1654
  f"semi-auto parallel mode now.")
1625
1655
 
1626
- def construct(self, x, input_mask, init_reset=True, batch_valid_length=None):
1656
+ def construct(self, x, input_mask=None, init_reset=True, batch_valid_length=None):
1627
1657
  self._check_input(x, input_mask, init_reset, batch_valid_length)
1628
1658
  x_shape = F.shape(x)
1629
1659
  x = F.reshape(x, (-1, x_shape[-1]))
1630
- input_x = self.layernorm1(x)
1660
+ if self.post_layernorm_residual:
1661
+ input_x = x
1662
+ else:
1663
+ input_x = self.layernorm1(x)
1631
1664
  input_x = F.cast(input_x, self.dtype)
1632
1665
 
1633
1666
  # indicate whether reset saved states
@@ -1636,8 +1669,10 @@ class TransformerEncoderLayer(Cell):
1636
1669
 
1637
1670
  if self.use_past:
1638
1671
  # reset states, init_reset True for reuse and False for reset
1639
- key_reset = self.assign(self.key_past, self.mul(self.key_past, F.cast(init_reset, self.dtype)))
1640
- value_reset = self.assign(self.value_past, self.mul(self.value_past, F.cast(init_reset, self.dtype)))
1672
+ self.assign(self.key_past, self.mul(self.key_past, F.cast(init_reset, self.dtype)))
1673
+ key_reset = self.key_past
1674
+ self.assign(self.value_past, self.mul(self.value_past, F.cast(init_reset, self.dtype)))
1675
+ value_reset = self.value_past
1641
1676
  # add dependency for desired execution order
1642
1677
  input_x = F.depend(input_x, key_reset)
1643
1678
  input_x = F.depend(input_x, value_reset)
@@ -1665,8 +1700,10 @@ class TransformerEncoderLayer(Cell):
1665
1700
  # current key and value
1666
1701
  key_present, value_present = layer_present
1667
1702
  # update key and value calculated this step
1668
- key_update = self.assign(self.key_past, key_present)
1669
- value_update = self.assign(self.value_past, value_present)
1703
+ self.assign(self.key_past, key_present)
1704
+ key_update = self.key_past
1705
+ self.assign(self.value_past, value_present)
1706
+ value_update = self.value_past
1670
1707
  # add dependency for desired execution order
1671
1708
  key_update = F.depend(key_update, key_reset)
1672
1709
  value_update = F.depend(value_update, value_reset)
@@ -1683,11 +1720,15 @@ class TransformerEncoderLayer(Cell):
1683
1720
 
1684
1721
  if self.post_layernorm_residual:
1685
1722
  output = self.add_3d(output_x, mlp_logit)
1723
+ output = F.reshape(output, (-1, x_shape[-1]))
1724
+ output = self.layernorm1(output)
1725
+ output = F.reshape(output, x_shape)
1686
1726
  else:
1687
1727
  output = self.add_3d(x, mlp_logit)
1688
1728
  else:
1689
1729
  if self.post_layernorm_residual:
1690
1730
  output = self.add(output_x, mlp_logit)
1731
+ output = self.layernorm1(output)
1691
1732
  else:
1692
1733
  output = self.add(x, mlp_logit)
1693
1734
  output = F.reshape(output, x_shape)
@@ -1698,18 +1739,9 @@ class TransformerEncoderLayer(Cell):
1698
1739
 
1699
1740
  def _check_input(self, x, input_mask, init_reset, batch_valid_length):
1700
1741
  r"""Check inputs"""
1701
- if not self.use_past or (self.use_past and self.is_first_iteration):
1702
- _check_shape_equal(F.shape(x), "x", self.cls_name,
1703
- [[self.batch_size, self.seq_length, self.hidden_size],
1704
- [self.batch_size * self.seq_length, self.hidden_size]])
1705
- _check_shape_equal(F.shape(input_mask), "input_mask", self.cls_name,
1706
- [self.batch_size, self.seq_length, self.seq_length])
1707
- else:
1708
- _check_shape_equal(F.shape(x), "x", self.cls_name, [self.batch_size, 1, self.hidden_size])
1709
- _check_shape_equal(F.shape(input_mask), "input_mask", self.cls_name,
1710
- [self.batch_size, 1, self.seq_length])
1711
1742
  _check_input_dtype(F.dtype(x), "x", [mstype.float32, mstype.float16], self.cls_name)
1712
- _check_input_dtype(F.dtype(input_mask), "input_mask", [mstype.float32, mstype.float16], self.cls_name)
1743
+ if input_mask is not None:
1744
+ _check_input_dtype(F.dtype(input_mask), "input_mask", [mstype.float32, mstype.float16], self.cls_name)
1713
1745
 
1714
1746
  init_reset_is_tensor = isinstance(init_reset, Tensor)
1715
1747
  init_reset_is_default = init_reset is True
@@ -1721,9 +1753,7 @@ class TransformerEncoderLayer(Cell):
1721
1753
  batch_valid_length_is_tensor, batch_is_default)
1722
1754
 
1723
1755
  if self.use_past:
1724
- _check_shape_equal(F.shape(init_reset), "init_reset", self.cls_name, [1])
1725
1756
  _check_input_dtype(F.dtype(init_reset), "init_reset", [mstype.bool_], self.cls_name)
1726
- _check_shape_equal(F.shape(batch_valid_length), "batch_valid_length", self.cls_name, [self.batch_size])
1727
1757
  _check_input_dtype(F.dtype(batch_valid_length), "batch_valid_length", [mstype.int32], self.cls_name)
1728
1758
  return True
1729
1759
 
@@ -1738,7 +1768,9 @@ class TransformerDecoderLayer(Cell):
1738
1768
  hidden_size(int): The hidden size of the input.
1739
1769
  ffn_hidden_size(int): The hidden size of bottleneck in the feedforward layer.
1740
1770
  num_heads(int): The number of the heads.
1741
- batch_size(int): The batch size of the input tensor.
1771
+ batch_size(int): The batch size of the input tensor when do increnmental prediction. Should be a positive
1772
+ value. When do training or prediction, the argument will not work and the user can just pass None to
1773
+ the argument.
1742
1774
  src_seq_length(int): The input source sequence length.
1743
1775
  tgt_seq_length(int): The input target sequence length.
1744
1776
  attention_dropout_rate(float): The dropout rate of the attention scores. Default:0.1.
@@ -1751,9 +1783,12 @@ class TransformerDecoderLayer(Cell):
1751
1783
  Should be dtype.float32 or dtype.float16. Default mstype.float32.
1752
1784
  param_init_type(dtype.Number): The parameter initialization type of the module.
1753
1785
  Should be dtype.float32 or dtype.float16. Default dtype.float32.
1754
- hidden_act(str): The activation of the internal feedforward layer. Supports 'relu',
1786
+ hidden_act (str, nn.Cell): The activation of the internal feedforward layer. Supports 'relu',
1755
1787
  'relu6', 'tanh', 'gelu', 'fast_gelu', 'elu', 'sigmoid', 'prelu', 'leakyrelu', 'hswish',
1756
- 'hsigmoid', 'logsigmoid' and so on. Default: gelu.
1788
+ 'hsigmoid', 'logsigmoid' and so on. User can provide custom activition to the argument.
1789
+ If user wants to run the net in the parallel mode, the custom activation must also provide
1790
+ the `activation_shard` function. Please see the examples of the
1791
+ class:`mindspore.nn.transformer.FeedForward`. Default: gelu.
1757
1792
  moe_config(MoEConfig): The configuration of MoE (Mixture of Expert). Default is an instance of MoEConfig
1758
1793
  with default values. Please see `MoEConfig`.
1759
1794
  parallel_config(OpParallelConfig, MoEParallelConfig): The parallel configure. When MoE is applied,
@@ -1764,13 +1799,13 @@ class TransformerDecoderLayer(Cell):
1764
1799
  - **hidden_stats** (Tensor) - The input tensor with shape [batch_size, tgt_seq_length, hidden_size] or
1765
1800
  [batch_size * tgt_seq_length, hidden_size].
1766
1801
  - **decoder_mask** (Tensor) - The attention mask for decoder with shape [batch_size, src_seq_length,
1767
- seq_length].
1802
+ seq_length] or None. None means there will be no mask in softmax computation in self attention.
1768
1803
  - **encoder_output** (Tensor) - The output of the encoder with shape [batch_size, seq_length, hidden_size]
1769
1804
  or [batch_size * seq_length, hidden_size].
1770
1805
  Note this args can not be passed by None when the net is in outermost layer. Default None.
1771
1806
  - **memory_mask** (Tensor) - The memory mask of the cross attention with shape [batch, tgt_seq_length,
1772
- src_seq_length] where tgt_seq_length is the length of the decoder. Note this args can not be passed by
1773
- None when the net is in outermost layer. Default None.
1807
+ src_seq_length] where tgt_seq_length is the length of the decoder. The user can also pass None. None
1808
+ means there will be no mask in softmax computation in cross attention. Default None.
1774
1809
  - **init_reset** (Tensor) - A bool tensor with shape [1], used to clear the past key parameter and
1775
1810
  past value parameter used in the incremental prediction. Only valid when use_past is True. Default True.
1776
1811
  - **batch_valid_length** (Tensor) - Int32 tensor with shape [batch_size] the past calculated the index.
@@ -1815,15 +1850,13 @@ class TransformerDecoderLayer(Cell):
1815
1850
  """
1816
1851
  @_LogActionOnce(logger=logger, key='TransformerDecoderLayer',
1817
1852
  no_warning=_get_parallel_mode() in (ParallelMode.STAND_ALONE,))
1818
- @_args_type_validator_check(batch_size=Validator.check_positive_int,
1819
- hidden_size=Validator.check_positive_int,
1853
+ @_args_type_validator_check(hidden_size=Validator.check_positive_int,
1820
1854
  num_heads=Validator.check_positive_int,
1821
1855
  ffn_hidden_size=Validator.check_positive_int,
1822
1856
  src_seq_length=Validator.check_positive_int,
1823
1857
  tgt_seq_length=Validator.check_positive_int,
1824
1858
  attention_dropout_rate=Validator.check_non_negative_float,
1825
1859
  hidden_dropout_rate=Validator.check_non_negative_float,
1826
- hidden_act=_valid_type_checks([str], "TransformerDecoderLayer"),
1827
1860
  post_layernorm_residual=Validator.check_bool,
1828
1861
  layernorm_compute_type=_valid_value_checks([mstype.float32, mstype.float16],
1829
1862
  "TransformerDecoderLayer"),
@@ -1854,7 +1887,9 @@ class TransformerDecoderLayer(Cell):
1854
1887
  _check_moe_config(moe_config, parallel_config)
1855
1888
  self.use_moe = (moe_config.expert_num > 1)
1856
1889
  config_to_attention = parallel_config.dpmp if self.use_moe else parallel_config
1857
- if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,) and _is_sharding_propagation():
1890
+ if batch_size or use_past:
1891
+ Validator.check_positive_int(batch_size)
1892
+ if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,):
1858
1893
  _check_config(parallel_config)
1859
1894
  if num_heads % parallel_config.model_parallel != 0:
1860
1895
  raise ValueError("For 'TransformerDecoderLayer', the class variable 'num_heads' must be divisibled by "
@@ -2066,8 +2101,10 @@ class TransformerDecoderLayer(Cell):
2066
2101
  value_reset = None
2067
2102
  if self.use_past:
2068
2103
  # reset states, init_reset True for reuse and False for reset
2069
- key_reset = self.assign(self.key_past, self.mul(self.key_past, F.cast(init_reset, self.dtype)))
2070
- value_reset = self.assign(self.value_past, self.mul(self.value_past, F.cast(init_reset, self.dtype)))
2104
+ self.assign(self.key_past, self.mul(self.key_past, F.cast(init_reset, self.dtype)))
2105
+ key_reset = self.key_past
2106
+ self.assign(self.value_past, self.mul(self.value_past, F.cast(init_reset, self.dtype)))
2107
+ value_reset = self.value_past
2071
2108
  # add dependency for desired execution order
2072
2109
  input_x = F.depend(input_x, key_reset)
2073
2110
  input_x = F.depend(input_x, value_reset)
@@ -2110,8 +2147,10 @@ class TransformerDecoderLayer(Cell):
2110
2147
  # current key and value
2111
2148
  key_present, value_present = layer_present
2112
2149
  # update key and value calculated this step
2113
- key_update = self.assign(self.key_past, key_present)
2114
- value_update = self.assign(self.value_past, value_present)
2150
+ self.assign(self.key_past, key_present)
2151
+ key_update = self.key_past
2152
+ self.assign(self.value_past, value_present)
2153
+ value_update = self.value_past
2115
2154
  # add dependency for desired execution order
2116
2155
  key_update = F.depend(key_update, key_reset)
2117
2156
  value_update = F.depend(value_update, value_reset)
@@ -2143,29 +2182,14 @@ class TransformerDecoderLayer(Cell):
2143
2182
 
2144
2183
  def _check_input(self, hidden_states, attention_mask, encoder_output, memory_mask, init_reset, batch_valid_length):
2145
2184
  r"""Check inputs"""
2146
- if not self.use_past or (self.use_past and self.is_first_iteration):
2147
- _check_shape_equal(F.shape(hidden_states), "hidden_states", self.cls_name,
2148
- [[self.batch_size, self.tgt_seq_length, self.hidden_size],
2149
- [self.batch_size * self.tgt_seq_length, self.hidden_size]])
2150
- _check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name,
2151
- [self.batch_size, self.tgt_seq_length, self.tgt_seq_length])
2152
-
2153
- else:
2154
- _check_shape_equal(F.shape(hidden_states), "hidden_states", self.cls_name,
2155
- [self.batch_size, 1, self.hidden_size])
2156
- _check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name,
2157
- [self.batch_size, 1, self.tgt_seq_length])
2158
2185
  _check_input_dtype(F.dtype(hidden_states), "hidden_states", [mstype.float32, mstype.float16], self.cls_name)
2159
- _check_input_dtype(F.dtype(attention_mask), "attention_mask", [mstype.float32, mstype.float16], self.cls_name)
2186
+ if attention_mask is not None:
2187
+ _check_input_dtype(F.dtype(attention_mask), "attention_mask", [mstype.float32, mstype.float16],
2188
+ self.cls_name)
2160
2189
  if encoder_output is not None:
2161
- _check_shape_equal(F.shape(encoder_output), "encoder_output", self.cls_name,
2162
- [[self.batch_size, self.src_seq_length, self.hidden_size],
2163
- [self.batch_size * self.src_seq_length, self.hidden_size]])
2164
2190
  _check_input_dtype(F.dtype(encoder_output), "encoder_output",
2165
2191
  [mstype.float32, mstype.float16], self.cls_name)
2166
2192
  if memory_mask is not None:
2167
- _check_shape_equal(F.shape(memory_mask), "memory_mask", self.cls_name,
2168
- [self.batch_size, self.tgt_seq_length, self.src_seq_length])
2169
2193
  _check_input_dtype(F.dtype(memory_mask), "memory_mask",
2170
2194
  [mstype.float32, mstype.float16], self.cls_name)
2171
2195
 
@@ -2179,9 +2203,7 @@ class TransformerDecoderLayer(Cell):
2179
2203
  batch_valid_length_is_tensor, batch_is_default)
2180
2204
 
2181
2205
  if self.use_past:
2182
- _check_shape_equal(F.shape(init_reset), "init_reset", self.cls_name, [1])
2183
2206
  _check_input_dtype(F.dtype(init_reset), "init_reset", [mstype.bool_], self.cls_name)
2184
- _check_shape_equal(F.shape(batch_valid_length), "batch_valid_length", self.cls_name, [self.batch_size])
2185
2207
  _check_input_dtype(F.dtype(batch_valid_length), "batch_valid_length", [mstype.int32], self.cls_name)
2186
2208
  return True
2187
2209
 
@@ -2240,7 +2262,9 @@ class TransformerEncoder(Cell):
2240
2262
  attention and feedforward layer.
2241
2263
 
2242
2264
  Args:
2243
- batch_size(int): The batch size of the input tensor.
2265
+ batch_size(int): The batch size of the input tensor when do increnmental prediction. Should be a positive
2266
+ value. When do training or prediction, the argument will not work and the user can just pass None to
2267
+ the argument.
2244
2268
  num_layers(int): The layers of the `TransformerEncoderLayer`
2245
2269
  hidden_size(int): The hidden size of the input.
2246
2270
  ffn_hidden_size(int): The hidden size of bottleneck in the feedforward layer.
@@ -2248,9 +2272,12 @@ class TransformerEncoder(Cell):
2248
2272
  num_heads(int): The number of the heads.
2249
2273
  attention_dropout_rate(float): The dropout rate of the attention scores. Default:0.1.
2250
2274
  hidden_dropout_rate(float): The dropout rate of the final output of the layer. Default: 0.1.
2251
- hidden_act(str): The activation of the internal feedforward layer. Supports 'relu',
2275
+ hidden_act (str, nn.Cell): The activation of the internal feedforward layer. Supports 'relu',
2252
2276
  'relu6', 'tanh', 'gelu', 'fast_gelu', 'elu', 'sigmoid', 'prelu', 'leakyrelu', 'hswish',
2253
- 'hsigmoid', 'logsigmoid' and so on. Default: gelu.
2277
+ 'hsigmoid', 'logsigmoid' and so on. User can provide custom activition to the argument.
2278
+ If user wants to run the net in the parallel mode, the custom activation must also provide
2279
+ the `activation_shard` function. Please see the examples of the
2280
+ class:`mindspore.nn.transformer.FeedForward`. Default: gelu.
2254
2281
  post_layernorm_residual(bool): Do residuals adds before the layernorm. Default False.
2255
2282
  layernorm_compute_type(dtype.Number): The computation type of the layernorm.
2256
2283
  Should be mstype.float32 or mstype.float16. Default mstype.float32.
@@ -2284,7 +2311,9 @@ class TransformerEncoder(Cell):
2284
2311
  - **hidden_states** (Tensor) - Tensor, shape should be [batch_size, seq_length, hidden_size] or
2285
2312
  [batch_size * seq_length, hidden_size], if the use_past is False or is_first_iteration=True. Otherwise,
2286
2313
  should be [batch_size, 1, hidden_size].
2287
- - **attention_mask** (Tensor) - Tensor, attention mask with shape [batch_size, seq_length, seq_length]
2314
+ - **attention_mask** (Tensor) - Float Tensor, If the use_past is False or is_first_iteration=True,
2315
+ the attention mask matrix should ba [batch_size, seq_length, seq_length], or None. None means there will
2316
+ be no mask in softmax computation. Otherwise, should be [batch_size, 1, hidden_size]
2288
2317
  - **init_reset** (Tensor) - A bool tensor with shape [1], used to clear the past key parameter and
2289
2318
  past value parameter used in the incremental prediction. Only valid when use_past is True. Default True.
2290
2319
  - **batch_valid_length** (Tensor) - Int32 tensor with shape [batch_size] the past calculated the index.
@@ -2361,7 +2390,6 @@ class TransformerEncoder(Cell):
2361
2390
  offset=Validator.check_non_negative_int,
2362
2391
  attention_dropout_rate=Validator.check_non_negative_float,
2363
2392
  hidden_dropout_rate=Validator.check_non_negative_float,
2364
- hidden_act=_valid_type_checks([str], "TransformerEncoder"),
2365
2393
  post_layernorm_residual=Validator.check_bool,
2366
2394
  layernorm_compute_type=_valid_value_checks([mstype.float32, mstype.float16],
2367
2395
  "TransformerEncoder"),
@@ -2396,7 +2424,7 @@ class TransformerEncoder(Cell):
2396
2424
  _check_moe_config(moe_config, parallel_config)
2397
2425
  self.use_moe = (moe_config.expert_num > 1)
2398
2426
  config_to_layer = parallel_config.moe_parallel_config if self.use_moe else parallel_config.dp_mp_config
2399
- if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,) and _is_sharding_propagation():
2427
+ if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,):
2400
2428
  self.add = P.Add()
2401
2429
  self.aux_loss = Tensor(0.0, mstype.float32)
2402
2430
  self.num_layers = num_layers
@@ -2490,7 +2518,9 @@ class TransformerDecoder(Cell):
2490
2518
 
2491
2519
  Args:
2492
2520
  num_layers(int): The layers of the `TransformerDecoderLayer`.
2493
- batch_size(int): The batch size of the input tensor.
2521
+ batch_size(int): The batch size of the input tensor when do increnmental prediction. Should be a positive
2522
+ value. When do training or prediction, the argument will not work and the user can just pass None to
2523
+ the argument.
2494
2524
  hidden_size(int): The hidden size of the input.
2495
2525
  ffn_hidden_size(int): The hidden size of bottleneck in the feedforward layer.
2496
2526
  src_seq_length(int): The input source sequence length.
@@ -2505,9 +2535,12 @@ class TransformerDecoder(Cell):
2505
2535
  Should be mstype.float32 or mstype.float16. Default mstype.float32.
2506
2536
  param_init_type(dtype.Number): The parameter initialization type of the module.
2507
2537
  Should be mstype.float32 or mstype.float16. Default mstype.float32.
2508
- hidden_act(str): The activation of the internal feedforward layer. Supports 'relu',
2538
+ hidden_act (str, nn.Cell): The activation of the internal feedforward layer. Supports 'relu',
2509
2539
  'relu6', 'tanh', 'gelu', 'fast_gelu', 'elu', 'sigmoid', 'prelu', 'leakyrelu', 'hswish',
2510
- 'hsigmoid', 'logsigmoid' and so on. Default: gelu.
2540
+ 'hsigmoid', 'logsigmoid' and so on. User can provide custom activition to the argument.
2541
+ If user wants to run the net in the parallel mode, the custom activation must also provide
2542
+ the `activation_shard` function. Please see the examples of the
2543
+ class:`mindspore.nn.transformer.FeedForward`. Default: gelu.
2511
2544
  lambda_func(function): A function can determine the fusion index,
2512
2545
  pipeline stages and recompute attribute. If the
2513
2546
  user wants to determine the pipeline stage and gradient aggregation fusion, the user can pass a
@@ -2528,13 +2561,14 @@ class TransformerDecoder(Cell):
2528
2561
  - **hidden_stats** (Tensor) - The input tensor with shape [batch_size, seq_length, hidden_size] or
2529
2562
  [batch_size * seq_length, hidden_size]
2530
2563
  - **attention_mask** (Tensor) - The attention mask for decoder with shape
2531
- [batch_size, seq_length, seq_length]
2564
+ [batch_size, seq_length, seq_length] or None. None means there will be no mask in softmax
2565
+ computation in self attention.
2532
2566
  - **encoder_output** (Tensor) - The output of the encoder with shape [batch_size, seq_length, hidden_size]
2533
2567
  or [batch_size * seq_length, hidden_size]. Note this args can not be passed by None when the net is in
2534
2568
  outermost layer. Default None.
2535
2569
  - **memory_mask** (Tensor) - The memory mask of the cross attention with shape [batch, tgt_seq_length,
2536
- src_seq_length] where tgt_seq_length is the length of the decoder. Note this args can not be passed by
2537
- None when the net is in outermost layer. Default None.
2570
+ src_seq_length] where tgt_seq_length is the length of the decoder. The user can also pass None. None
2571
+ means there will be no mask in softmax computation in cross attention. Default None.
2538
2572
  - **init_reset** (Tensor) - A bool tensor with shape [1], used to clear the past key parameter and
2539
2573
  past value parameter used in the incremental prediction. Only valid when use_past is True. Default True.
2540
2574
  - **batch_valid_length** (Tensor) - Int32 tensor with shape [batch_size] the past calculated the index.
@@ -2591,7 +2625,6 @@ class TransformerDecoder(Cell):
2591
2625
  offset=Validator.check_non_negative_int,
2592
2626
  attention_dropout_rate=Validator.check_non_negative_float,
2593
2627
  hidden_dropout_rate=Validator.check_non_negative_float,
2594
- hidden_act=_valid_type_checks([str], "TransformerDecoder"),
2595
2628
  post_layernorm_residual=Validator.check_bool,
2596
2629
  layernorm_compute_type=_valid_value_checks([mstype.float32, mstype.float16],
2597
2630
  "TransformerDecoder"),
@@ -2627,7 +2660,7 @@ class TransformerDecoder(Cell):
2627
2660
  _check_config(parallel_config)
2628
2661
  self.use_moe = (moe_config.expert_num > 1)
2629
2662
  config_to_layer = parallel_config.moe_parallel_config if self.use_moe else parallel_config.dp_mp_config
2630
- if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,) and _is_sharding_propagation():
2663
+ if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,):
2631
2664
  self.add = P.Add()
2632
2665
  self.aux_loss = Tensor(0.0, mstype.float32)
2633
2666
  self.num_layers = num_layers
@@ -2731,12 +2764,14 @@ class Transformer(Cell):
2731
2764
  the residual addition before the layer normalization. And the default hidden act is `gelu`.
2732
2765
  The details can be found in `Attention is all you need <https://arxiv.org/pdf/1706.03762v5.pdf>`_.
2733
2766
 
2734
- Note:
2735
- This is an experimental interface that is subject to change or deletion.
2767
+ .. warning::
2768
+ This is an experimental API that is subject to change or deletion.
2736
2769
 
2737
2770
  Args:
2738
2771
  hidden_size(int): The hidden size of the input.
2739
- batch_size(int): The batch size of the input tensor.
2772
+ batch_size(int): The batch size of the input tensor when do increnmental prediction. Should be a positive
2773
+ value. When do training or prediction, the argument will not work and the user can just pass None to
2774
+ the argument.
2740
2775
  ffn_hidden_size(int): The hidden size of bottleneck in the feedforward layer.
2741
2776
  src_seq_length(int): The seq_length of the encoder's input tensor.
2742
2777
  tgt_seq_length(int): The seq_length of the decoder's input tensor.
@@ -2745,9 +2780,12 @@ class Transformer(Cell):
2745
2780
  num_heads(int): The number of the heads. Default: 2.
2746
2781
  attention_dropout_rate(float): The dropout rate of the attention scores. Default:0.1.
2747
2782
  hidden_dropout_rate(float): The dropout rate of the final output of the layer. Default:0.1.
2748
- hidden_act(str): The activation of the internal feedforward layer. Supports 'relu',
2783
+ hidden_act (str, nn.Cell): The activation of the internal feedforward layer. Supports 'relu',
2749
2784
  'relu6', 'tanh', 'gelu', 'fast_gelu', 'elu', 'sigmoid', 'prelu', 'leakyrelu', 'hswish',
2750
- 'hsigmoid', 'logsigmoid' and so on. Default: gelu.
2785
+ 'hsigmoid', 'logsigmoid' and so on. User can provide custom activition to the argument.
2786
+ If user wants to run the net in the parallel mode, the custom activation must also provide
2787
+ the `activation_shard` function. Please see the examples of the
2788
+ class:`mindspore.nn.transformer.FeedForward`. Default: gelu.
2751
2789
  post_layernorm_residual(bool): Do residuals adds before the layernorm. Default False.
2752
2790
  layernorm_compute_type(dtype.Number): The computation type of the layernorm.
2753
2791
  Should be dtype.float32 or dtype.float16. Default dtype.float32.
@@ -2772,15 +2810,17 @@ class Transformer(Cell):
2772
2810
  - **encoder_inputs** (Tensor) - The input tensor with shape [batch_size, seq_length, hidden_size] or
2773
2811
  [batch_size * seq_length, hidden_size].
2774
2812
  - **encoder_masks** (Tensor) - The attention mask for decoder with shape
2775
- [batch_size, seq_length, seq_length].
2813
+ [batch_size, seq_length, seq_length] or None. None means there will be no mask in softmax computation
2814
+ in self attention of the encoder module.
2776
2815
  - **decoder_inputs** (Tensor) - The output of the encoder with shape [batch_size, seq_length, hidden_size]
2777
2816
  or [batch_size * seq_length, hidden_size], this should be none if the decoder layer is 0.
2778
2817
  - **decoder_masks** (Tensor) - The attention mask for decoder with shape
2779
- [batch_size, seq_length, seq_length]
2818
+ [batch_size, seq_length, seq_length] or None. None means there will be no mask in softmax computation
2819
+ in self attention of the decoder module.
2780
2820
  - **memory_mask** (Tensor) - The memory mask of the cross attention with shape [batch, tgt_seq_length,
2781
2821
  src_seq_length]
2782
2822
  where tgt_seq_length is the length of the decoder. The output of the encoder with shape [batch_size,
2783
- seq_length, hidden_size], this should be none if the decoder layer is 0.
2823
+ seq_length, hidden_size], this should be none if the decoder layer is 0 or the user wants no mask.
2784
2824
  - **init_reset** (Tensor) - A bool tensor with shape [1], used to clear the past key parameter and
2785
2825
  past value parameter used in the incremental prediction. Only valid when use_past is True. Default True.
2786
2826
  - **batch_valid_length** (Tensor) - Int32 tensor with shape [batch_size] the past calculated the index.
@@ -2854,7 +2894,6 @@ class Transformer(Cell):
2854
2894
  tgt_seq_length=Validator.check_positive_int,
2855
2895
  attention_dropout_rate=Validator.check_non_negative_float,
2856
2896
  hidden_dropout_rate=Validator.check_non_negative_float,
2857
- hidden_act=_valid_type_checks([str], "Transformer"),
2858
2897
  post_layernorm_residual=Validator.check_bool,
2859
2898
  layernorm_compute_type=_valid_value_checks([mstype.float32, mstype.float16],
2860
2899
  "Transformer"),
@@ -2884,7 +2923,7 @@ class Transformer(Cell):
2884
2923
  moe_config=default_moe_config,
2885
2924
  parallel_config=default_transformer_config):
2886
2925
  super(Transformer, self).__init__()
2887
- if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,) and _is_sharding_propagation():
2926
+ if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,):
2888
2927
  _check_config(parallel_config)
2889
2928
  self.batch_size = batch_size
2890
2929
  self.hidden_size = hidden_size