mindspore 2.1.0__cp38-cp38-manylinux1_x86_64.whl → 2.2.11__cp38-cp38-manylinux1_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (589) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -1
  3. mindspore/_akg/akg/build_module.py +5 -6
  4. mindspore/_akg/akg/composite/build_module.py +139 -22
  5. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  6. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  7. mindspore/_akg/akg/tvm/api.py +4 -3
  8. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  9. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  10. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  11. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  12. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  13. mindspore/_akg/akg/tvm/build_module.py +16 -1
  14. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  15. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  16. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  17. mindspore/_akg/akg/tvm/module.py +1 -2
  18. mindspore/_akg/akg/tvm/stmt.py +2 -2
  19. mindspore/_akg/akg/utils/ascend_profilier/cann_file_parser.py +76 -0
  20. mindspore/_akg/akg/utils/ascend_profilier/file_manager.py +56 -0
  21. mindspore/_akg/akg/utils/ascend_profilier/op_summary_bean.py +23 -0
  22. mindspore/_akg/akg/utils/ascend_profilier/op_summary_headers.py +8 -0
  23. mindspore/_akg/akg/utils/ascend_profilier/op_summary_parser.py +42 -0
  24. mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +65 -0
  25. mindspore/_akg/akg/utils/composite_op_helper.py +16 -12
  26. mindspore/_akg/akg/utils/dump_ascend_meta.py +22 -3
  27. mindspore/_akg/akg/utils/kernel_exec.py +98 -274
  28. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  29. mindspore/_akg/akg/utils/tbe_codegen_utils.py +219 -0
  30. mindspore/_akg/akg/utils/util.py +56 -1
  31. mindspore/_c_dataengine.cpython-38-x86_64-linux-gnu.so +0 -0
  32. mindspore/_c_expression.cpython-38-x86_64-linux-gnu.so +0 -0
  33. mindspore/_c_mindrecord.cpython-38-x86_64-linux-gnu.so +0 -0
  34. mindspore/_check_jit_forbidden_api.py +3 -1
  35. mindspore/_checkparam.py +23 -29
  36. mindspore/_extends/graph_kernel/__init__.py +0 -1
  37. mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
  38. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  39. mindspore/_extends/graph_kernel/splitter.py +4 -11
  40. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
  41. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
  42. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  43. mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
  44. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
  45. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
  46. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  47. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  48. mindspore/_extends/parse/__init__.py +13 -15
  49. mindspore/_extends/parse/namespace.py +7 -33
  50. mindspore/_extends/parse/parser.py +67 -72
  51. mindspore/_extends/parse/resources.py +1 -1
  52. mindspore/_extends/parse/standard_method.py +86 -106
  53. mindspore/_extends/parse/trope.py +1 -1
  54. mindspore/_extends/remote/kernel_build_server.py +25 -7
  55. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  56. mindspore/_install_custom.py +43 -0
  57. mindspore/_mindspore_offline_debug.cpython-38-x86_64-linux-gnu.so +0 -0
  58. mindspore/amp.py +47 -11
  59. mindspore/bin/cache_admin +0 -0
  60. mindspore/bin/cache_server +0 -0
  61. mindspore/boost/boost.py +1 -8
  62. mindspore/boost/boost_cell_wrapper.py +3 -2
  63. mindspore/boost/grad_accumulation.py +1 -1
  64. mindspore/boost/group_loss_scale_manager.py +8 -7
  65. mindspore/common/__init__.py +5 -3
  66. mindspore/common/_jit_fallback_utils.py +6 -0
  67. mindspore/common/_register_for_adapter.py +2 -0
  68. mindspore/common/_register_for_tensor.py +2 -2
  69. mindspore/common/_stub_tensor.py +13 -0
  70. mindspore/common/_utils.py +29 -0
  71. mindspore/common/api.py +174 -259
  72. mindspore/common/auto_dynamic_shape.py +494 -0
  73. mindspore/common/dtype.py +18 -11
  74. mindspore/common/dump.py +6 -4
  75. mindspore/common/initializer.py +14 -14
  76. mindspore/common/jit_config.py +33 -15
  77. mindspore/common/lazy_inline.py +126 -7
  78. mindspore/common/mindir_util.py +101 -0
  79. mindspore/common/parameter.py +51 -41
  80. mindspore/common/seed.py +4 -4
  81. mindspore/common/sparse_tensor.py +13 -14
  82. mindspore/common/tensor.py +243 -165
  83. mindspore/communication/__init__.py +7 -4
  84. mindspore/communication/_comm_helper.py +83 -4
  85. mindspore/communication/management.py +152 -84
  86. mindspore/config/op_info.config +14 -3
  87. mindspore/config/super_bar_config.json +4 -2
  88. mindspore/context.py +152 -61
  89. mindspore/dataset/__init__.py +5 -5
  90. mindspore/dataset/audio/__init__.py +2 -2
  91. mindspore/dataset/audio/transforms.py +52 -52
  92. mindspore/dataset/callback/ds_callback.py +16 -2
  93. mindspore/dataset/core/config.py +68 -51
  94. mindspore/dataset/engine/cache_client.py +33 -7
  95. mindspore/dataset/engine/datasets.py +250 -112
  96. mindspore/dataset/engine/datasets_audio.py +43 -211
  97. mindspore/dataset/engine/datasets_standard_format.py +16 -35
  98. mindspore/dataset/engine/datasets_text.py +43 -67
  99. mindspore/dataset/engine/datasets_user_defined.py +86 -100
  100. mindspore/dataset/engine/datasets_vision.py +219 -1029
  101. mindspore/dataset/engine/iterators.py +11 -4
  102. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
  103. mindspore/dataset/engine/obs/util.py +3 -0
  104. mindspore/dataset/engine/samplers.py +1 -1
  105. mindspore/dataset/engine/validators.py +19 -5
  106. mindspore/dataset/text/__init__.py +3 -3
  107. mindspore/dataset/text/transforms.py +101 -127
  108. mindspore/dataset/text/utils.py +205 -138
  109. mindspore/dataset/transforms/__init__.py +1 -1
  110. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  111. mindspore/dataset/transforms/transforms.py +95 -40
  112. mindspore/dataset/utils/browse_dataset.py +8 -2
  113. mindspore/dataset/utils/line_reader.py +17 -19
  114. mindspore/dataset/vision/__init__.py +3 -3
  115. mindspore/dataset/vision/c_transforms.py +6 -3
  116. mindspore/dataset/vision/transforms.py +409 -287
  117. mindspore/dataset/vision/utils.py +13 -14
  118. mindspore/dataset/vision/validators.py +11 -1
  119. mindspore/experimental/map_parameter.py +14 -0
  120. mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
  121. mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
  122. mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
  123. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  124. mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
  125. mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
  126. mindspore/gen_ops.py +273 -0
  127. mindspore/include/OWNERS +0 -1
  128. mindspore/include/api/data_type.h +2 -1
  129. mindspore/include/api/graph.h +0 -15
  130. mindspore/include/api/kernel.h +2 -0
  131. mindspore/include/api/kernel_api.h +37 -12
  132. mindspore/include/api/model.h +17 -14
  133. mindspore/include/api/status.h +8 -3
  134. mindspore/include/api/types.h +37 -4
  135. mindspore/include/c_api/ms/abstract.h +67 -0
  136. mindspore/include/c_api/ms/attribute.h +197 -0
  137. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  138. mindspore/include/c_api/ms/base/macros.h +32 -0
  139. mindspore/include/c_api/ms/base/status.h +33 -0
  140. mindspore/include/c_api/ms/base/types.h +282 -0
  141. mindspore/include/c_api/ms/context.h +102 -0
  142. mindspore/include/c_api/ms/graph.h +160 -0
  143. mindspore/include/c_api/ms/node.h +606 -0
  144. mindspore/include/c_api/ms/tensor.h +161 -0
  145. mindspore/include/c_api/ms/value.h +84 -0
  146. mindspore/include/dataset/constants.h +6 -5
  147. mindspore/include/dataset/execute.h +23 -13
  148. mindspore/include/dataset/text.h +26 -26
  149. mindspore/include/dataset/transforms.h +13 -13
  150. mindspore/include/dataset/vision.h +60 -60
  151. mindspore/include/dataset/vision_ascend.h +5 -6
  152. mindspore/include/dataset/vision_lite.h +17 -17
  153. mindspore/include/mindapi/base/type_id.h +1 -0
  154. mindspore/include/mindapi/base/types.h +1 -0
  155. mindspore/lib/libdnnl.so.2 +0 -0
  156. mindspore/lib/libjemalloc.so.2 +0 -0
  157. mindspore/lib/libmindspore.so +0 -0
  158. mindspore/lib/libmindspore_backend.so +0 -0
  159. mindspore/lib/libmindspore_common.so +0 -0
  160. mindspore/lib/libmindspore_core.so +0 -0
  161. mindspore/lib/libmindspore_glog.so.0 +0 -0
  162. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  163. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  164. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  165. mindspore/lib/libmindspore_shared_lib.so +0 -0
  166. mindspore/lib/libnnacl.so +0 -0
  167. mindspore/lib/libopencv_core.so.4.5 +0 -0
  168. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  169. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  170. mindspore/lib/libps_cache.so +0 -0
  171. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310/aic-ascend310-ops-info.json +123 -0
  172. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +123 -0
  173. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +158 -0
  174. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +37 -0
  175. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  176. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  177. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  178. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  179. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  180. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  181. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  182. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  183. mindspore/lib/plugin/ascend/custom_aicore_ops/op_proto/libop_proto.so +0 -0
  184. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  185. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  186. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +8998 -0
  187. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  188. mindspore/lib/plugin/ascend/libakg.so +0 -0
  189. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  190. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  191. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  192. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  193. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  194. mindspore/lib/plugin/cpu/libakg.so +0 -0
  195. mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
  196. mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
  197. mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
  198. mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
  199. mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
  200. mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
  201. mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
  202. mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
  203. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  204. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  205. mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
  206. mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
  207. mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
  208. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  209. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  210. mindspore/nn/__init__.py +0 -2
  211. mindspore/nn/cell.py +313 -74
  212. mindspore/nn/dynamic_lr.py +21 -21
  213. mindspore/nn/layer/activation.py +22 -30
  214. mindspore/nn/layer/basic.py +15 -13
  215. mindspore/nn/layer/channel_shuffle.py +1 -1
  216. mindspore/nn/layer/container.py +271 -9
  217. mindspore/nn/layer/conv.py +323 -204
  218. mindspore/nn/layer/dense.py +8 -5
  219. mindspore/nn/layer/embedding.py +33 -27
  220. mindspore/nn/layer/flash_attention.py +61 -95
  221. mindspore/nn/layer/image.py +8 -6
  222. mindspore/nn/layer/math.py +16 -25
  223. mindspore/nn/layer/normalization.py +107 -66
  224. mindspore/nn/layer/padding.py +1 -1
  225. mindspore/nn/layer/pooling.py +131 -109
  226. mindspore/nn/layer/rnn_cells.py +27 -22
  227. mindspore/nn/layer/rnns.py +13 -16
  228. mindspore/nn/layer/thor_layer.py +1 -1
  229. mindspore/nn/layer/transformer.py +221 -154
  230. mindspore/nn/learning_rate_schedule.py +9 -1
  231. mindspore/nn/loss/loss.py +235 -174
  232. mindspore/nn/optim/ada_grad.py +2 -1
  233. mindspore/nn/optim/adadelta.py +1 -0
  234. mindspore/nn/optim/adafactor.py +2 -1
  235. mindspore/nn/optim/adam.py +7 -4
  236. mindspore/nn/optim/adamax.py +3 -2
  237. mindspore/nn/optim/adasum.py +2 -2
  238. mindspore/nn/optim/asgd.py +2 -3
  239. mindspore/nn/optim/ftrl.py +6 -5
  240. mindspore/nn/optim/lamb.py +7 -4
  241. mindspore/nn/optim/lars.py +1 -1
  242. mindspore/nn/optim/lazyadam.py +5 -3
  243. mindspore/nn/optim/momentum.py +2 -1
  244. mindspore/nn/optim/optimizer.py +53 -4
  245. mindspore/nn/optim/proximal_ada_grad.py +3 -4
  246. mindspore/nn/optim/rmsprop.py +4 -3
  247. mindspore/nn/optim/rprop.py +23 -12
  248. mindspore/nn/optim/sgd.py +26 -11
  249. mindspore/nn/optim/thor.py +9 -7
  250. mindspore/nn/probability/bijector/bijector.py +5 -5
  251. mindspore/nn/probability/bijector/power_transform.py +27 -27
  252. mindspore/nn/probability/bijector/softplus.py +3 -3
  253. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
  254. mindspore/nn/probability/distribution/bernoulli.py +5 -5
  255. mindspore/nn/probability/distribution/beta.py +3 -3
  256. mindspore/nn/probability/distribution/categorical.py +7 -7
  257. mindspore/nn/probability/distribution/cauchy.py +0 -1
  258. mindspore/nn/probability/distribution/distribution.py +3 -3
  259. mindspore/nn/probability/distribution/gamma.py +3 -3
  260. mindspore/nn/probability/distribution/geometric.py +4 -4
  261. mindspore/nn/probability/distribution/gumbel.py +4 -4
  262. mindspore/nn/probability/distribution/log_normal.py +2 -2
  263. mindspore/nn/probability/distribution/logistic.py +2 -2
  264. mindspore/nn/probability/distribution/poisson.py +4 -4
  265. mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
  266. mindspore/nn/probability/distribution/uniform.py +6 -6
  267. mindspore/nn/wrap/__init__.py +4 -2
  268. mindspore/nn/wrap/cell_wrapper.py +87 -34
  269. mindspore/nn/wrap/grad_reducer.py +8 -5
  270. mindspore/nn/wrap/loss_scale.py +105 -42
  271. mindspore/numpy/array_creations.py +1 -2
  272. mindspore/numpy/array_ops.py +3 -2
  273. mindspore/numpy/utils_const.py +5 -5
  274. mindspore/offline_debug/convert_async.py +2 -2
  275. mindspore/ops/_grad_experimental/__init__.py +0 -5
  276. mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
  277. mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
  278. mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
  279. mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
  280. mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
  281. mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
  282. mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
  283. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
  284. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  285. mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
  286. mindspore/ops/_op_impl/aicpu/add.py +3 -3
  287. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  288. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  289. mindspore/ops/_op_impl/{_custom_op/flash_attention/constants.py → aicpu/eps.py} +18 -27
  290. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  291. mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +21 -2
  292. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  293. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  294. mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
  295. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  296. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  297. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  298. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  299. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  300. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  301. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  302. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
  303. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
  304. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  305. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  306. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  307. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  308. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  309. mindspore/ops/_op_impl/tbe/__init__.py +4 -4
  310. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  311. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  312. mindspore/ops/_primitive_cache.py +1 -1
  313. mindspore/ops/_tracefunc.py +45 -13
  314. mindspore/ops/_utils/utils.py +6 -1
  315. mindspore/ops/_vmap/vmap_array_ops.py +3 -3
  316. mindspore/ops/_vmap/vmap_base.py +3 -3
  317. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  318. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  319. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  320. mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
  321. mindspore/ops/arg_dtype_cast.py +54 -0
  322. mindspore/ops/composite/base.py +37 -10
  323. mindspore/ops/composite/math_ops.py +5 -4
  324. mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
  325. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
  326. mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
  327. mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
  328. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  329. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  330. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  331. mindspore/ops/deprecated.py +304 -0
  332. mindspore/ops/function/__init__.py +4 -1
  333. mindspore/ops/function/array_func.py +174 -193
  334. mindspore/ops/function/clip_func.py +81 -13
  335. mindspore/ops/function/debug_func.py +1 -1
  336. mindspore/ops/function/grad/grad_func.py +18 -9
  337. mindspore/ops/function/image_func.py +10 -4
  338. mindspore/ops/function/linalg_func.py +5 -5
  339. mindspore/ops/function/math_func.py +575 -386
  340. mindspore/ops/function/nn_func.py +568 -260
  341. mindspore/ops/function/random_func.py +88 -57
  342. mindspore/ops/function/sparse_func.py +1 -1
  343. mindspore/ops/function/sparse_unary_func.py +14 -12
  344. mindspore/ops/function/vmap_func.py +6 -5
  345. mindspore/ops/functional.py +15 -10
  346. mindspore/ops/op_info_register.py +244 -25
  347. mindspore/ops/operations/__init__.py +31 -19
  348. mindspore/ops/operations/_grad_ops.py +71 -7
  349. mindspore/ops/operations/_inner_ops.py +350 -17
  350. mindspore/ops/operations/_quant_ops.py +4 -8
  351. mindspore/ops/operations/_sequence_ops.py +42 -0
  352. mindspore/ops/operations/array_ops.py +68 -282
  353. mindspore/ops/operations/comm_ops.py +107 -59
  354. mindspore/ops/operations/custom_ops.py +94 -70
  355. mindspore/ops/operations/debug_ops.py +8 -4
  356. mindspore/ops/operations/image_ops.py +18 -12
  357. mindspore/ops/operations/inner_ops.py +26 -3
  358. mindspore/ops/operations/math_ops.py +192 -144
  359. mindspore/ops/operations/nn_ops.py +857 -489
  360. mindspore/ops/operations/other_ops.py +0 -22
  361. mindspore/ops/operations/random_ops.py +53 -111
  362. mindspore/ops/operations/sparse_ops.py +3 -1
  363. mindspore/ops/primitive.py +24 -18
  364. mindspore/parallel/_auto_parallel_context.py +68 -8
  365. mindspore/parallel/_cost_model_context.py +2 -2
  366. mindspore/parallel/_offload_context.py +17 -3
  367. mindspore/parallel/_parallel_serialization.py +12 -5
  368. mindspore/parallel/_ps_context.py +12 -0
  369. mindspore/parallel/_tensor.py +18 -13
  370. mindspore/parallel/_transformer/layers.py +5 -3
  371. mindspore/parallel/_transformer/loss.py +1 -0
  372. mindspore/parallel/_transformer/moe.py +2 -2
  373. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  374. mindspore/parallel/_transformer/transformer.py +23 -3
  375. mindspore/parallel/_utils.py +11 -7
  376. mindspore/parallel/algo_parameter_config.py +85 -5
  377. mindspore/parallel/checkpoint_transform.py +19 -12
  378. mindspore/parallel/shard.py +21 -14
  379. mindspore/profiler/common/struct_type.py +3 -3
  380. mindspore/profiler/common/util.py +4 -2
  381. mindspore/profiler/envprofiling.py +1 -1
  382. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  383. mindspore/profiler/parser/ascend_flops_generator.py +2 -2
  384. mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
  385. mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
  386. mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
  387. mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
  388. mindspore/profiler/parser/ascend_op_generator.py +6 -6
  389. mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
  390. mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
  391. mindspore/profiler/parser/base_timeline_generator.py +10 -8
  392. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
  393. mindspore/profiler/parser/flops_parser.py +15 -11
  394. mindspore/profiler/parser/framework_parser.py +38 -22
  395. mindspore/profiler/parser/hccl_parser.py +16 -12
  396. mindspore/profiler/parser/integrator.py +22 -11
  397. mindspore/profiler/parser/memory_usage_parser.py +2 -2
  398. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  399. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  400. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  401. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  402. mindspore/profiler/parser/optime_parser.py +1 -1
  403. mindspore/profiler/parser/profiler_info.py +21 -2
  404. mindspore/profiler/parser/step_trace_parser.py +11 -14
  405. mindspore/profiler/profiling.py +179 -89
  406. mindspore/rewrite/api/node.py +102 -19
  407. mindspore/rewrite/api/node_type.py +5 -1
  408. mindspore/rewrite/api/pattern_engine.py +1 -1
  409. mindspore/rewrite/api/scoped_value.py +9 -17
  410. mindspore/rewrite/api/symbol_tree.py +131 -47
  411. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  412. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  413. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  414. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
  415. mindspore/rewrite/common/rewrite_elog.py +5 -1
  416. mindspore/rewrite/namer.py +33 -24
  417. mindspore/rewrite/namespace.py +14 -5
  418. mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
  419. mindspore/rewrite/node/call_function.py +79 -0
  420. mindspore/rewrite/node/cell_container.py +135 -0
  421. mindspore/rewrite/node/control_flow.py +88 -0
  422. mindspore/rewrite/{node.py → node/node.py} +273 -234
  423. mindspore/rewrite/node/node_manager.py +254 -0
  424. mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
  425. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  426. mindspore/rewrite/parsers/assign_parser.py +216 -221
  427. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  428. mindspore/rewrite/parsers/class_def_parser.py +174 -113
  429. mindspore/rewrite/parsers/constant_parser.py +9 -6
  430. mindspore/rewrite/parsers/container_parser.py +9 -7
  431. mindspore/rewrite/parsers/for_parser.py +42 -21
  432. mindspore/rewrite/parsers/function_def_parser.py +24 -16
  433. mindspore/rewrite/parsers/if_parser.py +28 -24
  434. mindspore/rewrite/parsers/module_parser.py +196 -25
  435. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  436. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  437. mindspore/rewrite/parsers/return_parser.py +6 -6
  438. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  439. mindspore/rewrite/sparsify/utils.py +1 -1
  440. mindspore/rewrite/symbol_tree.py +523 -578
  441. mindspore/rewrite/symbol_tree_builder.py +9 -193
  442. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  443. mindspore/run_check/_check_version.py +6 -4
  444. mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
  445. mindspore/safeguard/rewrite_obfuscation.py +541 -0
  446. mindspore/scipy/linalg.py +1 -1
  447. mindspore/scipy/ops.py +55 -5
  448. mindspore/scipy/optimize/__init__.py +3 -2
  449. mindspore/scipy/optimize/linear_sum_assignment.py +38 -33
  450. mindspore/scipy/optimize/minimize.py +7 -3
  451. mindspore/train/_utils.py +7 -3
  452. mindspore/train/amp.py +323 -123
  453. mindspore/train/anf_ir_pb2.py +14 -2
  454. mindspore/train/callback/_backup_and_restore.py +2 -12
  455. mindspore/train/callback/_callback.py +29 -4
  456. mindspore/train/callback/_checkpoint.py +23 -8
  457. mindspore/train/callback/_early_stop.py +2 -2
  458. mindspore/train/callback/_landscape.py +4 -4
  459. mindspore/train/callback/_loss_monitor.py +2 -2
  460. mindspore/train/callback/_on_request_exit.py +2 -2
  461. mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
  462. mindspore/train/callback/_summary_collector.py +15 -8
  463. mindspore/train/callback/_time_monitor.py +58 -5
  464. mindspore/train/data_sink.py +5 -11
  465. mindspore/train/dataset_helper.py +84 -57
  466. mindspore/train/loss_scale_manager.py +2 -2
  467. mindspore/train/metrics/__init__.py +3 -3
  468. mindspore/train/metrics/cosine_similarity.py +1 -1
  469. mindspore/train/metrics/hausdorff_distance.py +3 -2
  470. mindspore/train/metrics/mean_surface_distance.py +3 -2
  471. mindspore/train/metrics/metric.py +39 -19
  472. mindspore/train/metrics/roc.py +2 -2
  473. mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
  474. mindspore/train/mind_ir_pb2.py +85 -36
  475. mindspore/train/model.py +187 -47
  476. mindspore/train/serialization.py +487 -161
  477. mindspore/train/summary/_summary_adapter.py +1 -1
  478. mindspore/train/summary/_writer_pool.py +3 -2
  479. mindspore/train/summary/summary_record.py +37 -17
  480. mindspore/train/train_thor/convert_utils.py +3 -3
  481. mindspore/train/train_thor/dataset_helper.py +1 -1
  482. mindspore/version.py +1 -1
  483. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/METADATA +8 -8
  484. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/RECORD +488 -539
  485. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/entry_points.txt +0 -1
  486. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  487. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  488. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  489. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  490. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  491. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  492. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  493. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  494. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  495. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  496. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  497. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  498. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  499. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  500. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  501. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  502. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  503. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  504. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  505. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  506. mindspore/_extends/graph_kernel/expander.py +0 -80
  507. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
  508. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  509. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  510. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  511. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  512. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  513. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  514. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  515. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  516. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  517. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  518. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  519. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  520. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  521. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  522. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  523. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  524. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  525. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  526. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  527. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  528. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  529. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  530. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  531. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  532. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  533. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  534. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  535. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  536. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  537. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  538. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  539. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  540. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  541. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  542. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  543. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  544. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  545. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  546. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  547. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  548. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  549. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  550. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  551. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  552. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  553. mindspore/dataset/datapreprocess/__init__.py +0 -20
  554. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  555. mindspore/include/api/net.h +0 -142
  556. mindspore/nn/lr_scheduler.py +0 -262
  557. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
  558. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
  559. mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
  560. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  561. mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
  562. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +0 -350
  563. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +0 -409
  564. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +0 -578
  565. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +0 -199
  566. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +0 -446
  567. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
  568. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +0 -45
  569. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +0 -67
  570. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +0 -62
  571. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
  572. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
  573. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
  574. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  575. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  576. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
  577. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
  578. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  579. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  580. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
  581. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
  582. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
  583. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
  584. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
  585. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  586. mindspore/rewrite/node_visitor.py +0 -44
  587. /mindspore/{ops/_op_impl/_custom_op/flash_attention → _akg/akg/utils/ascend_profilier}/__init__.py +0 -0
  588. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/WHEEL +0 -0
  589. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/top_level.txt +0 -0
