mindspore 2.0.0a0__cp38-cp38-win_amd64.whl → 2.0.0rc1__cp38-cp38-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (655) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -2
  3. mindspore/_c_dataengine.cp38-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp38-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp38-win_amd64.pyd +0 -0
  6. mindspore/_check_jit_forbidden_api.py +102 -0
  7. mindspore/_checkparam.py +1066 -1001
  8. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +4 -3
  9. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +50 -48
  10. mindspore/_extends/parallel_compile/akg_compiler/util.py +9 -4
  11. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +4 -4
  12. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +9 -4
  13. mindspore/_extends/parse/__init__.py +5 -3
  14. mindspore/_extends/parse/namespace.py +16 -1
  15. mindspore/_extends/parse/parser.py +107 -22
  16. mindspore/_extends/parse/resources.py +0 -7
  17. mindspore/_extends/parse/standard_method.py +885 -413
  18. mindspore/amp.py +52 -57
  19. mindspore/boost/boost.py +2 -2
  20. mindspore/boost/boost_cell_wrapper.py +38 -20
  21. mindspore/boost/dim_reduce.py +3 -3
  22. mindspore/boost/group_loss_scale_manager.py +1 -1
  23. mindspore/common/__init__.py +4 -6
  24. mindspore/common/_decorator.py +2 -0
  25. mindspore/common/_register_for_adapter.py +55 -0
  26. mindspore/common/_stub_tensor.py +201 -0
  27. mindspore/common/_utils.py +41 -7
  28. mindspore/common/api.py +215 -141
  29. mindspore/common/dtype.py +8 -1
  30. mindspore/common/dump.py +2 -2
  31. mindspore/common/initializer.py +4 -2
  32. mindspore/common/jit_config.py +17 -13
  33. mindspore/common/mutable.py +33 -13
  34. mindspore/common/parameter.py +23 -21
  35. mindspore/common/seed.py +8 -24
  36. mindspore/common/sparse_tensor.py +62 -41
  37. mindspore/common/tensor.py +852 -1154
  38. mindspore/communication/__init__.py +2 -2
  39. mindspore/communication/_comm_helper.py +11 -4
  40. mindspore/communication/management.py +22 -21
  41. mindspore/config/op_info.config +501 -1008
  42. mindspore/context.py +201 -23
  43. mindspore/dataset/__init__.py +6 -6
  44. mindspore/dataset/audio/__init__.py +7 -7
  45. mindspore/dataset/audio/transforms.py +670 -30
  46. mindspore/dataset/audio/utils.py +47 -4
  47. mindspore/dataset/audio/validators.py +223 -1
  48. mindspore/dataset/callback/ds_callback.py +2 -2
  49. mindspore/dataset/core/config.py +210 -14
  50. mindspore/dataset/core/validator_helpers.py +2 -2
  51. mindspore/{parallel/nn/layers.py → dataset/debug/__init__.py} +7 -8
  52. mindspore/dataset/debug/debug_hook.py +65 -0
  53. mindspore/dataset/debug/pre_defined_hook.py +67 -0
  54. mindspore/dataset/engine/__init__.py +7 -3
  55. mindspore/dataset/engine/cache_client.py +1 -1
  56. mindspore/dataset/engine/datasets.py +322 -66
  57. mindspore/dataset/engine/datasets_audio.py +80 -76
  58. mindspore/dataset/engine/datasets_standard_format.py +51 -38
  59. mindspore/dataset/engine/datasets_text.py +232 -118
  60. mindspore/dataset/engine/datasets_user_defined.py +41 -17
  61. mindspore/dataset/engine/datasets_vision.py +746 -225
  62. mindspore/dataset/engine/graphdata.py +75 -10
  63. mindspore/dataset/engine/iterators.py +45 -5
  64. mindspore/dataset/engine/offload.py +48 -28
  65. mindspore/dataset/engine/validators.py +117 -8
  66. mindspore/dataset/text/__init__.py +6 -5
  67. mindspore/dataset/text/transforms.py +86 -3
  68. mindspore/dataset/text/utils.py +6 -4
  69. mindspore/dataset/text/validators.py +25 -0
  70. mindspore/dataset/transforms/__init__.py +3 -2
  71. mindspore/dataset/transforms/c_transforms.py +1 -1
  72. mindspore/dataset/transforms/transforms.py +2 -2
  73. mindspore/dataset/utils/__init__.py +2 -1
  74. mindspore/dataset/utils/line_reader.py +121 -0
  75. mindspore/dataset/vision/__init__.py +2 -3
  76. mindspore/dataset/vision/c_transforms.py +9 -9
  77. mindspore/dataset/vision/py_transforms.py +5 -5
  78. mindspore/dataset/vision/py_transforms_util.py +2 -0
  79. mindspore/dataset/vision/transforms.py +160 -161
  80. mindspore/dataset/vision/utils.py +3 -3
  81. mindspore/experimental/map_parameter.py +38 -26
  82. mindspore/include/OWNERS +0 -1
  83. mindspore/include/api/callback/callback.h +9 -13
  84. mindspore/include/api/callback/ckpt_saver.h +2 -2
  85. mindspore/include/api/callback/loss_monitor.h +2 -2
  86. mindspore/include/api/callback/lr_scheduler.h +5 -5
  87. mindspore/include/api/callback/time_monitor.h +2 -2
  88. mindspore/include/api/callback/train_accuracy.h +4 -6
  89. mindspore/include/api/cfg.h +19 -6
  90. mindspore/include/api/context.h +44 -9
  91. mindspore/include/api/delegate.h +1 -1
  92. mindspore/include/api/metrics/accuracy.h +2 -2
  93. mindspore/include/api/metrics/metrics.h +4 -3
  94. mindspore/include/api/model.h +9 -4
  95. mindspore/include/api/model_parallel_runner.h +2 -2
  96. mindspore/include/api/net.h +12 -11
  97. mindspore/include/api/serialization.h +19 -3
  98. mindspore/include/api/types.h +3 -3
  99. mindspore/include/dataset/constants.h +7 -0
  100. mindspore/include/dataset/text.h +59 -0
  101. mindspore/jpeg62.dll +0 -0
  102. mindspore/log.py +1 -1
  103. mindspore/mindrecord/filereader.py +18 -0
  104. mindspore/mindrecord/filewriter.py +197 -34
  105. mindspore/mindrecord/shardreader.py +9 -0
  106. mindspore/mindrecord/shardwriter.py +1 -1
  107. mindspore/mindrecord/tools/cifar100_to_mr.py +3 -3
  108. mindspore/mindrecord/tools/cifar10_to_mr.py +3 -3
  109. mindspore/mindrecord/tools/csv_to_mr.py +3 -3
  110. mindspore/mindrecord/tools/imagenet_to_mr.py +16 -11
  111. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  112. mindspore/mindrecord/tools/tfrecord_to_mr.py +6 -6
  113. mindspore/mindspore_backend.dll +0 -0
  114. mindspore/mindspore_common.dll +0 -0
  115. mindspore/mindspore_core.dll +0 -0
  116. mindspore/mindspore_glog.dll +0 -0
  117. mindspore/mindspore_shared_lib.dll +0 -0
  118. mindspore/nn/__init__.py +0 -4
  119. mindspore/nn/cell.py +204 -132
  120. mindspore/nn/dynamic_lr.py +1 -1
  121. mindspore/nn/grad/cell_grad.py +7 -6
  122. mindspore/nn/layer/__init__.py +5 -4
  123. mindspore/nn/layer/activation.py +40 -89
  124. mindspore/nn/layer/basic.py +255 -624
  125. mindspore/nn/layer/channel_shuffle.py +7 -6
  126. mindspore/nn/layer/combined.py +1 -1
  127. mindspore/nn/layer/container.py +41 -4
  128. mindspore/nn/layer/conv.py +64 -28
  129. mindspore/nn/layer/dense.py +9 -8
  130. mindspore/nn/layer/embedding.py +27 -25
  131. mindspore/nn/layer/image.py +53 -46
  132. mindspore/nn/layer/math.py +97 -105
  133. mindspore/nn/layer/normalization.py +117 -86
  134. mindspore/nn/layer/padding.py +185 -95
  135. mindspore/nn/layer/pooling.py +817 -414
  136. mindspore/nn/layer/rnn_cells.py +10 -15
  137. mindspore/nn/layer/rnns.py +37 -38
  138. mindspore/nn/layer/thor_layer.py +11 -12
  139. mindspore/nn/layer/timedistributed.py +5 -5
  140. mindspore/nn/layer/transformer.py +701 -0
  141. mindspore/nn/learning_rate_schedule.py +8 -8
  142. mindspore/nn/loss/__init__.py +5 -4
  143. mindspore/nn/loss/loss.py +334 -199
  144. mindspore/nn/optim/ada_grad.py +6 -6
  145. mindspore/nn/optim/adadelta.py +2 -3
  146. mindspore/nn/optim/adafactor.py +4 -5
  147. mindspore/nn/optim/adam.py +126 -62
  148. mindspore/nn/optim/adamax.py +3 -4
  149. mindspore/nn/optim/adasum.py +6 -6
  150. mindspore/nn/optim/asgd.py +2 -2
  151. mindspore/nn/optim/ftrl.py +67 -38
  152. mindspore/nn/optim/lamb.py +4 -5
  153. mindspore/nn/optim/lars.py +2 -2
  154. mindspore/nn/optim/lazyadam.py +43 -4
  155. mindspore/nn/optim/momentum.py +6 -5
  156. mindspore/nn/optim/optimizer.py +3 -1
  157. mindspore/nn/optim/proximal_ada_grad.py +2 -2
  158. mindspore/nn/optim/rmsprop.py +1 -1
  159. mindspore/nn/optim/rprop.py +8 -9
  160. mindspore/nn/optim/sgd.py +19 -13
  161. mindspore/nn/optim/thor.py +10 -15
  162. mindspore/nn/probability/__init__.py +0 -2
  163. mindspore/nn/probability/bijector/bijector.py +4 -4
  164. mindspore/nn/probability/bijector/invert.py +1 -1
  165. mindspore/nn/probability/bijector/softplus.py +2 -2
  166. mindspore/nn/probability/bnn_layers/dense_variational.py +1 -1
  167. mindspore/nn/probability/bnn_layers/layer_distribution.py +2 -2
  168. mindspore/nn/probability/distribution/_utils/utils.py +9 -15
  169. mindspore/nn/probability/distribution/bernoulli.py +3 -3
  170. mindspore/nn/probability/distribution/beta.py +1 -1
  171. mindspore/nn/probability/distribution/categorical.py +5 -7
  172. mindspore/nn/probability/distribution/cauchy.py +3 -3
  173. mindspore/nn/probability/distribution/distribution.py +2 -2
  174. mindspore/nn/probability/distribution/exponential.py +2 -2
  175. mindspore/nn/probability/distribution/gamma.py +3 -3
  176. mindspore/nn/probability/distribution/geometric.py +1 -1
  177. mindspore/nn/probability/distribution/gumbel.py +3 -3
  178. mindspore/nn/probability/distribution/half_normal.py +15 -11
  179. mindspore/nn/probability/distribution/laplace.py +16 -13
  180. mindspore/nn/probability/distribution/logistic.py +2 -2
  181. mindspore/nn/probability/distribution/normal.py +1 -1
  182. mindspore/nn/probability/distribution/poisson.py +1 -1
  183. mindspore/nn/probability/distribution/student_t.py +20 -15
  184. mindspore/nn/probability/distribution/transformed_distribution.py +4 -4
  185. mindspore/nn/probability/distribution/uniform.py +2 -2
  186. mindspore/nn/reinforcement/_tensors_queue.py +3 -3
  187. mindspore/nn/reinforcement/tensor_array.py +2 -2
  188. mindspore/nn/sparse/sparse.py +2 -2
  189. mindspore/nn/wrap/cell_wrapper.py +27 -10
  190. mindspore/nn/wrap/grad_reducer.py +2 -2
  191. mindspore/nn/wrap/loss_scale.py +40 -24
  192. mindspore/numpy/array_creations.py +33 -22
  193. mindspore/numpy/array_ops.py +35 -30
  194. mindspore/numpy/logic_ops.py +6 -27
  195. mindspore/numpy/math_ops.py +22 -19
  196. mindspore/numpy/utils.py +1 -1
  197. mindspore/numpy/utils_const.py +108 -58
  198. mindspore/opencv_core452.dll +0 -0
  199. mindspore/opencv_imgcodecs452.dll +0 -0
  200. mindspore/opencv_imgproc452.dll +0 -0
  201. mindspore/ops/_constants.py +0 -6
  202. mindspore/ops/_grad/__init__.py +2 -1
  203. mindspore/ops/_grad/grad_array_ops.py +86 -117
  204. mindspore/ops/_grad/grad_base.py +23 -1
  205. mindspore/ops/_grad/grad_clip_ops.py +2 -3
  206. mindspore/ops/_grad/grad_comm_ops.py +34 -24
  207. mindspore/ops/_grad/grad_implementations.py +9 -45
  208. mindspore/ops/_grad/grad_inner_ops.py +47 -4
  209. mindspore/ops/_grad/grad_math_ops.py +142 -117
  210. mindspore/ops/_grad/grad_nn_ops.py +71 -165
  211. mindspore/ops/_grad/grad_sequence_ops.py +296 -0
  212. mindspore/ops/_grad/grad_sparse.py +7 -6
  213. mindspore/ops/_grad_experimental/__init__.py +1 -0
  214. mindspore/ops/_grad_experimental/grad_array_ops.py +150 -15
  215. mindspore/ops/_grad_experimental/grad_image_ops.py +16 -7
  216. mindspore/ops/_grad_experimental/grad_inner_ops.py +1 -22
  217. mindspore/ops/_grad_experimental/grad_linalg_ops.py +4 -11
  218. mindspore/ops/_grad_experimental/grad_math_ops.py +210 -89
  219. mindspore/ops/_grad_experimental/grad_nn_ops.py +26 -22
  220. mindspore/ops/_grad_experimental/grad_scalar_ops.py +112 -0
  221. mindspore/ops/_grad_experimental/grad_sparse_ops.py +49 -8
  222. mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +1 -1
  223. mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +2 -2
  224. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +2 -2
  225. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +2 -2
  226. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +4 -4
  227. mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +3 -3
  228. mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +1 -1
  229. mindspore/ops/_op_impl/_custom_op/correction_mul.py +2 -2
  230. mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +2 -2
  231. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -5
  232. mindspore/ops/_op_impl/_custom_op/dsd_impl.py +1 -1
  233. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +2 -2
  234. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +2 -2
  235. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +2 -2
  236. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +2 -2
  237. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +2 -2
  238. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +2 -2
  239. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +2 -2
  240. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +2 -2
  241. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +2 -2
  242. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +2 -2
  243. mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +1 -1
  244. mindspore/ops/_op_impl/_custom_op/img2col_impl.py +1 -1
  245. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
  246. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +1 -1
  247. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +1 -1
  248. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +1 -1
  249. mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +2 -2
  250. mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +0 -4
  251. mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +1 -1
  252. mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +2 -2
  253. mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +2 -2
  254. mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +1 -1
  255. mindspore/ops/_op_impl/aicpu/__init__.py +236 -4
  256. mindspore/ops/_op_impl/aicpu/abs.py +36 -0
  257. mindspore/ops/_op_impl/aicpu/{adaptive_avg_pool_2d_v1.py → adaptive_avg_pool_2d.py} +6 -5
  258. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
  259. mindspore/ops/_op_impl/aicpu/add.py +43 -0
  260. mindspore/ops/_op_impl/aicpu/addcdiv.py +0 -32
  261. mindspore/ops/_op_impl/aicpu/addcmul.py +0 -84
  262. mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
  263. mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -43
  264. mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
  265. mindspore/{compression/common/__init__.py → ops/_op_impl/aicpu/bessel_i0.py} +15 -8
  266. mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
  267. mindspore/ops/_op_impl/aicpu/conj.py +11 -0
  268. mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +0 -3
  269. mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
  270. mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +43 -0
  271. mindspore/ops/_op_impl/aicpu/{adaptive_avg_pool_2d_grad_v1.py → digamma.py} +7 -9
  272. mindspore/ops/_op_impl/aicpu/flatten.py +1 -0
  273. mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
  274. mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
  275. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +1 -1
  276. mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
  277. mindspore/ops/_op_impl/aicpu/greater.py +41 -0
  278. mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
  279. mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
  280. mindspore/ops/_op_impl/aicpu/less.py +41 -0
  281. mindspore/{nn/probability/infer/variational/__init__.py → ops/_op_impl/aicpu/lgamma.py} +16 -10
  282. mindspore/ops/_op_impl/aicpu/mirror_pad.py +0 -4
  283. mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +0 -4
  284. mindspore/ops/_op_impl/aicpu/mul.py +3 -1
  285. mindspore/ops/_op_impl/aicpu/multinomial.py +14 -6
  286. mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
  287. mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
  288. mindspore/ops/_op_impl/aicpu/ones_like.py +0 -2
  289. mindspore/ops/_op_impl/aicpu/polar.py +32 -0
  290. mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
  291. mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
  292. mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
  293. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
  294. mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
  295. mindspore/ops/_op_impl/aicpu/resize_bicubic.py +2 -8
  296. mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +1 -1
  297. mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
  298. mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
  299. mindspore/ops/_op_impl/aicpu/scatter_elements.py +4 -0
  300. mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +2 -0
  301. mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
  302. mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
  303. mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
  304. mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
  305. mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
  306. mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +0 -24
  307. mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
  308. mindspore/ops/_op_impl/aicpu/sparse_slice.py +4 -0
  309. mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +6 -0
  310. mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
  311. mindspore/ops/_op_impl/aicpu/trans_data.py +1 -0
  312. mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
  313. mindspore/ops/_op_impl/aicpu/uniform.py +34 -0
  314. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +1 -0
  315. mindspore/ops/_op_impl/aicpu/unique_consecutive.py +10 -2
  316. mindspore/ops/_op_impl/cpu/dynamic_shape.py +5 -1
  317. mindspore/ops/_op_impl/cpu/sparse_slice.py +4 -0
  318. mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +6 -0
  319. mindspore/ops/_op_impl/cpu/tensor_shape.py +5 -1
  320. mindspore/ops/_op_impl/tbe/__init__.py +27 -611
  321. mindspore/ops/_op_impl/tbe/assign_add_ds.py +1 -0
  322. mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
  323. mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +1 -1
  324. mindspore/ops/_op_impl/tbe/batch_matmul_ds.py +1 -0
  325. mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
  326. mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +1 -1
  327. mindspore/ops/_op_impl/tbe/bn_infer_grad.py +4 -2
  328. mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -1
  329. mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -1
  330. mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +6 -4
  331. mindspore/ops/_op_impl/tbe/cast.py +0 -2
  332. mindspore/ops/_op_impl/tbe/cast_ds.py +3 -3
  333. mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +1 -0
  334. mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +2 -2
  335. mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +1 -1
  336. mindspore/ops/_op_impl/tbe/gather_nd.py +1 -0
  337. mindspore/ops/_op_impl/tbe/{index_add.py → inplace_index_add.py} +3 -6
  338. mindspore/ops/_op_impl/tbe/matmul_ds.py +2 -0
  339. mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +35 -0
  340. mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +35 -0
  341. mindspore/ops/_op_impl/tbe/scatter_mul.py +2 -0
  342. mindspore/ops/_op_impl/tbe/scatter_nd_add.py +0 -2
  343. mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
  344. mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +1 -1
  345. mindspore/ops/_op_impl/tbe/trans_data_ds.py +15 -5
  346. mindspore/ops/_register_for_op.py +1 -0
  347. mindspore/ops/_utils/__init__.py +1 -2
  348. mindspore/ops/_utils/utils.py +19 -40
  349. mindspore/ops/_vmap/vmap_array_ops.py +116 -38
  350. mindspore/ops/_vmap/vmap_base.py +16 -9
  351. mindspore/ops/_vmap/vmap_convolution_ops.py +7 -10
  352. mindspore/ops/_vmap/vmap_grad_math_ops.py +4 -4
  353. mindspore/ops/_vmap/vmap_grad_nn_ops.py +7 -5
  354. mindspore/ops/_vmap/vmap_image_ops.py +12 -5
  355. mindspore/ops/_vmap/vmap_math_ops.py +46 -5
  356. mindspore/ops/_vmap/vmap_nn_ops.py +15 -21
  357. mindspore/ops/_vmap/vmap_random_ops.py +1 -1
  358. mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
  359. mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
  360. mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +150 -0
  361. mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +66 -0
  362. mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
  363. mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
  364. mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
  365. mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +33 -0
  366. mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +220 -106
  367. mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
  368. mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +240 -0
  369. mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +247 -0
  370. mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +247 -0
  371. mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +315 -0
  372. mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +278 -0
  373. mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +58 -0
  374. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +138 -0
  375. mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
  376. mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
  377. mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +22 -23
  378. mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +16 -17
  379. mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +27 -0
  380. mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
  381. mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
  382. mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
  383. mindspore/ops/bprop_mindir/Elu_bprop.mindir +16 -0
  384. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  385. mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +39 -41
  386. mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +16 -0
  387. mindspore/ops/bprop_mindir/Flatten_bprop.mindir +41 -43
  388. mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +51 -57
  389. mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
  390. mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +16 -0
  391. mindspore/ops/bprop_mindir/HSwish_bprop.mindir +16 -0
  392. mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
  393. mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +126 -0
  394. mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +15 -0
  395. mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +30 -0
  396. mindspore/ops/bprop_mindir/LRN_bprop.mindir +43 -0
  397. mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
  398. mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +23 -0
  399. mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +74 -0
  400. mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +74 -0
  401. mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +75 -0
  402. mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +65 -0
  403. mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
  404. mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +27 -0
  405. mindspore/ops/bprop_mindir/Mish_bprop.mindir +35 -0
  406. mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
  407. mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
  408. mindspore/ops/bprop_mindir/OneHot_bprop.mindir +24 -25
  409. mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
  410. mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
  411. mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
  412. mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +29 -0
  413. mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +82 -0
  414. mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +16 -0
  415. mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
  416. mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +18 -19
  417. mindspore/ops/bprop_mindir/Reshape_bprop.mindir +53 -53
  418. mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +29 -0
  419. mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +77 -85
  420. mindspore/ops/bprop_mindir/SeLU_bprop.mindir +21 -0
  421. mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +21 -0
  422. mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
  423. mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +16 -0
  424. mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +36 -0
  425. mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  426. mindspore/ops/bprop_mindir/Softplus_bprop.mindir +16 -0
  427. mindspore/ops/bprop_mindir/Softsign_bprop.mindir +33 -0
  428. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  429. mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +37 -39
  430. mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +70 -72
  431. mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
  432. mindspore/ops/bprop_mindir/Tanh_bprop.mindir +66 -0
  433. mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
  434. mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
  435. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +17 -17
  436. mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +32 -0
  437. mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +38 -0
  438. mindspore/ops/bprop_mindir/generate_mindir.py +2 -0
  439. mindspore/ops/composite/__init__.py +7 -8
  440. mindspore/ops/composite/base.py +101 -47
  441. mindspore/ops/composite/math_ops.py +188 -158
  442. mindspore/ops/composite/multitype_ops/_compile_utils.py +415 -170
  443. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +142 -87
  444. mindspore/ops/composite/multitype_ops/add_impl.py +6 -1
  445. mindspore/ops/composite/multitype_ops/div_impl.py +2 -3
  446. mindspore/ops/composite/multitype_ops/getitem_impl.py +31 -3
  447. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +31 -0
  448. mindspore/ops/composite/multitype_ops/greater_impl.py +31 -0
  449. mindspore/ops/composite/multitype_ops/in_impl.py +9 -0
  450. mindspore/ops/composite/multitype_ops/less_equal_impl.py +31 -0
  451. mindspore/ops/composite/multitype_ops/less_impl.py +31 -0
  452. mindspore/ops/composite/multitype_ops/mul_impl.py +21 -5
  453. mindspore/ops/composite/multitype_ops/not_in_impl.py +9 -0
  454. mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -4
  455. mindspore/ops/composite/multitype_ops/setitem_impl.py +21 -3
  456. mindspore/ops/composite/multitype_ops/sub_impl.py +1 -1
  457. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +35 -4
  458. mindspore/ops/function/__init__.py +152 -8
  459. mindspore/ops/function/array_func.py +2555 -674
  460. mindspore/ops/function/clip_func.py +209 -13
  461. mindspore/ops/function/debug_func.py +2 -2
  462. mindspore/ops/function/grad/__init__.py +2 -1
  463. mindspore/ops/function/grad/grad_func.py +147 -62
  464. mindspore/ops/function/image_func.py +54 -38
  465. mindspore/ops/function/linalg_func.py +167 -16
  466. mindspore/ops/function/math_func.py +4849 -1492
  467. mindspore/ops/function/nn_func.py +2573 -988
  468. mindspore/ops/function/other_func.py +115 -0
  469. mindspore/ops/function/parameter_func.py +3 -3
  470. mindspore/ops/function/random_func.py +790 -73
  471. mindspore/ops/function/sparse_func.py +98 -78
  472. mindspore/ops/function/sparse_unary_func.py +54 -53
  473. mindspore/ops/function/spectral_func.py +27 -24
  474. mindspore/ops/function/vmap_func.py +22 -2
  475. mindspore/ops/functional.py +97 -37
  476. mindspore/ops/op_info_register.py +70 -28
  477. mindspore/ops/operations/__init__.py +47 -14
  478. mindspore/ops/operations/_csr_ops.py +7 -7
  479. mindspore/ops/operations/_embedding_cache_ops.py +5 -5
  480. mindspore/ops/operations/_grad_ops.py +276 -187
  481. mindspore/ops/operations/_inner_ops.py +319 -113
  482. mindspore/ops/operations/_ms_kernel.py +10 -8
  483. mindspore/ops/operations/_ocr_ops.py +9 -9
  484. mindspore/ops/operations/_opaque_predicate_registry.py +4 -0
  485. mindspore/ops/operations/_quant_ops.py +137 -102
  486. mindspore/ops/operations/_rl_inner_ops.py +121 -60
  487. mindspore/ops/operations/_scalar_ops.py +466 -0
  488. mindspore/ops/operations/_sequence_ops.py +1004 -2
  489. mindspore/ops/operations/_tensor_array.py +10 -11
  490. mindspore/ops/operations/_thor_ops.py +1 -1
  491. mindspore/ops/operations/array_ops.py +801 -466
  492. mindspore/ops/operations/comm_ops.py +51 -49
  493. mindspore/ops/operations/control_ops.py +2 -2
  494. mindspore/ops/operations/custom_ops.py +123 -44
  495. mindspore/ops/operations/debug_ops.py +24 -24
  496. mindspore/ops/operations/image_ops.py +240 -153
  497. mindspore/ops/operations/inner_ops.py +34 -50
  498. mindspore/ops/operations/linalg_ops.py +31 -9
  499. mindspore/ops/operations/math_ops.py +988 -757
  500. mindspore/ops/operations/nn_ops.py +965 -819
  501. mindspore/ops/operations/other_ops.py +51 -40
  502. mindspore/ops/operations/random_ops.py +204 -122
  503. mindspore/ops/operations/rl_ops.py +8 -9
  504. mindspore/ops/operations/sparse_ops.py +254 -93
  505. mindspore/ops/operations/spectral_ops.py +35 -3
  506. mindspore/ops/primitive.py +111 -9
  507. mindspore/parallel/_auto_parallel_context.py +189 -83
  508. mindspore/parallel/_offload_context.py +185 -0
  509. mindspore/parallel/_parallel_serialization.py +99 -7
  510. mindspore/parallel/_ps_context.py +9 -5
  511. mindspore/parallel/_recovery_context.py +1 -1
  512. mindspore/parallel/_tensor.py +7 -1
  513. mindspore/{nn/transformer → parallel/_transformer}/__init__.py +6 -6
  514. mindspore/{nn/transformer → parallel/_transformer}/layers.py +6 -37
  515. mindspore/{nn/transformer → parallel/_transformer}/loss.py +4 -7
  516. mindspore/{nn/transformer → parallel/_transformer}/moe.py +20 -16
  517. mindspore/{nn/transformer → parallel/_transformer}/op_parallel_config.py +3 -3
  518. mindspore/{nn/transformer → parallel/_transformer}/transformer.py +48 -111
  519. mindspore/parallel/_utils.py +1 -2
  520. mindspore/parallel/algo_parameter_config.py +1 -1
  521. mindspore/parallel/checkpoint_transform.py +37 -34
  522. mindspore/parallel/shard.py +17 -18
  523. mindspore/profiler/common/validator/validate_path.py +2 -2
  524. mindspore/profiler/envprofiling.py +69 -47
  525. mindspore/profiler/parser/ascend_timeline_generator.py +49 -42
  526. mindspore/profiler/parser/base_timeline_generator.py +49 -56
  527. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +98 -78
  528. mindspore/profiler/parser/hwts_log_parser.py +1 -1
  529. mindspore/profiler/parser/integrator.py +15 -14
  530. mindspore/profiler/parser/minddata_analyzer.py +2 -2
  531. mindspore/profiler/parser/msadvisor_analyzer.py +12 -25
  532. mindspore/profiler/parser/msadvisor_parser.py +2 -4
  533. mindspore/profiler/parser/optime_parser.py +17 -18
  534. mindspore/profiler/parser/profiler_info.py +2 -1
  535. mindspore/profiler/profiling.py +218 -186
  536. mindspore/rewrite/__init__.py +3 -1
  537. mindspore/rewrite/api/node.py +1 -114
  538. mindspore/rewrite/api/node_type.py +3 -0
  539. mindspore/rewrite/api/pattern_engine.py +31 -1
  540. mindspore/rewrite/api/scoped_value.py +4 -4
  541. mindspore/rewrite/api/symbol_tree.py +3 -78
  542. mindspore/rewrite/api/tree_node_helper.py +1 -1
  543. mindspore/rewrite/ast_creator_register.py +1 -0
  544. mindspore/rewrite/ast_helpers/__init__.py +2 -2
  545. mindspore/rewrite/ast_helpers/ast_creator.py +1 -2
  546. mindspore/rewrite/ast_helpers/ast_finder.py +65 -0
  547. mindspore/rewrite/ast_helpers/ast_modifier.py +11 -3
  548. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +18 -2
  549. mindspore/rewrite/namespace.py +0 -2
  550. mindspore/rewrite/node.py +157 -11
  551. mindspore/rewrite/parsers/assign_parser.py +231 -53
  552. mindspore/rewrite/parsers/class_def_parser.py +187 -109
  553. mindspore/rewrite/parsers/for_parser.py +24 -14
  554. mindspore/rewrite/parsers/function_def_parser.py +21 -4
  555. mindspore/rewrite/parsers/if_parser.py +6 -2
  556. mindspore/rewrite/sparsify/__init__.py +0 -0
  557. mindspore/rewrite/sparsify/sparse_transformer.py +448 -0
  558. mindspore/rewrite/sparsify/sparsify.py +109 -0
  559. mindspore/rewrite/sparsify/utils.py +173 -0
  560. mindspore/rewrite/symbol_tree.py +256 -133
  561. mindspore/rewrite/symbol_tree_builder.py +38 -1
  562. mindspore/run_check/_check_version.py +69 -63
  563. mindspore/run_check/run_check.py +2 -1
  564. mindspore/tinyxml2.dll +0 -0
  565. mindspore/train/__init__.py +1 -1
  566. mindspore/train/_utils.py +28 -5
  567. mindspore/train/amp.py +273 -102
  568. mindspore/train/callback/_backup_and_restore.py +5 -5
  569. mindspore/train/callback/_callback.py +2 -2
  570. mindspore/train/callback/_checkpoint.py +3 -3
  571. mindspore/train/callback/_early_stop.py +3 -3
  572. mindspore/train/callback/_lambda_callback.py +2 -2
  573. mindspore/train/callback/_landscape.py +29 -31
  574. mindspore/train/callback/_loss_monitor.py +3 -3
  575. mindspore/train/callback/_on_request_exit.py +3 -3
  576. mindspore/train/callback/_reduce_lr_on_plateau.py +4 -4
  577. mindspore/train/callback/_summary_collector.py +23 -16
  578. mindspore/train/callback/_time_monitor.py +3 -3
  579. mindspore/train/checkpoint_pb2.py +68 -8
  580. mindspore/train/data_sink.py +15 -3
  581. mindspore/train/dataset_helper.py +10 -15
  582. mindspore/train/loss_scale_manager.py +8 -11
  583. mindspore/train/metrics/__init__.py +1 -1
  584. mindspore/train/metrics/bleu_score.py +1 -1
  585. mindspore/train/metrics/confusion_matrix.py +1 -1
  586. mindspore/train/metrics/cosine_similarity.py +1 -1
  587. mindspore/train/metrics/dice.py +2 -2
  588. mindspore/train/metrics/fbeta.py +1 -1
  589. mindspore/train/metrics/hausdorff_distance.py +4 -3
  590. mindspore/train/metrics/mean_surface_distance.py +2 -2
  591. mindspore/train/metrics/occlusion_sensitivity.py +1 -1
  592. mindspore/train/metrics/perplexity.py +1 -1
  593. mindspore/train/metrics/precision.py +1 -1
  594. mindspore/train/metrics/recall.py +1 -1
  595. mindspore/train/metrics/roc.py +2 -2
  596. mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
  597. mindspore/train/mind_ir_pb2.py +116 -37
  598. mindspore/train/model.py +45 -28
  599. mindspore/train/serialization.py +295 -188
  600. mindspore/train/summary/_summary_adapter.py +1 -1
  601. mindspore/train/summary/summary_record.py +43 -13
  602. mindspore/train/train_thor/convert_utils.py +2 -2
  603. mindspore/train/train_thor/dataset_helper.py +3 -3
  604. mindspore/turbojpeg.dll +0 -0
  605. mindspore/version.py +1 -1
  606. {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/METADATA +3 -2
  607. {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/RECORD +610 -541
  608. mindspore/compression/__init__.py +0 -19
  609. mindspore/compression/common/constant.py +0 -124
  610. mindspore/compression/export/__init__.py +0 -19
  611. mindspore/compression/export/quant_export.py +0 -515
  612. mindspore/compression/quant/__init__.py +0 -28
  613. mindspore/compression/quant/qat.py +0 -634
  614. mindspore/compression/quant/quant_utils.py +0 -462
  615. mindspore/compression/quant/quantizer.py +0 -68
  616. mindspore/nn/layer/quant.py +0 -1868
  617. mindspore/nn/layer/rnn_utils.py +0 -90
  618. mindspore/nn/probability/dpn/__init__.py +0 -22
  619. mindspore/nn/probability/dpn/vae/__init__.py +0 -25
  620. mindspore/nn/probability/dpn/vae/cvae.py +0 -140
  621. mindspore/nn/probability/dpn/vae/vae.py +0 -124
  622. mindspore/nn/probability/infer/__init__.py +0 -22
  623. mindspore/nn/probability/infer/variational/elbo.py +0 -70
  624. mindspore/nn/probability/infer/variational/svi.py +0 -84
  625. mindspore/nn/probability/toolbox/__init__.py +0 -22
  626. mindspore/nn/probability/toolbox/anomaly_detection.py +0 -99
  627. mindspore/nn/probability/toolbox/uncertainty_evaluation.py +0 -364
  628. mindspore/nn/probability/transforms/__init__.py +0 -22
  629. mindspore/nn/probability/transforms/transform_bnn.py +0 -262
  630. mindspore/nn/probability/zhusuan/__init__.py +0 -18
  631. mindspore/nn/probability/zhusuan/framework/__init__.py +0 -18
  632. mindspore/nn/probability/zhusuan/framework/bn.py +0 -95
  633. mindspore/nn/probability/zhusuan/variational/__init__.py +0 -18
  634. mindspore/nn/probability/zhusuan/variational/elbo.py +0 -46
  635. mindspore/ops/_op_impl/aicpu/parallel_concat.py +0 -42
  636. mindspore/ops/_op_impl/tbe/gather_v2.py +0 -56
  637. mindspore/ops/bprop_mindir/AssignAdd_bprop.mindir +0 -19
  638. mindspore/ops/bprop_mindir/Cast_bprop.mindir +0 -19
  639. mindspore/ops/bprop_mindir/LogicalOr_bprop.mindir +0 -19
  640. mindspore/ops/bprop_mindir/MatMul_bprop.mindir +0 -0
  641. mindspore/ops/bprop_mindir/ReLU_bprop.mindir +0 -17
  642. mindspore/ops/bprop_mindir/Transpose_bprop.mindir +0 -0
  643. mindspore/ops/bprop_mindir/UpdateState_bprop.mindir +0 -15
  644. mindspore/ops/composite/array_ops.py +0 -241
  645. mindspore/ops/composite/clip_ops.py +0 -134
  646. mindspore/ops/composite/random_ops.py +0 -426
  647. mindspore/ops/composite/vmap_ops.py +0 -38
  648. mindspore/parallel/nn/__init__.py +0 -42
  649. mindspore/parallel/nn/loss.py +0 -22
  650. mindspore/parallel/nn/moe.py +0 -21
  651. mindspore/parallel/nn/op_parallel_config.py +0 -22
  652. mindspore/parallel/nn/transformer.py +0 -31
  653. {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/WHEEL +0 -0
  654. {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/entry_points.txt +0 -0
  655. {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/top_level.txt +0 -0
@@ -33,6 +33,8 @@ from .api.node import Node
33
33
  from .api.node_type import NodeType
34
34
  from .api.pattern_engine import PatternEngine, PatternNode, VarNode, Replacement
35
35
  from .api.tree_node_helper import TreeNodeHelper
36
+ from .sparsify.sparsify import sparsify
37
+ from .sparsify.utils import ArgType, SparseFunc
36
38
 
37
39
  __all__ = ["SymbolTree", "Node", "NodeType", "ScopedValue", "ValueType", "PatternEngine", "PatternNode", "VarNode",
38
- "Replacement", "TreeNodeHelper"]
40
+ "Replacement", "TreeNodeHelper", "sparsify", "ArgType", "SparseFunc"]
@@ -18,7 +18,7 @@ from typing import Union, Optional
18
18
 
19
19
  from mindspore.nn import Cell
20
20
  from mindspore.ops.primitive import Primitive
21
- from ..._checkparam import Validator
21
+ from mindspore import _checkparam as Validator
22
22
  from ..node import Node as NodeImpl
23
23
  from ..symbol_tree import SymbolTree as SymbolTreeImpl
24
24
  from .node_type import NodeType
@@ -99,12 +99,6 @@ class Node:
99
99
  args, kwargs, name, is_sub_net))
100
100
 
101
101
  def get_handler(self) -> NodeImpl:
102
- """
103
- Get handler of node implementation.
104
-
105
- Returns:
106
- An instance of `NodeImpl`.
107
- """
108
102
  return self._node
109
103
 
110
104
  def get_inputs(self) -> ['Node']:
@@ -181,7 +175,6 @@ class Node:
181
175
 
182
176
  Raises:
183
177
  RuntimeError: If `src_node` is not belong to current `SymbolTree`.
184
- RuntimeError: If current node and `src_node` is not belong to same `SymbolTree`.
185
178
  TypeError: If `arg_idx` is not a `int` number.
186
179
  ValueError: If `arg_idx` is out of range.
187
180
  TypeError: If `src_node` is not a `Node` instance.
@@ -209,27 +202,6 @@ class Node:
209
202
  belong_symbol_tree.set_node_arg_by_node(self._node, arg_idx, src_node.get_handler(), out_idx)
210
203
 
211
204
  def get_targets(self) -> [ScopedValue]:
212
- """
213
- Get targets of current node.
214
-
215
- - When node_type of current node is `CallCell`, `CallPrimitive`, `CallMethod` or `Tree`, `targets` are strings
216
- represents invoke result of the cell-op or primitive-op or function-call which are corresponding to targets of
217
- ast.Assign.
218
- - When node_type of current node is Input, `targets` should have only one element which is a string represents
219
- parameter of function.
220
- - When node_type of current node is `Python` or `Output`, `targets` are don't-care.
221
-
222
- Returns:
223
- A list of instances of ScopedValue as targets of node.
224
-
225
- Examples:
226
- >>> from mindspore.rewrite import SymbolTree
227
- >>> from lenet import Lenet
228
- >>> net = Lenet()
229
- >>> stree = SymbolTree.create(net)
230
- >>> node = stree.get_node("conv1")
231
- >>> targets = node.get_targets()
232
- """
233
205
  return self._node.get_targets()
234
206
 
235
207
  def get_name(self) -> str:
@@ -284,106 +256,21 @@ class Node:
284
256
  return self._node.get_instance_type()
285
257
 
286
258
  def get_instance(self):
287
- """
288
- Get the instance of current node.
289
-
290
- - When node_type of current node is `CallCell`, instance is an instance of Cell.
291
- - When node_type of current node is `CallPrimitive`, instance is an instance of primitive.
292
- - When node_type of current node is `Tree`, instance is an instance of network-cell.
293
- - When node_type of current node is `Python`, `Input`, `Output` or `CallMethod`, instance should be None.
294
-
295
- Returns:
296
- A object represents corresponding instance of current node.
297
- """
298
259
  return self._node.get_instance()
299
260
 
300
261
  def get_args(self) -> [ScopedValue]:
301
- """
302
- Get the arguments of current node.
303
-
304
- - When `node_type` of current node is `CallCell`, `CallPrimitive` or `Tree`, arguments are corresponding to args
305
- of ast.Call which represents arguments to invoke forward method of cell-op or primitive-op.
306
- - When `node_type` of current node is `Input`, arguments represents default-value of argument of function.
307
- - When `node_type` of current node is `Output`, arguments represents the return values of network.
308
- - When `node_type` of current node is `Python`, arguments are don't-care.
309
-
310
- Returns:
311
- A list of instances of `ScopedValue`.
312
-
313
- Examples:
314
- >>> from mindspore.rewrite import SymbolTree
315
- >>> from lenet import Lenet
316
- >>> net = Lenet()
317
- >>> stree = SymbolTree.create(net)
318
- >>> node = stree.get_node("conv1")
319
- >>> args = node.get_args()
320
- """
321
262
  return self._node.get_args()
322
263
 
323
264
  def get_kwargs(self) -> {str: ScopedValue}:
324
- """
325
- Get the keyword arguments of current node.
326
-
327
- - When node_type of current node is `CallCell`, `CallPrimitive` or `Tree`, keyword arguments are corresponding
328
- to kwargs of ast.Call which represents arguments to invoke forward method of cell-op or primitive-op.
329
- - When node_type of current node is `Python`, `Input` or `Output`, keyword arguments are don't-care.
330
-
331
- Returns:
332
- A dict of str to instance of `ScopedValue`.
333
-
334
- Examples:
335
- >>> from mindspore.rewrite import SymbolTree
336
- >>> from lenet import Lenet
337
- >>> net = Lenet()
338
- >>> stree = SymbolTree.create(net)
339
- >>> node = stree.get_node("conv1")
340
- >>> kwargs = node.get_kwargs()
341
- """
342
265
  return self._node.get_kwargs()
343
266
 
344
267
  def set_attribute(self, key: str, value):
345
- """
346
- Set attribute of current node.
347
-
348
- Args:
349
- key (str): Key of attribute.
350
- value (object): Value of attribute.
351
-
352
- Raises:
353
- TypeError: If `key` is not a `str`.
354
-
355
- Examples:
356
- >>> from mindspore.rewrite import SymbolTree
357
- >>> from lenet import Lenet
358
- >>> net = Lenet()
359
- >>> stree = SymbolTree.create(net)
360
- >>> node = stree.get_node("conv1")
361
- >>> node.set_attribute("channel", 3)
362
- """
363
268
  Validator.check_value_type("key", key, [str], "Node attribute")
364
269
  self._node.set_attribute(key, value)
365
270
 
366
271
  def get_attributes(self) -> {str: object}:
367
- """
368
- Get all attributes of current node.
369
-
370
- Returns:
371
- A dict of str to instance of object as attributes.
372
- """
373
272
  return self._node.get_attributes()
374
273
 
375
274
  def get_attribute(self, key: str):
376
- """
377
- Get attribute of current node by key.
378
-
379
- Args:
380
- key (str): Key of attribute.
381
-
382
- Returns:
383
- A object as attribute, can be any type.
384
-
385
- Raises:
386
- TypeError: If `key` is not a `str`.
387
- """
388
275
  Validator.check_value_type("key", key, [str], "Node attribute")
389
276
  return self._node.get_attribute(key)
@@ -29,6 +29,7 @@ class NodeType(Enum):
29
29
  - Input: `Input` node represents input of `SymbolTree` corresponding to arguments of forward method.
30
30
  - Output: `Output` node represents output of SymbolTree corresponding to return statement of forward method.
31
31
  - Tree: `Tree` node represents sub-network invoking in forward method.
32
+ - MathOps: `MathOps` node represents a mathematical operation, such as adding or comparing in forward method.
32
33
 
33
34
  """
34
35
  Unknown = 0
@@ -43,3 +44,5 @@ class NodeType(Enum):
43
44
  Input = 7
44
45
  Output = 8
45
46
  Tree = 9
47
+ CellContainer = 10
48
+ MathOps = 11
@@ -20,7 +20,7 @@ import abc
20
20
  from mindspore.nn import Cell
21
21
  from mindspore.ops.primitive import Primitive
22
22
  from mindspore import log as logger
23
- from ..._checkparam import Validator
23
+ from mindspore import _checkparam as Validator
24
24
  from .node_type import NodeType
25
25
  from .node import Node
26
26
  from .symbol_tree import SymbolTree
@@ -308,6 +308,16 @@ class PatternEngine:
308
308
  queue.extend(inputs_dict.get(cur_node.get_name()))
309
309
  return new_root
310
310
 
311
+ @staticmethod
312
+ def _multi_replace_cellcontainer(stree, cellcontainer, node, matched_dict, new_nodes):
313
+ """Replace node in CellContainer."""
314
+ to_erase_list = list(matched_dict.values())
315
+ stree.replace(Node(node), new_nodes)
316
+ for n in reversed(to_erase_list):
317
+ if n.get_handler() is node:
318
+ continue
319
+ stree.erase_node(n)
320
+
311
321
  def apply(self, stree: SymbolTree) -> bool:
312
322
  """
313
323
  Apply current pattern to a `SymbolTree`.
@@ -359,6 +369,9 @@ class PatternEngine:
359
369
  visited.append(cur_node)
360
370
  queue.extend(cur_node.get_users())
361
371
  continue
372
+ if cur_node.get_node_type() == NodeType.CellContainer:
373
+ self._process_cellcontainer(stree, cur_node.get_handler())
374
+ continue
362
375
  visited.append(cur_node)
363
376
  matched, matched_dict = self._match(self._pattern, cur_node)
364
377
  # not matched
@@ -460,3 +473,20 @@ class PatternEngine:
460
473
  logger.debug("Check match failed, pattern leaked")
461
474
  return False
462
475
  return True
476
+
477
+ def _process_cellcontainer(self, stree, cellcontainer):
478
+ """Process CellContainer node."""
479
+ for node in cellcontainer.nodes():
480
+ if node.get_node_type() == NodeType.Tree:
481
+ subtree = node.symbol_tree
482
+ self.apply(SymbolTree(subtree))
483
+ continue
484
+ matched, matched_dict = self._match(self._pattern, Node(node))
485
+ if not matched:
486
+ continue
487
+ new_nodes = []
488
+ if self._replacement is not None:
489
+ new_nodes = self._replacement(self._pattern, self._is_chain, matched_dict)
490
+ if not new_nodes: # if replacement is empty, do nothing
491
+ continue
492
+ PatternEngine._multi_replace_cellcontainer(stree, cellcontainer, node, matched_dict, new_nodes)
@@ -15,7 +15,7 @@
15
15
  """Rewrite module api: ValueType and ScopedValue."""
16
16
  from enum import Enum
17
17
  from typing import Optional, Union
18
- from ..._checkparam import Validator
18
+ from mindspore import _checkparam as Validator
19
19
 
20
20
 
21
21
  class ValueType(Enum):
@@ -127,8 +127,8 @@ class ScopedValue:
127
127
  Create a list of naming `ScopedValue`.
128
128
 
129
129
  Args:
130
- names: (list[str] or tuple[str]): List or tuple of `str` represents names of referenced variables.
131
- scopes: (list[str] or tuple[str]): List or tuple of `str` represents scopes of referenced variables.
130
+ names (list[str] or tuple[str]): List or tuple of `str` represents names of referenced variables.
131
+ scopes (list[str] or tuple[str]): List or tuple of `str` represents scopes of referenced variables.
132
132
 
133
133
  Returns:
134
134
  An list of instance of `ScopedValue`.
@@ -140,7 +140,7 @@ class ScopedValue:
140
140
 
141
141
  Examples:
142
142
  >>> from mindspore.rewrite import ScopedValue
143
- >>> variables = ScopedValue.create_name_values(["z", "z_1"]), name="subnet")
143
+ >>> variables = ScopedValue.create_name_values(["z", "z_1"], name="subnet")
144
144
  """
145
145
  Validator.check_element_type_of_iterable("names", names, [str], "ScopedValue")
146
146
  if scopes is not None:
@@ -18,7 +18,7 @@ from types import FunctionType
18
18
  import mindspore as ms
19
19
 
20
20
  from mindspore.nn import Cell
21
- from ..._checkparam import Validator
21
+ from mindspore import _checkparam as Validator
22
22
  from .node import Node
23
23
  from ..symbol_tree_builder import SymbolTreeBuilder
24
24
  from ..symbol_tree import Position, SymbolTree as SymbolTreeImpl
@@ -70,40 +70,7 @@ class SymbolTree:
70
70
  if v not in MsDtypes and not isinstance(v, ParamTypes):
71
71
  raise TypeError(f"For call-function Node, got unsupported kwarg value: {v}, type: {type(v)}")
72
72
 
73
- def create_call_function(self, func, targets, *args, **kwargs):
74
- r"""
75
- Create a Node object and generate the execution code to insert into the source code.
76
- The source code calls the 'func' function with 'args' and' kwargs' as parameters.
77
-
78
- Args:
79
- func (FunctionType): The function to be called.
80
- targets (list[str]): indicates the output name. As the output of the node in the source code.
81
- args (Union[MsDtypes, ParamTypes]): parameter name of the node. Used as a parameter to a code statement in
82
- source code. The default value is None, which means there is no parameter input in the cell.
83
- kwargs (dict{str,Union[MsDtypes, ParamTypes]}): The key type must be str,
84
- and the value must be value or type must be ParamTypes.
85
- The input parameter name used to describe the formal parameter with a keyword.
86
- Enter the name in the source code as the 'kwargs' in the statement expression.The default value is
87
- None, which means there is no 'kwargs' input.
88
-
89
- Returns:
90
- An instance of `Node`.
91
-
92
- Raises:
93
- TypeError: If `func` is not FunctionType.
94
- TypeError: If `targets` is not `list`.
95
- TypeError: If the type of `targets` is not str.
96
- TypeError: If arg in `args` is not ParamType.
97
- TypeError: If key of `kwarg` is not a str or value of kwarg in `kwargs` is not ParamType.
98
-
99
- Examples:
100
- >>> from mindspore.rewrite import SymbolTree
101
- >>> from lenet import Lenet
102
- >>> net = Lenet()
103
- >>> stree = SymbolTree.create(net)
104
- >>> node = stree.get_node("conv1")
105
- >>> new_node = stree.create_call_function(F.abs, ["x"], node)
106
- """
73
+ def create_call_function(self, func, targets, *args, **kwargs): # pylint: disable=C0111
107
74
  Validator.check_value_type("func", func, [FunctionType], "SymbolTree node")
108
75
  Validator.check_element_type_of_iterable("targets", targets, [str], "SymbolTree node")
109
76
  args_ = list(args)
@@ -115,22 +82,9 @@ class SymbolTree:
115
82
  for key, value in kwargs.items():
116
83
  if isinstance(value, Node):
117
84
  kwargs[key] = value.get_handler()
118
- return Node(self._symbol_tree.create_call_function(func, targets, args_, kwargs))
85
+ return Node(self._symbol_tree._create_call_function(func, targets, args_, kwargs)) # pylint: disable=W0212
119
86
 
120
87
  def get_handler(self) -> SymbolTreeImpl:
121
- """
122
- Get handler of `SymbolTree` implementation.
123
-
124
- Returns:
125
- An instance of `SymbolTree`.
126
-
127
- Examples:
128
- >>> from mindspore.rewrite import SymbolTree
129
- >>> from lenet import Lenet
130
- >>> net = Lenet()
131
- >>> stree = SymbolTree.create(net)
132
- >>> handler = stree.get_handler()
133
- """
134
88
  return self._symbol_tree
135
89
 
136
90
  def nodes(self):
@@ -152,25 +106,6 @@ class SymbolTree:
152
106
  yield Node(node)
153
107
 
154
108
  def get_node(self, node_name: str) -> Optional[Node]:
155
- """
156
- Get node by `node_name`.
157
-
158
- Args:
159
- node_name (str): A string represents name of node.
160
-
161
- Returns:
162
- An instance of node if find else None.
163
-
164
- Raises:
165
- TypeError: If `node_name` is not `str`.
166
-
167
- Examples:
168
- >>> from mindspore.rewrite import SymbolTree
169
- >>> from lenet import Lenet
170
- >>> net = Lenet()
171
- >>> stree = SymbolTree.create(net)
172
- >>> node = stree.get_node("conv1")
173
- """
174
109
  Validator.check_value_type("node_name", node_name, [str], "SymbolTree")
175
110
  node_impl = self._symbol_tree.get_node(node_name)
176
111
  if node_impl is None:
@@ -354,16 +289,6 @@ class SymbolTree:
354
289
  self._symbol_tree.dump()
355
290
 
356
291
  def print_node_tabulate(self):
357
- """
358
- Print node information of graph.
359
-
360
- Examples:
361
- >>> from mindspore.rewrite import SymbolTree
362
- >>> from lenet import Lenet
363
- >>> net = Lenet()
364
- >>> stree = SymbolTree.create(net)
365
- >>> stree.print_node_tabulate()
366
- """
367
292
  self._symbol_tree.print_node_tabulate()
368
293
 
369
294
  def get_code(self) -> str:
@@ -16,7 +16,7 @@
16
16
  from typing import Optional
17
17
 
18
18
  from mindspore import log as logger
19
- from ..._checkparam import Validator
19
+ from mindspore import _checkparam as Validator
20
20
  from .symbol_tree import SymbolTree
21
21
  from .node import Node
22
22
  from .node_type import NodeType
@@ -23,6 +23,7 @@ class Registry(UserDict):
23
23
  """Registry class for registry functions for creating ast node."""
24
24
 
25
25
  def register(self, obj_str, obj):
26
+ """Register object by str."""
26
27
  if isinstance(obj_str, str):
27
28
  self[obj_str] = obj
28
29
 
@@ -17,11 +17,11 @@
17
17
  Define some ast helpers for manipulating python ast.
18
18
  """
19
19
 
20
- from .ast_finder import AstFinder, StrChecker, FindConstValueInInit
20
+ from .ast_finder import AstFinder, StrChecker, CheckPropertyIsUsed, GetPropertyOfObj
21
21
  from .ast_replacer import AstReplacer
22
22
  from .ast_modifier import AstModifier
23
23
  from .ast_creator import ast_args_creator, ast_assign_creator, ast_attributer_creator, ast_call_creator, \
24
24
  ast_create_arg_value, ast_index_creator, ast_keyword_creator, ast_kwargs_creator, ast_name_creator, \
25
25
  ast_num_creator, ast_str_creator, ast_subscript_creator
26
26
 
27
- __all__ = ["AstFinder", "AstReplacer", "AstModifier", "StrChecker"]
27
+ __all__ = ["AstFinder", "AstReplacer", "AstModifier", "StrChecker", "CheckPropertyIsUsed", "GetPropertyOfObj"]
@@ -49,14 +49,13 @@ def ast_call_creator(func: ast.AST, args: list, keywords: list):
49
49
 
50
50
  def ast_create_arg_value(value):
51
51
  """Create arg node by type."""
52
- from mindspore.rewrite.node import Node
53
52
  if isinstance(value, (int, float)):
54
53
  ast_value = ast_num_creator(value)
55
54
  elif isinstance(value, str):
56
55
  ast_value = ast_str_creator(value)
57
56
  elif value in (ms.float16, ms.float32, ms.float64):
58
57
  ast_value = ast_attributer_creator(".".join(["mindspore", str(value).lower()]))
59
- elif isinstance(value, Node):
58
+ elif isinstance(value, ms.rewrite.node.Node):
60
59
  ast_value = ast_str_creator(value.get_targets()[0])
61
60
  else:
62
61
  raise TypeError("Unsupported arg type: ", type(value))
@@ -160,3 +160,68 @@ class FindConstValueInInit(ast.NodeVisitor):
160
160
  self._hit = False
161
161
  self.generic_visit(self._context)
162
162
  return self._hit
163
+
164
+
165
+ class CheckPropertyIsUsed(ast.NodeVisitor):
166
+ """
167
+ Check whether a property is used.
168
+
169
+ Args:
170
+ node (ast.AST): An instance of ast node.
171
+ """
172
+ def __init__(self, node: ast.AST):
173
+ self._context = node
174
+ self._value = ""
175
+ self._attr = ""
176
+ self._hit = False
177
+
178
+ def visit_Attribute(self, node: ast.Attribute) -> Any: # pylint: disable=invalid-name
179
+ """Visit a node of type ast.Attribute."""
180
+ if isinstance(node.value, ast.Name) and node.value.id == self._value and node.attr == self._attr:
181
+ self._hit = True
182
+ return super(CheckPropertyIsUsed, self).generic_visit(node)
183
+
184
+ def generic_visit(self, node: ast.AST) -> Any:
185
+ """
186
+ An override method, iterating over all nodes and save target ast nodes.
187
+ """
188
+ if self._hit:
189
+ return
190
+ super(CheckPropertyIsUsed, self).generic_visit(node)
191
+
192
+ def check(self, value, attr) -> bool:
193
+ """
194
+ Check whether `value` and `attr` exists.
195
+ """
196
+ self._value = value
197
+ self._attr = attr
198
+ self._hit = False
199
+ self.generic_visit(self._context)
200
+ return self._hit
201
+
202
+
203
+ class GetPropertyOfObj(ast.NodeVisitor):
204
+ """
205
+ Check whether a property is used.
206
+
207
+ Args:
208
+ node (ast.AST): An instance of ast node.
209
+ """
210
+ def __init__(self, node: ast.AST):
211
+ self._context = node
212
+ self._property = set()
213
+
214
+ def visit_Assign(self, node: ast.Assign) -> Any: # pylint: disable=invalid-name
215
+ """Visit a node of type ast.Attribute."""
216
+ target = node.targets[0]
217
+ if isinstance(target, ast.Attribute) and isinstance(target.value, ast.Name) and target.value.id == "self":
218
+ self._property.add(target.attr)
219
+ return super(GetPropertyOfObj, self).generic_visit(node)
220
+
221
+ def get(self):
222
+ """
223
+ Check whether `value` and `attr` exists.
224
+ """
225
+ self._property = set()
226
+ self.generic_visit(self._context)
227
+ return self._property
@@ -241,8 +241,10 @@ class AstModifier(ast.NodeTransformer):
241
241
  An instance of ast.Assign which has been appended to 'init_func'.
242
242
  """
243
243
  return AstModifier.insert_assign_to_function(init_func, targets=targets,
244
- args=[ScopedValue.create_variable_value(field)],
245
- expr=ScopedValue(ValueType.NamingValue, "global_vars", "get"))
244
+ expr=ScopedValue(ValueType.NamingValue, "", "setattr"),
245
+ args=[ScopedValue(ValueType.NamingValue, "obj"),
246
+ ScopedValue.create_variable_value(field)])
247
+
246
248
 
247
249
  @staticmethod
248
250
  def create_call_assign(targets: [ScopedValue], expr: ScopedValue, args: [ScopedValue],
@@ -459,7 +461,7 @@ class AstModifier(ast.NodeTransformer):
459
461
 
460
462
  Args:
461
463
  src_argument (ScopedValue): An instance of ScopedValue represents new argument.
462
- dst_ast (ast.AST): Targets of ast.Assign.
464
+ dst_ast (ast.AST): Ast node to be updated by ScopedValue.
463
465
 
464
466
  Raises:
465
467
  TypeError: Input src_argument is not a ScopedValue
@@ -490,6 +492,12 @@ class AstModifier(ast.NodeTransformer):
490
492
  str(src_argument.type))
491
493
  dst_ast.n = src_argument.value
492
494
  return
495
+ if isinstance(dst_ast, ast.Str):
496
+ if src_argument.type not in [ValueType.StringValue]:
497
+ raise RuntimeError("src_argument should be a StringValue, but got:",
498
+ str(src_argument.type))
499
+ dst_ast.s = src_argument.value
500
+ return
493
501
  if isinstance(dst_ast, ast.Name):
494
502
  if src_argument.type not in [ValueType.NamingValue, ValueType.StringValue]:
495
503
  raise RuntimeError("src_argument.type should be ValueType.NamingValue or ValueType.StringValue.")
@@ -17,6 +17,7 @@
17
17
  from typing import Any, Tuple
18
18
  import ast
19
19
  from ast import FunctionDef
20
+ import astunparse
20
21
 
21
22
  from mindspore import log as logger
22
23
  from ..common import error_str
@@ -37,7 +38,8 @@ class FlattenRecursiveStmt(ast.NodeTransformer):
37
38
  ast.Call: ["args"],
38
39
  ast.BinOp: ["left", "right"],
39
40
  ast.BoolOp: ["values"],
40
- ast.unaryop: ["operand"],
41
+ ast.UnaryOp: ["operand"],
42
+ ast.Compare: ["left", "comparators"],
41
43
  }
42
44
 
43
45
  @staticmethod
@@ -54,7 +56,7 @@ class FlattenRecursiveStmt(ast.NodeTransformer):
54
56
  target_name = "function"
55
57
  elif isinstance(node, ast.Return):
56
58
  target_name = "return_value"
57
- elif isinstance(node, (ast.BinOp, ast.boolop, ast.UnaryOp)):
59
+ elif isinstance(node, (ast.BinOp, ast.BoolOp, ast.UnaryOp)):
58
60
  target_name = type(node.op).__name__.lower() + "_var"
59
61
  elif isinstance(node, ast.Tuple):
60
62
  target_name = type(node).__name__.lower() + "_var"
@@ -180,6 +182,20 @@ class FlattenRecursiveStmt(ast.NodeTransformer):
180
182
  child = node.body[index]
181
183
  if isinstance(child, ast.Assign):
182
184
  stmt = child.value
185
+ elif isinstance(child, ast.If):
186
+ if isinstance(child.body[0], ast.Return) and not isinstance(child.test, ast.UnaryOp):
187
+ if isinstance(child.body[0].value, ast.Call):
188
+ if_body = child.body
189
+ if_func = if_body[0].value
190
+ expr = "x = " + astunparse.unparse(if_func)
191
+ if_body = ast.parse(expr)
192
+ if_body = if_body.body+ast.parse("return x").body
193
+ child.body = if_body
194
+ stmt = child
195
+ else:
196
+ stmt = child
197
+ else:
198
+ stmt = child
183
199
  elif isinstance(child, ast.Expr):
184
200
  stmt = child.value
185
201
  else:
@@ -24,8 +24,6 @@ _ms_functional_ns = CellNamespace('mindspore.ops.functional')
24
24
 
25
25
  def is_subtree(cls_name):
26
26
  """Determine whether 'cls_name' is a subtree."""
27
- if cls_name == "SequentialCell":
28
- return True
29
27
  if cls_name == "QuantizeWrapperCell":
30
28
  return False
31
29
  if cls_name in _ms_common_ns or cls_name in _ms_nn_ns or cls_name in _ms_ops_ns: