mindspore 1.10.0__cp39-cp39-win_amd64.whl → 2.0.0rc1__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (966) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/ConcurrencyCheck.dll +0 -0
  3. mindspore/CppBuildInsights.dll +0 -0
  4. mindspore/CppCoreCheck.dll +0 -0
  5. mindspore/EnumIndex.dll +0 -0
  6. mindspore/EspXEngine.dll +0 -0
  7. mindspore/HResultCheck.dll +0 -0
  8. mindspore/KernelTraceControl.dll +0 -0
  9. mindspore/LocalESPC.dll +0 -0
  10. mindspore/Microsoft.Diagnostics.Tracing.EventSource.dll +0 -0
  11. mindspore/Microsoft.VisualStudio.RemoteControl.dll +0 -0
  12. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  13. mindspore/Microsoft.VisualStudio.Utilities.Internal.dll +0 -0
  14. mindspore/Newtonsoft.Json.dll +0 -0
  15. mindspore/System.Runtime.CompilerServices.Unsafe.dll +0 -0
  16. mindspore/VariantClear.dll +0 -0
  17. mindspore/__init__.py +9 -4
  18. mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
  19. mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
  20. mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
  21. mindspore/_check_jit_forbidden_api.py +102 -0
  22. mindspore/_checkparam.py +1066 -1001
  23. mindspore/_extends/builtin_operations.py +32 -4
  24. mindspore/_extends/graph_kernel/model/graph_split.py +66 -222
  25. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +12 -9
  26. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +119 -26
  27. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +50 -50
  28. mindspore/_extends/parallel_compile/akg_compiler/util.py +9 -6
  29. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +4 -25
  30. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +9 -4
  31. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -27
  32. mindspore/_extends/parse/__init__.py +5 -3
  33. mindspore/_extends/parse/namespace.py +17 -2
  34. mindspore/_extends/parse/parser.py +193 -34
  35. mindspore/_extends/parse/resources.py +7 -8
  36. mindspore/_extends/parse/standard_method.py +1780 -435
  37. mindspore/_extends/parse/trope.py +3 -1
  38. mindspore/amp.py +53 -58
  39. mindspore/atlprov.dll +0 -0
  40. mindspore/boost/adasum.py +3 -2
  41. mindspore/boost/boost.py +2 -2
  42. mindspore/boost/boost_cell_wrapper.py +46 -26
  43. mindspore/boost/dim_reduce.py +6 -5
  44. mindspore/boost/grad_accumulation.py +2 -1
  45. mindspore/boost/group_loss_scale_manager.py +1 -1
  46. mindspore/c1.dll +0 -0
  47. mindspore/c1xx.dll +0 -0
  48. mindspore/c2.dll +0 -0
  49. mindspore/cfgpersist.dll +0 -0
  50. mindspore/clang_rt.asan_dbg_dynamic-x86_64.dll +0 -0
  51. mindspore/clang_rt.asan_dynamic-x86_64.dll +0 -0
  52. mindspore/common/__init__.py +11 -10
  53. mindspore/common/_decorator.py +2 -0
  54. mindspore/common/_register_for_adapter.py +55 -0
  55. mindspore/common/_stub_tensor.py +201 -0
  56. mindspore/common/_utils.py +57 -0
  57. mindspore/common/api.py +582 -297
  58. mindspore/common/dtype.py +66 -18
  59. mindspore/common/dump.py +2 -2
  60. mindspore/common/initializer.py +38 -1
  61. mindspore/common/jit_config.py +25 -13
  62. mindspore/common/mutable.py +53 -24
  63. mindspore/common/parameter.py +60 -37
  64. mindspore/common/seed.py +8 -24
  65. mindspore/common/sparse_tensor.py +927 -0
  66. mindspore/common/tensor.py +1627 -3900
  67. mindspore/communication/__init__.py +10 -5
  68. mindspore/communication/_comm_helper.py +78 -214
  69. mindspore/communication/_hccl_management.py +2 -1
  70. mindspore/communication/management.py +136 -47
  71. mindspore/config/op_info.config +501 -1008
  72. mindspore/context.py +291 -56
  73. mindspore/d3dcompiler_47.dll +0 -0
  74. mindspore/dataset/__init__.py +12 -8
  75. mindspore/dataset/audio/__init__.py +9 -9
  76. mindspore/dataset/audio/transforms.py +1090 -228
  77. mindspore/dataset/audio/utils.py +87 -39
  78. mindspore/dataset/audio/validators.py +223 -1
  79. mindspore/dataset/callback/ds_callback.py +17 -15
  80. mindspore/dataset/core/config.py +246 -17
  81. mindspore/dataset/core/py_util_helpers.py +4 -3
  82. mindspore/dataset/core/validator_helpers.py +10 -10
  83. mindspore/{parallel/nn/layers.py → dataset/debug/__init__.py} +7 -8
  84. mindspore/dataset/debug/debug_hook.py +65 -0
  85. mindspore/dataset/debug/pre_defined_hook.py +67 -0
  86. mindspore/dataset/engine/__init__.py +7 -3
  87. mindspore/dataset/engine/cache_client.py +9 -9
  88. mindspore/dataset/engine/datasets.py +648 -477
  89. mindspore/dataset/engine/datasets_audio.py +165 -167
  90. mindspore/dataset/engine/datasets_standard_format.py +93 -67
  91. mindspore/dataset/engine/datasets_text.py +492 -342
  92. mindspore/dataset/engine/datasets_user_defined.py +85 -50
  93. mindspore/dataset/engine/datasets_vision.py +1224 -699
  94. mindspore/dataset/engine/graphdata.py +134 -69
  95. mindspore/dataset/engine/iterators.py +50 -9
  96. mindspore/dataset/engine/offload.py +52 -31
  97. mindspore/dataset/engine/samplers.py +27 -24
  98. mindspore/dataset/engine/serializer_deserializer.py +14 -15
  99. mindspore/dataset/engine/validators.py +213 -52
  100. mindspore/dataset/text/__init__.py +10 -8
  101. mindspore/dataset/text/transforms.py +152 -57
  102. mindspore/dataset/text/utils.py +98 -49
  103. mindspore/dataset/text/validators.py +25 -0
  104. mindspore/dataset/transforms/__init__.py +4 -2
  105. mindspore/dataset/transforms/c_transforms.py +11 -13
  106. mindspore/dataset/transforms/py_transforms.py +2 -2
  107. mindspore/dataset/transforms/py_transforms_util.py +10 -0
  108. mindspore/dataset/transforms/transforms.py +13 -15
  109. mindspore/dataset/transforms/validators.py +7 -7
  110. mindspore/dataset/utils/__init__.py +2 -1
  111. mindspore/dataset/utils/browse_dataset.py +13 -13
  112. mindspore/dataset/utils/line_reader.py +121 -0
  113. mindspore/dataset/vision/__init__.py +8 -7
  114. mindspore/dataset/vision/c_transforms.py +125 -126
  115. mindspore/dataset/vision/py_transforms.py +37 -37
  116. mindspore/dataset/vision/py_transforms_util.py +23 -20
  117. mindspore/dataset/vision/transforms.py +316 -315
  118. mindspore/dataset/vision/utils.py +313 -17
  119. mindspore/dataset/vision/validators.py +6 -6
  120. mindspore/default_config.py +0 -1
  121. mindspore/dpcmi.dll +0 -0
  122. mindspore/{compression → experimental}/__init__.py +6 -5
  123. mindspore/experimental/map_parameter.py +275 -0
  124. mindspore/include/OWNERS +0 -1
  125. mindspore/include/api/callback/callback.h +9 -13
  126. mindspore/include/api/callback/ckpt_saver.h +2 -2
  127. mindspore/include/api/callback/loss_monitor.h +2 -2
  128. mindspore/include/api/callback/lr_scheduler.h +5 -5
  129. mindspore/include/api/callback/time_monitor.h +2 -2
  130. mindspore/include/api/callback/train_accuracy.h +4 -6
  131. mindspore/include/api/cfg.h +19 -6
  132. mindspore/include/api/context.h +70 -9
  133. mindspore/include/api/delegate.h +8 -1
  134. mindspore/include/api/dual_abi_helper.h +8 -24
  135. mindspore/include/api/metrics/accuracy.h +2 -2
  136. mindspore/include/api/metrics/metrics.h +4 -3
  137. mindspore/include/api/model.h +9 -4
  138. mindspore/include/api/model_group.h +68 -0
  139. mindspore/include/api/model_parallel_runner.h +17 -17
  140. mindspore/include/api/net.h +12 -11
  141. mindspore/include/api/serialization.h +20 -4
  142. mindspore/include/api/status.h +7 -1
  143. mindspore/include/api/types.h +25 -21
  144. mindspore/include/api/visible.h +4 -0
  145. mindspore/include/c_api/model_c.h +5 -0
  146. mindspore/include/c_api/status_c.h +1 -1
  147. mindspore/include/dataset/config.h +1 -1
  148. mindspore/include/dataset/constants.h +14 -0
  149. mindspore/include/dataset/text.h +59 -0
  150. mindspore/include/dataset/vision.h +56 -117
  151. mindspore/include/dataset/vision_lite.h +102 -0
  152. mindspore/jpeg62.dll +0 -0
  153. mindspore/log.py +28 -28
  154. mindspore/mindrecord/common/exceptions.py +2 -4
  155. mindspore/mindrecord/filereader.py +19 -1
  156. mindspore/mindrecord/filewriter.py +250 -88
  157. mindspore/mindrecord/mindpage.py +13 -13
  158. mindspore/mindrecord/shardheader.py +15 -15
  159. mindspore/mindrecord/shardreader.py +9 -0
  160. mindspore/mindrecord/shardwriter.py +29 -29
  161. mindspore/mindrecord/tools/cifar100_to_mr.py +9 -9
  162. mindspore/mindrecord/tools/cifar10_to_mr.py +9 -9
  163. mindspore/mindrecord/tools/csv_to_mr.py +4 -4
  164. mindspore/mindrecord/tools/imagenet_to_mr.py +70 -65
  165. mindspore/mindrecord/tools/mnist_to_mr.py +41 -41
  166. mindspore/mindrecord/tools/tfrecord_to_mr.py +6 -6
  167. mindspore/{libmindspore_backend.dll → mindspore_backend.dll} +0 -0
  168. mindspore/mindspore_common.dll +0 -0
  169. mindspore/mindspore_core.dll +0 -0
  170. mindspore/mindspore_glog.dll +0 -0
  171. mindspore/mindspore_shared_lib.dll +0 -0
  172. mindspore/msobj140.dll +0 -0
  173. mindspore/mspdb140.dll +0 -0
  174. mindspore/mspdbcore.dll +0 -0
  175. mindspore/mspdbst.dll +0 -0
  176. mindspore/mspft140.dll +0 -0
  177. mindspore/msvcdis140.dll +0 -0
  178. mindspore/msvcp140_1.dll +0 -0
  179. mindspore/msvcp140_2.dll +0 -0
  180. mindspore/msvcp140_atomic_wait.dll +0 -0
  181. mindspore/msvcp140_codecvt_ids.dll +0 -0
  182. mindspore/nn/__init__.py +1 -5
  183. mindspore/nn/cell.py +297 -234
  184. mindspore/nn/dynamic_lr.py +1 -1
  185. mindspore/nn/grad/cell_grad.py +17 -42
  186. mindspore/nn/layer/__init__.py +7 -4
  187. mindspore/nn/layer/activation.py +131 -88
  188. mindspore/nn/layer/basic.py +313 -613
  189. mindspore/nn/layer/channel_shuffle.py +103 -0
  190. mindspore/nn/layer/combined.py +1 -1
  191. mindspore/nn/layer/container.py +52 -6
  192. mindspore/nn/layer/conv.py +112 -43
  193. mindspore/nn/layer/dense.py +10 -9
  194. mindspore/nn/layer/embedding.py +36 -34
  195. mindspore/nn/layer/image.py +123 -27
  196. mindspore/nn/layer/math.py +108 -107
  197. mindspore/nn/layer/normalization.py +212 -366
  198. mindspore/nn/layer/padding.py +370 -42
  199. mindspore/nn/layer/pooling.py +1443 -219
  200. mindspore/nn/layer/rnn_cells.py +11 -16
  201. mindspore/nn/layer/rnns.py +38 -39
  202. mindspore/nn/layer/thor_layer.py +24 -25
  203. mindspore/nn/layer/timedistributed.py +5 -5
  204. mindspore/nn/layer/transformer.py +701 -0
  205. mindspore/nn/learning_rate_schedule.py +8 -8
  206. mindspore/nn/loss/__init__.py +9 -6
  207. mindspore/nn/loss/loss.py +678 -142
  208. mindspore/nn/metrics.py +53 -0
  209. mindspore/nn/optim/_dist_optimizer_registry.py +2 -2
  210. mindspore/nn/optim/ada_grad.py +8 -8
  211. mindspore/nn/optim/adadelta.py +2 -3
  212. mindspore/nn/optim/adafactor.py +18 -14
  213. mindspore/nn/optim/adam.py +429 -87
  214. mindspore/nn/optim/adamax.py +5 -6
  215. mindspore/nn/optim/adasum.py +10 -8
  216. mindspore/nn/optim/asgd.py +7 -7
  217. mindspore/nn/optim/ftrl.py +81 -11
  218. mindspore/nn/optim/lamb.py +7 -8
  219. mindspore/nn/optim/lars.py +4 -4
  220. mindspore/nn/optim/lazyadam.py +82 -7
  221. mindspore/nn/optim/momentum.py +8 -7
  222. mindspore/nn/optim/optimizer.py +19 -10
  223. mindspore/nn/optim/proximal_ada_grad.py +6 -5
  224. mindspore/nn/optim/rmsprop.py +3 -3
  225. mindspore/nn/optim/rprop.py +20 -16
  226. mindspore/nn/optim/sgd.py +21 -15
  227. mindspore/nn/optim/thor.py +23 -21
  228. mindspore/nn/probability/__init__.py +0 -2
  229. mindspore/nn/probability/bijector/bijector.py +7 -6
  230. mindspore/nn/probability/bijector/invert.py +4 -2
  231. mindspore/nn/probability/bijector/softplus.py +2 -2
  232. mindspore/nn/probability/bnn_layers/dense_variational.py +1 -1
  233. mindspore/nn/probability/bnn_layers/layer_distribution.py +2 -2
  234. mindspore/nn/probability/distribution/__init__.py +6 -0
  235. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -2
  236. mindspore/nn/probability/distribution/_utils/utils.py +11 -17
  237. mindspore/nn/probability/distribution/bernoulli.py +6 -6
  238. mindspore/nn/probability/distribution/beta.py +1 -1
  239. mindspore/nn/probability/distribution/categorical.py +9 -9
  240. mindspore/nn/probability/distribution/cauchy.py +8 -8
  241. mindspore/nn/probability/distribution/distribution.py +12 -6
  242. mindspore/nn/probability/distribution/exponential.py +5 -5
  243. mindspore/nn/probability/distribution/gamma.py +3 -3
  244. mindspore/nn/probability/distribution/geometric.py +6 -5
  245. mindspore/nn/probability/distribution/gumbel.py +5 -5
  246. mindspore/nn/probability/distribution/half_normal.py +133 -0
  247. mindspore/nn/probability/distribution/laplace.py +128 -0
  248. mindspore/nn/probability/distribution/log_normal.py +0 -1
  249. mindspore/nn/probability/distribution/logistic.py +4 -5
  250. mindspore/nn/probability/distribution/normal.py +11 -15
  251. mindspore/nn/probability/distribution/poisson.py +6 -2
  252. mindspore/nn/probability/distribution/student_t.py +150 -0
  253. mindspore/nn/probability/distribution/transformed_distribution.py +4 -4
  254. mindspore/nn/probability/distribution/uniform.py +5 -5
  255. mindspore/nn/reinforcement/_tensors_queue.py +3 -3
  256. mindspore/nn/reinforcement/tensor_array.py +2 -2
  257. mindspore/nn/sparse/sparse.py +8 -1
  258. mindspore/nn/wrap/cell_wrapper.py +55 -27
  259. mindspore/nn/wrap/grad_reducer.py +20 -11
  260. mindspore/nn/wrap/loss_scale.py +47 -30
  261. mindspore/numpy/array_creations.py +33 -22
  262. mindspore/numpy/array_ops.py +46 -42
  263. mindspore/numpy/logic_ops.py +6 -27
  264. mindspore/numpy/math_ops.py +26 -19
  265. mindspore/numpy/utils.py +1 -8
  266. mindspore/numpy/utils_const.py +112 -62
  267. mindspore/opencv_core452.dll +0 -0
  268. mindspore/opencv_imgcodecs452.dll +0 -0
  269. mindspore/opencv_imgproc452.dll +0 -0
  270. mindspore/ops/__init__.py +6 -3
  271. mindspore/ops/_constants.py +0 -6
  272. mindspore/ops/_grad/__init__.py +2 -1
  273. mindspore/ops/_grad/grad_array_ops.py +209 -152
  274. mindspore/ops/_grad/grad_base.py +55 -17
  275. mindspore/ops/_grad/grad_clip_ops.py +11 -3
  276. mindspore/ops/_grad/grad_comm_ops.py +58 -47
  277. mindspore/ops/_grad/grad_implementations.py +21 -61
  278. mindspore/ops/_grad/grad_inner_ops.py +48 -6
  279. mindspore/ops/_grad/grad_math_ops.py +306 -161
  280. mindspore/ops/_grad/grad_nn_ops.py +192 -181
  281. mindspore/ops/_grad/grad_other_ops.py +1 -1
  282. mindspore/ops/_grad/grad_quant_ops.py +5 -5
  283. mindspore/ops/_grad/grad_sequence_ops.py +296 -0
  284. mindspore/ops/_grad/grad_sparse.py +15 -9
  285. mindspore/ops/_grad_experimental/__init__.py +1 -0
  286. mindspore/ops/_grad_experimental/grad_array_ops.py +441 -55
  287. mindspore/ops/_grad_experimental/grad_image_ops.py +25 -7
  288. mindspore/ops/_grad_experimental/grad_inner_ops.py +3 -44
  289. mindspore/ops/_grad_experimental/grad_linalg_ops.py +16 -21
  290. mindspore/ops/_grad_experimental/grad_math_ops.py +979 -49
  291. mindspore/ops/_grad_experimental/grad_nn_ops.py +78 -8
  292. mindspore/ops/_grad_experimental/grad_scalar_ops.py +112 -0
  293. mindspore/ops/_grad_experimental/grad_sparse_ops.py +197 -13
  294. mindspore/ops/_op_impl/__init__.py +3 -3
  295. mindspore/ops/_op_impl/_custom_op/__init__.py +0 -1
  296. mindspore/ops/_op_impl/_custom_op/_basic.py +0 -1
  297. mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +1 -1
  298. mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +4 -2
  299. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +2 -2
  300. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +2 -2
  301. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +5 -5
  302. mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +3 -3
  303. mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +1 -1
  304. mindspore/ops/_op_impl/_custom_op/correction_mul.py +3 -3
  305. mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +2 -2
  306. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +4 -8
  307. mindspore/ops/_op_impl/_custom_op/dsd_impl.py +1 -1
  308. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +2 -2
  309. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +2 -2
  310. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +2 -2
  311. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +2 -2
  312. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +2 -2
  313. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +2 -2
  314. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +2 -2
  315. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +2 -2
  316. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +2 -2
  317. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +2 -2
  318. mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +1 -1
  319. mindspore/ops/_op_impl/_custom_op/img2col_impl.py +1 -1
  320. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
  321. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +1 -1
  322. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +1 -1
  323. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +1 -1
  324. mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +2 -2
  325. mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py +0 -1
  326. mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +0 -1
  327. mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +1 -1
  328. mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +2 -2
  329. mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +2 -2
  330. mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +1 -1
  331. mindspore/ops/_op_impl/aicpu/__init__.py +238 -3
  332. mindspore/ops/_op_impl/aicpu/abs.py +36 -0
  333. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d.py +34 -0
  334. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
  335. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d.py +39 -0
  336. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d_grad.py +39 -0
  337. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py +37 -0
  338. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d.py +42 -0
  339. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d_grad.py +152 -0
  340. mindspore/ops/_op_impl/aicpu/add.py +43 -0
  341. mindspore/ops/_op_impl/aicpu/addcdiv.py +0 -32
  342. mindspore/ops/_op_impl/aicpu/addcmul.py +0 -84
  343. mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
  344. mindspore/ops/_op_impl/aicpu/arg_max.py +75 -0
  345. mindspore/ops/_op_impl/aicpu/arg_min.py +75 -0
  346. mindspore/ops/_op_impl/aicpu/argmin_with_value.py +43 -0
  347. mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -0
  348. mindspore/ops/_op_impl/aicpu/batch_norm_grad_grad.py +49 -0
  349. mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
  350. mindspore/ops/_op_impl/aicpu/bessel_i0.py +31 -0
  351. mindspore/ops/_op_impl/aicpu/bias_add.py +44 -0
  352. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +43 -0
  353. mindspore/ops/_op_impl/aicpu/bincount.py +33 -0
  354. mindspore/{nn/probability/infer/variational/__init__.py → ops/_op_impl/aicpu/cauchy.py} +17 -10
  355. mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
  356. mindspore/ops/_op_impl/aicpu/cholesky.py +1 -1
  357. mindspore/ops/_op_impl/{cpu/bias_add.py → aicpu/choleskygrad.py} +9 -7
  358. mindspore/ops/_op_impl/aicpu/combined_non_max_suppression.py +42 -0
  359. mindspore/ops/_op_impl/aicpu/concat_offset.py +42 -0
  360. mindspore/ops/_op_impl/aicpu/concat_offset_v1.py +31 -0
  361. mindspore/ops/_op_impl/aicpu/conj.py +11 -0
  362. mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_image.py +38 -0
  363. mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +36 -0
  364. mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
  365. mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +2 -2
  366. mindspore/ops/_op_impl/aicpu/dense_to_sparse_set_operation.py +48 -0
  367. mindspore/ops/_op_impl/aicpu/diag.py +36 -0
  368. mindspore/ops/_op_impl/aicpu/diag_part.py +36 -0
  369. mindspore/ops/_op_impl/aicpu/diagonal.py +35 -0
  370. mindspore/ops/_op_impl/{cpu/bias_add_grad.py → aicpu/digamma.py} +9 -7
  371. mindspore/ops/_op_impl/aicpu/eig.py +35 -0
  372. mindspore/ops/_op_impl/aicpu/fft_with_size.py +41 -0
  373. mindspore/ops/_op_impl/aicpu/flatten.py +1 -0
  374. mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
  375. mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
  376. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +1 -1
  377. mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
  378. mindspore/ops/_op_impl/aicpu/glu.py +33 -0
  379. mindspore/ops/_op_impl/aicpu/glu_grad.py +34 -0
  380. mindspore/ops/_op_impl/aicpu/greater.py +41 -0
  381. mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
  382. mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
  383. mindspore/ops/_op_impl/{tbe/scatter_add_ds.py → aicpu/inplace_index_add.py} +17 -21
  384. mindspore/ops/_op_impl/aicpu/instance_norm_v2.py +41 -0
  385. mindspore/ops/_op_impl/aicpu/instance_norm_v2_grad.py +44 -0
  386. mindspore/ops/_op_impl/aicpu/layer_norm_grad_grad.py +47 -0
  387. mindspore/ops/_op_impl/aicpu/less.py +41 -0
  388. mindspore/ops/_op_impl/aicpu/less_equal.py +41 -0
  389. mindspore/ops/_op_impl/aicpu/lgamma.py +32 -0
  390. mindspore/ops/_op_impl/aicpu/log_normal_reverse.py +33 -0
  391. mindspore/ops/_op_impl/aicpu/logit.py +33 -0
  392. mindspore/ops/_op_impl/aicpu/logit_grad.py +34 -0
  393. mindspore/ops/_op_impl/aicpu/masked_fill.py +42 -0
  394. mindspore/ops/_op_impl/aicpu/masked_scatter.py +39 -0
  395. mindspore/ops/_op_impl/aicpu/matmul.py +39 -0
  396. mindspore/ops/_op_impl/aicpu/matrix_logarithm.py +31 -0
  397. mindspore/ops/_op_impl/aicpu/matrix_power.py +32 -0
  398. mindspore/ops/_op_impl/aicpu/matrix_solve_ls.py +36 -0
  399. mindspore/ops/_op_impl/aicpu/matrix_triangular_solve.py +36 -0
  400. mindspore/ops/_op_impl/aicpu/mirror_pad.py +2 -0
  401. mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +0 -4
  402. mindspore/ops/_op_impl/aicpu/mul.py +3 -1
  403. mindspore/ops/_op_impl/aicpu/multinomial.py +14 -6
  404. mindspore/ops/_op_impl/aicpu/multinomial_with_replacement.py +35 -0
  405. mindspore/ops/_op_impl/aicpu/nan_to_num.py +34 -0
  406. mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
  407. mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
  408. mindspore/ops/_op_impl/aicpu/ones_like.py +0 -2
  409. mindspore/ops/_op_impl/aicpu/polar.py +32 -0
  410. mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
  411. mindspore/ops/_op_impl/aicpu/qr.py +36 -0
  412. mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
  413. mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
  414. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
  415. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_tensor.py +74 -0
  416. mindspore/ops/_op_impl/aicpu/random_shuffle.py +3 -0
  417. mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
  418. mindspore/ops/_op_impl/aicpu/range.py +36 -0
  419. mindspore/ops/_op_impl/aicpu/reciprocal.py +34 -0
  420. mindspore/ops/_op_impl/aicpu/reciprocal_grad.py +35 -0
  421. mindspore/ops/_op_impl/aicpu/reduce_sum.py +57 -0
  422. mindspore/ops/_op_impl/aicpu/resize_bicubic.py +2 -8
  423. mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +1 -1
  424. mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
  425. mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
  426. mindspore/ops/_op_impl/aicpu/scatter_elements.py +4 -0
  427. mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +2 -0
  428. mindspore/ops/_op_impl/aicpu/search_sorted.py +12 -6
  429. mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py +34 -0
  430. mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
  431. mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
  432. mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
  433. mindspore/ops/_op_impl/aicpu/slice_grad.py +76 -0
  434. mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
  435. mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
  436. mindspore/ops/_op_impl/aicpu/sort.py +39 -0
  437. mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +0 -24
  438. mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
  439. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows.py +63 -0
  440. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows_grad.py +45 -0
  441. mindspore/ops/_op_impl/aicpu/sparse_matrix_mat_mul.py +56 -0
  442. mindspore/ops/_op_impl/{tbe/slice_ds.py → aicpu/sparse_segment_sum.py} +16 -24
  443. mindspore/ops/_op_impl/aicpu/sparse_segment_sum_with_num_segments.py +68 -0
  444. mindspore/ops/_op_impl/aicpu/sparse_slice.py +63 -0
  445. mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +61 -0
  446. mindspore/ops/_op_impl/aicpu/squared_difference.py +2 -0
  447. mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +93 -0
  448. mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +66 -0
  449. mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
  450. mindspore/ops/_op_impl/{tbe/gather_v2.py → aicpu/tile.py} +24 -24
  451. mindspore/ops/_op_impl/aicpu/tridiagonal_solve.py +35 -0
  452. mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
  453. mindspore/ops/_op_impl/aicpu/triu_indices.py +34 -0
  454. mindspore/ops/_op_impl/aicpu/uniform.py +34 -0
  455. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +1 -0
  456. mindspore/ops/_op_impl/aicpu/unique_consecutive.py +10 -2
  457. mindspore/ops/_op_impl/cpu/__init__.py +1 -2
  458. mindspore/ops/_op_impl/cpu/dynamic_shape.py +5 -1
  459. mindspore/ops/_op_impl/cpu/maximum_grad.py +2 -0
  460. mindspore/{compression/common/__init__.py → ops/_op_impl/cpu/pyexecute.py} +13 -8
  461. mindspore/ops/_op_impl/cpu/reduce_sum.py +8 -0
  462. mindspore/ops/_op_impl/cpu/sparse_slice.py +62 -0
  463. mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +60 -0
  464. mindspore/ops/_op_impl/cpu/tensor_shape.py +5 -1
  465. mindspore/ops/_op_impl/tbe/__init__.py +27 -608
  466. mindspore/ops/_op_impl/tbe/addcdiv_ds.py +42 -0
  467. mindspore/ops/_op_impl/tbe/addcmul_ds.py +44 -0
  468. mindspore/ops/_op_impl/tbe/assign_add_ds.py +1 -0
  469. mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
  470. mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +1 -1
  471. mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad_v2.py +0 -1
  472. mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
  473. mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +1 -1
  474. mindspore/ops/_op_impl/tbe/batch_to_space_nd_v2.py +41 -0
  475. mindspore/ops/_op_impl/tbe/bce_with_logits_loss.py +1 -0
  476. mindspore/ops/_op_impl/tbe/bias_add_grad.py +2 -0
  477. mindspore/ops/_op_impl/tbe/bn_infer_grad.py +4 -2
  478. mindspore/ops/_op_impl/tbe/bn_infer_grad_ds.py +40 -0
  479. mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -1
  480. mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -1
  481. mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +6 -4
  482. mindspore/ops/_op_impl/tbe/cast.py +0 -2
  483. mindspore/ops/_op_impl/tbe/cast_ds.py +3 -3
  484. mindspore/ops/_op_impl/tbe/ctc_loss_v2.py +0 -2
  485. mindspore/ops/_op_impl/tbe/ctc_loss_v2_grad.py +0 -2
  486. mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +1 -0
  487. mindspore/ops/_op_impl/tbe/deformable_offsets.py +1 -0
  488. mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +1 -1
  489. mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +1 -1
  490. mindspore/ops/_op_impl/tbe/gather_nd.py +1 -0
  491. mindspore/ops/_op_impl/tbe/greater.py +2 -0
  492. mindspore/ops/_op_impl/tbe/{index_add.py → inplace_index_add.py} +3 -6
  493. mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2.py +0 -1
  494. mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +35 -0
  495. mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +35 -0
  496. mindspore/ops/_op_impl/tbe/one_hot_ds.py +0 -6
  497. mindspore/ops/_op_impl/tbe/{greater_ds.py → reduce_all_ds.py} +13 -16
  498. mindspore/ops/_op_impl/tbe/reduce_any_ds.py +39 -0
  499. mindspore/ops/_op_impl/tbe/roi_align_ds.py +44 -0
  500. mindspore/ops/_op_impl/tbe/roi_align_grad_ds.py +44 -0
  501. mindspore/ops/_op_impl/tbe/scatter_add.py +2 -0
  502. mindspore/ops/_op_impl/tbe/scatter_nd_add.py +2 -2
  503. mindspore/ops/_op_impl/tbe/slice.py +26 -15
  504. mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
  505. mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +1 -1
  506. mindspore/ops/_op_impl/tbe/strided_slice_grad_d.py +1 -0
  507. mindspore/ops/_op_impl/tbe/trans_data_ds.py +15 -5
  508. mindspore/ops/_op_impl/tbe/unsorted_segment_sum.py +1 -1
  509. mindspore/ops/_op_impl/tbe/unsorted_segment_sum_ds.py +2 -0
  510. mindspore/ops/_primitive_cache.py +3 -2
  511. mindspore/ops/_register_for_op.py +11 -0
  512. mindspore/ops/_utils/__init__.py +1 -1
  513. mindspore/ops/_utils/utils.py +20 -41
  514. mindspore/ops/_vmap/__init__.py +2 -2
  515. mindspore/ops/_vmap/vmap_array_ops.py +170 -78
  516. mindspore/ops/_vmap/vmap_base.py +24 -10
  517. mindspore/ops/_vmap/vmap_convolution_ops.py +7 -10
  518. mindspore/ops/_vmap/vmap_grad_math_ops.py +4 -4
  519. mindspore/ops/_vmap/vmap_grad_nn_ops.py +41 -9
  520. mindspore/ops/_vmap/vmap_image_ops.py +52 -0
  521. mindspore/ops/_vmap/vmap_math_ops.py +77 -6
  522. mindspore/ops/_vmap/vmap_nn_ops.py +78 -29
  523. mindspore/ops/_vmap/vmap_other_ops.py +3 -1
  524. mindspore/ops/_vmap/vmap_random_ops.py +55 -3
  525. mindspore/ops/_vmap/vmap_sparse_ops.py +1 -0
  526. mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
  527. mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
  528. mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +18 -19
  529. mindspore/ops/bprop_mindir/Argmax_bprop.mindir +13 -12
  530. mindspore/ops/bprop_mindir/Argmin_bprop.mindir +14 -13
  531. mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +17 -18
  532. mindspore/ops/bprop_mindir/Assign_bprop.mindir +16 -16
  533. mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +150 -0
  534. mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +66 -0
  535. mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
  536. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +13 -12
  537. mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
  538. mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +28 -0
  539. mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
  540. mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +33 -0
  541. mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +306 -0
  542. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +12 -8
  543. mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
  544. mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
  545. mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +240 -0
  546. mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +247 -0
  547. mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +247 -0
  548. mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +315 -0
  549. mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +278 -0
  550. mindspore/ops/bprop_mindir/DType_bprop.mindir +12 -12
  551. mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +58 -0
  552. mindspore/ops/bprop_mindir/Depend_bprop.mindir +12 -13
  553. mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +23 -0
  554. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +138 -0
  555. mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +15 -0
  556. mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
  557. mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
  558. mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +22 -24
  559. mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +16 -14
  560. mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +27 -0
  561. mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
  562. mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
  563. mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
  564. mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +12 -12
  565. mindspore/ops/bprop_mindir/Elu_bprop.mindir +16 -0
  566. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  567. mindspore/ops/bprop_mindir/Equal_bprop.mindir +18 -19
  568. mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +58 -0
  569. mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +16 -0
  570. mindspore/ops/bprop_mindir/Flatten_bprop.mindir +54 -0
  571. mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +18 -15
  572. mindspore/ops/bprop_mindir/GatherD_bprop.mindir +26 -0
  573. mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +57 -0
  574. mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
  575. mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +17 -18
  576. mindspore/ops/bprop_mindir/Greater_bprop.mindir +18 -19
  577. mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +16 -0
  578. mindspore/ops/bprop_mindir/HSwish_bprop.mindir +16 -0
  579. mindspore/ops/bprop_mindir/IOU_bprop.mindir +18 -19
  580. mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
  581. mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +13 -12
  582. mindspore/ops/bprop_mindir/IsInf_bprop.mindir +13 -10
  583. mindspore/ops/bprop_mindir/IsNan_bprop.mindir +14 -11
  584. mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +126 -0
  585. mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +15 -0
  586. mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +30 -0
  587. mindspore/ops/bprop_mindir/LRN_bprop.mindir +43 -0
  588. mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
  589. mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +18 -19
  590. mindspore/ops/bprop_mindir/Less_bprop.mindir +17 -18
  591. mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +22 -19
  592. mindspore/ops/bprop_mindir/Load_bprop.mindir +12 -13
  593. mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +23 -0
  594. mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +17 -18
  595. mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +14 -13
  596. mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +21 -0
  597. mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +74 -0
  598. mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +74 -0
  599. mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +75 -0
  600. mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +65 -0
  601. mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
  602. mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
  603. mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
  604. mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +27 -0
  605. mindspore/ops/bprop_mindir/Mish_bprop.mindir +35 -0
  606. mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
  607. mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
  608. mindspore/ops/bprop_mindir/NonZero_bprop.mindir +14 -0
  609. mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +18 -19
  610. mindspore/ops/bprop_mindir/OneHot_bprop.mindir +25 -23
  611. mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +13 -13
  612. mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
  613. mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
  614. mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
  615. mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +29 -0
  616. mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +82 -0
  617. mindspore/ops/bprop_mindir/Range_bprop.mindir +21 -19
  618. mindspore/ops/bprop_mindir/Rank_bprop.mindir +11 -11
  619. mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +16 -0
  620. mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
  621. mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +18 -17
  622. mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +18 -17
  623. mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +19 -23
  624. mindspore/ops/bprop_mindir/Reshape_bprop.mindir +60 -0
  625. mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +29 -0
  626. mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +89 -0
  627. mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +52 -0
  628. mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +22 -0
  629. mindspore/ops/bprop_mindir/Round_bprop.mindir +14 -13
  630. mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
  631. mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
  632. mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +22 -0
  633. mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +24 -0
  634. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +22 -0
  635. mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
  636. mindspore/ops/bprop_mindir/SeLU_bprop.mindir +21 -0
  637. mindspore/ops/bprop_mindir/Select_bprop.mindir +30 -34
  638. mindspore/ops/bprop_mindir/Shape_bprop.mindir +12 -12
  639. mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +21 -0
  640. mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
  641. mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +16 -0
  642. mindspore/ops/bprop_mindir/Sign_bprop.mindir +13 -12
  643. mindspore/ops/bprop_mindir/Slice_bprop.mindir +26 -0
  644. mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +36 -0
  645. mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  646. mindspore/ops/bprop_mindir/Softplus_bprop.mindir +16 -0
  647. mindspore/ops/bprop_mindir/Softsign_bprop.mindir +33 -0
  648. mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
  649. mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +28 -0
  650. mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +23 -0
  651. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  652. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  653. mindspore/ops/bprop_mindir/Split_bprop.mindir +22 -0
  654. mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +54 -0
  655. mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +95 -0
  656. mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +98 -0
  657. mindspore/ops/bprop_mindir/Switch_bprop.mindir +28 -32
  658. mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
  659. mindspore/ops/bprop_mindir/Tanh_bprop.mindir +66 -0
  660. mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +22 -0
  661. mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +29 -0
  662. mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +14 -0
  663. mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
  664. mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
  665. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +23 -0
  666. mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +18 -15
  667. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +11 -13
  668. mindspore/ops/bprop_mindir/Unique_bprop.mindir +16 -0
  669. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +22 -0
  670. mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +32 -0
  671. mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +38 -0
  672. mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +13 -12
  673. mindspore/ops/bprop_mindir/__init__.py +1 -4
  674. mindspore/ops/bprop_mindir/generate_mindir.py +32 -20
  675. mindspore/ops/composite/__init__.py +12 -13
  676. mindspore/ops/composite/base.py +261 -254
  677. mindspore/ops/composite/env_ops.py +41 -0
  678. mindspore/ops/composite/math_ops.py +197 -156
  679. mindspore/ops/composite/multitype_ops/_compile_utils.py +428 -176
  680. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +188 -87
  681. mindspore/ops/composite/multitype_ops/add_impl.py +23 -1
  682. mindspore/ops/composite/multitype_ops/div_impl.py +3 -3
  683. mindspore/ops/composite/multitype_ops/equal_impl.py +1 -0
  684. mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -1
  685. mindspore/ops/composite/multitype_ops/getitem_impl.py +52 -5
  686. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +31 -0
  687. mindspore/ops/composite/multitype_ops/greater_impl.py +31 -0
  688. mindspore/ops/composite/multitype_ops/in_impl.py +15 -3
  689. mindspore/ops/composite/multitype_ops/less_equal_impl.py +33 -2
  690. mindspore/ops/composite/multitype_ops/less_impl.py +33 -0
  691. mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -2
  692. mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
  693. mindspore/ops/composite/multitype_ops/mod_impl.py +1 -1
  694. mindspore/ops/composite/multitype_ops/mul_impl.py +21 -7
  695. mindspore/ops/composite/multitype_ops/not_in_impl.py +15 -3
  696. mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -4
  697. mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
  698. mindspore/ops/composite/multitype_ops/setitem_impl.py +62 -70
  699. mindspore/ops/composite/multitype_ops/sub_impl.py +3 -3
  700. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +41 -4
  701. mindspore/ops/function/__init__.py +323 -8
  702. mindspore/ops/function/array_func.py +3511 -780
  703. mindspore/ops/function/clip_func.py +329 -0
  704. mindspore/ops/function/debug_func.py +6 -6
  705. mindspore/ops/function/grad/__init__.py +5 -1
  706. mindspore/ops/function/grad/grad_func.py +736 -65
  707. mindspore/ops/function/image_func.py +270 -0
  708. mindspore/ops/function/linalg_func.py +268 -8
  709. mindspore/ops/function/math_func.py +8032 -3164
  710. mindspore/ops/function/nn_func.py +5619 -1855
  711. mindspore/ops/function/other_func.py +115 -0
  712. mindspore/ops/function/parameter_func.py +11 -10
  713. mindspore/ops/function/random_func.py +939 -77
  714. mindspore/ops/function/sparse_func.py +249 -84
  715. mindspore/ops/function/sparse_unary_func.py +2303 -0
  716. mindspore/ops/function/spectral_func.py +146 -0
  717. mindspore/ops/function/vmap_func.py +114 -0
  718. mindspore/ops/functional.py +182 -254
  719. mindspore/ops/op_info_register.py +79 -34
  720. mindspore/ops/operations/__init__.py +210 -118
  721. mindspore/ops/operations/_csr_ops.py +7 -7
  722. mindspore/ops/operations/_embedding_cache_ops.py +25 -15
  723. mindspore/ops/operations/_grad_ops.py +447 -322
  724. mindspore/ops/operations/_inner_ops.py +547 -176
  725. mindspore/ops/operations/_map_tensor_ops.py +112 -0
  726. mindspore/ops/operations/_ms_kernel.py +29 -27
  727. mindspore/ops/operations/_ocr_ops.py +11 -11
  728. mindspore/ops/operations/_opaque_predicate_registry.py +41 -0
  729. mindspore/ops/operations/_quant_ops.py +186 -101
  730. mindspore/ops/operations/_rl_inner_ops.py +122 -61
  731. mindspore/ops/operations/_scalar_ops.py +466 -0
  732. mindspore/ops/operations/_sequence_ops.py +1047 -0
  733. mindspore/ops/operations/_tensor_array.py +10 -11
  734. mindspore/ops/operations/_thor_ops.py +4 -4
  735. mindspore/ops/operations/array_ops.py +1428 -1226
  736. mindspore/ops/operations/comm_ops.py +180 -117
  737. mindspore/ops/operations/control_ops.py +4 -2
  738. mindspore/ops/operations/custom_ops.py +185 -98
  739. mindspore/ops/operations/debug_ops.py +92 -54
  740. mindspore/ops/operations/image_ops.py +406 -211
  741. mindspore/ops/operations/inner_ops.py +42 -53
  742. mindspore/ops/operations/linalg_ops.py +32 -29
  743. mindspore/ops/operations/math_ops.py +2076 -897
  744. mindspore/ops/operations/nn_ops.py +1282 -1252
  745. mindspore/ops/operations/other_ops.py +124 -278
  746. mindspore/ops/operations/random_ops.py +345 -178
  747. mindspore/ops/operations/rl_ops.py +8 -9
  748. mindspore/ops/operations/sparse_ops.py +502 -157
  749. mindspore/ops/operations/spectral_ops.py +107 -0
  750. mindspore/ops/primitive.py +192 -15
  751. mindspore/ops/vm_impl_registry.py +23 -2
  752. mindspore/parallel/__init__.py +6 -1
  753. mindspore/parallel/_auto_parallel_context.py +199 -92
  754. mindspore/parallel/_cell_wrapper.py +4 -2
  755. mindspore/parallel/_cost_model_context.py +3 -0
  756. mindspore/parallel/_dp_allreduce_fusion.py +2 -1
  757. mindspore/parallel/_offload_context.py +185 -0
  758. mindspore/parallel/_parallel_serialization.py +167 -28
  759. mindspore/parallel/_ps_context.py +9 -5
  760. mindspore/parallel/_recovery_context.py +1 -1
  761. mindspore/parallel/_tensor.py +9 -1
  762. mindspore/{nn/transformer → parallel/_transformer}/__init__.py +6 -6
  763. mindspore/{nn/transformer → parallel/_transformer}/layers.py +59 -37
  764. mindspore/{nn/transformer → parallel/_transformer}/loss.py +4 -7
  765. mindspore/{nn/transformer → parallel/_transformer}/moe.py +160 -35
  766. mindspore/{nn/transformer → parallel/_transformer}/op_parallel_config.py +3 -3
  767. mindspore/{nn/transformer → parallel/_transformer}/transformer.py +235 -196
  768. mindspore/parallel/_utils.py +47 -7
  769. mindspore/parallel/algo_parameter_config.py +5 -1
  770. mindspore/parallel/checkpoint_transform.py +329 -0
  771. mindspore/parallel/shard.py +229 -0
  772. mindspore/perf_msvcbuildinsights.dll +0 -0
  773. mindspore/pgodb140.dll +0 -0
  774. mindspore/pgort140.dll +0 -0
  775. mindspore/profiler/__init__.py +2 -1
  776. mindspore/profiler/common/util.py +4 -3
  777. mindspore/profiler/common/validator/validate_path.py +2 -2
  778. mindspore/profiler/envprofiling.py +249 -0
  779. mindspore/profiler/parser/aicpu_data_parser.py +38 -39
  780. mindspore/profiler/parser/ascend_timeline_generator.py +497 -0
  781. mindspore/profiler/parser/base_timeline_generator.py +471 -0
  782. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +684 -0
  783. mindspore/profiler/parser/framework_parser.py +42 -16
  784. mindspore/profiler/parser/hccl_parser.py +158 -158
  785. mindspore/profiler/parser/hwts_log_parser.py +7 -6
  786. mindspore/profiler/parser/integrator.py +18 -1579
  787. mindspore/profiler/parser/minddata_analyzer.py +8 -8
  788. mindspore/profiler/parser/msadvisor_analyzer.py +14 -27
  789. mindspore/profiler/parser/msadvisor_parser.py +2 -4
  790. mindspore/profiler/parser/optime_parser.py +17 -18
  791. mindspore/profiler/parser/profiler_info.py +108 -0
  792. mindspore/profiler/parser/step_trace_parser.py +1 -1
  793. mindspore/profiler/profiling.py +396 -194
  794. mindspore/rewrite/__init__.py +6 -2
  795. mindspore/rewrite/api/node.py +51 -110
  796. mindspore/rewrite/api/node_type.py +10 -6
  797. mindspore/rewrite/api/pattern_engine.py +51 -7
  798. mindspore/rewrite/api/scoped_value.py +64 -53
  799. mindspore/rewrite/api/symbol_tree.py +108 -61
  800. mindspore/rewrite/api/tree_node_helper.py +2 -3
  801. mindspore/{compression/quant/__init__.py → rewrite/ast_creator_register.py} +20 -11
  802. mindspore/rewrite/ast_helpers/__init__.py +6 -3
  803. mindspore/rewrite/ast_helpers/ast_creator.py +115 -0
  804. mindspore/rewrite/ast_helpers/ast_finder.py +99 -1
  805. mindspore/rewrite/ast_helpers/ast_modifier.py +17 -4
  806. mindspore/rewrite/ast_helpers/ast_replacer.py +1 -1
  807. mindspore/rewrite/ast_transformers/__init__.py +0 -1
  808. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +46 -5
  809. mindspore/rewrite/ast_transformers/remove_return_out_of_if.py +6 -3
  810. mindspore/rewrite/common/__init__.py +2 -0
  811. mindspore/rewrite/common/event.py +1 -1
  812. mindspore/rewrite/common/observable.py +1 -1
  813. mindspore/rewrite/common/observer.py +1 -1
  814. mindspore/rewrite/common/rewrite_elog.py +35 -0
  815. mindspore/rewrite/namer.py +2 -2
  816. mindspore/rewrite/namespace.py +14 -4
  817. mindspore/rewrite/node.py +161 -13
  818. mindspore/rewrite/parser.py +0 -1
  819. mindspore/rewrite/parser_register.py +0 -1
  820. mindspore/rewrite/parsers/arguments_parser.py +3 -2
  821. mindspore/rewrite/parsers/assign_parser.py +267 -67
  822. mindspore/rewrite/parsers/attribute_parser.py +56 -0
  823. mindspore/rewrite/parsers/class_def_parser.py +191 -108
  824. mindspore/rewrite/parsers/constant_parser.py +101 -0
  825. mindspore/rewrite/parsers/container_parser.py +88 -0
  826. mindspore/rewrite/parsers/for_parser.py +28 -15
  827. mindspore/rewrite/parsers/function_def_parser.py +21 -5
  828. mindspore/rewrite/parsers/if_parser.py +11 -28
  829. mindspore/rewrite/parsers/module_parser.py +9 -6
  830. mindspore/rewrite/parsers/return_parser.py +3 -2
  831. mindspore/rewrite/sparsify/__init__.py +0 -0
  832. mindspore/rewrite/sparsify/sparse_transformer.py +448 -0
  833. mindspore/rewrite/sparsify/sparsify.py +109 -0
  834. mindspore/rewrite/sparsify/utils.py +173 -0
  835. mindspore/rewrite/symbol_tree.py +322 -109
  836. mindspore/rewrite/symbol_tree_builder.py +45 -8
  837. mindspore/rewrite/symbol_tree_dumper.py +0 -1
  838. mindspore/rewrite/topological_manager.py +1 -2
  839. mindspore/run_check/_check_version.py +209 -112
  840. mindspore/run_check/run_check.py +2 -1
  841. mindspore/tbbmalloc.dll +0 -0
  842. mindspore/tinyxml2.dll +0 -0
  843. mindspore/train/__init__.py +6 -4
  844. mindspore/train/_utils.py +28 -5
  845. mindspore/train/amp.py +321 -50
  846. mindspore/train/callback/__init__.py +3 -1
  847. mindspore/train/callback/_backup_and_restore.py +120 -0
  848. mindspore/train/callback/_callback.py +8 -8
  849. mindspore/train/callback/_checkpoint.py +12 -9
  850. mindspore/train/callback/_early_stop.py +13 -7
  851. mindspore/train/callback/_history.py +8 -8
  852. mindspore/train/callback/_lambda_callback.py +6 -6
  853. mindspore/train/callback/_landscape.py +36 -38
  854. mindspore/train/callback/_loss_monitor.py +12 -6
  855. mindspore/train/callback/_lr_scheduler_callback.py +2 -4
  856. mindspore/train/callback/_on_request_exit.py +212 -0
  857. mindspore/train/callback/_reduce_lr_on_plateau.py +13 -7
  858. mindspore/train/callback/_summary_collector.py +27 -19
  859. mindspore/train/callback/_time_monitor.py +13 -7
  860. mindspore/train/checkpoint_pb2.py +68 -8
  861. mindspore/train/data_sink.py +122 -33
  862. mindspore/train/dataset_helper.py +28 -87
  863. mindspore/train/loss_scale_manager.py +4 -7
  864. mindspore/{nn → train}/metrics/__init__.py +20 -20
  865. mindspore/{nn → train}/metrics/accuracy.py +12 -10
  866. mindspore/{nn → train}/metrics/auc.py +4 -4
  867. mindspore/{nn → train}/metrics/bleu_score.py +4 -4
  868. mindspore/{nn → train}/metrics/confusion_matrix.py +10 -8
  869. mindspore/{nn → train}/metrics/cosine_similarity.py +4 -4
  870. mindspore/{nn → train}/metrics/dice.py +6 -5
  871. mindspore/{nn → train}/metrics/error.py +7 -5
  872. mindspore/{nn → train}/metrics/fbeta.py +9 -7
  873. mindspore/{nn → train}/metrics/hausdorff_distance.py +8 -6
  874. mindspore/{nn → train}/metrics/loss.py +4 -3
  875. mindspore/{nn → train}/metrics/mean_surface_distance.py +6 -5
  876. mindspore/{nn → train}/metrics/metric.py +6 -5
  877. mindspore/{nn → train}/metrics/occlusion_sensitivity.py +4 -3
  878. mindspore/{nn → train}/metrics/perplexity.py +5 -4
  879. mindspore/{nn → train}/metrics/precision.py +5 -4
  880. mindspore/{nn → train}/metrics/recall.py +5 -4
  881. mindspore/{nn → train}/metrics/roc.py +7 -6
  882. mindspore/{nn → train}/metrics/root_mean_square_surface_distance.py +6 -5
  883. mindspore/{nn → train}/metrics/topk.py +7 -5
  884. mindspore/train/mind_ir_pb2.py +339 -32
  885. mindspore/train/model.py +113 -84
  886. mindspore/train/serialization.py +547 -167
  887. mindspore/train/summary/_summary_adapter.py +1 -1
  888. mindspore/train/summary/summary_record.py +43 -12
  889. mindspore/train/train_thor/convert_utils.py +7 -1
  890. mindspore/train/train_thor/dataset_helper.py +3 -3
  891. mindspore/train/train_thor/model_thor.py +0 -4
  892. mindspore/turbojpeg.dll +0 -0
  893. mindspore/vcmeta.dll +0 -0
  894. mindspore/vcruntime140.dll +0 -0
  895. mindspore/vcruntime140_1.dll +0 -0
  896. mindspore/version.py +1 -1
  897. {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/METADATA +4 -3
  898. {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/RECORD +901 -660
  899. mindspore/compression/common/constant.py +0 -124
  900. mindspore/compression/export/__init__.py +0 -19
  901. mindspore/compression/export/quant_export.py +0 -514
  902. mindspore/compression/quant/qat.py +0 -636
  903. mindspore/compression/quant/quant_utils.py +0 -462
  904. mindspore/compression/quant/quantizer.py +0 -68
  905. mindspore/libatomic-1.dll +0 -0
  906. mindspore/libgcc_s_seh-1.dll +0 -0
  907. mindspore/libgfortran-4.dll +0 -0
  908. mindspore/libgomp-1.dll +0 -0
  909. mindspore/libjpeg-62.dll +0 -0
  910. mindspore/libmindspore.dll +0 -0
  911. mindspore/libmindspore_common.dll +0 -0
  912. mindspore/libmindspore_core.dll +0 -0
  913. mindspore/libmindspore_glog.dll +0 -0
  914. mindspore/libnnacl.dll +0 -0
  915. mindspore/libopencv_core452.dll +0 -0
  916. mindspore/libopencv_imgcodecs452.dll +0 -0
  917. mindspore/libopencv_imgproc452.dll +0 -0
  918. mindspore/libquadmath-0.dll +0 -0
  919. mindspore/libsqlite3.dll +0 -0
  920. mindspore/libssp-0.dll +0 -0
  921. mindspore/libstdc++-6.dll +0 -0
  922. mindspore/libtinyxml2.dll +0 -0
  923. mindspore/libturbojpeg.dll +0 -0
  924. mindspore/libwinpthread-1.dll +0 -0
  925. mindspore/nn/layer/quant.py +0 -1868
  926. mindspore/nn/layer/rnn_utils.py +0 -90
  927. mindspore/nn/probability/dpn/__init__.py +0 -22
  928. mindspore/nn/probability/dpn/vae/__init__.py +0 -25
  929. mindspore/nn/probability/dpn/vae/cvae.py +0 -138
  930. mindspore/nn/probability/dpn/vae/vae.py +0 -122
  931. mindspore/nn/probability/infer/__init__.py +0 -22
  932. mindspore/nn/probability/infer/variational/elbo.py +0 -70
  933. mindspore/nn/probability/infer/variational/svi.py +0 -84
  934. mindspore/nn/probability/toolbox/__init__.py +0 -22
  935. mindspore/nn/probability/toolbox/anomaly_detection.py +0 -99
  936. mindspore/nn/probability/toolbox/uncertainty_evaluation.py +0 -363
  937. mindspore/nn/probability/transforms/__init__.py +0 -22
  938. mindspore/nn/probability/transforms/transform_bnn.py +0 -262
  939. mindspore/nn/probability/zhusuan/__init__.py +0 -18
  940. mindspore/nn/probability/zhusuan/framework/__init__.py +0 -18
  941. mindspore/nn/probability/zhusuan/framework/bn.py +0 -95
  942. mindspore/nn/probability/zhusuan/variational/__init__.py +0 -18
  943. mindspore/nn/probability/zhusuan/variational/elbo.py +0 -46
  944. mindspore/ops/_op_impl/tbe/bias_add_grad_ds.py +0 -52
  945. mindspore/ops/_op_impl/tbe/scatter_nd_add_ds.py +0 -43
  946. mindspore/ops/bprop_mindir/AssignAdd_bprop.mindir +0 -20
  947. mindspore/ops/bprop_mindir/Identity_bprop.mindir +0 -9
  948. mindspore/ops/bprop_mindir/LogicalOr_bprop.mindir +0 -20
  949. mindspore/ops/bprop_mindir/ReLU_bprop.mindir +0 -16
  950. mindspore/ops/bprop_mindir/UpdateState_bprop.mindir +0 -17
  951. mindspore/ops/bprop_mindir/stop_gradient_bprop.mindir +0 -12
  952. mindspore/ops/composite/array_ops.py +0 -210
  953. mindspore/ops/composite/clip_ops.py +0 -238
  954. mindspore/ops/composite/random_ops.py +0 -426
  955. mindspore/ops/composite/vmap_ops.py +0 -38
  956. mindspore/ops/operations/sponge_ops.py +0 -3531
  957. mindspore/ops/operations/sponge_update_ops.py +0 -2546
  958. mindspore/parallel/nn/__init__.py +0 -42
  959. mindspore/parallel/nn/loss.py +0 -22
  960. mindspore/parallel/nn/moe.py +0 -21
  961. mindspore/parallel/nn/op_parallel_config.py +0 -22
  962. mindspore/parallel/nn/transformer.py +0 -31
  963. mindspore/run_check/_check_deps_version.py +0 -84
  964. {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/WHEEL +0 -0
  965. {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/entry_points.txt +0 -0
  966. {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2020-2021 Huawei Technologies Co., Ltd
1
+ # Copyright 2020-2023 Huawei Technologies Co., Ltd
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -15,24 +15,32 @@
15
15
 
16
16
  """constexpr util"""
17
17
  from __future__ import absolute_import
18
+ from enum import IntEnum
19
+
18
20
 
19
21
  from mindspore.ops.composite.multitype_ops import _constexpr_utils as const_utils
20
22
  from mindspore.ops import functional as F
21
- from mindspore.ops import operations as P
23
+ from mindspore.ops import operations as P
22
24
  from mindspore.ops.composite import base
23
25
  from mindspore.ops._primitive_cache import _get_cache_prim
24
- from mindspore.ops.operations._inner_ops import TensorCopySlices, SliceGetItem, DynamicBroadcastTo, TopTypeof
26
+ from mindspore.ops.operations._inner_ops import TensorCopySlices, SliceGetItem, \
27
+ TopTypeof, issubclass_, IsParameter, GetitemTensorIndexInfo, SetitemTensorIndexInfo
25
28
  from mindspore.common import dtype as mstype
26
29
  from mindspore.common._register_for_tensor import tensor_operator_registry
27
- from mindspore.common.tensor import Tensor, CSRTensor, COOTensor
28
- from mindspore.common._utils import is_shape_unknown
30
+ from mindspore.common.initializer import Zero
31
+ from mindspore.common import Tensor, CSRTensor, COOTensor
32
+ from mindspore.common import mutable
33
+ from mindspore import ops
34
+ from mindspore.ops.primitive import _primexpr
29
35
 
30
36
  slice_get_item = SliceGetItem()
31
37
  hyper_map = base.HyperMap()
32
38
  stack = P.Stack(axis=-1)
33
39
  copy_slice = TensorCopySlices()
34
- dynamic_broadcast_to = DynamicBroadcastTo()
35
40
  toptypeof = TopTypeof()
41
+ is_parameter = IsParameter()
42
+ getitem_tensor_index_info = GetitemTensorIndexInfo(const_utils.is_ascend())
43
+ setitem_tensor_index_info = SetitemTensorIndexInfo(const_utils.is_ascend())
36
44
 
37
45
 
38
46
  def strided_slice(data, begin_strides, end_strides, step_strides, begin_mask=0, end_mask=0, ellipsis_mask=0,
@@ -43,50 +51,138 @@ def strided_slice(data, begin_strides, end_strides, step_strides, begin_mask=0,
43
51
  return strided_slice_(data, begin_strides, end_strides, step_strides)
44
52
 
45
53
 
54
+ class ValueTransferType(IntEnum):
55
+ """Transfer op types of handling tensor getitem/setitem"""
56
+ kUnknown = 0
57
+ kTensorScatterUpdate = 1
58
+ kExpandDims = 2
59
+ kBroadCast = 3
60
+ kCast = 4
61
+ kSelect = 5
62
+ kGather = 6
63
+ kStrideSlice = 7
64
+ kStrideSliceWithMask = 8
65
+ kGatherND = 9
66
+ kScatterNdUpdate = 10
67
+ kReshape = 11
68
+ kScatterND = 12
69
+ kNumberToTensor = 13
70
+ kHandleSequenceValue = 14
71
+ kByPass = 15
72
+ kReSetItemByIndex = 16
73
+ kCopySlice = 17
74
+ kSetItemByBool = 18
75
+ kEmptyTensor = 19
76
+ kSetItemByEllipsis = 20
77
+ kRaiseIndexError = 21
78
+
79
+
80
+ def data_update(transfer_types, args, data, new_index, value=None):
81
+ """
82
+ We finally generate a new tensor when handling tensor getitem/setitem
83
+ by transfer data and value with index.
84
+ """
85
+ for transfer_type, arg in zip(transfer_types, args):
86
+ if transfer_type == ValueTransferType.kUnknown:
87
+ raise IndexError(f"Inlvaid transfer type {transfer_type}.")
88
+ if transfer_type <= ValueTransferType.kScatterND:
89
+ data = data_update_by_ops(transfer_type, arg, data, new_index, value)
90
+ if transfer_type == ValueTransferType.kSetItemByBool:
91
+ return tensor_setitem_by_bool(data, new_index, value)
92
+ if transfer_type == ValueTransferType.kCopySlice:
93
+ return copy_slice(data, value.astype(data.dtype), arg[0], arg[1], arg[2])
94
+ if transfer_type == ValueTransferType.kSetItemByEllipsis:
95
+ return tensor_setitem_by_ellipsis(data, new_index, value)
96
+ if transfer_type == ValueTransferType.kReSetItemByIndex:
97
+ data[new_index] = value
98
+ return data
99
+ if transfer_type == ValueTransferType.kEmptyTensor:
100
+ return handle_empty_tensor(arg, data)
101
+ if transfer_type == ValueTransferType.kRaiseIndexError:
102
+ raise IndexError(
103
+ f'index {arg[0]} is out of bounds for dimension with size {arg[1]}')
104
+ return data
105
+
106
+
107
+ def data_update_by_ops(transfer_type, arg, data, new_index, value=None):
108
+ """
109
+ Generate a new tensor when handling tensor getitem/setitem
110
+ by ops.
111
+ """
112
+ if transfer_type == ValueTransferType.kStrideSliceWithMask:
113
+ stride_info, mask_index = arg[0], arg[1]
114
+ data = strided_slice(data, stride_info[0], stride_info[1], stride_info[2],
115
+ mask_index[0], mask_index[1], 0, 0, mask_index[2])
116
+ elif transfer_type == ValueTransferType.kGatherND:
117
+ if isinstance(new_index, list):
118
+ new_index = handle_multi_dim_index_tensor(new_index, arg)
119
+ data = F.gather_nd(data, Tensor(new_index))
120
+ elif transfer_type == ValueTransferType.kTensorScatterUpdate:
121
+ if isinstance(new_index, list):
122
+ new_index = handle_multi_dim_index_tensor(new_index, arg)
123
+ data = F.tensor_scatter_update(data, new_index, value)
124
+ elif transfer_type == ValueTransferType.kScatterNdUpdate:
125
+ F.scatter_nd_update(data, new_index, value)
126
+ elif transfer_type == ValueTransferType.kSelect:
127
+ data = F.select(Tensor(new_index), value, data)
128
+ elif transfer_type == ValueTransferType.kReshape:
129
+ data = F.reshape(data, arg)
130
+ elif transfer_type == ValueTransferType.kGather:
131
+ data = F.gather(data, new_index, 0)
132
+ elif transfer_type == ValueTransferType.kExpandDims:
133
+ data = F.expand_dims(data, 0)
134
+ elif transfer_type == ValueTransferType.kStrideSlice:
135
+ data = F.strided_slice(data, arg[0], arg[1], arg[2])
136
+ else:
137
+ raise IndexError(f"Inlvaid transfer type {transfer_type}.")
138
+ return data
139
+
140
+
141
+ def value_update(transfer_types, args, data, value):
142
+ """Transfer value before set value to tensor when handling tensor setitem"""
143
+ for transfer_type, arg in zip(transfer_types, args):
144
+ if transfer_type == ValueTransferType.kByPass:
145
+ continue
146
+ if transfer_type == ValueTransferType.kNumberToTensor:
147
+ value = F.fill(F.dtype(data), (), value)
148
+ elif transfer_type == ValueTransferType.kHandleSequenceValue:
149
+ op_type, index = arg
150
+ if op_type == const_utils.SET_ITEM_BY_ONE_TENSOR:
151
+ index = Tensor(index)
152
+ value = _generate_updates_from_sequence(
153
+ data, index, value, op_type)
154
+ elif transfer_type == ValueTransferType.kExpandDims:
155
+ value = F.expand_dims(value, arg)
156
+ elif transfer_type == ValueTransferType.kBroadCast:
157
+ value = _broadcast(arg, value.astype(F.dtype(data)))
158
+ elif transfer_type == ValueTransferType.kCast:
159
+ value = F.cast(value, F.dtype(data))
160
+ elif transfer_type == ValueTransferType.kReshape:
161
+ value = F.reshape(value, arg)
162
+ elif transfer_type == ValueTransferType.kScatterND:
163
+ value = F.scatter_nd(arg[0], value, arg[1])
164
+ else:
165
+ raise IndexError(f"Inlvaid transfer type {transfer_type}.")
166
+ return value
167
+
168
+
46
169
  def _tensor_getitem(self, index):
47
170
  """Handle tensor getitem"""
48
- if isinstance(index, Tensor):
49
- return tensor_index_by_tensor(self, index)
50
- if isinstance(index, list):
51
- return tensor_index_by_list(self, index)
52
- if isinstance(index, tuple):
53
- return tensor_index_by_tuple(self, index)
54
- if isinstance(index, bool):
55
- return _tensor_index_by_bool(self, index)
56
- if isinstance(index, int):
57
- return _tensor_index_by_integer(self, index)
58
- if isinstance(index, slice):
59
- return tensor_index_by_slice(self, index)
60
- if index is None:
61
- return F.expand_dims(self, 0)
62
- if index is ...:
63
- return self
64
- raise IndexError(f"Only support integers, slices(`:`), ellipsis(`...`), None, bool, tensor with int, "
65
- f"list and tuple ,but got {index} with type {type(index)}.")
171
+ new_index, tensor_update_types, tensor_update_args = getitem_tensor_index_info(
172
+ self, index)
173
+ return data_update(tensor_update_types, tensor_update_args, self, new_index)
66
174
 
67
175
 
68
176
  def _tensor_setitem(self, index, value):
69
177
  """Handle tensor setitem"""
70
- if not isinstance(value, (int, float, bool, list, tuple, Tensor)):
71
- raise ValueError(f"only support numbers, Tensor, tuple, list as value,"
72
- f"but got {value} with type {type(value)}.")
73
- if isinstance(index, list):
74
- index = format_list_indices(index, F.shape(self)[0])
75
- if isinstance(index, Tensor):
76
- return tensor_setitem_by_tensor(self, index, value)
77
- if isinstance(index, tuple):
78
- return tensor_setitem_by_tuple(self, index, value)
79
- if isinstance(index, bool):
80
- return tensor_setitem_by_bool(self, index, value)
81
- if isinstance(index, int):
82
- return tensor_setitem_by_number(self, index, value)
83
- if isinstance(index, slice):
84
- return tensor_setitem_by_slice(self, index, value)
85
- if index in (None, ...):
86
- return tensor_setitem_by_ellipsis(self, index, value)
87
-
88
- raise IndexError("Tensor setitem index only support integers, slices(`:`), ellipsis(`...`), bool, tensor, \
89
- list and tuple, but got {index} with type{type(index)}")
178
+ setitem_info = setitem_tensor_index_info(self, index, value)
179
+ new_index = setitem_info[0]
180
+ v_transfer_types = setitem_info[1]
181
+ v_transfer_args = setitem_info[2]
182
+ data_update_types = setitem_info[3]
183
+ data_update_args = setitem_info[4]
184
+ value = value_update(v_transfer_types, v_transfer_args, self, value)
185
+ return data_update(data_update_types, data_update_args, self, new_index, value)
90
186
 
91
187
 
92
188
  tensor_operator_registry.register("__getitem__", _tensor_getitem)
@@ -119,6 +215,10 @@ def _tensor_mul(self, other):
119
215
  return F.mul(self, other)
120
216
 
121
217
 
218
+ def _tensor_matmul(self, other):
219
+ return F.matmul(self, other)
220
+
221
+
122
222
  def _tensor_div(self, other):
123
223
  if isinstance(self, (tuple, list)):
124
224
  self = sequence_to_tensor(self, F.dtype(other))
@@ -158,6 +258,7 @@ def _tensor_floordiv(self, other):
158
258
  tensor_operator_registry.register('__add__', _tensor_add)
159
259
  tensor_operator_registry.register('__sub__', _tensor_sub)
160
260
  tensor_operator_registry.register('__mul__', _tensor_mul)
261
+ tensor_operator_registry.register('__matmul__', _tensor_matmul)
161
262
  tensor_operator_registry.register('__truediv__', _tensor_div)
162
263
  tensor_operator_registry.register('__mod__', _tensor_mod)
163
264
  tensor_operator_registry.register('__pow__', _tensor_pow)
@@ -165,6 +266,13 @@ tensor_operator_registry.register('__rpow__', _tensor_rpow)
165
266
  tensor_operator_registry.register('__floordiv__', _tensor_floordiv)
166
267
 
167
268
 
269
+ def _scalar_to_tensor(input_x):
270
+ if ops.isconstant(input_x):
271
+ return P.ScalarToTensor()(input_x, ops.dtype(input_x))
272
+ # use add Tensor([0]) cast scalar to tensor.
273
+ return ops.add(input_x, mutable(Tensor(0)))
274
+
275
+
168
276
  def tensor_item(data, *args):
169
277
  """Tensor getitem by index whose dtype is int or tuple with int."""
170
278
  # transform a.item(tuple(int)) -> a.item(int1,int2...intN)
@@ -239,13 +347,9 @@ def tensor_itemset_by_tuple_with_number(data, tuple_index, nubmer_value):
239
347
 
240
348
  def _broadcast(broadcast_shape, x):
241
349
  """Broadcast tensor to the required shape."""
242
- if not const_utils.check_two_shapes_need_broadcast(broadcast_shape, F.shape(x)):
350
+ if F.shape(x) == broadcast_shape:
243
351
  return x
244
- multiples = const_utils.compute_multiples(F.shape(x), broadcast_shape)
245
- if multiples:
246
- x = F.reshape(x, const_utils.expanded_shape(F.shape(x), len(multiples) - F.rank(x)))
247
- return F.tile(x, multiples)
248
- return x
352
+ return F.broadcast_to(x, broadcast_shape)
249
353
 
250
354
 
251
355
  def _transform_indexing_tensor(broadcast_shape, final_shape, new_shape, item):
@@ -285,6 +389,46 @@ def _transform_ellipsis_to_slice(data, tuple_index, op_name):
285
389
  return tuple_index_new
286
390
 
287
391
 
392
+ def handle_empty_tensor(arg, data):
393
+ """handle data update with empty tensor"""
394
+ if 0 in arg:
395
+ init_func = Zero()
396
+ init_func.__enable_zero_dim__ = True
397
+ return Tensor(shape=arg, dtype=data.dtype, init=init_func)
398
+ return const_utils.make_tensor([], data.dtype, arg)
399
+
400
+
401
+ def handle_multi_dim_index_tensor(new_index, arg):
402
+ """handle data update with multi dim index tensor"""
403
+ slice_cnt = 0
404
+ new_indies_tensor = []
405
+ if len(arg) == 1:
406
+ broadcast_shape = arg[0]
407
+ new_index = hyper_map(F.partial(Tensor), new_index)
408
+ broadcast_tensors = hyper_map(
409
+ F.partial(_broadcast, broadcast_shape), new_index)
410
+ new_broadcast_tensors = ()
411
+ for tensor in broadcast_tensors:
412
+ new_broadcast_tensors += (F.cast(tensor, mstype.int64),)
413
+ new_index = stack(new_broadcast_tensors)
414
+ return new_index
415
+ broadcast_shape, final_shape, index_tensor_new_shape, slice_shapes, tensor_positions, fancy_position = arg
416
+ for i, index in enumerate(new_index):
417
+ if i in tensor_positions:
418
+ transform_tensor = _transform_indexing_tensor(broadcast_shape, final_shape, index_tensor_new_shape,
419
+ Tensor(index))
420
+ new_indies_tensor.append(F.cast(transform_tensor, mstype.int64))
421
+ else:
422
+ shape = const_utils.compute_slice_shape(
423
+ slice_shapes, len(broadcast_shape), slice_cnt, fancy_position)
424
+ array = Tensor(index).reshape(shape)
425
+ slice_index_tensor = _broadcast(final_shape, array)
426
+ new_indies_tensor.append(F.cast(slice_index_tensor, mstype.int64))
427
+ slice_cnt += 1
428
+ new_index = stack(new_indies_tensor)
429
+ return new_index
430
+
431
+
288
432
  def _expand_data_dims(data, tuple_index):
289
433
  """expand the data's dim with 'None' and 'Boolean' in tuple_index"""
290
434
  indexes_types = hyper_map(toptypeof, tuple_index)
@@ -307,12 +451,34 @@ def _expand_data_dims(data, tuple_index):
307
451
  return data, tuple_index_new
308
452
 
309
453
 
454
+ def convert_variable_to_tensor_slice(slice_index):
455
+ """convert mutable scalar to tensor"""
456
+ start = slice_get_item(slice_index, "start")
457
+ stop = slice_get_item(slice_index, "stop")
458
+ step = slice_get_item(slice_index, "step")
459
+ find_mutable_scalar = False
460
+ if isinstance(start, int) and not F.isconstant(start):
461
+ start = ops.Cast()(start, mstype.int64)
462
+ find_mutable_scalar = True
463
+ if isinstance(stop, int) and not F.isconstant(stop):
464
+ stop = ops.Cast()(stop, mstype.int64)
465
+ find_mutable_scalar = True
466
+ if isinstance(step, int) and not F.isconstant(step):
467
+ step = ops.Cast()(step, mstype.int64)
468
+ find_mutable_scalar = True
469
+ if find_mutable_scalar:
470
+ return F.make_slice(start, stop, step)
471
+ return slice_index
472
+
473
+
310
474
  def tensor_index_by_slice(data, slice_index):
311
475
  """Tensor getitem by a slice."""
312
476
  min_data_dim, max_data_dim = 1, 8
313
477
  const_utils.judge_data_dim(data.ndim, min_data_dim, max_data_dim)
314
478
  data_shape = F.shape(data)
315
- is_dynamic = (is_shape_unknown(data_shape)
479
+ slice_index = convert_variable_to_tensor_slice(slice_index)
480
+
481
+ is_dynamic = (F.is_sequence_value_unknown(data_shape)
316
482
  or isinstance(slice_get_item(slice_index, "start"), Tensor)
317
483
  or isinstance(slice_get_item(slice_index, "stop"), Tensor)
318
484
  or isinstance(slice_get_item(slice_index, "step"), Tensor))
@@ -335,6 +501,12 @@ def get_stride_info_from_slice(data, slice_index):
335
501
  data_shape = F.dyn_shape(data)
336
502
  begin_strides, end_strides, step_strides = [], [], []
337
503
  start, stop, step = get_slice_stride(slice_index, data_shape[0])
504
+ if start.ndim > 0:
505
+ start = start.item()
506
+ if stop.ndim > 0:
507
+ stop = stop.item()
508
+ if step.ndim > 0:
509
+ step = step.item()
338
510
  begin_strides.append(start)
339
511
  end_strides.append(stop)
340
512
  step_strides.append(step)
@@ -364,19 +536,10 @@ def _tensor_index_by_bool(data, bool_value):
364
536
  return const_utils.raise_index_error("When tensor is indexed by a bool object, the value only support 'True'.")
365
537
 
366
538
 
367
- def check_range(x, dim_size):
368
- """Check whether x is within the range of dim_size"""
369
- tensor_x = const_utils.make_tensor(x)
370
- if tensor_x >= dim_size or tensor_x < -dim_size:
371
- return tensor_x
372
- tensor_x = tensor_x % dim_size
373
- return tensor_x
374
-
375
-
376
539
  def get_stride_info_from_integer(tensor_int):
377
540
  """Convert integer to slice"""
378
541
  begin_strides = [tensor_int]
379
- end_strides = [tensor_int + const_utils.make_tensor(1)]
542
+ end_strides = [tensor_int + 1]
380
543
  step_strides = [const_utils.make_tensor(1)]
381
544
  begin_tensor = stack(begin_strides)
382
545
  end_tensor = stack(end_strides)
@@ -386,14 +549,15 @@ def get_stride_info_from_integer(tensor_int):
386
549
 
387
550
  def _tensor_index_by_integer(data, int_index):
388
551
  """Tensor getitem by a single integer number"""
552
+ data_shape = F.shape(data)
553
+ if not data_shape:
554
+ const_utils.raise_type_error("Cannot iterate over a scalar tensor.")
389
555
  if data.ndim < 1 or data.ndim > 8:
390
556
  const_utils.raise_value_error("Expect Tensor to have dimension between 1 and 8.")
391
557
 
392
- data_shape = F.shape(data)
393
- if is_shape_unknown(data_shape):
394
- data_shape = F.dyn_shape(data)
395
- transformed_tensor = check_range(int_index, data_shape[0])
396
- begin_strides, end_strides, step_strides = get_stride_info_from_integer(transformed_tensor)
558
+ if F.is_sequence_value_unknown(data_shape) or not F.isconstant(int_index):
559
+ tensor_index = _scalar_to_tensor(int_index)
560
+ begin_strides, end_strides, step_strides = get_stride_info_from_integer(tensor_index)
397
561
  else:
398
562
  transformed_number = const_utils.check_range(int_index, data_shape[0])
399
563
  begin_strides, end_strides, step_strides = \
@@ -401,22 +565,41 @@ def _tensor_index_by_integer(data, int_index):
401
565
  shrink_axis_mask = 1
402
566
  begin_mask = 0
403
567
  end_mask = 0
404
- for i in range(1, len(data_shape)):
568
+ for i in range(2, 8):
405
569
  begin_mask += 2 ** i
406
570
  end_mask += 2 ** i
407
571
  return strided_slice(data, begin_strides, end_strides, step_strides, begin_mask, end_mask, 0, 0, shrink_axis_mask)
408
572
 
409
573
 
574
+ def _check_dim_shape_valid(data, tensor_index):
575
+ """check dim and shape of tensor_index for tensor(bool) indexing"""
576
+ if data.ndim < tensor_index.ndim:
577
+ raise IndexError(f"The dim of index cannot be greater than indexed data, but got "
578
+ f"dim of index:{tensor_index.ndim}, dim of data:{data.ndim}")
579
+ if data.shape[:tensor_index.ndim] != tensor_index.shape[:]:
580
+ raise IndexError(f"The shape of index {tensor_index.shape} does not match the shape "
581
+ f"of the indexed data {data.shape}")
582
+
583
+
584
+ def tensor_index_by_bool_tensor(data, tensor_index):
585
+ """Tensor getitem by a bool tensor"""
586
+ _check_dim_shape_valid(data, tensor_index)
587
+ tensor_index = tensor_index.nonzero()
588
+ return F.gather_nd(data, tensor_index)
589
+
590
+
410
591
  def tensor_index_by_tensor(data, tensor_index):
411
592
  """Tensor getitem by a single tensor"""
412
593
  min_data_dim, max_data_dim = 0, 7
413
594
  const_utils.judge_data_dim(data.ndim, min_data_dim, max_data_dim)
414
- valid = const_utils.check_type_isinstance(F.dtype(tensor_index), mstype.Int)
415
- if valid is False:
416
- exp_msg = const_utils.gen_exception_msg(
417
- "The tensor index must be int type, but got {}.", F.dtype(tensor_index))
418
- const_utils.raise_index_error(exp_msg)
419
- return F.gather(data, tensor_index, 0)
595
+ if const_utils.check_type_isinstance(F.dtype(tensor_index), mstype.Int):
596
+ return F.gather(data, tensor_index, 0)
597
+ if const_utils.check_type_isinstance(F.dtype(tensor_index), mstype.Bool):
598
+ return tensor_index_by_bool_tensor(data, tensor_index)
599
+ exp_msg = const_utils.gen_exception_msg(
600
+ "The tensor index must be int or bool type, but got {}.", F.dtype(tensor_index))
601
+ const_utils.raise_index_error(exp_msg)
602
+ return data
420
603
 
421
604
 
422
605
  def tensor_index_by_list(data, list_index):
@@ -427,10 +610,13 @@ def tensor_index_by_list(data, list_index):
427
610
  data_shape = F.shape(data)
428
611
  indexes_types = hyper_map(toptypeof, list_index)
429
612
  if const_utils.check_type_isinstance(indexes_types, (mstype.Bool, mstype.Int)):
430
- if data_shape[0] == -1 and all(isinstance(i, bool) for i in list_index):
431
- const_utils.raise_unimplemented_error(
432
- "Not supported to the dynamic shape tensor slice by using list of Boolean type")
433
- tensor_index = const_utils.sequence_to_index(list_index, data_shape[0])
613
+ if not F.isconstant(data_shape[0]):
614
+ if all(isinstance(i, bool) for i in list_index):
615
+ const_utils.raise_unimplemented_error(
616
+ "Not supported to the dynamic shape tensor slice by using list of Boolean type")
617
+ tensor_index = const_utils.sequence_to_index(list_index, None)
618
+ else:
619
+ tensor_index = const_utils.sequence_to_index(list_index, data_shape[0])
434
620
  if tensor_index is False:
435
621
  const_utils.raise_index_error("When tensor is indexed by list, the list can't be empty.")
436
622
  return F.gather(data, tensor_index, 0)
@@ -441,18 +627,28 @@ def tensor_index_by_list(data, list_index):
441
627
  return tensor_index_by_tuple(data, tuple_index_new)
442
628
 
443
629
 
630
+ def convert_tupleslice_to_tensor(tuple_index):
631
+ """convert mutable scalar in slice to tensor"""
632
+ new_tuple_index = []
633
+ for item in tuple_index:
634
+ if isinstance(item, slice):
635
+ item = convert_variable_to_tensor_slice(item)
636
+ new_tuple_index.append(item)
637
+ return tuple(new_tuple_index)
638
+
639
+
444
640
  def tensor_index_by_tuple(data, tuple_index):
445
641
  """Tensor getitem by tuple of various types with None"""
446
642
  if not tuple_index:
447
643
  return data
448
644
 
645
+ tuple_index = convert_tupleslice_to_tensor(tuple_index)
449
646
  op_name = const_utils.TENSOR_GETITEM
450
647
  tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name)
451
648
  data, tuple_index = _expand_data_dims(data, tuple_index)
452
649
 
453
650
  min_data_dim, max_data_dim = 1, 8
454
651
  const_utils.judge_data_dim(data.ndim, min_data_dim, max_data_dim)
455
-
456
652
  indexes_types = hyper_map(toptypeof, tuple_index)
457
653
  contain_type = const_utils.tuple_index_type_cnt(indexes_types, op_name)
458
654
  if contain_type == const_utils.ALL_BASIC:
@@ -460,31 +656,6 @@ def tensor_index_by_tuple(data, tuple_index):
460
656
  return _tensor_getitem_by_tuple(data, tuple_index, op_name)
461
657
 
462
658
 
463
- def _tensor_getitem_by_tuple_of_tensor(data, tuple_index, op_name):
464
- """Tensor getitem by a tuple of tensor."""
465
- data_shape = F.shape(data)
466
- tuple_index_len = len(tuple_index)
467
-
468
- indexes_types = hyper_map(F.dtype, tuple_index)
469
- const_utils.check_indexes_types_valid(indexes_types, mstype.int_type, op_name)
470
- tensor_index_shape = hyper_map(F.shape, tuple_index)
471
- broadcast_shape = const_utils.generate_broadcast_shape(tensor_index_shape, op_name)
472
- if 0 in broadcast_shape:
473
- res_shape = broadcast_shape
474
- if tuple_index_len < len(data_shape):
475
- res_shape += data_shape[tuple_index_len:]
476
- res = const_utils.make_tensor([], data.dtype, res_shape)
477
- return res
478
-
479
- broadcast_tensors = hyper_map(F.partial(_broadcast, broadcast_shape), tuple_index)
480
- new_broadcast_tensors = ()
481
- for tensor in broadcast_tensors:
482
- new_broadcast_tensors += (F.cast(tensor, mstype.int64),)
483
- indices = stack(new_broadcast_tensors)
484
- result = F.gather_nd(data, indices)
485
- return result
486
-
487
-
488
659
  def get_slice_stride(slice_index, dim_size):
489
660
  """Get slice stride info"""
490
661
  start = slice_get_item(slice_index, "start")
@@ -498,13 +669,13 @@ def get_slice_stride(slice_index, dim_size):
498
669
  if step is None:
499
670
  step = const_utils.make_tensor(1)
500
671
 
501
- if F.issubclass_(F.typeof(start), mstype.number):
672
+ if issubclass_(F.typeof(start), mstype.number):
502
673
  start = const_utils.make_tensor(start)
503
674
 
504
- if F.issubclass_(F.typeof(stop), mstype.number):
675
+ if issubclass_(F.typeof(stop), mstype.number):
505
676
  stop = const_utils.make_tensor(stop)
506
677
 
507
- if F.issubclass_(F.typeof(step), mstype.number):
678
+ if issubclass_(F.typeof(step), mstype.number):
508
679
  step = const_utils.make_tensor(step)
509
680
 
510
681
  return start, stop, step
@@ -543,7 +714,7 @@ def _get_stride_info_from_tuple(data, tuple_index):
543
714
  step_strides.append(step)
544
715
  index_count = index_count + 1
545
716
  elif isinstance(index, int):
546
- int_tensor = check_range(index, dim_size)
717
+ int_tensor = _scalar_to_tensor(index)
547
718
  begin_strides.append(int_tensor)
548
719
  end_strides.append(int_tensor + const_utils.make_tensor(1))
549
720
  step_strides.append(const_utils.make_tensor(1))
@@ -577,7 +748,7 @@ def _get_stride_info_from_tuple(data, tuple_index):
577
748
  def _tensor_getitem_by_tuple_slice(data, tuple_index):
578
749
  """Tensor getitem by a tuple of slice"""
579
750
  data_shape = F.shape(data)
580
- is_dynamic = is_shape_unknown(data_shape)
751
+ is_dynamic = F.is_sequence_value_unknown(data_shape)
581
752
  for item in tuple_index:
582
753
  if isinstance(item, slice):
583
754
  is_dynamic = is_dynamic or isinstance(slice_get_item(item, "start"), Tensor) \
@@ -599,6 +770,39 @@ def _tensor_getitem_by_tuple_slice(data, tuple_index):
599
770
  return strided_slice(data, begin_v, end_v, step_v, begin_mask, end_mask, 0, 0, shrink_axis_mask)
600
771
 
601
772
 
773
+ @_primexpr
774
+ def _tensor_getitem_by_tuple_parse_bool_tensor_index(index, tuple_index_new, tensor_indexes,
775
+ tensor_positions_new):
776
+ """ parse index of bool tensor type """
777
+ indices = index.nonzero()
778
+ if indices.shape[0] == 0:
779
+ return None, tensor_indexes, tensor_positions_new
780
+ indices = F.cast(indices, mstype.int64)
781
+ indices = indices.T
782
+ for sub_index in indices:
783
+ tensor_positions_new.append(len(tuple_index_new))
784
+ tuple_index_new += (sub_index,)
785
+ tensor_indexes.append(sub_index)
786
+ return tuple_index_new, tensor_indexes, tensor_positions_new
787
+
788
+
789
+ def _tensor_getitem_by_tuple_parse_tensor_index(index, tuple_index_new, tensor_indexes, tensor_positions_new):
790
+ """ parse index of tensor type """
791
+ if F.dtype(index) in mstype.int_type:
792
+ tensor_index = F.cast(index, mstype.int64)
793
+ tensor_positions_new.append(len(tuple_index_new))
794
+ tuple_index_new += (tensor_index,)
795
+ tensor_indexes.append(tensor_index)
796
+ elif F.dtype(index) == mstype.bool_:
797
+ return _tensor_getitem_by_tuple_parse_bool_tensor_index(index, tuple_index_new, tensor_indexes,
798
+ tensor_positions_new)
799
+ else:
800
+ exp_msg = const_utils.gen_exception_msg(
801
+ "The tensor element in tuple index must be int or bool type, but got {}.", F.dtype(index))
802
+ const_utils.raise_index_error(exp_msg)
803
+ return tuple_index_new, tensor_indexes, tensor_positions_new
804
+
805
+
602
806
  def _tensor_getitem_by_tuple(data, tuple_index, op_name):
603
807
  """Tensor getitem by a tuple of mixed tensor."""
604
808
  slice_is_tensor = False
@@ -609,51 +813,49 @@ def _tensor_getitem_by_tuple(data, tuple_index, op_name):
609
813
  or isinstance(slice_get_item(item, "step"), Tensor)
610
814
  if slice_is_tensor:
611
815
  const_utils.raise_index_error("Not supported when slice has tensor")
612
- tuple_index_len = len(tuple_index)
613
- tensor_indexes, slice_indexes = [], []
816
+
614
817
  indexes_types = hyper_map(toptypeof, tuple_index)
615
818
  slice_positions, _, _, int_positions, _, tensor_positions, sequence_positions = \
616
819
  const_utils.get_pos_of_indexes_types(indexes_types, op_name)
617
- tuple_index_new, slice_shapes = (), ()
618
820
  data_shape = F.shape(data)
821
+ tensor_indexes, slice_indexes = [], []
822
+ tuple_index_new, slice_shapes = (), ()
823
+ slice_positions_new, tensor_positions_new = [], []
619
824
  for i, (index, dim_size) in enumerate(zip(tuple_index, data_shape)):
620
825
  if i in int_positions:
621
826
  int_index = const_utils.check_range(index, dim_size)
622
827
  tensor_index = F.scalar_to_tensor(int_index, mstype.int64)
623
- if is_shape_unknown(data_shape):
624
- dyn_shape = F.dyn_shape(data)
625
- tensor_index = check_range(index, dyn_shape[i])
828
+ if F.is_sequence_value_unknown(data_shape):
829
+ tensor_index = _scalar_to_tensor(int_index)
626
830
  tensor_index = F.cast(tensor_index, mstype.int64)
831
+ tensor_positions_new.append(len(tuple_index_new))
627
832
  tuple_index_new += (tensor_index,)
628
833
  tensor_indexes.append(tensor_index)
629
- tensor_positions += (i,)
630
834
  elif i in sequence_positions:
631
835
  tensor_index = const_utils.sequence_to_index(index, dim_size)
632
836
  if tensor_index is False:
633
837
  const_utils.raise_index_error("The sequence element(tuple/list) in tuple index can't be empty.")
838
+ tensor_positions_new.append(len(tuple_index_new))
634
839
  tuple_index_new += (tensor_index,)
635
840
  tensor_indexes.append(tensor_index)
636
- tensor_positions += (i,)
637
841
  elif i in tensor_positions:
638
- invalid = const_utils.check_type_invalid(F.dtype(index), mstype.int_type)
639
- if invalid:
640
- exp_msg = const_utils.gen_exception_msg(
641
- "The tensor element in tuple index must be int type, but got {}.", F.dtype(index))
642
- const_utils.raise_index_error(exp_msg)
643
- tensor_index = F.cast(index, mstype.int64)
644
- tuple_index_new += (tensor_index,)
645
- tensor_indexes.append(tensor_index)
842
+ tuple_index_new, tensor_indexes, tensor_positions_new = \
843
+ _tensor_getitem_by_tuple_parse_tensor_index(index, tuple_index_new,
844
+ tensor_indexes, tensor_positions_new)
845
+ if tuple_index_new is None:
846
+ return Tensor([])
646
847
  elif i in slice_positions:
647
848
  slice_ele_list_index = const_utils.transform_slice_to_ele_list(index, dim_size)
648
849
  slice_shapes += (len(slice_ele_list_index),)
850
+ slice_positions_new.append(len(tuple_index_new))
649
851
  tuple_index_new += (slice_ele_list_index,)
650
852
  slice_indexes.append(slice_ele_list_index)
651
-
652
853
  tensor_indexes_shapes = hyper_map(F.shape, tensor_indexes)
653
854
  broadcast_shape, index_tensor_new_shape, final_shape, fancy_position = \
654
- const_utils.generate_index_info_from_tuple_of_mixed_tensors(tensor_positions, tensor_indexes_shapes,
855
+ const_utils.generate_index_info_from_tuple_of_mixed_tensors(tensor_positions_new, tensor_indexes_shapes,
655
856
  slice_shapes, op_name)
656
857
 
858
+ tuple_index_len = len(tuple_index)
657
859
  if 0 in final_shape + data_shape:
658
860
  if tuple_index_len < len(data_shape):
659
861
  final_shape = final_shape + data_shape[tuple_index_len:]
@@ -662,11 +864,11 @@ def _tensor_getitem_by_tuple(data, tuple_index, op_name):
662
864
  final_index_tensors = []
663
865
  slice_cnt = 0
664
866
  for i, index in enumerate(tuple_index_new):
665
- if i in tensor_positions:
867
+ if i in tensor_positions_new:
666
868
  transform_tensor = _transform_indexing_tensor(broadcast_shape, final_shape, index_tensor_new_shape,
667
869
  index)
668
870
  final_index_tensors.append(transform_tensor)
669
- elif i in slice_positions:
871
+ elif i in slice_positions_new:
670
872
  slice_index_tensor = convert_slice_to_tensor(index, final_shape, slice_cnt, broadcast_shape,
671
873
  slice_shapes, fancy_position)
672
874
  final_index_tensors.append(slice_index_tensor)
@@ -701,7 +903,6 @@ def _generate_indices_from_tuple(data, tuple_index, op_name, fancy_position):
701
903
  slice_positions, _, _, int_positions, _, tensor_positions, sequence_positions = \
702
904
  const_utils.get_pos_of_indexes_types(indexes_types, op_name)
703
905
  tuple_index_new, slice_shapes = (), ()
704
-
705
906
  for i, (index, dim_size) in enumerate(zip(tuple_index, data_shape)):
706
907
  if i in int_positions:
707
908
  int_index = const_utils.check_range(index, dim_size)
@@ -718,7 +919,7 @@ def _generate_indices_from_tuple(data, tuple_index, op_name, fancy_position):
718
919
  invalid = const_utils.check_type_invalid(F.dtype(index), mstype.int_type)
719
920
  if invalid:
720
921
  exp_msg = const_utils.gen_exception_msg(
721
- "The tensor element in tuple index must be int type, but got {}.", F.dtype(index))
922
+ "The tensor element in tuple index must be int or bool type, but got {}.", F.dtype(index))
722
923
  const_utils.raise_index_error(exp_msg)
723
924
  tensor_index = F.cast(index, mstype.int64)
724
925
  tuple_index_new += (tensor_index,)
@@ -783,11 +984,11 @@ def _generate_updates_from_sequence(data, index, value, op_type):
783
984
  def _generate_updates_from_tensor(data, index, value, op_type):
784
985
  """Generate an updates tensor from a tensor."""
785
986
  value = value.astype(data.dtype)
786
- if is_shape_unknown(F.shape(data)):
987
+ if F.is_sequence_value_unknown(F.shape(data)):
787
988
  data_shape = F.dyn_shape(data)
788
989
  index_shape = F.dyn_shape(index)
789
990
  updates_shape = const_utils.generate_updates_shape(data_shape, index_shape, op_type, True)
790
- updates = dynamic_broadcast_to(value, updates_shape)
991
+ updates = ops.broadcast_to(value, updates_shape)
791
992
  return updates
792
993
  updates_shape = const_utils.generate_updates_shape(data.shape, index.shape, op_type, False)
793
994
  need_broadcast = const_utils.check_two_shapes_need_broadcast(updates_shape, value.shape)
@@ -807,6 +1008,7 @@ def tensor_setitem_by_tensor(self, index, value):
807
1008
 
808
1009
 
809
1010
  def tensor_setitem_by_tuple(self, index, value):
1011
+ index = convert_tupleslice_to_tensor(index)
810
1012
  if isinstance(value, (int, float, bool)):
811
1013
  index = format_tuple_indices(index)
812
1014
  return tensor_setitem_by_tuple_with_number(self, index, value)
@@ -824,6 +1026,7 @@ def tensor_setitem_by_number(self, index, value):
824
1026
 
825
1027
 
826
1028
  def tensor_setitem_by_slice(self, index, value):
1029
+ index = convert_variable_to_tensor_slice(index)
827
1030
  if isinstance(value, (int, float, bool)):
828
1031
  return tensor_setitem_by_slice_with_number(self, index, value)
829
1032
  if isinstance(value, Tensor):
@@ -844,28 +1047,29 @@ def _tensor_setitem_by_int_tensor_with_tensor(data, index, value):
844
1047
  if F.rank(index) == 0:
845
1048
  index = F.expand_dims(index, -1)
846
1049
  updates = _generate_updates_from_tensor(data, index, value, const_utils.SET_ITEM_BY_ONE_TENSOR)
847
- index = F.select(index < 0, index + F.shape(data)[0], index)
1050
+ data_shape = F.shape(data)
1051
+ first_val = data_shape[0]
1052
+ if not F.isconstant(first_val):
1053
+ first_val = -1
1054
+ index = F.select(index < 0, index + first_val, index)
848
1055
  index = F.expand_dims(index, -1)
849
1056
  if F.rank(index) < 2:
850
1057
  index = F.expand_dims(index, 0)
851
1058
  updates = F.expand_dims(updates, 0)
1059
+ if is_parameter(data):
1060
+ F.scatter_nd_update(data, index, updates)
1061
+ return data
852
1062
  return F.tensor_scatter_update(data, index, updates)
853
1063
 
854
1064
 
855
1065
  def _tensor_setitem_by_bool_tensor_with_tensor(data, index, value):
856
1066
  """Set a tensor item by a bool tensor with a tensor."""
857
- index_shape = F.shape(index)
858
- data_shape = F.shape(data)
859
- const_utils.check_equal(data_shape, index_shape,
860
- "The tensor(shape={}) and tensor index(shape={}) should be the same shape.")
861
- size = F.shape_mul(F.shape(value))
862
- const_utils.check_equal(1, size,
863
- "When assign value is a tensor, its size should be {}, but current size is {}.")
864
- dtype = F.dtype(data)
865
- u_cast = F.cast(value, dtype)
866
- one_data = F.ones_like(data)
867
- u = F.tensor_mul(one_data, u_cast)
868
- result = F.select(index, u, data)
1067
+ index = index.reshape(const_utils.generate_padding_shape(index.shape, len(data.shape)))
1068
+ index = F.broadcast_to(index, data.shape)
1069
+ value = F.cast(value, F.dtype(data))
1070
+ value = value.reshape(const_utils.generate_padding_shape(value.shape, len(data.shape)))
1071
+ value = F.broadcast_to(value, data.shape)
1072
+ result = F.select(index, value, data)
869
1073
  return result
870
1074
 
871
1075
 
@@ -876,7 +1080,7 @@ def tensor_setitem_by_tensor_with_tensor(data, index, value_tensor):
876
1080
  if tensor_dtype == const_utils.INT_:
877
1081
  return _tensor_setitem_by_int_tensor_with_tensor(data, index, value_tensor)
878
1082
 
879
- if is_shape_unknown(F.shape(data)):
1083
+ if F.is_sequence_value_unknown(F.shape(data)):
880
1084
  const_utils.raise_unimplemented_error(
881
1085
  "Not supported to the dynamic shape tensor slice by using tensor of Boolean type")
882
1086
  return _tensor_setitem_by_bool_tensor_with_tensor(data, index, value_tensor)
@@ -890,11 +1094,13 @@ def tensor_setitem_by_tensor_with_number(data, index, value):
890
1094
  def tensor_setitem_by_tensor_with_sequence(data, index, value):
891
1095
  """Assigns the tensor by tensor with tuple value."""
892
1096
  index_dtype = F.dtype(index)
893
- invalid = const_utils.check_type_invalid(index_dtype, (mstype.int32, mstype.int64))
894
- if invalid:
895
- exp_msg = const_utils.gen_exception_msg("The tensor index must be int type, but got {}.", index_dtype)
896
- const_utils.raise_index_error(exp_msg)
897
- return _tensor_setitem_by_tensor_with_sequence(data, index, value)
1097
+ if index_dtype in (mstype.int32, mstype.int64):
1098
+ return _tensor_setitem_by_tensor_with_sequence(data, index, value)
1099
+ if index_dtype == mstype.bool_:
1100
+ return _tensor_setitem_by_bool_tensor_with_sequence(data, index, value)
1101
+ exp_msg = const_utils.gen_exception_msg("The tensor index must be int or bool type, but got {}.", index_dtype)
1102
+ const_utils.raise_index_error(exp_msg)
1103
+ return None
898
1104
 
899
1105
 
900
1106
  def _tensor_setitem_by_tensor_with_sequence(data, index, value):
@@ -904,6 +1110,12 @@ def _tensor_setitem_by_tensor_with_sequence(data, index, value):
904
1110
  return F.tensor_scatter_update(data, index, updates)
905
1111
 
906
1112
 
1113
+ def _tensor_setitem_by_bool_tensor_with_sequence(data, index, value):
1114
+ """Set a tensor item by a bool tensor with a tuple."""
1115
+ value = sequence_to_tensor(value, F.dtype(data))
1116
+ return _tensor_setitem_by_bool_tensor_with_tensor(data, index, value)
1117
+
1118
+
907
1119
  def tensor_setitem_by_slice_with_number(data, input_slice, value):
908
1120
  """Givens a scalar assign to tensor by slice"""
909
1121
  value = F.fill(F.dtype(data), (), value)
@@ -929,7 +1141,7 @@ def tensor_copy_slice_from_slice(data, input_slice, value):
929
1141
  if dim0_size >= data_shape[0]:
930
1142
  dim0_size = data_shape[0:1]
931
1143
  value_shape = P.Concat(-1)((dim0_size, data_shape[1:]))
932
- value = dynamic_broadcast_to(value, value_shape)
1144
+ value = ops.broadcast_to(value, value_shape)
933
1145
  return copy_slice(data, value.astype(data.dtype), start_tensor, stop_tensor, step_tensor)
934
1146
 
935
1147
 
@@ -941,7 +1153,7 @@ def tensor_setitem_by_slice_with_tensor(data, input_slice, value):
941
1153
  data_shape = F.shape(data)
942
1154
  step = const_utils.get_step_from_slice(input_slice)
943
1155
  if step == 1 and not const_utils.is_ascend():
944
- if is_shape_unknown(data_shape):
1156
+ if F.is_sequence_value_unknown(data_shape):
945
1157
  return tensor_copy_slice_from_slice(data, input_slice, value)
946
1158
  start, stop, step = const_utils.normalize_slice(input_slice, data.shape[0])
947
1159
  dim0_size = stop - start
@@ -950,7 +1162,7 @@ def tensor_setitem_by_slice_with_tensor(data, input_slice, value):
950
1162
  value_shape = (dim0_size,) + const_utils.tuple_slice(data.shape, 1, None)
951
1163
  value = _broadcast(value_shape, value)
952
1164
  return copy_slice(data, value.astype(data.dtype), (start,), (stop,), (step,))
953
- if is_shape_unknown(data_shape):
1165
+ if F.is_sequence_value_unknown(data_shape):
954
1166
  const_utils.raise_unimplemented_error(
955
1167
  "Not supported to take the subscript of dynamic shape tensor slice setitem")
956
1168
  indices = const_utils.slice2indices(input_slice, data_shape)
@@ -974,7 +1186,7 @@ def tensor_copy_slice_from_tuple(data, tuple_index, value):
974
1186
  dim1_start, dim1_stop, _ = get_slice_stride(tuple_index[1], data_shape[1])
975
1187
  if dim1_stop - dim1_start <= 0:
976
1188
  return data
977
- dim0_start = check_range(tuple_index[0], data_shape[0])
1189
+ dim0_start = _scalar_to_tensor(tuple_index[0])
978
1190
  dim0_stop = dim0_start + const_utils.make_tensor(1)
979
1191
  start = (dim0_start, dim1_start)
980
1192
  stop = (dim0_stop, dim1_stop)
@@ -986,7 +1198,7 @@ def tensor_copy_slice_from_tuple(data, tuple_index, value):
986
1198
  if dim1_size > data_shape[1]:
987
1199
  dim1_size = data_shape[1:2]
988
1200
  value_shape = P.Concat(-1)((dim1_size, data_shape[2:]))
989
- value = dynamic_broadcast_to(value, value_shape)
1201
+ value = ops.broadcast_to(value, value_shape)
990
1202
  return copy_slice(data, value.astype(data.dtype), start_tensor, stop_tensor, step_tensor)
991
1203
 
992
1204
 
@@ -996,7 +1208,7 @@ def tensor_setitem_by_tuple_with_tensor(data, tuple_index, value):
996
1208
  tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name)
997
1209
 
998
1210
  if const_utils.use_copy_slice(tuple_index) and not const_utils.is_ascend():
999
- if is_shape_unknown(F.shape(data)):
1211
+ if F.is_sequence_value_unknown(F.shape(data)):
1000
1212
  return tensor_copy_slice_from_tuple(data, tuple_index, value)
1001
1213
  dim1_start, dim1_stop, _ = const_utils.normalize_slice(tuple_index[1], data.shape[1])
1002
1214
  if dim1_stop - dim1_start <= 0:
@@ -1016,7 +1228,6 @@ def tensor_setitem_by_tuple_with_tensor(data, tuple_index, value):
1016
1228
  if len(tuple_index) == 1:
1017
1229
  data[tuple_index[0]] = value
1018
1230
  return data
1019
-
1020
1231
  indexes_types = hyper_map(toptypeof, tuple_index)
1021
1232
  contain_type = const_utils.tuple_index_type_cnt(indexes_types, op_name)
1022
1233
 
@@ -1050,14 +1261,20 @@ def tensor_setitem_by_number_with_sequence(data, index, value):
1050
1261
  def tensor_setitem_by_number_with_tensor(data, index, value):
1051
1262
  """Assigns the tensor by number with tensor value."""
1052
1263
  data_shape = F.shape(data)
1053
- if is_shape_unknown(data_shape):
1054
- index = check_range(index, F.dyn_shape(data)[0])
1264
+ if F.is_sequence_value_unknown(data_shape):
1265
+ index = _scalar_to_tensor(index)
1055
1266
  index = F.expand_dims(index, -1)
1056
1267
  return _tensor_setitem_by_int_tensor_with_tensor(data, index, value)
1057
1268
 
1269
+ dim_size = data_shape[0]
1270
+ if index < -dim_size or index >= dim_size:
1271
+ raise IndexError(f'index {index} is out of bounds for axis 0 with size {dim_size}')
1058
1272
  index = const_utils.int_to_index(index, data_shape)
1059
1273
  value_shape = const_utils.tuple_slice(F.shape(index), None, -1)
1060
1274
  value = _broadcast(value_shape, value.astype(F.dtype(data)))
1275
+ if is_parameter(data):
1276
+ F.scatter_nd_update(data, index, value)
1277
+ return data
1061
1278
  return F.tensor_scatter_update(data, index, value)
1062
1279
 
1063
1280
 
@@ -1065,7 +1282,7 @@ def tensor_setitem_by_ellipsis_with_number(data, value):
1065
1282
  """Assigns the tensor by ellipsis with number value."""
1066
1283
  data_shape = F.shape(data)
1067
1284
  data_dtype = F.dtype(data)
1068
- if is_shape_unknown(data_shape):
1285
+ if F.is_sequence_value_unknown(data_shape):
1069
1286
  value = F.fill(F.dtype(data), (), value)
1070
1287
  return tensor_setitem_by_ellipsis_with_tensor(data, value)
1071
1288
  return F.fill(data_dtype, data_shape, value)
@@ -1077,9 +1294,9 @@ def tensor_setitem_by_ellipsis_with_tensor(data, value):
1077
1294
  data_dtype = F.dtype(data)
1078
1295
  value = value.astype(data_dtype)
1079
1296
 
1080
- if is_shape_unknown(data_shape):
1297
+ if F.is_sequence_value_unknown(data_shape):
1081
1298
  data_shape = F.dyn_shape(data)
1082
- data = dynamic_broadcast_to(value, data_shape)
1299
+ data = ops.broadcast_to(value, data_shape)
1083
1300
  return data
1084
1301
  value_shape = F.shape(value)
1085
1302
  source_shape = const_utils.get_source_shape(data_shape, value_shape)
@@ -1107,9 +1324,9 @@ def tensor_setitem_by_bool(data, index, value):
1107
1324
  elif isinstance(value, float):
1108
1325
  value = const_utils.make_tensor(value, mstype.float32)
1109
1326
 
1110
- if is_shape_unknown(data_shape) and index:
1327
+ if F.is_sequence_value_unknown(data_shape) and index:
1111
1328
  data_shape = F.dyn_shape(data)
1112
- data = dynamic_broadcast_to(value, data_shape)
1329
+ data = ops.broadcast_to(value, data_shape)
1113
1330
  return data
1114
1331
  value_shape = F.shape(value)
1115
1332
  source_shape = const_utils.get_source_shape(data_shape, value_shape)
@@ -1135,6 +1352,8 @@ def format_list_indices(list_indices, length):
1135
1352
  # If eyery element in list is bool, it's treated as 1-D bool tensor.
1136
1353
  # If every element in list is int(not all bool), it's treated as int tensor.
1137
1354
  if const_utils.judge_indexes_types(indices_types, mstype.int_type + (mstype.bool_,)):
1355
+ if not F.isconstant(length):
1356
+ return const_utils.sequence_to_index(list_indices, None)
1138
1357
  return const_utils.sequence_to_index(list_indices, length)
1139
1358
  # If list contains other types(.../list/tuple/None), it's treated as a tuple
1140
1359
  return const_utils.deep_tuple(list_indices)
@@ -1154,11 +1373,34 @@ def format_tuple_indices(tuple_indices):
1154
1373
  return res
1155
1374
 
1156
1375
 
1376
+ @_primexpr
1377
+ def remove_expanded_dims_parse_bool_tensor_index(index_out, indices_out, shapes, cur_dim):
1378
+ """ Parse bool tensor index """
1379
+ index_out = index_out.nonzero()
1380
+ if index_out.shape[0] == 0:
1381
+ return None, shapes, cur_dim
1382
+ for i in range(index_out.shape[1]):
1383
+ out = index_out[:, i]
1384
+ indices_out += (out,)
1385
+ shapes.append(F.shape(out))
1386
+ cur_dim += 1
1387
+ return indices_out, shapes, cur_dim
1388
+
1389
+
1390
+ def remove_expanded_dims_parse_tensor_index(index_out, indices_out, shapes, cur_dim):
1391
+ """ Parse tensor index """
1392
+ if index_out.dtype == mstype.bool_:
1393
+ return remove_expanded_dims_parse_bool_tensor_index(index_out, indices_out, shapes, cur_dim)
1394
+ indices_out += (index_out,)
1395
+ shapes.append(F.shape(index_out))
1396
+ cur_dim += 1
1397
+ return indices_out, shapes, cur_dim
1398
+
1399
+
1157
1400
  def remove_expanded_dims(tuple_index, data_shape, value):
1158
1401
  """Removes expanded dimensions in tuple_index and value."""
1159
- op_name = const_utils.TENSOR_SETITEM
1160
1402
  not_expanded_dim = ()
1161
- shapes = ()
1403
+ shapes = []
1162
1404
  has_true = False
1163
1405
  has_false = False
1164
1406
  has_sequence = False
@@ -1185,17 +1427,18 @@ def remove_expanded_dims(tuple_index, data_shape, value):
1185
1427
  idx_advanced = 0
1186
1428
  idx_tensor = i
1187
1429
  if isinstance(index_out, Tensor):
1188
- if F.rank(index_out) > 0:
1430
+ indices_out, shapes, cur_dim = \
1431
+ remove_expanded_dims_parse_tensor_index(index_out, indices_out, shapes, cur_dim)
1432
+ if indices_out is None:
1433
+ return False, value, 0
1434
+ if index_out.dtype != mstype.bool_ and F.rank(index_out) > 0:
1189
1435
  has_sequence = True
1190
- indices_out += (index_out,)
1191
- shapes += (F.shape(index_out),)
1192
- cur_dim += 1
1193
1436
  has_true = has_true or index_out is True
1194
1437
  has_false = has_false or index_out is False
1195
1438
  else:
1196
1439
  const_utils.raise_index_error('invalid index type')
1197
1440
 
1198
- broadcast_shape = const_utils.generate_broadcast_shape(shapes, op_name)
1441
+ broadcast_shape = const_utils.generate_broadcast_shape(shapes, const_utils.TENSOR_SETITEM)
1199
1442
  if has_false:
1200
1443
  if F.shape_mul(broadcast_shape) != 1:
1201
1444
  const_utils.raise_index_error('unable to broadcast indices')
@@ -1222,11 +1465,21 @@ def format_index(idx, data_shape, cur_dim):
1222
1465
  elif isinstance(idx, int) and not isinstance(idx, bool):
1223
1466
  idx = const_utils.make_tensor(idx, mstype.int64, None, data_shape[cur_dim])
1224
1467
  elif isinstance(idx, Tensor):
1225
- # does not take bool tensor into account since it's currently not supported
1226
- idx = F.select(idx < 0, idx + data_shape[cur_dim], idx)
1468
+ tensor_dtype = const_utils.get_index_tensor_dtype(idx.dtype)
1469
+ if tensor_dtype == const_utils.INT_:
1470
+ idx = F.select(idx < 0, idx + data_shape[cur_dim], idx)
1471
+ elif tensor_dtype == const_utils.BOOL_:
1472
+ # index with tensor(bool) type is processed in remove_expanded_dims()
1473
+ pass
1227
1474
  return idx
1228
1475
 
1229
1476
 
1477
+ @_primexpr
1478
+ def _check_shape_mul(shape):
1479
+ if F.shape_mul(shape) == 0:
1480
+ raise ValueError('zero-size tensors are not supported.')
1481
+
1482
+
1230
1483
  def reduce_(a, reduce_fn, cmp_fn=None, axis=None, keepdims=False, initial=None, where=True, dtype=None):
1231
1484
  """
1232
1485
  Applies comparison based on cmp_fn and reduction based on reduce_fn.
@@ -1243,8 +1496,7 @@ def reduce_(a, reduce_fn, cmp_fn=None, axis=None, keepdims=False, initial=None,
1243
1496
  not isinstance(initial, (int, float, bool, Tensor))):
1244
1497
  const_utils.raise_type_error('initial must be scalar')
1245
1498
 
1246
- if F.shape_mul(shape) == 0:
1247
- const_utils.raise_value_error('zero-size tensors are not supported.')
1499
+ _check_shape_mul(shape)
1248
1500
 
1249
1501
  if initial is not None:
1250
1502
  if isinstance(initial, Tensor):