@@ -16,16 +16,18 @@
16
16
  from typing import Optional, Union
17
17
  import ast
18
18
  import inspect
19
+ from types import FunctionType
19
20
 
20
21
  from mindspore.nn import Cell
21
22
  from mindspore.ops import Primitive
22
23
  from mindspore import log as logger
23
- from .. import _checkparam as Validator
24
- from .ast_helpers import AstModifier
25
- from .api.scoped_value import ScopedValue, ValueType
26
- from .api.node_type import NodeType
27
- from .namespace import is_subtree
28
- from .ast_helpers.ast_replacer import AstReplacer
24
+ from ... import _checkparam as Validator
25
+ from ..ast_helpers import AstModifier
26
+ from ..api.scoped_value import ScopedValue, ValueType
27
+ from ..api.node_type import NodeType
28
+ from ..namespace import is_subtree
29
+ from ..ast_helpers.ast_replacer import AstReplacer
30
+ from ..ast_creator_register import ast_creator_registry
29
31
 
30
32
  PASS_THROUGH_METHOD = ScopedValue.create_naming_value("PassThrough")
31
33
 
@@ -36,35 +38,33 @@ class Node:
36
38
  invoking in forward which could be an instance of Cell, an instance of Primitive or a callable method. Fields of
37
39
  Node has different meaning in different type of node:
