mindspore 1.10.0__cp37-cp37m-win_amd64.whl → 2.0.0rc1__cp37-cp37m-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (966) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/ConcurrencyCheck.dll +0 -0
  3. mindspore/CppBuildInsights.dll +0 -0
  4. mindspore/CppCoreCheck.dll +0 -0
  5. mindspore/EnumIndex.dll +0 -0
  6. mindspore/EspXEngine.dll +0 -0
  7. mindspore/HResultCheck.dll +0 -0
  8. mindspore/KernelTraceControl.dll +0 -0
  9. mindspore/LocalESPC.dll +0 -0
  10. mindspore/Microsoft.Diagnostics.Tracing.EventSource.dll +0 -0
  11. mindspore/Microsoft.VisualStudio.RemoteControl.dll +0 -0
  12. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  13. mindspore/Microsoft.VisualStudio.Utilities.Internal.dll +0 -0
  14. mindspore/Newtonsoft.Json.dll +0 -0
  15. mindspore/System.Runtime.CompilerServices.Unsafe.dll +0 -0
  16. mindspore/VariantClear.dll +0 -0
  17. mindspore/__init__.py +9 -4
  18. mindspore/_c_dataengine.cp37-win_amd64.pyd +0 -0
  19. mindspore/_c_expression.cp37-win_amd64.pyd +0 -0
  20. mindspore/_c_mindrecord.cp37-win_amd64.pyd +0 -0
  21. mindspore/_check_jit_forbidden_api.py +102 -0
  22. mindspore/_checkparam.py +1066 -1001
  23. mindspore/_extends/builtin_operations.py +32 -4
  24. mindspore/_extends/graph_kernel/model/graph_split.py +66 -222
  25. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +12 -9
  26. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +119 -26
  27. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +50 -50
  28. mindspore/_extends/parallel_compile/akg_compiler/util.py +9 -6
  29. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +4 -25
  30. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +9 -4
  31. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -27
  32. mindspore/_extends/parse/__init__.py +5 -3
  33. mindspore/_extends/parse/namespace.py +17 -2
  34. mindspore/_extends/parse/parser.py +193 -34
  35. mindspore/_extends/parse/resources.py +7 -8
  36. mindspore/_extends/parse/standard_method.py +1780 -435
  37. mindspore/_extends/parse/trope.py +3 -1
  38. mindspore/amp.py +53 -58
  39. mindspore/atlprov.dll +0 -0
  40. mindspore/boost/adasum.py +3 -2
  41. mindspore/boost/boost.py +2 -2
  42. mindspore/boost/boost_cell_wrapper.py +46 -26
  43. mindspore/boost/dim_reduce.py +6 -5
  44. mindspore/boost/grad_accumulation.py +2 -1
  45. mindspore/boost/group_loss_scale_manager.py +1 -1
  46. mindspore/c1.dll +0 -0
  47. mindspore/c1xx.dll +0 -0
  48. mindspore/c2.dll +0 -0
  49. mindspore/cfgpersist.dll +0 -0
  50. mindspore/clang_rt.asan_dbg_dynamic-x86_64.dll +0 -0
  51. mindspore/clang_rt.asan_dynamic-x86_64.dll +0 -0
  52. mindspore/common/__init__.py +11 -10
  53. mindspore/common/_decorator.py +2 -0
  54. mindspore/common/_register_for_adapter.py +55 -0
  55. mindspore/common/_stub_tensor.py +201 -0
  56. mindspore/common/_utils.py +57 -0
  57. mindspore/common/api.py +582 -297
  58. mindspore/common/dtype.py +66 -18
  59. mindspore/common/dump.py +2 -2
  60. mindspore/common/initializer.py +38 -1
  61. mindspore/common/jit_config.py +25 -13
  62. mindspore/common/mutable.py +53 -24
  63. mindspore/common/parameter.py +60 -37
  64. mindspore/common/seed.py +8 -24
  65. mindspore/common/sparse_tensor.py +927 -0
  66. mindspore/common/tensor.py +1627 -3900
  67. mindspore/communication/__init__.py +10 -5
  68. mindspore/communication/_comm_helper.py +78 -214
  69. mindspore/communication/_hccl_management.py +2 -1
  70. mindspore/communication/management.py +136 -47
  71. mindspore/config/op_info.config +501 -1008
  72. mindspore/context.py +291 -56
  73. mindspore/d3dcompiler_47.dll +0 -0
  74. mindspore/dataset/__init__.py +12 -8
  75. mindspore/dataset/audio/__init__.py +9 -9
  76. mindspore/dataset/audio/transforms.py +1090 -228
  77. mindspore/dataset/audio/utils.py +87 -39
  78. mindspore/dataset/audio/validators.py +223 -1
  79. mindspore/dataset/callback/ds_callback.py +17 -15
  80. mindspore/dataset/core/config.py +246 -17
  81. mindspore/dataset/core/py_util_helpers.py +4 -3
  82. mindspore/dataset/core/validator_helpers.py +10 -10
  83. mindspore/{parallel/nn/layers.py → dataset/debug/__init__.py} +7 -8
  84. mindspore/dataset/debug/debug_hook.py +65 -0
  85. mindspore/dataset/debug/pre_defined_hook.py +67 -0
  86. mindspore/dataset/engine/__init__.py +7 -3
  87. mindspore/dataset/engine/cache_client.py +9 -9
  88. mindspore/dataset/engine/datasets.py +648 -477
  89. mindspore/dataset/engine/datasets_audio.py +165 -167
  90. mindspore/dataset/engine/datasets_standard_format.py +93 -67
  91. mindspore/dataset/engine/datasets_text.py +492 -342
  92. mindspore/dataset/engine/datasets_user_defined.py +85 -50
  93. mindspore/dataset/engine/datasets_vision.py +1224 -699
  94. mindspore/dataset/engine/graphdata.py +134 -69
  95. mindspore/dataset/engine/iterators.py +50 -9
  96. mindspore/dataset/engine/offload.py +52 -31
  97. mindspore/dataset/engine/samplers.py +27 -24
  98. mindspore/dataset/engine/serializer_deserializer.py +14 -15
  99. mindspore/dataset/engine/validators.py +213 -52
  100. mindspore/dataset/text/__init__.py +10 -8
  101. mindspore/dataset/text/transforms.py +152 -57
  102. mindspore/dataset/text/utils.py +98 -49
  103. mindspore/dataset/text/validators.py +25 -0
  104. mindspore/dataset/transforms/__init__.py +4 -2
  105. mindspore/dataset/transforms/c_transforms.py +11 -13
  106. mindspore/dataset/transforms/py_transforms.py +2 -2
  107. mindspore/dataset/transforms/py_transforms_util.py +10 -0
  108. mindspore/dataset/transforms/transforms.py +13 -15
  109. mindspore/dataset/transforms/validators.py +7 -7
  110. mindspore/dataset/utils/__init__.py +2 -1
  111. mindspore/dataset/utils/browse_dataset.py +13 -13
  112. mindspore/dataset/utils/line_reader.py +121 -0
  113. mindspore/dataset/vision/__init__.py +8 -7
  114. mindspore/dataset/vision/c_transforms.py +125 -126
  115. mindspore/dataset/vision/py_transforms.py +37 -37
  116. mindspore/dataset/vision/py_transforms_util.py +23 -20
  117. mindspore/dataset/vision/transforms.py +316 -315
  118. mindspore/dataset/vision/utils.py +313 -17
  119. mindspore/dataset/vision/validators.py +6 -6
  120. mindspore/default_config.py +0 -1
  121. mindspore/dpcmi.dll +0 -0
  122. mindspore/{compression → experimental}/__init__.py +6 -5
  123. mindspore/experimental/map_parameter.py +275 -0
  124. mindspore/include/OWNERS +0 -1
  125. mindspore/include/api/callback/callback.h +9 -13
  126. mindspore/include/api/callback/ckpt_saver.h +2 -2
  127. mindspore/include/api/callback/loss_monitor.h +2 -2
  128. mindspore/include/api/callback/lr_scheduler.h +5 -5
  129. mindspore/include/api/callback/time_monitor.h +2 -2
  130. mindspore/include/api/callback/train_accuracy.h +4 -6
  131. mindspore/include/api/cfg.h +19 -6
  132. mindspore/include/api/context.h +70 -9
  133. mindspore/include/api/delegate.h +8 -1
  134. mindspore/include/api/dual_abi_helper.h +8 -24
  135. mindspore/include/api/metrics/accuracy.h +2 -2
  136. mindspore/include/api/metrics/metrics.h +4 -3
  137. mindspore/include/api/model.h +9 -4
  138. mindspore/include/api/model_group.h +68 -0
  139. mindspore/include/api/model_parallel_runner.h +17 -17
  140. mindspore/include/api/net.h +12 -11
  141. mindspore/include/api/serialization.h +20 -4
  142. mindspore/include/api/status.h +7 -1
  143. mindspore/include/api/types.h +25 -21
  144. mindspore/include/api/visible.h +4 -0
  145. mindspore/include/c_api/model_c.h +5 -0
  146. mindspore/include/c_api/status_c.h +1 -1
  147. mindspore/include/dataset/config.h +1 -1
  148. mindspore/include/dataset/constants.h +14 -0
  149. mindspore/include/dataset/text.h +59 -0
  150. mindspore/include/dataset/vision.h +56 -117
  151. mindspore/include/dataset/vision_lite.h +102 -0
  152. mindspore/jpeg62.dll +0 -0
  153. mindspore/log.py +28 -28
  154. mindspore/mindrecord/common/exceptions.py +2 -4
  155. mindspore/mindrecord/filereader.py +19 -1
  156. mindspore/mindrecord/filewriter.py +250 -88
  157. mindspore/mindrecord/mindpage.py +13 -13
  158. mindspore/mindrecord/shardheader.py +15 -15
  159. mindspore/mindrecord/shardreader.py +9 -0
  160. mindspore/mindrecord/shardwriter.py +29 -29
  161. mindspore/mindrecord/tools/cifar100_to_mr.py +9 -9
  162. mindspore/mindrecord/tools/cifar10_to_mr.py +9 -9
  163. mindspore/mindrecord/tools/csv_to_mr.py +4 -4
  164. mindspore/mindrecord/tools/imagenet_to_mr.py +70 -65
  165. mindspore/mindrecord/tools/mnist_to_mr.py +41 -41
  166. mindspore/mindrecord/tools/tfrecord_to_mr.py +6 -6
  167. mindspore/{libmindspore_backend.dll → mindspore_backend.dll} +0 -0
  168. mindspore/mindspore_common.dll +0 -0
  169. mindspore/mindspore_core.dll +0 -0
  170. mindspore/mindspore_glog.dll +0 -0
  171. mindspore/mindspore_shared_lib.dll +0 -0
  172. mindspore/msobj140.dll +0 -0
  173. mindspore/mspdb140.dll +0 -0
  174. mindspore/mspdbcore.dll +0 -0
  175. mindspore/mspdbst.dll +0 -0
  176. mindspore/mspft140.dll +0 -0
  177. mindspore/msvcdis140.dll +0 -0
  178. mindspore/msvcp140_1.dll +0 -0
  179. mindspore/msvcp140_2.dll +0 -0
  180. mindspore/msvcp140_atomic_wait.dll +0 -0
  181. mindspore/msvcp140_codecvt_ids.dll +0 -0
  182. mindspore/nn/__init__.py +1 -5
  183. mindspore/nn/cell.py +297 -234
  184. mindspore/nn/dynamic_lr.py +1 -1
  185. mindspore/nn/grad/cell_grad.py +17 -42
  186. mindspore/nn/layer/__init__.py +7 -4
  187. mindspore/nn/layer/activation.py +131 -88
  188. mindspore/nn/layer/basic.py +313 -613
  189. mindspore/nn/layer/channel_shuffle.py +103 -0
  190. mindspore/nn/layer/combined.py +1 -1
  191. mindspore/nn/layer/container.py +52 -6
  192. mindspore/nn/layer/conv.py +112 -43
  193. mindspore/nn/layer/dense.py +10 -9
  194. mindspore/nn/layer/embedding.py +36 -34
  195. mindspore/nn/layer/image.py +123 -27
  196. mindspore/nn/layer/math.py +108 -107
  197. mindspore/nn/layer/normalization.py +212 -366
  198. mindspore/nn/layer/padding.py +370 -42
  199. mindspore/nn/layer/pooling.py +1443 -219
  200. mindspore/nn/layer/rnn_cells.py +11 -16
  201. mindspore/nn/layer/rnns.py +38 -39
  202. mindspore/nn/layer/thor_layer.py +24 -25
  203. mindspore/nn/layer/timedistributed.py +5 -5
  204. mindspore/nn/layer/transformer.py +701 -0
  205. mindspore/nn/learning_rate_schedule.py +8 -8
  206. mindspore/nn/loss/__init__.py +9 -6
  207. mindspore/nn/loss/loss.py +678 -142
  208. mindspore/nn/metrics.py +53 -0
  209. mindspore/nn/optim/_dist_optimizer_registry.py +2 -2
  210. mindspore/nn/optim/ada_grad.py +8 -8
  211. mindspore/nn/optim/adadelta.py +2 -3
  212. mindspore/nn/optim/adafactor.py +18 -14
  213. mindspore/nn/optim/adam.py +429 -87
  214. mindspore/nn/optim/adamax.py +5 -6
  215. mindspore/nn/optim/adasum.py +10 -8
  216. mindspore/nn/optim/asgd.py +7 -7
  217. mindspore/nn/optim/ftrl.py +81 -11
  218. mindspore/nn/optim/lamb.py +7 -8
  219. mindspore/nn/optim/lars.py +4 -4
  220. mindspore/nn/optim/lazyadam.py +82 -7
  221. mindspore/nn/optim/momentum.py +8 -7
  222. mindspore/nn/optim/optimizer.py +19 -10
  223. mindspore/nn/optim/proximal_ada_grad.py +6 -5
  224. mindspore/nn/optim/rmsprop.py +3 -3
  225. mindspore/nn/optim/rprop.py +20 -16
  226. mindspore/nn/optim/sgd.py +21 -15
  227. mindspore/nn/optim/thor.py +23 -21
  228. mindspore/nn/probability/__init__.py +0 -2
  229. mindspore/nn/probability/bijector/bijector.py +7 -6
  230. mindspore/nn/probability/bijector/invert.py +4 -2
  231. mindspore/nn/probability/bijector/softplus.py +2 -2
  232. mindspore/nn/probability/bnn_layers/dense_variational.py +1 -1
  233. mindspore/nn/probability/bnn_layers/layer_distribution.py +2 -2
  234. mindspore/nn/probability/distribution/__init__.py +6 -0
  235. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -2
  236. mindspore/nn/probability/distribution/_utils/utils.py +11 -17
  237. mindspore/nn/probability/distribution/bernoulli.py +6 -6
  238. mindspore/nn/probability/distribution/beta.py +1 -1
  239. mindspore/nn/probability/distribution/categorical.py +9 -9
  240. mindspore/nn/probability/distribution/cauchy.py +8 -8
  241. mindspore/nn/probability/distribution/distribution.py +12 -6
  242. mindspore/nn/probability/distribution/exponential.py +5 -5
  243. mindspore/nn/probability/distribution/gamma.py +3 -3
  244. mindspore/nn/probability/distribution/geometric.py +6 -5
  245. mindspore/nn/probability/distribution/gumbel.py +5 -5
  246. mindspore/nn/probability/distribution/half_normal.py +133 -0
  247. mindspore/nn/probability/distribution/laplace.py +128 -0
  248. mindspore/nn/probability/distribution/log_normal.py +0 -1
  249. mindspore/nn/probability/distribution/logistic.py +4 -5
  250. mindspore/nn/probability/distribution/normal.py +11 -15
  251. mindspore/nn/probability/distribution/poisson.py +6 -2
  252. mindspore/nn/probability/distribution/student_t.py +150 -0
  253. mindspore/nn/probability/distribution/transformed_distribution.py +4 -4
  254. mindspore/nn/probability/distribution/uniform.py +5 -5
  255. mindspore/nn/reinforcement/_tensors_queue.py +3 -3
  256. mindspore/nn/reinforcement/tensor_array.py +2 -2
  257. mindspore/nn/sparse/sparse.py +8 -1
  258. mindspore/nn/wrap/cell_wrapper.py +55 -27
  259. mindspore/nn/wrap/grad_reducer.py +20 -11
  260. mindspore/nn/wrap/loss_scale.py +47 -30
  261. mindspore/numpy/array_creations.py +33 -22
  262. mindspore/numpy/array_ops.py +46 -42
  263. mindspore/numpy/logic_ops.py +6 -27
  264. mindspore/numpy/math_ops.py +26 -19
  265. mindspore/numpy/utils.py +1 -8
  266. mindspore/numpy/utils_const.py +112 -62
  267. mindspore/opencv_core452.dll +0 -0
  268. mindspore/opencv_imgcodecs452.dll +0 -0
  269. mindspore/opencv_imgproc452.dll +0 -0
  270. mindspore/ops/__init__.py +6 -3
  271. mindspore/ops/_constants.py +0 -6
  272. mindspore/ops/_grad/__init__.py +2 -1
  273. mindspore/ops/_grad/grad_array_ops.py +209 -152
  274. mindspore/ops/_grad/grad_base.py +55 -17
  275. mindspore/ops/_grad/grad_clip_ops.py +11 -3
  276. mindspore/ops/_grad/grad_comm_ops.py +58 -47
  277. mindspore/ops/_grad/grad_implementations.py +21 -61
  278. mindspore/ops/_grad/grad_inner_ops.py +48 -6
  279. mindspore/ops/_grad/grad_math_ops.py +306 -161
  280. mindspore/ops/_grad/grad_nn_ops.py +192 -181
  281. mindspore/ops/_grad/grad_other_ops.py +1 -1
  282. mindspore/ops/_grad/grad_quant_ops.py +5 -5
  283. mindspore/ops/_grad/grad_sequence_ops.py +296 -0
  284. mindspore/ops/_grad/grad_sparse.py +15 -9
  285. mindspore/ops/_grad_experimental/__init__.py +1 -0
  286. mindspore/ops/_grad_experimental/grad_array_ops.py +441 -55
  287. mindspore/ops/_grad_experimental/grad_image_ops.py +25 -7
  288. mindspore/ops/_grad_experimental/grad_inner_ops.py +3 -44
  289. mindspore/ops/_grad_experimental/grad_linalg_ops.py +16 -21
  290. mindspore/ops/_grad_experimental/grad_math_ops.py +979 -49
  291. mindspore/ops/_grad_experimental/grad_nn_ops.py +78 -8
  292. mindspore/ops/_grad_experimental/grad_scalar_ops.py +112 -0
  293. mindspore/ops/_grad_experimental/grad_sparse_ops.py +197 -13
  294. mindspore/ops/_op_impl/__init__.py +3 -3
  295. mindspore/ops/_op_impl/_custom_op/__init__.py +0 -1
  296. mindspore/ops/_op_impl/_custom_op/_basic.py +0 -1
  297. mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +1 -1
  298. mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +4 -2
  299. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +2 -2
  300. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +2 -2
  301. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +5 -5
  302. mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +3 -3
  303. mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +1 -1
  304. mindspore/ops/_op_impl/_custom_op/correction_mul.py +3 -3
  305. mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +2 -2
  306. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +4 -8
  307. mindspore/ops/_op_impl/_custom_op/dsd_impl.py +1 -1
  308. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +2 -2
  309. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +2 -2
  310. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +2 -2
  311. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +2 -2
  312. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +2 -2
  313. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +2 -2
  314. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +2 -2
  315. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +2 -2
  316. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +2 -2
  317. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +2 -2
  318. mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +1 -1
  319. mindspore/ops/_op_impl/_custom_op/img2col_impl.py +1 -1
  320. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
  321. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +1 -1
  322. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +1 -1
  323. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +1 -1
  324. mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +2 -2
  325. mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py +0 -1
  326. mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +0 -1
  327. mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +1 -1
  328. mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +2 -2
  329. mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +2 -2
  330. mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +1 -1
  331. mindspore/ops/_op_impl/aicpu/__init__.py +238 -3
  332. mindspore/ops/_op_impl/aicpu/abs.py +36 -0
  333. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d.py +34 -0
  334. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
  335. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d.py +39 -0
  336. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d_grad.py +39 -0
  337. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py +37 -0
  338. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d.py +42 -0
  339. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d_grad.py +152 -0
  340. mindspore/ops/_op_impl/aicpu/add.py +43 -0
  341. mindspore/ops/_op_impl/aicpu/addcdiv.py +0 -32
  342. mindspore/ops/_op_impl/aicpu/addcmul.py +0 -84
  343. mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
  344. mindspore/ops/_op_impl/aicpu/arg_max.py +75 -0
  345. mindspore/ops/_op_impl/aicpu/arg_min.py +75 -0
  346. mindspore/ops/_op_impl/aicpu/argmin_with_value.py +43 -0
  347. mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -0
  348. mindspore/ops/_op_impl/aicpu/batch_norm_grad_grad.py +49 -0
  349. mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
  350. mindspore/ops/_op_impl/aicpu/bessel_i0.py +31 -0
  351. mindspore/ops/_op_impl/aicpu/bias_add.py +44 -0
  352. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +43 -0
  353. mindspore/ops/_op_impl/aicpu/bincount.py +33 -0
  354. mindspore/{nn/probability/infer/variational/__init__.py → ops/_op_impl/aicpu/cauchy.py} +17 -10
  355. mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
  356. mindspore/ops/_op_impl/aicpu/cholesky.py +1 -1
  357. mindspore/ops/_op_impl/{cpu/bias_add.py → aicpu/choleskygrad.py} +9 -7
  358. mindspore/ops/_op_impl/aicpu/combined_non_max_suppression.py +42 -0
  359. mindspore/ops/_op_impl/aicpu/concat_offset.py +42 -0
  360. mindspore/ops/_op_impl/aicpu/concat_offset_v1.py +31 -0
  361. mindspore/ops/_op_impl/aicpu/conj.py +11 -0
  362. mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_image.py +38 -0
  363. mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +36 -0
  364. mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
  365. mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +2 -2
  366. mindspore/ops/_op_impl/aicpu/dense_to_sparse_set_operation.py +48 -0
  367. mindspore/ops/_op_impl/aicpu/diag.py +36 -0
  368. mindspore/ops/_op_impl/aicpu/diag_part.py +36 -0
  369. mindspore/ops/_op_impl/aicpu/diagonal.py +35 -0
  370. mindspore/ops/_op_impl/{cpu/bias_add_grad.py → aicpu/digamma.py} +9 -7
  371. mindspore/ops/_op_impl/aicpu/eig.py +35 -0
  372. mindspore/ops/_op_impl/aicpu/fft_with_size.py +41 -0
  373. mindspore/ops/_op_impl/aicpu/flatten.py +1 -0
  374. mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
  375. mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
  376. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +1 -1
  377. mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
  378. mindspore/ops/_op_impl/aicpu/glu.py +33 -0
  379. mindspore/ops/_op_impl/aicpu/glu_grad.py +34 -0
  380. mindspore/ops/_op_impl/aicpu/greater.py +41 -0
  381. mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
  382. mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
  383. mindspore/ops/_op_impl/{tbe/scatter_add_ds.py → aicpu/inplace_index_add.py} +17 -21
  384. mindspore/ops/_op_impl/aicpu/instance_norm_v2.py +41 -0
  385. mindspore/ops/_op_impl/aicpu/instance_norm_v2_grad.py +44 -0
  386. mindspore/ops/_op_impl/aicpu/layer_norm_grad_grad.py +47 -0
  387. mindspore/ops/_op_impl/aicpu/less.py +41 -0
  388. mindspore/ops/_op_impl/aicpu/less_equal.py +41 -0
  389. mindspore/ops/_op_impl/aicpu/lgamma.py +32 -0
  390. mindspore/ops/_op_impl/aicpu/log_normal_reverse.py +33 -0
  391. mindspore/ops/_op_impl/aicpu/logit.py +33 -0
  392. mindspore/ops/_op_impl/aicpu/logit_grad.py +34 -0
  393. mindspore/ops/_op_impl/aicpu/masked_fill.py +42 -0
  394. mindspore/ops/_op_impl/aicpu/masked_scatter.py +39 -0
  395. mindspore/ops/_op_impl/aicpu/matmul.py +39 -0
  396. mindspore/ops/_op_impl/aicpu/matrix_logarithm.py +31 -0
  397. mindspore/ops/_op_impl/aicpu/matrix_power.py +32 -0
  398. mindspore/ops/_op_impl/aicpu/matrix_solve_ls.py +36 -0
  399. mindspore/ops/_op_impl/aicpu/matrix_triangular_solve.py +36 -0
  400. mindspore/ops/_op_impl/aicpu/mirror_pad.py +2 -0
  401. mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +0 -4
  402. mindspore/ops/_op_impl/aicpu/mul.py +3 -1
  403. mindspore/ops/_op_impl/aicpu/multinomial.py +14 -6
  404. mindspore/ops/_op_impl/aicpu/multinomial_with_replacement.py +35 -0
  405. mindspore/ops/_op_impl/aicpu/nan_to_num.py +34 -0
  406. mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
  407. mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
  408. mindspore/ops/_op_impl/aicpu/ones_like.py +0 -2
  409. mindspore/ops/_op_impl/aicpu/polar.py +32 -0
  410. mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
  411. mindspore/ops/_op_impl/aicpu/qr.py +36 -0
  412. mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
  413. mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
  414. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
  415. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_tensor.py +74 -0
  416. mindspore/ops/_op_impl/aicpu/random_shuffle.py +3 -0
  417. mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
  418. mindspore/ops/_op_impl/aicpu/range.py +36 -0
  419. mindspore/ops/_op_impl/aicpu/reciprocal.py +34 -0
  420. mindspore/ops/_op_impl/aicpu/reciprocal_grad.py +35 -0
  421. mindspore/ops/_op_impl/aicpu/reduce_sum.py +57 -0
  422. mindspore/ops/_op_impl/aicpu/resize_bicubic.py +2 -8
  423. mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +1 -1
  424. mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
  425. mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
  426. mindspore/ops/_op_impl/aicpu/scatter_elements.py +4 -0
  427. mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +2 -0
  428. mindspore/ops/_op_impl/aicpu/search_sorted.py +12 -6
  429. mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py +34 -0
  430. mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
  431. mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
  432. mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
  433. mindspore/ops/_op_impl/aicpu/slice_grad.py +76 -0
  434. mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
  435. mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
  436. mindspore/ops/_op_impl/aicpu/sort.py +39 -0
  437. mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +0 -24
  438. mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
  439. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows.py +63 -0
  440. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows_grad.py +45 -0
  441. mindspore/ops/_op_impl/aicpu/sparse_matrix_mat_mul.py +56 -0
  442. mindspore/ops/_op_impl/{tbe/slice_ds.py → aicpu/sparse_segment_sum.py} +16 -24
  443. mindspore/ops/_op_impl/aicpu/sparse_segment_sum_with_num_segments.py +68 -0
  444. mindspore/ops/_op_impl/aicpu/sparse_slice.py +63 -0
  445. mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +61 -0
  446. mindspore/ops/_op_impl/aicpu/squared_difference.py +2 -0
  447. mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +93 -0
  448. mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +66 -0
  449. mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
  450. mindspore/ops/_op_impl/{tbe/gather_v2.py → aicpu/tile.py} +24 -24
  451. mindspore/ops/_op_impl/aicpu/tridiagonal_solve.py +35 -0
  452. mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
  453. mindspore/ops/_op_impl/aicpu/triu_indices.py +34 -0
  454. mindspore/ops/_op_impl/aicpu/uniform.py +34 -0
  455. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +1 -0
  456. mindspore/ops/_op_impl/aicpu/unique_consecutive.py +10 -2
  457. mindspore/ops/_op_impl/cpu/__init__.py +1 -2
  458. mindspore/ops/_op_impl/cpu/dynamic_shape.py +5 -1
  459. mindspore/ops/_op_impl/cpu/maximum_grad.py +2 -0
  460. mindspore/{compression/common/__init__.py → ops/_op_impl/cpu/pyexecute.py} +13 -8
  461. mindspore/ops/_op_impl/cpu/reduce_sum.py +8 -0
  462. mindspore/ops/_op_impl/cpu/sparse_slice.py +62 -0
  463. mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +60 -0
  464. mindspore/ops/_op_impl/cpu/tensor_shape.py +5 -1
  465. mindspore/ops/_op_impl/tbe/__init__.py +27 -608
  466. mindspore/ops/_op_impl/tbe/addcdiv_ds.py +42 -0
  467. mindspore/ops/_op_impl/tbe/addcmul_ds.py +44 -0
  468. mindspore/ops/_op_impl/tbe/assign_add_ds.py +1 -0
  469. mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
  470. mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +1 -1
  471. mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad_v2.py +0 -1
  472. mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
  473. mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +1 -1
  474. mindspore/ops/_op_impl/tbe/batch_to_space_nd_v2.py +41 -0
  475. mindspore/ops/_op_impl/tbe/bce_with_logits_loss.py +1 -0
  476. mindspore/ops/_op_impl/tbe/bias_add_grad.py +2 -0
  477. mindspore/ops/_op_impl/tbe/bn_infer_grad.py +4 -2
  478. mindspore/ops/_op_impl/tbe/bn_infer_grad_ds.py +40 -0
  479. mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -1
  480. mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -1
  481. mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +6 -4
  482. mindspore/ops/_op_impl/tbe/cast.py +0 -2
  483. mindspore/ops/_op_impl/tbe/cast_ds.py +3 -3
  484. mindspore/ops/_op_impl/tbe/ctc_loss_v2.py +0 -2
  485. mindspore/ops/_op_impl/tbe/ctc_loss_v2_grad.py +0 -2
  486. mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +1 -0
  487. mindspore/ops/_op_impl/tbe/deformable_offsets.py +1 -0
  488. mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +1 -1
  489. mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +1 -1
  490. mindspore/ops/_op_impl/tbe/gather_nd.py +1 -0
  491. mindspore/ops/_op_impl/tbe/greater.py +2 -0
  492. mindspore/ops/_op_impl/tbe/{index_add.py → inplace_index_add.py} +3 -6
  493. mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2.py +0 -1
  494. mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +35 -0
  495. mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +35 -0
  496. mindspore/ops/_op_impl/tbe/one_hot_ds.py +0 -6
  497. mindspore/ops/_op_impl/tbe/{greater_ds.py → reduce_all_ds.py} +13 -16
  498. mindspore/ops/_op_impl/tbe/reduce_any_ds.py +39 -0
  499. mindspore/ops/_op_impl/tbe/roi_align_ds.py +44 -0
  500. mindspore/ops/_op_impl/tbe/roi_align_grad_ds.py +44 -0
  501. mindspore/ops/_op_impl/tbe/scatter_add.py +2 -0
  502. mindspore/ops/_op_impl/tbe/scatter_nd_add.py +2 -2
  503. mindspore/ops/_op_impl/tbe/slice.py +26 -15
  504. mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
  505. mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +1 -1
  506. mindspore/ops/_op_impl/tbe/strided_slice_grad_d.py +1 -0
  507. mindspore/ops/_op_impl/tbe/trans_data_ds.py +15 -5
  508. mindspore/ops/_op_impl/tbe/unsorted_segment_sum.py +1 -1
  509. mindspore/ops/_op_impl/tbe/unsorted_segment_sum_ds.py +2 -0
  510. mindspore/ops/_primitive_cache.py +3 -2
  511. mindspore/ops/_register_for_op.py +11 -0
  512. mindspore/ops/_utils/__init__.py +1 -1
  513. mindspore/ops/_utils/utils.py +20 -41
  514. mindspore/ops/_vmap/__init__.py +2 -2
  515. mindspore/ops/_vmap/vmap_array_ops.py +170 -78
  516. mindspore/ops/_vmap/vmap_base.py +24 -10
  517. mindspore/ops/_vmap/vmap_convolution_ops.py +7 -10
  518. mindspore/ops/_vmap/vmap_grad_math_ops.py +4 -4
  519. mindspore/ops/_vmap/vmap_grad_nn_ops.py +41 -9
  520. mindspore/ops/_vmap/vmap_image_ops.py +52 -0
  521. mindspore/ops/_vmap/vmap_math_ops.py +77 -6
  522. mindspore/ops/_vmap/vmap_nn_ops.py +78 -29
  523. mindspore/ops/_vmap/vmap_other_ops.py +3 -1
  524. mindspore/ops/_vmap/vmap_random_ops.py +55 -3
  525. mindspore/ops/_vmap/vmap_sparse_ops.py +1 -0
  526. mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
  527. mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
  528. mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +18 -19
  529. mindspore/ops/bprop_mindir/Argmax_bprop.mindir +13 -12
  530. mindspore/ops/bprop_mindir/Argmin_bprop.mindir +14 -13
  531. mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +17 -18
  532. mindspore/ops/bprop_mindir/Assign_bprop.mindir +16 -16
  533. mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +150 -0
  534. mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +66 -0
  535. mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
  536. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +13 -12
  537. mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
  538. mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +28 -0
  539. mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
  540. mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +33 -0
  541. mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +306 -0
  542. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +12 -8
  543. mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
  544. mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
  545. mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +240 -0
  546. mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +247 -0
  547. mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +247 -0
  548. mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +315 -0
  549. mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +278 -0
  550. mindspore/ops/bprop_mindir/DType_bprop.mindir +12 -12
  551. mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +58 -0
  552. mindspore/ops/bprop_mindir/Depend_bprop.mindir +12 -13
  553. mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +23 -0
  554. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +138 -0
  555. mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +15 -0
  556. mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
  557. mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
  558. mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +22 -24
  559. mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +16 -14
  560. mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +27 -0
  561. mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
  562. mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
  563. mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
  564. mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +12 -12
  565. mindspore/ops/bprop_mindir/Elu_bprop.mindir +16 -0
  566. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  567. mindspore/ops/bprop_mindir/Equal_bprop.mindir +18 -19
  568. mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +58 -0
  569. mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +16 -0
  570. mindspore/ops/bprop_mindir/Flatten_bprop.mindir +54 -0
  571. mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +18 -15
  572. mindspore/ops/bprop_mindir/GatherD_bprop.mindir +26 -0
  573. mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +57 -0
  574. mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
  575. mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +17 -18
  576. mindspore/ops/bprop_mindir/Greater_bprop.mindir +18 -19
  577. mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +16 -0
  578. mindspore/ops/bprop_mindir/HSwish_bprop.mindir +16 -0
  579. mindspore/ops/bprop_mindir/IOU_bprop.mindir +18 -19
  580. mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
  581. mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +13 -12
  582. mindspore/ops/bprop_mindir/IsInf_bprop.mindir +13 -10
  583. mindspore/ops/bprop_mindir/IsNan_bprop.mindir +14 -11
  584. mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +126 -0
  585. mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +15 -0
  586. mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +30 -0
  587. mindspore/ops/bprop_mindir/LRN_bprop.mindir +43 -0
  588. mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
  589. mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +18 -19
  590. mindspore/ops/bprop_mindir/Less_bprop.mindir +17 -18
  591. mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +22 -19
  592. mindspore/ops/bprop_mindir/Load_bprop.mindir +12 -13
  593. mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +23 -0
  594. mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +17 -18
  595. mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +14 -13
  596. mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +21 -0
  597. mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +74 -0
  598. mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +74 -0
  599. mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +75 -0
  600. mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +65 -0
  601. mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
  602. mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
  603. mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
  604. mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +27 -0
  605. mindspore/ops/bprop_mindir/Mish_bprop.mindir +35 -0
  606. mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
  607. mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
  608. mindspore/ops/bprop_mindir/NonZero_bprop.mindir +14 -0
  609. mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +18 -19
  610. mindspore/ops/bprop_mindir/OneHot_bprop.mindir +25 -23
  611. mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +13 -13
  612. mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
  613. mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
  614. mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
  615. mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +29 -0
  616. mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +82 -0
  617. mindspore/ops/bprop_mindir/Range_bprop.mindir +21 -19
  618. mindspore/ops/bprop_mindir/Rank_bprop.mindir +11 -11
  619. mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +16 -0
  620. mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
  621. mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +18 -17
  622. mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +18 -17
  623. mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +19 -23
  624. mindspore/ops/bprop_mindir/Reshape_bprop.mindir +60 -0
  625. mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +29 -0
  626. mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +89 -0
  627. mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +52 -0
  628. mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +22 -0
  629. mindspore/ops/bprop_mindir/Round_bprop.mindir +14 -13
  630. mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
  631. mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
  632. mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +22 -0
  633. mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +24 -0
  634. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +22 -0
  635. mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
  636. mindspore/ops/bprop_mindir/SeLU_bprop.mindir +21 -0
  637. mindspore/ops/bprop_mindir/Select_bprop.mindir +30 -34
  638. mindspore/ops/bprop_mindir/Shape_bprop.mindir +12 -12
  639. mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +21 -0
  640. mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
  641. mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +16 -0
  642. mindspore/ops/bprop_mindir/Sign_bprop.mindir +13 -12
  643. mindspore/ops/bprop_mindir/Slice_bprop.mindir +26 -0
  644. mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +36 -0
  645. mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  646. mindspore/ops/bprop_mindir/Softplus_bprop.mindir +16 -0
  647. mindspore/ops/bprop_mindir/Softsign_bprop.mindir +33 -0
  648. mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
  649. mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +28 -0
  650. mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +23 -0
  651. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  652. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  653. mindspore/ops/bprop_mindir/Split_bprop.mindir +22 -0
  654. mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +54 -0
  655. mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +95 -0
  656. mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +98 -0
  657. mindspore/ops/bprop_mindir/Switch_bprop.mindir +28 -32
  658. mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
  659. mindspore/ops/bprop_mindir/Tanh_bprop.mindir +66 -0
  660. mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +22 -0
  661. mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +29 -0
  662. mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +14 -0
  663. mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
  664. mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
  665. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +23 -0
  666. mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +18 -15
  667. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +11 -13
  668. mindspore/ops/bprop_mindir/Unique_bprop.mindir +16 -0
  669. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +22 -0
  670. mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +32 -0
  671. mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +38 -0
  672. mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +13 -12
  673. mindspore/ops/bprop_mindir/__init__.py +1 -4
  674. mindspore/ops/bprop_mindir/generate_mindir.py +32 -20
  675. mindspore/ops/composite/__init__.py +12 -13
  676. mindspore/ops/composite/base.py +261 -254
  677. mindspore/ops/composite/env_ops.py +41 -0
  678. mindspore/ops/composite/math_ops.py +197 -156
  679. mindspore/ops/composite/multitype_ops/_compile_utils.py +428 -176
  680. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +188 -87
  681. mindspore/ops/composite/multitype_ops/add_impl.py +23 -1
  682. mindspore/ops/composite/multitype_ops/div_impl.py +3 -3
  683. mindspore/ops/composite/multitype_ops/equal_impl.py +1 -0
  684. mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -1
  685. mindspore/ops/composite/multitype_ops/getitem_impl.py +52 -5
  686. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +31 -0
  687. mindspore/ops/composite/multitype_ops/greater_impl.py +31 -0
  688. mindspore/ops/composite/multitype_ops/in_impl.py +15 -3
  689. mindspore/ops/composite/multitype_ops/less_equal_impl.py +33 -2
  690. mindspore/ops/composite/multitype_ops/less_impl.py +33 -0
  691. mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -2
  692. mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
  693. mindspore/ops/composite/multitype_ops/mod_impl.py +1 -1
  694. mindspore/ops/composite/multitype_ops/mul_impl.py +21 -7
  695. mindspore/ops/composite/multitype_ops/not_in_impl.py +15 -3
  696. mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -4
  697. mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
  698. mindspore/ops/composite/multitype_ops/setitem_impl.py +62 -70
  699. mindspore/ops/composite/multitype_ops/sub_impl.py +3 -3
  700. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +41 -4
  701. mindspore/ops/function/__init__.py +323 -8
  702. mindspore/ops/function/array_func.py +3511 -780
  703. mindspore/ops/function/clip_func.py +329 -0
  704. mindspore/ops/function/debug_func.py +6 -6
  705. mindspore/ops/function/grad/__init__.py +5 -1
  706. mindspore/ops/function/grad/grad_func.py +736 -65
  707. mindspore/ops/function/image_func.py +270 -0
  708. mindspore/ops/function/linalg_func.py +268 -8
  709. mindspore/ops/function/math_func.py +8032 -3164
  710. mindspore/ops/function/nn_func.py +5619 -1855
  711. mindspore/ops/function/other_func.py +115 -0
  712. mindspore/ops/function/parameter_func.py +11 -10
  713. mindspore/ops/function/random_func.py +939 -77
  714. mindspore/ops/function/sparse_func.py +249 -84
  715. mindspore/ops/function/sparse_unary_func.py +2303 -0
  716. mindspore/ops/function/spectral_func.py +146 -0
  717. mindspore/ops/function/vmap_func.py +114 -0
  718. mindspore/ops/functional.py +182 -254
  719. mindspore/ops/op_info_register.py +79 -34
  720. mindspore/ops/operations/__init__.py +210 -118
  721. mindspore/ops/operations/_csr_ops.py +7 -7
  722. mindspore/ops/operations/_embedding_cache_ops.py +25 -15
  723. mindspore/ops/operations/_grad_ops.py +447 -322
  724. mindspore/ops/operations/_inner_ops.py +547 -176
  725. mindspore/ops/operations/_map_tensor_ops.py +112 -0
  726. mindspore/ops/operations/_ms_kernel.py +29 -27
  727. mindspore/ops/operations/_ocr_ops.py +11 -11
  728. mindspore/ops/operations/_opaque_predicate_registry.py +41 -0
  729. mindspore/ops/operations/_quant_ops.py +186 -101
  730. mindspore/ops/operations/_rl_inner_ops.py +122 -61
  731. mindspore/ops/operations/_scalar_ops.py +466 -0
  732. mindspore/ops/operations/_sequence_ops.py +1047 -0
  733. mindspore/ops/operations/_tensor_array.py +10 -11
  734. mindspore/ops/operations/_thor_ops.py +4 -4
  735. mindspore/ops/operations/array_ops.py +1428 -1226
  736. mindspore/ops/operations/comm_ops.py +180 -117
  737. mindspore/ops/operations/control_ops.py +4 -2
  738. mindspore/ops/operations/custom_ops.py +185 -98
  739. mindspore/ops/operations/debug_ops.py +92 -54
  740. mindspore/ops/operations/image_ops.py +406 -211
  741. mindspore/ops/operations/inner_ops.py +42 -53
  742. mindspore/ops/operations/linalg_ops.py +32 -29
  743. mindspore/ops/operations/math_ops.py +2076 -897
  744. mindspore/ops/operations/nn_ops.py +1282 -1252
  745. mindspore/ops/operations/other_ops.py +124 -278
  746. mindspore/ops/operations/random_ops.py +345 -178
  747. mindspore/ops/operations/rl_ops.py +8 -9
  748. mindspore/ops/operations/sparse_ops.py +502 -157
  749. mindspore/ops/operations/spectral_ops.py +107 -0
  750. mindspore/ops/primitive.py +192 -15
  751. mindspore/ops/vm_impl_registry.py +23 -2
  752. mindspore/parallel/__init__.py +6 -1
  753. mindspore/parallel/_auto_parallel_context.py +199 -92
  754. mindspore/parallel/_cell_wrapper.py +4 -2
  755. mindspore/parallel/_cost_model_context.py +3 -0
  756. mindspore/parallel/_dp_allreduce_fusion.py +2 -1
  757. mindspore/parallel/_offload_context.py +185 -0
  758. mindspore/parallel/_parallel_serialization.py +167 -28
  759. mindspore/parallel/_ps_context.py +9 -5
  760. mindspore/parallel/_recovery_context.py +1 -1
  761. mindspore/parallel/_tensor.py +9 -1
  762. mindspore/{nn/transformer → parallel/_transformer}/__init__.py +6 -6
  763. mindspore/{nn/transformer → parallel/_transformer}/layers.py +59 -37
  764. mindspore/{nn/transformer → parallel/_transformer}/loss.py +4 -7
  765. mindspore/{nn/transformer → parallel/_transformer}/moe.py +160 -35
  766. mindspore/{nn/transformer → parallel/_transformer}/op_parallel_config.py +3 -3
  767. mindspore/{nn/transformer → parallel/_transformer}/transformer.py +235 -196
  768. mindspore/parallel/_utils.py +47 -7
  769. mindspore/parallel/algo_parameter_config.py +5 -1
  770. mindspore/parallel/checkpoint_transform.py +329 -0
  771. mindspore/parallel/shard.py +229 -0
  772. mindspore/perf_msvcbuildinsights.dll +0 -0
  773. mindspore/pgodb140.dll +0 -0
  774. mindspore/pgort140.dll +0 -0
  775. mindspore/profiler/__init__.py +2 -1
  776. mindspore/profiler/common/util.py +4 -3
  777. mindspore/profiler/common/validator/validate_path.py +2 -2
  778. mindspore/profiler/envprofiling.py +249 -0
  779. mindspore/profiler/parser/aicpu_data_parser.py +38 -39
  780. mindspore/profiler/parser/ascend_timeline_generator.py +497 -0
  781. mindspore/profiler/parser/base_timeline_generator.py +471 -0
  782. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +684 -0
  783. mindspore/profiler/parser/framework_parser.py +42 -16
  784. mindspore/profiler/parser/hccl_parser.py +158 -158
  785. mindspore/profiler/parser/hwts_log_parser.py +7 -6
  786. mindspore/profiler/parser/integrator.py +18 -1579
  787. mindspore/profiler/parser/minddata_analyzer.py +8 -8
  788. mindspore/profiler/parser/msadvisor_analyzer.py +14 -27
  789. mindspore/profiler/parser/msadvisor_parser.py +2 -4
  790. mindspore/profiler/parser/optime_parser.py +17 -18
  791. mindspore/profiler/parser/profiler_info.py +108 -0
  792. mindspore/profiler/parser/step_trace_parser.py +1 -1
  793. mindspore/profiler/profiling.py +396 -194
  794. mindspore/rewrite/__init__.py +6 -2
  795. mindspore/rewrite/api/node.py +51 -110
  796. mindspore/rewrite/api/node_type.py +10 -6
  797. mindspore/rewrite/api/pattern_engine.py +51 -7
  798. mindspore/rewrite/api/scoped_value.py +64 -53
  799. mindspore/rewrite/api/symbol_tree.py +108 -61
  800. mindspore/rewrite/api/tree_node_helper.py +2 -3
  801. mindspore/{compression/quant/__init__.py → rewrite/ast_creator_register.py} +20 -11
  802. mindspore/rewrite/ast_helpers/__init__.py +6 -3
  803. mindspore/rewrite/ast_helpers/ast_creator.py +115 -0
  804. mindspore/rewrite/ast_helpers/ast_finder.py +99 -1
  805. mindspore/rewrite/ast_helpers/ast_modifier.py +17 -4
  806. mindspore/rewrite/ast_helpers/ast_replacer.py +1 -1
  807. mindspore/rewrite/ast_transformers/__init__.py +0 -1
  808. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +46 -5
  809. mindspore/rewrite/ast_transformers/remove_return_out_of_if.py +6 -3
  810. mindspore/rewrite/common/__init__.py +2 -0
  811. mindspore/rewrite/common/event.py +1 -1
  812. mindspore/rewrite/common/observable.py +1 -1
  813. mindspore/rewrite/common/observer.py +1 -1
  814. mindspore/rewrite/common/rewrite_elog.py +35 -0
  815. mindspore/rewrite/namer.py +2 -2
  816. mindspore/rewrite/namespace.py +14 -4
  817. mindspore/rewrite/node.py +161 -13
  818. mindspore/rewrite/parser.py +0 -1
  819. mindspore/rewrite/parser_register.py +0 -1
  820. mindspore/rewrite/parsers/arguments_parser.py +3 -2
  821. mindspore/rewrite/parsers/assign_parser.py +267 -67
  822. mindspore/rewrite/parsers/attribute_parser.py +56 -0
  823. mindspore/rewrite/parsers/class_def_parser.py +191 -108
  824. mindspore/rewrite/parsers/constant_parser.py +101 -0
  825. mindspore/rewrite/parsers/container_parser.py +88 -0
  826. mindspore/rewrite/parsers/for_parser.py +28 -15
  827. mindspore/rewrite/parsers/function_def_parser.py +21 -5
  828. mindspore/rewrite/parsers/if_parser.py +11 -28
  829. mindspore/rewrite/parsers/module_parser.py +9 -6
  830. mindspore/rewrite/parsers/return_parser.py +3 -2
  831. mindspore/rewrite/sparsify/__init__.py +0 -0
  832. mindspore/rewrite/sparsify/sparse_transformer.py +448 -0
  833. mindspore/rewrite/sparsify/sparsify.py +109 -0
  834. mindspore/rewrite/sparsify/utils.py +173 -0
  835. mindspore/rewrite/symbol_tree.py +322 -109
  836. mindspore/rewrite/symbol_tree_builder.py +45 -8
  837. mindspore/rewrite/symbol_tree_dumper.py +0 -1
  838. mindspore/rewrite/topological_manager.py +1 -2
  839. mindspore/run_check/_check_version.py +209 -112
  840. mindspore/run_check/run_check.py +2 -1
  841. mindspore/tbbmalloc.dll +0 -0
  842. mindspore/tinyxml2.dll +0 -0
  843. mindspore/train/__init__.py +6 -4
  844. mindspore/train/_utils.py +28 -5
  845. mindspore/train/amp.py +321 -50
  846. mindspore/train/callback/__init__.py +3 -1
  847. mindspore/train/callback/_backup_and_restore.py +120 -0
  848. mindspore/train/callback/_callback.py +8 -8
  849. mindspore/train/callback/_checkpoint.py +12 -9
  850. mindspore/train/callback/_early_stop.py +13 -7
  851. mindspore/train/callback/_history.py +8 -8
  852. mindspore/train/callback/_lambda_callback.py +6 -6
  853. mindspore/train/callback/_landscape.py +36 -38
  854. mindspore/train/callback/_loss_monitor.py +12 -6
  855. mindspore/train/callback/_lr_scheduler_callback.py +2 -4
  856. mindspore/train/callback/_on_request_exit.py +212 -0
  857. mindspore/train/callback/_reduce_lr_on_plateau.py +13 -7
  858. mindspore/train/callback/_summary_collector.py +27 -19
  859. mindspore/train/callback/_time_monitor.py +13 -7
  860. mindspore/train/checkpoint_pb2.py +68 -8
  861. mindspore/train/data_sink.py +122 -33
  862. mindspore/train/dataset_helper.py +28 -87
  863. mindspore/train/loss_scale_manager.py +4 -7
  864. mindspore/{nn → train}/metrics/__init__.py +20 -20
  865. mindspore/{nn → train}/metrics/accuracy.py +12 -10
  866. mindspore/{nn → train}/metrics/auc.py +4 -4
  867. mindspore/{nn → train}/metrics/bleu_score.py +4 -4
  868. mindspore/{nn → train}/metrics/confusion_matrix.py +10 -8
  869. mindspore/{nn → train}/metrics/cosine_similarity.py +4 -4
  870. mindspore/{nn → train}/metrics/dice.py +6 -5
  871. mindspore/{nn → train}/metrics/error.py +7 -5
  872. mindspore/{nn → train}/metrics/fbeta.py +9 -7
  873. mindspore/{nn → train}/metrics/hausdorff_distance.py +8 -6
  874. mindspore/{nn → train}/metrics/loss.py +4 -3
  875. mindspore/{nn → train}/metrics/mean_surface_distance.py +6 -5
  876. mindspore/{nn → train}/metrics/metric.py +6 -5
  877. mindspore/{nn → train}/metrics/occlusion_sensitivity.py +4 -3
  878. mindspore/{nn → train}/metrics/perplexity.py +5 -4
  879. mindspore/{nn → train}/metrics/precision.py +5 -4
  880. mindspore/{nn → train}/metrics/recall.py +5 -4
  881. mindspore/{nn → train}/metrics/roc.py +7 -6
  882. mindspore/{nn → train}/metrics/root_mean_square_surface_distance.py +6 -5
  883. mindspore/{nn → train}/metrics/topk.py +7 -5
  884. mindspore/train/mind_ir_pb2.py +339 -32
  885. mindspore/train/model.py +113 -84
  886. mindspore/train/serialization.py +547 -167
  887. mindspore/train/summary/_summary_adapter.py +1 -1
  888. mindspore/train/summary/summary_record.py +43 -12
  889. mindspore/train/train_thor/convert_utils.py +7 -1
  890. mindspore/train/train_thor/dataset_helper.py +3 -3
  891. mindspore/train/train_thor/model_thor.py +0 -4
  892. mindspore/turbojpeg.dll +0 -0
  893. mindspore/vcmeta.dll +0 -0
  894. mindspore/vcruntime140.dll +0 -0
  895. mindspore/vcruntime140_1.dll +0 -0
  896. mindspore/version.py +1 -1
  897. {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/METADATA +4 -3
  898. {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/RECORD +901 -660
  899. mindspore/compression/common/constant.py +0 -124
  900. mindspore/compression/export/__init__.py +0 -19
  901. mindspore/compression/export/quant_export.py +0 -514
  902. mindspore/compression/quant/qat.py +0 -636
  903. mindspore/compression/quant/quant_utils.py +0 -462
  904. mindspore/compression/quant/quantizer.py +0 -68
  905. mindspore/libatomic-1.dll +0 -0
  906. mindspore/libgcc_s_seh-1.dll +0 -0
  907. mindspore/libgfortran-4.dll +0 -0
  908. mindspore/libgomp-1.dll +0 -0
  909. mindspore/libjpeg-62.dll +0 -0
  910. mindspore/libmindspore.dll +0 -0
  911. mindspore/libmindspore_common.dll +0 -0
  912. mindspore/libmindspore_core.dll +0 -0
  913. mindspore/libmindspore_glog.dll +0 -0
  914. mindspore/libnnacl.dll +0 -0
  915. mindspore/libopencv_core452.dll +0 -0
  916. mindspore/libopencv_imgcodecs452.dll +0 -0
  917. mindspore/libopencv_imgproc452.dll +0 -0
  918. mindspore/libquadmath-0.dll +0 -0
  919. mindspore/libsqlite3.dll +0 -0
  920. mindspore/libssp-0.dll +0 -0
  921. mindspore/libstdc++-6.dll +0 -0
  922. mindspore/libtinyxml2.dll +0 -0
  923. mindspore/libturbojpeg.dll +0 -0
  924. mindspore/libwinpthread-1.dll +0 -0
  925. mindspore/nn/layer/quant.py +0 -1868
  926. mindspore/nn/layer/rnn_utils.py +0 -90
  927. mindspore/nn/probability/dpn/__init__.py +0 -22
  928. mindspore/nn/probability/dpn/vae/__init__.py +0 -25
  929. mindspore/nn/probability/dpn/vae/cvae.py +0 -138
  930. mindspore/nn/probability/dpn/vae/vae.py +0 -122
  931. mindspore/nn/probability/infer/__init__.py +0 -22
  932. mindspore/nn/probability/infer/variational/elbo.py +0 -70
  933. mindspore/nn/probability/infer/variational/svi.py +0 -84
  934. mindspore/nn/probability/toolbox/__init__.py +0 -22
  935. mindspore/nn/probability/toolbox/anomaly_detection.py +0 -99
  936. mindspore/nn/probability/toolbox/uncertainty_evaluation.py +0 -363
  937. mindspore/nn/probability/transforms/__init__.py +0 -22
  938. mindspore/nn/probability/transforms/transform_bnn.py +0 -262
  939. mindspore/nn/probability/zhusuan/__init__.py +0 -18
  940. mindspore/nn/probability/zhusuan/framework/__init__.py +0 -18
  941. mindspore/nn/probability/zhusuan/framework/bn.py +0 -95
  942. mindspore/nn/probability/zhusuan/variational/__init__.py +0 -18
  943. mindspore/nn/probability/zhusuan/variational/elbo.py +0 -46
  944. mindspore/ops/_op_impl/tbe/bias_add_grad_ds.py +0 -52
  945. mindspore/ops/_op_impl/tbe/scatter_nd_add_ds.py +0 -43
  946. mindspore/ops/bprop_mindir/AssignAdd_bprop.mindir +0 -20
  947. mindspore/ops/bprop_mindir/Identity_bprop.mindir +0 -9
  948. mindspore/ops/bprop_mindir/LogicalOr_bprop.mindir +0 -20
  949. mindspore/ops/bprop_mindir/ReLU_bprop.mindir +0 -16
  950. mindspore/ops/bprop_mindir/UpdateState_bprop.mindir +0 -17
  951. mindspore/ops/bprop_mindir/stop_gradient_bprop.mindir +0 -12
  952. mindspore/ops/composite/array_ops.py +0 -210
  953. mindspore/ops/composite/clip_ops.py +0 -238
  954. mindspore/ops/composite/random_ops.py +0 -426
  955. mindspore/ops/composite/vmap_ops.py +0 -38
  956. mindspore/ops/operations/sponge_ops.py +0 -3531
  957. mindspore/ops/operations/sponge_update_ops.py +0 -2546
  958. mindspore/parallel/nn/__init__.py +0 -42
  959. mindspore/parallel/nn/loss.py +0 -22
  960. mindspore/parallel/nn/moe.py +0 -21
  961. mindspore/parallel/nn/op_parallel_config.py +0 -22
  962. mindspore/parallel/nn/transformer.py +0 -31
  963. mindspore/run_check/_check_deps_version.py +0 -84
  964. {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/WHEEL +0 -0
  965. {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/entry_points.txt +0 -0
  966. {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/top_level.txt +0 -0
@@ -14,16 +14,17 @@
14
14
  # ============================================================================
15
15
 
16
16
  """The vmap implement of grad operator corresponding to nn_ops."""
17
-
18
17
  from __future__ import absolute_import
18
+
19
19
  from __future__ import division
20
20
  from functools import reduce
21
21
  import mindspore.numpy as mnp
22
22
  from mindspore.ops.operations import _grad_ops as G
23
23
  from mindspore.ops import functional as F
24
24
  from mindspore.ops import constexpr
25
+ from mindspore.ops.primitive import _primexpr
25
26
  from mindspore.ops.primitive import Primitive
26
- from mindspore.ops.composite import _VmapGeneralRule
27
+ from mindspore.ops.function import _VmapGeneralRule
27
28
  from mindspore.ops._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, _raise_value_error, \
28
29
  _bdim_at_front, _vmap_clone_prim, _vmap_update_prim_attr, _bdim_at_any, _handle_broadcasting
29
30
 
@@ -38,7 +39,7 @@ def get_nll_loss_grad_vmap_rule(prim, axis_size):
38
39
  2. And weight only support shape as (C,), while total_weight should be a scalar.
39
40
  """
40
41
 
41
- @constexpr
42
+ @_primexpr
42
43
  def _get_reshape_shape(shape, keep_dim=0):
43
44
  new_batch_size = reduce(
44
45
  lambda x, y: x * y, shape if keep_dim == 0 else shape[:-keep_dim])
@@ -104,6 +105,7 @@ def get_nll_loss_grad_vmap_rule(prim, axis_size):
104
105
  return vmap_rule
105
106
 
106
107
 
108
+ @vmap_rules_getters.register(G.MaxPoolGrad)
107
109
  @vmap_rules_getters.register(G.AvgPoolGrad)
108
110
  def get_avg_pool_grad_vmap_rule(prim, axis_size):
109
111
  """VmapRule for `AvgPoolGrad`."""
@@ -225,11 +227,15 @@ def get_cdist_grad_vmap_rule(prim, axis_size):
225
227
  return vmap_rule
226
228
 
227
229
 
230
+ @vmap_rules_getters.register(G.AdaptiveMaxPool3DGrad)
228
231
  @vmap_rules_getters.register(G.AdaptiveMaxPool2DGrad)
229
232
  def get_adaptive_avgpool2d_vmap_rule(prim, axis_size):
230
- """VmapRule for `AdaptiveMaxPool2DGrad` operation."""
233
+ """VmapRule for `AdaptiveMaxPool2DGrad` and `AdaptiveMaxPool3DGrad` operation."""
231
234
  chw_reverse_index = -3
232
- hw_reverse_index = -2
235
+ if prim.name == "AdaptiveMaxPool2DGrad":
236
+ hw_reverse_index = -2
237
+ else:
238
+ hw_reverse_index = -3
233
239
 
234
240
  def vmap_rule(ygrad_bdim, x_bdim, max_index_bdim):
235
241
  is_all_none, result = vmap_general_preprocess(prim, ygrad_bdim, x_bdim, max_index_bdim)
@@ -352,7 +358,7 @@ def get_batchnorm_grad_vmap_rule(prim, axis_size):
352
358
  if is_all_none:
353
359
  return result
354
360
  if data_format == "NHWC":
355
- #BatchNormGrad with NHWC format is a GPU backend operation and not supported for now.
361
+ # BatchNormGrad with NHWC format is a GPU backend operation and not supported for now.
356
362
  return batchnorm_grad_nhwc_vmap(grad_bdim, x_bdim, scale_bdim, rsv_1_bdim, rsv_2_bdim, rsv_3_bdim)
357
363
  grad, grad_dim = grad_bdim
358
364
  input_x, input_x_dim = x_bdim
@@ -392,8 +398,9 @@ def get_batchnorm_grad_vmap_rule(prim, axis_size):
392
398
 
393
399
  @vmap_rules_getters.register(G.MaxPoolGradGrad)
394
400
  @vmap_rules_getters.register(G.MaxPoolGradGradWithArgmax)
401
+ @vmap_rules_getters.register(G.MaxPoolGradWithArgmaxV2)
395
402
  def get_maxpool_grad_grad_vmap_rule(prim, axis_size):
396
- """VmapRule for `MaxPoolGradGrad` and `MaxPoolGradGradWithArgmax`."""
403
+ """VmapRule for `MaxPoolGradGrad`, `MaxPoolGradGradWithArgmax` and `MaxPoolGradWithArgmaxV2`."""
397
404
  chw_reverse_index = -3
398
405
 
399
406
  def vmap_rule(in0_bdim, in1_bdim, in2_bdim):
@@ -552,7 +559,7 @@ def get_layernormgrad_vmap_rule(prim, axis_size):
552
559
  return prim_attr_axis
553
560
  return prim_attr_axis + 1
554
561
 
555
- @constexpr
562
+ @_primexpr
556
563
  def get_batch_params_reduce_axes(begin_params_axis, x_shape):
557
564
  if begin_params_axis < 0:
558
565
  x_rank = len(x_shape)
@@ -560,7 +567,7 @@ def get_layernormgrad_vmap_rule(prim, axis_size):
560
567
  batch_params_reduce_axes = tuple(range(1, begin_params_axis))
561
568
  return batch_params_reduce_axes
562
569
 
563
- @constexpr
570
+ @_primexpr
564
571
  def get_logical_shape(var_shape):
565
572
  return var_shape[1:]
566
573
 
@@ -682,3 +689,28 @@ def get_upsample_grad_vmap_rule(prim, axis_size):
682
689
  out = F.reshape(out, real_out_shape)
683
690
  return out, 0
684
691
  return vmap_rule
692
+
693
+
694
+ @vmap_rules_getters.register(G.LogSoftmaxGrad)
695
+ def get_log_softmax_vmap_rule(prim, axis_size):
696
+ """VmapRule for 'LogSoftmaxGrad' operation."""
697
+ if isinstance(prim, str):
698
+ axis = -1
699
+ else:
700
+ axis = prim.axis
701
+
702
+ def vmap_rule(x_bdim, grad_bdim):
703
+ is_all_none, result = vmap_general_preprocess(prim, x_bdim)
704
+ if is_all_none:
705
+ return result
706
+ x, x_dim = x_bdim
707
+ grad, _ = grad_bdim
708
+ x_ndim = F.rank(x) - 1
709
+
710
+ batch_axis = axis + x_ndim if axis < 0 else axis
711
+ batch_axis = batch_axis if batch_axis < x_dim else batch_axis + 1
712
+
713
+ dx = G.LogSoftmaxGrad(axis=batch_axis)(x, grad)
714
+ return dx, x_dim
715
+
716
+ return vmap_rule
@@ -16,9 +16,12 @@
16
16
  """image_ops vmap impl."""
17
17
  from __future__ import absolute_import
18
18
 
19
+ import numpy as np
20
+ from mindspore import Tensor
19
21
  from mindspore.ops import functional as F
20
22
  from mindspore.ops.operations import _grad_ops as G
21
23
  from mindspore.ops.operations import image_ops as IMG
24
+ from mindspore.ops import constexpr
22
25
  from mindspore.ops._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, _bdim_at_front, \
23
26
  _raise_value_error
24
27
 
@@ -83,3 +86,52 @@ def get_resize_grad_dynamic_rule(prim, axis_size):
83
86
  return out, 0
84
87
 
85
88
  return vmap_rule
89
+
90
+
91
+ @vmap_rules_getters.register(IMG.CropAndResize)
92
+ def get_crop_and_resize_vmap_rule(prim, axis_size):
93
+ """VmapRule for `CropAndResize` operation."""
94
+
95
+ @constexpr
96
+ def get_box_indices_offsets(axis_size, batch_size, num_boxes):
97
+ offsets = np.arange(0, axis_size * batch_size, batch_size).astype(np.int32)
98
+ offsets = np.reshape(offsets, (axis_size, 1))
99
+ offsets = np.broadcast_to(offsets, (axis_size, num_boxes))
100
+ return Tensor(offsets)
101
+
102
+ def vmap_rule(x_bdim, boxes_bdim, box_indices_bdim, crop_size_bdim):
103
+ is_all_none, result = vmap_general_preprocess(x_bdim, boxes_bdim, box_indices_bdim, crop_size_bdim)
104
+ if is_all_none:
105
+ return result
106
+
107
+ x, x_dim = x_bdim
108
+ boxes, boxes_dim = boxes_bdim
109
+ box_indices, box_indices_dim = box_indices_bdim
110
+ crop_size, crop_size_dim = crop_size_bdim
111
+ if crop_size_dim is not None:
112
+ _raise_value_error(
113
+ "The axis of `crop_size` in `{}` must be None, but got {}.".format(prim.name, crop_size_dim))
114
+
115
+ boxes = _bdim_at_front(boxes, boxes_dim, axis_size)
116
+ box_indices = _bdim_at_front(box_indices, box_indices_dim, axis_size)
117
+ boxes = F.reshape(boxes, (-1, 4))
118
+ num_boxes = F.shape(box_indices)[-1]
119
+
120
+ if x_dim is None:
121
+ box_indices = F.reshape(box_indices, (-1,))
122
+ out = prim(x, boxes, box_indices, crop_size)
123
+ else:
124
+ x = _bdim_at_front(x, x_dim, axis_size)
125
+ x_shape = F.shape(x)
126
+ x = F.reshape(x, (-1,) + x_shape[2:])
127
+ offsets = get_box_indices_offsets(axis_size, x_shape[1], num_boxes)
128
+ box_indices = F.add(box_indices, offsets)
129
+ box_indices = F.reshape(box_indices, (-1,))
130
+ out = prim(x, boxes, box_indices, crop_size)
131
+
132
+ out_shape = F.shape(out)
133
+ out = F.reshape(out, (-1, num_boxes) + out_shape[1:])
134
+ return out, 0
135
+
136
+
137
+ return vmap_rule
@@ -19,13 +19,13 @@ from __future__ import absolute_import
19
19
  import mindspore.numpy as mnp
20
20
  from mindspore.ops import operations as P
21
21
  from mindspore.ops import functional as F
22
- from mindspore.ops import constexpr
22
+ from mindspore.ops.primitive import _primexpr
23
23
  from mindspore.common import Tensor
24
24
  from mindspore.ops.operations import math_ops
25
25
  from mindspore.ops.operations import linalg_ops
26
26
  from mindspore.ops.operations import _inner_ops
27
27
  from mindspore.ops.primitive import Primitive
28
- from mindspore.ops.composite import _VmapGeneralRule
28
+ from mindspore.ops.function import _VmapGeneralRule
29
29
  from mindspore.ops._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, get_assign_vmap_rule, \
30
30
  get_unop_vmap_rule, _raise_value_error, _bdim_at_front, _broadcast_by_axis, _handle_broadcasting, \
31
31
  _vmap_clone_prim, _bdim_at_any, _get_reduce_batch_axis, _get_reduce_out_dim
@@ -33,7 +33,7 @@ from mindspore.ops.operations.math_ops import Bernoulli, BesselI0, BesselI1, Bes
33
33
  BesselK0, BesselK0e, BesselY0, BesselY1, BesselK1, BesselK1e, Median
34
34
 
35
35
 
36
- @constexpr
36
+ @_primexpr
37
37
  def _broadcast_shape(nd, x_ndim, x_shape):
38
38
  return x_shape + (1,) * (nd - x_ndim)
39
39
 
@@ -97,6 +97,35 @@ def get_broadcast_binary_op_vmap_rule(prim, axis_size):
97
97
  return vmap_rule
98
98
 
99
99
 
100
+ @vmap_rules_getters.register(P.Addcdiv)
101
+ @vmap_rules_getters.register(P.Addcmul)
102
+ def get_addcxxx_vmap_rule(prim, axis_size):
103
+ """VmapRule for addcxxx, such as `Addcdiv` and `Addcmul`."""
104
+
105
+ def vmap_rule(input_data_bdim, x1_bdim, x2_bdim, value_bdim):
106
+ is_all_none, result = vmap_general_preprocess(prim, input_data_bdim, x1_bdim, x2_bdim, value_bdim)
107
+ if is_all_none:
108
+ return result
109
+
110
+ input_data, input_data_dim = input_data_bdim
111
+ x1, x1_dim = x1_bdim
112
+ x2, x2_dim = x2_bdim
113
+ value, value_dim = value_bdim
114
+ if input_data_dim == x1_dim and x1_dim == x2_dim and x2_dim == value_dim:
115
+ out = prim(input_data, x1, x2, value)
116
+ return out, input_data_dim
117
+
118
+ input_data = _bdim_at_front(input_data, input_data_dim, axis_size)
119
+ x1 = _bdim_at_front(x1, x1_dim, axis_size)
120
+ x2 = _bdim_at_front(x2, x2_dim, axis_size)
121
+ value = _bdim_at_front(value, value_dim, axis_size)
122
+
123
+ out = prim(input_data, x1, x2, value)
124
+ return out, 0
125
+
126
+ return vmap_rule
127
+
128
+
100
129
  @vmap_rules_getters.register(P.Cdist)
101
130
  def get_cdist_vmap_rule(prim, axis_size):
102
131
  """VmapRule for `cdist` operation."""
@@ -358,6 +387,8 @@ def get_inplace_ops_vmap_rule(prim, axis_size):
358
387
  @vmap_rules_getters.register(P.ReduceMin)
359
388
  @vmap_rules_getters.register(P.ReduceMean)
360
389
  @vmap_rules_getters.register(P.ReduceProd)
390
+ @vmap_rules_getters.register(P.ReduceAll)
391
+ @vmap_rules_getters.register(P.ReduceAny)
361
392
  def get_reducer_vmap_rule(prim, axis_size):
362
393
  """VmapRule for reduce operations, such as `ReduceSum`."""
363
394
  reduce_op_map = {
@@ -365,7 +396,9 @@ def get_reducer_vmap_rule(prim, axis_size):
365
396
  "ReduceMax": P.ReduceMax,
366
397
  "ReduceMin": P.ReduceMin,
367
398
  "ReduceMean": P.ReduceMean,
368
- "ReduceProd": P.ReduceProd
399
+ "ReduceProd": P.ReduceProd,
400
+ "ReduceAll": P.ReduceAll,
401
+ "ReduceAny": P.ReduceAny,
369
402
  }
370
403
 
371
404
  if isinstance(prim, str):
@@ -403,7 +436,7 @@ def get_median_vmap_rule(prim, axis_size):
403
436
  axis = prim.axis
404
437
  keep_dims = prim.keep_dims
405
438
 
406
- @constexpr
439
+ @_primexpr
407
440
  def trans_axis(axis, rank, dim, keep_dims):
408
441
  if axis < 0:
409
442
  axis += rank - 1
@@ -431,7 +464,7 @@ def get_index_add_vmap_rule(prim, axis_size):
431
464
  """VmapRule for IndexAdd."""
432
465
  axis = prim.axis
433
466
 
434
- @constexpr
467
+ @_primexpr
435
468
  def _get_index_add_batch_axis(axis, x_dim, x_ndim):
436
469
  """get batch_axis for IndexAdd."""
437
470
  # case1: batch not exists
@@ -770,6 +803,44 @@ def get_square_sum_all_vmap_rule(prim, axis_size):
770
803
  return vmap_rule
771
804
 
772
805
 
806
+ @vmap_rules_getters.register(math_ops.FFTWithSize)
807
+ def get_fft_with_size_vmap_rule(prim, axis_size):
808
+ """VmapRule for `FFTWithSize` operation"""
809
+ if isinstance(prim, str):
810
+ prim_name = prim
811
+ prim = Primitive(prim)
812
+ signal_ndim = 1
813
+ inverse = False
814
+ real = False
815
+ norm = "backward"
816
+ oneside = True
817
+ signal_sizes = ()
818
+ else:
819
+ prim_name = prim.name
820
+ signal_ndim = prim.signal_ndim
821
+ inverse = prim.inverse
822
+ real = prim.real
823
+ norm = prim.norm
824
+ oneside = prim.oneside
825
+ signal_sizes = prim.signal_sizes
826
+
827
+ fft = math_ops.FFTWithSize(signal_ndim, inverse, real, norm, oneside, signal_sizes)
828
+
829
+ def vmap_rule(x_bdim):
830
+ is_all_none, result = vmap_general_preprocess(prim, x_bdim)
831
+ if is_all_none:
832
+ return result
833
+ x, x_dim = x_bdim
834
+ x_ndim = F.rank(x)
835
+ if x_dim < 0 or x_dim >= x_ndim - signal_ndim:
836
+ _raise_value_error("The source axi of `x` in `{} must be`in range of ({} {}), "
837
+ "but got {}.".format(prim_name, 0, x_ndim - signal_ndim, x_dim))
838
+ out = fft(x)
839
+ return (out, x_dim)
840
+
841
+ return vmap_rule
842
+
843
+
773
844
  get_assign_vmap_rule = vmap_rules_getters.register(P.AssignAdd)(get_assign_vmap_rule)
774
845
  get_assign_vmap_rule = vmap_rules_getters.register(P.AssignSub)(get_assign_vmap_rule)
775
846
 
@@ -23,6 +23,7 @@ from mindspore.ops.operations import _grad_ops as G
23
23
  from mindspore.ops.operations import nn_ops as NN
24
24
  from mindspore.ops import functional as F
25
25
  from mindspore.ops import constexpr
26
+ from mindspore.ops.primitive import _primexpr
26
27
  from mindspore.ops._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, get_unop_vmap_rule, \
27
28
  _bdim_at_any, _bdim_at_front, _bdim_at_back, _handle_broadcasting, get_unary_grad_vmap_rule, _raise_value_error, \
28
29
  _vmap_clone_prim, _get_reduce_batch_axis
@@ -375,7 +376,7 @@ def get_bias_add_vmap_rule(prim, axis_size):
375
376
  def get_channal_pos_in_x(d_format):
376
377
  return d_format.find('C') + 1
377
378
 
378
- @constexpr
379
+ @_primexpr
379
380
  def get_bias_dst_shape(x_shape, n_dims, d_format, has_b_dim: bool):
380
381
  pos = get_channal_pos_in_x(d_format)
381
382
 
@@ -430,7 +431,7 @@ def get_bias_add_grad_vmap_rule(prim, axis_size):
430
431
  def get_channal_pos(d_format):
431
432
  return d_format.find('C') + 1
432
433
 
433
- @constexpr
434
+ @_primexpr
434
435
  def get_axis_for_reduce(x_shape_rank, data_format):
435
436
  channal_pos = get_channal_pos(data_format)
436
437
  axis_list = ()
@@ -1072,24 +1073,24 @@ def get_pad_v3_vmap_rule(prim, axis_size):
1072
1073
  if is_all_none:
1073
1074
  return result
1074
1075
  if len(params_bdim) < 2:
1075
- _raise_value_error("The input params in `{}` must >= 2, "
1076
- "but got {}.".format(prim.name, len(params_bdim)))
1076
+ _raise_value_error("The input params in `PadV3` must >= 2, "
1077
+ "but got {}.".format(len(params_bdim)))
1077
1078
  input_x, input_x_dim = params_bdim[0]
1078
1079
  paddings, paddings_dim = params_bdim[1]
1079
1080
  values = None
1080
1081
  out = None
1081
1082
  x = _bdim_at_front(input_x, input_x_dim, axis_size)
1082
1083
  if paddings_dim is not None:
1083
- _raise_value_error("The source axis of `paddings` in `{}` must be None, "
1084
- "but got {}.".format(prim.name, paddings_dim))
1084
+ _raise_value_error("The source axis of `paddings` in `PadV3` must be None, "
1085
+ "but got {}.".format(paddings_dim))
1085
1086
  if mode == "constant":
1086
1087
  if len(params_bdim) != 3:
1087
- _raise_value_error("The input params in `{}` of constant mode must be 3, "
1088
- "but got {}.".format(prim.name, len(params_bdim)))
1088
+ _raise_value_error("The input params in `PadV3` of constant mode must be 3, "
1089
+ "but got {}.".format(len(params_bdim)))
1089
1090
  values, values_dim = params_bdim[2]
1090
1091
  if values_dim is not None:
1091
- _raise_value_error("The source axis of `values_dim` in `{}` must be None, "
1092
- "but got {}.".format(prim.name, values_dim))
1092
+ _raise_value_error("The source axis of `values_dim` in `PadV3` must be None, "
1093
+ "but got {}.".format(values_dim))
1093
1094
  if isinstance(paddings, Tensor):
1094
1095
  pad_dim = F.shape(paddings)[0] / pad_pair
1095
1096
  else:
@@ -1101,7 +1102,7 @@ def get_pad_v3_vmap_rule(prim, axis_size):
1101
1102
  out = prim(x, paddings, values)
1102
1103
  else:
1103
1104
  out = prim(x, paddings)
1104
- elif x_ndim > input_max_dim:
1105
+ elif x_ndim >= input_max_dim:
1105
1106
  # reshape to 4 dims
1106
1107
  x_shape = F.shape(x)
1107
1108
  diff_dim = x_ndim - input_max_dim
@@ -1118,8 +1119,8 @@ def get_pad_v3_vmap_rule(prim, axis_size):
1118
1119
  real_out_shape = x_shape[:diff_dim + 1] + out_shape[1:]
1119
1120
  out = F.reshape(out, real_out_shape)
1120
1121
  else:
1121
- _raise_value_error("The dim of `input_x` in `{}` must be bigger than {}, "
1122
- "but got {}.".format(prim.name, pad_dim, x_ndim))
1122
+ _raise_value_error("The dim of `input_x` in `PadV3` must be bigger than {}, "
1123
+ "but got {}.".format(pad_dim, x_ndim))
1123
1124
  return out, 0
1124
1125
 
1125
1126
  return vmap_rule
@@ -1308,6 +1309,60 @@ def get_apply_adam_with_amsgrad_rule(prim, axis_size):
1308
1309
  return vmap_rule
1309
1310
 
1310
1311
 
1312
+ @vmap_rules_getters.register(P.Adam)
1313
+ def get_adam_rule(prim, axis_size):
1314
+ """VmapRule for `Adam` operation"""
1315
+ if hasattr(prim, "batch_rank"):
1316
+ batch_rank = prim.batch_rank + 1
1317
+ else:
1318
+ batch_rank = 1
1319
+ prim_name = prim.name
1320
+ batch_prim = _vmap_clone_prim(prim)
1321
+ batch_prim.add_prim_attr("batch_rank", batch_rank)
1322
+
1323
+ def vmap_rule(var_bdim, m_bdim, v_bdim, beta1_power_bdim, beta2_power_bdim, lr_bdim, beta1_bdim,
1324
+ beta2_bdim, epsilon_bdim, grad_bdim, u_monad):
1325
+ var, var_dim = var_bdim
1326
+ m, m_dim = m_bdim
1327
+ v, v_dim = v_bdim
1328
+ beta1_power, beta1_power_dim = beta1_power_bdim
1329
+ beta2_power, beta2_power_dim = beta2_power_bdim
1330
+ lr, lr_dim = lr_bdim
1331
+ beta1, beta1_dim = beta1_bdim
1332
+ beta2, beta2_dim = beta2_bdim
1333
+ epsilon, epsilon_dim = epsilon_bdim
1334
+ grad, grad_dim = grad_bdim
1335
+
1336
+ all_dim = [m_dim, v_dim, beta1_power_dim, beta2_power_dim, lr_dim, beta1_dim, beta2_dim, epsilon_dim, grad_dim]
1337
+ if var_dim is None:
1338
+ if any(dim is not None for dim in all_dim):
1339
+ raise ValueError("The source axis of `var` is None, "
1340
+ "but the source axis of `m/v/vhat/beta1_power/beta2_power/lr/beta1/beta2/epsilon grad"
1341
+ " is not None. The execution of operator `{}` cannot be guaranteed.".format(prim_name))
1342
+ out_var, out_m, out_v = prim(
1343
+ var, m, v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, u_monad)
1344
+ return ((out_var, None), (out_m, None), (out_v, None))
1345
+
1346
+ if any(dim != 0 for dim in [var_dim, m_dim, v_dim]):
1347
+ raise ValueError("For `{}`, the source axis of `var/m/v` must be 0, "
1348
+ "but get `var`: {}, `m`: {}, `v`: {}".format(prim_name, var_dim,
1349
+ m_dim, v_dim))
1350
+
1351
+ beta1_power = _bdim_at_front(beta1_power, beta1_power_dim, axis_size)
1352
+ beta2_power = _bdim_at_front(beta2_power, beta2_power_dim, axis_size)
1353
+ lr = _bdim_at_front(lr, lr_dim, axis_size)
1354
+ beta1 = _bdim_at_front(beta1, beta1_dim, axis_size)
1355
+ beta2 = _bdim_at_front(beta2, beta2_dim, axis_size)
1356
+ epsilon = _bdim_at_front(epsilon, epsilon_dim, axis_size)
1357
+ grad = _bdim_at_front(grad, grad_dim, axis_size)
1358
+
1359
+ out_var, out_m, out_v = batch_prim(
1360
+ var, m, v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, u_monad)
1361
+ return ((out_var, 0), (out_m, 0), (out_v, 0))
1362
+
1363
+ return vmap_rule
1364
+
1365
+
1311
1366
  @vmap_rules_getters.register(P.ApplyPowerSign)
1312
1367
  def get_apply_power_sign_rule(prim, axis_size):
1313
1368
  """VmapRule for `ApplyPowerSign` operation."""
@@ -1461,10 +1516,9 @@ def get_adaptive_max_pool_2d_vmap_rule(prim, axis_size):
1461
1516
  nchw_index = 4
1462
1517
  chw_reverse_index = -3
1463
1518
  hw_size = 2
1464
- return_indices = prim.return_indices
1465
1519
  output_size = prim.output_size
1466
1520
 
1467
- @constexpr
1521
+ @_primexpr
1468
1522
  def get_output_shape(x_ori_shape, output_size):
1469
1523
  if isinstance(output_size, tuple):
1470
1524
  h_out, w_out = output_size
@@ -1499,20 +1553,14 @@ def get_adaptive_max_pool_2d_vmap_rule(prim, axis_size):
1499
1553
  x_ori_shape = F.shape(x)
1500
1554
  x = F.reshape(x, (-1,) + x_ori_shape[chw_reverse_index:])
1501
1555
  output_shape = get_output_shape(x_ori_shape, output_size)
1502
- if return_indices:
1503
- out, indices = prim(x)
1504
- out = F.reshape(out, output_shape)
1505
- indices = F.reshape(indices, output_shape)
1506
- return (out, 0), (indices, 0)
1507
- out = prim(x)
1508
- out = F.reshape(out, output_shape)
1509
- return out, 0
1510
- # for the case of CHW
1511
- if return_indices:
1512
1556
  out, indices = prim(x)
1557
+ out = F.reshape(out, output_shape)
1558
+ indices = F.reshape(indices, output_shape)
1513
1559
  return (out, 0), (indices, 0)
1514
- out = prim(x)
1515
- return out, 0
1560
+
1561
+ # for the case of CHW
1562
+ out, indices = prim(x)
1563
+ return (out, 0), (indices, 0)
1516
1564
 
1517
1565
  return vmap_rule
1518
1566
 
@@ -1657,6 +1705,7 @@ def get_apply_centered_rmsprop_vmap_rule(prim, axis_size):
1657
1705
 
1658
1706
  @vmap_rules_getters.register(P.MaxPool)
1659
1707
  @vmap_rules_getters.register(P.MaxPoolWithArgmax)
1708
+ @vmap_rules_getters.register(P.MaxPoolWithArgmaxV2)
1660
1709
  def get_max_pool_vmap_rule(prim, axis_size):
1661
1710
  """VmapRule for `MaxPool` operation."""
1662
1711
  if isinstance(prim, str):
@@ -1664,7 +1713,7 @@ def get_max_pool_vmap_rule(prim, axis_size):
1664
1713
 
1665
1714
  prim_name = prim.name
1666
1715
 
1667
- @constexpr
1716
+ @_primexpr
1668
1717
  def get_original_shape(x_shape, out_shape):
1669
1718
  h_new = out_shape[2]
1670
1719
  w_new = out_shape[3]
@@ -1709,7 +1758,7 @@ def get_layernorm_vmap_rule(prim, axis_size):
1709
1758
  params_axis = process_attr_axis(prim.begin_params_axis)
1710
1759
  batch_prim = P.LayerNorm(norm_axis, params_axis, prim.epsilon)
1711
1760
 
1712
- @constexpr
1761
+ @_primexpr
1713
1762
  def get_logical_shape(var_shape):
1714
1763
  return var_shape[1:]
1715
1764
 
@@ -83,7 +83,9 @@ def get_partical_vmap_rule(prim, axis_size):
83
83
  else:
84
84
  val, dim = val_bdim
85
85
  if dim is not None:
86
- _raise_value_error("The source axis of args in {} must be None, "
86
+ _raise_value_error("In the scenario where vmap contains control flow, currently only the "
87
+ "case of each batch branch with the same processing operations is "
88
+ "supported, so that the source axis of args in {} must be None, "
87
89
  "but got {}.".format(prim_name, dim))
88
90
  vals = vals + (val,)
89
91
 
@@ -1,4 +1,3 @@
1
-
2
1
  # Copyright 2022 Huawei Technologies Co., Ltd
3
2
  #
4
3
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -17,8 +16,11 @@
17
16
  """random_ops vmap impl."""
18
17
  from __future__ import absolute_import
19
18
 
20
- from mindspore.ops.operations.random_ops import UniformCandidateSampler, RandomShuffle
21
- from mindspore.ops._vmap.vmap_base import vmap_rules_getters, _bdim_at_front, _vmap_clone_prim, vmap_general_preprocess
19
+ from mindspore.ops.operations.random_ops import UniformCandidateSampler, RandomShuffle, Multinomial, \
20
+ RandomChoiceWithMask
21
+ from mindspore.ops.function import _VmapGeneralRule
22
+ from mindspore.ops._vmap.vmap_base import vmap_rules_getters, _bdim_at_front, _vmap_clone_prim, \
23
+ vmap_general_preprocess, _raise_value_error
22
24
 
23
25
 
24
26
  @vmap_rules_getters.register(UniformCandidateSampler)
@@ -68,3 +70,53 @@ def get_random_shuffle_vmap_rule(prim, axis_size):
68
70
  return out, 0
69
71
 
70
72
  return vmap_rule
73
+
74
+
75
+ @vmap_rules_getters.register(Multinomial)
76
+ def get_multinomial_vmap_rule(prim, axis_size):
77
+ """VmapRule for `Multinomial` operation."""
78
+ prim_name = prim.name
79
+ prim_vmap = _VmapGeneralRule(prim, axis_size)
80
+
81
+ def vmap_rule(x_bdim, num_samples_bdim):
82
+ is_all_none, result = vmap_general_preprocess(
83
+ prim, x_bdim, num_samples_bdim)
84
+ if is_all_none:
85
+ return result
86
+
87
+ x, x_dim = x_bdim
88
+ num_samples, num_samples_dim = num_samples_bdim
89
+ if len(x.shape) > 2:
90
+ out = prim_vmap(x_bdim, num_samples_bdim)
91
+ return out
92
+ if num_samples_dim is not None:
93
+ _raise_value_error("The source axis of args in {} must be None, "
94
+ "but got {}.".format(prim_name, num_samples_dim))
95
+ x = _bdim_at_front(x, x_dim, axis_size)
96
+ out = prim(x, num_samples)
97
+ return (out, 0)
98
+
99
+ return vmap_rule
100
+
101
+
102
+ @vmap_rules_getters.register(RandomChoiceWithMask)
103
+ def get_random_choice_with_mask(prim, axis_size):
104
+ """VmapRule for 'RandomChoiceWithMask' operation."""
105
+ if hasattr(prim, 'batch_rank'):
106
+ batch_rank = prim.batch_rank + 1
107
+ else:
108
+ batch_rank = 1
109
+
110
+ batch_prim = _vmap_clone_prim(prim)
111
+ batch_prim.add_prim_attr('batch_rank', batch_rank)
112
+
113
+ def vmap_rule(x_bdim):
114
+ is_all_none, result = vmap_general_preprocess(prim, x_bdim)
115
+ if is_all_none:
116
+ return result
117
+ x_data, x_dim = x_bdim
118
+ x = _bdim_at_front(x_data, x_dim, axis_size)
119
+ index, mask = batch_prim(x)
120
+ return (index, 0), (mask, 0)
121
+
122
+ return vmap_rule
@@ -14,6 +14,7 @@
14
14
  # ============================================================================
15
15
 
16
16
  """sparse_ops vmap impl."""
17
+ from __future__ import absolute_import
17
18
 
18
19
  from mindspore.ops._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, _raise_value_error
19
20
  from mindspore.ops.primitive import Primitive
@@ -1,20 +1,19 @@
1
1
 
2
- 0.1.1 MindSpore*1.9.0:�
3
- �
4
-
5
- bprop.33:xbprop.33:[CNode]34:1bprop.33:[CNode]34:1"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op23
6
- �
7
-
8
- bprop.33:ybprop.33:[CNode]35:3bprop.33:[CNode]35:3"(REF::S-Prim-hyper_map[zeros_like_leaf]:2:.Default/S-Prim-hyper_map[zeros_like_leaf]-op24
9
- �
10
- bprop.33:[CNode]34:1
11
- bprop.33:[CNode]35:3bprop.33:[CNode]36:4bprop.33:[CNode]36:4"REF::S-Prim-MakeTuple:5:Default/S-Prim-MakeTuple-op25bprop.33*
12
-
13
- bprop.33:x*
14
-
15
- bprop.33:y*
16
- bprop.33:out*
17
-
18
- bprop.33:[CNode]36:4:@157abcda70ad669686a320be427dedf2a17498d59a38b65d04c272c1a40296e0Pb&
19
- S-Prim-MakeTuple:5S-Prim-MakeTuplebH
20
- #S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
2
+ 0.1.1 MindSpore*2.0.0:�
3
+ �.get_bprop_approximate_equal.1184:[CNode]1185:1.get_bprop_approximate_equal.1184:[CNode]1185:1"REF::bprop.1186:Default/bprop.1186-op927 get_bprop_approximate_equal.1184*'
4
+ %get_bprop_approximate_equal.1184:self*$
5
+ "get_bprop_approximate_equal.1184:x*$
6
+ "get_bprop_approximate_equal.1184:y*&
7
+ $get_bprop_approximate_equal.1184:out*'
8
+ %get_bprop_approximate_equal.1184:dout20
9
+ .get_bprop_approximate_equal.1184:[CNode]1185:1:@7fb54a66e55c2c40cd92783044880b792666a0d7fc794bb717bee3544337d6a0J/grad_math_ops.pyB�
10
+ �
11
+ "get_bprop_approximate_equal.1184:xbprop.1186:[CNode]1187:2bprop.1186:[CNode]1187:2".REF::MetaFuncGraph::hyper_map[zeros_like_leaf]:/Default/S-Prim-hyper_map[zeros_like_leaf]-op928
12
+ �
13
+ "get_bprop_approximate_equal.1184:ybprop.1186:[CNode]1188:3bprop.1186:[CNode]1188:3".REF::MetaFuncGraph::hyper_map[zeros_like_leaf]:/Default/S-Prim-hyper_map[zeros_like_leaf]-op929
14
+ �
15
+ bprop.1186:[CNode]1187:2
16
+ bprop.1186:[CNode]1188:3bprop.1186:[CNode]1189:4bprop.1186:[CNode]1189:4"REF::S-Prim-MakeTuple:5:Default/S-Prim-MakeTuple-op930
17
+ bprop.11862
18
+ bprop.1186:[CNode]1189:4Pb&
19
+ S-Prim-MakeTuple:5S-Prim-MakeTupleh