mindspore 2.0.0a0__cp38-cp38-win_amd64.whl → 2.0.0rc1__cp38-cp38-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (655) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -2
  3. mindspore/_c_dataengine.cp38-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp38-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp38-win_amd64.pyd +0 -0
  6. mindspore/_check_jit_forbidden_api.py +102 -0
  7. mindspore/_checkparam.py +1066 -1001
  8. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +4 -3
  9. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +50 -48
  10. mindspore/_extends/parallel_compile/akg_compiler/util.py +9 -4
  11. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +4 -4
  12. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +9 -4
  13. mindspore/_extends/parse/__init__.py +5 -3
  14. mindspore/_extends/parse/namespace.py +16 -1
  15. mindspore/_extends/parse/parser.py +107 -22
  16. mindspore/_extends/parse/resources.py +0 -7
  17. mindspore/_extends/parse/standard_method.py +885 -413
  18. mindspore/amp.py +52 -57
  19. mindspore/boost/boost.py +2 -2
  20. mindspore/boost/boost_cell_wrapper.py +38 -20
  21. mindspore/boost/dim_reduce.py +3 -3
  22. mindspore/boost/group_loss_scale_manager.py +1 -1
  23. mindspore/common/__init__.py +4 -6
  24. mindspore/common/_decorator.py +2 -0
  25. mindspore/common/_register_for_adapter.py +55 -0
  26. mindspore/common/_stub_tensor.py +201 -0
  27. mindspore/common/_utils.py +41 -7
  28. mindspore/common/api.py +215 -141
  29. mindspore/common/dtype.py +8 -1
  30. mindspore/common/dump.py +2 -2
  31. mindspore/common/initializer.py +4 -2
  32. mindspore/common/jit_config.py +17 -13
  33. mindspore/common/mutable.py +33 -13
  34. mindspore/common/parameter.py +23 -21
  35. mindspore/common/seed.py +8 -24
  36. mindspore/common/sparse_tensor.py +62 -41
  37. mindspore/common/tensor.py +852 -1154
  38. mindspore/communication/__init__.py +2 -2
  39. mindspore/communication/_comm_helper.py +11 -4
  40. mindspore/communication/management.py +22 -21
  41. mindspore/config/op_info.config +501 -1008
  42. mindspore/context.py +201 -23
  43. mindspore/dataset/__init__.py +6 -6
  44. mindspore/dataset/audio/__init__.py +7 -7
  45. mindspore/dataset/audio/transforms.py +670 -30
  46. mindspore/dataset/audio/utils.py +47 -4
  47. mindspore/dataset/audio/validators.py +223 -1
  48. mindspore/dataset/callback/ds_callback.py +2 -2
  49. mindspore/dataset/core/config.py +210 -14
  50. mindspore/dataset/core/validator_helpers.py +2 -2
  51. mindspore/{parallel/nn/layers.py → dataset/debug/__init__.py} +7 -8
  52. mindspore/dataset/debug/debug_hook.py +65 -0
  53. mindspore/dataset/debug/pre_defined_hook.py +67 -0
  54. mindspore/dataset/engine/__init__.py +7 -3
  55. mindspore/dataset/engine/cache_client.py +1 -1
  56. mindspore/dataset/engine/datasets.py +322 -66
  57. mindspore/dataset/engine/datasets_audio.py +80 -76
  58. mindspore/dataset/engine/datasets_standard_format.py +51 -38
  59. mindspore/dataset/engine/datasets_text.py +232 -118
  60. mindspore/dataset/engine/datasets_user_defined.py +41 -17
  61. mindspore/dataset/engine/datasets_vision.py +746 -225
  62. mindspore/dataset/engine/graphdata.py +75 -10
  63. mindspore/dataset/engine/iterators.py +45 -5
  64. mindspore/dataset/engine/offload.py +48 -28
  65. mindspore/dataset/engine/validators.py +117 -8
  66. mindspore/dataset/text/__init__.py +6 -5
  67. mindspore/dataset/text/transforms.py +86 -3
  68. mindspore/dataset/text/utils.py +6 -4
  69. mindspore/dataset/text/validators.py +25 -0
  70. mindspore/dataset/transforms/__init__.py +3 -2
  71. mindspore/dataset/transforms/c_transforms.py +1 -1
  72. mindspore/dataset/transforms/transforms.py +2 -2
  73. mindspore/dataset/utils/__init__.py +2 -1
  74. mindspore/dataset/utils/line_reader.py +121 -0
  75. mindspore/dataset/vision/__init__.py +2 -3
  76. mindspore/dataset/vision/c_transforms.py +9 -9
  77. mindspore/dataset/vision/py_transforms.py +5 -5
  78. mindspore/dataset/vision/py_transforms_util.py +2 -0
  79. mindspore/dataset/vision/transforms.py +160 -161
  80. mindspore/dataset/vision/utils.py +3 -3
  81. mindspore/experimental/map_parameter.py +38 -26
  82. mindspore/include/OWNERS +0 -1
  83. mindspore/include/api/callback/callback.h +9 -13
  84. mindspore/include/api/callback/ckpt_saver.h +2 -2
  85. mindspore/include/api/callback/loss_monitor.h +2 -2
  86. mindspore/include/api/callback/lr_scheduler.h +5 -5
  87. mindspore/include/api/callback/time_monitor.h +2 -2
  88. mindspore/include/api/callback/train_accuracy.h +4 -6
  89. mindspore/include/api/cfg.h +19 -6
  90. mindspore/include/api/context.h +44 -9
  91. mindspore/include/api/delegate.h +1 -1
  92. mindspore/include/api/metrics/accuracy.h +2 -2
  93. mindspore/include/api/metrics/metrics.h +4 -3
  94. mindspore/include/api/model.h +9 -4
  95. mindspore/include/api/model_parallel_runner.h +2 -2
  96. mindspore/include/api/net.h +12 -11
  97. mindspore/include/api/serialization.h +19 -3
  98. mindspore/include/api/types.h +3 -3
  99. mindspore/include/dataset/constants.h +7 -0
  100. mindspore/include/dataset/text.h +59 -0
  101. mindspore/jpeg62.dll +0 -0
  102. mindspore/log.py +1 -1
  103. mindspore/mindrecord/filereader.py +18 -0
  104. mindspore/mindrecord/filewriter.py +197 -34
  105. mindspore/mindrecord/shardreader.py +9 -0
  106. mindspore/mindrecord/shardwriter.py +1 -1
  107. mindspore/mindrecord/tools/cifar100_to_mr.py +3 -3
  108. mindspore/mindrecord/tools/cifar10_to_mr.py +3 -3
  109. mindspore/mindrecord/tools/csv_to_mr.py +3 -3
  110. mindspore/mindrecord/tools/imagenet_to_mr.py +16 -11
  111. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  112. mindspore/mindrecord/tools/tfrecord_to_mr.py +6 -6
  113. mindspore/mindspore_backend.dll +0 -0
  114. mindspore/mindspore_common.dll +0 -0
  115. mindspore/mindspore_core.dll +0 -0
  116. mindspore/mindspore_glog.dll +0 -0
  117. mindspore/mindspore_shared_lib.dll +0 -0
  118. mindspore/nn/__init__.py +0 -4
  119. mindspore/nn/cell.py +204 -132
  120. mindspore/nn/dynamic_lr.py +1 -1
  121. mindspore/nn/grad/cell_grad.py +7 -6
  122. mindspore/nn/layer/__init__.py +5 -4
  123. mindspore/nn/layer/activation.py +40 -89
  124. mindspore/nn/layer/basic.py +255 -624
  125. mindspore/nn/layer/channel_shuffle.py +7 -6
  126. mindspore/nn/layer/combined.py +1 -1
  127. mindspore/nn/layer/container.py +41 -4
  128. mindspore/nn/layer/conv.py +64 -28
  129. mindspore/nn/layer/dense.py +9 -8
  130. mindspore/nn/layer/embedding.py +27 -25
  131. mindspore/nn/layer/image.py +53 -46
  132. mindspore/nn/layer/math.py +97 -105
  133. mindspore/nn/layer/normalization.py +117 -86
  134. mindspore/nn/layer/padding.py +185 -95
  135. mindspore/nn/layer/pooling.py +817 -414
  136. mindspore/nn/layer/rnn_cells.py +10 -15
  137. mindspore/nn/layer/rnns.py +37 -38
  138. mindspore/nn/layer/thor_layer.py +11 -12
  139. mindspore/nn/layer/timedistributed.py +5 -5
  140. mindspore/nn/layer/transformer.py +701 -0
  141. mindspore/nn/learning_rate_schedule.py +8 -8
  142. mindspore/nn/loss/__init__.py +5 -4
  143. mindspore/nn/loss/loss.py +334 -199
  144. mindspore/nn/optim/ada_grad.py +6 -6
  145. mindspore/nn/optim/adadelta.py +2 -3
  146. mindspore/nn/optim/adafactor.py +4 -5
  147. mindspore/nn/optim/adam.py +126 -62
  148. mindspore/nn/optim/adamax.py +3 -4
  149. mindspore/nn/optim/adasum.py +6 -6
  150. mindspore/nn/optim/asgd.py +2 -2
  151. mindspore/nn/optim/ftrl.py +67 -38
  152. mindspore/nn/optim/lamb.py +4 -5
  153. mindspore/nn/optim/lars.py +2 -2
  154. mindspore/nn/optim/lazyadam.py +43 -4
  155. mindspore/nn/optim/momentum.py +6 -5
  156. mindspore/nn/optim/optimizer.py +3 -1
  157. mindspore/nn/optim/proximal_ada_grad.py +2 -2
  158. mindspore/nn/optim/rmsprop.py +1 -1
  159. mindspore/nn/optim/rprop.py +8 -9
  160. mindspore/nn/optim/sgd.py +19 -13
  161. mindspore/nn/optim/thor.py +10 -15
  162. mindspore/nn/probability/__init__.py +0 -2
  163. mindspore/nn/probability/bijector/bijector.py +4 -4
  164. mindspore/nn/probability/bijector/invert.py +1 -1
  165. mindspore/nn/probability/bijector/softplus.py +2 -2
  166. mindspore/nn/probability/bnn_layers/dense_variational.py +1 -1
  167. mindspore/nn/probability/bnn_layers/layer_distribution.py +2 -2
  168. mindspore/nn/probability/distribution/_utils/utils.py +9 -15
  169. mindspore/nn/probability/distribution/bernoulli.py +3 -3
  170. mindspore/nn/probability/distribution/beta.py +1 -1
  171. mindspore/nn/probability/distribution/categorical.py +5 -7
  172. mindspore/nn/probability/distribution/cauchy.py +3 -3
  173. mindspore/nn/probability/distribution/distribution.py +2 -2
  174. mindspore/nn/probability/distribution/exponential.py +2 -2
  175. mindspore/nn/probability/distribution/gamma.py +3 -3
  176. mindspore/nn/probability/distribution/geometric.py +1 -1
  177. mindspore/nn/probability/distribution/gumbel.py +3 -3
  178. mindspore/nn/probability/distribution/half_normal.py +15 -11
  179. mindspore/nn/probability/distribution/laplace.py +16 -13
  180. mindspore/nn/probability/distribution/logistic.py +2 -2
  181. mindspore/nn/probability/distribution/normal.py +1 -1
  182. mindspore/nn/probability/distribution/poisson.py +1 -1
  183. mindspore/nn/probability/distribution/student_t.py +20 -15
  184. mindspore/nn/probability/distribution/transformed_distribution.py +4 -4
  185. mindspore/nn/probability/distribution/uniform.py +2 -2
  186. mindspore/nn/reinforcement/_tensors_queue.py +3 -3
  187. mindspore/nn/reinforcement/tensor_array.py +2 -2
  188. mindspore/nn/sparse/sparse.py +2 -2
  189. mindspore/nn/wrap/cell_wrapper.py +27 -10
  190. mindspore/nn/wrap/grad_reducer.py +2 -2
  191. mindspore/nn/wrap/loss_scale.py +40 -24
  192. mindspore/numpy/array_creations.py +33 -22
  193. mindspore/numpy/array_ops.py +35 -30
  194. mindspore/numpy/logic_ops.py +6 -27
  195. mindspore/numpy/math_ops.py +22 -19
  196. mindspore/numpy/utils.py +1 -1
  197. mindspore/numpy/utils_const.py +108 -58
  198. mindspore/opencv_core452.dll +0 -0
  199. mindspore/opencv_imgcodecs452.dll +0 -0
  200. mindspore/opencv_imgproc452.dll +0 -0
  201. mindspore/ops/_constants.py +0 -6
  202. mindspore/ops/_grad/__init__.py +2 -1
  203. mindspore/ops/_grad/grad_array_ops.py +86 -117
  204. mindspore/ops/_grad/grad_base.py +23 -1
  205. mindspore/ops/_grad/grad_clip_ops.py +2 -3
  206. mindspore/ops/_grad/grad_comm_ops.py +34 -24
  207. mindspore/ops/_grad/grad_implementations.py +9 -45
  208. mindspore/ops/_grad/grad_inner_ops.py +47 -4
  209. mindspore/ops/_grad/grad_math_ops.py +142 -117
  210. mindspore/ops/_grad/grad_nn_ops.py +71 -165
  211. mindspore/ops/_grad/grad_sequence_ops.py +296 -0
  212. mindspore/ops/_grad/grad_sparse.py +7 -6
  213. mindspore/ops/_grad_experimental/__init__.py +1 -0
  214. mindspore/ops/_grad_experimental/grad_array_ops.py +150 -15
  215. mindspore/ops/_grad_experimental/grad_image_ops.py +16 -7
  216. mindspore/ops/_grad_experimental/grad_inner_ops.py +1 -22
  217. mindspore/ops/_grad_experimental/grad_linalg_ops.py +4 -11
  218. mindspore/ops/_grad_experimental/grad_math_ops.py +210 -89
  219. mindspore/ops/_grad_experimental/grad_nn_ops.py +26 -22
  220. mindspore/ops/_grad_experimental/grad_scalar_ops.py +112 -0
  221. mindspore/ops/_grad_experimental/grad_sparse_ops.py +49 -8
  222. mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +1 -1
  223. mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +2 -2
  224. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +2 -2
  225. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +2 -2
  226. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +4 -4
  227. mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +3 -3
  228. mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +1 -1
  229. mindspore/ops/_op_impl/_custom_op/correction_mul.py +2 -2
  230. mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +2 -2
  231. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -5
  232. mindspore/ops/_op_impl/_custom_op/dsd_impl.py +1 -1
  233. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +2 -2
  234. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +2 -2
  235. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +2 -2
  236. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +2 -2
  237. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +2 -2
  238. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +2 -2
  239. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +2 -2
  240. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +2 -2
  241. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +2 -2
  242. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +2 -2
  243. mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +1 -1
  244. mindspore/ops/_op_impl/_custom_op/img2col_impl.py +1 -1
  245. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
  246. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +1 -1
  247. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +1 -1
  248. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +1 -1
  249. mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +2 -2
  250. mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +0 -4
  251. mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +1 -1
  252. mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +2 -2
  253. mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +2 -2
  254. mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +1 -1
  255. mindspore/ops/_op_impl/aicpu/__init__.py +236 -4
  256. mindspore/ops/_op_impl/aicpu/abs.py +36 -0
  257. mindspore/ops/_op_impl/aicpu/{adaptive_avg_pool_2d_v1.py → adaptive_avg_pool_2d.py} +6 -5
  258. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
  259. mindspore/ops/_op_impl/aicpu/add.py +43 -0
  260. mindspore/ops/_op_impl/aicpu/addcdiv.py +0 -32
  261. mindspore/ops/_op_impl/aicpu/addcmul.py +0 -84
  262. mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
  263. mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -43
  264. mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
  265. mindspore/{compression/common/__init__.py → ops/_op_impl/aicpu/bessel_i0.py} +15 -8
  266. mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
  267. mindspore/ops/_op_impl/aicpu/conj.py +11 -0
  268. mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +0 -3
  269. mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
  270. mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +43 -0
  271. mindspore/ops/_op_impl/aicpu/{adaptive_avg_pool_2d_grad_v1.py → digamma.py} +7 -9
  272. mindspore/ops/_op_impl/aicpu/flatten.py +1 -0
  273. mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
  274. mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
  275. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +1 -1
  276. mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
  277. mindspore/ops/_op_impl/aicpu/greater.py +41 -0
  278. mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
  279. mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
  280. mindspore/ops/_op_impl/aicpu/less.py +41 -0
  281. mindspore/{nn/probability/infer/variational/__init__.py → ops/_op_impl/aicpu/lgamma.py} +16 -10
  282. mindspore/ops/_op_impl/aicpu/mirror_pad.py +0 -4
  283. mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +0 -4
  284. mindspore/ops/_op_impl/aicpu/mul.py +3 -1
  285. mindspore/ops/_op_impl/aicpu/multinomial.py +14 -6
  286. mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
  287. mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
  288. mindspore/ops/_op_impl/aicpu/ones_like.py +0 -2
  289. mindspore/ops/_op_impl/aicpu/polar.py +32 -0
  290. mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
  291. mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
  292. mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
  293. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
  294. mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
  295. mindspore/ops/_op_impl/aicpu/resize_bicubic.py +2 -8
  296. mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +1 -1
  297. mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
  298. mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
  299. mindspore/ops/_op_impl/aicpu/scatter_elements.py +4 -0
  300. mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +2 -0
  301. mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
  302. mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
  303. mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
  304. mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
  305. mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
  306. mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +0 -24
  307. mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
  308. mindspore/ops/_op_impl/aicpu/sparse_slice.py +4 -0
  309. mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +6 -0
  310. mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
  311. mindspore/ops/_op_impl/aicpu/trans_data.py +1 -0
  312. mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
  313. mindspore/ops/_op_impl/aicpu/uniform.py +34 -0
  314. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +1 -0
  315. mindspore/ops/_op_impl/aicpu/unique_consecutive.py +10 -2
  316. mindspore/ops/_op_impl/cpu/dynamic_shape.py +5 -1
  317. mindspore/ops/_op_impl/cpu/sparse_slice.py +4 -0
  318. mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +6 -0
  319. mindspore/ops/_op_impl/cpu/tensor_shape.py +5 -1
  320. mindspore/ops/_op_impl/tbe/__init__.py +27 -611
  321. mindspore/ops/_op_impl/tbe/assign_add_ds.py +1 -0
  322. mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
  323. mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +1 -1
  324. mindspore/ops/_op_impl/tbe/batch_matmul_ds.py +1 -0
  325. mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
  326. mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +1 -1
  327. mindspore/ops/_op_impl/tbe/bn_infer_grad.py +4 -2
  328. mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -1
  329. mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -1
  330. mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +6 -4
  331. mindspore/ops/_op_impl/tbe/cast.py +0 -2
  332. mindspore/ops/_op_impl/tbe/cast_ds.py +3 -3
  333. mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +1 -0
  334. mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +2 -2
  335. mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +1 -1
  336. mindspore/ops/_op_impl/tbe/gather_nd.py +1 -0
  337. mindspore/ops/_op_impl/tbe/{index_add.py → inplace_index_add.py} +3 -6
  338. mindspore/ops/_op_impl/tbe/matmul_ds.py +2 -0
  339. mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +35 -0
  340. mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +35 -0
  341. mindspore/ops/_op_impl/tbe/scatter_mul.py +2 -0
  342. mindspore/ops/_op_impl/tbe/scatter_nd_add.py +0 -2
  343. mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
  344. mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +1 -1
  345. mindspore/ops/_op_impl/tbe/trans_data_ds.py +15 -5
  346. mindspore/ops/_register_for_op.py +1 -0
  347. mindspore/ops/_utils/__init__.py +1 -2
  348. mindspore/ops/_utils/utils.py +19 -40
  349. mindspore/ops/_vmap/vmap_array_ops.py +116 -38
  350. mindspore/ops/_vmap/vmap_base.py +16 -9
  351. mindspore/ops/_vmap/vmap_convolution_ops.py +7 -10
  352. mindspore/ops/_vmap/vmap_grad_math_ops.py +4 -4
  353. mindspore/ops/_vmap/vmap_grad_nn_ops.py +7 -5
  354. mindspore/ops/_vmap/vmap_image_ops.py +12 -5
  355. mindspore/ops/_vmap/vmap_math_ops.py +46 -5
  356. mindspore/ops/_vmap/vmap_nn_ops.py +15 -21
  357. mindspore/ops/_vmap/vmap_random_ops.py +1 -1
  358. mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
  359. mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
  360. mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +150 -0
  361. mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +66 -0
  362. mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
  363. mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
  364. mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
  365. mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +33 -0
  366. mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +220 -106
  367. mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
  368. mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +240 -0
  369. mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +247 -0
  370. mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +247 -0
  371. mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +315 -0
  372. mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +278 -0
  373. mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +58 -0
  374. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +138 -0
  375. mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
  376. mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
  377. mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +22 -23
  378. mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +16 -17
  379. mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +27 -0
  380. mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
  381. mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
  382. mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
  383. mindspore/ops/bprop_mindir/Elu_bprop.mindir +16 -0
  384. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  385. mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +39 -41
  386. mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +16 -0
  387. mindspore/ops/bprop_mindir/Flatten_bprop.mindir +41 -43
  388. mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +51 -57
  389. mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
  390. mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +16 -0
  391. mindspore/ops/bprop_mindir/HSwish_bprop.mindir +16 -0
  392. mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
  393. mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +126 -0
  394. mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +15 -0
  395. mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +30 -0
  396. mindspore/ops/bprop_mindir/LRN_bprop.mindir +43 -0
  397. mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
  398. mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +23 -0
  399. mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +74 -0
  400. mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +74 -0
  401. mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +75 -0
  402. mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +65 -0
  403. mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
  404. mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +27 -0
  405. mindspore/ops/bprop_mindir/Mish_bprop.mindir +35 -0
  406. mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
  407. mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
  408. mindspore/ops/bprop_mindir/OneHot_bprop.mindir +24 -25
  409. mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
  410. mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
  411. mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
  412. mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +29 -0
  413. mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +82 -0
  414. mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +16 -0
  415. mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
  416. mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +18 -19
  417. mindspore/ops/bprop_mindir/Reshape_bprop.mindir +53 -53
  418. mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +29 -0
  419. mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +77 -85
  420. mindspore/ops/bprop_mindir/SeLU_bprop.mindir +21 -0
  421. mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +21 -0
  422. mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
  423. mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +16 -0
  424. mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +36 -0
  425. mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  426. mindspore/ops/bprop_mindir/Softplus_bprop.mindir +16 -0
  427. mindspore/ops/bprop_mindir/Softsign_bprop.mindir +33 -0
  428. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  429. mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +37 -39
  430. mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +70 -72
  431. mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
  432. mindspore/ops/bprop_mindir/Tanh_bprop.mindir +66 -0
  433. mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
  434. mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
  435. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +17 -17
  436. mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +32 -0
  437. mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +38 -0
  438. mindspore/ops/bprop_mindir/generate_mindir.py +2 -0
  439. mindspore/ops/composite/__init__.py +7 -8
  440. mindspore/ops/composite/base.py +101 -47
  441. mindspore/ops/composite/math_ops.py +188 -158
  442. mindspore/ops/composite/multitype_ops/_compile_utils.py +415 -170
  443. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +142 -87
  444. mindspore/ops/composite/multitype_ops/add_impl.py +6 -1
  445. mindspore/ops/composite/multitype_ops/div_impl.py +2 -3
  446. mindspore/ops/composite/multitype_ops/getitem_impl.py +31 -3
  447. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +31 -0
  448. mindspore/ops/composite/multitype_ops/greater_impl.py +31 -0
  449. mindspore/ops/composite/multitype_ops/in_impl.py +9 -0
  450. mindspore/ops/composite/multitype_ops/less_equal_impl.py +31 -0
  451. mindspore/ops/composite/multitype_ops/less_impl.py +31 -0
  452. mindspore/ops/composite/multitype_ops/mul_impl.py +21 -5
  453. mindspore/ops/composite/multitype_ops/not_in_impl.py +9 -0
  454. mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -4
  455. mindspore/ops/composite/multitype_ops/setitem_impl.py +21 -3
  456. mindspore/ops/composite/multitype_ops/sub_impl.py +1 -1
  457. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +35 -4
  458. mindspore/ops/function/__init__.py +152 -8
  459. mindspore/ops/function/array_func.py +2555 -674
  460. mindspore/ops/function/clip_func.py +209 -13
  461. mindspore/ops/function/debug_func.py +2 -2
  462. mindspore/ops/function/grad/__init__.py +2 -1
  463. mindspore/ops/function/grad/grad_func.py +147 -62
  464. mindspore/ops/function/image_func.py +54 -38
  465. mindspore/ops/function/linalg_func.py +167 -16
  466. mindspore/ops/function/math_func.py +4849 -1492
  467. mindspore/ops/function/nn_func.py +2573 -988
  468. mindspore/ops/function/other_func.py +115 -0
  469. mindspore/ops/function/parameter_func.py +3 -3
  470. mindspore/ops/function/random_func.py +790 -73
  471. mindspore/ops/function/sparse_func.py +98 -78
  472. mindspore/ops/function/sparse_unary_func.py +54 -53
  473. mindspore/ops/function/spectral_func.py +27 -24
  474. mindspore/ops/function/vmap_func.py +22 -2
  475. mindspore/ops/functional.py +97 -37
  476. mindspore/ops/op_info_register.py +70 -28
  477. mindspore/ops/operations/__init__.py +47 -14
  478. mindspore/ops/operations/_csr_ops.py +7 -7
  479. mindspore/ops/operations/_embedding_cache_ops.py +5 -5
  480. mindspore/ops/operations/_grad_ops.py +276 -187
  481. mindspore/ops/operations/_inner_ops.py +319 -113
  482. mindspore/ops/operations/_ms_kernel.py +10 -8
  483. mindspore/ops/operations/_ocr_ops.py +9 -9
  484. mindspore/ops/operations/_opaque_predicate_registry.py +4 -0
  485. mindspore/ops/operations/_quant_ops.py +137 -102
  486. mindspore/ops/operations/_rl_inner_ops.py +121 -60
  487. mindspore/ops/operations/_scalar_ops.py +466 -0
  488. mindspore/ops/operations/_sequence_ops.py +1004 -2
  489. mindspore/ops/operations/_tensor_array.py +10 -11
  490. mindspore/ops/operations/_thor_ops.py +1 -1
  491. mindspore/ops/operations/array_ops.py +801 -466
  492. mindspore/ops/operations/comm_ops.py +51 -49
  493. mindspore/ops/operations/control_ops.py +2 -2
  494. mindspore/ops/operations/custom_ops.py +123 -44
  495. mindspore/ops/operations/debug_ops.py +24 -24
  496. mindspore/ops/operations/image_ops.py +240 -153
  497. mindspore/ops/operations/inner_ops.py +34 -50
  498. mindspore/ops/operations/linalg_ops.py +31 -9
  499. mindspore/ops/operations/math_ops.py +988 -757
  500. mindspore/ops/operations/nn_ops.py +965 -819
  501. mindspore/ops/operations/other_ops.py +51 -40
  502. mindspore/ops/operations/random_ops.py +204 -122
  503. mindspore/ops/operations/rl_ops.py +8 -9
  504. mindspore/ops/operations/sparse_ops.py +254 -93
  505. mindspore/ops/operations/spectral_ops.py +35 -3
  506. mindspore/ops/primitive.py +111 -9
  507. mindspore/parallel/_auto_parallel_context.py +189 -83
  508. mindspore/parallel/_offload_context.py +185 -0
  509. mindspore/parallel/_parallel_serialization.py +99 -7
  510. mindspore/parallel/_ps_context.py +9 -5
  511. mindspore/parallel/_recovery_context.py +1 -1
  512. mindspore/parallel/_tensor.py +7 -1
  513. mindspore/{nn/transformer → parallel/_transformer}/__init__.py +6 -6
  514. mindspore/{nn/transformer → parallel/_transformer}/layers.py +6 -37
  515. mindspore/{nn/transformer → parallel/_transformer}/loss.py +4 -7
  516. mindspore/{nn/transformer → parallel/_transformer}/moe.py +20 -16
  517. mindspore/{nn/transformer → parallel/_transformer}/op_parallel_config.py +3 -3
  518. mindspore/{nn/transformer → parallel/_transformer}/transformer.py +48 -111
  519. mindspore/parallel/_utils.py +1 -2
  520. mindspore/parallel/algo_parameter_config.py +1 -1
  521. mindspore/parallel/checkpoint_transform.py +37 -34
  522. mindspore/parallel/shard.py +17 -18
  523. mindspore/profiler/common/validator/validate_path.py +2 -2
  524. mindspore/profiler/envprofiling.py +69 -47
  525. mindspore/profiler/parser/ascend_timeline_generator.py +49 -42
  526. mindspore/profiler/parser/base_timeline_generator.py +49 -56
  527. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +98 -78
  528. mindspore/profiler/parser/hwts_log_parser.py +1 -1
  529. mindspore/profiler/parser/integrator.py +15 -14
  530. mindspore/profiler/parser/minddata_analyzer.py +2 -2
  531. mindspore/profiler/parser/msadvisor_analyzer.py +12 -25
  532. mindspore/profiler/parser/msadvisor_parser.py +2 -4
  533. mindspore/profiler/parser/optime_parser.py +17 -18
  534. mindspore/profiler/parser/profiler_info.py +2 -1
  535. mindspore/profiler/profiling.py +218 -186
  536. mindspore/rewrite/__init__.py +3 -1
  537. mindspore/rewrite/api/node.py +1 -114
  538. mindspore/rewrite/api/node_type.py +3 -0
  539. mindspore/rewrite/api/pattern_engine.py +31 -1
  540. mindspore/rewrite/api/scoped_value.py +4 -4
  541. mindspore/rewrite/api/symbol_tree.py +3 -78
  542. mindspore/rewrite/api/tree_node_helper.py +1 -1
  543. mindspore/rewrite/ast_creator_register.py +1 -0
  544. mindspore/rewrite/ast_helpers/__init__.py +2 -2
  545. mindspore/rewrite/ast_helpers/ast_creator.py +1 -2
  546. mindspore/rewrite/ast_helpers/ast_finder.py +65 -0
  547. mindspore/rewrite/ast_helpers/ast_modifier.py +11 -3
  548. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +18 -2
  549. mindspore/rewrite/namespace.py +0 -2
  550. mindspore/rewrite/node.py +157 -11
  551. mindspore/rewrite/parsers/assign_parser.py +231 -53
  552. mindspore/rewrite/parsers/class_def_parser.py +187 -109
  553. mindspore/rewrite/parsers/for_parser.py +24 -14
  554. mindspore/rewrite/parsers/function_def_parser.py +21 -4
  555. mindspore/rewrite/parsers/if_parser.py +6 -2
  556. mindspore/rewrite/sparsify/__init__.py +0 -0
  557. mindspore/rewrite/sparsify/sparse_transformer.py +448 -0
  558. mindspore/rewrite/sparsify/sparsify.py +109 -0
  559. mindspore/rewrite/sparsify/utils.py +173 -0
  560. mindspore/rewrite/symbol_tree.py +256 -133
  561. mindspore/rewrite/symbol_tree_builder.py +38 -1
  562. mindspore/run_check/_check_version.py +69 -63
  563. mindspore/run_check/run_check.py +2 -1
  564. mindspore/tinyxml2.dll +0 -0
  565. mindspore/train/__init__.py +1 -1
  566. mindspore/train/_utils.py +28 -5
  567. mindspore/train/amp.py +273 -102
  568. mindspore/train/callback/_backup_and_restore.py +5 -5
  569. mindspore/train/callback/_callback.py +2 -2
  570. mindspore/train/callback/_checkpoint.py +3 -3
  571. mindspore/train/callback/_early_stop.py +3 -3
  572. mindspore/train/callback/_lambda_callback.py +2 -2
  573. mindspore/train/callback/_landscape.py +29 -31
  574. mindspore/train/callback/_loss_monitor.py +3 -3
  575. mindspore/train/callback/_on_request_exit.py +3 -3
  576. mindspore/train/callback/_reduce_lr_on_plateau.py +4 -4
  577. mindspore/train/callback/_summary_collector.py +23 -16
  578. mindspore/train/callback/_time_monitor.py +3 -3
  579. mindspore/train/checkpoint_pb2.py +68 -8
  580. mindspore/train/data_sink.py +15 -3
  581. mindspore/train/dataset_helper.py +10 -15
  582. mindspore/train/loss_scale_manager.py +8 -11
  583. mindspore/train/metrics/__init__.py +1 -1
  584. mindspore/train/metrics/bleu_score.py +1 -1
  585. mindspore/train/metrics/confusion_matrix.py +1 -1
  586. mindspore/train/metrics/cosine_similarity.py +1 -1
  587. mindspore/train/metrics/dice.py +2 -2
  588. mindspore/train/metrics/fbeta.py +1 -1
  589. mindspore/train/metrics/hausdorff_distance.py +4 -3
  590. mindspore/train/metrics/mean_surface_distance.py +2 -2
  591. mindspore/train/metrics/occlusion_sensitivity.py +1 -1
  592. mindspore/train/metrics/perplexity.py +1 -1
  593. mindspore/train/metrics/precision.py +1 -1
  594. mindspore/train/metrics/recall.py +1 -1
  595. mindspore/train/metrics/roc.py +2 -2
  596. mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
  597. mindspore/train/mind_ir_pb2.py +116 -37
  598. mindspore/train/model.py +45 -28
  599. mindspore/train/serialization.py +295 -188
  600. mindspore/train/summary/_summary_adapter.py +1 -1
  601. mindspore/train/summary/summary_record.py +43 -13
  602. mindspore/train/train_thor/convert_utils.py +2 -2
  603. mindspore/train/train_thor/dataset_helper.py +3 -3
  604. mindspore/turbojpeg.dll +0 -0
  605. mindspore/version.py +1 -1
  606. {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/METADATA +3 -2
  607. {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/RECORD +610 -541
  608. mindspore/compression/__init__.py +0 -19
  609. mindspore/compression/common/constant.py +0 -124
  610. mindspore/compression/export/__init__.py +0 -19
  611. mindspore/compression/export/quant_export.py +0 -515
  612. mindspore/compression/quant/__init__.py +0 -28
  613. mindspore/compression/quant/qat.py +0 -634
  614. mindspore/compression/quant/quant_utils.py +0 -462
  615. mindspore/compression/quant/quantizer.py +0 -68
  616. mindspore/nn/layer/quant.py +0 -1868
  617. mindspore/nn/layer/rnn_utils.py +0 -90
  618. mindspore/nn/probability/dpn/__init__.py +0 -22
  619. mindspore/nn/probability/dpn/vae/__init__.py +0 -25
  620. mindspore/nn/probability/dpn/vae/cvae.py +0 -140
  621. mindspore/nn/probability/dpn/vae/vae.py +0 -124
  622. mindspore/nn/probability/infer/__init__.py +0 -22
  623. mindspore/nn/probability/infer/variational/elbo.py +0 -70
  624. mindspore/nn/probability/infer/variational/svi.py +0 -84
  625. mindspore/nn/probability/toolbox/__init__.py +0 -22
  626. mindspore/nn/probability/toolbox/anomaly_detection.py +0 -99
  627. mindspore/nn/probability/toolbox/uncertainty_evaluation.py +0 -364
  628. mindspore/nn/probability/transforms/__init__.py +0 -22
  629. mindspore/nn/probability/transforms/transform_bnn.py +0 -262
  630. mindspore/nn/probability/zhusuan/__init__.py +0 -18
  631. mindspore/nn/probability/zhusuan/framework/__init__.py +0 -18
  632. mindspore/nn/probability/zhusuan/framework/bn.py +0 -95
  633. mindspore/nn/probability/zhusuan/variational/__init__.py +0 -18
  634. mindspore/nn/probability/zhusuan/variational/elbo.py +0 -46
  635. mindspore/ops/_op_impl/aicpu/parallel_concat.py +0 -42
  636. mindspore/ops/_op_impl/tbe/gather_v2.py +0 -56
  637. mindspore/ops/bprop_mindir/AssignAdd_bprop.mindir +0 -19
  638. mindspore/ops/bprop_mindir/Cast_bprop.mindir +0 -19
  639. mindspore/ops/bprop_mindir/LogicalOr_bprop.mindir +0 -19
  640. mindspore/ops/bprop_mindir/MatMul_bprop.mindir +0 -0
  641. mindspore/ops/bprop_mindir/ReLU_bprop.mindir +0 -17
  642. mindspore/ops/bprop_mindir/Transpose_bprop.mindir +0 -0
  643. mindspore/ops/bprop_mindir/UpdateState_bprop.mindir +0 -15
  644. mindspore/ops/composite/array_ops.py +0 -241
  645. mindspore/ops/composite/clip_ops.py +0 -134
  646. mindspore/ops/composite/random_ops.py +0 -426
  647. mindspore/ops/composite/vmap_ops.py +0 -38
  648. mindspore/parallel/nn/__init__.py +0 -42
  649. mindspore/parallel/nn/loss.py +0 -22
  650. mindspore/parallel/nn/moe.py +0 -21
  651. mindspore/parallel/nn/op_parallel_config.py +0 -22
  652. mindspore/parallel/nn/transformer.py +0 -31
  653. {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/WHEEL +0 -0
  654. {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/entry_points.txt +0 -0
  655. {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2020-2022 Huawei Technologies Co., Ltd
1
+ # Copyright 2020-2023 Huawei Technologies Co., Ltd
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -15,25 +15,32 @@
15
15
 
16
16
  """constexpr util"""
17
17
  from __future__ import absolute_import
18
+ from enum import IntEnum
19
+
18
20
 
19
21
  from mindspore.ops.composite.multitype_ops import _constexpr_utils as const_utils
20
22
  from mindspore.ops import functional as F
21
23
  from mindspore.ops import operations as P
22
24
  from mindspore.ops.composite import base
23
25
  from mindspore.ops._primitive_cache import _get_cache_prim
24
- from mindspore.ops.operations._inner_ops import TensorCopySlices, SliceGetItem, DynamicBroadcastTo, \
25
- TopTypeof, issubclass_
26
+ from mindspore.ops.operations._inner_ops import TensorCopySlices, SliceGetItem, \
27
+ TopTypeof, issubclass_, IsParameter, GetitemTensorIndexInfo, SetitemTensorIndexInfo
26
28
  from mindspore.common import dtype as mstype
27
29
  from mindspore.common._register_for_tensor import tensor_operator_registry
30
+ from mindspore.common.initializer import Zero
28
31
  from mindspore.common import Tensor, CSRTensor, COOTensor
29
- from mindspore.common._utils import is_shape_unknown
32
+ from mindspore.common import mutable
33
+ from mindspore import ops
34
+ from mindspore.ops.primitive import _primexpr
30
35
 
31
36
  slice_get_item = SliceGetItem()
32
37
  hyper_map = base.HyperMap()
33
38
  stack = P.Stack(axis=-1)
34
39
  copy_slice = TensorCopySlices()
35
- dynamic_broadcast_to = DynamicBroadcastTo()
36
40
  toptypeof = TopTypeof()
41
+ is_parameter = IsParameter()
42
+ getitem_tensor_index_info = GetitemTensorIndexInfo(const_utils.is_ascend())
43
+ setitem_tensor_index_info = SetitemTensorIndexInfo(const_utils.is_ascend())
37
44
 
38
45
 
39
46
  def strided_slice(data, begin_strides, end_strides, step_strides, begin_mask=0, end_mask=0, ellipsis_mask=0,
@@ -44,50 +51,138 @@ def strided_slice(data, begin_strides, end_strides, step_strides, begin_mask=0,
44
51
  return strided_slice_(data, begin_strides, end_strides, step_strides)
45
52
 
46
53
 
54
+ class ValueTransferType(IntEnum):
55
+ """Transfer op types of handling tensor getitem/setitem"""
56
+ kUnknown = 0
57
+ kTensorScatterUpdate = 1
58
+ kExpandDims = 2
59
+ kBroadCast = 3
60
+ kCast = 4
61
+ kSelect = 5
62
+ kGather = 6
63
+ kStrideSlice = 7
64
+ kStrideSliceWithMask = 8
65
+ kGatherND = 9
66
+ kScatterNdUpdate = 10
67
+ kReshape = 11
68
+ kScatterND = 12
69
+ kNumberToTensor = 13
70
+ kHandleSequenceValue = 14
71
+ kByPass = 15
72
+ kReSetItemByIndex = 16
73
+ kCopySlice = 17
74
+ kSetItemByBool = 18
75
+ kEmptyTensor = 19
76
+ kSetItemByEllipsis = 20
77
+ kRaiseIndexError = 21
78
+
79
+
80
+ def data_update(transfer_types, args, data, new_index, value=None):
81
+ """
82
+ We finally generate a new tensor when handling tensor getitem/setitem
83
+ by transfer data and value with index.
84
+ """
85
+ for transfer_type, arg in zip(transfer_types, args):
86
+ if transfer_type == ValueTransferType.kUnknown:
87
+ raise IndexError(f"Inlvaid transfer type {transfer_type}.")
88
+ if transfer_type <= ValueTransferType.kScatterND:
89
+ data = data_update_by_ops(transfer_type, arg, data, new_index, value)
90
+ if transfer_type == ValueTransferType.kSetItemByBool:
91
+ return tensor_setitem_by_bool(data, new_index, value)
92
+ if transfer_type == ValueTransferType.kCopySlice:
93
+ return copy_slice(data, value.astype(data.dtype), arg[0], arg[1], arg[2])
94
+ if transfer_type == ValueTransferType.kSetItemByEllipsis:
95
+ return tensor_setitem_by_ellipsis(data, new_index, value)
96
+ if transfer_type == ValueTransferType.kReSetItemByIndex:
97
+ data[new_index] = value
98
+ return data
99
+ if transfer_type == ValueTransferType.kEmptyTensor:
100
+ return handle_empty_tensor(arg, data)
101
+ if transfer_type == ValueTransferType.kRaiseIndexError:
102
+ raise IndexError(
103
+ f'index {arg[0]} is out of bounds for dimension with size {arg[1]}')
104
+ return data
105
+
106
+
107
+ def data_update_by_ops(transfer_type, arg, data, new_index, value=None):
108
+ """
109
+ Generate a new tensor when handling tensor getitem/setitem
110
+ by ops.
111
+ """
112
+ if transfer_type == ValueTransferType.kStrideSliceWithMask:
113
+ stride_info, mask_index = arg[0], arg[1]
114
+ data = strided_slice(data, stride_info[0], stride_info[1], stride_info[2],
115
+ mask_index[0], mask_index[1], 0, 0, mask_index[2])
116
+ elif transfer_type == ValueTransferType.kGatherND:
117
+ if isinstance(new_index, list):
118
+ new_index = handle_multi_dim_index_tensor(new_index, arg)
119
+ data = F.gather_nd(data, Tensor(new_index))
120
+ elif transfer_type == ValueTransferType.kTensorScatterUpdate:
121
+ if isinstance(new_index, list):
122
+ new_index = handle_multi_dim_index_tensor(new_index, arg)
123
+ data = F.tensor_scatter_update(data, new_index, value)
124
+ elif transfer_type == ValueTransferType.kScatterNdUpdate:
125
+ F.scatter_nd_update(data, new_index, value)
126
+ elif transfer_type == ValueTransferType.kSelect:
127
+ data = F.select(Tensor(new_index), value, data)
128
+ elif transfer_type == ValueTransferType.kReshape:
129
+ data = F.reshape(data, arg)
130
+ elif transfer_type == ValueTransferType.kGather:
131
+ data = F.gather(data, new_index, 0)
132
+ elif transfer_type == ValueTransferType.kExpandDims:
133
+ data = F.expand_dims(data, 0)
134
+ elif transfer_type == ValueTransferType.kStrideSlice:
135
+ data = F.strided_slice(data, arg[0], arg[1], arg[2])
136
+ else:
137
+ raise IndexError(f"Inlvaid transfer type {transfer_type}.")
138
+ return data
139
+
140
+
141
+ def value_update(transfer_types, args, data, value):
142
+ """Transfer value before set value to tensor when handling tensor setitem"""
143
+ for transfer_type, arg in zip(transfer_types, args):
144
+ if transfer_type == ValueTransferType.kByPass:
145
+ continue
146
+ if transfer_type == ValueTransferType.kNumberToTensor:
147
+ value = F.fill(F.dtype(data), (), value)
148
+ elif transfer_type == ValueTransferType.kHandleSequenceValue:
149
+ op_type, index = arg
150
+ if op_type == const_utils.SET_ITEM_BY_ONE_TENSOR:
151
+ index = Tensor(index)
152
+ value = _generate_updates_from_sequence(
153
+ data, index, value, op_type)
154
+ elif transfer_type == ValueTransferType.kExpandDims:
155
+ value = F.expand_dims(value, arg)
156
+ elif transfer_type == ValueTransferType.kBroadCast:
157
+ value = _broadcast(arg, value.astype(F.dtype(data)))
158
+ elif transfer_type == ValueTransferType.kCast:
159
+ value = F.cast(value, F.dtype(data))
160
+ elif transfer_type == ValueTransferType.kReshape:
161
+ value = F.reshape(value, arg)
162
+ elif transfer_type == ValueTransferType.kScatterND:
163
+ value = F.scatter_nd(arg[0], value, arg[1])
164
+ else:
165
+ raise IndexError(f"Inlvaid transfer type {transfer_type}.")
166
+ return value
167
+
168
+
47
169
  def _tensor_getitem(self, index):
48
170
  """Handle tensor getitem"""
49
- if isinstance(index, Tensor):
50
- return tensor_index_by_tensor(self, index)
51
- if isinstance(index, list):
52
- return tensor_index_by_list(self, index)
53
- if isinstance(index, tuple):
54
- return tensor_index_by_tuple(self, index)
55
- if isinstance(index, bool):
56
- return _tensor_index_by_bool(self, index)
57
- if isinstance(index, int):
58
- return _tensor_index_by_integer(self, index)
59
- if isinstance(index, slice):
60
- return tensor_index_by_slice(self, index)
61
- if index is None:
62
- return F.expand_dims(self, 0)
63
- if index is ...:
64
- return self
65
- raise IndexError(f"Only support integers, slices(`:`), ellipsis(`...`), None, bool, tensor with int, "
66
- f"list and tuple ,but got {index} with type {type(index)}.")
171
+ new_index, tensor_update_types, tensor_update_args = getitem_tensor_index_info(
172
+ self, index)
173
+ return data_update(tensor_update_types, tensor_update_args, self, new_index)
67
174
 
68
175
 
69
176
  def _tensor_setitem(self, index, value):
70
177
  """Handle tensor setitem"""
71
- if not isinstance(value, (int, float, bool, list, tuple, Tensor)):
72
- raise ValueError(f"only support numbers, Tensor, tuple, list as value,"
73
- f"but got {value} with type {type(value)}.")
74
- if isinstance(index, list):
75
- index = format_list_indices(index, F.shape(self)[0])
76
- if isinstance(index, Tensor):
77
- return tensor_setitem_by_tensor(self, index, value)
78
- if isinstance(index, tuple):
79
- return tensor_setitem_by_tuple(self, index, value)
80
- if isinstance(index, bool):
81
- return tensor_setitem_by_bool(self, index, value)
82
- if isinstance(index, int):
83
- return tensor_setitem_by_number(self, index, value)
84
- if isinstance(index, slice):
85
- return tensor_setitem_by_slice(self, index, value)
86
- if index in (None, ...):
87
- return tensor_setitem_by_ellipsis(self, index, value)
88
-
89
- raise IndexError("Tensor setitem index only support integers, slices(`:`), ellipsis(`...`), bool, tensor, \
90
- list and tuple, but got {index} with type{type(index)}")
178
+ setitem_info = setitem_tensor_index_info(self, index, value)
179
+ new_index = setitem_info[0]
180
+ v_transfer_types = setitem_info[1]
181
+ v_transfer_args = setitem_info[2]
182
+ data_update_types = setitem_info[3]
183
+ data_update_args = setitem_info[4]
184
+ value = value_update(v_transfer_types, v_transfer_args, self, value)
185
+ return data_update(data_update_types, data_update_args, self, new_index, value)
91
186
 
92
187
 
93
188
  tensor_operator_registry.register("__getitem__", _tensor_getitem)
@@ -171,6 +266,13 @@ tensor_operator_registry.register('__rpow__', _tensor_rpow)
171
266
  tensor_operator_registry.register('__floordiv__', _tensor_floordiv)
172
267
 
173
268
 
269
+ def _scalar_to_tensor(input_x):
270
+ if ops.isconstant(input_x):
271
+ return P.ScalarToTensor()(input_x, ops.dtype(input_x))
272
+ # use add Tensor([0]) cast scalar to tensor.
273
+ return ops.add(input_x, mutable(Tensor(0)))
274
+
275
+
174
276
  def tensor_item(data, *args):
175
277
  """Tensor getitem by index whose dtype is int or tuple with int."""
176
278
  # transform a.item(tuple(int)) -> a.item(int1,int2...intN)
@@ -245,13 +347,9 @@ def tensor_itemset_by_tuple_with_number(data, tuple_index, nubmer_value):
245
347
 
246
348
  def _broadcast(broadcast_shape, x):
247
349
  """Broadcast tensor to the required shape."""
248
- if not const_utils.check_two_shapes_need_broadcast(broadcast_shape, F.shape(x)):
350
+ if F.shape(x) == broadcast_shape:
249
351
  return x
250
- multiples = const_utils.compute_multiples(F.shape(x), broadcast_shape)
251
- if multiples:
252
- x = F.reshape(x, const_utils.expanded_shape(F.shape(x), len(multiples) - F.rank(x)))
253
- return F.tile(x, multiples)
254
- return x
352
+ return F.broadcast_to(x, broadcast_shape)
255
353
 
256
354
 
257
355
  def _transform_indexing_tensor(broadcast_shape, final_shape, new_shape, item):
@@ -291,6 +389,46 @@ def _transform_ellipsis_to_slice(data, tuple_index, op_name):
291
389
  return tuple_index_new
292
390
 
293
391
 
392
+ def handle_empty_tensor(arg, data):
393
+ """handle data update with empty tensor"""
394
+ if 0 in arg:
395
+ init_func = Zero()
396
+ init_func.__enable_zero_dim__ = True
397
+ return Tensor(shape=arg, dtype=data.dtype, init=init_func)
398
+ return const_utils.make_tensor([], data.dtype, arg)
399
+
400
+
401
+ def handle_multi_dim_index_tensor(new_index, arg):
402
+ """handle data update with multi dim index tensor"""
403
+ slice_cnt = 0
404
+ new_indies_tensor = []
405
+ if len(arg) == 1:
406
+ broadcast_shape = arg[0]
407
+ new_index = hyper_map(F.partial(Tensor), new_index)
408
+ broadcast_tensors = hyper_map(
409
+ F.partial(_broadcast, broadcast_shape), new_index)
410
+ new_broadcast_tensors = ()
411
+ for tensor in broadcast_tensors:
412
+ new_broadcast_tensors += (F.cast(tensor, mstype.int64),)
413
+ new_index = stack(new_broadcast_tensors)
414
+ return new_index
415
+ broadcast_shape, final_shape, index_tensor_new_shape, slice_shapes, tensor_positions, fancy_position = arg
416
+ for i, index in enumerate(new_index):
417
+ if i in tensor_positions:
418
+ transform_tensor = _transform_indexing_tensor(broadcast_shape, final_shape, index_tensor_new_shape,
419
+ Tensor(index))
420
+ new_indies_tensor.append(F.cast(transform_tensor, mstype.int64))
421
+ else:
422
+ shape = const_utils.compute_slice_shape(
423
+ slice_shapes, len(broadcast_shape), slice_cnt, fancy_position)
424
+ array = Tensor(index).reshape(shape)
425
+ slice_index_tensor = _broadcast(final_shape, array)
426
+ new_indies_tensor.append(F.cast(slice_index_tensor, mstype.int64))
427
+ slice_cnt += 1
428
+ new_index = stack(new_indies_tensor)
429
+ return new_index
430
+
431
+
294
432
  def _expand_data_dims(data, tuple_index):
295
433
  """expand the data's dim with 'None' and 'Boolean' in tuple_index"""
296
434
  indexes_types = hyper_map(toptypeof, tuple_index)
@@ -313,12 +451,34 @@ def _expand_data_dims(data, tuple_index):
313
451
  return data, tuple_index_new
314
452
 
315
453
 
454
+ def convert_variable_to_tensor_slice(slice_index):
455
+ """convert mutable scalar to tensor"""
456
+ start = slice_get_item(slice_index, "start")
457
+ stop = slice_get_item(slice_index, "stop")
458
+ step = slice_get_item(slice_index, "step")
459
+ find_mutable_scalar = False
460
+ if isinstance(start, int) and not F.isconstant(start):
461
+ start = ops.Cast()(start, mstype.int64)
462
+ find_mutable_scalar = True
463
+ if isinstance(stop, int) and not F.isconstant(stop):
464
+ stop = ops.Cast()(stop, mstype.int64)
465
+ find_mutable_scalar = True
466
+ if isinstance(step, int) and not F.isconstant(step):
467
+ step = ops.Cast()(step, mstype.int64)
468
+ find_mutable_scalar = True
469
+ if find_mutable_scalar:
470
+ return F.make_slice(start, stop, step)
471
+ return slice_index
472
+
473
+
316
474
  def tensor_index_by_slice(data, slice_index):
317
475
  """Tensor getitem by a slice."""
318
476
  min_data_dim, max_data_dim = 1, 8
319
477
  const_utils.judge_data_dim(data.ndim, min_data_dim, max_data_dim)
320
478
  data_shape = F.shape(data)
321
- is_dynamic = (is_shape_unknown(data_shape)
479
+ slice_index = convert_variable_to_tensor_slice(slice_index)
480
+
481
+ is_dynamic = (F.is_sequence_value_unknown(data_shape)
322
482
  or isinstance(slice_get_item(slice_index, "start"), Tensor)
323
483
  or isinstance(slice_get_item(slice_index, "stop"), Tensor)
324
484
  or isinstance(slice_get_item(slice_index, "step"), Tensor))
@@ -341,6 +501,12 @@ def get_stride_info_from_slice(data, slice_index):
341
501
  data_shape = F.dyn_shape(data)
342
502
  begin_strides, end_strides, step_strides = [], [], []
343
503
  start, stop, step = get_slice_stride(slice_index, data_shape[0])
504
+ if start.ndim > 0:
505
+ start = start.item()
506
+ if stop.ndim > 0:
507
+ stop = stop.item()
508
+ if step.ndim > 0:
509
+ step = step.item()
344
510
  begin_strides.append(start)
345
511
  end_strides.append(stop)
346
512
  step_strides.append(step)
@@ -370,19 +536,10 @@ def _tensor_index_by_bool(data, bool_value):
370
536
  return const_utils.raise_index_error("When tensor is indexed by a bool object, the value only support 'True'.")
371
537
 
372
538
 
373
- def check_range(x, dim_size):
374
- """Check whether x is within the range of dim_size"""
375
- tensor_x = const_utils.make_tensor(x)
376
- if tensor_x >= dim_size or tensor_x < -dim_size:
377
- return tensor_x
378
- tensor_x = tensor_x % dim_size
379
- return tensor_x
380
-
381
-
382
539
  def get_stride_info_from_integer(tensor_int):
383
540
  """Convert integer to slice"""
384
541
  begin_strides = [tensor_int]
385
- end_strides = [tensor_int + const_utils.make_tensor(1)]
542
+ end_strides = [tensor_int + 1]
386
543
  step_strides = [const_utils.make_tensor(1)]
387
544
  begin_tensor = stack(begin_strides)
388
545
  end_tensor = stack(end_strides)
@@ -398,10 +555,9 @@ def _tensor_index_by_integer(data, int_index):
398
555
  if data.ndim < 1 or data.ndim > 8:
399
556
  const_utils.raise_value_error("Expect Tensor to have dimension between 1 and 8.")
400
557
 
401
- if is_shape_unknown(data_shape):
402
- data_shape = F.dyn_shape(data)
403
- transformed_tensor = check_range(int_index, data_shape[0])
404
- begin_strides, end_strides, step_strides = get_stride_info_from_integer(transformed_tensor)
558
+ if F.is_sequence_value_unknown(data_shape) or not F.isconstant(int_index):
559
+ tensor_index = _scalar_to_tensor(int_index)
560
+ begin_strides, end_strides, step_strides = get_stride_info_from_integer(tensor_index)
405
561
  else:
406
562
  transformed_number = const_utils.check_range(int_index, data_shape[0])
407
563
  begin_strides, end_strides, step_strides = \
@@ -415,16 +571,35 @@ def _tensor_index_by_integer(data, int_index):
415
571
  return strided_slice(data, begin_strides, end_strides, step_strides, begin_mask, end_mask, 0, 0, shrink_axis_mask)
416
572
 
417
573
 
574
+ def _check_dim_shape_valid(data, tensor_index):
575
+ """check dim and shape of tensor_index for tensor(bool) indexing"""
576
+ if data.ndim < tensor_index.ndim:
577
+ raise IndexError(f"The dim of index cannot be greater than indexed data, but got "
578
+ f"dim of index:{tensor_index.ndim}, dim of data:{data.ndim}")
579
+ if data.shape[:tensor_index.ndim] != tensor_index.shape[:]:
580
+ raise IndexError(f"The shape of index {tensor_index.shape} does not match the shape "
581
+ f"of the indexed data {data.shape}")
582
+
583
+
584
+ def tensor_index_by_bool_tensor(data, tensor_index):
585
+ """Tensor getitem by a bool tensor"""
586
+ _check_dim_shape_valid(data, tensor_index)
587
+ tensor_index = tensor_index.nonzero()
588
+ return F.gather_nd(data, tensor_index)
589
+
590
+
418
591
  def tensor_index_by_tensor(data, tensor_index):
419
592
  """Tensor getitem by a single tensor"""
420
593
  min_data_dim, max_data_dim = 0, 7
421
594
  const_utils.judge_data_dim(data.ndim, min_data_dim, max_data_dim)
422
- valid = const_utils.check_type_isinstance(F.dtype(tensor_index), mstype.Int)
423
- if valid is False:
424
- exp_msg = const_utils.gen_exception_msg(
425
- "The tensor index must be int type, but got {}.", F.dtype(tensor_index))
426
- const_utils.raise_index_error(exp_msg)
427
- return F.gather(data, tensor_index, 0)
595
+ if const_utils.check_type_isinstance(F.dtype(tensor_index), mstype.Int):
596
+ return F.gather(data, tensor_index, 0)
597
+ if const_utils.check_type_isinstance(F.dtype(tensor_index), mstype.Bool):
598
+ return tensor_index_by_bool_tensor(data, tensor_index)
599
+ exp_msg = const_utils.gen_exception_msg(
600
+ "The tensor index must be int or bool type, but got {}.", F.dtype(tensor_index))
601
+ const_utils.raise_index_error(exp_msg)
602
+ return data
428
603
 
429
604
 
430
605
  def tensor_index_by_list(data, list_index):
@@ -435,10 +610,13 @@ def tensor_index_by_list(data, list_index):
435
610
  data_shape = F.shape(data)
436
611
  indexes_types = hyper_map(toptypeof, list_index)
437
612
  if const_utils.check_type_isinstance(indexes_types, (mstype.Bool, mstype.Int)):
438
- if data_shape[0] == -1 and all(isinstance(i, bool) for i in list_index):
439
- const_utils.raise_unimplemented_error(
440
- "Not supported to the dynamic shape tensor slice by using list of Boolean type")
441
- tensor_index = const_utils.sequence_to_index(list_index, data_shape[0])
613
+ if not F.isconstant(data_shape[0]):
614
+ if all(isinstance(i, bool) for i in list_index):
615
+ const_utils.raise_unimplemented_error(
616
+ "Not supported to the dynamic shape tensor slice by using list of Boolean type")
617
+ tensor_index = const_utils.sequence_to_index(list_index, None)
618
+ else:
619
+ tensor_index = const_utils.sequence_to_index(list_index, data_shape[0])
442
620
  if tensor_index is False:
443
621
  const_utils.raise_index_error("When tensor is indexed by list, the list can't be empty.")
444
622
  return F.gather(data, tensor_index, 0)
@@ -449,18 +627,28 @@ def tensor_index_by_list(data, list_index):
449
627
  return tensor_index_by_tuple(data, tuple_index_new)
450
628
 
451
629
 
630
+ def convert_tupleslice_to_tensor(tuple_index):
631
+ """convert mutable scalar in slice to tensor"""
632
+ new_tuple_index = []
633
+ for item in tuple_index:
634
+ if isinstance(item, slice):
635
+ item = convert_variable_to_tensor_slice(item)
636
+ new_tuple_index.append(item)
637
+ return tuple(new_tuple_index)
638
+
639
+
452
640
  def tensor_index_by_tuple(data, tuple_index):
453
641
  """Tensor getitem by tuple of various types with None"""
454
642
  if not tuple_index:
455
643
  return data
456
644
 
645
+ tuple_index = convert_tupleslice_to_tensor(tuple_index)
457
646
  op_name = const_utils.TENSOR_GETITEM
458
647
  tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name)
459
648
  data, tuple_index = _expand_data_dims(data, tuple_index)
460
649
 
461
650
  min_data_dim, max_data_dim = 1, 8
462
651
  const_utils.judge_data_dim(data.ndim, min_data_dim, max_data_dim)
463
-
464
652
  indexes_types = hyper_map(toptypeof, tuple_index)
465
653
  contain_type = const_utils.tuple_index_type_cnt(indexes_types, op_name)
466
654
  if contain_type == const_utils.ALL_BASIC:
@@ -468,31 +656,6 @@ def tensor_index_by_tuple(data, tuple_index):
468
656
  return _tensor_getitem_by_tuple(data, tuple_index, op_name)
469
657
 
470
658
 
471
- def _tensor_getitem_by_tuple_of_tensor(data, tuple_index, op_name):
472
- """Tensor getitem by a tuple of tensor."""
473
- data_shape = F.shape(data)
474
- tuple_index_len = len(tuple_index)
475
-
476
- indexes_types = hyper_map(F.dtype, tuple_index)
477
- const_utils.check_indexes_types_valid(indexes_types, mstype.int_type, op_name)
478
- tensor_index_shape = hyper_map(F.shape, tuple_index)
479
- broadcast_shape = const_utils.generate_broadcast_shape(tensor_index_shape, op_name)
480
- if 0 in broadcast_shape:
481
- res_shape = broadcast_shape
482
- if tuple_index_len < len(data_shape):
483
- res_shape += data_shape[tuple_index_len:]
484
- res = const_utils.make_tensor([], data.dtype, res_shape)
485
- return res
486
-
487
- broadcast_tensors = hyper_map(F.partial(_broadcast, broadcast_shape), tuple_index)
488
- new_broadcast_tensors = ()
489
- for tensor in broadcast_tensors:
490
- new_broadcast_tensors += (F.cast(tensor, mstype.int64),)
491
- indices = stack(new_broadcast_tensors)
492
- result = F.gather_nd(data, indices)
493
- return result
494
-
495
-
496
659
  def get_slice_stride(slice_index, dim_size):
497
660
  """Get slice stride info"""
498
661
  start = slice_get_item(slice_index, "start")
@@ -551,7 +714,7 @@ def _get_stride_info_from_tuple(data, tuple_index):
551
714
  step_strides.append(step)
552
715
  index_count = index_count + 1
553
716
  elif isinstance(index, int):
554
- int_tensor = check_range(index, dim_size)
717
+ int_tensor = _scalar_to_tensor(index)
555
718
  begin_strides.append(int_tensor)
556
719
  end_strides.append(int_tensor + const_utils.make_tensor(1))
557
720
  step_strides.append(const_utils.make_tensor(1))
@@ -585,7 +748,7 @@ def _get_stride_info_from_tuple(data, tuple_index):
585
748
  def _tensor_getitem_by_tuple_slice(data, tuple_index):
586
749
  """Tensor getitem by a tuple of slice"""
587
750
  data_shape = F.shape(data)
588
- is_dynamic = is_shape_unknown(data_shape)
751
+ is_dynamic = F.is_sequence_value_unknown(data_shape)
589
752
  for item in tuple_index:
590
753
  if isinstance(item, slice):
591
754
  is_dynamic = is_dynamic or isinstance(slice_get_item(item, "start"), Tensor) \
@@ -607,6 +770,39 @@ def _tensor_getitem_by_tuple_slice(data, tuple_index):
607
770
  return strided_slice(data, begin_v, end_v, step_v, begin_mask, end_mask, 0, 0, shrink_axis_mask)
608
771
 
609
772
 
773
+ @_primexpr
774
+ def _tensor_getitem_by_tuple_parse_bool_tensor_index(index, tuple_index_new, tensor_indexes,
775
+ tensor_positions_new):
776
+ """ parse index of bool tensor type """
777
+ indices = index.nonzero()
778
+ if indices.shape[0] == 0:
779
+ return None, tensor_indexes, tensor_positions_new
780
+ indices = F.cast(indices, mstype.int64)
781
+ indices = indices.T
782
+ for sub_index in indices:
783
+ tensor_positions_new.append(len(tuple_index_new))
784
+ tuple_index_new += (sub_index,)
785
+ tensor_indexes.append(sub_index)
786
+ return tuple_index_new, tensor_indexes, tensor_positions_new
787
+
788
+
789
+ def _tensor_getitem_by_tuple_parse_tensor_index(index, tuple_index_new, tensor_indexes, tensor_positions_new):
790
+ """ parse index of tensor type """
791
+ if F.dtype(index) in mstype.int_type:
792
+ tensor_index = F.cast(index, mstype.int64)
793
+ tensor_positions_new.append(len(tuple_index_new))
794
+ tuple_index_new += (tensor_index,)
795
+ tensor_indexes.append(tensor_index)
796
+ elif F.dtype(index) == mstype.bool_:
797
+ return _tensor_getitem_by_tuple_parse_bool_tensor_index(index, tuple_index_new, tensor_indexes,
798
+ tensor_positions_new)
799
+ else:
800
+ exp_msg = const_utils.gen_exception_msg(
801
+ "The tensor element in tuple index must be int or bool type, but got {}.", F.dtype(index))
802
+ const_utils.raise_index_error(exp_msg)
803
+ return tuple_index_new, tensor_indexes, tensor_positions_new
804
+
805
+
610
806
  def _tensor_getitem_by_tuple(data, tuple_index, op_name):
611
807
  """Tensor getitem by a tuple of mixed tensor."""
612
808
  slice_is_tensor = False
@@ -617,51 +813,49 @@ def _tensor_getitem_by_tuple(data, tuple_index, op_name):
617
813
  or isinstance(slice_get_item(item, "step"), Tensor)
618
814
  if slice_is_tensor:
619
815
  const_utils.raise_index_error("Not supported when slice has tensor")
620
- tuple_index_len = len(tuple_index)
621
- tensor_indexes, slice_indexes = [], []
816
+
622
817
  indexes_types = hyper_map(toptypeof, tuple_index)
623
818
  slice_positions, _, _, int_positions, _, tensor_positions, sequence_positions = \
624
819
  const_utils.get_pos_of_indexes_types(indexes_types, op_name)
625
- tuple_index_new, slice_shapes = (), ()
626
820
  data_shape = F.shape(data)
821
+ tensor_indexes, slice_indexes = [], []
822
+ tuple_index_new, slice_shapes = (), ()
823
+ slice_positions_new, tensor_positions_new = [], []
627
824
  for i, (index, dim_size) in enumerate(zip(tuple_index, data_shape)):
628
825
  if i in int_positions:
629
826
  int_index = const_utils.check_range(index, dim_size)
630
827
  tensor_index = F.scalar_to_tensor(int_index, mstype.int64)
631
- if is_shape_unknown(data_shape):
632
- dyn_shape = F.dyn_shape(data)
633
- tensor_index = check_range(index, dyn_shape[i])
828
+ if F.is_sequence_value_unknown(data_shape):
829
+ tensor_index = _scalar_to_tensor(int_index)
634
830
  tensor_index = F.cast(tensor_index, mstype.int64)
831
+ tensor_positions_new.append(len(tuple_index_new))
635
832
  tuple_index_new += (tensor_index,)
636
833
  tensor_indexes.append(tensor_index)
637
- tensor_positions += (i,)
638
834
  elif i in sequence_positions:
639
835
  tensor_index = const_utils.sequence_to_index(index, dim_size)
640
836
  if tensor_index is False:
641
837
  const_utils.raise_index_error("The sequence element(tuple/list) in tuple index can't be empty.")
838
+ tensor_positions_new.append(len(tuple_index_new))
642
839
  tuple_index_new += (tensor_index,)
643
840
  tensor_indexes.append(tensor_index)
644
- tensor_positions += (i,)
645
841
  elif i in tensor_positions:
646
- invalid = const_utils.check_type_invalid(F.dtype(index), mstype.int_type)
647
- if invalid:
648
- exp_msg = const_utils.gen_exception_msg(
649
- "The tensor element in tuple index must be int type, but got {}.", F.dtype(index))
650
- const_utils.raise_index_error(exp_msg)
651
- tensor_index = F.cast(index, mstype.int64)
652
- tuple_index_new += (tensor_index,)
653
- tensor_indexes.append(tensor_index)
842
+ tuple_index_new, tensor_indexes, tensor_positions_new = \
843
+ _tensor_getitem_by_tuple_parse_tensor_index(index, tuple_index_new,
844
+ tensor_indexes, tensor_positions_new)
845
+ if tuple_index_new is None:
846
+ return Tensor([])
654
847
  elif i in slice_positions:
655
848
  slice_ele_list_index = const_utils.transform_slice_to_ele_list(index, dim_size)
656
849
  slice_shapes += (len(slice_ele_list_index),)
850
+ slice_positions_new.append(len(tuple_index_new))
657
851
  tuple_index_new += (slice_ele_list_index,)
658
852
  slice_indexes.append(slice_ele_list_index)
659
-
660
853
  tensor_indexes_shapes = hyper_map(F.shape, tensor_indexes)
661
854
  broadcast_shape, index_tensor_new_shape, final_shape, fancy_position = \
662
- const_utils.generate_index_info_from_tuple_of_mixed_tensors(tensor_positions, tensor_indexes_shapes,
855
+ const_utils.generate_index_info_from_tuple_of_mixed_tensors(tensor_positions_new, tensor_indexes_shapes,
663
856
  slice_shapes, op_name)
664
857
 
858
+ tuple_index_len = len(tuple_index)
665
859
  if 0 in final_shape + data_shape:
666
860
  if tuple_index_len < len(data_shape):
667
861
  final_shape = final_shape + data_shape[tuple_index_len:]
@@ -670,11 +864,11 @@ def _tensor_getitem_by_tuple(data, tuple_index, op_name):
670
864
  final_index_tensors = []
671
865
  slice_cnt = 0
672
866
  for i, index in enumerate(tuple_index_new):
673
- if i in tensor_positions:
867
+ if i in tensor_positions_new:
674
868
  transform_tensor = _transform_indexing_tensor(broadcast_shape, final_shape, index_tensor_new_shape,
675
869
  index)
676
870
  final_index_tensors.append(transform_tensor)
677
- elif i in slice_positions:
871
+ elif i in slice_positions_new:
678
872
  slice_index_tensor = convert_slice_to_tensor(index, final_shape, slice_cnt, broadcast_shape,
679
873
  slice_shapes, fancy_position)
680
874
  final_index_tensors.append(slice_index_tensor)
@@ -709,7 +903,6 @@ def _generate_indices_from_tuple(data, tuple_index, op_name, fancy_position):
709
903
  slice_positions, _, _, int_positions, _, tensor_positions, sequence_positions = \
710
904
  const_utils.get_pos_of_indexes_types(indexes_types, op_name)
711
905
  tuple_index_new, slice_shapes = (), ()
712
-
713
906
  for i, (index, dim_size) in enumerate(zip(tuple_index, data_shape)):
714
907
  if i in int_positions:
715
908
  int_index = const_utils.check_range(index, dim_size)
@@ -726,7 +919,7 @@ def _generate_indices_from_tuple(data, tuple_index, op_name, fancy_position):
726
919
  invalid = const_utils.check_type_invalid(F.dtype(index), mstype.int_type)
727
920
  if invalid:
728
921
  exp_msg = const_utils.gen_exception_msg(
729
- "The tensor element in tuple index must be int type, but got {}.", F.dtype(index))
922
+ "The tensor element in tuple index must be int or bool type, but got {}.", F.dtype(index))
730
923
  const_utils.raise_index_error(exp_msg)
731
924
  tensor_index = F.cast(index, mstype.int64)
732
925
  tuple_index_new += (tensor_index,)
@@ -791,11 +984,11 @@ def _generate_updates_from_sequence(data, index, value, op_type):
791
984
  def _generate_updates_from_tensor(data, index, value, op_type):
792
985
  """Generate an updates tensor from a tensor."""
793
986
  value = value.astype(data.dtype)
794
- if is_shape_unknown(F.shape(data)):
987
+ if F.is_sequence_value_unknown(F.shape(data)):
795
988
  data_shape = F.dyn_shape(data)
796
989
  index_shape = F.dyn_shape(index)
797
990
  updates_shape = const_utils.generate_updates_shape(data_shape, index_shape, op_type, True)
798
- updates = dynamic_broadcast_to(value, updates_shape)
991
+ updates = ops.broadcast_to(value, updates_shape)
799
992
  return updates
800
993
  updates_shape = const_utils.generate_updates_shape(data.shape, index.shape, op_type, False)
801
994
  need_broadcast = const_utils.check_two_shapes_need_broadcast(updates_shape, value.shape)
@@ -815,6 +1008,7 @@ def tensor_setitem_by_tensor(self, index, value):
815
1008
 
816
1009
 
817
1010
  def tensor_setitem_by_tuple(self, index, value):
1011
+ index = convert_tupleslice_to_tensor(index)
818
1012
  if isinstance(value, (int, float, bool)):
819
1013
  index = format_tuple_indices(index)
820
1014
  return tensor_setitem_by_tuple_with_number(self, index, value)
@@ -832,6 +1026,7 @@ def tensor_setitem_by_number(self, index, value):
832
1026
 
833
1027
 
834
1028
  def tensor_setitem_by_slice(self, index, value):
1029
+ index = convert_variable_to_tensor_slice(index)
835
1030
  if isinstance(value, (int, float, bool)):
836
1031
  return tensor_setitem_by_slice_with_number(self, index, value)
837
1032
  if isinstance(value, Tensor):
@@ -852,28 +1047,29 @@ def _tensor_setitem_by_int_tensor_with_tensor(data, index, value):
852
1047
  if F.rank(index) == 0:
853
1048
  index = F.expand_dims(index, -1)
854
1049
  updates = _generate_updates_from_tensor(data, index, value, const_utils.SET_ITEM_BY_ONE_TENSOR)
855
- index = F.select(index < 0, index + F.shape(data)[0], index)
1050
+ data_shape = F.shape(data)
1051
+ first_val = data_shape[0]
1052
+ if not F.isconstant(first_val):
1053
+ first_val = -1
1054
+ index = F.select(index < 0, index + first_val, index)
856
1055
  index = F.expand_dims(index, -1)
857
1056
  if F.rank(index) < 2:
858
1057
  index = F.expand_dims(index, 0)
859
1058
  updates = F.expand_dims(updates, 0)
1059
+ if is_parameter(data):
1060
+ F.scatter_nd_update(data, index, updates)
1061
+ return data
860
1062
  return F.tensor_scatter_update(data, index, updates)
861
1063
 
862
1064
 
863
1065
  def _tensor_setitem_by_bool_tensor_with_tensor(data, index, value):
864
1066
  """Set a tensor item by a bool tensor with a tensor."""
865
- index_shape = F.shape(index)
866
- data_shape = F.shape(data)
867
- const_utils.check_equal(data_shape, index_shape,
868
- "The tensor(shape={}) and tensor index(shape={}) should be the same shape.")
869
- size = F.shape_mul(F.shape(value))
870
- const_utils.check_equal(1, size,
871
- "When assign value is a tensor, its size should be {}, but current size is {}.")
872
- dtype = F.dtype(data)
873
- u_cast = F.cast(value, dtype)
874
- one_data = F.ones_like(data)
875
- u = F.tensor_mul(one_data, u_cast)
876
- result = F.select(index, u, data)
1067
+ index = index.reshape(const_utils.generate_padding_shape(index.shape, len(data.shape)))
1068
+ index = F.broadcast_to(index, data.shape)
1069
+ value = F.cast(value, F.dtype(data))
1070
+ value = value.reshape(const_utils.generate_padding_shape(value.shape, len(data.shape)))
1071
+ value = F.broadcast_to(value, data.shape)
1072
+ result = F.select(index, value, data)
877
1073
  return result
878
1074
 
879
1075
 
@@ -884,7 +1080,7 @@ def tensor_setitem_by_tensor_with_tensor(data, index, value_tensor):
884
1080
  if tensor_dtype == const_utils.INT_:
885
1081
  return _tensor_setitem_by_int_tensor_with_tensor(data, index, value_tensor)
886
1082
 
887
- if is_shape_unknown(F.shape(data)):
1083
+ if F.is_sequence_value_unknown(F.shape(data)):
888
1084
  const_utils.raise_unimplemented_error(
889
1085
  "Not supported to the dynamic shape tensor slice by using tensor of Boolean type")
890
1086
  return _tensor_setitem_by_bool_tensor_with_tensor(data, index, value_tensor)
@@ -898,11 +1094,13 @@ def tensor_setitem_by_tensor_with_number(data, index, value):
898
1094
  def tensor_setitem_by_tensor_with_sequence(data, index, value):
899
1095
  """Assigns the tensor by tensor with tuple value."""
900
1096
  index_dtype = F.dtype(index)
901
- invalid = const_utils.check_type_invalid(index_dtype, (mstype.int32, mstype.int64))
902
- if invalid:
903
- exp_msg = const_utils.gen_exception_msg("The tensor index must be int type, but got {}.", index_dtype)
904
- const_utils.raise_index_error(exp_msg)
905
- return _tensor_setitem_by_tensor_with_sequence(data, index, value)
1097
+ if index_dtype in (mstype.int32, mstype.int64):
1098
+ return _tensor_setitem_by_tensor_with_sequence(data, index, value)
1099
+ if index_dtype == mstype.bool_:
1100
+ return _tensor_setitem_by_bool_tensor_with_sequence(data, index, value)
1101
+ exp_msg = const_utils.gen_exception_msg("The tensor index must be int or bool type, but got {}.", index_dtype)
1102
+ const_utils.raise_index_error(exp_msg)
1103
+ return None
906
1104
 
907
1105
 
908
1106
  def _tensor_setitem_by_tensor_with_sequence(data, index, value):
@@ -912,6 +1110,12 @@ def _tensor_setitem_by_tensor_with_sequence(data, index, value):
912
1110
  return F.tensor_scatter_update(data, index, updates)
913
1111
 
914
1112
 
1113
+ def _tensor_setitem_by_bool_tensor_with_sequence(data, index, value):
1114
+ """Set a tensor item by a bool tensor with a tuple."""
1115
+ value = sequence_to_tensor(value, F.dtype(data))
1116
+ return _tensor_setitem_by_bool_tensor_with_tensor(data, index, value)
1117
+
1118
+
915
1119
  def tensor_setitem_by_slice_with_number(data, input_slice, value):
916
1120
  """Givens a scalar assign to tensor by slice"""
917
1121
  value = F.fill(F.dtype(data), (), value)
@@ -937,7 +1141,7 @@ def tensor_copy_slice_from_slice(data, input_slice, value):
937
1141
  if dim0_size >= data_shape[0]:
938
1142
  dim0_size = data_shape[0:1]
939
1143
  value_shape = P.Concat(-1)((dim0_size, data_shape[1:]))
940
- value = dynamic_broadcast_to(value, value_shape)
1144
+ value = ops.broadcast_to(value, value_shape)
941
1145
  return copy_slice(data, value.astype(data.dtype), start_tensor, stop_tensor, step_tensor)
942
1146
 
943
1147
 
@@ -948,8 +1152,8 @@ def tensor_setitem_by_slice_with_tensor(data, input_slice, value):
948
1152
  if check_result:
949
1153
  data_shape = F.shape(data)
950
1154
  step = const_utils.get_step_from_slice(input_slice)
951
- if step == 1:
952
- if is_shape_unknown(data_shape):
1155
+ if step == 1 and not const_utils.is_ascend():
1156
+ if F.is_sequence_value_unknown(data_shape):
953
1157
  return tensor_copy_slice_from_slice(data, input_slice, value)
954
1158
  start, stop, step = const_utils.normalize_slice(input_slice, data.shape[0])
955
1159
  dim0_size = stop - start
@@ -958,7 +1162,7 @@ def tensor_setitem_by_slice_with_tensor(data, input_slice, value):
958
1162
  value_shape = (dim0_size,) + const_utils.tuple_slice(data.shape, 1, None)
959
1163
  value = _broadcast(value_shape, value)
960
1164
  return copy_slice(data, value.astype(data.dtype), (start,), (stop,), (step,))
961
- if is_shape_unknown(data_shape):
1165
+ if F.is_sequence_value_unknown(data_shape):
962
1166
  const_utils.raise_unimplemented_error(
963
1167
  "Not supported to take the subscript of dynamic shape tensor slice setitem")
964
1168
  indices = const_utils.slice2indices(input_slice, data_shape)
@@ -982,7 +1186,7 @@ def tensor_copy_slice_from_tuple(data, tuple_index, value):
982
1186
  dim1_start, dim1_stop, _ = get_slice_stride(tuple_index[1], data_shape[1])
983
1187
  if dim1_stop - dim1_start <= 0:
984
1188
  return data
985
- dim0_start = check_range(tuple_index[0], data_shape[0])
1189
+ dim0_start = _scalar_to_tensor(tuple_index[0])
986
1190
  dim0_stop = dim0_start + const_utils.make_tensor(1)
987
1191
  start = (dim0_start, dim1_start)
988
1192
  stop = (dim0_stop, dim1_stop)
@@ -994,7 +1198,7 @@ def tensor_copy_slice_from_tuple(data, tuple_index, value):
994
1198
  if dim1_size > data_shape[1]:
995
1199
  dim1_size = data_shape[1:2]
996
1200
  value_shape = P.Concat(-1)((dim1_size, data_shape[2:]))
997
- value = dynamic_broadcast_to(value, value_shape)
1201
+ value = ops.broadcast_to(value, value_shape)
998
1202
  return copy_slice(data, value.astype(data.dtype), start_tensor, stop_tensor, step_tensor)
999
1203
 
1000
1204
 
@@ -1003,8 +1207,8 @@ def tensor_setitem_by_tuple_with_tensor(data, tuple_index, value):
1003
1207
  op_name = const_utils.TENSOR_SETITEM
1004
1208
  tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name)
1005
1209
 
1006
- if const_utils.use_copy_slice(tuple_index):
1007
- if is_shape_unknown(F.shape(data)):
1210
+ if const_utils.use_copy_slice(tuple_index) and not const_utils.is_ascend():
1211
+ if F.is_sequence_value_unknown(F.shape(data)):
1008
1212
  return tensor_copy_slice_from_tuple(data, tuple_index, value)
1009
1213
  dim1_start, dim1_stop, _ = const_utils.normalize_slice(tuple_index[1], data.shape[1])
1010
1214
  if dim1_stop - dim1_start <= 0:
@@ -1024,7 +1228,6 @@ def tensor_setitem_by_tuple_with_tensor(data, tuple_index, value):
1024
1228
  if len(tuple_index) == 1:
1025
1229
  data[tuple_index[0]] = value
1026
1230
  return data
1027
-
1028
1231
  indexes_types = hyper_map(toptypeof, tuple_index)
1029
1232
  contain_type = const_utils.tuple_index_type_cnt(indexes_types, op_name)
1030
1233
 
@@ -1058,14 +1261,20 @@ def tensor_setitem_by_number_with_sequence(data, index, value):
1058
1261
  def tensor_setitem_by_number_with_tensor(data, index, value):
1059
1262
  """Assigns the tensor by number with tensor value."""
1060
1263
  data_shape = F.shape(data)
1061
- if is_shape_unknown(data_shape):
1062
- index = check_range(index, F.dyn_shape(data)[0])
1264
+ if F.is_sequence_value_unknown(data_shape):
1265
+ index = _scalar_to_tensor(index)
1063
1266
  index = F.expand_dims(index, -1)
1064
1267
  return _tensor_setitem_by_int_tensor_with_tensor(data, index, value)
1065
1268
 
1269
+ dim_size = data_shape[0]
1270
+ if index < -dim_size or index >= dim_size:
1271
+ raise IndexError(f'index {index} is out of bounds for axis 0 with size {dim_size}')
1066
1272
  index = const_utils.int_to_index(index, data_shape)
1067
1273
  value_shape = const_utils.tuple_slice(F.shape(index), None, -1)
1068
1274
  value = _broadcast(value_shape, value.astype(F.dtype(data)))
1275
+ if is_parameter(data):
1276
+ F.scatter_nd_update(data, index, value)
1277
+ return data
1069
1278
  return F.tensor_scatter_update(data, index, value)
1070
1279
 
1071
1280
 
@@ -1073,7 +1282,7 @@ def tensor_setitem_by_ellipsis_with_number(data, value):
1073
1282
  """Assigns the tensor by ellipsis with number value."""
1074
1283
  data_shape = F.shape(data)
1075
1284
  data_dtype = F.dtype(data)
1076
- if is_shape_unknown(data_shape):
1285
+ if F.is_sequence_value_unknown(data_shape):
1077
1286
  value = F.fill(F.dtype(data), (), value)
1078
1287
  return tensor_setitem_by_ellipsis_with_tensor(data, value)
1079
1288
  return F.fill(data_dtype, data_shape, value)
@@ -1085,9 +1294,9 @@ def tensor_setitem_by_ellipsis_with_tensor(data, value):
1085
1294
  data_dtype = F.dtype(data)
1086
1295
  value = value.astype(data_dtype)
1087
1296
 
1088
- if is_shape_unknown(data_shape):
1297
+ if F.is_sequence_value_unknown(data_shape):
1089
1298
  data_shape = F.dyn_shape(data)
1090
- data = dynamic_broadcast_to(value, data_shape)
1299
+ data = ops.broadcast_to(value, data_shape)
1091
1300
  return data
1092
1301
  value_shape = F.shape(value)
1093
1302
  source_shape = const_utils.get_source_shape(data_shape, value_shape)
@@ -1115,9 +1324,9 @@ def tensor_setitem_by_bool(data, index, value):
1115
1324
  elif isinstance(value, float):
1116
1325
  value = const_utils.make_tensor(value, mstype.float32)
1117
1326
 
1118
- if is_shape_unknown(data_shape) and index:
1327
+ if F.is_sequence_value_unknown(data_shape) and index:
1119
1328
  data_shape = F.dyn_shape(data)
1120
- data = dynamic_broadcast_to(value, data_shape)
1329
+ data = ops.broadcast_to(value, data_shape)
1121
1330
  return data
1122
1331
  value_shape = F.shape(value)
1123
1332
  source_shape = const_utils.get_source_shape(data_shape, value_shape)
@@ -1143,6 +1352,8 @@ def format_list_indices(list_indices, length):
1143
1352
  # If eyery element in list is bool, it's treated as 1-D bool tensor.
1144
1353
  # If every element in list is int(not all bool), it's treated as int tensor.
1145
1354
  if const_utils.judge_indexes_types(indices_types, mstype.int_type + (mstype.bool_,)):
1355
+ if not F.isconstant(length):
1356
+ return const_utils.sequence_to_index(list_indices, None)
1146
1357
  return const_utils.sequence_to_index(list_indices, length)
1147
1358
  # If list contains other types(.../list/tuple/None), it's treated as a tuple
1148
1359
  return const_utils.deep_tuple(list_indices)
@@ -1162,10 +1373,34 @@ def format_tuple_indices(tuple_indices):
1162
1373
  return res
1163
1374
 
1164
1375
 
1376
+ @_primexpr
1377
+ def remove_expanded_dims_parse_bool_tensor_index(index_out, indices_out, shapes, cur_dim):
1378
+ """ Parse bool tensor index """
1379
+ index_out = index_out.nonzero()
1380
+ if index_out.shape[0] == 0:
1381
+ return None, shapes, cur_dim
1382
+ for i in range(index_out.shape[1]):
1383
+ out = index_out[:, i]
1384
+ indices_out += (out,)
1385
+ shapes.append(F.shape(out))
1386
+ cur_dim += 1
1387
+ return indices_out, shapes, cur_dim
1388
+
1389
+
1390
+ def remove_expanded_dims_parse_tensor_index(index_out, indices_out, shapes, cur_dim):
1391
+ """ Parse tensor index """
1392
+ if index_out.dtype == mstype.bool_:
1393
+ return remove_expanded_dims_parse_bool_tensor_index(index_out, indices_out, shapes, cur_dim)
1394
+ indices_out += (index_out,)
1395
+ shapes.append(F.shape(index_out))
1396
+ cur_dim += 1
1397
+ return indices_out, shapes, cur_dim
1398
+
1399
+
1165
1400
  def remove_expanded_dims(tuple_index, data_shape, value):
1166
1401
  """Removes expanded dimensions in tuple_index and value."""
1167
1402
  not_expanded_dim = ()
1168
- shapes = ()
1403
+ shapes = []
1169
1404
  has_true = False
1170
1405
  has_false = False
1171
1406
  has_sequence = False
@@ -1192,11 +1427,12 @@ def remove_expanded_dims(tuple_index, data_shape, value):
1192
1427
  idx_advanced = 0
1193
1428
  idx_tensor = i
1194
1429
  if isinstance(index_out, Tensor):
1195
- if F.rank(index_out) > 0:
1430
+ indices_out, shapes, cur_dim = \
1431
+ remove_expanded_dims_parse_tensor_index(index_out, indices_out, shapes, cur_dim)
1432
+ if indices_out is None:
1433
+ return False, value, 0
1434
+ if index_out.dtype != mstype.bool_ and F.rank(index_out) > 0:
1196
1435
  has_sequence = True
1197
- indices_out += (index_out,)
1198
- shapes += (F.shape(index_out),)
1199
- cur_dim += 1
1200
1436
  has_true = has_true or index_out is True
1201
1437
  has_false = has_false or index_out is False
1202
1438
  else:
@@ -1229,11 +1465,21 @@ def format_index(idx, data_shape, cur_dim):
1229
1465
  elif isinstance(idx, int) and not isinstance(idx, bool):
1230
1466
  idx = const_utils.make_tensor(idx, mstype.int64, None, data_shape[cur_dim])
1231
1467
  elif isinstance(idx, Tensor):
1232
- # does not take bool tensor into account since it's currently not supported
1233
- idx = F.select(idx < 0, idx + data_shape[cur_dim], idx)
1468
+ tensor_dtype = const_utils.get_index_tensor_dtype(idx.dtype)
1469
+ if tensor_dtype == const_utils.INT_:
1470
+ idx = F.select(idx < 0, idx + data_shape[cur_dim], idx)
1471
+ elif tensor_dtype == const_utils.BOOL_:
1472
+ # index with tensor(bool) type is processed in remove_expanded_dims()
1473
+ pass
1234
1474
  return idx
1235
1475
 
1236
1476
 
1477
+ @_primexpr
1478
+ def _check_shape_mul(shape):
1479
+ if F.shape_mul(shape) == 0:
1480
+ raise ValueError('zero-size tensors are not supported.')
1481
+
1482
+
1237
1483
  def reduce_(a, reduce_fn, cmp_fn=None, axis=None, keepdims=False, initial=None, where=True, dtype=None):
1238
1484
  """
1239
1485
  Applies comparison based on cmp_fn and reduction based on reduce_fn.
@@ -1250,8 +1496,7 @@ def reduce_(a, reduce_fn, cmp_fn=None, axis=None, keepdims=False, initial=None,
1250
1496
  not isinstance(initial, (int, float, bool, Tensor))):
1251
1497
  const_utils.raise_type_error('initial must be scalar')
1252
1498
 
1253
- if F.shape_mul(shape) == 0:
1254
- const_utils.raise_value_error('zero-size tensors are not supported.')
1499
+ _check_shape_mul(shape)
1255
1500
 
1256
1501
  if initial is not None:
1257
1502
  if isinstance(initial, Tensor):