mindspore 2.1.0__cp38-cp38-manylinux1_x86_64.whl → 2.2.0__cp38-cp38-manylinux1_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (550) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -1
  3. mindspore/_akg/akg/build_module.py +5 -6
  4. mindspore/_akg/akg/composite/build_module.py +49 -16
  5. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  6. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  7. mindspore/_akg/akg/tvm/api.py +4 -3
  8. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  9. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  10. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  11. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  12. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  13. mindspore/_akg/akg/tvm/build_module.py +16 -1
  14. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  15. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  16. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  17. mindspore/_akg/akg/tvm/module.py +1 -2
  18. mindspore/_akg/akg/tvm/stmt.py +2 -2
  19. mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
  20. mindspore/_akg/akg/utils/kernel_exec.py +58 -260
  21. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  22. mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
  23. mindspore/_c_dataengine.cpython-38-x86_64-linux-gnu.so +0 -0
  24. mindspore/_c_expression.cpython-38-x86_64-linux-gnu.so +0 -0
  25. mindspore/_c_mindrecord.cpython-38-x86_64-linux-gnu.so +0 -0
  26. mindspore/_check_jit_forbidden_api.py +3 -1
  27. mindspore/_checkparam.py +26 -32
  28. mindspore/_extends/graph_kernel/__init__.py +0 -1
  29. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  30. mindspore/_extends/graph_kernel/splitter.py +1 -9
  31. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
  32. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
  33. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  34. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
  35. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +4 -4
  36. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  37. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  38. mindspore/_extends/parse/__init__.py +12 -15
  39. mindspore/_extends/parse/namespace.py +7 -33
  40. mindspore/_extends/parse/parser.py +61 -71
  41. mindspore/_extends/parse/resources.py +1 -1
  42. mindspore/_extends/parse/standard_method.py +72 -95
  43. mindspore/_extends/parse/trope.py +1 -1
  44. mindspore/_extends/remote/kernel_build_server.py +24 -7
  45. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  46. mindspore/_install_custom.py +43 -0
  47. mindspore/_mindspore_offline_debug.cpython-38-x86_64-linux-gnu.so +0 -0
  48. mindspore/amp.py +47 -11
  49. mindspore/bin/cache_admin +0 -0
  50. mindspore/bin/cache_server +0 -0
  51. mindspore/boost/boost.py +1 -8
  52. mindspore/boost/boost_cell_wrapper.py +3 -2
  53. mindspore/boost/grad_accumulation.py +1 -1
  54. mindspore/boost/group_loss_scale_manager.py +8 -7
  55. mindspore/common/__init__.py +5 -3
  56. mindspore/common/_jit_fallback_utils.py +6 -0
  57. mindspore/common/_register_for_adapter.py +2 -0
  58. mindspore/common/_register_for_tensor.py +2 -2
  59. mindspore/common/_stub_tensor.py +13 -0
  60. mindspore/common/_utils.py +13 -0
  61. mindspore/common/api.py +173 -258
  62. mindspore/common/auto_dynamic_shape.py +498 -0
  63. mindspore/common/dtype.py +18 -11
  64. mindspore/common/dump.py +6 -4
  65. mindspore/common/initializer.py +14 -14
  66. mindspore/common/jit_config.py +33 -15
  67. mindspore/common/lazy_inline.py +126 -7
  68. mindspore/common/mindir_util.py +101 -0
  69. mindspore/common/parameter.py +51 -41
  70. mindspore/common/seed.py +4 -4
  71. mindspore/common/sparse_tensor.py +13 -14
  72. mindspore/common/tensor.py +240 -145
  73. mindspore/communication/__init__.py +7 -4
  74. mindspore/communication/_comm_helper.py +83 -4
  75. mindspore/communication/management.py +152 -84
  76. mindspore/config/op_info.config +13 -2
  77. mindspore/config/super_bar_config.json +4 -2
  78. mindspore/context.py +143 -59
  79. mindspore/dataset/__init__.py +5 -5
  80. mindspore/dataset/audio/__init__.py +2 -2
  81. mindspore/dataset/audio/transforms.py +52 -52
  82. mindspore/dataset/callback/ds_callback.py +16 -2
  83. mindspore/dataset/core/config.py +68 -51
  84. mindspore/dataset/engine/cache_client.py +28 -5
  85. mindspore/dataset/engine/datasets.py +250 -112
  86. mindspore/dataset/engine/datasets_audio.py +43 -211
  87. mindspore/dataset/engine/datasets_standard_format.py +11 -35
  88. mindspore/dataset/engine/datasets_text.py +43 -67
  89. mindspore/dataset/engine/datasets_user_defined.py +86 -100
  90. mindspore/dataset/engine/datasets_vision.py +219 -1029
  91. mindspore/dataset/engine/iterators.py +11 -4
  92. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
  93. mindspore/dataset/engine/obs/util.py +3 -0
  94. mindspore/dataset/engine/samplers.py +1 -1
  95. mindspore/dataset/engine/validators.py +19 -5
  96. mindspore/dataset/text/__init__.py +3 -3
  97. mindspore/dataset/text/transforms.py +101 -127
  98. mindspore/dataset/text/utils.py +205 -138
  99. mindspore/dataset/transforms/__init__.py +1 -1
  100. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  101. mindspore/dataset/transforms/transforms.py +95 -40
  102. mindspore/dataset/utils/browse_dataset.py +8 -2
  103. mindspore/dataset/utils/line_reader.py +17 -19
  104. mindspore/dataset/vision/__init__.py +3 -3
  105. mindspore/dataset/vision/c_transforms.py +6 -3
  106. mindspore/dataset/vision/transforms.py +409 -287
  107. mindspore/dataset/vision/utils.py +13 -14
  108. mindspore/dataset/vision/validators.py +11 -1
  109. mindspore/experimental/map_parameter.py +14 -0
  110. mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
  111. mindspore/{nn/optim_ex → experimental/optim}/adam.py +59 -66
  112. mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
  113. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  114. mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
  115. mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
  116. mindspore/gen_ops.py +273 -0
  117. mindspore/include/OWNERS +0 -1
  118. mindspore/include/api/data_type.h +2 -1
  119. mindspore/include/api/graph.h +0 -15
  120. mindspore/include/api/kernel.h +2 -0
  121. mindspore/include/api/kernel_api.h +37 -12
  122. mindspore/include/api/model.h +0 -14
  123. mindspore/include/api/types.h +37 -4
  124. mindspore/include/c_api/ms/abstract.h +67 -0
  125. mindspore/include/c_api/ms/attribute.h +197 -0
  126. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  127. mindspore/include/c_api/ms/base/macros.h +32 -0
  128. mindspore/include/c_api/ms/base/status.h +33 -0
  129. mindspore/include/c_api/ms/base/types.h +282 -0
  130. mindspore/include/c_api/ms/context.h +102 -0
  131. mindspore/include/c_api/ms/graph.h +160 -0
  132. mindspore/include/c_api/ms/node.h +606 -0
  133. mindspore/include/c_api/ms/tensor.h +161 -0
  134. mindspore/include/c_api/ms/value.h +84 -0
  135. mindspore/include/dataset/constants.h +6 -5
  136. mindspore/include/dataset/execute.h +23 -13
  137. mindspore/include/dataset/text.h +26 -26
  138. mindspore/include/dataset/transforms.h +13 -13
  139. mindspore/include/dataset/vision.h +60 -60
  140. mindspore/include/dataset/vision_ascend.h +5 -6
  141. mindspore/include/dataset/vision_lite.h +17 -17
  142. mindspore/include/mindapi/base/type_id.h +1 -0
  143. mindspore/include/mindapi/base/types.h +1 -0
  144. mindspore/lib/libdnnl.so.2 +0 -0
  145. mindspore/lib/libjemalloc.so.2 +0 -0
  146. mindspore/lib/libmindspore.so +0 -0
  147. mindspore/lib/libmindspore_backend.so +0 -0
  148. mindspore/lib/libmindspore_common.so +0 -0
  149. mindspore/lib/libmindspore_core.so +0 -0
  150. mindspore/lib/libmindspore_glog.so.0 +0 -0
  151. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  152. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  153. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  154. mindspore/lib/libmindspore_shared_lib.so +0 -0
  155. mindspore/lib/libnnacl.so +0 -0
  156. mindspore/lib/libopencv_core.so.4.5 +0 -0
  157. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  158. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  159. mindspore/lib/libps_cache.so +0 -0
  160. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  161. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  162. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
  163. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  164. mindspore/lib/plugin/ascend/libakg.so +0 -0
  165. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  166. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  167. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  168. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  169. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  170. mindspore/lib/plugin/cpu/libakg.so +0 -0
  171. mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
  172. mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
  173. mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
  174. mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
  175. mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
  176. mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
  177. mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
  178. mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
  179. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  180. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  181. mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
  182. mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
  183. mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
  184. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  185. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  186. mindspore/nn/__init__.py +0 -2
  187. mindspore/nn/cell.py +316 -74
  188. mindspore/nn/dynamic_lr.py +21 -21
  189. mindspore/nn/layer/activation.py +21 -28
  190. mindspore/nn/layer/basic.py +15 -13
  191. mindspore/nn/layer/channel_shuffle.py +1 -1
  192. mindspore/nn/layer/container.py +271 -9
  193. mindspore/nn/layer/conv.py +310 -207
  194. mindspore/nn/layer/dense.py +8 -5
  195. mindspore/nn/layer/embedding.py +33 -27
  196. mindspore/nn/layer/flash_attention.py +82 -41
  197. mindspore/nn/layer/image.py +8 -6
  198. mindspore/nn/layer/math.py +13 -18
  199. mindspore/nn/layer/normalization.py +107 -66
  200. mindspore/nn/layer/padding.py +1 -1
  201. mindspore/nn/layer/pooling.py +131 -109
  202. mindspore/nn/layer/rnn_cells.py +22 -17
  203. mindspore/nn/layer/rnns.py +13 -16
  204. mindspore/nn/layer/thor_layer.py +1 -1
  205. mindspore/nn/layer/transformer.py +221 -154
  206. mindspore/nn/learning_rate_schedule.py +9 -1
  207. mindspore/nn/loss/loss.py +235 -174
  208. mindspore/nn/optim/ada_grad.py +2 -1
  209. mindspore/nn/optim/adadelta.py +1 -0
  210. mindspore/nn/optim/adafactor.py +2 -1
  211. mindspore/nn/optim/adam.py +7 -4
  212. mindspore/nn/optim/adamax.py +3 -2
  213. mindspore/nn/optim/adasum.py +2 -2
  214. mindspore/nn/optim/asgd.py +2 -3
  215. mindspore/nn/optim/ftrl.py +6 -5
  216. mindspore/nn/optim/lamb.py +7 -4
  217. mindspore/nn/optim/lars.py +1 -1
  218. mindspore/nn/optim/lazyadam.py +5 -3
  219. mindspore/nn/optim/momentum.py +2 -1
  220. mindspore/nn/optim/optimizer.py +53 -4
  221. mindspore/nn/optim/proximal_ada_grad.py +3 -4
  222. mindspore/nn/optim/rmsprop.py +4 -3
  223. mindspore/nn/optim/rprop.py +23 -12
  224. mindspore/nn/optim/sgd.py +26 -11
  225. mindspore/nn/optim/thor.py +9 -7
  226. mindspore/nn/probability/bijector/bijector.py +5 -5
  227. mindspore/nn/probability/bijector/power_transform.py +27 -27
  228. mindspore/nn/probability/bijector/softplus.py +3 -3
  229. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
  230. mindspore/nn/probability/distribution/bernoulli.py +5 -5
  231. mindspore/nn/probability/distribution/beta.py +3 -3
  232. mindspore/nn/probability/distribution/categorical.py +7 -7
  233. mindspore/nn/probability/distribution/cauchy.py +0 -1
  234. mindspore/nn/probability/distribution/distribution.py +3 -3
  235. mindspore/nn/probability/distribution/gamma.py +3 -3
  236. mindspore/nn/probability/distribution/geometric.py +4 -4
  237. mindspore/nn/probability/distribution/gumbel.py +4 -4
  238. mindspore/nn/probability/distribution/log_normal.py +2 -2
  239. mindspore/nn/probability/distribution/logistic.py +2 -2
  240. mindspore/nn/probability/distribution/poisson.py +4 -4
  241. mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
  242. mindspore/nn/probability/distribution/uniform.py +6 -6
  243. mindspore/nn/wrap/cell_wrapper.py +78 -34
  244. mindspore/nn/wrap/grad_reducer.py +8 -5
  245. mindspore/nn/wrap/loss_scale.py +105 -42
  246. mindspore/numpy/array_creations.py +1 -2
  247. mindspore/numpy/array_ops.py +3 -2
  248. mindspore/offline_debug/convert_async.py +2 -2
  249. mindspore/ops/_grad_experimental/__init__.py +0 -5
  250. mindspore/ops/_grad_experimental/grad_array_ops.py +1 -2
  251. mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
  252. mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
  253. mindspore/ops/_grad_experimental/grad_implementations.py +10 -0
  254. mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
  255. mindspore/ops/_grad_experimental/grad_math_ops.py +0 -181
  256. mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
  257. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  258. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +165 -109
  259. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +144 -86
  260. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +172 -187
  261. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +51 -57
  262. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +6 -17
  263. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +1 -1
  264. mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
  265. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  266. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  267. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  268. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  269. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  270. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  271. mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
  272. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  273. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  274. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  275. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  276. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  277. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  278. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  279. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
  280. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
  281. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  282. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  283. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  284. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  285. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  286. mindspore/ops/_op_impl/tbe/__init__.py +4 -4
  287. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  288. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  289. mindspore/ops/_primitive_cache.py +1 -1
  290. mindspore/ops/_tracefunc.py +45 -13
  291. mindspore/ops/_utils/utils.py +4 -1
  292. mindspore/ops/_vmap/vmap_array_ops.py +3 -3
  293. mindspore/ops/_vmap/vmap_base.py +3 -3
  294. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  295. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  296. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  297. mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
  298. mindspore/ops/arg_dtype_cast.py +54 -0
  299. mindspore/ops/composite/base.py +37 -10
  300. mindspore/ops/composite/math_ops.py +5 -4
  301. mindspore/ops/composite/multitype_ops/_compile_utils.py +273 -72
  302. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
  303. mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
  304. mindspore/ops/composite/multitype_ops/getitem_impl.py +40 -2
  305. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  306. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  307. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  308. mindspore/ops/deprecated.py +304 -0
  309. mindspore/ops/function/__init__.py +4 -1
  310. mindspore/ops/function/array_func.py +167 -189
  311. mindspore/ops/function/clip_func.py +81 -13
  312. mindspore/ops/function/debug_func.py +1 -1
  313. mindspore/ops/function/grad/grad_func.py +18 -8
  314. mindspore/ops/function/image_func.py +10 -4
  315. mindspore/ops/function/linalg_func.py +5 -5
  316. mindspore/ops/function/math_func.py +575 -386
  317. mindspore/ops/function/nn_func.py +470 -251
  318. mindspore/ops/function/random_func.py +86 -56
  319. mindspore/ops/function/sparse_func.py +1 -1
  320. mindspore/ops/function/sparse_unary_func.py +14 -12
  321. mindspore/ops/function/vmap_func.py +6 -5
  322. mindspore/ops/functional.py +15 -10
  323. mindspore/ops/op_info_register.py +235 -19
  324. mindspore/ops/operations/__init__.py +25 -17
  325. mindspore/ops/operations/_grad_ops.py +52 -7
  326. mindspore/ops/operations/_inner_ops.py +213 -12
  327. mindspore/ops/operations/_quant_ops.py +4 -8
  328. mindspore/ops/operations/_sequence_ops.py +42 -0
  329. mindspore/ops/operations/array_ops.py +64 -280
  330. mindspore/ops/operations/comm_ops.py +105 -57
  331. mindspore/ops/operations/custom_ops.py +10 -3
  332. mindspore/ops/operations/debug_ops.py +8 -4
  333. mindspore/ops/operations/image_ops.py +18 -12
  334. mindspore/ops/operations/math_ops.py +185 -138
  335. mindspore/ops/operations/nn_ops.py +716 -492
  336. mindspore/ops/operations/other_ops.py +0 -22
  337. mindspore/ops/operations/random_ops.py +53 -111
  338. mindspore/ops/operations/sparse_ops.py +3 -1
  339. mindspore/ops/primitive.py +24 -18
  340. mindspore/parallel/_auto_parallel_context.py +68 -8
  341. mindspore/parallel/_cost_model_context.py +2 -2
  342. mindspore/parallel/_offload_context.py +17 -3
  343. mindspore/parallel/_parallel_serialization.py +2 -2
  344. mindspore/parallel/_ps_context.py +12 -0
  345. mindspore/parallel/_tensor.py +14 -12
  346. mindspore/parallel/_transformer/layers.py +5 -3
  347. mindspore/parallel/_transformer/loss.py +1 -0
  348. mindspore/parallel/_transformer/moe.py +2 -2
  349. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  350. mindspore/parallel/_transformer/transformer.py +23 -3
  351. mindspore/parallel/_utils.py +11 -7
  352. mindspore/parallel/algo_parameter_config.py +85 -5
  353. mindspore/parallel/checkpoint_transform.py +6 -10
  354. mindspore/parallel/shard.py +4 -4
  355. mindspore/profiler/common/struct_type.py +3 -3
  356. mindspore/profiler/common/util.py +3 -2
  357. mindspore/profiler/envprofiling.py +1 -1
  358. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  359. mindspore/profiler/parser/ascend_flops_generator.py +2 -2
  360. mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
  361. mindspore/profiler/parser/ascend_hccl_generator.py +17 -12
  362. mindspore/profiler/parser/ascend_msprof_exporter.py +104 -252
  363. mindspore/profiler/parser/ascend_msprof_generator.py +8 -8
  364. mindspore/profiler/parser/ascend_op_generator.py +5 -5
  365. mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
  366. mindspore/profiler/parser/ascend_timeline_generator.py +9 -6
  367. mindspore/profiler/parser/base_timeline_generator.py +9 -7
  368. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +14 -10
  369. mindspore/profiler/parser/flops_parser.py +15 -11
  370. mindspore/profiler/parser/framework_parser.py +37 -21
  371. mindspore/profiler/parser/hccl_parser.py +16 -12
  372. mindspore/profiler/parser/integrator.py +22 -11
  373. mindspore/profiler/parser/memory_usage_parser.py +2 -2
  374. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  375. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  376. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  377. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  378. mindspore/profiler/parser/optime_parser.py +1 -1
  379. mindspore/profiler/parser/profiler_info.py +2 -2
  380. mindspore/profiler/parser/step_trace_parser.py +11 -14
  381. mindspore/profiler/profiling.py +139 -71
  382. mindspore/rewrite/api/node.py +102 -19
  383. mindspore/rewrite/api/node_type.py +5 -1
  384. mindspore/rewrite/api/scoped_value.py +9 -17
  385. mindspore/rewrite/api/symbol_tree.py +131 -47
  386. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  387. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  388. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  389. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
  390. mindspore/rewrite/common/rewrite_elog.py +5 -1
  391. mindspore/rewrite/namer.py +33 -24
  392. mindspore/rewrite/namespace.py +14 -5
  393. mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
  394. mindspore/rewrite/node/call_function.py +79 -0
  395. mindspore/rewrite/node/cell_container.py +135 -0
  396. mindspore/rewrite/node/control_flow.py +88 -0
  397. mindspore/rewrite/{node.py → node/node.py} +273 -234
  398. mindspore/rewrite/node/node_manager.py +254 -0
  399. mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
  400. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  401. mindspore/rewrite/parsers/assign_parser.py +216 -221
  402. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  403. mindspore/rewrite/parsers/class_def_parser.py +174 -113
  404. mindspore/rewrite/parsers/constant_parser.py +9 -6
  405. mindspore/rewrite/parsers/container_parser.py +9 -7
  406. mindspore/rewrite/parsers/for_parser.py +36 -15
  407. mindspore/rewrite/parsers/function_def_parser.py +24 -16
  408. mindspore/rewrite/parsers/if_parser.py +28 -24
  409. mindspore/rewrite/parsers/module_parser.py +196 -25
  410. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  411. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  412. mindspore/rewrite/parsers/return_parser.py +6 -6
  413. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  414. mindspore/rewrite/sparsify/utils.py +1 -1
  415. mindspore/rewrite/symbol_tree.py +525 -577
  416. mindspore/rewrite/symbol_tree_builder.py +9 -193
  417. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  418. mindspore/run_check/_check_version.py +2 -2
  419. mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
  420. mindspore/safeguard/rewrite_obfuscation.py +517 -0
  421. mindspore/scipy/linalg.py +1 -1
  422. mindspore/scipy/optimize/minimize.py +7 -3
  423. mindspore/train/_utils.py +7 -3
  424. mindspore/train/amp.py +323 -123
  425. mindspore/train/anf_ir_pb2.py +14 -2
  426. mindspore/train/callback/_backup_and_restore.py +2 -12
  427. mindspore/train/callback/_callback.py +29 -4
  428. mindspore/train/callback/_checkpoint.py +23 -8
  429. mindspore/train/callback/_early_stop.py +2 -2
  430. mindspore/train/callback/_landscape.py +4 -4
  431. mindspore/train/callback/_loss_monitor.py +2 -2
  432. mindspore/train/callback/_on_request_exit.py +2 -2
  433. mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
  434. mindspore/train/callback/_summary_collector.py +14 -7
  435. mindspore/train/callback/_time_monitor.py +58 -5
  436. mindspore/train/data_sink.py +5 -11
  437. mindspore/train/dataset_helper.py +83 -57
  438. mindspore/train/loss_scale_manager.py +2 -2
  439. mindspore/train/metrics/__init__.py +3 -3
  440. mindspore/train/metrics/cosine_similarity.py +1 -1
  441. mindspore/train/metrics/hausdorff_distance.py +3 -2
  442. mindspore/train/metrics/mean_surface_distance.py +3 -2
  443. mindspore/train/metrics/metric.py +39 -19
  444. mindspore/train/metrics/roc.py +2 -2
  445. mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
  446. mindspore/train/mind_ir_pb2.py +85 -36
  447. mindspore/train/model.py +185 -45
  448. mindspore/train/serialization.py +390 -150
  449. mindspore/train/summary/_writer_pool.py +3 -2
  450. mindspore/train/summary/summary_record.py +14 -10
  451. mindspore/train/train_thor/convert_utils.py +3 -3
  452. mindspore/train/train_thor/dataset_helper.py +1 -1
  453. mindspore/version.py +1 -1
  454. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/METADATA +6 -7
  455. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/RECORD +458 -518
  456. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
  457. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  458. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  459. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  460. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  461. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  462. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  463. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  464. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  465. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  466. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  467. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  468. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  469. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  470. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  471. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  472. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  473. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  474. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  475. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  476. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  477. mindspore/_extends/graph_kernel/expander.py +0 -80
  478. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
  479. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  480. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  481. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  482. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  483. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  484. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  485. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  486. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  487. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  488. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  489. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  490. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  491. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  492. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  493. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  494. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  495. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  496. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  497. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  498. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  499. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  500. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  501. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  502. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  503. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  504. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  505. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  506. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  507. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  508. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  509. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  510. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  511. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  512. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  513. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  514. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  515. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  516. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  517. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  518. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  519. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  520. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  521. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  522. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  523. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  524. mindspore/dataset/datapreprocess/__init__.py +0 -20
  525. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  526. mindspore/include/api/net.h +0 -142
  527. mindspore/nn/lr_scheduler.py +0 -262
  528. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
  529. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
  530. mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
  531. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  532. mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
  533. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
  534. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
  535. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
  536. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  537. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  538. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
  539. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
  540. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  541. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  542. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
  543. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
  544. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
  545. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
  546. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
  547. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  548. mindspore/rewrite/node_visitor.py +0 -44
  549. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
  550. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
