mindspore 1.10.0__cp37-cp37m-win_amd64.whl → 2.0.0rc1__cp37-cp37m-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (966) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/ConcurrencyCheck.dll +0 -0
  3. mindspore/CppBuildInsights.dll +0 -0
  4. mindspore/CppCoreCheck.dll +0 -0
  5. mindspore/EnumIndex.dll +0 -0
  6. mindspore/EspXEngine.dll +0 -0
  7. mindspore/HResultCheck.dll +0 -0
  8. mindspore/KernelTraceControl.dll +0 -0
  9. mindspore/LocalESPC.dll +0 -0
  10. mindspore/Microsoft.Diagnostics.Tracing.EventSource.dll +0 -0
  11. mindspore/Microsoft.VisualStudio.RemoteControl.dll +0 -0
  12. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  13. mindspore/Microsoft.VisualStudio.Utilities.Internal.dll +0 -0
  14. mindspore/Newtonsoft.Json.dll +0 -0
  15. mindspore/System.Runtime.CompilerServices.Unsafe.dll +0 -0
  16. mindspore/VariantClear.dll +0 -0
  17. mindspore/__init__.py +9 -4
  18. mindspore/_c_dataengine.cp37-win_amd64.pyd +0 -0
  19. mindspore/_c_expression.cp37-win_amd64.pyd +0 -0
  20. mindspore/_c_mindrecord.cp37-win_amd64.pyd +0 -0
  21. mindspore/_check_jit_forbidden_api.py +102 -0
  22. mindspore/_checkparam.py +1066 -1001
  23. mindspore/_extends/builtin_operations.py +32 -4
  24. mindspore/_extends/graph_kernel/model/graph_split.py +66 -222
  25. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +12 -9
  26. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +119 -26
  27. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +50 -50
  28. mindspore/_extends/parallel_compile/akg_compiler/util.py +9 -6
  29. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +4 -25
  30. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +9 -4
  31. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -27
  32. mindspore/_extends/parse/__init__.py +5 -3
  33. mindspore/_extends/parse/namespace.py +17 -2
  34. mindspore/_extends/parse/parser.py +193 -34
  35. mindspore/_extends/parse/resources.py +7 -8
  36. mindspore/_extends/parse/standard_method.py +1780 -435
  37. mindspore/_extends/parse/trope.py +3 -1
  38. mindspore/amp.py +53 -58
  39. mindspore/atlprov.dll +0 -0
  40. mindspore/boost/adasum.py +3 -2
  41. mindspore/boost/boost.py +2 -2
  42. mindspore/boost/boost_cell_wrapper.py +46 -26
  43. mindspore/boost/dim_reduce.py +6 -5
  44. mindspore/boost/grad_accumulation.py +2 -1
  45. mindspore/boost/group_loss_scale_manager.py +1 -1
  46. mindspore/c1.dll +0 -0
  47. mindspore/c1xx.dll +0 -0
  48. mindspore/c2.dll +0 -0
  49. mindspore/cfgpersist.dll +0 -0
  50. mindspore/clang_rt.asan_dbg_dynamic-x86_64.dll +0 -0
  51. mindspore/clang_rt.asan_dynamic-x86_64.dll +0 -0
  52. mindspore/common/__init__.py +11 -10
  53. mindspore/common/_decorator.py +2 -0
  54. mindspore/common/_register_for_adapter.py +55 -0
  55. mindspore/common/_stub_tensor.py +201 -0
  56. mindspore/common/_utils.py +57 -0
  57. mindspore/common/api.py +582 -297
  58. mindspore/common/dtype.py +66 -18
  59. mindspore/common/dump.py +2 -2
  60. mindspore/common/initializer.py +38 -1
  61. mindspore/common/jit_config.py +25 -13
  62. mindspore/common/mutable.py +53 -24
  63. mindspore/common/parameter.py +60 -37
  64. mindspore/common/seed.py +8 -24
  65. mindspore/common/sparse_tensor.py +927 -0
  66. mindspore/common/tensor.py +1627 -3900
  67. mindspore/communication/__init__.py +10 -5
  68. mindspore/communication/_comm_helper.py +78 -214
  69. mindspore/communication/_hccl_management.py +2 -1
  70. mindspore/communication/management.py +136 -47
  71. mindspore/config/op_info.config +501 -1008
  72. mindspore/context.py +291 -56
  73. mindspore/d3dcompiler_47.dll +0 -0
  74. mindspore/dataset/__init__.py +12 -8
  75. mindspore/dataset/audio/__init__.py +9 -9
  76. mindspore/dataset/audio/transforms.py +1090 -228
  77. mindspore/dataset/audio/utils.py +87 -39
  78. mindspore/dataset/audio/validators.py +223 -1
  79. mindspore/dataset/callback/ds_callback.py +17 -15
  80. mindspore/dataset/core/config.py +246 -17
  81. mindspore/dataset/core/py_util_helpers.py +4 -3
  82. mindspore/dataset/core/validator_helpers.py +10 -10
  83. mindspore/{parallel/nn/layers.py → dataset/debug/__init__.py} +7 -8
  84. mindspore/dataset/debug/debug_hook.py +65 -0
  85. mindspore/dataset/debug/pre_defined_hook.py +67 -0
  86. mindspore/dataset/engine/__init__.py +7 -3
  87. mindspore/dataset/engine/cache_client.py +9 -9
  88. mindspore/dataset/engine/datasets.py +648 -477
  89. mindspore/dataset/engine/datasets_audio.py +165 -167
  90. mindspore/dataset/engine/datasets_standard_format.py +93 -67
  91. mindspore/dataset/engine/datasets_text.py +492 -342
  92. mindspore/dataset/engine/datasets_user_defined.py +85 -50
  93. mindspore/dataset/engine/datasets_vision.py +1224 -699
  94. mindspore/dataset/engine/graphdata.py +134 -69
  95. mindspore/dataset/engine/iterators.py +50 -9
  96. mindspore/dataset/engine/offload.py +52 -31
  97. mindspore/dataset/engine/samplers.py +27 -24
  98. mindspore/dataset/engine/serializer_deserializer.py +14 -15
  99. mindspore/dataset/engine/validators.py +213 -52
  100. mindspore/dataset/text/__init__.py +10 -8
  101. mindspore/dataset/text/transforms.py +152 -57
  102. mindspore/dataset/text/utils.py +98 -49
  103. mindspore/dataset/text/validators.py +25 -0
  104. mindspore/dataset/transforms/__init__.py +4 -2
  105. mindspore/dataset/transforms/c_transforms.py +11 -13
  106. mindspore/dataset/transforms/py_transforms.py +2 -2
  107. mindspore/dataset/transforms/py_transforms_util.py +10 -0
  108. mindspore/dataset/transforms/transforms.py +13 -15
  109. mindspore/dataset/transforms/validators.py +7 -7
  110. mindspore/dataset/utils/__init__.py +2 -1
  111. mindspore/dataset/utils/browse_dataset.py +13 -13
  112. mindspore/dataset/utils/line_reader.py +121 -0
  113. mindspore/dataset/vision/__init__.py +8 -7
  114. mindspore/dataset/vision/c_transforms.py +125 -126
  115. mindspore/dataset/vision/py_transforms.py +37 -37
  116. mindspore/dataset/vision/py_transforms_util.py +23 -20
  117. mindspore/dataset/vision/transforms.py +316 -315
  118. mindspore/dataset/vision/utils.py +313 -17
  119. mindspore/dataset/vision/validators.py +6 -6
  120. mindspore/default_config.py +0 -1
  121. mindspore/dpcmi.dll +0 -0
  122. mindspore/{compression → experimental}/__init__.py +6 -5
  123. mindspore/experimental/map_parameter.py +275 -0
  124. mindspore/include/OWNERS +0 -1
  125. mindspore/include/api/callback/callback.h +9 -13
  126. mindspore/include/api/callback/ckpt_saver.h +2 -2
  127. mindspore/include/api/callback/loss_monitor.h +2 -2
  128. mindspore/include/api/callback/lr_scheduler.h +5 -5
  129. mindspore/include/api/callback/time_monitor.h +2 -2
  130. mindspore/include/api/callback/train_accuracy.h +4 -6
  131. mindspore/include/api/cfg.h +19 -6
  132. mindspore/include/api/context.h +70 -9
  133. mindspore/include/api/delegate.h +8 -1
  134. mindspore/include/api/dual_abi_helper.h +8 -24
  135. mindspore/include/api/metrics/accuracy.h +2 -2
  136. mindspore/include/api/metrics/metrics.h +4 -3
  137. mindspore/include/api/model.h +9 -4
  138. mindspore/include/api/model_group.h +68 -0
  139. mindspore/include/api/model_parallel_runner.h +17 -17
  140. mindspore/include/api/net.h +12 -11
  141. mindspore/include/api/serialization.h +20 -4
  142. mindspore/include/api/status.h +7 -1
  143. mindspore/include/api/types.h +25 -21
  144. mindspore/include/api/visible.h +4 -0
  145. mindspore/include/c_api/model_c.h +5 -0
  146. mindspore/include/c_api/status_c.h +1 -1
  147. mindspore/include/dataset/config.h +1 -1
  148. mindspore/include/dataset/constants.h +14 -0
  149. mindspore/include/dataset/text.h +59 -0
  150. mindspore/include/dataset/vision.h +56 -117
  151. mindspore/include/dataset/vision_lite.h +102 -0
  152. mindspore/jpeg62.dll +0 -0
  153. mindspore/log.py +28 -28
  154. mindspore/mindrecord/common/exceptions.py +2 -4
  155. mindspore/mindrecord/filereader.py +19 -1
  156. mindspore/mindrecord/filewriter.py +250 -88
  157. mindspore/mindrecord/mindpage.py +13 -13
  158. mindspore/mindrecord/shardheader.py +15 -15
  159. mindspore/mindrecord/shardreader.py +9 -0
  160. mindspore/mindrecord/shardwriter.py +29 -29
  161. mindspore/mindrecord/tools/cifar100_to_mr.py +9 -9
  162. mindspore/mindrecord/tools/cifar10_to_mr.py +9 -9
  163. mindspore/mindrecord/tools/csv_to_mr.py +4 -4
  164. mindspore/mindrecord/tools/imagenet_to_mr.py +70 -65
  165. mindspore/mindrecord/tools/mnist_to_mr.py +41 -41
  166. mindspore/mindrecord/tools/tfrecord_to_mr.py +6 -6
  167. mindspore/{libmindspore_backend.dll → mindspore_backend.dll} +0 -0
  168. mindspore/mindspore_common.dll +0 -0
  169. mindspore/mindspore_core.dll +0 -0
  170. mindspore/mindspore_glog.dll +0 -0
  171. mindspore/mindspore_shared_lib.dll +0 -0
  172. mindspore/msobj140.dll +0 -0
  173. mindspore/mspdb140.dll +0 -0
  174. mindspore/mspdbcore.dll +0 -0
  175. mindspore/mspdbst.dll +0 -0
  176. mindspore/mspft140.dll +0 -0
  177. mindspore/msvcdis140.dll +0 -0
  178. mindspore/msvcp140_1.dll +0 -0
  179. mindspore/msvcp140_2.dll +0 -0
  180. mindspore/msvcp140_atomic_wait.dll +0 -0
  181. mindspore/msvcp140_codecvt_ids.dll +0 -0
  182. mindspore/nn/__init__.py +1 -5
  183. mindspore/nn/cell.py +297 -234
  184. mindspore/nn/dynamic_lr.py +1 -1
  185. mindspore/nn/grad/cell_grad.py +17 -42
  186. mindspore/nn/layer/__init__.py +7 -4
  187. mindspore/nn/layer/activation.py +131 -88
  188. mindspore/nn/layer/basic.py +313 -613
  189. mindspore/nn/layer/channel_shuffle.py +103 -0
  190. mindspore/nn/layer/combined.py +1 -1
  191. mindspore/nn/layer/container.py +52 -6
  192. mindspore/nn/layer/conv.py +112 -43
  193. mindspore/nn/layer/dense.py +10 -9
  194. mindspore/nn/layer/embedding.py +36 -34
  195. mindspore/nn/layer/image.py +123 -27
  196. mindspore/nn/layer/math.py +108 -107
  197. mindspore/nn/layer/normalization.py +212 -366
  198. mindspore/nn/layer/padding.py +370 -42
  199. mindspore/nn/layer/pooling.py +1443 -219
  200. mindspore/nn/layer/rnn_cells.py +11 -16
  201. mindspore/nn/layer/rnns.py +38 -39
  202. mindspore/nn/layer/thor_layer.py +24 -25
  203. mindspore/nn/layer/timedistributed.py +5 -5
  204. mindspore/nn/layer/transformer.py +701 -0
  205. mindspore/nn/learning_rate_schedule.py +8 -8
  206. mindspore/nn/loss/__init__.py +9 -6
  207. mindspore/nn/loss/loss.py +678 -142
  208. mindspore/nn/metrics.py +53 -0
  209. mindspore/nn/optim/_dist_optimizer_registry.py +2 -2
  210. mindspore/nn/optim/ada_grad.py +8 -8
  211. mindspore/nn/optim/adadelta.py +2 -3
  212. mindspore/nn/optim/adafactor.py +18 -14
  213. mindspore/nn/optim/adam.py +429 -87
  214. mindspore/nn/optim/adamax.py +5 -6
  215. mindspore/nn/optim/adasum.py +10 -8
  216. mindspore/nn/optim/asgd.py +7 -7
  217. mindspore/nn/optim/ftrl.py +81 -11
  218. mindspore/nn/optim/lamb.py +7 -8
  219. mindspore/nn/optim/lars.py +4 -4
  220. mindspore/nn/optim/lazyadam.py +82 -7
  221. mindspore/nn/optim/momentum.py +8 -7
  222. mindspore/nn/optim/optimizer.py +19 -10
  223. mindspore/nn/optim/proximal_ada_grad.py +6 -5
  224. mindspore/nn/optim/rmsprop.py +3 -3
  225. mindspore/nn/optim/rprop.py +20 -16
  226. mindspore/nn/optim/sgd.py +21 -15
  227. mindspore/nn/optim/thor.py +23 -21
  228. mindspore/nn/probability/__init__.py +0 -2
  229. mindspore/nn/probability/bijector/bijector.py +7 -6
  230. mindspore/nn/probability/bijector/invert.py +4 -2
  231. mindspore/nn/probability/bijector/softplus.py +2 -2
  232. mindspore/nn/probability/bnn_layers/dense_variational.py +1 -1
  233. mindspore/nn/probability/bnn_layers/layer_distribution.py +2 -2
  234. mindspore/nn/probability/distribution/__init__.py +6 -0
  235. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -2
  236. mindspore/nn/probability/distribution/_utils/utils.py +11 -17
  237. mindspore/nn/probability/distribution/bernoulli.py +6 -6
  238. mindspore/nn/probability/distribution/beta.py +1 -1
  239. mindspore/nn/probability/distribution/categorical.py +9 -9
  240. mindspore/nn/probability/distribution/cauchy.py +8 -8
  241. mindspore/nn/probability/distribution/distribution.py +12 -6
  242. mindspore/nn/probability/distribution/exponential.py +5 -5
  243. mindspore/nn/probability/distribution/gamma.py +3 -3
  244. mindspore/nn/probability/distribution/geometric.py +6 -5
  245. mindspore/nn/probability/distribution/gumbel.py +5 -5
  246. mindspore/nn/probability/distribution/half_normal.py +133 -0
  247. mindspore/nn/probability/distribution/laplace.py +128 -0
  248. mindspore/nn/probability/distribution/log_normal.py +0 -1
  249. mindspore/nn/probability/distribution/logistic.py +4 -5
  250. mindspore/nn/probability/distribution/normal.py +11 -15
  251. mindspore/nn/probability/distribution/poisson.py +6 -2
  252. mindspore/nn/probability/distribution/student_t.py +150 -0
  253. mindspore/nn/probability/distribution/transformed_distribution.py +4 -4
  254. mindspore/nn/probability/distribution/uniform.py +5 -5
  255. mindspore/nn/reinforcement/_tensors_queue.py +3 -3
  256. mindspore/nn/reinforcement/tensor_array.py +2 -2
  257. mindspore/nn/sparse/sparse.py +8 -1
  258. mindspore/nn/wrap/cell_wrapper.py +55 -27
  259. mindspore/nn/wrap/grad_reducer.py +20 -11
  260. mindspore/nn/wrap/loss_scale.py +47 -30
  261. mindspore/numpy/array_creations.py +33 -22
  262. mindspore/numpy/array_ops.py +46 -42
  263. mindspore/numpy/logic_ops.py +6 -27
  264. mindspore/numpy/math_ops.py +26 -19
  265. mindspore/numpy/utils.py +1 -8
  266. mindspore/numpy/utils_const.py +112 -62
  267. mindspore/opencv_core452.dll +0 -0
  268. mindspore/opencv_imgcodecs452.dll +0 -0
  269. mindspore/opencv_imgproc452.dll +0 -0
  270. mindspore/ops/__init__.py +6 -3
  271. mindspore/ops/_constants.py +0 -6
  272. mindspore/ops/_grad/__init__.py +2 -1
  273. mindspore/ops/_grad/grad_array_ops.py +209 -152
  274. mindspore/ops/_grad/grad_base.py +55 -17
  275. mindspore/ops/_grad/grad_clip_ops.py +11 -3
  276. mindspore/ops/_grad/grad_comm_ops.py +58 -47
  277. mindspore/ops/_grad/grad_implementations.py +21 -61
  278. mindspore/ops/_grad/grad_inner_ops.py +48 -6
  279. mindspore/ops/_grad/grad_math_ops.py +306 -161
  280. mindspore/ops/_grad/grad_nn_ops.py +192 -181
  281. mindspore/ops/_grad/grad_other_ops.py +1 -1
  282. mindspore/ops/_grad/grad_quant_ops.py +5 -5
  283. mindspore/ops/_grad/grad_sequence_ops.py +296 -0
  284. mindspore/ops/_grad/grad_sparse.py +15 -9
  285. mindspore/ops/_grad_experimental/__init__.py +1 -0
  286. mindspore/ops/_grad_experimental/grad_array_ops.py +441 -55
  287. mindspore/ops/_grad_experimental/grad_image_ops.py +25 -7
  288. mindspore/ops/_grad_experimental/grad_inner_ops.py +3 -44
  289. mindspore/ops/_grad_experimental/grad_linalg_ops.py +16 -21
  290. mindspore/ops/_grad_experimental/grad_math_ops.py +979 -49
  291. mindspore/ops/_grad_experimental/grad_nn_ops.py +78 -8
  292. mindspore/ops/_grad_experimental/grad_scalar_ops.py +112 -0
  293. mindspore/ops/_grad_experimental/grad_sparse_ops.py +197 -13
  294. mindspore/ops/_op_impl/__init__.py +3 -3
  295. mindspore/ops/_op_impl/_custom_op/__init__.py +0 -1
  296. mindspore/ops/_op_impl/_custom_op/_basic.py +0 -1
  297. mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +1 -1
  298. mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +4 -2
  299. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +2 -2
  300. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +2 -2
  301. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +5 -5
  302. mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +3 -3
  303. mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +1 -1
  304. mindspore/ops/_op_impl/_custom_op/correction_mul.py +3 -3
  305. mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +2 -2
  306. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +4 -8
  307. mindspore/ops/_op_impl/_custom_op/dsd_impl.py +1 -1
  308. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +2 -2
  309. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +2 -2
  310. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +2 -2
  311. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +2 -2
  312. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +2 -2
  313. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +2 -2
  314. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +2 -2
  315. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +2 -2
  316. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +2 -2
  317. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +2 -2
  318. mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +1 -1
  319. mindspore/ops/_op_impl/_custom_op/img2col_impl.py +1 -1
  320. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
  321. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +1 -1
  322. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +1 -1
  323. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +1 -1
  324. mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +2 -2
  325. mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py +0 -1
  326. mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +0 -1
  327. mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +1 -1
  328. mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +2 -2
  329. mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +2 -2
  330. mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +1 -1
  331. mindspore/ops/_op_impl/aicpu/__init__.py +238 -3
  332. mindspore/ops/_op_impl/aicpu/abs.py +36 -0
  333. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d.py +34 -0
  334. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
  335. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d.py +39 -0
  336. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d_grad.py +39 -0
  337. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py +37 -0
  338. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d.py +42 -0
  339. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d_grad.py +152 -0
  340. mindspore/ops/_op_impl/aicpu/add.py +43 -0
  341. mindspore/ops/_op_impl/aicpu/addcdiv.py +0 -32
  342. mindspore/ops/_op_impl/aicpu/addcmul.py +0 -84
  343. mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
  344. mindspore/ops/_op_impl/aicpu/arg_max.py +75 -0
  345. mindspore/ops/_op_impl/aicpu/arg_min.py +75 -0
  346. mindspore/ops/_op_impl/aicpu/argmin_with_value.py +43 -0
  347. mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -0
  348. mindspore/ops/_op_impl/aicpu/batch_norm_grad_grad.py +49 -0
  349. mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
  350. mindspore/ops/_op_impl/aicpu/bessel_i0.py +31 -0
  351. mindspore/ops/_op_impl/aicpu/bias_add.py +44 -0
  352. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +43 -0
  353. mindspore/ops/_op_impl/aicpu/bincount.py +33 -0
  354. mindspore/{nn/probability/infer/variational/__init__.py → ops/_op_impl/aicpu/cauchy.py} +17 -10
  355. mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
  356. mindspore/ops/_op_impl/aicpu/cholesky.py +1 -1
  357. mindspore/ops/_op_impl/{cpu/bias_add.py → aicpu/choleskygrad.py} +9 -7
  358. mindspore/ops/_op_impl/aicpu/combined_non_max_suppression.py +42 -0
  359. mindspore/ops/_op_impl/aicpu/concat_offset.py +42 -0
  360. mindspore/ops/_op_impl/aicpu/concat_offset_v1.py +31 -0
  361. mindspore/ops/_op_impl/aicpu/conj.py +11 -0
  362. mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_image.py +38 -0
  363. mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +36 -0
  364. mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
  365. mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +2 -2
  366. mindspore/ops/_op_impl/aicpu/dense_to_sparse_set_operation.py +48 -0
  367. mindspore/ops/_op_impl/aicpu/diag.py +36 -0
  368. mindspore/ops/_op_impl/aicpu/diag_part.py +36 -0
  369. mindspore/ops/_op_impl/aicpu/diagonal.py +35 -0
  370. mindspore/ops/_op_impl/{cpu/bias_add_grad.py → aicpu/digamma.py} +9 -7
  371. mindspore/ops/_op_impl/aicpu/eig.py +35 -0
  372. mindspore/ops/_op_impl/aicpu/fft_with_size.py +41 -0
  373. mindspore/ops/_op_impl/aicpu/flatten.py +1 -0
  374. mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
  375. mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
  376. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +1 -1
  377. mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
  378. mindspore/ops/_op_impl/aicpu/glu.py +33 -0
  379. mindspore/ops/_op_impl/aicpu/glu_grad.py +34 -0
  380. mindspore/ops/_op_impl/aicpu/greater.py +41 -0
  381. mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
  382. mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
  383. mindspore/ops/_op_impl/{tbe/scatter_add_ds.py → aicpu/inplace_index_add.py} +17 -21
  384. mindspore/ops/_op_impl/aicpu/instance_norm_v2.py +41 -0
  385. mindspore/ops/_op_impl/aicpu/instance_norm_v2_grad.py +44 -0
  386. mindspore/ops/_op_impl/aicpu/layer_norm_grad_grad.py +47 -0
  387. mindspore/ops/_op_impl/aicpu/less.py +41 -0
  388. mindspore/ops/_op_impl/aicpu/less_equal.py +41 -0
  389. mindspore/ops/_op_impl/aicpu/lgamma.py +32 -0
  390. mindspore/ops/_op_impl/aicpu/log_normal_reverse.py +33 -0
  391. mindspore/ops/_op_impl/aicpu/logit.py +33 -0
  392. mindspore/ops/_op_impl/aicpu/logit_grad.py +34 -0
  393. mindspore/ops/_op_impl/aicpu/masked_fill.py +42 -0
  394. mindspore/ops/_op_impl/aicpu/masked_scatter.py +39 -0
  395. mindspore/ops/_op_impl/aicpu/matmul.py +39 -0
  396. mindspore/ops/_op_impl/aicpu/matrix_logarithm.py +31 -0
  397. mindspore/ops/_op_impl/aicpu/matrix_power.py +32 -0
  398. mindspore/ops/_op_impl/aicpu/matrix_solve_ls.py +36 -0
  399. mindspore/ops/_op_impl/aicpu/matrix_triangular_solve.py +36 -0
  400. mindspore/ops/_op_impl/aicpu/mirror_pad.py +2 -0
  401. mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +0 -4
  402. mindspore/ops/_op_impl/aicpu/mul.py +3 -1
  403. mindspore/ops/_op_impl/aicpu/multinomial.py +14 -6
  404. mindspore/ops/_op_impl/aicpu/multinomial_with_replacement.py +35 -0
  405. mindspore/ops/_op_impl/aicpu/nan_to_num.py +34 -0
  406. mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
  407. mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
  408. mindspore/ops/_op_impl/aicpu/ones_like.py +0 -2
  409. mindspore/ops/_op_impl/aicpu/polar.py +32 -0
  410. mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
  411. mindspore/ops/_op_impl/aicpu/qr.py +36 -0
  412. mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
  413. mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
  414. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
  415. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_tensor.py +74 -0
  416. mindspore/ops/_op_impl/aicpu/random_shuffle.py +3 -0
  417. mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
  418. mindspore/ops/_op_impl/aicpu/range.py +36 -0
  419. mindspore/ops/_op_impl/aicpu/reciprocal.py +34 -0
  420. mindspore/ops/_op_impl/aicpu/reciprocal_grad.py +35 -0
  421. mindspore/ops/_op_impl/aicpu/reduce_sum.py +57 -0
  422. mindspore/ops/_op_impl/aicpu/resize_bicubic.py +2 -8
  423. mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +1 -1
  424. mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
  425. mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
  426. mindspore/ops/_op_impl/aicpu/scatter_elements.py +4 -0
  427. mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +2 -0
  428. mindspore/ops/_op_impl/aicpu/search_sorted.py +12 -6
  429. mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py +34 -0
  430. mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
  431. mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
  432. mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
  433. mindspore/ops/_op_impl/aicpu/slice_grad.py +76 -0
  434. mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
  435. mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
  436. mindspore/ops/_op_impl/aicpu/sort.py +39 -0
  437. mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +0 -24
  438. mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
  439. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows.py +63 -0
  440. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows_grad.py +45 -0
  441. mindspore/ops/_op_impl/aicpu/sparse_matrix_mat_mul.py +56 -0
  442. mindspore/ops/_op_impl/{tbe/slice_ds.py → aicpu/sparse_segment_sum.py} +16 -24
  443. mindspore/ops/_op_impl/aicpu/sparse_segment_sum_with_num_segments.py +68 -0
  444. mindspore/ops/_op_impl/aicpu/sparse_slice.py +63 -0
  445. mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +61 -0
  446. mindspore/ops/_op_impl/aicpu/squared_difference.py +2 -0
  447. mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +93 -0
  448. mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +66 -0
  449. mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
  450. mindspore/ops/_op_impl/{tbe/gather_v2.py → aicpu/tile.py} +24 -24
  451. mindspore/ops/_op_impl/aicpu/tridiagonal_solve.py +35 -0
  452. mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
  453. mindspore/ops/_op_impl/aicpu/triu_indices.py +34 -0
  454. mindspore/ops/_op_impl/aicpu/uniform.py +34 -0
  455. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +1 -0
  456. mindspore/ops/_op_impl/aicpu/unique_consecutive.py +10 -2
  457. mindspore/ops/_op_impl/cpu/__init__.py +1 -2
  458. mindspore/ops/_op_impl/cpu/dynamic_shape.py +5 -1
  459. mindspore/ops/_op_impl/cpu/maximum_grad.py +2 -0
  460. mindspore/{compression/common/__init__.py → ops/_op_impl/cpu/pyexecute.py} +13 -8
  461. mindspore/ops/_op_impl/cpu/reduce_sum.py +8 -0
  462. mindspore/ops/_op_impl/cpu/sparse_slice.py +62 -0
  463. mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +60 -0
  464. mindspore/ops/_op_impl/cpu/tensor_shape.py +5 -1
  465. mindspore/ops/_op_impl/tbe/__init__.py +27 -608
  466. mindspore/ops/_op_impl/tbe/addcdiv_ds.py +42 -0
  467. mindspore/ops/_op_impl/tbe/addcmul_ds.py +44 -0
  468. mindspore/ops/_op_impl/tbe/assign_add_ds.py +1 -0
  469. mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
  470. mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +1 -1
  471. mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad_v2.py +0 -1
  472. mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
  473. mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +1 -1
  474. mindspore/ops/_op_impl/tbe/batch_to_space_nd_v2.py +41 -0
  475. mindspore/ops/_op_impl/tbe/bce_with_logits_loss.py +1 -0
  476. mindspore/ops/_op_impl/tbe/bias_add_grad.py +2 -0
  477. mindspore/ops/_op_impl/tbe/bn_infer_grad.py +4 -2
  478. mindspore/ops/_op_impl/tbe/bn_infer_grad_ds.py +40 -0
  479. mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -1
  480. mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -1
  481. mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +6 -4
  482. mindspore/ops/_op_impl/tbe/cast.py +0 -2
  483. mindspore/ops/_op_impl/tbe/cast_ds.py +3 -3
  484. mindspore/ops/_op_impl/tbe/ctc_loss_v2.py +0 -2
  485. mindspore/ops/_op_impl/tbe/ctc_loss_v2_grad.py +0 -2
  486. mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +1 -0
  487. mindspore/ops/_op_impl/tbe/deformable_offsets.py +1 -0
  488. mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +1 -1
  489. mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +1 -1
  490. mindspore/ops/_op_impl/tbe/gather_nd.py +1 -0
  491. mindspore/ops/_op_impl/tbe/greater.py +2 -0
  492. mindspore/ops/_op_impl/tbe/{index_add.py → inplace_index_add.py} +3 -6
  493. mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2.py +0 -1
  494. mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +35 -0
  495. mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +35 -0
  496. mindspore/ops/_op_impl/tbe/one_hot_ds.py +0 -6
  497. mindspore/ops/_op_impl/tbe/{greater_ds.py → reduce_all_ds.py} +13 -16
  498. mindspore/ops/_op_impl/tbe/reduce_any_ds.py +39 -0
  499. mindspore/ops/_op_impl/tbe/roi_align_ds.py +44 -0
  500. mindspore/ops/_op_impl/tbe/roi_align_grad_ds.py +44 -0
  501. mindspore/ops/_op_impl/tbe/scatter_add.py +2 -0
  502. mindspore/ops/_op_impl/tbe/scatter_nd_add.py +2 -2
  503. mindspore/ops/_op_impl/tbe/slice.py +26 -15
  504. mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
  505. mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +1 -1
  506. mindspore/ops/_op_impl/tbe/strided_slice_grad_d.py +1 -0
  507. mindspore/ops/_op_impl/tbe/trans_data_ds.py +15 -5
  508. mindspore/ops/_op_impl/tbe/unsorted_segment_sum.py +1 -1
  509. mindspore/ops/_op_impl/tbe/unsorted_segment_sum_ds.py +2 -0
  510. mindspore/ops/_primitive_cache.py +3 -2
  511. mindspore/ops/_register_for_op.py +11 -0
  512. mindspore/ops/_utils/__init__.py +1 -1
  513. mindspore/ops/_utils/utils.py +20 -41
  514. mindspore/ops/_vmap/__init__.py +2 -2
  515. mindspore/ops/_vmap/vmap_array_ops.py +170 -78
  516. mindspore/ops/_vmap/vmap_base.py +24 -10
  517. mindspore/ops/_vmap/vmap_convolution_ops.py +7 -10
  518. mindspore/ops/_vmap/vmap_grad_math_ops.py +4 -4
  519. mindspore/ops/_vmap/vmap_grad_nn_ops.py +41 -9
  520. mindspore/ops/_vmap/vmap_image_ops.py +52 -0
  521. mindspore/ops/_vmap/vmap_math_ops.py +77 -6
  522. mindspore/ops/_vmap/vmap_nn_ops.py +78 -29
  523. mindspore/ops/_vmap/vmap_other_ops.py +3 -1
  524. mindspore/ops/_vmap/vmap_random_ops.py +55 -3
  525. mindspore/ops/_vmap/vmap_sparse_ops.py +1 -0
  526. mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
  527. mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
  528. mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +18 -19
  529. mindspore/ops/bprop_mindir/Argmax_bprop.mindir +13 -12
  530. mindspore/ops/bprop_mindir/Argmin_bprop.mindir +14 -13
  531. mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +17 -18
  532. mindspore/ops/bprop_mindir/Assign_bprop.mindir +16 -16
  533. mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +150 -0
  534. mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +66 -0
  535. mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
  536. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +13 -12
  537. mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
  538. mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +28 -0
  539. mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
  540. mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +33 -0
  541. mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +306 -0
  542. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +12 -8
  543. mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
  544. mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
  545. mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +240 -0
  546. mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +247 -0
  547. mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +247 -0
  548. mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +315 -0
  549. mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +278 -0
  550. mindspore/ops/bprop_mindir/DType_bprop.mindir +12 -12
  551. mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +58 -0
  552. mindspore/ops/bprop_mindir/Depend_bprop.mindir +12 -13
  553. mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +23 -0
  554. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +138 -0
  555. mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +15 -0
  556. mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
  557. mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
  558. mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +22 -24
  559. mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +16 -14
  560. mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +27 -0
  561. mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
  562. mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
  563. mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
  564. mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +12 -12
  565. mindspore/ops/bprop_mindir/Elu_bprop.mindir +16 -0
  566. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  567. mindspore/ops/bprop_mindir/Equal_bprop.mindir +18 -19
  568. mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +58 -0
  569. mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +16 -0
  570. mindspore/ops/bprop_mindir/Flatten_bprop.mindir +54 -0
  571. mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +18 -15
  572. mindspore/ops/bprop_mindir/GatherD_bprop.mindir +26 -0
  573. mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +57 -0
  574. mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
  575. mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +17 -18
  576. mindspore/ops/bprop_mindir/Greater_bprop.mindir +18 -19
  577. mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +16 -0
  578. mindspore/ops/bprop_mindir/HSwish_bprop.mindir +16 -0
  579. mindspore/ops/bprop_mindir/IOU_bprop.mindir +18 -19
  580. mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
  581. mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +13 -12
  582. mindspore/ops/bprop_mindir/IsInf_bprop.mindir +13 -10
  583. mindspore/ops/bprop_mindir/IsNan_bprop.mindir +14 -11
  584. mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +126 -0
  585. mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +15 -0
  586. mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +30 -0
  587. mindspore/ops/bprop_mindir/LRN_bprop.mindir +43 -0
  588. mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
  589. mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +18 -19
  590. mindspore/ops/bprop_mindir/Less_bprop.mindir +17 -18
  591. mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +22 -19
  592. mindspore/ops/bprop_mindir/Load_bprop.mindir +12 -13
  593. mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +23 -0
  594. mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +17 -18
  595. mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +14 -13
  596. mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +21 -0
  597. mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +74 -0
  598. mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +74 -0
  599. mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +75 -0
  600. mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +65 -0
  601. mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
  602. mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
  603. mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
  604. mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +27 -0
  605. mindspore/ops/bprop_mindir/Mish_bprop.mindir +35 -0
  606. mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
  607. mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
  608. mindspore/ops/bprop_mindir/NonZero_bprop.mindir +14 -0
  609. mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +18 -19
  610. mindspore/ops/bprop_mindir/OneHot_bprop.mindir +25 -23
  611. mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +13 -13
  612. mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
  613. mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
  614. mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
  615. mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +29 -0
  616. mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +82 -0
  617. mindspore/ops/bprop_mindir/Range_bprop.mindir +21 -19
  618. mindspore/ops/bprop_mindir/Rank_bprop.mindir +11 -11
  619. mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +16 -0
  620. mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
  621. mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +18 -17
  622. mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +18 -17
  623. mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +19 -23
  624. mindspore/ops/bprop_mindir/Reshape_bprop.mindir +60 -0
  625. mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +29 -0
  626. mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +89 -0
  627. mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +52 -0
  628. mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +22 -0
  629. mindspore/ops/bprop_mindir/Round_bprop.mindir +14 -13
  630. mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
  631. mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
  632. mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +22 -0
  633. mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +24 -0
  634. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +22 -0
  635. mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
  636. mindspore/ops/bprop_mindir/SeLU_bprop.mindir +21 -0
  637. mindspore/ops/bprop_mindir/Select_bprop.mindir +30 -34
  638. mindspore/ops/bprop_mindir/Shape_bprop.mindir +12 -12
  639. mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +21 -0
  640. mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
  641. mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +16 -0
  642. mindspore/ops/bprop_mindir/Sign_bprop.mindir +13 -12
  643. mindspore/ops/bprop_mindir/Slice_bprop.mindir +26 -0
  644. mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +36 -0
  645. mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  646. mindspore/ops/bprop_mindir/Softplus_bprop.mindir +16 -0
  647. mindspore/ops/bprop_mindir/Softsign_bprop.mindir +33 -0
  648. mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
  649. mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +28 -0
  650. mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +23 -0
  651. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  652. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  653. mindspore/ops/bprop_mindir/Split_bprop.mindir +22 -0
  654. mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +54 -0
  655. mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +95 -0
  656. mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +98 -0
  657. mindspore/ops/bprop_mindir/Switch_bprop.mindir +28 -32
  658. mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
  659. mindspore/ops/bprop_mindir/Tanh_bprop.mindir +66 -0
  660. mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +22 -0
  661. mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +29 -0
  662. mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +14 -0
  663. mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
  664. mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
  665. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +23 -0
  666. mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +18 -15
  667. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +11 -13
  668. mindspore/ops/bprop_mindir/Unique_bprop.mindir +16 -0
  669. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +22 -0
  670. mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +32 -0
  671. mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +38 -0
  672. mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +13 -12
  673. mindspore/ops/bprop_mindir/__init__.py +1 -4
  674. mindspore/ops/bprop_mindir/generate_mindir.py +32 -20
  675. mindspore/ops/composite/__init__.py +12 -13
  676. mindspore/ops/composite/base.py +261 -254
  677. mindspore/ops/composite/env_ops.py +41 -0
  678. mindspore/ops/composite/math_ops.py +197 -156
  679. mindspore/ops/composite/multitype_ops/_compile_utils.py +428 -176
  680. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +188 -87
  681. mindspore/ops/composite/multitype_ops/add_impl.py +23 -1
  682. mindspore/ops/composite/multitype_ops/div_impl.py +3 -3
  683. mindspore/ops/composite/multitype_ops/equal_impl.py +1 -0
  684. mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -1
  685. mindspore/ops/composite/multitype_ops/getitem_impl.py +52 -5
  686. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +31 -0
  687. mindspore/ops/composite/multitype_ops/greater_impl.py +31 -0
  688. mindspore/ops/composite/multitype_ops/in_impl.py +15 -3
  689. mindspore/ops/composite/multitype_ops/less_equal_impl.py +33 -2
  690. mindspore/ops/composite/multitype_ops/less_impl.py +33 -0
  691. mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -2
  692. mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
  693. mindspore/ops/composite/multitype_ops/mod_impl.py +1 -1
  694. mindspore/ops/composite/multitype_ops/mul_impl.py +21 -7
  695. mindspore/ops/composite/multitype_ops/not_in_impl.py +15 -3
  696. mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -4
  697. mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
  698. mindspore/ops/composite/multitype_ops/setitem_impl.py +62 -70
  699. mindspore/ops/composite/multitype_ops/sub_impl.py +3 -3
  700. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +41 -4
  701. mindspore/ops/function/__init__.py +323 -8
  702. mindspore/ops/function/array_func.py +3511 -780
  703. mindspore/ops/function/clip_func.py +329 -0
  704. mindspore/ops/function/debug_func.py +6 -6
  705. mindspore/ops/function/grad/__init__.py +5 -1
  706. mindspore/ops/function/grad/grad_func.py +736 -65
  707. mindspore/ops/function/image_func.py +270 -0
  708. mindspore/ops/function/linalg_func.py +268 -8
  709. mindspore/ops/function/math_func.py +8032 -3164
  710. mindspore/ops/function/nn_func.py +5619 -1855
  711. mindspore/ops/function/other_func.py +115 -0
  712. mindspore/ops/function/parameter_func.py +11 -10
  713. mindspore/ops/function/random_func.py +939 -77
  714. mindspore/ops/function/sparse_func.py +249 -84
  715. mindspore/ops/function/sparse_unary_func.py +2303 -0
  716. mindspore/ops/function/spectral_func.py +146 -0
  717. mindspore/ops/function/vmap_func.py +114 -0
  718. mindspore/ops/functional.py +182 -254
  719. mindspore/ops/op_info_register.py +79 -34
  720. mindspore/ops/operations/__init__.py +210 -118
  721. mindspore/ops/operations/_csr_ops.py +7 -7
  722. mindspore/ops/operations/_embedding_cache_ops.py +25 -15
  723. mindspore/ops/operations/_grad_ops.py +447 -322
  724. mindspore/ops/operations/_inner_ops.py +547 -176
  725. mindspore/ops/operations/_map_tensor_ops.py +112 -0
  726. mindspore/ops/operations/_ms_kernel.py +29 -27
  727. mindspore/ops/operations/_ocr_ops.py +11 -11
  728. mindspore/ops/operations/_opaque_predicate_registry.py +41 -0
  729. mindspore/ops/operations/_quant_ops.py +186 -101
  730. mindspore/ops/operations/_rl_inner_ops.py +122 -61
  731. mindspore/ops/operations/_scalar_ops.py +466 -0
  732. mindspore/ops/operations/_sequence_ops.py +1047 -0
  733. mindspore/ops/operations/_tensor_array.py +10 -11
  734. mindspore/ops/operations/_thor_ops.py +4 -4
  735. mindspore/ops/operations/array_ops.py +1428 -1226
  736. mindspore/ops/operations/comm_ops.py +180 -117
  737. mindspore/ops/operations/control_ops.py +4 -2
  738. mindspore/ops/operations/custom_ops.py +185 -98
  739. mindspore/ops/operations/debug_ops.py +92 -54
  740. mindspore/ops/operations/image_ops.py +406 -211
  741. mindspore/ops/operations/inner_ops.py +42 -53
  742. mindspore/ops/operations/linalg_ops.py +32 -29
  743. mindspore/ops/operations/math_ops.py +2076 -897
  744. mindspore/ops/operations/nn_ops.py +1282 -1252
  745. mindspore/ops/operations/other_ops.py +124 -278
  746. mindspore/ops/operations/random_ops.py +345 -178
  747. mindspore/ops/operations/rl_ops.py +8 -9
  748. mindspore/ops/operations/sparse_ops.py +502 -157
  749. mindspore/ops/operations/spectral_ops.py +107 -0
  750. mindspore/ops/primitive.py +192 -15
  751. mindspore/ops/vm_impl_registry.py +23 -2
  752. mindspore/parallel/__init__.py +6 -1
  753. mindspore/parallel/_auto_parallel_context.py +199 -92
  754. mindspore/parallel/_cell_wrapper.py +4 -2
  755. mindspore/parallel/_cost_model_context.py +3 -0
  756. mindspore/parallel/_dp_allreduce_fusion.py +2 -1
  757. mindspore/parallel/_offload_context.py +185 -0
  758. mindspore/parallel/_parallel_serialization.py +167 -28
  759. mindspore/parallel/_ps_context.py +9 -5
  760. mindspore/parallel/_recovery_context.py +1 -1
  761. mindspore/parallel/_tensor.py +9 -1
  762. mindspore/{nn/transformer → parallel/_transformer}/__init__.py +6 -6
  763. mindspore/{nn/transformer → parallel/_transformer}/layers.py +59 -37
  764. mindspore/{nn/transformer → parallel/_transformer}/loss.py +4 -7
  765. mindspore/{nn/transformer → parallel/_transformer}/moe.py +160 -35
  766. mindspore/{nn/transformer → parallel/_transformer}/op_parallel_config.py +3 -3
  767. mindspore/{nn/transformer → parallel/_transformer}/transformer.py +235 -196
  768. mindspore/parallel/_utils.py +47 -7
  769. mindspore/parallel/algo_parameter_config.py +5 -1
  770. mindspore/parallel/checkpoint_transform.py +329 -0
  771. mindspore/parallel/shard.py +229 -0
  772. mindspore/perf_msvcbuildinsights.dll +0 -0
  773. mindspore/pgodb140.dll +0 -0
  774. mindspore/pgort140.dll +0 -0
  775. mindspore/profiler/__init__.py +2 -1
  776. mindspore/profiler/common/util.py +4 -3
  777. mindspore/profiler/common/validator/validate_path.py +2 -2
  778. mindspore/profiler/envprofiling.py +249 -0
  779. mindspore/profiler/parser/aicpu_data_parser.py +38 -39
  780. mindspore/profiler/parser/ascend_timeline_generator.py +497 -0
  781. mindspore/profiler/parser/base_timeline_generator.py +471 -0
  782. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +684 -0
  783. mindspore/profiler/parser/framework_parser.py +42 -16
  784. mindspore/profiler/parser/hccl_parser.py +158 -158
  785. mindspore/profiler/parser/hwts_log_parser.py +7 -6
  786. mindspore/profiler/parser/integrator.py +18 -1579
  787. mindspore/profiler/parser/minddata_analyzer.py +8 -8
  788. mindspore/profiler/parser/msadvisor_analyzer.py +14 -27
  789. mindspore/profiler/parser/msadvisor_parser.py +2 -4
  790. mindspore/profiler/parser/optime_parser.py +17 -18
  791. mindspore/profiler/parser/profiler_info.py +108 -0
  792. mindspore/profiler/parser/step_trace_parser.py +1 -1
  793. mindspore/profiler/profiling.py +396 -194
  794. mindspore/rewrite/__init__.py +6 -2
  795. mindspore/rewrite/api/node.py +51 -110
  796. mindspore/rewrite/api/node_type.py +10 -6
  797. mindspore/rewrite/api/pattern_engine.py +51 -7
  798. mindspore/rewrite/api/scoped_value.py +64 -53
  799. mindspore/rewrite/api/symbol_tree.py +108 -61
  800. mindspore/rewrite/api/tree_node_helper.py +2 -3
  801. mindspore/{compression/quant/__init__.py → rewrite/ast_creator_register.py} +20 -11
  802. mindspore/rewrite/ast_helpers/__init__.py +6 -3
  803. mindspore/rewrite/ast_helpers/ast_creator.py +115 -0
  804. mindspore/rewrite/ast_helpers/ast_finder.py +99 -1
  805. mindspore/rewrite/ast_helpers/ast_modifier.py +17 -4
  806. mindspore/rewrite/ast_helpers/ast_replacer.py +1 -1
  807. mindspore/rewrite/ast_transformers/__init__.py +0 -1
  808. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +46 -5
  809. mindspore/rewrite/ast_transformers/remove_return_out_of_if.py +6 -3
  810. mindspore/rewrite/common/__init__.py +2 -0
  811. mindspore/rewrite/common/event.py +1 -1
  812. mindspore/rewrite/common/observable.py +1 -1
  813. mindspore/rewrite/common/observer.py +1 -1
  814. mindspore/rewrite/common/rewrite_elog.py +35 -0
  815. mindspore/rewrite/namer.py +2 -2
  816. mindspore/rewrite/namespace.py +14 -4
  817. mindspore/rewrite/node.py +161 -13
  818. mindspore/rewrite/parser.py +0 -1
  819. mindspore/rewrite/parser_register.py +0 -1
  820. mindspore/rewrite/parsers/arguments_parser.py +3 -2
  821. mindspore/rewrite/parsers/assign_parser.py +267 -67
  822. mindspore/rewrite/parsers/attribute_parser.py +56 -0
  823. mindspore/rewrite/parsers/class_def_parser.py +191 -108
  824. mindspore/rewrite/parsers/constant_parser.py +101 -0
  825. mindspore/rewrite/parsers/container_parser.py +88 -0
  826. mindspore/rewrite/parsers/for_parser.py +28 -15
  827. mindspore/rewrite/parsers/function_def_parser.py +21 -5
  828. mindspore/rewrite/parsers/if_parser.py +11 -28
  829. mindspore/rewrite/parsers/module_parser.py +9 -6
  830. mindspore/rewrite/parsers/return_parser.py +3 -2
  831. mindspore/rewrite/sparsify/__init__.py +0 -0
  832. mindspore/rewrite/sparsify/sparse_transformer.py +448 -0
  833. mindspore/rewrite/sparsify/sparsify.py +109 -0
  834. mindspore/rewrite/sparsify/utils.py +173 -0
  835. mindspore/rewrite/symbol_tree.py +322 -109
  836. mindspore/rewrite/symbol_tree_builder.py +45 -8
  837. mindspore/rewrite/symbol_tree_dumper.py +0 -1
  838. mindspore/rewrite/topological_manager.py +1 -2
  839. mindspore/run_check/_check_version.py +209 -112
  840. mindspore/run_check/run_check.py +2 -1
  841. mindspore/tbbmalloc.dll +0 -0
  842. mindspore/tinyxml2.dll +0 -0
  843. mindspore/train/__init__.py +6 -4
  844. mindspore/train/_utils.py +28 -5
  845. mindspore/train/amp.py +321 -50
  846. mindspore/train/callback/__init__.py +3 -1
  847. mindspore/train/callback/_backup_and_restore.py +120 -0
  848. mindspore/train/callback/_callback.py +8 -8
  849. mindspore/train/callback/_checkpoint.py +12 -9
  850. mindspore/train/callback/_early_stop.py +13 -7
  851. mindspore/train/callback/_history.py +8 -8
  852. mindspore/train/callback/_lambda_callback.py +6 -6
  853. mindspore/train/callback/_landscape.py +36 -38
  854. mindspore/train/callback/_loss_monitor.py +12 -6
  855. mindspore/train/callback/_lr_scheduler_callback.py +2 -4
  856. mindspore/train/callback/_on_request_exit.py +212 -0
  857. mindspore/train/callback/_reduce_lr_on_plateau.py +13 -7
  858. mindspore/train/callback/_summary_collector.py +27 -19
  859. mindspore/train/callback/_time_monitor.py +13 -7
  860. mindspore/train/checkpoint_pb2.py +68 -8
  861. mindspore/train/data_sink.py +122 -33
  862. mindspore/train/dataset_helper.py +28 -87
  863. mindspore/train/loss_scale_manager.py +4 -7
  864. mindspore/{nn → train}/metrics/__init__.py +20 -20
  865. mindspore/{nn → train}/metrics/accuracy.py +12 -10
  866. mindspore/{nn → train}/metrics/auc.py +4 -4
  867. mindspore/{nn → train}/metrics/bleu_score.py +4 -4
  868. mindspore/{nn → train}/metrics/confusion_matrix.py +10 -8
  869. mindspore/{nn → train}/metrics/cosine_similarity.py +4 -4
  870. mindspore/{nn → train}/metrics/dice.py +6 -5
  871. mindspore/{nn → train}/metrics/error.py +7 -5
  872. mindspore/{nn → train}/metrics/fbeta.py +9 -7
  873. mindspore/{nn → train}/metrics/hausdorff_distance.py +8 -6
  874. mindspore/{nn → train}/metrics/loss.py +4 -3
  875. mindspore/{nn → train}/metrics/mean_surface_distance.py +6 -5
  876. mindspore/{nn → train}/metrics/metric.py +6 -5
  877. mindspore/{nn → train}/metrics/occlusion_sensitivity.py +4 -3
  878. mindspore/{nn → train}/metrics/perplexity.py +5 -4
  879. mindspore/{nn → train}/metrics/precision.py +5 -4
  880. mindspore/{nn → train}/metrics/recall.py +5 -4
  881. mindspore/{nn → train}/metrics/roc.py +7 -6
  882. mindspore/{nn → train}/metrics/root_mean_square_surface_distance.py +6 -5
  883. mindspore/{nn → train}/metrics/topk.py +7 -5
  884. mindspore/train/mind_ir_pb2.py +339 -32
  885. mindspore/train/model.py +113 -84
  886. mindspore/train/serialization.py +547 -167
  887. mindspore/train/summary/_summary_adapter.py +1 -1
  888. mindspore/train/summary/summary_record.py +43 -12
  889. mindspore/train/train_thor/convert_utils.py +7 -1
  890. mindspore/train/train_thor/dataset_helper.py +3 -3
  891. mindspore/train/train_thor/model_thor.py +0 -4
  892. mindspore/turbojpeg.dll +0 -0
  893. mindspore/vcmeta.dll +0 -0
  894. mindspore/vcruntime140.dll +0 -0
  895. mindspore/vcruntime140_1.dll +0 -0
  896. mindspore/version.py +1 -1
  897. {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/METADATA +4 -3
  898. {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/RECORD +901 -660
  899. mindspore/compression/common/constant.py +0 -124
  900. mindspore/compression/export/__init__.py +0 -19
  901. mindspore/compression/export/quant_export.py +0 -514
  902. mindspore/compression/quant/qat.py +0 -636
  903. mindspore/compression/quant/quant_utils.py +0 -462
  904. mindspore/compression/quant/quantizer.py +0 -68
  905. mindspore/libatomic-1.dll +0 -0
  906. mindspore/libgcc_s_seh-1.dll +0 -0
  907. mindspore/libgfortran-4.dll +0 -0
  908. mindspore/libgomp-1.dll +0 -0
  909. mindspore/libjpeg-62.dll +0 -0
  910. mindspore/libmindspore.dll +0 -0
  911. mindspore/libmindspore_common.dll +0 -0
  912. mindspore/libmindspore_core.dll +0 -0
  913. mindspore/libmindspore_glog.dll +0 -0
  914. mindspore/libnnacl.dll +0 -0
  915. mindspore/libopencv_core452.dll +0 -0
  916. mindspore/libopencv_imgcodecs452.dll +0 -0
  917. mindspore/libopencv_imgproc452.dll +0 -0
  918. mindspore/libquadmath-0.dll +0 -0
  919. mindspore/libsqlite3.dll +0 -0
  920. mindspore/libssp-0.dll +0 -0
  921. mindspore/libstdc++-6.dll +0 -0
  922. mindspore/libtinyxml2.dll +0 -0
  923. mindspore/libturbojpeg.dll +0 -0
  924. mindspore/libwinpthread-1.dll +0 -0
  925. mindspore/nn/layer/quant.py +0 -1868
  926. mindspore/nn/layer/rnn_utils.py +0 -90
  927. mindspore/nn/probability/dpn/__init__.py +0 -22
  928. mindspore/nn/probability/dpn/vae/__init__.py +0 -25
  929. mindspore/nn/probability/dpn/vae/cvae.py +0 -138
  930. mindspore/nn/probability/dpn/vae/vae.py +0 -122
  931. mindspore/nn/probability/infer/__init__.py +0 -22
  932. mindspore/nn/probability/infer/variational/elbo.py +0 -70
  933. mindspore/nn/probability/infer/variational/svi.py +0 -84
  934. mindspore/nn/probability/toolbox/__init__.py +0 -22
  935. mindspore/nn/probability/toolbox/anomaly_detection.py +0 -99
  936. mindspore/nn/probability/toolbox/uncertainty_evaluation.py +0 -363
  937. mindspore/nn/probability/transforms/__init__.py +0 -22
  938. mindspore/nn/probability/transforms/transform_bnn.py +0 -262
  939. mindspore/nn/probability/zhusuan/__init__.py +0 -18
  940. mindspore/nn/probability/zhusuan/framework/__init__.py +0 -18
  941. mindspore/nn/probability/zhusuan/framework/bn.py +0 -95
  942. mindspore/nn/probability/zhusuan/variational/__init__.py +0 -18
  943. mindspore/nn/probability/zhusuan/variational/elbo.py +0 -46
  944. mindspore/ops/_op_impl/tbe/bias_add_grad_ds.py +0 -52
  945. mindspore/ops/_op_impl/tbe/scatter_nd_add_ds.py +0 -43
  946. mindspore/ops/bprop_mindir/AssignAdd_bprop.mindir +0 -20
  947. mindspore/ops/bprop_mindir/Identity_bprop.mindir +0 -9
  948. mindspore/ops/bprop_mindir/LogicalOr_bprop.mindir +0 -20
  949. mindspore/ops/bprop_mindir/ReLU_bprop.mindir +0 -16
  950. mindspore/ops/bprop_mindir/UpdateState_bprop.mindir +0 -17
  951. mindspore/ops/bprop_mindir/stop_gradient_bprop.mindir +0 -12
  952. mindspore/ops/composite/array_ops.py +0 -210
  953. mindspore/ops/composite/clip_ops.py +0 -238
  954. mindspore/ops/composite/random_ops.py +0 -426
  955. mindspore/ops/composite/vmap_ops.py +0 -38
  956. mindspore/ops/operations/sponge_ops.py +0 -3531
  957. mindspore/ops/operations/sponge_update_ops.py +0 -2546
  958. mindspore/parallel/nn/__init__.py +0 -42
  959. mindspore/parallel/nn/loss.py +0 -22
  960. mindspore/parallel/nn/moe.py +0 -21
  961. mindspore/parallel/nn/op_parallel_config.py +0 -22
  962. mindspore/parallel/nn/transformer.py +0 -31
  963. mindspore/run_check/_check_deps_version.py +0 -84
  964. {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/WHEEL +0 -0
  965. {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/entry_points.txt +0 -0
  966. {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/top_level.txt +0 -0
@@ -16,25 +16,26 @@
16
16
  """array_ops vmap impl."""
17
17
  from __future__ import absolute_import
18
18
 
19
- import numpy as np
20
19
  import mindspore
21
20
  import mindspore.numpy as mnp
22
21
  from mindspore import ops
23
22
  from mindspore.common import Tensor
23
+ from mindspore._c_expression import Tensor as Tensor_
24
24
  from mindspore.ops import operations as P
25
25
  from mindspore.ops import functional as F
26
- from mindspore.ops import constexpr
26
+ from mindspore.ops.primitive import constexpr, _primexpr
27
27
  from mindspore.ops.operations._grad_ops import MaskedSelectGrad
28
28
  from mindspore.ops.operations import _grad_ops as G
29
29
  from mindspore.ops.operations.array_ops import Fills, UniqueConsecutive, Col2Im, NonZero, IndexFill, \
30
30
  TensorScatterElements
31
31
  from mindspore.ops.operations.random_ops import RandomPoisson
32
+ from mindspore.ops.operations._inner_ops import DynamicBroadcastTo
32
33
  from mindspore.ops.primitive import Primitive
33
34
  from mindspore.ops._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, _bdim_at_front, \
34
35
  _raise_value_error, _vmap_clone_prim, _handle_broadcasting, get_unsupported_dynamic_vmap_rule, _broadcast_by_axis, \
35
- get_unop_vmap_rule, _get_reduce_out_dim, _get_reduce_batch_axis, _vmap_update_prim_attr, \
36
+ get_unop_vmap_rule, _get_reduce_out_dim, _get_reduce_batch_axis, \
36
37
  _bdim_at_any
37
- from mindspore.ops.composite import _VmapGeneralRule
38
+ from mindspore.ops.function import _VmapGeneralRule
38
39
 
39
40
 
40
41
  @vmap_rules_getters.register(P.NoRepeatNGram)
@@ -137,7 +138,7 @@ def get_arg_min_max_with_value_vmap_rule(prim, axis_size):
137
138
  return vmap_rule
138
139
 
139
140
 
140
- @constexpr
141
+ @_primexpr
141
142
  def _get_prefix(indices_shape, axis_size, indices_dtype):
142
143
  """
143
144
  Generate prefix by indices shape, whose -1 axis value is the index value of axis 0.
@@ -147,14 +148,16 @@ def _get_prefix(indices_shape, axis_size, indices_dtype):
147
148
  the generated prefix is a Tensor([[[0], [0]],
148
149
  [[1], [1]]])
149
150
  """
150
- if not indices_shape:
151
- raise ValueError("indices_shape is empty in _get_prefix.")
151
+ def _check(indices_shape):
152
+ if not indices_shape:
153
+ raise ValueError("indices_shape is empty in _get_prefix.")
152
154
 
155
+ _check(indices_shape)
153
156
  indices_len = len(indices_shape)
154
-
155
157
  if indices_len == 1:
156
- prefix = np.arange(axis_size)
157
- return Tensor(prefix, indices_dtype)
158
+ prefix = P.Range()(Tensor(0, indices_dtype), P.Fill()(
159
+ indices_dtype, (), axis_size), Tensor(1, indices_dtype))
160
+ return prefix
158
161
 
159
162
  indices_end = indices_len - 1
160
163
  prefix_shape = ()
@@ -169,8 +172,9 @@ def _get_prefix(indices_shape, axis_size, indices_dtype):
169
172
  else:
170
173
  expand_shape = expand_shape + (1,)
171
174
 
172
- prefix = np.broadcast_to(np.arange(axis_size).reshape(expand_shape), prefix_shape)
173
- return Tensor(prefix, indices_dtype)
175
+ prefix = P.BroadcastTo(prefix_shape)(P.Reshape()(P.Range()(Tensor(
176
+ 0, indices_dtype), Tensor(axis_size, indices_dtype), Tensor(1, indices_dtype)), expand_shape))
177
+ return prefix
174
178
 
175
179
 
176
180
  @vmap_rules_getters.register(P.Transpose)
@@ -179,7 +183,7 @@ def get_transpose_vmap_rule(prim, axis_size):
179
183
  if isinstance(prim, str):
180
184
  prim = Primitive(prim)
181
185
 
182
- @constexpr
186
+ @_primexpr
183
187
  def _get_transpose_batch_perm(dim, perm, x_rank):
184
188
  """Generate batch_perm based on the original perm of transpose operation and dim of the input."""
185
189
  if dim < 0:
@@ -223,24 +227,20 @@ def get_tile_vmap_rule(prim, axis_size):
223
227
  if isinstance(prim, str):
224
228
  prim = Primitive(prim)
225
229
 
226
- @constexpr
227
- def _get_tile_shape(input_shape, multiples):
230
+ @_primexpr
231
+ def _get_batch_multiples(input_shape, dim, multiples):
228
232
  input_ndim = len(input_shape)
229
233
  multiples_ndim = len(multiples)
230
- input_shape = (input_shape[0],) + (1,) * (multiples_ndim - input_ndim + 1) + input_shape[1:]
231
- multiples = (1,) * (input_ndim - multiples_ndim) + multiples
232
- input_expand_shape = (input_shape[0],) + tuple([
233
- j
234
- for i in input_shape[1:]
235
- for j in [1, i]
236
- ])
237
- repeat_shape = (input_shape[0],) + tuple([
238
- k
239
- for pair in zip(multiples[1:], input_shape[1:])
240
- for k in pair
241
- ])
242
- output_shape = tuple([a * b for a, b in zip(input_shape, multiples)])
243
- return input_expand_shape, repeat_shape, output_shape
234
+ if multiples_ndim < input_ndim - 1:
235
+ multiples = (1,) * (input_ndim - 1 - multiples_ndim) + multiples
236
+
237
+ rev_dim = input_ndim - 1 - dim
238
+ if rev_dim == 0:
239
+ return multiples + (1,), multiples_ndim
240
+
241
+ batch_multiples = list(multiples)
242
+ batch_multiples.insert(-rev_dim, 1)
243
+ return tuple(batch_multiples), multiples_ndim - rev_dim
244
244
 
245
245
  def vmap_rule(input_bdim, multiples_bdim):
246
246
  is_all_none, result = vmap_general_preprocess(prim, input_bdim, multiples_bdim)
@@ -252,13 +252,10 @@ def get_tile_vmap_rule(prim, axis_size):
252
252
  if multiples_dim is not None:
253
253
  _raise_value_error("The source axis of shape in `Tile` must be None, but got {}.".format(multiples_dim))
254
254
 
255
- input_x = _bdim_at_front(input_x, dim, axis_size)
256
255
  input_shape = F.shape(input_x)
257
- input_expand_shape, repeat_shape, output_shape = _get_tile_shape(input_shape, multiples)
258
- expand_input = F.reshape(input_x, input_expand_shape)
259
- repeat_tensor = P.BroadcastTo(repeat_shape)(expand_input)
260
- output = F.reshape(repeat_tensor, output_shape)
261
- return output, 0
256
+ batch_multiples, out_dim = _get_batch_multiples(input_shape, dim, multiples)
257
+ repeat_tensor = P.Tile()(input_x, batch_multiples)
258
+ return repeat_tensor, out_dim
262
259
 
263
260
  return vmap_rule
264
261
 
@@ -359,8 +356,13 @@ def get_unstack_vmap_rule(prim, axis_size):
359
356
  def get_reshape_vmap_rule(prim, axis_size):
360
357
  """VmapRule for `Reshape` operation."""
361
358
 
362
- @constexpr
359
+
360
+ @_primexpr
363
361
  def get_batch_shape(x_shape, x_dim, target_shape, axis_size):
362
+ def _check(neg_index, target_shape):
363
+ if neg_index != -1:
364
+ raise ValueError(f'The shape can only has one -1 at most, but {target_shape}.')
365
+
364
366
  if x_dim == 0:
365
367
  return (axis_size,) + target_shape, 0, False
366
368
 
@@ -371,19 +373,21 @@ def get_reshape_vmap_rule(prim, axis_size):
371
373
  dim_prod = 1
372
374
  for i, shp_i in enumerate(target_shape):
373
375
  if shp_i == -1:
374
- if neg_index != -1:
375
- raise ValueError(f'The shape can only has one -1 at most, but {target_shape}.')
376
+ _check(neg_index, target_shape)
376
377
  neg_index = i
377
378
  else:
378
379
  dim_prod *= shp_i
379
- arr_prod = np.prod(x_shape)
380
+ arr_prod = 1
381
+ for i in x_shape:
382
+ arr_prod *= i
380
383
  target_shape_list = list(target_shape)
381
384
  if neg_index != -1:
382
385
  neg_index_size = int(arr_prod // (dim_prod * axis_size))
383
386
  target_shape_list[neg_index] = neg_index_size
384
387
 
385
- arr_prod_before_dim = np.prod(x_shape[:x_dim])
386
-
388
+ arr_prod_before_dim = 1
389
+ for i in x_shape[:x_dim]:
390
+ arr_prod_before_dim *= i
387
391
  dim_prod = 1
388
392
  for i, shp_i in enumerate(target_shape_list, start=1):
389
393
  dim_prod *= shp_i
@@ -428,7 +432,7 @@ def get_reverse_sequence_vmap_rule(prim, axis_size):
428
432
  batch_dim = prim.batch_dim_
429
433
  seq_dim = prim.seq_dim_
430
434
 
431
- @constexpr
435
+ @_primexpr
432
436
  def get_batch_seq_dim(dim, batch_dim_, seq_dim_):
433
437
  if dim is None:
434
438
  batch_dim_ += 1
@@ -444,7 +448,7 @@ def get_reverse_sequence_vmap_rule(prim, axis_size):
444
448
  seq_dim_ += 1
445
449
  return batch_dim_, seq_dim_
446
450
 
447
- @constexpr
451
+ @_primexpr
448
452
  def get_seq_dim(dim, batch_dim_, seq_dim_):
449
453
  if dim is None:
450
454
  return seq_dim_
@@ -564,20 +568,19 @@ def get_scatter_nd_vmap_rule(prim, axis_size):
564
568
  Reshape the output tensor to `[10, 6, 4, 5]`
565
569
  """
566
570
 
567
- @constexpr
571
+ @_primexpr
568
572
  def _refine_shape(shape, bdim_size):
569
573
  offset = shape[0]
570
574
  return (bdim_size * shape[0],) + tuple(shape[1:]), offset, (bdim_size,) + tuple(shape)
571
575
 
572
- @constexpr
576
+ @_primexpr
573
577
  def _gen_indices_offset(shape, offset):
574
578
  # original rank(indices.shape) is required >= 2, so indices with batch dim's rank >= 3.
575
- shape = [shape[0]] + [1] * (len(shape) - 2) + [shape[-1]]
576
- val = np.zeros(shape, np.int32) # the dtype will be changed when creating Tensor
577
- val = np.reshape(val, (shape[0], shape[-1]))
579
+ shape = (shape[0],) + (1,) * (len(shape) - 2) + (shape[-1],)
580
+ val = P.Zeros()((shape[0], shape[-1]), mindspore.int32)
578
581
  for i in range(shape[0]):
579
582
  val[i, 0] = i * offset
580
- return np.reshape(val, shape)
583
+ return P.Reshape()(val, shape)
581
584
 
582
585
  if isinstance(prim, str):
583
586
  prim = Primitive(prim)
@@ -598,7 +601,7 @@ def get_scatter_nd_vmap_rule(prim, axis_size):
598
601
  indices_shape = F.shape(indices)
599
602
  indices_dtype = F.dtype(indices)
600
603
  offset_val = _gen_indices_offset(indices_shape, offset)
601
- indices_offset = Tensor(offset_val, indices_dtype)
604
+ indices_offset = P.Cast()(offset_val, indices_dtype)
602
605
  new_indices = P.Add()(indices, indices_offset)
603
606
  out = prim(new_indices, updates, new_shape)
604
607
  real_out = P.Reshape()(out, out_shape)
@@ -846,6 +849,62 @@ def get_fill_vmap_rule(prim, axis_size):
846
849
  return vmap_rule
847
850
 
848
851
 
852
+ @constexpr
853
+ def to_tensor_with_type(x, type):
854
+ """x to Tensor with type"""
855
+ return Tensor(x, type)
856
+
857
+
858
+ @vmap_rules_getters.register(P.FillV2)
859
+ def get_fill_v2_vmap_rule(prim, axis_size):
860
+ """VmapRule for `FillV2` operation."""
861
+ if isinstance(prim, str):
862
+ prim = Primitive(prim)
863
+
864
+ def vmap_rule(shape_bdim, value_bdim):
865
+ is_all_none, result = vmap_general_preprocess(prim, shape_bdim, value_bdim)
866
+ if is_all_none:
867
+ return result
868
+
869
+ value_shape, shape_dim = shape_bdim
870
+ if shape_dim is not None:
871
+ _raise_value_error(
872
+ "The source axis of `shape` in `P.FillV2` must be None, but got {}."
873
+ .format(shape_dim))
874
+
875
+ value, vdim = value_bdim
876
+ value_rank = F.rank(value)
877
+ if value_rank != 1 or vdim != 0:
878
+ _raise_value_error(
879
+ "The `value` in `P.FillV2` must be constant value, thus the value only "
880
+ "can be rank: 1 with source axis: 0 in vmap scope, but got value rank: "
881
+ "{} with source axis: {}.".format(value_rank, vdim))
882
+ value = F.reshape(value, (axis_size,) + (1,) * len(value_shape))
883
+
884
+ out = None
885
+ if isinstance(value_shape, (Tensor_, Tensor)):
886
+ value_shape_rank = F.rank(value_shape)
887
+ if value_shape_rank != 1:
888
+ _raise_value_error(
889
+ "The `shape` in `P.FillV2` must be 1-D tensor, thus the shape only "
890
+ "can be rank: 1, but got shape rank: "
891
+ "{}.".format(value_shape_rank))
892
+ axis_size_tensor = to_tensor_with_type((axis_size,),
893
+ F.dtype(value_shape))
894
+ broad_cast_shape = F.concat((axis_size_tensor, value_shape))
895
+ out = DynamicBroadcastTo()(value, broad_cast_shape)
896
+ elif isinstance(value_shape, tuple):
897
+ out = P.BroadcastTo((axis_size,) + value_shape)(value)
898
+ else:
899
+ _raise_value_error(
900
+ f"For `P.FillV2`, the input `shape` should be Tuple or Tensor, but got `shape`: {value_shape}."
901
+ )
902
+
903
+ return out, 0
904
+
905
+ return vmap_rule
906
+
907
+
849
908
  @vmap_rules_getters.register(Fills)
850
909
  def get_fills_vmap_rule(prim, axis_size):
851
910
  """VmapRule for `Fills` operation."""
@@ -1299,12 +1358,7 @@ def get_gatherd_grad_v2_vmap_rule(prim, axis_size):
1299
1358
  if isinstance(prim, str):
1300
1359
  prim = Primitive(prim)
1301
1360
 
1302
- dim = 0
1303
- if hasattr(prim, 'dim'):
1304
- dim = prim.dim
1305
-
1306
- @constexpr
1307
- def _update_attr(x_rank, batch_dim):
1361
+ def _update_dim(dim, x_rank, batch_dim):
1308
1362
  pdim = dim
1309
1363
  if pdim < 0:
1310
1364
  pdim += x_rank
@@ -1312,19 +1366,22 @@ def get_gatherd_grad_v2_vmap_rule(prim, axis_size):
1312
1366
  _raise_value_error(
1313
1367
  "The `dim` in `GatherDGradV2` must be in range [{}, {}], but got {}.".format(-x_rank, x_rank - 1, dim))
1314
1368
  if pdim >= batch_dim:
1315
- _vmap_update_prim_attr(prim, 'dim', pdim + 1)
1316
- elif dim < 0:
1317
- _vmap_update_prim_attr(prim, 'dim', pdim)
1369
+ return pdim + 1
1370
+ if dim < 0:
1371
+ return pdim
1372
+ return dim
1318
1373
 
1319
- def vmap_rule(x_bdim, index_bdim, grad_bdim):
1320
- is_all_none, result = vmap_general_preprocess(prim, x_bdim, index_bdim, grad_bdim)
1374
+ def vmap_rule(x_bdim, dim_bdim, index_bdim, grad_bdim):
1375
+ is_all_none, result = vmap_general_preprocess(prim, x_bdim, dim_bdim, index_bdim, grad_bdim)
1321
1376
  if is_all_none:
1322
1377
  return result
1323
1378
 
1324
1379
  x, x_dim = x_bdim
1380
+ dim, dim_dim = dim_bdim
1381
+ if dim_dim is not None:
1382
+ _raise_value_error("The dim of 'dim' in `GatherDGradV2` must be None, but got {}.".format(dim_dim))
1325
1383
  index, index_dim = index_bdim
1326
1384
  grad, grad_dim = grad_bdim
1327
-
1328
1385
  batch_dim = 0
1329
1386
  if x_dim is not None:
1330
1387
  batch_dim = x_dim
@@ -1336,12 +1393,10 @@ def get_gatherd_grad_v2_vmap_rule(prim, axis_size):
1336
1393
  x = _bdim_at_any(x, x_dim, batch_dim, axis_size)
1337
1394
  index = _bdim_at_any(index, index_dim, batch_dim, axis_size)
1338
1395
  grad = _bdim_at_any(grad, grad_dim, batch_dim, axis_size)
1339
-
1340
- # Adjust dim-attr if needed
1341
1396
  x_rank = F.rank(x) - 1
1342
- _update_attr(x_rank, batch_dim)
1343
-
1344
- out = prim(x, index, grad)
1397
+ # Adjust dim if needed
1398
+ dim = _update_dim(dim, x_rank, batch_dim)
1399
+ out = prim(x, dim, index, grad)
1345
1400
  return (out, batch_dim)
1346
1401
 
1347
1402
  return vmap_rule
@@ -1425,6 +1480,7 @@ def get_meshgrid_vmap_rule(prim, axis_size):
1425
1480
  "The input number of P.Meshgrid must be greater than 1.")
1426
1481
 
1427
1482
  output_shape = []
1483
+ ones_shape = []
1428
1484
  for each_arg in args:
1429
1485
  x, bdim = each_arg
1430
1486
  if bdim is None:
@@ -1435,19 +1491,30 @@ def get_meshgrid_vmap_rule(prim, axis_size):
1435
1491
  _raise_value_error(
1436
1492
  "Each input of Meshgrid must be 1D, but got {}.".format(F.rank(x) - 1))
1437
1493
  output_shape.append(F.shape(x)[-1])
1494
+ ones_shape.append(1)
1438
1495
  output_shape.insert(0, axis_size)
1496
+ ones_shape.insert(0, axis_size)
1439
1497
 
1440
1498
  if indexing == "xy":
1441
1499
  output_shape[1], output_shape[2] = output_shape[2], output_shape[1]
1442
-
1443
1500
  shape = tuple(output_shape)
1501
+
1502
+ input_0, _ = args[0]
1503
+ dtype = F.dtype(input_0)
1504
+ ones_tensor = F.fill(dtype, shape, 1)
1505
+
1506
+ index = 0
1444
1507
  vals_out_tuple = ()
1445
1508
  for each_arg in args:
1446
1509
  x, bdim = each_arg
1447
1510
  x = _bdim_at_front(x, bdim, axis_size)
1448
- x = _handle_broadcasting(x, F.shape(x), output_shape)
1449
- output = P.BroadcastTo(shape)(x)
1511
+ shape_index = (1 - index) if (index <= 1 and indexing == "xy") else index
1512
+ ones_shape[shape_index + 1] = output_shape[shape_index + 1]
1513
+ x = P.Reshape()(x, tuple(ones_shape))
1514
+ output = P.Mul()(x, ones_tensor)
1450
1515
  vals_out_tuple = vals_out_tuple + ((output, 0),)
1516
+ ones_shape[shape_index + 1] = 1
1517
+ index = index + 1
1451
1518
 
1452
1519
  return vals_out_tuple
1453
1520
 
@@ -1491,7 +1558,7 @@ def get_gather_vmap_rule(prim, axis_size):
1491
1558
  else:
1492
1559
  prim_name = prim.name
1493
1560
 
1494
- @constexpr
1561
+ @_primexpr
1495
1562
  def process_axis(axis, x_shape_size, has_xdim: bool, has_idim: bool):
1496
1563
  if has_xdim and has_idim:
1497
1564
  if axis < 0:
@@ -1505,7 +1572,7 @@ def get_gather_vmap_rule(prim, axis_size):
1505
1572
 
1506
1573
  return axis
1507
1574
 
1508
- @constexpr
1575
+ @_primexpr
1509
1576
  def get_x_dst_shape(x_shape, axis):
1510
1577
  target_axis_size = x_shape[axis + 1]
1511
1578
  x_dst_shape = x_shape[0:axis] + (axis_size * target_axis_size,) + x_shape[axis + 2:]
@@ -1705,7 +1772,7 @@ def get_data_format_dim_map_vmap_rule(prim, axis_size):
1705
1772
  def get_expand_dims_vmap_rule(prim, axis_size):
1706
1773
  """VmapRule for `ExpandDims`."""
1707
1774
 
1708
- @constexpr
1775
+ @_primexpr
1709
1776
  def process_axis(axis, rank, x_dim):
1710
1777
  if axis < 0:
1711
1778
  axis += rank
@@ -1799,7 +1866,7 @@ def get_squeeze_vmap_rule(prim, axis_size):
1799
1866
  else:
1800
1867
  prim_axis = None
1801
1868
 
1802
- @constexpr
1869
+ @_primexpr
1803
1870
  def move_axis(axes):
1804
1871
  new_axis = ()
1805
1872
  for axis in axes:
@@ -1809,7 +1876,7 @@ def get_squeeze_vmap_rule(prim, axis_size):
1809
1876
  new_axis = new_axis + (axis + 1,)
1810
1877
  return new_axis
1811
1878
 
1812
- @constexpr
1879
+ @_primexpr
1813
1880
  def generate_all_axis_except_first(x_rank):
1814
1881
  new_axis = ()
1815
1882
  for i in range(1, x_rank, 1):
@@ -1838,6 +1905,7 @@ def get_squeeze_vmap_rule(prim, axis_size):
1838
1905
  batch_squeeze = P.Squeeze(axis=new_axis)
1839
1906
  out = batch_squeeze(x)
1840
1907
  return out, 0
1908
+
1841
1909
  return vmap_rule
1842
1910
 
1843
1911
 
@@ -1852,7 +1920,7 @@ def get_stridedslice_vmap_rule(prim, axis_size):
1852
1920
  batch_stridedslice = P.StridedSlice(new_begin_mask, new_end_mask, new_ellipsis_mask, new_new_axis_mask, \
1853
1921
  new_shrink_axis_mask)
1854
1922
 
1855
- @constexpr
1923
+ @_primexpr
1856
1924
  def get_new_begin_end_strided(begin, end, strided):
1857
1925
  new_begin = (0,) + begin
1858
1926
  new_end = (0,) + end
@@ -1891,9 +1959,9 @@ def get_stridedslice_grad_vmap_rule(prim, axis_size):
1891
1959
  new_new_axis_mask = prim.new_axis_mask * 2
1892
1960
  new_shrink_axis_mask = prim.shrink_axis_mask * 2
1893
1961
  batch_stridedslice_grad = G.StridedSliceGrad(new_begin_mask, new_end_mask, new_ellipsis_mask, new_new_axis_mask, \
1894
- new_shrink_axis_mask)
1962
+ new_shrink_axis_mask)
1895
1963
 
1896
- @constexpr
1964
+ @_primexpr
1897
1965
  def get_new_xshape_begin_end_strided(xshape, begin, end, strided):
1898
1966
  new_xshape = (axis_size,) + xshape
1899
1967
  new_begin = (0,) + begin
@@ -1984,6 +2052,30 @@ def get_im2col_vmap_rule(prim, axis_size):
1984
2052
  return vmap_rule
1985
2053
 
1986
2054
 
2055
+ @vmap_rules_getters.register(P.Split)
2056
+ def get_split_vmap_rule(prim, axis_size):
2057
+ """VmapRule for `Split`."""
2058
+
2059
+ axis = prim.axis
2060
+ if axis >= 0:
2061
+ axis += 1
2062
+ batch_prim = P.Split(axis, prim.output_num)
2063
+
2064
+ def vmap_rule(x_bdim):
2065
+ is_all_none, result = vmap_general_preprocess(prim, x_bdim)
2066
+ if is_all_none:
2067
+ return result
2068
+ x, x_dim = x_bdim
2069
+ x = _bdim_at_front(x, x_dim, axis_size)
2070
+ outputs = batch_prim(x)
2071
+ output = ()
2072
+ for out in outputs:
2073
+ output = output + ((out, 0),)
2074
+ return output
2075
+
2076
+ return vmap_rule
2077
+
2078
+
1987
2079
  get_unsupported_dynamic_vmap_rule = vmap_rules_getters.register(NonZero)(get_unsupported_dynamic_vmap_rule)
1988
2080
  get_unsupported_dynamic_vmap_rule = vmap_rules_getters.register(P.Unique)(get_unsupported_dynamic_vmap_rule)
1989
2081
  get_unsupported_dynamic_vmap_rule = \
@@ -21,12 +21,13 @@ from mindspore.common import Tensor
21
21
  from mindspore.ops import operations as P
22
22
  from mindspore.ops import functional as F
23
23
  from mindspore.ops import constexpr
24
+ from mindspore.ops.primitive import _primexpr
24
25
  from mindspore.ops.operations import math_ops
25
26
  from mindspore.ops.operations import _grad_ops as G
26
27
  from mindspore.ops.operations import nn_ops as nps
27
- from mindspore.ops.composite import _VmapGeneralPreprocess
28
- from mindspore.ops.primitive import Primitive
29
- from mindspore.ops.operations.random_ops import UniformCandidateSampler
28
+ from mindspore.ops.function import _VmapGeneralPreprocess
29
+ from mindspore.ops.primitive import Primitive, _PrimitiveC
30
+ from mindspore.ops.operations.random_ops import UniformCandidateSampler, RandomShuffle
30
31
  from mindspore.ops._grad.grad_base import BpropRegistry as VmapRuleRegistry
31
32
 
32
33
 
@@ -41,7 +42,7 @@ def get_vmap_rule(prim, axis_size):
41
42
  return None
42
43
 
43
44
 
44
- @constexpr
45
+ @_primexpr
45
46
  def _get_broadcast_shape_with_front_axis(x_shape, y_shape):
46
47
  """ Explicitly matched with the broadcast shape, that is, 1 is added to the broadcast position. """
47
48
  x_len = len(x_shape)
@@ -86,7 +87,7 @@ def _handle_broadcasting(x, x_shape, y_shape):
86
87
  return F.reshape(x, broadcast_shape)
87
88
 
88
89
 
89
- @constexpr
90
+ @_primexpr
90
91
  def _get_broadcasting_with_front_axis_additional_axis(x_shape, y_shape):
91
92
  """ Get the axes that are inserted after broadcasting.
92
93
  Args:
@@ -129,15 +130,19 @@ def _raise_value_error(info, param=None):
129
130
  raise ValueError(info + f"{param}")
130
131
 
131
132
 
132
- @constexpr
133
+ @_primexpr
133
134
  def _get_broadcast_shape(x_shape, dst, axis_size):
134
135
  """Get the target shape for broadcast array."""
136
+ def _check(dst, broadcast_ndim):
137
+ if dst < -broadcast_ndim or dst >= broadcast_ndim:
138
+ _raise_value_error("Destination axis {} is out of bounds for array of dimension"
139
+ " [{}, {}).".format(dst, -broadcast_ndim, broadcast_ndim))
140
+
135
141
  x_ndim = len(x_shape)
136
142
  broadcast_ndim = x_ndim + 1
137
143
 
138
- if dst < -broadcast_ndim or dst >= broadcast_ndim:
139
- _raise_value_error("Destination axis {} is out of bounds for array of dimension"
140
- " [{}, {}).".format(dst, -broadcast_ndim, broadcast_ndim))
144
+ _check(dst, broadcast_ndim)
145
+
141
146
  if dst < 0:
142
147
  dst = broadcast_ndim + dst
143
148
 
@@ -190,6 +195,10 @@ def vmap_unstack(dim, val):
190
195
  return P.Unstack(dim)(val)
191
196
 
192
197
 
198
+ def vmap_stack(val):
199
+ return P.Stack()(val)
200
+
201
+
193
202
  def vmap_general_output_process(output):
194
203
  """ Match output to axis 0"""
195
204
  vals_out_tuple = ()
@@ -416,6 +425,8 @@ def _vmap_clone_prim(prim):
416
425
  """
417
426
  Cloning a new primitive object same as `prim`.
418
427
  """
428
+ if isinstance(prim, _PrimitiveC):
429
+ return _PrimitiveC(prim.name, prim.attrs)
419
430
  new_ops = _ops_vmap_clone_prim_dict.get(prim.name, None)
420
431
  if new_ops is None:
421
432
  raise ValueError("Failed to get the primitive object of {} from `_ops_vmap_clone_prim_dict`. Please register "
@@ -433,7 +444,7 @@ def _vmap_clone_prim(prim):
433
444
  return cloned
434
445
 
435
446
 
436
- @constexpr
447
+ @_primexpr
437
448
  def _get_reduce_batch_axis(axis, x_dim, x_ndim):
438
449
  """get batch_axis for reduce* operation."""
439
450
  # For axis, it's value in Union[int, list, tuple]
@@ -481,6 +492,7 @@ _ops_vmap_clone_prim_dict = {
481
492
  "ApplyAdaMax": P.ApplyAdaMax,
482
493
  "ApplyAdadelta": P.ApplyAdadelta,
483
494
  "ApplyRMSProp": P.ApplyRMSProp,
495
+ 'Adam': P.Adam,
484
496
  "ApplyCenteredRMSProp": P.ApplyCenteredRMSProp,
485
497
  "ApplyFtrl": P.ApplyFtrl,
486
498
  "ApplyGradientDescent": P.ApplyGradientDescent,
@@ -508,4 +520,6 @@ _ops_vmap_clone_prim_dict = {
508
520
  "SparseApplyAdagrad": P.SparseApplyAdagrad,
509
521
  "SparseApplyAdagradV2": P.SparseApplyAdagradV2,
510
522
  "SparseApplyFtrl": P.SparseApplyFtrl,
523
+ "RandomShuffle": RandomShuffle,
524
+ "RandomChoiceWithMask": P.RandomChoiceWithMask
511
525
  }
@@ -16,9 +16,9 @@
16
16
  """convolution vmap impl"""
17
17
  from __future__ import absolute_import
18
18
 
19
- import numpy as np
20
19
  import mindspore.numpy as mnp
21
20
  from mindspore.ops import constexpr
21
+ from mindspore.ops.primitive import _primexpr
22
22
  from mindspore.ops import operations as P
23
23
  from mindspore.ops import functional as F
24
24
  from mindspore.ops.operations import nn_ops as nps
@@ -142,7 +142,7 @@ def get_conv3d_backprop_filter_vmap_rule(prim, axis_size):
142
142
  return vmap_rule
143
143
 
144
144
 
145
- @constexpr
145
+ @_primexpr
146
146
  def _get_reshape_src_dim(data_dim, cmp_dim):
147
147
  """Get source dim for reshape"""
148
148
  if data_dim > cmp_dim:
@@ -154,7 +154,7 @@ def _get_reshape_src_dim(data_dim, cmp_dim):
154
154
  return expand_dim, merge_dim
155
155
 
156
156
 
157
- @constexpr
157
+ @_primexpr
158
158
  def _get_merge_shape(src_dim, dst_dim, shape):
159
159
  """Get new shape for merging the src_dim and dst_dim. The dst_dim is the value after removing src_dim."""
160
160
  new_shape = [shape[i] for i in range(len(shape)) if i != src_dim]
@@ -171,13 +171,10 @@ def _reshape_merge_dims(src_dim, dst_dim, target):
171
171
  return output, new_shape
172
172
 
173
173
 
174
- @constexpr
174
+ @_primexpr
175
175
  def _get_expand_shape(src_dim, dst_size, shape, prim_name):
176
176
  """Get new shape for splitting src_dim into dst_size parts."""
177
- dst_size2, remainder = np.divmod(shape[src_dim], dst_size)
178
- if remainder != 0:
179
- _raise_value_error("The remainder of {} / {} should be 0, "
180
- "but got {} in {}.".format(shape[src_dim], dst_size, remainder, prim_name))
177
+ dst_size2 = shape[src_dim] // dst_size
181
178
  new_shape = list(shape)
182
179
  new_shape[src_dim:(src_dim + 1)] = [dst_size, dst_size2]
183
180
  return tuple(new_shape)
@@ -190,7 +187,7 @@ def _reshape_expand_dims(src_dim, dst_size, target, prim_name):
190
187
  return F.reshape(target, new_shape)
191
188
 
192
189
 
193
- @constexpr
190
+ @_primexpr
194
191
  def _get_new_size_by_index(input_size, batch_size, index):
195
192
  """Get the new size of input_size by multiplying input_size[index] by batch_size."""
196
193
  new_size = ()
@@ -201,7 +198,7 @@ def _get_new_size_by_index(input_size, batch_size, index):
201
198
  return tuple(new_size)
202
199
 
203
200
 
204
- @constexpr
201
+ @_primexpr
205
202
  def _update_group_attr(prim, groups, batch_size):
206
203
  """Set new value for 'group' attribute of the convolution primitive."""
207
204
  group = groups * batch_size
@@ -17,9 +17,9 @@
17
17
  from __future__ import absolute_import
18
18
 
19
19
  from mindspore.ops import functional as F
20
- from mindspore.ops import constexpr
20
+ from mindspore.ops.primitive import _primexpr
21
21
  from mindspore.ops.operations import _grad_ops as G
22
- from mindspore.ops.composite import _VmapGeneralRule
22
+ from mindspore.ops.function import _VmapGeneralRule
23
23
  from mindspore.ops._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, _bdim_at_front, \
24
24
  _handle_broadcasting, get_unary_grad_vmap_rule, _get_broadcasting_with_front_axis_additional_axis
25
25
 
@@ -36,7 +36,7 @@ def get_broadcast_binary_op_grad_vmap_rule(prim, axis_size):
36
36
  if isinstance(prim, str):
37
37
  prim = broadcast_binary_op_grad_map.get(prim)()
38
38
 
39
- @constexpr
39
+ @_primexpr
40
40
  def get_longest_shape(x_shape, y_shape, g_shape):
41
41
  x_rank = len(x_shape)
42
42
  y_rank = len(y_shape)
@@ -148,7 +148,7 @@ def get_median_grad_vmap_rule(prim, axis_size):
148
148
  axis = prim.axis
149
149
  keep_dims = prim.keep_dims
150
150
 
151
- @constexpr
151
+ @_primexpr
152
152
  def trans_grad_axis(axis, rank, dim, keep_dims):
153
153
  if axis < 0:
154
154
  axis += rank - 1