mindspore 2.0.0a0__cp37-none-any.whl → 2.0.0rc1__cp37-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (693) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Third_Party_Open_Source_Software_Notice +9064 -0
  3. mindspore/__init__.py +4 -2
  4. mindspore/_akg/akg/composite/build_module.py +11 -0
  5. mindspore/_akg/akg/config/repository_cuda.json +11 -0
  6. mindspore/_akg/akg/tvm/contrib/nvcc.py +4 -3
  7. mindspore/_c_dataengine.cpython-37m-aarch64-linux-gnu.so +0 -0
  8. mindspore/_c_expression.cpython-37m-aarch64-linux-gnu.so +0 -0
  9. mindspore/_c_mindrecord.cpython-37m-aarch64-linux-gnu.so +0 -0
  10. mindspore/_check_jit_forbidden_api.py +102 -0
  11. mindspore/_checkparam.py +1066 -1001
  12. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +4 -3
  13. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +50 -48
  14. mindspore/_extends/parallel_compile/akg_compiler/util.py +9 -4
  15. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +4 -4
  16. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +9 -4
  17. mindspore/_extends/parse/__init__.py +5 -3
  18. mindspore/_extends/parse/namespace.py +16 -1
  19. mindspore/_extends/parse/parser.py +107 -22
  20. mindspore/_extends/parse/resources.py +0 -7
  21. mindspore/_extends/parse/standard_method.py +885 -413
  22. mindspore/_mindspore_offline_debug.cpython-37m-aarch64-linux-gnu.so +0 -0
  23. mindspore/amp.py +52 -57
  24. mindspore/bin/cache_admin +0 -0
  25. mindspore/bin/cache_server +0 -0
  26. mindspore/boost/boost.py +2 -2
  27. mindspore/boost/boost_cell_wrapper.py +38 -20
  28. mindspore/boost/dim_reduce.py +3 -3
  29. mindspore/boost/group_loss_scale_manager.py +1 -1
  30. mindspore/common/__init__.py +4 -6
  31. mindspore/common/_decorator.py +2 -0
  32. mindspore/common/_register_for_adapter.py +55 -0
  33. mindspore/common/_stub_tensor.py +201 -0
  34. mindspore/common/_utils.py +41 -7
  35. mindspore/common/api.py +215 -141
  36. mindspore/common/dtype.py +8 -1
  37. mindspore/common/dump.py +2 -2
  38. mindspore/common/initializer.py +4 -2
  39. mindspore/common/jit_config.py +17 -13
  40. mindspore/common/mutable.py +33 -13
  41. mindspore/common/parameter.py +23 -21
  42. mindspore/common/seed.py +8 -24
  43. mindspore/common/sparse_tensor.py +62 -41
  44. mindspore/common/tensor.py +852 -1154
  45. mindspore/communication/__init__.py +2 -2
  46. mindspore/communication/_comm_helper.py +11 -4
  47. mindspore/communication/management.py +22 -21
  48. mindspore/config/op_info.config +501 -1008
  49. mindspore/config/super_bar_config.json +512 -0
  50. mindspore/context.py +201 -23
  51. mindspore/dataset/__init__.py +6 -6
  52. mindspore/dataset/audio/__init__.py +7 -7
  53. mindspore/dataset/audio/transforms.py +670 -30
  54. mindspore/dataset/audio/utils.py +47 -4
  55. mindspore/dataset/audio/validators.py +223 -1
  56. mindspore/dataset/callback/ds_callback.py +2 -2
  57. mindspore/dataset/core/config.py +210 -14
  58. mindspore/dataset/core/validator_helpers.py +2 -2
  59. mindspore/{parallel/nn/layers.py → dataset/debug/__init__.py} +7 -8
  60. mindspore/dataset/debug/debug_hook.py +65 -0
  61. mindspore/dataset/debug/pre_defined_hook.py +67 -0
  62. mindspore/dataset/engine/__init__.py +7 -3
  63. mindspore/dataset/engine/cache_client.py +1 -1
  64. mindspore/dataset/engine/datasets.py +322 -66
  65. mindspore/dataset/engine/datasets_audio.py +80 -76
  66. mindspore/dataset/engine/datasets_standard_format.py +51 -38
  67. mindspore/dataset/engine/datasets_text.py +232 -118
  68. mindspore/dataset/engine/datasets_user_defined.py +41 -17
  69. mindspore/dataset/engine/datasets_vision.py +746 -225
  70. mindspore/dataset/engine/graphdata.py +75 -10
  71. mindspore/dataset/engine/iterators.py +45 -5
  72. mindspore/dataset/engine/offload.py +48 -28
  73. mindspore/dataset/engine/validators.py +117 -8
  74. mindspore/dataset/text/__init__.py +6 -5
  75. mindspore/dataset/text/transforms.py +86 -3
  76. mindspore/dataset/text/utils.py +6 -4
  77. mindspore/dataset/text/validators.py +25 -0
  78. mindspore/dataset/transforms/__init__.py +3 -2
  79. mindspore/dataset/transforms/c_transforms.py +1 -1
  80. mindspore/dataset/transforms/transforms.py +2 -2
  81. mindspore/dataset/utils/__init__.py +2 -1
  82. mindspore/dataset/utils/line_reader.py +121 -0
  83. mindspore/dataset/vision/__init__.py +2 -3
  84. mindspore/dataset/vision/c_transforms.py +9 -9
  85. mindspore/dataset/vision/py_transforms.py +5 -5
  86. mindspore/dataset/vision/py_transforms_util.py +2 -0
  87. mindspore/dataset/vision/transforms.py +160 -161
  88. mindspore/dataset/vision/utils.py +3 -3
  89. mindspore/experimental/map_parameter.py +38 -26
  90. mindspore/include/OWNERS +0 -1
  91. mindspore/include/api/callback/callback.h +9 -13
  92. mindspore/include/api/callback/ckpt_saver.h +2 -2
  93. mindspore/include/api/callback/loss_monitor.h +2 -2
  94. mindspore/include/api/callback/lr_scheduler.h +5 -5
  95. mindspore/include/api/callback/time_monitor.h +2 -2
  96. mindspore/include/api/callback/train_accuracy.h +4 -6
  97. mindspore/include/api/cfg.h +19 -6
  98. mindspore/include/api/context.h +44 -9
  99. mindspore/include/api/delegate.h +1 -1
  100. mindspore/include/api/metrics/accuracy.h +2 -2
  101. mindspore/include/api/metrics/metrics.h +4 -3
  102. mindspore/include/api/model.h +9 -4
  103. mindspore/include/api/model_parallel_runner.h +2 -2
  104. mindspore/include/api/net.h +12 -11
  105. mindspore/include/api/serialization.h +19 -3
  106. mindspore/include/api/types.h +3 -3
  107. mindspore/include/dataset/constants.h +7 -0
  108. mindspore/include/dataset/text.h +59 -0
  109. mindspore/include/mindapi/base/type_id.h +1 -0
  110. mindspore/lib/libdnnl.so.2 +0 -0
  111. mindspore/lib/libicudata.so.69 +0 -0
  112. mindspore/lib/libicui18n.so.69 +0 -0
  113. mindspore/lib/libicuuc.so.69 +0 -0
  114. mindspore/lib/libmindspore.so +0 -0
  115. mindspore/lib/libmindspore_backend.so +0 -0
  116. mindspore/lib/libmindspore_common.so +0 -0
  117. mindspore/lib/libmindspore_core.so +0 -0
  118. mindspore/lib/libmindspore_glog.so.0 +0 -0
  119. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  120. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  121. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  122. mindspore/lib/libmindspore_shared_lib.so +0 -0
  123. mindspore/lib/libmpi_adapter.so +0 -0
  124. mindspore/lib/libmpi_collective.so +0 -0
  125. mindspore/lib/libnnacl.so +0 -0
  126. mindspore/lib/libopencv_core.so.4.5 +0 -0
  127. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  128. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  129. mindspore/lib/libps_cache.so +0 -0
  130. mindspore/lib/plugin/ascend/libakg.so +0 -0
  131. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  132. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  133. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  134. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  135. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  136. mindspore/lib/plugin/cpu/libakg.so +0 -0
  137. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  138. mindspore/lib/plugin/{libmindspore_ascend.so → libmindspore_ascend.so.2} +0 -0
  139. mindspore/log.py +1 -1
  140. mindspore/mindrecord/filereader.py +18 -0
  141. mindspore/mindrecord/filewriter.py +197 -34
  142. mindspore/mindrecord/shardreader.py +9 -0
  143. mindspore/mindrecord/shardwriter.py +1 -1
  144. mindspore/mindrecord/tools/cifar100_to_mr.py +3 -3
  145. mindspore/mindrecord/tools/cifar10_to_mr.py +3 -3
  146. mindspore/mindrecord/tools/csv_to_mr.py +3 -3
  147. mindspore/mindrecord/tools/imagenet_to_mr.py +16 -11
  148. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  149. mindspore/mindrecord/tools/tfrecord_to_mr.py +6 -6
  150. mindspore/nn/__init__.py +0 -4
  151. mindspore/nn/cell.py +204 -132
  152. mindspore/nn/dynamic_lr.py +1 -1
  153. mindspore/nn/grad/cell_grad.py +7 -6
  154. mindspore/nn/layer/__init__.py +5 -4
  155. mindspore/nn/layer/activation.py +40 -89
  156. mindspore/nn/layer/basic.py +255 -624
  157. mindspore/nn/layer/channel_shuffle.py +7 -6
  158. mindspore/nn/layer/combined.py +1 -1
  159. mindspore/nn/layer/container.py +41 -4
  160. mindspore/nn/layer/conv.py +64 -28
  161. mindspore/nn/layer/dense.py +9 -8
  162. mindspore/nn/layer/embedding.py +27 -25
  163. mindspore/nn/layer/image.py +53 -46
  164. mindspore/nn/layer/math.py +97 -105
  165. mindspore/nn/layer/normalization.py +117 -86
  166. mindspore/nn/layer/padding.py +185 -95
  167. mindspore/nn/layer/pooling.py +817 -414
  168. mindspore/nn/layer/rnn_cells.py +10 -15
  169. mindspore/nn/layer/rnns.py +37 -38
  170. mindspore/nn/layer/thor_layer.py +11 -12
  171. mindspore/nn/layer/timedistributed.py +5 -5
  172. mindspore/nn/layer/transformer.py +701 -0
  173. mindspore/nn/learning_rate_schedule.py +8 -8
  174. mindspore/nn/loss/__init__.py +5 -4
  175. mindspore/nn/loss/loss.py +334 -199
  176. mindspore/nn/optim/ada_grad.py +6 -6
  177. mindspore/nn/optim/adadelta.py +2 -3
  178. mindspore/nn/optim/adafactor.py +4 -5
  179. mindspore/nn/optim/adam.py +126 -62
  180. mindspore/nn/optim/adamax.py +3 -4
  181. mindspore/nn/optim/adasum.py +6 -6
  182. mindspore/nn/optim/asgd.py +2 -2
  183. mindspore/nn/optim/ftrl.py +67 -38
  184. mindspore/nn/optim/lamb.py +4 -5
  185. mindspore/nn/optim/lars.py +2 -2
  186. mindspore/nn/optim/lazyadam.py +43 -4
  187. mindspore/nn/optim/momentum.py +6 -5
  188. mindspore/nn/optim/optimizer.py +3 -1
  189. mindspore/nn/optim/proximal_ada_grad.py +2 -2
  190. mindspore/nn/optim/rmsprop.py +1 -1
  191. mindspore/nn/optim/rprop.py +8 -9
  192. mindspore/nn/optim/sgd.py +19 -13
  193. mindspore/nn/optim/thor.py +10 -15
  194. mindspore/nn/probability/__init__.py +0 -2
  195. mindspore/nn/probability/bijector/bijector.py +4 -4
  196. mindspore/nn/probability/bijector/invert.py +1 -1
  197. mindspore/nn/probability/bijector/softplus.py +2 -2
  198. mindspore/nn/probability/bnn_layers/dense_variational.py +1 -1
  199. mindspore/nn/probability/bnn_layers/layer_distribution.py +2 -2
  200. mindspore/nn/probability/distribution/_utils/utils.py +9 -15
  201. mindspore/nn/probability/distribution/bernoulli.py +3 -3
  202. mindspore/nn/probability/distribution/beta.py +1 -1
  203. mindspore/nn/probability/distribution/categorical.py +5 -7
  204. mindspore/nn/probability/distribution/cauchy.py +3 -3
  205. mindspore/nn/probability/distribution/distribution.py +2 -2
  206. mindspore/nn/probability/distribution/exponential.py +2 -2
  207. mindspore/nn/probability/distribution/gamma.py +3 -3
  208. mindspore/nn/probability/distribution/geometric.py +1 -1
  209. mindspore/nn/probability/distribution/gumbel.py +3 -3
  210. mindspore/nn/probability/distribution/half_normal.py +15 -11
  211. mindspore/nn/probability/distribution/laplace.py +16 -13
  212. mindspore/nn/probability/distribution/logistic.py +2 -2
  213. mindspore/nn/probability/distribution/normal.py +1 -1
  214. mindspore/nn/probability/distribution/poisson.py +1 -1
  215. mindspore/nn/probability/distribution/student_t.py +20 -15
  216. mindspore/nn/probability/distribution/transformed_distribution.py +4 -4
  217. mindspore/nn/probability/distribution/uniform.py +2 -2
  218. mindspore/nn/reinforcement/_tensors_queue.py +3 -3
  219. mindspore/nn/reinforcement/tensor_array.py +2 -2
  220. mindspore/nn/sparse/sparse.py +2 -2
  221. mindspore/nn/wrap/cell_wrapper.py +27 -10
  222. mindspore/nn/wrap/grad_reducer.py +2 -2
  223. mindspore/nn/wrap/loss_scale.py +40 -24
  224. mindspore/numpy/array_creations.py +33 -22
  225. mindspore/numpy/array_ops.py +35 -30
  226. mindspore/numpy/logic_ops.py +6 -27
  227. mindspore/numpy/math_ops.py +22 -19
  228. mindspore/numpy/utils.py +1 -1
  229. mindspore/numpy/utils_const.py +108 -58
  230. mindspore/ops/_constants.py +0 -6
  231. mindspore/ops/_grad/__init__.py +2 -1
  232. mindspore/ops/_grad/grad_array_ops.py +86 -117
  233. mindspore/ops/_grad/grad_base.py +23 -1
  234. mindspore/ops/_grad/grad_clip_ops.py +2 -3
  235. mindspore/ops/_grad/grad_comm_ops.py +34 -24
  236. mindspore/ops/_grad/grad_implementations.py +9 -45
  237. mindspore/ops/_grad/grad_inner_ops.py +47 -4
  238. mindspore/ops/_grad/grad_math_ops.py +142 -117
  239. mindspore/ops/_grad/grad_nn_ops.py +71 -165
  240. mindspore/ops/_grad/grad_sequence_ops.py +296 -0
  241. mindspore/ops/_grad/grad_sparse.py +7 -6
  242. mindspore/ops/_grad_experimental/__init__.py +1 -0
  243. mindspore/ops/_grad_experimental/grad_array_ops.py +150 -15
  244. mindspore/ops/_grad_experimental/grad_image_ops.py +16 -7
  245. mindspore/ops/_grad_experimental/grad_inner_ops.py +1 -22
  246. mindspore/ops/_grad_experimental/grad_linalg_ops.py +4 -11
  247. mindspore/ops/_grad_experimental/grad_math_ops.py +210 -89
  248. mindspore/ops/_grad_experimental/grad_nn_ops.py +26 -22
  249. mindspore/ops/_grad_experimental/grad_scalar_ops.py +112 -0
  250. mindspore/ops/_grad_experimental/grad_sparse_ops.py +49 -8
  251. mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +1 -1
  252. mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +2 -2
  253. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +2 -2
  254. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +2 -2
  255. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +4 -4
  256. mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +3 -3
  257. mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +1 -1
  258. mindspore/ops/_op_impl/_custom_op/correction_mul.py +2 -2
  259. mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +2 -2
  260. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -5
  261. mindspore/ops/_op_impl/_custom_op/dsd_impl.py +1 -1
  262. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +2 -2
  263. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +2 -2
  264. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +2 -2
  265. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +2 -2
  266. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +2 -2
  267. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +2 -2
  268. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +2 -2
  269. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +2 -2
  270. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +2 -2
  271. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +2 -2
  272. mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +1 -1
  273. mindspore/ops/_op_impl/_custom_op/img2col_impl.py +1 -1
  274. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
  275. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +1 -1
  276. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +1 -1
  277. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +1 -1
  278. mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +2 -2
  279. mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +0 -4
  280. mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +1 -1
  281. mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +2 -2
  282. mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +2 -2
  283. mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +1 -1
  284. mindspore/ops/_op_impl/aicpu/__init__.py +236 -4
  285. mindspore/ops/_op_impl/aicpu/abs.py +36 -0
  286. mindspore/ops/_op_impl/aicpu/{adaptive_avg_pool_2d_v1.py → adaptive_avg_pool_2d.py} +6 -5
  287. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
  288. mindspore/ops/_op_impl/aicpu/add.py +43 -0
  289. mindspore/ops/_op_impl/aicpu/addcdiv.py +0 -32
  290. mindspore/ops/_op_impl/aicpu/addcmul.py +0 -84
  291. mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
  292. mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -43
  293. mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
  294. mindspore/{compression/common/__init__.py → ops/_op_impl/aicpu/bessel_i0.py} +15 -8
  295. mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
  296. mindspore/ops/_op_impl/aicpu/conj.py +11 -0
  297. mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +0 -3
  298. mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
  299. mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +43 -0
  300. mindspore/ops/_op_impl/aicpu/{adaptive_avg_pool_2d_grad_v1.py → digamma.py} +7 -9
  301. mindspore/ops/_op_impl/aicpu/flatten.py +1 -0
  302. mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
  303. mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
  304. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +1 -1
  305. mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
  306. mindspore/ops/_op_impl/aicpu/greater.py +41 -0
  307. mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
  308. mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
  309. mindspore/ops/_op_impl/aicpu/less.py +41 -0
  310. mindspore/{nn/probability/infer/variational/__init__.py → ops/_op_impl/aicpu/lgamma.py} +16 -10
  311. mindspore/ops/_op_impl/aicpu/mirror_pad.py +0 -4
  312. mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +0 -4
  313. mindspore/ops/_op_impl/aicpu/mul.py +3 -1
  314. mindspore/ops/_op_impl/aicpu/multinomial.py +14 -6
  315. mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
  316. mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
  317. mindspore/ops/_op_impl/aicpu/ones_like.py +0 -2
  318. mindspore/ops/_op_impl/aicpu/polar.py +32 -0
  319. mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
  320. mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
  321. mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
  322. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
  323. mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
  324. mindspore/ops/_op_impl/aicpu/resize_bicubic.py +2 -8
  325. mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +1 -1
  326. mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
  327. mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
  328. mindspore/ops/_op_impl/aicpu/scatter_elements.py +4 -0
  329. mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +2 -0
  330. mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
  331. mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
  332. mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
  333. mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
  334. mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
  335. mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +0 -24
  336. mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
  337. mindspore/ops/_op_impl/aicpu/sparse_slice.py +4 -0
  338. mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +6 -0
  339. mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
  340. mindspore/ops/_op_impl/aicpu/trans_data.py +1 -0
  341. mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
  342. mindspore/ops/_op_impl/aicpu/uniform.py +34 -0
  343. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +1 -0
  344. mindspore/ops/_op_impl/aicpu/unique_consecutive.py +10 -2
  345. mindspore/ops/_op_impl/cpu/dynamic_shape.py +5 -1
  346. mindspore/ops/_op_impl/cpu/sparse_slice.py +4 -0
  347. mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +6 -0
  348. mindspore/ops/_op_impl/cpu/tensor_shape.py +5 -1
  349. mindspore/ops/_op_impl/tbe/__init__.py +27 -611
  350. mindspore/ops/_op_impl/tbe/assign_add_ds.py +1 -0
  351. mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
  352. mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +1 -1
  353. mindspore/ops/_op_impl/tbe/batch_matmul_ds.py +1 -0
  354. mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
  355. mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +1 -1
  356. mindspore/ops/_op_impl/tbe/bn_infer_grad.py +4 -2
  357. mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -1
  358. mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -1
  359. mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +6 -4
  360. mindspore/ops/_op_impl/tbe/cast.py +0 -2
  361. mindspore/ops/_op_impl/tbe/cast_ds.py +3 -3
  362. mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +1 -0
  363. mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +2 -2
  364. mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +1 -1
  365. mindspore/ops/_op_impl/tbe/gather_nd.py +1 -0
  366. mindspore/ops/_op_impl/tbe/{index_add.py → inplace_index_add.py} +3 -6
  367. mindspore/ops/_op_impl/tbe/matmul_ds.py +2 -0
  368. mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +35 -0
  369. mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +35 -0
  370. mindspore/ops/_op_impl/tbe/scatter_mul.py +2 -0
  371. mindspore/ops/_op_impl/tbe/scatter_nd_add.py +0 -2
  372. mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
  373. mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +1 -1
  374. mindspore/ops/_op_impl/tbe/trans_data_ds.py +15 -5
  375. mindspore/ops/_register_for_op.py +1 -0
  376. mindspore/ops/_utils/__init__.py +1 -2
  377. mindspore/ops/_utils/utils.py +19 -40
  378. mindspore/ops/_vmap/vmap_array_ops.py +116 -38
  379. mindspore/ops/_vmap/vmap_base.py +16 -9
  380. mindspore/ops/_vmap/vmap_convolution_ops.py +7 -10
  381. mindspore/ops/_vmap/vmap_grad_math_ops.py +4 -4
  382. mindspore/ops/_vmap/vmap_grad_nn_ops.py +7 -5
  383. mindspore/ops/_vmap/vmap_image_ops.py +12 -5
  384. mindspore/ops/_vmap/vmap_math_ops.py +46 -5
  385. mindspore/ops/_vmap/vmap_nn_ops.py +15 -21
  386. mindspore/ops/_vmap/vmap_random_ops.py +1 -1
  387. mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
  388. mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
  389. mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +150 -0
  390. mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +66 -0
  391. mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
  392. mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
  393. mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
  394. mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +33 -0
  395. mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +220 -106
  396. mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
  397. mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +240 -0
  398. mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +247 -0
  399. mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +247 -0
  400. mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +315 -0
  401. mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +278 -0
  402. mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +58 -0
  403. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +138 -0
  404. mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
  405. mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
  406. mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +22 -23
  407. mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +16 -17
  408. mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +27 -0
  409. mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
  410. mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
  411. mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
  412. mindspore/ops/bprop_mindir/Elu_bprop.mindir +16 -0
  413. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  414. mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +39 -41
  415. mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +16 -0
  416. mindspore/ops/bprop_mindir/Flatten_bprop.mindir +41 -43
  417. mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +51 -57
  418. mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
  419. mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +16 -0
  420. mindspore/ops/bprop_mindir/HSwish_bprop.mindir +16 -0
  421. mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
  422. mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +126 -0
  423. mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +15 -0
  424. mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +30 -0
  425. mindspore/ops/bprop_mindir/LRN_bprop.mindir +43 -0
  426. mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
  427. mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +23 -0
  428. mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +74 -0
  429. mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +74 -0
  430. mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +75 -0
  431. mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +65 -0
  432. mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
  433. mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +27 -0
  434. mindspore/ops/bprop_mindir/Mish_bprop.mindir +35 -0
  435. mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
  436. mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
  437. mindspore/ops/bprop_mindir/OneHot_bprop.mindir +24 -25
  438. mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
  439. mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
  440. mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
  441. mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +29 -0
  442. mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +82 -0
  443. mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +16 -0
  444. mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
  445. mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +18 -19
  446. mindspore/ops/bprop_mindir/Reshape_bprop.mindir +53 -53
  447. mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +29 -0
  448. mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +77 -85
  449. mindspore/ops/bprop_mindir/SeLU_bprop.mindir +21 -0
  450. mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +21 -0
  451. mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
  452. mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +16 -0
  453. mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +36 -0
  454. mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  455. mindspore/ops/bprop_mindir/Softplus_bprop.mindir +16 -0
  456. mindspore/ops/bprop_mindir/Softsign_bprop.mindir +33 -0
  457. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  458. mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +37 -39
  459. mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +70 -72
  460. mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
  461. mindspore/ops/bprop_mindir/Tanh_bprop.mindir +66 -0
  462. mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
  463. mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
  464. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +17 -17
  465. mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +32 -0
  466. mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +38 -0
  467. mindspore/ops/bprop_mindir/generate_mindir.py +2 -0
  468. mindspore/ops/composite/__init__.py +7 -8
  469. mindspore/ops/composite/base.py +101 -47
  470. mindspore/ops/composite/math_ops.py +188 -158
  471. mindspore/ops/composite/multitype_ops/_compile_utils.py +415 -170
  472. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +142 -87
  473. mindspore/ops/composite/multitype_ops/add_impl.py +6 -1
  474. mindspore/ops/composite/multitype_ops/div_impl.py +2 -3
  475. mindspore/ops/composite/multitype_ops/getitem_impl.py +31 -3
  476. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +31 -0
  477. mindspore/ops/composite/multitype_ops/greater_impl.py +31 -0
  478. mindspore/ops/composite/multitype_ops/in_impl.py +9 -0
  479. mindspore/ops/composite/multitype_ops/less_equal_impl.py +31 -0
  480. mindspore/ops/composite/multitype_ops/less_impl.py +31 -0
  481. mindspore/ops/composite/multitype_ops/mul_impl.py +21 -5
  482. mindspore/ops/composite/multitype_ops/not_in_impl.py +9 -0
  483. mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -4
  484. mindspore/ops/composite/multitype_ops/setitem_impl.py +21 -3
  485. mindspore/ops/composite/multitype_ops/sub_impl.py +1 -1
  486. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +35 -4
  487. mindspore/ops/function/__init__.py +152 -8
  488. mindspore/ops/function/array_func.py +2555 -674
  489. mindspore/ops/function/clip_func.py +209 -13
  490. mindspore/ops/function/debug_func.py +2 -2
  491. mindspore/ops/function/grad/__init__.py +2 -1
  492. mindspore/ops/function/grad/grad_func.py +147 -62
  493. mindspore/ops/function/image_func.py +54 -38
  494. mindspore/ops/function/linalg_func.py +167 -16
  495. mindspore/ops/function/math_func.py +4849 -1492
  496. mindspore/ops/function/nn_func.py +2573 -988
  497. mindspore/ops/function/other_func.py +115 -0
  498. mindspore/ops/function/parameter_func.py +3 -3
  499. mindspore/ops/function/random_func.py +790 -73
  500. mindspore/ops/function/sparse_func.py +98 -78
  501. mindspore/ops/function/sparse_unary_func.py +54 -53
  502. mindspore/ops/function/spectral_func.py +27 -24
  503. mindspore/ops/function/vmap_func.py +22 -2
  504. mindspore/ops/functional.py +97 -37
  505. mindspore/ops/op_info_register.py +70 -28
  506. mindspore/ops/operations/__init__.py +47 -14
  507. mindspore/ops/operations/_csr_ops.py +7 -7
  508. mindspore/ops/operations/_embedding_cache_ops.py +5 -5
  509. mindspore/ops/operations/_grad_ops.py +276 -187
  510. mindspore/ops/operations/_inner_ops.py +319 -113
  511. mindspore/ops/operations/_ms_kernel.py +10 -8
  512. mindspore/ops/operations/_ocr_ops.py +9 -9
  513. mindspore/ops/operations/_opaque_predicate_registry.py +4 -0
  514. mindspore/ops/operations/_quant_ops.py +137 -102
  515. mindspore/ops/operations/_rl_inner_ops.py +121 -60
  516. mindspore/ops/operations/_scalar_ops.py +466 -0
  517. mindspore/ops/operations/_sequence_ops.py +1004 -2
  518. mindspore/ops/operations/_tensor_array.py +10 -11
  519. mindspore/ops/operations/_thor_ops.py +1 -1
  520. mindspore/ops/operations/array_ops.py +801 -466
  521. mindspore/ops/operations/comm_ops.py +51 -49
  522. mindspore/ops/operations/control_ops.py +2 -2
  523. mindspore/ops/operations/custom_ops.py +123 -44
  524. mindspore/ops/operations/debug_ops.py +24 -24
  525. mindspore/ops/operations/image_ops.py +240 -153
  526. mindspore/ops/operations/inner_ops.py +34 -50
  527. mindspore/ops/operations/linalg_ops.py +31 -9
  528. mindspore/ops/operations/math_ops.py +988 -757
  529. mindspore/ops/operations/nn_ops.py +965 -819
  530. mindspore/ops/operations/other_ops.py +51 -40
  531. mindspore/ops/operations/random_ops.py +204 -122
  532. mindspore/ops/operations/rl_ops.py +8 -9
  533. mindspore/ops/operations/sparse_ops.py +254 -93
  534. mindspore/ops/operations/spectral_ops.py +35 -3
  535. mindspore/ops/primitive.py +111 -9
  536. mindspore/parallel/_auto_parallel_context.py +189 -83
  537. mindspore/parallel/_offload_context.py +185 -0
  538. mindspore/parallel/_parallel_serialization.py +99 -7
  539. mindspore/parallel/_ps_context.py +9 -5
  540. mindspore/parallel/_recovery_context.py +1 -1
  541. mindspore/parallel/_tensor.py +7 -1
  542. mindspore/{nn/transformer → parallel/_transformer}/__init__.py +6 -6
  543. mindspore/{nn/transformer → parallel/_transformer}/layers.py +6 -37
  544. mindspore/{nn/transformer → parallel/_transformer}/loss.py +4 -7
  545. mindspore/{nn/transformer → parallel/_transformer}/moe.py +20 -16
  546. mindspore/{nn/transformer → parallel/_transformer}/op_parallel_config.py +3 -3
  547. mindspore/{nn/transformer → parallel/_transformer}/transformer.py +48 -111
  548. mindspore/parallel/_utils.py +1 -2
  549. mindspore/parallel/algo_parameter_config.py +1 -1
  550. mindspore/parallel/checkpoint_transform.py +37 -34
  551. mindspore/parallel/shard.py +17 -18
  552. mindspore/profiler/common/validator/validate_path.py +2 -2
  553. mindspore/profiler/envprofiling.py +69 -47
  554. mindspore/profiler/parser/ascend_timeline_generator.py +49 -42
  555. mindspore/profiler/parser/base_timeline_generator.py +49 -56
  556. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +98 -78
  557. mindspore/profiler/parser/hwts_log_parser.py +1 -1
  558. mindspore/profiler/parser/integrator.py +15 -14
  559. mindspore/profiler/parser/minddata_analyzer.py +2 -2
  560. mindspore/profiler/parser/msadvisor_analyzer.py +12 -25
  561. mindspore/profiler/parser/msadvisor_parser.py +2 -4
  562. mindspore/profiler/parser/optime_parser.py +17 -18
  563. mindspore/profiler/parser/profiler_info.py +2 -1
  564. mindspore/profiler/profiling.py +218 -186
  565. mindspore/rewrite/__init__.py +3 -1
  566. mindspore/rewrite/api/node.py +1 -114
  567. mindspore/rewrite/api/node_type.py +3 -0
  568. mindspore/rewrite/api/pattern_engine.py +31 -1
  569. mindspore/rewrite/api/scoped_value.py +4 -4
  570. mindspore/rewrite/api/symbol_tree.py +3 -78
  571. mindspore/rewrite/api/tree_node_helper.py +1 -1
  572. mindspore/rewrite/ast_creator_register.py +1 -0
  573. mindspore/rewrite/ast_helpers/__init__.py +2 -2
  574. mindspore/rewrite/ast_helpers/ast_creator.py +1 -2
  575. mindspore/rewrite/ast_helpers/ast_finder.py +65 -0
  576. mindspore/rewrite/ast_helpers/ast_modifier.py +11 -3
  577. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +18 -2
  578. mindspore/rewrite/namespace.py +0 -2
  579. mindspore/rewrite/node.py +157 -11
  580. mindspore/rewrite/parsers/assign_parser.py +231 -53
  581. mindspore/rewrite/parsers/class_def_parser.py +187 -109
  582. mindspore/rewrite/parsers/for_parser.py +24 -14
  583. mindspore/rewrite/parsers/function_def_parser.py +21 -4
  584. mindspore/rewrite/parsers/if_parser.py +6 -2
  585. mindspore/rewrite/sparsify/__init__.py +0 -0
  586. mindspore/rewrite/sparsify/sparse_transformer.py +448 -0
  587. mindspore/rewrite/sparsify/sparsify.py +109 -0
  588. mindspore/rewrite/sparsify/utils.py +173 -0
  589. mindspore/rewrite/symbol_tree.py +256 -133
  590. mindspore/rewrite/symbol_tree_builder.py +38 -1
  591. mindspore/run_check/_check_version.py +69 -63
  592. mindspore/run_check/run_check.py +2 -1
  593. mindspore/scipy/linalg.py +10 -114
  594. mindspore/scipy/ops.py +2 -2
  595. mindspore/scipy/ops_wrapper.py +1 -1
  596. mindspore/scipy/optimize/_bfgs.py +1 -1
  597. mindspore/scipy/optimize/_lagrange.py +200 -0
  598. mindspore/scipy/optimize/line_search.py +3 -2
  599. mindspore/scipy/optimize/minimize.py +41 -2
  600. mindspore/scipy/sparse/__init__.py +2 -2
  601. mindspore/scipy/sparse/linalg.py +4 -464
  602. mindspore/scipy/utils.py +1 -1
  603. mindspore/scipy/utils_const.py +7 -1
  604. mindspore/train/__init__.py +1 -1
  605. mindspore/train/_utils.py +28 -5
  606. mindspore/train/amp.py +273 -102
  607. mindspore/train/callback/_backup_and_restore.py +5 -5
  608. mindspore/train/callback/_callback.py +2 -2
  609. mindspore/train/callback/_checkpoint.py +3 -3
  610. mindspore/train/callback/_early_stop.py +3 -3
  611. mindspore/train/callback/_lambda_callback.py +2 -2
  612. mindspore/train/callback/_landscape.py +29 -31
  613. mindspore/train/callback/_loss_monitor.py +3 -3
  614. mindspore/train/callback/_on_request_exit.py +3 -3
  615. mindspore/train/callback/_reduce_lr_on_plateau.py +4 -4
  616. mindspore/train/callback/_summary_collector.py +23 -16
  617. mindspore/train/callback/_time_monitor.py +3 -3
  618. mindspore/train/checkpoint_pb2.py +68 -8
  619. mindspore/train/data_sink.py +15 -3
  620. mindspore/train/dataset_helper.py +10 -15
  621. mindspore/train/loss_scale_manager.py +8 -11
  622. mindspore/train/metrics/__init__.py +1 -1
  623. mindspore/train/metrics/bleu_score.py +1 -1
  624. mindspore/train/metrics/confusion_matrix.py +1 -1
  625. mindspore/train/metrics/cosine_similarity.py +1 -1
  626. mindspore/train/metrics/dice.py +2 -2
  627. mindspore/train/metrics/fbeta.py +1 -1
  628. mindspore/train/metrics/hausdorff_distance.py +4 -3
  629. mindspore/train/metrics/mean_surface_distance.py +2 -2
  630. mindspore/train/metrics/occlusion_sensitivity.py +1 -1
  631. mindspore/train/metrics/perplexity.py +1 -1
  632. mindspore/train/metrics/precision.py +1 -1
  633. mindspore/train/metrics/recall.py +1 -1
  634. mindspore/train/metrics/roc.py +2 -2
  635. mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
  636. mindspore/train/mind_ir_pb2.py +116 -37
  637. mindspore/train/model.py +45 -28
  638. mindspore/train/serialization.py +295 -188
  639. mindspore/train/summary/_summary_adapter.py +1 -1
  640. mindspore/train/summary/summary_record.py +43 -13
  641. mindspore/train/train_thor/convert_utils.py +2 -2
  642. mindspore/train/train_thor/dataset_helper.py +3 -3
  643. mindspore/version.py +1 -1
  644. {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/METADATA +3 -2
  645. {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/RECORD +648 -574
  646. mindspore/compression/__init__.py +0 -19
  647. mindspore/compression/common/constant.py +0 -124
  648. mindspore/compression/export/__init__.py +0 -19
  649. mindspore/compression/export/quant_export.py +0 -515
  650. mindspore/compression/quant/__init__.py +0 -28
  651. mindspore/compression/quant/qat.py +0 -634
  652. mindspore/compression/quant/quant_utils.py +0 -462
  653. mindspore/compression/quant/quantizer.py +0 -68
  654. mindspore/nn/layer/quant.py +0 -1868
  655. mindspore/nn/layer/rnn_utils.py +0 -90
  656. mindspore/nn/probability/dpn/__init__.py +0 -22
  657. mindspore/nn/probability/dpn/vae/__init__.py +0 -25
  658. mindspore/nn/probability/dpn/vae/cvae.py +0 -140
  659. mindspore/nn/probability/dpn/vae/vae.py +0 -124
  660. mindspore/nn/probability/infer/__init__.py +0 -22
  661. mindspore/nn/probability/infer/variational/elbo.py +0 -70
  662. mindspore/nn/probability/infer/variational/svi.py +0 -84
  663. mindspore/nn/probability/toolbox/__init__.py +0 -22
  664. mindspore/nn/probability/toolbox/anomaly_detection.py +0 -99
  665. mindspore/nn/probability/toolbox/uncertainty_evaluation.py +0 -364
  666. mindspore/nn/probability/transforms/__init__.py +0 -22
  667. mindspore/nn/probability/transforms/transform_bnn.py +0 -262
  668. mindspore/nn/probability/zhusuan/__init__.py +0 -18
  669. mindspore/nn/probability/zhusuan/framework/__init__.py +0 -18
  670. mindspore/nn/probability/zhusuan/framework/bn.py +0 -95
  671. mindspore/nn/probability/zhusuan/variational/__init__.py +0 -18
  672. mindspore/nn/probability/zhusuan/variational/elbo.py +0 -46
  673. mindspore/ops/_op_impl/aicpu/parallel_concat.py +0 -42
  674. mindspore/ops/_op_impl/tbe/gather_v2.py +0 -56
  675. mindspore/ops/bprop_mindir/AssignAdd_bprop.mindir +0 -19
  676. mindspore/ops/bprop_mindir/Cast_bprop.mindir +0 -19
  677. mindspore/ops/bprop_mindir/LogicalOr_bprop.mindir +0 -19
  678. mindspore/ops/bprop_mindir/MatMul_bprop.mindir +0 -0
  679. mindspore/ops/bprop_mindir/ReLU_bprop.mindir +0 -17
  680. mindspore/ops/bprop_mindir/Transpose_bprop.mindir +0 -0
  681. mindspore/ops/bprop_mindir/UpdateState_bprop.mindir +0 -15
  682. mindspore/ops/composite/array_ops.py +0 -241
  683. mindspore/ops/composite/clip_ops.py +0 -134
  684. mindspore/ops/composite/random_ops.py +0 -426
  685. mindspore/ops/composite/vmap_ops.py +0 -38
  686. mindspore/parallel/nn/__init__.py +0 -42
  687. mindspore/parallel/nn/loss.py +0 -22
  688. mindspore/parallel/nn/moe.py +0 -21
  689. mindspore/parallel/nn/op_parallel_config.py +0 -22
  690. mindspore/parallel/nn/transformer.py +0 -31
  691. {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/WHEEL +0 -0
  692. {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/entry_points.txt +0 -0
  693. {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/top_level.txt +0 -0
@@ -16,25 +16,26 @@
16
16
  """array_ops vmap impl."""
17
17
  from __future__ import absolute_import
18
18
 
19
- import numpy as np
20
19
  import mindspore
21
20
  import mindspore.numpy as mnp
22
21
  from mindspore import ops
23
22
  from mindspore.common import Tensor
23
+ from mindspore._c_expression import Tensor as Tensor_
24
24
  from mindspore.ops import operations as P
25
25
  from mindspore.ops import functional as F
26
- from mindspore.ops import constexpr
26
+ from mindspore.ops.primitive import constexpr, _primexpr
27
27
  from mindspore.ops.operations._grad_ops import MaskedSelectGrad
28
28
  from mindspore.ops.operations import _grad_ops as G
29
29
  from mindspore.ops.operations.array_ops import Fills, UniqueConsecutive, Col2Im, NonZero, IndexFill, \
30
30
  TensorScatterElements
31
31
  from mindspore.ops.operations.random_ops import RandomPoisson
32
+ from mindspore.ops.operations._inner_ops import DynamicBroadcastTo
32
33
  from mindspore.ops.primitive import Primitive
33
34
  from mindspore.ops._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, _bdim_at_front, \
34
35
  _raise_value_error, _vmap_clone_prim, _handle_broadcasting, get_unsupported_dynamic_vmap_rule, _broadcast_by_axis, \
35
36
  get_unop_vmap_rule, _get_reduce_out_dim, _get_reduce_batch_axis, \
36
37
  _bdim_at_any
37
- from mindspore.ops.composite import _VmapGeneralRule
38
+ from mindspore.ops.function import _VmapGeneralRule
38
39
 
39
40
 
40
41
  @vmap_rules_getters.register(P.NoRepeatNGram)
@@ -137,7 +138,7 @@ def get_arg_min_max_with_value_vmap_rule(prim, axis_size):
137
138
  return vmap_rule
138
139
 
139
140
 
140
- @constexpr
141
+ @_primexpr
141
142
  def _get_prefix(indices_shape, axis_size, indices_dtype):
142
143
  """
143
144
  Generate prefix by indices shape, whose -1 axis value is the index value of axis 0.
@@ -147,14 +148,16 @@ def _get_prefix(indices_shape, axis_size, indices_dtype):
147
148
  the generated prefix is a Tensor([[[0], [0]],
148
149
  [[1], [1]]])
149
150
  """
150
- if not indices_shape:
151
- raise ValueError("indices_shape is empty in _get_prefix.")
151
+ def _check(indices_shape):
152
+ if not indices_shape:
153
+ raise ValueError("indices_shape is empty in _get_prefix.")
152
154
 
155
+ _check(indices_shape)
153
156
  indices_len = len(indices_shape)
154
-
155
157
  if indices_len == 1:
156
- prefix = np.arange(axis_size)
157
- return Tensor(prefix, indices_dtype)
158
+ prefix = P.Range()(Tensor(0, indices_dtype), P.Fill()(
159
+ indices_dtype, (), axis_size), Tensor(1, indices_dtype))
160
+ return prefix
158
161
 
159
162
  indices_end = indices_len - 1
160
163
  prefix_shape = ()
@@ -169,8 +172,9 @@ def _get_prefix(indices_shape, axis_size, indices_dtype):
169
172
  else:
170
173
  expand_shape = expand_shape + (1,)
171
174
 
172
- prefix = np.broadcast_to(np.arange(axis_size).reshape(expand_shape), prefix_shape)
173
- return Tensor(prefix, indices_dtype)
175
+ prefix = P.BroadcastTo(prefix_shape)(P.Reshape()(P.Range()(Tensor(
176
+ 0, indices_dtype), Tensor(axis_size, indices_dtype), Tensor(1, indices_dtype)), expand_shape))
177
+ return prefix
174
178
 
175
179
 
176
180
  @vmap_rules_getters.register(P.Transpose)
@@ -179,7 +183,7 @@ def get_transpose_vmap_rule(prim, axis_size):
179
183
  if isinstance(prim, str):
180
184
  prim = Primitive(prim)
181
185
 
182
- @constexpr
186
+ @_primexpr
183
187
  def _get_transpose_batch_perm(dim, perm, x_rank):
184
188
  """Generate batch_perm based on the original perm of transpose operation and dim of the input."""
185
189
  if dim < 0:
@@ -223,7 +227,7 @@ def get_tile_vmap_rule(prim, axis_size):
223
227
  if isinstance(prim, str):
224
228
  prim = Primitive(prim)
225
229
 
226
- @constexpr
230
+ @_primexpr
227
231
  def _get_batch_multiples(input_shape, dim, multiples):
228
232
  input_ndim = len(input_shape)
229
233
  multiples_ndim = len(multiples)
@@ -352,8 +356,13 @@ def get_unstack_vmap_rule(prim, axis_size):
352
356
  def get_reshape_vmap_rule(prim, axis_size):
353
357
  """VmapRule for `Reshape` operation."""
354
358
 
355
- @constexpr
359
+
360
+ @_primexpr
356
361
  def get_batch_shape(x_shape, x_dim, target_shape, axis_size):
362
+ def _check(neg_index, target_shape):
363
+ if neg_index != -1:
364
+ raise ValueError(f'The shape can only has one -1 at most, but {target_shape}.')
365
+
357
366
  if x_dim == 0:
358
367
  return (axis_size,) + target_shape, 0, False
359
368
 
@@ -364,19 +373,21 @@ def get_reshape_vmap_rule(prim, axis_size):
364
373
  dim_prod = 1
365
374
  for i, shp_i in enumerate(target_shape):
366
375
  if shp_i == -1:
367
- if neg_index != -1:
368
- raise ValueError(f'The shape can only has one -1 at most, but {target_shape}.')
376
+ _check(neg_index, target_shape)
369
377
  neg_index = i
370
378
  else:
371
379
  dim_prod *= shp_i
372
- arr_prod = np.prod(x_shape)
380
+ arr_prod = 1
381
+ for i in x_shape:
382
+ arr_prod *= i
373
383
  target_shape_list = list(target_shape)
374
384
  if neg_index != -1:
375
385
  neg_index_size = int(arr_prod // (dim_prod * axis_size))
376
386
  target_shape_list[neg_index] = neg_index_size
377
387
 
378
- arr_prod_before_dim = np.prod(x_shape[:x_dim])
379
-
388
+ arr_prod_before_dim = 1
389
+ for i in x_shape[:x_dim]:
390
+ arr_prod_before_dim *= i
380
391
  dim_prod = 1
381
392
  for i, shp_i in enumerate(target_shape_list, start=1):
382
393
  dim_prod *= shp_i
@@ -421,7 +432,7 @@ def get_reverse_sequence_vmap_rule(prim, axis_size):
421
432
  batch_dim = prim.batch_dim_
422
433
  seq_dim = prim.seq_dim_
423
434
 
424
- @constexpr
435
+ @_primexpr
425
436
  def get_batch_seq_dim(dim, batch_dim_, seq_dim_):
426
437
  if dim is None:
427
438
  batch_dim_ += 1
@@ -437,7 +448,7 @@ def get_reverse_sequence_vmap_rule(prim, axis_size):
437
448
  seq_dim_ += 1
438
449
  return batch_dim_, seq_dim_
439
450
 
440
- @constexpr
451
+ @_primexpr
441
452
  def get_seq_dim(dim, batch_dim_, seq_dim_):
442
453
  if dim is None:
443
454
  return seq_dim_
@@ -557,20 +568,19 @@ def get_scatter_nd_vmap_rule(prim, axis_size):
557
568
  Reshape the output tensor to `[10, 6, 4, 5]`
558
569
  """
559
570
 
560
- @constexpr
571
+ @_primexpr
561
572
  def _refine_shape(shape, bdim_size):
562
573
  offset = shape[0]
563
574
  return (bdim_size * shape[0],) + tuple(shape[1:]), offset, (bdim_size,) + tuple(shape)
564
575
 
565
- @constexpr
576
+ @_primexpr
566
577
  def _gen_indices_offset(shape, offset):
567
578
  # original rank(indices.shape) is required >= 2, so indices with batch dim's rank >= 3.
568
- shape = [shape[0]] + [1] * (len(shape) - 2) + [shape[-1]]
569
- val = np.zeros(shape, np.int32) # the dtype will be changed when creating Tensor
570
- val = np.reshape(val, (shape[0], shape[-1]))
579
+ shape = (shape[0],) + (1,) * (len(shape) - 2) + (shape[-1],)
580
+ val = P.Zeros()((shape[0], shape[-1]), mindspore.int32)
571
581
  for i in range(shape[0]):
572
582
  val[i, 0] = i * offset
573
- return np.reshape(val, shape)
583
+ return P.Reshape()(val, shape)
574
584
 
575
585
  if isinstance(prim, str):
576
586
  prim = Primitive(prim)
@@ -591,7 +601,7 @@ def get_scatter_nd_vmap_rule(prim, axis_size):
591
601
  indices_shape = F.shape(indices)
592
602
  indices_dtype = F.dtype(indices)
593
603
  offset_val = _gen_indices_offset(indices_shape, offset)
594
- indices_offset = Tensor(offset_val, indices_dtype)
604
+ indices_offset = P.Cast()(offset_val, indices_dtype)
595
605
  new_indices = P.Add()(indices, indices_offset)
596
606
  out = prim(new_indices, updates, new_shape)
597
607
  real_out = P.Reshape()(out, out_shape)
@@ -839,6 +849,62 @@ def get_fill_vmap_rule(prim, axis_size):
839
849
  return vmap_rule
840
850
 
841
851
 
852
+ @constexpr
853
+ def to_tensor_with_type(x, type):
854
+ """x to Tensor with type"""
855
+ return Tensor(x, type)
856
+
857
+
858
+ @vmap_rules_getters.register(P.FillV2)
859
+ def get_fill_v2_vmap_rule(prim, axis_size):
860
+ """VmapRule for `FillV2` operation."""
861
+ if isinstance(prim, str):
862
+ prim = Primitive(prim)
863
+
864
+ def vmap_rule(shape_bdim, value_bdim):
865
+ is_all_none, result = vmap_general_preprocess(prim, shape_bdim, value_bdim)
866
+ if is_all_none:
867
+ return result
868
+
869
+ value_shape, shape_dim = shape_bdim
870
+ if shape_dim is not None:
871
+ _raise_value_error(
872
+ "The source axis of `shape` in `P.FillV2` must be None, but got {}."
873
+ .format(shape_dim))
874
+
875
+ value, vdim = value_bdim
876
+ value_rank = F.rank(value)
877
+ if value_rank != 1 or vdim != 0:
878
+ _raise_value_error(
879
+ "The `value` in `P.FillV2` must be constant value, thus the value only "
880
+ "can be rank: 1 with source axis: 0 in vmap scope, but got value rank: "
881
+ "{} with source axis: {}.".format(value_rank, vdim))
882
+ value = F.reshape(value, (axis_size,) + (1,) * len(value_shape))
883
+
884
+ out = None
885
+ if isinstance(value_shape, (Tensor_, Tensor)):
886
+ value_shape_rank = F.rank(value_shape)
887
+ if value_shape_rank != 1:
888
+ _raise_value_error(
889
+ "The `shape` in `P.FillV2` must be 1-D tensor, thus the shape only "
890
+ "can be rank: 1, but got shape rank: "
891
+ "{}.".format(value_shape_rank))
892
+ axis_size_tensor = to_tensor_with_type((axis_size,),
893
+ F.dtype(value_shape))
894
+ broad_cast_shape = F.concat((axis_size_tensor, value_shape))
895
+ out = DynamicBroadcastTo()(value, broad_cast_shape)
896
+ elif isinstance(value_shape, tuple):
897
+ out = P.BroadcastTo((axis_size,) + value_shape)(value)
898
+ else:
899
+ _raise_value_error(
900
+ f"For `P.FillV2`, the input `shape` should be Tuple or Tensor, but got `shape`: {value_shape}."
901
+ )
902
+
903
+ return out, 0
904
+
905
+ return vmap_rule
906
+
907
+
842
908
  @vmap_rules_getters.register(Fills)
843
909
  def get_fills_vmap_rule(prim, axis_size):
844
910
  """VmapRule for `Fills` operation."""
@@ -1414,6 +1480,7 @@ def get_meshgrid_vmap_rule(prim, axis_size):
1414
1480
  "The input number of P.Meshgrid must be greater than 1.")
1415
1481
 
1416
1482
  output_shape = []
1483
+ ones_shape = []
1417
1484
  for each_arg in args:
1418
1485
  x, bdim = each_arg
1419
1486
  if bdim is None:
@@ -1424,19 +1491,30 @@ def get_meshgrid_vmap_rule(prim, axis_size):
1424
1491
  _raise_value_error(
1425
1492
  "Each input of Meshgrid must be 1D, but got {}.".format(F.rank(x) - 1))
1426
1493
  output_shape.append(F.shape(x)[-1])
1494
+ ones_shape.append(1)
1427
1495
  output_shape.insert(0, axis_size)
1496
+ ones_shape.insert(0, axis_size)
1428
1497
 
1429
1498
  if indexing == "xy":
1430
1499
  output_shape[1], output_shape[2] = output_shape[2], output_shape[1]
1431
-
1432
1500
  shape = tuple(output_shape)
1501
+
1502
+ input_0, _ = args[0]
1503
+ dtype = F.dtype(input_0)
1504
+ ones_tensor = F.fill(dtype, shape, 1)
1505
+
1506
+ index = 0
1433
1507
  vals_out_tuple = ()
1434
1508
  for each_arg in args:
1435
1509
  x, bdim = each_arg
1436
1510
  x = _bdim_at_front(x, bdim, axis_size)
1437
- x = _handle_broadcasting(x, F.shape(x), output_shape)
1438
- output = P.BroadcastTo(shape)(x)
1511
+ shape_index = (1 - index) if (index <= 1 and indexing == "xy") else index
1512
+ ones_shape[shape_index + 1] = output_shape[shape_index + 1]
1513
+ x = P.Reshape()(x, tuple(ones_shape))
1514
+ output = P.Mul()(x, ones_tensor)
1439
1515
  vals_out_tuple = vals_out_tuple + ((output, 0),)
1516
+ ones_shape[shape_index + 1] = 1
1517
+ index = index + 1
1440
1518
 
1441
1519
  return vals_out_tuple
1442
1520
 
@@ -1480,7 +1558,7 @@ def get_gather_vmap_rule(prim, axis_size):
1480
1558
  else:
1481
1559
  prim_name = prim.name
1482
1560
 
1483
- @constexpr
1561
+ @_primexpr
1484
1562
  def process_axis(axis, x_shape_size, has_xdim: bool, has_idim: bool):
1485
1563
  if has_xdim and has_idim:
1486
1564
  if axis < 0:
@@ -1494,7 +1572,7 @@ def get_gather_vmap_rule(prim, axis_size):
1494
1572
 
1495
1573
  return axis
1496
1574
 
1497
- @constexpr
1575
+ @_primexpr
1498
1576
  def get_x_dst_shape(x_shape, axis):
1499
1577
  target_axis_size = x_shape[axis + 1]
1500
1578
  x_dst_shape = x_shape[0:axis] + (axis_size * target_axis_size,) + x_shape[axis + 2:]
@@ -1694,7 +1772,7 @@ def get_data_format_dim_map_vmap_rule(prim, axis_size):
1694
1772
  def get_expand_dims_vmap_rule(prim, axis_size):
1695
1773
  """VmapRule for `ExpandDims`."""
1696
1774
 
1697
- @constexpr
1775
+ @_primexpr
1698
1776
  def process_axis(axis, rank, x_dim):
1699
1777
  if axis < 0:
1700
1778
  axis += rank
@@ -1788,7 +1866,7 @@ def get_squeeze_vmap_rule(prim, axis_size):
1788
1866
  else:
1789
1867
  prim_axis = None
1790
1868
 
1791
- @constexpr
1869
+ @_primexpr
1792
1870
  def move_axis(axes):
1793
1871
  new_axis = ()
1794
1872
  for axis in axes:
@@ -1798,7 +1876,7 @@ def get_squeeze_vmap_rule(prim, axis_size):
1798
1876
  new_axis = new_axis + (axis + 1,)
1799
1877
  return new_axis
1800
1878
 
1801
- @constexpr
1879
+ @_primexpr
1802
1880
  def generate_all_axis_except_first(x_rank):
1803
1881
  new_axis = ()
1804
1882
  for i in range(1, x_rank, 1):
@@ -1842,7 +1920,7 @@ def get_stridedslice_vmap_rule(prim, axis_size):
1842
1920
  batch_stridedslice = P.StridedSlice(new_begin_mask, new_end_mask, new_ellipsis_mask, new_new_axis_mask, \
1843
1921
  new_shrink_axis_mask)
1844
1922
 
1845
- @constexpr
1923
+ @_primexpr
1846
1924
  def get_new_begin_end_strided(begin, end, strided):
1847
1925
  new_begin = (0,) + begin
1848
1926
  new_end = (0,) + end
@@ -1883,7 +1961,7 @@ def get_stridedslice_grad_vmap_rule(prim, axis_size):
1883
1961
  batch_stridedslice_grad = G.StridedSliceGrad(new_begin_mask, new_end_mask, new_ellipsis_mask, new_new_axis_mask, \
1884
1962
  new_shrink_axis_mask)
1885
1963
 
1886
- @constexpr
1964
+ @_primexpr
1887
1965
  def get_new_xshape_begin_end_strided(xshape, begin, end, strided):
1888
1966
  new_xshape = (axis_size,) + xshape
1889
1967
  new_begin = (0,) + begin
@@ -21,11 +21,12 @@ from mindspore.common import Tensor
21
21
  from mindspore.ops import operations as P
22
22
  from mindspore.ops import functional as F
23
23
  from mindspore.ops import constexpr
24
+ from mindspore.ops.primitive import _primexpr
24
25
  from mindspore.ops.operations import math_ops
25
26
  from mindspore.ops.operations import _grad_ops as G
26
27
  from mindspore.ops.operations import nn_ops as nps
27
- from mindspore.ops.composite import _VmapGeneralPreprocess
28
- from mindspore.ops.primitive import Primitive
28
+ from mindspore.ops.function import _VmapGeneralPreprocess
29
+ from mindspore.ops.primitive import Primitive, _PrimitiveC
29
30
  from mindspore.ops.operations.random_ops import UniformCandidateSampler, RandomShuffle
30
31
  from mindspore.ops._grad.grad_base import BpropRegistry as VmapRuleRegistry
31
32
 
@@ -41,7 +42,7 @@ def get_vmap_rule(prim, axis_size):
41
42
  return None
42
43
 
43
44
 
44
- @constexpr
45
+ @_primexpr
45
46
  def _get_broadcast_shape_with_front_axis(x_shape, y_shape):
46
47
  """ Explicitly matched with the broadcast shape, that is, 1 is added to the broadcast position. """
47
48
  x_len = len(x_shape)
@@ -86,7 +87,7 @@ def _handle_broadcasting(x, x_shape, y_shape):
86
87
  return F.reshape(x, broadcast_shape)
87
88
 
88
89
 
89
- @constexpr
90
+ @_primexpr
90
91
  def _get_broadcasting_with_front_axis_additional_axis(x_shape, y_shape):
91
92
  """ Get the axes that are inserted after broadcasting.
92
93
  Args:
@@ -129,15 +130,19 @@ def _raise_value_error(info, param=None):
129
130
  raise ValueError(info + f"{param}")
130
131
 
131
132
 
132
- @constexpr
133
+ @_primexpr
133
134
  def _get_broadcast_shape(x_shape, dst, axis_size):
134
135
  """Get the target shape for broadcast array."""
136
+ def _check(dst, broadcast_ndim):
137
+ if dst < -broadcast_ndim or dst >= broadcast_ndim:
138
+ _raise_value_error("Destination axis {} is out of bounds for array of dimension"
139
+ " [{}, {}).".format(dst, -broadcast_ndim, broadcast_ndim))
140
+
135
141
  x_ndim = len(x_shape)
136
142
  broadcast_ndim = x_ndim + 1
137
143
 
138
- if dst < -broadcast_ndim or dst >= broadcast_ndim:
139
- _raise_value_error("Destination axis {} is out of bounds for array of dimension"
140
- " [{}, {}).".format(dst, -broadcast_ndim, broadcast_ndim))
144
+ _check(dst, broadcast_ndim)
145
+
141
146
  if dst < 0:
142
147
  dst = broadcast_ndim + dst
143
148
 
@@ -420,6 +425,8 @@ def _vmap_clone_prim(prim):
420
425
  """
421
426
  Cloning a new primitive object same as `prim`.
422
427
  """
428
+ if isinstance(prim, _PrimitiveC):
429
+ return _PrimitiveC(prim.name, prim.attrs)
423
430
  new_ops = _ops_vmap_clone_prim_dict.get(prim.name, None)
424
431
  if new_ops is None:
425
432
  raise ValueError("Failed to get the primitive object of {} from `_ops_vmap_clone_prim_dict`. Please register "
@@ -437,7 +444,7 @@ def _vmap_clone_prim(prim):
437
444
  return cloned
438
445
 
439
446
 
440
- @constexpr
447
+ @_primexpr
441
448
  def _get_reduce_batch_axis(axis, x_dim, x_ndim):
442
449
  """get batch_axis for reduce* operation."""
443
450
  # For axis, it's value in Union[int, list, tuple]
@@ -16,9 +16,9 @@
16
16
  """convolution vmap impl"""
17
17
  from __future__ import absolute_import
18
18
 
19
- import numpy as np
20
19
  import mindspore.numpy as mnp
21
20
  from mindspore.ops import constexpr
21
+ from mindspore.ops.primitive import _primexpr
22
22
  from mindspore.ops import operations as P
23
23
  from mindspore.ops import functional as F
24
24
  from mindspore.ops.operations import nn_ops as nps
@@ -142,7 +142,7 @@ def get_conv3d_backprop_filter_vmap_rule(prim, axis_size):
142
142
  return vmap_rule
143
143
 
144
144
 
145
- @constexpr
145
+ @_primexpr
146
146
  def _get_reshape_src_dim(data_dim, cmp_dim):
147
147
  """Get source dim for reshape"""
148
148
  if data_dim > cmp_dim:
@@ -154,7 +154,7 @@ def _get_reshape_src_dim(data_dim, cmp_dim):
154
154
  return expand_dim, merge_dim
155
155
 
156
156
 
157
- @constexpr
157
+ @_primexpr
158
158
  def _get_merge_shape(src_dim, dst_dim, shape):
159
159
  """Get new shape for merging the src_dim and dst_dim. The dst_dim is the value after removing src_dim."""
160
160
  new_shape = [shape[i] for i in range(len(shape)) if i != src_dim]
@@ -171,13 +171,10 @@ def _reshape_merge_dims(src_dim, dst_dim, target):
171
171
  return output, new_shape
172
172
 
173
173
 
174
- @constexpr
174
+ @_primexpr
175
175
  def _get_expand_shape(src_dim, dst_size, shape, prim_name):
176
176
  """Get new shape for splitting src_dim into dst_size parts."""
177
- dst_size2, remainder = np.divmod(shape[src_dim], dst_size)
178
- if remainder != 0:
179
- _raise_value_error("The remainder of {} / {} should be 0, "
180
- "but got {} in {}.".format(shape[src_dim], dst_size, remainder, prim_name))
177
+ dst_size2 = shape[src_dim] // dst_size
181
178
  new_shape = list(shape)
182
179
  new_shape[src_dim:(src_dim + 1)] = [dst_size, dst_size2]
183
180
  return tuple(new_shape)
@@ -190,7 +187,7 @@ def _reshape_expand_dims(src_dim, dst_size, target, prim_name):
190
187
  return F.reshape(target, new_shape)
191
188
 
192
189
 
193
- @constexpr
190
+ @_primexpr
194
191
  def _get_new_size_by_index(input_size, batch_size, index):
195
192
  """Get the new size of input_size by multiplying input_size[index] by batch_size."""
196
193
  new_size = ()
@@ -201,7 +198,7 @@ def _get_new_size_by_index(input_size, batch_size, index):
201
198
  return tuple(new_size)
202
199
 
203
200
 
204
- @constexpr
201
+ @_primexpr
205
202
  def _update_group_attr(prim, groups, batch_size):
206
203
  """Set new value for 'group' attribute of the convolution primitive."""
207
204
  group = groups * batch_size
@@ -17,9 +17,9 @@
17
17
  from __future__ import absolute_import
18
18
 
19
19
  from mindspore.ops import functional as F
20
- from mindspore.ops import constexpr
20
+ from mindspore.ops.primitive import _primexpr
21
21
  from mindspore.ops.operations import _grad_ops as G
22
- from mindspore.ops.composite import _VmapGeneralRule
22
+ from mindspore.ops.function import _VmapGeneralRule
23
23
  from mindspore.ops._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, _bdim_at_front, \
24
24
  _handle_broadcasting, get_unary_grad_vmap_rule, _get_broadcasting_with_front_axis_additional_axis
25
25
 
@@ -36,7 +36,7 @@ def get_broadcast_binary_op_grad_vmap_rule(prim, axis_size):
36
36
  if isinstance(prim, str):
37
37
  prim = broadcast_binary_op_grad_map.get(prim)()
38
38
 
39
- @constexpr
39
+ @_primexpr
40
40
  def get_longest_shape(x_shape, y_shape, g_shape):
41
41
  x_rank = len(x_shape)
42
42
  y_rank = len(y_shape)
@@ -148,7 +148,7 @@ def get_median_grad_vmap_rule(prim, axis_size):
148
148
  axis = prim.axis
149
149
  keep_dims = prim.keep_dims
150
150
 
151
- @constexpr
151
+ @_primexpr
152
152
  def trans_grad_axis(axis, rank, dim, keep_dims):
153
153
  if axis < 0:
154
154
  axis += rank - 1
@@ -22,8 +22,9 @@ import mindspore.numpy as mnp
22
22
  from mindspore.ops.operations import _grad_ops as G
23
23
  from mindspore.ops import functional as F
24
24
  from mindspore.ops import constexpr
25
+ from mindspore.ops.primitive import _primexpr
25
26
  from mindspore.ops.primitive import Primitive
26
- from mindspore.ops.composite import _VmapGeneralRule
27
+ from mindspore.ops.function import _VmapGeneralRule
27
28
  from mindspore.ops._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, _raise_value_error, \
28
29
  _bdim_at_front, _vmap_clone_prim, _vmap_update_prim_attr, _bdim_at_any, _handle_broadcasting
29
30
 
@@ -38,7 +39,7 @@ def get_nll_loss_grad_vmap_rule(prim, axis_size):
38
39
  2. And weight only support shape as (C,), while total_weight should be a scalar.
39
40
  """
40
41
 
41
- @constexpr
42
+ @_primexpr
42
43
  def _get_reshape_shape(shape, keep_dim=0):
43
44
  new_batch_size = reduce(
44
45
  lambda x, y: x * y, shape if keep_dim == 0 else shape[:-keep_dim])
@@ -397,8 +398,9 @@ def get_batchnorm_grad_vmap_rule(prim, axis_size):
397
398
 
398
399
  @vmap_rules_getters.register(G.MaxPoolGradGrad)
399
400
  @vmap_rules_getters.register(G.MaxPoolGradGradWithArgmax)
401
+ @vmap_rules_getters.register(G.MaxPoolGradWithArgmaxV2)
400
402
  def get_maxpool_grad_grad_vmap_rule(prim, axis_size):
401
- """VmapRule for `MaxPoolGradGrad` and `MaxPoolGradGradWithArgmax`."""
403
+ """VmapRule for `MaxPoolGradGrad`, `MaxPoolGradGradWithArgmax` and `MaxPoolGradWithArgmaxV2`."""
402
404
  chw_reverse_index = -3
403
405
 
404
406
  def vmap_rule(in0_bdim, in1_bdim, in2_bdim):
@@ -557,7 +559,7 @@ def get_layernormgrad_vmap_rule(prim, axis_size):
557
559
  return prim_attr_axis
558
560
  return prim_attr_axis + 1
559
561
 
560
- @constexpr
562
+ @_primexpr
561
563
  def get_batch_params_reduce_axes(begin_params_axis, x_shape):
562
564
  if begin_params_axis < 0:
563
565
  x_rank = len(x_shape)
@@ -565,7 +567,7 @@ def get_layernormgrad_vmap_rule(prim, axis_size):
565
567
  batch_params_reduce_axes = tuple(range(1, begin_params_axis))
566
568
  return batch_params_reduce_axes
567
569
 
568
- @constexpr
570
+ @_primexpr
569
571
  def get_logical_shape(var_shape):
570
572
  return var_shape[1:]
571
573
 
@@ -16,10 +16,12 @@
16
16
  """image_ops vmap impl."""
17
17
  from __future__ import absolute_import
18
18
 
19
- import mindspore.numpy as mnp
19
+ import numpy as np
20
+ from mindspore import Tensor
20
21
  from mindspore.ops import functional as F
21
22
  from mindspore.ops.operations import _grad_ops as G
22
23
  from mindspore.ops.operations import image_ops as IMG
24
+ from mindspore.ops import constexpr
23
25
  from mindspore.ops._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, _bdim_at_front, \
24
26
  _raise_value_error
25
27
 
@@ -90,6 +92,13 @@ def get_resize_grad_dynamic_rule(prim, axis_size):
90
92
  def get_crop_and_resize_vmap_rule(prim, axis_size):
91
93
  """VmapRule for `CropAndResize` operation."""
92
94
 
95
+ @constexpr
96
+ def get_box_indices_offsets(axis_size, batch_size, num_boxes):
97
+ offsets = np.arange(0, axis_size * batch_size, batch_size).astype(np.int32)
98
+ offsets = np.reshape(offsets, (axis_size, 1))
99
+ offsets = np.broadcast_to(offsets, (axis_size, num_boxes))
100
+ return Tensor(offsets)
101
+
93
102
  def vmap_rule(x_bdim, boxes_bdim, box_indices_bdim, crop_size_bdim):
94
103
  is_all_none, result = vmap_general_preprocess(x_bdim, boxes_bdim, box_indices_bdim, crop_size_bdim)
95
104
  if is_all_none:
@@ -115,10 +124,8 @@ def get_crop_and_resize_vmap_rule(prim, axis_size):
115
124
  x = _bdim_at_front(x, x_dim, axis_size)
116
125
  x_shape = F.shape(x)
117
126
  x = F.reshape(x, (-1,) + x_shape[2:])
118
- counts = mnp.arange(0, axis_size * x_shape[1], x_shape[1])
119
- counts = F.reshape(counts, (axis_size, 1))
120
- counts = F.broadcast_to(counts, (axis_size, num_boxes))
121
- box_indices = F.add(box_indices, counts)
127
+ offsets = get_box_indices_offsets(axis_size, x_shape[1], num_boxes)
128
+ box_indices = F.add(box_indices, offsets)
122
129
  box_indices = F.reshape(box_indices, (-1,))
123
130
  out = prim(x, boxes, box_indices, crop_size)
124
131