@@ -14,37 +14,39 @@
14
14
  # ============================================================================
15
15
  """Parse ast.Assign in construct function to node of SymbolTree."""
16
16
  from typing import Union
17
+ import os
17
18
  import ast
18
19
  import sys
19
20
  import inspect
20
- import astunparse
21
21
 
22
22
  from mindspore import log as logger
23
- from mindspore._extends.parse.namespace import CellNamespace
24
23
  from mindspore.nn import Cell, SequentialCell
25
- from mindspore.ops import operations as P
26
24
  from mindspore.ops import Primitive
27
- from mindspore.rewrite.parser_register import ParserRegister
25
+ from mindspore.rewrite.parsers.parser_register import ParserRegister, reg_parser
28
26
  from mindspore.rewrite.namespace import is_subtree, is_functional, get_functional
29
27
  from mindspore.rewrite.symbol_tree import SymbolTree
30
- from mindspore.rewrite.node import Node, TreeNode, CellContainer
31
- from mindspore.rewrite.parser import Parser
32
- from mindspore.rewrite.parser_register import reg_parser
28
+ from mindspore.rewrite.node.node import Node, TreeNode
29
+ from mindspore.rewrite.node.node_manager import NodeManager
30
+ from mindspore.rewrite.node.call_function import CallFunction
31
+ from mindspore.rewrite.node.cell_container import CellContainer
32
+ from mindspore.rewrite.parsers.parser import Parser
33
33
  from mindspore.rewrite.api.scoped_value import ScopedValue, ValueType
34
- from mindspore.rewrite.symbol_tree_builder import SymbolTreeBuilder, FunctionSymbolTreeBuilder
34
+ from mindspore.rewrite.symbol_tree_builder import SymbolTreeBuilder
35
+ from mindspore.rewrite.ast_transformers.flatten_recursive_stmt import FlattenRecursiveStmt
35
36
  from mindspore.rewrite.ast_helpers import AstReplacer
36
- from mindspore.rewrite.common.event import Event
37
37
  from ..common import error_str
38
38
 
39
+ if sys.version_info >= (3, 9):
40
+ import ast as astunparse # pylint: disable=reimported, ungrouped-imports
41
+ else:
42
+ import astunparse
43
+
39
44
 
40
45
  class AssignParser(Parser):
41
46
  """Parse ast.Assign in construct function to node of SymbolTree."""
42
47
 
43
- def __init__(self):
44
- """Constructor"""
45
- super(AssignParser, self).__init__()
46
- self._cell_namespce = CellNamespace('mindspore.nn')
47
- self._primitive_namespce = CellNamespace('mindspore.ops.operations')
48
+ # Types for creating Cell Container node
49
+ types_for_cell_container = [SequentialCell,]
48
50
 
49
51
  def target(self):
50
52
  """Parse target type."""
@@ -68,9 +70,9 @@ class AssignParser(Parser):
68
70
  tuple_values = []
69
71
  for tuple_elt in tuple_elts:
70
72
  if not isinstance(tuple_elt, (ast.Constant, ast.Name, ast.Attribute)):
71
- raise RuntimeError(f"Only support ast.Constant or ast.Name as elts of ast.Tuple, "
72
- f"but got ast type {type(tuple_elt).__name__}",
73
- child_node=tuple_elt, father_node=node)
73
+ raise RuntimeError(error_str(f"Only support ast.Constant or ast.Name as elts of ast.Tuple, "
74
+ f"but got ast type {type(tuple_elt).__name__}",
75
+ child_node=tuple_elt, father_node=node))
74
76
  if isinstance(tuple_elt, ast.Constant):
75
77
  tuple_values.append(tuple_elt.value)
76
78
  elif isinstance(tuple_elt, ast.Name):
@@ -116,12 +118,12 @@ class AssignParser(Parser):
116
118
  father_node=node))
117
119
 
118
120
  @staticmethod
119
- def _get_func_name(ast_node: ast.Call) -> str:
121
+ def _get_func_name(ast_call: ast.Call) -> str:
120
122
  """
121
123
  Get the func name from ast.Call.
122
124
 
123
125
  Args:
124
- ast_node (ast.Call): Input ast.Call node.
126
+ ast_call (ast.Call): Input ast.Call node.
125
127
 
126
128
  Returns:
127
129
  Func name.
@@ -129,7 +131,7 @@ class AssignParser(Parser):
129
131
  Raises:
130
132
  RuntimeError: Func of input ast node is not ast.Name or ast.Attribute.
131
133
  """
132
- func = ast_node.func
134
+ func = ast_call.func
133
135
  if isinstance(func, ast.Name):
134
136
  return func.id
135
137
  if isinstance(func, ast.Attribute):
@@ -137,15 +139,16 @@ class AssignParser(Parser):
137
139
  if isinstance(func, ast.Call):
138
140
  return AssignParser._get_func_name(func)
139
141
  raise RuntimeError(error_str(f"funcValue should be Name or a Attribute or a Call, but got ast type "
140
- f"'{type(func).__name__}'", child_node=func, father_node=ast_node))
142
+ f"'{type(func).__name__}'", child_node=func, father_node=ast_call))
141
143
 
142
144
  @staticmethod
143
- def _get_func_scope(ast_node: ast.Call) -> str:
145
+ def _get_func_scope(ast_call: ast.Call, node_manager: NodeManager = None) -> str:
144
146
  """
145
147
  Get the func scope from ast.Call.
146
148
 
147
149
  Args:
148
- ast_node (ast.Call): Input ast.Call node.
150
+ ast_call (ast.Call): Input ast.Call node.
151
+ node_manager (NodeManager): NodeManager those asts belong to.
149
152
 
150
153
  Returns:
151
154
  Func scope.
@@ -154,17 +157,17 @@ class AssignParser(Parser):
154
157
  RuntimeError: FuncValue is not an ast.Name when func is an ast.Attribute.
155
158
  RuntimeError: Func of input ast node is not ast.Name or ast.Attribute.
156
159
  """
157
- func = ast_node.func
160
+ func = ast_call.func
158
161
  if isinstance(func, ast.Name):
159
162
  return ""
160
163
  if isinstance(func, ast.Attribute):
161
164
  parser = ParserRegister.instance().get_parser(type(func))
162
- value = parser.process(None, func)
165
+ value = parser.process(None, func, node_manager)
163
166
  return value.rsplit(".", 1)[0]
164
167
  if isinstance(func, ast.Call):
165
- return AssignParser._get_func_scope(func)
168
+ return AssignParser._get_func_scope(func, node_manager)
166
169
  raise RuntimeError(error_str(f"funcValue should be Name or a Attribute or a Call, but got ast type "
167
- f"'{type(func).__name__}'", child_node=func, father_node=ast_node))
170
+ f"'{type(func).__name__}'", child_node=func, father_node=ast_call))
168
171
 
169
172
  @staticmethod
170
173
  def _get_symbol_object(symbol_name, origin_net):
@@ -205,9 +208,9 @@ class AssignParser(Parser):
205
208
  return results
206
209
 
207
210
  @staticmethod
208
- def _find_op_and_type(func_scope, func_name, stree: SymbolTree):
211
+ def _get_call_instance(func_scope, func_name, stree: SymbolTree):
209
212
  """
210
- Get the func scope from ast.Call.
213
+ Get object instance from ast.Call with type of Cell or Primitive.
211
214
 
212
215
  Args:
213
216
  func_scope (str): Func scope.
@@ -215,21 +218,21 @@ class AssignParser(Parser):
215
218
  stree (SymbolTree): Belong SymbolTree.
216
219
 
217
220
  Returns:
218
- A type represents type of op and an instance represents operator instance.
221
+ An instance represents operator instance.
219
222
  """
220
-
221
223
  if func_scope != "self":
222
- logger.warning("Not support parse operator which is instantiated at runtime now: %s; name: %s", func_scope,
223
- func_name)
224
+ return None
224
225
  var_dict = stree.get_origin_network().__dict__
