mindspore 2.0.0a0__cp37-none-any.whl → 2.0.0rc1__cp37-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (693) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Third_Party_Open_Source_Software_Notice +9064 -0
  3. mindspore/__init__.py +4 -2
  4. mindspore/_akg/akg/composite/build_module.py +11 -0
  5. mindspore/_akg/akg/config/repository_cuda.json +11 -0
  6. mindspore/_akg/akg/tvm/contrib/nvcc.py +4 -3
  7. mindspore/_c_dataengine.cpython-37m-aarch64-linux-gnu.so +0 -0
  8. mindspore/_c_expression.cpython-37m-aarch64-linux-gnu.so +0 -0
  9. mindspore/_c_mindrecord.cpython-37m-aarch64-linux-gnu.so +0 -0
  10. mindspore/_check_jit_forbidden_api.py +102 -0
  11. mindspore/_checkparam.py +1066 -1001
  12. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +4 -3
  13. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +50 -48
  14. mindspore/_extends/parallel_compile/akg_compiler/util.py +9 -4
  15. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +4 -4
  16. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +9 -4
  17. mindspore/_extends/parse/__init__.py +5 -3
  18. mindspore/_extends/parse/namespace.py +16 -1
  19. mindspore/_extends/parse/parser.py +107 -22
  20. mindspore/_extends/parse/resources.py +0 -7
  21. mindspore/_extends/parse/standard_method.py +885 -413
  22. mindspore/_mindspore_offline_debug.cpython-37m-aarch64-linux-gnu.so +0 -0
  23. mindspore/amp.py +52 -57
  24. mindspore/bin/cache_admin +0 -0
  25. mindspore/bin/cache_server +0 -0
  26. mindspore/boost/boost.py +2 -2
  27. mindspore/boost/boost_cell_wrapper.py +38 -20
  28. mindspore/boost/dim_reduce.py +3 -3
  29. mindspore/boost/group_loss_scale_manager.py +1 -1
  30. mindspore/common/__init__.py +4 -6
  31. mindspore/common/_decorator.py +2 -0
  32. mindspore/common/_register_for_adapter.py +55 -0
  33. mindspore/common/_stub_tensor.py +201 -0
  34. mindspore/common/_utils.py +41 -7
  35. mindspore/common/api.py +215 -141
  36. mindspore/common/dtype.py +8 -1
  37. mindspore/common/dump.py +2 -2
  38. mindspore/common/initializer.py +4 -2
  39. mindspore/common/jit_config.py +17 -13
  40. mindspore/common/mutable.py +33 -13
  41. mindspore/common/parameter.py +23 -21
  42. mindspore/common/seed.py +8 -24
  43. mindspore/common/sparse_tensor.py +62 -41
  44. mindspore/common/tensor.py +852 -1154
  45. mindspore/communication/__init__.py +2 -2
  46. mindspore/communication/_comm_helper.py +11 -4
  47. mindspore/communication/management.py +22 -21
  48. mindspore/config/op_info.config +501 -1008
  49. mindspore/config/super_bar_config.json +512 -0
  50. mindspore/context.py +201 -23
  51. mindspore/dataset/__init__.py +6 -6
  52. mindspore/dataset/audio/__init__.py +7 -7
  53. mindspore/dataset/audio/transforms.py +670 -30
  54. mindspore/dataset/audio/utils.py +47 -4
  55. mindspore/dataset/audio/validators.py +223 -1
  56. mindspore/dataset/callback/ds_callback.py +2 -2
  57. mindspore/dataset/core/config.py +210 -14
  58. mindspore/dataset/core/validator_helpers.py +2 -2
  59. mindspore/{parallel/nn/layers.py → dataset/debug/__init__.py} +7 -8
  60. mindspore/dataset/debug/debug_hook.py +65 -0
  61. mindspore/dataset/debug/pre_defined_hook.py +67 -0
  62. mindspore/dataset/engine/__init__.py +7 -3
  63. mindspore/dataset/engine/cache_client.py +1 -1
  64. mindspore/dataset/engine/datasets.py +322 -66
  65. mindspore/dataset/engine/datasets_audio.py +80 -76
  66. mindspore/dataset/engine/datasets_standard_format.py +51 -38
  67. mindspore/dataset/engine/datasets_text.py +232 -118
  68. mindspore/dataset/engine/datasets_user_defined.py +41 -17
  69. mindspore/dataset/engine/datasets_vision.py +746 -225
  70. mindspore/dataset/engine/graphdata.py +75 -10
  71. mindspore/dataset/engine/iterators.py +45 -5
  72. mindspore/dataset/engine/offload.py +48 -28
  73. mindspore/dataset/engine/validators.py +117 -8
  74. mindspore/dataset/text/__init__.py +6 -5
  75. mindspore/dataset/text/transforms.py +86 -3
  76. mindspore/dataset/text/utils.py +6 -4
  77. mindspore/dataset/text/validators.py +25 -0
  78. mindspore/dataset/transforms/__init__.py +3 -2
  79. mindspore/dataset/transforms/c_transforms.py +1 -1
  80. mindspore/dataset/transforms/transforms.py +2 -2
  81. mindspore/dataset/utils/__init__.py +2 -1
  82. mindspore/dataset/utils/line_reader.py +121 -0
  83. mindspore/dataset/vision/__init__.py +2 -3
  84. mindspore/dataset/vision/c_transforms.py +9 -9
  85. mindspore/dataset/vision/py_transforms.py +5 -5
  86. mindspore/dataset/vision/py_transforms_util.py +2 -0
  87. mindspore/dataset/vision/transforms.py +160 -161
  88. mindspore/dataset/vision/utils.py +3 -3
  89. mindspore/experimental/map_parameter.py +38 -26
  90. mindspore/include/OWNERS +0 -1
  91. mindspore/include/api/callback/callback.h +9 -13
  92. mindspore/include/api/callback/ckpt_saver.h +2 -2
  93. mindspore/include/api/callback/loss_monitor.h +2 -2
  94. mindspore/include/api/callback/lr_scheduler.h +5 -5
  95. mindspore/include/api/callback/time_monitor.h +2 -2
  96. mindspore/include/api/callback/train_accuracy.h +4 -6
  97. mindspore/include/api/cfg.h +19 -6
  98. mindspore/include/api/context.h +44 -9
  99. mindspore/include/api/delegate.h +1 -1
  100. mindspore/include/api/metrics/accuracy.h +2 -2
  101. mindspore/include/api/metrics/metrics.h +4 -3
  102. mindspore/include/api/model.h +9 -4
  103. mindspore/include/api/model_parallel_runner.h +2 -2
  104. mindspore/include/api/net.h +12 -11
  105. mindspore/include/api/serialization.h +19 -3
  106. mindspore/include/api/types.h +3 -3
  107. mindspore/include/dataset/constants.h +7 -0
  108. mindspore/include/dataset/text.h +59 -0
  109. mindspore/include/mindapi/base/type_id.h +1 -0
  110. mindspore/lib/libdnnl.so.2 +0 -0
  111. mindspore/lib/libicudata.so.69 +0 -0
  112. mindspore/lib/libicui18n.so.69 +0 -0
  113. mindspore/lib/libicuuc.so.69 +0 -0
  114. mindspore/lib/libmindspore.so +0 -0
  115. mindspore/lib/libmindspore_backend.so +0 -0
  116. mindspore/lib/libmindspore_common.so +0 -0
  117. mindspore/lib/libmindspore_core.so +0 -0
  118. mindspore/lib/libmindspore_glog.so.0 +0 -0
  119. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  120. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  121. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  122. mindspore/lib/libmindspore_shared_lib.so +0 -0
  123. mindspore/lib/libmpi_adapter.so +0 -0
  124. mindspore/lib/libmpi_collective.so +0 -0
  125. mindspore/lib/libnnacl.so +0 -0
  126. mindspore/lib/libopencv_core.so.4.5 +0 -0
  127. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  128. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  129. mindspore/lib/libps_cache.so +0 -0
  130. mindspore/lib/plugin/ascend/libakg.so +0 -0
  131. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  132. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  133. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  134. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  135. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  136. mindspore/lib/plugin/cpu/libakg.so +0 -0
  137. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  138. mindspore/lib/plugin/{libmindspore_ascend.so → libmindspore_ascend.so.2} +0 -0
  139. mindspore/log.py +1 -1
  140. mindspore/mindrecord/filereader.py +18 -0
  141. mindspore/mindrecord/filewriter.py +197 -34
  142. mindspore/mindrecord/shardreader.py +9 -0
  143. mindspore/mindrecord/shardwriter.py +1 -1
  144. mindspore/mindrecord/tools/cifar100_to_mr.py +3 -3
  145. mindspore/mindrecord/tools/cifar10_to_mr.py +3 -3
  146. mindspore/mindrecord/tools/csv_to_mr.py +3 -3
  147. mindspore/mindrecord/tools/imagenet_to_mr.py +16 -11
  148. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  149. mindspore/mindrecord/tools/tfrecord_to_mr.py +6 -6
  150. mindspore/nn/__init__.py +0 -4
  151. mindspore/nn/cell.py +204 -132
  152. mindspore/nn/dynamic_lr.py +1 -1
  153. mindspore/nn/grad/cell_grad.py +7 -6
  154. mindspore/nn/layer/__init__.py +5 -4
  155. mindspore/nn/layer/activation.py +40 -89
  156. mindspore/nn/layer/basic.py +255 -624
  157. mindspore/nn/layer/channel_shuffle.py +7 -6
  158. mindspore/nn/layer/combined.py +1 -1
  159. mindspore/nn/layer/container.py +41 -4
  160. mindspore/nn/layer/conv.py +64 -28
  161. mindspore/nn/layer/dense.py +9 -8
  162. mindspore/nn/layer/embedding.py +27 -25
  163. mindspore/nn/layer/image.py +53 -46
  164. mindspore/nn/layer/math.py +97 -105
  165. mindspore/nn/layer/normalization.py +117 -86
  166. mindspore/nn/layer/padding.py +185 -95
  167. mindspore/nn/layer/pooling.py +817 -414
  168. mindspore/nn/layer/rnn_cells.py +10 -15
  169. mindspore/nn/layer/rnns.py +37 -38
  170. mindspore/nn/layer/thor_layer.py +11 -12
  171. mindspore/nn/layer/timedistributed.py +5 -5
  172. mindspore/nn/layer/transformer.py +701 -0
  173. mindspore/nn/learning_rate_schedule.py +8 -8
  174. mindspore/nn/loss/__init__.py +5 -4
  175. mindspore/nn/loss/loss.py +334 -199
  176. mindspore/nn/optim/ada_grad.py +6 -6
  177. mindspore/nn/optim/adadelta.py +2 -3
  178. mindspore/nn/optim/adafactor.py +4 -5
  179. mindspore/nn/optim/adam.py +126 -62
  180. mindspore/nn/optim/adamax.py +3 -4
  181. mindspore/nn/optim/adasum.py +6 -6
  182. mindspore/nn/optim/asgd.py +2 -2
  183. mindspore/nn/optim/ftrl.py +67 -38
  184. mindspore/nn/optim/lamb.py +4 -5
  185. mindspore/nn/optim/lars.py +2 -2
  186. mindspore/nn/optim/lazyadam.py +43 -4
  187. mindspore/nn/optim/momentum.py +6 -5
  188. mindspore/nn/optim/optimizer.py +3 -1
  189. mindspore/nn/optim/proximal_ada_grad.py +2 -2
  190. mindspore/nn/optim/rmsprop.py +1 -1
  191. mindspore/nn/optim/rprop.py +8 -9
  192. mindspore/nn/optim/sgd.py +19 -13
  193. mindspore/nn/optim/thor.py +10 -15
  194. mindspore/nn/probability/__init__.py +0 -2
  195. mindspore/nn/probability/bijector/bijector.py +4 -4
  196. mindspore/nn/probability/bijector/invert.py +1 -1
  197. mindspore/nn/probability/bijector/softplus.py +2 -2
  198. mindspore/nn/probability/bnn_layers/dense_variational.py +1 -1
  199. mindspore/nn/probability/bnn_layers/layer_distribution.py +2 -2
  200. mindspore/nn/probability/distribution/_utils/utils.py +9 -15
  201. mindspore/nn/probability/distribution/bernoulli.py +3 -3
  202. mindspore/nn/probability/distribution/beta.py +1 -1
  203. mindspore/nn/probability/distribution/categorical.py +5 -7
  204. mindspore/nn/probability/distribution/cauchy.py +3 -3
  205. mindspore/nn/probability/distribution/distribution.py +2 -2
  206. mindspore/nn/probability/distribution/exponential.py +2 -2
  207. mindspore/nn/probability/distribution/gamma.py +3 -3
  208. mindspore/nn/probability/distribution/geometric.py +1 -1
  209. mindspore/nn/probability/distribution/gumbel.py +3 -3
  210. mindspore/nn/probability/distribution/half_normal.py +15 -11
  211. mindspore/nn/probability/distribution/laplace.py +16 -13
  212. mindspore/nn/probability/distribution/logistic.py +2 -2
  213. mindspore/nn/probability/distribution/normal.py +1 -1
  214. mindspore/nn/probability/distribution/poisson.py +1 -1
  215. mindspore/nn/probability/distribution/student_t.py +20 -15
  216. mindspore/nn/probability/distribution/transformed_distribution.py +4 -4
  217. mindspore/nn/probability/distribution/uniform.py +2 -2
  218. mindspore/nn/reinforcement/_tensors_queue.py +3 -3
  219. mindspore/nn/reinforcement/tensor_array.py +2 -2
  220. mindspore/nn/sparse/sparse.py +2 -2
  221. mindspore/nn/wrap/cell_wrapper.py +27 -10
  222. mindspore/nn/wrap/grad_reducer.py +2 -2
  223. mindspore/nn/wrap/loss_scale.py +40 -24
  224. mindspore/numpy/array_creations.py +33 -22
  225. mindspore/numpy/array_ops.py +35 -30
  226. mindspore/numpy/logic_ops.py +6 -27
  227. mindspore/numpy/math_ops.py +22 -19
  228. mindspore/numpy/utils.py +1 -1
  229. mindspore/numpy/utils_const.py +108 -58
  230. mindspore/ops/_constants.py +0 -6
  231. mindspore/ops/_grad/__init__.py +2 -1
  232. mindspore/ops/_grad/grad_array_ops.py +86 -117
  233. mindspore/ops/_grad/grad_base.py +23 -1
  234. mindspore/ops/_grad/grad_clip_ops.py +2 -3
  235. mindspore/ops/_grad/grad_comm_ops.py +34 -24
  236. mindspore/ops/_grad/grad_implementations.py +9 -45
  237. mindspore/ops/_grad/grad_inner_ops.py +47 -4
  238. mindspore/ops/_grad/grad_math_ops.py +142 -117
  239. mindspore/ops/_grad/grad_nn_ops.py +71 -165
  240. mindspore/ops/_grad/grad_sequence_ops.py +296 -0
  241. mindspore/ops/_grad/grad_sparse.py +7 -6
  242. mindspore/ops/_grad_experimental/__init__.py +1 -0
  243. mindspore/ops/_grad_experimental/grad_array_ops.py +150 -15
  244. mindspore/ops/_grad_experimental/grad_image_ops.py +16 -7
  245. mindspore/ops/_grad_experimental/grad_inner_ops.py +1 -22
  246. mindspore/ops/_grad_experimental/grad_linalg_ops.py +4 -11
  247. mindspore/ops/_grad_experimental/grad_math_ops.py +210 -89
  248. mindspore/ops/_grad_experimental/grad_nn_ops.py +26 -22
  249. mindspore/ops/_grad_experimental/grad_scalar_ops.py +112 -0
  250. mindspore/ops/_grad_experimental/grad_sparse_ops.py +49 -8
  251. mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +1 -1
  252. mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +2 -2
  253. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +2 -2
  254. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +2 -2
  255. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +4 -4
  256. mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +3 -3
  257. mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +1 -1
  258. mindspore/ops/_op_impl/_custom_op/correction_mul.py +2 -2
  259. mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +2 -2
  260. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -5
  261. mindspore/ops/_op_impl/_custom_op/dsd_impl.py +1 -1
  262. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +2 -2
  263. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +2 -2
  264. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +2 -2
  265. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +2 -2
  266. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +2 -2
  267. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +2 -2
  268. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +2 -2
  269. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +2 -2
  270. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +2 -2
  271. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +2 -2
  272. mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +1 -1
  273. mindspore/ops/_op_impl/_custom_op/img2col_impl.py +1 -1
  274. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
  275. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +1 -1
  276. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +1 -1
  277. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +1 -1
  278. mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +2 -2
  279. mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +0 -4
  280. mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +1 -1
  281. mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +2 -2
  282. mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +2 -2
  283. mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +1 -1
  284. mindspore/ops/_op_impl/aicpu/__init__.py +236 -4
  285. mindspore/ops/_op_impl/aicpu/abs.py +36 -0
  286. mindspore/ops/_op_impl/aicpu/{adaptive_avg_pool_2d_v1.py → adaptive_avg_pool_2d.py} +6 -5
  287. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
  288. mindspore/ops/_op_impl/aicpu/add.py +43 -0
  289. mindspore/ops/_op_impl/aicpu/addcdiv.py +0 -32
  290. mindspore/ops/_op_impl/aicpu/addcmul.py +0 -84
  291. mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
  292. mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -43
  293. mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
  294. mindspore/{compression/common/__init__.py → ops/_op_impl/aicpu/bessel_i0.py} +15 -8
  295. mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
  296. mindspore/ops/_op_impl/aicpu/conj.py +11 -0
  297. mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +0 -3
  298. mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
  299. mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +43 -0
  300. mindspore/ops/_op_impl/aicpu/{adaptive_avg_pool_2d_grad_v1.py → digamma.py} +7 -9
  301. mindspore/ops/_op_impl/aicpu/flatten.py +1 -0
  302. mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
  303. mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
  304. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +1 -1
  305. mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
  306. mindspore/ops/_op_impl/aicpu/greater.py +41 -0
  307. mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
  308. mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
  309. mindspore/ops/_op_impl/aicpu/less.py +41 -0
  310. mindspore/{nn/probability/infer/variational/__init__.py → ops/_op_impl/aicpu/lgamma.py} +16 -10
  311. mindspore/ops/_op_impl/aicpu/mirror_pad.py +0 -4
  312. mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +0 -4
  313. mindspore/ops/_op_impl/aicpu/mul.py +3 -1
  314. mindspore/ops/_op_impl/aicpu/multinomial.py +14 -6
  315. mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
  316. mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
  317. mindspore/ops/_op_impl/aicpu/ones_like.py +0 -2
  318. mindspore/ops/_op_impl/aicpu/polar.py +32 -0
  319. mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
  320. mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
  321. mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
  322. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
  323. mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
  324. mindspore/ops/_op_impl/aicpu/resize_bicubic.py +2 -8
  325. mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +1 -1
  326. mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
  327. mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
  328. mindspore/ops/_op_impl/aicpu/scatter_elements.py +4 -0
  329. mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +2 -0
  330. mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
  331. mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
  332. mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
  333. mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
  334. mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
  335. mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +0 -24
  336. mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
  337. mindspore/ops/_op_impl/aicpu/sparse_slice.py +4 -0
  338. mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +6 -0
  339. mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
  340. mindspore/ops/_op_impl/aicpu/trans_data.py +1 -0
  341. mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
  342. mindspore/ops/_op_impl/aicpu/uniform.py +34 -0
  343. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +1 -0
  344. mindspore/ops/_op_impl/aicpu/unique_consecutive.py +10 -2
  345. mindspore/ops/_op_impl/cpu/dynamic_shape.py +5 -1
  346. mindspore/ops/_op_impl/cpu/sparse_slice.py +4 -0
  347. mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +6 -0
  348. mindspore/ops/_op_impl/cpu/tensor_shape.py +5 -1
  349. mindspore/ops/_op_impl/tbe/__init__.py +27 -611
  350. mindspore/ops/_op_impl/tbe/assign_add_ds.py +1 -0
  351. mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
  352. mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +1 -1
  353. mindspore/ops/_op_impl/tbe/batch_matmul_ds.py +1 -0
  354. mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
  355. mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +1 -1
  356. mindspore/ops/_op_impl/tbe/bn_infer_grad.py +4 -2
  357. mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -1
  358. mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -1
  359. mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +6 -4
  360. mindspore/ops/_op_impl/tbe/cast.py +0 -2
  361. mindspore/ops/_op_impl/tbe/cast_ds.py +3 -3
  362. mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +1 -0
  363. mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +2 -2
  364. mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +1 -1
  365. mindspore/ops/_op_impl/tbe/gather_nd.py +1 -0
  366. mindspore/ops/_op_impl/tbe/{index_add.py → inplace_index_add.py} +3 -6
  367. mindspore/ops/_op_impl/tbe/matmul_ds.py +2 -0
  368. mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +35 -0
  369. mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +35 -0
  370. mindspore/ops/_op_impl/tbe/scatter_mul.py +2 -0
  371. mindspore/ops/_op_impl/tbe/scatter_nd_add.py +0 -2
  372. mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
  373. mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +1 -1
  374. mindspore/ops/_op_impl/tbe/trans_data_ds.py +15 -5
  375. mindspore/ops/_register_for_op.py +1 -0
  376. mindspore/ops/_utils/__init__.py +1 -2
  377. mindspore/ops/_utils/utils.py +19 -40
  378. mindspore/ops/_vmap/vmap_array_ops.py +116 -38
  379. mindspore/ops/_vmap/vmap_base.py +16 -9
  380. mindspore/ops/_vmap/vmap_convolution_ops.py +7 -10
  381. mindspore/ops/_vmap/vmap_grad_math_ops.py +4 -4
  382. mindspore/ops/_vmap/vmap_grad_nn_ops.py +7 -5
  383. mindspore/ops/_vmap/vmap_image_ops.py +12 -5
  384. mindspore/ops/_vmap/vmap_math_ops.py +46 -5
  385. mindspore/ops/_vmap/vmap_nn_ops.py +15 -21
  386. mindspore/ops/_vmap/vmap_random_ops.py +1 -1
  387. mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
  388. mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
  389. mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +150 -0
  390. mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +66 -0
  391. mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
  392. mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
  393. mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
  394. mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +33 -0
  395. mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +220 -106
  396. mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
  397. mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +240 -0
  398. mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +247 -0
  399. mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +247 -0
  400. mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +315 -0
  401. mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +278 -0
  402. mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +58 -0
  403. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +138 -0
  404. mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
  405. mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
  406. mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +22 -23
  407. mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +16 -17
  408. mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +27 -0
  409. mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
  410. mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
  411. mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
  412. mindspore/ops/bprop_mindir/Elu_bprop.mindir +16 -0
  413. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  414. mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +39 -41
  415. mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +16 -0
  416. mindspore/ops/bprop_mindir/Flatten_bprop.mindir +41 -43
  417. mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +51 -57
  418. mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
  419. mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +16 -0
  420. mindspore/ops/bprop_mindir/HSwish_bprop.mindir +16 -0
  421. mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
  422. mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +126 -0
  423. mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +15 -0
  424. mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +30 -0
  425. mindspore/ops/bprop_mindir/LRN_bprop.mindir +43 -0
  426. mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
  427. mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +23 -0
  428. mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +74 -0
  429. mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +74 -0
  430. mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +75 -0
  431. mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +65 -0
  432. mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
  433. mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +27 -0
  434. mindspore/ops/bprop_mindir/Mish_bprop.mindir +35 -0
  435. mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
  436. mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
  437. mindspore/ops/bprop_mindir/OneHot_bprop.mindir +24 -25
  438. mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
  439. mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
  440. mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
  441. mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +29 -0
  442. mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +82 -0
  443. mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +16 -0
  444. mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
  445. mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +18 -19
  446. mindspore/ops/bprop_mindir/Reshape_bprop.mindir +53 -53
  447. mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +29 -0
  448. mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +77 -85
  449. mindspore/ops/bprop_mindir/SeLU_bprop.mindir +21 -0
  450. mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +21 -0
  451. mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
  452. mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +16 -0
  453. mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +36 -0
  454. mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  455. mindspore/ops/bprop_mindir/Softplus_bprop.mindir +16 -0
  456. mindspore/ops/bprop_mindir/Softsign_bprop.mindir +33 -0
  457. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  458. mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +37 -39
  459. mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +70 -72
  460. mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
  461. mindspore/ops/bprop_mindir/Tanh_bprop.mindir +66 -0
  462. mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
  463. mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
  464. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +17 -17
  465. mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +32 -0
  466. mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +38 -0
  467. mindspore/ops/bprop_mindir/generate_mindir.py +2 -0
  468. mindspore/ops/composite/__init__.py +7 -8
  469. mindspore/ops/composite/base.py +101 -47
  470. mindspore/ops/composite/math_ops.py +188 -158
  471. mindspore/ops/composite/multitype_ops/_compile_utils.py +415 -170
  472. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +142 -87
  473. mindspore/ops/composite/multitype_ops/add_impl.py +6 -1
  474. mindspore/ops/composite/multitype_ops/div_impl.py +2 -3
  475. mindspore/ops/composite/multitype_ops/getitem_impl.py +31 -3
  476. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +31 -0
  477. mindspore/ops/composite/multitype_ops/greater_impl.py +31 -0
  478. mindspore/ops/composite/multitype_ops/in_impl.py +9 -0
  479. mindspore/ops/composite/multitype_ops/less_equal_impl.py +31 -0
  480. mindspore/ops/composite/multitype_ops/less_impl.py +31 -0
  481. mindspore/ops/composite/multitype_ops/mul_impl.py +21 -5
  482. mindspore/ops/composite/multitype_ops/not_in_impl.py +9 -0
  483. mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -4
  484. mindspore/ops/composite/multitype_ops/setitem_impl.py +21 -3
  485. mindspore/ops/composite/multitype_ops/sub_impl.py +1 -1
  486. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +35 -4
  487. mindspore/ops/function/__init__.py +152 -8
  488. mindspore/ops/function/array_func.py +2555 -674
  489. mindspore/ops/function/clip_func.py +209 -13
  490. mindspore/ops/function/debug_func.py +2 -2
  491. mindspore/ops/function/grad/__init__.py +2 -1
  492. mindspore/ops/function/grad/grad_func.py +147 -62
  493. mindspore/ops/function/image_func.py +54 -38
  494. mindspore/ops/function/linalg_func.py +167 -16
  495. mindspore/ops/function/math_func.py +4849 -1492
  496. mindspore/ops/function/nn_func.py +2573 -988
  497. mindspore/ops/function/other_func.py +115 -0
  498. mindspore/ops/function/parameter_func.py +3 -3
  499. mindspore/ops/function/random_func.py +790 -73
  500. mindspore/ops/function/sparse_func.py +98 -78
  501. mindspore/ops/function/sparse_unary_func.py +54 -53
  502. mindspore/ops/function/spectral_func.py +27 -24
  503. mindspore/ops/function/vmap_func.py +22 -2
  504. mindspore/ops/functional.py +97 -37
  505. mindspore/ops/op_info_register.py +70 -28
  506. mindspore/ops/operations/__init__.py +47 -14
  507. mindspore/ops/operations/_csr_ops.py +7 -7
  508. mindspore/ops/operations/_embedding_cache_ops.py +5 -5
  509. mindspore/ops/operations/_grad_ops.py +276 -187
  510. mindspore/ops/operations/_inner_ops.py +319 -113
  511. mindspore/ops/operations/_ms_kernel.py +10 -8
  512. mindspore/ops/operations/_ocr_ops.py +9 -9
  513. mindspore/ops/operations/_opaque_predicate_registry.py +4 -0
  514. mindspore/ops/operations/_quant_ops.py +137 -102
  515. mindspore/ops/operations/_rl_inner_ops.py +121 -60
  516. mindspore/ops/operations/_scalar_ops.py +466 -0
  517. mindspore/ops/operations/_sequence_ops.py +1004 -2
  518. mindspore/ops/operations/_tensor_array.py +10 -11
  519. mindspore/ops/operations/_thor_ops.py +1 -1
  520. mindspore/ops/operations/array_ops.py +801 -466
  521. mindspore/ops/operations/comm_ops.py +51 -49
  522. mindspore/ops/operations/control_ops.py +2 -2
  523. mindspore/ops/operations/custom_ops.py +123 -44
  524. mindspore/ops/operations/debug_ops.py +24 -24
  525. mindspore/ops/operations/image_ops.py +240 -153
  526. mindspore/ops/operations/inner_ops.py +34 -50
  527. mindspore/ops/operations/linalg_ops.py +31 -9
  528. mindspore/ops/operations/math_ops.py +988 -757
  529. mindspore/ops/operations/nn_ops.py +965 -819
  530. mindspore/ops/operations/other_ops.py +51 -40
  531. mindspore/ops/operations/random_ops.py +204 -122
  532. mindspore/ops/operations/rl_ops.py +8 -9
  533. mindspore/ops/operations/sparse_ops.py +254 -93
  534. mindspore/ops/operations/spectral_ops.py +35 -3
  535. mindspore/ops/primitive.py +111 -9
  536. mindspore/parallel/_auto_parallel_context.py +189 -83
  537. mindspore/parallel/_offload_context.py +185 -0
  538. mindspore/parallel/_parallel_serialization.py +99 -7
  539. mindspore/parallel/_ps_context.py +9 -5
  540. mindspore/parallel/_recovery_context.py +1 -1
  541. mindspore/parallel/_tensor.py +7 -1
  542. mindspore/{nn/transformer → parallel/_transformer}/__init__.py +6 -6
  543. mindspore/{nn/transformer → parallel/_transformer}/layers.py +6 -37
  544. mindspore/{nn/transformer → parallel/_transformer}/loss.py +4 -7
  545. mindspore/{nn/transformer → parallel/_transformer}/moe.py +20 -16
  546. mindspore/{nn/transformer → parallel/_transformer}/op_parallel_config.py +3 -3
  547. mindspore/{nn/transformer → parallel/_transformer}/transformer.py +48 -111
  548. mindspore/parallel/_utils.py +1 -2
  549. mindspore/parallel/algo_parameter_config.py +1 -1
  550. mindspore/parallel/checkpoint_transform.py +37 -34
  551. mindspore/parallel/shard.py +17 -18
  552. mindspore/profiler/common/validator/validate_path.py +2 -2
  553. mindspore/profiler/envprofiling.py +69 -47
  554. mindspore/profiler/parser/ascend_timeline_generator.py +49 -42
  555. mindspore/profiler/parser/base_timeline_generator.py +49 -56
  556. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +98 -78
  557. mindspore/profiler/parser/hwts_log_parser.py +1 -1
  558. mindspore/profiler/parser/integrator.py +15 -14
  559. mindspore/profiler/parser/minddata_analyzer.py +2 -2
  560. mindspore/profiler/parser/msadvisor_analyzer.py +12 -25
  561. mindspore/profiler/parser/msadvisor_parser.py +2 -4
  562. mindspore/profiler/parser/optime_parser.py +17 -18
  563. mindspore/profiler/parser/profiler_info.py +2 -1
  564. mindspore/profiler/profiling.py +218 -186
  565. mindspore/rewrite/__init__.py +3 -1
  566. mindspore/rewrite/api/node.py +1 -114
  567. mindspore/rewrite/api/node_type.py +3 -0
  568. mindspore/rewrite/api/pattern_engine.py +31 -1
  569. mindspore/rewrite/api/scoped_value.py +4 -4
  570. mindspore/rewrite/api/symbol_tree.py +3 -78
  571. mindspore/rewrite/api/tree_node_helper.py +1 -1
  572. mindspore/rewrite/ast_creator_register.py +1 -0
  573. mindspore/rewrite/ast_helpers/__init__.py +2 -2
  574. mindspore/rewrite/ast_helpers/ast_creator.py +1 -2
  575. mindspore/rewrite/ast_helpers/ast_finder.py +65 -0
  576. mindspore/rewrite/ast_helpers/ast_modifier.py +11 -3
  577. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +18 -2
  578. mindspore/rewrite/namespace.py +0 -2
  579. mindspore/rewrite/node.py +157 -11
  580. mindspore/rewrite/parsers/assign_parser.py +231 -53
  581. mindspore/rewrite/parsers/class_def_parser.py +187 -109
  582. mindspore/rewrite/parsers/for_parser.py +24 -14
  583. mindspore/rewrite/parsers/function_def_parser.py +21 -4
  584. mindspore/rewrite/parsers/if_parser.py +6 -2
  585. mindspore/rewrite/sparsify/__init__.py +0 -0
  586. mindspore/rewrite/sparsify/sparse_transformer.py +448 -0
  587. mindspore/rewrite/sparsify/sparsify.py +109 -0
  588. mindspore/rewrite/sparsify/utils.py +173 -0
  589. mindspore/rewrite/symbol_tree.py +256 -133
  590. mindspore/rewrite/symbol_tree_builder.py +38 -1
  591. mindspore/run_check/_check_version.py +69 -63
  592. mindspore/run_check/run_check.py +2 -1
  593. mindspore/scipy/linalg.py +10 -114
  594. mindspore/scipy/ops.py +2 -2
  595. mindspore/scipy/ops_wrapper.py +1 -1
  596. mindspore/scipy/optimize/_bfgs.py +1 -1
  597. mindspore/scipy/optimize/_lagrange.py +200 -0
  598. mindspore/scipy/optimize/line_search.py +3 -2
  599. mindspore/scipy/optimize/minimize.py +41 -2
  600. mindspore/scipy/sparse/__init__.py +2 -2
  601. mindspore/scipy/sparse/linalg.py +4 -464
  602. mindspore/scipy/utils.py +1 -1
  603. mindspore/scipy/utils_const.py +7 -1
  604. mindspore/train/__init__.py +1 -1
  605. mindspore/train/_utils.py +28 -5
  606. mindspore/train/amp.py +273 -102
  607. mindspore/train/callback/_backup_and_restore.py +5 -5
  608. mindspore/train/callback/_callback.py +2 -2
  609. mindspore/train/callback/_checkpoint.py +3 -3
  610. mindspore/train/callback/_early_stop.py +3 -3
  611. mindspore/train/callback/_lambda_callback.py +2 -2
  612. mindspore/train/callback/_landscape.py +29 -31
  613. mindspore/train/callback/_loss_monitor.py +3 -3
  614. mindspore/train/callback/_on_request_exit.py +3 -3
  615. mindspore/train/callback/_reduce_lr_on_plateau.py +4 -4
  616. mindspore/train/callback/_summary_collector.py +23 -16
  617. mindspore/train/callback/_time_monitor.py +3 -3
  618. mindspore/train/checkpoint_pb2.py +68 -8
  619. mindspore/train/data_sink.py +15 -3
  620. mindspore/train/dataset_helper.py +10 -15
  621. mindspore/train/loss_scale_manager.py +8 -11
  622. mindspore/train/metrics/__init__.py +1 -1
  623. mindspore/train/metrics/bleu_score.py +1 -1
  624. mindspore/train/metrics/confusion_matrix.py +1 -1
  625. mindspore/train/metrics/cosine_similarity.py +1 -1
  626. mindspore/train/metrics/dice.py +2 -2
  627. mindspore/train/metrics/fbeta.py +1 -1
  628. mindspore/train/metrics/hausdorff_distance.py +4 -3
  629. mindspore/train/metrics/mean_surface_distance.py +2 -2
  630. mindspore/train/metrics/occlusion_sensitivity.py +1 -1
  631. mindspore/train/metrics/perplexity.py +1 -1
  632. mindspore/train/metrics/precision.py +1 -1
  633. mindspore/train/metrics/recall.py +1 -1
  634. mindspore/train/metrics/roc.py +2 -2
  635. mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
  636. mindspore/train/mind_ir_pb2.py +116 -37
  637. mindspore/train/model.py +45 -28
  638. mindspore/train/serialization.py +295 -188
  639. mindspore/train/summary/_summary_adapter.py +1 -1
  640. mindspore/train/summary/summary_record.py +43 -13
  641. mindspore/train/train_thor/convert_utils.py +2 -2
  642. mindspore/train/train_thor/dataset_helper.py +3 -3
  643. mindspore/version.py +1 -1
  644. {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/METADATA +3 -2
  645. {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/RECORD +648 -574
  646. mindspore/compression/__init__.py +0 -19
  647. mindspore/compression/common/constant.py +0 -124
  648. mindspore/compression/export/__init__.py +0 -19
  649. mindspore/compression/export/quant_export.py +0 -515
  650. mindspore/compression/quant/__init__.py +0 -28
  651. mindspore/compression/quant/qat.py +0 -634
  652. mindspore/compression/quant/quant_utils.py +0 -462
  653. mindspore/compression/quant/quantizer.py +0 -68
  654. mindspore/nn/layer/quant.py +0 -1868
  655. mindspore/nn/layer/rnn_utils.py +0 -90
  656. mindspore/nn/probability/dpn/__init__.py +0 -22
  657. mindspore/nn/probability/dpn/vae/__init__.py +0 -25
  658. mindspore/nn/probability/dpn/vae/cvae.py +0 -140
  659. mindspore/nn/probability/dpn/vae/vae.py +0 -124
  660. mindspore/nn/probability/infer/__init__.py +0 -22
  661. mindspore/nn/probability/infer/variational/elbo.py +0 -70
  662. mindspore/nn/probability/infer/variational/svi.py +0 -84
  663. mindspore/nn/probability/toolbox/__init__.py +0 -22
  664. mindspore/nn/probability/toolbox/anomaly_detection.py +0 -99
  665. mindspore/nn/probability/toolbox/uncertainty_evaluation.py +0 -364
  666. mindspore/nn/probability/transforms/__init__.py +0 -22
  667. mindspore/nn/probability/transforms/transform_bnn.py +0 -262
  668. mindspore/nn/probability/zhusuan/__init__.py +0 -18
  669. mindspore/nn/probability/zhusuan/framework/__init__.py +0 -18
  670. mindspore/nn/probability/zhusuan/framework/bn.py +0 -95
  671. mindspore/nn/probability/zhusuan/variational/__init__.py +0 -18
  672. mindspore/nn/probability/zhusuan/variational/elbo.py +0 -46
  673. mindspore/ops/_op_impl/aicpu/parallel_concat.py +0 -42
  674. mindspore/ops/_op_impl/tbe/gather_v2.py +0 -56
  675. mindspore/ops/bprop_mindir/AssignAdd_bprop.mindir +0 -19
  676. mindspore/ops/bprop_mindir/Cast_bprop.mindir +0 -19
  677. mindspore/ops/bprop_mindir/LogicalOr_bprop.mindir +0 -19
  678. mindspore/ops/bprop_mindir/MatMul_bprop.mindir +0 -0
  679. mindspore/ops/bprop_mindir/ReLU_bprop.mindir +0 -17
  680. mindspore/ops/bprop_mindir/Transpose_bprop.mindir +0 -0
  681. mindspore/ops/bprop_mindir/UpdateState_bprop.mindir +0 -15
  682. mindspore/ops/composite/array_ops.py +0 -241
  683. mindspore/ops/composite/clip_ops.py +0 -134
  684. mindspore/ops/composite/random_ops.py +0 -426
  685. mindspore/ops/composite/vmap_ops.py +0 -38
  686. mindspore/parallel/nn/__init__.py +0 -42
  687. mindspore/parallel/nn/loss.py +0 -22
  688. mindspore/parallel/nn/moe.py +0 -21
  689. mindspore/parallel/nn/op_parallel_config.py +0 -22
  690. mindspore/parallel/nn/transformer.py +0 -31
  691. {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/WHEEL +0 -0
  692. {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/entry_points.txt +0 -0
  693. {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/top_level.txt +0 -0
mindspore/rewrite/node.py CHANGED
@@ -20,7 +20,7 @@ import inspect
20
20
  from mindspore.nn import Cell
21
21
  from mindspore.ops import Primitive
22
22
  from mindspore import log as logger
23
- from .._checkparam import Validator, Rel
23
+ from .. import _checkparam as Validator
24
24
  from .ast_helpers import AstModifier
25
25
  from .api.scoped_value import ScopedValue, ValueType
26
26
  from .api.node_type import NodeType
@@ -222,6 +222,32 @@ class Node:
222
222
  return cls(NodeType.Output, ast_node, None, ScopedValue.create_naming_value("return"), real_return_values, {},
223
223
  name, None)
224
224
 
225
+ @classmethod
226
+ def create_mathops_node(cls, ast_node: ast.AST, targets: [ScopedValue],
227
+ op_type: ScopedValue, args: [ScopedValue],
228
+ ops: {str: list}, name: str = ""):
229
+ """
230
+ Class method of Node. Instantiate an instance of node whose type is `MathOps` .
231
+ A mathops node is used to represent a node with mathematical operations, such as
232
+ `y = a + b` , `y = not a` , `y = 0 < a < 1`, `y = a or b` , etc.
233
+
234
+ Args:
235
+ ast_node ([ast.AST, optional]): An instance of ast.AST represents corresponding node in ast. The type of
236
+ node is ast.Assign, and the type of ast_node.value is one of ast.BinOp, ast.UnaryOp, ast.BoolOp and
237
+ ast.Compare.
238
+ targets (list[ScopedValue]): Targets of mathematical operations. A list of instance of `ScopedValue`.
239
+ See detail in docstring of Node class.
240
+ op_type (ScopedValue): The type of ast_node.value saved by string. A ScopedValue with NamingValue type.
241
+ args (list[ScopedValue]): Values participating in the mathematical operations. All values are saved
242
+ sequentially in the list.
243
+ ops (dict[str:ScopedValue]): Operators participating in the mathematical operations. All operators are
244
+ saved sequentially in the dict, and keys are numbers in string format, such as {'0':'add', '1':'sub'}.
245
+ name (str): A string represents name of node. Name of node will be unique when inserted into `SymbolTree`.
246
+ Name of node also used as field name in network class. The format of mathops node name
247
+ is 'AstNodeName_AstOpName_n'.
248
+ """
249
+ return cls(NodeType.MathOps, ast_node, targets, op_type, args, ops, name, None)
250
+
225
251
  @staticmethod
226
252
  def create_call_op(op: Union[Cell, Primitive], ast_node: Optional[ast.AST], targets: [Union[ScopedValue, str]],
227
253
  func: Union[ScopedValue, str], args: [ScopedValue] = None, kwargs: {str: ScopedValue}=None,
@@ -624,7 +650,8 @@ class Node:
624
650
  """
625
651
  self._targets = targets
626
652
  if self._node_type in (NodeType.CallCell, NodeType.CallMethod, NodeType.CallPrimitive,
627
- NodeType.Tree, NodeType.CallFunction):
653
+ NodeType.Tree, NodeType.CallFunction, NodeType.CellContainer,
654
+ NodeType.MathOps):
628
655
  self._sync_assign_targets_to_ast()
629
656
 
630
657
  def get_func(self) -> ScopedValue:
@@ -721,12 +748,12 @@ class Node:
721
748
  ValueError: If `node` has multi-outputs while `out_idx` is None or `out_idx` is not offered.
722
749
  """
723
750
  Validator.check_value_type("node", node, [Node], "Node")
724
- Validator.check_int_range(arg_idx, 0, self._args_num, Rel.INC_LEFT, "arg_idx")
751
+ Validator.check_int_range(arg_idx, 0, self._args_num, Validator.INC_LEFT, "arg_idx")
725
752
  if out_idx is None:
726
753
  if len(node._targets) != 1:
727
754
  raise RuntimeError("node should has one output when out_idx is not provided")
728
755
  out_idx = 0
729
- Validator.check_int_range(out_idx, 0, len(node._targets), Rel.INC_LEFT, "arg_idx")
756
+ Validator.check_int_range(out_idx, 0, len(node._targets), Validator.INC_LEFT, "arg_idx")
730
757
  new_arg = node._targets[out_idx]
731
758
  self._normalized_args[self._normalized_args_keys[arg_idx]] = new_arg
732
759
  self._sync_arg()
@@ -743,7 +770,7 @@ class Node:
743
770
  Raises:
744
771
  ValueError: If `index` is out of range.
745
772
  """
746
- Validator.check_int_range(index, 0, self._args_num, Rel.INC_LEFT, "index")
773
+ Validator.check_int_range(index, 0, self._args_num, Validator.INC_LEFT, "index")
747
774
  Validator.check_value_type("arg", arg, [ScopedValue, str], "Node")
748
775
  if isinstance(arg, str):
749
776
  arg = ScopedValue.create_naming_value(arg)
@@ -763,7 +790,7 @@ class Node:
763
790
  Raises:
764
791
  TypeError: Element of new argument is not an instance of ScopedValue.
765
792
  """
766
- Validator.check_int_range(len(args), 0, self._args_num, Rel.INC_LEFT, "Length of args")
793
+ Validator.check_int_range(len(args), 0, self._args_num, Validator.INC_LEFT, "Length of args")
767
794
  Validator.check_element_type_of_iterable("args", args, [ScopedValue], "Node")
768
795
  for arg_index, arg in enumerate(args):
769
796
  if not isinstance(arg, ScopedValue):
@@ -783,7 +810,7 @@ class Node:
783
810
  TypeError: Value of new argument is not an instance of ScopedValue.
784
811
  RuntimeError: Length of new arguments is not equal to length of old arguments.
785
812
  """
786
- Validator.check_int_range(len(kwargs), 0, self._kwargs_num, Rel.INC_LEFT, "Length of kwargs")
813
+ Validator.check_int_range(len(kwargs), 0, self._kwargs_num, Validator.INC_LEFT, "Length of kwargs")
787
814
  Validator.check_element_type_of_dict("kwargs", kwargs, [str], [ScopedValue], "Node")
788
815
  for key, arg in kwargs.items():
789
816
  if key not in self._normalized_args.keys() or key not in self._normalized_args_keys:
@@ -1099,7 +1126,7 @@ class Node:
1099
1126
  elt.id = scoped_value.value
1100
1127
  elif isinstance(elt, ast.Attribute) and isinstance(elt.value, ast.Name):
1101
1128
  elt.value.id = scoped_value.scope
1102
- elt.value = scoped_value.value
1129
+ elt.attr = scoped_value.value
1103
1130
  else:
1104
1131
  raise RuntimeError("Only support constant or symbol in tuple now")
1105
1132
  else:
@@ -1133,14 +1160,50 @@ class Node:
1133
1160
  raise RuntimeError("Unsupported return value type: ", return_value_ast)
1134
1161
  ast.fix_missing_locations(return_ast)
1135
1162
 
1163
+ def _sync_mathops_node_args_to_ast(self):
1164
+ """
1165
+ Sync values from self._normalized_args to the ast node for mathematical operations.
1166
+ """
1167
+ if self._ast_node is None:
1168
+ return
1169
+ if not isinstance(self._ast_node, ast.Assign):
1170
+ raise TypeError(f"type of node should be ast.Assign, but got {type(self._ast_node)}")
1171
+ mathops_node = self._ast_node.value
1172
+ if isinstance(mathops_node, ast.BinOp):
1173
+ left = mathops_node.left
1174
+ right = mathops_node.right
1175
+ AstModifier.update_arg_value(self._normalized_args.get(self._normalized_args_keys[0]), left)
1176
+ AstModifier.update_arg_value(self._normalized_args.get(self._normalized_args_keys[1]), right)
1177
+ elif isinstance(mathops_node, ast.UnaryOp):
1178
+ operand = mathops_node.operand
1179
+ AstModifier.update_arg_value(self._normalized_args.get(self._normalized_args_keys[0]), operand)
1180
+ elif isinstance(mathops_node, ast.BoolOp):
1181
+ values = mathops_node.values
1182
+ for arg_index in range(self._args_num):
1183
+ arg_value = self._normalized_args.get(self._normalized_args_keys[arg_index])
1184
+ AstModifier.update_arg_value(arg_value, values[arg_index])
1185
+ elif isinstance(mathops_node, ast.Compare):
1186
+ left = mathops_node.left
1187
+ AstModifier.update_arg_value(self._normalized_args.get(self._normalized_args_keys[0]), left)
1188
+ comparators = mathops_node.comparators
1189
+ for arg_index in range(1, self._args_num):
1190
+ arg_value = self._normalized_args.get(self._normalized_args_keys[arg_index])
1191
+ AstModifier.update_arg_value(arg_value, comparators[arg_index - 1])
1192
+ else:
1193
+ raise TypeError("The type of 'mathops_node' must be one of (ast.BinOp, ast.UnaryOp, "
1194
+ "ast.BoolOp, ast.Compare), but got ", type(mathops_node))
1195
+
1136
1196
  def _sync_arg(self):
1137
1197
  """Sync _normalized_args to corresponding ast node when updated."""
1138
- if self._node_type in (NodeType.CallCell, NodeType.CallPrimitive, NodeType.Tree):
1198
+ if self._node_type in (NodeType.CallCell, NodeType.CallPrimitive, NodeType.Tree,\
1199
+ NodeType.CellContainer, NodeType.CallFunction):
1139
1200
  self._sync_call_cell_args_to_ast()
1140
1201
  elif self._node_type == NodeType.Output:
1141
1202
  self._sync_return_node_to_ast()
1142
1203
  elif self._node_type == NodeType.CallMethod:
1143
1204
  self._sync_call_method_args_to_ast()
1205
+ elif self._node_type == NodeType.MathOps:
1206
+ self._sync_mathops_node_args_to_ast()
1144
1207
 
1145
1208
 
1146
1209
  class TreeNode(Node):
@@ -1188,8 +1251,6 @@ class TreeNode(Node):
1188
1251
  instance: Object in network corresponding to this node.
1189
1252
  """
1190
1253
 
1191
- if not isinstance(instance, Cell):
1192
- raise ValueError("Argument instance should be a Cell: ", type(instance))
1193
1254
  non_custom_args = Node._handle_custom_obj_in_args(args)
1194
1255
  non_custom_kwargs = Node._handle_custom_obj_in_kwargs(kwargs)
1195
1256
  new_targets = Node._handle_targets(targets)
@@ -1198,3 +1259,88 @@ class TreeNode(Node):
1198
1259
  if ast_node is None:
1199
1260
  ast_node = AstModifier.create_call_assign(new_targets, func, non_custom_args, non_custom_kwargs)
1200
1261
  return cls(tree, ast_node, new_targets, func, args, kwargs, name, instance)
1262
+
1263
+
1264
+ class CellContainer(Node):
1265
+ """ Container for saving cell-objects node. """
1266
+ class _Visitor():
1267
+ """ A iterator of CellContainer nodes. """
1268
+ def __init__(self, cellcontainer):
1269
+ self._cellcontainer = cellcontainer
1270
+
1271
+ def __len__(self):
1272
+ """ Get the number of nodes. """
1273
+ return self._cellcontainer.node_count
1274
+
1275
+ def __iter__(self):
1276
+ """Create an iterator over the CellContainer."""
1277
+ count = len(self._cellcontainer.node_list)
1278
+ i = 0
1279
+ while i < count:
1280
+ curr = self._cellcontainer.node_list[i]
1281
+ if curr.valid:
1282
+ yield curr
1283
+ i += 1
1284
+
1285
+ def __init__(self, ast_node: ast.AST, targets: [ScopedValue], func: ScopedValue,
1286
+ args: [ScopedValue], kwargs: {str: ScopedValue}, name: str, instance):
1287
+ """Constructor of CellContainer.
1288
+
1289
+ Args:
1290
+ ast_node (ast.AST): An instance of ast.AST represents corresponding node in ast.
1291
+ targets (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
1292
+ func ([ScopedValue, optional]): An instance of ScopedValue. See detail in docstring of Node class.
1293
+ args (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
1294
+ kwargs (dict{str: ScopedValue}): A list of instance of ScopedValue. See detail in docstring of Node class.
1295
+ name (str): A string represents name of node. Name of node will be unique when inserted into SymbolTree.
1296
+ Name of node also used as field name in network class.
1297
+ instance: Object in network corresponding to this node.
1298
+ """
1299
+ if isinstance(func, str):
1300
+ func = ScopedValue.create_naming_value(func)
1301
+ super().__init__(NodeType.CellContainer, ast_node, targets, func, args, kwargs, name, instance)
1302
+ self._node_list = list()
1303
+ self._node_count = 0
1304
+
1305
+ @property
1306
+ def node_count(self):
1307
+ """Number of nodes."""
1308
+ return len(self._node_list)
1309
+
1310
+ @property
1311
+ def node_list(self):
1312
+ """ Get node list. """
1313
+ return self._node_list
1314
+
1315
+ def append(self, node):
1316
+ """ Append new node to node list. """
1317
+ setattr(node, "container", self)
1318
+ setattr(node, "valid", True)
1319
+ node.set_belong_symbol_tree(self.get_belong_symbol_tree())
1320
+ self._node_list.append(node)
1321
+ # when creating a cell_container, node instance is already in SequentialCell cell_list
1322
+ # so here we need to write a if judgement
1323
+ if node.get_instance() not in self.get_instance().cell_list:
1324
+ self.get_instance().append(node.get_instance())
1325
+
1326
+ def erase(self, node):
1327
+ """Erase node form container."""
1328
+ index_node = self.node_list.index(node)
1329
+ index_instance = self.get_instance().cell_list.index(node.get_instance())
1330
+ if index_node != index_instance:
1331
+ raise RuntimeError("In MindSpore Rewrite CellContainer, erasing a node raises index error!!!")
1332
+ setattr(node, "valid", False)
1333
+ del self.get_instance()[index_node]
1334
+ del self._node_list[index_node]
1335
+
1336
+ def insert(self, index, node):
1337
+ """Insert node into container"""
1338
+ self.node_list.insert(index, node)
1339
+ setattr(node, "container", self)
1340
+ setattr(node, "valid", True)
1341
+ node.set_belong_symbol_tree(self.get_belong_symbol_tree())
1342
+ self.get_instance()._insert(index, node.get_instance())
1343
+
1344
+ def nodes(self):
1345
+ """ Return a iterator of node."""
1346
+ return self._Visitor(self)
@@ -15,21 +15,23 @@
15
15
  """Parse ast.Assign in construct function to node of SymbolTree."""
16
16
  from typing import Union
17
17
  import ast
18
+ import sys
19
+ import inspect
18
20
  import astunparse
19
21
 
20
22
  from mindspore import log as logger
21
23
  from mindspore._extends.parse.namespace import CellNamespace
22
- from mindspore.nn import Cell
24
+ from mindspore.nn import Cell, SequentialCell
23
25
  from mindspore.ops import operations as P
24
26
  from mindspore.ops import Primitive
25
27
  from mindspore.rewrite.parser_register import ParserRegister
26
28
  from mindspore.rewrite.namespace import is_subtree, is_functional, get_functional
27
29
  from mindspore.rewrite.symbol_tree import SymbolTree
28
- from mindspore.rewrite.node import Node, TreeNode
30
+ from mindspore.rewrite.node import Node, TreeNode, CellContainer
29
31
  from mindspore.rewrite.parser import Parser
30
32
  from mindspore.rewrite.parser_register import reg_parser
31
33
  from mindspore.rewrite.api.scoped_value import ScopedValue, ValueType
32
- from mindspore.rewrite.symbol_tree_builder import SymbolTreeBuilder
34
+ from mindspore.rewrite.symbol_tree_builder import SymbolTreeBuilder, FunctionSymbolTreeBuilder
33
35
  from mindspore.rewrite.ast_helpers import AstReplacer, AstModifier
34
36
  from mindspore.rewrite.common.event import Event
35
37
  from ..common import error_str
@@ -65,7 +67,7 @@ class AssignParser(Parser):
65
67
  tuple_elts = node.elts
66
68
  tuple_values = []
67
69
  for tuple_elt in tuple_elts:
68
- if not isinstance(tuple_elt, (ast.Constant, ast.Name)):
70
+ if not isinstance(tuple_elt, (ast.Constant, ast.Name, ast.Attribute)):
69
71
  raise RuntimeError(f"Only support ast.Constant or ast.Name as elts of ast.Tuple, "
70
72
  f"but got ast type {type(tuple_elt).__name__}",
71
73
  child_node=tuple_elt, father_node=node)
@@ -73,6 +75,8 @@ class AssignParser(Parser):
73
75
  tuple_values.append(tuple_elt.value)
74
76
  elif isinstance(tuple_elt, ast.Name):
75
77
  tuple_values.append(tuple_elt.id)
78
+ elif isinstance(tuple_elt, ast.Attribute):
79
+ tuple_values.append("".join([tuple_elt.value.id, '.', tuple_elt.attr]))
76
80
  return ScopedValue.create_variable_value(tuple(tuple_values))
77
81
 
78
82
  @staticmethod
@@ -281,15 +285,15 @@ class AssignParser(Parser):
281
285
  if len(body.targets) > 1:
282
286
  raise NotImplementedError(error_str("not support multi-targets in assign now!", father_node=body))
283
287
  target = body.targets[0]
284
- if not isinstance(target, ast.Attribute) or not (target.value, ast.Name) or target.value.id != "self":
288
+ if not isinstance(target, ast.Attribute) or not isinstance(target.value, ast.Name):
285
289
  continue
286
- if target.attr != func_name:
290
+ if target.value.id != "self" or target.attr != func_name:
287
291
  continue
288
292
  changed = True
289
- global_vars_key = "_".join([func_name, "args"])
290
- stree.add_global_vars(global_vars_key, sub_tree.get_global_vars())
291
- args_call = AstModifier.create_call(ScopedValue.create_naming_value("get", "global_vars"),
292
- [ScopedValue.create_variable_value(global_vars_key)])
293
+ setattr(stree.get_origin_network(), func_name, sub_tree.get_origin_network())
294
+ args_call = AstModifier.create_call(ScopedValue(ValueType.NamingValue, "", "getattr"),
295
+ [ScopedValue(ValueType.NamingValue, "", "obj"),
296
+ ScopedValue(ValueType.StringValue, "", func_name)])
293
297
  body.value = ast.Call(func=ast.Name(class_name, ast.Store()), args=[args_call], keywords=[])
294
298
  break
295
299
  return changed
@@ -308,6 +312,91 @@ class AssignParser(Parser):
308
312
  call_args = [AssignParser._create_scopedvalue(arg) for arg in father_ast_node.value.args]
309
313
  return Node.create_call_buildin_op(op, father_ast_node, targets, func, call_args, {})
310
314
 
315
+ @staticmethod
316
+ def _create_inputs_for_cell_container(father_ast_node) -> ['Node']:
317
+ """Create inputs for cell container first node."""
318
+ call_ast_node = father_ast_node.value
319
+ if not isinstance(call_ast_node, ast.Call):
320
+ raise RuntimeError(error_str(f"when creating input node for cellcontainer, value of input father ast node"
321
+ "is not ast.Call!'", child_node=call_ast_node, father_node=father_ast_node))
322
+ first_node_inputs: ['Node'] = []
323
+ exist_param_name = []
324
+ for arg in call_ast_node.args:
325
+ if isinstance(arg, ast.Name):
326
+ param_name = arg.id
327
+ elif isinstance(arg, ast.arg):
328
+ param_name = arg.arg
329
+ else:
330
+ raise RuntimeError(error_str(f"only support ast.arg, ast.arg in arguments arg, but got "
331
+ f"'{type(arg).__name__}'", child_node=arg, father_node=call_ast_node))
332
+ if param_name in exist_param_name:
333
+ raise RuntimeError(error_str(f"Cellcontianer has duplicate input names", child_node=arg,
334
+ father_node=call_ast_node))
335
+ exist_param_name.append(param_name)
336
+ node = Node.create_input_node(arg, param_name, name=f"input_{param_name}")
337
+ first_node_inputs.append(node)
338
+
339
+ if call_ast_node.keywords:
340
+ raise RuntimeError(error_str(f"Not support keyword input for cellcontainer now.",
341
+ child_node=call_ast_node, father_node=father_ast_node))
342
+
343
+ return first_node_inputs
344
+
345
+ def _cell_container_process(self, ast_node, stree, targets, func, call_args, call_kwargs, op_name, container_obj):
346
+ """ parse cell container object."""
347
+ cell_container = CellContainer(ast_node, targets, func, call_args, call_kwargs, op_name, container_obj)
348
+ cell_container.set_belong_symbol_tree(stree)
349
+ first_node_inputs = AssignParser._create_inputs_for_cell_container(ast_node)
350
+ for i, cell in enumerate(container_obj):
351
+ is_sub_tree = is_subtree(type(cell).__name__)
352
+ if is_sub_tree:
353
+ stb = SymbolTreeBuilder(cell)
354
+ new_stree = stb.build()
355
+ replacer = AstReplacer(new_stree.get_class_ast())
356
+ replacer.replace_all(new_stree.get_ori_cls_name(), new_stree.get_opt_cls_name())
357
+ sub_node = TreeNode.create_tree_node(new_stree, ast_node, targets, func, call_args, call_kwargs,
358
+ type(cell).__name__, cell)
359
+ else:
360
+ sub_node = Node.create_call_buildin_op(cell, ast_node, targets, func, call_args, call_kwargs,
361
+ type(cell).__name__)
362
+ # add sub node to cell_container
363
+ cell_container.append(sub_node)
364
+ # set node inputs
365
+ if i == 0:
366
+ sub_node.set_inputs(first_node_inputs)
367
+ else:
368
+ sub_node.set_inputs([cell_container.node_list[i-1]])
369
+ return cell_container
370
+
371
+ def _process_external_function(self, stree, func_name):
372
+ """Process external function."""
373
+ for k, m in sys.modules.items():
374
+ if k in ("_ast", "ast"):
375
+ continue
376
+ if hasattr(m, func_name):
377
+ func = getattr(m, func_name)
378
+ source_code = inspect.getsource(func)
379
+ ast_root: ast.Module = ast.parse(source_code)
380
+ stree._external_func_ast.append(ast_root.body[0]) # pylint: disable=protected-access
381
+ return func, ast_root.body[0]
382
+ return None, None
383
+
384
+ def _process_internal_function(self, stree: SymbolTree, func_name):
385
+ """Process internal function."""
386
+ func = getattr(stree._origin_network, func_name) # pylint: disable=protected-access
387
+ ast_node = None
388
+ for body in stree._class_ast.body: # pylint: disable=protected-access
389
+ if isinstance(body, ast.FunctionDef) and func_name == body.name:
390
+ ast_node = body
391
+ return func, ast_node
392
+
393
+ def _create_func_subtree(self, op, targets, father_ast_node, ast_node, call_args, call_kwargs, func_name):
394
+ """Create subtree of function."""
395
+ stb = FunctionSymbolTreeBuilder(op, ast_node)
396
+ new_stree = stb.build()
397
+ return TreeNode.create_tree_node(new_stree, father_ast_node, targets, func_name, call_args, call_kwargs,
398
+ func_name, op)
399
+
311
400
  def _convert_ast_call_to_node(self, ast_node: ast.Call, father_ast_node: ast.Assign, stree: SymbolTree) -> Node:
312
401
  """
313
402
  Convert ast.Call to a symbol tree node.
@@ -340,9 +429,19 @@ class AssignParser(Parser):
340
429
  func = get_functional(func_name.split(".")[-1])
341
430
  node = stree.inner_create_call_function(func_name, father_ast_node, func_name, func, targets,
342
431
  call_args, call_kwargs)
343
- return node
344
- raise RuntimeError(error_str(f"operator instance undefined.",
345
- child_node=ast_node.func, father_node=ast_node))
432
+ elif hasattr(stree._origin_network, func_name): # pylint: disable=protected-access
433
+ func, ast_node = self._process_internal_function(stree, func_name)
434
+ node = self._create_func_subtree(func, targets, father_ast_node, ast_node, call_args, call_kwargs,
435
+ func_name)
436
+ else:
437
+ func, ast_node = self._process_external_function(stree, func_name)
438
+ node = self._create_func_subtree(func, targets, father_ast_node, ast_node, call_args, call_kwargs,
439
+ func_name)
440
+ return node
441
+ if isinstance(op, SequentialCell):
442
+ node = self._cell_container_process(father_ast_node, stree, targets, func, call_args, call_kwargs,
443
+ func_name, op)
444
+ return node
346
445
  if isinstance(op, Primitive):
347
446
  return Node.create_call_buildin_op(op, father_ast_node, targets, func, call_args, call_kwargs, func_name)
348
447
  if isinstance(op, Cell):
@@ -394,6 +493,74 @@ class AssignParser(Parser):
394
493
  raise RuntimeError("For MindSpore Rewrite, only support Primitive or Cell operator or Primitive operator, got ",
395
494
  type(op).__name__)
396
495
 
496
+ @staticmethod
497
+ def _tuple_elts_support_scopledvalue(value: ast.Tuple) -> bool:
498
+ """ check whether each element's type in tuple is supported by scopled value. """
499
+ if not isinstance(value, ast.Tuple):
500
+ raise RuntimeError("For AssignParser._tuple_elts_support_scopledvalue(), the type of value should be "
501
+ f"Tuple, but got {type(value).__name__}")
502
+
503
+ for elt in value.elts:
504
+ if not isinstance(elt, (ast.Name, ast.Attribute, ast.Tuple, ast.Constant, ast.Num, ast.Str, ast.Bytes)):
505
+ return False
506
+ return True
507
+
508
+ @staticmethod
509
+ def _convert_ast_mathops_to_node(ast_node: Union[ast.BinOp, ast.UnaryOp, ast.BoolOp, ast.Compare],
510
+ father_ast_node: ast.Assign) -> Node:
511
+ """
512
+ Convert ast node of math operations(ast.BinOp, ast.UnaryOp, ast.BoolOp, ast.Compare) to
513
+ a symbol tree node.
514
+
515
+ Args:
516
+ ast_node (Union[ast.BinOp, ast.UnaryOp, ast.BoolOp, ast.Compare]): An assign node with mathematival
517
+ operation in construct function.
518
+ father_ast_node (ast.Assign): Assign node in construct.
519
+
520
+ Returns:
521
+ An instance of Node in Symbol Tree.
522
+
523
+ Raises:
524
+ TypeError: The type of parameter 'ast_node' is not in (ast.BinOp, ast.UnaryOp, ast.BoolOp, ast.Compare).
525
+
526
+ """
527
+ if not isinstance(ast_node, (ast.BinOp, ast.UnaryOp, ast.BoolOp, ast.Compare)):
528
+ raise TypeError("The type of parameter 'ast_node' must be one of (ast.BinOp, ast.UnaryOp, "
529
+ "ast.BoolOp, ast.Compare), but got ", type(ast_node))
530
+
531
+ targets = AssignParser._get_targets(AssignParser._create_scopedvalue(father_ast_node.targets[0]))
532
+ args = []
533
+ op_type_str = type(ast_node).__name__
534
+ op_type = ScopedValue.create_naming_value(op_type_str)
535
+ ops = {}
536
+ name = op_type_str
537
+ if isinstance(ast_node, ast.BinOp):
538
+ op = type(ast_node.op).__name__
539
+ name = f'{name}_{op}'
540
+ ops['0'] = ScopedValue.create_naming_value(op)
541
+ args.append(AssignParser._create_scopedvalue(ast_node.left))
542
+ args.append(AssignParser._create_scopedvalue(ast_node.right))
543
+ elif isinstance(ast_node, ast.UnaryOp):
544
+ op = type(ast_node.op).__name__
545
+ name = f'{name}_{op}'
546
+ ops['0'] = ScopedValue.create_naming_value(op)
547
+ args.append(AssignParser._create_scopedvalue(ast_node.operand))
548
+ elif isinstance(ast_node, ast.BoolOp):
549
+ op = type(ast_node.op).__name__
550
+ name = f'{name}_{op}'
551
+ ops['0'] = ScopedValue.create_naming_value(op)
552
+ for value in ast_node.values:
553
+ args.append(AssignParser._create_scopedvalue(value))
554
+ elif isinstance(ast_node, ast.Compare):
555
+ args.append(AssignParser._create_scopedvalue(ast_node.left))
556
+ for idx, ast_op in enumerate(ast_node.ops):
557
+ op = type(ast_op).__name__
558
+ name = f'{name}_{op}'
559
+ ops[str(idx)] = ScopedValue.create_naming_value(op)
560
+ args.append(AssignParser._create_scopedvalue(ast_node.comparators[idx]))
561
+ name = name.lower()
562
+ return Node.create_mathops_node(father_ast_node, targets, op_type, args, ops, name)
563
+
397
564
  def process(self, stree: SymbolTree, node: ast.Assign):
398
565
  """
399
566
  Parse ast.Assign and create a node in symbol tree.
@@ -413,52 +580,63 @@ class AssignParser(Parser):
413
580
  """
414
581
 
415
582
  targets = node.targets
416
- if len(targets) != 1:
417
- raise RuntimeError(
418
- error_str(f"only support one target in assign now.", child_node=targets, father_node=node))
419
- value = node.value
420
- if isinstance(value, ast.Call):
421
- node_ = self._convert_ast_call_to_node(value, node, stree)
422
- stree.append_origin_field(node_)
423
- elif isinstance(value, ast.BinOp):
424
- if isinstance(value.op, ast.Add):
425
- node_ = AssignParser._convert_ast_binop_to_node(value, node)
583
+ try:
584
+ if len(targets) != 1:
585
+ raise RuntimeError(
586
+ error_str(f"only support one target in assign now.", child_node=targets, father_node=node))
587
+ value = node.value
588
+ if isinstance(value, ast.Call):
589
+ stree.update_scope_for_unique(value)
590
+ node_ = self._convert_ast_call_to_node(value, node, stree)
426
591
  stree.append_origin_field(node_)
427
- else:
592
+ elif isinstance(value, (ast.BinOp, ast.UnaryOp, ast.BoolOp, ast.Compare)):
593
+ node_ = AssignParser._convert_ast_mathops_to_node(value, node)
594
+ stree.append_origin_field(node_)
595
+ elif isinstance(value, ast.Subscript):
428
596
  logger.info(f"ops-call({astunparse.unparse(node)}) in assign will be supported in near feature, "
429
597
  f"ignored as a python node now")
430
598
  stree.try_append_python_node(node, node)
431
- elif isinstance(value, (ast.BoolOp, ast.Subscript)):
432
- logger.info(f"ops-call({astunparse.unparse(node)}) in assign will be supported in near feature, "
433
- f"ignored as a python node now")
434
- stree.try_append_python_node(node, node)
435
- elif isinstance(value, (ast.Name, ast.Constant, ast.Attribute, ast.Num, ast.NameConstant, ast.Bytes, ast.Str)):
436
- if isinstance(value, ast.Name):
437
- node_name = "name_assign"
438
- elif isinstance(value, ast.Constant):
439
- node_name = "constant_assign"
599
+ elif isinstance(value, (ast.Name, ast.Constant, ast.Attribute, ast.Num, ast.NameConstant,
600
+ ast.Bytes, ast.Str)):
601
+ if isinstance(value, ast.Name):
602
+ node_name = "name_assign"
603
+ elif isinstance(value, ast.Constant):
604
+ node_name = "constant_assign"
605
+ elif isinstance(value, ast.Attribute):
606
+ node_name = "attribute_assign"
607
+ stree.update_scope_for_unique(value)
608
+ else:
609
+ node_name = "other_assign"
610
+ targets = AssignParser._get_targets(AssignParser._create_scopedvalue(node.targets[0]))
611
+ call_args = [AssignParser._create_scopedvalue(value)]
612
+ node_ = Node.create_call_pass_through_method(node, targets, call_args, {}, node_name)
613
+ stree.append_origin_field(node_)
614
+ elif isinstance(value, ast.Tuple):
615
+ if AssignParser._tuple_elts_support_scopledvalue(value):
616
+ # ensure that each element's type in tuple is supported by scopled value
617
+ targets = AssignParser._get_targets(AssignParser._create_scopedvalue(node.targets[0]))
618
+ args = []
619
+ for elt in value.elts:
620
+ args.append(AssignParser._create_scopedvalue(elt))
621
+ node_ = Node.create_call_method(node, targets, ScopedValue.create_naming_value("tuple"),
622
+ args, {}, "tuple")
623
+ stree.append_origin_field(node_)
624
+ else:
625
+ logger.warning(f"some elements in Tuple of assign({astunparse.unparse(node)}) are not supported "
626
+ "in rewrite, fallback to python")
627
+ stree.try_append_python_node(node, node)
628
+ elif isinstance(value, (ast.List, ast.Dict)):
629
+ # add these as callmethod node if necessary
630
+ stree.try_append_python_node(node, node)
440
631
  else:
441
- node_name = "attribute_assign"
442
- targets = AssignParser._get_targets(AssignParser._create_scopedvalue(node.targets[0]))
443
- call_args = [AssignParser._create_scopedvalue(value)]
444
- node_ = Node.create_call_pass_through_method(node, targets, call_args, {}, node_name)
445
- stree.append_origin_field(node_)
446
- elif isinstance(value, ast.Tuple):
447
- targets = AssignParser._get_targets(AssignParser._create_scopedvalue(node.targets[0]))
448
- args = []
449
- for elt in value.elts:
450
- args.append(AssignParser._create_scopedvalue(elt))
451
- node_ = Node.create_call_method(node, targets, ScopedValue.create_naming_value("tuple"), args, {}, "tuple")
452
- stree.append_origin_field(node_)
453
- elif isinstance(value, (ast.List, ast.Dict)):
454
- # add these as callmethod node if necessary
632
+ raise RuntimeError(
633
+ error_str(f"only support (ast.Call, ast.BinOp, ast.BoolOp, ast.Subscript, ast.Name, ast.Constant, "
634
+ f"ast.Attribute, ast.Num, ast.NameConstant, ast.Bytes, ast.Str, ast.Tuple, ast.List, "
635
+ f"ast.Dict) as value of ast.assign, but got ast type '{type(value).__name__}'",
636
+ child_node=value, father_node=node))
637
+ except RuntimeError:
638
+ logger.info(f"ops-call({astunparse.unparse(node)}) not supported in rewrite, fallback to python")
455
639
  stree.try_append_python_node(node, node)
456
- else:
457
- raise RuntimeError(
458
- error_str(f"only support (ast.Call, ast.BinOp, ast.BoolOp, ast.Subscript, ast.Name, ast.Constant, "
459
- f"ast.Attribute, ast.Num, ast.NameConstant, ast.Bytes, ast.Str, ast.Tuple, ast.List, ast.Dict"
460
- f") as value of ast.assign, but got ast type '{type(value).__name__}'", child_node=value,
461
- father_node=node))
462
640
 
463
641
 
464
642
  g_assign_parser = reg_parser(AssignParser())