mindspore 2.0.0a0__cp39-cp39-win_amd64.whl → 2.0.0rc1__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (655) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -2
  3. mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
  6. mindspore/_check_jit_forbidden_api.py +102 -0
  7. mindspore/_checkparam.py +1066 -1001
  8. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +4 -3
  9. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +50 -48
  10. mindspore/_extends/parallel_compile/akg_compiler/util.py +9 -4
  11. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +4 -4
  12. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +9 -4
  13. mindspore/_extends/parse/__init__.py +5 -3
  14. mindspore/_extends/parse/namespace.py +16 -1
  15. mindspore/_extends/parse/parser.py +107 -22
  16. mindspore/_extends/parse/resources.py +0 -7
  17. mindspore/_extends/parse/standard_method.py +885 -413
  18. mindspore/amp.py +52 -57
  19. mindspore/boost/boost.py +2 -2
  20. mindspore/boost/boost_cell_wrapper.py +38 -20
  21. mindspore/boost/dim_reduce.py +3 -3
  22. mindspore/boost/group_loss_scale_manager.py +1 -1
  23. mindspore/common/__init__.py +4 -6
  24. mindspore/common/_decorator.py +2 -0
  25. mindspore/common/_register_for_adapter.py +55 -0
  26. mindspore/common/_stub_tensor.py +201 -0
  27. mindspore/common/_utils.py +41 -7
  28. mindspore/common/api.py +215 -141
  29. mindspore/common/dtype.py +8 -1
  30. mindspore/common/dump.py +2 -2
  31. mindspore/common/initializer.py +4 -2
  32. mindspore/common/jit_config.py +17 -13
  33. mindspore/common/mutable.py +33 -13
  34. mindspore/common/parameter.py +23 -21
  35. mindspore/common/seed.py +8 -24
  36. mindspore/common/sparse_tensor.py +62 -41
  37. mindspore/common/tensor.py +852 -1154
  38. mindspore/communication/__init__.py +2 -2
  39. mindspore/communication/_comm_helper.py +11 -4
  40. mindspore/communication/management.py +22 -21
  41. mindspore/config/op_info.config +501 -1008
  42. mindspore/context.py +201 -23
  43. mindspore/dataset/__init__.py +6 -6
  44. mindspore/dataset/audio/__init__.py +7 -7
  45. mindspore/dataset/audio/transforms.py +670 -30
  46. mindspore/dataset/audio/utils.py +47 -4
  47. mindspore/dataset/audio/validators.py +223 -1
  48. mindspore/dataset/callback/ds_callback.py +2 -2
  49. mindspore/dataset/core/config.py +210 -14
  50. mindspore/dataset/core/validator_helpers.py +2 -2
  51. mindspore/{parallel/nn/layers.py → dataset/debug/__init__.py} +7 -8
  52. mindspore/dataset/debug/debug_hook.py +65 -0
  53. mindspore/dataset/debug/pre_defined_hook.py +67 -0
  54. mindspore/dataset/engine/__init__.py +7 -3
  55. mindspore/dataset/engine/cache_client.py +1 -1
  56. mindspore/dataset/engine/datasets.py +322 -66
  57. mindspore/dataset/engine/datasets_audio.py +80 -76
  58. mindspore/dataset/engine/datasets_standard_format.py +51 -38
  59. mindspore/dataset/engine/datasets_text.py +232 -118
  60. mindspore/dataset/engine/datasets_user_defined.py +41 -17
  61. mindspore/dataset/engine/datasets_vision.py +746 -225
  62. mindspore/dataset/engine/graphdata.py +75 -10
  63. mindspore/dataset/engine/iterators.py +45 -5
  64. mindspore/dataset/engine/offload.py +48 -28
  65. mindspore/dataset/engine/validators.py +117 -8
  66. mindspore/dataset/text/__init__.py +6 -5
  67. mindspore/dataset/text/transforms.py +86 -3
  68. mindspore/dataset/text/utils.py +6 -4
  69. mindspore/dataset/text/validators.py +25 -0
  70. mindspore/dataset/transforms/__init__.py +3 -2
  71. mindspore/dataset/transforms/c_transforms.py +1 -1
  72. mindspore/dataset/transforms/transforms.py +2 -2
  73. mindspore/dataset/utils/__init__.py +2 -1
  74. mindspore/dataset/utils/line_reader.py +121 -0
  75. mindspore/dataset/vision/__init__.py +2 -3
  76. mindspore/dataset/vision/c_transforms.py +9 -9
  77. mindspore/dataset/vision/py_transforms.py +5 -5
  78. mindspore/dataset/vision/py_transforms_util.py +2 -0
  79. mindspore/dataset/vision/transforms.py +160 -161
  80. mindspore/dataset/vision/utils.py +3 -3
  81. mindspore/experimental/map_parameter.py +38 -26
  82. mindspore/include/OWNERS +0 -1
  83. mindspore/include/api/callback/callback.h +9 -13
  84. mindspore/include/api/callback/ckpt_saver.h +2 -2
  85. mindspore/include/api/callback/loss_monitor.h +2 -2
  86. mindspore/include/api/callback/lr_scheduler.h +5 -5
  87. mindspore/include/api/callback/time_monitor.h +2 -2
  88. mindspore/include/api/callback/train_accuracy.h +4 -6
  89. mindspore/include/api/cfg.h +19 -6
  90. mindspore/include/api/context.h +44 -9
  91. mindspore/include/api/delegate.h +1 -1
  92. mindspore/include/api/metrics/accuracy.h +2 -2
  93. mindspore/include/api/metrics/metrics.h +4 -3
  94. mindspore/include/api/model.h +9 -4
  95. mindspore/include/api/model_parallel_runner.h +2 -2
  96. mindspore/include/api/net.h +12 -11
  97. mindspore/include/api/serialization.h +19 -3
  98. mindspore/include/api/types.h +3 -3
  99. mindspore/include/dataset/constants.h +7 -0
  100. mindspore/include/dataset/text.h +59 -0
  101. mindspore/jpeg62.dll +0 -0
  102. mindspore/log.py +1 -1
  103. mindspore/mindrecord/filereader.py +18 -0
  104. mindspore/mindrecord/filewriter.py +197 -34
  105. mindspore/mindrecord/shardreader.py +9 -0
  106. mindspore/mindrecord/shardwriter.py +1 -1
  107. mindspore/mindrecord/tools/cifar100_to_mr.py +3 -3
  108. mindspore/mindrecord/tools/cifar10_to_mr.py +3 -3
  109. mindspore/mindrecord/tools/csv_to_mr.py +3 -3
  110. mindspore/mindrecord/tools/imagenet_to_mr.py +16 -11
  111. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  112. mindspore/mindrecord/tools/tfrecord_to_mr.py +6 -6
  113. mindspore/mindspore_backend.dll +0 -0
  114. mindspore/mindspore_common.dll +0 -0
  115. mindspore/mindspore_core.dll +0 -0
  116. mindspore/mindspore_glog.dll +0 -0
  117. mindspore/mindspore_shared_lib.dll +0 -0
  118. mindspore/nn/__init__.py +0 -4
  119. mindspore/nn/cell.py +204 -132
  120. mindspore/nn/dynamic_lr.py +1 -1
  121. mindspore/nn/grad/cell_grad.py +7 -6
  122. mindspore/nn/layer/__init__.py +5 -4
  123. mindspore/nn/layer/activation.py +40 -89
  124. mindspore/nn/layer/basic.py +255 -624
  125. mindspore/nn/layer/channel_shuffle.py +7 -6
  126. mindspore/nn/layer/combined.py +1 -1
  127. mindspore/nn/layer/container.py +41 -4
  128. mindspore/nn/layer/conv.py +64 -28
  129. mindspore/nn/layer/dense.py +9 -8
  130. mindspore/nn/layer/embedding.py +27 -25
  131. mindspore/nn/layer/image.py +53 -46
  132. mindspore/nn/layer/math.py +97 -105
  133. mindspore/nn/layer/normalization.py +117 -86
  134. mindspore/nn/layer/padding.py +185 -95
  135. mindspore/nn/layer/pooling.py +817 -414
  136. mindspore/nn/layer/rnn_cells.py +10 -15
  137. mindspore/nn/layer/rnns.py +37 -38
  138. mindspore/nn/layer/thor_layer.py +11 -12
  139. mindspore/nn/layer/timedistributed.py +5 -5
  140. mindspore/nn/layer/transformer.py +701 -0
  141. mindspore/nn/learning_rate_schedule.py +8 -8
  142. mindspore/nn/loss/__init__.py +5 -4
  143. mindspore/nn/loss/loss.py +334 -199
  144. mindspore/nn/optim/ada_grad.py +6 -6
  145. mindspore/nn/optim/adadelta.py +2 -3
  146. mindspore/nn/optim/adafactor.py +4 -5
  147. mindspore/nn/optim/adam.py +126 -62
  148. mindspore/nn/optim/adamax.py +3 -4
  149. mindspore/nn/optim/adasum.py +6 -6
  150. mindspore/nn/optim/asgd.py +2 -2
  151. mindspore/nn/optim/ftrl.py +67 -38
  152. mindspore/nn/optim/lamb.py +4 -5
  153. mindspore/nn/optim/lars.py +2 -2
  154. mindspore/nn/optim/lazyadam.py +43 -4
  155. mindspore/nn/optim/momentum.py +6 -5
  156. mindspore/nn/optim/optimizer.py +3 -1
  157. mindspore/nn/optim/proximal_ada_grad.py +2 -2
  158. mindspore/nn/optim/rmsprop.py +1 -1
  159. mindspore/nn/optim/rprop.py +8 -9
  160. mindspore/nn/optim/sgd.py +19 -13
  161. mindspore/nn/optim/thor.py +10 -15
  162. mindspore/nn/probability/__init__.py +0 -2
  163. mindspore/nn/probability/bijector/bijector.py +4 -4
  164. mindspore/nn/probability/bijector/invert.py +1 -1
  165. mindspore/nn/probability/bijector/softplus.py +2 -2
  166. mindspore/nn/probability/bnn_layers/dense_variational.py +1 -1
  167. mindspore/nn/probability/bnn_layers/layer_distribution.py +2 -2
  168. mindspore/nn/probability/distribution/_utils/utils.py +9 -15
  169. mindspore/nn/probability/distribution/bernoulli.py +3 -3
  170. mindspore/nn/probability/distribution/beta.py +1 -1
  171. mindspore/nn/probability/distribution/categorical.py +5 -7
  172. mindspore/nn/probability/distribution/cauchy.py +3 -3
  173. mindspore/nn/probability/distribution/distribution.py +2 -2
  174. mindspore/nn/probability/distribution/exponential.py +2 -2
  175. mindspore/nn/probability/distribution/gamma.py +3 -3
  176. mindspore/nn/probability/distribution/geometric.py +1 -1
  177. mindspore/nn/probability/distribution/gumbel.py +3 -3
  178. mindspore/nn/probability/distribution/half_normal.py +15 -11
  179. mindspore/nn/probability/distribution/laplace.py +16 -13
  180. mindspore/nn/probability/distribution/logistic.py +2 -2
  181. mindspore/nn/probability/distribution/normal.py +1 -1
  182. mindspore/nn/probability/distribution/poisson.py +1 -1
  183. mindspore/nn/probability/distribution/student_t.py +20 -15
  184. mindspore/nn/probability/distribution/transformed_distribution.py +4 -4
  185. mindspore/nn/probability/distribution/uniform.py +2 -2
  186. mindspore/nn/reinforcement/_tensors_queue.py +3 -3
  187. mindspore/nn/reinforcement/tensor_array.py +2 -2
  188. mindspore/nn/sparse/sparse.py +2 -2
  189. mindspore/nn/wrap/cell_wrapper.py +27 -10
  190. mindspore/nn/wrap/grad_reducer.py +2 -2
  191. mindspore/nn/wrap/loss_scale.py +40 -24
  192. mindspore/numpy/array_creations.py +33 -22
  193. mindspore/numpy/array_ops.py +35 -30
  194. mindspore/numpy/logic_ops.py +6 -27
  195. mindspore/numpy/math_ops.py +22 -19
  196. mindspore/numpy/utils.py +1 -1
  197. mindspore/numpy/utils_const.py +108 -58
  198. mindspore/opencv_core452.dll +0 -0
  199. mindspore/opencv_imgcodecs452.dll +0 -0
  200. mindspore/opencv_imgproc452.dll +0 -0
  201. mindspore/ops/_constants.py +0 -6
  202. mindspore/ops/_grad/__init__.py +2 -1
  203. mindspore/ops/_grad/grad_array_ops.py +86 -117
  204. mindspore/ops/_grad/grad_base.py +23 -1
  205. mindspore/ops/_grad/grad_clip_ops.py +2 -3
  206. mindspore/ops/_grad/grad_comm_ops.py +34 -24
  207. mindspore/ops/_grad/grad_implementations.py +9 -45
  208. mindspore/ops/_grad/grad_inner_ops.py +47 -4
  209. mindspore/ops/_grad/grad_math_ops.py +142 -117
  210. mindspore/ops/_grad/grad_nn_ops.py +71 -165
  211. mindspore/ops/_grad/grad_sequence_ops.py +296 -0
  212. mindspore/ops/_grad/grad_sparse.py +7 -6
  213. mindspore/ops/_grad_experimental/__init__.py +1 -0
  214. mindspore/ops/_grad_experimental/grad_array_ops.py +150 -15
  215. mindspore/ops/_grad_experimental/grad_image_ops.py +16 -7
  216. mindspore/ops/_grad_experimental/grad_inner_ops.py +1 -22
  217. mindspore/ops/_grad_experimental/grad_linalg_ops.py +4 -11
  218. mindspore/ops/_grad_experimental/grad_math_ops.py +210 -89
  219. mindspore/ops/_grad_experimental/grad_nn_ops.py +26 -22
  220. mindspore/ops/_grad_experimental/grad_scalar_ops.py +112 -0
  221. mindspore/ops/_grad_experimental/grad_sparse_ops.py +49 -8
  222. mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +1 -1
  223. mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +2 -2
  224. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +2 -2
  225. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +2 -2
  226. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +4 -4
  227. mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +3 -3
  228. mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +1 -1
  229. mindspore/ops/_op_impl/_custom_op/correction_mul.py +2 -2
  230. mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +2 -2
  231. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -5
  232. mindspore/ops/_op_impl/_custom_op/dsd_impl.py +1 -1
  233. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +2 -2
  234. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +2 -2
  235. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +2 -2
  236. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +2 -2
  237. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +2 -2
  238. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +2 -2
  239. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +2 -2
  240. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +2 -2
  241. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +2 -2
  242. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +2 -2
  243. mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +1 -1
  244. mindspore/ops/_op_impl/_custom_op/img2col_impl.py +1 -1
  245. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
  246. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +1 -1
  247. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +1 -1
  248. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +1 -1
  249. mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +2 -2
  250. mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +0 -4
  251. mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +1 -1
  252. mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +2 -2
  253. mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +2 -2
  254. mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +1 -1
  255. mindspore/ops/_op_impl/aicpu/__init__.py +236 -4
  256. mindspore/ops/_op_impl/aicpu/abs.py +36 -0
  257. mindspore/ops/_op_impl/aicpu/{adaptive_avg_pool_2d_v1.py → adaptive_avg_pool_2d.py} +6 -5
  258. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
  259. mindspore/ops/_op_impl/aicpu/add.py +43 -0
  260. mindspore/ops/_op_impl/aicpu/addcdiv.py +0 -32
  261. mindspore/ops/_op_impl/aicpu/addcmul.py +0 -84
  262. mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
  263. mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -43
  264. mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
  265. mindspore/{compression/common/__init__.py → ops/_op_impl/aicpu/bessel_i0.py} +15 -8
  266. mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
  267. mindspore/ops/_op_impl/aicpu/conj.py +11 -0
  268. mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +0 -3
  269. mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
  270. mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +43 -0
  271. mindspore/ops/_op_impl/aicpu/{adaptive_avg_pool_2d_grad_v1.py → digamma.py} +7 -9
  272. mindspore/ops/_op_impl/aicpu/flatten.py +1 -0
  273. mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
  274. mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
  275. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +1 -1
  276. mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
  277. mindspore/ops/_op_impl/aicpu/greater.py +41 -0
  278. mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
  279. mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
  280. mindspore/ops/_op_impl/aicpu/less.py +41 -0
  281. mindspore/{nn/probability/infer/variational/__init__.py → ops/_op_impl/aicpu/lgamma.py} +16 -10
  282. mindspore/ops/_op_impl/aicpu/mirror_pad.py +0 -4
  283. mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +0 -4
  284. mindspore/ops/_op_impl/aicpu/mul.py +3 -1
  285. mindspore/ops/_op_impl/aicpu/multinomial.py +14 -6
  286. mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
  287. mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
  288. mindspore/ops/_op_impl/aicpu/ones_like.py +0 -2
  289. mindspore/ops/_op_impl/aicpu/polar.py +32 -0
  290. mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
  291. mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
  292. mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
  293. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
  294. mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
  295. mindspore/ops/_op_impl/aicpu/resize_bicubic.py +2 -8
  296. mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +1 -1
  297. mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
  298. mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
  299. mindspore/ops/_op_impl/aicpu/scatter_elements.py +4 -0
  300. mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +2 -0
  301. mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
  302. mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
  303. mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
  304. mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
  305. mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
  306. mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +0 -24
  307. mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
  308. mindspore/ops/_op_impl/aicpu/sparse_slice.py +4 -0
  309. mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +6 -0
  310. mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
  311. mindspore/ops/_op_impl/aicpu/trans_data.py +1 -0
  312. mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
  313. mindspore/ops/_op_impl/aicpu/uniform.py +34 -0
  314. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +1 -0
  315. mindspore/ops/_op_impl/aicpu/unique_consecutive.py +10 -2
  316. mindspore/ops/_op_impl/cpu/dynamic_shape.py +5 -1
  317. mindspore/ops/_op_impl/cpu/sparse_slice.py +4 -0
  318. mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +6 -0
  319. mindspore/ops/_op_impl/cpu/tensor_shape.py +5 -1
  320. mindspore/ops/_op_impl/tbe/__init__.py +27 -611
  321. mindspore/ops/_op_impl/tbe/assign_add_ds.py +1 -0
  322. mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
  323. mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +1 -1
  324. mindspore/ops/_op_impl/tbe/batch_matmul_ds.py +1 -0
  325. mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
  326. mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +1 -1
  327. mindspore/ops/_op_impl/tbe/bn_infer_grad.py +4 -2
  328. mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -1
  329. mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -1
  330. mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +6 -4
  331. mindspore/ops/_op_impl/tbe/cast.py +0 -2
  332. mindspore/ops/_op_impl/tbe/cast_ds.py +3 -3
  333. mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +1 -0
  334. mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +2 -2
  335. mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +1 -1
  336. mindspore/ops/_op_impl/tbe/gather_nd.py +1 -0
  337. mindspore/ops/_op_impl/tbe/{index_add.py → inplace_index_add.py} +3 -6
  338. mindspore/ops/_op_impl/tbe/matmul_ds.py +2 -0
  339. mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +35 -0
  340. mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +35 -0
  341. mindspore/ops/_op_impl/tbe/scatter_mul.py +2 -0
  342. mindspore/ops/_op_impl/tbe/scatter_nd_add.py +0 -2
  343. mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
  344. mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +1 -1
  345. mindspore/ops/_op_impl/tbe/trans_data_ds.py +15 -5
  346. mindspore/ops/_register_for_op.py +1 -0
  347. mindspore/ops/_utils/__init__.py +1 -2
  348. mindspore/ops/_utils/utils.py +19 -40
  349. mindspore/ops/_vmap/vmap_array_ops.py +116 -38
  350. mindspore/ops/_vmap/vmap_base.py +16 -9
  351. mindspore/ops/_vmap/vmap_convolution_ops.py +7 -10
  352. mindspore/ops/_vmap/vmap_grad_math_ops.py +4 -4
  353. mindspore/ops/_vmap/vmap_grad_nn_ops.py +7 -5
  354. mindspore/ops/_vmap/vmap_image_ops.py +12 -5
  355. mindspore/ops/_vmap/vmap_math_ops.py +46 -5
  356. mindspore/ops/_vmap/vmap_nn_ops.py +15 -21
  357. mindspore/ops/_vmap/vmap_random_ops.py +1 -1
  358. mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
  359. mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
  360. mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +150 -0
  361. mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +66 -0
  362. mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
  363. mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
  364. mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
  365. mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +33 -0
  366. mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +220 -106
  367. mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
  368. mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +240 -0
  369. mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +247 -0
  370. mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +247 -0
  371. mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +315 -0
  372. mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +278 -0
  373. mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +58 -0
  374. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +138 -0
  375. mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
  376. mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
  377. mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +22 -23
  378. mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +16 -17
  379. mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +27 -0
  380. mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
  381. mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
  382. mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
  383. mindspore/ops/bprop_mindir/Elu_bprop.mindir +16 -0
  384. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  385. mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +39 -41
  386. mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +16 -0
  387. mindspore/ops/bprop_mindir/Flatten_bprop.mindir +41 -43
  388. mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +51 -57
  389. mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
  390. mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +16 -0
  391. mindspore/ops/bprop_mindir/HSwish_bprop.mindir +16 -0
  392. mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
  393. mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +126 -0
  394. mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +15 -0
  395. mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +30 -0
  396. mindspore/ops/bprop_mindir/LRN_bprop.mindir +43 -0
  397. mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
  398. mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +23 -0
  399. mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +74 -0
  400. mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +74 -0
  401. mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +75 -0
  402. mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +65 -0
  403. mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
  404. mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +27 -0
  405. mindspore/ops/bprop_mindir/Mish_bprop.mindir +35 -0
  406. mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
  407. mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
  408. mindspore/ops/bprop_mindir/OneHot_bprop.mindir +24 -25
  409. mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
  410. mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
  411. mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
  412. mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +29 -0
  413. mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +82 -0
  414. mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +16 -0
  415. mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
  416. mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +18 -19
  417. mindspore/ops/bprop_mindir/Reshape_bprop.mindir +53 -53
  418. mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +29 -0
  419. mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +77 -85
  420. mindspore/ops/bprop_mindir/SeLU_bprop.mindir +21 -0
  421. mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +21 -0
  422. mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
  423. mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +16 -0
  424. mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +36 -0
  425. mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  426. mindspore/ops/bprop_mindir/Softplus_bprop.mindir +16 -0
  427. mindspore/ops/bprop_mindir/Softsign_bprop.mindir +33 -0
  428. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  429. mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +37 -39
  430. mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +70 -72
  431. mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
  432. mindspore/ops/bprop_mindir/Tanh_bprop.mindir +66 -0
  433. mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
  434. mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
  435. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +17 -17
  436. mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +32 -0
  437. mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +38 -0
  438. mindspore/ops/bprop_mindir/generate_mindir.py +2 -0
  439. mindspore/ops/composite/__init__.py +7 -8
  440. mindspore/ops/composite/base.py +101 -47
  441. mindspore/ops/composite/math_ops.py +188 -158
  442. mindspore/ops/composite/multitype_ops/_compile_utils.py +415 -170
  443. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +142 -87
  444. mindspore/ops/composite/multitype_ops/add_impl.py +6 -1
  445. mindspore/ops/composite/multitype_ops/div_impl.py +2 -3
  446. mindspore/ops/composite/multitype_ops/getitem_impl.py +31 -3
  447. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +31 -0
  448. mindspore/ops/composite/multitype_ops/greater_impl.py +31 -0
  449. mindspore/ops/composite/multitype_ops/in_impl.py +9 -0
  450. mindspore/ops/composite/multitype_ops/less_equal_impl.py +31 -0
  451. mindspore/ops/composite/multitype_ops/less_impl.py +31 -0
  452. mindspore/ops/composite/multitype_ops/mul_impl.py +21 -5
  453. mindspore/ops/composite/multitype_ops/not_in_impl.py +9 -0
  454. mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -4
  455. mindspore/ops/composite/multitype_ops/setitem_impl.py +21 -3
  456. mindspore/ops/composite/multitype_ops/sub_impl.py +1 -1
  457. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +35 -4
  458. mindspore/ops/function/__init__.py +152 -8
  459. mindspore/ops/function/array_func.py +2555 -674
  460. mindspore/ops/function/clip_func.py +209 -13
  461. mindspore/ops/function/debug_func.py +2 -2
  462. mindspore/ops/function/grad/__init__.py +2 -1
  463. mindspore/ops/function/grad/grad_func.py +147 -62
  464. mindspore/ops/function/image_func.py +54 -38
  465. mindspore/ops/function/linalg_func.py +167 -16
  466. mindspore/ops/function/math_func.py +4849 -1492
  467. mindspore/ops/function/nn_func.py +2573 -988
  468. mindspore/ops/function/other_func.py +115 -0
  469. mindspore/ops/function/parameter_func.py +3 -3
  470. mindspore/ops/function/random_func.py +790 -73
  471. mindspore/ops/function/sparse_func.py +98 -78
  472. mindspore/ops/function/sparse_unary_func.py +54 -53
  473. mindspore/ops/function/spectral_func.py +27 -24
  474. mindspore/ops/function/vmap_func.py +22 -2
  475. mindspore/ops/functional.py +97 -37
  476. mindspore/ops/op_info_register.py +70 -28
  477. mindspore/ops/operations/__init__.py +47 -14
  478. mindspore/ops/operations/_csr_ops.py +7 -7
  479. mindspore/ops/operations/_embedding_cache_ops.py +5 -5
  480. mindspore/ops/operations/_grad_ops.py +276 -187
  481. mindspore/ops/operations/_inner_ops.py +319 -113
  482. mindspore/ops/operations/_ms_kernel.py +10 -8
  483. mindspore/ops/operations/_ocr_ops.py +9 -9
  484. mindspore/ops/operations/_opaque_predicate_registry.py +4 -0
  485. mindspore/ops/operations/_quant_ops.py +137 -102
  486. mindspore/ops/operations/_rl_inner_ops.py +121 -60
  487. mindspore/ops/operations/_scalar_ops.py +466 -0
  488. mindspore/ops/operations/_sequence_ops.py +1004 -2
  489. mindspore/ops/operations/_tensor_array.py +10 -11
  490. mindspore/ops/operations/_thor_ops.py +1 -1
  491. mindspore/ops/operations/array_ops.py +801 -466
  492. mindspore/ops/operations/comm_ops.py +51 -49
  493. mindspore/ops/operations/control_ops.py +2 -2
  494. mindspore/ops/operations/custom_ops.py +123 -44
  495. mindspore/ops/operations/debug_ops.py +24 -24
  496. mindspore/ops/operations/image_ops.py +240 -153
  497. mindspore/ops/operations/inner_ops.py +34 -50
  498. mindspore/ops/operations/linalg_ops.py +31 -9
  499. mindspore/ops/operations/math_ops.py +988 -757
  500. mindspore/ops/operations/nn_ops.py +965 -819
  501. mindspore/ops/operations/other_ops.py +51 -40
  502. mindspore/ops/operations/random_ops.py +204 -122
  503. mindspore/ops/operations/rl_ops.py +8 -9
  504. mindspore/ops/operations/sparse_ops.py +254 -93
  505. mindspore/ops/operations/spectral_ops.py +35 -3
  506. mindspore/ops/primitive.py +111 -9
  507. mindspore/parallel/_auto_parallel_context.py +189 -83
  508. mindspore/parallel/_offload_context.py +185 -0
  509. mindspore/parallel/_parallel_serialization.py +99 -7
  510. mindspore/parallel/_ps_context.py +9 -5
  511. mindspore/parallel/_recovery_context.py +1 -1
  512. mindspore/parallel/_tensor.py +7 -1
  513. mindspore/{nn/transformer → parallel/_transformer}/__init__.py +6 -6
  514. mindspore/{nn/transformer → parallel/_transformer}/layers.py +6 -37
  515. mindspore/{nn/transformer → parallel/_transformer}/loss.py +4 -7
  516. mindspore/{nn/transformer → parallel/_transformer}/moe.py +20 -16
  517. mindspore/{nn/transformer → parallel/_transformer}/op_parallel_config.py +3 -3
  518. mindspore/{nn/transformer → parallel/_transformer}/transformer.py +48 -111
  519. mindspore/parallel/_utils.py +1 -2
  520. mindspore/parallel/algo_parameter_config.py +1 -1
  521. mindspore/parallel/checkpoint_transform.py +37 -34
  522. mindspore/parallel/shard.py +17 -18
  523. mindspore/profiler/common/validator/validate_path.py +2 -2
  524. mindspore/profiler/envprofiling.py +69 -47
  525. mindspore/profiler/parser/ascend_timeline_generator.py +49 -42
  526. mindspore/profiler/parser/base_timeline_generator.py +49 -56
  527. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +98 -78
  528. mindspore/profiler/parser/hwts_log_parser.py +1 -1
  529. mindspore/profiler/parser/integrator.py +15 -14
  530. mindspore/profiler/parser/minddata_analyzer.py +2 -2
  531. mindspore/profiler/parser/msadvisor_analyzer.py +12 -25
  532. mindspore/profiler/parser/msadvisor_parser.py +2 -4
  533. mindspore/profiler/parser/optime_parser.py +17 -18
  534. mindspore/profiler/parser/profiler_info.py +2 -1
  535. mindspore/profiler/profiling.py +218 -186
  536. mindspore/rewrite/__init__.py +3 -1
  537. mindspore/rewrite/api/node.py +1 -114
  538. mindspore/rewrite/api/node_type.py +3 -0
  539. mindspore/rewrite/api/pattern_engine.py +31 -1
  540. mindspore/rewrite/api/scoped_value.py +4 -4
  541. mindspore/rewrite/api/symbol_tree.py +3 -78
  542. mindspore/rewrite/api/tree_node_helper.py +1 -1
  543. mindspore/rewrite/ast_creator_register.py +1 -0
  544. mindspore/rewrite/ast_helpers/__init__.py +2 -2
  545. mindspore/rewrite/ast_helpers/ast_creator.py +1 -2
  546. mindspore/rewrite/ast_helpers/ast_finder.py +65 -0
  547. mindspore/rewrite/ast_helpers/ast_modifier.py +11 -3
  548. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +18 -2
  549. mindspore/rewrite/namespace.py +0 -2
  550. mindspore/rewrite/node.py +157 -11
  551. mindspore/rewrite/parsers/assign_parser.py +231 -53
  552. mindspore/rewrite/parsers/class_def_parser.py +187 -109
  553. mindspore/rewrite/parsers/for_parser.py +24 -14
  554. mindspore/rewrite/parsers/function_def_parser.py +21 -4
  555. mindspore/rewrite/parsers/if_parser.py +6 -2
  556. mindspore/rewrite/sparsify/__init__.py +0 -0
  557. mindspore/rewrite/sparsify/sparse_transformer.py +448 -0
  558. mindspore/rewrite/sparsify/sparsify.py +109 -0
  559. mindspore/rewrite/sparsify/utils.py +173 -0
  560. mindspore/rewrite/symbol_tree.py +256 -133
  561. mindspore/rewrite/symbol_tree_builder.py +38 -1
  562. mindspore/run_check/_check_version.py +69 -63
  563. mindspore/run_check/run_check.py +2 -1
  564. mindspore/tinyxml2.dll +0 -0
  565. mindspore/train/__init__.py +1 -1
  566. mindspore/train/_utils.py +28 -5
  567. mindspore/train/amp.py +273 -102
  568. mindspore/train/callback/_backup_and_restore.py +5 -5
  569. mindspore/train/callback/_callback.py +2 -2
  570. mindspore/train/callback/_checkpoint.py +3 -3
  571. mindspore/train/callback/_early_stop.py +3 -3
  572. mindspore/train/callback/_lambda_callback.py +2 -2
  573. mindspore/train/callback/_landscape.py +29 -31
  574. mindspore/train/callback/_loss_monitor.py +3 -3
  575. mindspore/train/callback/_on_request_exit.py +3 -3
  576. mindspore/train/callback/_reduce_lr_on_plateau.py +4 -4
  577. mindspore/train/callback/_summary_collector.py +23 -16
  578. mindspore/train/callback/_time_monitor.py +3 -3
  579. mindspore/train/checkpoint_pb2.py +68 -8
  580. mindspore/train/data_sink.py +15 -3
  581. mindspore/train/dataset_helper.py +10 -15
  582. mindspore/train/loss_scale_manager.py +8 -11
  583. mindspore/train/metrics/__init__.py +1 -1
  584. mindspore/train/metrics/bleu_score.py +1 -1
  585. mindspore/train/metrics/confusion_matrix.py +1 -1
  586. mindspore/train/metrics/cosine_similarity.py +1 -1
  587. mindspore/train/metrics/dice.py +2 -2
  588. mindspore/train/metrics/fbeta.py +1 -1
  589. mindspore/train/metrics/hausdorff_distance.py +4 -3
  590. mindspore/train/metrics/mean_surface_distance.py +2 -2
  591. mindspore/train/metrics/occlusion_sensitivity.py +1 -1
  592. mindspore/train/metrics/perplexity.py +1 -1
  593. mindspore/train/metrics/precision.py +1 -1
  594. mindspore/train/metrics/recall.py +1 -1
  595. mindspore/train/metrics/roc.py +2 -2
  596. mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
  597. mindspore/train/mind_ir_pb2.py +116 -37
  598. mindspore/train/model.py +45 -28
  599. mindspore/train/serialization.py +295 -188
  600. mindspore/train/summary/_summary_adapter.py +1 -1
  601. mindspore/train/summary/summary_record.py +43 -13
  602. mindspore/train/train_thor/convert_utils.py +2 -2
  603. mindspore/train/train_thor/dataset_helper.py +3 -3
  604. mindspore/turbojpeg.dll +0 -0
  605. mindspore/version.py +1 -1
  606. {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/METADATA +3 -2
  607. {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/RECORD +610 -541
  608. mindspore/compression/__init__.py +0 -19
  609. mindspore/compression/common/constant.py +0 -124
  610. mindspore/compression/export/__init__.py +0 -19
  611. mindspore/compression/export/quant_export.py +0 -515
  612. mindspore/compression/quant/__init__.py +0 -28
  613. mindspore/compression/quant/qat.py +0 -634
  614. mindspore/compression/quant/quant_utils.py +0 -462
  615. mindspore/compression/quant/quantizer.py +0 -68
  616. mindspore/nn/layer/quant.py +0 -1868
  617. mindspore/nn/layer/rnn_utils.py +0 -90
  618. mindspore/nn/probability/dpn/__init__.py +0 -22
  619. mindspore/nn/probability/dpn/vae/__init__.py +0 -25
  620. mindspore/nn/probability/dpn/vae/cvae.py +0 -140
  621. mindspore/nn/probability/dpn/vae/vae.py +0 -124
  622. mindspore/nn/probability/infer/__init__.py +0 -22
  623. mindspore/nn/probability/infer/variational/elbo.py +0 -70
  624. mindspore/nn/probability/infer/variational/svi.py +0 -84
  625. mindspore/nn/probability/toolbox/__init__.py +0 -22
  626. mindspore/nn/probability/toolbox/anomaly_detection.py +0 -99
  627. mindspore/nn/probability/toolbox/uncertainty_evaluation.py +0 -364
  628. mindspore/nn/probability/transforms/__init__.py +0 -22
  629. mindspore/nn/probability/transforms/transform_bnn.py +0 -262
  630. mindspore/nn/probability/zhusuan/__init__.py +0 -18
  631. mindspore/nn/probability/zhusuan/framework/__init__.py +0 -18
  632. mindspore/nn/probability/zhusuan/framework/bn.py +0 -95
  633. mindspore/nn/probability/zhusuan/variational/__init__.py +0 -18
  634. mindspore/nn/probability/zhusuan/variational/elbo.py +0 -46
  635. mindspore/ops/_op_impl/aicpu/parallel_concat.py +0 -42
  636. mindspore/ops/_op_impl/tbe/gather_v2.py +0 -56
  637. mindspore/ops/bprop_mindir/AssignAdd_bprop.mindir +0 -19
  638. mindspore/ops/bprop_mindir/Cast_bprop.mindir +0 -19
  639. mindspore/ops/bprop_mindir/LogicalOr_bprop.mindir +0 -19
  640. mindspore/ops/bprop_mindir/MatMul_bprop.mindir +0 -0
  641. mindspore/ops/bprop_mindir/ReLU_bprop.mindir +0 -17
  642. mindspore/ops/bprop_mindir/Transpose_bprop.mindir +0 -0
  643. mindspore/ops/bprop_mindir/UpdateState_bprop.mindir +0 -15
  644. mindspore/ops/composite/array_ops.py +0 -241
  645. mindspore/ops/composite/clip_ops.py +0 -134
  646. mindspore/ops/composite/random_ops.py +0 -426
  647. mindspore/ops/composite/vmap_ops.py +0 -38
  648. mindspore/parallel/nn/__init__.py +0 -42
  649. mindspore/parallel/nn/loss.py +0 -22
  650. mindspore/parallel/nn/moe.py +0 -21
  651. mindspore/parallel/nn/op_parallel_config.py +0 -22
  652. mindspore/parallel/nn/transformer.py +0 -31
  653. {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/WHEEL +0 -0
  654. {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/entry_points.txt +0 -0
  655. {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/top_level.txt +0 -0
@@ -13,22 +13,22 @@
13
13
  # limitations under the License.
14
14
  # ============================================================================
15
15
  """math Operations."""
16
- import numpy as np
17
16
  from mindspore.ops.composite.multitype_ops import _constexpr_utils as const_utils
18
17
  from mindspore.common import dtype as mstype
19
- from mindspore._checkparam import Validator as validator
20
- from mindspore.ops.primitive import constexpr
18
+ from mindspore import _checkparam as validator
19
+ from mindspore.ops.primitive import constexpr, _primexpr
21
20
  from mindspore.ops import functional as F
22
- from mindspore.ops.operations._inner_ops import DynamicResizeNearestNeighbor
23
21
  from mindspore.ops.function.math_func import cummin as cummin_
24
22
  from mindspore.ops import operations as P
25
23
 
26
24
 
27
- @constexpr
25
+ @_primexpr
28
26
  def _check_validate_axis(axis, name):
29
- if isinstance(axis, (tuple, list)):
30
- for idx, item in enumerate(axis):
31
- validator.check_value_type("axis[%d]" % idx, item, [int], name)
27
+ def _check(axis):
28
+ if isinstance(axis, (tuple, list)):
29
+ for idx, item in enumerate(axis):
30
+ validator.check_value_type("axis[%d]" % idx, item, [int], name)
31
+ _check(axis)
32
32
  axis = validator.check_value_type('axis', axis, [int, tuple, list], name)
33
33
  return axis
34
34
 
@@ -46,24 +46,26 @@ def is_const(x):
46
46
 
47
47
  def count_nonzero(x, axis=(), keep_dims=False, dtype=mstype.int32):
48
48
  r"""
49
- Count number of nonzero elements across axis of input tensor
49
+ Count number of nonzero elements across axis of input tensor.
50
50
 
51
51
  Args:
52
- x (Tensor): Input data is used to count non-zero numbers.
53
- :math:`(N,*)` where :math:`*` means, any number of additional dimensions.
54
- axis (Union[int, tuple(int), list(int)]): The dimensions to reduce. Only constant value is allowed.
55
- Default: (), reduce all dimensions.
56
- keep_dims (bool): If true, keep these reduced dimensions and the length is 1.
57
- If false, don't keep these dimensions. Default: False.
58
- dtype (Union[Number, mindspore.bool\_]): The data type of the output tensor. Only constant value is allowed.
59
- Default: mindspore.int32
52
+ x (Tensor): Input data is used to count non-zero numbers. With shape
53
+ :math:`(N,*)` where :math:`*` means, any number of additional dimensions.
54
+ axis (Union[int, tuple(int), list(int)], optional): The dimensions to reduce.
55
+ Default: (), reduce all dimensions.
56
+ keep_dims (bool, optional): Whether to maintain dimensions specified by `axis`.
57
+ If true, keep these reduced dimensions and the length is 1.
58
+ If false, don't keep these dimensions. Default: False.
59
+ dtype (Union[Number, mindspore.bool\_], optional): The data type of the output tensor.
60
+ Default: mindspore.int32.
60
61
 
61
62
  Returns:
62
- Tensor, number of nonzero element. The data type is `dtype`.
63
+ Tensor, number of nonzero element across axis specified by `axis`.
64
+ The data type is specified by `dtype`.
63
65
 
64
66
  Raises:
65
- TypeError: If axis is not int or tuple.
66
- ValueError: If axis is not in range [-x.ndim, x.ndim).
67
+ TypeError: If `axis` is not int, tuple or list.
68
+ ValueError: If any value in `axis` is not in range [-x.ndim, x.ndim).
67
69
 
68
70
  Supported Platforms:
69
71
  ``Ascend`` ``GPU`` ``CPU``
@@ -116,7 +118,7 @@ def count_nonzero(x, axis=(), keep_dims=False, dtype=mstype.int32):
116
118
  return nonzero_num
117
119
 
118
120
 
119
- @constexpr
121
+ @_primexpr
120
122
  def _int_to_tuple_conv(axes):
121
123
  """
122
124
  Converts ints to tuples in input axes, expected by most validation checks.
@@ -127,7 +129,7 @@ def _int_to_tuple_conv(axes):
127
129
  return axes
128
130
 
129
131
 
130
- @constexpr
132
+ @_primexpr
131
133
  def _check_axes(axes, prim_name=None):
132
134
  """
133
135
  Check for validity and type of axes passed to function.
@@ -160,21 +162,29 @@ def _typecheck_input(x1_type, x2_type, prim_name=None):
160
162
  f"and x2_type: {x2_type}.")
161
163
 
162
164
 
163
- @constexpr
165
+ @_primexpr
164
166
  def _axes_int_check(x1_shape, x2_shape, axes, prim_name=None):
165
167
  """
166
168
  Convert from single int axes to 2d tuple if required
167
169
  """
168
170
  msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
169
- if isinstance(axes, int):
171
+
172
+ def _check_lt_zero(axes):
170
173
  if axes < 0:
171
174
  raise ValueError(f"{msg_prefix} 'axes' must be at least 0, but got {axes}.")
172
- if axes == 0:
173
- # outer product, no input validation required
174
- return [], []
175
+
176
+ def _check_len(axes, x1_shape, x2_shape):
175
177
  if axes > len(x1_shape) or axes > len(x2_shape):
176
178
  raise ValueError(f"{msg_prefix} 'axes' cannot be greater than the length of 'x1_shape' and 'x2_shape', "
177
179
  f"but got 'axes': {axes}, 'x1_shape': {x1_shape}, 'x2_shape': {x2_shape}.")
180
+
181
+
182
+ if isinstance(axes, int):
183
+ _check_lt_zero(axes)
184
+ if axes == 0:
185
+ # outer product, no input validation required
186
+ return [], []
187
+ _check_len(axes, x1_shape, x2_shape)
178
188
  x1_ind = tuple(range(len(x1_shape))[-1 * axes:])
179
189
  x2_ind = tuple(range(len(x2_shape))[:axes])
180
190
  axes = tuple((x1_ind, x2_ind))
@@ -182,7 +192,7 @@ def _axes_int_check(x1_shape, x2_shape, axes, prim_name=None):
182
192
  return axes
183
193
 
184
194
 
185
- @constexpr
195
+ @_primexpr
186
196
  def _validate_axes(x1_shape, x2_shape, axes, prim_name=None):
187
197
  """
188
198
  Checks for axes having the correct length according to input, for any value in axis
@@ -190,25 +200,32 @@ def _validate_axes(x1_shape, x2_shape, axes, prim_name=None):
190
200
  with given inputs.
191
201
  """
192
202
  msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
203
+
204
+ def _check_len(axes_len, shape_dim_len, x_axes):
205
+ if axes_len > shape_dim_len:
206
+ raise ValueError(f"{msg_prefix} length of element {x_axes} in 'axes' must be less than or equal to "
207
+ f"{shape_dim_len}, but got {axes_len}.")
208
+
209
+ def _check_value(x_axes, min_val, max_val):
210
+ for _, x_value in enumerate(x_axes):
211
+ if x_value > max_val or x_value < min_val:
212
+ raise ValueError(f"{msg_prefix} value in 'axes' must be in range: [{min_val}, {max_val}], "
213
+ f"but got {x_value}.")
214
+
193
215
  shapes = [x1_shape, x2_shape]
194
216
 
195
217
  # axis length check
196
218
  for ix_input, x_axes in enumerate(axes):
197
219
  axes_len = len(x_axes)
198
220
  shape_dim_len = len(shapes[ix_input])
199
- if axes_len > shape_dim_len:
200
- raise ValueError(f"{msg_prefix} length of element {x_axes} in 'axes' must be less than or equal to "
201
- f"{shape_dim_len}, but got {axes_len}.")
221
+ _check_len(axes_len, shape_dim_len, x_axes)
202
222
 
203
223
  # axis values range check
204
224
  for ix_input, x_axes in enumerate(axes):
205
225
  comp_shape = shapes[ix_input]
206
226
  max_val = len(comp_shape) - 1
207
227
  min_val = -1 * len(comp_shape)
208
- for _, x_value in enumerate(x_axes):
209
- if not min_val <= x_value <= max_val:
210
- raise ValueError(f"{msg_prefix} value in 'axes' must be in range: [{min_val}, {max_val}], "
211
- f"but got {x_value}.")
228
+ _check_value(x_axes, min_val, max_val)
212
229
 
213
230
  # check axis value with input shape - both ways for axis valid
214
231
  invalid_a = False
@@ -218,23 +235,31 @@ def _validate_axes(x1_shape, x2_shape, axes, prim_name=None):
218
235
  invalid_a = True
219
236
  if x1_shape[axes[0][i]] != x2_shape[axes[1][len(axes[0]) - 1 - i]]:
220
237
  invalid_b = True
221
- if invalid_a and invalid_b:
222
- raise ValueError(f"{msg_prefix} 'i' should exist such that 'x1_shape[axes[0][i]]' is equal to "
223
- f"'x2_shape[axes[1][i]]' or 'x2_shape[axes[1][len(axes[0])-1-i]]', but got "
224
- f"'x1_shape': {x1_shape}, 'x2_shape': {x2_shape}, 'axes': {axes}.")
225
238
 
239
+ def _check(invalid_a, invalid_b, x1_shape, x2_shape, axes):
240
+ if invalid_a and invalid_b:
241
+ raise ValueError(f"{msg_prefix} 'i' should exist such that 'x1_shape[axes[0][i]]' is equal to "
242
+ f"'x2_shape[axes[1][i]]' or 'x2_shape[axes[1][len(axes[0])-1-i]]', but got "
243
+ f"'x1_shape': {x1_shape}, 'x2_shape': {x2_shape}, 'axes': {axes}.")
226
244
 
227
- @constexpr
245
+ _check(invalid_a, invalid_b, x1_shape, x2_shape, axes)
246
+
247
+
248
+ @_primexpr
228
249
  def _calc_new_shape(shape, axes, position=0):
229
250
  """
230
251
  Calculate transpose and reshape parameters for input transformations,
231
252
  'position' refers to whether tensor is first or second in the op.
232
253
  """
233
254
  contraction_axes = tuple(i if i >= 0 else i + len(shape) for i in axes[position])
234
- prod_contraction = int(np.prod([shape[i] for i in contraction_axes]))
255
+ prod_contraction = 1
256
+ for i in contraction_axes:
257
+ prod_contraction *= shape[i]
235
258
  free_axes = tuple(i for i in range(len(shape)) if i not in contraction_axes)
236
- free_dims = tuple(shape[i] for i in free_axes)
237
- prod_free = int(np.prod(free_dims))
259
+ free_dims = tuple(shape[i] if shape[i] is not None else -1 for i in free_axes)
260
+ prod_free = 1
261
+ for free_dim in free_dims:
262
+ prod_free *= free_dim
238
263
 
239
264
  transpose_perm = contraction_axes + free_axes if position else free_axes + contraction_axes
240
265
  new_shape = (prod_contraction, prod_free) if position else (prod_free, prod_contraction)
@@ -294,10 +319,7 @@ def tensor_dot(x1, x2, axes):
294
319
  # input validity checks
295
320
  x1_shape = shape_op(x1)
296
321
  x2_shape = shape_op(x2)
297
- x1_type = F.dtype(x1)
298
- x2_type = F.dtype(x2)
299
322
  axes = _check_axes(axes, 'tensor_dot')
300
- _typecheck_input(x1_type, x2_type, 'tensor_dot')
301
323
  # input compatibility check & axes format update
302
324
  axes = _axes_int_check(x1_shape, x2_shape, axes, 'tensor_dot')
303
325
  _validate_axes(x1_shape, x2_shape, axes, 'tensor_dot')
@@ -314,7 +336,7 @@ def tensor_dot(x1, x2, axes):
314
336
  return final_result
315
337
 
316
338
 
317
- @constexpr
339
+ @_primexpr
318
340
  def _check_invalid_input(x1_shape, x2_shape, prim_name=None):
319
341
  msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
320
342
  if len(x1_shape) < 2 or len(x2_shape) < 2:
@@ -335,30 +357,30 @@ def _typecheck_input_dot(x1_type, x2_type, prim_name=None):
335
357
  f"x1_type: {x1_type} and x2_type: {x2_type}.")
336
358
 
337
359
 
338
- @constexpr
360
+ @_primexpr
339
361
  def _get_transpose_shape(x2_shape):
340
362
  x2_shape_range = tuple(range(len(x2_shape)))
341
363
  x2_shape_transpose = x2_shape_range[-2:-1] + x2_shape_range[:-2] + x2_shape_range[-1:]
342
364
  return x2_shape_transpose
343
365
 
344
366
 
345
- def dot(x1, x2):
367
+ def dot(input, other):
346
368
  """
347
369
  Computation a dot product between samples in two tensors.
348
370
 
349
371
  Args:
350
- x1 (Tensor): First tensor in Dot op with datatype float16 or float32,
372
+ input (Tensor): First tensor in Dot op with datatype float16 or float32,
351
373
  The rank must be greater than or equal to 2.
352
- x2 (Tensor): Second tensor in Dot op with datatype float16 or float32,
374
+ other (Tensor): Second tensor in Dot op with datatype float16 or float32,
353
375
  The rank must be greater than or equal to 2.
354
376
 
355
377
  Returns:
356
- Tensor, dot product of x1 and x2.
378
+ Tensor, dot product of input and other.
357
379
 
358
380
  Raises:
359
- TypeError: If type of x1 and x2 are not the same.
360
- TypeError: If dtype of x1 or x2 is not float16 or float32.
361
- ValueError: If rank of x1 or x2 less than 2.
381
+ TypeError: If type of input and other are not the same.
382
+ TypeError: If dtype of input or other is not float16 or float32.
383
+ ValueError: If rank of input or other less than 2.
362
384
 
363
385
  Supported Platforms:
364
386
  ``Ascend`` ``GPU`` ``CPU``
@@ -367,25 +389,25 @@ def dot(x1, x2):
367
389
  >>> import numpy as np
368
390
  >>> import mindspore
369
391
  >>> from mindspore import Tensor, ops
370
- >>> input_x1 = Tensor(np.ones(shape=[2, 3]), mindspore.float32)
371
- >>> input_x2 = Tensor(np.ones(shape=[1, 3, 2]), mindspore.float32)
372
- >>> output = ops.dot(input_x1, input_x2)
392
+ >>> input = Tensor(np.ones(shape=[2, 3]), mindspore.float32)
393
+ >>> other = Tensor(np.ones(shape=[1, 3, 2]), mindspore.float32)
394
+ >>> output = ops.dot(input, other)
373
395
  >>> print(output)
374
396
  [[[3. 3.]]
375
397
  [[3. 3.]]]
376
398
  >>> print(output.shape)
377
399
  (2, 1, 2)
378
- >>> input_x1 = Tensor(np.ones(shape=[1, 2, 3]), mindspore.float32)
379
- >>> input_x2 = Tensor(np.ones(shape=[1, 3, 2]), mindspore.float32)
380
- >>> output = ops.dot(input_x1, input_x2)
400
+ >>> input = Tensor(np.ones(shape=[1, 2, 3]), mindspore.float32)
401
+ >>> other = Tensor(np.ones(shape=[1, 3, 2]), mindspore.float32)
402
+ >>> output = ops.dot(input, other)
381
403
  >>> print(output)
382
404
  [[[[3. 3.]]
383
405
  [[3. 3.]]]]
384
406
  >>> print(output.shape)
385
407
  (1, 2, 1, 2)
386
- >>> input_x1 = Tensor(np.ones(shape=[1, 2, 3]), mindspore.float32)
387
- >>> input_x2 = Tensor(np.ones(shape=[2, 3, 2]), mindspore.float32)
388
- >>> output = ops.dot(input_x1, input_x2)
408
+ >>> input = Tensor(np.ones(shape=[1, 2, 3]), mindspore.float32)
409
+ >>> other = Tensor(np.ones(shape=[2, 3, 2]), mindspore.float32)
410
+ >>> output = ops.dot(input, other)
389
411
  >>> print(output)
390
412
  [[[[3. 3.]
391
413
  [3. 3.]]
@@ -393,9 +415,9 @@ def dot(x1, x2):
393
415
  [3. 3.]]]]
394
416
  >>> print(output.shape)
395
417
  (1, 2, 2, 2)
396
- >>> input_x1 = Tensor(np.ones(shape=[3, 2, 3]), mindspore.float32)
397
- >>> input_x2 = Tensor(np.ones(shape=[2, 1, 3, 2]), mindspore.float32)
398
- >>> output = ops.dot(input_x1, input_x2)
418
+ >>> input = Tensor(np.ones(shape=[3, 2, 3]), mindspore.float32)
419
+ >>> other = Tensor(np.ones(shape=[2, 1, 3, 2]), mindspore.float32)
420
+ >>> output = ops.dot(input, other)
399
421
  >>> print(output)
400
422
  [[[[[3. 3.]]
401
423
  [[3. 3.]]]
@@ -416,34 +438,36 @@ def dot(x1, x2):
416
438
  reshape_op = P.Reshape()
417
439
  transpose_op = P.Transpose()
418
440
  matmul_op = P.MatMul(False, False)
419
- x1_shape = shape_op(x1)
420
- x2_shape = shape_op(x2)
421
- x1_type = F.dtype(x1)
422
- x2_type = F.dtype(x2)
423
- _typecheck_input_dot(x1_type, x2_type, 'dot')
424
- _check_invalid_input(x1_shape, x2_shape, 'dot')
425
-
426
- if len(x1_shape) > 2 or len(x2_shape) > 2:
427
- x2_shape_transpose = _get_transpose_shape(x2_shape)
428
- x2_transpose = transpose_op(x2, x2_shape_transpose)
429
- x1_reshape = reshape_op(x1, (-1, x1_shape[-1]))
430
- x2_reshape = reshape_op(x2_transpose, (x2_shape[-2], -1))
431
- mul_result = matmul_op(x1_reshape, x2_reshape)
432
- reshape_shape = x1_shape[:-1] + x2_shape[:-2] + x2_shape[-1:]
441
+ input_shape = shape_op(input)
442
+ other_shape = shape_op(other)
443
+ input_type = F.dtype(input)
444
+ other_type = F.dtype(other)
445
+ _typecheck_input_dot(input_type, other_type, 'dot')
446
+ _check_invalid_input(input_shape, other_shape, 'dot')
447
+
448
+ if len(input_shape) > 2 or len(other_shape) > 2:
449
+ other_shape_transpose = _get_transpose_shape(other_shape)
450
+ other_transpose = transpose_op(other, other_shape_transpose)
451
+ input_reshape = reshape_op(input, (-1, input_shape[-1]))
452
+ other_reshape = reshape_op(other_transpose, (other_shape[-2], -1))
453
+ mul_result = matmul_op(input_reshape, other_reshape)
454
+ reshape_shape = input_shape[:-1] + other_shape[:-2] + other_shape[-1:]
433
455
  reshape_shape = (-1,) + reshape_shape[1:]
434
456
  return reshape_op(mul_result, reshape_shape)
435
- return matmul_op(x1, x2)
457
+ return matmul_op(input, other)
436
458
 
437
459
 
438
- @constexpr
460
+ @_primexpr
439
461
  def _get_batch_size(x1_shape, x2_shape, prim_name=None):
440
462
  """
441
463
  Get batch sizes from two inputs
442
464
  """
443
- msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
444
- if len(x1_shape) < 2 or len(x2_shape) < 2:
445
- raise ValueError(f"{msg_prefix} inputs x1, x2 should have 'dimension >= 2', "
446
- f"but got 'len(x1_shape)': ({len(x1_shape)}) and 'len(x2_shape)': ({len(x2_shape)}).")
465
+ def _check():
466
+ msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
467
+ if len(x1_shape) < 2 or len(x2_shape) < 2:
468
+ raise ValueError(f"{msg_prefix} inputs x1, x2 should have 'dimension >= 2', "
469
+ f"but got 'len(x1_shape)': ({len(x1_shape)}) and 'len(x2_shape)': ({len(x2_shape)}).")
470
+ _check()
447
471
  return x1_shape[0], x2_shape[0]
448
472
 
449
473
 
@@ -460,12 +484,33 @@ def _typecheck_input_batch_dot(x1_type, x2_type, prim_name=None):
460
484
  f"x2_type: {x2_type}.")
461
485
 
462
486
 
463
- @constexpr
487
+ @_primexpr
464
488
  def _check_axes_for_batch_dot(x1_shape, x2_shape, axes, prim_name=None):
465
489
  """
466
490
  Check whether axes are valid and cast axes from tuple to list
467
491
  """
468
492
  msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
493
+
494
+ def _check_1(axes):
495
+ if 0 in axes:
496
+ raise ValueError(f"{msg_prefix} 'axes' cannot contain 0, but got axes: {axes}.")
497
+ if len(axes) != 2:
498
+ raise ValueError(f"{msg_prefix} length of 'axes' must be equal to 2, but got {len(axes)}.")
499
+
500
+ def _check_2(axes, x1_shape, x2_shape):
501
+ if axes[0] > len(x1_shape) or axes[1] > len(x2_shape):
502
+ raise ValueError(f"{msg_prefix} axes[0] must be less than or equal to len(x1_shape), "
503
+ f"and axes[1] must be less than or equal to len(x2_shape)."
504
+ f"But got 'axes': {axes}, 'x1_shape': {x1_shape}, 'x2_shape': {x2_shape}.")
505
+
506
+ def _check_3(axes, x1_shape, x2_shape):
507
+ if axes == 0:
508
+ raise ValueError(f"{msg_prefix} 'axes' should not be equal to 0, but got {axes}.")
509
+
510
+ if axes > len(x1_shape) or axes > len(x2_shape):
511
+ raise ValueError(f"{msg_prefix} 'axes' cannot be greater than the length of 'x1_shape' and 'x2_shape', "
512
+ f"but got 'axes': {axes}, 'x1_shape': {x1_shape}, 'x2_shape': {x2_shape}.")
513
+
469
514
  if axes is None:
470
515
  if len(x2_shape) == 2:
471
516
  axes = [len(x1_shape) - 1, len(x2_shape) - 1]
@@ -473,10 +518,7 @@ def _check_axes_for_batch_dot(x1_shape, x2_shape, axes, prim_name=None):
473
518
  axes = [len(x1_shape) - 1, len(x2_shape) - 2]
474
519
 
475
520
  if isinstance(axes, (list, tuple)):
476
- if 0 in axes:
477
- raise ValueError(f"{msg_prefix} 'axes' cannot contain 0, but got axes: {axes}.")
478
- if len(axes) != 2:
479
- raise ValueError(f"{msg_prefix} length of 'axes' must be equal to 2, but got {len(axes)}.")
521
+ _check_1(axes)
480
522
  if isinstance(axes, tuple):
481
523
  axes = list(axes)
482
524
  validator.check_value_type('axes[0]', axes[0], [int], 'batch_dot')
@@ -488,19 +530,12 @@ def _check_axes_for_batch_dot(x1_shape, x2_shape, axes, prim_name=None):
488
530
  axes[1] += len(x2_shape)
489
531
  validator.check_non_negative_int(axes[0], 'reversed axes[0]', 'batch_dot')
490
532
  validator.check_non_negative_int(axes[1], 'reversed axes[1]', 'batch_dot')
491
- if axes[0] > len(x1_shape) or axes[1] > len(x2_shape):
492
- raise ValueError(f"{msg_prefix} axes[0] must be less than or equal to len(x1_shape), "
493
- f"and axes[1] must be less than or equal to len(x2_shape)."
494
- f"But got 'axes': {axes}, 'x1_shape': {x1_shape}, 'x2_shape': {x2_shape}.")
533
+ _check_2(axes, x1_shape, x2_shape)
495
534
  elif isinstance(axes, int):
496
- if axes == 0:
497
- raise ValueError(f"{msg_prefix} 'axes' should not be equal to 0, but got {axes}.")
535
+ _check_3(axes, x1_shape, x2_shape)
498
536
  if axes < 0:
499
537
  axes = [axes + len(x1_shape), axes + len(x2_shape)]
500
538
  validator.check_non_negative_int(axes[0], 'reversed axes', 'batch_dot')
501
- elif axes > len(x1_shape) or axes > len(x2_shape):
502
- raise ValueError(f"{msg_prefix} 'axes' cannot be greater than the length of 'x1_shape' and 'x2_shape', "
503
- f"but got 'axes': {axes}, 'x1_shape': {x1_shape}, 'x2_shape': {x2_shape}.")
504
539
  else:
505
540
  axes = [axes, axes]
506
541
  else:
@@ -509,7 +544,7 @@ def _check_axes_for_batch_dot(x1_shape, x2_shape, axes, prim_name=None):
509
544
  return axes
510
545
 
511
546
 
512
- @constexpr
547
+ @_primexpr
513
548
  def _calc_new_shape_batchdot(shape, axes, position=0):
514
549
  """
515
550
  Calculate transpose and reshape parameters for input transformations,
@@ -517,10 +552,14 @@ def _calc_new_shape_batchdot(shape, axes, position=0):
517
552
  """
518
553
  axis = axes[position]
519
554
  contraction_axes = tuple([axis])
520
- prod_contraction = int(np.prod([shape[i] for i in contraction_axes]))
555
+ prod_contraction = 1
556
+ for i in contraction_axes:
557
+ prod_contraction *= shape[i]
521
558
  free_axes = tuple(i for i in range(1, len(shape)) if i not in contraction_axes)
522
559
  free_dims = tuple(shape[i] for i in free_axes)
523
- prod_free = int(np.prod(free_dims))
560
+ prod_free = 1
561
+ for free_dim in free_dims:
562
+ prod_free *= free_dim
524
563
 
525
564
  transpose_perm = contraction_axes + free_axes if position else free_axes + contraction_axes
526
565
  transpose_perm = tuple([0]) + transpose_perm
@@ -529,7 +568,7 @@ def _calc_new_shape_batchdot(shape, axes, position=0):
529
568
  return new_shape, transpose_perm, free_dims
530
569
 
531
570
 
532
- @constexpr
571
+ @_primexpr
533
572
  def _check_batch_size(x1_batch_size, x2_batch_size, prim_name=None):
534
573
  """
535
574
  Check whether batch size of two inputs are the same
@@ -540,7 +579,7 @@ def _check_batch_size(x1_batch_size, x2_batch_size, prim_name=None):
540
579
  f"'x1_batch_size': {x1_batch_size} and 'x2_batch_size': {x2_batch_size}.")
541
580
 
542
581
 
543
- @constexpr
582
+ @_primexpr
544
583
  def _get_output_shape(batch_size, x1_ret, x2_ret):
545
584
  """
546
585
  Compute output shape for batch dot
@@ -732,6 +771,49 @@ def matmul(x1, x2, dtype=None):
732
771
  return res
733
772
 
734
773
 
774
+ def mm(input, mat2):
775
+ r"""
776
+ Returns the matrix product of two arrays.
777
+ If `input` is a :math:`(n \times m)` Tensor, `mat2` is a
778
+ :math:`(m \times p)` Tensor, `out` will be a :math:`(n \times p)` Tensor.
779
+
780
+ Note:
781
+ This function cannot support broadcasting.
782
+ Refer to :func:`mindspore.ops.matmul` instead if you need a broadcastable function.
783
+
784
+ Args:
785
+ input (Tensor): The first matrix of matrix multiplication.
786
+ The last dimension of `input` must be the same size as the first dimension of `mat2`.
787
+ mat2 (Tensor): The second matrix of matrix multiplication.
788
+ The last dimension of `input` must be the same size as the first dimension of `mat2`.
789
+
790
+ Returns:
791
+ Tensor or scalar, the matrix product of the inputs.
792
+
793
+ Raises:
794
+ ValueError: If the last dimension of `input` is not the same size as the
795
+ second-to-last dimension of `mat2`.
796
+ ValueError: If `input` or `mat2` is not a matrix.
797
+
798
+ Supported Platforms:
799
+ ``Ascend`` ``GPU`` ``CPU``
800
+
801
+ Examples:
802
+ >>> import mindspore as ms
803
+ >>> import mindspore.ops as ops
804
+ >>> import numpy as np
805
+ >>> x1 = ms.Tensor(np.random.rand(2, 3))
806
+ >>> x2 = ms.Tensor(np.random.rand(3, 4))
807
+ >>> out = ops.mm(x1, x2)
808
+ >>> print(out.shape)
809
+ (2, 4)
810
+ """
811
+ if input.ndim != 2 or mat2.ndim != 2:
812
+ raise ValueError(f"For mm, the input tensor must be a matrix, "
813
+ f"but got mat1.ndim:{input.ndim}, mat2.ndim:{mat2.ndim}")
814
+ return matmul(input, mat2)
815
+
816
+
735
817
  def cummin(x, axis):
736
818
  r"""
737
819
  Returns a tuple (values,indices) where 'values' is the cumulative minimum value of input Tensor `x`
@@ -770,55 +852,3 @@ def cummin(x, axis):
770
852
  [0 1 1 1 4 4]
771
853
  """
772
854
  return cummin_(x, axis)
773
-
774
-
775
- def resize_nearest_neighbor(input_x, size, align_corners=False):
776
- r"""
777
- Resizes the input tensor by using the nearest neighbor algorithm.
778
-
779
- Resizes the input tensor to a given size by using the nearest neighbor algorithm. The nearest
780
- neighbor algorithm selects the value of the nearest point and does not consider the
781
- values of neighboring points at all, yielding a piecewise-constant interpolant.
782
-
783
- Args:
784
- input_x (Tensor) - The input tensor. The shape of the tensor is :math:`(N, C, H, W)`.
785
- size (Union[Tensor, tuple, list]): The target size. The dimension of size must be 2.
786
- align_corners (bool): Whether the centers of the 4 corner pixels of the input
787
- and output tensors are aligned. Default: False.
788
-
789
- Returns:
790
- Tensor, the shape of the output tensor is :math:`(N, C, NEW\_H, NEW\_W)`.
791
- The data type is the same as the `input_x`.
792
-
793
- Raises:
794
- TypeError: If `input_x` is not a Tensor.
795
- TypeError: If `size` is neither tuple nor list.
796
- TypeError: If `align_corners` is not a bool.
797
- ValueError: If length of `size` is not equal to 2.
798
-
799
- Supported Platforms:
800
- ``Ascend`` ``GPU`` ``CPU``
801
-
802
- Examples:
803
- >>> import numpy as np
804
- >>> import mindspore
805
- >>> from mindspore import Tensor, ops
806
- >>> input_tensor = Tensor(np.array([[[[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]]]), mindspore.float32)
807
- >>> size = (2, 2)
808
- >>> output = ops.ResizeNearestNeighbor(size=size)(input_tensor)
809
- >>> print(output)
810
- [[[[-0.1 0.3]
811
- [ 0.4 0.5]]]]
812
- """
813
- if size is None:
814
- raise ValueError(f'For ResizeNearestNeighbor, size could not be None.')
815
- if isinstance(size, (tuple, list)):
816
- resize = P.ResizeNearestNeighbor(size, align_corners)
817
- return resize(input_x)
818
- if is_const(size):
819
- size = size.asnumpy()
820
- resize = P.ResizeNearestNeighbor(size, align_corners)
821
- return resize(input_x)
822
-
823
- resize = DynamicResizeNearestNeighbor(align_corners)
824
- return resize(input_x, size)