226
+ # Instance is of type Cell
225
227
  for key, value in var_dict["_cells"].items():
226
228
  if key == func_name:
227
- return type(value), value
228
-
229
+ return value
230
+ # Instance is of type Primitive
229
231
  for key, value in var_dict["_primitives"].items():
230
232
  if key == func_name:
231
- return type(value), value
232
- return type(None), None
233
+ return value
234
+ # Instance is of other type.
235
+ return None
233
236
 
234
237
  @staticmethod
235
238
  def _get_targets(all_targets: ScopedValue) -> [Union[ScopedValue, str]]:
@@ -240,7 +243,7 @@ class AssignParser(Parser):
240
243
  if not isinstance(single_target, ScopedValue) and not isinstance(single_target.value, str):
241
244
  raise RuntimeError(f"For MindSpore Rewrite, only support str target in tuple, but got type "
242
245
  f"{type(single_target).__name__}")
243
- if single_target.type == ValueType.StringValue:
246
+ if single_target.type == ValueType.ConstantValue and isinstance(single_target.value, str):
244
247
  single_target.type = ValueType.NamingValue
245
248
  targets.append(single_target)
246
249
  else:
@@ -251,18 +254,7 @@ class AssignParser(Parser):
251
254
  def _update_field_in_init(func_scope, func_name, stree: SymbolTree, sub_tree: SymbolTree) -> bool:
252
255
  """
253
256
  When node is an invoking to sub-network, update value of ast.Assign of corresponding field in `__init__` method.
254
-
255
- Update from:
256
-
257
- .. code-block::
258
-
259
- self.field = getattr(self._handler, "field")
260
-
261
- to:
262
-
263
- .. code-block::
264
-
265
- self.field = SubNetwork(global_vars.get("field_args"))
257
+ Add the code like: `self.field = SubNetwork(self.field)`
266
258
 
267
259
  Args:
268
260
  func_scope (str): A string represents scope of function symbol.
@@ -278,39 +270,24 @@ class AssignParser(Parser):
278
270
  logger.warning("Not support parse operator which is instantiated at runtime now: %s; name: %s", func_scope,
279
271
  func_name)
280
272
  init_func_ast = stree.get_init_func_ast()
281
- class_name = sub_tree.get_opt_cls_name()
282
- setattr(stree.get_origin_network(), func_name, sub_tree.get_origin_network())
273
+ sub_net_obj = sub_tree.get_origin_network()
274
+ sub_net_opt_name = sub_tree.get_opt_cls_name()
283
275
  # Add .to_float(mindspore.float16) if origin subnet has this attribute
284
- if hasattr(sub_tree.get_origin_network(), "to_float_fp16")\
285
- and sub_tree.get_origin_network().to_float_fp16:
286
- new_code = f"self.{func_name} = {class_name}(getattr(self, '{func_name}')).to_float(mindspore.float16)"
287
- else:
288
- new_code = f"self.{func_name} = {class_name}(getattr(self, '{func_name}'))"
276
+ new_code = f"{func_scope}.{func_name} = {sub_net_opt_name}({func_scope}.{func_name})"
277
+ if hasattr(sub_net_obj, "fp16") and sub_net_obj.fp16:
278
+ new_code = f"{new_code}.to_float(mindspore.float16)"
279
+ elif hasattr(sub_net_obj, "bf16") and sub_net_obj.bf16:
280
+ new_code = f"{new_code}.to_float(mindspore.bfloat16)"
289
281
  new_ast = ast.parse(new_code).body[0]
290
282
  init_func_ast.body.append(new_ast)
291
- return True
292
-
293
- @staticmethod
294
- def _convert_ast_binop_to_node(ast_node: ast.BinOp, father_ast_node: ast.Assign) -> Node:
295
- """convert ast.BinOp to Node"""
296
-
297
- # only support ast.Add now
298
- op = P.Add()
299
- func_ast = ast.Attribute(value=ast.Name(id='F', ctx=ast.Load()), attr='add', ctx=ast.Load())
300
- func = ScopedValue.create_naming_value('add', 'F')
301
-
302
- father_ast_node.value = ast.Call(func=func_ast, args=[ast_node.left, ast_node.right], keywords=[])
303
- targets = AssignParser._get_targets(AssignParser._create_scopedvalue(father_ast_node.targets[0]))
304
- call_args = [AssignParser._create_scopedvalue(arg) for arg in father_ast_node.value.args]
305
- return Node.create_call_buildin_op(op, father_ast_node, targets, func, call_args, {})
306
283
 
307
284
  @staticmethod
308
- def _create_inputs_for_cell_container(father_ast_node) -> ['Node']:
285
+ def _create_inputs_for_cell_container(ast_assign) -> ['Node']:
309
286
  """Create inputs for cell container first node."""
310
- call_ast_node = father_ast_node.value
287
+ call_ast_node = ast_assign.value
311
288
  if not isinstance(call_ast_node, ast.Call):
312
289
  raise RuntimeError(error_str(f"when creating input node for cellcontainer, value of input father ast node"
313
- "is not ast.Call!'", child_node=call_ast_node, father_node=father_ast_node))
290
+ "is not ast.Call!'", child_node=call_ast_node, father_node=ast_assign))
314
291
  first_node_inputs: ['Node'] = []
315
292
  exist_param_name = []
316
293
  for arg in call_ast_node.args:
@@ -330,30 +307,52 @@ class AssignParser(Parser):
330
307
 
331
308
  if call_ast_node.keywords:
332
309
  raise RuntimeError(error_str(f"Not support keyword input for cellcontainer now.",
333
- child_node=call_ast_node, father_node=father_ast_node))
310
+ child_node=call_ast_node, father_node=ast_assign))
334
311
 
335
312
  return first_node_inputs
336
313
 
337
- def _cell_container_process(self, ast_node, stree, targets, func, call_args, call_kwargs, op_name, container_obj):
314
+ @staticmethod
315
+ def _update_cell_container_in_init(stree, container_name, container_idx, subnet_opt_name):
316
+ """
317
+ When nn.SequentialCell include sub-symboltree, the new class definition will be used to create object.
318
+ So the assign code will be got from origin code first, and then be modified to new class name.
319
+
320
+ Codes like:
321
+
322
+ `self.container = nn.SequentialCell([ReLU(), MyNet()])`
323
+
324
+ will be updated by add codes:
325
+
326
+ `self.container[1] = MyNetOpt(self.container[1])`
327
+
328
+ """
329
+ new_code = f"{container_name}[{container_idx}] = {subnet_opt_name}({container_name}[{container_idx}])"
330
+ new_ast = ast.parse(new_code).body[0]
331
+ stree.get_init_func_ast().body.append(new_ast)
332
+
333
+ @staticmethod
334
+ def cell_container_process(ast_assign, stree, targets, func_scope_name, call_args, call_kwargs,
335
+ op_name, container_obj):
338
336
  """ parse cell container object."""
339
- cell_container = CellContainer(ast_node, targets, func, call_args, call_kwargs, op_name, container_obj)
340
- cell_container.set_belong_symbol_tree(stree)
341
- first_node_inputs = AssignParser._create_inputs_for_cell_container(ast_node)
337
+ cell_container = CellContainer(ast_assign, targets, func_scope_name, call_args, call_kwargs,
338
+ op_name, stree, container_obj)
339
+ first_node_inputs = AssignParser._create_inputs_for_cell_container(ast_assign)
342
340
  for i, cell in enumerate(container_obj):
343
- is_sub_tree = is_subtree(type(cell).__name__)
341
+ cell_name = type(cell).__name__
342
+ is_sub_tree = is_subtree(cell)
344
343
  if is_sub_tree:
345
344
  stb = SymbolTreeBuilder(cell)
346
345
  new_stree = stb.build()
347
- replacer = AstReplacer(new_stree.get_class_ast())
348
- replacer.replace_all(new_stree.get_ori_cls_name(), new_stree.get_opt_cls_name())
349
- sub_node = TreeNode.create_tree_node(new_stree, ast_node, targets, func, call_args, call_kwargs,
350
- type(cell).__name__, cell)
346
+ sub_node = TreeNode.create_tree_node(new_stree, None, targets, cell_name, call_args,
347
+ call_kwargs, cell_name, cell)
348
+ AssignParser._update_cell_container_in_init(stree, func_scope_name, i, new_stree.get_opt_cls_name())
351
349
  else:
352
- sub_node = Node.create_call_buildin_op(cell, ast_node, targets, func, call_args, call_kwargs,
353
- type(cell).__name__)
350
+ sub_node = Node.create_call_buildin_op(cell, None, targets, cell_name, call_args,
351
+ call_kwargs, cell_name)
354
352
  # add sub node to cell_container
355
- cell_container.append(sub_node)
356
- # set node inputs
353
+ cell_container.append(sub_node, False)
354
+ # set node inputs, those input nodes are NOT inserted in container, only
355
+ # topological relationship is updated.
357
356
  if i == 0:
358
357
  for idx, arg_provider in enumerate(first_node_inputs):
359
358
  sub_node.set_arg_providers(idx, (arg_provider, 0))
@@ -361,43 +360,61 @@ class AssignParser(Parser):
361
360
  sub_node.set_arg_providers(0, (cell_container.node_list[i-1], 0))
362
361
  return cell_container
363
362
 
364
- def _process_external_function(self, stree, func_name):
365
- """Process external function."""
363
+ @staticmethod
364
+ def process_external_function(stree, func_name, file_path):
365
+ """
366
+ Process external function.
367
+ Ast of external function defined in specifical file_path will be saved to generate codes.
368
+ """
366
369
  for k, m in sys.modules.items():
367
370
  if k in ("_ast", "ast"):
368
371
  continue
369
372
  if hasattr(m, func_name):
370
373
  func = getattr(m, func_name)
374
+ if not inspect.isfunction(func):
375
+ continue
376
+ func_source_code_file = inspect.getfile(func)
377
+ if func_source_code_file != file_path:
378
+ continue
371
379
  source_code = inspect.getsource(func)
372
380
  ast_root: ast.Module = ast.parse(source_code)
373
- stree._external_func_ast.append(ast_root.body[0]) # pylint: disable=protected-access
381
+ stree.get_external_ast().append(ast_root.body[0])
374
382
  return func, ast_root.body[0]
375
- return None, None
383
+ logger.info(f"Cannot get ast of function {func_name} from {file_path}.")
384
+ return None, None
376
385
 
377
386
  def _process_internal_function(self, stree: SymbolTree, func_name):
378
387
  """Process internal function."""
379
- func = getattr(stree._origin_network, func_name) # pylint: disable=protected-access
380
- ast_node = None
381
- for body in stree._class_ast.body: # pylint: disable=protected-access
388
+ func_inst = getattr(stree.get_origin_network(), func_name)
389
+ ast_functiondef = None
390
+ for body in stree.get_class_ast().body:
382
391
  if isinstance(body, ast.FunctionDef) and func_name == body.name:
383
- ast_node = body
384
- return func, ast_node
385
-
386
- def _create_func_subtree(self, op, targets, father_ast_node, ast_node, call_args, call_kwargs, func_name):
387
- """Create subtree of function."""
388
- stb = FunctionSymbolTreeBuilder(op, ast_node)
389
- new_stree = stb.build()
390
- return TreeNode.create_tree_node(new_stree, father_ast_node, targets, func_name, call_args, call_kwargs,
391
- func_name, op)
392
-
393
- def _convert_ast_call_to_node(self, ast_node: ast.Call, father_ast_node: ast.Assign, stree: SymbolTree) -> Node:
392
+ ast_functiondef = body
393
+ return func_inst, ast_functiondef
394
+
395
+ def _create_callfunction_node(self, targets: [ScopedValue], func_scope_name: ScopedValue, args: [ScopedValue],
396
+ kwargs: {str: ScopedValue}, node_name: str, ast_assign: ast.Assign,
397
+ ast_functiondef: ast.FunctionDef, stree: SymbolTree, instance):
398
+ """Create CallFunction node for class internal function."""
399
+ node = CallFunction(targets, func_scope_name, args, kwargs, node_name, ast_assign, ast_functiondef,
400
+ stree, instance)
401
+ # expand ast codes
402
+ ast_functiondef = FlattenRecursiveStmt().transform(ast_functiondef, [func_scope_name.value], stree)
403
+ # parse ast codes into CallFunction Node
404
+ parser = ParserRegister.instance().get_parser(ast.FunctionDef)
405
+ parser.process(stree, ast_functiondef, node_manager=node)
406
+ return node
407
+
408
+ def _convert_ast_call_to_node(self, ast_call: ast.Call, ast_assign: ast.Assign, stree: SymbolTree,
409
+ node_manager: NodeManager) -> Node:
394
410
  """