38
40
 
39
- - CallCell: a call-cell node represents an assign statement whose value is a calling to cell in mindspore. `targets`
40
- is corresponding to targets of ast.Assign which means return values of this cell-op. `args` and `kwargs` are
41
- corresponding to args and keywords of ast.Call which mean arguments to invoke cell-op's forward method. `func` is
42
- corresponding to func of call expression which means symbol of the cell-op.
41
+ - CallCell: a call-cell node represents an assign statement whose value is a calling to cell in mindspore.
42
+ `targets` is corresponding to targets of ast.Assign which means return values of this cell-op. `args` and
43
+ `kwargs` are corresponding to args and keywords of ast.Call which mean arguments to invoke cell-op's forward
44
+ method. `func` is corresponding to func of call expression which means symbol of the cell-op.
43
45
  - CallPrimitive: a call-primitive node represents an ast.Assign whose value is a calling to operator in mindspore.
44
- `targets`, `args`, `kwargs` and `func` are as previous.
46
+ `targets`, `args`, `kwargs` and `func_name` are as previous.
45
47
  - CallMethod: a call-method node represents an ast.Assign whose value is a calling to python-method such as `len`.
46
- `targets` is corresponding to targets of ast.Assign which means return values of this method. `func` represents
47
- the string name of method. `args` and `kwargs` are corresponding to args and keywords to invoke the method. When
48
- value of ast.Assign is an ast.Name or ast.Attribute, it means a simplest assign which would also be mapped to
49
- CallMethod node whose `func` is "PassThrough".
50
- - GetAttr: retrieves a parameter from the SymbolTree hierarchy. `func` represents which parameter in SymbolTree
51
- hierarchy. `targets` is corresponding to targets of ast.Assign which means what symbol to accept the result of
52
- get-attr. `args` and `kwargs` are don't-care.
48
+ `targets` is corresponding to targets of ast.Assign which means return values of this method. `func_name`
49
+ represents the string name of method. `args` and `kwargs` are corresponding to args and keywords to invoke the
50
+ method. When value of ast.Assign is an ast.Name or ast.Attribute, it means a simplest assign which would also be
51
+ mapped to CallMethod node whose `func_name` is "PassThrough".
53
52
  - Python: a python node holds an ast-node which is not parsed. a python node means some python statement is not