395
411
  Convert ast.Call to a symbol tree node.
396
412
 
397
413
  Args:
398
- ast_node (ast.Call): An ast.Call of assign node in construct.
399
- father_ast_node (ast.Assign): Assign node in construct.
414
+ ast_call (ast.Call): An ast.Call of assign node in construct.
415
+ ast_assign (ast.Assign): Assign node in construct.
400
416
  stree (SymbolTree): Symbol Tree under parsing.
417
+ node_manager (NodeManager): NodeManager those asts belong to.
401
418
 
402
419
  Returns:
403
420
  An instance of Node in Symbol Tree.
@@ -405,86 +422,63 @@ class AssignParser(Parser):
405
422
  Raises:
406
423
  RuntimeError: If operator instance invoked by assign is undefined.
407
424
  """
408
- targets = AssignParser._get_targets(AssignParser._create_scopedvalue(father_ast_node.targets[0]))
409
- func_name = AssignParser._get_func_name(ast_node)
425
+ targets = AssignParser._get_targets(AssignParser._create_scopedvalue(ast_assign.targets[0]))
426
+ func_name = AssignParser._get_func_name(ast_call)
410
427
  if func_name is None or func_name == "":
411
428
  raise RuntimeError("function name not exist")
412
- func_scope = AssignParser._get_func_scope(ast_node)
413
- func = ScopedValue.create_naming_value(func_name, func_scope)
414
- call_args = [AssignParser._create_scopedvalue(arg) for arg in ast_node.args]
415
- call_kwargs = AssignParser._create_kwargs(ast_node.keywords)
416
-
417
- _, op = AssignParser._find_op_and_type(func_scope, func_name, stree)
418
- if op is None:
419
- if is_functional(func_name):
420
- parser = ParserRegister.instance().get_parser(type(ast_node.func))
421
- func_name = parser.process(stree, ast_node.func)
422
- func = get_functional(func_name.split(".")[-1])
423
- node = stree.inner_create_call_function(func_name, father_ast_node, func_name, func, targets,
424
- call_args, call_kwargs)
425
- elif hasattr(stree._origin_network, func_name): # pylint: disable=protected-access
426
- func, ast_node = self._process_internal_function(stree, func_name)
427
- node = self._create_func_subtree(func, targets, father_ast_node, ast_node, call_args, call_kwargs,
428
- func_name)
429
+ func_scope = AssignParser._get_func_scope(ast_call, node_manager)
430
+ func_scope_name = ScopedValue.create_naming_value(func_name, func_scope)
431
+ call_args = [AssignParser._create_scopedvalue(arg) for arg in ast_call.args]
432
+ call_kwargs = AssignParser._create_kwargs(ast_call.keywords)
433
+
434
+ func_inst = AssignParser._get_call_instance(func_scope, func_name, stree)
435
+ if func_inst is None:
436
+ # Function is not Cell and Primitive
437
+ if func_scope in ('self', stree.get_opt_cls_name()) and hasattr(stree.get_origin_network(), func_name):
438
+ # Function defined in current class
439
+ func_inst, ast_functiondef = self._process_internal_function(stree, func_name)
440
+ if ast_functiondef is None:
441
+ raise RuntimeError(f"Find ast of function {func_scope}.{func_name} in symbol tree class failed.")
442
+ node = self._create_callfunction_node(targets, func_scope_name, call_args, call_kwargs, func_name,
443
+ ast_assign, ast_functiondef, stree, func_inst)
444
+ elif is_functional(func_name):
445
+ # Function defined in mindspore.ops.functional
446
+ parser = ParserRegister.instance().get_parser(type(ast_call.func)) # ast.Name or ast.Attribute
447
+ func_name = parser.process(stree, ast_call.func, node_manager).split(".")[-1]
448
+ func_inst = get_functional(func_name)
449
+ node = Node.inner_create_call_function(func_name, ast_assign, func_name, func_inst, targets,
450
+ call_args, call_kwargs)
429
451
  else:
430
- func, ast_node = self._process_external_function(stree, func_name)
431
- node = self._create_func_subtree(func, targets, father_ast_node, ast_node, call_args, call_kwargs,
432
- func_name)
452
+ origin_net_file = inspect.getfile(type(stree.get_origin_network()))
453
+ if not os.path.exists(origin_net_file):
454
+ raise RuntimeError(f"For MindSpore Rewrite, in assign parser, origin_net_file "
455
+ f"{origin_net_file} not exist")
456
+ func_inst, ast_functiondef = AssignParser.process_external_function(stree, func_name, origin_net_file)
457
+ node = Node.inner_create_call_function(func_name, ast_assign, func_name, func_inst, targets,
458
+ call_args, call_kwargs)
433
459
  return node
434
- if isinstance(op, SequentialCell):
435
- node = self._cell_container_process(father_ast_node, stree, targets, func, call_args, call_kwargs,
436
- func_name, op)
460
+ if isinstance(func_inst, tuple(AssignParser.types_for_cell_container)):
461
+ node = AssignParser.cell_container_process(ast_assign, stree, targets, func_scope_name, call_args,
462
+ call_kwargs, func_name, func_inst)
437
463
  return node
438
- if isinstance(op, Primitive):
439
- return Node.create_call_buildin_op(op, father_ast_node, targets, func, call_args, call_kwargs, func_name)
440
- if isinstance(op, Cell):
441
- is_sub_tree = is_subtree(type(op).__name__)
442
- if is_sub_tree:
443
- stb = SymbolTreeBuilder(op)
464
+ if isinstance(func_inst, Primitive):
465
+ return Node.create_call_buildin_op(func_inst, ast_assign, targets, func_scope_name, call_args, call_kwargs,
466
+ func_name)
467
+ if isinstance(func_inst, Cell):
468
+ if is_subtree(func_inst):
469
+ # Instance of function is user custom network, create sub-symboltree
470
+ stb = SymbolTreeBuilder(func_inst)
444
471
  new_stree = stb.build()
445
- changed = AssignParser._update_field_in_init(func_scope, func_name, stree, new_stree)
446
- if changed:
447
- # class SubSubNet:
448
- # def __init__(self, global_vars):
449
- # self._handler = global_vars.get("handler")
450
- #
451
- # class SubNet:
452
- # def __init__(self, global_vars):
453
- # self._handler = global_vars.get("handler")
454
- # self._subsubnet = None
455
- # if xxx:
456
- # self._subsubnet = SubSubNet(xxx)
457
- #
458
- # Assuming there are two instance of SubNet A and B. "if xxx" in A is True, and in B is False.
459
- # So self._subsubnet in A is an instance of SubSubNet, and in B is None.
460
- # So After rewrite, A's code:
461
- # class SubNetA:
462
- # def __init__(self, global_vars):
463
- # self._handler = global_vars.get("handler")
464
- # self._subsubnet = SubSubNet(global_vars.get("subsubnet_args"))
465
- # while B's code:
466
- # class SubNetB:
467
- # def __init__(self, global_vars):
468
- # self._handler = global_vars.get("handler")
469
- # self._subsubnet = getattr(self._handler, "_subsubnet")
470
- # So SubNet should use SubNetA as its code when _update_field_in_init return True.
471
- # So SubNet should use SubNetB as its code when _update_field_in_init return False or undefined
472
- # error will occur to "global_vars.get("subsubnet_args")".
473
- stree.on_change(Event.CodeChangeEvent)
474
- # Sub-network in main-network is expressed as:
475
- # self._subnet = SubNet(global_vars.get("subnet_args"))
476
- # when subnet is changed, its class will change, take SubNet1 as new class-name, so code main-network
477
- # also need to change:
478
- # self._subnet = SubNet1(global_vars.get("subnet_args"))
479
- # so a change in sub-network should also be identified as a change in main-network.
480
- # so main-network should observe sub-network
472
+ AssignParser._update_field_in_init(func_scope, func_name, stree, new_stree)
481
473
  replacer = AstReplacer(new_stree.get_class_ast())
482
474
  replacer.replace_all(new_stree.get_ori_cls_name(), new_stree.get_opt_cls_name())
483
- return TreeNode.create_tree_node(new_stree, father_ast_node, targets, func, call_args, call_kwargs,
484
- func_name, new_stree.get_origin_network())
485
- return Node.create_call_buildin_op(op, father_ast_node, targets, func, call_args, call_kwargs, func_name)
486
- raise RuntimeError("For MindSpore Rewrite, only support Primitive or Cell operator or Primitive operator, got ",
487
- type(op).__name__)
475
+ return TreeNode.create_tree_node(new_stree, ast_assign, targets, func_scope_name, call_args,
476
+ call_kwargs, func_name, new_stree.get_origin_network())
477
+ # Instance of function is buildin cells
478
+ return Node.create_call_buildin_op(func_inst, ast_assign, targets, func_scope_name, call_args, call_kwargs,
479
+ func_name)
480
+ raise RuntimeError("For MindSpore Rewrite, unsupported operation in ast.Call found: ",
481
+ type(func_inst).__name__)
488
482
 
489
483
  @staticmethod
490
484
  def _tuple_elts_support_scopledvalue(value: ast.Tuple) -> bool:
@@ -499,62 +493,62 @@ class AssignParser(Parser):
499
493
  return True
500
494
 
501
495
  @staticmethod
502
- def _convert_ast_mathops_to_node(ast_node: Union[ast.BinOp, ast.UnaryOp, ast.BoolOp, ast.Compare],
503
- father_ast_node: ast.Assign) -> Node:
496
+ def _convert_ast_mathops_to_node(ast_op: Union[ast.BinOp, ast.UnaryOp, ast.BoolOp, ast.Compare],
497
+ ast_assign: ast.Assign) -> Node:
504
498
  """
505
499
  Convert ast node of math operations(ast.BinOp, ast.UnaryOp, ast.BoolOp, ast.Compare) to
506
500
  a symbol tree node.
507
501
 
508
502
  Args:
509
- ast_node (Union[ast.BinOp, ast.UnaryOp, ast.BoolOp, ast.Compare]): An assign node with mathematival
503
+ ast_op (Union[ast.BinOp, ast.UnaryOp, ast.BoolOp, ast.Compare]): An assign node with mathematival
510
504
  operation in construct function.
511
- father_ast_node (ast.Assign): Assign node in construct.
505
+ ast_assign (ast.Assign): Assign node in construct.
512
506
 
513
507
  Returns:
514
508
  An instance of Node in Symbol Tree.
515
509
 
516
510
  Raises:
517
- TypeError: The type of parameter 'ast_node' is not in (ast.BinOp, ast.UnaryOp, ast.BoolOp, ast.Compare).
511
+ TypeError: The type of parameter 'ast_op' is not in (ast.BinOp, ast.UnaryOp, ast.BoolOp, ast.Compare).
518
512
 
519
513
  """
520
- if not isinstance(ast_node, (ast.BinOp, ast.UnaryOp, ast.BoolOp, ast.Compare)):
521
- raise TypeError("The type of parameter 'ast_node' must be one of (ast.BinOp, ast.UnaryOp, "
522
- "ast.BoolOp, ast.Compare), but got ", type(ast_node))
514
+ if not isinstance(ast_op, (ast.BinOp, ast.UnaryOp, ast.BoolOp, ast.Compare)):
515
+ raise TypeError("The type of parameter 'ast_op' must be one of (ast.BinOp, ast.UnaryOp, "
516
+ "ast.BoolOp, ast.Compare), but got ", type(ast_op))
523
517
 
524
- targets = AssignParser._get_targets(AssignParser._create_scopedvalue(father_ast_node.targets[0]))
518
+ targets = AssignParser._get_targets(AssignParser._create_scopedvalue(ast_assign.targets[0]))
525
519
  args = []
526
- op_type_str = type(ast_node).__name__
520
+ op_type_str = type(ast_op).__name__
527
521
  op_type = ScopedValue.create_naming_value(op_type_str)
528
522
  ops = {}
529
523
  name = op_type_str
530
- if isinstance(ast_node, ast.BinOp):
531
- op = type(ast_node.op).__name__
524
+ if isinstance(ast_op, ast.BinOp):
525
+ op = type(ast_op.op).__name__
532
526
  name = f'{name}_{op}'
533
527
  ops['0'] = ScopedValue.create_naming_value(op)
534
- args.append(AssignParser._create_scopedvalue(ast_node.left))
535
- args.append(AssignParser._create_scopedvalue(ast_node.right))
536
- elif isinstance(ast_node, ast.UnaryOp):
537
- op = type(ast_node.op).__name__
528
+ args.append(AssignParser._create_scopedvalue(ast_op.left))
529
+ args.append(AssignParser._create_scopedvalue(ast_op.right))
530
+ elif isinstance(ast_op, ast.UnaryOp):
531
+ op = type(ast_op.op).__name__
538
532
  name = f'{name}_{op}'
539
533
  ops['0'] = ScopedValue.create_naming_value(op)
540
- args.append(AssignParser._create_scopedvalue(ast_node.operand))
541
- elif isinstance(ast_node, ast.BoolOp):
542
- op = type(ast_node.op).__name__
534
+ args.append(AssignParser._create_scopedvalue(ast_op.operand))
535
+ elif isinstance(ast_op, ast.BoolOp):
536
+ op = type(ast_op.op).__name__
543
537
  name = f'{name}_{op}'
544
538
  ops['0'] = ScopedValue.create_naming_value(op)
545
- for value in ast_node.values:
539
+ for value in ast_op.values:
546
540
  args.append(AssignParser._create_scopedvalue(value))
547
- elif isinstance(ast_node, ast.Compare):
548
- args.append(AssignParser._create_scopedvalue(ast_node.left))
549
- for idx, ast_op in enumerate(ast_node.ops):
550
- op = type(ast_op).__name__
541
+ elif isinstance(ast_op, ast.Compare):
542
+ args.append(AssignParser._create_scopedvalue(ast_op.left))
543
+ for idx, ast_cmp_op in enumerate(ast_op.ops):
544
+ op = type(ast_cmp_op).__name__
551
545
  name = f'{name}_{op}'
552
546
  ops[str(idx)] = ScopedValue.create_naming_value(op)
553
- args.append(AssignParser._create_scopedvalue(ast_node.comparators[idx]))
547
+ args.append(AssignParser._create_scopedvalue(ast_op.comparators[idx]))
554
548
  name = name.lower()
555
- return Node.create_mathops_node(father_ast_node, targets, op_type, args, ops, name)
549
+ return Node.create_mathops_node(ast_assign, targets, op_type, args, ops, name)
556
550
 
557
- def process(self, stree: SymbolTree, node: ast.Assign):
551
+ def process(self, stree: SymbolTree, node: ast.Assign, node_manager: NodeManager):
558
552
  """