54
- supported by Rewrite or ignored by Rewrite. `targets`, `args`, `kwargs` and `func` are don't-care.
53
+ supported by Rewrite or ignored by Rewrite. `targets`, `args`, `kwargs` and `func_name` are don't-care.
55
54
  - Input: an input node represents an input of current network which also a parameter of forward method of Cell.
56
55
  `targets` is corresponding to arg-name of parameter of forward function. `args` means default-value of parameter
57
- of forward function. `kwargs` and `func` are don't-care.
56
+ of forward function. `kwargs` and `func_name` are don't-care.
58
57
  - Output: an output node represents the output of current network which is corresponding to return statement of
59
- forward method of Cell. `args` represents return values. `func` are always be "return". `targets` and `kwargs` are
60
- don't-care.
58
+ forward method of Cell. `args` represents return values. `func_name` are always be "return". `targets` and
59
+ `kwargs` are don't-care.
61
60
  - Tree: a tree node represents a sub-network call in current network. A sub-network is also a Cell in mindspore, so
62
- `targets`, `args`, `kwargs` and `func` are same as a call-cell node. `symbol_tree` is a handler of a SymbolTree
63
- instance.
61
+ `targets`, `args`, `kwargs` and `func_name` are same as a call-cell node. `symbol_tree` is a handler of a
62
+ SymbolTree instance.
64
63
  """
65
64
 
66
65
  def __init__(self, node_type: NodeType, ast_node: Optional[ast.AST], targets: [ScopedValue],
67
- func: Optional[ScopedValue], args: [ScopedValue], kwargs: {str: ScopedValue}, name: str, instance):
66
+ func_name: Optional[ScopedValue], args: [ScopedValue], kwargs: {str: ScopedValue}, name: str,
67
+ instance):
68
68
  """
69
69
  Constructor of Node. Rewrite recommend invoking class method of Node to instantiate an instance of Node such
70
70
  as `create_call_op`, `create_call_method`, `create_python_node`, `create_input_node` and
@@ -75,7 +75,7 @@ class Node:
75
75
  ast_node (ast.AST, optional): An instance of ast.AST represents corresponding node in ast. `ast_node` should
76
76
  not be None except when node type is Unknown.