559
553
  Parse ast.Assign and create a node in symbol tree.
560
554
 
@@ -566,6 +560,7 @@ class AssignParser(Parser):
566
560
  Args:
567
561
  stree ([SymbolTree]): Symbol Tree under parsing.
568
562
  node ([ast.Assign]): An ast.Assign node.
563
+ node_manager (NodeManager): NodeManager those asts belong to.
569
564
 
570
565
  Raises:
571
566
  RuntimeError: Only support one target in assign now.
@@ -576,18 +571,18 @@ class AssignParser(Parser):
576
571
  try:
577
572
  if len(targets) != 1:
578
573
  raise RuntimeError(
579
- error_str(f"only support one target in assign now.", child_node=targets, father_node=node))
574
+ error_str(f"only support one target in assign now.", targets, node))
580
575
  value = node.value
581
576
  if isinstance(value, ast.Call):
582
- node_ = self._convert_ast_call_to_node(value, node, stree)
583
- stree.append_origin_field(node_)
577
+ node_ = self._convert_ast_call_to_node(value, node, stree, node_manager)
578
+ stree.append_origin_field(node_, node_manager)
584
579
  elif isinstance(value, (ast.BinOp, ast.UnaryOp, ast.BoolOp, ast.Compare)):
585
580
  node_ = AssignParser._convert_ast_mathops_to_node(value, node)
586
- stree.append_origin_field(node_)
581
+ stree.append_origin_field(node_, node_manager)
587
582
  elif isinstance(value, ast.Subscript):
588
583
  logger.info(f"ops-call({astunparse.unparse(node)}) in assign will be supported in near feature, "
589
584
  f"ignored as a python node now")
590
- stree.try_append_python_node(node, node)
585
+ stree.try_append_python_node(node, node, node_manager)
591
586
  elif isinstance(value, (ast.Name, ast.Constant, ast.Attribute, ast.Num, ast.NameConstant,
592
587
  ast.Bytes, ast.Str)):
593
588
  if isinstance(value, ast.Name):
@@ -601,7 +596,7 @@ class AssignParser(Parser):
601
596
  targets = AssignParser._get_targets(AssignParser._create_scopedvalue(node.targets[0]))
602
597
  call_args = [AssignParser._create_scopedvalue(value)]
603
598
  node_ = Node.create_call_pass_through_method(node, targets, call_args, {}, node_name)
604
- stree.append_origin_field(node_)
599
+ stree.append_origin_field(node_, node_manager)
605
600
  elif isinstance(value, ast.Tuple):
606
601
  if AssignParser._tuple_elts_support_scopledvalue(value):
607
602
  # ensure that each element's type in tuple is supported by scopled value
@@ -611,14 +606,14 @@ class AssignParser(Parser):
611
606
  args.append(AssignParser._create_scopedvalue(elt))
612
607
  node_ = Node.create_call_method(node, targets, ScopedValue.create_naming_value("tuple"),
613
608
  args, {}, "tuple")
614
- stree.append_origin_field(node_)
609
+ stree.append_origin_field(node_, node_manager)
615
610
  else:
616
- logger.warning(f"some elements in Tuple of assign({astunparse.unparse(node)}) are not supported "
617
- "in rewrite, fallback to python")
618
- stree.try_append_python_node(node, node)
611
+ logger.info(f"some elements in Tuple of assign({astunparse.unparse(node)}) are not supported "
612
+ "in rewrite, fallback to python")
613
+ stree.try_append_python_node(node, node, node_manager)
619
614
  elif isinstance(value, (ast.List, ast.Dict)):
620
615
  # add these as callmethod node if necessary
621
- stree.try_append_python_node(node, node)
616
+ stree.try_append_python_node(node, node, node_manager)
622
617
  else:
623
618
  raise RuntimeError(
624
619
  error_str(f"only support (ast.Call, ast.BinOp, ast.BoolOp, ast.Subscript, ast.Name, ast.Constant, "
@@ -626,8 +621,8 @@ class AssignParser(Parser):
626
621
  f"ast.Dict) as value of ast.assign, but got ast type '{type(value).__name__}'",
627
622
  child_node=value, father_node=node))
628
623
  except RuntimeError:
629
- logger.info(f"ops-call({astunparse.unparse(node)}) not supported in rewrite, fallback to python")
630
- stree.try_append_python_node(node, node)
624
+ logger.info(f"ops-call({astunparse.unparse(node).strip()}) not supported in rewrite, fallback to python")
625
+ stree.try_append_python_node(node, node, node_manager)
631
626
 
632
627
 
633
628
  g_assign_parser = reg_parser(AssignParser())