77
77
  targets (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
78
- func (ScopedValue, optional): An instance of ScopedValue. See detail in docstring of Node class.
78
+ func_name (ScopedValue, optional): An instance of ScopedValue. See detail in docstring of Node class.
79
79
  args (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
80
80
  kwargs (dict{str: ScopedValue}): A list of instance of ScopedValue. See detail in docstring of Node class.
81
81
  name (str): A string represents name of node. Name of node will be unique when inserted into SymbolTree.
@@ -89,7 +89,7 @@ class Node:
89
89
  self._attribute = Node._get_cell_or_prim_op_attribute(instance)
90
90
  self._instance = instance
91
91
  self._name = name
92
- self._func: Optional[ScopedValue] = func
92
+ self._func_name: Optional[ScopedValue] = func_name
93
93
  self._targets: [ScopedValue] = targets
94
94
  self._args_num = len(args) if args is not None else 0
95
95
  self._kwargs_num = len(kwargs) if kwargs is not None else 0
@@ -101,48 +101,17 @@ class Node:
101
101
  self._next: Optional[Node] = None
102
102
  # A handler of SymbolTree current node belonging to
103
103
  self._belong_tree = None
104
- # A dict that records which target of which Node current Node's argument came from
104
+ # A handler of NodeManager current node belonging to
105
+ self._node_manager = None
106
+ # A dict that records which target of which Node current Node's argument come from
105
107
  self._arg_providers: {int: (Node, int)} = {}
106
108
  # A dict that records which argument of which Node uses current Node's target
107
109
  self._target_users: {int: [(Node, int)]} = {}
108
110
 
109
- @classmethod
110
- def create_call_buildin_op(cls, op: Union[Cell, Primitive], ast_node: Optional[ast.AST], targets: [ScopedValue],
111
- func: ScopedValue, args: [ScopedValue] = None, kwargs: {str: ScopedValue}=None,
112
- name: str = ""):
113
- """
114
- Class method of Node. Instantiate an instance of node whose type is `CallCell` or `CallPrimitive`.
115
- A `CallCell` node represents an invoking to cell-op.
116
- A `CallPrimitive` node represents an invoking to primitive-op.
117
-
118
- Args:
119
- op (Union[Cell, Primitive]): An instance of `Cell` or `Primitive` corresponding to this node.
120
- ast_node ([ast.AST, optional]): An instance of ast.AST represents corresponding node in ast.
121
- targets (list[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class.
122
- func ([ScopedValue, optional]): An instance of `ScopedValue`. See detail in docstring of Node class.
123
- args (list[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class.
124
- kwargs (dict{str: ScopedValue}): A list of instance of `ScopedValue`. See detail in docstring of `Node`
125
- class.
126
- name (str): A string represents name of node. Name of node will be unique when inserted into `SymbolTree`.
127
- Name of node also used as field name in network class.
128
- """
129
-
130
- if not isinstance(op, (Cell, Primitive)):
131
- raise ValueError("Input op is not a buildin op(Cell or Primitive): ", type(op))
132
- non_custom_args = Node._handle_custom_obj_in_args(args)
133
- non_custom_kwargs = Node._handle_custom_obj_in_kwargs(kwargs)
134
- if ast_node is None:
135
- ast_node = AstModifier.create_call_assign(targets, func, non_custom_args, non_custom_kwargs)
136
- if isinstance(op, Cell):
137
- node_type = NodeType.CallCell
138
- else:
139
- node_type = NodeType.CallPrimitive
140
- return cls(node_type, ast_node, targets, func, args, kwargs, name, op)
141
-
142
111
  @classmethod
143
112
  def create_call_method(cls, ast_node: Optional[ast.AST], targets: [Union[ScopedValue, str]],
144
- func: Union[ScopedValue, str], args: [ScopedValue] = None, kwargs: {str: ScopedValue}=None,
145
- name: str = ""):
113
+ func_name: Union[ScopedValue, str], args: [ScopedValue] = None,
114
+ kwargs: {str: ScopedValue}=None, name: str = ""):
146
115
  """
147
116
  Class method of Node. Instantiate an instance of node whose type is CallCell. A CallCell node represents an
148
117
  invoking to cell-op.
@@ -151,7 +120,7 @@ class Node:
151
120
  ast_node ([ast.AST, optional]): An instance of ast.AST represents corresponding node in ast. `ast_node`
152
121
  should not be None currently.
153
122
  targets (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
154
- func ([ScopedValue, optional]): An instance of ScopedValue. See detail in docstring of Node class.
123
+ func_name ([ScopedValue, optional]): An instance of ScopedValue. See detail in docstring of Node class.
155
124
  args (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
156
125
  kwargs (dict{str: ScopedValue}): A list of instance of ScopedValue. See detail in docstring of Node class.
157
126
  name (str): A string represents name of node. Name of node will be unique when inserted into SymbolTree.
@@ -161,12 +130,12 @@ class Node:
161
130
  args = []
162
131
  if kwargs is None:
163
132
  kwargs = {}
164
- if isinstance(func, str):
165
- func = ScopedValue.create_naming_value(func)
133
+ if isinstance(func_name, str):
134
+ func_name = ScopedValue.create_naming_value(func_name)
166
135
  new_targets = Node._handle_targets(targets)
167
136
  if ast_node is None:
168
137
  raise RuntimeError("Input ast_node is None")
169
- return cls(NodeType.CallMethod, ast_node, new_targets, func, args, kwargs, name, None)
138
+ return cls(NodeType.CallMethod, ast_node, new_targets, func_name, args, kwargs, name, None)
170
139
 
171
140
  @classmethod
172
141
  def create_call_pass_through_method(cls, ast_node: Optional[ast.AST], targets: [Union[ScopedValue, str]],
@@ -189,7 +158,8 @@ class Node:
189
158
  return cls(NodeType.Python, ast_node, None, None, [], {}, name, instance)
190
159
 
191
160
  @classmethod
192
- def create_input_node(cls, ast_node: ast.AST, arg_name: str, default: Optional[ScopedValue] = None, name: str = ""):
161
+ def create_input_node(cls, ast_node: Optional[ast.AST], arg_name: str, default: Optional[ScopedValue] = None,
162
+ name: str = ""):
193
163
  """
194
164
  Class method of Node. Instantiate an instance of node whose type is Input. An Input node represents input of
195
165
  SymbolTree which is corresponding to parameters of forward function.
@@ -206,6 +176,8 @@ class Node:
206
176
  args = []
207
177
  else:
208
178
  args = [default]
179
+ if ast_node is None:
180
+ ast_node = ast.arg(arg_name)
209
181
  return cls(NodeType.Input, ast_node, [target], None, args, {}, name, None)
210
182
 
211
183
  @classmethod
@@ -243,17 +215,83 @@ class Node:
243
215
  args (list[ScopedValue]): Values participating in the mathematical operations. All values are saved
244
216
  sequentially in the list.
245
217
  ops (dict[str:ScopedValue]): Operators participating in the mathematical operations. All operators are
246
- saved sequentially in the dict, and keys are numbers in string format, such as {'0':'add', '1':'sub'}.
218
+ saved sequentially in the dict, and keys are numbers in string format, such as {'0':'add', '1':'sub'}.
247
219
  name (str): A string represents name of node. Name of node will be unique when inserted into `SymbolTree`.
248
220
  Name of node also used as field name in network class. The format of mathops node name
249
221
  is 'AstNodeName_AstOpName_n'.
250
222
  """
251
223
  return cls(NodeType.MathOps, ast_node, targets, op_type, args, ops, name, None)
252
224
 
225
+ @staticmethod
226
+ def create_assign_node(targets, func_name, args, kwargs):
227
+ """Create a ast.Assign type node."""
228
+ # create targets
229
+ ast_targets = [ast_creator_registry.get("Name")(targets)]
230
+ # create call
231
+ ast_func = ast_creator_registry.get("Attribute")(func_name)
232
+ ast_args = ast_creator_registry.get("Args")(args)
233
+ ast_kwargs = ast_creator_registry.get("KwArgs")(kwargs) if kwargs else []
234
+ ast_value = ast_creator_registry.get("Call")(func=ast_func, args=ast_args, keywords=ast_kwargs)
235
+ # create assign
236
+ ast_node = ast_creator_registry.get("Assign")(targets=ast_targets, value=ast_value)
237
+ return ast_node
238
+
239
+ @staticmethod
240
+ def _create_call_function(function: FunctionType, targets: [Union[ScopedValue, str]], args: [ScopedValue] = None,
241
+ kwargs: {str: ScopedValue}=None):
242
+ """
243
+ Create a node that corresponds to a function call.
244
+
245
+ Args:
246
+ function (FunctionType): The function to be called.
247
+ targets (list[str]): indicates output names. Used as targets of an assign statement in source code.
248
+ args (list[ScopedValue]): Indicate input names. Used as args of a call expression of an assign statement in
249
+ source code. Default: ``None`` , which indicates the `function` has no args inputs.
250
+ kwargs (dict): Type of key must be `str` and type of value must be `ScopedValue`.
251
+ Indicate keyword input names. Used as kwargs of a call expression of an assign statement in source
252
+ code. Default: ``None`` , which indicates the `function` has no kwargs inputs.
253
+
254
+ Returns:
255
+ An instance of `Node`.
256
+ """
257
+ if args is None:
258
+ args = []
259
+ if kwargs is None:
260
+ kwargs = {}
261
+ targets = Node._handle_targets(targets)
262
+ _package = None
263
+ if isinstance(function, FunctionType):
264
+ _package = function.__globals__['__package__']
265
+ func_full_name = ".".join([_package, function.__name__]) if _package else function.__name__
266
+ func_scope = ''
267
+ func_name = func_full_name.split('.')[-1]
268
+ if func_full_name.count('.') > 0:
269
+ func_scope = func_full_name.rsplit('.')[0]
270
+ func_scope_name = ScopedValue.create_naming_value(func_name, func_scope)
271
+ node = Node.inner_create_call_function(func_name, None, func_scope_name, function, targets, args, kwargs)
272
+ return node
273
+
274
+ @classmethod
275
+ def inner_create_call_function(cls, node_name, ast_node, func_name, function, targets, args, kwargs):
276
+ '''
277
+ Instantiate an instance of node whose type is `CallFunction`.
278
+
279
+ Args:
280
+ node_name (str): Name of node.
281
+ func_name (str): Name of function.
282
+ ast_node ([ast.AST, optional]): An instance of ast.AST represents corresponding node in ast.
283
+ targets (list[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class.
284
+ function (Object): An instance of function. See detail in docstring of Node class.
285
+ args (list[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class.
286
+ kwargs (dict{str: ScopedValue}): A list of instance of `ScopedValue`. See detail in docstring of `Node`
287
+ class.
288
+ '''
289
+ return cls(NodeType.CallFunction, ast_node, targets, func_name, args, kwargs, node_name, function)
290
+
253
291
  @staticmethod
254
292
  def create_call_op(op: Union[Cell, Primitive], ast_node: Optional[ast.AST], targets: [Union[ScopedValue, str]],
255
- func: Union[ScopedValue, str], args: [ScopedValue] = None, kwargs: {str: ScopedValue}=None,
256
- name: str = "", is_sub_net: bool = False):
293
+ args: [ScopedValue] = None, kwargs: {str: ScopedValue}=None, node_name: str = "",
294
+ is_sub_net: bool = False):
257
295
  """
258
296
  Static method of Node. Instantiate an instance of node whose type is `CallCell` or `CallPrimitive`.
259
297
  If op is custom defined, it is treated by TreeNode.
@@ -264,12 +302,11 @@ class Node:
264
302
  op (Union[Cell, Primitive]): An instance of `Cell` or `Primitive` corresponding to this node.
265
303
  ast_node ([ast.AST, optional]): An instance of ast.AST represents corresponding node in ast.
266
304
  targets (list[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class.
267
- func ([ScopedValue, optional]): An instance of `ScopedValue`. See detail in docstring of Node class.
268
305
  args (list[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class.
269
306
  kwargs (dict{str: ScopedValue}): A list of instance of `ScopedValue`. See detail in docstring of `Node`
270
307
  class.
271
- name (str): A string represents name of node. Name of node will be unique when inserted into `SymbolTree`.
272
- Name of node also used as field name in network class.
308
+ node_name (str): A string represents name of node. Name of node will be unique when inserted into
309
+ `SymbolTree`. Name of node also used as field name in network class.
273
310
  is_sub_net (bool): Indicate that is `cell` a network. If `is_sub_net` is true, Rewrite will try to parse the
274
311
  `cell` to a TreeNode, else a CallCell Node. Default is a False.
275
312
  """
@@ -277,29 +314,58 @@ class Node:
277
314
  if ast_node is not None:
278
315
  Validator.check_value_type("ast_node", ast_node, [ast.AST], "Node")
279
316
  Validator.check_element_type_of_iterable("targets", targets, [ScopedValue, str], "Node")
280
- Validator.check_value_type("func", func, [ScopedValue, str], "Node")
281
317
  if args is not None:
282
318
  Validator.check_element_type_of_iterable("args", args, [ScopedValue], "Node")
283
319
  if kwargs is not None:
284
320
  Validator.check_element_type_of_dict("kwargs", kwargs, [str], [ScopedValue], "Node")
285
- cls_name = type(op).__name__
286
-
287
321
  if args is None:
288
322
  args = []
289
323
  if kwargs is None:
290
324
  kwargs = {}
291
- if isinstance(func, str):
292
- func = ScopedValue.create_naming_value(func)
325
+ Validator.check_value_type("node_name", node_name, [str], "Node")
293
326
  new_targets = Node._handle_targets(targets)
294
- if is_sub_net and is_subtree(cls_name):
295
- from .symbol_tree_builder import SymbolTreeBuilder
327
+ if isinstance(node_name, str):
328
+ func_name = ScopedValue.create_naming_value(node_name)
329
+ else:
330
+ func_name = node_name
331
+ if is_sub_net and is_subtree(op):
332
+ from ..symbol_tree_builder import SymbolTreeBuilder
296
333
  stb = SymbolTreeBuilder(op)
297
334
  stree = stb.build()
298
335
  replacer = AstReplacer(stree.get_class_ast())
299
336
  replacer.replace_all(stree.get_ori_cls_name(), stree.get_opt_cls_name())
300
- return TreeNode.create_tree_node(stree, ast_node, new_targets, func, args, kwargs, name, op)
337
+ return TreeNode.create_tree_node(stree, ast_node, new_targets, func_name, args, kwargs, node_name, op)
301
338
 
302
- return Node.create_call_buildin_op(op, ast_node, new_targets, func, args, kwargs, name)
339
+ return Node.create_call_buildin_op(op, ast_node, new_targets, func_name, args, kwargs, node_name)
340
+
341
+ @classmethod
342
+ def create_call_buildin_op(cls, op: Union[Cell, Primitive], ast_node: Optional[ast.AST], targets: [ScopedValue],
343
+ func_name: ScopedValue, args: [ScopedValue] = None, kwargs: {str: ScopedValue}=None,
344
+ node_name: str = ""):
345
+ """
346
+ Class method of Node. Instantiate an instance of node whose type is `CallCell` or `CallPrimitive`.
347
+ A `CallCell` node represents an invoking to cell-op.
348
+ A `CallPrimitive` node represents an invoking to primitive-op.
349
+
350
+ Args:
351
+ op (Union[Cell, Primitive]): An instance of `Cell` or `Primitive` corresponding to this node.
352
+ ast_node ([ast.AST, optional]): An instance of ast.AST represents corresponding node in ast.
353
+ targets (list[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class.
354
+ func_name ([ScopedValue, optional]): An instance of `ScopedValue`. See detail in docstring of Node class.
355
+ args (list[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class.
356
+ kwargs (dict{str: ScopedValue}): A list of instance of `ScopedValue`. See detail in docstring of `Node`
357
+ class.
358
+ node_name (str): A string represents name of node. Name of node will be unique when inserted into
359
+ `SymbolTree`. Name of node also used as field name in network class.
360
+ """
361
+
362
+ if not isinstance(op, (Cell, Primitive)):
363
+ raise ValueError("Input op is not a buildin op(Cell or Primitive): ", type(op))
364
+ if isinstance(op, Cell):
365
+ node_type = NodeType.CallCell
366
+ else:
367
+ node_type = NodeType.CallPrimitive
368
+ return cls(node_type, ast_node, targets, func_name, args, kwargs, node_name, op)
303
369
 
304
370
  @staticmethod
305
371
  def _get_construct_arg_names(parameters):
@@ -508,21 +574,23 @@ class Node:
508
574
  """
509
575
  return self._next
510
576
 
511
- def has_same_ast(self, node: Union['Node', ast.AST]) -> bool:
577
+ def set_prev(self, node: 'Node'):
512
578
  """
513
- Check if other node holds same ast node with self.
579
+ Set previous node of current node.
514
580
 
515
581
  Args:
516
- node (Union[Node, ast.AST]): An instance of ast.AST or an instance of node to be compared.
582
+ node (Node): Node to be set as previous node of current node.
583
+ """
584
+ self._prev = node
517
585
 
518
- Returns:
519
- A bool.
586
+ def set_next(self, node: 'Node'):
587
+ """
588
+ Set next node of current node.
589
+
590
+ Args:
591
+ node (Node): Node to be set as next node of current node.
520
592
  """
521
- if isinstance(node, Node):
522
- return self.has_same_ast(node._ast_node)
523
- if isinstance(node, ast.AST):
524
- return id(self._ast_node) == id(node)
525
- return False
593
+ self._next = node
526
594
 
527
595
  def get_ast(self) -> Optional[ast.AST]:
528
596
  """
@@ -552,16 +620,24 @@ class Node:
552
620
  """Set the symbol tree to which node belongs."""
553
621
  self._belong_tree = symbol_tree
554
622
 
623
+ def get_node_manager(self):
624
+ """Get the NodeManager current node belongs to."""
625
+ return self._node_manager
626
+
627
+ def set_node_manager(self, node_manager):
628
+ """Set NodeManager current node belongs."""
629
+ self._node_manager = node_manager
630
+
555
631
  def isolate(self):
556
632
  """Link prev node to next node and isolate node from source code order list."""
557
- origin_prev: Optional[Node] = self._prev
558
- origin_next: Optional[Node] = self._next
633
+ origin_prev: Optional[Node] = self.get_prev()
634
+ origin_next: Optional[Node] = self.get_next()
559
635
  if origin_prev is not None:
560
- origin_prev._next = origin_next
636
+ origin_prev.set_next(origin_next)
561
637
  if origin_next is not None:
562
- origin_next._prev = origin_prev
563
- self._prev = None
564
- self._next = None
638
+ origin_next.set_prev(origin_prev)
639
+ self.set_prev(None)
640
+ self.set_next(None)
565
641
 
566
642
  def insert_before(self, node: 'Node'):
567
643
  """
@@ -571,12 +647,12 @@ class Node:
571
647
  node (Node): An instance of node to be inserted in.
572
648
  """
573
649
  node.isolate()
574
- origin_prev: Optional[Node] = self._prev
650
+ origin_prev: Optional[Node] = self.get_prev()
575
651
  if origin_prev is not None:
576
- origin_prev._next = node
577
- node._prev = origin_prev
578
- node._next = self
579
- self._prev = node
652
+ origin_prev.set_next(node)
653
+ node.set_prev(origin_prev)
654
+ node.set_next(self)
655
+ self.set_prev(node)
580
656
 
581
657
  def insert_after(self, node: 'Node'):
582
658
  """
@@ -586,12 +662,12 @@ class Node:
586
662
  node (Node): An instance of node to be inserted in.
587
663
  """
588
664
  node.isolate()
589
- origin_next: Optional[Node] = self._next
590
- self._next = node
591
- node._prev = self
592
- node._next = origin_next
665
+ origin_next: Optional[Node] = self.get_next()
666
+ self.set_next(node)
667
+ node.set_prev(self)
668
+ node.set_next(origin_next)
593
669
  if origin_next is not None:
594
- origin_next._prev = node
670
+ origin_next.set_prev(node)
595
671
 
596
672
  def get_inputs(self) -> ['Node']:
597
673
  """
@@ -651,26 +727,26 @@ class Node:
651
727
  NodeType.MathOps):
652
728
  self._sync_assign_targets_to_ast()
653
729
 
654
- def get_func(self) -> ScopedValue:
730
+ def get_func_name(self) -> ScopedValue:
655
731
  """
656
- Getter of `_func`. See detail in docstring of Node class for meaning of func.
732
+ Getter of `_func_name`. See detail in docstring of Node class for meaning of func.
657
733
 
658
734
  Returns:
659
735
  An instance of ScopedValue.
660
736
  """
661
- return self._func
737
+ return self._func_name
662
738
 
663
- def set_func(self, func: ScopedValue):
739
+ def set_func_name(self, func_name: ScopedValue):
664
740
  """
665
- Setter of `_func`. See detail in docstring of Node class for meaning of func.
741
+ Setter of `_func_name`. See detail in docstring of Node class for meaning of func.
666
742
 
667
743
  Note:
668
- When `_func` is updated, corresponding ast node would be updated also.
744
+ When `_func_name` is updated, corresponding ast node would be updated also.
669
745
 
670
746
  Args:
671
747
  func (ScopedValue): An instance of ScopedValue as new func.
672
748
  """
673
- self._func = func
749
+ self._func_name = func_name
674
750
  if self._node_type in (NodeType.CallCell, NodeType.CallPrimitive):
675
751
  self._sync_assign_func_to_ast()
676
752
 
@@ -747,11 +823,11 @@ class Node:
747
823
  Validator.check_value_type("node", node, [Node], "Node")
748
824
  Validator.check_int_range(arg_idx, 0, self._args_num, Validator.INC_LEFT, "arg_idx")
749
825
  if out_idx is None:
750
- if len(node._targets) != 1:
826
+ if len(node.get_targets()) != 1:
751
827
  raise RuntimeError("node should has one output when out_idx is not provided")
752
828
  out_idx = 0
753
- Validator.check_int_range(out_idx, 0, len(node._targets), Validator.INC_LEFT, "arg_idx")
754
- new_arg = node._targets[out_idx]
829
+ Validator.check_int_range(out_idx, 0, len(node.get_targets()), Validator.INC_LEFT, "arg_idx")
830
+ new_arg = node.get_targets()[out_idx]
755
831
  self._normalized_args[self._normalized_args_keys[arg_idx]] = new_arg
756
832
  self._sync_arg()
757
833
 
@@ -943,18 +1019,36 @@ class Node:
943
1019
  def get_arg_providers(self) -> dict:
944
1020
  """
945
1021
  Getter of _arg_providers.
1022
+
1023
+ Return:
1024
+ dict, key is type of int indicating the index of args, and value is type of tuple, which includes
1025
+ the node and the index of node's targets who provides the argument.
946
1026
  """
947
1027
  return self._arg_providers
948
1028
 
949
1029
  def set_arg_providers(self, index: int, provider: tuple):
950
1030
  """
951
1031
  Setter of _arg_providers.
1032
+
1033
+ Args:
1034
+ index (int): Indicating provider of which argument need to be set.
1035
+ provider (tuple): A tuple includes the node and the index of node's targets who provides the argument.
952
1036
  """
953
1037
  self._arg_providers[index] = provider
954
1038
 
955
1039
  def get_target_users(self, index=-1) -> Union[dict, list]:
956
1040
  """
957
1041
  Getter of _target_users.
1042
+
1043
+ Args:
1044
+ index (int): Indicating users of which target need to be got. Default: -1, means all targets's users will
1045
+ be returned.
1046
+
1047
+ Return:
1048
+ Union[dict, list]. When index is not -1, a list of users of specified target will be returned.
1049
+ The type of elements in list is tuple, which includes the user node and the index of node's arguments
1050
+ who uses the target. When index is -1, a dict will be returned. The key is index of targets, and the
1051
+ value is list of users of corresponding target.
958
1052
  """
959
1053
  if index == -1:
960
1054
  return self._target_users
@@ -965,11 +1059,23 @@ class Node:
965
1059
  def append_target_users(self, index: int, provider: tuple):
966
1060
  """
967
1061
  Setter of _target_users.
1062
+
1063
+ Args:
1064
+ index (int): Indicating users of which target need to be append.
1065
+ provider (tuple): A tuple includes the node and the index of node's argument who uses the target.
1066
+
968
1067
  """
969
1068
  if index not in self._target_users.keys():
970
1069
  self._target_users[index] = []
971
1070
  self._target_users.get(index).append(provider)
972
1071
 
1072
+ def update_ast_node(self) -> ast.AST:
1073
+ """Update node's ast_node by current targets, func_name, args and kwargs."""
1074
+ ast_assign = AstModifier.create_call_assign(self.get_targets(), self.get_func_name(),
1075
+ self.get_args(), self.get_kwargs())
1076
+ self.set_ast(ast_assign)
1077
+ return ast_assign
1078
+
973
1079
  def _get_normalized_args(self, args: [ScopedValue], kwargs: {str: ScopedValue}) -> dict:
974
1080
  """
975
1081
  Merge args and kwargs to normalized args.
@@ -1010,6 +1116,10 @@ class Node:
1010
1116
  self._normalized_args_keys.append(arg_key)
1011
1117
  return normalized_args
1012
1118
 
1119
+ ##########################################################################################################
1120
+ # Synchronize rewrite node args to ast node
1121
+ ##########################################################################################################
1122
+
1013
1123
  def _sync_assign_func_to_ast(self):
1014
1124
  """Sync func of ast.Call of ast.Assign from self._name when NodeType is CallCell or CallPrimitive."""
1015
1125
  if self._ast_node is None:
@@ -1021,20 +1131,21 @@ class Node:
1021
1131
  if not isinstance(call_ast, ast.Call):
1022
1132
  raise TypeError("call_ast should be ast.Call, got: ", type(call_ast))
1023
1133
  func_ast = call_ast.func
1024
- if not self._func.value:
1134
+ if not self._func_name.value:
1025
1135
  if isinstance(func_ast, ast.Name):
1026
- func_ast.id = self._func.value
1136
+ func_ast.id = self._func_name.value
1027
1137
  else:
1028
- call_ast.func = ast.Name(self._func.value, ast.Store())
1138
+ call_ast.func = ast.Name(self._func_name.value, ast.Store())
1029
1139
  else:
1030
1140
  if isinstance(func_ast, ast.Attribute):
1031
1141
  func_value = func_ast.value
1032
1142
  if not isinstance(func_value, ast.Name):
1033
1143
  raise RuntimeError("Only support ast.Name as value of attribute ", type(func_ast.value))
1034
- func_value.id = self._func.scope
1035
- func_ast.attr = self._func.value
1144
+ func_value.id = self._func_name.scope
1145
+ func_ast.attr = self._func_name.value
1036
1146
  else:
1037
- call_ast.func = ast.Attribute(ast.Name(self._func.scope, ast.Load()), self._func.value, ast.Store())
1147
+ call_ast.func = ast.Attribute(ast.Name(self._func_name.scope, ast.Load()),
1148
+ self._func_name.value, ast.Store())
1038
1149
  ast.fix_missing_locations(assign_ast)
1039
1150
 
1040
1151
  def _sync_assign_targets_to_ast(self):
@@ -1050,7 +1161,7 @@ class Node:
1050
1161
  raise RuntimeError("self._targets should have the same length as targets_ast's elts")
1051
1162
  if not isinstance(targets_ast[0], ast.Tuple) and len(self._targets) != len(targets_ast):
1052
1163
  raise RuntimeError("self._targets should have targets_ast same length")
1053
- for i in range(0, len(self._targets)):
1164
+ for i, _ in enumerate(self._targets):
1054
1165
  target = self._targets[i]
1055
1166
  target_ast = targets_ast[0]
1056
1167
  if isinstance(target_ast, ast.Name):
@@ -1070,7 +1181,7 @@ class Node:
1070
1181
  return
1071
1182
  assign_ast = self._ast_node
1072
1183
  if not isinstance(assign_ast, ast.Assign):
1073
- raise TypeError("assign_ast should be ast.Assign, got: ", type(assign_ast))
1184
+ raise TypeError(f"assign_ast should be ast.Assign, got: {type(assign_ast)}")
1074
1185
  assign_value = assign_ast.value
1075
1186
  if not isinstance(assign_value, ast.Call):
1076
1187
  return
@@ -1121,23 +1232,31 @@ class Node:
1121
1232
  if len(self._normalized_args_keys) != 1:
1122
1233
  raise RuntimeError("self._normalized_args_keys should have 1 elements")
1123
1234
  arg = self._normalized_args.get(self._normalized_args_keys[0])
1124
- if arg.type not in (ValueType.IntValue, ValueType.FloatValue, ValueType.StringValue, ValueType.NoneValue):
1125
- raise RuntimeError("arg should be an IntValue, FloatValue, StringValue or NoneValue")
1235
+ if arg.type != ValueType.ConstantValue:
1236
+ raise RuntimeError("arg should be an ConstantValue")
1126
1237
  if arg.scope != "":
1127
1238
  raise RuntimeError("arg.scope should be empty")
1128
1239
  assign_value.value = arg.value
1129
1240
 
1130
1241
  def _sync_call_method_args_to_ast(self):
1131
- """Sync args of ast.Cell of ast.Assign from self._normalized_args when NodeType is CallMethod."""
1242
+ """
1243
+ Sync args to value of ast.Assign from self._normalized_args when NodeType is CallMethod.
1244
+
1245
+ For node with type of CallMethod, the value of ast.Assign is one of:
1246
+ - ast.Tuple
1247
+ - ast.Name
1248
+ - ast.ast.Attribute
1249
+ - ...
1250
+ """
1132
1251
  if self._ast_node is None:
1133
1252
  return
1134
1253
  assign_ast = self._ast_node
1135
1254
  if not isinstance(assign_ast, ast.Assign):
1136
1255
  raise TypeError("assign_ast should be ast.Assign, got: ", type(assign_ast))
1137
1256
  assign_value = assign_ast.value
1138
- if self._func == PASS_THROUGH_METHOD:
1257
+ if self._func_name == PASS_THROUGH_METHOD:
1139
1258
  self._sync_call_pass_through_method_args_to_ast(assign_value)
1140
- elif self._func.value == "tuple":
1259
+ elif self._func_name.value == "tuple":
1141
1260
  tuple_ast: ast.Tuple = assign_value
1142
1261
  if len(self._normalized_args_keys) != len(tuple_ast.elts):
1143
1262
  raise RuntimeError("size of self._normalized_args_keys should be equal to size of elements of tuple")
@@ -1157,10 +1276,16 @@ class Node:
1157
1276
  else:
1158
1277
  raise RuntimeError("Only support constant or symbol in tuple now")
1159
1278
  else:
1160
- raise RuntimeError("Only support pass_through or tuple method as call_method now, ", self._func.value)
1279
+ raise RuntimeError("Only support pass_through or tuple method as call_method now, ", self._func_name.value)
1161
1280
 
1162
1281
  def _sync_return_node_to_ast(self):
1163
- """Sync return value of ast.Return from self._normalized_args when NodeType is Output."""
1282
+ """
1283
+ Sync args to value of ast.Return from self._normalized_args when NodeType is Output.
1284
+
1285
+ For node with type of CallMethod, the value of ast.Assign is one of:
1286
+ - ast.Name
1287
+ - ast.Tuple
1288
+ """
1164
1289
  if self._ast_node is None:
1165
1290
  return
1166
1291
  return_ast = self._ast_node
@@ -1222,7 +1347,7 @@ class Node:
1222
1347
 
1223
1348
  def _sync_arg(self):
1224
1349
  """Sync _normalized_args to corresponding ast node when updated."""
1225
- if self._node_type in (NodeType.CallCell, NodeType.CallPrimitive, NodeType.Tree,\
1350
+ if self._node_type in (NodeType.CallCell, NodeType.CallPrimitive, NodeType.Tree, \
1226
1351
  NodeType.CellContainer, NodeType.CallFunction):
1227
1352
  self._sync_call_cell_args_to_ast()
1228
1353
  elif self._node_type == NodeType.Output:
@@ -1233,15 +1358,18 @@ class Node:
1233
1358
  self._sync_mathops_node_args_to_ast()
1234
1359
 
1235
1360
 
1361
+ ##########################################################################################################
1362
+ # Child classes
1363
+ ##########################################################################################################
1364
+
1236
1365
  class TreeNode(Node):
1237
1366
  """Tree type Node who holds a handler of SymbolTree."""
1238
1367
 
1239
1368
  def __init__(self, tree, ast_node: ast.AST, targets: [ScopedValue], func: ScopedValue,
1240
1369
  args: [ScopedValue], kwargs: {str: ScopedValue}, name: str, instance):
1241
1370
  """
1242
- Constructor of Node. Rewrite recommend to invoking class method of Node to instantiate an instance of Node such
1243
- as `create_call_buildin_op`, `create_call_method`, `create_python_node`, `create_input_node` and
1244
- `create_output_node`, etc. rather than invoking constructor of Node directly.
1371
+ Constructor of TreeNode. Rewrite recommend to invoking class method of Node to instantiate an instance of
1372
+ TreeNode such as `create_tree_node` rather than invoking constructor of Node directly.
1245
1373
 
1246
1374
  Args:
1247
1375
  tree: An instance of SymbolTree represents a handler of sub-symbol-tree.
@@ -1260,8 +1388,9 @@ class TreeNode(Node):
1260
1388
  self.symbol_tree = tree
1261
1389
 
1262
1390
  @classmethod
1263
- def create_tree_node(cls, tree, ast_node: ast.AST, targets: Union[ScopedValue, str], func: Union[ScopedValue, str],
1264
- args: [ScopedValue], kwargs: {str: ScopedValue}, name: str = "", instance=None):
1391
+ def create_tree_node(cls, tree, ast_node: ast.AST, targets: Union[ScopedValue, str],
1392
+ func_name: Union[ScopedValue, str], args: [ScopedValue], kwargs: {str: ScopedValue},
1393
+ name: str = "", instance=None):
1265
1394
  """
1266
1395
  Class method of TreeNode. Instantiate an instance of node whose type is Tree. A Tree node represents an invoking
1267
1396
  to sub-network.
@@ -1270,104 +1399,14 @@ class TreeNode(Node):
1270
1399
  tree: An instance of SymbolTree represents a handler of sub-symbol-tree.
1271
1400
  ast_node (ast.AST): An instance of ast.AST represents corresponding node in ast.
1272
1401
  targets (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
1273
- func ([ScopedValue, optional]): An instance of ScopedValue. See detail in docstring of Node class.
1402
+ func_name ([ScopedValue, optional]): An instance of ScopedValue. See detail in docstring of Node class.
1274
1403
  args (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
1275
1404
  kwargs (dict{str: ScopedValue}): A list of instance of ScopedValue. See detail in docstring of Node class.
1276
1405
  name (str): A string represents name of node. Name of node will be unique when inserted into SymbolTree.
1277
1406
  Name of node also used as field name in network class.
1278
1407
  instance: Object in network corresponding to this node.
1279
1408
  """
1280
-
1281
- non_custom_args = Node._handle_custom_obj_in_args(args)
1282
- non_custom_kwargs = Node._handle_custom_obj_in_kwargs(kwargs)
1283
1409
  new_targets = Node._handle_targets(targets)
1284
- if isinstance(func, str):
1285
- func = ScopedValue.create_naming_value(func)
1286
- if ast_node is None:
1287
- ast_node = AstModifier.create_call_assign(new_targets, func, non_custom_args, non_custom_kwargs)
1288
- return cls(tree, ast_node, new_targets, func, args, kwargs, name, instance)
1289
-
1290
-
1291
- class CellContainer(Node):
1292
- """ Container for saving cell-objects node. """
1293
- class _Visitor():
1294
- """ A iterator of CellContainer nodes. """
1295
- def __init__(self, cellcontainer):
1296
- self._cellcontainer = cellcontainer
1297
-
1298
- def __len__(self):
1299
- """ Get the number of nodes. """
1300
- return self._cellcontainer.node_count
1301
-
1302
- def __iter__(self):
1303
- """Create an iterator over the CellContainer."""
1304
- count = len(self._cellcontainer.node_list)
1305
- i = 0
1306
- while i < count:
1307
- curr = self._cellcontainer.node_list[i]
1308
- if curr.valid:
1309
- yield curr
1310
- i += 1
1311
-
1312
- def __init__(self, ast_node: ast.AST, targets: [ScopedValue], func: ScopedValue,
1313
- args: [ScopedValue], kwargs: {str: ScopedValue}, name: str, instance):
1314
- """Constructor of CellContainer.
1315
-
1316
- Args:
1317
- ast_node (ast.AST): An instance of ast.AST represents corresponding node in ast.
1318
- targets (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
1319
- func ([ScopedValue, optional]): An instance of ScopedValue. See detail in docstring of Node class.
1320
- args (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
1321
- kwargs (dict{str: ScopedValue}): A list of instance of ScopedValue. See detail in docstring of Node class.
1322
- name (str): A string represents name of node. Name of node will be unique when inserted into SymbolTree.
1323
- Name of node also used as field name in network class.
1324
- instance: Object in network corresponding to this node.
1325
- """
1326
- if isinstance(func, str):
1327
- func = ScopedValue.create_naming_value(func)
1328
- super().__init__(NodeType.CellContainer, ast_node, targets, func, args, kwargs, name, instance)
1329
- self._node_list = list()
1330
- self._node_count = 0
1331
-
1332
- @property
1333
- def node_count(self):
1334
- """Number of nodes."""
1335
- return len(self._node_list)
1336
-
1337
- @property
1338
- def node_list(self):
1339
- """ Get node list. """
1340
- return self._node_list
1341
-
1342
- def append(self, node):
1343
- """ Append new node to node list. """
1344
- setattr(node, "container", self)
1345
- setattr(node, "valid", True)
1346
- node.set_belong_symbol_tree(self.get_belong_symbol_tree())
1347
- self._node_list.append(node)
1348
- # when creating a cell_container, node instance is already in SequentialCell cell_list
1349
- # so here we need to write a if judgement
1350
- if node.get_instance() not in self.get_instance().cell_list:
1351
- self.get_instance().append(node.get_instance())
1352
-
1353
- def erase(self, node):
1354
- """Erase node form container."""
1355
- index_node = self.node_list.index(node)
1356
- index_instance = self.get_instance().cell_list.index(node.get_instance())
1357
- if index_node != index_instance:
1358
- raise RuntimeError("In MindSpore Rewrite CellContainer, erasing a node raises index error!!!")
1359
- setattr(node, "valid", False)
1360
- del self.get_instance()[index_node]
1361
- del self._node_list[index_node]
1362
-
1363
- def insert(self, index, node):
1364
- """Insert node into container"""
1365
- self.node_list.insert(index, node)
1366
- setattr(node, "container", self)
1367
- setattr(node, "valid", True)
1368
- node.set_belong_symbol_tree(self.get_belong_symbol_tree())
1369
- self.get_instance()._insert(index, node.get_instance())
1370
-
1371
- def nodes(self):
1372
- """ Return a iterator of node."""
1373
- return self._Visitor(self)
1410
+ if isinstance(func_name, str):
1411
+ func_name = ScopedValue.create_naming_value(func_name)
1412
+ return cls(tree, ast_node, new_targets, func_name, args, kwargs, name, instance)